This commit is contained in:
nub31
2025-09-11 21:22:30 +02:00
parent 5fecfeba43
commit fd27d2709d
46 changed files with 5339 additions and 0 deletions

View File

@@ -0,0 +1,171 @@
using NubLang.Parsing.Syntax;
using NubLang.TypeChecking.Node;
namespace NubLang.TypeChecking;
public class Module
{
public static IReadOnlyList<Module> CollectFromSyntaxTrees(IReadOnlyList<SyntaxTree> syntaxTrees)
{
var modules = new Dictionary<string, Module>();
foreach (var syntaxTree in syntaxTrees)
{
var name = syntaxTree.Metadata.ModuleName;
if (name == null)
{
continue;
}
if (!modules.TryGetValue(name, out var module))
{
module = new Module(name, syntaxTree.Metadata.Imports);
modules[name] = module;
}
foreach (var definition in syntaxTree.Definitions)
{
module.AddDefinition(definition);
}
}
return modules.Values.ToList();
}
private readonly List<DefinitionSyntax> _definitions = [];
public Module(string name, IReadOnlyList<string> imports)
{
Name = name;
Imports = imports;
}
public string Name { get; }
public IReadOnlyList<string> Imports { get; }
public IReadOnlyList<DefinitionSyntax> Definitions => _definitions;
private void AddDefinition(DefinitionSyntax syntax)
{
_definitions.Add(syntax);
}
}
public class TypedModule
{
public TypedModule(string name, IReadOnlyList<DefinitionNode> definitions)
{
Name = name;
Definitions = definitions;
}
public string Name { get; }
public IReadOnlyList<DefinitionNode> Definitions { get; }
}
public class ModuleSignature
{
public static IReadOnlyDictionary<string, ModuleSignature> CollectFromSyntaxTrees(IReadOnlyList<SyntaxTree> syntaxTrees)
{
var modules = new Dictionary<string, ModuleSignature>();
foreach (var syntaxTree in syntaxTrees)
{
var moduleName = syntaxTree.Metadata.ModuleName;
if (moduleName == null)
{
continue;
}
if (!modules.TryGetValue(moduleName, out var module))
{
module = new ModuleSignature();
modules[moduleName] = module;
}
foreach (var def in syntaxTree.Definitions)
{
if (!def.Exported) continue;
switch (def)
{
case FuncSyntax funcDef:
{
var parameters = funcDef.Signature.Parameters.Select(p => TypeResolver.ResolveType(p.Type, modules)).ToList();
var returnType = TypeResolver.ResolveType(funcDef.Signature.ReturnType, modules);
var type = new FuncTypeNode(parameters, returnType, funcDef.ExternSymbol);
module._functions.Add(funcDef.Name, type);
break;
}
case InterfaceSyntax interfaceDef:
{
var functions = new List<InterfaceTypeFunc>();
for (var i = 0; i < interfaceDef.Functions.Count; i++)
{
var function = interfaceDef.Functions[i];
var parameters = function.Signature.Parameters.Select(p => TypeResolver.ResolveType(p.Type, modules)).ToList();
var returnType = TypeResolver.ResolveType(function.Signature.ReturnType, modules);
functions.Add(new InterfaceTypeFunc(function.Name, new FuncTypeNode(parameters, returnType), i));
}
var type = new InterfaceTypeNode(moduleName, interfaceDef.Name, functions);
module._interfaces.Add(type);
break;
}
case StructSyntax structDef:
{
var fields = new List<StructTypeField>();
foreach (var field in structDef.Fields)
{
fields.Add(new StructTypeField(field.Name, TypeResolver.ResolveType(field.Type, modules), field.Index, field.Value.HasValue));
}
var functions = new List<StructTypeFunc>();
foreach (var function in structDef.Functions)
{
var parameters = function.Signature.Parameters.Select(p => TypeResolver.ResolveType(p.Type, modules)).ToList();
var returnType = TypeResolver.ResolveType(function.Signature.ReturnType, modules);
functions.Add(new StructTypeFunc(function.Name, new FuncTypeNode(parameters, returnType)));
}
var interfaceImplementations = new List<InterfaceTypeNode>();
foreach (var interfaceImplementation in structDef.InterfaceImplementations)
{
if (interfaceImplementation is not CustomTypeSyntax customType)
{
throw new Exception("Interface implementation is not a custom type");
}
var resolvedType = TypeResolver.ResolveCustomType(customType.Module, customType.Name, modules);
if (resolvedType is not InterfaceTypeNode interfaceType)
{
throw new Exception("Interface implementation is not a interface");
}
interfaceImplementations.Add(interfaceType);
}
var type = new StructTypeNode(moduleName, structDef.Name, fields, functions, interfaceImplementations);
module._structs.Add(type);
break;
}
default:
{
throw new ArgumentOutOfRangeException(nameof(def));
}
}
}
}
return modules;
}
private readonly List<StructTypeNode> _structs = [];
private readonly List<InterfaceTypeNode> _interfaces = [];
private readonly Dictionary<string, FuncTypeNode> _functions = [];
public IReadOnlyList<StructTypeNode> StructTypes => _structs;
public IReadOnlyList<InterfaceTypeNode> InterfaceTypes => _interfaces;
public IReadOnlyDictionary<string, FuncTypeNode> Functions => _functions;
}

