This commit is contained in:
nub31
2025-09-12 16:58:09 +02:00
parent adcc9f3580
commit ef1720195d
12 changed files with 212 additions and 278 deletions

View File

@@ -1,189 +0,0 @@
using NubLang.Parsing.Syntax;
using NubLang.TypeChecking.Node;
namespace NubLang.TypeChecking;
public class Module
{
public static IReadOnlyDictionary<string, Module> CollectFromSyntaxTrees(IReadOnlyList<SyntaxTree> syntaxTrees)
{
var modules = new Dictionary<string, Module>();
foreach (var syntaxTree in syntaxTrees)
{
var moduleName = syntaxTree.Metadata.ModuleName;
if (!modules.TryGetValue(moduleName, out var module))
{
module = new Module();
modules[moduleName] = module;
}
foreach (var def in syntaxTree.Definitions)
{
switch (def)
{
case FuncSyntax funcDef:
{
var parameters = funcDef.Signature.Parameters.Select(p => ResolveType(p.Type)).ToList();
var returnType = ResolveType(funcDef.Signature.ReturnType);
var type = new FuncTypeNode(parameters, returnType);
module._functions.Add(new ModuleFuncType(def.Exported, funcDef.Name, funcDef.ExternSymbol, type));
break;
}
case InterfaceSyntax interfaceDef:
{
var functions = new List<InterfaceTypeFunc>();
for (var i = 0; i < interfaceDef.Functions.Count; i++)
{
var function = interfaceDef.Functions[i];
var parameters = function.Signature.Parameters.Select(p => ResolveType(p.Type)).ToList();
var returnType = ResolveType(function.Signature.ReturnType);
functions.Add(new InterfaceTypeFunc(function.Name, new FuncTypeNode(parameters, returnType), i));
}
var type = new InterfaceTypeNode(moduleName, interfaceDef.Name, functions);
module._interfaces.Add(new ModuleInterfaceType(def.Exported, type));
break;
}
case StructSyntax structDef:
{
var fields = new List<StructTypeField>();
foreach (var field in structDef.Fields)
{
fields.Add(new StructTypeField(field.Name, ResolveType(field.Type), field.Index, field.Value.HasValue));
}
var functions = new List<StructTypeFunc>();
foreach (var function in structDef.Functions)
{
var parameters = function.Signature.Parameters.Select(p => ResolveType(p.Type)).ToList();
var returnType = ResolveType(function.Signature.ReturnType);
functions.Add(new StructTypeFunc(function.Name, new FuncTypeNode(parameters, returnType)));
}
var interfaceImplementations = new List<InterfaceTypeNode>();
foreach (var interfaceImplementation in structDef.InterfaceImplementations)
{
if (interfaceImplementation is not CustomTypeSyntax customType)
{
throw new Exception("Interface implementation is not a custom type");
}
var resolvedType = ResolveType(customType);
if (resolvedType is not InterfaceTypeNode interfaceType)
{
throw new Exception("Interface implementation is not a interface");
}
interfaceImplementations.Add(interfaceType);
}
var type = new StructTypeNode(moduleName, structDef.Name, fields, functions, interfaceImplementations);
module._structs.Add(new ModuleStructType(def.Exported, type));
break;
}
default:
{
throw new ArgumentOutOfRangeException(nameof(def));
}
}
}
}
return modules;
TypeNode ResolveType(TypeSyntax type)
{
return type switch
{
BoolTypeSyntax => new BoolTypeNode(),
CStringTypeSyntax => new CStringTypeNode(),
IntTypeSyntax i => new IntTypeNode(i.Signed, i.Width),
CustomTypeSyntax c => ResolveCustomType(c.Module, c.Name),
FloatTypeSyntax f => new FloatTypeNode(f.Width),
FuncTypeSyntax func => new FuncTypeNode(func.Parameters.Select(ResolveType).ToList(), ResolveType(func.ReturnType)),
ArrayTypeSyntax arr => new ArrayTypeNode(ResolveType(arr.BaseType)),
PointerTypeSyntax ptr => new PointerTypeNode(ResolveType(ptr.BaseType)),
StringTypeSyntax => new StringTypeNode(),
VoidTypeSyntax => new VoidTypeNode(),
_ => throw new NotSupportedException($"Unknown type syntax: {type}")
};
}
TypeNode ResolveCustomType(string moduleName, string typeName)
{
if (!modules.TryGetValue(moduleName, out var module))
{
throw new Exception("Module not found: " + moduleName);
}
var structType = module.AllStructTypes.FirstOrDefault(x => x.StructType.Name == typeName);
if (structType != null)
{
return structType.StructType;
}
var interfaceType = module.AllInterfaceTypes.FirstOrDefault(x => x.InterfaceType.Name == typeName);
if (interfaceType != null)
{
return interfaceType.InterfaceType;
}
throw new Exception($"Type {typeName} not found in module {moduleName}");
}
}
private readonly List<ModuleStructType> _structs = [];
private readonly List<ModuleInterfaceType> _interfaces = [];
private readonly List<ModuleFuncType> _functions = [];
public IReadOnlyList<ModuleStructType> ExportedStructTypes => _structs.Where(x => x.Exported).ToList();
public IReadOnlyList<ModuleInterfaceType> ExportedInterfaceTypes => _interfaces.Where(x => x.Exported).ToList();
public IReadOnlyList<ModuleFuncType> ExportedFunctions => _functions.Where(x => x.Exported).ToList();
public IReadOnlyList<ModuleStructType> AllStructTypes => _structs;
public IReadOnlyList<ModuleInterfaceType> AllInterfaceTypes => _interfaces;
public IReadOnlyList<ModuleFuncType> AllFunctions => _functions;
}
public class ModuleStructType
{
public ModuleStructType(bool exported, StructTypeNode structType)
{
Exported = exported;
StructType = structType;
}
public bool Exported { get; }
public StructTypeNode StructType { get; }
}
public class ModuleInterfaceType
{
public ModuleInterfaceType(bool exported, InterfaceTypeNode interfaceType)
{
Exported = exported;
InterfaceType = interfaceType;
}
public bool Exported { get; }
public InterfaceTypeNode InterfaceType { get; }
}
public class ModuleFuncType
{
public ModuleFuncType(bool exported, string name, string? externSymbol, FuncTypeNode funcType)
{
Exported = exported;
Name = name;
ExternSymbol = externSymbol;
FuncType = funcType;
}
public bool Exported { get; }
public string Name { get; }
public string? ExternSymbol { get; }
public FuncTypeNode FuncType { get; }
}

