Skip to content

Commit eb6bb3c

Browse files
committed
Fix L3 test failures after merge
1 parent 6eb2707 commit eb6bb3c

8 files changed

Lines changed: 32 additions & 32 deletions

File tree

tests/jax/test_custom_call_compute.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1450,22 +1450,24 @@ def ref_func(x, w, data_layout):
14501450
assert_allclose(primitive_w_grad, ref_w_grad, dtype=jnp.bfloat16)
14511451

14521452
@pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
1453-
@pytest_parametrize_wrapper("m,n,k", TEST_SHAPES)
1454-
@pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
1453+
@pytest_parametrize_wrapper("m,n,k", [(64, 128, 128), (128, 256, 256)])
1454+
@pytest_parametrize_wrapper("recipe", supported_recipes)
14551455
@pytest_parametrize_wrapper("with_jax_gemm", [False, True])
14561456
@pytest_parametrize_wrapper("use_bias", [False, True] if is_hip_extension() else [True])
1457-
def test_dense_grad_fp8(self, m, n, k, scaling_mode, with_jax_gemm, use_bias):
1457+
def test_dense_grad_fp8_and_fp4(self, m, n, k, recipe, with_jax_gemm, use_bias):
14581458
data_layout = "NN"
14591459
x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout)
14601460

14611461
key = jax.random.PRNGKey(1)
14621462
bias = jax.random.uniform(key, n, dtype=jnp.bfloat16) if use_bias else None
14631463

1464-
if scaling_mode.is_1d_block_scaling():
1464+
if recipe.__class__.__name__ == "MXFP8BlockScaling":
14651465
# Check for first GEMM
14661466
_check_mxfp8_gemm_support(with_jax_gemm, m, n, k, use_bias)
14671467
# Check for second GEMM
14681468
_check_mxfp8_gemm_support(with_jax_gemm, m, k, n, use_bias)
1469+
# Check for third GEMM
1470+
_check_mxfp8_gemm_support(with_jax_gemm, k, n, m, use_bias)
14691471

14701472

