using NubLang.Common; using NubLang.Diagnostics; using NubLang.Syntax.Binding.Node; using NubLang.Syntax.Parsing.Node; using NubLang.Syntax.Tokenization; namespace NubLang.Syntax.Binding; public sealed class Binder { private readonly SyntaxTree _syntaxTree; private readonly DefinitionTable _definitionTable; private readonly Stack _scopes = []; private readonly Stack _funcReturnTypes = []; private Scope Scope => _scopes.Peek(); public Binder(SyntaxTree syntaxTree, DefinitionTable definitionTable) { _syntaxTree = syntaxTree; _definitionTable = definitionTable; } public BoundSyntaxTree Bind() { _funcReturnTypes.Clear(); _scopes.Clear(); var diagnostics = new List(); var definitions = new List(); foreach (var definition in _syntaxTree.Definitions) { try { definitions.Add(BindDefinition(definition)); } catch (BindException e) { diagnostics.Add(e.Diagnostic); } } return new BoundSyntaxTree(_syntaxTree.Namespace, definitions, diagnostics); } private BoundDefinition BindDefinition(DefinitionSyntax node) { return node switch { ExternFuncSyntax definition => BindExternFuncDefinition(definition), TraitImplSyntax definition => BindTraitImplementation(definition), TraitSyntax definition => BindTraitDefinition(definition), LocalFuncSyntax definition => BindLocalFuncDefinition(definition), StructSyntax definition => BindStruct(definition), _ => throw new ArgumentOutOfRangeException(nameof(node)) }; } private BoundTraitImpl BindTraitImplementation(TraitImplSyntax node) { var functions = new List(); foreach (var func in node.Functions) { var signature = BindFuncSignature(func.Signature); var body = BindFuncBody(func.Body, signature.ReturnType, signature.Parameters); functions.Add(new BoundTraitFuncImpl(func.Tokens, func.Name, signature, body)); } return new BoundTraitImpl(node.Tokens, node.Namespace, node.TraitType, node.ForType, functions); } private BoundTrait BindTraitDefinition(TraitSyntax node) { var functions = new List(); foreach (var function in node.Functions) { functions.Add(new BoundTraitFunc(node.Tokens, function.Name, BindFuncSignature(function.Signature))); } return new BoundTrait(node.Tokens, node.Namespace, node.Name, functions); } private BoundStruct BindStruct(StructSyntax node) { var structFields = new List(); foreach (var field in node.Fields) { var value = Optional.Empty(); if (field.Value.HasValue) { value = BindExpression(field.Value.Value, field.Type); } structFields.Add(new BoundStructField(field.Tokens, field.Index, field.Name, field.Type, value)); } return new BoundStruct(node.Tokens, node.Namespace, node.Name, structFields); } private BoundExternFunc BindExternFuncDefinition(ExternFuncSyntax node) { return new BoundExternFunc(node.Tokens, node.Namespace, node.Name, node.CallName, BindFuncSignature(node.Signature)); } private BoundLocalFunc BindLocalFuncDefinition(LocalFuncSyntax node) { var signature = BindFuncSignature(node.Signature); var body = BindFuncBody(node.Body, signature.ReturnType, signature.Parameters); return new BoundLocalFunc(node.Tokens, node.Namespace, node.Name, signature, body); } private BoundStatement BindStatement(StatementSyntax node) { return node switch { AssignmentSyntax statement => BindAssignment(statement), BreakSyntax statement => BindBreak(statement), ContinueSyntax statement => BindContinue(statement), IfSyntax statement => BindIf(statement), ReturnSyntax statement => BindReturn(statement), StatementExpressionSyntax statement => BindStatementExpression(statement), VariableDeclarationSyntax statement => BindVariableDeclaration(statement), WhileSyntax statement => BindWhile(statement), _ => throw new ArgumentOutOfRangeException(nameof(node)) }; } private BoundStatement BindAssignment(AssignmentSyntax statement) { var expression = BindExpression(statement.Target); var value = BindExpression(statement.Value, expression.Type); return new BoundAssignment(statement.Tokens, expression, value); } private BoundBreak BindBreak(BreakSyntax statement) { return new BoundBreak(statement.Tokens); } private BoundContinue BindContinue(ContinueSyntax statement) { return new BoundContinue(statement.Tokens); } private BoundIf BindIf(IfSyntax statement) { var elseStatement = Optional.Empty>(); if (statement.Else.HasValue) { elseStatement = statement.Else.Value.Match> ( elseIf => BindIf(elseIf), @else => BindBlock(@else) ); } return new BoundIf(statement.Tokens, BindExpression(statement.Condition, new NubPrimitiveType(PrimitiveTypeKind.Bool)), BindBlock(statement.Body), elseStatement); } private BoundReturn BindReturn(ReturnSyntax statement) { var value = Optional.Empty(); if (statement.Value.HasValue) { value = BindExpression(statement.Value.Value, _funcReturnTypes.Peek()); } return new BoundReturn(statement.Tokens, value); } private BoundStatementExpression BindStatementExpression(StatementExpressionSyntax statement) { return new BoundStatementExpression(statement.Tokens, BindExpression(statement.Expression)); } private BoundVariableDeclaration BindVariableDeclaration(VariableDeclarationSyntax statement) { NubType? type = null; if (statement.ExplicitType.HasValue) { type = statement.ExplicitType.Value; } var assignment = Optional.Empty(); if (statement.Assignment.HasValue) { var boundValue = BindExpression(statement.Assignment.Value, type); assignment = boundValue; type = boundValue.Type; } if (type == null) { throw new NotImplementedException("Diagnostics not implemented"); } Scope.Declare(new Variable(statement.Name, type)); return new BoundVariableDeclaration(statement.Tokens, statement.Name, assignment, type); } private BoundWhile BindWhile(WhileSyntax statement) { return new BoundWhile(statement.Tokens, BindExpression(statement.Condition, new NubPrimitiveType(PrimitiveTypeKind.Bool)), BindBlock(statement.Body)); } private BoundExpression BindExpression(ExpressionSyntax node, NubType? expectedType = null) { return node switch { AddressOfSyntax expression => BindAddressOf(expression), ArrowFuncSyntax expression => BindArrowFunc(expression, expectedType), ArrayIndexAccessSyntax expression => BindArrayIndexAccess(expression), ArrayInitializerSyntax expression => BindArrayInitializer(expression), BinaryExpressionSyntax expression => BindBinaryExpression(expression), DereferenceSyntax expression => BindDereference(expression), FuncCallSyntax expression => BindFuncCall(expression), IdentifierSyntax expression => BindIdentifier(expression), LiteralSyntax expression => BindLiteral(expression, expectedType), MemberAccessSyntax expression => BindMemberAccess(expression), StructInitializerSyntax expression => BindStructInitializer(expression), UnaryExpressionSyntax expression => BindUnaryExpression(expression), _ => throw new ArgumentOutOfRangeException(nameof(node)) }; } private BoundAddressOf BindAddressOf(AddressOfSyntax expression) { var inner = BindExpression(expression.Expression); return new BoundAddressOf(expression.Tokens, new NubPointerType(inner.Type), inner); } private BoundArrowFunc BindArrowFunc(ArrowFuncSyntax expression, NubType? expectedType = null) { if (expectedType == null) { throw new BindException(Diagnostic.Error("Cannot infer argument types for arrow function").At(expression).Build()); } if (expectedType is not NubFuncType funcType) { throw new BindException(Diagnostic.Error($"Expected {expectedType}, but got arrow function").At(expression).Build()); } var parameters = new List(); for (var i = 0; i < expression.Parameters.Count; i++) { if (i >= funcType.Parameters.Count) { throw new BindException(Diagnostic.Error($"Arrow function expected a maximum of {funcType.Parameters.Count} arguments").Build()); } var expectedParameterType = funcType.Parameters[i]; var parameter = expression.Parameters[i]; parameters.Add(new BoundFuncParameter(parameter.Tokens, parameter.Name, expectedParameterType)); } var body = BindFuncBody(expression.Body, funcType.ReturnType, parameters); return new BoundArrowFunc(expression.Tokens, new NubFuncType(funcType.ReturnType, parameters.Select(x => x.Type).ToList()), parameters, funcType.ReturnType, body); } private BoundArrayIndexAccess BindArrayIndexAccess(ArrayIndexAccessSyntax expression) { var boundArray = BindExpression(expression.Target); var elementType = ((NubArrayType)boundArray.Type).ElementType; return new BoundArrayIndexAccess(expression.Tokens, elementType, boundArray, BindExpression(expression.Index, new NubPrimitiveType(PrimitiveTypeKind.U64))); } private BoundArrayInitializer BindArrayInitializer(ArrayInitializerSyntax expression) { return new BoundArrayInitializer(expression.Tokens, new NubArrayType(expression.ElementType), BindExpression(expression.Capacity, new NubPrimitiveType(PrimitiveTypeKind.U64)), expression.ElementType); } private BoundBinaryExpression BindBinaryExpression(BinaryExpressionSyntax expression) { var boundLeft = BindExpression(expression.Left); var boundRight = BindExpression(expression.Right, boundLeft.Type); return new BoundBinaryExpression(expression.Tokens, boundLeft.Type, boundLeft, BindBinaryOperator(expression.Operator), boundRight); } private BoundDereference BindDereference(DereferenceSyntax expression) { var boundExpression = BindExpression(expression.Expression); var dereferencedType = ((NubPointerType)boundExpression.Type).BaseType; return new BoundDereference(expression.Tokens, dereferencedType, boundExpression); } private BoundFuncCall BindFuncCall(FuncCallSyntax expression) { var boundExpression = BindExpression(expression.Expression); var funcType = (NubFuncType)boundExpression.Type; var parameters = new List(); foreach (var (i, parameter) in expression.Parameters.Index()) { if (i >= funcType.Parameters.Count) { throw new NotImplementedException("Diagnostics not implemented"); } var expectedType = funcType.Parameters[i]; parameters.Add(BindExpression(parameter, expectedType)); } return new BoundFuncCall(expression.Tokens, funcType.ReturnType, boundExpression, parameters); } private BoundExpression BindIdentifier(IdentifierSyntax expression) { var @namespace = expression.Namespace.Or(_syntaxTree.Namespace); var localFuncs = _definitionTable.LookupLocalFunc(@namespace, expression.Name).ToArray(); if (localFuncs.Length > 0) { if (localFuncs.Length > 1) { throw new BindException(Diagnostic.Error($"Extern func {expression.Namespace}::{expression.Name} has multiple definitions").Build()); } var localFunc = localFuncs[0]; var type = new NubFuncType(localFunc.Signature.ReturnType, localFunc.Signature.Parameters.Select(p => p.Type).ToList()); return new BoundLocalFuncIdent(expression.Tokens, type, @namespace, expression.Name); } var externFuncs = _definitionTable.LookupExternFunc(@namespace, expression.Name).ToArray(); if (externFuncs.Length > 0) { if (externFuncs.Length > 1) { throw new BindException(Diagnostic.Error($"Extern func {expression.Namespace}::{expression.Name} has multiple definitions").Build()); } var externFunc = externFuncs[0]; var type = new NubFuncType(externFunc.Signature.ReturnType, externFunc.Signature.Parameters.Select(p => p.Type).ToList()); return new BoundExternFuncIdent(expression.Tokens, type, @namespace, expression.Name); } if (!expression.Namespace.HasValue) { var variable = Scope.Lookup(expression.Name); if (variable != null) { return new BoundVariableIdent(expression.Tokens, variable.Type, variable.Name); } } throw new BindException(Diagnostic.Error($"No identifier with the name {(expression.Namespace.HasValue ? $"{expression.Namespace.Value}::" : "")}{expression.Name} exists").Build()); } private BoundLiteral BindLiteral(LiteralSyntax expression, NubType? expectedType = null) { var type = expectedType ?? expression.Kind switch { LiteralKind.Integer => new NubPrimitiveType(PrimitiveTypeKind.I64), LiteralKind.Float => new NubPrimitiveType(PrimitiveTypeKind.F64), LiteralKind.String => new NubStringType(), LiteralKind.Bool => new NubPrimitiveType(PrimitiveTypeKind.Bool), _ => throw new ArgumentOutOfRangeException() }; return new BoundLiteral(expression.Tokens, type, expression.Literal, expression.Kind); } private BoundExpression BindMemberAccess(MemberAccessSyntax expression) { var boundExpression = BindExpression(expression.Target); var traitFuncImpls = _definitionTable.LookupTraitFuncImpl(boundExpression.Type, expression.Member).ToArray(); if (traitFuncImpls.Length > 0) { if (traitFuncImpls.Length > 1) { throw new BindException(Diagnostic.Error($"Type {boundExpression.Type} implements multiple traits with the function {expression.Member}").Build()); } var impl = traitFuncImpls[0]; var type = new NubFuncType(impl.Signature.ReturnType, impl.Signature.Parameters.Select(p => p.Type).ToList()); return new BoundTraitImplFuncAccess(expression.Tokens, type, boundExpression, expression.Member); } if (boundExpression.Type is NubCustomType customType) { var traits = _definitionTable.LookupTrait(customType.Namespace, customType.Name).ToArray(); if (traits.Length > 0) { if (traits.Length > 1) { throw new BindException(Diagnostic.Error($"Trait {customType} has multiple definitions").Build()); } var trait = traits[0]; var traitFuncs = _definitionTable.LookupTraitFunc(trait, expression.Member).ToArray(); if (traits.Length > 0) { if (traits.Length > 1) { throw new BindException(Diagnostic.Error($"Trait {customType} has multiple functions with the name {expression.Member}").Build()); } var traitFunc = traitFuncs[0]; var type = new NubFuncType(traitFunc.Signature.ReturnType, traitFunc.Signature.Parameters.Select(p => p.Type).ToList()); return new BoundTraitFuncAccess(expression.Tokens, type, customType, boundExpression, expression.Member); } } var structs = _definitionTable.LookupStruct(customType.Namespace, customType.Name).ToArray(); if (structs.Length > 0) { if (structs.Length > 1) { throw new BindException(Diagnostic.Error($"Struct {customType} has multiple definitions").Build()); } var @struct = structs[0]; var fields = _definitionTable.LookupStructField(@struct, expression.Member).ToArray(); if (fields.Length > 0) { if (fields.Length > 1) { throw new BindException(Diagnostic.Error($"Struct {customType} has multiple fields with the name {expression.Member}").Build()); } var field = fields[0]; return new BoundStructFieldAccess(expression.Tokens, field.Type, customType, boundExpression, expression.Member); } } } throw new BindException(Diagnostic.Error($"{boundExpression.Type} does not have a member with the name {expression.Member}").Build()); } private BoundStructInitializer BindStructInitializer(StructInitializerSyntax expression) { if (expression.StructType is not NubCustomType structType) { throw new BindException(Diagnostic.Error($"Cannot initialize non-struct type {expression.StructType}").Build()); } var structs = _definitionTable.LookupStruct(structType.Namespace, structType.Name).ToArray(); if (structs.Length == 0) { throw new BindException(Diagnostic.Error($"Struct {structType} is not defined").Build()); } if (structs.Length > 1) { throw new BindException(Diagnostic.Error($"Struct {structType} has multiple definitions").Build()); } var @struct = structs[0]; var initializers = new Dictionary(); foreach (var (field, initializer) in expression.Initializers) { var fields = _definitionTable.LookupStructField(@struct, field).ToArray(); if (fields.Length == 0) { throw new BindException(Diagnostic.Error($"Struct {structType} does not have a field with the name {field}").Build()); } if (fields.Length > 1) { throw new BindException(Diagnostic.Error($"Struct {structType} has multiple fields with the name {field}").Build()); } initializers[field] = BindExpression(initializer, fields[0].Type); } return new BoundStructInitializer(expression.Tokens, structType, initializers); } private BoundUnaryExpression BindUnaryExpression(UnaryExpressionSyntax expression) { var boundOperand = BindExpression(expression.Operand); NubType? type = null; switch (expression.Operator) { case UnaryOperator.Negate: { boundOperand = BindExpression(expression.Operand, new NubPrimitiveType(PrimitiveTypeKind.I64)); if (boundOperand.Type.IsNumber) { type = boundOperand.Type; } break; } case UnaryOperator.Invert: { boundOperand = BindExpression(expression.Operand, new NubPrimitiveType(PrimitiveTypeKind.Bool)); type = new NubPrimitiveType(PrimitiveTypeKind.Bool); break; } } if (type == null) { throw new NotImplementedException("Diagnostics not implemented"); } return new BoundUnaryExpression(expression.Tokens, type, BindBinaryOperator(expression.Operator), boundOperand); } private BoundFuncSignature BindFuncSignature(FuncSignatureSyntax node) { var parameters = new List(); foreach (var parameter in node.Parameters) { parameters.Add(new BoundFuncParameter(parameter.Tokens, parameter.Name, parameter.Type)); } return new BoundFuncSignature(node.Tokens, parameters, node.ReturnType); } private BoundBinaryOperator BindBinaryOperator(BinaryOperator op) { return op switch { BinaryOperator.Equal => BoundBinaryOperator.Equal, BinaryOperator.NotEqual => BoundBinaryOperator.NotEqual, BinaryOperator.GreaterThan => BoundBinaryOperator.GreaterThan, BinaryOperator.GreaterThanOrEqual => BoundBinaryOperator.GreaterThanOrEqual, BinaryOperator.LessThan => BoundBinaryOperator.LessThan, BinaryOperator.LessThanOrEqual => BoundBinaryOperator.LessThanOrEqual, BinaryOperator.Plus => BoundBinaryOperator.Plus, BinaryOperator.Minus => BoundBinaryOperator.Minus, BinaryOperator.Multiply => BoundBinaryOperator.Multiply, BinaryOperator.Divide => BoundBinaryOperator.Divide, _ => throw new ArgumentOutOfRangeException(nameof(op), op, null) }; } private BoundUnaryOperator BindBinaryOperator(UnaryOperator op) { return op switch { UnaryOperator.Negate => BoundUnaryOperator.Negate, UnaryOperator.Invert => BoundUnaryOperator.Invert, _ => throw new ArgumentOutOfRangeException(nameof(op), op, null) }; } private BoundBlock BindBlock(BlockSyntax node, Scope? scope = null) { var statements = new List(); _scopes.Push(scope ?? Scope.SubScope()); foreach (var statement in node.Statements) { statements.Add(BindStatement(statement)); } _scopes.Pop(); return new BoundBlock(node.Tokens, statements); } private BoundBlock BindFuncBody(BlockSyntax block, NubType returnType, IReadOnlyList parameters) { _funcReturnTypes.Push(returnType); var scope = new Scope(); foreach (var parameter in parameters) { scope.Declare(new Variable(parameter.Name, parameter.Type)); } var body = BindBlock(block, scope); _funcReturnTypes.Pop(); return body; } } public record Variable(string Name, NubType Type); public class Scope(Scope? parent = null) { private readonly List _variables = []; public Variable? Lookup(string name) { var variable = _variables.FirstOrDefault(x => x.Name == name); if (variable != null) { return variable; } return parent?.Lookup(name); } public void Declare(Variable variable) { _variables.Add(variable); } public Scope SubScope() { return new Scope(this); } } public class BindException : Exception { public Diagnostic Diagnostic { get; } public BindException(Diagnostic diagnostic) : base(diagnostic.Message) { Diagnostic = diagnostic; } }