Files
nub-lang/compiler/TypeChecker.cs
nub31 96670b1201 ...
2026-02-09 19:34:47 +01:00

510 lines
22 KiB
C#

namespace Compiler;
public sealed class TypeChecker(string fileName, NodeDefinitionFunc function, TypeResolver typeResolver)
{
public static TypedNodeDefinitionFunc? CheckFunction(string fileName, NodeDefinitionFunc function, TypeResolver typeResolver, out List<Diagnostic> diagnostics)
{
return new TypeChecker(fileName, function, typeResolver).CheckFunction(out diagnostics);
}
private Scope scope = new(null);
private TypedNodeDefinitionFunc? CheckFunction(out List<Diagnostic> diagnostics)
{
diagnostics = [];
var parameters = new List<TypedNodeDefinitionFunc.Param>();
var invalidParameter = false;
TypedNodeStatement? body = null;
NubType? returnType = null;
foreach (var parameter in function.Parameters)
{
try
{
parameters.Add(CheckDefinitionFuncParameter(parameter));
}
catch (CompileException e)
{
diagnostics.Add(e.Diagnostic);
invalidParameter = true;
}
}
try
{
body = CheckStatement(function.Body);
}
catch (CompileException e)
{
diagnostics.Add(e.Diagnostic);
}
try
{
returnType = typeResolver.Resolve(function.ReturnType);
}
catch (CompileException e)
{
diagnostics.Add(e.Diagnostic);
}
if (body == null || returnType == null || invalidParameter)
return null;
return new TypedNodeDefinitionFunc(function.Tokens, function.Name, parameters, body, returnType);
}
private TypedNodeDefinitionFunc.Param CheckDefinitionFuncParameter(NodeDefinitionFunc.Param node)
{
return new TypedNodeDefinitionFunc.Param(node.Tokens, node.Name, typeResolver.Resolve(node.Type));
}
private TypedNodeStatement CheckStatement(NodeStatement node)
{
return node switch
{
NodeStatementAssignment statement => CheckStatementAssignment(statement),
NodeStatementBlock statement => CheckStatementBlock(statement),
NodeStatementFuncCall statement => CheckStatementFuncCall(statement),
NodeStatementIf statement => CheckStatementIf(statement),
NodeStatementReturn statement => CheckStatementReturn(statement),
NodeStatementVariableDeclaration statement => CheckStatementVariableDeclaration(statement),
NodeStatementWhile statement => CheckStatementWhile(statement),
_ => throw new ArgumentOutOfRangeException(nameof(node))
};
}
private TypedNodeStatementAssignment CheckStatementAssignment(NodeStatementAssignment statement)
{
return new TypedNodeStatementAssignment(statement.Tokens, CheckExpression(statement.Target), CheckExpression(statement.Value));
}
private TypedNodeStatementBlock CheckStatementBlock(NodeStatementBlock statement)
{
return new TypedNodeStatementBlock(statement.Tokens, statement.Statements.Select(CheckStatement).ToList());
}
private TypedNodeStatementFuncCall CheckStatementFuncCall(NodeStatementFuncCall statement)
{
return new TypedNodeStatementFuncCall(statement.Tokens, CheckExpression(statement.Target), statement.Parameters.Select(CheckExpression).ToList());
}
private TypedNodeStatementIf CheckStatementIf(NodeStatementIf statement)
{
return new TypedNodeStatementIf(statement.Tokens, CheckExpression(statement.Condition), CheckStatement(statement.ThenBlock), statement.ElseBlock == null ? null : CheckStatement(statement.ElseBlock));
}
private TypedNodeStatementReturn CheckStatementReturn(NodeStatementReturn statement)
{
return new TypedNodeStatementReturn(statement.Tokens, CheckExpression(statement.Value));
}
private TypedNodeStatementVariableDeclaration CheckStatementVariableDeclaration(NodeStatementVariableDeclaration statement)
{
var type = typeResolver.Resolve(statement.Type);
var value = CheckExpression(statement.Value);
if (type != value.Type)
throw new CompileException(Diagnostic.Error("Type of variable does match type of assigned value").At(fileName, value).Build());
scope.DeclareIdentifier(statement.Name.Ident, type);
return new TypedNodeStatementVariableDeclaration(statement.Tokens, statement.Name, type, value);
}
private TypedNodeStatementWhile CheckStatementWhile(NodeStatementWhile statement)
{
return new TypedNodeStatementWhile(statement.Tokens, CheckExpression(statement.Condition), CheckStatement(statement.Block));
}
private TypedNodeExpression CheckExpression(NodeExpression node)
{
return node switch
{
NodeExpressionBinary expression => CheckExpressionBinary(expression),
NodeExpressionUnary expression => CheckExpressionUnary(expression),
NodeExpressionBoolLiteral expression => CheckExpressionBoolLiteral(expression),
NodeExpressionIdent expression => CheckExpressionIdent(expression),
NodeExpressionIntLiteral expression => CheckExpressionIntLiteral(expression),
NodeExpressionMemberAccess expression => CheckExpressionMemberAccess(expression),
NodeExpressionStringLiteral expression => CheckExpressionStringLiteral(expression),
NodeExpressionStructLiteral expression => CheckExpressionStructLiteral(expression),
_ => throw new ArgumentOutOfRangeException(nameof(node))
};
}
private TypedNodeExpressionBinary CheckExpressionBinary(NodeExpressionBinary expression)
{
var left = CheckExpression(expression.Left);
var right = CheckExpression(expression.Right);
NubType type;
switch (expression.Operation)
{
case NodeExpressionBinary.Op.Add:
case NodeExpressionBinary.Op.Subtract:
case NodeExpressionBinary.Op.Multiply:
case NodeExpressionBinary.Op.Divide:
case NodeExpressionBinary.Op.Modulo:
{
if (left.Type is not NubTypeSInt and not NubTypeUInt)
throw new CompileException(Diagnostic.Error($"Unsupported type for left hand side arithmetic operation: {left.Type}").At(fileName, left).Build());
if (right.Type is not NubTypeSInt and not NubTypeUInt)
throw new CompileException(Diagnostic.Error($"Unsupported type for right hand side arithmetic operation: {right.Type}").At(fileName, right).Build());
type = left.Type;
break;
}
case NodeExpressionBinary.Op.LeftShift:
case NodeExpressionBinary.Op.RightShift:
{
if (left.Type is not NubTypeUInt)
throw new CompileException(Diagnostic.Error($"Unsupported type for left hand side of left/right shift operation: {left.Type}").At(fileName, left).Build());
if (right.Type is not NubTypeUInt)
throw new CompileException(Diagnostic.Error($"Unsupported type for right hand side of left/right shift operation: {right.Type}").At(fileName, right).Build());
type = left.Type;
break;
}
case NodeExpressionBinary.Op.Equal:
case NodeExpressionBinary.Op.NotEqual:
case NodeExpressionBinary.Op.LessThan:
case NodeExpressionBinary.Op.LessThanOrEqual:
case NodeExpressionBinary.Op.GreaterThan:
case NodeExpressionBinary.Op.GreaterThanOrEqual:
{
if (left.Type is not NubTypeSInt and not NubTypeUInt)
throw new CompileException(Diagnostic.Error($"Unsupported type for left hand side of comparison: {left.Type}").At(fileName, left).Build());
if (right.Type is not NubTypeSInt and not NubTypeUInt)
throw new CompileException(Diagnostic.Error($"Unsupported type for right hand side of comparison: {right.Type}").At(fileName, right).Build());
type = new NubTypeBool();
break;
}
case NodeExpressionBinary.Op.LogicalAnd:
case NodeExpressionBinary.Op.LogicalOr:
{
if (left.Type is not NubTypeBool)
throw new CompileException(Diagnostic.Error($"Unsupported type for left hand side of logical operation: {left.Type}").At(fileName, left).Build());
if (right.Type is not NubTypeBool)
throw new CompileException(Diagnostic.Error($"Unsupported type for right hand side of logical operation: {right.Type}").At(fileName, right).Build());
type = new NubTypeBool();
break;
}
default:
throw new ArgumentOutOfRangeException();
}
return new TypedNodeExpressionBinary(expression.Tokens, type, left, CheckExpressionBinaryOperation(expression.Operation), right);
}
private static TypedNodeExpressionBinary.Op CheckExpressionBinaryOperation(NodeExpressionBinary.Op op)
{
return op switch
{
NodeExpressionBinary.Op.Add => TypedNodeExpressionBinary.Op.Add,
NodeExpressionBinary.Op.Subtract => TypedNodeExpressionBinary.Op.Subtract,
NodeExpressionBinary.Op.Multiply => TypedNodeExpressionBinary.Op.Multiply,
NodeExpressionBinary.Op.Divide => TypedNodeExpressionBinary.Op.Divide,
NodeExpressionBinary.Op.Modulo => TypedNodeExpressionBinary.Op.Modulo,
NodeExpressionBinary.Op.Equal => TypedNodeExpressionBinary.Op.Equal,
NodeExpressionBinary.Op.NotEqual => TypedNodeExpressionBinary.Op.NotEqual,
NodeExpressionBinary.Op.LessThan => TypedNodeExpressionBinary.Op.LessThan,
NodeExpressionBinary.Op.LessThanOrEqual => TypedNodeExpressionBinary.Op.LessThanOrEqual,
NodeExpressionBinary.Op.GreaterThan => TypedNodeExpressionBinary.Op.GreaterThan,
NodeExpressionBinary.Op.GreaterThanOrEqual => TypedNodeExpressionBinary.Op.GreaterThanOrEqual,
NodeExpressionBinary.Op.LeftShift => TypedNodeExpressionBinary.Op.LeftShift,
NodeExpressionBinary.Op.RightShift => TypedNodeExpressionBinary.Op.RightShift,
NodeExpressionBinary.Op.LogicalAnd => TypedNodeExpressionBinary.Op.LogicalAnd,
NodeExpressionBinary.Op.LogicalOr => TypedNodeExpressionBinary.Op.LogicalOr,
_ => throw new ArgumentOutOfRangeException(nameof(op), op, null)
};
}
private TypedNodeExpressionUnary CheckExpressionUnary(NodeExpressionUnary expression)
{
var target = CheckExpression(expression.Target);
NubType type;
switch (expression.Operation)
{
case NodeExpressionUnary.Op.Negate:
{
if (target.Type is not NubTypeSInt and not NubTypeUInt)
throw new CompileException(Diagnostic.Error($"Unsupported type for negation: {target.Type}").At(fileName, target).Build());
type = target.Type;
break;
}
case NodeExpressionUnary.Op.Invert:
{
if (target.Type is not NubTypeBool)
throw new CompileException(Diagnostic.Error($"Unsupported type for inversion: {target.Type}").At(fileName, target).Build());
type = new NubTypeBool();
break;
}
default:
throw new ArgumentOutOfRangeException();
}
return new TypedNodeExpressionUnary(expression.Tokens, type, target, CheckExpressionUnaryOperation(expression.Operation));
}
private static TypedNodeExpressionUnary.Op CheckExpressionUnaryOperation(NodeExpressionUnary.Op op)
{
return op switch
{
NodeExpressionUnary.Op.Negate => TypedNodeExpressionUnary.Op.Negate,
NodeExpressionUnary.Op.Invert => TypedNodeExpressionUnary.Op.Invert,
_ => throw new ArgumentOutOfRangeException(nameof(op), op, null)
};
}
private TypedNodeExpressionBoolLiteral CheckExpressionBoolLiteral(NodeExpressionBoolLiteral expression)
{
return new TypedNodeExpressionBoolLiteral(expression.Tokens, new NubTypeBool(), expression.Value);
}
private TypedNodeExpressionIdent CheckExpressionIdent(NodeExpressionIdent expression)
{
var type = scope.GetIdentifierType(expression.Value.Ident);
if (type == null)
throw new CompileException(Diagnostic.Error($"Identifier '{expression.Value.Ident}' is not declared").At(fileName, expression.Value).Build());
return new TypedNodeExpressionIdent(expression.Tokens, type, expression.Value);
}
private TypedNodeExpressionIntLiteral CheckExpressionIntLiteral(NodeExpressionIntLiteral expression)
{
return new TypedNodeExpressionIntLiteral(expression.Tokens, new NubTypeSInt(32), expression.Value);
}
private TypedNodeExpressionMemberAccess CheckExpressionMemberAccess(NodeExpressionMemberAccess expression)
{
var target = CheckExpression(expression.Target);
if (target.Type is not NubTypeStruct structType)
throw new CompileException(Diagnostic.Error($"Cannot access member of non-struct type {target.Type}").At(fileName, target).Build());
var field = structType.Fields.FirstOrDefault(x => x.Name == expression.Name.Ident);
if (field == null)
throw new CompileException(Diagnostic.Error($"Struct {target.Type} does not have a field matching the name '{expression.Name.Ident}'").At(fileName, target).Build());
return new TypedNodeExpressionMemberAccess(expression.Tokens, field.Type, target, expression.Name);
}
private TypedNodeExpressionStringLiteral CheckExpressionStringLiteral(NodeExpressionStringLiteral expression)
{
return new TypedNodeExpressionStringLiteral(expression.Tokens, new NubTypeString(), expression.Value);
}
private TypedNodeExpressionStructLiteral CheckExpressionStructLiteral(NodeExpressionStructLiteral expression)
{
var type = typeResolver.GetNamedStruct(expression.Module.Ident, expression.Name.Ident);
if (type == null)
throw new CompileException(Diagnostic.Error($"Undeclared struct '{expression.Module.Ident}::{expression.Name.Ident}'").At(fileName, expression.Name).Build());
var initializers = new List<TypedNodeExpressionStructLiteral.Initializer>();
foreach (var initializer in expression.Initializers)
{
var field = type.Fields.FirstOrDefault(x => x.Name == initializer.Name.Ident);
if (field == null)
throw new CompileException(Diagnostic.Error($"Field '{initializer.Name.Ident}' does not exist on struct '{expression.Name.Ident}'").At(fileName, initializer.Name).Build());
var value = CheckExpression(initializer.Value);
if (value.Type != field.Type)
throw new CompileException(Diagnostic.Error($"Type of assignment ({value.Type}) does not match expected type of field '{field.Name}' ({field.Type})").At(fileName, initializer.Name).Build());
initializers.Add(new TypedNodeExpressionStructLiteral.Initializer(initializer.Tokens, initializer.Name, value));
}
return new TypedNodeExpressionStructLiteral(expression.Tokens, type, initializers);
}
private class Scope(Scope? parent)
{
private readonly Dictionary<string, NubType> identifiers = new();
public void DeclareIdentifier(string name, NubType type)
{
identifiers.Add(name, type);
}
public NubType? GetIdentifierType(string name)
{
return identifiers.TryGetValue(name, out var type)
? type
: parent?.GetIdentifierType(name);
}
}
}
public abstract class TypedNode(List<Token> tokens)
{
public readonly List<Token> Tokens = tokens;
}
public abstract class TypedNodeDefinition(List<Token> tokens) : TypedNode(tokens);
public sealed class TypedNodeDefinitionFunc(List<Token> tokens, TokenIdent name, List<TypedNodeDefinitionFunc.Param> parameters, TypedNodeStatement body, NubType returnType) : TypedNodeDefinition(tokens)
{
public readonly TokenIdent Name = name;
public readonly List<Param> Parameters = parameters;
public readonly TypedNodeStatement Body = body;
public readonly NubType ReturnType = returnType;
public sealed class Param(List<Token> tokens, TokenIdent name, NubType type) : TypedNode(tokens)
{
public readonly TokenIdent Name = name;
public readonly NubType Type = type;
}
}
public sealed class TypedNodeDefinitionStruct(List<Token> tokens, TokenIdent name, List<TypedNodeDefinitionStruct.Field> fields) : TypedNodeDefinition(tokens)
{
public readonly TokenIdent Name = name;
public readonly List<Field> Fields = fields;
public sealed class Field(List<Token> tokens, TokenIdent name, NubType type) : TypedNode(tokens)
{
public readonly TokenIdent Name = name;
public readonly NubType Type = type;
}
}
public abstract class TypedNodeStatement(List<Token> tokens) : TypedNode(tokens);
public sealed class TypedNodeStatementBlock(List<Token> tokens, List<TypedNodeStatement> statements) : TypedNodeStatement(tokens)
{
public readonly List<TypedNodeStatement> Statements = statements;
}
public sealed class TypedNodeStatementFuncCall(List<Token> tokens, TypedNodeExpression target, List<TypedNodeExpression> parameters) : TypedNodeStatement(tokens)
{
public readonly TypedNodeExpression Target = target;
public readonly List<TypedNodeExpression> Parameters = parameters;
}
public sealed class TypedNodeStatementReturn(List<Token> tokens, TypedNodeExpression value) : TypedNodeStatement(tokens)
{
public readonly TypedNodeExpression Value = value;
}
public sealed class TypedNodeStatementVariableDeclaration(List<Token> tokens, TokenIdent name, NubType type, TypedNodeExpression value) : TypedNodeStatement(tokens)
{
public readonly TokenIdent Name = name;
public readonly NubType Type = type;
public readonly TypedNodeExpression Value = value;
}
public sealed class TypedNodeStatementAssignment(List<Token> tokens, TypedNodeExpression target, TypedNodeExpression value) : TypedNodeStatement(tokens)
{
public readonly TypedNodeExpression Target = target;
public readonly TypedNodeExpression Value = value;
}
public sealed class TypedNodeStatementIf(List<Token> tokens, TypedNodeExpression condition, TypedNodeStatement thenBlock, TypedNodeStatement? elseBlock) : TypedNodeStatement(tokens)
{
public readonly TypedNodeExpression Condition = condition;
public readonly TypedNodeStatement ThenBlock = thenBlock;
public readonly TypedNodeStatement? ElseBlock = elseBlock;
}
public sealed class TypedNodeStatementWhile(List<Token> tokens, TypedNodeExpression condition, TypedNodeStatement block) : TypedNodeStatement(tokens)
{
public readonly TypedNodeExpression Condition = condition;
public readonly TypedNodeStatement Block = block;
}
public abstract class TypedNodeExpression(List<Token> tokens, NubType type) : TypedNode(tokens)
{
public readonly NubType Type = type;
}
public sealed class TypedNodeExpressionIntLiteral(List<Token> tokens, NubType type, TokenIntLiteral value) : TypedNodeExpression(tokens, type)
{
public readonly TokenIntLiteral Value = value;
}
public sealed class TypedNodeExpressionStringLiteral(List<Token> tokens, NubType type, TokenStringLiteral value) : TypedNodeExpression(tokens, type)
{
public readonly TokenStringLiteral Value = value;
}
public sealed class TypedNodeExpressionBoolLiteral(List<Token> tokens, NubType type, TokenBoolLiteral value) : TypedNodeExpression(tokens, type)
{
public readonly TokenBoolLiteral Value = value;
}
public sealed class TypedNodeExpressionStructLiteral(List<Token> tokens, NubType type, List<TypedNodeExpressionStructLiteral.Initializer> initializers) : TypedNodeExpression(tokens, type)
{
public readonly List<Initializer> Initializers = initializers;
public sealed class Initializer(List<Token> tokens, TokenIdent name, TypedNodeExpression value) : Node(tokens)
{
public readonly TokenIdent Name = name;
public readonly TypedNodeExpression Value = value;
}
}
public sealed class TypedNodeExpressionMemberAccess(List<Token> tokens, NubType type, TypedNodeExpression target, TokenIdent name) : TypedNodeExpression(tokens, type)
{
public readonly TypedNodeExpression Target = target;
public readonly TokenIdent Name = name;
}
public sealed class TypedNodeExpressionIdent(List<Token> tokens, NubType type, TokenIdent value) : TypedNodeExpression(tokens, type)
{
public readonly TokenIdent Value = value;
}
public sealed class TypedNodeExpressionBinary(List<Token> tokens, NubType type, TypedNodeExpression left, TypedNodeExpressionBinary.Op operation, TypedNodeExpression right) : TypedNodeExpression(tokens, type)
{
public readonly TypedNodeExpression Left = left;
public readonly Op Operation = operation;
public readonly TypedNodeExpression Right = right;
public enum Op
{
Add,
Subtract,
Multiply,
Divide,
Modulo,
Equal,
NotEqual,
LessThan,
LessThanOrEqual,
GreaterThan,
GreaterThanOrEqual,
LeftShift,
RightShift,
// BitwiseAnd,
// BitwiseXor,
// BitwiseOr,
LogicalAnd,
LogicalOr,
}
}
public sealed class TypedNodeExpressionUnary(List<Token> tokens, NubType type, TypedNodeExpression target, TypedNodeExpressionUnary.Op op) : TypedNodeExpression(tokens, type)
{
public TypedNodeExpression Target { get; } = target;
public Op Operation { get; } = op;
public enum Op
{
Negate,
Invert,
}
}