Skip to content

Commit 5a3a029

Browse files
committed
refactor: Validate TinkerNative KL source before state lookup
1 parent 6973479 commit 5a3a029

File tree

2 files changed

+9
-29
lines changed

2 files changed

+9
-29
lines changed

src/art/tinker_native/backend.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,10 @@ async def train( # type: ignore[override]
314314
kl_penalty_reference_step: int | None = None,
315315
kl_penalty_source: Literal["sample"] = "sample",
316316
) -> TrainResult:
317+
assert kl_penalty_source == "sample", (
318+
"TinkerNativeBackend only supports kl_penalty_source='sample'."
319+
)
320+
317321
state = self._model_state[model.name]
318322
groups_list = list(trajectory_groups)
319323
summary = summarize_trajectory_groups(groups_list)
@@ -333,10 +337,6 @@ async def train( # type: ignore[override]
333337
"data/step_num_datums": float(len(datums)),
334338
}
335339

336-
assert kl_penalty_source == "sample", (
337-
"TinkerNativeBackend only supports kl_penalty_source='sample'."
338-
)
339-
340340
if not datums:
341341
return TrainResult(step=state.current_step, metrics=metrics)
342342

tests/unit/test_tinker_native_kl.py

Lines changed: 5 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,12 @@
1-
from typing import Any, cast
2-
31
import pytest
42
import tinker
53

64
from art import TrainableModel
7-
from art.tinker_native.backend import (
8-
ModelState,
9-
TinkerNativeBackend,
10-
_apply_kl_penalty,
11-
)
5+
from art.tinker_native.backend import TinkerNativeBackend, _apply_kl_penalty
126
from art.tinker_native.data import build_datum
137

148

15-
class FakeSamplingClient:
9+
class FakeSamplingClient(tinker.SamplingClient):
1610
def __init__(self, responses: dict[tuple[int, ...], list[float | None]]) -> None:
1711
self._responses = responses
1812

@@ -48,7 +42,7 @@ async def test_incorporate_kl_penalty_rewrites_advantages_in_place() -> None:
4842

4943
metrics = await _apply_kl_penalty(
5044
[datum_a, datum_b],
51-
sampling_client, # type: ignore[arg-type]
45+
sampling_client,
5246
kl_penalty_coef=2.0,
5347
)
5448

@@ -70,28 +64,14 @@ async def test_tinker_native_backend_rejects_current_learner_kl_source(
7064
base_model="test-model",
7165
base_path=str(tmp_path),
7266
)
73-
backend._model_state[model.name] = ModelState(
74-
service_client=cast(Any, object()),
75-
rest_client=cast(Any, object()),
76-
training_client=cast(Any, object()),
77-
sampler_clients={},
78-
sampler_checkpoint_paths={},
79-
training_checkpoint_paths={},
80-
current_step=0,
81-
renderer=cast(Any, object()),
82-
tokenizer=cast(Any, object()),
83-
output_dir=str(tmp_path),
84-
tinker_run_ids=[],
85-
model_name=model.name,
86-
)
8767

8868
with pytest.raises(
8969
AssertionError,
9070
match="only supports kl_penalty_source='sample'",
9171
):
92-
await cast(Any, backend).train(
72+
await backend.train(
9373
model,
9474
[],
9575
kl_penalty_coef=0.25,
96-
kl_penalty_source="current_learner",
76+
kl_penalty_source="current_learner", # ty:ignore[invalid-argument-type]
9777
)

0 commit comments

Comments
 (0)