From f0f7774beb47ffa48e0d1621ec997e4455db006e Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 6 Nov 2025 03:53:03 +0000 Subject: [PATCH 01/17] update budgets for a100 hardware weightclass --- algoperf/workloads/criteo1tb/workload.py | 4 ++-- algoperf/workloads/fastmri/workload.py | 4 ++-- .../workloads/imagenet_resnet/workload.py | 4 ++-- algoperf/workloads/imagenet_vit/workload.py | 4 ++-- .../librispeech_conformer/workload.py | 4 ++-- .../librispeech_jax/workload.py | 6 +++++- .../librispeech_pytorch/workload.py | 6 +++++- algoperf/workloads/ogbg/workload.py | 4 ++-- algoperf/workloads/wmt/workload.py | 4 ++-- docker/build_docker_images.sh | 14 ++++++------- scoring/performance_profile.py | 1 + scoring/score_submissions.py | 4 +++- scoring/scoring_utils.py | 20 +++++++++++++++++++ scoring/utils/run_workloads.py | 7 ++++++- .../workload_metadata_external_tuning.json | 2 +- 15 files changed, 62 insertions(+), 26 deletions(-) diff --git a/algoperf/workloads/criteo1tb/workload.py b/algoperf/workloads/criteo1tb/workload.py index 2cb7e5450..fb38eacc3 100644 --- a/algoperf/workloads/criteo1tb/workload.py +++ b/algoperf/workloads/criteo1tb/workload.py @@ -95,11 +95,11 @@ def train_stddev(self): @property def max_allowed_runtime_sec(self) -> int: - return 7_703 # ~2.1 hours. + return 8915 # ~2.4 hours. @property def eval_period_time_sec(self) -> int: - return 2 * 60 # 2 mins. + return 356 # approx 25 evals def _build_input_queue( self, diff --git a/algoperf/workloads/fastmri/workload.py b/algoperf/workloads/fastmri/workload.py index 0b1ecfaa1..5a8afa2e9 100644 --- a/algoperf/workloads/fastmri/workload.py +++ b/algoperf/workloads/fastmri/workload.py @@ -95,11 +95,11 @@ def accelerations(self): @property def max_allowed_runtime_sec(self) -> int: - return 4_430 # ~1.2 hours + return 2745 # ~0.7 hours @property def eval_period_time_sec(self) -> int: - return 80 + return 110 # approx 25 evals @property def step_hint(self) -> int: diff --git a/algoperf/workloads/imagenet_resnet/workload.py b/algoperf/workloads/imagenet_resnet/workload.py index ef696e328..b5263e0a6 100644 --- a/algoperf/workloads/imagenet_resnet/workload.py +++ b/algoperf/workloads/imagenet_resnet/workload.py @@ -103,11 +103,11 @@ def resize_size(self) -> int: @property def max_allowed_runtime_sec(self) -> int: - return 66_159 # ~18.4 hours + return 49918 # ~13.8 hours @property def eval_period_time_sec(self) -> int: - return 510 # 8.5 minutes. + return 1996 # approx 25 evals def _build_dataset( self, diff --git a/algoperf/workloads/imagenet_vit/workload.py b/algoperf/workloads/imagenet_vit/workload.py index 2a0070ba4..f8f4f2659 100644 --- a/algoperf/workloads/imagenet_vit/workload.py +++ b/algoperf/workloads/imagenet_vit/workload.py @@ -88,11 +88,11 @@ def eval_batch_size(self) -> int: @property def max_allowed_runtime_sec(self) -> int: - return 69_768 # ~19.4 hours + return 64_292 # ~17.8 hours @property def eval_period_time_sec(self) -> int: - return 7 * 60 # 7 mins. + return 2571 # 7 mins. def _build_dataset( self, diff --git a/algoperf/workloads/librispeech_conformer/workload.py b/algoperf/workloads/librispeech_conformer/workload.py index 791270719..327e8bc39 100644 --- a/algoperf/workloads/librispeech_conformer/workload.py +++ b/algoperf/workloads/librispeech_conformer/workload.py @@ -80,11 +80,11 @@ def train_stddev(self): @property def max_allowed_runtime_sec(self) -> int: - return 58_015 # ~16.1 hours + return 43680 # ~16.1 hours @property def eval_period_time_sec(self) -> int: - return 24 * 60 + return 1747 # approx 25 evals @property def step_hint(self) -> int: diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py index 3a320b0dd..2a8fd29d0 100644 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py @@ -100,7 +100,11 @@ def step_hint(self) -> int: @property def max_allowed_runtime_sec(self) -> int: - return 44_405 # ~12.3 hours + return 36_949 # ~12.3 hours + + @property + def eval_period_time_sec(self) -> int: + return 1447 # approx 25 evals @property def use_tanh(self) -> bool: diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py index 672f3440f..119049b34 100644 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py @@ -96,7 +96,11 @@ def step_hint(self) -> int: @property def max_allowed_runtime_sec(self) -> int: - return 44_405 # ~12.3 hours + return 36949 # 10.3 hours + + @property + def eval_period_time_sec(self) -> int: + return 1447 # approx 25 evals @property def use_tanh(self) -> bool: diff --git a/algoperf/workloads/ogbg/workload.py b/algoperf/workloads/ogbg/workload.py index 8717e46d6..53206200f 100644 --- a/algoperf/workloads/ogbg/workload.py +++ b/algoperf/workloads/ogbg/workload.py @@ -88,11 +88,11 @@ def train_stddev(self): @property def max_allowed_runtime_sec(self) -> int: - return 12_011 # ~3.3 hours + return 11303 # ~3.1 hours @property def eval_period_time_sec(self) -> int: - return 4 * 60 + return 452. # approx 25 evals def _build_input_queue( self, diff --git a/algoperf/workloads/wmt/workload.py b/algoperf/workloads/wmt/workload.py index 40e4262dd..d972a5486 100644 --- a/algoperf/workloads/wmt/workload.py +++ b/algoperf/workloads/wmt/workload.py @@ -89,11 +89,11 @@ def train_stddev(self): @property def max_allowed_runtime_sec(self) -> int: - return 43_336 # ~12.0 hours + return 16114 # ~12.0 hours @property def eval_period_time_sec(self) -> int: - return 14 * 60 + return 644 @property def step_hint(self) -> int: diff --git a/docker/build_docker_images.sh b/docker/build_docker_images.sh index 6b5e67ceb..22590b9fd 100644 --- a/docker/build_docker_images.sh +++ b/docker/build_docker_images.sh @@ -27,7 +27,7 @@ then GIT_BRANCH='main' # Set default argument fi -FRAMEWORKS=( "jax" "pythorch" "both" ) +FRAMEWORKS=( "jax" "pytorch") if [[ -n "$FRAMEWORK" ]]; then @@ -45,10 +45,10 @@ do echo "On branch: ${GIT_BRANCH}" echo $DOCKER_BUILD_COMMAND eval $DOCKER_BUILD_COMMAND - echo $DOCKER_TAG_COMMAND - eval $DOCKER_TAG_COMMAND - echo $DOCKER_PUSH_COMMAND - eval $DOCKER_PUSH_COMMAND - echo "To pull container run: " - echo $DOCKER_PULL_COMMAND + # echo $DOCKER_TAG_COMMAND + # eval $DOCKER_TAG_COMMAND + # echo $DOCKER_PUSH_COMMAND + # eval $DOCKER_PUSH_COMMAND + # echo "To pull container run: " + # echo $DOCKER_PULL_COMMAND done diff --git a/scoring/performance_profile.py b/scoring/performance_profile.py index 4f2ae9c57..b200c6865 100644 --- a/scoring/performance_profile.py +++ b/scoring/performance_profile.py @@ -71,6 +71,7 @@ 'wer', 'l1_loss', 'loss', + 'ppl' ] MAX_EVAL_METRICS = ['mean_average_precision', 'ssim', 'accuracy', 'bleu'] diff --git a/scoring/score_submissions.py b/scoring/score_submissions.py index 3423df2e1..4b7bed2b5 100644 --- a/scoring/score_submissions.py +++ b/scoring/score_submissions.py @@ -123,6 +123,8 @@ def get_summary_df(workload, workload_df, include_test_split=False): workload_df['accumulated_submission_time'] / workload_df['global_step'] ).iloc[-1][-1] + summary_df['step_hint'] = scoring_utils.get_workload_stephint(workload) + # test metrics if include_test_split: test_metric, test_target = scoring_utils.get_workload_metrics_and_targets( @@ -157,7 +159,7 @@ def get_summary_df(workload, workload_df, include_test_split=False): return summary_df -def get_submission_summary(df, include_test_split=True): +def get_submission_summary(df, include_test_split=False): """Summarizes the submission results into metric and time tables organized by workload. """ diff --git a/scoring/scoring_utils.py b/scoring/scoring_utils.py index 5be6c790c..cb63eab4b 100644 --- a/scoring/scoring_utils.py +++ b/scoring/scoring_utils.py @@ -240,3 +240,23 @@ def get_workload_metrics_and_targets(workload, split='validation'): metric = f'test/{metric_name}' target = workload_obj.test_target_value return metric, target + + +def get_workload_stephint(workload): + workload_name = re.match(WORKLOAD_NAME_PATTERN, workload).group(1) + framework = re.match(WORKLOAD_NAME_PATTERN, workload).group(2) + workload_metadata = copy.copy(WORKLOADS[workload_name]) + + # Extend path according to framework. + workload_metadata['workload_path'] = os.path.join( + BASE_WORKLOADS_DIR, + workload_metadata['workload_path'] + f'{framework}', + 'workload.py', + ) + workload_init_kwargs = {} + workload_obj = workloads_registry.import_workload( + workload_path=workload_metadata['workload_path'], + workload_class_name=workload_metadata['workload_class_name'], + workload_init_kwargs=workload_init_kwargs, + ) + return workload_obj.step_hint diff --git a/scoring/utils/run_workloads.py b/scoring/utils/run_workloads.py index 273881c5a..c6764e9de 100644 --- a/scoring/utils/run_workloads.py +++ b/scoring/utils/run_workloads.py @@ -241,7 +241,8 @@ def main(_): # For each runnable workload check if there are any containers running and if not launch next container command for workload in workloads: - run_key = prng.fold_in(rng_subkey, hash(workload)) + workload_foldin = hash(workload) % 9 + run_key = prng.fold_in(rng_subkey, workload_foldin) run_seed = run_key[0] # arbitrary base_workload_name = get_base_workload_name(workload) wait_until_container_not_running() @@ -270,6 +271,10 @@ def main(_): 'docker run -t -d -v /home/kasimbeg/data/:/data/ ' '-v /home/kasimbeg/experiment_runs/:/experiment_runs ' '-v /home/kasimbeg/experiment_runs/logs:/logs ' +<<<<<<< Updated upstream +======= + '-v /home/kasimbeg/algorithmic-efficiency:/algorithmic-efficiency ' +>>>>>>> Stashed changes f'{mount_repo_flag}' '--gpus all --ipc=host ' f'{docker_image_url} ' diff --git a/scoring/utils/workload_metadata_external_tuning.json b/scoring/utils/workload_metadata_external_tuning.json index c7d4ae195..3d9f78ca1 100644 --- a/scoring/utils/workload_metadata_external_tuning.json +++ b/scoring/utils/workload_metadata_external_tuning.json @@ -24,7 +24,7 @@ "dataset": "librispeech" }, "criteo1tb": { - "max_steps": 10666, + "max_steps": 15666, "dataset": "criteo1tb" }, "librispeech_conformer": { From b93eb3ca97871ed30ddd6b08806a3bbc1ca0bdae Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 6 Nov 2025 03:56:48 +0000 Subject: [PATCH 02/17] formatting --- algoperf/workloads/criteo1tb/workload.py | 2 +- algoperf/workloads/fastmri/workload.py | 2 +- algoperf/workloads/imagenet_resnet/workload.py | 4 ++-- algoperf/workloads/imagenet_vit/workload.py | 2 +- algoperf/workloads/librispeech_conformer/workload.py | 2 +- .../librispeech_deepspeech/librispeech_pytorch/workload.py | 2 +- algoperf/workloads/ogbg/workload.py | 4 ++-- algoperf/workloads/wmt/workload.py | 2 +- 8 files changed, 10 insertions(+), 10 deletions(-) diff --git a/algoperf/workloads/criteo1tb/workload.py b/algoperf/workloads/criteo1tb/workload.py index fb38eacc3..4d2196cd5 100644 --- a/algoperf/workloads/criteo1tb/workload.py +++ b/algoperf/workloads/criteo1tb/workload.py @@ -95,7 +95,7 @@ def train_stddev(self): @property def max_allowed_runtime_sec(self) -> int: - return 8915 # ~2.4 hours. + return 8_915 # ~2.4 hours. @property def eval_period_time_sec(self) -> int: diff --git a/algoperf/workloads/fastmri/workload.py b/algoperf/workloads/fastmri/workload.py index 5a8afa2e9..b87dfc755 100644 --- a/algoperf/workloads/fastmri/workload.py +++ b/algoperf/workloads/fastmri/workload.py @@ -95,7 +95,7 @@ def accelerations(self): @property def max_allowed_runtime_sec(self) -> int: - return 2745 # ~0.7 hours + return 2_745 # ~0.7 hours @property def eval_period_time_sec(self) -> int: diff --git a/algoperf/workloads/imagenet_resnet/workload.py b/algoperf/workloads/imagenet_resnet/workload.py index b5263e0a6..de8458c92 100644 --- a/algoperf/workloads/imagenet_resnet/workload.py +++ b/algoperf/workloads/imagenet_resnet/workload.py @@ -103,11 +103,11 @@ def resize_size(self) -> int: @property def max_allowed_runtime_sec(self) -> int: - return 49918 # ~13.8 hours + return 49_918 # ~13.8 hours @property def eval_period_time_sec(self) -> int: - return 1996 # approx 25 evals + return 1_996 # approx 25 evals def _build_dataset( self, diff --git a/algoperf/workloads/imagenet_vit/workload.py b/algoperf/workloads/imagenet_vit/workload.py index f8f4f2659..4da02614f 100644 --- a/algoperf/workloads/imagenet_vit/workload.py +++ b/algoperf/workloads/imagenet_vit/workload.py @@ -92,7 +92,7 @@ def max_allowed_runtime_sec(self) -> int: @property def eval_period_time_sec(self) -> int: - return 2571 # 7 mins. + return 2_571 # 7 mins. def _build_dataset( self, diff --git a/algoperf/workloads/librispeech_conformer/workload.py b/algoperf/workloads/librispeech_conformer/workload.py index 327e8bc39..5a0a546e4 100644 --- a/algoperf/workloads/librispeech_conformer/workload.py +++ b/algoperf/workloads/librispeech_conformer/workload.py @@ -80,7 +80,7 @@ def train_stddev(self): @property def max_allowed_runtime_sec(self) -> int: - return 43680 # ~16.1 hours + return 43_680 # ~16.1 hours @property def eval_period_time_sec(self) -> int: diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py index 119049b34..c6bb149f7 100644 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py @@ -96,7 +96,7 @@ def step_hint(self) -> int: @property def max_allowed_runtime_sec(self) -> int: - return 36949 # 10.3 hours + return 36_949 # 10.3 hours @property def eval_period_time_sec(self) -> int: diff --git a/algoperf/workloads/ogbg/workload.py b/algoperf/workloads/ogbg/workload.py index 53206200f..002576268 100644 --- a/algoperf/workloads/ogbg/workload.py +++ b/algoperf/workloads/ogbg/workload.py @@ -88,11 +88,11 @@ def train_stddev(self): @property def max_allowed_runtime_sec(self) -> int: - return 11303 # ~3.1 hours + return 11_303 # ~3.1 hours @property def eval_period_time_sec(self) -> int: - return 452. # approx 25 evals + return 452 # approx 25 evals def _build_input_queue( self, diff --git a/algoperf/workloads/wmt/workload.py b/algoperf/workloads/wmt/workload.py index d972a5486..2e232214e 100644 --- a/algoperf/workloads/wmt/workload.py +++ b/algoperf/workloads/wmt/workload.py @@ -89,7 +89,7 @@ def train_stddev(self): @property def max_allowed_runtime_sec(self) -> int: - return 16114 # ~12.0 hours + return 16_114 # ~12.0 hours @property def eval_period_time_sec(self) -> int: From 88b0e47fe9694d35d651a77c1acf8ea9491df5ab Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 6 Nov 2025 03:57:34 +0000 Subject: [PATCH 03/17] revert changes to docker build shell script --- docker/build_docker_images.sh | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/docker/build_docker_images.sh b/docker/build_docker_images.sh index 22590b9fd..aa94222ea 100644 --- a/docker/build_docker_images.sh +++ b/docker/build_docker_images.sh @@ -45,10 +45,10 @@ do echo "On branch: ${GIT_BRANCH}" echo $DOCKER_BUILD_COMMAND eval $DOCKER_BUILD_COMMAND - # echo $DOCKER_TAG_COMMAND - # eval $DOCKER_TAG_COMMAND - # echo $DOCKER_PUSH_COMMAND - # eval $DOCKER_PUSH_COMMAND - # echo "To pull container run: " - # echo $DOCKER_PULL_COMMAND + echo $DOCKER_TAG_COMMAND + eval $DOCKER_TAG_COMMAND + echo $DOCKER_PUSH_COMMAND + eval $DOCKER_PUSH_COMMAND + echo "To pull container run: " + echo $DOCKER_PULL_COMMAND done From fa946d861aab6d803d88046edb05caa27a79c4ab Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 6 Nov 2025 04:00:09 +0000 Subject: [PATCH 04/17] fix merge conflict --- scoring/utils/run_workloads.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/scoring/utils/run_workloads.py b/scoring/utils/run_workloads.py index c6764e9de..d8e0172fa 100644 --- a/scoring/utils/run_workloads.py +++ b/scoring/utils/run_workloads.py @@ -271,10 +271,7 @@ def main(_): 'docker run -t -d -v /home/kasimbeg/data/:/data/ ' '-v /home/kasimbeg/experiment_runs/:/experiment_runs ' '-v /home/kasimbeg/experiment_runs/logs:/logs ' -<<<<<<< Updated upstream -======= '-v /home/kasimbeg/algorithmic-efficiency:/algorithmic-efficiency ' ->>>>>>> Stashed changes f'{mount_repo_flag}' '--gpus all --ipc=host ' f'{docker_image_url} ' From 4e564d5438398ab40da419413d7cac603dd96261 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 20 Nov 2025 23:17:17 +0000 Subject: [PATCH 05/17] update pytorch --- pyproject.toml | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index e4de98f89..e1fc84987 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -105,7 +105,6 @@ jax_cpu = [ jax_gpu = [ "jax[cuda12]==0.7.0", "algoperf[jax_core_deps]", - "nvidia-cudnn-cu12==9.10.2.21", # temporary workaround for https://github.com/jax-ml/jax/issues/30663 ] pytorch_cpu = [ @@ -113,8 +112,8 @@ pytorch_cpu = [ "torchvision==0.20.1" ] pytorch_gpu = [ - "torch==2.5.1", - "torchvision==0.20.1", + "torch==2.9.0", + "torchvision==0.24.0", ] # Note: omit the cuda suffix and installing from the appropriate wheel will result in using locally installed CUDA. ############################################################################### From 6f7d638adc190d9bce3f30ba3314c27dac1a8cc5 Mon Sep 17 00:00:00 2001 From: rka97 Date: Mon, 1 Dec 2025 04:33:49 +0000 Subject: [PATCH 06/17] ImageNet and CIFAR mixed-precision support, need to debug slow pytorch - Introduced DTYPE enum to standardize data types (FLOAT32, FLOAT16, BFLOAT16) for JAX and PyTorch. - Updated input pipelines and model definitions in CIFAR and ImageNet workloads to utilize mixed precision. - Implemented casting policies for parameters and inputs using jmp and torch.autocast. --- algoperf/spec.py | 23 ++++++ .../cifar/cifar_jax/input_pipeline.py | 2 - algoperf/workloads/cifar/cifar_jax/models.py | 8 ++- .../workloads/cifar/cifar_jax/workload.py | 29 ++++++-- .../workloads/cifar/cifar_pytorch/models.py | 24 ++++++- .../workloads/cifar/cifar_pytorch/workload.py | 9 ++- algoperf/workloads/cifar/workload.py | 2 + .../imagenet_resnet/imagenet_jax/models.py | 8 ++- .../imagenet_resnet/imagenet_jax/workload.py | 28 ++++++-- .../imagenet_pytorch/models.py | 52 +++++++++++--- .../imagenet_pytorch/workload.py | 9 ++- .../workloads/imagenet_resnet/workload.py | 2 + .../imagenet_vit/imagenet_jax/models.py | 48 +++++++++---- .../imagenet_vit/imagenet_jax/workload.py | 10 ++- .../imagenet_vit/imagenet_pytorch/models.py | 71 +++++++++++++------ .../imagenet_vit/imagenet_pytorch/workload.py | 12 ++-- algoperf/workloads/ogbg/workload.py | 2 +- .../external_tuning/jax_nadamw_full_budget.py | 2 + .../pytorch_nadamw_full_budget.py | 10 +-- scoring/performance_profile.py | 2 +- submission_runner.py | 25 +++++-- 21 files changed, 288 insertions(+), 90 deletions(-) diff --git a/algoperf/spec.py b/algoperf/spec.py index b86e55954..8dd00345c 100644 --- a/algoperf/spec.py +++ b/algoperf/spec.py @@ -6,11 +6,34 @@ from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Union import jax +import jax.numpy as jnp +import torch import torch.nn.functional as F from absl import logging from torch import nn +class DTYPE(enum.Enum): + FLOAT32 = 0 + FLOAT16 = 1 + BFLOAT16 = 2 + + +# Mapping from DTYPE enum to JAX dtypes +JAX_DTYPE_MAP = { + DTYPE.FLOAT32: jnp.float32, + DTYPE.FLOAT16: jnp.float16, + DTYPE.BFLOAT16: jnp.bfloat16, +} + +# Mapping from DTYPE enum to PyTorch dtypes +PYTORCH_DTYPE_MAP = { + DTYPE.FLOAT32: torch.float32, + DTYPE.FLOAT16: torch.float16, + DTYPE.BFLOAT16: torch.bfloat16, +} + + class LossType(enum.Enum): SOFTMAX_CROSS_ENTROPY = 0 SIGMOID_CROSS_ENTROPY = 1 diff --git a/algoperf/workloads/cifar/cifar_jax/input_pipeline.py b/algoperf/workloads/cifar/cifar_jax/input_pipeline.py index 7fbc95bc6..307e9e705 100644 --- a/algoperf/workloads/cifar/cifar_jax/input_pipeline.py +++ b/algoperf/workloads/cifar/cifar_jax/input_pipeline.py @@ -11,7 +11,6 @@ import jax import tensorflow as tf import tensorflow_datasets as tfds -from flax import jax_utils from algoperf import spec from algoperf.data_utils import shard_and_maybe_pad_np @@ -186,5 +185,4 @@ def create_input_iter( ), ds, ) - it = jax_utils.prefetch_to_device(it, 2) return it diff --git a/algoperf/workloads/cifar/cifar_jax/models.py b/algoperf/workloads/cifar/cifar_jax/models.py index 95238c997..9a4f7fd96 100644 --- a/algoperf/workloads/cifar/cifar_jax/models.py +++ b/algoperf/workloads/cifar/cifar_jax/models.py @@ -31,7 +31,7 @@ def __call__( update_batch_norm: bool = True, use_running_average_bn: bool = None, ) -> spec.Tensor: - conv = functools.partial(nn.Conv, use_bias=False, dtype=self.dtype) + conv = functools.partial(nn.Conv, use_bias=False, param_dtype=self.dtype) # Preserve default behavior for backwards compatibility if use_running_average_bn is None: @@ -41,7 +41,7 @@ def __call__( use_running_average=use_running_average_bn, momentum=0.9, epsilon=1e-5, - dtype=self.dtype, + param_dtype=self.dtype, ) x = conv( @@ -66,7 +66,9 @@ def __call__( x = nn.avg_pool(x, (4, 4), strides=(4, 4)) x = jnp.mean(x, axis=(1, 2)) x = nn.Dense( - self.num_classes, kernel_init=nn.initializers.normal(), dtype=self.dtype + self.num_classes, + kernel_init=nn.initializers.normal(), + param_dtype=self.dtype, )(x) return x diff --git a/algoperf/workloads/cifar/cifar_jax/workload.py b/algoperf/workloads/cifar/cifar_jax/workload.py index defc30121..e6bc5b419 100644 --- a/algoperf/workloads/cifar/cifar_jax/workload.py +++ b/algoperf/workloads/cifar/cifar_jax/workload.py @@ -5,6 +5,7 @@ import jax import jax.numpy as jnp +import jmp import optax import tensorflow_datasets as tfds from flax import linen as nn @@ -18,6 +19,17 @@ class CifarWorkload(BaseCifarWorkload): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + compute_dtype = spec.JAX_DTYPE_MAP[self._compute_dtype] + param_dtype = spec.JAX_DTYPE_MAP[self._param_dtype] + output_dtype = compute_dtype + self._mp_policy = jmp.Policy( + compute_dtype=compute_dtype, + param_dtype=param_dtype, + output_dtype=output_dtype, + ) + def _build_cifar_dataset( self, data_rng: spec.RandomState, @@ -80,7 +92,8 @@ def sync_batch_stats( def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: """Dropout is unused.""" model_cls = getattr(models, 'ResNet18') - model = model_cls(num_classes=self._num_classes, dtype=jnp.float32) + param_dtype = spec.JAX_DTYPE_MAP[self._param_dtype] + model = model_cls(num_classes=self._num_classes, dtype=param_dtype) self._model = model input_shape = (1, 32, 32, 3) variables = jax.jit(model.init)( @@ -89,7 +102,7 @@ def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: model_state, params = pop(variables, 'params') self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) - model_state = jax_sharding_utils.replicate(params) + model_state = jax_sharding_utils.replicate(model_state) params = jax_sharding_utils.replicate(params) return params, model_state @@ -110,24 +123,32 @@ def model_fn( del mode del rng del dropout_rate + # Cast params and inputs to compute dtype + params, inputs = self._mp_policy.cast_to_compute( + (params, augmented_and_preprocessed_input_batch['inputs']) + ) variables = {'params': params, **model_state} if update_batch_norm: logits, new_model_state = self._model.apply( variables, - augmented_and_preprocessed_input_batch['inputs'], + inputs, update_batch_norm=update_batch_norm, mutable=['batch_stats'], use_running_average_bn=use_running_average_bn, ) + # Cast logits to output dtype + logits = self._mp_policy.cast_to_output(logits) return logits, new_model_state else: logits = self._model.apply( variables, - augmented_and_preprocessed_input_batch['inputs'], + inputs, update_batch_norm=update_batch_norm, mutable=False, use_running_average_bn=use_running_average_bn, ) + # Cast logits to output dtype + logits = self._mp_policy.cast_to_output(logits) return logits, model_state # Does NOT apply regularization, which is left to the submitter to do in diff --git a/algoperf/workloads/cifar/cifar_pytorch/models.py b/algoperf/workloads/cifar/cifar_pytorch/models.py index 0e08f5c5a..b2b37c001 100644 --- a/algoperf/workloads/cifar/cifar_pytorch/models.py +++ b/algoperf/workloads/cifar/cifar_pytorch/models.py @@ -29,11 +29,13 @@ def __init__( width_per_group: int = 64, replace_stride_with_dilation: Optional[List[bool]] = None, norm_layer: Optional[Callable[..., nn.Module]] = None, + dtype: torch.dtype = torch.float32, ) -> None: super().__init__() if norm_layer is None: norm_layer = nn.BatchNorm2d self._norm_layer = norm_layer + self.dtype = dtype self.inplanes = 64 self.dilation = 1 @@ -49,7 +51,13 @@ def __init__( self.groups = groups self.base_width = width_per_group self.conv1 = nn.Conv2d( - 3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False + 3, + self.inplanes, + kernel_size=3, + stride=1, + padding=1, + bias=False, + dtype=dtype, ) self.bn1 = norm_layer(self.inplanes) self.relu = nn.ReLU(inplace=True) @@ -63,7 +71,7 @@ def __init__( self.layer4 = self._make_layer( block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2] ) - self.fc = nn.Linear(512 * block.expansion, num_classes) + self.fc = nn.Linear(512 * block.expansion, num_classes, dtype=dtype) self.reset_parameters() def reset_parameters(self) -> None: @@ -105,7 +113,15 @@ def _make_layer( downsample = torch.nn.Sequential( collections.OrderedDict( [ - ('conv', conv1x1(self.inplanes, planes * block.expansion, stride)), + ( + 'conv', + conv1x1( + self.inplanes, + planes * block.expansion, + stride, + dtype=self.dtype, + ), + ), ('bn', norm_layer(planes * block.expansion)), ] ) @@ -122,6 +138,7 @@ def _make_layer( self.base_width, previous_dilation, norm_layer, + dtype=self.dtype, ) ) self.inplanes = planes * block.expansion @@ -134,6 +151,7 @@ def _make_layer( base_width=self.base_width, dilation=self.dilation, norm_layer=norm_layer, + dtype=self.dtype, ) ) diff --git a/algoperf/workloads/cifar/cifar_pytorch/workload.py b/algoperf/workloads/cifar/cifar_pytorch/workload.py index a6e8569cc..141bef922 100644 --- a/algoperf/workloads/cifar/cifar_pytorch/workload.py +++ b/algoperf/workloads/cifar/cifar_pytorch/workload.py @@ -25,6 +25,8 @@ def __init__(self, *args, **kwargs) -> None: # Is set in submission_runner.py for workloads with PyTorch evaluation # data loaders via the `eval_num_workers` property. self._eval_num_workers = None + self._param_dtype_pt = spec.PYTORCH_DTYPE_MAP[self._param_dtype] + self._compute_dtype_pt = spec.PYTORCH_DTYPE_MAP[self._compute_dtype] @property def eval_num_workers(self) -> int: @@ -128,7 +130,9 @@ def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: return self._model, None torch.random.manual_seed(rng[0]) - self._model = resnet18(num_classes=self._num_classes) + self._model = resnet18( + num_classes=self._num_classes, dtype=self._param_dtype_pt + ) self._param_shapes = param_utils.pytorch_param_shapes(self._model) self._param_types = param_utils.pytorch_param_types(self._param_shapes) self._model.to(DEVICE) @@ -175,7 +179,8 @@ def model_fn( spec.ForwardPassMode.TRAIN: contextlib.nullcontext, } with contexts[mode](): - logits_batch = model(augmented_and_preprocessed_input_batch['inputs']) + with torch.autocast(device_type='cuda', dtype=self._compute_dtype_pt): + logits_batch = model(augmented_and_preprocessed_input_batch['inputs']) return logits_batch, None # Does NOT apply regularization, which is left to the submitter to do in diff --git a/algoperf/workloads/cifar/workload.py b/algoperf/workloads/cifar/workload.py index 31636807c..6866bc918 100644 --- a/algoperf/workloads/cifar/workload.py +++ b/algoperf/workloads/cifar/workload.py @@ -16,6 +16,8 @@ class BaseCifarWorkload(spec.Workload): _num_classes: int = 10 + _compute_dtype: spec.DTYPE = spec.DTYPE.BFLOAT16 + _param_dtype: spec.DTYPE = spec.DTYPE.FLOAT32 @property def target_metric_name(self) -> str: diff --git a/algoperf/workloads/imagenet_resnet/imagenet_jax/models.py b/algoperf/workloads/imagenet_resnet/imagenet_jax/models.py index ee1ddf427..41551d4d2 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_jax/models.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_jax/models.py @@ -90,7 +90,7 @@ def __call__( update_batch_norm: bool = True, use_running_average_bn: Optional[bool] = None, ) -> spec.Tensor: - conv = functools.partial(nn.Conv, use_bias=False, dtype=self.dtype) + conv = functools.partial(nn.Conv, use_bias=False, param_dtype=self.dtype) # Preserve default behavior for backwards compatibility if use_running_average_bn is None: use_running_average_bn = not update_batch_norm @@ -99,7 +99,7 @@ def __call__( use_running_average=use_running_average_bn, momentum=0.9, epsilon=1e-5, - dtype=self.dtype, + param_dtype=self.dtype, ) x = conv( @@ -125,7 +125,9 @@ def __call__( )(x) x = jnp.mean(x, axis=(1, 2)) x = nn.Dense( - self.num_classes, kernel_init=nn.initializers.normal(), dtype=self.dtype + self.num_classes, + kernel_init=nn.initializers.normal(), + param_dtype=self.dtype, )(x) return x diff --git a/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py b/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py index f73a1b26e..d7a8ede67 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py @@ -11,6 +11,7 @@ import jax import jax.numpy as jnp +import jmp import optax import tensorflow_datasets as tfds from flax import linen as nn @@ -29,6 +30,17 @@ class ImagenetResNetWorkload(BaseImagenetResNetWorkload): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + compute_dtype = spec.JAX_DTYPE_MAP[self._compute_dtype] + param_dtype = spec.JAX_DTYPE_MAP[self._param_dtype] + output_dtype = compute_dtype + self._mp_policy = jmp.Policy( + compute_dtype=compute_dtype, + param_dtype=param_dtype, + output_dtype=output_dtype, + ) + def _build_dataset( self, data_rng: spec.RandomState, @@ -89,11 +101,12 @@ def init_model_fn( else: act_fnc = nn.relu + param_dtype = spec.JAX_DTYPE_MAP[self._param_dtype] model = model_cls( num_classes=self._num_classes, act=act_fnc, bn_init_scale=self.bn_init_scale, - dtype=jnp.float32, + dtype=param_dtype, ) self._model = model input_shape = (1, 224, 224, 3) @@ -159,25 +172,28 @@ def model_fn( del mode del rng del dropout_rate + params, inputs = self._mp_policy.cast_to_compute( + (params, augmented_and_preprocessed_input_batch['inputs']) + ) variables = {'params': params, **model_state} if update_batch_norm: - logits, new_model_state = self._model.apply( + logits, model_state = self._model.apply( variables, - augmented_and_preprocessed_input_batch['inputs'], + inputs, update_batch_norm=update_batch_norm, mutable=['batch_stats'], use_running_average_bn=use_running_average_bn, ) - return logits, new_model_state else: logits = self._model.apply( variables, - augmented_and_preprocessed_input_batch['inputs'], + inputs, update_batch_norm=update_batch_norm, mutable=False, use_running_average_bn=use_running_average_bn, ) - return logits, model_state + logits = self._mp_policy.cast_to_output(logits) + return logits, model_state # Does NOT apply regularization, which is left to the submitter to do in # `update_params`. diff --git a/algoperf/workloads/imagenet_resnet/imagenet_pytorch/models.py b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/models.py index c980faa06..f24ba66b9 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_pytorch/models.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/models.py @@ -20,6 +20,7 @@ def conv3x3( stride: int = 1, groups: int = 1, dilation: int = 1, + dtype: torch.dtype = torch.float32, ) -> nn.Conv2d: """3x3 convolution with padding.""" return nn.Conv2d( @@ -31,13 +32,24 @@ def conv3x3( groups=groups, bias=False, dilation=dilation, + dtype=dtype, ) -def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: +def conv1x1( + in_planes: int, + out_planes: int, + stride: int = 1, + dtype: torch.dtype = torch.float32, +) -> nn.Conv2d: """1x1 convolution.""" return nn.Conv2d( - in_planes, out_planes, kernel_size=1, stride=stride, bias=False + in_planes, + out_planes, + kernel_size=1, + stride=stride, + bias=False, + dtype=dtype, ) @@ -57,6 +69,7 @@ def __init__( dilation: int = 1, norm_layer: Optional[Callable[..., nn.Module]] = None, act_fnc: nn.Module = nn.ReLU(inplace=True), + dtype: torch.dtype = torch.float32, ) -> None: super().__init__() if norm_layer is None: @@ -67,10 +80,10 @@ def __init__( raise NotImplementedError('Dilation > 1 not supported in BasicBlock') # Both self.conv1 and self.downsample layers downsample # the input when stride != 1. - self.conv1 = conv3x3(inplanes, planes, stride) + self.conv1 = conv3x3(inplanes, planes, stride, dtype=dtype) self.bn1 = norm_layer(planes) self.act_fnc = act_fnc - self.conv2 = conv3x3(planes, planes) + self.conv2 = conv3x3(planes, planes, dtype=dtype) self.bn2 = norm_layer(planes) self.downsample = downsample self.stride = stride @@ -110,6 +123,7 @@ def __init__( dilation: int = 1, norm_layer: Optional[Callable[..., nn.Module]] = None, act_fnc: nn.Module = nn.ReLU(inplace=True), + dtype: torch.dtype = torch.float32, ) -> None: super().__init__() if norm_layer is None: @@ -117,11 +131,11 @@ def __init__( width = int(planes * (base_width / 64.0)) * groups # Both self.conv2 and self.downsample layers downsample # the input when stride != 1. - self.conv1 = conv1x1(inplanes, width) + self.conv1 = conv1x1(inplanes, width, dtype=dtype) self.bn1 = norm_layer(width) - self.conv2 = conv3x3(width, width, stride, groups, dilation) + self.conv2 = conv3x3(width, width, stride, groups, dilation, dtype=dtype) self.bn2 = norm_layer(width) - self.conv3 = conv1x1(width, planes * self.expansion) + self.conv3 = conv1x1(width, planes * self.expansion, dtype=dtype) self.bn3 = norm_layer(planes * self.expansion) self.act_fnc = act_fnc self.downsample = downsample @@ -163,11 +177,13 @@ def __init__( norm_layer: Optional[Callable[..., nn.Module]] = None, act_fnc: nn.Module = nn.ReLU(inplace=True), bn_init_scale: float = 0.0, + dtype: torch.dtype = torch.float32, ) -> None: super().__init__() if norm_layer is None: norm_layer = nn.BatchNorm2d self._norm_layer = norm_layer + self.dtype = dtype self.inplanes = 64 self.dilation = 1 @@ -183,7 +199,13 @@ def __init__( self.groups = groups self.base_width = width_per_group self.conv1 = nn.Conv2d( - 3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False + 3, + self.inplanes, + kernel_size=7, + stride=2, + padding=3, + bias=False, + dtype=dtype, ) self.bn1 = norm_layer(self.inplanes) self.act_fnc = act_fnc @@ -214,7 +236,7 @@ def __init__( dilate=replace_stride_with_dilation[2], ) self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) - self.fc = nn.Linear(512 * block.expansion, num_classes) + self.fc = nn.Linear(512 * block.expansion, num_classes, dtype=dtype) for m in self.modules(): if isinstance(m, nn.Conv2d): @@ -256,7 +278,15 @@ def _make_layer( downsample = torch.nn.Sequential( collections.OrderedDict( [ - ('conv', conv1x1(self.inplanes, planes * block.expansion, stride)), + ( + 'conv', + conv1x1( + self.inplanes, + planes * block.expansion, + stride, + dtype=self.dtype, + ), + ), ('bn', norm_layer(planes * block.expansion)), ] ) @@ -274,6 +304,7 @@ def _make_layer( previous_dilation, norm_layer, act_fnc, + dtype=self.dtype, ) ) self.inplanes = planes * block.expansion @@ -287,6 +318,7 @@ def _make_layer( dilation=self.dilation, norm_layer=norm_layer, act_fnc=act_fnc, + dtype=self.dtype, ) ) diff --git a/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py index d5366c60d..3a88245ae 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py @@ -178,7 +178,10 @@ def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: else: act_fnc = torch.nn.ReLU(inplace=True) - model = resnet50(act_fnc=act_fnc, bn_init_scale=self.bn_init_scale) + param_dtype = spec.PYTORCH_DTYPE_MAP[self._param_dtype] + model = resnet50( + act_fnc=act_fnc, bn_init_scale=self.bn_init_scale, dtype=param_dtype + ) self._param_shapes = param_utils.pytorch_param_shapes(model) self._param_types = param_utils.pytorch_param_types(self._param_shapes) model.to(DEVICE) @@ -229,8 +232,10 @@ def model_fn( spec.ForwardPassMode.TRAIN: contextlib.nullcontext, } + compute_dtype = spec.PYTORCH_DTYPE_MAP[self._compute_dtype] with contexts[mode](): - logits_batch = model(augmented_and_preprocessed_input_batch['inputs']) + with torch.autocast(device_type='cuda', dtype=compute_dtype): + logits_batch = model(augmented_and_preprocessed_input_batch['inputs']) return logits_batch, None diff --git a/algoperf/workloads/imagenet_resnet/workload.py b/algoperf/workloads/imagenet_resnet/workload.py index de8458c92..bc5982f1d 100644 --- a/algoperf/workloads/imagenet_resnet/workload.py +++ b/algoperf/workloads/imagenet_resnet/workload.py @@ -8,6 +8,8 @@ class BaseImagenetResNetWorkload(spec.Workload): _num_classes: int = 1000 + _compute_dtype: spec.DTYPE = spec.DTYPE.BFLOAT16 + _param_dtype: spec.DTYPE = spec.DTYPE.FLOAT32 @property def target_metric_name(self) -> str: diff --git a/algoperf/workloads/imagenet_vit/imagenet_jax/models.py b/algoperf/workloads/imagenet_vit/imagenet_jax/models.py index e86233011..2e4630701 100644 --- a/algoperf/workloads/imagenet_vit/imagenet_jax/models.py +++ b/algoperf/workloads/imagenet_vit/imagenet_jax/models.py @@ -42,6 +42,7 @@ class MlpBlock(nn.Module): mlp_dim: Optional[int] = None # Defaults to 4x input dim. use_glu: bool = False dropout_rate: float = DROPOUT_RATE + dtype: jnp.dtype = jnp.float32 @nn.compact def __call__( @@ -54,15 +55,15 @@ def __call__( } d = x.shape[2] - x = nn.Dense(self.mlp_dim or 4 * d, **inits)(x) + x = nn.Dense(self.mlp_dim or 4 * d, param_dtype=self.dtype, **inits)(x) x = nn.gelu(x) if self.use_glu: - y = nn.Dense(self.mlp_dim, **inits)(x) + y = nn.Dense(self.mlp_dim, param_dtype=self.dtype, **inits)(x) x = x * y x = Dropout(dropout_rate)(x, train, rate=dropout_rate) - x = nn.Dense(d, **inits)(x) + x = nn.Dense(d, param_dtype=self.dtype, **inits)(x) return x @@ -74,25 +75,30 @@ class Encoder1DBlock(nn.Module): use_glu: bool = False use_post_layer_norm: bool = False dropout_rate: float = 0.0 + dtype: jnp.dtype = jnp.float32 @nn.compact def __call__( self, x: spec.Tensor, train: bool = True, dropout_rate=dropout_rate ) -> spec.Tensor: if not self.use_post_layer_norm: - y = nn.LayerNorm(name='LayerNorm_0')(x) + y = nn.LayerNorm(name='LayerNorm_0', param_dtype=self.dtype)(x) y = nn.MultiHeadDotProductAttention( num_heads=self.num_heads, kernel_init=nn.initializers.xavier_uniform(), deterministic=train, name='MultiHeadDotProductAttention_1', + param_dtype=self.dtype, )(y) y = Dropout(dropout_rate)(y, train, rate=dropout_rate) x = x + y - y = nn.LayerNorm(name='LayerNorm_2')(x) + y = nn.LayerNorm(name='LayerNorm_2', param_dtype=self.dtype)(x) y = MlpBlock( - mlp_dim=self.mlp_dim, use_glu=self.use_glu, name='MlpBlock_3' + mlp_dim=self.mlp_dim, + use_glu=self.use_glu, + dtype=self.dtype, + name='MlpBlock_3', )(y, train, dropout_rate=dropout_rate) y = Dropout(dropout_rate)(y, train, rate=dropout_rate) x = x + y @@ -103,21 +109,23 @@ def __call__( kernel_init=nn.initializers.xavier_uniform(), deterministic=train, name='MultiHeadDotProductAttention_1', + param_dtype=self.dtype, )(y) y = Dropout(dropout_rate)(y, train, rate=dropout_rate) x = x + y - x = nn.LayerNorm(name='LayerNorm_0')(x) + x = nn.LayerNorm(name='LayerNorm_0', param_dtype=self.dtype)(x) y = x y = MlpBlock( mlp_dim=self.mlp_dim, use_glu=self.use_glu, + dtype=self.dtype, name='MlpBlock_3', dropout_rate=dropout_rate, )(y, train, dropout_rate=dropout_rate) y = Dropout(dropout_rate)(y, train)(rate=dropout_rate) x = x + y - x = nn.LayerNorm(name='LayerNorm_2')(x) + x = nn.LayerNorm(name='LayerNorm_2', param_dtype=self.dtype)(x) return x @@ -130,6 +138,7 @@ class Encoder(nn.Module): num_heads: int = 12 use_glu: bool = False use_post_layer_norm: bool = False + dtype: jnp.dtype = jnp.float32 @nn.compact def __call__( @@ -143,9 +152,10 @@ def __call__( num_heads=self.num_heads, use_glu=self.use_glu, use_post_layer_norm=self.use_post_layer_norm, + dtype=self.dtype, )(x, train=train, dropout_rate=dropout_rate) if not self.use_post_layer_norm: - return nn.LayerNorm(name='encoder_layernorm')(x) + return nn.LayerNorm(name='encoder_layernorm', param_dtype=self.dtype)(x) else: return x @@ -156,12 +166,13 @@ class MAPHead(nn.Module): mlp_dim: Optional[int] = None # Defaults to 4x input dim num_heads: int = 12 dropout_rate: float = 0.0 + dtype: jnp.dtype = jnp.float32 @nn.compact def __call__(self, x, dropout_rate=DROPOUT_RATE): n, _, d = x.shape probe = self.param( - 'probe', nn.initializers.xavier_uniform(), (1, 1, d), x.dtype + 'probe', nn.initializers.xavier_uniform(), (1, 1, d), self.dtype ) probe = jnp.tile(probe, [n, 1, 1]) @@ -169,10 +180,13 @@ def __call__(self, x, dropout_rate=DROPOUT_RATE): num_heads=self.num_heads, use_bias=True, kernel_init=nn.initializers.xavier_uniform(), + param_dtype=self.dtype, )(probe, x) - y = nn.LayerNorm()(x) - x = x + MlpBlock(mlp_dim=self.mlp_dim, dropout_rate=dropout_rate)(y) + y = nn.LayerNorm(param_dtype=self.dtype)(x) + x = x + MlpBlock( + mlp_dim=self.mlp_dim, dropout_rate=dropout_rate, dtype=self.dtype + )(y) return x[:, 0] @@ -192,6 +206,7 @@ class ViT(nn.Module): use_glu: bool = False use_post_layer_norm: bool = False use_map: bool = False + dtype: jnp.dtype = jnp.float32 def get_posemb( self, seqshape: tuple, width: int, dtype: jnp.dtype = jnp.float32 @@ -209,6 +224,7 @@ def __call__( strides=self.patch_size, padding='VALID', name='conv_patch_extract', + param_dtype=self.dtype, )(x) n, h, w, c = x.shape @@ -225,6 +241,7 @@ def __call__( num_heads=self.num_heads, use_glu=self.use_glu, use_post_layer_norm=self.use_post_layer_norm, + dtype=self.dtype, name='Transformer', )(x, train=not train, dropout_rate=dropout_rate) @@ -233,18 +250,21 @@ def __call__( num_heads=self.num_heads, mlp_dim=self.mlp_dim, dropout_rate=dropout_rate, + dtype=self.dtype, )(x, dropout_rate=dropout_rate) else: x = jnp.mean(x, axis=1) if self.rep_size: rep_size = self.width if self.rep_size is True else self.rep_size - hid = nn.Dense(rep_size, name='pre_logits') + hid = nn.Dense(rep_size, name='pre_logits', param_dtype=self.dtype) x = nn.tanh(hid(x)) if self.num_classes: kw = {'kernel_init': nn.initializers.zeros} if self.head_zeroinit else {} - head = nn.Dense(self.num_classes, name='head', **kw) + head = nn.Dense( + self.num_classes, name='head', param_dtype=self.dtype, **kw + ) x = head(x) return x diff --git a/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py b/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py index 8a33aeb47..6819a4862 100644 --- a/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py +++ b/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py @@ -32,11 +32,13 @@ def initialized( return params, model_state def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: + param_dtype = spec.JAX_DTYPE_MAP[self._param_dtype] self._model = models.ViT( num_classes=self._num_classes, use_glu=self.use_glu, use_post_layer_norm=self.use_post_layer_norm, use_map=self.use_map, + dtype=param_dtype, **decode_variant('S/16'), ) params, model_state = self.initialized(rng, self._model) @@ -62,15 +64,19 @@ def model_fn( ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del model_state del update_batch_norm - del use_running_average_bn + # Cast params and inputs to compute dtype + params, inputs = self._mp_policy.cast_to_compute( + (params, augmented_and_preprocessed_input_batch['inputs']) + ) train = mode == spec.ForwardPassMode.TRAIN logits = self._model.apply( {'params': params}, - augmented_and_preprocessed_input_batch['inputs'], + inputs, rngs={'dropout': rng}, train=train, dropout_rate=dropout_rate, ) + logits = self._mp_policy.cast_to_output(logits) return logits, None def _eval_model_on_split( diff --git a/algoperf/workloads/imagenet_vit/imagenet_pytorch/models.py b/algoperf/workloads/imagenet_vit/imagenet_pytorch/models.py index fc2a3cd46..6dfb5fddf 100644 --- a/algoperf/workloads/imagenet_vit/imagenet_pytorch/models.py +++ b/algoperf/workloads/imagenet_vit/imagenet_pytorch/models.py @@ -46,22 +46,24 @@ def __init__( width: int, mlp_dim: Optional[int] = None, # Defaults to 4x input dim. use_glu: bool = False, + dtype: Any = torch.float32, ) -> None: super().__init__() self.width = width self.mlp_dim = mlp_dim or 4 * width self.use_glu = use_glu + self.dtype = dtype - self.linear1 = nn.Linear(self.width, self.mlp_dim) + self.linear1 = nn.Linear(self.width, self.mlp_dim, dtype=self.dtype) self.act_fnc = nn.GELU(approximate='tanh') if self.use_glu: - self.glu_linear = nn.Linear(self.mlp_dim, self.mlp_dim) + self.glu_linear = nn.Linear(self.mlp_dim, self.mlp_dim, dtype=self.dtype) else: self.glu_linear = None - self.linear2 = nn.Linear(self.mlp_dim, self.width) + self.linear2 = nn.Linear(self.mlp_dim, self.width, dtype=self.dtype) self.reset_parameters() @@ -85,14 +87,18 @@ def forward(self, x: spec.Tensor, dropout_rate: float) -> spec.Tensor: return x +# TODO(rka97): switch this to built-in attention with cudnn class SelfAttention(nn.Module): """Self-attention special case of multi-head dot-product attention.""" - def __init__(self, width: int, num_heads: int = 8) -> None: + def __init__( + self, width: int, num_heads: int = 8, dtype: Any = torch.float32 + ) -> None: super().__init__() self.width = width self.num_heads = num_heads + self.dtype = dtype assert width % num_heads == 0, ( 'Memory dimension must be divisible by number of heads.' @@ -101,10 +107,10 @@ def __init__(self, width: int, num_heads: int = 8) -> None: self.head_dim = int(width / num_heads) self.all_head_dim = self.num_heads * self.head_dim - self.query = nn.Linear(self.width, self.all_head_dim) - self.key = nn.Linear(self.width, self.all_head_dim) - self.value = nn.Linear(self.width, self.all_head_dim) - self.out = nn.Linear(self.width, self.width) + self.query = nn.Linear(self.width, self.all_head_dim, dtype=self.dtype) + self.key = nn.Linear(self.width, self.all_head_dim, dtype=self.dtype) + self.value = nn.Linear(self.width, self.all_head_dim, dtype=self.dtype) + self.out = nn.Linear(self.width, self.width, dtype=self.dtype) self.reset_parameters() def reset_parameters(self) -> None: @@ -150,6 +156,7 @@ def __init__( num_heads: int = 12, use_glu: bool = False, use_post_layer_norm: bool = False, + dtype: Any = torch.float32, ) -> None: super().__init__() @@ -158,12 +165,18 @@ def __init__( self.num_heads = num_heads self.use_glu = use_glu self.use_post_layer_norm = use_post_layer_norm + self.dtype = dtype - self.layer_norm0 = nn.LayerNorm(self.width, eps=1e-6) - self.self_attention1 = SelfAttention(self.width, self.num_heads) - self.layer_norm2 = nn.LayerNorm(self.width, eps=1e-6) + self.layer_norm0 = nn.LayerNorm(self.width, eps=1e-6, dtype=self.dtype) + self.self_attention1 = SelfAttention( + self.width, self.num_heads, dtype=self.dtype + ) + self.layer_norm2 = nn.LayerNorm(self.width, eps=1e-6, dtype=self.dtype) self.mlp3 = MlpBlock( - width=self.width, mlp_dim=self.mlp_dim, use_glu=self.use_glu + width=self.width, + mlp_dim=self.mlp_dim, + use_glu=self.use_glu, + dtype=self.dtype, ) def forward(self, x: spec.Tensor, dropout_rate: float) -> spec.Tensor: @@ -203,6 +216,7 @@ def __init__( num_heads: int = 12, use_glu: bool = False, use_post_layer_norm: bool = False, + dtype: Any = torch.float32, ) -> None: super().__init__() @@ -212,6 +226,7 @@ def __init__( self.num_heads = num_heads self.use_glu = use_glu self.use_post_layer_norm = use_post_layer_norm + self.dtype = dtype self.net = nn.ModuleList( [ @@ -221,13 +236,14 @@ def __init__( self.num_heads, self.use_glu, self.use_post_layer_norm, + dtype=self.dtype, ) for _ in range(depth) ] ) if not self.use_post_layer_norm: - self.encoder_norm = nn.LayerNorm(self.width, eps=1e-6) + self.encoder_norm = nn.LayerNorm(self.width, eps=1e-6, dtype=self.dtype) else: self.encoder_norm = None @@ -245,21 +261,32 @@ class MAPHead(nn.Module): """Multihead Attention Pooling.""" def __init__( - self, width: int, mlp_dim: Optional[int] = None, num_heads: int = 12 + self, + width: int, + mlp_dim: Optional[int] = None, + num_heads: int = 12, + dtype: torch.dtype = torch.float32, ): super().__init__() self.width = width self.mlp_dim = mlp_dim self.num_heads = num_heads + self.dtype = dtype self.probe = nn.Parameter(torch.zeros((1, 1, self.width))) nn.init.xavier_uniform_(self.probe.data) self.mha = MultiheadAttention( - self.width, num_heads=self.num_heads, self_attn=False, bias=True + self.width, + num_heads=self.num_heads, + self_attn=False, + bias=True, + dtype=self.dtype, + ) + self.layer_norm = nn.LayerNorm(self.width, eps=1e-6, dtype=self.dtype) + self.mlp = MlpBlock( + width=self.width, mlp_dim=self.mlp_dim, dtype=self.dtype ) - self.layer_norm = nn.LayerNorm(self.width, eps=1e-6) - self.mlp = MlpBlock(width=self.width, mlp_dim=self.mlp_dim) def forward(self, x: spec.Tensor, dropout_rate: float) -> spec.Tensor: n, _, _ = x.shape @@ -310,7 +337,7 @@ def __init__( if self.rep_size: rep_size = self.width if self.rep_size is True else self.rep_size - self.pre_logits = nn.Linear(self.width, rep_size) + self.pre_logits = nn.Linear(self.width, rep_size, dtype=self.dtype) self.conv_patch_extract = nn.Conv2d( self.channels, @@ -318,6 +345,7 @@ def __init__( self.patch_size, stride=self.patch_size, padding='valid', + dtype=self.dtype, ) self.encoder = Encoder( @@ -327,13 +355,16 @@ def __init__( num_heads=self.num_heads, use_glu=self.use_glu, use_post_layer_norm=self.use_post_layer_norm, + dtype=self.dtype, ) if self.num_classes: - self.head = nn.Linear(self.width, self.num_classes) + self.head = nn.Linear(self.width, self.num_classes, dtype=self.dtype) if self.use_map: - self.map = MAPHead(self.width, self.mlp_dim, self.num_heads) + self.map = MAPHead( + self.width, self.mlp_dim, self.num_heads, dtype=self.dtype + ) else: self.map = None diff --git a/algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py b/algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py index 9c6faf70b..bfef3e0a9 100644 --- a/algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py +++ b/algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py @@ -23,11 +23,13 @@ class ImagenetVitWorkload(BaseImagenetVitWorkload, ImagenetResNetWorkload): def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: torch.random.manual_seed(rng[0]) + param_dtype = spec.PYTORCH_DTYPE_MAP[self._param_dtype] model = models.ViT( num_classes=self._num_classes, use_glu=self.use_glu, use_post_layer_norm=self.use_post_layer_norm, use_map=self.use_map, + dtype=param_dtype, **decode_variant('S/16'), ) self._param_shapes = param_utils.pytorch_param_shapes(model) @@ -70,11 +72,13 @@ def model_fn( spec.ForwardPassMode.TRAIN: contextlib.nullcontext, } + compute_dtype = spec.PYTORCH_DTYPE_MAP[self._compute_dtype] with contexts[mode](): - logits_batch = model( - augmented_and_preprocessed_input_batch['inputs'], - dropout_rate=dropout_rate, - ) + with torch.autocast(device_type='cuda', dtype=compute_dtype): + logits_batch = model( + augmented_and_preprocessed_input_batch['inputs'], + dropout_rate=dropout_rate, + ) return logits_batch, None diff --git a/algoperf/workloads/ogbg/workload.py b/algoperf/workloads/ogbg/workload.py index 002576268..771b103a0 100644 --- a/algoperf/workloads/ogbg/workload.py +++ b/algoperf/workloads/ogbg/workload.py @@ -92,7 +92,7 @@ def max_allowed_runtime_sec(self) -> int: @property def eval_period_time_sec(self) -> int: - return 452 # approx 25 evals + return 452 # approx 25 evals def _build_input_queue( self, diff --git a/algorithms/baselines/external_tuning/jax_nadamw_full_budget.py b/algorithms/baselines/external_tuning/jax_nadamw_full_budget.py index 0577cd4e0..a6f36fd30 100644 --- a/algorithms/baselines/external_tuning/jax_nadamw_full_budget.py +++ b/algorithms/baselines/external_tuning/jax_nadamw_full_budget.py @@ -396,6 +396,8 @@ def get_batch_size(workload_name): return 128 elif workload_name == 'mnist': return 16 + elif workload_name == 'cifar': + return 16384 else: raise ValueError(f'Unsupported workload name: {workload_name}.') diff --git a/algorithms/baselines/external_tuning/pytorch_nadamw_full_budget.py b/algorithms/baselines/external_tuning/pytorch_nadamw_full_budget.py index 0b32199ba..285727885 100644 --- a/algorithms/baselines/external_tuning/pytorch_nadamw_full_budget.py +++ b/algorithms/baselines/external_tuning/pytorch_nadamw_full_budget.py @@ -5,7 +5,6 @@ import torch import torch.distributed.nn as dist_nn -from absl import logging from torch import Tensor from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR @@ -315,13 +314,6 @@ def update_params( }, global_step, ) - logging.info( - '%d) loss = %0.3f, grad_norm = %0.3f', - global_step, - loss.item(), - grad_norm.item(), - ) - return (optimizer_state, current_param_container, new_model_state) @@ -372,6 +364,8 @@ def get_batch_size(workload_name): return 128 elif workload_name == 'mnist': return 16 + elif workload_name == 'cifar': + return 16384 else: raise ValueError(f'Unsupported workload name: {workload_name}.') diff --git a/scoring/performance_profile.py b/scoring/performance_profile.py index b200c6865..043a65791 100644 --- a/scoring/performance_profile.py +++ b/scoring/performance_profile.py @@ -71,7 +71,7 @@ 'wer', 'l1_loss', 'loss', - 'ppl' + 'ppl', ] MAX_EVAL_METRICS = ['mean_average_precision', 'ssim', 'accuracy', 'bleu'] diff --git a/submission_runner.py b/submission_runner.py index 552c99b79..84ae3307b 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -266,6 +266,7 @@ def train_once( 'librispeech_deepspeech', 'ogbg', 'wmt', + 'cifar', ] base_workload = workloads.get_base_workload_name(workload_name) if base_workload in compile_error_workloads: @@ -409,10 +410,15 @@ def train_once( train_state['training_complete'] = True train_step_end_time = get_time() - - train_state['accumulated_submission_time'] += ( - train_step_end_time - train_state['last_step_end_time'] - ) + step_time = train_step_end_time - train_state['last_step_end_time'] + train_state['accumulated_submission_time'] += step_time + # Log training progress periodically + if global_step % 10 == 0: + logging.info( + f'Step: {global_step}, ' + f'\tLast step time: {step_time:.4f}s, ' + f'\tTotal time: {train_state["accumulated_submission_time"]:.2f}s' + ) # Check if submission is eligible for an untimed eval. if ( @@ -512,10 +518,19 @@ def train_once( latest_eval_result['accumulated_logging_time'] = train_state[ 'accumulated_logging_time' ] + # Calculate average per-step time + avg_per_step_time = ( + train_state['accumulated_submission_time'] / global_step + if global_step > 0 + else 0.0 + ) + latest_eval_result['avg_per_step_time'] = avg_per_step_time time_since_start = latest_eval_result['total_duration'] logging.info( f'Time since start: {time_since_start:.2f}s, ' - f'\tStep: {global_step}, \t{latest_eval_result}' + f'\tStep: {global_step}, ' + f'\tAvg per-step time: {avg_per_step_time:.4f}s, ' + f'\t{latest_eval_result}' ) eval_results.append((global_step, latest_eval_result)) From 68060195b8d3aa79848d32bd5d0ea8040a634b18 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 11 Dec 2025 03:15:03 +0000 Subject: [PATCH 07/17] Revert "ImageNet and CIFAR mixed-precision support, need to debug slow pytorch" This reverts commit 6f7d638adc190d9bce3f30ba3314c27dac1a8cc5. --- algoperf/spec.py | 23 ------ .../cifar/cifar_jax/input_pipeline.py | 2 + algoperf/workloads/cifar/cifar_jax/models.py | 8 +-- .../workloads/cifar/cifar_jax/workload.py | 29 ++------ .../workloads/cifar/cifar_pytorch/models.py | 24 +------ .../workloads/cifar/cifar_pytorch/workload.py | 9 +-- algoperf/workloads/cifar/workload.py | 2 - .../imagenet_resnet/imagenet_jax/models.py | 8 +-- .../imagenet_resnet/imagenet_jax/workload.py | 28 ++------ .../imagenet_pytorch/models.py | 52 +++----------- .../imagenet_pytorch/workload.py | 9 +-- .../workloads/imagenet_resnet/workload.py | 2 - .../imagenet_vit/imagenet_jax/models.py | 48 ++++--------- .../imagenet_vit/imagenet_jax/workload.py | 10 +-- .../imagenet_vit/imagenet_pytorch/models.py | 71 ++++++------------- .../imagenet_vit/imagenet_pytorch/workload.py | 12 ++-- algoperf/workloads/ogbg/workload.py | 2 +- .../external_tuning/jax_nadamw_full_budget.py | 2 - .../pytorch_nadamw_full_budget.py | 10 ++- scoring/performance_profile.py | 2 +- submission_runner.py | 25 ++----- 21 files changed, 90 insertions(+), 288 deletions(-) diff --git a/algoperf/spec.py b/algoperf/spec.py index 8dd00345c..b86e55954 100644 --- a/algoperf/spec.py +++ b/algoperf/spec.py @@ -6,34 +6,11 @@ from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Union import jax -import jax.numpy as jnp -import torch import torch.nn.functional as F from absl import logging from torch import nn -class DTYPE(enum.Enum): - FLOAT32 = 0 - FLOAT16 = 1 - BFLOAT16 = 2 - - -# Mapping from DTYPE enum to JAX dtypes -JAX_DTYPE_MAP = { - DTYPE.FLOAT32: jnp.float32, - DTYPE.FLOAT16: jnp.float16, - DTYPE.BFLOAT16: jnp.bfloat16, -} - -# Mapping from DTYPE enum to PyTorch dtypes -PYTORCH_DTYPE_MAP = { - DTYPE.FLOAT32: torch.float32, - DTYPE.FLOAT16: torch.float16, - DTYPE.BFLOAT16: torch.bfloat16, -} - - class LossType(enum.Enum): SOFTMAX_CROSS_ENTROPY = 0 SIGMOID_CROSS_ENTROPY = 1 diff --git a/algoperf/workloads/cifar/cifar_jax/input_pipeline.py b/algoperf/workloads/cifar/cifar_jax/input_pipeline.py index 307e9e705..7fbc95bc6 100644 --- a/algoperf/workloads/cifar/cifar_jax/input_pipeline.py +++ b/algoperf/workloads/cifar/cifar_jax/input_pipeline.py @@ -11,6 +11,7 @@ import jax import tensorflow as tf import tensorflow_datasets as tfds +from flax import jax_utils from algoperf import spec from algoperf.data_utils import shard_and_maybe_pad_np @@ -185,4 +186,5 @@ def create_input_iter( ), ds, ) + it = jax_utils.prefetch_to_device(it, 2) return it diff --git a/algoperf/workloads/cifar/cifar_jax/models.py b/algoperf/workloads/cifar/cifar_jax/models.py index 9a4f7fd96..95238c997 100644 --- a/algoperf/workloads/cifar/cifar_jax/models.py +++ b/algoperf/workloads/cifar/cifar_jax/models.py @@ -31,7 +31,7 @@ def __call__( update_batch_norm: bool = True, use_running_average_bn: bool = None, ) -> spec.Tensor: - conv = functools.partial(nn.Conv, use_bias=False, param_dtype=self.dtype) + conv = functools.partial(nn.Conv, use_bias=False, dtype=self.dtype) # Preserve default behavior for backwards compatibility if use_running_average_bn is None: @@ -41,7 +41,7 @@ def __call__( use_running_average=use_running_average_bn, momentum=0.9, epsilon=1e-5, - param_dtype=self.dtype, + dtype=self.dtype, ) x = conv( @@ -66,9 +66,7 @@ def __call__( x = nn.avg_pool(x, (4, 4), strides=(4, 4)) x = jnp.mean(x, axis=(1, 2)) x = nn.Dense( - self.num_classes, - kernel_init=nn.initializers.normal(), - param_dtype=self.dtype, + self.num_classes, kernel_init=nn.initializers.normal(), dtype=self.dtype )(x) return x diff --git a/algoperf/workloads/cifar/cifar_jax/workload.py b/algoperf/workloads/cifar/cifar_jax/workload.py index e6bc5b419..defc30121 100644 --- a/algoperf/workloads/cifar/cifar_jax/workload.py +++ b/algoperf/workloads/cifar/cifar_jax/workload.py @@ -5,7 +5,6 @@ import jax import jax.numpy as jnp -import jmp import optax import tensorflow_datasets as tfds from flax import linen as nn @@ -19,17 +18,6 @@ class CifarWorkload(BaseCifarWorkload): - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) - compute_dtype = spec.JAX_DTYPE_MAP[self._compute_dtype] - param_dtype = spec.JAX_DTYPE_MAP[self._param_dtype] - output_dtype = compute_dtype - self._mp_policy = jmp.Policy( - compute_dtype=compute_dtype, - param_dtype=param_dtype, - output_dtype=output_dtype, - ) - def _build_cifar_dataset( self, data_rng: spec.RandomState, @@ -92,8 +80,7 @@ def sync_batch_stats( def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: """Dropout is unused.""" model_cls = getattr(models, 'ResNet18') - param_dtype = spec.JAX_DTYPE_MAP[self._param_dtype] - model = model_cls(num_classes=self._num_classes, dtype=param_dtype) + model = model_cls(num_classes=self._num_classes, dtype=jnp.float32) self._model = model input_shape = (1, 32, 32, 3) variables = jax.jit(model.init)( @@ -102,7 +89,7 @@ def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: model_state, params = pop(variables, 'params') self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) - model_state = jax_sharding_utils.replicate(model_state) + model_state = jax_sharding_utils.replicate(params) params = jax_sharding_utils.replicate(params) return params, model_state @@ -123,32 +110,24 @@ def model_fn( del mode del rng del dropout_rate - # Cast params and inputs to compute dtype - params, inputs = self._mp_policy.cast_to_compute( - (params, augmented_and_preprocessed_input_batch['inputs']) - ) variables = {'params': params, **model_state} if update_batch_norm: logits, new_model_state = self._model.apply( variables, - inputs, + augmented_and_preprocessed_input_batch['inputs'], update_batch_norm=update_batch_norm, mutable=['batch_stats'], use_running_average_bn=use_running_average_bn, ) - # Cast logits to output dtype - logits = self._mp_policy.cast_to_output(logits) return logits, new_model_state else: logits = self._model.apply( variables, - inputs, + augmented_and_preprocessed_input_batch['inputs'], update_batch_norm=update_batch_norm, mutable=False, use_running_average_bn=use_running_average_bn, ) - # Cast logits to output dtype - logits = self._mp_policy.cast_to_output(logits) return logits, model_state # Does NOT apply regularization, which is left to the submitter to do in diff --git a/algoperf/workloads/cifar/cifar_pytorch/models.py b/algoperf/workloads/cifar/cifar_pytorch/models.py index b2b37c001..0e08f5c5a 100644 --- a/algoperf/workloads/cifar/cifar_pytorch/models.py +++ b/algoperf/workloads/cifar/cifar_pytorch/models.py @@ -29,13 +29,11 @@ def __init__( width_per_group: int = 64, replace_stride_with_dilation: Optional[List[bool]] = None, norm_layer: Optional[Callable[..., nn.Module]] = None, - dtype: torch.dtype = torch.float32, ) -> None: super().__init__() if norm_layer is None: norm_layer = nn.BatchNorm2d self._norm_layer = norm_layer - self.dtype = dtype self.inplanes = 64 self.dilation = 1 @@ -51,13 +49,7 @@ def __init__( self.groups = groups self.base_width = width_per_group self.conv1 = nn.Conv2d( - 3, - self.inplanes, - kernel_size=3, - stride=1, - padding=1, - bias=False, - dtype=dtype, + 3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False ) self.bn1 = norm_layer(self.inplanes) self.relu = nn.ReLU(inplace=True) @@ -71,7 +63,7 @@ def __init__( self.layer4 = self._make_layer( block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2] ) - self.fc = nn.Linear(512 * block.expansion, num_classes, dtype=dtype) + self.fc = nn.Linear(512 * block.expansion, num_classes) self.reset_parameters() def reset_parameters(self) -> None: @@ -113,15 +105,7 @@ def _make_layer( downsample = torch.nn.Sequential( collections.OrderedDict( [ - ( - 'conv', - conv1x1( - self.inplanes, - planes * block.expansion, - stride, - dtype=self.dtype, - ), - ), + ('conv', conv1x1(self.inplanes, planes * block.expansion, stride)), ('bn', norm_layer(planes * block.expansion)), ] ) @@ -138,7 +122,6 @@ def _make_layer( self.base_width, previous_dilation, norm_layer, - dtype=self.dtype, ) ) self.inplanes = planes * block.expansion @@ -151,7 +134,6 @@ def _make_layer( base_width=self.base_width, dilation=self.dilation, norm_layer=norm_layer, - dtype=self.dtype, ) ) diff --git a/algoperf/workloads/cifar/cifar_pytorch/workload.py b/algoperf/workloads/cifar/cifar_pytorch/workload.py index 141bef922..a6e8569cc 100644 --- a/algoperf/workloads/cifar/cifar_pytorch/workload.py +++ b/algoperf/workloads/cifar/cifar_pytorch/workload.py @@ -25,8 +25,6 @@ def __init__(self, *args, **kwargs) -> None: # Is set in submission_runner.py for workloads with PyTorch evaluation # data loaders via the `eval_num_workers` property. self._eval_num_workers = None - self._param_dtype_pt = spec.PYTORCH_DTYPE_MAP[self._param_dtype] - self._compute_dtype_pt = spec.PYTORCH_DTYPE_MAP[self._compute_dtype] @property def eval_num_workers(self) -> int: @@ -130,9 +128,7 @@ def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: return self._model, None torch.random.manual_seed(rng[0]) - self._model = resnet18( - num_classes=self._num_classes, dtype=self._param_dtype_pt - ) + self._model = resnet18(num_classes=self._num_classes) self._param_shapes = param_utils.pytorch_param_shapes(self._model) self._param_types = param_utils.pytorch_param_types(self._param_shapes) self._model.to(DEVICE) @@ -179,8 +175,7 @@ def model_fn( spec.ForwardPassMode.TRAIN: contextlib.nullcontext, } with contexts[mode](): - with torch.autocast(device_type='cuda', dtype=self._compute_dtype_pt): - logits_batch = model(augmented_and_preprocessed_input_batch['inputs']) + logits_batch = model(augmented_and_preprocessed_input_batch['inputs']) return logits_batch, None # Does NOT apply regularization, which is left to the submitter to do in diff --git a/algoperf/workloads/cifar/workload.py b/algoperf/workloads/cifar/workload.py index 6866bc918..31636807c 100644 --- a/algoperf/workloads/cifar/workload.py +++ b/algoperf/workloads/cifar/workload.py @@ -16,8 +16,6 @@ class BaseCifarWorkload(spec.Workload): _num_classes: int = 10 - _compute_dtype: spec.DTYPE = spec.DTYPE.BFLOAT16 - _param_dtype: spec.DTYPE = spec.DTYPE.FLOAT32 @property def target_metric_name(self) -> str: diff --git a/algoperf/workloads/imagenet_resnet/imagenet_jax/models.py b/algoperf/workloads/imagenet_resnet/imagenet_jax/models.py index 41551d4d2..ee1ddf427 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_jax/models.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_jax/models.py @@ -90,7 +90,7 @@ def __call__( update_batch_norm: bool = True, use_running_average_bn: Optional[bool] = None, ) -> spec.Tensor: - conv = functools.partial(nn.Conv, use_bias=False, param_dtype=self.dtype) + conv = functools.partial(nn.Conv, use_bias=False, dtype=self.dtype) # Preserve default behavior for backwards compatibility if use_running_average_bn is None: use_running_average_bn = not update_batch_norm @@ -99,7 +99,7 @@ def __call__( use_running_average=use_running_average_bn, momentum=0.9, epsilon=1e-5, - param_dtype=self.dtype, + dtype=self.dtype, ) x = conv( @@ -125,9 +125,7 @@ def __call__( )(x) x = jnp.mean(x, axis=(1, 2)) x = nn.Dense( - self.num_classes, - kernel_init=nn.initializers.normal(), - param_dtype=self.dtype, + self.num_classes, kernel_init=nn.initializers.normal(), dtype=self.dtype )(x) return x diff --git a/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py b/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py index d7a8ede67..f73a1b26e 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py @@ -11,7 +11,6 @@ import jax import jax.numpy as jnp -import jmp import optax import tensorflow_datasets as tfds from flax import linen as nn @@ -30,17 +29,6 @@ class ImagenetResNetWorkload(BaseImagenetResNetWorkload): - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) - compute_dtype = spec.JAX_DTYPE_MAP[self._compute_dtype] - param_dtype = spec.JAX_DTYPE_MAP[self._param_dtype] - output_dtype = compute_dtype - self._mp_policy = jmp.Policy( - compute_dtype=compute_dtype, - param_dtype=param_dtype, - output_dtype=output_dtype, - ) - def _build_dataset( self, data_rng: spec.RandomState, @@ -101,12 +89,11 @@ def init_model_fn( else: act_fnc = nn.relu - param_dtype = spec.JAX_DTYPE_MAP[self._param_dtype] model = model_cls( num_classes=self._num_classes, act=act_fnc, bn_init_scale=self.bn_init_scale, - dtype=param_dtype, + dtype=jnp.float32, ) self._model = model input_shape = (1, 224, 224, 3) @@ -172,28 +159,25 @@ def model_fn( del mode del rng del dropout_rate - params, inputs = self._mp_policy.cast_to_compute( - (params, augmented_and_preprocessed_input_batch['inputs']) - ) variables = {'params': params, **model_state} if update_batch_norm: - logits, model_state = self._model.apply( + logits, new_model_state = self._model.apply( variables, - inputs, + augmented_and_preprocessed_input_batch['inputs'], update_batch_norm=update_batch_norm, mutable=['batch_stats'], use_running_average_bn=use_running_average_bn, ) + return logits, new_model_state else: logits = self._model.apply( variables, - inputs, + augmented_and_preprocessed_input_batch['inputs'], update_batch_norm=update_batch_norm, mutable=False, use_running_average_bn=use_running_average_bn, ) - logits = self._mp_policy.cast_to_output(logits) - return logits, model_state + return logits, model_state # Does NOT apply regularization, which is left to the submitter to do in # `update_params`. diff --git a/algoperf/workloads/imagenet_resnet/imagenet_pytorch/models.py b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/models.py index f24ba66b9..c980faa06 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_pytorch/models.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/models.py @@ -20,7 +20,6 @@ def conv3x3( stride: int = 1, groups: int = 1, dilation: int = 1, - dtype: torch.dtype = torch.float32, ) -> nn.Conv2d: """3x3 convolution with padding.""" return nn.Conv2d( @@ -32,24 +31,13 @@ def conv3x3( groups=groups, bias=False, dilation=dilation, - dtype=dtype, ) -def conv1x1( - in_planes: int, - out_planes: int, - stride: int = 1, - dtype: torch.dtype = torch.float32, -) -> nn.Conv2d: +def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: """1x1 convolution.""" return nn.Conv2d( - in_planes, - out_planes, - kernel_size=1, - stride=stride, - bias=False, - dtype=dtype, + in_planes, out_planes, kernel_size=1, stride=stride, bias=False ) @@ -69,7 +57,6 @@ def __init__( dilation: int = 1, norm_layer: Optional[Callable[..., nn.Module]] = None, act_fnc: nn.Module = nn.ReLU(inplace=True), - dtype: torch.dtype = torch.float32, ) -> None: super().__init__() if norm_layer is None: @@ -80,10 +67,10 @@ def __init__( raise NotImplementedError('Dilation > 1 not supported in BasicBlock') # Both self.conv1 and self.downsample layers downsample # the input when stride != 1. - self.conv1 = conv3x3(inplanes, planes, stride, dtype=dtype) + self.conv1 = conv3x3(inplanes, planes, stride) self.bn1 = norm_layer(planes) self.act_fnc = act_fnc - self.conv2 = conv3x3(planes, planes, dtype=dtype) + self.conv2 = conv3x3(planes, planes) self.bn2 = norm_layer(planes) self.downsample = downsample self.stride = stride @@ -123,7 +110,6 @@ def __init__( dilation: int = 1, norm_layer: Optional[Callable[..., nn.Module]] = None, act_fnc: nn.Module = nn.ReLU(inplace=True), - dtype: torch.dtype = torch.float32, ) -> None: super().__init__() if norm_layer is None: @@ -131,11 +117,11 @@ def __init__( width = int(planes * (base_width / 64.0)) * groups # Both self.conv2 and self.downsample layers downsample # the input when stride != 1. - self.conv1 = conv1x1(inplanes, width, dtype=dtype) + self.conv1 = conv1x1(inplanes, width) self.bn1 = norm_layer(width) - self.conv2 = conv3x3(width, width, stride, groups, dilation, dtype=dtype) + self.conv2 = conv3x3(width, width, stride, groups, dilation) self.bn2 = norm_layer(width) - self.conv3 = conv1x1(width, planes * self.expansion, dtype=dtype) + self.conv3 = conv1x1(width, planes * self.expansion) self.bn3 = norm_layer(planes * self.expansion) self.act_fnc = act_fnc self.downsample = downsample @@ -177,13 +163,11 @@ def __init__( norm_layer: Optional[Callable[..., nn.Module]] = None, act_fnc: nn.Module = nn.ReLU(inplace=True), bn_init_scale: float = 0.0, - dtype: torch.dtype = torch.float32, ) -> None: super().__init__() if norm_layer is None: norm_layer = nn.BatchNorm2d self._norm_layer = norm_layer - self.dtype = dtype self.inplanes = 64 self.dilation = 1 @@ -199,13 +183,7 @@ def __init__( self.groups = groups self.base_width = width_per_group self.conv1 = nn.Conv2d( - 3, - self.inplanes, - kernel_size=7, - stride=2, - padding=3, - bias=False, - dtype=dtype, + 3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False ) self.bn1 = norm_layer(self.inplanes) self.act_fnc = act_fnc @@ -236,7 +214,7 @@ def __init__( dilate=replace_stride_with_dilation[2], ) self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) - self.fc = nn.Linear(512 * block.expansion, num_classes, dtype=dtype) + self.fc = nn.Linear(512 * block.expansion, num_classes) for m in self.modules(): if isinstance(m, nn.Conv2d): @@ -278,15 +256,7 @@ def _make_layer( downsample = torch.nn.Sequential( collections.OrderedDict( [ - ( - 'conv', - conv1x1( - self.inplanes, - planes * block.expansion, - stride, - dtype=self.dtype, - ), - ), + ('conv', conv1x1(self.inplanes, planes * block.expansion, stride)), ('bn', norm_layer(planes * block.expansion)), ] ) @@ -304,7 +274,6 @@ def _make_layer( previous_dilation, norm_layer, act_fnc, - dtype=self.dtype, ) ) self.inplanes = planes * block.expansion @@ -318,7 +287,6 @@ def _make_layer( dilation=self.dilation, norm_layer=norm_layer, act_fnc=act_fnc, - dtype=self.dtype, ) ) diff --git a/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py index 3a88245ae..d5366c60d 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py @@ -178,10 +178,7 @@ def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: else: act_fnc = torch.nn.ReLU(inplace=True) - param_dtype = spec.PYTORCH_DTYPE_MAP[self._param_dtype] - model = resnet50( - act_fnc=act_fnc, bn_init_scale=self.bn_init_scale, dtype=param_dtype - ) + model = resnet50(act_fnc=act_fnc, bn_init_scale=self.bn_init_scale) self._param_shapes = param_utils.pytorch_param_shapes(model) self._param_types = param_utils.pytorch_param_types(self._param_shapes) model.to(DEVICE) @@ -232,10 +229,8 @@ def model_fn( spec.ForwardPassMode.TRAIN: contextlib.nullcontext, } - compute_dtype = spec.PYTORCH_DTYPE_MAP[self._compute_dtype] with contexts[mode](): - with torch.autocast(device_type='cuda', dtype=compute_dtype): - logits_batch = model(augmented_and_preprocessed_input_batch['inputs']) + logits_batch = model(augmented_and_preprocessed_input_batch['inputs']) return logits_batch, None diff --git a/algoperf/workloads/imagenet_resnet/workload.py b/algoperf/workloads/imagenet_resnet/workload.py index bc5982f1d..de8458c92 100644 --- a/algoperf/workloads/imagenet_resnet/workload.py +++ b/algoperf/workloads/imagenet_resnet/workload.py @@ -8,8 +8,6 @@ class BaseImagenetResNetWorkload(spec.Workload): _num_classes: int = 1000 - _compute_dtype: spec.DTYPE = spec.DTYPE.BFLOAT16 - _param_dtype: spec.DTYPE = spec.DTYPE.FLOAT32 @property def target_metric_name(self) -> str: diff --git a/algoperf/workloads/imagenet_vit/imagenet_jax/models.py b/algoperf/workloads/imagenet_vit/imagenet_jax/models.py index 2e4630701..e86233011 100644 --- a/algoperf/workloads/imagenet_vit/imagenet_jax/models.py +++ b/algoperf/workloads/imagenet_vit/imagenet_jax/models.py @@ -42,7 +42,6 @@ class MlpBlock(nn.Module): mlp_dim: Optional[int] = None # Defaults to 4x input dim. use_glu: bool = False dropout_rate: float = DROPOUT_RATE - dtype: jnp.dtype = jnp.float32 @nn.compact def __call__( @@ -55,15 +54,15 @@ def __call__( } d = x.shape[2] - x = nn.Dense(self.mlp_dim or 4 * d, param_dtype=self.dtype, **inits)(x) + x = nn.Dense(self.mlp_dim or 4 * d, **inits)(x) x = nn.gelu(x) if self.use_glu: - y = nn.Dense(self.mlp_dim, param_dtype=self.dtype, **inits)(x) + y = nn.Dense(self.mlp_dim, **inits)(x) x = x * y x = Dropout(dropout_rate)(x, train, rate=dropout_rate) - x = nn.Dense(d, param_dtype=self.dtype, **inits)(x) + x = nn.Dense(d, **inits)(x) return x @@ -75,30 +74,25 @@ class Encoder1DBlock(nn.Module): use_glu: bool = False use_post_layer_norm: bool = False dropout_rate: float = 0.0 - dtype: jnp.dtype = jnp.float32 @nn.compact def __call__( self, x: spec.Tensor, train: bool = True, dropout_rate=dropout_rate ) -> spec.Tensor: if not self.use_post_layer_norm: - y = nn.LayerNorm(name='LayerNorm_0', param_dtype=self.dtype)(x) + y = nn.LayerNorm(name='LayerNorm_0')(x) y = nn.MultiHeadDotProductAttention( num_heads=self.num_heads, kernel_init=nn.initializers.xavier_uniform(), deterministic=train, name='MultiHeadDotProductAttention_1', - param_dtype=self.dtype, )(y) y = Dropout(dropout_rate)(y, train, rate=dropout_rate) x = x + y - y = nn.LayerNorm(name='LayerNorm_2', param_dtype=self.dtype)(x) + y = nn.LayerNorm(name='LayerNorm_2')(x) y = MlpBlock( - mlp_dim=self.mlp_dim, - use_glu=self.use_glu, - dtype=self.dtype, - name='MlpBlock_3', + mlp_dim=self.mlp_dim, use_glu=self.use_glu, name='MlpBlock_3' )(y, train, dropout_rate=dropout_rate) y = Dropout(dropout_rate)(y, train, rate=dropout_rate) x = x + y @@ -109,23 +103,21 @@ def __call__( kernel_init=nn.initializers.xavier_uniform(), deterministic=train, name='MultiHeadDotProductAttention_1', - param_dtype=self.dtype, )(y) y = Dropout(dropout_rate)(y, train, rate=dropout_rate) x = x + y - x = nn.LayerNorm(name='LayerNorm_0', param_dtype=self.dtype)(x) + x = nn.LayerNorm(name='LayerNorm_0')(x) y = x y = MlpBlock( mlp_dim=self.mlp_dim, use_glu=self.use_glu, - dtype=self.dtype, name='MlpBlock_3', dropout_rate=dropout_rate, )(y, train, dropout_rate=dropout_rate) y = Dropout(dropout_rate)(y, train)(rate=dropout_rate) x = x + y - x = nn.LayerNorm(name='LayerNorm_2', param_dtype=self.dtype)(x) + x = nn.LayerNorm(name='LayerNorm_2')(x) return x @@ -138,7 +130,6 @@ class Encoder(nn.Module): num_heads: int = 12 use_glu: bool = False use_post_layer_norm: bool = False - dtype: jnp.dtype = jnp.float32 @nn.compact def __call__( @@ -152,10 +143,9 @@ def __call__( num_heads=self.num_heads, use_glu=self.use_glu, use_post_layer_norm=self.use_post_layer_norm, - dtype=self.dtype, )(x, train=train, dropout_rate=dropout_rate) if not self.use_post_layer_norm: - return nn.LayerNorm(name='encoder_layernorm', param_dtype=self.dtype)(x) + return nn.LayerNorm(name='encoder_layernorm')(x) else: return x @@ -166,13 +156,12 @@ class MAPHead(nn.Module): mlp_dim: Optional[int] = None # Defaults to 4x input dim num_heads: int = 12 dropout_rate: float = 0.0 - dtype: jnp.dtype = jnp.float32 @nn.compact def __call__(self, x, dropout_rate=DROPOUT_RATE): n, _, d = x.shape probe = self.param( - 'probe', nn.initializers.xavier_uniform(), (1, 1, d), self.dtype + 'probe', nn.initializers.xavier_uniform(), (1, 1, d), x.dtype ) probe = jnp.tile(probe, [n, 1, 1]) @@ -180,13 +169,10 @@ def __call__(self, x, dropout_rate=DROPOUT_RATE): num_heads=self.num_heads, use_bias=True, kernel_init=nn.initializers.xavier_uniform(), - param_dtype=self.dtype, )(probe, x) - y = nn.LayerNorm(param_dtype=self.dtype)(x) - x = x + MlpBlock( - mlp_dim=self.mlp_dim, dropout_rate=dropout_rate, dtype=self.dtype - )(y) + y = nn.LayerNorm()(x) + x = x + MlpBlock(mlp_dim=self.mlp_dim, dropout_rate=dropout_rate)(y) return x[:, 0] @@ -206,7 +192,6 @@ class ViT(nn.Module): use_glu: bool = False use_post_layer_norm: bool = False use_map: bool = False - dtype: jnp.dtype = jnp.float32 def get_posemb( self, seqshape: tuple, width: int, dtype: jnp.dtype = jnp.float32 @@ -224,7 +209,6 @@ def __call__( strides=self.patch_size, padding='VALID', name='conv_patch_extract', - param_dtype=self.dtype, )(x) n, h, w, c = x.shape @@ -241,7 +225,6 @@ def __call__( num_heads=self.num_heads, use_glu=self.use_glu, use_post_layer_norm=self.use_post_layer_norm, - dtype=self.dtype, name='Transformer', )(x, train=not train, dropout_rate=dropout_rate) @@ -250,21 +233,18 @@ def __call__( num_heads=self.num_heads, mlp_dim=self.mlp_dim, dropout_rate=dropout_rate, - dtype=self.dtype, )(x, dropout_rate=dropout_rate) else: x = jnp.mean(x, axis=1) if self.rep_size: rep_size = self.width if self.rep_size is True else self.rep_size - hid = nn.Dense(rep_size, name='pre_logits', param_dtype=self.dtype) + hid = nn.Dense(rep_size, name='pre_logits') x = nn.tanh(hid(x)) if self.num_classes: kw = {'kernel_init': nn.initializers.zeros} if self.head_zeroinit else {} - head = nn.Dense( - self.num_classes, name='head', param_dtype=self.dtype, **kw - ) + head = nn.Dense(self.num_classes, name='head', **kw) x = head(x) return x diff --git a/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py b/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py index 6819a4862..8a33aeb47 100644 --- a/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py +++ b/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py @@ -32,13 +32,11 @@ def initialized( return params, model_state def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: - param_dtype = spec.JAX_DTYPE_MAP[self._param_dtype] self._model = models.ViT( num_classes=self._num_classes, use_glu=self.use_glu, use_post_layer_norm=self.use_post_layer_norm, use_map=self.use_map, - dtype=param_dtype, **decode_variant('S/16'), ) params, model_state = self.initialized(rng, self._model) @@ -64,19 +62,15 @@ def model_fn( ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del model_state del update_batch_norm - # Cast params and inputs to compute dtype - params, inputs = self._mp_policy.cast_to_compute( - (params, augmented_and_preprocessed_input_batch['inputs']) - ) + del use_running_average_bn train = mode == spec.ForwardPassMode.TRAIN logits = self._model.apply( {'params': params}, - inputs, + augmented_and_preprocessed_input_batch['inputs'], rngs={'dropout': rng}, train=train, dropout_rate=dropout_rate, ) - logits = self._mp_policy.cast_to_output(logits) return logits, None def _eval_model_on_split( diff --git a/algoperf/workloads/imagenet_vit/imagenet_pytorch/models.py b/algoperf/workloads/imagenet_vit/imagenet_pytorch/models.py index 6dfb5fddf..fc2a3cd46 100644 --- a/algoperf/workloads/imagenet_vit/imagenet_pytorch/models.py +++ b/algoperf/workloads/imagenet_vit/imagenet_pytorch/models.py @@ -46,24 +46,22 @@ def __init__( width: int, mlp_dim: Optional[int] = None, # Defaults to 4x input dim. use_glu: bool = False, - dtype: Any = torch.float32, ) -> None: super().__init__() self.width = width self.mlp_dim = mlp_dim or 4 * width self.use_glu = use_glu - self.dtype = dtype - self.linear1 = nn.Linear(self.width, self.mlp_dim, dtype=self.dtype) + self.linear1 = nn.Linear(self.width, self.mlp_dim) self.act_fnc = nn.GELU(approximate='tanh') if self.use_glu: - self.glu_linear = nn.Linear(self.mlp_dim, self.mlp_dim, dtype=self.dtype) + self.glu_linear = nn.Linear(self.mlp_dim, self.mlp_dim) else: self.glu_linear = None - self.linear2 = nn.Linear(self.mlp_dim, self.width, dtype=self.dtype) + self.linear2 = nn.Linear(self.mlp_dim, self.width) self.reset_parameters() @@ -87,18 +85,14 @@ def forward(self, x: spec.Tensor, dropout_rate: float) -> spec.Tensor: return x -# TODO(rka97): switch this to built-in attention with cudnn class SelfAttention(nn.Module): """Self-attention special case of multi-head dot-product attention.""" - def __init__( - self, width: int, num_heads: int = 8, dtype: Any = torch.float32 - ) -> None: + def __init__(self, width: int, num_heads: int = 8) -> None: super().__init__() self.width = width self.num_heads = num_heads - self.dtype = dtype assert width % num_heads == 0, ( 'Memory dimension must be divisible by number of heads.' @@ -107,10 +101,10 @@ def __init__( self.head_dim = int(width / num_heads) self.all_head_dim = self.num_heads * self.head_dim - self.query = nn.Linear(self.width, self.all_head_dim, dtype=self.dtype) - self.key = nn.Linear(self.width, self.all_head_dim, dtype=self.dtype) - self.value = nn.Linear(self.width, self.all_head_dim, dtype=self.dtype) - self.out = nn.Linear(self.width, self.width, dtype=self.dtype) + self.query = nn.Linear(self.width, self.all_head_dim) + self.key = nn.Linear(self.width, self.all_head_dim) + self.value = nn.Linear(self.width, self.all_head_dim) + self.out = nn.Linear(self.width, self.width) self.reset_parameters() def reset_parameters(self) -> None: @@ -156,7 +150,6 @@ def __init__( num_heads: int = 12, use_glu: bool = False, use_post_layer_norm: bool = False, - dtype: Any = torch.float32, ) -> None: super().__init__() @@ -165,18 +158,12 @@ def __init__( self.num_heads = num_heads self.use_glu = use_glu self.use_post_layer_norm = use_post_layer_norm - self.dtype = dtype - self.layer_norm0 = nn.LayerNorm(self.width, eps=1e-6, dtype=self.dtype) - self.self_attention1 = SelfAttention( - self.width, self.num_heads, dtype=self.dtype - ) - self.layer_norm2 = nn.LayerNorm(self.width, eps=1e-6, dtype=self.dtype) + self.layer_norm0 = nn.LayerNorm(self.width, eps=1e-6) + self.self_attention1 = SelfAttention(self.width, self.num_heads) + self.layer_norm2 = nn.LayerNorm(self.width, eps=1e-6) self.mlp3 = MlpBlock( - width=self.width, - mlp_dim=self.mlp_dim, - use_glu=self.use_glu, - dtype=self.dtype, + width=self.width, mlp_dim=self.mlp_dim, use_glu=self.use_glu ) def forward(self, x: spec.Tensor, dropout_rate: float) -> spec.Tensor: @@ -216,7 +203,6 @@ def __init__( num_heads: int = 12, use_glu: bool = False, use_post_layer_norm: bool = False, - dtype: Any = torch.float32, ) -> None: super().__init__() @@ -226,7 +212,6 @@ def __init__( self.num_heads = num_heads self.use_glu = use_glu self.use_post_layer_norm = use_post_layer_norm - self.dtype = dtype self.net = nn.ModuleList( [ @@ -236,14 +221,13 @@ def __init__( self.num_heads, self.use_glu, self.use_post_layer_norm, - dtype=self.dtype, ) for _ in range(depth) ] ) if not self.use_post_layer_norm: - self.encoder_norm = nn.LayerNorm(self.width, eps=1e-6, dtype=self.dtype) + self.encoder_norm = nn.LayerNorm(self.width, eps=1e-6) else: self.encoder_norm = None @@ -261,32 +245,21 @@ class MAPHead(nn.Module): """Multihead Attention Pooling.""" def __init__( - self, - width: int, - mlp_dim: Optional[int] = None, - num_heads: int = 12, - dtype: torch.dtype = torch.float32, + self, width: int, mlp_dim: Optional[int] = None, num_heads: int = 12 ): super().__init__() self.width = width self.mlp_dim = mlp_dim self.num_heads = num_heads - self.dtype = dtype self.probe = nn.Parameter(torch.zeros((1, 1, self.width))) nn.init.xavier_uniform_(self.probe.data) self.mha = MultiheadAttention( - self.width, - num_heads=self.num_heads, - self_attn=False, - bias=True, - dtype=self.dtype, - ) - self.layer_norm = nn.LayerNorm(self.width, eps=1e-6, dtype=self.dtype) - self.mlp = MlpBlock( - width=self.width, mlp_dim=self.mlp_dim, dtype=self.dtype + self.width, num_heads=self.num_heads, self_attn=False, bias=True ) + self.layer_norm = nn.LayerNorm(self.width, eps=1e-6) + self.mlp = MlpBlock(width=self.width, mlp_dim=self.mlp_dim) def forward(self, x: spec.Tensor, dropout_rate: float) -> spec.Tensor: n, _, _ = x.shape @@ -337,7 +310,7 @@ def __init__( if self.rep_size: rep_size = self.width if self.rep_size is True else self.rep_size - self.pre_logits = nn.Linear(self.width, rep_size, dtype=self.dtype) + self.pre_logits = nn.Linear(self.width, rep_size) self.conv_patch_extract = nn.Conv2d( self.channels, @@ -345,7 +318,6 @@ def __init__( self.patch_size, stride=self.patch_size, padding='valid', - dtype=self.dtype, ) self.encoder = Encoder( @@ -355,16 +327,13 @@ def __init__( num_heads=self.num_heads, use_glu=self.use_glu, use_post_layer_norm=self.use_post_layer_norm, - dtype=self.dtype, ) if self.num_classes: - self.head = nn.Linear(self.width, self.num_classes, dtype=self.dtype) + self.head = nn.Linear(self.width, self.num_classes) if self.use_map: - self.map = MAPHead( - self.width, self.mlp_dim, self.num_heads, dtype=self.dtype - ) + self.map = MAPHead(self.width, self.mlp_dim, self.num_heads) else: self.map = None diff --git a/algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py b/algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py index bfef3e0a9..9c6faf70b 100644 --- a/algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py +++ b/algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py @@ -23,13 +23,11 @@ class ImagenetVitWorkload(BaseImagenetVitWorkload, ImagenetResNetWorkload): def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: torch.random.manual_seed(rng[0]) - param_dtype = spec.PYTORCH_DTYPE_MAP[self._param_dtype] model = models.ViT( num_classes=self._num_classes, use_glu=self.use_glu, use_post_layer_norm=self.use_post_layer_norm, use_map=self.use_map, - dtype=param_dtype, **decode_variant('S/16'), ) self._param_shapes = param_utils.pytorch_param_shapes(model) @@ -72,13 +70,11 @@ def model_fn( spec.ForwardPassMode.TRAIN: contextlib.nullcontext, } - compute_dtype = spec.PYTORCH_DTYPE_MAP[self._compute_dtype] with contexts[mode](): - with torch.autocast(device_type='cuda', dtype=compute_dtype): - logits_batch = model( - augmented_and_preprocessed_input_batch['inputs'], - dropout_rate=dropout_rate, - ) + logits_batch = model( + augmented_and_preprocessed_input_batch['inputs'], + dropout_rate=dropout_rate, + ) return logits_batch, None diff --git a/algoperf/workloads/ogbg/workload.py b/algoperf/workloads/ogbg/workload.py index 771b103a0..002576268 100644 --- a/algoperf/workloads/ogbg/workload.py +++ b/algoperf/workloads/ogbg/workload.py @@ -92,7 +92,7 @@ def max_allowed_runtime_sec(self) -> int: @property def eval_period_time_sec(self) -> int: - return 452 # approx 25 evals + return 452 # approx 25 evals def _build_input_queue( self, diff --git a/algorithms/baselines/external_tuning/jax_nadamw_full_budget.py b/algorithms/baselines/external_tuning/jax_nadamw_full_budget.py index a6f36fd30..0577cd4e0 100644 --- a/algorithms/baselines/external_tuning/jax_nadamw_full_budget.py +++ b/algorithms/baselines/external_tuning/jax_nadamw_full_budget.py @@ -396,8 +396,6 @@ def get_batch_size(workload_name): return 128 elif workload_name == 'mnist': return 16 - elif workload_name == 'cifar': - return 16384 else: raise ValueError(f'Unsupported workload name: {workload_name}.') diff --git a/algorithms/baselines/external_tuning/pytorch_nadamw_full_budget.py b/algorithms/baselines/external_tuning/pytorch_nadamw_full_budget.py index 285727885..0b32199ba 100644 --- a/algorithms/baselines/external_tuning/pytorch_nadamw_full_budget.py +++ b/algorithms/baselines/external_tuning/pytorch_nadamw_full_budget.py @@ -5,6 +5,7 @@ import torch import torch.distributed.nn as dist_nn +from absl import logging from torch import Tensor from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR @@ -314,6 +315,13 @@ def update_params( }, global_step, ) + logging.info( + '%d) loss = %0.3f, grad_norm = %0.3f', + global_step, + loss.item(), + grad_norm.item(), + ) + return (optimizer_state, current_param_container, new_model_state) @@ -364,8 +372,6 @@ def get_batch_size(workload_name): return 128 elif workload_name == 'mnist': return 16 - elif workload_name == 'cifar': - return 16384 else: raise ValueError(f'Unsupported workload name: {workload_name}.') diff --git a/scoring/performance_profile.py b/scoring/performance_profile.py index 043a65791..b200c6865 100644 --- a/scoring/performance_profile.py +++ b/scoring/performance_profile.py @@ -71,7 +71,7 @@ 'wer', 'l1_loss', 'loss', - 'ppl', + 'ppl' ] MAX_EVAL_METRICS = ['mean_average_precision', 'ssim', 'accuracy', 'bleu'] diff --git a/submission_runner.py b/submission_runner.py index 84ae3307b..552c99b79 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -266,7 +266,6 @@ def train_once( 'librispeech_deepspeech', 'ogbg', 'wmt', - 'cifar', ] base_workload = workloads.get_base_workload_name(workload_name) if base_workload in compile_error_workloads: @@ -410,15 +409,10 @@ def train_once( train_state['training_complete'] = True train_step_end_time = get_time() - step_time = train_step_end_time - train_state['last_step_end_time'] - train_state['accumulated_submission_time'] += step_time - # Log training progress periodically - if global_step % 10 == 0: - logging.info( - f'Step: {global_step}, ' - f'\tLast step time: {step_time:.4f}s, ' - f'\tTotal time: {train_state["accumulated_submission_time"]:.2f}s' - ) + + train_state['accumulated_submission_time'] += ( + train_step_end_time - train_state['last_step_end_time'] + ) # Check if submission is eligible for an untimed eval. if ( @@ -518,19 +512,10 @@ def train_once( latest_eval_result['accumulated_logging_time'] = train_state[ 'accumulated_logging_time' ] - # Calculate average per-step time - avg_per_step_time = ( - train_state['accumulated_submission_time'] / global_step - if global_step > 0 - else 0.0 - ) - latest_eval_result['avg_per_step_time'] = avg_per_step_time time_since_start = latest_eval_result['total_duration'] logging.info( f'Time since start: {time_since_start:.2f}s, ' - f'\tStep: {global_step}, ' - f'\tAvg per-step time: {avg_per_step_time:.4f}s, ' - f'\t{latest_eval_result}' + f'\tStep: {global_step}, \t{latest_eval_result}' ) eval_results.append((global_step, latest_eval_result)) From c9899cfd25f57a8fd9ea32f7d74006ea2548ebf3 Mon Sep 17 00:00:00 2001 From: rka97 Date: Thu, 11 Dec 2025 23:47:19 +0000 Subject: [PATCH 08/17] Use tf32 in pytorch --- algoperf/pytorch_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/algoperf/pytorch_utils.py b/algoperf/pytorch_utils.py index af09e67fc..937001b87 100644 --- a/algoperf/pytorch_utils.py +++ b/algoperf/pytorch_utils.py @@ -20,6 +20,7 @@ def pytorch_setup() -> Tuple[bool, int, torch.device, int]: + torch.set_float32_matmul_precision('high') use_pytorch_ddp = 'LOCAL_RANK' in os.environ rank = int(os.environ['LOCAL_RANK']) if use_pytorch_ddp else 0 device = torch.device(f'cuda:{rank}' if torch.cuda.is_available() else 'cpu') From 2f865a15336153b6f1162e6d13325cf1d72bf118 Mon Sep 17 00:00:00 2001 From: rka97 Date: Fri, 12 Dec 2025 01:47:03 +0000 Subject: [PATCH 09/17] ImageNet caching for faster dataset access PyTorch --- .../imagenet_pytorch/workload.py | 111 +++++++++++++++++- algoperf/workloads/ogbg/workload.py | 2 +- scoring/performance_profile.py | 2 +- 3 files changed, 109 insertions(+), 6 deletions(-) diff --git a/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py index d5366c60d..07f84975e 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py @@ -3,10 +3,13 @@ import contextlib import functools import itertools +import json import math import os import random -from typing import Dict, Iterator, Optional, Tuple +import time +from pathlib import Path +from typing import Any, Callable, Dict, Iterator, Optional, Tuple, Union import numpy as np import torch @@ -14,7 +17,11 @@ import torch.nn.functional as F from torch.nn.parallel import DistributedDataParallel as DDP from torchvision import transforms -from torchvision.datasets.folder import ImageFolder +from torchvision.datasets.folder import ( + IMG_EXTENSIONS, + ImageFolder, + default_loader, +) import algoperf.random_utils as prng from algoperf import data_utils, param_utils, pytorch_utils, spec @@ -28,6 +35,100 @@ USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_utils.pytorch_setup() +class CachedImageFolder(ImageFolder): + """ImageFolder that caches the file listing to avoid repeated filesystem scans.""" + + def __init__( + self, + root: Union[str, Path], + cache_file: Optional[Union[str, Path]] = None, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + loader: Callable[[str], Any] = default_loader, + is_valid_file: Optional[Callable[[str], bool]] = None, + allow_empty: bool = False, + rebuild_cache: bool = False, + cache_build_timeout_minutes: int = 30, + ): + self.root = os.path.expanduser(root) + self.transform = transform + self.target_transform = target_transform + self.loader = loader + self.extensions = IMG_EXTENSIONS if is_valid_file is None else None + + # Default cache location: .cache_index.json in the root directory + if cache_file is None: + cache_file = os.path.join(self.root, '.cache_index.json') + self.cache_file = cache_file + + is_distributed = dist.is_available() and dist.is_initialized() + rank = dist.get_rank() if is_distributed else 0 + + cache_exists = os.path.exists(self.cache_file) + needs_rebuild = rebuild_cache or not cache_exists + + if needs_rebuild: + # We only want one process to build the cache + # and others to wait for it to finish. + if rank == 0: + self._build_and_save_cache(is_valid_file, allow_empty) + if is_distributed: + self._wait_for_cache(timeout_minutes=cache_build_timeout_minutes) + dist.barrier() + + self._load_from_cache() + + self.targets = [s[1] for s in self.samples] + self.imgs = self.samples + + def _wait_for_cache(self, timeout_minutes: int): + """Poll for cache file to exist.""" + timeout_seconds = timeout_minutes * 60 + poll_interval = 5 + elapsed = 0 + + while not os.path.exists(self.cache_file): + if elapsed >= timeout_seconds: + raise TimeoutError( + f'Timed out waiting for cache file after {timeout_minutes} minutes: {self.cache_file}' + ) + time.sleep(poll_interval) + elapsed += poll_interval + + def _load_from_cache(self): + """Load classes and samples from cache file.""" + with open(os.path.abspath(self.cache_file), 'r') as f: + cache = json.load(f) + self.classes = cache['classes'] + self.class_to_idx = cache['class_to_idx'] + # Convert relative paths back to absolute + self.samples = [ + (os.path.join(self.root, rel_path), idx) + for rel_path, idx in cache['samples'] + ] + + def _build_and_save_cache(self, is_valid_file, allow_empty): + """Scan filesystem, build index, and save to cache.""" + self.classes, self.class_to_idx = self.find_classes(self.root) + self.samples = self.make_dataset( + self.root, + class_to_idx=self.class_to_idx, + extensions=self.extensions, + is_valid_file=is_valid_file, + allow_empty=allow_empty, + ) + + cache = { + 'classes': self.classes, + 'class_to_idx': self.class_to_idx, + 'samples': [ + (os.path.relpath(path, self.root), idx) for path, idx in self.samples + ], + } + with open(os.path.abspath(self.cache_file), 'w') as f: + json.dump(cache, f) + + def imagenet_v2_to_torch( batch: Dict[str, spec.Tensor], ) -> Dict[str, spec.Tensor]: @@ -119,8 +220,10 @@ def _build_dataset( ) folder = 'train' if 'train' in split else 'val' - dataset = ImageFolder( - os.path.join(data_dir, folder), transform=transform_config + dataset = CachedImageFolder( + os.path.join(data_dir, folder), + transform=transform_config, + cache_file='.imagenet_cache_index.json', ) if split == 'eval_train': diff --git a/algoperf/workloads/ogbg/workload.py b/algoperf/workloads/ogbg/workload.py index 002576268..771b103a0 100644 --- a/algoperf/workloads/ogbg/workload.py +++ b/algoperf/workloads/ogbg/workload.py @@ -92,7 +92,7 @@ def max_allowed_runtime_sec(self) -> int: @property def eval_period_time_sec(self) -> int: - return 452 # approx 25 evals + return 452 # approx 25 evals def _build_input_queue( self, diff --git a/scoring/performance_profile.py b/scoring/performance_profile.py index b200c6865..043a65791 100644 --- a/scoring/performance_profile.py +++ b/scoring/performance_profile.py @@ -71,7 +71,7 @@ 'wer', 'l1_loss', 'loss', - 'ppl' + 'ppl', ] MAX_EVAL_METRICS = ['mean_average_precision', 'ssim', 'accuracy', 'bleu'] From 38fa915bc9073a4c073526e209893815432adbe1 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 8 Jan 2026 05:10:42 +0000 Subject: [PATCH 10/17] modify pytorch run command in startup script for docker --- docker/scripts/startup.sh | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docker/scripts/startup.sh b/docker/scripts/startup.sh index 35ac30461..1cd676d2a 100644 --- a/docker/scripts/startup.sh +++ b/docker/scripts/startup.sh @@ -174,7 +174,7 @@ fi # Check if arguments are valid VALID_DATASETS=("criteo1tb" "imagenet" "fastmri" "ogbg" "librispeech" \ - "wmt" "mnist") + "wmt" "mnist" "fineweb_edu_10B") VALID_WORKLOADS=("criteo1tb" "imagenet_resnet" "imagenet_resnet_silu" "imagenet_resnet_gelu" \ "imagenet_resnet_large_bn_init" "imagenet_vit" "imagenet_vit_glu" \ "imagenet_vit_post_ln" "imagenet_vit_map" "fastmri" "ogbg" \ @@ -185,7 +185,7 @@ VALID_WORKLOADS=("criteo1tb" "imagenet_resnet" "imagenet_resnet_silu" "imagenet_ "librispeech_conformer_gelu" "fastmri_model_size" "fastmri_tanh" \ "librispeech_deepspeech_tanh" \ "librispeech_deepspeech_no_resnet" "librispeech_deepspeech_norm_and_spec_aug" - "fastmri_layernorm" "ogbg_gelu" "ogbg_silu" "ogbg_model_size") + "fastmri_layernorm" "ogbg_gelu" "ogbg_silu" "ogbg_model_size" "finewebedu_lm") VALID_RULESETS=("self" "external") # Set data and experiment paths @@ -221,7 +221,7 @@ TUNING_RULESET_FLAG="--tuning_ruleset=${TUNING_RULESET}" if [[ "${FRAMEWORK}" == "jax" ]]; then COMMAND_PREFIX="python" else - COMMAND_PREFIX="torchrun --redirects 1:0,2:0,3:0,4:0,5:0,6:0,7:0 --standalone --nnodes=1 --nproc_per_node=8" + COMMAND_PREFIX="torchrun --redirects 1:0,2:0,3:0 --standalone --nnodes=1 --nproc_per_node=4" fi # Set data directory and bucket (bucket is only relevant in internal mode) From 9c93fc2fd03f5193806525da8b5aef3d995d6fc3 Mon Sep 17 00:00:00 2001 From: rka97 Date: Mon, 12 Jan 2026 00:04:34 +0000 Subject: [PATCH 11/17] some benchmarking steps --- .../external_tuning/jax_nadamw_full_budget.py | 6 - .../pytorch_nadamw_full_budget.py | 22 -- benchmark_step_times.py | 274 ++++++++++++++++++ submission_runner.py | 18 +- 4 files changed, 291 insertions(+), 29 deletions(-) create mode 100755 benchmark_step_times.py diff --git a/algorithms/baselines/external_tuning/jax_nadamw_full_budget.py b/algorithms/baselines/external_tuning/jax_nadamw_full_budget.py index 0577cd4e0..cf431de24 100644 --- a/algorithms/baselines/external_tuning/jax_nadamw_full_budget.py +++ b/algorithms/baselines/external_tuning/jax_nadamw_full_budget.py @@ -340,12 +340,6 @@ def update_params( dropout_rate, ) ) - - # Log loss, grad_norm. - if global_step % 100 == 0 and workload.metrics_logger is not None: - workload.metrics_logger.append_scalar_metrics( - {'loss': loss.item(), 'grad_norm': grad_norm.item()}, global_step - ) return (new_optimizer_state, opt_update_fn), new_params, new_model_state diff --git a/algorithms/baselines/external_tuning/pytorch_nadamw_full_budget.py b/algorithms/baselines/external_tuning/pytorch_nadamw_full_budget.py index 0b32199ba..10f481d5f 100644 --- a/algorithms/baselines/external_tuning/pytorch_nadamw_full_budget.py +++ b/algorithms/baselines/external_tuning/pytorch_nadamw_full_budget.py @@ -300,28 +300,6 @@ def update_params( optimizer_state['optimizer'].step() optimizer_state['scheduler'].step() - # Log training metrics - loss, grad_norm, batch_size. - if global_step <= 100 or global_step % 500 == 0: - with torch.no_grad(): - parameters = [p for p in current_model.parameters() if p.grad is not None] - grad_norm = torch.norm( - torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2 - ) - if workload.metrics_logger is not None: - workload.metrics_logger.append_scalar_metrics( - { - 'loss': loss.item(), - 'grad_norm': grad_norm.item(), - }, - global_step, - ) - logging.info( - '%d) loss = %0.3f, grad_norm = %0.3f', - global_step, - loss.item(), - grad_norm.item(), - ) - return (optimizer_state, current_param_container, new_model_state) diff --git a/benchmark_step_times.py b/benchmark_step_times.py new file mode 100755 index 000000000..363856734 --- /dev/null +++ b/benchmark_step_times.py @@ -0,0 +1,274 @@ +#!/usr/bin/env python3 +"""Benchmark step times for JAX and PyTorch across all workloads. + +This script runs each workload for 101 steps with both JAX and PyTorch, +captures the step_time_ms metric, and produces a comparison table. +""" + +import argparse +import re +import subprocess +from pathlib import Path + +# Base workloads to benchmark +WORKLOADS = [ + 'imagenet_resnet', +] + +FRAMEWORKS = ['jax', 'pytorch'] +MAX_STEPS = 201 +OUTPUT_DIR = Path('/home/ak4605/aef2/benchmark_outputs') + + +def get_data_dir(workload: str, framework: str) -> str: + """Map workload to its data directory.""" + if workload in ['imagenet_resnet', 'imagenet_vit']: + return '/opt/data/imagenet/' + framework + elif workload in ['librispeech_conformer', 'librispeech_deepspeech']: + return '/opt/data/librispeech' + elif workload == 'criteo1tb': + return '/opt/data/criteo1tb' + elif workload == 'fastmri': + return '/opt/data/fastmri' + elif workload == 'ogbg': + return '/opt/data/ogbg' + elif workload == 'wmt': + return '/opt/data/wmt' + else: + return '/opt/' + + +def run_workload(workload: str, framework: str, output_file: Path) -> bool: + """Run a workload and capture output to file.""" + data_dir = get_data_dir(workload, framework) + experiment_dir = '/home/ak4605/experiments' + + # Clean up previous experiment directories + for item in Path(experiment_dir).glob(f'{workload}*'): + if item.is_dir(): + subprocess.run(['rm', '-rf', str(item)], check=True) + + # Build command based on framework + submission_path = ( + f'algorithms/baselines/external_tuning/{framework}_nadamw_full_budget.py' + ) + tuning_search_space = ( + 'algorithms/baselines/external_tuning/tuning_search_space.json' + ) + + if framework == 'jax': + cmd = [ + 'python', + 'submission_runner.py', + f'--framework={framework}', + f'--workload={workload}', + f'--data_dir={data_dir}', + f'--experiment_dir={experiment_dir}', + f'--experiment_name={workload}_benchmark', + f'--submission_path={submission_path}', + f'--tuning_search_space={tuning_search_space}', + f'--max_global_steps={MAX_STEPS}', + '--skip_evals', + '--nosave_checkpoints', + '--nosave_intermediate_checkpoints', + ] + # For JAX, activate the jax conda environment + activate_cmd = 'source $(conda info --base)/etc/profile.d/conda.sh && conda activate ap11_jax && ' + else: + cmd = [ + 'torchrun', + '--nproc_per_node=4', + '--standalone', + 'submission_runner.py', + f'--framework={framework}', + f'--workload={workload}', + f'--data_dir={data_dir}', + f'--experiment_dir={experiment_dir}', + f'--experiment_name={workload}_benchmark', + f'--submission_path={submission_path}', + f'--tuning_search_space={tuning_search_space}', + f'--max_global_steps={MAX_STEPS}', + '--skip_evals', + '--nosave_checkpoints', + '--nosave_intermediate_checkpoints', + ] + # For PyTorch, activate the torch conda environment + activate_cmd = 'source $(conda info --base)/etc/profile.d/conda.sh && conda activate ap11_torch_latest && ' + + # Run the command with shell to handle conda activation + full_cmd = activate_cmd + ' '.join(cmd) + print(f'Running: {workload} with {framework}') + print(f'Output will be saved to: {output_file}') + + with open(output_file, 'w') as f: + result = subprocess.run( + full_cmd, + shell=True, + executable='/bin/bash', + stdout=f, + stderr=subprocess.STDOUT, + cwd='/home/ak4605/aef2/', + ) + + return result.returncode == 0 + + +def parse_step_time(output_file: Path) -> float | None: + """Parse the last step_time_ms from output file.""" + if not output_file.exists(): + return None + + with open(output_file, 'r') as f: + content = f.read() + + # Find all step_time_ms values + # Pattern matches: step_time_ms=123.456 or 'step_time_ms': 123.456 + pattern = r'step_time_ms[=:]\s*([\d.]+)' + matches = re.findall(pattern, content) + + if matches: + # Return the last value (most recent EMA) + return float(matches[-1]) + return None + + +def parse_args(): + parser = argparse.ArgumentParser( + description='Benchmark step times for JAX and PyTorch across workloads.' + ) + group = parser.add_mutually_exclusive_group() + group.add_argument( + '--torch-only', + action='store_true', + help='Only run PyTorch experiments; read existing JAX results from files.', + ) + group.add_argument( + '--jax-only', + action='store_true', + help='Only run JAX experiments; read existing PyTorch results from files.', + ) + group.add_argument( + '--just-read', + action='store_true', + help='Do not run any experiments; just read and compare existing outputs.', + ) + return parser.parse_args() + + +def main(): + args = parse_args() + + # Create output directory + OUTPUT_DIR.mkdir(parents=True, exist_ok=True) + + results = {} + + # Determine which frameworks to run vs read from files + if args.just_read: + frameworks_to_run = [] + frameworks_to_read = FRAMEWORKS + elif args.torch_only: + frameworks_to_run = ['pytorch'] + frameworks_to_read = ['jax'] + elif args.jax_only: + frameworks_to_run = ['jax'] + frameworks_to_read = ['pytorch'] + else: + frameworks_to_run = FRAMEWORKS + frameworks_to_read = [] + + # Run all workloads + for workload in WORKLOADS: + results[workload] = {} + + # Read existing results from files + for framework in frameworks_to_read: + output_file = OUTPUT_DIR / f'{workload}_{framework}.out' + step_time = parse_step_time(output_file) + results[workload][framework] = step_time + if step_time: + print(f'\nLoaded existing {framework.upper()} result for {workload}: {step_time:.2f} ms') + else: + print(f'\nNo existing {framework.upper()} result found for {workload}') + + # Run experiments for specified frameworks + for framework in frameworks_to_run: + output_file = OUTPUT_DIR / f'{workload}_{framework}.out' + + print(f'\n{"=" * 60}') + print(f'Benchmarking {workload} with {framework}') + print(f'{"=" * 60}') + + success = run_workload(workload, framework, output_file) + + if success: + step_time = parse_step_time(output_file) + results[workload][framework] = step_time + print( + f'Step time: {step_time:.2f} ms' if step_time else 'Step time: N/A' + ) + else: + results[workload][framework] = None + print(f'Failed to run {workload} with {framework}') + + # Print results table + print('\n\n') + print('=' * 80) + print('STEP TIME COMPARISON (ms)') + print('=' * 80) + print( + f'{"Workload":<30} {"JAX (ms)":<15} {"PyTorch (ms)":<15} {"Ratio (PT/JAX)":<15}' + ) + print('-' * 80) + + for workload in WORKLOADS: + jax_time = results[workload].get('jax') + pytorch_time = results[workload].get('pytorch') + + jax_str = f'{jax_time:.2f}' if jax_time else 'N/A' + pytorch_str = f'{pytorch_time:.2f}' if pytorch_time else 'N/A' + + if jax_time and pytorch_time: + ratio = pytorch_time / jax_time + ratio_str = f'{ratio:.2f}x' + else: + ratio_str = 'N/A' + + print(f'{workload:<30} {jax_str:<15} {pytorch_str:<15} {ratio_str:<15}') + + print('=' * 80) + + # Save results to file + results_file = OUTPUT_DIR / 'results.txt' + with open(results_file, 'w') as f: + f.write('STEP TIME COMPARISON (ms)\n') + f.write('=' * 80 + '\n') + f.write( + f'{"Workload":<30} {"JAX (ms)":<15} {"PyTorch (ms)":<15} {"Ratio (PT/JAX)":<15}\n' + ) + f.write('-' * 80 + '\n') + + for workload in WORKLOADS: + jax_time = results[workload].get('jax') + pytorch_time = results[workload].get('pytorch') + + jax_str = f'{jax_time:.2f}' if jax_time else 'N/A' + pytorch_str = f'{pytorch_time:.2f}' if pytorch_time else 'N/A' + + if jax_time and pytorch_time: + ratio = pytorch_time / jax_time + ratio_str = f'{ratio:.2f}x' + else: + ratio_str = 'N/A' + + f.write( + f'{workload:<30} {jax_str:<15} {pytorch_str:<15} {ratio_str:<15}\n' + ) + + f.write('=' * 80 + '\n') + + print(f'\nResults saved to: {results_file}') + + +if __name__ == '__main__': + main() diff --git a/submission_runner.py b/submission_runner.py index 552c99b79..bee6ea539 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -352,7 +352,7 @@ def train_once( log_dir, flags.FLAGS, hyperparameters ) workload.attach_metrics_logger(metrics_logger) - + step_10_end_time = None global_start_time = get_time() train_state['last_step_end_time'] = global_start_time @@ -409,6 +409,22 @@ def train_once( train_state['training_complete'] = True train_step_end_time = get_time() + if global_step == 11: + step_10_end_time = train_step_end_time + + # Log step time every 100 steps + # Note: global_step was incremented, so use (global_step - 1) to match + if (global_step - 1) % 100 == 0 and workload.metrics_logger is not None: + if step_10_end_time is not None and global_step > 11: + elapsed_time_ms = (train_step_end_time - step_10_end_time) * 1000.0 + elapsed_steps = global_step - 11 + avg_step_time_ms = elapsed_time_ms / elapsed_steps + else: + avg_step_time_ms = 0.0 + workload.metrics_logger.append_scalar_metrics( + {'step_time_ms': avg_step_time_ms}, + global_step - 1, + ) train_state['accumulated_submission_time'] += ( train_step_end_time - train_state['last_step_end_time'] From f6974ebb8eb776a77f48487f2413c1cc4af63bf9 Mon Sep 17 00:00:00 2001 From: rka97 Date: Mon, 12 Jan 2026 00:55:46 +0000 Subject: [PATCH 12/17] Change num_workers for imagenet, add validation tests for step times --- .../workloads/cifar/cifar_pytorch/workload.py | 4 +- .../imagenet_pytorch/workload.py | 4 +- .../imagenet_vit/imagenet_pytorch/models.py | 14 +- .../pytorch_nadamw_full_budget.py | 1 - benchmark_step_times.py | 274 ------------------ submission_runner.py | 5 +- tests/test_step_times.py | 199 +++++++++++++ 7 files changed, 212 insertions(+), 289 deletions(-) delete mode 100755 benchmark_step_times.py create mode 100644 tests/test_step_times.py diff --git a/algoperf/workloads/cifar/cifar_pytorch/workload.py b/algoperf/workloads/cifar/cifar_pytorch/workload.py index a6e8569cc..f053fd828 100644 --- a/algoperf/workloads/cifar/cifar_pytorch/workload.py +++ b/algoperf/workloads/cifar/cifar_pytorch/workload.py @@ -110,12 +110,12 @@ def _build_dataset( batch_size=ds_iter_batch_size, shuffle=not USE_PYTORCH_DDP and is_train, sampler=sampler, - num_workers=4 if is_train else self.eval_num_workers, + num_workers=2 * N_GPUS if is_train else self.eval_num_workers, pin_memory=True, drop_last=is_train, ) - dataloader = data_utils.PrefetchedWrapper(dataloader, DEVICE) dataloader = data_utils.cycle(dataloader, custom_sampler=USE_PYTORCH_DDP) + dataloader = data_utils.dataloader_iterator_wrapper(dataloader, DEVICE) return dataloader def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: diff --git a/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py index 07f84975e..b31998822 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py @@ -254,10 +254,11 @@ def _build_dataset( batch_size=ds_iter_batch_size, shuffle=not USE_PYTORCH_DDP and is_train, sampler=sampler, - num_workers=4 if is_train else self.eval_num_workers, + num_workers=5 * N_GPUS if is_train else self.eval_num_workers, pin_memory=True, drop_last=is_train, persistent_workers=is_train, + prefetch_factor=N_GPUS, ) dataloader = data_utils.PrefetchedWrapper(dataloader, DEVICE) dataloader = data_utils.cycle( @@ -266,7 +267,6 @@ def _build_dataset( use_mixup=use_mixup, mixup_alpha=0.2, ) - return dataloader def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: diff --git a/algoperf/workloads/imagenet_vit/imagenet_pytorch/models.py b/algoperf/workloads/imagenet_vit/imagenet_pytorch/models.py index fc2a3cd46..06df7ea75 100644 --- a/algoperf/workloads/imagenet_vit/imagenet_pytorch/models.py +++ b/algoperf/workloads/imagenet_vit/imagenet_pytorch/models.py @@ -5,7 +5,6 @@ and https://github.com/lucidrains/vit-pytorch. """ -import math from typing import Any, Optional, Tuple, Union import torch @@ -126,13 +125,14 @@ def forward(self, x: spec.Tensor, dropout_rate: float) -> spec.Tensor: value_layer = self.transpose_for_scores(self.value(x)) query_layer = self.transpose_for_scores(mixed_query_layer) - attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) - attention_scores = attention_scores / math.sqrt(self.head_dim) - - attention_probs = F.softmax(attention_scores, dim=-1) - attention_probs = F.dropout(attention_probs, dropout_rate, self.training) + # Use built-in scaled_dot_product_attention (Flash Attention when available) + context_layer = F.scaled_dot_product_attention( + query_layer, + key_layer, + value_layer, + dropout_p=dropout_rate if self.training else 0.0, + ) - context_layer = torch.matmul(attention_probs, value_layer) context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_dim,) context_layer = context_layer.view(new_context_layer_shape) diff --git a/algorithms/baselines/external_tuning/pytorch_nadamw_full_budget.py b/algorithms/baselines/external_tuning/pytorch_nadamw_full_budget.py index 10f481d5f..494ada4c8 100644 --- a/algorithms/baselines/external_tuning/pytorch_nadamw_full_budget.py +++ b/algorithms/baselines/external_tuning/pytorch_nadamw_full_budget.py @@ -5,7 +5,6 @@ import torch import torch.distributed.nn as dist_nn -from absl import logging from torch import Tensor from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR diff --git a/benchmark_step_times.py b/benchmark_step_times.py deleted file mode 100755 index 363856734..000000000 --- a/benchmark_step_times.py +++ /dev/null @@ -1,274 +0,0 @@ -#!/usr/bin/env python3 -"""Benchmark step times for JAX and PyTorch across all workloads. - -This script runs each workload for 101 steps with both JAX and PyTorch, -captures the step_time_ms metric, and produces a comparison table. -""" - -import argparse -import re -import subprocess -from pathlib import Path - -# Base workloads to benchmark -WORKLOADS = [ - 'imagenet_resnet', -] - -FRAMEWORKS = ['jax', 'pytorch'] -MAX_STEPS = 201 -OUTPUT_DIR = Path('/home/ak4605/aef2/benchmark_outputs') - - -def get_data_dir(workload: str, framework: str) -> str: - """Map workload to its data directory.""" - if workload in ['imagenet_resnet', 'imagenet_vit']: - return '/opt/data/imagenet/' + framework - elif workload in ['librispeech_conformer', 'librispeech_deepspeech']: - return '/opt/data/librispeech' - elif workload == 'criteo1tb': - return '/opt/data/criteo1tb' - elif workload == 'fastmri': - return '/opt/data/fastmri' - elif workload == 'ogbg': - return '/opt/data/ogbg' - elif workload == 'wmt': - return '/opt/data/wmt' - else: - return '/opt/' - - -def run_workload(workload: str, framework: str, output_file: Path) -> bool: - """Run a workload and capture output to file.""" - data_dir = get_data_dir(workload, framework) - experiment_dir = '/home/ak4605/experiments' - - # Clean up previous experiment directories - for item in Path(experiment_dir).glob(f'{workload}*'): - if item.is_dir(): - subprocess.run(['rm', '-rf', str(item)], check=True) - - # Build command based on framework - submission_path = ( - f'algorithms/baselines/external_tuning/{framework}_nadamw_full_budget.py' - ) - tuning_search_space = ( - 'algorithms/baselines/external_tuning/tuning_search_space.json' - ) - - if framework == 'jax': - cmd = [ - 'python', - 'submission_runner.py', - f'--framework={framework}', - f'--workload={workload}', - f'--data_dir={data_dir}', - f'--experiment_dir={experiment_dir}', - f'--experiment_name={workload}_benchmark', - f'--submission_path={submission_path}', - f'--tuning_search_space={tuning_search_space}', - f'--max_global_steps={MAX_STEPS}', - '--skip_evals', - '--nosave_checkpoints', - '--nosave_intermediate_checkpoints', - ] - # For JAX, activate the jax conda environment - activate_cmd = 'source $(conda info --base)/etc/profile.d/conda.sh && conda activate ap11_jax && ' - else: - cmd = [ - 'torchrun', - '--nproc_per_node=4', - '--standalone', - 'submission_runner.py', - f'--framework={framework}', - f'--workload={workload}', - f'--data_dir={data_dir}', - f'--experiment_dir={experiment_dir}', - f'--experiment_name={workload}_benchmark', - f'--submission_path={submission_path}', - f'--tuning_search_space={tuning_search_space}', - f'--max_global_steps={MAX_STEPS}', - '--skip_evals', - '--nosave_checkpoints', - '--nosave_intermediate_checkpoints', - ] - # For PyTorch, activate the torch conda environment - activate_cmd = 'source $(conda info --base)/etc/profile.d/conda.sh && conda activate ap11_torch_latest && ' - - # Run the command with shell to handle conda activation - full_cmd = activate_cmd + ' '.join(cmd) - print(f'Running: {workload} with {framework}') - print(f'Output will be saved to: {output_file}') - - with open(output_file, 'w') as f: - result = subprocess.run( - full_cmd, - shell=True, - executable='/bin/bash', - stdout=f, - stderr=subprocess.STDOUT, - cwd='/home/ak4605/aef2/', - ) - - return result.returncode == 0 - - -def parse_step_time(output_file: Path) -> float | None: - """Parse the last step_time_ms from output file.""" - if not output_file.exists(): - return None - - with open(output_file, 'r') as f: - content = f.read() - - # Find all step_time_ms values - # Pattern matches: step_time_ms=123.456 or 'step_time_ms': 123.456 - pattern = r'step_time_ms[=:]\s*([\d.]+)' - matches = re.findall(pattern, content) - - if matches: - # Return the last value (most recent EMA) - return float(matches[-1]) - return None - - -def parse_args(): - parser = argparse.ArgumentParser( - description='Benchmark step times for JAX and PyTorch across workloads.' - ) - group = parser.add_mutually_exclusive_group() - group.add_argument( - '--torch-only', - action='store_true', - help='Only run PyTorch experiments; read existing JAX results from files.', - ) - group.add_argument( - '--jax-only', - action='store_true', - help='Only run JAX experiments; read existing PyTorch results from files.', - ) - group.add_argument( - '--just-read', - action='store_true', - help='Do not run any experiments; just read and compare existing outputs.', - ) - return parser.parse_args() - - -def main(): - args = parse_args() - - # Create output directory - OUTPUT_DIR.mkdir(parents=True, exist_ok=True) - - results = {} - - # Determine which frameworks to run vs read from files - if args.just_read: - frameworks_to_run = [] - frameworks_to_read = FRAMEWORKS - elif args.torch_only: - frameworks_to_run = ['pytorch'] - frameworks_to_read = ['jax'] - elif args.jax_only: - frameworks_to_run = ['jax'] - frameworks_to_read = ['pytorch'] - else: - frameworks_to_run = FRAMEWORKS - frameworks_to_read = [] - - # Run all workloads - for workload in WORKLOADS: - results[workload] = {} - - # Read existing results from files - for framework in frameworks_to_read: - output_file = OUTPUT_DIR / f'{workload}_{framework}.out' - step_time = parse_step_time(output_file) - results[workload][framework] = step_time - if step_time: - print(f'\nLoaded existing {framework.upper()} result for {workload}: {step_time:.2f} ms') - else: - print(f'\nNo existing {framework.upper()} result found for {workload}') - - # Run experiments for specified frameworks - for framework in frameworks_to_run: - output_file = OUTPUT_DIR / f'{workload}_{framework}.out' - - print(f'\n{"=" * 60}') - print(f'Benchmarking {workload} with {framework}') - print(f'{"=" * 60}') - - success = run_workload(workload, framework, output_file) - - if success: - step_time = parse_step_time(output_file) - results[workload][framework] = step_time - print( - f'Step time: {step_time:.2f} ms' if step_time else 'Step time: N/A' - ) - else: - results[workload][framework] = None - print(f'Failed to run {workload} with {framework}') - - # Print results table - print('\n\n') - print('=' * 80) - print('STEP TIME COMPARISON (ms)') - print('=' * 80) - print( - f'{"Workload":<30} {"JAX (ms)":<15} {"PyTorch (ms)":<15} {"Ratio (PT/JAX)":<15}' - ) - print('-' * 80) - - for workload in WORKLOADS: - jax_time = results[workload].get('jax') - pytorch_time = results[workload].get('pytorch') - - jax_str = f'{jax_time:.2f}' if jax_time else 'N/A' - pytorch_str = f'{pytorch_time:.2f}' if pytorch_time else 'N/A' - - if jax_time and pytorch_time: - ratio = pytorch_time / jax_time - ratio_str = f'{ratio:.2f}x' - else: - ratio_str = 'N/A' - - print(f'{workload:<30} {jax_str:<15} {pytorch_str:<15} {ratio_str:<15}') - - print('=' * 80) - - # Save results to file - results_file = OUTPUT_DIR / 'results.txt' - with open(results_file, 'w') as f: - f.write('STEP TIME COMPARISON (ms)\n') - f.write('=' * 80 + '\n') - f.write( - f'{"Workload":<30} {"JAX (ms)":<15} {"PyTorch (ms)":<15} {"Ratio (PT/JAX)":<15}\n' - ) - f.write('-' * 80 + '\n') - - for workload in WORKLOADS: - jax_time = results[workload].get('jax') - pytorch_time = results[workload].get('pytorch') - - jax_str = f'{jax_time:.2f}' if jax_time else 'N/A' - pytorch_str = f'{pytorch_time:.2f}' if pytorch_time else 'N/A' - - if jax_time and pytorch_time: - ratio = pytorch_time / jax_time - ratio_str = f'{ratio:.2f}x' - else: - ratio_str = 'N/A' - - f.write( - f'{workload:<30} {jax_str:<15} {pytorch_str:<15} {ratio_str:<15}\n' - ) - - f.write('=' * 80 + '\n') - - print(f'\nResults saved to: {results_file}') - - -if __name__ == '__main__': - main() diff --git a/submission_runner.py b/submission_runner.py index bee6ea539..1432b8509 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -256,7 +256,6 @@ def train_once( 'librispeech_conformer', 'ogbg', 'criteo1tb', - 'imagenet_vit', 'librispeech_deepspeech', ] eager_backend_workloads = [] @@ -266,6 +265,7 @@ def train_once( 'librispeech_deepspeech', 'ogbg', 'wmt', + 'imagenet_vit', ] base_workload = workloads.get_base_workload_name(workload_name) if base_workload in compile_error_workloads: @@ -411,9 +411,8 @@ def train_once( train_step_end_time = get_time() if global_step == 11: step_10_end_time = train_step_end_time - + # Log step time every 100 steps - # Note: global_step was incremented, so use (global_step - 1) to match if (global_step - 1) % 100 == 0 and workload.metrics_logger is not None: if step_10_end_time is not None and global_step > 11: elapsed_time_ms = (train_step_end_time - step_10_end_time) * 1000.0 diff --git a/tests/test_step_times.py b/tests/test_step_times.py new file mode 100644 index 000000000..22868d67d --- /dev/null +++ b/tests/test_step_times.py @@ -0,0 +1,199 @@ +"""Tests that JAX and PyTorch step times are within 20% of each other. + +This test runs each workload for a number of steps with both JAX and PyTorch, +captures the step_time_ms metric, and asserts they are within 20%. +""" + +import re +import subprocess +import sys +import tempfile +from pathlib import Path + +from absl import flags, logging +from absl.testing import absltest, parameterized + +FLAGS = flags.FLAGS +FLAGS(sys.argv) + +MAX_STEPS = 101 +TOLERANCE = 0.25 + +WORKLOADS = [ + 'imagenet_vit', +] + +DATA_DIRS = { + 'imagenet_resnet': '/opt/data/imagenet/', + 'imagenet_vit': '/opt/data/imagenet/', + 'librispeech_conformer': '/opt/data/librispeech', + 'librispeech_deepspeech': '/opt/data/librispeech', + 'criteo1tb': '/opt/data/criteo1tb', + 'fastmri': '/opt/data/fastmri', + 'ogbg': '/opt/data/ogbg', + 'wmt': '/opt/data/wmt', +} + +CONDA_ENVS = { + 'jax': 'ap11_jax', + 'pytorch': 'ap11_torch_latest', +} + + +def get_data_dir(workload: str, framework: str) -> str: + """Map workload to its data directory.""" + base_dir = DATA_DIRS.get(workload, '/opt/data') + if workload in ['imagenet_resnet', 'imagenet_vit']: + return base_dir + framework + return base_dir + + +def run_workload(workload: str, framework: str, output_file: Path) -> bool: + """Run a workload and capture output to file.""" + data_dir = get_data_dir(workload, framework) + experiment_dir = tempfile.mkdtemp(prefix=f'{workload}_{framework}_') + + submission_path = ( + f'algorithms/baselines/external_tuning/{framework}_nadamw_full_budget.py' + ) + tuning_search_space = ( + 'algorithms/baselines/external_tuning/tuning_search_space.json' + ) + + if framework == 'jax': + cmd = [ + 'python', + 'submission_runner.py', + f'--framework={framework}', + f'--workload={workload}', + f'--data_dir={data_dir}', + f'--experiment_dir={experiment_dir}', + f'--experiment_name={workload}_benchmark', + f'--submission_path={submission_path}', + f'--tuning_search_space={tuning_search_space}', + f'--max_global_steps={MAX_STEPS}', + '--skip_evals', + '--nosave_checkpoints', + '--nosave_intermediate_checkpoints', + ] + else: + cmd = [ + 'torchrun', + '--nproc_per_node=4', + '--standalone', + 'submission_runner.py', + f'--framework={framework}', + f'--workload={workload}', + f'--data_dir={data_dir}', + f'--experiment_dir={experiment_dir}', + f'--experiment_name={workload}_benchmark', + f'--submission_path={submission_path}', + f'--tuning_search_space={tuning_search_space}', + f'--max_global_steps={MAX_STEPS}', + '--skip_evals', + '--nosave_checkpoints', + '--nosave_intermediate_checkpoints', + ] + + conda_env = CONDA_ENVS[framework] + activate_cmd = ( + f'source $(conda info --base)/etc/profile.d/conda.sh && ' + f'conda activate {conda_env} && ' + ) + full_cmd = activate_cmd + ' '.join(cmd) + + logging.info(f'Running: {workload} with {framework}') + logging.info(f'Output will be saved to: {output_file}') + + with open(output_file, 'w') as f: + result = subprocess.run( + full_cmd, + shell=True, + executable='/bin/bash', + stdout=f, + stderr=subprocess.STDOUT, + cwd=str(Path(__file__).parent.parent), + ) + + return result.returncode == 0 + + +def parse_step_time(output_file: Path) -> float | None: + """Parse the last step_time_ms from output file.""" + if not output_file.exists(): + return None + + with open(output_file, 'r') as f: + content = f.read() + + # Find all step_time_ms values + # Pattern matches: step_time_ms=123.456 or 'step_time_ms': 123.456 + pattern = r'step_time_ms[=:]\s*([\d.]+)' + matches = re.findall(pattern, content) + + if matches: + # Return the last value (most recent EMA) + return float(matches[-1]) + return None + + +named_parameters = [ + dict(testcase_name=workload, workload=workload) for workload in WORKLOADS +] + + +class StepTimeTest(parameterized.TestCase): + """Tests that JAX and PyTorch step times are within tolerance.""" + + @parameterized.named_parameters(*named_parameters) + def test_step_times_within_tolerance(self, workload): + """Test that JAX and PyTorch step times are within 20% of each other.""" + results = {} + + with tempfile.TemporaryDirectory() as tmpdir: + tmpdir = Path(tmpdir) + + for framework in ['jax', 'pytorch']: + output_file = tmpdir / f'{workload}_{framework}.out' + + success = run_workload(workload, framework, output_file) + self.assertTrue(success, f'Failed to run {workload} with {framework}') + + step_time = parse_step_time(output_file) + self.assertIsNotNone( + step_time, + f'Could not parse step_time_ms for {workload} with {framework}', + ) + + results[framework] = step_time + logging.info(f'{workload} {framework}: {step_time:.2f} ms') + + jax_time = results['jax'] + pytorch_time = results['pytorch'] + ratio = pytorch_time / jax_time + + logging.info( + f'{workload}: JAX={jax_time:.2f}ms, PyTorch={pytorch_time:.2f}ms, ' + f'ratio={ratio:.2f}' + ) + + # Check that ratio is within tolerance (0.8 to 1.2 for 20% tolerance) + lower_bound = 1.0 - TOLERANCE + upper_bound = 1.0 + TOLERANCE + + self.assertGreaterEqual( + ratio, + lower_bound, + f'{workload}: PyTorch is more than {TOLERANCE * 100:.0f}% faster than JAX ' + f'(ratio={ratio:.2f}, expected >= {lower_bound:.2f})', + ) + self.assertLessEqual( + ratio, + upper_bound, + f'{workload}: PyTorch is more than {TOLERANCE * 100:.0f}% slower than JAX ' + f'(ratio={ratio:.2f}, expected <= {upper_bound:.2f})', + ) + + +if __name__ == '__main__': + absltest.main() From b4d742c7cf20c69df02234b38b1e7a6b78b4d637 Mon Sep 17 00:00:00 2001 From: rka97 Date: Mon, 12 Jan 2026 00:55:46 +0000 Subject: [PATCH 13/17] Change num_workers for imagenet, add validation tests for step times --- .../workloads/cifar/cifar_pytorch/workload.py | 4 +- .../imagenet_pytorch/workload.py | 9 +- .../imagenet_vit/imagenet_pytorch/models.py | 14 +- .../pytorch_nadamw_full_budget.py | 1 - benchmark_step_times.py | 274 ------------------ submission_runner.py | 5 +- tests/test_step_times.py | 199 +++++++++++++ 7 files changed, 214 insertions(+), 292 deletions(-) delete mode 100755 benchmark_step_times.py create mode 100644 tests/test_step_times.py diff --git a/algoperf/workloads/cifar/cifar_pytorch/workload.py b/algoperf/workloads/cifar/cifar_pytorch/workload.py index a6e8569cc..f053fd828 100644 --- a/algoperf/workloads/cifar/cifar_pytorch/workload.py +++ b/algoperf/workloads/cifar/cifar_pytorch/workload.py @@ -110,12 +110,12 @@ def _build_dataset( batch_size=ds_iter_batch_size, shuffle=not USE_PYTORCH_DDP and is_train, sampler=sampler, - num_workers=4 if is_train else self.eval_num_workers, + num_workers=2 * N_GPUS if is_train else self.eval_num_workers, pin_memory=True, drop_last=is_train, ) - dataloader = data_utils.PrefetchedWrapper(dataloader, DEVICE) dataloader = data_utils.cycle(dataloader, custom_sampler=USE_PYTORCH_DDP) + dataloader = data_utils.dataloader_iterator_wrapper(dataloader, DEVICE) return dataloader def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: diff --git a/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py index 07f84975e..7aa62b961 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py @@ -50,7 +50,7 @@ def __init__( rebuild_cache: bool = False, cache_build_timeout_minutes: int = 30, ): - self.root = os.path.expanduser(root) + self.root = os.path.abspath(root) self.transform = transform self.target_transform = target_transform self.loader = loader @@ -223,7 +223,7 @@ def _build_dataset( dataset = CachedImageFolder( os.path.join(data_dir, folder), transform=transform_config, - cache_file='.imagenet_cache_index.json', + cache_file='.imagenet_{}_cache_index.json'.format(split), ) if split == 'eval_train': @@ -248,16 +248,16 @@ def _build_dataset( sampler = data_utils.DistributedEvalSampler( dataset, num_replicas=N_GPUS, rank=RANK, shuffle=False ) - dataloader = torch.utils.data.DataLoader( dataset, batch_size=ds_iter_batch_size, shuffle=not USE_PYTORCH_DDP and is_train, sampler=sampler, - num_workers=4 if is_train else self.eval_num_workers, + num_workers=5 * N_GPUS if is_train else self.eval_num_workers, pin_memory=True, drop_last=is_train, persistent_workers=is_train, + prefetch_factor=N_GPUS if is_train else None, ) dataloader = data_utils.PrefetchedWrapper(dataloader, DEVICE) dataloader = data_utils.cycle( @@ -266,7 +266,6 @@ def _build_dataset( use_mixup=use_mixup, mixup_alpha=0.2, ) - return dataloader def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: diff --git a/algoperf/workloads/imagenet_vit/imagenet_pytorch/models.py b/algoperf/workloads/imagenet_vit/imagenet_pytorch/models.py index fc2a3cd46..06df7ea75 100644 --- a/algoperf/workloads/imagenet_vit/imagenet_pytorch/models.py +++ b/algoperf/workloads/imagenet_vit/imagenet_pytorch/models.py @@ -5,7 +5,6 @@ and https://github.com/lucidrains/vit-pytorch. """ -import math from typing import Any, Optional, Tuple, Union import torch @@ -126,13 +125,14 @@ def forward(self, x: spec.Tensor, dropout_rate: float) -> spec.Tensor: value_layer = self.transpose_for_scores(self.value(x)) query_layer = self.transpose_for_scores(mixed_query_layer) - attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) - attention_scores = attention_scores / math.sqrt(self.head_dim) - - attention_probs = F.softmax(attention_scores, dim=-1) - attention_probs = F.dropout(attention_probs, dropout_rate, self.training) + # Use built-in scaled_dot_product_attention (Flash Attention when available) + context_layer = F.scaled_dot_product_attention( + query_layer, + key_layer, + value_layer, + dropout_p=dropout_rate if self.training else 0.0, + ) - context_layer = torch.matmul(attention_probs, value_layer) context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_dim,) context_layer = context_layer.view(new_context_layer_shape) diff --git a/algorithms/baselines/external_tuning/pytorch_nadamw_full_budget.py b/algorithms/baselines/external_tuning/pytorch_nadamw_full_budget.py index 10f481d5f..494ada4c8 100644 --- a/algorithms/baselines/external_tuning/pytorch_nadamw_full_budget.py +++ b/algorithms/baselines/external_tuning/pytorch_nadamw_full_budget.py @@ -5,7 +5,6 @@ import torch import torch.distributed.nn as dist_nn -from absl import logging from torch import Tensor from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR diff --git a/benchmark_step_times.py b/benchmark_step_times.py deleted file mode 100755 index 363856734..000000000 --- a/benchmark_step_times.py +++ /dev/null @@ -1,274 +0,0 @@ -#!/usr/bin/env python3 -"""Benchmark step times for JAX and PyTorch across all workloads. - -This script runs each workload for 101 steps with both JAX and PyTorch, -captures the step_time_ms metric, and produces a comparison table. -""" - -import argparse -import re -import subprocess -from pathlib import Path - -# Base workloads to benchmark -WORKLOADS = [ - 'imagenet_resnet', -] - -FRAMEWORKS = ['jax', 'pytorch'] -MAX_STEPS = 201 -OUTPUT_DIR = Path('/home/ak4605/aef2/benchmark_outputs') - - -def get_data_dir(workload: str, framework: str) -> str: - """Map workload to its data directory.""" - if workload in ['imagenet_resnet', 'imagenet_vit']: - return '/opt/data/imagenet/' + framework - elif workload in ['librispeech_conformer', 'librispeech_deepspeech']: - return '/opt/data/librispeech' - elif workload == 'criteo1tb': - return '/opt/data/criteo1tb' - elif workload == 'fastmri': - return '/opt/data/fastmri' - elif workload == 'ogbg': - return '/opt/data/ogbg' - elif workload == 'wmt': - return '/opt/data/wmt' - else: - return '/opt/' - - -def run_workload(workload: str, framework: str, output_file: Path) -> bool: - """Run a workload and capture output to file.""" - data_dir = get_data_dir(workload, framework) - experiment_dir = '/home/ak4605/experiments' - - # Clean up previous experiment directories - for item in Path(experiment_dir).glob(f'{workload}*'): - if item.is_dir(): - subprocess.run(['rm', '-rf', str(item)], check=True) - - # Build command based on framework - submission_path = ( - f'algorithms/baselines/external_tuning/{framework}_nadamw_full_budget.py' - ) - tuning_search_space = ( - 'algorithms/baselines/external_tuning/tuning_search_space.json' - ) - - if framework == 'jax': - cmd = [ - 'python', - 'submission_runner.py', - f'--framework={framework}', - f'--workload={workload}', - f'--data_dir={data_dir}', - f'--experiment_dir={experiment_dir}', - f'--experiment_name={workload}_benchmark', - f'--submission_path={submission_path}', - f'--tuning_search_space={tuning_search_space}', - f'--max_global_steps={MAX_STEPS}', - '--skip_evals', - '--nosave_checkpoints', - '--nosave_intermediate_checkpoints', - ] - # For JAX, activate the jax conda environment - activate_cmd = 'source $(conda info --base)/etc/profile.d/conda.sh && conda activate ap11_jax && ' - else: - cmd = [ - 'torchrun', - '--nproc_per_node=4', - '--standalone', - 'submission_runner.py', - f'--framework={framework}', - f'--workload={workload}', - f'--data_dir={data_dir}', - f'--experiment_dir={experiment_dir}', - f'--experiment_name={workload}_benchmark', - f'--submission_path={submission_path}', - f'--tuning_search_space={tuning_search_space}', - f'--max_global_steps={MAX_STEPS}', - '--skip_evals', - '--nosave_checkpoints', - '--nosave_intermediate_checkpoints', - ] - # For PyTorch, activate the torch conda environment - activate_cmd = 'source $(conda info --base)/etc/profile.d/conda.sh && conda activate ap11_torch_latest && ' - - # Run the command with shell to handle conda activation - full_cmd = activate_cmd + ' '.join(cmd) - print(f'Running: {workload} with {framework}') - print(f'Output will be saved to: {output_file}') - - with open(output_file, 'w') as f: - result = subprocess.run( - full_cmd, - shell=True, - executable='/bin/bash', - stdout=f, - stderr=subprocess.STDOUT, - cwd='/home/ak4605/aef2/', - ) - - return result.returncode == 0 - - -def parse_step_time(output_file: Path) -> float | None: - """Parse the last step_time_ms from output file.""" - if not output_file.exists(): - return None - - with open(output_file, 'r') as f: - content = f.read() - - # Find all step_time_ms values - # Pattern matches: step_time_ms=123.456 or 'step_time_ms': 123.456 - pattern = r'step_time_ms[=:]\s*([\d.]+)' - matches = re.findall(pattern, content) - - if matches: - # Return the last value (most recent EMA) - return float(matches[-1]) - return None - - -def parse_args(): - parser = argparse.ArgumentParser( - description='Benchmark step times for JAX and PyTorch across workloads.' - ) - group = parser.add_mutually_exclusive_group() - group.add_argument( - '--torch-only', - action='store_true', - help='Only run PyTorch experiments; read existing JAX results from files.', - ) - group.add_argument( - '--jax-only', - action='store_true', - help='Only run JAX experiments; read existing PyTorch results from files.', - ) - group.add_argument( - '--just-read', - action='store_true', - help='Do not run any experiments; just read and compare existing outputs.', - ) - return parser.parse_args() - - -def main(): - args = parse_args() - - # Create output directory - OUTPUT_DIR.mkdir(parents=True, exist_ok=True) - - results = {} - - # Determine which frameworks to run vs read from files - if args.just_read: - frameworks_to_run = [] - frameworks_to_read = FRAMEWORKS - elif args.torch_only: - frameworks_to_run = ['pytorch'] - frameworks_to_read = ['jax'] - elif args.jax_only: - frameworks_to_run = ['jax'] - frameworks_to_read = ['pytorch'] - else: - frameworks_to_run = FRAMEWORKS - frameworks_to_read = [] - - # Run all workloads - for workload in WORKLOADS: - results[workload] = {} - - # Read existing results from files - for framework in frameworks_to_read: - output_file = OUTPUT_DIR / f'{workload}_{framework}.out' - step_time = parse_step_time(output_file) - results[workload][framework] = step_time - if step_time: - print(f'\nLoaded existing {framework.upper()} result for {workload}: {step_time:.2f} ms') - else: - print(f'\nNo existing {framework.upper()} result found for {workload}') - - # Run experiments for specified frameworks - for framework in frameworks_to_run: - output_file = OUTPUT_DIR / f'{workload}_{framework}.out' - - print(f'\n{"=" * 60}') - print(f'Benchmarking {workload} with {framework}') - print(f'{"=" * 60}') - - success = run_workload(workload, framework, output_file) - - if success: - step_time = parse_step_time(output_file) - results[workload][framework] = step_time - print( - f'Step time: {step_time:.2f} ms' if step_time else 'Step time: N/A' - ) - else: - results[workload][framework] = None - print(f'Failed to run {workload} with {framework}') - - # Print results table - print('\n\n') - print('=' * 80) - print('STEP TIME COMPARISON (ms)') - print('=' * 80) - print( - f'{"Workload":<30} {"JAX (ms)":<15} {"PyTorch (ms)":<15} {"Ratio (PT/JAX)":<15}' - ) - print('-' * 80) - - for workload in WORKLOADS: - jax_time = results[workload].get('jax') - pytorch_time = results[workload].get('pytorch') - - jax_str = f'{jax_time:.2f}' if jax_time else 'N/A' - pytorch_str = f'{pytorch_time:.2f}' if pytorch_time else 'N/A' - - if jax_time and pytorch_time: - ratio = pytorch_time / jax_time - ratio_str = f'{ratio:.2f}x' - else: - ratio_str = 'N/A' - - print(f'{workload:<30} {jax_str:<15} {pytorch_str:<15} {ratio_str:<15}') - - print('=' * 80) - - # Save results to file - results_file = OUTPUT_DIR / 'results.txt' - with open(results_file, 'w') as f: - f.write('STEP TIME COMPARISON (ms)\n') - f.write('=' * 80 + '\n') - f.write( - f'{"Workload":<30} {"JAX (ms)":<15} {"PyTorch (ms)":<15} {"Ratio (PT/JAX)":<15}\n' - ) - f.write('-' * 80 + '\n') - - for workload in WORKLOADS: - jax_time = results[workload].get('jax') - pytorch_time = results[workload].get('pytorch') - - jax_str = f'{jax_time:.2f}' if jax_time else 'N/A' - pytorch_str = f'{pytorch_time:.2f}' if pytorch_time else 'N/A' - - if jax_time and pytorch_time: - ratio = pytorch_time / jax_time - ratio_str = f'{ratio:.2f}x' - else: - ratio_str = 'N/A' - - f.write( - f'{workload:<30} {jax_str:<15} {pytorch_str:<15} {ratio_str:<15}\n' - ) - - f.write('=' * 80 + '\n') - - print(f'\nResults saved to: {results_file}') - - -if __name__ == '__main__': - main() diff --git a/submission_runner.py b/submission_runner.py index bee6ea539..1432b8509 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -256,7 +256,6 @@ def train_once( 'librispeech_conformer', 'ogbg', 'criteo1tb', - 'imagenet_vit', 'librispeech_deepspeech', ] eager_backend_workloads = [] @@ -266,6 +265,7 @@ def train_once( 'librispeech_deepspeech', 'ogbg', 'wmt', + 'imagenet_vit', ] base_workload = workloads.get_base_workload_name(workload_name) if base_workload in compile_error_workloads: @@ -411,9 +411,8 @@ def train_once( train_step_end_time = get_time() if global_step == 11: step_10_end_time = train_step_end_time - + # Log step time every 100 steps - # Note: global_step was incremented, so use (global_step - 1) to match if (global_step - 1) % 100 == 0 and workload.metrics_logger is not None: if step_10_end_time is not None and global_step > 11: elapsed_time_ms = (train_step_end_time - step_10_end_time) * 1000.0 diff --git a/tests/test_step_times.py b/tests/test_step_times.py new file mode 100644 index 000000000..22868d67d --- /dev/null +++ b/tests/test_step_times.py @@ -0,0 +1,199 @@ +"""Tests that JAX and PyTorch step times are within 20% of each other. + +This test runs each workload for a number of steps with both JAX and PyTorch, +captures the step_time_ms metric, and asserts they are within 20%. +""" + +import re +import subprocess +import sys +import tempfile +from pathlib import Path + +from absl import flags, logging +from absl.testing import absltest, parameterized + +FLAGS = flags.FLAGS +FLAGS(sys.argv) + +MAX_STEPS = 101 +TOLERANCE = 0.25 + +WORKLOADS = [ + 'imagenet_vit', +] + +DATA_DIRS = { + 'imagenet_resnet': '/opt/data/imagenet/', + 'imagenet_vit': '/opt/data/imagenet/', + 'librispeech_conformer': '/opt/data/librispeech', + 'librispeech_deepspeech': '/opt/data/librispeech', + 'criteo1tb': '/opt/data/criteo1tb', + 'fastmri': '/opt/data/fastmri', + 'ogbg': '/opt/data/ogbg', + 'wmt': '/opt/data/wmt', +} + +CONDA_ENVS = { + 'jax': 'ap11_jax', + 'pytorch': 'ap11_torch_latest', +} + + +def get_data_dir(workload: str, framework: str) -> str: + """Map workload to its data directory.""" + base_dir = DATA_DIRS.get(workload, '/opt/data') + if workload in ['imagenet_resnet', 'imagenet_vit']: + return base_dir + framework + return base_dir + + +def run_workload(workload: str, framework: str, output_file: Path) -> bool: + """Run a workload and capture output to file.""" + data_dir = get_data_dir(workload, framework) + experiment_dir = tempfile.mkdtemp(prefix=f'{workload}_{framework}_') + + submission_path = ( + f'algorithms/baselines/external_tuning/{framework}_nadamw_full_budget.py' + ) + tuning_search_space = ( + 'algorithms/baselines/external_tuning/tuning_search_space.json' + ) + + if framework == 'jax': + cmd = [ + 'python', + 'submission_runner.py', + f'--framework={framework}', + f'--workload={workload}', + f'--data_dir={data_dir}', + f'--experiment_dir={experiment_dir}', + f'--experiment_name={workload}_benchmark', + f'--submission_path={submission_path}', + f'--tuning_search_space={tuning_search_space}', + f'--max_global_steps={MAX_STEPS}', + '--skip_evals', + '--nosave_checkpoints', + '--nosave_intermediate_checkpoints', + ] + else: + cmd = [ + 'torchrun', + '--nproc_per_node=4', + '--standalone', + 'submission_runner.py', + f'--framework={framework}', + f'--workload={workload}', + f'--data_dir={data_dir}', + f'--experiment_dir={experiment_dir}', + f'--experiment_name={workload}_benchmark', + f'--submission_path={submission_path}', + f'--tuning_search_space={tuning_search_space}', + f'--max_global_steps={MAX_STEPS}', + '--skip_evals', + '--nosave_checkpoints', + '--nosave_intermediate_checkpoints', + ] + + conda_env = CONDA_ENVS[framework] + activate_cmd = ( + f'source $(conda info --base)/etc/profile.d/conda.sh && ' + f'conda activate {conda_env} && ' + ) + full_cmd = activate_cmd + ' '.join(cmd) + + logging.info(f'Running: {workload} with {framework}') + logging.info(f'Output will be saved to: {output_file}') + + with open(output_file, 'w') as f: + result = subprocess.run( + full_cmd, + shell=True, + executable='/bin/bash', + stdout=f, + stderr=subprocess.STDOUT, + cwd=str(Path(__file__).parent.parent), + ) + + return result.returncode == 0 + + +def parse_step_time(output_file: Path) -> float | None: + """Parse the last step_time_ms from output file.""" + if not output_file.exists(): + return None + + with open(output_file, 'r') as f: + content = f.read() + + # Find all step_time_ms values + # Pattern matches: step_time_ms=123.456 or 'step_time_ms': 123.456 + pattern = r'step_time_ms[=:]\s*([\d.]+)' + matches = re.findall(pattern, content) + + if matches: + # Return the last value (most recent EMA) + return float(matches[-1]) + return None + + +named_parameters = [ + dict(testcase_name=workload, workload=workload) for workload in WORKLOADS +] + + +class StepTimeTest(parameterized.TestCase): + """Tests that JAX and PyTorch step times are within tolerance.""" + + @parameterized.named_parameters(*named_parameters) + def test_step_times_within_tolerance(self, workload): + """Test that JAX and PyTorch step times are within 20% of each other.""" + results = {} + + with tempfile.TemporaryDirectory() as tmpdir: + tmpdir = Path(tmpdir) + + for framework in ['jax', 'pytorch']: + output_file = tmpdir / f'{workload}_{framework}.out' + + success = run_workload(workload, framework, output_file) + self.assertTrue(success, f'Failed to run {workload} with {framework}') + + step_time = parse_step_time(output_file) + self.assertIsNotNone( + step_time, + f'Could not parse step_time_ms for {workload} with {framework}', + ) + + results[framework] = step_time + logging.info(f'{workload} {framework}: {step_time:.2f} ms') + + jax_time = results['jax'] + pytorch_time = results['pytorch'] + ratio = pytorch_time / jax_time + + logging.info( + f'{workload}: JAX={jax_time:.2f}ms, PyTorch={pytorch_time:.2f}ms, ' + f'ratio={ratio:.2f}' + ) + + # Check that ratio is within tolerance (0.8 to 1.2 for 20% tolerance) + lower_bound = 1.0 - TOLERANCE + upper_bound = 1.0 + TOLERANCE + + self.assertGreaterEqual( + ratio, + lower_bound, + f'{workload}: PyTorch is more than {TOLERANCE * 100:.0f}% faster than JAX ' + f'(ratio={ratio:.2f}, expected >= {lower_bound:.2f})', + ) + self.assertLessEqual( + ratio, + upper_bound, + f'{workload}: PyTorch is more than {TOLERANCE * 100:.0f}% slower than JAX ' + f'(ratio={ratio:.2f}, expected <= {upper_bound:.2f})', + ) + + +if __name__ == '__main__': + absltest.main() From 400640a4553dafe8e880de22d7dc5b4ddb30ba37 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 29 Jan 2026 03:30:17 +0000 Subject: [PATCH 14/17] update pytorch package --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index e1fc84987..534f5d678 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -108,8 +108,8 @@ jax_gpu = [ ] pytorch_cpu = [ - "torch==2.5.1", - "torchvision==0.20.1" + "torch==2.9.0", + "torchvision==0.24.0" ] pytorch_gpu = [ "torch==2.9.0", From 6e2bad461c7de9d0554b48101a8d178a145e86d0 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 29 Jan 2026 03:31:39 +0000 Subject: [PATCH 15/17] remove logging from submission runner --- submission_runner.py | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/submission_runner.py b/submission_runner.py index 1432b8509..d15bda74b 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -352,7 +352,6 @@ def train_once( log_dir, flags.FLAGS, hyperparameters ) workload.attach_metrics_logger(metrics_logger) - step_10_end_time = None global_start_time = get_time() train_state['last_step_end_time'] = global_start_time @@ -409,21 +408,6 @@ def train_once( train_state['training_complete'] = True train_step_end_time = get_time() - if global_step == 11: - step_10_end_time = train_step_end_time - - # Log step time every 100 steps - if (global_step - 1) % 100 == 0 and workload.metrics_logger is not None: - if step_10_end_time is not None and global_step > 11: - elapsed_time_ms = (train_step_end_time - step_10_end_time) * 1000.0 - elapsed_steps = global_step - 11 - avg_step_time_ms = elapsed_time_ms / elapsed_steps - else: - avg_step_time_ms = 0.0 - workload.metrics_logger.append_scalar_metrics( - {'step_time_ms': avg_step_time_ms}, - global_step - 1, - ) train_state['accumulated_submission_time'] += ( train_step_end_time - train_state['last_step_end_time'] From a13e8b668b46eb10c00664c2452989bcb35c05f7 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 29 Jan 2026 03:40:36 +0000 Subject: [PATCH 16/17] update documentation --- README.md | 2 +- docs/DOCUMENTATION.md | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index f8e6763b4..71595e11b 100644 --- a/README.md +++ b/README.md @@ -31,7 +31,7 @@ The MLCommons™ **AlgoPerf: Training Algorithms benchmark** is designed to find When training neural nets, practitioners face many critical yet often opaque decisions: What optimizer to choose? How should its learning rate be tuned? What learning rate schedule should be used? These choices can make or break training, yet the community has lacked a clear, standardized way to identify the state of the art. Unlike benchmarks focused on hardware or model architecture, AlgoPerf isolates the **training algorithm** itself, which includes the optimizer, regularization, data selection, and hyperparameters like the learning rate schedule. By standardizing the benchmark process, AlgoPerf offers a meaningful apples-to-apples comparison of training algorithms and follows the following **key principles**: -- 🎯 **Fixed Target, Model & Hardware:** Submitted training algorithms must train a set of [**fixed models**](/docs/DOCUMENTATION.md#workloads) to a pre-defined validation performance target as fast as possible. All submissions use the same model architecture and are run on the same [**standardized hardware**](/docs/DOCUMENTATION.md#benchmarking-hardware) (8x NVIDIA V100 GPUs). This isolates the training algorithm's performance and allows a fair apples-to-apples comparison. +- 🎯 **Fixed Target, Model & Hardware:** Submitted training algorithms must train a set of [**fixed models**](/docs/DOCUMENTATION.md#workloads) to a pre-defined validation performance target as fast as possible. All submissions use the same model architecture and are run on the same [**standardized hardware**](/docs/DOCUMENTATION.md#benchmarking-hardware) (4x A100 (40GB) GPUs). This isolates the training algorithm's performance and allows a fair apples-to-apples comparison. - ⏱️ **Time-To-Result:** Submissions are evaluated based on the total wall-clock time required to reach the target, rewarding practical and efficient algorithms. - 🧠 **Diverse Workloads:** The benchmark includes [**8 diverse deep learning workloads**](/docs/DOCUMENTATION.md#workloads) across domains like image classification, speech recognition, and machine translation. A submission's score is computed by aggregating its performance, using [**performance profiles**](/docs/DOCUMENTATION.md#benchmark-score-using-performance-profiles), across all workloads to ensure general-purpose algorithms. - 📦 **Fully-Specified Algorithms:** Submissions must be complete procedures and thus hyperparameter tuning is treated as part of the algorithm. Submissions can either provide a search space for automated tuning ([**External tuning ruleset**](/docs/DOCUMENTATION.md#external-tuning-ruleset)) or be hyperparameter-free ([**Self-tuning ruleset**](/docs/DOCUMENTATION.md#self-tuning-ruleset)) with any tuning done automatically and "on the clock". This measures an algorithm's _total_ practical cost and provides practitioners with a complete method, eliminating the guesswork of how to apply it. diff --git a/docs/DOCUMENTATION.md b/docs/DOCUMENTATION.md index f7ac5e659..49e738408 100644 --- a/docs/DOCUMENTATION.md +++ b/docs/DOCUMENTATION.md @@ -55,7 +55,7 @@ The **AlgoPerf: Training Algorithms benchmark** challenges participants to submi The benchmarking process follows these **key principles**: -- 🎯 **Fixed Target, Model & Hardware:** Submitted training algorithms must train a set of [**fixed models**](#workloads) to a pre-defined validation performance target as fast as possible. All submissions use the same model architecture and are run on the same [**standardized hardware**](#benchmarking-hardware) (currently `8x NVIDIA V100 GPUs`). This isolates the training algorithm's performance and allows a fair apples-to-apples comparison. +- 🎯 **Fixed Target, Model & Hardware:** Submitted training algorithms must train a set of [**fixed models**](#workloads) to a pre-defined validation performance target as fast as possible. All submissions use the same model architecture and are run on the same [**standardized hardware**](#benchmarking-hardware) (currently `4x NVIDIA A100 GPUs`). This isolates the training algorithm's performance and allows a fair apples-to-apples comparison. - ⏱️ **Time-To-Result:** Submissions are evaluated based on the total wall-clock time required to reach the target, rewarding practical and efficient algorithms. - 🧠 **Diverse Workloads:** The benchmark includes [**8 diverse deep learning workloads**](#workloads) across domains like image classification, speech recognition, and machine translation. A submission's score is computed by aggregating its performance across all workloads, using [**performance profiles**](#algoperf-benchmark-score-via-integrated-performance-profiles), to ensure general-purpose algorithms. - 📦 **Fully-Specified Algorithms:** Submissions must be [**complete procedures**](#submission-api) and thus hyperparameter tuning is treated as part of the algorithm. Depending on the [**ruleset**](#tuning-rulesets), submissions may use parallel tuning resources. This ensures that the benchmark measures the _total_ practical cost of a training algorithm and provides practitioners with a complete method, eliminating the guesswork of how to apply it. @@ -542,7 +542,7 @@ All officially scored runs will be performed on the same benchmarking hardware t This benchmarking hardware is chosen to be easily accessible via common cloud computing providers and will likely change with each iteration of the benchmark. The specs of the benchmarking hardware for this iteration of the benchmark are: -- 8× NVIDIA V100 (16 GB) GPUs +- 4× NVIDIA A100 (40 GB) GPUs - 240 GB in RAM - 2 TB in storage (for datasets). @@ -595,7 +595,7 @@ Furthermore, all submitters must sign the following agreements:
My machine only has one GPU. How can I use this repo? -> You can run this repo on a machine with an arbitrary number of GPUs. However, the default batch sizes of our algorithms collection (e.g. `algorithms/`) are tuned for a machine with 8× NVIDIA V100 (16 GB) GPUs. You may run into OOMs if you run these algorithms with fewer than 8 GPUs. If you run into these issues because you are using a machine with less total GPU memory, please reduce the batch sizes for the submission. Note that your final submission must 'fit' on the [**benchmarking hardware**](#benchmarking-hardware), so if you are using fewer GPUs with higher per-GPU memory, please monitor your memory usage to make sure it will fit on 8× NVIDIA V100 GPUs with 16 GB of VRAM per card. +> You can run this repo on a machine with an arbitrary number of GPUs. However, the default batch sizes of our algorithms collection (e.g. `algorithms/`) are tuned for a machine with 4× NVIDIA A100 (40 GB) GPUs. You may run into OOMs if you run these algorithms with fewer than 8 GPUs. If you run into these issues because you are using a machine with less total GPU memory, please reduce the batch sizes for the submission. Note that your final submission must 'fit' on the [**benchmarking hardware**](#benchmarking-hardware), so if you are using fewer GPUs with higher per-GPU memory, please monitor your memory usage to make sure it will fit on 4× NVIDIA A100 GPUs with 40 GB of VRAM per card.
From f7ce628dabe8f48869b08df4627dabb22ae79d2c Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 29 Jan 2026 03:44:51 +0000 Subject: [PATCH 17/17] fix --- algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py | 1 - 1 file changed, 1 deletion(-) diff --git a/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py index 8cfc4769f..289136bfb 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py @@ -254,7 +254,6 @@ def _build_dataset( shuffle=not USE_PYTORCH_DDP and is_train, sampler=sampler, num_workers=5 * N_GPUS if is_train else self.eval_num_workers, - num_workers=5 * N_GPUS if is_train else self.eval_num_workers, pin_memory=True, drop_last=is_train, persistent_workers=is_train,