Skip to content
1 change: 1 addition & 0 deletions qa/L1_pytorch_distributed_unittest/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_torch_fsdp2.xml $TE_
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_comm_gemm_overlap.xml $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py || test_fail "test_comm_gemm_overlap.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops_with_userbuffers.xml $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py || test_fail "test_fusible_ops_with_userbuffers.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_attention_with_cp.xml $TE_PATH/tests/pytorch/attention/test_attention_with_cp.py || test_fail "test_attention_with_cp.py"
NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_attention_deterministic_with_cp.xml $TE_PATH/tests/pytorch/attention/test_attention_with_cp.py || test_fail "NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 test_attention_with_cp.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cp_utils.xml $TE_PATH/tests/pytorch/attention/test_cp_utils.py || test_fail "test_cp_utils.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cast_master_weights_to_fp8.xml $TE_PATH/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py || test_fail "test_cast_master_weights_to_fp8.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_newton_schulz.xml $TE_PATH/tests/pytorch/distributed/test_newton_schulz.py || test_fail "test_newton_schulz.py"
Expand Down
50 changes: 47 additions & 3 deletions qa/L3_pytorch_FA_versions_test/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,25 @@
#
# See LICENSE for license information.

set -e
function error_exit() {
echo "Error: $1"
exit 1
}

function test_fail() {
RET=1
FAILED_CASES="$FAILED_CASES $1"
echo "Error: sub-test failed: $1"
}

RET=0
FAILED_CASES=""

: ${TE_PATH:=/opt/transformerengine}
: ${XML_LOG_DIR:=/logs}
mkdir -p "$XML_LOG_DIR"

pip3 install pytest==8.2.1
pip3 install pytest==8.2.1 || error_exit "Failed to install pytest"

# Limit parallel build jobs to avoid overwhelming system resources
export MAX_JOBS=32
Expand Down Expand Up @@ -40,7 +52,39 @@ do
cd ../../
fi

# Ensure local test utils is found before nvidia-cutlass-dsl's utils package
export PYTHONPATH=$TE_PATH/tests/pytorch:${PYTHONPATH:-}

# Run tests
NVTE_TORCH_COMPILE=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest.xml $TE_PATH/tests/pytorch/attention/test_attention.py
NUM_GPUS=$(nvidia-smi -L | wc -l)
echo "Detected $NUM_GPUS GPU(s)"
if [ "$NUM_GPUS" -ge 5 ]; then
CP_NUM_GPUS=$(( NUM_GPUS - 1 > 4 ? 4 : NUM_GPUS - 1 ))
CP_GPUS=$(seq -s, 1 $CP_NUM_GPUS)
echo "Running tests in parallel: test_attention.py on GPU 0, test_attention_with_cp.py on GPUs $CP_GPUS ($CP_NUM_GPUS GPUs)"

CUDA_VISIBLE_DEVICES=0 NVTE_TORCH_COMPILE=0 python3 -m pytest -v -s \
--junitxml=$XML_LOG_DIR/pytest.xml \
$TE_PATH/tests/pytorch/attention/test_attention.py &
PID_ATTN=$!

CUDA_VISIBLE_DEVICES=$CP_GPUS NVTE_TORCH_COMPILE=0 python3 -m pytest -v -s \
--junitxml=$XML_LOG_DIR/pytest_test_attention_with_cp.xml \
$TE_PATH/tests/pytorch/attention/test_attention_with_cp.py &
Comment thread
sudhakarsingh27 marked this conversation as resolved.
PID_CP=$!

wait $PID_ATTN || test_fail "test_attention.py"
wait $PID_CP || test_fail "test_attention_with_cp.py"
else
echo "Running tests sequentially: need >=5 GPUs for parallel execution (1 for test_attention + 4 for test_attention_with_cp)"
NVTE_TORCH_COMPILE=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest.xml $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "test_attention.py"
NVTE_TORCH_COMPILE=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_attention_with_cp.xml $TE_PATH/tests/pytorch/attention/test_attention_with_cp.py || test_fail "test_attention_with_cp.py"
fi
done

