This commit is contained in:
nub31
2025-07-22 23:20:56 +02:00
parent 62c9d86cda
commit d993581361
26 changed files with 1002 additions and 1003 deletions

View File

@@ -5,12 +5,11 @@ using NubLang.Common;
using NubLang.Diagnostics;
using NubLang.Generation;
using NubLang.Generation.QBE;
using NubLang.Syntax.Binding;
using NubLang.Syntax.Binding.Node;
using NubLang.Syntax.Parsing;
using NubLang.Syntax.Parsing.Node;
using NubLang.Syntax.Tokenization;
using Binder = NubLang.Syntax.Binding.Binder;
using NubLang.Parsing;
using NubLang.Parsing.Syntax;
using NubLang.Tokenization;
using NubLang.TypeChecking;
using NubLang.TypeChecking.Node;
const string BIN_DIR = "bin";
const string INT_DIR = "bin-int";
@@ -92,14 +91,14 @@ foreach (var file in options.Files)
var definitionTable = new DefinitionTable(syntaxTrees);
var boundSyntaxTrees = new List<BoundSyntaxTree>();
var typedSyntaxTrees = new List<TypedSyntaxTree>();
foreach (var syntaxTree in syntaxTrees)
{
var binder = new Binder(syntaxTree, definitionTable);
var boundSyntaxTree = binder.Bind();
diagnostics.AddRange(boundSyntaxTree.Diagnostics);
boundSyntaxTrees.Add(boundSyntaxTree);
var typeChecker = new TypeChecker(syntaxTree, definitionTable);
var typedSyntaxTree = typeChecker.Check();
diagnostics.AddRange(typeChecker.GetDiagnostics());
typedSyntaxTrees.Add(typedSyntaxTree);
}
foreach (var diagnostic in diagnostics)
@@ -112,15 +111,15 @@ if (diagnostics.Any(diagnostic => diagnostic.Severity == DiagnosticSeverity.Erro
return 1;
}
var boundDefinitionTable = new BoundDefinitionTable(boundSyntaxTrees);
var typedDefinitionTable = new TypedDefinitionTable(typedSyntaxTrees);
var objectFiles = new List<string>();
foreach (var syntaxTree in boundSyntaxTrees)
foreach (var syntaxTree in typedSyntaxTrees)
{
var outFileName = HexString.CreateUnique(8);
var generator = new QBEGenerator(syntaxTree, boundDefinitionTable);
var generator = new QBEGenerator(syntaxTree, typedDefinitionTable);
var ssa = generator.Emit();
File.WriteAllText(Path.Join(INT_DEBUG_DIR, $"{outFileName}.ssa"), ssa);

View File

@@ -1,78 +0,0 @@
using NubLang.Syntax.Binding;
using NubLang.Syntax.Binding.Node;
namespace NubLang.Generation;
public sealed class BoundDefinitionTable
{
private readonly List<BoundDefinition> _definitions;
public BoundDefinitionTable(IEnumerable<BoundSyntaxTree> syntaxTrees)
{
_definitions = syntaxTrees.SelectMany(x => x.Definitions).ToList();
}
public BoundLocalFunc LookupLocalFunc(string name)
{
return _definitions
.OfType<BoundLocalFunc>()
.First(x => x.Name == name);
}
public BoundExternFunc LookupExternFunc(string name)
{
return _definitions
.OfType<BoundExternFunc>()
.First(x => x.Name == name);
}
public BoundStruct LookupStruct(string name)
{
return _definitions
.OfType<BoundStruct>()
.First(x => x.Name == name);
}
public BoundStructField LookupStructField(BoundStruct @struct, string field)
{
return @struct.Fields.First(x => x.Name == field);
}
public IEnumerable<BoundTraitImpl> LookupTraitImpls(NubType itemType)
{
return _definitions
.OfType<BoundTraitImpl>()
.Where(x => x.ForType == itemType);
}
public BoundTraitFuncImpl LookupTraitFuncImpl(NubType forType, string name)
{
return _definitions
.OfType<BoundTraitImpl>()
.Where(x => x.ForType == forType)
.SelectMany(x => x.Functions)
.First(x => x.Name == name);
}
public BoundTrait LookupTrait(string name)
{
return _definitions
.OfType<BoundTrait>()
.First(x => x.Name == name);
}
public BoundTraitFunc LookupTraitFunc(BoundTrait trait, string name)
{
return trait.Functions.First(x => x.Name == name);
}
public IEnumerable<BoundStruct> GetStructs()
{
return _definitions.OfType<BoundStruct>();
}
public IEnumerable<BoundTrait> GetTraits()
{
return _definitions.OfType<BoundTrait>();
}
}

View File

@@ -1,49 +1,49 @@
using System.Diagnostics;
using System.Globalization;
using NubLang.Syntax.Binding;
using NubLang.Syntax.Binding.Node;
using NubLang.Syntax.Tokenization;
using NubLang.Tokenization;
using NubLang.TypeChecking;
using NubLang.TypeChecking.Node;
namespace NubLang.Generation.QBE;
public partial class QBEGenerator
{
private Val EmitExpression(BoundExpression expression)
private Val EmitExpression(Expression expression)
{
return expression switch
{
BoundArrayInitializer arrayInitializer => EmitArrayInitializer(arrayInitializer),
BoundStructInitializer structInitializer => EmitStructInitializer(structInitializer),
BoundAddressOf addressOf => EmitAddressOf(addressOf),
BoundDereference dereference => EmitDereference(dereference),
BoundArrowFunc arrowFunc => EmitArrowFunc(arrowFunc),
BoundBinaryExpression binaryExpression => EmitBinaryExpression(binaryExpression),
BoundFuncCall funcCallExpression => EmitFuncCall(funcCallExpression),
BoundExternFuncIdent externFuncIdent => EmitExternFuncIdent(externFuncIdent),
BoundLocalFuncIdent localFuncIdent => EmitLocalFuncIdent(localFuncIdent),
BoundVariableIdent variableIdent => EmitVariableIdent(variableIdent),
BoundLiteral literal => EmitLiteral(literal),
BoundUnaryExpression unaryExpression => EmitUnaryExpression(unaryExpression),
BoundStructFieldAccess structFieldAccess => EmitStructFieldAccess(structFieldAccess),
BoundInterfaceFuncAccess traitFuncAccess => EmitTraitFuncAccess(traitFuncAccess),
BoundArrayIndexAccess arrayIndex => EmitArrayIndexAccess(arrayIndex),
ArrayInitializer arrayInitializer => EmitArrayInitializer(arrayInitializer),
StructInitializer structInitializer => EmitStructInitializer(structInitializer),
AddressOf addressOf => EmitAddressOf(addressOf),
Dereference dereference => EmitDereference(dereference),
ArrowFunc arrowFunc => EmitArrowFunc(arrowFunc),
BinaryExpression binaryExpression => EmitBinaryExpression(binaryExpression),
FuncCall funcCallExpression => EmitFuncCall(funcCallExpression),
ExternFuncIdent externFuncIdent => EmitExternFuncIdent(externFuncIdent),
LocalFuncIdent localFuncIdent => EmitLocalFuncIdent(localFuncIdent),
VariableIdent variableIdent => EmitVariableIdent(variableIdent),
Literal literal => EmitLiteral(literal),
UnaryExpression unaryExpression => EmitUnaryExpression(unaryExpression),
StructFieldAccess structFieldAccess => EmitStructFieldAccess(structFieldAccess),
InterfaceFuncAccess traitFuncAccess => EmitTraitFuncAccess(traitFuncAccess),
ArrayIndexAccess arrayIndex => EmitArrayIndexAccess(arrayIndex),
_ => throw new ArgumentOutOfRangeException(nameof(expression))
};
}
private Val EmitArrowFunc(BoundArrowFunc arrowFunc)
private Val EmitArrowFunc(ArrowFunc arrowFunc)
{
var name = $"$arrow_func{++_arrowFuncIndex}";
_arrowFunctions.Enqueue((arrowFunc, name));
return new Val(name, arrowFunc.Type, ValKind.Direct);
}
private Val EmitArrayIndexAccess(BoundArrayIndexAccess arrayIndexAccess)
private Val EmitArrayIndexAccess(ArrayIndexAccess arrayIndexAccess)
{
var array = EmitUnwrap(EmitExpression(arrayIndexAccess.Target));
var index = EmitUnwrap(EmitExpression(arrayIndexAccess.Index));
EmitArrayBoundsCheck(array, index);
EmitArraysCheck(array, index);
var elementType = ((NubArrayType)arrayIndexAccess.Target.Type).ElementType;
@@ -54,7 +54,7 @@ public partial class QBEGenerator
return new Val(pointer, arrayIndexAccess.Type, ValKind.Pointer);
}
private void EmitArrayBoundsCheck(string array, string index)
private void EmitArraysCheck(string array, string index)
{
var count = TmpName();
_writer.Indented($"{count} =l loadl {array}");
@@ -78,7 +78,7 @@ public partial class QBEGenerator
_writer.Indented(notOobLabel);
}
private Val EmitArrayInitializer(BoundArrayInitializer arrayInitializer)
private Val EmitArrayInitializer(ArrayInitializer arrayInitializer)
{
var capacity = EmitUnwrap(EmitExpression(arrayInitializer.Capacity));
var elementSize = arrayInitializer.ElementType.Size(_definitionTable);
@@ -99,12 +99,12 @@ public partial class QBEGenerator
return new Val(arrayPointer, arrayInitializer.Type, ValKind.Direct);
}
private Val EmitDereference(BoundDereference dereference)
private Val EmitDereference(Dereference dereference)
{
return EmitLoad(dereference.Type, EmitUnwrap(EmitExpression(dereference.Expression)));
}
private Val EmitAddressOf(BoundAddressOf addressOf)
private Val EmitAddressOf(AddressOf addressOf)
{
var value = EmitExpression(addressOf.Expression);
if (value.Kind != ValKind.Pointer)
@@ -115,7 +115,7 @@ public partial class QBEGenerator
return new Val(value.Name, addressOf.Type, ValKind.Direct);
}
private Val EmitBinaryExpression(BoundBinaryExpression binaryExpression)
private Val EmitBinaryExpression(BinaryExpression binaryExpression)
{
var left = EmitUnwrap(EmitExpression(binaryExpression.Left));
var right = EmitUnwrap(EmitExpression(binaryExpression.Right));
@@ -128,15 +128,15 @@ public partial class QBEGenerator
return new Val(outputName, binaryExpression.Type, ValKind.Direct);
}
private string EmitBinaryInstructionFor(BoundBinaryOperator op, NubType type, string left, string right)
private string EmitBinaryInstructionFor(BinaryOperator op, NubType type, string left, string right)
{
if (op is
BoundBinaryOperator.Equal or
BoundBinaryOperator.NotEqual or
BoundBinaryOperator.GreaterThan or
BoundBinaryOperator.GreaterThanOrEqual or
BoundBinaryOperator.LessThan or
BoundBinaryOperator.LessThanOrEqual)
BinaryOperator.Equal or
BinaryOperator.NotEqual or
BinaryOperator.GreaterThan or
BinaryOperator.GreaterThanOrEqual or
BinaryOperator.LessThan or
BinaryOperator.LessThanOrEqual)
{
char suffix;
@@ -177,12 +177,12 @@ public partial class QBEGenerator
throw new NotSupportedException($"Unsupported type '{simpleType}' for binary operator '{op}'");
}
if (op is BoundBinaryOperator.Equal)
if (op is BinaryOperator.Equal)
{
return "ceq" + suffix;
}
if (op is BoundBinaryOperator.NotEqual)
if (op is BinaryOperator.NotEqual)
{
return "cne" + suffix;
}
@@ -204,42 +204,42 @@ public partial class QBEGenerator
return op switch
{
BoundBinaryOperator.GreaterThan => 'c' + sign + "gt" + suffix,
BoundBinaryOperator.GreaterThanOrEqual => 'c' + sign + "ge" + suffix,
BoundBinaryOperator.LessThan => 'c' + sign + "lt" + suffix,
BoundBinaryOperator.LessThanOrEqual => 'c' + sign + "le" + suffix,
BinaryOperator.GreaterThan => 'c' + sign + "gt" + suffix,
BinaryOperator.GreaterThanOrEqual => 'c' + sign + "ge" + suffix,
BinaryOperator.LessThan => 'c' + sign + "lt" + suffix,
BinaryOperator.LessThanOrEqual => 'c' + sign + "le" + suffix,
_ => throw new ArgumentOutOfRangeException(nameof(op), op, null)
};
}
return op switch
{
BoundBinaryOperator.Plus => "add",
BoundBinaryOperator.Minus => "sub",
BoundBinaryOperator.Multiply => "mul",
BoundBinaryOperator.Divide => "div",
BinaryOperator.Plus => "add",
BinaryOperator.Minus => "sub",
BinaryOperator.Multiply => "mul",
BinaryOperator.Divide => "div",
_ => throw new ArgumentOutOfRangeException(nameof(op))
};
}
private Val EmitExternFuncIdent(BoundExternFuncIdent externFuncIdent)
private Val EmitExternFuncIdent(ExternFuncIdent externFuncIdent)
{
var func = _definitionTable.LookupExternFunc(externFuncIdent.Name);
return new Val(ExternFuncName(func), externFuncIdent.Type, ValKind.Direct);
}
private Val EmitLocalFuncIdent(BoundLocalFuncIdent localFuncIdent)
private Val EmitLocalFuncIdent(LocalFuncIdent localFuncIdent)
{
var func = _definitionTable.LookupLocalFunc(localFuncIdent.Name);
return new Val(LocalFuncName(func), localFuncIdent.Type, ValKind.Direct);
}
private Val EmitVariableIdent(BoundVariableIdent variableIdent)
private Val EmitVariableIdent(VariableIdent variableIdent)
{
return Scope.Lookup(variableIdent.Name);
}
private Val EmitLiteral(BoundLiteral literal)
private Val EmitLiteral(Literal literal)
{
switch (literal.Kind)
{
@@ -247,21 +247,21 @@ public partial class QBEGenerator
{
if (literal.Type is NubPrimitiveType { Kind: PrimitiveTypeKind.F32 })
{
var value = float.Parse(literal.Literal, CultureInfo.InvariantCulture);
var value = float.Parse(literal.Value, CultureInfo.InvariantCulture);
var bits = BitConverter.SingleToInt32Bits(value);
return new Val(bits.ToString(), literal.Type, ValKind.Direct);
}
if (literal.Type is NubPrimitiveType { Kind: PrimitiveTypeKind.F64 })
{
var value = double.Parse(literal.Literal, CultureInfo.InvariantCulture);
var value = double.Parse(literal.Value, CultureInfo.InvariantCulture);
var bits = BitConverter.DoubleToInt64Bits(value);
return new Val(bits.ToString(), literal.Type, ValKind.Direct);
}
if (literal.Type is NubPrimitiveType { Kind: PrimitiveTypeKind.I8 or PrimitiveTypeKind.U8 or PrimitiveTypeKind.I16 or PrimitiveTypeKind.U16 or PrimitiveTypeKind.I32 or PrimitiveTypeKind.U32 or PrimitiveTypeKind.I64 or PrimitiveTypeKind.U64 })
{
return new Val(literal.Literal, literal.Type, ValKind.Direct);
return new Val(literal.Value, literal.Type, ValKind.Direct);
}
break;
@@ -270,19 +270,19 @@ public partial class QBEGenerator
{
if (literal.Type is NubPrimitiveType { Kind: PrimitiveTypeKind.I8 or PrimitiveTypeKind.U8 or PrimitiveTypeKind.I16 or PrimitiveTypeKind.U16 or PrimitiveTypeKind.I32 or PrimitiveTypeKind.U32 or PrimitiveTypeKind.I64 or PrimitiveTypeKind.U64 })
{
return new Val(literal.Literal.Split(".").First(), literal.Type, ValKind.Direct);
return new Val(literal.Value.Split(".").First(), literal.Type, ValKind.Direct);
}
if (literal.Type is NubPrimitiveType { Kind: PrimitiveTypeKind.F32 })
{
var value = float.Parse(literal.Literal, CultureInfo.InvariantCulture);
var value = float.Parse(literal.Value, CultureInfo.InvariantCulture);
var bits = BitConverter.SingleToInt32Bits(value);
return new Val(bits.ToString(), literal.Type, ValKind.Direct);
}
if (literal.Type is NubPrimitiveType { Kind: PrimitiveTypeKind.F64 })
{
var value = double.Parse(literal.Literal, CultureInfo.InvariantCulture);
var value = double.Parse(literal.Value, CultureInfo.InvariantCulture);
var bits = BitConverter.DoubleToInt64Bits(value);
return new Val(bits.ToString(), literal.Type, ValKind.Direct);
}
@@ -293,14 +293,14 @@ public partial class QBEGenerator
{
if (literal.Type is NubStringType)
{
var stringLiteral = new StringLiteral(literal.Literal, StringName());
var stringLiteral = new StringLiteral(literal.Value, StringName());
_stringLiterals.Add(stringLiteral);
return new Val(stringLiteral.Name, literal.Type, ValKind.Direct);
}
if (literal.Type is NubCStringType)
{
var cStringLiteral = new CStringLiteral(literal.Literal, CStringName());
var cStringLiteral = new CStringLiteral(literal.Value, CStringName());
_cStringLiterals.Add(cStringLiteral);
return new Val(cStringLiteral.Name, literal.Type, ValKind.Direct);
}
@@ -311,7 +311,7 @@ public partial class QBEGenerator
{
if (literal.Type is NubPrimitiveType { Kind: PrimitiveTypeKind.Bool })
{
return new Val(bool.Parse(literal.Literal) ? "1" : "0", literal.Type, ValKind.Direct);
return new Val(bool.Parse(literal.Value) ? "1" : "0", literal.Type, ValKind.Direct);
}
break;
@@ -321,7 +321,7 @@ public partial class QBEGenerator
throw new NotSupportedException($"Cannot create literal of kind '{literal.Kind}' for type {literal.Type}");
}
private Val EmitStructInitializer(BoundStructInitializer structInitializer, string? destination = null)
private Val EmitStructInitializer(StructInitializer structInitializer, string? destination = null)
{
var @struct = _definitionTable.LookupStruct(structInitializer.StructType.Name);
@@ -349,14 +349,14 @@ public partial class QBEGenerator
return new Val(destination, structInitializer.StructType, ValKind.Direct);
}
private Val EmitUnaryExpression(BoundUnaryExpression unaryExpression)
private Val EmitUnaryExpression(UnaryExpression unaryExpression)
{
var operand = EmitUnwrap(EmitExpression(unaryExpression.Operand));
var outputName = TmpName();
switch (unaryExpression.Operator)
{
case BoundUnaryOperator.Negate:
case UnaryOperator.Negate:
{
switch (unaryExpression.Operand.Type)
{
@@ -376,7 +376,7 @@ public partial class QBEGenerator
break;
}
case BoundUnaryOperator.Invert:
case UnaryOperator.Invert:
{
switch (unaryExpression.Operand.Type)
{
@@ -396,7 +396,7 @@ public partial class QBEGenerator
throw new NotSupportedException($"Unary operator {unaryExpression.Operator} for type {unaryExpression.Operand.Type} not supported");
}
private Val EmitStructFieldAccess(BoundStructFieldAccess structFieldAccess)
private Val EmitStructFieldAccess(StructFieldAccess structFieldAccess)
{
var target = EmitUnwrap(EmitExpression(structFieldAccess.Target));
@@ -415,12 +415,12 @@ public partial class QBEGenerator
return new Val(output, structFieldAccess.Type, ValKind.Pointer);
}
private Val EmitTraitFuncAccess(BoundInterfaceFuncAccess interfaceFuncAccess)
private Val EmitTraitFuncAccess(InterfaceFuncAccess interfaceFuncAccess)
{
throw new NotImplementedException();
}
private Val EmitFuncCall(BoundFuncCall funcCall)
private Val EmitFuncCall(FuncCall funcCall)
{
var expression = EmitExpression(funcCall.Expression);
var funcPointer = EmitUnwrap(expression);

View File

@@ -1,36 +1,36 @@
using System.Diagnostics;
using NubLang.Syntax.Binding.Node;
using NubLang.TypeChecking.Node;
namespace NubLang.Generation.QBE;
public partial class QBEGenerator
{
private void EmitStatement(BoundStatement statement)
private void EmitStatement(Statement statement)
{
switch (statement)
{
case BoundAssignment assignment:
case Assignment assignment:
EmitAssignment(assignment);
break;
case BoundBreak:
case Break:
EmitBreak();
break;
case BoundContinue:
case Continue:
EmitContinue();
break;
case BoundIf ifStatement:
case If ifStatement:
EmitIf(ifStatement);
break;
case BoundReturn @return:
case Return @return:
EmitReturn(@return);
break;
case BoundStatementExpression statementExpression:
case StatementExpression statementExpression:
EmitExpression(statementExpression.Expression);
break;
case BoundVariableDeclaration variableDeclaration:
case VariableDeclaration variableDeclaration:
EmitVariableDeclaration(variableDeclaration);
break;
case BoundWhile whileStatement:
case While whileStatement:
EmitWhile(whileStatement);
break;
default:
@@ -38,7 +38,7 @@ public partial class QBEGenerator
}
}
private void EmitAssignment(BoundAssignment assignment)
private void EmitAssignment(Assignment assignment)
{
var destination = EmitExpression(assignment.Target);
Debug.Assert(destination.Kind == ValKind.Pointer);
@@ -57,7 +57,7 @@ public partial class QBEGenerator
_codeIsReachable = false;
}
private void EmitIf(BoundIf ifStatement)
private void EmitIf(If ifStatement)
{
var trueLabel = LabelName();
var falseLabel = LabelName();
@@ -81,7 +81,7 @@ public partial class QBEGenerator
_writer.WriteLine(endLabel);
}
private void EmitReturn(BoundReturn @return)
private void EmitReturn(Return @return)
{
if (@return.Value.HasValue)
{
@@ -94,7 +94,7 @@ public partial class QBEGenerator
}
}
private void EmitVariableDeclaration(BoundVariableDeclaration variableDeclaration)
private void EmitVariableDeclaration(VariableDeclaration variableDeclaration)
{
var name = $"%{variableDeclaration.Name}";
_writer.Indented($"{name} =l alloc8 8");
@@ -108,7 +108,7 @@ public partial class QBEGenerator
Scope.Declare(variableDeclaration.Name, new Val(name, variableDeclaration.Type, ValKind.Pointer));
}
private void EmitWhile(BoundWhile whileStatement)
private void EmitWhile(While whileStatement)
{
var conditionLabel = LabelName();
var iterationLabel = LabelName();

View File

@@ -1,23 +1,23 @@
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Text;
using NubLang.Syntax.Binding;
using NubLang.Syntax.Binding.Node;
using NubLang.Syntax.Tokenization;
using NubLang.Tokenization;
using NubLang.TypeChecking;
using NubLang.TypeChecking.Node;
namespace NubLang.Generation.QBE;
public partial class QBEGenerator
{
private readonly BoundSyntaxTree _syntaxTree;
private readonly BoundDefinitionTable _definitionTable;
private readonly TypedSyntaxTree _syntaxTree;
private readonly TypedDefinitionTable _definitionTable;
private readonly QBEWriter _writer;
private readonly List<CStringLiteral> _cStringLiterals = [];
private readonly List<StringLiteral> _stringLiterals = [];
private readonly Stack<string> _breakLabels = [];
private readonly Stack<string> _continueLabels = [];
private readonly Queue<(BoundArrowFunc Func, string Name)> _arrowFunctions = [];
private readonly Queue<(ArrowFunc Func, string Name)> _arrowFunctions = [];
private readonly Stack<Scope> _scopes = [];
private int _tmpIndex;
private int _labelIndex;
@@ -28,7 +28,7 @@ public partial class QBEGenerator
private Scope Scope => _scopes.Peek();
public QBEGenerator(BoundSyntaxTree syntaxTree, BoundDefinitionTable definitionTable)
public QBEGenerator(TypedSyntaxTree syntaxTree, TypedDefinitionTable definitionTable)
{
_syntaxTree = syntaxTree;
_definitionTable = definitionTable;
@@ -62,7 +62,7 @@ public partial class QBEGenerator
_writer.NewLine();
}
foreach (var funcDef in _syntaxTree.Definitions.OfType<BoundLocalFunc>())
foreach (var funcDef in _syntaxTree.Definitions.OfType<LocalFunc>())
{
EmitFuncDefinition(LocalFuncName(funcDef), funcDef.Signature.Parameters, funcDef.Signature.ReturnType, funcDef.Body);
_writer.NewLine();
@@ -191,21 +191,21 @@ public partial class QBEGenerator
return size;
}
private bool EmitTryMoveInto(BoundExpression source, string destinationPointer)
private bool EmitTryMoveInto(Expression source, string destinationPointer)
{
switch (source)
{
case BoundArrayInitializer arrayInitializer:
case ArrayInitializer arrayInitializer:
{
EmitStore(source.Type, EmitUnwrap(EmitArrayInitializer(arrayInitializer)), destinationPointer);
return true;
}
case BoundStructInitializer structInitializer:
case StructInitializer structInitializer:
{
EmitStructInitializer(structInitializer, destinationPointer);
return true;
}
case BoundLiteral { Kind: LiteralKind.String } literal:
case Literal { Kind: LiteralKind.String } literal:
{
EmitStore(source.Type, EmitUnwrap(EmitLiteral(literal)), destinationPointer);
return true;
@@ -215,7 +215,7 @@ public partial class QBEGenerator
return false;
}
private void EmitCopyIntoOrInitialize(BoundExpression source, string destinationPointer)
private void EmitCopyIntoOrInitialize(Expression source, string destinationPointer)
{
// If the source is a value which is not used yet such as an array/struct initializer or literal, we can skip copying
if (EmitTryMoveInto(source, destinationPointer))
@@ -253,13 +253,13 @@ public partial class QBEGenerator
}
}
private bool EmitTryCreateWithoutCopy(BoundExpression source, [NotNullWhen(true)] out string? destination)
private bool EmitTryCreateWithoutCopy(Expression source, [NotNullWhen(true)] out string? destination)
{
switch (source)
{
case BoundArrayInitializer:
case BoundStructInitializer:
case BoundLiteral { Kind: LiteralKind.String }:
case ArrayInitializer:
case StructInitializer:
case Literal { Kind: LiteralKind.String }:
{
destination = EmitUnwrap(EmitExpression(source));
return true;
@@ -270,7 +270,7 @@ public partial class QBEGenerator
return false;
}
private string EmitCreateCopyOrInitialize(BoundExpression source)
private string EmitCreateCopyOrInitialize(Expression source)
{
// If the source is a value which is not used yet such as an array/struct initializer or literal, we can skip copying
if (EmitTryCreateWithoutCopy(source, out var uncopiedValue))
@@ -328,7 +328,7 @@ public partial class QBEGenerator
return "l";
}
private void EmitFuncDefinition(string name, IReadOnlyList<BoundFuncParameter> parameters, NubType returnType, BoundBlock body)
private void EmitFuncDefinition(string name, IReadOnlyList<FuncParameter> parameters, NubType returnType, Block body)
{
_labelIndex = 0;
_tmpIndex = 0;
@@ -360,7 +360,7 @@ public partial class QBEGenerator
EmitBlock(body, scope);
// Implicit return for void functions if no explicit return has been set
if (returnType is NubVoidType && body.Statements is [.., not BoundReturn])
if (returnType is NubVoidType && body.Statements is [.., not Return])
{
if (returnType is NubVoidType)
{
@@ -371,7 +371,7 @@ public partial class QBEGenerator
_writer.EndFunction();
}
private void EmitStructDefinition(BoundStruct structDef)
private void EmitStructDefinition(Struct structDef)
{
_writer.WriteLine($"type {CustomTypeName(structDef.Name)} = {{ ");
@@ -392,7 +392,7 @@ public partial class QBEGenerator
_writer.WriteLine("}");
return;
string StructDefQBEType(BoundStructField field)
string StructDefQBEType(StructField field)
{
if (field.Type.IsSimpleType(out var simpleType, out var complexType))
{
@@ -417,7 +417,7 @@ public partial class QBEGenerator
}
}
private void EmitTraitVTable(BoundTrait traitDef)
private void EmitTraitVTable(Trait traitDef)
{
_writer.WriteLine($"type {CustomTypeName(traitDef.Name)} = {{");
@@ -429,7 +429,7 @@ public partial class QBEGenerator
_writer.WriteLine("}");
}
private void EmitBlock(BoundBlock block, Scope? scope = null)
private void EmitBlock(Block block, Scope? scope = null)
{
_scopes.Push(scope ?? Scope.SubScope());
@@ -456,7 +456,7 @@ public partial class QBEGenerator
};
}
private int OffsetOf(BoundStruct structDefinition, string member)
private int OffsetOf(Struct structDefinition, string member)
{
var offset = 0;
@@ -498,12 +498,12 @@ public partial class QBEGenerator
return $"$string{++_stringLiteralIndex}";
}
private string LocalFuncName(BoundLocalFunc funcDef)
private string LocalFuncName(LocalFunc funcDef)
{
return $"${funcDef.Name}";
}
private string ExternFuncName(BoundExternFunc funcDef)
private string ExternFuncName(ExternFunc funcDef)
{
return $"${funcDef.CallName}";
}

View File

@@ -0,0 +1,78 @@
using NubLang.TypeChecking;
using NubLang.TypeChecking.Node;
namespace NubLang.Generation;
public sealed class TypedDefinitionTable
{
private readonly List<Definition> _definitions;
public TypedDefinitionTable(IEnumerable<TypedSyntaxTree> syntaxTrees)
{
_definitions = syntaxTrees.SelectMany(x => x.Definitions).ToList();
}
public LocalFunc LookupLocalFunc(string name)
{
return _definitions
.OfType<LocalFunc>()
.First(x => x.Name == name);
}
public ExternFunc LookupExternFunc(string name)
{
return _definitions
.OfType<ExternFunc>()
.First(x => x.Name == name);
}
public Struct LookupStruct(string name)
{
return _definitions
.OfType<Struct>()
.First(x => x.Name == name);
}
public StructField LookupStructField(Struct @struct, string field)
{
return @struct.Fields.First(x => x.Name == field);
}
public IEnumerable<TraitImpl> LookupTraitImpls(NubType itemType)
{
return _definitions
.OfType<TraitImpl>()
.Where(x => x.ForType == itemType);
}
public TraitFuncImpl LookupTraitFuncImpl(NubType forType, string name)
{
return _definitions
.OfType<TraitImpl>()
.Where(x => x.ForType == forType)
.SelectMany(x => x.Functions)
.First(x => x.Name == name);
}
public Trait LookupTrait(string name)
{
return _definitions
.OfType<Trait>()
.First(x => x.Name == name);
}
public TraitFunc LookupTraitFunc(Trait trait, string name)
{
return trait.Functions.First(x => x.Name == name);
}
public IEnumerable<Struct> GetStructs()
{
return _definitions.OfType<Struct>();
}
public IEnumerable<Trait> GetTraits()
{
return _definitions.OfType<Trait>();
}
}

View File

@@ -1,10 +1,10 @@
using System.Diagnostics.CodeAnalysis;
using NubLang.Common;
using NubLang.Diagnostics;
using NubLang.Syntax.Parsing.Node;
using NubLang.Syntax.Tokenization;
using NubLang.Parsing.Syntax;
using NubLang.Tokenization;
namespace NubLang.Syntax.Parsing;
namespace NubLang.Parsing;
public sealed class Parser
{
@@ -324,57 +324,57 @@ public sealed class Parser
return left;
}
private int GetBinaryOperatorPrecedence(BinaryOperator @operator)
private int GetBinaryOperatorPrecedence(BinaryOperatorSyntax operatorSyntax)
{
return @operator switch
return operatorSyntax switch
{
BinaryOperator.Multiply => 3,
BinaryOperator.Divide => 3,
BinaryOperator.Plus => 2,
BinaryOperator.Minus => 2,
BinaryOperator.GreaterThan => 1,
BinaryOperator.GreaterThanOrEqual => 1,
BinaryOperator.LessThan => 1,
BinaryOperator.LessThanOrEqual => 1,
BinaryOperator.Equal => 0,
BinaryOperator.NotEqual => 0,
_ => throw new ArgumentOutOfRangeException(nameof(@operator), @operator, null)
BinaryOperatorSyntax.Multiply => 3,
BinaryOperatorSyntax.Divide => 3,
BinaryOperatorSyntax.Plus => 2,
BinaryOperatorSyntax.Minus => 2,
BinaryOperatorSyntax.GreaterThan => 1,
BinaryOperatorSyntax.GreaterThanOrEqual => 1,
BinaryOperatorSyntax.LessThan => 1,
BinaryOperatorSyntax.LessThanOrEqual => 1,
BinaryOperatorSyntax.Equal => 0,
BinaryOperatorSyntax.NotEqual => 0,
_ => throw new ArgumentOutOfRangeException(nameof(operatorSyntax), operatorSyntax, null)
};
}
private bool TryGetBinaryOperator(Symbol symbol, [NotNullWhen(true)] out BinaryOperator? binaryExpressionOperator)
private bool TryGetBinaryOperator(Symbol symbol, [NotNullWhen(true)] out BinaryOperatorSyntax? binaryExpressionOperator)
{
switch (symbol)
{
case Symbol.Equal:
binaryExpressionOperator = BinaryOperator.Equal;
binaryExpressionOperator = BinaryOperatorSyntax.Equal;
return true;
case Symbol.NotEqual:
binaryExpressionOperator = BinaryOperator.NotEqual;
binaryExpressionOperator = BinaryOperatorSyntax.NotEqual;
return true;
case Symbol.LessThan:
binaryExpressionOperator = BinaryOperator.LessThan;
binaryExpressionOperator = BinaryOperatorSyntax.LessThan;
return true;
case Symbol.LessThanOrEqual:
binaryExpressionOperator = BinaryOperator.LessThanOrEqual;
binaryExpressionOperator = BinaryOperatorSyntax.LessThanOrEqual;
return true;
case Symbol.GreaterThan:
binaryExpressionOperator = BinaryOperator.GreaterThan;
binaryExpressionOperator = BinaryOperatorSyntax.GreaterThan;
return true;
case Symbol.GreaterThanOrEqual:
binaryExpressionOperator = BinaryOperator.GreaterThanOrEqual;
binaryExpressionOperator = BinaryOperatorSyntax.GreaterThanOrEqual;
return true;
case Symbol.Plus:
binaryExpressionOperator = BinaryOperator.Plus;
binaryExpressionOperator = BinaryOperatorSyntax.Plus;
return true;
case Symbol.Minus:
binaryExpressionOperator = BinaryOperator.Minus;
binaryExpressionOperator = BinaryOperatorSyntax.Minus;
return true;
case Symbol.Star:
binaryExpressionOperator = BinaryOperator.Multiply;
binaryExpressionOperator = BinaryOperatorSyntax.Multiply;
return true;
case Symbol.ForwardSlash:
binaryExpressionOperator = BinaryOperator.Divide;
binaryExpressionOperator = BinaryOperatorSyntax.Divide;
return true;
default:
binaryExpressionOperator = null;
@@ -393,8 +393,8 @@ public sealed class Parser
{
Symbol.Func => ParseArrowFunction(),
Symbol.OpenParen => ParseParenthesizedExpression(),
Symbol.Minus => new UnaryExpressionSyntax(UnaryOperator.Negate, ParsePrimaryExpression()),
Symbol.Bang => new UnaryExpressionSyntax(UnaryOperator.Invert, ParsePrimaryExpression()),
Symbol.Minus => new UnaryExpressionSyntax(UnaryOperatorSyntax.Negate, ParsePrimaryExpression()),
Symbol.Bang => new UnaryExpressionSyntax(UnaryOperatorSyntax.Invert, ParsePrimaryExpression()),
Symbol.OpenBracket => ParseArrayInitializer(),
Symbol.Alloc => ParseStructInitializer(),
_ => throw new ParseException(Diagnostic

View File

@@ -1,6 +1,6 @@
using NubLang.Common;
namespace NubLang.Syntax.Parsing.Node;
namespace NubLang.Parsing.Syntax;
public abstract record DefinitionSyntax : SyntaxNode;

View File

@@ -1,15 +1,14 @@
using NubLang.Common;
using NubLang.Syntax.Tokenization;
using NubLang.Tokenization;
namespace NubLang.Syntax.Parsing.Node;
namespace NubLang.Parsing.Syntax;
public enum UnaryOperator
public enum UnaryOperatorSyntax
{
Negate,
Invert
}
public enum BinaryOperator
public enum BinaryOperatorSyntax
{
Equal,
NotEqual,
@@ -25,7 +24,7 @@ public enum BinaryOperator
public abstract record ExpressionSyntax : SyntaxNode;
public record BinaryExpressionSyntax(ExpressionSyntax Left, BinaryOperator Operator, ExpressionSyntax Right) : ExpressionSyntax
public record BinaryExpressionSyntax(ExpressionSyntax Left, BinaryOperatorSyntax OperatorSyntax, ExpressionSyntax Right) : ExpressionSyntax
{
public override IEnumerable<SyntaxNode> GetChildren()
{
@@ -34,7 +33,7 @@ public record BinaryExpressionSyntax(ExpressionSyntax Left, BinaryOperator Opera
}
}
public record UnaryExpressionSyntax(UnaryOperator Operator, ExpressionSyntax Operand) : ExpressionSyntax
public record UnaryExpressionSyntax(UnaryOperatorSyntax OperatorSyntax, ExpressionSyntax Operand) : ExpressionSyntax
{
public override IEnumerable<SyntaxNode> GetChildren()
{

View File

@@ -1,6 +1,6 @@
using NubLang.Common;
namespace NubLang.Syntax.Parsing.Node;
namespace NubLang.Parsing.Syntax;
public abstract record StatementSyntax : SyntaxNode;

View File

@@ -1,4 +1,4 @@
namespace NubLang.Syntax.Parsing.Node;
namespace NubLang.Parsing.Syntax;
public abstract record SyntaxNode
{

View File

@@ -1,4 +1,4 @@
namespace NubLang.Syntax.Parsing.Node;
namespace NubLang.Parsing.Syntax;
public enum PrimitiveTypeSyntaxKind
{

View File

@@ -1,637 +0,0 @@
using NubLang.Common;
using NubLang.Diagnostics;
using NubLang.Syntax.Binding.Node;
using NubLang.Syntax.Parsing.Node;
using NubLang.Syntax.Tokenization;
namespace NubLang.Syntax.Binding;
public sealed class Binder
{
private readonly SyntaxTree _syntaxTree;
private readonly DefinitionTable _definitionTable;
private readonly Stack<Scope> _scopes = [];
private readonly Stack<NubType> _funcReturnTypes = [];
private Scope Scope => _scopes.Peek();
public Binder(SyntaxTree syntaxTree, DefinitionTable definitionTable)
{
_syntaxTree = syntaxTree;
_definitionTable = definitionTable;
}
public BoundSyntaxTree Bind()
{
_funcReturnTypes.Clear();
_scopes.Clear();
var diagnostics = new List<Diagnostic>();
var definitions = new List<BoundDefinition>();
foreach (var definition in _syntaxTree.Definitions)
{
try
{
definitions.Add(BindDefinition(definition));
}
catch (BindException e)
{
diagnostics.Add(e.Diagnostic);
}
}
return new BoundSyntaxTree(definitions, diagnostics);
}
private BoundDefinition BindDefinition(DefinitionSyntax node)
{
return node switch
{
ExternFuncSyntax definition => BindExternFuncDefinition(definition),
InterfaceSyntax definition => BindTraitDefinition(definition),
LocalFuncSyntax definition => BindLocalFuncDefinition(definition),
StructSyntax definition => BindStruct(definition),
_ => throw new ArgumentOutOfRangeException(nameof(node))
};
}
private BoundTrait BindTraitDefinition(InterfaceSyntax node)
{
var functions = new List<BoundTraitFunc>();
foreach (var function in node.Functions)
{
functions.Add(new BoundTraitFunc(function.Name, BindFuncSignature(function.Signature)));
}
return new BoundTrait(node.Name, functions);
}
private BoundStruct BindStruct(StructSyntax node)
{
var structFields = new List<BoundStructField>();
foreach (var field in node.Fields)
{
var value = Optional.Empty<BoundExpression>();
if (field.Value.HasValue)
{
value = BindExpression(field.Value.Value, BindType(field.Type));
}
structFields.Add(new BoundStructField(field.Index, field.Name, BindType(field.Type), value));
}
return new BoundStruct(node.Name, structFields);
}
private BoundExternFunc BindExternFuncDefinition(ExternFuncSyntax node)
{
return new BoundExternFunc(node.Name, node.CallName, BindFuncSignature(node.Signature));
}
private BoundLocalFunc BindLocalFuncDefinition(LocalFuncSyntax node)
{
var signature = BindFuncSignature(node.Signature);
var body = BindFuncBody(node.Body, signature.ReturnType, signature.Parameters);
return new BoundLocalFunc(node.Name, signature, body);
}
private BoundStatement BindStatement(StatementSyntax node)
{
return node switch
{
AssignmentSyntax statement => BindAssignment(statement),
BreakSyntax => new BoundBreak(),
ContinueSyntax => new BoundContinue(),
IfSyntax statement => BindIf(statement),
ReturnSyntax statement => BindReturn(statement),
StatementExpressionSyntax statement => BindStatementExpression(statement),
VariableDeclarationSyntax statement => BindVariableDeclaration(statement),
WhileSyntax statement => BindWhile(statement),
_ => throw new ArgumentOutOfRangeException(nameof(node))
};
}
private BoundStatement BindAssignment(AssignmentSyntax statement)
{
var expression = BindExpression(statement.Target);
var value = BindExpression(statement.Value, expression.Type);
return new BoundAssignment(expression, value);
}
private BoundIf BindIf(IfSyntax statement)
{
var elseStatement = Optional.Empty<Variant<BoundIf, BoundBlock>>();
if (statement.Else.HasValue)
{
elseStatement = statement.Else.Value.Match<Variant<BoundIf, BoundBlock>>
(
elseIf => BindIf(elseIf),
@else => BindBlock(@else)
);
}
return new BoundIf(BindExpression(statement.Condition, new NubPrimitiveType(PrimitiveTypeKind.Bool)), BindBlock(statement.Body), elseStatement);
}
private BoundReturn BindReturn(ReturnSyntax statement)
{
var value = Optional.Empty<BoundExpression>();
if (statement.Value.HasValue)
{
value = BindExpression(statement.Value.Value, _funcReturnTypes.Peek());
}
return new BoundReturn(value);
}
private BoundStatementExpression BindStatementExpression(StatementExpressionSyntax statement)
{
return new BoundStatementExpression(BindExpression(statement.Expression));
}
private BoundVariableDeclaration BindVariableDeclaration(VariableDeclarationSyntax statement)
{
NubType? type = null;
if (statement.ExplicitType.HasValue)
{
type = BindType(statement.ExplicitType.Value);
}
var assignment = Optional<BoundExpression>.Empty();
if (statement.Assignment.HasValue)
{
var boundValue = BindExpression(statement.Assignment.Value, type);
assignment = boundValue;
type = boundValue.Type;
}
if (type == null)
{
throw new NotImplementedException("Diagnostics not implemented");
}
Scope.Declare(new Variable(statement.Name, type));
return new BoundVariableDeclaration(statement.Name, assignment, type);
}
private BoundWhile BindWhile(WhileSyntax statement)
{
return new BoundWhile(BindExpression(statement.Condition, new NubPrimitiveType(PrimitiveTypeKind.Bool)), BindBlock(statement.Body));
}
private BoundExpression BindExpression(ExpressionSyntax node, NubType? expectedType = null)
{
return node switch
{
AddressOfSyntax expression => BindAddressOf(expression),
ArrowFuncSyntax expression => BindArrowFunc(expression, expectedType),
ArrayIndexAccessSyntax expression => BindArrayIndexAccess(expression),
ArrayInitializerSyntax expression => BindArrayInitializer(expression),
BinaryExpressionSyntax expression => BindBinaryExpression(expression),
DereferenceSyntax expression => BindDereference(expression),
FuncCallSyntax expression => BindFuncCall(expression),
IdentifierSyntax expression => BindIdentifier(expression),
LiteralSyntax expression => BindLiteral(expression, expectedType),
MemberAccessSyntax expression => BindMemberAccess(expression),
StructInitializerSyntax expression => BindStructInitializer(expression),
UnaryExpressionSyntax expression => BindUnaryExpression(expression),
_ => throw new ArgumentOutOfRangeException(nameof(node))
};
}
private BoundAddressOf BindAddressOf(AddressOfSyntax expression)
{
var inner = BindExpression(expression.Expression);
return new BoundAddressOf(new NubPointerType(inner.Type), inner);
}
private BoundArrowFunc BindArrowFunc(ArrowFuncSyntax expression, NubType? expectedType = null)
{
if (expectedType == null)
{
throw new BindException(Diagnostic.Error("Cannot infer argument types for arrow function").Build());
}
if (expectedType is not NubFuncType funcType)
{
throw new BindException(Diagnostic.Error($"Expected {expectedType}, but got arrow function").Build());
}
var parameters = new List<BoundFuncParameter>();
for (var i = 0; i < expression.Parameters.Count; i++)
{
if (i >= funcType.Parameters.Count)
{
throw new BindException(Diagnostic.Error($"Arrow function expected a maximum of {funcType.Parameters.Count} arguments").Build());
}
var expectedParameterType = funcType.Parameters[i];
var parameter = expression.Parameters[i];
parameters.Add(new BoundFuncParameter(parameter.Name, expectedParameterType));
}
var body = BindFuncBody(expression.Body, funcType.ReturnType, parameters);
return new BoundArrowFunc(new NubFuncType(parameters.Select(x => x.Type).ToList(), funcType.ReturnType), parameters, funcType.ReturnType, body);
}
private BoundArrayIndexAccess BindArrayIndexAccess(ArrayIndexAccessSyntax expression)
{
var boundArray = BindExpression(expression.Target);
var elementType = ((NubArrayType)boundArray.Type).ElementType;
return new BoundArrayIndexAccess(elementType, boundArray, BindExpression(expression.Index, new NubPrimitiveType(PrimitiveTypeKind.U64)));
}
private BoundArrayInitializer BindArrayInitializer(ArrayInitializerSyntax expression)
{
var capacity = BindExpression(expression.Capacity, new NubPrimitiveType(PrimitiveTypeKind.U64));
var type = new NubArrayType(BindType(expression.ElementType));
return new BoundArrayInitializer(type, capacity, BindType(expression.ElementType));
}
private BoundBinaryExpression BindBinaryExpression(BinaryExpressionSyntax expression)
{
var boundLeft = BindExpression(expression.Left);
var boundRight = BindExpression(expression.Right, boundLeft.Type);
return new BoundBinaryExpression(boundLeft.Type, boundLeft, BindBinaryOperator(expression.Operator), boundRight);
}
private BoundDereference BindDereference(DereferenceSyntax expression)
{
var boundExpression = BindExpression(expression.Expression);
var dereferencedType = ((NubPointerType)boundExpression.Type).BaseType;
return new BoundDereference(dereferencedType, boundExpression);
}
private BoundFuncCall BindFuncCall(FuncCallSyntax expression)
{
var boundExpression = BindExpression(expression.Expression);
var funcType = (NubFuncType)boundExpression.Type;
var parameters = new List<BoundExpression>();
foreach (var (i, parameter) in expression.Parameters.Index())
{
if (i >= funcType.Parameters.Count)
{
throw new NotImplementedException("Diagnostics not implemented");
}
var expectedType = funcType.Parameters[i];
parameters.Add(BindExpression(parameter, expectedType));
}
return new BoundFuncCall(funcType.ReturnType, boundExpression, parameters);
}
private BoundExpression BindIdentifier(IdentifierSyntax expression)
{
var variable = Scope.Lookup(expression.Name);
if (variable != null)
{
return new BoundVariableIdent(variable.Type, variable.Name);
}
var localFuncs = _definitionTable.LookupLocalFunc(expression.Name).ToArray();
if (localFuncs.Length > 0)
{
if (localFuncs.Length > 1)
{
throw new BindException(Diagnostic.Error($"Extern func {expression.Name} has multiple definitions").Build());
}
var localFunc = localFuncs[0];
var returnType = BindType(localFunc.Signature.ReturnType);
var parameterTypes = localFunc.Signature.Parameters.Select(p => BindType(p.Type)).ToList();
var type = new NubFuncType(parameterTypes, returnType);
return new BoundLocalFuncIdent(type, expression.Name);
}
var externFuncs = _definitionTable.LookupExternFunc(expression.Name).ToArray();
if (externFuncs.Length > 0)
{
if (externFuncs.Length > 1)
{
throw new BindException(Diagnostic.Error($"Extern func {expression.Name} has multiple definitions").Build());
}
var externFunc = externFuncs[0];
var returnType = BindType(externFunc.Signature.ReturnType);
var parameterTypes = externFunc.Signature.Parameters.Select(p => BindType(p.Type)).ToList();
var type = new NubFuncType(parameterTypes, returnType);
return new BoundExternFuncIdent(type, expression.Name);
}
throw new BindException(Diagnostic.Error($"No identifier with the name {expression.Name} exists").Build());
}
private BoundLiteral BindLiteral(LiteralSyntax expression, NubType? expectedType = null)
{
var type = expectedType ?? expression.Kind switch
{
LiteralKind.Integer => new NubPrimitiveType(PrimitiveTypeKind.I64),
LiteralKind.Float => new NubPrimitiveType(PrimitiveTypeKind.F64),
LiteralKind.String => new NubStringType(),
LiteralKind.Bool => new NubPrimitiveType(PrimitiveTypeKind.Bool),
_ => throw new ArgumentOutOfRangeException()
};
return new BoundLiteral(type, expression.Value, expression.Kind);
}
private BoundExpression BindMemberAccess(MemberAccessSyntax expression)
{
var boundExpression = BindExpression(expression.Target);
if (boundExpression.Type is NubCustomType customType)
{
var traits = _definitionTable.LookupTrait(customType).ToArray();
if (traits.Length > 0)
{
if (traits.Length > 1)
{
throw new BindException(Diagnostic.Error($"Trait {customType} has multiple definitions").Build());
}
var trait = traits[0];
var traitFuncs = _definitionTable.LookupTraitFunc(trait, expression.Member).ToArray();
if (traits.Length > 0)
{
if (traits.Length > 1)
{
throw new BindException(Diagnostic.Error($"Trait {customType} has multiple functions with the name {expression.Member}").Build());
}
var traitFunc = traitFuncs[0];
var returnType = BindType(traitFunc.Signature.ReturnType);
var parameterTypes = traitFunc.Signature.Parameters.Select(p => BindType(p.Type)).ToList();
var type = new NubFuncType(parameterTypes, returnType);
return new BoundInterfaceFuncAccess(type, customType, boundExpression, expression.Member);
}
}
var structs = _definitionTable.LookupStruct(customType).ToArray();
if (structs.Length > 0)
{
if (structs.Length > 1)
{
throw new BindException(Diagnostic.Error($"Struct {customType} has multiple definitions").Build());
}
var @struct = structs[0];
var fields = _definitionTable.LookupStructField(@struct, expression.Member).ToArray();
if (fields.Length > 0)
{
if (fields.Length > 1)
{
throw new BindException(Diagnostic.Error($"Struct {customType} has multiple fields with the name {expression.Member}").Build());
}
var field = fields[0];
return new BoundStructFieldAccess(BindType(field.Type), customType, boundExpression, expression.Member);
}
}
}
throw new BindException(Diagnostic.Error($"{boundExpression.Type} does not have a member with the name {expression.Member}").Build());
}
private BoundStructInitializer BindStructInitializer(StructInitializerSyntax expression)
{
var boundType = BindType(expression.StructType);
if (boundType is not NubCustomType structType)
{
throw new BindException(Diagnostic.Error($"Cannot initialize non-struct type {expression.StructType}").Build());
}
var structs = _definitionTable.LookupStruct(structType).ToArray();
if (structs.Length == 0)
{
throw new BindException(Diagnostic.Error($"Struct {structType} is not defined").Build());
}
if (structs.Length > 1)
{
throw new BindException(Diagnostic.Error($"Struct {structType} has multiple definitions").Build());
}
var @struct = structs[0];
var initializers = new Dictionary<string, BoundExpression>();
foreach (var (field, initializer) in expression.Initializers)
{
var fields = _definitionTable.LookupStructField(@struct, field).ToArray();
if (fields.Length == 0)
{
throw new BindException(Diagnostic.Error($"Struct {structType} does not have a field with the name {field}").Build());
}
if (fields.Length > 1)
{
throw new BindException(Diagnostic.Error($"Struct {structType} has multiple fields with the name {field}").Build());
}
initializers[field] = BindExpression(initializer, BindType(fields[0].Type));
}
return new BoundStructInitializer(structType, initializers);
}
private BoundUnaryExpression BindUnaryExpression(UnaryExpressionSyntax expression)
{
var boundOperand = BindExpression(expression.Operand);
NubType? type = null;
switch (expression.Operator)
{
case UnaryOperator.Negate:
{
boundOperand = BindExpression(expression.Operand, new NubPrimitiveType(PrimitiveTypeKind.I64));
if (boundOperand.Type.IsNumber)
{
type = boundOperand.Type;
}
break;
}
case UnaryOperator.Invert:
{
boundOperand = BindExpression(expression.Operand, new NubPrimitiveType(PrimitiveTypeKind.Bool));
type = new NubPrimitiveType(PrimitiveTypeKind.Bool);
break;
}
}
if (type == null)
{
throw new NotImplementedException("Diagnostics not implemented");
}
return new BoundUnaryExpression(type, BindBinaryOperator(expression.Operator), boundOperand);
}
private BoundFuncSignature BindFuncSignature(FuncSignatureSyntax node)
{
var parameters = new List<BoundFuncParameter>();
foreach (var parameter in node.Parameters)
{
parameters.Add(new BoundFuncParameter(parameter.Name, BindType(parameter.Type)));
}
return new BoundFuncSignature(parameters, BindType(node.ReturnType));
}
private BoundBinaryOperator BindBinaryOperator(BinaryOperator op)
{
return op switch
{
BinaryOperator.Equal => BoundBinaryOperator.Equal,
BinaryOperator.NotEqual => BoundBinaryOperator.NotEqual,
BinaryOperator.GreaterThan => BoundBinaryOperator.GreaterThan,
BinaryOperator.GreaterThanOrEqual => BoundBinaryOperator.GreaterThanOrEqual,
BinaryOperator.LessThan => BoundBinaryOperator.LessThan,
BinaryOperator.LessThanOrEqual => BoundBinaryOperator.LessThanOrEqual,
BinaryOperator.Plus => BoundBinaryOperator.Plus,
BinaryOperator.Minus => BoundBinaryOperator.Minus,
BinaryOperator.Multiply => BoundBinaryOperator.Multiply,
BinaryOperator.Divide => BoundBinaryOperator.Divide,
_ => throw new ArgumentOutOfRangeException(nameof(op), op, null)
};
}
private BoundUnaryOperator BindBinaryOperator(UnaryOperator op)
{
return op switch
{
UnaryOperator.Negate => BoundUnaryOperator.Negate,
UnaryOperator.Invert => BoundUnaryOperator.Invert,
_ => throw new ArgumentOutOfRangeException(nameof(op), op, null)
};
}
private BoundBlock BindBlock(BlockSyntax node, Scope? scope = null)
{
var statements = new List<BoundStatement>();
_scopes.Push(scope ?? Scope.SubScope());
foreach (var statement in node.Statements)
{
statements.Add(BindStatement(statement));
}
_scopes.Pop();
return new BoundBlock(statements);
}
private BoundBlock BindFuncBody(BlockSyntax block, NubType returnType, IReadOnlyList<BoundFuncParameter> parameters)
{
_funcReturnTypes.Push(returnType);
var scope = new Scope();
foreach (var parameter in parameters)
{
scope.Declare(new Variable(parameter.Name, parameter.Type));
}
var body = BindBlock(block, scope);
_funcReturnTypes.Pop();
return body;
}
private NubType BindType(TypeSyntax node)
{
return node switch
{
ArrayTypeSyntax type => new NubArrayType(BindType(type.BaseType)),
CStringTypeSyntax => new NubCStringType(),
CustomTypeSyntax type => new NubCustomType(type.MangledName()),
FuncTypeSyntax type => new NubFuncType(type.Parameters.Select(BindType).ToList(), BindType(type.ReturnType)),
PointerTypeSyntax type => new NubPointerType(BindType(type.BaseType)),
PrimitiveTypeSyntax type => new NubPrimitiveType(type.SyntaxKind switch
{
PrimitiveTypeSyntaxKind.I64 => PrimitiveTypeKind.I64,
PrimitiveTypeSyntaxKind.I32 => PrimitiveTypeKind.I32,
PrimitiveTypeSyntaxKind.I16 => PrimitiveTypeKind.I16,
PrimitiveTypeSyntaxKind.I8 => PrimitiveTypeKind.I8,
PrimitiveTypeSyntaxKind.U64 => PrimitiveTypeKind.U64,
PrimitiveTypeSyntaxKind.U32 => PrimitiveTypeKind.U32,
PrimitiveTypeSyntaxKind.U16 => PrimitiveTypeKind.U16,
PrimitiveTypeSyntaxKind.U8 => PrimitiveTypeKind.U8,
PrimitiveTypeSyntaxKind.F64 => PrimitiveTypeKind.F64,
PrimitiveTypeSyntaxKind.F32 => PrimitiveTypeKind.F32,
PrimitiveTypeSyntaxKind.Bool => PrimitiveTypeKind.Bool,
_ => throw new ArgumentOutOfRangeException()
}),
StringTypeSyntax => new NubStringType(),
VoidTypeSyntax => new NubVoidType(),
_ => throw new ArgumentOutOfRangeException(nameof(node))
};
}
}
public record Variable(string Name, NubType Type);
public class Scope(Scope? parent = null)
{
private readonly List<Variable> _variables = [];
public Variable? Lookup(string name)
{
var variable = _variables.FirstOrDefault(x => x.Name == name);
if (variable != null)
{
return variable;
}
return parent?.Lookup(name);
}
public void Declare(Variable variable)
{
_variables.Add(variable);
}
public Scope SubScope()
{
return new Scope(this);
}
}
public class BindException : Exception
{
public Diagnostic Diagnostic { get; }
public BindException(Diagnostic diagnostic) : base(diagnostic.Message)
{
Diagnostic = diagnostic;
}
}

View File

@@ -1,25 +0,0 @@
using NubLang.Common;
namespace NubLang.Syntax.Binding.Node;
public abstract record BoundDefinition : BoundNode;
public record BoundFuncParameter(string Name, NubType Type) : BoundNode;
public record BoundFuncSignature(IReadOnlyList<BoundFuncParameter> Parameters, NubType ReturnType) : BoundNode;
public record BoundLocalFunc(string Name, BoundFuncSignature Signature, BoundBlock Body) : BoundDefinition;
public record BoundExternFunc(string Name, string CallName, BoundFuncSignature Signature) : BoundDefinition;
public record BoundStructField(int Index, string Name, NubType Type, Optional<BoundExpression> Value) : BoundNode;
public record BoundStruct(string Name, IReadOnlyList<BoundStructField> Fields) : BoundDefinition;
public record BoundTraitFunc(string Name, BoundFuncSignature Signature) : BoundNode;
public record BoundTrait(string Name, IReadOnlyList<BoundTraitFunc> Functions) : BoundDefinition;
public record BoundTraitFuncImpl(string Name, BoundFuncSignature Signature, BoundBlock Body) : BoundNode;
public record BoundTraitImpl(NubType TraitType, NubType ForType, IReadOnlyList<BoundTraitFuncImpl> Functions) : BoundDefinition;

View File

@@ -1,55 +0,0 @@
using NubLang.Syntax.Tokenization;
namespace NubLang.Syntax.Binding.Node;
public enum BoundUnaryOperator
{
Negate,
Invert
}
public enum BoundBinaryOperator
{
Equal,
NotEqual,
GreaterThan,
GreaterThanOrEqual,
LessThan,
LessThanOrEqual,
Plus,
Minus,
Multiply,
Divide
}
public abstract record BoundExpression(NubType Type) : BoundNode;
public record BoundBinaryExpression(NubType Type, BoundExpression Left, BoundBinaryOperator Operator, BoundExpression Right) : BoundExpression(Type);
public record BoundUnaryExpression(NubType Type, BoundUnaryOperator Operator, BoundExpression Operand) : BoundExpression(Type);
public record BoundFuncCall(NubType Type, BoundExpression Expression, IReadOnlyList<BoundExpression> Parameters) : BoundExpression(Type);
public record BoundVariableIdent(NubType Type, string Name) : BoundExpression(Type);
public record BoundLocalFuncIdent(NubType Type, string Name) : BoundExpression(Type);
public record BoundExternFuncIdent(NubType Type, string Name) : BoundExpression(Type);
public record BoundArrayInitializer(NubType Type, BoundExpression Capacity, NubType ElementType) : BoundExpression(Type);
public record BoundArrayIndexAccess(NubType Type, BoundExpression Target, BoundExpression Index) : BoundExpression(Type);
public record BoundArrowFunc(NubType Type, IReadOnlyList<BoundFuncParameter> Parameters, NubType ReturnType, BoundBlock Body) : BoundExpression(Type);
public record BoundAddressOf(NubType Type, BoundExpression Expression) : BoundExpression(Type);
public record BoundLiteral(NubType Type, string Literal, LiteralKind Kind) : BoundExpression(Type);
public record BoundStructFieldAccess(NubType Type, NubCustomType StructType, BoundExpression Target, string Field) : BoundExpression(Type);
public record BoundInterfaceFuncAccess(NubType Type, NubCustomType InterfaceType, BoundExpression Target, string FuncName) : BoundExpression(Type);
public record BoundStructInitializer(NubCustomType StructType, Dictionary<string, BoundExpression> Initializers) : BoundExpression(StructType);
public record BoundDereference(NubType Type, BoundExpression Expression) : BoundExpression(Type);

View File

@@ -1,21 +0,0 @@
using NubLang.Common;
namespace NubLang.Syntax.Binding.Node;
public record BoundStatement : BoundNode;
public record BoundStatementExpression(BoundExpression Expression) : BoundStatement;
public record BoundReturn(Optional<BoundExpression> Value) : BoundStatement;
public record BoundAssignment(BoundExpression Target, BoundExpression Value) : BoundStatement;
public record BoundIf(BoundExpression Condition, BoundBlock Body, Optional<Variant<BoundIf, BoundBlock>> Else) : BoundStatement;
public record BoundVariableDeclaration(string Name, Optional<BoundExpression> Assignment, NubType Type) : BoundStatement;
public record BoundContinue : BoundStatement;
public record BoundBreak : BoundStatement;
public record BoundWhile(BoundExpression Condition, BoundBlock Body) : BoundStatement;

View File

@@ -1,9 +0,0 @@
using NubLang.Diagnostics;
namespace NubLang.Syntax.Binding.Node;
public record BoundSyntaxTree(IReadOnlyList<BoundDefinition> Definitions, IReadOnlyList<Diagnostic> Diagnostics);
public abstract record BoundNode;
public record BoundBlock(IReadOnlyList<BoundStatement> Statements) : BoundNode;

View File

@@ -1,4 +1,4 @@
namespace NubLang.Syntax.Tokenization;
namespace NubLang.Tokenization;
public abstract class Token;

View File

@@ -1,6 +1,6 @@
using NubLang.Common;
namespace NubLang.Syntax.Tokenization;
namespace NubLang.Tokenization;
public sealed class Tokenizer
{

View File

@@ -1,6 +1,6 @@
using NubLang.Syntax.Parsing.Node;
using NubLang.Parsing.Syntax;
namespace NubLang.Syntax.Binding;
namespace NubLang.TypeChecking;
public class DefinitionTable
{

View File

@@ -0,0 +1,25 @@
using NubLang.Common;
namespace NubLang.TypeChecking.Node;
public abstract record Definition : Node;
public record FuncParameter(string Name, NubType Type) : Node;
public record FuncSignature(IReadOnlyList<FuncParameter> Parameters, NubType ReturnType) : Node;
public record LocalFunc(string Name, FuncSignature Signature, Block Body) : Definition;
public record ExternFunc(string Name, string CallName, FuncSignature Signature) : Definition;
public record StructField(int Index, string Name, NubType Type, Optional<Expression> Value) : Node;
public record Struct(string Name, IReadOnlyList<StructField> Fields) : Definition;
public record TraitFunc(string Name, FuncSignature Signature) : Node;
public record Trait(string Name, IReadOnlyList<TraitFunc> Functions) : Definition;
public record TraitFuncImpl(string Name, FuncSignature Signature, Block Body) : Node;
public record TraitImpl(NubType TraitType, NubType ForType, IReadOnlyList<TraitFuncImpl> Functions) : Definition;

View File

@@ -0,0 +1,55 @@
using NubLang.Tokenization;
namespace NubLang.TypeChecking.Node;
public enum UnaryOperator
{
Negate,
Invert
}
public enum BinaryOperator
{
Equal,
NotEqual,
GreaterThan,
GreaterThanOrEqual,
LessThan,
LessThanOrEqual,
Plus,
Minus,
Multiply,
Divide
}
public abstract record Expression(NubType Type) : Node;
public record BinaryExpression(NubType Type, Expression Left, BinaryOperator Operator, Expression Right) : Expression(Type);
public record UnaryExpression(NubType Type, UnaryOperator Operator, Expression Operand) : Expression(Type);
public record FuncCall(NubType Type, Expression Expression, IReadOnlyList<Expression> Parameters) : Expression(Type);
public record VariableIdent(NubType Type, string Name) : Expression(Type);
public record LocalFuncIdent(NubType Type, string Name) : Expression(Type);
public record ExternFuncIdent(NubType Type, string Name) : Expression(Type);
public record ArrayInitializer(NubType Type, Expression Capacity, NubType ElementType) : Expression(Type);
public record ArrayIndexAccess(NubType Type, Expression Target, Expression Index) : Expression(Type);
public record ArrowFunc(NubType Type, IReadOnlyList<FuncParameter> Parameters, NubType ReturnType, Block Body) : Expression(Type);
public record AddressOf(NubType Type, Expression Expression) : Expression(Type);
public record Literal(NubType Type, string Value, LiteralKind Kind) : Expression(Type);
public record StructFieldAccess(NubType Type, NubCustomType StructType, Expression Target, string Field) : Expression(Type);
public record InterfaceFuncAccess(NubType Type, NubCustomType InterfaceType, Expression Target, string FuncName) : Expression(Type);
public record StructInitializer(NubCustomType StructType, Dictionary<string, Expression> Initializers) : Expression(StructType);
public record Dereference(NubType Type, Expression Expression) : Expression(Type);

View File

@@ -0,0 +1,21 @@
using NubLang.Common;
namespace NubLang.TypeChecking.Node;
public record Statement : Node;
public record StatementExpression(Expression Expression) : Statement;
public record Return(Optional<Expression> Value) : Statement;
public record Assignment(Expression Target, Expression Value) : Statement;
public record If(Expression Condition, Block Body, Optional<Variant<If, Block>> Else) : Statement;
public record VariableDeclaration(string Name, Optional<Expression> Assignment, NubType Type) : Statement;
public record Continue : Statement;
public record Break : Statement;
public record While(Expression Condition, Block Body) : Statement;

View File

@@ -0,0 +1,7 @@
namespace NubLang.TypeChecking.Node;
public record TypedSyntaxTree(IReadOnlyList<Definition> Definitions);
public abstract record Node;
public record Block(IReadOnlyList<Statement> Statements) : Node;

View File

@@ -1,7 +1,7 @@
using System.Diagnostics.CodeAnalysis;
using NubLang.Generation;
namespace NubLang.Syntax.Binding;
namespace NubLang.TypeChecking;
public abstract class NubType : IEquatable<NubType>
{
@@ -38,8 +38,8 @@ public abstract class NubType : IEquatable<NubType>
throw new ArgumentException($"Type {this} is not a simple type nor a compex type");
}
public abstract int Size(BoundDefinitionTable definitionTable);
public abstract int Alignment(BoundDefinitionTable definitionTable);
public abstract int Size(TypedDefinitionTable definitionTable);
public abstract int Alignment(TypedDefinitionTable definitionTable);
public static int AlignTo(int offset, int alignment)
{
@@ -75,7 +75,7 @@ public abstract class NubSimpleType : NubType
{
public abstract StorageSize StorageSize { get; }
public override int Size(BoundDefinitionTable definitionTable)
public override int Size(TypedDefinitionTable definitionTable)
{
return StorageSize switch
{
@@ -87,7 +87,7 @@ public abstract class NubSimpleType : NubType
};
}
public override int Alignment(BoundDefinitionTable definitionTable)
public override int Alignment(TypedDefinitionTable definitionTable)
{
return Size(definitionTable);
}
@@ -208,8 +208,8 @@ public abstract class NubComplexType : NubType;
public class NubCStringType : NubComplexType
{
public override int Size(BoundDefinitionTable definitionTable) => 8;
public override int Alignment(BoundDefinitionTable definitionTable) => Size(definitionTable);
public override int Size(TypedDefinitionTable definitionTable) => 8;
public override int Alignment(TypedDefinitionTable definitionTable) => Size(definitionTable);
public override string ToString() => "cstring";
public override bool Equals(NubType? other) => other is NubCStringType;
@@ -218,8 +218,8 @@ public class NubCStringType : NubComplexType
public class NubStringType : NubComplexType
{
public override int Size(BoundDefinitionTable definitionTable) => 8;
public override int Alignment(BoundDefinitionTable definitionTable) => Size(definitionTable);
public override int Size(TypedDefinitionTable definitionTable) => 8;
public override int Alignment(TypedDefinitionTable definitionTable) => Size(definitionTable);
public override string ToString() => "string";
public override bool Equals(NubType? other) => other is NubStringType;
@@ -230,7 +230,7 @@ public class NubCustomType(string name) : NubComplexType
{
public string Name { get; } = name;
public CustomTypeKind Kind(BoundDefinitionTable definitionTable)
public CustomTypeKind Kind(TypedDefinitionTable definitionTable)
{
if (definitionTable.GetStructs().Any(x => x.Name == Name))
{
@@ -245,7 +245,7 @@ public class NubCustomType(string name) : NubComplexType
throw new ArgumentException($"Definition table does not have any type information for {this}");
}
public override int Size(BoundDefinitionTable definitionTable)
public override int Size(TypedDefinitionTable definitionTable)
{
switch (Kind(definitionTable))
{
@@ -275,7 +275,7 @@ public class NubCustomType(string name) : NubComplexType
}
}
public override int Alignment(BoundDefinitionTable definitionTable)
public override int Alignment(TypedDefinitionTable definitionTable)
{
switch (Kind(definitionTable))
{
@@ -303,8 +303,8 @@ public class NubArrayType(NubType elementType) : NubComplexType
{
public NubType ElementType { get; } = elementType;
public override int Size(BoundDefinitionTable definitionTable) => 8;
public override int Alignment(BoundDefinitionTable definitionTable) => Size(definitionTable);
public override int Size(TypedDefinitionTable definitionTable) => 8;
public override int Alignment(TypedDefinitionTable definitionTable) => Size(definitionTable);
public override string ToString() => "[]" + ElementType;

View File

@@ -0,0 +1,640 @@
using NubLang.Common;
using NubLang.Diagnostics;
using NubLang.Parsing.Syntax;
using NubLang.Tokenization;
using NubLang.TypeChecking.Node;
namespace NubLang.TypeChecking;
public sealed class TypeChecker
{
private readonly SyntaxTree _syntaxTree;
private readonly DefinitionTable _definitionTable;
private readonly Stack<Scope> _scopes = [];
private readonly Stack<NubType> _funcReturnTypes = [];
private readonly List<Diagnostic> _diagnostics = [];
private Scope Scope => _scopes.Peek();
public TypeChecker(SyntaxTree syntaxTree, DefinitionTable definitionTable)
{
_syntaxTree = syntaxTree;
_definitionTable = definitionTable;
}
public IReadOnlyList<Diagnostic> GetDiagnostics() => _diagnostics;
public TypedSyntaxTree Check()
{
_diagnostics.Clear();
_funcReturnTypes.Clear();
_scopes.Clear();
var definitions = new List<Definition>();
foreach (var definition in _syntaxTree.Definitions)
{
try
{
definitions.Add(CheckDefinition(definition));
}
catch (CheckException e)
{
_diagnostics.Add(e.Diagnostic);
}
}
return new TypedSyntaxTree(definitions);
}
private Definition CheckDefinition(DefinitionSyntax node)
{
return node switch
{
ExternFuncSyntax definition => CheckExternFuncDefinition(definition),
InterfaceSyntax definition => CheckTraitDefinition(definition),
LocalFuncSyntax definition => CheckLocalFuncDefinition(definition),
StructSyntax definition => CheckStruct(definition),
_ => throw new ArgumentOutOfRangeException(nameof(node))
};
}
private Trait CheckTraitDefinition(InterfaceSyntax node)
{
var functions = new List<TraitFunc>();
foreach (var function in node.Functions)
{
functions.Add(new TraitFunc(function.Name, CheckFuncSignature(function.Signature)));
}
return new Trait(node.Name, functions);
}
private Struct CheckStruct(StructSyntax node)
{
var structFields = new List<StructField>();
foreach (var field in node.Fields)
{
var value = Optional.Empty<Expression>();
if (field.Value.HasValue)
{
value = CheckExpression(field.Value.Value, CheckType(field.Type));
}
structFields.Add(new StructField(field.Index, field.Name, CheckType(field.Type), value));
}
return new Struct(node.Name, structFields);
}
private ExternFunc CheckExternFuncDefinition(ExternFuncSyntax node)
{
return new ExternFunc(node.Name, node.CallName, CheckFuncSignature(node.Signature));
}
private LocalFunc CheckLocalFuncDefinition(LocalFuncSyntax node)
{
var signature = CheckFuncSignature(node.Signature);
var body = CheckFuncBody(node.Body, signature.ReturnType, signature.Parameters);
return new LocalFunc(node.Name, signature, body);
}
private Statement CheckStatement(StatementSyntax node)
{
return node switch
{
AssignmentSyntax statement => CheckAssignment(statement),
BreakSyntax => new Break(),
ContinueSyntax => new Continue(),
IfSyntax statement => CheckIf(statement),
ReturnSyntax statement => CheckReturn(statement),
StatementExpressionSyntax statement => CheckStatementExpression(statement),
VariableDeclarationSyntax statement => CheckVariableDeclaration(statement),
WhileSyntax statement => CheckWhile(statement),
_ => throw new ArgumentOutOfRangeException(nameof(node))
};
}
private Statement CheckAssignment(AssignmentSyntax statement)
{
var expression = CheckExpression(statement.Target);
var value = CheckExpression(statement.Value, expression.Type);
return new Assignment(expression, value);
}
private If CheckIf(IfSyntax statement)
{
var elseStatement = Optional.Empty<Variant<If, Block>>();
if (statement.Else.HasValue)
{
elseStatement = statement.Else.Value.Match<Variant<If, Block>>
(
elseIf => CheckIf(elseIf),
@else => CheckBlock(@else)
);
}
return new If(CheckExpression(statement.Condition, new NubPrimitiveType(PrimitiveTypeKind.Bool)), CheckBlock(statement.Body), elseStatement);
}
private Return CheckReturn(ReturnSyntax statement)
{
var value = Optional.Empty<Expression>();
if (statement.Value.HasValue)
{
value = CheckExpression(statement.Value.Value, _funcReturnTypes.Peek());
}
return new Return(value);
}
private StatementExpression CheckStatementExpression(StatementExpressionSyntax statement)
{
return new StatementExpression(CheckExpression(statement.Expression));
}
private VariableDeclaration CheckVariableDeclaration(VariableDeclarationSyntax statement)
{
NubType? type = null;
if (statement.ExplicitType.HasValue)
{
type = CheckType(statement.ExplicitType.Value);
}
var assignment = Optional<Expression>.Empty();
if (statement.Assignment.HasValue)
{
var boundValue = CheckExpression(statement.Assignment.Value, type);
assignment = boundValue;
type = boundValue.Type;
}
if (type == null)
{
throw new NotImplementedException("Diagnostics not implemented");
}
Scope.Declare(new Variable(statement.Name, type));
return new VariableDeclaration(statement.Name, assignment, type);
}
private While CheckWhile(WhileSyntax statement)
{
return new While(CheckExpression(statement.Condition, new NubPrimitiveType(PrimitiveTypeKind.Bool)), CheckBlock(statement.Body));
}
private Expression CheckExpression(ExpressionSyntax node, NubType? expectedType = null)
{
return node switch
{
AddressOfSyntax expression => CheckAddressOf(expression),
ArrowFuncSyntax expression => CheckArrowFunc(expression, expectedType),
ArrayIndexAccessSyntax expression => CheckArrayIndexAccess(expression),
ArrayInitializerSyntax expression => CheckArrayInitializer(expression),
BinaryExpressionSyntax expression => CheckBinaryExpression(expression),
DereferenceSyntax expression => CheckDereference(expression),
FuncCallSyntax expression => CheckFuncCall(expression),
IdentifierSyntax expression => CheckIdentifier(expression),
LiteralSyntax expression => CheckLiteral(expression, expectedType),
MemberAccessSyntax expression => CheckMemberAccess(expression),
StructInitializerSyntax expression => CheckStructInitializer(expression),
UnaryExpressionSyntax expression => CheckUnaryExpression(expression),
_ => throw new ArgumentOutOfRangeException(nameof(node))
};
}
private AddressOf CheckAddressOf(AddressOfSyntax expression)
{
var inner = CheckExpression(expression.Expression);
return new AddressOf(new NubPointerType(inner.Type), inner);
}
private ArrowFunc CheckArrowFunc(ArrowFuncSyntax expression, NubType? expectedType = null)
{
if (expectedType == null)
{
throw new CheckException(Diagnostic.Error("Cannot infer argument types for arrow function").Build());
}
if (expectedType is not NubFuncType funcType)
{
throw new CheckException(Diagnostic.Error($"Expected {expectedType}, but got arrow function").Build());
}
var parameters = new List<FuncParameter>();
for (var i = 0; i < expression.Parameters.Count; i++)
{
if (i >= funcType.Parameters.Count)
{
throw new CheckException(Diagnostic.Error($"Arrow function expected a maximum of {funcType.Parameters.Count} arguments").Build());
}
var expectedParameterType = funcType.Parameters[i];
var parameter = expression.Parameters[i];
parameters.Add(new FuncParameter(parameter.Name, expectedParameterType));
}
var body = CheckFuncBody(expression.Body, funcType.ReturnType, parameters);
return new ArrowFunc(new NubFuncType(parameters.Select(x => x.Type).ToList(), funcType.ReturnType), parameters, funcType.ReturnType, body);
}
private ArrayIndexAccess CheckArrayIndexAccess(ArrayIndexAccessSyntax expression)
{
var boundArray = CheckExpression(expression.Target);
var elementType = ((NubArrayType)boundArray.Type).ElementType;
return new ArrayIndexAccess(elementType, boundArray, CheckExpression(expression.Index, new NubPrimitiveType(PrimitiveTypeKind.U64)));
}
private ArrayInitializer CheckArrayInitializer(ArrayInitializerSyntax expression)
{
var capacity = CheckExpression(expression.Capacity, new NubPrimitiveType(PrimitiveTypeKind.U64));
var type = new NubArrayType(CheckType(expression.ElementType));
return new ArrayInitializer(type, capacity, CheckType(expression.ElementType));
}
private BinaryExpression CheckBinaryExpression(BinaryExpressionSyntax expression)
{
var boundLeft = CheckExpression(expression.Left);
var boundRight = CheckExpression(expression.Right, boundLeft.Type);
return new BinaryExpression(boundLeft.Type, boundLeft, CheckBinaryOperator(expression.OperatorSyntax), boundRight);
}
private Dereference CheckDereference(DereferenceSyntax expression)
{
var boundExpression = CheckExpression(expression.Expression);
var dereferencedType = ((NubPointerType)boundExpression.Type).BaseType;
return new Dereference(dereferencedType, boundExpression);
}
private FuncCall CheckFuncCall(FuncCallSyntax expression)
{
var boundExpression = CheckExpression(expression.Expression);
var funcType = (NubFuncType)boundExpression.Type;
var parameters = new List<Expression>();
foreach (var (i, parameter) in expression.Parameters.Index())
{
if (i >= funcType.Parameters.Count)
{
throw new NotImplementedException("Diagnostics not implemented");
}
var expectedType = funcType.Parameters[i];
parameters.Add(CheckExpression(parameter, expectedType));
}
return new FuncCall(funcType.ReturnType, boundExpression, parameters);
}
private Expression CheckIdentifier(IdentifierSyntax expression)
{
var variable = Scope.Lookup(expression.Name);
if (variable != null)
{
return new VariableIdent(variable.Type, variable.Name);
}
var localFuncs = _definitionTable.LookupLocalFunc(expression.Name).ToArray();
if (localFuncs.Length > 0)
{
if (localFuncs.Length > 1)
{
throw new CheckException(Diagnostic.Error($"Extern func {expression.Name} has multiple definitions").Build());
}
var localFunc = localFuncs[0];
var returnType = CheckType(localFunc.Signature.ReturnType);
var parameterTypes = localFunc.Signature.Parameters.Select(p => CheckType(p.Type)).ToList();
var type = new NubFuncType(parameterTypes, returnType);
return new LocalFuncIdent(type, expression.Name);
}
var externFuncs = _definitionTable.LookupExternFunc(expression.Name).ToArray();
if (externFuncs.Length > 0)
{
if (externFuncs.Length > 1)
{
throw new CheckException(Diagnostic.Error($"Extern func {expression.Name} has multiple definitions").Build());
}
var externFunc = externFuncs[0];
var returnType = CheckType(externFunc.Signature.ReturnType);
var parameterTypes = externFunc.Signature.Parameters.Select(p => CheckType(p.Type)).ToList();
var type = new NubFuncType(parameterTypes, returnType);
return new ExternFuncIdent(type, expression.Name);
}
throw new CheckException(Diagnostic.Error($"No identifier with the name {expression.Name} exists").Build());
}
private Literal CheckLiteral(LiteralSyntax expression, NubType? expectedType = null)
{
var type = expectedType ?? expression.Kind switch
{
LiteralKind.Integer => new NubPrimitiveType(PrimitiveTypeKind.I64),
LiteralKind.Float => new NubPrimitiveType(PrimitiveTypeKind.F64),
LiteralKind.String => new NubStringType(),
LiteralKind.Bool => new NubPrimitiveType(PrimitiveTypeKind.Bool),
_ => throw new ArgumentOutOfRangeException()
};
return new Literal(type, expression.Value, expression.Kind);
}
private Expression CheckMemberAccess(MemberAccessSyntax expression)
{
var boundExpression = CheckExpression(expression.Target);
if (boundExpression.Type is NubCustomType customType)
{
var traits = _definitionTable.LookupTrait(customType).ToArray();
if (traits.Length > 0)
{
if (traits.Length > 1)
{
throw new CheckException(Diagnostic.Error($"Trait {customType} has multiple definitions").Build());
}
var trait = traits[0];
var traitFuncs = _definitionTable.LookupTraitFunc(trait, expression.Member).ToArray();
if (traits.Length > 0)
{
if (traits.Length > 1)
{
throw new CheckException(Diagnostic.Error($"Trait {customType} has multiple functions with the name {expression.Member}").Build());
}
var traitFunc = traitFuncs[0];
var returnType = CheckType(traitFunc.Signature.ReturnType);
var parameterTypes = traitFunc.Signature.Parameters.Select(p => CheckType(p.Type)).ToList();
var type = new NubFuncType(parameterTypes, returnType);
return new InterfaceFuncAccess(type, customType, boundExpression, expression.Member);
}
}
var structs = _definitionTable.LookupStruct(customType).ToArray();
if (structs.Length > 0)
{
if (structs.Length > 1)
{
throw new CheckException(Diagnostic.Error($"Struct {customType} has multiple definitions").Build());
}
var @struct = structs[0];
var fields = _definitionTable.LookupStructField(@struct, expression.Member).ToArray();
if (fields.Length > 0)
{
if (fields.Length > 1)
{
throw new CheckException(Diagnostic.Error($"Struct {customType} has multiple fields with the name {expression.Member}").Build());
}
var field = fields[0];
return new StructFieldAccess(CheckType(field.Type), customType, boundExpression, expression.Member);
}
}
}
throw new CheckException(Diagnostic.Error($"{boundExpression.Type} does not have a member with the name {expression.Member}").Build());
}
private StructInitializer CheckStructInitializer(StructInitializerSyntax expression)
{
var boundType = CheckType(expression.StructType);
if (boundType is not NubCustomType structType)
{
throw new CheckException(Diagnostic.Error($"Cannot initialize non-struct type {expression.StructType}").Build());
}
var structs = _definitionTable.LookupStruct(structType).ToArray();
if (structs.Length == 0)
{
throw new CheckException(Diagnostic.Error($"Struct {structType} is not defined").Build());
}
if (structs.Length > 1)
{
throw new CheckException(Diagnostic.Error($"Struct {structType} has multiple definitions").Build());
}
var @struct = structs[0];
var initializers = new Dictionary<string, Expression>();
foreach (var (field, initializer) in expression.Initializers)
{
var fields = _definitionTable.LookupStructField(@struct, field).ToArray();
if (fields.Length == 0)
{
throw new CheckException(Diagnostic.Error($"Struct {structType} does not have a field with the name {field}").Build());
}
if (fields.Length > 1)
{
throw new CheckException(Diagnostic.Error($"Struct {structType} has multiple fields with the name {field}").Build());
}
initializers[field] = CheckExpression(initializer, CheckType(fields[0].Type));
}
return new StructInitializer(structType, initializers);
}
private UnaryExpression CheckUnaryExpression(UnaryExpressionSyntax expression)
{
var boundOperand = CheckExpression(expression.Operand);
NubType? type = null;
switch (expression.OperatorSyntax)
{
case UnaryOperatorSyntax.Negate:
{
boundOperand = CheckExpression(expression.Operand, new NubPrimitiveType(PrimitiveTypeKind.I64));
if (boundOperand.Type.IsNumber)
{
type = boundOperand.Type;
}
break;
}
case UnaryOperatorSyntax.Invert:
{
boundOperand = CheckExpression(expression.Operand, new NubPrimitiveType(PrimitiveTypeKind.Bool));
type = new NubPrimitiveType(PrimitiveTypeKind.Bool);
break;
}
}
if (type == null)
{
throw new NotImplementedException("Diagnostics not implemented");
}
return new UnaryExpression(type, CheckUnaryOperator(expression.OperatorSyntax), boundOperand);
}
private FuncSignature CheckFuncSignature(FuncSignatureSyntax node)
{
var parameters = new List<FuncParameter>();
foreach (var parameter in node.Parameters)
{
parameters.Add(new FuncParameter(parameter.Name, CheckType(parameter.Type)));
}
return new FuncSignature(parameters, CheckType(node.ReturnType));
}
private BinaryOperator CheckBinaryOperator(BinaryOperatorSyntax op)
{
return op switch
{
BinaryOperatorSyntax.Equal => BinaryOperator.Equal,
BinaryOperatorSyntax.NotEqual => BinaryOperator.NotEqual,
BinaryOperatorSyntax.GreaterThan => BinaryOperator.GreaterThan,
BinaryOperatorSyntax.GreaterThanOrEqual => BinaryOperator.GreaterThanOrEqual,
BinaryOperatorSyntax.LessThan => BinaryOperator.LessThan,
BinaryOperatorSyntax.LessThanOrEqual => BinaryOperator.LessThanOrEqual,
BinaryOperatorSyntax.Plus => BinaryOperator.Plus,
BinaryOperatorSyntax.Minus => BinaryOperator.Minus,
BinaryOperatorSyntax.Multiply => BinaryOperator.Multiply,
BinaryOperatorSyntax.Divide => BinaryOperator.Divide,
_ => throw new ArgumentOutOfRangeException(nameof(op), op, null)
};
}
private UnaryOperator CheckUnaryOperator(UnaryOperatorSyntax op)
{
return op switch
{
UnaryOperatorSyntax.Negate => UnaryOperator.Negate,
UnaryOperatorSyntax.Invert => UnaryOperator.Invert,
_ => throw new ArgumentOutOfRangeException(nameof(op), op, null)
};
}
private Block CheckBlock(BlockSyntax node, Scope? scope = null)
{
var statements = new List<Statement>();
_scopes.Push(scope ?? Scope.SubScope());
foreach (var statement in node.Statements)
{
statements.Add(CheckStatement(statement));
}
_scopes.Pop();
return new Block(statements);
}
private Block CheckFuncBody(BlockSyntax block, NubType returnType, IReadOnlyList<FuncParameter> parameters)
{
_funcReturnTypes.Push(returnType);
var scope = new Scope();
foreach (var parameter in parameters)
{
scope.Declare(new Variable(parameter.Name, parameter.Type));
}
var body = CheckBlock(block, scope);
_funcReturnTypes.Pop();
return body;
}
private NubType CheckType(TypeSyntax node)
{
return node switch
{
ArrayTypeSyntax type => new NubArrayType(CheckType(type.BaseType)),
CStringTypeSyntax => new NubCStringType(),
CustomTypeSyntax type => new NubCustomType(type.MangledName()),
FuncTypeSyntax type => new NubFuncType(type.Parameters.Select(CheckType).ToList(), CheckType(type.ReturnType)),
PointerTypeSyntax type => new NubPointerType(CheckType(type.BaseType)),
PrimitiveTypeSyntax type => new NubPrimitiveType(type.SyntaxKind switch
{
PrimitiveTypeSyntaxKind.I64 => PrimitiveTypeKind.I64,
PrimitiveTypeSyntaxKind.I32 => PrimitiveTypeKind.I32,
PrimitiveTypeSyntaxKind.I16 => PrimitiveTypeKind.I16,
PrimitiveTypeSyntaxKind.I8 => PrimitiveTypeKind.I8,
PrimitiveTypeSyntaxKind.U64 => PrimitiveTypeKind.U64,
PrimitiveTypeSyntaxKind.U32 => PrimitiveTypeKind.U32,
PrimitiveTypeSyntaxKind.U16 => PrimitiveTypeKind.U16,
PrimitiveTypeSyntaxKind.U8 => PrimitiveTypeKind.U8,
PrimitiveTypeSyntaxKind.F64 => PrimitiveTypeKind.F64,
PrimitiveTypeSyntaxKind.F32 => PrimitiveTypeKind.F32,
PrimitiveTypeSyntaxKind.Bool => PrimitiveTypeKind.Bool,
_ => throw new ArgumentOutOfRangeException()
}),
StringTypeSyntax => new NubStringType(),
VoidTypeSyntax => new NubVoidType(),
_ => throw new ArgumentOutOfRangeException(nameof(node))
};
}
}
public record Variable(string Name, NubType Type);
public class Scope(Scope? parent = null)
{
private readonly List<Variable> _variables = [];
public Variable? Lookup(string name)
{
var variable = _variables.FirstOrDefault(x => x.Name == name);
if (variable != null)
{
return variable;
}
return parent?.Lookup(name);
}
public void Declare(Variable variable)
{
_variables.Add(variable);
}
public Scope SubScope()
{
return new Scope(this);
}
}
public class CheckException : Exception
{
public Diagnostic Diagnostic { get; }
public CheckException(Diagnostic diagnostic) : base(diagnostic.Message)
{
Diagnostic = diagnostic;
}
}