Implements struct initializers

This commit is contained in:
nub31
2025-01-31 22:47:55 +01:00
parent ee640ae7a8
commit 16a031823e
11 changed files with 221 additions and 52 deletions

View File

@@ -16,7 +16,7 @@ public class Generator
public Generator(List<DefinitionNode> definitions) public Generator(List<DefinitionNode> definitions)
{ {
_definitions = []; _definitions = definitions;
_builder = new StringBuilder(); _builder = new StringBuilder();
_labelFactory = new LabelFactory(); _labelFactory = new LabelFactory();
_symbolTable = new SymbolTable(_labelFactory); _symbolTable = new SymbolTable(_labelFactory);
@@ -25,19 +25,16 @@ public class Generator
foreach (var globalVariableDefinition in definitions.OfType<GlobalVariableDefinitionNode>()) foreach (var globalVariableDefinition in definitions.OfType<GlobalVariableDefinitionNode>())
{ {
_symbolTable.DefineGlobalVariable(globalVariableDefinition); _symbolTable.DefineGlobalVariable(globalVariableDefinition);
_definitions.Add(globalVariableDefinition);
} }
foreach (var funcDefinitionNode in definitions.OfType<ExternFuncDefinitionNode>()) foreach (var funcDefinitionNode in definitions.OfType<ExternFuncDefinitionNode>())
{ {
_symbolTable.DefineFunc(funcDefinitionNode); _symbolTable.DefineFunc(funcDefinitionNode);
_definitions.Add(funcDefinitionNode);
} }
foreach (var funcDefinitionNode in definitions.OfType<LocalFuncDefinitionNode>()) foreach (var funcDefinitionNode in definitions.OfType<LocalFuncDefinitionNode>())
{ {
_symbolTable.DefineFunc(funcDefinitionNode); _symbolTable.DefineFunc(funcDefinitionNode);
_definitions.Add(funcDefinitionNode);
} }
} }
@@ -53,7 +50,7 @@ public class Generator
_builder.AppendLine(); _builder.AppendLine();
_builder.AppendLine("section .text"); _builder.AppendLine("section .text");
// TODO: Only add start label if main is present // TODO: Only add start label if entrypoint is present, otherwise assume library
var main = _symbolTable.ResolveLocalFunc(Entrypoint, []); var main = _symbolTable.ResolveLocalFunc(Entrypoint, []);
_builder.AppendLine("_start:"); _builder.AppendLine("_start:");
@@ -102,16 +99,15 @@ public class Generator
foreach (var str in _symbolTable.Strings) foreach (var str in _symbolTable.Strings)
{ {
_builder.AppendLine($"{str.Key}: db `{str.Value}`, 0"); _builder.AppendLine($" {str.Key}: db `{str.Value}`, 0");
} }
Dictionary<string, string> completed = []; Dictionary<string, string> completed = [];
foreach (var globalVariableDefinition in _definitions.OfType<GlobalVariableDefinitionNode>()) foreach (var globalVariableDefinition in _definitions.OfType<GlobalVariableDefinitionNode>())
{ {
var variable = _symbolTable.ResolveGlobalVariable(globalVariableDefinition.Name); var variable = _symbolTable.ResolveGlobalVariable(globalVariableDefinition.Name);
var evaluated = EvaluateExpression(globalVariableDefinition.Value, completed); var evaluated = EvaluateExpression(globalVariableDefinition.Value, completed);
_builder.AppendLine($"{variable.Identifier}: dq {evaluated}"); _builder.AppendLine($" {variable.Identifier}: dq {evaluated}");
completed[variable.Name] = evaluated; completed[variable.Name] = evaluated;
} }
@@ -341,7 +337,7 @@ public class Generator
GenerateArrayIndexAccess(arrayIndexAccess, func); GenerateArrayIndexAccess(arrayIndexAccess, func);
break; break;
case ArrayInitializerNode arrayInitializer: case ArrayInitializerNode arrayInitializer:
GenerateArrayInitializer(arrayInitializer, func); GenerateArrayInitializer(arrayInitializer);
break; break;
case BinaryExpressionNode binaryExpression: case BinaryExpressionNode binaryExpression:
GenerateBinaryExpression(binaryExpression, func); GenerateBinaryExpression(binaryExpression, func);
@@ -355,6 +351,9 @@ public class Generator
case LiteralNode literal: case LiteralNode literal:
GenerateLiteral(literal); GenerateLiteral(literal);
break; break;
case StructInitializerNode structInitializer:
GenerateStructInitializer(structInitializer, func);
break;
case SyscallExpressionNode syscallExpression: case SyscallExpressionNode syscallExpression:
GenerateSyscall(syscallExpression.Syscall, func); GenerateSyscall(syscallExpression.Syscall, func);
break; break;
@@ -369,7 +368,7 @@ public class Generator
_builder.AppendLine(" mov rax, [rax]"); _builder.AppendLine(" mov rax, [rax]");
} }
private void GenerateArrayInitializer(ArrayInitializerNode arrayInitializer, LocalFunc func) private void GenerateArrayInitializer(ArrayInitializerNode arrayInitializer)
{ {
_builder.AppendLine($" sub rsp, {8 + arrayInitializer.Length * 8}"); _builder.AppendLine($" sub rsp, {8 + arrayInitializer.Length * 8}");
_builder.AppendLine(" mov rax, rsp"); _builder.AppendLine(" mov rax, rsp");
@@ -591,6 +590,45 @@ public class Generator
} }
} }
private void GenerateStructInitializer(StructInitializerNode structInitializer, LocalFunc func)
{
var structDefinition = _definitions
.OfType<StructDefinitionNode>()
.FirstOrDefault(sd => sd.Name == structInitializer.StructType.Name);
if (structDefinition == null)
{
throw new Exception($"Struct {structInitializer.StructType} is not defined");
}
_builder.AppendLine($" add rsp, {structDefinition.Members.Count * 8}");
foreach (var initializer in structInitializer.Initializers)
{
GenerateExpression(initializer.Value, func);
var index = structDefinition.Members.FindIndex(sd => sd.Name == initializer.Key);
if (index == -1)
{
throw new Exception($"Member {initializer.Key} is not defined on struct {structInitializer.StructType}");
}
_builder.AppendLine($" mov [rsp + {index * 8}], rax");
}
foreach (var uninitializedMember in structDefinition.Members.Where(m => !structInitializer.Initializers.ContainsKey(m.Name)))
{
if (!uninitializedMember.Value.HasValue)
{
throw new Exception($"Struct {structInitializer.StructType} must be initializer with member {uninitializedMember.Name}");
}
GenerateExpression(uninitializedMember.Value.Value, func);
_builder.AppendLine($" mov [rsp + {structDefinition.Members.IndexOf(uninitializedMember) * 8}], rax");
}
_builder.AppendLine(" mov rax, rsp");
}
private void GenerateFuncCall(FuncCall funcCall, LocalFunc func) private void GenerateFuncCall(FuncCall funcCall, LocalFunc func)
{ {
var symbol = _symbolTable.ResolveFunc(funcCall.Name, funcCall.Parameters.Select(p => p.Type).ToList()); var symbol = _symbolTable.ResolveFunc(funcCall.Name, funcCall.Parameters.Select(p => p.Type).ToList());

View File

@@ -17,6 +17,7 @@ public class Lexer
["continue"] = Symbol.Continue, ["continue"] = Symbol.Continue,
["return"] = Symbol.Return, ["return"] = Symbol.Return,
["new"] = Symbol.New, ["new"] = Symbol.New,
["struct"] = Symbol.Struct,
}; };
private static readonly Dictionary<char[], Symbol> Chians = new() private static readonly Dictionary<char[], Symbol> Chians = new()

