Files
pl/include/ir/ir.hpp

168 lines
4.8 KiB
C++

#pragma once
#include <unordered_map>
#include "prelude/string.hpp"
#include "parser/ast.hpp"
#include "ir/slot.hpp"
#include "ir/allocator.hpp"
#include "ir/value.hpp"
#include "ir/ops.hpp"
namespace IR
{
class IRBuilder
{
public:
IRBuilder(const Node* root)
: m_root(root), m_ops(new OpBuilder()) {}
public:
// TODO: support other literals
Value ParseIntLiteral(const IntLiteralNode* literal)
{
auto dst = AllocateRegister();
m_ops->Push(new LoadConstOp(dst, literal->integer()));
return dst;
}
Value ParseVariable(const VariableNode* var)
{
// auto dst = AllocateRegister();
// m_ops->Push(new LoadOp(dst, var->name()));
if (m_locals.find(var->name()) == m_locals.end())
{
// TODO: throw proper error
assert(0 && "ERROR: variable does not exist");
}
return m_locals[var->name()];
}
Value ParseFnCall(const FnCallNode* fn)
{
// TODO: support multiple args
auto argRegs = ValueBuilder();
if (fn->arg() != nullptr)
{
auto arg = ParseExpression(fn->arg());
argRegs.Push(arg);
}
auto dst = AllocateRegister();
m_ops->Push(new CallOp(dst, fn->name(), argRegs.view()));
return dst;
}
Value ParseFactor(const Node* factor)
{
switch(factor->GetType())
{
case NodeType::IntLiteral: return ParseIntLiteral(reinterpret_cast<const IntLiteralNode*>(factor));
case NodeType::Variable: return ParseVariable(reinterpret_cast<const VariableNode*>(factor));
case NodeType::FnCall: return ParseFnCall(reinterpret_cast<const FnCallNode*>(factor));
default: assert(0 && "some factor may not be handled"); break;
}
assert(0 && "unreachable");
return Value();
}
Value ParseExpression(const Node* expression)
{
if (expression->GetType() == NodeType::Expression)
{
auto expr = reinterpret_cast<const ExpressionNode*>(expression);
auto lhs = ParseExpression(expr->left());
auto rhs = ParseExpression(expr->right());
auto dst = AllocateRegister();
assert(4 == static_cast<int>(ExpressionNode::Operator::COUNT_OPERATORS) && "some operators may not be handled");
switch (expr->op())
{
case ExpressionNode::Operator::Plus: m_ops->Push(new AddOp(dst, lhs, rhs)); break;
default: assert(0 && "TODO: implement other operations"); break;
}
return dst;
}
return ParseFactor(expression);
}
void ParseVarDecl(const VarDeclNode* varDecl)
{
auto value = ParseExpression(varDecl->value());
// m_ops->Push(new StoreOp(varDecl->name(), value));
m_locals.insert(std::make_pair(varDecl->name(), value));
}
Block ParseBlock(const CompoundNode* compound)
{
StartBlock();
for (auto &statement : *compound)
{
switch(statement->GetType())
{
case NodeType::VarDecl: ParseVarDecl(reinterpret_cast<VarDeclNode*>(statement)); continue;
default: ParseExpression(statement); continue;
}
}
auto ops = EndBlock();
auto block = Block(m_block_counter++, std::move(ops->view()));
operator delete(ops);
return block;
}
OpView Build()
{
assert(m_root->GetType() == NodeType::Program && "root should be a program");
auto program = reinterpret_cast<const ProgramNode*>(m_root);
// Externs
for (auto &extrn : program->externs())
{
m_ops->Push(new ExternOp(extrn->symbol()));
}
// Functions
for (auto &fn : program->funcs())
{
auto block = ParseBlock(fn->body());
m_ops->Push(new FnOp(fn->name(), fn->params(), std::move(block)));
}
return OpView(m_ops->data, m_ops->size);
}
public:
// TODO: think about safety (copying m_ops->data before giving)
OpView ops() const { return OpView(m_ops->data, m_ops->size); }
private:
void StartBlock()
{
m_containers.Push(m_ops);
m_ops = new OpBuilder();
}
OpBuilder* EndBlock()
{
assert(m_containers.size > 0 && "containers stack is empty");
auto current = m_ops;
m_ops = m_containers.data[m_containers.size - 1];
m_containers.size--;
return current;
}
private:
Value AllocateRegister()
{
return Value(m_value_counter++);
}
private:
const Node* m_root = nullptr;
OpBuilder* m_ops = nullptr;
unsigned int m_value_counter = 0;
unsigned int m_block_counter = 0;
std::unordered_map<StringView, Value> m_locals;
Builder<OpBuilder*> m_containers;
};
} // namespace IR