From 99e4543e745a6259f83d1acf2ba0919501b96b17 Mon Sep 17 00:00:00 2001 From: nub31 Date: Fri, 16 May 2025 22:14:41 +0200 Subject: [PATCH] type checker --- example/program.nub | 2 +- src/compiler/Nub.Lang/Backend/Generator.cs | 23 +- .../Frontend/Typing/ExpressionTyper.cs | 292 ------------- .../Nub.Lang/Frontend/Typing/TypeChecker.cs | 412 ++++++++++++++++++ src/compiler/Nub.Lang/Program.cs | 4 +- 5 files changed, 418 insertions(+), 315 deletions(-) delete mode 100644 src/compiler/Nub.Lang/Frontend/Typing/ExpressionTyper.cs create mode 100644 src/compiler/Nub.Lang/Frontend/Typing/TypeChecker.cs diff --git a/example/program.nub b/example/program.nub index eb398cf..aa00e3a 100644 --- a/example/program.nub +++ b/example/program.nub @@ -3,6 +3,6 @@ import c global func main(argc: i64, argv: i64) { printf("args: %d, starts at %p\n", argc, argv) - x: i8 = 320000 + x: i8 = (i8)320000 printf("%d\n", x) } \ No newline at end of file diff --git a/src/compiler/Nub.Lang/Backend/Generator.cs b/src/compiler/Nub.Lang/Backend/Generator.cs index 9e72769..1885ef3 100644 --- a/src/compiler/Nub.Lang/Backend/Generator.cs +++ b/src/compiler/Nub.Lang/Backend/Generator.cs @@ -14,7 +14,6 @@ public class Generator private readonly Stack _breakLabels = new(); private readonly Stack _continueLabels = new(); private bool _codeIsReachable = true; - private LocalFuncDefinitionNode? _currentFuncDefininition; public Generator(List definitions) { @@ -218,7 +217,6 @@ public class Generator private void GenerateFuncDefinition(LocalFuncDefinitionNode node) { - _currentFuncDefininition = node; _variables.Clear(); if (node.Global) @@ -290,7 +288,6 @@ public class Generator } _builder.AppendLine("}"); - _currentFuncDefininition = null; } private void GenerateStructDefinition(StructDefinitionNode structDefinition) @@ -370,8 +367,7 @@ public class Generator } var parameter = funcCall.Parameters[i]; - var parameterOutput = GenerateExpression(parameter); - var result = GenerateTypeConversion(parameterOutput, parameter.Type, expectedType); + var result = GenerateExpression(parameter); var qbeParameterType = SQT(expectedType.Equals(NubPrimitiveType.Any) ? parameter.Type : expectedType); parameterStrings.Add($"{qbeParameterType} {result}"); @@ -436,14 +432,8 @@ public class Generator { if (@return.Value.HasValue) { - if (!_currentFuncDefininition!.ReturnType.HasValue) - { - throw new Exception("Cannot return a value when function does not have a return value"); - } - var result = GenerateExpression(@return.Value.Value); - var converted = GenerateTypeConversion(result, @return.Value.Value.Type, _currentFuncDefininition.ReturnType.Value); - _builder.AppendLine($" ret {converted}"); + _builder.AppendLine($" ret {result}"); } else { @@ -454,17 +444,10 @@ public class Generator private void GenerateVariableAssignment(VariableAssignmentNode variableAssignment) { var result = GenerateExpression(variableAssignment.Value); - var variableType = variableAssignment.Value.Type; - - if (variableAssignment.ExplicitType.HasValue) - { - result = GenerateTypeConversion(result, variableType, variableAssignment.ExplicitType.Value); - variableType = variableAssignment.ExplicitType.Value; - } _variables[variableAssignment.Name] = new Variable { Identifier = result, - Type = variableType + Type = variableAssignment.Value.Type }; } diff --git a/src/compiler/Nub.Lang/Frontend/Typing/ExpressionTyper.cs b/src/compiler/Nub.Lang/Frontend/Typing/ExpressionTyper.cs deleted file mode 100644 index 09fb4b3..0000000 --- a/src/compiler/Nub.Lang/Frontend/Typing/ExpressionTyper.cs +++ /dev/null @@ -1,292 +0,0 @@ -using Nub.Lang.Frontend.Parsing; - -namespace Nub.Lang.Frontend.Typing; - -public class Func(string name, List parameters, Optional body, Optional returnType) -{ - public string Name { get; } = name; - public List Parameters { get; } = parameters; - public Optional Body { get; } = body; - public Optional ReturnType { get; } = returnType; -} - -public class ExpressionTyper -{ - private readonly List _functions; - private readonly List _structDefinitions; - private readonly Stack _variables; - - public ExpressionTyper(List definitions) - { - _variables = new Stack(); - _functions = []; - - _structDefinitions = definitions.OfType().ToList(); - - var functions = definitions - .OfType() - .Select(f => new Func(f.Name, f.Parameters, f.Body, f.ReturnType)) - .ToList(); - - var externFunctions = definitions - .OfType() - .Select(f => new Func(f.Name, f.Parameters, Optional.Empty(), f.ReturnType)) - .ToList(); - - _functions.AddRange(functions); - _functions.AddRange(externFunctions); - } - - public void Populate() - { - _variables.Clear(); - - foreach (var @class in _structDefinitions) - { - foreach (var variable in @class.Fields) - { - if (variable.Value.HasValue) - { - PopulateExpression(variable.Value.Value); - } - } - } - - foreach (var function in _functions) - { - foreach (var parameter in function.Parameters) - { - _variables.Push(new Variable(parameter.Name, parameter.Type)); - } - - if (function.Body.HasValue) - { - PopulateBlock(function.Body.Value); - } - for (var i = 0; i < function.Parameters.Count; i++) - { - _variables.Pop(); - } - } - } - - private void PopulateBlock(BlockNode block) - { - var variableCount = _variables.Count; - foreach (var statement in block.Statements) - { - PopulateStatement(statement); - } - while (_variables.Count > variableCount) - { - _variables.Pop(); - } - } - - private void PopulateStatement(StatementNode statement) - { - switch (statement) - { - case BreakNode: - case ContinueNode: - break; - case FuncCallStatementNode funcCall: - PopulateFuncCallStatement(funcCall); - break; - case IfNode ifStatement: - PopulateIf(ifStatement); - break; - case ReturnNode returnNode: - PopulateReturn(returnNode); - break; - case VariableAssignmentNode variableAssignment: - PopulateVariableAssignment(variableAssignment); - break; - case WhileNode whileStatement: - PopulateWhileStatement(whileStatement); - break; - default: - throw new ArgumentOutOfRangeException(nameof(statement)); - } - } - - private void PopulateFuncCallStatement(FuncCallStatementNode funcCall) - { - foreach (var parameter in funcCall.FuncCall.Parameters) - { - PopulateExpression(parameter); - } - } - - private void PopulateIf(IfNode ifStatement) - { - PopulateExpression(ifStatement.Condition); - PopulateBlock(ifStatement.Body); - if (ifStatement.Else.HasValue) - { - ifStatement.Else.Value.Match - ( - PopulateIf, - PopulateBlock - ); - } - } - - private void PopulateReturn(ReturnNode returnNode) - { - if (returnNode.Value.HasValue) - { - PopulateExpression(returnNode.Value.Value); - } - } - - private void PopulateVariableAssignment(VariableAssignmentNode variableAssignment) - { - PopulateExpression(variableAssignment.Value); - _variables.Push(new Variable(variableAssignment.Name, variableAssignment.ExplicitType.HasValue ? variableAssignment.ExplicitType.Value : variableAssignment.Value.Type)); - } - - private void PopulateVariableReassignment(VariableAssignmentNode variableAssignment) - { - PopulateExpression(variableAssignment.Value); - } - - private void PopulateWhileStatement(WhileNode whileStatement) - { - PopulateExpression(whileStatement.Condition); - PopulateBlock(whileStatement.Body); - } - - private void PopulateExpression(ExpressionNode expression) - { - switch (expression) - { - case BinaryExpressionNode binaryExpression: - PopulateBinaryExpression(binaryExpression); - break; - case FuncCallExpressionNode funcCall: - PopulateFuncCallExpression(funcCall); - break; - case IdentifierNode identifier: - PopulateIdentifier(identifier); - break; - case LiteralNode literal: - PopulateLiteral(literal); - break; - case StructInitializerNode structInitializer: - PopulateStructInitializer(structInitializer); - break; - case StructFieldAccessorNode structMemberAccessor: - PopulateStructMemberAccessorNode(structMemberAccessor); - break; - default: - throw new ArgumentOutOfRangeException(nameof(expression)); - } - } - - private void PopulateBinaryExpression(BinaryExpressionNode binaryExpression) - { - PopulateExpression(binaryExpression.Left); - PopulateExpression(binaryExpression.Right); - switch (binaryExpression.Operator) - { - case BinaryExpressionOperator.Equal: - case BinaryExpressionOperator.NotEqual: - case BinaryExpressionOperator.GreaterThan: - case BinaryExpressionOperator.GreaterThanOrEqual: - case BinaryExpressionOperator.LessThan: - case BinaryExpressionOperator.LessThanOrEqual: - { - binaryExpression.Type = NubPrimitiveType.Bool; - break; - } - case BinaryExpressionOperator.Plus: - case BinaryExpressionOperator.Minus: - case BinaryExpressionOperator.Multiply: - case BinaryExpressionOperator.Divide: - { - binaryExpression.Type = binaryExpression.Left.Type; - break; - } - default: - { - throw new ArgumentOutOfRangeException(nameof(binaryExpression.Operator)); - } - } - } - - private void PopulateFuncCallExpression(FuncCallExpressionNode funcCall) - { - foreach (var parameter in funcCall.FuncCall.Parameters) - { - PopulateExpression(parameter); - } - - var function = _functions.FirstOrDefault(f => f.Name == funcCall.FuncCall.Name); - if (function == null) - { - throw new Exception($"Func {funcCall} is not defined"); - } - if (!function.ReturnType.HasValue) - { - throw new Exception($"Func {funcCall} must have a return type when used in an expression"); - } - funcCall.Type = function.ReturnType.Value; - } - - private void PopulateIdentifier(IdentifierNode identifier) - { - var type = _variables.FirstOrDefault(v => v.Name == identifier.Identifier)?.Type; - if (type == null) - { - throw new Exception($"Variable {identifier} is not defined"); - } - identifier.Type = type; - } - - private static void PopulateLiteral(LiteralNode literal) - { - literal.Type = literal.LiteralType; - } - - private void PopulateStructInitializer(StructInitializerNode structInitializer) - { - foreach (var initializer in structInitializer.Initializers) - { - PopulateExpression(initializer.Value); - } - - structInitializer.Type = structInitializer.StructType; - } - - private void PopulateStructMemberAccessorNode(StructFieldAccessorNode structFieldAccessor) - { - PopulateExpression(structFieldAccessor.Struct); - - var structType = structFieldAccessor.Struct.Type; - if (structType == null) - { - throw new Exception($"Cannot access field on non-struct type: {structFieldAccessor.Struct}"); - } - - var structDefinition = _structDefinitions.FirstOrDefault(s => s.Name == structType.Name); - if (structDefinition == null) - { - throw new Exception($"Struct {structType.Name} is not defined"); - } - - var field = structDefinition.Fields.FirstOrDefault(f => f.Name == structFieldAccessor.Field); - if (field == null) - { - throw new Exception($"Field {structFieldAccessor.Field} is not defined in struct {structType.Name}"); - } - - structFieldAccessor.Type = field.Type; - } - - private class Variable(string name, NubType type) - { - public string Name { get; } = name; - public NubType Type { get; } = type; - } -} \ No newline at end of file diff --git a/src/compiler/Nub.Lang/Frontend/Typing/TypeChecker.cs b/src/compiler/Nub.Lang/Frontend/Typing/TypeChecker.cs new file mode 100644 index 0000000..03b5ec1 --- /dev/null +++ b/src/compiler/Nub.Lang/Frontend/Typing/TypeChecker.cs @@ -0,0 +1,412 @@ +using Nub.Lang.Frontend.Parsing; + +namespace Nub.Lang.Frontend.Typing; + +public class TypeCheckingException : Exception +{ + public TypeCheckingException(string message) : base(message) { } +} + +public class TypeChecker +{ + private readonly Dictionary _variables = new(); + private readonly Dictionary Parameters, Optional ReturnType)> _functions = new(); + private readonly Dictionary> _structs = new(); + private NubType? _currentFunctionReturnType; + private bool _hasReturnStatement; + + public void TypeCheck(List definitions) + { + CollectDefinitions(definitions); + + foreach (var definition in definitions) + { + if (definition is LocalFuncDefinitionNode funcDef) + { + TypeCheckFunction(funcDef); + } + } + } + + private void CollectDefinitions(List definitions) + { + foreach (var definition in definitions) + { + switch (definition) + { + case StructDefinitionNode structDef: + RegisterStruct(structDef); + break; + case LocalFuncDefinitionNode funcDef: + RegisterFunction(funcDef); + break; + case ExternFuncDefinitionNode externFuncDef: + RegisterExternFunction(externFuncDef); + break; + } + } + } + + private void RegisterStruct(StructDefinitionNode structDef) + { + var fields = new Dictionary(); + foreach (var field in structDef.Fields) + { + if (fields.ContainsKey(field.Name)) + { + throw new TypeCheckingException($"Duplicate field '{field.Name}' in struct '{structDef.Name}'"); + } + fields[field.Name] = field.Type; + } + _structs[structDef.Name] = fields; + } + + private void RegisterFunction(LocalFuncDefinitionNode funcDef) + { + _functions[funcDef.Name] = (funcDef.Parameters, funcDef.ReturnType); + } + + private void RegisterExternFunction(ExternFuncDefinitionNode funcDef) + { + _functions[funcDef.Name] = (funcDef.Parameters, funcDef.ReturnType); + } + + private void TypeCheckFunction(LocalFuncDefinitionNode funcDef) + { + _variables.Clear(); + _currentFunctionReturnType = funcDef.ReturnType.HasValue ? funcDef.ReturnType.Value : null; + _hasReturnStatement = false; + + foreach (var param in funcDef.Parameters) + { + _variables[param.Name] = param.Type; + } + + TypeCheckBlock(funcDef.Body); + + if (_currentFunctionReturnType != null && !_hasReturnStatement) + { + throw new TypeCheckingException($"Function '{funcDef.Name}' must return a value of type '{_currentFunctionReturnType}'"); + } + } + + private void TypeCheckBlock(BlockNode block) + { + foreach (var statement in block.Statements) + { + TypeCheckStatement(statement); + } + } + + private void TypeCheckStatement(StatementNode statement) + { + switch (statement) + { + case VariableAssignmentNode varAssign: + TypeCheckVariableAssignment(varAssign); + break; + case FuncCallStatementNode funcCall: + TypeCheckFuncCall(funcCall.FuncCall); + break; + case IfNode ifNode: + TypeCheckIf(ifNode); + break; + case WhileNode whileNode: + TypeCheckWhile(whileNode); + break; + case ReturnNode returnNode: + TypeCheckReturn(returnNode); + break; + case BreakNode: + case ContinueNode: + break; + default: + throw new TypeCheckingException($"Unsupported statement type: {statement.GetType().Name}"); + } + } + + private void TypeCheckVariableAssignment(VariableAssignmentNode varAssign) + { + var valueType = TypeCheckExpression(varAssign.Value); + + if (varAssign.ExplicitType.HasValue) + { + var explicitType = varAssign.ExplicitType.Value; + if (!AreTypesCompatible(valueType, explicitType)) + { + throw new TypeCheckingException($"Cannot assign expression of type '{valueType}' to variable '{varAssign.Name}' of type '{explicitType}'"); + } + _variables[varAssign.Name] = explicitType; + } + else + { + _variables[varAssign.Name] = valueType; + } + } + + private NubType TypeCheckFuncCall(FuncCall funcCall) + { + if (!_functions.TryGetValue(funcCall.Name, out var funcSignature)) + { + throw new TypeCheckingException($"Function '{funcCall.Name}' is not defined"); + } + + var paramTypes = funcSignature.Parameters; + if (paramTypes.Take(paramTypes.Count - 1).Any(x => x.Variadic)) + { + throw new TypeCheckingException($"Function '{funcCall.Name}' has multiple variadic parameters"); + } + + for (var i = 0; i < funcCall.Parameters.Count; i++) + { + var argType = TypeCheckExpression(funcCall.Parameters[i]); + + NubType paramType; + if (i < paramTypes.Count) + { + paramType = paramTypes[i].Type; + } + else if (paramTypes.LastOrDefault()?.Variadic ?? false) + { + return paramTypes[^1].Type; + } + else + { + throw new TypeCheckingException($"Function '{funcCall.Name}' does not take {funcCall.Parameters.Count} parameters"); + } + + if (!AreTypesCompatible(argType, paramType)) + { + throw new TypeCheckingException($"Parameter {i} of function '{funcCall.Name}' expects type '{paramType}', but got '{argType}'"); + } + } + + return funcSignature.ReturnType.HasValue ? funcSignature.ReturnType.Value : NubPrimitiveType.Any; + } + + private void TypeCheckIf(IfNode ifNode) + { + var conditionType = TypeCheckExpression(ifNode.Condition); + if (!conditionType.Equals(NubPrimitiveType.Bool)) + { + throw new TypeCheckingException($"If condition must be a boolean expression, got '{conditionType}'"); + } + + TypeCheckBlock(ifNode.Body); + + if (ifNode.Else.HasValue) + { + var elseValue = ifNode.Else.Value; + elseValue.Match(TypeCheckIf, TypeCheckBlock); + } + } + + private void TypeCheckWhile(WhileNode whileNode) + { + var conditionType = TypeCheckExpression(whileNode.Condition); + if (!conditionType.Equals(NubPrimitiveType.Bool)) + { + throw new TypeCheckingException($"While condition must be a boolean expression, got '{conditionType}'"); + } + + TypeCheckBlock(whileNode.Body); + } + + private void TypeCheckReturn(ReturnNode returnNode) + { + _hasReturnStatement = true; + + if (returnNode.Value.HasValue) + { + var returnType = TypeCheckExpression(returnNode.Value.Value); + + if (_currentFunctionReturnType == null) + { + throw new TypeCheckingException("Cannot return a value from a function with no return type"); + } + + if (!AreTypesCompatible(returnType, _currentFunctionReturnType)) + { + throw new TypeCheckingException($"Return value of type '{returnType}' is not compatible with function return type '{_currentFunctionReturnType}'"); + } + } + else if (_currentFunctionReturnType != null) + { + throw new TypeCheckingException($"Function must return a value of type '{_currentFunctionReturnType}'"); + } + } + + private NubType TypeCheckExpression(ExpressionNode expression) + { + NubType resultType; + + switch (expression) + { + case LiteralNode literal: + resultType = literal.LiteralType; + break; + + case IdentifierNode identifier: + if (!_variables.TryGetValue(identifier.Identifier, out var varType)) + { + throw new TypeCheckingException($"Variable '{identifier.Identifier}' is not defined"); + } + resultType = varType; + break; + + case BinaryExpressionNode binaryExpr: + resultType = TypeCheckBinaryExpression(binaryExpr); + break; + + case FuncCallExpressionNode funcCallExpr: + resultType = TypeCheckFuncCall(funcCallExpr.FuncCall); + break; + + case StructInitializerNode structInit: + resultType = TypeCheckStructInitializer(structInit); + break; + + case StructFieldAccessorNode fieldAccess: + resultType = TypeCheckStructFieldAccess(fieldAccess); + break; + + default: + throw new TypeCheckingException($"Unsupported expression type: {expression.GetType().Name}"); + } + + expression.Type = resultType; + return resultType; + } + + private NubType TypeCheckBinaryExpression(BinaryExpressionNode binaryExpr) + { + var leftType = TypeCheckExpression(binaryExpr.Left); + var rightType = TypeCheckExpression(binaryExpr.Right); + + if (!leftType.Equals(rightType)) + { + throw new TypeCheckingException($"Left '{leftType}' and right '{rightType}' side of the binary expression is not equal"); + } + + switch (binaryExpr.Operator) + { + case BinaryExpressionOperator.Equal: + case BinaryExpressionOperator.NotEqual: + return NubPrimitiveType.Bool; + case BinaryExpressionOperator.GreaterThan: + case BinaryExpressionOperator.GreaterThanOrEqual: + case BinaryExpressionOperator.LessThan: + case BinaryExpressionOperator.LessThanOrEqual: + if (!IsNumeric(leftType)) + { + throw new TypeCheckingException($"Comparison operators require numeric operands, got '{leftType}' and '{rightType}'"); + } + return NubPrimitiveType.Bool; + case BinaryExpressionOperator.Plus: + case BinaryExpressionOperator.Minus: + case BinaryExpressionOperator.Multiply: + case BinaryExpressionOperator.Divide: + if (!IsNumeric(leftType)) + { + throw new TypeCheckingException($"Arithmetic operators require numeric operands, got '{leftType}' and '{rightType}'"); + } + return leftType; + default: + throw new TypeCheckingException($"Unsupported binary operator: {binaryExpr.Operator}"); + } + } + + private NubType TypeCheckStructInitializer(StructInitializerNode structInit) + { + var structType = structInit.StructType; + if (structType is not NubCustomType customType) + { + throw new TypeCheckingException($"Type '{structType}' is not a struct type"); + } + + if (!_structs.TryGetValue(customType.Name, out var fields)) + { + throw new TypeCheckingException($"Struct type '{customType.Name}' is not defined"); + } + + foreach (var initializer in structInit.Initializers) + { + if (!fields.TryGetValue(initializer.Key, out var fieldType)) + { + throw new TypeCheckingException($"Field '{initializer.Key}' does not exist in struct '{customType.Name}'"); + } + + var initializerType = TypeCheckExpression(initializer.Value); + if (!AreTypesCompatible(initializerType, fieldType)) + { + throw new TypeCheckingException($"Cannot initialize field '{initializer.Key}' of type '{fieldType}' with expression of type '{initializerType}'"); + } + } + + foreach (var field in fields) + { + if (!structInit.Initializers.ContainsKey(field.Key)) + { + throw new TypeCheckingException($"Field '{field.Key}' of struct '{customType.Name}' is not initialized"); + } + } + + return structType; + } + + private NubType TypeCheckStructFieldAccess(StructFieldAccessorNode fieldAccess) + { + var structType = TypeCheckExpression(fieldAccess.Struct); + + if (structType is not NubCustomType customType) + { + throw new TypeCheckingException($"Cannot access field '{fieldAccess.Field}' on non-struct type '{structType}'"); + } + + if (!_structs.TryGetValue(customType.Name, out var fields)) + { + throw new TypeCheckingException($"Struct type '{customType.Name}' is not defined"); + } + + if (!fields.TryGetValue(fieldAccess.Field, out var fieldType)) + { + throw new TypeCheckingException($"Field '{fieldAccess.Field}' does not exist in struct '{customType.Name}'"); + } + + return fieldType; + } + + #region Type Helper Methods + + private static bool AreTypesCompatible(NubType sourceType, NubType targetType) + { + return targetType.Equals(NubPrimitiveType.Any) || sourceType.Equals(targetType); + } + + private static bool IsNumeric(NubType type) + { + if (type is not NubPrimitiveType primitiveType) + { + return false; + } + + switch (primitiveType.Kind) + { + case PrimitiveTypeKind.I8: + case PrimitiveTypeKind.I16: + case PrimitiveTypeKind.I32: + case PrimitiveTypeKind.I64: + case PrimitiveTypeKind.U8: + case PrimitiveTypeKind.U16: + case PrimitiveTypeKind.U32: + case PrimitiveTypeKind.U64: + case PrimitiveTypeKind.F32: + case PrimitiveTypeKind.F64: + return true; + default: + return false; + } + } + + #endregion +} \ No newline at end of file diff --git a/src/compiler/Nub.Lang/Program.cs b/src/compiler/Nub.Lang/Program.cs index 955028a..a7b94cd 100644 --- a/src/compiler/Nub.Lang/Program.cs +++ b/src/compiler/Nub.Lang/Program.cs @@ -44,8 +44,8 @@ internal static class Program var modules = RunFrontend(input); var definitions = modules.SelectMany(f => f.Definitions).ToList(); - var typer = new ExpressionTyper(definitions); - typer.Populate(); + var typeChecker = new TypeChecker(); + typeChecker.TypeCheck(definitions); var generator = new Generator(definitions); var result = generator.Generate();