Joos1W Compiler Framework
All Classes Functions Typedefs Pages
SimplifyCFG.cc
1 #include <unordered_set>
2 
3 #include "../IRContextPass.h"
4 #include "tir/BasicBlock.h"
5 #include "tir/Instructions.h"
6 #include "utils/PassManager.h"
7 
8 using std::string_view;
9 using utils::Pass;
10 using utils::PassManager;
11 
12 namespace {
13 
14 // Remove all instructions after the first terminator in a basic block
15 bool eliminateAfterFirstTerminator(tir::BasicBlock& bb) {
16  auto firstTerminator = bb.begin();
17  for(auto it = bb.begin(); it != bb.end(); ++it) {
18  if(it->isTerminator()) {
19  firstTerminator = it;
20  break;
21  }
22  }
23  if(firstTerminator == --bb.end()) return false;
24  auto instr = *(++firstTerminator);
25  while(instr != nullptr) {
26  auto next = instr->next();
27  bb.erase(instr);
28  instr = next;
29  }
30  return true;
31 }
32 
33 // Merge two basic blocks if the second one has a single predecessor
34 // and the first one has a single successor
35 bool mergeSinglePredSingleSucc(tir::BasicBlock& bb) {
36  // 1. Grab the successor of the current basic block
37  auto term = dyn_cast<tir::BranchInst>(bb.terminator());
38  if(term == nullptr) return false;
39  auto succ = term->getSuccessor(0);
40  if(term->getSuccessor(1) != succ) return false;
41  // 2. Check if the successor's predecessors is this BB
42  for(auto pred : succ->users()) {
43  if(pred != term) return false;
44  }
45  // 3. Move all the instructions from the successor to the current basic block
46  term->eraseFromParent();
47  tir::Instruction* instr = *succ->begin();
48  while(instr) {
49  auto next = instr->next();
50  instr->eraseFromParent(true); // Keep all refs
51  bb.appendAfterEnd(instr);
52  instr = next;
53  }
54  // 4. Remove the successor from the parent function
55  succ->eraseFromParent();
56  return true;
57 }
58 
59 /**
60  * @brief Merge basic block with single successor in one branch.
61  *
62  * +---------------+
63  * | bb0: |
64  * | br bb1, bb2 | =====> Can be transformed into br bb1, bb3
65  * +-+----------+--+
66  * | |
67  * v v
68  * +----+--+ +---+------+
69  * | bb1: | | bb2: |
70  * | ... | | br bb3 |
71  * +-------+ +----------+
72  *
73  */
74 bool replaceSucessorInOneBranch(tir::BasicBlock& bb) {
75  bool changed = false;
76  auto term = dyn_cast<tir::BranchInst>(bb.terminator());
77  if(term == nullptr) return changed;
78  // Try to replace either successor
79  for(int i = 0; i < 2; i++) {
80  auto succ = term->getSuccessor(i);
81  // 1. There must be only one successor
82  if(++succ->begin() != succ->end()) continue;
83  // 2. The sucessor must be an unconditional branch
84  auto sterm = dyn_cast<tir::BranchInst>(succ->terminator());
85  if(sterm == nullptr) continue;
86  if(sterm->getSuccessor(0) != sterm->getSuccessor(1)) continue;
87  // 3. Then this sucessor can be replaced
88  term->replaceSuccessor(i, sterm->getSuccessor(0));
89  changed = true;
90  }
91  return changed;
92 }
93 
94 } // namespace
95 
96 /* ===--------------------------------------------------------------------=== */
97 // SimplifyCFG pass infrastructure
98 /* ===--------------------------------------------------------------------=== */
99 
100 class SimplifyCFG final : public Pass {
101 public:
102  SimplifyCFG(PassManager& PM) noexcept : Pass(PM) {}
103  void Run() override {
104  tir::CompilationUnit& CU = GetPass<IRContextPass>().CU();
105  std::vector<tir::BasicBlock*> toRemove;
106  for(auto func : CU.functions()) {
107  // 1. Iteratively simplify the CFG
108  bool changed = false;
109  do {
110  changed = false;
111  visited.clear();
112  if(func->getEntryBlock()) {
113  changed = visitBB(*func->getEntryBlock());
114  }
115  } while(changed);
116  // 2. Record all the basic blocks that were not visited
117  toRemove.clear();
118  for(auto bb : func->body()) {
119  if(visited.count(bb)) continue;
120  toRemove.push_back(bb);
121  }
122  // 3. Remove the basic blocks that are unreachable
123  for(auto bb : toRemove) {
124  bb->eraseFromParent();
125  }
126  }
127  }
128  string_view Name() const override { return "simplifycfg"; }
129  string_view Desc() const override { return "Simplify CFG"; }
130 
131 private:
132  bool visitBB(tir::BasicBlock& bb) {
133  bool changed = false;
134  if(visited.count(&bb)) return false;
135  visited.insert(&bb);
136  // 1. Run all the simplifications
137  changed |= eliminateAfterFirstTerminator(bb);
138  changed |= mergeSinglePredSingleSucc(bb);
139  changed |= replaceSucessorInOneBranch(bb);
140  // 2. Grab the next basic block
141  auto term = dyn_cast<tir::BranchInst>(bb.terminator());
142  if(term == nullptr) return false;
143  changed |= visitBB(*term->getSuccessor(0));
144  changed |= visitBB(*term->getSuccessor(1));
145  return changed;
146  }
147 
148 private:
149  void computeDependencies() override {
150  ComputeDependency(GetPass<IRContextPass>());
151  }
152  std::unordered_set<tir::BasicBlock*> visited;
153 };
154 
155 REGISTER_PASS(SimplifyCFG);