Skip to content

Commit e85e8fd

Browse files
edenhausjoostlek
andauthored
As ws client (#2)
Co-authored-by: Joost Lekkerkerker <joostlek@outlook.com>
1 parent 047e12a commit e85e8fd

13 files changed

Lines changed: 658 additions & 17 deletions

File tree

go2rtc_client/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""go2rtc client."""
22

3-
from .client import Go2RtcClient
3+
from . import ws
44
from .models import Stream, WebRTCSdpAnswer, WebRTCSdpOffer
5+
from .rest import Go2RtcRestClient
56

6-
__all__ = ["Go2RtcClient", "Stream", "WebRTCSdpAnswer", "WebRTCSdpOffer"]
7+
__all__ = ["Go2RtcRestClient", "Stream", "WebRTCSdpAnswer", "WebRTCSdpOffer", "ws"]

go2rtc_client/exceptions.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
"""Go2rtc client exceptions."""
2+
3+
from __future__ import annotations
4+
5+
6+
class Go2RtcClientError(Exception):
7+
"""Base exception for go2rtc client."""
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,8 +117,8 @@ async def add(self, name: str, source: str) -> None:
117117
)
118118

119119

120-
class Go2RtcClient:
121-
"""Client for go2rtc server."""
120+
class Go2RtcRestClient:
121+
"""Rest client for go2rtc server."""
122122

123123
def __init__(self, websession: ClientSession, server_url: str) -> None:
124124
"""Initialize Client."""

go2rtc_client/ws/__init__.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
"""Websocket module."""
2+
3+
from .client import Go2RtcWsClient
4+
from .messages import (
5+
ReceiveMessages,
6+
SendMessages,
7+
WebRTCAnswer,
8+
WebRTCCandidate,
9+
WebRTCOffer,
10+
WsError,
11+
)
12+
13+
__all__ = [
14+
"ReceiveMessages",
15+
"SendMessages",
16+
"Go2RtcWsClient",
17+
"WebRTCCandidate",
18+
"WebRTCOffer",
19+
"WebRTCAnswer",
20+
"WsError",
21+
]

go2rtc_client/ws/client.py

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
"""Websocket client for go2rtc server."""
2+
3+
import asyncio
4+
from collections.abc import Callable
5+
import logging
6+
from typing import TYPE_CHECKING, Any
7+
from urllib.parse import urljoin
8+
9+
from aiohttp import (
10+
ClientError,
11+
ClientSession,
12+
ClientWebSocketResponse,
13+
WSMsgType,
14+
WSServerHandshakeError,
15+
)
16+
17+
from go2rtc_client.exceptions import Go2RtcClientError
18+
from go2rtc_client.ws.messages import BaseMessage
19+
20+
_LOGGER = logging.getLogger(__name__)
21+
22+
23+
class Go2RtcWsClient:
24+
"""Websocket client for go2rtc server."""
25+
26+
def __init__(
27+
self,
28+
session: ClientSession,
29+
server_url: str,
30+
*,
31+
source: str | None = None,
32+
destination: str | None = None,
33+
) -> None:
34+
"""Initialize Client."""
35+
if source:
36+
if destination:
37+
msg = "Source and destination cannot be set at the same time"
38+
raise ValueError(msg)
39+
params = {"src": source}
40+
elif destination:
41+
params = {"dst": destination}
42+
else:
43+
msg = "Source or destination must be set"
44+
raise ValueError(msg)
45+
46+
self._server_url = server_url
47+
self._session = session
48+
self._params = params
49+
self._client: ClientWebSocketResponse | None = None
50+
self._rx_task: asyncio.Task[None] | None = None
51+
self._subscribers: list[Callable[[BaseMessage], None]] = []
52+
self._connect_lock = asyncio.Lock()
53+
54+
@property
55+
def connected(self) -> bool:
56+
"""Return if we're currently connected."""
57+
return self._client is not None and not self._client.closed
58+
59+
async def connect(self) -> None:
60+
"""Connect to device."""
61+
async with self._connect_lock:
62+
if self.connected:
63+
return
64+
65+
_LOGGER.debug("Trying to connect to %s", self._server_url)
66+
try:
67+
self._client = await self._session.ws_connect(
68+
urljoin(self._server_url, "/api/ws"), params=self._params
69+
)
70+
except (
71+
WSServerHandshakeError,
72+
ClientError,
73+
) as err:
74+
raise Go2RtcClientError(err) from err
75+
76+
self._rx_task = asyncio.create_task(self._receive_messages())
77+
_LOGGER.info("Connected to %s", self._server_url)
78+
79+
async def close(self) -> None:
80+
"""Close connection."""
81+
if self.connected:
82+
if TYPE_CHECKING:
83+
assert self._client is not None
84+
client = self._client
85+
self._client = None
86+
await client.close()
87+
88+
async def send(self, message: BaseMessage) -> None:
89+
"""Send a message."""
90+
if not self.connected:
91+
await self.connect()
92+
93+
if TYPE_CHECKING:
94+
assert self._client is not None
95+
96+
await self._client.send_str(message.to_json())
97+
98+
def _process_text_message(self, data: Any) -> None:
99+
"""Process text message."""
100+
try:
101+
message = BaseMessage.from_json(data)
102+
except Exception: # pylint: disable=broad-except
103+
_LOGGER.exception("Invalid message received: %s", data)
104+
else:
105+
for subscriber in self._subscribers:
106+
try:
107+
subscriber(message)
108+
except Exception: # pylint: disable=broad-except
109+
_LOGGER.exception("Error on subscriber callback")
110+
111+
async def _receive_messages(self) -> None:
112+
"""Receive messages."""
113+
if TYPE_CHECKING:
114+
assert self._client
115+
116+
try:
117+
while self.connected:
118+
msg = await self._client.receive()
119+
match msg.type:
120+
case (
121+
WSMsgType.CLOSE
122+
| WSMsgType.CLOSED
123+
| WSMsgType.CLOSING
124+
| WSMsgType.PING
125+
| WSMsgType.PONG
126+
):
127+
break
128+
case WSMsgType.ERROR:
129+
_LOGGER.error("Error received: %s", msg.data)
130+
case WSMsgType.TEXT:
131+
self._process_text_message(msg.data)
132+
case _:
133+
_LOGGER.warning("Received unknown message: %s", msg)
134+
except Exception:
135+
_LOGGER.exception("Unexpected error while receiving message")
136+
raise
137+
finally:
138+
_LOGGER.debug(
139+
"Websocket client connection from %s closed", self._server_url
140+
)
141+
142+
if self.connected:
143+
await self.close()
144+
145+
def subscribe(self, callback: Callable[[BaseMessage], None]) -> Callable[[], None]:
146+
"""Subscribe to messages."""
147+
148+
def _unsubscribe() -> None:
149+
self._subscribers.remove(callback)
150+
151+
self._subscribers.append(callback)
152+
return _unsubscribe

go2rtc_client/ws/messages.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
"""Go2rtc websocket messages."""
2+
3+
from __future__ import annotations
4+
5+
from dataclasses import dataclass, field
6+
from typing import Any, ClassVar
7+
8+
from mashumaro import field_options
9+
from mashumaro.config import BaseConfig
10+
from mashumaro.mixins.orjson import DataClassORJSONMixin
11+
from mashumaro.types import Discriminator
12+
13+
14+
@dataclass(frozen=True)
15+
class BaseMessage(DataClassORJSONMixin):
16+
"""Base message class."""
17+
18+
TYPE: ClassVar[str]
19+
20+
class Config(BaseConfig):
21+
"""Config for BaseMessage."""
22+
23+
serialize_by_alias = True
24+
discriminator = Discriminator(
25+
field="type",
26+
include_subtypes=True,
27+
variant_tagger_fn=lambda cls: cls.TYPE,
28+
)
29+
30+
def __post_serialize__(self, d: dict[Any, Any]) -> dict[Any, Any]:
31+
"""Add type to serialized dict."""
32+
# ClassVar will not serialize by default
33+
d["type"] = self.TYPE
34+
return d
35+
36+
37+
@dataclass(frozen=True)
38+
class WebRTCCandidate(BaseMessage):
39+
"""WebRTC ICE candidate message."""
40+
41+
TYPE = "webrtc/candidate"
42+
candidate: str = field(metadata=field_options(alias="value"))
43+
44+
45+
@dataclass(frozen=True)
46+
class WebRTCOffer(BaseMessage):
47+
"""WebRTC offer message."""
48+
49+
TYPE = "webrtc/offer"
50+
offer: str = field(metadata=field_options(alias="value"))
51+
52+
53+
@dataclass(frozen=True)
54+
class WebRTCAnswer(BaseMessage):
55+
"""WebRTC answer message."""
56+
57+
TYPE = "webrtc/answer"
58+
answer: str = field(metadata=field_options(alias="value"))
59+
60+
61+
@dataclass(frozen=True)
62+
class WsError(BaseMessage):
63+
"""Error message."""
64+
65+
TYPE = "error"
66+
error: str = field(metadata=field_options(alias="value"))
67+
68+
69+
ReceiveMessages = WebRTCAnswer | WebRTCCandidate | WsError
70+
SendMessages = WebRTCCandidate | WebRTCOffer

pyproject.toml

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ dependencies = [
2525
version = "0.0.0"
2626

2727
[project.urls]
28-
"Homepage" = "https://deebot.readthedocs.io/"
28+
"Homepage" = "https://pypi.org/project/go2rtc-client"
2929
"Source Code" = "https://github.com/home-assistant-libs/python-go2rtc-client"
3030
"Bug Reports" = "https://github.com/home-assistant-libs/python-go2rtc-client/issues"
3131

@@ -35,7 +35,9 @@ dev-dependencies = [
3535
"covdefaults>=2.3.0",
3636
"mypy==1.11.2",
3737
"pre-commit==3.8.0",
38+
"pylint-per-file-ignores>=1.3.2",
3839
"pylint==3.2.7",
40+
"pytest-aiohttp>=1.0.5",
3941
"pytest-asyncio==0.24.0",
4042
"pytest-cov==5.0.0",
4143
"pytest-timeout==2.3.1",
@@ -118,6 +120,11 @@ good-names = [
118120
[tool.pylint.DESIGN]
119121
max-attributes = 8
120122

123+
[tool.pylint.MASTER]
124+
load-plugins=[
125+
"pylint_per_file_ignores",
126+
]
127+
121128
[tool.pylint."MESSAGES CONTROL"]
122129
disable = [
123130
"duplicate-code",
@@ -130,6 +137,12 @@ disable = [
130137
"wrong-import-order",
131138
]
132139

140+
per-file-ignores = [
141+
# redefined-outer-name: Tests reference fixtures in the test function
142+
"/tests/:redefined-outer-name",
143+
]
144+
145+
133146
[tool.pylint.SIMILARITIES]
134147
ignore-imports = true
135148

tests/conftest.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import pytest
88
from syrupy import SnapshotAssertion
99

10-
from go2rtc_client import Go2RtcClient
10+
from go2rtc_client import Go2RtcRestClient
1111

1212
from . import URL
1313
from .syrupy import Go2RtcSnapshotExtension
@@ -20,12 +20,12 @@ def snapshot_assertion(snapshot: SnapshotAssertion) -> SnapshotAssertion:
2020

2121

2222
@pytest.fixture
23-
async def client() -> AsyncGenerator[Go2RtcClient, None]:
24-
"""Return a go2rtc client."""
23+
async def rest_client() -> AsyncGenerator[Go2RtcRestClient, None]:
24+
"""Return a go2rtc rest client."""
2525
async with (
2626
aiohttp.ClientSession() as session,
2727
):
28-
client_ = Go2RtcClient(
28+
client_ = Go2RtcRestClient(
2929
session,
3030
URL,
3131
)

0 commit comments

Comments
 (0)