Files
pl/include/ir/ir.hpp

203 lines
6.6 KiB
C++

#pragma once
#include <unordered_map>
#include "parser/nodes.hpp"
#include "prelude/string.hpp"
#include "ir/value.hpp"
#include "ir/ops.hpp"
namespace IR
{
class IRBuilder
{
public:
IRBuilder(const Node *root)
: m_root(root), m_ops(new OpBuilder()) {}
public:
// TODO: support other literals
ValueHandle *ParseIntLiteral(const IntLiteralNode *literal)
{
auto dst = AllocateUnnamed<ConstantInt>(literal->integer());
return dst;
}
void ParseVarDecl(const VarDeclNode *varDecl)
{
auto value = ParseExpression(varDecl->value());
// TODO: gather type information from var decl signature, aka local <int> v = 0;
auto dst = AllocateNamed<Pointer>(value->GetType());
m_ops->Push(new AllocateOp(dst, value->GetType()));
m_ops->Push(new StoreOp(value, reinterpret_cast<Pointer *>(dst)));
m_locals.insert(std::make_pair(varDecl->name(), reinterpret_cast<Pointer *>(dst)));
}
ValueHandle *ParseVariable(const VariableNode *var)
{
if (m_locals.find(var->name()) == m_locals.end())
{
// TODO: throw proper error
assert(0 && "ERROR: variable does not exist");
}
auto dst = AllocateNamed<Instruction>(m_locals[var->name()]->GetValueType());
m_ops->Push(new LoadOp(dst, m_locals[var->name()]));
return reinterpret_cast<ValueHandle *>(dst);
}
ValueHandle *ParseFnCall(const FnCallNode *fn)
{
// TODO: support multiple args
auto argRegs = ValueBuilder();
if (fn->arg() != nullptr)
{
auto arg = ParseExpression(fn->arg());
argRegs.Push(arg);
}
// TODO: gather return type of the function
auto dst = AllocateNamed<Instruction>(new ValueHandle::Type {ValueHandle::Type::Kind::Void});
m_ops->Push(new CallOp(dst, fn->name(), argRegs.view()));
return dst;
}
ValueHandle *ParseFactor(const Node *factor)
{
switch (factor->GetType())
{
case NodeType::IntLiteral:
return ParseIntLiteral(reinterpret_cast<const IntLiteralNode *>(factor));
case NodeType::Variable:
return ParseVariable(reinterpret_cast<const VariableNode *>(factor));
case NodeType::FnCall:
return ParseFnCall(reinterpret_cast<const FnCallNode *>(factor));
default:
assert(0 && "some factor may not be handled");
break;
}
assert(0 && "unreachable");
return nullptr;
}
ValueHandle *ParseExpression(const Node *expression)
{
if (expression->GetType() == NodeType::Expression)
{
auto expr = reinterpret_cast<const ExpressionNode *>(expression);
auto lhs = ParseExpression(expr->left());
auto rhs = ParseExpression(expr->right());
auto dst = AllocateNamed<Instruction>(lhs->GetType());
assert(4 == static_cast<int>(ExpressionNode::Operator::COUNT_OPERATORS) && "some operators may not be handled");
switch (expr->op())
{
case ExpressionNode::Operator::Plus:
m_ops->Push(new MathOp(dst, lhs, rhs, OpType::ADD));
break;
case ExpressionNode::Operator::Multiply:
m_ops->Push(new MathOp(dst, lhs, rhs, OpType::MUL));
break;
case ExpressionNode::Operator::Minus:
m_ops->Push(new MathOp(dst, lhs, rhs, OpType::SUB));
break;
case ExpressionNode::Operator::Divide:
m_ops->Push(new MathOp(dst, lhs, rhs, OpType::DIV));
break;
default:
assert(0 && "unreachable");
break;
}
return dst;
}
return ParseFactor(expression);
}
Block ParseBlock(const CompoundNode *compound)
{
StartBlock();
for (auto &statement : *compound)
{
switch (statement->GetType())
{
case NodeType::VarDecl:
ParseVarDecl(reinterpret_cast<VarDeclNode *>(statement));
continue;
default:
ParseExpression(statement);
continue;
}
}
auto ops = EndBlock();
auto block = Block(m_block_counter++, std::move(ops->view()));
operator delete(ops);
return block;
}
OpView Build()
{
assert(m_root->GetType() == NodeType::Program && "root should be a program");
auto program = reinterpret_cast<const ProgramNode *>(m_root);
// Externs
for (auto &extrn : program->externs())
{
m_ops->Push(new ExternOp(extrn->symbol()));
}
// Functions
for (auto &fn : program->funcs())
{
auto block = ParseBlock(fn->body());
m_ops->Push(new FnOp(fn->name(), fn->params(), std::move(block)));
}
return OpView(m_ops->data, m_ops->size);
}
public:
OpView ops() const { return OpView(m_ops->data, m_ops->size); }
private:
void StartBlock()
{
m_containers.Push(m_ops);
m_ops = new OpBuilder();
}
OpBuilder *EndBlock()
{
assert(m_containers.size > 0 && "containers stack is empty");
auto current = m_ops;
m_ops = m_containers.data[m_containers.size - 1];
m_containers.size--;
return current;
}
private:
template <typename V, typename... Args>
ValueHandle *AllocateNamed(Args &&...args)
{
return new V(++m_value_counter, std::forward<Args>(args)...);
}
template <typename V, typename... Args>
ValueHandle *AllocateUnnamed(Args &&...args)
{
return new V(ValueHandle::kNoId, std::forward<Args>(args)...);
}
private:
const Node *m_root = nullptr;
OpBuilder *m_ops = nullptr;
unsigned int m_value_counter = 0;
unsigned int m_block_counter = 0;
std::unordered_map<StringView, Pointer *> m_locals;
Builder<OpBuilder *> m_containers;
};
} // namespace IR