From 3db412a0606e47d252b990df7ee0df4857a281ab Mon Sep 17 00:00:00 2001 From: Oliver Stene Date: Sun, 8 Feb 2026 23:05:50 +0100 Subject: [PATCH] Type checking basics --- compiler/Compiler/Diagnostic.cs | 11 + compiler/Compiler/Generator.cs | 146 +++---- compiler/Compiler/Program.cs | 22 +- compiler/Compiler/TypeChecker.cs | 674 +++++++++++++++++++++++++++++++ 4 files changed, 780 insertions(+), 73 deletions(-) create mode 100644 compiler/Compiler/TypeChecker.cs diff --git a/compiler/Compiler/Diagnostic.cs b/compiler/Compiler/Diagnostic.cs index 5ed6891..1b9ce99 100644 --- a/compiler/Compiler/Diagnostic.cs +++ b/compiler/Compiler/Diagnostic.cs @@ -44,6 +44,17 @@ public sealed class DiagnosticBuilder(DiagnosticSeverity severity, string messag return this; } + public DiagnosticBuilder At(string fileName, TypedNode? node) + { + if (node != null && node.Tokens.Count != 0) + { + // todo(nub31): Calculate length based on last token + At(fileName, node.Tokens[0]); + } + + return this; + } + public DiagnosticBuilder WithHelp(string helpMessage) { help = helpMessage; diff --git a/compiler/Compiler/Generator.cs b/compiler/Compiler/Generator.cs index 1168f22..0705898 100644 --- a/compiler/Compiler/Generator.cs +++ b/compiler/Compiler/Generator.cs @@ -2,14 +2,15 @@ namespace Compiler; -public sealed class Generator(List nodes) +public sealed class Generator(TypedAst ast) { - public static string Emit(List nodes) + public static string Emit(TypedAst ast) { - return new Generator(nodes).Emit(); + return new Generator(ast).Emit(); } private IndentedTextWriter writer = new(); + private Dictionary structTypeNames = new(); private string Emit() { @@ -27,23 +28,22 @@ public sealed class Generator(List nodes) """); - foreach (var node in nodes.OfType()) + for (var i = 0; i < ast.StructTypes.Count; i++) { - var parameters = node.Parameters.Select(x => CType(x.Type, x.Name.Ident)); - writer.WriteLine($"{CType(node.ReturnType, node.Name.Ident)}({string.Join(", ", parameters)});"); + var structType = ast.StructTypes[i]; + structTypeNames[structType] = $"s{i}"; } - writer.WriteLine(); - - foreach (var node in nodes.OfType()) + foreach (var structType in ast.StructTypes) { - writer.WriteLine($"struct {node.Name.Ident}"); + var name = structTypeNames[structType]; + writer.WriteLine($"struct {name}"); writer.WriteLine("{"); using (writer.Indent()) { - foreach (var field in node.Fields) + foreach (var field in structType.Fields) { - writer.WriteLine($"{CType(field.Type, field.Name.Ident)};"); + writer.WriteLine($"{CType(field.Type, field.Name)};"); } } @@ -52,7 +52,15 @@ public sealed class Generator(List nodes) writer.WriteLine(); - foreach (var node in nodes.OfType()) + foreach (var node in ast.Functions) + { + var parameters = node.Parameters.Select(x => CType(x.Type, x.Name.Ident)); + writer.WriteLine($"{CType(node.ReturnType, node.Name.Ident)}({string.Join(", ", parameters)});"); + } + + writer.WriteLine(); + + foreach (var node in ast.Functions) { var parameters = node.Parameters.Select(x => CType(x.Type, x.Name.Ident)); writer.WriteLine($"{CType(node.ReturnType, node.Name.Ident)}({string.Join(", ", parameters)})"); @@ -69,29 +77,29 @@ public sealed class Generator(List nodes) return writer.ToString(); } - private void EmitStatement(NodeStatement node) + private void EmitStatement(TypedNodeStatement node) { switch (node) { - case NodeStatementBlock statement: + case TypedNodeStatementBlock statement: EmitStatementBlock(statement); break; - case NodeStatementFuncCall statement: + case TypedNodeStatementFuncCall statement: EmitStatementFuncCall(statement); break; - case NodeStatementReturn statement: + case TypedNodeStatementReturn statement: EmitStatementReturn(statement); break; - case NodeStatementVariableDeclaration statement: + case TypedNodeStatementVariableDeclaration statement: EmitStatementVariableDeclaration(statement); break; - case NodeStatementAssignment statement: + case TypedNodeStatementAssignment statement: EmitStatementAssignment(statement); break; - case NodeStatementIf statement: + case TypedNodeStatementIf statement: EmitStatementIf(statement); break; - case NodeStatementWhile statement: + case TypedNodeStatementWhile statement: EmitStatementWhile(statement); break; default: @@ -99,7 +107,7 @@ public sealed class Generator(List nodes) } } - private void EmitStatementBlock(NodeStatementBlock node) + private void EmitStatementBlock(TypedNodeStatementBlock node) { writer.WriteLine("{"); using (writer.Indent()) @@ -111,33 +119,33 @@ public sealed class Generator(List nodes) writer.WriteLine("}"); } - private void EmitStatementFuncCall(NodeStatementFuncCall node) + private void EmitStatementFuncCall(TypedNodeStatementFuncCall node) { var name = EmitExpression(node.Target); var parameterValues = node.Parameters.Select(EmitExpression).ToList(); writer.WriteLine($"{name}({string.Join(", ", parameterValues)});"); } - private void EmitStatementReturn(NodeStatementReturn statement) + private void EmitStatementReturn(TypedNodeStatementReturn statement) { var value = EmitExpression(statement.Value); writer.WriteLine($"return {value};"); } - private void EmitStatementVariableDeclaration(NodeStatementVariableDeclaration statement) + private void EmitStatementVariableDeclaration(TypedNodeStatementVariableDeclaration statement) { var value = EmitExpression(statement.Value); writer.WriteLine($"{CType(statement.Type)} {statement.Name.Ident} = {value};"); } - private void EmitStatementAssignment(NodeStatementAssignment statement) + private void EmitStatementAssignment(TypedNodeStatementAssignment statement) { var target = EmitExpression(statement.Target); var value = EmitExpression(statement.Value); writer.WriteLine($"{target} = {value};"); } - private void EmitStatementIf(NodeStatementIf statement) + private void EmitStatementIf(TypedNodeStatementIf statement) { var condition = EmitExpression(statement.Condition); writer.WriteLine($"if ({condition})"); @@ -152,7 +160,7 @@ public sealed class Generator(List nodes) if (statement.ElseBlock != null) { writer.Write("else"); - if (statement.ElseBlock is NodeStatementIf) + if (statement.ElseBlock is TypedNodeStatementIf) writer.Write(" "); else writer.WriteLine(); @@ -167,7 +175,7 @@ public sealed class Generator(List nodes) } } - private void EmitStatementWhile(NodeStatementWhile statement) + private void EmitStatementWhile(TypedNodeStatementWhile statement) { var condition = EmitExpression(statement.Condition); writer.WriteLine($"while ({condition})"); @@ -180,61 +188,61 @@ public sealed class Generator(List nodes) writer.WriteLine("}"); } - private string EmitExpression(NodeExpression node) + private string EmitExpression(TypedNodeExpression node) { return node switch { - NodeExpressionBinary expression => EmitExpressionBinary(expression), - NodeExpressionUnary expression => EmitExpressionUnary(expression), - NodeExpressionBoolLiteral expression => expression.Value.Value ? "true" : "false", - NodeExpressionIntLiteral expression => expression.Value.Value.ToString(), - NodeExpressionStringLiteral expression => $"(struct string){{ \"{expression.Value.Value}\", {expression.Value.Value.Length} }}", - NodeExpressionStructLiteral expression => EmitExpressionStructLiteral(expression), - NodeExpressionMemberAccess expression => EmitExpressionMemberAccess(expression), - NodeExpressionIdent expression => expression.Value.Ident, + TypedNodeExpressionBinary expression => EmitExpressionBinary(expression), + TypedNodeExpressionUnary expression => EmitExpressionUnary(expression), + TypedNodeExpressionBoolLiteral expression => expression.Value.Value ? "true" : "false", + TypedNodeExpressionIntLiteral expression => expression.Value.Value.ToString(), + TypedNodeExpressionStringLiteral expression => $"(struct string){{ \"{expression.Value.Value}\", {expression.Value.Value.Length} }}", + TypedNodeExpressionStructLiteral expression => EmitExpressionStructLiteral(expression), + TypedNodeExpressionMemberAccess expression => EmitExpressionMemberAccess(expression), + TypedNodeExpressionIdent expression => expression.Value.Ident, _ => throw new ArgumentOutOfRangeException(nameof(node), node, null) }; } - private string EmitExpressionBinary(NodeExpressionBinary expression) + private string EmitExpressionBinary(TypedNodeExpressionBinary expression) { var left = EmitExpression(expression.Left); var right = EmitExpression(expression.Right); return expression.Operation switch { - NodeExpressionBinary.Op.Add => $"({left} + {right})", - NodeExpressionBinary.Op.Subtract => $"({left} - {right})", - NodeExpressionBinary.Op.Multiply => $"({left} * {right})", - NodeExpressionBinary.Op.Divide => $"({left} / {right})", - NodeExpressionBinary.Op.Modulo => $"({left} % {right})", - NodeExpressionBinary.Op.Equal => $"({left} == {right})", - NodeExpressionBinary.Op.NotEqual => $"({left} != {right})", - NodeExpressionBinary.Op.LessThan => $"({left} < {right})", - NodeExpressionBinary.Op.LessThanOrEqual => $"({left} <= {right})", - NodeExpressionBinary.Op.GreaterThan => $"({left} > {right})", - NodeExpressionBinary.Op.GreaterThanOrEqual => $"({left} >= {right})", - NodeExpressionBinary.Op.LeftShift => $"({left} << {right})", - NodeExpressionBinary.Op.RightShift => $"({left} >> {right})", - NodeExpressionBinary.Op.LogicalAnd => $"({left} && {right})", - NodeExpressionBinary.Op.LogicalOr => $"({left} || {right})", + TypedNodeExpressionBinary.Op.Add => $"({left} + {right})", + TypedNodeExpressionBinary.Op.Subtract => $"({left} - {right})", + TypedNodeExpressionBinary.Op.Multiply => $"({left} * {right})", + TypedNodeExpressionBinary.Op.Divide => $"({left} / {right})", + TypedNodeExpressionBinary.Op.Modulo => $"({left} % {right})", + TypedNodeExpressionBinary.Op.Equal => $"({left} == {right})", + TypedNodeExpressionBinary.Op.NotEqual => $"({left} != {right})", + TypedNodeExpressionBinary.Op.LessThan => $"({left} < {right})", + TypedNodeExpressionBinary.Op.LessThanOrEqual => $"({left} <= {right})", + TypedNodeExpressionBinary.Op.GreaterThan => $"({left} > {right})", + TypedNodeExpressionBinary.Op.GreaterThanOrEqual => $"({left} >= {right})", + TypedNodeExpressionBinary.Op.LeftShift => $"({left} << {right})", + TypedNodeExpressionBinary.Op.RightShift => $"({left} >> {right})", + TypedNodeExpressionBinary.Op.LogicalAnd => $"({left} && {right})", + TypedNodeExpressionBinary.Op.LogicalOr => $"({left} || {right})", _ => throw new ArgumentOutOfRangeException() }; } - private string EmitExpressionUnary(NodeExpressionUnary expression) + private string EmitExpressionUnary(TypedNodeExpressionUnary expression) { var target = EmitExpression(expression.Target); return expression.Operation switch { - NodeExpressionUnary.Op.Negate => $"(-{target})", - NodeExpressionUnary.Op.Invert => $"(!{target})", + TypedNodeExpressionUnary.Op.Negate => $"(-{target})", + TypedNodeExpressionUnary.Op.Invert => $"(!{target})", _ => throw new ArgumentOutOfRangeException() }; } - private string EmitExpressionStructLiteral(NodeExpressionStructLiteral expression) + private string EmitExpressionStructLiteral(TypedNodeExpressionStructLiteral expression) { var initializerValues = new Dictionary(); @@ -246,27 +254,27 @@ public sealed class Generator(List nodes) var initializerStrings = initializerValues.Select(x => $".{x.Key} = {x.Value}"); - return $"(struct {expression.Name.Ident}){{ {string.Join(", ", initializerStrings)} }}"; + return $"(struct {structTypeNames[(NubTypeStruct)expression.Type]}){{ {string.Join(", ", initializerStrings)} }}"; } - private string EmitExpressionMemberAccess(NodeExpressionMemberAccess expression) + private string EmitExpressionMemberAccess(TypedNodeExpressionMemberAccess expression) { var target = EmitExpression(expression.Target); return $"{target}.{expression.Name.Ident}"; } - private static string CType(NodeType node, string? varName = null) + private string CType(NubType node, string? varName = null) { return node switch { - NodeTypeVoid => "void" + (varName != null ? $" {varName}" : ""), - NodeTypeBool => "bool" + (varName != null ? $" {varName}" : ""), - NodeTypeCustom type => $"struct {type.Name.Ident}" + (varName != null ? $" {varName}" : ""), - NodeTypeSInt type => $"int{type.Width}_t" + (varName != null ? $" {varName}" : ""), - NodeTypeUInt type => $"uint{type.Width}_t" + (varName != null ? $" {varName}" : ""), - NodeTypePointer type => CType(type.To) + (varName != null ? $" *{varName}" : "*"), - NodeTypeString => "struct string" + (varName != null ? $" {varName}" : ""), - NodeTypeFunc type => $"{CType(type.ReturnType)} (*{varName})({string.Join(", ", type.Parameters.Select(p => CType(p)))})", + NubTypeVoid => "void" + (varName != null ? $" {varName}" : ""), + NubTypeBool => "bool" + (varName != null ? $" {varName}" : ""), + NubTypeStruct type => $"struct {structTypeNames[type]}" + (varName != null ? $" {varName}" : ""), + NubTypeSInt type => $"int{type.Width}_t" + (varName != null ? $" {varName}" : ""), + NubTypeUInt type => $"uint{type.Width}_t" + (varName != null ? $" {varName}" : ""), + NubTypePointer type => CType(type.To) + (varName != null ? $" *{varName}" : "*"), + NubTypeString => "struct string" + (varName != null ? $" {varName}" : ""), + NubTypeFunc type => $"{CType(type.ReturnType)} (*{varName})({string.Join(", ", type.Parameters.Select(p => CType(p)))})", _ => throw new ArgumentOutOfRangeException(nameof(node), node, null) }; } diff --git a/compiler/Compiler/Program.cs b/compiler/Compiler/Program.cs index c81f58e..794682f 100644 --- a/compiler/Compiler/Program.cs +++ b/compiler/Compiler/Program.cs @@ -1,8 +1,10 @@ using Compiler; -var file = File.ReadAllText("test.nub"); +const string fileName = "test.nub"; -var tokens = Tokenizer.Tokenize("test.nub", file, out var tokenizerDiagnostics); +var file = File.ReadAllText(fileName); + +var tokens = Tokenizer.Tokenize(fileName, file, out var tokenizerDiagnostics); foreach (var diagnostic in tokenizerDiagnostics) { @@ -14,7 +16,7 @@ if (tokenizerDiagnostics.Any(x => x.Severity == DiagnosticSeverity.Error)) return 1; } -var nodes = Parser.Parse("test.nub", tokens, out var parserDiagnostics); +var nodes = Parser.Parse(fileName, tokens, out var parserDiagnostics); foreach (var diagnostic in parserDiagnostics) { @@ -26,7 +28,19 @@ if (parserDiagnostics.Any(x => x.Severity == DiagnosticSeverity.Error)) return 1; } -var output = Generator.Emit(nodes); +var typedNodes = TypeChecker.Check(fileName, nodes, out var typeCheckerDiagnostics); + +foreach (var diagnostic in typeCheckerDiagnostics) +{ + DiagnosticFormatter.Print(diagnostic, Console.Error); +} + +if (typeCheckerDiagnostics.Any(x => x.Severity == DiagnosticSeverity.Error)) +{ + return 1; +} + +var output = Generator.Emit(typedNodes); File.WriteAllText("C:/Users/oliste/repos/nub-lang/compiler/Compiler/out.c", output); diff --git a/compiler/Compiler/TypeChecker.cs b/compiler/Compiler/TypeChecker.cs new file mode 100644 index 0000000..0408c65 --- /dev/null +++ b/compiler/Compiler/TypeChecker.cs @@ -0,0 +1,674 @@ +namespace Compiler; + +public sealed class TypeChecker(string fileName, List definitions) +{ + public static TypedAst Check(string fileName, List nodes, out List diagnostics) + { + return new TypeChecker(fileName, nodes).Check(out diagnostics); + } + + private Scope scope = new(null); + private Dictionary structTypes = new(); + + private TypedAst Check(out List diagnostics) + { + var functions = new List(); + diagnostics = []; + + // todo(nub31): Types must be resolved better to prevent circular dependencies and independent ordering + foreach (var structDef in definitions.OfType()) + { + var fields = structDef.Fields.Select(x => new NubTypeStruct.Field(x.Name.Ident, CheckType(x.Type))).ToList(); + structTypes.Add(structDef.Name.Ident, new NubTypeStruct(fields)); + } + + foreach (var funcDef in definitions.OfType()) + { + var type = new NubTypeFunc(funcDef.Parameters.Select(x => CheckType(x.Type)).ToList(), CheckType(funcDef.ReturnType)); + scope.DeclareIdentifier(funcDef.Name.Ident, type); + } + + foreach (var funcDef in definitions.OfType()) + { + try + { + functions.Add(CheckDefinitionFunc(funcDef)); + } + catch (CompileException e) + { + diagnostics.Add(e.Diagnostic); + } + } + + return new TypedAst(functions, structTypes.Values.ToList()); + } + + private TypedNodeDefinitionFunc CheckDefinitionFunc(NodeDefinitionFunc definition) + { + return new TypedNodeDefinitionFunc(definition.Tokens, definition.Name, definition.Parameters.Select(CheckDefinitionFuncParameter).ToList(), CheckStatement(definition.Body), CheckType(definition.ReturnType)); + } + + private TypedNodeDefinitionFunc.Param CheckDefinitionFuncParameter(NodeDefinitionFunc.Param node) + { + return new TypedNodeDefinitionFunc.Param(node.Tokens, node.Name, CheckType(node.Type)); + } + + private TypedNodeStatement CheckStatement(NodeStatement node) + { + return node switch + { + NodeStatementAssignment statement => CheckStatementAssignment(statement), + NodeStatementBlock statement => CheckStatementBlock(statement), + NodeStatementFuncCall statement => CheckStatementFuncCall(statement), + NodeStatementIf statement => CheckStatementIf(statement), + NodeStatementReturn statement => CheckStatementReturn(statement), + NodeStatementVariableDeclaration statement => CheckStatementVariableDeclaration(statement), + NodeStatementWhile statement => CheckStatementWhile(statement), + _ => throw new ArgumentOutOfRangeException(nameof(node)) + }; + } + + private TypedNodeStatementAssignment CheckStatementAssignment(NodeStatementAssignment statement) + { + return new TypedNodeStatementAssignment(statement.Tokens, CheckExpression(statement.Target), CheckExpression(statement.Value)); + } + + private TypedNodeStatementBlock CheckStatementBlock(NodeStatementBlock statement) + { + return new TypedNodeStatementBlock(statement.Tokens, statement.Statements.Select(CheckStatement).ToList()); + } + + private TypedNodeStatementFuncCall CheckStatementFuncCall(NodeStatementFuncCall statement) + { + return new TypedNodeStatementFuncCall(statement.Tokens, CheckExpression(statement.Target), statement.Parameters.Select(CheckExpression).ToList()); + } + + private TypedNodeStatementIf CheckStatementIf(NodeStatementIf statement) + { + return new TypedNodeStatementIf(statement.Tokens, CheckExpression(statement.Condition), CheckStatement(statement.ThenBlock), statement.ElseBlock == null ? null : CheckStatement(statement.ElseBlock)); + } + + private TypedNodeStatementReturn CheckStatementReturn(NodeStatementReturn statement) + { + return new TypedNodeStatementReturn(statement.Tokens, CheckExpression(statement.Value)); + } + + private TypedNodeStatementVariableDeclaration CheckStatementVariableDeclaration(NodeStatementVariableDeclaration statement) + { + var type = CheckType(statement.Type); + var value = CheckExpression(statement.Value); + + if (type != value.Type) + throw new CompileException(Diagnostic.Error("Type of variable does match type of assigned value").At(fileName, value).Build()); + + scope.DeclareIdentifier(statement.Name.Ident, type); + + return new TypedNodeStatementVariableDeclaration(statement.Tokens, statement.Name, type, value); + } + + private TypedNodeStatementWhile CheckStatementWhile(NodeStatementWhile statement) + { + return new TypedNodeStatementWhile(statement.Tokens, CheckExpression(statement.Condition), CheckStatement(statement.Block)); + } + + private TypedNodeExpression CheckExpression(NodeExpression node) + { + return node switch + { + NodeExpressionBinary expression => CheckExpressionBinary(expression), + NodeExpressionUnary expression => CheckExpressionUnary(expression), + NodeExpressionBoolLiteral expression => CheckExpressionBoolLiteral(expression), + NodeExpressionIdent expression => CheckExpressionIdent(expression), + NodeExpressionIntLiteral expression => CheckExpressionIntLiteral(expression), + NodeExpressionMemberAccess expression => CheckExpressionMemberAccess(expression), + NodeExpressionStringLiteral expression => CheckExpressionStringLiteral(expression), + NodeExpressionStructLiteral expression => CheckExpressionStructLiteral(expression), + _ => throw new ArgumentOutOfRangeException(nameof(node)) + }; + } + + private TypedNodeExpressionBinary CheckExpressionBinary(NodeExpressionBinary expression) + { + var left = CheckExpression(expression.Left); + var right = CheckExpression(expression.Right); + NubType type; + + switch (expression.Operation) + { + case NodeExpressionBinary.Op.Add: + case NodeExpressionBinary.Op.Subtract: + case NodeExpressionBinary.Op.Multiply: + case NodeExpressionBinary.Op.Divide: + case NodeExpressionBinary.Op.Modulo: + { + if (left.Type is not NubTypeSInt and not NubTypeUInt) + throw new CompileException(Diagnostic.Error($"Unsupported type for left hand side arithmetic operation: {left.Type}").At(fileName, left).Build()); + + if (right.Type is not NubTypeSInt and not NubTypeUInt) + throw new CompileException(Diagnostic.Error($"Unsupported type for right hand side arithmetic operation: {right.Type}").At(fileName, right).Build()); + + type = left.Type; + break; + } + case NodeExpressionBinary.Op.LeftShift: + case NodeExpressionBinary.Op.RightShift: + { + if (left.Type is not NubTypeUInt) + throw new CompileException(Diagnostic.Error($"Unsupported type for left hand side of left/right shift operation: {left.Type}").At(fileName, left).Build()); + + if (right.Type is not NubTypeUInt) + throw new CompileException(Diagnostic.Error($"Unsupported type for right hand side of left/right shift operation: {right.Type}").At(fileName, right).Build()); + + type = left.Type; + break; + } + case NodeExpressionBinary.Op.Equal: + case NodeExpressionBinary.Op.NotEqual: + case NodeExpressionBinary.Op.LessThan: + case NodeExpressionBinary.Op.LessThanOrEqual: + case NodeExpressionBinary.Op.GreaterThan: + case NodeExpressionBinary.Op.GreaterThanOrEqual: + { + if (left.Type is not NubTypeSInt and not NubTypeUInt) + throw new CompileException(Diagnostic.Error($"Unsupported type for left hand side of comparison: {left.Type}").At(fileName, left).Build()); + + if (right.Type is not NubTypeSInt and not NubTypeUInt) + throw new CompileException(Diagnostic.Error($"Unsupported type for right hand side of comparison: {right.Type}").At(fileName, right).Build()); + + type = new NubTypeBool(); + break; + } + case NodeExpressionBinary.Op.LogicalAnd: + case NodeExpressionBinary.Op.LogicalOr: + { + if (left.Type is not NubTypeBool) + throw new CompileException(Diagnostic.Error($"Unsupported type for left hand side of logical operation: {left.Type}").At(fileName, left).Build()); + + if (right.Type is not NubTypeBool) + throw new CompileException(Diagnostic.Error($"Unsupported type for right hand side of logical operation: {right.Type}").At(fileName, right).Build()); + + type = new NubTypeBool(); + break; + } + default: + throw new ArgumentOutOfRangeException(); + } + + return new TypedNodeExpressionBinary(expression.Tokens, type, left, CheckExpressionBinaryOperation(expression.Operation), right); + } + + private static TypedNodeExpressionBinary.Op CheckExpressionBinaryOperation(NodeExpressionBinary.Op op) + { + return op switch + { + NodeExpressionBinary.Op.Add => TypedNodeExpressionBinary.Op.Add, + NodeExpressionBinary.Op.Subtract => TypedNodeExpressionBinary.Op.Subtract, + NodeExpressionBinary.Op.Multiply => TypedNodeExpressionBinary.Op.Multiply, + NodeExpressionBinary.Op.Divide => TypedNodeExpressionBinary.Op.Divide, + NodeExpressionBinary.Op.Modulo => TypedNodeExpressionBinary.Op.Modulo, + NodeExpressionBinary.Op.Equal => TypedNodeExpressionBinary.Op.Equal, + NodeExpressionBinary.Op.NotEqual => TypedNodeExpressionBinary.Op.NotEqual, + NodeExpressionBinary.Op.LessThan => TypedNodeExpressionBinary.Op.LessThan, + NodeExpressionBinary.Op.LessThanOrEqual => TypedNodeExpressionBinary.Op.LessThanOrEqual, + NodeExpressionBinary.Op.GreaterThan => TypedNodeExpressionBinary.Op.GreaterThan, + NodeExpressionBinary.Op.GreaterThanOrEqual => TypedNodeExpressionBinary.Op.GreaterThanOrEqual, + NodeExpressionBinary.Op.LeftShift => TypedNodeExpressionBinary.Op.LeftShift, + NodeExpressionBinary.Op.RightShift => TypedNodeExpressionBinary.Op.RightShift, + NodeExpressionBinary.Op.LogicalAnd => TypedNodeExpressionBinary.Op.LogicalAnd, + NodeExpressionBinary.Op.LogicalOr => TypedNodeExpressionBinary.Op.LogicalOr, + _ => throw new ArgumentOutOfRangeException(nameof(op), op, null) + }; + } + + private TypedNodeExpressionUnary CheckExpressionUnary(NodeExpressionUnary expression) + { + var target = CheckExpression(expression.Target); + NubType type; + + switch (expression.Operation) + { + case NodeExpressionUnary.Op.Negate: + { + if (target.Type is not NubTypeSInt and not NubTypeUInt) + throw new CompileException(Diagnostic.Error($"Unsupported type for negation: {target.Type}").At(fileName, target).Build()); + + type = target.Type; + break; + } + case NodeExpressionUnary.Op.Invert: + { + if (target.Type is not NubTypeBool) + throw new CompileException(Diagnostic.Error($"Unsupported type for inversion: {target.Type}").At(fileName, target).Build()); + + type = new NubTypeBool(); + break; + } + default: + throw new ArgumentOutOfRangeException(); + } + + return new TypedNodeExpressionUnary(expression.Tokens, type, target, CheckExpressionUnaryOperation(expression.Operation)); + } + + private static TypedNodeExpressionUnary.Op CheckExpressionUnaryOperation(NodeExpressionUnary.Op op) + { + return op switch + { + NodeExpressionUnary.Op.Negate => TypedNodeExpressionUnary.Op.Negate, + NodeExpressionUnary.Op.Invert => TypedNodeExpressionUnary.Op.Invert, + _ => throw new ArgumentOutOfRangeException(nameof(op), op, null) + }; + } + + private TypedNodeExpressionBoolLiteral CheckExpressionBoolLiteral(NodeExpressionBoolLiteral expression) + { + return new TypedNodeExpressionBoolLiteral(expression.Tokens, new NubTypeBool(), expression.Value); + } + + private TypedNodeExpressionIdent CheckExpressionIdent(NodeExpressionIdent expression) + { + var type = scope.GetIdentifierType(expression.Value.Ident); + if (type == null) + throw new CompileException(Diagnostic.Error($"Identifier '{expression.Value.Ident}' is not declared").At(fileName, expression.Value).Build()); + + return new TypedNodeExpressionIdent(expression.Tokens, type, expression.Value); + } + + private TypedNodeExpressionIntLiteral CheckExpressionIntLiteral(NodeExpressionIntLiteral expression) + { + return new TypedNodeExpressionIntLiteral(expression.Tokens, new NubTypeSInt(32), expression.Value); + } + + private TypedNodeExpressionMemberAccess CheckExpressionMemberAccess(NodeExpressionMemberAccess expression) + { + var target = CheckExpression(expression.Target); + if (target.Type is not NubTypeStruct structType) + throw new CompileException(Diagnostic.Error($"Cannot access member of non-struct type {target.Type}").At(fileName, target).Build()); + + var field = structType.Fields.FirstOrDefault(x => x.Name == expression.Name.Ident); + if (field == null) + throw new CompileException(Diagnostic.Error($"Struct {target.Type} does not have a field matching the name '{expression.Name.Ident}'").At(fileName, target).Build()); + + return new TypedNodeExpressionMemberAccess(expression.Tokens, field.Type, target, expression.Name); + } + + private TypedNodeExpressionStringLiteral CheckExpressionStringLiteral(NodeExpressionStringLiteral expression) + { + return new TypedNodeExpressionStringLiteral(expression.Tokens, new NubTypeString(), expression.Value); + } + + private TypedNodeExpressionStructLiteral CheckExpressionStructLiteral(NodeExpressionStructLiteral expression) + { + var type = structTypes.GetValueOrDefault(expression.Name.Ident); + if (type == null) + throw new CompileException(Diagnostic.Error($"Undeclared struct '{expression.Name.Ident}'").At(fileName, expression.Name).Build()); + + var initializers = new List(); + foreach (var initializer in expression.Initializers) + { + var field = type.Fields.FirstOrDefault(x => x.Name == initializer.Name.Ident); + if (field == null) + throw new CompileException(Diagnostic.Error($"Field '{initializer.Name.Ident}' does not exist on struct '{expression.Name.Ident}'").At(fileName, initializer.Name).Build()); + + var value = CheckExpression(initializer.Value); + if (value.Type != field.Type) + throw new CompileException(Diagnostic.Error($"Type of assignment ({value.Type}) does not match expected type of field '{field.Name}' ({field.Type})").At(fileName, initializer.Name).Build()); + + initializers.Add(new TypedNodeExpressionStructLiteral.Initializer(initializer.Tokens, initializer.Name, value)); + } + + return new TypedNodeExpressionStructLiteral(expression.Tokens, type, initializers); + } + + private NubType CheckType(NodeType node) + { + return node switch + { + NodeTypeBool type => new NubTypeBool(), + NodeTypeCustom type => CheckStructType(type), + NodeTypeFunc type => new NubTypeFunc(type.Parameters.Select(CheckType).ToList(), CheckType(type.ReturnType)), + NodeTypePointer type => new NubTypePointer(CheckType(type.To)), + NodeTypeSInt type => new NubTypeSInt(type.Width), + NodeTypeUInt type => new NubTypeUInt(type.Width), + NodeTypeString type => new NubTypeString(), + NodeTypeVoid type => new NubTypeVoid(), + _ => throw new ArgumentOutOfRangeException(nameof(node)) + }; + } + + private NubTypeStruct CheckStructType(NodeTypeCustom type) + { + var structType = structTypes.GetValueOrDefault(type.Name.Ident); + if (structType == null) + throw new CompileException(Diagnostic.Error($"Unknown custom type: {type}").At(fileName, type).Build()); + + return structType; + } + + private class Scope(Scope? parent) + { + private Dictionary identifiers = new(); + + public void DeclareIdentifier(string name, NubType type) + { + identifiers.Add(name, type); + } + + public NubType? GetIdentifierType(string name) + { + return identifiers.TryGetValue(name, out var type) + ? type + : parent?.GetIdentifierType(name); + } + } +} + +public sealed class TypedAst(List functions, List structTypes) +{ + public List Functions = functions; + public List StructTypes = structTypes; +} + +public abstract class TypedNode(List tokens) +{ + public readonly List Tokens = tokens; +} + +public abstract class TypedNodeDefinition(List tokens) : TypedNode(tokens); + +public sealed class TypedNodeDefinitionFunc(List tokens, TokenIdent name, List parameters, TypedNodeStatement body, NubType returnType) : TypedNodeDefinition(tokens) +{ + public readonly TokenIdent Name = name; + public readonly List Parameters = parameters; + public readonly TypedNodeStatement Body = body; + public readonly NubType ReturnType = returnType; + + public sealed class Param(List tokens, TokenIdent name, NubType type) : TypedNode(tokens) + { + public readonly TokenIdent Name = name; + public readonly NubType Type = type; + } +} + +public sealed class TypedNodeDefinitionStruct(List tokens, TokenIdent name, List fields) : TypedNodeDefinition(tokens) +{ + public readonly TokenIdent Name = name; + public readonly List Fields = fields; + + public sealed class Field(List tokens, TokenIdent name, NubType type) : TypedNode(tokens) + { + public readonly TokenIdent Name = name; + public readonly NubType Type = type; + } +} + +public abstract class TypedNodeStatement(List tokens) : TypedNode(tokens); + +public sealed class TypedNodeStatementBlock(List tokens, List statements) : TypedNodeStatement(tokens) +{ + public readonly List Statements = statements; +} + +public sealed class TypedNodeStatementFuncCall(List tokens, TypedNodeExpression target, List parameters) : TypedNodeStatement(tokens) +{ + public readonly TypedNodeExpression Target = target; + public readonly List Parameters = parameters; +} + +public sealed class TypedNodeStatementReturn(List tokens, TypedNodeExpression value) : TypedNodeStatement(tokens) +{ + public readonly TypedNodeExpression Value = value; +} + +public sealed class TypedNodeStatementVariableDeclaration(List tokens, TokenIdent name, NubType type, TypedNodeExpression value) : TypedNodeStatement(tokens) +{ + public readonly TokenIdent Name = name; + public readonly NubType Type = type; + public readonly TypedNodeExpression Value = value; +} + +public sealed class TypedNodeStatementAssignment(List tokens, TypedNodeExpression target, TypedNodeExpression value) : TypedNodeStatement(tokens) +{ + public readonly TypedNodeExpression Target = target; + public readonly TypedNodeExpression Value = value; +} + +public sealed class TypedNodeStatementIf(List tokens, TypedNodeExpression condition, TypedNodeStatement thenBlock, TypedNodeStatement? elseBlock) : TypedNodeStatement(tokens) +{ + public readonly TypedNodeExpression Condition = condition; + public readonly TypedNodeStatement ThenBlock = thenBlock; + public readonly TypedNodeStatement? ElseBlock = elseBlock; +} + +public sealed class TypedNodeStatementWhile(List tokens, TypedNodeExpression condition, TypedNodeStatement block) : TypedNodeStatement(tokens) +{ + public readonly TypedNodeExpression Condition = condition; + public readonly TypedNodeStatement Block = block; +} + +public abstract class TypedNodeExpression(List tokens, NubType type) : TypedNode(tokens) +{ + public readonly NubType Type = type; +} + +public sealed class TypedNodeExpressionIntLiteral(List tokens, NubType type, TokenIntLiteral value) : TypedNodeExpression(tokens, type) +{ + public readonly TokenIntLiteral Value = value; +} + +public sealed class TypedNodeExpressionStringLiteral(List tokens, NubType type, TokenStringLiteral value) : TypedNodeExpression(tokens, type) +{ + public readonly TokenStringLiteral Value = value; +} + +public sealed class TypedNodeExpressionBoolLiteral(List tokens, NubType type, TokenBoolLiteral value) : TypedNodeExpression(tokens, type) +{ + public readonly TokenBoolLiteral Value = value; +} + +public sealed class TypedNodeExpressionStructLiteral(List tokens, NubType type, List initializers) : TypedNodeExpression(tokens, type) +{ + public readonly List Initializers = initializers; + + public sealed class Initializer(List tokens, TokenIdent name, TypedNodeExpression value) : Node(tokens) + { + public readonly TokenIdent Name = name; + public readonly TypedNodeExpression Value = value; + } +} + +public sealed class TypedNodeExpressionMemberAccess(List tokens, NubType type, TypedNodeExpression target, TokenIdent name) : TypedNodeExpression(tokens, type) +{ + public readonly TypedNodeExpression Target = target; + public readonly TokenIdent Name = name; +} + +public sealed class TypedNodeExpressionIdent(List tokens, NubType type, TokenIdent value) : TypedNodeExpression(tokens, type) +{ + public readonly TokenIdent Value = value; +} + +public sealed class TypedNodeExpressionBinary(List tokens, NubType type, TypedNodeExpression left, TypedNodeExpressionBinary.Op operation, TypedNodeExpression right) : TypedNodeExpression(tokens, type) +{ + public readonly TypedNodeExpression Left = left; + public readonly Op Operation = operation; + public readonly TypedNodeExpression Right = right; + + public enum Op + { + Add, + Subtract, + Multiply, + Divide, + Modulo, + + Equal, + NotEqual, + LessThan, + LessThanOrEqual, + GreaterThan, + GreaterThanOrEqual, + + LeftShift, + RightShift, + + // BitwiseAnd, + // BitwiseXor, + // BitwiseOr, + + LogicalAnd, + LogicalOr, + } +} + +public sealed class TypedNodeExpressionUnary(List tokens, NubType type, TypedNodeExpression target, TypedNodeExpressionUnary.Op op) : TypedNodeExpression(tokens, type) +{ + public TypedNodeExpression Target { get; } = target; + public Op Operation { get; } = op; + + public enum Op + { + Negate, + Invert, + } +} + +public abstract class NubType : IEquatable +{ + public abstract override string ToString(); + + public abstract bool Equals(NubType? other); + + public override bool Equals(object? obj) + { + if (obj is NubType otherNubType) + { + return Equals(otherNubType); + } + + return false; + } + + public abstract override int GetHashCode(); + + public static bool operator ==(NubType? left, NubType? right) => Equals(left, right); + public static bool operator !=(NubType? left, NubType? right) => !Equals(left, right); +} + +public sealed class NubTypeVoid : NubType +{ + public override string ToString() => "void"; + + public override bool Equals(NubType? other) => other is NubTypeVoid; + public override int GetHashCode() => HashCode.Combine(typeof(NubTypeVoid)); +} + +public sealed class NubTypeUInt(int width) : NubType +{ + public readonly int Width = width; + + public override string ToString() => $"u{Width}"; + + public override bool Equals(NubType? other) => other is NubTypeUInt otherUInt && Width == otherUInt.Width; + public override int GetHashCode() => HashCode.Combine(typeof(NubTypeUInt), Width); +} + +public sealed class NubTypeSInt(int width) : NubType +{ + public readonly int Width = width; + + public override string ToString() => $"i{Width}"; + + public override bool Equals(NubType? other) => other is NubTypeSInt otherUInt && Width == otherUInt.Width; + public override int GetHashCode() => HashCode.Combine(typeof(NubTypeSInt), Width); +} + +public sealed class NubTypeBool : NubType +{ + public override string ToString() => "bool"; + + public override bool Equals(NubType? other) => other is NubTypeBool; + public override int GetHashCode() => HashCode.Combine(typeof(NubTypeBool)); +} + +public sealed class NubTypeString : NubType +{ + public override string ToString() => "string"; + + public override bool Equals(NubType? other) => other is NubTypeString; + public override int GetHashCode() => HashCode.Combine(typeof(NubTypeString)); +} + +public sealed class NubTypeStruct(List fields) : NubType +{ + public readonly List Fields = fields; + public override string ToString() => $"struct {{ {string.Join(' ', Fields.Select(x => $"{x.Name}: {x.Type}"))} }}"; + + public override bool Equals(NubType? other) + { + if (other is not NubTypeStruct structType) + return false; + + if (Fields.Count != structType.Fields.Count) + return false; + + for (var i = 0; i < Fields.Count; i++) + { + if (Fields[i].Name != structType.Fields[i].Name) + return false; + + if (Fields[i].Type != structType.Fields[i].Type) + return false; + } + + return true; + } + + public override int GetHashCode() + { + var hash = new HashCode(); + hash.Add(typeof(NubTypeStruct)); + foreach (var field in Fields) + { + hash.Add(field.Name); + hash.Add(field.Type); + } + + return hash.ToHashCode(); + } + + public sealed class Field(string name, NubType type) + { + public readonly string Name = name; + public readonly NubType Type = type; + } +} + +public sealed class NubTypePointer(NubType to) : NubType +{ + public readonly NubType To = to; + public override string ToString() => $"^{To}"; + + public override bool Equals(NubType? other) => other is NubTypePointer pointer && To == pointer.To; + public override int GetHashCode() => HashCode.Combine(typeof(NubTypePointer)); +} + +public sealed class NubTypeFunc(List parameters, NubType returnType) : NubType +{ + public readonly List Parameters = parameters; + public readonly NubType ReturnType = returnType; + public override string ToString() => $"func({string.Join(' ', Parameters)}): {ReturnType}"; + + public override bool Equals(NubType? other) => other is NubTypeFunc func && ReturnType.Equals(func.ReturnType) && Parameters.SequenceEqual(func.Parameters); + + public override int GetHashCode() + { + var hash = new HashCode(); + hash.Add(typeof(NubTypeFunc)); + hash.Add(ReturnType); + foreach (var param in Parameters) + hash.Add(param); + + return hash.ToHashCode(); + } +} \ No newline at end of file