Skip to content

Commit

Permalink
more refactorization and simplification
Browse files Browse the repository at this point in the history
Now have tiled implementations for SIMD16 as well.
  • Loading branch information
bashbaug committed Jan 15, 2024
1 parent feb1064 commit 4e89026
Show file tree
Hide file tree
Showing 4 changed files with 369 additions and 660 deletions.
2 changes: 1 addition & 1 deletion samples/99_matrixexperiments/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ add_opencl_sample(
TARGET matrixexperiments
VERSION 120
SOURCES main.cpp
KERNELS matrix_helpers.cl matrix_kernels.cl)
KERNELS matrix_helpers.cl matrix_kernels.cl matrix_kernel_tiled.cl)
118 changes: 72 additions & 46 deletions samples/99_matrixexperiments/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ static float hw_time(cl::Event& event)
{
auto ns = event.getProfilingInfo<CL_PROFILING_COMMAND_END>() -
event.getProfilingInfo<CL_PROFILING_COMMAND_START>();
return ns / 1e9;
return ns / 1e9f;
}

static void go_naive(
Expand All @@ -166,34 +166,38 @@ static void go_naive(
printf("%80s: ", makeTestName(__FUNCTION__, M, N, K).c_str()); fflush(stdout);

cl::Kernel kernel{program, "bfloat16_naive"};
kernel.setArg(0, C);
kernel.setArg(1, A);
kernel.setArg(2, B);
kernel.setArg(3, static_cast<cl_int>(K));

queue.enqueueFillBuffer(C, 0, 0, C_ref.size());

float best = 999.0f;
for (int test = 0; test < testIterations; test++) {
cl::Event event;
auto start = test_clock::now();
queue.enqueueNDRangeKernel(kernel, cl::NullRange,
cl::NDRange{N, M}, cl::NullRange, nullptr, &event);
queue.finish();
auto end = test_clock::now();
std::chrono::duration<float> sw_time = end - start;
auto elapsed = wallclock ? sw_time.count() : hw_time(event);
best = std::min(best, elapsed);
}
auto gops = 2.0 * M * N * K / best / 1e9;
printf("Best in %f seconds (%f gops)\n", best, gops);
if (kernel() == nullptr) {
printf("unsupported.\n");
} else {
kernel.setArg(0, C);
kernel.setArg(1, A);
kernel.setArg(2, B);
kernel.setArg(3, static_cast<cl_int>(K));

if (validate) {
printf("Checking results... "); fflush(stdout);
std::vector<float> C_check(C_ref.size());
queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(C_check[0]), C_check.data());
check_results(M, N, C_check, C_ref);
printf(" done!\n");
queue.enqueueFillBuffer(C, 0, 0, C_ref.size());

float best = 999.0f;
for (int test = 0; test < testIterations; test++) {
cl::Event event;
auto start = test_clock::now();
queue.enqueueNDRangeKernel(kernel, cl::NullRange,
cl::NDRange{N, M}, cl::NullRange, nullptr, &event);
queue.finish();
auto end = test_clock::now();
std::chrono::duration<float> sw_time = end - start;
auto elapsed = wallclock ? sw_time.count() : hw_time(event);
best = std::min(best, elapsed);
}
auto gops = 2.0 * M * N * K / best / 1e9;
printf("Best in %f seconds (%f gops)\n", best, gops);

if (validate) {
printf("Checking results... "); fflush(stdout);
std::vector<float> C_check(C_ref.size());
queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(C_check[0]), C_check.data());
check_results(M, N, C_check, C_ref);
printf(" done!\n");
}
}
}

