feat: migrate from OpView to DoubleLinkedList<Op*> + multiple

funciton arguments
This commit is contained in:
2026-01-06 16:39:51 +01:00
parent 806e20d9b1
commit 589586db51
10 changed files with 136 additions and 89 deletions

View File

@@ -1,5 +1,6 @@
#pragma once
#include "ir/op.hpp"
#include "prelude/linkedlist.hpp"
#include "prelude/string.hpp"
class CodeGenerator
@@ -9,7 +10,7 @@ public:
virtual ~CodeGenerator() {}
public:
virtual bool Generate(const IR::OpView* ops) = 0;
virtual bool Generate(const DoubleLinkedList<IR::Op*>* ops) = 0;
StringView GetOutput() { return output().view(); }
protected:

View File

@@ -13,30 +13,35 @@ public:
FasmX86_64Generator() = default;
public:
bool Generate(const IR::OpView* ops) override
bool Generate(const DoubleLinkedList<IR::Op*>* ops) override
{
m_ops = ops;
output().Extend("format ELF64\n");
output().Extend("section '.text' executable\n");
for (size_t i = 0; i < ops->size; ++i)
for (m_current = m_ops->Begin(); m_current != nullptr; node_next())
{
GenerateStatement(ops->data[i]);
GenerateStatement();
}
return true;
}
private:
void GenerateExtern(IR::ExternOp* extrn)
void GenerateExtern()
{
auto extrn = current<IR::ExternOp>();
// TODO: instead of __<symbol, use some random UID or hash string
// for safety and collision considerations
output().AppendFormat("extrn '%s' as __%s\n", extrn->symbol().c_str(), extrn->symbol().c_str());
output().AppendFormat("%s = PLT __%s\n", extrn->symbol().c_str(), extrn->symbol().c_str());
}
void GenerateFunction(IR::FnOp* fn)
void GenerateFunction()
{
auto fn = current<IR::FnOp>();
m_slots.clear();
m_stackCounter = 0;
@@ -45,17 +50,23 @@ private:
output().Extend(" push rbp\n");
output().Extend(" mov rbp, rsp\n");
for (auto cur = fn->body().ops().Begin(); cur != nullptr; cur = cur->next)
auto fnNode = node<IR::Op>();
auto ops = m_ops;
m_ops = &fn->body().ops();
for (m_current = m_ops->Begin(); m_current != nullptr; node_next())
{
GenerateStatement(cur->value);
GenerateStatement();
}
m_ops = ops;
m_current = fnNode;
output().Extend(" leave\n");
output().Extend(" ret\n");
}
void GenerateCall(IR::CallOp* call)
void GenerateCall()
{
auto call = current<IR::CallOp>();
// TODO: support stack spilled arguments
assert(call->args().size < 7 && "stack arguments not supported yet");
const char *regs[6] = {"edi", "esi", "edx", "ecx","e8", "e9"};
@@ -77,16 +88,29 @@ private:
output().AppendFormat(" mov %s [rbp-%d], eax\n", size, sp);
}
void GenerateAllocate(IR::AllocateOp* alloc)
void GenerateAllocate()
{
// TODO: support other types
assert(alloc->Type()->kind == IR::ValueHandle::Type::Kind::Int);
auto totalAllocSize = 0;
while (current<IR::Op>()->GetType() == IR::OpType::ALLOCATE)
{
auto alloc = current<IR::AllocateOp>();
// TODO: support other types
assert(alloc->Type()->kind == IR::ValueHandle::Type::Kind::Int);
// TODO: dynamic size + alignment
auto allocSize = 4;
totalAllocSize += allocSize;
m_stackCounter += allocSize;
m_slots.insert(std::make_pair(alloc->result()->GetId(), m_stackCounter));
EnsureSlot(alloc->result());
if (seek<IR::Op>() && seek<IR::Op>()->get()->GetType() == IR::OpType::ALLOCATE) node_next();
else break;
};
output().AppendFormat(" sub rsp, %d\n", totalAllocSize);
}
void GenerateStore(IR::StoreOp* store)
void GenerateStore()
{
auto store = current<IR::StoreOp>();
auto sp = EnsureSlot(store->dst());
// TODO: support other types
@@ -103,8 +127,9 @@ private:
}
}
void GenerateLoad(IR::LoadOp* load)
void GenerateLoad()
{
auto load = current<IR::LoadOp>();
auto sp = EnsureSlot(load->Ptr());
// TODO: support other types
@@ -115,8 +140,9 @@ private:
output().AppendFormat(" mov dword [rbp-%d], eax\n", sp);
}
void GenerateMath(IR::MathOp* math)
void GenerateMath()
{
auto math = current<IR::MathOp>();
StringBuilder sb;
switch(math->GetType())
@@ -157,30 +183,31 @@ private:
output().AppendFormat(" mov %s [rbp-%d], eax\n", size, sp);
}
void GenerateStatement(IR::Op* op)
void GenerateStatement()
{
switch(op->GetType())
{
case IR::OpType::EXTERN:
return GenerateExtern(reinterpret_cast<IR::ExternOp*>(op));
case IR::OpType::FN:
return GenerateFunction(reinterpret_cast<IR::FnOp*>(op));
case IR::OpType::CALL:
return GenerateCall(reinterpret_cast<IR::CallOp*>(op));
case IR::OpType::ALLOCATE:
return GenerateAllocate(reinterpret_cast<IR::AllocateOp*>(op));
case IR::OpType::STORE:
return GenerateStore(reinterpret_cast<IR::StoreOp*>(op));
case IR::OpType::LOAD:
return GenerateLoad(reinterpret_cast<IR::LoadOp*>(op));
case IR::OpType::ADD:
case IR::OpType::SUB:
case IR::OpType::MUL:
case IR::OpType::DIV:
return GenerateMath(reinterpret_cast<IR::MathOp*>(op));
// TODO:
default: output().AppendFormat(" ; %d not implemented\n", op->GetType());
}
auto op = current<IR::Op>();
switch(op->GetType())
{
case IR::OpType::EXTERN:
return GenerateExtern();
case IR::OpType::FN:
return GenerateFunction();
case IR::OpType::CALL:
return GenerateCall();
case IR::OpType::ALLOCATE:
return GenerateAllocate();
case IR::OpType::STORE:
return GenerateStore();
case IR::OpType::LOAD:
return GenerateLoad();
case IR::OpType::ADD:
case IR::OpType::SUB:
case IR::OpType::MUL:
case IR::OpType::DIV:
return GenerateMath();
// TODO:
default: output().AppendFormat(" ; %d not implemented\n", op->GetType());
}
}
private:
@@ -202,6 +229,19 @@ private:
}
private:
template<typename T>
ListNode<T*>* node() { return reinterpret_cast<ListNode<T*>*>(m_current); }
template<typename T>
ListNode<T*>* seek() { return m_current && m_current->next ? reinterpret_cast<ListNode<T*>*>(m_current->next) : nullptr; }
void node_next() { assert(m_current); m_current = m_current->next; }
template<typename T>
T* current() const { return reinterpret_cast<T*>(m_current->value); }
private:
const DoubleLinkedList<IR::Op*>* m_ops;
ListNode<IR::Op*>* m_current;
std::unordered_map<uint32_t, uint32_t> m_slots;
uint32_t m_stackCounter = 0;
};

