using NubLang.Common; using NubLang.Diagnostics; using NubLang.Syntax.Binding.Node; using NubLang.Syntax.Parsing.Node; using NubLang.Syntax.Tokenization; namespace NubLang.Syntax.Binding; public sealed class Binder { private readonly SyntaxTree _syntaxTree; private readonly DefinitionTable _definitionTable; private readonly Stack _scopes = []; private readonly Stack _funcReturnTypes = []; private Scope Scope => _scopes.Peek(); public Binder(SyntaxTree syntaxTree, DefinitionTable definitionTable) { _syntaxTree = syntaxTree; _definitionTable = definitionTable; } public BoundSyntaxTree Bind() { _funcReturnTypes.Clear(); _scopes.Clear(); var diagnostics = new List(); var definitions = new List(); foreach (var definition in _syntaxTree.Definitions) { try { definitions.Add(BindDefinition(definition)); } catch (BindException e) { diagnostics.Add(e.Diagnostic); } } return new BoundSyntaxTree(_syntaxTree.Namespace, definitions, diagnostics); } private BoundDefinition BindDefinition(DefinitionSyntax node) { return node switch { ExternFuncSyntax definition => BindExternFuncDefinition(definition), InterfaceSyntax definition => BindTraitDefinition(definition), LocalFuncSyntax definition => BindLocalFuncDefinition(definition), StructSyntax definition => BindStruct(definition), _ => throw new ArgumentOutOfRangeException(nameof(node)) }; } private BoundTrait BindTraitDefinition(InterfaceSyntax node) { var functions = new List(); foreach (var function in node.Functions) { functions.Add(new BoundTraitFunc(function.Name, BindFuncSignature(function.Signature))); } return new BoundTrait(node.Namespace, node.Name, functions); } private BoundStruct BindStruct(StructSyntax node) { var structFields = new List(); foreach (var field in node.Fields) { var value = Optional.Empty(); if (field.Value.HasValue) { value = BindExpression(field.Value.Value, BindType(field.Type)); } structFields.Add(new BoundStructField(field.Index, field.Name, BindType(field.Type), value)); } return new BoundStruct(node.Namespace, node.Name, structFields); } private BoundExternFunc BindExternFuncDefinition(ExternFuncSyntax node) { return new BoundExternFunc(node.Namespace, node.Name, node.CallName, BindFuncSignature(node.Signature)); } private BoundLocalFunc BindLocalFuncDefinition(LocalFuncSyntax node) { var signature = BindFuncSignature(node.Signature); var body = BindFuncBody(node.Body, signature.ReturnType, signature.Parameters); return new BoundLocalFunc(node.Namespace, node.Name, signature, body); } private BoundStatement BindStatement(StatementSyntax node) { return node switch { AssignmentSyntax statement => BindAssignment(statement), BreakSyntax => new BoundBreak(), ContinueSyntax => new BoundContinue(), IfSyntax statement => BindIf(statement), ReturnSyntax statement => BindReturn(statement), StatementExpressionSyntax statement => BindStatementExpression(statement), VariableDeclarationSyntax statement => BindVariableDeclaration(statement), WhileSyntax statement => BindWhile(statement), _ => throw new ArgumentOutOfRangeException(nameof(node)) }; } private BoundStatement BindAssignment(AssignmentSyntax statement) { var expression = BindExpression(statement.Target); var value = BindExpression(statement.Value, expression.Type); return new BoundAssignment(expression, value); } private BoundIf BindIf(IfSyntax statement) { var elseStatement = Optional.Empty>(); if (statement.Else.HasValue) { elseStatement = statement.Else.Value.Match> ( elseIf => BindIf(elseIf), @else => BindBlock(@else) ); } return new BoundIf(BindExpression(statement.Condition, new NubPrimitiveType(PrimitiveTypeKind.Bool)), BindBlock(statement.Body), elseStatement); } private BoundReturn BindReturn(ReturnSyntax statement) { var value = Optional.Empty(); if (statement.Value.HasValue) { value = BindExpression(statement.Value.Value, _funcReturnTypes.Peek()); } return new BoundReturn(value); } private BoundStatementExpression BindStatementExpression(StatementExpressionSyntax statement) { return new BoundStatementExpression(BindExpression(statement.Expression)); } private BoundVariableDeclaration BindVariableDeclaration(VariableDeclarationSyntax statement) { NubType? type = null; if (statement.ExplicitType.HasValue) { type = BindType(statement.ExplicitType.Value); } var assignment = Optional.Empty(); if (statement.Assignment.HasValue) { var boundValue = BindExpression(statement.Assignment.Value, type); assignment = boundValue; type = boundValue.Type; } if (type == null) { throw new NotImplementedException("Diagnostics not implemented"); } Scope.Declare(new Variable(statement.Name, type)); return new BoundVariableDeclaration(statement.Name, assignment, type); } private BoundWhile BindWhile(WhileSyntax statement) { return new BoundWhile(BindExpression(statement.Condition, new NubPrimitiveType(PrimitiveTypeKind.Bool)), BindBlock(statement.Body)); } private BoundExpression BindExpression(ExpressionSyntax node, NubType? expectedType = null) { return node switch { AddressOfSyntax expression => BindAddressOf(expression), ArrowFuncSyntax expression => BindArrowFunc(expression, expectedType), ArrayIndexAccessSyntax expression => BindArrayIndexAccess(expression), ArrayInitializerSyntax expression => BindArrayInitializer(expression), BinaryExpressionSyntax expression => BindBinaryExpression(expression), DereferenceSyntax expression => BindDereference(expression), FuncCallSyntax expression => BindFuncCall(expression), IdentifierSyntax expression => BindIdentifier(expression), LiteralSyntax expression => BindLiteral(expression, expectedType), MemberAccessSyntax expression => BindMemberAccess(expression), StructInitializerSyntax expression => BindStructInitializer(expression), UnaryExpressionSyntax expression => BindUnaryExpression(expression), _ => throw new ArgumentOutOfRangeException(nameof(node)) }; } private BoundAddressOf BindAddressOf(AddressOfSyntax expression) { var inner = BindExpression(expression.Expression); return new BoundAddressOf(new NubPointerType(inner.Type), inner); } private BoundArrowFunc BindArrowFunc(ArrowFuncSyntax expression, NubType? expectedType = null) { if (expectedType == null) { throw new BindException(Diagnostic.Error("Cannot infer argument types for arrow function").Build()); } if (expectedType is not NubFuncType funcType) { throw new BindException(Diagnostic.Error($"Expected {expectedType}, but got arrow function").Build()); } var parameters = new List(); for (var i = 0; i < expression.Parameters.Count; i++) { if (i >= funcType.Parameters.Count) { throw new BindException(Diagnostic.Error($"Arrow function expected a maximum of {funcType.Parameters.Count} arguments").Build()); } var expectedParameterType = funcType.Parameters[i]; var parameter = expression.Parameters[i]; parameters.Add(new BoundFuncParameter(parameter.Name, expectedParameterType)); } var body = BindFuncBody(expression.Body, funcType.ReturnType, parameters); return new BoundArrowFunc(new NubFuncType(parameters.Select(x => x.Type).ToList(), funcType.ReturnType), parameters, funcType.ReturnType, body); } private BoundArrayIndexAccess BindArrayIndexAccess(ArrayIndexAccessSyntax expression) { var boundArray = BindExpression(expression.Target); var elementType = ((NubArrayType)boundArray.Type).ElementType; return new BoundArrayIndexAccess(elementType, boundArray, BindExpression(expression.Index, new NubPrimitiveType(PrimitiveTypeKind.U64))); } private BoundArrayInitializer BindArrayInitializer(ArrayInitializerSyntax expression) { var capacity = BindExpression(expression.Capacity, new NubPrimitiveType(PrimitiveTypeKind.U64)); var type = new NubArrayType(BindType(expression.ElementType)); return new BoundArrayInitializer(type, capacity, BindType(expression.ElementType)); } private BoundBinaryExpression BindBinaryExpression(BinaryExpressionSyntax expression) { var boundLeft = BindExpression(expression.Left); var boundRight = BindExpression(expression.Right, boundLeft.Type); return new BoundBinaryExpression(boundLeft.Type, boundLeft, BindBinaryOperator(expression.Operator), boundRight); } private BoundDereference BindDereference(DereferenceSyntax expression) { var boundExpression = BindExpression(expression.Expression); var dereferencedType = ((NubPointerType)boundExpression.Type).BaseType; return new BoundDereference(dereferencedType, boundExpression); } private BoundFuncCall BindFuncCall(FuncCallSyntax expression) { var boundExpression = BindExpression(expression.Expression); var funcType = (NubFuncType)boundExpression.Type; var parameters = new List(); foreach (var (i, parameter) in expression.Parameters.Index()) { if (i >= funcType.Parameters.Count) { throw new NotImplementedException("Diagnostics not implemented"); } var expectedType = funcType.Parameters[i]; parameters.Add(BindExpression(parameter, expectedType)); } return new BoundFuncCall(funcType.ReturnType, boundExpression, parameters); } private BoundExpression BindIdentifier(IdentifierSyntax expression) { var @namespace = expression.Namespace.Or(_syntaxTree.Namespace); var localFuncs = _definitionTable.LookupLocalFunc(@namespace, expression.Name).ToArray(); if (localFuncs.Length > 0) { if (localFuncs.Length > 1) { throw new BindException(Diagnostic.Error($"Extern func {expression.Namespace}::{expression.Name} has multiple definitions").Build()); } var localFunc = localFuncs[0]; var returnType = BindType(localFunc.Signature.ReturnType); var parameterTypes = localFunc.Signature.Parameters.Select(p => BindType(p.Type)).ToList(); var type = new NubFuncType(parameterTypes, returnType); return new BoundLocalFuncIdent(type, @namespace, expression.Name); } var externFuncs = _definitionTable.LookupExternFunc(@namespace, expression.Name).ToArray(); if (externFuncs.Length > 0) { if (externFuncs.Length > 1) { throw new BindException(Diagnostic.Error($"Extern func {expression.Namespace}::{expression.Name} has multiple definitions").Build()); } var externFunc = externFuncs[0]; var returnType = BindType(externFunc.Signature.ReturnType); var parameterTypes = externFunc.Signature.Parameters.Select(p => BindType(p.Type)).ToList(); var type = new NubFuncType(parameterTypes, returnType); return new BoundExternFuncIdent(type, @namespace, expression.Name); } if (!expression.Namespace.HasValue) { var variable = Scope.Lookup(expression.Name); if (variable != null) { return new BoundVariableIdent(variable.Type, variable.Name); } } throw new BindException(Diagnostic.Error($"No identifier with the name {(expression.Namespace.HasValue ? $"{expression.Namespace.Value}::" : "")}{expression.Name} exists").Build()); } private BoundLiteral BindLiteral(LiteralSyntax expression, NubType? expectedType = null) { var type = expectedType ?? expression.Kind switch { LiteralKind.Integer => new NubPrimitiveType(PrimitiveTypeKind.I64), LiteralKind.Float => new NubPrimitiveType(PrimitiveTypeKind.F64), LiteralKind.String => new NubStringType(), LiteralKind.Bool => new NubPrimitiveType(PrimitiveTypeKind.Bool), _ => throw new ArgumentOutOfRangeException() }; return new BoundLiteral(type, expression.Value, expression.Kind); } private BoundExpression BindMemberAccess(MemberAccessSyntax expression) { var boundExpression = BindExpression(expression.Target); // var traitFuncImpls = _definitionTable.LookupTraitFuncImpl(boundExpression.Type, expression.Member).ToArray(); // if (traitFuncImpls.Length > 0) // { // if (traitFuncImpls.Length > 1) // { // throw new BindException(Diagnostic.Error($"Type {boundExpression.Type} implements multiple traits with the function {expression.Member}").Build()); // } // // var impl = traitFuncImpls[0]; // // var returnType = BindType(impl.Signature.ReturnType); // var parameterTypes = impl.Signature.Parameters.Select(p => BindType(p.Type)).ToList(); // var type = new NubFuncType(parameterTypes, returnType); // return new BoundTraitImplFuncAccess(type, boundExpression, expression.Member); // } if (boundExpression.Type is NubCustomType customType) { var traits = _definitionTable.LookupTrait(customType).ToArray(); if (traits.Length > 0) { if (traits.Length > 1) { throw new BindException(Diagnostic.Error($"Trait {customType} has multiple definitions").Build()); } var trait = traits[0]; var traitFuncs = _definitionTable.LookupTraitFunc(trait, expression.Member).ToArray(); if (traits.Length > 0) { if (traits.Length > 1) { throw new BindException(Diagnostic.Error($"Trait {customType} has multiple functions with the name {expression.Member}").Build()); } var traitFunc = traitFuncs[0]; var returnType = BindType(traitFunc.Signature.ReturnType); var parameterTypes = traitFunc.Signature.Parameters.Select(p => BindType(p.Type)).ToList(); var type = new NubFuncType(parameterTypes, returnType); return new BoundInterfaceFuncAccess(type, customType, boundExpression, expression.Member); } } var structs = _definitionTable.LookupStruct(customType).ToArray(); if (structs.Length > 0) { if (structs.Length > 1) { throw new BindException(Diagnostic.Error($"Struct {customType} has multiple definitions").Build()); } var @struct = structs[0]; var fields = _definitionTable.LookupStructField(@struct, expression.Member).ToArray(); if (fields.Length > 0) { if (fields.Length > 1) { throw new BindException(Diagnostic.Error($"Struct {customType} has multiple fields with the name {expression.Member}").Build()); } var field = fields[0]; return new BoundStructFieldAccess(BindType(field.Type), customType, boundExpression, expression.Member); } } } throw new BindException(Diagnostic.Error($"{boundExpression.Type} does not have a member with the name {expression.Member}").Build()); } private BoundStructInitializer BindStructInitializer(StructInitializerSyntax expression) { var boundType = BindType(expression.StructType); if (boundType is not NubCustomType structType) { throw new BindException(Diagnostic.Error($"Cannot initialize non-struct type {expression.StructType}").Build()); } var structs = _definitionTable.LookupStruct(structType).ToArray(); if (structs.Length == 0) { throw new BindException(Diagnostic.Error($"Struct {structType} is not defined").Build()); } if (structs.Length > 1) { throw new BindException(Diagnostic.Error($"Struct {structType} has multiple definitions").Build()); } var @struct = structs[0]; var initializers = new Dictionary(); foreach (var (field, initializer) in expression.Initializers) { var fields = _definitionTable.LookupStructField(@struct, field).ToArray(); if (fields.Length == 0) { throw new BindException(Diagnostic.Error($"Struct {structType} does not have a field with the name {field}").Build()); } if (fields.Length > 1) { throw new BindException(Diagnostic.Error($"Struct {structType} has multiple fields with the name {field}").Build()); } initializers[field] = BindExpression(initializer, BindType(fields[0].Type)); } return new BoundStructInitializer(structType, initializers); } private BoundUnaryExpression BindUnaryExpression(UnaryExpressionSyntax expression) { var boundOperand = BindExpression(expression.Operand); NubType? type = null; switch (expression.Operator) { case UnaryOperator.Negate: { boundOperand = BindExpression(expression.Operand, new NubPrimitiveType(PrimitiveTypeKind.I64)); if (boundOperand.Type.IsNumber) { type = boundOperand.Type; } break; } case UnaryOperator.Invert: { boundOperand = BindExpression(expression.Operand, new NubPrimitiveType(PrimitiveTypeKind.Bool)); type = new NubPrimitiveType(PrimitiveTypeKind.Bool); break; } } if (type == null) { throw new NotImplementedException("Diagnostics not implemented"); } return new BoundUnaryExpression(type, BindBinaryOperator(expression.Operator), boundOperand); } private BoundFuncSignature BindFuncSignature(FuncSignatureSyntax node) { var parameters = new List(); foreach (var parameter in node.Parameters) { parameters.Add(new BoundFuncParameter(parameter.Name, BindType(parameter.Type))); } return new BoundFuncSignature(parameters, BindType(node.ReturnType)); } private BoundBinaryOperator BindBinaryOperator(BinaryOperator op) { return op switch { BinaryOperator.Equal => BoundBinaryOperator.Equal, BinaryOperator.NotEqual => BoundBinaryOperator.NotEqual, BinaryOperator.GreaterThan => BoundBinaryOperator.GreaterThan, BinaryOperator.GreaterThanOrEqual => BoundBinaryOperator.GreaterThanOrEqual, BinaryOperator.LessThan => BoundBinaryOperator.LessThan, BinaryOperator.LessThanOrEqual => BoundBinaryOperator.LessThanOrEqual, BinaryOperator.Plus => BoundBinaryOperator.Plus, BinaryOperator.Minus => BoundBinaryOperator.Minus, BinaryOperator.Multiply => BoundBinaryOperator.Multiply, BinaryOperator.Divide => BoundBinaryOperator.Divide, _ => throw new ArgumentOutOfRangeException(nameof(op), op, null) }; } private BoundUnaryOperator BindBinaryOperator(UnaryOperator op) { return op switch { UnaryOperator.Negate => BoundUnaryOperator.Negate, UnaryOperator.Invert => BoundUnaryOperator.Invert, _ => throw new ArgumentOutOfRangeException(nameof(op), op, null) }; } private BoundBlock BindBlock(BlockSyntax node, Scope? scope = null) { var statements = new List(); _scopes.Push(scope ?? Scope.SubScope()); foreach (var statement in node.Statements) { statements.Add(BindStatement(statement)); } _scopes.Pop(); return new BoundBlock(statements); } private BoundBlock BindFuncBody(BlockSyntax block, NubType returnType, IReadOnlyList parameters) { _funcReturnTypes.Push(returnType); var scope = new Scope(); foreach (var parameter in parameters) { scope.Declare(new Variable(parameter.Name, parameter.Type)); } var body = BindBlock(block, scope); _funcReturnTypes.Pop(); return body; } private NubType BindType(TypeSyntax node) { return node switch { ArrayTypeSyntax type => new NubArrayType(BindType(type.BaseType)), CStringTypeSyntax => new NubCStringType(), CustomTypeSyntax type => new NubCustomType(type.Namespace, type.MangledName()), FuncTypeSyntax type => new NubFuncType(type.Parameters.Select(BindType).ToList(), BindType(type.ReturnType)), PointerTypeSyntax type => new NubPointerType(BindType(type.BaseType)), PrimitiveTypeSyntax type => new NubPrimitiveType(type.SyntaxKind switch { PrimitiveTypeSyntaxKind.I64 => PrimitiveTypeKind.I64, PrimitiveTypeSyntaxKind.I32 => PrimitiveTypeKind.I32, PrimitiveTypeSyntaxKind.I16 => PrimitiveTypeKind.I16, PrimitiveTypeSyntaxKind.I8 => PrimitiveTypeKind.I8, PrimitiveTypeSyntaxKind.U64 => PrimitiveTypeKind.U64, PrimitiveTypeSyntaxKind.U32 => PrimitiveTypeKind.U32, PrimitiveTypeSyntaxKind.U16 => PrimitiveTypeKind.U16, PrimitiveTypeSyntaxKind.U8 => PrimitiveTypeKind.U8, PrimitiveTypeSyntaxKind.F64 => PrimitiveTypeKind.F64, PrimitiveTypeSyntaxKind.F32 => PrimitiveTypeKind.F32, PrimitiveTypeSyntaxKind.Bool => PrimitiveTypeKind.Bool, _ => throw new ArgumentOutOfRangeException() }), StringTypeSyntax => new NubStringType(), VoidTypeSyntax => new NubVoidType(), _ => throw new ArgumentOutOfRangeException(nameof(node)) }; } } public record Variable(string Name, NubType Type); public class Scope(Scope? parent = null) { private readonly List _variables = []; public Variable? Lookup(string name) { var variable = _variables.FirstOrDefault(x => x.Name == name); if (variable != null) { return variable; } return parent?.Lookup(name); } public void Declare(Variable variable) { _variables.Add(variable); } public Scope SubScope() { return new Scope(this); } } public class BindException : Exception { public Diagnostic Diagnostic { get; } public BindException(Diagnostic diagnostic) : base(diagnostic.Message) { Diagnostic = diagnostic; } }