Skip to content

Commit aed1ee7

Browse files
authored
Clean up unused adapters before saving checkpoint (#605)
Removes reference adapters (loaded for KL/logprob computation) from the PEFT model before saving, freeing GPU/CPU memory. Also adds a gc_and_empty_cuda_cache call after saving. Made-with: Cursor
1 parent 43a6ed0 commit aed1ee7

1 file changed

Lines changed: 29 additions & 0 deletions

File tree

src/art/unsloth/service.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,13 +151,42 @@ def save_checkpoint(
151151
verbose: bool = False,
152152
) -> str:
153153
"""Save a checkpoint and return the checkpoint directory path."""
154+
# _use_adapter() may load reference adapters for KL/logprob computation and
155+
# keep them attached to the PEFT model. Before saving, keep only active
156+
# adapter(s) and drop the rest to release GPU/CPU memory.
157+
try:
158+
peft_model = trainer.accelerator.unwrap_model( # type: ignore[attr-defined]
159+
trainer.model, keep_fp32_wrapper=False
160+
)
161+
active_adapters = peft_model.active_adapter
162+
if isinstance(active_adapters, str):
163+
keep_adapters = {active_adapters}
164+
else:
165+
keep_adapters = set(active_adapters)
166+
167+
before_adapters = list(peft_model.peft_config.keys())
168+
print(f"Adapters before cleanup: {before_adapters}")
169+
print(f"Keeping active adapter(s): {sorted(keep_adapters)}")
170+
171+
for adapter_name in before_adapters:
172+
if adapter_name not in keep_adapters:
173+
peft_model.delete_adapter(adapter_name)
174+
print(f"Deleted unused adapter: {adapter_name}")
175+
176+
after_adapters = list(peft_model.peft_config.keys())
177+
print(f"Adapters after cleanup: {after_adapters}")
178+
except Exception as e:
179+
print(f"Warning: failed to cleanup unused adapters: {e}")
180+
154181
if verbose:
155182
print("Saving new LoRA adapter...")
156183
next_step = get_step_from_dir(output_dir) + 1
157184
checkpoint_dir = get_step_checkpoint_dir(output_dir, next_step)
158185
os.makedirs(checkpoint_dir, exist_ok=True)
159186
trainer.save_model(checkpoint_dir)
160187
convert_checkpoint_if_needed(checkpoint_dir)
188+
189+
gc_and_empty_cuda_cache()
161190
return checkpoint_dir
162191

163192

0 commit comments

Comments
 (0)