Skip to content

Commit ad07878

Browse files
q10facebook-github-bot
authored andcommitted
More if-constexpr (#4054)
Summary: Pull Request resolved: #4054 X-link: facebookresearch/FBGEMM#1139 - Update more locations to if-constexpr to accomodate D73473546 Reviewed By: spcyppt Differential Revision: D73873544 fbshipit-source-id: 8fbdb7f63cc66c75b775ae33b4ec69f88ca6ad2c
1 parent 58f4152 commit ad07878

File tree

4 files changed

+4
-4
lines changed

4 files changed

+4
-4
lines changed

fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu

+1-1
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ __global__ __launch_bounds__(kForwardMaxThreads) void
122122
const auto D_end = D_offsets[t + 1];
123123
const auto D = D_end - D_start;
124124
auto D_emb = D;
125-
if (std::is_same<emb_t, uint8_t>::value) {
125+
if constexpr (std::is_same_v<emb_t, uint8_t>) {
126126
D_emb += kINT8QparamsBytes;
127127
}
128128
const auto indices_start = offsets[b_t];

fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_nobag_small_template.cu

+1-1
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ batch_index_select_dim0_codegen_forward_small_kernel(
124124
{%- endif %}
125125

126126
int32_t D_emb = D;
127-
if (std::is_same<emb_t, uint8_t>::value) {
127+
if constexpr (std::is_same_v<emb_t, uint8_t>) {
128128
D_emb += kINT8QparamsBytes;
129129
}
130130

fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_template.cu

+1-1
Original file line numberDiff line numberDiff line change
@@ -730,7 +730,7 @@ batch_index_select_dim0_codegen_forward_kernel(
730730
// D is computed in the bag case or provided as function arg in the nobag case
731731
// (nobag only supports the case where the embedding dimensions are the same for all tables)
732732
int32_t D_emb = D;
733-
if (std::is_same<emb_t, uint8_t>::value) {
733+
if constexpr (std::is_same_v<emb_t, uint8_t>) {
734734
D_emb += kINT8QparamsBytes;
735735
}
736736

fbgemm_gpu/src/split_embeddings_cache/lxu_cache.cu

+1-1
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ __global__ __launch_bounds__(kMaxThreads) void lxu_cache_flush_kernel(
6666
threadIdx.x);
6767

6868
float2 qparams;
69-
if (std::is_same<emb_t, uint8_t>::value) {
69+
if constexpr (std::is_same_v<emb_t, uint8_t>) {
7070
qparams =
7171
thrust_find_qparams<cache_t>(&lxu_cache_weights[b][0], D_current);
7272
if (threadIdx.x == 0) {

0 commit comments

Comments
 (0)