-
Notifications
You must be signed in to change notification settings - Fork 687
Request for batched general_gemm() (or FP8-aware torch.bmm) for non-Linear GEMM workloads #2846
Description
Is your feature request related to a problem? Please describe.
We’re accelerating triangular multiplication in a protein structure prediction model (AlphaFold-style tri-mul). The core operation is two large einsums over 4D pair representations that we’ve reshaped into batched matmuls:
# Input: (B, N, N, D) where N = 2048 (sequence length), D = 128
# After chunk, permute, reshape: (B*32, 2048, 2048)
x1 = torch.bmm(a, b.transpose(1, 2)) # B*32 independent N×N GEMMsAt N = 2048, this accounts for roughly 40% of the tri-mul compute and is heavily memory-bandwidth-bound. Currently we run in FP32 (4 bytes/element) or BF16 (2 bytes/element). MXFP8 inputs (1 byte/element) with FP32 accumulation would provide up to a 4× reduction in HBM reads, which is the dominant cost at these sizes.
However, there is currently no way to run FP8 batched matrix multiplication through TE:
te.autocast()only intercepts TE modules, nottorch.bmmFloat8Tensor/MXFP8Tensorpassed totorch.bmmsilently dequantize to full precisiongeneral_gemm()supports FP8 × FP8 withuse_split_accumulator=True, but only accepts 2D inputs — looping overB*32slices would likely negate the bandwidth savings
Related: #1910 describes the same gap for FP8 GEMM beyond te.Linear.
Describe the solution you’d like
A batched variant of general_gemm() that accepts 3D inputs and runs FP8 GEMMs across the batch dimension with FP32 accumulation:
from transformer_engine.pytorch.cpp_extensions import batched_general_gemm
# Quantize inputs to FP8
a_fp8 = mxfp8_quantizer.quantize(a_3d) # (B*32, N, N)
b_fp8 = mxfp8_quantizer.quantize(b_3d) # (B*32, N, N)
# Batched FP8 GEMM with FP32 accumulation
output = batched_general_gemm(
a_fp8,
b_fp8,
out_dtype=torch.bfloat16,
layout="NN",
use_split_accumulator=True, # FP8×FP8 multiply, FP32 accumulate
)
# output: (B*32, N, N) in BF16Alternatively, making Float8Tensor / MXFP8Tensor dispatch torch.bmm to real FP8 tensor core GEMMs, instead of dequantizing, would also solve this.
Describe alternatives you’ve considered
- GroupedLinear: Suggested in
#1910, but it is designed for MoE-style use cases with different weights per group. Our use case is two arbitrary input tensors, not input × stored weight. It was also noted there may be significant overhead. - Looping
general_gemm()over batch slices: Functionally possible, but Python loop overhead and the lack of kernel batching would likely wipe out the memory-bandwidth gains from FP8. - Skipping
.float()and runningtorch.bmmin BF16: This is our current workaround. It gives a 2× memory reduction versus FP32, but still leaves another 2× on the table compared with FP8.
Additional context
- Targeting Blackwell (
MXFP8BlockScaling) and Hopper (DelayedScaling/CurrentScaling) - Training workload, so backward-pass support is needed
- This batched FP8 GEMM pattern would also help other workloads with non-
Linearmatmuls, including attention (unfused path), structure prediction, graph neural networks, and any model with einsum contractions reshaped tobmm - TE v1.12+
Happy to provide a minimal repro or benchmark if helpful.