Skip to content

Commit

Permalink
Add nsys report analyzer (#65)
Browse files Browse the repository at this point in the history
Summary:
This PR add a nsys report analyzer providing metrics
```python
nsys_metrics_to_reports = {
    # the sum of kernel execution time
    "nsys_gpu_kernel_sum": ["cuda_gpu_kern_sum", "nvtx_sum"],
    # the overhead of kernel launch
    "nsys_launch_overhead": ["cuda_gpu_kern_sum", "nvtx_sum"],
    # the names of kernels
    "nsys_kernel_names": ["cuda_gpu_kern_sum"],
    # the durations of kernels
    "nsys_kernel_durations": ["cuda_gpu_kern_sum"],
    # the duration of nvtx range
    "nsys_nvtx_range_duration": ["nvtx_sum"],
    # the number of kernels
    "nsys_num_of_kernels": ["cuda_gpu_kern_sum"],
}
```
`nsys_gpu_kernel_sum` is the sum of total GPU kernel execution time on GPUs, the `nsys_nvtx_range_duration ` is the total execution time of the operator, and the `nsys_launch_overhead` is their difference which indicates the launch overhead. This is one way to measure execution time mentioned in #50

Fix #67

Pull Request resolved: #65

Test Plan:
```
% python run.py --op rope  --num-inputs 1  --metrics nsys_gpu_kernel_sum,nsys_launch_overhead,nsys_kernel_names,nsys_kernel_durations,nsys_nvtx_range_duration,nsys_num_of_kernels --csv --dump-csv
  0%|                                                                                                                                                         | 0/1 [00:00<?, ?it/s]`LlamaRotaryEmbedding` can now be fully parameterized by passing the model config through the `config` argument. All other arguments will be removed in v4.46
  0%|          | 0/1 [00:00<?, ?it/s]`LlamaRotaryEmbedding` can now be fully parameterized by passing the model config through the `config` argument. All other arguments will be removed in v4.46
Capture range started in the application.
Capture range ended in the application.
Generating '/tmp/nsys-report-531e.qdstrm'
[1/1] [0%                          ] nsys_output.nsys-repProcessing events...
[1/1] [========================100%] nsys_output.nsys-rep
Generated:
    /tmp/tritonbench/rope/nsys_traces/apply_rotary_pos_emb_0/nsys_output.nsys-rep
  0%|          | 0/1 [00:00<?, ?it/s]`LlamaRotaryEmbedding` can now be fully parameterized by passing the model config through the `config` argument. All other arguments will be removed in v4.46
Capture range started in the application.
Capture range ended in the application.
Generating '/tmp/nsys-report-39ea.qdstrm'
[1/1] [0%                          ] nsys_output.nsys-repProcessing events...
[1/1] [========================100%] nsys_output.nsys-rep
Generated:
    /tmp/tritonbench/rope/nsys_traces/liger_rotary_pos_emb_0/nsys_output.nsys-rep
  0%|          | 0/1 [00:00<?, ?it/s]`LlamaRotaryEmbedding` can now be fully parameterized by passing the model config through the `config` argument. All other arguments will be removed in v4.46
Capture range started in the application.
Capture range ended in the application.
Generating '/tmp/nsys-report-e8bf.qdstrm'
[1/1] [0%                          ] nsys_output.nsys-repProcessing events...
[1/1] [========================100%] nsys_output.nsys-rep
Generated:
    /tmp/tritonbench/rope/nsys_traces/inductor_rotary_pos_emb_full_op_0/nsys_output.nsys-rep
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:52<00:00, 52.40s/it]
(H, T);apply_rotary_pos_emb-nsys_kernel_names;apply_rotary_pos_emb-nsys_kernel_durations;apply_rotary_pos_emb-nsys_gpu_kernel_sum;apply_rotary_pos_emb-nsys_num_of_kernels;apply_rotary_pos_emb-nsys_launch_overhead;apply_rotary_pos_emb-nsys_nvtx_range_duration;liger_rotary_pos_emb-nsys_kernel_names;liger_rotary_pos_emb-nsys_kernel_durations;liger_rotary_pos_emb-nsys_gpu_kernel_sum;liger_rotary_pos_emb-nsys_num_of_kernels;liger_rotary_pos_emb-nsys_launch_overhead;liger_rotary_pos_emb-nsys_nvtx_range_duration;inductor_rotary_pos_emb_full_op-nsys_kernel_names;inductor_rotary_pos_emb_full_op-nsys_kernel_durations;inductor_rotary_pos_emb_full_op-nsys_gpu_kernel_sum;inductor_rotary_pos_emb_full_op-nsys_num_of_kernels;inductor_rotary_pos_emb_full_op-nsys_launch_overhead;inductor_rotary_pos_emb_full_op-nsys_nvtx_range_duration
(8192, 1024);['void at::native::elementwise_kernel<(int)128, (int)2, void at::native::gpu_kernel_impl_nocast<at::native::BinaryFunctor<float, float, float, at::native::binary_internal::MulFunctor<float>>>(at::TensorIteratorBase &, const T1 &)::[lambda(int) (instance 1)]>(int, T3)', 'void at::native::<unnamed>::CatArrayBatchedCopy<at::native::<unnamed>::OpaqueType<(unsigned int)4>, unsigned int, (int)4, (int)64, (int)64>(T1 *, at::native::<unnamed>::CatArrInputTensorMetadata<T1, T2, T4, T5>, at::native::<unnamed>::TensorSizeStride<T2, (unsigned int)4>, int, T2)', 'void at::native::elementwise_kernel<(int)128, (int)2, void at::native::gpu_kernel_impl_nocast<at::native::CUDAFunctor_add<float>>(at::TensorIteratorBase &, const T1 &)::[lambda(int) (instance 1)]>(int, T3)', 'void at::native::elementwise_kernel<(int)128, (int)2, void at::native::gpu_kernel_impl_nocast<at::native::neg_kernel_cuda(at::TensorIteratorBase &)::[lambda() (instance 2)]::operator ()() const::[lambda() (instance 7)]::operator ()() const::[lambda(float) (instance 1)]>(at::TensorIteratorBase &, const T1 &)::[lambda(int) (instance 1)]>(int, T3)'];0.090065;0.351364;4;0.4534;0.804764;['_triton_rope'];0.049281;0.049281;1;0.176437;0.225718;['triton_poi_fused_add_cat_mul_0', 'triton_poi_fused_add_cat_mul_1'];0.0266885;0.053377;2;0.444969;0.498346
[TritonBench] Dumped csv to /tmp/tritonbench/op_rope__z_yqmrz.csv
```

Reviewed By: xuzhao9

Differential Revision: D66311127

Pulled By: FindHao

fbshipit-source-id: 085454e34a3e9aadb360309cc69885684a8a1758
  • Loading branch information
FindHao authored and facebook-github-bot committed Nov 27, 2024
1 parent bec6d9f commit deb9183
Show file tree
Hide file tree
Showing 3 changed files with 160 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def get_arithmetic_intensity(kernel):
def read_ncu_report(report_path: str, required_metrics: List[str]):
assert os.path.exists(
report_path
), f"The NCU report at {report_path} does not exist. Ensure you add --metrics ncu_rep to your benchmark run."
), f"The NCU report at {report_path} does not exist."
import_ncu_python_path()
import ncu_report