View File

@@ -41,4 +41,5 @@ public enum Symbol
Star, Star,
ForwardSlash, ForwardSlash,
New, New,
Struct
} }

View File

@@ -47,6 +47,7 @@ public class Parser
Symbol.Let => ParseGlobalVariableDefinition(), Symbol.Let => ParseGlobalVariableDefinition(),
Symbol.Func => ParseFuncDefinition(), Symbol.Func => ParseFuncDefinition(),
Symbol.Extern => ParseExternFuncDefinition(), Symbol.Extern => ParseExternFuncDefinition(),
Symbol.Struct => ParseStruct(),
_ => throw new Exception("Unexpected symbol: " + keyword.Symbol) _ => throw new Exception("Unexpected symbol: " + keyword.Symbol)
}; };
} }
@@ -112,6 +113,36 @@ public class Parser
return new ExternFuncDefinitionNode(name.Value, parameters, returnType); return new ExternFuncDefinitionNode(name.Value, parameters, returnType);
} }
private StructDefinitionNode ParseStruct()
{
var name = ExpectIdentifier().Value;
ExpectSymbol(Symbol.OpenBrace);
List<StructMember> variables = [];
while (!TryExpectSymbol(Symbol.CloseBrace))
{
ExpectSymbol(Symbol.Let);
var variableName = ExpectIdentifier().Value;
ExpectSymbol(Symbol.Colon);
var variableType = ParseType();
var variableValue = Optional<ExpressionNode>.Empty();
if (TryExpectSymbol(Symbol.Assign))
{
variableValue = ParseExpression();
}
ExpectSymbol(Symbol.Semicolon);
variables.Add(new StructMember(variableName, variableType, variableValue));
}
return new StructDefinitionNode(name, variables);
}
private FuncParameter ParseFuncParameter() private FuncParameter ParseFuncParameter()
{ {
var name = ExpectIdentifier(); var name = ExpectIdentifier();
@@ -346,14 +377,40 @@ public class Parser
case Symbol.New: case Symbol.New:
{ {
var type = ParseType(); var type = ParseType();
ExpectSymbol(Symbol.OpenParen);
var size = ExpectLiteral(); switch (type)
if (size.Type is not PrimitiveType { Kind: PrimitiveTypeKind.Int64 })
{ {
throw new Exception($"Array initializer size must be an {PrimitiveTypeKind.Int64}"); // TODO: Parse arrays differently
case ArrayType:
{
ExpectSymbol(Symbol.OpenParen);
var size = ExpectLiteral();
if (size.Type is not PrimitiveType { Kind: PrimitiveTypeKind.Int64 })
{
throw new Exception($"Array initializer size must be an {PrimitiveTypeKind.Int64}");
}
ExpectSymbol(Symbol.CloseParen);
return new ArrayInitializerNode(long.Parse(size.Value), type);
}
case StructType structType:
{
Dictionary<string, ExpressionNode> initializers = [];
ExpectSymbol(Symbol.OpenBrace);
while (!TryExpectSymbol(Symbol.CloseBrace))
{
var name = ExpectIdentifier().Value;
ExpectSymbol(Symbol.Assign);
var value = ParseExpression();
TryExpectSymbol(Symbol.Comma);
initializers.Add(name, value);
}
return new StructInitializerNode(structType, initializers);
}
default:
throw new Exception($"Type {type} cannot be initialized with the new keyword");
} }
ExpectSymbol(Symbol.CloseParen);
return new ArrayInitializerNode(long.Parse(size.Value), type);
} }
default: default:
throw new Exception($"Unknown symbol: {symbolToken.Symbol}"); throw new Exception($"Unknown symbol: {symbolToken.Symbol}");
@@ -408,7 +465,6 @@ public class Parser
private Type ParseType() private Type ParseType()
{ {
var name = ExpectIdentifier().Value; var name = ExpectIdentifier().Value;
switch (name) switch (name)
{ {
case "String": case "String":
@@ -428,7 +484,12 @@ public class Parser
} }
default: default:
{ {
return PrimitiveType.Parse(name); if (PrimitiveType.TryParse(name, out var primitiveType))
{
return primitiveType;
}
return new StructType(name);
} }
} }
} }

