Joos1W Compiler Framework
All Classes Functions Typedefs Pages
PassManager.h
1 #pragma once
2 
3 #include <utils/Error.h>
4 
5 #include <memory>
6 #include <string_view>
7 #include <type_traits>
8 #include <unordered_map>
9 #include <vector>
10 
11 #include "diagnostics/Diagnostics.h"
12 #include "third-party/CLI11.h"
13 #include "utils/BumpAllocator.h"
14 #include "utils/Generator.h"
15 
16 namespace utils {
17 
18 class PassManager;
19 class Pass;
20 class PassOptions;
21 
22 template <typename T>
23 concept PassType = std::is_base_of_v<Pass, T>;
24 
25 class PassOptions {
26 public:
27  PassOptions(CLI::App& app) : app_{app} {}
28  PassOptions(PassOptions const&) = delete;
29  PassOptions(PassOptions&&) = delete;
30  PassOptions& operator=(PassOptions const&) = delete;
31  PassOptions& operator=(PassOptions&&) = delete;
32  ~PassOptions() = default;
33 
34  CLI::Option* FindOption(std::string_view name) {
35  std::string test_name = "--" + get_single_name(name);
36  if(test_name.size() == 3) test_name.erase(0, 1);
37  return app_.get_option_no_throw(test_name);
38  }
39 
40  CLI::Option* GetExistingOption(std::string name) {
41  auto opt = FindOption(name);
42  if(opt == nullptr)
43  throw utils::FatalError("pass requested nonexistent option: " + name);
44  return opt;
45  }
46 
47  /// @return True if the pass is disabled
48  bool IsPassDisabled(Pass* p);
49 
50  /// @brief Enables or disables a pass given the pass name
51  void EnablePass(std::string_view name, bool enabled = true) {
52  if(auto it = pass_descs_.find(std::string{name});
53  it != pass_descs_.end()) {
54  it->second.enabled = enabled;
55  } else {
56  assert(false && "Pass not found");
57  }
58  }
59 
60  /// @brief Iterate through the pass names
62  for(auto& [name, desc] : pass_descs_) co_yield {name, desc.desc};
63  }
64 
65  /// @brief Get the description of a pass
67  return pass_descs_.find(std::string{name}) != pass_descs_.end();
68  }
69 
70 private:
71  /**
72  * @brief Get the single name of a command line option
73  * @param name The option name to get the single name of (ie., "-o,--option")
74  * @return std::string The single name (ie., "option")
75  */
77  std::vector<std::string> s, l;
78  std::string p;
79  std::tie(s, l, p) =
81  if(!l.empty()) return l[0];
82  if(!s.empty()) return s[0];
83  if(!p.empty()) return p;
84  return std::string{name};
85  }
86 
87  /// @brief Sets the pass to be enabled or disabled
88  /// @param p The pass to set
89  /// @param enabled If true, the pass is enabled
90  void setPassEnabled(Pass* p, bool enabled);
91 
92 private:
93  friend class Pass;
94  friend class PassManager;
95 
96  CLI::App& app_;
97  // A list of passes parsed from the command line
98  std::string passes_;
99  /// @brief A map of pass names to descriptions and whether they are enabled
100  struct PassDesc {
101  bool enabled;
102  std::string desc;
103  };
104  std::unordered_map<std::string, PassDesc> pass_descs_;
105 };
106 
107 /* ===--------------------------------------------------------------------=== */
108 // Pass
109 /* ===--------------------------------------------------------------------=== */
110 
111 class Pass {
112 private:
113  // Deleted copy and move constructor and assignment operator
114  Pass(Pass const&) = delete;
115  Pass(Pass&&) = delete;
116  Pass& operator=(Pass const&) = delete;
117  Pass& operator=(Pass&&) = delete;
118 
119 public:
120  virtual ~Pass() = default;
121  /// @brief Function to override when you want to acquire resources
122  virtual void Init() {}
123  /// @brief Function to override to run the pass
124  virtual void Run() = 0;
125  /// @brief Function to override to get the name (id) of the pass
126  virtual std::string_view Name() const { return ""; }
127  /// @brief Function to override to get the description of the pass
128  virtual std::string_view Desc() const = 0;
129  /// @brief Preserve the analysis results of this pass
130  void Preserve() { preserve = true; }
131  /// @brief Should this pass be preserved?
132  bool ShouldPreserve() const { return preserve; }
133 
134 protected:
135  /// @brief Gets the pass manager that owns the pass
136  auto& PM() { return pm_; }
137 
138  /// @brief Gets a single pass of type T. Throws if no pass is found.
139  /// Also throws if multiple passes of type T are found.
140  /// @tparam T The type of the pass
141  /// @return T& The pass of type T
142  template <typename T>
143  requires PassType<T>
144  T& GetPass();
145 
146  /// @brief Gets a single pass by name. Throws if no pass is found.
147  Pass& GetPass(std::string_view name);
148 
149  /// @brief Gets all passes of type T. Throws if no pass is found.
150  /// @tparam T The type of the pass
151  /// @return A generator that yields all passes of type T
152  template <typename T>
153  requires PassType<T>
154  Generator<T*> GetPasses();
155 
156  /// @brief Computes a dependency between this and another pass
157  /// @param pass The pass to add as a dependency
158  void ComputeDependency(Pass& pass);
159 
160  /// @brief Requests a new heap from the pass manager
161  /// @return CustomBufferResource* The new heap
163 
164 protected:
165  friend class PassManager;
166  /// @brief Overload to state the dependencies of this pass
167  virtual void computeDependencies() = 0;
168  /// @brief Constructor for the pass
169  /// @param pm The pass manager that owns the pass
170  explicit Pass(PassManager& pm) noexcept : pm_(pm) {}
171 
172 private:
173  /// @brief Adds a switch to enable the pass
174  void RegisterCLI();
175 
176 private:
177  PassManager& pm_;
178  enum State {
179  Uninitialized,
180  PropagateEnabled,
181  AcquireResources,
182  RegisterDependencies,
183  Running,
184  Cleanup,
185  Valid,
186  Invalid
187  };
188  State state = State::Uninitialized;
189  bool preserve = false;
190 };
191 
192 /* ===--------------------------------------------------------------------=== */
193 // PassManager
194 /* ===--------------------------------------------------------------------=== */
195 
196 class PassManager final {
197 private:
198  /// @brief A heap is a bump allocator that is used to allocate memory. The
199  /// heap is owned by a pass and is destroyed when no future passes will
200  /// require it.
201  struct Heap {
202  std::unique_ptr<CustomBufferResource> heap;
203  Pass* owner;
204  int refCount;
205  Heap(Pass* owner)
206  : heap{std::make_unique<CustomBufferResource>()},
207  owner{owner},
208  refCount{1} {}
209  };
210 
211 public:
212  PassManager(CLI::App& app) : options_{app}, reuseHeaps_{true} {}
213  /// @brief Runs all the passes in the pass manager
214  /// @return True if all passes ran successfully
215  bool Run();
216 
217  /// @brief Resets the pass manager and frees all resources
218  void Reset();
219 
220  /// @return The last pass that was run by the pass manager
221  Pass const* LastRun() const { return lastRun_; }
222 
223  /// @brief Sets whether the pass manager should reuse heaps
224  void SetHeapReuse(bool reuse) { reuseHeaps_ = reuse; }
225 
226  // Deleted copy and move constructor and assignment operator
227  PassManager(PassManager const&) = delete;
228  PassManager(PassManager&&) = delete;
229  PassManager& operator=(PassManager const&) = delete;
230  PassManager& operator=(PassManager&&) = delete;
231 
232  ~PassManager() {
233  // Clean up passes, and then heaps
234  passes_.clear();
235  heaps_.clear();
236  }
237 
238  /// @brief Adds a pass to the pass manager
239  /// @tparam T The type of the passr
240  /// @param ...args The remaining arguments to pass to the pass constructor.
241  /// The pass manager will be passed to the pass as the first argument.
242  template <typename T, typename... Args>
243  requires PassType<T>
244  T& AddPass(Args&&... args) {
245  passes_.emplace_back(new T(*this, std::forward<Args>(args)...));
246  T& result = *cast<T*>(passes_.back().get());
247  // If the pass has a name, register it as constructible from the command
248  // line options
249  if(!result.Name().empty()) result.RegisterCLI();
250  return result;
251  }
252 
253  /// @brief Gets a reference to the diagnostic engine
254  diagnostics::DiagnosticEngine& Diag() { return diag_; }
255 
256  /// @brief Gets the pass options
257  PassOptions& PO() { return options_; }
258 
259  /// @brief Outside method to get a pass by type
260  template <typename T>
261  requires PassType<T>
262  T& FindPass() {
263  T* result = nullptr;
264  for(auto& pass : passes_) {
265  if(auto* p = dyn_cast<T*>(pass.get())) {
266  if(result != nullptr)
267  throw utils::FatalError("Multiple passes of type: " +
268  std::string(typeid(T).name()));
269  result = p;
270  }
271  }
272  if(result == nullptr) {
273  throw utils::FatalError("Pass not found: " +
274  std::string(typeid(T).name()));
275  }
276  return *result;
277  }
278 
279  /// @brief Declare an analysis to be preserved forever
280  template <typename T>
281  requires PassType<T>
282  void PreserveAnalysis() {
283  auto& pass = FindPass<T>();
284  pass.Preserve();
285  }
286 
287 private:
288  template <typename T>
289  requires PassType<T>
290  T& getPass(Pass& pass) {
291  auto& result = FindPass<T>();
292  // If the requester is running, the result must be valid
293  if(pass.state == Pass::Running && result.state != Pass::Valid) {
294  throw utils::FatalError("Pass not valid: " +
295  std::string(typeid(T).name()));
296  }
297  return result;
298  }
299 
300  template <typename T>
301  requires PassType<T>
302  Generator<T*> getPasses(Pass& pass) {
303  bool found = false;
304  for(auto& pass : passes_) {
305  if(auto* p = dyn_cast<T*>(pass.get())) {
306  // If the requester is running, the result must be valid
307  if(p->state == Pass::Running && p->state != Pass::Valid) {
308  throw utils::FatalError("Pass not valid: " +
309  std::string(typeid(T).name()));
310  }
311  co_yield p;
312  found = true;
313  }
314  }
315  if(!found)
316  throw utils::FatalError("Pass of type not found: " +
317  std::string(typeid(T).name()));
318  }
319 
320  /// @brief Gets a single pass by name. Throws if no pass is found.
321  Pass& getPass(std::string_view name) {
322  for(auto& pass : passes_)
323  if(pass->Name() == name) return *pass;
324  throw utils::FatalError("Pass not found: " + std::string{name});
325  }
326 
327 private:
328  friend class Pass;
329  friend class HeapRef;
330  CustomBufferResource* newHeap(Pass& pass);
331  void freeHeap(Heap& heap);
332  void addDependency(Pass& pass, Pass& depends) {
333  depgraph_[&depends].push_back(&pass);
334  passDeps_[&pass]++;
335  }
336 
337 private:
338  std::vector<std::unique_ptr<Pass>> passes_;
339  std::vector<Heap> heaps_;
340  std::unordered_map<Pass*, std::vector<Pass*>> depgraph_;
341  diagnostics::DiagnosticEngine diag_;
342  Pass* lastRun_ = nullptr;
343  std::unordered_map<Pass*, int> passDeps_;
344  PassOptions options_;
345  bool reuseHeaps_;
346 };
347 
348 template <typename T>
349  requires PassType<T>
350 T& Pass::GetPass() {
351  return PM().getPass<T>(*this);
352 }
353 
354 template <typename T>
355  requires PassType<T>
356 Generator<T*> Pass::GetPasses() {
357  return PM().getPasses<T>(*this);
358 }
359 
360 /**
361  * @brief Registers NS::T with the pass manager.
362  */
363 #define REGISTER_PASS_NS(NS, T)
364  utils::Pass& New##T(utils::PassManager& PM) { return PM.AddPass<NS::T>(); }
365 
366 /**
367  * @brief Registers T with the pass manager.
368  */
369 #define REGISTER_PASS(T)
370  utils::Pass& New##T(utils::PassManager& PM) { return PM.AddPass<T>(); }
371 
372 #define DECLARE_PASS(T) utils::Pass& New##T(utils::PassManager& PM);
373 
374 } // namespace utils