From 0219a3178e1a4868ca2ec83dc1ab3d1567d3acfb Mon Sep 17 00:00:00 2001 From: A Vertex SDK engineer Date: Sat, 9 May 2026 19:56:03 -0700 Subject: [PATCH] fix: update agent engine utils in vertex_ai and agentplatform to support python-a2a sdk 1.0 PiperOrigin-RevId: 913139425 --- agentplatform/_genai/_agent_engines_utils.py | 135 +++++-------- .../test_agent_engine_a2a_v1_methods.py | 102 ++++++++++ vertexai/_genai/_agent_engines_utils.py | 188 ++++++++++++++++-- 3 files changed, 327 insertions(+), 98 deletions(-) create mode 100644 tests/unit/agentplatform/genai/replays/test_agent_engine_a2a_v1_methods.py diff --git a/agentplatform/_genai/_agent_engines_utils.py b/agentplatform/_genai/_agent_engines_utils.py index 170773088c..124123f0e1 100644 --- a/agentplatform/_genai/_agent_engines_utils.py +++ b/agentplatform/_genai/_agent_engines_utils.py @@ -111,30 +111,18 @@ try: - from a2a.types import ( - AgentCard, - TransportProtocol, - Message, - TaskIdParams, - TaskQueryParams, - ) + from a2a.types import AgentCard from a2a.client import ClientConfig, ClientFactory - - AgentCard = AgentCard - TransportProtocol = TransportProtocol - Message = Message - ClientConfig = ClientConfig - ClientFactory = ClientFactory - TaskIdParams = TaskIdParams - TaskQueryParams = TaskQueryParams + from a2a.utils.constants import TransportProtocol except (ImportError, AttributeError): AgentCard = None TransportProtocol = None - Message = None ClientConfig = None ClientFactory = None - TaskIdParams = None - TaskQueryParams = None + SendMessageRequest = None + GetTaskRequest = None + CancelTaskRequest = None + GetExtendedAgentCardRequest = None try: from autogen.agentchat import chat @@ -1807,79 +1795,53 @@ def _wrap_a2a_operation(method_name: str, agent_card: str) -> Callable[..., list Args: method_name: The name of the Agent Engine method to call. agent_card: The agent card to use for the A2A API call. - Example: - {'additionalInterfaces': None, - 'capabilities': {'extensions': None, - 'pushNotifications': None, - 'stateTransitionHistory': None, - 'streaming': False}, - 'defaultInputModes': ['text'], - 'defaultOutputModes': ['text'], - 'description': ( - 'A helpful assistant agent that can answer questions.' - ), - 'documentationUrl': None, - 'iconUrl': None, - 'name': 'Q&A Agent', - 'preferredTransport': 'JSONRPC', - 'protocolVersion': '0.3.0', - 'provider': None, - 'security': None, - 'securitySchemes': None, - 'signatures': None, - 'skills': [{ - 'description': ( - 'A helpful assistant agent that can answer questions.' - ), - 'examples': ['Who is leading 2025 F1 Standings?', - 'Where can i find an active volcano?'], - 'id': 'question_answer', - 'inputModes': None, - 'name': 'Q&A Agent', - 'outputModes': None, - 'security': None, - 'tags': ['Question-Answer']}], - 'supportsAuthenticatedExtendedCard': True, - 'url': 'http://localhost:8080/', - 'version': '1.0.0'} + Example: { 'name': 'Sample Agent', 'description': ( 'A helpful + assistant agent that can answer questions.' ), + 'supportedInterfaces': [{ 'url': 'http://localhost:8080/a2a/rest/', + 'protocolBinding': 'HTTP+JSON', 'protocolVersion': '1.0', }], + 'version': '1.0.0', 'capabilities': { 'streaming': True, + 'pushNotifications': False, 'extendedAgentCard': True, }, + 'defaultInputModes': ['text'], 'defaultOutputModes': ['text'], + 'skills': [{ 'id': 'question_answer', 'name': 'Q&A Agent', + 'description': ( 'A helpful assistant agent that can answer + questions.' ), 'tags': ['Question-Answer'], 'examples': [ 'Who is + leading 2025 F1 Standings?', 'Where can i find an active volcano?', + ], 'inputModes': ['text'], 'outputModes': ['text'], }], } + Returns: A callable object that executes the method on the Agent Engine via the A2A API. """ async def _method(self, **kwargs) -> Any: # type: ignore[no-untyped-def] - """Wraps an Agent Engine method, creating a callable for A2A API.""" if not self.api_client: raise ValueError("api_client is not initialized.") if not self.api_resource: raise ValueError("api_resource is not initialized.") - a2a_agent_card = AgentCard(**json.loads(agent_card)) - # A2A + AE integration currently only supports Rest API. - if ( - a2a_agent_card.preferred_transport - and a2a_agent_card.preferred_transport != TransportProtocol.http_json - ): - raise ValueError( - "Only HTTP+JSON is supported for preferred transport on agent card " - ) - # Set preferred transport to HTTP+JSON if not set. - if not hasattr(a2a_agent_card, "preferred_transport"): - a2a_agent_card.preferred_transport = TransportProtocol.http_json + a2a_agent_card = AgentCard() + json_format.ParseDict( + json.loads(agent_card), a2a_agent_card, ignore_unknown_fields=True + ) - if not hasattr(a2a_agent_card.capabilities, "streaming"): - a2a_agent_card.capabilities.streaming = False + if a2a_agent_card.supported_interfaces: + interface = a2a_agent_card.supported_interfaces[0] + if interface.protocol_binding != TransportProtocol.HTTP_JSON: + raise ValueError( + "Only HTTP+JSON is supported for preferred transport on agent card" + ) + else: + raise ValueError("Agent card does not define any supported interfaces.") - # agent_card is set on the class_methods before set_up is invoked. - # Ensure that the agent_card url is set correctly before the client is created. base_url = self.api_client._api_client._http_options.base_url.rstrip("/") api_version = self.api_client._api_client._http_options.api_version - a2a_agent_card.url = f"{base_url}/{api_version}/{self.api_resource.name}/a2a" + a2a_agent_card.supported_interfaces[0].url = ( + f"{base_url}/{api_version}/{self.api_resource.name}/a2a" + ) - # Using a2a client, inject the auth token from the global config. config = ClientConfig( - supported_transports=[ - TransportProtocol.http_json, + supported_protocol_bindings=[ + TransportProtocol.HTTP_JSON, ], use_client_preference=True, httpx_client=httpx.AsyncClient( @@ -1898,23 +1860,34 @@ async def _method(self, **kwargs) -> Any: # type: ignore[no-untyped-def] factory = ClientFactory(config) client = factory.create(a2a_agent_card) + context = kwargs.pop("context", None) + if context is not None: + from a2a.client.client import ClientCallContext + + if not isinstance(context, ClientCallContext): + actual_context = ClientCallContext() + if hasattr(context, "state"): + actual_context.state = context.state + elif isinstance(context, dict): + actual_context.state = context + context = actual_context + + req = kwargs["request"] if method_name == "on_message_send": - response = client.send_message(Message(**kwargs)) + response = client.send_message(req, context=context) chunks = [] async for chunk in response: chunks.append(chunk) return chunks elif method_name == "on_get_task": - response = await client.get_task(TaskQueryParams(**kwargs)) + return await client.get_task(req, context=context) elif method_name == "on_cancel_task": - response = await client.cancel_task(TaskIdParams(**kwargs)) - elif method_name == "handle_authenticated_agent_card": - response = await client.get_card() + return await client.cancel_task(req, context=context) + elif method_name == "on_get_extended_agent_card": + return await client.get_extended_agent_card(req, context=context) else: raise ValueError(f"Unknown method name: {method_name}") - return response - return _method # type: ignore[return-value] diff --git a/tests/unit/agentplatform/genai/replays/test_agent_engine_a2a_v1_methods.py b/tests/unit/agentplatform/genai/replays/test_agent_engine_a2a_v1_methods.py new file mode 100644 index 0000000000..200e590fc3 --- /dev/null +++ b/tests/unit/agentplatform/genai/replays/test_agent_engine_a2a_v1_methods.py @@ -0,0 +1,102 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# pylint: disable=protected-access,bad-continuation,missing-function-docstring + +from unittest import mock + +from tests.unit.agentplatform.genai.replays import pytest_helper +from agentplatform._genai import types +from google.genai import _api_client +import httpx +import pytest + + +# These tests target a2a-sdk 1.0, where the request/response types are protobuf +# messages and client errors are surfaced as `A2AClientError`. +pytest.importorskip( + "a2a.client.errors", reason="a2a-sdk not installed, skipping Agent Engine A2A tests" +) +from a2a.client import ( # noqa: E402 # pylint: disable=g-import-not-at-top,g-bad-import-order + errors as a2a_errors, +) +from a2a import types as a2a_types # noqa: E402 + +pytest_plugins = ("pytest_asyncio",) + + +def _build_send_message_request() -> "a2a_types.SendMessageRequest": + """Builds an a2a 1.0 SendMessageRequest proto for on_message_send.""" + return a2a_types.SendMessageRequest( + message=a2a_types.Message( + message_id="msg-123", + role=a2a_types.Role.ROLE_USER, + parts=[a2a_types.Part(text="Where will be the Super Bowl held in 2026?")], + ) + ) + + +@pytest.mark.asyncio +async def test_timeout_is_set(client): + agent_engine = client.agent_engines.get( + name="projects/964831358985/locations/us-central1/reasoningEngines/6859679872613089280", + ) + assert isinstance(agent_engine, types.AgentEngine) + + with mock.patch( + "httpx.AsyncClient", spec=httpx.AsyncClient + ) as mock_async_client_factory: + # Replay mode does not capture A2A calls so instead of relying on the + # real service, we simulate a failed call. + mock_response = httpx.Response( + 401, + request=httpx.Request("POST", "url"), + json={ + "error": { + "code": "UNAUTHENTICATED", + "message": "Authentication failed: Missing or invalid API key.", + } + }, + ) + mock_async_client_instance = mock_async_client_factory.return_value + mock_async_client_instance.post.return_value = mock_response + mock_async_client_instance.send.return_value = mock_response + + # These credentials are missing in replay mode, so we need to set a fake + # value. (This is not necessary in record mode.) + class FakeCredentials: + token = "fake-token" + + agent_engine.api_client._api_client._credentials = FakeCredentials() + + # In a2a 1.0 the wrapped operation forwards the `request` kwarg directly + # to `client.send_message(request)`, and HTTP failures surface as an + # `A2AClientError` (the legacy `A2AClientHTTPError` no longer exists). + with pytest.raises(a2a_errors.A2AClientError) as exc_info: + await agent_engine.on_message_send(request=_build_send_message_request()) + + # Make sure the authentication failure was propagated, otherwise the + # test is not validating the request path. + assert "401" in str(exc_info.value) + + mock_async_client_factory.assert_called_once() + assert mock_async_client_factory.call_args.kwargs["timeout"] == 99.0 + + +pytestmark = pytest_helper.setup( + file=__file__, + globals_for_file=globals(), + test_method="agent_engines.get", + http_options=_api_client.HttpOptions(timeout=99000), +) diff --git a/vertexai/_genai/_agent_engines_utils.py b/vertexai/_genai/_agent_engines_utils.py index 92d91addbb..f876238401 100644 --- a/vertexai/_genai/_agent_engines_utils.py +++ b/vertexai/_genai/_agent_engines_utils.py @@ -110,23 +110,39 @@ try: - from a2a.types import ( - AgentCard, - TransportProtocol, - Message, - TaskIdParams, - TaskQueryParams, - ) - from a2a.client import ClientConfig, ClientFactory - - AgentCard = AgentCard - TransportProtocol = TransportProtocol - Message = Message - ClientConfig = ClientConfig - ClientFactory = ClientFactory - TaskIdParams = TaskIdParams - TaskQueryParams = TaskQueryParams + from a2a.utils.constants import TransportProtocol as _A2aVersionTest # noqa: F401 + + _A2A_SDK_VERSION: Optional[str] = "1.0" +except ImportError: + try: + from a2a.types import TransportProtocol as _A2aVersionTest # noqa: F401 + + _A2A_SDK_VERSION = "0.3" + except ImportError: + _A2A_SDK_VERSION = None + +try: + if _A2A_SDK_VERSION == "1.0": + from a2a.types import ( + AgentCard, + Message, + ) + from a2a.client import ClientConfig, ClientFactory + from a2a.utils.constants import TransportProtocol + from a2a.compat.v0_3.types import TaskIdParams, TaskQueryParams + elif _A2A_SDK_VERSION == "0.3": + from a2a.types import ( + AgentCard, + TransportProtocol, + Message, + TaskIdParams, + TaskQueryParams, + ) + from a2a.client import ClientConfig, ClientFactory + else: + raise ImportError except (ImportError, AttributeError): + _A2A_SDK_VERSION = None AgentCard = None TransportProtocol = None Message = None @@ -134,6 +150,10 @@ ClientFactory = None TaskIdParams = None TaskQueryParams = None + SendMessageRequest = None + GetTaskRequest = None + CancelTaskRequest = None + GetExtendedAgentCardRequest = None _ACTIONS_KEY = "actions" _ACTION_APPEND = "append" @@ -1737,7 +1757,9 @@ async def _method(self: genai_types.AgentEngine, **kwargs) -> AsyncIterator[Any] return _method -def _wrap_a2a_operation(method_name: str, agent_card: str) -> Callable[..., list[Any]]: +def _wrap_a2a_operation_v03( + method_name: str, agent_card: str +) -> Callable[..., list[Any]]: """Wraps an Agent Engine method, creating a callable for A2A API. Args: @@ -1854,6 +1876,138 @@ async def _method(self, **kwargs) -> Any: # type: ignore[no-untyped-def] return _method # type: ignore[return-value] +def _wrap_a2a_operation_v10( + method_name: str, agent_card: str +) -> Callable[..., list[Any]]: + """Wraps an Agent Engine method, creating a callable for A2A API (v1.0.0+). + + Args: + method_name: The name of the Agent Engine method to call. + agent_card: The agent card JSON string to use for the A2A API call. + Example: + { + 'name': 'Sample Agent', + 'description': ( + 'A helpful assistant agent that can answer questions.' + ), + 'supportedInterfaces': [{ + 'url': 'http://localhost:8080/a2a/rest/', + 'protocolBinding': 'HTTP+JSON', + 'protocolVersion': '1.0', + }], + 'version': '1.0.0', + 'capabilities': { + 'streaming': True, + 'pushNotifications': False, + 'extendedAgentCard': True, + }, + 'defaultInputModes': ['text'], + 'defaultOutputModes': ['text'], + 'skills': [{ + 'id': 'question_answer', + 'name': 'Q&A Agent', + 'description': ( + 'A helpful assistant agent that can answer questions.' + ), + 'tags': ['Question-Answer'], + 'examples': [ + 'Who is leading 2025 F1 Standings?', + 'Where can i find an active volcano?', + ], + 'inputModes': ['text'], + 'outputModes': ['text'], + }], + } + + Returns: + A callable object that executes the method on the Agent Engine via + the A2A API. + """ + + async def _method(self, **kwargs) -> Any: # type: ignore[no-untyped-def] + if not self.api_client: + raise ValueError("api_client is not initialized.") + if not self.api_resource: + raise ValueError("api_resource is not initialized.") + + a2a_agent_card = AgentCard() + json_format.ParseDict( + json.loads(agent_card), a2a_agent_card, ignore_unknown_fields=True + ) + + if a2a_agent_card.supported_interfaces: + interface = a2a_agent_card.supported_interfaces[0] + if interface.protocol_binding != TransportProtocol.HTTP_JSON: + raise ValueError( + "Only HTTP+JSON is supported for preferred transport on agent card" + ) + else: + raise ValueError("Agent card does not define any supported interfaces.") + + base_url = self.api_client._api_client._http_options.base_url.rstrip("/") + api_version = self.api_client._api_client._http_options.api_version + a2a_agent_card.supported_interfaces[0].url = ( + f"{base_url}/{api_version}/{self.api_resource.name}/a2a" + ) + + config = ClientConfig( + supported_protocol_bindings=[ + TransportProtocol.HTTP_JSON, + ], + use_client_preference=True, + httpx_client=httpx.AsyncClient( + headers={ + "Authorization": ( + f"Bearer {self.api_client._api_client._credentials.token}" + ) + }, + timeout=( + self.api_client._api_client._http_options.timeout / 1000.0 + if self.api_client._api_client._http_options.timeout + else None + ), + ), + ) + factory = ClientFactory(config) + client = factory.create(a2a_agent_card) + + context = kwargs.pop("context", None) + if context is not None: + from a2a.client.client import ClientCallContext + + if not isinstance(context, ClientCallContext): + actual_context = ClientCallContext() + if hasattr(context, "state"): + actual_context.state = context.state + elif isinstance(context, dict): + actual_context.state = context + context = actual_context + + req = kwargs["request"] + if method_name == "on_message_send": + response = client.send_message(req, context=context) + chunks = [] + async for chunk in response: + chunks.append(chunk) + return chunks + elif method_name == "on_get_task": + return await client.get_task(req, context=context) + elif method_name == "on_cancel_task": + return await client.cancel_task(req, context=context) + elif method_name == "on_get_extended_agent_card": + return await client.get_extended_agent_card(req, context=context) + else: + raise ValueError(f"Unknown method name: {method_name}") + + return _method # type: ignore[return-value] + + +if _A2A_SDK_VERSION == "1.0": + _wrap_a2a_operation = _wrap_a2a_operation_v10 +else: + _wrap_a2a_operation = _wrap_a2a_operation_v03 + + def _yield_parsed_json(http_response: google_genai_types.HttpResponse) -> Iterator[Any]: """Converts the body of the HTTP Response message to JSON format.