Skip to content

Commit 31e4fd9

Browse files
authored
feat: group-level trajectory metadata (#539)
* feat: add group-level trajectory metadata Persist group-level metadata/metrics/logs in parquet and expose them in loaders and aggregate metrics so history/wandb can report group-level stats. * fix: align image processor typing * chore: drop local backend import workaround
1 parent fc4ffb7 commit 31e4fd9

8 files changed

Lines changed: 272 additions & 7 deletions

File tree

src/art/model.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -467,8 +467,14 @@ async def log(
467467

468468
# 2. Calculate aggregate metrics
469469
all_metrics: dict[str, list[float]] = {"reward": [], "exception_rate": []}
470+
group_metrics: dict[str, list[float]] = {}
470471

471472
for group in trajectory_groups:
473+
if group.trajectories:
474+
for metric, value in group.metrics.items():
475+
if metric not in group_metrics:
476+
group_metrics[metric] = []
477+
group_metrics[metric].append(float(value))
472478
for trajectory in group:
473479
if isinstance(trajectory, BaseException):
474480
all_metrics["exception_rate"].append(1)
@@ -490,6 +496,11 @@ async def log(
490496
if len(values) > 0:
491497
averages[metric] = sum(values) / len(values)
492498

499+
# Aggregate group-level metrics once per group
500+
for metric, values in group_metrics.items():
501+
if len(values) > 0:
502+
averages[f"group_metric_{metric}"] = sum(values) / len(values)
503+
493504
# Calculate average standard deviation of rewards within groups
494505
averages["reward_std_dev"] = calculate_step_std_dev(trajectory_groups)
495506

src/art/pipeline_trainer/trainer.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -610,8 +610,7 @@ def _apply_scenario_metadata(
610610
continue
611611
if not self._is_scalar_metadata(value):
612612
continue
613-
for trajectory in group.trajectories:
614-
trajectory.metadata[f"scenario_{key}"] = value
613+
group.metadata[f"scenario_{key}"] = value
615614

616615
def _is_group_stale(self, group: TrajectoryGroup, min_version: int) -> bool:
617616
group_version = self._group_initial_version(group)

src/art/trajectories.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,9 @@ def get_messages(messages_and_choices: MessagesAndChoices) -> Messages:
131131
class TrajectoryGroup(pydantic.BaseModel):
132132
trajectories: list[Trajectory]
133133
exceptions: list[PydanticException] = []
134+
metadata: dict[str, MetadataValue] = {}
135+
metrics: dict[str, float | int | bool] = {}
136+
logs: list[str] = []
134137

135138
def __init__(
136139
self,
@@ -139,6 +142,9 @@ def __init__(
139142
),
140143
*,
141144
exceptions: list[BaseException] = [],
145+
metadata: dict[str, MetadataValue] | None = None,
146+
metrics: dict[str, float | int | bool] | None = None,
147+
logs: list[str] | None = None,
142148
) -> None:
143149
super().__init__(
144150
trajectories=[
@@ -166,6 +172,11 @@ def __init__(
166172
+ exceptions
167173
)
168174
],
175+
metadata=metadata
176+
if metadata is not None
177+
else getattr(self, "metadata", {}),
178+
metrics=metrics if metrics is not None else getattr(self, "metrics", {}),
179+
logs=logs if logs is not None else getattr(self, "logs", []),
169180
)
170181

171182
def __copy__(self):
@@ -176,6 +187,9 @@ def __copy__(self):
176187
new_instance = self.__class__(
177188
trajectories=self.trajectories[:], # Shallow copy of list
178189
exceptions=[], # Will be set below
190+
metadata=self.metadata.copy(),
191+
metrics=self.metrics.copy(),
192+
logs=self.logs[:],
179193
)
180194
# Manually copy exceptions since they're PydanticException objects
181195
new_instance.exceptions = self.exceptions[:]
@@ -197,13 +211,19 @@ def __deepcopy__(self, memo: dict[int, Any] | None = None):
197211
new_instance = self.__class__(
198212
trajectories=copy.deepcopy(self.trajectories, memo),
199213
exceptions=[], # Will be set below
214+
metadata=copy.deepcopy(self.metadata, memo),
215+
metrics=copy.deepcopy(self.metrics, memo),
216+
logs=copy.deepcopy(self.logs, memo),
200217
)
201218
# Register in memo before deep copying attributes to handle circular refs
202219
memo[id(self)] = new_instance
203220
# Deep copy exceptions
204221
new_instance.exceptions = copy.deepcopy(self.exceptions, memo)
205222
return new_instance
206223

224+
def log(self, message: str) -> None:
225+
self.logs.append(message)
226+
207227
def __iter__(self) -> Iterator[Trajectory]: # type: ignore[override]
208228
return iter(self.trajectories)
209229

@@ -216,6 +236,9 @@ def __new__(
216236
trajectories: Iterable[Trajectory | BaseException],
217237
*,
218238
exceptions: list[BaseException] = [],
239+
metadata: dict[str, MetadataValue] | None = None,
240+
metrics: dict[str, float | int | bool] | None = None,
241+
logs: list[str] | None = None,
219242
) -> "TrajectoryGroup": ...
220243

221244
@overload
@@ -224,6 +247,9 @@ def __new__(
224247
trajectories: Iterable[Awaitable[Trajectory]],
225248
*,
226249
exceptions: list[BaseException] = [],
250+
metadata: dict[str, MetadataValue] | None = None,
251+
metrics: dict[str, float | int | bool] | None = None,
252+
logs: list[str] | None = None,
227253
) -> Awaitable["TrajectoryGroup"]: ...
228254

229255
def __new__(
@@ -233,11 +259,19 @@ def __new__(
233259
),
234260
*,
235261
exceptions: list[BaseException] = [],
262+
metadata: dict[str, MetadataValue] | None = None,
263+
metrics: dict[str, float | int | bool] | None = None,
264+
logs: list[str] | None = None,
236265
) -> "TrajectoryGroup | Awaitable[TrajectoryGroup]":
237266
ts = list(trajectories)
238267
if any(hasattr(t, "__await__") for t in ts):
239268

240-
async def _(exceptions: list[BaseException]):
269+
async def _(
270+
exceptions: list[BaseException],
271+
metadata: dict[str, MetadataValue] | None,
272+
metrics: dict[str, float | int | bool] | None,
273+
logs: list[str] | None,
274+
):
241275
from .gather import get_gather_context, record_metrics
242276

243277
context = get_gather_context()
@@ -259,6 +293,9 @@ async def _(exceptions: list[BaseException]):
259293
return TrajectoryGroup(
260294
trajectories=trajectories,
261295
exceptions=exceptions,
296+
metadata=metadata,
297+
metrics=metrics,
298+
logs=logs,
262299
)
263300

264301
class CoroutineWithMetadata:
@@ -269,12 +306,15 @@ def __init__(self, coro, num_trajectories):
269306
def __await__(self):
270307
return self.coro.__await__()
271308

272-
coro = _(exceptions.copy())
309+
coro = _(exceptions.copy(), metadata, metrics, logs)
273310
return CoroutineWithMetadata(coro, len(ts))
274311
else:
275312
group = super().__new__(cls)
276313
group.__init__(
277314
trajectories=cast(list[Trajectory | BaseException], ts),
278315
exceptions=exceptions,
316+
metadata=metadata,
317+
metrics=metrics,
318+
logs=logs,
279319
)
280320
return group

src/art/utils/benchmarking/load_trajectories.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,10 @@ async def load_trajectories(
6969
One column for every distinct metric key found in the dataset.
7070
metadata_* : str
7171
One column for every distinct metadata key.
72+
group_metric_* : float
73+
One column for every distinct group-level metric key.
74+
group_metadata_* : str
75+
One column for every distinct group-level metadata key.
7276
7377
Parameters
7478
----------
@@ -144,6 +148,8 @@ async def load_trajectories(
144148
rows: list[dict] = []
145149
metric_cols: set[str] = set()
146150
metadata_cols: set[str] = set()
151+
group_metric_cols: set[str] = set()
152+
group_metadata_cols: set[str] = set()
147153
# Map (model, split, step, group_index) -> unique group_number
148154
group_key_to_number: dict[tuple[str, str, int, int], int] = {}
149155
next_group_number = 1
@@ -195,11 +201,35 @@ async def load_trajectories(
195201
except (json.JSONDecodeError, TypeError):
196202
pass
197203

204+
# Parse group metrics from JSON (duplicated across group rows)
205+
group_metrics = {}
206+
if row_dict.get("group_metrics"):
207+
try:
208+
group_metrics = json.loads(row_dict["group_metrics"])
209+
except (json.JSONDecodeError, TypeError):
210+
pass
211+
212+
# Parse group metadata from JSON (duplicated across group rows)
213+
group_metadata = {}
214+
if row_dict.get("group_metadata"):
215+
try:
216+
group_metadata = json.loads(row_dict["group_metadata"])
217+
except (json.JSONDecodeError, TypeError):
218+
pass
219+
198220
# Prepare metrics and metadata columns
199221
prepped_metrics = {f"metric_{k}": v for k, v in metrics.items()}
200222
prepped_metadata = {f"metadata_{k}": str(v) for k, v in metadata.items()}
223+
prepped_group_metrics = {
224+
f"group_metric_{k}": v for k, v in group_metrics.items()
225+
}
226+
prepped_group_metadata = {
227+
f"group_metadata_{k}": str(v) for k, v in group_metadata.items()
228+
}
201229
metric_cols.update(prepped_metrics.keys())
202230
metadata_cols.update(prepped_metadata.keys())
231+
group_metric_cols.update(prepped_group_metrics.keys())
232+
group_metadata_cols.update(prepped_group_metadata.keys())
203233

204234
# Process messages
205235
messages = []
@@ -255,6 +285,8 @@ async def load_trajectories(
255285
"logs": row_dict.get("logs"),
256286
**prepped_metrics,
257287
**prepped_metadata,
288+
**prepped_group_metrics,
289+
**prepped_group_metadata,
258290
}
259291

260292
rows.append(row_data)
@@ -295,6 +327,8 @@ async def load_trajectories(
295327
}
296328
| {k: pl.Float64 for k in metric_cols}
297329
| {k: pl.Utf8 for k in metadata_cols}
330+
| {k: pl.Float64 for k in group_metric_cols}
331+
| {k: pl.Utf8 for k in group_metadata_cols}
298332
)
299333

300334
return pl.DataFrame(rows, schema=schema)

src/art/utils/trajectory_logging.py

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
import json
1111
from pathlib import Path
12-
from typing import Any
12+
from typing import Any, cast
1313

1414
from litellm.types.utils import Choices
1515
from openai.types.chat.chat_completion import Choice
@@ -72,6 +72,9 @@ def write_trajectory_groups_parquet(
7272

7373
rows = []
7474
for group_index, group in enumerate(trajectory_groups):
75+
group_metadata = json.dumps(group.metadata) if group.metadata else None
76+
group_metrics = json.dumps(group.metrics) if group.metrics else None
77+
group_logs = group.logs if group.logs else None
7578
for trajectory in group.trajectories:
7679
if not isinstance(trajectory, Trajectory):
7780
continue
@@ -96,6 +99,9 @@ def write_trajectory_groups_parquet(
9699
rows.append(
97100
{
98101
"group_index": group_index,
102+
"group_metadata": group_metadata,
103+
"group_metrics": group_metrics,
104+
"group_logs": group_logs,
99105
"reward": trajectory.reward,
100106
"metrics": json.dumps(trajectory.metrics)
101107
if trajectory.metrics
@@ -123,6 +129,9 @@ def write_trajectory_groups_parquet(
123129
schema = pa.schema(
124130
[
125131
("group_index", pa.int64()),
132+
("group_metadata", pa.string()),
133+
("group_metrics", pa.string()),
134+
("group_logs", pa.list_(pa.string())),
126135
("reward", pa.float64()),
127136
("metrics", pa.string()),
128137
("metadata", pa.string()),
@@ -158,6 +167,23 @@ def read_trajectory_groups_parquet(path: str | Path) -> list[TrajectoryGroup]:
158167
columns = [desc[0] for desc in con.description]
159168

160169
groups_dict: dict[int, list[Trajectory]] = {}
170+
group_metadata_by_index: dict[int, dict[str, Any]] = {}
171+
group_metrics_by_index: dict[int, dict[str, Any]] = {}
172+
group_logs_by_index: dict[int, list[str]] = {}
173+
174+
def _load_json_payload(payload: object | None) -> dict[str, Any]:
175+
if payload is None:
176+
return {}
177+
if isinstance(payload, dict):
178+
return cast(dict[str, Any], payload)
179+
if isinstance(payload, (str, bytes, bytearray)):
180+
if not payload:
181+
return {}
182+
try:
183+
return json.loads(payload)
184+
except (json.JSONDecodeError, TypeError):
185+
return {}
186+
return {}
161187

162188
for row in rows:
163189
row_dict = dict(zip(columns, row))
@@ -166,6 +192,24 @@ def read_trajectory_groups_parquet(path: str | Path) -> list[TrajectoryGroup]:
166192
continue
167193

168194
group_index = row_dict.get("group_index", 0)
195+
if group_index not in group_metadata_by_index:
196+
group_metadata_by_index[group_index] = _load_json_payload(
197+
row_dict.get("group_metadata")
198+
)
199+
if group_index not in group_metrics_by_index:
200+
group_metrics_by_index[group_index] = _load_json_payload(
201+
row_dict.get("group_metrics")
202+
)
203+
if group_index not in group_logs_by_index:
204+
raw_group_logs = row_dict.get("group_logs")
205+
if isinstance(raw_group_logs, (list, tuple)):
206+
group_logs_by_index[group_index] = [
207+
str(item) for item in raw_group_logs
208+
]
209+
elif raw_group_logs is None:
210+
group_logs_by_index[group_index] = []
211+
else:
212+
group_logs_by_index[group_index] = [str(raw_group_logs)]
169213

170214
# Convert messages
171215
messages_and_choices = []
@@ -196,6 +240,12 @@ def read_trajectory_groups_parquet(path: str | Path) -> list[TrajectoryGroup]:
196240
groups_dict[group_index].append(trajectory)
197241

198242
return [
199-
TrajectoryGroup(trajectories=groups_dict[idx], exceptions=[])
243+
TrajectoryGroup(
244+
trajectories=groups_dict[idx],
245+
exceptions=[],
246+
metadata=group_metadata_by_index.get(idx, {}),
247+
metrics=group_metrics_by_index.get(idx, {}),
248+
logs=group_logs_by_index.get(idx, []),
249+
)
200250
for idx in sorted(groups_dict.keys())
201251
]
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
import pytest
2+
3+
from art import Trajectory, TrajectoryGroup
4+
from art.utils.benchmarking.load_trajectories import load_trajectories
5+
from art.utils.trajectory_logging import write_trajectory_groups_parquet
6+
7+
8+
@pytest.mark.asyncio
9+
async def test_load_trajectories_group_columns(tmp_path):
10+
project_name = "proj"
11+
model_name = "model"
12+
traj_dir = tmp_path / project_name / "models" / model_name / "trajectories" / "val"
13+
traj_dir.mkdir(parents=True)
14+
15+
groups = [
16+
TrajectoryGroup(
17+
trajectories=[
18+
Trajectory(
19+
reward=1.0,
20+
messages_and_choices=[{"role": "user", "content": "hi"}],
21+
)
22+
],
23+
metadata={"scenario_id": "abc"},
24+
metrics={"judge_score": 0.9},
25+
logs=["group log"],
26+
exceptions=[],
27+
)
28+
]
29+
write_trajectory_groups_parquet(groups, traj_dir / "0000.parquet")
30+
31+
df = await load_trajectories(
32+
project_name=project_name,
33+
models=[model_name],
34+
art_path=str(tmp_path),
35+
)
36+
37+
assert "group_metric_judge_score" in df.columns
38+
assert "group_metadata_scenario_id" in df.columns
39+
assert df["group_metric_judge_score"][0] == 0.9
40+
assert df["group_metadata_scenario_id"][0] == "abc"

0 commit comments

Comments
 (0)