Skip to content

Add keep_orig_idx_per_feature parameter to block_bucketize_sparse_features kernel #4027

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 24 additions & 1 deletion fbgemm_gpu/fbgemm_gpu/docs/sparse_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,7 @@
add_docs(
torch.ops.fbgemm.block_bucketize_sparse_features,
"""
block_bucketize_sparse_features(lengths, indices, bucketize_pos, sequence, block_sizes, my_size, weights=None, batch_size_per_feature=None, max_B= -1, block_bucketize_pos=None, keep_orig_idx=False) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor]]
block_bucketize_sparse_features(lengths, indices, bucketize_pos, sequence, block_sizes, my_size, weights=None, batch_size_per_feature=None, max_B= -1, block_bucketize_pos=None, keep_orig_idx=False, total_num_blocks=None, keep_orig_idx_per_feature=None) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor]]

Preprocess sparse features by partitioning sparse features into multiple
buckets. Every feature is split into the same number of buckets, but the bucket
Expand Down Expand Up @@ -387,6 +387,11 @@
where we scale up/down the number of GPUs but want to maintain
same numerical behavior.

keep_orig_idx_per_feature (Optional[Tensor] = None): An optional tensor that
contains whether to keep original indices for each feature. If not None,
the operator will use this tensor to determine whether to keep original
indices for each feature. if None, will fallback to `keep_orig_idx`

Return:
A tuple of tensors containing

Expand Down Expand Up @@ -448,6 +453,24 @@
dtype=torch.int32),
tensor([ 0, 1, 5, 2, 6, 7, 3, 8, 9, 10, 11, 4, 12, 13, 14],
device='cuda:0', dtype=torch.int32))
>>> # Invoke with keep_orig_idx_per_feature
>>> keep_orig_idx_per_feature = torch.tensor([False, True, False, True], dtype=torch.bool)
>>> torch.ops.fbgemm.block_bucketize_sparse_features(
>>> lengths,
>>> indices,
>>> bucketize_pos=False,
>>> sequence=False,
>>> block_sizes=block_sizes,
>>> my_size=my_size,
>>> keep_orig_idx=False,
>>> keep_orig_idx_per_feature=keep_orig_idx_per_feature)
(tensor([0, 0, 0, 1, 1, 1, 2, 1, 0, 2, 1, 2, 1, 2, 1, 0], device='cuda:0',
dtype=torch.int32),
tensor([ 3, 4, 11, 1, 11, 15, 28, 29, 0, 1, 2, 3, 22, 20, 20],
device='cuda:0', dtype=torch.int32),
None,
None,
None)
>>> # Invoke with block_bucketize_pos
>>> block_bucketize_pos = [
>>> torch.tensor([0, 2, 8], dtype=torch.int),
Expand Down
1 change: 1 addition & 0 deletions fbgemm_gpu/fbgemm_gpu/sparse_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,7 @@ def block_bucketize_sparse_features_meta(
block_bucketize_pos: Optional[torch.Tensor] = None,
keep_orig_idx: bool = False,
total_num_blocks: Optional[torch.Tensor] = None,
keep_orig_idx_per_feature: Optional[torch.Tensor] = None,
) -> Tuple[
torch.Tensor,
torch.Tensor,
Expand Down
12 changes: 8 additions & 4 deletions fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,8 @@ block_bucketize_sparse_features_cuda(
const int64_t max_batch_size,
const std::optional<std::vector<at::Tensor>>& block_bucketize_pos,
const bool keep_orig_idx,
const std::optional<at::Tensor>& total_num_blocks);
const std::optional<at::Tensor>& total_num_blocks,
const std::optional<at::Tensor>& keep_orig_idx_per_feature);

std::tuple<
at::Tensor,
Expand All @@ -214,7 +215,8 @@ block_bucketize_sparse_features_cpu(
const int64_t max_batch_size,
const std::optional<std::vector<at::Tensor>>& block_bucketize_pos,
const bool keep_orig_idx,
const std::optional<at::Tensor>& total_num_blocks);
const std::optional<at::Tensor>& total_num_blocks,
const std::optional<at::Tensor>& keep_orig_idx_per_feature);

std::tuple<
at::Tensor,
Expand All @@ -237,7 +239,8 @@ block_bucketize_sparse_features_inference_cuda(
const std::optional<std::vector<at::Tensor>>& block_bucketize_pos,
const bool return_bucket_mapping,
const bool keep_orig_idx,
const std::optional<at::Tensor>& total_num_blocks);
const std::optional<at::Tensor>& total_num_blocks,
const std::optional<at::Tensor>& keep_orig_idx_per_feature);

///@ingroup sparse-data-cuda
at::Tensor populate_bucketized_permute_cuda(
Expand Down Expand Up @@ -267,7 +270,8 @@ block_bucketize_sparse_features_inference_cpu(
const std::optional<std::vector<at::Tensor>>& block_bucketize_pos,
const bool return_bucket_mapping,
const bool keep_orig_idx,
const std::optional<at::Tensor>& total_num_blocks);
const std::optional<at::Tensor>& total_num_blocks,
const std::optional<at::Tensor>& keep_orig_idx_per_feature);

///@ingroup sparse-data-cpu
at::Tensor populate_bucketized_permute_cpu(
Expand Down
76 changes: 55 additions & 21 deletions fbgemm_gpu/src/sparse_ops/sparse_block_bucketize_features.cu
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,8 @@ __launch_bounds__(kMaxThreads) void _block_bucketize_pooled_sparse_features_cuda
const index_t* const __restrict__ block_bucketize_pos_concat,
const index_t* const __restrict__ block_bucketize_pos_offsets,
const index_t* const __restrict__ indices_to_lb,
const bool keep_orig_idx) {
const bool keep_orig_idx,
const bool* const __restrict__ keep_orig_idx_per_feature) {
using uindex_t = std::make_unsigned_t<index_t>;
const auto bt_start = blockIdx.x * blockDim.y + threadIdx.y;
const auto stride = gridDim.x * blockDim.y;
Expand Down Expand Up @@ -220,6 +221,12 @@ __launch_bounds__(kMaxThreads) void _block_bucketize_pooled_sparse_features_cuda
total_num_blocks == nullptr ? my_size : total_num_blocks[t];
const index_t global_idx_size = blk_size * global_num_blks;
const index_t local_idx_size = blk_size * local_num_blks;
auto keep_idx = keep_orig_idx;
if (keep_orig_idx_per_feature != nullptr) {
// When keep_orig_idx_per_feature is set, override global
// keep_orig_idx settings
keep_idx = keep_orig_idx_per_feature[t];
}
for (auto i = rowstart + threadIdx.x; i < rowend; i += blockDim.x) {
// We have use cases using none-hashed raw indices that can be either
// negative or larger than embedding table hash_size (blk_size *
Expand All @@ -233,7 +240,7 @@ __launch_bounds__(kMaxThreads) void _block_bucketize_pooled_sparse_features_cuda
if (!use_block_bucketize_pos) { // uniform bucket sizes
p = idx < global_idx_size ? idx / local_idx_size
: (idx % global_num_blks) / local_num_blks;
if (keep_orig_idx) {
if (keep_idx) {
new_idx = idx;
} else if (idx < global_idx_size) {
new_idx = idx % local_idx_size;
Expand All @@ -243,7 +250,7 @@ __launch_bounds__(kMaxThreads) void _block_bucketize_pooled_sparse_features_cuda
} else { // variable bucket sizes
uindex_t lb = indices_to_lb[i];
p = lb < my_size ? lb : idx % my_size;
if (keep_orig_idx) {
if (keep_idx) {
new_idx = idx;
} else if (blk_size == 0) {
new_idx = idx / global_num_blks;
Expand Down Expand Up @@ -307,7 +314,8 @@ __launch_bounds__(kMaxThreads) void _block_bucketize_sequence_sparse_features_cu
const index_t* const __restrict__ block_bucketize_pos_concat,
const index_t* const __restrict__ block_bucketize_pos_offsets,
const index_t* const __restrict__ indices_to_lb,
const bool keep_orig_idx) {
const bool keep_orig_idx,
const bool* const __restrict__ keep_orig_idx_per_feature) {
using uindex_t = std::make_unsigned_t<index_t>;
using uoffset_t = std::make_unsigned_t<offset_t>;
CUDA_KERNEL_LOOP(b_t, lengths_size) {
Expand All @@ -324,6 +332,12 @@ __launch_bounds__(kMaxThreads) void _block_bucketize_sequence_sparse_features_cu
offset_t rowend = offsets_data[b_t];
const auto use_block_bucketize_pos =
(block_bucketize_pos_concat != nullptr);
auto keep_idx = keep_orig_idx;
if (keep_orig_idx_per_feature != nullptr) {
// When keep_orig_idx_per_feature is set, override global
// keep_orig_idx settings
keep_idx = keep_orig_idx_per_feature[t];
}
for (index_t i = rowstart; i < rowend; ++i) {
// We have use cases using none-hashed raw indices that can be either
// negative or larger than embedding table hash_size (blk_size *
Expand All @@ -337,7 +351,7 @@ __launch_bounds__(kMaxThreads) void _block_bucketize_sequence_sparse_features_cu
if (!use_block_bucketize_pos) {
p = idx < global_idx_size ? idx / local_idx_size
: (idx % global_num_blks) / local_num_blks;
if (keep_orig_idx) {
if (keep_idx) {
new_idx = idx;
} else if (idx < global_idx_size) {
new_idx = idx % local_idx_size;
Expand All @@ -347,7 +361,7 @@ __launch_bounds__(kMaxThreads) void _block_bucketize_sequence_sparse_features_cu
} else {
uindex_t lb = indices_to_lb[i];
p = lb < my_size ? lb : idx % my_size;
if (keep_orig_idx) {
if (keep_idx) {
new_idx = idx;
} else if (blk_size == 0) {
new_idx = idx / global_num_blks;
Expand Down Expand Up @@ -455,7 +469,10 @@ __launch_bounds__(kMaxThreads) void _populate_bucketized_permute_cuda_kernel(
block_bucketize_pos.has_value() \
? indices_to_lb.data_ptr<index_t>() \
: static_cast<index_t*>(nullptr), \
keep_orig_idx); \
keep_orig_idx, \
keep_orig_idx_per_feature.has_value() \
? keep_orig_idx_per_feature->data_ptr<bool>() \
: static_cast<bool*>(nullptr)); \
C10_CUDA_KERNEL_LAUNCH_CHECK(); \
}); \
}); \
Expand Down Expand Up @@ -514,7 +531,10 @@ __launch_bounds__(kMaxThreads) void _populate_bucketized_permute_cuda_kernel(
block_bucketize_pos.has_value() \
? indices_to_lb.data_ptr<index_t>() \
: static_cast<index_t*>(nullptr), \
keep_orig_idx); \
keep_orig_idx, \
keep_orig_idx_per_feature.has_value() \
? keep_orig_idx_per_feature->data_ptr<bool>() \
: static_cast<bool*>(nullptr)); \
C10_CUDA_KERNEL_LAUNCH_CHECK(); \
}); \
});
Expand Down Expand Up @@ -577,7 +597,10 @@ __launch_bounds__(kMaxThreads) void _populate_bucketized_permute_cuda_kernel(
block_bucketize_pos.has_value() \
? indices_to_lb.data_ptr<index_t>() \
: static_cast<index_t*>(nullptr), \
keep_orig_idx); \
keep_orig_idx, \
keep_orig_idx_per_feature.has_value() \
? keep_orig_idx_per_feature->data_ptr<bool>() \
: static_cast<bool*>(nullptr)); \
C10_CUDA_KERNEL_LAUNCH_CHECK(); \
}); \
}); \
Expand Down Expand Up @@ -637,13 +660,17 @@ __launch_bounds__(kMaxThreads) void _populate_bucketized_permute_cuda_kernel(
block_bucketize_pos.has_value() \
? indices_to_lb.data_ptr<index_t>() \
: static_cast<index_t*>(nullptr), \
keep_orig_idx); \
keep_orig_idx, \
keep_orig_idx_per_feature.has_value() \
? keep_orig_idx_per_feature->data_ptr<bool>() \
: static_cast<bool*>(nullptr)); \
C10_CUDA_KERNEL_LAUNCH_CHECK(); \
}); \
});

// This function partitions sparse features
// continuously along the sparse dimension into my_size blocks
// continuously along the sparse dimension into
// my_size blocks
std::tuple<
Tensor,
Tensor,
Expand All @@ -664,7 +691,8 @@ _block_bucketize_sparse_features_cuda(
const int64_t max_B,
const std::optional<std::vector<at::Tensor>>& block_bucketize_pos,
const bool return_bucket_mapping,
const bool keep_orig_idx) {
const bool keep_orig_idx,
const std::optional<Tensor>& keep_orig_idx_per_feature) {
TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(lengths, indices);

CUDA_DEVICE_GUARD(lengths);
Expand Down Expand Up @@ -740,8 +768,9 @@ _block_bucketize_sparse_features_cuda(
at::Tensor sizes_vec =
at::tensor(sizes_, at::TensorOptions().dtype(indices_contig.dtype()));
block_bucketize_pos_offsets = asynchronous_exclusive_cumsum_cpu(
sizes_vec); // expect sizes_vec to be a small tensor, using cpu instead
// of gpu for cumsum
sizes_vec); // expect sizes_vec to be a
// small tensor, using cpu
// instead of gpu for cumsum
block_bucketize_pos_offsets = block_bucketize_pos_offsets.to(
block_bucketize_pos_concat.device(), true);
}
Expand Down Expand Up @@ -896,8 +925,8 @@ _block_bucketize_sparse_features_cuda(
#undef LAUNCH_BLOCK_BUCKETIZE_POOLED_SPARSE_FEATURES_CUDA_KERNEL_2_WITHOUT_WEIGHT

// This function partitions sparse features
// continuously along the sparse dimension into my_size
// blocks
// continuously along the sparse dimension into
// my_size blocks
DLL_PUBLIC std::tuple<
Tensor,
Tensor,
Expand All @@ -916,7 +945,8 @@ block_bucketize_sparse_features_cuda(
const int64_t max_B,
const std::optional<std::vector<at::Tensor>>& block_bucketize_pos,
const bool keep_orig_idx,
const std::optional<Tensor>& total_num_blocks) {
const std::optional<Tensor>& total_num_blocks,
const std::optional<at::Tensor>& keep_orig_idx_per_feature) {
Tensor new_lengths;
Tensor new_indices;
std::optional<Tensor> new_weights;
Expand All @@ -942,12 +972,14 @@ block_bucketize_sparse_features_cuda(
max_B,
block_bucketize_pos,
false,
keep_orig_idx);
keep_orig_idx,
keep_orig_idx_per_feature);
return {new_lengths, new_indices, new_weights, new_pos, unbucketize_permute};
}

// This function partitions sparse features
// continuously along the sparse dimension into my_size blocks
// continuously along the sparse dimension into
// my_size blocks
DLL_PUBLIC std::tuple<
Tensor,
Tensor,
Expand All @@ -968,7 +1000,8 @@ block_bucketize_sparse_features_inference_cuda(
const std::optional<std::vector<at::Tensor>>& block_bucketize_pos,
const bool return_bucket_mapping,
const bool keep_orig_idx,
const std::optional<Tensor>& total_num_blocks) {
const std::optional<Tensor>& total_num_blocks,
const std::optional<at::Tensor>& keep_orig_idx_per_feature) {
return _block_bucketize_sparse_features_cuda(
lengths,
indices,
Expand All @@ -982,7 +1015,8 @@ block_bucketize_sparse_features_inference_cuda(
max_B,
block_bucketize_pos,
return_bucket_mapping,
keep_orig_idx);
keep_orig_idx,
keep_orig_idx_per_feature);
}

DLL_PUBLIC Tensor populate_bucketized_permute_cuda(
Expand Down
Loading
Loading