View File

@@ -0,0 +1,7 @@
namespace Nub.Lang.Frontend.Parsing;
public class StructDefinitionNode(string name, List<StructMember> members) : DefinitionNode
{
public string Name { get; } = name;
public List<StructMember> Members { get; } = members;
}

View File

@@ -0,0 +1,7 @@
namespace Nub.Lang.Frontend.Parsing;
public class StructInitializerNode(StructType structType, Dictionary<string, ExpressionNode> initializers) : ExpressionNode
{
public StructType StructType { get; } = structType;
public Dictionary<string, ExpressionNode> Initializers { get; } = initializers;
}

View File

@@ -15,6 +15,7 @@ public class ExpressionTyper
{ {
private readonly List<Func> _functions; private readonly List<Func> _functions;
private readonly List<GlobalVariableDefinitionNode> _variableDefinitions; private readonly List<GlobalVariableDefinitionNode> _variableDefinitions;
private readonly List<StructDefinitionNode> _classes;
private readonly Stack<Variable> _variables; private readonly Stack<Variable> _variables;
public ExpressionTyper(List<DefinitionNode> definitions) public ExpressionTyper(List<DefinitionNode> definitions)
@@ -23,6 +24,8 @@ public class ExpressionTyper
_functions = []; _functions = [];
_variableDefinitions = []; _variableDefinitions = [];
_classes = definitions.OfType<StructDefinitionNode>().ToList();
var functions = definitions var functions = definitions
.OfType<LocalFuncDefinitionNode>() .OfType<LocalFuncDefinitionNode>()
.Select(f => new Func(f.Name, f.Parameters, f.Body, f.ReturnType)) .Select(f => new Func(f.Name, f.Parameters, f.Body, f.ReturnType))
@@ -42,6 +45,17 @@ public class ExpressionTyper
{ {
_variables.Clear(); _variables.Clear();
foreach (var @class in _classes)
{
foreach (var variable in @class.Members)
{
if (variable.Value.HasValue)
{
PopulateExpression(variable.Value.Value);
}
}
}
foreach (var variable in _variableDefinitions) foreach (var variable in _variableDefinitions)
{ {
PopulateExpression(variable.Value); PopulateExpression(variable.Value);
@@ -199,6 +213,9 @@ public class ExpressionTyper
case LiteralNode literal: case LiteralNode literal:
PopulateLiteral(literal); PopulateLiteral(literal);
break; break;
case StructInitializerNode structInitializer:
PopulateStructInitializer(structInitializer);
break;
case SyscallExpressionNode syscall: case SyscallExpressionNode syscall:
PopulateSyscallExpression(syscall); PopulateSyscallExpression(syscall);
break; break;
@@ -296,6 +313,16 @@ public class ExpressionTyper
literal.Type = literal.LiteralType; literal.Type = literal.LiteralType;
} }
private void PopulateStructInitializer(StructInitializerNode structInitializer)
{
foreach (var initializer in structInitializer.Initializers)
{
PopulateExpression(initializer.Value);
}
structInitializer.Type = structInitializer.StructType;
}
private void PopulateSyscallExpression(SyscallExpressionNode syscall) private void PopulateSyscallExpression(SyscallExpressionNode syscall)
{ {
foreach (var parameter in syscall.Syscall.Parameters) foreach (var parameter in syscall.Syscall.Parameters)

View File

@@ -0,0 +1,11 @@
using Nub.Core;
using Nub.Lang.Frontend.Parsing;
namespace Nub.Lang;
public class StructMember(string name, Type type, Optional<ExpressionNode> value)
{
public string Name { get; } = name;
public Type Type { get; } = type;
public Optional<ExpressionNode> Value { get; } = value;
}

View File

@@ -1,4 +1,6 @@
namespace Nub.Lang; using System.Diagnostics.CodeAnalysis;
namespace Nub.Lang;
public abstract class Type public abstract class Type
{ {
@@ -17,12 +19,14 @@ public abstract class Type
protected abstract bool Equals(Type other); protected abstract bool Equals(Type other);
public abstract override int GetHashCode(); public abstract override int GetHashCode();
public static bool operator == (Type left, Type right) public static bool operator == (Type? left, Type? right)
{ {
if (left is null && right is null) return true;
if (left is null || right is null) return false;
return ReferenceEquals(left, right) || left.Equals(right); return ReferenceEquals(left, right) || left.Equals(right);
} }
public static bool operator !=(Type left, Type right) => !(left == right); public static bool operator !=(Type? left, Type? right) => !(left == right);
} }
public class AnyType : Type public class AnyType : Type
@@ -32,13 +36,8 @@ public class AnyType : Type
public override string ToString() => "Any"; public override string ToString() => "Any";
} }
public class PrimitiveType : Type public class PrimitiveType(PrimitiveTypeKind kind) : Type
{ {
public PrimitiveType(PrimitiveTypeKind kind)
{
Kind = kind;
}
// TODO: This should be looked at more in the future // TODO: This should be looked at more in the future
public override bool IsAssignableTo(Type otherType) public override bool IsAssignableTo(Type otherType)
{ {
@@ -56,20 +55,20 @@ public class PrimitiveType : Type
return false; return false;
} }
public static PrimitiveType Parse(string value) public static bool TryParse(string value, [NotNullWhen(true)] out PrimitiveType? result)
{ {
var kind = value switch result = value switch
{ {
"bool" => PrimitiveTypeKind.Bool, "bool" => new PrimitiveType(PrimitiveTypeKind.Bool),
"int64" => PrimitiveTypeKind.Int64, "int64" => new PrimitiveType(PrimitiveTypeKind.Int64),
"int32" => PrimitiveTypeKind.Int32, "int32" => new PrimitiveType(PrimitiveTypeKind.Int32),
_ => throw new ArgumentOutOfRangeException(nameof(value), value, null) _ => null
}; };
return new PrimitiveType(kind); return result != null;
} }
public PrimitiveTypeKind Kind { get; } public PrimitiveTypeKind Kind { get; } = kind;
protected override bool Equals(Type other) => other is PrimitiveType primitiveType && Kind == primitiveType.Kind; protected override bool Equals(Type other) => other is PrimitiveType primitiveType && Kind == primitiveType.Kind;
public override int GetHashCode() => Kind.GetHashCode(); public override int GetHashCode() => Kind.GetHashCode();
@@ -90,14 +89,9 @@ public class StringType : Type
public override string ToString() => "String"; public override string ToString() => "String";
} }
public class ArrayType : Type public class ArrayType(Type innerType) : Type
{ {
public ArrayType(Type innerType) public Type InnerType { get; } = innerType;
{
InnerType = innerType;
}
public Type InnerType { get; }
public override bool IsAssignableTo(Type otherType) public override bool IsAssignableTo(Type otherType)
{ {
@@ -108,4 +102,13 @@ public class ArrayType : Type
protected override bool Equals(Type other) => other is ArrayType at && InnerType.Equals(at.InnerType); protected override bool Equals(Type other) => other is ArrayType at && InnerType.Equals(at.InnerType);
public override int GetHashCode() => HashCode.Combine(InnerType); public override int GetHashCode() => HashCode.Combine(InnerType);
public override string ToString() => $"Array<{InnerType}>"; public override string ToString() => $"Array<{InnerType}>";
}
public class StructType(string name) : Type
{
public string Name { get; } = name;
protected override bool Equals(Type other) => other is StructType classType && Name == classType.Name;
public override int GetHashCode() => Name.GetHashCode();
public override string ToString() => Name;
} }

View File

@@ -6,12 +6,12 @@ func print(msg: String) {
syscall(SYS_WRITE, STD_OUT, msg, str_len(msg)); syscall(SYS_WRITE, STD_OUT, msg, str_len(msg));
} }
func print(value: int64) { func print(value1: int64) {
print(itoa(value)); print(itoa(value1));
} }
func print(value: bool) { func print(value2: bool) {
if value { if value2 {
print("true"); print("true");
} else { } else {
print("false"); print("false");
@@ -27,12 +27,12 @@ func println(msg: String) {
println(); println();
} }
func println(value: bool) { func println(value3: bool) {
print(value); print(value3);
println(); println();
} }
func println(value: int64) { func println(value4: int64) {
print(value); print(value4);
println(); println();
} }

View File

@@ -1,6 +1,19 @@
import "core"; import "core";
func main() { func main() {
let x = new Test
{
some_string = "test2",
some_int = 69,
};
}
struct Test {
let some_string: String;
let some_int: int64;
}
func example() {
let some_string = "test"; let some_string = "test";
println(some_string); println(some_string);
@@ -17,4 +30,4 @@ func main() {
println(some_array[i]); println(some_array[i]);
i = i + 1; i = i + 1;
} }
} }