Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions backend/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@ COPY pyproject.toml ./
ARG CLAWITH_PIP_INDEX_URL
ARG CLAWITH_PIP_TRUSTED_HOST
RUN if [ -n "$CLAWITH_PIP_INDEX_URL" ] && [ -n "$CLAWITH_PIP_TRUSTED_HOST" ]; then \
pip install --no-cache-dir --index-url "$CLAWITH_PIP_INDEX_URL" --trusted-host "$CLAWITH_PIP_TRUSTED_HOST" .; \
pip install --no-cache-dir --index-url "$CLAWITH_PIP_INDEX_URL" --trusted-host "$CLAWITH_PIP_TRUSTED_HOST" . uv; \
elif [ -n "$CLAWITH_PIP_INDEX_URL" ]; then \
pip install --no-cache-dir --index-url "$CLAWITH_PIP_INDEX_URL" .; \
pip install --no-cache-dir --index-url "$CLAWITH_PIP_INDEX_URL" . uv; \
else \
pip install --no-cache-dir .; \
pip install --no-cache-dir . uv; \
fi

# ─── Production ─────────────────────────────────────────
Expand Down
50 changes: 31 additions & 19 deletions backend/alembic/versions/056_add_user_tenant_onboarding.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,25 +19,37 @@


def upgrade() -> None:
op.create_table(
"user_tenant_onboardings",
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True, nullable=False),
sa.Column("user_id", postgresql.UUID(as_uuid=True), sa.ForeignKey("users.id", ondelete="CASCADE"), nullable=False),
sa.Column("tenant_id", postgresql.UUID(as_uuid=True), sa.ForeignKey("tenants.id", ondelete="CASCADE"), nullable=False),
sa.Column("status", sa.String(length=32), nullable=False, server_default="in_progress"),
sa.Column("current_step", sa.String(length=32), nullable=False, server_default="assistant"),
sa.Column("entry_mode", sa.String(length=32), nullable=False, server_default="create"),
sa.Column("personal_assistant_agent_id", postgresql.UUID(as_uuid=True), sa.ForeignKey("agents.id", ondelete="SET NULL"), nullable=True),
sa.Column("started_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
sa.Column("completed_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
sa.UniqueConstraint("user_id", "tenant_id", name="uq_user_tenant_onboarding"),
)
op.create_index("ix_user_tenant_onboardings_user_id", "user_tenant_onboardings", ["user_id"])
op.create_index("ix_user_tenant_onboardings_tenant_id", "user_tenant_onboardings", ["tenant_id"])
bind = op.get_bind()
inspector = sa.inspect(bind)

if "user_tenant_onboardings" not in inspector.get_table_names():
op.create_table(
"user_tenant_onboardings",
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True, nullable=False),
sa.Column("user_id", postgresql.UUID(as_uuid=True), sa.ForeignKey("users.id", ondelete="CASCADE"), nullable=False),
sa.Column("tenant_id", postgresql.UUID(as_uuid=True), sa.ForeignKey("tenants.id", ondelete="CASCADE"), nullable=False),
sa.Column("status", sa.String(length=32), nullable=False, server_default="in_progress"),
sa.Column("current_step", sa.String(length=32), nullable=False, server_default="assistant"),
sa.Column("entry_mode", sa.String(length=32), nullable=False, server_default="create"),
sa.Column("personal_assistant_agent_id", postgresql.UUID(as_uuid=True), sa.ForeignKey("agents.id", ondelete="SET NULL"), nullable=True),
sa.Column("started_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
sa.Column("completed_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
sa.UniqueConstraint("user_id", "tenant_id", name="uq_user_tenant_onboarding"),
)

op.create_index("ix_user_tenant_onboardings_user_id", "user_tenant_onboardings", ["user_id"])
op.create_index("ix_user_tenant_onboardings_tenant_id", "user_tenant_onboardings", ["tenant_id"])


def downgrade() -> None:
op.drop_index("ix_user_tenant_onboardings_tenant_id", table_name="user_tenant_onboardings")
op.drop_index("ix_user_tenant_onboardings_user_id", table_name="user_tenant_onboardings")
op.drop_table("user_tenant_onboardings")
bind = op.get_bind()
inspector = sa.inspect(bind)

if "user_tenant_onboardings" in inspector.get_table_names():
indexes = [ix["name"] for ix in inspector.get_indexes("user_tenant_onboardings")]
if "ix_user_tenant_onboardings_tenant_id" in indexes:
op.drop_index("ix_user_tenant_onboardings_tenant_id", table_name="user_tenant_onboardings")
if "ix_user_tenant_onboardings_user_id" in indexes:
op.drop_index("ix_user_tenant_onboardings_user_id", table_name="user_tenant_onboardings")
op.drop_table("user_tenant_onboardings")
4 changes: 2 additions & 2 deletions backend/app/services/agent_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,6 @@ async def _get_tool_config(agent_id: Optional[uuid.UUID], tool_name: str) -> Opt
base_config = global_config or {}
tenant_config = {}
if tool_source == "builtin":
base_config = {}
tenant_config = await get_tenant_tool_config(db, agent_tenant_id, db_tool_name, config_schema)
# Merge: agent overrides global
merged = {**base_config, **tenant_config, **(agent_config or {})}
Expand All @@ -206,7 +205,7 @@ async def _get_tool_config(agent_id: Optional[uuid.UUID], tool_name: str) -> Opt
tenant_config = {}
if tool.source == "builtin":
tenant_config = await get_tenant_tool_config(db, agent_tenant_id, tool.name, tool.config_schema)
base_config = {} if tool.source == "builtin" else (tool.config or {})
base_config = tool.config or {}
merged = {**base_config, **tenant_config}
else:
merged = {}
Expand Down Expand Up @@ -8047,6 +8046,7 @@ async def _execute_code(
timeout=timeout,
work_dir=str(work_dir),
on_output=on_output,
agent_id=agent_id,
)

# Format result for user display
Expand Down
58 changes: 46 additions & 12 deletions backend/app/services/heartbeat.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,7 @@ async def _execute_heartbeat(agent_id: uuid.UUID):
plaza_posts_made = 0 # hard limit: 1 new post per heartbeat
plaza_comments_made = 0 # hard limit: 2 comments per heartbeat
_hb_accumulated_usage = None
_hb_unsaved_usage = None

# Token tracking helpers
from app.services.token_tracker import (
Expand All @@ -304,6 +305,7 @@ async def _execute_heartbeat(agent_id: uuid.UUID):
estimate_token_usage_from_chars,
)
_hb_accumulated_usage = TokenUsage()
_hb_unsaved_usage = TokenUsage()

# Convert messages to LLMMessage format
llm_messages = [
Expand All @@ -312,6 +314,21 @@ async def _execute_heartbeat(agent_id: uuid.UUID):
]

for round_i in range(20): # More rounds for search + write + plaza
# Check token usage limit mid-loop (every 3 rounds)
if round_i > 0 and round_i % 3 == 0:
if agent_id and _hb_unsaved_usage.total_tokens > 0:
async with async_session() as db:
await record_token_usage(agent_id, _hb_unsaved_usage)
await db.commit()
_hb_unsaved_usage = TokenUsage()
from app.services.llm.caller import _get_agent_config
_, _token_limit_msg = await _get_agent_config(agent_id)
if _token_limit_msg:
logger.warning(f"[Heartbeat] Token limit exceeded mid-loop: {_token_limit_msg}")
await client.close()
reply = _token_limit_msg
break

try:
response = await client.complete(
messages=llm_messages,
Expand All @@ -330,11 +347,11 @@ async def _execute_heartbeat(agent_id: uuid.UUID):

# Track tokens for this round
usage = extract_token_usage(response.usage)
if usage:
_hb_accumulated_usage.add(usage)
else:
if not usage:
round_chars = sum(len(m.content or '') for m in llm_messages) + len(response.content or '')
_hb_accumulated_usage.add(estimate_token_usage_from_chars(round_chars))
usage = estimate_token_usage_from_chars(round_chars)
_hb_accumulated_usage.add(usage)
_hb_unsaved_usage.add(usage)

if response.tool_calls:
# Add assistant message with tool calls
Expand Down Expand Up @@ -423,8 +440,8 @@ async def _execute_heartbeat(agent_id: uuid.UUID):
# ── Phase 3: Write results back to DB (short transaction) ──
async with async_session() as db:
# Record accumulated heartbeat token usage
if _hb_accumulated_usage and _hb_accumulated_usage.total_tokens > 0:
await record_token_usage(agent_id, _hb_accumulated_usage)
if _hb_unsaved_usage and _hb_unsaved_usage.total_tokens > 0:
await record_token_usage(agent_id, _hb_unsaved_usage)
await db.commit()

# Log activity if not empty
Expand Down Expand Up @@ -687,8 +704,25 @@ async def run_agent_oneshot(

reply = ""
accumulated_usage = TokenUsage()
unsaved_usage = TokenUsage()

for round_i in range(max_rounds):
# Check token usage limit mid-loop (every 3 rounds)
if round_i > 0 and round_i % 3 == 0:
if agent_id and unsaved_usage.total_tokens > 0:
try:
await record_token_usage(agent_id, unsaved_usage)
except Exception as e:
logger.warning(f"[Oneshot] Failed to record token usage mid-loop: {e}")
unsaved_usage = TokenUsage()
from app.services.llm.caller import _get_agent_config
_, _token_limit_msg = await _get_agent_config(agent_id)
if _token_limit_msg:
logger.warning(f"[Oneshot] Token limit exceeded mid-loop: {_token_limit_msg}")
await client.close()
reply = _token_limit_msg
break

try:
response = await client.complete(
messages=llm_messages,
Expand All @@ -713,11 +747,11 @@ async def run_agent_oneshot(

# Track token usage
usage = extract_token_usage(response.usage)
if usage:
accumulated_usage.add(usage)
else:
if not usage:
round_chars = sum(len(m.content or "") for m in llm_messages) + len(response.content or "")
accumulated_usage.add(estimate_token_usage_from_chars(round_chars))
usage = estimate_token_usage_from_chars(round_chars)
accumulated_usage.add(usage)
unsaved_usage.add(usage)

if response.tool_calls:
llm_messages.append(LLMMessage(
Expand Down Expand Up @@ -769,9 +803,9 @@ async def run_agent_oneshot(
await client.close()

# ── Phase 3: Record token usage (best-effort) ───────────────────────────
if accumulated_usage.total_tokens > 0:
if unsaved_usage.total_tokens > 0:
try:
await record_token_usage(agent_id, accumulated_usage)
await record_token_usage(agent_id, unsaved_usage)
except Exception as e:
logger.warning(f"[Oneshot] Failed to record token usage: {e}")

Expand Down
71 changes: 51 additions & 20 deletions backend/app/services/llm/caller.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,14 +93,17 @@ def _sanitize_tool_calls_for_context(tool_calls: list[dict]) -> tuple[list[dict]
"Retry the tool call with `function.arguments` as one valid JSON object string."
)

sanitized.append({
new_tc = {
"id": tc.get("id", ""),
"type": tc.get("type") or "function",
"function": {
"name": tool_name,
"arguments": args_str,
},
})
}
if "_gemini_extra" in tc:
new_tc["_gemini_extra"] = tc["_gemini_extra"]
sanitized.append(new_tc)

return sanitized, None

Expand Down Expand Up @@ -497,6 +500,7 @@ async def _default_on_tool_call(data: dict):

max_tokens = get_max_tokens(model.provider, model.model, getattr(model, 'max_output_tokens', None))
_accumulated_usage = TokenUsage()
_unsaved_usage = TokenUsage()

# Tool-calling loop
for round_i in range(_max_tool_rounds):
Expand All @@ -518,6 +522,17 @@ async def _default_on_tool_call(data: dict):
content="🚨 仅剩 2 轮工具调用。请立即使用 upsert_focus_item 保存进度并设置续接触发器。",
))

# Check token usage limit mid-loop (every 3 rounds)
if round_i > 0 and round_i % 3 == 0:
if agent_id and _unsaved_usage.total_tokens > 0:
await record_token_usage(agent_id, _unsaved_usage)
_unsaved_usage = TokenUsage()
_, _token_limit_msg = await _get_agent_config(agent_id)
if _token_limit_msg:
logger.warning(f"[LLM] Token limit exceeded mid-loop: {_token_limit_msg}")
await client.close()
return _token_limit_msg

try:
# Use streaming API for real-time responses
async def _buffer_chunk(_text: str) -> None:
Expand All @@ -535,19 +550,21 @@ async def _buffer_chunk(_text: str) -> None:
)
except LLMError as e:
logger.error(f"[LLM] LLMError: provider={getattr(model, 'provider', '?')} model={getattr(model, 'model', '?')} {e}")
if agent_id and _accumulated_usage.total_tokens > 0:
await record_token_usage(agent_id, _accumulated_usage)
if agent_id and _unsaved_usage.total_tokens > 0:
await record_token_usage(agent_id, _unsaved_usage)
await client.close()
return f"[LLM Error] {e}"
except Exception as e:
logger.exception(f"[LLM] Unexpected error: {type(e).__name__}: {str(e)[:300]}")
if agent_id and _accumulated_usage.total_tokens > 0:
await record_token_usage(agent_id, _accumulated_usage)
if agent_id and _unsaved_usage.total_tokens > 0:
await record_token_usage(agent_id, _unsaved_usage)
await client.close()
return f"[LLM call error] {type(e).__name__}: {str(e)[:200]}"

# Track tokens for this round
_accumulated_usage.add(_usage_from_response_or_estimate(response, api_messages))
_usage_this_round = _usage_from_response_or_estimate(response, api_messages)
_accumulated_usage.add(_usage_this_round)
_unsaved_usage.add(_usage_this_round)

# Plain assistant text is not a stop condition. The model must finish
# explicitly via finish(content=...).
Expand All @@ -567,8 +584,8 @@ async def _buffer_chunk(_text: str) -> None:
finish_call = find_finish_call(sanitized_tool_calls)
if finish_call:
if finish_call.valid:
if agent_id and _accumulated_usage.total_tokens > 0:
await record_token_usage(agent_id, _accumulated_usage)
if agent_id and _unsaved_usage.total_tokens > 0:
await record_token_usage(agent_id, _unsaved_usage)
await client.close()
return finish_call.content

Expand Down Expand Up @@ -616,8 +633,8 @@ async def _buffer_chunk(_text: str) -> None:
))

# Record tokens even on "too many rounds" exit
if agent_id and _accumulated_usage.total_tokens > 0:
await record_token_usage(agent_id, _accumulated_usage)
if agent_id and _unsaved_usage.total_tokens > 0:
await record_token_usage(agent_id, _unsaved_usage)
await client.close()
return "[Error] Too many tool call rounds"

Expand Down Expand Up @@ -882,6 +899,7 @@ async def call_agent_llm_with_tools(
async def _try_model(model: LLMModel) -> tuple[str, bool, bool]:
"""Try to complete with a model. Returns (response, success, tool_executed)."""
_accumulated_usage = TokenUsage()
_unsaved_usage = TokenUsage()
tool_executed = False
try:
client = create_llm_client(
Expand All @@ -900,6 +918,17 @@ async def _try_model(model: LLMModel) -> tuple[str, bool, bool]:
# Tool-calling loop
api_messages = list(messages)
for round_i in range(max_rounds):
# Check token usage limit mid-loop (every 3 rounds)
if round_i > 0 and round_i % 3 == 0:
if agent_id and _unsaved_usage.total_tokens > 0:
await record_token_usage(agent_id, _unsaved_usage)
_unsaved_usage = TokenUsage()
_, _token_limit_msg = await _get_agent_config(agent_id)
if _token_limit_msg:
logger.warning(f"[call_agent_llm_with_tools] Token limit exceeded mid-loop: {_token_limit_msg}")
await client.close()
return _token_limit_msg, False, tool_executed

try:
response = await client.complete(
messages=api_messages,
Expand All @@ -910,12 +939,14 @@ async def _try_model(model: LLMModel) -> tuple[str, bool, bool]:
except Exception as e:
logger.error(f"[call_agent_llm_with_tools] Agent {agent_id}: LLM call error: {e}")
await client.close()
if agent_id and _accumulated_usage.total_tokens > 0:
await record_token_usage(agent_id, _accumulated_usage)
if agent_id and _unsaved_usage.total_tokens > 0:
await record_token_usage(agent_id, _unsaved_usage)
raise

# Track tokens for this round
_accumulated_usage.add(_usage_from_response_or_estimate(response, api_messages))
_usage_this_round = _usage_from_response_or_estimate(response, api_messages)
_accumulated_usage.add(_usage_this_round)
_unsaved_usage.add(_usage_this_round)

if not response.tool_calls:
if response.content:
Expand All @@ -932,8 +963,8 @@ async def _try_model(model: LLMModel) -> tuple[str, bool, bool]:
finish_call = find_finish_call(sanitized_tool_calls)
if finish_call:
if finish_call.valid:
if agent_id and _accumulated_usage.total_tokens > 0:
await record_token_usage(agent_id, _accumulated_usage)
if agent_id and _unsaved_usage.total_tokens > 0:
await record_token_usage(agent_id, _unsaved_usage)
await client.close()
return finish_call.content, True, tool_executed
api_messages.append(LLMMessage(
Expand Down Expand Up @@ -982,14 +1013,14 @@ async def _try_model(model: LLMModel) -> tuple[str, bool, bool]:
content=str(result),
))

if agent_id and _accumulated_usage.total_tokens > 0:
await record_token_usage(agent_id, _accumulated_usage)
if agent_id and _unsaved_usage.total_tokens > 0:
await record_token_usage(agent_id, _unsaved_usage)
await client.close()
return "[Error] Too many tool call rounds", False, tool_executed

except Exception as e:
if agent_id and _accumulated_usage.total_tokens > 0:
await record_token_usage(agent_id, _accumulated_usage)
if agent_id and _unsaved_usage.total_tokens > 0:
await record_token_usage(agent_id, _unsaved_usage)
return f"[Error] {e}", False, tool_executed

# Try primary model
Expand Down
Loading