Skip to content

Commit ca77e97

Browse files
authored
feat: Add W&B run config API (#615)
1 parent 4888500 commit ca77e97

File tree

2 files changed

+159
-0
lines changed

2 files changed

+159
-0
lines changed

src/art/model.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ class Model(
114114
_openai_client: AsyncOpenAI | None = None
115115
_wandb_run: Optional["Run"] = None # Private, for lazy wandb initialization
116116
_wandb_defined_metrics: set[str]
117+
_wandb_config: dict[str, Any]
117118
_run_start_time: float
118119
_run_start_monotonic: float
119120
_last_local_train_log_monotonic: float
@@ -150,6 +151,7 @@ def __init__(
150151
**kwargs,
151152
)
152153
object.__setattr__(self, "_wandb_defined_metrics", set())
154+
object.__setattr__(self, "_wandb_config", {})
153155
object.__setattr__(self, "_run_start_time", time.time())
154156
object.__setattr__(self, "_run_start_monotonic", time.monotonic())
155157
object.__setattr__(
@@ -371,6 +373,34 @@ def _deep_merge_dicts(
371373
merged[key] = value
372374
return merged
373375

376+
@staticmethod
377+
def _merge_wandb_config(
378+
existing: dict[str, Any],
379+
updates: dict[str, Any],
380+
*,
381+
path: str = "",
382+
) -> dict[str, Any]:
383+
merged = dict(existing)
384+
for key, value in updates.items():
385+
key_path = f"{path}.{key}" if path else key
386+
if key not in merged:
387+
merged[key] = value
388+
continue
389+
existing_value = merged[key]
390+
if isinstance(existing_value, dict) and isinstance(value, dict):
391+
merged[key] = Model._merge_wandb_config(
392+
existing_value,
393+
value,
394+
path=key_path,
395+
)
396+
continue
397+
if existing_value != value:
398+
raise ValueError(
399+
"W&B config is immutable once set. "
400+
f"Conflicting value for '{key_path}'."
401+
)
402+
return merged
403+
374404
def read_state(self) -> StateType | None:
375405
"""Read persistent state from the model directory.
376406
@@ -390,6 +420,43 @@ def read_state(self) -> StateType | None:
390420
with open(state_path, "r") as f:
391421
return json.load(f)
392422

423+
def update_wandb_config(
424+
self,
425+
config: dict[str, Any],
426+
) -> None:
427+
"""Merge configuration into the W&B run config for this model.
428+
429+
This can be called before the W&B run exists, in which case the config is
430+
passed to `wandb.init(...)` when ART first creates the run. If the run is
431+
already active, ART updates the run config immediately.
432+
433+
Args:
434+
config: JSON-serializable configuration to store on the W&B run.
435+
"""
436+
if not isinstance(config, dict):
437+
raise TypeError("config must be a dict[str, Any]")
438+
439+
merged = self._merge_wandb_config(self._wandb_config, config)
440+
object.__setattr__(self, "_wandb_config", merged)
441+
442+
if self._wandb_run is not None and not self._wandb_run._is_finished:
443+
self._sync_wandb_config(self._wandb_run)
444+
445+
def _sync_wandb_config(
446+
self,
447+
run: "Run",
448+
) -> None:
449+
if not self._wandb_config:
450+
return
451+
452+
run_config = getattr(run, "config", None)
453+
if run_config is None or not hasattr(run_config, "update"):
454+
return
455+
456+
run_config.update(
457+
self._wandb_config,
458+
)
459+
393460
def _get_wandb_run(self) -> Optional["Run"]:
394461
"""Get or create the wandb run for this model."""
395462
import wandb
@@ -401,6 +468,7 @@ def _get_wandb_run(self) -> Optional["Run"]:
401468
project=self.project,
402469
name=self.name,
403470
id=self.name,
471+
config=self._wandb_config or None,
404472
resume="allow",
405473
settings=wandb.Settings(
406474
x_stats_open_metrics_endpoints={
@@ -436,6 +504,7 @@ def _get_wandb_run(self) -> Optional["Run"]:
436504
wandb.define_metric("val/*", step_metric="training_step")
437505
wandb.define_metric("test/*", step_metric="training_step")
438506
wandb.define_metric("discarded/*", step_metric="training_step")
507+
self._sync_wandb_config(run)
439508
return self._wandb_run
440509

441510
def _log_metrics(

tests/unit/test_metric_routing.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
import types
55
from unittest.mock import MagicMock, patch
66

7+
import pytest
8+
79
from art import Model
810

911

@@ -80,6 +82,7 @@ def test_log_metrics_defines_nested_cost_keys_with_training_step(
8082
) -> None:
8183
fake_run = MagicMock()
8284
fake_run._is_finished = False
85+
fake_run.config = MagicMock()
8386

8487
fake_wandb = types.SimpleNamespace()
8588
fake_wandb.init = MagicMock(return_value=fake_run)
@@ -121,3 +124,90 @@ def test_log_metrics_defines_nested_cost_keys_with_training_step(
121124
assert logged_metrics["training_step"] == 1
122125
assert "time/wall_clock_sec" in logged_metrics
123126
assert fake_run.log.call_args.kwargs == {}
127+
128+
def test_update_wandb_config_seeds_wandb_init(self, tmp_path: Path) -> None:
129+
fake_run = MagicMock()
130+
fake_run._is_finished = False
131+
fake_run.config = MagicMock()
132+
133+
fake_wandb = types.SimpleNamespace()
134+
fake_wandb.init = MagicMock(return_value=fake_run)
135+
fake_wandb.define_metric = MagicMock()
136+
fake_wandb.Settings = lambda **kwargs: kwargs
137+
138+
payload = {
139+
"experiment": {"learning_rate": 1e-5, "batch_size": 4},
140+
"dataset": {"split": "train"},
141+
}
142+
143+
with patch.dict(os.environ, {"WANDB_API_KEY": "test-key"}, clear=False):
144+
with patch.dict("sys.modules", {"wandb": fake_wandb}):
145+
model = Model(
146+
name="test-model",
147+
project="test-project",
148+
base_path=str(tmp_path),
149+
)
150+
model.update_wandb_config(payload)
151+
run = model._get_wandb_run()
152+
153+
assert run is fake_run
154+
init_kwargs = fake_wandb.init.call_args.kwargs
155+
assert init_kwargs["config"] == payload
156+
assert "allow_val_change" not in init_kwargs
157+
fake_run.config.update.assert_called_once_with(payload)
158+
159+
def test_update_wandb_config_updates_active_run(self, tmp_path: Path) -> None:
160+
fake_run = MagicMock()
161+
fake_run._is_finished = False
162+
fake_run.config = MagicMock()
163+
164+
fake_wandb = types.SimpleNamespace()
165+
fake_wandb.init = MagicMock(return_value=fake_run)
166+
fake_wandb.define_metric = MagicMock()
167+
fake_wandb.Settings = lambda **kwargs: kwargs
168+
169+
with patch.dict(os.environ, {"WANDB_API_KEY": "test-key"}, clear=False):
170+
with patch.dict("sys.modules", {"wandb": fake_wandb}):
171+
model = Model(
172+
name="test-model",
173+
project="test-project",
174+
base_path=str(tmp_path),
175+
)
176+
model.update_wandb_config({"experiment": {"learning_rate": 1e-5}})
177+
_ = model._get_wandb_run()
178+
fake_run.config.update.reset_mock()
179+
180+
model.update_wandb_config(
181+
{"experiment": {"learning_rate": 1e-5, "batch_size": 8}},
182+
)
183+
model.update_wandb_config(
184+
{"dataset": {"split": "train"}},
185+
)
186+
187+
assert fake_run.config.update.call_count == 2
188+
assert fake_run.config.update.call_args_list[0].args == (
189+
{"experiment": {"learning_rate": 1e-5, "batch_size": 8}},
190+
)
191+
assert fake_run.config.update.call_args_list[1].args == (
192+
{
193+
"experiment": {"learning_rate": 1e-5, "batch_size": 8},
194+
"dataset": {"split": "train"},
195+
},
196+
)
197+
198+
def test_update_wandb_config_rejects_conflicting_values(
199+
self, tmp_path: Path
200+
) -> None:
201+
model = Model(
202+
name="test-model",
203+
project="test-project",
204+
base_path=str(tmp_path),
205+
)
206+
207+
model.update_wandb_config({"experiment": {"learning_rate": 1e-5}})
208+
209+
with pytest.raises(
210+
ValueError,
211+
match="Conflicting value for 'experiment.learning_rate'",
212+
):
213+
model.update_wandb_config({"experiment": {"learning_rate": 2e-5}})

0 commit comments

Comments
 (0)