This commit is contained in:
nub31
2025-10-31 14:42:58 +01:00
parent 031b118a24
commit 7c7624b1bc
17 changed files with 453 additions and 605 deletions

View File

@@ -1,12 +0,0 @@
using NubLang.Syntax;
namespace NubLang.Ast;
// public sealed class CompilationUnit(IdentifierToken module, List<FuncNode> functions, List<StructNode> structTypes, Dictionary<IdentifierToken, List<NubStructType>> importedStructTypes, Dictionary<IdentifierToken, List<FuncPrototypeNode>> importedFunctions)
// {
// public IdentifierToken Module { get; } = module;
// public List<FuncNode> Functions { get; } = functions;
// public List<StructNode> Structs { get; } = structTypes;
// public Dictionary<IdentifierToken, List<NubStructType>> ImportedStructTypes { get; } = importedStructTypes;
// public Dictionary<IdentifierToken, List<FuncPrototypeNode>> ImportedFunctions { get; } = importedFunctions;
// }

View File

@@ -1,4 +1,5 @@
using NubLang.Syntax;
using NubLang.Types;
namespace NubLang.Ast;
@@ -31,16 +32,6 @@ public abstract class Node(List<Token> tokens)
public abstract class TopLevelNode(List<Token> tokens) : Node(tokens);
public class ImportNode(List<Token> tokens, IdentifierToken nameToken) : TopLevelNode(tokens)
{
public IdentifierToken NameToken { get; } = nameToken;
public override IEnumerable<Node> Children()
{
return [];
}
}
public class ModuleNode(List<Token> tokens, IdentifierToken nameToken) : TopLevelNode(tokens)
{
public IdentifierToken NameToken { get; } = nameToken;
@@ -490,7 +481,18 @@ public class VariableIdentifierNode(List<Token> tokens, NubType type, Identifier
}
}
public class FuncIdentifierNode(List<Token> tokens, NubType type, IdentifierToken moduleToken, IdentifierToken nameToken, StringLiteralToken? externSymbolToken) : RValue(tokens, type)
public class LocalFuncIdentifierNode(List<Token> tokens, NubType type, IdentifierToken nameToken, StringLiteralToken? externSymbolToken) : RValue(tokens, type)
{
public IdentifierToken NameToken { get; } = nameToken;
public StringLiteralToken? ExternSymbolToken { get; } = externSymbolToken;
public override IEnumerable<Node> Children()
{
return [];
}
}
public class ModuleFuncIdentifierNode(List<Token> tokens, NubType type, IdentifierToken moduleToken, IdentifierToken nameToken, StringLiteralToken? externSymbolToken) : RValue(tokens, type)
{
public IdentifierToken ModuleToken { get; } = moduleToken;
public IdentifierToken NameToken { get; } = nameToken;
@@ -612,17 +614,4 @@ public class ConstArrayInitializerNode(List<Token> tokens, NubType type, List<Ex
}
}
public abstract class IntermediateExpression(List<Token> tokens) : ExpressionNode(tokens, new NubVoidType());
public class EnumReferenceIntermediateNode(List<Token> tokens, IdentifierToken moduleToken, IdentifierToken nameToken) : IntermediateExpression(tokens)
{
public IdentifierToken ModuleToken { get; } = moduleToken;
public IdentifierToken NameToken { get; } = nameToken;
public override IEnumerable<Node> Children()
{
return [];
}
}
#endregion

View File

