Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 67 additions & 26 deletions decart/realtime/client.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
from typing import Callable, Optional, Union
from typing import Callable, Optional, Union, TYPE_CHECKING
import asyncio
import base64
import json
import logging
from pathlib import Path
from urllib.parse import urlparse, quote
import aiohttp
from aiortc import MediaStreamTrack
from pydantic import BaseModel

from .webrtc_manager import WebRTCManager, WebRTCConfiguration
from .messages import PromptMessage, SessionIdMessage, GenerationTickMessage
from .livekit_manager import LiveKitManager, LiveKitConfiguration
from .messages import PromptMessage, LiveKitRoomInfoMessage, GenerationTickMessage
from .subscribe import (
SubscribeClient,
SubscribeOptions,
Expand All @@ -21,6 +21,9 @@
from ..errors import DecartSDKError, InvalidInputError, WebRTCError
from ..process.request import file_input_to_bytes

if TYPE_CHECKING:
from livekit.rtc import LocalVideoTrack

logger = logging.getLogger(__name__)

PROMPT_TIMEOUT_S = 15.0
Expand Down Expand Up @@ -64,10 +67,46 @@ async def _image_to_base64(
return image


def _realtime_base_to_http(base_url: str) -> str:
if base_url.startswith("wss://"):
return "https://" + base_url[len("wss://") :]
if base_url.startswith("ws://"):
return "http://" + base_url[len("ws://") :]
return base_url


async def _fetch_watch_stream_credentials(
base_url: str,
api_key: str,
room_name: str,
) -> LiveKitRoomInfoMessage:
http_base_url = _realtime_base_to_http(base_url).rstrip("/")
url = f"{http_base_url}/watch-stream/{quote(room_name)}"
async with aiohttp.ClientSession() as session:
async with session.post(url, headers={"x-api-key": api_key}) as response:
body = await response.text()
if response.status >= 400:
raise WebRTCError(
f"watch-stream request failed ({response.status}): {body or response.reason}"
)
data = json.loads(body)

try:
return LiveKitRoomInfoMessage(
type="livekit_room_info",
livekit_url=data["livekit_url"],
token=data["token"],
room_name=data["room_name"],
session_id=data.get("session_id", data["room_name"]),
)
except KeyError as e:
raise WebRTCError(f"watch-stream response missing required field: {e.args[0]}") from e


class RealtimeClient:
def __init__(
self,
manager: WebRTCManager,
manager: LiveKitManager,
http_session: Optional[aiohttp.ClientSession] = None,
):
self._manager = manager
Expand All @@ -88,51 +127,53 @@ def session_id(self) -> Optional[str]:
def subscribe_token(self) -> Optional[str]:
return self._subscribe_token

def _handle_session_id(self, msg: SessionIdMessage) -> None:
def _handle_session_started(self, msg: LiveKitRoomInfoMessage) -> None:
self._session_id = msg.session_id
self._subscribe_token = encode_subscribe_token(
msg.session_id, msg.server_ip, msg.server_port
)
self._subscribe_token = encode_subscribe_token(msg.room_name)

@classmethod
async def connect(
cls,
base_url: str,
api_key: str,
local_track: Optional[MediaStreamTrack],
local_track: Optional["LocalVideoTrack"],
options: RealtimeConnectOptions,
integration: Optional[str] = None,
) -> "RealtimeClient":
ws_url = f"{base_url}{options.model.url_path}"
ws_url += f"?api_key={quote(api_key)}&model={quote(options.model.name)}"
ws_url += (
f"?api_key={quote(api_key)}"
f"&model={quote(options.model.name)}"
"&livekit_early_room_info=true"
)
if options.resolution is not None:
ws_url += f"&resolution={quote(options.resolution)}"

config = WebRTCConfiguration(
webrtc_url=ws_url,
config = LiveKitConfiguration(
livekit_url=ws_url,
api_key=api_key,
session_id="",
fps=options.model.fps,
on_remote_stream=options.on_remote_stream,
on_connection_state_change=None,
on_error=None,
on_session_id=None,
on_session_started=None,
initial_state=options.initial_state,
customize_offer=options.customize_offer,
preferred_video_codec=options.preferred_video_codec,
integration=integration,
)

http_session = aiohttp.ClientSession()

manager = WebRTCManager(config)
manager = LiveKitManager(config)
client = cls(
manager=manager,
http_session=http_session,
)

config.on_connection_state_change = client._emit_connection_change
config.on_error = lambda error: client._emit_error(WebRTCError(str(error), cause=error))
config.on_session_id = client._handle_session_id
config.on_session_started = client._handle_session_started
config.on_generation_tick = client._emit_generation_tick

try:
Expand Down Expand Up @@ -169,25 +210,25 @@ async def subscribe(
integration: Optional[str] = None,
) -> SubscribeClient:
token_data = decode_subscribe_token(options.token)
subscribe_url = (
f"{base_url}/subscribe/{quote(token_data.sid)}"
f"?IP={quote(token_data.ip)}"
f"&port={quote(str(token_data.port))}"
f"&api_key={quote(api_key)}"
room_info = await _fetch_watch_stream_credentials(
base_url=base_url,
api_key=api_key,
room_name=token_data.room_name,
)

config = WebRTCConfiguration(
webrtc_url=subscribe_url,
config = LiveKitConfiguration(
livekit_url="",
api_key=api_key,
session_id=token_data.sid,
session_id=token_data.room_name,
fps=0,
on_remote_stream=options.on_remote_stream,
on_connection_state_change=None,
on_error=None,
room_info=room_info,
integration=integration,
)

manager = WebRTCManager(config)
manager = LiveKitManager(config)
sub_client = SubscribeClient(manager)

config.on_connection_state_change = sub_client._emit_connection_change
Expand Down
Loading
Loading