Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -65,46 +65,19 @@ __device__ __forceinline__ size_t get_current_tensor_id(
template <typename IType, int kHadamardDimension, int BUFF_DIM_Y, int BUFF_DIM_X,
bool kReturnPreRhtAmax, bool kReturnIdentityAmax, bool kReturnTransposedAmax>
__device__ __forceinline__ void ComputeKernel(uint32_t b_frag_i[4], uint32_t b_frag_t[4],
IType* in_sh_ptr, uint32_t& local_pre_rht_amax_reg,
IType* in_sh_ptr, int swizzle_idx,
uint32_t& local_pre_rht_amax_reg,
uint32_t& local_amax_reg,
uint32_t& local_amax_t_reg) {
uint32_t a_frag[4]; // A matrix fragment
uint32_t c_frag[4]; // Result fragment

int warp_id = threadIdx.x / kThreadsPerWarp;
int local_rank = (threadIdx.x % kThreadsPerWarp);

int ld_row_idx = local_rank % kHadamardDimension;
int ld_col_idx = local_rank / kHadamardDimension + warp_id * 2;
int swizzle_idx = swizzle_128B_atom_32B(ld_row_idx, ld_col_idx);

uint32_t temp_amax_reg;
uint32_t temp_amax_t_reg;

if (kReturnIdentityAmax) {
ldmatrix_x4_m8n8_shared_b16<false>(a_frag[0], a_frag[1], a_frag[2], a_frag[3],
reinterpret_cast<uint4*>(in_sh_ptr) + swizzle_idx);

mma_m16_n16_k16_b16_b16_b16_noacc<kReturnIdentityAmax>(
a_frag[0], a_frag[1], a_frag[2], a_frag[3], b_frag_i[0], b_frag_i[1], b_frag_i[2],
b_frag_i[3], c_frag[0], c_frag[1], c_frag[2], c_frag[3], temp_amax_reg);
asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t"
: "=r"(local_amax_reg)
: "r"(local_amax_reg), "r"(temp_amax_reg));
}

if (kReturnTransposedAmax) {
// TODO(Frank): This is not efficient, since we could directly load the
// matrix in transposed layout.
if (!kReturnIdentityAmax) {
ldmatrix_x4_m8n8_shared_b16<false>(a_frag[0], a_frag[1], a_frag[2], a_frag[3],
reinterpret_cast<uint4*>(in_sh_ptr) + swizzle_idx);
}

matrix_transpose_m8_n8_b16_inplace(a_frag[0]);
matrix_transpose_m8_n8_b16_inplace(a_frag[1]);
matrix_transpose_m8_n8_b16_inplace(a_frag[2]);
matrix_transpose_m8_n8_b16_inplace(a_frag[3]);
ldmatrix_x4_m8n8_shared_b16<true>(a_frag[0], a_frag[1], a_frag[2], a_frag[3],
reinterpret_cast<uint4*>(in_sh_ptr) + swizzle_idx);

mma_m16_n16_k16_b16_b16_b16_noacc<kReturnTransposedAmax>(
a_frag[0], a_frag[2], a_frag[1], a_frag[3], b_frag_t[0], b_frag_t[1], b_frag_t[2],
Expand All @@ -115,7 +88,7 @@ __device__ __forceinline__ void ComputeKernel(uint32_t b_frag_i[4], uint32_t b_f
}

if (kReturnPreRhtAmax) {
if (!kReturnIdentityAmax && !kReturnTransposedAmax) {
if (!kReturnTransposedAmax) {
ldmatrix_x4_m8n8_shared_b16<false>(a_frag[0], a_frag[1], a_frag[2], a_frag[3],
reinterpret_cast<uint4*>(in_sh_ptr) + swizzle_idx);
}
Expand All @@ -133,6 +106,18 @@ __device__ __forceinline__ void ComputeKernel(uint32_t b_frag_i[4], uint32_t b_f
: "=r"(local_pre_rht_amax_reg)
: "r"(a_frag[0]), "r"(local_pre_rht_amax_reg));
}

if (kReturnIdentityAmax) {
ldmatrix_x4_m8n8_shared_b16<false>(a_frag[0], a_frag[1], a_frag[2], a_frag[3],
reinterpret_cast<uint4*>(in_sh_ptr) + swizzle_idx);

mma_m16_n16_k16_b16_b16_b16_noacc<kReturnIdentityAmax>(
a_frag[0], a_frag[1], a_frag[2], a_frag[3], b_frag_i[0], b_frag_i[1], b_frag_i[2],
b_frag_i[3], c_frag[0], c_frag[1], c_frag[2], c_frag[3], temp_amax_reg);
asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t"
: "=r"(local_amax_reg)
: "r"(local_amax_reg), "r"(temp_amax_reg));
}
}

template <int kN>
Expand Down Expand Up @@ -322,6 +307,12 @@ __global__ void GraphSafeGroupHadamardAmaxTmaKernel(
uint32_t local_amax_reg = *reinterpret_cast<uint32_t*>(&local_amax);
uint32_t local_amax_t_reg = *reinterpret_cast<uint32_t*>(&local_amax_t);

const int warp_id = threadIdx.x / kThreadsPerWarp;
const int local_rank = threadIdx.x % kThreadsPerWarp;
const int ld_row_idx = local_rank % kHadamardDimension;
const int ld_col_idx = local_rank / kHadamardDimension + warp_id * 2;
const int swizzle_idx = swizzle_128B_atom_32B(ld_row_idx, ld_col_idx);

for (int stage_y = 0; stage_y < STAGES_Y; ++stage_y) {
for (int stage_x = 0; stage_x < STAGES_X; ++stage_x) {
int stage = STAGES_X * stage_y + stage_x;
Expand Down Expand Up @@ -364,7 +355,7 @@ __global__ void GraphSafeGroupHadamardAmaxTmaKernel(
had_frag_i, had_frag_t,
in_sh_ptr + in_row_offset +
(compute_stage_x * kHadamardDimension * (THREADS_PER_CHUNK / kThreadsPerWarp)),
local_pre_rht_amax_reg, local_amax_reg, local_amax_t_reg);
swizzle_idx, local_pre_rht_amax_reg, local_amax_reg, local_amax_t_reg);
}

// Ensure all threads have finished their computation before new data over-writes the shared
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,46 +41,19 @@ constexpr int kThreadsPerWarp = 32;
template <typename IType, int kHadamardDimension, int BUFF_DIM_Y, int BUFF_DIM_X,
bool kReturnPreRhtAmax, bool kReturnIdentityAmax, bool kReturnTransposedAmax>
__device__ __forceinline__ void ComputeKernel(uint32_t b_frag_i[4], uint32_t b_frag_t[4],
IType* in_sh_ptr, uint32_t& local_pre_rht_amax_reg,
IType* in_sh_ptr, int swizzle_idx,
uint32_t& local_pre_rht_amax_reg,
uint32_t& local_amax_reg,
uint32_t& local_amax_t_reg) {
uint32_t a_frag[4]; // A matrix fragment
uint32_t c_frag[4]; // Result fragment

int warp_id = threadIdx.x / kThreadsPerWarp;
int local_rank = (threadIdx.x % kThreadsPerWarp);

int ld_row_idx = local_rank % kHadamardDimension;
int ld_col_idx = local_rank / kHadamardDimension + warp_id * 2;
int swizzle_idx = swizzle_128B_atom_32B(ld_row_idx, ld_col_idx);

uint32_t temp_amax_reg;
uint32_t temp_amax_t_reg;

if (kReturnIdentityAmax) {
ldmatrix_x4_m8n8_shared_b16<false>(a_frag[0], a_frag[1], a_frag[2], a_frag[3],
reinterpret_cast<uint4*>(in_sh_ptr) + swizzle_idx);

mma_m16_n16_k16_b16_b16_b16_noacc<kReturnIdentityAmax>(
a_frag[0], a_frag[1], a_frag[2], a_frag[3], b_frag_i[0], b_frag_i[1], b_frag_i[2],
b_frag_i[3], c_frag[0], c_frag[1], c_frag[2], c_frag[3], temp_amax_reg);
asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t"
: "=r"(local_amax_reg)
: "r"(local_amax_reg), "r"(temp_amax_reg));
}

if (kReturnTransposedAmax) {
// TODO(Frank): This is not efficient, since we could directly load the
// matrix in transposed layout.
if (!kReturnIdentityAmax) {
ldmatrix_x4_m8n8_shared_b16<false>(a_frag[0], a_frag[1], a_frag[2], a_frag[3],
reinterpret_cast<uint4*>(in_sh_ptr) + swizzle_idx);
}

matrix_transpose_m8_n8_b16_inplace(a_frag[0]);
matrix_transpose_m8_n8_b16_inplace(a_frag[1]);
matrix_transpose_m8_n8_b16_inplace(a_frag[2]);
matrix_transpose_m8_n8_b16_inplace(a_frag[3]);
ldmatrix_x4_m8n8_shared_b16<true>(a_frag[0], a_frag[1], a_frag[2], a_frag[3],
reinterpret_cast<uint4*>(in_sh_ptr) + swizzle_idx);

mma_m16_n16_k16_b16_b16_b16_noacc<kReturnTransposedAmax>(
a_frag[0], a_frag[2], a_frag[1], a_frag[3], b_frag_t[0], b_frag_t[1], b_frag_t[2],
Expand All @@ -91,7 +64,7 @@ __device__ __forceinline__ void ComputeKernel(uint32_t b_frag_i[4], uint32_t b_f
}

if (kReturnPreRhtAmax) {
if (!kReturnIdentityAmax && !kReturnTransposedAmax) {
if (!kReturnTransposedAmax) {
ldmatrix_x4_m8n8_shared_b16<false>(a_frag[0], a_frag[1], a_frag[2], a_frag[3],
reinterpret_cast<uint4*>(in_sh_ptr) + swizzle_idx);
}
Expand All @@ -109,6 +82,18 @@ __device__ __forceinline__ void ComputeKernel(uint32_t b_frag_i[4], uint32_t b_f
: "=r"(local_pre_rht_amax_reg)
: "r"(a_frag[0]), "r"(local_pre_rht_amax_reg));
}

if (kReturnIdentityAmax) {
ldmatrix_x4_m8n8_shared_b16<false>(a_frag[0], a_frag[1], a_frag[2], a_frag[3],
reinterpret_cast<uint4*>(in_sh_ptr) + swizzle_idx);

mma_m16_n16_k16_b16_b16_b16_noacc<kReturnIdentityAmax>(
a_frag[0], a_frag[1], a_frag[2], a_frag[3], b_frag_i[0], b_frag_i[1], b_frag_i[2],
b_frag_i[3], c_frag[0], c_frag[1], c_frag[2], c_frag[3], temp_amax_reg);
asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t"
: "=r"(local_amax_reg)
: "r"(local_amax_reg), "r"(temp_amax_reg));
}
}

template <int kN>
Expand Down Expand Up @@ -305,6 +290,12 @@ __global__ void GroupHadamardAmaxTmaKernel(const __grid_constant__ CUtensorMap t
uint32_t local_amax_reg = *reinterpret_cast<uint32_t*>(&local_amax);
uint32_t local_amax_t_reg = *reinterpret_cast<uint32_t*>(&local_amax_t);

const int warp_id = threadIdx.x / kThreadsPerWarp;
const int local_rank = threadIdx.x % kThreadsPerWarp;
const int ld_row_idx = local_rank % kHadamardDimension;
const int ld_col_idx = local_rank / kHadamardDimension + warp_id * 2;
const int swizzle_idx = swizzle_128B_atom_32B(ld_row_idx, ld_col_idx);

for (int stage_y = 0; stage_y < STAGES_Y; ++stage_y) {
for (int stage_x = 0; stage_x < STAGES_X; ++stage_x) {
int stage = STAGES_X * stage_y + stage_x;
Expand Down Expand Up @@ -347,7 +338,7 @@ __global__ void GroupHadamardAmaxTmaKernel(const __grid_constant__ CUtensorMap t
had_frag_i, had_frag_t,
in_sh_ptr + in_row_offset +
(compute_stage_x * kHadamardDimension * (THREADS_PER_CHUNK / kThreadsPerWarp)),
local_pre_rht_amax_reg, local_amax_reg, local_amax_t_reg);
swizzle_idx, local_pre_rht_amax_reg, local_amax_reg, local_amax_t_reg);
}

// Ensure all threads have finished their computation before new data over-writes the shared
Expand Down
57 changes: 24 additions & 33 deletions transformer_engine/common/hadamard_transform/hadamard_transform.cu
Original file line number Diff line number Diff line change
Expand Up @@ -26,46 +26,19 @@ constexpr int kThreadsPerWarp = 32;
template <typename IType, int kHadamardDimension, int BUFF_DIM_Y, int BUFF_DIM_X,
bool kReturnPreRhtAmax, bool kReturnIdentityAmax, bool kReturnTransposedAmax>
__device__ __forceinline__ void ComputeKernel(uint32_t b_frag_i[4], uint32_t b_frag_t[4],
IType* in_sh_ptr, uint32_t& local_pre_rht_amax_reg,
IType* in_sh_ptr, int swizzle_idx,
uint32_t& local_pre_rht_amax_reg,
uint32_t& local_amax_reg,
uint32_t& local_amax_t_reg) {
uint32_t a_frag[4]; // A matrix fragment
uint32_t c_frag[4]; // Result fragment

int warp_id = threadIdx.x / kThreadsPerWarp;
int local_rank = (threadIdx.x % kThreadsPerWarp);

int ld_row_idx = local_rank % kHadamardDimension;
int ld_col_idx = local_rank / kHadamardDimension + warp_id * 2;
int swizzle_idx = swizzle_128B_atom_32B(ld_row_idx, ld_col_idx);

uint32_t temp_amax_reg;
uint32_t temp_amax_t_reg;

if (kReturnIdentityAmax) {
ldmatrix_x4_m8n8_shared_b16<false>(a_frag[0], a_frag[1], a_frag[2], a_frag[3],
reinterpret_cast<uint4*>(in_sh_ptr) + swizzle_idx);

mma_m16_n16_k16_b16_b16_b16_noacc<kReturnIdentityAmax>(
a_frag[0], a_frag[1], a_frag[2], a_frag[3], b_frag_i[0], b_frag_i[1], b_frag_i[2],
b_frag_i[3], c_frag[0], c_frag[1], c_frag[2], c_frag[3], temp_amax_reg);
asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t"
: "=r"(local_amax_reg)
: "r"(local_amax_reg), "r"(temp_amax_reg));
}

if (kReturnTransposedAmax) {
// TODO(Frank): This is not efficient, since we could directly load the
// matrix in transposed layout.
if (!kReturnIdentityAmax) {
ldmatrix_x4_m8n8_shared_b16<false>(a_frag[0], a_frag[1], a_frag[2], a_frag[3],
reinterpret_cast<uint4*>(in_sh_ptr) + swizzle_idx);
}

matrix_transpose_m8_n8_b16_inplace(a_frag[0]);
matrix_transpose_m8_n8_b16_inplace(a_frag[1]);
matrix_transpose_m8_n8_b16_inplace(a_frag[2]);
matrix_transpose_m8_n8_b16_inplace(a_frag[3]);
ldmatrix_x4_m8n8_shared_b16<true>(a_frag[0], a_frag[1], a_frag[2], a_frag[3],
reinterpret_cast<uint4*>(in_sh_ptr) + swizzle_idx);

mma_m16_n16_k16_b16_b16_b16_noacc<kReturnTransposedAmax>(
a_frag[0], a_frag[2], a_frag[1], a_frag[3], b_frag_t[0], b_frag_t[1], b_frag_t[2],
Expand All @@ -76,7 +49,7 @@ __device__ __forceinline__ void ComputeKernel(uint32_t b_frag_i[4], uint32_t b_f
}

if (kReturnPreRhtAmax) {
if (!kReturnIdentityAmax && !kReturnTransposedAmax) {
if (!kReturnTransposedAmax) {
ldmatrix_x4_m8n8_shared_b16<false>(a_frag[0], a_frag[1], a_frag[2], a_frag[3],
reinterpret_cast<uint4*>(in_sh_ptr) + swizzle_idx);
}
Expand All @@ -94,6 +67,18 @@ __device__ __forceinline__ void ComputeKernel(uint32_t b_frag_i[4], uint32_t b_f
: "=r"(local_pre_rht_amax_reg)
: "r"(a_frag[0]), "r"(local_pre_rht_amax_reg));
}

if (kReturnIdentityAmax) {
ldmatrix_x4_m8n8_shared_b16<false>(a_frag[0], a_frag[1], a_frag[2], a_frag[3],
reinterpret_cast<uint4*>(in_sh_ptr) + swizzle_idx);

mma_m16_n16_k16_b16_b16_b16_noacc<kReturnIdentityAmax>(
a_frag[0], a_frag[1], a_frag[2], a_frag[3], b_frag_i[0], b_frag_i[1], b_frag_i[2],
b_frag_i[3], c_frag[0], c_frag[1], c_frag[2], c_frag[3], temp_amax_reg);
asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t"
: "=r"(local_amax_reg)
: "r"(local_amax_reg), "r"(temp_amax_reg));
}
}

template <int kN>
Expand Down Expand Up @@ -248,6 +233,12 @@ __global__ void HadamardAmaxTmaKernel(const __grid_constant__ CUtensorMap tensor
uint32_t local_amax_reg = *reinterpret_cast<uint32_t*>(&local_amax);
uint32_t local_amax_t_reg = *reinterpret_cast<uint32_t*>(&local_amax_t);

const int warp_id = threadIdx.x / kThreadsPerWarp;
const int local_rank = threadIdx.x % kThreadsPerWarp;
const int ld_row_idx = local_rank % kHadamardDimension;
const int ld_col_idx = local_rank / kHadamardDimension + warp_id * 2;
const int swizzle_idx = swizzle_128B_atom_32B(ld_row_idx, ld_col_idx);

for (int stage_y = 0; stage_y < STAGES_Y; ++stage_y) {
for (int stage_x = 0; stage_x < STAGES_X; ++stage_x) {
int stage = STAGES_X * stage_y + stage_x;
Expand Down Expand Up @@ -290,7 +281,7 @@ __global__ void HadamardAmaxTmaKernel(const __grid_constant__ CUtensorMap tensor
had_frag_i, had_frag_t,
in_sh_ptr + in_row_offset +
(compute_stage_x * kHadamardDimension * (THREADS_PER_CHUNK / kThreadsPerWarp)),
local_pre_rht_amax_reg, local_amax_reg, local_amax_t_reg);
swizzle_idx, local_pre_rht_amax_reg, local_amax_reg, local_amax_t_reg);
}

// Ensure all threads have finished their computation before new data over-writes the shared
Expand Down