View File

@@ -8,7 +8,7 @@ public record FuncSignatureNode(IReadOnlyList<FuncParameterNode> Parameters, Typ
public record FuncNode(string Module, string Name, string? ExternSymbol, FuncSignatureNode Signature, BlockNode? Body) : DefinitionNode(Module, Name);
public record StructFieldNode(int Index, string Name, TypeNode Type, Optional<ExpressionNode> Value) : Node;
public record StructFieldNode(string Name, TypeNode Type, Optional<ExpressionNode> Value) : Node;
public record StructFuncNode(string Name, FuncSignatureNode Signature, BlockNode Body) : Node;

View File

@@ -49,7 +49,7 @@ public record VariableIdentifierNode(TypeNode Type, string Name) : LValueExpress
public record FuncParameterIdentifierNode(TypeNode Type, string Name) : RValueExpressionNode(Type);
public record FuncIdentifierNode(TypeNode Type, string Module, string Name) : RValueExpressionNode(Type);
public record FuncIdentifierNode(TypeNode Type, string Module, string Name, string? ExternSymbol) : RValueExpressionNode(Type);
public record ArrayInitializerNode(TypeNode Type, ExpressionNode Capacity, TypeNode ElementType) : RValueExpressionNode(Type);

View File

@@ -174,11 +174,10 @@ public class StringTypeNode : ComplexTypeNode
public override int GetHashCode() => HashCode.Combine(typeof(StringTypeNode));
}
public class StructTypeField(string name, TypeNode type, int index, bool hasDefaultValue)
public class StructTypeField(string name, TypeNode type, bool hasDefaultValue)
{
public string Name { get; } = name;
public TypeNode Type { get; } = type;
public int Index { get; } = index;
public bool HasDefaultValue { get; } = hasDefaultValue;
}

