|
1 | 1 | import asyncio |
| 2 | +import hashlib |
| 3 | +import json |
2 | 4 | import logging |
3 | | -import ssl |
4 | | -from typing import Any, Dict, List, Optional |
| 5 | +from typing import Any, Callable, Dict, List, Optional |
5 | 6 |
|
6 | | -from gql import Client, gql |
7 | | -from gql.transport.aiohttp import AIOHTTPTransport |
8 | | -from gql.transport.websockets import WebsocketsTransport |
| 7 | +import httpx |
| 8 | +from httpx_sse import aconnect_sse |
9 | 9 |
|
10 | 10 | from pyhilo import API |
11 | 11 | from pyhilo.const import LOG, PLATFORM_HOST |
@@ -533,91 +533,158 @@ async def async_init(self) -> None: |
533 | 533 | async def call_get_location_query(self, location_hilo_id: str) -> None: |
534 | 534 | """This functions calls the digital-twin and requests location id""" |
535 | 535 | access_token = await self._get_access_token() |
536 | | - transport = AIOHTTPTransport( |
537 | | - url=f"https://{PLATFORM_HOST}/api/digital-twin/v3/graphql", |
538 | | - headers={"Authorization": f"Bearer {access_token}"}, |
539 | | - ) |
540 | | - client = Client(transport=transport, fetch_schema_from_transport=True) |
541 | | - query = gql(self.QUERY_GET_LOCATION) |
| 536 | + url = f"https://{PLATFORM_HOST}/api/digital-twin/v3/graphql" |
| 537 | + headers = {"Authorization": f"Bearer {access_token}"} |
| 538 | + |
| 539 | + query = self.QUERY_GET_LOCATION |
| 540 | + query_hash = hashlib.sha256(query.encode("utf-8")).hexdigest() |
| 541 | + |
| 542 | + payload = { |
| 543 | + "extensions": { |
| 544 | + "persistedQuery": { |
| 545 | + "version": 1, |
| 546 | + "sha256Hash": query_hash, |
| 547 | + } |
| 548 | + }, |
| 549 | + "variables": {"locationHiloId": location_hilo_id}, |
| 550 | + } |
| 551 | + |
| 552 | + async with httpx.AsyncClient(http2=True) as client: |
| 553 | + try: |
| 554 | + response = await client.post(url, json=payload, headers=headers) |
| 555 | + response.raise_for_status() |
| 556 | + response_json = response.json() |
| 557 | + except Exception as e: |
| 558 | + LOG.error("Error parsing response: %s", e) |
| 559 | + return |
| 560 | + |
| 561 | + if "errors" in response_json: |
| 562 | + for error in response_json["errors"]: |
| 563 | + if error.get("message") == "PersistedQueryNotFound": |
| 564 | + payload["query"] = query |
| 565 | + try: |
| 566 | + response = await client.post( |
| 567 | + url, json=payload, headers=headers |
| 568 | + ) |
| 569 | + response.raise_for_status() |
| 570 | + response_json = response.json() |
| 571 | + except Exception as e: |
| 572 | + LOG.error("Error parsing response on retry: %s", e) |
| 573 | + return |
| 574 | + break |
542 | 575 |
|
543 | | - async with client as session: |
544 | | - result = await session.execute( |
545 | | - query, variable_values={"locationHiloId": location_hilo_id} |
546 | | - ) |
547 | | - self._handle_query_result(result) |
| 576 | + if "errors" in response_json: |
| 577 | + LOG.error("GraphQL errors: %s", response_json["errors"]) |
| 578 | + return |
| 579 | + |
| 580 | + if "data" in response_json: |
| 581 | + self._handle_query_result(response_json["data"]) |
548 | 582 |
|
549 | 583 | async def subscribe_to_device_updated( |
550 | 584 | self, location_hilo_id: str, callback: callable = None |
551 | 585 | ) -> None: |
552 | 586 | LOG.debug("subscribe_to_device_updated called") |
| 587 | + await self._listen_to_sse( |
| 588 | + self.SUBSCRIPTION_DEVICE_UPDATED, |
| 589 | + {"locationHiloId": location_hilo_id}, |
| 590 | + self._handle_device_subscription_result, |
| 591 | + callback, |
| 592 | + location_hilo_id, |
| 593 | + ) |
| 594 | + |
| 595 | + async def subscribe_to_location_updated( |
| 596 | + self, location_hilo_id: str, callback: callable = None |
| 597 | + ) -> None: |
| 598 | + LOG.debug("subscribe_to_location_updated called") |
| 599 | + await self._listen_to_sse( |
| 600 | + self.SUBSCRIPTION_LOCATION_UPDATED, |
| 601 | + {"locationHiloId": location_hilo_id}, |
| 602 | + self._handle_location_subscription_result, |
| 603 | + callback, |
| 604 | + location_hilo_id, |
| 605 | + ) |
553 | 606 |
|
554 | | - # Setting log level to suppress keepalive messages on gql transport |
555 | | - logging.getLogger("gql.transport.websockets").setLevel(logging.WARNING) |
556 | | - |
557 | | - # |
558 | | - loop = asyncio.get_event_loop() |
559 | | - ssl_context = await loop.run_in_executor(None, ssl.create_default_context) |
560 | | - |
561 | | - while True: # Loop to reconnect if the connection is lost |
562 | | - LOG.debug("subscribe_to_device_updated while true") |
563 | | - access_token = await self._get_access_token() |
564 | | - transport = WebsocketsTransport( |
565 | | - url=f"wss://{PLATFORM_HOST}/api/digital-twin/v3/graphql?access_token={access_token}", |
566 | | - ssl=ssl_context, |
567 | | - ) |
568 | | - client = Client(transport=transport, fetch_schema_from_transport=True) |
569 | | - query = gql(self.SUBSCRIPTION_DEVICE_UPDATED) |
| 607 | + async def _listen_to_sse( |
| 608 | + self, |
| 609 | + query: str, |
| 610 | + variables: Dict[str, Any], |
| 611 | + handler: Callable[[Dict[str, Any]], str], |
| 612 | + callback: Optional[Callable[[str], None]] = None, |
| 613 | + location_hilo_id: str = None, |
| 614 | + ) -> None: |
| 615 | + query_hash = hashlib.sha256(query.encode("utf-8")).hexdigest() |
| 616 | + payload = { |
| 617 | + "extensions": { |
| 618 | + "persistedQuery": { |
| 619 | + "version": 1, |
| 620 | + "sha256Hash": query_hash, |
| 621 | + } |
| 622 | + }, |
| 623 | + "variables": variables, |
| 624 | + } |
| 625 | + |
| 626 | + while True: |
570 | 627 | try: |
571 | | - async with client as session: |
572 | | - async for result in session.subscribe( |
573 | | - query, variable_values={"locationHiloId": location_hilo_id} |
574 | | - ): |
575 | | - LOG.debug( |
576 | | - "subscribe_to_device_updated: Received subscription result %s", |
577 | | - result, |
578 | | - ) |
579 | | - device_hilo_id = self._handle_device_subscription_result(result) |
580 | | - if callback: |
581 | | - callback(device_hilo_id) |
| 628 | + access_token = await self._get_access_token() |
| 629 | + url = f"https://{PLATFORM_HOST}/api/digital-twin/v3/graphql" |
| 630 | + headers = {"Authorization": f"Bearer {access_token}"} |
| 631 | + |
| 632 | + retry_with_full_query = False |
| 633 | + |
| 634 | + async with httpx.AsyncClient(http2=True, timeout=None) as client: |
| 635 | + async with aconnect_sse( |
| 636 | + client, "POST", url, json=payload, headers=headers |
| 637 | + ) as event_source: |
| 638 | + async for sse in event_source.aiter_sse(): |
| 639 | + if not sse.data: |
| 640 | + continue |
| 641 | + try: |
| 642 | + data = json.loads(sse.data) |
| 643 | + except json.JSONDecodeError: |
| 644 | + continue |
| 645 | + |
| 646 | + if "errors" in data: |
| 647 | + if any( |
| 648 | + e.get("message") == "PersistedQueryNotFound" |
| 649 | + for e in data["errors"] |
| 650 | + ): |
| 651 | + retry_with_full_query = True |
| 652 | + break |
| 653 | + LOG.error( |
| 654 | + "GraphQL Subscription Errors: %s", data["errors"] |
| 655 | + ) |
| 656 | + continue |
| 657 | + |
| 658 | + if "data" in data: |
| 659 | + LOG.debug( |
| 660 | + "Received subscription result %s", data["data"] |
| 661 | + ) |
| 662 | + result = handler(data["data"]) |
| 663 | + if callback: |
| 664 | + callback(result) |
| 665 | + |
| 666 | + if retry_with_full_query: |
| 667 | + payload["query"] = query |
| 668 | + continue |
| 669 | + |
582 | 670 | except Exception as e: |
583 | 671 | LOG.debug( |
584 | | - "subscribe_to_device_updated: Connection lost: %s. Reconnecting in 5 seconds...", |
585 | | - e, |
| 672 | + "Subscription connection lost: %s. Reconnecting in 5 seconds...", e |
586 | 673 | ) |
587 | 674 | await asyncio.sleep(5) |
588 | | - try: |
589 | | - await self.call_get_location_query(location_hilo_id) |
590 | | - LOG.debug( |
591 | | - "subscribe_to_device_updated, call_get_location_query success" |
592 | | - ) |
593 | | - |
594 | | - except Exception as e2: |
595 | | - LOG.error( |
596 | | - "subscribe_to_device_updated, exception while reconnecting, retrying: %s", |
597 | | - e2, |
598 | | - ) |
| 675 | + # Reset payload to APQ only on reconnect |
| 676 | + if "query" in payload: |
| 677 | + del payload["query"] |
599 | 678 |
|
600 | | - async def subscribe_to_location_updated( |
601 | | - self, location_hilo_id: str, callback: callable = None |
602 | | - ) -> None: |
603 | | - access_token = await self._get_access_token() |
604 | | - transport = WebsocketsTransport( |
605 | | - url=f"wss://{PLATFORM_HOST}/api/digital-twin/v3/graphql?access_token={access_token}" |
606 | | - ) |
607 | | - client = Client(transport=transport, fetch_schema_from_transport=True) |
608 | | - query = gql(self.SUBSCRIPTION_LOCATION_UPDATED) |
609 | | - try: |
610 | | - async with client as session: |
611 | | - async for result in session.subscribe( |
612 | | - query, variable_values={"locationHiloId": location_hilo_id} |
613 | | - ): |
614 | | - LOG.debug("Received subscription result %s", result) |
615 | | - device_hilo_id = self._handle_location_subscription_result(result) |
616 | | - callback(device_hilo_id) |
617 | | - except asyncio.CancelledError: |
618 | | - LOG.debug("Subscription cancelled.") |
619 | | - asyncio.sleep(1) |
620 | | - await self.subscribe_to_location_updated(location_hilo_id) |
| 679 | + if location_hilo_id: |
| 680 | + try: |
| 681 | + await self.call_get_location_query(location_hilo_id) |
| 682 | + LOG.debug("call_get_location_query success after reconnect") |
| 683 | + except Exception as e2: |
| 684 | + LOG.error( |
| 685 | + "exception while RE-connecting, retrying: %s", |
| 686 | + e2, |
| 687 | + ) |
621 | 688 |
|
622 | 689 | async def _get_access_token(self) -> str: |
623 | 690 | """Get the access token.""" |
|
0 commit comments