-
Notifications
You must be signed in to change notification settings - Fork 804
Expand file tree
/
Copy pathtypes.py
More file actions
74 lines (52 loc) · 2.23 KB
/
types.py
File metadata and controls
74 lines (52 loc) · 2.23 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
from dataclasses import dataclass, field
from typing import Annotated, Literal
from openai.types.chat.chat_completion import Choice
from openai.types.chat.chat_completion_message_param import ChatCompletionMessageParam
from openai.types.chat.chat_completion_tool_param import ChatCompletionToolParam
import pydantic
from pydantic import SkipValidation
Message = Annotated[ChatCompletionMessageParam, SkipValidation]
MessageOrChoice = Message | Choice
Messages = list[Message]
MessagesAndChoices = list[MessageOrChoice]
Tools = list[ChatCompletionToolParam]
class TrainConfig(pydantic.BaseModel):
learning_rate: float = 5e-6
kl_penalty_coef: float = 0.0
kl_penalty_source: Literal["current_learner", "sample"] = "current_learner"
class TrainSFTConfig(pydantic.BaseModel):
learning_rate: float | list[float] = 5e-5 # Single value or per-batch list
batch_size: int | Literal["auto"] = "auto"
Verbosity = Literal[0, 1, 2]
# ---------------------------------------------------------------------------
# TrainResult classes
# ---------------------------------------------------------------------------
@dataclass
class TrainResult:
"""Base result returned from backend.train().
Attributes:
step: The training step after this training call completed.
metrics: Aggregated training metrics (loss, gradient norms, etc.).
"""
step: int
metrics: dict[str, float] = field(default_factory=dict)
@dataclass
class LocalTrainResult(TrainResult):
"""Result from LocalBackend.train().
Attributes:
step: The training step after this training call completed.
metrics: Aggregated training metrics (loss, gradient norms, etc.).
checkpoint_path: Path to the saved checkpoint directory, or None if
no checkpoint was saved.
"""
checkpoint_path: str | None = None
@dataclass
class ServerlessTrainResult(TrainResult):
"""Result from ServerlessBackend.train().
Attributes:
step: The training step after this training call completed.
metrics: Aggregated training metrics (loss, gradient norms, etc.).
artifact_name: The W&B artifact name for the checkpoint
(e.g., "entity/project/model:step5").
"""
artifact_name: str | None = None