View File

@@ -0,0 +1,19 @@
namespace NubLang.TypeChecking.Node;
public abstract record DefinitionNode : Node;
public record FuncParameterNode(string Name, TypeNode Type) : Node;
public record FuncSignatureNode(IReadOnlyList<FuncParameterNode> Parameters, TypeNode ReturnType) : Node;
public record FuncNode(string Name, FuncSignatureNode Signature, BlockNode? Body) : DefinitionNode;
public record StructFieldNode(int Index, string Name, TypeNode Type, Optional<ExpressionNode> Value) : Node;
public record StructFuncNode(string Name, FuncSignatureNode Signature, BlockNode Body) : Node;
public record StructNode(string Name, IReadOnlyList<StructFieldNode> Fields, IReadOnlyList<StructFuncNode> Functions, IReadOnlyList<InterfaceTypeNode> InterfaceImplementations) : DefinitionNode;
public record InterfaceFuncNode(string Name, FuncSignatureNode Signature) : Node;
public record InterfaceNode(string Name, IReadOnlyList<InterfaceFuncNode> Functions) : DefinitionNode;

View File

@@ -0,0 +1,72 @@
using NubLang.Tokenization;
namespace NubLang.TypeChecking.Node;
public enum UnaryOperator
{
Negate,
Invert
}
public enum BinaryOperator
{
Equal,
NotEqual,
GreaterThan,
GreaterThanOrEqual,
LessThan,
LessThanOrEqual,
LogicalAnd,
LogicalOr,
Plus,
Minus,
Multiply,
Divide,
Modulo,
LeftShift,
RightShift,
BitwiseAnd,
BitwiseXor,
BitwiseOr
}
public abstract record ExpressionNode(TypeNode Type) : Node;
public abstract record LValueExpressionNode(TypeNode Type) : RValueExpressionNode(Type);
public abstract record RValueExpressionNode(TypeNode Type) : ExpressionNode(Type);
public record BinaryExpressionNode(TypeNode Type, ExpressionNode Left, BinaryOperator Operator, ExpressionNode Right) : RValueExpressionNode(Type);
public record UnaryExpressionNode(TypeNode Type, UnaryOperator Operator, ExpressionNode Operand) : RValueExpressionNode(Type);
public record FuncCallNode(TypeNode Type, ExpressionNode Expression, IReadOnlyList<ExpressionNode> Parameters) : RValueExpressionNode(Type);
public record StructFuncCallNode(TypeNode Type, string Name, StructTypeNode StructType, ExpressionNode StructExpression, IReadOnlyList<ExpressionNode> Parameters) : RValueExpressionNode(Type);
public record InterfaceFuncCallNode(TypeNode Type, string Name, InterfaceTypeNode InterfaceType, ExpressionNode InterfaceExpression, IReadOnlyList<ExpressionNode> Parameters) : RValueExpressionNode(Type);
public record VariableIdentifierNode(TypeNode Type, string Name) : LValueExpressionNode(Type);
public record FuncParameterIdentifierNode(TypeNode Type, string Name) : RValueExpressionNode(Type);
public record FuncIdentifierNode(TypeNode Type, string Module, string Name) : RValueExpressionNode(Type);
public record ArrayInitializerNode(TypeNode Type, ExpressionNode Capacity, TypeNode ElementType) : RValueExpressionNode(Type);
public record ArrayIndexAccessNode(TypeNode Type, ExpressionNode Target, ExpressionNode Index) : LValueExpressionNode(Type);
public record AddressOfNode(TypeNode Type, LValueExpressionNode LValue) : RValueExpressionNode(Type);
public record LiteralNode(TypeNode Type, string Value, LiteralKind Kind) : RValueExpressionNode(Type);
public record StructFieldAccessNode(TypeNode Type, StructTypeNode StructType, ExpressionNode Target, string Field) : LValueExpressionNode(Type);
public record StructInitializerNode(StructTypeNode StructType, Dictionary<string, ExpressionNode> Initializers) : RValueExpressionNode(StructType);
public record DereferenceNode(TypeNode Type, ExpressionNode Expression) : RValueExpressionNode(Type);
public record ConvertToInterfaceNode(TypeNode Type, InterfaceTypeNode InterfaceType, StructTypeNode StructType, ExpressionNode Implementation) : RValueExpressionNode(Type);
public record ConvertIntNode(TypeNode Type, ExpressionNode Value, IntTypeNode ValueType, IntTypeNode TargetType) : RValueExpressionNode(Type);
public record ConvertFloatNode(TypeNode Type, ExpressionNode Value, FloatTypeNode ValueType, FloatTypeNode TargetType) : RValueExpressionNode(Type);