Expand Down
105 changes: 105 additions & 0 deletions tritonbench/components/ncu/nsys_analyzer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import csv
import os
import subprocess
from typing import Dict, List

# The nsys metrics to the reports. The value is the list of reports of nsys.
nsys_metrics_to_reports = {
# the sum of kernel execution time
"nsys_gpu_kernel_sum": ["nvtx_kern_sum", "nvtx_sum"],
# the overhead of kernel launch
"nsys_launch_overhead": ["nvtx_kern_sum", "nvtx_sum"],
# the names of kernels
"nsys_kernel_names": ["nvtx_kern_sum"],
# the durations of kernels
"nsys_kernel_durations": ["nvtx_kern_sum"],
# the duration of nvtx range
"nsys_nvtx_range_duration": ["nvtx_sum"],
# the number of kernels
"nsys_num_of_kernels": ["nvtx_kern_sum"],
}


def read_nsys_report(
report_path: str, required_metrics: List[str]
) -> Dict[str, List[float]]:
assert os.path.exists(
report_path
), f"The nsys report at {report_path} does not exist."
reports_required = []
for metric in required_metrics:
if metric in nsys_metrics_to_reports:
reports_required.extend(nsys_metrics_to_reports[metric])
reports_required = list(set(reports_required))
assert reports_required, "No nsys reports required"
cmd = f"nsys stats --report {','.join(reports_required)} --force-export=true --format csv --output . --force-overwrite=true {report_path}"
try:
subprocess.check_call(
cmd.split(), stdout=subprocess.DEVNULL, stderr=subprocess.PIPE
)
except subprocess.CalledProcessError as e:
print(f"Failed to run nsys command: {cmd}\nError: {e}")
raise e
# Get the base path and filename without extension
base_path = os.path.dirname(report_path)
base_name = os.path.splitext(os.path.basename(report_path))[0]

