Joos1W Compiler Framework
All Classes Functions Typedefs Pages
MemToReg.cc
1 #include <queue>
2 #include <unordered_set>
3 
4 #include "../IRContextPass.h"
5 #include "diagnostics/Diagnostics.h"
6 #include "tir/BasicBlock.h"
7 #include "tir/Constant.h"
8 #include "tir/Instructions.h"
9 #include "utils/PassManager.h"
10 
11 using std::string_view;
12 using utils::Pass;
13 using utils::PassManager;
14 using DE = diagnostics::DiagnosticEngine;
15 using namespace tir;
16 
17 // Enclose everything but the pass in an anonymous namespace
18 namespace {
19 
20 /* ===--------------------------------------------------------------------=== */
21 // Class definitions
22 /* ===--------------------------------------------------------------------=== */
23 
24 /**
25  * @brief A class to compute and store the dominator tree and
26  * dominance frontiers of a given function.
27  * Citation: A Simple, Fast Dominance Algorithm
28  * By: Keith D. Cooper, Timothy J. Harvey, and Ken Kennedy
29  */
30 class DominantorTree {
31 public:
32  DominantorTree(Function* func) : func(func) {
33  computePostorderIdx(func);
34  computeDominators(func);
35  computeFrontiers(func);
36  }
37 
38  std::ostream& print(std::ostream& os) const {
39  // Print the dominator tree
40  os << "*** Dominator Tree ***\n";
41  for(auto b : func->body()) {
42  if(doms.contains(b)) {
43  os << " Dom(";
44  b->printName(os) << ") = ";
45  doms.at(b)->printName(os) << std::endl;
46  }
47  }
48  // Print the dominance frontier
49  os << "*** Dominance Frontier ***\n";
50  for(auto [b, frontier] : frontiers) {
51  os << " DF(";
52  b->printName(os) << ") = {";
53  bool first = true;
54  for(auto f : frontier) {
55  if(!first) os << ", ";
56  first = false;
57  f->printName(os);
58  }
59  os << "}\n";
60  }
61  return os;
62  }
63 
64  // Get the dominance frontier of block b
65  auto& DF(BasicBlock* b) { return frontiers[b]; }
66 
67  // Get the immediate dominator of block b, nullptr if it does not exist
68  BasicBlock* getIDom(BasicBlock* b) {
69  return doms.contains(b) ? doms.at(b) : nullptr;
70  }
71 
72 private:
73  Function* func;
74  std::unordered_map<BasicBlock*, BasicBlock*> doms;
75  std::unordered_map<BasicBlock*, int> poidx;
76  std::unordered_map<BasicBlock*, std::unordered_set<BasicBlock*>> frontiers;
77 
78  void computePostorderIdx(Function* func);
79  void computeDominators(Function* func);
80  BasicBlock* intersect(BasicBlock* b1, BasicBlock* b2);
81  void computeFrontiers(Function* func);
82 };
83 
84 /**
85  * @brief Hoist alloca instructions into registers by placing phi nodes
86  * in the dominance frontier of the alloca.
87  * Citation: Simple and Efficient Construction of Static Single Assignment Form
88  * Authors: Ron Cytron, et al.
89  * URL: https://pages.cs.wisc.edu/~fischer/cs701.f14/ssa.pdf
90  * Also referenced: https://c9x.me/compile/bib/braun13cc.pdf
91  */
92 class HoistAlloca {
93 public:
94  HoistAlloca(Function* func, DE& diag) : func(func), DT{func} {
95  // 1. Print out the dominator tree
96  if(diag.Verbose(2)) {
97  auto dbg = diag.ReportDebug();
98  DT.print(dbg.get());
99  }
100  // 2. Place PHI nodes for each alloca
101  for(auto alloca : func->allocas()) {
102  if(canAllocaBeReplaced(alloca)) {
103  placePHINodes(alloca);
104  }
105  }
106 
107  // 3. Print out the alloca and phi nodes
108  if(diag.Verbose()) {
109  print(diag.ReportDebug().get());
110  }
111  }
112 
113  std::ostream& print(std::ostream& os) const {
114  os << "*** PHI node insertion points ***\n";
115  for(auto [alloca, phis] : allocaPhiMap) {
116  os << " Insert PHI for: ";
117  alloca->printName(os) << "\n";
118  for(auto bb : phis) {
119  os << " at: ";
120  bb->printName(os) << "\n";
121  }
122  }
123  return os;
124  }
125 
126 private:
127  Function* func;
128  DominantorTree DT;
129  // Places to insert PHI for alloca
130  std::unordered_map<AllocaInst*, std::vector<BasicBlock*>> allocaPhiMap;
131 
132 private:
133  void placePHINodes(AllocaInst* alloca);
134  bool canAllocaBeReplaced(AllocaInst* alloca);
135  void replaceUses(AllocaInst* alloca);
136 };
137 
138 /* ===--------------------------------------------------------------------=== */
139 // DominatorTree implementation
140 /* ===--------------------------------------------------------------------=== */
141 
142 void DominantorTree::computePostorderIdx(Function* func) {
143  int i = 0;
144  for(auto b : func->reversePostOrder()) poidx[b] = i++;
145 }
146 
147 void DominantorTree::computeDominators(Function* func) {
148  doms[func->getEntryBlock()] = func->getEntryBlock();
149  bool changed = true;
150  while(changed) {
151  changed = false;
152  for(auto b : func->reversePostOrder()) {
153  BasicBlock* newIdom = nullptr;
154  for(auto pred : b->predecessors()) {
155  if(doms.contains(pred)) {
156  if(!newIdom) {
157  newIdom = pred;
158  } else {
159  newIdom = intersect(pred, newIdom);
160  }
161  }
162  }
163  if(!newIdom) continue;
164  if(!doms.contains(b) || doms[b] != newIdom) {
165  doms[b] = newIdom;
166  changed = true;
167  }
168  }
169  }
170 }
171 
172 BasicBlock* DominantorTree::intersect(BasicBlock* b1, BasicBlock* b2) {
173  BasicBlock* finger1 = b1;
174  BasicBlock* finger2 = b2;
175  while(finger1 != finger2) {
176  while(poidx[finger1] > poidx[finger2]) finger1 = doms[finger1];
177  while(poidx[finger2] > poidx[finger1]) finger2 = doms[finger2];
178  }
179  return finger1;
180 }
181 
182 void DominantorTree::computeFrontiers(Function* func) {
183  for(auto b : func->body()) {
184  int numPreds = 0;
185  for(auto p : b->predecessors()) {
186  (void)p;
187  numPreds++;
188  }
189  if(numPreds < 2) continue;
190  for(auto pred : b->predecessors()) {
191  BasicBlock* runner = pred;
192  while(runner != doms[b]) {
193  frontiers[runner].insert(b);
194  runner = doms[runner];
195  }
196  }
197  }
198 }
199 
200 /* ===--------------------------------------------------------------------=== */
201 // HoistAlloca implementation
202 /* ===--------------------------------------------------------------------=== */
203 
204 bool HoistAlloca::canAllocaBeReplaced(AllocaInst* alloca) {
205  // 1. The allocas must be a scalar type
206  if(!alloca->type()->isIntegerType() && !alloca->type()->isPointerType())
207  return false;
208  // 2. Each use must be a load or store instruction
209  for(auto user : alloca->users()) {
210  if(!dyn_cast<LoadInst>(user) && !dyn_cast<StoreInst>(user)) return false;
211  }
212  return true;
213 }
214 
215 // Ref from paper, Figure 4. Placement of PHI-functions
216 void HoistAlloca::placePHINodes(AllocaInst* V) {
217  std::unordered_set<BasicBlock*> DFPlus;
218  std::unordered_set<BasicBlock*> Work;
219  std::queue<BasicBlock*> W;
220  // NOTE: A(V) = set of stores to V
221  for(auto user : V->users()) {
222  if(auto X = dyn_cast<StoreInst>(user)) {
223  Work.insert(X->parent());
224  W.push(X->parent());
225  }
226  }
227  while(!W.empty()) {
228  BasicBlock* X = W.front();
229  W.pop();
230  for(auto Y : DT.DF(X)) {
231  if(DFPlus.contains(Y)) continue;
232  allocaPhiMap[V].push_back(Y);
233  DFPlus.insert(Y);
234  if(!Work.contains(Y)) {
235  Work.insert(Y);
236  W.push(Y);
237  }
238  }
239  }
240 }
241 
242 // Ref from paper, Figure 5. Construction of SSA form
243 void HoistAlloca::replaceUses(AllocaInst* V) {
244 
245 }
246 
247 } // namespace
248 
249 /* ===--------------------------------------------------------------------=== */
250 // MemToReg pass wrapper
251 /* ===--------------------------------------------------------------------=== */
252 
253 class MemToReg final : public Pass {
254 public:
255  MemToReg(PassManager& PM) noexcept : Pass(PM) {}
256  void Run() override {
257  tir::CompilationUnit& CU = GetPass<IRContextPass>().CU();
258  for(auto func : CU.functions()) {
259  if(!func->hasBody() || !func->getEntryBlock()) continue;
260  if(PM().Diag().Verbose()) {
261  PM().Diag().ReportDebug()
262  << "*** Running on function: " << func->name() << " ***";
263  }
264  HoistAlloca hoister{func, PM().Diag()};
265  }
266  }
267  string_view Name() const override { return "mem2reg"; }
268  string_view Desc() const override { return "Promote memory to register"; }
269 
270 private:
271  void computeDependencies() override {
272  ComputeDependency(GetPass<IRContextPass>());
273  }
274 };
275 
276 REGISTER_PASS(MemToReg);