This commit is contained in:
nub31
2025-10-25 18:07:34 +02:00
parent 3f18aa4782
commit 396ddf93a2
18 changed files with 951 additions and 598 deletions

View File

@@ -1,15 +1,11 @@
using NubLang.Syntax;
namespace NubLang.Ast;
public sealed class CompilationUnit
public sealed class CompilationUnit(IdentifierToken module, List<FuncNode> functions, Dictionary<IdentifierToken, List<NubStructType>> importedStructTypes, Dictionary<IdentifierToken, List<FuncPrototypeNode>> importedFunctions)
{
public CompilationUnit(List<FuncNode> functions, List<NubStructType> importedStructTypes, List<FuncPrototypeNode> importedFunctions)
{
Functions = functions;
ImportedStructTypes = importedStructTypes;
ImportedFunctions = importedFunctions;
}
public List<FuncNode> Functions { get; }
public List<NubStructType> ImportedStructTypes { get; }
public List<FuncPrototypeNode> ImportedFunctions { get; }
public IdentifierToken Module { get; } = module;
public List<FuncNode> Functions { get; } = functions;
public Dictionary<IdentifierToken, List<NubStructType>> ImportedStructTypes { get; } = importedStructTypes;
public Dictionary<IdentifierToken, List<FuncPrototypeNode>> ImportedFunctions { get; } = importedFunctions;
}

View File

