diff --git a/src/art/megatron/train.py b/src/art/megatron/train.py index 876611a65..480a03be4 100644 --- a/src/art/megatron/train.py +++ b/src/art/megatron/train.py @@ -63,8 +63,8 @@ def freeze_model(model_chunks: list[MegatronModule]) -> list[MegatronModule]: data_parallel_random_init=False, ) -rank = torch.distributed.get_rank() -world_size = torch.distributed.get_world_size() +rank = torch.distributed.get_rank() # ty:ignore[possibly-missing-attribute] +world_size = torch.distributed.get_world_size() # ty:ignore[possibly-missing-attribute] if rank == 0: print("TORCHINDUCTOR_CACHE_DIR:", os.environ["TORCHINDUCTOR_CACHE_DIR"]) @@ -141,7 +141,7 @@ def print0(*values: Any) -> None: offload_to_cpu(model, optimizer, rank, offload_state) while True: - torch.distributed.barrier() + torch.distributed.barrier() # ty:ignore[possibly-missing-attribute] jobs_dir = "/tmp/megatron_training_jobs" os.makedirs(jobs_dir, exist_ok=True) job_names = sorted( @@ -259,9 +259,9 @@ def print0(*values: Any) -> None: for param in chunk.parameters(): if param.grad is None: continue - torch.distributed.all_reduce( + torch.distributed.all_reduce( # ty:ignore[possibly-missing-attribute] param.grad, - op=torch.distributed.ReduceOp.AVG, + op=torch.distributed.ReduceOp.AVG, # ty:ignore[possibly-missing-attribute] group=ps.get_data_parallel_group(), ) num_grads += 1 @@ -276,7 +276,7 @@ def print0(*values: Any) -> None: optimizer.zero_grad() # Mean reduce loss across all ranks for logging - torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.AVG) + torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.AVG) # ty:ignore[possibly-missing-attribute] if rank == 0: with open("/tmp/megatron_training_log.jsonl", "a+") as log_file: @@ -322,7 +322,7 @@ def print0(*values: Any) -> None: gc.collect() torch.cuda.empty_cache() # Ensure all ranks have finished saving before signaling completion - torch.distributed.barrier() + torch.distributed.barrier() # ty:ignore[possibly-missing-attribute] if rank == 0: os.remove(job_path) with open("/tmp/megatron_training_log.jsonl", "a+") as log_file: diff --git a/src/art/preprocessing/tokenize.py b/src/art/preprocessing/tokenize.py index f4d5694e9..df0322f55 100644 --- a/src/art/preprocessing/tokenize.py +++ b/src/art/preprocessing/tokenize.py @@ -347,7 +347,7 @@ def tokenize_trajectory( return TokenizedResult( advantage=advantage, chat=chat, - tokens=[cast(str, tokenizer.decode(token_id)) for token_id in token_ids], + tokens=[tokenizer.decode(token_id) for token_id in token_ids], token_ids=token_ids, input_pos=list(range(len(token_ids))), assistant_mask=assistant_mask, diff --git a/src/art/serverless/client.py b/src/art/serverless/client.py index aa710861e..e7b0263d0 100644 --- a/src/art/serverless/client.py +++ b/src/art/serverless/client.py @@ -250,7 +250,7 @@ async def create( @cached_property def events(self) -> "TrainingJobEvents": - return TrainingJobEvents(cast(AsyncOpenAI, self._client)) + return TrainingJobEvents(cast(AsyncOpenAI, self._client)) # ty:ignore[redundant-cast] class Client(AsyncAPIClient): diff --git a/src/art/tinker/cookbook_v/renderers/qwen3.py b/src/art/tinker/cookbook_v/renderers/qwen3.py index 0e9479046..343a6ece9 100644 --- a/src/art/tinker/cookbook_v/renderers/qwen3.py +++ b/src/art/tinker/cookbook_v/renderers/qwen3.py @@ -448,9 +448,9 @@ def _preprocess_message_parts( base_parts: list[ImagePart | TextPart] = [] for p in content: if p["type"] == "text": - base_parts.append(cast(TextPart, p)) + base_parts.append(p) elif p["type"] == "image": - base_parts.append(cast(ImagePart, p)) + base_parts.append(p) elif p["type"] == "thinking": if not strip_thinking: # Render thinking as ... text diff --git a/src/art/tinker/server.py b/src/art/tinker/server.py index 8a5534094..82a8ab2af 100644 --- a/src/art/tinker/server.py +++ b/src/art/tinker/server.py @@ -296,7 +296,7 @@ async def chat_completions( def _default_num_workers(self) -> int: try: - return max(1, len(os.sched_getaffinity(0))) + return max(1, len(os.sched_getaffinity(0))) # ty:ignore[unresolved-attribute] except (AttributeError, OSError): return os.cpu_count() or 1