Skip to content

Commit 2bdab06

Browse files
committed
refactor: Simplify training client arguments handling in TinkerBackend and TinkerService
- Removed TinkerTrainingClientArgs import and adjusted the configuration to use a default empty dictionary for training client arguments. - Ensured that default values for 'rank' and 'train_unembed' are set within TinkerService when creating the training client.
1 parent 0e567fb commit 2bdab06

2 files changed

Lines changed: 9 additions & 6 deletions

File tree

src/art/tinker/backend.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def __init__(
2626

2727
async def _get_service(self, model: TrainableModel) -> ModelService:
2828
from ..dev.get_model_config import get_model_config
29-
from ..dev.model import TinkerArgs, TinkerTrainingClientArgs
29+
from ..dev.model import TinkerArgs
3030
from .service import TinkerService
3131

3232
if model.name not in self._services:
@@ -38,10 +38,8 @@ async def _get_service(self, model: TrainableModel) -> ModelService:
3838
config["tinker_args"] = config.get("tinker_args") or TinkerArgs(
3939
renderer_name=get_renderer_name(model.base_model)
4040
)
41-
config["tinker_args"]["training_client_args"] = config["tinker_args"].get(
42-
"training_client_args"
43-
) or TinkerTrainingClientArgs(
44-
rank=8,
41+
config["tinker_args"]["training_client_args"] = (
42+
config["tinker_args"].get("training_client_args") or {}
4543
)
4644
self._services[model.name] = TinkerService(
4745
model_name=model.name,

src/art/tinker/service.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,10 +213,15 @@ async def _get_state(self) -> "TinkerState":
213213
)
214214
else:
215215
with log_timing("Creating Tinker training client"):
216+
training_client_args = config.get("training_client_args", {})
217+
if "rank" not in training_client_args:
218+
training_client_args["rank"] = 8
219+
if "train_unembed" not in training_client_args:
220+
training_client_args["train_unembed"] = False
216221
training_client = (
217222
await service_client.create_lora_training_client_async(
218223
base_model=self.base_model,
219-
**config.get("training_client_args", {}),
224+
**training_client_args,
220225
)
221226
)
222227
sampler_client = await self._save_checkpoint(

0 commit comments

Comments
 (0)