Branch checker

This commit is contained in:
nub31
2025-01-26 21:52:54 +01:00
parent 6e6a1a830e
commit 738aa92da5
9 changed files with 108 additions and 40 deletions

View File

@@ -0,0 +1,32 @@
using Nub.Lang.Parsing;
namespace Nub.Lang.Branching;
public class BranchChecker
{
private readonly IReadOnlyCollection<DefinitionNode> _definitions;
public BranchChecker(IReadOnlyCollection<DefinitionNode> definitions)
{
_definitions = definitions;
}
public void Check()
{
foreach (var funcDefinition in _definitions.OfType<FuncDefinitionNode>())
{
if (funcDefinition.ReturnType.HasValue)
{
CheckBlock(funcDefinition.Body);
}
}
}
private void CheckBlock(BlockNode block)
{
if (!block.Statements.Any(s => s is ReturnNode))
{
throw new Exception("Block must contain a return statement");
}
}
}

View File

@@ -53,7 +53,7 @@ public class Generator
_builder.AppendLine(); _builder.AppendLine();
_builder.AppendLine($" ; Call entrypoint {Entrypoint}"); _builder.AppendLine($" ; Call entrypoint {Entrypoint}");
_builder.AppendLine($" call {main.Label}"); _builder.AppendLine($" call {main.StartLabel}");
_builder.AppendLine(); _builder.AppendLine();
_builder.AppendLine(" ; Exit with status code 0"); _builder.AppendLine(" ; Exit with status code 0");
@@ -96,20 +96,20 @@ public class Generator
return _builder.ToString(); return _builder.ToString();
} }
private void GenerateFuncDefinition(FuncDefinitionNode funcDefinition) private void GenerateFuncDefinition(FuncDefinitionNode node)
{ {
var symbol = _symbolTable.ResolveFunc(funcDefinition.Name, funcDefinition.Parameters.Select(p => p.Type).ToList()); var func = _symbolTable.ResolveFunc(node.Name, node.Parameters.Select(p => p.Type).ToList());
_builder.AppendLine($"; {funcDefinition.ToString()}"); _builder.AppendLine($"; {node.ToString()}");
_builder.AppendLine($"{symbol.Label}:"); _builder.AppendLine($"{func.StartLabel}:");
_builder.AppendLine(" push rbp"); _builder.AppendLine(" push rbp");
_builder.AppendLine(" mov rbp, rsp"); _builder.AppendLine(" mov rbp, rsp");
_builder.AppendLine($" sub rsp, {symbol.StackAllocation}"); _builder.AppendLine($" sub rsp, {func.StackAllocation}");
string[] registers = ["rdi", "rsi", "rdx", "rcx", "r8", "r9"]; string[] registers = ["rdi", "rsi", "rdx", "rcx", "r8", "r9"];
for (var i = 0; i < symbol.Parameters.Count; i++) for (var i = 0; i < func.Parameters.Count; i++)
{ {
var parameter = symbol.ResolveLocalVariable(symbol.Parameters.ElementAt(i).Name); var parameter = func.ResolveLocalVariable(func.Parameters.ElementAt(i).Name);
if (i < registers.Length) if (i < registers.Length)
{ {
_builder.AppendLine($" mov [rbp - {parameter.Offset}], {registers[i]}"); _builder.AppendLine($" mov [rbp - {parameter.Offset}], {registers[i]}");
@@ -122,7 +122,9 @@ public class Generator
} }
} }
GenerateBlock(funcDefinition.Body, symbol); GenerateBlock(node.Body, func);
_builder.AppendLine($"{func.EndLabel}:");
_builder.AppendLine(" mov rsp, rbp"); _builder.AppendLine(" mov rsp, rbp");
_builder.AppendLine(" pop rbp"); _builder.AppendLine(" pop rbp");
_builder.AppendLine(" ret"); _builder.AppendLine(" ret");
@@ -143,6 +145,13 @@ public class Generator
case FuncCallStatementNode funcCallStatement: case FuncCallStatementNode funcCallStatement:
GenerateFuncCall(funcCallStatement.FuncCall, func); GenerateFuncCall(funcCallStatement.FuncCall, func);
break; break;
case ReturnNode @return:
if (@return.Value.HasValue)
{
GenerateExpression(@return.Value.Value, func);
}
_builder.AppendLine($" jmp {func.EndLabel}");
break;
case SyscallStatementNode syscallStatement: case SyscallStatementNode syscallStatement:
GenerateSyscall(syscallStatement.Syscall, func); GenerateSyscall(syscallStatement.Syscall, func);
break; break;
@@ -159,7 +168,7 @@ public class Generator
switch (expression) switch (expression)
{ {
case FuncCallExpressionNode funcCallExpression: case FuncCallExpressionNode funcCallExpression:
throw new NotImplementedException(); GenerateFuncCall(funcCallExpression.FuncCall, func);
break; break;
case IdentifierNode identifier: case IdentifierNode identifier:
GenerateIdentifier(identifier, func); GenerateIdentifier(identifier, func);
@@ -280,7 +289,7 @@ public class Generator
_builder.AppendLine($" pop {registers[i]}"); _builder.AppendLine($" pop {registers[i]}");
} }
_builder.AppendLine($" call {symbol.Label}"); _builder.AppendLine($" call {symbol.StartLabel}");
if (stackParameters != 0) if (stackParameters != 0)
{ {
_builder.AppendLine($" add rsp, {stackParameters}"); _builder.AppendLine($" add rsp, {stackParameters}");

View File

@@ -21,9 +21,10 @@ public class SymbolTable
public void DefineFunc(FuncDefinitionNode funcDefinition) public void DefineFunc(FuncDefinitionNode funcDefinition)
{ {
var label = $"func{++_labelIndex}"; var startLabel = $"func{++_labelIndex}";
var endLabel = $"endfunc{_labelIndex}";
var localVariables = ResolveFunctionVariables(funcDefinition); var localVariables = ResolveFunctionVariables(funcDefinition);
_functions.Add(new Func(label, funcDefinition.Name, funcDefinition.Parameters, funcDefinition.ReturnType, _globalVariables.Concat<Variable>(localVariables.Variables).ToList(), localVariables.StackSize)); _functions.Add(new Func(startLabel, endLabel, funcDefinition.Name, funcDefinition.Parameters, funcDefinition.ReturnType, _globalVariables.Concat<Variable>(localVariables.Variables).ToList(), localVariables.StackSize));
} }
private (int StackSize, List<LocalVariable> Variables) ResolveFunctionVariables(FuncDefinitionNode funcDefinition) private (int StackSize, List<LocalVariable> Variables) ResolveFunctionVariables(FuncDefinitionNode funcDefinition)
@@ -88,9 +89,10 @@ public class GlobalVariable(string name, Type type, string identifier) : Variabl
public string Identifier { get; } = identifier; public string Identifier { get; } = identifier;
} }
public class Func(string label, string name, IReadOnlyCollection<FuncParameter> parameters, Optional<Type> returnType, IReadOnlyCollection<Variable> variables, int stackAllocation) public class Func(string startLabel, string endLabel, string name, IReadOnlyCollection<FuncParameter> parameters, Optional<Type> returnType, IReadOnlyCollection<Variable> variables, int stackAllocation)
{ {
public string Label { get; } = label; public string StartLabel { get; } = startLabel;
public string EndLabel { get; } = endLabel;
public string Name { get; } = name; public string Name { get; } = name;
public IReadOnlyCollection<FuncParameter> Parameters { get; } = parameters; public IReadOnlyCollection<FuncParameter> Parameters { get; } = parameters;
public Optional<Type> ReturnType { get; } = returnType; public Optional<Type> ReturnType { get; } = returnType;

View File

@@ -4,32 +4,11 @@ let STD_OUT = 1;
let STD_ERR = 2; let STD_ERR = 2;
func main() { func main() {
test write(test());
(
"a\n",
"b\n",
"c\n",
"d\n",
"e\n",
"f\n",
"g\n",
"h\n",
"i\n",
"j\n",
);
} }
func test(a: String, b: String, c: String, d: String, e: String, f: String, g: String, h: String, i: String, j: String) { func test(): String {
write(a); return "test";
write(b);
write(c);
write(d);
write(e);
write(f);
write(g);
write(h);
write(i);
write(j);
} }
func write(msg: String) { func write(msg: String) {

View File

@@ -3,4 +3,6 @@
public class FuncCallStatementNode(FuncCall funcCall) : StatementNode public class FuncCallStatementNode(FuncCall funcCall) : StatementNode
{ {
public FuncCall FuncCall { get; } = funcCall; public FuncCall FuncCall { get; } = funcCall;
public override string ToString() => FuncCall.ToString();
} }

View File

@@ -114,6 +114,27 @@ public class Parser
throw new Exception($"Unexpected symbol {symbol.Symbol}"); throw new Exception($"Unexpected symbol {symbol.Symbol}");
} }
} }
case SymbolToken symbol:
{
switch (symbol.Symbol)
{
case Symbol.Return:
{
var value = Optional<ExpressionNode>.Empty();
if (!TryExpectSymbol(Symbol.Semicolon))
{
value = ParseExpression();
ExpectSymbol(Symbol.Semicolon);
}
return new ReturnNode(value);
}
default:
{
throw new Exception($"Unexpected symbol {symbol.Symbol}");
}
}
}
default: default:
throw new Exception($"Unexpected token type {token.GetType().Name}"); throw new Exception($"Unexpected token type {token.GetType().Name}");
} }

