Skip to content

Commit

Permalink
Add early exit to sparse_segment_sum_csr_cuda op (pytorch#2277)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#2277

Add early exit to sparse_segment_sum_csr_cuda op in case of empty input. Add check for invalid input size.

Reviewed By: sryap, jasonjk-park

Differential Revision: D52963209

fbshipit-source-id: bec0192793e9be49018d47a751aae1dcf0ac8425
  • Loading branch information
Mark Eremeev authored and facebook-github-bot committed Feb 4, 2024
1 parent aabb4ae commit dad9720
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 0 deletions.
7 changes: 7 additions & 0 deletions fbgemm_gpu/src/sparse_ops/sparse_segment_sum_csr.cu
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,14 @@ DLL_PUBLIC Tensor segment_sum_csr_cuda(

CUDA_DEVICE_GUARD(values);

TORCH_CHECK(csr_seg.numel() >= 1, "The csr_seg tensor should not be empty")

auto output = at::empty(csr_seg.numel() - 1, values.options());

if (csr_seg.numel() == 1) {
return output;
}

constexpr uint32_t threads_per_block = 256;
const uint32_t num_blocks = csr_seg.numel() - 1;

Expand Down
18 changes: 18 additions & 0 deletions fbgemm_gpu/test/sparse/misc_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,24 @@ def test_segment_sum_csr(self) -> None:
segment_sum_cuda.cpu(), torch.Tensor([10.0, 11.0, 34.0]), rtol=0, atol=0
)

def test_segment_sum_csr_empty_input(self) -> None:
segment_sum_cpu = torch.ops.fbgemm.segment_sum_csr(
0,
torch.IntTensor([0]),
torch.Tensor([]),
)
torch.testing.assert_close(segment_sum_cpu.numel(), 0, rtol=0, atol=0)

if torch.cuda.is_available():
segment_sum_cuda = torch.ops.fbgemm.segment_sum_csr(
0,
torch.IntTensor([0]).cuda(),
torch.Tensor([]).cuda(),
)
torch.testing.assert_close(
segment_sum_cuda.cpu().numel(), 0, rtol=0, atol=0
)

@given(
batch_size=st.just(2),
m=st.just(3),
Expand Down

0 comments on commit dad9720

Please sign in to comment.