Skip to content

Userbuffer epic#367

Open
alextmagro wants to merge 4 commits intodevfrom
userbuffer_epic
Open

Userbuffer epic#367
alextmagro wants to merge 4 commits intodevfrom
userbuffer_epic

Conversation

@alextmagro
Copy link
Contributor

This is the userbuffer_epic branch, to be merged only once all epic tasks have been completed. PRs for epic tasks will be onto this branch.

parser.add_argument("--seed", type=int, default=1234, help="RNG seed.")
parser.add_argument(
"--fp8", action="store_true", default=False, help="Enables the te.fp8_autocast() context."
"--fp8", action="store_true", default=False, help="Enables the te.autocast() context."
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Up to TE v2.8, I think it's still fp8_autocast. Were you targeting at higher versions?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you had a few comments on this, so will address it here quickly. I moved the UB code up to release 2.10, as there were a few bugs and inefficiencies that NV fixed. Most of the changes that aren't guarded in the files are NV upstream changes.

I am fixing up the te_layer_with_overlap differences, and working on integrating the benchmark script into the file directly.


# This file was modified for portability to AMDGPU
# Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was this file sharing a lot of codes with examples/pytorch/comm_gemm_overlap/te_layer_with_overlap.py? Is it possible to consolidate those two files

import transformer_engine.pytorch.cpp_extensions as tex
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager

from transformer_engine.jax.cpp_extensions.misc import is_hip_extension
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's not import jax specific code into pytorch side. Use this instead:

from torch.utils.cpp_extension import IS_HIP_EXTENSION

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, this is an mistake. Will fix.

if (_ub_comm->myrank == 0) printf("!!! [UB] Register UBuf %d\n", _ub_reg);
if (_ub_comm->myrank == 0) {
printf("!!! [UB] Register UBuf %d\n", _ub_reg);
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer aligning the coding style with NV upstream so it's easier for us to maintain/IFU later

allgather_handle, barrier_handle, tp_size, num_max_streams, comm_cga_size,
gemm_priority, comm_priority, num_comm_sm, set_sm_margin, use_ce,
atomic_gemm) {
initialize(buffer_shape, buffer_dtype, comm_type, aggregate);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same question here for the motivation of this initialize function in the constructor

NVTE_CHECK_CUDA(cudaMemset((*comm)->flags_baseptr, 0, 2 * GPU_PAGE_SIZE));
(*comm)->flags = reinterpret_cast<int *>(
#ifdef __HIP_PLATFORM_AMD__
(reinterpret_cast<uintptr_t>((*comm)->flags) + GPU_PAGE_SIZE - 1) & GPU_PAGE_MASK);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should it be (*comm)->flags_baseptr as the nv upstream below? (*comm)->flags is not allocated/assigned above


__syncthreads();
if (threadIdx.x == 0) __threadfence_system();
if (threadIdx.x == 0) __threadfence();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

void userbuffers_send(const int srchandler, const size_t srcoffset, const int dsthandler,
const size_t dstoffset, const size_t bytes, communicator *comm,
const int peer, cudaStream_t stream) {
const int peer, cudaStream_t stream, int ring_id) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Emm, I guess my question then would be why NV upstream does not need a ring_id? Is it because of we have different implementation? The NVTE_ROCM_MAX_RINGS?

_comm_priority = comm_priority;
}
for (int i = 0; i < std::min(num_max_streams, num_splits); i++) {
for (int i = 0; i < std::max(num_max_streams, num_splits); i++) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In fact, do we need stream numbers more than the min of max_stream and num_splits?

NVTE_DIM_CHECK(chunk_height > 0 && chunk_width > 0, "Attempted to get empty tensor chunk");
NVTE_DIM_CHECK(chunk_height <= height && chunk_width <= width,
"Attempted to get out-of-bounds tensor chunk");
#ifndef __HIP_PLATFORM_AMD__
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we already support mxfp8. Add a to-do comment so that we won't forget to turn it on later


// Input data
const size_t source_size = source.numel();
const void *src_ptr = (rowwise) ? source.dptr() : source.columnwise_dptr();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, what if we need both row-wise and colwise? How about other fields of a tensor, for example, scale inv?

"num_sm": 1 if method == "ring_exchange" else 16,
"cga_size": 1 if method == "ring_exchange" else 2,
"set_sm_margin": not method == "ring_exchange",
"set_sm_margin": not method == "ring_exchange" and not IS_HIP_EXTENSION,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ilya already had the sm_margin feature supported on rocm

if IS_HIP_EXTENSION and user_ub_cfg is not None:
for name, cfg in user_ub_cfg.items():
assert cfg.get("method") != "bulk", (
f"Bulk overlap method for '{name}' is not supported on HIP/ROCm. "
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I recall we supported bulk overlap but the performance is not great?

"<nvtx3/nvToolsExt.h>" : "<roctracer/roctx.h>",
"cudaFuncSetAttribute(" : "hipFuncSetAttribute((const void*)"
"cudaFuncSetAttribute(" : "hipFuncSetAttribute((const void*)",
"cudaLaunchKernel": "hipLaunchKernel",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cudaLaunchKernel cannot be hipified?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants