Skip to content

Commit 13f9d3b

Browse files
committed
refactor: Clean up assertion formatting and message handling in LocalBackend
- Simplified the message assignment logic in LocalBackend for better readability. - Improved the formatting of assertions to enhance clarity. These changes contribute to cleaner and more maintainable code in the LocalBackend class.
1 parent a437c9d commit 13f9d3b

1 file changed

Lines changed: 4 additions & 6 deletions

File tree

src/art/local/backend.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -416,9 +416,7 @@ def _trajectory_log(self, trajectory: Trajectory) -> str:
416416
if isinstance(message_or_choice, dict):
417417
message = message_or_choice
418418
else:
419-
message = cast(
420-
Message, message_or_choice.message.model_dump()
421-
) # ty:ignore[possibly-missing-attribute]
419+
message = cast(Message, message_or_choice.message.model_dump()) # ty:ignore[possibly-missing-attribute]
422420
formatted_messages.append(format_message(message))
423421
return header + "\n".join(formatted_messages)
424422

@@ -704,9 +702,9 @@ async def _train_model(
704702
num_gradient_steps = int(
705703
result.pop("num_gradient_steps", estimated_gradient_steps)
706704
)
707-
assert (
708-
num_gradient_steps == estimated_gradient_steps
709-
), f"num_gradient_steps {num_gradient_steps} != estimated_gradient_steps {estimated_gradient_steps}"
705+
assert num_gradient_steps == estimated_gradient_steps, (
706+
f"num_gradient_steps {num_gradient_steps} != estimated_gradient_steps {estimated_gradient_steps}"
707+
)
710708
results.append(result)
711709
yield {**result, "num_gradient_steps": num_gradient_steps}
712710
pbar.update(1)

0 commit comments

Comments
 (0)