diff --git a/CMakeLists.txt b/CMakeLists.txt index 19bd7b623..81326ffcf 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -54,8 +54,7 @@ endif() # Define included source files set(CPP_FILES csrc/cpu_ops.cpp csrc/pythonInterface.cpp) -set(CUDA_FILES csrc/ops.cu csrc/kernels.cu) -set(HIP_FILES csrc/ops.hip csrc/kernels.hip) +set(GPU_FILES csrc/ops.cu csrc/kernels.cu) set(MPS_FILES csrc/mps_ops.mm) set(METAL_FILES csrc/mps_kernels.metal) set(XPU_FILES csrc/xpu_ops.cpp csrc/xpu_kernels.cpp) @@ -225,7 +224,7 @@ if(BUILD_CUDA) message(STATUS "CUDA Targets: ${CMAKE_CUDA_ARCHITECTURES}") message(STATUS "CUDA NVCC Flags: ${CMAKE_CUDA_FLAGS}") - list(APPEND SRC_FILES ${CUDA_FILES}) + list(APPEND SRC_FILES ${GPU_FILES}) string(APPEND BNB_OUTPUT_NAME "_cuda${CUDA_VERSION_SHORT}") add_compile_definitions(BUILD_CUDA) @@ -244,7 +243,7 @@ elseif(BUILD_HIP) message(STATUS "HIP Compiler: ${CMAKE_HIP_COMPILER}") message(STATUS "HIP Targets: ${CMAKE_HIP_ARCHITECTURES}") - list(APPEND SRC_FILES ${HIP_FILES}) + list(APPEND SRC_FILES ${GPU_FILES}) string(APPEND BNB_OUTPUT_NAME "_rocm") @@ -389,7 +388,7 @@ if(BUILD_HIP) endif() target_compile_definitions(bitsandbytes PUBLIC BNB_USE_HIP) - set_source_files_properties(${HIP_FILES} PROPERTIES LANGUAGE HIP) + set_source_files_properties(${GPU_FILES} PROPERTIES LANGUAGE HIP) set_target_properties(bitsandbytes PROPERTIES LINKER_LANGUAGE CXX) if(HIP_VERSION VERSION_LESS "6.1") diff --git a/csrc/common.cuh b/csrc/common.cuh index 9e245fcd6..61bef3c27 100644 --- a/csrc/common.cuh +++ b/csrc/common.cuh @@ -1,6 +1,32 @@ +// common.cuh — Architecture constants and feature detection + #pragma once -// TODO: Let's make some of these constexpr and put in a namespace. +#include "compat.cuh" + +// Warp size + +#if BNB_HIP +// CDNA (gfx9xx) = 64, RDNA = 32. +#ifdef __AMDGCN_WAVEFRONT_SIZE +#define BNB_WARP_SIZE __AMDGCN_WAVEFRONT_SIZE +#else +#define BNB_WARP_SIZE 64 // Safe default for HIP (matches CDNA) +#endif +#else +#define BNB_WARP_SIZE 32 +#endif + +// BF16 availability + +#if BNB_HIP +// BF16 is available on all currently-supported ROCm architectures (CDNA2+, RDNA3+) +#define BNB_BF16_AVAILABLE true +#else +#define BNB_BF16_AVAILABLE (__CUDA_ARCH__ >= BNB_CC_AMPERE) +#endif + +// Compute capability constants #define BNB_CC_PASCAL 600 #define BNB_CC_PASCAL_X2 620 @@ -14,18 +40,27 @@ #define BNB_CC_HOPPER 900 #define BNB_CC_BLACKWELL 1000 +// Feature availability based on arch + +#if BNB_HIP +// HIP: MMA not supported via mma.h; FP8 support varies by arch +#define BNB_FP16_MMA_AVAILABLE 0 +#define BNB_INT8_MMA_AVAILABLE 0 +#define BNB_FP8_AVAILABLE 0 +#else #define BNB_FP16_MMA_AVAILABLE (__CUDA_ARCH__ >= BNB_CC_VOLTA) #define BNB_INT8_MMA_AVAILABLE (__CUDA_ARCH__ >= BNB_CC_VOLTA_XAVIER) -#define BNB_BF16_AVAILABLE (__CUDA_ARCH__ >= BNB_CC_AMPERE) #define BNB_FP8_AVAILABLE (__CUDA_ARCH__ >= BNB_CC_ADA) +#endif -#define BNB_WARP_SIZE 32 +// Maximum threads per SM/CU -// The maximum number of resident threads per SM varies by arch. -// For A100/H100 and all prior to Turing, it is 2048, which allows -// for 2 full blocks of 1024 threads per SM. -// Reference: -// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#features-and-technical-specifications-technical-specifications-per-compute-capability +#if BNB_HIP +// For currently supported ROCm architectures (CDNA2, RDNA3) +#define BNB_MAX_THREADS_PER_SM 2048 +#else +// The maximum number of resident threads per SM varies by NVIDIA arch. +// Reference: CUDA Programming Guide, Technical Specifications per Compute Capability #if __CUDA_ARCH__ == 750 #define BNB_MAX_THREADS_PER_SM 1024 #elif __CUDA_ARCH__ >= 860 && __CUDA_ARCH__ <= 890 @@ -33,12 +68,13 @@ #else #define BNB_MAX_THREADS_PER_SM 2048 #endif +#endif -// Maximum resident warps per SM is always directly related to the number of threads. +// Maximum resident warps per SM/CU #define BNB_MAX_WARPS_PER_SM ((BNB_MAX_THREADS_PER_SM) / (BNB_WARP_SIZE)) -// Maximum resident blocks per SM may vary. -#if __CUDA_ARCH__ == 860 || __CUDA_ARCH__ == 870 +// Maximum resident blocks per SM/CU +#if !BNB_HIP && (defined(__CUDA_ARCH__)) && (__CUDA_ARCH__ == 860 || __CUDA_ARCH__ == 870) #define BNB_MAX_BLOCKS_PER_SM 16 #else #define BNB_MAX_BLOCKS_PER_SM ((BNB_MAX_WARPS_PER_SM) / 2) diff --git a/csrc/common_hip.cuh b/csrc/common_hip.cuh deleted file mode 100644 index 7ecb59df7..000000000 --- a/csrc/common_hip.cuh +++ /dev/null @@ -1,11 +0,0 @@ -#pragma once - -#ifdef __GFX9__ -#define BNB_WARP_SIZE 64 -#else -#define BNB_WARP_SIZE 32 -#endif - -// These are set based on current BNB support for CDNA 2 & RDNA 3. Update as needed for future archs -#define BNB_MAX_THREADS_PER_CU 2048 -#define BNB_BF16_AVAILABLE true diff --git a/csrc/compat.cuh b/csrc/compat.cuh new file mode 100644 index 000000000..18c1a54d2 --- /dev/null +++ b/csrc/compat.cuh @@ -0,0 +1,181 @@ +// compat.cuh — Platform abstraction layer for CUDA/HIP portability +// +// This header resolves ALL mechanical differences between CUDA and HIP. +// Kernel code should include this header and use the bnb_* types/macros +// instead of cuda*/hip* identifiers directly. +// +// The guard macro is BNB_HIP, which is defined when compiling for ROCm/HIP +// (set via CMakeLists.txt's add_compile_definitions(__HIP_PLATFORM_AMD__)). + +#pragma once + +// Platform detection + +#if defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__) +#define BNB_HIP 1 +#else +#define BNB_HIP 0 +#endif + +// Runtime and FP16/BF16 headers + +#if BNB_HIP + +#include +#include +#include +#include +#include +#include + +#else // CUDA + +#include +#include +#include + +#endif + +// Stream and error types + +#if BNB_HIP + +using bnb_stream_t = hipStream_t; +using bnb_error_t = hipError_t; + +#define BNB_SUCCESS hipSuccess +#define BNB_PEEK_LAST_ERROR() hipPeekAtLastError() +#define BNB_GET_ERROR_STRING(e) hipGetErrorString(e) +#define BNB_DEVICE_MALLOC(p, s) hipMalloc(p, s) +#define BNB_DEVICE_FREE(p) hipFree(p) +#define BNB_DEVICE_MEMSET(p, v, s) hipMemset(p, v, s) + +#else // CUDA + +using bnb_stream_t = cudaStream_t; +using bnb_error_t = cudaError_t; + +#define BNB_SUCCESS cudaSuccess +#define BNB_PEEK_LAST_ERROR() cudaPeekAtLastError() +#define BNB_GET_ERROR_STRING(e) cudaGetErrorString(e) +#define BNB_DEVICE_MALLOC(p, s) cudaMalloc(p, s) +#define BNB_DEVICE_FREE(p) cudaFree(p) +#define BNB_DEVICE_MEMSET(p, v, s) cudaMemset(p, v, s) + +#endif + +// Error checking + +#define BNB_CHECK_RETURN(value) \ + { \ + bnb_error_t _bnb_stat = value; \ + if (_bnb_stat != BNB_SUCCESS) { \ + fprintf(stderr, "Error %s at line %d in file %s\n", BNB_GET_ERROR_STRING(_bnb_stat), __LINE__, __FILE__); \ + exit(1); \ + } \ + } + +// Keep backward compat for existing code during migration +#define CUDA_CHECK_RETURN(value) BNB_CHECK_RETURN(value) + +// Warp synchronization +// +// HIP warps are always in lockstep (no independent thread scheduling), +// so __syncwarp() is a no-op. CUDA needs it for warp convergence. + +#if BNB_HIP +#define __syncwarp() \ + do { \ + } while (0) +#endif + +// BFloat16 type alias + +#if BNB_HIP +using bnb_bfloat16 = hip_bfloat16; +#else +using bnb_bfloat16 = __nv_bfloat16; +#endif + +// Data type enum aliases for BLAS libraries + +#if BNB_HIP + +#define BNB_R_16F HIP_R_16F +#define BNB_R_32F HIP_R_32F +#define BNB_R_8I HIP_R_8I +#define BNB_R_32I HIP_R_32I + +#else // CUDA + +#define BNB_R_16F CUDA_R_16F +#define BNB_R_32F CUDA_R_32F +#define BNB_R_8I CUDA_R_8I +#define BNB_R_32I CUDA_R_32I + +#endif + +// BLAS Lt types and functions + +#if BNB_HIP + +#ifndef NO_HIPBLASLT +#include +#endif + +using bnb_blasLt_handle_t = hipblasLtHandle_t; +using bnb_blasLt_matmul_desc_t = hipblasLtMatmulDesc_t; +using bnb_blasLt_layout_t = hipblasLtMatrixLayout_t; +using bnb_blasLt_preference_t = hipblasLtMatmulPreference_t; + +#define BNB_BLASLT_OP_T HIPBLAS_OP_T +#define BNB_BLASLT_COMPUTE_32I HIPBLAS_COMPUTE_32I + +#define bnb_blasLtCreate hipblasLtCreate +#define bnb_blasLtMatmulDescCreate hipblasLtMatmulDescCreate +#define bnb_blasLtMatmulDescSetAttr hipblasLtMatmulDescSetAttribute +#define bnb_blasLtLayoutCreate hipblasLtMatrixLayoutCreate +#define bnb_blasLtLayoutDestroy hipblasLtMatrixLayoutDestroy +#define bnb_blasLtMatmulDescDestroy hipblasLtMatmulDescDestroy +#define bnb_blasLtMatmul hipblasLtMatmul +#define bnb_blasLtPrefCreate hipblasLtMatmulPreferenceCreate +#define bnb_blasLtPrefSetAttr hipblasLtMatmulPreferenceSetAttribute +#define bnb_blasLtAlgoGetHeuristic hipblasLtMatmulAlgoGetHeuristic + +#define BNB_BLASLT_DESC_TRANSA HIPBLASLT_MATMUL_DESC_TRANSA +#define BNB_BLASLT_DESC_POINTER_MODE HIPBLASLT_MATMUL_DESC_POINTER_MODE +#define BNB_BLASLT_PREF_MAX_WORKSPACE HIPBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES +#define BNB_BLASLT_PTR_MODE_ALPHA_VEC HIPBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_HOST + +using bnb_blasLt_heuristic_t = hipblasLtMatmulHeuristicResult_t; +using bnb_blas_status_t = hipblasStatus_t; +#define BNB_BLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS + +#else // CUDA + +#include +#include + +using bnb_blasLt_handle_t = cublasLtHandle_t; +using bnb_blasLt_matmul_desc_t = cublasLtMatmulDesc_t; +using bnb_blasLt_layout_t = cublasLtMatrixLayout_t; + +#define BNB_BLASLT_OP_T CUBLAS_OP_T +#define BNB_BLASLT_COMPUTE_32I CUBLAS_COMPUTE_32I + +#define bnb_blasLtCreate cublasLtCreate +#define bnb_blasLtMatmulDescCreate cublasLtMatmulDescCreate +#define bnb_blasLtMatmulDescSetAttr cublasLtMatmulDescSetAttribute +#define bnb_blasLtLayoutCreate cublasLtMatrixLayoutCreate +#define bnb_blasLtLayoutDestroy cublasLtMatrixLayoutDestroy +#define bnb_blasLtMatmulDescDestroy cublasLtMatmulDescDestroy +#define bnb_blasLtMatmul cublasLtMatmul + +#define BNB_BLASLT_DESC_TRANSA CUBLASLT_MATMUL_DESC_TRANSA +#define BNB_BLASLT_DESC_POINTER_MODE CUBLASLT_MATMUL_DESC_POINTER_MODE +#define BNB_BLASLT_PTR_MODE_ALPHA_VEC CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_ZERO + +using bnb_blas_status_t = cublasStatus_t; +#define BNB_BLAS_STATUS_SUCCESS CUBLAS_STATUS_SUCCESS + +#endif diff --git a/csrc/compat_device.cuh b/csrc/compat_device.cuh new file mode 100644 index 000000000..8419f1485 --- /dev/null +++ b/csrc/compat_device.cuh @@ -0,0 +1,51 @@ +// compat_device.cuh — Device-only portability layer (CUB, reduction ops, MMA) +// +// Include this from .cu kernel files only (compiled by nvcc/hipcc). +// Do NOT include from .cpp files — use compat.cuh instead for host-safe types. + +#pragma once + +#include "compat.cuh" + +// CUB / hipCUB — namespace alias + +#if BNB_HIP + +#include +namespace bnb_cub = hipcub; + +#else // CUDA + +#include +#include +#include +#include +#include +#include +#include +#include +#include +namespace bnb_cub = cub; + +#endif + +// Reduction operators + +#if BNB_HIP + +#define BNB_MAX_OP hipcub::Max() +#define BNB_SUM_OP hipcub::Sum() + +#else // CUDA + +// CCCL 2.8.2+ moved to cuda::maximum<>{}, older versions use cub::Max() +#if defined(CCCL_VERSION) && CCCL_VERSION >= 2008002 +#include +#define BNB_MAX_OP \ + cuda::maximum<> {} +#else +#define BNB_MAX_OP cub::Max() +#endif +#define BNB_SUM_OP cub::Sum() + +#endif diff --git a/csrc/kernels.cu b/csrc/kernels.cu index dac6a2dc4..cff242316 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -4,25 +4,8 @@ // LICENSE file in the root directory of this source tree. #include "common.cuh" +#include "compat_device.cuh" #include "kernels.cuh" -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#if CCCL_VERSION >= 2008002 -#include -#define CUB_REDUCTIONOP_MAX \ - cuda::maximum<> {} -#else -#define CUB_REDUCTIONOP_MAX cub::Max() -#endif #define HLF_MAX 65504 #define TH 1024 @@ -60,6 +43,8 @@ __device__ static float nf4_dequantization_lut[16] = { }; // source: https://stackoverflow.com/questions/17399119/how-do-i-use-atomicmax-on-floating-point-values-in-cuda +// HIP has native atomicMax for float; CUDA needs a CAS loop +#if !BNB_HIP __device__ float atomicMax(float* address, float val) { int* address_as_i = reinterpret_cast(address); int old = *address_as_i, assumed; @@ -69,6 +54,7 @@ __device__ float atomicMax(float* address, float val) { } while (assumed != old); return __int_as_float(old); } +#endif // !BNB_HIP __device__ __forceinline__ float dDequantizeFP4Tree(unsigned char val) { float sign = 1.0f - 2 * ((val & 0b1000) >> 3); @@ -299,13 +285,20 @@ __global__ void kQuantizeBlockwise( float local_abs_max = 0.0f; int local_rand_idx = 0; - typedef cub::BlockLoad LoadT; - typedef cub::BlockStore< - unsigned char, BLOCK_SIZE / NUM_PER_TH, (DATA_TYPE > 0) ? NUM_PER_TH / 2 : NUM_PER_TH, - cub::BLOCK_STORE_WARP_TRANSPOSE> + // WARP_TRANSPOSE requires block_dim >= warp_size. On CDNA (warp=64), + // block_dim=32 (from BLOCK_SIZE=64/NUM_PER_TH=2) is too small. Fall back + // to DIRECT load/store in that case. + static constexpr int THREADS = BLOCK_SIZE / NUM_PER_TH; + static constexpr auto LOAD_ALGO = + (THREADS >= BNB_WARP_SIZE) ? bnb_cub::BLOCK_LOAD_WARP_TRANSPOSE : bnb_cub::BLOCK_LOAD_DIRECT; + static constexpr auto STORE_ALGO = + (THREADS >= BNB_WARP_SIZE) ? bnb_cub::BLOCK_STORE_WARP_TRANSPOSE : bnb_cub::BLOCK_STORE_DIRECT; + + typedef bnb_cub::BlockLoad LoadT; + typedef bnb_cub::BlockStore 0) ? NUM_PER_TH / 2 : NUM_PER_TH, STORE_ALGO> StoreChar; - typedef cub::BlockReduce BlockReduce; - typedef cub::BlockLoad LoadFloat; + typedef bnb_cub::BlockReduce BlockReduce; + typedef bnb_cub::BlockLoad LoadFloat; __shared__ typename LoadT::TempStorage loadt; __shared__ typename LoadFloat::TempStorage loadf; @@ -333,7 +326,7 @@ __global__ void kQuantizeBlockwise( for (int j = 0; j < NUM_PER_TH; j++) local_abs_max = fmaxf(local_abs_max, fabsf((float)vals[j])); - local_abs_max = BlockReduce(reduce).Reduce(local_abs_max, CUB_REDUCTIONOP_MAX, valid_items); + local_abs_max = BlockReduce(reduce).Reduce(local_abs_max, BNB_MAX_OP, valid_items); if (threadIdx.x == 0) { smem_absmax_value[0] = 1.0f / local_abs_max; @@ -381,20 +374,22 @@ __global__ void kQuantizeBlockwise( } } -// Specialized kernel for blocksize=32 with 4-bit quantization -// Processes 2 blocks of 32 values per warp to maintain full thread utilization -// Uses 32 threads total: threads 0-15 handle block 0, threads 16-31 handle block 1 +// Unified small-blocksize kernel for 4-bit quantization +// Processes 2 blocks of BNB_WARP_SIZE values per thread block +// On CUDA (warp=32): blocksize=32, 32 threads, WarpReduce<16> +// On HIP (warp=64): blocksize=64, 64 threads, WarpReduce<32> +// On HIP (warp=32): blocksize=32, 32 threads, WarpReduce<16> template -__global__ void kQuantizeBlockwise32( +__global__ void kQuantizeBlockwiseSmall( float* code, T* __restrict__ const A, float* absmax, unsigned char* out, float* __restrict__ const rand, const int rand_offset, const int n ) { - constexpr int BLOCK_SIZE = 32; // Size of each quantization block - constexpr int NUM_PER_TH = 2; // Values per thread (for 4-bit packing) - constexpr int THREADS = 32; // Total threads (full warp) - constexpr int THREADS_PER_BLOCK = 16; // Threads handling each quantization block + constexpr int BLOCK_SIZE = BNB_WARP_SIZE; // Size of each quantization block + constexpr int NUM_PER_TH = 2; // Values per thread (for 4-bit packing) + constexpr int THREADS = BNB_WARP_SIZE; // Total threads (one full warp) + constexpr int THREADS_PER_BLOCK = BNB_WARP_SIZE / 2; // Half-warp per quantization block - const int base_idx = blockIdx.x * BLOCK_SIZE * 2; // 2 blocks per CUDA block + const int base_idx = blockIdx.x * BLOCK_SIZE * 2; // 2 blocks per thread block T vals[NUM_PER_TH]; unsigned char qvals[NUM_PER_TH / 2]; // For 4-bit: 2 values per byte @@ -403,10 +398,10 @@ __global__ void kQuantizeBlockwise32( const int block_id = threadIdx.x / THREADS_PER_BLOCK; // 0 for threads 0-15, 1 for threads 16-31 const int local_thread_id = threadIdx.x % THREADS_PER_BLOCK; // Thread ID within the block (0-15) - typedef cub::BlockLoad LoadT; - typedef cub::BlockStore StoreChar; - typedef cub::WarpReduce - WarpReduce; // Logical warp size of 16: threads 0-15 and 16-31 reduce independently + typedef bnb_cub::BlockLoad LoadT; + typedef bnb_cub::BlockStore StoreChar; + typedef bnb_cub::WarpReduce + WarpReduce; // Half-warp logical reduction: each half reduces independently __shared__ typename LoadT::TempStorage loadt; __shared__ typename StoreChar::TempStorage storec; @@ -429,7 +424,7 @@ __global__ void kQuantizeBlockwise32( local_abs_max = fmaxf(local_abs_max, fabsf((float)vals[j])); // Reduce within each logical warp of 16 threads independently - local_abs_max = WarpReduce(warp_reduce[block_id]).Reduce(local_abs_max, CUB_REDUCTIONOP_MAX); + local_abs_max = WarpReduce(warp_reduce[block_id]).Reduce(local_abs_max, BNB_MAX_OP); if (local_thread_id == 0) { if (block_valid) { @@ -478,8 +473,9 @@ __global__ void unsigned char qvals[NUM_PER_TH]; float local_abs_max = -FLT_MAX; - typedef cub::BlockLoad LoadChar; - typedef cub::BlockStore 0) ? 2 : 1), cub::BLOCK_STORE_WARP_TRANSPOSE> StoreT; + typedef bnb_cub::BlockLoad LoadChar; + typedef bnb_cub::BlockStore 0) ? 2 : 1), bnb_cub::BLOCK_STORE_WARP_TRANSPOSE> + StoreT; __shared__ typename LoadChar::TempStorage loadchar; __shared__ typename StoreT::TempStorage storet; @@ -548,9 +544,9 @@ __launch_bounds__(BLOCK_SIZE / NUM_VALS, 1) __global__ void kPreconditionOptimiz const float correction1 = 1.0f / (1.0f - powf(beta1, step)); const float correction2 = 1.0f / (1.0f - powf(beta2, step)); - typedef cub::BlockLoad Load; - typedef cub::BlockLoad LoadFloat; - typedef cub::BlockReduce BlockReduce; + typedef bnb_cub::BlockLoad Load; + typedef bnb_cub::BlockLoad LoadFloat; + typedef bnb_cub::BlockReduce BlockReduce; __shared__ union { typename Load::TempStorage load; @@ -643,11 +639,11 @@ __launch_bounds__(TH, 1) __global__ void kOptimizer32bit2State( update_scale = 1.0f; } - typedef cub::BlockLoad Load; - typedef cub::BlockStore Store; + typedef bnb_cub::BlockLoad Load; + typedef bnb_cub::BlockStore Store; - typedef cub::BlockLoad LoadFloat; - typedef cub::BlockStore StoreFloat; + typedef bnb_cub::BlockLoad LoadFloat; + typedef bnb_cub::BlockStore StoreFloat; __shared__ union { typename Load::TempStorage load; @@ -742,9 +738,9 @@ __launch_bounds__(BLOCK_SIZE / NUM_VALS, 1) __global__ void kPreconditionOptimiz float s1_vals[NUM_VALS]; - typedef cub::BlockLoad Load; - typedef cub::BlockLoad LoadFloat; - typedef cub::BlockReduce BlockReduce; + typedef bnb_cub::BlockLoad Load; + typedef bnb_cub::BlockLoad LoadFloat; + typedef bnb_cub::BlockReduce BlockReduce; __shared__ union { typename Load::TempStorage load; @@ -834,11 +830,11 @@ __launch_bounds__(TH, 1) __global__ void kOptimizer32bit1State( float s1_vals[NUM_PER_THREAD]; - typedef cub::BlockLoad Load; - typedef cub::BlockStore Store; + typedef bnb_cub::BlockLoad Load; + typedef bnb_cub::BlockStore Store; - typedef cub::BlockLoad LoadFloat; - typedef cub::BlockStore StoreFloat; + typedef bnb_cub::BlockLoad LoadFloat; + typedef bnb_cub::BlockStore StoreFloat; __shared__ union { typename Load::TempStorage load; @@ -939,17 +935,19 @@ __launch_bounds__(256, 3) __global__ void kOptimizerStatic8bit2StateBlockwise( T g_vals[N_PER_TH]; T p_vals[N_PER_TH]; - typedef cub::BlockLoad LoadT; - typedef cub::BlockLoad LoadChar; + typedef bnb_cub::BlockLoad LoadT; + typedef bnb_cub::BlockLoad + LoadChar; - typedef cub::BlockStore StoreChar; - typedef cub::BlockStore StoreT; + typedef bnb_cub::BlockStore + StoreChar; + typedef bnb_cub::BlockStore StoreT; __shared__ float smem_quantiles1[LANES][257]; __shared__ float smem_quantiles2[LANES][257]; - typedef cub::BlockReduce BlockReduce1; - typedef cub::BlockReduce BlockReduce2; - typedef cub::BlockReduce BlockReduce3; + typedef bnb_cub::BlockReduce BlockReduce1; + typedef bnb_cub::BlockReduce BlockReduce2; + typedef bnb_cub::BlockReduce BlockReduce3; __shared__ typename BlockReduce1::TempStorage reduce1; __shared__ typename BlockReduce2::TempStorage reduce2; __shared__ typename BlockReduce2::TempStorage reduce3; @@ -1041,11 +1039,11 @@ __launch_bounds__(256, 3) __global__ void kOptimizerStatic8bit2StateBlockwise( } // reduce: 2.51/1.60 -> 2.67/1.69 - new_local_abs_max1 = BlockReduce1(reduce1).Reduce(new_local_abs_max1, CUB_REDUCTIONOP_MAX); - new_local_abs_max2 = BlockReduce2(reduce2).Reduce(new_local_abs_max2, CUB_REDUCTIONOP_MAX); + new_local_abs_max1 = BlockReduce1(reduce1).Reduce(new_local_abs_max1, BNB_MAX_OP); + new_local_abs_max2 = BlockReduce2(reduce2).Reduce(new_local_abs_max2, BNB_MAX_OP); if (OPTIMIZER == ADEMAMIX) { - new_local_abs_max3 = BlockReduce3(reduce3).Reduce(new_local_abs_max3, CUB_REDUCTIONOP_MAX); + new_local_abs_max3 = BlockReduce3(reduce3).Reduce(new_local_abs_max3, BNB_MAX_OP); } if (threadIdx.x == 0) { @@ -1163,14 +1161,16 @@ __launch_bounds__(256, 3) __global__ void kOptimizerStatic8bit1StateBlockwise( T g_vals[N_PER_TH]; T p_vals[N_PER_TH]; - typedef cub::BlockLoad LoadT; - typedef cub::BlockLoad LoadChar; + typedef bnb_cub::BlockLoad LoadT; + typedef bnb_cub::BlockLoad + LoadChar; - typedef cub::BlockStore StoreChar; - typedef cub::BlockStore StoreT; + typedef bnb_cub::BlockStore + StoreChar; + typedef bnb_cub::BlockStore StoreT; __shared__ float smem_quantiles1[LANES][257]; - typedef cub::BlockReduce BlockReduce1; + typedef bnb_cub::BlockReduce BlockReduce1; __shared__ typename BlockReduce1::TempStorage reduce1; __shared__ float smem_exchange1[1]; @@ -1254,7 +1254,7 @@ __launch_bounds__(256, 3) __global__ void kOptimizerStatic8bit1StateBlockwise( } // reduce: 2.51/1.60 -> 2.67/1.69 - new_local_abs_max1 = BlockReduce1(reduce1).Reduce(new_local_abs_max1, CUB_REDUCTIONOP_MAX); + new_local_abs_max1 = BlockReduce1(reduce1).Reduce(new_local_abs_max1, BNB_MAX_OP); if (threadIdx.x == 0) smem_exchange1[0] = new_local_abs_max1; @@ -1322,7 +1322,7 @@ template __launch_bounds__(1024, BNB_MAX_THREADS_PER_SM / 1024) __global__ void kInt8VectorQuant(T* __restrict__ A, int8_t* out, float* rowStats, float threshold, int rows, int cols) { - using BlockReduceT = cub::BlockReduce; + using BlockReduceT = bnb_cub::BlockReduce; // One block per row. // Threads load column values in a striped arrangement. @@ -1352,7 +1352,7 @@ __launch_bounds__(1024, BNB_MAX_THREADS_PER_SM / 1024) __global__ } // Reduce thread-local absmax across the block. - const T row_absmax = BlockReduceT(temp_storage).Reduce(row_local_absmax, CUB_REDUCTIONOP_MAX, cols); + const T row_absmax = BlockReduceT(temp_storage).Reduce(row_local_absmax, BNB_MAX_OP, cols); if (threadIdx.x == 0) { // Save our block's absmax to shared memory for the quantization step. rowStats[row_id] = smem_row_absmax = row_absmax; @@ -1400,7 +1400,7 @@ __global__ void kdequant_mm_int32_fp16( float local_colStats[ITEMS_PER_THREAD]; float local_biasValue[ITEMS_PER_THREAD]; - typedef cub::BlockLoad LoadInt32; + typedef bnb_cub::BlockLoad LoadInt32; __shared__ typename LoadInt32::TempStorage loadint32; int row_idx, col_idx; @@ -1449,7 +1449,7 @@ __global__ void kgemm_4bit_inference_naive( // load step-by-step in chunks of [32,warps]: 1x32 * [32,warps] -> [1,warps] // 4 warps -> 4 loads per iter // 1x32 * 32x4 -> 1x4 outputs per thread block - typedef cub::WarpReduce WarpReduce; + typedef bnb_cub::WarpReduce WarpReduce; __shared__ typename WarpReduce::TempStorage temp_storage[THREADS / 32]; const int warp_idx = threadIdx.x / 32; @@ -1585,9 +1585,9 @@ template __global__ void kgemm_4bit_inference_naive( int M, int N, int K, half* __restrict__ const A, unsigned char* B, float* absmax, const float* datatype, half* out, int lda, int ldb, int ldc, int blocksize ); -template __global__ void kgemm_4bit_inference_naive<__nv_bfloat16, 128, 16>( - int M, int N, int K, __nv_bfloat16* __restrict__ const A, unsigned char* B, float* absmax, const float* datatype, - __nv_bfloat16* out, int lda, int ldb, int ldc, int blocksize +template __global__ void kgemm_4bit_inference_naive( + int M, int N, int K, bnb_bfloat16* __restrict__ const A, unsigned char* B, float* absmax, const float* datatype, + bnb_bfloat16* out, int lda, int ldb, int ldc, int blocksize ); template __global__ void kgemm_4bit_inference_naive( int M, int N, int K, float* __restrict__ const A, unsigned char* B, float* absmax, const float* datatype, @@ -1610,16 +1610,16 @@ template __device__ unsigned char dQuantize<1>(float* smem_code, const float ran MAKE_PreconditionOptimizer32bit1State(MOMENTUM, half) MAKE_PreconditionOptimizer32bit1State(MOMENTUM, float) -MAKE_PreconditionOptimizer32bit1State(MOMENTUM, __nv_bfloat16) +MAKE_PreconditionOptimizer32bit1State(MOMENTUM, bnb_bfloat16) MAKE_PreconditionOptimizer32bit1State(RMSPROP, half) MAKE_PreconditionOptimizer32bit1State(RMSPROP, float) -MAKE_PreconditionOptimizer32bit1State(RMSPROP, __nv_bfloat16) +MAKE_PreconditionOptimizer32bit1State(RMSPROP, bnb_bfloat16) MAKE_PreconditionOptimizer32bit1State(LION, half) MAKE_PreconditionOptimizer32bit1State(LION, float) -MAKE_PreconditionOptimizer32bit1State(LION, __nv_bfloat16) +MAKE_PreconditionOptimizer32bit1State(LION, bnb_bfloat16) MAKE_PreconditionOptimizer32bit1State(ADAGRAD, half) MAKE_PreconditionOptimizer32bit1State(ADAGRAD, float) -MAKE_PreconditionOptimizer32bit1State(ADAGRAD, __nv_bfloat16) +MAKE_PreconditionOptimizer32bit1State(ADAGRAD, bnb_bfloat16) #define MAKE_Optimizer32bit1State(oname, gtype) \ template __global__ void kOptimizer32bit1State( \ @@ -1630,16 +1630,16 @@ MAKE_PreconditionOptimizer32bit1State(ADAGRAD, __nv_bfloat16) MAKE_Optimizer32bit1State(MOMENTUM, half) MAKE_Optimizer32bit1State(MOMENTUM, float) -MAKE_Optimizer32bit1State(MOMENTUM, __nv_bfloat16) +MAKE_Optimizer32bit1State(MOMENTUM, bnb_bfloat16) MAKE_Optimizer32bit1State(RMSPROP, half) MAKE_Optimizer32bit1State(RMSPROP, float) -MAKE_Optimizer32bit1State(RMSPROP, __nv_bfloat16) +MAKE_Optimizer32bit1State(RMSPROP, bnb_bfloat16) MAKE_Optimizer32bit1State(LION, half) MAKE_Optimizer32bit1State(LION, float) -MAKE_Optimizer32bit1State(LION, __nv_bfloat16) +MAKE_Optimizer32bit1State(LION, bnb_bfloat16) MAKE_Optimizer32bit1State(ADAGRAD, half) MAKE_Optimizer32bit1State(ADAGRAD, float) -MAKE_Optimizer32bit1State(ADAGRAD, __nv_bfloat16) +MAKE_Optimizer32bit1State(ADAGRAD, bnb_bfloat16) #define MAKE_PreconditionOptimizer32bit2State(oname, gtype) \ template __global__ void kPreconditionOptimizer32bit2State( \ @@ -1650,10 +1650,10 @@ MAKE_Optimizer32bit1State(ADAGRAD, __nv_bfloat16) MAKE_PreconditionOptimizer32bit2State(ADAM, float) MAKE_PreconditionOptimizer32bit2State(ADAM, half) -MAKE_PreconditionOptimizer32bit2State(ADAM, __nv_bfloat16) +MAKE_PreconditionOptimizer32bit2State(ADAM, bnb_bfloat16) MAKE_PreconditionOptimizer32bit2State(ADEMAMIX, float) MAKE_PreconditionOptimizer32bit2State(ADEMAMIX, half) -MAKE_PreconditionOptimizer32bit2State(ADEMAMIX, __nv_bfloat16) +MAKE_PreconditionOptimizer32bit2State(ADEMAMIX, bnb_bfloat16) template __global__ void kOptimizer32bit2State( float* g, float* p, float* state1, float* state2, float* unorm, const float max_unorm, const float param_norm, @@ -1667,8 +1667,8 @@ template __global__ void kOptimizer32bit2State( const float weight_decay, const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n ); -template __global__ void kOptimizer32bit2State<__nv_bfloat16, ADAM>( - __nv_bfloat16* g, __nv_bfloat16* p, float* state1, float* state2, float* unorm, const float max_unorm, +template __global__ void kOptimizer32bit2State( + bnb_bfloat16* g, bnb_bfloat16* p, float* state1, float* state2, float* unorm, const float max_unorm, const float param_norm, const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay, const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n @@ -1685,8 +1685,8 @@ template __global__ void kOptimizer32bit2State( const float weight_decay, const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n ); -template __global__ void kOptimizer32bit2State<__nv_bfloat16, ADEMAMIX>( - __nv_bfloat16* g, __nv_bfloat16* p, float* state1, float* state2, float* unorm, const float max_unorm, +template __global__ void kOptimizer32bit2State( + bnb_bfloat16* g, bnb_bfloat16* p, float* state1, float* state2, float* unorm, const float max_unorm, const float param_norm, const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay, const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n @@ -1743,42 +1743,44 @@ MAKE_kQuantizeBlockwise(float, 256, 2, 0, NF4) MAKE_kQuantizeBlockwise(float, 128, 2, 0, NF4) MAKE_kQuantizeBlockwise(float, 64, 2, 0, NF4) -MAKE_kQuantizeBlockwise(__nv_bfloat16, 4096, 4, 0, General8bit) -MAKE_kQuantizeBlockwise(__nv_bfloat16, 4096, 4, 1, General8bit) -MAKE_kQuantizeBlockwise(__nv_bfloat16, 2048, 4, 0, General8bit) -MAKE_kQuantizeBlockwise(__nv_bfloat16, 1024, 4, 0, General8bit) -MAKE_kQuantizeBlockwise(__nv_bfloat16, 512, 2, 0, General8bit) -MAKE_kQuantizeBlockwise(__nv_bfloat16, 256, 2, 0, General8bit) -MAKE_kQuantizeBlockwise(__nv_bfloat16, 128, 2, 0, General8bit) -MAKE_kQuantizeBlockwise(__nv_bfloat16, 64, 2, 0, General8bit) -MAKE_kQuantizeBlockwise(__nv_bfloat16, 4096, 4, 0, FP4) -MAKE_kQuantizeBlockwise(__nv_bfloat16, 2048, 4, 0, FP4) -MAKE_kQuantizeBlockwise(__nv_bfloat16, 1024, 4, 0, FP4) -MAKE_kQuantizeBlockwise(__nv_bfloat16, 512, 2, 0, FP4) -MAKE_kQuantizeBlockwise(__nv_bfloat16, 256, 2, 0, FP4) -MAKE_kQuantizeBlockwise(__nv_bfloat16, 128, 2, 0, FP4) -MAKE_kQuantizeBlockwise(__nv_bfloat16, 64, 2, 0, FP4) -MAKE_kQuantizeBlockwise(__nv_bfloat16, 4096, 4, 0, NF4) -MAKE_kQuantizeBlockwise(__nv_bfloat16, 2048, 4, 0, NF4) -MAKE_kQuantizeBlockwise(__nv_bfloat16, 1024, 4, 0, NF4) -MAKE_kQuantizeBlockwise(__nv_bfloat16, 512, 2, 0, NF4) -MAKE_kQuantizeBlockwise(__nv_bfloat16, 256, 2, 0, NF4) -MAKE_kQuantizeBlockwise(__nv_bfloat16, 128, 2, 0, NF4) -MAKE_kQuantizeBlockwise(__nv_bfloat16, 64, 2, 0, NF4) +MAKE_kQuantizeBlockwise(bnb_bfloat16, 4096, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(bnb_bfloat16, 4096, 4, 1, General8bit) +MAKE_kQuantizeBlockwise(bnb_bfloat16, 2048, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(bnb_bfloat16, 1024, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(bnb_bfloat16, 512, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(bnb_bfloat16, 256, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(bnb_bfloat16, 128, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(bnb_bfloat16, 64, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(bnb_bfloat16, 4096, 4, 0, FP4) +MAKE_kQuantizeBlockwise(bnb_bfloat16, 2048, 4, 0, FP4) +MAKE_kQuantizeBlockwise(bnb_bfloat16, 1024, 4, 0, FP4) +MAKE_kQuantizeBlockwise(bnb_bfloat16, 512, 2, 0, FP4) +MAKE_kQuantizeBlockwise(bnb_bfloat16, 256, 2, 0, FP4) +MAKE_kQuantizeBlockwise(bnb_bfloat16, 128, 2, 0, FP4) +MAKE_kQuantizeBlockwise(bnb_bfloat16, 64, 2, 0, FP4) +MAKE_kQuantizeBlockwise(bnb_bfloat16, 4096, 4, 0, NF4) +MAKE_kQuantizeBlockwise(bnb_bfloat16, 2048, 4, 0, NF4) +MAKE_kQuantizeBlockwise(bnb_bfloat16, 1024, 4, 0, NF4) +MAKE_kQuantizeBlockwise(bnb_bfloat16, 512, 2, 0, NF4) +MAKE_kQuantizeBlockwise(bnb_bfloat16, 256, 2, 0, NF4) +MAKE_kQuantizeBlockwise(bnb_bfloat16, 128, 2, 0, NF4) +MAKE_kQuantizeBlockwise(bnb_bfloat16, 64, 2, 0, NF4) // Template instantiations for blocksize=32 specialized kernel (4-bit only) -#define MAKE_kQuantizeBlockwise32(dtype, data_type_name) \ - template __global__ void kQuantizeBlockwise32( \ +#define MAKE_kQuantizeBlockwiseSmall(dtype, data_type_name) \ + template __global__ void kQuantizeBlockwiseSmall( \ float* code, dtype* __restrict__ const A, float* absmax, unsigned char* out, float* __restrict__ const rand, \ const int rand_offset, const int n \ ); // FP4 instantiations for blocksize=32 -MAKE_kQuantizeBlockwise32(half, FP4) MAKE_kQuantizeBlockwise32(float, FP4) MAKE_kQuantizeBlockwise32(__nv_bfloat16, FP4) +MAKE_kQuantizeBlockwiseSmall(half, FP4) MAKE_kQuantizeBlockwiseSmall(float, FP4) MAKE_kQuantizeBlockwiseSmall( + bnb_bfloat16, FP4 +) // NF4 instantiations for blocksize=32 - MAKE_kQuantizeBlockwise32(half, NF4) MAKE_kQuantizeBlockwise32(float, NF4) MAKE_kQuantizeBlockwise32( - __nv_bfloat16, NF4 + MAKE_kQuantizeBlockwiseSmall(half, NF4) MAKE_kQuantizeBlockwiseSmall(float, NF4) MAKE_kQuantizeBlockwiseSmall( + bnb_bfloat16, NF4 ) template __global__ void kDequantizeBlockwise( @@ -1799,14 +1801,14 @@ template __global__ void kDequantizeBlockwise( template __global__ void kDequantizeBlockwise( float* code, unsigned char* A, float* absmax, float* out, const int blocksize, const int n ); -template __global__ void kDequantizeBlockwise<__nv_bfloat16, 512, 64, 8, FP4>( - float* code, unsigned char* A, float* absmax, __nv_bfloat16* out, const int blocksize, const int n +template __global__ void kDequantizeBlockwise( + float* code, unsigned char* A, float* absmax, bnb_bfloat16* out, const int blocksize, const int n ); -template __global__ void kDequantizeBlockwise<__nv_bfloat16, 512, 64, 8, General8bit>( - float* code, unsigned char* A, float* absmax, __nv_bfloat16* out, const int blocksize, const int n +template __global__ void kDequantizeBlockwise( + float* code, unsigned char* A, float* absmax, bnb_bfloat16* out, const int blocksize, const int n ); -template __global__ void kDequantizeBlockwise<__nv_bfloat16, 512, 64, 8, NF4>( - float* code, unsigned char* A, float* absmax, __nv_bfloat16* out, const int blocksize, const int n +template __global__ void kDequantizeBlockwise( + float* code, unsigned char* A, float* absmax, bnb_bfloat16* out, const int blocksize, const int n ); #define MAKE_OptimizerStatic8bit2StateBlockwise(oname, gtype, block_size, num_per_thread) \ @@ -1819,10 +1821,10 @@ template __global__ void kDequantizeBlockwise<__nv_bfloat16, 512, 64, 8, NF4>( MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, float, 256, 1) MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, half, 256, 1) -MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, __nv_bfloat16, 256, 1) +MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, bnb_bfloat16, 256, 1) MAKE_OptimizerStatic8bit2StateBlockwise(ADEMAMIX, float, 256, 1) MAKE_OptimizerStatic8bit2StateBlockwise(ADEMAMIX, half, 256, 1) -MAKE_OptimizerStatic8bit2StateBlockwise(ADEMAMIX, __nv_bfloat16, 256, 1) +MAKE_OptimizerStatic8bit2StateBlockwise(ADEMAMIX, bnb_bfloat16, 256, 1) #define MAKE_OptimizerStatic8bit1StateBlockwise(oname, gtype, block_size, num_per_thread) \ template __global__ void kOptimizerStatic8bit1StateBlockwise( \ @@ -1833,13 +1835,13 @@ MAKE_OptimizerStatic8bit2StateBlockwise(ADEMAMIX, __nv_bfloat16, 256, 1) MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, float, 256, 1) MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, half, 256, 1) -MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, __nv_bfloat16, 256, 1) +MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, bnb_bfloat16, 256, 1) MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, float, 256, 1) MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, half, 256, 1) -MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, __nv_bfloat16, 256, 1) +MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, bnb_bfloat16, 256, 1) MAKE_OptimizerStatic8bit1StateBlockwise(LION, float, 256, 1) MAKE_OptimizerStatic8bit1StateBlockwise(LION, half, 256, 1) -MAKE_OptimizerStatic8bit1StateBlockwise(LION, __nv_bfloat16, 256, 1) +MAKE_OptimizerStatic8bit1StateBlockwise(LION, bnb_bfloat16, 256, 1) MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, float, 256, 1) MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, half, 256, 1) -MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, __nv_bfloat16, 256, 1) +MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, bnb_bfloat16, 256, 1) diff --git a/csrc/kernels.cuh b/csrc/kernels.cuh index 6de55f2e8..6ac6732fc 100644 --- a/csrc/kernels.cuh +++ b/csrc/kernels.cuh @@ -15,7 +15,7 @@ __global__ void kQuantizeBlockwise( const int rand_offset, const int n ); template -__global__ void kQuantizeBlockwise32( +__global__ void kQuantizeBlockwiseSmall( float* code, T* __restrict__ const A, float* absmax, unsigned char* out, float* __restrict__ const rand, const int rand_offset, const int n ); diff --git a/csrc/kernels.hip b/csrc/kernels.hip deleted file mode 100644 index 691f6e07c..000000000 --- a/csrc/kernels.hip +++ /dev/null @@ -1,1923 +0,0 @@ -// !!! This is a file automatically generated by hipify!!! -#include "hip/hip_runtime.h" -// Copyright (c) Facebook, Inc. and its affiliates. -// -// This source code is licensed under the MIT license found in the -// LICENSE file in the root directory of this source tree. - -#include "kernels_hip.cuh" -#include "common_hip.cuh" -#include -#include -#include - -//#include - - -#define HLF_MAX 65504 -#define TH 1024 -#define NUM 4 -#define NUM_BLOCK 4096 - -__device__ static float fp4_dequantization_lut[8] = { - 0.0f, // 0b000 - 0.005208333333f, // 0b001 - 0.66666667f, // 0b010 - 1.0f, // 0b011 - 0.33333333f, // 0b100 - 0.5f, // 0b101 - 0.16666667f, // 0b110 - 0.25f // 0b111 -}; - -__device__ static float nf4_dequantization_lut[16] = { - -1.0f, // 0b0000 - -0.6961928009986877f, // 0b0001 - -0.5250730514526367f, // 0b0010 - -0.39491748809814453f, // 0b0011 - -0.28444138169288635f, // 0b0100 - -0.18477343022823334f, // 0b0101 - -0.09105003625154495f, // 0b0110 - 0.0f, // 0b0111 - 0.07958029955625534f, // 0b1000 - 0.16093020141124725f, // 0b1001 - 0.24611230194568634f, // 0b1010 - 0.33791524171829224f, // 0b1011 - 0.44070982933044434f, // 0b1100 - 0.5626170039176941f, // 0b1101 - 0.7229568362236023f, // 0b1110 - 1.0f // 0b1111 -}; - -// source: https://stackoverflow.com/questions/17399119/how-do-i-use-atomicmax-on-floating-point-values-in-cuda -// Luckily we have atomicmax and atomicmin in ROCm - -__device__ __forceinline__ float dDequantizeFP4Tree(unsigned char val) { - float sign = 1.0f - 2 * ((val & 0b1000) >> 3); - return fp4_dequantization_lut[val & 0b111] * sign; -} - -__device__ unsigned char dQuantizeFP4(float x) -{ - // FP4 with bias of 3 - // first bit is a sign - // subnormals - // 0b000 = 0 - // 0b001 = 0.0625 - // 0b110 = 2 - // 0b111 = 3 - // 0b100 = 4 - // 0b101 = 6 - // 0b010 = 8 - // 0b011 = 12 - - - // we do a binary search - // the pivots are divided by 12 (the FP4 absmax) - // since we assume input data is in [-1.0, 1.0] - - // !be careful here, its easy to make a mistake - // that is difficult to notice if you add an extra - // zero somewhere! - - int sign = x < 0 ? 0b1000 : 0b0000; - x = fabsf(x); - if(x > 0.29166667f) - if( x > 0.583333f) - if( x > 0.8333333f) - return 0b0011+sign; - else - return 0b0010+sign; - else - if(x > 0.4166667f) - return 0b101+sign; - else - return 0b100+sign; - else - if(x > 0.0859375f) - if(x > 0.20833333f) - return 0b0111+sign; - else - return 0b0110+sign; - else - if(x > 0.00260417f) - return 0b0001+sign; - else - return 0b0000+sign; -} - -__device__ __forceinline__ float dDequantizeNF4(unsigned char val) { return nf4_dequantization_lut[val & 0x0F]; } - -__device__ unsigned char dQuantizeNF4(float x) -{ - - // the values for this tree was generated by test_normal_map_tree - // in the file tests/test_functional.py - if(x > 0.03979014977812767f) - if(x > 0.3893125355243683f) // 1 - if(x > 0.6427869200706482f) // 11 - if(x > 0.8614784181118011f) // 111 - return 0b1111; - else - return 0b1110; - else - if(x > 0.5016634166240692f) // 110 - return 0b1101; - else - return 0b1100; - else - if(x > 0.2035212516784668f) // 10 - if(x > 0.2920137718319893f) // 101 - return 0b1011; - else - return 0b1010; - else - if(x > 0.1202552504837513f) // 100 - return 0b1001; - else - return 0b1000; - else - if(x > -0.33967943489551544f) // 0 - if(x > -0.13791173323988914f) // 01 - if(x > -0.045525018125772476f) // 011 - return 0b0111; - else - return 0b0110; - else - if(x > -0.23460740596055984f) // 010 - return 0b0101; - else - return 0b0100; - else - if(x > -0.6106329262256622f) // 00 - if(x > -0.4599952697753906f) // 001 - return 0b0011; - else - return 0b0010; - else - if(x > -0.8480964004993439f) // 000 - return 0b0001; - else - return 0b0000; -} -// sign function for lion -// taken from https://stackoverflow.com/a/4609795, but not sure if there's a proper way to do this in CUDA - -template __device__ int sgn(T val) -{ - return (T(0) < val) - (val < T(0)); -} - -template -__device__ unsigned char dQuantize(float* smem_code, const float rand, float x) -{ - int pivot = 127; - int upper_pivot = 255; - int lower_pivot = 0; - - float lower = -1.0f; - float upper = 1.0f; - - float val = smem_code[pivot]; - // i>>=1 = {32, 16, 8, 4, 2, 1} - for(int i = 64; i > 0; i>>=1) - { - if(x > val) - { - lower_pivot = pivot; - lower = val; - pivot+=i; - } - else - { - upper_pivot = pivot; - upper = val; - pivot-=i; - } - val = smem_code[pivot]; - } - - if(upper_pivot == 255) - upper = smem_code[upper_pivot]; - if(lower_pivot == 0) - lower = smem_code[lower_pivot]; - - if(!STOCHASTIC) - { - if(x > val) - { - float midpoint = (upper+val)*0.5f; - if(x > midpoint) - { - return upper_pivot; - } - else - return pivot; - } - else - { - float midpoint = (lower+val)*0.5f; - if(x < midpoint) - return lower_pivot; - else - return pivot; - } - } - else - { - if(x > val) - { - float dist_to_upper = fabsf(upper-x); - float dist_full = upper-val; - if(rand >= dist_to_upper/dist_full) return upper_pivot; - else return pivot; - } - else - { - float dist_to_lower = fabsf(lower-x); - float dist_full = val-lower; - if(rand >= dist_to_lower/dist_full) return lower_pivot; - else return pivot; - } - } -} - -template -__device__ __forceinline__ unsigned char quantize_2D(float *__restrict__ quadrants, float *__restrict__ const smem_code, float x) -{ - int pivot = 127; - int upper_pivot = 255; - int lower_pivot = 0; - - float lower = SIGNED ? -1.0f : 0.0f; - float upper = 1.0f; - float midpoint; - float val = quadrants[1]; - int local_pivot = 1; - int offset = 1; - - // i>>=1 = {32, 16, 8, 4, 2, 1} - for(int i = 64; i > 0; i>>=1) - { - if(x > val) - { - lower_pivot = pivot; - lower = val; - pivot+=i; - //val = i == 64 ? quadrants[2] : smem_code[pivot]; - local_pivot += offset; - } - else - { - upper_pivot = pivot; - upper = val; - pivot-=i; - //val = i == 64 ? quadrants[0] : smem_code[pivot]; - local_pivot -= offset; - } - val = i >= 64 ? quadrants[local_pivot] : smem_code[pivot]; - offset -= 1; - } - - if(x > val) - { - midpoint = (upper+val)*0.5f; - if(x > midpoint) - return upper_pivot; - else - return pivot; - } - else - { - midpoint = (lower+val)*0.5f; - if(x < midpoint) - return lower_pivot; - else - return pivot; - } -} - -template -//__launch_bounds__(TH, 4) -__global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n) -{ - // This can overflow, so we clamp to INT32_MAX. We won't have more elements than this. - const int n_full = min(gridDim.x * BLOCK_SIZE, INT32_MAX); - int valid_items = 0; - const int base_idx = blockIdx.x * BLOCK_SIZE; - - T vals[NUM_PER_TH]; - float rand_vals[NUM_PER_TH]; - unsigned char qvals[(DATA_TYPE > 0) ? NUM_PER_TH / 2 : NUM_PER_TH]; - - float local_abs_max = 0.0f; - int local_rand_idx = 0; - - typedef hipcub::BlockLoad LoadT; - typedef hipcub::BlockStore 0) ? NUM_PER_TH/2 : NUM_PER_TH, hipcub::BLOCK_STORE_WARP_TRANSPOSE> StoreChar; - typedef hipcub::BlockReduce BlockReduce; - typedef hipcub::BlockLoad LoadFloat; - - __shared__ typename LoadT::TempStorage loadt; - __shared__ typename LoadFloat::TempStorage loadf; - __shared__ typename StoreChar::TempStorage storec; - __shared__ typename BlockReduce::TempStorage reduce; - __shared__ float smem_code[256]; - __shared__ float smem_absmax_value[1]; - - if(DATA_TYPE == General8bit) - for(int i = threadIdx.x; i < 256; i+=blockDim.x) - smem_code[i] = code[i]; - - - for (int64_t i = base_idx; i < n_full; i += gridDim.x * BLOCK_SIZE) { - valid_items = min(BLOCK_SIZE, static_cast(n - i)); - local_abs_max = -FLT_MAX; - - __syncthreads(); - LoadT(loadt).Load(&(A[i]), vals, valid_items, (T)0.0f); - - // 1. compute local max - // 2. broadcast local max - // 3. normalize inputs and quantize - - #pragma unroll NUM_PER_TH - for(int j = 0; j < NUM_PER_TH; j++) - local_abs_max = fmaxf(local_abs_max, fabsf((float)vals[j])); - - local_abs_max = BlockReduce(reduce).Reduce(local_abs_max, hipcub::Max(), valid_items); - - if(threadIdx.x == 0) { - smem_absmax_value[0] = 1.0f / local_abs_max; - absmax[i / BLOCK_SIZE] = local_abs_max; - } - __syncthreads(); - - local_abs_max = smem_absmax_value[0]; - - if(STOCHASTIC) - { - local_rand_idx = ((blockIdx.x*NUM_BLOCK) + (threadIdx.x*NUM) + rand_offset) % (1024-4); - LoadFloat(loadf).Load(&rand[local_rand_idx], rand_vals, BLOCK_SIZE, 0); - } - - switch(DATA_TYPE) - { - case General8bit: - #pragma unroll NUM_PER_TH - for(int j = 0; j < NUM_PER_TH; j++) - { - if(!STOCHASTIC) - qvals[j] = dQuantize<0>(smem_code, 0.0f, ((float)vals[j])*local_abs_max); - else - qvals[j] = dQuantize<1>(smem_code, rand_vals[j], ((float)vals[j])*local_abs_max); - } - break; - case FP4: - #pragma unroll NUM_PER_TH - for(int j = 0; j < NUM_PER_TH/2; j++) - { - qvals[j] = dQuantizeFP4(((float)vals[2 * j]) * local_abs_max) << 4; - qvals[j] |= dQuantizeFP4(((float)vals[2 * j + 1]) * local_abs_max); - } - break; - case NF4: - #pragma unroll NUM_PER_TH - for(int j = 0; j < NUM_PER_TH/2; j++) - { - qvals[j] = dQuantizeNF4(((float)vals[2 * j]) * local_abs_max) << 4; - qvals[j] |= dQuantizeNF4(((float)vals[2 * j + 1]) * local_abs_max); - } - break; - } - - __syncthreads(); - StoreChar(storec).Store(&(out[(DATA_TYPE > 0) ? i/2 : i]), qvals, (DATA_TYPE > 0) ? (valid_items+1)/2 : valid_items); - } -} - -// Specialized kernel for blocksize=64 with 4-bit quantization -// Works on both warp32 and warp64 hardware -// Processes 2 blocks of 64 values per thread block using 64 threads -// Uses logical warps of 32: threads 0-31 handle block 0, threads 32-63 handle block 1 -// - warp32: 2 hardware warps, each reduces naturally -// - warp64: 1 hardware warp split into 2 logical warps of 32 -template -__global__ void kQuantizeBlockwise64( - float* code, T* __restrict__ const A, float* absmax, unsigned char* out, float* __restrict__ const rand, - const int rand_offset, const int n -) { - constexpr int BLOCK_SIZE = 64; // Size of each quantization block - constexpr int NUM_PER_TH = 2; // Values per thread (for 4-bit packing) - constexpr int THREADS = 64; // Total threads per HIP block - constexpr int THREADS_PER_BLOCK = 32; // Threads handling each quantization block - - const int base_idx = blockIdx.x * BLOCK_SIZE * 2; // 2 quantization blocks per HIP block - - T vals[NUM_PER_TH]; - unsigned char qvals[NUM_PER_TH / 2]; // For 4-bit: 2 values per byte - float local_abs_max = 0.0f; - - const int block_id = threadIdx.x / THREADS_PER_BLOCK; // 0 for threads 0-31, 1 for threads 32-63 - const int local_thread_id = threadIdx.x % THREADS_PER_BLOCK; // Thread ID within the quantization block (0-31) - - typedef hipcub::BlockLoad LoadT; - typedef hipcub::BlockStore StoreChar; - // Logical warp size of 32: on warp32 this matches hardware warps, - // on warp64 this splits the single hardware warp into two independent reductions - typedef hipcub::WarpReduce WarpReduce; - - __shared__ typename LoadT::TempStorage loadt; - __shared__ typename StoreChar::TempStorage storec; - __shared__ typename WarpReduce::TempStorage warp_reduce[2]; // One per logical warp - __shared__ float smem_absmax_value[2]; - - const int i = base_idx + block_id * BLOCK_SIZE; - // Use a flag instead of early return: BlockLoad/BlockStore/__syncthreads are cooperative - // operations that require ALL 64 threads to participate - const bool block_valid = (i < n); - - // All 64 threads participate in the load (out-of-bounds threads get 0.0f) - __syncthreads(); - LoadT(loadt).Load(&(A[base_idx]), vals, min(BLOCK_SIZE * 2, n - base_idx), (T)0.0f); - - // Each thread computes max of its values - local_abs_max = -FLT_MAX; -#pragma unroll NUM_PER_TH - for (int j = 0; j < NUM_PER_TH; j++) - local_abs_max = fmaxf(local_abs_max, fabsf((float)vals[j])); - - // Reduce within each logical warp of 32 threads independently - local_abs_max = WarpReduce(warp_reduce[block_id]).Reduce(local_abs_max, hipcub::Max()); - - if (local_thread_id == 0) { - if (block_valid) { - smem_absmax_value[block_id] = 1.0f / local_abs_max; - absmax[blockIdx.x * 2 + block_id] = local_abs_max; - } else { - smem_absmax_value[block_id] = 0.0f; - } - } - __syncthreads(); - - local_abs_max = smem_absmax_value[block_id]; - - switch (DATA_TYPE) { - case FP4: -#pragma unroll NUM_PER_TH - for (int j = 0; j < NUM_PER_TH / 2; j++) { - qvals[j] = dQuantizeFP4(((float)vals[2 * j]) * local_abs_max) << 4; - qvals[j] |= dQuantizeFP4(((float)vals[2 * j + 1]) * local_abs_max); - } - break; - case NF4: -#pragma unroll NUM_PER_TH - for (int j = 0; j < NUM_PER_TH / 2; j++) { - qvals[j] = dQuantizeNF4(((float)vals[2 * j]) * local_abs_max) << 4; - qvals[j] |= dQuantizeNF4(((float)vals[2 * j + 1]) * local_abs_max); - } - break; - } - - // All 64 threads participate in the store (valid_items limits the actual writes) - __syncthreads(); - StoreChar(storec).Store(&(out[base_idx / 2]), qvals, min((BLOCK_SIZE * 2 + 1) / 2, (n - base_idx + 1) / 2)); -} - -template -__global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, T *out, const int blocksize, const int n) -{ - - const int n_load = (gridDim.x * TILE_SIZE); - int valid_items_load = 0; - int valid_items_store = 0; - const int base_idx = (blockIdx.x * TILE_SIZE); - - T vals[NUM_PER_TH*((DATA_TYPE > 0) ? 2 : 1)]; - unsigned char qvals[NUM_PER_TH]; - float local_abs_max = -FLT_MAX; - - typedef hipcub::BlockLoad LoadChar; - typedef hipcub::BlockStore 0) ? 2 : 1), hipcub::BLOCK_STORE_WARP_TRANSPOSE> StoreT; - - __shared__ typename LoadChar::TempStorage loadchar; - __shared__ typename StoreT::TempStorage storet; - - for (int i = base_idx; i < n_load; i += gridDim.x*TILE_SIZE) - { - if (DATA_TYPE > 0) - { - // Cast n to int64_t to avoid overflow for large n - valid_items_load = min(TILE_SIZE, static_cast((static_cast(n) + 1) / 2) - i); - valid_items_store = min(TILE_SIZE * 2, n - i * 2); - } - else - { - valid_items_load = min(TILE_SIZE, n - i); - valid_items_store = valid_items_load; - } - - // Since blocksize will always be a power-of-2, we avoid more expensive - // division by the blocksize and instead use a shift operation. - // This is equivalent to (i+threadId.x*NUM_PER_TH)/blocksize. - local_abs_max = __ldg(&absmax[(i+threadIdx.x*NUM_PER_TH) >> (31 - __clz(blocksize))]); - - __syncthreads(); - LoadChar(loadchar).Load(&(A[i]), qvals, valid_items_load, 128); - - switch (DATA_TYPE) - { - case General8bit: - // load code through read-only cache via __ldg - #pragma unroll NUM_PER_TH - for(int j = 0; j < NUM_PER_TH; j++) - vals[j] = __ldg(&code[qvals[j]])*local_abs_max; - break; - case FP4: - #pragma unroll NUM_PER_TH - for(int j = 0; j < NUM_PER_TH; j++) - { - vals[j * 2] = dDequantizeFP4Tree(qvals[j] >> 4) * local_abs_max; - vals[j * 2 + 1] = dDequantizeFP4Tree(qvals[j] & 0x0F) * local_abs_max; - } - break; - case NF4: - #pragma unroll NUM_PER_TH - for(int j = 0; j < NUM_PER_TH; j++) - { - vals[j*2] = dDequantizeNF4(qvals[j] >> 4)* local_abs_max; - vals[j*2 + 1] = dDequantizeNF4(qvals[j] & 0x0F)* local_abs_max; - } - break; - } - - __syncthreads(); - StoreT(storet).Store(&(out[(DATA_TYPE > 0) ? i*2 : i]), vals, valid_items_store); - } -} - -template -__launch_bounds__(BLOCK_SIZE/NUM_VALS, 1) -__global__ void kPreconditionOptimizer32bit2State(T* g, T* p, - float* state1, float* state2, float *unorm, - const float beta1, const float beta2, const float eps, const float weight_decay, - const int step, const float lr, const float gnorm_scale, const int n) -{ - - const int n_full = (BLOCK_SIZE*(n/BLOCK_SIZE)) + (n % BLOCK_SIZE == 0 ? 0 : BLOCK_SIZE); - const int base_idx = (blockIdx.x * blockDim.x * NUM_VALS); - int valid_items = 0; - - T g_vals[NUM_VALS]; - - float s1_vals[NUM_VALS]; - float s2_vals[NUM_VALS]; - - const float correction1 = 1.0f/(1.0f - powf(beta1, step)); - const float correction2 = 1.0f/(1.0f - powf(beta2, step)); - - typedef hipcub::BlockLoad Load; - typedef hipcub::BlockLoad LoadFloat; - typedef hipcub::BlockReduce BlockReduce; - - __shared__ union { - typename Load::TempStorage load; - typename LoadFloat::TempStorage loadf; - typename BlockReduce::TempStorage reduce; - } temp_storage; - - for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_SIZE) - { - valid_items = n - i >= (BLOCK_SIZE) ? (BLOCK_SIZE) : n - i; - - __syncthreads(); - Load(temp_storage.load).Load(&(g[i]), g_vals, valid_items, 0.0f); - __syncthreads(); - LoadFloat(temp_storage.loadf).Load(&(state1[i]), s1_vals, valid_items, 0.0f); - __syncthreads(); - LoadFloat(temp_storage.loadf).Load(&(state2[i]), s2_vals, valid_items, 0.0f); - - # pragma unroll NUM_VALS - for(unsigned int j = 0; j < NUM_VALS; j++) - g_vals[j] = gnorm_scale*((float)g_vals[j]); - - # pragma unroll NUM_VALS - for(unsigned int j = 0; j < NUM_VALS; j++) - { - switch(OPTIMIZER) - { - case ADAM: - s1_vals[j] = s1_vals[j]*beta1 + ((1.0f -beta1)*((float)g_vals[j])); - s2_vals[j] = s2_vals[j]*beta2 + ((1.0f -beta2)*(((float)g_vals[j])*((float)g_vals[j]))); - s1_vals[j] *= correction1; - s2_vals[j] *= correction2; - s1_vals[j] = s1_vals[j]/(sqrtf(s2_vals[j])+eps); // update - s1_vals[j] *= s1_vals[j]; // update l2 norm (update*update) - break; - case ADEMAMIX: - break; - } - } - - # pragma unroll NUM_VALS-1 - for(unsigned int j = 1; j < NUM_VALS; j++) - s1_vals[0] += s1_vals[j]; - - __syncthreads(); - s1_vals[0] = BlockReduce(temp_storage.reduce).Sum(s1_vals[0]); - - if(threadIdx.x == 0) - atomicAdd(&unorm[0], s1_vals[0]); - - //__syncwarp(); - } -} - - - -#define NUM_PER_THREAD 4 - -template -__launch_bounds__(TH, 1) -__global__ void kOptimizer32bit2State(T* g, T* p, - float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, - const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay, - const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n) -{ - - const int n_full = ((TH*NUM_PER_THREAD)*(n/(TH*NUM_PER_THREAD))) + (n % (TH*NUM_PER_THREAD) == 0 ? 0 : (TH*NUM_PER_THREAD)); - const int base_idx = (blockIdx.x * blockDim.x * NUM_PER_THREAD); - int valid_items = 0; - float update_scale = 0.0f; - T g_vals[NUM_PER_THREAD]; - T p_vals[NUM_PER_THREAD]; - - - float s1_vals[NUM_PER_THREAD]; - float s2_vals[NUM_PER_THREAD]; - - // AdEMAMix has an additional state buffer, which we packed - // into state1. We need thread-local storage here for these. - // TODO: Mark with [[maybe_unused]] after upgrade to min compiler. - float s3_vals[NUM_PER_THREAD]; - - const float correction1 = 1.0f - powf(beta1, step); - const float correction2 = sqrtf(1.0f - powf(beta2, step)); - const float step_size = -lr*correction2/correction1; - - if(max_unorm > 0.0f) - { - update_scale = max_unorm > 0.0f ? sqrtf(unorm[0]) : 1.0f; - if(update_scale > max_unorm*param_norm){ update_scale = (max_unorm*param_norm)/update_scale; } - else{ update_scale = 1.0f; } - } - else{ update_scale = 1.0f; } - - typedef hipcub::BlockLoad Load; - typedef hipcub::BlockStore Store; - - typedef hipcub::BlockLoad LoadFloat; - typedef hipcub::BlockStore StoreFloat; - - __shared__ union { - typename Load::TempStorage load; - typename Store::TempStorage store; - typename LoadFloat::TempStorage loadf; - typename StoreFloat::TempStorage storef; - } temp_storage; - - for (unsigned int i = base_idx; i < n_full; i += gridDim.x*TH*NUM_PER_THREAD) - { - valid_items = n - i >= (TH*NUM_PER_THREAD) ? (TH*NUM_PER_THREAD) : n - i; - - __syncthreads(); - Load(temp_storage.load).Load(&(g[i]), g_vals, valid_items); - __syncthreads(); - LoadFloat(temp_storage.loadf).Load(&(state1[i]), s1_vals, valid_items); - __syncthreads(); - LoadFloat(temp_storage.loadf).Load(&(state2[i]), s2_vals, valid_items); - __syncthreads(); - Load(temp_storage.load).Load(&(p[i]), p_vals, valid_items); - - // Load additional state1 data for AdEMAMix - // TODO: Make constexpr after updating min compiler - if (OPTIMIZER == ADEMAMIX) { - __syncthreads(); - LoadFloat(temp_storage.loadf).Load(&(state1[n + i]), s3_vals, valid_items); - } - - # pragma unroll 4 - for(unsigned int j = 0; j < NUM_PER_THREAD; j++) - g_vals[j] = gnorm_scale*((float)g_vals[j]); - - # pragma unroll 4 - for(unsigned int j = 0; j < NUM_PER_THREAD; j++) - { - switch(OPTIMIZER) - { - case ADEMAMIX: - // m1 update: m1 = beta1 * m1 + (1-beta1) * g - s1_vals[j] = (s1_vals[j] * beta1) + ((1.0f - beta1) * (float)g_vals[j]); - - // m2 update: m2 = m2 * beta3 + (1-beta3) * g - s3_vals[j] = (s3_vals[j] * beta3) + ((1.0f - beta3) * (float)g_vals[j]); - - // nu update: nu = beta2 * nu + (1-beta2) * g^2 - s2_vals[j] = (s2_vals[j] * beta2) + ((1.0f - beta2) * (float)g_vals[j] * (float)g_vals[j]); - - p_vals[j] = (float)p_vals[j] - lr * ( - ((s1_vals[j] / correction1) + (alpha * s3_vals[j])) / ( - (sqrtf(s2_vals[j]) / correction2) + eps - ) - ); - - if (weight_decay > 0.0f) - p_vals[j] = ((float)p_vals[j]) * (1.0f - (lr * weight_decay)); - - break; - case ADAM: - if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f))) - { - s1_vals[j] = s1_vals[j]*beta1 + ((1.0f -beta1)*((float)g_vals[j])); - s2_vals[j] = s2_vals[j]*beta2 + ((1.0f -beta2)*(((float)g_vals[j])*((float)g_vals[j]))); - p_vals[j] = ((float)p_vals[j]) + (update_scale*step_size*(s1_vals[j]/(sqrtf(s2_vals[j])+(eps*correction2)))); - - if(weight_decay > 0.0f) - p_vals[j] = ((float)p_vals[j])*(1.0f-(lr*weight_decay)); - } - break; - } - } - - __syncthreads(); - Store(temp_storage.store).Store(&(p[i]), p_vals, valid_items); - __syncthreads(); - StoreFloat(temp_storage.storef).Store(&(state1[i]), s1_vals, valid_items); - __syncthreads(); - StoreFloat(temp_storage.storef).Store(&(state2[i]), s2_vals, valid_items); - - if (OPTIMIZER == ADEMAMIX) { - __syncthreads(); - StoreFloat(temp_storage.storef).Store(&(state1[n + i]), s3_vals, valid_items); - } - } -} - -template -__launch_bounds__(BLOCK_SIZE/NUM_VALS, 1) -__global__ void kPreconditionOptimizer32bit1State(T* g, T* p, - float* state1, float *unorm, - const float beta1, const float beta2, const float eps, const float weight_decay, - const int step, const float lr, const float gnorm_scale, const int n) -{ - - const int n_full = (BLOCK_SIZE*(n/BLOCK_SIZE)) + (n % BLOCK_SIZE == 0 ? 0 : BLOCK_SIZE); - const int base_idx = (blockIdx.x * blockDim.x * NUM_VALS); - int valid_items = 0; - - T g_vals[NUM_VALS]; - - float s1_vals[NUM_VALS]; - - typedef hipcub::BlockLoad Load; - typedef hipcub::BlockLoad LoadFloat; - typedef hipcub::BlockReduce BlockReduce; - - __shared__ union { - typename Load::TempStorage load; - typename LoadFloat::TempStorage loadf; - typename BlockReduce::TempStorage reduce; - } temp_storage; - - for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_SIZE) - { - valid_items = n - i >= (BLOCK_SIZE) ? (BLOCK_SIZE) : n - i; - - __syncthreads(); - Load(temp_storage.load).Load(&(g[i]), g_vals, valid_items, 0.0f); - __syncthreads(); - LoadFloat(temp_storage.loadf).Load(&(state1[i]), s1_vals, valid_items, 0.0f); - - # pragma unroll NUM_VALS - for(unsigned int j = 0; j < NUM_VALS; j++) - g_vals[j] = gnorm_scale*((float)g_vals[j]); - - # pragma unroll NUM_VALS - for(unsigned int j = 0; j < NUM_VALS; j++) - { - switch(OPTIMIZER) - { - case MOMENTUM: - if(step == 1) - s1_vals[j] = (float)g_vals[j]; // state update - else - s1_vals[j] = s1_vals[j]*beta1 + ((float)g_vals[j]); // state update - s1_vals[j] = s1_vals[j]*s1_vals[j]; // update norm - break; - case LION: - s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*(float)g_vals[j]); // state update - break; - case RMSPROP: - s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*((float)g_vals[j])*((float)g_vals[j])); // state update - s1_vals[j] = __fdividef((float)g_vals[j],sqrtf(s1_vals[j])+eps); // update value - s1_vals[j] = s1_vals[j]*s1_vals[j]; // update norm - break; - case ADAGRAD: - s1_vals[j] = s1_vals[j] + ((float)g_vals[j])*((float)g_vals[j]); // state update - s1_vals[j] = __fdividef((float)g_vals[j],sqrtf(s1_vals[j])+eps); // update value - s1_vals[j] = s1_vals[j]*s1_vals[j]; // update norm - break; - } - } - - # pragma unroll - for(unsigned int j = 1; j < NUM_VALS; j++) - s1_vals[0] += s1_vals[j]; - - __syncthreads(); - s1_vals[0] = BlockReduce(temp_storage.reduce).Sum(s1_vals[0], valid_items); - - if(threadIdx.x == 0) - atomicAdd(&unorm[0], s1_vals[0]); - - //__syncwarp(); - } -} - -template -__launch_bounds__(TH, 1) -__global__ void kOptimizer32bit1State(T *g, T *p, - float *state1, float *unorm, const float max_unorm, const float param_norm, - const float beta1, const float beta2, const float eps, const float weight_decay, - const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n) -{ - - const int n_full = ((TH*NUM_PER_THREAD)*(n/(TH*NUM_PER_THREAD))) + (n % (TH*NUM_PER_THREAD) == 0 ? 0 : (TH*NUM_PER_THREAD)); - const int base_idx = (blockIdx.x * blockDim.x * NUM_PER_THREAD); - int valid_items = 0; - float update_scale = 0.0f; - - if(max_unorm > 0.0f) - { - update_scale = max_unorm > 0.0f ? sqrtf(unorm[0]) : 1.0f; - if(update_scale > max_unorm*param_norm+eps){ update_scale = (max_unorm*param_norm+eps)/update_scale; } - else{ update_scale = 1.0f; } - } - else{ update_scale = 1.0f; } - - T g_vals[NUM_PER_THREAD]; - T p_vals[NUM_PER_THREAD]; - - float s1_vals[NUM_PER_THREAD]; - - typedef hipcub::BlockLoad Load; - typedef hipcub::BlockStore Store; - - typedef hipcub::BlockLoad LoadFloat; - typedef hipcub::BlockStore StoreFloat; - - __shared__ union { - typename Load::TempStorage load; - typename Store::TempStorage store; - typename LoadFloat::TempStorage loadf; - typename StoreFloat::TempStorage storef; - } temp_storage; - - for (unsigned int i = base_idx; i < n_full; i += gridDim.x*TH*NUM_PER_THREAD) - { - valid_items = n - i >= (TH*NUM_PER_THREAD) ? (TH*NUM_PER_THREAD) : n - i; - - __syncthreads(); - Load(temp_storage.load).Load(&(g[i]), g_vals, valid_items); - __syncthreads(); - LoadFloat(temp_storage.loadf).Load(&(state1[i]), s1_vals, valid_items); - __syncthreads(); - Load(temp_storage.load).Load(&(p[i]), p_vals, valid_items); - - # pragma unroll 4 - for(unsigned int j = 0; j < NUM_PER_THREAD; j++) - { - g_vals[j] = gnorm_scale*((float)g_vals[j]); - if(weight_decay > 0.0f) - g_vals[j] = (float)g_vals[j] + (((float)p_vals[j])*weight_decay); - } - - # pragma unroll 4 - for(unsigned int j = 0; j < NUM_PER_THREAD; j++) - { - if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f))) - { - switch(OPTIMIZER) - { - case MOMENTUM: - if(step == 1) - s1_vals[j] = (float)g_vals[j]; - else - s1_vals[j] = s1_vals[j]*beta1 + ((float)g_vals[j]); - - p_vals[j] = ((float)p_vals[j]) + update_scale*(-lr*(s1_vals[j])); - break; - case LION: - p_vals[j] = ((float)p_vals[j]) - update_scale*(lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*((float)g_vals[j])))); - s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*((float)g_vals[j])); - break; - case RMSPROP: - s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*((float)g_vals[j])*((float)g_vals[j])); - p_vals[j] = ((float)p_vals[j]) - update_scale*(lr*__fdividef((float)g_vals[j],sqrtf((float)s1_vals[j])+eps)); - break; - case ADAGRAD: - s1_vals[j] = s1_vals[j] + ((float)g_vals[j])*((float)g_vals[j]); - p_vals[j] = ((float)p_vals[j]) - lr*__fdividef((float)g_vals[j],sqrtf((float)s1_vals[j])+eps); - break; - } - } - } - - __syncthreads(); - Store(temp_storage.store).Store(&(p[i]), p_vals, valid_items); - __syncthreads(); - StoreFloat(temp_storage.storef).Store(&(state1[i]), s1_vals, valid_items); - } -} - - -#define LANES 2 -#define QUAD 3 -template -__launch_bounds__(256, 3) -__global__ void -kOptimizerStatic8bit2StateBlockwise( - T* p, - T* __restrict__ const g, - unsigned char* state1, - unsigned char* state2, - const float beta1, - const float beta2, - const float beta3, - const float alpha, - const float eps, - const int step, - const float lr, - float* __restrict__ const quantiles1, - float* __restrict__ const quantiles2, - float* absmax1, - float* absmax2, - float weight_decay, - const float gnorm_scale, - const bool skip_zeros, - const int n -) { - - //const int n_full = n + (n%BLOCK_SIZE); - const int n_full = gridDim.x * BLOCK_SIZE; - const int base_idx = (blockIdx.x * BLOCK_SIZE); - int valid_items = 0; - float g_val = 0.0f; - float s1_vals[N_PER_TH]; - float s2_vals[N_PER_TH]; - float s3_vals[N_PER_TH]; - - // 2-5% - const float correction1 = 1.0f - __powf(beta1, step); - const float correction2 = sqrtf(1.0f -__powf(beta2, step)); - const float step_size = __fdividef(-lr*correction2,correction1); - const int lane_id = threadIdx.x % LANES; - float new_local_abs_max1 = -FLT_MAX; - float new_local_abs_max2 = -FLT_MAX; - float new_local_abs_max3 = -FLT_MAX; - float quadrants1[QUAD]; - float quadrants2[QUAD]; - - unsigned char c1s[N_PER_TH]; - unsigned char c2s[N_PER_TH]; - unsigned char c3s[N_PER_TH]; - - T g_vals[N_PER_TH]; - T p_vals[N_PER_TH]; - typedef hipcub::BlockLoad LoadT; - typedef hipcub::BlockLoad LoadChar; - - typedef hipcub::BlockStore StoreChar; - typedef hipcub::BlockStore StoreT; - - __shared__ float smem_quantiles1[LANES][257]; - __shared__ float smem_quantiles2[LANES][257]; - typedef hipcub::BlockReduce BlockReduce1; - typedef hipcub::BlockReduce BlockReduce2; - typedef hipcub::BlockReduce BlockReduce3; - __shared__ typename BlockReduce1::TempStorage reduce1; - __shared__ typename BlockReduce2::TempStorage reduce2; - __shared__ typename BlockReduce2::TempStorage reduce3; - __shared__ float smem_exchange1[1]; - __shared__ float smem_exchange2[1]; - __shared__ float smem_exchange3[1]; // [[maybe_unused]] - - __shared__ union { - typename LoadT::TempStorage loadh; - typename LoadChar::TempStorage loadc; - typename StoreChar::TempStorage storec; - typename StoreT::TempStorage storeh; - } temp_storage; - // init: 0.2 -> 0.23 - - // 0.23 -> 0.23 - smem_quantiles1[0][threadIdx.x] = quantiles1[threadIdx.x]; - smem_quantiles2[0][threadIdx.x] = quantiles2[threadIdx.x]; - # pragma unroll - for(unsigned int j = 1; j < LANES; j++) - { - smem_quantiles1[j][threadIdx.x] = smem_quantiles1[0][threadIdx.x]; - smem_quantiles2[j][threadIdx.x] = smem_quantiles2[0][threadIdx.x]; - } - - __syncthreads(); - - #pragma unroll - for(int k = 0; k < QUAD; k++) - { - quadrants1[k] = smem_quantiles1[lane_id][(k*256/(QUAD+1)) + (256/(QUAD+1)-1)]; - quadrants2[k] = smem_quantiles2[lane_id][(k*256/(QUAD+1)) + (256/(QUAD+1)-1)]; - } - - - for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_SIZE) - { - // loads: 0.23 -> 0.85/1.44 - valid_items = n - i >= BLOCK_SIZE ? BLOCK_SIZE : n - i; - __syncthreads(); - LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f); - __syncthreads(); - LoadChar(temp_storage.loadc).Load(&(state1[i]), c1s, valid_items, 128); - __syncthreads(); - LoadChar(temp_storage.loadc).Load(&(state2[i]), c2s, valid_items, 0); - - // AdEMAMix has an additional state packed into state1. - if (OPTIMIZER == ADEMAMIX) { - __syncthreads(); - LoadChar(temp_storage.loadc).Load(&(state1[n + i]), c3s, valid_items, 128); - } - - new_local_abs_max1 = -FLT_MAX; - new_local_abs_max2 = -FLT_MAX; - new_local_abs_max3 = -FLT_MAX; - - // update: 2.48/1.57 -> 2.51/1.60 - # pragma unroll N_PER_TH - for(unsigned int j = 0; j < N_PER_TH; j++) - { - if(!isnan((float)g_vals[j]) && !isinf((float)g_vals[j])) - { - s2_vals[j] = smem_quantiles2[lane_id][c2s[j]]*absmax2[i/BLOCK_SIZE]; - g_val = g_vals[j]; - //float ratio = (g_val*g_val)/fmaxf(s2_vals[j], eps*eps); - //g_val = ratio > 2.0f ? 2.0f*g_val/ratio : g_val; - g_val *= gnorm_scale; - - s2_vals[j] = (s2_vals[j]*beta2) + (((1.0f-beta2)*g_val*g_val)); - - s1_vals[j] = smem_quantiles1[lane_id][c1s[j]]*absmax1[i/BLOCK_SIZE]; - s1_vals[j] = (s1_vals[j]*beta1) + (((1.0f-beta1)*g_val)); - - if (OPTIMIZER == ADEMAMIX) { - // The absmax for the third state is appended to absmax1 - s3_vals[j] = smem_quantiles1[lane_id][c3s[j]] * absmax1[(n + i)/BLOCK_SIZE]; - s3_vals[j] = (s3_vals[j] * beta3) + (((1.0f - beta3) * g_val)); - } - } - else - { - s1_vals[j] = 0.0f; - s2_vals[j] = 0.0f; - - if (OPTIMIZER == ADEMAMIX) { - s3_vals[j] = 0.0f; - } - } - - new_local_abs_max1 = fmaxf(new_local_abs_max1, fabsf(s1_vals[j])); - new_local_abs_max2 = fmaxf(new_local_abs_max2, fabsf(s2_vals[j])); - - if (OPTIMIZER == ADEMAMIX) { - new_local_abs_max3 = fmaxf(new_local_abs_max3, fabsf(s3_vals[j])); - } - } - - - // reduce: 2.51/1.60 -> 2.67/1.69 - new_local_abs_max1 = BlockReduce1(reduce1).Reduce(new_local_abs_max1, hipcub::Max()); - new_local_abs_max2 = BlockReduce2(reduce2).Reduce(new_local_abs_max2, hipcub::Max()); - - if (OPTIMIZER == ADEMAMIX) { - new_local_abs_max3 = BlockReduce3(reduce3).Reduce(new_local_abs_max3, hipcub::Max()); - } - - if(threadIdx.x == 0) - { - smem_exchange1[0] = new_local_abs_max1; - smem_exchange2[0] = new_local_abs_max2; - - if (OPTIMIZER == ADEMAMIX) { - smem_exchange3[0] = new_local_abs_max3; - } - } - - __syncthreads(); - - if(threadIdx.x == 0) - { - absmax1[i/BLOCK_SIZE] = new_local_abs_max1; - absmax2[i/BLOCK_SIZE] = new_local_abs_max2; - - if (OPTIMIZER == ADEMAMIX) { - absmax1[(n + i)/BLOCK_SIZE] = new_local_abs_max3; - } - } - else - { - new_local_abs_max1 = smem_exchange1[0]; - new_local_abs_max2 = smem_exchange2[0]; - - if (OPTIMIZER == ADEMAMIX) { - new_local_abs_max3 = smem_exchange3[0]; - } - } - - __syncthreads(); - LoadT(temp_storage.loadh).Load(&(p[i]), p_vals, valid_items, (T)0.0f); - // reduce: 2.67/1.69 -> 2.67/1.70 - # pragma unroll N_PER_TH - for(unsigned int j = 0; j < N_PER_TH; j++) - { - //if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f))) - if(!isnan((float)g_vals[j]) && !isinf((float)g_vals[j])) - { - if (OPTIMIZER == ADEMAMIX) { - p_vals[j] = T((float)p_vals[j] - lr * ( - ((s1_vals[j] / correction1) + (alpha * s3_vals[j])) / ( - (sqrtf(s2_vals[j]) / correction2) + eps - ) - )); - } else { - p_vals[j] = (T)(((float)p_vals[j]) + ((step_size*(__fdividef(s1_vals[j],(sqrtf(s2_vals[j])+(correction2*eps))))))); - } - - if(weight_decay > 0.0f) - p_vals[j] = ((float)p_vals[j])*(1.0f-(lr*weight_decay)); - } - } - - // store: 0.85/1.44 -> 2.48/1.57 - __syncthreads(); - StoreT(temp_storage.storeh).Store(&(p[i]), p_vals, valid_items); - - // quantizaztion: 2.67/1.70 -> 3.4/3.3 - # pragma unroll N_PER_TH - for(unsigned int j = 0; j < N_PER_TH; j++) - { - c1s[j] = quantize_2D<1>(quadrants1, smem_quantiles1[lane_id], __fdividef(s1_vals[j],new_local_abs_max1)); - c2s[j] = quantize_2D<0>(quadrants2, smem_quantiles2[lane_id], __fdividef(s2_vals[j],new_local_abs_max2)); - - // make sure state1 term has still the same sign after quantization - // (not needed for state2 term which has only positive values) - if(signbit(smem_quantiles1[lane_id][c1s[j]]) != signbit(s1_vals[j])) - { - if(s1_vals[j] > 0.0f) - c1s[j] += 1; - else - c1s[j] -= 1; - } - - if (OPTIMIZER == ADEMAMIX) { - c3s[j] = quantize_2D<1>(quadrants1, smem_quantiles1[lane_id], __fdividef(s3_vals[j],new_local_abs_max3)); - - if (signbit(smem_quantiles1[lane_id][c3s[j]]) != signbit(s3_vals[j])) { - c3s[j] += (s3_vals[j] > 0.0f) ? 1 : -1; - } - } - } - - __syncthreads(); - StoreChar(temp_storage.storec).Store(&(state1[i]), c1s, valid_items); - __syncthreads(); - StoreChar(temp_storage.storec).Store(&(state2[i]), c2s, valid_items); - - if (OPTIMIZER == ADEMAMIX) { - __syncthreads(); - StoreChar(temp_storage.storec).Store(&(state1[n + i]), c3s, valid_items); - } - } -} - - -#define LANES 2 -#define QUAD 3 -template -__launch_bounds__(256, 3) -__global__ void -kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char* state1, - const float beta1, const float beta2, - const float eps, const int step, const float lr, - float* __restrict__ const quantiles1, - float* absmax1, - float weight_decay, - const float gnorm_scale, const bool skip_zeros, const int n) -{ - - //const int n_full = n + (n%BLOCK_SIZE); - const int n_full = gridDim.x * BLOCK_SIZE; - const int base_idx = (blockIdx.x * BLOCK_SIZE); - int valid_items = 0; - float g_val = 0.0f; - float s1_vals[N_PER_TH]; - // 2-5% - const int lane_id = threadIdx.x % LANES; - float new_local_abs_max1 = -FLT_MAX; - float quadrants1[QUAD]; - - unsigned char c1s[N_PER_TH]; - T g_vals[N_PER_TH]; - T p_vals[N_PER_TH]; - - typedef hipcub::BlockLoad LoadT; - typedef hipcub::BlockLoad LoadChar; - - typedef hipcub::BlockStore StoreChar; - typedef hipcub::BlockStore StoreT; - - __shared__ float smem_quantiles1[LANES][257]; - typedef hipcub::BlockReduce BlockReduce1; - __shared__ typename BlockReduce1::TempStorage reduce1; - __shared__ float smem_exchange1[1]; - - __shared__ union { - typename LoadT::TempStorage loadh; - typename LoadChar::TempStorage loadc; - typename StoreChar::TempStorage storec; - typename StoreT::TempStorage storeh; - } temp_storage; - // init: 0.2 -> 0.23 - - // 0.23 -> 0.23 - smem_quantiles1[0][threadIdx.x] = quantiles1[threadIdx.x]; - # pragma unroll - for(unsigned int j = 1; j < LANES; j++) - smem_quantiles1[j][threadIdx.x] = smem_quantiles1[0][threadIdx.x]; - - __syncthreads(); - - #pragma unroll - for(int k = 0; k < QUAD; k++) - quadrants1[k] = smem_quantiles1[lane_id][(k*256/(QUAD+1)) + (256/(QUAD+1)-1)]; - - for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_SIZE) - { - // loads: 0.23 -> 0.85/1.44 - valid_items = n - i >= BLOCK_SIZE ? BLOCK_SIZE : n - i; - __syncthreads(); - LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f); - __syncthreads(); - LoadChar(temp_storage.loadc).Load(&(state1[i]), c1s, valid_items, 128); - __syncthreads(); - LoadT(temp_storage.loadh).Load(&(p[i]), p_vals, valid_items, (T)0.0f); - - new_local_abs_max1 = -FLT_MAX; - - // update: 2.48/1.57 -> 2.51/1.60 - # pragma unroll N_PER_TH - for(unsigned int j = 0; j < N_PER_TH; j++) - { - g_val = float(g_vals[j]); - g_val *= gnorm_scale; - if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f))) - { - if(weight_decay > 0.0f) { - switch(OPTIMIZER) { - case MOMENTUM: - case ADAGRAD: - case RMSPROP: - g_val += ((float)p_vals[j])*weight_decay; - break; - case LION: - p_vals[j] = ((float)p_vals[j])*(1.0f-lr*weight_decay); - break; - } - } - - s1_vals[j] = smem_quantiles1[lane_id][c1s[j]]*absmax1[i/BLOCK_SIZE]; - - switch(OPTIMIZER) - { - case MOMENTUM: - if(step == 1) - s1_vals[j] = g_val; - else - s1_vals[j] = (s1_vals[j]*beta1) + g_val; - break; - case LION: - // here, using gvals[j] to store the gradient smoothed by beta1 for the following parameter update, before the momentum is updated by beta2 - g_vals[j] = lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*g_val)); - s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*g_val); - break; - case RMSPROP: - s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val)); - break; - case ADAGRAD: - s1_vals[j] = s1_vals[j] + (g_val*g_val); - break; - } - } - - new_local_abs_max1 = fmaxf(new_local_abs_max1, fabsf(s1_vals[j])); - } - - - // reduce: 2.51/1.60 -> 2.67/1.69 - new_local_abs_max1 = BlockReduce1(reduce1).Reduce(new_local_abs_max1, hipcub::Max()); - - if(threadIdx.x == 0) - smem_exchange1[0] = new_local_abs_max1; - - __syncthreads(); - - if(threadIdx.x == 0) - absmax1[i/BLOCK_SIZE] = new_local_abs_max1; - else - new_local_abs_max1 = smem_exchange1[0]; - - // reduce: 2.67/1.69 -> 2.67/1.70 - # pragma unroll N_PER_TH - for(unsigned int j = 0; j < N_PER_TH; j++) - { - if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f))) - { - switch(OPTIMIZER) - { - case MOMENTUM: - p_vals[j] = ((float)p_vals[j]) - lr*(s1_vals[j]); - break; - case LION: - p_vals[j] = ((float)p_vals[j]) - ((float)g_vals[j]); - break; - case RMSPROP: - g_val = g_vals[j]; - p_vals[j] = ((float)p_vals[j]) - lr*(__fdividef(g_val, sqrtf(s1_vals[j])+eps)); - break; - case ADAGRAD: - g_val = g_vals[j]; - p_vals[j] = ((float)p_vals[j]) - lr*(__fdividef(g_val, sqrtf(s1_vals[j])+eps)); - break; - } - } - } - - // store: 0.85/1.44 -> 2.48/1.57 - __syncthreads(); - StoreT(temp_storage.storeh).Store(&(p[i]), p_vals, valid_items); - - // quantizaztion: 2.67/1.70 -> 3.4/3.3 - # pragma unroll N_PER_TH - for(unsigned int j = 0; j < N_PER_TH; j++) - { - c1s[j] = quantize_2D<1>(quadrants1, smem_quantiles1[lane_id], __fdividef(s1_vals[j],new_local_abs_max1)); - - // make sure state1 term has still the same sign after quantization - // (not needed for state2 term which has only positive values) - if(signbit(smem_quantiles1[lane_id][c1s[j]]) != signbit(s1_vals[j])) - { - if(s1_vals[j] > 0.0f) - c1s[j] += 1; - else - c1s[j] -= 1; - } - } - - __syncthreads(); - StoreChar(temp_storage.storec).Store(&(state1[i]), c1s, valid_items); - } -} - -// Inputs: -// A [rows, cols] -// Outputs: -// rowStats [rows] -// out [rows, cols] -template -__launch_bounds__(1024, BNB_MAX_THREADS_PER_CU / 1024) -__global__ void kInt8VectorQuant(T * __restrict__ A, int8_t* out, float* rowStats, float threshold, int rows, int cols) { - - // For sm50/sm52 and CUDA < 12.2 we need to do the reduction in fp32. - // Otherwise `T` is `fp16`. This can be removed when Maxwell is dropped. -#if (__CUDACC_VER_MAJOR__ >= 12 && __CUDACC_VER_MINOR >= 2) || BNB_FP16_AVAILABLE - using TReduction = T; -#else - using TReduction = float; -#endif - - using BlockReduceT = hipcub::BlockReduce; - - // One block per row. - // Threads load column values in a striped arrangement. - // e.g. t0 reads row[0], row[0+nthreads], .. - // and t1 reads row[1], row[1+nthreads], .. - // Each thread will determine its local absmax. - // We then do a blockwise reduction to determine the row's absmax. - - __shared__ typename BlockReduceT::TempStorage temp_storage; - __shared__ TReduction smem_row_absmax; - - const int row_id = blockIdx.x; - const T* row_data = A + (row_id * cols); - - // Threads will read the row values in a striped access pattern and find a local absmax. - TReduction row_local_absmax = -FLT_MIN; - for (int i = threadIdx.x; i < cols; i += THREADS) { - const TReduction absval = fabsf(__ldcs(&(row_data[i]))); - - // For sparse decomposition, values outside of the threshold are not to be - // included when calculating the row's absmax. - if constexpr (SPARSE_DECOMP) { - row_local_absmax = fmaxf(row_local_absmax, absval < TReduction(threshold) ? absval : row_local_absmax); - } else { - row_local_absmax = fmaxf(row_local_absmax, absval); - } - } - - // Reduce thread-local absmax across the block. - const TReduction row_absmax = BlockReduceT(temp_storage).Reduce(row_local_absmax, hipcub::Max(), cols); - if (threadIdx.x == 0) { - // Save our block's absmax to shared memory for the quantization step. - rowStats[row_id] = smem_row_absmax = row_absmax; - } - __syncthreads(); - - // Quantize row-wise. - const float scale = __fdividef(127.0f, smem_row_absmax); - for (int i = threadIdx.x; i < cols; i += THREADS) { - float val = row_data[i]; - - if constexpr (SPARSE_DECOMP) { - // For sparse decomposition, we do not want to quantize the outliers. - // Instead they're zeroed out. - out[row_id * cols + i] = fabs(val) < threshold ? __float2int_rn(val * scale) : 0; - } else { - out[row_id * cols + i] = __float2int_rn(val * scale); - } - } -} - -template __global__ void kInt8VectorQuant(half * __restrict__ A, int8_t *out, float *rowStats, float threshold, int rows, int cols); -template __global__ void kInt8VectorQuant(half * __restrict__ A, int8_t *out, float *rowStats, float threshold, int rows, int cols); - - -#define MM_DEQUANT_CONST 6.200012e-05f //1.0f/(127.0f*127.0f) - -template -__global__ void kdequant_mm_int32_fp16( - int* __restrict__ const A, - float *__restrict__ const rowStats, - float *__restrict__ const colStats, - half *out, - half *__restrict__ const bias, - const int numRows, - const int numCols, - const int n -) { - const int n_out = numRows * numCols; - - int block_offset = blockIdx.x * THREADS * ITEMS_PER_THREAD; - int thread_offset = threadIdx.x * ITEMS_PER_THREAD; - - int local_values[ITEMS_PER_THREAD]; - half local_output[ITEMS_PER_THREAD]; - - float local_rowStats[ITEMS_PER_THREAD]; - float local_colStats[ITEMS_PER_THREAD]; - float local_biasValue[ITEMS_PER_THREAD]; - - typedef hipcub::BlockLoad LoadInt32; - __shared__ typename LoadInt32::TempStorage loadint32; - - int row_idx, col_idx; - - #pragma unroll ITEMS_PER_THREAD - for(int j = 0; j < ITEMS_PER_THREAD; j++) - { - row_idx = (block_offset + thread_offset + j) / numCols; - col_idx = (block_offset + thread_offset + j) % numCols; - - local_colStats[j] = col_idx >= numCols ? 0.0f : colStats[col_idx]; - local_rowStats[j] = row_idx >= numRows ? 0.0f : rowStats[row_idx]; - local_biasValue[j] = ((bias == nullptr) || (col_idx >= numCols)) ? 0.0f : __half2float(bias[col_idx]); - } - - // Each block loads THREADS * ITEMS_PER_THREAD values from A - int valid_items = block_offset + THREADS * ITEMS_PER_THREAD < n_out - ? THREADS * ITEMS_PER_THREAD - : n_out - block_offset; - LoadInt32(loadint32).Load(&(A[block_offset]), local_values, valid_items, 0); - - #pragma unroll ITEMS_PER_THREAD - for (int j = 0; j < ITEMS_PER_THREAD; ++j) { - local_output[j] = __float2half( - fmaf(local_values[j] * local_rowStats[j] * local_colStats[j], MM_DEQUANT_CONST, local_biasValue[j]) - ); - } - - #pragma unroll ITEMS_PER_THREAD - for (int j = 0; j < ITEMS_PER_THREAD; j++) { - int outIdx = block_offset + thread_offset + j; - if (outIdx < n_out) { - out[outIdx] = local_output[j]; - } - } -} - - -// No of 4bit values processed by each thread -#define num_values_4bit 32 -template __global__ void kgemm_4bit_inference_naive(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, T * out, int lda, int ldb, int ldc, int blocksize) -{ - - // per threadblock: - // load step-by-step in chunks of [warp_size,warps]: 1xwarp_size * [warp_size,warps] -> [1,warps] - // 4 warps -> 4 loads per iter - // 1xwarp_size * warp_sizex4 -> 1x4 outputs per thread block - typedef hipcub::WarpReduce WarpReduce; - __shared__ typename WarpReduce::TempStorage temp_storage[THREADS/BNB_WARP_SIZE]; - - const int warp_idx = threadIdx.x / BNB_WARP_SIZE; - const int warp_lane = threadIdx.x % BNB_WARP_SIZE; - const int row_B = (THREADS/BNB_WARP_SIZE)*blockIdx.x + warp_idx; - const int offset_B = ldb * row_B; - const int num_values_8bit = num_values_4bit/2; - float local_C = 0.0f; - - unsigned char local_B_4bit[num_values_8bit]; - T local_B[num_values_4bit/4]; - T local_A[num_values_4bit/4]; - __shared__ T quant_map[16]; - T local_absmax = T(0.0f); - - if (threadIdx.x < 16) - quant_map[threadIdx.x] = T(__ldg(&datatype[threadIdx.x])); - //for(int i = threadIdx.x; i < 16; i++) - //quant_map[i] = T(datatype[i]); - __syncthreads(); - - // A: [1, K] - // B: [M, K] - for(int inner_idx = warp_lane*num_values_4bit; inner_idx < K; inner_idx += BNB_WARP_SIZE*num_values_4bit) - { - const int inner_idx_halved = inner_idx/2; - - // Since blocksize will always be a power-of-2, we avoid more expensive - // division by the blocksize and instead use a shift operation. - // This is equivalent to (i+threadId.x*NUM_PER_TH)/blocksize. - const int absidx = ((2*offset_B)+inner_idx) >> (31 - __clz(blocksize)); - - local_absmax = __ldg(&(absmax[absidx])); - - if(row_B < M) - { - if((inner_idx_halved + num_values_8bit) < (K/2)) - { - // this is the most important for performance considerations - reinterpret_cast(local_B_4bit)[0] = reinterpret_cast(B)[(offset_B+(inner_idx_halved))/(num_values_8bit)]; - } - else - { - #pragma unroll - for(int j = 0; j < (num_values_8bit); j++) - if((inner_idx_halved) + j < (K/2)) - local_B_4bit[j] = B[offset_B+inner_idx_halved + j]; - else - local_B_4bit[j] = 0b01110111; - } - } - else - { - #pragma unroll - for(int j = 0; j < (num_values_8bit); j++) - local_B_4bit[j] = 0b01110111; - } - - for(int i = 0; i < 4; i++) - { - #pragma unroll - for(int k = 0; k < num_values_8bit/4; k++) - { - #if BNB_BF16_AVAILABLE - local_B[k*2] = quant_map[local_B_4bit[(i*num_values_8bit/4) + k] >> 4]*local_absmax; - local_B[k*2 + 1] = quant_map[local_B_4bit[(i*num_values_8bit/4) + k] & 0x0F]*local_absmax; - #else - // bf16 multipliation not supported - local_B[k*2] = T((float)quant_map[local_B_4bit[(i*num_values_8bit/4) + k] >> 4]*(float)local_absmax); - local_B[k*2 + 1] = T((float)quant_map[local_B_4bit[(i*num_values_8bit/4) + k] & 0x0F]*(float)local_absmax); - #endif - } - - if(inner_idx+(num_values_4bit/4) + (i*num_values_4bit/4) < K) - { - // this is also relatively important for performance - if(BITS==16) - { - reinterpret_cast(local_A)[0] = reinterpret_cast(A)[inner_idx/(num_values_4bit/4) + i]; - } - else - { - reinterpret_cast(local_A)[0] = reinterpret_cast(A)[inner_idx/(num_values_4bit/8) + (2*i) + 0]; - reinterpret_cast(local_A)[1] = reinterpret_cast(A)[inner_idx/(num_values_4bit/8) + (2*i) + 1]; - } - - } - else - #pragma unroll - for(int k = 0; k < num_values_4bit/4; k++) - if(inner_idx + (i*num_values_4bit/4) + k < K) - local_A[k] = A[inner_idx + k + (i*num_values_4bit/4)]; - else - local_A[k] = T(0.0f); - - - // accumulate in float; small performance hit for Ampere, but lower error for outputs - #pragma unroll - for(int k = 0; k < num_values_4bit/4; k++) - { - #if BNB_BF16_AVAILABLE - local_C += (float)(local_A[k]*local_B[k]); - #else - // bf16 multipliation not supported - local_C += ((float)local_A[k]*(float)local_B[k]); - #endif - } - } - } - - local_C = WarpReduce(temp_storage[warp_idx]).Sum(local_C); - - if(row_B < M && warp_lane == 0) - out[row_B] = T(local_C); - -} - - -template __global__ void kfunc(T *A, T *B, T value, long n) -{ - for(long i = (blockDim.x*blockIdx.x) + threadIdx.x; i < n; i+=(blockDim.x*gridDim.x)) - { - switch(FUNC) - { - case FILL: - A[i] = (T)value; - break; - case ARANGE: - A[i] = (T)i; - break; - case _MUL: - A[i] = A[i]*B[i]; - break; - } - } -} - - -//============================================================== -// TEMPLATE DEFINITIONS -//============================================================== - -template __global__ void kfunc(float *A, float *B, float value, long n); -template __global__ void kfunc(unsigned char *A, unsigned char *B, unsigned char value, long n); -template __global__ void kfunc(float *A, float *B, float value, long n); -template __global__ void kfunc(float *A, float *B, float value, long n); - -template __global__ void kgemm_4bit_inference_naive(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, half * out, int lda, int ldb, int ldc, int blocksize); -template __global__ void kgemm_4bit_inference_naive(int M, int N, int K, hip_bfloat16 * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, hip_bfloat16 * out, int lda, int ldb, int ldc, int blocksize); -template __global__ void kgemm_4bit_inference_naive(int M, int N, int K, float * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, float * out, int lda, int ldb, int ldc, int blocksize); - - -template __global__ void kdequant_mm_int32_fp16<4, 512>(int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, half *out, half * __restrict__ const bias, const int numRows, const int numCols, const int n); - -template __device__ unsigned char dQuantize<0>(float* smem_code, const float rand, float x); -template __device__ unsigned char dQuantize<1>(float* smem_code, const float rand, float x); - -#define MAKE_PreconditionOptimizer32bit1State(oname, gtype) \ -template __global__ void kPreconditionOptimizer32bit1State(gtype* g, gtype* p, \ - float* state1, float *unorm, \ - const float beta1, const float beta2, const float eps, const float weight_decay, \ - const int step, const float lr, const float gnorm_scale, const int n); \ - -MAKE_PreconditionOptimizer32bit1State(MOMENTUM, half) -MAKE_PreconditionOptimizer32bit1State(MOMENTUM, float) -MAKE_PreconditionOptimizer32bit1State(MOMENTUM, hip_bfloat16) -MAKE_PreconditionOptimizer32bit1State(RMSPROP, half) -MAKE_PreconditionOptimizer32bit1State(RMSPROP, float) -MAKE_PreconditionOptimizer32bit1State(RMSPROP, hip_bfloat16) -MAKE_PreconditionOptimizer32bit1State(LION, half) -MAKE_PreconditionOptimizer32bit1State(LION, float) -MAKE_PreconditionOptimizer32bit1State(LION, hip_bfloat16) -MAKE_PreconditionOptimizer32bit1State(ADAGRAD, half) -MAKE_PreconditionOptimizer32bit1State(ADAGRAD, float) -MAKE_PreconditionOptimizer32bit1State(ADAGRAD, hip_bfloat16) - -#define MAKE_Optimizer32bit1State(oname, gtype) \ -template __global__ void kOptimizer32bit1State(gtype* g, gtype* p, float* state1, float *unorm, const float max_unorm, const float param_norm, \ - const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); \ - -MAKE_Optimizer32bit1State(MOMENTUM, half) -MAKE_Optimizer32bit1State(MOMENTUM, float) -MAKE_Optimizer32bit1State(MOMENTUM, hip_bfloat16) -MAKE_Optimizer32bit1State(RMSPROP, half) -MAKE_Optimizer32bit1State(RMSPROP, float) -MAKE_Optimizer32bit1State(RMSPROP, hip_bfloat16) -MAKE_Optimizer32bit1State(LION, half) -MAKE_Optimizer32bit1State(LION, float) -MAKE_Optimizer32bit1State(LION, hip_bfloat16) -MAKE_Optimizer32bit1State(ADAGRAD, half) -MAKE_Optimizer32bit1State(ADAGRAD, float) -MAKE_Optimizer32bit1State(ADAGRAD, hip_bfloat16) - -#define MAKE_PreconditionOptimizer32bit2State(oname, gtype) \ -template __global__ void kPreconditionOptimizer32bit2State(gtype* g, gtype* p, \ - float* state1, float* state2, float *unorm, \ - const float beta1, const float beta2, const float eps, const float weight_decay, \ - const int step, const float lr, const float gnorm_scale, const int n); \ - -MAKE_PreconditionOptimizer32bit2State(ADAM, float) -MAKE_PreconditionOptimizer32bit2State(ADAM, half) -MAKE_PreconditionOptimizer32bit2State(ADAM, hip_bfloat16) -MAKE_PreconditionOptimizer32bit2State(ADEMAMIX, float) -MAKE_PreconditionOptimizer32bit2State(ADEMAMIX, half) -MAKE_PreconditionOptimizer32bit2State(ADEMAMIX, hip_bfloat16) - -template __global__ void kOptimizer32bit2State(float* g, float* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, - const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); -template __global__ void kOptimizer32bit2State(half* g, half* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, - const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); -template __global__ void kOptimizer32bit2State(hip_bfloat16* g, hip_bfloat16* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, - const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); -template __global__ void kOptimizer32bit2State(float* g, float* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, - const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); -template __global__ void kOptimizer32bit2State(half* g, half* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, - const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); -template __global__ void kOptimizer32bit2State(hip_bfloat16* g, hip_bfloat16* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, - const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); - - -#define MAKE_kQuantizeBlockwise(dtype, blocksize, num_per_thread, stochastic, data_type_name) \ -template __global__ void kQuantizeBlockwise(float * code, dtype * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); \ - -MAKE_kQuantizeBlockwise(half, 4096, 4, 0, General8bit) -MAKE_kQuantizeBlockwise(half, 4096, 4, 1, General8bit) -MAKE_kQuantizeBlockwise(half, 2048, 4, 0, General8bit) -MAKE_kQuantizeBlockwise(half, 1024, 4, 0, General8bit) -MAKE_kQuantizeBlockwise(half, 512, 2, 0, General8bit) -MAKE_kQuantizeBlockwise(half, 256, 2, 0, General8bit) -MAKE_kQuantizeBlockwise(half, 128, 2, 0, General8bit) -#if BNB_WARP_SIZE == 32 - MAKE_kQuantizeBlockwise(half, 64, 2, 0, General8bit) -#endif - -MAKE_kQuantizeBlockwise(half, 4096, 4, 0, FP4) -MAKE_kQuantizeBlockwise(half, 2048, 4, 0, FP4) -MAKE_kQuantizeBlockwise(half, 1024, 4, 0, FP4) -MAKE_kQuantizeBlockwise(half, 512, 2, 0, FP4) -MAKE_kQuantizeBlockwise(half, 256, 2, 0, FP4) -MAKE_kQuantizeBlockwise(half, 128, 2, 0, FP4) -#if BNB_WARP_SIZE == 32 - MAKE_kQuantizeBlockwise(half, 64, 2, 0, FP4) -#endif - -MAKE_kQuantizeBlockwise(half, 4096, 4, 0, NF4) -MAKE_kQuantizeBlockwise(half, 2048, 4, 0, NF4) -MAKE_kQuantizeBlockwise(half, 1024, 4, 0, NF4) -MAKE_kQuantizeBlockwise(half, 512, 2, 0, NF4) -MAKE_kQuantizeBlockwise(half, 256, 2, 0, NF4) -MAKE_kQuantizeBlockwise(half, 128, 2, 0, NF4) -#if BNB_WARP_SIZE == 32 - MAKE_kQuantizeBlockwise(half, 64, 2, 0, NF4) -#endif - -MAKE_kQuantizeBlockwise(float, 4096, 4, 0, General8bit) -MAKE_kQuantizeBlockwise(float, 4096, 4, 1, General8bit) -MAKE_kQuantizeBlockwise(float, 2048, 4, 0, General8bit) -MAKE_kQuantizeBlockwise(float, 1024, 4, 0, General8bit) -MAKE_kQuantizeBlockwise(float, 512, 2, 0, General8bit) -MAKE_kQuantizeBlockwise(float, 256, 2, 0, General8bit) -MAKE_kQuantizeBlockwise(float, 128, 2, 0, General8bit) -#if BNB_WARP_SIZE == 32 - MAKE_kQuantizeBlockwise(float, 64, 2, 0, General8bit) -#endif - -MAKE_kQuantizeBlockwise(float, 4096, 4, 0, FP4) -MAKE_kQuantizeBlockwise(float, 2048, 4, 0, FP4) -MAKE_kQuantizeBlockwise(float, 1024, 4, 0, FP4) -MAKE_kQuantizeBlockwise(float, 512, 2, 0, FP4) -MAKE_kQuantizeBlockwise(float, 256, 2, 0, FP4) -MAKE_kQuantizeBlockwise(float, 128, 2, 0, FP4) -#if BNB_WARP_SIZE == 32 - MAKE_kQuantizeBlockwise(float, 64, 2, 0, FP4) -#endif - -MAKE_kQuantizeBlockwise(float, 4096, 4, 0, NF4) -MAKE_kQuantizeBlockwise(float, 2048, 4, 0, NF4) -MAKE_kQuantizeBlockwise(float, 1024, 4, 0, NF4) -MAKE_kQuantizeBlockwise(float, 512, 2, 0, NF4) -MAKE_kQuantizeBlockwise(float, 256, 2, 0, NF4) -MAKE_kQuantizeBlockwise(float, 128, 2, 0, NF4) -#if BNB_WARP_SIZE == 32 - MAKE_kQuantizeBlockwise(float, 64, 2, 0, NF4) -#endif - -MAKE_kQuantizeBlockwise(hip_bfloat16, 4096, 4, 0, General8bit) -MAKE_kQuantizeBlockwise(hip_bfloat16, 4096, 4, 1, General8bit) -MAKE_kQuantizeBlockwise(hip_bfloat16, 2048, 4, 0, General8bit) -MAKE_kQuantizeBlockwise(hip_bfloat16, 1024, 4, 0, General8bit) -MAKE_kQuantizeBlockwise(hip_bfloat16, 512, 2, 0, General8bit) -MAKE_kQuantizeBlockwise(hip_bfloat16, 256, 2, 0, General8bit) -MAKE_kQuantizeBlockwise(hip_bfloat16, 128, 2, 0, General8bit) -#if BNB_WARP_SIZE == 32 - MAKE_kQuantizeBlockwise(hip_bfloat16, 64, 2, 0, General8bit) -#endif - -MAKE_kQuantizeBlockwise(hip_bfloat16, 4096, 4, 0, FP4) -MAKE_kQuantizeBlockwise(hip_bfloat16, 2048, 4, 0, FP4) -MAKE_kQuantizeBlockwise(hip_bfloat16, 1024, 4, 0, FP4) -MAKE_kQuantizeBlockwise(hip_bfloat16, 512, 2, 0, FP4) -MAKE_kQuantizeBlockwise(hip_bfloat16, 256, 2, 0, FP4) -MAKE_kQuantizeBlockwise(hip_bfloat16, 128, 2, 0, FP4) -#if BNB_WARP_SIZE == 32 - MAKE_kQuantizeBlockwise(hip_bfloat16, 64, 2, 0, FP4) -#endif - -MAKE_kQuantizeBlockwise(hip_bfloat16, 4096, 4, 0, NF4) -MAKE_kQuantizeBlockwise(hip_bfloat16, 2048, 4, 0, NF4) -MAKE_kQuantizeBlockwise(hip_bfloat16, 1024, 4, 0, NF4) -MAKE_kQuantizeBlockwise(hip_bfloat16, 512, 2, 0, NF4) -MAKE_kQuantizeBlockwise(hip_bfloat16, 256, 2, 0, NF4) -MAKE_kQuantizeBlockwise(hip_bfloat16, 128, 2, 0, NF4) -#if BNB_WARP_SIZE == 32 - MAKE_kQuantizeBlockwise(hip_bfloat16, 64, 2, 0, NF4) -#endif - -// Specialized blocksize=64 4-bit quantization kernel instantiations for ROCm -#define MAKE_kQuantizeBlockwise64(dtype, data_type_name) \ -template __global__ void kQuantizeBlockwise64(float * code, dtype * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); - -// FP4 instantiations -MAKE_kQuantizeBlockwise64(half, FP4) -MAKE_kQuantizeBlockwise64(float, FP4) -MAKE_kQuantizeBlockwise64(hip_bfloat16, FP4) - -// NF4 instantiations -MAKE_kQuantizeBlockwise64(half, NF4) -MAKE_kQuantizeBlockwise64(float, NF4) -MAKE_kQuantizeBlockwise64(hip_bfloat16, NF4) - -template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n); -template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n); -template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n); -template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n); -template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n); -template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n); -template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, hip_bfloat16 *out, const int blocksize, const int n); -template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, hip_bfloat16 *out, const int blocksize, const int n); -template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, hip_bfloat16 *out, const int blocksize, const int n); - -#define MAKE_OptimizerStatic8bit2StateBlockwise(oname, gtype, block_size, num_per_thread) \ -template __global__ void kOptimizerStatic8bit2StateBlockwise(gtype* p, gtype* __restrict__ const g, unsigned char* state1, unsigned char* state2, \ - const float beta1, const float beta2, const float beta3, const float alpha, \ - const float eps, const int step, const float lr, \ - float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, \ - float* absmax1, float* absmax2, \ - float weight_decay, \ - const float gnorm_scale, const bool skip_zeros, const int n); \ - -MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, float, 256, 1) -MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, half, 256, 1) -MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, hip_bfloat16, 256, 1) -MAKE_OptimizerStatic8bit2StateBlockwise(ADEMAMIX, float, 256, 1) -MAKE_OptimizerStatic8bit2StateBlockwise(ADEMAMIX, half, 256, 1) -MAKE_OptimizerStatic8bit2StateBlockwise(ADEMAMIX, hip_bfloat16, 256, 1) - -#define MAKE_OptimizerStatic8bit1StateBlockwise(oname, gtype, block_size, num_per_thread) \ -template __global__ void kOptimizerStatic8bit1StateBlockwise( \ - gtype* p, gtype* __restrict__ const g, unsigned char* state1, \ - const float beta1, const float beta2, \ - const float eps, const int step, const float lr, \ - float* __restrict__ const quantiles1, \ - float* absmax1, \ - float weight_decay, \ - const float gnorm_scale, const bool skip_zeros, const int n); \ - -MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, float, 256, 1) -MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, half, 256, 1) -MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, hip_bfloat16, 256, 1) -MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, float, 256, 1) -MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, half, 256, 1) -MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, hip_bfloat16, 256, 1) -MAKE_OptimizerStatic8bit1StateBlockwise(LION, float, 256, 1) -MAKE_OptimizerStatic8bit1StateBlockwise(LION, half, 256, 1) -MAKE_OptimizerStatic8bit1StateBlockwise(LION, hip_bfloat16, 256, 1) -MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, float, 256, 1) -MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, half, 256, 1) -MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, hip_bfloat16, 256, 1) diff --git a/csrc/kernels_hip.cuh b/csrc/kernels_hip.cuh deleted file mode 100644 index 0e2885693..000000000 --- a/csrc/kernels_hip.cuh +++ /dev/null @@ -1,87 +0,0 @@ -// !!! This is a file automatically generated by hipify!!! -#include "hip/hip_runtime.h" -// Copyright (c) Facebook, Inc. and its affiliates. -// -// This source code is licensed under the MIT license found in the -// LICENSE file in the root directory of this source tree. - -#include -#include - -#ifndef kernels -#define kernels - -template -__global__ void kQuantizeBlockwise( - float* code, T* __restrict__ const A, float* absmax, unsigned char* out, float* __restrict__ const rand, - const int rand_offset, const int n -); -template -__global__ void kQuantizeBlockwise64( - float* code, T* __restrict__ const A, float* absmax, unsigned char* out, float* __restrict__ const rand, - const int rand_offset, const int n -); -template -__global__ void - kDequantizeBlockwise(float* code, unsigned char* A, float* absmax, T* out, const int blocksize, const int n); - -template -__global__ void kPreconditionOptimizer32bit2State( - T* g, T* p, float* state1, float* state2, float* unorm, const float beta1, const float beta2, const float eps, - const float weight_decay, const int step, const float lr, const float gnorm_scale, const int n -); - -template -__global__ void kOptimizer32bit2State( - T* g, T* p, float* state1, float* state2, float* unorm, const float max_unorm, const float param_norm, - const float beta1, const float beta2, const float beta3, const float alpha, const float eps, - const float weight_decay, const int step, const float lr, const float gnorm_scale, const bool skip_zeros, - const int n -); - -template -__global__ void kPreconditionOptimizer32bit1State( - T* g, T* p, float* state1, float* unorm, const float beta1, const float beta2, const float eps, - const float weight_decay, const int step, const float lr, const float gnorm_scale, const int n -); - -template -__global__ void kOptimizer32bit1State( - T* g, T* p, float* state1, float* unorm, const float max_unorm, const float param_norm, const float beta1, - const float beta2, const float eps, const float weight_decay, const int step, const float lr, - const float gnorm_scale, const bool skip_zeros, const int n -); - -template -__global__ void kOptimizerStatic8bit2StateBlockwise( - T* p, T* __restrict__ const g, unsigned char* state1, unsigned char* state2, const float beta1, const float beta2, - const float beta3, const float alpha, const float eps, const int step, const float lr, - float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, float* absmax1, float* absmax2, - float weight_decay, const float gnorm_scale, const bool skip_zeros, const int n -); - -template -__global__ void kOptimizerStatic8bit1StateBlockwise( - T* p, T* __restrict__ const g, unsigned char* state1, const float beta1, const float beta2, const float eps, - const int step, const float lr, float* __restrict__ const quantiles1, float* absmax1, float weight_decay, - const float gnorm_scale, const bool skip_zeros, const int n -); - -template -__global__ void kdequant_mm_int32_fp16( - int* __restrict__ const A, float* __restrict__ const rowStats, float* __restrict__ const colStats, half* out, - half* __restrict__ const bias, const int numRows, const int numCols, const int n -); - -template -__global__ void kInt8VectorQuant(T* __restrict__ A, int8_t* out, float* rowStats, float threshold, int rows, int cols); - -template -__global__ void kgemm_4bit_inference_naive( - int M, int N, int K, T* __restrict__ const A, unsigned char* B, float* absmax, const float* datatype, T* out, - int lda, int ldb, int ldc, int blocksize -); - -template __global__ void kfunc(T* A, T* B, T value, long n); - -#endif diff --git a/csrc/ops.cu b/csrc/ops.cu index 88bb675a3..ef13678e4 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -4,7 +4,6 @@ // LICENSE file in the root directory of this source tree. #include -#include #include #include #include @@ -34,23 +33,34 @@ void quantizeBlockwise( kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); else if (blocksize == 128) kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); - else if (blocksize == 64) + else if (blocksize == 64) { +#if BNB_HIP + // On HIP with 64-wide warps (CDNA), use specialized kernel for 4-bit types + if constexpr (DATA_TYPE > 0) { + kQuantizeBlockwiseSmall + <<<(num_blocks + 1) / 2, 64>>>(code, A, absmax, out, rand, rand_offset, n); + } else { + kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); + } +#else kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); - else if (blocksize == 32) { - // For 4-bit: use specialized kernel (kQuantizeBlockwise32) that processes 2 blocks per warp +#endif + } else if (blocksize == 32) { + // For 4-bit: use specialized kernel that processes 2 blocks per warp // Each CUDA block handles 2 quantization blocks, so divide num_blocks by 2 - if (DATA_TYPE > 0) { + if constexpr (DATA_TYPE > 0) { int num_blocks_adjusted = (num_blocks + 1) / 2; - kQuantizeBlockwise32<<>>(code, A, absmax, out, rand, rand_offset, n); + kQuantizeBlockwiseSmall + <<>>(code, A, absmax, out, rand, rand_offset, n); } } - CUDA_CHECK_RETURN(cudaPeekAtLastError()); + BNB_CHECK_RETURN(BNB_PEEK_LAST_ERROR()); } template void dequantizeBlockwise( - float* code, unsigned char* A, float* absmax, T* out, int blocksize, const int n, cudaStream_t stream + float* code, unsigned char* A, float* absmax, T* out, int blocksize, const int n, bnb_stream_t stream ) { constexpr int tile_size = (DATA_TYPE > 0) ? 1024 : 512; @@ -64,7 +74,7 @@ void dequantizeBlockwise( kDequantizeBlockwise <<>>(code, A, absmax, out, blocksize, n); - CUDA_CHECK_RETURN(cudaPeekAtLastError()); + BNB_CHECK_RETURN(BNB_PEEK_LAST_ERROR()); } template @@ -79,33 +89,33 @@ void optimizer32bit( case ADAM: case ADEMAMIX: if (max_unorm > 0.0f) { - CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1 * sizeof(float))); + BNB_CHECK_RETURN(BNB_DEVICE_MEMSET(unorm, 0, 1 * sizeof(float))); kPreconditionOptimizer32bit2State<<>>( g, p, state1, state2, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n ); - CUDA_CHECK_RETURN(cudaPeekAtLastError()); + BNB_CHECK_RETURN(BNB_PEEK_LAST_ERROR()); } kOptimizer32bit2State<<>>( g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, beta3, alpha, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n ); - CUDA_CHECK_RETURN(cudaPeekAtLastError()); + BNB_CHECK_RETURN(BNB_PEEK_LAST_ERROR()); break; case MOMENTUM: case RMSPROP: case ADAGRAD: if (max_unorm > 0.0f) { - CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1 * sizeof(float))); + BNB_CHECK_RETURN(BNB_DEVICE_MEMSET(unorm, 0, 1 * sizeof(float))); kPreconditionOptimizer32bit1State <<>>(g, p, state1, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n); - CUDA_CHECK_RETURN(cudaPeekAtLastError()); + BNB_CHECK_RETURN(BNB_PEEK_LAST_ERROR()); } kOptimizer32bit1State<<>>( g, p, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n ); - CUDA_CHECK_RETURN(cudaPeekAtLastError()); + BNB_CHECK_RETURN(BNB_PEEK_LAST_ERROR()); break; case LION: // in lion, the momentum update after the parameter update @@ -113,13 +123,13 @@ void optimizer32bit( g, p, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n ); - CUDA_CHECK_RETURN(cudaPeekAtLastError()); + BNB_CHECK_RETURN(BNB_PEEK_LAST_ERROR()); if (max_unorm > 0.0f) { - CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1 * sizeof(float))); + BNB_CHECK_RETURN(BNB_DEVICE_MEMSET(unorm, 0, 1 * sizeof(float))); kPreconditionOptimizer32bit1State <<>>(g, p, state1, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n); - CUDA_CHECK_RETURN(cudaPeekAtLastError()); + BNB_CHECK_RETURN(BNB_PEEK_LAST_ERROR()); } break; } @@ -148,7 +158,7 @@ void optimizerStatic8bitBlockwise( p, g, state1, state2, beta1, beta2, beta3, alpha, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n ); - CUDA_CHECK_RETURN(cudaPeekAtLastError()); + BNB_CHECK_RETURN(BNB_PEEK_LAST_ERROR()); break; case MOMENTUM: case RMSPROP: @@ -160,7 +170,7 @@ void optimizerStatic8bitBlockwise( <<>>( p, g, state1, beta1, beta2, eps, step, lr, quantiles1, absmax1, weight_decay, gnorm_scale, skip_zeros, n ); - CUDA_CHECK_RETURN(cudaPeekAtLastError()); + BNB_CHECK_RETURN(BNB_PEEK_LAST_ERROR()); break; } } @@ -173,6 +183,27 @@ void gemmex( const int fbeta = 0; const void* alpha = &falpha; const void* beta = &fbeta; + +#if BNB_HIP + hipblasStatus_t status; + +#if hipblasVersionMajor >= 3 + status = hipblasGemmEx( + context->m_handle, transposeA ? HIPBLAS_OP_T : HIPBLAS_OP_N, transposeB ? HIPBLAS_OP_T : HIPBLAS_OP_N, m, n, k, + alpha, A, HIP_R_8I, lda, B, HIP_R_8I, ldb, beta, C, HIP_R_32I, ldc, HIPBLAS_COMPUTE_32I, HIPBLAS_GEMM_DEFAULT + ); +#else + status = hipblasGemmEx( + context->m_handle, transposeA ? HIPBLAS_OP_T : HIPBLAS_OP_N, transposeB ? HIPBLAS_OP_T : HIPBLAS_OP_N, m, n, k, + alpha, A, HIPBLAS_R_8I, lda, B, HIPBLAS_R_8I, ldb, beta, C, HIPBLAS_R_32I, ldc, HIPBLAS_R_32I, + HIPBLAS_GEMM_DEFAULT + ); +#endif + + if (status != HIPBLAS_STATUS_SUCCESS) { + std::cout << "HIPBLAS ERROR: Status " << status << std::endl; + } +#else cublasStatus_t status; status = cublasGemmEx( @@ -183,6 +214,7 @@ void gemmex( if (status != CUBLAS_STATUS_SUCCESS) { std::cout << "CUBLAS ERROR: Status " << status << std::endl; } +#endif } void strided_gemmex( @@ -193,13 +225,29 @@ void strided_gemmex( const int fbeta = 0; const void* alpha = &falpha; const void* beta = &fbeta; - cublasStatus_t status; - // cout << transposeA << transposeB << endl; - // printf("%i %i %i\n", m,n,k); - // printf("%i %i %i\n", lda,ldb,ldc); - // printf("%i %i %i\n", strideA, strideB, strideC); - // printf("%i\n", batchCount); +#if BNB_HIP + hipblasStatus_t status; + +#if hipblasVersionMajor >= 3 + status = hipblasGemmStridedBatchedEx( + context->m_handle, transposeA ? HIPBLAS_OP_T : HIPBLAS_OP_N, transposeB ? HIPBLAS_OP_T : HIPBLAS_OP_N, m, n, k, + alpha, A, HIP_R_8I, lda, (long long int)strideA, B, HIP_R_8I, ldb, (long long int)strideB, beta, C, HIP_R_32I, + ldc, (long long int)strideC, batchCount, HIPBLAS_COMPUTE_32I, HIPBLAS_GEMM_DEFAULT + ); +#else + status = hipblasGemmStridedBatchedEx( + context->m_handle, transposeA ? HIPBLAS_OP_T : HIPBLAS_OP_N, transposeB ? HIPBLAS_OP_T : HIPBLAS_OP_N, m, n, k, + alpha, A, HIPBLAS_R_8I, lda, (long long int)strideA, B, HIPBLAS_R_8I, ldb, (long long int)strideB, beta, C, + HIPBLAS_R_32I, ldc, (long long int)strideC, batchCount, HIPBLAS_R_32I, HIPBLAS_GEMM_DEFAULT + ); +#endif + + if (status != HIPBLAS_STATUS_SUCCESS) { + std::cout << "HIPBLAS ERROR: Status " << status << std::endl; + } +#else + cublasStatus_t status; status = cublasGemmStridedBatchedEx( context->m_handle, transposeA ? CUBLAS_OP_T : CUBLAS_OP_N, transposeB ? CUBLAS_OP_T : CUBLAS_OP_N, m, n, k, @@ -210,16 +258,21 @@ void strided_gemmex( if (status != CUBLAS_STATUS_SUCCESS) { std::cout << "CUBLAS ERROR: Status " << status << std::endl; } +#endif } int roundoff(int v, int d) { return (v + d - 1) / d * d; } template int igemmlt( - cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t* A, const int8_t* B, void* C, float* row_scale, - int lda, int ldb, int ldc, cudaStream_t stream + bnb_blasLt_handle_t ltHandle, int m, int n, int k, const int8_t* A, const int8_t* B, void* C, float* row_scale, + int lda, int ldb, int ldc, bnb_stream_t stream ) { +#if BNB_HIP && defined(NO_HIPBLASLT) + return ERR_NOT_IMPLEMENTED; +#else + // Calculate C = A^T @ B, in col-major layout. // // Use the IMMA kernels requires: @@ -229,62 +282,92 @@ int igemmlt( int has_error = 0; - cublasLtMatmulDesc_t matmulDesc; - cublasLtMatrixLayout_t aDesc, bDesc, cDesc; - cublasOperation_t opT = CUBLAS_OP_T; + bnb_blasLt_matmul_desc_t matmulDesc; + bnb_blasLt_layout_t aDesc, bDesc, cDesc; + auto opT = BNB_BLASLT_OP_T; - cudaDataType_t outType = DTYPE_OUT == 32 ? CUDA_R_32I : CUDA_R_8I; - cudaDataType_t scaleType = DTYPE_OUT == 32 ? CUDA_R_32I : CUDA_R_32F; + auto outType = DTYPE_OUT == 32 ? BNB_R_32I : BNB_R_8I; + auto scaleType = DTYPE_OUT == 32 ? BNB_R_32I : BNB_R_32F; - cublasLtPointerMode_t pointerMode = CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_ZERO; + auto pointerMode = BNB_BLASLT_PTR_MODE_ALPHA_VEC; - has_error |= checkCublasStatus(cublasLtMatrixLayoutCreate(&aDesc, CUDA_R_8I, m, k, lda)); - has_error |= checkCublasStatus(cublasLtMatrixLayoutCreate(&bDesc, CUDA_R_8I, m, n, ldb)); - has_error |= checkCublasStatus(cublasLtMatrixLayoutCreate(&cDesc, outType, k, n, ldc)); + has_error |= checkBlasLtStatus(bnb_blasLtLayoutCreate(&aDesc, BNB_R_8I, m, k, lda)); + has_error |= checkBlasLtStatus(bnb_blasLtLayoutCreate(&bDesc, BNB_R_8I, m, n, ldb)); + has_error |= checkBlasLtStatus(bnb_blasLtLayoutCreate(&cDesc, outType, k, n, ldc)); // Default layout order is col major - has_error |= checkCublasStatus(cublasLtMatmulDescCreate(&matmulDesc, CUBLAS_COMPUTE_32I, scaleType)); - has_error |= - checkCublasStatus(cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_TRANSA, &opT, sizeof(opT))); + has_error |= checkBlasLtStatus(bnb_blasLtMatmulDescCreate(&matmulDesc, BNB_BLASLT_COMPUTE_32I, scaleType)); + has_error |= checkBlasLtStatus(bnb_blasLtMatmulDescSetAttr(matmulDesc, BNB_BLASLT_DESC_TRANSA, &opT, sizeof(opT))); if (DTYPE_OUT == 32) { +#if BNB_HIP + // HIP requires heuristic algo selection + const int64_t max_workspace_size = 0; // set to 0 to avoid choosing GSU kernel + + bnb_blasLt_preference_t pref; + checkBlasLtStatus(bnb_blasLtPrefCreate(&pref)); + checkBlasLtStatus( + bnb_blasLtPrefSetAttr(pref, BNB_BLASLT_PREF_MAX_WORKSPACE, &max_workspace_size, sizeof(max_workspace_size)) + ); + + const int request_solutions = 1; + bnb_blasLt_heuristic_t heuristicResult[request_solutions]; + int returnedAlgoCount = 0; + checkBlasLtStatus(bnb_blasLtAlgoGetHeuristic( + ltHandle, matmulDesc, aDesc, bDesc, cDesc, cDesc, pref, request_solutions, heuristicResult, + &returnedAlgoCount + )); + + if (returnedAlgoCount == 0) { + has_error = 1; + fprintf(stderr, "Error: Matmul Algo Heuristic didn't return algorithms\n"); + } else { + int alpha = 1, beta = 0; + has_error |= checkBlasLtStatus(bnb_blasLtMatmul( + ltHandle, matmulDesc, &alpha, A, aDesc, B, bDesc, &beta, (int32_t*)C, cDesc, (int32_t*)C, cDesc, + &heuristicResult[0].algo, NULL, 0, stream + )); + } +#else int alpha = 1, beta = 0; - has_error |= checkCublasStatus(cublasLtMatmul( + has_error |= checkBlasLtStatus(bnb_blasLtMatmul( ltHandle, matmulDesc, &alpha, A, aDesc, B, bDesc, &beta, (int32_t*)C, cDesc, (int32_t*)C, cDesc, NULL, NULL, 0, stream )); +#endif } else { // This path is unlikely to be used, as 8-bit accumulation can lead to likely overflows. if (!SCALE_ROWS) { float alpha = 1.0f, beta = 0.0f; - has_error |= checkCublasStatus(cublasLtMatmul( + has_error |= checkBlasLtStatus(bnb_blasLtMatmul( ltHandle, matmulDesc, &alpha, A, aDesc, B, bDesc, &beta, (int8_t*)C, cDesc, (int8_t*)C, cDesc, NULL, NULL, 0, stream )); } else { - cublasLtPointerMode_t alphaVec = CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_HOST; + auto alphaVec = BNB_BLASLT_PTR_MODE_ALPHA_VEC; float beta = 0.0f; - has_error |= checkCublasStatus(cublasLtMatmulDescSetAttribute( - matmulDesc, CUBLASLT_MATMUL_DESC_POINTER_MODE, &pointerMode, sizeof(alphaVec) - )); - has_error |= checkCublasStatus(cublasLtMatmul( + has_error |= checkBlasLtStatus( + bnb_blasLtMatmulDescSetAttr(matmulDesc, BNB_BLASLT_DESC_POINTER_MODE, &pointerMode, sizeof(alphaVec)) + ); + has_error |= checkBlasLtStatus(bnb_blasLtMatmul( ltHandle, matmulDesc, row_scale, A, aDesc, B, bDesc, &beta, (int8_t*)C, cDesc, (int8_t*)C, cDesc, NULL, NULL, 0, stream )); } } - has_error |= checkCublasStatus(cublasLtMatrixLayoutDestroy(cDesc)); - has_error |= checkCublasStatus(cublasLtMatrixLayoutDestroy(bDesc)); - has_error |= checkCublasStatus(cublasLtMatrixLayoutDestroy(aDesc)); - has_error |= checkCublasStatus(cublasLtMatmulDescDestroy(matmulDesc)); + has_error |= checkBlasLtStatus(bnb_blasLtLayoutDestroy(cDesc)); + has_error |= checkBlasLtStatus(bnb_blasLtLayoutDestroy(bDesc)); + has_error |= checkBlasLtStatus(bnb_blasLtLayoutDestroy(aDesc)); + has_error |= checkBlasLtStatus(bnb_blasLtMatmulDescDestroy(matmulDesc)); if (has_error == 1) printf("error detected"); return has_error; +#endif // NO_HIPBLASLT } int fill_up_to_nearest_multiple(int value, int multiple) { @@ -292,7 +375,7 @@ int fill_up_to_nearest_multiple(int value, int multiple) { } void dequant_mm_int32_fp16( - int* A, float* rowStats, float* colStats, half* out, half* bias, int numRows, int numCols, cudaStream_t stream + int* A, float* rowStats, float* colStats, half* out, half* bias, int numRows, int numCols, bnb_stream_t stream ) { const int threads = 512; const int num_per_thread = 4; @@ -302,30 +385,37 @@ void dequant_mm_int32_fp16( kdequant_mm_int32_fp16 <<>>(A, rowStats, colStats, out, bias, numRows, numCols, n); - CUDA_CHECK_RETURN(cudaPeekAtLastError()); + BNB_CHECK_RETURN(BNB_PEEK_LAST_ERROR()); } void int8VectorQuant( - half* __restrict__ A, int8_t* out, float* rowStats, float threshold, int rows, int cols, cudaStream_t stream + half* __restrict__ A, int8_t* out, float* rowStats, float threshold, int rows, int cols, bnb_stream_t stream ) { if (threshold == 0.0) { kInt8VectorQuant<<>>(A, out, rowStats, threshold, rows, cols); } else { kInt8VectorQuant<<>>(A, out, rowStats, threshold, rows, cols); } - CUDA_CHECK_RETURN(cudaPeekAtLastError()); + BNB_CHECK_RETURN(BNB_PEEK_LAST_ERROR()); } template void gemm_4bit_inference_naive( int m, int n, int k, T* A, unsigned char* B, float* absmax, float* datatype, T* out, int lda, int ldb, int ldc, - int blocksize, cudaStream_t stream + int blocksize, bnb_stream_t stream ) { int num_blocks = (m + 3) / 4; +#if BNB_HIP + // On 64-wide warp architectures, each warp processes 2 rows instead of 4 + if (BNB_WARP_SIZE == 64) { + num_blocks = (m + 1) / 2; + } +#endif + kgemm_4bit_inference_naive <<>>(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize); - CUDA_CHECK_RETURN(cudaPeekAtLastError()); + BNB_CHECK_RETURN(BNB_PEEK_LAST_ERROR()); } template void func(T* A, T* B, T value, long n) { @@ -334,7 +424,7 @@ template void func(T* A, T* B, T value, long n) { blocks = n % threads == 0 ? blocks : blocks + 1; blocks = blocks > 65535 ? 65535 : blocks; kfunc<<>>(A, B, value, n); - CUDA_CHECK_RETURN(cudaPeekAtLastError()); + BNB_CHECK_RETURN(BNB_PEEK_LAST_ERROR()); } //============================================================== @@ -348,28 +438,28 @@ template void func(float* A, float* B, float value, long n); template void gemm_4bit_inference_naive( int m, int n, int k, half* A, unsigned char* B, float* absmax, float* datatype, half* out, int lda, int ldb, - int ldc, int blocksize, cudaStream_t stream + int ldc, int blocksize, bnb_stream_t stream ); -template void gemm_4bit_inference_naive<__nv_bfloat16, 16>( - int m, int n, int k, __nv_bfloat16* A, unsigned char* B, float* absmax, float* datatype, __nv_bfloat16* out, - int lda, int ldb, int ldc, int blocksize, cudaStream_t stream +template void gemm_4bit_inference_naive( + int m, int n, int k, bnb_bfloat16* A, unsigned char* B, float* absmax, float* datatype, bnb_bfloat16* out, int lda, + int ldb, int ldc, int blocksize, bnb_stream_t stream ); template void gemm_4bit_inference_naive( int m, int n, int k, float* A, unsigned char* B, float* absmax, float* datatype, float* out, int lda, int ldb, - int ldc, int blocksize, cudaStream_t stream + int ldc, int blocksize, bnb_stream_t stream ); template int igemmlt<32, 0>( - cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t* A, const int8_t* B, void* C, float* row_scale, - int lda, int ldb, int ldc, cudaStream_t stream + bnb_blasLt_handle_t ltHandle, int m, int n, int k, const int8_t* A, const int8_t* B, void* C, float* row_scale, + int lda, int ldb, int ldc, bnb_stream_t stream ); template int igemmlt<8, 0>( - cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t* A, const int8_t* B, void* C, float* row_scale, - int lda, int ldb, int ldc, cudaStream_t stream + bnb_blasLt_handle_t ltHandle, int m, int n, int k, const int8_t* A, const int8_t* B, void* C, float* row_scale, + int lda, int ldb, int ldc, bnb_stream_t stream ); template int igemmlt<8, 1>( - cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t* A, const int8_t* B, void* C, float* row_scale, - int lda, int ldb, int ldc, cudaStream_t stream + bnb_blasLt_handle_t ltHandle, int m, int n, int k, const int8_t* A, const int8_t* B, void* C, float* row_scale, + int lda, int ldb, int ldc, bnb_stream_t stream ); template void quantizeBlockwise( @@ -396,49 +486,49 @@ template void quantizeBlockwise( template void quantizeBlockwise( float* code, float* A, float* absmax, unsigned char* out, float* rand, int rand_offset, int blocksize, const int n ); -template void quantizeBlockwise<__nv_bfloat16, 1, General8bit>( - float* code, __nv_bfloat16* A, float* absmax, unsigned char* out, float* rand, int rand_offset, int blocksize, +template void quantizeBlockwise( + float* code, bnb_bfloat16* A, float* absmax, unsigned char* out, float* rand, int rand_offset, int blocksize, const int n ); -template void quantizeBlockwise<__nv_bfloat16, 0, General8bit>( - float* code, __nv_bfloat16* A, float* absmax, unsigned char* out, float* rand, int rand_offset, int blocksize, +template void quantizeBlockwise( + float* code, bnb_bfloat16* A, float* absmax, unsigned char* out, float* rand, int rand_offset, int blocksize, const int n ); -template void quantizeBlockwise<__nv_bfloat16, 0, FP4>( - float* code, __nv_bfloat16* A, float* absmax, unsigned char* out, float* rand, int rand_offset, int blocksize, +template void quantizeBlockwise( + float* code, bnb_bfloat16* A, float* absmax, unsigned char* out, float* rand, int rand_offset, int blocksize, const int n ); -template void quantizeBlockwise<__nv_bfloat16, 0, NF4>( - float* code, __nv_bfloat16* A, float* absmax, unsigned char* out, float* rand, int rand_offset, int blocksize, +template void quantizeBlockwise( + float* code, bnb_bfloat16* A, float* absmax, unsigned char* out, float* rand, int rand_offset, int blocksize, const int n ); template void dequantizeBlockwise( - float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, cudaStream_t stream + float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, bnb_stream_t stream ); template void dequantizeBlockwise( - float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, cudaStream_t stream + float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, bnb_stream_t stream ); template void dequantizeBlockwise( - float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, cudaStream_t stream + float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, bnb_stream_t stream ); template void dequantizeBlockwise( - float* code, unsigned char* A, float* absmax, half* out, int blocksize, const int n, cudaStream_t stream + float* code, unsigned char* A, float* absmax, half* out, int blocksize, const int n, bnb_stream_t stream ); template void dequantizeBlockwise( - float* code, unsigned char* A, float* absmax, half* out, int blocksize, const int n, cudaStream_t stream + float* code, unsigned char* A, float* absmax, half* out, int blocksize, const int n, bnb_stream_t stream ); template void dequantizeBlockwise( - float* code, unsigned char* A, float* absmax, half* out, int blocksize, const int n, cudaStream_t stream + float* code, unsigned char* A, float* absmax, half* out, int blocksize, const int n, bnb_stream_t stream ); -template void dequantizeBlockwise<__nv_bfloat16, General8bit>( - float* code, unsigned char* A, float* absmax, __nv_bfloat16* out, int blocksize, const int n, cudaStream_t stream +template void dequantizeBlockwise( + float* code, unsigned char* A, float* absmax, bnb_bfloat16* out, int blocksize, const int n, bnb_stream_t stream ); -template void dequantizeBlockwise<__nv_bfloat16, FP4>( - float* code, unsigned char* A, float* absmax, __nv_bfloat16* out, int blocksize, const int n, cudaStream_t stream +template void dequantizeBlockwise( + float* code, unsigned char* A, float* absmax, bnb_bfloat16* out, int blocksize, const int n, bnb_stream_t stream ); -template void dequantizeBlockwise<__nv_bfloat16, NF4>( - float* code, unsigned char* A, float* absmax, __nv_bfloat16* out, int blocksize, const int n, cudaStream_t stream +template void dequantizeBlockwise( + float* code, unsigned char* A, float* absmax, bnb_bfloat16* out, int blocksize, const int n, bnb_stream_t stream ); #define MAKE_optimizer32bit(name, gtype) \ @@ -449,9 +539,9 @@ template void dequantizeBlockwise<__nv_bfloat16, NF4>( const int n \ ); -MAKE_optimizer32bit(ADAM, half) MAKE_optimizer32bit(ADAM, float) MAKE_optimizer32bit(ADAM, __nv_bfloat16) MAKE_optimizer32bit(MOMENTUM, half) MAKE_optimizer32bit(MOMENTUM, float) MAKE_optimizer32bit( - MOMENTUM, __nv_bfloat16 -) MAKE_optimizer32bit(RMSPROP, half) MAKE_optimizer32bit(RMSPROP, float) MAKE_optimizer32bit(RMSPROP, __nv_bfloat16) MAKE_optimizer32bit(LION, half) MAKE_optimizer32bit(LION, float) MAKE_optimizer32bit(LION, __nv_bfloat16) MAKE_optimizer32bit(ADAGRAD, half) MAKE_optimizer32bit(ADAGRAD, float) MAKE_optimizer32bit(ADAGRAD, __nv_bfloat16) MAKE_optimizer32bit(ADEMAMIX, half) MAKE_optimizer32bit(ADEMAMIX, __nv_bfloat16) MAKE_optimizer32bit(ADEMAMIX, float) +MAKE_optimizer32bit(ADAM, half) MAKE_optimizer32bit(ADAM, float) MAKE_optimizer32bit(ADAM, bnb_bfloat16) MAKE_optimizer32bit(MOMENTUM, half) MAKE_optimizer32bit(MOMENTUM, float) MAKE_optimizer32bit(MOMENTUM, bnb_bfloat16) MAKE_optimizer32bit(RMSPROP, half) MAKE_optimizer32bit(RMSPROP, float) MAKE_optimizer32bit(RMSPROP, bnb_bfloat16) MAKE_optimizer32bit( + LION, half +) MAKE_optimizer32bit(LION, float) MAKE_optimizer32bit(LION, bnb_bfloat16) MAKE_optimizer32bit(ADAGRAD, half) MAKE_optimizer32bit(ADAGRAD, float) MAKE_optimizer32bit(ADAGRAD, bnb_bfloat16) MAKE_optimizer32bit(ADEMAMIX, half) MAKE_optimizer32bit(ADEMAMIX, bnb_bfloat16) MAKE_optimizer32bit(ADEMAMIX, float) #define MAKE_optimizerStatic8bitBlockwise(gtype, optim_name) \ template void optimizerStatic8bitBlockwise( \ @@ -462,19 +552,19 @@ MAKE_optimizer32bit(ADAM, half) MAKE_optimizer32bit(ADAM, float) MAKE_optimizer3 MAKE_optimizerStatic8bitBlockwise(half, ADAM); MAKE_optimizerStatic8bitBlockwise(float, ADAM); -MAKE_optimizerStatic8bitBlockwise(__nv_bfloat16, ADAM); +MAKE_optimizerStatic8bitBlockwise(bnb_bfloat16, ADAM); MAKE_optimizerStatic8bitBlockwise(half, MOMENTUM); MAKE_optimizerStatic8bitBlockwise(float, MOMENTUM); -MAKE_optimizerStatic8bitBlockwise(__nv_bfloat16, MOMENTUM); +MAKE_optimizerStatic8bitBlockwise(bnb_bfloat16, MOMENTUM); MAKE_optimizerStatic8bitBlockwise(half, RMSPROP); MAKE_optimizerStatic8bitBlockwise(float, RMSPROP); -MAKE_optimizerStatic8bitBlockwise(__nv_bfloat16, RMSPROP); +MAKE_optimizerStatic8bitBlockwise(bnb_bfloat16, RMSPROP); MAKE_optimizerStatic8bitBlockwise(half, LION); MAKE_optimizerStatic8bitBlockwise(float, LION); -MAKE_optimizerStatic8bitBlockwise(__nv_bfloat16, LION); +MAKE_optimizerStatic8bitBlockwise(bnb_bfloat16, LION); MAKE_optimizerStatic8bitBlockwise(half, ADAGRAD); MAKE_optimizerStatic8bitBlockwise(float, ADAGRAD); -MAKE_optimizerStatic8bitBlockwise(__nv_bfloat16, ADAGRAD); +MAKE_optimizerStatic8bitBlockwise(bnb_bfloat16, ADAGRAD); MAKE_optimizerStatic8bitBlockwise(half, ADEMAMIX); -MAKE_optimizerStatic8bitBlockwise(__nv_bfloat16, ADEMAMIX); +MAKE_optimizerStatic8bitBlockwise(bnb_bfloat16, ADEMAMIX); MAKE_optimizerStatic8bitBlockwise(float, ADEMAMIX); diff --git a/csrc/ops.cuh b/csrc/ops.cuh index 4d3af547f..c7114bcaa 100644 --- a/csrc/ops.cuh +++ b/csrc/ops.cuh @@ -8,42 +8,34 @@ #include #include +#include #include #include +#include +#include "common.cuh" +#include "compat.cuh" #include -#include -#include -#include -#include -#include -#include -#define CUDA_CHECK_RETURN(value) \ - { \ - cudaError_t _m_cudaStat = value; \ - if (_m_cudaStat != cudaSuccess) { \ - fprintf(stderr, "Error %s at line %d in file %s\n", cudaGetErrorString(_m_cudaStat), __LINE__, __FILE__); \ - exit(1); \ - } \ - } +// Error checking helpers -inline void checkCudaStatus(cudaError_t status) { - if (status != cudaSuccess) { - printf("cuda API failed with status %d: %s\n", status, cudaGetErrorString(status)); - throw std::logic_error("cuda API failed"); +inline void checkDeviceStatus(bnb_error_t status) { + if (status != BNB_SUCCESS) { + printf("Device API failed with status %d: %s\n", status, BNB_GET_ERROR_STRING(status)); + throw std::logic_error("Device API failed"); } } -inline int checkCublasStatus(cublasStatus_t status) { - if (status != CUBLAS_STATUS_SUCCESS) { - printf("cuBLAS API failed with status %d\n", status); - // throw std::logic_error("cuBLAS API failed"); +inline int checkBlasLtStatus(bnb_blas_status_t status) { + if (status != BNB_BLAS_STATUS_SUCCESS) { + printf("BLAS Lt API failed with status %d\n", status); return 1; } return 0; } +// Enums + typedef enum Operations_t { ksmul = 0, } Operations_t; @@ -55,7 +47,7 @@ typedef enum Optimizer_t { LARS = 3, ADAGRAD = 4, LION = 5, - ADEMAMIX = 6 + ADEMAMIX = 6, } Optimizer_t; typedef enum Funcs_t { @@ -64,8 +56,19 @@ typedef enum Funcs_t { _MUL = 2, } Funcs_t; +// Context classes + class Context { public: +#if BNB_HIP + rocblas_handle m_handle; + + Context() { + rocblas_handle handle; + rocblas_create_handle(&handle); + m_handle = handle; + } +#else cublasHandle_t m_handle; Context() { @@ -73,26 +76,29 @@ class Context { cublasCreate_v2(&handle); m_handle = handle; } +#endif }; class ContextLt { public: - cublasLtHandle_t m_handle; + bnb_blasLt_handle_t m_handle; ContextLt() { - cublasLtHandle_t handle; - cublasLtCreate(&handle); + bnb_blasLt_handle_t handle; + bnb_blasLtCreate(&handle); m_handle = handle; } }; +// Function declarations + template void quantizeBlockwise( float* code, T* A, float* absmax, unsigned char* out, float* rand, int rand_offset, int blocksize, const int n ); template void dequantizeBlockwise( - float* code, unsigned char* A, float* absmax, T* out, int block_size, const int n, cudaStream_t stream + float* code, unsigned char* A, float* absmax, T* out, int block_size, const int n, bnb_stream_t stream ); template @@ -120,24 +126,24 @@ void strided_gemmex( template int igemmlt( - cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t* A, const int8_t* B, void* C, float* row_scale, - int lda, int ldb, int ldc, cudaStream_t stream + bnb_blasLt_handle_t ltHandle, int m, int n, int k, const int8_t* A, const int8_t* B, void* C, float* row_scale, + int lda, int ldb, int ldc, bnb_stream_t stream ); void cutlass_igemm( bool transposeA, bool transposeB, int m, int n, int k, void* A, void* B, void* C, int lda, int ldb, int ldc ); void dequant_mm_int32_fp16( - int* A, float* rowStats, float* colStats, half* out, half* bias, int numRows, int numCols, cudaStream_t stream + int* A, float* rowStats, float* colStats, half* out, half* bias, int numRows, int numCols, bnb_stream_t stream ); void int8VectorQuant( - half* __restrict__ A, int8_t* out, float* rowStats, float threshold, int rows, int cols, cudaStream_t stream + half* __restrict__ A, int8_t* out, float* rowStats, float threshold, int rows, int cols, bnb_stream_t stream ); template void gemm_4bit_inference_naive( int m, int n, int k, T* A, unsigned char* B, float* absmax, float* datatype, T* out, int lda, int ldb, int ldc, - int blocksize, cudaStream_t stream + int blocksize, bnb_stream_t stream ); template void func(T* A, T* B, T value, long n); diff --git a/csrc/ops.hip b/csrc/ops.hip deleted file mode 100644 index 937f8f249..000000000 --- a/csrc/ops.hip +++ /dev/null @@ -1,569 +0,0 @@ -// !!! This is a file automatically generated by hipify!!! -#include "hip/hip_runtime.h" -// Copyright (c) Facebook, Inc. and its affiliates. -// -// This source code is licensed under the MIT license found in the -// LICENSE file in the root directory of this source tree. - -#include -#include -#include -#include -#include -#ifndef NO_HIPBLASLT -#include -#endif -#include -#include -#include - -#define ERR_NOT_IMPLEMENTED 100 - -using std::cout; -using std::endl; - -template void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float *rand, int rand_offset, int blocksize, const int n) -{ - int num_blocks = n/blocksize; - num_blocks = n % blocksize == 0 ? num_blocks : num_blocks + 1; - - if(blocksize == 4096) - hipLaunchKernelGGL(( kQuantizeBlockwise), dim3(num_blocks), dim3(1024), 0, 0, code, A, absmax, out, rand, rand_offset, n); - else if(blocksize == 2048) - hipLaunchKernelGGL(( kQuantizeBlockwise), dim3(num_blocks), dim3(512), 0, 0, code, A, absmax, out, rand, rand_offset, n); - else if(blocksize == 1024) - hipLaunchKernelGGL(( kQuantizeBlockwise), dim3(num_blocks), dim3(256), 0, 0, code, A, absmax, out, rand, rand_offset, n); - else if(blocksize == 512) - hipLaunchKernelGGL(( kQuantizeBlockwise), dim3(num_blocks), dim3(256), 0, 0, code, A, absmax, out, rand, rand_offset, n); - else if(blocksize == 256) - hipLaunchKernelGGL(( kQuantizeBlockwise), dim3(num_blocks), dim3(128), 0, 0, code, A, absmax, out, rand, rand_offset, n); - else if(blocksize == 128) - hipLaunchKernelGGL(( kQuantizeBlockwise), dim3(num_blocks), dim3(64), 0, 0, code, A, absmax, out, rand, rand_offset, n); - else if(blocksize == 64) { - // For 4-bit (FP4/NF4): use specialized kernel that processes 2 blocks of 64 per thread block - // Works on all warp sizes (32 and 64) by using logical warps of 32 - if constexpr(DATA_TYPE > 0) - hipLaunchKernelGGL(( kQuantizeBlockwise64), dim3((num_blocks + 1) / 2), dim3(64), 0, 0, code, A, absmax, out, rand, rand_offset, n); - else - hipLaunchKernelGGL(( kQuantizeBlockwise), dim3(num_blocks), dim3(32), 0, 0, code, A, absmax, out, rand, rand_offset, n); - } - - CUDA_CHECK_RETURN(hipPeekAtLastError()); -} - -template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int blocksize, const int n, hipStream_t stream) -{ - int tile_size = (DATA_TYPE > 0) ? 1024 : 512; - - // Upcast to int64 to avoid overflow for large n - int grid_blocks = ((int64_t)n + tile_size - 1) / tile_size; - - if(DATA_TYPE > 0) - hipLaunchKernelGGL(( kDequantizeBlockwise), dim3(grid_blocks), dim3(64), 0, stream, code, A, absmax, out, blocksize / 2, n); - else - hipLaunchKernelGGL(( kDequantizeBlockwise), dim3(grid_blocks), dim3(64), 0, stream, code, A, absmax, out, blocksize, n); - - CUDA_CHECK_RETURN(hipPeekAtLastError()); -} - - - -template void optimizer32bit(T* g, T* p, - float* state1, float* state2, float *unorm, float max_unorm, float param_norm, - const float beta1, const float beta2, const float beta3, const float alpha, - const float eps, const float weight_decay, - const int step, const float lr, const float gnorm_scale, bool skip_zeros, const int n) -{ - int num_blocks = n/4096; - num_blocks = n % 4096 == 0 ? num_blocks : num_blocks + 1; - switch(OPTIMIZER) - { - case ADAM: - case ADEMAMIX: - if(max_unorm > 0.0f) - { - CUDA_CHECK_RETURN(hipMemset(unorm, 0, 1*sizeof(float))); - hipLaunchKernelGGL(( kPreconditionOptimizer32bit2State), dim3(num_blocks), dim3(512), 0, 0, g, p, state1, state2, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n); - CUDA_CHECK_RETURN(hipPeekAtLastError()); - } - hipLaunchKernelGGL(( kOptimizer32bit2State), dim3(num_blocks), dim3(1024), 0, 0, g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, beta3, alpha, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); - CUDA_CHECK_RETURN(hipPeekAtLastError()); - break; - case MOMENTUM: - case RMSPROP: - case ADAGRAD: - if(max_unorm > 0.0f) - { - CUDA_CHECK_RETURN(hipMemset(unorm, 0, 1*sizeof(float))); - hipLaunchKernelGGL(( kPreconditionOptimizer32bit1State), dim3(num_blocks), dim3(512), 0, 0, g, p, state1, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n); - CUDA_CHECK_RETURN(hipPeekAtLastError()); - } - - hipLaunchKernelGGL(( kOptimizer32bit1State), dim3(num_blocks), dim3(1024), 0, 0, g, p, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); - CUDA_CHECK_RETURN(hipPeekAtLastError()); - break; - case LION: - // in lion, the momentum update after the parameter update - hipLaunchKernelGGL(( kOptimizer32bit1State), dim3(num_blocks), dim3(1024), 0, 0, g, p, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); - CUDA_CHECK_RETURN(hipPeekAtLastError()); - - if(max_unorm > 0.0f) - { - CUDA_CHECK_RETURN(hipMemset(unorm, 0, 1*sizeof(float))); - hipLaunchKernelGGL(( kPreconditionOptimizer32bit1State), dim3(num_blocks), dim3(512), 0, 0, g, p, state1, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n); - CUDA_CHECK_RETURN(hipPeekAtLastError()); - } - break; - } -} - -#define BLOCKSIZE_2STATE 256 -#define NUM_2STATE 1 -#define BLOCKSIZE_1STATE 256 -#define NUM_1STATE 1 - -template void optimizerStatic8bitBlockwise( - T* p, - T* g, - unsigned char* state1, - unsigned char* state2, - float beta1, - float beta2, - float beta3, - float alpha, - float eps, - int step, - float lr, - float* quantiles1, - float* quantiles2, - float* absmax1, - float* absmax2, - float weight_decay, - const float gnorm_scale, - bool skip_zeros, - int n -) { - - int num_blocks = 0; - switch(OPTIMIZER) - { - case ADAM: - case ADEMAMIX: - num_blocks = n/BLOCKSIZE_2STATE; - num_blocks = n % BLOCKSIZE_2STATE == 0 ? num_blocks : num_blocks + 1; - hipLaunchKernelGGL(( kOptimizerStatic8bit2StateBlockwise), dim3(num_blocks), dim3(BLOCKSIZE_2STATE/NUM_2STATE), 0, 0, p, g, state1, state2, beta1, beta2, beta3, alpha, eps, step, lr, - quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n); - CUDA_CHECK_RETURN(hipPeekAtLastError()); - break; - case MOMENTUM: - case RMSPROP: - case ADAGRAD: - case LION: - num_blocks = n/BLOCKSIZE_1STATE; - num_blocks = n % BLOCKSIZE_1STATE == 0 ? num_blocks : num_blocks + 1; - hipLaunchKernelGGL(( kOptimizerStatic8bit1StateBlockwise), dim3(num_blocks), dim3(BLOCKSIZE_1STATE/NUM_1STATE), 0, 0, p, g, state1, beta1, beta2, eps, step, lr, - quantiles1, absmax1, weight_decay, gnorm_scale, skip_zeros, n); - CUDA_CHECK_RETURN(hipPeekAtLastError()); - break; - } -} - -void gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc) -{ - const int falpha = 1; - const int fbeta = 0; - const void * alpha = &falpha; - const void * beta = &fbeta; - hipblasStatus_t status; - -#if hipblasVersionMajor >= 3 - status = hipblasGemmEx(context->m_handle, - transposeA ? HIPBLAS_OP_T : HIPBLAS_OP_N, - transposeB ? HIPBLAS_OP_T : HIPBLAS_OP_N, - m, n, k, - alpha, A, HIP_R_8I, lda, B, HIP_R_8I, ldb, beta, - C, HIP_R_32I, ldc, - HIPBLAS_COMPUTE_32I, HIPBLAS_GEMM_DEFAULT); -#else - status = hipblasGemmEx(context->m_handle, - transposeA ? HIPBLAS_OP_T : HIPBLAS_OP_N, - transposeB ? HIPBLAS_OP_T : HIPBLAS_OP_N, - m, n, k, - alpha, A, HIPBLAS_R_8I, lda, B, HIPBLAS_R_8I, ldb, beta, - C, HIPBLAS_R_32I, ldc, - HIPBLAS_R_32I, HIPBLAS_GEMM_DEFAULT); -#endif - - if (status != HIPBLAS_STATUS_SUCCESS) - { - std::cout << "HIPBLAS ERROR: Status " << status << std::endl; - } - -} - -void strided_gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc, - long long int strideA, long long int strideB, long long int strideC, int batchCount) -{ - const int falpha = 1; - const int fbeta = 0; - const void * alpha = &falpha; - const void * beta = &fbeta; - hipblasStatus_t status; - - //cout << transposeA << transposeB << endl; - //printf("%i %i %i\n", m,n,k); - //printf("%i %i %i\n", lda,ldb,ldc); - //printf("%i %i %i\n", strideA, strideB, strideC); - //printf("%i\n", batchCount); - -#if hipblasVersionMajor >= 3 - status = hipblasGemmStridedBatchedEx(context->m_handle, - transposeA ? HIPBLAS_OP_T : HIPBLAS_OP_N, - transposeB ? HIPBLAS_OP_T : HIPBLAS_OP_N, - m, n, k, - alpha, A, HIP_R_8I, lda, (long long int)strideA, B, HIP_R_8I, ldb, (long long int)strideB, beta, - C, HIP_R_32I, ldc, (long long int)strideC, batchCount, - HIPBLAS_COMPUTE_32I, HIPBLAS_GEMM_DEFAULT); -#else - status = hipblasGemmStridedBatchedEx(context->m_handle, - transposeA ? HIPBLAS_OP_T : HIPBLAS_OP_N, - transposeB ? HIPBLAS_OP_T : HIPBLAS_OP_N, - m, n, k, - alpha, A, HIPBLAS_R_8I, lda, (long long int)strideA, B, HIPBLAS_R_8I, ldb, (long long int)strideB, beta, - C, HIPBLAS_R_32I, ldc, (long long int)strideC, batchCount, - HIPBLAS_R_32I, HIPBLAS_GEMM_DEFAULT); -#endif - - if (status != HIPBLAS_STATUS_SUCCESS) - { - std::cout << "HIPBLAS ERROR: Status " << status << std::endl; - } - -} - -int roundoff(int v, int d) { - return (v + d - 1) / d * d; -} - -static std::string hipError_to_string(const hipError_t ret) -{ - switch(ret) - { - case hipSuccess: - return "hipSuccess"; - case hipErrorInvalidContext: - return "hipErrorInvalidContext"; - case hipErrorInvalidKernelFile: - return "hipErrorInvalidKernelFile"; - case hipErrorMemoryAllocation: - return "hipErrorMemoryAllocation"; - case hipErrorInitializationError: - return "hipErrorInitializationError"; - case hipErrorLaunchFailure: - return "hipErrorLaunchFailure"; - case hipErrorLaunchOutOfResources: - return "hipErrorLaunchOutOfResources"; - case hipErrorInvalidDevice: - return "hipErrorInvalidDevice"; - case hipErrorInvalidValue: - return "hipErrorInvalidValue"; - case hipErrorInvalidDevicePointer: - return "hipErrorInvalidDevicePointer"; - case hipErrorInvalidMemcpyDirection: - return "hipErrorInvalidMemcpyDirection"; - case hipErrorUnknown: - return "hipErrorUnknown"; - case hipErrorInvalidResourceHandle: - return "hipErrorInvalidResourceHandle"; - case hipErrorNotReady: - return "hipErrorNotReady"; - case hipErrorNoDevice: - return "hipErrorNoDevice"; - case hipErrorPeerAccessAlreadyEnabled: - return "hipErrorPeerAccessAlreadyEnabled"; - case hipErrorPeerAccessNotEnabled: - return "hipErrorPeerAccessNotEnabled"; - case hipErrorRuntimeMemory: - return "hipErrorRuntimeMemory"; - case hipErrorRuntimeOther: - return "hipErrorRuntimeOther"; - case hipErrorHostMemoryAlreadyRegistered: - return "hipErrorHostMemoryAlreadyRegistered"; - case hipErrorHostMemoryNotRegistered: - return "hipErrorHostMemoryNotRegistered"; - case hipErrorMapBufferObjectFailed: - return "hipErrorMapBufferObjectFailed"; - case hipErrorTbd: - return "hipErrorTbd"; - default: - throw std::runtime_error("unknown hipError"); - } -} - -template int igemmlt( - hipblasLtHandle_t ltHandle, - int m, int n, int k, - const int8_t *A, - const int8_t *B, - void *C, - float *row_scale, - int lda, int ldb, int ldc, - hipStream_t stream -) { -#ifdef NO_HIPBLASLT - return ERR_NOT_IMPLEMENTED; -#else - - // Calculate C = A^T @ B, in col-major layout. - // - // Use the IMMA kernels requires: - // * A must be transposed and B must be non-transposed. - // * Dimensions m and k must be multiples of 4. - // * All pointers must be 4-byte aligned; 16-byte alignment preferred. - - int has_error = 0; - const int64_t max_workspace_size = 0;//set to 0 to avoid choosing GSU kernel - - hipblasLtMatmulDesc_t matmulDesc; - hipblasLtMatrixLayout_t aDesc, bDesc, cDesc; - hipblasOperation_t opT = HIPBLAS_OP_T; - - hipDataType outType = DTYPE_OUT == 32 ? HIP_R_32I : HIP_R_8I; - hipDataType scaleType = DTYPE_OUT == 32 ? HIP_R_32I : HIP_R_32F; - - hipblasLtPointerMode_t pointerMode = HIPBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_HOST; - - has_error |= checkHipblasStatus(hipblasLtMatrixLayoutCreate(&aDesc, HIP_R_8I, m, k, lda)); - has_error |= checkHipblasStatus(hipblasLtMatrixLayoutCreate(&bDesc, HIP_R_8I, m, n, ldb)); - has_error |= checkHipblasStatus(hipblasLtMatrixLayoutCreate(&cDesc, outType, k, n, ldc)); - - // Default layout order is col major - - has_error |= checkHipblasStatus(hipblasLtMatmulDescCreate(&matmulDesc, HIPBLAS_COMPUTE_32I, scaleType)); - has_error |= checkHipblasStatus(hipblasLtMatmulDescSetAttribute(matmulDesc, HIPBLASLT_MATMUL_DESC_TRANSA, &opT, sizeof(opT))); - - if (DTYPE_OUT == 32) { - - /* Algo and workspace TODO: need to rework to not be duplicated */ - // Set User Preference attributes - hipblasLtMatmulPreference_t pref; - checkHipblasStatus(hipblasLtMatmulPreferenceCreate(&pref)); - checkHipblasStatus( - hipblasLtMatmulPreferenceSetAttribute(pref, - HIPBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, - &max_workspace_size, - sizeof(max_workspace_size))); - - const int request_solutions = 1; - hipblasLtMatmulHeuristicResult_t heuristicResult[request_solutions]; - int returnedAlgoCount = 0; - checkHipblasStatus(hipblasLtMatmulAlgoGetHeuristic(ltHandle, - matmulDesc, - aDesc, - bDesc, - cDesc, - cDesc, - pref, - request_solutions, - heuristicResult, - &returnedAlgoCount)); - - if (returnedAlgoCount == 0) - { - has_error = 1; - fprintf(stderr, "Error: Matmul Algo Heuristic didn't return algorithms\n"); - } else { - int alpha = 1, beta = 0; - has_error |= checkHipblasStatus(hipblasLtMatmul( - ltHandle, matmulDesc, - &alpha, A, aDesc, - B, bDesc, &beta, - (int32_t*)C, cDesc, - (int32_t*)C, cDesc, - &heuristicResult[0].algo, NULL, 0, stream - )); - } - } else { - // This path is unlikely to be used, as 8-bit accumulation can lead to likely overflows. - - if (!SCALE_ROWS) { - float alpha = 1.0f, beta = 0.0f; - has_error |= checkHipblasStatus(hipblasLtMatmul( - ltHandle, matmulDesc, - &alpha, A, aDesc, - B, bDesc, &beta, - (int8_t*)C, cDesc, - (int8_t*)C, cDesc, - NULL, NULL, 0, stream - )); - } else { - hipblasLtPointerMode_t alphaVec = HIPBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_HOST; - float beta = 0.0f; - has_error |= checkHipblasStatus(hipblasLtMatmulDescSetAttribute( - matmulDesc, - HIPBLASLT_MATMUL_DESC_POINTER_MODE, - &pointerMode, - sizeof(alphaVec) - )); - has_error |= checkHipblasStatus(hipblasLtMatmul( - ltHandle, matmulDesc, - row_scale, A, aDesc, - B, bDesc, &beta, - (int8_t*)C, cDesc, - (int8_t*)C, cDesc, - NULL, NULL, 0, stream - )); - } - } - - has_error |= checkHipblasStatus(hipblasLtMatrixLayoutDestroy(cDesc)); - has_error |= checkHipblasStatus(hipblasLtMatrixLayoutDestroy(bDesc)); - has_error |= checkHipblasStatus(hipblasLtMatrixLayoutDestroy(aDesc)); - has_error |= checkHipblasStatus(hipblasLtMatmulDescDestroy(matmulDesc)); - - if(has_error == 1) - printf("error detected"); - - return has_error; -#endif // NO_HIPBLASLT -} - -int fill_up_to_nearest_multiple(int value, int multiple) -{ - return value + (value % multiple == 0 ? 0 : (multiple - (value % multiple))); -} - -void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, half *bias, int numRows, int numCols, hipStream_t stream) -{ - const int threads = 512; - const int num_per_thread = 4; - const int num_per_block = threads * num_per_thread; - const int n = numRows*numCols; - const int num_blocks = (n + num_per_block - 1) / num_per_block; - - hipLaunchKernelGGL(( kdequant_mm_int32_fp16), dim3(num_blocks), dim3(threads), 0, stream, A, rowStats, colStats, out, bias, numRows, numCols, n); - CUDA_CHECK_RETURN(hipPeekAtLastError()); -} - -void int8VectorQuant(half * __restrict__ A, int8_t *out, float *rowStats, float threshold, int rows, int cols, hipStream_t stream) { if (threshold == 0.0) { - kInt8VectorQuant<<>>(A, out, rowStats, threshold, rows, cols); - } else { - kInt8VectorQuant<<>>(A, out, rowStats, threshold, rows, cols); - } - CUDA_CHECK_RETURN(hipPeekAtLastError()); -} - - -template void gemm_4bit_inference_naive(int m, int n, int k, T * A, unsigned char* B, float *absmax, float *datatype, T * out, int lda, int ldb, int ldc, int blocksize, hipStream_t stream) -{ - - //warpsize - 32 - int num_blocks = (m+3)/4; - //warpsize - 64 - if (BNB_WARP_SIZE == 64) { - num_blocks = (m+1)/2; - } - - hipLaunchKernelGGL(( kgemm_4bit_inference_naive), dim3(num_blocks), dim3(128), 0, stream, m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize); - CUDA_CHECK_RETURN(hipPeekAtLastError()); -} - -template void func(T *A, T *B, T value, long n) -{ - int threads = 512; - int blocks = n/threads; - blocks = n % threads == 0 ? blocks : blocks + 1; - blocks = blocks > 65535 ? 65535 : blocks; - hipLaunchKernelGGL(( kfunc), dim3(blocks), dim3(512), 0, 0, A, B, value, n); - CUDA_CHECK_RETURN(hipPeekAtLastError()); -} - -//============================================================== -// TEMPLATE DEFINITIONS -//============================================================== - -template void func(float *A, float *B, float value, long n); -template void func(unsigned char *A, unsigned char *B, unsigned char value, long n); -template void func(float *A, float *B, float value, long n); -template void func(float *A, float *B, float value, long n); - -template void gemm_4bit_inference_naive(int m, int n, int k, half * A, unsigned char* B, float *absmax, float *datatype, half * out, int lda, int ldb, int ldc, int blocksize, hipStream_t stream); -template void gemm_4bit_inference_naive(int m, int n, int k, hip_bfloat16 * A, unsigned char* B, float *absmax, float *datatype, hip_bfloat16 * out, int lda, int ldb, int ldc, int blocksize, hipStream_t stream); -template void gemm_4bit_inference_naive(int m, int n, int k, float * A, unsigned char* B, float *absmax, float *datatype, float * out, int lda, int ldb, int ldc, int blocksize, hipStream_t stream); - - -template int igemmlt<32, 0>(hipblasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, hipStream_t stream); -template int igemmlt<8, 0>(hipblasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, hipStream_t stream); -template int igemmlt<8, 1>(hipblasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, hipStream_t stream); - -template void quantizeBlockwise(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); -template void quantizeBlockwise(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); -template void quantizeBlockwise(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); -template void quantizeBlockwise(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); -template void quantizeBlockwise(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); -template void quantizeBlockwise(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); -template void quantizeBlockwise(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); -template void quantizeBlockwise(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); -template void quantizeBlockwise(float * code, hip_bfloat16 *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); -template void quantizeBlockwise(float * code, hip_bfloat16 *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); -template void quantizeBlockwise(float * code, hip_bfloat16 *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); -template void quantizeBlockwise(float * code, hip_bfloat16 *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); - -template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, hipStream_t stream); -template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, hipStream_t stream); -template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, hipStream_t stream); -template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n, hipStream_t stream); -template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n, hipStream_t stream); -template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n, hipStream_t stream); -template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, hip_bfloat16 *out, int blocksize, const int n, hipStream_t stream); -template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, hip_bfloat16 *out, int blocksize, const int n, hipStream_t stream); -template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, hip_bfloat16 *out, int blocksize, const int n, hipStream_t stream); - -#define MAKE_optimizer32bit(name, gtype) \ -template void optimizer32bit(gtype* g, gtype* p, \ - float* state1, float* state2, float* unorm, float max_unorm, float param_norm, \ - const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay, \ - const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); - -MAKE_optimizer32bit(ADAM, half) -MAKE_optimizer32bit(ADAM, float) -MAKE_optimizer32bit(ADAM, hip_bfloat16) -MAKE_optimizer32bit(MOMENTUM, half) -MAKE_optimizer32bit(MOMENTUM, float) -MAKE_optimizer32bit(MOMENTUM, hip_bfloat16) -MAKE_optimizer32bit(RMSPROP, half) -MAKE_optimizer32bit(RMSPROP, float) -MAKE_optimizer32bit(RMSPROP, hip_bfloat16) -MAKE_optimizer32bit(LION, half) -MAKE_optimizer32bit(LION, float) -MAKE_optimizer32bit(LION, hip_bfloat16) -MAKE_optimizer32bit(ADAGRAD, half) -MAKE_optimizer32bit(ADAGRAD, float) -MAKE_optimizer32bit(ADAGRAD, hip_bfloat16) -MAKE_optimizer32bit(ADEMAMIX, half) -MAKE_optimizer32bit(ADEMAMIX, hip_bfloat16) -MAKE_optimizer32bit(ADEMAMIX, float) - -#define MAKE_optimizerStatic8bitBlockwise(gtype, optim_name) \ -template void optimizerStatic8bitBlockwise(gtype* p, gtype* g, \ - unsigned char* state1, unsigned char* state2, float beta1, float beta2, float beta3, float alpha, float eps, int step, float lr, \ - float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n); \ - -MAKE_optimizerStatic8bitBlockwise(half, ADAM); -MAKE_optimizerStatic8bitBlockwise(float, ADAM); -MAKE_optimizerStatic8bitBlockwise(hip_bfloat16, ADAM); -MAKE_optimizerStatic8bitBlockwise(half, MOMENTUM); -MAKE_optimizerStatic8bitBlockwise(float, MOMENTUM); -MAKE_optimizerStatic8bitBlockwise(hip_bfloat16, MOMENTUM); -MAKE_optimizerStatic8bitBlockwise(half, RMSPROP); -MAKE_optimizerStatic8bitBlockwise(float, RMSPROP); -MAKE_optimizerStatic8bitBlockwise(hip_bfloat16, RMSPROP); -MAKE_optimizerStatic8bitBlockwise(half, LION); -MAKE_optimizerStatic8bitBlockwise(float, LION); -MAKE_optimizerStatic8bitBlockwise(hip_bfloat16, LION); -MAKE_optimizerStatic8bitBlockwise(half, ADAGRAD); -MAKE_optimizerStatic8bitBlockwise(float, ADAGRAD); -MAKE_optimizerStatic8bitBlockwise(hip_bfloat16, ADAGRAD); -MAKE_optimizerStatic8bitBlockwise(half, ADEMAMIX); -MAKE_optimizerStatic8bitBlockwise(hip_bfloat16, ADEMAMIX); -MAKE_optimizerStatic8bitBlockwise(float, ADEMAMIX); diff --git a/csrc/ops_hip.cuh b/csrc/ops_hip.cuh deleted file mode 100644 index 6e884df00..000000000 --- a/csrc/ops_hip.cuh +++ /dev/null @@ -1,154 +0,0 @@ -// !!! This is a file automatically generated by hipify!!! -// Copyright (c) Facebook, Inc. and its affiliates. -// -// This source code is licensed under the MIT license found in the -// LICENSE file in the root directory of this source tree. - -#ifndef ops_H -#define ops_H - -#include -#include -#include -#include - -#ifdef _WIN32 -#include -#include -#include -#else -#include -#endif - -#include -#include -#include -#include -#include -#include -#include - -#define CUDA_CHECK_RETURN(value) \ - { \ - hipError_t _m_cudaStat = value; \ - if (_m_cudaStat != hipSuccess) { \ - fprintf(stderr, "Error %s at line %d in file %s\n", hipGetErrorString(_m_cudaStat), __LINE__, __FILE__); \ - exit(1); \ - } \ - } - -inline void checkHipStatus(hipError_t status) { - if (status != hipSuccess) { - printf("hip API failed with status %d: %s\n", status, hipGetErrorString(status)); - throw std::logic_error("hip API failed"); - } -} - -inline int checkHipblasStatus(hipblasStatus_t status) { - if (status != HIPBLAS_STATUS_SUCCESS) { - printf("hipBLAS API failed with status %d\n", status); - // throw std::logic_error("cuBLAS API failed"); - return 1; - } - return 0; -} - -typedef enum Operations_t { - ksmul = 0, -} Operations_t; - -typedef enum Optimizer_t { - ADAM = 0, - MOMENTUM = 1, - RMSPROP = 2, - LARS = 3, - ADAGRAD = 4, - LION = 5, - ADEMAMIX = 6, -} Optimizer_t; - -typedef enum Funcs_t { - FILL = 0, - ARANGE = 1, - _MUL = 2, -} Funcs_t; - -class Context { - public: - rocblas_handle m_handle; - - Context() { - rocblas_handle handle; - rocblas_create_handle(&handle); - m_handle = handle; - } -}; - -class ContextLt { - public: - hipblasLtHandle_t m_handle; - - ContextLt() { - hipblasLtHandle_t handle; - hipblasLtCreate(&handle); - m_handle = handle; - } -}; - -template -void quantizeBlockwise( - float* code, T* A, float* absmax, unsigned char* out, float* rand, int rand_offset, int blocksize, const int n -); -template -void dequantizeBlockwise( - float* code, unsigned char* A, float* absmax, T* out, int block_size, const int n, hipStream_t stream -); - -template -void optimizer32bit( - T* g, T* p, float* state1, float* state2, float* unorm, float max_unorm, float param_norm, float beta1, float beta2, - float beta3, float alpha, float eps, float weight_decay, int step, float lr, const float gnorm_scale, - bool skip_zeros, int n -); - -template -void optimizerStatic8bitBlockwise( - T* p, T* g, unsigned char* state1, unsigned char* state2, float beta1, float beta2, float beta3, float alpha, - float eps, int step, float lr, float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, - float weight_decay, const float gnorm_scale, bool skip_zeros, int n -); - -void gemmex( - Context* context, bool transposeA, bool transposeB, int m, int n, int k, void* A, void* B, void* C, int lda, - int ldb, int ldc -); -void strided_gemmex( - Context* context, bool transposeA, bool transposeB, int m, int n, int k, void* A, void* B, void* C, int lda, - int ldb, int ldc, long long int strideA, long long int strideB, long long int strideC, int batchCount -); - -template -int igemmlt( - hipblasLtHandle_t ltHandle, int m, int n, int k, const int8_t* A, const int8_t* B, void* C, float* row_scale, - int lda, int ldb, int ldc, hipStream_t stream -); - -void cutlass_igemm( - bool transposeA, bool transposeB, int m, int n, int k, void* A, void* B, void* C, int lda, int ldb, int ldc -); -void dequant_mm_int32_fp16( - int* A, float* rowStats, float* colStats, half* out, half* bias, int numRows, int numCols, hipStream_t stream -); -void int8VectorQuant( - half* __restrict__ A, int8_t* out, float* rowStats, float threshold, int rows, int cols, hipStream_t stream -); - -template -void gemm_4bit_inference_naive( - int m, int n, int k, T* A, unsigned char* B, float* absmax, float* datatype, T* out, int lda, int ldb, int ldc, - int blocksize, hipStream_t stream -); - -template void func(T* A, T* B, T value, long n); - -#endif diff --git a/csrc/pythonInterface.cpp b/csrc/pythonInterface.cpp index aee7a4d25..7493574f0 100644 --- a/csrc/pythonInterface.cpp +++ b/csrc/pythonInterface.cpp @@ -8,7 +8,7 @@ #include #endif #if BUILD_HIP -#include +#include #endif #if BUILD_MPS // #include