1- from typing import Any , cast
2-
31import pytest
42import tinker
53
64from art import TrainableModel
7- from art .tinker_native .backend import (
8- ModelState ,
9- TinkerNativeBackend ,
10- _apply_kl_penalty ,
11- )
5+ from art .tinker_native .backend import TinkerNativeBackend , _apply_kl_penalty
126from art .tinker_native .data import build_datum
137
148
15- class FakeSamplingClient :
9+ class FakeSamplingClient ( tinker . SamplingClient ) :
1610 def __init__ (self , responses : dict [tuple [int , ...], list [float | None ]]) -> None :
1711 self ._responses = responses
1812
@@ -48,7 +42,7 @@ async def test_incorporate_kl_penalty_rewrites_advantages_in_place() -> None:
4842
4943 metrics = await _apply_kl_penalty (
5044 [datum_a , datum_b ],
51- sampling_client , # type: ignore[arg-type]
45+ sampling_client ,
5246 kl_penalty_coef = 2.0 ,
5347 )
5448
@@ -70,28 +64,14 @@ async def test_tinker_native_backend_rejects_current_learner_kl_source(
7064 base_model = "test-model" ,
7165 base_path = str (tmp_path ),
7266 )
73- backend ._model_state [model .name ] = ModelState (
74- service_client = cast (Any , object ()),
75- rest_client = cast (Any , object ()),
76- training_client = cast (Any , object ()),
77- sampler_clients = {},
78- sampler_checkpoint_paths = {},
79- training_checkpoint_paths = {},
80- current_step = 0 ,
81- renderer = cast (Any , object ()),
82- tokenizer = cast (Any , object ()),
83- output_dir = str (tmp_path ),
84- tinker_run_ids = [],
85- model_name = model .name ,
86- )
8767
8868 with pytest .raises (
8969 AssertionError ,
9070 match = "only supports kl_penalty_source='sample'" ,
9171 ):
92- await cast ( Any , backend ) .train (
72+ await backend .train (
9373 model ,
9474 [],
9575 kl_penalty_coef = 0.25 ,
96- kl_penalty_source = "current_learner" ,
76+ kl_penalty_source = "current_learner" , # ty:ignore[invalid-argument-type]
9777 )
0 commit comments