diff --git a/lang/Nub.Lang/Backend/Generator.cs b/lang/Nub.Lang/Backend/Generator.cs index a645632..19c8cfd 100644 --- a/lang/Nub.Lang/Backend/Generator.cs +++ b/lang/Nub.Lang/Backend/Generator.cs @@ -5,186 +5,40 @@ namespace Nub.Lang.Backend; public class Generator { - private const string Entrypoint = "main"; - private readonly List _definitions; + private readonly StringBuilder _builder = new(); private readonly SymbolTable _symbolTable; - private readonly StringBuilder _builder; - private readonly LabelFactory _labelFactory; - private readonly Stack<(string StartLabel, string EndLabel)> _loops; public Generator(List definitions) { _definitions = definitions; - _builder = new StringBuilder(); - _labelFactory = new LabelFactory(); - _symbolTable = new SymbolTable(_labelFactory); - _loops = []; - - foreach (var globalVariableDefinition in definitions.OfType()) - { - _symbolTable.DefineGlobalVariable(globalVariableDefinition); - } - - foreach (var funcDefinitionNode in definitions.OfType()) - { - _symbolTable.DefineFunc(funcDefinitionNode); - } - - foreach (var funcDefinitionNode in definitions.OfType()) - { - _symbolTable.DefineFunc(funcDefinitionNode); - } + _symbolTable = SymbolTable.Create(definitions); } public string Generate() { - _builder.AppendLine("global _start"); - _builder.AppendLine("extern gc_init"); - _builder.AppendLine("extern gc_alloc"); - _builder.AppendLine("extern str_cmp"); - foreach (var externFuncDefinition in _definitions.OfType()) { - _builder.AppendLine($"extern {externFuncDefinition.Name}"); + GenerateExternFuncDefinition(externFuncDefinition); } - _builder.AppendLine(); - _builder.AppendLine("section .text"); - - // TODO: Only add start label if entrypoint is present, otherwise assume library - var main = _symbolTable.ResolveLocalFunc(Entrypoint, []); - - _builder.AppendLine("_start:"); - _builder.AppendLine(" call gc_init"); - _builder.AppendLine($" call {main.StartLabel}"); - - _builder.AppendLine(main.ReturnType.HasValue - ? " mov rdi, rax" - : " mov rdi, 0"); - _builder.AppendLine(" mov rax, 60"); - _builder.AppendLine(" syscall"); - foreach (var funcDefinition in _definitions.OfType()) { - _builder.AppendLine(); GenerateFuncDefinition(funcDefinition); } - _builder.AppendLine(""" - - eb6e_oob_error: - mov rax, 60 - mov rdi, 139 - syscall - """); - - _builder.AppendLine(); - _builder.AppendLine("section .data"); - - foreach (var str in _symbolTable.Strings) - { - _builder.AppendLine($" {str.Key}: db `{str.Value}`, 0"); - } - - Dictionary completed = []; - foreach (var globalVariableDefinition in _definitions.OfType()) - { - var variable = _symbolTable.ResolveGlobalVariable(globalVariableDefinition.Name); - var evaluated = EvaluateExpression(globalVariableDefinition.Value, completed); - _builder.AppendLine($" {variable.Identifier}: dq {evaluated}"); - completed[variable.Name] = evaluated; - } - return _builder.ToString(); } - private string EvaluateExpression(ExpressionNode expression, Dictionary completed) + private void GenerateExternFuncDefinition(ExternFuncDefinitionNode externFuncDefinition) { - switch (expression) - { - case BinaryExpressionNode binaryExpression: - { - var left = EvaluateExpression(binaryExpression.Left, completed); - var right = EvaluateExpression(binaryExpression.Right, completed); - return binaryExpression.Operator switch - { - BinaryExpressionOperator.Equal => bool.Parse(left) == bool.Parse(right) ? "1" : "0", - BinaryExpressionOperator.NotEqual => bool.Parse(left) != bool.Parse(right) ? "1" : "0", - BinaryExpressionOperator.GreaterThan => long.Parse(left) > long.Parse(right) ? "1" : "0", - BinaryExpressionOperator.GreaterThanOrEqual => long.Parse(left) >= long.Parse(right) ? "1" : "0", - BinaryExpressionOperator.LessThan => long.Parse(left) < long.Parse(right) ? "1" : "0", - BinaryExpressionOperator.LessThanOrEqual => long.Parse(left) <= long.Parse(right) ? "1" : "0", - BinaryExpressionOperator.Plus => (long.Parse(left) + long.Parse(right)).ToString(), - BinaryExpressionOperator.Minus => (long.Parse(left) - long.Parse(right)).ToString(), - BinaryExpressionOperator.Multiply => (long.Parse(left) * long.Parse(right)).ToString(), - BinaryExpressionOperator.Divide => (long.Parse(left) / long.Parse(right)).ToString(), - _ => throw new ArgumentOutOfRangeException() - }; - } - case IdentifierNode identifier: - { - return completed[identifier.Identifier]; - } - case LiteralNode literal: - { - if (literal.LiteralType.Equals(NubType.Int64) - || literal.LiteralType.Equals(NubType.Int32) - || literal.LiteralType.Equals(NubType.Int16) - || literal.LiteralType.Equals(NubType.Int8)) - { - return literal.Literal; - } - - if (literal.LiteralType.Equals(NubType.Bool)) - { - return bool.Parse(literal.Literal) ? "1" : "0"; - } - - throw new InvalidOperationException("BAD"); - } - default: - { - throw new InvalidOperationException("Global variables must be compile time consistant"); - } - } } - + private void GenerateFuncDefinition(LocalFuncDefinitionNode node) { - var func = _symbolTable.ResolveLocalFunc(node.Name, node.Parameters.Select(p => p.Type).ToList()); - - _builder.AppendLine($"{func.StartLabel}:"); - _builder.AppendLine(" push rbp"); - _builder.AppendLine(" mov rbp, rsp"); - _builder.AppendLine($" sub rsp, {func.StackAllocation}"); - - string[] registers = ["rdi", "rsi", "rdx", "rcx", "r8", "r9"]; - - for (var i = 0; i < func.Parameters.Count; i++) - { - var parameter = func.ResolveLocalVariable(func.Parameters.ElementAt(i).Name); - if (i < registers.Length) - { - _builder.AppendLine($" mov [rbp - {parameter.Offset}], {registers[i]}"); - } - else - { - var stackOffset = 16 + (i - registers.Length) * 8; - _builder.AppendLine($" mov rax, [rbp + {stackOffset}]"); - _builder.AppendLine($" mov [rbp - {parameter.Offset}], rax"); - } - } - - GenerateBlock(node.Body, func); - - _builder.AppendLine($"{func.EndLabel}:"); - _builder.AppendLine(" mov rsp, rbp"); - _builder.AppendLine(" pop rbp"); - _builder.AppendLine(" ret"); } - private void GenerateBlock(BlockNode block, LocalFunc func) + private void GenerateBlock(BlockNode block, LocalFuncDef func) { foreach (var statement in block.Statements) { @@ -192,7 +46,7 @@ public class Generator } } - private void GenerateStatement(StatementNode statement, LocalFunc func) + private void GenerateStatement(StatementNode statement, LocalFuncDef func) { switch (statement) { @@ -230,82 +84,33 @@ public class Generator private void GenerateBreak() { - _builder.AppendLine($" jmp {_loops.Peek().EndLabel}"); } private void GenerateContinue() { - _builder.AppendLine($" jmp {_loops.Peek().StartLabel}"); } - private void GenerateIf(IfNode ifStatement, LocalFunc func) + private void GenerateIf(IfNode ifStatement, LocalFuncDef func) { - var endLabel = _labelFactory.Create(); - GenerateIf(ifStatement, endLabel, func); - _builder.AppendLine($"{endLabel}:"); } - - private void GenerateIf(IfNode ifStatement, string endLabel, LocalFunc func) + + private void GenerateReturn(ReturnNode @return, LocalFuncDef func) { - var nextLabel = _labelFactory.Create(); - GenerateExpression(ifStatement.Condition, func); - _builder.AppendLine(" cmp rax, 0"); - _builder.AppendLine($" je {nextLabel}"); - GenerateBlock(ifStatement.Body, func); - _builder.AppendLine($" jmp {endLabel}"); - _builder.AppendLine($"{nextLabel}:"); - - if (ifStatement.Else.HasValue) - { - ifStatement.Else.Value.Match - ( - elseIfStatement => GenerateIf(elseIfStatement, endLabel, func), - elseStatement => GenerateBlock(elseStatement, func) - ); - } } - private void GenerateReturn(ReturnNode @return, LocalFunc func) + private void GenerateVariableAssignment(VariableAssignmentNode variableAssignment, LocalFuncDef func) { - if (@return.Value.HasValue) - { - GenerateExpression(@return.Value.Value, func); - } - - _builder.AppendLine($" jmp {func.EndLabel}"); } - private void GenerateVariableAssignment(VariableAssignmentNode variableAssignment, LocalFunc func) + private void GenerateVariableReassignment(VariableReassignmentNode variableReassignment, LocalFuncDef func) { - var variable = func.ResolveLocalVariable(variableAssignment.Name); - GenerateExpression(variableAssignment.Value, func); - _builder.AppendLine($" mov [rbp - {variable.Offset}], rax"); } - private void GenerateVariableReassignment(VariableReassignmentNode variableReassignment, LocalFunc func) + private void GenerateWhile(WhileNode whileStatement, LocalFuncDef func) { - var variable = func.ResolveLocalVariable(variableReassignment.Name); - GenerateExpression(variableReassignment.Value, func); - _builder.AppendLine($" mov [rbp - {variable.Offset}], rax"); } - private void GenerateWhile(WhileNode whileStatement, LocalFunc func) - { - var startLabel = _labelFactory.Create(); - var endLabel = _labelFactory.Create(); - - _builder.AppendLine($"{startLabel}:"); - GenerateExpression(whileStatement.Condition, func); - _builder.AppendLine(" cmp rax, 0"); - _builder.AppendLine($" je {endLabel}"); - _loops.Push((startLabel, endLabel)); - GenerateBlock(whileStatement.Body, func); - _loops.Pop(); - _builder.AppendLine($" jmp {startLabel}"); - _builder.AppendLine($"{endLabel}:"); - } - - private void GenerateExpression(ExpressionNode expression, LocalFunc func) + private void GenerateExpression(ExpressionNode expression, LocalFuncDef func) { switch (expression) { @@ -319,7 +124,7 @@ public class Generator GenerateIdentifier(identifier, func); break; case LiteralNode literal: - GenerateLiteral(literal); + GenerateLiteral(literal, func); break; case StructInitializerNode structInitializer: GenerateStructInitializer(structInitializer, func); @@ -335,337 +140,31 @@ public class Generator } } - private void GenerateStructMemberAccessor(StructMemberAccessorNode structMemberAccessor, LocalFunc func) + private void GenerateStructMemberAccessor(StructMemberAccessorNode structMemberAccessor, LocalFuncDef func) { - var variable = func.ResolveLocalVariable(structMemberAccessor.Members[0]); - - _builder.AppendLine($" mov rax, [rbp - {variable.Offset}]"); - - var prevMemberType = variable.Type; - for (var i = 1; i < structMemberAccessor.Members.Count; i++) - { - var structDefinition = _definitions.OfType().FirstOrDefault(sd => sd.Name == prevMemberType.Name); - if (structDefinition == null) - { - throw new Exception($"Struct {prevMemberType} is not defined"); - } - - var member = structDefinition.Members.FirstOrDefault(m => m.Name == structMemberAccessor.Members[i]); - if (member == null) - { - throw new Exception($"Struct {prevMemberType} has no member with name {structMemberAccessor.Members[i]}"); - } - - var offset = structDefinition.Members.IndexOf(member); - _builder.AppendLine($" mov rax, [rax + {offset * 8}]"); - - prevMemberType = member.Type; - } } - private void GenerateBinaryExpression(BinaryExpressionNode binaryExpression, LocalFunc func) + private void GenerateBinaryExpression(BinaryExpressionNode binaryExpression, LocalFuncDef func) { - GenerateExpression(binaryExpression.Left, func); - _builder.AppendLine(" push rax"); - GenerateExpression(binaryExpression.Right, func); - _builder.AppendLine(" mov rcx, rax"); - _builder.AppendLine(" pop rax"); - - switch (binaryExpression.Operator) - { - case BinaryExpressionOperator.Equal: - GenerateComparison(binaryExpression.Left.Type); - _builder.AppendLine(" sete al"); - _builder.AppendLine(" movzx rax, al"); - break; - case BinaryExpressionOperator.NotEqual: - GenerateComparison(binaryExpression.Left.Type); - _builder.AppendLine(" setne al"); - _builder.AppendLine(" movzx rax, al"); - break; - case BinaryExpressionOperator.GreaterThan: - GenerateComparison(binaryExpression.Left.Type); - _builder.AppendLine(" setg al"); - _builder.AppendLine(" movzx rax, al"); - break; - case BinaryExpressionOperator.GreaterThanOrEqual: - GenerateComparison(binaryExpression.Left.Type); - _builder.AppendLine(" setge al"); - _builder.AppendLine(" movzx rax, al"); - break; - case BinaryExpressionOperator.LessThan: - GenerateComparison(binaryExpression.Left.Type); - _builder.AppendLine(" setl al"); - _builder.AppendLine(" movzx rax, al"); - break; - case BinaryExpressionOperator.LessThanOrEqual: - GenerateComparison(binaryExpression.Left.Type); - _builder.AppendLine(" setle al"); - _builder.AppendLine(" movzx rax, al"); - break; - case BinaryExpressionOperator.Plus: - GenerateBinaryAddition(binaryExpression.Left.Type); - break; - case BinaryExpressionOperator.Minus: - GenerateBinarySubtraction(binaryExpression.Left.Type); - break; - case BinaryExpressionOperator.Multiply: - GenerateBinaryMultiplication(binaryExpression.Left.Type); - break; - case BinaryExpressionOperator.Divide: - GenerateBinaryDivision(binaryExpression.Left.Type); - break; - default: - throw new ArgumentOutOfRangeException(nameof(binaryExpression.Operator)); - } } - private void GenerateComparison(NubType type) + private void GenerateIdentifier(IdentifierNode identifier, LocalFuncDef func) { - if (type.Equals(NubType.String)) - { - _builder.AppendLine(" mov rdi, rax"); - _builder.AppendLine(" mov rsi, rcx"); - _builder.AppendLine(" call str_cmp"); - } - else if (type.Equals(NubType.Bool) || type.Equals(NubType.Int64) || type.Equals(NubType.Int32) || type.Equals(NubType.Int16) || type.Equals(NubType.Int8)) - { - _builder.AppendLine(" cmp rax, rcx"); - } - else - { - throw new ArgumentOutOfRangeException(nameof(type)); - } } - private void GenerateBinaryAddition(NubType type) + private void GenerateLiteral(LiteralNode literal, LocalFuncDef func) { - if (type.Equals(NubType.Int64)) - { - _builder.AppendLine(" add rax, rcx"); - } - else if (type.Equals(NubType.Int32)) - { - _builder.AppendLine(" add eax, ecx"); - } - else if (type.Equals(NubType.Int16)) - { - _builder.AppendLine(" add ax, cx"); - } - else if (type.Equals(NubType.Int8)) - { - _builder.AppendLine(" add al, cl"); - } - else - { - throw new InvalidOperationException($"Invalid type for addition {type}"); - } - } - - private void GenerateBinarySubtraction(NubType type) - { - if (type.Equals(NubType.Int64)) - { - _builder.AppendLine(" sub rax, rcx"); - } - else if (type.Equals(NubType.Int32)) - { - _builder.AppendLine(" sub eax, ecx"); - } - else if (type.Equals(NubType.Int16)) - { - _builder.AppendLine(" sub ax, cx"); - } - else if (type.Equals(NubType.Int8)) - { - _builder.AppendLine(" sub al, cl"); - } - else - { - throw new InvalidOperationException($"Invalid type for subtraction {type}"); - } } - private void GenerateBinaryMultiplication(NubType type) + private void GenerateStructInitializer(StructInitializerNode structInitializer, LocalFuncDef func) { - if (type.Equals(NubType.Int64)) - { - _builder.AppendLine(" imul rcx"); - } - else if (type.Equals(NubType.Int32)) - { - _builder.AppendLine(" imul ecx"); - } - else if (type.Equals(NubType.Int16)) - { - _builder.AppendLine(" imul cx"); - } - else if (type.Equals(NubType.Int8)) - { - _builder.AppendLine(" imul cl"); - } - else - { - throw new InvalidOperationException($"Invalid type for multiplication {type}"); - } } - private void GenerateBinaryDivision(NubType type) + private void GenerateFuncCall(FuncCall funcCall, LocalFuncDef func) { - if (type.Equals(NubType.Int64)) - { - _builder.AppendLine(" cqo"); - _builder.AppendLine(" idiv rcx"); - } - else if (type.Equals(NubType.Int32)) - { - _builder.AppendLine(" cdq"); - _builder.AppendLine(" idiv ecx"); - } - else if (type.Equals(NubType.Int16)) - { - _builder.AppendLine(" cwd"); - _builder.AppendLine(" idiv cx"); - } - else if (type.Equals(NubType.Int8)) - { - _builder.AppendLine(" cbw"); - _builder.AppendLine(" idiv cl"); - } - else - { - throw new InvalidOperationException($"Invalid type for division {type}"); - } } - private void GenerateIdentifier(IdentifierNode identifier, LocalFunc func) + private void GenerateSyscall(Syscall syscall, LocalFuncDef func) { - var variable = func.ResolveVariable(identifier.Identifier); - switch (variable) - { - case GlobalVariable globalVariable: - _builder.AppendLine($" mov rax, [{globalVariable.Identifier}]"); - break; - case LocalVariable localVariable: - { - _builder.AppendLine($" mov rax, [rbp - {localVariable.Offset}]"); - break; - } - default: - { - throw new ArgumentOutOfRangeException(nameof(variable)); - } - } - } - - private void GenerateLiteral(LiteralNode literal) - { - if (literal.Type.Equals(NubType.String)) - { - var label = _symbolTable.DefineString(literal.Literal); - _builder.AppendLine($" mov rax, {label}"); - } - else if (literal.Type.Equals(NubType.Int64) || literal.Type.Equals(NubType.Int32) || literal.Type.Equals(NubType.Int16) || literal.Type.Equals(NubType.Int8)) - { - _builder.AppendLine($" mov rax, {literal.Literal}"); - } - else if (literal.Type.Equals(NubType.Bool)) - { - _builder.AppendLine($" mov rax, {(bool.Parse(literal.Literal) ? "1" : "0")}"); - } - else - { - throw new NotImplementedException($"Literal type {literal.Type} not implemented"); - } - } - - private void GenerateStructInitializer(StructInitializerNode structInitializer, LocalFunc func) - { - var structDefinition = _definitions - .OfType() - .FirstOrDefault(sd => sd.Name == structInitializer.StructType.Name); - - if (structDefinition == null) - { - throw new Exception($"Struct {structInitializer.StructType} is not defined"); - } - - _builder.AppendLine($" mov rdi, {structDefinition.Members.Count * 8}"); - _builder.AppendLine(" call gc_alloc"); - _builder.AppendLine(" mov rcx, rax"); - - foreach (var initializer in structInitializer.Initializers) - { - _builder.AppendLine(" push rcx"); - GenerateExpression(initializer.Value, func); - var index = structDefinition.Members.FindIndex(sd => sd.Name == initializer.Key); - if (index == -1) - { - throw new Exception($"Member {initializer.Key} is not defined on struct {structInitializer.StructType}"); - } - - _builder.AppendLine(" pop rcx"); - _builder.AppendLine($" mov [rcx + {index * 8}], rax"); - } - - foreach (var uninitializedMember in structDefinition.Members.Where(m => !structInitializer.Initializers.ContainsKey(m.Name))) - { - if (!uninitializedMember.Value.HasValue) - { - throw new Exception($"Struct {structInitializer.StructType} must be initializer with member {uninitializedMember.Name}"); - } - - _builder.AppendLine(" push rcx"); - GenerateExpression(uninitializedMember.Value.Value, func); - var index = structDefinition.Members.IndexOf(uninitializedMember); - _builder.AppendLine(" pop rcx"); - _builder.AppendLine($" mov [rcx + {index * 8}], rax"); - } - - _builder.AppendLine(" mov rax, rcx"); - } - - private void GenerateFuncCall(FuncCall funcCall, LocalFunc func) - { - var symbol = _symbolTable.ResolveFunc(funcCall.Name, funcCall.Parameters.Select(p => p.Type).ToList()); - string[] registers = ["rdi", "rsi", "rdx", "rcx", "r8", "r9"]; - - for (var i = funcCall.Parameters.Count - 1; i >= 0; i--) - { - GenerateExpression(funcCall.Parameters.ElementAt(i), func); - _builder.AppendLine(" push rax"); - } - - var registerParameters = Math.Min(registers.Length, funcCall.Parameters.Count); - var stackParameters = funcCall.Parameters.Count - registerParameters; - - for (var i = 0; i < registerParameters; i++) - { - _builder.AppendLine($" pop {registers[i]}"); - } - - _builder.AppendLine($" call {symbol.StartLabel}"); - if (stackParameters != 0) - { - _builder.AppendLine($" add rsp, {stackParameters}"); - } - } - - private void GenerateSyscall(Syscall syscall, LocalFunc func) - { - string[] registers = ["rax", "rdi", "rsi", "rdx", "r10", "r8", "r9"]; - - foreach (var parameter in syscall.Parameters) - { - GenerateExpression(parameter, func); - _builder.AppendLine(" push rax"); - } - - for (var i = syscall.Parameters.Count - 1; i >= 0; i--) - { - _builder.AppendLine($" pop {registers[i]}"); - } - - _builder.AppendLine(" syscall"); } } \ No newline at end of file diff --git a/lang/Nub.Lang/Backend/LabelFactory.cs b/lang/Nub.Lang/Backend/LabelFactory.cs deleted file mode 100644 index 495c15a..0000000 --- a/lang/Nub.Lang/Backend/LabelFactory.cs +++ /dev/null @@ -1,7 +0,0 @@ -namespace Nub.Lang.Backend; - -public class LabelFactory -{ - private int _index; - public string Create() => $"label{++_index}"; -} \ No newline at end of file diff --git a/lang/Nub.Lang/Backend/SymbolTable.cs b/lang/Nub.Lang/Backend/SymbolTable.cs index eef8668..550e6ef 100644 --- a/lang/Nub.Lang/Backend/SymbolTable.cs +++ b/lang/Nub.Lang/Backend/SymbolTable.cs @@ -4,195 +4,129 @@ namespace Nub.Lang.Backend; public class SymbolTable { - private readonly List _funcDefinitions = []; - private readonly List _globalVariables = []; - private readonly LabelFactory _labelFactory; - - public readonly Dictionary Strings = []; + public static SymbolTable Create(IEnumerable program) + { + var externFuncDefs = new List(); + var localFuncDefs = new List(); - public SymbolTable(LabelFactory labelFactory) - { - _labelFactory = labelFactory; - } - - public string DefineString(string value) - { - var label = _labelFactory.Create(); - Strings.Add(label, value); - return label; - } - - public void DefineGlobalVariable(GlobalVariableDefinitionNode globalVariableDefinition) - { - var identifier = _labelFactory.Create(); - _globalVariables.Add(new GlobalVariable(globalVariableDefinition.Name, globalVariableDefinition.Value.Type, identifier)); - } - - public void DefineFunc(ExternFuncDefinitionNode externFuncDefinition) - { - var existing = _funcDefinitions - .FirstOrDefault(f => f - .SignatureMatches - ( - externFuncDefinition.Name, - externFuncDefinition.Parameters.Select(p => p.Type).ToList() - )); - - if (existing != null) + foreach (var node in program) { - throw new Exception($"Func {existing} is already defined"); - } - - _funcDefinitions.Add(new ExternFunc(externFuncDefinition.Name, externFuncDefinition.Name, externFuncDefinition.Parameters, externFuncDefinition.ReturnType)); - } - - public void DefineFunc(LocalFuncDefinitionNode localFuncDefinition) - { - var existing = _funcDefinitions - .FirstOrDefault(f => f - .SignatureMatches - ( - localFuncDefinition.Name, - localFuncDefinition.Parameters.Select(p => p.Type).ToList() - )); - - if (existing != null) - { - throw new Exception($"Func {existing} is already defined"); - } - - var startLabel = _labelFactory.Create(); - var endLabel = _labelFactory.Create(); - _funcDefinitions.Add(new LocalFunc(localFuncDefinition.Name, startLabel, endLabel, localFuncDefinition.Parameters, localFuncDefinition.ReturnType, _globalVariables.Concat(ResolveFuncVariables(localFuncDefinition)).ToList())); - } - - private static List ResolveFuncVariables(LocalFuncDefinitionNode localFuncDefinition) - { - var offset = 0; - List variables = []; - - foreach (var parameter in localFuncDefinition.Parameters) - { - offset += 8; - variables.Add(new LocalVariable(parameter.Name, parameter.Type, offset)); - } - - ResolveBlockVariables(localFuncDefinition.Body, variables, offset); - - return variables; - } - - private static int ResolveBlockVariables(BlockNode block, List variables, int offset) - { - foreach (var statement in block.Statements) - { - switch (statement) + switch (node) { - case IfNode ifStatement: + case ExternFuncDefinitionNode externFuncDefinitionNode: { - offset = ResolveBlockVariables(ifStatement.Body, variables, offset); - if (ifStatement.Else.HasValue) + var parameters = externFuncDefinitionNode.Parameters.Select(parameter => new Variable(parameter.Name, parameter.Type)).ToList(); + externFuncDefs.Add(new ExternFuncDef { - ifStatement.Else.Value.Match - ( - elseIfStatement => offset = ResolveBlockVariables(elseIfStatement.Body, variables, offset), - elseStatement => offset = ResolveBlockVariables(elseStatement, variables, offset) - ); + Name = externFuncDefinitionNode.Name, + Parameters = parameters, + ReturnType = externFuncDefinitionNode.ReturnType + }); + break; + } + case LocalFuncDefinitionNode localFuncDefinitionNode: + { + var parameters = localFuncDefinitionNode.Parameters.Select(parameter => new Variable(parameter.Name, parameter.Type)).ToList(); + var localVariables = new List(); + + FindVariables(localFuncDefinitionNode.Body); + + localFuncDefs.Add(new LocalFuncDef + { + Name = localFuncDefinitionNode.Name, + Parameters = parameters, + LocalVariables = localVariables, + ReturnType = localFuncDefinitionNode.ReturnType + }); + break; + + void FindVariables(BlockNode blockNode) + { + foreach (var statement in blockNode.Statements) + { + switch (statement) + { + case IfNode ifNode: + { + FindVariables(ifNode.Body); + break; + } + case WhileNode whileNode: + { + FindVariables(whileNode.Body); + break; + } + case VariableAssignmentNode variableAssignmentNode: + { + localVariables.Add(new Variable(variableAssignmentNode.Name, variableAssignmentNode.Value.Type)); + break; + } + } + } } - break; } - case WhileNode whileStatement: + case StructDefinitionNode structDefinitionNode: { - offset = ResolveBlockVariables(whileStatement.Body, variables, offset); - break; + throw new NotImplementedException(); } - case VariableAssignmentNode variableAssignment: + default: { - offset += 8; - variables.Add(new LocalVariable(variableAssignment.Name, variableAssignment.Value.Type, offset)); - break; + throw new ArgumentOutOfRangeException(nameof(node)); } } } - return offset; - } - - public Func ResolveFunc(string name, List parameterTypes) + return new SymbolTable(externFuncDefs, localFuncDefs); + } + + private readonly List _externFuncDefs; + private readonly List _localFuncDefs; + + private SymbolTable(List externFuncDefs, List localFuncDefs) { - var func = _funcDefinitions.FirstOrDefault(f => f.SignatureMatches(name, parameterTypes)); - if (func == null) + _externFuncDefs = externFuncDefs; + _localFuncDefs = localFuncDefs; + } + + public FuncDef ResolveFunc(string name, List parameters) + { + var matching = _externFuncDefs.Concat(_localFuncDefs).Where(funcDef => funcDef.SignatureMatches(name, parameters)).ToArray(); + return matching.Length switch { - throw new Exception($"Func {name}({string.Join(", ", parameterTypes)}) is not defined"); + 0 => throw new Exception($"Could not resolve a func with signature {name}({string.Join(", ", parameters)})"), + > 1 => throw new Exception($"Multiple functions matches the signature {name}({string.Join(", ", parameters)})"), + _ => matching[0] + }; + } + + public LocalFuncDef ResolveLocalFunc(string name, List parameters) + { + var funcDef = ResolveFunc(name, parameters); + if (funcDef is LocalFuncDef localFuncDef) + { + return localFuncDef; } - return func; + throw new Exception($"Could not resolve a local func with signature {name}({string.Join(", ", parameters)})"); } - - public LocalFunc ResolveLocalFunc(string name, List parameterTypes) + + public ExternFuncDef ResolveExternFunc(string name, List parameters) { - var func = ResolveFunc(name, parameterTypes); - if (func is not LocalFunc localFunc) + var funcDef = ResolveFunc(name, parameters); + if (funcDef is ExternFuncDef externFuncDef) { - throw new Exception($"Func {func} is not a local func"); - } - return localFunc; - } - - public ExternFunc ResolveExternFunc(string name, List parameterTypes) - { - var func = ResolveFunc(name, parameterTypes); - if (func is not ExternFunc externFunc) - { - throw new Exception($"Func {func} is not an extern func"); - } - return externFunc; - } - - public GlobalVariable ResolveGlobalVariable(string name) - { - var variable = _globalVariables.FirstOrDefault(v => v.Name == name); - if (variable == null) - { - throw new Exception($"Global variable {name} is not defined"); + return externFuncDef; } - return variable; + throw new Exception($"Could not resolve a extern func with signature {name}({string.Join(", ", parameters)})"); } } -public abstract class Variable(string name, NubType type) +public abstract class FuncDef { - public string Name { get; } = name; - public NubType Type { get; } = type; - - public override string ToString() => $"{Name}: {Type}"; -} - -public class LocalVariable(string name, NubType type, int offset) : Variable(name, type) -{ - public int Offset { get; } = offset; -} - -public class GlobalVariable(string name, NubType type, string identifier) : Variable(name, type) -{ - public string Identifier { get; } = identifier; -} - -public abstract class Func -{ - protected Func(string name, string startLabel, List parameters, Optional returnType) - { - Name = name; - Parameters = parameters; - ReturnType = returnType; - StartLabel = startLabel; - } - - public string Name { get; } - public string StartLabel { get; } - public List Parameters { get; } - public Optional ReturnType { get; } + public required string Name { get; init; } + public required List Parameters { get; init; } + public required Optional ReturnType { get; init; } public bool SignatureMatches(string name, List parameterTypes) { @@ -206,53 +140,24 @@ public abstract class Func return true; } - - public override string ToString() => $"{Name}({string.Join(", ", Parameters.Select(p => p.ToString()))}){(ReturnType.HasValue ? ": " + ReturnType.Value : "")}"; } -public class ExternFunc : Func +public sealed class LocalFuncDef : FuncDef { - public ExternFunc(string name, string startLabel, List parameters, Optional returnType) : base(name, startLabel, parameters, returnType) + public required List LocalVariables { get; set; } + + public override string ToString() { + return $"func {Name}({string.Join(", ", Parameters.Select(p => p.ToString()))}){(ReturnType.HasValue ? ": " + ReturnType.Value : "")}"; } } -public class LocalFunc : Func +public sealed class ExternFuncDef : FuncDef; + +public sealed class Variable(string name, NubType type) { - public LocalFunc(string name, string startLabel, string endLabel, List parameters, Optional returnType, List variables) : base(name, startLabel, parameters, returnType) - { - EndLabel = endLabel; - Variables = variables; - } + public string Name { get; } = name; + public NubType Type { get; } = type; - public string EndLabel { get; } - public List Variables { get; } - public int StackAllocation => Variables.OfType().Sum(variable => variable.Offset); - - public Variable ResolveVariable(string name) - { - var variable = Variables.FirstOrDefault(v => v.Name == name); - if (variable == null) - { - throw new Exception($"Variable {name} is not defined"); - } - - return variable; - } - - public LocalVariable ResolveLocalVariable(string name) - { - var variable = Variables.FirstOrDefault(v => v.Name == name); - if (variable == null) - { - throw new Exception($"Variable {name} is not defined"); - } - - if (variable is not LocalVariable localVariable) - { - throw new Exception($"Variable {name} is not a local variable"); - } - - return localVariable; - } + public override string ToString() => $"{Name}: {Type}"; } \ No newline at end of file diff --git a/lang/Nub.Lang/Frontend/Parsing/GlobalVariableDefinitionNode.cs b/lang/Nub.Lang/Frontend/Parsing/GlobalVariableDefinitionNode.cs deleted file mode 100644 index 04a69d2..0000000 --- a/lang/Nub.Lang/Frontend/Parsing/GlobalVariableDefinitionNode.cs +++ /dev/null @@ -1,7 +0,0 @@ -namespace Nub.Lang.Frontend.Parsing; - -public class GlobalVariableDefinitionNode(string name, ExpressionNode value) : DefinitionNode -{ - public string Name { get; } = name; - public ExpressionNode Value { get; } = value; -} \ No newline at end of file diff --git a/lang/Nub.Lang/Frontend/Parsing/Parser.cs b/lang/Nub.Lang/Frontend/Parsing/Parser.cs index c564d8c..81d740f 100644 --- a/lang/Nub.Lang/Frontend/Parsing/Parser.cs +++ b/lang/Nub.Lang/Frontend/Parsing/Parser.cs @@ -43,7 +43,6 @@ public class Parser var keyword = ExpectSymbol(); return keyword.Symbol switch { - Symbol.Let => ParseGlobalVariableDefinition(), Symbol.Func => ParseFuncDefinition(), Symbol.Extern => ParseExternFuncDefinition(), Symbol.Struct => ParseStruct(), @@ -51,16 +50,6 @@ public class Parser }; } - private GlobalVariableDefinitionNode ParseGlobalVariableDefinition() - { - var name = ExpectIdentifier(); - ExpectSymbol(Symbol.Assign); - var value = ParseExpression(); - ExpectSymbol(Symbol.Semicolon); - - return new GlobalVariableDefinitionNode(name.Value, value); - } - private LocalFuncDefinitionNode ParseFuncDefinition() { var name = ExpectIdentifier(); diff --git a/lang/Nub.Lang/Frontend/Typing/ExpressionTyper.cs b/lang/Nub.Lang/Frontend/Typing/ExpressionTyper.cs index c8b5f3a..0fe8d2c 100644 --- a/lang/Nub.Lang/Frontend/Typing/ExpressionTyper.cs +++ b/lang/Nub.Lang/Frontend/Typing/ExpressionTyper.cs @@ -13,7 +13,6 @@ public class Func(string name, List parameters, Optional _functions; - private readonly List _variableDefinitions; private readonly List _structDefinitions; private readonly Stack _variables; @@ -21,7 +20,6 @@ public class ExpressionTyper { _variables = new Stack(); _functions = []; - _variableDefinitions = []; _structDefinitions = definitions.OfType().ToList(); @@ -37,7 +35,6 @@ public class ExpressionTyper _functions.AddRange(functions); _functions.AddRange(externFunctions); - _variableDefinitions.AddRange(definitions.OfType()); } public void Populate() @@ -55,12 +52,6 @@ public class ExpressionTyper } } - foreach (var variable in _variableDefinitions) - { - PopulateExpression(variable.Value); - _variables.Push(new Variable(variable.Name, variable.Value.Type)); - } - foreach (var function in _functions) { foreach (var parameter in function.Parameters)