Files
nub-lang/compiler/TypeChecker.cs
nub31 2b7eb56895 ...
2026-03-01 18:14:54 +01:00

976 lines
41 KiB
C#

using System.Data.Common;
using System.Diagnostics.CodeAnalysis;
using System.Formats.Tar;
namespace Compiler;
public class TypeChecker
{
public static TypedNodeDefinitionFunc? CheckFunction(string fileName, string currentModule, NodeDefinitionFunc function, ModuleGraph moduleGraph, out List<Diagnostic> diagnostics)
{
return new TypeChecker(fileName, currentModule, function, moduleGraph).CheckFunction(out diagnostics);
}
private TypeChecker(string fileName, string currentModule, NodeDefinitionFunc function, ModuleGraph moduleGraph)
{
this.fileName = fileName;
this.currentModule = currentModule;
this.function = function;
this.moduleGraph = moduleGraph;
}
private readonly string fileName;
private readonly string currentModule;
private readonly NodeDefinitionFunc function;
private NubType functionReturnType = null!;
private readonly ModuleGraph moduleGraph;
private readonly Scope scope = new();
private TypedNodeDefinitionFunc? CheckFunction(out List<Diagnostic> diagnostics)
{
diagnostics = [];
var parameters = new List<TypedNodeDefinitionFunc.Param>();
var invalidParameter = false;
TypedNodeStatement? body = null;
if (function.ReturnType == null)
{
functionReturnType = NubTypeVoid.Instance;
}
else
{
try
{
functionReturnType = ResolveType(function.ReturnType);
}
catch (CompileException e)
{
diagnostics.Add(e.Diagnostic);
return null;
}
}
using (scope.EnterScope())
{
foreach (var parameter in function.Parameters)
{
NubType parameterType;
try
{
parameterType = ResolveType(parameter.Type);
}
catch (CompileException e)
{
diagnostics.Add(e.Diagnostic);
invalidParameter = true;
continue;
}
scope.DeclareIdentifier(parameter.Name.Ident, parameterType);
parameters.Add(new TypedNodeDefinitionFunc.Param(parameter.Tokens, parameter.Name, parameterType));
}
try
{
body = CheckStatement(function.Body);
}
catch (CompileException e)
{
diagnostics.Add(e.Diagnostic);
}
if (body == null || invalidParameter)
return null;
return new TypedNodeDefinitionFunc(function.Tokens, currentModule, function.Name, parameters, body, functionReturnType);
}
}
private TypedNodeStatement CheckStatement(NodeStatement node)
{
return node switch
{
NodeStatementAssignment statement => CheckStatementAssignment(statement),
NodeStatementBlock statement => CheckStatementBlock(statement),
NodeStatementExpression statement => CheckStatementExpression(statement),
NodeStatementIf statement => CheckStatementIf(statement),
NodeStatementReturn statement => CheckStatementReturn(statement),
NodeStatementVariableDeclaration statement => CheckStatementVariableDeclaration(statement),
NodeStatementWhile statement => CheckStatementWhile(statement),
NodeStatementMatch statement => CheckStatementMatch(statement),
_ => throw new ArgumentOutOfRangeException(nameof(node))
};
}
private TypedNodeStatementAssignment CheckStatementAssignment(NodeStatementAssignment statement)
{
var target = CheckExpression(statement.Target, null);
var value = CheckExpression(statement.Value, target.Type);
return new TypedNodeStatementAssignment(statement.Tokens, target, value);
}
private TypedNodeStatementBlock CheckStatementBlock(NodeStatementBlock statement)
{
using (scope.EnterScope())
{
var statements = statement.Statements.Select(CheckStatement).ToList();
return new TypedNodeStatementBlock(statement.Tokens, statements);
}
}
private TypedNodeStatementFuncCall CheckStatementExpression(NodeStatementExpression statement)
{
if (statement.Expression is not NodeExpressionFuncCall funcCall)
throw BasicError("Expected statement or function call", statement);
var expr = CheckExpressionFuncCall(funcCall, null);
return new TypedNodeStatementFuncCall(expr.Tokens, expr.Target, expr.Parameters);
}
private TypedNodeStatementIf CheckStatementIf(NodeStatementIf statement)
{
var condition = CheckExpression(statement.Condition, NubTypeBool.Instance);
if (!condition.Type.IsAssignableTo(NubTypeBool.Instance))
throw BasicError("Condition part of if statement must be a boolean", condition);
var thenBlock = CheckStatement(statement.ThenBlock);
var elseBlock = statement.ElseBlock == null ? null : CheckStatement(statement.ElseBlock);
return new TypedNodeStatementIf(statement.Tokens, condition, thenBlock, elseBlock);
}
private TypedNodeStatementReturn CheckStatementReturn(NodeStatementReturn statement)
{
if (statement.Value == null)
{
if (functionReturnType is not NubTypeVoid)
throw BasicError($"Missing return value. Expected '{functionReturnType}'", statement);
return new TypedNodeStatementReturn(statement.Tokens, null);
}
else
{
var value = CheckExpression(statement.Value, functionReturnType);
if (!value.Type.IsAssignableTo(functionReturnType))
throw BasicError($"Type of returned value ({value.Type}) is not assignable to the return type of the function ({functionReturnType})", value);
return new TypedNodeStatementReturn(statement.Tokens, value);
}
}
private TypedNodeStatementVariableDeclaration CheckStatementVariableDeclaration(NodeStatementVariableDeclaration statement)
{
NubType? type = null;
if (statement.Type != null)
type = ResolveType(statement.Type);
var value = CheckExpression(statement.Value, type);
if (type is not null && !value.Type.IsAssignableTo(type))
throw BasicError("Type of variable does match type of assigned value", value);
type ??= value.Type;
scope.DeclareIdentifier(statement.Name.Ident, type);
return new TypedNodeStatementVariableDeclaration(statement.Tokens, statement.Name, type, value);
}
private TypedNodeStatementWhile CheckStatementWhile(NodeStatementWhile statement)
{
var condition = CheckExpression(statement.Condition, NubTypeBool.Instance);
if (!condition.Type.IsAssignableTo(NubTypeBool.Instance))
throw BasicError("Condition part of if statement must be a boolean", condition);
var body = CheckStatement(statement.Body);
return new TypedNodeStatementWhile(statement.Tokens, condition, body);
}
private TypedNodeStatementMatch CheckStatementMatch(NodeStatementMatch statement)
{
var target = CheckExpression(statement.Target, null);
if (target.Type is not NubTypeEnum enumType)
throw BasicError("A match statement can only be used on enum types", target);
if (!moduleGraph.TryResolveType(enumType.Module, enumType.Name, enumType.Module == currentModule, out var info))
throw BasicError($"Type '{enumType}' not found", target);
if (info is not Module.TypeInfoEnum enumInfo)
throw BasicError($"Type '{enumType}' is not an enum", target);
var uncoveredCases = enumInfo.Variants.Select(x => x.Name).ToList();
var cases = new List<TypedNodeStatementMatch.Case>();
foreach (var @case in statement.Cases)
{
var variant = enumInfo.Variants.FirstOrDefault(x => x.Name == @case.Variant.Ident);
if (variant == null)
throw BasicError($"Enum type'{enumType}' does not have a variant named '{@case.Variant.Ident}'", @case.Variant);
uncoveredCases.Remove(@case.Variant.Ident);
using (scope.EnterScope())
{
if (@case.VariableName != null)
{
if (variant.Type is null)
throw BasicError("Cannot capture variable for enum variant without type", @case.VariableName);
scope.DeclareIdentifier(@case.VariableName.Ident, variant.Type);
}
var body = CheckStatement(@case.Body);
cases.Add(new TypedNodeStatementMatch.Case(@case.Tokens, @case.Variant, @case.VariableName, body));
}
}
if (uncoveredCases.Any())
throw BasicError($"Match statement does not cover the following cases: {string.Join(", ", uncoveredCases)}", statement);
return new TypedNodeStatementMatch(statement.Tokens, target, cases);
}
private TypedNodeExpression CheckExpression(NodeExpression node, NubType? expectedType)
{
return node switch
{
NodeExpressionBinary expression => CheckExpressionBinary(expression, expectedType),
NodeExpressionUnary expression => CheckExpressionUnary(expression, expectedType),
NodeExpressionBoolLiteral expression => CheckExpressionBoolLiteral(expression, expectedType),
NodeExpressionIdent expression => CheckExpressionIdent(expression, expectedType),
NodeExpressionIntLiteral expression => CheckExpressionIntLiteral(expression, expectedType),
NodeExpressionMemberAccess expression => CheckExpressionMemberAccess(expression, expectedType),
NodeExpressionFuncCall expression => CheckExpressionFuncCall(expression, expectedType),
NodeExpressionStringLiteral expression => CheckExpressionStringLiteral(expression, expectedType),
NodeExpressionStructLiteral expression => CheckExpressionStructLiteral(expression, expectedType),
NodeExpressionNewNamedType expression => CheckExpressionNewNamedType(expression, expectedType),
_ => throw new ArgumentOutOfRangeException(nameof(node))
};
}
private TypedNodeExpressionBinary CheckExpressionBinary(NodeExpressionBinary expression, NubType? expectedType)
{
// todo(nub31): Add proper inference here
var left = CheckExpression(expression.Left, null);
var right = CheckExpression(expression.Right, null);
NubType type;
switch (expression.Operation)
{
case NodeExpressionBinary.Op.Add:
{
if (left.Type is NubTypeString)
{
if (right.Type is not NubTypeString)
throw BasicError("Right hand side of string concatination operator must be a string", right);
return new TypedNodeExpressionBinary(expression.Tokens, NubTypeString.Instance, left, CheckExpressionBinaryOperation(expression.Operation), right);
}
if (left.Type is not NubTypeSInt and not NubTypeUInt)
throw BasicError($"Unsupported type for left hand side arithmetic operation: {left.Type}", left);
if (right.Type is not NubTypeSInt and not NubTypeUInt)
throw BasicError($"Unsupported type for right hand side arithmetic operation: {right.Type}", right);
type = left.Type;
break;
}
case NodeExpressionBinary.Op.Subtract:
case NodeExpressionBinary.Op.Multiply:
case NodeExpressionBinary.Op.Divide:
case NodeExpressionBinary.Op.Modulo:
{
if (left.Type is not NubTypeSInt and not NubTypeUInt)
throw BasicError($"Unsupported type for left hand side arithmetic operation: {left.Type}", left);
if (right.Type is not NubTypeSInt and not NubTypeUInt)
throw BasicError($"Unsupported type for right hand side arithmetic operation: {right.Type}", right);
type = left.Type;
break;
}
case NodeExpressionBinary.Op.LeftShift:
case NodeExpressionBinary.Op.RightShift:
{
if (left.Type is not NubTypeUInt)
throw BasicError($"Unsupported type for left hand side of left/right shift operation: {left.Type}", left);
if (right.Type is not NubTypeUInt)
throw BasicError($"Unsupported type for right hand side of left/right shift operation: {right.Type}", right);
type = left.Type;
break;
}
case NodeExpressionBinary.Op.Equal:
case NodeExpressionBinary.Op.NotEqual:
case NodeExpressionBinary.Op.LessThan:
case NodeExpressionBinary.Op.LessThanOrEqual:
case NodeExpressionBinary.Op.GreaterThan:
case NodeExpressionBinary.Op.GreaterThanOrEqual:
{
if (left.Type is not NubTypeSInt and not NubTypeUInt)
throw BasicError($"Unsupported type for left hand side of comparison: {left.Type}", left);
if (right.Type is not NubTypeSInt and not NubTypeUInt)
throw BasicError($"Unsupported type for right hand side of comparison: {right.Type}", right);
type = NubTypeBool.Instance;
break;
}
case NodeExpressionBinary.Op.LogicalAnd:
case NodeExpressionBinary.Op.LogicalOr:
{
if (left.Type is not NubTypeBool)
throw BasicError($"Unsupported type for left hand side of logical operation: {left.Type}", left);
if (right.Type is not NubTypeBool)
throw BasicError($"Unsupported type for right hand side of logical operation: {right.Type}", right);
type = NubTypeBool.Instance;
break;
}
default:
throw new ArgumentOutOfRangeException();
}
return new TypedNodeExpressionBinary(expression.Tokens, type, left, CheckExpressionBinaryOperation(expression.Operation), right);
}
private static TypedNodeExpressionBinary.Op CheckExpressionBinaryOperation(NodeExpressionBinary.Op op)
{
return op switch
{
NodeExpressionBinary.Op.Add => TypedNodeExpressionBinary.Op.Add,
NodeExpressionBinary.Op.Subtract => TypedNodeExpressionBinary.Op.Subtract,
NodeExpressionBinary.Op.Multiply => TypedNodeExpressionBinary.Op.Multiply,
NodeExpressionBinary.Op.Divide => TypedNodeExpressionBinary.Op.Divide,
NodeExpressionBinary.Op.Modulo => TypedNodeExpressionBinary.Op.Modulo,
NodeExpressionBinary.Op.Equal => TypedNodeExpressionBinary.Op.Equal,
NodeExpressionBinary.Op.NotEqual => TypedNodeExpressionBinary.Op.NotEqual,
NodeExpressionBinary.Op.LessThan => TypedNodeExpressionBinary.Op.LessThan,
NodeExpressionBinary.Op.LessThanOrEqual => TypedNodeExpressionBinary.Op.LessThanOrEqual,
NodeExpressionBinary.Op.GreaterThan => TypedNodeExpressionBinary.Op.GreaterThan,
NodeExpressionBinary.Op.GreaterThanOrEqual => TypedNodeExpressionBinary.Op.GreaterThanOrEqual,
NodeExpressionBinary.Op.LeftShift => TypedNodeExpressionBinary.Op.LeftShift,
NodeExpressionBinary.Op.RightShift => TypedNodeExpressionBinary.Op.RightShift,
NodeExpressionBinary.Op.LogicalAnd => TypedNodeExpressionBinary.Op.LogicalAnd,
NodeExpressionBinary.Op.LogicalOr => TypedNodeExpressionBinary.Op.LogicalOr,
_ => throw new ArgumentOutOfRangeException(nameof(op), op, null)
};
}
private TypedNodeExpressionUnary CheckExpressionUnary(NodeExpressionUnary expression, NubType? expectedType)
{
// todo(nub31): Add proper inference here
var target = CheckExpression(expression.Target, null);
NubType type;
switch (expression.Operation)
{
case NodeExpressionUnary.Op.Negate:
{
if (target.Type is not NubTypeSInt and not NubTypeUInt)
throw BasicError($"Unsupported type for negation: {target.Type}", target);
type = target.Type;
break;
}
case NodeExpressionUnary.Op.Invert:
{
if (target.Type is not NubTypeBool)
throw BasicError($"Unsupported type for inversion: {target.Type}", target);
type = NubTypeBool.Instance;
break;
}
default:
throw new ArgumentOutOfRangeException();
}
return new TypedNodeExpressionUnary(expression.Tokens, type, target, CheckExpressionUnaryOperation(expression.Operation));
}
private static TypedNodeExpressionUnary.Op CheckExpressionUnaryOperation(NodeExpressionUnary.Op op)
{
return op switch
{
NodeExpressionUnary.Op.Negate => TypedNodeExpressionUnary.Op.Negate,
NodeExpressionUnary.Op.Invert => TypedNodeExpressionUnary.Op.Invert,
_ => throw new ArgumentOutOfRangeException(nameof(op), op, null)
};
}
private TypedNodeExpressionBoolLiteral CheckExpressionBoolLiteral(NodeExpressionBoolLiteral expression, NubType? expectedType)
{
return new TypedNodeExpressionBoolLiteral(expression.Tokens, NubTypeBool.Instance, expression.Value);
}
private TypedNodeExpression CheckExpressionIdent(NodeExpressionIdent expression, NubType? expectedType)
{
if (expression.Sections.Count == 1)
{
var name = expression.Sections[0].Ident;
var localType = scope.GetIdentifierType(name);
if (localType is not null)
return new TypedNodeExpressionLocalIdent(expression.Tokens, localType, name);
if (moduleGraph.TryResolveIdentifier(currentModule, name, true, out var ident))
return new TypedNodeExpressionGlobalIdent(expression.Tokens, ident.Type, currentModule, name);
}
else if (expression.Sections.Count == 2)
{
var module = expression.Sections[0].Ident;
var name = expression.Sections[1].Ident;
if (moduleGraph.TryResolveIdentifier(module, name, true, out var ident))
return new TypedNodeExpressionGlobalIdent(expression.Tokens, ident.Type, module, name);
}
throw BasicError($"Unknown identifier '{string.Join("::", expression.Sections.Select(x => x.Ident))}'", expression);
}
private TypedNodeExpressionIntLiteral CheckExpressionIntLiteral(NodeExpressionIntLiteral expression, NubType? expectedType)
{
NubType? type = null;
if (expectedType is NubTypeSInt or NubTypeUInt)
type = expectedType;
type ??= NubTypeSInt.Get(32);
return new TypedNodeExpressionIntLiteral(expression.Tokens, type, expression.Value);
}
private TypedNodeExpression CheckExpressionMemberAccess(NodeExpressionMemberAccess expression, NubType? expectedType)
{
var target = CheckExpression(expression.Target, null);
switch (target.Type)
{
case NubTypeString stringType:
{
switch (expression.Name.Ident)
{
case "length":
return new TypedNodeExpressionStringLength(expression.Tokens, NubTypeUInt.Get(64), target);
case "ptr":
return new TypedNodeExpressionStringPointer(expression.Tokens, NubTypePointer.Get(NubTypeUInt.Get(8)), target);
default:
throw BasicError($"'{expression.Name.Ident}' is not a member of type string", expression.Name);
}
}
case NubTypeStruct structType:
{
if (!moduleGraph.TryResolveModule(structType.Module, out var module))
throw BasicError($"Module '{structType.Module}' not found", expression.Target);
if (!module.TryResolveType(structType.Name, currentModule == structType.Module, out var typeDef))
throw BasicError($"Type '{structType.Name}' not found in module '{structType.Module}'", expression.Target);
if (typeDef is not Module.TypeInfoStruct structDef)
throw BasicError($"Type '{target.Type}' is not a struct", expression.Target);
var field = structDef.Fields.FirstOrDefault(x => x.Name == expression.Name.Ident);
if (field == null)
throw BasicError($"Struct '{target.Type}' does not have a field matching the name '{expression.Name.Ident}'", target);
return new TypedNodeExpressionStructMemberAccess(expression.Tokens, field.Type, target, expression.Name);
}
case NubTypeAnonymousStruct anonymousStructType:
{
var field = anonymousStructType.Fields.FirstOrDefault(x => x.Name == expression.Name.Ident);
if (field == null)
throw BasicError($"Struct '{target.Type}' does not have a field matching the name '{expression.Name.Ident}'", target);
return new TypedNodeExpressionStructMemberAccess(expression.Tokens, field.Type, target, expression.Name);
}
default:
throw BasicError($"{target.Type} has no member '{expression.Name.Ident}'", target);
}
}
private TypedNodeExpressionFuncCall CheckExpressionFuncCall(NodeExpressionFuncCall expression, NubType? expectedType)
{
var target = CheckExpression(expression.Target, null);
if (target.Type is not NubTypeFunc funcType)
throw BasicError("Expected a function type", target);
if (funcType.Parameters.Count != expression.Parameters.Count)
throw BasicError($"Expected {funcType.Parameters.Count} parameters but got {expression.Parameters.Count}", expression);
var parameters = new List<TypedNodeExpression>();
for (int i = 0; i < expression.Parameters.Count; i++)
{
var parameter = CheckExpression(expression.Parameters[i], funcType.Parameters[i]);
if (!parameter.Type.IsAssignableTo(funcType.Parameters[i]))
throw BasicError($"Parameter {i + 1} ({parameter.Type}) does is not assignable to '{funcType.Parameters[i]}'", parameter);
parameters.Add(parameter);
}
return new TypedNodeExpressionFuncCall(expression.Tokens, funcType.ReturnType, target, parameters);
}
private TypedNodeExpressionStringLiteral CheckExpressionStringLiteral(NodeExpressionStringLiteral expression, NubType? expectedType)
{
return new TypedNodeExpressionStringLiteral(expression.Tokens, NubTypeString.Instance, expression.Value);
}
private TypedNodeExpressionStructLiteral CheckExpressionStructLiteral(NodeExpressionStructLiteral expression, NubType? expectedType)
{
if (expectedType is NubTypeStruct structType)
{
if (!moduleGraph.TryResolveType(structType.Module, structType.Name, structType.Module == currentModule, out var info))
throw BasicError($"Type '{structType}' struct literal not found", expression);
if (info is not Module.TypeInfoStruct structInfo)
throw BasicError($"Type '{structType}' is not a struct", expression);
var initializers = new List<TypedNodeExpressionStructLiteral.Initializer>();
foreach (var initializer in expression.Initializers)
{
var field = structInfo.Fields.FirstOrDefault(x => x.Name == initializer.Name.Ident);
if (field == null)
throw BasicError($"Field '{initializer.Name.Ident}' does not exist on struct '{structType.Module}::{structType.Name}'", initializer.Name);
var value = CheckExpression(initializer.Value, field.Type);
if (!value.Type.IsAssignableTo(field.Type))
throw BasicError($"Type of assignment ({value.Type}) does not match expected type of field '{field.Name}' ({field.Type})", initializer.Name);
initializers.Add(new TypedNodeExpressionStructLiteral.Initializer(initializer.Tokens, initializer.Name, value));
}
return new TypedNodeExpressionStructLiteral(expression.Tokens, structType, initializers);
}
else if (expectedType is NubTypeAnonymousStruct anonymousStructType)
{
var initializers = new List<TypedNodeExpressionStructLiteral.Initializer>();
foreach (var initializer in expression.Initializers)
{
var field = anonymousStructType.Fields.FirstOrDefault(x => x.Name == initializer.Name.Ident);
if (field == null)
throw BasicError($"Field '{initializer.Name.Ident}' does not exist on anonymous struct '{anonymousStructType}'", initializer.Name);
var value = CheckExpression(initializer.Value, field.Type);
if (!value.Type.IsAssignableTo(field.Type))
throw BasicError($"Type of assignment ({value.Type}) does not match expected type of field '{field.Name}' ({field.Type})", initializer.Name);
initializers.Add(new TypedNodeExpressionStructLiteral.Initializer(initializer.Tokens, initializer.Name, value));
}
return new TypedNodeExpressionStructLiteral(expression.Tokens, anonymousStructType, initializers);
}
else
{
var initializers = new List<TypedNodeExpressionStructLiteral.Initializer>();
foreach (var initializer in expression.Initializers)
{
var value = CheckExpression(initializer.Value, null);
initializers.Add(new TypedNodeExpressionStructLiteral.Initializer(initializer.Tokens, initializer.Name, value));
}
var type = NubTypeAnonymousStruct.Get(initializers.Select(x => new NubTypeAnonymousStruct.Field(x.Name.Ident, x.Value.Type)).ToList());
return new TypedNodeExpressionStructLiteral(expression.Tokens, type, initializers);
}
}
private TypedNodeExpressionNewNamedType CheckExpressionNewNamedType(NodeExpressionNewNamedType expression, NubType? expectedType)
{
var type = ResolveType(expression.Type);
switch (type)
{
case NubTypeStruct structType:
{
var value = CheckExpression(expression.Value, structType);
return new TypedNodeExpressionNewNamedType(expression.Tokens, structType, value);
}
case NubTypeEnumVariant enumVariantType:
{
if (!moduleGraph.TryResolveType(enumVariantType.EnumType.Module, enumVariantType.EnumType.Name, enumVariantType.EnumType.Module == currentModule, out var info))
throw BasicError($"Type '{enumVariantType.EnumType}' not found", expression.Type);
if (info is not Module.TypeInfoEnum enumInfo)
throw BasicError($"Type '{enumVariantType.EnumType}' is not an enum", expression.Type);
var variant = enumInfo.Variants.FirstOrDefault(x => x.Name == enumVariantType.Variant);
if (variant == null)
throw BasicError($"Enum type '{enumVariantType.EnumType}' does not have a variant named '{enumVariantType.Variant}'", expression.Type);
var value = CheckExpression(expression.Value, variant.Type);
return new TypedNodeExpressionNewNamedType(expression.Tokens, enumVariantType, value);
}
default:
{
throw BasicError($"'{type}' is not a valid type for the new operator", expression);
}
}
}
private NubType ResolveType(NodeType node)
{
return node switch
{
NodeTypeBool => NubTypeBool.Instance,
NodeTypeNamed type => ResolveNamedType(type),
NodeTypeAnonymousStruct type => NubTypeAnonymousStruct.Get(type.Fields.Select(x => new NubTypeAnonymousStruct.Field(x.Name.Ident, ResolveType(x.Type))).ToList()),
NodeTypeFunc type => NubTypeFunc.Get(type.Parameters.Select(ResolveType).ToList(), ResolveType(type.ReturnType)),
NodeTypePointer type => NubTypePointer.Get(ResolveType(type.To)),
NodeTypeSInt type => NubTypeSInt.Get(type.Width),
NodeTypeUInt type => NubTypeUInt.Get(type.Width),
NodeTypeString => NubTypeString.Instance,
NodeTypeVoid => NubTypeVoid.Instance,
_ => throw new ArgumentOutOfRangeException(nameof(node))
};
}
private NubType ResolveNamedType(NodeTypeNamed type)
{
return type.Sections.Count switch
{
3 => ResolveThreePartType(type.Sections[0], type.Sections[1], type.Sections[2]),
2 => ResolveTwoPartType(type.Sections[0], type.Sections[1]),
1 => ResolveOnePartType(type.Sections[0]),
_ => throw BasicError("Invalid type name", type)
};
}
private NubType ResolveThreePartType(TokenIdent first, TokenIdent second, TokenIdent third)
{
if (TryResolveEnumVariant(first.Ident, second.Ident, third.Ident, out var variantType))
return variantType;
throw BasicError($"Enum '{first.Ident}::{second.Ident}::{third.Ident}' does not have a variant named '{third.Ident}'", third);
}
private NubType ResolveTwoPartType(TokenIdent first, TokenIdent second)
{
if (TryResolveEnumVariant(currentModule, first.Ident, second.Ident, out var variantType))
return variantType;
var typeInfo = ResolveModuleTypeInfo(ResolveModule(first), second);
return typeInfo switch
{
Module.TypeInfoStruct => NubTypeStruct.Get(first.Ident, second.Ident),
Module.TypeInfoEnum => NubTypeEnum.Get(first.Ident, second.Ident),
_ => throw new ArgumentOutOfRangeException(nameof(typeInfo))
};
}
private NubType ResolveOnePartType(TokenIdent name)
{
if (!moduleGraph.TryResolveModule(currentModule, out var module))
throw BasicError($"Module '{currentModule}' not found", name);
var typeInfo = ResolveModuleTypeInfo(module, name);
return typeInfo switch
{
Module.TypeInfoStruct => NubTypeStruct.Get(currentModule, name.Ident),
Module.TypeInfoEnum => NubTypeEnum.Get(currentModule, name.Ident),
_ => throw new ArgumentOutOfRangeException(nameof(typeInfo))
};
}
private Module ResolveModule(TokenIdent name)
{
if (!moduleGraph.TryResolveModule(name.Ident, out var module))
throw BasicError($"Module '{name.Ident}' not found", name);
return module;
}
private Module.TypeInfo ResolveModuleTypeInfo(Module module, TokenIdent name)
{
if (!module.TryResolveType(name.Ident, currentModule == module.Name, out var type))
throw BasicError($"Named type '{module.Name}::{name.Ident}' not found", name);
return type;
}
private bool TryResolveEnumVariant(string moduleName, string enumName, string variantName, [NotNullWhen(true)] out NubType? result)
{
result = null;
if (!moduleGraph.TryResolveModule(moduleName, out var module))
return false;
if (!module.TryResolveType(enumName, true, out var type))
return false;
if (type is not Module.TypeInfoEnum enumInfo)
return false;
var variant = enumInfo.Variants.FirstOrDefault(v => v.Name == variantName);
if (variant == null)
return false;
result = NubTypeEnumVariant.Get(NubTypeEnum.Get(moduleName, enumName), variantName);
return true;
}
private CompileException BasicError(string message, TokenIdent ident)
{
return new CompileException(Diagnostic.Error(message).At(fileName, ident).Build());
}
private CompileException BasicError(string message, Node node)
{
return new CompileException(Diagnostic.Error(message).At(fileName, node).Build());
}
private CompileException BasicError(string message, TypedNode node)
{
return new CompileException(Diagnostic.Error(message).At(fileName, node).Build());
}
private sealed class Scope
{
private readonly Stack<Dictionary<string, NubType>> scopes = new();
public IDisposable EnterScope()
{
scopes.Push([]);
return new ScopeGuard(this);
}
public void DeclareIdentifier(string name, NubType type)
{
scopes.Peek().Add(name, type);
}
public NubType? GetIdentifierType(string name)
{
foreach (var scope in scopes)
{
if (scope.TryGetValue(name, out var type))
{
return type;
}
}
return null;
}
private void ExitScope()
{
scopes.Pop();
}
private sealed class ScopeGuard(Scope owner) : IDisposable
{
public void Dispose()
{
owner.ExitScope();
}
}
}
}
public abstract class TypedNode(List<Token> tokens)
{
public List<Token> Tokens { get; } = tokens;
}
public abstract class TypedNodeDefinition(List<Token> tokens, string module) : TypedNode(tokens)
{
public string Module { get; } = module;
}
public class TypedNodeDefinitionFunc(List<Token> tokens, string module, TokenIdent name, List<TypedNodeDefinitionFunc.Param> parameters, TypedNodeStatement body, NubType returnType) : TypedNodeDefinition(tokens, module)
{
public TokenIdent Name { get; } = name;
public List<Param> Parameters { get; } = parameters;
public TypedNodeStatement Body { get; } = body;
public NubType ReturnType { get; } = returnType;
public NubTypeFunc GetNubType()
{
return NubTypeFunc.Get(Parameters.Select(x => x.Type).ToList(), ReturnType);
}
public class Param(List<Token> tokens, TokenIdent name, NubType type) : TypedNode(tokens)
{
public TokenIdent Name { get; } = name;
public NubType Type { get; } = type;
}
}
public abstract class TypedNodeStatement(List<Token> tokens) : TypedNode(tokens);
public class TypedNodeStatementBlock(List<Token> tokens, List<TypedNodeStatement> statements) : TypedNodeStatement(tokens)
{
public List<TypedNodeStatement> Statements { get; } = statements;
}
public class TypedNodeStatementFuncCall(List<Token> tokens, TypedNodeExpression target, List<TypedNodeExpression> parameters) : TypedNodeStatement(tokens)
{
public TypedNodeExpression Target { get; } = target;
public List<TypedNodeExpression> Parameters { get; } = parameters;
}
public class TypedNodeStatementReturn(List<Token> tokens, TypedNodeExpression? value) : TypedNodeStatement(tokens)
{
public TypedNodeExpression? Value { get; } = value;
}
public class TypedNodeStatementVariableDeclaration(List<Token> tokens, TokenIdent name, NubType type, TypedNodeExpression value) : TypedNodeStatement(tokens)
{
public TokenIdent Name { get; } = name;
public NubType Type { get; } = type;
public TypedNodeExpression Value { get; } = value;
}
public class TypedNodeStatementAssignment(List<Token> tokens, TypedNodeExpression target, TypedNodeExpression value) : TypedNodeStatement(tokens)
{
public TypedNodeExpression Target { get; } = target;
public TypedNodeExpression Value { get; } = value;
}
public class TypedNodeStatementIf(List<Token> tokens, TypedNodeExpression condition, TypedNodeStatement thenBlock, TypedNodeStatement? elseBlock) : TypedNodeStatement(tokens)
{
public TypedNodeExpression Condition { get; } = condition;
public TypedNodeStatement ThenBlock { get; } = thenBlock;
public TypedNodeStatement? ElseBlock { get; } = elseBlock;
}
public class TypedNodeStatementWhile(List<Token> tokens, TypedNodeExpression condition, TypedNodeStatement body) : TypedNodeStatement(tokens)
{
public TypedNodeExpression Condition { get; } = condition;
public TypedNodeStatement Body { get; } = body;
}
public class TypedNodeStatementMatch(List<Token> tokens, TypedNodeExpression target, List<TypedNodeStatementMatch.Case> cases) : TypedNodeStatement(tokens)
{
public TypedNodeExpression Target { get; } = target;
public List<Case> Cases { get; } = cases;
public class Case(List<Token> tokens, TokenIdent type, TokenIdent? variableName, TypedNodeStatement body) : Node(tokens)
{
public TokenIdent Variant { get; } = type;
public TokenIdent? VariableName { get; } = variableName;
public TypedNodeStatement Body { get; } = body;
}
}
public abstract class TypedNodeExpression(List<Token> tokens, NubType type) : TypedNode(tokens)
{
public NubType Type { get; } = type;
}
public class TypedNodeExpressionIntLiteral(List<Token> tokens, NubType type, TokenIntLiteral value) : TypedNodeExpression(tokens, type)
{
public TokenIntLiteral Value { get; } = value;
}
public class TypedNodeExpressionStringLiteral(List<Token> tokens, NubType type, TokenStringLiteral value) : TypedNodeExpression(tokens, type)
{
public TokenStringLiteral Value { get; } = value;
}
public class TypedNodeExpressionBoolLiteral(List<Token> tokens, NubType type, TokenBoolLiteral value) : TypedNodeExpression(tokens, type)
{
public TokenBoolLiteral Value { get; } = value;
}
public class TypedNodeExpressionStructLiteral(List<Token> tokens, NubType type, List<TypedNodeExpressionStructLiteral.Initializer> initializers) : TypedNodeExpression(tokens, type)
{
public List<Initializer> Initializers { get; } = initializers;
public class Initializer(List<Token> tokens, TokenIdent name, TypedNodeExpression value) : Node(tokens)
{
public TokenIdent Name { get; } = name;
public TypedNodeExpression Value { get; } = value;
}
}
public class TypedNodeExpressionNewNamedType(List<Token> tokens, NubType type, TypedNodeExpression value) : TypedNodeExpression(tokens, type)
{
public TypedNodeExpression Value { get; } = value;
}
public class TypedNodeExpressionStructMemberAccess(List<Token> tokens, NubType type, TypedNodeExpression target, TokenIdent name) : TypedNodeExpression(tokens, type)
{
public TypedNodeExpression Target { get; } = target;
public TokenIdent Name { get; } = name;
}
public class TypedNodeExpressionStringLength(List<Token> tokens, NubType type, TypedNodeExpression target) : TypedNodeExpression(tokens, type)
{
public TypedNodeExpression Target { get; } = target;
}
public class TypedNodeExpressionStringPointer(List<Token> tokens, NubType type, TypedNodeExpression target) : TypedNodeExpression(tokens, type)
{
public TypedNodeExpression Target { get; } = target;
}
public class TypedNodeExpressionFuncCall(List<Token> tokens, NubType type, TypedNodeExpression target, List<TypedNodeExpression> parameters) : TypedNodeExpression(tokens, type)
{
public TypedNodeExpression Target { get; } = target;
public List<TypedNodeExpression> Parameters { get; } = parameters;
}
public class TypedNodeExpressionLocalIdent(List<Token> tokens, NubType type, string value) : TypedNodeExpression(tokens, type)
{
public string Name { get; } = value;
}
public class TypedNodeExpressionGlobalIdent(List<Token> tokens, NubType type, string module, string value) : TypedNodeExpression(tokens, type)
{
public string Module { get; } = module;
public string Name { get; } = value;
}
public class TypedNodeExpressionBinary(List<Token> tokens, NubType type, TypedNodeExpression left, TypedNodeExpressionBinary.Op operation, TypedNodeExpression right) : TypedNodeExpression(tokens, type)
{
public TypedNodeExpression Left { get; } = left;
public Op Operation { get; } = operation;
public TypedNodeExpression Right { get; } = right;
public enum Op
{
Add,
Subtract,
Multiply,
Divide,
Modulo,
Equal,
NotEqual,
LessThan,
LessThanOrEqual,
GreaterThan,
GreaterThanOrEqual,
LeftShift,
RightShift,
// BitwiseAnd,
// BitwiseXor,
// BitwiseOr,
LogicalAnd,
LogicalOr,
}
}
public class TypedNodeExpressionUnary(List<Token> tokens, NubType type, TypedNodeExpression target, TypedNodeExpressionUnary.Op op) : TypedNodeExpression(tokens, type)
{
public TypedNodeExpression Target { get; } = target;
public Op Operation { get; } = op;
public enum Op
{
Negate,
Invert,
}
}