From 2b0244616aad9bc2eb4006c005d41893828ffadd Mon Sep 17 00:00:00 2001 From: NNVM Authors <xxs_chy@outlook.com> Date: Sun, 7 Jul 2024 16:33:38 +0000 Subject: [PATCH] sync Sun Jul 7 16:33:38 UTC 2024 --- src/nnvm/nnvm/ADT/Ranges.h | 23 +++-- src/nnvm/nnvm/Analysis/AAInfo.h | 6 ++ src/nnvm/nnvm/Analysis/AliasAnalysis.cpp | 5 ++ src/nnvm/nnvm/Analysis/AliasAnalysis.h | 4 + src/nnvm/nnvm/Analysis/BasicAA.cpp | 9 ++ src/nnvm/nnvm/Analysis/BasicAA.h | 13 +++ src/nnvm/nnvm/Backend/RISCV/CodeLayoutOpt.cpp | 22 +++++ src/nnvm/nnvm/Backend/RISCV/CodeLayoutOpt.h | 12 +++ src/nnvm/nnvm/Backend/RISCV/CodegenInfo.cpp | 2 + src/nnvm/nnvm/Backend/RISCV/LowIR.cpp | 3 + src/nnvm/nnvm/Backend/RISCV/LowInstType.h | 2 + src/nnvm/nnvm/Backend/RISCV/PhiResolution.cpp | 10 ++- src/nnvm/nnvm/Backend/RISCV/Spiller.h | 7 -- src/nnvm/nnvm/Frontend/IRGenerator.cpp | 2 + src/nnvm/nnvm/IR/BasicBlock.h | 15 ++++ src/nnvm/nnvm/IR/GlobalVariable.cpp | 8 +- src/nnvm/nnvm/IR/GlobalVariable.h | 4 + src/nnvm/nnvm/IR/Instruction.h | 1 + src/nnvm/nnvm/Transform/Infra/BlockUtils.cpp | 11 +++ src/nnvm/nnvm/Transform/Infra/BlockUtils.h | 8 ++ .../nnvm/Transform/Scalar/CFGCombiner.cpp | 84 +++++++++++++++++-- src/nnvm/nnvm/Transform/Scalar/CFGCombiner.h | 2 + src/nnvm/nnvm/Transform/Scalar/CSE.cpp | 6 +- .../nnvm/Transform/Scalar/CombinePatterns.h | 62 +++++++++++++- src/nnvm/nnvm/Transform/Scalar/Combiner.cpp | 28 ++++++- src/nnvm/nnvm/Transform/Scalar/Combiner.h | 2 + .../nnvm/Transform/Scalar/ConstantFold.cpp | 13 +++ src/nnvm/nnvm/Transform/Scalar/ConstantFold.h | 1 + timestamp | 2 +- 29 files changed, 326 insertions(+), 41 deletions(-) create mode 100644 src/nnvm/nnvm/Analysis/AAInfo.h create mode 100644 src/nnvm/nnvm/Analysis/BasicAA.cpp create mode 100644 src/nnvm/nnvm/Analysis/BasicAA.h create mode 100644 src/nnvm/nnvm/Backend/RISCV/CodeLayoutOpt.cpp create mode 100644 src/nnvm/nnvm/Backend/RISCV/CodeLayoutOpt.h delete mode 100644 src/nnvm/nnvm/Backend/RISCV/Spiller.h create mode 100644 src/nnvm/nnvm/Transform/Infra/BlockUtils.cpp create mode 100644 src/nnvm/nnvm/Transform/Infra/BlockUtils.h diff --git a/src/nnvm/nnvm/ADT/Ranges.h b/src/nnvm/nnvm/ADT/Ranges.h index 620ce8b..1151667 100644 --- a/src/nnvm/nnvm/ADT/Ranges.h +++ b/src/nnvm/nnvm/ADT/Ranges.h @@ -68,21 +68,18 @@ template <typename Range> IncChangeRange<Range> incChange(Range &range) { return IncChangeRange<Range>(range); } -template <typename T> struct reversion_wrapper { - T &iterable; -}; +template <typename Iter> struct RangeWrapper { +public: + Iter begin() { return beginIter; } + Iter end() { return endIter; } -template <typename T> auto begin(reversion_wrapper<T> w) { - auto begin = std::begin(w.iterable); - return std::reverse_iterator<decltype(begin)>(begin); -} +private: + Iter beginIter; + Iter endIter; +}; -template <typename T> auto end(reversion_wrapper<T> w) { - auto end = std::end(w.iterable); - return std::reverse_iterator<decltype(end)>(end); +template <typename Iter> RangeWrapper<Iter> makeRange(Iter begin, Iter end) { + return RangeWrapper<Iter>(begin, end); } -template <typename T> reversion_wrapper<T> reverseRange(T &&iterable) { - return {iterable}; -} } /* namespace nnvm */ diff --git a/src/nnvm/nnvm/Analysis/AAInfo.h b/src/nnvm/nnvm/Analysis/AAInfo.h new file mode 100644 index 0000000..744be87 --- /dev/null +++ b/src/nnvm/nnvm/Analysis/AAInfo.h @@ -0,0 +1,6 @@ +#pragma once + +namespace nnvm { +enum AAFlag { MayAlias, MustAlias, NotAlias }; + +} /* namespace nnvm */ diff --git a/src/nnvm/nnvm/Analysis/AliasAnalysis.cpp b/src/nnvm/nnvm/Analysis/AliasAnalysis.cpp index e69de29..e2b9238 100644 --- a/src/nnvm/nnvm/Analysis/AliasAnalysis.cpp +++ b/src/nnvm/nnvm/Analysis/AliasAnalysis.cpp @@ -0,0 +1,5 @@ +#include "AliasAnalysis.h" + +using namespace nnvm; + +AAFlag AliasAnalysis::alias(Value *a, Value *b) { return MayAlias; } diff --git a/src/nnvm/nnvm/Analysis/AliasAnalysis.h b/src/nnvm/nnvm/Analysis/AliasAnalysis.h index 317f2e9..f1d12f7 100644 --- a/src/nnvm/nnvm/Analysis/AliasAnalysis.h +++ b/src/nnvm/nnvm/Analysis/AliasAnalysis.h @@ -5,16 +5,20 @@ #pragma once +#include "Analysis/AAInfo.h" #include "IR/Instruction.h" #include "Transform/Infra/Pass.h" #include <unordered_map> #include <vector> namespace nnvm { + class AliasAnalysis : public FunctionPass { public: bool run(Function &F); + AAFlag alias(Value *a, Value *b); + private: }; } /* namespace nnvm */ diff --git a/src/nnvm/nnvm/Analysis/BasicAA.cpp b/src/nnvm/nnvm/Analysis/BasicAA.cpp new file mode 100644 index 0000000..5a3e181 --- /dev/null +++ b/src/nnvm/nnvm/Analysis/BasicAA.cpp @@ -0,0 +1,9 @@ +#include "BasicAA.h" +#include "Analysis/AAInfo.h" +using namespace nnvm; + +AAFlag BasicAA::alias(Value *a, Value *b) { + if (a == b) + return MustAlias; + return MayAlias; +} diff --git a/src/nnvm/nnvm/Analysis/BasicAA.h b/src/nnvm/nnvm/Analysis/BasicAA.h new file mode 100644 index 0000000..5fa4bb9 --- /dev/null +++ b/src/nnvm/nnvm/Analysis/BasicAA.h @@ -0,0 +1,13 @@ +#pragma once + +#include "Analysis/AAInfo.h" +#include "IR/Value.h" + +namespace nnvm { + +class BasicAA { +public: + AAFlag alias(Value *a, Value *b); +}; + +} /* namespace nnvm */ diff --git a/src/nnvm/nnvm/Backend/RISCV/CodeLayoutOpt.cpp b/src/nnvm/nnvm/Backend/RISCV/CodeLayoutOpt.cpp new file mode 100644 index 0000000..b420ed8 --- /dev/null +++ b/src/nnvm/nnvm/Backend/RISCV/CodeLayoutOpt.cpp @@ -0,0 +1,22 @@ +#include "CodeLayoutOpt.h" +#include "ADT/GenericInt.h" +#include "ADT/Ranges.h" +#include "Backend/RISCV/CodegenInfo.h" +#include "Backend/RISCV/LowIR.h" +#include "Backend/RISCV/LowIR/LIRValue.h" +#include "Backend/RISCV/LowInstType.h" +#include "Backend/RISCV/PhiResolution.h" +#include "IR/Instruction.h" +#include "Utils/Cast.h" +#include "Utils/Debug.h" +#include <utility> +using namespace nnvm; +using namespace nnvm::riscv; + +void CodeLayoutOpt::layout(LIRFunc &func) { + std::list<LIRBB *> result; + std::unordered_set<LIRBB *> visited; + + std::stack<LIRBB *> worklist; + worklist.push(func.getEntry()); +} diff --git a/src/nnvm/nnvm/Backend/RISCV/CodeLayoutOpt.h b/src/nnvm/nnvm/Backend/RISCV/CodeLayoutOpt.h new file mode 100644 index 0000000..1961582 --- /dev/null +++ b/src/nnvm/nnvm/Backend/RISCV/CodeLayoutOpt.h @@ -0,0 +1,12 @@ + +#pragma once +#include "Backend/RISCV/LowIR.h" +#include "Backend/RISCV/LowIR/Builder.h" +namespace nnvm::riscv { + +class CodeLayoutOpt { +public: + void layout(LIRFunc &func); +}; + +} /* namespace nnvm::riscv */ diff --git a/src/nnvm/nnvm/Backend/RISCV/CodegenInfo.cpp b/src/nnvm/nnvm/Backend/RISCV/CodegenInfo.cpp index 40d06e6..711f836 100644 --- a/src/nnvm/nnvm/Backend/RISCV/CodegenInfo.cpp +++ b/src/nnvm/nnvm/Backend/RISCV/CodegenInfo.cpp @@ -70,6 +70,7 @@ LIRBB *riscv::getBranchDest(LIRInst *inst) { case (B_BEGIN + 1)...(B_END - 1): return inst->getOp(2)->as<LIRBB>(); default: + std::cerr << "Invalid opcode" << inst->getOpcode() << "\n"; nnvm_unreachable("Must be a valid branch instruction") } } @@ -83,6 +84,7 @@ void riscv::setBranchDest(LIRInst *inst, LIRBB *dest) { inst->setUse(2, dest); break; default: + std::cerr << "Invalid opcode" << inst->getOpcode() << "\n"; nnvm_unreachable("Must be a valid branch instruction") } } diff --git a/src/nnvm/nnvm/Backend/RISCV/LowIR.cpp b/src/nnvm/nnvm/Backend/RISCV/LowIR.cpp index 8130ae9..82d1073 100644 --- a/src/nnvm/nnvm/Backend/RISCV/LowIR.cpp +++ b/src/nnvm/nnvm/Backend/RISCV/LowIR.cpp @@ -11,6 +11,9 @@ using namespace nnvm::riscv; void LIRInst::emit(std::ostream &out, EmitInfo &info) { switch (type) { + case IMPLICIT_JUMP: + out << ""; + break; case SB: case SH: case SW: diff --git a/src/nnvm/nnvm/Backend/RISCV/LowInstType.h b/src/nnvm/nnvm/Backend/RISCV/LowInstType.h index ba987c3..302d19f 100644 --- a/src/nnvm/nnvm/Backend/RISCV/LowInstType.h +++ b/src/nnvm/nnvm/Backend/RISCV/LowInstType.h @@ -53,6 +53,8 @@ namespace nnvm::riscv { enum LIRInstID : uint64_t { // ==== Middle IR Reserved ==== // ..... + // ==== Generic ==== + IMPLICIT_JUMP, // ==== RISC-V Specific ==== ISA_BEGIN = (uint64_t)InstID::INST_END + 1, NONE, diff --git a/src/nnvm/nnvm/Backend/RISCV/PhiResolution.cpp b/src/nnvm/nnvm/Backend/RISCV/PhiResolution.cpp index a13e4a5..dba594b 100644 --- a/src/nnvm/nnvm/Backend/RISCV/PhiResolution.cpp +++ b/src/nnvm/nnvm/Backend/RISCV/PhiResolution.cpp @@ -29,10 +29,14 @@ void PhiResolution::processBB(LIRBB *BB) { for (uint64_t i = 1; i < firstPhi->getNumOp(); i += 2) { LIRBB *incomingBB = firstPhi->getOp(i)->as<LIRBB>(); if (incomingBB->getSuccNum() == 2) { - int succIndex; - for (succIndex = 0; succIndex < incomingBB->getSuccNum(); succIndex++) - if (incomingBB->getSucc(succIndex) == BB) + int succIndex = -1; + for (int index = 0; index < incomingBB->getSuccNum(); index++) + if (incomingBB->getSucc(index) == BB) { + succIndex = index; break; + } + + assert(succIndex != -1); LIRBB *splittedBB = new LIRBB; incomingBB->setSucc(succIndex, splittedBB); diff --git a/src/nnvm/nnvm/Backend/RISCV/Spiller.h b/src/nnvm/nnvm/Backend/RISCV/Spiller.h deleted file mode 100644 index 3a4d55b..0000000 --- a/src/nnvm/nnvm/Backend/RISCV/Spiller.h +++ /dev/null @@ -1,7 +0,0 @@ -#pragma once - -namespace nnvm::riscv { -class Spiller { -public: -}; -} /* namespace nnvm::riscv */ diff --git a/src/nnvm/nnvm/Frontend/IRGenerator.cpp b/src/nnvm/nnvm/Frontend/IRGenerator.cpp index e61fad7..81ca267 100644 --- a/src/nnvm/nnvm/Frontend/IRGenerator.cpp +++ b/src/nnvm/nnvm/Frontend/IRGenerator.cpp @@ -339,6 +339,7 @@ Any IRGenerator::constDef(SysYParser::ConstDefContext *ctx, if (symbolTable.isGlobal()) { GlobalVariable *global = new GlobalVariable(*ir, constVal); global->setName(ctx->IDENT()->getText()); + global->setImmutable(true); return symbolTable.create(symbolName, symbolType, global); } else { Type *irElementType = toIRType(symbolType->getInnerMost()); @@ -361,6 +362,7 @@ Any IRGenerator::constDef(SysYParser::ConstDefContext *ctx, if (symbolTable.isGlobal()) { GlobalVariable *global = new GlobalVariable(*ir, constVal); global->setName(ctx->IDENT()->getText()); + global->setImmutable(true); return symbolTable.create(symbolName, symbolType, global); } else { return symbolTable.create(symbolName, symbolType, constVal); diff --git a/src/nnvm/nnvm/IR/BasicBlock.h b/src/nnvm/nnvm/IR/BasicBlock.h index 0b4e0d1..044ca3c 100644 --- a/src/nnvm/nnvm/IR/BasicBlock.h +++ b/src/nnvm/nnvm/IR/BasicBlock.h @@ -76,9 +76,19 @@ public: return *this; } + PredIterator operator++() { + PredIterator ret = cur; + cur++; + while (cur != end && !dyn_cast<TerminatorInst>((*cur)->getUser())) + cur++; + return ret; + } + BasicBlock *operator*() { return cast<Instruction>((*cur)->getUser())->getParent(); } + + bool operator==(PredIterator other) { return cur == other.cur; } bool operator!=(PredIterator other) { return cur != other.cur; } private: @@ -102,6 +112,11 @@ public: ListTrait<BasicBlock>::eraseFromList(); } + bool containsPhi() { return dyn_cast<PhiInst>(*begin()); } + + const List<Instruction> &getInsts() const { return instList; } + List<Instruction> &getInsts() { return instList; } + ~BasicBlock(); private: diff --git a/src/nnvm/nnvm/IR/GlobalVariable.cpp b/src/nnvm/nnvm/IR/GlobalVariable.cpp index 8ca39c4..cd9be77 100644 --- a/src/nnvm/nnvm/IR/GlobalVariable.cpp +++ b/src/nnvm/nnvm/IR/GlobalVariable.cpp @@ -5,11 +5,11 @@ using namespace nnvm; GlobalVariable::GlobalVariable(Module &module, Type *innerType) : Value(ValueID::GlobalVariable, module.getPtrType()), innerType(innerType), - module(module) {} + immutable(false), module(module) {} GlobalVariable::GlobalVariable(Module &module, Constant *initVal) : Value(ValueID::GlobalVariable, module.getPtrType()), initVal(initVal), - innerType(initVal->getType()), module(module) { + innerType(initVal->getType()), immutable(false), module(module) { module.addGlobalVar(this); } @@ -20,8 +20,10 @@ void GlobalVariable::setName(const std::string &name) { } std::string GlobalVariable::dump() { + std::string decorator = immutable ? "immutable " : ""; auto initDump = (initVal ? (" init with " + initVal->dumpAsOperand()) : ""); - return "global " + type->dump() + " " + getName() + initDump + "\n"; + return decorator + "global " + type->dump() + " " + getName() + initDump + + "\n"; } std::string GlobalVariable::dumpAsOperand() { diff --git a/src/nnvm/nnvm/IR/GlobalVariable.h b/src/nnvm/nnvm/IR/GlobalVariable.h index 7e0fb3c..71931aa 100644 --- a/src/nnvm/nnvm/IR/GlobalVariable.h +++ b/src/nnvm/nnvm/IR/GlobalVariable.h @@ -21,9 +21,13 @@ public: void setInnerType(Type *innerType) { this->innerType = innerType; } Type *getInnerType() { return innerType; } + void setImmutable(bool immutable) { this->immutable = immutable; } + bool isImmutable() const { return immutable; } + private: Constant *initVal; Type *innerType; + bool immutable; Module &module; }; diff --git a/src/nnvm/nnvm/IR/Instruction.h b/src/nnvm/nnvm/IR/Instruction.h index 6f54409..967acf5 100644 --- a/src/nnvm/nnvm/IR/Instruction.h +++ b/src/nnvm/nnvm/IR/Instruction.h @@ -119,6 +119,7 @@ public: const Metadata *getMetadata() const { return metadata; } Metadata *getMetadata() { return metadata; } + void removeFromBB() { ListTrait<Instruction>::removeFromList(); } void eraseFromBB() { for (Use *use : useeList) use->eraseFromList(); diff --git a/src/nnvm/nnvm/Transform/Infra/BlockUtils.cpp b/src/nnvm/nnvm/Transform/Infra/BlockUtils.cpp new file mode 100644 index 0000000..c83bdf4 --- /dev/null +++ b/src/nnvm/nnvm/Transform/Infra/BlockUtils.cpp @@ -0,0 +1,11 @@ +#include "BlockUtils.h" +#include "ADT/Ranges.h" + +using namespace nnvm; + +void nnvm::moveInstInBlock(BasicBlock *from, BasicBlock *to) { + for (Instruction *I : incChange(*from)) { + I->removeFromBB(); + to->end().insertBefore(I); + } +} diff --git a/src/nnvm/nnvm/Transform/Infra/BlockUtils.h b/src/nnvm/nnvm/Transform/Infra/BlockUtils.h new file mode 100644 index 0000000..639396d --- /dev/null +++ b/src/nnvm/nnvm/Transform/Infra/BlockUtils.h @@ -0,0 +1,8 @@ +#pragma once + +#include "IR/Function.h" +#include "IR/Module.h" +#include <memory> +namespace nnvm { +void moveInstInBlock(BasicBlock *from, BasicBlock *to); +} /* namespace nnvm */ diff --git a/src/nnvm/nnvm/Transform/Scalar/CFGCombiner.cpp b/src/nnvm/nnvm/Transform/Scalar/CFGCombiner.cpp index 93bb8a0..c66c1a2 100644 --- a/src/nnvm/nnvm/Transform/Scalar/CFGCombiner.cpp +++ b/src/nnvm/nnvm/Transform/Scalar/CFGCombiner.cpp @@ -1,6 +1,7 @@ #include "CFGCombiner.h" #include "ADT/Ranges.h" #include "IR/Instruction.h" +#include "Transform/Infra/BlockUtils.h" #include "Utils/Cast.h" #include "Utils/Debug.h" using namespace nnvm; @@ -34,14 +35,81 @@ bool CFGCombinerPass::run(Function &F) { bool CFGCombinerPass::processBB(BasicBlock *BB) { if (BranchInst *BI = dyn_cast<BranchInst>(BB->getTerminator())) { - // Replace "br true, a, b" with "br a". - if (BI->isConditional() && BI->getCondition()->isConstant()) { - ConstantInt *constCond = dyn_cast<ConstantInt>(BI->getCondition()); - builder.setInsertPoint(BB->end()); - builder.buildBr(constCond->getValue() ? BI->getSucc(0) : BI->getSucc(1)); - BI->eraseFromBB(); - return true; - } + if (!BI->isConditional()) + return foldBBWithUncondBr(BB, BI); + + if (BI->isConditional()) + return foldBBWithCondBr(BB, BI); + } + return false; +} + +bool CFGCombinerPass::foldBBWithUncondBr(BasicBlock *BB, BranchInst *BI) { + + BasicBlock *succ = BI->getSucc(0); + if (BB == succ) + return false; + + // Before: + // preds --> BB --> succ --> ... + // After: + // preds --> [BB -- succ] --> ... + if (succ->getInsts().size() > 1) { + // The successor must have BB as the single predecessor. + if (succ->getPredNum() != 1 || succ->containsPhi()) + return false; + + BI->eraseFromBB(); + succ->replaceSelf(BB); + moveInstInBlock(succ, BB); + + IRBuilder builder; + builder.setInsertPoint(succ->end()); + builder.buildUnreachable(); + return true; + } + + // Before: + // pred --> BB --> succ + // After: + // pred --> succ + + if (BB->getInsts().size() != 1) + return false; + if (BB->getPredNum() != 1) + return false; + BasicBlock *pred = *BB->getPredBegin(); + if (pred->getSuccNum() != 1) + return false; + // Those jump from pred to BB, now jump from pred to succ directly. + TerminatorInst *TI = pred->getTerminator(); + for (int i = 0; i < TI->getSuccNum(); i++) + if (TI->getSucc(i) == BB) + TI->setSucc(i, succ); + + // Replace BB in phis with pred. + BB->replaceSelf(pred); + + return true; +} + +bool CFGCombinerPass::foldBBWithCondBr(BasicBlock *BB, BranchInst *BI) { + // Replace "br true, a, b" with "br a". + if (BI->getCondition()->isConstant()) { + ConstantInt *constCond = dyn_cast<ConstantInt>(BI->getCondition()); + + BasicBlock *unlinked = + constCond->getValue() ? BI->getSucc(1) : BI->getSucc(0); + for (Instruction *I : *unlinked) + if (PhiInst *phi = dyn_cast<PhiInst>(I)) + phi->removeIncoming(BB); + else + break; + + builder.setInsertPoint(BB->end()); + builder.buildBr(constCond->getValue() ? BI->getSucc(0) : BI->getSucc(1)); + BI->eraseFromBB(); + return true; } return false; } diff --git a/src/nnvm/nnvm/Transform/Scalar/CFGCombiner.h b/src/nnvm/nnvm/Transform/Scalar/CFGCombiner.h index 8ef902b..1181bf4 100644 --- a/src/nnvm/nnvm/Transform/Scalar/CFGCombiner.h +++ b/src/nnvm/nnvm/Transform/Scalar/CFGCombiner.h @@ -18,6 +18,8 @@ public: private: bool processBB(BasicBlock *BB); + bool foldBBWithUncondBr(BasicBlock *BB, BranchInst *BI); + bool foldBBWithCondBr(BasicBlock *BB, BranchInst *BI); IRBuilder builder; }; } /* namespace nnvm */ diff --git a/src/nnvm/nnvm/Transform/Scalar/CSE.cpp b/src/nnvm/nnvm/Transform/Scalar/CSE.cpp index d04e800..9a041f3 100644 --- a/src/nnvm/nnvm/Transform/Scalar/CSE.cpp +++ b/src/nnvm/nnvm/Transform/Scalar/CSE.cpp @@ -19,6 +19,8 @@ bool CSEPass::EqInstImpl::operator()(Instruction *A, Instruction *B) const { return true; if (A->getOpcode() != B->getOpcode()) return false; + if (A->getType() != B->getType()) + return false; if (A->getOperandNum() != B->getOperandNum()) return false; for (uint i = 0; i < A->getOperandNum(); i++) { @@ -28,11 +30,11 @@ bool CSEPass::EqInstImpl::operator()(Instruction *A, Instruction *B) const { if (dyn_cast<ICmpInst>(A)) return cast<ICmpInst>(A)->getPredicate() == - dyn_cast<ICmpInst>(B)->getPredicate(); + cast<ICmpInst>(B)->getPredicate(); if (dyn_cast<FCmpInst>(A)) return cast<FCmpInst>(A)->getPredicate() == - dyn_cast<FCmpInst>(B)->getPredicate(); + cast<FCmpInst>(B)->getPredicate(); // TODO: handle commutative operators. return true; diff --git a/src/nnvm/nnvm/Transform/Scalar/CombinePatterns.h b/src/nnvm/nnvm/Transform/Scalar/CombinePatterns.h index 64d2493..780b110 100644 --- a/src/nnvm/nnvm/Transform/Scalar/CombinePatterns.h +++ b/src/nnvm/nnvm/Transform/Scalar/CombinePatterns.h @@ -37,6 +37,42 @@ protected: Value **receiver; }; +class pConstantInt { +public: + pConstantInt() : receiver(nullptr) {} + pConstantInt(ConstantInt *&receiver) : receiver(&receiver) {} + bool match(Value *op) { + ConstantInt *opcasted = dyn_cast<ConstantInt>(op); + if (!opcasted) + return false; + if (receiver) + *receiver = opcasted; + return true; + } + +protected: + ConstantInt **receiver; +}; + +class pZero { +public: + pZero() : receiver(nullptr) {} + pZero(ConstantInt *&receiver) : receiver(&receiver) {} + bool match(Value *op) { + ConstantInt *opcasted = dyn_cast<ConstantInt>(op); + if (!opcasted) + return false; + if (opcasted->getValue() != 0) + return false; + if (receiver) + *receiver = opcasted; + return true; + } + +protected: + ConstantInt **receiver; +}; + class pInst { public: pInst() : receiver(nullptr) {} @@ -71,7 +107,8 @@ public: template <InstID instID, typename LSubPattern, typename RSubPattern> class pBinOp : public pSpecificInst<instID> { public: - pBinOp(LSubPattern LHS, RSubPattern RHS) : pSpecificInst<instID>(), LHS(LHS), RHS(RHS) {} + pBinOp(LSubPattern LHS, RSubPattern RHS) + : pSpecificInst<instID>(), LHS(LHS), RHS(RHS) {} bool match(Value *op) { if (!pSpecificInst<instID>::match(op)) @@ -99,4 +136,27 @@ public: : pBinOp<InstID::Add, LSubPattern, RSubPattern>(LHS, RHS) {} }; +template <typename LSubPattern, typename RSubPattern> +class pICmp : public pSpecificInst<InstID::ICmp> { +public: + pICmp(LSubPattern LHS, RSubPattern RHS) + : pSpecificInst<InstID::ICmp>(), LHS(LHS), RHS(RHS) {} + + bool match(Value *op) { + if (!pSpecificInst<InstID::ICmp>::match(op)) + return false; + + ICmpInst *I = cast<ICmpInst>(op); + if (!LHS.match(I->getOperand(0))) + return false; + if (!RHS.match(I->getOperand(1))) + return false; + return true; + } + +protected: + LSubPattern LHS; + RSubPattern RHS; +}; + } // namespace nnvm::pattern diff --git a/src/nnvm/nnvm/Transform/Scalar/Combiner.cpp b/src/nnvm/nnvm/Transform/Scalar/Combiner.cpp index 020ae84..9185ad9 100644 --- a/src/nnvm/nnvm/Transform/Scalar/Combiner.cpp +++ b/src/nnvm/nnvm/Transform/Scalar/Combiner.cpp @@ -50,14 +50,26 @@ Value *CombinerPass::simplifyInst(Instruction *I) { if (SDivInst *SI = dyn_cast<SDivInst>(I)) return simplifySDiv(SI); - //if (PhiInst *phi = dyn_cast<PhiInst>(I)) - //return simplifyPhi(phi); + if (ICmpInst *ICI = dyn_cast<ICmpInst>(I)) + return simplifyICmp(ICI); + + if (PhiInst *phi = dyn_cast<PhiInst>(I)) + return simplifyPhi(phi); return nullptr; } Value *CombinerPass::simplifyAdd(AddInst *I) { Value *A, *B, *C; + + // C1 + A --> A + C1 + if (match(I, pAdd(pConstant(A), pValue(B)))) + return builder.buildBinOp<AddInst>(B, A, I->getType()); + + // A + 0 --> A + if (match(I, pAdd(pValue(A), pZero()))) + return A; + // (A + C1) + C2 --> A + (C1 + C2) if (match(I, pAdd(pAdd(pValue(A), pConstant(B)), pConstant(C)))) { Value *addc = builder.buildBinOp<AddInst>(B, C, I->getType()); @@ -69,9 +81,19 @@ Value *CombinerPass::simplifyAdd(AddInst *I) { Value *CombinerPass::simplifySDiv(SDivInst *I) { return nullptr; } +Value *CombinerPass::simplifyICmp(ICmpInst *I) { + Value *A, *B; + // A != 0 --> A + if (I->getPredicate() == ICmpInst::NE && + match(I, pICmp(pValue(A), pZero())) && A->getType()->isIntegerNBits(1)) + return A; + return nullptr; +} + Value *CombinerPass::simplifyPhi(PhiInst *I) { if (I->getIncomingNum() == 1) - return I->getIncomingValue(0); + return I->getIncomingValue(0)->isInstruction() ? nullptr + : I->getIncomingValue(0); return nullptr; } diff --git a/src/nnvm/nnvm/Transform/Scalar/Combiner.h b/src/nnvm/nnvm/Transform/Scalar/Combiner.h index d463170..fcf6d60 100644 --- a/src/nnvm/nnvm/Transform/Scalar/Combiner.h +++ b/src/nnvm/nnvm/Transform/Scalar/Combiner.h @@ -22,6 +22,8 @@ private: Value *simplifyInst(Instruction *I); Value *simplifyAdd(AddInst *I); Value *simplifySDiv(SDivInst *I); + + Value *simplifyICmp(ICmpInst *I); Value *simplifyPhi(PhiInst *I); std::queue<Instruction *> worklist; diff --git a/src/nnvm/nnvm/Transform/Scalar/ConstantFold.cpp b/src/nnvm/nnvm/Transform/Scalar/ConstantFold.cpp index 606e4a8..51882c0 100644 --- a/src/nnvm/nnvm/Transform/Scalar/ConstantFold.cpp +++ b/src/nnvm/nnvm/Transform/Scalar/ConstantFold.cpp @@ -45,6 +45,9 @@ Value *ConstantFold::fold(Instruction *I) { dyn_cast<Constant>(CI->getOperand(1))) return foldICmp(CI); + if (LoadInst *LI = dyn_cast<LoadInst>(I)) + return foldLoad(LI); + // TODO: Handle other operator on constant operands, such as "a[0]", where "a" // is a constant array. @@ -123,3 +126,13 @@ Value *ConstantFold::foldICmp(ICmpInst *I) { return nullptr; } } + +Value *ConstantFold::foldLoad(LoadInst *I) { + if (GlobalVariable *GV = dyn_cast<GlobalVariable>(I->getSrc())) { + if (!GV->isImmutable()) + return nullptr; + if (I->getType() == GV->getInitVal()->getType()) + return GV->getInitVal(); + } + return nullptr; +} diff --git a/src/nnvm/nnvm/Transform/Scalar/ConstantFold.h b/src/nnvm/nnvm/Transform/Scalar/ConstantFold.h index 9b7184d..e0f2063 100644 --- a/src/nnvm/nnvm/Transform/Scalar/ConstantFold.h +++ b/src/nnvm/nnvm/Transform/Scalar/ConstantFold.h @@ -25,6 +25,7 @@ public: Value *foldSDiv(SDivInst *I); Value *foldSRem(SRemInst *I); Value *foldICmp(ICmpInst *I); + Value *foldLoad(LoadInst *I); void setModule(Module *module) { this->module = module; } diff --git a/timestamp b/timestamp index 1494600..d058203 100644 --- a/timestamp +++ b/timestamp @@ -1 +1 @@ -Sun Jul 7 08:24:46 UTC 2024 +Sun Jul 7 16:33:38 UTC 2024 -- GitLab