diff --git a/3rdparty/aiter b/3rdparty/aiter index a64fa18e6..a52d98bad 160000 --- a/3rdparty/aiter +++ b/3rdparty/aiter @@ -1 +1 @@ -Subproject commit a64fa18e60235994e4cbfd7059cc2f60d06e743f +Subproject commit a52d98bad74478202fb19e4c8cb065be0f1ec8c6 diff --git a/transformer_engine/common/ck_fused_attn/CMakeLists.txt b/transformer_engine/common/ck_fused_attn/CMakeLists.txt index 0327a9f85..b3ec4db73 100644 --- a/transformer_engine/common/ck_fused_attn/CMakeLists.txt +++ b/transformer_engine/common/ck_fused_attn/CMakeLists.txt @@ -11,6 +11,30 @@ set(AITER_MHA_INSTALL_PREFIX "transformer_engine" CACHE STRING "aiter mha shared set(__AITER_SOURCE_DIR "${CMAKE_CURRENT_LIST_DIR}/../../../3rdparty/aiter") set(__CK_SOURCE_DIR "${__AITER_SOURCE_DIR}/3rdparty/composable_kernel") +if(NOT Python_EXECUTABLE) + find_package(Python COMPONENTS Interpreter QUIET) +endif() + +if(Python_EXECUTABLE) + execute_process( + COMMAND ${Python_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/check_aiter_mha_args.py --mode both --te-dir "${CMAKE_CURRENT_LIST_DIR}/../../.." + RESULT_VARIABLE AITER_ARG_CHECK_RESULT + OUTPUT_VARIABLE AITER_ARG_CHECK_OUTPUT + ERROR_VARIABLE AITER_ARG_CHECK_ERROR + OUTPUT_STRIP_TRAILING_WHITESPACE + ERROR_STRIP_TRAILING_WHITESPACE + ) + + if(NOT AITER_ARG_CHECK_RESULT EQUAL 0) + message(FATAL_ERROR + "AITER API validation failed in check_aiter_mha_args.py.\n" + "${AITER_ARG_CHECK_OUTPUT}\n${AITER_ARG_CHECK_ERROR}") + endif() + message(STATUS "AITER API validation passed via check_aiter_mha_args.py") +else() + message(WARNING "Python interpreter not found; skipping AITER API validation.") +endif() + # so far, there are only gfx942 and gfx950 v3 kernels SET(V3_ASM_ARCHS_SUPPORTED "gfx942;gfx950") @@ -107,3 +131,13 @@ set_target_properties(ck_fused_attn PROPERTIES INSTALL_RPATH "$ORIGIN") install(FILES ${__AITER_MHA_PATH}/libmha_fwd.so ${__AITER_MHA_PATH}/libmha_bwd.so DESTINATION ${CMAKE_INSTALL_PREFIX}/${AITER_MHA_INSTALL_PREFIX}/lib) install(TARGETS ck_fused_attn DESTINATION ${CMAKE_INSTALL_PREFIX}/${AITER_MHA_INSTALL_PREFIX}/lib) +# copy v3 kernels to destination +foreach(ARCH IN LISTS V3_ASM_ARCHS) + foreach(KERNEL_TYPE IN ITEMS fmha_v3_fwd fmha_v3_bwd) + file(REMOVE_RECURSE ${CMAKE_INSTALL_PREFIX}/${AITER_MHA_INSTALL_PREFIX}/lib/aiter/${ARCH}/${KERNEL_TYPE}) + install(DIRECTORY + ${__AITER_SOURCE_DIR}/hsa/${ARCH}/${KERNEL_TYPE} + DESTINATION ${CMAKE_INSTALL_PREFIX}/${AITER_MHA_INSTALL_PREFIX}/lib/aiter/${ARCH}/ + PATTERN "codegen.py" EXCLUDE) + endforeach() +endforeach() diff --git a/transformer_engine/common/ck_fused_attn/check_aiter_mha_args.py b/transformer_engine/common/ck_fused_attn/check_aiter_mha_args.py new file mode 100644 index 000000000..2e9831f1a --- /dev/null +++ b/transformer_engine/common/ck_fused_attn/check_aiter_mha_args.py @@ -0,0 +1,109 @@ +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. + + +""" +This script is run during setup through setup.py, and can be run independently +to check that the fields defined in the mha_{fwd,bwd}_args structs in the AITER +headers are correctly referenced in the source code. +""" + +import argparse +import re +from pathlib import Path +from typing import List, Set +import sys + +def parse_with_skip_comments(buffer, line, regex, outputs): + # skip comments + stripped = line.strip() + if not stripped or stripped.startswith("//"): + return + line_no_comment = re.sub(r"//.*", "", line) + buffer[0] += " " + line_no_comment.strip() + if ";" not in line_no_comment: + return + match = regex.search(buffer[0]) + if match: + outputs.append(match.group(1)) + buffer[0] = "" + + +def extract_fields_from_header(text: str, struct_name: str) -> List[str]: + struct_field_re = re.compile(r"([A-Za-z_][A-Za-z0-9_]*)\s*(?:=[^;]*)?;\s*$") + struct_end_re = re.compile(r"^\s*};\s*$") + + struct_start_re = re.compile(rf"\bstruct\s+{re.escape(struct_name)}\b") + lines = text.splitlines() + in_struct = False + fields: List[str] = [] + buffer = [""] + for line in lines: + if not in_struct: + if struct_start_re.search(line): + in_struct = True + continue + if struct_end_re.search(line): + break + parse_with_skip_comments(buffer, line, struct_field_re, fields) + return fields + + +def extract_usage_from_source(text: str, var_name: str) -> Set[str]: + assign_re = re.compile(rf"\b{re.escape(var_name)}\.([A-Za-z_][A-Za-z0-9_]*)\b\s*=") + assignments = [] + lines = text.splitlines() + buffer = [""] + for line in lines: + parse_with_skip_comments(buffer, line, assign_re, assignments) + return set(assignments) + + +def main() -> int: + parser = argparse.ArgumentParser(description="Check aiter args usage vs header definition") + parser.add_argument("--mode", choices=["fwd", "bwd", "both"], default="both", help="Mode: fwd, bwd, or both") + parser.add_argument("--te-dir", type=Path, default=Path(__file__).parent.parent.parent.parent, help="Root directory of TransformerEngine") + args = parser.parse_args() + modes = ["fwd", "bwd"] if args.mode == "both" else [args.mode] + mismatch = 0 + for mode in modes: + header_path = args.te_dir / f"3rdparty/aiter/csrc/include/mha_{mode}.h" + source_path = args.te_dir / f"transformer_engine/common/ck_fused_attn/src/ck_fused_attn_{mode}.cpp" + header_text = header_path.read_text(encoding="utf-8") + source_text = source_path.read_text(encoding="utf-8") + + header_fields = extract_fields_from_header(header_text, f"mha_{mode}_args") + header_set = set(header_fields) + used_fields = extract_usage_from_source(source_text, f"fmha_args") + + missing_in_usage = sorted(header_set - used_fields) + unknown_in_header = sorted(used_fields - header_set) + mismatch += len(missing_in_usage) + len(unknown_in_header) + + print(f"\nAnalyzing mha_{mode}_args\n") + print(f"mha_{mode}_args fields in header:", len(header_set)) + print(f"mha_{mode}_args fields referenced in source:", len(used_fields)) + + if missing_in_usage: + print("\nFields present in header but not referenced in source:") + for name in missing_in_usage: + print(f" - {name}") + else: + print("\nAll header fields are referenced in source.") + + if unknown_in_header: + print("\nFields referenced in source but not in header:") + for name in unknown_in_header: + print(f" - {name}") + else: + print("\nNo unknown fields referenced in source.") + + if mismatch: + print(f"\nTotal mismatched fields: {mismatch}") + return 1 + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp index c349b9681..632baa142 100644 --- a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp +++ b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp @@ -9,7 +9,6 @@ #include #include #include "ck_fused_attn/ck_fused_attn.hpp" -#include "ck_tile/host.hpp" #include "mha_bwd.h" #include "ck_fused_attn_utils.hpp" @@ -44,12 +43,12 @@ __global__ void dk_dv_reduce( DataType *dv, //k,v, dk, dv guaranteed to have the same stride uint64_t stride_b_dkv, uint64_t stride_h_dkv, uint64_t stride_s_dkv){ - + uint64_t batch_idx = blockIdx.x; uint64_t seqlen_idx = blockIdx.y; uint64_t head_k_idx = blockIdx.z; uint64_t hdim_idx = threadIdx.x; - + // h guaranteed to be multiples of hg uint64_t head_idx_offset = h / hg; @@ -59,7 +58,7 @@ __global__ void dk_dv_reduce( assert(hdim_idx){ @@ -91,12 +90,12 @@ __global__ void dk_or_dv_reduce( DataType *dk_or_dv, //k,v, dk, dv guaranteed to have the same stride uint64_t stride_b_dk_or_dv, uint64_t stride_h_dk_or_dv, uint64_t stride_s_dk_or_dv){ - + uint64_t batch_idx = blockIdx.x; uint64_t seqlen_idx = blockIdx.y; uint64_t head_k_or_v_idx = blockIdx.z; uint64_t hdim_idx = threadIdx.x; - + // h guaranteed to be multiples of hg uint64_t head_idx_offset = h / hg; @@ -105,7 +104,7 @@ __global__ void dk_or_dv_reduce( assert(hdim_idx){ @@ -141,9 +140,9 @@ __global__ void dk_dv_reduce_thd( uint64_t seqlen_idx = blockIdx.x; uint64_t head_k_idx = blockIdx.y; uint64_t hdim_idx = threadIdx.x; - + assert(hdim_idx= *((cu_seqlen_kv_padded_ptr? cu_seqlen_kv_padded_ptr: cu_seqlen_kv_ptr)+b)){ return; } @@ -163,7 +162,7 @@ __global__ void dk_dv_reduce_thd( uint64_t read_idx = head_k_idx*head_idx_offset*stride_h_dkv_expanded + seqlen_idx*stride_s_dkv_expanded + hdim_idx; uint64_t write_idx = head_k_idx*stride_h_dkv + seqlen_idx* stride_s_dkv + hdim_idx; - + for(uint64_t ii = 0; ii < head_idx_offset; ii++){ // bf16 requires special casting in CK if constexpr (std::is_same_v){ @@ -201,7 +200,7 @@ __global__ void dk_or_dv_reduce_thd( uint64_t seqlen_idx = blockIdx.x; uint64_t head_k_or_v_idx = blockIdx.y; uint64_t hdim_idx = threadIdx.x; - + assert(hdim_idx= *((cu_seqlen_kv_padded_ptr? cu_seqlen_kv_padded_ptr: cu_seqlen_kv_ptr)+b)){ @@ -221,7 +220,7 @@ __global__ void dk_or_dv_reduce_thd( uint64_t read_idx = head_k_or_v_idx*head_idx_offset*stride_h_dk_or_dv_expanded + seqlen_idx*stride_s_dk_or_dv_expanded + hdim_idx; uint64_t write_idx = head_k_or_v_idx*stride_h_dk_or_dv + seqlen_idx* stride_s_dk_or_dv + hdim_idx; - + for(uint64_t ii = 0; ii < head_idx_offset; ii++){ // bf16 requires special casting in CK if constexpr (std::is_same_v){ @@ -247,7 +246,7 @@ __global__ void dbias_reduce_11ss( uint64_t b, uint64_t h, uint64_t s_q, uint64_t s_kv, const DataType *dbias_expanded, DataType *dbias){ - + const uint64_t stride_h = s_q*s_kv; const uint64_t stride_b = h*s_q*s_kv; for(uint64_t ss_idx = blockIdx.x*blockDim.x + threadIdx.x; ss_idx < s_q*s_kv; ss_idx += blockDim.x * gridDim.x){ @@ -277,7 +276,7 @@ __global__ void dbias_reduce_1hss( uint64_t b, uint64_t h, uint64_t s_q, uint64_t s_kv, const DataType *dbias_expanded, DataType *dbias){ - + const uint64_t stride_h = s_q*s_kv; const uint64_t stride_b = h*s_q*s_kv; for(uint64_t ss_idx = blockIdx.x*blockDim.x + threadIdx.x; ss_idx < s_q*s_kv; ss_idx += blockDim.x * gridDim.x){ @@ -307,7 +306,7 @@ __global__ void dbias_reduce_b1ss( uint64_t b, uint64_t h, uint64_t s_q, uint64_t s_kv, const DataType *dbias_expanded, DataType *dbias){ - + const uint64_t stride_h = s_q*s_kv; const uint64_t stride_b = h*s_q*s_kv; for(uint64_t ss_idx = blockIdx.x*blockDim.x + threadIdx.x; ss_idx < s_q*s_kv; ss_idx += blockDim.x * gridDim.x){ @@ -332,119 +331,108 @@ __global__ void dbias_reduce_b1ss( } // print the fmha_traits and args passed into ck apis -void log_bwd_config( - std::ostream* log_file, - const char* func_name, - const std::string data_type_str, - const bool is_group_mode, - const mask_enum mask_type, - const bias_enum bias_type, - const bool has_dbias, - const bool has_dropout, - const bool is_store_randval, - const bool is_deterministic, - const bool uses_bwd_v3, - const bool is_v3_atomic_fp32, - const int how_v3_bf16_cvt, - const fmha_bwd_args& fmha_args -){ - *log_file << "\n" << func_name << "\n"; +void log_bwd_config(const char* func_name, const aiter::mha_bwd_args& fmha_args){ + + std::ostream* log_file = get_ck_log_stream(); + (*log_file) << "\n" << func_name << "\n"; // fmha_traits debug - *log_file << "\n" << "fmha_traits: " << "\n"; - *log_file << "hdim_q: " << fmha_args.hdim_q << "\n"; - *log_file << "hdim_v: " << fmha_args.hdim_v << "\n"; - *log_file << "data_type: " << data_type_str << "\n"; - *log_file << "is_group_mode: " << is_group_mode << "\n"; - *log_file << "mask_type: " << static_cast::type>(mask_type) << "\n"; - *log_file << "bias_type: " << static_cast::type>(bias_type) << "\n"; - *log_file << "has_dbias: " << has_dbias << "\n"; - *log_file << "has_dropout: " << has_dropout << "\n"; - *log_file << "is_store_randval: " << is_store_randval << "\n"; - *log_file << "is_deterministic: " << is_deterministic << "\n"; - *log_file << "uses_bwd_v3: " << uses_bwd_v3 << "\n"; - *log_file << "is_v3_atomic_fp32: " << is_v3_atomic_fp32 << "\n"; - *log_file << "how_v3_bf16_cvt: " << how_v3_bf16_cvt << "\n"; + (*log_file) << "\nfmha_traits: \n"; + log_value(log_file, "hdim_q", fmha_args.hdim_q); + log_value(log_file, "hdim_v", fmha_args.hdim_v); + log_value(log_file, "data_type", fmha_args.data_type); + log_value(log_file, "is_group_mode", fmha_args.is_group_mode); + log_value(log_file, "has_dbias", fmha_args.has_dbias); + log_value(log_file, "has_dropout", fmha_args.has_dropout); + log_value(log_file, "is_store_randval", fmha_args.is_store_randval); + log_value(log_file, "is_deterministic", fmha_args.is_deterministic); + log_value(log_file, "use_asm_v3", fmha_args.use_asm_v3); + log_value(log_file, "v3_atomic_fp32", fmha_args.v3_atomic_fp32); + log_value(log_file, "v3_bf16_cvt", fmha_args.v3_bf16_cvt); // fmha_args debug - *log_file << "\n" << "fmha_args: " << "\n"; - *log_file << "q_ptr: " << fmha_args.q_ptr << "\n"; - *log_file << "k_ptr: " << fmha_args.k_ptr << "\n"; - *log_file << "v_ptr: " << fmha_args.v_ptr << "\n"; - *log_file << "bias_ptr: " << fmha_args.bias_ptr << "\n"; - *log_file << "o_ptr: " << fmha_args.o_ptr << "\n"; - *log_file << "lse_ptr: " << fmha_args.lse_ptr << "\n"; - *log_file << "do_ptr: " << fmha_args.do_ptr << "\n"; - *log_file << "d_ptr: " << fmha_args.d_ptr << "\n"; - *log_file << "rand_val_ptr: " << fmha_args.rand_val_ptr << "\n"; - *log_file << "dq_ptr: " << fmha_args.dq_ptr << "\n"; - *log_file << "dk_ptr: " << fmha_args.dk_ptr << "\n"; - *log_file << "dv_ptr: " << fmha_args.dv_ptr << "\n"; - *log_file << "dbias_ptr: " << fmha_args.dbias_ptr << "\n"; - *log_file << "dq_acc_ptr: " << fmha_args.dq_acc_ptr << "\n"; - - *log_file << "seqstart_q_ptr: " << fmha_args.seqstart_q_ptr << "\n"; - *log_file << "seqstart_k_ptr: " << fmha_args.seqstart_k_ptr << "\n"; - *log_file << "seqlen_q_ptr: " << fmha_args.seqlen_q_ptr << "\n"; - *log_file << "seqlen_k_ptr: " << fmha_args.seqlen_k_ptr << "\n"; - *log_file << "cu_seqlen_q_ptr: " << fmha_args.cu_seqlen_q_ptr << "\n"; - *log_file << "cu_seqlen_k_ptr: " << fmha_args.cu_seqlen_k_ptr << "\n"; - - *log_file << "seqlen_q: " << fmha_args.seqlen_q << "\n"; - *log_file << "seqlen_k: " << fmha_args.seqlen_k << "\n"; - *log_file << "batch: " << fmha_args.batch << "\n"; - *log_file << "max_seqlen_q: " << fmha_args.max_seqlen_q << "\n"; - *log_file << "max_seqlen_k: " << fmha_args.max_seqlen_k << "\n"; - *log_file << "hdim_q: " << fmha_args.hdim_q << "\n"; - *log_file << "hdim_v: " << fmha_args.hdim_v << "\n"; - *log_file << "nhead_q: " << fmha_args.nhead_q << "\n"; - *log_file << "nhead_k: " << fmha_args.nhead_k << "\n"; - *log_file << "scale: " << fmha_args.scale << "\n"; - *log_file << "stride_q: " << fmha_args.stride_q << "\n"; - *log_file << "stride_k: " << fmha_args.stride_k << "\n"; - *log_file << "stride_v: " << fmha_args.stride_v << "\n"; - *log_file << "stride_bias: " << fmha_args.stride_bias << "\n"; - *log_file << "stride_o: " << fmha_args.stride_o << "\n"; - *log_file << "stride_randval: " << fmha_args.stride_randval << "\n"; - *log_file << "stride_do: " << fmha_args.stride_do << "\n"; - *log_file << "stride_dq_acc: " << fmha_args.stride_dq_acc << "\n"; - *log_file << "stride_dq: " << fmha_args.stride_dq << "\n"; - *log_file << "stride_dk: " << fmha_args.stride_dk << "\n"; - *log_file << "stride_dv: " << fmha_args.stride_dv << "\n"; - *log_file << "stride_dbias: " << fmha_args.stride_dbias << "\n"; - *log_file << "nhead_stride_q: " << fmha_args.nhead_stride_q << "\n"; - *log_file << "nhead_stride_k: " << fmha_args.nhead_stride_k << "\n"; - *log_file << "nhead_stride_v: " << fmha_args.nhead_stride_v << "\n"; - *log_file << "nhead_stride_bias: " << fmha_args.nhead_stride_bias << "\n"; - *log_file << "nhead_stride_o: " << fmha_args.nhead_stride_o << "\n"; - *log_file << "nhead_stride_randval: " << fmha_args.nhead_stride_randval << "\n"; - *log_file << "nhead_stride_do: " << fmha_args.nhead_stride_do << "\n"; - *log_file << "nhead_stride_lsed: " << fmha_args.nhead_stride_lsed << "\n"; - *log_file << "nhead_stride_dq_acc: " << fmha_args.nhead_stride_dq_acc << "\n"; - *log_file << "nhead_stride_dq: " << fmha_args.nhead_stride_dq << "\n"; - *log_file << "nhead_stride_dk: " << fmha_args.nhead_stride_dk << "\n"; - *log_file << "nhead_stride_dv: " << fmha_args.nhead_stride_dv << "\n"; - *log_file << "nhead_stride_dbias: " << fmha_args.nhead_stride_dbias << "\n"; - *log_file << "batch_stride_q: " << fmha_args.batch_stride_q << "\n"; - *log_file << "batch_stride_k: " << fmha_args.batch_stride_k << "\n"; - *log_file << "batch_stride_v: " << fmha_args.batch_stride_v << "\n"; - *log_file << "batch_stride_bias: " << fmha_args.batch_stride_bias << "\n"; - *log_file << "batch_stride_o: " << fmha_args.batch_stride_o << "\n"; - *log_file << "batch_stride_randval: " << fmha_args.batch_stride_randval << "\n"; - *log_file << "batch_stride_do: " << fmha_args.batch_stride_do << "\n"; - *log_file << "batch_stride_lsed: " << fmha_args.batch_stride_lsed << "\n"; - *log_file << "batch_stride_dq_acc: " << fmha_args.batch_stride_dq_acc << "\n"; - *log_file << "batch_stride_dq: " << fmha_args.batch_stride_dq << "\n"; - *log_file << "batch_stride_dk: " << fmha_args.batch_stride_dk << "\n"; - *log_file << "batch_stride_dv: " << fmha_args.batch_stride_dv << "\n"; - *log_file << "batch_stride_dbias: " << fmha_args.batch_stride_dbias << "\n"; - *log_file << "window_size_left: " << fmha_args.window_size_left << "\n"; - *log_file << "window_size_right: " << fmha_args.window_size_right << "\n"; - *log_file << "mask_type: " << fmha_args.mask_type << "\n"; - *log_file << "p_drop: " << fmha_args.p_drop << "\n"; - *log_file << "p_undrop: " << fmha_args.p_undrop << "\n"; - *log_file << "dropout_seed_ptr: " << std::get<0>(std::get>(fmha_args.drop_seed_offset)) << "\n"; - *log_file << "dropout_offset_ptr: " << std::get<1>(std::get>(fmha_args.drop_seed_offset)) << "\n"; + (*log_file) << "\nfmha_args: \n"; + log_value(log_file, "q_ptr", fmha_args.q_ptr); + log_value(log_file, "k_ptr", fmha_args.k_ptr); + log_value(log_file, "v_ptr", fmha_args.v_ptr); + log_value(log_file, "bias_ptr", fmha_args.bias_ptr); + log_value(log_file, "o_ptr", fmha_args.o_ptr); + log_value(log_file, "lse_ptr", fmha_args.lse_ptr); + log_value(log_file, "do_ptr", fmha_args.do_ptr); + log_value(log_file, "d_ptr", fmha_args.d_ptr); + log_value(log_file, "rand_val_ptr", fmha_args.rand_val_ptr); + log_value(log_file, "dq_ptr", fmha_args.dq_ptr); + log_value(log_file, "dk_ptr", fmha_args.dk_ptr); + log_value(log_file, "dv_ptr", fmha_args.dv_ptr); + log_value(log_file, "dbias_ptr", fmha_args.dbias_ptr); + log_value(log_file, "dq_acc_ptr", fmha_args.dq_acc_ptr); + + log_value(log_file, "seqstart_q_ptr", fmha_args.seqstart_q_ptr); + log_value(log_file, "seqstart_k_ptr", fmha_args.seqstart_k_ptr); + log_value(log_file, "seqlen_q_ptr", fmha_args.seqlen_q_ptr); + log_value(log_file, "seqlen_k_ptr", fmha_args.seqlen_k_ptr); + log_value(log_file, "cu_seqlen_q_ptr", fmha_args.cu_seqlen_q_ptr); + log_value(log_file, "cu_seqlen_k_ptr", fmha_args.cu_seqlen_k_ptr); + log_value(log_file, "seqlen_q", fmha_args.seqlen_q); + log_value(log_file, "seqlen_k", fmha_args.seqlen_k); + log_value(log_file, "batch", fmha_args.batch); + log_value(log_file, "max_seqlen_q", fmha_args.max_seqlen_q); + log_value(log_file, "max_seqlen_k", fmha_args.max_seqlen_k); + log_value(log_file, "hdim_q", fmha_args.hdim_q); + log_value(log_file, "hdim_v", fmha_args.hdim_v); + log_value(log_file, "nhead_q", fmha_args.nhead_q); + log_value(log_file, "nhead_k", fmha_args.nhead_k); + log_value(log_file, "scale", fmha_args.scale); + log_value(log_file, "stride_q", fmha_args.stride_q); + log_value(log_file, "stride_k", fmha_args.stride_k); + log_value(log_file, "stride_v", fmha_args.stride_v); + log_value(log_file, "stride_bias", fmha_args.stride_bias); + log_value(log_file, "stride_o", fmha_args.stride_o); + log_value(log_file, "stride_randval", fmha_args.stride_randval); + log_value(log_file, "stride_do", fmha_args.stride_do); + log_value(log_file, "stride_dq_acc", fmha_args.stride_dq_acc); + log_value(log_file, "stride_dq", fmha_args.stride_dq); + log_value(log_file, "stride_dk", fmha_args.stride_dk); + log_value(log_file, "stride_dv", fmha_args.stride_dv); + log_value(log_file, "stride_dbias", fmha_args.stride_dbias); + log_value(log_file, "nhead_stride_q", fmha_args.nhead_stride_q); + log_value(log_file, "nhead_stride_k", fmha_args.nhead_stride_k); + log_value(log_file, "nhead_stride_v", fmha_args.nhead_stride_v); + log_value(log_file, "nhead_stride_bias", fmha_args.nhead_stride_bias); + log_value(log_file, "nhead_stride_o", fmha_args.nhead_stride_o); + log_value(log_file, "nhead_stride_randval", fmha_args.nhead_stride_randval); + log_value(log_file, "nhead_stride_do", fmha_args.nhead_stride_do); + log_value(log_file, "nhead_stride_lsed", fmha_args.nhead_stride_lsed); + log_value(log_file, "nhead_stride_dq_acc", fmha_args.nhead_stride_dq_acc); + log_value(log_file, "nhead_stride_dq", fmha_args.nhead_stride_dq); + log_value(log_file, "nhead_stride_dk", fmha_args.nhead_stride_dk); + log_value(log_file, "nhead_stride_dv", fmha_args.nhead_stride_dv); + log_value(log_file, "nhead_stride_dbias", fmha_args.nhead_stride_dbias); + log_value(log_file, "batch_stride_q", fmha_args.batch_stride_q); + log_value(log_file, "batch_stride_k", fmha_args.batch_stride_k); + log_value(log_file, "batch_stride_v", fmha_args.batch_stride_v); + log_value(log_file, "batch_stride_bias", fmha_args.batch_stride_bias); + log_value(log_file, "batch_stride_o", fmha_args.batch_stride_o); + log_value(log_file, "batch_stride_randval", fmha_args.batch_stride_randval); + log_value(log_file, "batch_stride_do", fmha_args.batch_stride_do); + log_value(log_file, "batch_stride_lsed", fmha_args.batch_stride_lsed); + log_value(log_file, "batch_stride_dq_acc", fmha_args.batch_stride_dq_acc); + log_value(log_file, "batch_stride_dq", fmha_args.batch_stride_dq); + log_value(log_file, "batch_stride_dk", fmha_args.batch_stride_dk); + log_value(log_file, "batch_stride_dv", fmha_args.batch_stride_dv); + log_value(log_file, "batch_stride_dbias", fmha_args.batch_stride_dbias); + log_value(log_file, "window_size_left", fmha_args.window_size_left); + log_value(log_file, "window_size_right", fmha_args.window_size_right); + log_value(log_file, "mask_type", fmha_args.mask_type); + log_value(log_file, "bias_type", fmha_args.bias_type); + log_value(log_file, "p_drop", fmha_args.p_drop); + log_value(log_file, "p_undrop", fmha_args.p_undrop); + log_value(log_file, "dropout_seed_ptr", + std::get<0>(std::get>(fmha_args.drop_seed_offset)) + ); + log_value(log_file, "dropout_offset_ptr", + std::get<1>(std::get>(fmha_args.drop_seed_offset)) + ); } void dump_bwd_timings(const char* dump_path, float average_runtime){ @@ -453,37 +441,42 @@ void dump_bwd_timings(const char* dump_path, float average_runtime){ file << average_runtime << "\n"; } -hipError_t ck_attn_bwd( +hipError_t _ck_attn_bwd_impl( DType dtype, - uint64_t b, uint64_t h, uint64_t hg, uint64_t s_q, uint64_t s_kv, uint64_t d_qk, uint64_t d_v, uint64_t bias_b, uint64_t bias_h, - const void* q_ptr, + uint64_t b, uint64_t h, uint64_t hg, uint64_t s_q, uint64_t s_kv, + uint64_t d_qk, uint64_t d_v, + uint64_t bias_b, uint64_t bias_h, + uint64_t max_tokens_q, uint64_t max_tokens_kv, + const void* q_ptr, uint64_t stride_b_q, uint64_t stride_h_q, uint64_t stride_s_q, - const void* k_ptr, + const void* k_ptr, uint64_t stride_b_k, uint64_t stride_h_k, uint64_t stride_s_k, - const void* v_ptr, + const void* v_ptr, uint64_t stride_b_v, uint64_t stride_h_v, uint64_t stride_s_v, const void* bias_ptr, const void* alibi_slope_ptr, - const void* o_ptr, + const void* cu_seqlen_q_ptr, const void* cu_seqlen_kv_ptr, + const void* cu_seqlen_q_padded_ptr, const void* cu_seqlen_kv_padded_ptr, + const void* o_ptr, uint64_t stride_b_o, uint64_t stride_h_o, uint64_t stride_s_o, - const void* lse_ptr, - const void* do_ptr, + const void* lse_ptr, + const void* do_ptr, uint64_t stride_b_do, uint64_t stride_h_do, uint64_t stride_s_do, float scaling_factor, float dropout_probability, void* philox_seed_ptr, void* philox_offset_ptr, BiasType attn_bias_type, MaskType attn_mask_type, int64_t window_size_left, int64_t window_size_right, - void* dq_ptr, + void* dq_ptr, uint64_t stride_b_dq, uint64_t stride_h_dq, uint64_t stride_s_dq, void* dq_acc_ptr, void* dk_expanded_ptr, void* dv_expanded_ptr, uint64_t stride_b_dk_expanded, uint64_t stride_h_dk_expanded, uint64_t stride_s_dk_expanded, uint64_t stride_b_dv_expanded, uint64_t stride_h_dv_expanded, uint64_t stride_s_dv_expanded, - void* dk_ptr, + void* dk_ptr, uint64_t stride_b_dk, uint64_t stride_h_dk, uint64_t stride_s_dk, - void* dv_ptr, + void* dv_ptr, uint64_t stride_b_dv, uint64_t stride_h_dv, uint64_t stride_s_dv, void* dbias_expanded_ptr, void* dbias_ptr, @@ -492,10 +485,13 @@ hipError_t ck_attn_bwd( bool uses_bwd_v3, bool is_v3_atomic_fp32, int how_v3_bf16_cvt, + bool is_group_mode, + const char* func_name, + bool ck_log_config, hipStream_t stream){ bool has_dropout = (dropout_probability > 0.f); - bool has_dbias = dbias_ptr!=nullptr; + bool has_dbias = dbias_ptr != nullptr; bool is_mqa_gqa = (h > hg); /* CK input parameters */ @@ -511,187 +507,274 @@ hipError_t ck_attn_bwd( float scale_s = scaling_factor; float p_drop = dropout_probability; float p_undrop = 1.0 - p_drop; - bool is_group_mode = false; bool s_randval = false; - bias_enum bias_type; - BiasShape bias_shape; - std::tie(bias_type, bias_shape) = get_ck_bias_type_shape(attn_bias_type, b, h, bias_b, bias_h); ck_tile::index_t left, right; left = window_size_left; right = window_size_right; - mask_enum mask_type = static_cast(attn_mask_type); const char* dump_path = std::getenv("NVTE_DUMP_AITER_RT"); // print kernel name on verbose mode ck_tile::stream_config stream_config{stream, dump_path!=nullptr, get_ck_log_stream() != nullptr}; - ck_tile::index_t shape_seqlen_q = seqlen_q; - ck_tile::index_t shape_seqlen_k = seqlen_k; - std::string data_type_str = get_data_type_str(dtype); - auto fmha_args = [&]() { - // setup stride_* arguments - const ck_tile::index_t stride_q = stride_s_q; - const ck_tile::index_t stride_k = stride_s_k; - const ck_tile::index_t stride_v = stride_s_v; - // bias of shape (bias_b, bias_h, s_q, s_kv) - const ck_tile::index_t stride_bias = max_seqlen_k; - const ck_tile::index_t stride_o = stride_s_o; - const ck_tile::index_t stride_randval = max_seqlen_k; - const ck_tile::index_t stride_do = stride_s_do; - const ck_tile::index_t stride_dq = stride_s_dq; - const ck_tile::index_t stride_dk = stride_s_dk; - const ck_tile::index_t stride_dv = stride_s_dv; - const ck_tile::index_t stride_dk_expanded = stride_s_dk_expanded; - const ck_tile::index_t stride_dv_expanded = stride_s_dv_expanded; - const ck_tile::index_t stride_dq_acc = d_qk; //dq_acc of shape (nsplits, B, H, S, D) - // dbias is of the same shape as bias - // but ck only take dbias with BHSS - const ck_tile::index_t stride_dbias = max_seqlen_k; - // setup nhead_stride_* arguments - const ck_tile::index_t nhead_stride_q = stride_h_q; - const ck_tile::index_t nhead_stride_k = stride_h_k; - const ck_tile::index_t nhead_stride_v = stride_h_v; - // bias input can be of different shapes (11SS, 1HSS, B1SS, and BHSS), but dbias must be of BHSS - const ck_tile::index_t nhead_stride_bias = (bias_shape==BiasShape::k1HSS || bias_shape==BiasShape::kBHSS) ? max_seqlen_q * max_seqlen_k: 0; - const ck_tile::index_t nhead_stride_o = stride_h_o; - const ck_tile::index_t nhead_stride_randval = - shape_seqlen_q * max_seqlen_k; - const ck_tile::index_t nhead_stride_do = stride_h_do; - const ck_tile::index_t nhead_stride_lsed = max_seqlen_q; - const ck_tile::index_t nhead_stride_dq = stride_h_dq; - const ck_tile::index_t nhead_stride_dk = stride_h_dk; - const ck_tile::index_t nhead_stride_dv = stride_h_dv; - const ck_tile::index_t nhead_stride_dk_expanded = stride_h_dk_expanded; - const ck_tile::index_t nhead_stride_dv_expanded = stride_h_dv_expanded; - // dbias can only be of BHSS - const ck_tile::index_t nhead_stride_dbias = max_seqlen_q * max_seqlen_k; - const ck_tile::index_t nhead_stride_dq_acc = s_q*d_qk; //dq_acc of shape (nsplits, B, H, S, D) - // setup batch_stride_* arguments - const ck_tile::index_t batch_stride_q = stride_b_q; - const ck_tile::index_t batch_stride_k = stride_b_k; - const ck_tile::index_t batch_stride_v = stride_b_v; - // bias input can be of different shapes (11SS, 1HSS, B1SS, and BHSS), but dbias must be of BHSS - // for B1SS and BHSS, batch stride for bias are both bias_h x s_q x s_kv (bias_h==1 for B1SS and bias_h == h for BHSS) - const ck_tile::index_t batch_stride_bias = (bias_shape==BiasShape::k11SS || bias_shape==BiasShape::k1HSS) ? 0: bias_h* max_seqlen_q * max_seqlen_k; - const ck_tile::index_t batch_stride_o = stride_b_o; - const ck_tile::index_t batch_stride_randval = - nhead * shape_seqlen_q * max_seqlen_k; - const ck_tile::index_t batch_stride_do = stride_b_do; - const ck_tile::index_t batch_stride_lsed = nhead * max_seqlen_q; - const ck_tile::index_t batch_stride_dq = stride_b_dq; - const ck_tile::index_t batch_stride_dk = stride_b_dk; - const ck_tile::index_t batch_stride_dv = stride_b_dv; - const ck_tile::index_t batch_stride_dk_expanded = stride_b_dk_expanded; - const ck_tile::index_t batch_stride_dv_expanded = stride_b_dv_expanded; - // for dbias, use h since h can be different from bias_h - const ck_tile::index_t batch_stride_dbias = h* max_seqlen_q * max_seqlen_k; - const ck_tile::index_t batch_stride_dq_acc = h*s_q*d_qk; //dq_acc of shape (nsplits, B, H, S, D) - const ck_tile::index_t split_stride_dq_acc = b * h * s_q * d_qk; - - return fmha_bwd_args{q_ptr, - k_ptr, - v_ptr, - bias_type==bias_enum::no_bias? nullptr : (bias_type==bias_enum::alibi? alibi_slope_ptr :bias_ptr), - o_ptr, - lse_ptr, - do_ptr, - lse_workspace_ptr, - nullptr, - dq_ptr, - is_mqa_gqa? dk_expanded_ptr:dk_ptr, - is_mqa_gqa? dv_expanded_ptr:dv_ptr, - has_dbias? (bias_shape==BiasShape::kBHSS ? dbias_ptr: dbias_expanded_ptr): nullptr, - dq_acc_ptr, //dq_acc_buf - nullptr,//seqstart_q_ptr - nullptr,//seqstart_k_ptr - nullptr, /* seqlen_q_ptr */ - nullptr, /* seqlen_k_ptr */ - nullptr, //cu_seqlen_q_ptr - nullptr, //cu_seqlen_k_ptr - shape_seqlen_q, - shape_seqlen_k, - batch, - max_seqlen_q, - max_seqlen_k, - hdim_q, - hdim_v, - nhead, - nhead_k, - scale_s, - stride_q, - stride_k, - stride_v, - bias_type==bias_enum::alibi? 0: stride_bias, - stride_o, - stride_randval, - stride_do, - stride_dq_acc,//stride_dq_acc - stride_dq,//stride_dq - is_mqa_gqa? stride_dk_expanded:stride_dk, - is_mqa_gqa? stride_dv_expanded:stride_dv, - stride_dbias, - nhead_stride_q, - nhead_stride_k, - nhead_stride_v, - nhead_stride_bias, - nhead_stride_o, - nhead_stride_randval, - nhead_stride_do, - nhead_stride_lsed, - nhead_stride_dq_acc, //nhead_stride_dq_acc - nhead_stride_dq, - is_mqa_gqa? nhead_stride_dk_expanded:nhead_stride_dk, - is_mqa_gqa? nhead_stride_dv_expanded:nhead_stride_dv, - nhead_stride_dbias, - batch_stride_q, - batch_stride_k, - batch_stride_v, - batch_stride_bias, - batch_stride_o, - batch_stride_randval, - batch_stride_do, - batch_stride_lsed, - batch_stride_dq_acc, //batch_stride_dq_acc - batch_stride_dq, - is_mqa_gqa? batch_stride_dk_expanded:batch_stride_dk, - is_mqa_gqa? batch_stride_dv_expanded:batch_stride_dv, - batch_stride_dbias, - split_stride_dq_acc, - left, - right, - static_cast(mask_type), - p_drop, - p_undrop, - std::pair{philox_seed_ptr, philox_offset_ptr}}; - }(); + bias_enum bias_type = bias_enum::no_bias; + BiasShape bias_shape = BiasShape::k11SS; + if (!is_group_mode) { + std::tie(bias_type, bias_shape) = get_ck_bias_type_shape(attn_bias_type, b, h, bias_b, bias_h); + } + + aiter::mha_bwd_args fmha_args{}; + fmha_args.mask_type = static_cast(mask_type); + fmha_args.use_asm_v3 = uses_bwd_v3; + fmha_args.v3_atomic_fp32 = is_v3_atomic_fp32; + fmha_args.v3_bf16_cvt = how_v3_bf16_cvt; + fmha_args.v3_api_check = false; - // print ck traits and args when needed - if (auto* log_file = get_ck_log_stream()) { - log_bwd_config(log_file, __FUNCTION__, data_type_str, is_group_mode, mask_type, bias_type, has_dbias, has_dropout, s_randval, deterministic, uses_bwd_v3, is_v3_atomic_fp32, how_v3_bf16_cvt, fmha_args); + fmha_args.hdim_q = hdim_q; + fmha_args.hdim_v = hdim_v; + fmha_args.data_type = data_type_str; + fmha_args.is_group_mode = is_group_mode; + fmha_args.bias_type = static_cast(bias_type); + fmha_args.has_dbias = (!is_group_mode) && has_dbias; + fmha_args.has_dropout = has_dropout; + fmha_args.is_store_randval = s_randval; + fmha_args.is_deterministic = deterministic; + + fmha_args.q_ptr = q_ptr; + fmha_args.k_ptr = k_ptr; + fmha_args.v_ptr = v_ptr; + fmha_args.bias_ptr = (bias_type==bias_enum::no_bias || is_group_mode) ? nullptr + : (bias_type==bias_enum::alibi? alibi_slope_ptr : bias_ptr); + fmha_args.o_ptr = o_ptr; + fmha_args.lse_ptr = lse_ptr; + fmha_args.do_ptr = do_ptr; + fmha_args.d_ptr = lse_workspace_ptr; + fmha_args.rand_val_ptr = nullptr; + fmha_args.dq_ptr = dq_ptr; + fmha_args.dk_ptr = is_mqa_gqa? dk_expanded_ptr:dk_ptr; + fmha_args.dv_ptr = is_mqa_gqa? dv_expanded_ptr:dv_ptr; + fmha_args.dbias_ptr = ((!is_group_mode) && has_dbias) + ? (bias_shape==BiasShape::kBHSS ? dbias_ptr: dbias_expanded_ptr) + : nullptr; + fmha_args.dq_acc_ptr = dq_acc_ptr; + + if (is_group_mode) { + fmha_args.seqstart_q_ptr = cu_seqlen_q_padded_ptr==nullptr? cu_seqlen_q_ptr: cu_seqlen_q_padded_ptr; + fmha_args.seqstart_k_ptr = cu_seqlen_kv_padded_ptr==nullptr? cu_seqlen_kv_ptr: cu_seqlen_kv_padded_ptr; + fmha_args.seqlen_q_ptr = nullptr; + fmha_args.seqlen_k_ptr = nullptr; + fmha_args.cu_seqlen_q_ptr = cu_seqlen_q_ptr; + fmha_args.cu_seqlen_k_ptr = cu_seqlen_kv_ptr; + } else { + fmha_args.seqstart_q_ptr = nullptr; + fmha_args.seqstart_k_ptr = nullptr; + fmha_args.seqlen_q_ptr = nullptr; + fmha_args.seqlen_k_ptr = nullptr; + fmha_args.cu_seqlen_q_ptr = nullptr; + fmha_args.cu_seqlen_k_ptr = nullptr; } - float average_runtime = aiter::mha_bwd(fmha_args, - stream_config, - data_type_str, - is_group_mode, - mask_type, - bias_type, - has_dbias, - s_randval, - deterministic, - uses_bwd_v3, - is_v3_atomic_fp32, - how_v3_bf16_cvt); - if(dump_path){ - dump_bwd_timings(dump_path, average_runtime); + + fmha_args.seqlen_q = is_group_mode ? max_seqlen_q : seqlen_q; + fmha_args.seqlen_k = is_group_mode ? max_seqlen_k : seqlen_k; + fmha_args.batch = batch; + fmha_args.max_seqlen_q = max_seqlen_q; + fmha_args.max_seqlen_k = max_seqlen_k; + fmha_args.nhead_q = nhead; + fmha_args.nhead_k = nhead_k; + fmha_args.scale = scale_s; + + // setup stride_* arguments + fmha_args.stride_q = stride_s_q; + fmha_args.stride_k = stride_s_k; + fmha_args.stride_v = stride_s_v; + // bias of shape (bias_b, bias_h, s_q, s_kv) + fmha_args.stride_bias = (!is_group_mode && bias_type!=bias_enum::alibi) ? max_seqlen_k : 0; + fmha_args.stride_o = stride_s_o; + fmha_args.stride_randval = max_seqlen_k; + fmha_args.stride_do = stride_s_do; + //dq_acc of shape (nsplits, B, H, S, D) + fmha_args.stride_dq_acc = d_qk; + fmha_args.stride_dq = stride_s_dq; + fmha_args.stride_dk = is_mqa_gqa? stride_s_dk_expanded:stride_s_dk; + fmha_args.stride_dv = is_mqa_gqa? stride_s_dv_expanded:stride_s_dv; + // dbias is of the same shape as bias + // but ck only take dbias with BHSS + fmha_args.stride_dbias = (!is_group_mode && bias_type!=bias_enum::alibi) ? max_seqlen_k : 0; + + // setup nhead_stride_* arguments + fmha_args.nhead_stride_q = stride_h_q; + fmha_args.nhead_stride_k = stride_h_k; + fmha_args.nhead_stride_v = stride_h_v; + // bias input can be of different shapes (11SS, 1HSS, B1SS, and BHSS), but dbias must be of BHSS + fmha_args.nhead_stride_bias = get_nhead_stride_bias(bias_shape, max_seqlen_q, max_seqlen_k, is_group_mode); + fmha_args.nhead_stride_o = stride_h_o; + fmha_args.nhead_stride_randval = is_group_mode ? 0 : seqlen_q * max_seqlen_k; + fmha_args.nhead_stride_do = stride_h_do; + fmha_args.nhead_stride_lsed = is_group_mode ? max_tokens_q : max_seqlen_q; + fmha_args.nhead_stride_dq_acc = static_cast((is_group_mode ? max_tokens_q : s_q) * d_qk); + fmha_args.nhead_stride_dq = stride_h_dq; + fmha_args.nhead_stride_dk = is_mqa_gqa? stride_h_dk_expanded:stride_h_dk; + fmha_args.nhead_stride_dv = is_mqa_gqa? stride_h_dv_expanded:stride_h_dv; + // dbias can only be of BHSS + fmha_args.nhead_stride_dbias = is_group_mode? 0: max_seqlen_q * max_seqlen_k; + + // setup batch_stride_* arguments + fmha_args.batch_stride_q = is_group_mode ? 0 : stride_b_q; + fmha_args.batch_stride_k = is_group_mode ? 0 : stride_b_k; + fmha_args.batch_stride_v = is_group_mode ? 0 : stride_b_v; + fmha_args.batch_stride_bias = get_batch_stride_bias(bias_h, bias_shape, max_seqlen_q, max_seqlen_k, is_group_mode, false); + fmha_args.batch_stride_o = is_group_mode ? 0 : stride_b_o; + fmha_args.batch_stride_randval = is_group_mode ? 0 : nhead * seqlen_q * max_seqlen_k; + fmha_args.batch_stride_do = is_group_mode ? 0 : stride_b_do; + fmha_args.batch_stride_lsed = is_group_mode ? 0 : nhead * max_seqlen_q; + fmha_args.batch_stride_dq_acc = is_group_mode ? 0 : static_cast(h * s_q * d_qk); + fmha_args.batch_stride_dq = is_group_mode ? 0 : stride_b_dq; + fmha_args.batch_stride_dk = is_group_mode ? 0 : (is_mqa_gqa? stride_b_dk_expanded:stride_b_dk); + fmha_args.batch_stride_dv = is_group_mode ? 0 : (is_mqa_gqa? stride_b_dv_expanded:stride_b_dv); + // for dbias, use h since h can be different from bias_h + fmha_args.batch_stride_dbias = is_group_mode ? 0 : h * max_seqlen_q * max_seqlen_k; + fmha_args.split_stride_dq_acc = static_cast(is_group_mode ? (max_tokens_q * h * d_qk) : (b * h * s_q * d_qk)); + + fmha_args.window_size_left = left; + fmha_args.window_size_right = right; + fmha_args.p_drop = p_drop; + fmha_args.p_undrop = p_undrop; + fmha_args.drop_seed_offset = std::pair{philox_seed_ptr, philox_offset_ptr}; + + // modify the max_seqlen_q for better performance in 0-length cases + // lse_thd_ptr used as buffer + if(const char* env_p = std::getenv("NVTE_CK_RUNTIME_MAX_SEQLEN")) { + if(is_group_mode && std::string(env_p) == "1"){ + if(ck_log_config){ + std::cout << "attn_bwd(ck): Enabling runtime max_seqlen calculation for small seqlen optimization."; + } + fmha_args.max_seqlen_q = get_runtime_max_seqlen(b, cu_seqlen_q_ptr, nullptr, lse_workspace_ptr, stream); + fmha_args.max_seqlen_k = get_runtime_max_seqlen(b, cu_seqlen_kv_ptr, nullptr, lse_workspace_ptr, stream); + } } + + // print ck traits and args when needed + if(ck_log_config){ + log_bwd_config(func_name, fmha_args); + } + float average_runtime = aiter::mha_bwd(fmha_args, stream_config); if(average_runtime < 0){ //TODO: better error out system throw std::runtime_error("fused attn configs not supported in ck_fused_attn bwd pass."); } + if(dump_path){ + dump_bwd_timings(dump_path, average_runtime); + } + return hipSuccess; +} + +hipError_t ck_attn_bwd( + DType dtype, + uint64_t b, uint64_t h, uint64_t hg, uint64_t s_q, uint64_t s_kv, uint64_t d_qk, uint64_t d_v, uint64_t bias_b, uint64_t bias_h, + const void* q_ptr, + uint64_t stride_b_q, uint64_t stride_h_q, uint64_t stride_s_q, + const void* k_ptr, + uint64_t stride_b_k, uint64_t stride_h_k, uint64_t stride_s_k, + const void* v_ptr, + uint64_t stride_b_v, uint64_t stride_h_v, uint64_t stride_s_v, + const void* bias_ptr, + const void* alibi_slope_ptr, + const void* o_ptr, + uint64_t stride_b_o, uint64_t stride_h_o, uint64_t stride_s_o, + const void* lse_ptr, + const void* do_ptr, + uint64_t stride_b_do, uint64_t stride_h_do, uint64_t stride_s_do, + float scaling_factor, float dropout_probability, + void* philox_seed_ptr, void* philox_offset_ptr, + BiasType attn_bias_type, + MaskType attn_mask_type, + int64_t window_size_left, int64_t window_size_right, + void* dq_ptr, + uint64_t stride_b_dq, uint64_t stride_h_dq, uint64_t stride_s_dq, + void* dq_acc_ptr, + void* dk_expanded_ptr, + void* dv_expanded_ptr, + uint64_t stride_b_dk_expanded, uint64_t stride_h_dk_expanded, uint64_t stride_s_dk_expanded, + uint64_t stride_b_dv_expanded, uint64_t stride_h_dv_expanded, uint64_t stride_s_dv_expanded, + void* dk_ptr, + uint64_t stride_b_dk, uint64_t stride_h_dk, uint64_t stride_s_dk, + void* dv_ptr, + uint64_t stride_b_dv, uint64_t stride_h_dv, uint64_t stride_s_dv, + void* dbias_expanded_ptr, + void* dbias_ptr, + void* lse_workspace_ptr, + bool deterministic, + bool uses_bwd_v3, + bool is_v3_atomic_fp32, + int how_v3_bf16_cvt, + hipStream_t stream){ + + bool has_dropout = (dropout_probability > 0.f); + bool has_dbias = dbias_ptr!=nullptr; + bool is_mqa_gqa = (h > hg); + bias_enum bias_type; + BiasShape bias_shape; + std::tie(bias_type, bias_shape) = get_ck_bias_type_shape(attn_bias_type, b, h, bias_b, bias_h); + + bool ck_log_config = false; + if (const char* env_p = std::getenv("CK_FUSED_ATTN_LOG_CONFIG") ) { + if (env_p != nullptr && std::string(env_p) == "1") + ck_log_config = true; + } + + hipError_t impl_status = _ck_attn_bwd_impl( + dtype, + b, h, hg, s_q, s_kv, d_qk, d_v, + bias_b, bias_h, + s_q, s_kv, + q_ptr, + stride_b_q, stride_h_q, stride_s_q, + k_ptr, + stride_b_k, stride_h_k, stride_s_k, + v_ptr, + stride_b_v, stride_h_v, stride_s_v, + bias_ptr, + alibi_slope_ptr, + nullptr, nullptr, + nullptr, nullptr, + o_ptr, + stride_b_o, stride_h_o, stride_s_o, + lse_ptr, + do_ptr, + stride_b_do, stride_h_do, stride_s_do, + scaling_factor, dropout_probability, + philox_seed_ptr, philox_offset_ptr, + attn_bias_type, + attn_mask_type, + window_size_left, window_size_right, + dq_ptr, + stride_b_dq, stride_h_dq, stride_s_dq, + dq_acc_ptr, + dk_expanded_ptr, + dv_expanded_ptr, + stride_b_dk_expanded, stride_h_dk_expanded, stride_s_dk_expanded, + stride_b_dv_expanded, stride_h_dv_expanded, stride_s_dv_expanded, + dk_ptr, + stride_b_dk, stride_h_dk, stride_s_dk, + dv_ptr, + stride_b_dv, stride_h_dv, stride_s_dv, + dbias_expanded_ptr, + dbias_ptr, + lse_workspace_ptr, + deterministic, + uses_bwd_v3, + is_v3_atomic_fp32, + how_v3_bf16_cvt, + false, + __FUNCTION__, + ck_log_config, + stream); + if (impl_status != hipSuccess) { + return impl_status; + } if(is_mqa_gqa){ dim3 grid(b, s_kv, hg); if (d_qk == d_v) { @@ -780,7 +863,7 @@ hipError_t ck_attn_bwd( dbias_reduce_11ss, grid, block, 0, stream, b, h, s_q, s_kv, static_cast(dbias_expanded_ptr), - static_cast(dbias_ptr));); + static_cast(dbias_ptr));); }else if(bias_shape==BiasShape::k1HSS){ if (auto* log_file = get_ck_log_stream()) { *log_file << "\n" << "run dbias_reduce_1HSS: " << "\n"; @@ -792,7 +875,7 @@ hipError_t ck_attn_bwd( dbias_reduce_1hss, grid, block, 0, stream, b, h, s_q, s_kv, static_cast(dbias_expanded_ptr), - static_cast(dbias_ptr));); + static_cast(dbias_ptr));); }else if(bias_shape==BiasShape::kB1SS){ if (auto* log_file = get_ck_log_stream()) { *log_file << "\n" << "run dbias_reduce_B1SS: " << "\n"; @@ -804,43 +887,43 @@ hipError_t ck_attn_bwd( dbias_reduce_b1ss, grid, block, 0, stream, b, h, s_q, s_kv, static_cast(dbias_expanded_ptr), - static_cast(dbias_ptr));); + static_cast(dbias_ptr));); } } return hipSuccess; } -hipError_t ck_attn_varlen_bwd( +hipError_t ck_attn_varlen_bwd( DType dtype, uint64_t b, uint64_t h, uint64_t hg, uint64_t s_q, uint64_t s_kv, uint64_t d_qk, uint64_t d_v, uint64_t max_tokens_q, uint64_t max_tokens_kv, - const void* q_ptr, + const void* q_ptr, uint64_t stride_h_q, uint64_t stride_s_q, - const void* k_ptr, + const void* k_ptr, uint64_t stride_h_k, uint64_t stride_s_k, - const void* v_ptr, + const void* v_ptr, uint64_t stride_h_v, uint64_t stride_s_v, const void* cu_seqlen_q_ptr, const void* cu_seqlen_kv_ptr, const void* cu_seqlen_q_padded_ptr, const void* cu_seqlen_kv_padded_ptr, - const void* o_ptr, + const void* o_ptr, uint64_t stride_h_o, uint64_t stride_s_o, - const void* lse_thd_ptr, - const void* do_ptr, + const void* lse_thd_ptr, + const void* do_ptr, uint64_t stride_h_do, uint64_t stride_s_do, float scaling_factor, float dropout_probability, void* philox_seed_ptr, void* philox_offset_ptr, MaskType attn_mask_type, int64_t window_size_left, int64_t window_size_right, - void* dq_ptr, + void* dq_ptr, uint64_t stride_h_dq, uint64_t stride_s_dq, void* dq_acc_ptr, void* dk_expanded_ptr, void* dv_expanded_ptr, uint64_t stride_h_dk_expanded, uint64_t stride_s_dk_expanded, uint64_t stride_h_dv_expanded, uint64_t stride_s_dv_expanded, - void* dk_ptr, + void* dk_ptr, uint64_t stride_h_dk, uint64_t stride_s_dk, - void* dv_ptr, + void* dv_ptr, uint64_t stride_h_dv, uint64_t stride_s_dv, void* lse_workspace_ptr, bool deterministic, @@ -848,208 +931,63 @@ hipError_t ck_attn_varlen_bwd( bool is_v3_atomic_fp32, int how_v3_bf16_cvt, hipStream_t stream){ - - bool has_dropout = (dropout_probability > 0.f); - bool has_dbias = false; bool is_mqa_gqa = (h > hg); - /* CK input parameters */ - ck_tile::index_t batch = b; - ck_tile::index_t nhead = h; - ck_tile::index_t hdim_q = d_qk; - ck_tile::index_t nhead_k = hg; - ck_tile::index_t hdim_v = d_v; - ck_tile::index_t max_seqlen_q = s_q; - ck_tile::index_t max_seqlen_k = s_kv; - float scale_s = scaling_factor; - float p_drop = dropout_probability; - float p_undrop = 1.0 - p_drop; - bool is_group_mode = true; - bool s_randval = false; - - // THD does not work with bias - - ck_tile::index_t left, right; - left = window_size_left; - right = window_size_right; - mask_enum mask_type = static_cast(attn_mask_type); - - const char* dump_path = std::getenv("NVTE_DUMP_AITER_RT"); - // print kernel name on verbose mode - ck_tile::stream_config stream_config{stream, dump_path!=nullptr, get_ck_log_stream() != nullptr}; - - std::string data_type_str = get_data_type_str(dtype); - - auto fmha_args = [&]() { - // setup stride_* arguments - const ck_tile::index_t stride_q = stride_s_q; - const ck_tile::index_t stride_k = stride_s_k; - const ck_tile::index_t stride_v = stride_s_v; - // bias not used in THD qkv layout - const ck_tile::index_t stride_bias = 0; - const ck_tile::index_t stride_o = stride_s_o; - const ck_tile::index_t stride_randval = max_seqlen_k; - const ck_tile::index_t stride_do = stride_s_do; - const ck_tile::index_t stride_dq = stride_s_dq; - const ck_tile::index_t stride_dk = stride_s_dk; - const ck_tile::index_t stride_dv = stride_s_dv; - const ck_tile::index_t stride_dk_expanded = stride_s_dk_expanded; - const ck_tile::index_t stride_dv_expanded = stride_s_dv_expanded; - const ck_tile::index_t stride_dq_acc = d_qk; //dq_acc of shape (nsplits, H, max_tokens_q, D_qk) - // bias not used in THD qkv layout - const ck_tile::index_t stride_dbias = 0; - // setup nhead_stride_* arguments - const ck_tile::index_t nhead_stride_q = stride_h_q; - const ck_tile::index_t nhead_stride_k = stride_h_k; - const ck_tile::index_t nhead_stride_v = stride_h_v; - // bias not used in THD qkv layout - const ck_tile::index_t nhead_stride_bias = 0; - const ck_tile::index_t nhead_stride_o = stride_h_o; - const ck_tile::index_t nhead_stride_randval = 0; - const ck_tile::index_t nhead_stride_do = stride_h_do; - // use packed lse - const ck_tile::index_t nhead_stride_lsed = max_tokens_q; - const ck_tile::index_t nhead_stride_dq = stride_h_dq; - const ck_tile::index_t nhead_stride_dk = stride_h_dk; - const ck_tile::index_t nhead_stride_dv = stride_h_dv; - const ck_tile::index_t nhead_stride_dk_expanded = stride_h_dk_expanded; - const ck_tile::index_t nhead_stride_dv_expanded = stride_h_dv_expanded; - // bias not used in THD qkv layout - const ck_tile::index_t nhead_stride_dbias = 0; - const ck_tile::index_t nhead_stride_dq_acc = max_tokens_q*d_qk; //dq_acc of shape (nsplits, H, max_tokens_q, D_qk) - // setup batch_stride_* arguments - const ck_tile::index_t batch_stride_q = 0; - const ck_tile::index_t batch_stride_k = 0; - const ck_tile::index_t batch_stride_v = 0; - // bias not used in THD qkv layout - const ck_tile::index_t batch_stride_bias = 0; - const ck_tile::index_t batch_stride_o = 0; - const ck_tile::index_t batch_stride_randval = 0; - const ck_tile::index_t batch_stride_do = 0; - const ck_tile::index_t batch_stride_lsed = 0; - const ck_tile::index_t batch_stride_dq = 0; - const ck_tile::index_t batch_stride_dk = 0; - const ck_tile::index_t batch_stride_dv = 0; - const ck_tile::index_t batch_stride_dk_expanded = 0; - const ck_tile::index_t batch_stride_dv_expanded = 0; - // bias not used in THD qkv layout - const ck_tile::index_t batch_stride_dbias = 0; - const ck_tile::index_t batch_stride_dq_acc = 0; //dq_acc of shape (nsplits, T, H, D) - const ck_tile::index_t split_stride_dq_acc = max_tokens_q*h*d_qk; - - return fmha_bwd_args{q_ptr, - k_ptr, - v_ptr, - nullptr, - o_ptr, - lse_thd_ptr, - do_ptr, - lse_workspace_ptr, - nullptr, - dq_ptr, - is_mqa_gqa? dk_expanded_ptr:dk_ptr, - is_mqa_gqa? dv_expanded_ptr:dv_ptr, - nullptr, //dbias_ptr - dq_acc_ptr, //dq_acc_buf - cu_seqlen_q_padded_ptr==nullptr? cu_seqlen_q_ptr: cu_seqlen_q_padded_ptr, //seqstart_q_ptr - cu_seqlen_kv_padded_ptr==nullptr? cu_seqlen_kv_ptr: cu_seqlen_kv_padded_ptr, //seqstart_k_ptr - nullptr, /* seqlen_q_ptr */ - nullptr, /* seqlen_k_ptr */ - cu_seqlen_q_ptr, //cu_seqlen_q_ptr - cu_seqlen_kv_ptr, //cu_seqlen_k_ptr - max_seqlen_q, //seqlen_q, unused in group mode - max_seqlen_k, //seqlen_kv, unused in group mode - batch, - max_seqlen_q, - max_seqlen_k, - hdim_q, - hdim_v, - nhead, - nhead_k, - scale_s, - stride_q, - stride_k, - stride_v, - stride_bias, - stride_o, - stride_randval, - stride_do, - stride_dq_acc,//stride_dq_acc - stride_dq,//stride_dq - is_mqa_gqa? stride_dk_expanded:stride_dk, - is_mqa_gqa? stride_dv_expanded:stride_dv, - stride_dbias, - nhead_stride_q, - nhead_stride_k, - nhead_stride_v, - nhead_stride_bias, - nhead_stride_o, - nhead_stride_randval, - nhead_stride_do, - nhead_stride_lsed, - nhead_stride_dq_acc, //nhead_stride_dq_acc - nhead_stride_dq, - is_mqa_gqa? nhead_stride_dk_expanded:nhead_stride_dk, - is_mqa_gqa? nhead_stride_dv_expanded:nhead_stride_dv, - nhead_stride_dbias, - batch_stride_q, - batch_stride_k, - batch_stride_v, - batch_stride_bias, - batch_stride_o, - batch_stride_randval, - batch_stride_do, - batch_stride_lsed, - batch_stride_dq_acc, //batch_stride_dq_acc - batch_stride_dq, - is_mqa_gqa? batch_stride_dk_expanded:batch_stride_dk, - is_mqa_gqa? batch_stride_dv_expanded:batch_stride_dv, - batch_stride_dbias, - split_stride_dq_acc, - left, - right, - static_cast(mask_type), - p_drop, - p_undrop, - std::pair{philox_seed_ptr, philox_offset_ptr}}; - }(); - - // modify the max_seqlen_q for better performance in 0-length cases - // lse_thd_ptr used as buffer - if(const char* env_p = std::getenv("NVTE_CK_RUNTIME_MAX_SEQLEN")) { - if(std::string(env_p) == "1"){ - if (auto* log_file = get_ck_log_stream()) { - *log_file - << "attn_bwd(ck): Enabling runtime max_seqlen calculation for small seqlen optimization.\n"; - } - fmha_args.max_seqlen_q = get_runtime_max_seqlen(b, cu_seqlen_q_ptr, nullptr, lse_workspace_ptr, stream); - fmha_args.max_seqlen_k = get_runtime_max_seqlen(b, cu_seqlen_kv_ptr, nullptr, lse_workspace_ptr, stream); - } + bool ck_log_config = false; + if (const char* env_p = std::getenv("CK_FUSED_ATTN_LOG_CONFIG") ) { + if (env_p != nullptr && std::string(env_p) == "1") + ck_log_config = true; } - // print ck traits and args when needed - if (auto* log_file = get_ck_log_stream()) { - log_bwd_config(log_file, __FUNCTION__, data_type_str, is_group_mode, mask_type, bias_enum::no_bias, has_dbias, has_dropout, s_randval, deterministic, uses_bwd_v3, is_v3_atomic_fp32, how_v3_bf16_cvt, fmha_args); - } - - float average_runtime = aiter::mha_bwd(fmha_args, - stream_config, - data_type_str, - is_group_mode, - mask_type, - bias_enum::no_bias, - has_dbias, - s_randval, + hipError_t impl_status = _ck_attn_bwd_impl( + dtype, + b, h, hg, s_q, s_kv, d_qk, d_v, + 0, 0, + max_tokens_q, max_tokens_kv, + q_ptr, + 0, stride_h_q, stride_s_q, + k_ptr, + 0, stride_h_k, stride_s_k, + v_ptr, + 0, stride_h_v, stride_s_v, + nullptr, + nullptr, + cu_seqlen_q_ptr, cu_seqlen_kv_ptr, + cu_seqlen_q_padded_ptr, cu_seqlen_kv_padded_ptr, + o_ptr, + 0, stride_h_o, stride_s_o, + lse_thd_ptr, + do_ptr, + 0, stride_h_do, stride_s_do, + scaling_factor, dropout_probability, + philox_seed_ptr, philox_offset_ptr, + BiasType::no_bias, + attn_mask_type, + window_size_left, window_size_right, + dq_ptr, + 0, stride_h_dq, stride_s_dq, + dq_acc_ptr, + dk_expanded_ptr, + dv_expanded_ptr, + 0, stride_h_dk_expanded, stride_s_dk_expanded, + 0, stride_h_dv_expanded, stride_s_dv_expanded, + dk_ptr, + 0, stride_h_dk, stride_s_dk, + dv_ptr, + 0, stride_h_dv, stride_s_dv, + nullptr, + nullptr, + lse_workspace_ptr, deterministic, uses_bwd_v3, is_v3_atomic_fp32, - how_v3_bf16_cvt); - if(dump_path){ - dump_bwd_timings(dump_path, average_runtime); - } - if(average_runtime < 0){ - //TODO: better error out system - throw std::runtime_error("fused attn configs not supported in ck_fused_attn bwd pass."); + how_v3_bf16_cvt, + true, + __FUNCTION__, + ck_log_config, + stream); + if (impl_status != hipSuccess) { + return impl_status; } if(is_mqa_gqa){ dim3 grid(max_tokens_kv, hg); diff --git a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_fwd.cpp b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_fwd.cpp index f9e93515d..356413706 100644 --- a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_fwd.cpp +++ b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_fwd.cpp @@ -9,110 +9,88 @@ #include #include #include "ck_fused_attn/ck_fused_attn.hpp" -#include "ck_tile/host.hpp" #include "mha_fwd.h" #include "ck_fused_attn_utils.hpp" namespace ck_fused_attn{ -// print the fmha traits and args when calling ck apis -void log_fwd_config( - std::ostream* log_file, - const char* func_name, - const std::string data_type_str, - const bool is_group_mode, - const bool has_logits_soft_cap, - const mask_enum mask_type, - const bias_enum bias_type, - const bool has_lse, - const bool has_dropout, - const bool is_v_rowmajor, - const bool do_fp8_static_quant, - const bool uses_fwd_v3, - const bool how_v3_bf16_cvt, - const fmha_fwd_args& fmha_args -){ - *log_file << "\n" << func_name << "\n"; +// print the fmha traits and fmha_args when calling ck apis +void log_fwd_config(const char* func_name, bool has_dropout, const aiter::mha_fwd_args& fmha_args){ - // debug fmha_traits - *log_file << "\n" << "fmha_traits: " << "\n"; - *log_file << "hdim_q: " << fmha_args.hdim_q << "\n"; - *log_file << "hdim_v: " << fmha_args.hdim_v << "\n"; - *log_file << "data_type: " << data_type_str << "\n"; - *log_file << "is_group_mode: " << is_group_mode << "\n"; - *log_file << "is_v_rowmajor: " << is_v_rowmajor << "\n"; - *log_file << "has_logits_soft_cap: " << has_logits_soft_cap << "\n"; - *log_file << "mask_type: " << static_cast::type>(mask_type) << "\n"; - *log_file << "bias_type: " << static_cast::type>(bias_type) << "\n"; - *log_file << "has_lse: " << has_lse << "\n"; - *log_file << "has_dropout: " << has_dropout << "\n"; - *log_file << "do_fp8_static_quant: " << do_fp8_static_quant << "\n"; - *log_file << "skip_min_seqlen_q: " << (fmha_args.min_seqlen_q != 0) << "\n"; - *log_file << "uses_fwd_v3: " << uses_fwd_v3 << "\n"; - *log_file << "how_v3_bf16_cvt: " << how_v3_bf16_cvt << "\n"; + std::ostream* log_file = get_ck_log_stream(); + (*log_file) << "\n" << func_name << "\n"; + // debug fmha_traits + (*log_file) << "\nfmha_traits: \n"; + log_value(log_file, "hdim_q", fmha_args.hdim_q); + log_value(log_file, "hdim_v", fmha_args.hdim_v); + log_value(log_file, "data_type", fmha_args.data_type); + log_value(log_file, "is_group_mode", fmha_args.is_group_mode); + log_value(log_file, "has_lse", fmha_args.has_lse); + log_value(log_file, "has_dropout", has_dropout); + log_value(log_file, "skip_min_seqlen_q", (fmha_args.min_seqlen_q != 0)); + log_value(log_file, "use_asm_v3", fmha_args.use_asm_v3); + log_value(log_file, "how_v3_bf16_cvt", fmha_args.how_v3_bf16_cvt); // debug fmha_args - *log_file << "\n" << "fmha_args: " << "\n"; - - *log_file << "q_ptr: " << fmha_args.q_ptr << "\n"; - *log_file << "k_ptr: " << fmha_args.k_ptr << "\n"; - *log_file << "v_ptr: " << fmha_args.v_ptr << "\n"; - *log_file << "bias_ptr: " << fmha_args.bias_ptr << "\n"; - *log_file << "rand_val_ptr: " << fmha_args.rand_val_ptr << "\n"; - *log_file << "lse_ptr: " << fmha_args.lse_ptr << "\n"; - *log_file << "o_ptr: " << fmha_args.o_ptr << "\n"; - - *log_file << "seqstart_q_ptr: " << fmha_args.seqstart_q_ptr << "\n"; - *log_file << "seqstart_k_ptr: " << fmha_args.seqstart_k_ptr << "\n"; - *log_file << "seqlen_q_ptr: " << fmha_args.seqlen_q_ptr << "\n"; - *log_file << "seqlen_k_ptr: " << fmha_args.seqlen_k_ptr << "\n"; - *log_file << "cu_seqlen_q_ptr: " << fmha_args.cu_seqlen_q_ptr << "\n"; - *log_file << "cu_seqlen_k_ptr: " << fmha_args.cu_seqlen_k_ptr << "\n"; - - *log_file << "seqlen_q: " << fmha_args.seqlen_q << "\n"; - *log_file << "seqlen_k: " << fmha_args.seqlen_k << "\n"; - *log_file << "batch: " << fmha_args.batch << "\n"; - *log_file << "max_seqlen_q: " << fmha_args.max_seqlen_q << "\n"; - *log_file << "hdim_q: " << fmha_args.hdim_q << "\n"; - *log_file << "hdim_v: " << fmha_args.hdim_v << "\n"; - *log_file << "nhead_q: " << fmha_args.nhead_q << "\n"; - *log_file << "nhead_k: " << fmha_args.nhead_k << "\n"; - - *log_file << "scale_s: " << fmha_args.scale_s << "\n"; - - *log_file << "logits_soft_cap: " << fmha_args.logits_soft_cap << "\n"; - - *log_file << "stride_q: " << fmha_args.stride_q << "\n"; - *log_file << "stride_k: " << fmha_args.stride_k << "\n"; - *log_file << "stride_v: " << fmha_args.stride_v << "\n"; - *log_file << "stride_bias: " << fmha_args.stride_bias << "\n"; - *log_file << "stride_randval: " << fmha_args.stride_randval << "\n"; - *log_file << "stride_o: " << fmha_args.stride_o << "\n"; - *log_file << "nhead_stride_q: " << fmha_args.nhead_stride_q << "\n"; - *log_file << "nhead_stride_k: " << fmha_args.nhead_stride_k << "\n"; - *log_file << "nhead_stride_v: " << fmha_args.nhead_stride_v << "\n"; - *log_file << "nhead_stride_bias: " << fmha_args.nhead_stride_bias << "\n"; - *log_file << "nhead_stride_randval: " << fmha_args.nhead_stride_randval << "\n"; - *log_file << "nhead_stride_lse: " << fmha_args.nhead_stride_lse << "\n"; - *log_file << "nhead_stride_o: " << fmha_args.nhead_stride_o << "\n"; - *log_file << "batch_stride_q: " << fmha_args.batch_stride_q << "\n"; - *log_file << "batch_stride_k: " << fmha_args.batch_stride_k << "\n"; - *log_file << "batch_stride_v: " << fmha_args.batch_stride_v << "\n"; - *log_file << "batch_stride_bias: " << fmha_args.batch_stride_bias << "\n"; - *log_file << "batch_stride_randval: " << fmha_args.batch_stride_randval << "\n"; - *log_file << "batch_stride_lse: " << fmha_args.batch_stride_lse << "\n"; - *log_file << "batch_stride_o: " << fmha_args.batch_stride_o << "\n"; - - *log_file << "window_size_left: " << fmha_args.window_size_left << "\n"; - *log_file << "window_size_right: " << fmha_args.window_size_right << "\n"; - *log_file << "mask_type: " << fmha_args.mask_type << "\n"; - *log_file << "min_seqlen_q: " << fmha_args.min_seqlen_q << "\n"; - - *log_file << "p_drop: " << fmha_args.p_drop << "\n"; - *log_file << "s_randval: " << fmha_args.s_randval << "\n"; - - *log_file << "dropout_seed_ptr: " << std::get<0>(std::get>(fmha_args.drop_seed_offset)) << "\n"; - *log_file << "dropout_offset_ptr: " << std::get<1>(std::get>(fmha_args.drop_seed_offset)) << "\n"; + (*log_file) << "\nfmha_args: \n"; + + log_value(log_file, "q_ptr", fmha_args.q_ptr); + log_value(log_file, "k_ptr", fmha_args.k_ptr); + log_value(log_file, "v_ptr", fmha_args.v_ptr); + log_value(log_file, "bias_ptr", fmha_args.bias_ptr); + log_value(log_file, "rand_val_ptr", fmha_args.rand_val_ptr); + log_value(log_file, "lse_ptr", fmha_args.lse_ptr); + log_value(log_file, "o_ptr", fmha_args.o_ptr); + + log_value(log_file, "seqstart_q_ptr", fmha_args.seqstart_q_ptr); + log_value(log_file, "seqstart_k_ptr", fmha_args.seqstart_k_ptr); + log_value(log_file, "seqlen_q_ptr", fmha_args.seqlen_q_ptr); + log_value(log_file, "seqlen_k_ptr", fmha_args.seqlen_k_ptr); + log_value(log_file, "cu_seqlen_q_ptr", fmha_args.cu_seqlen_q_ptr); + log_value(log_file, "cu_seqlen_k_ptr", fmha_args.cu_seqlen_k_ptr); + + log_value(log_file, "seqlen_q", fmha_args.seqlen_q); + log_value(log_file, "seqlen_k", fmha_args.seqlen_k); + log_value(log_file, "batch", fmha_args.batch); + log_value(log_file, "max_seqlen_q", fmha_args.max_seqlen_q); + log_value(log_file, "hdim_q", fmha_args.hdim_q); + log_value(log_file, "hdim_v", fmha_args.hdim_v); + log_value(log_file, "nhead_q", fmha_args.nhead_q); + log_value(log_file, "nhead_k", fmha_args.nhead_k); + log_value(log_file, "scale_s", fmha_args.scale_s); + log_value(log_file, "logits_soft_cap", fmha_args.logits_soft_cap); + + log_value(log_file, "stride_q", fmha_args.stride_q); + log_value(log_file, "stride_k", fmha_args.stride_k); + log_value(log_file, "stride_v", fmha_args.stride_v); + log_value(log_file, "stride_bias", fmha_args.stride_bias); + log_value(log_file, "stride_randval", fmha_args.stride_randval); + log_value(log_file, "stride_o", fmha_args.stride_o); + log_value(log_file, "nhead_stride_q", fmha_args.nhead_stride_q); + log_value(log_file, "nhead_stride_k", fmha_args.nhead_stride_k); + log_value(log_file, "nhead_stride_v", fmha_args.nhead_stride_v); + log_value(log_file, "nhead_stride_bias", fmha_args.nhead_stride_bias); + log_value(log_file, "nhead_stride_randval", fmha_args.nhead_stride_randval); + log_value(log_file, "nhead_stride_lse", fmha_args.nhead_stride_lse); + log_value(log_file, "nhead_stride_o", fmha_args.nhead_stride_o); + log_value(log_file, "batch_stride_q", fmha_args.batch_stride_q); + log_value(log_file, "batch_stride_k", fmha_args.batch_stride_k); + log_value(log_file, "batch_stride_v", fmha_args.batch_stride_v); + log_value(log_file, "batch_stride_bias", fmha_args.batch_stride_bias); + log_value(log_file, "batch_stride_randval", fmha_args.batch_stride_randval); + log_value(log_file, "batch_stride_lse", fmha_args.batch_stride_lse); + log_value(log_file, "batch_stride_o", fmha_args.batch_stride_o); + + log_value(log_file, "window_size_left", fmha_args.window_size_left); + log_value(log_file, "window_size_right", fmha_args.window_size_right); + log_value(log_file, "mask_type", fmha_args.mask_type); + log_value(log_file, "bias_type", fmha_args.bias_type); + log_value(log_file, "min_seqlen_q", fmha_args.min_seqlen_q); + log_value(log_file, "p_drop", fmha_args.p_drop); + log_value(log_file, "s_randval", fmha_args.s_randval); + + log_value(log_file, "dropout_seed_ptr", std::get<0>(std::get>(fmha_args.drop_seed_offset))); + log_value(log_file, "dropout_offset_ptr", std::get<1>(std::get>(fmha_args.drop_seed_offset))); } void dump_fwd_timings(const char* dump_path, float average_runtime){ @@ -121,17 +99,20 @@ void dump_fwd_timings(const char* dump_path, float average_runtime){ file << average_runtime << "\n"; } -hipError_t ck_attn_fwd( +hipError_t _ck_attn_fwd_impl( DType dtype, uint64_t b, uint64_t h, uint64_t hg, uint64_t s_q, uint64_t s_kv, uint64_t d_qk, uint64_t d_v, uint64_t bias_b, uint64_t bias_h, - const void* q_ptr, + uint64_t max_tokens_q, + const void* q_ptr, uint64_t stride_b_q, uint64_t stride_h_q, uint64_t stride_s_q, - const void* k_ptr, + const void* k_ptr, uint64_t stride_b_k, uint64_t stride_h_k, uint64_t stride_s_k, - const void* v_ptr, + const void* v_ptr, uint64_t stride_b_v, uint64_t stride_h_v, uint64_t stride_s_v, const void* bias_ptr, const void* alibi_slope_ptr, + const void* cu_seqlen_q_ptr, const void* cu_seqlen_kv_ptr, + const void* cu_seqlen_q_padded_ptr, const void* cu_seqlen_kv_padded_ptr, bool is_training, float scaling_factor, float dropout_probability, @@ -139,11 +120,13 @@ hipError_t ck_attn_fwd( BiasType attn_bias_type, MaskType attn_mask_type, int64_t window_size_left, int64_t window_size_right, - void* o_ptr, + void* o_ptr, uint64_t stride_b_o, uint64_t stride_h_o, uint64_t stride_s_o, void* lse_ptr, bool uses_fwd_v3, int how_v3_bf16_cvt, + bool is_group_mode, + const char* func_name, hipStream_t stream){ bool has_dropout = (is_training && dropout_probability > 0.f); @@ -157,150 +140,216 @@ hipError_t ck_attn_fwd( ck_tile::index_t hdim_v = d_v; ck_tile::index_t max_seqlen_q = s_q; ck_tile::index_t max_seqlen_k = s_kv; + float scale_s = scaling_factor; float logits_soft_cap = 0.f; float p_drop = dropout_probability; - bool is_group_mode = false; - bool is_v_rowmajor = true; - bool has_logits_soft_cap = 0.f < logits_soft_cap; - bool do_fp8_static_quant = false; - - bias_enum bias_type; - BiasShape bias_shape; - std::tie(bias_type, bias_shape) = get_ck_bias_type_shape(attn_bias_type, b, h, bias_b, bias_h); - + ck_tile::index_t left, right; left = window_size_left; right = window_size_right; mask_enum mask_type = static_cast(attn_mask_type); + bool ck_log_config = false; + if (const char* env_p = std::getenv("CK_FUSED_ATTN_LOG_CONFIG") ) { + if (env_p != nullptr && std::string(env_p) == "1") + ck_log_config = true; + } const char* dump_path = std::getenv("NVTE_DUMP_AITER_RT"); // print kernel name on verbose mode ck_tile::stream_config stream_config{stream, dump_path!=nullptr, get_ck_log_stream() != nullptr}; - std::string data_type_str = get_data_type_str(dtype); - - auto fmha_args = [&]() { - // setup stride_* arguments - const ck_tile::index_t stride_q = stride_s_q; - const ck_tile::index_t stride_k = stride_s_k; - const ck_tile::index_t stride_v = stride_s_v; - // bias is of shape [b, h , s_q, s_kv] - const ck_tile::index_t stride_bias = max_seqlen_k; - const ck_tile::index_t stride_randval = max_seqlen_k; - const ck_tile::index_t stride_o = stride_s_o; - // setup nhead_stride_* arguments - const ck_tile::index_t nhead_stride_q = stride_h_q; - const ck_tile::index_t nhead_stride_k = stride_h_k; - const ck_tile::index_t nhead_stride_v = stride_h_v; - const ck_tile::index_t nhead_stride_bias = (bias_shape==BiasShape::k1HSS || bias_shape==BiasShape::kBHSS) ? max_seqlen_q * max_seqlen_k: 0; - //TODO: randval never used, can we remove it - const ck_tile::index_t nhead_stride_randval = 0; - // softmax_lse is of shape [b, h, s_q] - const ck_tile::index_t nhead_stride_lse = max_seqlen_q; - const ck_tile::index_t nhead_stride_o = stride_h_o; - // setup batch_stride_* arguments - const ck_tile::index_t batch_stride_q = stride_b_q; - const ck_tile::index_t batch_stride_k = stride_b_k; - const ck_tile::index_t batch_stride_v = stride_b_v; - const ck_tile::index_t batch_stride_bias = (bias_shape==BiasShape::k11SS || bias_shape==BiasShape::k1HSS) ? 0: (bias_shape==BiasShape::kBHSS? bias_h* max_seqlen_q * max_seqlen_k: max_seqlen_q*max_seqlen_k); - //TODO: randval never used, can we remove it - const ck_tile::index_t batch_stride_randval = 0; - // softmax_lse is of shape [b, h, s_q] - const ck_tile::index_t batch_stride_lse = nhead * max_seqlen_q; - const ck_tile::index_t batch_stride_o = stride_b_o; - - return fmha_fwd_args{q_ptr, - k_ptr, - v_ptr, - bias_type==bias_enum::alibi? alibi_slope_ptr :bias_ptr, - nullptr, //q_descale_ptr - nullptr, //k_descale_ptr - nullptr, //v_descale_ptr - nullptr,//rand_val_ptr - lse_ptr, - o_ptr, - nullptr, //seqstart_q_ptr - nullptr, //seqstart_k_ptr - nullptr, //seqlen_q_ptr - nullptr, //seqlen_k_ptr - nullptr, //cu_padded_q_ptr - nullptr, //cu_padded_k_ptr - max_seqlen_q, - max_seqlen_k, - batch, - max_seqlen_q, - hdim_q, - hdim_v, - nhead, - nhead_k, - scale_s, - logits_soft_cap, - stride_q, - stride_k, - stride_v, - bias_type==bias_enum::alibi? 0: stride_bias, // upstream TE only requires standard (vanilla) alibi slopes - stride_randval, - stride_o, - nhead_stride_q, - nhead_stride_k, - nhead_stride_v, - nhead_stride_bias, - nhead_stride_randval, - nhead_stride_lse, - nhead_stride_o, - batch_stride_q, - batch_stride_k, - batch_stride_v, - batch_stride_bias, - batch_stride_randval, - batch_stride_lse, - batch_stride_o, - left, - right, - 0, // sink_size - static_cast(mask_type), - 0, // min_seqlen_q - p_drop, - false, - std::pair{philox_seed_ptr, philox_offset_ptr}}; - }(); - - // print ck traits and args when needed - if (auto* log_file = get_ck_log_stream()) { - log_fwd_config(log_file, __FUNCTION__, data_type_str, is_group_mode, has_logits_soft_cap, mask_type, bias_type, has_lse, has_dropout, is_v_rowmajor, do_fp8_static_quant, uses_fwd_v3, how_v3_bf16_cvt, fmha_args); + bias_enum bias_type = bias_enum::no_bias; + BiasShape bias_shape = BiasShape::k11SS; + + aiter::mha_fwd_args fmha_args{}; + fmha_args.q_ptr = q_ptr; + fmha_args.k_ptr = k_ptr; + fmha_args.v_ptr = v_ptr; + + fmha_args.batch = batch; + fmha_args.seqlen_q = max_seqlen_q; // unused in group mode + fmha_args.hdim_q = hdim_q; + fmha_args.hdim_v = hdim_v; + fmha_args.nhead_q = nhead; + fmha_args.nhead_k = nhead_k; + + fmha_args.stride_q = stride_s_q; + fmha_args.stride_k = stride_s_k; + fmha_args.stride_v = stride_s_v; + fmha_args.nhead_stride_q = stride_h_q; + fmha_args.nhead_stride_k = stride_h_k; + fmha_args.nhead_stride_v = stride_h_v; + fmha_args.batch_stride_q = stride_b_q; + fmha_args.batch_stride_k = stride_b_k; + fmha_args.batch_stride_v = stride_b_v; + + if (is_group_mode) { + fmha_args.seqstart_q_ptr = cu_seqlen_q_padded_ptr==nullptr? cu_seqlen_q_ptr: cu_seqlen_q_padded_ptr; + fmha_args.seqstart_k_ptr = cu_seqlen_kv_padded_ptr==nullptr? cu_seqlen_kv_ptr: cu_seqlen_kv_padded_ptr; + fmha_args.seqlen_q_ptr = nullptr; + fmha_args.seqlen_k_ptr = nullptr; + fmha_args.cu_seqlen_q_ptr = cu_seqlen_q_ptr; + fmha_args.cu_seqlen_k_ptr = cu_seqlen_kv_ptr; + } else { + fmha_args.seqstart_q_ptr = nullptr; + fmha_args.seqstart_k_ptr = nullptr; + fmha_args.seqlen_q_ptr = nullptr; + fmha_args.seqlen_k_ptr = nullptr; + fmha_args.cu_seqlen_q_ptr = nullptr; + fmha_args.cu_seqlen_k_ptr = nullptr; + std::tie(bias_type, bias_shape) = get_ck_bias_type_shape(attn_bias_type, b, h, bias_b, bias_h); } - float average_runtime = aiter::mha_fwd(fmha_args, - stream_config, - data_type_str, - is_group_mode, - mask_type, - bias_type, - has_lse, - quant_scale_enum::no_scale, - uses_fwd_v3, - false,//has_sink - how_v3_bf16_cvt); - if(dump_path){ - dump_fwd_timings(dump_path, average_runtime); + fmha_args.bias_ptr = bias_type==bias_enum::alibi? alibi_slope_ptr :bias_ptr; + fmha_args.lse_ptr = lse_ptr; + fmha_args.o_ptr = o_ptr; + + fmha_args.block_scale_seqstart_q_ptr = nullptr; + fmha_args.block_scale_seqstart_k_ptr = nullptr; + fmha_args.sink_ptr = nullptr; + fmha_args.seqlen_k = max_seqlen_k; // unused in group mode (or kvcache enabled) + fmha_args.max_seqlen_q = max_seqlen_q; + + fmha_args.scale_s = scale_s; + + fmha_args.logits_soft_cap = logits_soft_cap; + + // bias is of shape [b, h , s_q, s_kv] + fmha_args.stride_bias = is_group_mode? 0 : (bias_type==bias_enum::alibi? 0: max_seqlen_k); + fmha_args.stride_o = stride_s_o; + fmha_args.nhead_stride_bias = get_nhead_stride_bias(bias_shape, max_seqlen_q, max_seqlen_k, is_group_mode); + fmha_args.batch_stride_bias = get_batch_stride_bias(bias_h, bias_shape, max_seqlen_q, max_seqlen_k, is_group_mode, true); + // softmax_lse is of shape [b, h, s_q] + fmha_args.nhead_stride_lse = is_group_mode? max_tokens_q : max_seqlen_q; + fmha_args.batch_stride_lse = is_group_mode? 0 : nhead * max_seqlen_q; + fmha_args.nhead_stride_o = stride_h_o; + fmha_args.batch_stride_o = stride_b_o; + + fmha_args.window_size_left = left; + fmha_args.window_size_right = right; + fmha_args.mask_type = static_cast(mask_type); + + fmha_args.rand_val_ptr = nullptr; + + fmha_args.stride_randval = max_seqlen_k; + // Unused + fmha_args.nhead_stride_randval = 0; + fmha_args.batch_stride_randval = 0; + fmha_args.nhead_stride_q_descale = 0; + fmha_args.nhead_stride_k_descale = 0; + fmha_args.nhead_stride_v_descale = 0; + fmha_args.batch_stride_q_descale = 0; + fmha_args.batch_stride_k_descale = 0; + fmha_args.batch_stride_v_descale = 0; + + fmha_args.p_drop = p_drop; + fmha_args.s_randval = 0; + fmha_args.drop_seed_offset = std::pair{philox_seed_ptr, philox_offset_ptr}; + fmha_args.use_asm_v3 = uses_fwd_v3; + fmha_args.how_v3_bf16_cvt = how_v3_bf16_cvt; + fmha_args.v3_api_check = false; + fmha_args.data_type = get_data_type_str(dtype); + fmha_args.is_group_mode = is_group_mode; + fmha_args.bias_type = static_cast(bias_type); + fmha_args.has_lse = lse_ptr!=nullptr; + fmha_args.qscale_type = static_cast(quant_scale_enum::no_scale); + fmha_args.has_sink = false; + fmha_args.q_descale_ptr = nullptr; + fmha_args.k_descale_ptr = nullptr; + fmha_args.v_descale_ptr = nullptr; + fmha_args.sink_size = 0; + fmha_args.min_seqlen_q = 0; + fmha_args.block_scale_size_q = 0; + fmha_args.block_scale_size_kv = 0; + + if(const char* env_p = std::getenv("NVTE_CK_RUNTIME_MAX_SEQLEN")){ + if(is_group_mode && std::string(env_p) == "1"){ + if(ck_log_config){ + std::cout << "attn_fwd(ck): Enabling runtime max_seqlen calculation for small seqlen optimization."; + } + fmha_args.max_seqlen_q = get_runtime_max_seqlen(b, cu_seqlen_q_ptr, cu_seqlen_q_padded_ptr, lse_ptr, stream); + } } + + // print ck traits and fmha_args when needed + if(ck_log_config){ + log_fwd_config(func_name, has_dropout, fmha_args); + } + float average_runtime = aiter::mha_fwd(fmha_args, stream_config); if(average_runtime < 0){ //TODO: better error out system throw std::runtime_error("fused attn configs not supported in ck_fused_attn fwd pass."); } + if(dump_path){ + dump_fwd_timings(dump_path, average_runtime); + } return hipSuccess; } +hipError_t ck_attn_fwd( + DType dtype, + uint64_t b, uint64_t h, uint64_t hg, uint64_t s_q, uint64_t s_kv, uint64_t d_qk, uint64_t d_v, uint64_t bias_b, uint64_t bias_h, + const void* q_ptr, + uint64_t stride_b_q, uint64_t stride_h_q, uint64_t stride_s_q, + const void* k_ptr, + uint64_t stride_b_k, uint64_t stride_h_k, uint64_t stride_s_k, + const void* v_ptr, + uint64_t stride_b_v, uint64_t stride_h_v, uint64_t stride_s_v, + const void* bias_ptr, + const void* alibi_slope_ptr, + bool is_training, + float scaling_factor, + float dropout_probability, + void* philox_seed_ptr, void* philox_offset_ptr, + BiasType attn_bias_type, + MaskType attn_mask_type, + int64_t window_size_left, int64_t window_size_right, + void* o_ptr, + uint64_t stride_b_o, uint64_t stride_h_o, uint64_t stride_s_o, + void* lse_ptr, + bool uses_fwd_v3, + int how_v3_bf16_cvt, + hipStream_t stream){ + + return _ck_attn_fwd_impl( + dtype, + b, h, hg, s_q, s_kv, d_qk, d_v, bias_b, bias_h, + 0, + q_ptr, stride_b_q, stride_h_q, stride_s_q, + k_ptr, stride_b_k, stride_h_k, stride_s_k, + v_ptr, stride_b_v, stride_h_v, stride_s_v, + bias_ptr, + alibi_slope_ptr, + nullptr, nullptr, // cu_seqlen_q_ptr, cu_seqlen_kv_ptr, + nullptr, nullptr, // cu_seqlen_q_padded_ptr, cu_seqlen_kv_padded_ptr + is_training, + scaling_factor, + dropout_probability, + philox_seed_ptr, philox_offset_ptr, + attn_bias_type, + attn_mask_type, + window_size_left, window_size_right, + o_ptr, + stride_b_o, stride_h_o, stride_s_o, + lse_ptr, + uses_fwd_v3, + how_v3_bf16_cvt, + false, + __FUNCTION__, // func_name + stream + ); +} + hipError_t ck_attn_varlen_fwd( DType dtype, uint64_t b, uint64_t h, uint64_t hg, uint64_t s_q, uint64_t s_kv, uint64_t d_qk, uint64_t d_v, uint64_t max_tokens_q, - const void* q_ptr, + const void* q_ptr, uint64_t stride_h_q, uint64_t stride_s_q, - const void* k_ptr, + const void* k_ptr, uint64_t stride_h_k, uint64_t stride_s_k, - const void* v_ptr, + const void* v_ptr, uint64_t stride_h_v, uint64_t stride_s_v, const void* cu_seqlen_q_ptr, const void* cu_seqlen_kv_ptr, const void* cu_seqlen_q_padded_ptr, const void* cu_seqlen_kv_padded_ptr, @@ -310,172 +359,40 @@ hipError_t ck_attn_varlen_fwd( void* philox_seed_ptr, void* philox_offset_ptr, MaskType attn_mask_type, int64_t window_size_left, int64_t window_size_right, - void* o_ptr, + void* o_ptr, uint64_t stride_h_o, uint64_t stride_s_o, void* lse_thd_ptr, bool uses_fwd_v3, int how_v3_bf16_cvt, hipStream_t stream){ - bool has_dropout = (is_training && dropout_probability > 0.f); - bool has_lse = (lse_thd_ptr != nullptr); - - /* CK input parameters */ - ck_tile::index_t batch = b; - ck_tile::index_t nhead = h; - ck_tile::index_t hdim_q = d_qk; - ck_tile::index_t nhead_k = hg; - ck_tile::index_t hdim_v = d_v; - ck_tile::index_t max_seqlen_q = s_q; - ck_tile::index_t max_seqlen_kv = s_kv; - - float scale_s = scaling_factor; - float logits_soft_cap = 0.f; - float p_drop = dropout_probability; - bool is_group_mode = true; - bool is_v_rowmajor = true; - bool has_logits_soft_cap = 0.f < logits_soft_cap; - bool do_fp8_static_quant = false; - - // THD does not work with bias - - ck_tile::index_t left, right; - left = window_size_left; - right = window_size_right; - mask_enum mask_type = static_cast(attn_mask_type); - - bias_enum bias_type = bias_enum::no_bias; - - const char* dump_path = std::getenv("NVTE_DUMP_AITER_RT"); - // print kernel name on verbose mode - ck_tile::stream_config stream_config{stream, dump_path!=nullptr, get_ck_log_stream() != nullptr}; - - - std::string data_type_str = get_data_type_str(dtype); - - auto fmha_args = [&]() { - // setup stride_* arguments - const ck_tile::index_t stride_q = stride_s_q; - const ck_tile::index_t stride_k = stride_s_k; - const ck_tile::index_t stride_v = stride_s_v; - // bias not used in THD qkv layout - const ck_tile::index_t stride_bias = 0; - // randval not used - const ck_tile::index_t stride_randval = 0; - const ck_tile::index_t stride_o = stride_s_o; - // setup nhead_stride_* arguments - const ck_tile::index_t nhead_stride_q = stride_h_q; - const ck_tile::index_t nhead_stride_k = stride_h_k; - const ck_tile::index_t nhead_stride_v = stride_h_v; - // bias not used in THD qkv layout - const ck_tile::index_t nhead_stride_bias = 0; - //TODO: randval never used, can we remove it - const ck_tile::index_t nhead_stride_randval = 0; - // use packed lse of shape [h, max_tokens_q] - const ck_tile::index_t nhead_stride_lse = max_tokens_q; - const ck_tile::index_t nhead_stride_o = stride_h_o; - // setup batch_stride_* arguments - const ck_tile::index_t batch_stride_q = 0; - const ck_tile::index_t batch_stride_k = 0; - const ck_tile::index_t batch_stride_v = 0; - // bias not used in THD qkv layout - const ck_tile::index_t batch_stride_bias = 0; - //TODO: randval never used, can we remove it - const ck_tile::index_t batch_stride_randval = 0; - const ck_tile::index_t batch_stride_lse = 0; - const ck_tile::index_t batch_stride_o = 0; - - return fmha_fwd_args{q_ptr, - k_ptr, - v_ptr, - nullptr,//bias_ptr - nullptr, //q_descale_ptr - nullptr, //k_descale_ptr - nullptr, //v_descale_ptr - nullptr,//rand_val_ptr - lse_thd_ptr, - o_ptr, - cu_seqlen_q_padded_ptr==nullptr? cu_seqlen_q_ptr: cu_seqlen_q_padded_ptr, //seqstart_q_ptr - cu_seqlen_kv_padded_ptr==nullptr? cu_seqlen_kv_ptr: cu_seqlen_kv_padded_ptr, //seqstart_k_ptr - nullptr, //seqlen_q_ptr - nullptr, //seqlen_k_ptr - cu_seqlen_q_ptr, //cu_seqlen_q_ptr - cu_seqlen_kv_ptr, //cu_seqlen_k_ptr - max_seqlen_q, //seqlen_q, unused in group mode - max_seqlen_kv, //seqlen_kv, unused in group mode - batch, - max_seqlen_q, - hdim_q, - hdim_v, - nhead, - nhead_k, - scale_s, - logits_soft_cap, - stride_q, - stride_k, - stride_v, - stride_bias, - stride_randval, - stride_o, - nhead_stride_q, - nhead_stride_k, - nhead_stride_v, - nhead_stride_bias, - nhead_stride_randval, - nhead_stride_lse, - nhead_stride_o, - batch_stride_q, - batch_stride_k, - batch_stride_v, - batch_stride_bias, - batch_stride_randval, - batch_stride_lse, - batch_stride_o, - left, - right, - 0, // sink_size - static_cast(mask_type), - 0, // min_seqlen_q - p_drop, - false, - std::pair{philox_seed_ptr, philox_offset_ptr}}; - }(); - // modify the max_seqlen_q for better performance in 0-length cases - // lse_thd_ptr used as buffer - if(const char* env_p = std::getenv("NVTE_CK_RUNTIME_MAX_SEQLEN")){ - if(std::string(env_p) == "1"){ - if (auto* log_file = get_ck_log_stream()) { - *log_file - << "attn_fwd(ck): Enabling runtime max_seqlen calculation for small seqlen optimization.\n"; - } - fmha_args.max_seqlen_q = get_runtime_max_seqlen(b, cu_seqlen_q_ptr, cu_seqlen_q_padded_ptr, lse_thd_ptr, stream); - } - } - // print ck traits and args when needed - if (auto* log_file = get_ck_log_stream()) { - log_fwd_config(log_file, __FUNCTION__, data_type_str, is_group_mode, has_logits_soft_cap, mask_type, bias_type, has_lse, has_dropout, is_v_rowmajor, do_fp8_static_quant, uses_fwd_v3, how_v3_bf16_cvt, fmha_args); - } - - float average_runtime = aiter::mha_fwd( - fmha_args, - stream_config, - data_type_str, - is_group_mode, - mask_type, - bias_type, - has_lse, - quant_scale_enum::no_scale, - uses_fwd_v3, - false,//has_sink - how_v3_bf16_cvt); - if(dump_path){ - dump_fwd_timings(dump_path, average_runtime); - } - if(average_runtime < 0){ - //TODO: better error out system - throw std::runtime_error("fused attn configs not supported in ck_fused_attn fwd pass."); - } - return hipSuccess; + return _ck_attn_fwd_impl( + dtype, + b, h, hg, s_q, s_kv, d_qk, d_v, 0, 0, + max_tokens_q, + q_ptr, 0, stride_h_q, stride_s_q, + k_ptr, 0, stride_h_k, stride_s_k, + v_ptr, 0, stride_h_v, stride_s_v, + nullptr, // bias_ptr, + nullptr, // alibi_slope_ptr + cu_seqlen_q_ptr, cu_seqlen_kv_ptr, + cu_seqlen_q_padded_ptr, cu_seqlen_kv_padded_ptr, + is_training, + scaling_factor, + dropout_probability, + philox_seed_ptr, philox_offset_ptr, + BiasType::no_bias, + attn_mask_type, + window_size_left, window_size_right, + o_ptr, + 0, stride_h_o, stride_s_o, + lse_thd_ptr, + uses_fwd_v3, + how_v3_bf16_cvt, + true, + __FUNCTION__, // func_name + stream + ); } }//namespace ck_fused_attn diff --git a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.cpp b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.cpp index 72a0122d9..558fd5fdc 100644 --- a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.cpp +++ b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.cpp @@ -5,6 +5,12 @@ ************************************************************************/ #include +#include +#include +#include //once_flag + +#include + #include #include #include @@ -17,6 +23,89 @@ namespace ck_fused_attn{ +ck_tile::index_t get_batch_stride_bias( + ck_tile::index_t bias_h, + BiasShape bias_shape, + ck_tile::index_t max_seqlen_q, + ck_tile::index_t max_seqlen_k, + bool is_group_mode, + bool is_fwd +){ + if(is_group_mode){ + return 0; + } + switch (bias_shape) { + case BiasShape::k11SS: + case BiasShape::k1HSS: + return 0; + case BiasShape::kB1SS: + // dbias must be BHSS + if(is_fwd){ + return max_seqlen_q * max_seqlen_k; + } + case BiasShape::kBHSS: + return bias_h * max_seqlen_q * max_seqlen_k; + default: + throw std::runtime_error("Invalid bias shape"); + } +} +// for B1SS and BHSS, batch stride for bias are both +// bias_h x s_q x s_kv (bias_h==1 for B1SS and bias_h == h for BHSS) +ck_tile::index_t get_nhead_stride_bias( + BiasShape bias_shape, + ck_tile::index_t max_seqlen_q, + ck_tile::index_t max_seqlen_k, + bool is_group_mode +){ + if(is_group_mode){ + return 0; + } + switch (bias_shape) { + case BiasShape::k1HSS: + case BiasShape::kBHSS: + return max_seqlen_q * max_seqlen_k; + case BiasShape::k11SS: + case BiasShape::kB1SS: + return 0; + default: + throw std::runtime_error("Invalid bias shape"); + } +} + +void set_aiter_asm_dir() { + static std::once_flag aiter_asm_dir_once; + std::call_once(aiter_asm_dir_once, []() { + Dl_info info; + dladdr((void*)set_aiter_asm_dir, &info); + const char* log_ck_config_env = std::getenv("NVTE_LOG_CK_CONFIG"); + bool log_ck_config = log_ck_config_env && std::string(log_ck_config_env) == "1"; + // Check if user has set AITER_ASM_DIR, if yes, skip auto setting and log + // the value if NVTE_LOG_CK_CONFIG is set + const char* aiter_asm_dir = std::getenv("AITER_ASM_DIR"); + if (aiter_asm_dir) { + if (log_ck_config) { + std::cout << "AITER_ASM_DIR is set by user to: " << aiter_asm_dir << std::endl; + } + return; + } + // Check standard path + auto install_lib_path = std::filesystem::path(info.dli_fname).parent_path() / "aiter"; + if(std::filesystem::exists(install_lib_path)) { + setenv("AITER_ASM_DIR", install_lib_path.c_str(), 1); + if (log_ck_config) { + std::cout << "AITER_ASM_DIR set to: " << getenv("AITER_ASM_DIR") << std::endl; + } + return; + } + if(log_ck_config) { + std::cout << "Checked AITER_ASM_DIR path: " << install_lib_path << " does not exist." << std::endl; + } + }); +} + + +const bool aiterAsmDirInitialized = (set_aiter_asm_dir(), true); + bool open_ck_fused_attn_log_file(std::ofstream& log_file, const char* file_prefix, const std::string& log_dir_str) { // Explicitly use std::cout as a fallback std::filesystem::path log_dir(log_dir_str); diff --git a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.hpp b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.hpp index 3270ec32f..4056e5962 100644 --- a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.hpp +++ b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.hpp @@ -9,6 +9,7 @@ #include #include +#include "ck_tile/host.hpp" //forward declaration for ck_tile enum enum class mask_enum; @@ -55,6 +56,27 @@ std::pair get_ck_bias_type_shape(BiasType attn_bias_type, uint64_t get_runtime_max_seqlen(uint64_t b, const void* cu_seqlen_ptr, const void* cu_seqlen_padded_ptr, void* workspace, hipStream_t stream); +// This helper merely standardizes the logging to make it a bit easier to parse +// through it at a glance while guaranteeing uniformity. +template +void log_value(std::ostream* log_file, const char* label, const T& value) { + (*log_file) << label << ": " << value << "\n"; +} + +ck_tile::index_t get_batch_stride_bias( + ck_tile::index_t bias_h, + BiasShape bias_shape, + ck_tile::index_t max_seqlen_q, + ck_tile::index_t max_seqlen_k, + bool is_group_mode, + bool is_fwd +); +ck_tile::index_t get_nhead_stride_bias( + BiasShape bias_shape, + ck_tile::index_t max_seqlen_q, + ck_tile::index_t max_seqlen_k, + bool is_group_mode +); std::ostream* get_ck_log_stream(); }//namespace ck_fused_attn diff --git a/transformer_engine/common/gemm/ck_grouped_gemm.cpp b/transformer_engine/common/gemm/ck_grouped_gemm.cpp index 61c08f0de..69ff4f053 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm.cpp +++ b/transformer_engine/common/gemm/ck_grouped_gemm.cpp @@ -110,7 +110,7 @@ struct Runner{ Partitioner::MPerBlock, Partitioner::NPerBlock, TileCfg::M_Warp, TileCfg::N_Warp, TileCfg::M_Warp_Tile, TileCfg::N_Warp_Tile, TileCfg::K_Warp_Tile, - Problem::TransposeC, MemOp>>; + Problem::TransposeC>>; using Kernel = ck_tile::GroupedGemmKernel; }; @@ -131,7 +131,8 @@ static bool run_grouped_impl(const NVTETensor* A_use, const size_t needed = Kernel::GetWorkSpaceSize(group_num); if (!workspace || workspace_bytes < needed) { - NVTE_ERROR("ck_tile_grouped_gemm: insufficient workspace. Needed bytes=", needed); + NVTE_WARN("ck_tile_grouped_gemm: insufficient workspace for CK path. Needed bytes=", needed, + ", available bytes=", workspace_bytes, ". Falling back."); return false; } @@ -197,7 +198,8 @@ static bool run_grouped_impl(const NVTETensor* A_use, const dim3 grids = Kernel::GridSize(descs); auto kargs = Kernel::MakeKargs(descs); if (!Kernel::IsSupportedArgument(kargs)) { - NVTE_ERROR("ck_tile_grouped_gemm: CK_Tile kernel arguments not supported for this config."); + NVTE_WARN("ck_tile_grouped_gemm: CK_Tile kernel arguments not supported for this config. " + "Falling back."); return false; } @@ -235,6 +237,12 @@ bool ck_tile_grouped_gemm(const NVTETensor* A, if (group_num <= 0) return true; + // The current CK grouped GEMM path uses CShuffleEpilogueProblem without an explicit + // memory-operation template argument, so D accumulation semantics are not guaranteed. + // Fall back for accumulate=true to preserve numerics. + if (accumulate) + return false; + using namespace transformer_engine; using namespace transformer_engine::grouped_gemm;