View File

@@ -10,8 +10,8 @@ using BlockID = unsigned int;
class Block
{
public:
Block(BlockID id, const OpView& ops)
: m_id(id), m_ops(DoubleLinkedList<Op*>::FromView(ops)) {}
Block(BlockID id, DoubleLinkedList<Op*> ops)
: m_id(id), m_ops(ops) {}
public:
DoubleLinkedList<Op*>& ops() { return m_ops; }
public:

View File

@@ -2,6 +2,7 @@
#include <unordered_map>
#include "parser/nodes.hpp"
#include "prelude/error.hpp"
#include "prelude/linkedlist.hpp"
#include "prelude/string.hpp"
#include "ir/value.hpp"
#include "ir/ops.hpp"
@@ -13,7 +14,7 @@ namespace IR
{
public:
IRBuilder(const StringView &filename, const Node *root)
: m_root(root), m_ops(new OpBuilder()), m_filename(filename) {}
: m_root(root), m_ops(new DoubleLinkedList<Op*>), m_filename(filename) {}
public:
// TODO: support other literals
@@ -28,8 +29,8 @@ namespace IR
auto value = ParseExpression(varDecl->value());
// TODO: gather type information from var decl signature, aka local <int> v = 0;
auto dst = AllocateNamed<Pointer>(value->GetType());
m_ops->Push(new AllocateOp(dst, value->GetType()));
m_ops->Push(new StoreOp(value, reinterpret_cast<Pointer *>(dst)));
m_ops->Append(m_ops->New(new AllocateOp(dst, value->GetType())));
m_ops->Append(m_ops->New(new StoreOp(value, reinterpret_cast<Pointer *>(dst))));
m_locals.insert(std::make_pair(varDecl->name(), reinterpret_cast<Pointer *>(dst)));
}
@@ -42,22 +43,21 @@ namespace IR
assert(false);
}
auto dst = AllocateNamed<Instruction>(m_locals[var->name()]->GetValueType());
m_ops->Push(new LoadOp(dst, m_locals[var->name()]));
m_ops->Append(m_ops->New(new LoadOp(dst, m_locals[var->name()])));
return reinterpret_cast<ValueHandle *>(dst);
}
ValueHandle *ParseFnCall(const FnCallNode *fn)
{
// TODO: support multiple args
auto argRegs = ValueBuilder();
if (fn->arg() != nullptr)
auto args = ValueBuilder();
for (size_t i = 0; i < fn->args().size; ++i)
{
auto arg = ParseExpression(fn->arg());
argRegs.Push(arg);
auto arg = ParseExpression(fn->args().data[i]);
args.Push(arg);
}
// TODO: gather return type of the function
auto dst = AllocateNamed<Instruction>(new ValueHandle::Type {ValueHandle::Type::Kind::Int});
m_ops->Push(new CallOp(dst, fn->name(), argRegs.view()));
m_ops->Append(m_ops->New(new CallOp(dst, fn->name(), args.view())));
return dst;
}
@@ -93,16 +93,16 @@ namespace IR
switch (expr->op())
{
case ExpressionNode::Operator::Plus:
m_ops->Push(new MathOp(dst, lhs, rhs, OpType::ADD));
m_ops->Append(m_ops->New(new MathOp(dst, lhs, rhs, OpType::ADD)));
break;
case ExpressionNode::Operator::Multiply:
m_ops->Push(new MathOp(dst, lhs, rhs, OpType::MUL));
m_ops->Append(m_ops->New(new MathOp(dst, lhs, rhs, OpType::MUL)));
break;
case ExpressionNode::Operator::Minus:
m_ops->Push(new MathOp(dst, lhs, rhs, OpType::SUB));
m_ops->Append(m_ops->New(new MathOp(dst, lhs, rhs, OpType::SUB)));
break;
case ExpressionNode::Operator::Divide:
m_ops->Push(new MathOp(dst, lhs, rhs, OpType::DIV));
m_ops->Append(m_ops->New(new MathOp(dst, lhs, rhs, OpType::DIV)));
break;
default:
assert(0 && "unreachable");
@@ -131,12 +131,12 @@ namespace IR
}
}
auto ops = EndBlock();
auto block = Block(m_block_counter++, std::move(ops->view()));
auto block = Block(m_block_counter++, *ops);
operator delete(ops);
return block;
}
OpView Build()
DoubleLinkedList<Op*>* Build()
{
assert(m_root->GetType() == NodeType::Program && "root should be a program");
auto program = reinterpret_cast<const ProgramNode *>(m_root);
@@ -144,30 +144,30 @@ namespace IR
// Externs
for (auto &extrn : program->externs())
{
m_ops->Push(new ExternOp(extrn->symbol()));
m_ops->Append(m_ops->New(new ExternOp(extrn->symbol())));
}
// Functions
for (auto &fn : program->funcs())
{
auto block = ParseBlock(fn->body());
m_ops->Push(new FnOp(fn->name(), fn->params(), std::move(block)));
m_ops->Append(m_ops->New(new FnOp(fn->name(), fn->params(), std::move(block))));
}
return OpView(m_ops->data, m_ops->size);
return m_ops;
}
public:
OpView ops() const { return OpView(m_ops->data, m_ops->size); }
const DoubleLinkedList<Op*>* ops() const { return m_ops; }
private:
void StartBlock()
{
m_containers.Push(m_ops);
m_ops = new OpBuilder();
m_ops = new DoubleLinkedList<Op*>;
}
OpBuilder *EndBlock()
DoubleLinkedList<Op*> *EndBlock()
{
assert(m_containers.size > 0 && "containers stack is empty");
auto current = m_ops;
@@ -193,13 +193,13 @@ namespace IR
const Node *m_root = nullptr;
StringView m_filename;
OpBuilder *m_ops = nullptr;
DoubleLinkedList<Op*> *m_ops = nullptr;
unsigned int m_value_counter = 0;
unsigned int m_block_counter = 0;
std::unordered_map<StringView, Pointer *> m_locals;
Builder<OpBuilder *> m_containers;
Builder<DoubleLinkedList<Op*>*> m_containers;
};
} // namespace IR

