Skip to content

Commit db56b8f

Browse files
authored
Update jax gemm.py
1 parent 316dffb commit db56b8f

1 file changed

Lines changed: 7 additions & 4 deletions

File tree

  • transformer_engine/jax/cpp_extensions

transformer_engine/jax/cpp_extensions/gemm.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,19 +20,22 @@
2020
from jax.sharding import NamedSharding, PartitionSpec
2121
from jax.experimental.custom_partitioning import SdyShardingRule
2222

23+
from ..util import is_hip_extension, get_jnp_float8_e4m3_type, get_jnp_float8_e5m2_type
24+
2325
from transformer_engine_jax import (
2426
get_num_compute_streams,
2527
JAXX_Collective_Op,
2628
get_device_compute_capability,
27-
#initialize_cgemm_communicator,
28-
#get_cgemm_num_max_streams,
2929
)
30+
if not is_hip_extension():
31+
from transformer_engine_jax import (
32+
initialize_cgemm_communicator,
33+
get_cgemm_num_max_streams,
34+
)
3035

3136
from .base import BasePrimitive, register_primitive
3237
from .quantization import grouped_quantize
3338

34-
from ..util import is_hip_extension, get_jnp_float8_e4m3_type, get_jnp_float8_e5m2_type
35-
3639
from ..quantize import (
3740
AbstractBaseTensor,
3841
NoScaleTensor,

0 commit comments

Comments
 (0)