Implements struct initializers

This commit is contained in:
nub31
2025-01-31 22:47:55 +01:00
parent ee640ae7a8
commit 16a031823e
11 changed files with 221 additions and 52 deletions

View File

@@ -16,7 +16,7 @@ public class Generator
public Generator(List<DefinitionNode> definitions)
{
_definitions = [];
_definitions = definitions;
_builder = new StringBuilder();
_labelFactory = new LabelFactory();
_symbolTable = new SymbolTable(_labelFactory);
@@ -25,19 +25,16 @@ public class Generator
foreach (var globalVariableDefinition in definitions.OfType<GlobalVariableDefinitionNode>())
{
_symbolTable.DefineGlobalVariable(globalVariableDefinition);
_definitions.Add(globalVariableDefinition);
}
foreach (var funcDefinitionNode in definitions.OfType<ExternFuncDefinitionNode>())
{
_symbolTable.DefineFunc(funcDefinitionNode);
_definitions.Add(funcDefinitionNode);
}
foreach (var funcDefinitionNode in definitions.OfType<LocalFuncDefinitionNode>())
{
_symbolTable.DefineFunc(funcDefinitionNode);
_definitions.Add(funcDefinitionNode);
}
}
@@ -53,7 +50,7 @@ public class Generator
_builder.AppendLine();
_builder.AppendLine("section .text");
// TODO: Only add start label if main is present
// TODO: Only add start label if entrypoint is present, otherwise assume library
var main = _symbolTable.ResolveLocalFunc(Entrypoint, []);
_builder.AppendLine("_start:");
@@ -102,16 +99,15 @@ public class Generator
foreach (var str in _symbolTable.Strings)
{
_builder.AppendLine($"{str.Key}: db `{str.Value}`, 0");
_builder.AppendLine($" {str.Key}: db `{str.Value}`, 0");
}
Dictionary<string, string> completed = [];
foreach (var globalVariableDefinition in _definitions.OfType<GlobalVariableDefinitionNode>())
{
var variable = _symbolTable.ResolveGlobalVariable(globalVariableDefinition.Name);
var evaluated = EvaluateExpression(globalVariableDefinition.Value, completed);
_builder.AppendLine($"{variable.Identifier}: dq {evaluated}");
_builder.AppendLine($" {variable.Identifier}: dq {evaluated}");
completed[variable.Name] = evaluated;
}
@@ -341,7 +337,7 @@ public class Generator
GenerateArrayIndexAccess(arrayIndexAccess, func);
break;
case ArrayInitializerNode arrayInitializer:
GenerateArrayInitializer(arrayInitializer, func);
GenerateArrayInitializer(arrayInitializer);
break;
case BinaryExpressionNode binaryExpression:
GenerateBinaryExpression(binaryExpression, func);
@@ -355,6 +351,9 @@ public class Generator
case LiteralNode literal:
GenerateLiteral(literal);
break;
case StructInitializerNode structInitializer:
GenerateStructInitializer(structInitializer, func);
break;
case SyscallExpressionNode syscallExpression:
GenerateSyscall(syscallExpression.Syscall, func);
break;
@@ -369,7 +368,7 @@ public class Generator
_builder.AppendLine(" mov rax, [rax]");
}
private void GenerateArrayInitializer(ArrayInitializerNode arrayInitializer, LocalFunc func)
private void GenerateArrayInitializer(ArrayInitializerNode arrayInitializer)
{
_builder.AppendLine($" sub rsp, {8 + arrayInitializer.Length * 8}");
_builder.AppendLine(" mov rax, rsp");
@@ -591,6 +590,45 @@ public class Generator
}
}
private void GenerateStructInitializer(StructInitializerNode structInitializer, LocalFunc func)
{
var structDefinition = _definitions
.OfType<StructDefinitionNode>()
.FirstOrDefault(sd => sd.Name == structInitializer.StructType.Name);
if (structDefinition == null)
{
throw new Exception($"Struct {structInitializer.StructType} is not defined");
}
_builder.AppendLine($" add rsp, {structDefinition.Members.Count * 8}");
foreach (var initializer in structInitializer.Initializers)
{
GenerateExpression(initializer.Value, func);
var index = structDefinition.Members.FindIndex(sd => sd.Name == initializer.Key);
if (index == -1)
{
throw new Exception($"Member {initializer.Key} is not defined on struct {structInitializer.StructType}");
}
_builder.AppendLine($" mov [rsp + {index * 8}], rax");
}
foreach (var uninitializedMember in structDefinition.Members.Where(m => !structInitializer.Initializers.ContainsKey(m.Name)))
{
if (!uninitializedMember.Value.HasValue)
{
throw new Exception($"Struct {structInitializer.StructType} must be initializer with member {uninitializedMember.Name}");
}
GenerateExpression(uninitializedMember.Value.Value, func);
_builder.AppendLine($" mov [rsp + {structDefinition.Members.IndexOf(uninitializedMember) * 8}], rax");
}
_builder.AppendLine(" mov rax, rsp");
}
private void GenerateFuncCall(FuncCall funcCall, LocalFunc func)
{
var symbol = _symbolTable.ResolveFunc(funcCall.Name, funcCall.Parameters.Select(p => p.Type).ToList());

View File

@@ -17,6 +17,7 @@ public class Lexer
["continue"] = Symbol.Continue,
["return"] = Symbol.Return,
["new"] = Symbol.New,
["struct"] = Symbol.Struct,
};
private static readonly Dictionary<char[], Symbol> Chians = new()

View File

@@ -41,4 +41,5 @@ public enum Symbol
Star,
ForwardSlash,
New,
Struct
}

View File

@@ -47,6 +47,7 @@ public class Parser
Symbol.Let => ParseGlobalVariableDefinition(),
Symbol.Func => ParseFuncDefinition(),
Symbol.Extern => ParseExternFuncDefinition(),
Symbol.Struct => ParseStruct(),
_ => throw new Exception("Unexpected symbol: " + keyword.Symbol)
};
}
@@ -112,6 +113,36 @@ public class Parser
return new ExternFuncDefinitionNode(name.Value, parameters, returnType);
}
private StructDefinitionNode ParseStruct()
{
var name = ExpectIdentifier().Value;
ExpectSymbol(Symbol.OpenBrace);
List<StructMember> variables = [];
while (!TryExpectSymbol(Symbol.CloseBrace))
{
ExpectSymbol(Symbol.Let);
var variableName = ExpectIdentifier().Value;
ExpectSymbol(Symbol.Colon);
var variableType = ParseType();
var variableValue = Optional<ExpressionNode>.Empty();
if (TryExpectSymbol(Symbol.Assign))
{
variableValue = ParseExpression();
}
ExpectSymbol(Symbol.Semicolon);
variables.Add(new StructMember(variableName, variableType, variableValue));
}
return new StructDefinitionNode(name, variables);
}
private FuncParameter ParseFuncParameter()
{
var name = ExpectIdentifier();
@@ -346,6 +377,12 @@ public class Parser
case Symbol.New:
{
var type = ParseType();
switch (type)
{
// TODO: Parse arrays differently
case ArrayType:
{
ExpectSymbol(Symbol.OpenParen);
var size = ExpectLiteral();
if (size.Type is not PrimitiveType { Kind: PrimitiveTypeKind.Int64 })
@@ -353,8 +390,28 @@ public class Parser
throw new Exception($"Array initializer size must be an {PrimitiveTypeKind.Int64}");
}
ExpectSymbol(Symbol.CloseParen);
return new ArrayInitializerNode(long.Parse(size.Value), type);
}
case StructType structType:
{
Dictionary<string, ExpressionNode> initializers = [];
ExpectSymbol(Symbol.OpenBrace);
while (!TryExpectSymbol(Symbol.CloseBrace))
{
var name = ExpectIdentifier().Value;
ExpectSymbol(Symbol.Assign);
var value = ParseExpression();
TryExpectSymbol(Symbol.Comma);
initializers.Add(name, value);
}
return new StructInitializerNode(structType, initializers);
}
default:
throw new Exception($"Type {type} cannot be initialized with the new keyword");
}
}
default:
throw new Exception($"Unknown symbol: {symbolToken.Symbol}");
}
@@ -408,7 +465,6 @@ public class Parser
private Type ParseType()
{
var name = ExpectIdentifier().Value;
switch (name)
{
case "String":
@@ -428,7 +484,12 @@ public class Parser
}
default:
{
return PrimitiveType.Parse(name);
if (PrimitiveType.TryParse(name, out var primitiveType))
{
return primitiveType;
}
return new StructType(name);
}
}
}

View File

@@ -0,0 +1,7 @@
namespace Nub.Lang.Frontend.Parsing;
public class StructDefinitionNode(string name, List<StructMember> members) : DefinitionNode
{
public string Name { get; } = name;
public List<StructMember> Members { get; } = members;
}

View File

@@ -0,0 +1,7 @@
namespace Nub.Lang.Frontend.Parsing;
public class StructInitializerNode(StructType structType, Dictionary<string, ExpressionNode> initializers) : ExpressionNode
{
public StructType StructType { get; } = structType;
public Dictionary<string, ExpressionNode> Initializers { get; } = initializers;
}

View File

@@ -15,6 +15,7 @@ public class ExpressionTyper
{
private readonly List<Func> _functions;
private readonly List<GlobalVariableDefinitionNode> _variableDefinitions;
private readonly List<StructDefinitionNode> _classes;
private readonly Stack<Variable> _variables;
public ExpressionTyper(List<DefinitionNode> definitions)
@@ -23,6 +24,8 @@ public class ExpressionTyper
_functions = [];
_variableDefinitions = [];
_classes = definitions.OfType<StructDefinitionNode>().ToList();
var functions = definitions
.OfType<LocalFuncDefinitionNode>()
.Select(f => new Func(f.Name, f.Parameters, f.Body, f.ReturnType))
@@ -42,6 +45,17 @@ public class ExpressionTyper
{
_variables.Clear();
foreach (var @class in _classes)
{
foreach (var variable in @class.Members)
{
if (variable.Value.HasValue)
{
PopulateExpression(variable.Value.Value);
}
}
}
foreach (var variable in _variableDefinitions)
{
PopulateExpression(variable.Value);
@@ -199,6 +213,9 @@ public class ExpressionTyper
case LiteralNode literal:
PopulateLiteral(literal);
break;
case StructInitializerNode structInitializer:
PopulateStructInitializer(structInitializer);
break;
case SyscallExpressionNode syscall:
PopulateSyscallExpression(syscall);
break;
@@ -296,6 +313,16 @@ public class ExpressionTyper
literal.Type = literal.LiteralType;
}
private void PopulateStructInitializer(StructInitializerNode structInitializer)
{
foreach (var initializer in structInitializer.Initializers)
{
PopulateExpression(initializer.Value);
}
structInitializer.Type = structInitializer.StructType;
}
private void PopulateSyscallExpression(SyscallExpressionNode syscall)
{
foreach (var parameter in syscall.Syscall.Parameters)

View File

@@ -0,0 +1,11 @@
using Nub.Core;
using Nub.Lang.Frontend.Parsing;
namespace Nub.Lang;
public class StructMember(string name, Type type, Optional<ExpressionNode> value)
{
public string Name { get; } = name;
public Type Type { get; } = type;
public Optional<ExpressionNode> Value { get; } = value;
}

View File

@@ -1,4 +1,6 @@
namespace Nub.Lang;
using System.Diagnostics.CodeAnalysis;
namespace Nub.Lang;
public abstract class Type
{
@@ -17,12 +19,14 @@ public abstract class Type
protected abstract bool Equals(Type other);
public abstract override int GetHashCode();
public static bool operator == (Type left, Type right)
public static bool operator == (Type? left, Type? right)
{
if (left is null && right is null) return true;
if (left is null || right is null) return false;
return ReferenceEquals(left, right) || left.Equals(right);
}
public static bool operator !=(Type left, Type right) => !(left == right);
public static bool operator !=(Type? left, Type? right) => !(left == right);
}
public class AnyType : Type
@@ -32,13 +36,8 @@ public class AnyType : Type
public override string ToString() => "Any";
}
public class PrimitiveType : Type
public class PrimitiveType(PrimitiveTypeKind kind) : Type
{
public PrimitiveType(PrimitiveTypeKind kind)
{
Kind = kind;
}
// TODO: This should be looked at more in the future
public override bool IsAssignableTo(Type otherType)
{
@@ -56,20 +55,20 @@ public class PrimitiveType : Type
return false;
}
public static PrimitiveType Parse(string value)
public static bool TryParse(string value, [NotNullWhen(true)] out PrimitiveType? result)
{
var kind = value switch
result = value switch
{
"bool" => PrimitiveTypeKind.Bool,
"int64" => PrimitiveTypeKind.Int64,
"int32" => PrimitiveTypeKind.Int32,
_ => throw new ArgumentOutOfRangeException(nameof(value), value, null)
"bool" => new PrimitiveType(PrimitiveTypeKind.Bool),
"int64" => new PrimitiveType(PrimitiveTypeKind.Int64),
"int32" => new PrimitiveType(PrimitiveTypeKind.Int32),
_ => null
};
return new PrimitiveType(kind);
return result != null;
}
public PrimitiveTypeKind Kind { get; }
public PrimitiveTypeKind Kind { get; } = kind;
protected override bool Equals(Type other) => other is PrimitiveType primitiveType && Kind == primitiveType.Kind;
public override int GetHashCode() => Kind.GetHashCode();
@@ -90,14 +89,9 @@ public class StringType : Type
public override string ToString() => "String";
}
public class ArrayType : Type
public class ArrayType(Type innerType) : Type
{
public ArrayType(Type innerType)
{
InnerType = innerType;
}
public Type InnerType { get; }
public Type InnerType { get; } = innerType;
public override bool IsAssignableTo(Type otherType)
{
@@ -109,3 +103,12 @@ public class ArrayType : Type
public override int GetHashCode() => HashCode.Combine(InnerType);
public override string ToString() => $"Array<{InnerType}>";
}
public class StructType(string name) : Type
{
public string Name { get; } = name;
protected override bool Equals(Type other) => other is StructType classType && Name == classType.Name;
public override int GetHashCode() => Name.GetHashCode();
public override string ToString() => Name;
}

View File

@@ -6,12 +6,12 @@ func print(msg: String) {
syscall(SYS_WRITE, STD_OUT, msg, str_len(msg));
}
func print(value: int64) {
print(itoa(value));
func print(value1: int64) {
print(itoa(value1));
}
func print(value: bool) {
if value {
func print(value2: bool) {
if value2 {
print("true");
} else {
print("false");
@@ -27,12 +27,12 @@ func println(msg: String) {
println();
}
func println(value: bool) {
print(value);
func println(value3: bool) {
print(value3);
println();
}
func println(value: int64) {
print(value);
func println(value4: int64) {
print(value4);
println();
}

View File

@@ -1,6 +1,19 @@
import "core";
func main() {
let x = new Test
{
some_string = "test2",
some_int = 69,
};
}
struct Test {
let some_string: String;
let some_int: int64;
}
func example() {
let some_string = "test";
println(some_string);