Skip to content

Commit

Permalink
Add nsys integration
Browse files Browse the repository at this point in the history
Summary: Add new metric `--metric nsys` to collect nsys trace.

Reviewed By: htyu

Differential Revision: D63274918

fbshipit-source-id: 0536310df6290ea5f5a02d85cc0ad6d342d45dbd
  • Loading branch information
xuzhao9 authored and facebook-github-bot committed Sep 26, 2024
1 parent a31c3fe commit 2edf80c
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 17 deletions.
56 changes: 41 additions & 15 deletions torchbenchmark/_components/ncu/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,28 @@
from typing import Callable

import torch

def do_bench_ncu_in_task(

class cuda_profiler_range:
def __init__(self, use_cuda_profiler_range):
self.use_cuda_profiler_range = use_cuda_profiler_range

def __enter__(self):
if self.use_cuda_profiler_range:
torch.cuda.cudart().cudaProfilerStart()

def __exit__(self, *exc_info):
if self.use_cuda_profiler_range:
torch.cuda.cudart().cudaProfilerStop()


def do_bench_in_task(
fn: Callable,
grad_to_none=None,
fast_flush=True,
range_name: str = "",
warmup: bool = False,
warmup_time: int = 25,
use_cuda_profiler_range: bool = False,
) -> None:
"""
Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with
Expand All @@ -15,23 +32,30 @@ def do_bench_ncu_in_task(
:type fn: Callable
:param grad_to_none: Reset the gradient of the provided tensor to None
:type grad_to_none: torch.tensor, optional
:param fast_flush: Use faster kernel to flush L2 between measurements
:type fast_flush: bool
:param output_dir: Output directory to store the trace
:type output_dir: str, optional
"""
import torch

fn()
torch.cuda.synchronize()

# We maintain a buffer of 256 MB that we clear
# before each kernel call to make sure that the L2
# doesn't contain any input data before the run
if fast_flush:
cache = torch.empty(int(256e6 // 4), dtype=torch.int, device="cuda")
else:
cache = torch.empty(int(256e6), dtype=torch.int8, device="cuda")
cache = torch.empty(int(256e6 // 4), dtype=torch.int, device="cuda")

if warmup:
# Estimate the runtime of the function
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
for _ in range(5):
cache.zero_()
fn()
end_event.record()
torch.cuda.synchronize()
estimate_ms = start_event.elapsed_time(end_event) / 5

# compute number of warmup and repeat
n_warmup = max(1, int(warmup / estimate_ms))
# Warm-up
for _ in range(n_warmup):
fn()

# we don't want `fn` to accumulate gradient values
# if it contains a backward pass. So we clear the
Expand All @@ -41,5 +65,7 @@ def do_bench_ncu_in_task(
x.grad = None
# we clear the L2 cache before run
cache.zero_()
with torch.cuda.nvtx.range(range_name):
with cuda_profiler_range(use_cuda_profiler_range), torch.cuda.nvtx.range(
range_name
):
fn()
75 changes: 73 additions & 2 deletions torchbenchmark/util/triton_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,8 @@ class BenchmarkOperatorMetrics:
ncu_rep: Optional[str] = None
# ncu replay file with TTGIR line numbers
ncu_rep_ir: Optional[str] = None
# nsys replay file
nsys_rep: Optional[str] = None
# kineto trace file
kineto_trace: Optional[str] = None
# cpu peak memory
Expand Down Expand Up @@ -859,6 +861,8 @@ def _init_extra_metrics() -> Dict[str, Any]:
metrics.ncu_rep_ir = self.ncu_trace(
input_id, fn_name, replay=True, profile_ir=True
)
if "nsys_rep" in self.required_metrics:
metrics.nsys_rep = self.nsys_rep(input_id, fn_name)
if "kineto_trace" in self.required_metrics:
metrics.kineto_trace = self.kineto_trace(input_id, fn)
if "best_config" in self.required_metrics:
Expand Down Expand Up @@ -886,14 +890,33 @@ def _init_extra_metrics() -> Dict[str, Any]:
"_ncu_trace_in_task must be measured by itself. "
f"required_metrics: {self.required_metrics}, _only: {self._only}, _input_id: {self._input_id}"
)
from torchbenchmark._components.ncu import do_bench_ncu_in_task
from torchbenchmark._components.ncu import do_bench_in_task

do_bench_ncu_in_task(
do_bench_in_task(
fn=fn,
grad_to_none=self.get_grad_to_none(self.example_inputs),
range_name=_RANGE_NAME,
)
metrics.extra_metrics["_ncu_trace_in_task"] = "success"
if "_nsys_rep_in_task" in self.required_metrics:
assert (
self.required_metrics == ["_nsys_rep_in_task"]
and len(self._only) == 1
and (self._input_id is not None)
), (
"_nsys_rep_in_task must be measured by itself. "
f"required_metrics: {self.required_metrics}, _only: {self._only}, _input_id: {self._input_id}"
)
from torchbenchmark._components.ncu import do_bench_in_task

do_bench_in_task(
fn=fn,
grad_to_none=self.get_grad_to_none(self.example_inputs),
range_name=_RANGE_NAME,
warmup=True,
use_cuda_profiler_range=True,
)
metrics.extra_metrics["_nsys_rep_in_task"] = "success"
# generate customized metrics
if self.name in REGISTERED_METRICS:
for metric_name in REGISTERED_METRICS[self.name]:
Expand Down Expand Up @@ -925,6 +948,54 @@ def get_peak_mem(
metrics_gpu_backend="nvml",
)

def nsys_rep(self, input_id: int, fn_name: str) -> str:
import subprocess
import sys

op_task_args = [] if IS_FBCODE else [sys.executable]
op_task_args.extend(copy.deepcopy(sys.argv))
for override_option in ["--only", "--input-id", "--num-inputs", "--metrics"]:
op_task_args = _remove_params(
op_task_args, _find_param_loc(op_task_args, override_option)
)
op_task_args.extend(
[
"--only",
fn_name,
"--num-inputs",
str(1),
"--input-id",
str(input_id),
"--metrics",
"_nsys_rep_in_task",
]
)
nsys_output_dir = self.get_temp_path(f"nsys_traces/{fn_name}_{input_id}")
nsys_output_dir.mkdir(parents=True, exist_ok=True)
ext = ".nsys-rep"
nsys_output_file = nsys_output_dir.joinpath(f"nsys_output{ext}").resolve()
nsys_trace_cmd = [
"nsys",
"profile",
"-c",
"cudaProfilerApi",
"-t",
"nvtx,osrt,cuda,cudnn,cublas",
"-w",
"true",
"-f",
"true",
"-o",
nsys_output_file,
]
nsys_trace_cmd.extend(op_task_args)
try:
subprocess.check_call(nsys_trace_cmd)
except subprocess.CalledProcessError:
# FIXME: calling nsys on Tritonbench will throw SIGTERM with error code 143
pass
return str(nsys_output_file.resolve())

def ncu_trace(
self, input_id: int, fn_name: str, replay: bool = False, profile_ir=False
) -> str:
Expand Down

0 comments on commit 2edf80c

Please sign in to comment.