Skip to content

Commit 2a3e87e

Browse files
emlinfacebook-github-bot
authored andcommitted
Add keep_orig_idx_per_feature parameter to block_bucketize_sparse_features kernel
Summary: X-link: facebookresearch/FBGEMM#1112 **Context** Enhance block_bucketize_sparse_features and block_bucketize_sparse_features_inference kernels to support mixed-format embedding tables. Previously, the keep_orig_idx parameter was a boolean flag applied uniformly across all features, determining whether to retain the original index. With the introduction of [the Flexible Collision-Free Embedding Table](https://github.com/pytorch/torchrec/blob/main/rfc/RFC-0002-Flexible-Collision-Free-Embedding-Table.md), one embedding collection may include both collision-free and collision tables. This update allows the kernel to handle mixed formats by supporting feature-wise control over index retention. For collision-free tables, a large table size of 2^50 is set, maintaining parameters as id-value pairs and preserving the original global id. This change facilitates the use of mixed-style embedding tables effectively. Spec: - keep_orig_idx_per_feature is an optional parameter with per feature settings. - If the keep_orig_idx_per_feature is not None, the value will override global flag keep_orig_idx, no matter it's true for false. - If keep_orig_idx_per_feature is None, fallback to keep_orig_idx control. Note: Adding additional parameter keep_orig_idx_per_feature, instead of change keep_orig_idx directly, is to avoid backward compatibility issue. Differential Revision: D73606958
1 parent 0f00a8a commit 2a3e87e

File tree

4 files changed

+281
-31
lines changed

4 files changed

+281
-31
lines changed

fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h

+4
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,7 @@ block_bucketize_sparse_features_cuda(
186186
const int64_t max_batch_size,
187187
const std::optional<std::vector<at::Tensor>>& block_bucketize_pos,
188188
const bool keep_orig_idx,
189+
const std::optional<at::Tensor>& keep_orig_idx_per_feature,
189190
const std::optional<at::Tensor>& total_num_blocks);
190191

191192
std::tuple<
@@ -208,6 +209,7 @@ block_bucketize_sparse_features_cpu(
208209
const int64_t max_batch_size,
209210
const std::optional<std::vector<at::Tensor>>& block_bucketize_pos,
210211
const bool keep_orig_idx,
212+
const std::optional<at::Tensor>& keep_orig_idx_per_feature,
211213
const std::optional<at::Tensor>& total_num_blocks);
212214

213215
std::tuple<
@@ -231,6 +233,7 @@ block_bucketize_sparse_features_inference_cuda(
231233
const std::optional<std::vector<at::Tensor>>& block_bucketize_pos,
232234
const bool return_bucket_mapping,
233235
const bool keep_orig_idx,
236+
const std::optional<at::Tensor>& keep_orig_idx_per_feature,
234237
const std::optional<at::Tensor>& total_num_blocks);
235238

236239
///@ingroup sparse-data-cuda
@@ -261,6 +264,7 @@ block_bucketize_sparse_features_inference_cpu(
261264
const std::optional<std::vector<at::Tensor>>& block_bucketize_pos,
262265
const bool return_bucket_mapping,
263266
const bool keep_orig_idx,
267+
const std::optional<at::Tensor>& keep_orig_idx_per_feature,
264268
const std::optional<at::Tensor>& total_num_blocks);
265269

266270
///@ingroup sparse-data-cpu

fbgemm_gpu/src/sparse_ops/sparse_block_bucketize_features.cu

