From cca10c2f020524beae78aceb0ad0c7a6d9a7844e Mon Sep 17 00:00:00 2001 From: pragnyanramtha Date: Sat, 16 May 2026 00:41:16 +0000 Subject: [PATCH] fix streamable http post error isolation --- src/mcp/client/streamable_http.py | 16 ++- tests/shared/test_streamable_http.py | 148 +++++++++++++++++++++++++++ 2 files changed, 163 insertions(+), 1 deletion(-) diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index ed28fcc275..17a445a6a9 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -28,6 +28,7 @@ ) from mcp.shared.message import ClientMessageMetadata, SessionMessage from mcp.types import ( + INTERNAL_ERROR, ErrorData, InitializeResult, JSONRPCError, @@ -355,7 +356,20 @@ async def _handle_post_request(self, ctx: RequestContext) -> None: ) # pragma: no cover return # pragma: no cover - response.raise_for_status() + if response.status_code >= 400: + if isinstance(message.root, JSONRPCRequest): + jsonrpc_error = JSONRPCError( + jsonrpc="2.0", + id=message.root.id, + error=ErrorData( + code=INTERNAL_ERROR, + message=f"Server returned HTTP {response.status_code}", + data={"status_code": response.status_code}, + ), + ) + await ctx.read_stream_writer.send(SessionMessage(JSONRPCMessage(jsonrpc_error))) + return + if is_initialization: self._maybe_extract_session_id_from_response(response) diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index 731dd20dd3..43864404d6 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -1043,6 +1043,154 @@ async def test_streamable_http_client_error_handling(initialized_client_session: assert "Unknown resource: unknown://test-error" in exc_info.value.error.message +@pytest.mark.anyio +async def test_streamable_http_client_http_error_does_not_cancel_concurrent_request(): + """Test that one POST HTTP error does not tear down an unrelated request.""" + good_request_started = anyio.Event() + allow_good_response = anyio.Event() + + async def handler(request: httpx.Request) -> httpx.Response: + payload = json.loads(request.content) + request_id = payload["id"] + uri = payload["params"]["uri"] + + if uri == "foobar://bad": + with anyio.fail_after(5): + await good_request_started.wait() + return httpx.Response(400, request=request, json={"error": "boom"}) + + assert uri == "foobar://good" + good_request_started.set() + with anyio.fail_after(5): + await allow_good_response.wait() + return httpx.Response( + 200, + request=request, + headers={"content-type": "application/json"}, + json={ + "jsonrpc": "2.0", + "id": request_id, + "result": { + "contents": [ + { + "uri": uri, + "mimeType": "text/plain", + "text": "good response", + } + ] + }, + }, + ) + + good_result: types.ReadResourceResult | None = None + bad_error: Exception | None = None + bad_request_failed = anyio.Event() + + async def run_good_request(session: ClientSession) -> None: + nonlocal good_result + good_result = await session.send_request( + types.ClientRequest( + types.ReadResourceRequest( + params=types.ReadResourceRequestParams(uri=AnyUrl("foobar://good")), + ) + ), + types.ReadResourceResult, + ) + + async def run_bad_request(session: ClientSession) -> None: + nonlocal bad_error + try: + await session.send_request( + types.ClientRequest( + types.ReadResourceRequest( + params=types.ReadResourceRequestParams(uri=AnyUrl("foobar://bad")), + ) + ), + types.ReadResourceResult, + ) + except Exception as exc: + bad_error = exc + bad_request_failed.set() + + transport = httpx.MockTransport(handler) + async with httpx.AsyncClient(transport=transport) as http_client: + async with streamable_http_client("http://test/mcp", http_client=http_client) as streams: # pragma: no branch + read_stream, write_stream, _ = streams + async with ClientSession(read_stream, write_stream) as session: # pragma: no branch + async with anyio.create_task_group() as tg: + tg.start_soon(run_good_request, session) + with anyio.fail_after(5): + await good_request_started.wait() + + tg.start_soon(run_bad_request, session) + + with anyio.fail_after(5): + await bad_request_failed.wait() + + allow_good_response.set() + + assert isinstance(bad_error, McpError) + assert bad_error.error.code == types.INTERNAL_ERROR + assert bad_error.error.message == "Server returned HTTP 400" + assert bad_error.error.data == {"status_code": 400} + assert good_result is not None + assert isinstance(good_result.contents[0], types.TextResourceContents) + assert good_result.contents[0].text == "good response" + + +@pytest.mark.anyio +async def test_streamable_http_client_notification_http_error_does_not_cancel_transport(): + """Test POST HTTP errors for notifications do not synthesize responses.""" + notification_seen = anyio.Event() + + async def handler(request: httpx.Request) -> httpx.Response: + payload = json.loads(request.content) + + if "id" not in payload: + notification_seen.set() + return httpx.Response(500, request=request, json={"error": "boom"}) + + return httpx.Response( + 200, + request=request, + headers={"content-type": "application/json"}, + json={ + "jsonrpc": "2.0", + "id": payload["id"], + "result": { + "contents": [ + { + "uri": "foobar://good", + "mimeType": "text/plain", + "text": "good response", + } + ] + }, + }, + ) + + transport = httpx.MockTransport(handler) + async with httpx.AsyncClient(transport=transport) as http_client: + async with streamable_http_client("http://test/mcp", http_client=http_client) as streams: # pragma: no branch + read_stream, write_stream, _ = streams + async with ClientSession(read_stream, write_stream) as session: # pragma: no branch + await session.send_notification(types.ClientNotification(types.RootsListChangedNotification())) + with anyio.fail_after(5): + await notification_seen.wait() + + result = await session.send_request( + types.ClientRequest( + types.ReadResourceRequest( + params=types.ReadResourceRequestParams(uri=AnyUrl("foobar://good")), + ) + ), + types.ReadResourceResult, + ) + + assert isinstance(result.contents[0], types.TextResourceContents) + assert result.contents[0].text == "good response" + + @pytest.mark.anyio async def test_streamable_http_client_session_persistence(basic_server: None, basic_server_url: str): """Test that session ID persists across requests."""