Skip to content

Commit 5593918

Browse files
arcticflyclaude
andauthored
Fix ty type checker errors and warnings (#606)
Suppress false-positive torch.distributed warnings, remove redundant casts, and suppress platform-specific os.sched_getaffinity error. Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
1 parent aed1ee7 commit 5593918

5 files changed

Lines changed: 12 additions & 12 deletions

File tree

src/art/megatron/train.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,8 @@ def freeze_model(model_chunks: list[MegatronModule]) -> list[MegatronModule]:
6363
data_parallel_random_init=False,
6464
)
6565

66-
rank = torch.distributed.get_rank()
67-
world_size = torch.distributed.get_world_size()
66+
rank = torch.distributed.get_rank() # ty:ignore[possibly-missing-attribute]
67+
world_size = torch.distributed.get_world_size() # ty:ignore[possibly-missing-attribute]
6868

6969
if rank == 0:
7070
print("TORCHINDUCTOR_CACHE_DIR:", os.environ["TORCHINDUCTOR_CACHE_DIR"])
@@ -141,7 +141,7 @@ def print0(*values: Any) -> None:
141141
offload_to_cpu(model, optimizer, rank, offload_state)
142142

143143
while True:
144-
torch.distributed.barrier()
144+
torch.distributed.barrier() # ty:ignore[possibly-missing-attribute]
145145
jobs_dir = "/tmp/megatron_training_jobs"
146146
os.makedirs(jobs_dir, exist_ok=True)
147147
job_names = sorted(
@@ -259,9 +259,9 @@ def print0(*values: Any) -> None:
259259
for param in chunk.parameters():
260260
if param.grad is None:
261261
continue
262-
torch.distributed.all_reduce(
262+
torch.distributed.all_reduce( # ty:ignore[possibly-missing-attribute]
263263
param.grad,
264-
op=torch.distributed.ReduceOp.AVG,
264+
op=torch.distributed.ReduceOp.AVG, # ty:ignore[possibly-missing-attribute]
265265
group=ps.get_data_parallel_group(),
266266
)
267267
num_grads += 1
@@ -276,7 +276,7 @@ def print0(*values: Any) -> None:
276276
optimizer.zero_grad()
277277

278278
# Mean reduce loss across all ranks for logging
279-
torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.AVG)
279+
torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.AVG) # ty:ignore[possibly-missing-attribute]
280280

281281
if rank == 0:
282282
with open("/tmp/megatron_training_log.jsonl", "a+") as log_file:
@@ -322,7 +322,7 @@ def print0(*values: Any) -> None:
322322
gc.collect()
323323
torch.cuda.empty_cache()
324324
# Ensure all ranks have finished saving before signaling completion
325-
torch.distributed.barrier()
325+
torch.distributed.barrier() # ty:ignore[possibly-missing-attribute]
326326
if rank == 0:
327327
os.remove(job_path)
328328
with open("/tmp/megatron_training_log.jsonl", "a+") as log_file:

src/art/preprocessing/tokenize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,7 @@ def tokenize_trajectory(
347347
return TokenizedResult(
348348
advantage=advantage,
349349
chat=chat,
350-
tokens=[cast(str, tokenizer.decode(token_id)) for token_id in token_ids],
350+
tokens=[tokenizer.decode(token_id) for token_id in token_ids],
351351
token_ids=token_ids,
352352
input_pos=list(range(len(token_ids))),
353353
assistant_mask=assistant_mask,

src/art/serverless/client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,7 @@ async def create(
250250

251251
@cached_property
252252
def events(self) -> "TrainingJobEvents":
253-
return TrainingJobEvents(cast(AsyncOpenAI, self._client))
253+
return TrainingJobEvents(cast(AsyncOpenAI, self._client)) # ty:ignore[redundant-cast]
254254

255255

256256
class Client(AsyncAPIClient):

src/art/tinker/cookbook_v/renderers/qwen3.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -448,9 +448,9 @@ def _preprocess_message_parts(
448448
base_parts: list[ImagePart | TextPart] = []
449449
for p in content:
450450
if p["type"] == "text":
451-
base_parts.append(cast(TextPart, p))
451+
base_parts.append(p)
452452
elif p["type"] == "image":
453-
base_parts.append(cast(ImagePart, p))
453+
base_parts.append(p)
454454
elif p["type"] == "thinking":
455455
if not strip_thinking:
456456
# Render thinking as <think>...</think> text

src/art/tinker/server.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,7 @@ async def chat_completions(
296296

297297
def _default_num_workers(self) -> int:
298298
try:
299-
return max(1, len(os.sched_getaffinity(0)))
299+
return max(1, len(os.sched_getaffinity(0))) # ty:ignore[unresolved-attribute]
300300
except (AttributeError, OSError):
301301
return os.cpu_count() or 1
302302

0 commit comments

Comments
 (0)