+53-19
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,8 @@ __launch_bounds__(kMaxThreads) void _block_bucketize_pooled_sparse_features_cuda
191191
const index_t* const __restrict__ block_bucketize_pos_concat,
192192
const index_t* const __restrict__ block_bucketize_pos_offsets,
193193
const index_t* const __restrict__ indices_to_lb,
194-
const bool keep_orig_idx) {
194+
const bool keep_orig_idx,
195+
const bool* const __restrict__ keep_orig_idx_per_feature) {
195196
using uindex_t = std::make_unsigned_t<index_t>;
196197
const auto bt_start = blockIdx.x * blockDim.y + threadIdx.y;
197198
const auto stride = gridDim.x * blockDim.y;
@@ -220,6 +221,12 @@ __launch_bounds__(kMaxThreads) void _block_bucketize_pooled_sparse_features_cuda
220221
total_num_blocks == nullptr ? my_size : total_num_blocks[t];
221222
const index_t global_idx_size = blk_size * global_num_blks;
222223
const index_t local_idx_size = blk_size * local_num_blks;
224+
auto keep_idx = keep_orig_idx;
225+
if (keep_orig_idx_per_feature != nullptr) {
226+
// When keep_orig_idx_per_feature is set, override global
227+
// keep_orig_idx settings
228+
keep_idx = keep_orig_idx_per_feature[t];
229+
}
223230
for (auto i = rowstart + threadIdx.x; i < rowend; i += blockDim.x) {
224231
// We have use cases using none-hashed raw indices that can be either
225232
// negative or larger than embedding table hash_size (blk_size *
@@ -233,7 +240,7 @@ __launch_bounds__(kMaxThreads) void _block_bucketize_pooled_sparse_features_cuda
233240
if (!use_block_bucketize_pos) { // uniform bucket sizes
234241
p = idx < global_idx_size ? idx / local_idx_size
235242
: (idx % global_num_blks) / local_num_blks;
236-
if (keep_orig_idx) {
243+
if (keep_idx) {
237244
new_idx = idx;
238245
} else if (idx < global_idx_size) {
239246
new_idx = idx % local_idx_size;
@@ -243,7 +250,7 @@ __launch_bounds__(kMaxThreads) void _block_bucketize_pooled_sparse_features_cuda
243250
} else { // variable bucket sizes
244251
uindex_t lb = indices_to_lb[i];
245252
p = lb < my_size ? lb : idx % my_size;
246-
if (keep_orig_idx) {
253+
if (keep_idx) {
247254
new_idx = idx;
248255
} else if (blk_size == 0) {
249256
new_idx = idx / global_num_blks;
@@ -307,7 +314,8 @@ __launch_bounds__(kMaxThreads) void _block_bucketize_sequence_sparse_features_cu
307314
const index_t* const __restrict__ block_bucketize_pos_concat,
308315
const index_t* const __restrict__ block_bucketize_pos_offsets,
309316
const index_t* const __restrict__ indices_to_lb,
310-
const bool keep_orig_idx) {
317+
const bool keep_orig_idx,
318+
const bool* const __restrict__ keep_orig_idx_per_feature) {
311319
using uindex_t = std::make_unsigned_t<index_t>;
312320
using uoffset_t = std::make_unsigned_t<offset_t>;
313321
CUDA_KERNEL_LOOP(b_t, lengths_size) {
@@ -324,6 +332,12 @@ __launch_bounds__(kMaxThreads) void _block_bucketize_sequence_sparse_features_cu
324332
offset_t rowend = offsets_data[b_t];
325333
const auto use_block_bucketize_pos =
326334
(block_bucketize_pos_concat != nullptr);
335+
auto keep_idx = keep_orig_idx;
336+
if (keep_orig_idx_per_feature != nullptr) {
337+
// When keep_orig_idx_per_feature is set, override global
338+
// keep_orig_idx settings
339+
keep_idx = keep_orig_idx_per_feature[t];
340+
}
327341
for (index_t i = rowstart; i < rowend; ++i) {
328342
// We have use cases using none-hashed raw indices that can be either
329343
// negative or larger than embedding table hash_size (blk_size *
@@ -337,7 +351,7 @@ __launch_bounds__(kMaxThreads) void _block_bucketize_sequence_sparse_features_cu
337351
if (!use_block_bucketize_pos) {
338352
p = idx < global_idx_size ? idx / local_idx_size
339353
: (idx % global_num_blks) / local_num_blks;
340-
if (keep_orig_idx) {
354+
if (keep_idx) {
341355
new_idx = idx;
342356
} else if (idx < global_idx_size) {
343357
new_idx = idx % local_idx_size;
@@ -347,7 +361,7 @@ __launch_bounds__(kMaxThreads) void _block_bucketize_sequence_sparse_features_cu
347361
} else {
348362
uindex_t lb = indices_to_lb[i];
349363
p = lb < my_size ? lb : idx % my_size;
350-
if (keep_orig_idx) {
364+
if (keep_idx) {
351365
new_idx = idx;
352366
} else if (blk_size == 0) {
353367
new_idx = idx / global_num_blks;
@@ -455,7 +469,10 @@ __launch_bounds__(kMaxThreads) void _populate_bucketized_permute_cuda_kernel(
455469
block_bucketize_pos.has_value() \
456470
? indices_to_lb.data_ptr<index_t>() \
457471
: static_cast<index_t*>(nullptr), \
458-
keep_orig_idx); \
472+
keep_orig_idx, \
473+
keep_orig_idx_per_feature.has_value() \
474+
? keep_orig_idx_per_feature->data_ptr<bool>() \
475+
: static_cast<bool*>(nullptr)); \
459476
C10_CUDA_KERNEL_LAUNCH_CHECK(); \
460477
}); \
461478
}); \
@@ -514,7 +531,10 @@ __launch_bounds__(kMaxThreads) void _populate_bucketized_permute_cuda_kernel(
514531
block_bucketize_pos.has_value() \
515532
? indices_to_lb.data_ptr<index_t>() \
516533
: static_cast<index_t*>(nullptr), \
517-
keep_orig_idx); \
534+
keep_orig_idx, \
535+
keep_orig_idx_per_feature.has_value() \
536+
? keep_orig_idx_per_feature->data_ptr<bool>() \
537+
: static_cast<bool*>(nullptr)); \
518538
C10_CUDA_KERNEL_LAUNCH_CHECK(); \
519539
}); \
520540
});
@@ -577,7 +597,10 @@ __launch_bounds__(kMaxThreads) void _populate_bucketized_permute_cuda_kernel(
577597
block_bucketize_pos.has_value() \
578598
? indices_to_lb.data_ptr<index_t>() \
579599
: static_cast<index_t*>(nullptr), \
580-
keep_orig_idx); \
600+
keep_orig_idx, \
601+
keep_orig_idx_per_feature.has_value() \
602+
? keep_orig_idx_per_feature->data_ptr<bool>() \
603+
: static_cast<bool*>(nullptr)); \
581604
C10_CUDA_KERNEL_LAUNCH_CHECK(); \
582605
}); \
583606
}); \
@@ -637,13 +660,17 @@ __launch_bounds__(kMaxThreads) void _populate_bucketized_permute_cuda_kernel(
637660
block_bucketize_pos.has_value() \
638661
? indices_to_lb.data_ptr<index_t>() \
639662
: static_cast<index_t*>(nullptr), \
640-
keep_orig_idx); \
663+
keep_orig_idx, \
664+
keep_orig_idx_per_feature.has_value() \
665+
? keep_orig_idx_per_feature->data_ptr<bool>() \
666+
: static_cast<bool*>(nullptr)); \
641667
C10_CUDA_KERNEL_LAUNCH_CHECK(); \
642668
}); \
643669
});
644670
645671
// This function partitions sparse features
646-
// continuously along the sparse dimension into my_size blocks
672+
// continuously along the sparse dimension into
673+
// my_size blocks
647674
std::tuple<
648675
Tensor,
649676
Tensor,
@@ -664,7 +691,8 @@ _block_bucketize_sparse_features_cuda(
664691
const int64_t max_B,
665692
const std::optional<std::vector<at::Tensor>>& block_bucketize_pos,
666693
const bool return_bucket_mapping,
667-
const bool keep_orig_idx) {
694+
const bool keep_orig_idx,
695+
const std::optional<Tensor>& keep_orig_idx_per_feature) {
668696
TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(lengths, indices);
669697
670698
CUDA_DEVICE_GUARD(lengths);
@@ -740,8 +768,9 @@ _block_bucketize_sparse_features_cuda(
740768
at::Tensor sizes_vec =
741769
at::tensor(sizes_, at::TensorOptions().dtype(indices_contig.dtype()));
742770
block_bucketize_pos_offsets = asynchronous_exclusive_cumsum_cpu(
743-
sizes_vec); // expect sizes_vec to be a small tensor, using cpu instead
744-
// of gpu for cumsum
771+
sizes_vec); // expect sizes_vec to be a
772+
// small tensor, using cpu
773+
// instead of gpu for cumsum
745774
block_bucketize_pos_offsets = block_bucketize_pos_offsets.to(
746775
block_bucketize_pos_concat.device(), true);
747776
}
@@ -896,8 +925,8 @@ _block_bucketize_sparse_features_cuda(
896925
#undef LAUNCH_BLOCK_BUCKETIZE_POOLED_SPARSE_FEATURES_CUDA_KERNEL_2_WITHOUT_WEIGHT
897926
898927
// This function partitions sparse features
899-
// continuously along the sparse dimension into my_size
900-
// blocks
928+
// continuously along the sparse dimension into
929+
// my_size blocks
901930
DLL_PUBLIC std::tuple<
902931
Tensor,
903932
Tensor,
@@ -916,6 +945,7 @@ block_bucketize_sparse_features_cuda(
916945
const int64_t max_B,
917946
const std::optional<std::vector<at::Tensor>>& block_bucketize_pos,
918947
const bool keep_orig_idx,
948+
const std::optional<at::Tensor>& keep_orig_idx_per_feature,
919949
const std::optional<Tensor>& total_num_blocks) {
920950
Tensor new_lengths;
921951
Tensor new_indices;
@@ -942,12 +972,14 @@ block_bucketize_sparse_features_cuda(
942972
max_B,
943973
block_bucketize_pos,
944974
false,
945-
keep_orig_idx);
975+
keep_orig_idx,
976+
keep_orig_idx_per_feature);
946977
return {new_lengths, new_indices, new_weights, new_pos, unbucketize_permute};
947978
}
948979
949980
// This function partitions sparse features
950-
// continuously along the sparse dimension into my_size blocks
981+
// continuously along the sparse dimension into
982+
// my_size blocks
951983
DLL_PUBLIC std::tuple<
952984
Tensor,
953985
Tensor,
@@ -968,6 +1000,7 @@ block_bucketize_sparse_features_inference_cuda(
9681000
const std::optional<std::vector<at::Tensor>>& block_bucketize_pos,
9691001
const bool return_bucket_mapping,
9701002
const bool keep_orig_idx,
1003+
const std::optional<at::Tensor>& keep_orig_idx_per_feature,
9711004
const std::optional<Tensor>& total_num_blocks) {
9721005
return _block_bucketize_sparse_features_cuda(
9731006
lengths,
@@ -982,7 +1015,8 @@ block_bucketize_sparse_features_inference_cuda(
9821015
max_B,
9831016
block_bucketize_pos,
9841017
return_bucket_mapping,
985-
keep_orig_idx);
1018+
keep_orig_idx,
1019+
keep_orig_idx_per_feature);
9861020
}
9871021
9881022
DLL_PUBLIC Tensor populate_bucketized_permute_cuda(

fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp

+27-12
Original file line numberDiff line numberDiff line change
@@ -365,7 +365,8 @@ void _block_bucketize_sparse_features_cpu_kernel(
365365
const std::optional<Tensor>& batch_size_per_feature,
366366
const std::optional<std::vector<at::Tensor>>& block_bucketize_pos,
367367
const std::optional<Tensor>& bucket_mapping,
368-
const bool keep_orig_idx) {
368+
const bool keep_orig_idx,
369+
const std::optional<Tensor>& keep_orig_idx_per_feature = std::nullopt) {
369370
// allocate tensors and buffers
370371
const auto lengths_size = lengths.numel();
371372
const auto new_lengths_size = lengths_size * my_size;
@@ -487,6 +488,13 @@ void _block_bucketize_sparse_features_cpu_kernel(
487488
: my_size;
488489
const index_t global_idx_size = blk_size * global_num_blks;
489490
const index_t local_idx_size = blk_size * local_num_blks;
491+
auto keep_idx = keep_orig_idx;
492+
if (keep_orig_idx_per_feature.has_value()) {
493+
// When keep_orig_idx_per_feature is set, override global
494+
// keep_orig_idx settings
495+
keep_idx = keep_orig_idx_per_feature.value().data_ptr<bool>()[t];
496+
}
497+
490498
for (const auto b : c10::irange(cur_batch_size)) {
491499
const auto b_t = (variable_batch_size ? cur_offset : t * B) + b;
492500
const offset_t rowstart = offsets_data[b_t];
@@ -503,7 +511,7 @@ void _block_bucketize_sparse_features_cpu_kernel(
503511
if (variable_bucket_sizes) {
504512
int64_t lb = lower_bounds[i];
505513
p = lb < my_size ? lb : idx % my_size;
506-
if (keep_orig_idx) {
514+
if (keep_idx) {
507515
new_idx = idx;
508516
} else if (blk_size == 0) {
509517
new_idx = idx / global_num_blks;
@@ -517,7 +525,7 @@ void _block_bucketize_sparse_features_cpu_kernel(
517525
const uindex_t ub = static_cast<uindex_t>(global_idx_size);
518526
p = idx < ub ? idx / local_idx_size
519527
: (idx % global_num_blks) / local_num_blks;
520-
if (keep_orig_idx) {
528+
if (keep_idx) {
521529
new_idx = idx;
522530
} else if (idx < ub) {
523531
new_idx = idx % local_idx_size;
@@ -1134,7 +1142,8 @@ _block_bucketize_sparse_features_cpu(
11341142
const int64_t /* max_batch_size */, // Only used in GPU variant
11351143
const std::optional<std::vector<at::Tensor>>& block_bucketize_pos,
11361144
const bool return_bucket_mapping,
1137-
const bool keep_orig_idx) {
1145+
const bool keep_orig_idx,
1146+
const std::optional<Tensor>& keep_orig_idx_per_feature = std::nullopt) {
11381147
const auto lengths_size = lengths.numel();
11391148
const auto new_lengths_size = lengths_size * my_size;
11401149
auto new_lengths = at::zeros({new_lengths_size}, lengths.options());
@@ -1183,7 +1192,8 @@ _block_bucketize_sparse_features_cpu(
11831192
batch_size_per_feature, \
11841193
block_bucketize_pos, \
11851194
bucket_mapping, \
1186-
keep_orig_idx); \
1195+
keep_orig_idx, \
1196+
keep_orig_idx_per_feature); \
11871197
}); \
11881198
}); \
11891199
});
@@ -1219,7 +1229,8 @@ _block_bucketize_sparse_features_cpu(
12191229
batch_size_per_feature, \
12201230
block_bucketize_pos, \
12211231
bucket_mapping, \
1222-
keep_orig_idx); \
1232+
keep_orig_idx, \
1233+
keep_orig_idx_per_feature); \
12231234
}); \
12241235
});
12251236
const auto lengths_sum = indices.numel();
@@ -1289,7 +1300,8 @@ block_bucketize_sparse_features_cpu(
12891300
const int64_t /* max_batch_size */, // Only used in GPU variant
12901301
const std::optional<std::vector<at::Tensor>>& block_bucketize_pos,
12911302
const bool keep_orig_idx,
1292-
const std::optional<Tensor>& total_num_blocks) {
1303+
const std::optional<Tensor>& keep_orig_idx_per_feature = std::nullopt,
1304+
const std::optional<Tensor>& total_num_blocks = std::nullopt) {
12931305
Tensor new_lengths;
12941306
Tensor new_indices;
12951307
std::optional<Tensor> new_weights;
@@ -1315,7 +1327,8 @@ block_bucketize_sparse_features_cpu(
13151327
-1, /* placeholder for max_batch_size */
13161328
block_bucketize_pos,
13171329
false,
1318-
keep_orig_idx);
1330+
keep_orig_idx,
1331+
keep_orig_idx_per_feature);
13191332
return {new_lengths, new_indices, new_weights, new_pos, unbucketize_permute};
13201333
}
13211334

