44import os
55import socket
66import time
7- from typing import Annotated
7+ from typing import Annotated , cast
88import uuid
99
1010from fastapi import FastAPI , HTTPException , Request
@@ -62,10 +62,13 @@ async def prompt_tokens(
6262 messages : list [ChatCompletionMessageParam ],
6363 tools : list [ChatCompletionToolUnionParam ] | None ,
6464 ) -> list [int ]:
65- return self ._get_renderer (base_model ).tokenizer .apply_chat_template (
66- messages , # type: ignore
67- tools = tools , # type: ignore
68- add_generation_prompt = True ,
65+ return cast (
66+ list [int ],
67+ self ._get_renderer (base_model ).tokenizer .apply_chat_template (
68+ messages , # type: ignore
69+ tools = tools , # type: ignore
70+ add_generation_prompt = True ,
71+ ),
6972 )
7073
7174 async def chat_completion_and_token_discrepancies (
@@ -80,9 +83,9 @@ async def chat_completion_and_token_discrepancies(
8083 token_discrepancies : list [tuple [list [int ], list [int ]]] = []
8184 for i , sequence in enumerate (sample_response .sequences ):
8285 assert sequence .logprobs is not None , "Logprobs are required"
83- assert len (sequence .tokens ) == len (
84- sequence . logprobs
85- ), "Tokens and logprobs must have the same length"
86+ assert len (sequence .tokens ) == len (sequence . logprobs ), (
87+ "Tokens and logprobs must have the same length"
88+ )
8689 rendered_response_tokens = renderer .tokenizer .encode (
8790 renderer .tokenizer .decode (sequence .tokens )
8891 )
@@ -222,10 +225,11 @@ async def chat_completions(
222225 detail = "Missing or invalid Authorization header" ,
223226 headers = {"WWW-Authenticate" : "Bearer" },
224227 )
225- sampling_client , base_model = (
226- await self ._get_sampling_client_and_base_model (
227- api_key , self .models .get (body ["model" ], body ["model" ])
228- )
228+ (
229+ sampling_client ,
230+ base_model ,
231+ ) = await self ._get_sampling_client_and_base_model (
232+ api_key , self .models .get (body ["model" ], body ["model" ])
229233 )
230234 rendered_prompt_tokens = await worker .prompt_tokens (
231235 base_model = base_model ,
@@ -265,13 +269,14 @@ async def chat_completions(
265269 else :
266270 detail = str (e )
267271 raise HTTPException (status_code = e .status_code , detail = detail ) from e
268- chat_completion , token_discrepancies = (
269- await worker .chat_completion_and_token_discrepancies (
270- base_model = base_model ,
271- sample_response = sample_response ,
272- model_name = body ["model" ],
273- prompt_tokens = len (prompt_tokens ),
274- )
272+ (
273+ chat_completion ,
274+ token_discrepancies ,
275+ ) = await worker .chat_completion_and_token_discrepancies (
276+ base_model = base_model ,
277+ sample_response = sample_response ,
278+ model_name = body ["model" ],
279+ prompt_tokens = len (prompt_tokens ),
275280 )
276281 for rendered_response_tokens , raw_response_tokens in token_discrepancies :
277282 self ._prefix_cache .insert (
0 commit comments