namespace Compiler; public sealed class TypeResolver(string fileName) { private readonly Dictionary<(string Module, string Name), NubTypeStruct> structTypes = []; public NubTypeStruct? GetNamedStruct(string module, string name) => structTypes.GetValueOrDefault((module, name)); public static TypeResolver Create(string fileName, Ast ast, out List diagnostics) { diagnostics = []; var resolver = new TypeResolver(fileName); var moduleDefinitions = ast.Definitions.OfType().ToList(); if (moduleDefinitions.Count == 0) diagnostics.Add(Diagnostic.Error($"'{fileName}' is not part of a module").At(fileName, 1, 1, 1).Build()); foreach (var moduleDefinition in moduleDefinitions.Skip(1)) diagnostics.Add(Diagnostic.Warning("Duplicate module definition").At(fileName, moduleDefinition).Build()); if (moduleDefinitions.Count >= 1) { var currentModule = moduleDefinitions[0].Name.Ident; foreach (var structDef in ast.Definitions.OfType()) { if (resolver.structTypes.ContainsKey((currentModule, structDef.Name.Ident))) { diagnostics.Add(Diagnostic.Error($"Duplicate struct: {structDef.Name.Ident}").At(fileName, structDef.Name).Build()); continue; } resolver.structTypes.Add((currentModule, structDef.Name.Ident), new NubTypeStruct()); } foreach (var structDef in ast.Definitions.OfType()) { var structType = resolver.structTypes[(currentModule, structDef.Name.Ident)]; try { structType.ResolveFields(structDef.Fields.Select(f => new NubTypeStruct.Field(f.Name.Ident, resolver.Resolve(f.Type))).ToList()); } catch (CompileException e) { diagnostics.Add(e.Diagnostic); } } } return resolver; } public NubType Resolve(NodeType node) { return node switch { NodeTypeBool type => new NubTypeBool(), NodeTypeCustom type => ResolveStruct(type), NodeTypeFunc type => new NubTypeFunc(type.Parameters.Select(Resolve).ToList(), Resolve(type.ReturnType)), NodeTypePointer type => new NubTypePointer(Resolve(type.To)), NodeTypeSInt type => new NubTypeSInt(type.Width), NodeTypeUInt type => new NubTypeUInt(type.Width), NodeTypeString type => new NubTypeString(), NodeTypeVoid type => new NubTypeVoid(), _ => throw new ArgumentOutOfRangeException(nameof(node)) }; } private NubTypeStruct ResolveStruct(NodeTypeCustom type) { var structType = structTypes.GetValueOrDefault((type.Module.Ident, type.Name.Ident)); if (structType == null) throw new CompileException(Diagnostic.Error($"Unknown custom type: {type.Module.Ident}::{type.Name.Ident}").At(fileName, type).Build()); return structType; } } public abstract class NubType : IEquatable { public abstract override string ToString(); public abstract bool Equals(NubType? other); public override bool Equals(object? obj) { if (obj is NubType otherNubType) { return Equals(otherNubType); } return false; } public abstract override int GetHashCode(); public static bool operator ==(NubType? left, NubType? right) => Equals(left, right); public static bool operator !=(NubType? left, NubType? right) => !Equals(left, right); } public sealed class NubTypeVoid : NubType { public override string ToString() => "void"; public override bool Equals(NubType? other) => other is NubTypeVoid; public override int GetHashCode() => HashCode.Combine(typeof(NubTypeVoid)); } public sealed class NubTypeUInt(int width) : NubType { public readonly int Width = width; public override string ToString() => $"u{Width}"; public override bool Equals(NubType? other) => other is NubTypeUInt otherUInt && Width == otherUInt.Width; public override int GetHashCode() => HashCode.Combine(typeof(NubTypeUInt), Width); } public sealed class NubTypeSInt(int width) : NubType { public readonly int Width = width; public override string ToString() => $"i{Width}"; public override bool Equals(NubType? other) => other is NubTypeSInt otherUInt && Width == otherUInt.Width; public override int GetHashCode() => HashCode.Combine(typeof(NubTypeSInt), Width); } public sealed class NubTypeBool : NubType { public override string ToString() => "bool"; public override bool Equals(NubType? other) => other is NubTypeBool; public override int GetHashCode() => HashCode.Combine(typeof(NubTypeBool)); } public sealed class NubTypeString : NubType { public override string ToString() => "string"; public override bool Equals(NubType? other) => other is NubTypeString; public override int GetHashCode() => HashCode.Combine(typeof(NubTypeString)); } public sealed class NubTypeStruct : NubType { private List? _resolvedFields; public List Fields => _resolvedFields ?? throw new InvalidOperationException(); public void ResolveFields(List fields) { if (_resolvedFields != null) throw new InvalidOperationException($"{ToString()} already resolved"); _resolvedFields = fields; } public override string ToString() { if (_resolvedFields == null) return "struct "; return $"struct {{ {string.Join(' ', Fields.Select(f => $"{f.Name}: {f.Type}"))} }}"; } public override bool Equals(NubType? other) { if (other is not NubTypeStruct structType) return false; if (Fields.Count != structType.Fields.Count) return false; for (var i = 0; i < Fields.Count; i++) { if (Fields[i].Name != structType.Fields[i].Name) return false; if (Fields[i].Type != structType.Fields[i].Type) return false; } return true; } public override int GetHashCode() { var hash = new HashCode(); hash.Add(typeof(NubTypeStruct)); foreach (var field in Fields) { hash.Add(field.Name); hash.Add(field.Type); } return hash.ToHashCode(); } public sealed class Field(string name, NubType type) { public readonly string Name = name; public readonly NubType Type = type; } } public sealed class NubTypePointer(NubType to) : NubType { public readonly NubType To = to; public override string ToString() => $"^{To}"; public override bool Equals(NubType? other) => other is NubTypePointer pointer && To == pointer.To; public override int GetHashCode() => HashCode.Combine(typeof(NubTypePointer)); } public sealed class NubTypeFunc(List parameters, NubType returnType) : NubType { public readonly List Parameters = parameters; public readonly NubType ReturnType = returnType; public override string ToString() => $"func({string.Join(' ', Parameters)}): {ReturnType}"; public override bool Equals(NubType? other) => other is NubTypeFunc func && ReturnType.Equals(func.ReturnType) && Parameters.SequenceEqual(func.Parameters); public override int GetHashCode() { var hash = new HashCode(); hash.Add(typeof(NubTypeFunc)); hash.Add(ReturnType); foreach (var param in Parameters) hash.Add(param); return hash.ToHashCode(); } }