View File

@@ -45,14 +45,15 @@ public:
// m_lexer->NextExpect(TokenType::Id);
// char* name = strdup(m_lexer->token().string);
m_lexer->NextExpect('(');
Node* arg = nullptr;
// TODO: support multiple arguments
if (m_lexer->seek_token()->token != ')')
Builder<Node*> args;
while (m_lexer->seek_token()->token != ')')
{
arg = ParseExpression();
auto arg = ParseExpression();
args.Push(arg);
if (m_lexer->seek_token()->token == ',') assert(m_lexer->NextToken());
}
m_lexer->NextExpect(')');
return new FnCallNode(name, arg);
return new FnCallNode(name, std::move(args.view()));
}
Node* ParseFactor()

View File

@@ -24,7 +24,7 @@ class Node
public:
virtual NodeType GetType() const = 0;
virtual ~Node() {}
};
};
class ExpressionNode : public Node
{
@@ -144,20 +144,17 @@ class FnCallNode : public Node
{
public:
// TODO: support multiple arguments
FnCallNode(const StringView& name, Node* arg)
: m_name(name), m_arg(arg) {}
~FnCallNode() override {
delete m_arg;
}
FnCallNode(const StringView& name, View<Node*>&& arg)
: m_name(name), m_args(arg) {}
~FnCallNode() override = default;
NODE_TYPE(FnCall)
public:
const StringView& name() const { return m_name; }
// TODO: support multiple args
const Node* arg() const { return m_arg; }
const View<Node*>& args() const { return m_args; }
private:
StringView m_name;
Node* m_arg;
View<Node*> m_args;
};
class VariableNode : public Node
@@ -214,4 +211,4 @@ public:
private:
std::vector<FnDeclNode*> m_funcs;
std::vector<ExternNode*> m_externs;
};
};

View File

@@ -8,12 +8,16 @@ struct ListNode
T value;
ListNode* prev = nullptr;
ListNode* next = nullptr;
public:
T& get() noexcept { return value; }
const T& get() const noexcept { return value; }
};
template<typename T>
class DoubleLinkedList {
public:
DoubleLinkedList() = default;
~DoubleLinkedList() = default;
public:
static DoubleLinkedList<T> FromView(const View<T> &view)
@@ -26,7 +30,6 @@ public:
return list;
}
public:
View<T> ToView()
{
Builder<T> b;
@@ -37,6 +40,9 @@ public:
return b.view();
}
public:
ListNode<T>* New(T value) const { return new ListNode<T>(value); }
public:
void Append(ListNode<T>* node)
{