Skip to content

Commit 97938fb

Browse files
authored
Move cost calculation to backend (#544)
1 parent 2ee5b56 commit 97938fb

8 files changed

Lines changed: 348 additions & 15 deletions

File tree

src/art/costs.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
"""Cost utilities for ART training and evaluation."""
2+
3+
from __future__ import annotations
4+
5+
from dataclasses import dataclass
6+
from typing import Callable, TypeAlias
7+
8+
9+
@dataclass(frozen=True)
10+
class ModelPricing:
11+
"""Per-million-token pricing for a model."""
12+
13+
prefill: float # $/1M tokens for prompt/prefill
14+
sample: float # $/1M tokens for sampling/generation
15+
train: float # $/1M tokens for training
16+
17+
18+
TokenCount: TypeAlias = int | None
19+
CostCalculator: TypeAlias = Callable[[TokenCount, TokenCount], dict[str, float]]
20+
21+
# Pricing per model ($/1M tokens). Keep in sync with infra pricing.
22+
MODEL_PRICING: dict[str, ModelPricing] = {
23+
# Qwen models
24+
"Qwen/Qwen3-4B-Instruct-2507": ModelPricing(prefill=0.07, sample=0.22, train=0.22),
25+
"Qwen/Qwen3-8B": ModelPricing(prefill=0.13, sample=0.40, train=0.40),
26+
"Qwen/Qwen3-8B-Base": ModelPricing(prefill=0.13, sample=0.40, train=0.40),
27+
"Qwen/Qwen3-30B-A3B": ModelPricing(prefill=0.12, sample=0.30, train=0.36),
28+
"Qwen/Qwen3-30B-A3B-Base": ModelPricing(prefill=0.12, sample=0.30, train=0.36),
29+
"Qwen/Qwen3-30B-A3B-Instruct-2507": ModelPricing(
30+
prefill=0.12, sample=0.30, train=0.36
31+
),
32+
"Qwen/Qwen3-32B": ModelPricing(prefill=0.49, sample=1.47, train=1.47),
33+
"Qwen/Qwen3-235B-A22B-Instruct-2507": ModelPricing(
34+
prefill=0.68, sample=1.70, train=2.04
35+
),
36+
"Qwen/Qwen3-VL-30B-A3B-Instruct": ModelPricing(
37+
prefill=0.18, sample=0.44, train=0.53
38+
),
39+
"Qwen/Qwen3-VL-235B-A22B-Instruct": ModelPricing(
40+
prefill=1.02, sample=2.56, train=3.07
41+
),
42+
# Meta Llama models
43+
"meta-llama/Llama-3.2-1B": ModelPricing(prefill=0.03, sample=0.09, train=0.09),
44+
"meta-llama/Llama-3.2-3B": ModelPricing(prefill=0.06, sample=0.18, train=0.18),
45+
"meta-llama/Llama-3.1-8B": ModelPricing(prefill=0.13, sample=0.40, train=0.40),
46+
"meta-llama/Llama-3.1-8B-Instruct": ModelPricing(
47+
prefill=0.13, sample=0.40, train=0.40
48+
),
49+
"meta-llama/Llama-3.1-70B": ModelPricing(prefill=1.05, sample=3.16, train=3.16),
50+
"meta-llama/Llama-3.3-70B-Instruct": ModelPricing(
51+
prefill=1.05, sample=3.16, train=3.16
52+
),
53+
# DeepSeek models
54+
"deepseek-ai/DeepSeek-V3.1": ModelPricing(prefill=1.13, sample=2.81, train=3.38),
55+
"deepseek-ai/DeepSeek-V3.1-Base": ModelPricing(
56+
prefill=1.13, sample=2.81, train=3.38
57+
),
58+
# OpenAI models
59+
"openai/gpt-oss-120b": ModelPricing(prefill=0.18, sample=0.44, train=0.52),
60+
"openai/gpt-oss-20b": ModelPricing(prefill=0.12, sample=0.30, train=0.36),
61+
# Moonshot models
62+
"moonshotai/Kimi-K2-Thinking": ModelPricing(prefill=0.98, sample=2.44, train=2.93),
63+
}
64+
65+
66+
def get_model_pricing(
67+
model_name: str | None, *, strict: bool = False
68+
) -> ModelPricing | None:
69+
"""Return pricing for a model or None if missing."""
70+
if model_name is None:
71+
return None
72+
pricing = MODEL_PRICING.get(model_name)
73+
if pricing is None and strict:
74+
raise ValueError(
75+
f"No pricing configured for model '{model_name}'. "
76+
f"Add pricing to art.costs.MODEL_PRICING. "
77+
f"Available models: {list(MODEL_PRICING.keys())}"
78+
)
79+
return pricing
80+
81+
82+
def tokens_to_cost(num_tokens: float, price_per_million: float) -> float:
83+
"""Convert token count to cost in dollars."""
84+
return float(num_tokens) * price_per_million / 1_000_000
85+
86+
87+
def compute_sample_costs(
88+
*,
89+
prompt_tokens: int | None,
90+
completion_tokens: int | None,
91+
pricing: ModelPricing,
92+
) -> dict[str, float]:
93+
"""Compute prompt+completion costs for a single API call."""
94+
prompt_value = float(prompt_tokens or 0)
95+
completion_value = float(completion_tokens or 0)
96+
prefill_cost = tokens_to_cost(prompt_value, pricing.prefill)
97+
sample_cost = tokens_to_cost(completion_value, pricing.sample)
98+
return {
99+
"costs_prefill": prefill_cost,
100+
"costs_sample": sample_cost,
101+
}
102+
103+
104+
def build_cost_calculator(pricing: ModelPricing) -> CostCalculator:
105+
"""Return a callable that computes prompt+completion costs for a request."""
106+
107+
def _calculator(
108+
prompt_tokens: int | None, completion_tokens: int | None
109+
) -> dict[str, float]:
110+
return compute_sample_costs(
111+
prompt_tokens=prompt_tokens,
112+
completion_tokens=completion_tokens,
113+
pricing=pricing,
114+
)
115+
116+
return _calculator
117+
118+
119+
def compute_train_cost(train_tokens: float, pricing: ModelPricing) -> float:
120+
"""Compute training cost from token count."""
121+
return tokens_to_cost(train_tokens, pricing.train)

src/art/model.py

Lines changed: 169 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
from datetime import datetime
23
import json
34
import os
@@ -11,6 +12,7 @@
1112
from typing_extensions import Never, TypeVar
1213

1314
from . import dev
15+
from .costs import CostCalculator
1416
from .trajectories import Trajectory, TrajectoryGroup
1517
from .types import TrainConfig
1618
from .utils.old_benchmarking.calculate_step_metrics import calculate_step_std_dev
@@ -25,6 +27,10 @@
2527
ModelConfig = TypeVar("ModelConfig", bound=BaseModel | None)
2628
StateType = TypeVar("StateType", bound=dict[str, Any], default=dict[str, Any])
2729

30+
COSTS_STATE_KEY = "_costs"
31+
COSTS_METRIC_PREFIX = "costs_"
32+
COSTS_TOTAL_KEY = f"{COSTS_METRIC_PREFIX}total"
33+
2834

2935
class Model(
3036
BaseModel,
@@ -87,6 +93,8 @@ class Model(
8793
_s3_prefix: str | None = None
8894
_openai_client: AsyncOpenAI | None = None
8995
_wandb_run: Optional["Run"] = None # Private, for lazy wandb initialization
96+
_costs_lock: asyncio.Lock
97+
_cost_calculator: CostCalculator
9098

9199
def __init__(
92100
self,
@@ -374,6 +382,7 @@ def _get_wandb_run(self) -> Optional["Run"]:
374382
wandb.define_metric("training_step")
375383
wandb.define_metric("train/*", step_metric="training_step")
376384
wandb.define_metric("val/*", step_metric="training_step")
385+
wandb.define_metric("costs/*", step_metric="training_step")
377386
return self._wandb_run
378387

379388
def _log_metrics(
@@ -406,6 +415,64 @@ def _log_metrics(
406415
if run := self._get_wandb_run():
407416
run.log({"training_step": step, **prefixed})
408417

418+
async def _record_costs(
419+
self,
420+
split: str,
421+
step: int,
422+
*,
423+
cost_components: dict[str, float],
424+
cost_total_direct: float,
425+
cost_seen: bool,
426+
) -> None:
427+
component_total = sum(cost_components.values())
428+
step_total = component_total if component_total > 0 else cost_total_direct
429+
if not cost_seen or step_total <= 0:
430+
return
431+
432+
async with self._costs_lock:
433+
existing_state = self.read_state() or {}
434+
raw_costs = existing_state.get(COSTS_STATE_KEY) or {}
435+
cumulative = {
436+
key: float(value)
437+
for key, value in raw_costs.items()
438+
if isinstance(value, (int, float))
439+
}
440+
last_steps = raw_costs.get("_last_steps")
441+
if not isinstance(last_steps, dict):
442+
last_steps = {}
443+
last_step = last_steps.get(split)
444+
445+
if isinstance(last_step, (int, float)) and int(last_step) >= step:
446+
for component, value in cost_components.items():
447+
if value == 0:
448+
continue
449+
cumulative_key = f"{split}_{component}"
450+
cumulative[cumulative_key] = max(
451+
cumulative.get(cumulative_key, 0.0), value
452+
)
453+
cumulative[split] = max(cumulative.get(split, 0.0), step_total)
454+
cumulative["total"] = max(
455+
cumulative.get("total", 0.0), cumulative.get(split, 0.0)
456+
)
457+
self.merge_state(
458+
{COSTS_STATE_KEY: {**cumulative, "_last_steps": last_steps}}
459+
)
460+
self._log_metrics(cumulative, "costs", step)
461+
return
462+
463+
for component, value in cost_components.items():
464+
if value == 0:
465+
continue
466+
cumulative_key = f"{split}_{component}"
467+
cumulative[cumulative_key] = cumulative.get(cumulative_key, 0.0) + value
468+
cumulative[split] = cumulative.get(split, 0.0) + step_total
469+
cumulative["total"] = cumulative.get("total", 0.0) + step_total
470+
last_steps[split] = step
471+
self.merge_state(
472+
{COSTS_STATE_KEY: {**cumulative, "_last_steps": last_steps}}
473+
)
474+
self._log_metrics(cumulative, "costs", step)
475+
409476
async def log(
410477
self,
411478
trajectories: (
@@ -439,7 +506,42 @@ async def log(
439506
# If only metrics provided (no trajectories), just log them and return
440507
if trajectories is None:
441508
if metrics is not None:
442-
self._log_metrics(metrics, split, step)
509+
cost_step = await self.get_step()
510+
cost_components: dict[str, float] = {}
511+
cost_total_direct = 0.0
512+
cost_seen = False
513+
514+
for metric, value in metrics.items():
515+
if not isinstance(value, (int, float)):
516+
continue
517+
if metric == COSTS_TOTAL_KEY:
518+
raise ValueError(
519+
"Do not log 'costs_total' directly. Log costs_* components "
520+
"(e.g., costs_prefill, costs_sample) and totals are derived."
521+
)
522+
elif metric.startswith(COSTS_METRIC_PREFIX):
523+
component = metric[len(COSTS_METRIC_PREFIX) :]
524+
if component:
525+
cost_components[component] = cost_components.get(
526+
component, 0.0
527+
) + float(value)
528+
cost_seen = True
529+
530+
metrics_without_costs = {
531+
key: value
532+
for key, value in metrics.items()
533+
if not key.startswith(COSTS_METRIC_PREFIX)
534+
}
535+
if metrics_without_costs:
536+
self._log_metrics(metrics_without_costs, split, step)
537+
538+
await self._record_costs(
539+
split,
540+
cost_step,
541+
cost_components=cost_components,
542+
cost_total_direct=cost_total_direct,
543+
cost_seen=cost_seen,
544+
)
443545
return
444546

445547
# Convert to list[TrajectoryGroup]
@@ -465,13 +567,39 @@ async def log(
465567
trajectory_groups, f"{trajectories_dir}/{file_name}"
466568
)
467569

468-
# 2. Calculate aggregate metrics
570+
# 2. Calculate aggregate metrics (excluding additive costs)
571+
cost_step = await self.get_step()
469572
all_metrics: dict[str, list[float]] = {"reward": [], "exception_rate": []}
470573
group_metrics: dict[str, list[float]] = {}
574+
cost_components: dict[str, float] = {}
575+
cost_total_direct = 0.0
576+
cost_seen = False
577+
578+
def _add_costs(metrics_dict: dict[str, float | int | bool]) -> None:
579+
nonlocal cost_total_direct, cost_seen
580+
for metric, value in metrics_dict.items():
581+
if not isinstance(value, (int, float)):
582+
continue
583+
if metric == COSTS_TOTAL_KEY:
584+
raise ValueError(
585+
"Do not log 'costs_total' directly. Log costs_* components "
586+
"(e.g., costs_prefill, costs_sample) and totals are derived."
587+
)
588+
elif metric.startswith(COSTS_METRIC_PREFIX):
589+
component = metric[len(COSTS_METRIC_PREFIX) :]
590+
if component:
591+
cost_components[component] = cost_components.get(
592+
component, 0.0
593+
) + float(value)
594+
cost_seen = True
471595

472596
for group in trajectory_groups:
597+
if group.metrics:
598+
_add_costs(group.metrics)
473599
if group.trajectories:
474600
for metric, value in group.metrics.items():
601+
if metric.startswith(COSTS_METRIC_PREFIX):
602+
continue
475603
if metric not in group_metrics:
476604
group_metrics[metric] = []
477605
group_metrics[metric].append(float(value))
@@ -486,9 +614,13 @@ async def log(
486614

487615
# Collect other custom metrics
488616
for metric, value in trajectory.metrics.items():
617+
if metric.startswith(COSTS_METRIC_PREFIX):
618+
continue
489619
if metric not in all_metrics:
490620
all_metrics[metric] = []
491621
all_metrics[metric].append(float(value))
622+
if trajectory.metrics:
623+
_add_costs(trajectory.metrics)
492624

493625
# Calculate averages for all metrics
494626
averages: dict[str, float] = {}
@@ -506,11 +638,26 @@ async def log(
506638

507639
# Merge in any additional metrics passed directly
508640
if metrics is not None:
509-
averages.update(metrics)
641+
_add_costs(metrics)
642+
metrics_without_costs = {
643+
key: value
644+
for key, value in metrics.items()
645+
if not key.startswith(COSTS_METRIC_PREFIX)
646+
}
647+
averages.update(metrics_without_costs)
510648

511649
# 3. Log metrics (writes to history.jsonl and wandb)
512650
self._log_metrics(averages, split, step)
513651

652+
# 4. Log cumulative costs (additive)
653+
await self._record_costs(
654+
split,
655+
cost_step,
656+
cost_components=cost_components,
657+
cost_total_direct=cost_total_direct,
658+
cost_seen=cost_seen,
659+
)
660+
514661
async def get_step(self) -> int:
515662
"""
516663
Get the model's current training step. For non-trainable models, returns 0.
@@ -559,6 +706,25 @@ def __init__(
559706
report_metrics=report_metrics,
560707
**kwargs,
561708
)
709+
object.__setattr__(self, "_costs_lock", asyncio.Lock())
710+
object.__setattr__(self, "_cost_calculator", self._noop_cost_calculator)
711+
712+
@property
713+
def cost_calculator(self) -> CostCalculator:
714+
return self._cost_calculator
715+
716+
def set_cost_calculator(self, calculator: CostCalculator | None) -> None:
717+
object.__setattr__(
718+
self,
719+
"_cost_calculator",
720+
calculator if calculator is not None else self._noop_cost_calculator,
721+
)
722+
723+
@staticmethod
724+
def _noop_cost_calculator(
725+
_prompt_tokens: int | None, _completion_tokens: int | None
726+
) -> dict[str, float]:
727+
return {}
562728
if _internal_config is not None:
563729
# Bypass BaseModel __setattr__ to allow setting private attr
564730
object.__setattr__(self, "_internal_config", _internal_config)

0 commit comments

Comments
 (0)