From eba128463da49a0a12556e1514253ac01cc187e2 Mon Sep 17 00:00:00 2001 From: nub31 Date: Thu, 22 May 2025 19:31:11 +0200 Subject: [PATCH] Array type works with args --- example/program.nub | 27 +--- src/compiler/Nub.Lang/Backend/Generator.cs | 131 ++++++++++++------ .../Frontend/Parsing/ArrayIndexNode.cs | 7 + .../Frontend/Parsing/MemberAccessNode.cs | 7 + .../Nub.Lang/Frontend/Parsing/Parser.cs | 52 +++++-- .../Parsing/StructFieldAccessorNode.cs | 7 - .../Nub.Lang/Frontend/Typing/TypeChecker.cs | 91 +++++++++--- src/compiler/Nub.Lang/NubType.cs | 24 +++- src/runtime/runtime.asm | 9 +- 9 files changed, 250 insertions(+), 105 deletions(-) create mode 100644 src/compiler/Nub.Lang/Frontend/Parsing/ArrayIndexNode.cs create mode 100644 src/compiler/Nub.Lang/Frontend/Parsing/MemberAccessNode.cs delete mode 100644 src/compiler/Nub.Lang/Frontend/Parsing/StructFieldAccessorNode.cs diff --git a/example/program.nub b/example/program.nub index aa95c8a..e1cfe83 100644 --- a/example/program.nub +++ b/example/program.nub @@ -1,23 +1,10 @@ import c -struct Test { - name: ^string -} - -struct Test2 { - parent: ^Test -} - -global func main(argc: i64, argv: i64) { - name = "Oliver" - - parent = new Test { - name = &name +global func main(args: []string) { + i = 0 + printf("%d\n", args.count) + while i < args.count { + printf("%s\n", args[i]) + i = i + 1 } - - test = new Test2 { - parent = &parent - } - - printf("%s\n", (test.parent^.name^)) -} \ No newline at end of file +} diff --git a/src/compiler/Nub.Lang/Backend/Generator.cs b/src/compiler/Nub.Lang/Backend/Generator.cs index b60e9af..b4314eb 100644 --- a/src/compiler/Nub.Lang/Backend/Generator.cs +++ b/src/compiler/Nub.Lang/Backend/Generator.cs @@ -74,6 +74,7 @@ public class Generator } case NubStructType: case NubPointerType: + case NubArrayType: { return "l"; } @@ -120,6 +121,7 @@ public class Generator return ":" + nubCustomType.Name; } case NubPointerType: + case NubArrayType: { return "l"; } @@ -169,6 +171,7 @@ public class Generator return ":" + nubCustomType.Name; } case NubPointerType: + case NubArrayType: { return "l"; } @@ -221,6 +224,7 @@ public class Generator return definition.Fields.Sum(f => QbeTypeSize(f.Type)); } case NubPointerType: + case NubArrayType: { return 8; } @@ -259,7 +263,7 @@ public class Generator _builder.AppendLine("@start"); _builder.AppendLine(" # Variable allocation"); - + foreach (var parameter in node.Parameters) { var parameterName = parameter.Name; @@ -287,8 +291,8 @@ public class Generator var pointerLabel = GenName(); _builder.AppendLine($" %{pointerLabel} ={SQT(parameter.Type)} alloc8 {QbeTypeSize(parameter.Type)}"); _builder.AppendLine($" storel %{parameterName}, %{pointerLabel}"); - - + + _variables[parameter.Name] = new Variable { Pointer = $"%{pointerLabel}", @@ -474,15 +478,23 @@ public class Generator private void GenerateVariableAssignment(VariableAssignmentNode variableAssignment) { var result = GenerateExpression(variableAssignment.Value); - var pointerLabel = GenName(); - _builder.AppendLine($" %{pointerLabel} ={SQT(variableAssignment.Value.Type)} alloc8 {QbeTypeSize(variableAssignment.Value.Type)}"); - _builder.AppendLine($" storel {result}, %{pointerLabel}"); - _variables[variableAssignment.Name] = new Variable + if (_variables.TryGetValue(variableAssignment.Name, out var existingVariable)) { - Pointer = $"%{pointerLabel}", - Type = variableAssignment.Value.Type - }; + _builder.AppendLine($" storel {result}, {existingVariable.Pointer}"); + } + else + { + var pointerLabel = GenName(); + _builder.AppendLine($" %{pointerLabel} ={SQT(variableAssignment.Value.Type)} alloc8 {QbeTypeSize(variableAssignment.Value.Type)}"); + _builder.AppendLine($" storel {result}, %{pointerLabel}"); + + _variables[variableAssignment.Name] = new Variable + { + Pointer = $"%{pointerLabel}", + Type = variableAssignment.Value.Type + }; + } } private void GenerateWhile(WhileNode whileStatement) @@ -511,6 +523,7 @@ public class Generator return expression switch { AddressOfNode addressOf => GenerateAddressOf(addressOf), + ArrayIndexNode arrayIndex => GenerateArrayIndex(arrayIndex), BinaryExpressionNode binaryExpression => GenerateBinaryExpression(binaryExpression), CastNode cast => GenerateCast(cast), DereferenceNode dereference => GenerateDereference(dereference), @@ -519,11 +532,29 @@ public class Generator LiteralNode literal => GenerateLiteral(literal), StructInitializerNode structInitializer => GenerateStructInitializer(structInitializer), UnaryExpressionNode unaryExpression => GenerateUnaryExpression(unaryExpression), - StructFieldAccessorNode structMemberAccessor => GenerateStructFieldAccessor(structMemberAccessor), + MemberAccessNode memberAccess => GenerateMemberAccess(memberAccess), _ => throw new ArgumentOutOfRangeException(nameof(expression)) }; } + private string GenerateArrayIndex(ArrayIndexNode arrayIndex) + { + var array = GenerateExpression(arrayIndex.Expression); + var index = GenerateExpression(arrayIndex.Index); + + var arrayBaseType = ((NubArrayType)arrayIndex.Expression.Type).BaseType; + + var firstItem = GenName(); + _builder.AppendLine($" %{firstItem} =l add {array}, 8"); + var adjustedIndex = GenName(); + _builder.AppendLine($" %{adjustedIndex} =l mul {index}, {QbeTypeSize(arrayBaseType)}"); + var indexLabel = GenName(); + _builder.AppendLine($" %{indexLabel} ={SQT(arrayIndex.Type)} add %{firstItem}, %{adjustedIndex}"); + var outputLabel = GenName(); + _builder.AppendLine($" %{outputLabel} =l load{SQT(arrayBaseType)} %{indexLabel}"); + return $"%{outputLabel}"; + } + private string GenerateDereference(DereferenceNode dereference) { var result = GenerateExpression(dereference.Expression); @@ -1331,7 +1362,7 @@ public class Generator _builder.AppendLine($" %{outputLabel} =s neg {operand}"); return $"%{outputLabel}"; } - + break; } case UnaryExpressionOperator.Invert: @@ -1354,42 +1385,60 @@ public class Generator throw new NotSupportedException($"Unary operator {unaryExpression.Operator} for type {unaryExpression.Operand.Type} not supported"); } - private string GenerateStructFieldAccessor(StructFieldAccessorNode structFieldAccessor) + private string GenerateMemberAccess(MemberAccessNode memberAccess) { - var structType = structFieldAccessor.Struct.Type; - var structDefinition = _definitions - .OfType() - .FirstOrDefault(s => s.Name == structType.Name); + var expression = GenerateExpression(memberAccess.Expression); - if (structDefinition == null) + switch (memberAccess.Expression.Type) { - throw new Exception($"Struct {structType.Name} is not defined"); - } - - var @struct = GenerateExpression(structFieldAccessor.Struct); - - var fieldIndex = -1; - for (var i = 0; i < structDefinition.Fields.Count; i++) - { - if (structDefinition.Fields[i].Name == structFieldAccessor.Field) + case NubArrayType: { - fieldIndex = i; + if (memberAccess.Member == "count") + { + var outputLabel = GenName(); + _builder.AppendLine($" %{outputLabel} =l loadl {expression}"); + return $"%{outputLabel}"; + } + break; } + case NubStructType structType: + { + var structDefinition = _definitions + .OfType() + .FirstOrDefault(s => s.Name == structType.Name); + + if (structDefinition == null) + { + throw new Exception($"Struct {structType.Name} is not defined"); + } + + var fieldIndex = -1; + for (var i = 0; i < structDefinition.Fields.Count; i++) + { + if (structDefinition.Fields[i].Name == memberAccess.Member) + { + fieldIndex = i; + break; + } + } + + if (fieldIndex == -1) + { + throw new Exception($"Field {memberAccess.Member} is not defined in struct {structType.Name}"); + } + + var offsetLabel = GenName(); + _builder.AppendLine($" %{offsetLabel} =l add {expression}, {fieldIndex * QbeTypeSize(memberAccess.Type)}"); + + var outputLabel = GenName(); + _builder.AppendLine($" %{outputLabel} ={SQT(memberAccess.Type)} load{SQT(memberAccess.Type)} %{offsetLabel}"); + + return $"%{outputLabel}"; + } } - - if (fieldIndex == -1) - { - throw new Exception($"Field {structFieldAccessor.Field} is not defined in struct {structType.Name}"); - } - - var offsetLabel = GenName(); - _builder.AppendLine($" %{offsetLabel} =l add {@struct}, {fieldIndex * QbeTypeSize(structFieldAccessor.Type)}"); - - var outputLabel = GenName(); - _builder.AppendLine($" %{outputLabel} ={SQT(structFieldAccessor.Type)} load{SQT(structFieldAccessor.Type)} %{offsetLabel}"); - - return $"%{outputLabel}"; + + throw new ArgumentOutOfRangeException(nameof(memberAccess.Expression.Type)); } private string GenerateExpressionFuncCall(FuncCallExpressionNode funcCall) diff --git a/src/compiler/Nub.Lang/Frontend/Parsing/ArrayIndexNode.cs b/src/compiler/Nub.Lang/Frontend/Parsing/ArrayIndexNode.cs new file mode 100644 index 0000000..4fe8606 --- /dev/null +++ b/src/compiler/Nub.Lang/Frontend/Parsing/ArrayIndexNode.cs @@ -0,0 +1,7 @@ +namespace Nub.Lang.Frontend.Parsing; + +public class ArrayIndexNode(ExpressionNode expression, ExpressionNode index) : ExpressionNode +{ + public ExpressionNode Expression { get; } = expression; + public ExpressionNode Index { get; } = index; +} \ No newline at end of file diff --git a/src/compiler/Nub.Lang/Frontend/Parsing/MemberAccessNode.cs b/src/compiler/Nub.Lang/Frontend/Parsing/MemberAccessNode.cs new file mode 100644 index 0000000..592e6ef --- /dev/null +++ b/src/compiler/Nub.Lang/Frontend/Parsing/MemberAccessNode.cs @@ -0,0 +1,7 @@ +namespace Nub.Lang.Frontend.Parsing; + +public class MemberAccessNode(ExpressionNode expression, string member) : ExpressionNode +{ + public ExpressionNode Expression { get; } = expression; + public string Member { get; } = member; +} \ No newline at end of file diff --git a/src/compiler/Nub.Lang/Frontend/Parsing/Parser.cs b/src/compiler/Nub.Lang/Frontend/Parsing/Parser.cs index 1698e13..e4b1234 100644 --- a/src/compiler/Nub.Lang/Frontend/Parsing/Parser.cs +++ b/src/compiler/Nub.Lang/Frontend/Parsing/Parser.cs @@ -427,17 +427,17 @@ public class Parser if (TryExpectSymbol(Symbol.Period)) { var structMember = ExpectIdentifier().Value; - expr = new StructFieldAccessorNode(expr, structMember); + expr = new MemberAccessNode(expr, structMember); continue; } - // if (TryExpectSymbol(Symbol.OpenBracket)) - // { - // var index = ParseExpression(); - // ExpectSymbol(Symbol.CloseBracket); - // expr = new ArrayIndexNode(expr, index); - // continue; - // } + if (TryExpectSymbol(Symbol.OpenBracket)) + { + var index = ParseExpression(); + ExpectSymbol(Symbol.CloseBracket); + expr = new ArrayIndexNode(expr, index); + continue; + } break; } @@ -459,10 +459,25 @@ public class Parser private NubType ParseType() { - var pointer = TryExpectSymbol(Symbol.Caret); - var name = ExpectIdentifier().Value; - var type = NubType.Parse(name); - return pointer ? new NubPointerType(type) : type; + if (TryExpectIdentifier(out var name)) + { + return NubType.Parse(name); + } + + if (TryExpectSymbol(Symbol.Caret)) + { + var baseType = ParseType(); + return new NubPointerType(baseType); + } + + if (TryExpectSymbol(Symbol.OpenBracket)) + { + ExpectSymbol(Symbol.CloseBracket); + var baseType = ParseType(); + return new NubArrayType(baseType); + } + + throw new Exception($"Unexpected token {Peek()} when parsing type"); } private Token ExpectToken() @@ -517,6 +532,19 @@ public class Parser return false; } + private bool TryExpectIdentifier([NotNullWhen(true)] out string? identifier) + { + if (Peek() is { HasValue: true, Value: IdentifierToken identifierToken }) + { + identifier = identifierToken.Value; + Next(); + return true; + } + + identifier = null; + return false; + } + private IdentifierToken ExpectIdentifier() { var token = ExpectToken(); diff --git a/src/compiler/Nub.Lang/Frontend/Parsing/StructFieldAccessorNode.cs b/src/compiler/Nub.Lang/Frontend/Parsing/StructFieldAccessorNode.cs deleted file mode 100644 index cd7e285..0000000 --- a/src/compiler/Nub.Lang/Frontend/Parsing/StructFieldAccessorNode.cs +++ /dev/null @@ -1,7 +0,0 @@ -namespace Nub.Lang.Frontend.Parsing; - -public class StructFieldAccessorNode(ExpressionNode @struct, string field) : ExpressionNode -{ - public ExpressionNode Struct { get; } = @struct; - public string Field { get; } = field; -} \ No newline at end of file diff --git a/src/compiler/Nub.Lang/Frontend/Typing/TypeChecker.cs b/src/compiler/Nub.Lang/Frontend/Typing/TypeChecker.cs index 2b18ce8..372b248 100644 --- a/src/compiler/Nub.Lang/Frontend/Typing/TypeChecker.cs +++ b/src/compiler/Nub.Lang/Frontend/Typing/TypeChecker.cs @@ -253,6 +253,7 @@ public class TypeChecker var resultType = expression switch { AddressOfNode addressOf => TypeCheckAddressOf(addressOf), + ArrayIndexNode arrayIndex => TypeCheckArrayIndex(arrayIndex), LiteralNode literal => literal.LiteralType, IdentifierNode identifier => TypeCheckIdentifier(identifier), BinaryExpressionNode binaryExpr => TypeCheckBinaryExpression(binaryExpr), @@ -261,7 +262,7 @@ public class TypeChecker FuncCallExpressionNode funcCallExpr => TypeCheckFuncCall(funcCallExpr.FuncCall), StructInitializerNode structInit => TypeCheckStructInitializer(structInit), UnaryExpressionNode unaryExpression => TypeCheckUnaryExpression(unaryExpression), - StructFieldAccessorNode fieldAccess => TypeCheckStructFieldAccess(fieldAccess), + MemberAccessNode memberAccess => TypeCheckMemberAccess(memberAccess), _ => throw new TypeCheckingException($"Unsupported expression type: {expression.GetType().Name}") }; @@ -269,6 +270,23 @@ public class TypeChecker return resultType; } + private NubType TypeCheckArrayIndex(ArrayIndexNode arrayIndex) + { + var expressionType = TypeCheckExpression(arrayIndex.Expression); + if (expressionType is not NubArrayType arrayType) + { + throw new TypeCheckingException($"Annot access index of non-array type {expressionType}"); + } + + var indexType = TypeCheckExpression(arrayIndex.Index); + if (!IsInteger(indexType)) + { + throw new TypeCheckingException("Array index type must be an integer"); + } + + return arrayType.BaseType; + } + private NubType TypeCheckIdentifier(IdentifierNode identifier) { if (!_variables.TryGetValue(identifier.Identifier, out var varType)) @@ -282,7 +300,7 @@ public class TypeChecker private NubType TypeCheckAddressOf(AddressOfNode addressOf) { TypeCheckExpression(addressOf.Expression); - if (addressOf.Expression is not (IdentifierNode or StructFieldAccessorNode)) + if (addressOf.Expression is not (IdentifierNode or MemberAccessNode)) { throw new TypeCheckingException($"Cannot take the address of {addressOf.Expression.Type}"); } @@ -422,27 +440,39 @@ public class TypeChecker } } - private NubType TypeCheckStructFieldAccess(StructFieldAccessorNode fieldAccess) + private NubType TypeCheckMemberAccess(MemberAccessNode memberAccess) { - var structType = TypeCheckExpression(fieldAccess.Struct); - if (structType is not NubStructType customType) + var expressionType = TypeCheckExpression(memberAccess.Expression); + switch (expressionType) { - throw new TypeCheckingException($"Cannot access field '{fieldAccess.Field}' on non-struct type '{structType}'"); + case NubArrayType: + { + if (memberAccess.Member == "count") + { + return NubPrimitiveType.I64; + } + + break; + } + case NubStructType structType: + { + var definition = _definitions.OfType().FirstOrDefault(s => s.Name == structType.Name); + if (definition == null) + { + throw new TypeCheckingException($"Struct type '{structType.Name}' is not defined"); + } + + var field = definition.Fields.FirstOrDefault(f => f.Name == memberAccess.Member); + if (field == null) + { + throw new TypeCheckingException($"Field '{memberAccess.Member}' does not exist in struct '{structType.Name}'"); + } + + return field.Type; + } } - var definition = _definitions.OfType().FirstOrDefault(s => s.Name == customType.Name); - if (definition == null) - { - throw new TypeCheckingException($"Struct type '{customType.Name}' is not defined"); - } - - var field = definition.Fields.FirstOrDefault(f => f.Name == fieldAccess.Field); - if (field == null) - { - throw new TypeCheckingException($"Field '{fieldAccess.Field}' does not exist in struct '{customType.Name}'"); - } - - return field.Type; + throw new TypeCheckingException($"Cannot access member '{memberAccess.Member}' on type '{expressionType}'"); } private static bool AreTypesCompatible(NubType sourceType, NubType targetType) @@ -474,4 +504,27 @@ public class TypeChecker return false; } } + + private static bool IsInteger(NubType type) + { + if (type is not NubPrimitiveType primitiveType) + { + return false; + } + + switch (primitiveType.Kind) + { + case PrimitiveTypeKind.I8: + case PrimitiveTypeKind.I16: + case PrimitiveTypeKind.I32: + case PrimitiveTypeKind.I64: + case PrimitiveTypeKind.U8: + case PrimitiveTypeKind.U16: + case PrimitiveTypeKind.U32: + case PrimitiveTypeKind.U64: + return true; + default: + return false; + } + } } \ No newline at end of file diff --git a/src/compiler/Nub.Lang/NubType.cs b/src/compiler/Nub.Lang/NubType.cs index edd8f4d..56e983e 100644 --- a/src/compiler/Nub.Lang/NubType.cs +++ b/src/compiler/Nub.Lang/NubType.cs @@ -41,7 +41,29 @@ public class NubPointerType(NubType baseType) : NubType("^" + baseType) return false; } - public override int GetHashCode() => BaseType.GetHashCode() * 31; + public override int GetHashCode() + { + return HashCode.Combine(base.GetHashCode(), BaseType); + } +} + +public class NubArrayType(NubType baseType) : NubType("[]" + baseType) +{ + public NubType BaseType { get; } = baseType; + + public override bool Equals(object? obj) + { + if (obj is NubArrayType other) + { + return BaseType.Equals(other.BaseType); + } + return false; + } + + public override int GetHashCode() + { + return HashCode.Combine(base.GetHashCode(), BaseType); + } } public class NubPrimitiveType(PrimitiveTypeKind kind) : NubType(KindToString(kind)) diff --git a/src/runtime/runtime.asm b/src/runtime/runtime.asm index f710165..c050e9b 100644 --- a/src/runtime/runtime.asm +++ b/src/runtime/runtime.asm @@ -3,11 +3,9 @@ extern main section .text _start: - ; Extract argc and argv from the stack - mov rdi, [rsp] ; rdi = argc - lea rsi, [rsp + 8] ; rsi = argv (pointer to array of strings) + ; The args already match our array structure, so we pass the result directly + mov rdi, rsp - ; Call main(argc, argv) call main ; main returns int in rax ; Exit with main's return value @@ -33,4 +31,5 @@ nub_strcmp: ret .equal: mov rax, 1 - ret \ No newline at end of file + ret + \ No newline at end of file