Skip to content

Commit 59778f0

Browse files
committed
fix: update step handling in LocalBackend and improve error handling
- Changed the default step assignment in get_inference_name() to use a method for better accuracy. - Enhanced error handling when importing UnslothService to prevent crashes if the module is not found. - Refactored assertion for clarity in gradient step validation.
1 parent 8d18989 commit 59778f0

1 file changed

Lines changed: 15 additions & 9 deletions

File tree

src/art/local/backend.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -128,9 +128,10 @@ def _model_inference_name(self, model: Model, step: int | None = None) -> str:
128128
step: If provided, returns name for specific checkpoint.
129129
If None, returns name for latest checkpoint (step 0 initially).
130130
"""
131+
131132
# For LocalBackend, vLLM always serves LoRA adapters with @step suffix
132133
# Default to step 0 when not specified (the initial checkpoint created at registration)
133-
actual_step = step if step is not None else 0
134+
actual_step = step if step is not None else self.__get_step(model)
134135
return f"{model.name}@{actual_step}"
135136

136137
async def _get_service(self, model: TrainableModel) -> ModelService:
@@ -573,12 +574,17 @@ async def _train_model(
573574
f"Advanced step from {current_step} to {next_step} (no training occurred)"
574575
)
575576

576-
# Register the renamed checkpoint as a new LoRA adapter
577-
# so it's available for inference at the new step
578-
from ..unsloth.service import UnslothService
577+
try:
578+
# Register the renamed checkpoint as a new LoRA adapter
579+
# so it's available for inference at the new step
580+
from ..unsloth.service import UnslothService
579581

580-
if isinstance(service, UnslothService):
581-
await service.register_lora_for_step(next_step, next_checkpoint_dir)
582+
if isinstance(service, UnslothService):
583+
await service.register_lora_for_step(
584+
next_step, next_checkpoint_dir
585+
)
586+
except ModuleNotFoundError:
587+
pass # Unsloth is not installed
582588

583589
# Yield metrics showing no groups were trainable
584590
# (the frontend will handle logging)
@@ -601,9 +607,9 @@ async def _train_model(
601607
num_gradient_steps = int(
602608
result.pop("num_gradient_steps", estimated_gradient_steps)
603609
)
604-
assert num_gradient_steps == estimated_gradient_steps, (
605-
f"num_gradient_steps {num_gradient_steps} != estimated_gradient_steps {estimated_gradient_steps}"
606-
)
610+
assert (
611+
num_gradient_steps == estimated_gradient_steps
612+
), f"num_gradient_steps {num_gradient_steps} != estimated_gradient_steps {estimated_gradient_steps}"
607613
results.append(result)
608614
yield {**result, "num_gradient_steps": num_gradient_steps}
609615
pbar.update(1)

0 commit comments

Comments
 (0)