Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add FusedLinearCrossEntropy #2485

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
[build-system]
# Use legacy backend to import local packages in setup.py
build-backend = "setuptools.build_meta:__legacy__"


[tool.black]
line-length = 88
target-version = ["py38"]
exclude = '''/submodules/.*'''

[tool.usort]
excludes = ["**/submodules/**"]
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .operator import Operator
108 changes: 108 additions & 0 deletions torchbenchmark/operators/FusedLinearCrossEntropy/operator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import argparse
from typing import Callable, Generator, List, Optional

import torch

from torchbenchmark.util.triton_op import BenchmarkOperator, register_benchmark

try:
from liger_kernel.transformers.fused_linear_cross_entropy import (
LigerFusedLinearCrossEntropyLoss,
)
except ModuleNotFoundError:
LigerFusedLinearCrossEntropyLoss = None

# Reference: https://github.com/linkedin/Liger-Kernel/blob/\
# 3d0653b035222cbb845435a1994854e4fd219107/benchmark/scripts/benchmark_fused_linear_cross_entropy.py


def parse_op_args(args: List[str]):
parser = argparse.ArgumentParser()
parser.add_argument("--hidden-size", type=int, default=4096, help="hidden size")
parser.add_argument("--vocab-size", type=int, default=128256, help="vocab size")
return parser.parse_args(args)


class TorchLMHeadCE(torch.nn.Module):
"""Ground truth implementation of the linear fused with torch based cross entropy loss.

:param H: hidden size
:param V: vocab size
:param ignore_index: index to ignore
:param reduction: reduction method
"""

def __init__(self, H: int, V: int, dtype: torch.dtype, ignore_index: int = -100):
super().__init__()
self.lin = torch.nn.Linear(
in_features=H, out_features=V, bias=False, dtype=dtype
)
self.ce_loss = torch.nn.CrossEntropyLoss(
ignore_index=ignore_index, reduction="mean"
)

def forward(self, input, target):
logits = self.lin(input)
return self.ce_loss(logits, target)


class LigerLMHeadCE(torch.nn.Module):
def __init__(self, H: int, V: int, dtype: torch.dtype, ignore_index: int = -100):
super().__init__()
self.lin = torch.nn.Linear(
in_features=H, out_features=V, bias=False, dtype=dtype
)
self.ce_loss = LigerFusedLinearCrossEntropyLoss(
ignore_index=ignore_index, reduction="mean"
)

def forward(self, input, target):
return self.ce_loss(self.lin.weight, input, target)


class Operator(BenchmarkOperator):
def __init__(
self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None
):
super().__init__(tb_args, extra_args)
op_args = parse_op_args(self.extra_args)
self.hidden_size = op_args.hidden_size
self.vocab_size = op_args.vocab_size
self.baseline_model = TorchLMHeadCE(
H=self.hidden_size, V=self.vocab_size, dtype=self.dtype
).to(self.device)
self.liger_model = LigerLMHeadCE(
H=self.hidden_size, V=self.vocab_size, dtype=self.dtype
).to(self.device)
self.use_cuda_graphs = False

def get_input_iter(self) -> Generator:
for BT in [2**i for i in range(12, 16)]:
_input = torch.randn(
BT,
self.hidden_size,
requires_grad=True,
dtype=self.dtype,
device=self.device,
)
target = torch.randint(
self.vocab_size, (BT, 1), dtype=torch.long, device=self.device
).squeeze(1)
yield _input, target

@register_benchmark(baseline=True)
def LMHeadCE(self, input, target) -> Callable:
return lambda: self.baseline_model(input, target)

@register_benchmark()
def LigerLMHeadCE(self, input, target) -> Callable:
return lambda: self.liger_model(input, target)

@register_benchmark()
def inductor_fused_linear_cross_entropy(self, input, target) -> Callable:
compiled = torch.compile(self.baseline_model, dynamic=False)
return lambda: compiled(input, target)

def get_bwd_fn(self, fwd_fn: Callable) -> Callable:
y = fwd_fn()
return lambda: y.backward(retain_graph=True)
10 changes: 10 additions & 0 deletions userbenchmark/triton/install.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,13 @@ def install_fa3():
subprocess.check_call(cmd, cwd=str(FA3_PATH.resolve()))


def install_liger():
# Liger-kernel has a conflict dependency `triton` with pytorch,
# so we need to install it without dependencies
cmd = ["pip", "install", "liger-kernel", "--no-deps"]
subprocess.check_call(cmd)


def install_tk():
try:
from .tk.install import install_tk
Expand All @@ -88,6 +95,7 @@ def install_tk():
)
parser.add_argument("--jax", action="store_true", help="Install jax nightly")
parser.add_argument("--tk", action="store_true", help="Install ThunderKittens")
parser.add_argument("--liger", action="store_true", help="Install Liger-kernel")
parser.add_argument("--test", action="store_true", help="Run test")
args = parser.parse_args()

Expand All @@ -105,3 +113,5 @@ def install_tk():
install_jax()
if args.tk and not args.test:
install_tk()
if args.liger and not args.test:
install_liger()
Loading