type checker
This commit is contained in:
@@ -3,6 +3,6 @@ import c
|
|||||||
global func main(argc: i64, argv: i64) {
|
global func main(argc: i64, argv: i64) {
|
||||||
printf("args: %d, starts at %p\n", argc, argv)
|
printf("args: %d, starts at %p\n", argc, argv)
|
||||||
|
|
||||||
x: i8 = 320000
|
x: i8 = (i8)320000
|
||||||
printf("%d\n", x)
|
printf("%d\n", x)
|
||||||
}
|
}
|
||||||
@@ -14,7 +14,6 @@ public class Generator
|
|||||||
private readonly Stack<string> _breakLabels = new();
|
private readonly Stack<string> _breakLabels = new();
|
||||||
private readonly Stack<string> _continueLabels = new();
|
private readonly Stack<string> _continueLabels = new();
|
||||||
private bool _codeIsReachable = true;
|
private bool _codeIsReachable = true;
|
||||||
private LocalFuncDefinitionNode? _currentFuncDefininition;
|
|
||||||
|
|
||||||
public Generator(List<DefinitionNode> definitions)
|
public Generator(List<DefinitionNode> definitions)
|
||||||
{
|
{
|
||||||
@@ -218,7 +217,6 @@ public class Generator
|
|||||||
|
|
||||||
private void GenerateFuncDefinition(LocalFuncDefinitionNode node)
|
private void GenerateFuncDefinition(LocalFuncDefinitionNode node)
|
||||||
{
|
{
|
||||||
_currentFuncDefininition = node;
|
|
||||||
_variables.Clear();
|
_variables.Clear();
|
||||||
|
|
||||||
if (node.Global)
|
if (node.Global)
|
||||||
@@ -290,7 +288,6 @@ public class Generator
|
|||||||
}
|
}
|
||||||
|
|
||||||
_builder.AppendLine("}");
|
_builder.AppendLine("}");
|
||||||
_currentFuncDefininition = null;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private void GenerateStructDefinition(StructDefinitionNode structDefinition)
|
private void GenerateStructDefinition(StructDefinitionNode structDefinition)
|
||||||
@@ -370,8 +367,7 @@ public class Generator
|
|||||||
}
|
}
|
||||||
|
|
||||||
var parameter = funcCall.Parameters[i];
|
var parameter = funcCall.Parameters[i];
|
||||||
var parameterOutput = GenerateExpression(parameter);
|
var result = GenerateExpression(parameter);
|
||||||
var result = GenerateTypeConversion(parameterOutput, parameter.Type, expectedType);
|
|
||||||
|
|
||||||
var qbeParameterType = SQT(expectedType.Equals(NubPrimitiveType.Any) ? parameter.Type : expectedType);
|
var qbeParameterType = SQT(expectedType.Equals(NubPrimitiveType.Any) ? parameter.Type : expectedType);
|
||||||
parameterStrings.Add($"{qbeParameterType} {result}");
|
parameterStrings.Add($"{qbeParameterType} {result}");
|
||||||
@@ -436,14 +432,8 @@ public class Generator
|
|||||||
{
|
{
|
||||||
if (@return.Value.HasValue)
|
if (@return.Value.HasValue)
|
||||||
{
|
{
|
||||||
if (!_currentFuncDefininition!.ReturnType.HasValue)
|
|
||||||
{
|
|
||||||
throw new Exception("Cannot return a value when function does not have a return value");
|
|
||||||
}
|
|
||||||
|
|
||||||
var result = GenerateExpression(@return.Value.Value);
|
var result = GenerateExpression(@return.Value.Value);
|
||||||
var converted = GenerateTypeConversion(result, @return.Value.Value.Type, _currentFuncDefininition.ReturnType.Value);
|
_builder.AppendLine($" ret {result}");
|
||||||
_builder.AppendLine($" ret {converted}");
|
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
@@ -454,17 +444,10 @@ public class Generator
|
|||||||
private void GenerateVariableAssignment(VariableAssignmentNode variableAssignment)
|
private void GenerateVariableAssignment(VariableAssignmentNode variableAssignment)
|
||||||
{
|
{
|
||||||
var result = GenerateExpression(variableAssignment.Value);
|
var result = GenerateExpression(variableAssignment.Value);
|
||||||
var variableType = variableAssignment.Value.Type;
|
|
||||||
|
|
||||||
if (variableAssignment.ExplicitType.HasValue)
|
|
||||||
{
|
|
||||||
result = GenerateTypeConversion(result, variableType, variableAssignment.ExplicitType.Value);
|
|
||||||
variableType = variableAssignment.ExplicitType.Value;
|
|
||||||
}
|
|
||||||
_variables[variableAssignment.Name] = new Variable
|
_variables[variableAssignment.Name] = new Variable
|
||||||
{
|
{
|
||||||
Identifier = result,
|
Identifier = result,
|
||||||
Type = variableType
|
Type = variableAssignment.Value.Type
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,292 +0,0 @@
|
|||||||
using Nub.Lang.Frontend.Parsing;
|
|
||||||
|
|
||||||
namespace Nub.Lang.Frontend.Typing;
|
|
||||||
|
|
||||||
public class Func(string name, List<FuncParameter> parameters, Optional<BlockNode> body, Optional<NubType> returnType)
|
|
||||||
{
|
|
||||||
public string Name { get; } = name;
|
|
||||||
public List<FuncParameter> Parameters { get; } = parameters;
|
|
||||||
public Optional<BlockNode> Body { get; } = body;
|
|
||||||
public Optional<NubType> ReturnType { get; } = returnType;
|
|
||||||
}
|
|
||||||
|
|
||||||
public class ExpressionTyper
|
|
||||||
{
|
|
||||||
private readonly List<Func> _functions;
|
|
||||||
private readonly List<StructDefinitionNode> _structDefinitions;
|
|
||||||
private readonly Stack<Variable> _variables;
|
|
||||||
|
|
||||||
public ExpressionTyper(List<DefinitionNode> definitions)
|
|
||||||
{
|
|
||||||
_variables = new Stack<Variable>();
|
|
||||||
_functions = [];
|
|
||||||
|
|
||||||
_structDefinitions = definitions.OfType<StructDefinitionNode>().ToList();
|
|
||||||
|
|
||||||
var functions = definitions
|
|
||||||
.OfType<LocalFuncDefinitionNode>()
|
|
||||||
.Select(f => new Func(f.Name, f.Parameters, f.Body, f.ReturnType))
|
|
||||||
.ToList();
|
|
||||||
|
|
||||||
var externFunctions = definitions
|
|
||||||
.OfType<ExternFuncDefinitionNode>()
|
|
||||||
.Select(f => new Func(f.Name, f.Parameters, Optional<BlockNode>.Empty(), f.ReturnType))
|
|
||||||
.ToList();
|
|
||||||
|
|
||||||
_functions.AddRange(functions);
|
|
||||||
_functions.AddRange(externFunctions);
|
|
||||||
}
|
|
||||||
|
|
||||||
public void Populate()
|
|
||||||
{
|
|
||||||
_variables.Clear();
|
|
||||||
|
|
||||||
foreach (var @class in _structDefinitions)
|
|
||||||
{
|
|
||||||
foreach (var variable in @class.Fields)
|
|
||||||
{
|
|
||||||
if (variable.Value.HasValue)
|
|
||||||
{
|
|
||||||
PopulateExpression(variable.Value.Value);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
foreach (var function in _functions)
|
|
||||||
{
|
|
||||||
foreach (var parameter in function.Parameters)
|
|
||||||
{
|
|
||||||
_variables.Push(new Variable(parameter.Name, parameter.Type));
|
|
||||||
}
|
|
||||||
|
|
||||||
if (function.Body.HasValue)
|
|
||||||
{
|
|
||||||
PopulateBlock(function.Body.Value);
|
|
||||||
}
|
|
||||||
for (var i = 0; i < function.Parameters.Count; i++)
|
|
||||||
{
|
|
||||||
_variables.Pop();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private void PopulateBlock(BlockNode block)
|
|
||||||
{
|
|
||||||
var variableCount = _variables.Count;
|
|
||||||
foreach (var statement in block.Statements)
|
|
||||||
{
|
|
||||||
PopulateStatement(statement);
|
|
||||||
}
|
|
||||||
while (_variables.Count > variableCount)
|
|
||||||
{
|
|
||||||
_variables.Pop();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private void PopulateStatement(StatementNode statement)
|
|
||||||
{
|
|
||||||
switch (statement)
|
|
||||||
{
|
|
||||||
case BreakNode:
|
|
||||||
case ContinueNode:
|
|
||||||
break;
|
|
||||||
case FuncCallStatementNode funcCall:
|
|
||||||
PopulateFuncCallStatement(funcCall);
|
|
||||||
break;
|
|
||||||
case IfNode ifStatement:
|
|
||||||
PopulateIf(ifStatement);
|
|
||||||
break;
|
|
||||||
case ReturnNode returnNode:
|
|
||||||
PopulateReturn(returnNode);
|
|
||||||
break;
|
|
||||||
case VariableAssignmentNode variableAssignment:
|
|
||||||
PopulateVariableAssignment(variableAssignment);
|
|
||||||
break;
|
|
||||||
case WhileNode whileStatement:
|
|
||||||
PopulateWhileStatement(whileStatement);
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
throw new ArgumentOutOfRangeException(nameof(statement));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private void PopulateFuncCallStatement(FuncCallStatementNode funcCall)
|
|
||||||
{
|
|
||||||
foreach (var parameter in funcCall.FuncCall.Parameters)
|
|
||||||
{
|
|
||||||
PopulateExpression(parameter);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private void PopulateIf(IfNode ifStatement)
|
|
||||||
{
|
|
||||||
PopulateExpression(ifStatement.Condition);
|
|
||||||
PopulateBlock(ifStatement.Body);
|
|
||||||
if (ifStatement.Else.HasValue)
|
|
||||||
{
|
|
||||||
ifStatement.Else.Value.Match
|
|
||||||
(
|
|
||||||
PopulateIf,
|
|
||||||
PopulateBlock
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private void PopulateReturn(ReturnNode returnNode)
|
|
||||||
{
|
|
||||||
if (returnNode.Value.HasValue)
|
|
||||||
{
|
|
||||||
PopulateExpression(returnNode.Value.Value);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private void PopulateVariableAssignment(VariableAssignmentNode variableAssignment)
|
|
||||||
{
|
|
||||||
PopulateExpression(variableAssignment.Value);
|
|
||||||
_variables.Push(new Variable(variableAssignment.Name, variableAssignment.ExplicitType.HasValue ? variableAssignment.ExplicitType.Value : variableAssignment.Value.Type));
|
|
||||||
}
|
|
||||||
|
|
||||||
private void PopulateVariableReassignment(VariableAssignmentNode variableAssignment)
|
|
||||||
{
|
|
||||||
PopulateExpression(variableAssignment.Value);
|
|
||||||
}
|
|
||||||
|
|
||||||
private void PopulateWhileStatement(WhileNode whileStatement)
|
|
||||||
{
|
|
||||||
PopulateExpression(whileStatement.Condition);
|
|
||||||
PopulateBlock(whileStatement.Body);
|
|
||||||
}
|
|
||||||
|
|
||||||
private void PopulateExpression(ExpressionNode expression)
|
|
||||||
{
|
|
||||||
switch (expression)
|
|
||||||
{
|
|
||||||
case BinaryExpressionNode binaryExpression:
|
|
||||||
PopulateBinaryExpression(binaryExpression);
|
|
||||||
break;
|
|
||||||
case FuncCallExpressionNode funcCall:
|
|
||||||
PopulateFuncCallExpression(funcCall);
|
|
||||||
break;
|
|
||||||
case IdentifierNode identifier:
|
|
||||||
PopulateIdentifier(identifier);
|
|
||||||
break;
|
|
||||||
case LiteralNode literal:
|
|
||||||
PopulateLiteral(literal);
|
|
||||||
break;
|
|
||||||
case StructInitializerNode structInitializer:
|
|
||||||
PopulateStructInitializer(structInitializer);
|
|
||||||
break;
|
|
||||||
case StructFieldAccessorNode structMemberAccessor:
|
|
||||||
PopulateStructMemberAccessorNode(structMemberAccessor);
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
throw new ArgumentOutOfRangeException(nameof(expression));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private void PopulateBinaryExpression(BinaryExpressionNode binaryExpression)
|
|
||||||
{
|
|
||||||
PopulateExpression(binaryExpression.Left);
|
|
||||||
PopulateExpression(binaryExpression.Right);
|
|
||||||
switch (binaryExpression.Operator)
|
|
||||||
{
|
|
||||||
case BinaryExpressionOperator.Equal:
|
|
||||||
case BinaryExpressionOperator.NotEqual:
|
|
||||||
case BinaryExpressionOperator.GreaterThan:
|
|
||||||
case BinaryExpressionOperator.GreaterThanOrEqual:
|
|
||||||
case BinaryExpressionOperator.LessThan:
|
|
||||||
case BinaryExpressionOperator.LessThanOrEqual:
|
|
||||||
{
|
|
||||||
binaryExpression.Type = NubPrimitiveType.Bool;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case BinaryExpressionOperator.Plus:
|
|
||||||
case BinaryExpressionOperator.Minus:
|
|
||||||
case BinaryExpressionOperator.Multiply:
|
|
||||||
case BinaryExpressionOperator.Divide:
|
|
||||||
{
|
|
||||||
binaryExpression.Type = binaryExpression.Left.Type;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
{
|
|
||||||
throw new ArgumentOutOfRangeException(nameof(binaryExpression.Operator));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private void PopulateFuncCallExpression(FuncCallExpressionNode funcCall)
|
|
||||||
{
|
|
||||||
foreach (var parameter in funcCall.FuncCall.Parameters)
|
|
||||||
{
|
|
||||||
PopulateExpression(parameter);
|
|
||||||
}
|
|
||||||
|
|
||||||
var function = _functions.FirstOrDefault(f => f.Name == funcCall.FuncCall.Name);
|
|
||||||
if (function == null)
|
|
||||||
{
|
|
||||||
throw new Exception($"Func {funcCall} is not defined");
|
|
||||||
}
|
|
||||||
if (!function.ReturnType.HasValue)
|
|
||||||
{
|
|
||||||
throw new Exception($"Func {funcCall} must have a return type when used in an expression");
|
|
||||||
}
|
|
||||||
funcCall.Type = function.ReturnType.Value;
|
|
||||||
}
|
|
||||||
|
|
||||||
private void PopulateIdentifier(IdentifierNode identifier)
|
|
||||||
{
|
|
||||||
var type = _variables.FirstOrDefault(v => v.Name == identifier.Identifier)?.Type;
|
|
||||||
if (type == null)
|
|
||||||
{
|
|
||||||
throw new Exception($"Variable {identifier} is not defined");
|
|
||||||
}
|
|
||||||
identifier.Type = type;
|
|
||||||
}
|
|
||||||
|
|
||||||
private static void PopulateLiteral(LiteralNode literal)
|
|
||||||
{
|
|
||||||
literal.Type = literal.LiteralType;
|
|
||||||
}
|
|
||||||
|
|
||||||
private void PopulateStructInitializer(StructInitializerNode structInitializer)
|
|
||||||
{
|
|
||||||
foreach (var initializer in structInitializer.Initializers)
|
|
||||||
{
|
|
||||||
PopulateExpression(initializer.Value);
|
|
||||||
}
|
|
||||||
|
|
||||||
structInitializer.Type = structInitializer.StructType;
|
|
||||||
}
|
|
||||||
|
|
||||||
private void PopulateStructMemberAccessorNode(StructFieldAccessorNode structFieldAccessor)
|
|
||||||
{
|
|
||||||
PopulateExpression(structFieldAccessor.Struct);
|
|
||||||
|
|
||||||
var structType = structFieldAccessor.Struct.Type;
|
|
||||||
if (structType == null)
|
|
||||||
{
|
|
||||||
throw new Exception($"Cannot access field on non-struct type: {structFieldAccessor.Struct}");
|
|
||||||
}
|
|
||||||
|
|
||||||
var structDefinition = _structDefinitions.FirstOrDefault(s => s.Name == structType.Name);
|
|
||||||
if (structDefinition == null)
|
|
||||||
{
|
|
||||||
throw new Exception($"Struct {structType.Name} is not defined");
|
|
||||||
}
|
|
||||||
|
|
||||||
var field = structDefinition.Fields.FirstOrDefault(f => f.Name == structFieldAccessor.Field);
|
|
||||||
if (field == null)
|
|
||||||
{
|
|
||||||
throw new Exception($"Field {structFieldAccessor.Field} is not defined in struct {structType.Name}");
|
|
||||||
}
|
|
||||||
|
|
||||||
structFieldAccessor.Type = field.Type;
|
|
||||||
}
|
|
||||||
|
|
||||||
private class Variable(string name, NubType type)
|
|
||||||
{
|
|
||||||
public string Name { get; } = name;
|
|
||||||
public NubType Type { get; } = type;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
412
src/compiler/Nub.Lang/Frontend/Typing/TypeChecker.cs
Normal file
412
src/compiler/Nub.Lang/Frontend/Typing/TypeChecker.cs
Normal file
@@ -0,0 +1,412 @@
|
|||||||
|
using Nub.Lang.Frontend.Parsing;
|
||||||
|
|
||||||
|
namespace Nub.Lang.Frontend.Typing;
|
||||||
|
|
||||||
|
public class TypeCheckingException : Exception
|
||||||
|
{
|
||||||
|
public TypeCheckingException(string message) : base(message) { }
|
||||||
|
}
|
||||||
|
|
||||||
|
public class TypeChecker
|
||||||
|
{
|
||||||
|
private readonly Dictionary<string, NubType> _variables = new();
|
||||||
|
private readonly Dictionary<string, (List<FuncParameter> Parameters, Optional<NubType> ReturnType)> _functions = new();
|
||||||
|
private readonly Dictionary<string, Dictionary<string, NubType>> _structs = new();
|
||||||
|
private NubType? _currentFunctionReturnType;
|
||||||
|
private bool _hasReturnStatement;
|
||||||
|
|
||||||
|
public void TypeCheck(List<DefinitionNode> definitions)
|
||||||
|
{
|
||||||
|
CollectDefinitions(definitions);
|
||||||
|
|
||||||
|
foreach (var definition in definitions)
|
||||||
|
{
|
||||||
|
if (definition is LocalFuncDefinitionNode funcDef)
|
||||||
|
{
|
||||||
|
TypeCheckFunction(funcDef);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private void CollectDefinitions(List<DefinitionNode> definitions)
|
||||||
|
{
|
||||||
|
foreach (var definition in definitions)
|
||||||
|
{
|
||||||
|
switch (definition)
|
||||||
|
{
|
||||||
|
case StructDefinitionNode structDef:
|
||||||
|
RegisterStruct(structDef);
|
||||||
|
break;
|
||||||
|
case LocalFuncDefinitionNode funcDef:
|
||||||
|
RegisterFunction(funcDef);
|
||||||
|
break;
|
||||||
|
case ExternFuncDefinitionNode externFuncDef:
|
||||||
|
RegisterExternFunction(externFuncDef);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private void RegisterStruct(StructDefinitionNode structDef)
|
||||||
|
{
|
||||||
|
var fields = new Dictionary<string, NubType>();
|
||||||
|
foreach (var field in structDef.Fields)
|
||||||
|
{
|
||||||
|
if (fields.ContainsKey(field.Name))
|
||||||
|
{
|
||||||
|
throw new TypeCheckingException($"Duplicate field '{field.Name}' in struct '{structDef.Name}'");
|
||||||
|
}
|
||||||
|
fields[field.Name] = field.Type;
|
||||||
|
}
|
||||||
|
_structs[structDef.Name] = fields;
|
||||||
|
}
|
||||||
|
|
||||||
|
private void RegisterFunction(LocalFuncDefinitionNode funcDef)
|
||||||
|
{
|
||||||
|
_functions[funcDef.Name] = (funcDef.Parameters, funcDef.ReturnType);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void RegisterExternFunction(ExternFuncDefinitionNode funcDef)
|
||||||
|
{
|
||||||
|
_functions[funcDef.Name] = (funcDef.Parameters, funcDef.ReturnType);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void TypeCheckFunction(LocalFuncDefinitionNode funcDef)
|
||||||
|
{
|
||||||
|
_variables.Clear();
|
||||||
|
_currentFunctionReturnType = funcDef.ReturnType.HasValue ? funcDef.ReturnType.Value : null;
|
||||||
|
_hasReturnStatement = false;
|
||||||
|
|
||||||
|
foreach (var param in funcDef.Parameters)
|
||||||
|
{
|
||||||
|
_variables[param.Name] = param.Type;
|
||||||
|
}
|
||||||
|
|
||||||
|
TypeCheckBlock(funcDef.Body);
|
||||||
|
|
||||||
|
if (_currentFunctionReturnType != null && !_hasReturnStatement)
|
||||||
|
{
|
||||||
|
throw new TypeCheckingException($"Function '{funcDef.Name}' must return a value of type '{_currentFunctionReturnType}'");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private void TypeCheckBlock(BlockNode block)
|
||||||
|
{
|
||||||
|
foreach (var statement in block.Statements)
|
||||||
|
{
|
||||||
|
TypeCheckStatement(statement);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private void TypeCheckStatement(StatementNode statement)
|
||||||
|
{
|
||||||
|
switch (statement)
|
||||||
|
{
|
||||||
|
case VariableAssignmentNode varAssign:
|
||||||
|
TypeCheckVariableAssignment(varAssign);
|
||||||
|
break;
|
||||||
|
case FuncCallStatementNode funcCall:
|
||||||
|
TypeCheckFuncCall(funcCall.FuncCall);
|
||||||
|
break;
|
||||||
|
case IfNode ifNode:
|
||||||
|
TypeCheckIf(ifNode);
|
||||||
|
break;
|
||||||
|
case WhileNode whileNode:
|
||||||
|
TypeCheckWhile(whileNode);
|
||||||
|
break;
|
||||||
|
case ReturnNode returnNode:
|
||||||
|
TypeCheckReturn(returnNode);
|
||||||
|
break;
|
||||||
|
case BreakNode:
|
||||||
|
case ContinueNode:
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
throw new TypeCheckingException($"Unsupported statement type: {statement.GetType().Name}");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private void TypeCheckVariableAssignment(VariableAssignmentNode varAssign)
|
||||||
|
{
|
||||||
|
var valueType = TypeCheckExpression(varAssign.Value);
|
||||||
|
|
||||||
|
if (varAssign.ExplicitType.HasValue)
|
||||||
|
{
|
||||||
|
var explicitType = varAssign.ExplicitType.Value;
|
||||||
|
if (!AreTypesCompatible(valueType, explicitType))
|
||||||
|
{
|
||||||
|
throw new TypeCheckingException($"Cannot assign expression of type '{valueType}' to variable '{varAssign.Name}' of type '{explicitType}'");
|
||||||
|
}
|
||||||
|
_variables[varAssign.Name] = explicitType;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
_variables[varAssign.Name] = valueType;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private NubType TypeCheckFuncCall(FuncCall funcCall)
|
||||||
|
{
|
||||||
|
if (!_functions.TryGetValue(funcCall.Name, out var funcSignature))
|
||||||
|
{
|
||||||
|
throw new TypeCheckingException($"Function '{funcCall.Name}' is not defined");
|
||||||
|
}
|
||||||
|
|
||||||
|
var paramTypes = funcSignature.Parameters;
|
||||||
|
if (paramTypes.Take(paramTypes.Count - 1).Any(x => x.Variadic))
|
||||||
|
{
|
||||||
|
throw new TypeCheckingException($"Function '{funcCall.Name}' has multiple variadic parameters");
|
||||||
|
}
|
||||||
|
|
||||||
|
for (var i = 0; i < funcCall.Parameters.Count; i++)
|
||||||
|
{
|
||||||
|
var argType = TypeCheckExpression(funcCall.Parameters[i]);
|
||||||
|
|
||||||
|
NubType paramType;
|
||||||
|
if (i < paramTypes.Count)
|
||||||
|
{
|
||||||
|
paramType = paramTypes[i].Type;
|
||||||
|
}
|
||||||
|
else if (paramTypes.LastOrDefault()?.Variadic ?? false)
|
||||||
|
{
|
||||||
|
return paramTypes[^1].Type;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
throw new TypeCheckingException($"Function '{funcCall.Name}' does not take {funcCall.Parameters.Count} parameters");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!AreTypesCompatible(argType, paramType))
|
||||||
|
{
|
||||||
|
throw new TypeCheckingException($"Parameter {i} of function '{funcCall.Name}' expects type '{paramType}', but got '{argType}'");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return funcSignature.ReturnType.HasValue ? funcSignature.ReturnType.Value : NubPrimitiveType.Any;
|
||||||
|
}
|
||||||
|
|
||||||
|
private void TypeCheckIf(IfNode ifNode)
|
||||||
|
{
|
||||||
|
var conditionType = TypeCheckExpression(ifNode.Condition);
|
||||||
|
if (!conditionType.Equals(NubPrimitiveType.Bool))
|
||||||
|
{
|
||||||
|
throw new TypeCheckingException($"If condition must be a boolean expression, got '{conditionType}'");
|
||||||
|
}
|
||||||
|
|
||||||
|
TypeCheckBlock(ifNode.Body);
|
||||||
|
|
||||||
|
if (ifNode.Else.HasValue)
|
||||||
|
{
|
||||||
|
var elseValue = ifNode.Else.Value;
|
||||||
|
elseValue.Match(TypeCheckIf, TypeCheckBlock);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private void TypeCheckWhile(WhileNode whileNode)
|
||||||
|
{
|
||||||
|
var conditionType = TypeCheckExpression(whileNode.Condition);
|
||||||
|
if (!conditionType.Equals(NubPrimitiveType.Bool))
|
||||||
|
{
|
||||||
|
throw new TypeCheckingException($"While condition must be a boolean expression, got '{conditionType}'");
|
||||||
|
}
|
||||||
|
|
||||||
|
TypeCheckBlock(whileNode.Body);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void TypeCheckReturn(ReturnNode returnNode)
|
||||||
|
{
|
||||||
|
_hasReturnStatement = true;
|
||||||
|
|
||||||
|
if (returnNode.Value.HasValue)
|
||||||
|
{
|
||||||
|
var returnType = TypeCheckExpression(returnNode.Value.Value);
|
||||||
|
|
||||||
|
if (_currentFunctionReturnType == null)
|
||||||
|
{
|
||||||
|
throw new TypeCheckingException("Cannot return a value from a function with no return type");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!AreTypesCompatible(returnType, _currentFunctionReturnType))
|
||||||
|
{
|
||||||
|
throw new TypeCheckingException($"Return value of type '{returnType}' is not compatible with function return type '{_currentFunctionReturnType}'");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else if (_currentFunctionReturnType != null)
|
||||||
|
{
|
||||||
|
throw new TypeCheckingException($"Function must return a value of type '{_currentFunctionReturnType}'");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private NubType TypeCheckExpression(ExpressionNode expression)
|
||||||
|
{
|
||||||
|
NubType resultType;
|
||||||
|
|
||||||
|
switch (expression)
|
||||||
|
{
|
||||||
|
case LiteralNode literal:
|
||||||
|
resultType = literal.LiteralType;
|
||||||
|
break;
|
||||||
|
|
||||||
|
case IdentifierNode identifier:
|
||||||
|
if (!_variables.TryGetValue(identifier.Identifier, out var varType))
|
||||||
|
{
|
||||||
|
throw new TypeCheckingException($"Variable '{identifier.Identifier}' is not defined");
|
||||||
|
}
|
||||||
|
resultType = varType;
|
||||||
|
break;
|
||||||
|
|
||||||
|
case BinaryExpressionNode binaryExpr:
|
||||||
|
resultType = TypeCheckBinaryExpression(binaryExpr);
|
||||||
|
break;
|
||||||
|
|
||||||
|
case FuncCallExpressionNode funcCallExpr:
|
||||||
|
resultType = TypeCheckFuncCall(funcCallExpr.FuncCall);
|
||||||
|
break;
|
||||||
|
|
||||||
|
case StructInitializerNode structInit:
|
||||||
|
resultType = TypeCheckStructInitializer(structInit);
|
||||||
|
break;
|
||||||
|
|
||||||
|
case StructFieldAccessorNode fieldAccess:
|
||||||
|
resultType = TypeCheckStructFieldAccess(fieldAccess);
|
||||||
|
break;
|
||||||
|
|
||||||
|
default:
|
||||||
|
throw new TypeCheckingException($"Unsupported expression type: {expression.GetType().Name}");
|
||||||
|
}
|
||||||
|
|
||||||
|
expression.Type = resultType;
|
||||||
|
return resultType;
|
||||||
|
}
|
||||||
|
|
||||||
|
private NubType TypeCheckBinaryExpression(BinaryExpressionNode binaryExpr)
|
||||||
|
{
|
||||||
|
var leftType = TypeCheckExpression(binaryExpr.Left);
|
||||||
|
var rightType = TypeCheckExpression(binaryExpr.Right);
|
||||||
|
|
||||||
|
if (!leftType.Equals(rightType))
|
||||||
|
{
|
||||||
|
throw new TypeCheckingException($"Left '{leftType}' and right '{rightType}' side of the binary expression is not equal");
|
||||||
|
}
|
||||||
|
|
||||||
|
switch (binaryExpr.Operator)
|
||||||
|
{
|
||||||
|
case BinaryExpressionOperator.Equal:
|
||||||
|
case BinaryExpressionOperator.NotEqual:
|
||||||
|
return NubPrimitiveType.Bool;
|
||||||
|
case BinaryExpressionOperator.GreaterThan:
|
||||||
|
case BinaryExpressionOperator.GreaterThanOrEqual:
|
||||||
|
case BinaryExpressionOperator.LessThan:
|
||||||
|
case BinaryExpressionOperator.LessThanOrEqual:
|
||||||
|
if (!IsNumeric(leftType))
|
||||||
|
{
|
||||||
|
throw new TypeCheckingException($"Comparison operators require numeric operands, got '{leftType}' and '{rightType}'");
|
||||||
|
}
|
||||||
|
return NubPrimitiveType.Bool;
|
||||||
|
case BinaryExpressionOperator.Plus:
|
||||||
|
case BinaryExpressionOperator.Minus:
|
||||||
|
case BinaryExpressionOperator.Multiply:
|
||||||
|
case BinaryExpressionOperator.Divide:
|
||||||
|
if (!IsNumeric(leftType))
|
||||||
|
{
|
||||||
|
throw new TypeCheckingException($"Arithmetic operators require numeric operands, got '{leftType}' and '{rightType}'");
|
||||||
|
}
|
||||||
|
return leftType;
|
||||||
|
default:
|
||||||
|
throw new TypeCheckingException($"Unsupported binary operator: {binaryExpr.Operator}");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private NubType TypeCheckStructInitializer(StructInitializerNode structInit)
|
||||||
|
{
|
||||||
|
var structType = structInit.StructType;
|
||||||
|
if (structType is not NubCustomType customType)
|
||||||
|
{
|
||||||
|
throw new TypeCheckingException($"Type '{structType}' is not a struct type");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!_structs.TryGetValue(customType.Name, out var fields))
|
||||||
|
{
|
||||||
|
throw new TypeCheckingException($"Struct type '{customType.Name}' is not defined");
|
||||||
|
}
|
||||||
|
|
||||||
|
foreach (var initializer in structInit.Initializers)
|
||||||
|
{
|
||||||
|
if (!fields.TryGetValue(initializer.Key, out var fieldType))
|
||||||
|
{
|
||||||
|
throw new TypeCheckingException($"Field '{initializer.Key}' does not exist in struct '{customType.Name}'");
|
||||||
|
}
|
||||||
|
|
||||||
|
var initializerType = TypeCheckExpression(initializer.Value);
|
||||||
|
if (!AreTypesCompatible(initializerType, fieldType))
|
||||||
|
{
|
||||||
|
throw new TypeCheckingException($"Cannot initialize field '{initializer.Key}' of type '{fieldType}' with expression of type '{initializerType}'");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
foreach (var field in fields)
|
||||||
|
{
|
||||||
|
if (!structInit.Initializers.ContainsKey(field.Key))
|
||||||
|
{
|
||||||
|
throw new TypeCheckingException($"Field '{field.Key}' of struct '{customType.Name}' is not initialized");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return structType;
|
||||||
|
}
|
||||||
|
|
||||||
|
private NubType TypeCheckStructFieldAccess(StructFieldAccessorNode fieldAccess)
|
||||||
|
{
|
||||||
|
var structType = TypeCheckExpression(fieldAccess.Struct);
|
||||||
|
|
||||||
|
if (structType is not NubCustomType customType)
|
||||||
|
{
|
||||||
|
throw new TypeCheckingException($"Cannot access field '{fieldAccess.Field}' on non-struct type '{structType}'");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!_structs.TryGetValue(customType.Name, out var fields))
|
||||||
|
{
|
||||||
|
throw new TypeCheckingException($"Struct type '{customType.Name}' is not defined");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!fields.TryGetValue(fieldAccess.Field, out var fieldType))
|
||||||
|
{
|
||||||
|
throw new TypeCheckingException($"Field '{fieldAccess.Field}' does not exist in struct '{customType.Name}'");
|
||||||
|
}
|
||||||
|
|
||||||
|
return fieldType;
|
||||||
|
}
|
||||||
|
|
||||||
|
#region Type Helper Methods
|
||||||
|
|
||||||
|
private static bool AreTypesCompatible(NubType sourceType, NubType targetType)
|
||||||
|
{
|
||||||
|
return targetType.Equals(NubPrimitiveType.Any) || sourceType.Equals(targetType);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static bool IsNumeric(NubType type)
|
||||||
|
{
|
||||||
|
if (type is not NubPrimitiveType primitiveType)
|
||||||
|
{
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
switch (primitiveType.Kind)
|
||||||
|
{
|
||||||
|
case PrimitiveTypeKind.I8:
|
||||||
|
case PrimitiveTypeKind.I16:
|
||||||
|
case PrimitiveTypeKind.I32:
|
||||||
|
case PrimitiveTypeKind.I64:
|
||||||
|
case PrimitiveTypeKind.U8:
|
||||||
|
case PrimitiveTypeKind.U16:
|
||||||
|
case PrimitiveTypeKind.U32:
|
||||||
|
case PrimitiveTypeKind.U64:
|
||||||
|
case PrimitiveTypeKind.F32:
|
||||||
|
case PrimitiveTypeKind.F64:
|
||||||
|
return true;
|
||||||
|
default:
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#endregion
|
||||||
|
}
|
||||||
@@ -44,8 +44,8 @@ internal static class Program
|
|||||||
var modules = RunFrontend(input);
|
var modules = RunFrontend(input);
|
||||||
var definitions = modules.SelectMany(f => f.Definitions).ToList();
|
var definitions = modules.SelectMany(f => f.Definitions).ToList();
|
||||||
|
|
||||||
var typer = new ExpressionTyper(definitions);
|
var typeChecker = new TypeChecker();
|
||||||
typer.Populate();
|
typeChecker.TypeCheck(definitions);
|
||||||
|
|
||||||
var generator = new Generator(definitions);
|
var generator = new Generator(definitions);
|
||||||
var result = generator.Generate();
|
var result = generator.Generate();
|
||||||
|
|||||||
Reference in New Issue
Block a user