Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,10 @@ install(EXPORT RaptorTargets
DESTINATION "${INSTALL_CMAKE_DIR}"
COMPONENT dev)

install(PROGRAMS
"${CMAKE_CURRENT_SOURCE_DIR}/scripts/raptor_plot_float_histogram.py"
DESTINATION bin)

add_subdirectory(runtime)
add_subdirectory(test)
add_subdirectory(wrappers)
68 changes: 57 additions & 11 deletions pass/Raptor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include <memory>

#include "llvm/Analysis/ScalarEvolution.h"
#include "llvm/Transforms/Utils/Instrumentation.h"
#include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"

#include "llvm/ADT/ArrayRef.h"
Expand Down Expand Up @@ -567,6 +568,43 @@ class RaptorBase {
llvm_unreachable("Unknown float type");
}

bool HandleLogFlops(CallInst *CI) {
IRBuilder<> Builder(CI);
Function *F = parseFunctionParameter(CI);
if (!F)
return false;
unsigned ArgNum = CI->arg_size();
if (ArgNum != 1) {
EmitFailure("TooManyArgs", CI->getDebugLoc(), CI,
"Had incorrect number of args ", *CI, " - expected 1");
return false;
}

RequestContext context(CI, &Builder);
for (auto FR :
{/* FloatRepresentation::getIEEE(16), */ FloatRepresentation::getIEEE(
32),
FloatRepresentation::getIEEE(64)}) {
FunctionType *Ty = FunctionType::get(
Builder.getVoidTy(), FR.getMustBeBuiltinType(Builder.getContext()),
false);
FunctionCallee FlopLogger = CI->getModule()->getOrInsertFunction(
std::string(RaptorPrefix) + "log_flops_" + FR.getMangling(), Ty);
llvm::Value *Res = Logic.CreateTruncateFunc(
context, F,
TruncationConfiguration::getInitialLogFlops(
FR, *cast<Function>(FlopLogger.getCallee())));
if (!Res)
return false;
F = cast<Function>(Res);
}
llvm::Value *Res = Builder.CreatePointerCast(F, CI->getType());

CI->replaceAllUsesWith(Res);
CI->eraseFromParent();
return true;
}

bool HandleTruncateFunc(CallInst *CI, TruncateMode Mode) {
IRBuilder<> Builder(CI);
Function *F = parseFunctionParameter(CI);
Expand All @@ -583,7 +621,8 @@ class RaptorBase {

RequestContext context(CI, &Builder);
llvm::Value *res = Logic.CreateTruncateFunc(
context, F, TruncationConfiguration::getInitial(Truncation, Mode));
context, F,
TruncationConfiguration::getInitial(Truncation, Builder.getContext()));
if (!res)
return false;
res = Builder.CreatePointerCast(res, CI->getType());
Expand Down Expand Up @@ -718,7 +757,7 @@ class RaptorBase {
Function *TruncatedFunc =
Logic.CreateTruncateFunc(context, &F,
TruncationConfiguration::getInitial(
Truncation, TruncOpFullModuleMode));
Truncation, Builder.getContext()));

ValueToValueMapTy Mapping;
for (auto &&[Arg, TArg] : llvm::zip(F.args(), TruncatedFunc->args()))
Expand Down Expand Up @@ -790,6 +829,7 @@ class RaptorBase {
Changed = true;
}

SmallVector<CallInst *, 4> toLogFlops;
SmallVector<CallInst *, 4> toTruncateFuncMem;
SmallVector<CallInst *, 4> toTruncateFuncOp;
SmallVector<CallInst *, 4> toTruncateValue;
Expand Down Expand Up @@ -1017,11 +1057,15 @@ class RaptorBase {
}

bool enableRaptor = false;
bool logFlops = false;
bool truncateFuncOp = false;
bool truncateFuncMem = false;
bool truncateValue = false;
bool expandValue = false;
if (false) {
} else if (Fn->getName().contains("__raptor_log_flops")) {
enableRaptor = true;
logFlops = true;
} else if (Fn->getName().contains("__raptor_truncate_mem_func")) {
enableRaptor = true;
truncateFuncMem = true;
Expand Down Expand Up @@ -1073,7 +1117,11 @@ class RaptorBase {
}
goto retry;
}
if (truncateFuncOp)
if (false)
abort();
else if (logFlops)
toLogFlops.push_back(CI);
else if (truncateFuncOp)
toTruncateFuncOp.push_back(CI);
else if (truncateFuncMem)
toTruncateFuncMem.push_back(CI);
Expand All @@ -1095,18 +1143,16 @@ class RaptorBase {
}
}

for (auto call : toTruncateFuncMem) {
for (auto call : toLogFlops)
HandleLogFlops(call);
for (auto call : toTruncateFuncMem)
HandleTruncateFunc(call, TruncMemMode);
}
for (auto call : toTruncateFuncOp) {
for (auto call : toTruncateFuncOp)
HandleTruncateFunc(call, TruncOpMode);
}
for (auto call : toTruncateValue) {
for (auto call : toTruncateValue)
HandleTruncateValue(call, true);
}
for (auto call : toExpandValue) {
for (auto call : toExpandValue)
HandleTruncateValue(call, false);
}

