Skip to content

Commit 32ff748

Browse files
committed
fix: Refactor temporal clue example to enhance S3 backup handling and streamline training loop
- Updated the `temporal-clue.py` script to check for the `BACKUP_BUCKET` environment variable before pulling from and pushing to S3, improving robustness. - Simplified the training loop structure for better readability and maintainability. - Adjusted the `LocalBackend` class to handle potential import errors gracefully when deleting checkpoints. This change enhances the overall functionality and error handling of the training process.
1 parent 5207db0 commit 32ff748

2 files changed

Lines changed: 48 additions & 32 deletions

File tree

examples/temporal_clue/temporal-clue.py

Lines changed: 38 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -60,35 +60,47 @@ async def main():
6060
base_model="Qwen/Qwen2.5-7B-Instruct",
6161
_internal_config={"init_args": {"gpu_memory_utilization": 0.775}},
6262
)
63-
backend = LocalBackend()
64-
await backend._experimental_pull_from_s3(model)
65-
await model.register(backend)
63+
with LocalBackend() as backend:
64+
if "BACKUP_BUCKET" in os.environ:
65+
await backend._experimental_pull_from_s3(
66+
model,
67+
s3_bucket=os.environ["BACKUP_BUCKET"],
68+
verbose=True,
69+
)
70+
else:
71+
print("BACKUP_BUCKET not found in environment variables")
72+
await model.register(backend)
6673

67-
stride = 4
68-
for i in range(await model.get_step(), 1_000):
69-
val_groups, train_groups = await asyncio.gather(
70-
art.gather_trajectory_groups(
71-
(
72-
art.TrajectoryGroup(rollout(model, puzzle) for _ in range(2))
73-
for puzzle in val_puzzles
74+
stride = 4
75+
for i in range(await model.get_step(), 1_000):
76+
val_groups, train_groups = await asyncio.gather(
77+
art.gather_trajectory_groups(
78+
(
79+
art.TrajectoryGroup(rollout(model, puzzle) for _ in range(2))
80+
for puzzle in val_puzzles
81+
),
82+
pbar_desc="val",
7483
),
75-
pbar_desc="val",
76-
),
77-
art.gather_trajectory_groups(
78-
(
79-
art.TrajectoryGroup(rollout(model, puzzle) for _ in range(50))
80-
for puzzle in train_puzzles[i * stride : (i + 1) * stride]
84+
art.gather_trajectory_groups(
85+
(
86+
art.TrajectoryGroup(rollout(model, puzzle) for _ in range(50))
87+
for puzzle in train_puzzles[i * stride : (i + 1) * stride]
88+
),
89+
pbar_desc="train",
8190
),
82-
pbar_desc="train",
83-
),
84-
)
85-
await model.log(val_groups)
86-
await model.delete_checkpoints()
87-
await backend._experimental_push_to_s3(model)
88-
result = await backend.train(model, train_groups, learning_rate=5e-5)
89-
await model.log(
90-
train_groups, metrics=result.metrics, step=result.step, split="train"
91-
)
91+
)
92+
await model.log(val_groups)
93+
await model.delete_checkpoints()
94+
if "BACKUP_BUCKET" in os.environ:
95+
await backend._experimental_push_to_s3(
96+
model,
97+
s3_bucket=os.environ["BACKUP_BUCKET"],
98+
verbose=True,
99+
)
100+
result = await backend.train(model, train_groups, learning_rate=5e-5)
101+
await model.log(
102+
train_groups, metrics=result.metrics, step=result.step, split="train"
103+
)
92104

93105

94106
if __name__ == "__main__":

src/art/local/backend.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ def _model_inference_name(self, model: Model, step: int | None = None) -> str:
148148
# Default to step 0 when not specified (the initial checkpoint created at registration)
149149
if step is not None:
150150
actual_step = step
151-
elif model.name in self._services:
151+
elif model.name in self._services and self._in_process:
152152
# In dedicated mode the service tracks which adapter vLLM has
153153
# actually loaded. Reading the filesystem would race: the
154154
# checkpoint directory appears before the HTTP reload completes.
@@ -294,14 +294,18 @@ async def _delete_checkpoint_files(
294294
steps_to_keep: list[int],
295295
) -> None:
296296
"""Delete checkpoint files, keeping only the specified steps."""
297-
from ..tinker.service import TinkerService
298297

299298
output_dir = get_model_dir(model=model, art_path=self._path)
300299
service = await self._get_service(model)
301-
if isinstance(service, TinkerService):
302-
await service.delete_checkpoints(steps_to_keep)
303-
else:
304-
delete_checkpoints(output_dir, steps_to_keep)
300+
try:
301+
from ..tinker.service import TinkerService
302+
303+
if isinstance(service, TinkerService):
304+
await service.delete_checkpoints(steps_to_keep)
305+
return
306+
except ImportError:
307+
pass
308+
delete_checkpoints(output_dir, steps_to_keep)
305309

306310
async def _prepare_backend_for_training(
307311
self,

0 commit comments

Comments
 (0)