results = {}
csv_contents = {}

for report in reports_required:
csv_path = os.path.join(base_path, f"{base_name}_{report}.csv")
if not os.path.exists(csv_path):
raise RuntimeError(f"Expected CSV report not found at {csv_path}")

# Read CSV using DictReader
with open(csv_path, "r") as f:
reader = csv.DictReader(f)
csv_contents[report] = list(reader)
kernel_duration = []
kernel_names = []
sum_kernel_duration = 0
nvtx_range_duration = 0
if "nvtx_kern_sum" in csv_contents:
# gpu kernel execution time summary
for row in csv_contents["nvtx_kern_sum"]:
# use ms as the unit
kernel_duration.append(float(row["Total Time (ns)"]) / 1_000_000)
kernel_names.append(row["Kernel Name"])
sum_kernel_duration = sum(kernel_duration)
if "nvtx_sum" in csv_contents:
# It is supposed to be only one row. The nvtx range is `:tritonbench_range`
assert len(csv_contents["nvtx_sum"]) == 1
# @TODO: nsys has a bug that the unit of nvtx range duration is ms sometimes.
# waiting for nvidia replys.
nvtx_range_duration = (
float(csv_contents["nvtx_sum"][0]["Total Time (ns)"]) / 1_000_000
)

# Define mapping of metrics to their values. The keys must be in nsys_bench_metrics.
metrics_map = {
# Because tritonbench takes the median of numerical values, we need to convert
# the list of floats to a list of strings.
"nsys_kernel_durations": [str(duration) for duration in kernel_duration],
"nsys_kernel_names": kernel_names,
"nsys_gpu_kernel_sum": sum_kernel_duration,
"nsys_nvtx_range_duration": nvtx_range_duration,
"nsys_launch_overhead": nvtx_range_duration - sum_kernel_duration,
"nsys_num_of_kernels": len(kernel_names),
}
# Verify that metrics_map keys match nsys_metrics_to_reports keys
assert set(metrics_map.keys()) == set(nsys_metrics_to_reports.keys()), (
f"Mismatch between metrics_map keys and nsys_metrics_to_reports keys.\n"
f"metrics_map keys: {set(metrics_map.keys())}\n"
f"nsys_metrics_to_reports keys: {set(nsys_metrics_to_reports.keys())}"
)
# Add only requested metrics to results
results.update(
{
metric: metrics_map[metric]
for metric in required_metrics
if metric in metrics_map
}
)

return results
59 changes: 54 additions & 5 deletions tritonbench/utils/triton_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import torch
import triton