@@ -31,15 +31,14 @@ public abstract class Node(List<Token> tokens)
#region Definitions
public abstract class DefinitionNode(List<Token> tokens, string module, string name) : Node(tokens)
public abstract class DefinitionNode(List<Token> tokens, IdentifierToken nameToken) : Node(tokens)
{
public string Module { get; } = module;
public string Name { get; } = name;
public IdentifierToken NameToken { get; } = nameToken;
}
public class FuncParameterNode(List<Token> tokens, string name, NubType type) : Node(tokens)
public class FuncParameterNode(List<Token> tokens, IdentifierToken nameToken, NubType type) : Node(tokens)
{
public string Name { get; } = name;
public IdentifierToken NameToken { get; } = nameToken;
public NubType Type { get; } = type;
public override IEnumerable<Node> Children()
@@ -48,11 +47,10 @@ public class FuncParameterNode(List<Token> tokens, string name, NubType type) :
}
}
public class FuncPrototypeNode(List<Token> tokens, string module, string name, string? externSymbol, List<FuncParameterNode> parameters, NubType returnType) : Node(tokens)
public class FuncPrototypeNode(List<Token> tokens, IdentifierToken nameToken, StringLiteralToken? externSymbolToken, List<FuncParameterNode> parameters, NubType returnType) : Node(tokens)
{
public string Module { get; } = module;
public string Name { get; } = name;
public string? ExternSymbol { get; } = externSymbol;
public IdentifierToken NameToken { get; } = nameToken;
public StringLiteralToken? ExternSymbolToken { get; } = externSymbolToken;
public List<FuncParameterNode> Parameters { get; } = parameters;
public NubType ReturnType { get; } = returnType;
@@ -62,7 +60,7 @@ public class FuncPrototypeNode(List<Token> tokens, string module, string name, s
}
}
public class FuncNode(List<Token> tokens, FuncPrototypeNode prototype, BlockNode? body) : DefinitionNode(tokens, prototype.Module, prototype.Name)
public class FuncNode(List<Token> tokens, FuncPrototypeNode prototype, BlockNode? body) : DefinitionNode(tokens, prototype.NameToken)
{
public FuncPrototypeNode Prototype { get; } = prototype;
public BlockNode? Body { get; } = body;
@@ -144,9 +142,9 @@ public class IfNode(List<Token> tokens, ExpressionNode condition, BlockNode body
}
}
public class VariableDeclarationNode(List<Token> tokens, string name, ExpressionNode? assignment, NubType type) : StatementNode(tokens)
public class VariableDeclarationNode(List<Token> tokens, IdentifierToken nameToken, ExpressionNode? assignment, NubType type) : StatementNode(tokens)
{
public string Name { get; } = name;
public IdentifierToken NameToken { get; } = nameToken;
public ExpressionNode? Assignment { get; } = assignment;
public NubType Type { get; } = type;
@@ -184,10 +182,10 @@ public class WhileNode(List<Token> tokens, ExpressionNode condition, BlockNode b
}
}
public class ForSliceNode(List<Token> tokens, string elementName, string? indexName, ExpressionNode target, BlockNode body) : StatementNode(tokens)
public class ForSliceNode(List<Token> tokens, IdentifierToken elementNameToken, IdentifierToken? indexNameToken, ExpressionNode target, BlockNode body) : StatementNode(tokens)
{
public string ElementName { get; } = elementName;
public string? IndexName { get; } = indexName;
public IdentifierToken ElementNameToken { get; } = elementNameToken;
public IdentifierToken? IndexNameToken { get; } = indexNameToken;
public ExpressionNode Target { get; } = target;
public BlockNode Body { get; } = body;
@@ -198,10 +196,10 @@ public class ForSliceNode(List<Token> tokens, string elementName, string? indexN
}
}
public class ForConstArrayNode(List<Token> tokens, string elementName, string? indexName, ExpressionNode target, BlockNode body) : StatementNode(tokens)
public class ForConstArrayNode(List<Token> tokens, IdentifierToken elementNameToken, IdentifierToken? indexNameToken, ExpressionNode target, BlockNode body) : StatementNode(tokens)
{
public string ElementName { get; } = elementName;
public string? IndexName { get; } = indexName;
public IdentifierToken ElementNameToken { get; } = elementNameToken;
public IdentifierToken? IndexNameToken { get; } = indexNameToken;
public ExpressionNode Target { get; } = target;
public BlockNode Body { get; } = body;
@@ -385,7 +383,7 @@ public class Float64LiteralNode(List<Token> tokens, double value) : RValueExpres
}
}
public class BoolLiteralNode(List<Token> tokens, NubType type, bool value) : RValueExpressionNode(tokens, type)
public class BoolLiteralNode(List<Token> tokens, bool value) : RValueExpressionNode(tokens, new NubBoolType())
{
public bool Value { get; } = value;
@@ -434,9 +432,9 @@ public class FuncCallNode(List<Token> tokens, NubType type, ExpressionNode expre
}
}
public class VariableIdentifierNode(List<Token> tokens, NubType type, string name) : LValueExpressionNode(tokens, type)
public class VariableIdentifierNode(List<Token> tokens, NubType type, IdentifierToken nameToken) : LValueExpressionNode(tokens, type)
{
public string Name { get; } = name;
public IdentifierToken NameToken { get; } = nameToken;
public override IEnumerable<Node> Children()
{
@@ -444,11 +442,11 @@ public class VariableIdentifierNode(List<Token> tokens, NubType type, string nam
}
}
public class FuncIdentifierNode(List<Token> tokens, NubType type, string module, string name, string? externSymbol) : RValueExpressionNode(tokens, type)
public class FuncIdentifierNode(List<Token> tokens, NubType type, IdentifierToken moduleToken, IdentifierToken nameToken, StringLiteralToken? externSymbolToken) : RValueExpressionNode(tokens, type)
{
public string Module { get; } = module;
public string Name { get; } = name;
public string? ExternSymbol { get; } = externSymbol;
public IdentifierToken ModuleToken { get; } = moduleToken;
public IdentifierToken NameToken { get; } = nameToken;
public StringLiteralToken? ExternSymbolToken { get; } = externSymbolToken;
public override IEnumerable<Node> Children()
{
@@ -522,10 +520,10 @@ public class AddressOfNode(List<Token> tokens, NubType type, LValueExpressionNod
}
}
public class StructFieldAccessNode(List<Token> tokens, NubType type, ExpressionNode target, string field) : LValueExpressionNode(tokens, type)
public class StructFieldAccessNode(List<Token> tokens, NubType type, ExpressionNode target, IdentifierToken fieldToken) : LValueExpressionNode(tokens, type)
{
public ExpressionNode Target { get; } = target;
public string Field { get; } = field;
public IdentifierToken FieldToken { get; } = fieldToken;
public override IEnumerable<Node> Children()
{
@@ -533,9 +531,9 @@ public class StructFieldAccessNode(List<Token> tokens, NubType type, ExpressionN
}
}
public class StructInitializerNode(List<Token> tokens, NubType type, Dictionary<string, ExpressionNode> initializers) : RValueExpressionNode(tokens, type)
public class StructInitializerNode(List<Token> tokens, NubType type, Dictionary<IdentifierToken, ExpressionNode> initializers) : RValueExpressionNode(tokens, type)
{
public Dictionary<string, ExpressionNode> Initializers { get; } = initializers;
public Dictionary<IdentifierToken, ExpressionNode> Initializers { get; } = initializers;
public override IEnumerable<Node> Children()
{
@@ -576,10 +574,10 @@ public class CastNode(List<Token> tokens, NubType type, ExpressionNode value) :
}
}
public class EnumReferenceIntermediateNode(List<Token> tokens, string module, string name) : IntermediateExpression(tokens)
public class EnumReferenceIntermediateNode(List<Token> tokens, IdentifierToken moduleToken, IdentifierToken nameToken) : IntermediateExpression(tokens)
{
public string Module { get; } = module;
public string Name { get; } = name;
public IdentifierToken ModuleToken { get; } = moduleToken;
public IdentifierToken NameToken { get; } = nameToken;
public override IEnumerable<Node> Children()
{

View File

@@ -106,10 +106,10 @@ public class NubSliceType(NubType elementType) : NubType
public override int GetHashCode() => HashCode.Combine(typeof(NubSliceType), ElementType);
}
public class NubConstArrayType(NubType elementType, long size) : NubType
public class NubConstArrayType(NubType elementType, ulong size) : NubType
{
public NubType ElementType { get; } = elementType;
public long Size { get; } = size;
public ulong Size { get; } = size;
public override string ToString() => $"[{Size}]{ElementType}";
public override bool Equals(NubType? other) => other is NubConstArrayType array && ElementType.Equals(array.ElementType) && Size == array.Size;

View File

@@ -7,7 +7,7 @@ namespace NubLang.Ast;
public sealed class TypeChecker
{
private readonly SyntaxTree _syntaxTree;
private readonly Dictionary<string, Module> _importedModules;
private readonly Dictionary<string, Module> _modules;
private readonly Stack<Scope> _scopes = [];
private readonly Dictionary<(string Module, string Name), NubType> _typeCache = new();
@@ -20,22 +20,74 @@ public sealed class TypeChecker
public TypeChecker(SyntaxTree syntaxTree, Dictionary<string, Module> modules)
{
_syntaxTree = syntaxTree;
_importedModules = modules
.Where(x => syntaxTree.Imports.Contains(x.Key) || _syntaxTree.ModuleName == x.Key)
.ToDictionary();
_modules = modules;
}
public CompilationUnit Check()
public CompilationUnit? Check()
{
_scopes.Clear();
_typeCache.Clear();
_resolvingTypes.Clear();
var moduleDeclarations = _syntaxTree.TopLevelSyntaxNodes.OfType<ModuleSyntax>().ToList();
if (moduleDeclarations.Count == 0)
{
Diagnostics.Add(Diagnostic.Error("Missing module declaration").WithHelp("module \"main\"").Build());
return null;
}
if (moduleDeclarations.Count > 1)
{
Diagnostics.Add(Diagnostic.Error("Multiple module declarations").WithHelp("Remove extra module declarations").Build());
}
var moduleName = moduleDeclarations[0].NameToken;
var importDeclarations = _syntaxTree.TopLevelSyntaxNodes.OfType<ImportSyntax>().ToList();
foreach (var importDeclaration in importDeclarations)
{
var name = importDeclaration.NameToken.Value;
var last = importDeclarations.Last(x => x.NameToken.Value == name);
if (importDeclaration != last)
{
Diagnostics.Add(Diagnostic
.Warning($"Module \"{last.NameToken.Value}\" is imported twice")
.WithHelp($"Remove duplicate import \"{last.NameToken.Value}\"")
.At(last)
.Build());
}
var exists = _modules.ContainsKey(name);
if (!exists)
{
var suggestions = _modules.Keys
.Select(m => new { Name = m, Distance = Utils.LevenshteinDistance(name, m) })
.OrderBy(x => x.Distance)
.Take(3)
.Where(x => x.Distance <= 3)
.Select(x => $"\"{x.Name}\"")
.ToArray();
var suggestionText = suggestions.Length != 0
? $"Did you mean {string.Join(", ", suggestions)}?"
: "Check that the module name is correct.";
Diagnostics.Add(Diagnostic
.Error($"Module \"{name}\" does not exist")
.WithHelp(suggestionText)
.At(last)
.Build());
return null;
}
}
var functions = new List<FuncNode>();
using (BeginRootScope(_syntaxTree.ModuleName))
using (BeginRootScope(moduleName))
{
foreach (var funcSyntax in _syntaxTree.Definitions.OfType<FuncSyntax>())
foreach (var funcSyntax in _syntaxTree.TopLevelSyntaxNodes.OfType<FuncSyntax>())
{
try
{
@@ -48,11 +100,14 @@ public sealed class TypeChecker
}
}
var importedStructTypes = new List<NubStructType>();
var importedFunctions = new List<FuncPrototypeNode>();
var importedStructTypes = new Dictionary<IdentifierToken, List<NubStructType>>();
var importedFunctions = new Dictionary<IdentifierToken, List<FuncPrototypeNode>>();
foreach (var (name, module) in _importedModules)
foreach (var (name, module) in GetImportedModules())
{
var moduleStructs = new List<NubStructType>();
var moduleFunctions = new List<FuncPrototypeNode>();
using (BeginRootScope(name))
{
foreach (var structSyntax in module.Structs(true))
@@ -60,10 +115,10 @@ public sealed class TypeChecker
try
{
var fields = structSyntax.Fields
.Select(f => new NubStructFieldType(f.Name, ResolveType(f.Type), f.Value != null))
.Select(f => new NubStructFieldType(f.NameToken.Value, ResolveType(f.Type), f.Value != null))
.ToList();
importedStructTypes.Add(new NubStructType(name, structSyntax.Name, fields));
moduleStructs.Add(new NubStructType(name.Value, structSyntax.NameToken.Value, fields));
}
catch (TypeCheckerException e)
{
@@ -71,21 +126,62 @@ public sealed class TypeChecker
}
}
importedStructTypes[name] = moduleStructs;
foreach (var funcSyntax in module.Functions(true))
{
try
{
importedFunctions.Add(CheckFuncPrototype(funcSyntax.Prototype));
moduleFunctions.Add(CheckFuncPrototype(funcSyntax.Prototype));
}
catch (TypeCheckerException e)
{
Diagnostics.Add(e.Diagnostic);
}
}
importedFunctions[name] = moduleFunctions;
}
}
return new CompilationUnit(functions, importedStructTypes, importedFunctions);
return new CompilationUnit(moduleName, functions, importedStructTypes, importedFunctions);
}
private List<(IdentifierToken Name, Module Module)> GetImportedModules()
{
var currentModule = _syntaxTree.TopLevelSyntaxNodes.OfType<ModuleSyntax>().First().NameToken;
return _syntaxTree.TopLevelSyntaxNodes
.OfType<ImportSyntax>()
.Select(x => (Name: x.NameToken, Module: _modules[x.NameToken.Value]))
.Concat([(Name: currentModule, Module: _modules[currentModule.Value])])
.ToList();
}
private bool IsCurrentModule(IdentifierToken? module)
{
if (module == null)
{
return true;
}
return module.Value == Scope.Module.Value;
}
private Module? GetImportedModule(string module)
{
var currentModule = _syntaxTree.TopLevelSyntaxNodes.OfType<ModuleSyntax>().First().NameToken;
if (module == currentModule.Value)
{
return _modules[currentModule.Value];
}
var import = _syntaxTree.TopLevelSyntaxNodes.OfType<ImportSyntax>().FirstOrDefault(x => x.NameToken.Value == module);
if (import != null)
{
return _modules[import.NameToken.Value];
}
return null;
}
private ScopeDisposer BeginScope()
@@ -94,7 +190,7 @@ public sealed class TypeChecker
return new ScopeDisposer(this);
}
private ScopeDisposer BeginRootScope(string moduleName)
private ScopeDisposer BeginRootScope(IdentifierToken moduleName)
{
_scopes.Push(new Scope(moduleName));
return new ScopeDisposer(this);
@@ -121,7 +217,7 @@ public sealed class TypeChecker
Scope.SetReturnType(prototype.ReturnType);
foreach (var parameter in prototype.Parameters)
{
Scope.DeclareVariable(new Variable(parameter.Name, parameter.Type));
Scope.DeclareVariable(new Variable(parameter.NameToken, parameter.Type));
}
var body = node.Body == null ? null : CheckBlock(node.Body);
@@ -217,14 +313,14 @@ public sealed class TypeChecker
if (type == null)
{
throw new TypeCheckerException(Diagnostic
.Error($"Cannot infer type of variable {statement.Name}")
.Error($"Cannot infer type of variable {statement.NameToken.Value}")
.At(statement)
.Build());
}
Scope.DeclareVariable(new Variable(statement.Name, type));
Scope.DeclareVariable(new Variable(statement.NameToken, type));
return new VariableDeclarationNode(statement.Tokens, statement.Name, assignmentNode, type);
return new VariableDeclarationNode(statement.Tokens, statement.NameToken, assignmentNode, type);
}
private WhileNode CheckWhile(WhileSyntax statement)
@@ -245,28 +341,28 @@ public sealed class TypeChecker
{
using (BeginScope())
{
Scope.DeclareVariable(new Variable(forSyntax.ElementName, sliceType.ElementType));
if (forSyntax.IndexName != null)
Scope.DeclareVariable(new Variable(forSyntax.ElementNameToken, sliceType.ElementType));
if (forSyntax.IndexNameToken != null)
{
Scope.DeclareVariable(new Variable(forSyntax.IndexName, new NubIntType(false, 64)));
Scope.DeclareVariable(new Variable(forSyntax.IndexNameToken, new NubIntType(false, 64)));
}
var body = CheckBlock(forSyntax.Body);
return new ForSliceNode(forSyntax.Tokens, forSyntax.ElementName, forSyntax.IndexName, target, body);
return new ForSliceNode(forSyntax.Tokens, forSyntax.ElementNameToken, forSyntax.IndexNameToken, target, body);
}
}
case NubConstArrayType constArrayType:
{
using (BeginScope())
{
Scope.DeclareVariable(new Variable(forSyntax.ElementName, constArrayType.ElementType));
if (forSyntax.IndexName != null)
Scope.DeclareVariable(new Variable(forSyntax.ElementNameToken, constArrayType.ElementType));
if (forSyntax.IndexNameToken != null)
{
Scope.DeclareVariable(new Variable(forSyntax.IndexName, new NubIntType(false, 64)));
Scope.DeclareVariable(new Variable(forSyntax.IndexNameToken, new NubIntType(false, 64)));
}
var body = CheckBlock(forSyntax.Body);
return new ForConstArrayNode(forSyntax.Tokens, forSyntax.ElementName, forSyntax.IndexName, target, body);
return new ForConstArrayNode(forSyntax.Tokens, forSyntax.ElementNameToken, forSyntax.IndexNameToken, target, body);
}
}
default:
@@ -284,10 +380,10 @@ public sealed class TypeChecker
var parameters = new List<FuncParameterNode>();
foreach (var parameter in statement.Parameters)
{
parameters.Add(new FuncParameterNode(parameter.Tokens, parameter.Name, ResolveType(parameter.Type)));
parameters.Add(new FuncParameterNode(parameter.Tokens, parameter.NameToken, ResolveType(parameter.Type)));
}
return new FuncPrototypeNode(statement.Tokens, Scope.Module, statement.Name, statement.ExternSymbol, parameters, ResolveType(statement.ReturnType));
return new FuncPrototypeNode(statement.Tokens, statement.NameToken, statement.ExternSymbolToken, parameters, ResolveType(statement.ReturnType));
}
private ExpressionNode CheckExpression(ExpressionSyntax node, NubType? expectedType = null)
@@ -480,7 +576,7 @@ public sealed class TypeChecker
{
NubArrayType => new ArrayInitializerNode(expression.Tokens, new NubArrayType(elementType), values),
NubConstArrayType constArrayType => new ConstArrayInitializerNode(expression.Tokens, constArrayType, values),
_ => new ConstArrayInitializerNode(expression.Tokens, new NubConstArrayType(elementType, expression.Values.Count), values)
_ => new ConstArrayInitializerNode(expression.Tokens, new NubConstArrayType(elementType, (ulong)expression.Values.Count), values)
};
}
@@ -747,81 +843,80 @@ public sealed class TypeChecker
return new FuncCallNode(expression.Tokens, funcType.ReturnType, accessor, parameters);
}
private ExpressionNode? CheckIdentifier(ExpressionSyntax expression, string moduleName, string name)
{
if (!_importedModules.TryGetValue(moduleName, out var module))
{
throw new TypeCheckerException(Diagnostic
.Error($"Module {moduleName} not found")
.WithHelp($"import \"{moduleName}\"")
.At(expression)
.Build());
}
var function = module.Functions(IsCurretModule(moduleName)).FirstOrDefault(x => x.Name == name);
if (function != null)
{
using (BeginRootScope(moduleName))
{
var parameters = function.Prototype.Parameters.Select(x => ResolveType(x.Type)).ToList();
var type = new NubFuncType(parameters, ResolveType(function.Prototype.ReturnType));
return new FuncIdentifierNode(expression.Tokens, type, moduleName, name, function.Prototype.ExternSymbol);
}
}
var enumDef = module.Enums(IsCurretModule(moduleName)).FirstOrDefault(x => x.Name == name);
if (enumDef != null)
{
return new EnumReferenceIntermediateNode(expression.Tokens, moduleName, name);
}
return null;
}
private ExpressionNode CheckLocalIdentifier(LocalIdentifierSyntax expression, NubType? _)
{
// note(nub31): Local identifiers can be variables or a symbol in a module
var scopeIdent = Scope.LookupVariable(expression.Name);
var scopeIdent = Scope.LookupVariable(expression.NameToken);
if (scopeIdent != null)
{
return new VariableIdentifierNode(expression.Tokens, scopeIdent.Type, expression.Name);
return new VariableIdentifierNode(expression.Tokens, scopeIdent.Type, expression.NameToken);
}
var ident = CheckIdentifier(expression, Scope.Module, expression.Name);
if (ident == null)
var module = GetImportedModule(Scope.Module.Value)!;
var function = module.Functions(true).FirstOrDefault(x => x.NameToken.Value == expression.NameToken.Value);
if (function != null)
{
throw new TypeCheckerException(Diagnostic
.Error($"There is no identifier named {expression.Name}")
.At(expression)
.Build());
var parameters = function.Prototype.Parameters.Select(x => ResolveType(x.Type)).ToList();
var type = new NubFuncType(parameters, ResolveType(function.Prototype.ReturnType));
return new FuncIdentifierNode(expression.Tokens, type, Scope.Module, expression.NameToken, function.Prototype.ExternSymbolToken);
}
return ident;
var enumDef = module.Enums(true).FirstOrDefault(x => x.NameToken.Value == expression.NameToken.Value);
if (enumDef != null)
{
return new EnumReferenceIntermediateNode(expression.Tokens, Scope.Module, expression.NameToken);
}
throw new TypeCheckerException(Diagnostic
.Error($"There is no identifier named {expression.NameToken.Value}")
.At(expression)
.Build());
}
private ExpressionNode CheckModuleIdentifier(ModuleIdentifierSyntax expression, NubType? _)
{
// note(nub31): Unlike local identifiers, module identifiers does not look for local variables
var ident = CheckIdentifier(expression, expression.Module, expression.Name);
if (ident == null)
var module = GetImportedModule(expression.ModuleToken.Value);
if (module == null)
{
throw new TypeCheckerException(Diagnostic
.Error($"Module {expression.Module} does not export a member named {expression.Name}")
.At(expression)
.Error($"Module {expression.ModuleToken.Value} not found")
.WithHelp($"import \"{expression.ModuleToken.Value}\"")
.At(expression.ModuleToken)
.Build());
}
return ident;
var function = module.Functions(false).FirstOrDefault(x => x.NameToken.Value == expression.NameToken.Value);
if (function != null)
{
using (BeginRootScope(expression.ModuleToken))
{
var parameters = function.Prototype.Parameters.Select(x => ResolveType(x.Type)).ToList();
var type = new NubFuncType(parameters, ResolveType(function.Prototype.ReturnType));
return new FuncIdentifierNode(expression.Tokens, type, expression.ModuleToken, expression.NameToken, function.Prototype.ExternSymbolToken);
}
}
var enumDef = module.Enums(false).FirstOrDefault(x => x.NameToken.Value == expression.NameToken.Value);
if (enumDef != null)
{
return new EnumReferenceIntermediateNode(expression.Tokens, expression.ModuleToken, expression.NameToken);
}
throw new TypeCheckerException(Diagnostic
.Error($"Module {expression.ModuleToken.Value} does not export a member named {expression.NameToken.Value}")
.At(expression)
.Build());
}
private ExpressionNode CheckStringLiteral(StringLiteralSyntax expression, NubType? expectedType)
{
if (expectedType is NubPointerType { BaseType: NubIntType { Signed: true, Width: 8 } })
{
return new CStringLiteralNode(expression.Tokens, expression.Value);
return new CStringLiteralNode(expression.Tokens, expression.Token.Value);
}
return new StringLiteralNode(expression.Tokens, expression.Value);
return new StringLiteralNode(expression.Tokens, expression.Token.Value);
}
private ExpressionNode CheckIntLiteral(IntLiteralSyntax expression, NubType? expectedType)
@@ -830,10 +925,10 @@ public sealed class TypeChecker
{
return intType.Width switch
{
8 => intType.Signed ? new I8LiteralNode(expression.Tokens, Convert.ToSByte(expression.Value, expression.Base)) : new U8LiteralNode(expression.Tokens, Convert.ToByte(expression.Value, expression.Base)),
16 => intType.Signed ? new I16LiteralNode(expression.Tokens, Convert.ToInt16(expression.Value, expression.Base)) : new U16LiteralNode(expression.Tokens, Convert.ToUInt16(expression.Value, expression.Base)),
32 => intType.Signed ? new I32LiteralNode(expression.Tokens, Convert.ToInt32(expression.Value, expression.Base)) : new U32LiteralNode(expression.Tokens, Convert.ToUInt32(expression.Value, expression.Base)),
64 => intType.Signed ? new I64LiteralNode(expression.Tokens, Convert.ToInt64(expression.Value, expression.Base)) : new U64LiteralNode(expression.Tokens, Convert.ToUInt64(expression.Value, expression.Base)),
8 => intType.Signed ? new I8LiteralNode(expression.Tokens, expression.Token.AsI8) : new U8LiteralNode(expression.Tokens, expression.Token.AsU8),
16 => intType.Signed ? new I16LiteralNode(expression.Tokens, expression.Token.AsI16) : new U16LiteralNode(expression.Tokens, expression.Token.AsU16),
32 => intType.Signed ? new I32LiteralNode(expression.Tokens, expression.Token.AsI32) : new U32LiteralNode(expression.Tokens, expression.Token.AsU32),
64 => intType.Signed ? new I64LiteralNode(expression.Tokens, expression.Token.AsI64) : new U64LiteralNode(expression.Tokens, expression.Token.AsU64),
_ => throw new ArgumentOutOfRangeException()
};
}
@@ -842,13 +937,13 @@ public sealed class TypeChecker
{
return floatType.Width switch
{
32 => new Float32LiteralNode(expression.Tokens, Convert.ToSingle(expression.Value)),
64 => new Float64LiteralNode(expression.Tokens, Convert.ToDouble(expression.Value)),
32 => new Float32LiteralNode(expression.Tokens, expression.Token.AsF32),
64 => new Float64LiteralNode(expression.Tokens, expression.Token.AsF64),
_ => throw new ArgumentOutOfRangeException()
};
}
return new I64LiteralNode(expression.Tokens, Convert.ToInt64(expression.Value, expression.Base));
return new I64LiteralNode(expression.Tokens, expression.Token.AsI64);
}
private ExpressionNode CheckFloatLiteral(FloatLiteralSyntax expression, NubType? expectedType)
@@ -857,18 +952,18 @@ public sealed class TypeChecker
{
return floatType.Width switch
{
32 => new Float32LiteralNode(expression.Tokens, Convert.ToSingle(expression.Value)),
64 => new Float64LiteralNode(expression.Tokens, Convert.ToDouble(expression.Value)),
32 => new Float32LiteralNode(expression.Tokens, expression.Token.AsF32),
64 => new Float64LiteralNode(expression.Tokens, expression.Token.AsF64),
_ => throw new ArgumentOutOfRangeException()
};
}
return new Float64LiteralNode(expression.Tokens, Convert.ToDouble(expression.Value));
return new Float64LiteralNode(expression.Tokens, expression.Token.AsF64);
}
private BoolLiteralNode CheckBoolLiteral(BoolLiteralSyntax expression, NubType? _)
{
return new BoolLiteralNode(expression.Tokens, new NubBoolType(), expression.Value);
return new BoolLiteralNode(expression.Tokens, expression.Token.Value);
}
private ExpressionNode CheckMemberAccess(MemberAccessSyntax expression, NubType? _)
@@ -877,15 +972,17 @@ public sealed class TypeChecker
if (target is EnumReferenceIntermediateNode enumReferenceIntermediate)
{
var enumDef = _importedModules[enumReferenceIntermediate.Module]
.Enums(IsCurretModule(enumReferenceIntermediate.Module))
.First(x => x.Name == enumReferenceIntermediate.Name);
var enumDef = GetImportedModules()
.First(x => x.Name.Value == enumReferenceIntermediate.ModuleToken.Value)
.Module
.Enums(IsCurrentModule(enumReferenceIntermediate.ModuleToken))
.First(x => x.NameToken.Value == enumReferenceIntermediate.NameToken.Value);
var field = enumDef.Fields.FirstOrDefault(x => x.Name == expression.Member);
var field = enumDef.Fields.FirstOrDefault(x => x.NameToken.Value == expression.MemberToken.Value);
if (field == null)
{
throw new TypeCheckerException(Diagnostic
.Error($"Enum {Scope.Module}::{enumReferenceIntermediate.Name} does not have a field named {expression.Member}")
.Error($"Enum {Scope.Module.Value}::{enumReferenceIntermediate.NameToken.Value} does not have a field named {expression.MemberToken.Value}")
.At(enumDef)
.Build());
}
@@ -896,36 +993,96 @@ public sealed class TypeChecker
throw new TypeCheckerException(Diagnostic.Error("Enum type must be an int type").At(enumDef.Type).Build());
}
return enumIntType.Width switch
if (enumIntType.Signed)
{
8 => enumIntType.Signed ? new I8LiteralNode(expression.Tokens, (sbyte)field.Value) : new U8LiteralNode(expression.Tokens, (byte)field.Value),
16 => enumIntType.Signed ? new I16LiteralNode(expression.Tokens, (short)field.Value) : new U16LiteralNode(expression.Tokens, (ushort)field.Value),
32 => enumIntType.Signed ? new I32LiteralNode(expression.Tokens, (int)field.Value) : new U32LiteralNode(expression.Tokens, (uint)field.Value),
64 => enumIntType.Signed ? new I64LiteralNode(expression.Tokens, field.Value) : new U64LiteralNode(expression.Tokens, (ulong)field.Value),
_ => throw new ArgumentOutOfRangeException()
};
var fieldValue = CalculateSignedEnumFieldValue(enumDef, field);
return enumIntType.Width switch
{
8 => new I8LiteralNode(expression.Tokens, (sbyte)fieldValue),
16 => new I16LiteralNode(expression.Tokens, (short)fieldValue),
32 => new I32LiteralNode(expression.Tokens, (int)fieldValue),
64 => new I64LiteralNode(expression.Tokens, fieldValue),
_ => throw new ArgumentOutOfRangeException()
};
}
else
{
var fieldValue = CalculateUnsignedEnumFieldValue(enumDef, field);
return enumIntType.Width switch
{
8 => new U8LiteralNode(expression.Tokens, (byte)fieldValue),
16 => new U16LiteralNode(expression.Tokens, (ushort)fieldValue),
32 => new U32LiteralNode(expression.Tokens, (uint)fieldValue),
64 => new U64LiteralNode(expression.Tokens, fieldValue),
_ => throw new ArgumentOutOfRangeException()
};
}
}
if (target.Type is NubStructType structType)
{
var field = structType.Fields.FirstOrDefault(x => x.Name == expression.Member);
var field = structType.Fields.FirstOrDefault(x => x.Name == expression.MemberToken.Value);
if (field == null)
{
throw new TypeCheckerException(Diagnostic
.Error($"Struct {target.Type} does not have a field with the name {expression.Member}")
.Error($"Struct {target.Type} does not have a field with the name {expression.MemberToken.Value}")
.At(expression)
.Build());
}
return new StructFieldAccessNode(expression.Tokens, field.Type, target, expression.Member);
return new StructFieldAccessNode(expression.Tokens, field.Type, target, expression.MemberToken);
}
throw new TypeCheckerException(Diagnostic
.Error($"Cannot access struct member {expression.Member} on type {target.Type}")
.Error($"Cannot access struct member {expression.MemberToken.Value} on type {target.Type}")
.At(expression)
.Build());
}
private static long CalculateSignedEnumFieldValue(EnumSyntax enumDef, EnumFieldSyntax field)
{
long currentValue = 0;
foreach (var f in enumDef.Fields)
{
if (f.ValueToken != null)
{
currentValue = f.ValueToken.AsI64;
}
if (f == field)
{
return currentValue;
}
currentValue++;
}
throw new UnreachableException();
}
private static ulong CalculateUnsignedEnumFieldValue(EnumSyntax enumDef, EnumFieldSyntax field)
{
ulong currentValue = 0;
foreach (var f in enumDef.Fields)
{
if (f.ValueToken != null)
{
currentValue = f.ValueToken.AsU64;
}
if (f == field)
{
return currentValue;
}
currentValue++;
}
throw new UnreachableException();
}
private StructInitializerNode CheckStructInitializer(StructInitializerSyntax expression, NubType? expectedType)
{
NubStructType? structType = null;
@@ -954,11 +1111,11 @@ public sealed class TypeChecker
.Build());
}
var initializers = new Dictionary<string, ExpressionNode>();
var initializers = new Dictionary<IdentifierToken, ExpressionNode>();
foreach (var initializer in expression.Initializers)
{
var typeField = structType.Fields.FirstOrDefault(x => x.Name == initializer.Key);
var typeField = structType.Fields.FirstOrDefault(x => x.Name == initializer.Key.Value);
if (typeField == null)
{
Diagnostics.AddRange(Diagnostic
@@ -973,7 +1130,7 @@ public sealed class TypeChecker
}
var missingFields = structType.Fields
.Where(x => !x.HasDefaultValue && !initializers.ContainsKey(x.Name))
.Where(x => !x.HasDefaultValue && initializers.All(y => y.Key.Value != x.Name))
.Select(x => x.Name)
.ToArray();
@@ -1049,25 +1206,26 @@ public sealed class TypeChecker
private NubType ResolveCustomType(CustomTypeSyntax customType)
{
if (!_importedModules.TryGetValue(customType.Module ?? Scope.Module, out var module))
var module = GetImportedModule(customType.ModuleToken?.Value ?? Scope.Module.Value);
if (module == null)
{
throw new TypeCheckerException(Diagnostic
.Error($"Module {customType.Module ?? Scope.Module} not found")
.WithHelp($"import \"{customType.Module ?? Scope.Module}\"")
.Error($"Module {customType.ModuleToken?.Value ?? Scope.Module.Value} not found")
.WithHelp($"import \"{customType.ModuleToken?.Value ?? Scope.Module.Value}\"")
.At(customType)
.Build());
}
var enumDef = module.Enums(IsCurretModule(customType.Module)).FirstOrDefault(x => x.Name == customType.Name);
var enumDef = module.Enums(IsCurrentModule(customType.ModuleToken)).FirstOrDefault(x => x.NameToken.Value == customType.NameToken.Value);
if (enumDef != null)
{
return enumDef.Type != null ? ResolveType(enumDef.Type) : new NubIntType(false, 64);
}
var structDef = module.Structs(IsCurretModule(customType.Module)).FirstOrDefault(x => x.Name == customType.Name);
var structDef = module.Structs(IsCurrentModule(customType.ModuleToken)).FirstOrDefault(x => x.NameToken.Value == customType.NameToken.Value);
if (structDef != null)
{
var key = (customType.Module ?? Scope.Module, customType.Name);
var key = (customType.ModuleToken?.Value ?? Scope.Module.Value, customType.NameToken.Value);
if (_typeCache.TryGetValue(key, out var cachedType))
{
@@ -1076,18 +1234,18 @@ public sealed class TypeChecker
if (!_resolvingTypes.Add(key))
{
var placeholder = new NubStructType(customType.Module ?? Scope.Module, customType.Name, []);
var placeholder = new NubStructType(customType.ModuleToken?.Value ?? Scope.Module.Value, customType.NameToken.Value, []);
_typeCache[key] = placeholder;
return placeholder;
}
try
{
var result = new NubStructType(customType.Module ?? Scope.Module, structDef.Name, []);
var result = new NubStructType(customType.ModuleToken?.Value ?? Scope.Module.Value, structDef.NameToken.Value, []);
_typeCache[key] = result;
var fields = structDef.Fields
.Select(x => new NubStructFieldType(x.Name, ResolveType(x.Type), x.Value != null))
.Select(x => new NubStructFieldType(x.NameToken.Value, ResolveType(x.Type), x.Value != null))
.ToList();
result.Fields.AddRange(fields);
@@ -1100,29 +1258,19 @@ public sealed class TypeChecker
}
throw new TypeCheckerException(Diagnostic
.Error($"Type {customType.Name} not found in module {customType.Module ?? Scope.Module}")
.Error($"Type {customType.NameToken.Value} not found in module {customType.ModuleToken?.Value ?? Scope.Module.Value}")
.At(customType)
.Build());
}
private bool IsCurretModule(string? module)
{
if (module == null)
{
return true;
}
return module == Scope.Module;
}
}
public record Variable(string Name, NubType Type);
public record Variable(IdentifierToken Name, NubType Type);
public class Scope(string module, Scope? parent = null)
public class Scope(IdentifierToken module, Scope? parent = null)
{
private NubType? _returnType;
private readonly List<Variable> _variables = [];
public string Module { get; } = module;
public IdentifierToken Module { get; } = module;
public void DeclareVariable(Variable variable)
{
@@ -1139,9 +1287,9 @@ public class Scope(string module, Scope? parent = null)
return _returnType ?? parent?.GetReturnType();
}
public Variable? LookupVariable(string name)
public Variable? LookupVariable(IdentifierToken name)
{
var variable = _variables.FirstOrDefault(x => x.Name == name);
var variable = _variables.FirstOrDefault(x => x.Name.Value == name.Value);
if (variable != null)
{
return variable;