Implements struct initializers
This commit is contained in:
@@ -16,7 +16,7 @@ public class Generator
|
||||
|
||||
public Generator(List<DefinitionNode> definitions)
|
||||
{
|
||||
_definitions = [];
|
||||
_definitions = definitions;
|
||||
_builder = new StringBuilder();
|
||||
_labelFactory = new LabelFactory();
|
||||
_symbolTable = new SymbolTable(_labelFactory);
|
||||
@@ -25,19 +25,16 @@ public class Generator
|
||||
foreach (var globalVariableDefinition in definitions.OfType<GlobalVariableDefinitionNode>())
|
||||
{
|
||||
_symbolTable.DefineGlobalVariable(globalVariableDefinition);
|
||||
_definitions.Add(globalVariableDefinition);
|
||||
}
|
||||
|
||||
foreach (var funcDefinitionNode in definitions.OfType<ExternFuncDefinitionNode>())
|
||||
{
|
||||
_symbolTable.DefineFunc(funcDefinitionNode);
|
||||
_definitions.Add(funcDefinitionNode);
|
||||
}
|
||||
|
||||
foreach (var funcDefinitionNode in definitions.OfType<LocalFuncDefinitionNode>())
|
||||
{
|
||||
_symbolTable.DefineFunc(funcDefinitionNode);
|
||||
_definitions.Add(funcDefinitionNode);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -53,7 +50,7 @@ public class Generator
|
||||
_builder.AppendLine();
|
||||
_builder.AppendLine("section .text");
|
||||
|
||||
// TODO: Only add start label if main is present
|
||||
// TODO: Only add start label if entrypoint is present, otherwise assume library
|
||||
var main = _symbolTable.ResolveLocalFunc(Entrypoint, []);
|
||||
|
||||
_builder.AppendLine("_start:");
|
||||
@@ -102,16 +99,15 @@ public class Generator
|
||||
|
||||
foreach (var str in _symbolTable.Strings)
|
||||
{
|
||||
_builder.AppendLine($"{str.Key}: db `{str.Value}`, 0");
|
||||
_builder.AppendLine($" {str.Key}: db `{str.Value}`, 0");
|
||||
}
|
||||
|
||||
|
||||
Dictionary<string, string> completed = [];
|
||||
foreach (var globalVariableDefinition in _definitions.OfType<GlobalVariableDefinitionNode>())
|
||||
{
|
||||
var variable = _symbolTable.ResolveGlobalVariable(globalVariableDefinition.Name);
|
||||
var evaluated = EvaluateExpression(globalVariableDefinition.Value, completed);
|
||||
_builder.AppendLine($"{variable.Identifier}: dq {evaluated}");
|
||||
_builder.AppendLine($" {variable.Identifier}: dq {evaluated}");
|
||||
completed[variable.Name] = evaluated;
|
||||
}
|
||||
|
||||
@@ -341,7 +337,7 @@ public class Generator
|
||||
GenerateArrayIndexAccess(arrayIndexAccess, func);
|
||||
break;
|
||||
case ArrayInitializerNode arrayInitializer:
|
||||
GenerateArrayInitializer(arrayInitializer, func);
|
||||
GenerateArrayInitializer(arrayInitializer);
|
||||
break;
|
||||
case BinaryExpressionNode binaryExpression:
|
||||
GenerateBinaryExpression(binaryExpression, func);
|
||||
@@ -355,6 +351,9 @@ public class Generator
|
||||
case LiteralNode literal:
|
||||
GenerateLiteral(literal);
|
||||
break;
|
||||
case StructInitializerNode structInitializer:
|
||||
GenerateStructInitializer(structInitializer, func);
|
||||
break;
|
||||
case SyscallExpressionNode syscallExpression:
|
||||
GenerateSyscall(syscallExpression.Syscall, func);
|
||||
break;
|
||||
@@ -369,7 +368,7 @@ public class Generator
|
||||
_builder.AppendLine(" mov rax, [rax]");
|
||||
}
|
||||
|
||||
private void GenerateArrayInitializer(ArrayInitializerNode arrayInitializer, LocalFunc func)
|
||||
private void GenerateArrayInitializer(ArrayInitializerNode arrayInitializer)
|
||||
{
|
||||
_builder.AppendLine($" sub rsp, {8 + arrayInitializer.Length * 8}");
|
||||
_builder.AppendLine(" mov rax, rsp");
|
||||
@@ -591,6 +590,45 @@ public class Generator
|
||||
}
|
||||
}
|
||||
|
||||
private void GenerateStructInitializer(StructInitializerNode structInitializer, LocalFunc func)
|
||||
{
|
||||
var structDefinition = _definitions
|
||||
.OfType<StructDefinitionNode>()
|
||||
.FirstOrDefault(sd => sd.Name == structInitializer.StructType.Name);
|
||||
|
||||
if (structDefinition == null)
|
||||
{
|
||||
throw new Exception($"Struct {structInitializer.StructType} is not defined");
|
||||
}
|
||||
|
||||
_builder.AppendLine($" add rsp, {structDefinition.Members.Count * 8}");
|
||||
|
||||
foreach (var initializer in structInitializer.Initializers)
|
||||
{
|
||||
GenerateExpression(initializer.Value, func);
|
||||
var index = structDefinition.Members.FindIndex(sd => sd.Name == initializer.Key);
|
||||
if (index == -1)
|
||||
{
|
||||
throw new Exception($"Member {initializer.Key} is not defined on struct {structInitializer.StructType}");
|
||||
}
|
||||
|
||||
_builder.AppendLine($" mov [rsp + {index * 8}], rax");
|
||||
}
|
||||
|
||||
foreach (var uninitializedMember in structDefinition.Members.Where(m => !structInitializer.Initializers.ContainsKey(m.Name)))
|
||||
{
|
||||
if (!uninitializedMember.Value.HasValue)
|
||||
{
|
||||
throw new Exception($"Struct {structInitializer.StructType} must be initializer with member {uninitializedMember.Name}");
|
||||
}
|
||||
|
||||
GenerateExpression(uninitializedMember.Value.Value, func);
|
||||
_builder.AppendLine($" mov [rsp + {structDefinition.Members.IndexOf(uninitializedMember) * 8}], rax");
|
||||
}
|
||||
|
||||
_builder.AppendLine(" mov rax, rsp");
|
||||
}
|
||||
|
||||
private void GenerateFuncCall(FuncCall funcCall, LocalFunc func)
|
||||
{
|
||||
var symbol = _symbolTable.ResolveFunc(funcCall.Name, funcCall.Parameters.Select(p => p.Type).ToList());
|
||||
|
||||
@@ -17,6 +17,7 @@ public class Lexer
|
||||
["continue"] = Symbol.Continue,
|
||||
["return"] = Symbol.Return,
|
||||
["new"] = Symbol.New,
|
||||
["struct"] = Symbol.Struct,
|
||||
};
|
||||
|
||||
private static readonly Dictionary<char[], Symbol> Chians = new()
|
||||
|
||||
@@ -41,4 +41,5 @@ public enum Symbol
|
||||
Star,
|
||||
ForwardSlash,
|
||||
New,
|
||||
Struct
|
||||
}
|
||||
@@ -47,6 +47,7 @@ public class Parser
|
||||
Symbol.Let => ParseGlobalVariableDefinition(),
|
||||
Symbol.Func => ParseFuncDefinition(),
|
||||
Symbol.Extern => ParseExternFuncDefinition(),
|
||||
Symbol.Struct => ParseStruct(),
|
||||
_ => throw new Exception("Unexpected symbol: " + keyword.Symbol)
|
||||
};
|
||||
}
|
||||
@@ -112,6 +113,36 @@ public class Parser
|
||||
return new ExternFuncDefinitionNode(name.Value, parameters, returnType);
|
||||
}
|
||||
|
||||
private StructDefinitionNode ParseStruct()
|
||||
{
|
||||
var name = ExpectIdentifier().Value;
|
||||
|
||||
ExpectSymbol(Symbol.OpenBrace);
|
||||
|
||||
List<StructMember> variables = [];
|
||||
|
||||
while (!TryExpectSymbol(Symbol.CloseBrace))
|
||||
{
|
||||
ExpectSymbol(Symbol.Let);
|
||||
var variableName = ExpectIdentifier().Value;
|
||||
ExpectSymbol(Symbol.Colon);
|
||||
var variableType = ParseType();
|
||||
|
||||
var variableValue = Optional<ExpressionNode>.Empty();
|
||||
|
||||
if (TryExpectSymbol(Symbol.Assign))
|
||||
{
|
||||
variableValue = ParseExpression();
|
||||
}
|
||||
|
||||
ExpectSymbol(Symbol.Semicolon);
|
||||
|
||||
variables.Add(new StructMember(variableName, variableType, variableValue));
|
||||
}
|
||||
|
||||
return new StructDefinitionNode(name, variables);
|
||||
}
|
||||
|
||||
private FuncParameter ParseFuncParameter()
|
||||
{
|
||||
var name = ExpectIdentifier();
|
||||
@@ -346,6 +377,12 @@ public class Parser
|
||||
case Symbol.New:
|
||||
{
|
||||
var type = ParseType();
|
||||
|
||||
switch (type)
|
||||
{
|
||||
// TODO: Parse arrays differently
|
||||
case ArrayType:
|
||||
{
|
||||
ExpectSymbol(Symbol.OpenParen);
|
||||
var size = ExpectLiteral();
|
||||
if (size.Type is not PrimitiveType { Kind: PrimitiveTypeKind.Int64 })
|
||||
@@ -353,8 +390,28 @@ public class Parser
|
||||
throw new Exception($"Array initializer size must be an {PrimitiveTypeKind.Int64}");
|
||||
}
|
||||
ExpectSymbol(Symbol.CloseParen);
|
||||
|
||||
return new ArrayInitializerNode(long.Parse(size.Value), type);
|
||||
}
|
||||
case StructType structType:
|
||||
{
|
||||
Dictionary<string, ExpressionNode> initializers = [];
|
||||
ExpectSymbol(Symbol.OpenBrace);
|
||||
while (!TryExpectSymbol(Symbol.CloseBrace))
|
||||
{
|
||||
var name = ExpectIdentifier().Value;
|
||||
ExpectSymbol(Symbol.Assign);
|
||||
var value = ParseExpression();
|
||||
TryExpectSymbol(Symbol.Comma);
|
||||
initializers.Add(name, value);
|
||||
}
|
||||
|
||||
return new StructInitializerNode(structType, initializers);
|
||||
}
|
||||
default:
|
||||
throw new Exception($"Type {type} cannot be initialized with the new keyword");
|
||||
}
|
||||
}
|
||||
default:
|
||||
throw new Exception($"Unknown symbol: {symbolToken.Symbol}");
|
||||
}
|
||||
@@ -408,7 +465,6 @@ public class Parser
|
||||
private Type ParseType()
|
||||
{
|
||||
var name = ExpectIdentifier().Value;
|
||||
|
||||
switch (name)
|
||||
{
|
||||
case "String":
|
||||
@@ -428,7 +484,12 @@ public class Parser
|
||||
}
|
||||
default:
|
||||
{
|
||||
return PrimitiveType.Parse(name);
|
||||
if (PrimitiveType.TryParse(name, out var primitiveType))
|
||||
{
|
||||
return primitiveType;
|
||||
}
|
||||
|
||||
return new StructType(name);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,7 @@
|
||||
namespace Nub.Lang.Frontend.Parsing;
|
||||
|
||||
public class StructDefinitionNode(string name, List<StructMember> members) : DefinitionNode
|
||||
{
|
||||
public string Name { get; } = name;
|
||||
public List<StructMember> Members { get; } = members;
|
||||
}
|
||||
@@ -0,0 +1,7 @@
|
||||
namespace Nub.Lang.Frontend.Parsing;
|
||||
|
||||
public class StructInitializerNode(StructType structType, Dictionary<string, ExpressionNode> initializers) : ExpressionNode
|
||||
{
|
||||
public StructType StructType { get; } = structType;
|
||||
public Dictionary<string, ExpressionNode> Initializers { get; } = initializers;
|
||||
}
|
||||
@@ -15,6 +15,7 @@ public class ExpressionTyper
|
||||
{
|
||||
private readonly List<Func> _functions;
|
||||
private readonly List<GlobalVariableDefinitionNode> _variableDefinitions;
|
||||
private readonly List<StructDefinitionNode> _classes;
|
||||
private readonly Stack<Variable> _variables;
|
||||
|
||||
public ExpressionTyper(List<DefinitionNode> definitions)
|
||||
@@ -23,6 +24,8 @@ public class ExpressionTyper
|
||||
_functions = [];
|
||||
_variableDefinitions = [];
|
||||
|
||||
_classes = definitions.OfType<StructDefinitionNode>().ToList();
|
||||
|
||||
var functions = definitions
|
||||
.OfType<LocalFuncDefinitionNode>()
|
||||
.Select(f => new Func(f.Name, f.Parameters, f.Body, f.ReturnType))
|
||||
@@ -42,6 +45,17 @@ public class ExpressionTyper
|
||||
{
|
||||
_variables.Clear();
|
||||
|
||||
foreach (var @class in _classes)
|
||||
{
|
||||
foreach (var variable in @class.Members)
|
||||
{
|
||||
if (variable.Value.HasValue)
|
||||
{
|
||||
PopulateExpression(variable.Value.Value);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
foreach (var variable in _variableDefinitions)
|
||||
{
|
||||
PopulateExpression(variable.Value);
|
||||
@@ -199,6 +213,9 @@ public class ExpressionTyper
|
||||
case LiteralNode literal:
|
||||
PopulateLiteral(literal);
|
||||
break;
|
||||
case StructInitializerNode structInitializer:
|
||||
PopulateStructInitializer(structInitializer);
|
||||
break;
|
||||
case SyscallExpressionNode syscall:
|
||||
PopulateSyscallExpression(syscall);
|
||||
break;
|
||||
@@ -296,6 +313,16 @@ public class ExpressionTyper
|
||||
literal.Type = literal.LiteralType;
|
||||
}
|
||||
|
||||
private void PopulateStructInitializer(StructInitializerNode structInitializer)
|
||||
{
|
||||
foreach (var initializer in structInitializer.Initializers)
|
||||
{
|
||||
PopulateExpression(initializer.Value);
|
||||
}
|
||||
|
||||
structInitializer.Type = structInitializer.StructType;
|
||||
}
|
||||
|
||||
private void PopulateSyscallExpression(SyscallExpressionNode syscall)
|
||||
{
|
||||
foreach (var parameter in syscall.Syscall.Parameters)
|
||||
|
||||
11
Nub.Lang/Nub.Lang/StructMember.cs
Normal file
11
Nub.Lang/Nub.Lang/StructMember.cs
Normal file
@@ -0,0 +1,11 @@
|
||||
using Nub.Core;
|
||||
using Nub.Lang.Frontend.Parsing;
|
||||
|
||||
namespace Nub.Lang;
|
||||
|
||||
public class StructMember(string name, Type type, Optional<ExpressionNode> value)
|
||||
{
|
||||
public string Name { get; } = name;
|
||||
public Type Type { get; } = type;
|
||||
public Optional<ExpressionNode> Value { get; } = value;
|
||||
}
|
||||
@@ -1,4 +1,6 @@
|
||||
namespace Nub.Lang;
|
||||
using System.Diagnostics.CodeAnalysis;
|
||||
|
||||
namespace Nub.Lang;
|
||||
|
||||
public abstract class Type
|
||||
{
|
||||
@@ -17,12 +19,14 @@ public abstract class Type
|
||||
protected abstract bool Equals(Type other);
|
||||
public abstract override int GetHashCode();
|
||||
|
||||
public static bool operator == (Type left, Type right)
|
||||
public static bool operator == (Type? left, Type? right)
|
||||
{
|
||||
if (left is null && right is null) return true;
|
||||
if (left is null || right is null) return false;
|
||||
return ReferenceEquals(left, right) || left.Equals(right);
|
||||
}
|
||||
|
||||
public static bool operator !=(Type left, Type right) => !(left == right);
|
||||
public static bool operator !=(Type? left, Type? right) => !(left == right);
|
||||
}
|
||||
|
||||
public class AnyType : Type
|
||||
@@ -32,13 +36,8 @@ public class AnyType : Type
|
||||
public override string ToString() => "Any";
|
||||
}
|
||||
|
||||
public class PrimitiveType : Type
|
||||
public class PrimitiveType(PrimitiveTypeKind kind) : Type
|
||||
{
|
||||
public PrimitiveType(PrimitiveTypeKind kind)
|
||||
{
|
||||
Kind = kind;
|
||||
}
|
||||
|
||||
// TODO: This should be looked at more in the future
|
||||
public override bool IsAssignableTo(Type otherType)
|
||||
{
|
||||
@@ -56,20 +55,20 @@ public class PrimitiveType : Type
|
||||
return false;
|
||||
}
|
||||
|
||||
public static PrimitiveType Parse(string value)
|
||||
public static bool TryParse(string value, [NotNullWhen(true)] out PrimitiveType? result)
|
||||
{
|
||||
var kind = value switch
|
||||
result = value switch
|
||||
{
|
||||
"bool" => PrimitiveTypeKind.Bool,
|
||||
"int64" => PrimitiveTypeKind.Int64,
|
||||
"int32" => PrimitiveTypeKind.Int32,
|
||||
_ => throw new ArgumentOutOfRangeException(nameof(value), value, null)
|
||||
"bool" => new PrimitiveType(PrimitiveTypeKind.Bool),
|
||||
"int64" => new PrimitiveType(PrimitiveTypeKind.Int64),
|
||||
"int32" => new PrimitiveType(PrimitiveTypeKind.Int32),
|
||||
_ => null
|
||||
};
|
||||
|
||||
return new PrimitiveType(kind);
|
||||
return result != null;
|
||||
}
|
||||
|
||||
public PrimitiveTypeKind Kind { get; }
|
||||
public PrimitiveTypeKind Kind { get; } = kind;
|
||||
|
||||
protected override bool Equals(Type other) => other is PrimitiveType primitiveType && Kind == primitiveType.Kind;
|
||||
public override int GetHashCode() => Kind.GetHashCode();
|
||||
@@ -90,14 +89,9 @@ public class StringType : Type
|
||||
public override string ToString() => "String";
|
||||
}
|
||||
|
||||
public class ArrayType : Type
|
||||
public class ArrayType(Type innerType) : Type
|
||||
{
|
||||
public ArrayType(Type innerType)
|
||||
{
|
||||
InnerType = innerType;
|
||||
}
|
||||
|
||||
public Type InnerType { get; }
|
||||
public Type InnerType { get; } = innerType;
|
||||
|
||||
public override bool IsAssignableTo(Type otherType)
|
||||
{
|
||||
@@ -109,3 +103,12 @@ public class ArrayType : Type
|
||||
public override int GetHashCode() => HashCode.Combine(InnerType);
|
||||
public override string ToString() => $"Array<{InnerType}>";
|
||||
}
|
||||
|
||||
public class StructType(string name) : Type
|
||||
{
|
||||
public string Name { get; } = name;
|
||||
|
||||
protected override bool Equals(Type other) => other is StructType classType && Name == classType.Name;
|
||||
public override int GetHashCode() => Name.GetHashCode();
|
||||
public override string ToString() => Name;
|
||||
}
|
||||
@@ -6,12 +6,12 @@ func print(msg: String) {
|
||||
syscall(SYS_WRITE, STD_OUT, msg, str_len(msg));
|
||||
}
|
||||
|
||||
func print(value: int64) {
|
||||
print(itoa(value));
|
||||
func print(value1: int64) {
|
||||
print(itoa(value1));
|
||||
}
|
||||
|
||||
func print(value: bool) {
|
||||
if value {
|
||||
func print(value2: bool) {
|
||||
if value2 {
|
||||
print("true");
|
||||
} else {
|
||||
print("false");
|
||||
@@ -27,12 +27,12 @@ func println(msg: String) {
|
||||
println();
|
||||
}
|
||||
|
||||
func println(value: bool) {
|
||||
print(value);
|
||||
func println(value3: bool) {
|
||||
print(value3);
|
||||
println();
|
||||
}
|
||||
|
||||
func println(value: int64) {
|
||||
print(value);
|
||||
func println(value4: int64) {
|
||||
print(value4);
|
||||
println();
|
||||
}
|
||||
|
||||
@@ -1,6 +1,19 @@
|
||||
import "core";
|
||||
|
||||
func main() {
|
||||
let x = new Test
|
||||
{
|
||||
some_string = "test2",
|
||||
some_int = 69,
|
||||
};
|
||||
}
|
||||
|
||||
struct Test {
|
||||
let some_string: String;
|
||||
let some_int: int64;
|
||||
}
|
||||
|
||||
func example() {
|
||||
let some_string = "test";
|
||||
println(some_string);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user