Skip to content

Commit 0a3752e

Browse files
authored
Require model@step inference names and migrate call sites (#554)
1 parent e57ab66 commit 0a3752e

26 files changed

Lines changed: 143 additions & 106 deletions

dev/math-vista/math-vista.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@
128128
" }\n",
129129
" ]\n",
130130
" chat_completion = await client.chat.completions.create(\n",
131-
" model=model.name, messages=trajectory.messages()\n",
131+
" model=model.get_inference_name(), messages=trajectory.messages()\n",
132132
" )\n",
133133
" choice = chat_completion.choices[0]\n",
134134
" trajectory.messages_and_choices.append(choice)\n",

dev/math-vista/math-vista.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ async def rollout(scenario: Scenario) -> art.Trajectory:
6161
]
6262

6363
chat_completion = await client.chat.completions.create(
64-
model=model.name, messages=trajectory.messages()
64+
model=model.get_inference_name(), messages=trajectory.messages()
6565
)
6666
choice = chat_completion.choices[0]
6767
trajectory.messages_and_choices.append(choice)

dev/new_models/benchmark_inference.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,13 @@ async def main():
7777
iteration_start = time.perf_counter()
7878
# launch concurrent requests and time each individually
7979
tasks = [
80-
timed_request(client, model.name, prompt, max_tokens, temperature)
80+
timed_request(
81+
client,
82+
model.get_inference_name(),
83+
prompt,
84+
max_tokens,
85+
temperature,
86+
)
8187
for _ in range(concurrency)
8288
]
8389
# Wait for all responses

dev/new_models/gemma3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ async def rollout(model: art.TrainableModel, prompt: str) -> art.Trajectory:
1919
client = model.openai_client()
2020
chat_completion = await client.chat.completions.create(
2121
messages=messages,
22-
model=model.name,
22+
model=model.get_inference_name(),
2323
max_tokens=100,
2424
timeout=100,
2525
)

dev/new_models/qwen3_try.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
" client = model.openai_client()\n",
3333
" chat_completion = await client.chat.completions.create(\n",
3434
" messages=messages,\n",
35-
" model=model.name,\n",
35+
" model=model.get_inference_name(),\n",
3636
" max_tokens=100,\n",
3737
" timeout=100,\n",
3838
" extra_body={\"chat_template_kwargs\": {\"enable_thinking\": False}},\n",

dev/new_models/qwen3_try.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ async def rollout(model: art.TrainableModel, prompt: str) -> art.Trajectory:
1919
client = model.openai_client()
2020
chat_completion = await client.chat.completions.create(
2121
messages=messages,
22-
model=model.name,
22+
model=model.get_inference_name(),
2323
max_tokens=100,
2424
timeout=100,
2525
extra_body={"chat_template_kwargs": {"enable_thinking": False}},

dev/yes-no-maybe-megatron.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,8 @@ async def main():
6464
train_groups = await art.gather_trajectory_groups(
6565
(
6666
art.TrajectoryGroup(
67-
rollout(openai_client, model.name, prompt) for _ in range(32)
67+
rollout(openai_client, model.get_inference_name(), prompt)
68+
for _ in range(32)
6869
)
6970
for prompt in prompts
7071
)

dev/yes-no-maybe-vision/train.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@
6060
" }\n",
6161
" ]\n",
6262
" chat_completion = await client.chat.completions.create(\n",
63-
" model=model.name, messages=messages, max_tokens=100, timeout=100\n",
63+
" model=model.get_inference_name(), messages=messages, max_tokens=100, timeout=100\n",
6464
" )\n",
6565
" choice = chat_completion.choices[0]\n",
6666
" content = choice.message.content\n",

dev/yes-no-maybe.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@
6565
" }\n",
6666
" ]\n",
6767
" chat_completion = await client.chat.completions.create(\n",
68-
" messages=messages, model=model.name, max_tokens=100, timeout=100\n",
68+
" messages=messages, model=model.get_inference_name(), max_tokens=100, timeout=100\n",
6969
" )\n",
7070
" choice = chat_completion.choices[0]\n",
7171
" content = choice.message.content\n",

dev/yes-no-maybe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ async def rollout(client: openai.AsyncOpenAI, prompt: str) -> art.Trajectory:
1717
}
1818
]
1919
chat_completion = await client.chat.completions.create(
20-
messages=messages, model=model.name, max_tokens=100, timeout=100
20+
messages=messages, model=model.get_inference_name(), max_tokens=100, timeout=100
2121
)
2222
choice = chat_completion.choices[0]
2323
content = choice.message.content

0 commit comments

Comments
 (0)