File tree Expand file tree Collapse file tree
transformer_engine/jax/cpp_extensions Expand file tree Collapse file tree Original file line number Diff line number Diff line change 2020from jax .sharding import NamedSharding , PartitionSpec
2121from jax .experimental .custom_partitioning import SdyShardingRule
2222
23+ from ..util import is_hip_extension , get_jnp_float8_e4m3_type , get_jnp_float8_e5m2_type
24+
2325from transformer_engine_jax import (
2426 get_num_compute_streams ,
2527 JAXX_Collective_Op ,
2628 get_device_compute_capability ,
27- #initialize_cgemm_communicator,
28- #get_cgemm_num_max_streams,
2929)
30+ if not is_hip_extension ():
31+ from transformer_engine_jax import (
32+ initialize_cgemm_communicator ,
33+ get_cgemm_num_max_streams ,
34+ )
3035
3136from .base import BasePrimitive , register_primitive
3237from .quantization import grouped_quantize
3338
34- from ..util import is_hip_extension , get_jnp_float8_e4m3_type , get_jnp_float8_e5m2_type
35-
3639from ..quantize import (
3740 AbstractBaseTensor ,
3841 NoScaleTensor ,
You can’t perform that action at this time.
0 commit comments