Skip to content

Commit

Permalink
Adjust fp8 CK GEMM heurstic (pytorch#2912)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: pytorch#2912

Reviewed By: jwfromm

Differential Revision: D60354356

fbshipit-source-id: 108eefa5ddbe266ffe6526d44cc64892487adf02
  • Loading branch information
zjing14 authored and facebook-github-bot committed Jul 31, 2024
1 parent 336a854 commit 20db709
Show file tree
Hide file tree
Showing 4 changed files with 128 additions and 43 deletions.
2 changes: 1 addition & 1 deletion fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def bench_with_rotating_buffer(self, fn, args):

# torch.cuda.get_device_properties does not have L2 cache size,
# so hard code an overapproximation of L2 cache size to ensure L2 cache flush
total_buffer_size = 16 * 1024 * 1024
total_buffer_size = 10000 * 1024 * 1024

# Use pickle to serialize model input to estimate total sizes of input
input_sizes = len(pickle.dumps(args))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,79 @@ fp8_rowwise_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave
at::Tensor w_scale,
at::Tensor Y) {
// This kernel works well for many medium to large shapes.
using DeviceGemmInstance = DeviceGemmHelper<
256,
224,
256,
128,
16,
16,
7,
8,
S<8, 32, 1>,
S<8, 32, 1>,
S<1, 32, 1, 8>,
S<8, 8, 1>,
1,
2,
ck::BlockGemmPipelineScheduler::Intrawave,
ck::BlockGemmPipelineVersion::v3>;
// Run kernel instance.
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(XQ, WQ, x_scale, w_scale, Y);

int M = size_to_dim_(XQ.dim() - 1, XQ.sizes());
int N = WQ.size(0);
int K = WQ.size(1);

bool mnpad = (M % 224 != 0) || (N % 256 != 0);
bool kpad = K % 128 != 0;

if (kpad) {
using DeviceGemmInstance = DeviceGemmHelper<
256,
224,
256,
128,
16,
16,
7,
8,
S<8, 32, 1>,
S<8, 32, 1>,
S<1, 32, 1, 8>,
S<8, 8, 1>,
1,
2,
ck::BlockGemmPipelineScheduler::Intrawave,
ck::BlockGemmPipelineVersion::v3,
ck::tensor_operation::device::GemmSpecialization::MNKPadding>;
// Run kernel instance.
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(
XQ, WQ, x_scale, w_scale, Y);
} else if (mnpad) {
using DeviceGemmInstance = DeviceGemmHelper<
256,
224,
256,
128,
16,
16,
7,
8,
S<8, 32, 1>,
S<8, 32, 1>,
S<1, 32, 1, 8>,
S<8, 8, 1>,
1,
2,
ck::BlockGemmPipelineScheduler::Intrawave,
ck::BlockGemmPipelineVersion::v3,
ck::tensor_operation::device::GemmSpecialization::MNPadding>;
// Run kernel instance.
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(
XQ, WQ, x_scale, w_scale, Y);
} else {
using DeviceGemmInstance = DeviceGemmHelper<
256,
224,
256,
128,
16,
16,
7,
8,
S<8, 32, 1>,
S<8, 32, 1>,
S<1, 32, 1, 8>,
S<8, 8, 1>,
1,
2,
ck::BlockGemmPipelineScheduler::Intrawave,
ck::BlockGemmPipelineVersion::v3,
ck::tensor_operation::device::GemmSpecialization::Default>;
// Run kernel instance.
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(
XQ, WQ, x_scale, w_scale, Y);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,24 +15,55 @@ fp8_rowwise_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave
at::Tensor x_scale,
at::Tensor w_scale,
at::Tensor Y) {
// Check if this input needs to be padded.
int M = size_to_dim_(XQ.dim() - 1, XQ.sizes());
int N = WQ.size(0);
int K = WQ.size(1);
bool pad = (M % 256 != 0) || (N % 224 != 0) || (K % 128 != 0);

// This kernel seems optimal in the most purely compute bound tasks.
using DeviceGemmInstance = DeviceGemmHelper<
256,
256,
224,
128,
16,
16,
8,
7,
S<8, 32, 1>,
S<8, 32, 1>,
S<1, 64, 1, 4>,
S<8, 8, 1>,
2,
1,
ck::BlockGemmPipelineScheduler::Intrawave,
ck::BlockGemmPipelineVersion::v3>;
// Run kernel instance.
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(XQ, WQ, x_scale, w_scale, Y);
if (pad) {
using DeviceGemmInstance = DeviceGemmHelper<
256,
256,
224,
128,
16,
16,
8,
7,
S<8, 32, 1>,
S<8, 32, 1>,
S<1, 64, 1, 4>,
S<8, 8, 1>,
2,
1,
ck::BlockGemmPipelineScheduler::Intrawave,
ck::BlockGemmPipelineVersion::v3>;
// Run kernel instance.
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(
XQ, WQ, x_scale, w_scale, Y);
} else {
using DeviceGemmInstance = DeviceGemmHelper<
256,
256,
224,
128,
16,
16,
8,
7,
S<8, 32, 1>,
S<8, 32, 1>,
S<1, 64, 1, 4>,
S<8, 8, 1>,
2,
1,
ck::BlockGemmPipelineScheduler::Intrawave,
ck::BlockGemmPipelineVersion::v3,
ck::tensor_operation::device::GemmSpecialization::Default>;
// Run kernel instance.
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(
XQ, WQ, x_scale, w_scale, Y);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ template <
ck::BlockGemmPipelineScheduler LOOP_SCHED,
ck::BlockGemmPipelineVersion PIPELINE_VERSION,
ck::tensor_operation::device::GemmSpecialization GEMM_SPEC =
ck::tensor_operation::device::GemmSpecialization::MNKPadding>
ck::tensor_operation::device::GemmSpecialization::MNPadding>
using DeviceGemmHelper =
ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3<
ALayout,
Expand Down Expand Up @@ -175,8 +175,6 @@ at::Tensor f8f8bf16_rowwise_impl(
auto cde_element_op = CDEElementOp{};

constexpr ck::index_t NumDTensor = ck::Number<2>{};
constexpr auto I0 =
ck::Number<0>{}; // Used to indicate 0 stride for row and col broadcast.

auto argument = gemm.MakeArgument(
reinterpret_cast<ADataType*>(XQ.data_ptr()),
Expand All @@ -190,7 +188,7 @@ at::Tensor f8f8bf16_rowwise_impl(
K,
StrideA,
StrideB,
std::array<ck::index_t, NumDTensor>{I0, I0},
std::array<ck::index_t, NumDTensor>{0, 0},
StrideE,
a_element_op,
b_element_op,
Expand Down

0 comments on commit 20db709

Please sign in to comment.