diff --git a/transformer_engine/common/hadamard_transform/graph_safe_group_hadamard_transform.cu b/transformer_engine/common/hadamard_transform/graph_safe_group_hadamard_transform.cu index 04e965a9da..3ef14c6b11 100644 --- a/transformer_engine/common/hadamard_transform/graph_safe_group_hadamard_transform.cu +++ b/transformer_engine/common/hadamard_transform/graph_safe_group_hadamard_transform.cu @@ -65,46 +65,19 @@ __device__ __forceinline__ size_t get_current_tensor_id( template __device__ __forceinline__ void ComputeKernel(uint32_t b_frag_i[4], uint32_t b_frag_t[4], - IType* in_sh_ptr, uint32_t& local_pre_rht_amax_reg, + IType* in_sh_ptr, int swizzle_idx, + uint32_t& local_pre_rht_amax_reg, uint32_t& local_amax_reg, uint32_t& local_amax_t_reg) { uint32_t a_frag[4]; // A matrix fragment uint32_t c_frag[4]; // Result fragment - int warp_id = threadIdx.x / kThreadsPerWarp; - int local_rank = (threadIdx.x % kThreadsPerWarp); - - int ld_row_idx = local_rank % kHadamardDimension; - int ld_col_idx = local_rank / kHadamardDimension + warp_id * 2; - int swizzle_idx = swizzle_128B_atom_32B(ld_row_idx, ld_col_idx); - uint32_t temp_amax_reg; uint32_t temp_amax_t_reg; - if (kReturnIdentityAmax) { - ldmatrix_x4_m8n8_shared_b16(a_frag[0], a_frag[1], a_frag[2], a_frag[3], - reinterpret_cast(in_sh_ptr) + swizzle_idx); - - mma_m16_n16_k16_b16_b16_b16_noacc( - a_frag[0], a_frag[1], a_frag[2], a_frag[3], b_frag_i[0], b_frag_i[1], b_frag_i[2], - b_frag_i[3], c_frag[0], c_frag[1], c_frag[2], c_frag[3], temp_amax_reg); - asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t" - : "=r"(local_amax_reg) - : "r"(local_amax_reg), "r"(temp_amax_reg)); - } - if (kReturnTransposedAmax) { - // TODO(Frank): This is not efficient, since we could directly load the - // matrix in transposed layout. - if (!kReturnIdentityAmax) { - ldmatrix_x4_m8n8_shared_b16(a_frag[0], a_frag[1], a_frag[2], a_frag[3], - reinterpret_cast(in_sh_ptr) + swizzle_idx); - } - - matrix_transpose_m8_n8_b16_inplace(a_frag[0]); - matrix_transpose_m8_n8_b16_inplace(a_frag[1]); - matrix_transpose_m8_n8_b16_inplace(a_frag[2]); - matrix_transpose_m8_n8_b16_inplace(a_frag[3]); + ldmatrix_x4_m8n8_shared_b16(a_frag[0], a_frag[1], a_frag[2], a_frag[3], + reinterpret_cast(in_sh_ptr) + swizzle_idx); mma_m16_n16_k16_b16_b16_b16_noacc( a_frag[0], a_frag[2], a_frag[1], a_frag[3], b_frag_t[0], b_frag_t[1], b_frag_t[2], @@ -115,7 +88,7 @@ __device__ __forceinline__ void ComputeKernel(uint32_t b_frag_i[4], uint32_t b_f } if (kReturnPreRhtAmax) { - if (!kReturnIdentityAmax && !kReturnTransposedAmax) { + if (!kReturnTransposedAmax) { ldmatrix_x4_m8n8_shared_b16(a_frag[0], a_frag[1], a_frag[2], a_frag[3], reinterpret_cast(in_sh_ptr) + swizzle_idx); } @@ -133,6 +106,18 @@ __device__ __forceinline__ void ComputeKernel(uint32_t b_frag_i[4], uint32_t b_f : "=r"(local_pre_rht_amax_reg) : "r"(a_frag[0]), "r"(local_pre_rht_amax_reg)); } + + if (kReturnIdentityAmax) { + ldmatrix_x4_m8n8_shared_b16(a_frag[0], a_frag[1], a_frag[2], a_frag[3], + reinterpret_cast(in_sh_ptr) + swizzle_idx); + + mma_m16_n16_k16_b16_b16_b16_noacc( + a_frag[0], a_frag[1], a_frag[2], a_frag[3], b_frag_i[0], b_frag_i[1], b_frag_i[2], + b_frag_i[3], c_frag[0], c_frag[1], c_frag[2], c_frag[3], temp_amax_reg); + asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t" + : "=r"(local_amax_reg) + : "r"(local_amax_reg), "r"(temp_amax_reg)); + } } template @@ -322,6 +307,12 @@ __global__ void GraphSafeGroupHadamardAmaxTmaKernel( uint32_t local_amax_reg = *reinterpret_cast(&local_amax); uint32_t local_amax_t_reg = *reinterpret_cast(&local_amax_t); + const int warp_id = threadIdx.x / kThreadsPerWarp; + const int local_rank = threadIdx.x % kThreadsPerWarp; + const int ld_row_idx = local_rank % kHadamardDimension; + const int ld_col_idx = local_rank / kHadamardDimension + warp_id * 2; + const int swizzle_idx = swizzle_128B_atom_32B(ld_row_idx, ld_col_idx); + for (int stage_y = 0; stage_y < STAGES_Y; ++stage_y) { for (int stage_x = 0; stage_x < STAGES_X; ++stage_x) { int stage = STAGES_X * stage_y + stage_x; @@ -364,7 +355,7 @@ __global__ void GraphSafeGroupHadamardAmaxTmaKernel( had_frag_i, had_frag_t, in_sh_ptr + in_row_offset + (compute_stage_x * kHadamardDimension * (THREADS_PER_CHUNK / kThreadsPerWarp)), - local_pre_rht_amax_reg, local_amax_reg, local_amax_t_reg); + swizzle_idx, local_pre_rht_amax_reg, local_amax_reg, local_amax_t_reg); } // Ensure all threads have finished their computation before new data over-writes the shared diff --git a/transformer_engine/common/hadamard_transform/group_hadamard_transform.cu b/transformer_engine/common/hadamard_transform/group_hadamard_transform.cu index 07813be059..56297814f5 100644 --- a/transformer_engine/common/hadamard_transform/group_hadamard_transform.cu +++ b/transformer_engine/common/hadamard_transform/group_hadamard_transform.cu @@ -41,46 +41,19 @@ constexpr int kThreadsPerWarp = 32; template __device__ __forceinline__ void ComputeKernel(uint32_t b_frag_i[4], uint32_t b_frag_t[4], - IType* in_sh_ptr, uint32_t& local_pre_rht_amax_reg, + IType* in_sh_ptr, int swizzle_idx, + uint32_t& local_pre_rht_amax_reg, uint32_t& local_amax_reg, uint32_t& local_amax_t_reg) { uint32_t a_frag[4]; // A matrix fragment uint32_t c_frag[4]; // Result fragment - int warp_id = threadIdx.x / kThreadsPerWarp; - int local_rank = (threadIdx.x % kThreadsPerWarp); - - int ld_row_idx = local_rank % kHadamardDimension; - int ld_col_idx = local_rank / kHadamardDimension + warp_id * 2; - int swizzle_idx = swizzle_128B_atom_32B(ld_row_idx, ld_col_idx); - uint32_t temp_amax_reg; uint32_t temp_amax_t_reg; - if (kReturnIdentityAmax) { - ldmatrix_x4_m8n8_shared_b16(a_frag[0], a_frag[1], a_frag[2], a_frag[3], - reinterpret_cast(in_sh_ptr) + swizzle_idx); - - mma_m16_n16_k16_b16_b16_b16_noacc( - a_frag[0], a_frag[1], a_frag[2], a_frag[3], b_frag_i[0], b_frag_i[1], b_frag_i[2], - b_frag_i[3], c_frag[0], c_frag[1], c_frag[2], c_frag[3], temp_amax_reg); - asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t" - : "=r"(local_amax_reg) - : "r"(local_amax_reg), "r"(temp_amax_reg)); - } - if (kReturnTransposedAmax) { - // TODO(Frank): This is not efficient, since we could directly load the - // matrix in transposed layout. - if (!kReturnIdentityAmax) { - ldmatrix_x4_m8n8_shared_b16(a_frag[0], a_frag[1], a_frag[2], a_frag[3], - reinterpret_cast(in_sh_ptr) + swizzle_idx); - } - - matrix_transpose_m8_n8_b16_inplace(a_frag[0]); - matrix_transpose_m8_n8_b16_inplace(a_frag[1]); - matrix_transpose_m8_n8_b16_inplace(a_frag[2]); - matrix_transpose_m8_n8_b16_inplace(a_frag[3]); + ldmatrix_x4_m8n8_shared_b16(a_frag[0], a_frag[1], a_frag[2], a_frag[3], + reinterpret_cast(in_sh_ptr) + swizzle_idx); mma_m16_n16_k16_b16_b16_b16_noacc( a_frag[0], a_frag[2], a_frag[1], a_frag[3], b_frag_t[0], b_frag_t[1], b_frag_t[2], @@ -91,7 +64,7 @@ __device__ __forceinline__ void ComputeKernel(uint32_t b_frag_i[4], uint32_t b_f } if (kReturnPreRhtAmax) { - if (!kReturnIdentityAmax && !kReturnTransposedAmax) { + if (!kReturnTransposedAmax) { ldmatrix_x4_m8n8_shared_b16(a_frag[0], a_frag[1], a_frag[2], a_frag[3], reinterpret_cast(in_sh_ptr) + swizzle_idx); } @@ -109,6 +82,18 @@ __device__ __forceinline__ void ComputeKernel(uint32_t b_frag_i[4], uint32_t b_f : "=r"(local_pre_rht_amax_reg) : "r"(a_frag[0]), "r"(local_pre_rht_amax_reg)); } + + if (kReturnIdentityAmax) { + ldmatrix_x4_m8n8_shared_b16(a_frag[0], a_frag[1], a_frag[2], a_frag[3], + reinterpret_cast(in_sh_ptr) + swizzle_idx); + + mma_m16_n16_k16_b16_b16_b16_noacc( + a_frag[0], a_frag[1], a_frag[2], a_frag[3], b_frag_i[0], b_frag_i[1], b_frag_i[2], + b_frag_i[3], c_frag[0], c_frag[1], c_frag[2], c_frag[3], temp_amax_reg); + asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t" + : "=r"(local_amax_reg) + : "r"(local_amax_reg), "r"(temp_amax_reg)); + } } template @@ -305,6 +290,12 @@ __global__ void GroupHadamardAmaxTmaKernel(const __grid_constant__ CUtensorMap t uint32_t local_amax_reg = *reinterpret_cast(&local_amax); uint32_t local_amax_t_reg = *reinterpret_cast(&local_amax_t); + const int warp_id = threadIdx.x / kThreadsPerWarp; + const int local_rank = threadIdx.x % kThreadsPerWarp; + const int ld_row_idx = local_rank % kHadamardDimension; + const int ld_col_idx = local_rank / kHadamardDimension + warp_id * 2; + const int swizzle_idx = swizzle_128B_atom_32B(ld_row_idx, ld_col_idx); + for (int stage_y = 0; stage_y < STAGES_Y; ++stage_y) { for (int stage_x = 0; stage_x < STAGES_X; ++stage_x) { int stage = STAGES_X * stage_y + stage_x; @@ -347,7 +338,7 @@ __global__ void GroupHadamardAmaxTmaKernel(const __grid_constant__ CUtensorMap t had_frag_i, had_frag_t, in_sh_ptr + in_row_offset + (compute_stage_x * kHadamardDimension * (THREADS_PER_CHUNK / kThreadsPerWarp)), - local_pre_rht_amax_reg, local_amax_reg, local_amax_t_reg); + swizzle_idx, local_pre_rht_amax_reg, local_amax_reg, local_amax_t_reg); } // Ensure all threads have finished their computation before new data over-writes the shared diff --git a/transformer_engine/common/hadamard_transform/hadamard_transform.cu b/transformer_engine/common/hadamard_transform/hadamard_transform.cu index 4adc836886..7a8db9d85c 100644 --- a/transformer_engine/common/hadamard_transform/hadamard_transform.cu +++ b/transformer_engine/common/hadamard_transform/hadamard_transform.cu @@ -26,46 +26,19 @@ constexpr int kThreadsPerWarp = 32; template __device__ __forceinline__ void ComputeKernel(uint32_t b_frag_i[4], uint32_t b_frag_t[4], - IType* in_sh_ptr, uint32_t& local_pre_rht_amax_reg, + IType* in_sh_ptr, int swizzle_idx, + uint32_t& local_pre_rht_amax_reg, uint32_t& local_amax_reg, uint32_t& local_amax_t_reg) { uint32_t a_frag[4]; // A matrix fragment uint32_t c_frag[4]; // Result fragment - int warp_id = threadIdx.x / kThreadsPerWarp; - int local_rank = (threadIdx.x % kThreadsPerWarp); - - int ld_row_idx = local_rank % kHadamardDimension; - int ld_col_idx = local_rank / kHadamardDimension + warp_id * 2; - int swizzle_idx = swizzle_128B_atom_32B(ld_row_idx, ld_col_idx); - uint32_t temp_amax_reg; uint32_t temp_amax_t_reg; - if (kReturnIdentityAmax) { - ldmatrix_x4_m8n8_shared_b16(a_frag[0], a_frag[1], a_frag[2], a_frag[3], - reinterpret_cast(in_sh_ptr) + swizzle_idx); - - mma_m16_n16_k16_b16_b16_b16_noacc( - a_frag[0], a_frag[1], a_frag[2], a_frag[3], b_frag_i[0], b_frag_i[1], b_frag_i[2], - b_frag_i[3], c_frag[0], c_frag[1], c_frag[2], c_frag[3], temp_amax_reg); - asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t" - : "=r"(local_amax_reg) - : "r"(local_amax_reg), "r"(temp_amax_reg)); - } - if (kReturnTransposedAmax) { - // TODO(Frank): This is not efficient, since we could directly load the - // matrix in transposed layout. - if (!kReturnIdentityAmax) { - ldmatrix_x4_m8n8_shared_b16(a_frag[0], a_frag[1], a_frag[2], a_frag[3], - reinterpret_cast(in_sh_ptr) + swizzle_idx); - } - - matrix_transpose_m8_n8_b16_inplace(a_frag[0]); - matrix_transpose_m8_n8_b16_inplace(a_frag[1]); - matrix_transpose_m8_n8_b16_inplace(a_frag[2]); - matrix_transpose_m8_n8_b16_inplace(a_frag[3]); + ldmatrix_x4_m8n8_shared_b16(a_frag[0], a_frag[1], a_frag[2], a_frag[3], + reinterpret_cast(in_sh_ptr) + swizzle_idx); mma_m16_n16_k16_b16_b16_b16_noacc( a_frag[0], a_frag[2], a_frag[1], a_frag[3], b_frag_t[0], b_frag_t[1], b_frag_t[2], @@ -76,7 +49,7 @@ __device__ __forceinline__ void ComputeKernel(uint32_t b_frag_i[4], uint32_t b_f } if (kReturnPreRhtAmax) { - if (!kReturnIdentityAmax && !kReturnTransposedAmax) { + if (!kReturnTransposedAmax) { ldmatrix_x4_m8n8_shared_b16(a_frag[0], a_frag[1], a_frag[2], a_frag[3], reinterpret_cast(in_sh_ptr) + swizzle_idx); } @@ -94,6 +67,18 @@ __device__ __forceinline__ void ComputeKernel(uint32_t b_frag_i[4], uint32_t b_f : "=r"(local_pre_rht_amax_reg) : "r"(a_frag[0]), "r"(local_pre_rht_amax_reg)); } + + if (kReturnIdentityAmax) { + ldmatrix_x4_m8n8_shared_b16(a_frag[0], a_frag[1], a_frag[2], a_frag[3], + reinterpret_cast(in_sh_ptr) + swizzle_idx); + + mma_m16_n16_k16_b16_b16_b16_noacc( + a_frag[0], a_frag[1], a_frag[2], a_frag[3], b_frag_i[0], b_frag_i[1], b_frag_i[2], + b_frag_i[3], c_frag[0], c_frag[1], c_frag[2], c_frag[3], temp_amax_reg); + asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t" + : "=r"(local_amax_reg) + : "r"(local_amax_reg), "r"(temp_amax_reg)); + } } template @@ -248,6 +233,12 @@ __global__ void HadamardAmaxTmaKernel(const __grid_constant__ CUtensorMap tensor uint32_t local_amax_reg = *reinterpret_cast(&local_amax); uint32_t local_amax_t_reg = *reinterpret_cast(&local_amax_t); + const int warp_id = threadIdx.x / kThreadsPerWarp; + const int local_rank = threadIdx.x % kThreadsPerWarp; + const int ld_row_idx = local_rank % kHadamardDimension; + const int ld_col_idx = local_rank / kHadamardDimension + warp_id * 2; + const int swizzle_idx = swizzle_128B_atom_32B(ld_row_idx, ld_col_idx); + for (int stage_y = 0; stage_y < STAGES_Y; ++stage_y) { for (int stage_x = 0; stage_x < STAGES_X; ++stage_x) { int stage = STAGES_X * stage_y + stage_x; @@ -290,7 +281,7 @@ __global__ void HadamardAmaxTmaKernel(const __grid_constant__ CUtensorMap tensor had_frag_i, had_frag_t, in_sh_ptr + in_row_offset + (compute_stage_x * kHadamardDimension * (THREADS_PER_CHUNK / kThreadsPerWarp)), - local_pre_rht_amax_reg, local_amax_reg, local_amax_t_reg); + swizzle_idx, local_pre_rht_amax_reg, local_amax_reg, local_amax_t_reg); } // Ensure all threads have finished their computation before new data over-writes the shared