Describe the bug
ONNX export is failing for a standard transformer encoder with fp8 precision. I realized that MHA doesn't seem to be working.
Simple repro:
import torch
from torch import nn
import transformer_engine.pytorch as te
from transformer_engine.pytorch.attention.multi_head_attention import MultiheadAttention as TF_MHA
from transformer_engine.pytorch.export import te_translation_table
from transformer_engine.common.recipe import Format, DelayedScaling
num_attention_heads = 2
hidden_dim = 128 * num_attention_heads
sequence_length = 8192
batch_dim = 1
num_layers = 3
x = torch.randn([sequence_length, batch_dim, hidden_dim], device="cuda")
class TestMHAModel(torch.nn.Module):
def __init__(self, hidden_dim, num_attention_heads, num_layers):
super(TestMHAModel, self).__init__()
self.layers = nn.Sequential(
*[TF_MHA(
hidden_size=hidden_dim,
num_attention_heads=num_attention_heads,
attention_dropout=0.0,
layer_number=i+1,
attn_mask_type="no_mask",
window_size=(-1,-1),
attention_type="self",
normalization="LayerNorm",
seq_length=sequence_length) for i in range(num_layers)]
)
def forward(self, x):
return self.layers(x)
def export(fname, x, recipe):
with te.autocast(enabled=True, recipe=recipe):
model = TestMHAModel(hidden_dim, num_attention_heads, num_layers).cuda().eval()
with torch.inference_mode():
model(x)
with te.onnx_export(enabled=True):
model(x)
with te.onnx_export(enabled=True):
torch.onnx.export(
model,
x,
fname,
output_names=["output"],
dynamo=True,
custom_translation_table=te_translation_table
)
# This works but doesn't introduce the right Q/DQ operators so MHA runs in fp32 precision with TensorRt
recipe = DelayedScaling(fp8_mha = False, fp8_dpa = False)
export("mha1.onnx", x, recipe)
# Failed in FP8EmulationFunc.apply
recipe = DelayedScaling(fp8_mha = True, fp8_dpa = True)
export("mha2.onnx", x, recipe)
I had to run these with NVTE_UnfusedDPA_Emulate_FP8=1 python export_mha_bug.py otherwise I get:
"No dot product attention backend is available for the provided inputs"
It fails with the following stack trace:
query_layer, key_layer, value_layer = FP8EmulationFunc.apply(
^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/autograd/function.py", line 582, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/transformer_engine/pytorch/attention/dot_product_attention/backends.py", line 171, in forward
q_fp8, k_fp8, v_fp8 = combine_and_quantize(
^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/transformer_engine/pytorch/attention/dot_product_attention/utils.py", line 2190, in combine_and_quantize
qkv_fp8 = qkv_quantizer(qkv)
^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/transformer_engine/pytorch/quantized_tensor.py", line 262, in __call__
return self.quantize(tensor)
^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/transformer_engine/pytorch/quantized_tensor.py", line 245, in quantize
return _QuantizeFunc.forward(None, tensor, self.quantize_impl)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/transformer_engine/pytorch/tensor/_quantization_helpers.py", line 29, in forward
return quantize_impl(tensor)
^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/transformer_engine/pytorch/tensor/float8_tensor.py", line 109, in quantize_impl
return tex.quantize(tensor, self)
^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Cannot access data pointer of Tensor (e.g. FakeTensor, FunctionalTensor). If you're using torch.compile/export/fx, it is likely that we are erroneously tracing into a custom kernel. To fix this, please wrap the custom kernel into an opaque custom op. Please see the following for details: https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html
Expected behavior
ONNX file with Q/DQ operators around the matmuls and MHA fusion in TensorRT for fp8 precision.
Environment overview (please complete the following information)
docker run --gpus all -it --rm \ -v $(pwd)/mount:/mount \ nvcr.io/nvidia/pytorch:25.11-py3
I also tried pulling the latest TransformerEngine (built with pip install inside the container above).
Environment details
Additional context
Add any other context about the problem here.
Describe the bug
ONNX export is failing for a standard transformer encoder with fp8 precision. I realized that MHA doesn't seem to be working.
Simple repro:
I had to run these with
NVTE_UnfusedDPA_Emulate_FP8=1 python export_mha_bug.pyotherwise I get:"No dot product attention backend is available for the provided inputs"
It fails with the following stack trace:
Expected behavior
ONNX file with Q/DQ operators around the matmuls and MHA fusion in TensorRT for fp8 precision.
Environment overview (please complete the following information)
docker run --gpus all -it --rm \ -v $(pwd)/mount:/mount \ nvcr.io/nvidia/pytorch:25.11-py3I also tried pulling the latest TransformerEngine (built with pip install inside the container above).
Environment details
Additional context
Add any other context about the problem here.