Expand All @@ -210,7 +214,9 @@ static void go_dpas_rowmajor(
kernelName += "_m" + std::to_string(tM);
kernelName += "_n" + std::to_string(tN);
cl::Kernel kernel{program, kernelName.c_str()};
if (kernel()) {
if (kernel() == nullptr) {
printf("unsupported.\n");
} else {
kernel.setArg(0, C);
kernel.setArg(1, A);
kernel.setArg(2, B);
Expand Down Expand Up @@ -238,8 +244,6 @@ static void go_dpas_rowmajor(
check_results(M, N, C_check, C_ref);
printf(" done!\n");
}
} else {
printf("unsupported.\n");
}
}

Expand All @@ -258,7 +262,13 @@ static void go_dpas_rowmajor_tiled(
kernelName += "_" + std::to_string(MM);
kernelName += "x" + std::to_string(NN);
cl::Kernel kernel{program, kernelName.c_str()};
if (kernel()) {
if (tM * MM > M) {
printf("M is too small.\n");
} else if (tN * NN > N) {
printf("N is too small.\n");
} else if (kernel() == nullptr) {
printf("unsupported.\n");
} else {
kernel.setArg(0, C);
kernel.setArg(1, A);
kernel.setArg(2, B);
Expand Down Expand Up @@ -286,8 +296,6 @@ static void go_dpas_rowmajor_tiled(
check_results(M, N, C_check, C_ref);
printf(" done!\n");
}
} else {
printf("unsupported.\n");
}
}

Expand All @@ -304,7 +312,9 @@ static void go_dpas_vnni(
kernelName += "_m" + std::to_string(tM);
kernelName += "_n" + std::to_string(tN);
cl::Kernel kernel{program, kernelName.c_str()};
if (kernel()) {
if (kernel() == nullptr) {
printf("unsupported.\n");
} else {
kernel.setArg(0, C);
kernel.setArg(1, A);
kernel.setArg(2, B);
Expand Down Expand Up @@ -334,8 +344,6 @@ static void go_dpas_vnni(
check_results(M, N, C_check, C_ref);
printf(" done!\n");
}
} else {
printf("unsupported.\n");
}
}

Expand All @@ -354,7 +362,13 @@ static void go_dpas_vnni_tiled(
kernelName += "_" + std::to_string(MM);
kernelName += "x" + std::to_string(NN);
cl::Kernel kernel{program, kernelName.c_str()};
if (kernel()) {
if (kernel() == nullptr) {
printf("unsupported.\n");
} else if (tM * MM > M) {
printf("M is too small.\n");
} else if (tN * NN > N) {
printf("N is too small.\n");
} else {
kernel.setArg(0, C);
kernel.setArg(1, A);
kernel.setArg(2, B);
Expand Down Expand Up @@ -384,8 +398,6 @@ static void go_dpas_vnni_tiled(
check_results(M, N, C_check, C_ref);
printf(" done!\n");
}
} else {
printf("unsupported.\n");
}
}

Expand All @@ -402,7 +414,9 @@ static void go_dpas_blockread_rowmajor(
kernelName += "_m" + std::to_string(tM);
kernelName += "_n" + std::to_string(tN);
cl::Kernel kernel{program, kernelName.c_str()};
if (kernel()) {
if (kernel() == nullptr) {
printf("unsupported.\n");
} else {
kernel.setArg(0, C);
kernel.setArg(1, A);
kernel.setArg(2, B);
Expand Down Expand Up @@ -430,8 +444,6 @@ static void go_dpas_blockread_rowmajor(
check_results(M, N, C_check, C_ref);
printf(" done!\n");
}
} else {
printf("unsupported.\n");
}
}

Expand All @@ -448,7 +460,9 @@ static void go_dpas_blockread_vnni(
kernelName += "_m" + std::to_string(tM);
kernelName += "_n" + std::to_string(tN);
cl::Kernel kernel{program, kernelName.c_str()};
if (kernel()) {
if (kernel() == nullptr) {
printf("unsupported.\n");
} else {
kernel.setArg(0, C);
kernel.setArg(1, A);
kernel.setArg(2, B);
Expand Down Expand Up @@ -476,8 +490,6 @@ static void go_dpas_blockread_vnni(
check_results(M, N, C_check, C_ref);
printf(" done!\n");
}
} else {
printf("unsupported.\n");
}
}

Expand Down Expand Up @@ -637,11 +649,25 @@ int main(int argc, char** argv)
go_dpas_rowmajor<4, 16, 16>(context, program, queue, C, A, B, M, N, K, C_ref);
go_dpas_rowmajor<8, 16, 16>(context, program, queue, C, A, B, M, N, K, C_ref);

go_dpas_rowmajor_tiled<8, 16, 16, 2, 1>(context, program, queue, C, A, B, M, N, K, C_ref);
go_dpas_rowmajor_tiled<8, 16, 16, 1, 2>(context, program, queue, C, A, B, M, N, K, C_ref);
go_dpas_rowmajor_tiled<8, 16, 16, 2, 2>(context, program, queue, C, A, B, M, N, K, C_ref);
go_dpas_rowmajor_tiled<8, 16, 16, 4, 2>(context, program, queue, C, A, B, M, N, K, C_ref);
go_dpas_rowmajor_tiled<8, 16, 16, 2, 4>(context, program, queue, C, A, B, M, N, K, C_ref);
go_dpas_rowmajor_tiled<8, 16, 16, 4, 4>(context, program, queue, C, A, B, M, N, K, C_ref);

go_dpas_vnni<1, 16, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);
go_dpas_vnni<2, 16, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);
go_dpas_vnni<4, 16, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);
go_dpas_vnni<8, 16, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);

go_dpas_vnni_tiled<8, 16, 16, 2, 1>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);
go_dpas_vnni_tiled<8, 16, 16, 1, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);
go_dpas_vnni_tiled<8, 16, 16, 2, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);
go_dpas_vnni_tiled<8, 16, 16, 4, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);
go_dpas_vnni_tiled<8, 16, 16, 2, 4>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);
go_dpas_vnni_tiled<8, 16, 16, 4, 4>(context, program, queue, C, A, Bvnni, M, N, K, C_ref);

go_dpas_blockread_rowmajor<1, 16, 16>(context, program, queue, C, A, B, M, N, K, C_ref);
go_dpas_blockread_rowmajor<2, 16, 16>(context, program, queue, C, A, B, M, N, K, C_ref);
go_dpas_blockread_rowmajor<4, 16, 16>(context, program, queue, C, A, B, M, N, K, C_ref);
Expand Down
Loading

0 comments on commit 4e89026

Please sign in to comment.