@@ -1,267 +0,0 @@
using System.Security.Cryptography;
using System.Text;
namespace NubLang.Ast;
public abstract class NubType : IEquatable<NubType>
{
public abstract ulong GetSize();
public abstract ulong GetAlignment();
public abstract bool IsAggregate();
public override bool Equals(object? obj) => obj is NubType other && Equals(other);
public abstract bool Equals(NubType? other);
public abstract override int GetHashCode();
public abstract override string ToString();
public static bool operator ==(NubType? left, NubType? right) => Equals(left, right);
public static bool operator !=(NubType? left, NubType? right) => !Equals(left, right);
}
public class NubVoidType : NubType
{
public override ulong GetSize() => 8;
public override ulong GetAlignment() => 8;
public override bool IsAggregate() => false;
public override string ToString() => "void";
public override bool Equals(NubType? other) => other is NubVoidType;
public override int GetHashCode() => HashCode.Combine(typeof(NubVoidType));
}
public sealed class NubIntType(bool signed, ulong width) : NubType
{
public bool Signed { get; } = signed;
public ulong Width { get; } = width;
public override ulong GetSize() => Width / 8;
public override ulong GetAlignment() => Width / 8;
public override bool IsAggregate() => false;
public override string ToString() => $"{(Signed ? "i" : "u")}{Width}";
public override bool Equals(NubType? other) => other is NubIntType @int && @int.Width == Width && @int.Signed == Signed;
public override int GetHashCode() => HashCode.Combine(typeof(NubIntType), Signed, Width);
}
public sealed class NubFloatType(ulong width) : NubType
{
public ulong Width { get; } = width;
public override ulong GetSize() => Width / 8;
public override ulong GetAlignment() => Width / 8;
public override bool IsAggregate() => false;
public override string ToString() => $"f{Width}";
public override bool Equals(NubType? other) => other is NubFloatType @float && @float.Width == Width;
public override int GetHashCode() => HashCode.Combine(typeof(NubFloatType), Width);
}
public class NubBoolType : NubType
{
public override ulong GetSize() => 1;
public override ulong GetAlignment() => 1;
public override bool IsAggregate() => false;
public override string ToString() => "bool";
public override bool Equals(NubType? other) => other is NubBoolType;
public override int GetHashCode() => HashCode.Combine(typeof(NubBoolType));
}
public sealed class NubPointerType(NubType baseType) : NubType
{
public NubType BaseType { get; } = baseType;
public override ulong GetSize() => 8;
public override ulong GetAlignment() => 8;
public override bool IsAggregate() => false;
public override string ToString() => "^" + BaseType;
public override bool Equals(NubType? other) => other is NubPointerType pointer && BaseType.Equals(pointer.BaseType);
public override int GetHashCode() => HashCode.Combine(typeof(NubPointerType), BaseType);
}
public class NubFuncType(List<NubType> parameters, NubType returnType) : NubType
{
public List<NubType> Parameters { get; } = parameters;
public NubType ReturnType { get; } = returnType;
public override ulong GetSize() => 8;
public override ulong GetAlignment() => 8;
public override bool IsAggregate() => false;
public override string ToString() => $"func({string.Join(", ", Parameters)}): {ReturnType}";
public override bool Equals(NubType? other) => other is NubFuncType func && ReturnType.Equals(func.ReturnType) && Parameters.SequenceEqual(func.Parameters);
public override int GetHashCode()
{
var hash = new HashCode();
hash.Add(typeof(NubFuncType));
hash.Add(ReturnType);
foreach (var param in Parameters)
{
hash.Add(param);
}
return hash.ToHashCode();
}
}
public class NubStructType(string module, string name, bool packed, List<NubStructFieldType> fields) : NubType
{
public string Module { get; } = module;
public string Name { get; } = name;
public bool Packed { get; } = packed;
public List<NubStructFieldType> Fields { get; set; } = fields;
public int GetFieldIndex(string name)
{
return Fields.FindIndex(x => x.Name == name);
}
public Dictionary<string, ulong> GetFieldOffsets()
{
var offsets = new Dictionary<string, ulong>();
ulong offset = 0;
foreach (var field in Fields)
{
var alignment = Packed ? 1 : field.Type.GetAlignment();
if (!Packed)
{
var padding = (alignment - offset % alignment) % alignment;
offset += padding;
}
offsets[field.Name] = offset;
offset += field.Type.GetSize();
}
return offsets;
}
public override ulong GetSize()
{
var offsets = GetFieldOffsets();
if (Fields.Count == 0)
{
return 0;
}
var lastField = Fields.Last();
var size = offsets[lastField.Name] + lastField.Type.GetSize();
if (!Packed)
{
var structAlignment = GetAlignment();
var padding = (structAlignment - size % structAlignment) % structAlignment;
size += padding;
}
return size;
}
public override ulong GetAlignment()
{
if (Fields.Count == 0)
return 1;
return Packed ? 1 : Fields.Max(f => f.Type.GetAlignment());
}
public override bool IsAggregate() => true;
public override string ToString() => $"{Module}::{Name}";
public override bool Equals(NubType? other) => other is NubStructType structType && Name == structType.Name && Module == structType.Module;
public override int GetHashCode() => HashCode.Combine(typeof(NubStructType), Module, Name);
}
public class NubStructFieldType(string name, NubType type, bool hasDefaultValue)
{
public string Name { get; } = name;
public NubType Type { get; } = type;
public bool HasDefaultValue { get; } = hasDefaultValue;
}
public class NubSliceType(NubType elementType) : NubType
{
public NubType ElementType { get; } = elementType;
public override ulong GetSize() => 16; // note(nub31): Fat pointer
public override ulong GetAlignment() => 8;
public override bool IsAggregate() => true;
public override string ToString() => "[]" + ElementType;
public override bool Equals(NubType? other) => other is NubSliceType slice && ElementType.Equals(slice.ElementType);
public override int GetHashCode() => HashCode.Combine(typeof(NubSliceType), ElementType);
}
public class NubConstArrayType(NubType elementType, ulong size) : NubType
{
public NubType ElementType { get; } = elementType;
public ulong Size { get; } = size;
public override ulong GetSize() => ElementType.GetSize() * Size;
public override ulong GetAlignment() => ElementType.GetAlignment();
public override bool IsAggregate() => true;
public override string ToString() => $"[{Size}]{ElementType}";
public override bool Equals(NubType? other) => other is NubConstArrayType array && ElementType.Equals(array.ElementType) && Size == array.Size;
public override int GetHashCode() => HashCode.Combine(typeof(NubConstArrayType), ElementType, Size);
}
public class NubArrayType(NubType elementType) : NubType
{
public NubType ElementType { get; } = elementType;
public override ulong GetSize() => 8;
public override ulong GetAlignment() => 8;
public override bool IsAggregate() => false; // note(nub31): Just a pointer
public override string ToString() => $"[?]{ElementType}";
public override bool Equals(NubType? other) => other is NubArrayType array && ElementType.Equals(array.ElementType);
public override int GetHashCode() => HashCode.Combine(typeof(NubArrayType), ElementType);
}
public class NubStringType : NubType
{
public override ulong GetSize() => 16; // note(nub31): Fat pointer
public override ulong GetAlignment() => 8;
public override bool IsAggregate() => true;
public override string ToString() => "string";
public override bool Equals(NubType? other) => other is NubStringType;
public override int GetHashCode() => HashCode.Combine(typeof(NubStringType));
}
public static class NameMangler
{
public static string Mangle(params IEnumerable<NubType> types)
{
var readable = string.Join(":", types.Select(EncodeType));
return ComputeShortHash(readable);
}
private static string EncodeType(NubType node) => node switch
{
NubVoidType => "V",
NubBoolType => "B",
NubIntType i => (i.Signed ? "I" : "U") + i.Width,
NubFloatType f => "F" + f.Width,
NubStringType => "S",
NubArrayType a => $"A({EncodeType(a.ElementType)})",
NubConstArrayType ca => $"CA({EncodeType(ca.ElementType)})",
NubSliceType a => $"SL{EncodeType(a.ElementType)}()",
NubPointerType p => $"P({EncodeType(p.BaseType)})",
NubFuncType fn => $"FN({string.Join(":", fn.Parameters.Select(EncodeType))}:{EncodeType(fn.ReturnType)})",
NubStructType st => $"ST({st.Module}:{st.Name})",
_ => throw new NotSupportedException($"Cannot encode type: {node}")
};
private static string ComputeShortHash(string input)
{
var bytes = Encoding.UTF8.GetBytes(input);
var hash = SHA256.HashData(bytes);
return Convert.ToHexString(hash[..8]).ToLower();
}
}

