Skip to content

Commit 470f153

Browse files
committed
Cleanup and RS flag race condition fix
1 parent 36eebf2 commit 470f153

12 files changed

Lines changed: 178 additions & 219 deletions

File tree

build_tools/pytorch.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,14 @@ def setup_pytorch_extension(
8585
if version < (12, 0):
8686
raise RuntimeError("Transformer Engine requires CUDA 12.0 or newer")
8787

88+
if bool(int(os.getenv("NVTE_UB_WITH_MPI", "0"))):
89+
assert (
90+
os.getenv("MPI_HOME") is not None
91+
), "MPI_HOME=/path/to/mpi must be set when compiling with NVTE_UB_WITH_MPI=1!"
92+
mpi_path = Path(os.getenv("MPI_HOME"))
93+
include_dirs.append(mpi_path / "include")
94+
cxx_flags.append("-DNVTE_UB_WITH_MPI")
95+
8896
library_dirs = []
8997
libraries = []
9098
if bool(int(os.getenv("NVTE_ENABLE_NVSHMEM", 0))):
@@ -104,17 +112,6 @@ def setup_pytorch_extension(
104112
libraries.append("mpi")
105113
cxx_flags.extend(["-DNVTE_ENABLE_ROCSHMEM", "-DOMPI_SKIP_MPICXX"])
106114

107-
extra_link_args = []
108-
if bool(int(os.getenv("NVTE_UB_WITH_MPI", "0"))):
109-
assert (
110-
os.getenv("MPI_HOME") is not None
111-
), "MPI_HOME=/path/to/mpi must be set when compiling with NVTE_UB_WITH_MPI=1!"
112-
mpi_path = Path(os.getenv("MPI_HOME", "/usr/lib/x86_64-linux-gnu/openmpi"))
113-
include_dirs.append(mpi_path / "include")
114-
library_dirs.append(mpi_path / "lib")
115-
libraries.append("mpi")
116-
cxx_flags.extend(["-DNVTE_UB_WITH_MPI", "-DOMPI_SKIP_MPICXX"])
117-
118115
# Construct PyTorch CUDA extension
119116
sources = [str(path) for path in sources]
120117
include_dirs = [str(path) for path in include_dirs]
@@ -127,5 +124,4 @@ def setup_pytorch_extension(
127124
extra_compile_args={"cxx": cxx_flags},
128125
libraries=[str(lib) for lib in libraries],
129126
library_dirs=[str(lib_dir) for lib_dir in library_dirs],
130-
extra_link_args=[str(arg) for arg in extra_link_args],
131127
)

ci/pytorch.sh

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,6 @@ run_test_config_mgpu(){
9292
#run in parallel on CI and it affects timing
9393
run_default_fa 1 test_gemm_sm_count.py
9494
run_default_fa 3 test_sanity_import.py
95-
run_default_fa 3 distributed/test_fusible_ops_with_userbuffers.py
9695
run_default_fa 3 distributed/test_comm_gemm_overlap.py
9796
run_default_fa 2 distributed/test_fusible_ops.py
9897
run_default_fa 2 distributed/test_numerics.py

tests/pytorch/distributed/run_layer_with_overlap.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,17 @@
2828
MXFP8BlockScaling,
2929
)
3030

31+
from torch.utils.cpp_extension import IS_HIP_EXTENSION
32+
3133
warnings.filterwarnings("ignore", category=DeprecationWarning)
3234
warnings.filterwarnings("ignore", category=FutureWarning)
3335
warnings.filterwarnings("ignore", category=UserWarning)
3436

35-
import transformer_engine.pytorch.cpp_extensions as tex
36-
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
37-
if not tex.device_supports_multicast():
38-
os.environ["UB_SKIPMC"] = "1"
37+
if IS_HIP_EXTENSION:
38+
import transformer_engine.pytorch.cpp_extensions as tex
39+
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
40+
if not tex.device_supports_multicast():
41+
os.environ["UB_SKIPMC"] = "1"
3942

4043

4144
class multi_module_model(torch.nn.Module):
@@ -118,6 +121,7 @@ def _get_layer_args(config, tp_group, tp_size, num_layers, reference=False):
118121
kwargs["input_layernorm"] = True
119122
else:
120123
kwargs["ub_tp_comm_overlap"] = not reference
124+
# Disable forward pass overlaps on HIP to isolate backward RS overlap
121125
kwargs["hidden_dropout"] = 0.0
122126
kwargs["set_parallel_mode"] = True
123127
kwargs["ub_overlap_rs_dgrad"] = config.overlap_rs_dgrad and not reference
@@ -557,8 +561,8 @@ def run_fwd_bwd(model, x):
557561
# Now validate accuracy
558562
if not bool(numerics_failed.item()):
559563
for i, (test_g, ref_g) in enumerate(zip(test_grads, ref_grads)):
560-
rtol = 0.125 if opts.fp8 else 0.025
561-
atol = 0.0625 if opts.fp8 else 0.00125
564+
rtol = 0.125 if opts.fp8 else 0.025 if not IS_HIP_EXTENSION else 5e-2
565+
atol = 0.0625 if opts.fp8 else 0.00125 if not IS_HIP_EXTENSION else 1e-2
562566
grad_failed, grad_info = _compare_tensors(names[i], test_g, ref_g, rtol, atol)
563567
dist_print(grad_info, src=WORLD_RANK, error=grad_failed)
564568
numerics_failed[0] = int(grad_failed)

tests/pytorch/distributed/test_comm_gemm_overlap.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def _run_gemm_with_overlap(comm_type, bulk, p2p, atomic, aggregate, quantization
7373
test_cmd.append("--bulk-overlap")
7474
else:
7575
if IS_HIP_EXTENSION and not p2p:
76-
pytest.skip("HIP only supports A2A operations.")
76+
pytest.skip("HIP only supports P2P operations.")
7777
if quantization == "fp8" and not fp8_available:
7878
pytest.skip(reason_for_no_fp8)
7979
if quantization == "mxfp8" and not mxfp8_available:
@@ -100,6 +100,9 @@ def _run_gemm_with_overlap(comm_type, bulk, p2p, atomic, aggregate, quantization
100100
def _run_layer_with_overlap(
101101
layer_type, linear_parallel_mode, overlap_rs_dgrad, fp8, quantization, num_layers=1
102102
):
103+
# Skip BULK overlap tests on HIP (column parallel or None with overlap_rs_dgrad=False)
104+
if IS_HIP_EXTENSION and not overlap_rs_dgrad and linear_parallel_mode in ("column", None):
105+
pytest.skip("Bulk overlap is not yet supported on HIP/ROCm.")
103106
test_path = TEST_ROOT / "run_layer_with_overlap.py"
104107
test_cmd = LAUNCH_CMD + [
105108
str(test_path),
@@ -163,6 +166,7 @@ def test_split_reduce_scatter_overlaps(quantization, p2p):
163166
_run_gemm_with_overlap("RS", False, p2p, False, False, quantization)
164167

165168

169+
@pytest.mark.skipif(IS_HIP_EXTENSION, reason="Bulk overlap is not yet supported on ROCm.")
166170
@pytest.mark.parametrize(
167171
"comm_type, quantization, connections",
168172
[
@@ -192,8 +196,6 @@ def test_bulk_overlaps(comm_type, quantization, connections):
192196
"CUDA_DEVICE_MAX_CONNECTIONS=8 test only applies to devices with compute capability"
193197
" 9.0 (HOPPER ARCH)."
194198
)
195-
if IS_HIP_EXTENSION:
196-
pytest.skip("HIP Does not support bulk overlaps with 8 connections.")
197199
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "8"
198200
_run_gemm_with_overlap(comm_type, True, False, False, False, quantization)
199201
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
@@ -228,7 +230,7 @@ def test_bulk_overlaps(comm_type, quantization, connections):
228230
ids=[
229231
f" {te.Linear.__name__} - ROW-PARALLEL ",
230232
f" {te.Linear.__name__} - COL-PARALLEL - BULK DGRAD/WGRAD ",
231-
f" {te.Linear.__name__} - COL-PARLALEL - DGRAD+RS ",
233+
f" {te.Linear.__name__} - COL-PARALLEL - DGRAD+RS ",
232234
f" {te.LayerNormLinear.__name__} - ROW-PARALLEL ",
233235
f" {te.LayerNormLinear.__name__} - COL-PARALLEL - BULK DGRAD/WGRAD ",
234236
f" {te.LayerNormLinear.__name__} - COL-PARALLEL - DGRAD+RS ",

transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,7 @@ TensorWrapper CommOverlapCore::get_tensor_chunk(const TensorWrapper &source, siz
206206
NVTE_DIM_CHECK(chunk_height > 0 && chunk_width > 0, "Attempted to get empty tensor chunk");
207207
NVTE_DIM_CHECK(chunk_height <= height && chunk_width <= width,
208208
"Attempted to get out-of-bounds tensor chunk");
209+
#ifndef __HIP_PLATFORM_AMD__
209210
if (scaling_mode == NVTEScalingMode::NVTE_MXFP8_1D_SCALING) {
210211
// MXFP8 scale-inverses are padded to a 2D matrix with dims that
211212
// are divisible by 128. UB doesn't handle this padding yet.
@@ -214,6 +215,7 @@ TensorWrapper CommOverlapCore::get_tensor_chunk(const TensorWrapper &source, siz
214215
NVTE_DIM_CHECK(chunk_height % 128 == 0 && chunk_width % 128 == 0,
215216
"Userbuffers requires MXFP8 tensor chunk dims that are divisible by 128");
216217
}
218+
#endif
217219
#undef NVTE_DIM_CHECK
218220

219221
// Construct tensor chunk
@@ -726,12 +728,12 @@ void CommOverlapP2PBase::initialize(const std::vector<size_t> &buffer_shape, DTy
726728
NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&stream, cudaStreamNonBlocking, _comm_priority));
727729
_stream_send.push_back(std::move(stream));
728730
}
729-
for (int i = 0; i < 7; i++) {
731+
for (int i = 0; i < NVTE_ROCM_MAX_RINGS; i++) {
730732
cudaStream_t stream;
731733
NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&stream, cudaStreamNonBlocking, _comm_priority));
732734
l_stream_send.push_back(std::move(stream));
733735
}
734-
for (int i = 0; i < 7; i++) {
736+
for (int i = 0; i < NVTE_ROCM_MAX_RINGS; i++) {
735737
cudaStream_t stream;
736738
NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&stream, cudaStreamNonBlocking, _comm_priority));
737739
l_stream_recv.push_back(std::move(stream));
@@ -740,7 +742,7 @@ void CommOverlapP2PBase::initialize(const std::vector<size_t> &buffer_shape, DTy
740742
cudaStreamCreateWithPriority(&_stream_recv, cudaStreamNonBlocking, _comm_priority));
741743
NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_stop_send, 0));
742744
NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_stop_recv, 0));
743-
for (int i = 0; i < 7; i++) {
745+
for (int i = 0; i < NVTE_ROCM_MAX_RINGS; i++) {
744746
NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&l_stop_recv[i], 0));
745747
}
746748
}
@@ -752,7 +754,7 @@ CommOverlapP2PBase::~CommOverlapP2PBase() {
752754
for (size_t i = 0; i < _stream_send.size(); i++) {
753755
cudaStreamDestroy(_stream_send[i]);
754756
}
755-
for (int i = 0; i < 7; i++) {
757+
for (int i = 0; i < NVTE_ROCM_MAX_RINGS; i++) {
756758
cudaStreamDestroy(l_stream_recv[i]);
757759
cudaStreamDestroy(l_stream_send[i]);
758760
cudaEventDestroy(l_stop_recv[i]);

0 commit comments

Comments
 (0)