diff --git a/tee_gateway/controllers/chat_controller.py b/tee_gateway/controllers/chat_controller.py index c48d7a6..608f698 100644 --- a/tee_gateway/controllers/chat_controller.py +++ b/tee_gateway/controllers/chat_controller.py @@ -671,6 +671,44 @@ def generate(): "model": chat_request.model, } yield f"data: {json.dumps(data)}\n\n" + else: + # Buffered providers (OpenAI/Anthropic) hold the + # full tool call until the stream ends so the + # signature can cover the complete arguments. + # Emitting nothing until then leaves the client + # with no "tool starting" signal — a large + # write_file looks frozen for minutes — and lets + # the connection idle out. Emit a lightweight + # progress frame carrying just the tool + # name/index (never the accumulating arguments), + # so the UI can show the tool immediately and + # bytes keep flowing. The authoritative tool call + # (with arguments) is still flushed once after the + # loop, ahead of the signed final frame. + progress_tool_call = { + "index": tc_index, + "type": "function", + } + if tc_chunk.get("id"): + progress_tool_call["id"] = tc_chunk["id"] + if tc_chunk.get("name"): + progress_tool_call["function"] = { + "name": tc_chunk["name"] + } + data = { + "choices": [ + { + "delta": { + "role": "assistant", + "tool_calls": [progress_tool_call], + }, + "index": 0, + "finish_reason": None, + } + ], + "model": chat_request.model, + } + yield f"data: {json.dumps(data)}\n\n" # --- Usage metadata --- # Accumulate deltas rather than replacing: Gemini returns cumulative diff --git a/tee_gateway/test/test_tool_forwarding.py b/tee_gateway/test/test_tool_forwarding.py index f4ca16f..c0cedb2 100644 --- a/tee_gateway/test/test_tool_forwarding.py +++ b/tee_gateway/test/test_tool_forwarding.py @@ -1,3 +1,4 @@ +import json import unittest from unittest.mock import patch, Mock @@ -30,6 +31,15 @@ def __init__(self, content="", tool_calls=None, usage=None): } +class _MockStreamChunk: + """Minimal stand-in for a LangChain AIMessageChunk yielded by model.stream().""" + + def __init__(self, content="", tool_call_chunks=None, usage=None): + self.content = content + self.tool_call_chunks = tool_call_chunks or [] + self.usage_metadata = usage + + def _make_mock_model(response: _MockLangChainResponse) -> Mock: """Return a mock LangChain chat model whose invoke() returns *response*.""" mock_model = Mock() @@ -275,6 +285,118 @@ def test_tee_metadata_in_response( # tee_id must carry the 0x prefix self.assertTrue(result["tee_id"].startswith("0x")) + @patch("tee_gateway.controllers.chat_controller.get_tee_keys") + @patch("tee_gateway.controllers.chat_controller.get_chat_model_cached") + @patch("tee_gateway.controllers.chat_controller.connexion") + def test_streaming_tool_call_emits_early_progress_frame( + self, mock_connexion, mock_get_model, mock_get_tee_keys + ): + """For buffered providers (OpenAI/Anthropic) the streaming response must + emit a lightweight progress frame carrying just the tool name as soon as + it is known — so the client can show "writing file…" immediately and the + connection keeps producing bytes — while the complete, argument-bearing + tool call is still flushed (and signed) at the end of the stream. + + Without this, a large write_file tool call streams nothing to the client + until the whole file has been generated (minutes), which both hides the + activity and lets the read timeout trip. + """ + mock_connexion.request.is_json = True + mock_connexion.request.get_json.return_value = { + "model": "gpt-4.1", + "messages": [{"role": "user", "content": "Write a file."}], + "tools": [ + { + "type": "function", + "function": { + "name": "write_file", + "description": "Write a file", + "parameters": {"type": "object", "properties": {}}, + }, + } + ], + "stream": True, + } + + # Real providers stream the name in the first fragment and the arguments + # across many later fragments; mirror that so the test exercises the + # "name known early, args trickle in" ordering. + chunks = [ + _MockStreamChunk( + tool_call_chunks=[ + {"index": 0, "id": "call_1", "name": "write_file", "args": ""} + ] + ), + _MockStreamChunk( + tool_call_chunks=[{"index": 0, "args": '{"path": "app.js",'}] + ), + _MockStreamChunk( + tool_call_chunks=[{"index": 0, "args": ' "content": "hi"}'}] + ), + _MockStreamChunk( + usage={ + "input_tokens": 5, + "output_tokens": 7, + "total_tokens": 12, + } + ), + ] + mock_model = Mock() + mock_model.stream.return_value = iter(chunks) + mock_model.bind_tools.return_value = mock_model + mock_get_model.return_value = mock_model + mock_get_tee_keys.return_value = _make_mock_tee_keys() + + response = create_chat_completion(None) + # The generator yields SSE frames as str; join them directly. + body = "".join( + part.decode("utf-8") if isinstance(part, bytes) else part + for part in response.response + ) + + # Parse the SSE frames in arrival order. + frames = [] + for line in body.split("\n"): + if not line.startswith("data: "): + continue + payload = line[len("data: ") :] + if payload.strip() == "[DONE]": + break + frames.append(json.loads(payload)) + + # Collect (frame_index, tool_call_delta) pairs in order. + tool_deltas = [ + (i, tc) + for i, frame in enumerate(frames) + for tc in (frame["choices"][0].get("delta") or {}).get("tool_calls", []) + ] + self.assertGreaterEqual( + len(tool_deltas), + 2, + "expected an early progress frame plus the final flushed tool call", + ) + + # First tool delta announces the name but carries no arguments yet. + _, first = tool_deltas[0] + self.assertEqual(first.get("function", {}).get("name"), "write_file") + self.assertNotIn("arguments", first.get("function", {})) + + # The complete arguments are flushed in a single later frame. + with_args = [ + tc for _, tc in tool_deltas if tc.get("function", {}).get("arguments") + ] + self.assertTrue(with_args, "expected a frame with the complete arguments") + args = json.loads(with_args[-1]["function"]["arguments"]) + self.assertEqual(args["path"], "app.js") + + # The progress frame(s) arrive strictly before the final signed frame, so + # the client sees activity well before generation finishes. + first_tool_frame_index = tool_deltas[0][0] + signed_frame_index = next( + i for i, frame in enumerate(frames) if "tee_signature" in frame + ) + self.assertLess(first_tool_frame_index, signed_frame_index) + @patch("tee_gateway.controllers.chat_controller.get_tee_keys") @patch("tee_gateway.controllers.chat_controller.get_chat_model_cached") @patch("tee_gateway.controllers.chat_controller.connexion")