...
This commit is contained in:
@@ -1,43 +1,58 @@
|
||||
namespace Compiler;
|
||||
|
||||
public sealed class TypeChecker(string fileName, Ast ast, TypeResolver typeResolver)
|
||||
public sealed class TypeChecker(string fileName, NodeDefinitionFunc function, TypeResolver typeResolver)
|
||||
{
|
||||
public static TypedAst Check(string fileName, Ast ast, TypeResolver typeResolver, out List<Diagnostic> diagnostics)
|
||||
public static TypedNodeDefinitionFunc? CheckFunction(string fileName, NodeDefinitionFunc function, TypeResolver typeResolver, out List<Diagnostic> diagnostics)
|
||||
{
|
||||
return new TypeChecker(fileName, ast, typeResolver).Check(out diagnostics);
|
||||
return new TypeChecker(fileName, function, typeResolver).CheckFunction(out diagnostics);
|
||||
}
|
||||
|
||||
private Scope scope = new(null);
|
||||
|
||||
private TypedAst Check(out List<Diagnostic> diagnostics)
|
||||
private TypedNodeDefinitionFunc? CheckFunction(out List<Diagnostic> diagnostics)
|
||||
{
|
||||
diagnostics = [];
|
||||
var functions = new List<TypedNodeDefinitionFunc>();
|
||||
|
||||
foreach (var funcDef in ast.Functions)
|
||||
{
|
||||
var type = new NubTypeFunc(funcDef.Parameters.Select(x => typeResolver.Resolve(x.Type)).ToList(), typeResolver.Resolve(funcDef.ReturnType));
|
||||
scope.DeclareIdentifier(funcDef.Name.Ident, type);
|
||||
}
|
||||
var parameters = new List<TypedNodeDefinitionFunc.Param>();
|
||||
var invalidParameter = false;
|
||||
TypedNodeStatement? body = null;
|
||||
NubType? returnType = null;
|
||||
|
||||
foreach (var funcDef in ast.Functions)
|
||||
foreach (var parameter in function.Parameters)
|
||||
{
|
||||
try
|
||||
{
|
||||
functions.Add(CheckDefinitionFunc(funcDef));
|
||||
parameters.Add(CheckDefinitionFuncParameter(parameter));
|
||||
}
|
||||
catch (CompileException e)
|
||||
{
|
||||
diagnostics.Add(e.Diagnostic);
|
||||
invalidParameter = true;
|
||||
}
|
||||
}
|
||||
|
||||
return new TypedAst(typeResolver.GetAllStructs(), functions);
|
||||
}
|
||||
try
|
||||
{
|
||||
body = CheckStatement(function.Body);
|
||||
}
|
||||
catch (CompileException e)
|
||||
{
|
||||
diagnostics.Add(e.Diagnostic);
|
||||
}
|
||||
|
||||
private TypedNodeDefinitionFunc CheckDefinitionFunc(NodeDefinitionFunc definition)
|
||||
{
|
||||
return new TypedNodeDefinitionFunc(definition.Tokens, definition.Name, definition.Parameters.Select(CheckDefinitionFuncParameter).ToList(), CheckStatement(definition.Body), typeResolver.Resolve(definition.ReturnType));
|
||||
try
|
||||
{
|
||||
returnType = typeResolver.Resolve(function.ReturnType);
|
||||
}
|
||||
catch (CompileException e)
|
||||
{
|
||||
diagnostics.Add(e.Diagnostic);
|
||||
}
|
||||
|
||||
if (body == null || returnType == null || invalidParameter)
|
||||
return null;
|
||||
|
||||
return new TypedNodeDefinitionFunc(function.Tokens, function.Name, parameters, body, returnType);
|
||||
}
|
||||
|
||||
private TypedNodeDefinitionFunc.Param CheckDefinitionFuncParameter(NodeDefinitionFunc.Param node)
|
||||
@@ -291,9 +306,9 @@ public sealed class TypeChecker(string fileName, Ast ast, TypeResolver typeResol
|
||||
|
||||
private TypedNodeExpressionStructLiteral CheckExpressionStructLiteral(NodeExpressionStructLiteral expression)
|
||||
{
|
||||
var type = typeResolver.GetNamedStruct(expression.Name.Ident);
|
||||
var type = typeResolver.GetNamedStruct(expression.Module.Ident, expression.Name.Ident);
|
||||
if (type == null)
|
||||
throw new CompileException(Diagnostic.Error($"Undeclared struct '{expression.Name.Ident}'").At(fileName, expression.Name).Build());
|
||||
throw new CompileException(Diagnostic.Error($"Undeclared struct '{expression.Module.Ident}::{expression.Name.Ident}'").At(fileName, expression.Name).Build());
|
||||
|
||||
var initializers = new List<TypedNodeExpressionStructLiteral.Initializer>();
|
||||
foreach (var initializer in expression.Initializers)
|
||||
@@ -330,12 +345,6 @@ public sealed class TypeChecker(string fileName, Ast ast, TypeResolver typeResol
|
||||
}
|
||||
}
|
||||
|
||||
public sealed class TypedAst(List<NubTypeStruct> structTypes, List<TypedNodeDefinitionFunc> functions)
|
||||
{
|
||||
public readonly List<NubTypeStruct> StructTypes = structTypes;
|
||||
public readonly List<TypedNodeDefinitionFunc> Functions = functions;
|
||||
}
|
||||
|
||||
public abstract class TypedNode(List<Token> tokens)
|
||||
{
|
||||
public readonly List<Token> Tokens = tokens;
|
||||
@@ -498,160 +507,4 @@ public sealed class TypedNodeExpressionUnary(List<Token> tokens, NubType type, T
|
||||
Negate,
|
||||
Invert,
|
||||
}
|
||||
}
|
||||
|
||||
public abstract class NubType : IEquatable<NubType>
|
||||
{
|
||||
public abstract override string ToString();
|
||||
|
||||
public abstract bool Equals(NubType? other);
|
||||
|
||||
public override bool Equals(object? obj)
|
||||
{
|
||||
if (obj is NubType otherNubType)
|
||||
{
|
||||
return Equals(otherNubType);
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
public abstract override int GetHashCode();
|
||||
|
||||
public static bool operator ==(NubType? left, NubType? right) => Equals(left, right);
|
||||
public static bool operator !=(NubType? left, NubType? right) => !Equals(left, right);
|
||||
}
|
||||
|
||||
public sealed class NubTypeVoid : NubType
|
||||
{
|
||||
public override string ToString() => "void";
|
||||
|
||||
public override bool Equals(NubType? other) => other is NubTypeVoid;
|
||||
public override int GetHashCode() => HashCode.Combine(typeof(NubTypeVoid));
|
||||
}
|
||||
|
||||
public sealed class NubTypeUInt(int width) : NubType
|
||||
{
|
||||
public readonly int Width = width;
|
||||
|
||||
public override string ToString() => $"u{Width}";
|
||||
|
||||
public override bool Equals(NubType? other) => other is NubTypeUInt otherUInt && Width == otherUInt.Width;
|
||||
public override int GetHashCode() => HashCode.Combine(typeof(NubTypeUInt), Width);
|
||||
}
|
||||
|
||||
public sealed class NubTypeSInt(int width) : NubType
|
||||
{
|
||||
public readonly int Width = width;
|
||||
|
||||
public override string ToString() => $"i{Width}";
|
||||
|
||||
public override bool Equals(NubType? other) => other is NubTypeSInt otherUInt && Width == otherUInt.Width;
|
||||
public override int GetHashCode() => HashCode.Combine(typeof(NubTypeSInt), Width);
|
||||
}
|
||||
|
||||
public sealed class NubTypeBool : NubType
|
||||
{
|
||||
public override string ToString() => "bool";
|
||||
|
||||
public override bool Equals(NubType? other) => other is NubTypeBool;
|
||||
public override int GetHashCode() => HashCode.Combine(typeof(NubTypeBool));
|
||||
}
|
||||
|
||||
public sealed class NubTypeString : NubType
|
||||
{
|
||||
public override string ToString() => "string";
|
||||
|
||||
public override bool Equals(NubType? other) => other is NubTypeString;
|
||||
public override int GetHashCode() => HashCode.Combine(typeof(NubTypeString));
|
||||
}
|
||||
|
||||
public sealed class NubTypeStruct : NubType
|
||||
{
|
||||
private List<Field>? _resolvedFields;
|
||||
public List<Field> Fields => _resolvedFields ?? throw new InvalidOperationException();
|
||||
|
||||
public void ResolveFields(List<Field> fields)
|
||||
{
|
||||
if (_resolvedFields != null)
|
||||
throw new InvalidOperationException($"{ToString()} already resolved");
|
||||
|
||||
_resolvedFields = fields;
|
||||
}
|
||||
|
||||
public override string ToString()
|
||||
{
|
||||
if (_resolvedFields == null)
|
||||
return "struct <unresolved>";
|
||||
|
||||
return $"struct {{ {string.Join(' ', Fields.Select(f => $"{f.Name}: {f.Type}"))} }}";
|
||||
}
|
||||
|
||||
public override bool Equals(NubType? other)
|
||||
{
|
||||
if (other is not NubTypeStruct structType)
|
||||
return false;
|
||||
|
||||
if (Fields.Count != structType.Fields.Count)
|
||||
return false;
|
||||
|
||||
for (var i = 0; i < Fields.Count; i++)
|
||||
{
|
||||
if (Fields[i].Name != structType.Fields[i].Name)
|
||||
return false;
|
||||
|
||||
if (Fields[i].Type != structType.Fields[i].Type)
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
public override int GetHashCode()
|
||||
{
|
||||
var hash = new HashCode();
|
||||
hash.Add(typeof(NubTypeStruct));
|
||||
foreach (var field in Fields)
|
||||
{
|
||||
hash.Add(field.Name);
|
||||
hash.Add(field.Type);
|
||||
}
|
||||
|
||||
return hash.ToHashCode();
|
||||
}
|
||||
|
||||
public sealed class Field(string name, NubType type)
|
||||
{
|
||||
public readonly string Name = name;
|
||||
public readonly NubType Type = type;
|
||||
}
|
||||
}
|
||||
|
||||
public sealed class NubTypePointer(NubType to) : NubType
|
||||
{
|
||||
public readonly NubType To = to;
|
||||
public override string ToString() => $"^{To}";
|
||||
|
||||
public override bool Equals(NubType? other) => other is NubTypePointer pointer && To == pointer.To;
|
||||
public override int GetHashCode() => HashCode.Combine(typeof(NubTypePointer));
|
||||
}
|
||||
|
||||
public sealed class NubTypeFunc(List<NubType> parameters, NubType returnType) : NubType
|
||||
{
|
||||
public readonly List<NubType> Parameters = parameters;
|
||||
public readonly NubType ReturnType = returnType;
|
||||
public override string ToString() => $"func({string.Join(' ', Parameters)}): {ReturnType}";
|
||||
|
||||
public override bool Equals(NubType? other) => other is NubTypeFunc func && ReturnType.Equals(func.ReturnType) && Parameters.SequenceEqual(func.Parameters);
|
||||
|
||||
public override int GetHashCode()
|
||||
{
|
||||
var hash = new HashCode();
|
||||
hash.Add(typeof(NubTypeFunc));
|
||||
hash.Add(ReturnType);
|
||||
foreach (var param in Parameters)
|
||||
hash.Add(param);
|
||||
|
||||
return hash.ToHashCode();
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user