14711473
def primitive_func(x, w, bias, contracting_dims, quantizer_set):
@@ -1530,19 +1532,21 @@ def _ref_jax_norm_impl(x, gamma, beta, norm_type, zero_centered_gamma, eps, quan
15301532

15311533
class TestFusedDense:
15321534
@pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
1533-
@pytest.mark.parametrize("m,n,k", [(64, 128, 128)])
1535+
@pytest.mark.parametrize("m,n,k", [(64, 128, 128), (128, 256, 256)])
15341536
@pytest_parametrize_wrapper("recipe", supported_recipes)
15351537
@pytest.mark.parametrize("norm_type", ["layernorm", "rmsnorm"])
15361538
@pytest_parametrize_wrapper("with_jax_gemm", [False, True])
15371539
def test_layernorm_dense_grad(self, m, n, k, recipe, norm_type, with_jax_gemm):
15381540
"""
15391541
Test layernorm_dense VJP Rule
15401542
"""
1541-
if scaling_mode.is_1d_block_scaling():
1543+
if recipe.__class__.__name__ == "MXFP8BlockScaling":
15421544
# Check for fwd GEMM
15431545
_check_mxfp8_gemm_support(with_jax_gemm, m, n, k)
1544-
# Check for bwd GEMM
1546+
# Check for first bwd GEMM
15451547
_check_mxfp8_gemm_support(with_jax_gemm, m, k, n)
1548+
# Check for second bwd GEMM
1549+
_check_mxfp8_gemm_support(with_jax_gemm, k, n, m)
15461550
# zero_centered_gamma is already tested in TestNorm
15471551
zero_centered_gamma = False
15481552
eps = 1e-6
@@ -1614,7 +1618,7 @@ def ref_func(x, w, gamma, beta):
16141618
assert_allclose(prim_beta_grad, ref_beta_grad, dtype=quantizer_set.dgrad.q_dtype)
16151619

16161620
@pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
1617-
@pytest.mark.parametrize("m,n,k", [(64, 128, 128)])
1621+
@pytest.mark.parametrize("m,n,k", [(64, 128, 128), (128, 256, 256)])
16181622
@pytest.mark.parametrize("activation_type", [("gelu",), ("gelu", "linear")])
16191623
@pytest_parametrize_wrapper("recipe", supported_recipes)
16201624
@pytest.mark.parametrize("norm_type", ["layernorm", "rmsnorm"])
@@ -1626,11 +1630,13 @@ def test_layernorm_mlp_grad(
16261630
"""
16271631
Test layernorm_mlp VJP Rule
16281632
"""
1629-
if scaling_mode.is_1d_block_scaling():
1633+
if recipe.__class__.__name__ == "MXFP8BlockScaling":
16301634
# Check for first GEMM
16311635
_check_mxfp8_gemm_support(with_jax_gemm, m, n, k, use_bias)
16321636
# Check for second GEMM
16331637
_check_mxfp8_gemm_support(with_jax_gemm, m, k, n, use_bias)
1638+
# Check for third GEMM
1639+
_check_mxfp8_gemm_support(with_jax_gemm, k, n, m, use_bias)
16341640

16351641
# zero_centered_gamma is already tested in TestNorm
16361642
zero_centered_gamma = False

tests/jax/test_distributed_layernorm_mlp.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -217,11 +217,11 @@ def _test_layernorm_mlp_grad(
217217
is_hip_extension()
218218
and (not with_jax_gemm)
219219
and use_bias
220-
and (fp8_recipe is None)
220+
and (quantization_recipe is None)
221221
and (dtype == jnp.bfloat16)
222222
):
223223
pytest.xfail("Skip known failure case.")
224-
if isinstance(fp8_recipe, recipe.MXFP8BlockScaling):
224+
if isinstance(quantization_recipe, recipe.MXFP8BlockScaling):
225225
_check_mxfp8_layernorm_mlp_grad_support(
226226
input_shape[0]*input_shape[1],
227227
INTERMEDIATE,
@@ -410,7 +410,7 @@ def _test_layernorm_mlp(
410410
use_shardy,
411411
with_jax_gemm,
412412
):
413-
if isinstance(fp8_recipe, recipe.MXFP8BlockScaling):
413+
if isinstance(quantization_recipe, recipe.MXFP8BlockScaling):
414414
_check_mxfp8_layernorm_mlp_support(
415415
input_shape[0]*input_shape[1],
416416
INTERMEDIATE,

tests/pytorch/distributed/run_fsdp2_fp8_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from torch.distributed._composable.fsdp import fully_shard
1919
from torch.distributed.device_mesh import init_device_mesh
2020
from transformer_engine.pytorch import torch_version
21-
from transformer_engine.pytorch.fp8 import fp8_model_init
21+
from transformer_engine.pytorch.quantization import quantized_model_init
2222
from torch.nn.parallel import DistributedDataParallel as DDP
2323
from pathlib import Path
2424

@@ -171,7 +171,7 @@ def _train(args):
171171
torch.cuda.memory._record_memory_history(enabled='all', context='all', stacks='all')
172172
if args.fp8_init:
173173
# Build the model with the specified context
174-
with fp8_model_init(enabled = True):
174+
with quantized_model_init(enabled=True):
175175
model = SimpleNet(args.input_size, args.hidden_size, args.output_size, use_fsdp2=args.use_fsdp2)
176176
else:
177177
model = SimpleNet(args.input_size, args.hidden_size, args.output_size, use_fsdp2=args.use_fsdp2)

tests/pytorch/distributed/test_torch_fsdp2_fp8.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#!/usr/bin/python3
2-
# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
2+
# Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved.
33
# See LICENSE for license information.
44

55
import os
@@ -8,7 +8,7 @@
88
import subprocess
99
from pathlib import Path
1010
from transformer_engine.pytorch import torch_version
11-
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
11+
from transformer_engine.pytorch.quantization import FP8GlobalStateManager
1212
import torch
1313
from run_fsdp2_fp8_model import SimpleNet
1414

@@ -17,20 +17,15 @@
1717

1818
NUM_PROCS: int = torch.cuda.device_count()
1919

20-
def assert_allclose(
21-
l1: List[torch.Tensor], l2: List[torch.Tensor], atol: float, rtol: float = None
22-
) -> bool:
23-
"""Ensures two lists are equal."""
20+
def assertEqual(
21+
l1: List[torch.Tensor], l2: List[torch.Tensor]) -> bool:
22+
"""Ensures two lists are exactly equal."""
2423
assert len(l1) == len(l2), "Unequal number of outputs."
2524
for i, (t1, t2) in enumerate(zip(l1, l2)):
26-
tols = dict(atol=atol)
27-
if rtol is not None:
28-
tols["rtol"] = rtol
29-
result = torch.allclose(t1, t2, **tols)
25+
result = torch.allclose(t1, t2, atol=0, rtol=0)
3026
if not result:
3127
diff = torch.abs(t1 - t2)
32-
tol = atol + (rtol * torch.abs(t2))
33-
exceed_mask = diff > tol
28+
exceed_mask = diff > 0
3429
if exceed_mask.any():
3530
indices = torch.nonzero(exceed_mask, as_tuple=True)
3631
max_diff = diff[exceed_mask].max()
@@ -64,7 +59,7 @@ def _run_test(fp_init, recipe):
6459
for idx, (te_output_no_cache, te_output_cache) in enumerate(zip(output_fsdp, output_dp)):
6560

6661
print(f"Comparing FSDP {te_output_no_cache[0]}, DDP {te_output_cache[0]} at index {idx}...")
67-
assert_allclose(te_output_no_cache[1], te_output_cache[1], atol=0, rtol=0)
62+
assertEqual(te_output_no_cache[1], te_output_cache[1]) # expects exact match
6863
print(f"Tensor at index {idx} passed comparison.")
6964

7065

transformer_engine/common/fused_attn_rocm/fused_attn.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
283283
using namespace transformer_engine;
284284

285285
// TODO: Add return_max_logit support
286-
if (return_max_logit || cuda_graph) return NVTE_Fused_Attn_Backend::NVTE_No_Backend;
286+
if (return_max_logit) return NVTE_Fused_Attn_Backend::NVTE_No_Backend;
287287

288288
// by default, fused attn is enabled
289289
bool nvte_fused_attn = true;

transformer_engine/jax/quantize/scaling_modes.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from transformer_engine_jax import JAXX_Scaling_Mode
2727
from .misc import QuantizeLayout
2828
from .device_utils import is_fp8_gemm_with_all_layouts_supported
29-
from ..util import is_hip_extension
29+
from ..util import is_hip_extension, get_jnp_float8_e4m3_type, get_jnp_float8_e5m2_type
3030

3131

3232
__all__ = [
@@ -1038,7 +1038,7 @@ def get_compatible_q_dtypes(self) -> set[jnp.dtype]:
10381038
ScalingMode.CURRENT_TENSOR_SCALING,
10391039
ScalingMode.MXFP8_1D_SCALING,
10401040
):
1041-
return {jnp.float8_e5m2, jnp.float8_e4m3fn}
1041+
return {get_jnp_float8_e5m2_type(), get_jnp_float8_e4m3_type()}
10421042
if self in (ScalingMode.NVFP4_1D_SCALING, ScalingMode.NVFP4_2D_SCALING):
10431043
return {jnp.float4_e2m1fn}
10441044
if self == ScalingMode.NO_SCALING:

transformer_engine/pytorch/cpp_extensions/fused_attn.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,6 @@ def fused_attn_fwd(
266266

267267
if IS_HIP_EXTENSION:
268268
assert not return_max_logit, "ROCm does not support return_max_logit yet."
269-
assert not cuda_graph, "ROCm does not support cuda_graph."
270269

271270
if attn_scale is None:
272271
d = q.size(-1)

transformer_engine/pytorch/quantization.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -623,7 +623,7 @@ def autocast_exit(cls, enabled: bool, _graph: bool) -> None:
623623
# Reduce only the non-FP8 weight modules here.
624624
# FP8 weight modules are reduced at the end of the optimizer
625625
# step after the weight amax is populated.
626-
if enabled and cls.AUTOCAST_DEPTH == 0 and not _graph and torch.is_grad_enabled():
626+
if not cls.SKIP_FP8_REDUCTION_FOR_FSDP2 and enabled and cls.AUTOCAST_DEPTH == 0 and not _graph and torch.is_grad_enabled():
627627
# delayed scaling only function, for other recipes (current scaling with any granularity),
628628
# this is noop for other recipes because cls.global_amax_buffer is empty list
629629
cls.reduce_and_update_fp8_tensors(forward=True)

0 commit comments

Comments
 (0)