-
Notifications
You must be signed in to change notification settings - Fork 804
Expand file tree
/
Copy pathbackend.py
More file actions
1093 lines (978 loc) · 40.8 KB
/
backend.py
File metadata and controls
1093 lines (978 loc) · 40.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
from __future__ import annotations
import asyncio
from dataclasses import dataclass
import json
import os
import re
import time
from typing import Any, Awaitable, Iterable, Literal, TypeVar, cast
import uuid
from fastapi import FastAPI, HTTPException
from openai import AsyncOpenAI
from openai.types.chat.chat_completion import ChatCompletion, Choice, ChoiceLogprobs
from openai.types.chat.chat_completion_message import ChatCompletionMessage
from openai.types.chat.chat_completion_message_function_tool_call import (
ChatCompletionMessageFunctionToolCall,
Function,
)
from openai.types.chat.chat_completion_message_tool_call import (
ChatCompletionMessageToolCallUnion,
)
from openai.types.chat.chat_completion_token_logprob import ChatCompletionTokenLogprob
from openai.types.chat.completion_create_params import CompletionCreateParams
from openai.types.completion_usage import CompletionUsage
import tinker
import torch
import uvicorn
from art.tinker.cookbook_v import renderers, tokenizer_utils
from .. import dev
from ..backend import Backend
from ..costs import build_cost_calculator, compute_train_cost, get_model_pricing
from ..metrics_taxonomy import (
build_training_summary_metrics,
summarize_trajectory_groups,
)
from ..model import Model, TrainableModel
from ..tinker.backend import get_renderer_name
from ..tinker.server import get_free_port
from ..trajectories import TrajectoryGroup
from ..types import TrainResult
from ..utils.output_dirs import get_model_dir
from ..utils.trajectory_migration import auto_migrate_on_register
from .data import (
convert_openai_messages_to_renderer_format,
parse_completion_to_openai_message,
trajectory_groups_to_datums,
)
STATE_KEY_RUN_IDS = "tinker_run_ids"
STATE_KEY_LATEST_STEP = "latest_step"
T = TypeVar("T")
_UPSTREAM_TRAIN_METRIC_KEYS = {
"reward": "reward",
"reward_std_dev": "reward_std_dev",
"exception_rate": "exception_rate",
"policy_loss": "loss/train",
"loss": "loss/train",
"entropy": "loss/entropy",
"kl_div": "loss/kl_div",
"kl_policy_ref": "loss/kl_policy_ref",
"grad_norm": "loss/grad_norm",
"learning_rate": "loss/learning_rate",
"num_groups_submitted": "data/step_num_groups_submitted",
"num_groups_trainable": "data/step_num_groups_trainable",
"num_trajectories": "data/step_num_trajectories",
"num_trainable_tokens": "data/step_trainer_tokens",
"train_tokens": "data/step_trainer_tokens",
"num_datums": "data/step_num_datums",
}
def _canonicalize_upstream_metric_key(metric: str) -> str:
if "/" in metric:
return metric
if metric == "tokens_per_second":
return ""
if metric.startswith("group_metric_"):
return f"group_{metric[len('group_metric_') :]}"
return _UPSTREAM_TRAIN_METRIC_KEYS.get(metric, metric)
async def _apply_kl_penalty(
datums: list[tinker.Datum],
reference_sampling_client: tinker.SamplingClient,
kl_penalty_coef: float,
) -> dict[str, float]:
assert datums
assert kl_penalty_coef > 0.0
full_sequences: list[tinker.ModelInput] = []
sampled_logprobs_by_datum: list[torch.Tensor] = []
masks_by_datum: list[torch.Tensor] = []
advantages_by_datum: list[torch.Tensor] = []
for datum in datums:
target_tokens = datum.loss_fn_inputs["target_tokens"].to_torch()
assert target_tokens.numel() > 0
full_sequences.append(
datum.model_input.append_int(int(target_tokens[-1].item()))
)
sampled_logprobs_by_datum.append(datum.loss_fn_inputs["logprobs"].to_torch())
masks_by_datum.append(datum.loss_fn_inputs["mask"].to_torch().float())
advantages_by_datum.append(datum.loss_fn_inputs["advantages"].to_torch())
reference_logprobs_by_datum = await asyncio.gather(
*[
reference_sampling_client.compute_logprobs_async(full_sequence)
for full_sequence in full_sequences
]
)
logprob_diffs_by_datum: list[torch.Tensor] = []
for reference_logprobs, sampled_logprobs, mask in zip(
reference_logprobs_by_datum,
sampled_logprobs_by_datum,
masks_by_datum,
strict=True,
):
reference_values = reference_logprobs[1:]
assert len(reference_values) == sampled_logprobs.numel()
assert all(value is not None for value in reference_values)
reference_logprobs_tensor = torch.tensor(
reference_values,
dtype=sampled_logprobs.dtype,
)
logprob_diffs_by_datum.append(
(sampled_logprobs - reference_logprobs_tensor) * mask
)
total_tokens = torch.stack([mask.sum() for mask in masks_by_datum]).sum()
assert total_tokens.item() > 0
avg_logprob_diff = (
torch.stack(
[logprob_diff.sum() for logprob_diff in logprob_diffs_by_datum]
).sum()
/ total_tokens
)
for datum, advantages, mask, logprob_diff in zip(
datums,
advantages_by_datum,
masks_by_datum,
logprob_diffs_by_datum,
strict=True,
):
datum.loss_fn_inputs["advantages"] = tinker.TensorData.from_torch(
advantages + kl_penalty_coef * (avg_logprob_diff - logprob_diff) * mask
)
return {"loss/kl_policy_ref": float(avg_logprob_diff)}
@dataclass
class ModelState:
service_client: tinker.ServiceClient
rest_client: Any
training_client: tinker.TrainingClient
sampler_clients: dict[int, tinker.SamplingClient]
sampler_checkpoint_paths: dict[int, str]
training_checkpoint_paths: dict[int, str]
current_step: int
renderer: Any
tokenizer: Any
output_dir: str
tinker_run_ids: list[str]
model_name: str
server_task: asyncio.Task[None] | None = None
server_host: str | None = None
server_port: int | None = None
server_api_key: str | None = None
@dataclass
class TinkerNativeModelConfig:
renderer_name: str
training_client_args: dict[str, Any]
class TinkerNativeBackend(Backend):
_tinker_train_log_env = "ART_TINKER_TRAIN_LOG"
_tinker_sample_log_env = "ART_TINKER_SAMPLE_LOG"
def __init__(
self,
*,
tinker_api_key: str | None = None,
path: str | None = None,
) -> None:
if not "TINKER_API_KEY" in os.environ or tinker_api_key is not None:
assert tinker_api_key is not None, (
"TINKER_API_KEY is not set and no tinker_api_key was provided"
)
print("Setting TINKER_API_KEY to", tinker_api_key, "in environment")
os.environ["TINKER_API_KEY"] = tinker_api_key
self._path = path or ".art"
os.makedirs(self._path, exist_ok=True)
self._model_state: dict[str, ModelState] = {}
def _env_enabled(self, env_name: str) -> bool:
value = os.getenv(env_name)
if value is None:
return False
return value.lower() not in ("", "0", "false", "no")
def _timestamp(self) -> str:
return time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
async def _tinker_call(
self,
label: str,
awaitable: Awaitable[T],
*,
env_name: str,
prefix: str,
) -> T:
if not self._env_enabled(env_name):
return await awaitable
start = time.perf_counter()
print(f"[tinker:{prefix}] {label} start {self._timestamp()}")
try:
return await awaitable
finally:
elapsed = time.perf_counter() - start
print(
f"[tinker:{prefix}] {label} done in {elapsed:.2f}s "
f"at {self._timestamp()}"
)
async def _tinker_train_call(self, label: str, awaitable: Awaitable[T]) -> T:
return await self._tinker_call(
label,
awaitable,
env_name=self._tinker_train_log_env,
prefix="train",
)
async def _tinker_sample_call(self, label: str, awaitable: Awaitable[T]) -> T:
return await self._tinker_call(
label,
awaitable,
env_name=self._tinker_sample_log_env,
prefix="sample",
)
async def close(self) -> None:
for state in self._model_state.values():
if state.server_task is not None:
state.server_task.cancel()
async def register(self, model: Model) -> None:
model.base_path = self._path
output_dir = get_model_dir(model=model, art_path=self._path)
os.makedirs(output_dir, exist_ok=True)
with open(f"{output_dir}/model.json", "w") as f:
import json
json.dump(model.model_dump(), f)
auto_migrate_on_register(output_dir)
if not model.trainable:
return
trainable_model = cast(TrainableModel, model)
pricing = get_model_pricing(trainable_model.base_model)
if pricing is not None:
trainable_model.set_cost_calculator(build_cost_calculator(pricing))
state = await self._build_model_state(trainable_model)
self._model_state[model.name] = state
async def _prepare_backend_for_training(
self,
model: TrainableModel,
config: dev.OpenAIServerConfig | None = None,
) -> tuple[str, str]:
state = self._model_state[model.name]
raw_config: dict[str, Any] = cast(dict[str, Any], config) if config else {}
server_args = cast(dict[str, Any], raw_config.get("server_args", {}))
host = server_args.get("host", raw_config.get("host", "0.0.0.0"))
port = server_args.get("port", raw_config.get("port"))
if port is None:
port = get_free_port()
api_key = server_args.get("api_key", raw_config.get("api_key")) or "default"
if state.server_task is None:
state.server_host = host
state.server_port = port
state.server_api_key = api_key
state.server_task = asyncio.create_task(
self._run_openai_server(state, host=host, port=port)
)
state.server_task.add_done_callback(self._crash_on_server_exit)
base_url = f"http://{host}:{port}/v1"
await self._wait_for_server_ready(base_url, api_key, model)
return base_url, api_key
async def train( # type: ignore[override]
self,
model: TrainableModel,
trajectory_groups: Iterable[TrajectoryGroup],
*,
learning_rate: float = 1e-5,
loss_fn: Literal["cispo", "ppo", "importance_sampling", "dro"] = "cispo",
normalize_advantages: bool = True,
save_checkpoint: bool = False,
loss_fn_config: dict | None = None,
adam_params: tinker.AdamParams | None = None,
kl_penalty_coef: float = 0.0,
kl_penalty_reference_step: int | None = None,
kl_penalty_source: Literal["sample"] = "sample",
) -> TrainResult:
assert kl_penalty_source == "sample", (
"TinkerNativeBackend only supports kl_penalty_source='sample'."
)
state = self._model_state[model.name]
groups_list = list(trajectory_groups)
summary = summarize_trajectory_groups(groups_list)
datums = trajectory_groups_to_datums(
groups_list,
state.renderer,
state.tokenizer,
normalize_advantages,
)
metrics: dict[str, float] = {
**build_training_summary_metrics(
summary,
include_trainable_groups=True,
),
"data/step_num_datums": float(len(datums)),
}
if not datums:
return TrainResult(step=state.current_step, metrics=metrics)
train_tokens = 0
for datum in datums:
train_tokens += len(datum.model_input.to_ints())
metrics["data/step_trainer_tokens"] = float(train_tokens)
pricing = get_model_pricing(model.base_model)
if pricing is not None:
metrics["costs/train/tinker_train"] = compute_train_cost(
train_tokens, pricing
)
trainer_started = time.monotonic()
sampled_kl_policy_ref: float | None = None
if kl_penalty_coef > 0:
kl_metrics = await self._tinker_sample_call(
"apply_kl_penalty",
_apply_kl_penalty(
datums,
await self._get_kl_reference_sampling_client(
state,
model.base_model,
kl_penalty_reference_step,
),
kl_penalty_coef,
),
)
sampled_kl_policy_ref = kl_metrics["loss/kl_policy_ref"]
metrics.update(kl_metrics)
if adam_params is None:
adam_params = tinker.AdamParams(
learning_rate=learning_rate,
beta1=0.9,
beta2=0.95,
eps=1e-8,
)
def remove_mask(datum: tinker.Datum) -> tinker.Datum:
if "mask" not in datum.loss_fn_inputs:
return datum
loss_fn_inputs = {
key: value
for key, value in datum.loss_fn_inputs.items()
if key != "mask"
}
return tinker.Datum(
model_input=datum.model_input, loss_fn_inputs=loss_fn_inputs
)
forward_output = await self._tinker_train_call(
"forward_backward",
state.training_client.forward_backward(
[remove_mask(datum) for datum in datums],
loss_fn=loss_fn,
loss_fn_config=loss_fn_config,
),
)
optim_output = await self._tinker_train_call(
"optim_step", state.training_client.optim_step(adam_params)
)
if forward_output.metrics:
for key, value in forward_output.metrics.items():
if value is None:
continue
canonical_key = _canonicalize_upstream_metric_key(key)
if (
sampled_kl_policy_ref is not None
and canonical_key == "loss/kl_policy_ref"
):
continue
if canonical_key:
metrics[canonical_key] = float(value)
if optim_output.metrics:
for key, value in optim_output.metrics.items():
if value is None:
continue
canonical_key = _canonicalize_upstream_metric_key(key)
if (
sampled_kl_policy_ref is not None
and canonical_key == "loss/kl_policy_ref"
):
continue
if canonical_key:
metrics[canonical_key] = float(value)
next_step = state.current_step + 1
checkpoint_name = f"step_{next_step:06d}"
if save_checkpoint:
state_response, sampler_response = await self._save_checkpoint(
state.training_client, checkpoint_name
)
state.training_checkpoint_paths[next_step] = state_response.path
else:
sampler_response = await self._save_sampler_weights(
state.training_client, checkpoint_name
)
sampler_client = await self._tinker_train_call(
"create_sampling_client_async",
state.training_client.create_sampling_client_async(
model_path=sampler_response.path
),
)
state.sampler_clients[next_step] = sampler_client
state.sampler_checkpoint_paths[next_step] = sampler_response.path
state.current_step = next_step
self._persist_model_state(model, state)
metrics["time/step_trainer_s"] = time.monotonic() - trainer_started
return TrainResult(step=state.current_step, metrics=metrics)
async def _get_step(self, model: TrainableModel) -> int:
if model.name in self._model_state:
return self._model_state[model.name].current_step
state = model.read_state() or {}
return int(state.get(STATE_KEY_LATEST_STEP, 0))
async def _delete_checkpoint_files(
self,
model: TrainableModel,
steps_to_keep: list[int],
) -> None:
print("Checkpoint deletion is not yet implemented for TinkerNativeBackend.")
def _model_inference_name(self, model: Model, step: int | None = None) -> str:
base_name = model.inference_model_name or model.name
if "@" in base_name:
base_name = base_name.split("@", 1)[0]
if step is None:
state = self._model_state.get(model.name)
step = state.current_step if state is not None else 0
return f"{base_name}@{step}"
async def _run_openai_server(
self,
state: ModelState,
host: str,
port: int,
) -> None:
app = FastAPI()
@app.post("/v1/chat/completions")
async def chat_completions(body: CompletionCreateParams) -> ChatCompletion:
model_name = body.get("model")
parsed_model_name, step = self._parse_model_name(model_name)
sampler_client = await self._get_sampler_client(state, step)
messages = self._normalize_messages(body["messages"])
tools = self._normalize_tools(body.get("tools"))
renderer_messages = convert_openai_messages_to_renderer_format(
messages=messages,
tools=tools,
renderer=state.renderer,
)
prompt = state.renderer.build_generation_prompt(renderer_messages)
prompt_tokens = list(prompt.to_ints())
max_tokens = body.get("max_completion_tokens")
if max_tokens is None:
max_tokens = body.get("max_tokens")
temperature = body.get("temperature")
top_k = body.get("top_k")
top_p = body.get("top_p")
sampling_params = tinker.SamplingParams(
max_tokens=max_tokens,
seed=body.get("seed"),
temperature=temperature if temperature is not None else 1.0,
top_k=top_k if top_k is not None else -1,
top_p=top_p if top_p is not None else 1.0,
stop=state.renderer.get_stop_sequences(),
)
sample_response = await self._tinker_sample_call(
"sample_async",
sampler_client.sample_async(
prompt=prompt,
num_samples=body.get("n") or 1,
sampling_params=sampling_params,
),
)
choices: list[Choice] = []
for i, sequence in enumerate(sample_response.sequences):
if sequence.logprobs is None:
raise HTTPException(status_code=400, detail="Logprobs are required")
if len(sequence.tokens) != len(sequence.logprobs):
raise HTTPException(
status_code=400,
detail="Tokens and logprobs must have the same length",
)
parsed_message = parse_completion_to_openai_message(
list(sequence.tokens), state.renderer
)
content = parsed_message.get("content")
tool_calls: list[ChatCompletionMessageToolCallUnion] | None = None
if parsed_message.get("tool_calls"):
tool_calls = [
ChatCompletionMessageFunctionToolCall(
type="function",
id=tool_call.get("id") or f"call_{idx}",
function=Function(
name=tool_call["function"]["name"],
arguments=(
tool_call["function"]["arguments"]
if isinstance(
tool_call["function"]["arguments"], str
)
else json.dumps(tool_call["function"]["arguments"])
),
),
)
for idx, tool_call in enumerate(parsed_message["tool_calls"])
]
choices.append(
Choice(
finish_reason=sequence.stop_reason,
index=i,
message=ChatCompletionMessage(
content=content or None,
role="assistant",
tool_calls=tool_calls,
),
logprobs=ChoiceLogprobs(
content=[
ChatCompletionTokenLogprob(
token=f"token_id:{token}",
logprob=logprob,
top_logprobs=[],
)
for token, logprob in zip(
sequence.tokens, sequence.logprobs
)
]
),
)
)
completion_tokens = sum(
len(sequence.tokens) for sequence in sample_response.sequences
)
return ChatCompletion(
id=str(uuid.uuid4()),
choices=choices,
created=int(time.time()),
model=self._format_response_model(parsed_model_name, step),
object="chat.completion",
usage=CompletionUsage(
completion_tokens=completion_tokens,
prompt_tokens=len(prompt_tokens),
total_tokens=completion_tokens + len(prompt_tokens),
),
)
server_config = uvicorn.Config(app, host=host, port=port, log_level="error")
server = uvicorn.Server(server_config)
await server.serve()
def _crash_on_server_exit(self, task: asyncio.Task[None]) -> None:
try:
task.result()
except asyncio.CancelledError:
return
except Exception as exc:
print(f"OpenAI server crashed: {exc}")
else:
print("OpenAI server exited unexpectedly.")
os._exit(1)
async def _wait_for_server_ready(
self, base_url: str, api_key: str, model: TrainableModel
) -> None:
client = AsyncOpenAI(base_url=base_url, api_key=api_key)
with_timeout = float(os.environ.get("ART_SERVER_TIMEOUT", 300.0))
start = time.time()
while True:
if time.time() - start > with_timeout:
raise TimeoutError(
f"Unable to reach OpenAI-compatible server within {with_timeout} seconds."
)
try:
await client.chat.completions.create(
model=self._model_inference_name(model),
messages=[{"role": "user", "content": "Hello, world!"}],
max_completion_tokens=1,
)
return
except Exception:
await asyncio.sleep(0.1)
async def _build_model_state(self, model: TrainableModel) -> ModelState:
config = self._resolve_model_config(model)
service_client = tinker.ServiceClient()
rest_client = service_client.create_rest_client()
tokenizer = tokenizer_utils.get_tokenizer(model.base_model)
renderer = renderers.get_renderer(
name=config.renderer_name,
tokenizer=tokenizer,
model_name=model.base_model,
)
saved_state = model.read_state() or {}
tinker_run_ids = list(saved_state.get(STATE_KEY_RUN_IDS, []))
training_paths, sampler_paths = await self._list_checkpoints(
rest_client, tinker_run_ids
)
training_client: tinker.TrainingClient
current_step = 0
if training_paths:
current_step = max(training_paths.keys())
checkpoint_path = training_paths[current_step]
training_client = await self._create_training_client_from_checkpoint(
service_client=service_client,
checkpoint_state_path=checkpoint_path,
base_model=model.base_model,
training_client_args=config.training_client_args,
reset_optimizer=False,
)
else:
training_client = await self._tinker_train_call(
"create_lora_training_client_async",
service_client.create_lora_training_client_async(
model.base_model, **config.training_client_args
),
)
checkpoint_name = f"step_{current_step:06d}"
sampler_response = await self._save_sampler_weights(
training_client, checkpoint_name
)
sampler_paths[current_step] = sampler_response.path
run_id = training_client.model_id
if run_id not in tinker_run_ids:
tinker_run_ids.append(run_id)
sampler_clients: dict[int, tinker.SamplingClient] = {}
if current_step in sampler_paths:
sampler_clients[current_step] = await self._tinker_train_call(
"create_sampling_client_async",
training_client.create_sampling_client_async(
model_path=sampler_paths[current_step]
),
)
else:
checkpoint_name = f"step_{current_step:06d}"
sampler_response = await self._save_sampler_weights(
training_client, checkpoint_name
)
sampler_paths[current_step] = sampler_response.path
sampler_clients[current_step] = await self._tinker_train_call(
"create_sampling_client_async",
training_client.create_sampling_client_async(
model_path=sampler_response.path
),
)
state = ModelState(
service_client=service_client,
rest_client=rest_client,
training_client=training_client,
sampler_clients=sampler_clients,
sampler_checkpoint_paths=sampler_paths,
training_checkpoint_paths=training_paths,
current_step=current_step,
renderer=renderer,
tokenizer=tokenizer,
output_dir=get_model_dir(model=model, art_path=self._path),
tinker_run_ids=tinker_run_ids,
model_name=((model.inference_model_name or model.name).split("@", 1)[0]),
)
self._persist_model_state(model, state)
return state
def _resolve_model_config(self, model: TrainableModel) -> TinkerNativeModelConfig:
internal_config = model._internal_config or {}
tinker_native_args = cast(
dev.TinkerNativeArgs | None,
internal_config.get("tinker_native_args"),
)
renderer_name = (
tinker_native_args.get("renderer_name")
if tinker_native_args is not None
else None
)
if renderer_name is None:
renderer_name = get_renderer_name(model.base_model)
training_client_args = dict(
tinker_native_args.get("training_client_args", {})
if tinker_native_args is not None
else {}
)
if "rank" not in training_client_args:
training_client_args["rank"] = 8
if "train_unembed" not in training_client_args:
training_client_args["train_unembed"] = False
return TinkerNativeModelConfig(
renderer_name=renderer_name,
training_client_args=training_client_args,
)
async def _list_checkpoints(
self, rest_client: Any, tinker_run_ids: list[str]
) -> tuple[dict[int, str], dict[int, str]]:
training_paths: dict[int, str] = {}
sampler_paths: dict[int, str] = {}
step_pattern = re.compile(r"(?:weights/)?step_(\d+)$")
for run_id in tinker_run_ids:
try:
response = await self._tinker_train_call(
f"list_checkpoints_async {run_id}",
rest_client.list_checkpoints_async(run_id),
)
except Exception as exc:
print(f"Warning: Could not list checkpoints for {run_id}: {exc}")
continue
for checkpoint in response.checkpoints:
match = step_pattern.match(checkpoint.checkpoint_id)
if not match:
continue
step = int(match.group(1))
path = f"tinker://{run_id}/{checkpoint.checkpoint_id}"
if checkpoint.checkpoint_type == "training":
training_paths[step] = path
elif checkpoint.checkpoint_type == "sampler":
sampler_paths[step] = path
return training_paths, sampler_paths
async def _get_sampler_client(
self, state: ModelState, step: int | None
) -> tinker.SamplingClient:
actual_step = step if step is not None else state.current_step
if actual_step in state.sampler_clients:
return state.sampler_clients[actual_step]
if actual_step not in state.sampler_checkpoint_paths:
training_paths, sampler_paths = await self._list_checkpoints(
state.rest_client, state.tinker_run_ids
)
state.training_checkpoint_paths.update(training_paths)
state.sampler_checkpoint_paths.update(sampler_paths)
if actual_step not in state.sampler_checkpoint_paths:
available = sorted(state.sampler_checkpoint_paths.keys())
raise HTTPException(
status_code=404,
detail=f"No sampler checkpoint for step {actual_step}. Available: {available}",
)
sampler_client = await self._tinker_train_call(
"create_sampling_client_async",
state.training_client.create_sampling_client_async(
model_path=state.sampler_checkpoint_paths[actual_step]
),
)
state.sampler_clients[actual_step] = sampler_client
return sampler_client
async def _get_kl_reference_sampling_client(
self,
state: ModelState,
base_model: str,
step: int | None,
) -> tinker.SamplingClient:
if step is not None:
return await self._get_sampler_client(state, step)
return await self._tinker_sample_call(
"create_sampling_client_async",
state.service_client.create_sampling_client_async(base_model=base_model),
)
def _normalize_messages(self, messages: Iterable[Any]) -> list[dict[str, Any]]:
normalized: list[dict[str, Any]] = []
for message in messages:
if hasattr(message, "model_dump"):
normalized.append(message.model_dump())
else:
normalized.append(dict(message))
return normalized
def _normalize_tools(
self, tools: Iterable[Any] | None
) -> list[dict[str, Any]] | None:
if tools is None:
return None
normalized: list[dict[str, Any]] = []
for tool in tools:
if hasattr(tool, "model_dump"):
normalized.append(tool.model_dump())
else:
normalized.append(dict(tool))
return normalized
def _parse_model_name(self, model_name: str | None) -> tuple[str, int]:
if not model_name:
raise HTTPException(
status_code=400,
detail="Model name is required and must include an '@step' suffix. Use model.get_inference_name().",
)
if "@" not in model_name:
raise HTTPException(
status_code=400,
detail=(
f"Model '{model_name}' is missing an '@step' suffix. "
"Use model.get_inference_name()."
),
)
base_name, step_str = model_name.rsplit("@", 1)
try:
return base_name, int(step_str)
except ValueError as exc:
raise HTTPException(
status_code=400, detail=f"Invalid model step: {model_name}"
) from exc
def _format_response_model(self, model_name: str, step: int) -> str:
# Echo back the explicit model@step used for this completion.
return f"{model_name}@{step}"
async def _create_training_client_from_checkpoint(
self,
service_client: tinker.ServiceClient,
checkpoint_state_path: str,
base_model: str,
training_client_args: dict[str, Any],
reset_optimizer: bool = False,
) -> tinker.TrainingClient:
training_client = await self._tinker_train_call(
"create_lora_training_client_async",
service_client.create_lora_training_client_async(
base_model, **training_client_args
),
)
if reset_optimizer:
load_future = await self._tinker_train_call(
"load_state_async",
training_client.load_state_async(checkpoint_state_path),
)
await self._tinker_train_call(
"load_state_result_async", load_future.result_async()
)
else:
load_future = await self._tinker_train_call(
"load_state_with_optimizer_async",
training_client.load_state_with_optimizer_async(checkpoint_state_path),
)
await self._tinker_train_call(
"load_state_with_optimizer_result_async", load_future.result_async()
)
return training_client
async def _save_checkpoint(
self,
training_client: tinker.TrainingClient,
checkpoint_name: str,
) -> tuple[Any, Any]:
state_future, sampler_future = await asyncio.gather(
self._tinker_train_call(
"save_state_async",
training_client.save_state_async(checkpoint_name),
),
self._tinker_train_call(
"save_weights_for_sampler_async",
training_client.save_weights_for_sampler_async(checkpoint_name),
),
)
state_result = await self._tinker_train_call(
"save_state_result_async", state_future.result_async()
)
sampler_result = await self._tinker_train_call(
"save_weights_for_sampler_result_async", sampler_future.result_async()
)
return state_result, sampler_result
async def _save_sampler_weights(
self,
training_client: tinker.TrainingClient,
checkpoint_name: str,
) -> Any:
sampler_future = await self._tinker_train_call(
"save_weights_for_sampler_async",
training_client.save_weights_for_sampler_async(checkpoint_name),
)
return await self._tinker_train_call(
"save_weights_for_sampler_result_async", sampler_future.result_async()
)
async def _save_training_state(
self,
training_client: tinker.TrainingClient,
checkpoint_name: str,
) -> Any:
state_future = await self._tinker_train_call(
"save_state_async",
training_client.save_state_async(checkpoint_name),
)
return await self._tinker_train_call(
"save_state_result_async", state_future.result_async()
)
def _persist_model_state(self, model: TrainableModel, state: ModelState) -> None:
model.merge_state(
{
STATE_KEY_RUN_IDS: state.tinker_run_ids,
STATE_KEY_LATEST_STEP: state.current_step,
}
)
async def _experimental_fork_checkpoint(
self,
model: Model,
from_model: str,
from_project: str | None = None,
from_s3_bucket: str | None = None,
not_after_step: int | None = None,
verbose: bool = False,
prefix: str | None = None,
) -> None:
"""Fork a checkpoint from another TinkerNative model to initialize this model.
Loads the source model's training checkpoint into the destination model's
training client directly via tinker:// paths. No local download needed.
Args:
model: The destination model to fork to (must already be registered).
from_model: The name of the source model to fork from.
from_project: The project of the source model. Defaults to model.project.
from_s3_bucket: Not supported for TinkerNativeBackend.
not_after_step: If provided, uses the latest checkpoint <= this step.
verbose: Whether to print verbose output.
prefix: Not applicable for TinkerNativeBackend.
"""
if from_s3_bucket is not None:
raise NotImplementedError(
"from_s3_bucket is not supported for TinkerNativeBackend. "
"Tinker checkpoints are stored on Tinker infrastructure, not S3."
)
trainable_model = cast(TrainableModel, model)
if trainable_model.name not in self._model_state:
raise RuntimeError(
f"Model '{trainable_model.name}' is not registered. "
"Call register() before forking."
)
from_project = from_project or model.project
# Read the source model's state.json to get its tinker_run_ids
source_state_dir = get_model_dir(
Model(name=from_model, project=from_project),