Skip to content

Migrate TBE backward kernels to FBGEMM_LAUNCH_KERNEL #4076

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "fbgemm_gpu/utils/ops_utils.h"
#include "fbgemm_gpu/utils/tensor_utils.h"
#include "fbgemm_gpu/utils/assert_macros.h"
#include "fbgemm_gpu/utils/kernel_launcher.cuh"

using Tensor = at::Tensor;
using namespace fbgemm_gpu;
Expand Down Expand Up @@ -405,53 +406,49 @@ Tensor {{ mdesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_cuda(
{%- set kernel_name =
"{}_embedding_codegen_grad_indice_weights{}_{}kernel".format(
mdesc, vdesc, vbdesc)
%}
#ifdef FBGEMM_GPU_MEMCHECK
const auto func_name = "{{ kernel_name }}";
#endif
{{ kernel_name }}<
emb_t,
grad_t,
cache_t,
index_t,
kFixedMaxVecsPerThread><<<
%}
FBGEMM_LAUNCH_KERNEL(
({{ kernel_name }}<
emb_t,
grad_t,
cache_t,
index_t,
kFixedMaxVecsPerThread>),
div_round_up(total_B, kForwardMaxThreads / kWarpSize),
dim3(kWarpSize, kForwardMaxThreads / kWarpSize),
0,
at::cuda::getCurrentCUDAStream()>>>(
MAKE_PTA_WITH_NAME(func_name, grad_output_reshaped, grad_t, 2, 64),
MAKE_PTA_WITH_NAME(func_name, dev_weights, emb_t, 1, 64),
at::cuda::getCurrentCUDAStream(),
PTA_B(grad_output_reshaped, grad_t, 2, 64),
PTA_B(dev_weights, emb_t, 1, 64),
{%- if not dense %}
MAKE_PTA_WITH_NAME(func_name, uvm_weights, emb_t, 1, 64),
MAKE_PTA_WITH_NAME(func_name, lxu_cache_weights, cache_t, 2, 64),
MAKE_PTA_WITH_NAME(func_name, weights_placements, int32_t, 1, 32),
PTA_B(uvm_weights, emb_t, 1, 64),
PTA_B(lxu_cache_weights, cache_t, 2, 64),
PTA_B(weights_placements, int32_t, 1, 32),
{%- endif %}
MAKE_PTA_WITH_NAME(func_name, weights_offsets, int64_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name, D_offsets, int32_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name, indices, index_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name, offsets, index_t, 1, 32),
PTA_B(weights_offsets, int64_t, 1, 32),
PTA_B(D_offsets, int32_t, 1, 32),
PTA_B(indices, index_t, 1, 32),
PTA_B(offsets, index_t, 1, 32),
{%- if not dense %}
MAKE_PTA_WITH_NAME(func_name, {{ locs_or_addrs_tensor }}, {{ locs_or_addrs_type }}, 1, 32),
PTA_B({{ locs_or_addrs_tensor }}, {{ locs_or_addrs_type }}, 1, 32),
{%- endif %}
MAKE_PTA_WITH_NAME(func_name, feature_requires_grad_, int32_t, 1, 32),
MAKE_PTA_ACC_WITH_NAME(func_name, grad_indice_weights, grad_t, 1, 32),
PTA_B(feature_requires_grad_, int32_t, 1, 32),
PTA_ACC_B(grad_indice_weights, grad_t, 1, 32),
{%- if vbe %}
MAKE_PTA_WITH_NAME(func_name, vbe_row_output_offsets, int64_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name, vbe_b_t_map, int32_t, 1, 32),
PTA_B(vbe_row_output_offsets, int64_t, 1, 32),
PTA_B(vbe_b_t_map, int32_t, 1, 32),
info_B_num_bits,
info_B_mask
{%- else %}
FixedDivisor(total_B / T)
{%- endif %}
);
C10_CUDA_KERNEL_LAUNCH_CHECK();
return;
});
{%- endfor %} {# /* for use_vec_blocking */ #}
});
});

