using Common; using Syntax.Diagnostics; using Syntax.Node; using Syntax.Tokenization; using UnaryExpressionNode = Syntax.Node.UnaryExpressionNode; namespace Syntax.Binding; // TODO: Currently anonymous function does not get a new scope public static class Binder { private static SyntaxTree _syntaxTree = null!; private static DefinitionTable _definitionTable = null!; private static Dictionary _variables = new(); private static NubType? _funcReturnType; public static BoundSyntaxTree Bind(SyntaxTree syntaxTree, DefinitionTable definitionTable, out IEnumerable diagnostics) { _syntaxTree = syntaxTree; _definitionTable = definitionTable; _variables = []; _funcReturnType = null; var definitions = new List(); foreach (var topLevel in syntaxTree.TopLevelNodes) { definitions.Add(BindTopLevel(topLevel)); } diagnostics = []; return new BoundSyntaxTree(syntaxTree.Namespace, definitions); } private static BoundTopLevelNode BindTopLevel(TopLevelNode node) { return node switch { ExternFuncNode definition => BindExternFuncDefinition(definition), TraitImplNode definition => BindTraitImplementation(definition), TraitNode definition => BindTraitDefinition(definition), LocalFuncNode definition => BindLocalFuncDefinition(definition), StructNode definition => BindStruct(definition), _ => throw new ArgumentOutOfRangeException(nameof(node)) }; } private static BoundTraitImplNode BindTraitImplementation(TraitImplNode node) { _variables.Clear(); var functions = new List(); foreach (var function in node.Functions) { foreach (var parameter in function.Parameters) { _variables[parameter.Name] = parameter.Type; } functions.Add(new BoundTraitFuncImplNode(function.Tokens, function.Name, function.Parameters, function.ReturnType, BindBlock(function.Body))); } return new BoundTraitImplNode(node.Tokens, node.Namespace, node.TraitType, node.ForType, functions); } private static BoundTraitNode BindTraitDefinition(TraitNode node) { var functions = new List(); foreach (var func in node.Functions) { functions.Add(new BoundTraitFuncNode(node.Tokens, func.Name, func.Parameters, func.ReturnType)); } return new BoundTraitNode(node.Tokens, node.Namespace, node.Name, functions); } private static BoundStructNode BindStruct(StructNode node) { var defOpt = _definitionTable.LookupStruct(node.Namespace, node.Name); if (!defOpt.TryGetValue(out var definition)) { throw new NotImplementedException("Diagnostics not implemented"); } var structFields = new List(); foreach (var structField in node.Fields) { var value = Optional.Empty(); if (structField.Value.HasValue) { var definitionField = definition.Fields.FirstOrDefault(f => f.Name == structField.Name); if (definitionField == null) { throw new NotImplementedException("Diagnostics not implemented"); } value = BindExpression(structField.Value.Value, definitionField.Type); } structFields.Add(new BoundStructFieldNode(structField.Tokens, structField.Name, structField.Type, value)); } return new BoundStructNode(node.Tokens, node.Namespace, node.Name, structFields); } private static BoundExternFuncNode BindExternFuncDefinition(ExternFuncNode node) { return new BoundExternFuncNode(node.Tokens, node.Namespace, node.Name, node.CallName, node.Parameters, node.ReturnType); } private static BoundLocalFuncNode BindLocalFuncDefinition(LocalFuncNode node) { _variables.Clear(); _funcReturnType = node.ReturnType; foreach (var parameter in node.Parameters) { _variables[parameter.Name] = parameter.Type; } var body = BindBlock(node.Body); return new BoundLocalFuncNode(node.Tokens, node.Namespace, node.Name, node.Parameters, body, node.ReturnType, node.Exported); } private static BoundBlock BindBlock(BlockNode node) { var statements = new List(); foreach (var statement in node.Statements) { statements.Add(BindStatement(statement)); } return new BoundBlock(node.Tokens, statements); } private static BoundStatementNode BindStatement(StatementNode node) { return node switch { AssignmentNode statement => BindAssignment(statement), BreakNode statement => BindBreak(statement), ContinueNode statement => BindContinue(statement), IfNode statement => BindIf(statement), ReturnNode statement => BindReturn(statement), StatementExpressionNode statement => BindStatementExpression(statement), VariableDeclarationNode statement => BindVariableDeclaration(statement), WhileNode statement => BindWhile(statement), _ => throw new ArgumentOutOfRangeException(nameof(node)) }; } private static BoundStatementNode BindAssignment(AssignmentNode statement) { var expression = BindExpression(statement.Expression); var value = BindExpression(statement.Value, expression.Type); return new BoundAssignmentNode(statement.Tokens, expression, value); } private static BoundBreakNode BindBreak(BreakNode statement) { return new BoundBreakNode(statement.Tokens); } private static BoundContinueNode BindContinue(ContinueNode statement) { return new BoundContinueNode(statement.Tokens); } private static BoundIfNode BindIf(IfNode statement) { var elseStatement = Optional.Empty>(); if (statement.Else.HasValue) { elseStatement = statement.Else.Value.Match> ( elseIf => BindIf(elseIf), @else => BindBlock(@else) ); } return new BoundIfNode(statement.Tokens, BindExpression(statement.Condition, NubPrimitiveType.Bool), BindBlock(statement.Body), elseStatement); } private static BoundReturnNode BindReturn(ReturnNode statement) { var value = Optional.Empty(); if (statement.Value.HasValue) { value = BindExpression(statement.Value.Value, _funcReturnType); } return new BoundReturnNode(statement.Tokens, value); } private static BoundStatementExpressionNode BindStatementExpression(StatementExpressionNode statement) { return new BoundStatementExpressionNode(statement.Tokens, BindExpression(statement.Expression)); } private static BoundVariableDeclarationNode BindVariableDeclaration(VariableDeclarationNode statement) { NubType? type = null; if (statement.ExplicitType.HasValue) { type = statement.ExplicitType.Value; } var assignment = Optional.Empty(); if (statement.Assignment.HasValue) { var boundValue = BindExpression(statement.Assignment.Value, type); assignment = boundValue; type = boundValue.Type; } if (type == null) { throw new NotImplementedException("Diagnostics not implemented"); } _variables[statement.Name] = type; return new BoundVariableDeclarationNode(statement.Tokens, statement.Name, statement.ExplicitType, assignment, type); } private static BoundWhileNode BindWhile(WhileNode statement) { return new BoundWhileNode(statement.Tokens, BindExpression(statement.Condition, NubPrimitiveType.Bool), BindBlock(statement.Body)); } private static BoundExpressionNode BindExpression(ExpressionNode node, NubType? expectedType = null) { return node switch { AddressOfNode expression => BindAddressOf(expression), AnonymousFuncNode expression => BindAnonymousFunc(expression), ArrayIndexAccessNode expression => BindArrayIndexAccess(expression), ArrayInitializerNode expression => BindArrayInitializer(expression), BinaryExpressionNode expression => BindBinaryExpression(expression), DereferenceNode expression => BindDereference(expression), FuncCallNode expression => BindFuncCall(expression), IdentifierNode expression => BindIdentifier(expression), LiteralNode expression => BindLiteral(expression, expectedType), MemberAccessNode expression => BindMemberAccess(expression), StructInitializerNode expression => BindStructInitializer(expression), UnaryExpressionNode expression => BindUnaryExpression(expression), _ => throw new ArgumentOutOfRangeException(nameof(node)) }; } private static BoundAddressOfNode BindAddressOf(AddressOfNode expression) { var inner = BindExpression(expression.Expression); return new BoundAddressOfNode(expression.Tokens, new NubPointerType(inner.Type), inner); } private static BoundAnonymousFuncNode BindAnonymousFunc(AnonymousFuncNode expression) { var parameterTypes = expression.Parameters.Select(x => x.Type).ToList(); var body = BindBlock(expression.Body); return new BoundAnonymousFuncNode(expression.Tokens, new NubFuncType(expression.ReturnType, parameterTypes), expression.Parameters, body, expression.ReturnType); } private static BoundArrayIndexAccessNode BindArrayIndexAccess(ArrayIndexAccessNode expression) { var boundArray = BindExpression(expression.Array); var elementType = ((NubArrayType)boundArray.Type).ElementType; return new BoundArrayIndexAccessNode(expression.Tokens, elementType, boundArray, BindExpression(expression.Index, NubPrimitiveType.U64)); } private static BoundArrayInitializerNode BindArrayInitializer(ArrayInitializerNode expression) { return new BoundArrayInitializerNode(expression.Tokens, new NubArrayType(expression.ElementType), BindExpression(expression.Capacity, NubPrimitiveType.U64), expression.ElementType); } private static BoundBinaryExpressionNode BindBinaryExpression(BinaryExpressionNode expression) { var boundLeft = BindExpression(expression.Left); var boundRight = BindExpression(expression.Right, boundLeft.Type); return new BoundBinaryExpressionNode(expression.Tokens, boundLeft.Type, boundLeft, expression.Operator, boundRight); } private static BoundDereferenceNode BindDereference(DereferenceNode expression) { var boundExpression = BindExpression(expression.Expression); var dereferencedType = ((NubPointerType)boundExpression.Type).BaseType; return new BoundDereferenceNode(expression.Tokens, dereferencedType, boundExpression); } private static BoundFuncCallNode BindFuncCall(FuncCallNode expression) { var boundExpression = BindExpression(expression.Expression); var funcType = (NubFuncType)boundExpression.Type; var returnType = ((NubFuncType)boundExpression.Type).ReturnType; var parameters = new List(); foreach (var (i, parameter) in expression.Parameters.Index()) { if (i >= funcType.Parameters.Count) { throw new NotImplementedException("Diagnostics not implemented"); } var expectedType = funcType.Parameters[i]; parameters.Add(BindExpression(parameter, expectedType)); } return new BoundFuncCallNode(expression.Tokens, returnType, boundExpression, parameters); } private static BoundIdentifierNode BindIdentifier(IdentifierNode expression) { NubType? type = null; var definition = _definitionTable.LookupFunction(expression.Namespace.Or(_syntaxTree.Namespace), expression.Name); if (definition.HasValue) { type = new NubFuncType(definition.Value.ReturnType, definition.Value.Parameters.Select(p => p.Type).ToList()); } if (type == null && !expression.Namespace.HasValue) { type = _variables[expression.Name]; } if (type == null) { throw new NotImplementedException("Diagnostics not implemented"); } return new BoundIdentifierNode(expression.Tokens, type, expression.Namespace, expression.Name); } private static BoundLiteralNode BindLiteral(LiteralNode expression, NubType? expectedType = null) { var type = expectedType ?? expression.Kind switch { LiteralKind.Integer => NubPrimitiveType.I64, LiteralKind.Float => NubPrimitiveType.F64, LiteralKind.String => new NubStringType(), LiteralKind.Bool => NubPrimitiveType.Bool, _ => throw new ArgumentOutOfRangeException() }; return new BoundLiteralNode(expression.Tokens, type, expression.Literal, expression.Kind); } private static BoundMemberAccessNode BindMemberAccess(MemberAccessNode expression) { var boundExpression = BindExpression(expression.Expression); var implementation = _definitionTable.LookupTraitImplementationForType(boundExpression.Type, expression.Member); if (implementation.HasValue) { var type = new NubFuncType(implementation.Value.Item2.ReturnType, implementation.Value.Item2.Parameters.Select(p => p.Type).ToList()); return new BoundMemberAccessNode(expression.Tokens, type, boundExpression, expression.Member); } if (boundExpression.Type is NubCustomType customType) { var function = _definitionTable.LookupFunctionOnTrait(customType.Namespace, customType.Name, expression.Member); if (function.HasValue) { var type = new NubFuncType(function.Value.ReturnType, function.Value.Parameters.Select(p => p.Type).ToList()); return new BoundMemberAccessNode(expression.Tokens, type, boundExpression, expression.Member); } var structDef = _definitionTable.LookupStruct(customType.Namespace, customType.Name); if (structDef.HasValue) { var matchingFields = structDef.Value.Fields.Where(f => f.Name == expression.Member).ToList(); if (matchingFields.Count > 1) { throw new NotImplementedException("Diagnostics not implemented"); } if (matchingFields.Count == 1) { return new BoundMemberAccessNode(expression.Tokens, matchingFields[0].Type, boundExpression, expression.Member); } } } if (boundExpression.Type is NubStringType or NubCStringType or NubArrayType && expression.Member == "count") { return new BoundMemberAccessNode(expression.Tokens, NubPrimitiveType.I64, boundExpression, expression.Member); } throw new NotImplementedException("Diagnostics not implemented"); } private static BoundStructInitializerNode BindStructInitializer(StructInitializerNode expression) { if (expression.StructType is not NubCustomType structType) { throw new NotImplementedException("Diagnostics not implemented"); } var defOpt = _definitionTable.LookupStruct(structType.Namespace, structType.Name); if (!defOpt.TryGetValue(out var definition)) { throw new NotImplementedException("Diagnostics not implemented"); } var initializers = new Dictionary(); foreach (var (member, initializer) in expression.Initializers) { var definitionField = definition.Fields.FirstOrDefault(x => x.Name == member); if (definitionField == null) { throw new NotImplementedException("Diagnostics not implemented"); } initializers[member] = BindExpression(initializer, definitionField.Type); } return new BoundStructInitializerNode(expression.Tokens, structType, structType, initializers); } private static BoundUnaryExpressionNode BindUnaryExpression(UnaryExpressionNode expression) { var boundOperand = BindExpression(expression.Operand); NubType? type = null; switch (expression.Operator) { case UnaryExpressionOperator.Negate: { boundOperand = BindExpression(expression.Operand, NubPrimitiveType.I64); if (boundOperand.Type.IsNumber) { type = boundOperand.Type; } break; } case UnaryExpressionOperator.Invert: { boundOperand = BindExpression(expression.Operand, NubPrimitiveType.Bool); type = new NubPrimitiveType(PrimitiveTypeKind.Bool); break; } } if (type == null) { throw new NotImplementedException("Diagnostics not implemented"); } return new BoundUnaryExpressionNode(expression.Tokens, type, expression.Operator, boundOperand); } }