using System.Data.Common; using System.Diagnostics.CodeAnalysis; using System.Formats.Tar; namespace Compiler; public class TypeChecker { public static TypedNodeDefinitionFunc? CheckFunction(string fileName, string currentModule, NodeDefinitionFunc function, ModuleGraph moduleGraph, out List diagnostics) { return new TypeChecker(fileName, currentModule, function, moduleGraph).CheckFunction(out diagnostics); } private TypeChecker(string fileName, string currentModule, NodeDefinitionFunc function, ModuleGraph moduleGraph) { this.fileName = fileName; this.currentModule = currentModule; this.function = function; this.moduleGraph = moduleGraph; } private readonly string fileName; private readonly string currentModule; private readonly NodeDefinitionFunc function; private NubType functionReturnType = null!; private readonly ModuleGraph moduleGraph; private readonly Scope scope = new(); private TypedNodeDefinitionFunc? CheckFunction(out List diagnostics) { diagnostics = []; var parameters = new List(); var invalidParameter = false; TypedNodeStatement? body = null; try { functionReturnType = ResolveType(function.ReturnType); } catch (CompileException e) { diagnostics.Add(e.Diagnostic); return null; } using (scope.EnterScope()) { foreach (var parameter in function.Parameters) { NubType parameterType; try { parameterType = ResolveType(parameter.Type); } catch (CompileException e) { diagnostics.Add(e.Diagnostic); invalidParameter = true; continue; } scope.DeclareIdentifier(parameter.Name.Ident, parameterType); parameters.Add(new TypedNodeDefinitionFunc.Param(parameter.Tokens, parameter.Name, parameterType)); } try { body = CheckStatement(function.Body); } catch (CompileException e) { diagnostics.Add(e.Diagnostic); } if (body == null || invalidParameter) return null; return new TypedNodeDefinitionFunc(function.Tokens, currentModule, function.Name, parameters, body, functionReturnType); } } private TypedNodeStatement CheckStatement(NodeStatement node) { return node switch { NodeStatementAssignment statement => CheckStatementAssignment(statement), NodeStatementBlock statement => CheckStatementBlock(statement), NodeStatementExpression statement => CheckStatementExpression(statement), NodeStatementIf statement => CheckStatementIf(statement), NodeStatementReturn statement => CheckStatementReturn(statement), NodeStatementVariableDeclaration statement => CheckStatementVariableDeclaration(statement), NodeStatementWhile statement => CheckStatementWhile(statement), NodeStatementMatch statement => CheckStatementMatch(statement), _ => throw new ArgumentOutOfRangeException(nameof(node)) }; } private TypedNodeStatementAssignment CheckStatementAssignment(NodeStatementAssignment statement) { var target = CheckExpression(statement.Target, null); var value = CheckExpression(statement.Value, target.Type); return new TypedNodeStatementAssignment(statement.Tokens, target, value); } private TypedNodeStatementBlock CheckStatementBlock(NodeStatementBlock statement) { using (scope.EnterScope()) { var statements = statement.Statements.Select(CheckStatement).ToList(); return new TypedNodeStatementBlock(statement.Tokens, statements); } } private TypedNodeStatementFuncCall CheckStatementExpression(NodeStatementExpression statement) { if (statement.Expression is not NodeExpressionFuncCall funcCall) throw BasicError("Expected statement or function call", statement); var target = CheckExpression(funcCall.Target, null); if (target.Type is not NubTypeFunc funcType) throw BasicError("Expected a function type", target); if (funcType.Parameters.Count != funcCall.Parameters.Count) throw BasicError($"Expected {funcType.Parameters.Count} parameters but got {funcCall.Parameters.Count}", funcCall); var parameters = new List(); for (int i = 0; i < funcCall.Parameters.Count; i++) { parameters.Add(CheckExpression(funcCall.Parameters[i], funcType.Parameters[i])); } return new TypedNodeStatementFuncCall(statement.Tokens, target, parameters); } private TypedNodeStatementIf CheckStatementIf(NodeStatementIf statement) { var condition = CheckExpression(statement.Condition, NubTypeBool.Instance); if (!condition.Type.IsAssignableTo(NubTypeBool.Instance)) throw BasicError("Condition part of if statement must be a boolean", condition); var thenBlock = CheckStatement(statement.ThenBlock); var elseBlock = statement.ElseBlock == null ? null : CheckStatement(statement.ElseBlock); return new TypedNodeStatementIf(statement.Tokens, condition, thenBlock, elseBlock); } private TypedNodeStatementReturn CheckStatementReturn(NodeStatementReturn statement) { var value = CheckExpression(statement.Value, functionReturnType); if (!value.Type.IsAssignableTo(functionReturnType)) throw BasicError($"Type of returned value ({value.Type}) is not assignable to the return type of the function ({functionReturnType})", value); return new TypedNodeStatementReturn(statement.Tokens, value); } private TypedNodeStatementVariableDeclaration CheckStatementVariableDeclaration(NodeStatementVariableDeclaration statement) { NubType? type = null; if (statement.Type != null) type = ResolveType(statement.Type); var value = CheckExpression(statement.Value, type); if (type is not null && !value.Type.IsAssignableTo(type)) throw BasicError("Type of variable does match type of assigned value", value); type ??= value.Type; scope.DeclareIdentifier(statement.Name.Ident, type); return new TypedNodeStatementVariableDeclaration(statement.Tokens, statement.Name, type, value); } private TypedNodeStatementWhile CheckStatementWhile(NodeStatementWhile statement) { var condition = CheckExpression(statement.Condition, NubTypeBool.Instance); if (!condition.Type.IsAssignableTo(NubTypeBool.Instance)) throw BasicError("Condition part of if statement must be a boolean", condition); var body = CheckStatement(statement.Body); return new TypedNodeStatementWhile(statement.Tokens, condition, body); } private TypedNodeStatementMatch CheckStatementMatch(NodeStatementMatch statement) { var target = CheckExpression(statement.Target, null); if (target.Type is not NubTypeEnum enumType) throw BasicError("A match statement can only be used on enum types", target); var cases = new List(); foreach (var @case in statement.Cases) { using (scope.EnterScope()) { scope.DeclareIdentifier(@case.VariableName.Ident, NubTypeEnumVariant.Get(NubTypeEnum.Get(enumType.Module, enumType.Name), @case.Variant.Ident)); var body = CheckStatement(@case.Body); cases.Add(new TypedNodeStatementMatch.Case(@case.Tokens, @case.Variant, @case.VariableName, body)); } } return new TypedNodeStatementMatch(statement.Tokens, target, cases); } private TypedNodeExpression CheckExpression(NodeExpression node, NubType? expectedType) { return node switch { NodeExpressionBinary expression => CheckExpressionBinary(expression, expectedType), NodeExpressionUnary expression => CheckExpressionUnary(expression, expectedType), NodeExpressionBoolLiteral expression => CheckExpressionBoolLiteral(expression, expectedType), NodeExpressionIdent expression => CheckExpressionIdent(expression, expectedType), NodeExpressionIntLiteral expression => CheckExpressionIntLiteral(expression, expectedType), NodeExpressionMemberAccess expression => CheckExpressionMemberAccess(expression, expectedType), NodeExpressionFuncCall expression => CheckExpressionFuncCall(expression, expectedType), NodeExpressionStringLiteral expression => CheckExpressionStringLiteral(expression, expectedType), NodeExpressionStructLiteral expression => CheckExpressionStructLiteral(expression, expectedType), NodeExpressionEnumLiteral expression => CheckExpressionEnumLiteral(expression, expectedType), _ => throw new ArgumentOutOfRangeException(nameof(node)) }; } private TypedNodeExpressionBinary CheckExpressionBinary(NodeExpressionBinary expression, NubType? expectedType) { // todo(nub31): Add proper inference here var left = CheckExpression(expression.Left, null); var right = CheckExpression(expression.Right, null); NubType type; switch (expression.Operation) { case NodeExpressionBinary.Op.Add: case NodeExpressionBinary.Op.Subtract: case NodeExpressionBinary.Op.Multiply: case NodeExpressionBinary.Op.Divide: case NodeExpressionBinary.Op.Modulo: { if (left.Type is not NubTypeSInt and not NubTypeUInt) throw BasicError($"Unsupported type for left hand side arithmetic operation: {left.Type}", left); if (right.Type is not NubTypeSInt and not NubTypeUInt) throw BasicError($"Unsupported type for right hand side arithmetic operation: {right.Type}", right); type = left.Type; break; } case NodeExpressionBinary.Op.LeftShift: case NodeExpressionBinary.Op.RightShift: { if (left.Type is not NubTypeUInt) throw BasicError($"Unsupported type for left hand side of left/right shift operation: {left.Type}", left); if (right.Type is not NubTypeUInt) throw BasicError($"Unsupported type for right hand side of left/right shift operation: {right.Type}", right); type = left.Type; break; } case NodeExpressionBinary.Op.Equal: case NodeExpressionBinary.Op.NotEqual: case NodeExpressionBinary.Op.LessThan: case NodeExpressionBinary.Op.LessThanOrEqual: case NodeExpressionBinary.Op.GreaterThan: case NodeExpressionBinary.Op.GreaterThanOrEqual: { if (left.Type is not NubTypeSInt and not NubTypeUInt) throw BasicError($"Unsupported type for left hand side of comparison: {left.Type}", left); if (right.Type is not NubTypeSInt and not NubTypeUInt) throw BasicError($"Unsupported type for right hand side of comparison: {right.Type}", right); type = NubTypeBool.Instance; break; } case NodeExpressionBinary.Op.LogicalAnd: case NodeExpressionBinary.Op.LogicalOr: { if (left.Type is not NubTypeBool) throw BasicError($"Unsupported type for left hand side of logical operation: {left.Type}", left); if (right.Type is not NubTypeBool) throw BasicError($"Unsupported type for right hand side of logical operation: {right.Type}", right); type = NubTypeBool.Instance; break; } default: throw new ArgumentOutOfRangeException(); } return new TypedNodeExpressionBinary(expression.Tokens, type, left, CheckExpressionBinaryOperation(expression.Operation), right); } private static TypedNodeExpressionBinary.Op CheckExpressionBinaryOperation(NodeExpressionBinary.Op op) { return op switch { NodeExpressionBinary.Op.Add => TypedNodeExpressionBinary.Op.Add, NodeExpressionBinary.Op.Subtract => TypedNodeExpressionBinary.Op.Subtract, NodeExpressionBinary.Op.Multiply => TypedNodeExpressionBinary.Op.Multiply, NodeExpressionBinary.Op.Divide => TypedNodeExpressionBinary.Op.Divide, NodeExpressionBinary.Op.Modulo => TypedNodeExpressionBinary.Op.Modulo, NodeExpressionBinary.Op.Equal => TypedNodeExpressionBinary.Op.Equal, NodeExpressionBinary.Op.NotEqual => TypedNodeExpressionBinary.Op.NotEqual, NodeExpressionBinary.Op.LessThan => TypedNodeExpressionBinary.Op.LessThan, NodeExpressionBinary.Op.LessThanOrEqual => TypedNodeExpressionBinary.Op.LessThanOrEqual, NodeExpressionBinary.Op.GreaterThan => TypedNodeExpressionBinary.Op.GreaterThan, NodeExpressionBinary.Op.GreaterThanOrEqual => TypedNodeExpressionBinary.Op.GreaterThanOrEqual, NodeExpressionBinary.Op.LeftShift => TypedNodeExpressionBinary.Op.LeftShift, NodeExpressionBinary.Op.RightShift => TypedNodeExpressionBinary.Op.RightShift, NodeExpressionBinary.Op.LogicalAnd => TypedNodeExpressionBinary.Op.LogicalAnd, NodeExpressionBinary.Op.LogicalOr => TypedNodeExpressionBinary.Op.LogicalOr, _ => throw new ArgumentOutOfRangeException(nameof(op), op, null) }; } private TypedNodeExpressionUnary CheckExpressionUnary(NodeExpressionUnary expression, NubType? expectedType) { // todo(nub31): Add proper inference here var target = CheckExpression(expression.Target, null); NubType type; switch (expression.Operation) { case NodeExpressionUnary.Op.Negate: { if (target.Type is not NubTypeSInt and not NubTypeUInt) throw BasicError($"Unsupported type for negation: {target.Type}", target); type = target.Type; break; } case NodeExpressionUnary.Op.Invert: { if (target.Type is not NubTypeBool) throw BasicError($"Unsupported type for inversion: {target.Type}", target); type = NubTypeBool.Instance; break; } default: throw new ArgumentOutOfRangeException(); } return new TypedNodeExpressionUnary(expression.Tokens, type, target, CheckExpressionUnaryOperation(expression.Operation)); } private static TypedNodeExpressionUnary.Op CheckExpressionUnaryOperation(NodeExpressionUnary.Op op) { return op switch { NodeExpressionUnary.Op.Negate => TypedNodeExpressionUnary.Op.Negate, NodeExpressionUnary.Op.Invert => TypedNodeExpressionUnary.Op.Invert, _ => throw new ArgumentOutOfRangeException(nameof(op), op, null) }; } private TypedNodeExpressionBoolLiteral CheckExpressionBoolLiteral(NodeExpressionBoolLiteral expression, NubType? expectedType) { return new TypedNodeExpressionBoolLiteral(expression.Tokens, NubTypeBool.Instance, expression.Value); } private TypedNodeExpression CheckExpressionIdent(NodeExpressionIdent expression, NubType? expectedType) { if (expression.Sections.Count == 1) { var name = expression.Sections[0].Ident; var localType = scope.GetIdentifierType(name); if (localType is not null) return new TypedNodeExpressionLocalIdent(expression.Tokens, localType, name); if (moduleGraph.TryResolveIdentifier(currentModule, name, true, out var ident)) return new TypedNodeExpressionGlobalIdent(expression.Tokens, ident.Type, currentModule, name); } else if (expression.Sections.Count == 2) { var module = expression.Sections[0].Ident; var name = expression.Sections[1].Ident; if (moduleGraph.TryResolveIdentifier(module, name, true, out var ident)) return new TypedNodeExpressionGlobalIdent(expression.Tokens, ident.Type, module, name); } throw BasicError($"Unknown identifier '{string.Join("::", expression.Sections.Select(x => x.Ident))}'", expression); } private TypedNodeExpressionIntLiteral CheckExpressionIntLiteral(NodeExpressionIntLiteral expression, NubType? expectedType) { return new TypedNodeExpressionIntLiteral(expression.Tokens, NubTypeSInt.Get(32), expression.Value); } private TypedNodeExpressionMemberAccess CheckExpressionMemberAccess(NodeExpressionMemberAccess expression, NubType? expectedType) { var target = CheckExpression(expression.Target, null); switch (target.Type) { case NubTypeStruct structType: { if (!moduleGraph.TryResolveModule(structType.Module, out var module)) throw BasicError($"Module '{structType.Module}' not found", expression.Target); if (!module.TryResolveType(structType.Name, currentModule == structType.Module, out var typeDef)) throw BasicError($"Type '{structType.Name}' not found in module '{structType.Module}'", expression.Target); if (typeDef is not Module.TypeInfoStruct structDef) throw BasicError($"Type '{target.Type}' is not a struct", expression.Target); var field = structDef.Fields.FirstOrDefault(x => x.Name == expression.Name.Ident); if (field == null) throw BasicError($"Struct '{target.Type}' does not have a field matching the name '{expression.Name.Ident}'", target); return new TypedNodeExpressionMemberAccess(expression.Tokens, field.Type, target, expression.Name); } case NubTypeEnumVariant enumVariantType: { if (!moduleGraph.TryResolveModule(enumVariantType.EnumType.Module, out var module)) throw BasicError($"Module '{enumVariantType.EnumType.Module}' not found", expression.Target); if (!module.TryResolveType(enumVariantType.EnumType.Name, currentModule == enumVariantType.EnumType.Module, out var typeDef)) throw BasicError($"Type '{enumVariantType.EnumType.Name}' not found in module '{enumVariantType.EnumType.Module}'", expression.Target); if (typeDef is not Module.TypeInfoEnum enumDef) throw BasicError($"Type '{enumVariantType.EnumType.Module}::{enumVariantType.EnumType.Name}' is not an enum", expression.Target); var variant = enumDef.Variants.FirstOrDefault(x => x.Name == enumVariantType.Variant); if (variant == null) throw BasicError($"Type '{target.Type}' does not have a variant named '{enumVariantType.Variant}'", expression.Target); return new TypedNodeExpressionMemberAccess(expression.Tokens, variant.Type, target, expression.Name); } default: throw BasicError($"{target.Type} has no member '{expression.Name.Ident}'", target); } } private TypedNodeExpressionFuncCall CheckExpressionFuncCall(NodeExpressionFuncCall expression, NubType? expectedType) { var target = CheckExpression(expression.Target, null); if (target.Type is not NubTypeFunc funcType) throw BasicError("Expected a function type", target); if (funcType.Parameters.Count != expression.Parameters.Count) throw BasicError($"Expected {funcType.Parameters.Count} parameters but got {expression.Parameters.Count}", expression); var parameters = new List(); for (int i = 0; i < expression.Parameters.Count; i++) { parameters.Add(CheckExpression(expression.Parameters[i], funcType.Parameters[i])); } return new TypedNodeExpressionFuncCall(expression.Tokens, funcType.ReturnType, target, parameters); } private TypedNodeExpressionStringLiteral CheckExpressionStringLiteral(NodeExpressionStringLiteral expression, NubType? expectedType) { return new TypedNodeExpressionStringLiteral(expression.Tokens, NubTypeString.Instance, expression.Value); } private TypedNodeExpressionStructLiteral CheckExpressionStructLiteral(NodeExpressionStructLiteral expression, NubType? expectedType) { if (expression.Type != null) { var type = ResolveType(expression.Type); if (type is not NubTypeStruct structType) throw BasicError("Type of struct literal is not a struct", expression); if (!moduleGraph.TryResolveType(structType.Module, structType.Name, structType.Module == currentModule, out var info)) throw BasicError($"Type '{structType}' struct literal not found", expression); if (info is not Module.TypeInfoStruct structInfo) throw BasicError($"Type '{structType}' is not a struct", expression.Type); var initializers = new List(); foreach (var initializer in expression.Initializers) { var field = structInfo.Fields.FirstOrDefault(x => x.Name == initializer.Name.Ident); if (field == null) throw BasicError($"Field '{initializer.Name.Ident}' does not exist on struct '{structType.Module}::{structType.Name}'", initializer.Name); var value = CheckExpression(initializer.Value, field.Type); if (!value.Type.IsAssignableTo(field.Type)) throw BasicError($"Type of assignment ({value.Type}) does not match expected type of field '{field.Name}' ({field.Type})", initializer.Name); initializers.Add(new TypedNodeExpressionStructLiteral.Initializer(initializer.Tokens, initializer.Name, value)); } return new TypedNodeExpressionStructLiteral(expression.Tokens, structType, initializers); } else if (expectedType is NubTypeStruct structType) { if (!moduleGraph.TryResolveType(structType.Module, structType.Name, structType.Module == currentModule, out var info)) throw BasicError($"Type '{structType}' struct literal not found", expression); if (info is not Module.TypeInfoStruct structInfo) throw BasicError($"Type '{structType}' is not a struct", expression); var initializers = new List(); foreach (var initializer in expression.Initializers) { var field = structInfo.Fields.FirstOrDefault(x => x.Name == initializer.Name.Ident); if (field == null) throw BasicError($"Field '{initializer.Name.Ident}' does not exist on struct '{structType.Module}::{structType.Name}'", initializer.Name); var value = CheckExpression(initializer.Value, field.Type); if (!value.Type.IsAssignableTo(field.Type)) throw BasicError($"Type of assignment ({value.Type}) does not match expected type of field '{field.Name}' ({field.Type})", initializer.Name); initializers.Add(new TypedNodeExpressionStructLiteral.Initializer(initializer.Tokens, initializer.Name, value)); } return new TypedNodeExpressionStructLiteral(expression.Tokens, structType, initializers); } // todo(nub31): Infer anonymous struct types if expectedType is anonymous struct else { var initializers = new List(); foreach (var initializer in expression.Initializers) { var value = CheckExpression(initializer.Value, null); initializers.Add(new TypedNodeExpressionStructLiteral.Initializer(initializer.Tokens, initializer.Name, value)); } var type = NubTypeAnonymousStruct.Get(initializers.Select(x => new NubTypeAnonymousStruct.Field(x.Name.Ident, x.Value.Type)).ToList()); return new TypedNodeExpressionStructLiteral(expression.Tokens, type, initializers); } } private TypedNodeExpressionEnumLiteral CheckExpressionEnumLiteral(NodeExpressionEnumLiteral expression, NubType? expectedType) { var type = ResolveType(expression.Type); if (type is not NubTypeEnumVariant variantType) throw BasicError("Expected enum variant type", expression.Type); if (!moduleGraph.TryResolveType(variantType.EnumType.Module, variantType.EnumType.Name, variantType.EnumType.Module == currentModule, out var info)) throw BasicError($"Type '{variantType.EnumType}' not found", expression.Type); if (info is not Module.TypeInfoEnum enumInfo) throw BasicError($"Type '{variantType.EnumType}' is not an enum", expression.Type); var variant = enumInfo.Variants.FirstOrDefault(x => x.Name == variantType.Variant); if (variant == null) throw BasicError($"Enum '{variantType.EnumType}' does not have a variant named '{variantType.Variant}'", expression.Type); var value = CheckExpression(expression.Value, variant.Type); return new TypedNodeExpressionEnumLiteral(expression.Tokens, type, value); } private NubType ResolveType(NodeType node) { return node switch { NodeTypeBool => NubTypeBool.Instance, NodeTypeNamed type => ResolveNamedType(type), NodeTypeAnonymousStruct type => NubTypeAnonymousStruct.Get(type.Fields.Select(x => new NubTypeAnonymousStruct.Field(x.Name.Ident, ResolveType(x.Type))).ToList()), NodeTypeFunc type => NubTypeFunc.Get(type.Parameters.Select(ResolveType).ToList(), ResolveType(type.ReturnType)), NodeTypePointer type => NubTypePointer.Get(ResolveType(type.To)), NodeTypeSInt type => NubTypeSInt.Get(type.Width), NodeTypeUInt type => NubTypeUInt.Get(type.Width), NodeTypeString => NubTypeString.Instance, NodeTypeVoid => NubTypeVoid.Instance, _ => throw new ArgumentOutOfRangeException(nameof(node)) }; } private NubType ResolveNamedType(NodeTypeNamed type) { return type.Sections.Count switch { 3 => ResolveThreePartType(type.Sections[0], type.Sections[1], type.Sections[2]), 2 => ResolveTwoPartType(type.Sections[0], type.Sections[1]), 1 => ResolveOnePartType(type.Sections[0]), _ => throw BasicError("Invalid type name", type) }; } private NubType ResolveThreePartType(TokenIdent first, TokenIdent second, TokenIdent third) { if (TryResolveEnumVariant(first.Ident, second.Ident, third.Ident, out var variantType)) return variantType; throw BasicError($"Enum '{first.Ident}::{second.Ident}::{third.Ident}' does not have a variant named '{third.Ident}'", third); } private NubType ResolveTwoPartType(TokenIdent first, TokenIdent second) { if (TryResolveEnumVariant(currentModule, first.Ident, second.Ident, out var variantType)) return variantType; var typeInfo = ResolveModuleTypeInfo(ResolveModule(first), second); return typeInfo switch { Module.TypeInfoStruct => NubTypeStruct.Get(first.Ident, second.Ident), Module.TypeInfoEnum => NubTypeEnum.Get(first.Ident, second.Ident), _ => throw new ArgumentOutOfRangeException(nameof(typeInfo)) }; } private NubType ResolveOnePartType(TokenIdent name) { if (!moduleGraph.TryResolveModule(currentModule, out var module)) throw BasicError($"Module '{currentModule}' not found", name); var typeInfo = ResolveModuleTypeInfo(module, name); return typeInfo switch { Module.TypeInfoStruct => NubTypeStruct.Get(currentModule, name.Ident), Module.TypeInfoEnum => NubTypeEnum.Get(currentModule, name.Ident), _ => throw new ArgumentOutOfRangeException(nameof(typeInfo)) }; } private Module ResolveModule(TokenIdent name) { if (!moduleGraph.TryResolveModule(name.Ident, out var module)) throw BasicError($"Module '{name.Ident}' not found", name); return module; } private Module.TypeInfo ResolveModuleTypeInfo(Module module, TokenIdent name) { if (!module.TryResolveType(name.Ident, currentModule == module.Name, out var type)) throw BasicError($"Named type '{module.Name}::{name.Ident}' not found", name); return type; } private bool TryResolveEnumVariant(string moduleName, string enumName, string variantName, [NotNullWhen(true)] out NubType? result) { result = null; if (!moduleGraph.TryResolveModule(moduleName, out var module)) return false; if (!module.TryResolveType(enumName, true, out var type)) return false; if (type is not Module.TypeInfoEnum enumInfo) return false; var variant = enumInfo.Variants.FirstOrDefault(v => v.Name == variantName); if (variant == null) return false; result = NubTypeEnumVariant.Get(NubTypeEnum.Get(moduleName, enumName), variantName); return true; } private CompileException BasicError(string message, TokenIdent ident) { return new CompileException(Diagnostic.Error(message).At(fileName, ident).Build()); } private CompileException BasicError(string message, Node node) { return new CompileException(Diagnostic.Error(message).At(fileName, node).Build()); } private CompileException BasicError(string message, TypedNode node) { return new CompileException(Diagnostic.Error(message).At(fileName, node).Build()); } private sealed class Scope { private readonly Stack> scopes = new(); public IDisposable EnterScope() { scopes.Push([]); return new ScopeGuard(this); } public void DeclareIdentifier(string name, NubType type) { scopes.Peek().Add(name, type); } public NubType? GetIdentifierType(string name) { foreach (var scope in scopes) { if (scope.TryGetValue(name, out var type)) { return type; } } return null; } private void ExitScope() { scopes.Pop(); } private sealed class ScopeGuard(Scope owner) : IDisposable { public void Dispose() { owner.ExitScope(); } } } } public abstract class TypedNode(List tokens) { public List Tokens { get; } = tokens; } public abstract class TypedNodeDefinition(List tokens, string module) : TypedNode(tokens) { public string Module { get; } = module; } public class TypedNodeDefinitionFunc(List tokens, string module, TokenIdent name, List parameters, TypedNodeStatement body, NubType returnType) : TypedNodeDefinition(tokens, module) { public TokenIdent Name { get; } = name; public List Parameters { get; } = parameters; public TypedNodeStatement Body { get; } = body; public NubType ReturnType { get; } = returnType; public NubTypeFunc GetNubType() { return NubTypeFunc.Get(Parameters.Select(x => x.Type).ToList(), ReturnType); } public class Param(List tokens, TokenIdent name, NubType type) : TypedNode(tokens) { public TokenIdent Name { get; } = name; public NubType Type { get; } = type; } } public abstract class TypedNodeStatement(List tokens) : TypedNode(tokens); public class TypedNodeStatementBlock(List tokens, List statements) : TypedNodeStatement(tokens) { public List Statements { get; } = statements; } public class TypedNodeStatementFuncCall(List tokens, TypedNodeExpression target, List parameters) : TypedNodeStatement(tokens) { public TypedNodeExpression Target { get; } = target; public List Parameters { get; } = parameters; } public class TypedNodeStatementReturn(List tokens, TypedNodeExpression value) : TypedNodeStatement(tokens) { public TypedNodeExpression Value { get; } = value; } public class TypedNodeStatementVariableDeclaration(List tokens, TokenIdent name, NubType type, TypedNodeExpression value) : TypedNodeStatement(tokens) { public TokenIdent Name { get; } = name; public NubType Type { get; } = type; public TypedNodeExpression Value { get; } = value; } public class TypedNodeStatementAssignment(List tokens, TypedNodeExpression target, TypedNodeExpression value) : TypedNodeStatement(tokens) { public TypedNodeExpression Target { get; } = target; public TypedNodeExpression Value { get; } = value; } public class TypedNodeStatementIf(List tokens, TypedNodeExpression condition, TypedNodeStatement thenBlock, TypedNodeStatement? elseBlock) : TypedNodeStatement(tokens) { public TypedNodeExpression Condition { get; } = condition; public TypedNodeStatement ThenBlock { get; } = thenBlock; public TypedNodeStatement? ElseBlock { get; } = elseBlock; } public class TypedNodeStatementWhile(List tokens, TypedNodeExpression condition, TypedNodeStatement body) : TypedNodeStatement(tokens) { public TypedNodeExpression Condition { get; } = condition; public TypedNodeStatement Body { get; } = body; } public class TypedNodeStatementMatch(List tokens, TypedNodeExpression target, List cases) : TypedNodeStatement(tokens) { public TypedNodeExpression Target { get; } = target; public List Cases { get; } = cases; public class Case(List tokens, TokenIdent type, TokenIdent variableName, TypedNodeStatement body) : Node(tokens) { public TokenIdent Type { get; } = type; public TokenIdent VariableName { get; } = variableName; public TypedNodeStatement Body { get; } = body; } } public abstract class TypedNodeExpression(List tokens, NubType type) : TypedNode(tokens) { public NubType Type { get; } = type; } public class TypedNodeExpressionIntLiteral(List tokens, NubType type, TokenIntLiteral value) : TypedNodeExpression(tokens, type) { public TokenIntLiteral Value { get; } = value; } public class TypedNodeExpressionStringLiteral(List tokens, NubType type, TokenStringLiteral value) : TypedNodeExpression(tokens, type) { public TokenStringLiteral Value { get; } = value; } public class TypedNodeExpressionBoolLiteral(List tokens, NubType type, TokenBoolLiteral value) : TypedNodeExpression(tokens, type) { public TokenBoolLiteral Value { get; } = value; } public class TypedNodeExpressionStructLiteral(List tokens, NubType type, List initializers) : TypedNodeExpression(tokens, type) { public List Initializers { get; } = initializers; public class Initializer(List tokens, TokenIdent name, TypedNodeExpression value) : Node(tokens) { public TokenIdent Name { get; } = name; public TypedNodeExpression Value { get; } = value; } } public class TypedNodeExpressionEnumLiteral(List tokens, NubType type, TypedNodeExpression value) : TypedNodeExpression(tokens, type) { public TypedNodeExpression Value { get; } = value; } public class TypedNodeExpressionMemberAccess(List tokens, NubType type, TypedNodeExpression target, TokenIdent name) : TypedNodeExpression(tokens, type) { public TypedNodeExpression Target { get; } = target; public TokenIdent Name { get; } = name; } public class TypedNodeExpressionFuncCall(List tokens, NubType type, TypedNodeExpression target, List parameters) : TypedNodeExpression(tokens, type) { public TypedNodeExpression Target { get; } = target; public List Parameters { get; } = parameters; } public class TypedNodeExpressionLocalIdent(List tokens, NubType type, string value) : TypedNodeExpression(tokens, type) { public string Name { get; } = value; } public class TypedNodeExpressionGlobalIdent(List tokens, NubType type, string module, string value) : TypedNodeExpression(tokens, type) { public string Module { get; } = module; public string Name { get; } = value; } public class TypedNodeExpressionBinary(List tokens, NubType type, TypedNodeExpression left, TypedNodeExpressionBinary.Op operation, TypedNodeExpression right) : TypedNodeExpression(tokens, type) { public TypedNodeExpression Left { get; } = left; public Op Operation { get; } = operation; public TypedNodeExpression Right { get; } = right; public enum Op { Add, Subtract, Multiply, Divide, Modulo, Equal, NotEqual, LessThan, LessThanOrEqual, GreaterThan, GreaterThanOrEqual, LeftShift, RightShift, // BitwiseAnd, // BitwiseXor, // BitwiseOr, LogicalAnd, LogicalOr, } } public class TypedNodeExpressionUnary(List tokens, NubType type, TypedNodeExpression target, TypedNodeExpressionUnary.Op op) : TypedNodeExpression(tokens, type) { public TypedNodeExpression Target { get; } = target; public Op Operation { get; } = op; public enum Op { Negate, Invert, } }