From 2b0244616aad9bc2eb4006c005d41893828ffadd Mon Sep 17 00:00:00 2001
From: NNVM Authors <xxs_chy@outlook.com>
Date: Sun, 7 Jul 2024 16:33:38 +0000
Subject: [PATCH] sync Sun Jul  7 16:33:38 UTC 2024

---
 src/nnvm/nnvm/ADT/Ranges.h                    | 23 +++--
 src/nnvm/nnvm/Analysis/AAInfo.h               |  6 ++
 src/nnvm/nnvm/Analysis/AliasAnalysis.cpp      |  5 ++
 src/nnvm/nnvm/Analysis/AliasAnalysis.h        |  4 +
 src/nnvm/nnvm/Analysis/BasicAA.cpp            |  9 ++
 src/nnvm/nnvm/Analysis/BasicAA.h              | 13 +++
 src/nnvm/nnvm/Backend/RISCV/CodeLayoutOpt.cpp | 22 +++++
 src/nnvm/nnvm/Backend/RISCV/CodeLayoutOpt.h   | 12 +++
 src/nnvm/nnvm/Backend/RISCV/CodegenInfo.cpp   |  2 +
 src/nnvm/nnvm/Backend/RISCV/LowIR.cpp         |  3 +
 src/nnvm/nnvm/Backend/RISCV/LowInstType.h     |  2 +
 src/nnvm/nnvm/Backend/RISCV/PhiResolution.cpp | 10 ++-
 src/nnvm/nnvm/Backend/RISCV/Spiller.h         |  7 --
 src/nnvm/nnvm/Frontend/IRGenerator.cpp        |  2 +
 src/nnvm/nnvm/IR/BasicBlock.h                 | 15 ++++
 src/nnvm/nnvm/IR/GlobalVariable.cpp           |  8 +-
 src/nnvm/nnvm/IR/GlobalVariable.h             |  4 +
 src/nnvm/nnvm/IR/Instruction.h                |  1 +
 src/nnvm/nnvm/Transform/Infra/BlockUtils.cpp  | 11 +++
 src/nnvm/nnvm/Transform/Infra/BlockUtils.h    |  8 ++
 .../nnvm/Transform/Scalar/CFGCombiner.cpp     | 84 +++++++++++++++++--
 src/nnvm/nnvm/Transform/Scalar/CFGCombiner.h  |  2 +
 src/nnvm/nnvm/Transform/Scalar/CSE.cpp        |  6 +-
 .../nnvm/Transform/Scalar/CombinePatterns.h   | 62 +++++++++++++-
 src/nnvm/nnvm/Transform/Scalar/Combiner.cpp   | 28 ++++++-
 src/nnvm/nnvm/Transform/Scalar/Combiner.h     |  2 +
 .../nnvm/Transform/Scalar/ConstantFold.cpp    | 13 +++
 src/nnvm/nnvm/Transform/Scalar/ConstantFold.h |  1 +
 timestamp                                     |  2 +-
 29 files changed, 326 insertions(+), 41 deletions(-)
 create mode 100644 src/nnvm/nnvm/Analysis/AAInfo.h
 create mode 100644 src/nnvm/nnvm/Analysis/BasicAA.cpp
 create mode 100644 src/nnvm/nnvm/Analysis/BasicAA.h
 create mode 100644 src/nnvm/nnvm/Backend/RISCV/CodeLayoutOpt.cpp
 create mode 100644 src/nnvm/nnvm/Backend/RISCV/CodeLayoutOpt.h
 delete mode 100644 src/nnvm/nnvm/Backend/RISCV/Spiller.h
 create mode 100644 src/nnvm/nnvm/Transform/Infra/BlockUtils.cpp
 create mode 100644 src/nnvm/nnvm/Transform/Infra/BlockUtils.h

diff --git a/src/nnvm/nnvm/ADT/Ranges.h b/src/nnvm/nnvm/ADT/Ranges.h
index 620ce8b..1151667 100644
--- a/src/nnvm/nnvm/ADT/Ranges.h
+++ b/src/nnvm/nnvm/ADT/Ranges.h
@@ -68,21 +68,18 @@ template <typename Range> IncChangeRange<Range> incChange(Range &range) {
   return IncChangeRange<Range>(range);
 }
 
-template <typename T> struct reversion_wrapper {
-  T &iterable;
-};
+template <typename Iter> struct RangeWrapper {
+public:
+  Iter begin() { return beginIter; }
+  Iter end() { return endIter; }
 
-template <typename T> auto begin(reversion_wrapper<T> w) {
-  auto begin = std::begin(w.iterable);
-  return std::reverse_iterator<decltype(begin)>(begin);
-}
+private:
+  Iter beginIter;
+  Iter endIter;
+};
 