return Changed;
}
Expand Down
88 changes: 42 additions & 46 deletions pass/RaptorLogic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
using namespace llvm;

static Value *floatValTruncate(IRBuilderBase &B, Value *v,
FloatTruncation truncation) {
TruncationConfiguration truncation) {
if (truncation.isToFPRT())
return v;

Expand All @@ -73,7 +73,7 @@ static Value *floatValTruncate(IRBuilderBase &B, Value *v,
}

static Value *floatValExpand(IRBuilderBase &B, Value *v,
FloatTruncation truncation) {
TruncationConfiguration truncation) {
if (truncation.isToFPRT())
return v;

Expand All @@ -84,41 +84,36 @@ static Value *floatValExpand(IRBuilderBase &B, Value *v,
}

static Value *floatMemTruncate(IRBuilderBase &B, Value *v,
FloatTruncation truncation) {
if (isa<VectorType>(v->getType()))
report_fatal_error("vector operations not allowed in mem trunc mode");

Type *toTy = truncation.getToType(B.getContext());
return B.CreateBitCast(v, toTy);
TruncationConfiguration truncation) {
return v;
}

static Value *floatMemExpand(IRBuilderBase &B, Value *v,
FloatTruncation truncation) {
if (isa<VectorType>(v->getType()))
report_fatal_error("vector operations not allowed in mem trunc mode");

Type *fromTy = truncation.getFromType(B.getContext());
return B.CreateBitCast(v, fromTy);
TruncationConfiguration truncation) {
return v;
}

class TruncateUtils {
protected:
FloatTruncation truncation;
TruncationConfiguration TC;
llvm::Module *M;
Type *fromType;
Type *toType;
LLVMContext &ctx;
RaptorLogic &Logic;
Value *UnknownLoc;
Value *scratch = nullptr;
CustomArgsTy CustomArgs;
std::string RTName;

private:
std::string getOriginalFPRTName(std::string Name) {
return std::string(RaptorFPRTOriginalPrefix) + truncation.mangleFrom() +
return std::string(RaptorPrefix) + RTName + "_original_" + TC.mangleFrom() +
"_" + Name;
}
std::string getFPRTName(std::string Name) {
return std::string(RaptorFPRTPrefix) + truncation.mangleFrom() + "_" + Name;
return std::string(RaptorPrefix) + RTName + "_" + TC.mangleFrom() + "_" +
Name;
}

// Creates a function which contains the original floating point operation.
Expand Down Expand Up @@ -169,9 +164,7 @@ class TruncateUtils {
const SmallVectorImpl<Value *> &ArgsIn,
llvm::Type *RetTy, Value *LocStr) {
SmallVector<Value *, 5> Args(ArgsIn.begin(), ArgsIn.end());
Args.push_back(B.getInt64(truncation.getTo().getExponentWidth()));
Args.push_back(B.getInt64(truncation.getTo().getSignificandWidth()));
Args.push_back(B.getInt64(truncation.getMode()));
Args.append(CustomArgs);
Args.push_back(LocStr);
Args.push_back(scratch);

Expand All @@ -189,11 +182,11 @@ class TruncateUtils {
return CI;
}

TruncateUtils(FloatTruncation truncation, Module *M, RaptorLogic &Logic)
: truncation(truncation), M(M), ctx(M->getContext()), Logic(Logic) {
fromType = truncation.getFromType(ctx);
toType = truncation.getToType(ctx);

TruncateUtils(TruncationConfiguration TC, Module *M, RaptorLogic &Logic)
: TC(TC), M(M), ctx(M->getContext()), Logic(Logic),
CustomArgs(TC.CustomArgs), RTName(TC.RTName) {
fromType = TC.getFromType(M->getContext());
toType = TC.getToType(M->getContext());
UnknownLoc = getUniquedLocStr(nullptr);
scratch = ConstantPointerNull::get(PointerType::get(M->getContext(), 0));
}
Expand Down Expand Up @@ -401,7 +394,7 @@ class TruncateGenerator : public llvm::InstVisitor<TruncateGenerator>,
public TruncateUtils {
private:
ValueToValueMapTy &OriginalToNewFn;
FloatTruncation Truncation;
TruncationConfiguration TC;
TruncateMode Mode;
RaptorLogic &Logic;
LLVMContext &Ctx;
Expand All @@ -410,9 +403,9 @@ class TruncateGenerator : public llvm::InstVisitor<TruncateGenerator>,
TruncateGenerator(ValueToValueMapTy &originalToNewFn, Function *oldFunc,
Function *newFunc, RaptorLogic &Logic,
TruncationConfiguration TC)
: TruncateUtils(TC.Truncation, newFunc->getParent(), Logic),
OriginalToNewFn(originalToNewFn), Truncation(TC.Truncation),
Mode(Truncation.getMode()), Logic(Logic), Ctx(newFunc->getContext()) {
: TruncateUtils(TC, newFunc->getParent(), Logic),
OriginalToNewFn(originalToNewFn), TC(TC), Mode(TC.getMode()),
Logic(Logic), Ctx(newFunc->getContext()) {

auto AllocScratch = [&]() {
// TODO we should check at the end if we never used the scracth we should
Expand Down Expand Up @@ -444,7 +437,7 @@ class TruncateGenerator : public llvm::InstVisitor<TruncateGenerator>,
}
}
};
if (Truncation.isToFPRT()) {
if (TC.isToFPRT()) {
if (Mode == TruncOpMode) {
if (TC.NeedTruncChange || TC.NeedNewScratch)
AllocScratch();
Expand Down Expand Up @@ -503,21 +496,21 @@ class TruncateGenerator : public llvm::InstVisitor<TruncateGenerator>,
case TruncMemMode:
if (isa<ConstantFP>(v))
return createFPRTConstCall(B, v);
return floatMemTruncate(B, v, Truncation);
return floatMemTruncate(B, v, TC);
case TruncOpMode:
case TruncOpFullModuleMode:
return floatValTruncate(B, v, Truncation);
return floatValTruncate(B, v, TC);
}
llvm_unreachable("Unknown trunc mode");
}

Value *expand(IRBuilder<> &B, Value *v) {
switch (Mode) {
case TruncMemMode:
return floatMemExpand(B, v, Truncation);
return floatMemExpand(B, v, TC);
case TruncOpMode:
case TruncOpFullModuleMode:
return floatValExpand(B, v, Truncation);
return floatValExpand(B, v, TC);
}
llvm_unreachable("Unknown trunc mode");
}
Expand All @@ -527,7 +520,7 @@ class TruncateGenerator : public llvm::InstVisitor<TruncateGenerator>,
case UnaryOperator::FNeg: {
if (I.getOperand(0)->getType() != getFromType())
return;
if (!Truncation.isToFPRT())
if (!TC.isToFPRT())
return;

auto newI = getNewFromOriginal(&I);
Expand Down Expand Up @@ -565,7 +558,7 @@ class TruncateGenerator : public llvm::InstVisitor<TruncateGenerator>,
Args.push_back(truncLHS);
Args.push_back(truncRHS);
Instruction *nres;
if (Truncation.isToFPRT())
if (TC.isToFPRT())
nres = createFPRTOpCall(B, CI, B.getInt1Ty(), Args);
else
nres =
Expand Down Expand Up @@ -685,9 +678,9 @@ class TruncateGenerator : public llvm::InstVisitor<TruncateGenerator>,
auto newLHS = truncate(B, getNewFromOriginal(oldLHS));
auto newRHS = truncate(B, getNewFromOriginal(oldRHS));
Instruction *nres = nullptr;
if (Truncation.isToFPRT()) {
if (TC.isToFPRT()) {
SmallVector<Value *, 2> Args({newLHS, newRHS});
nres = createFPRTOpCall(B, BO, Truncation.getToType(Ctx), Args);
nres = createFPRTOpCall(B, BO, getToType(), Args);
} else {
nres = cast<Instruction>(B.CreateBinOp(BO.getOpcode(), newLHS, newRHS));
}
Expand Down Expand Up @@ -748,7 +741,7 @@ class TruncateGenerator : public llvm::InstVisitor<TruncateGenerator>,

Instruction *intr = nullptr;
Value *nres = nullptr;
if (Truncation.isToFPRT()) {
if (TC.isToFPRT()) {
nres = intr = createFPRTOpCall(B, CI, retTy, new_ops);
} else {
// TODO check that the intrinsic is overloaded
Expand Down Expand Up @@ -832,11 +825,13 @@ class TruncateGenerator : public llvm::InstVisitor<TruncateGenerator>,
}

Value *GetShadow(RequestContext &ctx, Value *v, bool WillPassScratch) {
if (auto F = dyn_cast<Function>(v))
return Logic.CreateTruncateFunc(
ctx, F,
TruncationConfiguration{Truncation, Mode, !WillPassScratch, false,
WillPassScratch});
if (auto F = dyn_cast<Function>(v)) {
auto NewTC = TC;
NewTC.NeedNewScratch = !WillPassScratch;
NewTC.NeedTruncChange = false;
NewTC.ScratchFromArgs = WillPassScratch;
return Logic.CreateTruncateFunc(ctx, F, NewTC);
}
llvm::errs() << " unknown get truncated func: " << *v << "\n";
llvm_unreachable("unknown get truncated func");
return v;
Expand Down Expand Up @@ -1006,8 +1001,9 @@ bool RaptorLogic::CreateTruncateValue(RequestContext context, Value *v,
IRBuilderBase &B = *context.ip;

Value *converted = nullptr;
TruncateUtils TU(Truncation, B.GetInsertBlock()->getParent()->getParent(),
*this);
TruncateUtils TU(
TruncationConfiguration::getInitial(Truncation, v->getContext()),
B.GetInsertBlock()->getParent()->getParent(), *this);
if (isTruncate)
converted = TU.createFPRTNewCall(B, v);
else
Expand Down
Loading