View File

@@ -0,0 +1,8 @@
using Nub.Lib;
namespace Nub.Lang.Parsing;
public class ReturnNode(Optional<ExpressionNode> value) : StatementNode
{
public Optional<ExpressionNode> Value { get; } = value;
}

View File

@@ -1,4 +1,5 @@
using Nub.Lang.Generation; using Nub.Lang.Branching;
using Nub.Lang.Generation;
using Nub.Lang.Lexing; using Nub.Lang.Lexing;
using Nub.Lang.Parsing; using Nub.Lang.Parsing;
using Nub.Lang.Typing; using Nub.Lang.Typing;
@@ -14,6 +15,9 @@ var definitions = parser.Parse();
var typer = new ExpressionTyper(definitions); var typer = new ExpressionTyper(definitions);
typer.Populate(); typer.Populate();
var branchChecker = new BranchChecker(definitions);
branchChecker.Check();
var generator = new Generator(definitions); var generator = new Generator(definitions);
var asm = generator.Generate(); var asm = generator.Generate();

View File

@@ -59,6 +59,9 @@ public class ExpressionTyper
case FuncCallStatementNode funcCall: case FuncCallStatementNode funcCall:
PopulateFuncCallStatement(funcCall); PopulateFuncCallStatement(funcCall);
break; break;
case ReturnNode returnNode:
PopulateReturn(returnNode);
break;
case SyscallStatementNode syscall: case SyscallStatementNode syscall:
PopulateSyscallStatement(syscall); PopulateSyscallStatement(syscall);
break; break;
@@ -86,6 +89,14 @@ public class ExpressionTyper
} }
} }
private void PopulateReturn(ReturnNode returnNode)
{
if (returnNode.Value.HasValue)
{
PopulateExpression(returnNode.Value.Value);
}
}
private void PopulateVariableAssignment(VariableAssignmentNode variableAssignment) private void PopulateVariableAssignment(VariableAssignmentNode variableAssignment)
{ {
PopulateExpression(variableAssignment.Value); PopulateExpression(variableAssignment.Value);