type checker
This commit is contained in:
@@ -3,6 +3,6 @@ import c
|
||||
global func main(argc: i64, argv: i64) {
|
||||
printf("args: %d, starts at %p\n", argc, argv)
|
||||
|
||||
x: i8 = 320000
|
||||
x: i8 = (i8)320000
|
||||
printf("%d\n", x)
|
||||
}
|
||||
@@ -14,7 +14,6 @@ public class Generator
|
||||
private readonly Stack<string> _breakLabels = new();
|
||||
private readonly Stack<string> _continueLabels = new();
|
||||
private bool _codeIsReachable = true;
|
||||
private LocalFuncDefinitionNode? _currentFuncDefininition;
|
||||
|
||||
public Generator(List<DefinitionNode> definitions)
|
||||
{
|
||||
@@ -218,7 +217,6 @@ public class Generator
|
||||
|
||||
private void GenerateFuncDefinition(LocalFuncDefinitionNode node)
|
||||
{
|
||||
_currentFuncDefininition = node;
|
||||
_variables.Clear();
|
||||
|
||||
if (node.Global)
|
||||
@@ -290,7 +288,6 @@ public class Generator
|
||||
}
|
||||
|
||||
_builder.AppendLine("}");
|
||||
_currentFuncDefininition = null;
|
||||
}
|
||||
|
||||
private void GenerateStructDefinition(StructDefinitionNode structDefinition)
|
||||
@@ -370,8 +367,7 @@ public class Generator
|
||||
}
|
||||
|
||||
var parameter = funcCall.Parameters[i];
|
||||
var parameterOutput = GenerateExpression(parameter);
|
||||
var result = GenerateTypeConversion(parameterOutput, parameter.Type, expectedType);
|
||||
var result = GenerateExpression(parameter);
|
||||
|
||||
var qbeParameterType = SQT(expectedType.Equals(NubPrimitiveType.Any) ? parameter.Type : expectedType);
|
||||
parameterStrings.Add($"{qbeParameterType} {result}");
|
||||
@@ -436,14 +432,8 @@ public class Generator
|
||||
{
|
||||
if (@return.Value.HasValue)
|
||||
{
|
||||
if (!_currentFuncDefininition!.ReturnType.HasValue)
|
||||
{
|
||||
throw new Exception("Cannot return a value when function does not have a return value");
|
||||
}
|
||||
|
||||
var result = GenerateExpression(@return.Value.Value);
|
||||
var converted = GenerateTypeConversion(result, @return.Value.Value.Type, _currentFuncDefininition.ReturnType.Value);
|
||||
_builder.AppendLine($" ret {converted}");
|
||||
_builder.AppendLine($" ret {result}");
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -454,17 +444,10 @@ public class Generator
|
||||
private void GenerateVariableAssignment(VariableAssignmentNode variableAssignment)
|
||||
{
|
||||
var result = GenerateExpression(variableAssignment.Value);
|
||||
var variableType = variableAssignment.Value.Type;
|
||||
|
||||
if (variableAssignment.ExplicitType.HasValue)
|
||||
{
|
||||
result = GenerateTypeConversion(result, variableType, variableAssignment.ExplicitType.Value);
|
||||
variableType = variableAssignment.ExplicitType.Value;
|
||||
}
|
||||
_variables[variableAssignment.Name] = new Variable
|
||||
{
|
||||
Identifier = result,
|
||||
Type = variableType
|
||||
Type = variableAssignment.Value.Type
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
@@ -1,292 +0,0 @@
|
||||
using Nub.Lang.Frontend.Parsing;
|
||||
|
||||
namespace Nub.Lang.Frontend.Typing;
|
||||
|
||||
public class Func(string name, List<FuncParameter> parameters, Optional<BlockNode> body, Optional<NubType> returnType)
|
||||
{
|
||||
public string Name { get; } = name;
|
||||
public List<FuncParameter> Parameters { get; } = parameters;
|
||||
public Optional<BlockNode> Body { get; } = body;
|
||||
public Optional<NubType> ReturnType { get; } = returnType;
|
||||
}
|
||||
|
||||
public class ExpressionTyper
|
||||
{
|
||||
private readonly List<Func> _functions;
|
||||
private readonly List<StructDefinitionNode> _structDefinitions;
|
||||
private readonly Stack<Variable> _variables;
|
||||
|
||||
public ExpressionTyper(List<DefinitionNode> definitions)
|
||||
{
|
||||
_variables = new Stack<Variable>();
|
||||
_functions = [];
|
||||
|
||||
_structDefinitions = definitions.OfType<StructDefinitionNode>().ToList();
|
||||
|
||||
var functions = definitions
|
||||
.OfType<LocalFuncDefinitionNode>()
|
||||
.Select(f => new Func(f.Name, f.Parameters, f.Body, f.ReturnType))
|
||||
.ToList();
|
||||
|
||||
var externFunctions = definitions
|
||||
.OfType<ExternFuncDefinitionNode>()
|
||||
.Select(f => new Func(f.Name, f.Parameters, Optional<BlockNode>.Empty(), f.ReturnType))
|
||||
.ToList();
|
||||
|
||||
_functions.AddRange(functions);
|
||||
_functions.AddRange(externFunctions);
|
||||
}
|
||||
|
||||
public void Populate()
|
||||
{
|
||||
_variables.Clear();
|
||||
|
||||
foreach (var @class in _structDefinitions)
|
||||
{
|
||||
foreach (var variable in @class.Fields)
|
||||
{
|
||||
if (variable.Value.HasValue)
|
||||
{
|
||||
PopulateExpression(variable.Value.Value);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
foreach (var function in _functions)
|
||||
{
|
||||
foreach (var parameter in function.Parameters)
|
||||
{
|
||||
_variables.Push(new Variable(parameter.Name, parameter.Type));
|
||||
}
|
||||
|
||||
if (function.Body.HasValue)
|
||||
{
|
||||
PopulateBlock(function.Body.Value);
|
||||
}
|
||||
for (var i = 0; i < function.Parameters.Count; i++)
|
||||
{
|
||||
_variables.Pop();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void PopulateBlock(BlockNode block)
|
||||
{
|
||||
var variableCount = _variables.Count;
|
||||
foreach (var statement in block.Statements)
|
||||
{
|
||||
PopulateStatement(statement);
|
||||
}
|
||||
while (_variables.Count > variableCount)
|
||||
{
|
||||
_variables.Pop();
|
||||
}
|
||||
}
|
||||
|
||||
private void PopulateStatement(StatementNode statement)
|
||||
{
|
||||
switch (statement)
|
||||
{
|
||||
case BreakNode:
|
||||
case ContinueNode:
|
||||
break;
|
||||
case FuncCallStatementNode funcCall:
|
||||
PopulateFuncCallStatement(funcCall);
|
||||
break;
|
||||
case IfNode ifStatement:
|
||||
PopulateIf(ifStatement);
|
||||
break;
|
||||
case ReturnNode returnNode:
|
||||
PopulateReturn(returnNode);
|
||||
break;
|
||||
case VariableAssignmentNode variableAssignment:
|
||||
PopulateVariableAssignment(variableAssignment);
|
||||
break;
|
||||
case WhileNode whileStatement:
|
||||
PopulateWhileStatement(whileStatement);
|
||||
break;
|
||||
default:
|
||||
throw new ArgumentOutOfRangeException(nameof(statement));
|
||||
}
|
||||
}
|
||||
|
||||
private void PopulateFuncCallStatement(FuncCallStatementNode funcCall)
|
||||
{
|
||||
foreach (var parameter in funcCall.FuncCall.Parameters)
|
||||
{
|
||||
PopulateExpression(parameter);
|
||||
}
|
||||
}
|
||||
|
||||
private void PopulateIf(IfNode ifStatement)
|
||||
{
|
||||
PopulateExpression(ifStatement.Condition);
|
||||
PopulateBlock(ifStatement.Body);
|
||||
if (ifStatement.Else.HasValue)
|
||||
{
|
||||
ifStatement.Else.Value.Match
|
||||
(
|
||||
PopulateIf,
|
||||
PopulateBlock
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
private void PopulateReturn(ReturnNode returnNode)
|
||||
{
|
||||
if (returnNode.Value.HasValue)
|
||||
{
|
||||
PopulateExpression(returnNode.Value.Value);
|
||||
}
|
||||
}
|
||||
|
||||
private void PopulateVariableAssignment(VariableAssignmentNode variableAssignment)
|
||||
{
|
||||
PopulateExpression(variableAssignment.Value);
|
||||
_variables.Push(new Variable(variableAssignment.Name, variableAssignment.ExplicitType.HasValue ? variableAssignment.ExplicitType.Value : variableAssignment.Value.Type));
|
||||
}
|
||||
|
||||
private void PopulateVariableReassignment(VariableAssignmentNode variableAssignment)
|
||||
{
|
||||
PopulateExpression(variableAssignment.Value);
|
||||
}
|
||||
|
||||
private void PopulateWhileStatement(WhileNode whileStatement)
|
||||
{
|
||||
PopulateExpression(whileStatement.Condition);
|
||||
PopulateBlock(whileStatement.Body);
|
||||
}
|
||||
|
||||
private void PopulateExpression(ExpressionNode expression)
|
||||
{
|
||||
switch (expression)
|
||||
{
|
||||
case BinaryExpressionNode binaryExpression:
|
||||
PopulateBinaryExpression(binaryExpression);
|
||||
break;
|
||||
case FuncCallExpressionNode funcCall:
|
||||
PopulateFuncCallExpression(funcCall);
|
||||
break;
|
||||
case IdentifierNode identifier:
|
||||
PopulateIdentifier(identifier);
|
||||
break;
|
||||
case LiteralNode literal:
|
||||
PopulateLiteral(literal);
|
||||
break;
|
||||
case StructInitializerNode structInitializer:
|
||||
PopulateStructInitializer(structInitializer);
|
||||
break;
|
||||
case StructFieldAccessorNode structMemberAccessor:
|
||||
PopulateStructMemberAccessorNode(structMemberAccessor);
|
||||
break;
|
||||
default:
|
||||
throw new ArgumentOutOfRangeException(nameof(expression));
|
||||
}
|
||||
}
|
||||
|
||||
private void PopulateBinaryExpression(BinaryExpressionNode binaryExpression)
|
||||
{
|
||||
PopulateExpression(binaryExpression.Left);
|
||||
PopulateExpression(binaryExpression.Right);
|
||||
switch (binaryExpression.Operator)
|
||||
{
|
||||
case BinaryExpressionOperator.Equal:
|
||||
case BinaryExpressionOperator.NotEqual:
|
||||
case BinaryExpressionOperator.GreaterThan:
|
||||
case BinaryExpressionOperator.GreaterThanOrEqual:
|
||||
case BinaryExpressionOperator.LessThan:
|
||||
case BinaryExpressionOperator.LessThanOrEqual:
|
||||
{
|
||||
binaryExpression.Type = NubPrimitiveType.Bool;
|
||||
break;
|
||||
}
|
||||
case BinaryExpressionOperator.Plus:
|
||||
case BinaryExpressionOperator.Minus:
|
||||
case BinaryExpressionOperator.Multiply:
|
||||
case BinaryExpressionOperator.Divide:
|
||||
{
|
||||
binaryExpression.Type = binaryExpression.Left.Type;
|
||||
break;
|
||||
}
|
||||
default:
|
||||
{
|
||||
throw new ArgumentOutOfRangeException(nameof(binaryExpression.Operator));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void PopulateFuncCallExpression(FuncCallExpressionNode funcCall)
|
||||
{
|
||||
foreach (var parameter in funcCall.FuncCall.Parameters)
|
||||
{
|
||||
PopulateExpression(parameter);
|
||||
}
|
||||
|
||||
var function = _functions.FirstOrDefault(f => f.Name == funcCall.FuncCall.Name);
|
||||
if (function == null)
|
||||
{
|
||||
throw new Exception($"Func {funcCall} is not defined");
|
||||
}
|
||||
if (!function.ReturnType.HasValue)
|
||||
{
|
||||
throw new Exception($"Func {funcCall} must have a return type when used in an expression");
|
||||
}
|
||||
funcCall.Type = function.ReturnType.Value;
|
||||
}
|
||||
|
||||
private void PopulateIdentifier(IdentifierNode identifier)
|
||||
{
|
||||
var type = _variables.FirstOrDefault(v => v.Name == identifier.Identifier)?.Type;
|
||||
if (type == null)
|
||||
{
|
||||
throw new Exception($"Variable {identifier} is not defined");
|
||||
}
|
||||
identifier.Type = type;
|
||||
}
|
||||
|
||||
private static void PopulateLiteral(LiteralNode literal)
|
||||
{
|
||||
literal.Type = literal.LiteralType;
|
||||
}
|
||||
|
||||
private void PopulateStructInitializer(StructInitializerNode structInitializer)
|
||||
{
|
||||
foreach (var initializer in structInitializer.Initializers)
|
||||
{
|
||||
PopulateExpression(initializer.Value);
|
||||
}
|
||||
|
||||
structInitializer.Type = structInitializer.StructType;
|
||||
}
|
||||
|
||||
private void PopulateStructMemberAccessorNode(StructFieldAccessorNode structFieldAccessor)
|
||||
{
|
||||
PopulateExpression(structFieldAccessor.Struct);
|
||||
|
||||
var structType = structFieldAccessor.Struct.Type;
|
||||
if (structType == null)
|
||||
{
|
||||
throw new Exception($"Cannot access field on non-struct type: {structFieldAccessor.Struct}");
|
||||
}
|
||||
|
||||
var structDefinition = _structDefinitions.FirstOrDefault(s => s.Name == structType.Name);
|
||||
if (structDefinition == null)
|
||||
{
|
||||
throw new Exception($"Struct {structType.Name} is not defined");
|
||||
}
|
||||
|
||||
var field = structDefinition.Fields.FirstOrDefault(f => f.Name == structFieldAccessor.Field);
|
||||
if (field == null)
|
||||
{
|
||||
throw new Exception($"Field {structFieldAccessor.Field} is not defined in struct {structType.Name}");
|
||||
}
|
||||
|
||||
structFieldAccessor.Type = field.Type;
|
||||
}
|
||||
|
||||
private class Variable(string name, NubType type)
|
||||
{
|
||||
public string Name { get; } = name;
|
||||
public NubType Type { get; } = type;
|
||||
}
|
||||
}
|
||||
412
src/compiler/Nub.Lang/Frontend/Typing/TypeChecker.cs
Normal file
412
src/compiler/Nub.Lang/Frontend/Typing/TypeChecker.cs
Normal file
@@ -0,0 +1,412 @@
|
||||
using Nub.Lang.Frontend.Parsing;
|
||||
|
||||
namespace Nub.Lang.Frontend.Typing;
|
||||
|
||||
public class TypeCheckingException : Exception
|
||||
{
|
||||
public TypeCheckingException(string message) : base(message) { }
|
||||
}
|
||||
|
||||
public class TypeChecker
|
||||
{
|
||||
private readonly Dictionary<string, NubType> _variables = new();
|
||||
private readonly Dictionary<string, (List<FuncParameter> Parameters, Optional<NubType> ReturnType)> _functions = new();
|
||||
private readonly Dictionary<string, Dictionary<string, NubType>> _structs = new();
|
||||
private NubType? _currentFunctionReturnType;
|
||||
private bool _hasReturnStatement;
|
||||
|
||||
public void TypeCheck(List<DefinitionNode> definitions)
|
||||
{
|
||||
CollectDefinitions(definitions);
|
||||
|
||||
foreach (var definition in definitions)
|
||||
{
|
||||
if (definition is LocalFuncDefinitionNode funcDef)
|
||||
{
|
||||
TypeCheckFunction(funcDef);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void CollectDefinitions(List<DefinitionNode> definitions)
|
||||
{
|
||||
foreach (var definition in definitions)
|
||||
{
|
||||
switch (definition)
|
||||
{
|
||||
case StructDefinitionNode structDef:
|
||||
RegisterStruct(structDef);
|
||||
break;
|
||||
case LocalFuncDefinitionNode funcDef:
|
||||
RegisterFunction(funcDef);
|
||||
break;
|
||||
case ExternFuncDefinitionNode externFuncDef:
|
||||
RegisterExternFunction(externFuncDef);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void RegisterStruct(StructDefinitionNode structDef)
|
||||
{
|
||||
var fields = new Dictionary<string, NubType>();
|
||||
foreach (var field in structDef.Fields)
|
||||
{
|
||||
if (fields.ContainsKey(field.Name))
|
||||
{
|
||||
throw new TypeCheckingException($"Duplicate field '{field.Name}' in struct '{structDef.Name}'");
|
||||
}
|
||||
fields[field.Name] = field.Type;
|
||||
}
|
||||
_structs[structDef.Name] = fields;
|
||||
}
|
||||
|
||||
private void RegisterFunction(LocalFuncDefinitionNode funcDef)
|
||||
{
|
||||
_functions[funcDef.Name] = (funcDef.Parameters, funcDef.ReturnType);
|
||||
}
|
||||
|
||||
private void RegisterExternFunction(ExternFuncDefinitionNode funcDef)
|
||||
{
|
||||
_functions[funcDef.Name] = (funcDef.Parameters, funcDef.ReturnType);
|
||||
}
|
||||
|
||||
private void TypeCheckFunction(LocalFuncDefinitionNode funcDef)
|
||||
{
|
||||
_variables.Clear();
|
||||
_currentFunctionReturnType = funcDef.ReturnType.HasValue ? funcDef.ReturnType.Value : null;
|
||||
_hasReturnStatement = false;
|
||||
|
||||
foreach (var param in funcDef.Parameters)
|
||||
{
|
||||
_variables[param.Name] = param.Type;
|
||||
}
|
||||
|
||||
TypeCheckBlock(funcDef.Body);
|
||||
|
||||
if (_currentFunctionReturnType != null && !_hasReturnStatement)
|
||||
{
|
||||
throw new TypeCheckingException($"Function '{funcDef.Name}' must return a value of type '{_currentFunctionReturnType}'");
|
||||
}
|
||||
}
|
||||
|
||||
private void TypeCheckBlock(BlockNode block)
|
||||
{
|
||||
foreach (var statement in block.Statements)
|
||||
{
|
||||
TypeCheckStatement(statement);
|
||||
}
|
||||
}
|
||||
|
||||
private void TypeCheckStatement(StatementNode statement)
|
||||
{
|
||||
switch (statement)
|
||||
{
|
||||
case VariableAssignmentNode varAssign:
|
||||
TypeCheckVariableAssignment(varAssign);
|
||||
break;
|
||||
case FuncCallStatementNode funcCall:
|
||||
TypeCheckFuncCall(funcCall.FuncCall);
|
||||
break;
|
||||
case IfNode ifNode:
|
||||
TypeCheckIf(ifNode);
|
||||
break;
|
||||
case WhileNode whileNode:
|
||||
TypeCheckWhile(whileNode);
|
||||
break;
|
||||
case ReturnNode returnNode:
|
||||
TypeCheckReturn(returnNode);
|
||||
break;
|
||||
case BreakNode:
|
||||
case ContinueNode:
|
||||
break;
|
||||
default:
|
||||
throw new TypeCheckingException($"Unsupported statement type: {statement.GetType().Name}");
|
||||
}
|
||||
}
|
||||
|
||||
private void TypeCheckVariableAssignment(VariableAssignmentNode varAssign)
|
||||
{
|
||||
var valueType = TypeCheckExpression(varAssign.Value);
|
||||
|
||||
if (varAssign.ExplicitType.HasValue)
|
||||
{
|
||||
var explicitType = varAssign.ExplicitType.Value;
|
||||
if (!AreTypesCompatible(valueType, explicitType))
|
||||
{
|
||||
throw new TypeCheckingException($"Cannot assign expression of type '{valueType}' to variable '{varAssign.Name}' of type '{explicitType}'");
|
||||
}
|
||||
_variables[varAssign.Name] = explicitType;
|
||||
}
|
||||
else
|
||||
{
|
||||
_variables[varAssign.Name] = valueType;
|
||||
}
|
||||
}
|
||||
|
||||
private NubType TypeCheckFuncCall(FuncCall funcCall)
|
||||
{
|
||||
if (!_functions.TryGetValue(funcCall.Name, out var funcSignature))
|
||||
{
|
||||
throw new TypeCheckingException($"Function '{funcCall.Name}' is not defined");
|
||||
}
|
||||
|
||||
var paramTypes = funcSignature.Parameters;
|
||||
if (paramTypes.Take(paramTypes.Count - 1).Any(x => x.Variadic))
|
||||
{
|
||||
throw new TypeCheckingException($"Function '{funcCall.Name}' has multiple variadic parameters");
|
||||
}
|
||||
|
||||
for (var i = 0; i < funcCall.Parameters.Count; i++)
|
||||
{
|
||||
var argType = TypeCheckExpression(funcCall.Parameters[i]);
|
||||
|
||||
NubType paramType;
|
||||
if (i < paramTypes.Count)
|
||||
{
|
||||
paramType = paramTypes[i].Type;
|
||||
}
|
||||
else if (paramTypes.LastOrDefault()?.Variadic ?? false)
|
||||
{
|
||||
return paramTypes[^1].Type;
|
||||
}
|
||||
else
|
||||
{
|
||||
throw new TypeCheckingException($"Function '{funcCall.Name}' does not take {funcCall.Parameters.Count} parameters");
|
||||
}
|
||||
|
||||
if (!AreTypesCompatible(argType, paramType))
|
||||
{
|
||||
throw new TypeCheckingException($"Parameter {i} of function '{funcCall.Name}' expects type '{paramType}', but got '{argType}'");
|
||||
}
|
||||
}
|
||||
|
||||
return funcSignature.ReturnType.HasValue ? funcSignature.ReturnType.Value : NubPrimitiveType.Any;
|
||||
}
|
||||
|
||||
private void TypeCheckIf(IfNode ifNode)
|
||||
{
|
||||
var conditionType = TypeCheckExpression(ifNode.Condition);
|
||||
if (!conditionType.Equals(NubPrimitiveType.Bool))
|
||||
{
|
||||
throw new TypeCheckingException($"If condition must be a boolean expression, got '{conditionType}'");
|
||||
}
|
||||
|
||||
TypeCheckBlock(ifNode.Body);
|
||||
|
||||
if (ifNode.Else.HasValue)
|
||||
{
|
||||
var elseValue = ifNode.Else.Value;
|
||||
elseValue.Match(TypeCheckIf, TypeCheckBlock);
|
||||
}
|
||||
}
|
||||
|
||||
private void TypeCheckWhile(WhileNode whileNode)
|
||||
{
|
||||
var conditionType = TypeCheckExpression(whileNode.Condition);
|
||||
if (!conditionType.Equals(NubPrimitiveType.Bool))
|
||||
{
|
||||
throw new TypeCheckingException($"While condition must be a boolean expression, got '{conditionType}'");
|
||||
}
|
||||
|
||||
TypeCheckBlock(whileNode.Body);
|
||||
}
|
||||
|
||||
private void TypeCheckReturn(ReturnNode returnNode)
|
||||
{
|
||||
_hasReturnStatement = true;
|
||||
|
||||
if (returnNode.Value.HasValue)
|
||||
{
|
||||
var returnType = TypeCheckExpression(returnNode.Value.Value);
|
||||
|
||||
if (_currentFunctionReturnType == null)
|
||||
{
|
||||
throw new TypeCheckingException("Cannot return a value from a function with no return type");
|
||||
}
|
||||
|
||||
if (!AreTypesCompatible(returnType, _currentFunctionReturnType))
|
||||
{
|
||||
throw new TypeCheckingException($"Return value of type '{returnType}' is not compatible with function return type '{_currentFunctionReturnType}'");
|
||||
}
|
||||
}
|
||||
else if (_currentFunctionReturnType != null)
|
||||
{
|
||||
throw new TypeCheckingException($"Function must return a value of type '{_currentFunctionReturnType}'");
|
||||
}
|
||||
}
|
||||
|
||||
private NubType TypeCheckExpression(ExpressionNode expression)
|
||||
{
|
||||
NubType resultType;
|
||||
|
||||
switch (expression)
|
||||
{
|
||||
case LiteralNode literal:
|
||||
resultType = literal.LiteralType;
|
||||
break;
|
||||
|
||||
case IdentifierNode identifier:
|
||||
if (!_variables.TryGetValue(identifier.Identifier, out var varType))
|
||||
{
|
||||
throw new TypeCheckingException($"Variable '{identifier.Identifier}' is not defined");
|
||||
}
|
||||
resultType = varType;
|
||||
break;
|
||||
|
||||
case BinaryExpressionNode binaryExpr:
|
||||
resultType = TypeCheckBinaryExpression(binaryExpr);
|
||||
break;
|
||||
|
||||
case FuncCallExpressionNode funcCallExpr:
|
||||
resultType = TypeCheckFuncCall(funcCallExpr.FuncCall);
|
||||
break;
|
||||
|
||||
case StructInitializerNode structInit:
|
||||
resultType = TypeCheckStructInitializer(structInit);
|
||||
break;
|
||||
|
||||
case StructFieldAccessorNode fieldAccess:
|
||||
resultType = TypeCheckStructFieldAccess(fieldAccess);
|
||||
break;
|
||||
|
||||
default:
|
||||
throw new TypeCheckingException($"Unsupported expression type: {expression.GetType().Name}");
|
||||
}
|
||||
|
||||
expression.Type = resultType;
|
||||
return resultType;
|
||||
}
|
||||
|
||||
private NubType TypeCheckBinaryExpression(BinaryExpressionNode binaryExpr)
|
||||
{
|
||||
var leftType = TypeCheckExpression(binaryExpr.Left);
|
||||
var rightType = TypeCheckExpression(binaryExpr.Right);
|
||||
|
||||
if (!leftType.Equals(rightType))
|
||||
{
|
||||
throw new TypeCheckingException($"Left '{leftType}' and right '{rightType}' side of the binary expression is not equal");
|
||||
}
|
||||
|
||||
switch (binaryExpr.Operator)
|
||||
{
|
||||
case BinaryExpressionOperator.Equal:
|
||||
case BinaryExpressionOperator.NotEqual:
|
||||
return NubPrimitiveType.Bool;
|
||||
case BinaryExpressionOperator.GreaterThan:
|
||||
case BinaryExpressionOperator.GreaterThanOrEqual:
|
||||
case BinaryExpressionOperator.LessThan:
|
||||
case BinaryExpressionOperator.LessThanOrEqual:
|
||||
if (!IsNumeric(leftType))
|
||||
{
|
||||
throw new TypeCheckingException($"Comparison operators require numeric operands, got '{leftType}' and '{rightType}'");
|
||||
}
|
||||
return NubPrimitiveType.Bool;
|
||||
case BinaryExpressionOperator.Plus:
|
||||
case BinaryExpressionOperator.Minus:
|
||||
case BinaryExpressionOperator.Multiply:
|
||||
case BinaryExpressionOperator.Divide:
|
||||
if (!IsNumeric(leftType))
|
||||
{
|
||||
throw new TypeCheckingException($"Arithmetic operators require numeric operands, got '{leftType}' and '{rightType}'");
|
||||
}
|
||||
return leftType;
|
||||
default:
|
||||
throw new TypeCheckingException($"Unsupported binary operator: {binaryExpr.Operator}");
|
||||
}
|
||||
}
|
||||
|
||||
private NubType TypeCheckStructInitializer(StructInitializerNode structInit)
|
||||
{
|
||||
var structType = structInit.StructType;
|
||||
if (structType is not NubCustomType customType)
|
||||
{
|
||||
throw new TypeCheckingException($"Type '{structType}' is not a struct type");
|
||||
}
|
||||
|
||||
if (!_structs.TryGetValue(customType.Name, out var fields))
|
||||
{
|
||||
throw new TypeCheckingException($"Struct type '{customType.Name}' is not defined");
|
||||
}
|
||||
|
||||
foreach (var initializer in structInit.Initializers)
|
||||
{
|
||||
if (!fields.TryGetValue(initializer.Key, out var fieldType))
|
||||
{
|
||||
throw new TypeCheckingException($"Field '{initializer.Key}' does not exist in struct '{customType.Name}'");
|
||||
}
|
||||
|
||||
var initializerType = TypeCheckExpression(initializer.Value);
|
||||
if (!AreTypesCompatible(initializerType, fieldType))
|
||||
{
|
||||
throw new TypeCheckingException($"Cannot initialize field '{initializer.Key}' of type '{fieldType}' with expression of type '{initializerType}'");
|
||||
}
|
||||
}
|
||||
|
||||
foreach (var field in fields)
|
||||
{
|
||||
if (!structInit.Initializers.ContainsKey(field.Key))
|
||||
{
|
||||
throw new TypeCheckingException($"Field '{field.Key}' of struct '{customType.Name}' is not initialized");
|
||||
}
|
||||
}
|
||||
|
||||
return structType;
|
||||
}
|
||||
|
||||
private NubType TypeCheckStructFieldAccess(StructFieldAccessorNode fieldAccess)
|
||||
{
|
||||
var structType = TypeCheckExpression(fieldAccess.Struct);
|
||||
|
||||
if (structType is not NubCustomType customType)
|
||||
{
|
||||
throw new TypeCheckingException($"Cannot access field '{fieldAccess.Field}' on non-struct type '{structType}'");
|
||||
}
|
||||
|
||||
if (!_structs.TryGetValue(customType.Name, out var fields))
|
||||
{
|
||||
throw new TypeCheckingException($"Struct type '{customType.Name}' is not defined");
|
||||
}
|
||||
|
||||
if (!fields.TryGetValue(fieldAccess.Field, out var fieldType))
|
||||
{
|
||||
throw new TypeCheckingException($"Field '{fieldAccess.Field}' does not exist in struct '{customType.Name}'");
|
||||
}
|
||||
|
||||
return fieldType;
|
||||
}
|
||||
|
||||
#region Type Helper Methods
|
||||
|
||||
private static bool AreTypesCompatible(NubType sourceType, NubType targetType)
|
||||
{
|
||||
return targetType.Equals(NubPrimitiveType.Any) || sourceType.Equals(targetType);
|
||||
}
|
||||
|
||||
private static bool IsNumeric(NubType type)
|
||||
{
|
||||
if (type is not NubPrimitiveType primitiveType)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
switch (primitiveType.Kind)
|
||||
{
|
||||
case PrimitiveTypeKind.I8:
|
||||
case PrimitiveTypeKind.I16:
|
||||
case PrimitiveTypeKind.I32:
|
||||
case PrimitiveTypeKind.I64:
|
||||
case PrimitiveTypeKind.U8:
|
||||
case PrimitiveTypeKind.U16:
|
||||
case PrimitiveTypeKind.U32:
|
||||
case PrimitiveTypeKind.U64:
|
||||
case PrimitiveTypeKind.F32:
|
||||
case PrimitiveTypeKind.F64:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
#endregion
|
||||
}
|
||||
@@ -44,8 +44,8 @@ internal static class Program
|
||||
var modules = RunFrontend(input);
|
||||
var definitions = modules.SelectMany(f => f.Definitions).ToList();
|
||||
|
||||
var typer = new ExpressionTyper(definitions);
|
||||
typer.Populate();
|
||||
var typeChecker = new TypeChecker();
|
||||
typeChecker.TypeCheck(definitions);
|
||||
|
||||
var generator = new Generator(definitions);
|
||||
var result = generator.Generate();
|
||||
|
||||
Reference in New Issue
Block a user