Skip to content

Commit fb94450

Browse files
authored
refactor: streamline AutoTrajectoryContext management
- Updated the context manager to return the trajectory directly upon entering, simplifying the capture_auto_trajectory function. - Ensured the trajectory is properly finished upon exiting the context, improving resource management.
1 parent 31e4fd9 commit fb94450

2 files changed

Lines changed: 9 additions & 8 deletions

File tree

src/art/auto_trajectory.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,10 +62,8 @@ def auto_trajectory(*, required: bool = False) -> Trajectory | None:
6262

6363

6464
async def capture_auto_trajectory(coroutine: Coroutine[Any, Any, Any]) -> Trajectory:
65-
with AutoTrajectoryContext():
65+
with AutoTrajectoryContext() as trajectory:
6666
await coroutine
67-
trajectory = auto_trajectory_context_var.get().trajectory
68-
trajectory.finish()
6967
return trajectory
7068

7169

@@ -76,11 +74,13 @@ def __init__(self) -> None:
7674
reward=0.0,
7775
)
7876

79-
def __enter__(self) -> None:
77+
def __enter__(self) -> Trajectory:
8078
self.token = auto_trajectory_context_var.set(self)
79+
return self.trajectory
8180

8281
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
8382
auto_trajectory_context_var.reset(self.token)
83+
self.trajectory.finish()
8484

8585
def handle_httpx_response(self, response: httpx._models.Response) -> None:
8686
# Get buffered content (set by patched aiter_bytes/iter_bytes)

src/art/trajectories.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,21 @@
11
import asyncio
2-
from contextlib import asynccontextmanager
3-
from datetime import datetime
42
import time
53
import traceback
4+
from contextlib import asynccontextmanager
5+
from datetime import datetime
66
from typing import (
77
Any,
88
AsyncGenerator,
99
Awaitable,
10+
Coroutine,
1011
Iterable,
1112
Iterator,
1213
cast,
1314
overload,
1415
)
1516

16-
from openai.types.chat.chat_completion import Choice
1717
import pydantic
18+
from openai.types.chat.chat_completion import Choice
1819

1920
from .types import Messages, MessagesAndChoices, Tools
2021

@@ -262,7 +263,7 @@ def __new__(
262263
metadata: dict[str, MetadataValue] | None = None,
263264
metrics: dict[str, float | int | bool] | None = None,
264265
logs: list[str] | None = None,
265-
) -> "TrajectoryGroup | Awaitable[TrajectoryGroup]":
266+
) -> "TrajectoryGroup | Coroutine[Any, Any, TrajectoryGroup]":
266267
ts = list(trajectories)
267268
if any(hasattr(t, "__await__") for t in ts):
268269

0 commit comments

Comments
 (0)