diff --git a/.github/workflows/integration.yaml b/.github/workflows/integration.yaml index cbcae0c..fb9b5fe 100644 --- a/.github/workflows/integration.yaml +++ b/.github/workflows/integration.yaml @@ -140,7 +140,7 @@ jobs: run: docker pull ${{ inputs.serviceImage }} - name: Run test tool - uses: restatedev/e2e/sdk-tests@v1.0 + uses: restatedev/e2e/sdk-tests@v2.2 with: restateContainerImage: ${{ inputs.restateCommit != '' && 'localhost/restatedev/restate-commit-download:latest' || (inputs.restateImage != '' && inputs.restateImage || 'ghcr.io/restatedev/restate:main') }} serviceContainerImage: ${{ inputs.serviceImage != '' && inputs.serviceImage || 'restatedev/test-services-python' }} diff --git a/Cargo.lock b/Cargo.lock index eda9ea2..2c85246 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -436,8 +436,9 @@ dependencies = [ [[package]] name = "restate-sdk-shared-core" -version = "0.10.0" -source = "git+https://github.com/restatedev/sdk-shared-core.git?rev=5127f0291bff456a515f2b8d572c4090e8ff450e#5127f0291bff456a515f2b8d572c4090e8ff450e" +version = "7.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9bf02583379fc28d6386bd864629ba6553ce942124956b862fa9e95a4d836558" dependencies = [ "base64", "bs58", diff --git a/Cargo.toml b/Cargo.toml index 5e6e675..ae77416 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,4 +14,4 @@ doc = false [dependencies] pyo3 = { version = "0.25.1", features = ["extension-module"] } tracing-subscriber = { version = "0.3", features = ["fmt", "env-filter"] } -restate-sdk-shared-core = { git = "https://github.com/restatedev/sdk-shared-core.git", rev = "5127f0291bff456a515f2b8d572c4090e8ff450e", features = ["request_identity", "sha2_random_seed"] } +restate-sdk-shared-core = { version = "7.0.0", features = ["request_identity", "sha2_random_seed"] } diff --git a/python/restate/__init__.py b/python/restate/__init__.py index 84173d5..16935e3 100644 --- a/python/restate/__init__.py +++ b/python/restate/__init__.py @@ -34,6 +34,7 @@ RestateDurableFuture, RestateDurableCallFuture, RestateDurableSleepFuture, + ScopedContext, SendHandle, RunOptions, ) @@ -101,6 +102,7 @@ async def create_client( "RestateDurableCallFuture", "RestateDurableSleepFuture", "SendHandle", + "ScopedContext", "RunOptions", "TerminalError", "app", diff --git a/python/restate/client.py b/python/restate/client.py index 2e8eda5..f78c336 100644 --- a/python/restate/client.py +++ b/python/restate/client.py @@ -17,7 +17,7 @@ import typing from contextlib import asynccontextmanager -from .client_types import RestateClient, RestateClientSendHandle, HttpError +from .client_types import RestateClient, RestateClientSendHandle, RestateScopedClient, HttpError from .context import HandlerType from .serde import BytesSerde, JsonSerde, Serde @@ -36,6 +36,9 @@ def __init__(self, client: httpx.AsyncClient, headers: typing.Optional[dict] = N self.headers = headers or {} self.client = client + def scope(self, scope: str) -> RestateScopedClient: + return ScopedClient(self, scope) + async def do_call( self, tpe: HandlerType[I, O], @@ -46,6 +49,8 @@ async def do_call( idempotency_key: str | None = None, headers: typing.Dict[str, str] | None = None, force_json_output: bool = False, + scope: str | None = None, + limit_key: str | None = None, ) -> O: """Make an RPC call to the given handler""" target_handler = handler_from_callable(tpe) @@ -77,6 +82,8 @@ async def do_call( send=send, idempotency_key=idempotency_key, headers=headers, + scope=scope, + limit_key=limit_key, ) async def do_raw_call( @@ -91,6 +98,8 @@ async def do_raw_call( send: bool = False, idempotency_key: str | None = None, headers: typing.Dict[str, str] | None = None, + scope: str | None = None, + limit_key: str | None = None, ) -> O: """Make an RPC call to the given handler""" parameter = input_serde.serialize(input_param) @@ -112,6 +121,8 @@ async def do_raw_call( key=key, delay=ms, idempotency_key=idempotency_key, + scope=scope, + limit_key=limit_key, ) return output_serde.deserialize(res) # type: ignore @@ -126,21 +137,37 @@ async def post( key: str | None = None, delay: int | None = None, idempotency_key: str | None = None, + scope: str | None = None, + limit_key: str | None = None, ) -> bytes: """ Send a POST request to the Restate service. """ - endpoint = service - if key: - endpoint += f"/{key}" - endpoint += f"/{handler}" - if send: - endpoint += "/send" - if delay is not None: + if scope is not None: + # Scoped invocations use the dedicated ingress path: + # restate/scope/{scope}/call/{service}[/{key}]/{handler} + # restate/scope/{scope}/send/{service}[/{key}]/{handler} + verb = "send" if send else "call" + endpoint = f"restate/scope/{scope}/{verb}/{service}" + if key: + endpoint += f"/{key}" + endpoint += f"/{handler}" + if send and delay is not None: endpoint = endpoint + f"?delay={delay}" + else: + endpoint = service + if key: + endpoint += f"/{key}" + endpoint += f"/{handler}" + if send: + endpoint += "/send" + if delay is not None: + endpoint = endpoint + f"?delay={delay}" dict_headers = dict(headers) if headers is not None else {} if idempotency_key is not None: dict_headers["Idempotency-Key"] = idempotency_key + if limit_key is not None: + dict_headers["x-restate-limit-key"] = limit_key res = await self.client.post(endpoint, headers=dict_headers, content=content) if res.status_code >= 400: raise HttpError(res.status_code, res.reason_phrase, res.text) @@ -250,6 +277,8 @@ async def generic_call( key: str | None = None, idempotency_key: str | None = None, headers: typing.Dict[str, str] | None = None, + scope: str | None = None, + limit_key: str | None = None, ) -> bytes: serde = BytesSerde() call_handle = await self.do_raw_call( @@ -261,6 +290,8 @@ async def generic_call( key=key, idempotency_key=idempotency_key, headers=headers, + scope=scope, + limit_key=limit_key, ) return call_handle @@ -273,6 +304,8 @@ async def generic_send( send_delay: timedelta | None = None, idempotency_key: str | None = None, headers: typing.Dict[str, str] | None = None, + scope: str | None = None, + limit_key: str | None = None, ) -> RestateClientSendHandle: serde = BytesSerde() output_serde: Serde[dict] = JsonSerde() @@ -288,11 +321,110 @@ async def generic_send( send=True, idempotency_key=idempotency_key, headers=headers, + scope=scope, + limit_key=limit_key, ) return RestateClientSendHandle(send_handle_json.get("invocationId", ""), 200) # TODO: verify +class ScopedClient(RestateScopedClient): + """ + A scoped client returned by ``client.scope(scope_key)``. + + Re-dispatches to the underlying :class:`Client` with the captured scope and a + per-call ``limit_key``. + """ + + def __init__(self, client: Client, scope_key: str): + self.client = client + self.scope_key = scope_key + + async def service_call( + self, + tpe: HandlerType[I, O], + arg: I, + limit_key: str | None = None, + idempotency_key: str | None = None, + headers: typing.Dict[str, str] | None = None, + ) -> O: + return await self.client.do_call( + tpe, + arg, + idempotency_key=idempotency_key, + headers=headers, + scope=self.scope_key, + limit_key=limit_key, + ) + + async def service_send( + self, + tpe: HandlerType[I, O], + arg: I, + send_delay: typing.Optional[timedelta] = None, + limit_key: str | None = None, + idempotency_key: str | None = None, + headers: typing.Dict[str, str] | None = None, + ) -> RestateClientSendHandle: + send_handle = await self.client.do_call( + tpe, + parameter=arg, + send=True, + send_delay=send_delay, + idempotency_key=idempotency_key, + headers=headers, + force_json_output=True, + scope=self.scope_key, + limit_key=limit_key, + ) + send = typing.cast(typing.Dict[str, str], send_handle) + return RestateClientSendHandle(send.get("invocationId", ""), 200) + + async def workflow_call( + self, + tpe: HandlerType[I, O], + key: str, + arg: I, + limit_key: str | None = None, + idempotency_key: str | None = None, + headers: typing.Dict[str, str] | None = None, + ) -> O: + return await self.client.do_call( + tpe, + arg, + key, + idempotency_key=idempotency_key, + headers=headers, + scope=self.scope_key, + limit_key=limit_key, + ) + + async def workflow_send( + self, + tpe: HandlerType[I, O], + key: str, + arg: I, + send_delay: typing.Optional[timedelta] = None, + limit_key: str | None = None, + idempotency_key: str | None = None, + headers: typing.Dict[str, str] | None = None, + ) -> RestateClientSendHandle: + send_handle = await self.client.do_call( + tpe, + parameter=arg, + key=key, + send=True, + send_delay=send_delay, + idempotency_key=idempotency_key, + headers=headers, + force_json_output=True, + scope=self.scope_key, + limit_key=limit_key, + ) + send = typing.cast(typing.Dict[str, str], send_handle) + return RestateClientSendHandle(send.get("invocationId", ""), 200) + + @asynccontextmanager async def create_client( ingress: str, headers: typing.Optional[dict] = None diff --git a/python/restate/client_types.py b/python/restate/client_types.py index 8d1bb23..4613fc5 100644 --- a/python/restate/client_types.py +++ b/python/restate/client_types.py @@ -46,12 +46,118 @@ def __init__(self, status_code: int, message: str, body: str | None = None): self.body = body +class RestateScopedClient(abc.ABC): + """ + An ingress client for making RPC calls within a specific scope. + + **NOTE:** This API is in preview and is not enabled by default. + To use it in restate-server 1.7, enable the flow control and protocol v7 experimental features, + via ``RESTATE_EXPERIMENTAL_ENABLE_PROTOCOL_V7=true`` and ``RESTATE_EXPERIMENTAL_ENABLE_VQUEUES=true``. + These can be enabled only on **new clusters**, for more info check out https://docs.restate.dev/services/flow-control#enabling-flow-control. + If these experimental features aren't enabled, the invocation isn't ingested and the client request fails. + + Returned by ``client.scope(scope_key)``: calls and sends made through this client + carry the captured scope, and each method additionally accepts an optional + ``limit_key``. + + The limit key enforces hierarchical concurrency limits on invocations sharing the same scope. + It can have one or two levels separated by ``/`` (e.g. ``"tenant1"`` or ``"tenant1/user42"``). + Each level must consist only of ``[a-zA-Z0-9_.-]`` characters, and 1 <= length <= 36. + + The limit key is **not** part of the request identity: two calls to the same target with the + same scope and object key but different limit keys refer to the **same** resource instance. + The limit key only affects concurrency limits, not resource identity. + """ + + @abc.abstractmethod + async def service_call( + self, + tpe: HandlerType[I, O], + arg: I, + limit_key: str | None = None, + idempotency_key: str | None = None, + headers: typing.Dict[str, str] | None = None, + ) -> O: + """Make an RPC call to the given handler, within this scope""" + pass + + @abc.abstractmethod + async def service_send( + self, + tpe: HandlerType[I, O], + arg: I, + send_delay: typing.Optional[timedelta] = None, + limit_key: str | None = None, + idempotency_key: str | None = None, + headers: typing.Dict[str, str] | None = None, + ) -> RestateClientSendHandle: + """Make a send operation to the given handler, within this scope""" + pass + + @abc.abstractmethod + async def workflow_call( + self, + tpe: HandlerType[I, O], + key: str, + arg: I, + limit_key: str | None = None, + idempotency_key: str | None = None, + headers: typing.Dict[str, str] | None = None, + ) -> O: + """Make an RPC call to the given workflow handler, within this scope""" + pass + + @abc.abstractmethod + async def workflow_send( + self, + tpe: HandlerType[I, O], + key: str, + arg: I, + send_delay: typing.Optional[timedelta] = None, + limit_key: str | None = None, + idempotency_key: str | None = None, + headers: typing.Dict[str, str] | None = None, + ) -> RestateClientSendHandle: + """Make a send operation to the given workflow handler, within this scope""" + pass + + class RestateClient(abc.ABC): """ An abstract base class for a Restate client. This class defines the interface for a Restate client. """ + @abc.abstractmethod + def scope(self, scope: str) -> RestateScopedClient: + """ + Returns a ``RestateScopedClient`` that routes all calls within the given scope. + + **NOTE:** This API is in preview and is not enabled by default. + To use it in restate-server 1.7, enable the flow control and protocol v7 experimental features, + via ``RESTATE_EXPERIMENTAL_ENABLE_PROTOCOL_V7=true`` and ``RESTATE_EXPERIMENTAL_ENABLE_VQUEUES=true``. + These can be enabled only on **new clusters**, for more info check out https://docs.restate.dev/services/flow-control#enabling-flow-control. + If these experimental features aren't enabled, the invocation won't be ingested and the client request fails. + + A scope is a sub-grouping of resources (invocations, workflow instances, concurrency limits) within the Restate cluster. + It becomes part of the target identity tuple: + - ``scope, service, handler, idempotency_key?`` + - ``scope, workflow, workflow_key, handler`` + + Under the hood, the scope contributes to the partition key, so all resources in a scope get co-located by the restate-server. + + Omitting the scope (i.e. using the regular ``service_call`` / ``workflow_call`` methods) + is equivalent to calling with no scope, which is the existing behavior. + + The scope must consist only of ``[a-zA-Z0-9_.-]`` characters, with 1 <= length <= 36 chars. + + Args: + scope: the scope identifier + + See also: https://docs.restate.dev/services/flow-control + """ + pass + @abc.abstractmethod async def service_call( self, @@ -134,8 +240,15 @@ async def generic_call( key: str | None = None, idempotency_key: str | None = None, headers: typing.Dict[str, str] | None = None, + scope: str | None = None, + limit_key: str | None = None, ) -> bytes: - """Make a generic RPC call to the given service and handler""" + """ + Make a generic RPC call to the given service and handler. + + ``scope`` optionally routes the call within the given scope (see ``RestateClient.scope``); + ``limit_key`` is an optional concurrency limit key within the scope and requires ``scope`` to be set. + """ pass @abc.abstractmethod @@ -148,6 +261,13 @@ async def generic_send( send_delay: timedelta | None = None, idempotency_key: str | None = None, headers: typing.Dict[str, str] | None = None, + scope: str | None = None, + limit_key: str | None = None, ) -> RestateClientSendHandle: - """Make a generic send operation to the given service and handler""" + """ + Make a generic send operation to the given service and handler. + + ``scope`` optionally routes the send within the given scope (see ``RestateClient.scope``); + ``limit_key`` is an optional concurrency limit key within the scope and requires ``scope`` to be set. + """ pass diff --git a/python/restate/context.py b/python/restate/context.py index 98f2239..642366f 100644 --- a/python/restate/context.py +++ b/python/restate/context.py @@ -151,6 +151,9 @@ class Request: attempt_headers (dict[str, str]): The attempt headers of the request. body (bytes): The body of the request. attempt_finished_event (AttemptFinishedEvent): The teardown event of the request. + scope (Optional[str]): The scope key with which this invocation was submitted, if any. + limit_key (Optional[str]): The limit key with which this invocation was submitted, if any. + idempotency_key (Optional[str]): The idempotency key with which this invocation was submitted, if any. """ id: str @@ -158,6 +161,9 @@ class Request: attempt_headers: Dict[str, str] body: bytes attempt_finished_event: AttemptFinishedEvent + scope: Optional[str] = None + limit_key: Optional[str] = None + idempotency_key: Optional[str] = None class KeyValueStore(abc.ABC): @@ -225,6 +231,86 @@ async def cancel_invocation(self) -> None: """ +class ScopedContext(abc.ABC): + """ + A context for making RPC calls within a specific scope. + + **NOTE:** This API is in preview and is not enabled by default. + To use it in restate-server 1.7, enable the flow control and protocol v7 experimental features, + via ``RESTATE_EXPERIMENTAL_ENABLE_PROTOCOL_V7=true`` and ``RESTATE_EXPERIMENTAL_ENABLE_VQUEUES=true``. + These can be enabled only on **new clusters**, for more info check out https://docs.restate.dev/services/flow-control#enabling-flow-control. + When the experimental features are disabled, this method fails the invocation with a retryable error, causing the invocation to be retried until fixed. + + Returned by ``ctx.scope(scope_key)``: calls and sends made through this context + carry the captured scope, and each method additionally accepts an optional + ``limit_key``. + + The limit key enforces hierarchical concurrency limits on invocations sharing the same scope. + It can have one or two levels separated by ``/`` (e.g. ``"tenant1"`` or ``"tenant1/user42"``). + Each level must consist only of ``[a-zA-Z0-9_.-]`` characters, and 1 <= length <= 36. + + The limit key is **not** part of the request identity: two calls to the same target with the + same scope and object key but different limit keys refer to the **same** resource instance. + The limit key only affects concurrency limits, not resource identity. + """ + + @abc.abstractmethod + def service_call( + self, + tpe: HandlerType[I, O], + arg: I, + limit_key: str | None = None, + idempotency_key: str | None = None, + headers: typing.Dict[str, str] | None = None, + ) -> RestateDurableCallFuture[O]: + """ + Invokes the given service with the given argument, within this scope. + """ + + @abc.abstractmethod + def service_send( + self, + tpe: HandlerType[I, O], + arg: I, + send_delay: Optional[timedelta] = None, + limit_key: str | None = None, + idempotency_key: str | None = None, + headers: typing.Dict[str, str] | None = None, + ) -> SendHandle: + """ + Invokes the given service with the given argument, within this scope. + """ + + @abc.abstractmethod + def workflow_call( + self, + tpe: HandlerType[I, O], + key: str, + arg: I, + limit_key: str | None = None, + idempotency_key: str | None = None, + headers: typing.Dict[str, str] | None = None, + ) -> RestateDurableCallFuture[O]: + """ + Invokes the given workflow with the given argument, within this scope. + """ + + @abc.abstractmethod + def workflow_send( + self, + tpe: HandlerType[I, O], + key: str, + arg: I, + send_delay: Optional[timedelta] = None, + limit_key: str | None = None, + idempotency_key: str | None = None, + headers: typing.Dict[str, str] | None = None, + ) -> SendHandle: + """ + Send a message to a workflow with the given argument, within this scope. + """ + + class Context(abc.ABC): """ Represents the context of the current invocation. @@ -236,6 +322,73 @@ def request(self) -> Request: Returns the request object. """ + @abc.abstractmethod + def scope(self, scope: str) -> ScopedContext: + """ + Returns a ``ScopedContext`` that routes all outgoing calls within the given scope. + + **NOTE:** This API is in preview and is not enabled by default. + To use it in restate-server 1.7, enable the flow control and protocol v7 experimental features, + via ``RESTATE_EXPERIMENTAL_ENABLE_PROTOCOL_V7=true`` and ``RESTATE_EXPERIMENTAL_ENABLE_VQUEUES=true``. + These can be enabled only on **new clusters**, for more info check out https://docs.restate.dev/services/flow-control#enabling-flow-control. + If these experimental features aren't enabled, the call fails with a retryable error and keeps retrying until they are. + + A scope is a sub-grouping of resources (invocations, workflow instances, concurrency limits) within the Restate cluster. + It becomes part of the target identity tuple: + - ``scope, service, handler, idempotency_key?`` + - ``scope, workflow, workflow_key, handler`` + + Under the hood, the scope contributes to the partition key, so all resources in a scope get co-located by the restate-server. + + Omitting the scope (i.e. using the regular ``service_call`` / ``workflow_call`` methods) + is equivalent to calling with no scope, which is the existing behavior. + + The scope must consist only of ``[a-zA-Z0-9_.-]`` characters, with 1 <= length <= 36 chars. + + Args: + scope: the scope identifier + + See also: https://docs.restate.dev/services/flow-control + """ + + @abc.abstractmethod + def signal( + self, name: str, serde: Serde[T] = DefaultSerde(), type_hint: Optional[typing.Type[T]] = None + ) -> RestateDurableFuture[T]: + """ + Awaits a named signal on the current invocation, resolving when the signal arrives. + + Args: + name: The signal name. + serde: The serialization/deserialization mechanism. Defaults to DefaultSerde. + type_hint: The type hint of the signal value, used to pick the serializer. + """ + + @abc.abstractmethod + def resolve_signal(self, invocation_id: str, name: str, value: I, serde: Serde[I] = DefaultSerde()) -> None: + """ + Resolves a named signal on the target invocation with the given value. + + Args: + invocation_id: The id of the target invocation. + name: The signal name. + value: The value to resolve the signal with. + serde: The serialization mechanism. Defaults to DefaultSerde. + """ + + @abc.abstractmethod + def reject_signal(self, invocation_id: str, name: str, failure_message: str, failure_code: int = 500) -> None: + """ + Rejects a named signal on the target invocation. The handler awaiting the + signal will observe a terminal error with the given message and code. + + Args: + invocation_id: The id of the target invocation. + name: The signal name. + failure_message: The failure message. + failure_code: The failure code. Defaults to 500. + """ + @abc.abstractmethod def random(self) -> Random: """ @@ -519,9 +672,15 @@ def generic_call( key: Optional[str] = None, idempotency_key: str | None = None, headers: typing.Dict[str, str] | None = None, + scope: str | None = None, + limit_key: str | None = None, ) -> RestateDurableCallFuture[bytes]: """ Invokes the given generic service/handler with the given argument. + + Args: + scope: Optional scope to route the call within. See ``Context.scope``. Since restate-server 1.7. + limit_key: Optional concurrency limit key within the scope. Requires ``scope`` to be set. Since restate-server 1.7. """ @abc.abstractmethod @@ -534,9 +693,15 @@ def generic_send( send_delay: Optional[timedelta] = None, idempotency_key: str | None = None, headers: typing.Dict[str, str] | None = None, + scope: str | None = None, + limit_key: str | None = None, ) -> SendHandle: """ Send a message to a generic service/handler with the given argument. + + Args: + scope: Optional scope to route the send within. See ``Context.scope``. Since restate-server 1.7. + limit_key: Optional concurrency limit key within the scope. Requires ``scope`` to be set. Since restate-server 1.7. """ @abc.abstractmethod diff --git a/python/restate/server_context.py b/python/restate/server_context.py index a45962a..cbaf4ea 100644 --- a/python/restate/server_context.py +++ b/python/restate/server_context.py @@ -41,6 +41,7 @@ RestateDurableCallFuture, RestateDurableFuture, RunAction, + ScopedContext, SendHandle, RestateDurableSleepFuture, RunOptions, @@ -210,6 +211,101 @@ async def cancel_invocation(self) -> None: self.context.cancel_invocation(invocation_id) +class ServerScopedContext(ScopedContext): + """This class implements the scoped context returned by ctx.scope(scope).""" + + def __init__(self, context: "ServerInvocationContext", scope: str) -> None: + super().__init__() + self.context = context + self.scope = scope + + def service_call( + self, + tpe: HandlerType[I, O], + arg: I, + limit_key: str | None = None, + idempotency_key: str | None = None, + headers: typing.Dict[str, str] | None = None, + ) -> RestateDurableCallFuture[O]: + coro = self.context.do_call( + tpe, + arg, + idempotency_key=idempotency_key, + headers=headers, + scope=self.scope, + limit_key=limit_key, + ) + assert not isinstance(coro, SendHandle) + return coro + + def service_send( + self, + tpe: HandlerType[I, O], + arg: I, + send_delay: Optional[timedelta] = None, + limit_key: str | None = None, + idempotency_key: str | None = None, + headers: typing.Dict[str, str] | None = None, + ) -> SendHandle: + send = self.context.do_call( + tpe=tpe, + parameter=arg, + send_delay=send_delay, + send=True, + idempotency_key=idempotency_key, + headers=headers, + scope=self.scope, + limit_key=limit_key, + ) + assert isinstance(send, SendHandle) + return send + + def workflow_call( + self, + tpe: HandlerType[I, O], + key: str, + arg: I, + limit_key: str | None = None, + idempotency_key: str | None = None, + headers: typing.Dict[str, str] | None = None, + ) -> RestateDurableCallFuture[O]: + coro = self.context.do_call( + tpe, + arg, + key, + idempotency_key=idempotency_key, + headers=headers, + scope=self.scope, + limit_key=limit_key, + ) + assert not isinstance(coro, SendHandle) + return coro + + def workflow_send( + self, + tpe: HandlerType[I, O], + key: str, + arg: I, + send_delay: Optional[timedelta] = None, + limit_key: str | None = None, + idempotency_key: str | None = None, + headers: typing.Dict[str, str] | None = None, + ) -> SendHandle: + send = self.context.do_call( + tpe=tpe, + key=key, + parameter=arg, + send_delay=send_delay, + send=True, + idempotency_key=idempotency_key, + headers=headers, + scope=self.scope, + limit_key=limit_key, + ) + assert isinstance(send, SendHandle) + return send + + async def async_value(n: Callable[[], T]) -> T: """convert a simple value to a coroutine.""" return n() @@ -453,7 +549,7 @@ async def leave(self): """Leave the context.""" while True: chunk = self.vm.take_output() - if chunk is None: + if not chunk: break await self.send( { @@ -648,8 +744,36 @@ def request(self) -> Request: attempt_headers=self.attempt_headers, body=self.invocation.input_buffer, attempt_finished_event=ServerTeardownEvent(self.request_finished_event), + scope=self.invocation.scope, + limit_key=self.invocation.limit_key, + idempotency_key=self.invocation.idempotency_key, ) + def scope(self, scope: str) -> ScopedContext: + return ServerScopedContext(self, scope) + + def signal( + self, name: str, serde: Serde[T] = DefaultSerde(), type_hint: Optional[typing.Type[T]] = None + ) -> RestateDurableFuture[T]: + if isinstance(serde, DefaultSerde): + serde = serde.with_maybe_type(type_hint) + handle = self.vm.sys_signal(name) + update_restate_context_is_replaying(self.vm) + return self.create_future(handle, serde) + + def resolve_signal(self, invocation_id: str, name: str, value: I, serde: Serde[I] = DefaultSerde()) -> None: + """Resolve a named signal on a target invocation.""" + if isinstance(serde, DefaultSerde): + serde = serde.with_maybe_type(type(value)) + buf = serde.serialize(value) + self.vm.sys_resolve_signal(invocation_id, name, buf) + update_restate_context_is_replaying(self.vm) + + def reject_signal(self, invocation_id: str, name: str, failure_message: str, failure_code: int = 500) -> None: + """Reject a named signal on a target invocation.""" + self.vm.sys_reject_signal(invocation_id, name, Failure(code=failure_code, message=failure_message)) + update_restate_context_is_replaying(self.vm) + def random(self) -> Random: return self.random_instance @@ -817,6 +941,8 @@ def do_call( send: bool = False, idempotency_key: str | None = None, headers: typing.Dict[str, str] | None = None, + scope: str | None = None, + limit_key: str | None = None, ) -> RestateDurableCallFuture[O] | SendHandle: """Make an RPC call to the given handler""" target_handler = handler_from_callable(tpe) @@ -825,7 +951,18 @@ def do_call( input_serde = target_handler.handler_io.input_serde output_serde = target_handler.handler_io.output_serde return self.do_raw_call( - service, handler, parameter, input_serde, output_serde, key, send_delay, send, idempotency_key, headers + service, + handler, + parameter, + input_serde, + output_serde, + key, + send_delay, + send, + idempotency_key, + headers, + scope, + limit_key, ) def do_raw_call( @@ -840,6 +977,8 @@ def do_raw_call( send: bool = False, idempotency_key: str | None = None, headers: typing.Dict[str, str] | None = None, + scope: str | None = None, + limit_key: str | None = None, ) -> RestateDurableCallFuture[O] | SendHandle: """Make an RPC call to the given handler""" parameter = input_serde.serialize(input_param) @@ -850,13 +989,28 @@ def do_raw_call( if send_delay: ms = int(send_delay.total_seconds() * 1000) send_handle = self.vm.sys_send( - service, handler, parameter, key, delay=ms, idempotency_key=idempotency_key, headers=headers_kvs + service, + handler, + parameter, + key, + delay=ms, + idempotency_key=idempotency_key, + headers=headers_kvs, + scope=scope, + limit_key=limit_key, ) update_restate_context_is_replaying(self.vm) return ServerSendHandle(self, send_handle) if send: send_handle = self.vm.sys_send( - service, handler, parameter, key, idempotency_key=idempotency_key, headers=headers_kvs + service, + handler, + parameter, + key, + idempotency_key=idempotency_key, + headers=headers_kvs, + scope=scope, + limit_key=limit_key, ) update_restate_context_is_replaying(self.vm) return ServerSendHandle(self, send_handle) @@ -868,6 +1022,8 @@ def do_raw_call( key=key, idempotency_key=idempotency_key, headers=headers_kvs, + scope=scope, + limit_key=limit_key, ) update_restate_context_is_replaying(self.vm) @@ -964,6 +1120,8 @@ def generic_call( key: str | None = None, idempotency_key: str | None = None, headers: typing.Dict[str, str] | None = None, + scope: str | None = None, + limit_key: str | None = None, ) -> RestateDurableCallFuture[bytes]: serde = BytesSerde() call_handle = self.do_raw_call( @@ -975,6 +1133,8 @@ def generic_call( key=key, idempotency_key=idempotency_key, headers=headers, + scope=scope, + limit_key=limit_key, ) assert not isinstance(call_handle, SendHandle) return call_handle @@ -988,6 +1148,8 @@ def generic_send( send_delay: timedelta | None = None, idempotency_key: str | None = None, headers: typing.Dict[str, str] | None = None, + scope: str | None = None, + limit_key: str | None = None, ) -> SendHandle: serde = BytesSerde() send_handle = self.do_raw_call( @@ -1001,6 +1163,8 @@ def generic_send( send=True, idempotency_key=idempotency_key, headers=headers, + scope=scope, + limit_key=limit_key, ) assert isinstance(send_handle, SendHandle) return send_handle diff --git a/python/restate/vm.py b/python/restate/vm.py index d3345f4..f1e91af 100644 --- a/python/restate/vm.py +++ b/python/restate/vm.py @@ -50,6 +50,9 @@ class Invocation: headers: typing.List[typing.Tuple[str, str]] input_buffer: bytes key: str + scope: typing.Optional[str] = None + limit_key: typing.Optional[str] = None + idempotency_key: typing.Optional[str] = None @dataclass @@ -229,8 +232,8 @@ def notify_error(self, error: str, stacktrace: str, delay_override: Optional[tim error, stacktrace, int(delay_override.total_seconds() * 1000) if delay_override is not None else None ) - def take_output(self) -> typing.Optional[bytes]: - """Take the output from the virtual machine.""" + def take_output(self) -> bytes: + """Take the buffered output from the virtual machine, possibly empty when there's nothing buffered.""" return self.vm.take_output() def is_ready_to_execute(self) -> bool: @@ -301,9 +304,19 @@ def sys_input(self) -> Invocation: headers: typing.List[typing.Tuple[str, str]] = [(h.key, h.value) for h in inp.headers] input_buffer: bytes = bytes(inp.input) key: str = inp.key + scope: typing.Optional[str] = inp.scope + limit_key: typing.Optional[str] = inp.limit_key + idempotency_key: typing.Optional[str] = inp.idempotency_key return Invocation( - invocation_id=invocation_id, random_seed=random_seed, headers=headers, input_buffer=input_buffer, key=key + invocation_id=invocation_id, + random_seed=random_seed, + headers=headers, + input_buffer=input_buffer, + key=key, + scope=scope, + limit_key=limit_key, + idempotency_key=idempotency_key, ) def sys_write_output_success(self, output: bytes): @@ -386,10 +399,12 @@ def sys_call( key: typing.Optional[str] = None, idempotency_key: typing.Optional[str] = None, headers: typing.Optional[typing.List[typing.Tuple[str, str]]] = None, + scope: typing.Optional[str] = None, + limit_key: typing.Optional[str] = None, ): """Call a service""" py_headers = [PyHeader(key=h[0], value=h[1]) for h in headers] if headers else None - return self.vm.sys_call(service, handler, parameter, key, idempotency_key, py_headers) + return self.vm.sys_call(service, handler, parameter, key, idempotency_key, py_headers, scope, limit_key) # pylint: disable=too-many-arguments def sys_send( @@ -401,13 +416,15 @@ def sys_send( delay: typing.Optional[int] = None, idempotency_key: typing.Optional[str] = None, headers: typing.Optional[typing.List[typing.Tuple[str, str]]] = None, + scope: typing.Optional[str] = None, + limit_key: typing.Optional[str] = None, ) -> int: """ send an invocation to a service, and return the handle to the promise that will resolve with the invocation id """ py_headers = [PyHeader(key=h[0], value=h[1]) for h in headers] if headers else None - return self.vm.sys_send(service, handler, parameter, key, delay, idempotency_key, py_headers) + return self.vm.sys_send(service, handler, parameter, key, delay, idempotency_key, py_headers, scope, limit_key) def sys_run(self, name: str) -> Run: """ @@ -435,6 +452,25 @@ def sys_reject_awakeable(self, name: str, failure: Failure): py_failure = PyFailure(failure.code, failure.message) self.vm.sys_complete_awakeable_failure(name, py_failure) + def sys_signal(self, signal_name: str) -> int: + """ + Create a handle to await a named signal on the current invocation. + """ + return self.vm.sys_signal(signal_name) + + def sys_resolve_signal(self, invocation_id: str, signal_name: str, value: bytes): + """ + Resolve a named signal on a target invocation. + """ + self.vm.sys_complete_signal_success(invocation_id, signal_name, value) + + def sys_reject_signal(self, invocation_id: str, signal_name: str, failure: Failure): + """ + Reject a named signal on a target invocation. + """ + py_failure = PyFailure(failure.code, failure.message) + self.vm.sys_complete_signal_failure(invocation_id, signal_name, py_failure) + def propose_run_completion_success(self, handle: int, output: bytes) -> None: """ Exit a side effect with a success value. diff --git a/src/lib.rs b/src/lib.rs index 83a9440..c128bda 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,7 +5,7 @@ use restate_sdk_shared_core::fmt::{set_error_formatter, ErrorFormatter}; use restate_sdk_shared_core::{ AwaitResponse, AwakeableHandle, CallHandle, CoreVM, Error, Header, IdentityVerifier, Input, NonEmptyValue, NotificationHandle, OnMaxAttempts, ResponseHead, RetryPolicy, RunExitResult, - RunHandle, TakeOutputResult, Target, TerminalFailure, UnresolvedFuture, VMOptions, Value, + RunHandle, Target, TerminalFailure, UnresolvedFuture, VMOptions, Value, CANCEL_NOTIFICATION_HANDLE, VM, }; use std::fmt; @@ -72,15 +72,6 @@ impl From for PyResponseHead { } } -fn take_output_result_into_py<'py>( - py: Python<'py>, - take_output_result: TakeOutputResult, -) -> Bound<'py, PyAny> { - match take_output_result { - TakeOutputResult::Buffer(b) => PyBytes::new(py, &b).into_any(), - TakeOutputResult::EOF => PyNone::get(py).to_owned().into_any(), - } -} type PyNotificationHandle = u32; @@ -239,6 +230,12 @@ pub struct PyInput { headers: Vec, #[pyo3(get, set)] input: Vec, + #[pyo3(get, set)] + scope: Option, + #[pyo3(get, set)] + limit_key: Option, + #[pyo3(get, set)] + idempotency_key: Option, } impl From for PyInput { @@ -249,6 +246,9 @@ impl From for PyInput { key: value.key, headers: value.headers.into_iter().map(Into::into).collect(), input: value.input.into(), + scope: value.scope, + limit_key: value.limit_key, + idempotency_key: value.idempotency_key, } } } @@ -411,9 +411,11 @@ impl PyVM { // Take(s) - /// Returns either bytes or None, indicating EOF - fn take_output(mut self_: PyRefMut<'_, Self>) -> Bound<'_, PyAny> { - take_output_result_into_py(self_.py(), self_.vm.take_output()) + /// Returns the buffered output as bytes, possibly empty when there's nothing buffered. + /// The caller (server_context) decides what to do with an empty buffer. + fn take_output(mut self_: PyRefMut<'_, Self>) -> Bound<'_, PyBytes> { + let output = self_.vm.take_output(); + PyBytes::new(self_.py(), &output) } fn is_ready_to_execute(self_: PyRef<'_, Self>) -> Result { @@ -552,7 +554,8 @@ impl PyVM { .map_err(Into::into) } - #[pyo3(signature = (service, handler, buffer, key=None, idempotency_key=None, headers=None))] + #[pyo3(signature = (service, handler, buffer, key=None, idempotency_key=None, headers=None, scope=None, limit_key=None))] + #[allow(clippy::too_many_arguments)] fn sys_call( mut self_: PyRefMut<'_, Self>, service: String, @@ -561,6 +564,8 @@ impl PyVM { key: Option, idempotency_key: Option, headers: Option>, + scope: Option, + limit_key: Option, ) -> Result { self_ .vm @@ -570,8 +575,8 @@ impl PyVM { handler, key, idempotency_key, - scope: None, - limit_key: None, + scope, + limit_key, headers: headers .unwrap_or_default() .into_iter() @@ -586,7 +591,7 @@ impl PyVM { .map_err(Into::into) } - #[pyo3(signature = (service, handler, buffer, key=None, delay=None, idempotency_key=None, headers=None))] + #[pyo3(signature = (service, handler, buffer, key=None, delay=None, idempotency_key=None, headers=None, scope=None, limit_key=None))] #[allow(clippy::too_many_arguments)] fn sys_send( mut self_: PyRefMut<'_, Self>, @@ -597,6 +602,8 @@ impl PyVM { delay: Option, idempotency_key: Option, headers: Option>, + scope: Option, + limit_key: Option, ) -> Result { self_ .vm @@ -606,8 +613,8 @@ impl PyVM { handler, key, idempotency_key, - scope: None, - limit_key: None, + scope, + limit_key, headers: headers .unwrap_or_default() .into_iter() @@ -664,6 +671,49 @@ impl PyVM { .map_err(Into::into) } + fn sys_signal( + mut self_: PyRefMut<'_, Self>, + signal_name: String, + ) -> Result { + self_ + .vm + .create_signal_handle(signal_name) + .map(Into::into) + .map_err(Into::into) + } + + fn sys_complete_signal_success( + mut self_: PyRefMut<'_, Self>, + invocation_id: String, + signal_name: String, + buffer: &Bound<'_, PyBytes>, + ) -> Result<(), PyVMError> { + self_ + .vm + .sys_complete_signal( + invocation_id, + signal_name, + NonEmptyValue::Success(buffer.as_bytes().to_vec().into()), + ) + .map_err(Into::into) + } + + fn sys_complete_signal_failure( + mut self_: PyRefMut<'_, Self>, + invocation_id: String, + signal_name: String, + value: PyFailure, + ) -> Result<(), PyVMError> { + self_ + .vm + .sys_complete_signal( + invocation_id, + signal_name, + NonEmptyValue::Failure(value.into()), + ) + .map_err(Into::into) + } + fn sys_get_promise( mut self_: PyRefMut<'_, Self>, key: String, diff --git a/test-services/services/proxy.py b/test-services/services/proxy.py index f0c47a8..b581350 100644 --- a/test-services/services/proxy.py +++ b/test-services/services/proxy.py @@ -14,6 +14,7 @@ from datetime import timedelta from restate import Service, Context +from restate.context import RestateDurableCallFuture, SendHandle from typing import TypedDict, Optional, Iterable proxy = Service("Proxy") @@ -26,34 +27,50 @@ class ProxyRequest(TypedDict): message: Iterable[int] delayMillis: Optional[int] idempotencyKey: Optional[str] + scope: Optional[str] + limitKey: Optional[str] -@proxy.handler() -async def call(ctx: Context, req: ProxyRequest) -> Iterable[int]: - response = await ctx.generic_call( +def do_call(ctx: Context, req: ProxyRequest) -> RestateDurableCallFuture[bytes]: + """Issue the outgoing generic call described by req, forwarding scope/limitKey when set.""" + return ctx.generic_call( req["serviceName"], req["handlerName"], bytes(req["message"]), req.get("virtualObjectKey"), - req.get("idempotencyKey"), + idempotency_key=req.get("idempotencyKey"), + scope=req.get("scope"), + limit_key=req.get("limitKey"), ) - return list(response) -@proxy.handler(name="oneWayCall") -async def one_way_call(ctx: Context, req: ProxyRequest) -> str: +def do_send(ctx: Context, req: ProxyRequest) -> SendHandle: + """Issue the outgoing generic send described by req, forwarding scope/limitKey when set.""" send_delay = None - delayMillis = req.get("delayMillis") - if delayMillis is not None: - send_delay = timedelta(milliseconds=delayMillis) - handle = ctx.generic_send( + delay_millis = req.get("delayMillis") + if delay_millis is not None: + send_delay = timedelta(milliseconds=delay_millis) + return ctx.generic_send( req["serviceName"], req["handlerName"], bytes(req["message"]), req.get("virtualObjectKey"), send_delay=send_delay, idempotency_key=req.get("idempotencyKey"), + scope=req.get("scope"), + limit_key=req.get("limitKey"), ) + + +@proxy.handler() +async def call(ctx: Context, req: ProxyRequest) -> Iterable[int]: + response = await do_call(ctx, req) + return list(response) + + +@proxy.handler(name="oneWayCall") +async def one_way_call(ctx: Context, req: ProxyRequest) -> str: + handle = do_send(ctx, req) invocation_id = await handle.invocation_id() return invocation_id @@ -70,26 +87,9 @@ async def many_calls(ctx: Context, requests: Iterable[ManyCallRequest]): for req in requests: if req["oneWayCall"]: - send_delay = None - delayMillis = req["proxyRequest"].get("delayMillis") - if delayMillis is not None: - send_delay = timedelta(milliseconds=delayMillis) - ctx.generic_send( - req["proxyRequest"]["serviceName"], - req["proxyRequest"]["handlerName"], - bytes(req["proxyRequest"]["message"]), - req["proxyRequest"].get("virtualObjectKey"), - send_delay=send_delay, - idempotency_key=req["proxyRequest"].get("idempotencyKey"), - ) + do_send(ctx, req["proxyRequest"]) else: - awaitable = ctx.generic_call( - req["proxyRequest"]["serviceName"], - req["proxyRequest"]["handlerName"], - bytes(req["proxyRequest"]["message"]), - req["proxyRequest"].get("virtualObjectKey"), - idempotency_key=req["proxyRequest"].get("idempotencyKey"), - ) + awaitable = do_call(ctx, req["proxyRequest"]) if req["awaitAtTheEnd"]: to_await.append(awaitable) diff --git a/test-services/services/test_utils.py b/test-services/services/test_utils.py index 55ef991..eb29363 100644 --- a/test-services/services/test_utils.py +++ b/test-services/services/test_utils.py @@ -12,8 +12,7 @@ # pylint: disable=C0116 # pylint: disable=W0613 -from datetime import timedelta -from typing import Dict, List +from typing import Dict, TypedDict from restate import Service, Context from restate.serde import BytesSerde @@ -46,14 +45,6 @@ async def raw_echo(context: Context, input: bytes) -> bytes: return input -@test_utils.handler(name="sleepConcurrently") -async def sleep_concurrently(context: Context, millis_duration: List[int]) -> None: - timers = [context.sleep(timedelta(milliseconds=duration)) for duration in millis_duration] - - for timer in timers: - await timer - - @test_utils.handler(name="countExecutedSideEffects") async def count_executed_side_effects(context: Context, increments: int) -> int: invoked_side_effects = 0 @@ -71,3 +62,25 @@ def effect(): @test_utils.handler(name="cancelInvocation") async def cancel_invocation(context: Context, invocation_id: str) -> None: context.cancel_invocation(invocation_id) + + +class ResolveSignalRequest(TypedDict): + invocationId: str + signalName: str + value: str + + +@test_utils.handler(name="resolveSignal") +async def resolve_signal(context: Context, req: ResolveSignalRequest) -> None: + context.resolve_signal(req["invocationId"], req["signalName"], req["value"]) + + +class RejectSignalRequest(TypedDict): + invocationId: str + signalName: str + reason: str + + +@test_utils.handler(name="rejectSignal") +async def reject_signal(context: Context, req: RejectSignalRequest) -> None: + context.reject_signal(req["invocationId"], req["signalName"], req["reason"]) diff --git a/test-services/services/virtual_object_command_interpreter.py b/test-services/services/virtual_object_command_interpreter.py index b6015cd..24eb4b5 100644 --- a/test-services/services/virtual_object_command_interpreter.py +++ b/test-services/services/virtual_object_command_interpreter.py @@ -13,10 +13,11 @@ # pylint: disable=W0613 import os +import asyncio from datetime import timedelta from typing import Iterable, List, Union, TypedDict, Literal, Any from restate import VirtualObject, ObjectSharedContext, ObjectContext, RestateDurableFuture, RestateDurableSleepFuture -from restate import select, wait_completed, as_completed +from restate import select, wait_completed, as_completed, gather from restate.exceptions import TerminalError virtual_object_command_interpreter = VirtualObject("VirtualObjectCommandInterpreter") @@ -50,7 +51,17 @@ class RunThrowTerminalException(TypedDict): reason: str -AwaitableCommand = Union[CreateAwakeable, Sleep, RunThrowTerminalException] +class CreateSignal(TypedDict): + type: Literal["createSignal"] + signalName: str + + +class RunReturns(TypedDict): + type: Literal["runReturns"] + value: str + + +AwaitableCommand = Union[CreateAwakeable, Sleep, RunThrowTerminalException, CreateSignal, RunReturns] class AwaitOne(TypedDict): @@ -68,6 +79,26 @@ class AwaitAny(TypedDict): commands: List[AwaitableCommand] +class AwaitFirstSucceededOrAllFailed(TypedDict): + type: Literal["awaitFirstSucceededOrAllFailed"] + commands: List[AwaitableCommand] + + +class AwaitFirstCompleted(TypedDict): + type: Literal["awaitFirstCompleted"] + commands: List[AwaitableCommand] + + +class AwaitAllSucceededOrFirstFailed(TypedDict): + type: Literal["awaitAllSucceededOrFirstFailed"] + commands: List[AwaitableCommand] + + +class AwaitAllCompleted(TypedDict): + type: Literal["awaitAllCompleted"] + commands: List[AwaitableCommand] + + class AwaitAwakeableOrTimeout(TypedDict): type: Literal["awaitAwakeableOrTimeout"] awakeableKey: str @@ -92,7 +123,17 @@ class GetEnvVariable(TypedDict): Command = Union[ - AwaitOne, AwaitAny, AwaitAnySuccessful, AwaitAwakeableOrTimeout, ResolveAwakeable, RejectAwakeable, GetEnvVariable + AwaitOne, + AwaitAny, + AwaitAnySuccessful, + AwaitFirstSucceededOrAllFailed, + AwaitFirstCompleted, + AwaitAllSucceededOrFirstFailed, + AwaitAllCompleted, + AwaitAwakeableOrTimeout, + ResolveAwakeable, + RejectAwakeable, + GetEnvVariable, ] @@ -130,6 +171,25 @@ def side_effect(reason: str): res = ctx.run_typed("run should fail command", side_effect, reason=cmd["reason"]) return res + elif cmd["type"] == "createSignal": + return ctx.signal(cmd["signalName"], type_hint=str) + elif cmd["type"] == "runReturns": + + async def run_returns(value: str) -> str: + # genuinely async: suspend inside the run block rather than returning synchronously + await asyncio.sleep(0) + return value + + return ctx.run_typed("runReturns", run_returns, value=cmd["value"]) + + +async def resolve_command_result(fut: RestateDurableFuture[Any]) -> str: + """Await a single command future, mapping a sleep future to the literal "sleep".""" + # We need this dance because the Python SDK doesn't support .map on futures + if isinstance(fut, RestateDurableSleepFuture): + await fut + return "sleep" + return await fut @virtual_object_command_interpreter.handler(name="interpretCommands") @@ -160,35 +220,41 @@ def side_effect(env_name: str): result = await ctx.run_typed("get_env", side_effect, env_name=env_name) elif cmd["type"] == "awaitOne": awaitable = to_durable_future(ctx, cmd["command"]) - # We need this dance because the Python SDK doesn't support .map on futures - if isinstance(awaitable, RestateDurableSleepFuture): - await awaitable - result = "sleep" - else: - result = await awaitable - elif cmd["type"] == "awaitAny": + result = await resolve_command_result(awaitable) + elif cmd["type"] in ("awaitAny", "awaitFirstCompleted"): + # Promise.race: settle with whatever the first command to complete does. futures = [to_durable_future(ctx, c) for c in cmd["commands"]] done, _ = await wait_completed(*futures) - done_fut = done[0] - # We need this dance because the Python SDK doesn't support .map on futures - if isinstance(done_fut, RestateDurableSleepFuture): - await done_fut - result = "sleep" - else: - result = await done_fut - elif cmd["type"] == "awaitAnySuccessful": + result = await resolve_command_result(done[0]) + elif cmd["type"] in ("awaitAnySuccessful", "awaitFirstSucceededOrAllFailed"): + # Promise.any: resolve with the first success; if all fail, raise the last error. futures = [to_durable_future(ctx, c) for c in cmd["commands"]] + last_error: TerminalError | None = None async for done_fut in as_completed(*futures): try: - # We need this dance because the Python SDK doesn't support .map on futures - if isinstance(done_fut, RestateDurableSleepFuture): - await done_fut - result = "sleep" - break - result = await done_fut + result = await resolve_command_result(done_fut) break - except TerminalError: - pass + except TerminalError as err: + last_error = err + else: + assert last_error is not None + raise last_error + elif cmd["type"] == "awaitAllSucceededOrFirstFailed": + # Promise.all: wait for all to succeed, raise on the first failure (input order). + futures = [to_durable_future(ctx, c) for c in cmd["commands"]] + await gather(*futures) + result = "|".join([await resolve_command_result(f) for f in futures]) + elif cmd["type"] == "awaitAllCompleted": + # Promise.allSettled: wait for all to settle, never raise. + futures = [to_durable_future(ctx, c) for c in cmd["commands"]] + await gather(*futures) + parts = [] + for f in futures: + try: + parts.append("ok:" + await resolve_command_result(f)) + except TerminalError as err: + parts.append("err:" + err.message) + result = "|".join(parts) last_results = await get_results(ctx) last_results.append(result)