using Common; using Syntax.Diagnostics; using Syntax.Parsing; using Syntax.Parsing.Node; using Syntax.Tokenization; using Syntax.Typing.BoundNode; using UnaryExpressionNode = Syntax.Parsing.Node.UnaryExpressionNode; namespace Syntax.Typing; // TODO: Currently anonymous function does not get a new scope public static class Binder { private static SyntaxTree _syntaxTree = null!; private static DefinitionTable _definitionTable = null!; private static Dictionary _variables = new(); private static NubType? _funcReturnType; public static BoundSyntaxTree Bind(SyntaxTree syntaxTree, DefinitionTable definitionTable, out IEnumerable diagnostics) { _syntaxTree = syntaxTree; _definitionTable = definitionTable; _variables = []; _funcReturnType = null; var definitions = new List(); foreach (var definition in syntaxTree.Definitions) { definitions.Add(BindDefinition(definition)); } diagnostics = []; return new BoundSyntaxTree(syntaxTree.Namespace, definitions); } private static BoundDefinitionNode BindDefinition(DefinitionNode node) { return node switch { ExternFuncDefinitionNode definition => BindExternFuncDefinition(definition), ImplementationDefinitionNode definition => BindImplementation(definition), InterfaceDefinitionNode definition => BindInterfaceDefinition(definition), LocalFuncDefinitionNode definition => BindLocalFuncDefinition(definition), StructDefinitionNode definition => BindStruct(definition), _ => throw new ArgumentOutOfRangeException(nameof(node)) }; } private static BoundImplementationDefinitionNode BindImplementation(ImplementationDefinitionNode node) { _variables.Clear(); var functions = new List(); foreach (var function in node.Functions) { var parameters = new List(); foreach (var parameter in function.Parameters) { parameters.Add(new BoundFuncParameter(parameter.Name, parameter.Type)); _variables[parameter.Name] = parameter.Type; } functions.Add(new BoundImplementationFunc(function.Name, parameters, function.ReturnType, BindBlock(function.Body))); } return new BoundImplementationDefinitionNode(node.Tokens, node.Documentation, node.Namespace, node.Type, node.Interface, functions); } private static BoundInterfaceDefinitionNode BindInterfaceDefinition(InterfaceDefinitionNode node) { var functions = new List(); foreach (var func in node.Functions) { var parameters = new List(); foreach (var parameter in func.Parameters) { parameters.Add(new BoundFuncParameter(parameter.Name, parameter.Type)); } functions.Add(new BoundInterfaceFunc(func.Name, parameters, func.ReturnType)); } return new BoundInterfaceDefinitionNode(node.Tokens, node.Documentation, node.Namespace, node.Name, functions); } private static BoundStructDefinitionNode BindStruct(StructDefinitionNode node) { var defOpt = _definitionTable.LookupStruct(node.Namespace, node.Name); if (!defOpt.TryGetValue(out var definition)) { throw new NotImplementedException("Diagnostics not implemented"); } var structFields = new List(); foreach (var structField in node.Fields) { var value = Optional.Empty(); if (structField.Value.HasValue) { var definitionField = definition.Fields.FirstOrDefault(f => f.Name == structField.Name); if (definitionField == null) { throw new NotImplementedException("Diagnostics not implemented"); } value = BindExpression(structField.Value.Value, definitionField.Type); } structFields.Add(new BoundStructField(structField.Name, structField.Type, value)); } return new BoundStructDefinitionNode(node.Tokens, node.Documentation, node.Namespace, node.Name, structFields); } private static BoundExternFuncDefinitionNode BindExternFuncDefinition(ExternFuncDefinitionNode node) { var parameters = new List(); foreach (var parameter in node.Parameters) { parameters.Add(new BoundFuncParameter(parameter.Name, parameter.Type)); } return new BoundExternFuncDefinitionNode(node.Tokens, node.Documentation, node.Namespace, node.Name, node.CallName, parameters, node.ReturnType); } private static BoundLocalFuncDefinitionNode BindLocalFuncDefinition(LocalFuncDefinitionNode node) { _variables.Clear(); _funcReturnType = node.ReturnType; var parameters = new List(); foreach (var parameter in node.Parameters) { parameters.Add(new BoundFuncParameter(parameter.Name, parameter.Type)); _variables[parameter.Name] = parameter.Type; } var body = BindBlock(node.Body); return new BoundLocalFuncDefinitionNode(node.Tokens, node.Documentation, node.Namespace, node.Name, parameters, body, node.ReturnType, node.Exported); } private static BoundBlockNode BindBlock(BlockNode node) { var statements = new List(); foreach (var statement in node.Statements) { statements.Add(BindStatement(statement)); } return new BoundBlockNode(node.Tokens, statements); } private static BoundStatementNode BindStatement(StatementNode node) { return node switch { AssignmentNode statement => BindAssignment(statement), BreakNode statement => BindBreak(statement), ContinueNode statement => BindContinue(statement), IfNode statement => BindIf(statement), ReturnNode statement => BindReturn(statement), StatementExpressionNode statement => BindStatementExpression(statement), VariableDeclarationNode statement => BindVariableDeclaration(statement), WhileNode statement => BindWhile(statement), _ => throw new ArgumentOutOfRangeException(nameof(node)) }; } private static BoundStatementNode BindAssignment(AssignmentNode statement) { var expression = BindExpression(statement.Expression); var value = BindExpression(statement.Value, expression.Type); return new BoundAssignmentNode(statement.Tokens, expression, value); } private static BoundBreakNode BindBreak(BreakNode statement) { return new BoundBreakNode(statement.Tokens); } private static BoundContinueNode BindContinue(ContinueNode statement) { return new BoundContinueNode(statement.Tokens); } private static BoundIfNode BindIf(IfNode statement) { var elseStatement = Optional.Empty>(); if (statement.Else.HasValue) { elseStatement = statement.Else.Value.Match> ( elseIf => BindIf(elseIf), @else => BindBlock(@else) ); } return new BoundIfNode(statement.Tokens, BindExpression(statement.Condition, NubPrimitiveType.Bool), BindBlock(statement.Body), elseStatement); } private static BoundReturnNode BindReturn(ReturnNode statement) { var value = Optional.Empty(); if (statement.Value.HasValue) { value = BindExpression(statement.Value.Value, _funcReturnType); } return new BoundReturnNode(statement.Tokens, value); } private static BoundStatementExpressionNode BindStatementExpression(StatementExpressionNode statement) { return new BoundStatementExpressionNode(statement.Tokens, BindExpression(statement.Expression)); } private static BoundVariableDeclarationNode BindVariableDeclaration(VariableDeclarationNode statement) { NubType? type = null; if (statement.ExplicitType.HasValue) { type = statement.ExplicitType.Value; } var assignment = Optional.Empty(); if (statement.Assignment.HasValue) { var boundValue = BindExpression(statement.Assignment.Value, type); assignment = boundValue; type = boundValue.Type; } if (type == null) { throw new NotImplementedException("Diagnostics not implemented"); } _variables[statement.Name] = type; return new BoundVariableDeclarationNode(statement.Tokens, statement.Name, statement.ExplicitType, assignment, type); } private static BoundWhileNode BindWhile(WhileNode statement) { return new BoundWhileNode(statement.Tokens, BindExpression(statement.Condition, NubPrimitiveType.Bool), BindBlock(statement.Body)); } private static BoundExpressionNode BindExpression(ExpressionNode node, NubType? expectedType = null) { return node switch { AddressOfNode expression => BindAddressOf(expression), AnonymousFuncNode expression => BindAnonymousFunc(expression), ArrayIndexAccessNode expression => BindArrayIndexAccess(expression), ArrayInitializerNode expression => BindArrayInitializer(expression), BinaryExpressionNode expression => BindBinaryExpression(expression), DereferenceNode expression => BindDereference(expression), FuncCallNode expression => BindFuncCall(expression), IdentifierNode expression => BindIdentifier(expression), LiteralNode expression => BindLiteral(expression, expectedType), MemberAccessNode expression => BindMemberAccess(expression), StructInitializerNode expression => BindStructInitializer(expression), UnaryExpressionNode expression => BindUnaryExpression(expression), _ => throw new ArgumentOutOfRangeException(nameof(node)) }; } private static BoundAddressOfNode BindAddressOf(AddressOfNode expression) { var inner = BindExpression(expression.Expression); return new BoundAddressOfNode(expression.Tokens, new NubPointerType(inner.Type), inner); } private static BoundAnonymousFuncNode BindAnonymousFunc(AnonymousFuncNode expression) { var parameters = new List(); var parameterTypes = new List(); foreach (var parameter in expression.Parameters) { var boundParameter = new BoundFuncParameter(parameter.Name, parameter.Type); parameters.Add(boundParameter); parameterTypes.Add(boundParameter.Type); } var body = BindBlock(expression.Body); return new BoundAnonymousFuncNode(expression.Tokens, new NubFuncType(expression.ReturnType, parameterTypes), parameters, body, expression.ReturnType); } private static BoundArrayIndexAccessNode BindArrayIndexAccess(ArrayIndexAccessNode expression) { var boundArray = BindExpression(expression.Array); var elementType = ((NubArrayType)boundArray.Type).ElementType; return new BoundArrayIndexAccessNode(expression.Tokens, elementType, boundArray, BindExpression(expression.Index, NubPrimitiveType.U64)); } private static BoundArrayInitializerNode BindArrayInitializer(ArrayInitializerNode expression) { return new BoundArrayInitializerNode(expression.Tokens, new NubArrayType(expression.ElementType), BindExpression(expression.Capacity, NubPrimitiveType.U64), expression.ElementType); } private static BoundBinaryExpressionNode BindBinaryExpression(BinaryExpressionNode expression) { var boundLeft = BindExpression(expression.Left); var boundRight = BindExpression(expression.Right, boundLeft.Type); return new BoundBinaryExpressionNode(expression.Tokens, boundLeft.Type, boundLeft, expression.Operator, boundRight); } private static BoundDereferenceNode BindDereference(DereferenceNode expression) { var boundExpression = BindExpression(expression.Expression); var dereferencedType = ((NubPointerType)boundExpression.Type).BaseType; return new BoundDereferenceNode(expression.Tokens, dereferencedType, boundExpression); } private static BoundFuncCallNode BindFuncCall(FuncCallNode expression) { var boundExpression = BindExpression(expression.Expression); var funcType = (NubFuncType)boundExpression.Type; var returnType = ((NubFuncType)boundExpression.Type).ReturnType; var parameters = new List(); foreach (var (i, parameter) in expression.Parameters.Index()) { if (i >= funcType.Parameters.Count) { throw new NotImplementedException("Diagnostics not implemented"); } var expectedType = funcType.Parameters[i]; parameters.Add(BindExpression(parameter, expectedType)); } return new BoundFuncCallNode(expression.Tokens, returnType, boundExpression, parameters); } private static BoundIdentifierNode BindIdentifier(IdentifierNode expression) { NubType? type = null; var definition = _definitionTable.LookupFunction(expression.Namespace.Or(_syntaxTree.Namespace), expression.Name); if (definition.HasValue) { type = new NubFuncType(definition.Value.ReturnType, definition.Value.Parameters.Select(p => p.Type).ToList()); } if (type == null && !expression.Namespace.HasValue) { type = _variables[expression.Name]; } if (type == null) { throw new NotImplementedException("Diagnostics not implemented"); } return new BoundIdentifierNode(expression.Tokens, type, expression.Namespace, expression.Name); } private static BoundLiteralNode BindLiteral(LiteralNode expression, NubType? expectedType = null) { var type = expectedType ?? expression.Kind switch { LiteralKind.Integer => NubPrimitiveType.I64, LiteralKind.Float => NubPrimitiveType.F64, LiteralKind.String => new NubStringType(), LiteralKind.Bool => NubPrimitiveType.Bool, _ => throw new ArgumentOutOfRangeException() }; return new BoundLiteralNode(expression.Tokens, type, expression.Literal, expression.Kind); } private static BoundMemberAccessNode BindMemberAccess(MemberAccessNode expression) { var boundExpression = BindExpression(expression.Expression); NubType? type = null; var implementation = _definitionTable.GetImplementations().FirstOrDefault(x => x.Type.Equals(boundExpression.Type)); if (implementation != null) { if (implementation.Interface.) } switch (boundExpression.Type) { case NubArrayType: case NubStringType: case NubCStringType: { if (expression.Member == "count") { type = NubPrimitiveType.U64; } break; } case NubStructType structType: { var defOpt = _definitionTable.LookupStruct(structType.Namespace, structType.Name); if (!defOpt.TryGetValue(out var definition)) { throw new NotImplementedException("Diagnostics not implemented"); } var field = definition.Fields.FirstOrDefault(f => f.Name == expression.Member); if (field == null) { throw new NotImplementedException("Diagnostics not implemented"); } type = field.Type; break; } } if (type == null) { throw new NotImplementedException("Diagnostics not implemented"); } return new BoundMemberAccessNode(expression.Tokens, type, boundExpression, expression.Member); } private static BoundStructInitializerNode BindStructInitializer(StructInitializerNode expression) { var defOpt = _definitionTable.LookupStruct(expression.StructType.Namespace, expression.StructType.Name); if (!defOpt.TryGetValue(out var definition)) { throw new NotImplementedException("Diagnostics not implemented"); } var initializers = new Dictionary(); foreach (var (member, initializer) in expression.Initializers) { var definitionField = definition.Fields.FirstOrDefault(x => x.Name == member); if (definitionField == null) { throw new NotImplementedException("Diagnostics not implemented"); } initializers[member] = BindExpression(initializer, definitionField.Type); } return new BoundStructInitializerNode(expression.Tokens, expression.StructType, expression.StructType, initializers); } private static BoundUnaryExpressionNode BindUnaryExpression(UnaryExpressionNode expression) { var boundOperand = BindExpression(expression.Operand); NubType? type = null; switch (expression.Operator) { case UnaryExpressionOperator.Negate: { boundOperand = BindExpression(expression.Operand, NubPrimitiveType.I64); if (boundOperand.Type.IsNumber) { type = boundOperand.Type; } break; } case UnaryExpressionOperator.Invert: { boundOperand = BindExpression(expression.Operand, NubPrimitiveType.Bool); type = new NubPrimitiveType(PrimitiveTypeKind.Bool); break; } } if (type == null) { throw new NotImplementedException("Diagnostics not implemented"); } return new BoundUnaryExpressionNode(expression.Tokens, type, expression.Operator, boundOperand); } }