Skip to content

Commit 66c4a3e

Browse files
committed
feat: Enhance TinkerBackend with training client arguments
- Added TinkerTrainingClientArgs to the TinkerBackend service configuration. - Updated the model service initialization to include training client arguments for improved training management.
1 parent e164f73 commit 66c4a3e

1 file changed

Lines changed: 5 additions & 2 deletions

File tree

src/art/tinker/backend.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def __init__(
2626

2727
async def _get_service(self, model: TrainableModel) -> ModelService:
2828
from ..dev.get_model_config import get_model_config
29-
from ..dev.model import TinkerArgs
29+
from ..dev.model import TinkerArgs, TinkerTrainingClientArgs
3030
from .service import TinkerService
3131

3232
if model.name not in self._services:
@@ -36,7 +36,10 @@ async def _get_service(self, model: TrainableModel) -> ModelService:
3636
config=model._internal_config,
3737
)
3838
config["tinker_args"] = config.get("tinker_args") or TinkerArgs(
39-
renderer_name=get_renderer_name(model.base_model)
39+
renderer_name=get_renderer_name(model.base_model),
40+
training_client_args=TinkerTrainingClientArgs(
41+
rank=8,
42+
),
4043
)
4144
self._services[model.name] = TinkerService(
4245
model_name=model.name,

0 commit comments

Comments
 (0)