Skip to content

Request for batched general_gemm() (or FP8-aware torch.bmm) for non-Linear GEMM workloads #2846

@jomitchellnv

Description

@jomitchellnv

Is your feature request related to a problem? Please describe.

We’re accelerating triangular multiplication in a protein structure prediction model (AlphaFold-style tri-mul). The core operation is two large einsums over 4D pair representations that we’ve reshaped into batched matmuls:

# Input: (B, N, N, D) where N = 2048 (sequence length), D = 128
# After chunk, permute, reshape: (B*32, 2048, 2048)
x1 = torch.bmm(a, b.transpose(1, 2))  # B*32 independent N×N GEMMs

At N = 2048, this accounts for roughly 40% of the tri-mul compute and is heavily memory-bandwidth-bound. Currently we run in FP32 (4 bytes/element) or BF16 (2 bytes/element). MXFP8 inputs (1 byte/element) with FP32 accumulation would provide up to a 4× reduction in HBM reads, which is the dominant cost at these sizes.

However, there is currently no way to run FP8 batched matrix multiplication through TE:

  • te.autocast() only intercepts TE modules, not torch.bmm
  • Float8Tensor / MXFP8Tensor passed to torch.bmm silently dequantize to full precision
  • general_gemm() supports FP8 × FP8 with use_split_accumulator=True, but only accepts 2D inputs — looping over B*32 slices would likely negate the bandwidth savings

Related: #1910 describes the same gap for FP8 GEMM beyond te.Linear.

Describe the solution you’d like

A batched variant of general_gemm() that accepts 3D inputs and runs FP8 GEMMs across the batch dimension with FP32 accumulation:

from transformer_engine.pytorch.cpp_extensions import batched_general_gemm

# Quantize inputs to FP8
a_fp8 = mxfp8_quantizer.quantize(a_3d)  # (B*32, N, N)
b_fp8 = mxfp8_quantizer.quantize(b_3d)  # (B*32, N, N)

# Batched FP8 GEMM with FP32 accumulation
output = batched_general_gemm(
    a_fp8,
    b_fp8,
    out_dtype=torch.bfloat16,
    layout="NN",
    use_split_accumulator=True,  # FP8×FP8 multiply, FP32 accumulate
)
# output: (B*32, N, N) in BF16

Alternatively, making Float8Tensor / MXFP8Tensor dispatch torch.bmm to real FP8 tensor core GEMMs, instead of dequantizing, would also solve this.

Describe alternatives you’ve considered

  • GroupedLinear: Suggested in #1910, but it is designed for MoE-style use cases with different weights per group. Our use case is two arbitrary input tensors, not input × stored weight. It was also noted there may be significant overhead.
  • Looping general_gemm() over batch slices: Functionally possible, but Python loop overhead and the lack of kernel batching would likely wipe out the memory-bandwidth gains from FP8.
  • Skipping .float() and running torch.bmm in BF16: This is our current workaround. It gives a 2× memory reduction versus FP32, but still leaves another 2× on the table compared with FP8.

Additional context

  • Targeting Blackwell (MXFP8BlockScaling) and Hopper (DelayedScaling / CurrentScaling)
  • Training workload, so backward-pass support is needed
  • This batched FP8 GEMM pattern would also help other workloads with non-Linear matmuls, including attention (unfused path), structure prediction, graph neural networks, and any model with einsum contractions reshaped to bmm
  • TE v1.12+

Happy to provide a minimal repro or benchmark if helpful.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions