diff --git a/Nub.Lang/Nub.Lang/Backend/Custom/Generator.cs b/Nub.Lang/Nub.Lang/Backend/Custom/Generator.cs index eb2f730..dc22c8d 100644 --- a/Nub.Lang/Nub.Lang/Backend/Custom/Generator.cs +++ b/Nub.Lang/Nub.Lang/Backend/Custom/Generator.cs @@ -16,7 +16,7 @@ public class Generator public Generator(List definitions) { - _definitions = []; + _definitions = definitions; _builder = new StringBuilder(); _labelFactory = new LabelFactory(); _symbolTable = new SymbolTable(_labelFactory); @@ -25,19 +25,16 @@ public class Generator foreach (var globalVariableDefinition in definitions.OfType()) { _symbolTable.DefineGlobalVariable(globalVariableDefinition); - _definitions.Add(globalVariableDefinition); } foreach (var funcDefinitionNode in definitions.OfType()) { _symbolTable.DefineFunc(funcDefinitionNode); - _definitions.Add(funcDefinitionNode); } foreach (var funcDefinitionNode in definitions.OfType()) { _symbolTable.DefineFunc(funcDefinitionNode); - _definitions.Add(funcDefinitionNode); } } @@ -53,7 +50,7 @@ public class Generator _builder.AppendLine(); _builder.AppendLine("section .text"); - // TODO: Only add start label if main is present + // TODO: Only add start label if entrypoint is present, otherwise assume library var main = _symbolTable.ResolveLocalFunc(Entrypoint, []); _builder.AppendLine("_start:"); @@ -102,16 +99,15 @@ public class Generator foreach (var str in _symbolTable.Strings) { - _builder.AppendLine($"{str.Key}: db `{str.Value}`, 0"); + _builder.AppendLine($" {str.Key}: db `{str.Value}`, 0"); } - Dictionary completed = []; foreach (var globalVariableDefinition in _definitions.OfType()) { var variable = _symbolTable.ResolveGlobalVariable(globalVariableDefinition.Name); var evaluated = EvaluateExpression(globalVariableDefinition.Value, completed); - _builder.AppendLine($"{variable.Identifier}: dq {evaluated}"); + _builder.AppendLine($" {variable.Identifier}: dq {evaluated}"); completed[variable.Name] = evaluated; } @@ -341,7 +337,7 @@ public class Generator GenerateArrayIndexAccess(arrayIndexAccess, func); break; case ArrayInitializerNode arrayInitializer: - GenerateArrayInitializer(arrayInitializer, func); + GenerateArrayInitializer(arrayInitializer); break; case BinaryExpressionNode binaryExpression: GenerateBinaryExpression(binaryExpression, func); @@ -355,6 +351,9 @@ public class Generator case LiteralNode literal: GenerateLiteral(literal); break; + case StructInitializerNode structInitializer: + GenerateStructInitializer(structInitializer, func); + break; case SyscallExpressionNode syscallExpression: GenerateSyscall(syscallExpression.Syscall, func); break; @@ -369,7 +368,7 @@ public class Generator _builder.AppendLine(" mov rax, [rax]"); } - private void GenerateArrayInitializer(ArrayInitializerNode arrayInitializer, LocalFunc func) + private void GenerateArrayInitializer(ArrayInitializerNode arrayInitializer) { _builder.AppendLine($" sub rsp, {8 + arrayInitializer.Length * 8}"); _builder.AppendLine(" mov rax, rsp"); @@ -591,6 +590,45 @@ public class Generator } } + private void GenerateStructInitializer(StructInitializerNode structInitializer, LocalFunc func) + { + var structDefinition = _definitions + .OfType() + .FirstOrDefault(sd => sd.Name == structInitializer.StructType.Name); + + if (structDefinition == null) + { + throw new Exception($"Struct {structInitializer.StructType} is not defined"); + } + + _builder.AppendLine($" add rsp, {structDefinition.Members.Count * 8}"); + + foreach (var initializer in structInitializer.Initializers) + { + GenerateExpression(initializer.Value, func); + var index = structDefinition.Members.FindIndex(sd => sd.Name == initializer.Key); + if (index == -1) + { + throw new Exception($"Member {initializer.Key} is not defined on struct {structInitializer.StructType}"); + } + + _builder.AppendLine($" mov [rsp + {index * 8}], rax"); + } + + foreach (var uninitializedMember in structDefinition.Members.Where(m => !structInitializer.Initializers.ContainsKey(m.Name))) + { + if (!uninitializedMember.Value.HasValue) + { + throw new Exception($"Struct {structInitializer.StructType} must be initializer with member {uninitializedMember.Name}"); + } + + GenerateExpression(uninitializedMember.Value.Value, func); + _builder.AppendLine($" mov [rsp + {structDefinition.Members.IndexOf(uninitializedMember) * 8}], rax"); + } + + _builder.AppendLine(" mov rax, rsp"); + } + private void GenerateFuncCall(FuncCall funcCall, LocalFunc func) { var symbol = _symbolTable.ResolveFunc(funcCall.Name, funcCall.Parameters.Select(p => p.Type).ToList()); diff --git a/Nub.Lang/Nub.Lang/Frontend/Lexing/Lexer.cs b/Nub.Lang/Nub.Lang/Frontend/Lexing/Lexer.cs index 7a60d85..3af365d 100644 --- a/Nub.Lang/Nub.Lang/Frontend/Lexing/Lexer.cs +++ b/Nub.Lang/Nub.Lang/Frontend/Lexing/Lexer.cs @@ -17,6 +17,7 @@ public class Lexer ["continue"] = Symbol.Continue, ["return"] = Symbol.Return, ["new"] = Symbol.New, + ["struct"] = Symbol.Struct, }; private static readonly Dictionary Chians = new() diff --git a/Nub.Lang/Nub.Lang/Frontend/Lexing/SymbolToken.cs b/Nub.Lang/Nub.Lang/Frontend/Lexing/SymbolToken.cs index a86357a..123364a 100644 --- a/Nub.Lang/Nub.Lang/Frontend/Lexing/SymbolToken.cs +++ b/Nub.Lang/Nub.Lang/Frontend/Lexing/SymbolToken.cs @@ -41,4 +41,5 @@ public enum Symbol Star, ForwardSlash, New, + Struct } \ No newline at end of file diff --git a/Nub.Lang/Nub.Lang/Frontend/Parsing/Parser.cs b/Nub.Lang/Nub.Lang/Frontend/Parsing/Parser.cs index d2f0904..64a5d9d 100644 --- a/Nub.Lang/Nub.Lang/Frontend/Parsing/Parser.cs +++ b/Nub.Lang/Nub.Lang/Frontend/Parsing/Parser.cs @@ -47,6 +47,7 @@ public class Parser Symbol.Let => ParseGlobalVariableDefinition(), Symbol.Func => ParseFuncDefinition(), Symbol.Extern => ParseExternFuncDefinition(), + Symbol.Struct => ParseStruct(), _ => throw new Exception("Unexpected symbol: " + keyword.Symbol) }; } @@ -112,6 +113,36 @@ public class Parser return new ExternFuncDefinitionNode(name.Value, parameters, returnType); } + private StructDefinitionNode ParseStruct() + { + var name = ExpectIdentifier().Value; + + ExpectSymbol(Symbol.OpenBrace); + + List variables = []; + + while (!TryExpectSymbol(Symbol.CloseBrace)) + { + ExpectSymbol(Symbol.Let); + var variableName = ExpectIdentifier().Value; + ExpectSymbol(Symbol.Colon); + var variableType = ParseType(); + + var variableValue = Optional.Empty(); + + if (TryExpectSymbol(Symbol.Assign)) + { + variableValue = ParseExpression(); + } + + ExpectSymbol(Symbol.Semicolon); + + variables.Add(new StructMember(variableName, variableType, variableValue)); + } + + return new StructDefinitionNode(name, variables); + } + private FuncParameter ParseFuncParameter() { var name = ExpectIdentifier(); @@ -346,14 +377,40 @@ public class Parser case Symbol.New: { var type = ParseType(); - ExpectSymbol(Symbol.OpenParen); - var size = ExpectLiteral(); - if (size.Type is not PrimitiveType { Kind: PrimitiveTypeKind.Int64 }) + + switch (type) { - throw new Exception($"Array initializer size must be an {PrimitiveTypeKind.Int64}"); + // TODO: Parse arrays differently + case ArrayType: + { + ExpectSymbol(Symbol.OpenParen); + var size = ExpectLiteral(); + if (size.Type is not PrimitiveType { Kind: PrimitiveTypeKind.Int64 }) + { + throw new Exception($"Array initializer size must be an {PrimitiveTypeKind.Int64}"); + } + ExpectSymbol(Symbol.CloseParen); + + return new ArrayInitializerNode(long.Parse(size.Value), type); + } + case StructType structType: + { + Dictionary initializers = []; + ExpectSymbol(Symbol.OpenBrace); + while (!TryExpectSymbol(Symbol.CloseBrace)) + { + var name = ExpectIdentifier().Value; + ExpectSymbol(Symbol.Assign); + var value = ParseExpression(); + TryExpectSymbol(Symbol.Comma); + initializers.Add(name, value); + } + + return new StructInitializerNode(structType, initializers); + } + default: + throw new Exception($"Type {type} cannot be initialized with the new keyword"); } - ExpectSymbol(Symbol.CloseParen); - return new ArrayInitializerNode(long.Parse(size.Value), type); } default: throw new Exception($"Unknown symbol: {symbolToken.Symbol}"); @@ -408,7 +465,6 @@ public class Parser private Type ParseType() { var name = ExpectIdentifier().Value; - switch (name) { case "String": @@ -428,7 +484,12 @@ public class Parser } default: { - return PrimitiveType.Parse(name); + if (PrimitiveType.TryParse(name, out var primitiveType)) + { + return primitiveType; + } + + return new StructType(name); } } } diff --git a/Nub.Lang/Nub.Lang/Frontend/Parsing/StructDefinitionNode.cs b/Nub.Lang/Nub.Lang/Frontend/Parsing/StructDefinitionNode.cs new file mode 100644 index 0000000..d4eaada --- /dev/null +++ b/Nub.Lang/Nub.Lang/Frontend/Parsing/StructDefinitionNode.cs @@ -0,0 +1,7 @@ +namespace Nub.Lang.Frontend.Parsing; + +public class StructDefinitionNode(string name, List members) : DefinitionNode +{ + public string Name { get; } = name; + public List Members { get; } = members; +} \ No newline at end of file diff --git a/Nub.Lang/Nub.Lang/Frontend/Parsing/StructInitializerNode.cs b/Nub.Lang/Nub.Lang/Frontend/Parsing/StructInitializerNode.cs new file mode 100644 index 0000000..a04df84 --- /dev/null +++ b/Nub.Lang/Nub.Lang/Frontend/Parsing/StructInitializerNode.cs @@ -0,0 +1,7 @@ +namespace Nub.Lang.Frontend.Parsing; + +public class StructInitializerNode(StructType structType, Dictionary initializers) : ExpressionNode +{ + public StructType StructType { get; } = structType; + public Dictionary Initializers { get; } = initializers; +} \ No newline at end of file diff --git a/Nub.Lang/Nub.Lang/Frontend/Typing/ExpressionTyper.cs b/Nub.Lang/Nub.Lang/Frontend/Typing/ExpressionTyper.cs index bad0be9..eab4ef0 100644 --- a/Nub.Lang/Nub.Lang/Frontend/Typing/ExpressionTyper.cs +++ b/Nub.Lang/Nub.Lang/Frontend/Typing/ExpressionTyper.cs @@ -15,6 +15,7 @@ public class ExpressionTyper { private readonly List _functions; private readonly List _variableDefinitions; + private readonly List _classes; private readonly Stack _variables; public ExpressionTyper(List definitions) @@ -23,6 +24,8 @@ public class ExpressionTyper _functions = []; _variableDefinitions = []; + _classes = definitions.OfType().ToList(); + var functions = definitions .OfType() .Select(f => new Func(f.Name, f.Parameters, f.Body, f.ReturnType)) @@ -42,6 +45,17 @@ public class ExpressionTyper { _variables.Clear(); + foreach (var @class in _classes) + { + foreach (var variable in @class.Members) + { + if (variable.Value.HasValue) + { + PopulateExpression(variable.Value.Value); + } + } + } + foreach (var variable in _variableDefinitions) { PopulateExpression(variable.Value); @@ -199,6 +213,9 @@ public class ExpressionTyper case LiteralNode literal: PopulateLiteral(literal); break; + case StructInitializerNode structInitializer: + PopulateStructInitializer(structInitializer); + break; case SyscallExpressionNode syscall: PopulateSyscallExpression(syscall); break; @@ -296,6 +313,16 @@ public class ExpressionTyper literal.Type = literal.LiteralType; } + private void PopulateStructInitializer(StructInitializerNode structInitializer) + { + foreach (var initializer in structInitializer.Initializers) + { + PopulateExpression(initializer.Value); + } + + structInitializer.Type = structInitializer.StructType; + } + private void PopulateSyscallExpression(SyscallExpressionNode syscall) { foreach (var parameter in syscall.Syscall.Parameters) diff --git a/Nub.Lang/Nub.Lang/StructMember.cs b/Nub.Lang/Nub.Lang/StructMember.cs new file mode 100644 index 0000000..61bd9c0 --- /dev/null +++ b/Nub.Lang/Nub.Lang/StructMember.cs @@ -0,0 +1,11 @@ +using Nub.Core; +using Nub.Lang.Frontend.Parsing; + +namespace Nub.Lang; + +public class StructMember(string name, Type type, Optional value) +{ + public string Name { get; } = name; + public Type Type { get; } = type; + public Optional Value { get; } = value; +} \ No newline at end of file diff --git a/Nub.Lang/Nub.Lang/Type.cs b/Nub.Lang/Nub.Lang/Type.cs index da2196e..76b5a3b 100644 --- a/Nub.Lang/Nub.Lang/Type.cs +++ b/Nub.Lang/Nub.Lang/Type.cs @@ -1,4 +1,6 @@ -namespace Nub.Lang; +using System.Diagnostics.CodeAnalysis; + +namespace Nub.Lang; public abstract class Type { @@ -17,12 +19,14 @@ public abstract class Type protected abstract bool Equals(Type other); public abstract override int GetHashCode(); - public static bool operator == (Type left, Type right) + public static bool operator == (Type? left, Type? right) { + if (left is null && right is null) return true; + if (left is null || right is null) return false; return ReferenceEquals(left, right) || left.Equals(right); } - public static bool operator !=(Type left, Type right) => !(left == right); + public static bool operator !=(Type? left, Type? right) => !(left == right); } public class AnyType : Type @@ -32,13 +36,8 @@ public class AnyType : Type public override string ToString() => "Any"; } -public class PrimitiveType : Type +public class PrimitiveType(PrimitiveTypeKind kind) : Type { - public PrimitiveType(PrimitiveTypeKind kind) - { - Kind = kind; - } - // TODO: This should be looked at more in the future public override bool IsAssignableTo(Type otherType) { @@ -56,20 +55,20 @@ public class PrimitiveType : Type return false; } - public static PrimitiveType Parse(string value) + public static bool TryParse(string value, [NotNullWhen(true)] out PrimitiveType? result) { - var kind = value switch + result = value switch { - "bool" => PrimitiveTypeKind.Bool, - "int64" => PrimitiveTypeKind.Int64, - "int32" => PrimitiveTypeKind.Int32, - _ => throw new ArgumentOutOfRangeException(nameof(value), value, null) + "bool" => new PrimitiveType(PrimitiveTypeKind.Bool), + "int64" => new PrimitiveType(PrimitiveTypeKind.Int64), + "int32" => new PrimitiveType(PrimitiveTypeKind.Int32), + _ => null }; - return new PrimitiveType(kind); + return result != null; } - - public PrimitiveTypeKind Kind { get; } + + public PrimitiveTypeKind Kind { get; } = kind; protected override bool Equals(Type other) => other is PrimitiveType primitiveType && Kind == primitiveType.Kind; public override int GetHashCode() => Kind.GetHashCode(); @@ -90,14 +89,9 @@ public class StringType : Type public override string ToString() => "String"; } -public class ArrayType : Type +public class ArrayType(Type innerType) : Type { - public ArrayType(Type innerType) - { - InnerType = innerType; - } - - public Type InnerType { get; } + public Type InnerType { get; } = innerType; public override bool IsAssignableTo(Type otherType) { @@ -108,4 +102,13 @@ public class ArrayType : Type protected override bool Equals(Type other) => other is ArrayType at && InnerType.Equals(at.InnerType); public override int GetHashCode() => HashCode.Combine(InnerType); public override string ToString() => $"Array<{InnerType}>"; +} + +public class StructType(string name) : Type +{ + public string Name { get; } = name; + + protected override bool Equals(Type other) => other is StructType classType && Name == classType.Name; + public override int GetHashCode() => Name.GetHashCode(); + public override string ToString() => Name; } \ No newline at end of file diff --git a/input/core/print.nub b/input/core/print.nub index bc4d90e..624233a 100644 --- a/input/core/print.nub +++ b/input/core/print.nub @@ -6,12 +6,12 @@ func print(msg: String) { syscall(SYS_WRITE, STD_OUT, msg, str_len(msg)); } -func print(value: int64) { - print(itoa(value)); +func print(value1: int64) { + print(itoa(value1)); } -func print(value: bool) { - if value { +func print(value2: bool) { + if value2 { print("true"); } else { print("false"); @@ -27,12 +27,12 @@ func println(msg: String) { println(); } -func println(value: bool) { - print(value); +func println(value3: bool) { + print(value3); println(); } -func println(value: int64) { - print(value); +func println(value4: int64) { + print(value4); println(); } diff --git a/input/program.nub b/input/program.nub index 116925b..4dc63f7 100644 --- a/input/program.nub +++ b/input/program.nub @@ -1,6 +1,19 @@ import "core"; func main() { + let x = new Test + { + some_string = "test2", + some_int = 69, + }; +} + +struct Test { + let some_string: String; + let some_int: int64; +} + +func example() { let some_string = "test"; println(some_string); @@ -17,4 +30,4 @@ func main() { println(some_array[i]); i = i + 1; } -} +} \ No newline at end of file