From 060ae2e71ae214b1798eb2c671d0956ac7b36cee Mon Sep 17 00:00:00 2001 From: Benson Ma Date: Tue, 29 Apr 2025 18:43:18 -0700 Subject: [PATCH] Fix `int32_t` to `auto` for code around `WeightRow` Summary: - Fix `int32_t` to `auto` for code around `WeightRow` - Fix `kINT8QparamsBytes` from `float` to `int32_t` Reviewed By: spcyppt, sryap Differential Revision: D73690651 --- .../embedding_forward_split_meta_template.cpp | 2 +- .../include/fbgemm_gpu/utils/cuda_prelude.cuh | 9 ++++----- .../include/fbgemm_gpu/utils/weight_row.cuh | 20 +++++++++---------- .../lfu_cache_populate.cu | 14 ++++++------- .../lru_cache_populate.cu | 14 ++++++------- .../src/split_embeddings_cache/lxu_cache.cu | 14 ++++++------- 6 files changed, 36 insertions(+), 37 deletions(-) diff --git a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_meta_template.cpp b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_meta_template.cpp index a6668bc606..d8cb839536 100644 --- a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_meta_template.cpp +++ b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_meta_template.cpp @@ -33,7 +33,7 @@ using namespace fbgemm_gpu; using Tensor = at::Tensor; -[[maybe_unused]] static constexpr float kINT8QparamsBytes = 8; +[[maybe_unused]] static constexpr int32_t kINT8QparamsBytes = 8; //////////////////////////////////////////////////////////////////////////////// // Kernel Definitions diff --git a/fbgemm_gpu/include/fbgemm_gpu/utils/cuda_prelude.cuh b/fbgemm_gpu/include/fbgemm_gpu/utils/cuda_prelude.cuh index 9b33c895c6..bf5a5fdacc 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/utils/cuda_prelude.cuh +++ b/fbgemm_gpu/include/fbgemm_gpu/utils/cuda_prelude.cuh @@ -81,11 +81,10 @@ static constexpr uint32_t kFullWarpMask = 0xff'ff'ff'ff; static constexpr float kQParamEps = 1e-8f; -/* For rowwise int8 quantization, two quantization parameters (qparams) -will be stored at the end of each row in FP32 formats, appending a total of -8 bytes to each row. -*/ -static constexpr float kINT8QparamsBytes = 8; +// For rowwise int8 quantization, two quantization parameters (qparams) will be +// stored at the end of each row in FP32 formats, appending a total of 8 bytes +// to each row. +static constexpr int32_t kINT8QparamsBytes = 8; template DEVICE_INLINE T shfl_xor( diff --git a/fbgemm_gpu/include/fbgemm_gpu/utils/weight_row.cuh b/fbgemm_gpu/include/fbgemm_gpu/utils/weight_row.cuh index ab948e0284..888e1b81b3 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/utils/weight_row.cuh +++ b/fbgemm_gpu/include/fbgemm_gpu/utils/weight_row.cuh @@ -214,12 +214,12 @@ struct WeightRow { DEVICE_INLINE void warp_copy_to_cache( cache_t* dst_row, - const int32_t dim_length, - const int32_t num_lanes, - const int32_t lane_id) { + const uint32_t dim_length, + const uint32_t num_lanes, + const uint32_t lane_id) { if constexpr (std::is_same_v) { // No conversion required when emb_t and cache_t are the same type - for (int32_t d = lane_id * 4; d < dim_length; d += num_lanes * 4) { + for (auto d = lane_id * 4; d < dim_length; d += num_lanes * 4) { same_type_vector_copy( dst_row + d, reinterpret_cast(row_ + d)); } @@ -229,7 +229,7 @@ struct WeightRow { // Copy over for each warp-sized slice of Vec4's // Does 2-step conversion: weight_t -> FP32 -> cache_t - for (int32_t d = lane_id * 4; d < dim_length; d += num_lanes * 4) { + for (auto d = lane_id * 4; d < dim_length; d += num_lanes * 4) { const auto slice = load(d, qparams); quantize_store(dst_row + d, slice, stoc_rounding_state_ptr_, qparams); } @@ -237,9 +237,9 @@ struct WeightRow { } DEVICE_INLINE void warp_evict_cache( - const int32_t dim_length, - const int32_t num_lanes, - const int32_t lane_id) { + const uint32_t dim_length, + const uint32_t num_lanes, + const uint32_t lane_id) { float2 qparams; if constexpr (std::is_same_v) { @@ -248,7 +248,7 @@ struct WeightRow { std::numeric_limits>::lowest(); // Compute the qparams from the cache row (not embedding row) weights - for (int32_t d = lane_id; d * 4 < dim_length; d += num_lanes) { + for (auto d = lane_id; d * 4 < dim_length; d += num_lanes) { const auto cache_slice = load(d * 4, qparams); // qparams not used local_max = max(local_max, cache_slice.vmax()); local_min = min(local_min, cache_slice.vmin()); @@ -263,7 +263,7 @@ struct WeightRow { } } - for (int32_t d = lane_id * 4; d < dim_length; d += num_lanes * 4) { + for (auto d = lane_id * 4; d < dim_length; d += num_lanes * 4) { // Evict the slice into the embedding row evict_cache(d, qparams); } diff --git a/fbgemm_gpu/src/split_embeddings_cache/lfu_cache_populate.cu b/fbgemm_gpu/src/split_embeddings_cache/lfu_cache_populate.cu index 370f30cb3f..6f570f1126 100644 --- a/fbgemm_gpu/src/split_embeddings_cache/lfu_cache_populate.cu +++ b/fbgemm_gpu/src/split_embeddings_cache/lfu_cache_populate.cu @@ -37,7 +37,7 @@ __global__ __launch_bounds__(kCacheMaxThreads) void lfu_cache_insert_kernel( bool stochastic_rounding, at::PhiloxCudaState stochastic_rounding_philox_args) { const int32_t C = lxu_cache_state.size(0); - for (int32_t n = blockIdx.x * blockDim.y + threadIdx.y; n < *N_unique; + for (auto n = blockIdx.x * blockDim.y + threadIdx.y; n < *N_unique; n += gridDim.x * blockDim.y) { // check if this warp is responsible for this whole segment. const bool segment_start = @@ -64,21 +64,21 @@ __global__ __launch_bounds__(kCacheMaxThreads) void lfu_cache_insert_kernel( // now, we need to insert the (unique!) values in indices[n:n + SL] into // our slots. - const int32_t slot = threadIdx.x; + const auto slot = threadIdx.x; const int64_t current_idx = lxu_cache_state[cache_set][slot]; const int64_t current_lfu_cost = (current_idx != static_cast(kCacheStateInvalid)) ? lfu_state[current_idx] : -1; int64_t costs[1] = {current_lfu_cost}; - int32_t slots[1] = {slot}; + uint32_t slots[1] = {slot}; - BitonicSort>::sort(costs, slots); - const int32_t sorted_slot = slots[0]; - const int64_t sorted_lfu_cost = costs[0]; + BitonicSort>::sort(costs, slots); + const auto sorted_slot = slots[0]; + const auto sorted_lfu_cost = costs[0]; for (int32_t l = 0; l < min(SL, kWarpSize); ++l) { - const int32_t insert_slot = shfl_sync(sorted_slot, l); + const auto insert_slot = shfl_sync(sorted_slot, l); const int64_t insert_current_lfu_cost = shfl_sync(sorted_lfu_cost, l); const int64_t insert_idx = cache_set_sorted_indices[n + l]; const int64_t insert_lfu_cost = lfu_state[insert_idx]; diff --git a/fbgemm_gpu/src/split_embeddings_cache/lru_cache_populate.cu b/fbgemm_gpu/src/split_embeddings_cache/lru_cache_populate.cu index fde27d8211..3dd3ce897a 100644 --- a/fbgemm_gpu/src/split_embeddings_cache/lru_cache_populate.cu +++ b/fbgemm_gpu/src/split_embeddings_cache/lru_cache_populate.cu @@ -45,7 +45,7 @@ __global__ __launch_bounds__(kMaxThreads) void lru_cache_insert_kernel( lxu_cache_locking_counter) { const int32_t C = lxu_cache_state.size(0); int32_t n_conflict_misses = 0; - for (int32_t n = blockIdx.x * blockDim.y + threadIdx.y; n < *N_unique; + for (auto n = blockIdx.x * blockDim.y + threadIdx.y; n < *N_unique; n += gridDim.x * blockDim.y) { // check if this warp is responsible for this whole segment. const bool segment_start = @@ -70,20 +70,20 @@ __global__ __launch_bounds__(kMaxThreads) void lru_cache_insert_kernel( // now, we need to insert the (unique!) values in indices[n:n + SL] into // our slots. - const int32_t slot = threadIdx.x; + const auto slot = threadIdx.x; const int64_t slot_time = lru_state[cache_set][slot]; int64_t costs[1] = {slot_time}; - int32_t slots[1] = {slot}; + uint32_t slots[1] = {slot}; - BitonicSort>::sort(costs, slots); - const int32_t sorted_slot = slots[0]; - const int64_t sorted_lru_cost = costs[0]; + BitonicSort>::sort(costs, slots); + const auto sorted_slot = slots[0]; + const auto sorted_lru_cost = costs[0]; const auto stoc_rounding_salt = kWarpSize * (blockIdx.x * blockDim.x * blockDim.y + threadIdx.y * blockDim.x + threadIdx.x); for (int32_t l = 0; l < min(SL, kWarpSize); ++l) { - const int32_t insert_slot = shfl_sync(sorted_slot, l); + const auto insert_slot = shfl_sync(sorted_slot, l); if (lock_cache_line) { auto count = lxu_cache_locking_counter[cache_set][insert_slot]; if (count > 0) { diff --git a/fbgemm_gpu/src/split_embeddings_cache/lxu_cache.cu b/fbgemm_gpu/src/split_embeddings_cache/lxu_cache.cu index daeb6b0be4..8355cf8a74 100644 --- a/fbgemm_gpu/src/split_embeddings_cache/lxu_cache.cu +++ b/fbgemm_gpu/src/split_embeddings_cache/lxu_cache.cu @@ -35,7 +35,7 @@ __global__ __launch_bounds__(kMaxThreads) void lxu_cache_flush_kernel( bool stochastic_rounding, at::PhiloxCudaState stochastic_rounding_philox_args) { const int32_t B = lxu_cache_weights.size(0); - const int32_t b = blockIdx.x * blockDim.y + threadIdx.y; + const auto b = blockIdx.x * blockDim.y + threadIdx.y; if (b >= B) { return; } @@ -55,7 +55,7 @@ __global__ __launch_bounds__(kMaxThreads) void lxu_cache_flush_kernel( if constexpr (std::is_same_v) { D_emb += kINT8QparamsBytes; } - StochasticRoundingRNGState state; + auto weight_row = WeightRow>( &weights[weights_offset_current + idx_current * D_emb + 0], &lxu_cache_weights[b][0], @@ -73,7 +73,7 @@ __global__ __launch_bounds__(kMaxThreads) void lxu_cache_flush_kernel( weight_row.store_qparams(qparams); } } - for (int32_t d = threadIdx.x * 4; d < D_current; d += blockDim.x * 4) { + for (auto d = threadIdx.x * 4; d < D_current; d += blockDim.x * 4) { weight_row.evict_cache(d, qparams); } } @@ -175,7 +175,7 @@ __launch_bounds__(kMaxThreads) void lxu_cache_locking_counter_decrement_kernel( lxu_cache_locking_counter, pta::PackedTensorAccessor32 count) { const int32_t C = lxu_cache_locking_counter.size(0); - for (int32_t i = blockIdx.x * blockDim.y + threadIdx.y; i < C; + for (auto i = blockIdx.x * blockDim.y + threadIdx.y; i < C; i += gridDim.x * blockDim.y) { const auto j = threadIdx.x; if (count[i][j] > 0) { @@ -259,7 +259,7 @@ __global__ __launch_bounds__(kMaxThreads) void lxu_cache_lookup_kernel( const int32_t C = lxu_cache_state.size(0); const int32_t N = N_unique == nullptr ? linear_cache_indices.size(0) : *N_unique; - const int32_t n0 = + const auto n0 = blockIdx.x * blockDim.y * blockDim.x + threadIdx.y * blockDim.x; if (n0 >= N) { return; @@ -270,7 +270,7 @@ __global__ __launch_bounds__(kMaxThreads) void lxu_cache_lookup_kernel( int32_t n_hits = 0; const auto slot = threadIdx.x; for (int i = 0; i < blockDim.x; ++i) { - int32_t n = n0 + i; + const auto n = n0 + i; if (n >= N) { continue; } @@ -303,7 +303,7 @@ __global__ __launch_bounds__(kMaxThreads) void lxu_cache_lookup_kernel( } } - const int32_t n = n0 + threadIdx.x; + const auto n = n0 + threadIdx.x; if (n < N) { lxu_cache_locations[n] = cache_location; }