@@ -191,7 +191,8 @@ __launch_bounds__(kMaxThreads) void _block_bucketize_pooled_sparse_features_cuda
191
191
const index_t * const __restrict__ block_bucketize_pos_concat,
192
192
const index_t * const __restrict__ block_bucketize_pos_offsets,
193
193
const index_t * const __restrict__ indices_to_lb,
194
- const bool keep_orig_idx) {
194
+ const bool keep_orig_idx,
195
+ const bool * const __restrict__ keep_orig_idx_per_feature) {
195
196
using uindex_t = std::make_unsigned_t <index_t >;
196
197
const auto bt_start = blockIdx .x * blockDim .y + threadIdx .y ;
197
198
const auto stride = gridDim .x * blockDim .y ;
@@ -220,6 +221,12 @@ __launch_bounds__(kMaxThreads) void _block_bucketize_pooled_sparse_features_cuda
220
221
total_num_blocks == nullptr ? my_size : total_num_blocks[t];
221
222
const index_t global_idx_size = blk_size * global_num_blks;
222
223
const index_t local_idx_size = blk_size * local_num_blks;
224
+ auto keep_idx = keep_orig_idx;
225
+ if (keep_orig_idx_per_feature != nullptr ) {
226
+ // When keep_orig_idx_per_feature is set, override global
227
+ // keep_orig_idx settings
228
+ keep_idx = keep_orig_idx_per_feature[t];
229
+ }
223
230
for (auto i = rowstart + threadIdx .x ; i < rowend; i += blockDim .x ) {
224
231
// We have use cases using none-hashed raw indices that can be either
225
232
// negative or larger than embedding table hash_size (blk_size *
@@ -233,7 +240,7 @@ __launch_bounds__(kMaxThreads) void _block_bucketize_pooled_sparse_features_cuda
233
240
if (!use_block_bucketize_pos) { // uniform bucket sizes
234
241
p = idx < global_idx_size ? idx / local_idx_size
235
242
: (idx % global_num_blks) / local_num_blks;
236
- if (keep_orig_idx ) {
243
+ if (keep_idx ) {
237
244
new_idx = idx;
238
245
} else if (idx < global_idx_size) {
239
246
new_idx = idx % local_idx_size;
@@ -243,7 +250,7 @@ __launch_bounds__(kMaxThreads) void _block_bucketize_pooled_sparse_features_cuda
243
250
} else { // variable bucket sizes
244
251
uindex_t lb = indices_to_lb[i];
245
252
p = lb < my_size ? lb : idx % my_size;
246
- if (keep_orig_idx ) {
253
+ if (keep_idx ) {
247
254
new_idx = idx;
248
255
} else if (blk_size == 0 ) {
249
256
new_idx = idx / global_num_blks;
@@ -307,7 +314,8 @@ __launch_bounds__(kMaxThreads) void _block_bucketize_sequence_sparse_features_cu
307
314
const index_t * const __restrict__ block_bucketize_pos_concat,
308
315
const index_t * const __restrict__ block_bucketize_pos_offsets,
309
316
const index_t * const __restrict__ indices_to_lb,
310
- const bool keep_orig_idx) {
317
+ const bool keep_orig_idx,
318
+ const bool * const __restrict__ keep_orig_idx_per_feature) {
311
319
using uindex_t = std::make_unsigned_t <index_t >;
312
320
using uoffset_t = std::make_unsigned_t <offset_t >;
313
321
CUDA_KERNEL_LOOP (b_t , lengths_size) {
@@ -324,6 +332,12 @@ __launch_bounds__(kMaxThreads) void _block_bucketize_sequence_sparse_features_cu
324
332
offset_t rowend = offsets_data[b_t ];
325
333
const auto use_block_bucketize_pos =
326
334
(block_bucketize_pos_concat != nullptr );
335
+ auto keep_idx = keep_orig_idx;
336
+ if (keep_orig_idx_per_feature != nullptr ) {
337
+ // When keep_orig_idx_per_feature is set, override global
338
+ // keep_orig_idx settings
339
+ keep_idx = keep_orig_idx_per_feature[t];
340
+ }
327
341
for (index_t i = rowstart; i < rowend; ++i) {
328
342
// We have use cases using none-hashed raw indices that can be either
329
343
// negative or larger than embedding table hash_size (blk_size *
@@ -337,7 +351,7 @@ __launch_bounds__(kMaxThreads) void _block_bucketize_sequence_sparse_features_cu
337
351
if (!use_block_bucketize_pos) {
338
352
p = idx < global_idx_size ? idx / local_idx_size
339
353
: (idx % global_num_blks) / local_num_blks;
340
- if (keep_orig_idx ) {
354
+ if (keep_idx ) {
341
355
new_idx = idx;
342
356
} else if (idx < global_idx_size) {
343
357
new_idx = idx % local_idx_size;
@@ -347,7 +361,7 @@ __launch_bounds__(kMaxThreads) void _block_bucketize_sequence_sparse_features_cu
347
361
} else {
348
362
uindex_t lb = indices_to_lb[i];
349
363
p = lb < my_size ? lb : idx % my_size;
350
- if (keep_orig_idx ) {
364
+ if (keep_idx ) {
351
365
new_idx = idx;
352
366
} else if (blk_size == 0 ) {
353
367
new_idx = idx / global_num_blks;
@@ -455,7 +469,10 @@ __launch_bounds__(kMaxThreads) void _populate_bucketized_permute_cuda_kernel(
455
469
block_bucketize_pos.has_value() \
456
470
? indices_to_lb.data_ptr<index_t>() \
457
471
: static_cast<index_t*>(nullptr ), \
458
- keep_orig_idx); \
472
+ keep_orig_idx, \
473
+ keep_orig_idx_per_feature.has_value() \
474
+ ? keep_orig_idx_per_feature->data_ptr<bool>() \
475
+ : static_cast<bool*>(nullptr )); \
459
476
C10_CUDA_KERNEL_LAUNCH_CHECK (); \
460
477
}); \
461
478
}); \
@@ -514,7 +531,10 @@ __launch_bounds__(kMaxThreads) void _populate_bucketized_permute_cuda_kernel(
514
531
block_bucketize_pos.has_value() \
515
532
? indices_to_lb.data_ptr<index_t>() \
516
533
: static_cast<index_t*>(nullptr ), \
517
- keep_orig_idx); \
534
+ keep_orig_idx, \
535
+ keep_orig_idx_per_feature.has_value() \
536
+ ? keep_orig_idx_per_feature->data_ptr<bool>() \
537
+ : static_cast<bool*>(nullptr )); \
518
538
C10_CUDA_KERNEL_LAUNCH_CHECK (); \
519
539
}); \
520
540
});
@@ -577,7 +597,10 @@ __launch_bounds__(kMaxThreads) void _populate_bucketized_permute_cuda_kernel(
577
597
block_bucketize_pos.has_value() \
578
598
? indices_to_lb.data_ptr<index_t>() \
579
599
: static_cast<index_t*>(nullptr ), \
580
- keep_orig_idx); \
600
+ keep_orig_idx, \
601
+ keep_orig_idx_per_feature.has_value() \
602
+ ? keep_orig_idx_per_feature->data_ptr<bool>() \
603
+ : static_cast<bool*>(nullptr )); \
581
604
C10_CUDA_KERNEL_LAUNCH_CHECK (); \
582
605
}); \
583
606
}); \
@@ -637,13 +660,17 @@ __launch_bounds__(kMaxThreads) void _populate_bucketized_permute_cuda_kernel(
637
660
block_bucketize_pos.has_value() \
638
661
? indices_to_lb.data_ptr<index_t>() \
639
662
: static_cast<index_t*>(nullptr ), \
640
- keep_orig_idx); \
663
+ keep_orig_idx, \
664
+ keep_orig_idx_per_feature.has_value() \
665
+ ? keep_orig_idx_per_feature->data_ptr<bool>() \
666
+ : static_cast<bool*>(nullptr )); \
641
667
C10_CUDA_KERNEL_LAUNCH_CHECK (); \
642
668
}); \
643
669
});
644
670
645
671
// This function partitions sparse features
646
- // continuously along the sparse dimension into my_size blocks
672
+ // continuously along the sparse dimension into
673
+ // my_size blocks
647
674
std::tuple<
648
675
Tensor,
649
676
Tensor,
@@ -664,7 +691,8 @@ _block_bucketize_sparse_features_cuda(
664
691
const int64_t max_B,
665
692
const std::optional<std::vector<at::Tensor>>& block_bucketize_pos,
666
693
const bool return_bucket_mapping,
667
- const bool keep_orig_idx) {
694
+ const bool keep_orig_idx,
695
+ const std::optional<Tensor>& keep_orig_idx_per_feature) {
668
696
TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL (lengths, indices);
669
697
670
698
CUDA_DEVICE_GUARD (lengths);
@@ -740,8 +768,9 @@ _block_bucketize_sparse_features_cuda(
740
768
at::Tensor sizes_vec =
741
769
at::tensor (sizes_, at::TensorOptions ().dtype (indices_contig.dtype ()));
742
770
block_bucketize_pos_offsets = asynchronous_exclusive_cumsum_cpu (
743
- sizes_vec); // expect sizes_vec to be a small tensor, using cpu instead
744
- // of gpu for cumsum
771
+ sizes_vec); // expect sizes_vec to be a
772
+ // small tensor, using cpu
773
+ // instead of gpu for cumsum
745
774
block_bucketize_pos_offsets = block_bucketize_pos_offsets.to (
746
775
block_bucketize_pos_concat.device (), true );
747
776
}
@@ -896,8 +925,8 @@ _block_bucketize_sparse_features_cuda(
896
925
#undef LAUNCH_BLOCK_BUCKETIZE_POOLED_SPARSE_FEATURES_CUDA_KERNEL_2_WITHOUT_WEIGHT
897
926
898
927
// This function partitions sparse features
899
- // continuously along the sparse dimension into my_size
900
- // blocks
928
+ // continuously along the sparse dimension into
929
+ // my_size blocks
901
930
DLL_PUBLIC std::tuple<
902
931
Tensor,
903
932
Tensor,
@@ -916,7 +945,8 @@ block_bucketize_sparse_features_cuda(
916
945
const int64_t max_B,
917
946
const std::optional<std::vector<at::Tensor>>& block_bucketize_pos,
918
947
const bool keep_orig_idx,
919
- const std::optional<Tensor>& total_num_blocks) {
948
+ const std::optional<Tensor>& total_num_blocks,
949
+ const std::optional<at::Tensor>& keep_orig_idx_per_feature) {
920
950
Tensor new_lengths;
921
951
Tensor new_indices;
922
952
std::optional<Tensor> new_weights;
@@ -942,12 +972,14 @@ block_bucketize_sparse_features_cuda(
942
972
max_B,
943
973
block_bucketize_pos,
944
974
false ,
945
- keep_orig_idx);
975
+ keep_orig_idx,
976
+ keep_orig_idx_per_feature);
946
977
return {new_lengths, new_indices, new_weights, new_pos, unbucketize_permute};
947
978
}
948
979
949
980
// This function partitions sparse features
950
- // continuously along the sparse dimension into my_size blocks
981
+ // continuously along the sparse dimension into
982
+ // my_size blocks
951
983
DLL_PUBLIC std::tuple<
952
984
Tensor,
953
985
Tensor,
@@ -968,7 +1000,8 @@ block_bucketize_sparse_features_inference_cuda(
968
1000
const std::optional<std::vector<at::Tensor>>& block_bucketize_pos,
969
1001
const bool return_bucket_mapping,
970
1002
const bool keep_orig_idx,
971
- const std::optional<Tensor>& total_num_blocks) {
1003
+ const std::optional<Tensor>& total_num_blocks,
1004
+ const std::optional<at::Tensor>& keep_orig_idx_per_feature) {
972
1005
return _block_bucketize_sparse_features_cuda (
973
1006
lengths,
974
1007
indices,
@@ -982,7 +1015,8 @@ block_bucketize_sparse_features_inference_cuda(
982
1015
max_B,
983
1016
block_bucketize_pos,
984
1017
return_bucket_mapping,
985
- keep_orig_idx);
1018
+ keep_orig_idx,
1019
+ keep_orig_idx_per_feature);
986
1020
}
987
1021
988
1022
DLL_PUBLIC Tensor populate_bucketized_permute_cuda (
0 commit comments