diff --git a/compiler/Generator.cs b/compiler/Generator.cs index a3e22fc..c71a04e 100644 --- a/compiler/Generator.cs +++ b/compiler/Generator.cs @@ -28,7 +28,7 @@ public sealed class Generator(List functions, ModuleGra """); - foreach (var (i, structType) in moduleGraph.GetStructTypes().Index()) + foreach (var (i, structType) in moduleGraph.GetModules().SelectMany(x => x.GetStructTypes().Index())) structTypeNames[structType] = $"s{i}"; foreach (var typeName in structTypeNames) diff --git a/compiler/ModuleGraph.cs b/compiler/ModuleGraph.cs index 7b62513..1aac96d 100644 --- a/compiler/ModuleGraph.cs +++ b/compiler/ModuleGraph.cs @@ -1,38 +1,21 @@ +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; + namespace Compiler; public class ModuleGraph(Dictionary modules) { public static Builder Create() => new(); - public NubTypeStruct ResolveStruct(string moduleName, string name) + public List GetModules() { - var module = modules.GetValueOrDefault(moduleName); - if (module == null) - throw new CompileException(Diagnostic.Error($"Module '{moduleName}' not found").Build()); - - var structType = module.ResolveStruct(name); - if (structType == null) - throw new CompileException(Diagnostic.Error($"Struct '{moduleName}::{name}' not found").Build()); - - return structType; + return modules.Values.ToList(); } - public NubType ResolveIdentifier(string moduleName, string name) + public bool TryResolveModule(string moduleName, [NotNullWhen(true)] out Module? module) { - var module = modules.GetValueOrDefault(moduleName); - if (module == null) - throw new CompileException(Diagnostic.Error($"Module '{moduleName}' not found").Build()); - - var identType = module.ResolveIdentifier(name); - if (identType == null) - throw new CompileException(Diagnostic.Error($"Function '{moduleName}::{name}' not found").Build()); - - return identType; - } - - public List GetStructTypes() - { - return modules.SelectMany(x => x.Value.GetStructTypes()).ToList(); + module = modules.GetValueOrDefault(moduleName); + return module != null; } public sealed class Module @@ -45,19 +28,16 @@ public class ModuleGraph(Dictionary modules) return structTypes.Values.ToList(); } - public List GetIdentifiers() + public bool TryResolveStructType(string structName, [NotNullWhen(true)] out NubTypeStruct? structType) { - return identifierTypes.Values.ToList(); + structType = structTypes.GetValueOrDefault(structName); + return structType != null; } - public NubTypeStruct? ResolveStruct(string name) + public bool TryResolveIdentifierType(string identifierName, [NotNullWhen(true)] out NubType? identifier) { - return structTypes.GetValueOrDefault(name); - } - - public NubType? ResolveIdentifier(string name) - { - return identifierTypes.GetValueOrDefault(name); + identifier = identifierTypes.GetValueOrDefault(identifierName); + return identifier != null; } public void AddStruct(string name, NubTypeStruct structType) @@ -84,45 +64,54 @@ public class ModuleGraph(Dictionary modules) { diagnostics = []; + var astModuleCache = new Dictionary(); var modules = new Dictionary(); // First pass: Register modules foreach (var ast in asts) { var moduleDefinitions = ast.Definitions.OfType().ToList(); - var currentModule = moduleDefinitions[0].Name.Ident; - if (!modules.ContainsKey(currentModule)) - modules.Add(currentModule, new Module()); + if (moduleDefinitions.Count == 0) + diagnostics.Add(Diagnostic.Error("Missing module declaration").At(ast.FileName, 1, 1, 1).Build()); + + foreach (var extraModuleDefinition in moduleDefinitions.Skip(1)) + diagnostics.Add(Diagnostic.Warning("Duplicate module declaration will be ignored").At(ast.FileName, extraModuleDefinition).Build()); + + if (moduleDefinitions.Count >= 1) + { + var currentModule = moduleDefinitions[0].Name.Ident; + + if (!modules.ContainsKey(currentModule)) + { + var module = new Module(); + modules.Add(currentModule, module); + astModuleCache[ast] = module; + } + } } // Second pass: Register struct types without fields foreach (var ast in asts) { - var moduleDefinitions = ast.Definitions.OfType().ToList(); - var currentModule = moduleDefinitions[0].Name.Ident; - - if (!modules.TryGetValue(currentModule, out var module)) - { - module = new Module(); - modules[currentModule] = module; - } + var module = astModuleCache.GetValueOrDefault(ast); + if (module == null) continue; foreach (var structDef in ast.Definitions.OfType()) - { module.AddStruct(structDef.Name.Ident, new NubTypeStruct()); - } } // Third pass: Resolve struct fields foreach (var ast in asts) { - var moduleDefinitions = ast.Definitions.OfType().ToList(); - var module = modules[moduleDefinitions[0].Name.Ident]; + var module = astModuleCache.GetValueOrDefault(ast); + if (module == null) continue; foreach (var structDef in ast.Definitions.OfType()) { - var structType = module.ResolveStruct(structDef.Name.Ident); + if (!module.TryResolveStructType(structDef.Name.Ident, out var structType)) + throw new UnreachableException($"{nameof(structType)} should always be registered"); + var fields = structDef.Fields.Select(f => new NubTypeStruct.Field(f.Name.Ident, Resolve(f.Type))).ToList(); structType.ResolveFields(fields); } @@ -131,8 +120,8 @@ public class ModuleGraph(Dictionary modules) // Fourth pass: Register identifiers foreach (var ast in asts) { - var moduleDefinitions = ast.Definitions.OfType().ToList(); - var module = modules[moduleDefinitions[0].Name.Ident]; + var module = astModuleCache.GetValueOrDefault(ast); + if (module == null) continue; foreach (var funcDef in ast.Definitions.OfType()) { @@ -149,14 +138,14 @@ public class ModuleGraph(Dictionary modules) { return node switch { - NodeTypeBool type => new NubTypeBool(), + NodeTypeBool => new NubTypeBool(), NodeTypeCustom type => ResolveStruct(type), NodeTypeFunc type => new NubTypeFunc(type.Parameters.Select(Resolve).ToList(), Resolve(type.ReturnType)), NodeTypePointer type => new NubTypePointer(Resolve(type.To)), NodeTypeSInt type => new NubTypeSInt(type.Width), NodeTypeUInt type => new NubTypeUInt(type.Width), - NodeTypeString type => new NubTypeString(), - NodeTypeVoid type => new NubTypeVoid(), + NodeTypeString => new NubTypeString(), + NodeTypeVoid => new NubTypeVoid(), _ => throw new ArgumentOutOfRangeException(nameof(node)) }; } @@ -167,8 +156,7 @@ public class ModuleGraph(Dictionary modules) if (module == null) throw new CompileException(Diagnostic.Error($"Unknown module: {type.Module.Ident}").Build()); - var structType = module.ResolveStruct(type.Name.Ident); - if (structType == null) + if (!module.TryResolveStructType(type.Name.Ident, out var structType)) throw new CompileException(Diagnostic.Error($"Unknown custom type: {type.Module.Ident}::{type.Name.Ident}").Build()); return structType; diff --git a/compiler/TypeChecker.cs b/compiler/TypeChecker.cs index 0cf774f..567555e 100644 --- a/compiler/TypeChecker.cs +++ b/compiler/TypeChecker.cs @@ -284,7 +284,12 @@ public sealed class TypeChecker(string fileName, NodeDefinitionFunc function, Mo private TypedNodeExpressionModuleIdent CheckExpressionModuleIdent(NodeExpressionModuleIdent expression) { - var identifierType = moduleGraph.ResolveIdentifier(expression.Module.Ident, expression.Value.Ident); + if (!moduleGraph.TryResolveModule(expression.Module.Ident, out var module)) + throw new CompileException(Diagnostic.Error($"Module '{expression.Module.Ident}' not found").At(fileName, expression.Module).Build()); + + if (!module.TryResolveIdentifierType(expression.Value.Ident, out var identifierType)) + throw new CompileException(Diagnostic.Error($"Identifier '{expression.Module.Ident}::{expression.Value.Ident}' not found").At(fileName, expression.Value).Build()); + return new TypedNodeExpressionModuleIdent(expression.Tokens, identifierType, expression.Module, expression.Value); } @@ -301,7 +306,7 @@ public sealed class TypeChecker(string fileName, NodeDefinitionFunc function, Mo var field = structType.Fields.FirstOrDefault(x => x.Name == expression.Name.Ident); if (field == null) - throw new CompileException(Diagnostic.Error($"Struct {target.Type} does not have a field matching the name '{expression.Name.Ident}'").At(fileName, target).Build()); + throw new CompileException(Diagnostic.Error($"Struct '{target.Type}' does not have a field matching the name '{expression.Name.Ident}'").At(fileName, target).Build()); return new TypedNodeExpressionMemberAccess(expression.Tokens, field.Type, target, expression.Name); } @@ -313,14 +318,16 @@ public sealed class TypeChecker(string fileName, NodeDefinitionFunc function, Mo private TypedNodeExpressionStructLiteral CheckExpressionStructLiteral(NodeExpressionStructLiteral expression) { - var type = moduleGraph.ResolveStruct(expression.Module.Ident, expression.Name.Ident); - if (type == null) - throw new CompileException(Diagnostic.Error($"Undeclared struct '{expression.Module.Ident}::{expression.Name.Ident}'").At(fileName, expression.Name).Build()); + if (!moduleGraph.TryResolveModule(expression.Module.Ident, out var module)) + throw new CompileException(Diagnostic.Error($"Module '{expression.Module.Ident}' not found").At(fileName, expression.Module).Build()); + + if (!module.TryResolveStructType(expression.Name.Ident, out var structType)) + throw new CompileException(Diagnostic.Error($"Struct '{expression.Module.Ident}::{expression.Name.Ident}' not found").At(fileName, expression.Name).Build()); var initializers = new List(); foreach (var initializer in expression.Initializers) { - var field = type.Fields.FirstOrDefault(x => x.Name == initializer.Name.Ident); + var field = structType.Fields.FirstOrDefault(x => x.Name == initializer.Name.Ident); if (field == null) throw new CompileException(Diagnostic.Error($"Field '{initializer.Name.Ident}' does not exist on struct '{expression.Name.Ident}'").At(fileName, initializer.Name).Build()); @@ -331,25 +338,36 @@ public sealed class TypeChecker(string fileName, NodeDefinitionFunc function, Mo initializers.Add(new TypedNodeExpressionStructLiteral.Initializer(initializer.Tokens, initializer.Name, value)); } - return new TypedNodeExpressionStructLiteral(expression.Tokens, type, initializers); + return new TypedNodeExpressionStructLiteral(expression.Tokens, structType, initializers); } private NubType ResolveType(NodeType node) { return node switch { - NodeTypeBool type => new NubTypeBool(), - NodeTypeCustom type => moduleGraph.ResolveStruct(type.Module.Ident, type.Name.Ident), + NodeTypeBool => new NubTypeBool(), + NodeTypeCustom type => ResolveCustomType(type), NodeTypeFunc type => new NubTypeFunc(type.Parameters.Select(ResolveType).ToList(), ResolveType(type.ReturnType)), NodeTypePointer type => new NubTypePointer(ResolveType(type.To)), NodeTypeSInt type => new NubTypeSInt(type.Width), NodeTypeUInt type => new NubTypeUInt(type.Width), - NodeTypeString type => new NubTypeString(), - NodeTypeVoid type => new NubTypeVoid(), + NodeTypeString => new NubTypeString(), + NodeTypeVoid => new NubTypeVoid(), _ => throw new ArgumentOutOfRangeException(nameof(node)) }; } + private NubTypeStruct ResolveCustomType(NodeTypeCustom type) + { + if (!moduleGraph.TryResolveModule(type.Module.Ident, out var module)) + throw new CompileException(Diagnostic.Error($"Module '{type.Module.Ident}' not found").At(fileName, type.Module).Build()); + + if (!module.TryResolveStructType(type.Module.Ident, out var structType)) + throw new CompileException(Diagnostic.Error($"Custom type '{type.Module.Ident}::{type.Name.Ident}' not found").At(fileName, type.Name).Build()); + + return structType; + } + private class Scope(Scope? parent) { private readonly Dictionary identifiers = new();