Joos1W Compiler Framework
All Classes Functions Typedefs Pages
CGExpr.cc
1 #include "codegen/CGExpr.h"
2 
3 #include <utility>
4 
5 #include "ast/Decl.h"
6 #include "ast/Type.h"
7 #include "codegen/CodeGen.h"
8 #include "semantic/NameResolver.h"
9 #include "tir/Constant.h"
10 #include "tir/IRBuilder.h"
11 #include "tir/Instructions.h"
12 #include "tir/Type.h"
13 #include "utils/Utils.h"
14 
15 using namespace tir;
16 namespace ex = ast::exprnode;
17 using T = codegen::CGExprEvaluator::T;
18 
19 /* ===--------------------------------------------------------------------=== */
20 // Conversion functions in T::
21 /* ===--------------------------------------------------------------------=== */
22 
23 tir::Value* T::asRValue(tir::IRBuilder& builder) const {
24  assert(kind_ == Kind::L || kind_ == Kind::R);
25  auto [_, type, value] = std::get<TirWrapped>(data_);
26  if(kind_ == Kind::L) {
27  assert(!type->isPointerType() || isAstTypeReference(astType()));
28  return builder.createLoadInstr(type, value);
29  } else {
30  return value;
31  }
32 }
33 
34 tir::Value* T::asLValue() const {
35  assert(kind_ == Kind::L);
36  return std::get<TirWrapped>(data_).value;
37 }
38 
39 tir::Value* T::asFn() const {
40  assert(kind_ == Kind::StaticFn || kind_ == Kind::MemberFn);
41  return std::get<FnWrapped>(data_).fn;
42 }
43 
44 ast::Type const* T::astType() const {
45  if(kind_ == Kind::L || kind_ == Kind::R) {
46  return std::get<TirWrapped>(data_).astType;
47  } else if(kind_ == Kind::AstType) {
48  return std::get<ast::Type const*>(data_);
49  }
50  assert(false);
51 }
52 
53 ast::Decl const* T::asDecl() const {
54  if(kind_ == Kind::AstDecl)
55  return std::get<ast::Decl const*>(data_);
56  else if(kind_ == Kind::StaticFn || kind_ == Kind::MemberFn)
57  return std::get<FnWrapped>(data_).decl;
58  assert(false);
59 }
60 
61 tir::Type* T::irType() const {
62  assert(kind_ == Kind::L || kind_ == Kind::R);
63  return std::get<TirWrapped>(data_).type;
64 }
65 
66 bool T::validate(CodeGenerator& cg) const {
67  // 1. Check the values are not empty
68  switch(kind_) {
69  case Kind::L:
70  case Kind::R:
71  assert(irType() != nullptr);
72  assert(std::get<TirWrapped>(data_).value != nullptr);
73  break;
74  case Kind::StaticFn:
75  case Kind::MemberFn:
76  assert(std::get<FnWrapped>(data_).fn != nullptr);
77  break;
78  case Kind::AstType:
79  assert(std::get<ast::Type const*>(data_) != nullptr);
80  break;
81  case Kind::AstDecl:
82  assert(std::get<ast::Decl const*>(data_) != nullptr);
83  break;
84  }
85  // 2. If the kind is an L/R value, check the AST type agrees
86  if(kind_ == Kind::R || kind_ == Kind::L) {
87  auto type = std::get<TirWrapped>(data_).type;
88  auto astTy = std::get<TirWrapped>(data_).astType;
89  assert(type == cg.emitType(astTy));
90  }
91  return true;
92 }
93 
94 void T::dump() const {
95  switch(kind_) {
96  case Kind::L:
97  std::cout << "L-value: ";
98  std::get<TirWrapped>(data_).value->dump();
99  break;
100  case Kind::R:
101  std::cout << "R-value: ";
102  std::get<TirWrapped>(data_).value->dump();
103  break;
104  case Kind::StaticFn:
105  std::cout << "Static function: ";
106  asFn()->dump();
107  break;
108  case Kind::MemberFn:
109  std::cout << "Member function: ";
110  asFn()->dump();
111  break;
112  case Kind::AstType:
113  std::cout << "AST type: ";
114  astType()->dump();
115  break;
116  case Kind::AstDecl:
117  std::cout << "AST decl: ";
118  asDecl()->dump();
119  break;
120  }
121 }
122 
123 /* ===--------------------------------------------------------------------=== */
124 // CodeGenerator expression evaluator, helper functions first
125 /* ===--------------------------------------------------------------------=== */
126 
127 namespace codegen {
128 T CGExprEvaluator::castIntegerType(ast::Type const* aTy, tir::Type* ty,
129  T value) const {
130  using CastOp = ICastInst::CastOp;
131  auto srcAstTy = cast<ast::BuiltInType>(value.astType());
132  auto dstBits = cast<tir::IntegerType>(ty)->getBitWidth();
133  auto srcBits = srcAstTy->typeSizeBits();
134  auto isSrcSigned = srcAstTy->getKind() != ast::BuiltInType::Kind::Char;
135  auto isNarrowing = dstBits < srcBits;
136  auto isWidening = dstBits > srcBits;
137  Instruction* castInst = nullptr;
138  // Narrowing == truncation
139  if(isNarrowing) {
140  castInst = cg.builder.createICastInstr(
141  CastOp::Trunc, value.asRValue(cg.builder), ty);
142  }
143  // Widening == sign extension, if the source is signed
144  else if(isWidening && isSrcSigned) {
145  castInst = cg.builder.createICastInstr(
146  CastOp::SExt, value.asRValue(cg.builder), ty);
147  }
148  // Widening == zero extension, if the source is unsigned
149  else if(isWidening && !isSrcSigned) {
150  castInst = cg.builder.createICastInstr(
151  CastOp::ZExt, value.asRValue(cg.builder), ty);
152  }
153  // Identity cast
154  else {
155  return value;
156  }
157  return T::R(aTy, castInst);
158 }
159 } // namespace codegen
160 
161 static CmpInst::Predicate getPredicate(ex::BinaryOp::OpType op) {
162  using OpType = ex::BinaryOp::OpType;
163  switch(op) {
164  case OpType::GreaterThan:
165  return CmpInst::Predicate::GT;
166  case OpType::GreaterThanOrEqual:
167  return CmpInst::Predicate::GE;
168  case OpType::LessThan:
169  return CmpInst::Predicate::LT;
170  case OpType::LessThanOrEqual:
171  return CmpInst::Predicate::LE;
172  case OpType::Equal:
173  return CmpInst::Predicate::EQ;
174  case OpType::NotEqual:
175  return CmpInst::Predicate::NE;
176  default:
177  assert(false);
178  }
179  std::unreachable();
180 }
181 
182 static Instruction::BinOp getBinOp(ex::BinaryOp::OpType op) {
183  using OpType = ex::BinaryOp::OpType;
184  switch(op) {
185  case OpType::BitwiseAnd:
186  return Instruction::BinOp::And;
187  case OpType::BitwiseOr:
188  return Instruction::BinOp::Or;
189  case OpType::BitwiseXor:
190  return Instruction::BinOp::Xor;
191  case OpType::Add:
192  return Instruction::BinOp::Add;
193  case OpType::Subtract:
194  return Instruction::BinOp::Sub;
195  case OpType::Multiply:
196  return Instruction::BinOp::Mul;
197  case OpType::Divide:
198  return Instruction::BinOp::Div;
199  case OpType::Modulo:
200  return Instruction::BinOp::Rem;
201  default:
202  assert(false);
203  }
204  std::unreachable();
205 }
206 
207 static auto findArrayField(semantic::NameResolver& nr) {
208  for(auto field : nr.GetArrayPrototype()->fields()) {
209  if(field->name() == "length") {
210  return field;
211  }
212  }
213  assert(false && "Array prototype field not found");
214 }
215 
216 /* ===--------------------------------------------------------------------=== */
217 // Emit specific expressions
218 /* ===--------------------------------------------------------------------=== */
219 
220 namespace codegen {
221 
222 T CGExprEvaluator::mapValue(ex::ExprValue& node) const {
223  auto aTy = node.type();
224  if(auto methodName = dyn_cast<ex::MethodName>(node)) {
225  auto* methodDecl = cast<ast::MethodDecl>(methodName->decl());
226  auto kind = methodDecl->modifiers().isStatic() ? T::Kind::StaticFn
227  : T::Kind::MemberFn;
228  auto fn = cg.gvMap[methodDecl];
229  // TODO: Virtual functions should be handled somewhere here?
230  return T::Fn(kind, methodDecl, fn);
231  } else if(auto memberName = dyn_cast<ex::MemberName>(node)) {
232  auto irTy = cg.emitType(cast<ast::TypedDecl>(memberName->decl())->type());
233  // 1. If it's a field decl, handle the static and non-static cases
234  if(auto* fieldDecl = dyn_cast<ast::FieldDecl>(memberName->decl())) {
235  // a) If it's static, then grab the GV
236  if(fieldDecl->modifiers().isStatic()) {
237  auto GV = cg.gvMap[fieldDecl];
238  return T::L(aTy, irTy, GV);
239  }
240  // b) Otherwise we need to wrap it to resolve in MemberAccess
241  else {
242  return T{fieldDecl};
243  }
244  }
245  // 2. If it's a local (var) decl, grab the alloca inst
246  else {
247  auto* localDecl = cast<ast::VarDecl>(memberName->decl());
248  return T::L(aTy, irTy, cg.valueMap[localDecl]);
249  }
250  } else if(auto thisNode = dyn_cast<ex::ThisNode>(node)) {
251  // "this" will be the first argument of the function
252  return T::L(aTy, cg.emitType(aTy), curFn.args().front());
253  } else if(auto literal = dyn_cast<ex::LiteralNode>(node)) {
254  if(literal->builtinType()->isNumeric()) {
255  auto bits = static_cast<uint8_t>(literal->builtinType()->typeSizeBits());
256  auto val = literal->getAsInt();
257  return T::R(aTy, Constant::CreateInt(ctx, bits, val));
258  } else if(literal->builtinType()->isBoolean()) {
259  return T::R(aTy, Constant::CreateBool(ctx, literal->getAsInt()));
260  } else if(literal->builtinType()->isString()) {
261  // TODO: String type
262  return T::L(
263  aTy, Type::getPointerTy(ctx), Constant::CreateNullPointer(ctx));
264  } else {
265  // Null type
266  return T::R(aTy, Constant::CreateNullPointer(ctx));
267  }
268  } else if(auto type = dyn_cast<ex::TypeNode>(node)) {
269  return T{aTy};
270  }
271  std::unreachable();
272 }
273 
274 T CGExprEvaluator::evalBinaryOp(ex::BinaryOp& op, T lhs, T rhs) const {
275  using OpType = ex::BinaryOp::OpType;
276  auto aTy = op.resultType();
277  switch(op.opType()) {
278  // Assignment expression //
279  case OpType::Assignment:
280  cg.builder.createStoreInstr(rhs.asRValue(cg.builder), lhs.asLValue());
281  return lhs;
282 
283  // Comparison expressions //
284  case OpType::GreaterThan:
285  case OpType::GreaterThanOrEqual:
286  case OpType::LessThan:
287  case OpType::LessThanOrEqual:
288  case OpType::Equal:
289  case OpType::NotEqual: {
290  auto inst = cg.builder.createCmpInstr(getPredicate(op.opType()),
291  lhs.asRValue(cg.builder),
292  rhs.asRValue(cg.builder));
293  return T::R(aTy, inst);
294  }
295 
296  // Short circuit expressions //
297  case OpType::And: {
298  /*
299  curBB:
300  %v0 = i1 eval(lhs)
301  store i1 %v0, %tmp
302  br i1 %v0, bb1, bb2
303  bb1:
304  %v1 = i1 eval(rhs)
305  store i1 %v1, %tmp
306  br bb2
307  bb2:
308  %tmp as lvalue
309  */
310  auto tmp = cg.curFn->createAlloca(Type::getInt1Ty(ctx));
311  auto bb1 = cg.builder.createBasicBlock(&curFn);
312  auto bb2 = cg.builder.createBasicBlock(&curFn);
313  bb1->setName("and.true");
314  bb2->setName("and.false");
315  auto v0 = lhs.asRValue(cg.builder);
316  cg.builder.createStoreInstr(v0, tmp);
317  cg.builder.createBranchInstr(v0, bb1, bb2);
318  cg.builder.setInsertPoint(bb1);
319  auto v1 = rhs.asRValue(cg.builder);
320  cg.builder.createStoreInstr(v1, tmp);
321  cg.builder.createBranchInstr(bb2);
322  cg.builder.setInsertPoint(bb2);
323  return T::L(aTy, Type::getInt1Ty(ctx), tmp);
324  }
325  case OpType::Or: {
326  /*
327  curBB:
328  v0 = i1 eval(lhs)
329  store i1 %v0, %tmp
330  br i1 %v0, bb2, bb1
331  bb1:
332  v1 = i1 eval(rhs)
333  store i1 %v1, %tmp
334  br bb2
335  bb2:
336  %tmp as lvalue
337  */
338  auto tmp = cg.curFn->createAlloca(Type::getInt1Ty(ctx));
339  auto bb1 = cg.builder.createBasicBlock(&curFn);
340  auto bb2 = cg.builder.createBasicBlock(&curFn);
341  bb1->setName("or.true");
342  bb2->setName("or.false");
343  auto v0 = lhs.asRValue(cg.builder);
344  cg.builder.createStoreInstr(v0, tmp);
345  cg.builder.createBranchInstr(v0, bb2, bb1);
346  cg.builder.setInsertPoint(bb1);
347  auto v1 = rhs.asRValue(cg.builder);
348  cg.builder.createStoreInstr(v1, tmp);
349  cg.builder.createBranchInstr(bb2);
350  cg.builder.setInsertPoint(bb2);
351  return T::L(aTy, Type::getInt1Ty(ctx), tmp);
352  }
353 
354  // Arithmetic expressions //
355  case OpType::BitwiseAnd:
356  case OpType::BitwiseOr:
357  case OpType::BitwiseXor:
358  case OpType::Add:
359  case OpType::Subtract:
360  case OpType::Multiply:
361  case OpType::Divide:
362  case OpType::Modulo: {
363  // 1. Promote the operands to i32
364  auto lhsP = castIntegerType(aTy, Type::getInt32Ty(ctx), lhs);
365  auto rhsP = castIntegerType(aTy, Type::getInt32Ty(ctx), rhs);
366  // 2. Compute in i32
367  auto res = cg.builder.createBinaryInstr(getBinOp(op.opType()),
368  lhsP.asRValue(cg.builder),
369  rhsP.asRValue(cg.builder));
370  // 3. Narrow back to aTy implicitly
371  auto emittedTy = cg.emitType(aTy);
372  assert(res->type() == emittedTy);
373  return castIntegerType(aTy, emittedTy, T::R(aTy, res));
374  }
375 
376  // Instance of expression //
377  case OpType::InstanceOf: {
378  // TODO(kevin): Implement
379  return T::R(aTy, Constant::CreateBool(ctx, false));
380  }
381 
382  default:
383  break;
384  }
385  std::unreachable();
386 }
387 
388 T CGExprEvaluator::evalUnaryOp(ex::UnaryOp& op, T rhs) const {
389  using BinOp = Instruction::BinOp;
390  using OpType = ex::UnaryOp::OpType;
391  auto aTy = op.resultType();
392  auto value = rhs.asRValue(cg.builder);
393  auto ty = value->type();
394  switch(op.opType()) {
395  case OpType::Not:
396  case OpType::BitwiseNot: {
397  // We actually don't need to promote the operand here
398  auto allOnes = ConstantInt::AllOnes(ctx, ty);
399  auto instr = cg.builder.createBinaryInstr(BinOp::Xor, value, allOnes);
400  return T::R(aTy, instr);
401  }
402  case OpType::Plus: {
403  // Do nothing for unary plus
404  return rhs;
405  }
406  case OpType::Minus: {
407  // We also don't need to promote the operand here either
408  auto instr = cg.builder.createBinaryInstr(
409  BinOp::Sub, ConstantInt::Zero(ctx, ty), value);
410  return T::R(aTy, instr);
411  }
412  default:
413  break;
414  }
415  std::unreachable();
416 }
417 
418 T CGExprEvaluator::evalMemberAccess(ex::MemberAccess& op, T lhs, T field) const {
419  auto aTy = op.resultType();
420  auto obj = lhs.asRValue(cg.builder);
421  auto decl = field.asDecl();
422  // Special case: "field" is actually a function
423  if(field.kind() == T::Kind::MemberFn) {
424  return T::Fn(T::Kind::MemberFn, decl, field.asFn(), obj);
425  }
426  // Special case: array.length
427  else if(decl == findArrayField(cg.nr)) {
428  auto arrTy = cast<StructType>(lhs.irType());
429  auto arrSzGep =
430  cg.builder.createGEPInstr(obj, arrTy, {Constant::CreateInt32(ctx, 0)});
431  auto arrSz = cg.builder.createLoadInstr(Type::getInt32Ty(ctx), arrSzGep);
432  return T::R(aTy, arrSz);
433  }
434  // Member access
435  else {
436  assert(false);
437  }
438 }
439 
440 T CGExprEvaluator::evalMethodCall(ex::MethodInvocation& op, T method,
441  const op_array& args) const {
442  auto aTy = op.resultType();
443  std::vector<Value*> argValues;
444  // If this is a member function, push back an extra "this"
445  if(method.kind() == T::Kind::MemberFn) {
446  assert(method.thisRef());
447  argValues.push_back(method.thisRef());
448  } else {
449  assert(method.kind() == T::Kind::StaticFn);
450  }
451  // Now we can push back the arguments
452  for(auto& arg : args) {
453  argValues.push_back(arg.asRValue(cg.builder));
454  }
455  auto callVal = cg.builder.createCallInstr(method.asFn(), argValues);
456  return T::R(aTy, callVal);
457 }
458 
459 T CGExprEvaluator::evalNewObject(ex::ClassInstanceCreation& op, T object,
460  const op_array& args) const {
461  (void)op;
462  (void)object;
463  (void)args;
464  // TODO: Implement this
465  return T::L(op.resultType(),
466  Type::getPointerTy(ctx),
467  Constant::CreateNullPointer(ctx));
468 }
469 
470 T CGExprEvaluator::evalNewArray(ex::ArrayInstanceCreation& op, T type,
471  T size) const {
472  auto aTy = op.resultType();
473  auto arrTy = cast<StructType>(cg.emitType(aTy));
474  auto elemTy = cg.emitType(type.astType());
475  auto arrLength = castIntegerType(nullptr, Type::getInt32Ty(ctx), size)
476  .asRValue(cg.builder);
477  auto totalSz = cg.builder.createBinaryInstr(
478  Instruction::BinOp::Mul,
479  arrLength,
480  Constant::CreateInt32(ctx, elemTy->getSizeInBits() / 8));
481  auto arrPtr = cg.builder.createCallInstr(cu.builtinMalloc(), {totalSz});
482  auto alloca = curFn.createAlloca(cg.emitType(op.resultType()));
483  cg.emitSetArrayPtr(alloca, arrPtr);
484  cg.emitSetArraySz(alloca, arrLength);
485  auto loadArr = cg.builder.createLoadInstr(arrTy, arrPtr);
486  cg.builder.createStoreInstr(loadArr, alloca);
487  totalSz->setName("arr.sz");
488  arrPtr->setName("arr.ptr");
489  alloca->setName("arr.alloca");
490  return T::L(aTy, arrTy, alloca);
491 }
492 
493 T CGExprEvaluator::evalArrayAccess(ex::ArrayAccess& op, T array, T index) const {
494  auto arrAlloca = array.asLValue();
495  auto elemAstTy = op.resultType();
496  auto arrTy = cast<StructType>(array.irType());
497  auto arrPtr = cg.emitGetArrayPtr(arrAlloca);
498  auto arrSz = cg.emitGetArraySz(arrAlloca);
499  auto idxVal = index.asRValue(cg.builder);
500  auto lengthValid =
501  cg.builder.createCmpInstr(CmpInst::Predicate::LT, idxVal, arrSz);
502  auto bb1 = cg.builder.createBasicBlock(&curFn);
503  bb1->setName("array.oob");
504  auto bb2 = cg.builder.createBasicBlock(&curFn);
505  bb2->setName("array.inbounds");
506  cg.builder.createBranchInstr(lengthValid, bb2, bb1);
507  cg.builder.setInsertPoint(bb1);
508  cg.builder.createCallInstr(cu.builtinException(), {});
509  cg.builder.setInsertPoint(bb2);
510  auto elemPtr = cg.builder.createGEPInstr(arrPtr, arrTy, {idxVal});
511  return T::L(elemAstTy, cg.emitType(elemAstTy), elemPtr);
512 }
513 
514 T CGExprEvaluator::evalCast(ex::Cast& op, T type, T value) const {
515  auto aTy = op.resultType();
516  auto castType = type.astType();
517  if(castType->isNumeric()) {
518  // Convert either promotion or narrowing
519  return castIntegerType(aTy, cg.emitType(castType), value);
520  } else if(castType->isBoolean()) {
521  // Booleans must be identity conversion
522  return value;
523  } else if(castType->isString()) {
524  } else if(castType->isArray()) {
525  } else {
526  }
527  assert(false);
528 }
529 
530 } // namespace codegen
531 
532 /* ===--------------------------------------------------------------------=== */
533 // CodeGenerator emit router
534 /* ===--------------------------------------------------------------------=== */
535 
536 namespace codegen {
537 
538 Value* CodeGenerator::emitExpr(ast::Expr const* expr) {
539  CGExprEvaluator eval{*this};
540  T result = eval.EvaluateList(expr->list());
541  return result.asRValue(builder);
542 }
543 
544 } // namespace codegen