...
This commit is contained in:
@@ -2,10 +2,11 @@ using NubLang.Syntax;
|
||||
|
||||
namespace NubLang.Ast;
|
||||
|
||||
public sealed class CompilationUnit(IdentifierToken module, List<FuncNode> functions, Dictionary<IdentifierToken, List<NubStructType>> importedStructTypes, Dictionary<IdentifierToken, List<FuncPrototypeNode>> importedFunctions)
|
||||
{
|
||||
public IdentifierToken Module { get; } = module;
|
||||
public List<FuncNode> Functions { get; } = functions;
|
||||
public Dictionary<IdentifierToken, List<NubStructType>> ImportedStructTypes { get; } = importedStructTypes;
|
||||
public Dictionary<IdentifierToken, List<FuncPrototypeNode>> ImportedFunctions { get; } = importedFunctions;
|
||||
}
|
||||
// public sealed class CompilationUnit(IdentifierToken module, List<FuncNode> functions, List<StructNode> structTypes, Dictionary<IdentifierToken, List<NubStructType>> importedStructTypes, Dictionary<IdentifierToken, List<FuncPrototypeNode>> importedFunctions)
|
||||
// {
|
||||
// public IdentifierToken Module { get; } = module;
|
||||
// public List<FuncNode> Functions { get; } = functions;
|
||||
// public List<StructNode> Structs { get; } = structTypes;
|
||||
// public Dictionary<IdentifierToken, List<NubStructType>> ImportedStructTypes { get; } = importedStructTypes;
|
||||
// public Dictionary<IdentifierToken, List<FuncPrototypeNode>> ImportedFunctions { get; } = importedFunctions;
|
||||
// }
|
||||
@@ -29,9 +29,31 @@ public abstract class Node(List<Token> tokens)
|
||||
}
|
||||
}
|
||||
|
||||
public abstract class TopLevelNode(List<Token> tokens) : Node(tokens);
|
||||
|
||||
public class ImportNode(List<Token> tokens, IdentifierToken nameToken) : TopLevelNode(tokens)
|
||||
{
|
||||
public IdentifierToken NameToken { get; } = nameToken;
|
||||
|
||||
public override IEnumerable<Node> Children()
|
||||
{
|
||||
return [];
|
||||
}
|
||||
}
|
||||
|
||||
public class ModuleNode(List<Token> tokens, IdentifierToken nameToken) : TopLevelNode(tokens)
|
||||
{
|
||||
public IdentifierToken NameToken { get; } = nameToken;
|
||||
|
||||
public override IEnumerable<Node> Children()
|
||||
{
|
||||
return [];
|
||||
}
|
||||
}
|
||||
|
||||
#region Definitions
|
||||
|
||||
public abstract class DefinitionNode(List<Token> tokens, IdentifierToken nameToken) : Node(tokens)
|
||||
public abstract class DefinitionNode(List<Token> tokens, IdentifierToken nameToken) : TopLevelNode(tokens)
|
||||
{
|
||||
public IdentifierToken NameToken { get; } = nameToken;
|
||||
}
|
||||
@@ -75,6 +97,35 @@ public class FuncNode(List<Token> tokens, FuncPrototypeNode prototype, BlockNode
|
||||
}
|
||||
}
|
||||
|
||||
public class StructFieldNode(List<Token> tokens, IdentifierToken nameToken, NubType type, ExpressionNode? value) : Node(tokens)
|
||||
{
|
||||
public IdentifierToken NameToken { get; } = nameToken;
|
||||
public NubType Type { get; } = type;
|
||||
public ExpressionNode? Value { get; } = value;
|
||||
|
||||
public override IEnumerable<Node> Children()
|
||||
{
|
||||
if (Value != null)
|
||||
{
|
||||
yield return Value;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public class StructNode(List<Token> tokens, IdentifierToken name, NubStructType structType, List<StructFieldNode> fields) : DefinitionNode(tokens, name)
|
||||
{
|
||||
public NubStructType StructType { get; } = structType;
|
||||
public List<StructFieldNode> Fields { get; } = fields;
|
||||
|
||||
public override IEnumerable<Node> Children()
|
||||
{
|
||||
foreach (var field in Fields)
|
||||
{
|
||||
yield return field;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endregion
|
||||
|
||||
#region Statements
|
||||
|
||||
@@ -10,8 +10,8 @@ public sealed class TypeChecker
|
||||
private readonly Dictionary<string, Module> _modules;
|
||||
|
||||
private readonly Stack<Scope> _scopes = [];
|
||||
private readonly Dictionary<(string Module, string Name), NubType> _typeCache = new();
|
||||
private readonly HashSet<(string Module, string Name)> _resolvingTypes = [];
|
||||
|
||||
private readonly TypeResolver _typeResolver;
|
||||
|
||||
private Scope Scope => _scopes.Peek();
|
||||
|
||||
@@ -21,19 +21,18 @@ public sealed class TypeChecker
|
||||
{
|
||||
_syntaxTree = syntaxTree;
|
||||
_modules = modules;
|
||||
_typeResolver = new TypeResolver(_modules);
|
||||
}
|
||||
|
||||
public CompilationUnit? Check()
|
||||
public List<TopLevelNode> Check()
|
||||
{
|
||||
_scopes.Clear();
|
||||
_typeCache.Clear();
|
||||
_resolvingTypes.Clear();
|
||||
|
||||
var moduleDeclarations = _syntaxTree.TopLevelSyntaxNodes.OfType<ModuleSyntax>().ToList();
|
||||
if (moduleDeclarations.Count == 0)
|
||||
{
|
||||
Diagnostics.Add(Diagnostic.Error("Missing module declaration").WithHelp("module \"main\"").Build());
|
||||
return null;
|
||||
return [];
|
||||
}
|
||||
|
||||
if (moduleDeclarations.Count > 1)
|
||||
@@ -79,72 +78,45 @@ public sealed class TypeChecker
|
||||
.At(last)
|
||||
.Build());
|
||||
|
||||
return null;
|
||||
return [];
|
||||
}
|
||||
}
|
||||
|
||||
var functions = new List<FuncNode>();
|
||||
var topLevelNodes = new List<TopLevelNode>();
|
||||
|
||||
using (BeginRootScope(moduleName))
|
||||
{
|
||||
foreach (var funcSyntax in _syntaxTree.TopLevelSyntaxNodes.OfType<FuncSyntax>())
|
||||
foreach (var topLevelSyntaxNode in _syntaxTree.TopLevelSyntaxNodes)
|
||||
{
|
||||
try
|
||||
switch (topLevelSyntaxNode)
|
||||
{
|
||||
functions.Add(CheckFuncDefinition(funcSyntax));
|
||||
}
|
||||
catch (TypeCheckerException e)
|
||||
{
|
||||
Diagnostics.Add(e.Diagnostic);
|
||||
case EnumSyntax:
|
||||
break;
|
||||
case FuncSyntax funcSyntax:
|
||||
topLevelNodes.Add(CheckFuncDefinition(funcSyntax));
|
||||
break;
|
||||
case StructSyntax structSyntax:
|
||||
topLevelNodes.Add(CheckStructDefinition(structSyntax));
|
||||
break;
|
||||
case ImportSyntax importSyntax:
|
||||
topLevelNodes.Add(new ImportNode(importSyntax.Tokens, importSyntax.NameToken));
|
||||
break;
|
||||
case ModuleSyntax moduleSyntax:
|
||||
topLevelNodes.Add(new ModuleNode(moduleSyntax.Tokens, moduleSyntax.NameToken));
|
||||
break;
|
||||
default:
|
||||
throw new ArgumentOutOfRangeException(nameof(topLevelSyntaxNode));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var importedStructTypes = new Dictionary<IdentifierToken, List<NubStructType>>();
|
||||
var importedFunctions = new Dictionary<IdentifierToken, List<FuncPrototypeNode>>();
|
||||
return topLevelNodes;
|
||||
}
|
||||
|
||||
foreach (var (name, module) in GetImportedModules())
|
||||
{
|
||||
var moduleStructs = new List<NubStructType>();
|
||||
var moduleFunctions = new List<FuncPrototypeNode>();
|
||||
|
||||
using (BeginRootScope(name))
|
||||
{
|
||||
foreach (var structSyntax in module.Structs(true))
|
||||
{
|
||||
try
|
||||
{
|
||||
var fields = structSyntax.Fields
|
||||
.Select(f => new NubStructFieldType(f.NameToken.Value, ResolveType(f.Type), f.Value != null))
|
||||
.ToList();
|
||||
|
||||
moduleStructs.Add(new NubStructType(name.Value, structSyntax.NameToken.Value, fields));
|
||||
}
|
||||
catch (TypeCheckerException e)
|
||||
{
|
||||
Diagnostics.Add(e.Diagnostic);
|
||||
}
|
||||
}
|
||||
|
||||
importedStructTypes[name] = moduleStructs;
|
||||
|
||||
foreach (var funcSyntax in module.Functions(true))
|
||||
{
|
||||
try
|
||||
{
|
||||
moduleFunctions.Add(CheckFuncPrototype(funcSyntax.Prototype));
|
||||
}
|
||||
catch (TypeCheckerException e)
|
||||
{
|
||||
Diagnostics.Add(e.Diagnostic);
|
||||
}
|
||||
}
|
||||
|
||||
importedFunctions[name] = moduleFunctions;
|
||||
}
|
||||
}
|
||||
|
||||
return new CompilationUnit(moduleName, functions, importedStructTypes, importedFunctions);
|
||||
private (IdentifierToken Name, Module Module) GetCurrentModule()
|
||||
{
|
||||
var currentModule = _syntaxTree.TopLevelSyntaxNodes.OfType<ModuleSyntax>().First().NameToken;
|
||||
return (currentModule, _modules[currentModule.Value]);
|
||||
}
|
||||
|
||||
private List<(IdentifierToken Name, Module Module)> GetImportedModules()
|
||||
@@ -225,19 +197,48 @@ public sealed class TypeChecker
|
||||
}
|
||||
}
|
||||
|
||||
private StructNode CheckStructDefinition(StructSyntax structSyntax)
|
||||
{
|
||||
var fields = new List<StructFieldNode>();
|
||||
|
||||
foreach (var field in structSyntax.Fields)
|
||||
{
|
||||
var fieldType = _typeResolver.ResolveType(field.Type, Scope.Module.Value);
|
||||
ExpressionNode? value = null;
|
||||
if (field.Value != null)
|
||||
{
|
||||
value = CheckExpression(field.Value, fieldType);
|
||||
if (value.Type != fieldType)
|
||||
{
|
||||
throw new CompileException(Diagnostic
|
||||
.Error($"Type {value.Type} is not assignable to {field.Type} for field {field.NameToken.Value}")
|
||||
.At(field)
|
||||
.Build());
|
||||
}
|
||||
}
|
||||
|
||||
fields.Add(new StructFieldNode(field.Tokens, field.NameToken, fieldType, value));
|
||||
}
|
||||
|
||||
var currentModule = GetCurrentModule();
|
||||
var type = new NubStructType(currentModule.Name.Value, structSyntax.NameToken.Value, fields.Select(x => new NubStructFieldType(x.NameToken.Value, x.Type, x.Value != null)).ToList());
|
||||
|
||||
return new StructNode(structSyntax.Tokens, structSyntax.NameToken, type, fields);
|
||||
}
|
||||
|
||||
private AssignmentNode CheckAssignment(AssignmentSyntax statement)
|
||||
{
|
||||
var target = CheckExpression(statement.Target);
|
||||
if (target is not LValueExpressionNode lValue)
|
||||
{
|
||||
throw new TypeCheckerException(Diagnostic.Error("Cannot assign to an rvalue").At(statement).Build());
|
||||
throw new CompileException(Diagnostic.Error("Cannot assign to an rvalue").At(statement).Build());
|
||||
}
|
||||
|
||||
var value = CheckExpression(statement.Value, lValue.Type);
|
||||
|
||||
if (value.Type != lValue.Type)
|
||||
{
|
||||
throw new TypeCheckerException(Diagnostic
|
||||
throw new CompileException(Diagnostic
|
||||
.Error($"Cannot assign {value.Type} to {lValue.Type}")
|
||||
.At(statement.Value)
|
||||
.Build());
|
||||
@@ -279,7 +280,7 @@ public sealed class TypeChecker
|
||||
return expression switch
|
||||
{
|
||||
FuncCallNode funcCall => new StatementFuncCallNode(statement.Tokens, funcCall),
|
||||
_ => throw new TypeCheckerException(Diagnostic.Error("Expressions statements can only be function calls").At(statement).Build())
|
||||
_ => throw new CompileException(Diagnostic.Error("Expressions statements can only be function calls").At(statement).Build())
|
||||
};
|
||||
}
|
||||
|
||||
@@ -290,7 +291,7 @@ public sealed class TypeChecker
|
||||
|
||||
if (statement.ExplicitType != null)
|
||||
{
|
||||
type = ResolveType(statement.ExplicitType);
|
||||
type = _typeResolver.ResolveType(statement.ExplicitType, Scope.Module.Value);
|
||||
}
|
||||
|
||||
if (statement.Assignment != null)
|
||||
@@ -303,7 +304,7 @@ public sealed class TypeChecker
|
||||
}
|
||||
else if (assignmentNode.Type != type)
|
||||
{
|
||||
throw new TypeCheckerException(Diagnostic
|
||||
throw new CompileException(Diagnostic
|
||||
.Error($"Cannot assign {assignmentNode.Type} to variable of type {type}")
|
||||
.At(statement.Assignment)
|
||||
.Build());
|
||||
@@ -312,7 +313,7 @@ public sealed class TypeChecker
|
||||
|
||||
if (type == null)
|
||||
{
|
||||
throw new TypeCheckerException(Diagnostic
|
||||
throw new CompileException(Diagnostic
|
||||
.Error($"Cannot infer type of variable {statement.NameToken.Value}")
|
||||
.At(statement)
|
||||
.Build());
|
||||
@@ -367,7 +368,7 @@ public sealed class TypeChecker
|
||||
}
|
||||
default:
|
||||
{
|
||||
throw new TypeCheckerException(Diagnostic
|
||||
throw new CompileException(Diagnostic
|
||||
.Error($"Cannot iterate over type {target.Type} which does not have size information")
|
||||
.At(forSyntax.Target)
|
||||
.Build());
|
||||
@@ -380,10 +381,10 @@ public sealed class TypeChecker
|
||||
var parameters = new List<FuncParameterNode>();
|
||||
foreach (var parameter in statement.Parameters)
|
||||
{
|
||||
parameters.Add(new FuncParameterNode(parameter.Tokens, parameter.NameToken, ResolveType(parameter.Type)));
|
||||
parameters.Add(new FuncParameterNode(parameter.Tokens, parameter.NameToken, _typeResolver.ResolveType(parameter.Type, Scope.Module.Value)));
|
||||
}
|
||||
|
||||
return new FuncPrototypeNode(statement.Tokens, statement.NameToken, statement.ExternSymbolToken, parameters, ResolveType(statement.ReturnType));
|
||||
return new FuncPrototypeNode(statement.Tokens, statement.NameToken, statement.ExternSymbolToken, parameters, _typeResolver.ResolveType(statement.ReturnType, Scope.Module.Value));
|
||||
}
|
||||
|
||||
private ExpressionNode CheckExpression(ExpressionSyntax node, NubType? expectedType = null)
|
||||
@@ -405,7 +406,7 @@ public sealed class TypeChecker
|
||||
FloatLiteralSyntax expression => CheckFloatLiteral(expression, expectedType),
|
||||
MemberAccessSyntax expression => CheckMemberAccess(expression, expectedType),
|
||||
StructInitializerSyntax expression => CheckStructInitializer(expression, expectedType),
|
||||
SizeSyntax expression => new SizeNode(node.Tokens, ResolveType(expression.Type)),
|
||||
SizeSyntax expression => new SizeNode(node.Tokens, _typeResolver.ResolveType(expression.Type, Scope.Module.Value)),
|
||||
CastSyntax expression => CheckCast(expression, expectedType),
|
||||
_ => throw new ArgumentOutOfRangeException(nameof(node))
|
||||
};
|
||||
@@ -430,7 +431,7 @@ public sealed class TypeChecker
|
||||
{
|
||||
if (expectedType == null)
|
||||
{
|
||||
throw new TypeCheckerException(Diagnostic
|
||||
throw new CompileException(Diagnostic
|
||||
.Error("Unable to infer target type of cast")
|
||||
.At(expression)
|
||||
.WithHelp("Specify target type where value is used")
|
||||
@@ -451,7 +452,7 @@ public sealed class TypeChecker
|
||||
|
||||
if (!IsCastAllowed(value.Type, expectedType, false))
|
||||
{
|
||||
throw new TypeCheckerException(Diagnostic
|
||||
throw new CompileException(Diagnostic
|
||||
.Error($"Cannot cast from {value.Type} to {expectedType}")
|
||||
.Build());
|
||||
}
|
||||
@@ -500,7 +501,7 @@ public sealed class TypeChecker
|
||||
var target = CheckExpression(expression.Target, (expectedType as NubPointerType)?.BaseType);
|
||||
if (target is not LValueExpressionNode lvalue)
|
||||
{
|
||||
throw new TypeCheckerException(Diagnostic.Error("Cannot take address of an rvalue").At(expression).Build());
|
||||
throw new CompileException(Diagnostic.Error("Cannot take address of an rvalue").At(expression).Build());
|
||||
}
|
||||
|
||||
var type = new NubPointerType(target.Type);
|
||||
@@ -512,7 +513,7 @@ public sealed class TypeChecker
|
||||
var index = CheckExpression(expression.Index);
|
||||
if (index.Type is not NubIntType)
|
||||
{
|
||||
throw new TypeCheckerException(Diagnostic
|
||||
throw new CompileException(Diagnostic
|
||||
.Error("Array indexer must be of type int")
|
||||
.At(expression.Index)
|
||||
.Build());
|
||||
@@ -525,7 +526,7 @@ public sealed class TypeChecker
|
||||
NubArrayType arrayType => new ArrayIndexAccessNode(expression.Tokens, arrayType.ElementType, target, index),
|
||||
NubConstArrayType constArrayType => new ConstArrayIndexAccessNode(expression.Tokens, constArrayType.ElementType, target, index),
|
||||
NubSliceType sliceType => new SliceIndexAccessNode(expression.Tokens, sliceType.ElementType, target, index),
|
||||
_ => throw new TypeCheckerException(Diagnostic.Error($"Cannot use array indexer on type {target.Type}").At(expression).Build())
|
||||
_ => throw new CompileException(Diagnostic.Error($"Cannot use array indexer on type {target.Type}").At(expression).Build())
|
||||
};
|
||||
}
|
||||
|
||||
@@ -550,7 +551,7 @@ public sealed class TypeChecker
|
||||
|
||||
if (elementType == null)
|
||||
{
|
||||
throw new TypeCheckerException(Diagnostic
|
||||
throw new CompileException(Diagnostic
|
||||
.Error("Unable to infer type of array initializer")
|
||||
.At(expression)
|
||||
.WithHelp("Provide a type for a variable assignment")
|
||||
@@ -563,7 +564,7 @@ public sealed class TypeChecker
|
||||
var value = CheckExpression(valueExpression, elementType);
|
||||
if (value.Type != elementType)
|
||||
{
|
||||
throw new TypeCheckerException(Diagnostic
|
||||
throw new CompileException(Diagnostic
|
||||
.Error("Value in array initializer is not the same as the array type")
|
||||
.At(valueExpression)
|
||||
.Build());
|
||||
@@ -613,7 +614,7 @@ public sealed class TypeChecker
|
||||
var left = CheckExpression(expression.Left);
|
||||
if (left.Type is not NubIntType and not NubFloatType and not NubBoolType)
|
||||
{
|
||||
throw new TypeCheckerException(Diagnostic
|
||||
throw new CompileException(Diagnostic
|
||||
.Error("Equal and not equal operators must must be used with int, float or bool types")
|
||||
.At(expression.Left)
|
||||
.Build());
|
||||
@@ -622,7 +623,7 @@ public sealed class TypeChecker
|
||||
var right = CheckExpression(expression.Right, left.Type);
|
||||
if (right.Type != left.Type)
|
||||
{
|
||||
throw new TypeCheckerException(Diagnostic
|
||||
throw new CompileException(Diagnostic
|
||||
.Error($"Expected type {left.Type} from left side of binary expression, but got {right.Type}")
|
||||
.At(expression.Right)
|
||||
.Build());
|
||||
@@ -638,7 +639,7 @@ public sealed class TypeChecker
|
||||
var left = CheckExpression(expression.Left);
|
||||
if (left.Type is not NubIntType and not NubFloatType)
|
||||
{
|
||||
throw new TypeCheckerException(Diagnostic
|
||||
throw new CompileException(Diagnostic
|
||||
.Error("Greater than and less than operators must must be used with int or float types")
|
||||
.At(expression.Left)
|
||||
.Build());
|
||||
@@ -647,7 +648,7 @@ public sealed class TypeChecker
|
||||
var right = CheckExpression(expression.Right, left.Type);
|
||||
if (right.Type != left.Type)
|
||||
{
|
||||
throw new TypeCheckerException(Diagnostic
|
||||
throw new CompileException(Diagnostic
|
||||
.Error($"Expected type {left.Type} from left side of binary expression, but got {right.Type}")
|
||||
.At(expression.Right)
|
||||
.Build());
|
||||
@@ -661,7 +662,7 @@ public sealed class TypeChecker
|
||||
var left = CheckExpression(expression.Left);
|
||||
if (left.Type is not NubBoolType)
|
||||
{
|
||||
throw new TypeCheckerException(Diagnostic
|
||||
throw new CompileException(Diagnostic
|
||||
.Error("Logical and/or must must be used with bool types")
|
||||
.At(expression.Left)
|
||||
.Build());
|
||||
@@ -670,7 +671,7 @@ public sealed class TypeChecker
|
||||
var right = CheckExpression(expression.Right, left.Type);
|
||||
if (right.Type != left.Type)
|
||||
{
|
||||
throw new TypeCheckerException(Diagnostic
|
||||
throw new CompileException(Diagnostic
|
||||
.Error($"Expected type {left.Type} from left side of binary expression, but got {right.Type}")
|
||||
.At(expression.Right)
|
||||
.Build());
|
||||
@@ -683,7 +684,7 @@ public sealed class TypeChecker
|
||||
var left = CheckExpression(expression.Left, expectedType);
|
||||
if (left.Type is not NubIntType and not NubFloatType)
|
||||
{
|
||||
throw new TypeCheckerException(Diagnostic
|
||||
throw new CompileException(Diagnostic
|
||||
.Error("The plus operator must only be used with int and float types")
|
||||
.At(expression.Left)
|
||||
.Build());
|
||||
@@ -692,7 +693,7 @@ public sealed class TypeChecker
|
||||
var right = CheckExpression(expression.Right, left.Type);
|
||||
if (right.Type != left.Type)
|
||||
{
|
||||
throw new TypeCheckerException(Diagnostic
|
||||
throw new CompileException(Diagnostic
|
||||
.Error($"Expected type {left.Type} from left side of binary expression, but got {right.Type}")
|
||||
.At(expression.Right)
|
||||
.Build());
|
||||
@@ -708,7 +709,7 @@ public sealed class TypeChecker
|
||||
var left = CheckExpression(expression.Left, expectedType);
|
||||
if (left.Type is not NubIntType and not NubFloatType)
|
||||
{
|
||||
throw new TypeCheckerException(Diagnostic
|
||||
throw new CompileException(Diagnostic
|
||||
.Error("Math operators must be used with int or float types")
|
||||
.At(expression.Left)
|
||||
.Build());
|
||||
@@ -717,7 +718,7 @@ public sealed class TypeChecker
|
||||
var right = CheckExpression(expression.Right, left.Type);
|
||||
if (right.Type != left.Type)
|
||||
{
|
||||
throw new TypeCheckerException(Diagnostic
|
||||
throw new CompileException(Diagnostic
|
||||
.Error($"Expected type {left.Type} from left side of binary expression, but got {right.Type}")
|
||||
.At(expression.Right)
|
||||
.Build());
|
||||
@@ -734,7 +735,7 @@ public sealed class TypeChecker
|
||||
var left = CheckExpression(expression.Left, expectedType);
|
||||
if (left.Type is not NubIntType)
|
||||
{
|
||||
throw new TypeCheckerException(Diagnostic
|
||||
throw new CompileException(Diagnostic
|
||||
.Error("Bitwise operators must be used with int types")
|
||||
.At(expression.Left)
|
||||
.Build());
|
||||
@@ -743,7 +744,7 @@ public sealed class TypeChecker
|
||||
var right = CheckExpression(expression.Right, left.Type);
|
||||
if (right.Type != left.Type)
|
||||
{
|
||||
throw new TypeCheckerException(Diagnostic
|
||||
throw new CompileException(Diagnostic
|
||||
.Error($"Expected type {left.Type} from left side of binary expression, but got {right.Type}")
|
||||
.At(expression.Right)
|
||||
.Build());
|
||||
@@ -767,7 +768,7 @@ public sealed class TypeChecker
|
||||
var operand = CheckExpression(expression.Operand, expectedType);
|
||||
if (operand.Type is not NubIntType { Signed: true } and not NubFloatType)
|
||||
{
|
||||
throw new TypeCheckerException(Diagnostic
|
||||
throw new CompileException(Diagnostic
|
||||
.Error("Negation operator must be used with signed integer or float types")
|
||||
.At(expression)
|
||||
.Build());
|
||||
@@ -780,7 +781,7 @@ public sealed class TypeChecker
|
||||
var operand = CheckExpression(expression.Operand, expectedType);
|
||||
if (operand.Type is not NubBoolType)
|
||||
{
|
||||
throw new TypeCheckerException(Diagnostic
|
||||
throw new CompileException(Diagnostic
|
||||
.Error("Invert operator must be used with booleans")
|
||||
.At(expression)
|
||||
.Build());
|
||||
@@ -803,7 +804,7 @@ public sealed class TypeChecker
|
||||
{
|
||||
NubPointerType pointerType => new DereferenceNode(expression.Tokens, pointerType.BaseType, target),
|
||||
NubRefType refType => new RefDereferenceNode(expression.Tokens, refType.BaseType, target),
|
||||
_ => throw new TypeCheckerException(Diagnostic.Error($"Cannot dereference non-pointer type {target.Type}").At(expression).Build())
|
||||
_ => throw new CompileException(Diagnostic.Error($"Cannot dereference non-pointer type {target.Type}").At(expression).Build())
|
||||
};
|
||||
}
|
||||
|
||||
@@ -812,12 +813,12 @@ public sealed class TypeChecker
|
||||
var accessor = CheckExpression(expression.Expression);
|
||||
if (accessor.Type is not NubFuncType funcType)
|
||||
{
|
||||
throw new TypeCheckerException(Diagnostic.Error($"Cannot call non-function type {accessor.Type}").At(expression.Expression).Build());
|
||||
throw new CompileException(Diagnostic.Error($"Cannot call non-function type {accessor.Type}").At(expression.Expression).Build());
|
||||
}
|
||||
|
||||
if (expression.Parameters.Count != funcType.Parameters.Count)
|
||||
{
|
||||
throw new TypeCheckerException(Diagnostic
|
||||
throw new CompileException(Diagnostic
|
||||
.Error($"Function {funcType} expects {funcType.Parameters.Count} parameters but got {expression.Parameters.Count}")
|
||||
.At(expression.Parameters.LastOrDefault(expression))
|
||||
.Build());
|
||||
@@ -832,7 +833,7 @@ public sealed class TypeChecker
|
||||
var parameterExpression = CheckExpression(parameter, expectedParameterType);
|
||||
if (parameterExpression.Type != expectedParameterType)
|
||||
{
|
||||
throw new TypeCheckerException(Diagnostic
|
||||
throw new CompileException(Diagnostic
|
||||
.Error($"Parameter {i + 1} does not match the type {expectedParameterType} for function {funcType}")
|
||||
.At(parameter)
|
||||
.Build());
|
||||
@@ -858,8 +859,8 @@ public sealed class TypeChecker
|
||||
var function = module.Functions(true).FirstOrDefault(x => x.NameToken.Value == expression.NameToken.Value);
|
||||
if (function != null)
|
||||
{
|
||||
var parameters = function.Prototype.Parameters.Select(x => ResolveType(x.Type)).ToList();
|
||||
var type = new NubFuncType(parameters, ResolveType(function.Prototype.ReturnType));
|
||||
var parameters = function.Prototype.Parameters.Select(x => _typeResolver.ResolveType(x.Type, Scope.Module.Value)).ToList();
|
||||
var type = new NubFuncType(parameters, _typeResolver.ResolveType(function.Prototype.ReturnType, Scope.Module.Value));
|
||||
return new FuncIdentifierNode(expression.Tokens, type, Scope.Module, expression.NameToken, function.Prototype.ExternSymbolToken);
|
||||
}
|
||||
|
||||
@@ -869,7 +870,7 @@ public sealed class TypeChecker
|
||||
return new EnumReferenceIntermediateNode(expression.Tokens, Scope.Module, expression.NameToken);
|
||||
}
|
||||
|
||||
throw new TypeCheckerException(Diagnostic
|
||||
throw new CompileException(Diagnostic
|
||||
.Error($"There is no identifier named {expression.NameToken.Value}")
|
||||
.At(expression)
|
||||
.Build());
|
||||
@@ -880,7 +881,7 @@ public sealed class TypeChecker
|
||||
var module = GetImportedModule(expression.ModuleToken.Value);
|
||||
if (module == null)
|
||||
{
|
||||
throw new TypeCheckerException(Diagnostic
|
||||
throw new CompileException(Diagnostic
|
||||
.Error($"Module {expression.ModuleToken.Value} not found")
|
||||
.WithHelp($"import \"{expression.ModuleToken.Value}\"")
|
||||
.At(expression.ModuleToken)
|
||||
@@ -892,8 +893,8 @@ public sealed class TypeChecker
|
||||
{
|
||||
using (BeginRootScope(expression.ModuleToken))
|
||||
{
|
||||
var parameters = function.Prototype.Parameters.Select(x => ResolveType(x.Type)).ToList();
|
||||
var type = new NubFuncType(parameters, ResolveType(function.Prototype.ReturnType));
|
||||
var parameters = function.Prototype.Parameters.Select(x => _typeResolver.ResolveType(x.Type, Scope.Module.Value)).ToList();
|
||||
var type = new NubFuncType(parameters, _typeResolver.ResolveType(function.Prototype.ReturnType, Scope.Module.Value));
|
||||
return new FuncIdentifierNode(expression.Tokens, type, expression.ModuleToken, expression.NameToken, function.Prototype.ExternSymbolToken);
|
||||
}
|
||||
}
|
||||
@@ -904,7 +905,7 @@ public sealed class TypeChecker
|
||||
return new EnumReferenceIntermediateNode(expression.Tokens, expression.ModuleToken, expression.NameToken);
|
||||
}
|
||||
|
||||
throw new TypeCheckerException(Diagnostic
|
||||
throw new CompileException(Diagnostic
|
||||
.Error($"Module {expression.ModuleToken.Value} does not export a member named {expression.NameToken.Value}")
|
||||
.At(expression)
|
||||
.Build());
|
||||
@@ -982,16 +983,16 @@ public sealed class TypeChecker
|
||||
var field = enumDef.Fields.FirstOrDefault(x => x.NameToken.Value == expression.MemberToken.Value);
|
||||
if (field == null)
|
||||
{
|
||||
throw new TypeCheckerException(Diagnostic
|
||||
throw new CompileException(Diagnostic
|
||||
.Error($"Enum {Scope.Module.Value}::{enumReferenceIntermediate.NameToken.Value} does not have a field named {expression.MemberToken.Value}")
|
||||
.At(enumDef)
|
||||
.Build());
|
||||
}
|
||||
|
||||
var enumType = enumDef.Type != null ? ResolveType(enumDef.Type) : new NubIntType(false, 64);
|
||||
var enumType = enumDef.Type != null ? _typeResolver.ResolveType(enumDef.Type, Scope.Module.Value) : new NubIntType(false, 64);
|
||||
if (enumType is not NubIntType enumIntType)
|
||||
{
|
||||
throw new TypeCheckerException(Diagnostic.Error("Enum type must be an int type").At(enumDef.Type).Build());
|
||||
throw new CompileException(Diagnostic.Error("Enum type must be an int type").At(enumDef.Type).Build());
|
||||
}
|
||||
|
||||
if (enumIntType.Signed)
|
||||
@@ -1027,7 +1028,7 @@ public sealed class TypeChecker
|
||||
var field = structType.Fields.FirstOrDefault(x => x.Name == expression.MemberToken.Value);
|
||||
if (field == null)
|
||||
{
|
||||
throw new TypeCheckerException(Diagnostic
|
||||
throw new CompileException(Diagnostic
|
||||
.Error($"Struct {target.Type} does not have a field with the name {expression.MemberToken.Value}")
|
||||
.At(expression)
|
||||
.Build());
|
||||
@@ -1037,7 +1038,7 @@ public sealed class TypeChecker
|
||||
}
|
||||
default:
|
||||
{
|
||||
throw new TypeCheckerException(Diagnostic
|
||||
throw new CompileException(Diagnostic
|
||||
.Error($"Cannot access struct member {expression.MemberToken.Value} on type {target.Type}")
|
||||
.At(expression)
|
||||
.Build());
|
||||
@@ -1095,7 +1096,7 @@ public sealed class TypeChecker
|
||||
|
||||
if (expression.StructType != null)
|
||||
{
|
||||
var checkedType = ResolveType(expression.StructType);
|
||||
var checkedType = _typeResolver.ResolveType(expression.StructType, Scope.Module.Value);
|
||||
if (checkedType is not NubStructType checkedStructType)
|
||||
{
|
||||
throw new UnreachableException("Parser fucked up");
|
||||
@@ -1115,7 +1116,7 @@ public sealed class TypeChecker
|
||||
|
||||
if (structType == null)
|
||||
{
|
||||
throw new TypeCheckerException(Diagnostic
|
||||
throw new CompileException(Diagnostic
|
||||
.Error("Cannot get implicit type of struct")
|
||||
.WithHelp("Specify struct type with struct {type_name} syntax")
|
||||
.At(expression)
|
||||
@@ -1174,7 +1175,7 @@ public sealed class TypeChecker
|
||||
{
|
||||
statements.Add(CheckStatement(statement));
|
||||
}
|
||||
catch (TypeCheckerException e)
|
||||
catch (CompileException e)
|
||||
{
|
||||
Diagnostics.Add(e.Diagnostic);
|
||||
}
|
||||
@@ -1202,85 +1203,6 @@ public sealed class TypeChecker
|
||||
_ => throw new ArgumentOutOfRangeException(nameof(statement))
|
||||
};
|
||||
}
|
||||
|
||||
private NubType ResolveType(TypeSyntax type)
|
||||
{
|
||||
return type switch
|
||||
{
|
||||
ArrayTypeSyntax arr => new NubArrayType(ResolveType(arr.BaseType)),
|
||||
BoolTypeSyntax => new NubBoolType(),
|
||||
IntTypeSyntax i => new NubIntType(i.Signed, i.Width),
|
||||
FloatTypeSyntax f => new NubFloatType(f.Width),
|
||||
FuncTypeSyntax func => new NubFuncType(func.Parameters.Select(ResolveType).ToList(), ResolveType(func.ReturnType)),
|
||||
SliceTypeSyntax slice => new NubSliceType(ResolveType(slice.BaseType)),
|
||||
ConstArrayTypeSyntax arr => new NubConstArrayType(ResolveType(arr.BaseType), arr.Size),
|
||||
PointerTypeSyntax ptr => new NubPointerType(ResolveType(ptr.BaseType)),
|
||||
RefTypeSyntax r => new NubRefType(ResolveType(r.BaseType)),
|
||||
StringTypeSyntax => new NubStringType(),
|
||||
CustomTypeSyntax c => ResolveCustomType(c),
|
||||
VoidTypeSyntax => new NubVoidType(),
|
||||
_ => throw new NotSupportedException($"Unknown type syntax: {type}")
|
||||
};
|
||||
}
|
||||
|
||||
private NubType ResolveCustomType(CustomTypeSyntax customType)
|
||||
{
|
||||
var module = GetImportedModule(customType.ModuleToken?.Value ?? Scope.Module.Value);
|
||||
if (module == null)
|
||||
{
|
||||
throw new TypeCheckerException(Diagnostic
|
||||
.Error($"Module {customType.ModuleToken?.Value ?? Scope.Module.Value} not found")
|
||||
.WithHelp($"import \"{customType.ModuleToken?.Value ?? Scope.Module.Value}\"")
|
||||
.At(customType)
|
||||
.Build());
|
||||
}
|
||||
|
||||
var enumDef = module.Enums(IsCurrentModule(customType.ModuleToken)).FirstOrDefault(x => x.NameToken.Value == customType.NameToken.Value);
|
||||
if (enumDef != null)
|
||||
{
|
||||
return enumDef.Type != null ? ResolveType(enumDef.Type) : new NubIntType(false, 64);
|
||||
}
|
||||
|
||||
var structDef = module.Structs(IsCurrentModule(customType.ModuleToken)).FirstOrDefault(x => x.NameToken.Value == customType.NameToken.Value);
|
||||
if (structDef != null)
|
||||
{
|
||||
var key = (customType.ModuleToken?.Value ?? Scope.Module.Value, customType.NameToken.Value);
|
||||
|
||||
if (_typeCache.TryGetValue(key, out var cachedType))
|
||||
{
|
||||
return cachedType;
|
||||
}
|
||||
|
||||
if (!_resolvingTypes.Add(key))
|
||||
{
|
||||
var placeholder = new NubStructType(customType.ModuleToken?.Value ?? Scope.Module.Value, customType.NameToken.Value, []);
|
||||
_typeCache[key] = placeholder;
|
||||
return placeholder;
|
||||
}
|
||||
|
||||
try
|
||||
{
|
||||
var result = new NubStructType(customType.ModuleToken?.Value ?? Scope.Module.Value, structDef.NameToken.Value, []);
|
||||
_typeCache[key] = result;
|
||||
|
||||
var fields = structDef.Fields
|
||||
.Select(x => new NubStructFieldType(x.NameToken.Value, ResolveType(x.Type), x.Value != null))
|
||||
.ToList();
|
||||
|
||||
result.Fields.AddRange(fields);
|
||||
return result;
|
||||
}
|
||||
finally
|
||||
{
|
||||
_resolvingTypes.Remove(key);
|
||||
}
|
||||
}
|
||||
|
||||
throw new TypeCheckerException(Diagnostic
|
||||
.Error($"Type {customType.NameToken.Value} not found in module {customType.ModuleToken?.Value ?? Scope.Module.Value}")
|
||||
.At(customType)
|
||||
.Build());
|
||||
}
|
||||
}
|
||||
|
||||
public record Variable(IdentifierToken Name, NubType Type);
|
||||
@@ -1321,14 +1243,4 @@ public class Scope(IdentifierToken module, Scope? parent = null)
|
||||
{
|
||||
return new Scope(Module, this);
|
||||
}
|
||||
}
|
||||
|
||||
public class TypeCheckerException : Exception
|
||||
{
|
||||
public Diagnostic Diagnostic { get; }
|
||||
|
||||
public TypeCheckerException(Diagnostic diagnostic) : base(diagnostic.Message)
|
||||
{
|
||||
Diagnostic = diagnostic;
|
||||
}
|
||||
}
|
||||
97
compiler/NubLang/Ast/TypeResolver.cs
Normal file
97
compiler/NubLang/Ast/TypeResolver.cs
Normal file
@@ -0,0 +1,97 @@
|
||||
using NubLang.Diagnostics;
|
||||
using NubLang.Syntax;
|
||||
|
||||
namespace NubLang.Ast;
|
||||
|
||||
public class TypeResolver
|
||||
{
|
||||
private readonly Dictionary<string, Module> _modules;
|
||||
private readonly Dictionary<(string Module, string Name), NubType> _typeCache = new();
|
||||
private readonly HashSet<(string Module, string Name)> _resolvingTypes = [];
|
||||
|
||||
public TypeResolver(Dictionary<string, Module> modules)
|
||||
{
|
||||
_modules = modules;
|
||||
}
|
||||
|
||||
public NubType ResolveType(TypeSyntax type, string currentModule)
|
||||
{
|
||||
return type switch
|
||||
{
|
||||
ArrayTypeSyntax arr => new NubArrayType(ResolveType(arr.BaseType, currentModule)),
|
||||
BoolTypeSyntax => new NubBoolType(),
|
||||
IntTypeSyntax i => new NubIntType(i.Signed, i.Width),
|
||||
FloatTypeSyntax f => new NubFloatType(f.Width),
|
||||
FuncTypeSyntax func => new NubFuncType(func.Parameters.Select(x => ResolveType(x, currentModule)).ToList(), ResolveType(func.ReturnType, currentModule)),
|
||||
SliceTypeSyntax slice => new NubSliceType(ResolveType(slice.BaseType, currentModule)),
|
||||
ConstArrayTypeSyntax arr => new NubConstArrayType(ResolveType(arr.BaseType, currentModule), arr.Size),
|
||||
PointerTypeSyntax ptr => new NubPointerType(ResolveType(ptr.BaseType, currentModule)),
|
||||
RefTypeSyntax r => new NubRefType(ResolveType(r.BaseType, currentModule)),
|
||||
StringTypeSyntax => new NubStringType(),
|
||||
CustomTypeSyntax c => ResolveCustomType(c, currentModule),
|
||||
VoidTypeSyntax => new NubVoidType(),
|
||||
_ => throw new NotSupportedException($"Unknown type syntax: {type}")
|
||||
};
|
||||
}
|
||||
|
||||
private NubType ResolveCustomType(CustomTypeSyntax customType, string currentModule)
|
||||
{
|
||||
var module = _modules[customType.ModuleToken?.Value ?? currentModule];
|
||||
|
||||
var enumDef = module.Enums(true).FirstOrDefault(x => x.NameToken.Value == customType.NameToken.Value);
|
||||
if (enumDef != null)
|
||||
{
|
||||
return enumDef.Type != null ? ResolveType(enumDef.Type, currentModule) : new NubIntType(false, 64);
|
||||
}
|
||||
|
||||
var structDef = module.Structs(true).FirstOrDefault(x => x.NameToken.Value == customType.NameToken.Value);
|
||||
if (structDef != null)
|
||||
{
|
||||
var key = (customType.ModuleToken?.Value ?? currentModule, customType.NameToken.Value);
|
||||
|
||||
if (_typeCache.TryGetValue(key, out var cachedType))
|
||||
{
|
||||
return cachedType;
|
||||
}
|
||||
|
||||
if (!_resolvingTypes.Add(key))
|
||||
{
|
||||
var placeholder = new NubStructType(customType.ModuleToken?.Value ?? currentModule, customType.NameToken.Value, []);
|
||||
_typeCache[key] = placeholder;
|
||||
return placeholder;
|
||||
}
|
||||
|
||||
try
|
||||
{
|
||||
var result = new NubStructType(customType.ModuleToken?.Value ?? currentModule, structDef.NameToken.Value, []);
|
||||
_typeCache[key] = result;
|
||||
|
||||
var fields = structDef.Fields
|
||||
.Select(x => new NubStructFieldType(x.NameToken.Value, ResolveType(x.Type, currentModule), x.Value != null))
|
||||
.ToList();
|
||||
|
||||
result.Fields.AddRange(fields);
|
||||
return result;
|
||||
}
|
||||
finally
|
||||
{
|
||||
_resolvingTypes.Remove(key);
|
||||
}
|
||||
}
|
||||
|
||||
throw new TypeResolverException(Diagnostic
|
||||
.Error($"Type {customType.NameToken.Value} not found in module {customType.ModuleToken?.Value ?? currentModule}")
|
||||
.At(customType)
|
||||
.Build());
|
||||
}
|
||||
}
|
||||
|
||||
public class TypeResolverException : Exception
|
||||
{
|
||||
public Diagnostic Diagnostic { get; }
|
||||
|
||||
public TypeResolverException(Diagnostic diagnostic) : base(diagnostic.Message)
|
||||
{
|
||||
Diagnostic = diagnostic;
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user