if [ "$RET" -ne 0 ]; then
echo "Error in the following test cases:$FAILED_CASES"
exit 1
fi
echo "All tests passed"
exit 0
96 changes: 56 additions & 40 deletions tests/pytorch/attention/run_attention_with_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def generate_input_shapes(
config: ModelConfig,
world_size: int,
kernel_backend: str,
fa_pad_between_seqs: str = "False",
):
if qkv_format == "bshd":
q_input_shape = (
Expand Down Expand Up @@ -105,9 +106,12 @@ def generate_input_shapes(
).cuda()
cu_seqlens_q = torch.clone(cu_seqlens_q_padded)

# Since FlashAttention doesn't support pad b/w sequences, and FusedAttention does,
# cu_seqlens_q is updated to reflect non-padded lengths for FusedAttention only.
if kernel_backend == "FusedAttention":
# Generate padded data (cu_seqlens_q reflects non-padded lengths, so it
# differs from cu_seqlens_q_padded) for FusedAttention always, and for
# FlashAttention only when its test param requests it. DPA auto-detects
# pad_between_seqs downstream from the cu_seqlens_q vs cu_seqlens_q_padded
# mismatch.
if kernel_backend == "FusedAttention" or fa_pad_between_seqs == "True":
cu_seqlens_q[1:] = seqlens_q.cumsum(0, dtype=torch.int32).cuda()

# NOTE: In case of Cross-Attention, `cu_seqlens_kv` and `cu_seqlens_kv_padded`
Expand Down Expand Up @@ -186,6 +190,7 @@ def run_dpa_with_cp(
scaling_mode="delayed",
f16_O="False",
is_training="True",
fa_pad_between_seqs="False",
deterministic="False",
log_level=logging.WARNING,
):
Expand Down Expand Up @@ -288,7 +293,7 @@ def run_dpa_with_cp(
cu_seqlens_kv,
cu_seqlens_q_padded,
cu_seqlens_kv_padded,
) = generate_input_shapes(qkv_format, config, world_size, kernel_backend)
) = generate_input_shapes(qkv_format, config, world_size, kernel_backend, fa_pad_between_seqs)
q_orig = torch.clamp(torch.randn(q_input_shape, dtype=dtypes[dtype]), min=-1, max=1).cuda()
k_orig = torch.clamp(torch.randn(k_input_shape, dtype=dtypes[dtype]), min=-1, max=1).cuda()
v_orig = torch.clamp(torch.randn(v_input_shape, dtype=dtypes[dtype]), min=-1, max=1).cuda()
Expand Down Expand Up @@ -531,11 +536,11 @@ def run_dpa_with_cp(
tensors_to_deq[i] = tensor.dequantize()
if not fp8_bwd:
tensors[0], tensors[5] = tensors_to_deq
for i, tensor in enumerate(tensors):
for tensor, name in zip(tensors, names):
# dbias/dbias_ could be None, so skip check for it
if tensor is not None:
assert torch.all(~torch.isnan(tensor)), f"{names[i]} contains NaN"
assert torch.all(~torch.isinf(tensor)), f"{names[i]} contains Inf"
assert torch.all(~torch.isnan(tensor)), f"{name} has nan values"
assert torch.all(~torch.isinf(tensor)), f"{name} has inf values"
out, dq, dk, dv, dbias, out_, dq_, dk_, dv_, dbias_ = tensors

############ compare results between CP and no-CP ############
Expand Down Expand Up @@ -588,49 +593,60 @@ def run_dpa_with_cp(
if is_training:
dq, out = [x.index_select(0, seq_idx_q).contiguous() for x in [dq, out]]
dk, dv = [x.index_select(0, seq_idx_kv).contiguous() for x in [dk, dv]]
dq_, dk_, dv_, out_ = [dq_, dk_, dv_, out_]
cu_seqlens_q_padded = cu_seqlens_q_padded // world_size
cu_seqlens_q = get_cu_seqlens_on_cp_rank(
cu_seqlens_q, cu_seqlens_q_padded, world_size, rank, True, True
)
cu_pads_q = cu_seqlens_q_padded - cu_seqlens_q
num_pads_q = cu_pads_q[1:] - cu_pads_q[:-1]
for x in [dq, out, dq_, out_]:
assert torch.count_nonzero(x[cu_seqlens_q_padded[-1] :]).item() == 0
for b in range(config.batch_size):
assert (
num_pads_q[b] == 0
or torch.count_nonzero(
x[
(cu_seqlens_q_padded[b + 1] - num_pads_q[b]) : cu_seqlens_q_padded[
b + 1
]
]
).item()
== 0
)
num_pads_q = (cu_seqlens_q_padded - cu_seqlens_q)[1:] - (
cu_seqlens_q_padded - cu_seqlens_q
)[:-1]
cu_seqlens_kv_padded = cu_seqlens_kv_padded // world_size
cu_seqlens_kv = get_cu_seqlens_on_cp_rank(
cu_seqlens_kv, cu_seqlens_kv_padded, world_size, rank, True, True
)
cu_pads_kv = cu_seqlens_kv_padded - cu_seqlens_kv
num_pads_kv = cu_pads_kv[1:] - cu_pads_kv[:-1]
for x in [dk, dv, dk_, dv_]:
assert torch.count_nonzero(x[cu_seqlens_kv_padded[-1] :]).item() == 0
for b in range(config.batch_size):
assert (
num_pads_kv[b] == 0
or torch.count_nonzero(
x[
(
cu_seqlens_kv_padded[b + 1] - num_pads_kv[b]
) : cu_seqlens_kv_padded[b + 1]
]
).item()
== 0
num_pads_kv = (cu_seqlens_kv_padded - cu_seqlens_kv)[1:] - (
cu_seqlens_kv_padded - cu_seqlens_kv
)[:-1]
# FA3 leaves garbage at padding positions despite seqused_q/k (tile spillover).
# Forward out_ can't be pre-zeroed because FA3's custom op returns out_ as an
# output rather than mutating it in-place, triggering PyTorch's aliasing constraint.
# Backward dq/dk/dv CAN be pre-zeroed because FA3 marks them as mutated inputs.
if fa_pad_between_seqs == "True":
# out_ is a view inside the CP custom autograd Function, so in-place
# zeroing is blocked by PyTorch. Clone to break the view relationship.
out_ = out_.clone()
for x in [out, out_, dq]:
for b in range(config.batch_size):
x[
cu_seqlens_q_padded[b + 1] - num_pads_q[b] : cu_seqlens_q_padded[b + 1]
] = 0.0
x[cu_seqlens_q_padded[-1] :] = 0.0
for x in [dk, dv]:
for b in range(config.batch_size):
x[
cu_seqlens_kv_padded[b + 1]
- num_pads_kv[b] : cu_seqlens_kv_padded[b + 1]
] = 0.0
x[cu_seqlens_kv_padded[-1] :] = 0.0
# Verify CP backward tensors have clean padding (pre-zeroed in context_parallel.py).
for xname, x, cu, np_ in [
("dq_", dq_, cu_seqlens_q_padded, num_pads_q),
("dk_", dk_, cu_seqlens_kv_padded, num_pads_kv),
("dv_", dv_, cu_seqlens_kv_padded, num_pads_kv),
]:
nnz = torch.count_nonzero(x[cu[-1] :]).item()
assert nnz == 0, (
f"{xname} has {nnz} nonzero values in tail padding — "
"context_parallel.py should zero padding positions"
)
for b in range(config.batch_size):
if np_[b] > 0:
nnz = torch.count_nonzero(x[cu[b + 1] - np_[b] : cu[b + 1]]).item()
assert nnz == 0, (
f"{xname} has {nnz} nonzero values in batch {b} padding — "
"context_parallel.py should zero padding positions"
)
else:
# Forward-only: reshape only out/out_ for comparison
out = out.index_select(0, seq_idx_q).contiguous()
out_ = out_

Expand Down
34 changes: 14 additions & 20 deletions tests/pytorch/attention/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def reset_global_fp8_state():
@pytest.mark.parametrize("workspace_opt", [True, False])
@pytest.mark.parametrize("qkv_layout", [None])
@pytest.mark.parametrize("swa", [False])
@pytest.mark.parametrize("pad_between_seqs", [False])
@pytest.mark.parametrize("pad_between_seqs", [False, True])
def test_dot_product_attention(
dtype,
model_configs,
Expand Down Expand Up @@ -157,6 +157,8 @@ def test_dot_product_attention(

config.window_size = check_set_window_size(config.attn_mask_type, config.window_size)
qkv_format = qkv_layout.replace("3", "").replace("2", "").split("_")[0]
if pad_between_seqs and qkv_format != "thd":
pytest.skip("pad_between_seqs only applies to THD format!")
if qkv_format == "thd" and "padding" not in config.attn_mask_type:
config.attn_mask_type = (
"padding_" + config.attn_mask_type if config.attn_mask_type != "no_mask" else "padding"
Expand Down Expand Up @@ -195,19 +197,6 @@ def test_dot_product_attention(
)
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends

# FlashAttention does not support pad_between_seqs, but _run_dot_product_attention
# mannually pads and unpads the input and output of FlashAttention for testing purposes
if (
pad_between_seqs
and FlashAttentionUtils.is_installed
and not (
config.max_seqlen_q != config.max_seqlen_kv
and config.attn_mask_type in ["causal", "padding_causal"]
)
and (config.window_size[0] == -1 or FlashAttentionUtils.v2_3_plus)
):
flash_attn_supported = True

# Skip if only unfused backend is supported
if (len(fused_attn_backends) + flash_attn_supported + unfused_attn_supported) < 2:
pytest.skip("Less than two backends to compare.")
Expand Down Expand Up @@ -1330,12 +1319,12 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
block.softmax_offset.requires_grad = True

# Run a forward and backward pass
if backend in ["FlashAttention", "UnfusedDotProductAttention"]:
if backend in ["UnfusedDotProductAttention"]:
q = inp_orig[0]
k = inp_orig[1]
v = inp_orig[2]
d_out = out_grad_orig
if backend == "FusedAttention":
if backend in ["FusedAttention", "FlashAttention"]:
q = inp[0]
k = inp[1]
v = inp[2]
Expand All @@ -1351,14 +1340,19 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
max_seqlen_kv=config.max_seqlen_kv,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
cu_seqlens_q_padded=cu_seqlens_q_after_pad if backend == "FusedAttention" else None,
cu_seqlens_kv_padded=cu_seqlens_kv_after_pad if backend == "FusedAttention" else None,
cu_seqlens_q_padded=(
cu_seqlens_q_after_pad if backend in ["FusedAttention", "FlashAttention"] else None
),
cu_seqlens_kv_padded=(
cu_seqlens_kv_after_pad if backend in ["FusedAttention", "FlashAttention"] else None
),
attn_mask_type=config.attn_mask_type,
checkpoint_core_attention=ckpt_attn,
core_attention_bias_type=config.attn_bias_type,
core_attention_bias=bias,
alibi_slopes=alibi_slopes,
fast_zero_fill=True,
pad_between_seqs=pad_between_seqs,
# Only pass num_splits when exercising the FlashAttention path
num_splits=config.num_splits if backend == "FlashAttention" else 1,
)
Expand All @@ -1372,12 +1366,12 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
if is_training and config.softmax_type != "vanilla":
d_softmax_offset = block.softmax_offset.grad

if backend in ["FlashAttention", "UnfusedDotProductAttention"]:
if backend in ["UnfusedDotProductAttention"]:
if is_training:
return out, max_logit, (q.grad, k.grad, v.grad, d_softmax_offset)
else:
return out, max_logit, (None, None, None, d_softmax_offset)
if backend == "FusedAttention":
if backend in ["FusedAttention", "FlashAttention"]:
if qkv_format == "thd" and pad_between_seqs:
out_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
if is_training:
Expand Down
22 changes: 21 additions & 1 deletion tests/pytorch/attention/test_attention_with_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,11 +91,20 @@ def get_bash_arguments(num_gpus_per_node, **kwargs):
@pytest.mark.parametrize("model", model_configs_flash_attn.keys())
@pytest.mark.parametrize("qkv_format", qkv_formats)
@pytest.mark.parametrize("cp_comm_type", cp_comm_types)
def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
@pytest.mark.parametrize("pad_between_seqs", [False, True])
def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type, pad_between_seqs):
num_gpus = 4 if cp_comm_type == "a2a+p2p" else 2
if num_gpus > torch.cuda.device_count():
pytest.skip(f"Test requires {num_gpus} GPUs, but found {torch.cuda.device_count()}")

if pad_between_seqs:
if qkv_format != "thd":
pytest.skip("pad_between_seqs only applies to THD format!")
if not FlashAttentionUtils.v3_is_installed:
pytest.skip("pad_between_seqs with CP requires Flash Attention v3!")
if cp_comm_type == "a2a+p2p":
pytest.skip("pad_between_seqs is not yet supported with A2A+P2P CP comm type!")
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about AG?


config = model_configs_flash_attn[model]
config.context_parallel = True
config.cp_comm_type = cp_comm_type
Expand Down Expand Up @@ -148,6 +157,7 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
qkv_format=qkv_format,
kernel_backend="FlashAttention",
cp_comm_type=cp_comm_type,
fa_pad_between_seqs=pad_between_seqs,
log_level=pytest_logging_level,
),
)
Expand Down Expand Up @@ -386,6 +396,7 @@ def test_cp_with_fused_attention(
is_training=is_training,
deterministic=_deterministic,
)

_, fused_attn_supported, _ = available_backends
if fused_attn_supported and config.attn_mask_type in ["causal", "padding_causal"]:
config_copy = copy.deepcopy(config)
Expand All @@ -404,6 +415,15 @@ def test_cp_with_fused_attention(
if not fused_attn_supported:
pytest.skip("No attention backend available.")

deterministic = not bool(int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1")))
if deterministic:
if config.softmax_type != "vanilla":
pytest.skip(
"Deterministic mode does not support non-vanilla softmax with FusedAttention"
)
if config.attn_bias_type == "post_scale_bias" and is_training:
pytest.skip("Deterministic mode does not support post_scale_bias with requires_grad")

run_distributed(
get_bash_arguments(
num_gpus_per_node=num_gpus,
Expand Down
Loading
Loading