@@ -63,8 +63,8 @@ def freeze_model(model_chunks: list[MegatronModule]) -> list[MegatronModule]:
6363 data_parallel_random_init = False ,
6464)
6565
66- rank = torch .distributed .get_rank ()
67- world_size = torch .distributed .get_world_size ()
66+ rank = torch .distributed .get_rank () # ty:ignore[possibly-missing-attribute]
67+ world_size = torch .distributed .get_world_size () # ty:ignore[possibly-missing-attribute]
6868
6969if rank == 0 :
7070 print ("TORCHINDUCTOR_CACHE_DIR:" , os .environ ["TORCHINDUCTOR_CACHE_DIR" ])
@@ -141,7 +141,7 @@ def print0(*values: Any) -> None:
141141offload_to_cpu (model , optimizer , rank , offload_state )
142142
143143while True :
144- torch .distributed .barrier ()
144+ torch .distributed .barrier () # ty:ignore[possibly-missing-attribute]
145145 jobs_dir = "/tmp/megatron_training_jobs"
146146 os .makedirs (jobs_dir , exist_ok = True )
147147 job_names = sorted (
@@ -259,9 +259,9 @@ def print0(*values: Any) -> None:
259259 for param in chunk .parameters ():
260260 if param .grad is None :
261261 continue
262- torch .distributed .all_reduce (
262+ torch .distributed .all_reduce ( # ty:ignore[possibly-missing-attribute]
263263 param .grad ,
264- op = torch .distributed .ReduceOp .AVG ,
264+ op = torch .distributed .ReduceOp .AVG , # ty:ignore[possibly-missing-attribute]
265265 group = ps .get_data_parallel_group (),
266266 )
267267 num_grads += 1
@@ -276,7 +276,7 @@ def print0(*values: Any) -> None:
276276 optimizer .zero_grad ()
277277
278278 # Mean reduce loss across all ranks for logging
279- torch .distributed .all_reduce (loss , op = torch .distributed .ReduceOp .AVG )
279+ torch .distributed .all_reduce (loss , op = torch .distributed .ReduceOp .AVG ) # ty:ignore[possibly-missing-attribute]
280280
281281 if rank == 0 :
282282 with open ("/tmp/megatron_training_log.jsonl" , "a+" ) as log_file :
@@ -322,7 +322,7 @@ def print0(*values: Any) -> None:
322322 gc .collect ()
323323 torch .cuda .empty_cache ()
324324 # Ensure all ranks have finished saving before signaling completion
325- torch .distributed .barrier ()
325+ torch .distributed .barrier () # ty:ignore[possibly-missing-attribute]
326326 if rank == 0 :
327327 os .remove (job_path )
328328 with open ("/tmp/megatron_training_log.jsonl" , "a+" ) as log_file :
0 commit comments