diff --git a/compiler/ModuleGraph.cs b/compiler/ModuleGraph.cs index 0dc3498..1704940 100644 --- a/compiler/ModuleGraph.cs +++ b/compiler/ModuleGraph.cs @@ -18,8 +18,9 @@ public class ModuleGraph(Dictionary modules) return module != null; } - public sealed class Module + public sealed class Module(string name) { + public string Name { get; } = name; private readonly Dictionary customTypes = new(); private readonly Dictionary identifierTypes = new(); @@ -84,7 +85,7 @@ public class ModuleGraph(Dictionary modules) if (!modules.ContainsKey(currentModule)) { - var module = new Module(); + var module = new Module(currentModule); modules.Add(currentModule, module); astModuleCache[ast] = module; } @@ -98,7 +99,7 @@ public class ModuleGraph(Dictionary modules) if (module == null) continue; foreach (var structDef in ast.Definitions.OfType()) - module.AddCustomType(structDef.Name.Ident, new NubTypeStruct()); + module.AddCustomType(structDef.Name.Ident, new NubTypeStruct(module.Name, structDef.Name.Ident)); } // Third pass: Resolve struct fields @@ -127,7 +128,7 @@ public class ModuleGraph(Dictionary modules) { var parameters = funcDef.Parameters.Select(x => ResolveType(x.Type)).ToList(); var returnType = ResolveType(funcDef.ReturnType); - var funcType = new NubTypeFunc(parameters, returnType); + var funcType = NubTypeFunc.Get(parameters, returnType); module.AddIdentifier(funcDef.Name.Ident, funcType); } } @@ -138,14 +139,14 @@ public class ModuleGraph(Dictionary modules) { return node switch { - NodeTypeBool => new NubTypeBool(), + NodeTypeBool => NubTypeBool.Instance, NodeTypeCustom type => ResolveCustomType(type), - NodeTypeFunc type => new NubTypeFunc(type.Parameters.Select(ResolveType).ToList(), ResolveType(type.ReturnType)), - NodeTypePointer type => new NubTypePointer(ResolveType(type.To)), - NodeTypeSInt type => new NubTypeSInt(type.Width), - NodeTypeUInt type => new NubTypeUInt(type.Width), - NodeTypeString => new NubTypeString(), - NodeTypeVoid => new NubTypeVoid(), + NodeTypeFunc type => NubTypeFunc.Get(type.Parameters.Select(ResolveType).ToList(), ResolveType(type.ReturnType)), + NodeTypePointer type => NubTypePointer.Get(ResolveType(type.To)), + NodeTypeSInt type => NubTypeSInt.Get(type.Width), + NodeTypeUInt type => NubTypeUInt.Get(type.Width), + NodeTypeString => NubTypeString.Instance, + NodeTypeVoid => NubTypeVoid.Instance, _ => throw new ArgumentOutOfRangeException(nameof(node)) }; } diff --git a/compiler/NubType.cs b/compiler/NubType.cs index 42a9fce..6a77525 100644 --- a/compiler/NubType.cs +++ b/compiler/NubType.cs @@ -1,157 +1,165 @@ namespace Compiler; -public abstract class NubType : IEquatable +public abstract class NubType { public abstract override string ToString(); - - public abstract bool Equals(NubType? other); - - public override bool Equals(object? obj) - { - if (obj is NubType otherNubType) - { - return Equals(otherNubType); - } - - return false; - } - - public abstract override int GetHashCode(); - - public static bool operator ==(NubType? left, NubType? right) => Equals(left, right); - public static bool operator !=(NubType? left, NubType? right) => !Equals(left, right); } public sealed class NubTypeVoid : NubType { - public override string ToString() => "void"; + public static readonly NubTypeVoid Instance = new(); - public override bool Equals(NubType? other) => other is NubTypeVoid; - public override int GetHashCode() => HashCode.Combine(typeof(NubTypeVoid)); + private NubTypeVoid() + { + } + + public override string ToString() => "void"; } -public sealed class NubTypeUInt(int width) : NubType +public sealed class NubTypeUInt : NubType { - public readonly int Width = width; + private static readonly Dictionary Cache = new(); + + public static NubTypeUInt Get(int width) + { + if (!Cache.TryGetValue(width, out var type)) + Cache[width] = type = new NubTypeUInt(width); + + return type; + } + + public int Width { get; } + + private NubTypeUInt(int width) + { + Width = width; + } public override string ToString() => $"u{Width}"; - - public override bool Equals(NubType? other) => other is NubTypeUInt otherUInt && Width == otherUInt.Width; - public override int GetHashCode() => HashCode.Combine(typeof(NubTypeUInt), Width); } -public sealed class NubTypeSInt(int width) : NubType +public sealed class NubTypeSInt : NubType { - public readonly int Width = width; + private static readonly Dictionary Cache = new(); + + public static NubTypeSInt Get(int width) + { + if (!Cache.TryGetValue(width, out var type)) + Cache[width] = type = new NubTypeSInt(width); + + return type; + } + + public int Width { get; } + + private NubTypeSInt(int width) + { + Width = width; + } public override string ToString() => $"i{Width}"; - - public override bool Equals(NubType? other) => other is NubTypeSInt otherUInt && Width == otherUInt.Width; - public override int GetHashCode() => HashCode.Combine(typeof(NubTypeSInt), Width); } public sealed class NubTypeBool : NubType { - public override string ToString() => "bool"; + public static readonly NubTypeBool Instance = new(); - public override bool Equals(NubType? other) => other is NubTypeBool; - public override int GetHashCode() => HashCode.Combine(typeof(NubTypeBool)); + private NubTypeBool() + { + } + + public override string ToString() => "bool"; } public sealed class NubTypeString : NubType { - public override string ToString() => "string"; + public static readonly NubTypeString Instance = new(); - public override bool Equals(NubType? other) => other is NubTypeString; - public override int GetHashCode() => HashCode.Combine(typeof(NubTypeString)); + private NubTypeString() + { + } + + public override string ToString() => "string"; } public sealed class NubTypeStruct : NubType { - private List? _resolvedFields; - public List Fields => _resolvedFields ?? throw new InvalidOperationException(); + public string Name { get; } + public string Module { get; } - public void ResolveFields(List fields) + private IReadOnlyList? _resolvedFields; + + public IReadOnlyList Fields => _resolvedFields ?? throw new InvalidOperationException(); + + public NubTypeStruct(string module, string name) + { + Module = module; + Name = name; + } + + public void ResolveFields(IReadOnlyList fields) { if (_resolvedFields != null) - throw new InvalidOperationException($"{ToString()} already resolved"); + throw new InvalidOperationException($"{Name} already resolved"); _resolvedFields = fields; } - public override string ToString() - { - if (_resolvedFields == null) - return "struct "; - - return $"struct {{ {string.Join(' ', Fields.Select(f => $"{f.Name}: {f.Type}"))} }}"; - } - - public override bool Equals(NubType? other) - { - if (other is not NubTypeStruct structType) - return false; - - if (Fields.Count != structType.Fields.Count) - return false; - - for (var i = 0; i < Fields.Count; i++) - { - if (Fields[i].Name != structType.Fields[i].Name) - return false; - - if (Fields[i].Type != structType.Fields[i].Type) - return false; - } - - return true; - } - - public override int GetHashCode() - { - var hash = new HashCode(); - hash.Add(typeof(NubTypeStruct)); - foreach (var field in Fields) - { - hash.Add(field.Name); - hash.Add(field.Type); - } - - return hash.ToHashCode(); - } + public override string ToString() => _resolvedFields == null ? $"struct {Module}::{Name} " : $"struct {Module}::{Name} {{ {string.Join(' ', Fields.Select(f => $"{f.Name}: {f.Type}"))} }}"; public sealed class Field(string name, NubType type) { - public readonly string Name = name; - public readonly NubType Type = type; + public string Name { get; } = name; + public NubType Type { get; } = type; } } -public sealed class NubTypePointer(NubType to) : NubType +public sealed class NubTypePointer : NubType { - public readonly NubType To = to; - public override string ToString() => $"^{To}"; + private static readonly Dictionary Cache = new(); - public override bool Equals(NubType? other) => other is NubTypePointer pointer && To == pointer.To; - public override int GetHashCode() => HashCode.Combine(typeof(NubTypePointer)); + public static NubTypePointer Get(NubType to) + { + if (!Cache.TryGetValue(to, out var ptr)) + Cache[to] = ptr = new NubTypePointer(to); + + return ptr; + } + + public NubType To { get; } + + private NubTypePointer(NubType to) + { + To = to; + } + + public override string ToString() => $"^{To}"; } -public sealed class NubTypeFunc(List parameters, NubType returnType) : NubType +public sealed class NubTypeFunc : NubType { - public readonly List Parameters = parameters; - public readonly NubType ReturnType = returnType; + private static readonly Dictionary Cache = new(); + + public static NubTypeFunc Get(List parameters, NubType returnType) + { + var sig = new Signature(parameters, returnType); + + if (!Cache.TryGetValue(sig, out var func)) + Cache[sig] = func = new NubTypeFunc(parameters, returnType); + + return func; + } + + public IReadOnlyList Parameters { get; } + public NubType ReturnType { get; } + + private NubTypeFunc(List parameters, NubType returnType) + { + Parameters = parameters; + ReturnType = returnType; + } + public override string ToString() => $"func({string.Join(' ', Parameters)}): {ReturnType}"; - public override bool Equals(NubType? other) => other is NubTypeFunc func && ReturnType.Equals(func.ReturnType) && Parameters.SequenceEqual(func.Parameters); - - public override int GetHashCode() - { - var hash = new HashCode(); - hash.Add(typeof(NubTypeFunc)); - hash.Add(ReturnType); - foreach (var param in Parameters) - hash.Add(param); - - return hash.ToHashCode(); - } + private readonly record struct Signature(IReadOnlyList Parameters, NubType ReturnType); } \ No newline at end of file diff --git a/compiler/TypeChecker.cs b/compiler/TypeChecker.cs index 0247c41..4c63907 100644 --- a/compiler/TypeChecker.cs +++ b/compiler/TypeChecker.cs @@ -183,7 +183,7 @@ public sealed class TypeChecker(string fileName, NodeDefinitionFunc function, Mo if (right.Type is not NubTypeSInt and not NubTypeUInt) throw new CompileException(Diagnostic.Error($"Unsupported type for right hand side of comparison: {right.Type}").At(fileName, right).Build()); - type = new NubTypeBool(); + type = NubTypeBool.Instance; break; } case NodeExpressionBinary.Op.LogicalAnd: @@ -195,7 +195,7 @@ public sealed class TypeChecker(string fileName, NodeDefinitionFunc function, Mo if (right.Type is not NubTypeBool) throw new CompileException(Diagnostic.Error($"Unsupported type for right hand side of logical operation: {right.Type}").At(fileName, right).Build()); - type = new NubTypeBool(); + type = NubTypeBool.Instance; break; } default: @@ -248,7 +248,7 @@ public sealed class TypeChecker(string fileName, NodeDefinitionFunc function, Mo if (target.Type is not NubTypeBool) throw new CompileException(Diagnostic.Error($"Unsupported type for inversion: {target.Type}").At(fileName, target).Build()); - type = new NubTypeBool(); + type = NubTypeBool.Instance; break; } default: @@ -270,7 +270,7 @@ public sealed class TypeChecker(string fileName, NodeDefinitionFunc function, Mo private TypedNodeExpressionBoolLiteral CheckExpressionBoolLiteral(NodeExpressionBoolLiteral expression) { - return new TypedNodeExpressionBoolLiteral(expression.Tokens, new NubTypeBool(), expression.Value); + return new TypedNodeExpressionBoolLiteral(expression.Tokens, NubTypeBool.Instance, expression.Value); } private TypedNodeExpressionLocalIdent CheckExpressionIdent(NodeExpressionLocalIdent expression) @@ -295,7 +295,7 @@ public sealed class TypeChecker(string fileName, NodeDefinitionFunc function, Mo private TypedNodeExpressionIntLiteral CheckExpressionIntLiteral(NodeExpressionIntLiteral expression) { - return new TypedNodeExpressionIntLiteral(expression.Tokens, new NubTypeSInt(32), expression.Value); + return new TypedNodeExpressionIntLiteral(expression.Tokens, NubTypeSInt.Get(32), expression.Value); } private TypedNodeExpressionMemberAccess CheckExpressionMemberAccess(NodeExpressionMemberAccess expression) @@ -313,7 +313,7 @@ public sealed class TypeChecker(string fileName, NodeDefinitionFunc function, Mo private TypedNodeExpressionStringLiteral CheckExpressionStringLiteral(NodeExpressionStringLiteral expression) { - return new TypedNodeExpressionStringLiteral(expression.Tokens, new NubTypeString(), expression.Value); + return new TypedNodeExpressionStringLiteral(expression.Tokens, NubTypeString.Instance, expression.Value); } private TypedNodeExpressionStructLiteral CheckExpressionStructLiteral(NodeExpressionStructLiteral expression) @@ -348,14 +348,14 @@ public sealed class TypeChecker(string fileName, NodeDefinitionFunc function, Mo { return node switch { - NodeTypeBool => new NubTypeBool(), + NodeTypeBool => NubTypeBool.Instance, NodeTypeCustom type => ResolveCustomType(type), - NodeTypeFunc type => new NubTypeFunc(type.Parameters.Select(ResolveType).ToList(), ResolveType(type.ReturnType)), - NodeTypePointer type => new NubTypePointer(ResolveType(type.To)), - NodeTypeSInt type => new NubTypeSInt(type.Width), - NodeTypeUInt type => new NubTypeUInt(type.Width), - NodeTypeString => new NubTypeString(), - NodeTypeVoid => new NubTypeVoid(), + NodeTypeFunc type => NubTypeFunc.Get(type.Parameters.Select(ResolveType).ToList(), ResolveType(type.ReturnType)), + NodeTypePointer type => NubTypePointer.Get(ResolveType(type.To)), + NodeTypeSInt type => NubTypeSInt.Get(type.Width), + NodeTypeUInt type => NubTypeUInt.Get(type.Width), + NodeTypeString => NubTypeString.Instance, + NodeTypeVoid => NubTypeVoid.Instance, _ => throw new ArgumentOutOfRangeException(nameof(node)) }; } @@ -365,7 +365,7 @@ public sealed class TypeChecker(string fileName, NodeDefinitionFunc function, Mo if (!moduleGraph.TryResolveModule(type.Module.Ident, out var module)) throw new CompileException(Diagnostic.Error($"Module '{type.Module.Ident}' not found").At(fileName, type.Module).Build()); - if (!module.TryResolveCustomType(type.Module.Ident, out var customType)) + if (!module.TryResolveCustomType(type.Name.Ident, out var customType)) throw new CompileException(Diagnostic.Error($"Custom type '{type.Module.Ident}::{type.Name.Ident}' not found").At(fileName, type.Name).Build()); return customType;