-template <typename T> auto end(reversion_wrapper<T> w) {
-  auto end = std::end(w.iterable);
-  return std::reverse_iterator<decltype(end)>(end);
+template <typename Iter> RangeWrapper<Iter> makeRange(Iter begin, Iter end) {
+  return RangeWrapper<Iter>(begin, end);
 }
 
-template <typename T> reversion_wrapper<T> reverseRange(T &&iterable) {
-  return {iterable};
-}
 } /* namespace nnvm */
diff --git a/src/nnvm/nnvm/Analysis/AAInfo.h b/src/nnvm/nnvm/Analysis/AAInfo.h
new file mode 100644
index 0000000..744be87
--- /dev/null
+++ b/src/nnvm/nnvm/Analysis/AAInfo.h
@@ -0,0 +1,6 @@
+#pragma once
+
+namespace nnvm {
+enum AAFlag { MayAlias, MustAlias, NotAlias };
+
+} /* namespace nnvm */
diff --git a/src/nnvm/nnvm/Analysis/AliasAnalysis.cpp b/src/nnvm/nnvm/Analysis/AliasAnalysis.cpp
index e69de29..e2b9238 100644
--- a/src/nnvm/nnvm/Analysis/AliasAnalysis.cpp
+++ b/src/nnvm/nnvm/Analysis/AliasAnalysis.cpp
@@ -0,0 +1,5 @@
+#include "AliasAnalysis.h"
+
+using namespace nnvm;
+
+AAFlag AliasAnalysis::alias(Value *a, Value *b) { return MayAlias; }
diff --git a/src/nnvm/nnvm/Analysis/AliasAnalysis.h b/src/nnvm/nnvm/Analysis/AliasAnalysis.h
index 317f2e9..f1d12f7 100644
--- a/src/nnvm/nnvm/Analysis/AliasAnalysis.h
+++ b/src/nnvm/nnvm/Analysis/AliasAnalysis.h
@@ -5,16 +5,20 @@
 
 #pragma once
 
+#include "Analysis/AAInfo.h"
 #include "IR/Instruction.h"
 #include "Transform/Infra/Pass.h"
 #include <unordered_map>
 #include <vector>
 
 namespace nnvm {
+
 class AliasAnalysis : public FunctionPass {
 public:
   bool run(Function &F);
 
+  AAFlag alias(Value *a, Value *b);
+
 private:
 };
 } /* namespace nnvm */