from tritonbench.components.ncu import analyzer as ncu_analyzer
from tritonbench.components.ncu import ncu_analyzer, nsys_analyzer
from tritonbench.utils.env_utils import (
apply_precision,
fresh_triton_cache,
Expand Down Expand Up @@ -68,7 +68,12 @@ class BenchmarkOperatorBackend:
REGISTERED_METRICS: Dict[str, List[str]] = {}
REGISTERED_X_VALS: Dict[str, str] = {}
BASELINE_BENCHMARKS: Dict[str, str] = {}
BASELINE_SKIP_METRICS = {"speedup", "accuracy", "mem_footprint_compression_ratio"}
BASELINE_SKIP_METRICS = {
"speedup",
"accuracy",
"mem_footprint_compression_ratio",
"nsys_gpu_speedup",
}
X_ONLY_METRICS = set(["hw_roofline"])
PRECISION_DTYPE_MAPPING = {
"fp32": torch.float32,
Expand Down Expand Up @@ -222,6 +227,8 @@ class BenchmarkOperatorMetrics:
mem_footprint_compression_ratio: Optional[float] = None
# gbps
gbps: Optional[float] = None
# speedup for the summary of kernel GPU time only
nsys_gpu_speedup: Optional[float] = None


BUILTIN_METRICS = {x.name for x in fields(BenchmarkOperatorMetrics)} - {"extra_metrics"}
Expand Down Expand Up @@ -307,9 +314,25 @@ def select_metric(backend, m):
)
metric_val = _metrics_dict.get(metric, None)
if isinstance(metric_val, list):
row.append(numpy.median(metric_val))
# Check if all elements are numbers before calculating median
if all(isinstance(x, Number) for x in metric_val):
row.append(numpy.median(metric_val))
else:
# For non-numeric lists, convert to string representation
metric_val_str = str(metric_val)
if ";" in metric_val_str:
logger.warning(
f"Metric value '{metric_val_str}' contains semicolon which may cause CSV parsing issues"
)
row.append(metric_val_str)
elif isinstance(metric_val, bool):
row.append(1.0 if metric_val else 0.0)
elif isinstance(metric_val, str):
if ";" in metric_val:
logger.warning(
f"Metric value '{metric_val}' contains semicolon which may cause CSV parsing issues"
)
row.append(metric_val)
else:
row.append(metric_val)
table.append(row)
Expand Down Expand Up @@ -1065,8 +1088,34 @@ def _init_extra_metrics() -> Dict[str, Any]:
metrics.ncu_rep_ir = self.ncu_trace(
input_id, fn_name, replay=True, profile_ir=True
)
if "nsys_rep" in self.required_metrics:
metrics.nsys_rep = self.nsys_rep(input_id, fn_name)
nsys_metrics = []
for metric_name in nsys_analyzer.nsys_metrics_to_reports.keys():
if metric_name in self.required_metrics:
nsys_metrics.append(metric_name)

if "nsys_rep" in self.required_metrics or nsys_metrics:
nsys_rep_path = self.nsys_rep(input_id, fn_name)
metrics.nsys_rep = nsys_rep_path
if nsys_metrics:
nsys_analyzer_results = nsys_analyzer.read_nsys_report(
nsys_rep_path, nsys_metrics
)
for metric_name, metric_value in nsys_analyzer_results.items():
metrics.extra_metrics[metric_name] = metric_value
if "nsys_gpu_speedup" in self.required_metrics:
baseline_nsys_gpu_kernel_sum = (
self.baseline_metrics.extra_metrics.get("nsys_gpu_kernel_sum", None)
if self.baseline_metrics
else None
)
current_nsys_gpu_kernel_sum = metrics.extra_metrics.get(
"nsys_gpu_kernel_sum", None
)
metrics.nsys_gpu_speedup = (
baseline_nsys_gpu_kernel_sum / current_nsys_gpu_kernel_sum
if baseline_nsys_gpu_kernel_sum and current_nsys_gpu_kernel_sum
else None
)
if "kineto_trace" in self.required_metrics:
metrics.kineto_trace = self.kineto_trace(input_id, fn)
if "best_config" in self.required_metrics:
Expand Down

0 comments on commit deb9183

Please sign in to comment.