Branch checker

This commit is contained in:
nub31
2025-01-26 21:52:54 +01:00
parent 6e6a1a830e
commit 738aa92da5
9 changed files with 108 additions and 40 deletions

View File

@@ -0,0 +1,32 @@
using Nub.Lang.Parsing;
namespace Nub.Lang.Branching;
public class BranchChecker
{
private readonly IReadOnlyCollection<DefinitionNode> _definitions;
public BranchChecker(IReadOnlyCollection<DefinitionNode> definitions)
{
_definitions = definitions;
}
public void Check()
{
foreach (var funcDefinition in _definitions.OfType<FuncDefinitionNode>())
{
if (funcDefinition.ReturnType.HasValue)
{
CheckBlock(funcDefinition.Body);
}
}
}
private void CheckBlock(BlockNode block)
{
if (!block.Statements.Any(s => s is ReturnNode))
{
throw new Exception("Block must contain a return statement");
}
}
}

View File

@@ -53,7 +53,7 @@ public class Generator
_builder.AppendLine();
_builder.AppendLine($" ; Call entrypoint {Entrypoint}");
_builder.AppendLine($" call {main.Label}");
_builder.AppendLine($" call {main.StartLabel}");
_builder.AppendLine();
_builder.AppendLine(" ; Exit with status code 0");
@@ -96,20 +96,20 @@ public class Generator
return _builder.ToString();
}
private void GenerateFuncDefinition(FuncDefinitionNode funcDefinition)
private void GenerateFuncDefinition(FuncDefinitionNode node)
{
var symbol = _symbolTable.ResolveFunc(funcDefinition.Name, funcDefinition.Parameters.Select(p => p.Type).ToList());
_builder.AppendLine($"; {funcDefinition.ToString()}");
_builder.AppendLine($"{symbol.Label}:");
var func = _symbolTable.ResolveFunc(node.Name, node.Parameters.Select(p => p.Type).ToList());
_builder.AppendLine($"; {node.ToString()}");
_builder.AppendLine($"{func.StartLabel}:");
_builder.AppendLine(" push rbp");
_builder.AppendLine(" mov rbp, rsp");
_builder.AppendLine($" sub rsp, {symbol.StackAllocation}");
_builder.AppendLine($" sub rsp, {func.StackAllocation}");
string[] registers = ["rdi", "rsi", "rdx", "rcx", "r8", "r9"];
for (var i = 0; i < symbol.Parameters.Count; i++)
for (var i = 0; i < func.Parameters.Count; i++)
{
var parameter = symbol.ResolveLocalVariable(symbol.Parameters.ElementAt(i).Name);
var parameter = func.ResolveLocalVariable(func.Parameters.ElementAt(i).Name);
if (i < registers.Length)
{
_builder.AppendLine($" mov [rbp - {parameter.Offset}], {registers[i]}");
@@ -122,7 +122,9 @@ public class Generator
}
}
GenerateBlock(funcDefinition.Body, symbol);
GenerateBlock(node.Body, func);
_builder.AppendLine($"{func.EndLabel}:");
_builder.AppendLine(" mov rsp, rbp");
_builder.AppendLine(" pop rbp");
_builder.AppendLine(" ret");
@@ -143,6 +145,13 @@ public class Generator
case FuncCallStatementNode funcCallStatement:
GenerateFuncCall(funcCallStatement.FuncCall, func);
break;
case ReturnNode @return:
if (@return.Value.HasValue)
{
GenerateExpression(@return.Value.Value, func);
}
_builder.AppendLine($" jmp {func.EndLabel}");
break;
case SyscallStatementNode syscallStatement:
GenerateSyscall(syscallStatement.Syscall, func);
break;
@@ -159,7 +168,7 @@ public class Generator
switch (expression)
{
case FuncCallExpressionNode funcCallExpression:
throw new NotImplementedException();
GenerateFuncCall(funcCallExpression.FuncCall, func);
break;
case IdentifierNode identifier:
GenerateIdentifier(identifier, func);
@@ -280,7 +289,7 @@ public class Generator
_builder.AppendLine($" pop {registers[i]}");
}
_builder.AppendLine($" call {symbol.Label}");
_builder.AppendLine($" call {symbol.StartLabel}");
if (stackParameters != 0)
{
_builder.AppendLine($" add rsp, {stackParameters}");

View File

@@ -21,9 +21,10 @@ public class SymbolTable
public void DefineFunc(FuncDefinitionNode funcDefinition)
{
var label = $"func{++_labelIndex}";
var startLabel = $"func{++_labelIndex}";
var endLabel = $"endfunc{_labelIndex}";
var localVariables = ResolveFunctionVariables(funcDefinition);
_functions.Add(new Func(label, funcDefinition.Name, funcDefinition.Parameters, funcDefinition.ReturnType, _globalVariables.Concat<Variable>(localVariables.Variables).ToList(), localVariables.StackSize));
_functions.Add(new Func(startLabel, endLabel, funcDefinition.Name, funcDefinition.Parameters, funcDefinition.ReturnType, _globalVariables.Concat<Variable>(localVariables.Variables).ToList(), localVariables.StackSize));
}
private (int StackSize, List<LocalVariable> Variables) ResolveFunctionVariables(FuncDefinitionNode funcDefinition)
@@ -88,9 +89,10 @@ public class GlobalVariable(string name, Type type, string identifier) : Variabl
public string Identifier { get; } = identifier;
}
public class Func(string label, string name, IReadOnlyCollection<FuncParameter> parameters, Optional<Type> returnType, IReadOnlyCollection<Variable> variables, int stackAllocation)
public class Func(string startLabel, string endLabel, string name, IReadOnlyCollection<FuncParameter> parameters, Optional<Type> returnType, IReadOnlyCollection<Variable> variables, int stackAllocation)
{
public string Label { get; } = label;
public string StartLabel { get; } = startLabel;
public string EndLabel { get; } = endLabel;
public string Name { get; } = name;
public IReadOnlyCollection<FuncParameter> Parameters { get; } = parameters;
public Optional<Type> ReturnType { get; } = returnType;

View File

@@ -4,32 +4,11 @@ let STD_OUT = 1;
let STD_ERR = 2;
func main() {
test
(
"a\n",
"b\n",
"c\n",
"d\n",
"e\n",
"f\n",
"g\n",
"h\n",
"i\n",
"j\n",
);
write(test());
}
func test(a: String, b: String, c: String, d: String, e: String, f: String, g: String, h: String, i: String, j: String) {
write(a);
write(b);
write(c);
write(d);
write(e);
write(f);
write(g);
write(h);
write(i);
write(j);
func test(): String {
return "test";
}
func write(msg: String) {

View File

@@ -3,4 +3,6 @@
public class FuncCallStatementNode(FuncCall funcCall) : StatementNode
{
public FuncCall FuncCall { get; } = funcCall;
public override string ToString() => FuncCall.ToString();
}

View File

@@ -114,6 +114,27 @@ public class Parser
throw new Exception($"Unexpected symbol {symbol.Symbol}");
}
}
case SymbolToken symbol:
{
switch (symbol.Symbol)
{
case Symbol.Return:
{
var value = Optional<ExpressionNode>.Empty();
if (!TryExpectSymbol(Symbol.Semicolon))
{
value = ParseExpression();
ExpectSymbol(Symbol.Semicolon);
}
return new ReturnNode(value);
}
default:
{
throw new Exception($"Unexpected symbol {symbol.Symbol}");
}
}
}
default:
throw new Exception($"Unexpected token type {token.GetType().Name}");
}

View File

@@ -0,0 +1,8 @@
using Nub.Lib;
namespace Nub.Lang.Parsing;
public class ReturnNode(Optional<ExpressionNode> value) : StatementNode
{
public Optional<ExpressionNode> Value { get; } = value;
}

View File

@@ -1,4 +1,5 @@
using Nub.Lang.Generation;
using Nub.Lang.Branching;
using Nub.Lang.Generation;
using Nub.Lang.Lexing;
using Nub.Lang.Parsing;
using Nub.Lang.Typing;
@@ -14,6 +15,9 @@ var definitions = parser.Parse();
var typer = new ExpressionTyper(definitions);
typer.Populate();
var branchChecker = new BranchChecker(definitions);
branchChecker.Check();
var generator = new Generator(definitions);
var asm = generator.Generate();

View File

@@ -59,6 +59,9 @@ public class ExpressionTyper
case FuncCallStatementNode funcCall:
PopulateFuncCallStatement(funcCall);
break;
case ReturnNode returnNode:
PopulateReturn(returnNode);
break;
case SyscallStatementNode syscall:
PopulateSyscallStatement(syscall);
break;
@@ -86,6 +89,14 @@ public class ExpressionTyper
}
}
private void PopulateReturn(ReturnNode returnNode)
{
if (returnNode.Value.HasValue)
{
PopulateExpression(returnNode.Value.Value);
}
}
private void PopulateVariableAssignment(VariableAssignmentNode variableAssignment)
{
PopulateExpression(variableAssignment.Value);