struct resolution

This commit is contained in:
nub31
2026-02-09 15:52:14 +01:00
parent f035499ba7
commit 89a827fdef
3 changed files with 106 additions and 60 deletions

View File

@@ -1,30 +1,22 @@
namespace Compiler;
public sealed class TypeChecker(string fileName, Ast ast)
public sealed class TypeChecker(string fileName, Ast ast, TypeResolver typeResolver)
{
public static TypedAst Check(string fileName, Ast ast, out List<Diagnostic> diagnostics)
public static TypedAst Check(string fileName, Ast ast, TypeResolver typeResolver, out List<Diagnostic> diagnostics)
{
return new TypeChecker(fileName, ast).Check(out diagnostics);
return new TypeChecker(fileName, ast, typeResolver).Check(out diagnostics);
}
private Scope scope = new(null);
private Dictionary<string, NubTypeStruct> structTypes = new();
private TypedAst Check(out List<Diagnostic> diagnostics)
{
var functions = new List<TypedNodeDefinitionFunc>();
diagnostics = [];
// todo(nub31): Types must be resolved better to prevent circular dependencies and independent ordering
foreach (var structDef in ast.Structs)
{
var fields = structDef.Fields.Select(x => new NubTypeStruct.Field(x.Name.Ident, CheckType(x.Type))).ToList();
structTypes.Add(structDef.Name.Ident, new NubTypeStruct(fields));
}
var functions = new List<TypedNodeDefinitionFunc>();
foreach (var funcDef in ast.Functions)
{
var type = new NubTypeFunc(funcDef.Parameters.Select(x => CheckType(x.Type)).ToList(), CheckType(funcDef.ReturnType));
var type = new NubTypeFunc(funcDef.Parameters.Select(x => typeResolver.Resolve(x.Type)).ToList(), typeResolver.Resolve(funcDef.ReturnType));
scope.DeclareIdentifier(funcDef.Name.Ident, type);
}
@@ -40,17 +32,17 @@ public sealed class TypeChecker(string fileName, Ast ast)
}
}
return new TypedAst(structTypes.Values.ToList(), functions);
return new TypedAst(typeResolver.GetAllStructs(), functions);
}
private TypedNodeDefinitionFunc CheckDefinitionFunc(NodeDefinitionFunc definition)
{
return new TypedNodeDefinitionFunc(definition.Tokens, definition.Name, definition.Parameters.Select(CheckDefinitionFuncParameter).ToList(), CheckStatement(definition.Body), CheckType(definition.ReturnType));
return new TypedNodeDefinitionFunc(definition.Tokens, definition.Name, definition.Parameters.Select(CheckDefinitionFuncParameter).ToList(), CheckStatement(definition.Body), typeResolver.Resolve(definition.ReturnType));
}
private TypedNodeDefinitionFunc.Param CheckDefinitionFuncParameter(NodeDefinitionFunc.Param node)
{
return new TypedNodeDefinitionFunc.Param(node.Tokens, node.Name, CheckType(node.Type));
return new TypedNodeDefinitionFunc.Param(node.Tokens, node.Name, typeResolver.Resolve(node.Type));
}
private TypedNodeStatement CheckStatement(NodeStatement node)
@@ -95,7 +87,7 @@ public sealed class TypeChecker(string fileName, Ast ast)
private TypedNodeStatementVariableDeclaration CheckStatementVariableDeclaration(NodeStatementVariableDeclaration statement)
{
var type = CheckType(statement.Type);
var type = typeResolver.Resolve(statement.Type);
var value = CheckExpression(statement.Value);
if (type != value.Type)
@@ -299,7 +291,7 @@ public sealed class TypeChecker(string fileName, Ast ast)
private TypedNodeExpressionStructLiteral CheckExpressionStructLiteral(NodeExpressionStructLiteral expression)
{
var type = structTypes.GetValueOrDefault(expression.Name.Ident);
var type = typeResolver.GetNamedStruct(expression.Name.Ident);
if (type == null)
throw new CompileException(Diagnostic.Error($"Undeclared struct '{expression.Name.Ident}'").At(fileName, expression.Name).Build());
@@ -320,34 +312,9 @@ public sealed class TypeChecker(string fileName, Ast ast)
return new TypedNodeExpressionStructLiteral(expression.Tokens, type, initializers);
}
private NubType CheckType(NodeType node)
{
return node switch
{
NodeTypeBool type => new NubTypeBool(),
NodeTypeCustom type => CheckStructType(type),
NodeTypeFunc type => new NubTypeFunc(type.Parameters.Select(CheckType).ToList(), CheckType(type.ReturnType)),
NodeTypePointer type => new NubTypePointer(CheckType(type.To)),
NodeTypeSInt type => new NubTypeSInt(type.Width),
NodeTypeUInt type => new NubTypeUInt(type.Width),
NodeTypeString type => new NubTypeString(),
NodeTypeVoid type => new NubTypeVoid(),
_ => throw new ArgumentOutOfRangeException(nameof(node))
};
}
private NubTypeStruct CheckStructType(NodeTypeCustom type)
{
var structType = structTypes.GetValueOrDefault(type.Name.Ident);
if (structType == null)
throw new CompileException(Diagnostic.Error($"Unknown custom type: {type}").At(fileName, type).Build());
return structType;
}
private class Scope(Scope? parent)
{
private Dictionary<string, NubType> identifiers = new();
private readonly Dictionary<string, NubType> identifiers = new();
public void DeclareIdentifier(string name, NubType type)
{
@@ -599,10 +566,26 @@ public sealed class NubTypeString : NubType
public override int GetHashCode() => HashCode.Combine(typeof(NubTypeString));
}
public sealed class NubTypeStruct(List<NubTypeStruct.Field> fields) : NubType
public sealed class NubTypeStruct : NubType
{
public readonly List<Field> Fields = fields;
public override string ToString() => $"struct {{ {string.Join(' ', Fields.Select(x => $"{x.Name}: {x.Type}"))} }}";
private List<Field>? _resolvedFields;
public List<Field> Fields => _resolvedFields ?? throw new InvalidOperationException();
public void ResolveFields(List<Field> fields)
{
if (_resolvedFields != null)
throw new InvalidOperationException($"{ToString()} already resolved");
_resolvedFields = fields;
}
public override string ToString()
{
if (_resolvedFields == null)
return "struct <unresolved>";
return $"struct {{ {string.Join(' ', Fields.Select(f => $"{f.Name}: {f.Type}"))} }}";
}
public override bool Equals(NubType? other)
{