using System.Diagnostics; using Common; using NubLang.Diagnostics; using NubLang.Syntax.Node; using NubLang.Syntax.Tokenization; using Node_UnaryExpressionNode = NubLang.Syntax.Node.UnaryExpressionNode; namespace NubLang.Syntax.Binding; // TODO: Currently anonymous function does not get a new scope public sealed class Binder { private readonly SyntaxTree _syntaxTree; private readonly DefinitionTable _definitionTable; // TODO: Implement proper variable tracking and scoping private Dictionary _variables = new(); private BoundNubType? _functionReturnType; public Binder(SyntaxTree syntaxTree, DefinitionTable definitionTable) { _syntaxTree = syntaxTree; _definitionTable = definitionTable; } public BoundSyntaxTree Bind() { _variables = []; _functionReturnType = null; var diagnostics = new List(); var topLevelNodes = new List(); foreach (var topLevel in _syntaxTree.TopLevelNodes) { try { topLevelNodes.Add(BindTopLevel(topLevel)); } catch (BindException e) { diagnostics.Add(e.Diagnostic); } } return new BoundSyntaxTree(_syntaxTree.Namespace, topLevelNodes, diagnostics); } private BoundTopLevelNode BindTopLevel(TopLevelNode node) { return node switch { ExternFuncNode topLevel => BindExternFuncDefinition(topLevel), TraitImplNode topLevel => BindTraitImplementation(topLevel), TraitNode topLevel => BindTraitDefinition(topLevel), LocalFuncNode topLevel => BindLocalFuncDefinition(topLevel), StructNode topLevel => BindStruct(topLevel), _ => throw new ArgumentOutOfRangeException(nameof(node)) }; } private BoundTraitImplNode BindTraitImplementation(TraitImplNode node) { _variables.Clear(); var functions = new List(); foreach (var function in node.Functions) { var parameters = new List(); foreach (var parameter in function.Parameters) { _variables[parameter.Name] = BindType(parameter.Type); parameters.Add(new BoundFuncParameterNode(parameter.Tokens, parameter.Name, BindType(parameter.Type))); } functions.Add(new BoundTraitFuncImplNode(function.Tokens, function.Name, parameters, BindType(function.ReturnType), BindBlock(function.Body))); } return new BoundTraitImplNode(node.Tokens, node.Namespace, BindType(node.TraitType), BindType(node.ForType), functions); } private BoundTraitNode BindTraitDefinition(TraitNode node) { var functions = new List(); foreach (var function in node.Functions) { var parameters = new List(); foreach (var parameter in function.Parameters) { parameters.Add(new BoundFuncParameterNode(parameter.Tokens, parameter.Name, BindType(parameter.Type))); } functions.Add(new BoundTraitFuncNode(node.Tokens, function.Name, parameters, BindType(function.ReturnType))); } return new BoundTraitNode(node.Tokens, node.Namespace, node.Name, functions); } private BoundStructNode BindStruct(StructNode node) { var structFields = new List(); foreach (var field in node.Fields) { var value = Optional.Empty(); if (field.Value.HasValue) { value = BindExpression(field.Value.Value, BindType(field.Type)); } structFields.Add(new BoundStructFieldNode(field.Tokens, field.Index, field.Name, BindType(field.Type), value)); } return new BoundStructNode(node.Tokens, node.Namespace, node.Name, structFields); } private BoundExternFuncNode BindExternFuncDefinition(ExternFuncNode node) { var parameters = new List(); foreach (var parameter in node.Parameters) { parameters.Add(new BoundFuncParameterNode(parameter.Tokens, parameter.Name, BindType(parameter.Type))); } return new BoundExternFuncNode(node.Tokens, node.Namespace, node.Name, node.CallName, parameters, BindType(node.ReturnType)); } private BoundLocalFuncNode BindLocalFuncDefinition(LocalFuncNode node) { _variables.Clear(); _functionReturnType = BindType(node.ReturnType); var parameters = new List(); foreach (var parameter in node.Parameters) { _variables[parameter.Name] = BindType(parameter.Type); parameters.Add(new BoundFuncParameterNode(parameter.Tokens, parameter.Name, BindType(parameter.Type))); } var body = BindBlock(node.Body); return new BoundLocalFuncNode(node.Tokens, node.Namespace, node.Name, parameters, body, BindType(node.ReturnType), node.Exported); } private BoundBlock BindBlock(BlockNode node) { var statements = new List(); foreach (var statement in node.Statements) { statements.Add(BindStatement(statement)); } return new BoundBlock(node.Tokens, statements); } private BoundStatementNode BindStatement(StatementNode node) { return node switch { AssignmentNode statement => BindAssignment(statement), BreakNode statement => BindBreak(statement), ContinueNode statement => BindContinue(statement), IfNode statement => BindIf(statement), ReturnNode statement => BindReturn(statement), StatementExpressionNode statement => BindStatementExpression(statement), VariableDeclarationNode statement => BindVariableDeclaration(statement), WhileNode statement => BindWhile(statement), _ => throw new ArgumentOutOfRangeException(nameof(node)) }; } private BoundStatementNode BindAssignment(AssignmentNode statement) { var expression = BindExpression(statement.Target); var value = BindExpression(statement.Value, expression.Type); return new BoundAssignmentNode(statement.Tokens, expression, value); } private BoundBreakNode BindBreak(BreakNode statement) { return new BoundBreakNode(statement.Tokens); } private BoundContinueNode BindContinue(ContinueNode statement) { return new BoundContinueNode(statement.Tokens); } private BoundIfNode BindIf(IfNode statement) { var elseStatement = Optional.Empty>(); if (statement.Else.HasValue) { elseStatement = statement.Else.Value.Match> ( elseIf => BindIf(elseIf), @else => BindBlock(@else) ); } return new BoundIfNode(statement.Tokens, BindExpression(statement.Condition, BoundNubPrimitiveType.Bool), BindBlock(statement.Body), elseStatement); } private BoundReturnNode BindReturn(ReturnNode statement) { var value = Optional.Empty(); if (statement.Value.HasValue) { value = BindExpression(statement.Value.Value, _functionReturnType); } return new BoundReturnNode(statement.Tokens, value); } private BoundStatementExpressionNode BindStatementExpression(StatementExpressionNode statement) { return new BoundStatementExpressionNode(statement.Tokens, BindExpression(statement.Expression)); } private BoundVariableDeclarationNode BindVariableDeclaration(VariableDeclarationNode statement) { BoundNubType? type = null; if (statement.ExplicitType.HasValue) { type = BindType(statement.ExplicitType.Value); } var assignment = Optional.Empty(); if (statement.Assignment.HasValue) { var boundValue = BindExpression(statement.Assignment.Value, type); assignment = boundValue; type = boundValue.Type; } if (type == null) { throw new NotImplementedException("Diagnostics not implemented"); } _variables[statement.Name] = type; return new BoundVariableDeclarationNode(statement.Tokens, statement.Name, assignment, type); } private BoundWhileNode BindWhile(WhileNode statement) { return new BoundWhileNode(statement.Tokens, BindExpression(statement.Condition, BoundNubPrimitiveType.Bool), BindBlock(statement.Body)); } private BoundExpressionNode BindExpression(ExpressionNode node, BoundNubType? expectedType = null) { return node switch { AddressOfNode expression => BindAddressOf(expression), AnonymousFuncNode expression => BindAnonymousFunc(expression), ArrayIndexAccessNode expression => BindArrayIndexAccess(expression), ArrayInitializerNode expression => BindArrayInitializer(expression), BinaryExpressionNode expression => BindBinaryExpression(expression), DereferenceNode expression => BindDereference(expression), FuncCallNode expression => BindFuncCall(expression), IdentifierNode expression => BindIdentifier(expression), LiteralNode expression => BindLiteral(expression, expectedType), MemberAccessNode expression => BindMemberAccess(expression), StructInitializerNode expression => BindStructInitializer(expression), Node_UnaryExpressionNode expression => BindUnaryExpression(expression), _ => throw new ArgumentOutOfRangeException(nameof(node)) }; } private BoundAddressOfNode BindAddressOf(AddressOfNode expression) { var inner = BindExpression(expression.Expression); return new BoundAddressOfNode(expression.Tokens, new BoundNubPointerType(inner.Type), inner); } private BoundAnonymousFuncNode BindAnonymousFunc(AnonymousFuncNode expression) { var parameters = new List(); foreach (var parameter in expression.Parameters) { parameters.Add(new BoundFuncParameterNode(parameter.Tokens, parameter.Name, BindType(parameter.Type))); } var body = BindBlock(expression.Body); return new BoundAnonymousFuncNode(expression.Tokens, new BoundNubFuncType(BindType(expression.ReturnType), parameters.Select(x => x.Type).ToList()), parameters, body, BindType(expression.ReturnType)); } private BoundArrayIndexAccessNode BindArrayIndexAccess(ArrayIndexAccessNode expression) { var boundArray = BindExpression(expression.Target); var elementType = ((BoundNubArrayType)boundArray.Type).ElementType; return new BoundArrayIndexAccessNode(expression.Tokens, elementType, boundArray, BindExpression(expression.Index, BoundNubPrimitiveType.U64)); } private BoundArrayInitializerNode BindArrayInitializer(ArrayInitializerNode expression) { return new BoundArrayInitializerNode(expression.Tokens, new BoundNubArrayType(BindType(expression.ElementType)), BindExpression(expression.Capacity, BoundNubPrimitiveType.U64), BindType(expression.ElementType)); } private BoundBinaryExpressionNode BindBinaryExpression(BinaryExpressionNode expression) { var boundLeft = BindExpression(expression.Left); var boundRight = BindExpression(expression.Right, boundLeft.Type); return new BoundBinaryExpressionNode(expression.Tokens, boundLeft.Type, boundLeft, expression.Operator, boundRight); } private BoundDereferenceNode BindDereference(DereferenceNode expression) { var boundExpression = BindExpression(expression.Expression); var dereferencedType = ((BoundNubPointerType)boundExpression.Type).BaseType; return new BoundDereferenceNode(expression.Tokens, dereferencedType, boundExpression); } private BoundFuncCallNode BindFuncCall(FuncCallNode expression) { var boundExpression = BindExpression(expression.Expression); var funcType = (BoundNubFuncType)boundExpression.Type; var parameters = new List(); foreach (var (i, parameter) in expression.Parameters.Index()) { if (i >= funcType.Parameters.Count) { throw new NotImplementedException("Diagnostics not implemented"); } var expectedType = funcType.Parameters[i]; parameters.Add(BindExpression(parameter, expectedType)); } return new BoundFuncCallNode(expression.Tokens, funcType.ReturnType, boundExpression, parameters); } private BoundExpressionNode BindIdentifier(IdentifierNode expression) { var @namespace = expression.Namespace.Or(_syntaxTree.Namespace); var localFuncs = _definitionTable.LookupLocalFunc(@namespace, expression.Name).ToArray(); if (localFuncs.Length > 0) { if (localFuncs.Length > 1) { throw new BindException(Diagnostic.Error($"Extern func {expression.Namespace}::{expression.Name} has multiple definitions").Build()); } var localFunc = localFuncs[0]; var type = new NubFuncType(localFunc.ReturnType, localFunc.Parameters.Select(p => p.Type).ToList()); return new BoundLocalFuncIdentNode(expression.Tokens, BindType(type), @namespace, expression.Name); } var externFuncs = _definitionTable.LookupExternFunc(@namespace, expression.Name).ToArray(); if (externFuncs.Length > 0) { if (externFuncs.Length > 1) { throw new BindException(Diagnostic.Error($"Extern func {expression.Namespace}::{expression.Name} has multiple definitions").Build()); } var externFunc = externFuncs[0]; var type = new NubFuncType(externFunc.ReturnType, externFunc.Parameters.Select(p => p.Type).ToList()); return new BoundExternFuncIdentNode(expression.Tokens, BindType(type), @namespace, expression.Name); } if (!expression.Namespace.HasValue) { return new BoundVariableIdentNode(expression.Tokens, _variables[expression.Name], expression.Name); } throw new BindException(Diagnostic.Error($"No identifier with then name {(expression.Namespace.HasValue ? $"{expression.Namespace.Value}::" : "")}{expression.Name} exists").Build()); } private BoundLiteralNode BindLiteral(LiteralNode expression, BoundNubType? expectedType = null) { var type = expectedType ?? expression.Kind switch { LiteralKind.Integer => BoundNubPrimitiveType.I64, LiteralKind.Float => BoundNubPrimitiveType.F64, LiteralKind.String => new BoundNubStringType(), LiteralKind.Bool => BoundNubPrimitiveType.Bool, _ => throw new ArgumentOutOfRangeException() }; return new BoundLiteralNode(expression.Tokens, type, expression.Literal, expression.Kind); } private BoundExpressionNode BindMemberAccess(MemberAccessNode expression) { var boundExpression = BindExpression(expression.Target); var traitFuncImpls = _definitionTable.LookupTraitFuncImpl(UnbindType(boundExpression.Type), expression.Member).ToArray(); if (traitFuncImpls.Length > 0) { if (traitFuncImpls.Length > 1) { throw new BindException(Diagnostic.Error($"Type {boundExpression.Type} implements multiple traits with the function {expression.Member}").Build()); } var impl = traitFuncImpls[0]; var type = new BoundNubFuncType(BindType(impl.ReturnType), impl.Parameters.Select(p => BindType(p.Type)).ToList()); return new BoundTraitImplFuncAccessNode(expression.Tokens, type, boundExpression, expression.Member); } if (boundExpression.Type is BoundNubTraitType traitType) { var traits = _definitionTable.LookupTrait(traitType.Namespace, traitType.Name).ToArray(); if (traits.Length > 0) { if (traits.Length > 1) { throw new BindException(Diagnostic.Error($"Trait {traitType.Namespace}::{traitType.Name} has multiple definitions").Build()); } var trait = traits[0]; var traitFuncs = _definitionTable.LookupTraitFunc(trait, expression.Member).ToArray(); if (traits.Length > 0) { if (traits.Length > 1) { throw new BindException(Diagnostic.Error($"Trait {trait.Namespace}::{trait.Name} has multiple functions with the name {expression.Member}").Build()); } var traitFunc = traitFuncs[0]; var type = new BoundNubFuncType(BindType(traitFunc.ReturnType), traitFunc.Parameters.Select(p => BindType(p.Type)).ToList()); return new BoundTraitFuncAccessNode(expression.Tokens, type, traitType, boundExpression, expression.Member); } } } if (boundExpression.Type is BoundNubStructType structType) { var structs = _definitionTable.LookupStruct(structType.Namespace, structType.Name).ToArray(); if (structs.Length > 0) { if (structs.Length > 1) { throw new BindException(Diagnostic.Error($"Struct {structType.Namespace}::{structType.Name} has multiple definitions").Build()); } var @struct = structs[0]; var fields = _definitionTable.LookupStructField(@struct, expression.Member).ToArray(); if (fields.Length > 0) { if (fields.Length > 1) { throw new BindException(Diagnostic.Error($"Struct {@struct.Namespace}::{@struct.Name} has multiple fields with the name {expression.Member}").Build()); } var field = fields[0]; return new BoundStructFieldAccessNode(expression.Tokens, BindType(field.Type), structType, boundExpression, expression.Member); } } } throw new BindException(Diagnostic.Error($"{boundExpression.Type} does not have a member with the name {expression.Member}").Build()); } private BoundStructInitializerNode BindStructInitializer(StructInitializerNode expression) { if (expression.StructType is not NubCustomType structType) { throw new BindException(Diagnostic.Error($"Cannot initialize non-struct type {expression.StructType}").Build()); } var structs = _definitionTable.LookupStruct(structType.Namespace, structType.Name).ToArray(); if (structs.Length == 0) { throw new BindException(Diagnostic.Error($"Struct {structType.Namespace}::{structType.Name} is not defined").Build()); } if (structs.Length > 1) { throw new BindException(Diagnostic.Error($"Struct {structType.Namespace}::{structType.Name} has multiple definitions").Build()); } var @struct = structs[0]; var initializers = new Dictionary(); foreach (var (field, initializer) in expression.Initializers) { var fields = _definitionTable.LookupStructField(@struct, field).ToArray(); if (fields.Length == 0) { throw new BindException(Diagnostic.Error($"Struct {@struct.Namespace}::{@struct.Name} does not have a field with the name {field}").Build()); } if (fields.Length > 1) { throw new BindException(Diagnostic.Error($"Struct {@struct.Namespace}::{@struct.Name} has multiple fields with the name {field}").Build()); } initializers[field] = BindExpression(initializer, BindType(fields[0].Type)); } return new BoundStructInitializerNode(expression.Tokens, BindType(structType), new BoundNubStructType(@struct.Namespace, @struct.Name), initializers); } private BoundUnaryExpressionNode BindUnaryExpression(Node_UnaryExpressionNode expression) { var boundOperand = BindExpression(expression.Operand); BoundNubType? type = null; switch (expression.Operator) { case UnaryExpressionOperator.Negate: { boundOperand = BindExpression(expression.Operand, BoundNubPrimitiveType.I64); if (boundOperand.Type.IsNumber) { type = boundOperand.Type; } break; } case UnaryExpressionOperator.Invert: { boundOperand = BindExpression(expression.Operand, BoundNubPrimitiveType.Bool); type = new BoundNubPrimitiveType(PrimitiveTypeKind.Bool); break; } } if (type == null) { throw new NotImplementedException("Diagnostics not implemented"); } return new BoundUnaryExpressionNode(expression.Tokens, type, expression.Operator, boundOperand); } private BoundNubType BindType(NubType type) { switch (type) { case NubCustomType customType: { var structs = _definitionTable.LookupStruct(customType.Namespace, customType.Name).ToArray(); if (structs.Length > 1) { throw new BindException(Diagnostic.Error($"Struct {type} has multiple definitions").Build()); } var traits = _definitionTable.LookupTrait(customType.Namespace, customType.Name).ToArray(); if (traits.Length > 1) { throw new BindException(Diagnostic.Error($"Trait {type} has multiple definitions").Build()); } if (structs.Length == 0 && traits.Length == 0) { throw new BindException(Diagnostic.Error($"Failed to resolve type {type} to a struct or trait").Build()); } if (structs.Length > 0 && traits.Length > 0) { throw new BindException(Diagnostic.Error($"Unable to determine if type {type} is a struct or trait").WithHelp($"Make {type} is not defined as bot a struct and trait").Build()); } if (structs.Length == 1) { return new BoundNubStructType(customType.Namespace, customType.Name); } if (traits.Length == 1) { return new BoundNubTraitType(customType.Namespace, customType.Name); } throw new UnreachableException(); } case NubArrayType arrayType: return new BoundNubArrayType(BindType(arrayType.ElementType)); case NubCStringType: return new BoundNubCStringType(); case NubStringType: return new BoundNubStringType(); case NubFuncType funcType: return new BoundNubFuncType(BindType(funcType.ReturnType), funcType.Parameters.Select(BindType).ToList()); case NubPointerType pointerType: return new BoundNubPointerType(BindType(pointerType.BaseType)); case NubPrimitiveType primitiveType: return new BoundNubPrimitiveType(primitiveType.Kind); case NubVoidType: return new BoundNubVoidType(); default: throw new ArgumentOutOfRangeException(nameof(type)); } } private NubType UnbindType(BoundNubType type) { return type switch { BoundNubArrayType arrayType => new NubArrayType(UnbindType(arrayType.ElementType)), BoundNubCStringType => new NubCStringType(), BoundNubStringType => new NubStringType(), BoundNubStructType structType => new NubCustomType(structType.Namespace, structType.Name), BoundNubTraitType traitType => new NubCustomType(traitType.Namespace, traitType.Name), BoundNubFuncType funcType => new NubFuncType(UnbindType(funcType.ReturnType), funcType.Parameters.Select(UnbindType).ToList()), BoundNubPointerType pointerType => new NubPointerType(UnbindType(pointerType.BaseType)), BoundNubPrimitiveType primitiveType => new NubPrimitiveType(primitiveType.Kind), BoundNubVoidType => new NubVoidType(), _ => throw new ArgumentOutOfRangeException(nameof(type)) }; } } public class BindException : Exception { public Diagnostic Diagnostic { get; } public BindException(Diagnostic diagnostic) : base(diagnostic.Message) { Diagnostic = diagnostic; } }