namespace Compiler; public sealed class TypeChecker(string fileName, string moduleName, NodeDefinitionFunc function, ModuleGraph moduleGraph) { public static TypedNodeDefinitionFunc? CheckFunction(string fileName, string moduleName, NodeDefinitionFunc function, ModuleGraph moduleGraph, out List diagnostics) { return new TypeChecker(fileName, moduleName, function, moduleGraph).CheckFunction(out diagnostics); } private readonly Scope scope = new(null); private TypedNodeDefinitionFunc? CheckFunction(out List diagnostics) { diagnostics = []; var parameters = new List(); var invalidParameter = false; TypedNodeStatement? body = null; NubType? returnType = null; foreach (var parameter in function.Parameters) { try { parameters.Add(CheckDefinitionFuncParameter(parameter)); } catch (CompileException e) { diagnostics.Add(e.Diagnostic); invalidParameter = true; } } try { body = CheckStatement(function.Body); } catch (CompileException e) { diagnostics.Add(e.Diagnostic); } try { returnType = ResolveType(function.ReturnType); } catch (CompileException e) { diagnostics.Add(e.Diagnostic); } if (body == null || returnType == null || invalidParameter) return null; return new TypedNodeDefinitionFunc(function.Tokens, moduleName, function.Name, parameters, body, returnType); } private TypedNodeDefinitionFunc.Param CheckDefinitionFuncParameter(NodeDefinitionFunc.Param node) { return new TypedNodeDefinitionFunc.Param(node.Tokens, node.Name, ResolveType(node.Type)); } private TypedNodeStatement CheckStatement(NodeStatement node) { return node switch { NodeStatementAssignment statement => CheckStatementAssignment(statement), NodeStatementBlock statement => CheckStatementBlock(statement), NodeStatementExpression statement => CheckStatementExpression(statement), NodeStatementIf statement => CheckStatementIf(statement), NodeStatementReturn statement => CheckStatementReturn(statement), NodeStatementVariableDeclaration statement => CheckStatementVariableDeclaration(statement), NodeStatementWhile statement => CheckStatementWhile(statement), _ => throw new ArgumentOutOfRangeException(nameof(node)) }; } private TypedNodeStatementAssignment CheckStatementAssignment(NodeStatementAssignment statement) { return new TypedNodeStatementAssignment(statement.Tokens, CheckExpression(statement.Target), CheckExpression(statement.Value)); } private TypedNodeStatementBlock CheckStatementBlock(NodeStatementBlock statement) { return new TypedNodeStatementBlock(statement.Tokens, statement.Statements.Select(CheckStatement).ToList()); } private TypedNodeStatementFuncCall CheckStatementExpression(NodeStatementExpression statement) { if (statement.Expression is not NodeExpressionFuncCall funcCall) throw new CompileException(Diagnostic.Error("Expected statement or function call").At(fileName, statement).Build()); return new TypedNodeStatementFuncCall(statement.Tokens, CheckExpression(funcCall.Target), funcCall.Parameters.Select(CheckExpression).ToList()); } private TypedNodeStatementIf CheckStatementIf(NodeStatementIf statement) { return new TypedNodeStatementIf(statement.Tokens, CheckExpression(statement.Condition), CheckStatement(statement.ThenBlock), statement.ElseBlock == null ? null : CheckStatement(statement.ElseBlock)); } private TypedNodeStatementReturn CheckStatementReturn(NodeStatementReturn statement) { return new TypedNodeStatementReturn(statement.Tokens, CheckExpression(statement.Value)); } private TypedNodeStatementVariableDeclaration CheckStatementVariableDeclaration(NodeStatementVariableDeclaration statement) { var type = ResolveType(statement.Type); var value = CheckExpression(statement.Value); if (type != value.Type) throw new CompileException(Diagnostic.Error("Type of variable does match type of assigned value").At(fileName, value).Build()); scope.DeclareIdentifier(statement.Name.Ident, type); return new TypedNodeStatementVariableDeclaration(statement.Tokens, statement.Name, type, value); } private TypedNodeStatementWhile CheckStatementWhile(NodeStatementWhile statement) { return new TypedNodeStatementWhile(statement.Tokens, CheckExpression(statement.Condition), CheckStatement(statement.Block)); } private TypedNodeExpression CheckExpression(NodeExpression node) { return node switch { NodeExpressionBinary expression => CheckExpressionBinary(expression), NodeExpressionUnary expression => CheckExpressionUnary(expression), NodeExpressionBoolLiteral expression => CheckExpressionBoolLiteral(expression), NodeExpressionLocalIdent expression => CheckExpressionIdent(expression), NodeExpressionModuleIdent expression => CheckExpressionModuleIdent(expression), NodeExpressionIntLiteral expression => CheckExpressionIntLiteral(expression), NodeExpressionMemberAccess expression => CheckExpressionMemberAccess(expression), NodeExpressionFuncCall expression => CheckExpressionFuncCall(expression), NodeExpressionStringLiteral expression => CheckExpressionStringLiteral(expression), NodeExpressionStructLiteral expression => CheckExpressionStructLiteral(expression), _ => throw new ArgumentOutOfRangeException(nameof(node)) }; } private TypedNodeExpressionBinary CheckExpressionBinary(NodeExpressionBinary expression) { var left = CheckExpression(expression.Left); var right = CheckExpression(expression.Right); NubType type; switch (expression.Operation) { case NodeExpressionBinary.Op.Add: case NodeExpressionBinary.Op.Subtract: case NodeExpressionBinary.Op.Multiply: case NodeExpressionBinary.Op.Divide: case NodeExpressionBinary.Op.Modulo: { if (left.Type is not NubTypeSInt and not NubTypeUInt) throw new CompileException(Diagnostic.Error($"Unsupported type for left hand side arithmetic operation: {left.Type}").At(fileName, left).Build()); if (right.Type is not NubTypeSInt and not NubTypeUInt) throw new CompileException(Diagnostic.Error($"Unsupported type for right hand side arithmetic operation: {right.Type}").At(fileName, right).Build()); type = left.Type; break; } case NodeExpressionBinary.Op.LeftShift: case NodeExpressionBinary.Op.RightShift: { if (left.Type is not NubTypeUInt) throw new CompileException(Diagnostic.Error($"Unsupported type for left hand side of left/right shift operation: {left.Type}").At(fileName, left).Build()); if (right.Type is not NubTypeUInt) throw new CompileException(Diagnostic.Error($"Unsupported type for right hand side of left/right shift operation: {right.Type}").At(fileName, right).Build()); type = left.Type; break; } case NodeExpressionBinary.Op.Equal: case NodeExpressionBinary.Op.NotEqual: case NodeExpressionBinary.Op.LessThan: case NodeExpressionBinary.Op.LessThanOrEqual: case NodeExpressionBinary.Op.GreaterThan: case NodeExpressionBinary.Op.GreaterThanOrEqual: { if (left.Type is not NubTypeSInt and not NubTypeUInt) throw new CompileException(Diagnostic.Error($"Unsupported type for left hand side of comparison: {left.Type}").At(fileName, left).Build()); if (right.Type is not NubTypeSInt and not NubTypeUInt) throw new CompileException(Diagnostic.Error($"Unsupported type for right hand side of comparison: {right.Type}").At(fileName, right).Build()); type = NubTypeBool.Instance; break; } case NodeExpressionBinary.Op.LogicalAnd: case NodeExpressionBinary.Op.LogicalOr: { if (left.Type is not NubTypeBool) throw new CompileException(Diagnostic.Error($"Unsupported type for left hand side of logical operation: {left.Type}").At(fileName, left).Build()); if (right.Type is not NubTypeBool) throw new CompileException(Diagnostic.Error($"Unsupported type for right hand side of logical operation: {right.Type}").At(fileName, right).Build()); type = NubTypeBool.Instance; break; } default: throw new ArgumentOutOfRangeException(); } return new TypedNodeExpressionBinary(expression.Tokens, type, left, CheckExpressionBinaryOperation(expression.Operation), right); } private static TypedNodeExpressionBinary.Op CheckExpressionBinaryOperation(NodeExpressionBinary.Op op) { return op switch { NodeExpressionBinary.Op.Add => TypedNodeExpressionBinary.Op.Add, NodeExpressionBinary.Op.Subtract => TypedNodeExpressionBinary.Op.Subtract, NodeExpressionBinary.Op.Multiply => TypedNodeExpressionBinary.Op.Multiply, NodeExpressionBinary.Op.Divide => TypedNodeExpressionBinary.Op.Divide, NodeExpressionBinary.Op.Modulo => TypedNodeExpressionBinary.Op.Modulo, NodeExpressionBinary.Op.Equal => TypedNodeExpressionBinary.Op.Equal, NodeExpressionBinary.Op.NotEqual => TypedNodeExpressionBinary.Op.NotEqual, NodeExpressionBinary.Op.LessThan => TypedNodeExpressionBinary.Op.LessThan, NodeExpressionBinary.Op.LessThanOrEqual => TypedNodeExpressionBinary.Op.LessThanOrEqual, NodeExpressionBinary.Op.GreaterThan => TypedNodeExpressionBinary.Op.GreaterThan, NodeExpressionBinary.Op.GreaterThanOrEqual => TypedNodeExpressionBinary.Op.GreaterThanOrEqual, NodeExpressionBinary.Op.LeftShift => TypedNodeExpressionBinary.Op.LeftShift, NodeExpressionBinary.Op.RightShift => TypedNodeExpressionBinary.Op.RightShift, NodeExpressionBinary.Op.LogicalAnd => TypedNodeExpressionBinary.Op.LogicalAnd, NodeExpressionBinary.Op.LogicalOr => TypedNodeExpressionBinary.Op.LogicalOr, _ => throw new ArgumentOutOfRangeException(nameof(op), op, null) }; } private TypedNodeExpressionUnary CheckExpressionUnary(NodeExpressionUnary expression) { var target = CheckExpression(expression.Target); NubType type; switch (expression.Operation) { case NodeExpressionUnary.Op.Negate: { if (target.Type is not NubTypeSInt and not NubTypeUInt) throw new CompileException(Diagnostic.Error($"Unsupported type for negation: {target.Type}").At(fileName, target).Build()); type = target.Type; break; } case NodeExpressionUnary.Op.Invert: { if (target.Type is not NubTypeBool) throw new CompileException(Diagnostic.Error($"Unsupported type for inversion: {target.Type}").At(fileName, target).Build()); type = NubTypeBool.Instance; break; } default: throw new ArgumentOutOfRangeException(); } return new TypedNodeExpressionUnary(expression.Tokens, type, target, CheckExpressionUnaryOperation(expression.Operation)); } private static TypedNodeExpressionUnary.Op CheckExpressionUnaryOperation(NodeExpressionUnary.Op op) { return op switch { NodeExpressionUnary.Op.Negate => TypedNodeExpressionUnary.Op.Negate, NodeExpressionUnary.Op.Invert => TypedNodeExpressionUnary.Op.Invert, _ => throw new ArgumentOutOfRangeException(nameof(op), op, null) }; } private TypedNodeExpressionBoolLiteral CheckExpressionBoolLiteral(NodeExpressionBoolLiteral expression) { return new TypedNodeExpressionBoolLiteral(expression.Tokens, NubTypeBool.Instance, expression.Value); } private TypedNodeExpressionLocalIdent CheckExpressionIdent(NodeExpressionLocalIdent expression) { var type = scope.GetIdentifierType(expression.Value.Ident); if (type == null) throw new CompileException(Diagnostic.Error($"Identifier '{expression.Value.Ident}' is not declared").At(fileName, expression.Value).Build()); return new TypedNodeExpressionLocalIdent(expression.Tokens, type, expression.Value); } private TypedNodeExpressionModuleIdent CheckExpressionModuleIdent(NodeExpressionModuleIdent expression) { if (!moduleGraph.TryResolveModule(expression.Module.Ident, out var module)) throw new CompileException(Diagnostic.Error($"Module '{expression.Module.Ident}' not found").At(fileName, expression.Module).Build()); var includePrivate = expression.Module.Ident == moduleName; if (!module.TryResolveIdentifierType(expression.Value.Ident, includePrivate, out var identifierType)) throw new CompileException(Diagnostic.Error($"Identifier '{expression.Module.Ident}::{expression.Value.Ident}' not found").At(fileName, expression.Value).Build()); return new TypedNodeExpressionModuleIdent(expression.Tokens, identifierType, expression.Module, expression.Value); } private TypedNodeExpressionIntLiteral CheckExpressionIntLiteral(NodeExpressionIntLiteral expression) { return new TypedNodeExpressionIntLiteral(expression.Tokens, NubTypeSInt.Get(32), expression.Value); } private TypedNodeExpressionMemberAccess CheckExpressionMemberAccess(NodeExpressionMemberAccess expression) { var target = CheckExpression(expression.Target); if (target.Type is not NubTypeStruct structType) throw new CompileException(Diagnostic.Error($"Cannot access member of non-struct type {target.Type}").At(fileName, target).Build()); var field = structType.Fields.FirstOrDefault(x => x.Name == expression.Name.Ident); if (field == null) throw new CompileException(Diagnostic.Error($"Struct '{target.Type}' does not have a field matching the name '{expression.Name.Ident}'").At(fileName, target).Build()); return new TypedNodeExpressionMemberAccess(expression.Tokens, field.Type, target, expression.Name); } private TypedNodeExpressionFuncCall CheckExpressionFuncCall(NodeExpressionFuncCall expression) { var target = CheckExpression(expression.Target); if (target.Type is not NubTypeFunc funcType) throw new CompileException(Diagnostic.Error($"Cannot invoke function call on type '{target.Type}'").At(fileName, target).Build()); var parameters = expression.Parameters.Select(CheckExpression).ToList(); return new TypedNodeExpressionFuncCall(expression.Tokens, funcType.ReturnType, target, parameters); } private TypedNodeExpressionStringLiteral CheckExpressionStringLiteral(NodeExpressionStringLiteral expression) { return new TypedNodeExpressionStringLiteral(expression.Tokens, NubTypeString.Instance, expression.Value); } private TypedNodeExpressionStructLiteral CheckExpressionStructLiteral(NodeExpressionStructLiteral expression) { if (!moduleGraph.TryResolveModule(expression.Module.Ident, out var module)) throw new CompileException(Diagnostic.Error($"Module '{expression.Module.Ident}' not found").At(fileName, expression.Module).Build()); var includePrivate = expression.Module.Ident == moduleName; if (!module.TryResolveCustomType(expression.Name.Ident, includePrivate, out var customType)) throw new CompileException(Diagnostic.Error($"Struct '{expression.Module.Ident}::{expression.Name.Ident}' not found").At(fileName, expression.Name).Build()); if (customType is not NubTypeStruct structType) throw new CompileException(Diagnostic.Error($"Cannot create struct literal of non-struct type '{expression.Module.Ident}::{expression.Name.Ident}'").At(fileName, expression.Name).Build()); var initializers = new List(); foreach (var initializer in expression.Initializers) { var field = structType.Fields.FirstOrDefault(x => x.Name == initializer.Name.Ident); if (field == null) throw new CompileException(Diagnostic.Error($"Field '{initializer.Name.Ident}' does not exist on struct '{expression.Name.Ident}'").At(fileName, initializer.Name).Build()); var value = CheckExpression(initializer.Value); if (value.Type != field.Type) throw new CompileException(Diagnostic.Error($"Type of assignment ({value.Type}) does not match expected type of field '{field.Name}' ({field.Type})").At(fileName, initializer.Name).Build()); initializers.Add(new TypedNodeExpressionStructLiteral.Initializer(initializer.Tokens, initializer.Name, value)); } return new TypedNodeExpressionStructLiteral(expression.Tokens, structType, initializers); } private NubType ResolveType(NodeType node) { return node switch { NodeTypeBool => NubTypeBool.Instance, NodeTypeCustom type => ResolveCustomType(type), NodeTypeFunc type => NubTypeFunc.Get(type.Parameters.Select(ResolveType).ToList(), ResolveType(type.ReturnType)), NodeTypePointer type => NubTypePointer.Get(ResolveType(type.To)), NodeTypeSInt type => NubTypeSInt.Get(type.Width), NodeTypeUInt type => NubTypeUInt.Get(type.Width), NodeTypeString => NubTypeString.Instance, NodeTypeVoid => NubTypeVoid.Instance, _ => throw new ArgumentOutOfRangeException(nameof(node)) }; } private NubType ResolveCustomType(NodeTypeCustom type) { if (!moduleGraph.TryResolveModule(type.Module.Ident, out var module)) throw new CompileException(Diagnostic.Error($"Module '{type.Module.Ident}' not found").At(fileName, type.Module).Build()); var includePrivate = type.Module.Ident == moduleName; if (!module.TryResolveCustomType(type.Name.Ident, includePrivate, out var customType)) throw new CompileException(Diagnostic.Error($"Custom type '{type.Module.Ident}::{type.Name.Ident}' not found").At(fileName, type.Name).Build()); return customType; } private class Scope(Scope? parent) { private readonly Dictionary identifiers = new(); public void DeclareIdentifier(string name, NubType type) { identifiers.Add(name, type); } public NubType? GetIdentifierType(string name) { return identifiers.TryGetValue(name, out var type) ? type : parent?.GetIdentifierType(name); } } } public abstract class TypedNode(List tokens) { public List Tokens { get; } = tokens; } public abstract class TypedNodeDefinition(List tokens, string module) : TypedNode(tokens) { public string Module { get; } = module; } public sealed class TypedNodeDefinitionFunc(List tokens, string module, TokenIdent name, List parameters, TypedNodeStatement body, NubType returnType) : TypedNodeDefinition(tokens, module) { public TokenIdent Name { get; } = name; public List Parameters { get; } = parameters; public TypedNodeStatement Body { get; } = body; public NubType ReturnType { get; } = returnType; public NubTypeFunc GetNubType() { return NubTypeFunc.Get(Parameters.Select(x => x.Type).ToList(), ReturnType); } public string GetMangledName() { return SymbolNameGen.Exported(Module, Name.Ident, GetNubType()); } public sealed class Param(List tokens, TokenIdent name, NubType type) : TypedNode(tokens) { public TokenIdent Name { get; } = name; public NubType Type { get; } = type; } } public abstract class TypedNodeStatement(List tokens) : TypedNode(tokens); public sealed class TypedNodeStatementBlock(List tokens, List statements) : TypedNodeStatement(tokens) { public List Statements { get; } = statements; } public sealed class TypedNodeStatementFuncCall(List tokens, TypedNodeExpression target, List parameters) : TypedNodeStatement(tokens) { public TypedNodeExpression Target { get; } = target; public List Parameters { get; } = parameters; } public sealed class TypedNodeStatementReturn(List tokens, TypedNodeExpression value) : TypedNodeStatement(tokens) { public TypedNodeExpression Value { get; } = value; } public sealed class TypedNodeStatementVariableDeclaration(List tokens, TokenIdent name, NubType type, TypedNodeExpression value) : TypedNodeStatement(tokens) { public TokenIdent Name { get; } = name; public NubType Type { get; } = type; public TypedNodeExpression Value { get; } = value; } public sealed class TypedNodeStatementAssignment(List tokens, TypedNodeExpression target, TypedNodeExpression value) : TypedNodeStatement(tokens) { public TypedNodeExpression Target { get; } = target; public TypedNodeExpression Value { get; } = value; } public sealed class TypedNodeStatementIf(List tokens, TypedNodeExpression condition, TypedNodeStatement thenBlock, TypedNodeStatement? elseBlock) : TypedNodeStatement(tokens) { public TypedNodeExpression Condition { get; } = condition; public TypedNodeStatement ThenBlock { get; } = thenBlock; public TypedNodeStatement? ElseBlock { get; } = elseBlock; } public sealed class TypedNodeStatementWhile(List tokens, TypedNodeExpression condition, TypedNodeStatement block) : TypedNodeStatement(tokens) { public TypedNodeExpression Condition { get; } = condition; public TypedNodeStatement Block { get; } = block; } public abstract class TypedNodeExpression(List tokens, NubType type) : TypedNode(tokens) { public NubType Type { get; } = type; } public sealed class TypedNodeExpressionIntLiteral(List tokens, NubType type, TokenIntLiteral value) : TypedNodeExpression(tokens, type) { public TokenIntLiteral Value { get; } = value; } public sealed class TypedNodeExpressionStringLiteral(List tokens, NubType type, TokenStringLiteral value) : TypedNodeExpression(tokens, type) { public TokenStringLiteral Value { get; } = value; } public sealed class TypedNodeExpressionBoolLiteral(List tokens, NubType type, TokenBoolLiteral value) : TypedNodeExpression(tokens, type) { public TokenBoolLiteral Value { get; } = value; } public sealed class TypedNodeExpressionStructLiteral(List tokens, NubType type, List initializers) : TypedNodeExpression(tokens, type) { public List Initializers { get; } = initializers; public sealed class Initializer(List tokens, TokenIdent name, TypedNodeExpression value) : Node(tokens) { public TokenIdent Name { get; } = name; public TypedNodeExpression Value { get; } = value; } } public sealed class TypedNodeExpressionMemberAccess(List tokens, NubType type, TypedNodeExpression target, TokenIdent name) : TypedNodeExpression(tokens, type) { public TypedNodeExpression Target { get; } = target; public TokenIdent Name { get; } = name; } public sealed class TypedNodeExpressionFuncCall(List tokens, NubType type, TypedNodeExpression target, List parameters) : TypedNodeExpression(tokens, type) { public TypedNodeExpression Target { get; } = target; public List Parameters { get; } = parameters; } public sealed class TypedNodeExpressionLocalIdent(List tokens, NubType type, TokenIdent value) : TypedNodeExpression(tokens, type) { public TokenIdent Value { get; } = value; } public sealed class TypedNodeExpressionModuleIdent(List tokens, NubType type, TokenIdent module, TokenIdent value) : TypedNodeExpression(tokens, type) { public TokenIdent Module { get; } = module; public TokenIdent Value { get; } = value; } public sealed class TypedNodeExpressionBinary(List tokens, NubType type, TypedNodeExpression left, TypedNodeExpressionBinary.Op operation, TypedNodeExpression right) : TypedNodeExpression(tokens, type) { public TypedNodeExpression Left { get; } = left; public Op Operation { get; } = operation; public TypedNodeExpression Right { get; } = right; public enum Op { Add, Subtract, Multiply, Divide, Modulo, Equal, NotEqual, LessThan, LessThanOrEqual, GreaterThan, GreaterThanOrEqual, LeftShift, RightShift, // BitwiseAnd, // BitwiseXor, // BitwiseOr, LogicalAnd, LogicalOr, } } public sealed class TypedNodeExpressionUnary(List tokens, NubType type, TypedNodeExpression target, TypedNodeExpressionUnary.Op op) : TypedNodeExpression(tokens, type) { public TypedNodeExpression Target { get; } = target; public Op Operation { get; } = op; public enum Op { Negate, Invert, } }