Skip to content

Commit 4c08d24

Browse files
committed
refactor: Simplify Trajectory initialization and enhance type handling
- Updated the `Trajectory` class to initialize `messages_and_choices` with an empty list by default, streamlining the constructor. - Removed unnecessary constructor in `Trajectory` to leverage Pydantic's default behavior. - Improved type handling in `get_messages` function for better clarity and type safety. - Adjusted the `Models` class to use `model_dump(mode="json")` for trajectory groups. - Fixed type hints in unit tests for better compatibility with type checkers.
1 parent 73c8cdd commit 4c08d24

4 files changed

Lines changed: 20 additions & 25 deletions

File tree

src/art/auto_trajectory.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,10 +69,7 @@ async def capture_auto_trajectory(coroutine: Coroutine[Any, Any, Any]) -> Trajec
6969

7070
class AutoTrajectoryContext:
7171
def __init__(self) -> None:
72-
self.trajectory = Trajectory(
73-
messages_and_choices=[],
74-
reward=0.0,
75-
)
72+
self.trajectory = Trajectory()
7673

7774
def __enter__(self) -> Trajectory:
7875
self.token = auto_trajectory_context_var.set(self)

src/art/serverless/client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ async def log(
126126
body={
127127
"model_id": model_id,
128128
"trajectory_groups": [
129-
trajectory_group.model_dump()
129+
trajectory_group.model_dump(mode="json")
130130
for trajectory_group in trajectory_groups
131131
],
132132
"split": split,

src/art/trajectories.py

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
Any,
88
AsyncGenerator,
99
Awaitable,
10-
Coroutine,
1110
Iterable,
1211
Iterator,
1312
cast,
@@ -17,7 +16,7 @@
1716
from openai.types.chat.chat_completion import Choice
1817
import pydantic
1918

20-
from .types import Messages, MessagesAndChoices, Tools
19+
from .types import Message, Messages, MessagesAndChoices, Tools
2120

2221
MetadataValue = float | int | str | bool | None
2322

@@ -37,22 +36,17 @@ def messages(self) -> Messages:
3736

3837

3938
class Trajectory(pydantic.BaseModel):
40-
messages_and_choices: MessagesAndChoices
39+
messages_and_choices: MessagesAndChoices = []
4140
tools: Tools | None = None
4241
additional_histories: list[History] = []
4342
reward: float = 0.0
4443
initial_policy_version: int | None = None
4544
final_policy_version: int | None = None
4645
metrics: dict[str, float | int | bool] = {}
47-
auto_metrics: dict[str, float | int | bool] = {}
4846
metadata: dict[str, MetadataValue] = {}
4947
logs: list[str] = []
5048
start_time: datetime = pydantic.Field(default_factory=datetime.now, exclude=True)
5149

52-
def __init__(self, **data: Any):
53-
super().__init__(**data)
54-
self.start_time = datetime.now()
55-
5650
def log(self, message: str) -> None:
5751
self.logs.append(message)
5852

@@ -79,7 +73,7 @@ def messages(self) -> Messages:
7973

8074
# Used for logging to console
8175
def for_logging(self) -> dict[str, Any]:
82-
loggable_dict = {
76+
loggable_dict: dict[str, Any] = {
8377
"reward": self.reward,
8478
"initial_policy_version": self.initial_policy_version,
8579
"final_policy_version": self.final_policy_version,
@@ -90,11 +84,13 @@ def for_logging(self) -> dict[str, Any]:
9084
"logs": self.logs,
9185
}
9286
for message_or_choice in self.messages_and_choices:
93-
trainable = isinstance(message_or_choice, Choice)
94-
message = (
95-
message_or_choice.message.to_dict() if trainable else message_or_choice # ty:ignore[possibly-missing-attribute]
96-
)
97-
loggable_dict["messages"].append({**message, "trainable": trainable}) # ty:ignore[invalid-argument-type, possibly-missing-attribute]
87+
if isinstance(message_or_choice, Choice):
88+
trainable = True
89+
message: dict[str, Any] = message_or_choice.message.to_dict()
90+
else:
91+
trainable = False
92+
message = cast(dict[str, Any], message_or_choice)
93+
loggable_dict["messages"].append({**message, "trainable": trainable})
9894
return loggable_dict
9995

10096

@@ -104,7 +100,8 @@ def get_messages(messages_and_choices: MessagesAndChoices) -> Messages:
104100
if isinstance(message_or_choice, Choice):
105101
content = message_or_choice.message.content or ""
106102
tool_calls = message_or_choice.message.tool_calls or []
107-
messages.append(
103+
assistant_message: Message = cast(
104+
Message,
108105
{
109106
"role": "assistant",
110107
"content": content,
@@ -118,8 +115,9 @@ def get_messages(messages_and_choices: MessagesAndChoices) -> Messages:
118115
if tool_calls
119116
else {}
120117
),
121-
}
118+
},
122119
)
120+
messages.append(assistant_message)
123121
else:
124122
# Ensure content is always a string for tokenizer chat templates
125123
msg = dict(message_or_choice)
@@ -251,7 +249,7 @@ def __new__(
251249
metadata: dict[str, MetadataValue] | None = None,
252250
metrics: dict[str, float | int | bool] | None = None,
253251
logs: list[str] | None = None,
254-
) -> Coroutine[Any, Any, "TrajectoryGroup"]: ...
252+
) -> Awaitable["TrajectoryGroup"]: ...
255253

256254
def __new__(
257255
cls,
@@ -263,7 +261,7 @@ def __new__(
263261
metadata: dict[str, MetadataValue] | None = None,
264262
metrics: dict[str, float | int | bool] | None = None,
265263
logs: list[str] | None = None,
266-
) -> "TrajectoryGroup | Coroutine[Any, Any, TrajectoryGroup]":
264+
) -> "TrajectoryGroup | Awaitable[TrajectoryGroup]":
267265
ts = list(trajectories)
268266
if any(hasattr(t, "__await__") for t in ts):
269267

tests/unit/test_trajectory_parquet.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -235,8 +235,8 @@ def test_choice_format(self, tmp_path: Path):
235235
"content": "Hi!",
236236
"tool_calls": None,
237237
},
238-
},
239-
],
238+
}, # type: ignore
239+
], # ty:ignore[invalid-argument-type]
240240
logs=[],
241241
)
242242
],

0 commit comments

Comments
 (0)