Branch checker
This commit is contained in:
32
Nub.Lang/Nub.Lang/Branching/BranchChecker.cs
Normal file
32
Nub.Lang/Nub.Lang/Branching/BranchChecker.cs
Normal file
@@ -0,0 +1,32 @@
|
||||
using Nub.Lang.Parsing;
|
||||
|
||||
namespace Nub.Lang.Branching;
|
||||
|
||||
public class BranchChecker
|
||||
{
|
||||
private readonly IReadOnlyCollection<DefinitionNode> _definitions;
|
||||
|
||||
public BranchChecker(IReadOnlyCollection<DefinitionNode> definitions)
|
||||
{
|
||||
_definitions = definitions;
|
||||
}
|
||||
|
||||
public void Check()
|
||||
{
|
||||
foreach (var funcDefinition in _definitions.OfType<FuncDefinitionNode>())
|
||||
{
|
||||
if (funcDefinition.ReturnType.HasValue)
|
||||
{
|
||||
CheckBlock(funcDefinition.Body);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void CheckBlock(BlockNode block)
|
||||
{
|
||||
if (!block.Statements.Any(s => s is ReturnNode))
|
||||
{
|
||||
throw new Exception("Block must contain a return statement");
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -53,7 +53,7 @@ public class Generator
|
||||
|
||||
_builder.AppendLine();
|
||||
_builder.AppendLine($" ; Call entrypoint {Entrypoint}");
|
||||
_builder.AppendLine($" call {main.Label}");
|
||||
_builder.AppendLine($" call {main.StartLabel}");
|
||||
|
||||
_builder.AppendLine();
|
||||
_builder.AppendLine(" ; Exit with status code 0");
|
||||
@@ -96,20 +96,20 @@ public class Generator
|
||||
return _builder.ToString();
|
||||
}
|
||||
|
||||
private void GenerateFuncDefinition(FuncDefinitionNode funcDefinition)
|
||||
private void GenerateFuncDefinition(FuncDefinitionNode node)
|
||||
{
|
||||
var symbol = _symbolTable.ResolveFunc(funcDefinition.Name, funcDefinition.Parameters.Select(p => p.Type).ToList());
|
||||
_builder.AppendLine($"; {funcDefinition.ToString()}");
|
||||
_builder.AppendLine($"{symbol.Label}:");
|
||||
var func = _symbolTable.ResolveFunc(node.Name, node.Parameters.Select(p => p.Type).ToList());
|
||||
_builder.AppendLine($"; {node.ToString()}");
|
||||
_builder.AppendLine($"{func.StartLabel}:");
|
||||
_builder.AppendLine(" push rbp");
|
||||
_builder.AppendLine(" mov rbp, rsp");
|
||||
_builder.AppendLine($" sub rsp, {symbol.StackAllocation}");
|
||||
_builder.AppendLine($" sub rsp, {func.StackAllocation}");
|
||||
|
||||
string[] registers = ["rdi", "rsi", "rdx", "rcx", "r8", "r9"];
|
||||
|
||||
for (var i = 0; i < symbol.Parameters.Count; i++)
|
||||
for (var i = 0; i < func.Parameters.Count; i++)
|
||||
{
|
||||
var parameter = symbol.ResolveLocalVariable(symbol.Parameters.ElementAt(i).Name);
|
||||
var parameter = func.ResolveLocalVariable(func.Parameters.ElementAt(i).Name);
|
||||
if (i < registers.Length)
|
||||
{
|
||||
_builder.AppendLine($" mov [rbp - {parameter.Offset}], {registers[i]}");
|
||||
@@ -122,7 +122,9 @@ public class Generator
|
||||
}
|
||||
}
|
||||
|
||||
GenerateBlock(funcDefinition.Body, symbol);
|
||||
GenerateBlock(node.Body, func);
|
||||
|
||||
_builder.AppendLine($"{func.EndLabel}:");
|
||||
_builder.AppendLine(" mov rsp, rbp");
|
||||
_builder.AppendLine(" pop rbp");
|
||||
_builder.AppendLine(" ret");
|
||||
@@ -143,6 +145,13 @@ public class Generator
|
||||
case FuncCallStatementNode funcCallStatement:
|
||||
GenerateFuncCall(funcCallStatement.FuncCall, func);
|
||||
break;
|
||||
case ReturnNode @return:
|
||||
if (@return.Value.HasValue)
|
||||
{
|
||||
GenerateExpression(@return.Value.Value, func);
|
||||
}
|
||||
_builder.AppendLine($" jmp {func.EndLabel}");
|
||||
break;
|
||||
case SyscallStatementNode syscallStatement:
|
||||
GenerateSyscall(syscallStatement.Syscall, func);
|
||||
break;
|
||||
@@ -159,7 +168,7 @@ public class Generator
|
||||
switch (expression)
|
||||
{
|
||||
case FuncCallExpressionNode funcCallExpression:
|
||||
throw new NotImplementedException();
|
||||
GenerateFuncCall(funcCallExpression.FuncCall, func);
|
||||
break;
|
||||
case IdentifierNode identifier:
|
||||
GenerateIdentifier(identifier, func);
|
||||
@@ -280,7 +289,7 @@ public class Generator
|
||||
_builder.AppendLine($" pop {registers[i]}");
|
||||
}
|
||||
|
||||
_builder.AppendLine($" call {symbol.Label}");
|
||||
_builder.AppendLine($" call {symbol.StartLabel}");
|
||||
if (stackParameters != 0)
|
||||
{
|
||||
_builder.AppendLine($" add rsp, {stackParameters}");
|
||||
|
||||
@@ -21,9 +21,10 @@ public class SymbolTable
|
||||
|
||||
public void DefineFunc(FuncDefinitionNode funcDefinition)
|
||||
{
|
||||
var label = $"func{++_labelIndex}";
|
||||
var startLabel = $"func{++_labelIndex}";
|
||||
var endLabel = $"endfunc{_labelIndex}";
|
||||
var localVariables = ResolveFunctionVariables(funcDefinition);
|
||||
_functions.Add(new Func(label, funcDefinition.Name, funcDefinition.Parameters, funcDefinition.ReturnType, _globalVariables.Concat<Variable>(localVariables.Variables).ToList(), localVariables.StackSize));
|
||||
_functions.Add(new Func(startLabel, endLabel, funcDefinition.Name, funcDefinition.Parameters, funcDefinition.ReturnType, _globalVariables.Concat<Variable>(localVariables.Variables).ToList(), localVariables.StackSize));
|
||||
}
|
||||
|
||||
private (int StackSize, List<LocalVariable> Variables) ResolveFunctionVariables(FuncDefinitionNode funcDefinition)
|
||||
@@ -88,9 +89,10 @@ public class GlobalVariable(string name, Type type, string identifier) : Variabl
|
||||
public string Identifier { get; } = identifier;
|
||||
}
|
||||
|
||||
public class Func(string label, string name, IReadOnlyCollection<FuncParameter> parameters, Optional<Type> returnType, IReadOnlyCollection<Variable> variables, int stackAllocation)
|
||||
public class Func(string startLabel, string endLabel, string name, IReadOnlyCollection<FuncParameter> parameters, Optional<Type> returnType, IReadOnlyCollection<Variable> variables, int stackAllocation)
|
||||
{
|
||||
public string Label { get; } = label;
|
||||
public string StartLabel { get; } = startLabel;
|
||||
public string EndLabel { get; } = endLabel;
|
||||
public string Name { get; } = name;
|
||||
public IReadOnlyCollection<FuncParameter> Parameters { get; } = parameters;
|
||||
public Optional<Type> ReturnType { get; } = returnType;
|
||||
|
||||
@@ -4,32 +4,11 @@ let STD_OUT = 1;
|
||||
let STD_ERR = 2;
|
||||
|
||||
func main() {
|
||||
test
|
||||
(
|
||||
"a\n",
|
||||
"b\n",
|
||||
"c\n",
|
||||
"d\n",
|
||||
"e\n",
|
||||
"f\n",
|
||||
"g\n",
|
||||
"h\n",
|
||||
"i\n",
|
||||
"j\n",
|
||||
);
|
||||
write(test());
|
||||
}
|
||||
|
||||
func test(a: String, b: String, c: String, d: String, e: String, f: String, g: String, h: String, i: String, j: String) {
|
||||
write(a);
|
||||
write(b);
|
||||
write(c);
|
||||
write(d);
|
||||
write(e);
|
||||
write(f);
|
||||
write(g);
|
||||
write(h);
|
||||
write(i);
|
||||
write(j);
|
||||
func test(): String {
|
||||
return "test";
|
||||
}
|
||||
|
||||
func write(msg: String) {
|
||||
|
||||
@@ -3,4 +3,6 @@
|
||||
public class FuncCallStatementNode(FuncCall funcCall) : StatementNode
|
||||
{
|
||||
public FuncCall FuncCall { get; } = funcCall;
|
||||
|
||||
public override string ToString() => FuncCall.ToString();
|
||||
}
|
||||
@@ -114,6 +114,27 @@ public class Parser
|
||||
throw new Exception($"Unexpected symbol {symbol.Symbol}");
|
||||
}
|
||||
}
|
||||
case SymbolToken symbol:
|
||||
{
|
||||
switch (symbol.Symbol)
|
||||
{
|
||||
case Symbol.Return:
|
||||
{
|
||||
var value = Optional<ExpressionNode>.Empty();
|
||||
if (!TryExpectSymbol(Symbol.Semicolon))
|
||||
{
|
||||
value = ParseExpression();
|
||||
ExpectSymbol(Symbol.Semicolon);
|
||||
}
|
||||
|
||||
return new ReturnNode(value);
|
||||
}
|
||||
default:
|
||||
{
|
||||
throw new Exception($"Unexpected symbol {symbol.Symbol}");
|
||||
}
|
||||
}
|
||||
}
|
||||
default:
|
||||
throw new Exception($"Unexpected token type {token.GetType().Name}");
|
||||
}
|
||||
|
||||
8
Nub.Lang/Nub.Lang/Parsing/ReturnNode.cs
Normal file
8
Nub.Lang/Nub.Lang/Parsing/ReturnNode.cs
Normal file
@@ -0,0 +1,8 @@
|
||||
using Nub.Lib;
|
||||
|
||||
namespace Nub.Lang.Parsing;
|
||||
|
||||
public class ReturnNode(Optional<ExpressionNode> value) : StatementNode
|
||||
{
|
||||
public Optional<ExpressionNode> Value { get; } = value;
|
||||
}
|
||||
@@ -1,4 +1,5 @@
|
||||
using Nub.Lang.Generation;
|
||||
using Nub.Lang.Branching;
|
||||
using Nub.Lang.Generation;
|
||||
using Nub.Lang.Lexing;
|
||||
using Nub.Lang.Parsing;
|
||||
using Nub.Lang.Typing;
|
||||
@@ -14,6 +15,9 @@ var definitions = parser.Parse();
|
||||
var typer = new ExpressionTyper(definitions);
|
||||
typer.Populate();
|
||||
|
||||
var branchChecker = new BranchChecker(definitions);
|
||||
branchChecker.Check();
|
||||
|
||||
var generator = new Generator(definitions);
|
||||
var asm = generator.Generate();
|
||||
|
||||
|
||||
@@ -59,6 +59,9 @@ public class ExpressionTyper
|
||||
case FuncCallStatementNode funcCall:
|
||||
PopulateFuncCallStatement(funcCall);
|
||||
break;
|
||||
case ReturnNode returnNode:
|
||||
PopulateReturn(returnNode);
|
||||
break;
|
||||
case SyscallStatementNode syscall:
|
||||
PopulateSyscallStatement(syscall);
|
||||
break;
|
||||
@@ -86,6 +89,14 @@ public class ExpressionTyper
|
||||
}
|
||||
}
|
||||
|
||||
private void PopulateReturn(ReturnNode returnNode)
|
||||
{
|
||||
if (returnNode.Value.HasValue)
|
||||
{
|
||||
PopulateExpression(returnNode.Value.Value);
|
||||
}
|
||||
}
|
||||
|
||||
private void PopulateVariableAssignment(VariableAssignmentNode variableAssignment)
|
||||
{
|
||||
PopulateExpression(variableAssignment.Value);
|
||||
|
||||
Reference in New Issue
Block a user