type checker

This commit is contained in:
nub31
2025-05-16 22:14:41 +02:00
parent 60f56a0b85
commit 0679eea140
5 changed files with 418 additions and 315 deletions

View File

@@ -3,6 +3,6 @@ import c
global func main(argc: i64, argv: i64) { global func main(argc: i64, argv: i64) {
printf("args: %d, starts at %p\n", argc, argv) printf("args: %d, starts at %p\n", argc, argv)
x: i8 = 320000 x: i8 = (i8)320000
printf("%d\n", x) printf("%d\n", x)
} }

View File

@@ -14,7 +14,6 @@ public class Generator
private readonly Stack<string> _breakLabels = new(); private readonly Stack<string> _breakLabels = new();
private readonly Stack<string> _continueLabels = new(); private readonly Stack<string> _continueLabels = new();
private bool _codeIsReachable = true; private bool _codeIsReachable = true;
private LocalFuncDefinitionNode? _currentFuncDefininition;
public Generator(List<DefinitionNode> definitions) public Generator(List<DefinitionNode> definitions)
{ {
@@ -218,7 +217,6 @@ public class Generator
private void GenerateFuncDefinition(LocalFuncDefinitionNode node) private void GenerateFuncDefinition(LocalFuncDefinitionNode node)
{ {
_currentFuncDefininition = node;
_variables.Clear(); _variables.Clear();
if (node.Global) if (node.Global)
@@ -290,7 +288,6 @@ public class Generator
} }
_builder.AppendLine("}"); _builder.AppendLine("}");
_currentFuncDefininition = null;
} }
private void GenerateStructDefinition(StructDefinitionNode structDefinition) private void GenerateStructDefinition(StructDefinitionNode structDefinition)
@@ -370,8 +367,7 @@ public class Generator
} }
var parameter = funcCall.Parameters[i]; var parameter = funcCall.Parameters[i];
var parameterOutput = GenerateExpression(parameter); var result = GenerateExpression(parameter);
var result = GenerateTypeConversion(parameterOutput, parameter.Type, expectedType);
var qbeParameterType = SQT(expectedType.Equals(NubPrimitiveType.Any) ? parameter.Type : expectedType); var qbeParameterType = SQT(expectedType.Equals(NubPrimitiveType.Any) ? parameter.Type : expectedType);
parameterStrings.Add($"{qbeParameterType} {result}"); parameterStrings.Add($"{qbeParameterType} {result}");
@@ -436,14 +432,8 @@ public class Generator
{ {
if (@return.Value.HasValue) if (@return.Value.HasValue)
{ {
if (!_currentFuncDefininition!.ReturnType.HasValue)
{
throw new Exception("Cannot return a value when function does not have a return value");
}
var result = GenerateExpression(@return.Value.Value); var result = GenerateExpression(@return.Value.Value);
var converted = GenerateTypeConversion(result, @return.Value.Value.Type, _currentFuncDefininition.ReturnType.Value); _builder.AppendLine($" ret {result}");
_builder.AppendLine($" ret {converted}");
} }
else else
{ {
@@ -454,17 +444,10 @@ public class Generator
private void GenerateVariableAssignment(VariableAssignmentNode variableAssignment) private void GenerateVariableAssignment(VariableAssignmentNode variableAssignment)
{ {
var result = GenerateExpression(variableAssignment.Value); var result = GenerateExpression(variableAssignment.Value);
var variableType = variableAssignment.Value.Type;
if (variableAssignment.ExplicitType.HasValue)
{
result = GenerateTypeConversion(result, variableType, variableAssignment.ExplicitType.Value);
variableType = variableAssignment.ExplicitType.Value;
}
_variables[variableAssignment.Name] = new Variable _variables[variableAssignment.Name] = new Variable
{ {
Identifier = result, Identifier = result,
Type = variableType Type = variableAssignment.Value.Type
}; };
} }

View File

@@ -1,292 +0,0 @@
using Nub.Lang.Frontend.Parsing;
namespace Nub.Lang.Frontend.Typing;
public class Func(string name, List<FuncParameter> parameters, Optional<BlockNode> body, Optional<NubType> returnType)
{
public string Name { get; } = name;
public List<FuncParameter> Parameters { get; } = parameters;
public Optional<BlockNode> Body { get; } = body;
public Optional<NubType> ReturnType { get; } = returnType;
}
public class ExpressionTyper
{
private readonly List<Func> _functions;
private readonly List<StructDefinitionNode> _structDefinitions;
private readonly Stack<Variable> _variables;
public ExpressionTyper(List<DefinitionNode> definitions)
{
_variables = new Stack<Variable>();
_functions = [];
_structDefinitions = definitions.OfType<StructDefinitionNode>().ToList();
var functions = definitions
.OfType<LocalFuncDefinitionNode>()
.Select(f => new Func(f.Name, f.Parameters, f.Body, f.ReturnType))
.ToList();
var externFunctions = definitions
.OfType<ExternFuncDefinitionNode>()
.Select(f => new Func(f.Name, f.Parameters, Optional<BlockNode>.Empty(), f.ReturnType))
.ToList();
_functions.AddRange(functions);
_functions.AddRange(externFunctions);
}
public void Populate()
{
_variables.Clear();
foreach (var @class in _structDefinitions)
{
foreach (var variable in @class.Fields)
{
if (variable.Value.HasValue)
{
PopulateExpression(variable.Value.Value);
}
}
}
foreach (var function in _functions)
{
foreach (var parameter in function.Parameters)
{
_variables.Push(new Variable(parameter.Name, parameter.Type));
}
if (function.Body.HasValue)
{
PopulateBlock(function.Body.Value);
}
for (var i = 0; i < function.Parameters.Count; i++)
{
_variables.Pop();
}
}
}
private void PopulateBlock(BlockNode block)
{
var variableCount = _variables.Count;
foreach (var statement in block.Statements)
{
PopulateStatement(statement);
}
while (_variables.Count > variableCount)
{
_variables.Pop();
}
}
private void PopulateStatement(StatementNode statement)
{
switch (statement)
{
case BreakNode:
case ContinueNode:
break;
case FuncCallStatementNode funcCall:
PopulateFuncCallStatement(funcCall);
break;
case IfNode ifStatement:
PopulateIf(ifStatement);
break;
case ReturnNode returnNode:
PopulateReturn(returnNode);
break;
case VariableAssignmentNode variableAssignment:
PopulateVariableAssignment(variableAssignment);
break;
case WhileNode whileStatement:
PopulateWhileStatement(whileStatement);
break;
default:
throw new ArgumentOutOfRangeException(nameof(statement));
}
}
private void PopulateFuncCallStatement(FuncCallStatementNode funcCall)
{
foreach (var parameter in funcCall.FuncCall.Parameters)
{
PopulateExpression(parameter);
}
}
private void PopulateIf(IfNode ifStatement)
{
PopulateExpression(ifStatement.Condition);
PopulateBlock(ifStatement.Body);
if (ifStatement.Else.HasValue)
{
ifStatement.Else.Value.Match
(
PopulateIf,
PopulateBlock
);
}
}
private void PopulateReturn(ReturnNode returnNode)
{
if (returnNode.Value.HasValue)
{
PopulateExpression(returnNode.Value.Value);
}
}
private void PopulateVariableAssignment(VariableAssignmentNode variableAssignment)
{
PopulateExpression(variableAssignment.Value);
_variables.Push(new Variable(variableAssignment.Name, variableAssignment.ExplicitType.HasValue ? variableAssignment.ExplicitType.Value : variableAssignment.Value.Type));
}
private void PopulateVariableReassignment(VariableAssignmentNode variableAssignment)
{
PopulateExpression(variableAssignment.Value);
}
private void PopulateWhileStatement(WhileNode whileStatement)
{
PopulateExpression(whileStatement.Condition);
PopulateBlock(whileStatement.Body);
}
private void PopulateExpression(ExpressionNode expression)
{
switch (expression)
{
case BinaryExpressionNode binaryExpression:
PopulateBinaryExpression(binaryExpression);
break;
case FuncCallExpressionNode funcCall:
PopulateFuncCallExpression(funcCall);
break;
case IdentifierNode identifier:
PopulateIdentifier(identifier);
break;
case LiteralNode literal:
PopulateLiteral(literal);
break;
case StructInitializerNode structInitializer:
PopulateStructInitializer(structInitializer);
break;
case StructFieldAccessorNode structMemberAccessor:
PopulateStructMemberAccessorNode(structMemberAccessor);
break;
default:
throw new ArgumentOutOfRangeException(nameof(expression));
}
}
private void PopulateBinaryExpression(BinaryExpressionNode binaryExpression)
{
PopulateExpression(binaryExpression.Left);
PopulateExpression(binaryExpression.Right);
switch (binaryExpression.Operator)
{
case BinaryExpressionOperator.Equal:
case BinaryExpressionOperator.NotEqual:
case BinaryExpressionOperator.GreaterThan:
case BinaryExpressionOperator.GreaterThanOrEqual:
case BinaryExpressionOperator.LessThan:
case BinaryExpressionOperator.LessThanOrEqual:
{
binaryExpression.Type = NubPrimitiveType.Bool;
break;
}
case BinaryExpressionOperator.Plus:
case BinaryExpressionOperator.Minus:
case BinaryExpressionOperator.Multiply:
case BinaryExpressionOperator.Divide:
{
binaryExpression.Type = binaryExpression.Left.Type;
break;
}
default:
{
throw new ArgumentOutOfRangeException(nameof(binaryExpression.Operator));
}
}
}
private void PopulateFuncCallExpression(FuncCallExpressionNode funcCall)
{
foreach (var parameter in funcCall.FuncCall.Parameters)
{
PopulateExpression(parameter);
}
var function = _functions.FirstOrDefault(f => f.Name == funcCall.FuncCall.Name);
if (function == null)
{
throw new Exception($"Func {funcCall} is not defined");
}
if (!function.ReturnType.HasValue)
{
throw new Exception($"Func {funcCall} must have a return type when used in an expression");
}
funcCall.Type = function.ReturnType.Value;
}
private void PopulateIdentifier(IdentifierNode identifier)
{
var type = _variables.FirstOrDefault(v => v.Name == identifier.Identifier)?.Type;
if (type == null)
{
throw new Exception($"Variable {identifier} is not defined");
}
identifier.Type = type;
}
private static void PopulateLiteral(LiteralNode literal)
{
literal.Type = literal.LiteralType;
}
private void PopulateStructInitializer(StructInitializerNode structInitializer)
{
foreach (var initializer in structInitializer.Initializers)
{
PopulateExpression(initializer.Value);
}
structInitializer.Type = structInitializer.StructType;
}
private void PopulateStructMemberAccessorNode(StructFieldAccessorNode structFieldAccessor)
{
PopulateExpression(structFieldAccessor.Struct);
var structType = structFieldAccessor.Struct.Type;
if (structType == null)
{
throw new Exception($"Cannot access field on non-struct type: {structFieldAccessor.Struct}");
}
var structDefinition = _structDefinitions.FirstOrDefault(s => s.Name == structType.Name);
if (structDefinition == null)
{
throw new Exception($"Struct {structType.Name} is not defined");
}
var field = structDefinition.Fields.FirstOrDefault(f => f.Name == structFieldAccessor.Field);
if (field == null)
{
throw new Exception($"Field {structFieldAccessor.Field} is not defined in struct {structType.Name}");
}
structFieldAccessor.Type = field.Type;
}
private class Variable(string name, NubType type)
{
public string Name { get; } = name;
public NubType Type { get; } = type;
}
}

View File

@@ -0,0 +1,412 @@
using Nub.Lang.Frontend.Parsing;
namespace Nub.Lang.Frontend.Typing;
public class TypeCheckingException : Exception
{
public TypeCheckingException(string message) : base(message) { }
}
public class TypeChecker
{
private readonly Dictionary<string, NubType> _variables = new();
private readonly Dictionary<string, (List<FuncParameter> Parameters, Optional<NubType> ReturnType)> _functions = new();
private readonly Dictionary<string, Dictionary<string, NubType>> _structs = new();
private NubType? _currentFunctionReturnType;
private bool _hasReturnStatement;
public void TypeCheck(List<DefinitionNode> definitions)
{
CollectDefinitions(definitions);
foreach (var definition in definitions)
{
if (definition is LocalFuncDefinitionNode funcDef)
{
TypeCheckFunction(funcDef);
}
}
}
private void CollectDefinitions(List<DefinitionNode> definitions)
{
foreach (var definition in definitions)
{
switch (definition)
{
case StructDefinitionNode structDef:
RegisterStruct(structDef);
break;
case LocalFuncDefinitionNode funcDef:
RegisterFunction(funcDef);
break;
case ExternFuncDefinitionNode externFuncDef:
RegisterExternFunction(externFuncDef);
break;
}
}
}
private void RegisterStruct(StructDefinitionNode structDef)
{
var fields = new Dictionary<string, NubType>();
foreach (var field in structDef.Fields)
{
if (fields.ContainsKey(field.Name))
{
throw new TypeCheckingException($"Duplicate field '{field.Name}' in struct '{structDef.Name}'");
}
fields[field.Name] = field.Type;
}
_structs[structDef.Name] = fields;
}
private void RegisterFunction(LocalFuncDefinitionNode funcDef)
{
_functions[funcDef.Name] = (funcDef.Parameters, funcDef.ReturnType);
}
private void RegisterExternFunction(ExternFuncDefinitionNode funcDef)
{
_functions[funcDef.Name] = (funcDef.Parameters, funcDef.ReturnType);
}
private void TypeCheckFunction(LocalFuncDefinitionNode funcDef)
{
_variables.Clear();
_currentFunctionReturnType = funcDef.ReturnType.HasValue ? funcDef.ReturnType.Value : null;
_hasReturnStatement = false;
foreach (var param in funcDef.Parameters)
{
_variables[param.Name] = param.Type;
}
TypeCheckBlock(funcDef.Body);
if (_currentFunctionReturnType != null && !_hasReturnStatement)
{
throw new TypeCheckingException($"Function '{funcDef.Name}' must return a value of type '{_currentFunctionReturnType}'");
}
}
private void TypeCheckBlock(BlockNode block)
{
foreach (var statement in block.Statements)
{
TypeCheckStatement(statement);
}
}
private void TypeCheckStatement(StatementNode statement)
{
switch (statement)
{
case VariableAssignmentNode varAssign:
TypeCheckVariableAssignment(varAssign);
break;
case FuncCallStatementNode funcCall:
TypeCheckFuncCall(funcCall.FuncCall);
break;
case IfNode ifNode:
TypeCheckIf(ifNode);
break;
case WhileNode whileNode:
TypeCheckWhile(whileNode);
break;
case ReturnNode returnNode:
TypeCheckReturn(returnNode);
break;
case BreakNode:
case ContinueNode:
break;
default:
throw new TypeCheckingException($"Unsupported statement type: {statement.GetType().Name}");
}
}
private void TypeCheckVariableAssignment(VariableAssignmentNode varAssign)
{
var valueType = TypeCheckExpression(varAssign.Value);
if (varAssign.ExplicitType.HasValue)
{
var explicitType = varAssign.ExplicitType.Value;
if (!AreTypesCompatible(valueType, explicitType))
{
throw new TypeCheckingException($"Cannot assign expression of type '{valueType}' to variable '{varAssign.Name}' of type '{explicitType}'");
}
_variables[varAssign.Name] = explicitType;
}
else
{
_variables[varAssign.Name] = valueType;
}
}
private NubType TypeCheckFuncCall(FuncCall funcCall)
{
if (!_functions.TryGetValue(funcCall.Name, out var funcSignature))
{
throw new TypeCheckingException($"Function '{funcCall.Name}' is not defined");
}
var paramTypes = funcSignature.Parameters;
if (paramTypes.Take(paramTypes.Count - 1).Any(x => x.Variadic))
{
throw new TypeCheckingException($"Function '{funcCall.Name}' has multiple variadic parameters");
}
for (var i = 0; i < funcCall.Parameters.Count; i++)
{
var argType = TypeCheckExpression(funcCall.Parameters[i]);
NubType paramType;
if (i < paramTypes.Count)
{
paramType = paramTypes[i].Type;
}
else if (paramTypes.LastOrDefault()?.Variadic ?? false)
{
return paramTypes[^1].Type;
}
else
{
throw new TypeCheckingException($"Function '{funcCall.Name}' does not take {funcCall.Parameters.Count} parameters");
}
if (!AreTypesCompatible(argType, paramType))
{
throw new TypeCheckingException($"Parameter {i} of function '{funcCall.Name}' expects type '{paramType}', but got '{argType}'");
}
}
return funcSignature.ReturnType.HasValue ? funcSignature.ReturnType.Value : NubPrimitiveType.Any;
}
private void TypeCheckIf(IfNode ifNode)
{
var conditionType = TypeCheckExpression(ifNode.Condition);
if (!conditionType.Equals(NubPrimitiveType.Bool))
{
throw new TypeCheckingException($"If condition must be a boolean expression, got '{conditionType}'");
}
TypeCheckBlock(ifNode.Body);
if (ifNode.Else.HasValue)
{
var elseValue = ifNode.Else.Value;
elseValue.Match(TypeCheckIf, TypeCheckBlock);
}
}
private void TypeCheckWhile(WhileNode whileNode)
{
var conditionType = TypeCheckExpression(whileNode.Condition);
if (!conditionType.Equals(NubPrimitiveType.Bool))
{
throw new TypeCheckingException($"While condition must be a boolean expression, got '{conditionType}'");
}
TypeCheckBlock(whileNode.Body);
}
private void TypeCheckReturn(ReturnNode returnNode)
{
_hasReturnStatement = true;
if (returnNode.Value.HasValue)
{
var returnType = TypeCheckExpression(returnNode.Value.Value);
if (_currentFunctionReturnType == null)
{
throw new TypeCheckingException("Cannot return a value from a function with no return type");
}
if (!AreTypesCompatible(returnType, _currentFunctionReturnType))
{
throw new TypeCheckingException($"Return value of type '{returnType}' is not compatible with function return type '{_currentFunctionReturnType}'");
}
}
else if (_currentFunctionReturnType != null)
{
throw new TypeCheckingException($"Function must return a value of type '{_currentFunctionReturnType}'");
}
}
private NubType TypeCheckExpression(ExpressionNode expression)
{
NubType resultType;
switch (expression)
{
case LiteralNode literal:
resultType = literal.LiteralType;
break;
case IdentifierNode identifier:
if (!_variables.TryGetValue(identifier.Identifier, out var varType))
{
throw new TypeCheckingException($"Variable '{identifier.Identifier}' is not defined");
}
resultType = varType;
break;
case BinaryExpressionNode binaryExpr:
resultType = TypeCheckBinaryExpression(binaryExpr);
break;
case FuncCallExpressionNode funcCallExpr:
resultType = TypeCheckFuncCall(funcCallExpr.FuncCall);
break;
case StructInitializerNode structInit:
resultType = TypeCheckStructInitializer(structInit);
break;
case StructFieldAccessorNode fieldAccess:
resultType = TypeCheckStructFieldAccess(fieldAccess);
break;
default:
throw new TypeCheckingException($"Unsupported expression type: {expression.GetType().Name}");
}
expression.Type = resultType;
return resultType;
}
private NubType TypeCheckBinaryExpression(BinaryExpressionNode binaryExpr)
{
var leftType = TypeCheckExpression(binaryExpr.Left);
var rightType = TypeCheckExpression(binaryExpr.Right);
if (!leftType.Equals(rightType))
{
throw new TypeCheckingException($"Left '{leftType}' and right '{rightType}' side of the binary expression is not equal");
}
switch (binaryExpr.Operator)
{
case BinaryExpressionOperator.Equal:
case BinaryExpressionOperator.NotEqual:
return NubPrimitiveType.Bool;
case BinaryExpressionOperator.GreaterThan:
case BinaryExpressionOperator.GreaterThanOrEqual:
case BinaryExpressionOperator.LessThan:
case BinaryExpressionOperator.LessThanOrEqual:
if (!IsNumeric(leftType))
{
throw new TypeCheckingException($"Comparison operators require numeric operands, got '{leftType}' and '{rightType}'");
}
return NubPrimitiveType.Bool;
case BinaryExpressionOperator.Plus:
case BinaryExpressionOperator.Minus:
case BinaryExpressionOperator.Multiply:
case BinaryExpressionOperator.Divide:
if (!IsNumeric(leftType))
{
throw new TypeCheckingException($"Arithmetic operators require numeric operands, got '{leftType}' and '{rightType}'");
}
return leftType;
default:
throw new TypeCheckingException($"Unsupported binary operator: {binaryExpr.Operator}");
}
}
private NubType TypeCheckStructInitializer(StructInitializerNode structInit)
{
var structType = structInit.StructType;
if (structType is not NubCustomType customType)
{
throw new TypeCheckingException($"Type '{structType}' is not a struct type");
}
if (!_structs.TryGetValue(customType.Name, out var fields))
{
throw new TypeCheckingException($"Struct type '{customType.Name}' is not defined");
}
foreach (var initializer in structInit.Initializers)
{
if (!fields.TryGetValue(initializer.Key, out var fieldType))
{
throw new TypeCheckingException($"Field '{initializer.Key}' does not exist in struct '{customType.Name}'");
}
var initializerType = TypeCheckExpression(initializer.Value);
if (!AreTypesCompatible(initializerType, fieldType))
{
throw new TypeCheckingException($"Cannot initialize field '{initializer.Key}' of type '{fieldType}' with expression of type '{initializerType}'");
}
}
foreach (var field in fields)
{
if (!structInit.Initializers.ContainsKey(field.Key))
{
throw new TypeCheckingException($"Field '{field.Key}' of struct '{customType.Name}' is not initialized");
}
}
return structType;
}
private NubType TypeCheckStructFieldAccess(StructFieldAccessorNode fieldAccess)
{
var structType = TypeCheckExpression(fieldAccess.Struct);
if (structType is not NubCustomType customType)
{
throw new TypeCheckingException($"Cannot access field '{fieldAccess.Field}' on non-struct type '{structType}'");
}
if (!_structs.TryGetValue(customType.Name, out var fields))
{
throw new TypeCheckingException($"Struct type '{customType.Name}' is not defined");
}
if (!fields.TryGetValue(fieldAccess.Field, out var fieldType))
{
throw new TypeCheckingException($"Field '{fieldAccess.Field}' does not exist in struct '{customType.Name}'");
}
return fieldType;
}
#region Type Helper Methods
private static bool AreTypesCompatible(NubType sourceType, NubType targetType)
{
return targetType.Equals(NubPrimitiveType.Any) || sourceType.Equals(targetType);
}
private static bool IsNumeric(NubType type)
{
if (type is not NubPrimitiveType primitiveType)
{
return false;
}
switch (primitiveType.Kind)
{
case PrimitiveTypeKind.I8:
case PrimitiveTypeKind.I16:
case PrimitiveTypeKind.I32:
case PrimitiveTypeKind.I64:
case PrimitiveTypeKind.U8:
case PrimitiveTypeKind.U16:
case PrimitiveTypeKind.U32:
case PrimitiveTypeKind.U64:
case PrimitiveTypeKind.F32:
case PrimitiveTypeKind.F64:
return true;
default:
return false;
}
}
#endregion
}

View File

@@ -44,8 +44,8 @@ internal static class Program
var modules = RunFrontend(input); var modules = RunFrontend(input);
var definitions = modules.SelectMany(f => f.Definitions).ToList(); var definitions = modules.SelectMany(f => f.Definitions).ToList();
var typer = new ExpressionTyper(definitions); var typeChecker = new TypeChecker();
typer.Populate(); typeChecker.TypeCheck(definitions);
var generator = new Generator(definitions); var generator = new Generator(definitions);
var result = generator.Generate(); var result = generator.Generate();