Files
nub-lang/src/compiler/NubLang/Syntax/Binding/Binder.cs
nub31 4055002a8c ...
2025-07-22 22:10:31 +02:00

657 lines
24 KiB
C#

using NubLang.Common;
using NubLang.Diagnostics;
using NubLang.Syntax.Binding.Node;
using NubLang.Syntax.Parsing.Node;
using NubLang.Syntax.Tokenization;
namespace NubLang.Syntax.Binding;
public sealed class Binder
{
private readonly SyntaxTree _syntaxTree;
private readonly DefinitionTable _definitionTable;
private readonly Stack<Scope> _scopes = [];
private readonly Stack<NubType> _funcReturnTypes = [];
private Scope Scope => _scopes.Peek();
public Binder(SyntaxTree syntaxTree, DefinitionTable definitionTable)
{
_syntaxTree = syntaxTree;
_definitionTable = definitionTable;
}
public BoundSyntaxTree Bind()
{
_funcReturnTypes.Clear();
_scopes.Clear();
var diagnostics = new List<Diagnostic>();
var definitions = new List<BoundDefinition>();
foreach (var definition in _syntaxTree.Definitions)
{
try
{
definitions.Add(BindDefinition(definition));
}
catch (BindException e)
{
diagnostics.Add(e.Diagnostic);
}
}
return new BoundSyntaxTree(_syntaxTree.Namespace, definitions, diagnostics);
}
private BoundDefinition BindDefinition(DefinitionSyntax node)
{
return node switch
{
ExternFuncSyntax definition => BindExternFuncDefinition(definition),
InterfaceSyntax definition => BindTraitDefinition(definition),
LocalFuncSyntax definition => BindLocalFuncDefinition(definition),
StructSyntax definition => BindStruct(definition),
_ => throw new ArgumentOutOfRangeException(nameof(node))
};
}
private BoundTrait BindTraitDefinition(InterfaceSyntax node)
{
var functions = new List<BoundTraitFunc>();
foreach (var function in node.Functions)
{
functions.Add(new BoundTraitFunc(function.Name, BindFuncSignature(function.Signature)));
}
return new BoundTrait(node.Namespace, node.Name, functions);
}
private BoundStruct BindStruct(StructSyntax node)
{
var structFields = new List<BoundStructField>();
foreach (var field in node.Fields)
{
var value = Optional.Empty<BoundExpression>();
if (field.Value.HasValue)
{
value = BindExpression(field.Value.Value, BindType(field.Type));
}
structFields.Add(new BoundStructField(field.Index, field.Name, BindType(field.Type), value));
}
return new BoundStruct(node.Namespace, node.Name, structFields);
}
private BoundExternFunc BindExternFuncDefinition(ExternFuncSyntax node)
{
return new BoundExternFunc(node.Namespace, node.Name, node.CallName, BindFuncSignature(node.Signature));
}
private BoundLocalFunc BindLocalFuncDefinition(LocalFuncSyntax node)
{
var signature = BindFuncSignature(node.Signature);
var body = BindFuncBody(node.Body, signature.ReturnType, signature.Parameters);
return new BoundLocalFunc(node.Namespace, node.Name, signature, body);
}
private BoundStatement BindStatement(StatementSyntax node)
{
return node switch
{
AssignmentSyntax statement => BindAssignment(statement),
BreakSyntax => new BoundBreak(),
ContinueSyntax => new BoundContinue(),
IfSyntax statement => BindIf(statement),
ReturnSyntax statement => BindReturn(statement),
StatementExpressionSyntax statement => BindStatementExpression(statement),
VariableDeclarationSyntax statement => BindVariableDeclaration(statement),
WhileSyntax statement => BindWhile(statement),
_ => throw new ArgumentOutOfRangeException(nameof(node))
};
}
private BoundStatement BindAssignment(AssignmentSyntax statement)
{
var expression = BindExpression(statement.Target);
var value = BindExpression(statement.Value, expression.Type);
return new BoundAssignment(expression, value);
}
private BoundIf BindIf(IfSyntax statement)
{
var elseStatement = Optional.Empty<Variant<BoundIf, BoundBlock>>();
if (statement.Else.HasValue)
{
elseStatement = statement.Else.Value.Match<Variant<BoundIf, BoundBlock>>
(
elseIf => BindIf(elseIf),
@else => BindBlock(@else)
);
}
return new BoundIf(BindExpression(statement.Condition, new NubPrimitiveType(PrimitiveTypeKind.Bool)), BindBlock(statement.Body), elseStatement);
}
private BoundReturn BindReturn(ReturnSyntax statement)
{
var value = Optional.Empty<BoundExpression>();
if (statement.Value.HasValue)
{
value = BindExpression(statement.Value.Value, _funcReturnTypes.Peek());
}
return new BoundReturn(value);
}
private BoundStatementExpression BindStatementExpression(StatementExpressionSyntax statement)
{
return new BoundStatementExpression(BindExpression(statement.Expression));
}
private BoundVariableDeclaration BindVariableDeclaration(VariableDeclarationSyntax statement)
{
NubType? type = null;
if (statement.ExplicitType.HasValue)
{
type = BindType(statement.ExplicitType.Value);
}
var assignment = Optional<BoundExpression>.Empty();
if (statement.Assignment.HasValue)
{
var boundValue = BindExpression(statement.Assignment.Value, type);
assignment = boundValue;
type = boundValue.Type;
}
if (type == null)
{
throw new NotImplementedException("Diagnostics not implemented");
}
Scope.Declare(new Variable(statement.Name, type));
return new BoundVariableDeclaration(statement.Name, assignment, type);
}
private BoundWhile BindWhile(WhileSyntax statement)
{
return new BoundWhile(BindExpression(statement.Condition, new NubPrimitiveType(PrimitiveTypeKind.Bool)), BindBlock(statement.Body));
}
private BoundExpression BindExpression(ExpressionSyntax node, NubType? expectedType = null)
{
return node switch
{
AddressOfSyntax expression => BindAddressOf(expression),
ArrowFuncSyntax expression => BindArrowFunc(expression, expectedType),
ArrayIndexAccessSyntax expression => BindArrayIndexAccess(expression),
ArrayInitializerSyntax expression => BindArrayInitializer(expression),
BinaryExpressionSyntax expression => BindBinaryExpression(expression),
DereferenceSyntax expression => BindDereference(expression),
FuncCallSyntax expression => BindFuncCall(expression),
IdentifierSyntax expression => BindIdentifier(expression),
LiteralSyntax expression => BindLiteral(expression, expectedType),
MemberAccessSyntax expression => BindMemberAccess(expression),
StructInitializerSyntax expression => BindStructInitializer(expression),
UnaryExpressionSyntax expression => BindUnaryExpression(expression),
_ => throw new ArgumentOutOfRangeException(nameof(node))
};
}
private BoundAddressOf BindAddressOf(AddressOfSyntax expression)
{
var inner = BindExpression(expression.Expression);
return new BoundAddressOf(new NubPointerType(inner.Type), inner);
}
private BoundArrowFunc BindArrowFunc(ArrowFuncSyntax expression, NubType? expectedType = null)
{
if (expectedType == null)
{
throw new BindException(Diagnostic.Error("Cannot infer argument types for arrow function").Build());
}
if (expectedType is not NubFuncType funcType)
{
throw new BindException(Diagnostic.Error($"Expected {expectedType}, but got arrow function").Build());
}
var parameters = new List<BoundFuncParameter>();
for (var i = 0; i < expression.Parameters.Count; i++)
{
if (i >= funcType.Parameters.Count)
{
throw new BindException(Diagnostic.Error($"Arrow function expected a maximum of {funcType.Parameters.Count} arguments").Build());
}
var expectedParameterType = funcType.Parameters[i];
var parameter = expression.Parameters[i];
parameters.Add(new BoundFuncParameter(parameter.Name, expectedParameterType));
}
var body = BindFuncBody(expression.Body, funcType.ReturnType, parameters);
return new BoundArrowFunc(new NubFuncType(parameters.Select(x => x.Type).ToList(), funcType.ReturnType), parameters, funcType.ReturnType, body);
}
private BoundArrayIndexAccess BindArrayIndexAccess(ArrayIndexAccessSyntax expression)
{
var boundArray = BindExpression(expression.Target);
var elementType = ((NubArrayType)boundArray.Type).ElementType;
return new BoundArrayIndexAccess(elementType, boundArray, BindExpression(expression.Index, new NubPrimitiveType(PrimitiveTypeKind.U64)));
}
private BoundArrayInitializer BindArrayInitializer(ArrayInitializerSyntax expression)
{
var capacity = BindExpression(expression.Capacity, new NubPrimitiveType(PrimitiveTypeKind.U64));
var type = new NubArrayType(BindType(expression.ElementType));
return new BoundArrayInitializer(type, capacity, BindType(expression.ElementType));
}
private BoundBinaryExpression BindBinaryExpression(BinaryExpressionSyntax expression)
{
var boundLeft = BindExpression(expression.Left);
var boundRight = BindExpression(expression.Right, boundLeft.Type);
return new BoundBinaryExpression(boundLeft.Type, boundLeft, BindBinaryOperator(expression.Operator), boundRight);
}
private BoundDereference BindDereference(DereferenceSyntax expression)
{
var boundExpression = BindExpression(expression.Expression);
var dereferencedType = ((NubPointerType)boundExpression.Type).BaseType;
return new BoundDereference(dereferencedType, boundExpression);
}
private BoundFuncCall BindFuncCall(FuncCallSyntax expression)
{
var boundExpression = BindExpression(expression.Expression);
var funcType = (NubFuncType)boundExpression.Type;
var parameters = new List<BoundExpression>();
foreach (var (i, parameter) in expression.Parameters.Index())
{
if (i >= funcType.Parameters.Count)
{
throw new NotImplementedException("Diagnostics not implemented");
}
var expectedType = funcType.Parameters[i];
parameters.Add(BindExpression(parameter, expectedType));
}
return new BoundFuncCall(funcType.ReturnType, boundExpression, parameters);
}
private BoundExpression BindIdentifier(IdentifierSyntax expression)
{
var @namespace = expression.Namespace.Or(_syntaxTree.Namespace);
var localFuncs = _definitionTable.LookupLocalFunc(@namespace, expression.Name).ToArray();
if (localFuncs.Length > 0)
{
if (localFuncs.Length > 1)
{
throw new BindException(Diagnostic.Error($"Extern func {expression.Namespace}::{expression.Name} has multiple definitions").Build());
}
var localFunc = localFuncs[0];
var returnType = BindType(localFunc.Signature.ReturnType);
var parameterTypes = localFunc.Signature.Parameters.Select(p => BindType(p.Type)).ToList();
var type = new NubFuncType(parameterTypes, returnType);
return new BoundLocalFuncIdent(type, @namespace, expression.Name);
}
var externFuncs = _definitionTable.LookupExternFunc(@namespace, expression.Name).ToArray();
if (externFuncs.Length > 0)
{
if (externFuncs.Length > 1)
{
throw new BindException(Diagnostic.Error($"Extern func {expression.Namespace}::{expression.Name} has multiple definitions").Build());
}
var externFunc = externFuncs[0];
var returnType = BindType(externFunc.Signature.ReturnType);
var parameterTypes = externFunc.Signature.Parameters.Select(p => BindType(p.Type)).ToList();
var type = new NubFuncType(parameterTypes, returnType);
return new BoundExternFuncIdent(type, @namespace, expression.Name);
}
if (!expression.Namespace.HasValue)
{
var variable = Scope.Lookup(expression.Name);
if (variable != null)
{
return new BoundVariableIdent(variable.Type, variable.Name);
}
}
throw new BindException(Diagnostic.Error($"No identifier with the name {(expression.Namespace.HasValue ? $"{expression.Namespace.Value}::" : "")}{expression.Name} exists").Build());
}
private BoundLiteral BindLiteral(LiteralSyntax expression, NubType? expectedType = null)
{
var type = expectedType ?? expression.Kind switch
{
LiteralKind.Integer => new NubPrimitiveType(PrimitiveTypeKind.I64),
LiteralKind.Float => new NubPrimitiveType(PrimitiveTypeKind.F64),
LiteralKind.String => new NubStringType(),
LiteralKind.Bool => new NubPrimitiveType(PrimitiveTypeKind.Bool),
_ => throw new ArgumentOutOfRangeException()
};
return new BoundLiteral(type, expression.Value, expression.Kind);
}
private BoundExpression BindMemberAccess(MemberAccessSyntax expression)
{
var boundExpression = BindExpression(expression.Target);
// var traitFuncImpls = _definitionTable.LookupTraitFuncImpl(boundExpression.Type, expression.Member).ToArray();
// if (traitFuncImpls.Length > 0)
// {
// if (traitFuncImpls.Length > 1)
// {
// throw new BindException(Diagnostic.Error($"Type {boundExpression.Type} implements multiple traits with the function {expression.Member}").Build());
// }
//
// var impl = traitFuncImpls[0];
//
// var returnType = BindType(impl.Signature.ReturnType);
// var parameterTypes = impl.Signature.Parameters.Select(p => BindType(p.Type)).ToList();
// var type = new NubFuncType(parameterTypes, returnType);
// return new BoundTraitImplFuncAccess(type, boundExpression, expression.Member);
// }
if (boundExpression.Type is NubCustomType customType)
{
var traits = _definitionTable.LookupTrait(customType).ToArray();
if (traits.Length > 0)
{
if (traits.Length > 1)
{
throw new BindException(Diagnostic.Error($"Trait {customType} has multiple definitions").Build());
}
var trait = traits[0];
var traitFuncs = _definitionTable.LookupTraitFunc(trait, expression.Member).ToArray();
if (traits.Length > 0)
{
if (traits.Length > 1)
{
throw new BindException(Diagnostic.Error($"Trait {customType} has multiple functions with the name {expression.Member}").Build());
}
var traitFunc = traitFuncs[0];
var returnType = BindType(traitFunc.Signature.ReturnType);
var parameterTypes = traitFunc.Signature.Parameters.Select(p => BindType(p.Type)).ToList();
var type = new NubFuncType(parameterTypes, returnType);
return new BoundInterfaceFuncAccess(type, customType, boundExpression, expression.Member);
}
}
var structs = _definitionTable.LookupStruct(customType).ToArray();
if (structs.Length > 0)
{
if (structs.Length > 1)
{
throw new BindException(Diagnostic.Error($"Struct {customType} has multiple definitions").Build());
}
var @struct = structs[0];
var fields = _definitionTable.LookupStructField(@struct, expression.Member).ToArray();
if (fields.Length > 0)
{
if (fields.Length > 1)
{
throw new BindException(Diagnostic.Error($"Struct {customType} has multiple fields with the name {expression.Member}").Build());
}
var field = fields[0];
return new BoundStructFieldAccess(BindType(field.Type), customType, boundExpression, expression.Member);
}
}
}
throw new BindException(Diagnostic.Error($"{boundExpression.Type} does not have a member with the name {expression.Member}").Build());
}
private BoundStructInitializer BindStructInitializer(StructInitializerSyntax expression)
{
var boundType = BindType(expression.StructType);
if (boundType is not NubCustomType structType)
{
throw new BindException(Diagnostic.Error($"Cannot initialize non-struct type {expression.StructType}").Build());
}
var structs = _definitionTable.LookupStruct(structType).ToArray();
if (structs.Length == 0)
{
throw new BindException(Diagnostic.Error($"Struct {structType} is not defined").Build());
}
if (structs.Length > 1)
{
throw new BindException(Diagnostic.Error($"Struct {structType} has multiple definitions").Build());
}
var @struct = structs[0];
var initializers = new Dictionary<string, BoundExpression>();
foreach (var (field, initializer) in expression.Initializers)
{
var fields = _definitionTable.LookupStructField(@struct, field).ToArray();
if (fields.Length == 0)
{
throw new BindException(Diagnostic.Error($"Struct {structType} does not have a field with the name {field}").Build());
}
if (fields.Length > 1)
{
throw new BindException(Diagnostic.Error($"Struct {structType} has multiple fields with the name {field}").Build());
}
initializers[field] = BindExpression(initializer, BindType(fields[0].Type));
}
return new BoundStructInitializer(structType, initializers);
}
private BoundUnaryExpression BindUnaryExpression(UnaryExpressionSyntax expression)
{
var boundOperand = BindExpression(expression.Operand);
NubType? type = null;
switch (expression.Operator)
{
case UnaryOperator.Negate:
{
boundOperand = BindExpression(expression.Operand, new NubPrimitiveType(PrimitiveTypeKind.I64));
if (boundOperand.Type.IsNumber)
{
type = boundOperand.Type;
}
break;
}
case UnaryOperator.Invert:
{
boundOperand = BindExpression(expression.Operand, new NubPrimitiveType(PrimitiveTypeKind.Bool));
type = new NubPrimitiveType(PrimitiveTypeKind.Bool);
break;
}
}
if (type == null)
{
throw new NotImplementedException("Diagnostics not implemented");
}
return new BoundUnaryExpression(type, BindBinaryOperator(expression.Operator), boundOperand);
}
private BoundFuncSignature BindFuncSignature(FuncSignatureSyntax node)
{
var parameters = new List<BoundFuncParameter>();
foreach (var parameter in node.Parameters)
{
parameters.Add(new BoundFuncParameter(parameter.Name, BindType(parameter.Type)));
}
return new BoundFuncSignature(parameters, BindType(node.ReturnType));
}
private BoundBinaryOperator BindBinaryOperator(BinaryOperator op)
{
return op switch
{
BinaryOperator.Equal => BoundBinaryOperator.Equal,
BinaryOperator.NotEqual => BoundBinaryOperator.NotEqual,
BinaryOperator.GreaterThan => BoundBinaryOperator.GreaterThan,
BinaryOperator.GreaterThanOrEqual => BoundBinaryOperator.GreaterThanOrEqual,
BinaryOperator.LessThan => BoundBinaryOperator.LessThan,
BinaryOperator.LessThanOrEqual => BoundBinaryOperator.LessThanOrEqual,
BinaryOperator.Plus => BoundBinaryOperator.Plus,
BinaryOperator.Minus => BoundBinaryOperator.Minus,
BinaryOperator.Multiply => BoundBinaryOperator.Multiply,
BinaryOperator.Divide => BoundBinaryOperator.Divide,
_ => throw new ArgumentOutOfRangeException(nameof(op), op, null)
};
}
private BoundUnaryOperator BindBinaryOperator(UnaryOperator op)
{
return op switch
{
UnaryOperator.Negate => BoundUnaryOperator.Negate,
UnaryOperator.Invert => BoundUnaryOperator.Invert,
_ => throw new ArgumentOutOfRangeException(nameof(op), op, null)
};
}
private BoundBlock BindBlock(BlockSyntax node, Scope? scope = null)
{
var statements = new List<BoundStatement>();
_scopes.Push(scope ?? Scope.SubScope());
foreach (var statement in node.Statements)
{
statements.Add(BindStatement(statement));
}
_scopes.Pop();
return new BoundBlock(statements);
}
private BoundBlock BindFuncBody(BlockSyntax block, NubType returnType, IReadOnlyList<BoundFuncParameter> parameters)
{
_funcReturnTypes.Push(returnType);
var scope = new Scope();
foreach (var parameter in parameters)
{
scope.Declare(new Variable(parameter.Name, parameter.Type));
}
var body = BindBlock(block, scope);
_funcReturnTypes.Pop();
return body;
}
private NubType BindType(TypeSyntax node)
{
return node switch
{
ArrayTypeSyntax type => new NubArrayType(BindType(type.BaseType)),
CStringTypeSyntax => new NubCStringType(),
CustomTypeSyntax type => new NubCustomType(type.Namespace, type.MangledName()),
FuncTypeSyntax type => new NubFuncType(type.Parameters.Select(BindType).ToList(), BindType(type.ReturnType)),
PointerTypeSyntax type => new NubPointerType(BindType(type.BaseType)),
PrimitiveTypeSyntax type => new NubPrimitiveType(type.SyntaxKind switch
{
PrimitiveTypeSyntaxKind.I64 => PrimitiveTypeKind.I64,
PrimitiveTypeSyntaxKind.I32 => PrimitiveTypeKind.I32,
PrimitiveTypeSyntaxKind.I16 => PrimitiveTypeKind.I16,
PrimitiveTypeSyntaxKind.I8 => PrimitiveTypeKind.I8,
PrimitiveTypeSyntaxKind.U64 => PrimitiveTypeKind.U64,
PrimitiveTypeSyntaxKind.U32 => PrimitiveTypeKind.U32,
PrimitiveTypeSyntaxKind.U16 => PrimitiveTypeKind.U16,
PrimitiveTypeSyntaxKind.U8 => PrimitiveTypeKind.U8,
PrimitiveTypeSyntaxKind.F64 => PrimitiveTypeKind.F64,
PrimitiveTypeSyntaxKind.F32 => PrimitiveTypeKind.F32,
PrimitiveTypeSyntaxKind.Bool => PrimitiveTypeKind.Bool,
_ => throw new ArgumentOutOfRangeException()
}),
StringTypeSyntax => new NubStringType(),
VoidTypeSyntax => new NubVoidType(),
_ => throw new ArgumentOutOfRangeException(nameof(node))
};
}
}
public record Variable(string Name, NubType Type);
public class Scope(Scope? parent = null)
{
private readonly List<Variable> _variables = [];
public Variable? Lookup(string name)
{
var variable = _variables.FirstOrDefault(x => x.Name == name);
if (variable != null)
{
return variable;
}
return parent?.Lookup(name);
}
public void Declare(Variable variable)
{
_variables.Add(variable);
}
public Scope SubScope()
{
return new Scope(this);
}
}
public class BindException : Exception
{
public Diagnostic Diagnostic { get; }
public BindException(Diagnostic diagnostic) : base(diagnostic.Message)
{
Diagnostic = diagnostic;
}
}