diff --git a/src/nnvm/nnvm/Analysis/BasicAA.cpp b/src/nnvm/nnvm/Analysis/BasicAA.cpp
new file mode 100644
index 0000000..5a3e181
--- /dev/null
+++ b/src/nnvm/nnvm/Analysis/BasicAA.cpp
@@ -0,0 +1,9 @@
+#include "BasicAA.h"
+#include "Analysis/AAInfo.h"
+using namespace nnvm;
+
+AAFlag BasicAA::alias(Value *a, Value *b) {
+  if (a == b)
+    return MustAlias;
+  return MayAlias;
+}
diff --git a/src/nnvm/nnvm/Analysis/BasicAA.h b/src/nnvm/nnvm/Analysis/BasicAA.h
new file mode 100644
index 0000000..5fa4bb9
--- /dev/null
+++ b/src/nnvm/nnvm/Analysis/BasicAA.h
@@ -0,0 +1,13 @@
+#pragma once
+
+#include "Analysis/AAInfo.h"
+#include "IR/Value.h"
+
+namespace nnvm {
+
+class BasicAA {
+public:
+  AAFlag alias(Value *a, Value *b);
+};
+
+} /* namespace nnvm */
diff --git a/src/nnvm/nnvm/Backend/RISCV/CodeLayoutOpt.cpp b/src/nnvm/nnvm/Backend/RISCV/CodeLayoutOpt.cpp
new file mode 100644
index 0000000..b420ed8
--- /dev/null
+++ b/src/nnvm/nnvm/Backend/RISCV/CodeLayoutOpt.cpp
@@ -0,0 +1,22 @@
+#include "CodeLayoutOpt.h"
+#include "ADT/GenericInt.h"
+#include "ADT/Ranges.h"
+#include "Backend/RISCV/CodegenInfo.h"
+#include "Backend/RISCV/LowIR.h"
+#include "Backend/RISCV/LowIR/LIRValue.h"
+#include "Backend/RISCV/LowInstType.h"
+#include "Backend/RISCV/PhiResolution.h"
+#include "IR/Instruction.h"
+#include "Utils/Cast.h"
+#include "Utils/Debug.h"
+#include <utility>
+using namespace nnvm;
+using namespace nnvm::riscv;
+
+void CodeLayoutOpt::layout(LIRFunc &func) {
+  std::list<LIRBB *> result;
+  std::unordered_set<LIRBB *> visited;
+
+  std::stack<LIRBB *> worklist;
+  worklist.push(func.getEntry());
+}
diff --git a/src/nnvm/nnvm/Backend/RISCV/CodeLayoutOpt.h b/src/nnvm/nnvm/Backend/RISCV/CodeLayoutOpt.h
new file mode 100644
index 0000000..1961582
--- /dev/null
+++ b/src/nnvm/nnvm/Backend/RISCV/CodeLayoutOpt.h
@@ -0,0 +1,12 @@
+
+#pragma once
+#include "Backend/RISCV/LowIR.h"
+#include "Backend/RISCV/LowIR/Builder.h"
+namespace nnvm::riscv {
+
+class CodeLayoutOpt {
+public:
+  void layout(LIRFunc &func);
+};
+
+} /* namespace nnvm::riscv */
diff --git a/src/nnvm/nnvm/Backend/RISCV/CodegenInfo.cpp b/src/nnvm/nnvm/Backend/RISCV/CodegenInfo.cpp
index 40d06e6..711f836 100644
--- a/src/nnvm/nnvm/Backend/RISCV/CodegenInfo.cpp
+++ b/src/nnvm/nnvm/Backend/RISCV/CodegenInfo.cpp
@@ -70,6 +70,7 @@ LIRBB *riscv::getBranchDest(LIRInst *inst) {
   case (B_BEGIN + 1)...(B_END - 1):
     return inst->getOp(2)->as<LIRBB>();
   default:
+    std::cerr << "Invalid opcode" << inst->getOpcode() << "\n";
     nnvm_unreachable("Must be a valid branch instruction")
   }
 }
@@ -83,6 +84,7 @@ void riscv::setBranchDest(LIRInst *inst, LIRBB *dest) {
     inst->setUse(2, dest);
     break;
   default:
+    std::cerr << "Invalid opcode" << inst->getOpcode() << "\n";
     nnvm_unreachable("Must be a valid branch instruction")
   }
 }
