diff --git a/compiler/Generator.cs b/compiler/Generator.cs index c71a04e..b6c0515 100644 --- a/compiler/Generator.cs +++ b/compiler/Generator.cs @@ -28,7 +28,7 @@ public sealed class Generator(List functions, ModuleGra """); - foreach (var (i, structType) in moduleGraph.GetModules().SelectMany(x => x.GetStructTypes().Index())) + foreach (var (i, structType) in moduleGraph.GetModules().SelectMany(x => x.GetCustomTypes().OfType().Index())) structTypeNames[structType] = $"s{i}"; foreach (var typeName in structTypeNames) diff --git a/compiler/ModuleGraph.cs b/compiler/ModuleGraph.cs index 1aac96d..0dc3498 100644 --- a/compiler/ModuleGraph.cs +++ b/compiler/ModuleGraph.cs @@ -20,29 +20,29 @@ public class ModuleGraph(Dictionary modules) public sealed class Module { - private readonly Dictionary structTypes = new(); + private readonly Dictionary customTypes = new(); private readonly Dictionary identifierTypes = new(); - public List GetStructTypes() + public List GetCustomTypes() { - return structTypes.Values.ToList(); + return customTypes.Values.ToList(); } - public bool TryResolveStructType(string structName, [NotNullWhen(true)] out NubTypeStruct? structType) + public bool TryResolveCustomType(string name, [NotNullWhen(true)] out NubType? customType) { - structType = structTypes.GetValueOrDefault(structName); - return structType != null; + customType = customTypes.GetValueOrDefault(name); + return customType != null; } - public bool TryResolveIdentifierType(string identifierName, [NotNullWhen(true)] out NubType? identifier) + public bool TryResolveIdentifierType(string name, [NotNullWhen(true)] out NubType? identifier) { - identifier = identifierTypes.GetValueOrDefault(identifierName); + identifier = identifierTypes.GetValueOrDefault(name); return identifier != null; } - public void AddStruct(string name, NubTypeStruct structType) + public void AddCustomType(string name, NubType type) { - structTypes.Add(name, structType); + customTypes.Add(name, type); } public void AddIdentifier(string name, NubType identifier) @@ -98,7 +98,7 @@ public class ModuleGraph(Dictionary modules) if (module == null) continue; foreach (var structDef in ast.Definitions.OfType()) - module.AddStruct(structDef.Name.Ident, new NubTypeStruct()); + module.AddCustomType(structDef.Name.Ident, new NubTypeStruct()); } // Third pass: Resolve struct fields @@ -109,11 +109,11 @@ public class ModuleGraph(Dictionary modules) foreach (var structDef in ast.Definitions.OfType()) { - if (!module.TryResolveStructType(structDef.Name.Ident, out var structType)) - throw new UnreachableException($"{nameof(structType)} should always be registered"); + if (!module.TryResolveCustomType(structDef.Name.Ident, out var customType)) + throw new UnreachableException($"{nameof(customType)} should always be registered"); - var fields = structDef.Fields.Select(f => new NubTypeStruct.Field(f.Name.Ident, Resolve(f.Type))).ToList(); - structType.ResolveFields(fields); + var fields = structDef.Fields.Select(f => new NubTypeStruct.Field(f.Name.Ident, ResolveType(f.Type))).ToList(); + ((NubTypeStruct)customType).ResolveFields(fields); } } @@ -125,8 +125,8 @@ public class ModuleGraph(Dictionary modules) foreach (var funcDef in ast.Definitions.OfType()) { - var parameters = funcDef.Parameters.Select(x => Resolve(x.Type)).ToList(); - var returnType = Resolve(funcDef.ReturnType); + var parameters = funcDef.Parameters.Select(x => ResolveType(x.Type)).ToList(); + var returnType = ResolveType(funcDef.ReturnType); var funcType = new NubTypeFunc(parameters, returnType); module.AddIdentifier(funcDef.Name.Ident, funcType); } @@ -134,14 +134,14 @@ public class ModuleGraph(Dictionary modules) return new ModuleGraph(modules); - NubType Resolve(NodeType node) + NubType ResolveType(NodeType node) { return node switch { NodeTypeBool => new NubTypeBool(), - NodeTypeCustom type => ResolveStruct(type), - NodeTypeFunc type => new NubTypeFunc(type.Parameters.Select(Resolve).ToList(), Resolve(type.ReturnType)), - NodeTypePointer type => new NubTypePointer(Resolve(type.To)), + NodeTypeCustom type => ResolveCustomType(type), + NodeTypeFunc type => new NubTypeFunc(type.Parameters.Select(ResolveType).ToList(), ResolveType(type.ReturnType)), + NodeTypePointer type => new NubTypePointer(ResolveType(type.To)), NodeTypeSInt type => new NubTypeSInt(type.Width), NodeTypeUInt type => new NubTypeUInt(type.Width), NodeTypeString => new NubTypeString(), @@ -150,16 +150,16 @@ public class ModuleGraph(Dictionary modules) }; } - NubTypeStruct ResolveStruct(NodeTypeCustom type) + NubType ResolveCustomType(NodeTypeCustom type) { var module = modules.GetValueOrDefault(type.Module.Ident); if (module == null) throw new CompileException(Diagnostic.Error($"Unknown module: {type.Module.Ident}").Build()); - if (!module.TryResolveStructType(type.Name.Ident, out var structType)) + if (!module.TryResolveCustomType(type.Name.Ident, out var customType)) throw new CompileException(Diagnostic.Error($"Unknown custom type: {type.Module.Ident}::{type.Name.Ident}").Build()); - return structType; + return customType; } } } diff --git a/compiler/TypeChecker.cs b/compiler/TypeChecker.cs index 567555e..0247c41 100644 --- a/compiler/TypeChecker.cs +++ b/compiler/TypeChecker.cs @@ -321,9 +321,12 @@ public sealed class TypeChecker(string fileName, NodeDefinitionFunc function, Mo if (!moduleGraph.TryResolveModule(expression.Module.Ident, out var module)) throw new CompileException(Diagnostic.Error($"Module '{expression.Module.Ident}' not found").At(fileName, expression.Module).Build()); - if (!module.TryResolveStructType(expression.Name.Ident, out var structType)) + if (!module.TryResolveCustomType(expression.Name.Ident, out var customType)) throw new CompileException(Diagnostic.Error($"Struct '{expression.Module.Ident}::{expression.Name.Ident}' not found").At(fileName, expression.Name).Build()); + if (customType is not NubTypeStruct structType) + throw new CompileException(Diagnostic.Error($"Cannot create struct literal of non-struct type '{expression.Module.Ident}::{expression.Name.Ident}'").At(fileName, expression.Name).Build()); + var initializers = new List(); foreach (var initializer in expression.Initializers) { @@ -357,15 +360,15 @@ public sealed class TypeChecker(string fileName, NodeDefinitionFunc function, Mo }; } - private NubTypeStruct ResolveCustomType(NodeTypeCustom type) + private NubType ResolveCustomType(NodeTypeCustom type) { if (!moduleGraph.TryResolveModule(type.Module.Ident, out var module)) throw new CompileException(Diagnostic.Error($"Module '{type.Module.Ident}' not found").At(fileName, type.Module).Build()); - if (!module.TryResolveStructType(type.Module.Ident, out var structType)) + if (!module.TryResolveCustomType(type.Module.Ident, out var customType)) throw new CompileException(Diagnostic.Error($"Custom type '{type.Module.Ident}::{type.Name.Ident}' not found").At(fileName, type.Name).Build()); - return structType; + return customType; } private class Scope(Scope? parent)