Skip to content

Commit d48764f

Browse files
authored
fix: auto-log inference tinker costs from managed openai client (#620)
1 parent abecd93 commit d48764f

File tree

3 files changed

+247
-1
lines changed

3 files changed

+247
-1
lines changed

src/art/local/backend.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444

4545
from .. import dev
4646
from ..backend import AnyTrainableModel, Backend
47+
from ..costs import build_cost_calculator, get_model_pricing
4748
from ..metrics_taxonomy import (
4849
TRAIN_GRADIENT_STEPS_KEY,
4950
average_metric_samples,
@@ -206,6 +207,11 @@ async def register(
206207
# (wandb initialization is now handled by the model's _get_wandb_run method)
207208
if model.trainable and "WANDB_API_KEY" in os.environ:
208209
_ = model._get_wandb_run()
210+
if model.trainable:
211+
trainable_model = cast(TrainableModel, model)
212+
pricing = get_model_pricing(trainable_model.base_model)
213+
if pricing is not None:
214+
trainable_model.set_cost_calculator(build_cost_calculator(pricing))
209215

210216
def _model_inference_name(self, model: Model, step: int | None = None) -> str:
211217
"""Return the inference name for a model checkpoint.

src/art/model.py

Lines changed: 87 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,47 @@
3636
StateType = TypeVar("StateType", bound=dict[str, Any], default=dict[str, Any])
3737

3838
METRICS_BUILDER_STATE_KEY = "_metrics_builder_state"
39+
40+
41+
class _OpenAIChatCompletionsProxy:
42+
def __init__(self, completions: Any, record_costs: Any) -> None:
43+
self._completions = completions
44+
self._record_costs = record_costs
45+
46+
async def create(self, *args: Any, **kwargs: Any) -> Any:
47+
response = await self._completions.create(*args, **kwargs)
48+
self._record_costs(response)
49+
return response
50+
51+
def __getattr__(self, name: str) -> Any:
52+
return getattr(self._completions, name)
53+
54+
55+
class _OpenAIChatProxy:
56+
def __init__(self, chat: Any, record_costs: Any) -> None:
57+
self._chat = chat
58+
self.completions = _OpenAIChatCompletionsProxy(chat.completions, record_costs)
59+
60+
def __getattr__(self, name: str) -> Any:
61+
return getattr(self._chat, name)
62+
63+
64+
class _OpenAIClientProxy:
65+
def __init__(self, client: Any, record_costs: Any) -> None:
66+
self._client = client
67+
self._record_costs = record_costs
68+
self.chat = _OpenAIChatProxy(client.chat, record_costs)
69+
70+
def with_options(self, *args: Any, **kwargs: Any) -> "_OpenAIClientProxy":
71+
return _OpenAIClientProxy(
72+
self._client.with_options(*args, **kwargs),
73+
self._record_costs,
74+
)
75+
76+
def __getattr__(self, name: str) -> Any:
77+
return getattr(self._client, name)
78+
79+
3980
METRIC_SECTIONS = frozenset(
4081
{
4182
"reward",
@@ -233,6 +274,12 @@ async def register(self, backend: "Backend") -> None:
233274
def openai_client(
234275
self,
235276
) -> AsyncOpenAI:
277+
"""Return ART's managed inference client.
278+
279+
For trainable models with configured pricing, chat completion calls made
280+
through this client automatically emit Tinker inference costs when an
281+
ART metrics context is active.
282+
"""
236283
if self._openai_client is not None:
237284
return self._openai_client
238285

@@ -245,7 +292,7 @@ def openai_client(
245292
raise ValueError(
246293
"In order to create an OpenAI client you must provide an `inference_api_key` and `inference_base_url`."
247294
)
248-
self._openai_client = AsyncOpenAI(
295+
raw_client = AsyncOpenAI(
249296
base_url=self.inference_base_url,
250297
api_key=self.inference_api_key,
251298
http_client=DefaultAsyncHttpxClient(
@@ -255,6 +302,13 @@ def openai_client(
255302
),
256303
),
257304
)
305+
# Wrap the raw OpenAI client so ART-owned inference calls can add
306+
# split-scoped Tinker costs without rollout code needing to do it
307+
# manually.
308+
self._openai_client = cast(
309+
AsyncOpenAI,
310+
_OpenAIClientProxy(raw_client, self._record_openai_completion_costs),
311+
)
258312
return self._openai_client
259313

260314
def litellm_completion_params(self, step: int | None = None) -> dict:
@@ -304,6 +358,10 @@ def get_inference_name(self, step: int | None = None) -> str:
304358
return f"{base_name}@{step}"
305359
return base_name
306360

361+
def _record_openai_completion_costs(self, _response: Any) -> None:
362+
"""Hook for subclasses that want to auto-log managed inference costs."""
363+
return
364+
307365
def _get_output_dir(self) -> str:
308366
"""Get the output directory for this model."""
309367
return f"{self.base_path}/{self.project}/models/{self.name}"
@@ -946,6 +1004,34 @@ def _noop_cost_calculator(
9461004
) -> dict[str, float]:
9471005
return {}
9481006

1007+
def _record_openai_completion_costs(self, _response: Any) -> None:
1008+
try:
1009+
builder = MetricsBuilder.get_active()
1010+
except LookupError:
1011+
return
1012+
1013+
usage = getattr(_response, "usage", None)
1014+
prompt_tokens = int(getattr(usage, "prompt_tokens", 0) or 0)
1015+
completion_tokens = int(getattr(usage, "completion_tokens", 0) or 0)
1016+
num_choices = len(getattr(_response, "choices", None) or [])
1017+
effective_prompt_tokens = prompt_tokens * max(num_choices, 1)
1018+
cost_context = builder.cost_context.strip("/")
1019+
if not cost_context:
1020+
return
1021+
1022+
cost_metrics = self._cost_calculator(
1023+
effective_prompt_tokens,
1024+
completion_tokens,
1025+
cost_context,
1026+
)
1027+
if not cost_metrics:
1028+
return
1029+
1030+
for key, value in cost_metrics.items():
1031+
if not key.startswith("costs/"):
1032+
continue
1033+
builder.add_cost(key[len("costs/") :], float(value))
1034+
9491035
@overload
9501036
def __new__(
9511037
cls,
Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
import importlib
2+
from typing import Any
3+
4+
import pytest
5+
6+
from art import TrainableModel
7+
from art.costs import build_cost_calculator, get_model_pricing
8+
9+
10+
class _FakeUsage:
11+
def __init__(self, prompt_tokens: int, completion_tokens: int) -> None:
12+
self.prompt_tokens = prompt_tokens
13+
self.completion_tokens = completion_tokens
14+
15+
16+
class _FakeResponse:
17+
def __init__(
18+
self,
19+
prompt_tokens: int,
20+
completion_tokens: int,
21+
*,
22+
num_choices: int = 1,
23+
) -> None:
24+
self.usage = _FakeUsage(prompt_tokens, completion_tokens)
25+
self.choices = [object() for _ in range(num_choices)]
26+
27+
28+
class _FakeCompletions:
29+
def __init__(self, response: _FakeResponse) -> None:
30+
self._response = response
31+
32+
async def create(self, *args: Any, **kwargs: Any) -> _FakeResponse:
33+
return self._response
34+
35+
36+
def _patch_async_openai(
37+
monkeypatch: pytest.MonkeyPatch, response: _FakeResponse
38+
) -> None:
39+
model_module = importlib.import_module("art.model")
40+
41+
class _FakeAsyncOpenAI:
42+
def __init__(self, *args: Any, **kwargs: Any) -> None:
43+
self.chat = type(
44+
"FakeChat",
45+
(),
46+
{"completions": _FakeCompletions(response)},
47+
)()
48+
49+
def with_options(self, *args: Any, **kwargs: Any) -> "_FakeAsyncOpenAI":
50+
return self
51+
52+
monkeypatch.setattr(model_module, "AsyncOpenAI", _FakeAsyncOpenAI)
53+
54+
55+
def _build_model() -> TrainableModel:
56+
pricing = get_model_pricing("openai/gpt-oss-20b")
57+
assert pricing is not None
58+
59+
model = TrainableModel(
60+
name="test-run",
61+
project="test-project",
62+
base_model="openai/gpt-oss-20b",
63+
)
64+
model.inference_api_key = "test-key"
65+
model.inference_base_url = "http://example.test/v1"
66+
model.set_cost_calculator(build_cost_calculator(pricing))
67+
return model
68+
69+
70+
class TestModelOpenAIClientCosts:
71+
@pytest.mark.asyncio
72+
async def test_openai_client_automatically_logs_train_tinker_costs(
73+
self,
74+
monkeypatch: pytest.MonkeyPatch,
75+
) -> None:
76+
_patch_async_openai(monkeypatch, _FakeResponse(1_000, 2_000))
77+
model = _build_model()
78+
builder = model.metrics_builder("train")
79+
80+
with builder.activate_context():
81+
await model.openai_client().chat.completions.create(
82+
model=model.get_inference_name(),
83+
messages=[{"role": "user", "content": "hello"}],
84+
)
85+
86+
metrics = await builder.flush()
87+
assert metrics["costs/train/tinker_prefill"] == pytest.approx(0.00012)
88+
assert metrics["costs/train/tinker_sample"] == pytest.approx(0.0006)
89+
assert metrics["costs/train"] == pytest.approx(0.00072)
90+
91+
@pytest.mark.asyncio
92+
async def test_openai_client_automatically_logs_eval_tinker_costs(
93+
self,
94+
monkeypatch: pytest.MonkeyPatch,
95+
) -> None:
96+
_patch_async_openai(monkeypatch, _FakeResponse(500, 250))
97+
model = _build_model()
98+
builder = model.metrics_builder("eval")
99+
100+
with builder.activate_context():
101+
await model.openai_client().chat.completions.create(
102+
model=model.get_inference_name(),
103+
messages=[{"role": "user", "content": "hello"}],
104+
)
105+
106+
metrics = await builder.flush()
107+
assert metrics["costs/eval/tinker_prefill"] == pytest.approx(0.00006)
108+
assert metrics["costs/eval/tinker_sample"] == pytest.approx(0.000075)
109+
assert metrics["costs/eval"] == pytest.approx(0.000135)
110+
111+
@pytest.mark.asyncio
112+
async def test_openai_client_does_not_log_costs_without_active_metrics_context(
113+
self,
114+
monkeypatch: pytest.MonkeyPatch,
115+
) -> None:
116+
_patch_async_openai(monkeypatch, _FakeResponse(1_000, 2_000))
117+
model = _build_model()
118+
builder = model.metrics_builder("train")
119+
120+
await model.openai_client().chat.completions.create(
121+
model=model.get_inference_name(),
122+
messages=[{"role": "user", "content": "hello"}],
123+
)
124+
125+
metrics = await builder.flush()
126+
assert metrics == {}
127+
128+
@pytest.mark.asyncio
129+
async def test_multiple_choices_scale_prefill_cost_once_per_sample(
130+
self,
131+
monkeypatch: pytest.MonkeyPatch,
132+
) -> None:
133+
_patch_async_openai(monkeypatch, _FakeResponse(1_000, 2_000, num_choices=3))
134+
model = _build_model()
135+
builder = model.metrics_builder("train")
136+
137+
with builder.activate_context():
138+
await model.openai_client().chat.completions.create(
139+
model=model.get_inference_name(),
140+
messages=[{"role": "user", "content": "hello"}],
141+
n=3,
142+
)
143+
144+
metrics = await builder.flush()
145+
assert metrics["costs/train/tinker_prefill"] == pytest.approx(0.00036)
146+
assert metrics["costs/train/tinker_sample"] == pytest.approx(0.0006)
147+
148+
def test_manual_cost_calculator_still_returns_tinker_metrics(self) -> None:
149+
model = _build_model()
150+
151+
metrics = model.cost_calculator(1_000, 2_000, "train")
152+
153+
assert metrics["costs/train/tinker_prefill"] == pytest.approx(0.00012)
154+
assert metrics["costs/train/tinker_sample"] == pytest.approx(0.0006)

0 commit comments

Comments
 (0)