11import asyncio
22import hashlib
3+ import json
34import logging
4- import ssl
5- from typing import Any , Dict , List , Optional
5+ from typing import Any , Callable , Dict , List , Optional
66
7- import aiohttp
8- from gql import Client , gql
9- from gql .transport .aiohttp import AIOHTTPTransport
10- from gql .transport .websockets import WebsocketsTransport
7+ import httpx
8+ from httpx_sse import aconnect_sse
119
1210from pyhilo import API
1311from pyhilo .const import LOG , PLATFORM_HOST
@@ -551,24 +549,28 @@ async def call_get_location_query(self, location_hilo_id: str) -> None:
551549 "variables" : {"locationHiloId" : location_hilo_id },
552550 }
553551
554- async with aiohttp .ClientSession (headers = headers ) as session :
555- async with session .post (url , json = payload ) as response :
556- try :
557- response_json = await response .json ()
558- except Exception as e :
559- LOG .error ("Error parsing response: %s" , e )
560- return
552+ async with httpx .AsyncClient (http2 = True ) as client :
553+ try :
554+ response = await client .post (url , json = payload , headers = headers )
555+ response .raise_for_status ()
556+ response_json = response .json ()
557+ except Exception as e :
558+ LOG .error ("Error parsing response: %s" , e )
559+ return
561560
562561 if "errors" in response_json :
563562 for error in response_json ["errors" ]:
564563 if error .get ("message" ) == "PersistedQueryNotFound" :
565564 payload ["query" ] = query
566- async with session .post (url , json = payload ) as response :
567- try :
568- response_json = await response .json ()
569- except Exception as e :
570- LOG .error ("Error parsing response on retry: %s" , e )
571- return
565+ try :
566+ response = await client .post (
567+ url , json = payload , headers = headers
568+ )
569+ response .raise_for_status ()
570+ response_json = response .json ()
571+ except Exception as e :
572+ LOG .error ("Error parsing response on retry: %s" , e )
573+ return
572574 break
573575
574576 if "errors" in response_json :
@@ -582,74 +584,107 @@ async def subscribe_to_device_updated(
582584 self , location_hilo_id : str , callback : callable = None
583585 ) -> None :
584586 LOG .debug ("subscribe_to_device_updated called" )
587+ await self ._listen_to_sse (
588+ self .SUBSCRIPTION_DEVICE_UPDATED ,
589+ {"locationHiloId" : location_hilo_id },
590+ self ._handle_device_subscription_result ,
591+ callback ,
592+ location_hilo_id ,
593+ )
594+
595+ async def subscribe_to_location_updated (
596+ self , location_hilo_id : str , callback : callable = None
597+ ) -> None :
598+ LOG .debug ("subscribe_to_location_updated called" )
599+ await self ._listen_to_sse (
600+ self .SUBSCRIPTION_LOCATION_UPDATED ,
601+ {"locationHiloId" : location_hilo_id },
602+ self ._handle_location_subscription_result ,
603+ callback ,
604+ location_hilo_id ,
605+ )
585606
586- # Setting log level to suppress keepalive messages on gql transport
587- logging .getLogger ("gql.transport.websockets" ).setLevel (logging .WARNING )
588-
589- #
590- loop = asyncio .get_event_loop ()
591- ssl_context = await loop .run_in_executor (None , ssl .create_default_context )
592-
593- while True : # Loop to reconnect if the connection is lost
594- LOG .debug ("subscribe_to_device_updated while true" )
595- access_token = await self ._get_access_token ()
596- transport = WebsocketsTransport (
597- url = f"wss://{ PLATFORM_HOST } /api/digital-twin/v3/graphql?access_token={ access_token } " ,
598- ssl = ssl_context ,
599- )
600- client = Client (transport = transport , fetch_schema_from_transport = True )
601- query = gql (self .SUBSCRIPTION_DEVICE_UPDATED )
607+ async def _listen_to_sse (
608+ self ,
609+ query : str ,
610+ variables : Dict [str , Any ],
611+ handler : Callable [[Dict [str , Any ]], str ],
612+ callback : Optional [Callable [[str ], None ]] = None ,
613+ location_hilo_id : str = None ,
614+ ) -> None :
615+ query_hash = hashlib .sha256 (query .encode ("utf-8" )).hexdigest ()
616+ payload = {
617+ "extensions" : {
618+ "persistedQuery" : {
619+ "version" : 1 ,
620+ "sha256Hash" : query_hash ,
621+ }
622+ },
623+ "variables" : variables ,
624+ }
625+
626+ while True :
602627 try :
603- async with client as session :
604- async for result in session .subscribe (
605- query , variable_values = {"locationHiloId" : location_hilo_id }
606- ):
607- LOG .debug (
608- "subscribe_to_device_updated: Received subscription result %s" ,
609- result ,
610- )
611- device_hilo_id = self ._handle_device_subscription_result (result )
612- if callback :
613- callback (device_hilo_id )
628+ access_token = await self ._get_access_token ()
629+ url = f"https://{ PLATFORM_HOST } /api/digital-twin/v3/graphql"
630+ headers = {"Authorization" : f"Bearer { access_token } " }
631+
632+ retry_with_full_query = False
633+
634+ async with httpx .AsyncClient (http2 = True , timeout = None ) as client :
635+ async with aconnect_sse (
636+ client , "POST" , url , json = payload , headers = headers
637+ ) as event_source :
638+ async for sse in event_source .aiter_sse ():
639+ if not sse .data :
640+ continue
641+ try :
642+ data = json .loads (sse .data )
643+ except json .JSONDecodeError :
644+ continue
645+
646+ if "errors" in data :
647+ if any (
648+ e .get ("message" ) == "PersistedQueryNotFound"
649+ for e in data ["errors" ]
650+ ):
651+ retry_with_full_query = True
652+ break
653+ LOG .error (
654+ "GraphQL Subscription Errors: %s" , data ["errors" ]
655+ )
656+ continue
657+
658+ if "data" in data :
659+ LOG .debug (
660+ "Received subscription result %s" , data ["data" ]
661+ )
662+ result = handler (data ["data" ])
663+ if callback :
664+ callback (result )
665+
666+ if retry_with_full_query :
667+ payload ["query" ] = query
668+ continue
669+
614670 except Exception as e :
615671 LOG .debug (
616- "subscribe_to_device_updated: Connection lost: %s. Reconnecting in 5 seconds..." ,
617- e ,
672+ "Subscription connection lost: %s. Reconnecting in 5 seconds..." , e
618673 )
619674 await asyncio .sleep (5 )
620- try :
621- await self .call_get_location_query (location_hilo_id )
622- LOG .debug (
623- "subscribe_to_device_updated, call_get_location_query success"
624- )
625-
626- except Exception as e2 :
627- LOG .error (
628- "subscribe_to_device_updated, exception while reconnecting, retrying: %s" ,
629- e2 ,
630- )
675+ # Reset payload to APQ only on reconnect
676+ if "query" in payload :
677+ del payload ["query" ]
631678
632- async def subscribe_to_location_updated (
633- self , location_hilo_id : str , callback : callable = None
634- ) -> None :
635- access_token = await self ._get_access_token ()
636- transport = WebsocketsTransport (
637- url = f"wss://{ PLATFORM_HOST } /api/digital-twin/v3/graphql?access_token={ access_token } "
638- )
639- client = Client (transport = transport , fetch_schema_from_transport = True )
640- query = gql (self .SUBSCRIPTION_LOCATION_UPDATED )
641- try :
642- async with client as session :
643- async for result in session .subscribe (
644- query , variable_values = {"locationHiloId" : location_hilo_id }
645- ):
646- LOG .debug ("Received subscription result %s" , result )
647- device_hilo_id = self ._handle_location_subscription_result (result )
648- callback (device_hilo_id )
649- except asyncio .CancelledError :
650- LOG .debug ("Subscription cancelled." )
651- asyncio .sleep (1 )
652- await self .subscribe_to_location_updated (location_hilo_id )
679+ if location_hilo_id :
680+ try :
681+ await self .call_get_location_query (location_hilo_id )
682+ LOG .debug ("call_get_location_query success after reconnect" )
683+ except Exception as e2 :
684+ LOG .error (
685+ "exception while RE-connecting, retrying: %s" ,
686+ e2 ,
687+ )
653688
654689 async def _get_access_token (self ) -> str :
655690 """Get the access token."""
0 commit comments