Skip to content

Commit 2b1f6db

Browse files
authored
Catch and wrap aiohttp and mashumaro errors (#8)
1 parent a1ab9c6 commit 2b1f6db

3 files changed

Lines changed: 53 additions & 18 deletions

File tree

go2rtc_client/exceptions.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,45 @@
22

33
from __future__ import annotations
44

5+
from functools import wraps
6+
from typing import TYPE_CHECKING, Any
7+
8+
from aiohttp import ClientError
9+
from mashumaro.exceptions import (
10+
ExtraKeysError,
11+
InvalidFieldValue,
12+
MissingDiscriminatorError,
13+
MissingField,
14+
SuitableVariantNotFoundError,
15+
UnserializableDataError,
16+
)
17+
18+
if TYPE_CHECKING:
19+
from collections.abc import Callable, Coroutine
20+
521

622
class Go2RtcClientError(Exception):
723
"""Base exception for go2rtc client."""
24+
25+
26+
def handle_error[**_P, _R](
27+
func: Callable[_P, Coroutine[Any, Any, _R]],
28+
) -> Callable[_P, Coroutine[Any, Any, _R]]:
29+
"""Wrap aiohttp and mashumaro errors."""
30+
31+
@wraps(func)
32+
async def _func(*args: _P.args, **kwargs: _P.kwargs) -> _R:
33+
try:
34+
return await func(*args, **kwargs)
35+
except (
36+
ClientError,
37+
ExtraKeysError,
38+
InvalidFieldValue,
39+
MissingDiscriminatorError,
40+
MissingField,
41+
SuitableVariantNotFoundError,
42+
UnserializableDataError,
43+
) as exc:
44+
raise Go2RtcClientError from exc
45+
46+
return _func

go2rtc_client/rest.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from mashumaro.mixins.dict import DataClassDictMixin
1212
from yarl import URL
1313

14+
from .exceptions import handle_error
1415
from .models import Stream, WebRTCSdpAnswer, WebRTCSdpOffer
1516

1617
if TYPE_CHECKING:
@@ -78,6 +79,7 @@ async def _forward_sdp_offer(
7879
)
7980
return WebRTCSdpAnswer.from_dict(await resp.json())
8081

82+
@handle_error
8183
async def forward_whep_sdp_offer(
8284
self, source_name: str, offer: WebRTCSdpOffer
8385
) -> WebRTCSdpAnswer:
@@ -99,11 +101,13 @@ def __init__(self, client: _BaseClient) -> None:
99101
"""Initialize Client."""
100102
self._client = client
101103

104+
@handle_error
102105
async def list(self) -> dict[str, Stream]:
103106
"""List streams registered with the server."""
104107
resp = await self._client.request("GET", self.PATH)
105108
return _GET_STREAMS_DECODER.decode(await resp.json())
106109

110+
@handle_error
107111
async def add(self, name: str, source: str) -> None:
108112
"""Add a stream to the server."""
109113
await self._client.request(

go2rtc_client/ws/client.py

Lines changed: 10 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,11 @@
66
from typing import TYPE_CHECKING, Any
77
from urllib.parse import urljoin
88

9-
from aiohttp import (
10-
ClientError,
11-
ClientSession,
12-
ClientWebSocketResponse,
13-
WSMsgType,
14-
WSServerHandshakeError,
15-
)
9+
from aiohttp import ClientSession, ClientWebSocketResponse, WSMsgType
1610

17-
from go2rtc_client.exceptions import Go2RtcClientError
18-
from go2rtc_client.ws.messages import BaseMessage
11+
from go2rtc_client.exceptions import handle_error
12+
13+
from .messages import BaseMessage
1914

2015
_LOGGER = logging.getLogger(__name__)
2116

@@ -56,26 +51,22 @@ def connected(self) -> bool:
5651
"""Return if we're currently connected."""
5752
return self._client is not None and not self._client.closed
5853

54+
@handle_error
5955
async def connect(self) -> None:
6056
"""Connect to device."""
6157
async with self._connect_lock:
6258
if self.connected:
6359
return
6460

6561
_LOGGER.debug("Trying to connect to %s", self._server_url)
66-
try:
67-
self._client = await self._session.ws_connect(
68-
urljoin(self._server_url, "/api/ws"), params=self._params
69-
)
70-
except (
71-
WSServerHandshakeError,
72-
ClientError,
73-
) as err:
74-
raise Go2RtcClientError(err) from err
62+
self._client = await self._session.ws_connect(
63+
urljoin(self._server_url, "/api/ws"), params=self._params
64+
)
7565

7666
self._rx_task = asyncio.create_task(self._receive_messages())
7767
_LOGGER.info("Connected to %s", self._server_url)
7868

69+
@handle_error
7970
async def close(self) -> None:
8071
"""Close connection."""
8172
if self.connected:
@@ -85,6 +76,7 @@ async def close(self) -> None:
8576
self._client = None
8677
await client.close()
8778

79+
@handle_error
8880
async def send(self, message: BaseMessage) -> None:
8981
"""Send a message."""
9082
if not self.connected:

0 commit comments

Comments
 (0)