diff --git a/include/opt/LoopIdvSimplify.hh b/include/opt/LoopIdvSimplify.hh new file mode 100644 index 0000000000000000000000000000000000000000..1e40a679456a9df640440bec46668c2bf1d1cefe --- /dev/null +++ b/include/opt/LoopIdvSimplify.hh @@ -0,0 +1,17 @@ +#ifndef _LOOP_IDV_SIMPLIFY_H_ +#define _LOOP_IDV_SIMPLIFY_H_ + +#include "LoopInfo.hh" +#include "Optimization.hh" + +class LoopIdvSimplify : public Optimization { + private: + bool runOnLoop(LoopInfo* loopInfo); + + public: + LoopIdvSimplify() {} + bool runOnModule(ANTPIE::Module* module) override; + bool runOnFunction(Function* func) override; +}; + +#endif \ No newline at end of file diff --git a/include/opt/LoopInfo.hh b/include/opt/LoopInfo.hh index 72ff152e1edd6cc363d8700908d08ee1a252691c..6d505e8f9de64c2509249532548e2b29c0cbd558 100644 --- a/include/opt/LoopInfo.hh +++ b/include/opt/LoopInfo.hh @@ -51,6 +51,8 @@ class LoopInfo { bool isEmptyLoop() const; void deleteLoop(); + + bool isInvariant(Value* value); string getName() { return header->getName() + "Loop"; } void dump(); diff --git a/src/ir/Module.cc b/src/ir/Module.cc index db67398402dd4a9f479c66e197cc4574a1d2abae..452f936e25f4e271cda78bac9c6ecbd8c9ba7c8e 100644 --- a/src/ir/Module.cc +++ b/src/ir/Module.cc @@ -14,6 +14,7 @@ #include "Inlining.hh" #include "LoadElimination.hh" #include "LoopAnalysis.hh" +#include "LoopIdvSimplify.hh" #include "LoopInvariantCodeMotion.hh" #include "LoopSimplify.hh" #include "LoopUnroll.hh" @@ -298,6 +299,7 @@ void Module::irOptimize() { optimizations.pushBack(new LoopAnalysis()); optimizations.pushBack(new LoopSimplify(false)); optimizations.pushBack(new AliasAnalysis()); + optimizations.pushBack(new LoopIdvSimplify()); optimizations.pushBack(new LoopInvariantCodeMotion()); optimizations.pushBack(new LoopUnroll()); optimizations.pushBack(new DeadCodeElimination()); @@ -324,9 +326,11 @@ void Module::irOptimize() { optimizations.pushBack(new InductionVariableSimplify()); optimizations.pushBack(new StoreElimination()); + optimizations.pushBack(new Reassociate()); optimizations.pushBack(new ConstantFolding()); optimizations.pushBack(new CFGSimplify()); optimizations.pushBack(new MergeBlock()); + optimizations.pushBack(new DeadCodeElimination()); // run all pass diff --git a/src/opt/CMakeLists.txt b/src/opt/CMakeLists.txt index 3ec34f2a552b0224999cf86a59a77d2bf2a77cb5..62eb5eaae7200f64c694873cebcb8f3fda1aedbf 100644 --- a/src/opt/CMakeLists.txt +++ b/src/opt/CMakeLists.txt @@ -26,4 +26,5 @@ add_library( StoreElimination.cc LruCache.cc InductionVariableSimplify.cc + LoopIdvSimplify.cc ) \ No newline at end of file diff --git a/src/opt/GEPSimplify.cc b/src/opt/GEPSimplify.cc index e84e6e9a0315355622c24cc369c7d3e0fa4c427b..8fdc0ca52ec91bccd61217e0a12f37586c19f7f6 100644 --- a/src/opt/GEPSimplify.cc +++ b/src/opt/GEPSimplify.cc @@ -20,6 +20,8 @@ bool GEPSimplify::runOnBasicBlock(BasicBlock* block) { unordered_map<Value*, std::pair<GetElemPtrInst*, GetElemPtrInst*>> preInstrMap; vector<std::pair<GetElemPtrInst*, GetElemPtrInst*>> gepPairs; + + // gep addr, 0, idx for (Instruction* instr : *block->getInstructions()) { GetElemPtrInst* oldGep = dynamic_cast<GetElemPtrInst*>(instr); if (!oldGep || oldGep->getRValueSize() != 3) continue; @@ -50,6 +52,38 @@ bool GEPSimplify::runOnBasicBlock(BasicBlock* block) { preInstrMap[loc] = {oldGep, newGep}; } + + // gep addr, idx + preInstrMap.clear(); + for (Instruction* instr : *block->getInstructions()) { + GetElemPtrInst* oldGep = dynamic_cast<GetElemPtrInst*>(instr); + if (!oldGep || oldGep->getRValueSize() != 2) continue; + Value* loc = oldGep->getRValue(0); + auto it = preInstrMap.find(loc); + if (it == preInstrMap.end()) { + preInstrMap[loc] = {oldGep, oldGep}; + continue; + } + auto preOldGep = it->second.first; + auto preNewGep = it->second.second; + + if (preOldGep->getRValue(0) != oldGep->getRValue(0)) { + preInstrMap[loc] = {oldGep, oldGep}; + continue; + } + BinaryOpInst* strideInstr = + dynamic_cast<BinaryOpInst*>(oldGep->getRValue(1)); + if (!strideInstr || strideInstr->getOpTag() != ADD || + strideInstr->getRValue(0) != preOldGep->getRValue(1)) { + preInstrMap[loc] = {oldGep, oldGep}; + continue; + } + GetElemPtrInst* newGep = + new GetElemPtrInst(preNewGep, strideInstr->getRValue(1), "new.gep"); + gepPairs.emplace_back(oldGep, newGep); + preInstrMap[loc] = {oldGep, newGep}; + } + for (auto [oldGep, newGep] : gepPairs) { newGep->moveBefore(oldGep); oldGep->replaceAllUsesWith(newGep); diff --git a/src/opt/LoopIdvSimplify.cc b/src/opt/LoopIdvSimplify.cc new file mode 100644 index 0000000000000000000000000000000000000000..5e96849d8881d930037388e000893c01724417e5 --- /dev/null +++ b/src/opt/LoopIdvSimplify.cc @@ -0,0 +1,134 @@ +#include "LoopIdvSimplify.hh" + +bool LoopIdvSimplify::runOnModule(ANTPIE::Module* module) { + bool changed = false; + for (Function* func : *module->getFunctions()) { + changed |= runOnFunction(func); + } + return changed; +} + +bool LoopIdvSimplify::runOnFunction(Function* func) { + bool changed = false; + for (LoopInfo* loopInfo : func->getLoopInfoBase()->loopInfos) { + changed |= runOnLoop(loopInfo); + } + return changed; +} + +/** + * while(i < n) { + * a = i * 4; + * i = i + 1; + * } + * => + * ----------------------- + * while(i < n) { + * a = a + 4; + * i = i + 1; + * } + */ +bool LoopIdvSimplify::runOnLoop(LoopInfo* loopInfo) { + bool changed = false; + SimpleLoopInfo* simpleLoop = loopInfo->simpleLoop; + if (!simpleLoop) return false; + PhiInst* counter = simpleLoop->phiInstr; + BinaryOpInst* strideInst = simpleLoop->strideInstr; + if (strideInst->getOpTag() != ADD) return false; + Value* stride = strideInst->getRValue(1); + if (!loopInfo->isInvariant(stride)) return false; + + BranchInst* branch = simpleLoop->brInstr; + BasicBlock* exitBlock = *loopInfo->exits.begin(); + BasicBlock* trueBlock = nullptr; + BasicBlock* header = loopInfo->header; + BasicBlock* preHeader = loopInfo->preHeader; + Value* initValue = simpleLoop->initValue; + if ((BasicBlock*)branch->getRValue(1) == exitBlock) { + trueBlock = (BasicBlock*)branch->getRValue(2); + } else { + trueBlock = (BasicBlock*)branch->getRValue(1); + } + for (Instruction* instr : *trueBlock->getInstructions()) { + BinaryOpInst* bopInstr = dynamic_cast<BinaryOpInst*>(instr); + if (!bopInstr || bopInstr == strideInst) continue; + if (bopInstr->getRValue(1) == counter) { + bopInstr->swapRValueAt(0, 1); + } + if (bopInstr->getRValue(0) != counter) continue; + Value* offset = bopInstr->getRValue(1); + if (!loopInfo->isInvariant(offset)) continue; + switch (bopInstr->getOpTag()) { + case ADD: { + PhiInst* idvPhi = new PhiInst(bopInstr->getName() + ".phi"); + BinaryOpInst* idvInit = new BinaryOpInst(ADD, initValue, offset, + bopInstr->getName() + ".init"); + idvInit->moveBefore(preHeader->getTailInstr()); + idvPhi->pushIncoming(idvInit, preHeader); + BinaryOpInst* idvStrideInst = new BinaryOpInst( + ADD, idvPhi, stride, bopInstr->getName() + ".stride"); + idvStrideInst->moveBefore(trueBlock->getTailInstr()); + bopInstr->replaceAllUsesWith(idvPhi); + for (BasicBlock* latch : loopInfo->latches) { + idvPhi->pushIncoming(idvStrideInst, latch); + } + header->pushInstrAtHead(idvPhi); + changed = true; + break; + } + + case MUL: { + PhiInst* idvPhi = new PhiInst(bopInstr->getName() + ".phi"); + BinaryOpInst* idvInit = new BinaryOpInst(MUL, initValue, offset, + bopInstr->getName() + ".init"); + BinaryOpInst* increValue = new BinaryOpInst( + MUL, stride, offset, bopInstr->getName() + ".incr"); + idvInit->moveBefore(preHeader->getTailInstr()); + increValue->moveBefore(preHeader->getTailInstr()); + idvPhi->pushIncoming(idvInit, preHeader); + BinaryOpInst* idvStrideInst = new BinaryOpInst( + ADD, idvPhi, increValue, bopInstr->getName() + ".stride"); + idvStrideInst->moveBefore(trueBlock->getTailInstr()); + bopInstr->replaceAllUsesWith(idvPhi); + for (BasicBlock* latch : loopInfo->latches) { + idvPhi->pushIncoming(idvStrideInst, latch); + } + header->pushInstrAtHead(idvPhi); + changed = true; + break; + } + + case SHL: { + PhiInst* idvPhi = new PhiInst(bopInstr->getName() + ".phi"); + BinaryOpInst* idvInit = new BinaryOpInst(SHL, initValue, offset, + bopInstr->getName() + ".init"); + Value* mulValue = 0; + if (offset->isa(VT_INTCONST)) { + mulValue = IntegerConstant::getConstInt( + 1 << ((IntegerConstant*)offset)->getValue()); + } else { + mulValue = new BinaryOpInst(SHL, IntegerConstant::getConstInt(1), + offset, bopInstr->getName() + ".mulv"); + } + BinaryOpInst* increValue = new BinaryOpInst( + MUL, stride, mulValue, bopInstr->getName() + ".incr"); + idvInit->moveBefore(preHeader->getTailInstr()); + increValue->moveBefore(preHeader->getTailInstr()); + idvPhi->pushIncoming(idvInit, preHeader); + BinaryOpInst* idvStrideInst = new BinaryOpInst( + ADD, idvPhi, increValue, bopInstr->getName() + ".stride"); + idvStrideInst->moveBefore(trueBlock->getTailInstr()); + bopInstr->replaceAllUsesWith(idvPhi); + for (BasicBlock* latch : loopInfo->latches) { + idvPhi->pushIncoming(idvStrideInst, latch); + } + header->pushInstrAtHead(idvPhi); + changed = true; + break; + } + default: + continue; + } + } + return true; +} diff --git a/src/opt/LoopInfo.cc b/src/opt/LoopInfo.cc index d3e6771ad4edf3cd4862b684e1f691ffe8ce599d..98cc2a4098e99cf2759964e0762ff413409c1806 100644 --- a/src/opt/LoopInfo.cc +++ b/src/opt/LoopInfo.cc @@ -122,6 +122,11 @@ void LoopInfo::deleteLoop() { } } } +bool LoopInfo::isInvariant(Value* value) { + Instruction* instr = dynamic_cast<Instruction*>(value); + if (!instr) return true; + return !containBlockInChildren(instr->getParent()); +} void LoopInfo::dump() { std::cout << "Header: " << header->getName() << std::endl;