Skip to content

Commit 5f49d48

Browse files
emlinfacebook-github-bot
authored andcommitted
Add keep_orig_idx_per_feature parameter to block_bucketize_sparse_features kernel (#4027)
Summary: X-link: facebookresearch/FBGEMM#1112 **Context** Enhance block_bucketize_sparse_features and block_bucketize_sparse_features_inference kernels to support mixed-format embedding tables. Previously, the keep_orig_idx parameter was a boolean flag applied uniformly across all features, determining whether to retain the original index. With the introduction of [the Flexible Collision-Free Embedding Table](https://github.com/pytorch/torchrec/blob/main/rfc/RFC-0002-Flexible-Collision-Free-Embedding-Table.md), one embedding collection may include both collision-free and collision tables. This update allows the kernel to handle mixed formats by supporting feature-wise control over index retention. For collision-free tables, a large table size of 2^50 is set, maintaining parameters as id-value pairs and preserving the original global id. This change facilitates the use of mixed-style embedding tables effectively. Spec: - keep_orig_idx_per_feature is an optional parameter with per feature settings. - If the keep_orig_idx_per_feature is not None, the value will override global flag keep_orig_idx, no matter it's true for false. - If keep_orig_idx_per_feature is None, fallback to keep_orig_idx control. Note: Adding additional parameter keep_orig_idx_per_feature, instead of change keep_orig_idx directly, is to avoid backward compatibility issue. Differential Revision: D73606958
1 parent 0911c94 commit 5f49d48

File tree

7 files changed

+320
-38
lines changed

7 files changed

+320
-38
lines changed

fbgemm_gpu/fbgemm_gpu/docs/sparse_ops.py

+24-1
Original file line numberDiff line numberDiff line change
@@ -323,7 +323,7 @@
323323
add_docs(
324324
torch.ops.fbgemm.block_bucketize_sparse_features,
325325
"""
326-
block_bucketize_sparse_features(lengths, indices, bucketize_pos, sequence, block_sizes, my_size, weights=None, batch_size_per_feature=None, max_B= -1, block_bucketize_pos=None, keep_orig_idx=False) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor]]
326+
block_bucketize_sparse_features(lengths, indices, bucketize_pos, sequence, block_sizes, my_size, weights=None, batch_size_per_feature=None, max_B= -1, block_bucketize_pos=None, keep_orig_idx=False, total_num_blocks=None, keep_orig_idx_per_feature=None) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor]]
327327
328328
Preprocess sparse features by partitioning sparse features into multiple
329329
buckets. Every feature is split into the same number of buckets, but the bucket
@@ -387,6 +387,11 @@
387387
where we scale up/down the number of GPUs but want to maintain
388388
same numerical behavior.
389389
390+
keep_orig_idx_per_feature (Optional[Tensor] = None): An optional tensor that
391+
contains whether to keep original indices for each feature. If not None,
392+
the operator will use this tensor to determine whether to keep original
393+
indices for each feature. if None, will fallback to `keep_orig_idx`
394+
390395
Return:
391396
A tuple of tensors containing
392397
@@ -448,6 +453,24 @@
448453
dtype=torch.int32),
449454
tensor([ 0, 1, 5, 2, 6, 7, 3, 8, 9, 10, 11, 4, 12, 13, 14],
450455
device='cuda:0', dtype=torch.int32))
456+
>>> # Invoke with keep_orig_idx_per_feature
457+
>>> keep_orig_idx_per_feature = torch.tensor([False, True, False, True], dtype=torch.bool)
458+
>>> torch.ops.fbgemm.block_bucketize_sparse_features(
459+
>>> lengths,
460+
>>> indices,
461+
>>> bucketize_pos=False,
462+
>>> sequence=False,
463+
>>> block_sizes=block_sizes,
464+
>>> my_size=my_size,
465+
>>> keep_orig_idx=False,
466+
>>> keep_orig_idx_per_feature=keep_orig_idx_per_feature)
467+
(tensor([0, 0, 0, 1, 1, 1, 2, 1, 0, 2, 1, 2, 1, 2, 1, 0], device='cuda:0',
468+
dtype=torch.int32),
469+
tensor([ 3, 4, 11, 1, 11, 15, 28, 29, 0, 1, 2, 3, 22, 20, 20],
470+
device='cuda:0', dtype=torch.int32),
471+
None,
472+
None,
473+
None)
451474
>>> # Invoke with block_bucketize_pos
452475
>>> block_bucketize_pos = [
453476
>>> torch.tensor([0, 2, 8], dtype=torch.int),

fbgemm_gpu/fbgemm_gpu/sparse_ops.py

+1
Original file line numberDiff line numberDiff line change
@@ -463,6 +463,7 @@ def block_bucketize_sparse_features_meta(
463463
block_bucketize_pos: Optional[torch.Tensor] = None,
464464
keep_orig_idx: bool = False,
465465
total_num_blocks: Optional[torch.Tensor] = None,
466+
keep_orig_idx_per_feature: Optional[torch.Tensor] = None,
466467
) -> Tuple[
467468
torch.Tensor,
468469
torch.Tensor,

fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h

+8-4
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,8 @@ block_bucketize_sparse_features_cuda(
192192
const int64_t max_batch_size,
193193
const std::optional<std::vector<at::Tensor>>& block_bucketize_pos,
194194
const bool keep_orig_idx,
195-
const std::optional<at::Tensor>& total_num_blocks);
195+
const std::optional<at::Tensor>& total_num_blocks,
196+
const std::optional<at::Tensor>& keep_orig_idx_per_feature);
196197

197198
std::tuple<
198199
at::Tensor,
@@ -214,7 +215,8 @@ block_bucketize_sparse_features_cpu(
214215
const int64_t max_batch_size,
215216
const std::optional<std::vector<at::Tensor>>& block_bucketize_pos,
216217
const bool keep_orig_idx,
217-
const std::optional<at::Tensor>& total_num_blocks);
218+
const std::optional<at::Tensor>& total_num_blocks,
219+
const std::optional<at::Tensor>& keep_orig_idx_per_feature);
218220

219221
std::tuple<
220222
at::Tensor,
@@ -237,7 +239,8 @@ block_bucketize_sparse_features_inference_cuda(
237239
const std::optional<std::vector<at::Tensor>>& block_bucketize_pos,
238240
const bool return_bucket_mapping,
239241
const bool keep_orig_idx,
240-
const std::optional<at::Tensor>& total_num_blocks);
242+
const std::optional<at::Tensor>& total_num_blocks,
243+
const std::optional<at::Tensor>& keep_orig_idx_per_feature);
241244

242245
///@ingroup sparse-data-cuda
243246
at::Tensor populate_bucketized_permute_cuda(
@@ -267,7 +270,8 @@ block_bucketize_sparse_features_inference_cpu(
267270
const std::optional<std::vector<at::Tensor>>& block_bucketize_pos,
268271
const bool return_bucket_mapping,
269272
const bool keep_orig_idx,
270-
const std::optional<at::Tensor>& total_num_blocks);
273+
const std::optional<at::Tensor>& total_num_blocks,
274+
const std::optional<at::Tensor>& keep_orig_idx_per_feature);
271275

272276
///@ingroup sparse-data-cpu
273277
at::Tensor populate_bucketized_permute_cpu(

fbgemm_gpu/src/sparse_ops/sparse_block_bucketize_features.cu

+55-21
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,8 @@ __launch_bounds__(kMaxThreads) void _block_bucketize_pooled_sparse_features_cuda
191191
const index_t* const __restrict__ block_bucketize_pos_concat,
192192
const index_t* const __restrict__ block_bucketize_pos_offsets,
193193
const index_t* const __restrict__ indices_to_lb,
194-
const bool keep_orig_idx) {
194+
const bool keep_orig_idx,
195+
const bool* const __restrict__ keep_orig_idx_per_feature) {
195196
using uindex_t = std::make_unsigned_t<index_t>;
196197
const auto bt_start = blockIdx.x * blockDim.y + threadIdx.y;
197198
const auto stride = gridDim.x * blockDim.y;
@@ -220,6 +221,12 @@ __launch_bounds__(kMaxThreads) void _block_bucketize_pooled_sparse_features_cuda
220221
total_num_blocks == nullptr ? my_size : total_num_blocks[t];
221222
const index_t global_idx_size = blk_size * global_num_blks;
222223
const index_t local_idx_size = blk_size * local_num_blks;
224+
auto keep_idx = keep_orig_idx;
225+
if (keep_orig_idx_per_feature != nullptr) {
226+
// When keep_orig_idx_per_feature is set, override global
227+
// keep_orig_idx settings
228+
keep_idx = keep_orig_idx_per_feature[t];
229+
}
223230
for (auto i = rowstart + threadIdx.x; i < rowend; i += blockDim.x) {
224231
// We have use cases using none-hashed raw indices that can be either
225232
// negative or larger than embedding table hash_size (blk_size *
@@ -233,7 +240,7 @@ __launch_bounds__(kMaxThreads) void _block_bucketize_pooled_sparse_features_cuda
233240
if (!use_block_bucketize_pos) { // uniform bucket sizes
234241
p = idx < global_idx_size ? idx / local_idx_size
235242
: (idx % global_num_blks) / local_num_blks;
236-
if (keep_orig_idx) {
243+
if (keep_idx) {
237244
new_idx = idx;
238245
} else if (idx < global_idx_size) {
239246
new_idx = idx % local_idx_size;
@@ -243,7 +250,7 @@ __launch_bounds__(kMaxThreads) void _block_bucketize_pooled_sparse_features_cuda
243250
} else { // variable bucket sizes
244251
uindex_t lb = indices_to_lb[i];
245252
p = lb < my_size ? lb : idx % my_size;
246-
if (keep_orig_idx) {
253+
if (keep_idx) {
247254
new_idx = idx;
248255
} else if (blk_size == 0) {
249256
new_idx = idx / global_num_blks;
@@ -307,7 +314,8 @@ __launch_bounds__(kMaxThreads) void _block_bucketize_sequence_sparse_features_cu
307314
const index_t* const __restrict__ block_bucketize_pos_concat,
308315
const index_t* const __restrict__ block_bucketize_pos_offsets,
309316
const index_t* const __restrict__ indices_to_lb,
310-
const bool keep_orig_idx) {
317+
const bool keep_orig_idx,
318+
const bool* const __restrict__ keep_orig_idx_per_feature) {
311319
using uindex_t = std::make_unsigned_t<index_t>;
312320
using uoffset_t = std::make_unsigned_t<offset_t>;
313321
CUDA_KERNEL_LOOP(b_t, lengths_size) {
@@ -324,6 +332,12 @@ __launch_bounds__(kMaxThreads) void _block_bucketize_sequence_sparse_features_cu
324332
offset_t rowend = offsets_data[b_t];
325333
const auto use_block_bucketize_pos =
326334
(block_bucketize_pos_concat != nullptr);
335+
auto keep_idx = keep_orig_idx;
336+
if (keep_orig_idx_per_feature != nullptr) {
337+
// When keep_orig_idx_per_feature is set, override global
338+
// keep_orig_idx settings
339+
keep_idx = keep_orig_idx_per_feature[t];
340+
}
327341
for (index_t i = rowstart; i < rowend; ++i) {
328342
// We have use cases using none-hashed raw indices that can be either
329343
// negative or larger than embedding table hash_size (blk_size *
@@ -337,7 +351,7 @@ __launch_bounds__(kMaxThreads) void _block_bucketize_sequence_sparse_features_cu
337351
if (!use_block_bucketize_pos) {
338352
p = idx < global_idx_size ? idx / local_idx_size
339353
: (idx % global_num_blks) / local_num_blks;
340-
if (keep_orig_idx) {
354+
if (keep_idx) {
341355
new_idx = idx;
342356
} else if (idx < global_idx_size) {
343357
new_idx = idx % local_idx_size;
@@ -347,7 +361,7 @@ __launch_bounds__(kMaxThreads) void _block_bucketize_sequence_sparse_features_cu
347361
} else {
348362
uindex_t lb = indices_to_lb[i];
349363
p = lb < my_size ? lb : idx % my_size;
350-
if (keep_orig_idx) {
364+
if (keep_idx) {
351365
new_idx = idx;
352366
} else if (blk_size == 0) {
353367
new_idx = idx / global_num_blks;
@@ -455,7 +469,10 @@ __launch_bounds__(kMaxThreads) void _populate_bucketized_permute_cuda_kernel(
455469
block_bucketize_pos.has_value() \
456470
? indices_to_lb.data_ptr<index_t>() \
457471
: static_cast<index_t*>(nullptr), \
458-
keep_orig_idx); \
472+
keep_orig_idx, \
473+
keep_orig_idx_per_feature.has_value() \
474+
? keep_orig_idx_per_feature->data_ptr<bool>() \
475+
: static_cast<bool*>(nullptr)); \
459476
C10_CUDA_KERNEL_LAUNCH_CHECK(); \
460477
}); \
461478
}); \
@@ -514,7 +531,10 @@ __launch_bounds__(kMaxThreads) void _populate_bucketized_permute_cuda_kernel(
514531
block_bucketize_pos.has_value() \
515532
? indices_to_lb.data_ptr<index_t>() \
516533
: static_cast<index_t*>(nullptr), \
517-
keep_orig_idx); \
534+
keep_orig_idx, \
535+
keep_orig_idx_per_feature.has_value() \
536+
? keep_orig_idx_per_feature->data_ptr<bool>() \
537+
: static_cast<bool*>(nullptr)); \
518538
C10_CUDA_KERNEL_LAUNCH_CHECK(); \
519539
}); \
520540
});
@@ -577,7 +597,10 @@ __launch_bounds__(kMaxThreads) void _populate_bucketized_permute_cuda_kernel(
577597
block_bucketize_pos.has_value() \
578598
? indices_to_lb.data_ptr<index_t>() \
579599
: static_cast<index_t*>(nullptr), \
580-
keep_orig_idx); \
600+
keep_orig_idx, \
601+
keep_orig_idx_per_feature.has_value() \
602+
? keep_orig_idx_per_feature->data_ptr<bool>() \
603+
: static_cast<bool*>(nullptr)); \
581604
C10_CUDA_KERNEL_LAUNCH_CHECK(); \
582605
}); \
583606
}); \
@@ -637,13 +660,17 @@ __launch_bounds__(kMaxThreads) void _populate_bucketized_permute_cuda_kernel(
637660
block_bucketize_pos.has_value() \
638661
? indices_to_lb.data_ptr<index_t>() \
639662
: static_cast<index_t*>(nullptr), \
640-
keep_orig_idx); \
663+
keep_orig_idx, \
664+
keep_orig_idx_per_feature.has_value() \
665+
? keep_orig_idx_per_feature->data_ptr<bool>() \
666+
: static_cast<bool*>(nullptr)); \
641667
C10_CUDA_KERNEL_LAUNCH_CHECK(); \
642668
}); \
643669
});
644670
645671
// This function partitions sparse features
646-
// continuously along the sparse dimension into my_size blocks
672+
// continuously along the sparse dimension into
673+
// my_size blocks
647674
std::tuple<
648675
Tensor,
649676
Tensor,
@@ -664,7 +691,8 @@ _block_bucketize_sparse_features_cuda(
664691
const int64_t max_B,
665692
const std::optional<std::vector<at::Tensor>>& block_bucketize_pos,
666693
const bool return_bucket_mapping,
667-
const bool keep_orig_idx) {
694+
const bool keep_orig_idx,
695+
const std::optional<Tensor>& keep_orig_idx_per_feature) {
668696
TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(lengths, indices);
669697
670698
CUDA_DEVICE_GUARD(lengths);
@@ -740,8 +768,9 @@ _block_bucketize_sparse_features_cuda(
740768
at::Tensor sizes_vec =
741769
at::tensor(sizes_, at::TensorOptions().dtype(indices_contig.dtype()));
742770
block_bucketize_pos_offsets = asynchronous_exclusive_cumsum_cpu(
743-
sizes_vec); // expect sizes_vec to be a small tensor, using cpu instead
744-
// of gpu for cumsum
771+
sizes_vec); // expect sizes_vec to be a
772+
// small tensor, using cpu
773+
// instead of gpu for cumsum
745774
block_bucketize_pos_offsets = block_bucketize_pos_offsets.to(
746775
block_bucketize_pos_concat.device(), true);
747776
}
@@ -896,8 +925,8 @@ _block_bucketize_sparse_features_cuda(
896925
#undef LAUNCH_BLOCK_BUCKETIZE_POOLED_SPARSE_FEATURES_CUDA_KERNEL_2_WITHOUT_WEIGHT
897926
898927
// This function partitions sparse features
899-
// continuously along the sparse dimension into my_size
900-
// blocks
928+
// continuously along the sparse dimension into
929+
// my_size blocks
901930
DLL_PUBLIC std::tuple<
902931
Tensor,
903932
Tensor,
@@ -916,7 +945,8 @@ block_bucketize_sparse_features_cuda(
916945
const int64_t max_B,
917946
const std::optional<std::vector<at::Tensor>>& block_bucketize_pos,
918947
const bool keep_orig_idx,
919-
const std::optional<Tensor>& total_num_blocks) {
948+
const std::optional<Tensor>& total_num_blocks,
949+
const std::optional<at::Tensor>& keep_orig_idx_per_feature) {
920950
Tensor new_lengths;
921951
Tensor new_indices;
922952
std::optional<Tensor> new_weights;
@@ -942,12 +972,14 @@ block_bucketize_sparse_features_cuda(
942972
max_B,
943973
block_bucketize_pos,
944974
false,
945-
keep_orig_idx);
975+
keep_orig_idx,
976+
keep_orig_idx_per_feature);
946977
return {new_lengths, new_indices, new_weights, new_pos, unbucketize_permute};
947978
}
948979
949980
// This function partitions sparse features
950-
// continuously along the sparse dimension into my_size blocks
981+
// continuously along the sparse dimension into
982+
// my_size blocks
951983
DLL_PUBLIC std::tuple<
952984
Tensor,
953985
Tensor,
@@ -968,7 +1000,8 @@ block_bucketize_sparse_features_inference_cuda(
9681000
const std::optional<std::vector<at::Tensor>>& block_bucketize_pos,
9691001
const bool return_bucket_mapping,
9701002
const bool keep_orig_idx,
971-
const std::optional<Tensor>& total_num_blocks) {
1003+
const std::optional<Tensor>& total_num_blocks,
1004+
const std::optional<at::Tensor>& keep_orig_idx_per_feature) {
9721005
return _block_bucketize_sparse_features_cuda(
9731006
lengths,
9741007
indices,
@@ -982,7 +1015,8 @@ block_bucketize_sparse_features_inference_cuda(
9821015
max_B,
9831016
block_bucketize_pos,
9841017
return_bucket_mapping,
985-
keep_orig_idx);
1018+
keep_orig_idx,
1019+
keep_orig_idx_per_feature);
9861020
}
9871021
9881022
DLL_PUBLIC Tensor populate_bucketized_permute_cuda(

0 commit comments

Comments
 (0)