Skip to content

Commit 6c3be3f

Browse files
committed
refactor: Refactor token handling in LocalBackend and TokenizedResult
- Updated LocalBackend to use `token_ids` instead of `tokens` for max token calculation. - Introduced a cached property in TokenizedResult to dynamically generate `tokens` from `token_ids`. - Cleaned up assertion formatting in LocalBackend for better readability. These changes enhance the clarity and efficiency of token management in the codebase.
1 parent d69345e commit 6c3be3f

2 files changed

Lines changed: 16 additions & 9 deletions

File tree

src/art/local/backend.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ def _get_packed_tensors(
244244
)
245245
if not tokenized_results:
246246
return None
247-
max_tokens = max(len(result.tokens) for result in tokenized_results)
247+
max_tokens = max(len(result.token_ids) for result in tokenized_results)
248248
# Round up max_tokens to the nearest multiple of 2048
249249
sequence_length = math.ceil(max_tokens / 2048) * 2048
250250
# Cap sequence length at the model's max sequence length
@@ -416,7 +416,9 @@ def _trajectory_log(self, trajectory: Trajectory) -> str:
416416
if isinstance(message_or_choice, dict):
417417
message = message_or_choice
418418
else:
419-
message = cast(Message, message_or_choice.message.model_dump()) # ty:ignore[possibly-missing-attribute]
419+
message = cast(
420+
Message, message_or_choice.message.model_dump()
421+
) # ty:ignore[possibly-missing-attribute]
420422
formatted_messages.append(format_message(message))
421423
return header + "\n".join(formatted_messages)
422424

@@ -702,9 +704,9 @@ async def _train_model(
702704
num_gradient_steps = int(
703705
result.pop("num_gradient_steps", estimated_gradient_steps)
704706
)
705-
assert num_gradient_steps == estimated_gradient_steps, (
706-
f"num_gradient_steps {num_gradient_steps} != estimated_gradient_steps {estimated_gradient_steps}"
707-
)
707+
assert (
708+
num_gradient_steps == estimated_gradient_steps
709+
), f"num_gradient_steps {num_gradient_steps} != estimated_gradient_steps {estimated_gradient_steps}"
708710
results.append(result)
709711
yield {**result, "num_gradient_steps": num_gradient_steps}
710712
pbar.update(1)

src/art/preprocessing/tokenize.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from dataclasses import dataclass
1+
from dataclasses import dataclass, field
2+
from functools import cached_property
23
from itertools import takewhile
34
import math
45
import random
@@ -16,30 +17,34 @@
1617
class TokenizedResult:
1718
advantage: float
1819
chat: str
19-
tokens: list[str]
2020
token_ids: list[int]
2121
input_pos: list[int]
2222
assistant_mask: list[int]
2323
logprobs: list[float]
2424
pixel_values: torch.Tensor | None
2525
image_grid_thw: torch.Tensor | None
2626
trajectory: Trajectory
27+
_tokenizer: "PreTrainedTokenizerBase" = field(repr=False, compare=False)
2728
weight: float = 0.0
2829
prompt_id: int = 0
2930
prompt_length: int = 0
3031

32+
@cached_property
33+
def tokens(self) -> list[str]:
34+
return [self._tokenizer.decode(token_id) for token_id in self.token_ids]
35+
3136
def without_prompt(self) -> "TokenizedResult":
3237
return TokenizedResult(
3338
advantage=self.advantage,
3439
chat=self.chat,
35-
tokens=self.tokens[self.prompt_length :],
3640
token_ids=self.token_ids[self.prompt_length :],
3741
input_pos=self.input_pos[self.prompt_length :],
3842
assistant_mask=self.assistant_mask[self.prompt_length :],
3943
logprobs=self.logprobs[self.prompt_length :],
4044
pixel_values=None,
4145
image_grid_thw=None,
4246
trajectory=self.trajectory,
47+
_tokenizer=self._tokenizer,
4348
weight=self.weight,
4449
prompt_id=self.prompt_id,
4550
prompt_length=0,
@@ -347,14 +352,14 @@ def tokenize_trajectory(
347352
return TokenizedResult(
348353
advantage=advantage,
349354
chat=chat,
350-
tokens=[tokenizer.decode(token_id) for token_id in token_ids],
351355
token_ids=token_ids,
352356
input_pos=list(range(len(token_ids))),
353357
assistant_mask=assistant_mask,
354358
logprobs=logprobs,
355359
pixel_values=pixel_values,
356360
image_grid_thw=image_grid_thw,
357361
trajectory=trajectory,
362+
_tokenizer=tokenizer,
358363
)
359364

360365

0 commit comments

Comments
 (0)