View File

@@ -1,5 +1,6 @@
using System.Diagnostics;
using NubLang.Diagnostics;
using NubLang.Modules;
using NubLang.Parsing.Syntax;
using NubLang.Tokenization;
using NubLang.TypeChecking.Node;
@@ -9,42 +10,51 @@ namespace NubLang.TypeChecking;
public sealed class TypeChecker
{
private readonly SyntaxTree _syntaxTree;
private readonly IReadOnlyDictionary<string, Module> _importedModules;
private readonly Dictionary<string, Module> _visibleModules;
private readonly Stack<Scope> _scopes = [];
private readonly Stack<TypeNode> _funcReturnTypes = [];
private readonly List<Diagnostic> _diagnostics = [];
private readonly List<StructTypeNode> _referencedStructTypes = [];
private readonly List<InterfaceTypeNode> _referencedInterfaceTypes = [];
private readonly List<DefinitionNode> _definitions = [];
private Scope Scope => _scopes.Peek();
public TypeChecker(SyntaxTree syntaxTree, IReadOnlyDictionary<string, Module> moduleSignatures)
public TypeChecker(SyntaxTree syntaxTree, ModuleRepository moduleRepository)
{
_syntaxTree = syntaxTree;
_importedModules = moduleSignatures.Where(x => syntaxTree.Metadata.Imports.Contains(x.Key) || _syntaxTree.Metadata.ModuleName == x.Key).ToDictionary();
_visibleModules = moduleRepository
.Modules()
.Where(x => syntaxTree.Metadata.Imports.Contains(x.Key) || _syntaxTree.Metadata.ModuleName == x.Key)
.ToDictionary();
}
public IReadOnlyList<Diagnostic> GetDiagnostics() => _diagnostics;
public IReadOnlyList<DefinitionNode> Definitions => _definitions;
public IReadOnlyList<Diagnostic> Diagnostics => _diagnostics;
public IReadOnlyList<StructTypeNode> ReferencedStructTypes => _referencedStructTypes;
public IReadOnlyList<InterfaceTypeNode> ReferencedInterfaceTypes => _referencedInterfaceTypes;
public IReadOnlyList<DefinitionNode> Check()
public void Check()
{
_diagnostics.Clear();
_scopes.Clear();
var definitions = new List<DefinitionNode>();
_funcReturnTypes.Clear();
_diagnostics.Clear();
_referencedStructTypes.Clear();
_referencedInterfaceTypes.Clear();
_definitions.Clear();
foreach (var definition in _syntaxTree.Definitions)
{
try
{
definitions.Add(CheckDefinition(definition));
_definitions.Add(CheckDefinition(definition));
}
catch (TypeCheckerException e)
{
_diagnostics.Add(e.Diagnostic);
}
}
return definitions;
}
private DefinitionNode CheckDefinition(DefinitionSyntax node)
@@ -78,7 +88,7 @@ public sealed class TypeChecker
value = CheckExpression(field.Value.Value);
}
fields.Add(new StructFieldNode(field.Index, field.Name, ResolveType(field.Type), value));
fields.Add(new StructFieldNode(field.Name, ResolveType(field.Type), value));
}
var functions = new List<StructFuncNode>();
@@ -351,12 +361,12 @@ public sealed class TypeChecker
}
// Second, look in the current module for a function matching the identifier
var module = _importedModules[_syntaxTree.Metadata.ModuleName];
var function = module.AllFunctions.FirstOrDefault(x => x.Name == expression.Name);
var module = _visibleModules[_syntaxTree.Metadata.ModuleName];
var function = module.Functions(true).FirstOrDefault(x => x.Name == expression.Name);
if (function != null)
{
return new FuncIdentifierNode(function.FuncType, _syntaxTree.Metadata.ModuleName, expression.Name);
return new FuncIdentifierNode(ResolveType(function.FuncType), _syntaxTree.Metadata.ModuleName, expression.Name, function.ExternSymbol);
}
throw new TypeCheckerException(Diagnostic.Error($"Symbol {expression.Name} not found").At(expression).Build());
@@ -364,27 +374,18 @@ public sealed class TypeChecker
private ExpressionNode CheckModuleIdentifier(ModuleIdentifierSyntax expression)
{
if (!_importedModules.TryGetValue(expression.Module, out var module))
if (!_visibleModules.TryGetValue(expression.Module, out var module))
{
throw new TypeCheckerException(Diagnostic.Error($"Module {expression.Module} not found").WithHelp($"import \"{expression.Module}\"").At(expression).Build());
}
// First, look for the exported function in the specified module (or all functions if current module)
if (expression.Module == _syntaxTree.Metadata.ModuleName)
var includePrivate = expression.Module == _syntaxTree.Metadata.ModuleName;
// First, look for the exported function in the specified module
var function = module.Functions(includePrivate).FirstOrDefault(x => x.Name == expression.Name);
if (function != null)
{
var function = module.AllFunctions.FirstOrDefault(x => x.Name == expression.Name);
if (function != null)
{
return new FuncIdentifierNode(function.FuncType, expression.Module, expression.Name);
}
}
else
{
var function = module.ExportedFunctions.FirstOrDefault(x => x.Name == expression.Name);
if (function != null)
{
return new FuncIdentifierNode(function.FuncType, expression.Module, expression.Name);
}
return new FuncIdentifierNode(ResolveType(function.FuncType), expression.Module, expression.Name, function.ExternSymbol);
}
throw new TypeCheckerException(Diagnostic.Error($"No exported symbol {expression.Name} not found in module {expression.Module}").At(expression).Build());
@@ -521,38 +522,28 @@ public sealed class TypeChecker
private TypeNode ResolveCustomType(CustomTypeSyntax customType)
{
if (!_importedModules.TryGetValue(customType.Module, out var module))
if (!_visibleModules.TryGetValue(customType.Module, out var module))
{
throw new TypeCheckerException(Diagnostic.Error($"Module {customType.Module} not found").WithHelp($"import \"{customType.Module}\"").At(customType).Build());
}
if (customType.Module == _syntaxTree.Metadata.ModuleName)
{
var structType = module.AllStructTypes.FirstOrDefault(x => x.StructType.Name == customType.Name);
if (structType != null)
{
return structType.StructType;
}
var includePrivate = customType.Module == _syntaxTree.Metadata.ModuleName;
var interfaceType = module.AllInterfaceTypes.FirstOrDefault(x => x.InterfaceType.Name == customType.Name);
if (interfaceType != null)
{
return interfaceType.InterfaceType;
}
var structType = module.StructTypes(includePrivate).FirstOrDefault(x => x.Name == customType.Name);
if (structType != null)
{
var fields = structType.Fields.Select(x => new StructTypeField(x.Name, ResolveType(x.Type), x.HasDefaultValue)).ToList();
var result = new StructTypeNode(customType.Module, structType.Name, fields, [], []);
_referencedStructTypes.Add(result);
return result;
}
else
{
var structType = module.ExportedStructTypes.FirstOrDefault(x => x.StructType.Name == customType.Name);
if (structType != null)
{
return structType.StructType;
}
var interfaceType = module.ExportedInterfaceTypes.FirstOrDefault(x => x.InterfaceType.Name == customType.Name);
if (interfaceType != null)
{
return interfaceType.InterfaceType;
}
var interfaceType = module.InterfaceTypes(includePrivate).FirstOrDefault(x => x.Name == customType.Name);
if (interfaceType != null)
{
var result = new InterfaceTypeNode(customType.Module, interfaceType.Name, []);
_referencedInterfaceTypes.AddRange(result);
return result;
}
throw new TypeCheckerException(Diagnostic.Error($"Type {customType.Name} not found in module {customType.Module}").At(customType).Build());