using System.Diagnostics; using Common; using Syntax.Diagnostics; using Syntax.Node; using Syntax.Tokenization; using UnaryExpressionNode = Syntax.Node.UnaryExpressionNode; namespace Syntax.Binding; // TODO: Currently anonymous function does not get a new scope public sealed class Binder { private readonly SyntaxTree _syntaxTree; private readonly DefinitionTable _definitionTable; // TODO: Implement proper variable tracking and scoping private Dictionary _variables = new(); private NubType? _functionReturnType; public Binder(SyntaxTree syntaxTree, DefinitionTable definitionTable) { _syntaxTree = syntaxTree; _definitionTable = definitionTable; } public BoundSyntaxTree Bind() { _variables = []; _functionReturnType = null; var diagnostics = new List(); var topLevelNodes = new List(); foreach (var topLevel in _syntaxTree.TopLevelNodes) { try { topLevelNodes.Add(BindTopLevel(topLevel)); } catch (BindException e) { diagnostics.Add(e.Diagnostic); } } return new BoundSyntaxTree(_syntaxTree.Namespace, topLevelNodes, diagnostics); } private BoundTopLevelNode BindTopLevel(TopLevelNode node) { return node switch { ExternFuncNode topLevel => BindExternFuncDefinition(topLevel), TraitImplNode topLevel => BindTraitImplementation(topLevel), TraitNode topLevel => BindTraitDefinition(topLevel), LocalFuncNode topLevel => BindLocalFuncDefinition(topLevel), StructNode topLevel => BindStruct(topLevel), _ => throw new ArgumentOutOfRangeException(nameof(node)) }; } private BoundTraitImplNode BindTraitImplementation(TraitImplNode node) { _variables.Clear(); var functions = new List(); foreach (var function in node.Functions) { foreach (var parameter in function.Parameters) { _variables[parameter.Name] = parameter.Type; } functions.Add(new BoundTraitFuncImplNode(function.Tokens, function.Name, function.Parameters, BindType(function.ReturnType), BindBlock(function.Body))); } return new BoundTraitImplNode(node.Tokens, node.Namespace, node.TraitType, node.ForType, functions); } private BoundTraitNode BindTraitDefinition(TraitNode node) { var functions = new List(); foreach (var func in node.Functions) { functions.Add(new BoundTraitFuncNode(node.Tokens, func.Name, func.Parameters, BindType(func.ReturnType))); } return new BoundTraitNode(node.Tokens, node.Namespace, node.Name, functions); } private BoundStructNode BindStruct(StructNode node) { var structFields = new List(); foreach (var field in node.Fields) { var value = Optional.Empty(); if (field.Value.HasValue) { value = BindExpression(field.Value.Value, field.Type); } structFields.Add(new BoundStructFieldNode(field.Tokens, field.Index, field.Name, field.Type, value)); } return new BoundStructNode(node.Tokens, node.Namespace, node.Name, structFields); } private BoundExternFuncNode BindExternFuncDefinition(ExternFuncNode node) { return new BoundExternFuncNode(node.Tokens, node.Namespace, node.Name, node.CallName, node.Parameters, BindType(node.ReturnType)); } private BoundLocalFuncNode BindLocalFuncDefinition(LocalFuncNode node) { _variables.Clear(); _functionReturnType = BindType(node.ReturnType); foreach (var parameter in node.Parameters) { _variables[parameter.Name] = parameter.Type; } var body = BindBlock(node.Body); return new BoundLocalFuncNode(node.Tokens, node.Namespace, node.Name, node.Parameters, body, BindType(node.ReturnType), node.Exported); } private BoundBlock BindBlock(BlockNode node) { var statements = new List(); foreach (var statement in node.Statements) { statements.Add(BindStatement(statement)); } return new BoundBlock(node.Tokens, statements); } private BoundStatementNode BindStatement(StatementNode node) { return node switch { AssignmentNode statement => BindAssignment(statement), BreakNode statement => BindBreak(statement), ContinueNode statement => BindContinue(statement), IfNode statement => BindIf(statement), ReturnNode statement => BindReturn(statement), StatementExpressionNode statement => BindStatementExpression(statement), VariableDeclarationNode statement => BindVariableDeclaration(statement), WhileNode statement => BindWhile(statement), _ => throw new ArgumentOutOfRangeException(nameof(node)) }; } private BoundStatementNode BindAssignment(AssignmentNode statement) { var expression = BindExpression(statement.Target); var value = BindExpression(statement.Value, expression.Type); return new BoundAssignmentNode(statement.Tokens, expression, value); } private BoundBreakNode BindBreak(BreakNode statement) { return new BoundBreakNode(statement.Tokens); } private BoundContinueNode BindContinue(ContinueNode statement) { return new BoundContinueNode(statement.Tokens); } private BoundIfNode BindIf(IfNode statement) { var elseStatement = Optional.Empty>(); if (statement.Else.HasValue) { elseStatement = statement.Else.Value.Match> ( elseIf => BindIf(elseIf), @else => BindBlock(@else) ); } return new BoundIfNode(statement.Tokens, BindExpression(statement.Condition, NubPrimitiveType.Bool), BindBlock(statement.Body), elseStatement); } private BoundReturnNode BindReturn(ReturnNode statement) { var value = Optional.Empty(); if (statement.Value.HasValue) { value = BindExpression(statement.Value.Value, _functionReturnType); } return new BoundReturnNode(statement.Tokens, value); } private BoundStatementExpressionNode BindStatementExpression(StatementExpressionNode statement) { return new BoundStatementExpressionNode(statement.Tokens, BindExpression(statement.Expression)); } private BoundVariableDeclarationNode BindVariableDeclaration(VariableDeclarationNode statement) { NubType? type = null; if (statement.ExplicitType.HasValue) { type = statement.ExplicitType.Value; } var assignment = Optional.Empty(); if (statement.Assignment.HasValue) { var boundValue = BindExpression(statement.Assignment.Value, type); assignment = boundValue; type = boundValue.Type; } if (type == null) { throw new NotImplementedException("Diagnostics not implemented"); } _variables[statement.Name] = type; return new BoundVariableDeclarationNode(statement.Tokens, statement.Name, statement.ExplicitType, assignment, type); } private BoundWhileNode BindWhile(WhileNode statement) { return new BoundWhileNode(statement.Tokens, BindExpression(statement.Condition, NubPrimitiveType.Bool), BindBlock(statement.Body)); } private BoundExpressionNode BindExpression(ExpressionNode node, NubType? expectedType = null) { return node switch { AddressOfNode expression => BindAddressOf(expression), AnonymousFuncNode expression => BindAnonymousFunc(expression), ArrayIndexAccessNode expression => BindArrayIndexAccess(expression), ArrayInitializerNode expression => BindArrayInitializer(expression), BinaryExpressionNode expression => BindBinaryExpression(expression), DereferenceNode expression => BindDereference(expression), FuncCallNode expression => BindFuncCall(expression), IdentifierNode expression => BindIdentifier(expression), LiteralNode expression => BindLiteral(expression, expectedType), MemberAccessNode expression => BindMemberAccess(expression), StructInitializerNode expression => BindStructInitializer(expression), UnaryExpressionNode expression => BindUnaryExpression(expression), _ => throw new ArgumentOutOfRangeException(nameof(node)) }; } private BoundAddressOfNode BindAddressOf(AddressOfNode expression) { var inner = BindExpression(expression.Expression); return new BoundAddressOfNode(expression.Tokens, new NubPointerType(inner.Type), inner); } private BoundAnonymousFuncNode BindAnonymousFunc(AnonymousFuncNode expression) { var parameterTypes = expression.Parameters.Select(x => x.Type).ToList(); var body = BindBlock(expression.Body); return new BoundAnonymousFuncNode(expression.Tokens, new NubFuncType(BindType(expression.ReturnType), parameterTypes), expression.Parameters, body, BindType(expression.ReturnType)); } private BoundArrayIndexAccessNode BindArrayIndexAccess(ArrayIndexAccessNode expression) { var boundArray = BindExpression(expression.Array); var elementType = ((NubArrayType)boundArray.Type).ElementType; return new BoundArrayIndexAccessNode(expression.Tokens, elementType, boundArray, BindExpression(expression.Index, NubPrimitiveType.U64)); } private BoundArrayInitializerNode BindArrayInitializer(ArrayInitializerNode expression) { return new BoundArrayInitializerNode(expression.Tokens, new NubArrayType(expression.ElementType), BindExpression(expression.Capacity, NubPrimitiveType.U64), expression.ElementType); } private BoundBinaryExpressionNode BindBinaryExpression(BinaryExpressionNode expression) { var boundLeft = BindExpression(expression.Left); var boundRight = BindExpression(expression.Right, boundLeft.Type); return new BoundBinaryExpressionNode(expression.Tokens, boundLeft.Type, boundLeft, expression.Operator, boundRight); } private BoundDereferenceNode BindDereference(DereferenceNode expression) { var boundExpression = BindExpression(expression.Expression); var dereferencedType = ((NubPointerType)boundExpression.Type).BaseType; return new BoundDereferenceNode(expression.Tokens, dereferencedType, boundExpression); } private BoundFuncCallNode BindFuncCall(FuncCallNode expression) { var boundExpression = BindExpression(expression.Expression); var funcType = (NubFuncType)boundExpression.Type; var returnType = BindType(funcType.ReturnType); var parameters = new List(); foreach (var (i, parameter) in expression.Parameters.Index()) { if (i >= funcType.Parameters.Count) { throw new NotImplementedException("Diagnostics not implemented"); } var expectedType = funcType.Parameters[i]; parameters.Add(BindExpression(parameter, expectedType)); } return new BoundFuncCallNode(expression.Tokens, returnType, boundExpression, parameters); } private BoundExpressionNode BindIdentifier(IdentifierNode expression) { var @namespace = expression.Namespace.Or(_syntaxTree.Namespace); var localFuncs = _definitionTable.LookupLocalFunc(@namespace, expression.Name).ToArray(); if (localFuncs.Length > 0) { if (localFuncs.Length > 1) { throw new BindException(Diagnostic.Error($"Extern func {expression.Namespace}::{expression.Name} has multiple definitions").Build()); } var localFunc = localFuncs[0]; var type = new NubFuncType(BindType(localFunc.ReturnType), localFunc.Parameters.Select(p => p.Type).ToList()); return new BoundLocalFuncIdentNode(expression.Tokens, type, @namespace, expression.Name); } var externFuncs = _definitionTable.LookupExternFunc(@namespace, expression.Name).ToArray(); if (externFuncs.Length > 0) { if (externFuncs.Length > 1) { throw new BindException(Diagnostic.Error($"Extern func {expression.Namespace}::{expression.Name} has multiple definitions").Build()); } var externFunc = externFuncs[0]; var type = new NubFuncType(BindType(externFunc.ReturnType), externFunc.Parameters.Select(p => p.Type).ToList()); return new BoundExternFuncIdentNode(expression.Tokens, type, @namespace, expression.Name); } if (!expression.Namespace.HasValue) { return new BoundVariableIdentNode(expression.Tokens, _variables[expression.Name], expression.Name); } throw new BindException(Diagnostic.Error($"No identifier with then name {(expression.Namespace.HasValue ? $"{expression.Namespace.Value}::" : "")}{expression.Name} exists").Build()); } private BoundLiteralNode BindLiteral(LiteralNode expression, NubType? expectedType = null) { var type = expectedType ?? expression.Kind switch { LiteralKind.Integer => NubPrimitiveType.I64, LiteralKind.Float => NubPrimitiveType.F64, LiteralKind.String => new NubStringType(), LiteralKind.Bool => NubPrimitiveType.Bool, _ => throw new ArgumentOutOfRangeException() }; return new BoundLiteralNode(expression.Tokens, type, expression.Literal, expression.Kind); } private BoundExpressionNode BindMemberAccess(MemberAccessNode expression) { var boundExpression = BindExpression(expression.Target); var traitFuncImpls = _definitionTable.LookupTraitImpl(boundExpression.Type).SelectMany(x => _definitionTable.LookupTraitFuncImpl(x, expression.Member)).ToArray(); if (traitFuncImpls.Length > 0) { if (traitFuncImpls.Length > 1) { throw new BindException(Diagnostic.Error($"Type {boundExpression.Type} implements multiple traits with the function {expression.Member}").Build()); } var impl = traitFuncImpls[0]; var type = new NubFuncType(BindType(impl.ReturnType), impl.Parameters.Select(p => BindType(p.Type)).ToList()); return new BoundTraitImplFuncAccessNode(expression.Tokens, type, boundExpression, expression.Member); } if (boundExpression.Type is NubCustomType customType) { var traits = _definitionTable.LookupTrait(customType.Namespace, customType.Name).ToArray(); if (traits.Length > 0) { if (traits.Length > 1) { throw new BindException(Diagnostic.Error($"Trait {customType.Namespace}::{customType.Name} has multiple definitions").Build()); } var trait = traits[0]; var traitFuncs = _definitionTable.LookupTraitFunc(trait, expression.Member).ToArray(); if (traits.Length > 0) { if (traits.Length > 1) { throw new BindException(Diagnostic.Error($"Trait {trait.Namespace}::{trait.Name} has multiple functions with the name {expression.Member}").Build()); } var traitFunc = traitFuncs[0]; var type = new NubFuncType(BindType(traitFunc.ReturnType), traitFunc.Parameters.Select(p => BindType(p.Type)).ToList()); return new BoundTraitFuncAccessNode(expression.Tokens, type, new NubTraitType(trait.Namespace, trait.Name), boundExpression, expression.Member); } } var structs = _definitionTable.LookupStruct(customType.Namespace, customType.Name).ToArray(); if (structs.Length > 0) { if (structs.Length > 1) { throw new BindException(Diagnostic.Error($"Struct {customType.Namespace}::{customType.Name} has multiple definitions").Build()); } var @struct = structs[0]; var fields = _definitionTable.LookupStructField(@struct, customType.Name).ToArray(); if (fields.Length > 0) { if (fields.Length > 1) { throw new BindException(Diagnostic.Error($"Struct {@struct.Namespace}::{@struct.Name} has multiple fields with the name {expression.Member}").Build()); } var field = fields[0]; return new BoundStructFieldAccessNode(expression.Tokens, field.Type, new NubStructType(@struct.Namespace, @struct.Name), boundExpression, expression.Member); } } } throw new BindException(Diagnostic.Error($"{boundExpression.Type} has not member with the name {expression.Member}").Build()); } private BoundStructInitializerNode BindStructInitializer(StructInitializerNode expression) { if (expression.StructType is not NubCustomType structType) { throw new BindException(Diagnostic.Error($"Cannot initialize non-struct type {expression.StructType}").Build()); } var structs = _definitionTable.LookupStruct(structType.Namespace, structType.Name).ToArray(); if (structs.Length == 0) { throw new BindException(Diagnostic.Error($"Struct {structType.Namespace}::{structType.Name} is not defined").Build()); } if (structs.Length > 1) { throw new BindException(Diagnostic.Error($"Struct {structType.Namespace}::{structType.Name} has multiple definitions").Build()); } var @struct = structs[0]; var initializers = new Dictionary(); foreach (var (field, initializer) in expression.Initializers) { var fields = _definitionTable.LookupStructField(@struct, field).ToArray(); if (fields.Length == 0) { throw new BindException(Diagnostic.Error($"Struct {@struct.Namespace}::{@struct.Name} does not have a field with the name {field}").Build()); } if (fields.Length > 1) { throw new BindException(Diagnostic.Error($"Struct {@struct.Namespace}::{@struct.Name} has multiple fields with the name {field}").Build()); } initializers[field] = BindExpression(initializer, fields[0].Type); } return new BoundStructInitializerNode(expression.Tokens, structType, structType, initializers); } private BoundUnaryExpressionNode BindUnaryExpression(UnaryExpressionNode expression) { var boundOperand = BindExpression(expression.Operand); NubType? type = null; switch (expression.Operator) { case UnaryExpressionOperator.Negate: { boundOperand = BindExpression(expression.Operand, NubPrimitiveType.I64); if (boundOperand.Type.IsNumber) { type = boundOperand.Type; } break; } case UnaryExpressionOperator.Invert: { boundOperand = BindExpression(expression.Operand, NubPrimitiveType.Bool); type = new NubPrimitiveType(PrimitiveTypeKind.Bool); break; } } if (type == null) { throw new NotImplementedException("Diagnostics not implemented"); } return new BoundUnaryExpressionNode(expression.Tokens, type, expression.Operator, boundOperand); } private NubType BindType(NubType type) { if (type is NubCustomType customType) { var structs = _definitionTable.LookupStruct(customType.Namespace, customType.Name).ToArray(); if (structs.Length > 1) { throw new BindException(Diagnostic.Error($"Struct {type} has multiple definitions").Build()); } var traits = _definitionTable.LookupTrait(customType.Namespace, customType.Name).ToArray(); if (traits.Length > 1) { throw new BindException(Diagnostic.Error($"Trait {type} has multiple definitions").Build()); } if (structs.Length == 0 && traits.Length == 0) { throw new BindException(Diagnostic.Error($"Failed to resolve type {type} to a struct or trait").Build()); } if (structs.Length > 0 && traits.Length > 0) { throw new BindException(Diagnostic.Error($"Unable to determine if type {type} is a struct or trait").WithHelp($"Make {type} is not defined as bot a struct and trait").Build()); } if (structs.Length == 1) { return new NubStructType(customType.Namespace, customType.Name); } if (traits.Length == 1) { return new NubTraitType(customType.Namespace, customType.Name); } throw new UnreachableException(); } return type; } } public class BindException : Exception { public Diagnostic Diagnostic { get; } public BindException(Diagnostic diagnostic) : base(diagnostic.Message) { Diagnostic = diagnostic; } }