@@ -131,6 +131,9 @@ def get_messages(messages_and_choices: MessagesAndChoices) -> Messages:
131131class TrajectoryGroup (pydantic .BaseModel ):
132132 trajectories : list [Trajectory ]
133133 exceptions : list [PydanticException ] = []
134+ metadata : dict [str , MetadataValue ] = {}
135+ metrics : dict [str , float | int | bool ] = {}
136+ logs : list [str ] = []
134137
135138 def __init__ (
136139 self ,
@@ -139,6 +142,9 @@ def __init__(
139142 ),
140143 * ,
141144 exceptions : list [BaseException ] = [],
145+ metadata : dict [str , MetadataValue ] | None = None ,
146+ metrics : dict [str , float | int | bool ] | None = None ,
147+ logs : list [str ] | None = None ,
142148 ) -> None :
143149 super ().__init__ (
144150 trajectories = [
@@ -166,6 +172,11 @@ def __init__(
166172 + exceptions
167173 )
168174 ],
175+ metadata = metadata
176+ if metadata is not None
177+ else getattr (self , "metadata" , {}),
178+ metrics = metrics if metrics is not None else getattr (self , "metrics" , {}),
179+ logs = logs if logs is not None else getattr (self , "logs" , []),
169180 )
170181
171182 def __copy__ (self ):
@@ -176,6 +187,9 @@ def __copy__(self):
176187 new_instance = self .__class__ (
177188 trajectories = self .trajectories [:], # Shallow copy of list
178189 exceptions = [], # Will be set below
190+ metadata = self .metadata .copy (),
191+ metrics = self .metrics .copy (),
192+ logs = self .logs [:],
179193 )
180194 # Manually copy exceptions since they're PydanticException objects
181195 new_instance .exceptions = self .exceptions [:]
@@ -197,13 +211,19 @@ def __deepcopy__(self, memo: dict[int, Any] | None = None):
197211 new_instance = self .__class__ (
198212 trajectories = copy .deepcopy (self .trajectories , memo ),
199213 exceptions = [], # Will be set below
214+ metadata = copy .deepcopy (self .metadata , memo ),
215+ metrics = copy .deepcopy (self .metrics , memo ),
216+ logs = copy .deepcopy (self .logs , memo ),
200217 )
201218 # Register in memo before deep copying attributes to handle circular refs
202219 memo [id (self )] = new_instance
203220 # Deep copy exceptions
204221 new_instance .exceptions = copy .deepcopy (self .exceptions , memo )
205222 return new_instance
206223
224+ def log (self , message : str ) -> None :
225+ self .logs .append (message )
226+
207227 def __iter__ (self ) -> Iterator [Trajectory ]: # type: ignore[override]
208228 return iter (self .trajectories )
209229
@@ -216,6 +236,9 @@ def __new__(
216236 trajectories : Iterable [Trajectory | BaseException ],
217237 * ,
218238 exceptions : list [BaseException ] = [],
239+ metadata : dict [str , MetadataValue ] | None = None ,
240+ metrics : dict [str , float | int | bool ] | None = None ,
241+ logs : list [str ] | None = None ,
219242 ) -> "TrajectoryGroup" : ...
220243
221244 @overload
@@ -224,6 +247,9 @@ def __new__(
224247 trajectories : Iterable [Awaitable [Trajectory ]],
225248 * ,
226249 exceptions : list [BaseException ] = [],
250+ metadata : dict [str , MetadataValue ] | None = None ,
251+ metrics : dict [str , float | int | bool ] | None = None ,
252+ logs : list [str ] | None = None ,
227253 ) -> Awaitable ["TrajectoryGroup" ]: ...
228254
229255 def __new__ (
@@ -233,11 +259,19 @@ def __new__(
233259 ),
234260 * ,
235261 exceptions : list [BaseException ] = [],
262+ metadata : dict [str , MetadataValue ] | None = None ,
263+ metrics : dict [str , float | int | bool ] | None = None ,
264+ logs : list [str ] | None = None ,
236265 ) -> "TrajectoryGroup | Awaitable[TrajectoryGroup]" :
237266 ts = list (trajectories )
238267 if any (hasattr (t , "__await__" ) for t in ts ):
239268
240- async def _ (exceptions : list [BaseException ]):
269+ async def _ (
270+ exceptions : list [BaseException ],
271+ metadata : dict [str , MetadataValue ] | None ,
272+ metrics : dict [str , float | int | bool ] | None ,
273+ logs : list [str ] | None ,
274+ ):
241275 from .gather import get_gather_context , record_metrics
242276
243277 context = get_gather_context ()
@@ -259,6 +293,9 @@ async def _(exceptions: list[BaseException]):
259293 return TrajectoryGroup (
260294 trajectories = trajectories ,
261295 exceptions = exceptions ,
296+ metadata = metadata ,
297+ metrics = metrics ,
298+ logs = logs ,
262299 )
263300
264301 class CoroutineWithMetadata :
@@ -269,12 +306,15 @@ def __init__(self, coro, num_trajectories):
269306 def __await__ (self ):
270307 return self .coro .__await__ ()
271308
272- coro = _ (exceptions .copy ())
309+ coro = _ (exceptions .copy (), metadata , metrics , logs )
273310 return CoroutineWithMetadata (coro , len (ts ))
274311 else :
275312 group = super ().__new__ (cls )
276313 group .__init__ (
277314 trajectories = cast (list [Trajectory | BaseException ], ts ),
278315 exceptions = exceptions ,
316+ metadata = metadata ,
317+ metrics = metrics ,
318+ logs = logs ,
279319 )
280320 return group
0 commit comments