diff --git a/src/nnvm/nnvm/Backend/RISCV/LowIR.cpp b/src/nnvm/nnvm/Backend/RISCV/LowIR.cpp
index 8130ae9..82d1073 100644
--- a/src/nnvm/nnvm/Backend/RISCV/LowIR.cpp
+++ b/src/nnvm/nnvm/Backend/RISCV/LowIR.cpp
@@ -11,6 +11,9 @@ using namespace nnvm::riscv;
 
 void LIRInst::emit(std::ostream &out, EmitInfo &info) {
   switch (type) {
+  case IMPLICIT_JUMP:
+    out << "";
+    break;
   case SB:
   case SH:
   case SW:
diff --git a/src/nnvm/nnvm/Backend/RISCV/LowInstType.h b/src/nnvm/nnvm/Backend/RISCV/LowInstType.h
index ba987c3..302d19f 100644
--- a/src/nnvm/nnvm/Backend/RISCV/LowInstType.h
+++ b/src/nnvm/nnvm/Backend/RISCV/LowInstType.h
@@ -53,6 +53,8 @@ namespace nnvm::riscv {
 enum LIRInstID : uint64_t {
   // ==== Middle IR Reserved ====
   // .....
+  // ==== Generic ====
+  IMPLICIT_JUMP,
   // ==== RISC-V Specific ====
   ISA_BEGIN = (uint64_t)InstID::INST_END + 1,
   NONE,
diff --git a/src/nnvm/nnvm/Backend/RISCV/PhiResolution.cpp b/src/nnvm/nnvm/Backend/RISCV/PhiResolution.cpp
index a13e4a5..dba594b 100644
--- a/src/nnvm/nnvm/Backend/RISCV/PhiResolution.cpp
+++ b/src/nnvm/nnvm/Backend/RISCV/PhiResolution.cpp
@@ -29,10 +29,14 @@ void PhiResolution::processBB(LIRBB *BB) {
   for (uint64_t i = 1; i < firstPhi->getNumOp(); i += 2) {
     LIRBB *incomingBB = firstPhi->getOp(i)->as<LIRBB>();
     if (incomingBB->getSuccNum() == 2) {
-      int succIndex;
-      for (succIndex = 0; succIndex < incomingBB->getSuccNum(); succIndex++)
-        if (incomingBB->getSucc(succIndex) == BB)
+      int succIndex = -1;
+      for (int index = 0; index < incomingBB->getSuccNum(); index++)
+        if (incomingBB->getSucc(index) == BB) {
+          succIndex = index;
           break;
+        }
+
+      assert(succIndex != -1);
 
       LIRBB *splittedBB = new LIRBB;
       incomingBB->setSucc(succIndex, splittedBB);
diff --git a/src/nnvm/nnvm/Backend/RISCV/Spiller.h b/src/nnvm/nnvm/Backend/RISCV/Spiller.h
deleted file mode 100644
index 3a4d55b..0000000
--- a/src/nnvm/nnvm/Backend/RISCV/Spiller.h
+++ /dev/null
@@ -1,7 +0,0 @@
-#pragma once
-
-namespace nnvm::riscv {
-class Spiller {
-public:
-};
-} /* namespace nnvm::riscv */
diff --git a/src/nnvm/nnvm/Frontend/IRGenerator.cpp b/src/nnvm/nnvm/Frontend/IRGenerator.cpp
index e61fad7..81ca267 100644
--- a/src/nnvm/nnvm/Frontend/IRGenerator.cpp
+++ b/src/nnvm/nnvm/Frontend/IRGenerator.cpp
@@ -339,6 +339,7 @@ Any IRGenerator::constDef(SysYParser::ConstDefContext *ctx,
     if (symbolTable.isGlobal()) {
       GlobalVariable *global = new GlobalVariable(*ir, constVal);
       global->setName(ctx->IDENT()->getText());
+      global->setImmutable(true);
       return symbolTable.create(symbolName, symbolType, global);
     } else {
       Type *irElementType = toIRType(symbolType->getInnerMost());
@@ -361,6 +362,7 @@ Any IRGenerator::constDef(SysYParser::ConstDefContext *ctx,
   if (symbolTable.isGlobal()) {
     GlobalVariable *global = new GlobalVariable(*ir, constVal);
     global->setName(ctx->IDENT()->getText());
+    global->setImmutable(true);
     return symbolTable.create(symbolName, symbolType, global);
   } else {
     return symbolTable.create(symbolName, symbolType, constVal);
diff --git a/src/nnvm/nnvm/IR/BasicBlock.h b/src/nnvm/nnvm/IR/BasicBlock.h
index 0b4e0d1..044ca3c 100644
--- a/src/nnvm/nnvm/IR/BasicBlock.h
+++ b/src/nnvm/nnvm/IR/BasicBlock.h
@@ -76,9 +76,19 @@ public:
       return *this;
     }
 
+    PredIterator operator++() {
+      PredIterator ret = cur;
+      cur++;
+      while (cur != end && !dyn_cast<TerminatorInst>((*cur)->getUser()))
+        cur++;
+      return ret;
+    }
+
     BasicBlock *operator*() {
       return cast<Instruction>((*cur)->getUser())->getParent();
     }
+
+    bool operator==(PredIterator other) { return cur == other.cur; }
     bool operator!=(PredIterator other) { return cur != other.cur; }
 
   private:
@@ -102,6 +112,11 @@ public:
     ListTrait<BasicBlock>::eraseFromList();
   }
 
+  bool containsPhi() { return dyn_cast<PhiInst>(*begin()); }
+
+  const List<Instruction> &getInsts() const { return instList; }
+  List<Instruction> &getInsts() { return instList; }
+
   ~BasicBlock();
 
 private:
diff --git a/src/nnvm/nnvm/IR/GlobalVariable.cpp b/src/nnvm/nnvm/IR/GlobalVariable.cpp
index 8ca39c4..cd9be77 100644
--- a/src/nnvm/nnvm/IR/GlobalVariable.cpp
+++ b/src/nnvm/nnvm/IR/GlobalVariable.cpp
@@ -5,11 +5,11 @@ using namespace nnvm;
 
 GlobalVariable::GlobalVariable(Module &module, Type *innerType)
     : Value(ValueID::GlobalVariable, module.getPtrType()), innerType(innerType),
-      module(module) {}
+      immutable(false), module(module) {}
 
 GlobalVariable::GlobalVariable(Module &module, Constant *initVal)
     : Value(ValueID::GlobalVariable, module.getPtrType()), initVal(initVal),
-      innerType(initVal->getType()), module(module) {
+      innerType(initVal->getType()), immutable(false), module(module) {
   module.addGlobalVar(this);
 }
 
@@ -20,8 +20,10 @@ void GlobalVariable::setName(const std::string &name) {
 }
 
 std::string GlobalVariable::dump() {
+  std::string decorator = immutable ? "immutable " : "";
   auto initDump = (initVal ? (" init with " + initVal->dumpAsOperand()) : "");
-  return "global " + type->dump() + " " + getName() + initDump + "\n";
+  return decorator + "global " + type->dump() + " " + getName() + initDump +
+         "\n";
 }
 
 std::string GlobalVariable::dumpAsOperand() {
diff --git a/src/nnvm/nnvm/IR/GlobalVariable.h b/src/nnvm/nnvm/IR/GlobalVariable.h
index 7e0fb3c..71931aa 100644
--- a/src/nnvm/nnvm/IR/GlobalVariable.h
+++ b/src/nnvm/nnvm/IR/GlobalVariable.h
@@ -21,9 +21,13 @@ public:
   void setInnerType(Type *innerType) { this->innerType = innerType; }
   Type *getInnerType() { return innerType; }
 
+  void setImmutable(bool immutable) { this->immutable = immutable; }
+  bool isImmutable() const { return immutable; }
+
 private:
   Constant *initVal;
   Type *innerType;
+  bool immutable;
   Module &module;
 };
 
diff --git a/src/nnvm/nnvm/IR/Instruction.h b/src/nnvm/nnvm/IR/Instruction.h
index 6f54409..967acf5 100644
--- a/src/nnvm/nnvm/IR/Instruction.h
+++ b/src/nnvm/nnvm/IR/Instruction.h
@@ -119,6 +119,7 @@ public:
   const Metadata *getMetadata() const { return metadata; }
   Metadata *getMetadata() { return metadata; }
 
+  void removeFromBB() { ListTrait<Instruction>::removeFromList(); }
   void eraseFromBB() {
     for (Use *use : useeList)
       use->eraseFromList();
diff --git a/src/nnvm/nnvm/Transform/Infra/BlockUtils.cpp b/src/nnvm/nnvm/Transform/Infra/BlockUtils.cpp
new file mode 100644
index 0000000..c83bdf4
--- /dev/null
+++ b/src/nnvm/nnvm/Transform/Infra/BlockUtils.cpp
@@ -0,0 +1,11 @@
+#include "BlockUtils.h"
+#include "ADT/Ranges.h"
+
+using namespace nnvm;
+
+void nnvm::moveInstInBlock(BasicBlock *from, BasicBlock *to) {
+  for (Instruction *I : incChange(*from)) {
+    I->removeFromBB();
+    to->end().insertBefore(I);
+  }
+}
diff --git a/src/nnvm/nnvm/Transform/Infra/BlockUtils.h b/src/nnvm/nnvm/Transform/Infra/BlockUtils.h
new file mode 100644
index 0000000..639396d
--- /dev/null
+++ b/src/nnvm/nnvm/Transform/Infra/BlockUtils.h
@@ -0,0 +1,8 @@
+#pragma once
+
+#include "IR/Function.h"
+#include "IR/Module.h"
+#include <memory>
+namespace nnvm {
+void moveInstInBlock(BasicBlock *from, BasicBlock *to);
+} /* namespace nnvm */
diff --git a/src/nnvm/nnvm/Transform/Scalar/CFGCombiner.cpp b/src/nnvm/nnvm/Transform/Scalar/CFGCombiner.cpp
index 93bb8a0..c66c1a2 100644
--- a/src/nnvm/nnvm/Transform/Scalar/CFGCombiner.cpp
+++ b/src/nnvm/nnvm/Transform/Scalar/CFGCombiner.cpp
@@ -1,6 +1,7 @@
 #include "CFGCombiner.h"
 #include "ADT/Ranges.h"
 #include "IR/Instruction.h"
+#include "Transform/Infra/BlockUtils.h"
 #include "Utils/Cast.h"
 #include "Utils/Debug.h"
 using namespace nnvm;
@@ -34,14 +35,81 @@ bool CFGCombinerPass::run(Function &F) {
 
 bool CFGCombinerPass::processBB(BasicBlock *BB) {
   if (BranchInst *BI = dyn_cast<BranchInst>(BB->getTerminator())) {
-    // Replace "br true, a, b" with "br a".
-    if (BI->isConditional() && BI->getCondition()->isConstant()) {
-      ConstantInt *constCond = dyn_cast<ConstantInt>(BI->getCondition());
-      builder.setInsertPoint(BB->end());
-      builder.buildBr(constCond->getValue() ? BI->getSucc(0) : BI->getSucc(1));
-      BI->eraseFromBB();
-      return true;
-    }
+    if (!BI->isConditional())
+      return foldBBWithUncondBr(BB, BI);
+
+    if (BI->isConditional())
+      return foldBBWithCondBr(BB, BI);
+  }
+  return false;
+}
+
+bool CFGCombinerPass::foldBBWithUncondBr(BasicBlock *BB, BranchInst *BI) {
+
+  BasicBlock *succ = BI->getSucc(0);
+  if (BB == succ)
+    return false;
+
+  // Before:
+  // preds --> BB --> succ --> ...
+  // After:
+  // preds --> [BB -- succ] --> ...
+  if (succ->getInsts().size() > 1) {
+    // The successor must have BB as the single predecessor.
+    if (succ->getPredNum() != 1 || succ->containsPhi())
+      return false;
+
+    BI->eraseFromBB();
+    succ->replaceSelf(BB);
+    moveInstInBlock(succ, BB);
+
+    IRBuilder builder;
+    builder.setInsertPoint(succ->end());
+    builder.buildUnreachable();
+    return true;
+  }
+
+  // Before:
+  // pred --> BB --> succ
+  // After:
+  // pred --> succ
+
+  if (BB->getInsts().size() != 1)
+    return false;
+  if (BB->getPredNum() != 1)
+    return false;
+  BasicBlock *pred = *BB->getPredBegin();
+  if (pred->getSuccNum() != 1)
+    return false;
+  // Those jump from pred to BB, now jump from pred to succ directly.
+  TerminatorInst *TI = pred->getTerminator();
+  for (int i = 0; i < TI->getSuccNum(); i++)
+    if (TI->getSucc(i) == BB)
+      TI->setSucc(i, succ);
+
+  // Replace BB in phis with pred.
+  BB->replaceSelf(pred);
+
+  return true;
+}
+
+bool CFGCombinerPass::foldBBWithCondBr(BasicBlock *BB, BranchInst *BI) {
+  // Replace "br true, a, b" with "br a".
+  if (BI->getCondition()->isConstant()) {
+    ConstantInt *constCond = dyn_cast<ConstantInt>(BI->getCondition());
+
+    BasicBlock *unlinked =
+        constCond->getValue() ? BI->getSucc(1) : BI->getSucc(0);
+    for (Instruction *I : *unlinked)
+      if (PhiInst *phi = dyn_cast<PhiInst>(I))
+        phi->removeIncoming(BB);
+      else
+        break;
+
+    builder.setInsertPoint(BB->end());
+    builder.buildBr(constCond->getValue() ? BI->getSucc(0) : BI->getSucc(1));
+    BI->eraseFromBB();
+    return true;
   }
   return false;
 }
diff --git a/src/nnvm/nnvm/Transform/Scalar/CFGCombiner.h b/src/nnvm/nnvm/Transform/Scalar/CFGCombiner.h
index 8ef902b..1181bf4 100644
--- a/src/nnvm/nnvm/Transform/Scalar/CFGCombiner.h
+++ b/src/nnvm/nnvm/Transform/Scalar/CFGCombiner.h
@@ -18,6 +18,8 @@ public:
 
 private:
   bool processBB(BasicBlock *BB);
+  bool foldBBWithUncondBr(BasicBlock *BB, BranchInst *BI);
+  bool foldBBWithCondBr(BasicBlock *BB, BranchInst *BI);
   IRBuilder builder;
 };
 } /* namespace nnvm */
diff --git a/src/nnvm/nnvm/Transform/Scalar/CSE.cpp b/src/nnvm/nnvm/Transform/Scalar/CSE.cpp
index d04e800..9a041f3 100644
--- a/src/nnvm/nnvm/Transform/Scalar/CSE.cpp
+++ b/src/nnvm/nnvm/Transform/Scalar/CSE.cpp
@@ -19,6 +19,8 @@ bool CSEPass::EqInstImpl::operator()(Instruction *A, Instruction *B) const {
     return true;
   if (A->getOpcode() != B->getOpcode())
     return false;
+  if (A->getType() != B->getType())
+    return false;
   if (A->getOperandNum() != B->getOperandNum())
     return false;
   for (uint i = 0; i < A->getOperandNum(); i++) {
@@ -28,11 +30,11 @@ bool CSEPass::EqInstImpl::operator()(Instruction *A, Instruction *B) const {
 
   if (dyn_cast<ICmpInst>(A))
     return cast<ICmpInst>(A)->getPredicate() ==
-           dyn_cast<ICmpInst>(B)->getPredicate();
+           cast<ICmpInst>(B)->getPredicate();
 
   if (dyn_cast<FCmpInst>(A))
     return cast<FCmpInst>(A)->getPredicate() ==
-           dyn_cast<FCmpInst>(B)->getPredicate();
+           cast<FCmpInst>(B)->getPredicate();
 
   // TODO: handle commutative operators.
   return true;
diff --git a/src/nnvm/nnvm/Transform/Scalar/CombinePatterns.h b/src/nnvm/nnvm/Transform/Scalar/CombinePatterns.h
index 64d2493..780b110 100644
--- a/src/nnvm/nnvm/Transform/Scalar/CombinePatterns.h
+++ b/src/nnvm/nnvm/Transform/Scalar/CombinePatterns.h
@@ -37,6 +37,42 @@ protected:
   Value **receiver;
 };
 
+class pConstantInt {
+public:
+  pConstantInt() : receiver(nullptr) {}
+  pConstantInt(ConstantInt *&receiver) : receiver(&receiver) {}
+  bool match(Value *op) {
+    ConstantInt *opcasted = dyn_cast<ConstantInt>(op);
+    if (!opcasted)
+      return false;
+    if (receiver)
+      *receiver = opcasted;
+    return true;
+  }
+
+protected:
+  ConstantInt **receiver;
+};
+
+class pZero {
+public:
+  pZero() : receiver(nullptr) {}
+  pZero(ConstantInt *&receiver) : receiver(&receiver) {}
+  bool match(Value *op) {
+    ConstantInt *opcasted = dyn_cast<ConstantInt>(op);
+    if (!opcasted)
+      return false;
+    if (opcasted->getValue() != 0)
+      return false;
+    if (receiver)
+      *receiver = opcasted;
+    return true;
+  }
+
+protected:
+  ConstantInt **receiver;
+};
+
 class pInst {
 public:
   pInst() : receiver(nullptr) {}
@@ -71,7 +107,8 @@ public:
 template <InstID instID, typename LSubPattern, typename RSubPattern>
 class pBinOp : public pSpecificInst<instID> {
 public:
-  pBinOp(LSubPattern LHS, RSubPattern RHS) : pSpecificInst<instID>(), LHS(LHS), RHS(RHS) {}
+  pBinOp(LSubPattern LHS, RSubPattern RHS)
+      : pSpecificInst<instID>(), LHS(LHS), RHS(RHS) {}
 
   bool match(Value *op) {
     if (!pSpecificInst<instID>::match(op))
@@ -99,4 +136,27 @@ public:
       : pBinOp<InstID::Add, LSubPattern, RSubPattern>(LHS, RHS) {}
 };
 
+template <typename LSubPattern, typename RSubPattern>
+class pICmp : public pSpecificInst<InstID::ICmp> {
+public:
+  pICmp(LSubPattern LHS, RSubPattern RHS)
+      : pSpecificInst<InstID::ICmp>(), LHS(LHS), RHS(RHS) {}
+
+  bool match(Value *op) {
+    if (!pSpecificInst<InstID::ICmp>::match(op))
+      return false;
+
+    ICmpInst *I = cast<ICmpInst>(op);
+    if (!LHS.match(I->getOperand(0)))
+      return false;
+    if (!RHS.match(I->getOperand(1)))
+      return false;
+    return true;
+  }
+
+protected:
+  LSubPattern LHS;
+  RSubPattern RHS;
+};
+
 } // namespace nnvm::pattern
diff --git a/src/nnvm/nnvm/Transform/Scalar/Combiner.cpp b/src/nnvm/nnvm/Transform/Scalar/Combiner.cpp
index 020ae84..9185ad9 100644
--- a/src/nnvm/nnvm/Transform/Scalar/Combiner.cpp
+++ b/src/nnvm/nnvm/Transform/Scalar/Combiner.cpp
@@ -50,14 +50,26 @@ Value *CombinerPass::simplifyInst(Instruction *I) {
   if (SDivInst *SI = dyn_cast<SDivInst>(I))
     return simplifySDiv(SI);
 
-  //if (PhiInst *phi = dyn_cast<PhiInst>(I))
-    //return simplifyPhi(phi);
+  if (ICmpInst *ICI = dyn_cast<ICmpInst>(I))
+    return simplifyICmp(ICI);
+
+  if (PhiInst *phi = dyn_cast<PhiInst>(I))
+    return simplifyPhi(phi);
 
   return nullptr;
 }
 
 Value *CombinerPass::simplifyAdd(AddInst *I) {
   Value *A, *B, *C;
+
+  // C1 + A --> A + C1
+  if (match(I, pAdd(pConstant(A), pValue(B))))
+    return builder.buildBinOp<AddInst>(B, A, I->getType());
+
+  // A + 0 --> A
+  if (match(I, pAdd(pValue(A), pZero())))
+    return A;
+
   // (A + C1) + C2 --> A + (C1 + C2)
   if (match(I, pAdd(pAdd(pValue(A), pConstant(B)), pConstant(C)))) {
     Value *addc = builder.buildBinOp<AddInst>(B, C, I->getType());
@@ -69,9 +81,19 @@ Value *CombinerPass::simplifyAdd(AddInst *I) {
 
 Value *CombinerPass::simplifySDiv(SDivInst *I) { return nullptr; }
 
+Value *CombinerPass::simplifyICmp(ICmpInst *I) {
+  Value *A, *B;
+  // A != 0  -->  A
+  if (I->getPredicate() == ICmpInst::NE &&
+      match(I, pICmp(pValue(A), pZero())) && A->getType()->isIntegerNBits(1))
+    return A;
+  return nullptr;
+}
+
 Value *CombinerPass::simplifyPhi(PhiInst *I) {
   if (I->getIncomingNum() == 1)
-    return I->getIncomingValue(0);
+    return I->getIncomingValue(0)->isInstruction() ? nullptr
+                                                   : I->getIncomingValue(0);
 
   return nullptr;
 }
diff --git a/src/nnvm/nnvm/Transform/Scalar/Combiner.h b/src/nnvm/nnvm/Transform/Scalar/Combiner.h
index d463170..fcf6d60 100644
--- a/src/nnvm/nnvm/Transform/Scalar/Combiner.h
+++ b/src/nnvm/nnvm/Transform/Scalar/Combiner.h
@@ -22,6 +22,8 @@ private:
   Value *simplifyInst(Instruction *I);
   Value *simplifyAdd(AddInst *I);
   Value *simplifySDiv(SDivInst *I);
+
+  Value *simplifyICmp(ICmpInst *I);
   Value *simplifyPhi(PhiInst *I);
 
   std::queue<Instruction *> worklist;
diff --git a/src/nnvm/nnvm/Transform/Scalar/ConstantFold.cpp b/src/nnvm/nnvm/Transform/Scalar/ConstantFold.cpp
index 606e4a8..51882c0 100644
--- a/src/nnvm/nnvm/Transform/Scalar/ConstantFold.cpp
+++ b/src/nnvm/nnvm/Transform/Scalar/ConstantFold.cpp
@@ -45,6 +45,9 @@ Value *ConstantFold::fold(Instruction *I) {
         dyn_cast<Constant>(CI->getOperand(1)))
       return foldICmp(CI);
 
+  if (LoadInst *LI = dyn_cast<LoadInst>(I))
+    return foldLoad(LI);
+
   // TODO: Handle other operator on constant operands, such as "a[0]", where "a"
   // is a constant array.
 
@@ -123,3 +126,13 @@ Value *ConstantFold::foldICmp(ICmpInst *I) {
     return nullptr;
   }
 }
+
+Value *ConstantFold::foldLoad(LoadInst *I) {
+  if (GlobalVariable *GV = dyn_cast<GlobalVariable>(I->getSrc())) {
+    if (!GV->isImmutable())
+      return nullptr;
+    if (I->getType() == GV->getInitVal()->getType())
+      return GV->getInitVal();
+  }
+  return nullptr;
+}
diff --git a/src/nnvm/nnvm/Transform/Scalar/ConstantFold.h b/src/nnvm/nnvm/Transform/Scalar/ConstantFold.h
index 9b7184d..e0f2063 100644
--- a/src/nnvm/nnvm/Transform/Scalar/ConstantFold.h
+++ b/src/nnvm/nnvm/Transform/Scalar/ConstantFold.h
@@ -25,6 +25,7 @@ public:
   Value *foldSDiv(SDivInst *I);
   Value *foldSRem(SRemInst *I);
   Value *foldICmp(ICmpInst *I);
+  Value *foldLoad(LoadInst *I);
 
   void setModule(Module *module) { this->module = module; }
 
diff --git a/timestamp b/timestamp
index 1494600..d058203 100644
--- a/timestamp
+++ b/timestamp
@@ -1 +1 @@
-Sun Jul  7 08:24:46 UTC 2024
+Sun Jul  7 16:33:38 UTC 2024
-- 
GitLab