2424from openai .types .chat .completion_create_params import CompletionCreateParams
2525from openai .types .completion_usage import CompletionUsage
2626import tinker
27+ import torch
2728import uvicorn
2829
2930from art .tinker .cookbook_v import renderers , tokenizer_utils
@@ -82,6 +83,76 @@ def _canonicalize_upstream_metric_key(metric: str) -> str:
8283 return _UPSTREAM_TRAIN_METRIC_KEYS .get (metric , metric )
8384
8485
86+ async def _apply_kl_penalty (
87+ datums : list [tinker .Datum ],
88+ reference_sampling_client : tinker .SamplingClient ,
89+ kl_penalty_coef : float ,
90+ ) -> dict [str , float ]:
91+ assert datums
92+ assert kl_penalty_coef > 0.0
93+
94+ full_sequences : list [tinker .ModelInput ] = []
95+ sampled_logprobs_by_datum : list [torch .Tensor ] = []
96+ masks_by_datum : list [torch .Tensor ] = []
97+ advantages_by_datum : list [torch .Tensor ] = []
98+ for datum in datums :
99+ target_tokens = datum .loss_fn_inputs ["target_tokens" ].to_torch ()
100+ assert target_tokens .numel () > 0
101+ full_sequences .append (
102+ datum .model_input .append_int (int (target_tokens [- 1 ].item ()))
103+ )
104+ sampled_logprobs_by_datum .append (datum .loss_fn_inputs ["logprobs" ].to_torch ())
105+ masks_by_datum .append (datum .loss_fn_inputs ["mask" ].to_torch ().float ())
106+ advantages_by_datum .append (datum .loss_fn_inputs ["advantages" ].to_torch ())
107+
108+ reference_logprobs_by_datum = await asyncio .gather (
109+ * [
110+ reference_sampling_client .compute_logprobs_async (full_sequence )
111+ for full_sequence in full_sequences
112+ ]
113+ )
114+
115+ logprob_diffs_by_datum : list [torch .Tensor ] = []
116+ for reference_logprobs , sampled_logprobs , mask in zip (
117+ reference_logprobs_by_datum ,
118+ sampled_logprobs_by_datum ,
119+ masks_by_datum ,
120+ strict = True ,
121+ ):
122+ reference_values = reference_logprobs [1 :]
123+ assert len (reference_values ) == sampled_logprobs .numel ()
124+ assert all (value is not None for value in reference_values )
125+ reference_logprobs_tensor = torch .tensor (
126+ reference_values ,
127+ dtype = sampled_logprobs .dtype ,
128+ )
129+ logprob_diffs_by_datum .append (
130+ (sampled_logprobs - reference_logprobs_tensor ) * mask
131+ )
132+
133+ total_tokens = torch .stack ([mask .sum () for mask in masks_by_datum ]).sum ()
134+ assert total_tokens .item () > 0
135+ avg_logprob_diff = (
136+ torch .stack (
137+ [logprob_diff .sum () for logprob_diff in logprob_diffs_by_datum ]
138+ ).sum ()
139+ / total_tokens
140+ )
141+
142+ for datum , advantages , mask , logprob_diff in zip (
143+ datums ,
144+ advantages_by_datum ,
145+ masks_by_datum ,
146+ logprob_diffs_by_datum ,
147+ strict = True ,
148+ ):
149+ datum .loss_fn_inputs ["advantages" ] = tinker .TensorData .from_torch (
150+ advantages + kl_penalty_coef * (avg_logprob_diff - logprob_diff ) * mask
151+ )
152+
153+ return {"loss/kl_policy_ref" : float (avg_logprob_diff )}
154+
155+
85156@dataclass
86157class ModelState :
87158 service_client : tinker .ServiceClient
@@ -239,6 +310,9 @@ async def train( # type: ignore[override]
239310 save_checkpoint : bool = False ,
240311 loss_fn_config : dict | None = None ,
241312 adam_params : tinker .AdamParams | None = None ,
313+ kl_penalty_coef : float = 0.0 ,
314+ kl_penalty_reference_step : int | None = None ,
315+ kl_penalty_source : Literal ["sample" ] = "sample" ,
242316 ) -> TrainResult :
243317 state = self ._model_state [model .name ]
244318 groups_list = list (trajectory_groups )
@@ -259,6 +333,10 @@ async def train( # type: ignore[override]
259333 "data/step_num_datums" : float (len (datums )),
260334 }
261335
336+ assert kl_penalty_source == "sample" , (
337+ "TinkerNativeBackend only supports kl_penalty_source='sample'."
338+ )
339+
262340 if not datums :
263341 return TrainResult (step = state .current_step , metrics = metrics )
264342
@@ -273,6 +351,23 @@ async def train( # type: ignore[override]
273351 )
274352 trainer_started = time .monotonic ()
275353
354+ if kl_penalty_coef > 0 :
355+ reference_sampling_client = await self ._get_kl_reference_sampling_client (
356+ state ,
357+ model .base_model ,
358+ kl_penalty_reference_step ,
359+ )
360+ metrics .update (
361+ await self ._tinker_sample_call (
362+ "apply_kl_penalty" ,
363+ _apply_kl_penalty (
364+ datums ,
365+ reference_sampling_client ,
366+ kl_penalty_coef ,
367+ ),
368+ )
369+ )
370+
276371 if adam_params is None :
277372 adam_params = tinker .AdamParams (
278373 learning_rate = learning_rate ,
@@ -697,6 +792,19 @@ async def _get_sampler_client(
697792 state .sampler_clients [actual_step ] = sampler_client
698793 return sampler_client
699794
795+ async def _get_kl_reference_sampling_client (
796+ self ,
797+ state : ModelState ,
798+ base_model : str ,
799+ step : int | None ,
800+ ) -> tinker .SamplingClient :
801+ if step is not None :
802+ return await self ._get_sampler_client (state , step )
803+ return await self ._tinker_sample_call (
804+ "create_sampling_client_async" ,
805+ state .service_client .create_sampling_client_async (base_model = base_model ),
806+ )
807+
700808 def _normalize_messages (self , messages : Iterable [Any ]) -> list [dict [str , Any ]]:
701809 normalized : list [dict [str , Any ]] = []
702810 for message in messages :
0 commit comments