Skip to content

Commit 0911c94

Browse files
q10facebook-github-bot
authored andcommitted
Use CUDAStream instead of cudaStream_t in kernel_launcher (#4071)
Summary: Pull Request resolved: #4071 Use CUDAStream that provides ID and device_idx without use of special APIs. at::cuda::getCurrentCUDAStream() at callsites returns CUDAStream that is then implicitly converted in cudaStream_t, but there's no reason to do that. Reviewed By: ngimel Differential Revision: D74051709 fbshipit-source-id: 52ce540d1e86e019aa369e9b8910bd71f4316c51
1 parent a3bd709 commit 0911c94

File tree

3 files changed

+4
-76
lines changed

3 files changed

+4
-76
lines changed

fbgemm_gpu/include/fbgemm_gpu/utils/device_properties.cuh

-72
Original file line numberDiff line numberDiff line change
@@ -52,76 +52,4 @@ inline auto get_device_properties(const int device) {
5252
}
5353
}
5454

55-
////////////////////////////////////////////////////////////////////////////////
56-
// Get CUDA Device From Stream
57-
//
58-
// Given a CUDA stream, fetch the device ID that the stream is associated with.
59-
// This function is memoized since the operation may be expensive
60-
////////////////////////////////////////////////////////////////////////////////
61-
62-
inline auto get_device_for_stream(const cudaStream_t& stream) {
63-
// Keep as thread local to avoid race conditions
64-
static thread_local std::unordered_map<cudaStream_t, int> table;
65-
66-
if (const auto search = table.find(stream); search != table.end()) {
67-
return search->second;
68-
69-
} else {
70-
int device = 0;
71-
72-
// CUDA 12.8+ introduced cudaStreamGetDevice() to straightforwardly fetch
73-
// the device from a given stream, but since the runtime drivers may not be
74-
// at the latest, it will not support the API. As such, we fetch the device
75-
// ID can be fetched by context capture instead.
76-
77-
// Save the current device
78-
int current_device;
79-
C10_CUDA_CHECK(cudaGetDevice(&current_device));
80-
81-
// Force stream association by capturing dummy work
82-
cudaStreamCaptureStatus status;
83-
C10_CUDA_CHECK(cudaStreamIsCapturing(stream, &status));
84-
85-
// Save the device associated with the stream, and revert back to the
86-
// current device
87-
C10_CUDA_CHECK(cudaGetDevice(&device));
88-
C10_CUDA_CHECK(cudaSetDevice(current_device));
89-
90-
table.insert({stream, device});
91-
return device;
92-
}
93-
}
94-
95-
inline auto get_stream_id(const cudaStream_t& stream) {
96-
#if defined(__HIPCC__) || (defined(CUDA_VERSION) && (CUDA_VERSION < 12060))
97-
// cudaStreamGetId is not available in HIP, and is only available in
98-
// CUDA 12.6+. Since streams are unique, we use its pointer value as the
99-
// effective stream ID here.
100-
return reinterpret_cast<unsigned long long>(stream);
101-
102-
#else
103-
// Keep as thread local to avoid race conditions
104-
static thread_local std::unordered_map<cudaStream_t, unsigned long long>
105-
table;
106-
107-
if (const auto search = table.find(stream); search != table.end()) {
108-
return search->second;
109-
110-
} else {
111-
unsigned long long streamId = 0;
112-
113-
if (auto [_, driver_version] = get_compute_versions();
114-
driver_version <= 12060) {
115-
streamId = reinterpret_cast<unsigned long long>(stream);
116-
117-
} else {
118-
C10_CUDA_CHECK(cudaStreamGetId(stream, &streamId));
119-
}
120-
121-
table.insert({stream, streamId});
122-
return streamId;
123-
}
124-
#endif
125-
}
126-
12755
} // namespace fbgemm_gpu::utils

fbgemm_gpu/include/fbgemm_gpu/utils/kernel_launcher.cuh

+3-3
Original file line numberDiff line numberDiff line change
@@ -212,12 +212,12 @@ struct KernelLauncher {
212212
const dim3 grid,
213213
const dim3 block,
214214
const size_t shared_mem_per_block,
215-
const cudaStream_t stream,
215+
const c10::cuda::CUDAStream stream,
216216
Args&&... args) const {
217217
// Fetch device properties from the stream information
218-
const auto device = get_device_for_stream(stream);
218+
const auto device = stream.device_index();
219219
const auto properties = get_device_properties(device);
220-
const auto streamId = get_stream_id(stream);
220+
const auto streamId = stream.id();
221221

222222
// Check that the grid sizes are within the range per the device associated
223223
// with the compute stream

fbgemm_gpu/test/utils/kernel_launcher_test.cu

+1-1
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ TEST(KernelLauncherTest, kernel_launch_checks) {
233233
at::Tensor A, B, C;
234234
std::tie(A, B, C) = sample_tensors(size);
235235

236-
const auto device = get_device_for_stream(at::cuda::getCurrentCUDAStream());
236+
const auto device = at::cuda::getCurrentCUDAStream().device_index();
237237
const auto properties = get_device_properties(device);
238238
const auto grid_max = properties.maxGridSize;
239239
const auto block_max = properties.maxThreadsDim;

0 commit comments

Comments
 (0)