From 89a827fdefc3322e13ef310967361bd7c32e6171 Mon Sep 17 00:00:00 2001 From: nub31 Date: Mon, 9 Feb 2026 15:52:14 +0100 Subject: [PATCH] struct resolution --- compiler/Program.cs | 22 +++++------- compiler/TypeChecker.cs | 77 ++++++++++++++++------------------------ compiler/TypeResolver.cs | 67 ++++++++++++++++++++++++++++++++++ 3 files changed, 106 insertions(+), 60 deletions(-) create mode 100644 compiler/TypeResolver.cs diff --git a/compiler/Program.cs b/compiler/Program.cs index 0b86813..84dd4c9 100644 --- a/compiler/Program.cs +++ b/compiler/Program.cs @@ -8,38 +8,34 @@ foreach (var fileName in args) var tokens = Tokenizer.Tokenize(fileName, file, out var tokenizerDiagnostics); foreach (var diagnostic in tokenizerDiagnostics) - { DiagnosticFormatter.Print(diagnostic, Console.Error); - } if (tokenizerDiagnostics.Any(x => x.Severity == DiagnosticSeverity.Error)) - { return 1; - } var ast = Parser.Parse(fileName, tokens, out var parserDiagnostics); foreach (var diagnostic in parserDiagnostics) - { DiagnosticFormatter.Print(diagnostic, Console.Error); - } if (parserDiagnostics.Any(x => x.Severity == DiagnosticSeverity.Error)) - { return 1; - } - var typedAst = TypeChecker.Check(fileName, ast, out var typeCheckerDiagnostics); + var typeResolver = TypeResolver.Create(fileName, ast, out var typeResolverDiagnostics); + + foreach (var diagnostic in typeResolverDiagnostics) + DiagnosticFormatter.Print(diagnostic, Console.Error); + + if (typeResolverDiagnostics.Any(x => x.Severity == DiagnosticSeverity.Error)) + return 1; + + var typedAst = TypeChecker.Check(fileName, ast, typeResolver, out var typeCheckerDiagnostics); foreach (var diagnostic in typeCheckerDiagnostics) - { DiagnosticFormatter.Print(diagnostic, Console.Error); - } if (typeCheckerDiagnostics.Any(x => x.Severity == DiagnosticSeverity.Error)) - { return 1; - } var output = Generator.Emit(typedAst); diff --git a/compiler/TypeChecker.cs b/compiler/TypeChecker.cs index ad565ff..3b9448f 100644 --- a/compiler/TypeChecker.cs +++ b/compiler/TypeChecker.cs @@ -1,30 +1,22 @@ namespace Compiler; -public sealed class TypeChecker(string fileName, Ast ast) +public sealed class TypeChecker(string fileName, Ast ast, TypeResolver typeResolver) { - public static TypedAst Check(string fileName, Ast ast, out List diagnostics) + public static TypedAst Check(string fileName, Ast ast, TypeResolver typeResolver, out List diagnostics) { - return new TypeChecker(fileName, ast).Check(out diagnostics); + return new TypeChecker(fileName, ast, typeResolver).Check(out diagnostics); } private Scope scope = new(null); - private Dictionary structTypes = new(); private TypedAst Check(out List diagnostics) { - var functions = new List(); diagnostics = []; - - // todo(nub31): Types must be resolved better to prevent circular dependencies and independent ordering - foreach (var structDef in ast.Structs) - { - var fields = structDef.Fields.Select(x => new NubTypeStruct.Field(x.Name.Ident, CheckType(x.Type))).ToList(); - structTypes.Add(structDef.Name.Ident, new NubTypeStruct(fields)); - } + var functions = new List(); foreach (var funcDef in ast.Functions) { - var type = new NubTypeFunc(funcDef.Parameters.Select(x => CheckType(x.Type)).ToList(), CheckType(funcDef.ReturnType)); + var type = new NubTypeFunc(funcDef.Parameters.Select(x => typeResolver.Resolve(x.Type)).ToList(), typeResolver.Resolve(funcDef.ReturnType)); scope.DeclareIdentifier(funcDef.Name.Ident, type); } @@ -40,17 +32,17 @@ public sealed class TypeChecker(string fileName, Ast ast) } } - return new TypedAst(structTypes.Values.ToList(), functions); + return new TypedAst(typeResolver.GetAllStructs(), functions); } private TypedNodeDefinitionFunc CheckDefinitionFunc(NodeDefinitionFunc definition) { - return new TypedNodeDefinitionFunc(definition.Tokens, definition.Name, definition.Parameters.Select(CheckDefinitionFuncParameter).ToList(), CheckStatement(definition.Body), CheckType(definition.ReturnType)); + return new TypedNodeDefinitionFunc(definition.Tokens, definition.Name, definition.Parameters.Select(CheckDefinitionFuncParameter).ToList(), CheckStatement(definition.Body), typeResolver.Resolve(definition.ReturnType)); } private TypedNodeDefinitionFunc.Param CheckDefinitionFuncParameter(NodeDefinitionFunc.Param node) { - return new TypedNodeDefinitionFunc.Param(node.Tokens, node.Name, CheckType(node.Type)); + return new TypedNodeDefinitionFunc.Param(node.Tokens, node.Name, typeResolver.Resolve(node.Type)); } private TypedNodeStatement CheckStatement(NodeStatement node) @@ -95,7 +87,7 @@ public sealed class TypeChecker(string fileName, Ast ast) private TypedNodeStatementVariableDeclaration CheckStatementVariableDeclaration(NodeStatementVariableDeclaration statement) { - var type = CheckType(statement.Type); + var type = typeResolver.Resolve(statement.Type); var value = CheckExpression(statement.Value); if (type != value.Type) @@ -299,7 +291,7 @@ public sealed class TypeChecker(string fileName, Ast ast) private TypedNodeExpressionStructLiteral CheckExpressionStructLiteral(NodeExpressionStructLiteral expression) { - var type = structTypes.GetValueOrDefault(expression.Name.Ident); + var type = typeResolver.GetNamedStruct(expression.Name.Ident); if (type == null) throw new CompileException(Diagnostic.Error($"Undeclared struct '{expression.Name.Ident}'").At(fileName, expression.Name).Build()); @@ -320,34 +312,9 @@ public sealed class TypeChecker(string fileName, Ast ast) return new TypedNodeExpressionStructLiteral(expression.Tokens, type, initializers); } - private NubType CheckType(NodeType node) - { - return node switch - { - NodeTypeBool type => new NubTypeBool(), - NodeTypeCustom type => CheckStructType(type), - NodeTypeFunc type => new NubTypeFunc(type.Parameters.Select(CheckType).ToList(), CheckType(type.ReturnType)), - NodeTypePointer type => new NubTypePointer(CheckType(type.To)), - NodeTypeSInt type => new NubTypeSInt(type.Width), - NodeTypeUInt type => new NubTypeUInt(type.Width), - NodeTypeString type => new NubTypeString(), - NodeTypeVoid type => new NubTypeVoid(), - _ => throw new ArgumentOutOfRangeException(nameof(node)) - }; - } - - private NubTypeStruct CheckStructType(NodeTypeCustom type) - { - var structType = structTypes.GetValueOrDefault(type.Name.Ident); - if (structType == null) - throw new CompileException(Diagnostic.Error($"Unknown custom type: {type}").At(fileName, type).Build()); - - return structType; - } - private class Scope(Scope? parent) { - private Dictionary identifiers = new(); + private readonly Dictionary identifiers = new(); public void DeclareIdentifier(string name, NubType type) { @@ -599,10 +566,26 @@ public sealed class NubTypeString : NubType public override int GetHashCode() => HashCode.Combine(typeof(NubTypeString)); } -public sealed class NubTypeStruct(List fields) : NubType +public sealed class NubTypeStruct : NubType { - public readonly List Fields = fields; - public override string ToString() => $"struct {{ {string.Join(' ', Fields.Select(x => $"{x.Name}: {x.Type}"))} }}"; + private List? _resolvedFields; + public List Fields => _resolvedFields ?? throw new InvalidOperationException(); + + public void ResolveFields(List fields) + { + if (_resolvedFields != null) + throw new InvalidOperationException($"{ToString()} already resolved"); + + _resolvedFields = fields; + } + + public override string ToString() + { + if (_resolvedFields == null) + return "struct "; + + return $"struct {{ {string.Join(' ', Fields.Select(f => $"{f.Name}: {f.Type}"))} }}"; + } public override bool Equals(NubType? other) { diff --git a/compiler/TypeResolver.cs b/compiler/TypeResolver.cs new file mode 100644 index 0000000..25c56e9 --- /dev/null +++ b/compiler/TypeResolver.cs @@ -0,0 +1,67 @@ +namespace Compiler; + +public sealed class TypeResolver(string fileName) +{ + private readonly Dictionary structTypes = []; + + public List GetAllStructs() => structTypes.Values.ToList(); + public NubTypeStruct? GetNamedStruct(string name) => structTypes.GetValueOrDefault(name); + + public static TypeResolver Create(string fileName, Ast ast, out List diagnostics) + { + diagnostics = []; + var resolver = new TypeResolver(fileName); + + foreach (var structDef in ast.Structs) + { + if (resolver.structTypes.ContainsKey(structDef.Name.Ident)) + { + diagnostics.Add(Diagnostic.Error($"Duplicate struct: {structDef.Name.Ident}").At(fileName, structDef.Name).Build()); + continue; + } + + resolver.structTypes.Add(structDef.Name.Ident, new NubTypeStruct()); + } + + foreach (var structDef in ast.Structs) + { + var structType = resolver.structTypes[structDef.Name.Ident]; + + try + { + structType.ResolveFields(structDef.Fields.Select(f => new NubTypeStruct.Field(f.Name.Ident, resolver.Resolve(f.Type))).ToList()); + } + catch (CompileException e) + { + diagnostics.Add(e.Diagnostic); + } + } + + return resolver; + } + + public NubType Resolve(NodeType node) + { + return node switch + { + NodeTypeBool type => new NubTypeBool(), + NodeTypeCustom type => ResolveStruct(type), + NodeTypeFunc type => new NubTypeFunc(type.Parameters.Select(Resolve).ToList(), Resolve(type.ReturnType)), + NodeTypePointer type => new NubTypePointer(Resolve(type.To)), + NodeTypeSInt type => new NubTypeSInt(type.Width), + NodeTypeUInt type => new NubTypeUInt(type.Width), + NodeTypeString type => new NubTypeString(), + NodeTypeVoid type => new NubTypeVoid(), + _ => throw new ArgumentOutOfRangeException(nameof(node)) + }; + } + + private NubTypeStruct ResolveStruct(NodeTypeCustom type) + { + var structType = structTypes.GetValueOrDefault(type.Name.Ident); + if (structType == null) + throw new CompileException(Diagnostic.Error($"Unknown custom type: {type}").At(fileName, type).Build()); + + return structType; + } +} \ No newline at end of file