...
This commit is contained in:
171
compiler/NubLang/TypeChecking/Module.cs
Normal file
171
compiler/NubLang/TypeChecking/Module.cs
Normal file
@@ -0,0 +1,171 @@
|
||||
using NubLang.Parsing.Syntax;
|
||||
using NubLang.TypeChecking.Node;
|
||||
|
||||
namespace NubLang.TypeChecking;
|
||||
|
||||
public class Module
|
||||
{
|
||||
public static IReadOnlyList<Module> CollectFromSyntaxTrees(IReadOnlyList<SyntaxTree> syntaxTrees)
|
||||
{
|
||||
var modules = new Dictionary<string, Module>();
|
||||
|
||||
foreach (var syntaxTree in syntaxTrees)
|
||||
{
|
||||
var name = syntaxTree.Metadata.ModuleName;
|
||||
if (name == null)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!modules.TryGetValue(name, out var module))
|
||||
{
|
||||
module = new Module(name, syntaxTree.Metadata.Imports);
|
||||
modules[name] = module;
|
||||
}
|
||||
|
||||
foreach (var definition in syntaxTree.Definitions)
|
||||
{
|
||||
module.AddDefinition(definition);
|
||||
}
|
||||
}
|
||||
|
||||
return modules.Values.ToList();
|
||||
}
|
||||
|
||||
private readonly List<DefinitionSyntax> _definitions = [];
|
||||
|
||||
public Module(string name, IReadOnlyList<string> imports)
|
||||
{
|
||||
Name = name;
|
||||
Imports = imports;
|
||||
}
|
||||
|
||||
public string Name { get; }
|
||||
public IReadOnlyList<string> Imports { get; }
|
||||
|
||||
public IReadOnlyList<DefinitionSyntax> Definitions => _definitions;
|
||||
|
||||
private void AddDefinition(DefinitionSyntax syntax)
|
||||
{
|
||||
_definitions.Add(syntax);
|
||||
}
|
||||
}
|
||||
|
||||
public class TypedModule
|
||||
{
|
||||
public TypedModule(string name, IReadOnlyList<DefinitionNode> definitions)
|
||||
{
|
||||
Name = name;
|
||||
Definitions = definitions;
|
||||
}
|
||||
|
||||
public string Name { get; }
|
||||
public IReadOnlyList<DefinitionNode> Definitions { get; }
|
||||
}
|
||||
|
||||
public class ModuleSignature
|
||||
{
|
||||
public static IReadOnlyDictionary<string, ModuleSignature> CollectFromSyntaxTrees(IReadOnlyList<SyntaxTree> syntaxTrees)
|
||||
{
|
||||
var modules = new Dictionary<string, ModuleSignature>();
|
||||
|
||||
foreach (var syntaxTree in syntaxTrees)
|
||||
{
|
||||
var moduleName = syntaxTree.Metadata.ModuleName;
|
||||
if (moduleName == null)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!modules.TryGetValue(moduleName, out var module))
|
||||
{
|
||||
module = new ModuleSignature();
|
||||
modules[moduleName] = module;
|
||||
}
|
||||
|
||||
foreach (var def in syntaxTree.Definitions)
|
||||
{
|
||||
if (!def.Exported) continue;
|
||||
|
||||
switch (def)
|
||||
{
|
||||
case FuncSyntax funcDef:
|
||||
{
|
||||
var parameters = funcDef.Signature.Parameters.Select(p => TypeResolver.ResolveType(p.Type, modules)).ToList();
|
||||
var returnType = TypeResolver.ResolveType(funcDef.Signature.ReturnType, modules);
|
||||
|
||||
var type = new FuncTypeNode(parameters, returnType, funcDef.ExternSymbol);
|
||||
module._functions.Add(funcDef.Name, type);
|
||||
break;
|
||||
}
|
||||
case InterfaceSyntax interfaceDef:
|
||||
{
|
||||
var functions = new List<InterfaceTypeFunc>();
|
||||
for (var i = 0; i < interfaceDef.Functions.Count; i++)
|
||||
{
|
||||
var function = interfaceDef.Functions[i];
|
||||
var parameters = function.Signature.Parameters.Select(p => TypeResolver.ResolveType(p.Type, modules)).ToList();
|
||||
var returnType = TypeResolver.ResolveType(function.Signature.ReturnType, modules);
|
||||
functions.Add(new InterfaceTypeFunc(function.Name, new FuncTypeNode(parameters, returnType), i));
|
||||
}
|
||||
|
||||
var type = new InterfaceTypeNode(moduleName, interfaceDef.Name, functions);
|
||||
module._interfaces.Add(type);
|
||||
break;
|
||||
}
|
||||
case StructSyntax structDef:
|
||||
{
|
||||
var fields = new List<StructTypeField>();
|
||||
foreach (var field in structDef.Fields)
|
||||
{
|
||||
fields.Add(new StructTypeField(field.Name, TypeResolver.ResolveType(field.Type, modules), field.Index, field.Value.HasValue));
|
||||
}
|
||||
|
||||
var functions = new List<StructTypeFunc>();
|
||||
foreach (var function in structDef.Functions)
|
||||
{
|
||||
var parameters = function.Signature.Parameters.Select(p => TypeResolver.ResolveType(p.Type, modules)).ToList();
|
||||
var returnType = TypeResolver.ResolveType(function.Signature.ReturnType, modules);
|
||||
functions.Add(new StructTypeFunc(function.Name, new FuncTypeNode(parameters, returnType)));
|
||||
}
|
||||
|
||||
var interfaceImplementations = new List<InterfaceTypeNode>();
|
||||
foreach (var interfaceImplementation in structDef.InterfaceImplementations)
|
||||
{
|
||||
if (interfaceImplementation is not CustomTypeSyntax customType)
|
||||
{
|
||||
throw new Exception("Interface implementation is not a custom type");
|
||||
}
|
||||
|
||||
var resolvedType = TypeResolver.ResolveCustomType(customType.Module, customType.Name, modules);
|
||||
if (resolvedType is not InterfaceTypeNode interfaceType)
|
||||
{
|
||||
throw new Exception("Interface implementation is not a interface");
|
||||
}
|
||||
|
||||
interfaceImplementations.Add(interfaceType);
|
||||
}
|
||||
|
||||
var type = new StructTypeNode(moduleName, structDef.Name, fields, functions, interfaceImplementations);
|
||||
module._structs.Add(type);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
{
|
||||
throw new ArgumentOutOfRangeException(nameof(def));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return modules;
|
||||
}
|
||||
|
||||
private readonly List<StructTypeNode> _structs = [];
|
||||
private readonly List<InterfaceTypeNode> _interfaces = [];
|
||||
private readonly Dictionary<string, FuncTypeNode> _functions = [];
|
||||
|
||||
public IReadOnlyList<StructTypeNode> StructTypes => _structs;
|
||||
public IReadOnlyList<InterfaceTypeNode> InterfaceTypes => _interfaces;
|
||||
public IReadOnlyDictionary<string, FuncTypeNode> Functions => _functions;
|
||||
}
|
||||
19
compiler/NubLang/TypeChecking/Node/DefinitionNode.cs
Normal file
19
compiler/NubLang/TypeChecking/Node/DefinitionNode.cs
Normal file
@@ -0,0 +1,19 @@
|
||||
namespace NubLang.TypeChecking.Node;
|
||||
|
||||
public abstract record DefinitionNode : Node;
|
||||
|
||||
public record FuncParameterNode(string Name, TypeNode Type) : Node;
|
||||
|
||||
public record FuncSignatureNode(IReadOnlyList<FuncParameterNode> Parameters, TypeNode ReturnType) : Node;
|
||||
|
||||
public record FuncNode(string Name, FuncSignatureNode Signature, BlockNode? Body) : DefinitionNode;
|
||||
|
||||
public record StructFieldNode(int Index, string Name, TypeNode Type, Optional<ExpressionNode> Value) : Node;
|
||||
|
||||
public record StructFuncNode(string Name, FuncSignatureNode Signature, BlockNode Body) : Node;
|
||||
|
||||
public record StructNode(string Name, IReadOnlyList<StructFieldNode> Fields, IReadOnlyList<StructFuncNode> Functions, IReadOnlyList<InterfaceTypeNode> InterfaceImplementations) : DefinitionNode;
|
||||
|
||||
public record InterfaceFuncNode(string Name, FuncSignatureNode Signature) : Node;
|
||||
|
||||
public record InterfaceNode(string Name, IReadOnlyList<InterfaceFuncNode> Functions) : DefinitionNode;
|
||||
72
compiler/NubLang/TypeChecking/Node/ExpressionNode.cs
Normal file
72
compiler/NubLang/TypeChecking/Node/ExpressionNode.cs
Normal file
@@ -0,0 +1,72 @@
|
||||
using NubLang.Tokenization;
|
||||
|
||||
namespace NubLang.TypeChecking.Node;
|
||||
|
||||
public enum UnaryOperator
|
||||
{
|
||||
Negate,
|
||||
Invert
|
||||
}
|
||||
|
||||
public enum BinaryOperator
|
||||
{
|
||||
Equal,
|
||||
NotEqual,
|
||||
GreaterThan,
|
||||
GreaterThanOrEqual,
|
||||
LessThan,
|
||||
LessThanOrEqual,
|
||||
LogicalAnd,
|
||||
LogicalOr,
|
||||
Plus,
|
||||
Minus,
|
||||
Multiply,
|
||||
Divide,
|
||||
Modulo,
|
||||
LeftShift,
|
||||
RightShift,
|
||||
BitwiseAnd,
|
||||
BitwiseXor,
|
||||
BitwiseOr
|
||||
}
|
||||
|
||||
public abstract record ExpressionNode(TypeNode Type) : Node;
|
||||
|
||||
public abstract record LValueExpressionNode(TypeNode Type) : RValueExpressionNode(Type);
|
||||
public abstract record RValueExpressionNode(TypeNode Type) : ExpressionNode(Type);
|
||||
|
||||
public record BinaryExpressionNode(TypeNode Type, ExpressionNode Left, BinaryOperator Operator, ExpressionNode Right) : RValueExpressionNode(Type);
|
||||
|
||||
public record UnaryExpressionNode(TypeNode Type, UnaryOperator Operator, ExpressionNode Operand) : RValueExpressionNode(Type);
|
||||
|
||||
public record FuncCallNode(TypeNode Type, ExpressionNode Expression, IReadOnlyList<ExpressionNode> Parameters) : RValueExpressionNode(Type);
|
||||
|
||||
public record StructFuncCallNode(TypeNode Type, string Name, StructTypeNode StructType, ExpressionNode StructExpression, IReadOnlyList<ExpressionNode> Parameters) : RValueExpressionNode(Type);
|
||||
|
||||
public record InterfaceFuncCallNode(TypeNode Type, string Name, InterfaceTypeNode InterfaceType, ExpressionNode InterfaceExpression, IReadOnlyList<ExpressionNode> Parameters) : RValueExpressionNode(Type);
|
||||
|
||||
public record VariableIdentifierNode(TypeNode Type, string Name) : LValueExpressionNode(Type);
|
||||
|
||||
public record FuncParameterIdentifierNode(TypeNode Type, string Name) : RValueExpressionNode(Type);
|
||||
|
||||
public record FuncIdentifierNode(TypeNode Type, string Module, string Name) : RValueExpressionNode(Type);
|
||||
|
||||
public record ArrayInitializerNode(TypeNode Type, ExpressionNode Capacity, TypeNode ElementType) : RValueExpressionNode(Type);
|
||||
|
||||
public record ArrayIndexAccessNode(TypeNode Type, ExpressionNode Target, ExpressionNode Index) : LValueExpressionNode(Type);
|
||||
|
||||
public record AddressOfNode(TypeNode Type, LValueExpressionNode LValue) : RValueExpressionNode(Type);
|
||||
|
||||
public record LiteralNode(TypeNode Type, string Value, LiteralKind Kind) : RValueExpressionNode(Type);
|
||||
|
||||
public record StructFieldAccessNode(TypeNode Type, StructTypeNode StructType, ExpressionNode Target, string Field) : LValueExpressionNode(Type);
|
||||
|
||||
public record StructInitializerNode(StructTypeNode StructType, Dictionary<string, ExpressionNode> Initializers) : RValueExpressionNode(StructType);
|
||||
|
||||
public record DereferenceNode(TypeNode Type, ExpressionNode Expression) : RValueExpressionNode(Type);
|
||||
|
||||
public record ConvertToInterfaceNode(TypeNode Type, InterfaceTypeNode InterfaceType, StructTypeNode StructType, ExpressionNode Implementation) : RValueExpressionNode(Type);
|
||||
|
||||
public record ConvertIntNode(TypeNode Type, ExpressionNode Value, IntTypeNode ValueType, IntTypeNode TargetType) : RValueExpressionNode(Type);
|
||||
|
||||
public record ConvertFloatNode(TypeNode Type, ExpressionNode Value, FloatTypeNode ValueType, FloatTypeNode TargetType) : RValueExpressionNode(Type);
|
||||
5
compiler/NubLang/TypeChecking/Node/Node.cs
Normal file
5
compiler/NubLang/TypeChecking/Node/Node.cs
Normal file
@@ -0,0 +1,5 @@
|
||||
namespace NubLang.TypeChecking.Node;
|
||||
|
||||
public abstract record Node;
|
||||
|
||||
public record BlockNode(IReadOnlyList<StatementNode> Statements) : Node;
|
||||
19
compiler/NubLang/TypeChecking/Node/StatementNode.cs
Normal file
19
compiler/NubLang/TypeChecking/Node/StatementNode.cs
Normal file
@@ -0,0 +1,19 @@
|
||||
namespace NubLang.TypeChecking.Node;
|
||||
|
||||
public record StatementNode : Node;
|
||||
|
||||
public record StatementExpressionNode(ExpressionNode Expression) : StatementNode;
|
||||
|
||||
public record ReturnNode(Optional<ExpressionNode> Value) : StatementNode;
|
||||
|
||||
public record AssignmentNode(LValueExpressionNode Target, ExpressionNode Value) : StatementNode;
|
||||
|
||||
public record IfNode(ExpressionNode Condition, BlockNode Body, Optional<Variant<IfNode, BlockNode>> Else) : StatementNode;
|
||||
|
||||
public record VariableDeclarationNode(string Name, Optional<ExpressionNode> Assignment, TypeNode Type) : StatementNode;
|
||||
|
||||
public record ContinueNode : StatementNode;
|
||||
|
||||
public record BreakNode : StatementNode;
|
||||
|
||||
public record WhileNode(ExpressionNode Condition, BlockNode Body) : StatementNode;
|
||||
233
compiler/NubLang/TypeChecking/Node/TypeNode.cs
Normal file
233
compiler/NubLang/TypeChecking/Node/TypeNode.cs
Normal file
@@ -0,0 +1,233 @@
|
||||
using System.Diagnostics.CodeAnalysis;
|
||||
|
||||
namespace NubLang.TypeChecking.Node;
|
||||
|
||||
public abstract class TypeNode : IEquatable<TypeNode>
|
||||
{
|
||||
public bool IsSimpleType([NotNullWhen(true)] out SimpleTypeNode? simpleType, [NotNullWhen(false)] out ComplexTypeNode? complexType)
|
||||
{
|
||||
if (this is SimpleTypeNode st)
|
||||
{
|
||||
complexType = null;
|
||||
simpleType = st;
|
||||
return true;
|
||||
}
|
||||
|
||||
if (this is ComplexTypeNode ct)
|
||||
{
|
||||
complexType = ct;
|
||||
simpleType = null;
|
||||
return false;
|
||||
}
|
||||
|
||||
throw new ArgumentException($"Type {this} is not a simple type nor a complex type");
|
||||
}
|
||||
|
||||
public override bool Equals(object? obj) => obj is TypeNode other && Equals(other);
|
||||
|
||||
public abstract bool Equals(TypeNode? other);
|
||||
public abstract override int GetHashCode();
|
||||
public abstract override string ToString();
|
||||
|
||||
public static bool operator ==(TypeNode? left, TypeNode? right) => Equals(left, right);
|
||||
public static bool operator !=(TypeNode? left, TypeNode? right) => !Equals(left, right);
|
||||
}
|
||||
|
||||
public enum StorageSize
|
||||
{
|
||||
Void,
|
||||
I8,
|
||||
I16,
|
||||
I32,
|
||||
I64,
|
||||
U8,
|
||||
U16,
|
||||
U32,
|
||||
U64,
|
||||
F32,
|
||||
F64
|
||||
}
|
||||
|
||||
public abstract class SimpleTypeNode : TypeNode
|
||||
{
|
||||
public abstract StorageSize StorageSize { get; }
|
||||
}
|
||||
|
||||
#region Simple types
|
||||
|
||||
public class IntTypeNode(bool signed, int width) : SimpleTypeNode
|
||||
{
|
||||
public bool Signed { get; } = signed;
|
||||
public int Width { get; } = width;
|
||||
|
||||
public override StorageSize StorageSize => Signed switch
|
||||
{
|
||||
true => Width switch
|
||||
{
|
||||
8 => StorageSize.I8,
|
||||
16 => StorageSize.I16,
|
||||
32 => StorageSize.I32,
|
||||
64 => StorageSize.I64,
|
||||
_ => throw new ArgumentOutOfRangeException(nameof(Width))
|
||||
},
|
||||
false => Width switch
|
||||
{
|
||||
8 => StorageSize.U8,
|
||||
16 => StorageSize.U16,
|
||||
32 => StorageSize.U32,
|
||||
64 => StorageSize.U64,
|
||||
_ => throw new ArgumentOutOfRangeException(nameof(Width))
|
||||
}
|
||||
};
|
||||
|
||||
public override string ToString() => $"{(Signed ? "i" : "u")}{Width}";
|
||||
public override bool Equals(TypeNode? other) => other is IntTypeNode @int && @int.Width == Width && @int.Signed == Signed;
|
||||
public override int GetHashCode() => HashCode.Combine(typeof(IntTypeNode), Signed, Width);
|
||||
}
|
||||
|
||||
public class FloatTypeNode(int width) : SimpleTypeNode
|
||||
{
|
||||
public int Width { get; } = width;
|
||||
|
||||
public override StorageSize StorageSize => Width switch
|
||||
{
|
||||
32 => StorageSize.F32,
|
||||
64 => StorageSize.F64,
|
||||
_ => throw new ArgumentOutOfRangeException(nameof(Width))
|
||||
};
|
||||
|
||||
public override string ToString() => $"f{Width}";
|
||||
public override bool Equals(TypeNode? other) => other is FloatTypeNode @int && @int.Width == Width;
|
||||
public override int GetHashCode() => HashCode.Combine(typeof(FloatTypeNode), Width);
|
||||
}
|
||||
|
||||
public class BoolTypeNode : SimpleTypeNode
|
||||
{
|
||||
public override StorageSize StorageSize => StorageSize.U8;
|
||||
|
||||
public override string ToString() => "bool";
|
||||
public override bool Equals(TypeNode? other) => other is BoolTypeNode;
|
||||
public override int GetHashCode() => HashCode.Combine(typeof(BoolTypeNode));
|
||||
}
|
||||
|
||||
public class FuncTypeNode(IReadOnlyList<TypeNode> parameters, TypeNode returnType, string? externSymbol = null) : SimpleTypeNode
|
||||
{
|
||||
public string? ExternSymbol { get; } = externSymbol;
|
||||
public IReadOnlyList<TypeNode> Parameters { get; } = parameters;
|
||||
public TypeNode ReturnType { get; } = returnType;
|
||||
|
||||
public override StorageSize StorageSize => StorageSize.U64;
|
||||
|
||||
public override string ToString() => $"func({string.Join(", ", Parameters)}): {ReturnType}";
|
||||
public override bool Equals(TypeNode? other) => other is FuncTypeNode func && ReturnType.Equals(func.ReturnType) && Parameters.SequenceEqual(func.Parameters);
|
||||
|
||||
public override int GetHashCode()
|
||||
{
|
||||
var hash = new HashCode();
|
||||
hash.Add(typeof(FuncTypeNode));
|
||||
hash.Add(ReturnType);
|
||||
foreach (var param in Parameters)
|
||||
{
|
||||
hash.Add(param);
|
||||
}
|
||||
|
||||
return hash.ToHashCode();
|
||||
}
|
||||
}
|
||||
|
||||
public class PointerTypeNode(TypeNode baseType) : SimpleTypeNode
|
||||
{
|
||||
public TypeNode BaseType { get; } = baseType;
|
||||
|
||||
public override StorageSize StorageSize => StorageSize.U64;
|
||||
|
||||
public override string ToString() => "^" + BaseType;
|
||||
public override bool Equals(TypeNode? other) => other is PointerTypeNode pointer && BaseType.Equals(pointer.BaseType);
|
||||
public override int GetHashCode() => HashCode.Combine(typeof(PointerTypeNode), BaseType);
|
||||
}
|
||||
|
||||
public class VoidTypeNode : SimpleTypeNode
|
||||
{
|
||||
public override StorageSize StorageSize => StorageSize.Void;
|
||||
|
||||
public override string ToString() => "void";
|
||||
public override bool Equals(TypeNode? other) => other is VoidTypeNode;
|
||||
public override int GetHashCode() => HashCode.Combine(typeof(VoidTypeNode));
|
||||
}
|
||||
|
||||
#endregion
|
||||
|
||||
public abstract class ComplexTypeNode : TypeNode;
|
||||
|
||||
#region Complex types
|
||||
|
||||
public class CStringTypeNode : ComplexTypeNode
|
||||
{
|
||||
public override string ToString() => "cstring";
|
||||
public override bool Equals(TypeNode? other) => other is CStringTypeNode;
|
||||
public override int GetHashCode() => HashCode.Combine(typeof(CStringTypeNode));
|
||||
}
|
||||
|
||||
public class StringTypeNode : ComplexTypeNode
|
||||
{
|
||||
public override string ToString() => "string";
|
||||
public override bool Equals(TypeNode? other) => other is StringTypeNode;
|
||||
public override int GetHashCode() => HashCode.Combine(typeof(StringTypeNode));
|
||||
}
|
||||
|
||||
public class StructTypeField(string name, TypeNode type, int index, bool hasDefaultValue)
|
||||
{
|
||||
public string Name { get; } = name;
|
||||
public TypeNode Type { get; } = type;
|
||||
public int Index { get; } = index;
|
||||
public bool HasDefaultValue { get; } = hasDefaultValue;
|
||||
}
|
||||
|
||||
public class StructTypeFunc(string name, FuncTypeNode type)
|
||||
{
|
||||
public string Name { get; } = name;
|
||||
public FuncTypeNode Type { get; } = type;
|
||||
}
|
||||
|
||||
public class StructTypeNode(string module, string name, IReadOnlyList<StructTypeField> fields, IReadOnlyList<StructTypeFunc> functions, IReadOnlyList<InterfaceTypeNode> interfaceImplementations) : ComplexTypeNode
|
||||
{
|
||||
public string Module { get; } = module;
|
||||
public string Name { get; } = name;
|
||||
public IReadOnlyList<StructTypeField> Fields { get; set; } = fields;
|
||||
public IReadOnlyList<StructTypeFunc> Functions { get; set; } = functions;
|
||||
public IReadOnlyList<InterfaceTypeNode> InterfaceImplementations { get; set; } = interfaceImplementations;
|
||||
|
||||
public override string ToString() => Name;
|
||||
public override bool Equals(TypeNode? other) => other is StructTypeNode structType && Name == structType.Name && Module == structType.Module;
|
||||
public override int GetHashCode() => HashCode.Combine(typeof(StructTypeNode), Name);
|
||||
}
|
||||
|
||||
public class InterfaceTypeFunc(string name, FuncTypeNode type, int index)
|
||||
{
|
||||
public string Name { get; } = name;
|
||||
public FuncTypeNode Type { get; } = type;
|
||||
public int Index { get; } = index;
|
||||
}
|
||||
|
||||
public class InterfaceTypeNode(string module, string name, IReadOnlyList<InterfaceTypeFunc> functions) : ComplexTypeNode
|
||||
{
|
||||
public string Module { get; } = module;
|
||||
public string Name { get; } = name;
|
||||
public IReadOnlyList<InterfaceTypeFunc> Functions { get; set; } = functions;
|
||||
|
||||
public override string ToString() => Name;
|
||||
public override bool Equals(TypeNode? other) => other is InterfaceTypeNode interfaceType && Name == interfaceType.Name && Module == interfaceType.Module;
|
||||
public override int GetHashCode() => HashCode.Combine(typeof(InterfaceTypeNode), Name);
|
||||
}
|
||||
|
||||
public class ArrayTypeNode(TypeNode elementType) : ComplexTypeNode
|
||||
{
|
||||
public TypeNode ElementType { get; } = elementType;
|
||||
|
||||
public override string ToString() => "[]" + ElementType;
|
||||
|
||||
public override bool Equals(TypeNode? other) => other is ArrayTypeNode array && ElementType.Equals(array.ElementType);
|
||||
public override int GetHashCode() => HashCode.Combine(typeof(ArrayTypeNode), ElementType);
|
||||
}
|
||||
|
||||
#endregion
|
||||
456
compiler/NubLang/TypeChecking/TypeChecker.cs
Normal file
456
compiler/NubLang/TypeChecking/TypeChecker.cs
Normal file
@@ -0,0 +1,456 @@
|
||||
using NubLang.Diagnostics;
|
||||
using NubLang.Parsing.Syntax;
|
||||
using NubLang.Tokenization;
|
||||
using NubLang.TypeChecking.Node;
|
||||
|
||||
namespace NubLang.TypeChecking;
|
||||
|
||||
public sealed class TypeChecker
|
||||
{
|
||||
private readonly Module _currentModule;
|
||||
private readonly IReadOnlyDictionary<string, ModuleSignature> _moduleSignatures;
|
||||
|
||||
private readonly Stack<Scope> _scopes = [];
|
||||
private readonly Stack<TypeNode> _funcReturnTypes = [];
|
||||
private readonly List<Diagnostic> _diagnostics = [];
|
||||
|
||||
private Scope Scope => _scopes.Peek();
|
||||
|
||||
public TypeChecker(Module currentModule, IReadOnlyDictionary<string, ModuleSignature> moduleSignatures)
|
||||
{
|
||||
_currentModule = currentModule;
|
||||
_moduleSignatures = moduleSignatures.Where(x => currentModule.Imports.Contains(x.Key) || _currentModule.Name == x.Key).ToDictionary();
|
||||
}
|
||||
|
||||
public IReadOnlyList<Diagnostic> GetDiagnostics() => _diagnostics;
|
||||
|
||||
public TypedModule CheckModule()
|
||||
{
|
||||
_diagnostics.Clear();
|
||||
_scopes.Clear();
|
||||
|
||||
var definitions = new List<DefinitionNode>();
|
||||
|
||||
foreach (var definition in _currentModule.Definitions)
|
||||
{
|
||||
try
|
||||
{
|
||||
definitions.Add(CheckDefinition(definition));
|
||||
}
|
||||
catch (TypeCheckerException e)
|
||||
{
|
||||
_diagnostics.Add(e.Diagnostic);
|
||||
}
|
||||
}
|
||||
|
||||
return new TypedModule(_currentModule.Name, definitions);
|
||||
}
|
||||
|
||||
private DefinitionNode CheckDefinition(DefinitionSyntax node)
|
||||
{
|
||||
return node switch
|
||||
{
|
||||
InterfaceSyntax definition => CheckInterfaceDefinition(definition),
|
||||
FuncSyntax definition => CheckFuncDefinition(definition),
|
||||
StructSyntax definition => CheckStructDefinition(definition),
|
||||
_ => throw new ArgumentOutOfRangeException(nameof(node))
|
||||
};
|
||||
}
|
||||
|
||||
private InterfaceNode CheckInterfaceDefinition(InterfaceSyntax node)
|
||||
{
|
||||
throw new NotImplementedException();
|
||||
}
|
||||
|
||||
private StructNode CheckStructDefinition(StructSyntax node)
|
||||
{
|
||||
var fields = new List<StructFieldNode>();
|
||||
foreach (var field in node.Fields)
|
||||
{
|
||||
var value = Optional.Empty<ExpressionNode>();
|
||||
if (field.Value.HasValue)
|
||||
{
|
||||
value = CheckExpression(field.Value.Value);
|
||||
}
|
||||
|
||||
fields.Add(new StructFieldNode(field.Index, field.Name, ResolveType(field.Type), value));
|
||||
}
|
||||
|
||||
var functions = new List<StructFuncNode>();
|
||||
foreach (var function in node.Functions)
|
||||
{
|
||||
var scope = new Scope();
|
||||
// todo(nub31): Add this parameter
|
||||
foreach (var parameter in function.Signature.Parameters)
|
||||
{
|
||||
scope.Declare(new Identifier(parameter.Name, ResolveType(parameter.Type), IdentifierKind.FunctionParameter));
|
||||
}
|
||||
|
||||
_funcReturnTypes.Push(ResolveType(function.Signature.ReturnType));
|
||||
var body = CheckBlock(function.Body, scope);
|
||||
_funcReturnTypes.Pop();
|
||||
functions.Add(new StructFuncNode(function.Name, CheckFuncSignature(function.Signature), body));
|
||||
}
|
||||
|
||||
var interfaceImplementations = new List<InterfaceTypeNode>();
|
||||
foreach (var interfaceImplementation in node.InterfaceImplementations)
|
||||
{
|
||||
var type = ResolveType(interfaceImplementation);
|
||||
if (type is not InterfaceTypeNode interfaceType)
|
||||
{
|
||||
_diagnostics.Add(Diagnostic.Error($"Struct {node.Name} cannot implement non-struct type {interfaceImplementation}").At(interfaceImplementation).Build());
|
||||
continue;
|
||||
}
|
||||
|
||||
interfaceImplementations.Add(interfaceType);
|
||||
}
|
||||
|
||||
return new StructNode(node.Name, fields, functions, interfaceImplementations);
|
||||
}
|
||||
|
||||
private FuncNode CheckFuncDefinition(FuncSyntax node)
|
||||
{
|
||||
var scope = new Scope();
|
||||
foreach (var parameter in node.Signature.Parameters)
|
||||
{
|
||||
scope.Declare(new Identifier(parameter.Name, ResolveType(parameter.Type), IdentifierKind.FunctionParameter));
|
||||
}
|
||||
|
||||
BlockNode? body = null;
|
||||
if (node.Body != null)
|
||||
{
|
||||
_funcReturnTypes.Push(ResolveType(node.Signature.ReturnType));
|
||||
body = CheckBlock(node.Body, scope);
|
||||
_funcReturnTypes.Pop();
|
||||
}
|
||||
|
||||
return new FuncNode(node.Name, CheckFuncSignature(node.Signature), body);
|
||||
}
|
||||
|
||||
private StatementNode CheckStatement(StatementSyntax node)
|
||||
{
|
||||
return node switch
|
||||
{
|
||||
AssignmentSyntax statement => CheckAssignment(statement),
|
||||
BreakSyntax => new BreakNode(),
|
||||
ContinueSyntax => new ContinueNode(),
|
||||
IfSyntax statement => CheckIf(statement),
|
||||
ReturnSyntax statement => CheckReturn(statement),
|
||||
StatementExpressionSyntax statement => CheckStatementExpression(statement),
|
||||
VariableDeclarationSyntax statement => CheckVariableDeclaration(statement),
|
||||
WhileSyntax statement => CheckWhile(statement),
|
||||
_ => throw new ArgumentOutOfRangeException(nameof(node))
|
||||
};
|
||||
}
|
||||
|
||||
private StatementNode CheckAssignment(AssignmentSyntax statement)
|
||||
{
|
||||
throw new NotImplementedException();
|
||||
}
|
||||
|
||||
private IfNode CheckIf(IfSyntax statement)
|
||||
{
|
||||
throw new NotImplementedException();
|
||||
}
|
||||
|
||||
private ReturnNode CheckReturn(ReturnSyntax statement)
|
||||
{
|
||||
var value = Optional.Empty<ExpressionNode>();
|
||||
|
||||
if (statement.Value.HasValue)
|
||||
{
|
||||
value = CheckExpression(statement.Value.Value, _funcReturnTypes.Peek());
|
||||
}
|
||||
|
||||
return new ReturnNode(value);
|
||||
}
|
||||
|
||||
private StatementExpressionNode CheckStatementExpression(StatementExpressionSyntax statement)
|
||||
{
|
||||
return new StatementExpressionNode(CheckExpression(statement.Expression));
|
||||
}
|
||||
|
||||
private VariableDeclarationNode CheckVariableDeclaration(VariableDeclarationSyntax statement)
|
||||
{
|
||||
TypeNode? type = null;
|
||||
ExpressionNode? assignmentNode = null;
|
||||
|
||||
if (statement.ExplicitType.TryGetValue(out var explicitType))
|
||||
{
|
||||
type = ResolveType(explicitType);
|
||||
}
|
||||
|
||||
if (statement.Assignment.TryGetValue(out var assignment))
|
||||
{
|
||||
assignmentNode = CheckExpression(assignment, type);
|
||||
type ??= assignmentNode.Type;
|
||||
}
|
||||
|
||||
if (type == null)
|
||||
{
|
||||
throw new TypeCheckerException(Diagnostic.Error($"Cannot infer type of variable {statement.Name}").At(statement).Build());
|
||||
}
|
||||
|
||||
Scope.Declare(new Identifier(statement.Name, type, IdentifierKind.Variable));
|
||||
|
||||
return new VariableDeclarationNode(statement.Name, Optional.OfNullable(assignmentNode), type);
|
||||
}
|
||||
|
||||
private WhileNode CheckWhile(WhileSyntax statement)
|
||||
{
|
||||
throw new NotImplementedException();
|
||||
}
|
||||
|
||||
private FuncSignatureNode CheckFuncSignature(FuncSignatureSyntax statement)
|
||||
{
|
||||
var parameters = new List<FuncParameterNode>();
|
||||
foreach (var parameter in statement.Parameters)
|
||||
{
|
||||
parameters.Add(new FuncParameterNode(parameter.Name, ResolveType(parameter.Type)));
|
||||
}
|
||||
|
||||
return new FuncSignatureNode(parameters, ResolveType(statement.ReturnType));
|
||||
}
|
||||
|
||||
private ExpressionNode CheckExpression(ExpressionSyntax node, TypeNode? expectedType = null)
|
||||
{
|
||||
var result = node switch
|
||||
{
|
||||
AddressOfSyntax expression => CheckAddressOf(expression),
|
||||
ArrayIndexAccessSyntax expression => CheckArrayIndexAccess(expression),
|
||||
ArrayInitializerSyntax expression => CheckArrayInitializer(expression),
|
||||
BinaryExpressionSyntax expression => CheckBinaryExpression(expression),
|
||||
DereferenceSyntax expression => CheckDereference(expression),
|
||||
DotFuncCallSyntax expression => CheckDotFuncCall(expression),
|
||||
FuncCallSyntax expression => CheckFuncCall(expression),
|
||||
IdentifierSyntax expression => CheckIdentifier(expression),
|
||||
LiteralSyntax expression => CheckLiteral(expression, expectedType),
|
||||
StructFieldAccessSyntax expression => CheckStructFieldAccess(expression),
|
||||
StructInitializerSyntax expression => CheckStructInitializer(expression, expectedType),
|
||||
UnaryExpressionSyntax expression => CheckUnaryExpression(expression),
|
||||
_ => throw new ArgumentOutOfRangeException(nameof(node))
|
||||
};
|
||||
|
||||
if (expectedType == null || result.Type == expectedType)
|
||||
{
|
||||
return result;
|
||||
}
|
||||
|
||||
if (result.Type is StructTypeNode structType && expectedType is InterfaceTypeNode interfaceType)
|
||||
{
|
||||
return new ConvertToInterfaceNode(interfaceType, interfaceType, structType, result);
|
||||
}
|
||||
|
||||
if (result.Type is IntTypeNode sourceIntType && expectedType is IntTypeNode targetIntType)
|
||||
{
|
||||
if (sourceIntType.Signed == targetIntType.Signed && sourceIntType.Width < targetIntType.Width)
|
||||
{
|
||||
return new ConvertIntNode(targetIntType, result, sourceIntType, targetIntType);
|
||||
}
|
||||
}
|
||||
|
||||
if (result.Type is FloatTypeNode sourceFloatType && expectedType is FloatTypeNode targetFloatType)
|
||||
{
|
||||
if (sourceFloatType.Width < targetFloatType.Width)
|
||||
{
|
||||
return new ConvertFloatNode(targetFloatType, result, sourceFloatType, targetFloatType);
|
||||
}
|
||||
}
|
||||
|
||||
throw new TypeCheckerException(Diagnostic.Error($"Cannot convert {result.Type} to {expectedType}").At(node).Build());
|
||||
}
|
||||
|
||||
private AddressOfNode CheckAddressOf(AddressOfSyntax expression)
|
||||
{
|
||||
throw new NotImplementedException();
|
||||
}
|
||||
|
||||
private ArrayIndexAccessNode CheckArrayIndexAccess(ArrayIndexAccessSyntax expression)
|
||||
{
|
||||
throw new NotImplementedException();
|
||||
}
|
||||
|
||||
private ArrayInitializerNode CheckArrayInitializer(ArrayInitializerSyntax expression)
|
||||
{
|
||||
throw new NotImplementedException();
|
||||
}
|
||||
|
||||
private BinaryExpressionNode CheckBinaryExpression(BinaryExpressionSyntax expression)
|
||||
{
|
||||
throw new NotImplementedException();
|
||||
}
|
||||
|
||||
private DereferenceNode CheckDereference(DereferenceSyntax expression)
|
||||
{
|
||||
throw new NotImplementedException();
|
||||
}
|
||||
|
||||
private FuncCallNode CheckFuncCall(FuncCallSyntax expression)
|
||||
{
|
||||
var accessor = CheckExpression(expression.Expression);
|
||||
if (accessor.Type is not FuncTypeNode funcType)
|
||||
{
|
||||
throw new TypeCheckerException(Diagnostic.Error($"Cannot call non-function type {accessor.Type}").At(expression.Expression).Build());
|
||||
}
|
||||
|
||||
if (expression.Parameters.Count != funcType.Parameters.Count)
|
||||
{
|
||||
throw new TypeCheckerException(Diagnostic.Error($"Function {funcType} expects {funcType.Parameters} but got {expression.Parameters.Count} parameters").At(expression.Expression).Build());
|
||||
}
|
||||
|
||||
var parameters = new List<ExpressionNode>();
|
||||
for (var i = 0; i < expression.Parameters.Count; i++)
|
||||
{
|
||||
var parameter = expression.Parameters[i];
|
||||
var expectedType = funcType.Parameters[i];
|
||||
|
||||
var parameterExpression = CheckExpression(parameter, expectedType);
|
||||
if (parameterExpression.Type != expectedType)
|
||||
{
|
||||
throw new Exception($"Parameter {i + 1} does not match the type {expectedType} for function {funcType}");
|
||||
}
|
||||
|
||||
parameters.Add(parameterExpression);
|
||||
}
|
||||
|
||||
return new FuncCallNode(funcType.ReturnType, accessor, parameters);
|
||||
}
|
||||
|
||||
private ExpressionNode CheckDotFuncCall(DotFuncCallSyntax expression)
|
||||
{
|
||||
throw new NotImplementedException();
|
||||
}
|
||||
|
||||
private ExpressionNode CheckIdentifier(IdentifierSyntax expression)
|
||||
{
|
||||
// If the identifier does not have a module specified, first check if a local variable or function parameter with that identifier exists
|
||||
if (!expression.Module.TryGetValue(out var moduleName))
|
||||
{
|
||||
var scopeIdent = Scope.Lookup(expression.Name);
|
||||
if (scopeIdent != null)
|
||||
{
|
||||
switch (scopeIdent.Kind)
|
||||
{
|
||||
case IdentifierKind.Variable:
|
||||
{
|
||||
return new VariableIdentifierNode(scopeIdent.Type, expression.Name);
|
||||
}
|
||||
case IdentifierKind.FunctionParameter:
|
||||
{
|
||||
return new FuncParameterIdentifierNode(scopeIdent.Type, expression.Name);
|
||||
}
|
||||
default:
|
||||
{
|
||||
throw new ArgumentOutOfRangeException();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
moduleName ??= _currentModule.Name;
|
||||
if (_moduleSignatures.TryGetValue(moduleName, out var module))
|
||||
{
|
||||
if (module.Functions.TryGetValue(expression.Name, out var function))
|
||||
{
|
||||
return new FuncIdentifierNode(function, moduleName, expression.Name);
|
||||
}
|
||||
}
|
||||
|
||||
throw new TypeCheckerException(Diagnostic.Error($"Identifier {expression.Name} not found").At(expression).Build());
|
||||
}
|
||||
|
||||
private LiteralNode CheckLiteral(LiteralSyntax expression, TypeNode? expectedType)
|
||||
{
|
||||
// todo(nub31): Check if the types can actually be represented as another one. For example, an int should be passed when a string is expected
|
||||
var type = expectedType ?? expression.Kind switch
|
||||
{
|
||||
LiteralKind.Integer => new IntTypeNode(true, 64),
|
||||
LiteralKind.Float => new FloatTypeNode(64),
|
||||
LiteralKind.String => new StringTypeNode(),
|
||||
LiteralKind.Bool => new BoolTypeNode(),
|
||||
_ => throw new ArgumentOutOfRangeException()
|
||||
};
|
||||
|
||||
return new LiteralNode(type, expression.Value, expression.Kind);
|
||||
}
|
||||
|
||||
private StructFieldAccessNode CheckStructFieldAccess(StructFieldAccessSyntax expression)
|
||||
{
|
||||
throw new NotImplementedException();
|
||||
}
|
||||
|
||||
private StructInitializerNode CheckStructInitializer(StructInitializerSyntax expression, TypeNode? expectedType)
|
||||
{
|
||||
throw new NotImplementedException();
|
||||
}
|
||||
|
||||
private UnaryExpressionNode CheckUnaryExpression(UnaryExpressionSyntax expression)
|
||||
{
|
||||
throw new NotImplementedException();
|
||||
}
|
||||
|
||||
private BlockNode CheckBlock(BlockSyntax node, Scope? scope = null)
|
||||
{
|
||||
var statements = new List<StatementNode>();
|
||||
|
||||
_scopes.Push(scope ?? Scope.SubScope());
|
||||
|
||||
foreach (var statement in node.Statements)
|
||||
{
|
||||
statements.Add(CheckStatement(statement));
|
||||
}
|
||||
|
||||
_scopes.Pop();
|
||||
|
||||
return new BlockNode(statements);
|
||||
}
|
||||
|
||||
private TypeNode ResolveType(TypeSyntax fieldType)
|
||||
{
|
||||
return TypeResolver.ResolveType(fieldType, _moduleSignatures);
|
||||
}
|
||||
}
|
||||
|
||||
public enum IdentifierKind
|
||||
{
|
||||
Variable,
|
||||
FunctionParameter
|
||||
}
|
||||
|
||||
public record Identifier(string Name, TypeNode Type, IdentifierKind Kind);
|
||||
|
||||
public class Scope(Scope? parent = null)
|
||||
{
|
||||
private readonly List<Identifier> _variables = [];
|
||||
|
||||
public Identifier? Lookup(string name)
|
||||
{
|
||||
var variable = _variables.FirstOrDefault(x => x.Name == name);
|
||||
if (variable != null)
|
||||
{
|
||||
return variable;
|
||||
}
|
||||
|
||||
return parent?.Lookup(name);
|
||||
}
|
||||
|
||||
public void Declare(Identifier identifier)
|
||||
{
|
||||
_variables.Add(identifier);
|
||||
}
|
||||
|
||||
public Scope SubScope()
|
||||
{
|
||||
return new Scope(this);
|
||||
}
|
||||
}
|
||||
|
||||
public class TypeCheckerException : Exception
|
||||
{
|
||||
public Diagnostic Diagnostic { get; }
|
||||
|
||||
public TypeCheckerException(Diagnostic diagnostic) : base(diagnostic.Message)
|
||||
{
|
||||
Diagnostic = diagnostic;
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user