Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
f02c48e
Initial inclusion of new API in fwd as well as part 1 of refactor
Micky774 Feb 2, 2026
0b0ad93
Initial implementation of refactor/API update across ALL CK funcs
Micky774 Feb 3, 2026
c198cbd
Updated logging
Micky774 Feb 6, 2026
a52bb32
Add script for comparing AITER/TE API
Micky774 Feb 6, 2026
77f0a05
Reconcile new AITER mask type
Micky774 Feb 9, 2026
1637266
Updated API helper tool
Micky774 Feb 9, 2026
568e9b5
Merge branch 'dev' into zain/aiter-api
Micky774 Feb 9, 2026
2cb6d82
Formatting
Micky774 Feb 9, 2026
cf4aa9e
Added sys exit to script
Micky774 Feb 9, 2026
e25cea8
Slightly better error message
Micky774 Feb 9, 2026
2122479
Updated AITER_ASM_DIR implementation
Micky774 Feb 11, 2026
4817e72
Update AITER
Micky774 Feb 11, 2026
837b827
Updated AITER_ASM_DIR logic to allow for hip-free use
Micky774 Feb 12, 2026
68ca0fe
Re-introduce setup AITER API check
Micky774 Feb 16, 2026
ae688ab
Update AITER to custom feature branch
Micky774 Feb 16, 2026
762b91b
Reduce AITER build verbosity
Micky774 Feb 16, 2026
0a7187d
Updated API
Micky774 Feb 16, 2026
39b27bc
Address PR comments
Micky774 Feb 17, 2026
29878cf
Updated bias stride calculations
Micky774 Feb 17, 2026
6846a27
Merge branch 'dev' into zain/aiter-api
Micky774 Feb 18, 2026
47592ac
Reverted AITER feature branch use due to verbosity changes
Micky774 Feb 18, 2026
357b5ce
PR review comments
Micky774 Feb 18, 2026
1f080c1
Reintroduced warning suppression in AITER
Micky774 Feb 18, 2026
a657bdd
Removes auto-setting of AITER_LOG_MORE, corrects batch stride impl
Micky774 Feb 18, 2026
c225448
Removes AITER_LOG_MORE from CI runs
Micky774 Feb 18, 2026
4193158
Minor corrections
Micky774 Feb 18, 2026
dbb6106
PR feedback
Micky774 Feb 19, 2026
899162e
Formatting
Micky774 Feb 19, 2026
1081c5e
Copyright
Micky774 Feb 19, 2026
9514855
Merge branch 'dev' into zain/aiter-api
Micky774 Feb 19, 2026
78f1d69
Updated ck_fused_attn lib build to include copying HSA
Micky774 Feb 20, 2026
f935956
Corrected AITER bug and moved to TE feature branch
Micky774 Mar 3, 2026
b90da33
Merge branch 'dev' into zain/aiter-api
Micky774 Mar 4, 2026
0475f85
Added back dropped code from merge conflict
Micky774 Mar 4, 2026
5db08ea
Downgrade to more conservative AITER commit for compat
Micky774 Mar 4, 2026
d5e5ec6
Removed python-level args check
Micky774 Mar 4, 2026
151e9ca
Removed old tools
Micky774 Mar 4, 2026
5a18d16
Corrected arg_size types manually
Micky774 Mar 5, 2026
db34177
Updated AITER commit and fixed API mismatch in group gemm
Micky774 Mar 5, 2026
9cd2833
Added build-time AITER API usage check
Micky774 Mar 5, 2026
a6b831e
PR review comments
Micky774 Mar 5, 2026
d1bd569
Undo extra import removal
Micky774 Mar 5, 2026
686e6a2
Adjusted python requirement in cmakelist
Micky774 Mar 5, 2026
52e1a1a
Updated group-gemm dispatch
Micky774 Mar 6, 2026
48c5839
Made AITER API check earlier
Micky774 Mar 6, 2026
9b2166e
Update AITER w/ Xinya's patch
Micky774 Mar 16, 2026
3dacb2a
Merge branch 'dev' into zain/aiter-api
Micky774 Mar 16, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/aiter
Submodule aiter updated 1369 files
34 changes: 34 additions & 0 deletions transformer_engine/common/ck_fused_attn/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,30 @@ set(AITER_MHA_INSTALL_PREFIX "transformer_engine" CACHE STRING "aiter mha shared
set(__AITER_SOURCE_DIR "${CMAKE_CURRENT_LIST_DIR}/../../../3rdparty/aiter")
set(__CK_SOURCE_DIR "${__AITER_SOURCE_DIR}/3rdparty/composable_kernel")

if(NOT Python_EXECUTABLE)
find_package(Python COMPONENTS Interpreter QUIET)
endif()

if(Python_EXECUTABLE)
execute_process(
COMMAND ${Python_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/check_aiter_mha_args.py --mode both --te-dir "${CMAKE_CURRENT_LIST_DIR}/../../.."
RESULT_VARIABLE AITER_ARG_CHECK_RESULT
OUTPUT_VARIABLE AITER_ARG_CHECK_OUTPUT
ERROR_VARIABLE AITER_ARG_CHECK_ERROR
OUTPUT_STRIP_TRAILING_WHITESPACE
ERROR_STRIP_TRAILING_WHITESPACE
)

if(NOT AITER_ARG_CHECK_RESULT EQUAL 0)
message(FATAL_ERROR
"AITER API validation failed in check_aiter_mha_args.py.\n"
"${AITER_ARG_CHECK_OUTPUT}\n${AITER_ARG_CHECK_ERROR}")
endif()
message(STATUS "AITER API validation passed via check_aiter_mha_args.py")
else()
message(WARNING "Python interpreter not found; skipping AITER API validation.")
endif()

# so far, there are only gfx942 and gfx950 v3 kernels
SET(V3_ASM_ARCHS_SUPPORTED "gfx942;gfx950")

Expand Down Expand Up @@ -107,3 +131,13 @@ set_target_properties(ck_fused_attn PROPERTIES INSTALL_RPATH "$ORIGIN")

install(FILES ${__AITER_MHA_PATH}/libmha_fwd.so ${__AITER_MHA_PATH}/libmha_bwd.so DESTINATION ${CMAKE_INSTALL_PREFIX}/${AITER_MHA_INSTALL_PREFIX}/lib)
install(TARGETS ck_fused_attn DESTINATION ${CMAKE_INSTALL_PREFIX}/${AITER_MHA_INSTALL_PREFIX}/lib)
# copy v3 kernels to destination
foreach(ARCH IN LISTS V3_ASM_ARCHS)
foreach(KERNEL_TYPE IN ITEMS fmha_v3_fwd fmha_v3_bwd)
file(REMOVE_RECURSE ${CMAKE_INSTALL_PREFIX}/${AITER_MHA_INSTALL_PREFIX}/lib/aiter/${ARCH}/${KERNEL_TYPE})
install(DIRECTORY
${__AITER_SOURCE_DIR}/hsa/${ARCH}/${KERNEL_TYPE}
DESTINATION ${CMAKE_INSTALL_PREFIX}/${AITER_MHA_INSTALL_PREFIX}/lib/aiter/${ARCH}/
PATTERN "codegen.py" EXCLUDE)
endforeach()
endforeach()
109 changes: 109 additions & 0 deletions transformer_engine/common/ck_fused_attn/check_aiter_mha_args.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved.
#
# See LICENSE for license information.


"""
This script is run during setup through setup.py, and can be run independently
to check that the fields defined in the mha_{fwd,bwd}_args structs in the AITER
headers are correctly referenced in the source code.
"""

import argparse
import re
from pathlib import Path
from typing import List, Set
import sys

def parse_with_skip_comments(buffer, line, regex, outputs):
# skip comments
stripped = line.strip()
if not stripped or stripped.startswith("//"):
return
line_no_comment = re.sub(r"//.*", "", line)
buffer[0] += " " + line_no_comment.strip()
if ";" not in line_no_comment:
return
match = regex.search(buffer[0])
if match:
outputs.append(match.group(1))
buffer[0] = ""


def extract_fields_from_header(text: str, struct_name: str) -> List[str]:
struct_field_re = re.compile(r"([A-Za-z_][A-Za-z0-9_]*)\s*(?:=[^;]*)?;\s*$")
struct_end_re = re.compile(r"^\s*};\s*$")

struct_start_re = re.compile(rf"\bstruct\s+{re.escape(struct_name)}\b")
lines = text.splitlines()
in_struct = False
fields: List[str] = []
buffer = [""]
for line in lines:
if not in_struct:
if struct_start_re.search(line):
in_struct = True
continue
if struct_end_re.search(line):
break
parse_with_skip_comments(buffer, line, struct_field_re, fields)
return fields


def extract_usage_from_source(text: str, var_name: str) -> Set[str]:
assign_re = re.compile(rf"\b{re.escape(var_name)}\.([A-Za-z_][A-Za-z0-9_]*)\b\s*=")
assignments = []
lines = text.splitlines()
buffer = [""]
for line in lines:
parse_with_skip_comments(buffer, line, assign_re, assignments)
return set(assignments)


def main() -> int:
parser = argparse.ArgumentParser(description="Check aiter args usage vs header definition")
parser.add_argument("--mode", choices=["fwd", "bwd", "both"], default="both", help="Mode: fwd, bwd, or both")
parser.add_argument("--te-dir", type=Path, default=Path(__file__).parent.parent.parent.parent, help="Root directory of TransformerEngine")
args = parser.parse_args()
modes = ["fwd", "bwd"] if args.mode == "both" else [args.mode]
mismatch = 0
for mode in modes:
header_path = args.te_dir / f"3rdparty/aiter/csrc/include/mha_{mode}.h"
source_path = args.te_dir / f"transformer_engine/common/ck_fused_attn/src/ck_fused_attn_{mode}.cpp"
header_text = header_path.read_text(encoding="utf-8")
source_text = source_path.read_text(encoding="utf-8")

header_fields = extract_fields_from_header(header_text, f"mha_{mode}_args")
header_set = set(header_fields)
used_fields = extract_usage_from_source(source_text, f"fmha_args")

missing_in_usage = sorted(header_set - used_fields)
unknown_in_header = sorted(used_fields - header_set)
mismatch += len(missing_in_usage) + len(unknown_in_header)

print(f"\nAnalyzing mha_{mode}_args\n")
print(f"mha_{mode}_args fields in header:", len(header_set))
print(f"mha_{mode}_args fields referenced in source:", len(used_fields))

if missing_in_usage:
print("\nFields present in header but not referenced in source:")
for name in missing_in_usage:
print(f" - {name}")
else:
print("\nAll header fields are referenced in source.")

if unknown_in_header:
print("\nFields referenced in source but not in header:")
for name in unknown_in_header:
print(f" - {name}")
else:
print("\nNo unknown fields referenced in source.")

if mismatch:
print(f"\nTotal mismatched fields: {mismatch}")
return 1
return 0


if __name__ == "__main__":
sys.exit(main())
Loading