Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions tee_gateway/controllers/chat_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,6 +671,44 @@ def generate():
"model": chat_request.model,
}
yield f"data: {json.dumps(data)}\n\n"
else:
# Buffered providers (OpenAI/Anthropic) hold the
# full tool call until the stream ends so the
# signature can cover the complete arguments.
# Emitting nothing until then leaves the client
# with no "tool starting" signal — a large
# write_file looks frozen for minutes — and lets
# the connection idle out. Emit a lightweight
# progress frame carrying just the tool
# name/index (never the accumulating arguments),
# so the UI can show the tool immediately and
# bytes keep flowing. The authoritative tool call
# (with arguments) is still flushed once after the
# loop, ahead of the signed final frame.
progress_tool_call = {
"index": tc_index,
"type": "function",
}
if tc_chunk.get("id"):
progress_tool_call["id"] = tc_chunk["id"]
if tc_chunk.get("name"):
progress_tool_call["function"] = {
"name": tc_chunk["name"]
}
data = {
"choices": [
{
"delta": {
"role": "assistant",
"tool_calls": [progress_tool_call],
},
"index": 0,
"finish_reason": None,
}
],
"model": chat_request.model,
}
yield f"data: {json.dumps(data)}\n\n"

# --- Usage metadata ---
# Accumulate deltas rather than replacing: Gemini returns cumulative
Expand Down
122 changes: 122 additions & 0 deletions tee_gateway/test/test_tool_forwarding.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import unittest
from unittest.mock import patch, Mock

Expand Down Expand Up @@ -30,6 +31,15 @@ def __init__(self, content="", tool_calls=None, usage=None):
}


class _MockStreamChunk:
"""Minimal stand-in for a LangChain AIMessageChunk yielded by model.stream()."""

def __init__(self, content="", tool_call_chunks=None, usage=None):
self.content = content
self.tool_call_chunks = tool_call_chunks or []
self.usage_metadata = usage


def _make_mock_model(response: _MockLangChainResponse) -> Mock:
"""Return a mock LangChain chat model whose invoke() returns *response*."""
mock_model = Mock()
Expand Down Expand Up @@ -275,6 +285,118 @@ def test_tee_metadata_in_response(
# tee_id must carry the 0x prefix
self.assertTrue(result["tee_id"].startswith("0x"))

@patch("tee_gateway.controllers.chat_controller.get_tee_keys")
@patch("tee_gateway.controllers.chat_controller.get_chat_model_cached")
@patch("tee_gateway.controllers.chat_controller.connexion")
def test_streaming_tool_call_emits_early_progress_frame(
self, mock_connexion, mock_get_model, mock_get_tee_keys
):
"""For buffered providers (OpenAI/Anthropic) the streaming response must
emit a lightweight progress frame carrying just the tool name as soon as
it is known — so the client can show "writing file…" immediately and the
connection keeps producing bytes — while the complete, argument-bearing
tool call is still flushed (and signed) at the end of the stream.

Without this, a large write_file tool call streams nothing to the client
until the whole file has been generated (minutes), which both hides the
activity and lets the read timeout trip.
"""
mock_connexion.request.is_json = True
mock_connexion.request.get_json.return_value = {
"model": "gpt-4.1",
"messages": [{"role": "user", "content": "Write a file."}],
"tools": [
{
"type": "function",
"function": {
"name": "write_file",
"description": "Write a file",
"parameters": {"type": "object", "properties": {}},
},
}
],
"stream": True,
}

# Real providers stream the name in the first fragment and the arguments
# across many later fragments; mirror that so the test exercises the
# "name known early, args trickle in" ordering.
chunks = [
_MockStreamChunk(
tool_call_chunks=[
{"index": 0, "id": "call_1", "name": "write_file", "args": ""}
]
),
_MockStreamChunk(
tool_call_chunks=[{"index": 0, "args": '{"path": "app.js",'}]
),
_MockStreamChunk(
tool_call_chunks=[{"index": 0, "args": ' "content": "hi"}'}]
),
_MockStreamChunk(
usage={
"input_tokens": 5,
"output_tokens": 7,
"total_tokens": 12,
}
),
]
mock_model = Mock()
mock_model.stream.return_value = iter(chunks)
mock_model.bind_tools.return_value = mock_model
mock_get_model.return_value = mock_model
mock_get_tee_keys.return_value = _make_mock_tee_keys()

response = create_chat_completion(None)
# The generator yields SSE frames as str; join them directly.
body = "".join(
part.decode("utf-8") if isinstance(part, bytes) else part
for part in response.response
)

# Parse the SSE frames in arrival order.
frames = []
for line in body.split("\n"):
if not line.startswith("data: "):
continue
payload = line[len("data: ") :]
if payload.strip() == "[DONE]":
break
frames.append(json.loads(payload))

# Collect (frame_index, tool_call_delta) pairs in order.
tool_deltas = [
(i, tc)
for i, frame in enumerate(frames)
for tc in (frame["choices"][0].get("delta") or {}).get("tool_calls", [])
]
self.assertGreaterEqual(
len(tool_deltas),
2,
"expected an early progress frame plus the final flushed tool call",
)

# First tool delta announces the name but carries no arguments yet.
_, first = tool_deltas[0]
self.assertEqual(first.get("function", {}).get("name"), "write_file")
self.assertNotIn("arguments", first.get("function", {}))

# The complete arguments are flushed in a single later frame.
with_args = [
tc for _, tc in tool_deltas if tc.get("function", {}).get("arguments")
]
self.assertTrue(with_args, "expected a frame with the complete arguments")
args = json.loads(with_args[-1]["function"]["arguments"])
self.assertEqual(args["path"], "app.js")

# The progress frame(s) arrive strictly before the final signed frame, so
# the client sees activity well before generation finishes.
first_tool_frame_index = tool_deltas[0][0]
signed_frame_index = next(
i for i, frame in enumerate(frames) if "tee_signature" in frame
)
self.assertLess(first_tool_frame_index, signed_frame_index)

@patch("tee_gateway.controllers.chat_controller.get_tee_keys")
@patch("tee_gateway.controllers.chat_controller.get_chat_model_cached")
@patch("tee_gateway.controllers.chat_controller.connexion")
Expand Down
Loading