C10_CUDA_KERNEL_LAUNCH_CHECK();
return grad_indice_weights;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -794,34 +794,22 @@ Tensor {{ embedding_cuda_op }}(
// {{ locs_or_addrs_tensor }} run ids and sorted_linear_indices run ids.
auto dev_or_uvm_unique_indices = at::zeros_like(weights_placements);

DEBUG_KERNEL_BARRIER_ISOLATE([&] {
#ifdef FBGEMM_GPU_MEMCHECK
const auto func_name = "split_embedding_backward_count_unique_indices_kernel";
#endif
split_embedding_backward_count_unique_indices_kernel<
{{ "int64_t" if nobag else "int32_t" }},
{{ "int64_t" if nobag else "uint32_t" }},
{{ "true" if nobag else "false" }}
><<<
div_round_up(total_unique_indices, kMaxThreads),
kMaxThreads,
0,
at::cuda::getCurrentCUDAStream()
>>>(
MAKE_PTA_WITH_NAME(
func_name, sorted_linear_indices_num_runs, int32_t, 1, 32),
MAKE_PTA_WITH_NAME(
func_name, sorted_linear_indices_cumulative_run_lengths, int32_t, 1, 32),
MAKE_PTA_WITH_NAME(
func_name, infos_sorted, {{ "int64_t" if nobag else "int32_t" }}, 1, 32),
MAKE_PTA_WITH_NAME(
func_name, weights_placements, int32_t, 1, 32),
MAKE_PTA_WITH_NAME(
func_name, dev_or_uvm_unique_indices, int32_t, 1, 32),
info_B_num_bits
);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}); // DEBUG_KERNEL_BARRIER_ISOLATE
FBGEMM_LAUNCH_KERNEL(
(split_embedding_backward_count_unique_indices_kernel<
{{ "int64_t" if nobag else "int32_t" }},
{{ "int64_t" if nobag else "uint32_t" }},
{{ "true" if nobag else "false" }}
>),
div_round_up(total_unique_indices, kMaxThreads),
kMaxThreads,
0,
at::cuda::getCurrentCUDAStream(),
PTA_B(sorted_linear_indices_num_runs, int32_t, 1, 32),
PTA_B(sorted_linear_indices_cumulative_run_lengths, int32_t, 1, 32),
PTA_B(infos_sorted, {{ "int64_t" if nobag else "int32_t" }}, 1, 32),
PTA_B(weights_placements, int32_t, 1, 32),
PTA_B(dev_or_uvm_unique_indices, int32_t, 1, 32), info_B_num_bits
);

table_unique_indices_offsets =
fbgemm_gpu::asynchronous_complete_cumsum_gpu(dev_or_uvm_unique_indices).to(at::kInt);
Expand Down Expand Up @@ -944,34 +932,27 @@ Tensor {{ embedding_cuda_op }}(
Tensor grad_output_mean;
if (static_cast<PoolingMode>(pooling_mode) == PoolingMode::MEAN) {
grad_output_mean = at::empty_like(grad_output_reshaped);
{%- if not dense or not vbe %}

DEBUG_KERNEL_BARRIER_ISOLATE([&] {
#ifdef FBGEMM_GPU_MEMCHECK
const auto func_name1 = "grad_mean{{ vdesc }}_kernel";
#endif
grad_mean{{ vdesc }}_kernel<<<
div_round_up(total_B, kMaxThreads / kWarpSize),
dim3(kWarpSize, kMaxThreads / kWarpSize),
0,
at::cuda::getCurrentCUDAStream()>>>
(
MAKE_PTA_WITH_NAME(func_name1, grad_output_mean, grad_t, 2, 64),
MAKE_PTA_WITH_NAME(func_name1, grad_output_reshaped, grad_t, 2, 64),
MAKE_PTA_WITH_NAME(func_name1, D_offsets, int32_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name1, offsets, index_t, 1, 32),
{%- if vbe %}
MAKE_PTA_WITH_NAME(func_name1, vbe_row_output_offsets, int64_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name1, vbe_b_t_map, int32_t, 1, 32),
info_B_num_bits,
info_B_mask
{%- else %}
FixedDivisor(total_B / T)
{%- endif %}
);

C10_CUDA_KERNEL_LAUNCH_CHECK();
}); // DEBUG_KERNEL_BARRIER_ISOLATE
{%- if not dense or not vbe %}
FBGEMM_LAUNCH_KERNEL(
(grad_mean{{ vdesc }}_kernel<grad_t, index_t>),
div_round_up(total_B, kMaxThreads / kWarpSize),
dim3(kWarpSize, kMaxThreads / kWarpSize),
0,
at::cuda::getCurrentCUDAStream(),
PTA_B(grad_output_mean, grad_t, 2, 64),
PTA_B(grad_output_reshaped, grad_t, 2, 64),
PTA_B(D_offsets, int32_t, 1, 32),
PTA_B(offsets, index_t, 1, 32),
{%- if vbe %}
PTA_B(vbe_row_output_offsets, int64_t, 1, 32),
PTA_B(vbe_b_t_map, int32_t, 1, 32),
info_B_num_bits,
info_B_mask
{%- else %}
FixedDivisor(total_B / T)
{%- endif %}
);
{%- endif %} // if not dense or not vbe

grad_output_accessor = MAKE_PTA_WITH_NAME("{{ embedding_cuda_op }}.2", grad_output_mean, grad_t, 2, 64);
Expand Down Expand Up @@ -1012,29 +993,23 @@ Tensor {{ embedding_cuda_op }}(
use_deterministic_algorithms ? 0 : (indices.numel() / max_segment_length_per_cta),
indices.options().dtype(at::kInt));

DEBUG_KERNEL_BARRIER_ISOLATE([&] {
#ifdef FBGEMM_GPU_MEMCHECK
const auto func_name2 = "split_embedding_backward_codegen_find_long_segments";
#endif

split_embedding_backward_codegen_find_long_segments<<<
div_round_up(total_unique_indices, kMaxThreads),
kMaxThreads,
0,
at::cuda::getCurrentCUDAStream()
>>>(
MAKE_PTA_WITH_NAME(func_name2, sorted_linear_indices_num_runs, int32_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name2, sorted_linear_indices_run_lengths, int32_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name2, long_run_ids, int32_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name2, num_long_run_ids, int32_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name2, long_run_id_to_really_long_run_ids, int32_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name2, num_really_long_run_ids, int32_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name2, grad_accum_counter, int32_t, 1, 32),
max_segment_length_per_warp,
max_segment_length_per_cta,
use_deterministic_algorithms);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}); // DEBUG_KERNEL_BARRIER_ISOLATE
FBGEMM_LAUNCH_KERNEL(
split_embedding_backward_codegen_find_long_segments,
div_round_up(total_unique_indices, kMaxThreads),
kMaxThreads,
0,
at::cuda::getCurrentCUDAStream(),
PTA_B(sorted_linear_indices_num_runs, int32_t, 1, 32),
PTA_B(sorted_linear_indices_run_lengths, int32_t, 1, 32),
PTA_B(long_run_ids, int32_t, 1, 32),
PTA_B(num_long_run_ids, int32_t, 1, 32),
PTA_B(long_run_id_to_really_long_run_ids, int32_t, 1, 32),
PTA_B(num_really_long_run_ids, int32_t, 1, 32),
PTA_B(grad_accum_counter, int32_t, 1, 32),
max_segment_length_per_warp,
max_segment_length_per_cta,
use_deterministic_algorithms
);

// A temp buffer to accumulate gradients with atomics.
auto temp_grad_accum = at::zeros(
Expand Down
Loading