235 lines
7.7 KiB
C#
235 lines
7.7 KiB
C#
namespace Compiler;
|
|
|
|
public sealed class TypeResolver(string fileName)
|
|
{
|
|
private readonly Dictionary<(string Module, string Name), NubTypeStruct> structTypes = [];
|
|
|
|
public NubTypeStruct? GetNamedStruct(string module, string name) => structTypes.GetValueOrDefault((module, name));
|
|
|
|
public static TypeResolver Create(string fileName, Ast ast, out List<Diagnostic> diagnostics)
|
|
{
|
|
diagnostics = [];
|
|
var resolver = new TypeResolver(fileName);
|
|
|
|
var moduleDefinitions = ast.Definitions.OfType<NodeDefinitionModule>().ToList();
|
|
|
|
if (moduleDefinitions.Count == 0)
|
|
diagnostics.Add(Diagnostic.Error($"'{fileName}' is not part of a module").At(fileName, 1, 1, 1).Build());
|
|
|
|
foreach (var moduleDefinition in moduleDefinitions.Skip(1))
|
|
diagnostics.Add(Diagnostic.Warning("Duplicate module definition").At(fileName, moduleDefinition).Build());
|
|
|
|
if (moduleDefinitions.Count >= 1)
|
|
{
|
|
var currentModule = moduleDefinitions[0].Name.Ident;
|
|
|
|
foreach (var structDef in ast.Definitions.OfType<NodeDefinitionStruct>())
|
|
{
|
|
if (resolver.structTypes.ContainsKey((currentModule, structDef.Name.Ident)))
|
|
{
|
|
diagnostics.Add(Diagnostic.Error($"Duplicate struct: {structDef.Name.Ident}").At(fileName, structDef.Name).Build());
|
|
continue;
|
|
}
|
|
|
|
resolver.structTypes.Add((currentModule, structDef.Name.Ident), new NubTypeStruct());
|
|
}
|
|
|
|
foreach (var structDef in ast.Definitions.OfType<NodeDefinitionStruct>())
|
|
{
|
|
var structType = resolver.structTypes[(currentModule, structDef.Name.Ident)];
|
|
|
|
try
|
|
{
|
|
structType.ResolveFields(structDef.Fields.Select(f => new NubTypeStruct.Field(f.Name.Ident, resolver.Resolve(f.Type))).ToList());
|
|
}
|
|
catch (CompileException e)
|
|
{
|
|
diagnostics.Add(e.Diagnostic);
|
|
}
|
|
}
|
|
}
|
|
|
|
return resolver;
|
|
}
|
|
|
|
public NubType Resolve(NodeType node)
|
|
{
|
|
return node switch
|
|
{
|
|
NodeTypeBool type => new NubTypeBool(),
|
|
NodeTypeCustom type => ResolveStruct(type),
|
|
NodeTypeFunc type => new NubTypeFunc(type.Parameters.Select(Resolve).ToList(), Resolve(type.ReturnType)),
|
|
NodeTypePointer type => new NubTypePointer(Resolve(type.To)),
|
|
NodeTypeSInt type => new NubTypeSInt(type.Width),
|
|
NodeTypeUInt type => new NubTypeUInt(type.Width),
|
|
NodeTypeString type => new NubTypeString(),
|
|
NodeTypeVoid type => new NubTypeVoid(),
|
|
_ => throw new ArgumentOutOfRangeException(nameof(node))
|
|
};
|
|
}
|
|
|
|
private NubTypeStruct ResolveStruct(NodeTypeCustom type)
|
|
{
|
|
var structType = structTypes.GetValueOrDefault((type.Module.Ident, type.Name.Ident));
|
|
if (structType == null)
|
|
throw new CompileException(Diagnostic.Error($"Unknown custom type: {type.Module.Ident}::{type.Name.Ident}").At(fileName, type).Build());
|
|
|
|
return structType;
|
|
}
|
|
}
|
|
|
|
public abstract class NubType : IEquatable<NubType>
|
|
{
|
|
public abstract override string ToString();
|
|
|
|
public abstract bool Equals(NubType? other);
|
|
|
|
public override bool Equals(object? obj)
|
|
{
|
|
if (obj is NubType otherNubType)
|
|
{
|
|
return Equals(otherNubType);
|
|
}
|
|
|
|
return false;
|
|
}
|
|
|
|
public abstract override int GetHashCode();
|
|
|
|
public static bool operator ==(NubType? left, NubType? right) => Equals(left, right);
|
|
public static bool operator !=(NubType? left, NubType? right) => !Equals(left, right);
|
|
}
|
|
|
|
public sealed class NubTypeVoid : NubType
|
|
{
|
|
public override string ToString() => "void";
|
|
|
|
public override bool Equals(NubType? other) => other is NubTypeVoid;
|
|
public override int GetHashCode() => HashCode.Combine(typeof(NubTypeVoid));
|
|
}
|
|
|
|
public sealed class NubTypeUInt(int width) : NubType
|
|
{
|
|
public readonly int Width = width;
|
|
|
|
public override string ToString() => $"u{Width}";
|
|
|
|
public override bool Equals(NubType? other) => other is NubTypeUInt otherUInt && Width == otherUInt.Width;
|
|
public override int GetHashCode() => HashCode.Combine(typeof(NubTypeUInt), Width);
|
|
}
|
|
|
|
public sealed class NubTypeSInt(int width) : NubType
|
|
{
|
|
public readonly int Width = width;
|
|
|
|
public override string ToString() => $"i{Width}";
|
|
|
|
public override bool Equals(NubType? other) => other is NubTypeSInt otherUInt && Width == otherUInt.Width;
|
|
public override int GetHashCode() => HashCode.Combine(typeof(NubTypeSInt), Width);
|
|
}
|
|
|
|
public sealed class NubTypeBool : NubType
|
|
{
|
|
public override string ToString() => "bool";
|
|
|
|
public override bool Equals(NubType? other) => other is NubTypeBool;
|
|
public override int GetHashCode() => HashCode.Combine(typeof(NubTypeBool));
|
|
}
|
|
|
|
public sealed class NubTypeString : NubType
|
|
{
|
|
public override string ToString() => "string";
|
|
|
|
public override bool Equals(NubType? other) => other is NubTypeString;
|
|
public override int GetHashCode() => HashCode.Combine(typeof(NubTypeString));
|
|
}
|
|
|
|
public sealed class NubTypeStruct : NubType
|
|
{
|
|
private List<Field>? _resolvedFields;
|
|
public List<Field> Fields => _resolvedFields ?? throw new InvalidOperationException();
|
|
|
|
public void ResolveFields(List<Field> fields)
|
|
{
|
|
if (_resolvedFields != null)
|
|
throw new InvalidOperationException($"{ToString()} already resolved");
|
|
|
|
_resolvedFields = fields;
|
|
}
|
|
|
|
public override string ToString()
|
|
{
|
|
if (_resolvedFields == null)
|
|
return "struct <unresolved>";
|
|
|
|
return $"struct {{ {string.Join(' ', Fields.Select(f => $"{f.Name}: {f.Type}"))} }}";
|
|
}
|
|
|
|
public override bool Equals(NubType? other)
|
|
{
|
|
if (other is not NubTypeStruct structType)
|
|
return false;
|
|
|
|
if (Fields.Count != structType.Fields.Count)
|
|
return false;
|
|
|
|
for (var i = 0; i < Fields.Count; i++)
|
|
{
|
|
if (Fields[i].Name != structType.Fields[i].Name)
|
|
return false;
|
|
|
|
if (Fields[i].Type != structType.Fields[i].Type)
|
|
return false;
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
public override int GetHashCode()
|
|
{
|
|
var hash = new HashCode();
|
|
hash.Add(typeof(NubTypeStruct));
|
|
foreach (var field in Fields)
|
|
{
|
|
hash.Add(field.Name);
|
|
hash.Add(field.Type);
|
|
}
|
|
|
|
return hash.ToHashCode();
|
|
}
|
|
|
|
public sealed class Field(string name, NubType type)
|
|
{
|
|
public readonly string Name = name;
|
|
public readonly NubType Type = type;
|
|
}
|
|
}
|
|
|
|
public sealed class NubTypePointer(NubType to) : NubType
|
|
{
|
|
public readonly NubType To = to;
|
|
public override string ToString() => $"^{To}";
|
|
|
|
public override bool Equals(NubType? other) => other is NubTypePointer pointer && To == pointer.To;
|
|
public override int GetHashCode() => HashCode.Combine(typeof(NubTypePointer));
|
|
}
|
|
|
|
public sealed class NubTypeFunc(List<NubType> parameters, NubType returnType) : NubType
|
|
{
|
|
public readonly List<NubType> Parameters = parameters;
|
|
public readonly NubType ReturnType = returnType;
|
|
public override string ToString() => $"func({string.Join(' ', Parameters)}): {ReturnType}";
|
|
|
|
public override bool Equals(NubType? other) => other is NubTypeFunc func && ReturnType.Equals(func.ReturnType) && Parameters.SequenceEqual(func.Parameters);
|
|
|
|
public override int GetHashCode()
|
|
{
|
|
var hash = new HashCode();
|
|
hash.Add(typeof(NubTypeFunc));
|
|
hash.Add(ReturnType);
|
|
foreach (var param in Parameters)
|
|
hash.Add(param);
|
|
|
|
return hash.ToHashCode();
|
|
}
|
|
} |