Skip to content

Commit 524feb0

Browse files
authored
fix: resolve pre-existing CI failures (#578)
1 parent 6b9f5c7 commit 524feb0

2 files changed

Lines changed: 27 additions & 22 deletions

File tree

src/art/cli.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -245,9 +245,9 @@ def migrate(
245245
model_dir,
246246
delete_originals=not keep_jsonl,
247247
dry_run=dry_run,
248-
progress_callback=lambda f: typer.echo(f" {f}")
249-
if verbose
250-
else None,
248+
progress_callback=lambda f: (
249+
typer.echo(f" {f}") if verbose else None
250+
),
251251
)
252252
result = result + model_result
253253
else:

src/art/tinker/server.py

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import os
55
import socket
66
import time
7-
from typing import Annotated
7+
from typing import Annotated, cast
88
import uuid
99

1010
from fastapi import FastAPI, HTTPException, Request
@@ -62,10 +62,13 @@ async def prompt_tokens(
6262
messages: list[ChatCompletionMessageParam],
6363
tools: list[ChatCompletionToolUnionParam] | None,
6464
) -> list[int]:
65-
return self._get_renderer(base_model).tokenizer.apply_chat_template(
66-
messages, # type: ignore
67-
tools=tools, # type: ignore
68-
add_generation_prompt=True,
65+
return cast(
66+
list[int],
67+
self._get_renderer(base_model).tokenizer.apply_chat_template(
68+
messages, # type: ignore
69+
tools=tools, # type: ignore
70+
add_generation_prompt=True,
71+
),
6972
)
7073

7174
async def chat_completion_and_token_discrepancies(
@@ -80,9 +83,9 @@ async def chat_completion_and_token_discrepancies(
8083
token_discrepancies: list[tuple[list[int], list[int]]] = []
8184
for i, sequence in enumerate(sample_response.sequences):
8285
assert sequence.logprobs is not None, "Logprobs are required"
83-
assert len(sequence.tokens) == len(
84-
sequence.logprobs
85-
), "Tokens and logprobs must have the same length"
86+
assert len(sequence.tokens) == len(sequence.logprobs), (
87+
"Tokens and logprobs must have the same length"
88+
)
8689
rendered_response_tokens = renderer.tokenizer.encode(
8790
renderer.tokenizer.decode(sequence.tokens)
8891
)
@@ -222,10 +225,11 @@ async def chat_completions(
222225
detail="Missing or invalid Authorization header",
223226
headers={"WWW-Authenticate": "Bearer"},
224227
)
225-
sampling_client, base_model = (
226-
await self._get_sampling_client_and_base_model(
227-
api_key, self.models.get(body["model"], body["model"])
228-
)
228+
(
229+
sampling_client,
230+
base_model,
231+
) = await self._get_sampling_client_and_base_model(
232+
api_key, self.models.get(body["model"], body["model"])
229233
)
230234
rendered_prompt_tokens = await worker.prompt_tokens(
231235
base_model=base_model,
@@ -265,13 +269,14 @@ async def chat_completions(
265269
else:
266270
detail = str(e)
267271
raise HTTPException(status_code=e.status_code, detail=detail) from e
268-
chat_completion, token_discrepancies = (
269-
await worker.chat_completion_and_token_discrepancies(
270-
base_model=base_model,
271-
sample_response=sample_response,
272-
model_name=body["model"],
273-
prompt_tokens=len(prompt_tokens),
274-
)
272+
(
273+
chat_completion,
274+
token_discrepancies,
275+
) = await worker.chat_completion_and_token_discrepancies(
276+
base_model=base_model,
277+
sample_response=sample_response,
278+
model_name=body["model"],
279+
prompt_tokens=len(prompt_tokens),
275280
)
276281
for rendered_response_tokens, raw_response_tokens in token_discrepancies:
277282
self._prefix_cache.insert(

0 commit comments

Comments
 (0)