namespace Compiler; public sealed class TypeChecker(string fileName, NodeDefinitionFunc function, TypeResolver typeResolver) { public static TypedNodeDefinitionFunc? CheckFunction(string fileName, NodeDefinitionFunc function, TypeResolver typeResolver, out List diagnostics) { return new TypeChecker(fileName, function, typeResolver).CheckFunction(out diagnostics); } private Scope scope = new(null); private TypedNodeDefinitionFunc? CheckFunction(out List diagnostics) { diagnostics = []; var parameters = new List(); var invalidParameter = false; TypedNodeStatement? body = null; NubType? returnType = null; foreach (var parameter in function.Parameters) { try { parameters.Add(CheckDefinitionFuncParameter(parameter)); } catch (CompileException e) { diagnostics.Add(e.Diagnostic); invalidParameter = true; } } try { body = CheckStatement(function.Body); } catch (CompileException e) { diagnostics.Add(e.Diagnostic); } try { returnType = typeResolver.Resolve(function.ReturnType); } catch (CompileException e) { diagnostics.Add(e.Diagnostic); } if (body == null || returnType == null || invalidParameter) return null; return new TypedNodeDefinitionFunc(function.Tokens, function.Name, parameters, body, returnType); } private TypedNodeDefinitionFunc.Param CheckDefinitionFuncParameter(NodeDefinitionFunc.Param node) { return new TypedNodeDefinitionFunc.Param(node.Tokens, node.Name, typeResolver.Resolve(node.Type)); } private TypedNodeStatement CheckStatement(NodeStatement node) { return node switch { NodeStatementAssignment statement => CheckStatementAssignment(statement), NodeStatementBlock statement => CheckStatementBlock(statement), NodeStatementFuncCall statement => CheckStatementFuncCall(statement), NodeStatementIf statement => CheckStatementIf(statement), NodeStatementReturn statement => CheckStatementReturn(statement), NodeStatementVariableDeclaration statement => CheckStatementVariableDeclaration(statement), NodeStatementWhile statement => CheckStatementWhile(statement), _ => throw new ArgumentOutOfRangeException(nameof(node)) }; } private TypedNodeStatementAssignment CheckStatementAssignment(NodeStatementAssignment statement) { return new TypedNodeStatementAssignment(statement.Tokens, CheckExpression(statement.Target), CheckExpression(statement.Value)); } private TypedNodeStatementBlock CheckStatementBlock(NodeStatementBlock statement) { return new TypedNodeStatementBlock(statement.Tokens, statement.Statements.Select(CheckStatement).ToList()); } private TypedNodeStatementFuncCall CheckStatementFuncCall(NodeStatementFuncCall statement) { return new TypedNodeStatementFuncCall(statement.Tokens, CheckExpression(statement.Target), statement.Parameters.Select(CheckExpression).ToList()); } private TypedNodeStatementIf CheckStatementIf(NodeStatementIf statement) { return new TypedNodeStatementIf(statement.Tokens, CheckExpression(statement.Condition), CheckStatement(statement.ThenBlock), statement.ElseBlock == null ? null : CheckStatement(statement.ElseBlock)); } private TypedNodeStatementReturn CheckStatementReturn(NodeStatementReturn statement) { return new TypedNodeStatementReturn(statement.Tokens, CheckExpression(statement.Value)); } private TypedNodeStatementVariableDeclaration CheckStatementVariableDeclaration(NodeStatementVariableDeclaration statement) { var type = typeResolver.Resolve(statement.Type); var value = CheckExpression(statement.Value); if (type != value.Type) throw new CompileException(Diagnostic.Error("Type of variable does match type of assigned value").At(fileName, value).Build()); scope.DeclareIdentifier(statement.Name.Ident, type); return new TypedNodeStatementVariableDeclaration(statement.Tokens, statement.Name, type, value); } private TypedNodeStatementWhile CheckStatementWhile(NodeStatementWhile statement) { return new TypedNodeStatementWhile(statement.Tokens, CheckExpression(statement.Condition), CheckStatement(statement.Block)); } private TypedNodeExpression CheckExpression(NodeExpression node) { return node switch { NodeExpressionBinary expression => CheckExpressionBinary(expression), NodeExpressionUnary expression => CheckExpressionUnary(expression), NodeExpressionBoolLiteral expression => CheckExpressionBoolLiteral(expression), NodeExpressionIdent expression => CheckExpressionIdent(expression), NodeExpressionIntLiteral expression => CheckExpressionIntLiteral(expression), NodeExpressionMemberAccess expression => CheckExpressionMemberAccess(expression), NodeExpressionStringLiteral expression => CheckExpressionStringLiteral(expression), NodeExpressionStructLiteral expression => CheckExpressionStructLiteral(expression), _ => throw new ArgumentOutOfRangeException(nameof(node)) }; } private TypedNodeExpressionBinary CheckExpressionBinary(NodeExpressionBinary expression) { var left = CheckExpression(expression.Left); var right = CheckExpression(expression.Right); NubType type; switch (expression.Operation) { case NodeExpressionBinary.Op.Add: case NodeExpressionBinary.Op.Subtract: case NodeExpressionBinary.Op.Multiply: case NodeExpressionBinary.Op.Divide: case NodeExpressionBinary.Op.Modulo: { if (left.Type is not NubTypeSInt and not NubTypeUInt) throw new CompileException(Diagnostic.Error($"Unsupported type for left hand side arithmetic operation: {left.Type}").At(fileName, left).Build()); if (right.Type is not NubTypeSInt and not NubTypeUInt) throw new CompileException(Diagnostic.Error($"Unsupported type for right hand side arithmetic operation: {right.Type}").At(fileName, right).Build()); type = left.Type; break; } case NodeExpressionBinary.Op.LeftShift: case NodeExpressionBinary.Op.RightShift: { if (left.Type is not NubTypeUInt) throw new CompileException(Diagnostic.Error($"Unsupported type for left hand side of left/right shift operation: {left.Type}").At(fileName, left).Build()); if (right.Type is not NubTypeUInt) throw new CompileException(Diagnostic.Error($"Unsupported type for right hand side of left/right shift operation: {right.Type}").At(fileName, right).Build()); type = left.Type; break; } case NodeExpressionBinary.Op.Equal: case NodeExpressionBinary.Op.NotEqual: case NodeExpressionBinary.Op.LessThan: case NodeExpressionBinary.Op.LessThanOrEqual: case NodeExpressionBinary.Op.GreaterThan: case NodeExpressionBinary.Op.GreaterThanOrEqual: { if (left.Type is not NubTypeSInt and not NubTypeUInt) throw new CompileException(Diagnostic.Error($"Unsupported type for left hand side of comparison: {left.Type}").At(fileName, left).Build()); if (right.Type is not NubTypeSInt and not NubTypeUInt) throw new CompileException(Diagnostic.Error($"Unsupported type for right hand side of comparison: {right.Type}").At(fileName, right).Build()); type = new NubTypeBool(); break; } case NodeExpressionBinary.Op.LogicalAnd: case NodeExpressionBinary.Op.LogicalOr: { if (left.Type is not NubTypeBool) throw new CompileException(Diagnostic.Error($"Unsupported type for left hand side of logical operation: {left.Type}").At(fileName, left).Build()); if (right.Type is not NubTypeBool) throw new CompileException(Diagnostic.Error($"Unsupported type for right hand side of logical operation: {right.Type}").At(fileName, right).Build()); type = new NubTypeBool(); break; } default: throw new ArgumentOutOfRangeException(); } return new TypedNodeExpressionBinary(expression.Tokens, type, left, CheckExpressionBinaryOperation(expression.Operation), right); } private static TypedNodeExpressionBinary.Op CheckExpressionBinaryOperation(NodeExpressionBinary.Op op) { return op switch { NodeExpressionBinary.Op.Add => TypedNodeExpressionBinary.Op.Add, NodeExpressionBinary.Op.Subtract => TypedNodeExpressionBinary.Op.Subtract, NodeExpressionBinary.Op.Multiply => TypedNodeExpressionBinary.Op.Multiply, NodeExpressionBinary.Op.Divide => TypedNodeExpressionBinary.Op.Divide, NodeExpressionBinary.Op.Modulo => TypedNodeExpressionBinary.Op.Modulo, NodeExpressionBinary.Op.Equal => TypedNodeExpressionBinary.Op.Equal, NodeExpressionBinary.Op.NotEqual => TypedNodeExpressionBinary.Op.NotEqual, NodeExpressionBinary.Op.LessThan => TypedNodeExpressionBinary.Op.LessThan, NodeExpressionBinary.Op.LessThanOrEqual => TypedNodeExpressionBinary.Op.LessThanOrEqual, NodeExpressionBinary.Op.GreaterThan => TypedNodeExpressionBinary.Op.GreaterThan, NodeExpressionBinary.Op.GreaterThanOrEqual => TypedNodeExpressionBinary.Op.GreaterThanOrEqual, NodeExpressionBinary.Op.LeftShift => TypedNodeExpressionBinary.Op.LeftShift, NodeExpressionBinary.Op.RightShift => TypedNodeExpressionBinary.Op.RightShift, NodeExpressionBinary.Op.LogicalAnd => TypedNodeExpressionBinary.Op.LogicalAnd, NodeExpressionBinary.Op.LogicalOr => TypedNodeExpressionBinary.Op.LogicalOr, _ => throw new ArgumentOutOfRangeException(nameof(op), op, null) }; } private TypedNodeExpressionUnary CheckExpressionUnary(NodeExpressionUnary expression) { var target = CheckExpression(expression.Target); NubType type; switch (expression.Operation) { case NodeExpressionUnary.Op.Negate: { if (target.Type is not NubTypeSInt and not NubTypeUInt) throw new CompileException(Diagnostic.Error($"Unsupported type for negation: {target.Type}").At(fileName, target).Build()); type = target.Type; break; } case NodeExpressionUnary.Op.Invert: { if (target.Type is not NubTypeBool) throw new CompileException(Diagnostic.Error($"Unsupported type for inversion: {target.Type}").At(fileName, target).Build()); type = new NubTypeBool(); break; } default: throw new ArgumentOutOfRangeException(); } return new TypedNodeExpressionUnary(expression.Tokens, type, target, CheckExpressionUnaryOperation(expression.Operation)); } private static TypedNodeExpressionUnary.Op CheckExpressionUnaryOperation(NodeExpressionUnary.Op op) { return op switch { NodeExpressionUnary.Op.Negate => TypedNodeExpressionUnary.Op.Negate, NodeExpressionUnary.Op.Invert => TypedNodeExpressionUnary.Op.Invert, _ => throw new ArgumentOutOfRangeException(nameof(op), op, null) }; } private TypedNodeExpressionBoolLiteral CheckExpressionBoolLiteral(NodeExpressionBoolLiteral expression) { return new TypedNodeExpressionBoolLiteral(expression.Tokens, new NubTypeBool(), expression.Value); } private TypedNodeExpressionIdent CheckExpressionIdent(NodeExpressionIdent expression) { var type = scope.GetIdentifierType(expression.Value.Ident); if (type == null) throw new CompileException(Diagnostic.Error($"Identifier '{expression.Value.Ident}' is not declared").At(fileName, expression.Value).Build()); return new TypedNodeExpressionIdent(expression.Tokens, type, expression.Value); } private TypedNodeExpressionIntLiteral CheckExpressionIntLiteral(NodeExpressionIntLiteral expression) { return new TypedNodeExpressionIntLiteral(expression.Tokens, new NubTypeSInt(32), expression.Value); } private TypedNodeExpressionMemberAccess CheckExpressionMemberAccess(NodeExpressionMemberAccess expression) { var target = CheckExpression(expression.Target); if (target.Type is not NubTypeStruct structType) throw new CompileException(Diagnostic.Error($"Cannot access member of non-struct type {target.Type}").At(fileName, target).Build()); var field = structType.Fields.FirstOrDefault(x => x.Name == expression.Name.Ident); if (field == null) throw new CompileException(Diagnostic.Error($"Struct {target.Type} does not have a field matching the name '{expression.Name.Ident}'").At(fileName, target).Build()); return new TypedNodeExpressionMemberAccess(expression.Tokens, field.Type, target, expression.Name); } private TypedNodeExpressionStringLiteral CheckExpressionStringLiteral(NodeExpressionStringLiteral expression) { return new TypedNodeExpressionStringLiteral(expression.Tokens, new NubTypeString(), expression.Value); } private TypedNodeExpressionStructLiteral CheckExpressionStructLiteral(NodeExpressionStructLiteral expression) { var type = typeResolver.GetNamedStruct(expression.Module.Ident, expression.Name.Ident); if (type == null) throw new CompileException(Diagnostic.Error($"Undeclared struct '{expression.Module.Ident}::{expression.Name.Ident}'").At(fileName, expression.Name).Build()); var initializers = new List(); foreach (var initializer in expression.Initializers) { var field = type.Fields.FirstOrDefault(x => x.Name == initializer.Name.Ident); if (field == null) throw new CompileException(Diagnostic.Error($"Field '{initializer.Name.Ident}' does not exist on struct '{expression.Name.Ident}'").At(fileName, initializer.Name).Build()); var value = CheckExpression(initializer.Value); if (value.Type != field.Type) throw new CompileException(Diagnostic.Error($"Type of assignment ({value.Type}) does not match expected type of field '{field.Name}' ({field.Type})").At(fileName, initializer.Name).Build()); initializers.Add(new TypedNodeExpressionStructLiteral.Initializer(initializer.Tokens, initializer.Name, value)); } return new TypedNodeExpressionStructLiteral(expression.Tokens, type, initializers); } private class Scope(Scope? parent) { private readonly Dictionary identifiers = new(); public void DeclareIdentifier(string name, NubType type) { identifiers.Add(name, type); } public NubType? GetIdentifierType(string name) { return identifiers.TryGetValue(name, out var type) ? type : parent?.GetIdentifierType(name); } } } public abstract class TypedNode(List tokens) { public readonly List Tokens = tokens; } public abstract class TypedNodeDefinition(List tokens) : TypedNode(tokens); public sealed class TypedNodeDefinitionFunc(List tokens, TokenIdent name, List parameters, TypedNodeStatement body, NubType returnType) : TypedNodeDefinition(tokens) { public readonly TokenIdent Name = name; public readonly List Parameters = parameters; public readonly TypedNodeStatement Body = body; public readonly NubType ReturnType = returnType; public sealed class Param(List tokens, TokenIdent name, NubType type) : TypedNode(tokens) { public readonly TokenIdent Name = name; public readonly NubType Type = type; } } public sealed class TypedNodeDefinitionStruct(List tokens, TokenIdent name, List fields) : TypedNodeDefinition(tokens) { public readonly TokenIdent Name = name; public readonly List Fields = fields; public sealed class Field(List tokens, TokenIdent name, NubType type) : TypedNode(tokens) { public readonly TokenIdent Name = name; public readonly NubType Type = type; } } public abstract class TypedNodeStatement(List tokens) : TypedNode(tokens); public sealed class TypedNodeStatementBlock(List tokens, List statements) : TypedNodeStatement(tokens) { public readonly List Statements = statements; } public sealed class TypedNodeStatementFuncCall(List tokens, TypedNodeExpression target, List parameters) : TypedNodeStatement(tokens) { public readonly TypedNodeExpression Target = target; public readonly List Parameters = parameters; } public sealed class TypedNodeStatementReturn(List tokens, TypedNodeExpression value) : TypedNodeStatement(tokens) { public readonly TypedNodeExpression Value = value; } public sealed class TypedNodeStatementVariableDeclaration(List tokens, TokenIdent name, NubType type, TypedNodeExpression value) : TypedNodeStatement(tokens) { public readonly TokenIdent Name = name; public readonly NubType Type = type; public readonly TypedNodeExpression Value = value; } public sealed class TypedNodeStatementAssignment(List tokens, TypedNodeExpression target, TypedNodeExpression value) : TypedNodeStatement(tokens) { public readonly TypedNodeExpression Target = target; public readonly TypedNodeExpression Value = value; } public sealed class TypedNodeStatementIf(List tokens, TypedNodeExpression condition, TypedNodeStatement thenBlock, TypedNodeStatement? elseBlock) : TypedNodeStatement(tokens) { public readonly TypedNodeExpression Condition = condition; public readonly TypedNodeStatement ThenBlock = thenBlock; public readonly TypedNodeStatement? ElseBlock = elseBlock; } public sealed class TypedNodeStatementWhile(List tokens, TypedNodeExpression condition, TypedNodeStatement block) : TypedNodeStatement(tokens) { public readonly TypedNodeExpression Condition = condition; public readonly TypedNodeStatement Block = block; } public abstract class TypedNodeExpression(List tokens, NubType type) : TypedNode(tokens) { public readonly NubType Type = type; } public sealed class TypedNodeExpressionIntLiteral(List tokens, NubType type, TokenIntLiteral value) : TypedNodeExpression(tokens, type) { public readonly TokenIntLiteral Value = value; } public sealed class TypedNodeExpressionStringLiteral(List tokens, NubType type, TokenStringLiteral value) : TypedNodeExpression(tokens, type) { public readonly TokenStringLiteral Value = value; } public sealed class TypedNodeExpressionBoolLiteral(List tokens, NubType type, TokenBoolLiteral value) : TypedNodeExpression(tokens, type) { public readonly TokenBoolLiteral Value = value; } public sealed class TypedNodeExpressionStructLiteral(List tokens, NubType type, List initializers) : TypedNodeExpression(tokens, type) { public readonly List Initializers = initializers; public sealed class Initializer(List tokens, TokenIdent name, TypedNodeExpression value) : Node(tokens) { public readonly TokenIdent Name = name; public readonly TypedNodeExpression Value = value; } } public sealed class TypedNodeExpressionMemberAccess(List tokens, NubType type, TypedNodeExpression target, TokenIdent name) : TypedNodeExpression(tokens, type) { public readonly TypedNodeExpression Target = target; public readonly TokenIdent Name = name; } public sealed class TypedNodeExpressionIdent(List tokens, NubType type, TokenIdent value) : TypedNodeExpression(tokens, type) { public readonly TokenIdent Value = value; } public sealed class TypedNodeExpressionBinary(List tokens, NubType type, TypedNodeExpression left, TypedNodeExpressionBinary.Op operation, TypedNodeExpression right) : TypedNodeExpression(tokens, type) { public readonly TypedNodeExpression Left = left; public readonly Op Operation = operation; public readonly TypedNodeExpression Right = right; public enum Op { Add, Subtract, Multiply, Divide, Modulo, Equal, NotEqual, LessThan, LessThanOrEqual, GreaterThan, GreaterThanOrEqual, LeftShift, RightShift, // BitwiseAnd, // BitwiseXor, // BitwiseOr, LogicalAnd, LogicalOr, } } public sealed class TypedNodeExpressionUnary(List tokens, NubType type, TypedNodeExpression target, TypedNodeExpressionUnary.Op op) : TypedNodeExpression(tokens, type) { public TypedNodeExpression Target { get; } = target; public Op Operation { get; } = op; public enum Op { Negate, Invert, } }