View File

@@ -0,0 +1,5 @@
namespace NubLang.TypeChecking.Node;
public abstract record Node;
public record BlockNode(IReadOnlyList<StatementNode> Statements) : Node;

View File

@@ -0,0 +1,19 @@
namespace NubLang.TypeChecking.Node;
public record StatementNode : Node;
public record StatementExpressionNode(ExpressionNode Expression) : StatementNode;
public record ReturnNode(Optional<ExpressionNode> Value) : StatementNode;
public record AssignmentNode(LValueExpressionNode Target, ExpressionNode Value) : StatementNode;
public record IfNode(ExpressionNode Condition, BlockNode Body, Optional<Variant<IfNode, BlockNode>> Else) : StatementNode;
public record VariableDeclarationNode(string Name, Optional<ExpressionNode> Assignment, TypeNode Type) : StatementNode;
public record ContinueNode : StatementNode;
public record BreakNode : StatementNode;
public record WhileNode(ExpressionNode Condition, BlockNode Body) : StatementNode;

View File

@@ -0,0 +1,233 @@
using System.Diagnostics.CodeAnalysis;
namespace NubLang.TypeChecking.Node;
public abstract class TypeNode : IEquatable<TypeNode>
{
public bool IsSimpleType([NotNullWhen(true)] out SimpleTypeNode? simpleType, [NotNullWhen(false)] out ComplexTypeNode? complexType)
{
if (this is SimpleTypeNode st)
{
complexType = null;
simpleType = st;
return true;
}
if (this is ComplexTypeNode ct)
{
complexType = ct;
simpleType = null;
return false;
}
throw new ArgumentException($"Type {this} is not a simple type nor a complex type");
}
public override bool Equals(object? obj) => obj is TypeNode other && Equals(other);
public abstract bool Equals(TypeNode? other);
public abstract override int GetHashCode();
public abstract override string ToString();
public static bool operator ==(TypeNode? left, TypeNode? right) => Equals(left, right);
public static bool operator !=(TypeNode? left, TypeNode? right) => !Equals(left, right);
}
public enum StorageSize
{
Void,
I8,
I16,
I32,
I64,
U8,
U16,
U32,
U64,
F32,
F64
}
public abstract class SimpleTypeNode : TypeNode
{
public abstract StorageSize StorageSize { get; }
}
#region Simple types
public class IntTypeNode(bool signed, int width) : SimpleTypeNode
{
public bool Signed { get; } = signed;
public int Width { get; } = width;
public override StorageSize StorageSize => Signed switch
{
true => Width switch
{
8 => StorageSize.I8,
16 => StorageSize.I16,
32 => StorageSize.I32,
64 => StorageSize.I64,
_ => throw new ArgumentOutOfRangeException(nameof(Width))
},
false => Width switch
{
8 => StorageSize.U8,
16 => StorageSize.U16,
32 => StorageSize.U32,
64 => StorageSize.U64,
_ => throw new ArgumentOutOfRangeException(nameof(Width))
}
};
public override string ToString() => $"{(Signed ? "i" : "u")}{Width}";
public override bool Equals(TypeNode? other) => other is IntTypeNode @int && @int.Width == Width && @int.Signed == Signed;
public override int GetHashCode() => HashCode.Combine(typeof(IntTypeNode), Signed, Width);
}
public class FloatTypeNode(int width) : SimpleTypeNode
{
public int Width { get; } = width;
public override StorageSize StorageSize => Width switch
{
32 => StorageSize.F32,
64 => StorageSize.F64,
_ => throw new ArgumentOutOfRangeException(nameof(Width))
};
public override string ToString() => $"f{Width}";
public override bool Equals(TypeNode? other) => other is FloatTypeNode @int && @int.Width == Width;
public override int GetHashCode() => HashCode.Combine(typeof(FloatTypeNode), Width);
}
public class BoolTypeNode : SimpleTypeNode
{
public override StorageSize StorageSize => StorageSize.U8;
public override string ToString() => "bool";
public override bool Equals(TypeNode? other) => other is BoolTypeNode;
public override int GetHashCode() => HashCode.Combine(typeof(BoolTypeNode));
}
public class FuncTypeNode(IReadOnlyList<TypeNode> parameters, TypeNode returnType, string? externSymbol = null) : SimpleTypeNode
{
public string? ExternSymbol { get; } = externSymbol;
public IReadOnlyList<TypeNode> Parameters { get; } = parameters;
public TypeNode ReturnType { get; } = returnType;
public override StorageSize StorageSize => StorageSize.U64;
public override string ToString() => $"func({string.Join(", ", Parameters)}): {ReturnType}";
public override bool Equals(TypeNode? other) => other is FuncTypeNode func && ReturnType.Equals(func.ReturnType) && Parameters.SequenceEqual(func.Parameters);
public override int GetHashCode()
{
var hash = new HashCode();
hash.Add(typeof(FuncTypeNode));
hash.Add(ReturnType);
foreach (var param in Parameters)
{
hash.Add(param);
}
return hash.ToHashCode();
}
}
public class PointerTypeNode(TypeNode baseType) : SimpleTypeNode
{
public TypeNode BaseType { get; } = baseType;
public override StorageSize StorageSize => StorageSize.U64;
public override string ToString() => "^" + BaseType;
public override bool Equals(TypeNode? other) => other is PointerTypeNode pointer && BaseType.Equals(pointer.BaseType);
public override int GetHashCode() => HashCode.Combine(typeof(PointerTypeNode), BaseType);
}
public class VoidTypeNode : SimpleTypeNode
{
public override StorageSize StorageSize => StorageSize.Void;
public override string ToString() => "void";
public override bool Equals(TypeNode? other) => other is VoidTypeNode;
public override int GetHashCode() => HashCode.Combine(typeof(VoidTypeNode));
}
#endregion
public abstract class ComplexTypeNode : TypeNode;
#region Complex types
public class CStringTypeNode : ComplexTypeNode
{
public override string ToString() => "cstring";
public override bool Equals(TypeNode? other) => other is CStringTypeNode;
public override int GetHashCode() => HashCode.Combine(typeof(CStringTypeNode));
}
public class StringTypeNode : ComplexTypeNode
{
public override string ToString() => "string";
public override bool Equals(TypeNode? other) => other is StringTypeNode;
public override int GetHashCode() => HashCode.Combine(typeof(StringTypeNode));
}
public class StructTypeField(string name, TypeNode type, int index, bool hasDefaultValue)
{
public string Name { get; } = name;
public TypeNode Type { get; } = type;
public int Index { get; } = index;
public bool HasDefaultValue { get; } = hasDefaultValue;
}
public class StructTypeFunc(string name, FuncTypeNode type)
{
public string Name { get; } = name;
public FuncTypeNode Type { get; } = type;
}
public class StructTypeNode(string module, string name, IReadOnlyList<StructTypeField> fields, IReadOnlyList<StructTypeFunc> functions, IReadOnlyList<InterfaceTypeNode> interfaceImplementations) : ComplexTypeNode
{
public string Module { get; } = module;
public string Name { get; } = name;
public IReadOnlyList<StructTypeField> Fields { get; set; } = fields;
public IReadOnlyList<StructTypeFunc> Functions { get; set; } = functions;
public IReadOnlyList<InterfaceTypeNode> InterfaceImplementations { get; set; } = interfaceImplementations;
public override string ToString() => Name;
public override bool Equals(TypeNode? other) => other is StructTypeNode structType && Name == structType.Name && Module == structType.Module;
public override int GetHashCode() => HashCode.Combine(typeof(StructTypeNode), Name);
}
public class InterfaceTypeFunc(string name, FuncTypeNode type, int index)
{
public string Name { get; } = name;
public FuncTypeNode Type { get; } = type;
public int Index { get; } = index;
}
public class InterfaceTypeNode(string module, string name, IReadOnlyList<InterfaceTypeFunc> functions) : ComplexTypeNode
{
public string Module { get; } = module;
public string Name { get; } = name;
public IReadOnlyList<InterfaceTypeFunc> Functions { get; set; } = functions;
public override string ToString() => Name;
public override bool Equals(TypeNode? other) => other is InterfaceTypeNode interfaceType && Name == interfaceType.Name && Module == interfaceType.Module;
public override int GetHashCode() => HashCode.Combine(typeof(InterfaceTypeNode), Name);
}
public class ArrayTypeNode(TypeNode elementType) : ComplexTypeNode
{
public TypeNode ElementType { get; } = elementType;
public override string ToString() => "[]" + ElementType;
public override bool Equals(TypeNode? other) => other is ArrayTypeNode array && ElementType.Equals(array.ElementType);
public override int GetHashCode() => HashCode.Combine(typeof(ArrayTypeNode), ElementType);
}
#endregion

