@@ -1450,22 +1450,24 @@ def ref_func(x, w, data_layout):
14501450 assert_allclose (primitive_w_grad , ref_w_grad , dtype = jnp .bfloat16 )
14511451
14521452 @pytest .mark .skipif (not is_fp8_supported , reason = fp8_unsupported_reason )
1453- @pytest_parametrize_wrapper ("m,n,k" , TEST_SHAPES )
1454- @pytest_parametrize_wrapper ("scaling_mode " , supported_scaling_modes )
1453+ @pytest_parametrize_wrapper ("m,n,k" , [( 64 , 128 , 128 ), ( 128 , 256 , 256 )] )
1454+ @pytest_parametrize_wrapper ("recipe " , supported_recipes )
14551455 @pytest_parametrize_wrapper ("with_jax_gemm" , [False , True ])
14561456 @pytest_parametrize_wrapper ("use_bias" , [False , True ] if is_hip_extension () else [True ])
1457- def test_dense_grad_fp8 (self , m , n , k , scaling_mode , with_jax_gemm , use_bias ):
1457+ def test_dense_grad_fp8_and_fp4 (self , m , n , k , recipe , with_jax_gemm , use_bias ):
14581458 data_layout = "NN"
14591459 x , w , contracting_dims = self ._generate_gemm_input (m , n , k , data_layout )
14601460
14611461 key = jax .random .PRNGKey (1 )
14621462 bias = jax .random .uniform (key , n , dtype = jnp .bfloat16 ) if use_bias else None
14631463
1464- if scaling_mode . is_1d_block_scaling () :
1464+ if recipe . __class__ . __name__ == "MXFP8BlockScaling" :
14651465 # Check for first GEMM
14661466 _check_mxfp8_gemm_support (with_jax_gemm , m , n , k , use_bias )
14671467 # Check for second GEMM
14681468 _check_mxfp8_gemm_support (with_jax_gemm , m , k , n , use_bias )
1469+ # Check for third GEMM
1470+ _check_mxfp8_gemm_support (with_jax_gemm , k , n , m , use_bias )
14691471
14701472
14711473 def primitive_func (x , w , bias , contracting_dims , quantizer_set ):
@@ -1530,19 +1532,21 @@ def _ref_jax_norm_impl(x, gamma, beta, norm_type, zero_centered_gamma, eps, quan
15301532
15311533class TestFusedDense :
15321534 @pytest .mark .skipif (not is_fp8_supported , reason = fp8_unsupported_reason )
1533- @pytest .mark .parametrize ("m,n,k" , [(64 , 128 , 128 )])
1535+ @pytest .mark .parametrize ("m,n,k" , [(64 , 128 , 128 ), ( 128 , 256 , 256 ) ])
15341536 @pytest_parametrize_wrapper ("recipe" , supported_recipes )
15351537 @pytest .mark .parametrize ("norm_type" , ["layernorm" , "rmsnorm" ])
15361538 @pytest_parametrize_wrapper ("with_jax_gemm" , [False , True ])
15371539 def test_layernorm_dense_grad (self , m , n , k , recipe , norm_type , with_jax_gemm ):
15381540 """
15391541 Test layernorm_dense VJP Rule
15401542 """
1541- if scaling_mode . is_1d_block_scaling () :
1543+ if recipe . __class__ . __name__ == "MXFP8BlockScaling" :
15421544 # Check for fwd GEMM
15431545 _check_mxfp8_gemm_support (with_jax_gemm , m , n , k )
1544- # Check for bwd GEMM
1546+ # Check for first bwd GEMM
15451547 _check_mxfp8_gemm_support (with_jax_gemm , m , k , n )
1548+ # Check for second bwd GEMM
1549+ _check_mxfp8_gemm_support (with_jax_gemm , k , n , m )
15461550 # zero_centered_gamma is already tested in TestNorm
15471551 zero_centered_gamma = False
15481552 eps = 1e-6
@@ -1614,7 +1618,7 @@ def ref_func(x, w, gamma, beta):
16141618 assert_allclose (prim_beta_grad , ref_beta_grad , dtype = quantizer_set .dgrad .q_dtype )
16151619
16161620 @pytest .mark .skipif (not is_fp8_supported , reason = fp8_unsupported_reason )
1617- @pytest .mark .parametrize ("m,n,k" , [(64 , 128 , 128 )])
1621+ @pytest .mark .parametrize ("m,n,k" , [(64 , 128 , 128 ), ( 128 , 256 , 256 ) ])
16181622 @pytest .mark .parametrize ("activation_type" , [("gelu" ,), ("gelu" , "linear" )])
16191623 @pytest_parametrize_wrapper ("recipe" , supported_recipes )
16201624 @pytest .mark .parametrize ("norm_type" , ["layernorm" , "rmsnorm" ])
@@ -1626,11 +1630,13 @@ def test_layernorm_mlp_grad(
16261630 """
16271631 Test layernorm_mlp VJP Rule
16281632 """
1629- if scaling_mode . is_1d_block_scaling () :
1633+ if recipe . __class__ . __name__ == "MXFP8BlockScaling" :
16301634 # Check for first GEMM
16311635 _check_mxfp8_gemm_support (with_jax_gemm , m , n , k , use_bias )
16321636 # Check for second GEMM
16331637 _check_mxfp8_gemm_support (with_jax_gemm , m , k , n , use_bias )
1638+ # Check for third GEMM
1639+ _check_mxfp8_gemm_support (with_jax_gemm , k , n , m , use_bias )
16341640
16351641 # zero_centered_gamma is already tested in TestNorm
16361642 zero_centered_gamma = False
0 commit comments