View File

@@ -1,32 +1,28 @@
using System.Diagnostics;
using NubLang.Diagnostics;
using NubLang.Modules;
using NubLang.Syntax;
using NubLang.Types;
namespace NubLang.Ast;
public sealed class TypeChecker
{
private readonly SyntaxTree _syntaxTree;
private readonly Dictionary<string, Module> _modules;
private SyntaxTree _syntaxTree = null!;
private ModuleRepository _repository = null!;
private readonly Stack<Scope> _scopes = [];
private Stack<Scope> _scopes = [];
private readonly TypeResolver _typeResolver;
private Scope CurrentScope => _scopes.Peek();
private Scope Scope => _scopes.Peek();
public List<Diagnostic> Diagnostics { get; set; } = [];
public List<Diagnostic> Diagnostics { get; } = [];
public TypeChecker(SyntaxTree syntaxTree, Dictionary<string, Module> modules)
public List<TopLevelNode> Check(SyntaxTree syntaxTree, ModuleRepository repository)
{
_syntaxTree = syntaxTree;
_modules = modules;
_typeResolver = new TypeResolver(_modules);
}
public List<TopLevelNode> Check()
{
_scopes.Clear();
_repository = repository;
Diagnostics = [];
_scopes = [];
var moduleDeclarations = _syntaxTree.TopLevelSyntaxNodes.OfType<ModuleSyntax>().ToList();
if (moduleDeclarations.Count == 0)
@@ -40,51 +36,11 @@ public sealed class TypeChecker
Diagnostics.Add(Diagnostic.Error("Multiple module declarations").WithHelp("Remove extra module declarations").Build());
}
var moduleName = moduleDeclarations[0].NameToken;
var importDeclarations = _syntaxTree.TopLevelSyntaxNodes.OfType<ImportSyntax>().ToList();
foreach (var importDeclaration in importDeclarations)
{
var name = importDeclaration.NameToken.Value;
var last = importDeclarations.Last(x => x.NameToken.Value == name);
if (importDeclaration != last)
{
Diagnostics.Add(Diagnostic
.Warning($"Module \"{last.NameToken.Value}\" is imported twice")
.WithHelp($"Remove duplicate import \"{last.NameToken.Value}\"")
.At(last)
.Build());
}
var exists = _modules.ContainsKey(name);
if (!exists)
{
var suggestions = _modules.Keys
.Select(m => new { Name = m, Distance = Utils.LevenshteinDistance(name, m) })
.OrderBy(x => x.Distance)
.Take(3)
.Where(x => x.Distance <= 3)
.Select(x => $"\"{x.Name}\"")
.ToArray();
var suggestionText = suggestions.Length != 0
? $"Did you mean {string.Join(", ", suggestions)}?"
: "Check that the module name is correct.";
Diagnostics.Add(Diagnostic
.Error($"Module \"{name}\" does not exist")
.WithHelp(suggestionText)
.At(last)
.Build());
return [];
}
}
var module = _repository.Get(moduleDeclarations[0].NameToken);
var topLevelNodes = new List<TopLevelNode>();
using (BeginRootScope(moduleName))
using (BeginRootScope(module))
{
foreach (var topLevelSyntaxNode in _syntaxTree.TopLevelSyntaxNodes)
{
@@ -98,9 +54,6 @@ public sealed class TypeChecker
case StructSyntax structSyntax:
topLevelNodes.Add(CheckStructDefinition(structSyntax));
break;
case ImportSyntax importSyntax:
topLevelNodes.Add(new ImportNode(importSyntax.Tokens, importSyntax.NameToken));
break;
case ModuleSyntax moduleSyntax:
topLevelNodes.Add(new ModuleNode(moduleSyntax.Tokens, moduleSyntax.NameToken));
break;
@@ -113,58 +66,15 @@ public sealed class TypeChecker
return topLevelNodes;
}
private (IdentifierToken Name, Module Module) GetCurrentModule()
{
var currentModule = _syntaxTree.TopLevelSyntaxNodes.OfType<ModuleSyntax>().First().NameToken;
return (currentModule, _modules[currentModule.Value]);
}
private List<(IdentifierToken Name, Module Module)> GetImportedModules()
{
var currentModule = _syntaxTree.TopLevelSyntaxNodes.OfType<ModuleSyntax>().First().NameToken;
return _syntaxTree.TopLevelSyntaxNodes
.OfType<ImportSyntax>()
.Select(x => (Name: x.NameToken, Module: _modules[x.NameToken.Value]))
.Concat([(Name: currentModule, Module: _modules[currentModule.Value])])
.ToList();
}
private bool IsCurrentModule(IdentifierToken? module)
{
if (module == null)
{
return true;
}
return module.Value == Scope.Module.Value;
}
private Module? GetImportedModule(string module)
{
var currentModule = _syntaxTree.TopLevelSyntaxNodes.OfType<ModuleSyntax>().First().NameToken;
if (module == currentModule.Value)
{
return _modules[currentModule.Value];
}
var import = _syntaxTree.TopLevelSyntaxNodes.OfType<ImportSyntax>().FirstOrDefault(x => x.NameToken.Value == module);
if (import != null)
{
return _modules[import.NameToken.Value];
}
return null;
}
private ScopeDisposer BeginScope()
{
_scopes.Push(Scope.SubScope());
_scopes.Push(CurrentScope.SubScope());
return new ScopeDisposer(this);
}
private ScopeDisposer BeginRootScope(IdentifierToken moduleName)
private ScopeDisposer BeginRootScope(ModuleRepository.Module module)
{
_scopes.Push(new Scope(moduleName));
_scopes.Push(new Scope(module));
return new ScopeDisposer(this);
}
@@ -186,10 +96,10 @@ public sealed class TypeChecker
{
var prototype = CheckFuncPrototype(node.Prototype);
Scope.SetReturnType(prototype.ReturnType);
CurrentScope.SetReturnType(prototype.ReturnType);
foreach (var parameter in prototype.Parameters)
{
Scope.DeclareVariable(new Variable(parameter.NameToken, parameter.Type));
CurrentScope.DeclareVariable(new Variable(parameter.NameToken, parameter.Type));
}
var body = node.Body == null ? null : CheckBlock(node.Body);
@@ -203,7 +113,7 @@ public sealed class TypeChecker
foreach (var field in structSyntax.Fields)
{
var fieldType = _typeResolver.ResolveType(field.Type, Scope.Module.Value);
var fieldType = ResolveType(field.Type);
ExpressionNode? value = null;
if (field.Value != null)
{
@@ -220,8 +130,7 @@ public sealed class TypeChecker
fields.Add(new StructFieldNode(field.Tokens, field.NameToken, fieldType, value));
}
var currentModule = GetCurrentModule();
var type = new NubStructType(currentModule.Name.Value, structSyntax.NameToken.Value, structSyntax.Packed, fields.Select(x => new NubStructFieldType(x.NameToken.Value, x.Type, x.Value != null)).ToList());
var type = new NubStructType(CurrentScope.Module.Name, structSyntax.NameToken.Value, structSyntax.Packed, fields.Select(x => new NubStructFieldType(x.NameToken.Value, x.Type, x.Value != null)).ToList());
return new StructNode(structSyntax.Tokens, structSyntax.NameToken, type, structSyntax.Packed, fields);
}
@@ -261,7 +170,7 @@ public sealed class TypeChecker
if (statement.Value != null)
{
var expectedReturnType = Scope.GetReturnType();
var expectedReturnType = CurrentScope.GetReturnType();
value = CheckExpression(statement.Value, expectedReturnType);
}
@@ -286,7 +195,7 @@ public sealed class TypeChecker
if (statement.ExplicitType != null)
{
type = _typeResolver.ResolveType(statement.ExplicitType, Scope.Module.Value);
type = ResolveType(statement.ExplicitType);
}
if (statement.Assignment != null)
@@ -314,7 +223,7 @@ public sealed class TypeChecker
.Build());
}
Scope.DeclareVariable(new Variable(statement.NameToken, type));
CurrentScope.DeclareVariable(new Variable(statement.NameToken, type));
return new VariableDeclarationNode(statement.Tokens, statement.NameToken, assignmentNode, type);
}
@@ -337,10 +246,10 @@ public sealed class TypeChecker
{
using (BeginScope())
{
Scope.DeclareVariable(new Variable(forSyntax.ElementNameToken, sliceType.ElementType));
CurrentScope.DeclareVariable(new Variable(forSyntax.ElementNameToken, sliceType.ElementType));
if (forSyntax.IndexNameToken != null)
{
Scope.DeclareVariable(new Variable(forSyntax.IndexNameToken, new NubIntType(false, 64)));
CurrentScope.DeclareVariable(new Variable(forSyntax.IndexNameToken, new NubIntType(false, 64)));
}
var body = CheckBlock(forSyntax.Body);
@@ -351,10 +260,10 @@ public sealed class TypeChecker
{
using (BeginScope())
{
Scope.DeclareVariable(new Variable(forSyntax.ElementNameToken, constArrayType.ElementType));
CurrentScope.DeclareVariable(new Variable(forSyntax.ElementNameToken, constArrayType.ElementType));
if (forSyntax.IndexNameToken != null)
{
Scope.DeclareVariable(new Variable(forSyntax.IndexNameToken, new NubIntType(false, 64)));
CurrentScope.DeclareVariable(new Variable(forSyntax.IndexNameToken, new NubIntType(false, 64)));
}
var body = CheckBlock(forSyntax.Body);
@@ -376,10 +285,10 @@ public sealed class TypeChecker
var parameters = new List<FuncParameterNode>();
foreach (var parameter in statement.Parameters)
{
parameters.Add(new FuncParameterNode(parameter.Tokens, parameter.NameToken, _typeResolver.ResolveType(parameter.Type, Scope.Module.Value)));
parameters.Add(new FuncParameterNode(parameter.Tokens, parameter.NameToken, ResolveType(parameter.Type)));
}
return new FuncPrototypeNode(statement.Tokens, statement.NameToken, statement.ExternSymbolToken, parameters, _typeResolver.ResolveType(statement.ReturnType, Scope.Module.Value));
return new FuncPrototypeNode(statement.Tokens, statement.NameToken, statement.ExternSymbolToken, parameters, ResolveType(statement.ReturnType));
}
private ExpressionNode CheckExpression(ExpressionSyntax node, NubType? expectedType = null)
@@ -401,7 +310,7 @@ public sealed class TypeChecker
FloatLiteralSyntax expression => CheckFloatLiteral(expression, expectedType),
MemberAccessSyntax expression => CheckMemberAccess(expression, expectedType),
StructInitializerSyntax expression => CheckStructInitializer(expression, expectedType),
SizeSyntax expression => new SizeNode(node.Tokens, _typeResolver.ResolveType(expression.Type, Scope.Module.Value)),
SizeSyntax expression => new SizeNode(node.Tokens, ResolveType(expression.Type)),
CastSyntax expression => CheckCast(expression, expectedType),
_ => throw new ArgumentOutOfRangeException(nameof(node))
};
@@ -856,26 +765,16 @@ public sealed class TypeChecker
private ExpressionNode CheckLocalIdentifier(LocalIdentifierSyntax expression, NubType? _)
{
// note(nub31): Local identifiers can be variables or a symbol in a module
var scopeIdent = Scope.LookupVariable(expression.NameToken);
var scopeIdent = CurrentScope.LookupVariable(expression.NameToken);
if (scopeIdent != null)
{
return new VariableIdentifierNode(expression.Tokens, scopeIdent.Type, expression.NameToken);
}
var module = GetImportedModule(Scope.Module.Value)!;
var function = module.Functions(true).FirstOrDefault(x => x.NameToken.Value == expression.NameToken.Value);
if (function != null)
if (CurrentScope.Module.TryResolveFunc(expression.NameToken, out var function, out var _))
{
var parameters = function.Prototype.Parameters.Select(x => _typeResolver.ResolveType(x.Type, Scope.Module.Value)).ToList();
var type = new NubFuncType(parameters, _typeResolver.ResolveType(function.Prototype.ReturnType, Scope.Module.Value));
return new FuncIdentifierNode(expression.Tokens, type, Scope.Module, expression.NameToken, function.Prototype.ExternSymbolToken);
}
var enumDef = module.Enums(true).FirstOrDefault(x => x.NameToken.Value == expression.NameToken.Value);
if (enumDef != null)
{
return new EnumReferenceIntermediateNode(expression.Tokens, Scope.Module, expression.NameToken);
var type = new NubFuncType(function.Parameters.Select(x => x.Type).ToList(), function.ReturnType);
return new LocalFuncIdentifierNode(expression.Tokens, type, expression.NameToken, function.ExternSymbolToken);
}
throw new CompileException(Diagnostic
@@ -886,37 +785,20 @@ public sealed class TypeChecker
private ExpressionNode CheckModuleIdentifier(ModuleIdentifierSyntax expression, NubType? _)
{
var module = GetImportedModule(expression.ModuleToken.Value);
if (module == null)
var module = _repository.Get(expression.ModuleToken);
using (BeginRootScope(module))
{
if (module.TryResolveFunc(expression.NameToken, out var function, out var _))
{
var type = new NubFuncType(function.Parameters.Select(x => x.Type).ToList(), function.ReturnType);
return new ModuleFuncIdentifierNode(expression.Tokens, type, expression.ModuleToken, expression.NameToken, function.ExternSymbolToken);
}
throw new CompileException(Diagnostic
.Error($"Module {expression.ModuleToken.Value} not found")
.WithHelp($"import \"{expression.ModuleToken.Value}\"")
.At(expression.ModuleToken)
.Error($"Module {expression.ModuleToken.Value} does not export a member named {expression.NameToken.Value}")
.At(expression)
.Build());
}
var function = module.Functions(false).FirstOrDefault(x => x.NameToken.Value == expression.NameToken.Value);
if (function != null)
{
using (BeginRootScope(expression.ModuleToken))
{
var parameters = function.Prototype.Parameters.Select(x => _typeResolver.ResolveType(x.Type, Scope.Module.Value)).ToList();
var type = new NubFuncType(parameters, _typeResolver.ResolveType(function.Prototype.ReturnType, Scope.Module.Value));
return new FuncIdentifierNode(expression.Tokens, type, expression.ModuleToken, expression.NameToken, function.Prototype.ExternSymbolToken);
}
}
var enumDef = module.Enums(false).FirstOrDefault(x => x.NameToken.Value == expression.NameToken.Value);
if (enumDef != null)
{
return new EnumReferenceIntermediateNode(expression.Tokens, expression.ModuleToken, expression.NameToken);
}
throw new CompileException(Diagnostic
.Error($"Module {expression.ModuleToken.Value} does not export a member named {expression.NameToken.Value}")
.At(expression)
.Build());
}
private ExpressionNode CheckStringLiteral(StringLiteralSyntax expression, NubType? expectedType)
@@ -980,55 +862,6 @@ public sealed class TypeChecker
{
var target = CheckExpression(expression.Target);
if (target is EnumReferenceIntermediateNode enumReferenceIntermediate)
{
var enumDef = GetImportedModules()
.First(x => x.Name.Value == enumReferenceIntermediate.ModuleToken.Value)
.Module
.Enums(IsCurrentModule(enumReferenceIntermediate.ModuleToken))
.First(x => x.NameToken.Value == enumReferenceIntermediate.NameToken.Value);
var field = enumDef.Fields.FirstOrDefault(x => x.NameToken.Value == expression.MemberToken.Value);
if (field == null)
{
throw new CompileException(Diagnostic
.Error($"Enum {Scope.Module.Value}::{enumReferenceIntermediate.NameToken.Value} does not have a field named {expression.MemberToken.Value}")
.At(enumDef)
.Build());
}
var enumType = enumDef.Type != null ? _typeResolver.ResolveType(enumDef.Type, Scope.Module.Value) : new NubIntType(false, 64);
if (enumType is not NubIntType enumIntType)
{
throw new CompileException(Diagnostic.Error("Enum type must be an int type").At(enumDef.Type).Build());
}
if (enumIntType.Signed)
{
var fieldValue = CalculateSignedEnumFieldValue(enumDef, field);
return enumIntType.Width switch
{
8 => new I8LiteralNode(expression.Tokens, (sbyte)fieldValue),
16 => new I16LiteralNode(expression.Tokens, (short)fieldValue),
32 => new I32LiteralNode(expression.Tokens, (int)fieldValue),
64 => new I64LiteralNode(expression.Tokens, fieldValue),
_ => throw new ArgumentOutOfRangeException()
};
}
else
{
var fieldValue = CalculateUnsignedEnumFieldValue(enumDef, field);
return enumIntType.Width switch
{
8 => new U8LiteralNode(expression.Tokens, (byte)fieldValue),
16 => new U16LiteralNode(expression.Tokens, (ushort)fieldValue),
32 => new U32LiteralNode(expression.Tokens, (uint)fieldValue),
64 => new U64LiteralNode(expression.Tokens, fieldValue),
_ => throw new ArgumentOutOfRangeException()
};
}
}
switch (target.Type)
{
case NubStructType structType:
@@ -1054,57 +887,13 @@ public sealed class TypeChecker
}
}
private static long CalculateSignedEnumFieldValue(EnumSyntax enumDef, EnumFieldSyntax field)
{
long currentValue = 0;
foreach (var f in enumDef.Fields)
{
if (f.ValueToken != null)
{
currentValue = f.ValueToken.AsI64;
}
if (f == field)
{
return currentValue;
}
currentValue++;
}
throw new UnreachableException();
}
private static ulong CalculateUnsignedEnumFieldValue(EnumSyntax enumDef, EnumFieldSyntax field)
{
ulong currentValue = 0;
foreach (var f in enumDef.Fields)
{
if (f.ValueToken != null)
{
currentValue = f.ValueToken.AsU64;
}
if (f == field)
{
return currentValue;
}
currentValue++;
}
throw new UnreachableException();
}
private StructInitializerNode CheckStructInitializer(StructInitializerSyntax expression, NubType? expectedType)
{
NubStructType? structType = null;
if (expression.StructType != null)
{
var checkedType = _typeResolver.ResolveType(expression.StructType, Scope.Module.Value);
var checkedType = ResolveType(expression.StructType);
if (checkedType is not NubStructType checkedStructType)
{
throw new UnreachableException("Parser fucked up");
@@ -1200,44 +989,85 @@ public sealed class TypeChecker
_ => throw new ArgumentOutOfRangeException(nameof(statement))
};
}
}
public record Variable(IdentifierToken Name, NubType Type);
public class Scope(IdentifierToken module, Scope? parent = null)
{
private NubType? _returnType;
private readonly List<Variable> _variables = [];
public IdentifierToken Module { get; } = module;
public void DeclareVariable(Variable variable)
private NubType ResolveType(TypeSyntax type)
{
_variables.Add(variable);
}
public void SetReturnType(NubType returnType)
{
_returnType = returnType;
}
public NubType? GetReturnType()
{
return _returnType ?? parent?.GetReturnType();
}
public Variable? LookupVariable(IdentifierToken name)
{
var variable = _variables.FirstOrDefault(x => x.Name.Value == name.Value);
if (variable != null)
return type switch
{
return variable;
ArrayTypeSyntax arr => new NubArrayType(ResolveType(arr.BaseType)),
BoolTypeSyntax => new NubBoolType(),
IntTypeSyntax i => new NubIntType(i.Signed, i.Width),
FloatTypeSyntax f => new NubFloatType(f.Width),
FuncTypeSyntax func => new NubFuncType(func.Parameters.Select(ResolveType).ToList(), ResolveType(func.ReturnType)),
SliceTypeSyntax slice => new NubSliceType(ResolveType(slice.BaseType)),
ConstArrayTypeSyntax arr => new NubConstArrayType(ResolveType(arr.BaseType), arr.Size),
PointerTypeSyntax ptr => new NubPointerType(ResolveType(ptr.BaseType)),
StringTypeSyntax => new NubStringType(),
CustomTypeSyntax c => ResolveCustomType(c),
VoidTypeSyntax => new NubVoidType(),
_ => throw new NotSupportedException($"Unknown type syntax: {type}")
};
}
private NubType ResolveCustomType(CustomTypeSyntax customType)
{
var module = customType.ModuleToken != null ? _repository.Get(customType.ModuleToken) : CurrentScope.Module;
var structType = module.StructTypes.FirstOrDefault(x => x.Name == customType.NameToken.Value);
if (structType != null)
{
return structType;
}
return parent?.LookupVariable(name);
var enumType = module.EnumTypes.GetValueOrDefault(customType.NameToken.Value);
if (enumType != null)
{
return enumType;
}
throw new CompileException(Diagnostic
.Error($"Type {customType.NameToken.Value} not found in module {module.Name}")
.At(customType)
.Build());
}
public Scope SubScope()
private record Variable(IdentifierToken Name, NubType Type);
private class Scope(ModuleRepository.Module module, Scope? parent = null)
{
return new Scope(Module, this);
private NubType? _returnType;
private readonly List<Variable> _variables = [];
public ModuleRepository.Module Module { get; } = module;
public void DeclareVariable(Variable variable)
{
_variables.Add(variable);
}
public void SetReturnType(NubType returnType)
{
_returnType = returnType;
}
public NubType? GetReturnType()
{
return _returnType ?? parent?.GetReturnType();
}
public Variable? LookupVariable(IdentifierToken name)
{
var variable = _variables.FirstOrDefault(x => x.Name.Value == name.Value);
if (variable != null)
{
return variable;
}
return parent?.LookupVariable(name);
}
public Scope SubScope()
{
return new Scope(Module, this);
}
}
}

View File

@@ -1,96 +0,0 @@
using NubLang.Diagnostics;
using NubLang.Syntax;
namespace NubLang.Ast;
public class TypeResolver
{
private readonly Dictionary<string, Module> _modules;
private readonly Dictionary<(string Module, string Name), NubType> _typeCache = new();
private readonly HashSet<(string Module, string Name)> _resolvingTypes = [];
public TypeResolver(Dictionary<string, Module> modules)
{
_modules = modules;
}
public NubType ResolveType(TypeSyntax type, string currentModule)
{
return type switch
{
ArrayTypeSyntax arr => new NubArrayType(ResolveType(arr.BaseType, currentModule)),
BoolTypeSyntax => new NubBoolType(),
IntTypeSyntax i => new NubIntType(i.Signed, i.Width),
FloatTypeSyntax f => new NubFloatType(f.Width),
FuncTypeSyntax func => new NubFuncType(func.Parameters.Select(x => ResolveType(x, currentModule)).ToList(), ResolveType(func.ReturnType, currentModule)),
SliceTypeSyntax slice => new NubSliceType(ResolveType(slice.BaseType, currentModule)),
ConstArrayTypeSyntax arr => new NubConstArrayType(ResolveType(arr.BaseType, currentModule), arr.Size),
PointerTypeSyntax ptr => new NubPointerType(ResolveType(ptr.BaseType, currentModule)),
StringTypeSyntax => new NubStringType(),
CustomTypeSyntax c => ResolveCustomType(c, currentModule),
VoidTypeSyntax => new NubVoidType(),
_ => throw new NotSupportedException($"Unknown type syntax: {type}")
};
}
private NubType ResolveCustomType(CustomTypeSyntax customType, string currentModule)
{
var module = _modules[customType.ModuleToken?.Value ?? currentModule];
var enumDef = module.Enums(true).FirstOrDefault(x => x.NameToken.Value == customType.NameToken.Value);
if (enumDef != null)
{
return enumDef.Type != null ? ResolveType(enumDef.Type, currentModule) : new NubIntType(false, 64);
}
var structDef = module.Structs(true).FirstOrDefault(x => x.NameToken.Value == customType.NameToken.Value);
if (structDef != null)
{
var key = (customType.ModuleToken?.Value ?? currentModule, customType.NameToken.Value);
if (_typeCache.TryGetValue(key, out var cachedType))
{
return cachedType;
}
if (!_resolvingTypes.Add(key))
{
var placeholder = new NubStructType(customType.ModuleToken?.Value ?? currentModule, customType.NameToken.Value, structDef.Packed, []);
_typeCache[key] = placeholder;
return placeholder;
}
try
{
var result = new NubStructType(customType.ModuleToken?.Value ?? currentModule, structDef.NameToken.Value, structDef.Packed, []);
_typeCache[key] = result;
var fields = structDef.Fields
.Select(x => new NubStructFieldType(x.NameToken.Value, ResolveType(x.Type, currentModule), x.Value != null))
.ToList();
result.Fields.AddRange(fields);
return result;
}
finally
{
_resolvingTypes.Remove(key);
}
}
throw new TypeResolverException(Diagnostic
.Error($"Type {customType.NameToken.Value} not found in module {customType.ModuleToken?.Value ?? currentModule}")
.At(customType)
.Build());
}
}
public class TypeResolverException : Exception
{
public Diagnostic Diagnostic { get; }
public TypeResolverException(Diagnostic diagnostic) : base(diagnostic.Message)
{
Diagnostic = diagnostic;
}
}