Forked from bddd / bddd
10 commits behind the upstream repository.
ast.h 17.10 KiB
#ifndef BDDD_AST_H
#define BDDD_AST_H
#include <cassert>
#include <fstream>
#include <map>
#include <memory>
#include <ostream>
#include <string>
#include <utility>
#include <variant>
#include <vector>
#include "ast/type.h"
class Value;
class IRBuilder;
/* used as var_type indicator of ExprAST instance */
enum class Op {
  // Arithmetic unary operators
  POSITIVE,
  NEGATIVE,
  // Arithmetic binary operators
  PLUS,
  MINUS,
  MULTI,
  DIV,
  MOD,
  // relational binary operators
  LE,
  LEQ,
  GE,
  GEQ,
  EQ,
  NEQ,
  // Logical binary operators
  AND,
  OR,
  // Logical unary operators
  NOT,
  // Special indicator
  CONST_INT,
  CONST_FLOAT,
  LVAL,
  FUNC_CALL,
enum class VarType {
  INT,    // i32
  FLOAT,  // float
  VOID,   // void
  CHAR,   // i8
  BOOL,   // i1
  UNKNOWN,
class SymbolTable;
class AST : public std::enable_shared_from_this<AST> {
public:
  virtual ~AST() = default;
  virtual void Debug(std::ofstream &ofs, int depth) = 0;
  virtual void TypeCheck(SymbolTable &symbol_table) = 0;
7172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
virtual std::shared_ptr<Value> CodeGen(std::shared_ptr<IRBuilder> builder) = 0; protected: template <typename Derived> std::shared_ptr<Derived> shared_from_base() { return std::static_pointer_cast<Derived>(shared_from_this()); } }; class StmtAST : public AST {}; class DeclAST; class ExprAST; /** * @expr not null when single expression, this time vals is null * @vals not null when array of items, this time expr is null * @is_const available in typechecking */ class InitValAST : public AST { private: bool m_is_const; // true => expr is const or all sub-init-vals are const bool m_all_zero; // for global array, we can use "zero initialize" public: [[nodiscard]] bool IsConst() const { return m_is_const; } void SetIsConst(bool is_const) { m_is_const = is_const; } [[nodiscard]] bool AllZero() const { return m_all_zero; } void SetAllZero(bool all_zero) { m_all_zero = all_zero; } public: std::shared_ptr<ExprAST> m_expr; // if single init-val std::vector<std::unique_ptr<InitValAST>> m_vals; // if multiple sub init-vals explicit InitValAST() : m_expr(nullptr), m_vals(), m_is_const(false), m_all_zero(false) {} explicit InitValAST(std::unique_ptr<ExprAST> expr) : m_expr(std::move(expr)), m_vals(), m_is_const(false), m_all_zero(false) {} explicit InitValAST(std::unique_ptr<InitValAST> val) : m_expr(nullptr), m_vals(), m_is_const(false), m_all_zero(false) { m_vals.push_back(std::move(val)); } void AppendVal(std::unique_ptr<InitValAST> val) { m_vals.push_back(std::move(val)); } void Debug(std::ofstream &ofs, int depth) override; void TypeCheck(SymbolTable &symbol_table) override; bool FillVals(int n, int &offset, const std::vector<int> &sizes, std::vector<std::shared_ptr<ExprAST>> &vals); std::shared_ptr<Value> CodeGen(std::shared_ptr<IRBuilder> builder) override; friend class DeclAST; }; /** * @name cannot be empty * @dimensions optional (nullptr might appear in m_indices) * @decl available after typechecking */ class LValAST : public AST {
141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
private: std::string m_name; std::vector<std::unique_ptr<ExprAST>> m_indices; public: std::shared_ptr<DeclAST> m_decl; std::string Name() const { return m_name; } explicit LValAST(std::string name) : m_name(std::move(name)), m_indices(), m_decl(nullptr) {} // methods used in AST construction void AddDimension(int x); void AddDimension(std::unique_ptr<ExprAST> expr); void Debug(std::ofstream &ofs, int depth) override; bool IsSingle(); bool IsArray(); bool HasIndex(); // methods used in typechecking void TypeCheck(SymbolTable &symbol_table) override; // called only when is_const is true and not an array EvalValue Evaluate(SymbolTable &symbol_table); // methods used in codegen std::shared_ptr<Value> CodeGen(std::shared_ptr<IRBuilder> builder) override; std::shared_ptr<Value> CodeGenAddr(std::shared_ptr<IRBuilder> builder); }; class FuncCallAST; /** * @op indicates the type of expression * @lhs meaningful when op is unary or binary operator * @rhs meaningful only if op is binary operator * @int_val meaningful only if op == CONST_INT * @float_val meaningful only if op == CONST_FLOAT * @func_call meaningful only if op == FUNC_CALL * @lval meaningful only if op == LVAL * @is_const available when typechecking */ class ExprAST : public AST { private: Op m_op; std::unique_ptr<ExprAST> m_lhs; std::unique_ptr<ExprAST> m_rhs; std::unique_ptr<FuncCallAST> m_func_call; int m_int_val; float m_float_val; std::unique_ptr<LValAST> m_lval; bool m_is_const; // true => can get value from int_val or float_val private: EvalValue EvaluateInner(SymbolTable &symbol_table); public: [[nodiscard]] bool IsConst() const { return m_is_const; } int IntVal() const { assert(m_op == Op::CONST_INT); return m_int_val; } float FloatVal() const { assert(m_op == Op::CONST_FLOAT); return m_float_val;
211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280
} void SetIsConst(bool is_const) { m_is_const = is_const; } Op GetOp() const { return m_op; } explicit ExprAST(Op op, std::unique_ptr<ExprAST> lhs, std::unique_ptr<ExprAST> rhs = nullptr) : m_op(op), m_lhs(std::move(lhs)), m_rhs(std::move(rhs)), m_func_call(nullptr), m_int_val(0), m_float_val(0.0), m_lval(nullptr), m_is_const(false) {} explicit ExprAST(std::unique_ptr<FuncCallAST> func_call) : m_op(Op::FUNC_CALL), m_lhs(nullptr), m_rhs(nullptr), m_func_call(std::move(func_call)), m_int_val(0), m_float_val(0.0), m_lval(nullptr), m_is_const(false) {} explicit ExprAST(int val) : m_op(Op::CONST_INT), m_lhs(nullptr), m_rhs(nullptr), m_func_call(nullptr), m_int_val(val), m_float_val(0.0), m_lval(nullptr), m_is_const(true) {} explicit ExprAST(float val) : m_op(Op::CONST_FLOAT), m_lhs(nullptr), m_rhs(nullptr), m_func_call(nullptr), m_int_val(0), m_float_val(val), m_lval(nullptr), m_is_const(true) {} explicit ExprAST(std::unique_ptr<LValAST> lval) : m_op(Op::LVAL), m_lhs(nullptr), m_rhs(nullptr), m_func_call(nullptr), m_int_val(0), m_float_val(0.0), m_lval(std::move(lval)), m_is_const(false) {} void Debug(std::ofstream &ofs, int depth) override; void TypeCheck(SymbolTable &symbol_table) override; EvalValue Evaluate(SymbolTable &symbol_table); EvalValue EvaluateInitVal(); // no need for a symbol table std::shared_ptr<Value> CodeGen(std::shared_ptr<IRBuilder> builder) override; std::shared_ptr<Value> CodeGenAnd(std::shared_ptr<IRBuilder> builder); std::shared_ptr<Value> CodeGenOr(std::shared_ptr<IRBuilder> builder); void SetIntVal(int int_val) { m_op = Op::CONST_INT; m_lhs = nullptr; m_rhs = nullptr;
281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350
m_func_call = nullptr; m_int_val = int_val; m_float_val = 0.0; m_lval = nullptr; m_is_const = true; } void SetFloatVal(float float_val) { m_op = Op::CONST_FLOAT; m_lhs = nullptr; m_rhs = nullptr; m_func_call = nullptr; m_int_val = 0; m_float_val = float_val; m_lval = nullptr; m_is_const = true; } }; /** * @is_const indicates whether it is a const declaration * @is_global used in typechecking * @is_param not declaration of variable but an argument of function definition * @var_type only can be CONST_INT or CONST_FLOAT */ class DeclAST : public AST { private: bool m_is_const; // true => init_val is also const bool m_is_global; // false at default VarType m_var_type; std::string m_varname; bool m_is_param; // false at default public: std::vector<std::unique_ptr<ExprAST>> m_dimensions; std::unique_ptr<InitValAST> m_init_val; std::vector<std::shared_ptr<ExprAST>> m_flatten_vals; std::vector<int> m_products; // suffix product std::shared_ptr<Value> m_addr; void SetIsConst(bool is_const) { m_is_const = is_const; } void SetIsGlobal(bool is_global) { m_is_global = is_global; } void SetIsParam(bool is_param) { m_is_param = is_param; } void SetVarType(VarType var_type) { m_var_type = var_type; } void SetInitVal(std::unique_ptr<InitValAST> init_val) { m_init_val = std::move(init_val); } size_t DimensionsSize() const { return m_dimensions.size(); } bool IsGlobal() const { return m_is_global; } bool IsConst() const { return m_is_const; } bool IsParam() const { return m_is_param; } bool IsArray() const { return !m_dimensions.empty(); } VarType GetVarType() const { return m_var_type; } std::string VarName() const { return m_varname; } explicit DeclAST(std::string varname, std::unique_ptr<InitValAST> init_val = nullptr) : m_is_const(false), m_is_global(false), m_is_param(false), m_var_type(VarType::UNKNOWN), m_varname(std::move(varname)), m_init_val(std::move(init_val)), m_flatten_vals(), m_products() {} void AddDimension(int x); void AddDimension(std::unique_ptr<ExprAST> expr);
351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420
void Debug(std::ofstream &ofs, int depth) override; void TypeCheck(SymbolTable &symbol_table) override; // called only when is_const = true and flatten_vals is constructed EvalValue GetFlattenVal(SymbolTable &symbol_table, int offset); std::shared_ptr<Value> CodeGen(std::shared_ptr<IRBuilder> builder) override; }; class FuncDefAST; /** * @func_name the m_name of called function * @params the parameters of function call * @return_type the return type of function call, initially unknown */ class FuncCallAST : public AST { private: std::string m_func_name; std::vector<std::shared_ptr<ExprAST>> m_params; VarType m_return_type; // initially UNKNOWN, available after typechecking std::shared_ptr<FuncDefAST> m_func_def; // is nullptr until typechecking public: [[nodiscard]] size_t ParamsSize() const { return m_params.size(); } VarType ReturnType() const { return m_return_type; } std::string FuncName() const { return m_func_name; } public: explicit FuncCallAST(std::string func_name) : m_func_name(std::move(func_name)), m_params(), m_return_type(VarType::UNKNOWN) {} void AssignParams(std::vector<std::unique_ptr<ExprAST>> params); void Debug(std::ofstream &ofs, int depth) override; void TypeCheck(SymbolTable &symbol_table) override; std::shared_ptr<Value> CodeGen(std::shared_ptr<IRBuilder> builder) override; friend class IRBuilder; }; class CondAST : public AST { private: std::unique_ptr<ExprAST> m_expr; public: explicit CondAST(std::unique_ptr<ExprAST> expr) : m_expr(std::move(expr)) {} void Debug(std::ofstream &ofs, int depth) override; void TypeCheck(SymbolTable &symbol_table) override; std::shared_ptr<Value> CodeGen(std::shared_ptr<IRBuilder> builder) override; }; /** * @attention * FuncFParam is a special wrapper of DeclAST, in which m_is_const = false, * m_is_global = false, m_init_val = nullptr, and m_flatten_vals is meaningless * * We only use m_name, m_var_type and m_indices */ class FuncFParamAST : public AST { private: std::shared_ptr<DeclAST> m_decl;
421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490
public: explicit FuncFParamAST(VarType type, std::string name) : m_decl(std::make_shared<DeclAST>(std::move(name))) { m_decl->SetVarType(type); m_decl->SetIsParam(true); } explicit FuncFParamAST(VarType type, std::string name, std::unique_ptr<ExprAST> dimension) : m_decl(std::make_shared<DeclAST>(std::move(name))) { m_decl->SetVarType(type); m_decl->SetIsParam(true); m_decl->m_dimensions.push_back(std::move(dimension)); } void AddDimension(int x); void AddDimension(std::unique_ptr<ExprAST> expr); void Debug(std::ofstream &ofs, int depth) override; void TypeCheck(SymbolTable &symbol_table) override; std::shared_ptr<Value> CodeGen(std::shared_ptr<IRBuilder> builder) override; friend class Function; friend class FunctionArg; }; class BlockAST : public StmtAST { private: std::vector<std::shared_ptr<AST>> m_nodes; public: void AppendNodes(std::vector<std::unique_ptr<AST>> nodes); void Debug(std::ofstream &ofs, int depth) override; void TypeCheck(SymbolTable &symbol_table) override; std::shared_ptr<Value> CodeGen(std::shared_ptr<IRBuilder> builder) override; }; class FuncDefAST : public AST { private: VarType m_return_type; std::string m_func_name; std::vector<std::unique_ptr<FuncFParamAST>> m_params; std::unique_ptr<BlockAST> m_block; bool m_is_builtin; public: size_t ParamsSize() const { return m_params.size(); } VarType ReturnType() const { return m_return_type; } std::string FuncName() const { return m_func_name; } bool IsBuiltin() const { return m_is_builtin; } public: explicit FuncDefAST(VarType return_type, std::string func_name, std::vector<std::unique_ptr<FuncFParamAST>> params, std::unique_ptr<BlockAST> block, bool is_builtin = false) : m_return_type(return_type), m_func_name(std::move(func_name)), m_params(std::move(params)), m_block(std::move(block)), m_is_builtin(is_builtin) {} explicit FuncDefAST(VarType return_type, std::string func_name, std::unique_ptr<BlockAST> block, bool is_builtin = false) : m_return_type(return_type), m_func_name(std::move(func_name)), m_params(),
491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560
m_block(std::move(block)), m_is_builtin(is_builtin) {} void AssignParams(std::vector<std::unique_ptr<FuncFParamAST>> params); void Debug(std::ofstream &ofs, int depth) override; void TypeCheck(SymbolTable &symbol_table) override; std::shared_ptr<Value> CodeGen(std::shared_ptr<IRBuilder> builder) override; friend class Function; friend class FunctionDecl; }; class AssignStmtAST : public StmtAST { private: std::unique_ptr<LValAST> m_lval; std::unique_ptr<ExprAST> m_rhs; public: explicit AssignStmtAST(std::unique_ptr<LValAST> lval, std::unique_ptr<ExprAST> rhs) : m_lval(std::move(lval)), m_rhs(std::move(rhs)) {} void Debug(std::ofstream &ofs, int depth) override; void TypeCheck(SymbolTable &symbol_table) override; std::shared_ptr<Value> CodeGen(std::shared_ptr<IRBuilder> builder) override; }; class EvalStmtAST : public StmtAST { private: std::unique_ptr<ExprAST> m_expr; public: explicit EvalStmtAST(std::unique_ptr<ExprAST> expr) : m_expr(std::move(expr)) {} void Debug(std::ofstream &ofs, int depth) override; void TypeCheck(SymbolTable &symbol_table) override; std::shared_ptr<Value> CodeGen(std::shared_ptr<IRBuilder> builder) override; }; class IfStmtAST : public StmtAST { private: std::unique_ptr<CondAST> m_cond; std::unique_ptr<StmtAST> m_then; std::unique_ptr<StmtAST> m_else; public: explicit IfStmtAST(std::unique_ptr<CondAST> cond, std::unique_ptr<StmtAST> then_stmt, std::unique_ptr<StmtAST> else_stmt = nullptr) : m_cond(std::move(cond)), m_then(std::move(then_stmt)), m_else(std::move(else_stmt)) {} void Debug(std::ofstream &ofs, int depth) override; void TypeCheck(SymbolTable &symbol_table) override; std::shared_ptr<Value> CodeGen(std::shared_ptr<IRBuilder> builder) override; }; class ReturnStmtAST : public StmtAST { private: VarType m_expected_type; // initially void, filled in typechecking std::unique_ptr<ExprAST> m_ret; public:
561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621
explicit ReturnStmtAST(std::unique_ptr<ExprAST> ret = nullptr) : m_expected_type(VarType::VOID), m_ret(std::move(ret)) {} void Debug(std::ofstream &ofs, int depth) override; void TypeCheck(SymbolTable &symbol_table) override; std::shared_ptr<Value> CodeGen(std::shared_ptr<IRBuilder> builder) override; }; class WhileStmtAST : public StmtAST { private: std::unique_ptr<CondAST> m_cond; std::unique_ptr<StmtAST> m_stmt; public: explicit WhileStmtAST(std::unique_ptr<CondAST> cond, std::unique_ptr<StmtAST> stmt) : m_cond(std::move(cond)), m_stmt(std::move(stmt)) {} void Debug(std::ofstream &ofs, int depth) override; void TypeCheck(SymbolTable &symbol_table) override; std::shared_ptr<Value> CodeGen(std::shared_ptr<IRBuilder> builder) override; }; class BreakStmtAST : public StmtAST { public: void Debug(std::ofstream &ofs, int depth) override; void TypeCheck(SymbolTable &symbol_table) override; std::shared_ptr<Value> CodeGen(std::shared_ptr<IRBuilder> builder) override; }; class ContinueStmtAST : public StmtAST { public: void Debug(std::ofstream &ofs, int depth) override; void TypeCheck(SymbolTable &symbol_table) override; std::shared_ptr<Value> CodeGen(std::shared_ptr<IRBuilder> builder) override; }; class CompUnitAST : public AST { private: std::vector<std::shared_ptr<AST>> m_nodes; public: void AppendDecls(std::vector<std::unique_ptr<DeclAST>> decls); void AppendFuncDef(std::unique_ptr<FuncDefAST> funcDef); void Debug(std::ofstream &ofs, int depth) override; void TypeCheck(SymbolTable &symbol_table) override; std::shared_ptr<Value> CodeGen(std::shared_ptr<IRBuilder> builder) override; }; extern std::vector<std::shared_ptr<FuncDefAST>> g_builtin_funcs; // global void InitBuiltinFunctions(); #endif // BDDD_AST_H