77 Any ,
88 AsyncGenerator ,
99 Awaitable ,
10- Coroutine ,
1110 Iterable ,
1211 Iterator ,
1312 cast ,
1716from openai .types .chat .chat_completion import Choice
1817import pydantic
1918
20- from .types import Messages , MessagesAndChoices , Tools
19+ from .types import Message , Messages , MessagesAndChoices , Tools
2120
2221MetadataValue = float | int | str | bool | None
2322
@@ -37,22 +36,17 @@ def messages(self) -> Messages:
3736
3837
3938class Trajectory (pydantic .BaseModel ):
40- messages_and_choices : MessagesAndChoices
39+ messages_and_choices : MessagesAndChoices = []
4140 tools : Tools | None = None
4241 additional_histories : list [History ] = []
4342 reward : float = 0.0
4443 initial_policy_version : int | None = None
4544 final_policy_version : int | None = None
4645 metrics : dict [str , float | int | bool ] = {}
47- auto_metrics : dict [str , float | int | bool ] = {}
4846 metadata : dict [str , MetadataValue ] = {}
4947 logs : list [str ] = []
5048 start_time : datetime = pydantic .Field (default_factory = datetime .now , exclude = True )
5149
52- def __init__ (self , ** data : Any ):
53- super ().__init__ (** data )
54- self .start_time = datetime .now ()
55-
5650 def log (self , message : str ) -> None :
5751 self .logs .append (message )
5852
@@ -79,7 +73,7 @@ def messages(self) -> Messages:
7973
8074 # Used for logging to console
8175 def for_logging (self ) -> dict [str , Any ]:
82- loggable_dict = {
76+ loggable_dict : dict [ str , Any ] = {
8377 "reward" : self .reward ,
8478 "initial_policy_version" : self .initial_policy_version ,
8579 "final_policy_version" : self .final_policy_version ,
@@ -90,11 +84,13 @@ def for_logging(self) -> dict[str, Any]:
9084 "logs" : self .logs ,
9185 }
9286 for message_or_choice in self .messages_and_choices :
93- trainable = isinstance (message_or_choice , Choice )
94- message = (
95- message_or_choice .message .to_dict () if trainable else message_or_choice # ty:ignore[possibly-missing-attribute]
96- )
97- loggable_dict ["messages" ].append ({** message , "trainable" : trainable }) # ty:ignore[invalid-argument-type, possibly-missing-attribute]
87+ if isinstance (message_or_choice , Choice ):
88+ trainable = True
89+ message : dict [str , Any ] = message_or_choice .message .to_dict ()
90+ else :
91+ trainable = False
92+ message = cast (dict [str , Any ], message_or_choice )
93+ loggable_dict ["messages" ].append ({** message , "trainable" : trainable })
9894 return loggable_dict
9995
10096
@@ -104,7 +100,8 @@ def get_messages(messages_and_choices: MessagesAndChoices) -> Messages:
104100 if isinstance (message_or_choice , Choice ):
105101 content = message_or_choice .message .content or ""
106102 tool_calls = message_or_choice .message .tool_calls or []
107- messages .append (
103+ assistant_message : Message = cast (
104+ Message ,
108105 {
109106 "role" : "assistant" ,
110107 "content" : content ,
@@ -118,8 +115,9 @@ def get_messages(messages_and_choices: MessagesAndChoices) -> Messages:
118115 if tool_calls
119116 else {}
120117 ),
121- }
118+ },
122119 )
120+ messages .append (assistant_message )
123121 else :
124122 # Ensure content is always a string for tokenizer chat templates
125123 msg = dict (message_or_choice )
@@ -251,7 +249,7 @@ def __new__(
251249 metadata : dict [str , MetadataValue ] | None = None ,
252250 metrics : dict [str , float | int | bool ] | None = None ,
253251 logs : list [str ] | None = None ,
254- ) -> Coroutine [ Any , Any , "TrajectoryGroup" ]: ...
252+ ) -> Awaitable [ "TrajectoryGroup" ]: ...
255253
256254 def __new__ (
257255 cls ,
@@ -263,7 +261,7 @@ def __new__(
263261 metadata : dict [str , MetadataValue ] | None = None ,
264262 metrics : dict [str , float | int | bool ] | None = None ,
265263 logs : list [str ] | None = None ,
266- ) -> "TrajectoryGroup | Coroutine[Any, Any, TrajectoryGroup]" :
264+ ) -> "TrajectoryGroup | Awaitable[ TrajectoryGroup]" :
267265 ts = list (trajectories )
268266 if any (hasattr (t , "__await__" ) for t in ts ):
269267
0 commit comments