@@ -1339,7 +1352,8 @@ block_bucketize_sparse_features_inference_cpu(
13391352
const std::optional<std::vector<at::Tensor>>& block_bucketize_pos,
13401353
const bool return_bucket_mapping,
13411354
const bool keep_orig_idx,
1342-
const std::optional<Tensor>& total_num_blocks) {
1355+
const std::optional<Tensor>& keep_orig_idx_per_feature = std::nullopt,
1356+
const std::optional<Tensor>& total_num_blocks = std::nullopt) {
13431357
return _block_bucketize_sparse_features_cpu(
13441358
lengths,
13451359
indices,
@@ -1353,7 +1367,8 @@ block_bucketize_sparse_features_inference_cpu(
13531367
-1, /* placeholder for max_batch_size */
13541368
block_bucketize_pos,
13551369
return_bucket_mapping,
1356-
keep_orig_idx);
1370+
keep_orig_idx,
1371+
keep_orig_idx_per_feature);
13571372
}
13581373

13591374
// This function partitions sparse features
@@ -3401,9 +3416,9 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
34013416
m.def(
34023417
"populate_bucketized_permute(Tensor lengths, Tensor bucketized_lengths, Tensor bucket_mapping) -> Tensor");
34033418
m.def(
3404-
"block_bucketize_sparse_features(Tensor lengths, Tensor indices, bool bucketize_pos, bool sequence, Tensor block_sizes, SymInt my_size, Tensor? weights=None, Tensor? batch_size_per_feature=None, SymInt max_B= -1, Tensor[]? block_bucketize_pos=None, bool keep_orig_idx=False, Tensor? total_num_blocks=None) -> (Tensor, Tensor, Tensor?, Tensor?, Tensor?)");
3419+
"block_bucketize_sparse_features(Tensor lengths, Tensor indices, bool bucketize_pos, bool sequence, Tensor block_sizes, SymInt my_size, Tensor? weights=None, Tensor? batch_size_per_feature=None, SymInt max_B= -1, Tensor[]? block_bucketize_pos=None, bool keep_orig_idx=False, Tensor? keep_orig_idx_per_feature=None, Tensor? total_num_blocks=None) -> (Tensor, Tensor, Tensor?, Tensor?, Tensor?)");
34053420
m.def(
3406-
"block_bucketize_sparse_features_inference(Tensor lengths, Tensor indices, bool bucketize_pos, bool sequence, Tensor block_sizes, SymInt my_size, Tensor? weights=None, Tensor? batch_size_per_feature=None, SymInt max_B= -1, Tensor[]? block_bucketize_pos=None, bool return_bucket_mapping=False, bool keep_orig_idx=False, Tensor? total_num_blocks=None) -> (Tensor, Tensor, Tensor?, Tensor?, Tensor?, Tensor?)");
3421+
"block_bucketize_sparse_features_inference(Tensor lengths, Tensor indices, bool bucketize_pos, bool sequence, Tensor block_sizes, SymInt my_size, Tensor? weights=None, Tensor? batch_size_per_feature=None, SymInt max_B= -1, Tensor[]? block_bucketize_pos=None, bool return_bucket_mapping=False, bool keep_orig_idx=False, Tensor? keep_orig_idx_per_feature=None, Tensor? total_num_blocks=None) -> (Tensor, Tensor, Tensor?, Tensor?, Tensor?, Tensor?)");
34073422
m.def(
34083423
"bucketize_sparse_features(Tensor lengths, Tensor indices, bool bucketize_pos, SymInt my_size, Tensor? weights=None) -> (Tensor, Tensor, Tensor?, Tensor?)");
34093424
m.def(

0 commit comments

Comments
 (0)