From d99358136113bd701cad83d6fe3077e0e2d63362 Mon Sep 17 00:00:00 2001 From: nub31 Date: Tue, 22 Jul 2025 23:20:56 +0200 Subject: [PATCH] ... --- src/compiler/NubLang.CLI/Program.cs | 27 +- .../Generation/BoundDefinitionTable.cs | 78 --- .../Generation/QBE/QBEGenerator.Expression.cs | 128 ++-- .../Generation/QBE/QBEGenerator.Statement.cs | 30 +- .../NubLang/Generation/QBE/QBEGenerator.cs | 54 +- .../Generation/TypedDefinitionTable.cs | 78 +++ .../NubLang/{Syntax => }/Parsing/Parser.cs | 58 +- .../Syntax}/DefinitionSyntax.cs | 2 +- .../Syntax}/ExpressionSyntax.cs | 13 +- .../Syntax}/StatementSyntax.cs | 2 +- .../Node => Parsing/Syntax}/SyntaxNode.cs | 2 +- .../Node => Parsing/Syntax}/TypeSyntax.cs | 2 +- src/compiler/NubLang/Syntax/Binding/Binder.cs | 637 ----------------- .../Syntax/Binding/Node/BoundDefinition.cs | 25 - .../Syntax/Binding/Node/BoundExpression.cs | 55 -- .../Syntax/Binding/Node/BoundStatement.cs | 21 - .../Syntax/Binding/Node/BoundSyntaxTree.cs | 9 - .../{Syntax => }/Tokenization/Token.cs | 2 +- .../{Syntax => }/Tokenization/Tokenizer.cs | 2 +- .../DefinitionTable.cs | 4 +- .../NubLang/TypeChecking/Node/Definition.cs | 25 + .../NubLang/TypeChecking/Node/Expression.cs | 55 ++ .../NubLang/TypeChecking/Node/Statement.cs | 21 + .../NubLang/TypeChecking/Node/SyntaxTree.cs | 7 + .../Binding => TypeChecking}/NubType.cs | 28 +- .../NubLang/TypeChecking/TypeChecker.cs | 640 ++++++++++++++++++ 26 files changed, 1002 insertions(+), 1003 deletions(-) delete mode 100644 src/compiler/NubLang/Generation/BoundDefinitionTable.cs create mode 100644 src/compiler/NubLang/Generation/TypedDefinitionTable.cs rename src/compiler/NubLang/{Syntax => }/Parsing/Parser.cs (92%) rename src/compiler/NubLang/{Syntax/Parsing/Node => Parsing/Syntax}/DefinitionSyntax.cs (98%) rename src/compiler/NubLang/{Syntax/Parsing/Node => Parsing/Syntax}/ExpressionSyntax.cs (90%) rename src/compiler/NubLang/{Syntax/Parsing/Node => Parsing/Syntax}/StatementSyntax.cs (98%) rename src/compiler/NubLang/{Syntax/Parsing/Node => Parsing/Syntax}/SyntaxNode.cs (95%) rename src/compiler/NubLang/{Syntax/Parsing/Node => Parsing/Syntax}/TypeSyntax.cs (98%) delete mode 100644 src/compiler/NubLang/Syntax/Binding/Binder.cs delete mode 100644 src/compiler/NubLang/Syntax/Binding/Node/BoundDefinition.cs delete mode 100644 src/compiler/NubLang/Syntax/Binding/Node/BoundExpression.cs delete mode 100644 src/compiler/NubLang/Syntax/Binding/Node/BoundStatement.cs delete mode 100644 src/compiler/NubLang/Syntax/Binding/Node/BoundSyntaxTree.cs rename src/compiler/NubLang/{Syntax => }/Tokenization/Token.cs (95%) rename src/compiler/NubLang/{Syntax => }/Tokenization/Tokenizer.cs (99%) rename src/compiler/NubLang/{Syntax/Binding => TypeChecking}/DefinitionTable.cs (95%) create mode 100644 src/compiler/NubLang/TypeChecking/Node/Definition.cs create mode 100644 src/compiler/NubLang/TypeChecking/Node/Expression.cs create mode 100644 src/compiler/NubLang/TypeChecking/Node/Statement.cs create mode 100644 src/compiler/NubLang/TypeChecking/Node/SyntaxTree.cs rename src/compiler/NubLang/{Syntax/Binding => TypeChecking}/NubType.cs (91%) create mode 100644 src/compiler/NubLang/TypeChecking/TypeChecker.cs diff --git a/src/compiler/NubLang.CLI/Program.cs b/src/compiler/NubLang.CLI/Program.cs index ba1639f..e050c10 100644 --- a/src/compiler/NubLang.CLI/Program.cs +++ b/src/compiler/NubLang.CLI/Program.cs @@ -5,12 +5,11 @@ using NubLang.Common; using NubLang.Diagnostics; using NubLang.Generation; using NubLang.Generation.QBE; -using NubLang.Syntax.Binding; -using NubLang.Syntax.Binding.Node; -using NubLang.Syntax.Parsing; -using NubLang.Syntax.Parsing.Node; -using NubLang.Syntax.Tokenization; -using Binder = NubLang.Syntax.Binding.Binder; +using NubLang.Parsing; +using NubLang.Parsing.Syntax; +using NubLang.Tokenization; +using NubLang.TypeChecking; +using NubLang.TypeChecking.Node; const string BIN_DIR = "bin"; const string INT_DIR = "bin-int"; @@ -92,14 +91,14 @@ foreach (var file in options.Files) var definitionTable = new DefinitionTable(syntaxTrees); -var boundSyntaxTrees = new List(); +var typedSyntaxTrees = new List(); foreach (var syntaxTree in syntaxTrees) { - var binder = new Binder(syntaxTree, definitionTable); - var boundSyntaxTree = binder.Bind(); - diagnostics.AddRange(boundSyntaxTree.Diagnostics); - boundSyntaxTrees.Add(boundSyntaxTree); + var typeChecker = new TypeChecker(syntaxTree, definitionTable); + var typedSyntaxTree = typeChecker.Check(); + diagnostics.AddRange(typeChecker.GetDiagnostics()); + typedSyntaxTrees.Add(typedSyntaxTree); } foreach (var diagnostic in diagnostics) @@ -112,15 +111,15 @@ if (diagnostics.Any(diagnostic => diagnostic.Severity == DiagnosticSeverity.Erro return 1; } -var boundDefinitionTable = new BoundDefinitionTable(boundSyntaxTrees); +var typedDefinitionTable = new TypedDefinitionTable(typedSyntaxTrees); var objectFiles = new List(); -foreach (var syntaxTree in boundSyntaxTrees) +foreach (var syntaxTree in typedSyntaxTrees) { var outFileName = HexString.CreateUnique(8); - var generator = new QBEGenerator(syntaxTree, boundDefinitionTable); + var generator = new QBEGenerator(syntaxTree, typedDefinitionTable); var ssa = generator.Emit(); File.WriteAllText(Path.Join(INT_DEBUG_DIR, $"{outFileName}.ssa"), ssa); diff --git a/src/compiler/NubLang/Generation/BoundDefinitionTable.cs b/src/compiler/NubLang/Generation/BoundDefinitionTable.cs deleted file mode 100644 index 233f945..0000000 --- a/src/compiler/NubLang/Generation/BoundDefinitionTable.cs +++ /dev/null @@ -1,78 +0,0 @@ -using NubLang.Syntax.Binding; -using NubLang.Syntax.Binding.Node; - -namespace NubLang.Generation; - -public sealed class BoundDefinitionTable -{ - private readonly List _definitions; - - public BoundDefinitionTable(IEnumerable syntaxTrees) - { - _definitions = syntaxTrees.SelectMany(x => x.Definitions).ToList(); - } - - public BoundLocalFunc LookupLocalFunc(string name) - { - return _definitions - .OfType() - .First(x => x.Name == name); - } - - public BoundExternFunc LookupExternFunc(string name) - { - return _definitions - .OfType() - .First(x => x.Name == name); - } - - public BoundStruct LookupStruct(string name) - { - return _definitions - .OfType() - .First(x => x.Name == name); - } - - public BoundStructField LookupStructField(BoundStruct @struct, string field) - { - return @struct.Fields.First(x => x.Name == field); - } - - public IEnumerable LookupTraitImpls(NubType itemType) - { - return _definitions - .OfType() - .Where(x => x.ForType == itemType); - } - - public BoundTraitFuncImpl LookupTraitFuncImpl(NubType forType, string name) - { - return _definitions - .OfType() - .Where(x => x.ForType == forType) - .SelectMany(x => x.Functions) - .First(x => x.Name == name); - } - - public BoundTrait LookupTrait(string name) - { - return _definitions - .OfType() - .First(x => x.Name == name); - } - - public BoundTraitFunc LookupTraitFunc(BoundTrait trait, string name) - { - return trait.Functions.First(x => x.Name == name); - } - - public IEnumerable GetStructs() - { - return _definitions.OfType(); - } - - public IEnumerable GetTraits() - { - return _definitions.OfType(); - } -} \ No newline at end of file diff --git a/src/compiler/NubLang/Generation/QBE/QBEGenerator.Expression.cs b/src/compiler/NubLang/Generation/QBE/QBEGenerator.Expression.cs index 7e493ea..e633156 100644 --- a/src/compiler/NubLang/Generation/QBE/QBEGenerator.Expression.cs +++ b/src/compiler/NubLang/Generation/QBE/QBEGenerator.Expression.cs @@ -1,49 +1,49 @@ using System.Diagnostics; using System.Globalization; -using NubLang.Syntax.Binding; -using NubLang.Syntax.Binding.Node; -using NubLang.Syntax.Tokenization; +using NubLang.Tokenization; +using NubLang.TypeChecking; +using NubLang.TypeChecking.Node; namespace NubLang.Generation.QBE; public partial class QBEGenerator { - private Val EmitExpression(BoundExpression expression) + private Val EmitExpression(Expression expression) { return expression switch { - BoundArrayInitializer arrayInitializer => EmitArrayInitializer(arrayInitializer), - BoundStructInitializer structInitializer => EmitStructInitializer(structInitializer), - BoundAddressOf addressOf => EmitAddressOf(addressOf), - BoundDereference dereference => EmitDereference(dereference), - BoundArrowFunc arrowFunc => EmitArrowFunc(arrowFunc), - BoundBinaryExpression binaryExpression => EmitBinaryExpression(binaryExpression), - BoundFuncCall funcCallExpression => EmitFuncCall(funcCallExpression), - BoundExternFuncIdent externFuncIdent => EmitExternFuncIdent(externFuncIdent), - BoundLocalFuncIdent localFuncIdent => EmitLocalFuncIdent(localFuncIdent), - BoundVariableIdent variableIdent => EmitVariableIdent(variableIdent), - BoundLiteral literal => EmitLiteral(literal), - BoundUnaryExpression unaryExpression => EmitUnaryExpression(unaryExpression), - BoundStructFieldAccess structFieldAccess => EmitStructFieldAccess(structFieldAccess), - BoundInterfaceFuncAccess traitFuncAccess => EmitTraitFuncAccess(traitFuncAccess), - BoundArrayIndexAccess arrayIndex => EmitArrayIndexAccess(arrayIndex), + ArrayInitializer arrayInitializer => EmitArrayInitializer(arrayInitializer), + StructInitializer structInitializer => EmitStructInitializer(structInitializer), + AddressOf addressOf => EmitAddressOf(addressOf), + Dereference dereference => EmitDereference(dereference), + ArrowFunc arrowFunc => EmitArrowFunc(arrowFunc), + BinaryExpression binaryExpression => EmitBinaryExpression(binaryExpression), + FuncCall funcCallExpression => EmitFuncCall(funcCallExpression), + ExternFuncIdent externFuncIdent => EmitExternFuncIdent(externFuncIdent), + LocalFuncIdent localFuncIdent => EmitLocalFuncIdent(localFuncIdent), + VariableIdent variableIdent => EmitVariableIdent(variableIdent), + Literal literal => EmitLiteral(literal), + UnaryExpression unaryExpression => EmitUnaryExpression(unaryExpression), + StructFieldAccess structFieldAccess => EmitStructFieldAccess(structFieldAccess), + InterfaceFuncAccess traitFuncAccess => EmitTraitFuncAccess(traitFuncAccess), + ArrayIndexAccess arrayIndex => EmitArrayIndexAccess(arrayIndex), _ => throw new ArgumentOutOfRangeException(nameof(expression)) }; } - private Val EmitArrowFunc(BoundArrowFunc arrowFunc) + private Val EmitArrowFunc(ArrowFunc arrowFunc) { var name = $"$arrow_func{++_arrowFuncIndex}"; _arrowFunctions.Enqueue((arrowFunc, name)); return new Val(name, arrowFunc.Type, ValKind.Direct); } - private Val EmitArrayIndexAccess(BoundArrayIndexAccess arrayIndexAccess) + private Val EmitArrayIndexAccess(ArrayIndexAccess arrayIndexAccess) { var array = EmitUnwrap(EmitExpression(arrayIndexAccess.Target)); var index = EmitUnwrap(EmitExpression(arrayIndexAccess.Index)); - EmitArrayBoundsCheck(array, index); + EmitArraysCheck(array, index); var elementType = ((NubArrayType)arrayIndexAccess.Target.Type).ElementType; @@ -54,7 +54,7 @@ public partial class QBEGenerator return new Val(pointer, arrayIndexAccess.Type, ValKind.Pointer); } - private void EmitArrayBoundsCheck(string array, string index) + private void EmitArraysCheck(string array, string index) { var count = TmpName(); _writer.Indented($"{count} =l loadl {array}"); @@ -78,7 +78,7 @@ public partial class QBEGenerator _writer.Indented(notOobLabel); } - private Val EmitArrayInitializer(BoundArrayInitializer arrayInitializer) + private Val EmitArrayInitializer(ArrayInitializer arrayInitializer) { var capacity = EmitUnwrap(EmitExpression(arrayInitializer.Capacity)); var elementSize = arrayInitializer.ElementType.Size(_definitionTable); @@ -99,12 +99,12 @@ public partial class QBEGenerator return new Val(arrayPointer, arrayInitializer.Type, ValKind.Direct); } - private Val EmitDereference(BoundDereference dereference) + private Val EmitDereference(Dereference dereference) { return EmitLoad(dereference.Type, EmitUnwrap(EmitExpression(dereference.Expression))); } - private Val EmitAddressOf(BoundAddressOf addressOf) + private Val EmitAddressOf(AddressOf addressOf) { var value = EmitExpression(addressOf.Expression); if (value.Kind != ValKind.Pointer) @@ -115,7 +115,7 @@ public partial class QBEGenerator return new Val(value.Name, addressOf.Type, ValKind.Direct); } - private Val EmitBinaryExpression(BoundBinaryExpression binaryExpression) + private Val EmitBinaryExpression(BinaryExpression binaryExpression) { var left = EmitUnwrap(EmitExpression(binaryExpression.Left)); var right = EmitUnwrap(EmitExpression(binaryExpression.Right)); @@ -128,15 +128,15 @@ public partial class QBEGenerator return new Val(outputName, binaryExpression.Type, ValKind.Direct); } - private string EmitBinaryInstructionFor(BoundBinaryOperator op, NubType type, string left, string right) + private string EmitBinaryInstructionFor(BinaryOperator op, NubType type, string left, string right) { if (op is - BoundBinaryOperator.Equal or - BoundBinaryOperator.NotEqual or - BoundBinaryOperator.GreaterThan or - BoundBinaryOperator.GreaterThanOrEqual or - BoundBinaryOperator.LessThan or - BoundBinaryOperator.LessThanOrEqual) + BinaryOperator.Equal or + BinaryOperator.NotEqual or + BinaryOperator.GreaterThan or + BinaryOperator.GreaterThanOrEqual or + BinaryOperator.LessThan or + BinaryOperator.LessThanOrEqual) { char suffix; @@ -177,12 +177,12 @@ public partial class QBEGenerator throw new NotSupportedException($"Unsupported type '{simpleType}' for binary operator '{op}'"); } - if (op is BoundBinaryOperator.Equal) + if (op is BinaryOperator.Equal) { return "ceq" + suffix; } - if (op is BoundBinaryOperator.NotEqual) + if (op is BinaryOperator.NotEqual) { return "cne" + suffix; } @@ -204,42 +204,42 @@ public partial class QBEGenerator return op switch { - BoundBinaryOperator.GreaterThan => 'c' + sign + "gt" + suffix, - BoundBinaryOperator.GreaterThanOrEqual => 'c' + sign + "ge" + suffix, - BoundBinaryOperator.LessThan => 'c' + sign + "lt" + suffix, - BoundBinaryOperator.LessThanOrEqual => 'c' + sign + "le" + suffix, + BinaryOperator.GreaterThan => 'c' + sign + "gt" + suffix, + BinaryOperator.GreaterThanOrEqual => 'c' + sign + "ge" + suffix, + BinaryOperator.LessThan => 'c' + sign + "lt" + suffix, + BinaryOperator.LessThanOrEqual => 'c' + sign + "le" + suffix, _ => throw new ArgumentOutOfRangeException(nameof(op), op, null) }; } return op switch { - BoundBinaryOperator.Plus => "add", - BoundBinaryOperator.Minus => "sub", - BoundBinaryOperator.Multiply => "mul", - BoundBinaryOperator.Divide => "div", + BinaryOperator.Plus => "add", + BinaryOperator.Minus => "sub", + BinaryOperator.Multiply => "mul", + BinaryOperator.Divide => "div", _ => throw new ArgumentOutOfRangeException(nameof(op)) }; } - private Val EmitExternFuncIdent(BoundExternFuncIdent externFuncIdent) + private Val EmitExternFuncIdent(ExternFuncIdent externFuncIdent) { var func = _definitionTable.LookupExternFunc(externFuncIdent.Name); return new Val(ExternFuncName(func), externFuncIdent.Type, ValKind.Direct); } - private Val EmitLocalFuncIdent(BoundLocalFuncIdent localFuncIdent) + private Val EmitLocalFuncIdent(LocalFuncIdent localFuncIdent) { var func = _definitionTable.LookupLocalFunc(localFuncIdent.Name); return new Val(LocalFuncName(func), localFuncIdent.Type, ValKind.Direct); } - private Val EmitVariableIdent(BoundVariableIdent variableIdent) + private Val EmitVariableIdent(VariableIdent variableIdent) { return Scope.Lookup(variableIdent.Name); } - private Val EmitLiteral(BoundLiteral literal) + private Val EmitLiteral(Literal literal) { switch (literal.Kind) { @@ -247,21 +247,21 @@ public partial class QBEGenerator { if (literal.Type is NubPrimitiveType { Kind: PrimitiveTypeKind.F32 }) { - var value = float.Parse(literal.Literal, CultureInfo.InvariantCulture); + var value = float.Parse(literal.Value, CultureInfo.InvariantCulture); var bits = BitConverter.SingleToInt32Bits(value); return new Val(bits.ToString(), literal.Type, ValKind.Direct); } if (literal.Type is NubPrimitiveType { Kind: PrimitiveTypeKind.F64 }) { - var value = double.Parse(literal.Literal, CultureInfo.InvariantCulture); + var value = double.Parse(literal.Value, CultureInfo.InvariantCulture); var bits = BitConverter.DoubleToInt64Bits(value); return new Val(bits.ToString(), literal.Type, ValKind.Direct); } if (literal.Type is NubPrimitiveType { Kind: PrimitiveTypeKind.I8 or PrimitiveTypeKind.U8 or PrimitiveTypeKind.I16 or PrimitiveTypeKind.U16 or PrimitiveTypeKind.I32 or PrimitiveTypeKind.U32 or PrimitiveTypeKind.I64 or PrimitiveTypeKind.U64 }) { - return new Val(literal.Literal, literal.Type, ValKind.Direct); + return new Val(literal.Value, literal.Type, ValKind.Direct); } break; @@ -270,19 +270,19 @@ public partial class QBEGenerator { if (literal.Type is NubPrimitiveType { Kind: PrimitiveTypeKind.I8 or PrimitiveTypeKind.U8 or PrimitiveTypeKind.I16 or PrimitiveTypeKind.U16 or PrimitiveTypeKind.I32 or PrimitiveTypeKind.U32 or PrimitiveTypeKind.I64 or PrimitiveTypeKind.U64 }) { - return new Val(literal.Literal.Split(".").First(), literal.Type, ValKind.Direct); + return new Val(literal.Value.Split(".").First(), literal.Type, ValKind.Direct); } if (literal.Type is NubPrimitiveType { Kind: PrimitiveTypeKind.F32 }) { - var value = float.Parse(literal.Literal, CultureInfo.InvariantCulture); + var value = float.Parse(literal.Value, CultureInfo.InvariantCulture); var bits = BitConverter.SingleToInt32Bits(value); return new Val(bits.ToString(), literal.Type, ValKind.Direct); } if (literal.Type is NubPrimitiveType { Kind: PrimitiveTypeKind.F64 }) { - var value = double.Parse(literal.Literal, CultureInfo.InvariantCulture); + var value = double.Parse(literal.Value, CultureInfo.InvariantCulture); var bits = BitConverter.DoubleToInt64Bits(value); return new Val(bits.ToString(), literal.Type, ValKind.Direct); } @@ -293,14 +293,14 @@ public partial class QBEGenerator { if (literal.Type is NubStringType) { - var stringLiteral = new StringLiteral(literal.Literal, StringName()); + var stringLiteral = new StringLiteral(literal.Value, StringName()); _stringLiterals.Add(stringLiteral); return new Val(stringLiteral.Name, literal.Type, ValKind.Direct); } if (literal.Type is NubCStringType) { - var cStringLiteral = new CStringLiteral(literal.Literal, CStringName()); + var cStringLiteral = new CStringLiteral(literal.Value, CStringName()); _cStringLiterals.Add(cStringLiteral); return new Val(cStringLiteral.Name, literal.Type, ValKind.Direct); } @@ -311,7 +311,7 @@ public partial class QBEGenerator { if (literal.Type is NubPrimitiveType { Kind: PrimitiveTypeKind.Bool }) { - return new Val(bool.Parse(literal.Literal) ? "1" : "0", literal.Type, ValKind.Direct); + return new Val(bool.Parse(literal.Value) ? "1" : "0", literal.Type, ValKind.Direct); } break; @@ -321,7 +321,7 @@ public partial class QBEGenerator throw new NotSupportedException($"Cannot create literal of kind '{literal.Kind}' for type {literal.Type}"); } - private Val EmitStructInitializer(BoundStructInitializer structInitializer, string? destination = null) + private Val EmitStructInitializer(StructInitializer structInitializer, string? destination = null) { var @struct = _definitionTable.LookupStruct(structInitializer.StructType.Name); @@ -349,14 +349,14 @@ public partial class QBEGenerator return new Val(destination, structInitializer.StructType, ValKind.Direct); } - private Val EmitUnaryExpression(BoundUnaryExpression unaryExpression) + private Val EmitUnaryExpression(UnaryExpression unaryExpression) { var operand = EmitUnwrap(EmitExpression(unaryExpression.Operand)); var outputName = TmpName(); switch (unaryExpression.Operator) { - case BoundUnaryOperator.Negate: + case UnaryOperator.Negate: { switch (unaryExpression.Operand.Type) { @@ -376,7 +376,7 @@ public partial class QBEGenerator break; } - case BoundUnaryOperator.Invert: + case UnaryOperator.Invert: { switch (unaryExpression.Operand.Type) { @@ -396,7 +396,7 @@ public partial class QBEGenerator throw new NotSupportedException($"Unary operator {unaryExpression.Operator} for type {unaryExpression.Operand.Type} not supported"); } - private Val EmitStructFieldAccess(BoundStructFieldAccess structFieldAccess) + private Val EmitStructFieldAccess(StructFieldAccess structFieldAccess) { var target = EmitUnwrap(EmitExpression(structFieldAccess.Target)); @@ -415,12 +415,12 @@ public partial class QBEGenerator return new Val(output, structFieldAccess.Type, ValKind.Pointer); } - private Val EmitTraitFuncAccess(BoundInterfaceFuncAccess interfaceFuncAccess) + private Val EmitTraitFuncAccess(InterfaceFuncAccess interfaceFuncAccess) { throw new NotImplementedException(); } - private Val EmitFuncCall(BoundFuncCall funcCall) + private Val EmitFuncCall(FuncCall funcCall) { var expression = EmitExpression(funcCall.Expression); var funcPointer = EmitUnwrap(expression); diff --git a/src/compiler/NubLang/Generation/QBE/QBEGenerator.Statement.cs b/src/compiler/NubLang/Generation/QBE/QBEGenerator.Statement.cs index 2b8853a..2b0b6c8 100644 --- a/src/compiler/NubLang/Generation/QBE/QBEGenerator.Statement.cs +++ b/src/compiler/NubLang/Generation/QBE/QBEGenerator.Statement.cs @@ -1,36 +1,36 @@ using System.Diagnostics; -using NubLang.Syntax.Binding.Node; +using NubLang.TypeChecking.Node; namespace NubLang.Generation.QBE; public partial class QBEGenerator { - private void EmitStatement(BoundStatement statement) + private void EmitStatement(Statement statement) { switch (statement) { - case BoundAssignment assignment: + case Assignment assignment: EmitAssignment(assignment); break; - case BoundBreak: + case Break: EmitBreak(); break; - case BoundContinue: + case Continue: EmitContinue(); break; - case BoundIf ifStatement: + case If ifStatement: EmitIf(ifStatement); break; - case BoundReturn @return: + case Return @return: EmitReturn(@return); break; - case BoundStatementExpression statementExpression: + case StatementExpression statementExpression: EmitExpression(statementExpression.Expression); break; - case BoundVariableDeclaration variableDeclaration: + case VariableDeclaration variableDeclaration: EmitVariableDeclaration(variableDeclaration); break; - case BoundWhile whileStatement: + case While whileStatement: EmitWhile(whileStatement); break; default: @@ -38,7 +38,7 @@ public partial class QBEGenerator } } - private void EmitAssignment(BoundAssignment assignment) + private void EmitAssignment(Assignment assignment) { var destination = EmitExpression(assignment.Target); Debug.Assert(destination.Kind == ValKind.Pointer); @@ -57,7 +57,7 @@ public partial class QBEGenerator _codeIsReachable = false; } - private void EmitIf(BoundIf ifStatement) + private void EmitIf(If ifStatement) { var trueLabel = LabelName(); var falseLabel = LabelName(); @@ -81,7 +81,7 @@ public partial class QBEGenerator _writer.WriteLine(endLabel); } - private void EmitReturn(BoundReturn @return) + private void EmitReturn(Return @return) { if (@return.Value.HasValue) { @@ -94,7 +94,7 @@ public partial class QBEGenerator } } - private void EmitVariableDeclaration(BoundVariableDeclaration variableDeclaration) + private void EmitVariableDeclaration(VariableDeclaration variableDeclaration) { var name = $"%{variableDeclaration.Name}"; _writer.Indented($"{name} =l alloc8 8"); @@ -108,7 +108,7 @@ public partial class QBEGenerator Scope.Declare(variableDeclaration.Name, new Val(name, variableDeclaration.Type, ValKind.Pointer)); } - private void EmitWhile(BoundWhile whileStatement) + private void EmitWhile(While whileStatement) { var conditionLabel = LabelName(); var iterationLabel = LabelName(); diff --git a/src/compiler/NubLang/Generation/QBE/QBEGenerator.cs b/src/compiler/NubLang/Generation/QBE/QBEGenerator.cs index a4ff66a..8f0f5ed 100644 --- a/src/compiler/NubLang/Generation/QBE/QBEGenerator.cs +++ b/src/compiler/NubLang/Generation/QBE/QBEGenerator.cs @@ -1,23 +1,23 @@ using System.Diagnostics; using System.Diagnostics.CodeAnalysis; using System.Text; -using NubLang.Syntax.Binding; -using NubLang.Syntax.Binding.Node; -using NubLang.Syntax.Tokenization; +using NubLang.Tokenization; +using NubLang.TypeChecking; +using NubLang.TypeChecking.Node; namespace NubLang.Generation.QBE; public partial class QBEGenerator { - private readonly BoundSyntaxTree _syntaxTree; - private readonly BoundDefinitionTable _definitionTable; + private readonly TypedSyntaxTree _syntaxTree; + private readonly TypedDefinitionTable _definitionTable; private readonly QBEWriter _writer; private readonly List _cStringLiterals = []; private readonly List _stringLiterals = []; private readonly Stack _breakLabels = []; private readonly Stack _continueLabels = []; - private readonly Queue<(BoundArrowFunc Func, string Name)> _arrowFunctions = []; + private readonly Queue<(ArrowFunc Func, string Name)> _arrowFunctions = []; private readonly Stack _scopes = []; private int _tmpIndex; private int _labelIndex; @@ -28,7 +28,7 @@ public partial class QBEGenerator private Scope Scope => _scopes.Peek(); - public QBEGenerator(BoundSyntaxTree syntaxTree, BoundDefinitionTable definitionTable) + public QBEGenerator(TypedSyntaxTree syntaxTree, TypedDefinitionTable definitionTable) { _syntaxTree = syntaxTree; _definitionTable = definitionTable; @@ -62,7 +62,7 @@ public partial class QBEGenerator _writer.NewLine(); } - foreach (var funcDef in _syntaxTree.Definitions.OfType()) + foreach (var funcDef in _syntaxTree.Definitions.OfType()) { EmitFuncDefinition(LocalFuncName(funcDef), funcDef.Signature.Parameters, funcDef.Signature.ReturnType, funcDef.Body); _writer.NewLine(); @@ -191,21 +191,21 @@ public partial class QBEGenerator return size; } - private bool EmitTryMoveInto(BoundExpression source, string destinationPointer) + private bool EmitTryMoveInto(Expression source, string destinationPointer) { switch (source) { - case BoundArrayInitializer arrayInitializer: + case ArrayInitializer arrayInitializer: { EmitStore(source.Type, EmitUnwrap(EmitArrayInitializer(arrayInitializer)), destinationPointer); return true; } - case BoundStructInitializer structInitializer: + case StructInitializer structInitializer: { EmitStructInitializer(structInitializer, destinationPointer); return true; } - case BoundLiteral { Kind: LiteralKind.String } literal: + case Literal { Kind: LiteralKind.String } literal: { EmitStore(source.Type, EmitUnwrap(EmitLiteral(literal)), destinationPointer); return true; @@ -215,7 +215,7 @@ public partial class QBEGenerator return false; } - private void EmitCopyIntoOrInitialize(BoundExpression source, string destinationPointer) + private void EmitCopyIntoOrInitialize(Expression source, string destinationPointer) { // If the source is a value which is not used yet such as an array/struct initializer or literal, we can skip copying if (EmitTryMoveInto(source, destinationPointer)) @@ -253,13 +253,13 @@ public partial class QBEGenerator } } - private bool EmitTryCreateWithoutCopy(BoundExpression source, [NotNullWhen(true)] out string? destination) + private bool EmitTryCreateWithoutCopy(Expression source, [NotNullWhen(true)] out string? destination) { switch (source) { - case BoundArrayInitializer: - case BoundStructInitializer: - case BoundLiteral { Kind: LiteralKind.String }: + case ArrayInitializer: + case StructInitializer: + case Literal { Kind: LiteralKind.String }: { destination = EmitUnwrap(EmitExpression(source)); return true; @@ -270,7 +270,7 @@ public partial class QBEGenerator return false; } - private string EmitCreateCopyOrInitialize(BoundExpression source) + private string EmitCreateCopyOrInitialize(Expression source) { // If the source is a value which is not used yet such as an array/struct initializer or literal, we can skip copying if (EmitTryCreateWithoutCopy(source, out var uncopiedValue)) @@ -328,7 +328,7 @@ public partial class QBEGenerator return "l"; } - private void EmitFuncDefinition(string name, IReadOnlyList parameters, NubType returnType, BoundBlock body) + private void EmitFuncDefinition(string name, IReadOnlyList parameters, NubType returnType, Block body) { _labelIndex = 0; _tmpIndex = 0; @@ -360,7 +360,7 @@ public partial class QBEGenerator EmitBlock(body, scope); // Implicit return for void functions if no explicit return has been set - if (returnType is NubVoidType && body.Statements is [.., not BoundReturn]) + if (returnType is NubVoidType && body.Statements is [.., not Return]) { if (returnType is NubVoidType) { @@ -371,7 +371,7 @@ public partial class QBEGenerator _writer.EndFunction(); } - private void EmitStructDefinition(BoundStruct structDef) + private void EmitStructDefinition(Struct structDef) { _writer.WriteLine($"type {CustomTypeName(structDef.Name)} = {{ "); @@ -392,7 +392,7 @@ public partial class QBEGenerator _writer.WriteLine("}"); return; - string StructDefQBEType(BoundStructField field) + string StructDefQBEType(StructField field) { if (field.Type.IsSimpleType(out var simpleType, out var complexType)) { @@ -417,7 +417,7 @@ public partial class QBEGenerator } } - private void EmitTraitVTable(BoundTrait traitDef) + private void EmitTraitVTable(Trait traitDef) { _writer.WriteLine($"type {CustomTypeName(traitDef.Name)} = {{"); @@ -429,7 +429,7 @@ public partial class QBEGenerator _writer.WriteLine("}"); } - private void EmitBlock(BoundBlock block, Scope? scope = null) + private void EmitBlock(Block block, Scope? scope = null) { _scopes.Push(scope ?? Scope.SubScope()); @@ -456,7 +456,7 @@ public partial class QBEGenerator }; } - private int OffsetOf(BoundStruct structDefinition, string member) + private int OffsetOf(Struct structDefinition, string member) { var offset = 0; @@ -498,12 +498,12 @@ public partial class QBEGenerator return $"$string{++_stringLiteralIndex}"; } - private string LocalFuncName(BoundLocalFunc funcDef) + private string LocalFuncName(LocalFunc funcDef) { return $"${funcDef.Name}"; } - private string ExternFuncName(BoundExternFunc funcDef) + private string ExternFuncName(ExternFunc funcDef) { return $"${funcDef.CallName}"; } diff --git a/src/compiler/NubLang/Generation/TypedDefinitionTable.cs b/src/compiler/NubLang/Generation/TypedDefinitionTable.cs new file mode 100644 index 0000000..54d1dc4 --- /dev/null +++ b/src/compiler/NubLang/Generation/TypedDefinitionTable.cs @@ -0,0 +1,78 @@ +using NubLang.TypeChecking; +using NubLang.TypeChecking.Node; + +namespace NubLang.Generation; + +public sealed class TypedDefinitionTable +{ + private readonly List _definitions; + + public TypedDefinitionTable(IEnumerable syntaxTrees) + { + _definitions = syntaxTrees.SelectMany(x => x.Definitions).ToList(); + } + + public LocalFunc LookupLocalFunc(string name) + { + return _definitions + .OfType() + .First(x => x.Name == name); + } + + public ExternFunc LookupExternFunc(string name) + { + return _definitions + .OfType() + .First(x => x.Name == name); + } + + public Struct LookupStruct(string name) + { + return _definitions + .OfType() + .First(x => x.Name == name); + } + + public StructField LookupStructField(Struct @struct, string field) + { + return @struct.Fields.First(x => x.Name == field); + } + + public IEnumerable LookupTraitImpls(NubType itemType) + { + return _definitions + .OfType() + .Where(x => x.ForType == itemType); + } + + public TraitFuncImpl LookupTraitFuncImpl(NubType forType, string name) + { + return _definitions + .OfType() + .Where(x => x.ForType == forType) + .SelectMany(x => x.Functions) + .First(x => x.Name == name); + } + + public Trait LookupTrait(string name) + { + return _definitions + .OfType() + .First(x => x.Name == name); + } + + public TraitFunc LookupTraitFunc(Trait trait, string name) + { + return trait.Functions.First(x => x.Name == name); + } + + public IEnumerable GetStructs() + { + return _definitions.OfType(); + } + + public IEnumerable GetTraits() + { + return _definitions.OfType(); + } +} \ No newline at end of file diff --git a/src/compiler/NubLang/Syntax/Parsing/Parser.cs b/src/compiler/NubLang/Parsing/Parser.cs similarity index 92% rename from src/compiler/NubLang/Syntax/Parsing/Parser.cs rename to src/compiler/NubLang/Parsing/Parser.cs index bc8dfde..53495ee 100644 --- a/src/compiler/NubLang/Syntax/Parsing/Parser.cs +++ b/src/compiler/NubLang/Parsing/Parser.cs @@ -1,10 +1,10 @@ using System.Diagnostics.CodeAnalysis; using NubLang.Common; using NubLang.Diagnostics; -using NubLang.Syntax.Parsing.Node; -using NubLang.Syntax.Tokenization; +using NubLang.Parsing.Syntax; +using NubLang.Tokenization; -namespace NubLang.Syntax.Parsing; +namespace NubLang.Parsing; public sealed class Parser { @@ -324,57 +324,57 @@ public sealed class Parser return left; } - private int GetBinaryOperatorPrecedence(BinaryOperator @operator) + private int GetBinaryOperatorPrecedence(BinaryOperatorSyntax operatorSyntax) { - return @operator switch + return operatorSyntax switch { - BinaryOperator.Multiply => 3, - BinaryOperator.Divide => 3, - BinaryOperator.Plus => 2, - BinaryOperator.Minus => 2, - BinaryOperator.GreaterThan => 1, - BinaryOperator.GreaterThanOrEqual => 1, - BinaryOperator.LessThan => 1, - BinaryOperator.LessThanOrEqual => 1, - BinaryOperator.Equal => 0, - BinaryOperator.NotEqual => 0, - _ => throw new ArgumentOutOfRangeException(nameof(@operator), @operator, null) + BinaryOperatorSyntax.Multiply => 3, + BinaryOperatorSyntax.Divide => 3, + BinaryOperatorSyntax.Plus => 2, + BinaryOperatorSyntax.Minus => 2, + BinaryOperatorSyntax.GreaterThan => 1, + BinaryOperatorSyntax.GreaterThanOrEqual => 1, + BinaryOperatorSyntax.LessThan => 1, + BinaryOperatorSyntax.LessThanOrEqual => 1, + BinaryOperatorSyntax.Equal => 0, + BinaryOperatorSyntax.NotEqual => 0, + _ => throw new ArgumentOutOfRangeException(nameof(operatorSyntax), operatorSyntax, null) }; } - private bool TryGetBinaryOperator(Symbol symbol, [NotNullWhen(true)] out BinaryOperator? binaryExpressionOperator) + private bool TryGetBinaryOperator(Symbol symbol, [NotNullWhen(true)] out BinaryOperatorSyntax? binaryExpressionOperator) { switch (symbol) { case Symbol.Equal: - binaryExpressionOperator = BinaryOperator.Equal; + binaryExpressionOperator = BinaryOperatorSyntax.Equal; return true; case Symbol.NotEqual: - binaryExpressionOperator = BinaryOperator.NotEqual; + binaryExpressionOperator = BinaryOperatorSyntax.NotEqual; return true; case Symbol.LessThan: - binaryExpressionOperator = BinaryOperator.LessThan; + binaryExpressionOperator = BinaryOperatorSyntax.LessThan; return true; case Symbol.LessThanOrEqual: - binaryExpressionOperator = BinaryOperator.LessThanOrEqual; + binaryExpressionOperator = BinaryOperatorSyntax.LessThanOrEqual; return true; case Symbol.GreaterThan: - binaryExpressionOperator = BinaryOperator.GreaterThan; + binaryExpressionOperator = BinaryOperatorSyntax.GreaterThan; return true; case Symbol.GreaterThanOrEqual: - binaryExpressionOperator = BinaryOperator.GreaterThanOrEqual; + binaryExpressionOperator = BinaryOperatorSyntax.GreaterThanOrEqual; return true; case Symbol.Plus: - binaryExpressionOperator = BinaryOperator.Plus; + binaryExpressionOperator = BinaryOperatorSyntax.Plus; return true; case Symbol.Minus: - binaryExpressionOperator = BinaryOperator.Minus; + binaryExpressionOperator = BinaryOperatorSyntax.Minus; return true; case Symbol.Star: - binaryExpressionOperator = BinaryOperator.Multiply; + binaryExpressionOperator = BinaryOperatorSyntax.Multiply; return true; case Symbol.ForwardSlash: - binaryExpressionOperator = BinaryOperator.Divide; + binaryExpressionOperator = BinaryOperatorSyntax.Divide; return true; default: binaryExpressionOperator = null; @@ -393,8 +393,8 @@ public sealed class Parser { Symbol.Func => ParseArrowFunction(), Symbol.OpenParen => ParseParenthesizedExpression(), - Symbol.Minus => new UnaryExpressionSyntax(UnaryOperator.Negate, ParsePrimaryExpression()), - Symbol.Bang => new UnaryExpressionSyntax(UnaryOperator.Invert, ParsePrimaryExpression()), + Symbol.Minus => new UnaryExpressionSyntax(UnaryOperatorSyntax.Negate, ParsePrimaryExpression()), + Symbol.Bang => new UnaryExpressionSyntax(UnaryOperatorSyntax.Invert, ParsePrimaryExpression()), Symbol.OpenBracket => ParseArrayInitializer(), Symbol.Alloc => ParseStructInitializer(), _ => throw new ParseException(Diagnostic diff --git a/src/compiler/NubLang/Syntax/Parsing/Node/DefinitionSyntax.cs b/src/compiler/NubLang/Parsing/Syntax/DefinitionSyntax.cs similarity index 98% rename from src/compiler/NubLang/Syntax/Parsing/Node/DefinitionSyntax.cs rename to src/compiler/NubLang/Parsing/Syntax/DefinitionSyntax.cs index 82b5b07..0ef71ee 100644 --- a/src/compiler/NubLang/Syntax/Parsing/Node/DefinitionSyntax.cs +++ b/src/compiler/NubLang/Parsing/Syntax/DefinitionSyntax.cs @@ -1,6 +1,6 @@ using NubLang.Common; -namespace NubLang.Syntax.Parsing.Node; +namespace NubLang.Parsing.Syntax; public abstract record DefinitionSyntax : SyntaxNode; diff --git a/src/compiler/NubLang/Syntax/Parsing/Node/ExpressionSyntax.cs b/src/compiler/NubLang/Parsing/Syntax/ExpressionSyntax.cs similarity index 90% rename from src/compiler/NubLang/Syntax/Parsing/Node/ExpressionSyntax.cs rename to src/compiler/NubLang/Parsing/Syntax/ExpressionSyntax.cs index 2defc18..e715314 100644 --- a/src/compiler/NubLang/Syntax/Parsing/Node/ExpressionSyntax.cs +++ b/src/compiler/NubLang/Parsing/Syntax/ExpressionSyntax.cs @@ -1,15 +1,14 @@ -using NubLang.Common; -using NubLang.Syntax.Tokenization; +using NubLang.Tokenization; -namespace NubLang.Syntax.Parsing.Node; +namespace NubLang.Parsing.Syntax; -public enum UnaryOperator +public enum UnaryOperatorSyntax { Negate, Invert } -public enum BinaryOperator +public enum BinaryOperatorSyntax { Equal, NotEqual, @@ -25,7 +24,7 @@ public enum BinaryOperator public abstract record ExpressionSyntax : SyntaxNode; -public record BinaryExpressionSyntax(ExpressionSyntax Left, BinaryOperator Operator, ExpressionSyntax Right) : ExpressionSyntax +public record BinaryExpressionSyntax(ExpressionSyntax Left, BinaryOperatorSyntax OperatorSyntax, ExpressionSyntax Right) : ExpressionSyntax { public override IEnumerable GetChildren() { @@ -34,7 +33,7 @@ public record BinaryExpressionSyntax(ExpressionSyntax Left, BinaryOperator Opera } } -public record UnaryExpressionSyntax(UnaryOperator Operator, ExpressionSyntax Operand) : ExpressionSyntax +public record UnaryExpressionSyntax(UnaryOperatorSyntax OperatorSyntax, ExpressionSyntax Operand) : ExpressionSyntax { public override IEnumerable GetChildren() { diff --git a/src/compiler/NubLang/Syntax/Parsing/Node/StatementSyntax.cs b/src/compiler/NubLang/Parsing/Syntax/StatementSyntax.cs similarity index 98% rename from src/compiler/NubLang/Syntax/Parsing/Node/StatementSyntax.cs rename to src/compiler/NubLang/Parsing/Syntax/StatementSyntax.cs index 61c23e3..4716e5e 100644 --- a/src/compiler/NubLang/Syntax/Parsing/Node/StatementSyntax.cs +++ b/src/compiler/NubLang/Parsing/Syntax/StatementSyntax.cs @@ -1,6 +1,6 @@ using NubLang.Common; -namespace NubLang.Syntax.Parsing.Node; +namespace NubLang.Parsing.Syntax; public abstract record StatementSyntax : SyntaxNode; diff --git a/src/compiler/NubLang/Syntax/Parsing/Node/SyntaxNode.cs b/src/compiler/NubLang/Parsing/Syntax/SyntaxNode.cs similarity index 95% rename from src/compiler/NubLang/Syntax/Parsing/Node/SyntaxNode.cs rename to src/compiler/NubLang/Parsing/Syntax/SyntaxNode.cs index 4eaf513..76ab7f1 100644 --- a/src/compiler/NubLang/Syntax/Parsing/Node/SyntaxNode.cs +++ b/src/compiler/NubLang/Parsing/Syntax/SyntaxNode.cs @@ -1,4 +1,4 @@ -namespace NubLang.Syntax.Parsing.Node; +namespace NubLang.Parsing.Syntax; public abstract record SyntaxNode { diff --git a/src/compiler/NubLang/Syntax/Parsing/Node/TypeSyntax.cs b/src/compiler/NubLang/Parsing/Syntax/TypeSyntax.cs similarity index 98% rename from src/compiler/NubLang/Syntax/Parsing/Node/TypeSyntax.cs rename to src/compiler/NubLang/Parsing/Syntax/TypeSyntax.cs index c70d371..43a316b 100644 --- a/src/compiler/NubLang/Syntax/Parsing/Node/TypeSyntax.cs +++ b/src/compiler/NubLang/Parsing/Syntax/TypeSyntax.cs @@ -1,4 +1,4 @@ -namespace NubLang.Syntax.Parsing.Node; +namespace NubLang.Parsing.Syntax; public enum PrimitiveTypeSyntaxKind { diff --git a/src/compiler/NubLang/Syntax/Binding/Binder.cs b/src/compiler/NubLang/Syntax/Binding/Binder.cs deleted file mode 100644 index c5ee868..0000000 --- a/src/compiler/NubLang/Syntax/Binding/Binder.cs +++ /dev/null @@ -1,637 +0,0 @@ -using NubLang.Common; -using NubLang.Diagnostics; -using NubLang.Syntax.Binding.Node; -using NubLang.Syntax.Parsing.Node; -using NubLang.Syntax.Tokenization; - -namespace NubLang.Syntax.Binding; - -public sealed class Binder -{ - private readonly SyntaxTree _syntaxTree; - private readonly DefinitionTable _definitionTable; - - private readonly Stack _scopes = []; - private readonly Stack _funcReturnTypes = []; - - private Scope Scope => _scopes.Peek(); - - public Binder(SyntaxTree syntaxTree, DefinitionTable definitionTable) - { - _syntaxTree = syntaxTree; - _definitionTable = definitionTable; - } - - public BoundSyntaxTree Bind() - { - _funcReturnTypes.Clear(); - _scopes.Clear(); - - var diagnostics = new List(); - var definitions = new List(); - - foreach (var definition in _syntaxTree.Definitions) - { - try - { - definitions.Add(BindDefinition(definition)); - } - catch (BindException e) - { - diagnostics.Add(e.Diagnostic); - } - } - - return new BoundSyntaxTree(definitions, diagnostics); - } - - private BoundDefinition BindDefinition(DefinitionSyntax node) - { - return node switch - { - ExternFuncSyntax definition => BindExternFuncDefinition(definition), - InterfaceSyntax definition => BindTraitDefinition(definition), - LocalFuncSyntax definition => BindLocalFuncDefinition(definition), - StructSyntax definition => BindStruct(definition), - _ => throw new ArgumentOutOfRangeException(nameof(node)) - }; - } - - private BoundTrait BindTraitDefinition(InterfaceSyntax node) - { - var functions = new List(); - - foreach (var function in node.Functions) - { - functions.Add(new BoundTraitFunc(function.Name, BindFuncSignature(function.Signature))); - } - - return new BoundTrait(node.Name, functions); - } - - private BoundStruct BindStruct(StructSyntax node) - { - var structFields = new List(); - - foreach (var field in node.Fields) - { - var value = Optional.Empty(); - - if (field.Value.HasValue) - { - value = BindExpression(field.Value.Value, BindType(field.Type)); - } - - structFields.Add(new BoundStructField(field.Index, field.Name, BindType(field.Type), value)); - } - - return new BoundStruct(node.Name, structFields); - } - - private BoundExternFunc BindExternFuncDefinition(ExternFuncSyntax node) - { - return new BoundExternFunc(node.Name, node.CallName, BindFuncSignature(node.Signature)); - } - - private BoundLocalFunc BindLocalFuncDefinition(LocalFuncSyntax node) - { - var signature = BindFuncSignature(node.Signature); - var body = BindFuncBody(node.Body, signature.ReturnType, signature.Parameters); - - return new BoundLocalFunc(node.Name, signature, body); - } - - private BoundStatement BindStatement(StatementSyntax node) - { - return node switch - { - AssignmentSyntax statement => BindAssignment(statement), - BreakSyntax => new BoundBreak(), - ContinueSyntax => new BoundContinue(), - IfSyntax statement => BindIf(statement), - ReturnSyntax statement => BindReturn(statement), - StatementExpressionSyntax statement => BindStatementExpression(statement), - VariableDeclarationSyntax statement => BindVariableDeclaration(statement), - WhileSyntax statement => BindWhile(statement), - _ => throw new ArgumentOutOfRangeException(nameof(node)) - }; - } - - private BoundStatement BindAssignment(AssignmentSyntax statement) - { - var expression = BindExpression(statement.Target); - var value = BindExpression(statement.Value, expression.Type); - return new BoundAssignment(expression, value); - } - - private BoundIf BindIf(IfSyntax statement) - { - var elseStatement = Optional.Empty>(); - - if (statement.Else.HasValue) - { - elseStatement = statement.Else.Value.Match> - ( - elseIf => BindIf(elseIf), - @else => BindBlock(@else) - ); - } - - return new BoundIf(BindExpression(statement.Condition, new NubPrimitiveType(PrimitiveTypeKind.Bool)), BindBlock(statement.Body), elseStatement); - } - - private BoundReturn BindReturn(ReturnSyntax statement) - { - var value = Optional.Empty(); - - if (statement.Value.HasValue) - { - value = BindExpression(statement.Value.Value, _funcReturnTypes.Peek()); - } - - return new BoundReturn(value); - } - - private BoundStatementExpression BindStatementExpression(StatementExpressionSyntax statement) - { - return new BoundStatementExpression(BindExpression(statement.Expression)); - } - - private BoundVariableDeclaration BindVariableDeclaration(VariableDeclarationSyntax statement) - { - NubType? type = null; - - if (statement.ExplicitType.HasValue) - { - type = BindType(statement.ExplicitType.Value); - } - - var assignment = Optional.Empty(); - if (statement.Assignment.HasValue) - { - var boundValue = BindExpression(statement.Assignment.Value, type); - assignment = boundValue; - type = boundValue.Type; - } - - if (type == null) - { - throw new NotImplementedException("Diagnostics not implemented"); - } - - Scope.Declare(new Variable(statement.Name, type)); - - return new BoundVariableDeclaration(statement.Name, assignment, type); - } - - private BoundWhile BindWhile(WhileSyntax statement) - { - return new BoundWhile(BindExpression(statement.Condition, new NubPrimitiveType(PrimitiveTypeKind.Bool)), BindBlock(statement.Body)); - } - - private BoundExpression BindExpression(ExpressionSyntax node, NubType? expectedType = null) - { - return node switch - { - AddressOfSyntax expression => BindAddressOf(expression), - ArrowFuncSyntax expression => BindArrowFunc(expression, expectedType), - ArrayIndexAccessSyntax expression => BindArrayIndexAccess(expression), - ArrayInitializerSyntax expression => BindArrayInitializer(expression), - BinaryExpressionSyntax expression => BindBinaryExpression(expression), - DereferenceSyntax expression => BindDereference(expression), - FuncCallSyntax expression => BindFuncCall(expression), - IdentifierSyntax expression => BindIdentifier(expression), - LiteralSyntax expression => BindLiteral(expression, expectedType), - MemberAccessSyntax expression => BindMemberAccess(expression), - StructInitializerSyntax expression => BindStructInitializer(expression), - UnaryExpressionSyntax expression => BindUnaryExpression(expression), - _ => throw new ArgumentOutOfRangeException(nameof(node)) - }; - } - - private BoundAddressOf BindAddressOf(AddressOfSyntax expression) - { - var inner = BindExpression(expression.Expression); - return new BoundAddressOf(new NubPointerType(inner.Type), inner); - } - - private BoundArrowFunc BindArrowFunc(ArrowFuncSyntax expression, NubType? expectedType = null) - { - if (expectedType == null) - { - throw new BindException(Diagnostic.Error("Cannot infer argument types for arrow function").Build()); - } - - if (expectedType is not NubFuncType funcType) - { - throw new BindException(Diagnostic.Error($"Expected {expectedType}, but got arrow function").Build()); - } - - var parameters = new List(); - - for (var i = 0; i < expression.Parameters.Count; i++) - { - if (i >= funcType.Parameters.Count) - { - throw new BindException(Diagnostic.Error($"Arrow function expected a maximum of {funcType.Parameters.Count} arguments").Build()); - } - - var expectedParameterType = funcType.Parameters[i]; - var parameter = expression.Parameters[i]; - parameters.Add(new BoundFuncParameter(parameter.Name, expectedParameterType)); - } - - var body = BindFuncBody(expression.Body, funcType.ReturnType, parameters); - - return new BoundArrowFunc(new NubFuncType(parameters.Select(x => x.Type).ToList(), funcType.ReturnType), parameters, funcType.ReturnType, body); - } - - private BoundArrayIndexAccess BindArrayIndexAccess(ArrayIndexAccessSyntax expression) - { - var boundArray = BindExpression(expression.Target); - var elementType = ((NubArrayType)boundArray.Type).ElementType; - return new BoundArrayIndexAccess(elementType, boundArray, BindExpression(expression.Index, new NubPrimitiveType(PrimitiveTypeKind.U64))); - } - - private BoundArrayInitializer BindArrayInitializer(ArrayInitializerSyntax expression) - { - var capacity = BindExpression(expression.Capacity, new NubPrimitiveType(PrimitiveTypeKind.U64)); - var type = new NubArrayType(BindType(expression.ElementType)); - return new BoundArrayInitializer(type, capacity, BindType(expression.ElementType)); - } - - private BoundBinaryExpression BindBinaryExpression(BinaryExpressionSyntax expression) - { - var boundLeft = BindExpression(expression.Left); - var boundRight = BindExpression(expression.Right, boundLeft.Type); - return new BoundBinaryExpression(boundLeft.Type, boundLeft, BindBinaryOperator(expression.Operator), boundRight); - } - - private BoundDereference BindDereference(DereferenceSyntax expression) - { - var boundExpression = BindExpression(expression.Expression); - var dereferencedType = ((NubPointerType)boundExpression.Type).BaseType; - return new BoundDereference(dereferencedType, boundExpression); - } - - private BoundFuncCall BindFuncCall(FuncCallSyntax expression) - { - var boundExpression = BindExpression(expression.Expression); - - var funcType = (NubFuncType)boundExpression.Type; - - var parameters = new List(); - - foreach (var (i, parameter) in expression.Parameters.Index()) - { - if (i >= funcType.Parameters.Count) - { - throw new NotImplementedException("Diagnostics not implemented"); - } - - var expectedType = funcType.Parameters[i]; - - parameters.Add(BindExpression(parameter, expectedType)); - } - - return new BoundFuncCall(funcType.ReturnType, boundExpression, parameters); - } - - private BoundExpression BindIdentifier(IdentifierSyntax expression) - { - var variable = Scope.Lookup(expression.Name); - if (variable != null) - { - return new BoundVariableIdent(variable.Type, variable.Name); - } - - var localFuncs = _definitionTable.LookupLocalFunc(expression.Name).ToArray(); - if (localFuncs.Length > 0) - { - if (localFuncs.Length > 1) - { - throw new BindException(Diagnostic.Error($"Extern func {expression.Name} has multiple definitions").Build()); - } - - var localFunc = localFuncs[0]; - - var returnType = BindType(localFunc.Signature.ReturnType); - var parameterTypes = localFunc.Signature.Parameters.Select(p => BindType(p.Type)).ToList(); - var type = new NubFuncType(parameterTypes, returnType); - return new BoundLocalFuncIdent(type, expression.Name); - } - - var externFuncs = _definitionTable.LookupExternFunc(expression.Name).ToArray(); - if (externFuncs.Length > 0) - { - if (externFuncs.Length > 1) - { - throw new BindException(Diagnostic.Error($"Extern func {expression.Name} has multiple definitions").Build()); - } - - var externFunc = externFuncs[0]; - - var returnType = BindType(externFunc.Signature.ReturnType); - var parameterTypes = externFunc.Signature.Parameters.Select(p => BindType(p.Type)).ToList(); - var type = new NubFuncType(parameterTypes, returnType); - return new BoundExternFuncIdent(type, expression.Name); - } - - throw new BindException(Diagnostic.Error($"No identifier with the name {expression.Name} exists").Build()); - } - - private BoundLiteral BindLiteral(LiteralSyntax expression, NubType? expectedType = null) - { - var type = expectedType ?? expression.Kind switch - { - LiteralKind.Integer => new NubPrimitiveType(PrimitiveTypeKind.I64), - LiteralKind.Float => new NubPrimitiveType(PrimitiveTypeKind.F64), - LiteralKind.String => new NubStringType(), - LiteralKind.Bool => new NubPrimitiveType(PrimitiveTypeKind.Bool), - _ => throw new ArgumentOutOfRangeException() - }; - - return new BoundLiteral(type, expression.Value, expression.Kind); - } - - private BoundExpression BindMemberAccess(MemberAccessSyntax expression) - { - var boundExpression = BindExpression(expression.Target); - - if (boundExpression.Type is NubCustomType customType) - { - var traits = _definitionTable.LookupTrait(customType).ToArray(); - if (traits.Length > 0) - { - if (traits.Length > 1) - { - throw new BindException(Diagnostic.Error($"Trait {customType} has multiple definitions").Build()); - } - - var trait = traits[0]; - - var traitFuncs = _definitionTable.LookupTraitFunc(trait, expression.Member).ToArray(); - if (traits.Length > 0) - { - if (traits.Length > 1) - { - throw new BindException(Diagnostic.Error($"Trait {customType} has multiple functions with the name {expression.Member}").Build()); - } - - var traitFunc = traitFuncs[0]; - - var returnType = BindType(traitFunc.Signature.ReturnType); - var parameterTypes = traitFunc.Signature.Parameters.Select(p => BindType(p.Type)).ToList(); - var type = new NubFuncType(parameterTypes, returnType); - return new BoundInterfaceFuncAccess(type, customType, boundExpression, expression.Member); - } - } - - var structs = _definitionTable.LookupStruct(customType).ToArray(); - if (structs.Length > 0) - { - if (structs.Length > 1) - { - throw new BindException(Diagnostic.Error($"Struct {customType} has multiple definitions").Build()); - } - - var @struct = structs[0]; - - var fields = _definitionTable.LookupStructField(@struct, expression.Member).ToArray(); - if (fields.Length > 0) - { - if (fields.Length > 1) - { - throw new BindException(Diagnostic.Error($"Struct {customType} has multiple fields with the name {expression.Member}").Build()); - } - - var field = fields[0]; - - return new BoundStructFieldAccess(BindType(field.Type), customType, boundExpression, expression.Member); - } - } - } - - throw new BindException(Diagnostic.Error($"{boundExpression.Type} does not have a member with the name {expression.Member}").Build()); - } - - private BoundStructInitializer BindStructInitializer(StructInitializerSyntax expression) - { - var boundType = BindType(expression.StructType); - - if (boundType is not NubCustomType structType) - { - throw new BindException(Diagnostic.Error($"Cannot initialize non-struct type {expression.StructType}").Build()); - } - - var structs = _definitionTable.LookupStruct(structType).ToArray(); - - if (structs.Length == 0) - { - throw new BindException(Diagnostic.Error($"Struct {structType} is not defined").Build()); - } - - if (structs.Length > 1) - { - throw new BindException(Diagnostic.Error($"Struct {structType} has multiple definitions").Build()); - } - - var @struct = structs[0]; - - var initializers = new Dictionary(); - - foreach (var (field, initializer) in expression.Initializers) - { - var fields = _definitionTable.LookupStructField(@struct, field).ToArray(); - - if (fields.Length == 0) - { - throw new BindException(Diagnostic.Error($"Struct {structType} does not have a field with the name {field}").Build()); - } - - if (fields.Length > 1) - { - throw new BindException(Diagnostic.Error($"Struct {structType} has multiple fields with the name {field}").Build()); - } - - initializers[field] = BindExpression(initializer, BindType(fields[0].Type)); - } - - return new BoundStructInitializer(structType, initializers); - } - - private BoundUnaryExpression BindUnaryExpression(UnaryExpressionSyntax expression) - { - var boundOperand = BindExpression(expression.Operand); - - NubType? type = null; - - switch (expression.Operator) - { - case UnaryOperator.Negate: - { - boundOperand = BindExpression(expression.Operand, new NubPrimitiveType(PrimitiveTypeKind.I64)); - - if (boundOperand.Type.IsNumber) - { - type = boundOperand.Type; - } - - break; - } - case UnaryOperator.Invert: - { - boundOperand = BindExpression(expression.Operand, new NubPrimitiveType(PrimitiveTypeKind.Bool)); - - type = new NubPrimitiveType(PrimitiveTypeKind.Bool); - break; - } - } - - if (type == null) - { - throw new NotImplementedException("Diagnostics not implemented"); - } - - return new BoundUnaryExpression(type, BindBinaryOperator(expression.Operator), boundOperand); - } - - private BoundFuncSignature BindFuncSignature(FuncSignatureSyntax node) - { - var parameters = new List(); - - foreach (var parameter in node.Parameters) - { - parameters.Add(new BoundFuncParameter(parameter.Name, BindType(parameter.Type))); - } - - return new BoundFuncSignature(parameters, BindType(node.ReturnType)); - } - - private BoundBinaryOperator BindBinaryOperator(BinaryOperator op) - { - return op switch - { - BinaryOperator.Equal => BoundBinaryOperator.Equal, - BinaryOperator.NotEqual => BoundBinaryOperator.NotEqual, - BinaryOperator.GreaterThan => BoundBinaryOperator.GreaterThan, - BinaryOperator.GreaterThanOrEqual => BoundBinaryOperator.GreaterThanOrEqual, - BinaryOperator.LessThan => BoundBinaryOperator.LessThan, - BinaryOperator.LessThanOrEqual => BoundBinaryOperator.LessThanOrEqual, - BinaryOperator.Plus => BoundBinaryOperator.Plus, - BinaryOperator.Minus => BoundBinaryOperator.Minus, - BinaryOperator.Multiply => BoundBinaryOperator.Multiply, - BinaryOperator.Divide => BoundBinaryOperator.Divide, - _ => throw new ArgumentOutOfRangeException(nameof(op), op, null) - }; - } - - private BoundUnaryOperator BindBinaryOperator(UnaryOperator op) - { - return op switch - { - UnaryOperator.Negate => BoundUnaryOperator.Negate, - UnaryOperator.Invert => BoundUnaryOperator.Invert, - _ => throw new ArgumentOutOfRangeException(nameof(op), op, null) - }; - } - - private BoundBlock BindBlock(BlockSyntax node, Scope? scope = null) - { - var statements = new List(); - - _scopes.Push(scope ?? Scope.SubScope()); - - foreach (var statement in node.Statements) - { - statements.Add(BindStatement(statement)); - } - - _scopes.Pop(); - - return new BoundBlock(statements); - } - - private BoundBlock BindFuncBody(BlockSyntax block, NubType returnType, IReadOnlyList parameters) - { - _funcReturnTypes.Push(returnType); - - var scope = new Scope(); - foreach (var parameter in parameters) - { - scope.Declare(new Variable(parameter.Name, parameter.Type)); - } - - var body = BindBlock(block, scope); - _funcReturnTypes.Pop(); - return body; - } - - private NubType BindType(TypeSyntax node) - { - return node switch - { - ArrayTypeSyntax type => new NubArrayType(BindType(type.BaseType)), - CStringTypeSyntax => new NubCStringType(), - CustomTypeSyntax type => new NubCustomType(type.MangledName()), - FuncTypeSyntax type => new NubFuncType(type.Parameters.Select(BindType).ToList(), BindType(type.ReturnType)), - PointerTypeSyntax type => new NubPointerType(BindType(type.BaseType)), - PrimitiveTypeSyntax type => new NubPrimitiveType(type.SyntaxKind switch - { - PrimitiveTypeSyntaxKind.I64 => PrimitiveTypeKind.I64, - PrimitiveTypeSyntaxKind.I32 => PrimitiveTypeKind.I32, - PrimitiveTypeSyntaxKind.I16 => PrimitiveTypeKind.I16, - PrimitiveTypeSyntaxKind.I8 => PrimitiveTypeKind.I8, - PrimitiveTypeSyntaxKind.U64 => PrimitiveTypeKind.U64, - PrimitiveTypeSyntaxKind.U32 => PrimitiveTypeKind.U32, - PrimitiveTypeSyntaxKind.U16 => PrimitiveTypeKind.U16, - PrimitiveTypeSyntaxKind.U8 => PrimitiveTypeKind.U8, - PrimitiveTypeSyntaxKind.F64 => PrimitiveTypeKind.F64, - PrimitiveTypeSyntaxKind.F32 => PrimitiveTypeKind.F32, - PrimitiveTypeSyntaxKind.Bool => PrimitiveTypeKind.Bool, - _ => throw new ArgumentOutOfRangeException() - }), - StringTypeSyntax => new NubStringType(), - VoidTypeSyntax => new NubVoidType(), - _ => throw new ArgumentOutOfRangeException(nameof(node)) - }; - } -} - -public record Variable(string Name, NubType Type); - -public class Scope(Scope? parent = null) -{ - private readonly List _variables = []; - - public Variable? Lookup(string name) - { - var variable = _variables.FirstOrDefault(x => x.Name == name); - if (variable != null) - { - return variable; - } - - return parent?.Lookup(name); - } - - public void Declare(Variable variable) - { - _variables.Add(variable); - } - - public Scope SubScope() - { - return new Scope(this); - } -} - -public class BindException : Exception -{ - public Diagnostic Diagnostic { get; } - - public BindException(Diagnostic diagnostic) : base(diagnostic.Message) - { - Diagnostic = diagnostic; - } -} \ No newline at end of file diff --git a/src/compiler/NubLang/Syntax/Binding/Node/BoundDefinition.cs b/src/compiler/NubLang/Syntax/Binding/Node/BoundDefinition.cs deleted file mode 100644 index 922a9e8..0000000 --- a/src/compiler/NubLang/Syntax/Binding/Node/BoundDefinition.cs +++ /dev/null @@ -1,25 +0,0 @@ -using NubLang.Common; - -namespace NubLang.Syntax.Binding.Node; - -public abstract record BoundDefinition : BoundNode; - -public record BoundFuncParameter(string Name, NubType Type) : BoundNode; - -public record BoundFuncSignature(IReadOnlyList Parameters, NubType ReturnType) : BoundNode; - -public record BoundLocalFunc(string Name, BoundFuncSignature Signature, BoundBlock Body) : BoundDefinition; - -public record BoundExternFunc(string Name, string CallName, BoundFuncSignature Signature) : BoundDefinition; - -public record BoundStructField(int Index, string Name, NubType Type, Optional Value) : BoundNode; - -public record BoundStruct(string Name, IReadOnlyList Fields) : BoundDefinition; - -public record BoundTraitFunc(string Name, BoundFuncSignature Signature) : BoundNode; - -public record BoundTrait(string Name, IReadOnlyList Functions) : BoundDefinition; - -public record BoundTraitFuncImpl(string Name, BoundFuncSignature Signature, BoundBlock Body) : BoundNode; - -public record BoundTraitImpl(NubType TraitType, NubType ForType, IReadOnlyList Functions) : BoundDefinition; \ No newline at end of file diff --git a/src/compiler/NubLang/Syntax/Binding/Node/BoundExpression.cs b/src/compiler/NubLang/Syntax/Binding/Node/BoundExpression.cs deleted file mode 100644 index d5e3796..0000000 --- a/src/compiler/NubLang/Syntax/Binding/Node/BoundExpression.cs +++ /dev/null @@ -1,55 +0,0 @@ -using NubLang.Syntax.Tokenization; - -namespace NubLang.Syntax.Binding.Node; - -public enum BoundUnaryOperator -{ - Negate, - Invert -} - -public enum BoundBinaryOperator -{ - Equal, - NotEqual, - GreaterThan, - GreaterThanOrEqual, - LessThan, - LessThanOrEqual, - Plus, - Minus, - Multiply, - Divide -} - -public abstract record BoundExpression(NubType Type) : BoundNode; - -public record BoundBinaryExpression(NubType Type, BoundExpression Left, BoundBinaryOperator Operator, BoundExpression Right) : BoundExpression(Type); - -public record BoundUnaryExpression(NubType Type, BoundUnaryOperator Operator, BoundExpression Operand) : BoundExpression(Type); - -public record BoundFuncCall(NubType Type, BoundExpression Expression, IReadOnlyList Parameters) : BoundExpression(Type); - -public record BoundVariableIdent(NubType Type, string Name) : BoundExpression(Type); - -public record BoundLocalFuncIdent(NubType Type, string Name) : BoundExpression(Type); - -public record BoundExternFuncIdent(NubType Type, string Name) : BoundExpression(Type); - -public record BoundArrayInitializer(NubType Type, BoundExpression Capacity, NubType ElementType) : BoundExpression(Type); - -public record BoundArrayIndexAccess(NubType Type, BoundExpression Target, BoundExpression Index) : BoundExpression(Type); - -public record BoundArrowFunc(NubType Type, IReadOnlyList Parameters, NubType ReturnType, BoundBlock Body) : BoundExpression(Type); - -public record BoundAddressOf(NubType Type, BoundExpression Expression) : BoundExpression(Type); - -public record BoundLiteral(NubType Type, string Literal, LiteralKind Kind) : BoundExpression(Type); - -public record BoundStructFieldAccess(NubType Type, NubCustomType StructType, BoundExpression Target, string Field) : BoundExpression(Type); - -public record BoundInterfaceFuncAccess(NubType Type, NubCustomType InterfaceType, BoundExpression Target, string FuncName) : BoundExpression(Type); - -public record BoundStructInitializer(NubCustomType StructType, Dictionary Initializers) : BoundExpression(StructType); - -public record BoundDereference(NubType Type, BoundExpression Expression) : BoundExpression(Type); \ No newline at end of file diff --git a/src/compiler/NubLang/Syntax/Binding/Node/BoundStatement.cs b/src/compiler/NubLang/Syntax/Binding/Node/BoundStatement.cs deleted file mode 100644 index 4d1474b..0000000 --- a/src/compiler/NubLang/Syntax/Binding/Node/BoundStatement.cs +++ /dev/null @@ -1,21 +0,0 @@ -using NubLang.Common; - -namespace NubLang.Syntax.Binding.Node; - -public record BoundStatement : BoundNode; - -public record BoundStatementExpression(BoundExpression Expression) : BoundStatement; - -public record BoundReturn(Optional Value) : BoundStatement; - -public record BoundAssignment(BoundExpression Target, BoundExpression Value) : BoundStatement; - -public record BoundIf(BoundExpression Condition, BoundBlock Body, Optional> Else) : BoundStatement; - -public record BoundVariableDeclaration(string Name, Optional Assignment, NubType Type) : BoundStatement; - -public record BoundContinue : BoundStatement; - -public record BoundBreak : BoundStatement; - -public record BoundWhile(BoundExpression Condition, BoundBlock Body) : BoundStatement; \ No newline at end of file diff --git a/src/compiler/NubLang/Syntax/Binding/Node/BoundSyntaxTree.cs b/src/compiler/NubLang/Syntax/Binding/Node/BoundSyntaxTree.cs deleted file mode 100644 index 7867e39..0000000 --- a/src/compiler/NubLang/Syntax/Binding/Node/BoundSyntaxTree.cs +++ /dev/null @@ -1,9 +0,0 @@ -using NubLang.Diagnostics; - -namespace NubLang.Syntax.Binding.Node; - -public record BoundSyntaxTree(IReadOnlyList Definitions, IReadOnlyList Diagnostics); - -public abstract record BoundNode; - -public record BoundBlock(IReadOnlyList Statements) : BoundNode; \ No newline at end of file diff --git a/src/compiler/NubLang/Syntax/Tokenization/Token.cs b/src/compiler/NubLang/Tokenization/Token.cs similarity index 95% rename from src/compiler/NubLang/Syntax/Tokenization/Token.cs rename to src/compiler/NubLang/Tokenization/Token.cs index a3ad076..8810681 100644 --- a/src/compiler/NubLang/Syntax/Tokenization/Token.cs +++ b/src/compiler/NubLang/Tokenization/Token.cs @@ -1,4 +1,4 @@ -namespace NubLang.Syntax.Tokenization; +namespace NubLang.Tokenization; public abstract class Token; diff --git a/src/compiler/NubLang/Syntax/Tokenization/Tokenizer.cs b/src/compiler/NubLang/Tokenization/Tokenizer.cs similarity index 99% rename from src/compiler/NubLang/Syntax/Tokenization/Tokenizer.cs rename to src/compiler/NubLang/Tokenization/Tokenizer.cs index 3df4b06..43ed3c6 100644 --- a/src/compiler/NubLang/Syntax/Tokenization/Tokenizer.cs +++ b/src/compiler/NubLang/Tokenization/Tokenizer.cs @@ -1,6 +1,6 @@ using NubLang.Common; -namespace NubLang.Syntax.Tokenization; +namespace NubLang.Tokenization; public sealed class Tokenizer { diff --git a/src/compiler/NubLang/Syntax/Binding/DefinitionTable.cs b/src/compiler/NubLang/TypeChecking/DefinitionTable.cs similarity index 95% rename from src/compiler/NubLang/Syntax/Binding/DefinitionTable.cs rename to src/compiler/NubLang/TypeChecking/DefinitionTable.cs index 4513060..bffb214 100644 --- a/src/compiler/NubLang/Syntax/Binding/DefinitionTable.cs +++ b/src/compiler/NubLang/TypeChecking/DefinitionTable.cs @@ -1,6 +1,6 @@ -using NubLang.Syntax.Parsing.Node; +using NubLang.Parsing.Syntax; -namespace NubLang.Syntax.Binding; +namespace NubLang.TypeChecking; public class DefinitionTable { diff --git a/src/compiler/NubLang/TypeChecking/Node/Definition.cs b/src/compiler/NubLang/TypeChecking/Node/Definition.cs new file mode 100644 index 0000000..2c8ec82 --- /dev/null +++ b/src/compiler/NubLang/TypeChecking/Node/Definition.cs @@ -0,0 +1,25 @@ +using NubLang.Common; + +namespace NubLang.TypeChecking.Node; + +public abstract record Definition : Node; + +public record FuncParameter(string Name, NubType Type) : Node; + +public record FuncSignature(IReadOnlyList Parameters, NubType ReturnType) : Node; + +public record LocalFunc(string Name, FuncSignature Signature, Block Body) : Definition; + +public record ExternFunc(string Name, string CallName, FuncSignature Signature) : Definition; + +public record StructField(int Index, string Name, NubType Type, Optional Value) : Node; + +public record Struct(string Name, IReadOnlyList Fields) : Definition; + +public record TraitFunc(string Name, FuncSignature Signature) : Node; + +public record Trait(string Name, IReadOnlyList Functions) : Definition; + +public record TraitFuncImpl(string Name, FuncSignature Signature, Block Body) : Node; + +public record TraitImpl(NubType TraitType, NubType ForType, IReadOnlyList Functions) : Definition; \ No newline at end of file diff --git a/src/compiler/NubLang/TypeChecking/Node/Expression.cs b/src/compiler/NubLang/TypeChecking/Node/Expression.cs new file mode 100644 index 0000000..1cf2a39 --- /dev/null +++ b/src/compiler/NubLang/TypeChecking/Node/Expression.cs @@ -0,0 +1,55 @@ +using NubLang.Tokenization; + +namespace NubLang.TypeChecking.Node; + +public enum UnaryOperator +{ + Negate, + Invert +} + +public enum BinaryOperator +{ + Equal, + NotEqual, + GreaterThan, + GreaterThanOrEqual, + LessThan, + LessThanOrEqual, + Plus, + Minus, + Multiply, + Divide +} + +public abstract record Expression(NubType Type) : Node; + +public record BinaryExpression(NubType Type, Expression Left, BinaryOperator Operator, Expression Right) : Expression(Type); + +public record UnaryExpression(NubType Type, UnaryOperator Operator, Expression Operand) : Expression(Type); + +public record FuncCall(NubType Type, Expression Expression, IReadOnlyList Parameters) : Expression(Type); + +public record VariableIdent(NubType Type, string Name) : Expression(Type); + +public record LocalFuncIdent(NubType Type, string Name) : Expression(Type); + +public record ExternFuncIdent(NubType Type, string Name) : Expression(Type); + +public record ArrayInitializer(NubType Type, Expression Capacity, NubType ElementType) : Expression(Type); + +public record ArrayIndexAccess(NubType Type, Expression Target, Expression Index) : Expression(Type); + +public record ArrowFunc(NubType Type, IReadOnlyList Parameters, NubType ReturnType, Block Body) : Expression(Type); + +public record AddressOf(NubType Type, Expression Expression) : Expression(Type); + +public record Literal(NubType Type, string Value, LiteralKind Kind) : Expression(Type); + +public record StructFieldAccess(NubType Type, NubCustomType StructType, Expression Target, string Field) : Expression(Type); + +public record InterfaceFuncAccess(NubType Type, NubCustomType InterfaceType, Expression Target, string FuncName) : Expression(Type); + +public record StructInitializer(NubCustomType StructType, Dictionary Initializers) : Expression(StructType); + +public record Dereference(NubType Type, Expression Expression) : Expression(Type); \ No newline at end of file diff --git a/src/compiler/NubLang/TypeChecking/Node/Statement.cs b/src/compiler/NubLang/TypeChecking/Node/Statement.cs new file mode 100644 index 0000000..9fd2eea --- /dev/null +++ b/src/compiler/NubLang/TypeChecking/Node/Statement.cs @@ -0,0 +1,21 @@ +using NubLang.Common; + +namespace NubLang.TypeChecking.Node; + +public record Statement : Node; + +public record StatementExpression(Expression Expression) : Statement; + +public record Return(Optional Value) : Statement; + +public record Assignment(Expression Target, Expression Value) : Statement; + +public record If(Expression Condition, Block Body, Optional> Else) : Statement; + +public record VariableDeclaration(string Name, Optional Assignment, NubType Type) : Statement; + +public record Continue : Statement; + +public record Break : Statement; + +public record While(Expression Condition, Block Body) : Statement; \ No newline at end of file diff --git a/src/compiler/NubLang/TypeChecking/Node/SyntaxTree.cs b/src/compiler/NubLang/TypeChecking/Node/SyntaxTree.cs new file mode 100644 index 0000000..46d72e5 --- /dev/null +++ b/src/compiler/NubLang/TypeChecking/Node/SyntaxTree.cs @@ -0,0 +1,7 @@ +namespace NubLang.TypeChecking.Node; + +public record TypedSyntaxTree(IReadOnlyList Definitions); + +public abstract record Node; + +public record Block(IReadOnlyList Statements) : Node; \ No newline at end of file diff --git a/src/compiler/NubLang/Syntax/Binding/NubType.cs b/src/compiler/NubLang/TypeChecking/NubType.cs similarity index 91% rename from src/compiler/NubLang/Syntax/Binding/NubType.cs rename to src/compiler/NubLang/TypeChecking/NubType.cs index f884732..cac5358 100644 --- a/src/compiler/NubLang/Syntax/Binding/NubType.cs +++ b/src/compiler/NubLang/TypeChecking/NubType.cs @@ -1,7 +1,7 @@ using System.Diagnostics.CodeAnalysis; using NubLang.Generation; -namespace NubLang.Syntax.Binding; +namespace NubLang.TypeChecking; public abstract class NubType : IEquatable { @@ -38,8 +38,8 @@ public abstract class NubType : IEquatable throw new ArgumentException($"Type {this} is not a simple type nor a compex type"); } - public abstract int Size(BoundDefinitionTable definitionTable); - public abstract int Alignment(BoundDefinitionTable definitionTable); + public abstract int Size(TypedDefinitionTable definitionTable); + public abstract int Alignment(TypedDefinitionTable definitionTable); public static int AlignTo(int offset, int alignment) { @@ -75,7 +75,7 @@ public abstract class NubSimpleType : NubType { public abstract StorageSize StorageSize { get; } - public override int Size(BoundDefinitionTable definitionTable) + public override int Size(TypedDefinitionTable definitionTable) { return StorageSize switch { @@ -87,7 +87,7 @@ public abstract class NubSimpleType : NubType }; } - public override int Alignment(BoundDefinitionTable definitionTable) + public override int Alignment(TypedDefinitionTable definitionTable) { return Size(definitionTable); } @@ -208,8 +208,8 @@ public abstract class NubComplexType : NubType; public class NubCStringType : NubComplexType { - public override int Size(BoundDefinitionTable definitionTable) => 8; - public override int Alignment(BoundDefinitionTable definitionTable) => Size(definitionTable); + public override int Size(TypedDefinitionTable definitionTable) => 8; + public override int Alignment(TypedDefinitionTable definitionTable) => Size(definitionTable); public override string ToString() => "cstring"; public override bool Equals(NubType? other) => other is NubCStringType; @@ -218,8 +218,8 @@ public class NubCStringType : NubComplexType public class NubStringType : NubComplexType { - public override int Size(BoundDefinitionTable definitionTable) => 8; - public override int Alignment(BoundDefinitionTable definitionTable) => Size(definitionTable); + public override int Size(TypedDefinitionTable definitionTable) => 8; + public override int Alignment(TypedDefinitionTable definitionTable) => Size(definitionTable); public override string ToString() => "string"; public override bool Equals(NubType? other) => other is NubStringType; @@ -230,7 +230,7 @@ public class NubCustomType(string name) : NubComplexType { public string Name { get; } = name; - public CustomTypeKind Kind(BoundDefinitionTable definitionTable) + public CustomTypeKind Kind(TypedDefinitionTable definitionTable) { if (definitionTable.GetStructs().Any(x => x.Name == Name)) { @@ -245,7 +245,7 @@ public class NubCustomType(string name) : NubComplexType throw new ArgumentException($"Definition table does not have any type information for {this}"); } - public override int Size(BoundDefinitionTable definitionTable) + public override int Size(TypedDefinitionTable definitionTable) { switch (Kind(definitionTable)) { @@ -275,7 +275,7 @@ public class NubCustomType(string name) : NubComplexType } } - public override int Alignment(BoundDefinitionTable definitionTable) + public override int Alignment(TypedDefinitionTable definitionTable) { switch (Kind(definitionTable)) { @@ -303,8 +303,8 @@ public class NubArrayType(NubType elementType) : NubComplexType { public NubType ElementType { get; } = elementType; - public override int Size(BoundDefinitionTable definitionTable) => 8; - public override int Alignment(BoundDefinitionTable definitionTable) => Size(definitionTable); + public override int Size(TypedDefinitionTable definitionTable) => 8; + public override int Alignment(TypedDefinitionTable definitionTable) => Size(definitionTable); public override string ToString() => "[]" + ElementType; diff --git a/src/compiler/NubLang/TypeChecking/TypeChecker.cs b/src/compiler/NubLang/TypeChecking/TypeChecker.cs new file mode 100644 index 0000000..4e16b63 --- /dev/null +++ b/src/compiler/NubLang/TypeChecking/TypeChecker.cs @@ -0,0 +1,640 @@ +using NubLang.Common; +using NubLang.Diagnostics; +using NubLang.Parsing.Syntax; +using NubLang.Tokenization; +using NubLang.TypeChecking.Node; + +namespace NubLang.TypeChecking; + +public sealed class TypeChecker +{ + private readonly SyntaxTree _syntaxTree; + private readonly DefinitionTable _definitionTable; + + private readonly Stack _scopes = []; + private readonly Stack _funcReturnTypes = []; + private readonly List _diagnostics = []; + + private Scope Scope => _scopes.Peek(); + + public TypeChecker(SyntaxTree syntaxTree, DefinitionTable definitionTable) + { + _syntaxTree = syntaxTree; + _definitionTable = definitionTable; + } + + public IReadOnlyList GetDiagnostics() => _diagnostics; + + public TypedSyntaxTree Check() + { + _diagnostics.Clear(); + _funcReturnTypes.Clear(); + _scopes.Clear(); + + var definitions = new List(); + + foreach (var definition in _syntaxTree.Definitions) + { + try + { + definitions.Add(CheckDefinition(definition)); + } + catch (CheckException e) + { + _diagnostics.Add(e.Diagnostic); + } + } + + return new TypedSyntaxTree(definitions); + } + + private Definition CheckDefinition(DefinitionSyntax node) + { + return node switch + { + ExternFuncSyntax definition => CheckExternFuncDefinition(definition), + InterfaceSyntax definition => CheckTraitDefinition(definition), + LocalFuncSyntax definition => CheckLocalFuncDefinition(definition), + StructSyntax definition => CheckStruct(definition), + _ => throw new ArgumentOutOfRangeException(nameof(node)) + }; + } + + private Trait CheckTraitDefinition(InterfaceSyntax node) + { + var functions = new List(); + + foreach (var function in node.Functions) + { + functions.Add(new TraitFunc(function.Name, CheckFuncSignature(function.Signature))); + } + + return new Trait(node.Name, functions); + } + + private Struct CheckStruct(StructSyntax node) + { + var structFields = new List(); + + foreach (var field in node.Fields) + { + var value = Optional.Empty(); + + if (field.Value.HasValue) + { + value = CheckExpression(field.Value.Value, CheckType(field.Type)); + } + + structFields.Add(new StructField(field.Index, field.Name, CheckType(field.Type), value)); + } + + return new Struct(node.Name, structFields); + } + + private ExternFunc CheckExternFuncDefinition(ExternFuncSyntax node) + { + return new ExternFunc(node.Name, node.CallName, CheckFuncSignature(node.Signature)); + } + + private LocalFunc CheckLocalFuncDefinition(LocalFuncSyntax node) + { + var signature = CheckFuncSignature(node.Signature); + var body = CheckFuncBody(node.Body, signature.ReturnType, signature.Parameters); + + return new LocalFunc(node.Name, signature, body); + } + + private Statement CheckStatement(StatementSyntax node) + { + return node switch + { + AssignmentSyntax statement => CheckAssignment(statement), + BreakSyntax => new Break(), + ContinueSyntax => new Continue(), + IfSyntax statement => CheckIf(statement), + ReturnSyntax statement => CheckReturn(statement), + StatementExpressionSyntax statement => CheckStatementExpression(statement), + VariableDeclarationSyntax statement => CheckVariableDeclaration(statement), + WhileSyntax statement => CheckWhile(statement), + _ => throw new ArgumentOutOfRangeException(nameof(node)) + }; + } + + private Statement CheckAssignment(AssignmentSyntax statement) + { + var expression = CheckExpression(statement.Target); + var value = CheckExpression(statement.Value, expression.Type); + return new Assignment(expression, value); + } + + private If CheckIf(IfSyntax statement) + { + var elseStatement = Optional.Empty>(); + + if (statement.Else.HasValue) + { + elseStatement = statement.Else.Value.Match> + ( + elseIf => CheckIf(elseIf), + @else => CheckBlock(@else) + ); + } + + return new If(CheckExpression(statement.Condition, new NubPrimitiveType(PrimitiveTypeKind.Bool)), CheckBlock(statement.Body), elseStatement); + } + + private Return CheckReturn(ReturnSyntax statement) + { + var value = Optional.Empty(); + + if (statement.Value.HasValue) + { + value = CheckExpression(statement.Value.Value, _funcReturnTypes.Peek()); + } + + return new Return(value); + } + + private StatementExpression CheckStatementExpression(StatementExpressionSyntax statement) + { + return new StatementExpression(CheckExpression(statement.Expression)); + } + + private VariableDeclaration CheckVariableDeclaration(VariableDeclarationSyntax statement) + { + NubType? type = null; + + if (statement.ExplicitType.HasValue) + { + type = CheckType(statement.ExplicitType.Value); + } + + var assignment = Optional.Empty(); + if (statement.Assignment.HasValue) + { + var boundValue = CheckExpression(statement.Assignment.Value, type); + assignment = boundValue; + type = boundValue.Type; + } + + if (type == null) + { + throw new NotImplementedException("Diagnostics not implemented"); + } + + Scope.Declare(new Variable(statement.Name, type)); + + return new VariableDeclaration(statement.Name, assignment, type); + } + + private While CheckWhile(WhileSyntax statement) + { + return new While(CheckExpression(statement.Condition, new NubPrimitiveType(PrimitiveTypeKind.Bool)), CheckBlock(statement.Body)); + } + + private Expression CheckExpression(ExpressionSyntax node, NubType? expectedType = null) + { + return node switch + { + AddressOfSyntax expression => CheckAddressOf(expression), + ArrowFuncSyntax expression => CheckArrowFunc(expression, expectedType), + ArrayIndexAccessSyntax expression => CheckArrayIndexAccess(expression), + ArrayInitializerSyntax expression => CheckArrayInitializer(expression), + BinaryExpressionSyntax expression => CheckBinaryExpression(expression), + DereferenceSyntax expression => CheckDereference(expression), + FuncCallSyntax expression => CheckFuncCall(expression), + IdentifierSyntax expression => CheckIdentifier(expression), + LiteralSyntax expression => CheckLiteral(expression, expectedType), + MemberAccessSyntax expression => CheckMemberAccess(expression), + StructInitializerSyntax expression => CheckStructInitializer(expression), + UnaryExpressionSyntax expression => CheckUnaryExpression(expression), + _ => throw new ArgumentOutOfRangeException(nameof(node)) + }; + } + + private AddressOf CheckAddressOf(AddressOfSyntax expression) + { + var inner = CheckExpression(expression.Expression); + return new AddressOf(new NubPointerType(inner.Type), inner); + } + + private ArrowFunc CheckArrowFunc(ArrowFuncSyntax expression, NubType? expectedType = null) + { + if (expectedType == null) + { + throw new CheckException(Diagnostic.Error("Cannot infer argument types for arrow function").Build()); + } + + if (expectedType is not NubFuncType funcType) + { + throw new CheckException(Diagnostic.Error($"Expected {expectedType}, but got arrow function").Build()); + } + + var parameters = new List(); + + for (var i = 0; i < expression.Parameters.Count; i++) + { + if (i >= funcType.Parameters.Count) + { + throw new CheckException(Diagnostic.Error($"Arrow function expected a maximum of {funcType.Parameters.Count} arguments").Build()); + } + + var expectedParameterType = funcType.Parameters[i]; + var parameter = expression.Parameters[i]; + parameters.Add(new FuncParameter(parameter.Name, expectedParameterType)); + } + + var body = CheckFuncBody(expression.Body, funcType.ReturnType, parameters); + + return new ArrowFunc(new NubFuncType(parameters.Select(x => x.Type).ToList(), funcType.ReturnType), parameters, funcType.ReturnType, body); + } + + private ArrayIndexAccess CheckArrayIndexAccess(ArrayIndexAccessSyntax expression) + { + var boundArray = CheckExpression(expression.Target); + var elementType = ((NubArrayType)boundArray.Type).ElementType; + return new ArrayIndexAccess(elementType, boundArray, CheckExpression(expression.Index, new NubPrimitiveType(PrimitiveTypeKind.U64))); + } + + private ArrayInitializer CheckArrayInitializer(ArrayInitializerSyntax expression) + { + var capacity = CheckExpression(expression.Capacity, new NubPrimitiveType(PrimitiveTypeKind.U64)); + var type = new NubArrayType(CheckType(expression.ElementType)); + return new ArrayInitializer(type, capacity, CheckType(expression.ElementType)); + } + + private BinaryExpression CheckBinaryExpression(BinaryExpressionSyntax expression) + { + var boundLeft = CheckExpression(expression.Left); + var boundRight = CheckExpression(expression.Right, boundLeft.Type); + return new BinaryExpression(boundLeft.Type, boundLeft, CheckBinaryOperator(expression.OperatorSyntax), boundRight); + } + + private Dereference CheckDereference(DereferenceSyntax expression) + { + var boundExpression = CheckExpression(expression.Expression); + var dereferencedType = ((NubPointerType)boundExpression.Type).BaseType; + return new Dereference(dereferencedType, boundExpression); + } + + private FuncCall CheckFuncCall(FuncCallSyntax expression) + { + var boundExpression = CheckExpression(expression.Expression); + + var funcType = (NubFuncType)boundExpression.Type; + + var parameters = new List(); + + foreach (var (i, parameter) in expression.Parameters.Index()) + { + if (i >= funcType.Parameters.Count) + { + throw new NotImplementedException("Diagnostics not implemented"); + } + + var expectedType = funcType.Parameters[i]; + + parameters.Add(CheckExpression(parameter, expectedType)); + } + + return new FuncCall(funcType.ReturnType, boundExpression, parameters); + } + + private Expression CheckIdentifier(IdentifierSyntax expression) + { + var variable = Scope.Lookup(expression.Name); + if (variable != null) + { + return new VariableIdent(variable.Type, variable.Name); + } + + var localFuncs = _definitionTable.LookupLocalFunc(expression.Name).ToArray(); + if (localFuncs.Length > 0) + { + if (localFuncs.Length > 1) + { + throw new CheckException(Diagnostic.Error($"Extern func {expression.Name} has multiple definitions").Build()); + } + + var localFunc = localFuncs[0]; + + var returnType = CheckType(localFunc.Signature.ReturnType); + var parameterTypes = localFunc.Signature.Parameters.Select(p => CheckType(p.Type)).ToList(); + var type = new NubFuncType(parameterTypes, returnType); + return new LocalFuncIdent(type, expression.Name); + } + + var externFuncs = _definitionTable.LookupExternFunc(expression.Name).ToArray(); + if (externFuncs.Length > 0) + { + if (externFuncs.Length > 1) + { + throw new CheckException(Diagnostic.Error($"Extern func {expression.Name} has multiple definitions").Build()); + } + + var externFunc = externFuncs[0]; + + var returnType = CheckType(externFunc.Signature.ReturnType); + var parameterTypes = externFunc.Signature.Parameters.Select(p => CheckType(p.Type)).ToList(); + var type = new NubFuncType(parameterTypes, returnType); + return new ExternFuncIdent(type, expression.Name); + } + + throw new CheckException(Diagnostic.Error($"No identifier with the name {expression.Name} exists").Build()); + } + + private Literal CheckLiteral(LiteralSyntax expression, NubType? expectedType = null) + { + var type = expectedType ?? expression.Kind switch + { + LiteralKind.Integer => new NubPrimitiveType(PrimitiveTypeKind.I64), + LiteralKind.Float => new NubPrimitiveType(PrimitiveTypeKind.F64), + LiteralKind.String => new NubStringType(), + LiteralKind.Bool => new NubPrimitiveType(PrimitiveTypeKind.Bool), + _ => throw new ArgumentOutOfRangeException() + }; + + return new Literal(type, expression.Value, expression.Kind); + } + + private Expression CheckMemberAccess(MemberAccessSyntax expression) + { + var boundExpression = CheckExpression(expression.Target); + + if (boundExpression.Type is NubCustomType customType) + { + var traits = _definitionTable.LookupTrait(customType).ToArray(); + if (traits.Length > 0) + { + if (traits.Length > 1) + { + throw new CheckException(Diagnostic.Error($"Trait {customType} has multiple definitions").Build()); + } + + var trait = traits[0]; + + var traitFuncs = _definitionTable.LookupTraitFunc(trait, expression.Member).ToArray(); + if (traits.Length > 0) + { + if (traits.Length > 1) + { + throw new CheckException(Diagnostic.Error($"Trait {customType} has multiple functions with the name {expression.Member}").Build()); + } + + var traitFunc = traitFuncs[0]; + + var returnType = CheckType(traitFunc.Signature.ReturnType); + var parameterTypes = traitFunc.Signature.Parameters.Select(p => CheckType(p.Type)).ToList(); + var type = new NubFuncType(parameterTypes, returnType); + return new InterfaceFuncAccess(type, customType, boundExpression, expression.Member); + } + } + + var structs = _definitionTable.LookupStruct(customType).ToArray(); + if (structs.Length > 0) + { + if (structs.Length > 1) + { + throw new CheckException(Diagnostic.Error($"Struct {customType} has multiple definitions").Build()); + } + + var @struct = structs[0]; + + var fields = _definitionTable.LookupStructField(@struct, expression.Member).ToArray(); + if (fields.Length > 0) + { + if (fields.Length > 1) + { + throw new CheckException(Diagnostic.Error($"Struct {customType} has multiple fields with the name {expression.Member}").Build()); + } + + var field = fields[0]; + + return new StructFieldAccess(CheckType(field.Type), customType, boundExpression, expression.Member); + } + } + } + + throw new CheckException(Diagnostic.Error($"{boundExpression.Type} does not have a member with the name {expression.Member}").Build()); + } + + private StructInitializer CheckStructInitializer(StructInitializerSyntax expression) + { + var boundType = CheckType(expression.StructType); + + if (boundType is not NubCustomType structType) + { + throw new CheckException(Diagnostic.Error($"Cannot initialize non-struct type {expression.StructType}").Build()); + } + + var structs = _definitionTable.LookupStruct(structType).ToArray(); + + if (structs.Length == 0) + { + throw new CheckException(Diagnostic.Error($"Struct {structType} is not defined").Build()); + } + + if (structs.Length > 1) + { + throw new CheckException(Diagnostic.Error($"Struct {structType} has multiple definitions").Build()); + } + + var @struct = structs[0]; + + var initializers = new Dictionary(); + + foreach (var (field, initializer) in expression.Initializers) + { + var fields = _definitionTable.LookupStructField(@struct, field).ToArray(); + + if (fields.Length == 0) + { + throw new CheckException(Diagnostic.Error($"Struct {structType} does not have a field with the name {field}").Build()); + } + + if (fields.Length > 1) + { + throw new CheckException(Diagnostic.Error($"Struct {structType} has multiple fields with the name {field}").Build()); + } + + initializers[field] = CheckExpression(initializer, CheckType(fields[0].Type)); + } + + return new StructInitializer(structType, initializers); + } + + private UnaryExpression CheckUnaryExpression(UnaryExpressionSyntax expression) + { + var boundOperand = CheckExpression(expression.Operand); + + NubType? type = null; + + switch (expression.OperatorSyntax) + { + case UnaryOperatorSyntax.Negate: + { + boundOperand = CheckExpression(expression.Operand, new NubPrimitiveType(PrimitiveTypeKind.I64)); + + if (boundOperand.Type.IsNumber) + { + type = boundOperand.Type; + } + + break; + } + case UnaryOperatorSyntax.Invert: + { + boundOperand = CheckExpression(expression.Operand, new NubPrimitiveType(PrimitiveTypeKind.Bool)); + + type = new NubPrimitiveType(PrimitiveTypeKind.Bool); + break; + } + } + + if (type == null) + { + throw new NotImplementedException("Diagnostics not implemented"); + } + + return new UnaryExpression(type, CheckUnaryOperator(expression.OperatorSyntax), boundOperand); + } + + private FuncSignature CheckFuncSignature(FuncSignatureSyntax node) + { + var parameters = new List(); + + foreach (var parameter in node.Parameters) + { + parameters.Add(new FuncParameter(parameter.Name, CheckType(parameter.Type))); + } + + return new FuncSignature(parameters, CheckType(node.ReturnType)); + } + + private BinaryOperator CheckBinaryOperator(BinaryOperatorSyntax op) + { + return op switch + { + BinaryOperatorSyntax.Equal => BinaryOperator.Equal, + BinaryOperatorSyntax.NotEqual => BinaryOperator.NotEqual, + BinaryOperatorSyntax.GreaterThan => BinaryOperator.GreaterThan, + BinaryOperatorSyntax.GreaterThanOrEqual => BinaryOperator.GreaterThanOrEqual, + BinaryOperatorSyntax.LessThan => BinaryOperator.LessThan, + BinaryOperatorSyntax.LessThanOrEqual => BinaryOperator.LessThanOrEqual, + BinaryOperatorSyntax.Plus => BinaryOperator.Plus, + BinaryOperatorSyntax.Minus => BinaryOperator.Minus, + BinaryOperatorSyntax.Multiply => BinaryOperator.Multiply, + BinaryOperatorSyntax.Divide => BinaryOperator.Divide, + _ => throw new ArgumentOutOfRangeException(nameof(op), op, null) + }; + } + + private UnaryOperator CheckUnaryOperator(UnaryOperatorSyntax op) + { + return op switch + { + UnaryOperatorSyntax.Negate => UnaryOperator.Negate, + UnaryOperatorSyntax.Invert => UnaryOperator.Invert, + _ => throw new ArgumentOutOfRangeException(nameof(op), op, null) + }; + } + + private Block CheckBlock(BlockSyntax node, Scope? scope = null) + { + var statements = new List(); + + _scopes.Push(scope ?? Scope.SubScope()); + + foreach (var statement in node.Statements) + { + statements.Add(CheckStatement(statement)); + } + + _scopes.Pop(); + + return new Block(statements); + } + + private Block CheckFuncBody(BlockSyntax block, NubType returnType, IReadOnlyList parameters) + { + _funcReturnTypes.Push(returnType); + + var scope = new Scope(); + foreach (var parameter in parameters) + { + scope.Declare(new Variable(parameter.Name, parameter.Type)); + } + + var body = CheckBlock(block, scope); + _funcReturnTypes.Pop(); + return body; + } + + private NubType CheckType(TypeSyntax node) + { + return node switch + { + ArrayTypeSyntax type => new NubArrayType(CheckType(type.BaseType)), + CStringTypeSyntax => new NubCStringType(), + CustomTypeSyntax type => new NubCustomType(type.MangledName()), + FuncTypeSyntax type => new NubFuncType(type.Parameters.Select(CheckType).ToList(), CheckType(type.ReturnType)), + PointerTypeSyntax type => new NubPointerType(CheckType(type.BaseType)), + PrimitiveTypeSyntax type => new NubPrimitiveType(type.SyntaxKind switch + { + PrimitiveTypeSyntaxKind.I64 => PrimitiveTypeKind.I64, + PrimitiveTypeSyntaxKind.I32 => PrimitiveTypeKind.I32, + PrimitiveTypeSyntaxKind.I16 => PrimitiveTypeKind.I16, + PrimitiveTypeSyntaxKind.I8 => PrimitiveTypeKind.I8, + PrimitiveTypeSyntaxKind.U64 => PrimitiveTypeKind.U64, + PrimitiveTypeSyntaxKind.U32 => PrimitiveTypeKind.U32, + PrimitiveTypeSyntaxKind.U16 => PrimitiveTypeKind.U16, + PrimitiveTypeSyntaxKind.U8 => PrimitiveTypeKind.U8, + PrimitiveTypeSyntaxKind.F64 => PrimitiveTypeKind.F64, + PrimitiveTypeSyntaxKind.F32 => PrimitiveTypeKind.F32, + PrimitiveTypeSyntaxKind.Bool => PrimitiveTypeKind.Bool, + _ => throw new ArgumentOutOfRangeException() + }), + StringTypeSyntax => new NubStringType(), + VoidTypeSyntax => new NubVoidType(), + _ => throw new ArgumentOutOfRangeException(nameof(node)) + }; + } +} + +public record Variable(string Name, NubType Type); + +public class Scope(Scope? parent = null) +{ + private readonly List _variables = []; + + public Variable? Lookup(string name) + { + var variable = _variables.FirstOrDefault(x => x.Name == name); + if (variable != null) + { + return variable; + } + + return parent?.Lookup(name); + } + + public void Declare(Variable variable) + { + _variables.Add(variable); + } + + public Scope SubScope() + { + return new Scope(this); + } +} + +public class CheckException : Exception +{ + public Diagnostic Diagnostic { get; } + + public CheckException(Diagnostic diagnostic) : base(diagnostic.Message) + { + Diagnostic = diagnostic; + } +} \ No newline at end of file