#pragma once #include "string.hpp" #include "ast.hpp" namespace IR { enum class OpType { EXTERN = 0, FN, LOAD_CONST, LOAD, STORE, ADD, CALL, COUNT_OPS, }; #define OP_TYPE(x) \ OpType GetType() const override { return OpType::x; } using Reg = int; using RegBuilder = Builder; using RegView = View; class Op { public: virtual OpType GetType() const = 0; virtual ~Op() {} virtual StringView Format(int indent) const = 0; }; using OpView = View; using OpBuilder = Builder; class Valued { public: Valued(Reg dest) : m_dest(dest) {} ~Valued() = default; public: Reg result() const { return m_dest; } private: Reg m_dest; }; class ExternOp : public Op { public: ExternOp(StringView symbol) : m_symbol(symbol) {} ~ExternOp() {} OP_TYPE(EXTERN) public: StringView Format(int indent) const override { StringBuilder sb; sb.AppendIndent(indent); sb << "EXTRN " << m_symbol.c_str(); return sb.view(); } public: const StringView& symbol() const { return m_symbol; } private: StringView m_symbol; }; class FnOp : public Op { public: FnOp(StringView name, const CompoundNode* body, const View& params); ~FnOp() {} OP_TYPE(FN) public: StringView Format(int indent) const override { StringBuilder sb; sb.AppendIndent(indent); sb << "LABEL " << m_name.c_str() << ':' << '\n'; for (size_t i = 0; i < m_ops.size; ++i) { sb << m_ops.data[i]->Format(indent + 2) << '\n'; } return sb.view(); } public: const StringView& name() const { return m_name; } const OpView& ops() const { return m_ops; } const View& params() const { return m_params; } private: StringView m_name; OpView m_ops; View m_params; }; class LoadConstOp : public Op, public Valued { public: LoadConstOp(Reg dest, long value) : Valued(dest), m_value(value) {} ~LoadConstOp() {} OP_TYPE(LOAD_CONST) public: StringView Format(int indent) const override { StringBuilder sb; sb.AppendIndent(indent); sb << 't' << result() << " = LOAD_CONST " << m_value; return sb.view(); } public: long value() const { return m_value; } private: long m_value; }; class LoadOp : public Op, public Valued { public: LoadOp(Reg dest, StringView addr) : Valued(dest), m_addr(addr) {} ~LoadOp() {} OP_TYPE(LOAD) public: StringView Format(int indent) const override { StringBuilder sb; sb.AppendIndent(indent); sb << 't' << result() << " = LOAD \"" << m_addr.c_str() << "\""; return sb.view(); } public: const StringView& addr() const { return m_addr; } private: StringView m_addr; }; class StoreOp : public Op { public: StoreOp(StringView addr, Reg src) : m_addr(addr), m_src(src) {} ~StoreOp() {} OP_TYPE(STORE) public: StringView Format(int indent) const override { StringBuilder sb; sb.AppendIndent(indent); sb << "STORE \"" << m_addr.c_str() << "\", t" << m_src; return sb.view(); } public: const StringView& addr() const { return m_addr; } Reg src() const { return m_src; } private: StringView m_addr; Reg m_src; }; class AddOp : public Op, public Valued { public: AddOp(Reg dest, Reg lhs, Reg rhs) : Valued(dest), m_lhs(lhs), m_rhs(rhs) {} ~AddOp() {} OP_TYPE(ADD) public: StringView Format(int indent) const override { StringBuilder sb; sb.AppendIndent(indent); sb << 't' << result() << " = ADD t" << m_lhs << ", t" << m_rhs; return sb.view(); } public: Reg lhs() const { return m_lhs; } Reg rhs() const { return m_rhs; } private: Reg m_lhs; Reg m_rhs; }; class CallOp : public Op, public Valued { public: CallOp(Reg dest, StringView callee, RegView args) : Valued(dest), m_callee(callee), m_args(args) {} ~CallOp() {} OP_TYPE(CALL) public: StringView Format(int indent) const override { StringBuilder sb; for (size_t i = 0; i < m_args.size; ++i) { sb.AppendIndent(indent); sb << "PARAM t" << m_args.data[i] << '\n'; } sb.AppendIndent(indent); sb << 't' << result() << " = CALL " << m_callee.c_str(); return sb.view(); } public: const StringView& callee() const { return m_callee; } const RegView& args() const { return m_args; } private: StringView m_callee; RegView m_args; }; class IRBuilder { public: IRBuilder(const Node* root) : m_root(root) {} public: // TODO: support other literals Reg ParseIntLiteral(const IntLiteralNode* literal) { auto dst = AllocateRegister(); m_ops.Push(new LoadConstOp(dst, literal->integer())); return dst; } Reg ParseVariable(const VariableNode* var) { auto dst = AllocateRegister(); m_ops.Push(new LoadOp(dst, var->name())); return dst; } Reg ParseFnCall(const FnCallNode* fn) { // TODO: support multiple args auto argRegs = RegBuilder(); if (fn->arg() != nullptr) { auto arg = ParseExpression(fn->arg()); argRegs.Push(arg); } auto dst = AllocateRegister(); m_ops.Push(new CallOp(dst, fn->name(), argRegs.view())); return dst; } Reg ParseFactor(const Node* factor) { switch(factor->GetType()) { case NodeType::IntLiteral: return ParseIntLiteral(reinterpret_cast(factor)); case NodeType::Variable: return ParseVariable(reinterpret_cast(factor)); case NodeType::FnCall: return ParseFnCall(reinterpret_cast(factor)); default: assert(0 && "some factor may not be handled"); break; } assert(0 && "unreachable"); return -1; } Reg ParseExpression(const Node* expression) { if (expression->GetType() == NodeType::Expression) { auto expr = reinterpret_cast(expression); auto lhs = ParseExpression(expr->left()); auto rhs = ParseExpression(expr->right()); auto dst = AllocateRegister(); assert(4 == static_cast(ExpressionNode::Operator::COUNT_OPERATORS) && "some operators may not be handled"); switch (expr->op()) { case ExpressionNode::Operator::Plus: m_ops.Push(new AddOp(dst, lhs, rhs)); break; default: assert(0 && "TODO: implement other operations"); break; } return dst; } return ParseFactor(expression); } void ParseVarDecl(const VarDeclNode* varDecl) { auto value = ParseExpression(varDecl->value()); m_ops.Push(new StoreOp(varDecl->name(), value)); } void ParseBlock(const CompoundNode* compound) { for (auto &statement : *compound) { switch(statement->GetType()) { case NodeType::VarDecl: ParseVarDecl(reinterpret_cast(statement)); continue; default: ParseExpression(statement); continue; } } } OpView Build() { assert(m_root->GetType() == NodeType::Program && "root should be a program"); auto program = reinterpret_cast(m_root); // Externs for (auto &extrn : program->externs()) { m_ops.Push(new ExternOp(extrn->symbol())); } // Functions for (auto &fn : program->funcs()) { m_ops.Push(new FnOp(fn->name(), fn->body(), fn->params())); } return OpView(m_ops.data, m_ops.size); } public: // TODO: think about safety (copying m_ops.data before giving) OpView ops() const { return OpView(m_ops.data, m_ops.size); } private: Reg AllocateRegister() { return m_reg_counter++; } private: OpBuilder m_ops; const Node* m_root = nullptr; Reg m_reg_counter = 0; }; } // namespace IR