-
Notifications
You must be signed in to change notification settings - Fork 617
[Common] Tuned NVFP4 cast kernel #2412
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Common] Tuned NVFP4 cast kernel #2412
Conversation
Greptile SummaryThis PR adds a specialized NVFP4 quantization kernel for Blackwell GPUs achieving 6.4 TB/s for round-to-nearest and 4.5 TB/s for stochastic rounding. The implementation leverages TMA, mbarrier, and cluster launch control instructions. Key Changes
Issues Found
Confidence Score: 3/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant User
participant Dispatcher as quantize_transpose()
participant TunedKernel as quantize_transpose_tuned_1D()
participant CUDAKernel as quantize_transpose_nvfp4_tuned_1D_kernel
participant PTX as PTX Instructions
participant GPU
User->>Dispatcher: Call with BF16 input, 1D quantization
Dispatcher->>Dispatcher: Check !use_2d_quantization && BF16
Note over Dispatcher: ⚠️ Missing is_supported_by_CC_100() check
Dispatcher->>TunedKernel: quantize_transpose_tuned_1D()
TunedKernel->>TunedKernel: Validate tensors (rows%32==0, cols%32==0)
TunedKernel->>TunedKernel: Create TMA tensor maps
TunedKernel->>TunedKernel: Configure shared memory
TunedKernel->>CUDAKernel: Launch kernel<<<grid, block, shmem>>>
alt Device is sm_100+
CUDAKernel->>PTX: mbarrier_init (TMA barriers)
CUDAKernel->>PTX: cp_async_bulk_tensor_2d (TMA prefetch)
loop For each stage
CUDAKernel->>PTX: mbarrier_wait (wait for data)
CUDAKernel->>CUDAKernel: rowwise_scaling (find amax, quantize)
CUDAKernel->>PTX: mul_cvt_bf16_to_fp4_8x (BF16→FP4)
opt If transpose requested
CUDAKernel->>CUDAKernel: colwise_scaling
end
CUDAKernel->>PTX: cp_async_bulk_tensor_2d (write results)
CUDAKernel->>PTX: cp_async_bulk_tensor_2d (write transpose)
end
CUDAKernel->>GPU: Store scales to global memory
GPU-->>User: Success
else Device is pre-sm_100
CUDAKernel->>CUDAKernel: NVTE_DEVICE_ERROR
GPU-->>User: Runtime error
end
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
5 files reviewed, 1 comment
transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
5 files reviewed, 1 comment
transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
5 files reviewed, no comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
5 files reviewed, no comments
445c870 to
a7a0652
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (1)
-
transformer_engine/common/cast/dispatch/quantize.cuh, line 250 (link)logic: missing
is_supported_by_CC_100()check. forward pass (line 104) has this check, but backward pass doesn't. without it, the persistent kernel can be dispatched on non-Blackwell hardware and hitNVTE_DEVICE_ERRORat runtime.
5 files reviewed, 1 comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
5 files reviewed, no comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
4 files reviewed, no comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
4 files reviewed, no comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (1)
-
transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh, line 1171-1174 (link)logic: Missing runtime architecture check before dispatching to persistent kernel. The tuned 1D kernel uses Blackwell-specific PTX instructions (mbarrier variants, cluster launch control, etc.) and requires sm_100+. Should guard with
is_supported_by_CC_100()check similar to other TMA kernel dispatches in the codebase (seequantize_fp8.cuh:535,gated.cuh:48).
4 files reviewed, 1 comment
91312be to
a38eeff
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (1)
-
transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh, line 1171-1174 (link)logic: Missing runtime check for sm_100+ before dispatching to tuned kernel. The tuned kernel uses Blackwell-specific features (TMA, cluster launch control) that will fail on non-sm_100 devices.
5 files reviewed, 1 comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (1)
-
tests/cpp/operator/test_cast_nvfp4_transpose.cu, line 125-126 (link)syntax: typo in comment: 'Numercial' should be 'Numerical'
5 files reviewed, 1 comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
5 files reviewed, no comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (1)
-
transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh, line 1171-1173 (link)logic: Add runtime architecture check before dispatching to tuned kernel:
The tuned kernel requires sm_100+ but this dispatch lacks the
is_supported_by_CC_100()check that other sm_100+ kernels use (seecast/dispatch/gated.cuh:48,cast/fp8/quantize_fp8.cuh:535). Without this check, the kernel will launch on non-Blackwell devices and hitNVTE_DEVICE_ERRORat runtime.
5 files reviewed, 1 comment
|
/te-ci |
Signed-off-by: Oleg Goncharov <[email protected]>
Signed-off-by: Oleg Goncharov <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Oleg Goncharov <[email protected]>
for more information, see https://pre-commit.ci
…(numerical truncation). Tightened the numerical tolerance Signed-off-by: Oleg Goncharov <[email protected]>
…formant when inputs are large Signed-off-by: Oleg Goncharov <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Oleg Goncharov <[email protected]>
Signed-off-by: Oleg Goncharov <[email protected]>
Signed-off-by: Oleg Goncharov <[email protected]>
Signed-off-by: Oleg Goncharov <[email protected]>
7362e47 to
bc26160
Compare
for more information, see https://pre-commit.ci
Greptile's behavior is changing!From now on, if a review finishes with no comments, we will not post an additional "statistics" comment to confirm that our review found nothing to comment on. However, you can confirm that we reviewed your changes in the status check section. This feature can be toggled off in your Code Review Settings by deselecting "Create a status check for each PR". |
|
/te-ci |
| const double atol = 0.05; | ||
| const double rtol = 0.1; | ||
| const double atol = 1.0E-6; | ||
| const double rtol = 1.0E-6; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is this changing? Looks like it got stricter so I think if it passes, it is better. But I am not understanding why it was lower before
| constexpr int THREADS_X_ROWWISE = TILE_DIM_X / ELTS_PER_THREAD; | ||
| constexpr int THREADS_Y_ROWWISE = THREADS_NUM / THREADS_X_ROWWISE; | ||
|
|
||
| constexpr int THREADS_X_TRANSP = TILE_DIM_X / 2; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we use 1 convention for macros for transpose? Like you named THREADS_X_T or BUFF_OUT_T_DIM_Y before, so "_T", or "_TRANSP", or "_TR" (like the BUFFS_NUM_OUT_TR one. But just choose 1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, I’m fine to standardize. Please consolidate any remaining feedback into a single final pass by EOD today so we can merge without another round of incremental comments. This PR has been open since late November and was originally requested as time-sensitive.
| inline void quantize_transpose_tuned_1D(const Tensor &input, const Tensor *noop, Tensor *output, | ||
| const QuantizationConfig *quant_config, | ||
| cudaStream_t stream) { | ||
| #if FP4_TYPE_SUPPORTED |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But shouldn't the TunableConfig struct is for that purpose? To make sure that these values are tuned separately from the common patrt of the 2 kernels, and we keep the common part in one place only. Then it should not be tight coupling
|
|
Signed-off-by: Oleg Goncharov <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
5 files reviewed, 1 comment
| if (!use_2d_quantization && (input.dtype() == DType::kBFloat16)) { | ||
| quantize_transpose_tuned_1D(input, noop, output, quant_config, stream); | ||
| return; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: add runtime check for sm_100+ before dispatching to tuned kernel
| if (!use_2d_quantization && (input.dtype() == DType::kBFloat16)) { | |
| quantize_transpose_tuned_1D(input, noop, output, quant_config, stream); | |
| return; | |
| if (!use_2d_quantization && (input.dtype() == DType::kBFloat16) && is_supported_by_CC_100()) { | |
| quantize_transpose_tuned_1D(input, noop, output, quant_config, stream); | |
| return; | |
| } |
The tuned kernel uses Blackwell-specific instructions (TMA, mbarrier, cluster launch control) that will cause runtime errors on older GPUs. FP4_TYPE_SUPPORTED is only a compile-time check for CUDA version, not device capability.
|
/te-ci |
tdophung
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have no other issue with this PR. We should merge. Thanks for the change!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
5 files reviewed, 1 comment
| if (!use_2d_quantization && (input.dtype() == DType::kBFloat16)) { | ||
| quantize_transpose_tuned_1D(input, noop, output, quant_config, stream); | ||
| return; | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: add runtime check for sm_100+ before dispatching to tuned kernel
| if (!use_2d_quantization && (input.dtype() == DType::kBFloat16)) { | |
| quantize_transpose_tuned_1D(input, noop, output, quant_config, stream); | |
| return; | |
| } | |
| if (!use_2d_quantization && (input.dtype() == DType::kBFloat16) && is_supported_by_CC_100()) { | |
| quantize_transpose_tuned_1D(input, noop, output, quant_config, stream); | |
| return; | |
| } |
the tuned kernel uses Blackwell-specific instructions (TMA, mbarrier, cluster launch control) that require sm_100+. without this check, older GPUs will hit the runtime error inside the kernel at line 651
* Implemented persistent nvfp4 kernel Signed-off-by: Oleg Goncharov <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix FP4 guard in ptx Signed-off-by: Oleg Goncharov <[email protected]> * Fix Signed-off-by: Oleg Goncharov <[email protected]> * Fix in ptx. reduxf32 guard Signed-off-by: Oleg Goncharov <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix Signed-off-by: Oleg Goncharov <[email protected]> * Fixes per PR review Signed-off-by: Oleg Goncharov <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fixes per PR review. Added parameter to turn off the persistency Signed-off-by: Oleg Goncharov <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Modified reference CPU implementation in C++ unit tests to match GPU (numerical truncation). Tightened the numerical tolerance Signed-off-by: Oleg Goncharov <[email protected]> * Disabled persistency by default, as non-persistent kernel is more performant when inputs are large Signed-off-by: Oleg Goncharov <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Use the tuned kernel also for the rowwise only quantization Signed-off-by: Oleg Goncharov <[email protected]> * Fixed typo Signed-off-by: Oleg Goncharov <[email protected]> * Addressed comments from the PR review Signed-off-by: Oleg Goncharov <[email protected]> * Resolved conflicts Signed-off-by: Oleg Goncharov <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Macros renaming Signed-off-by: Oleg Goncharov <[email protected]> --------- Signed-off-by: Oleg Goncharov <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Description
This PR introduces a specialized CUDA kernel optimized for NVFP4 quantization of BF16 inputs on Blackwell architecture (sm100f family). The implementation achieves performance improvements by leveraging architecture-specific features:
RN: round-to-nearest mode 6.4 TB/s (rowwise only 7.2 TB/s)
SR: stochastic rounding 4.5 TB/s (rowwise only 7.0 TB/s)
Rowwise + Colwise (transpose)
Rowwise only
a) round-to-nearest
b) stochastic rounding
Below are the performance measurements for quantizing tensors using dimensions representative of DSv3 [8192×8, 7168] on internal Cluster (B300).
Using
--fast-mathcan improve performance of the kernel with the stochastic rounding (RNG) by up to ~10%.Threads to data mapping (colwise case)
To reduce shared memory bank conflicts, the following mapping is use when reading from and writing to shmem buffers:
where
SCALE_DIM=16.The arrows in the figure below illustrate how thread indices increment, forming a zigzag pattern.
a) Reads from SHMEM Input Buffer
b) Writes to SHMEM Output Transpose Buffer
Type of change
Changes
Checklist: