diff --git a/Nub.Lang/Nub.Lang/Branching/BranchChecker.cs b/Nub.Lang/Nub.Lang/Branching/BranchChecker.cs new file mode 100644 index 0000000..8954360 --- /dev/null +++ b/Nub.Lang/Nub.Lang/Branching/BranchChecker.cs @@ -0,0 +1,32 @@ +using Nub.Lang.Parsing; + +namespace Nub.Lang.Branching; + +public class BranchChecker +{ + private readonly IReadOnlyCollection _definitions; + + public BranchChecker(IReadOnlyCollection definitions) + { + _definitions = definitions; + } + + public void Check() + { + foreach (var funcDefinition in _definitions.OfType()) + { + if (funcDefinition.ReturnType.HasValue) + { + CheckBlock(funcDefinition.Body); + } + } + } + + private void CheckBlock(BlockNode block) + { + if (!block.Statements.Any(s => s is ReturnNode)) + { + throw new Exception("Block must contain a return statement"); + } + } +} \ No newline at end of file diff --git a/Nub.Lang/Nub.Lang/Generation/Generator.cs b/Nub.Lang/Nub.Lang/Generation/Generator.cs index 471b797..6702c23 100644 --- a/Nub.Lang/Nub.Lang/Generation/Generator.cs +++ b/Nub.Lang/Nub.Lang/Generation/Generator.cs @@ -53,7 +53,7 @@ public class Generator _builder.AppendLine(); _builder.AppendLine($" ; Call entrypoint {Entrypoint}"); - _builder.AppendLine($" call {main.Label}"); + _builder.AppendLine($" call {main.StartLabel}"); _builder.AppendLine(); _builder.AppendLine(" ; Exit with status code 0"); @@ -96,20 +96,20 @@ public class Generator return _builder.ToString(); } - private void GenerateFuncDefinition(FuncDefinitionNode funcDefinition) + private void GenerateFuncDefinition(FuncDefinitionNode node) { - var symbol = _symbolTable.ResolveFunc(funcDefinition.Name, funcDefinition.Parameters.Select(p => p.Type).ToList()); - _builder.AppendLine($"; {funcDefinition.ToString()}"); - _builder.AppendLine($"{symbol.Label}:"); + var func = _symbolTable.ResolveFunc(node.Name, node.Parameters.Select(p => p.Type).ToList()); + _builder.AppendLine($"; {node.ToString()}"); + _builder.AppendLine($"{func.StartLabel}:"); _builder.AppendLine(" push rbp"); _builder.AppendLine(" mov rbp, rsp"); - _builder.AppendLine($" sub rsp, {symbol.StackAllocation}"); + _builder.AppendLine($" sub rsp, {func.StackAllocation}"); string[] registers = ["rdi", "rsi", "rdx", "rcx", "r8", "r9"]; - for (var i = 0; i < symbol.Parameters.Count; i++) + for (var i = 0; i < func.Parameters.Count; i++) { - var parameter = symbol.ResolveLocalVariable(symbol.Parameters.ElementAt(i).Name); + var parameter = func.ResolveLocalVariable(func.Parameters.ElementAt(i).Name); if (i < registers.Length) { _builder.AppendLine($" mov [rbp - {parameter.Offset}], {registers[i]}"); @@ -122,7 +122,9 @@ public class Generator } } - GenerateBlock(funcDefinition.Body, symbol); + GenerateBlock(node.Body, func); + + _builder.AppendLine($"{func.EndLabel}:"); _builder.AppendLine(" mov rsp, rbp"); _builder.AppendLine(" pop rbp"); _builder.AppendLine(" ret"); @@ -143,6 +145,13 @@ public class Generator case FuncCallStatementNode funcCallStatement: GenerateFuncCall(funcCallStatement.FuncCall, func); break; + case ReturnNode @return: + if (@return.Value.HasValue) + { + GenerateExpression(@return.Value.Value, func); + } + _builder.AppendLine($" jmp {func.EndLabel}"); + break; case SyscallStatementNode syscallStatement: GenerateSyscall(syscallStatement.Syscall, func); break; @@ -159,7 +168,7 @@ public class Generator switch (expression) { case FuncCallExpressionNode funcCallExpression: - throw new NotImplementedException(); + GenerateFuncCall(funcCallExpression.FuncCall, func); break; case IdentifierNode identifier: GenerateIdentifier(identifier, func); @@ -280,7 +289,7 @@ public class Generator _builder.AppendLine($" pop {registers[i]}"); } - _builder.AppendLine($" call {symbol.Label}"); + _builder.AppendLine($" call {symbol.StartLabel}"); if (stackParameters != 0) { _builder.AppendLine($" add rsp, {stackParameters}"); diff --git a/Nub.Lang/Nub.Lang/Generation/SymbolTable.cs b/Nub.Lang/Nub.Lang/Generation/SymbolTable.cs index 0af7b19..66688e6 100644 --- a/Nub.Lang/Nub.Lang/Generation/SymbolTable.cs +++ b/Nub.Lang/Nub.Lang/Generation/SymbolTable.cs @@ -21,9 +21,10 @@ public class SymbolTable public void DefineFunc(FuncDefinitionNode funcDefinition) { - var label = $"func{++_labelIndex}"; + var startLabel = $"func{++_labelIndex}"; + var endLabel = $"endfunc{_labelIndex}"; var localVariables = ResolveFunctionVariables(funcDefinition); - _functions.Add(new Func(label, funcDefinition.Name, funcDefinition.Parameters, funcDefinition.ReturnType, _globalVariables.Concat(localVariables.Variables).ToList(), localVariables.StackSize)); + _functions.Add(new Func(startLabel, endLabel, funcDefinition.Name, funcDefinition.Parameters, funcDefinition.ReturnType, _globalVariables.Concat(localVariables.Variables).ToList(), localVariables.StackSize)); } private (int StackSize, List Variables) ResolveFunctionVariables(FuncDefinitionNode funcDefinition) @@ -88,9 +89,10 @@ public class GlobalVariable(string name, Type type, string identifier) : Variabl public string Identifier { get; } = identifier; } -public class Func(string label, string name, IReadOnlyCollection parameters, Optional returnType, IReadOnlyCollection variables, int stackAllocation) +public class Func(string startLabel, string endLabel, string name, IReadOnlyCollection parameters, Optional returnType, IReadOnlyCollection variables, int stackAllocation) { - public string Label { get; } = label; + public string StartLabel { get; } = startLabel; + public string EndLabel { get; } = endLabel; public string Name { get; } = name; public IReadOnlyCollection Parameters { get; } = parameters; public Optional ReturnType { get; } = returnType; diff --git a/Nub.Lang/Nub.Lang/Input/program.nub b/Nub.Lang/Nub.Lang/Input/program.nub index f21d41b..64b3f58 100644 --- a/Nub.Lang/Nub.Lang/Input/program.nub +++ b/Nub.Lang/Nub.Lang/Input/program.nub @@ -4,32 +4,11 @@ let STD_OUT = 1; let STD_ERR = 2; func main() { - test - ( - "a\n", - "b\n", - "c\n", - "d\n", - "e\n", - "f\n", - "g\n", - "h\n", - "i\n", - "j\n", - ); + write(test()); } -func test(a: String, b: String, c: String, d: String, e: String, f: String, g: String, h: String, i: String, j: String) { - write(a); - write(b); - write(c); - write(d); - write(e); - write(f); - write(g); - write(h); - write(i); - write(j); +func test(): String { + return "test"; } func write(msg: String) { diff --git a/Nub.Lang/Nub.Lang/Parsing/FuncCallStatementNode.cs b/Nub.Lang/Nub.Lang/Parsing/FuncCallStatementNode.cs index 6d24695..bd60bf5 100644 --- a/Nub.Lang/Nub.Lang/Parsing/FuncCallStatementNode.cs +++ b/Nub.Lang/Nub.Lang/Parsing/FuncCallStatementNode.cs @@ -3,4 +3,6 @@ public class FuncCallStatementNode(FuncCall funcCall) : StatementNode { public FuncCall FuncCall { get; } = funcCall; + + public override string ToString() => FuncCall.ToString(); } \ No newline at end of file diff --git a/Nub.Lang/Nub.Lang/Parsing/Parser.cs b/Nub.Lang/Nub.Lang/Parsing/Parser.cs index e2921bd..0252c57 100644 --- a/Nub.Lang/Nub.Lang/Parsing/Parser.cs +++ b/Nub.Lang/Nub.Lang/Parsing/Parser.cs @@ -114,6 +114,27 @@ public class Parser throw new Exception($"Unexpected symbol {symbol.Symbol}"); } } + case SymbolToken symbol: + { + switch (symbol.Symbol) + { + case Symbol.Return: + { + var value = Optional.Empty(); + if (!TryExpectSymbol(Symbol.Semicolon)) + { + value = ParseExpression(); + ExpectSymbol(Symbol.Semicolon); + } + + return new ReturnNode(value); + } + default: + { + throw new Exception($"Unexpected symbol {symbol.Symbol}"); + } + } + } default: throw new Exception($"Unexpected token type {token.GetType().Name}"); } diff --git a/Nub.Lang/Nub.Lang/Parsing/ReturnNode.cs b/Nub.Lang/Nub.Lang/Parsing/ReturnNode.cs new file mode 100644 index 0000000..ca49538 --- /dev/null +++ b/Nub.Lang/Nub.Lang/Parsing/ReturnNode.cs @@ -0,0 +1,8 @@ +using Nub.Lib; + +namespace Nub.Lang.Parsing; + +public class ReturnNode(Optional value) : StatementNode +{ + public Optional Value { get; } = value; +} \ No newline at end of file diff --git a/Nub.Lang/Nub.Lang/Program.cs b/Nub.Lang/Nub.Lang/Program.cs index 5492bed..df63529 100644 --- a/Nub.Lang/Nub.Lang/Program.cs +++ b/Nub.Lang/Nub.Lang/Program.cs @@ -1,4 +1,5 @@ -using Nub.Lang.Generation; +using Nub.Lang.Branching; +using Nub.Lang.Generation; using Nub.Lang.Lexing; using Nub.Lang.Parsing; using Nub.Lang.Typing; @@ -14,6 +15,9 @@ var definitions = parser.Parse(); var typer = new ExpressionTyper(definitions); typer.Populate(); +var branchChecker = new BranchChecker(definitions); +branchChecker.Check(); + var generator = new Generator(definitions); var asm = generator.Generate(); diff --git a/Nub.Lang/Nub.Lang/Typing/ExpressionTyper.cs b/Nub.Lang/Nub.Lang/Typing/ExpressionTyper.cs index 1c5cdc2..1b5f80a 100644 --- a/Nub.Lang/Nub.Lang/Typing/ExpressionTyper.cs +++ b/Nub.Lang/Nub.Lang/Typing/ExpressionTyper.cs @@ -59,6 +59,9 @@ public class ExpressionTyper case FuncCallStatementNode funcCall: PopulateFuncCallStatement(funcCall); break; + case ReturnNode returnNode: + PopulateReturn(returnNode); + break; case SyscallStatementNode syscall: PopulateSyscallStatement(syscall); break; @@ -86,6 +89,14 @@ public class ExpressionTyper } } + private void PopulateReturn(ReturnNode returnNode) + { + if (returnNode.Value.HasValue) + { + PopulateExpression(returnNode.Value.Value); + } + } + private void PopulateVariableAssignment(VariableAssignmentNode variableAssignment) { PopulateExpression(variableAssignment.Value);