Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Features] Add NMS Kernel support with Triton Implementation #8746

Open
wants to merge 9 commits into
base: main
Choose a base branch
from

Conversation

Stonepia
Copy link

@Stonepia Stonepia commented Nov 25, 2024

Motivation

This PR follows RFC #8679 which proposes to add torchvision custom op support with Triton kernels.

Implementing Method

The Triton kernel mapping basically follows the CUDA kernels. As is shown below, the native CUDA kernel will be mapped into the Triton kernel. Some logic could not be run in parallel, thus they will be implemented with Python as well as C++ Ops.

Kernel Mapping

This PR contains the following parts:

  1. Kernel Implementation: This is mostly done in folder torchvision/ops/triton/. This contains the common logic that could be implemented in Triton.
  2. Op Registration: This is in torchvision/ops/xpu. This will do op registration and combine non-Triton ops with Triton kernels into one big op.
  3. Tests: Will be the same as the existing test.

Kernel Implementing Structure

The NMS kernel contains three parts, please see torchvision/ops/xpu/nms.py for details. It wraps the three parts:

  1. pre-processing part: There are some logic like argsort which are called using PyTorch ATen ops.
  2. Triton kernel: This Triton kernel will compute a matrix of IoU mask. This is done in torchvision/ops/triton/nms.py. It is a device-agnostic part, which could be shared across devices.
  3. post-processing: This is a serialized part that has data dependency, there is no benefit of implementing them using Triton, thus fallback to ATen implementation.

Kernel Implementing Detail

  1. The Triton kernel calculates the mask matrix for every input box based on the intersection-over-union (IoU) score. Its output will be a matrix indicating whether we should choose box j if we have already chosen box i. A naive implementation will have a matrix with [N, N]. However, as the performance consideration, it will try to combine the "bit mask" into 32-bits ints. Thus, the output will be [N, N//32].
  2. After the mask matrix is calculated, a serialized post-process function will be needed. It will iterate on each row of the mask matrix. If we choose the row i, this means we choose the box i. As a result, some boxes j will be excluded. That's what the post-process function does. To make it more device-agnostic, we choose to do this serialized process on the CPU.

cc: @EikanWang

Copy link

pytorch-bot bot commented Nov 25, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/vision/8746

Note: Links to docs will display an error until the docs builds have been completed.

This comment was automatically generated by Dr. CI and updates every 15 minutes.

picked.append(order[i])
remove_box[i:] |= iou_keep_out_mask[i][i:]

return torch.as_tensor(picked)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this also respect the device of the boxes? (remove_boxes is allocated on boxes.device, while the return value - always on CPU)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the reminder~! Yes this should be on boxes.device. I will update it.

@Stonepia
Copy link
Author

Also attach the performance compared with native CUDA implementation on A100:
nms_performance

From this picture, the triton implementation could reach a competitive result as Native CUDA. The torch.compile would reach peak performance, because it reduces the Python overhead and kernel launch overhead.

However, when the size is too large, one interesting finding is that all of the implementations will have a large performance drop. I think this may reach the bottleneck of memory. Maybe the next step would be optimizing this.

@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=['size'],  # Argument names to use as an x-axis for the plot.
        x_vals=[2**i for i in range(2, 13, 1)],  # Different possible values for `x_name`.
        x_log=True,  # x axis is logarithmic.
        line_arg='provider',  # Argument name whose value corresponds to a different line in the plot.
        line_vals=['triton', 'torch_compile', 'native_cuda', 'python'],  # Possible values for `line_arg`.
        line_names=['Triton', 'Torch Compile', 'Native CUDA', 'Python Op'],  # Label name for the lines.
        styles=[('blue', '-'),('pink', '-'), ('green', '-'),('yellow', '-')],  # Line styles.
        ylabel='GB/s',  # Label name for the y-axis.
        plot_name='nms_performance',  # Name for the plot. Used also as a file name for saving the plot.
        args={},  # Values for function arguments not in `x_names` and `y_name`.
    ))
def benchmark(size, provider):
    boxes, scores = _create_tensors_with_iou(size, threshold)
    quantiles = [0.5, 0.2, 0.8]
    compiled_nms = torch.compile(custom_nms_triton_kernel)
    if provider == 'native_cuda':
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: torchvision.ops.nms(boxes, scores, threshold), quantiles=quantiles)
    if provider == 'triton':
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: custom_nms_triton_kernel(boxes, scores,threshold), quantiles=quantiles)
    if provider == 'python':
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: _reference_nms(boxes, scores,threshold), quantiles=quantiles)
    if provider == 'torch_compile':
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: compiled_nms(boxes, scores,threshold), quantiles=quantiles)
    gbps = lambda ms: boxes.numel() * boxes.element_size() * 1e-9 / (ms * 1e-3)
    return gbps(ms), gbps(max_ms), gbps(min_ms)

benchmark.run(print_data=True, show_plots=True, save_path='.')

@Stonepia Stonepia marked this pull request as ready for review December 16, 2024 10:05
@Stonepia Stonepia changed the title [Draft] [Features] Add NMS Kernel support with Triton Implementation [Features] Add NMS Kernel support with Triton Implementation Dec 16, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants