Skip to content

Commit 176632c

Browse files
authored
Serverless fork (#550)
* Allow forking for ServerlessBackend * Add forking to TinkerNativeBackend
1 parent 6e45e60 commit 176632c

4 files changed

Lines changed: 709 additions & 2 deletions

File tree

src/art/serverless/backend.py

Lines changed: 216 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,22 @@
1313
from ..types import ServerlessTrainResult, TrainConfig
1414

1515
if TYPE_CHECKING:
16+
import wandb
17+
1618
from ..model import Model, TrainableModel
1719

1820

21+
def _extract_step_from_wandb_artifact(artifact: "wandb.Artifact") -> int | None:
22+
"""Extract step number from a W&B artifact's aliases."""
23+
for alias in artifact.aliases:
24+
if alias.startswith("step"):
25+
try:
26+
return int(alias[4:])
27+
except ValueError:
28+
pass
29+
return None
30+
31+
1932
class ServerlessBackend(Backend):
2033
def __init__(
2134
self, *, api_key: str | None = None, base_url: str | None = None
@@ -417,7 +430,58 @@ async def _experimental_push_to_s3(
417430
verbose: bool = False,
418431
delete: bool = False,
419432
) -> None:
420-
raise NotImplementedError
433+
"""Push model checkpoints from W&B artifacts to S3.
434+
435+
Downloads checkpoint(s) from W&B and uploads them to S3.
436+
437+
Args:
438+
model: The model whose checkpoints to push.
439+
s3_bucket: S3 bucket name. If None, uses BACKUP_BUCKET env var.
440+
prefix: Optional S3 prefix path.
441+
verbose: Whether to print verbose output.
442+
delete: Whether to delete files from S3 that don't exist in source.
443+
"""
444+
from art.utils.s3 import build_s3_path, ensure_bucket_exists, s3_sync
445+
446+
assert model.id is not None, "Model ID is required"
447+
448+
# Get all checkpoint steps
449+
steps: list[int] = []
450+
async for checkpoint in self._client.models.checkpoints.list( # ty:ignore[possibly-missing-attribute]
451+
model_id=model.id, order="asc"
452+
):
453+
steps.append(checkpoint.step)
454+
455+
if not steps:
456+
if verbose:
457+
print("No checkpoints found to push.")
458+
return
459+
460+
await ensure_bucket_exists(s3_bucket)
461+
462+
for step in steps:
463+
if verbose:
464+
print(f"Pushing checkpoint step {step} to S3...")
465+
466+
# Pull from W&B to local temp dir
467+
checkpoint_dir = await self._experimental_pull_model_checkpoint(
468+
model, # type: ignore[arg-type]
469+
step=step,
470+
verbose=verbose,
471+
)
472+
473+
# Push to S3
474+
s3_path = build_s3_path(
475+
model_name=model.name,
476+
project=model.project,
477+
step=step,
478+
s3_bucket=s3_bucket,
479+
prefix=prefix,
480+
)
481+
await s3_sync(checkpoint_dir, s3_path, verbose=verbose, delete=delete)
482+
483+
if verbose:
484+
print(f"Successfully pushed {len(steps)} checkpoint(s) to S3.")
421485

422486
async def _experimental_fork_checkpoint(
423487
self,
@@ -429,4 +493,154 @@ async def _experimental_fork_checkpoint(
429493
verbose: bool = False,
430494
prefix: str | None = None,
431495
) -> None:
432-
raise NotImplementedError
496+
"""Fork a checkpoint from another model to initialize this model.
497+
498+
Pulls the source checkpoint from W&B artifacts (or S3 if from_s3_bucket
499+
is provided) and uploads it as a W&B artifact for the destination model.
500+
501+
Note: This uploads the artifact directly to W&B. The ServerlessBackend's
502+
checkpoint tracking may not immediately reflect the forked checkpoint
503+
until the next training step.
504+
505+
Args:
506+
model: The destination model to fork to.
507+
from_model: The name of the source model to fork from.
508+
from_project: The project of the source model. Defaults to model.project.
509+
from_s3_bucket: Optional S3 bucket to pull the checkpoint from.
510+
not_after_step: If provided, uses the latest checkpoint <= this step.
511+
verbose: Whether to print verbose output.
512+
prefix: Optional S3 prefix for bucket operations.
513+
"""
514+
import os
515+
import tempfile
516+
517+
import wandb
518+
519+
from_project = from_project or model.project
520+
521+
if from_s3_bucket is not None:
522+
# Pull from S3
523+
from art.utils.s3 import build_s3_path, ensure_bucket_exists, s3_sync
524+
from art.utils.s3_checkpoint_utils import (
525+
get_checkpoint_step_not_after_from_s3,
526+
get_latest_checkpoint_step_from_s3,
527+
)
528+
529+
if not_after_step is None:
530+
target_step = await get_latest_checkpoint_step_from_s3(
531+
model_name=from_model,
532+
project=from_project,
533+
s3_bucket=from_s3_bucket,
534+
prefix=prefix,
535+
)
536+
else:
537+
target_step = await get_checkpoint_step_not_after_from_s3(
538+
model_name=from_model,
539+
project=from_project,
540+
not_after_step=not_after_step,
541+
s3_bucket=from_s3_bucket,
542+
prefix=prefix,
543+
)
544+
545+
if target_step is None:
546+
raise ValueError(
547+
f"No suitable checkpoint found in S3 for model {from_model}"
548+
)
549+
550+
if verbose:
551+
print(f"Pulling checkpoint step {target_step} from S3...")
552+
553+
checkpoint_dir = os.path.join(
554+
tempfile.gettempdir(),
555+
"art_fork_checkpoints",
556+
from_project,
557+
from_model,
558+
f"{target_step:04d}",
559+
)
560+
os.makedirs(checkpoint_dir, exist_ok=True)
561+
562+
s3_path = build_s3_path(
563+
model_name=from_model,
564+
project=from_project,
565+
step=target_step,
566+
s3_bucket=from_s3_bucket,
567+
prefix=prefix,
568+
)
569+
await ensure_bucket_exists(from_s3_bucket)
570+
await s3_sync(s3_path, checkpoint_dir, verbose=verbose)
571+
selected_step = target_step
572+
else:
573+
# Pull from W&B artifacts
574+
api = wandb.Api(api_key=self._client.api_key) # ty:ignore[possibly-missing-attribute]
575+
from_entity = model.entity or api.default_entity
576+
577+
# Iterate all artifact versions to find the best step.
578+
# We avoid relying on the W&B `:latest` alias because it
579+
# may not correspond to the highest training step.
580+
collection_path = f"{from_entity}/{from_project}/{from_model}"
581+
versions = api.artifacts("lora", collection_path)
582+
583+
best_step: int | None = None
584+
best_artifact = None
585+
for version in versions:
586+
step_num = _extract_step_from_wandb_artifact(version)
587+
if step_num is None:
588+
continue
589+
if not_after_step is not None and step_num > not_after_step:
590+
continue
591+
if best_step is None or step_num > best_step:
592+
best_step = step_num
593+
best_artifact = version
594+
595+
if best_step is None or best_artifact is None:
596+
if not_after_step is not None:
597+
raise ValueError(
598+
f"No checkpoints found not after step {not_after_step} "
599+
f"for model {from_model}"
600+
)
601+
raise ValueError(f"No checkpoints found for model {from_model}")
602+
selected_step = best_step
603+
artifact = best_artifact
604+
605+
checkpoint_dir = os.path.join(
606+
tempfile.gettempdir(),
607+
"art_fork_checkpoints",
608+
from_project,
609+
from_model,
610+
f"{selected_step:04d}" if selected_step is not None else "latest",
611+
)
612+
os.makedirs(checkpoint_dir, exist_ok=True)
613+
artifact.download(root=checkpoint_dir)
614+
615+
if verbose:
616+
print(f"Downloaded source checkpoint step {selected_step} from W&B")
617+
618+
# Upload as W&B artifact for the destination model
619+
assert model.entity is not None, "Model entity is required"
620+
621+
if verbose:
622+
print(f"Uploading forked checkpoint as W&B artifact for {model.name}...")
623+
624+
wandb.login(key=self._client.api_key) # ty:ignore[possibly-missing-attribute]
625+
run = wandb.init(
626+
project=model.project,
627+
entity=model.entity,
628+
job_type="checkpoint-fork",
629+
name=f"fork-{from_model}-to-{model.name}",
630+
settings=wandb.Settings(silent=True),
631+
)
632+
assert run is not None
633+
634+
dest_artifact = wandb.Artifact(name=model.name, type="lora")
635+
dest_artifact.add_dir(checkpoint_dir)
636+
aliases = ["latest"]
637+
if selected_step is not None:
638+
aliases.insert(0, f"step{selected_step}")
639+
run.log_artifact(dest_artifact, aliases=aliases)
640+
run.finish()
641+
642+
if verbose:
643+
print(
644+
f"Successfully forked checkpoint from {from_model} "
645+
f"(step {selected_step}) to {model.name}"
646+
)

src/art/tinker_native/backend.py

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -778,3 +778,140 @@ def _persist_model_state(self, model: TrainableModel, state: ModelState) -> None
778778
STATE_KEY_LATEST_STEP: state.current_step,
779779
}
780780
)
781+
782+
async def _experimental_fork_checkpoint(
783+
self,
784+
model: Model,
785+
from_model: str,
786+
from_project: str | None = None,
787+
from_s3_bucket: str | None = None,
788+
not_after_step: int | None = None,
789+
verbose: bool = False,
790+
prefix: str | None = None,
791+
) -> None:
792+
"""Fork a checkpoint from another TinkerNative model to initialize this model.
793+
794+
Loads the source model's training checkpoint into the destination model's
795+
training client directly via tinker:// paths. No local download needed.
796+
797+
Args:
798+
model: The destination model to fork to (must already be registered).
799+
from_model: The name of the source model to fork from.
800+
from_project: The project of the source model. Defaults to model.project.
801+
from_s3_bucket: Not supported for TinkerNativeBackend.
802+
not_after_step: If provided, uses the latest checkpoint <= this step.
803+
verbose: Whether to print verbose output.
804+
prefix: Not applicable for TinkerNativeBackend.
805+
"""
806+
if from_s3_bucket is not None:
807+
raise NotImplementedError(
808+
"from_s3_bucket is not supported for TinkerNativeBackend. "
809+
"Tinker checkpoints are stored on Tinker infrastructure, not S3."
810+
)
811+
812+
trainable_model = cast(TrainableModel, model)
813+
814+
if trainable_model.name not in self._model_state:
815+
raise RuntimeError(
816+
f"Model '{trainable_model.name}' is not registered. "
817+
"Call register() before forking."
818+
)
819+
820+
from_project = from_project or model.project
821+
822+
# Read the source model's state.json to get its tinker_run_ids
823+
source_state_dir = get_model_dir(
824+
Model(name=from_model, project=from_project),
825+
art_path=self._path,
826+
)
827+
source_state_path = f"{source_state_dir}/state.json"
828+
import json
829+
830+
if not os.path.exists(source_state_path):
831+
raise FileNotFoundError(
832+
f"Source model state not found at {source_state_path}. "
833+
f"Ensure the source model '{from_model}' has been trained "
834+
f"with this backend."
835+
)
836+
with open(source_state_path, "r") as f:
837+
source_state = json.load(f)
838+
839+
source_run_ids = list(source_state.get(STATE_KEY_RUN_IDS, []))
840+
if not source_run_ids:
841+
raise ValueError(
842+
f"Source model '{from_model}' has no tinker run IDs in its state."
843+
)
844+
845+
# List source model's checkpoints
846+
dest_state = self._model_state[trainable_model.name]
847+
training_paths, sampler_paths = await self._list_checkpoints(
848+
dest_state.rest_client, source_run_ids
849+
)
850+
851+
if not training_paths:
852+
raise ValueError(
853+
f"No training checkpoints found for source model '{from_model}'."
854+
)
855+
856+
# Select the target step
857+
available_steps = sorted(training_paths.keys())
858+
if not_after_step is not None:
859+
eligible_steps = [s for s in available_steps if s <= not_after_step]
860+
if not eligible_steps:
861+
raise ValueError(
862+
f"No checkpoint found at or before step {not_after_step}. "
863+
f"Available steps: {available_steps}"
864+
)
865+
target_step = max(eligible_steps)
866+
else:
867+
target_step = max(available_steps)
868+
869+
source_checkpoint_path = training_paths[target_step]
870+
if verbose:
871+
print(
872+
f"Forking from '{from_model}' step {target_step} "
873+
f"(checkpoint: {source_checkpoint_path})"
874+
)
875+
876+
# Load the source checkpoint into a new training client
877+
config = self._resolve_model_config(trainable_model)
878+
new_training_client = await self._create_training_client_from_checkpoint(
879+
service_client=dest_state.service_client,
880+
checkpoint_state_path=source_checkpoint_path,
881+
base_model=trainable_model.base_model,
882+
training_client_args=config.training_client_args,
883+
reset_optimizer=True,
884+
)
885+
886+
# Save new sampler weights
887+
checkpoint_name = f"step_{target_step:06d}"
888+
sampler_response = await self._save_sampler_weights(
889+
new_training_client, checkpoint_name
890+
)
891+
892+
# Create a sampler client from the new weights
893+
sampler_client = await self._tinker_train_call(
894+
"create_sampling_client_async",
895+
new_training_client.create_sampling_client_async(
896+
model_path=sampler_response.path
897+
),
898+
)
899+
900+
# Update the destination model's state
901+
new_run_id = new_training_client.model_id
902+
if new_run_id not in dest_state.tinker_run_ids:
903+
dest_state.tinker_run_ids.append(new_run_id)
904+
905+
dest_state.training_client = new_training_client
906+
dest_state.current_step = target_step
907+
dest_state.sampler_clients[target_step] = sampler_client
908+
dest_state.sampler_checkpoint_paths[target_step] = sampler_response.path
909+
dest_state.training_checkpoint_paths[target_step] = source_checkpoint_path
910+
911+
self._persist_model_state(trainable_model, dest_state)
912+
913+
if verbose:
914+
print(
915+
f"Fork complete. Model '{trainable_model.name}' is now at "
916+
f"step {target_step}."
917+
)

0 commit comments

Comments
 (0)