Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions src/art/megatron/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ def freeze_model(model_chunks: list[MegatronModule]) -> list[MegatronModule]:
data_parallel_random_init=False,
)

rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
rank = torch.distributed.get_rank() # ty:ignore[possibly-missing-attribute]
world_size = torch.distributed.get_world_size() # ty:ignore[possibly-missing-attribute]

if rank == 0:
print("TORCHINDUCTOR_CACHE_DIR:", os.environ["TORCHINDUCTOR_CACHE_DIR"])
Expand Down Expand Up @@ -141,7 +141,7 @@ def print0(*values: Any) -> None:
offload_to_cpu(model, optimizer, rank, offload_state)

while True:
torch.distributed.barrier()
torch.distributed.barrier() # ty:ignore[possibly-missing-attribute]
jobs_dir = "/tmp/megatron_training_jobs"
os.makedirs(jobs_dir, exist_ok=True)
job_names = sorted(
Expand Down Expand Up @@ -259,9 +259,9 @@ def print0(*values: Any) -> None:
for param in chunk.parameters():
if param.grad is None:
continue
torch.distributed.all_reduce(
torch.distributed.all_reduce( # ty:ignore[possibly-missing-attribute]
param.grad,
op=torch.distributed.ReduceOp.AVG,
op=torch.distributed.ReduceOp.AVG, # ty:ignore[possibly-missing-attribute]
group=ps.get_data_parallel_group(),
)
num_grads += 1
Expand All @@ -276,7 +276,7 @@ def print0(*values: Any) -> None:
optimizer.zero_grad()

# Mean reduce loss across all ranks for logging
torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.AVG)
torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.AVG) # ty:ignore[possibly-missing-attribute]

if rank == 0:
with open("/tmp/megatron_training_log.jsonl", "a+") as log_file:
Expand Down Expand Up @@ -322,7 +322,7 @@ def print0(*values: Any) -> None:
gc.collect()
torch.cuda.empty_cache()
# Ensure all ranks have finished saving before signaling completion
torch.distributed.barrier()
torch.distributed.barrier() # ty:ignore[possibly-missing-attribute]
if rank == 0:
os.remove(job_path)
with open("/tmp/megatron_training_log.jsonl", "a+") as log_file:
Expand Down
2 changes: 1 addition & 1 deletion src/art/preprocessing/tokenize.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ def tokenize_trajectory(
return TokenizedResult(
advantage=advantage,
chat=chat,
tokens=[cast(str, tokenizer.decode(token_id)) for token_id in token_ids],
tokens=[tokenizer.decode(token_id) for token_id in token_ids],
token_ids=token_ids,
input_pos=list(range(len(token_ids))),
assistant_mask=assistant_mask,
Expand Down
2 changes: 1 addition & 1 deletion src/art/serverless/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ async def create(

@cached_property
def events(self) -> "TrainingJobEvents":
return TrainingJobEvents(cast(AsyncOpenAI, self._client))
return TrainingJobEvents(cast(AsyncOpenAI, self._client)) # ty:ignore[redundant-cast]


class Client(AsyncAPIClient):
Expand Down
4 changes: 2 additions & 2 deletions src/art/tinker/cookbook_v/renderers/qwen3.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,9 +448,9 @@ def _preprocess_message_parts(
base_parts: list[ImagePart | TextPart] = []
for p in content:
if p["type"] == "text":
base_parts.append(cast(TextPart, p))
base_parts.append(p)
elif p["type"] == "image":
base_parts.append(cast(ImagePart, p))
base_parts.append(p)
elif p["type"] == "thinking":
if not strip_thinking:
# Render thinking as <think>...</think> text
Expand Down
2 changes: 1 addition & 1 deletion src/art/tinker/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ async def chat_completions(

def _default_num_workers(self) -> int:
try:
return max(1, len(os.sched_getaffinity(0)))
return max(1, len(os.sched_getaffinity(0))) # ty:ignore[unresolved-attribute]
except (AttributeError, OSError):
return os.cpu_count() or 1

Expand Down
Loading