1+ import asyncio
12from datetime import datetime
23import json
34import os
1112from typing_extensions import Never , TypeVar
1213
1314from . import dev
15+ from .costs import CostCalculator
1416from .trajectories import Trajectory , TrajectoryGroup
1517from .types import TrainConfig
1618from .utils .old_benchmarking .calculate_step_metrics import calculate_step_std_dev
2527ModelConfig = TypeVar ("ModelConfig" , bound = BaseModel | None )
2628StateType = TypeVar ("StateType" , bound = dict [str , Any ], default = dict [str , Any ])
2729
30+ COSTS_STATE_KEY = "_costs"
31+ COSTS_METRIC_PREFIX = "costs_"
32+ COSTS_TOTAL_KEY = f"{ COSTS_METRIC_PREFIX } total"
33+
2834
2935class Model (
3036 BaseModel ,
@@ -87,6 +93,8 @@ class Model(
8793 _s3_prefix : str | None = None
8894 _openai_client : AsyncOpenAI | None = None
8995 _wandb_run : Optional ["Run" ] = None # Private, for lazy wandb initialization
96+ _costs_lock : asyncio .Lock
97+ _cost_calculator : CostCalculator
9098
9199 def __init__ (
92100 self ,
@@ -374,6 +382,7 @@ def _get_wandb_run(self) -> Optional["Run"]:
374382 wandb .define_metric ("training_step" )
375383 wandb .define_metric ("train/*" , step_metric = "training_step" )
376384 wandb .define_metric ("val/*" , step_metric = "training_step" )
385+ wandb .define_metric ("costs/*" , step_metric = "training_step" )
377386 return self ._wandb_run
378387
379388 def _log_metrics (
@@ -406,6 +415,64 @@ def _log_metrics(
406415 if run := self ._get_wandb_run ():
407416 run .log ({"training_step" : step , ** prefixed })
408417
418+ async def _record_costs (
419+ self ,
420+ split : str ,
421+ step : int ,
422+ * ,
423+ cost_components : dict [str , float ],
424+ cost_total_direct : float ,
425+ cost_seen : bool ,
426+ ) -> None :
427+ component_total = sum (cost_components .values ())
428+ step_total = component_total if component_total > 0 else cost_total_direct
429+ if not cost_seen or step_total <= 0 :
430+ return
431+
432+ async with self ._costs_lock :
433+ existing_state = self .read_state () or {}
434+ raw_costs = existing_state .get (COSTS_STATE_KEY ) or {}
435+ cumulative = {
436+ key : float (value )
437+ for key , value in raw_costs .items ()
438+ if isinstance (value , (int , float ))
439+ }
440+ last_steps = raw_costs .get ("_last_steps" )
441+ if not isinstance (last_steps , dict ):
442+ last_steps = {}
443+ last_step = last_steps .get (split )
444+
445+ if isinstance (last_step , (int , float )) and int (last_step ) >= step :
446+ for component , value in cost_components .items ():
447+ if value == 0 :
448+ continue
449+ cumulative_key = f"{ split } _{ component } "
450+ cumulative [cumulative_key ] = max (
451+ cumulative .get (cumulative_key , 0.0 ), value
452+ )
453+ cumulative [split ] = max (cumulative .get (split , 0.0 ), step_total )
454+ cumulative ["total" ] = max (
455+ cumulative .get ("total" , 0.0 ), cumulative .get (split , 0.0 )
456+ )
457+ self .merge_state (
458+ {COSTS_STATE_KEY : {** cumulative , "_last_steps" : last_steps }}
459+ )
460+ self ._log_metrics (cumulative , "costs" , step )
461+ return
462+
463+ for component , value in cost_components .items ():
464+ if value == 0 :
465+ continue
466+ cumulative_key = f"{ split } _{ component } "
467+ cumulative [cumulative_key ] = cumulative .get (cumulative_key , 0.0 ) + value
468+ cumulative [split ] = cumulative .get (split , 0.0 ) + step_total
469+ cumulative ["total" ] = cumulative .get ("total" , 0.0 ) + step_total
470+ last_steps [split ] = step
471+ self .merge_state (
472+ {COSTS_STATE_KEY : {** cumulative , "_last_steps" : last_steps }}
473+ )
474+ self ._log_metrics (cumulative , "costs" , step )
475+
409476 async def log (
410477 self ,
411478 trajectories : (
@@ -439,7 +506,42 @@ async def log(
439506 # If only metrics provided (no trajectories), just log them and return
440507 if trajectories is None :
441508 if metrics is not None :
442- self ._log_metrics (metrics , split , step )
509+ cost_step = await self .get_step ()
510+ cost_components : dict [str , float ] = {}
511+ cost_total_direct = 0.0
512+ cost_seen = False
513+
514+ for metric , value in metrics .items ():
515+ if not isinstance (value , (int , float )):
516+ continue
517+ if metric == COSTS_TOTAL_KEY :
518+ raise ValueError (
519+ "Do not log 'costs_total' directly. Log costs_* components "
520+ "(e.g., costs_prefill, costs_sample) and totals are derived."
521+ )
522+ elif metric .startswith (COSTS_METRIC_PREFIX ):
523+ component = metric [len (COSTS_METRIC_PREFIX ) :]
524+ if component :
525+ cost_components [component ] = cost_components .get (
526+ component , 0.0
527+ ) + float (value )
528+ cost_seen = True
529+
530+ metrics_without_costs = {
531+ key : value
532+ for key , value in metrics .items ()
533+ if not key .startswith (COSTS_METRIC_PREFIX )
534+ }
535+ if metrics_without_costs :
536+ self ._log_metrics (metrics_without_costs , split , step )
537+
538+ await self ._record_costs (
539+ split ,
540+ cost_step ,
541+ cost_components = cost_components ,
542+ cost_total_direct = cost_total_direct ,
543+ cost_seen = cost_seen ,
544+ )
443545 return
444546
445547 # Convert to list[TrajectoryGroup]
@@ -465,13 +567,39 @@ async def log(
465567 trajectory_groups , f"{ trajectories_dir } /{ file_name } "
466568 )
467569
468- # 2. Calculate aggregate metrics
570+ # 2. Calculate aggregate metrics (excluding additive costs)
571+ cost_step = await self .get_step ()
469572 all_metrics : dict [str , list [float ]] = {"reward" : [], "exception_rate" : []}
470573 group_metrics : dict [str , list [float ]] = {}
574+ cost_components : dict [str , float ] = {}
575+ cost_total_direct = 0.0
576+ cost_seen = False
577+
578+ def _add_costs (metrics_dict : dict [str , float | int | bool ]) -> None :
579+ nonlocal cost_total_direct , cost_seen
580+ for metric , value in metrics_dict .items ():
581+ if not isinstance (value , (int , float )):
582+ continue
583+ if metric == COSTS_TOTAL_KEY :
584+ raise ValueError (
585+ "Do not log 'costs_total' directly. Log costs_* components "
586+ "(e.g., costs_prefill, costs_sample) and totals are derived."
587+ )
588+ elif metric .startswith (COSTS_METRIC_PREFIX ):
589+ component = metric [len (COSTS_METRIC_PREFIX ) :]
590+ if component :
591+ cost_components [component ] = cost_components .get (
592+ component , 0.0
593+ ) + float (value )
594+ cost_seen = True
471595
472596 for group in trajectory_groups :
597+ if group .metrics :
598+ _add_costs (group .metrics )
473599 if group .trajectories :
474600 for metric , value in group .metrics .items ():
601+ if metric .startswith (COSTS_METRIC_PREFIX ):
602+ continue
475603 if metric not in group_metrics :
476604 group_metrics [metric ] = []
477605 group_metrics [metric ].append (float (value ))
@@ -486,9 +614,13 @@ async def log(
486614
487615 # Collect other custom metrics
488616 for metric , value in trajectory .metrics .items ():
617+ if metric .startswith (COSTS_METRIC_PREFIX ):
618+ continue
489619 if metric not in all_metrics :
490620 all_metrics [metric ] = []
491621 all_metrics [metric ].append (float (value ))
622+ if trajectory .metrics :
623+ _add_costs (trajectory .metrics )
492624
493625 # Calculate averages for all metrics
494626 averages : dict [str , float ] = {}
@@ -506,11 +638,26 @@ async def log(
506638
507639 # Merge in any additional metrics passed directly
508640 if metrics is not None :
509- averages .update (metrics )
641+ _add_costs (metrics )
642+ metrics_without_costs = {
643+ key : value
644+ for key , value in metrics .items ()
645+ if not key .startswith (COSTS_METRIC_PREFIX )
646+ }
647+ averages .update (metrics_without_costs )
510648
511649 # 3. Log metrics (writes to history.jsonl and wandb)
512650 self ._log_metrics (averages , split , step )
513651
652+ # 4. Log cumulative costs (additive)
653+ await self ._record_costs (
654+ split ,
655+ cost_step ,
656+ cost_components = cost_components ,
657+ cost_total_direct = cost_total_direct ,
658+ cost_seen = cost_seen ,
659+ )
660+
514661 async def get_step (self ) -> int :
515662 """
516663 Get the model's current training step. For non-trainable models, returns 0.
@@ -559,6 +706,25 @@ def __init__(
559706 report_metrics = report_metrics ,
560707 ** kwargs ,
561708 )
709+ object .__setattr__ (self , "_costs_lock" , asyncio .Lock ())
710+ object .__setattr__ (self , "_cost_calculator" , self ._noop_cost_calculator )
711+
712+ @property
713+ def cost_calculator (self ) -> CostCalculator :
714+ return self ._cost_calculator
715+
716+ def set_cost_calculator (self , calculator : CostCalculator | None ) -> None :
717+ object .__setattr__ (
718+ self ,
719+ "_cost_calculator" ,
720+ calculator if calculator is not None else self ._noop_cost_calculator ,
721+ )
722+
723+ @staticmethod
724+ def _noop_cost_calculator (
725+ _prompt_tokens : int | None , _completion_tokens : int | None
726+ ) -> dict [str , float ]:
727+ return {}
562728 if _internal_config is not None :
563729 # Bypass BaseModel __setattr__ to allow setting private attr
564730 object .__setattr__ (self , "_internal_config" , _internal_config )
0 commit comments