Skip to content

Commit

Permalink
[SYCL][NATIVE_CPU] Fill in any SYCL functions which require mapping t…
Browse files Browse the repository at this point in the history
…o mux (#15592)

Native cpu can make calls to mux builtins such as shuffle which are ABI
compliant but are not what is expected by ock passes. This fixes them up
by remove the vector versions from libnativecpu.cpp and using a pass to
convert from parameters which relate to the ABI to calling the mux
functions with the set interface unaffected by the ABI.

This currently only handle a small number of cases for shuffle such as
when a vector i2 is replaced with double or byval is used. It will be
expanded over time as needed.
  • Loading branch information
coldav authored Oct 4, 2024
1 parent b808427 commit ddd23ad
Show file tree
Hide file tree
Showing 7 changed files with 365 additions and 22 deletions.
15 changes: 1 addition & 14 deletions libdevice/nativecpu_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -296,20 +296,7 @@ DefShuffleINTEL_All(uint8_t, i8, int8_t)
DefShuffleINTEL_All(double, f64, double)
DefShuffleINTEL_All(float, f32, float)

#define DefineShuffleVec(T, N, Sfx, MuxType) \
using vt##T##N = sycl::vec<T, N>::vector_t; \
using vt##MuxType##N = sycl::vec<MuxType, N>::vector_t; \
DefShuffleINTEL_All(vt##T##N, v##N##Sfx, vt##MuxType##N)

#define DefineShuffleVec2to16(Type, Sfx, MuxType) \
DefineShuffleVec(Type, 2, Sfx, MuxType) \
DefineShuffleVec(Type, 4, Sfx, MuxType) \
DefineShuffleVec(Type, 8, Sfx, MuxType) \
DefineShuffleVec(Type, 16, Sfx, MuxType)

DefineShuffleVec2to16(int32_t, i32, int32_t)
DefineShuffleVec2to16(uint32_t, i32, int32_t)
DefineShuffleVec2to16(float, f32, float)
// Vector versions of shuffle are generated by the FixABIBuiltinsSYCLNativeCPU pass

#define Define2ArgForward(Type, Name, Callee)\
DEVICE_EXTERNAL Type Name(Type a, Type b) { return Callee(a,b);}
Expand Down
29 changes: 29 additions & 0 deletions llvm/include/llvm/SYCLLowerIR/FixABIMuxBuiltinsSYCLNativeCPU.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
//===---- FixABIMuxBuiltins.h - Fixup ABI issues with called mux builtins ---===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Creates calls to shuffle up/down/xor mux builtins taking into account ABI of the
// SYCL functions. For now this only is used for vector variants.
//
//===----------------------------------------------------------------------===//

#pragma once

#include "llvm/IR/Module.h"
#include "llvm/IR/PassManager.h"


namespace llvm {

class FixABIMuxBuiltinsPass final
: public llvm::PassInfoMixin<FixABIMuxBuiltinsPass> {
public:
llvm::PreservedAnalyses run(llvm::Module &, llvm::ModuleAnalysisManager &);
};

} // namespace llvm

2 changes: 1 addition & 1 deletion llvm/lib/SYCLNativeCPUUtils/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ add_llvm_component_library(LLVMSYCLNativeCPUUtils
PrepareSYCLNativeCPU.cpp
RenameKernelSYCLNativeCPU.cpp
ConvertToMuxBuiltinsSYCLNativeCPU.cpp

FixABIMuxBuiltinsSYCLNativeCPU.cpp

ADDITIONAL_HEADER_DIRS
${LLVM_MAIN_INCLUDE_DIR}/llvm/SYCLLowerIR
Expand Down
226 changes: 226 additions & 0 deletions llvm/lib/SYCLNativeCPUUtils/FixABIMuxBuiltinsSYCLNativeCPU.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
//===-- FixABIMuxBuiltinsSYCLNativeCPU.cpp - Fixup mux ABI issues ---===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Creates calls to shuffle up/down/xor mux builtins taking into account ABI of
// the SYCL functions. For now this only is used for vector variants.
//
//===----------------------------------------------------------------------===//

#include <llvm/IR/IRBuilder.h>
#include <llvm/IR/Module.h>
#include <llvm/IR/Type.h>
#include <llvm/SYCLLowerIR/FixABIMuxBuiltinsSYCLNativeCPU.h>

#define DEBUG_TYPE "fix-abi-mux-builtins"

using namespace llvm;

PreservedAnalyses FixABIMuxBuiltinsPass::run(Module &M,
ModuleAnalysisManager &AM) {
bool Changed = false;

// Decide if a function needs updated and if so what parameters need changing,
// as well as the return value
auto FunctionNeedsFixing =
[](Function &F,
llvm::SmallVectorImpl<std::pair<unsigned int, llvm::Type *>> &Updates,
llvm::Type *&RetVal, std::string &MuxFuncNameToCall) {
if (!F.isDeclaration()) {
return false;
}
if (!F.getName().contains("__spirv_SubgroupShuffle")) {
return false;
}
Updates.clear();
auto LIDvPos = F.getName().find("ELIDv");
llvm::StringRef NameToMatch;
if (LIDvPos != llvm::StringRef::npos) {
// Add sizeof ELIDv to get num characters to match against
NameToMatch = F.getName().take_front(LIDvPos + 5);
} else {
return false;
}

unsigned int StartIdx = 0;
unsigned int EndIdx = 1;
if (NameToMatch == "_Z32__spirv_SubgroupShuffleDownINTELIDv") {
MuxFuncNameToCall = "__mux_sub_group_shuffle_down_";
} else if (NameToMatch == "_Z30__spirv_SubgroupShuffleUpINTELIDv") {
MuxFuncNameToCall = "__mux_sub_group_shuffle_up_";
} else if (NameToMatch == "_Z28__spirv_SubgroupShuffleINTELIDv") {
MuxFuncNameToCall = "__mux_sub_group_shuffle_";
EndIdx = 0;
} else if (NameToMatch == "_Z31__spirv_SubgroupShuffleXorINTELIDv") {
MuxFuncNameToCall = "__mux_sub_group_shuffle_xor_";
EndIdx = 0;
} else {
return false;
}

// We need to create the body for this. First we need to find out what
// the first arguments should be
llvm::StringRef RemainingName =
F.getName().drop_front(NameToMatch.size());
std::string MuxFuncTypeStr = "UNKNOWN";

unsigned int VecWidth = 0;
if (RemainingName.consumeInteger(10, VecWidth)) {
return false;
}
if (!RemainingName.consume_front("_")) {
return false;
}

char TypeCh = RemainingName[0];
Type *BaseType = nullptr;
switch (TypeCh) {
case 'a':
case 'h':
BaseType = llvm::Type::getInt8Ty(F.getContext());
MuxFuncTypeStr = "i8";
break;
case 's':
case 't':
BaseType = llvm::Type::getInt16Ty(F.getContext());
MuxFuncTypeStr = "i16";
break;

case 'i':
case 'j':
BaseType = llvm::Type::getInt32Ty(F.getContext());
MuxFuncTypeStr = "i32";
break;
case 'l':
case 'm':
BaseType = llvm::Type::getInt64Ty(F.getContext());
MuxFuncTypeStr = "i64";
break;
case 'f':
BaseType = llvm::Type::getFloatTy(F.getContext());
MuxFuncTypeStr = "f32";
break;
case 'd':
BaseType = llvm::Type::getDoubleTy(F.getContext());
MuxFuncTypeStr = "f64";
break;
default:
return false;
}
auto *VecType = llvm::FixedVectorType::get(BaseType, VecWidth);
RetVal = VecType;

// Work out the mux function to call's type extension based on v##N##Sfx
MuxFuncNameToCall += "v";
MuxFuncNameToCall += std::to_string(VecWidth);
MuxFuncNameToCall += MuxFuncTypeStr;

unsigned int CurrentIndex = 0;
for (auto &Arg : F.args()) {
if (Arg.hasStructRetAttr()) {
StartIdx++;
EndIdx++;
} else {
if (CurrentIndex >= StartIdx && CurrentIndex <= EndIdx) {
if (Arg.getType() != VecType) {
Updates.push_back(std::pair<unsigned int, llvm::Type *>(
CurrentIndex, VecType));
}
}
}
CurrentIndex++;
}
return true;
};

llvm::SmallVector<Function *, 4> FuncsToProcess;
for (auto &F : M.functions()) {
FuncsToProcess.push_back(&F);
}

for (auto *F : FuncsToProcess) {
llvm::SmallVector<std::pair<unsigned int, llvm::Type *>, 4> ArgUpdates;
llvm::Type *RetType = nullptr;
std::string MuxFuncNameToCall;
if (!FunctionNeedsFixing(*F, ArgUpdates, RetType, MuxFuncNameToCall)) {
continue;
}
if (!F->isDeclaration()) {
continue;
}
Changed = true;
IRBuilder<> IR(BasicBlock::Create(F->getContext(), "", F));

llvm::SmallVector<Type *, 8> Args;
unsigned int ArgIndex = 0;
unsigned int UpdateIndex = 0;

for (auto &Arg : F->args()) {
if (!Arg.hasStructRetAttr()) {
if (UpdateIndex < ArgUpdates.size() &&
std::get<0>(ArgUpdates[UpdateIndex]) == ArgIndex) {
Args.push_back(std::get<1>(ArgUpdates[UpdateIndex]));
UpdateIndex++;
} else {
Args.push_back(Arg.getType());
}
}
ArgIndex++;
}

FunctionType *FT = FunctionType::get(RetType, Args, false);
Function *NewFunc =
Function::Create(FT, F->getLinkage(), MuxFuncNameToCall, M);
llvm::SmallVector<Value *, 8> CallArgs;
auto NewFuncArgItr = NewFunc->args().begin();
Argument *SretPtr = nullptr;
for (auto &Arg : F->args()) {
if (Arg.hasStructRetAttr()) {
SretPtr = &Arg;
} else {
if (Arg.getType() != (*NewFuncArgItr).getType()) {
if (Arg.getType()->isPointerTy()) {
Value *ArgLoad = IR.CreateLoad((*NewFuncArgItr).getType(), &Arg);
CallArgs.push_back(ArgLoad);
} else {
Value *ArgCast = IR.CreateBitCast(&Arg, (*NewFuncArgItr).getType());
CallArgs.push_back(ArgCast);
}
} else {
CallArgs.push_back(&Arg);
}
NewFuncArgItr++;
}
}

Value *Res = IR.CreateCall(NewFunc, CallArgs);
// If the return type is different to the initial function, then bitcast it
// unless it's void in which case we'd expect an StructRet parameter which
// needs stored to.
if (F->getReturnType() != RetType) {
if (F->getReturnType()->isVoidTy()) {
// If we don't have an StructRet parameter then something is wrong with
// the initial function
if (!SretPtr) {
llvm_unreachable(
"No struct ret pointer for Sub group shuffle function");
}

IR.CreateStore(Res, SretPtr);
} else {
Res = IR.CreateBitCast(Res, F->getReturnType());
}
}
if (F->getReturnType()->isVoidTy()) {
IR.CreateRetVoid();
} else {
IR.CreateRet(Res);
}
}

return Changed ? PreservedAnalyses::none() : PreservedAnalyses::all();
}
2 changes: 2 additions & 0 deletions llvm/lib/SYCLNativeCPUUtils/PipelineSYCLNativeCPU.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
//
//===----------------------------------------------------------------------===//
#include "llvm/SYCLLowerIR/ConvertToMuxBuiltinsSYCLNativeCPU.h"
#include "llvm/SYCLLowerIR/FixABIMuxBuiltinsSYCLNativeCPU.h"
#include "llvm/SYCLLowerIR/PrepareSYCLNativeCPU.h"
#include "llvm/SYCLLowerIR/RenameKernelSYCLNativeCPU.h"
#include "llvm/SYCLLowerIR/SpecConstants.h"
Expand Down Expand Up @@ -65,6 +66,7 @@ void llvm::sycl::utils::addSYCLNativeCPUBackendPasses(
MPM.addPass(ConvertToMuxBuiltinsSYCLNativeCPUPass());
#ifdef NATIVECPU_USE_OCK
MPM.addPass(compiler::utils::TransferKernelMetadataPass());
MPM.addPass(FixABIMuxBuiltinsPass());
// Always enable vectorizer, unless explictly disabled or -O0 is set.
if (OptLevel != OptimizationLevel::O0 && !SYCLNativeCPUNoVecz) {
MAM.registerPass([] { return vecz::TargetInfoAnalysis(); });
Expand Down
18 changes: 11 additions & 7 deletions llvm/lib/SYCLNativeCPUUtils/PrepareSYCLNativeCPU.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -464,13 +464,17 @@ PreservedAnalyses PrepareSYCLNativeCPUPass::run(Module &M,
F->eraseFromParent();
ModuleChanged = true;
}
for (auto It = M.begin(); It != M.end();) {
auto Curr = It++;
Function &F = *Curr;
if (F.getNumUses() == 0 && F.isDeclaration() &&
F.getName().starts_with("__mux_")) {
F.eraseFromParent();
ModuleChanged = true;

// We do these twice because we create abi wrappers for mux which may show up
// before we've removed their user
for (unsigned int i = 0; i < 2; i++) {
for (auto It = M.begin(); It != M.end();) {
auto Curr = It++;
Function &F = *Curr;
if (F.getNumUses() == 0 && F.getName().starts_with("__mux_")) {
F.eraseFromParent();
ModuleChanged = true;
}
}
}

Expand Down
Loading

0 comments on commit ddd23ad

Please sign in to comment.