View File

@@ -0,0 +1,456 @@
using NubLang.Diagnostics;
using NubLang.Parsing.Syntax;
using NubLang.Tokenization;
using NubLang.TypeChecking.Node;
namespace NubLang.TypeChecking;
public sealed class TypeChecker
{
private readonly Module _currentModule;
private readonly IReadOnlyDictionary<string, ModuleSignature> _moduleSignatures;
private readonly Stack<Scope> _scopes = [];
private readonly Stack<TypeNode> _funcReturnTypes = [];
private readonly List<Diagnostic> _diagnostics = [];
private Scope Scope => _scopes.Peek();
public TypeChecker(Module currentModule, IReadOnlyDictionary<string, ModuleSignature> moduleSignatures)
{
_currentModule = currentModule;
_moduleSignatures = moduleSignatures.Where(x => currentModule.Imports.Contains(x.Key) || _currentModule.Name == x.Key).ToDictionary();
}
public IReadOnlyList<Diagnostic> GetDiagnostics() => _diagnostics;
public TypedModule CheckModule()
{
_diagnostics.Clear();
_scopes.Clear();
var definitions = new List<DefinitionNode>();
foreach (var definition in _currentModule.Definitions)
{
try
{
definitions.Add(CheckDefinition(definition));
}
catch (TypeCheckerException e)
{
_diagnostics.Add(e.Diagnostic);
}
}
return new TypedModule(_currentModule.Name, definitions);
}
private DefinitionNode CheckDefinition(DefinitionSyntax node)
{
return node switch
{
InterfaceSyntax definition => CheckInterfaceDefinition(definition),
FuncSyntax definition => CheckFuncDefinition(definition),
StructSyntax definition => CheckStructDefinition(definition),
_ => throw new ArgumentOutOfRangeException(nameof(node))
};
}
private InterfaceNode CheckInterfaceDefinition(InterfaceSyntax node)
{
throw new NotImplementedException();
}
private StructNode CheckStructDefinition(StructSyntax node)
{
var fields = new List<StructFieldNode>();
foreach (var field in node.Fields)
{
var value = Optional.Empty<ExpressionNode>();
if (field.Value.HasValue)
{
value = CheckExpression(field.Value.Value);
}
fields.Add(new StructFieldNode(field.Index, field.Name, ResolveType(field.Type), value));
}
var functions = new List<StructFuncNode>();
foreach (var function in node.Functions)
{
var scope = new Scope();
// todo(nub31): Add this parameter
foreach (var parameter in function.Signature.Parameters)
{
scope.Declare(new Identifier(parameter.Name, ResolveType(parameter.Type), IdentifierKind.FunctionParameter));
}
_funcReturnTypes.Push(ResolveType(function.Signature.ReturnType));
var body = CheckBlock(function.Body, scope);
_funcReturnTypes.Pop();
functions.Add(new StructFuncNode(function.Name, CheckFuncSignature(function.Signature), body));
}
var interfaceImplementations = new List<InterfaceTypeNode>();
foreach (var interfaceImplementation in node.InterfaceImplementations)
{
var type = ResolveType(interfaceImplementation);
if (type is not InterfaceTypeNode interfaceType)
{
_diagnostics.Add(Diagnostic.Error($"Struct {node.Name} cannot implement non-struct type {interfaceImplementation}").At(interfaceImplementation).Build());
continue;
}
interfaceImplementations.Add(interfaceType);
}
return new StructNode(node.Name, fields, functions, interfaceImplementations);
}
private FuncNode CheckFuncDefinition(FuncSyntax node)
{
var scope = new Scope();
foreach (var parameter in node.Signature.Parameters)
{
scope.Declare(new Identifier(parameter.Name, ResolveType(parameter.Type), IdentifierKind.FunctionParameter));
}
BlockNode? body = null;
if (node.Body != null)
{
_funcReturnTypes.Push(ResolveType(node.Signature.ReturnType));
body = CheckBlock(node.Body, scope);
_funcReturnTypes.Pop();
}
return new FuncNode(node.Name, CheckFuncSignature(node.Signature), body);
}
private StatementNode CheckStatement(StatementSyntax node)
{
return node switch
{
AssignmentSyntax statement => CheckAssignment(statement),
BreakSyntax => new BreakNode(),
ContinueSyntax => new ContinueNode(),
IfSyntax statement => CheckIf(statement),
ReturnSyntax statement => CheckReturn(statement),
StatementExpressionSyntax statement => CheckStatementExpression(statement),
VariableDeclarationSyntax statement => CheckVariableDeclaration(statement),
WhileSyntax statement => CheckWhile(statement),
_ => throw new ArgumentOutOfRangeException(nameof(node))
};
}
private StatementNode CheckAssignment(AssignmentSyntax statement)
{
throw new NotImplementedException();
}
private IfNode CheckIf(IfSyntax statement)
{
throw new NotImplementedException();
}
private ReturnNode CheckReturn(ReturnSyntax statement)
{
var value = Optional.Empty<ExpressionNode>();
if (statement.Value.HasValue)
{
value = CheckExpression(statement.Value.Value, _funcReturnTypes.Peek());
}
return new ReturnNode(value);
}
private StatementExpressionNode CheckStatementExpression(StatementExpressionSyntax statement)
{
return new StatementExpressionNode(CheckExpression(statement.Expression));
}
private VariableDeclarationNode CheckVariableDeclaration(VariableDeclarationSyntax statement)
{
TypeNode? type = null;
ExpressionNode? assignmentNode = null;
if (statement.ExplicitType.TryGetValue(out var explicitType))
{
type = ResolveType(explicitType);
}
if (statement.Assignment.TryGetValue(out var assignment))
{
assignmentNode = CheckExpression(assignment, type);
type ??= assignmentNode.Type;
}
if (type == null)
{
throw new TypeCheckerException(Diagnostic.Error($"Cannot infer type of variable {statement.Name}").At(statement).Build());
}
Scope.Declare(new Identifier(statement.Name, type, IdentifierKind.Variable));
return new VariableDeclarationNode(statement.Name, Optional.OfNullable(assignmentNode), type);
}
private WhileNode CheckWhile(WhileSyntax statement)
{
throw new NotImplementedException();
}
private FuncSignatureNode CheckFuncSignature(FuncSignatureSyntax statement)
{
var parameters = new List<FuncParameterNode>();
foreach (var parameter in statement.Parameters)
{
parameters.Add(new FuncParameterNode(parameter.Name, ResolveType(parameter.Type)));
}
return new FuncSignatureNode(parameters, ResolveType(statement.ReturnType));
}
private ExpressionNode CheckExpression(ExpressionSyntax node, TypeNode? expectedType = null)
{
var result = node switch
{
AddressOfSyntax expression => CheckAddressOf(expression),
ArrayIndexAccessSyntax expression => CheckArrayIndexAccess(expression),
ArrayInitializerSyntax expression => CheckArrayInitializer(expression),
BinaryExpressionSyntax expression => CheckBinaryExpression(expression),
DereferenceSyntax expression => CheckDereference(expression),
DotFuncCallSyntax expression => CheckDotFuncCall(expression),
FuncCallSyntax expression => CheckFuncCall(expression),
IdentifierSyntax expression => CheckIdentifier(expression),
LiteralSyntax expression => CheckLiteral(expression, expectedType),
StructFieldAccessSyntax expression => CheckStructFieldAccess(expression),
StructInitializerSyntax expression => CheckStructInitializer(expression, expectedType),
UnaryExpressionSyntax expression => CheckUnaryExpression(expression),
_ => throw new ArgumentOutOfRangeException(nameof(node))
};
if (expectedType == null || result.Type == expectedType)
{
return result;
}
if (result.Type is StructTypeNode structType && expectedType is InterfaceTypeNode interfaceType)
{
return new ConvertToInterfaceNode(interfaceType, interfaceType, structType, result);
}
if (result.Type is IntTypeNode sourceIntType && expectedType is IntTypeNode targetIntType)
{
if (sourceIntType.Signed == targetIntType.Signed && sourceIntType.Width < targetIntType.Width)
{
return new ConvertIntNode(targetIntType, result, sourceIntType, targetIntType);
}
}
if (result.Type is FloatTypeNode sourceFloatType && expectedType is FloatTypeNode targetFloatType)
{
if (sourceFloatType.Width < targetFloatType.Width)
{
return new ConvertFloatNode(targetFloatType, result, sourceFloatType, targetFloatType);
}
}
throw new TypeCheckerException(Diagnostic.Error($"Cannot convert {result.Type} to {expectedType}").At(node).Build());
}
private AddressOfNode CheckAddressOf(AddressOfSyntax expression)
{
throw new NotImplementedException();
}
private ArrayIndexAccessNode CheckArrayIndexAccess(ArrayIndexAccessSyntax expression)
{
throw new NotImplementedException();
}
private ArrayInitializerNode CheckArrayInitializer(ArrayInitializerSyntax expression)
{
throw new NotImplementedException();
}
private BinaryExpressionNode CheckBinaryExpression(BinaryExpressionSyntax expression)
{
throw new NotImplementedException();
}
private DereferenceNode CheckDereference(DereferenceSyntax expression)
{
throw new NotImplementedException();
}
private FuncCallNode CheckFuncCall(FuncCallSyntax expression)
{
var accessor = CheckExpression(expression.Expression);
if (accessor.Type is not FuncTypeNode funcType)
{
throw new TypeCheckerException(Diagnostic.Error($"Cannot call non-function type {accessor.Type}").At(expression.Expression).Build());
}
if (expression.Parameters.Count != funcType.Parameters.Count)
{
throw new TypeCheckerException(Diagnostic.Error($"Function {funcType} expects {funcType.Parameters} but got {expression.Parameters.Count} parameters").At(expression.Expression).Build());
}
var parameters = new List<ExpressionNode>();
for (var i = 0; i < expression.Parameters.Count; i++)
{
var parameter = expression.Parameters[i];
var expectedType = funcType.Parameters[i];
var parameterExpression = CheckExpression(parameter, expectedType);
if (parameterExpression.Type != expectedType)
{
throw new Exception($"Parameter {i + 1} does not match the type {expectedType} for function {funcType}");
}
parameters.Add(parameterExpression);
}
return new FuncCallNode(funcType.ReturnType, accessor, parameters);
}
private ExpressionNode CheckDotFuncCall(DotFuncCallSyntax expression)
{
throw new NotImplementedException();
}
private ExpressionNode CheckIdentifier(IdentifierSyntax expression)
{
// If the identifier does not have a module specified, first check if a local variable or function parameter with that identifier exists
if (!expression.Module.TryGetValue(out var moduleName))
{
var scopeIdent = Scope.Lookup(expression.Name);
if (scopeIdent != null)
{
switch (scopeIdent.Kind)
{
case IdentifierKind.Variable:
{
return new VariableIdentifierNode(scopeIdent.Type, expression.Name);
}
case IdentifierKind.FunctionParameter:
{
return new FuncParameterIdentifierNode(scopeIdent.Type, expression.Name);
}
default:
{
throw new ArgumentOutOfRangeException();
}
}
}
}
moduleName ??= _currentModule.Name;
if (_moduleSignatures.TryGetValue(moduleName, out var module))
{
if (module.Functions.TryGetValue(expression.Name, out var function))
{
return new FuncIdentifierNode(function, moduleName, expression.Name);
}
}
throw new TypeCheckerException(Diagnostic.Error($"Identifier {expression.Name} not found").At(expression).Build());
}
private LiteralNode CheckLiteral(LiteralSyntax expression, TypeNode? expectedType)
{
// todo(nub31): Check if the types can actually be represented as another one. For example, an int should be passed when a string is expected
var type = expectedType ?? expression.Kind switch
{
LiteralKind.Integer => new IntTypeNode(true, 64),
LiteralKind.Float => new FloatTypeNode(64),
LiteralKind.String => new StringTypeNode(),
LiteralKind.Bool => new BoolTypeNode(),
_ => throw new ArgumentOutOfRangeException()
};
return new LiteralNode(type, expression.Value, expression.Kind);
}
private StructFieldAccessNode CheckStructFieldAccess(StructFieldAccessSyntax expression)
{
throw new NotImplementedException();
}
private StructInitializerNode CheckStructInitializer(StructInitializerSyntax expression, TypeNode? expectedType)
{
throw new NotImplementedException();
}
private UnaryExpressionNode CheckUnaryExpression(UnaryExpressionSyntax expression)
{
throw new NotImplementedException();
}
private BlockNode CheckBlock(BlockSyntax node, Scope? scope = null)
{
var statements = new List<StatementNode>();
_scopes.Push(scope ?? Scope.SubScope());
foreach (var statement in node.Statements)
{
statements.Add(CheckStatement(statement));
}
_scopes.Pop();
return new BlockNode(statements);
}
private TypeNode ResolveType(TypeSyntax fieldType)
{
return TypeResolver.ResolveType(fieldType, _moduleSignatures);
}
}
public enum IdentifierKind
{
Variable,
FunctionParameter
}
public record Identifier(string Name, TypeNode Type, IdentifierKind Kind);
public class Scope(Scope? parent = null)
{
private readonly List<Identifier> _variables = [];
public Identifier? Lookup(string name)
{
var variable = _variables.FirstOrDefault(x => x.Name == name);
if (variable != null)
{
return variable;
}
return parent?.Lookup(name);
}
public void Declare(Identifier identifier)
{
_variables.Add(identifier);
}
public Scope SubScope()
{
return new Scope(this);
}
}
public class TypeCheckerException : Exception
{
public Diagnostic Diagnostic { get; }
public TypeCheckerException(Diagnostic diagnostic) : base(diagnostic.Message)
{
Diagnostic = diagnostic;
}
}