This commit is contained in:
nub31
2025-05-12 21:41:00 +02:00
parent b4a80bb7e1
commit 90ef9fb8e8
8 changed files with 186 additions and 329 deletions

View File

@@ -1,4 +1,4 @@
import "c"; import c;
global func main() { global func main() {
printf("something %s\n", "your mom"); printf("something %s\n", "your mom");

View File

@@ -2,6 +2,5 @@
<project version="4"> <project version="4">
<component name="VcsDirectoryMappings"> <component name="VcsDirectoryMappings">
<mapping directory="$PROJECT_DIR$/../.." vcs="Git" /> <mapping directory="$PROJECT_DIR$/../.." vcs="Git" />
<mapping directory="$PROJECT_DIR$/.." vcs="Git" />
</component> </component>
</project> </project>

View File

@@ -44,22 +44,64 @@ public class Generator
private string QbeTypeName(NubType type) private string QbeTypeName(NubType type)
{ {
if (type.Equals(NubType.Int32) || type.Equals(NubType.Bool)) if (type is NubPrimitiveType primitiveType)
{ {
switch (primitiveType.Kind)
{
case PrimitiveTypeKind.I64:
case PrimitiveTypeKind.U64:
case PrimitiveTypeKind.String:
case PrimitiveTypeKind.Any:
return "l";
case PrimitiveTypeKind.I32:
case PrimitiveTypeKind.U32:
case PrimitiveTypeKind.I16:
case PrimitiveTypeKind.U16:
case PrimitiveTypeKind.I8:
case PrimitiveTypeKind.U8:
case PrimitiveTypeKind.Bool:
return "w"; return "w";
case PrimitiveTypeKind.F64:
return "d";
case PrimitiveTypeKind.F32:
return "s";
default:
throw new ArgumentOutOfRangeException();
}
} }
return "l"; throw new NotImplementedException();
} }
private int QbeTypeSize(NubType type) private int QbeTypeSize(NubType type)
{ {
if (type.Equals(NubType.Int32) || type.Equals(NubType.Bool)) if (type is NubPrimitiveType primitiveType)
{ {
switch (primitiveType.Kind)
{
case PrimitiveTypeKind.I64:
case PrimitiveTypeKind.U64:
case PrimitiveTypeKind.String:
case PrimitiveTypeKind.Any:
return 8;
case PrimitiveTypeKind.I32:
case PrimitiveTypeKind.U32:
case PrimitiveTypeKind.I16:
case PrimitiveTypeKind.U16:
case PrimitiveTypeKind.I8:
case PrimitiveTypeKind.U8:
case PrimitiveTypeKind.Bool:
return 4; return 4;
case PrimitiveTypeKind.F64:
return 8;
case PrimitiveTypeKind.F32:
return 4;
default:
throw new ArgumentOutOfRangeException();
}
} }
return 8; throw new NotImplementedException();
} }
private void GenerateFuncDefinition(LocalFuncDefinitionNode node) private void GenerateFuncDefinition(LocalFuncDefinitionNode node)
@@ -349,25 +391,25 @@ public class Generator
{ {
case BinaryExpressionOperator.Equal: case BinaryExpressionOperator.Equal:
{ {
if (binaryExpression.Left.Type.Equals(NubType.Int64) && binaryExpression.Right.Type.Equals(NubType.Int64)) if (binaryExpression.Left.Type.Equals(NubPrimitiveType.I64))
{ {
_builder.AppendLine($" %{outputLabel} =w ceql {left}, {right}"); _builder.AppendLine($" %{outputLabel} =w ceql {left}, {right}");
return $"%{outputLabel}"; return $"%{outputLabel}";
} }
if (binaryExpression.Left.Type.Equals(NubType.Int32) && binaryExpression.Right.Type.Equals(NubType.Int32)) if (binaryExpression.Left.Type.Equals(NubPrimitiveType.I32))
{ {
_builder.AppendLine($" %{outputLabel} =w ceqw {left}, {right}"); _builder.AppendLine($" %{outputLabel} =w ceqw {left}, {right}");
return $"%{outputLabel}"; return $"%{outputLabel}";
} }
if (binaryExpression.Left.Type.Equals(NubType.String) && binaryExpression.Right.Type.Equals(NubType.String)) if (binaryExpression.Left.Type.Equals(NubPrimitiveType.String))
{ {
_builder.AppendLine($" %{outputLabel} =w call $nub_strcmp(l {left}, l {right})"); _builder.AppendLine($" %{outputLabel} =w call $nub_strcmp(l {left}, l {right})");
return $"%{outputLabel}"; return $"%{outputLabel}";
} }
if (binaryExpression.Left.Type.Equals(NubType.Bool) && binaryExpression.Right.Type.Equals(NubType.Bool)) if (binaryExpression.Left.Type.Equals(NubPrimitiveType.Bool))
{ {
_builder.AppendLine($" %{outputLabel} =w ceqw {left}, {right}"); _builder.AppendLine($" %{outputLabel} =w ceqw {left}, {right}");
return $"%{outputLabel}"; return $"%{outputLabel}";
@@ -376,26 +418,26 @@ public class Generator
} }
case BinaryExpressionOperator.NotEqual: case BinaryExpressionOperator.NotEqual:
{ {
if (binaryExpression.Left.Type.Equals(NubType.Int64) && binaryExpression.Right.Type.Equals(NubType.Int64)) if (binaryExpression.Left.Type.Equals(NubPrimitiveType.I64))
{ {
_builder.AppendLine($" %{outputLabel} =w cnel {left}, {right}"); _builder.AppendLine($" %{outputLabel} =w cnel {left}, {right}");
return $"%{outputLabel}"; return $"%{outputLabel}";
} }
if (binaryExpression.Left.Type.Equals(NubType.Int32) && binaryExpression.Right.Type.Equals(NubType.Int32)) if (binaryExpression.Left.Type.Equals(NubPrimitiveType.I32))
{ {
_builder.AppendLine($" %{outputLabel} =w cnew {left}, {right}"); _builder.AppendLine($" %{outputLabel} =w cnew {left}, {right}");
return $"%{outputLabel}"; return $"%{outputLabel}";
} }
if (binaryExpression.Left.Type.Equals(NubType.String) && binaryExpression.Right.Type.Equals(NubType.String)) if (binaryExpression.Left.Type.Equals(NubPrimitiveType.String))
{ {
_builder.AppendLine($" %{outputLabel} =w call $nub_strcmp(l {left}, l {right})"); _builder.AppendLine($" %{outputLabel} =w call $nub_strcmp(l {left}, l {right})");
_builder.AppendLine($" %{outputLabel} =w xor %{outputLabel}, 1"); _builder.AppendLine($" %{outputLabel} =w xor %{outputLabel}, 1");
return $"%{outputLabel}"; return $"%{outputLabel}";
} }
if (binaryExpression.Left.Type.Equals(NubType.Bool) && binaryExpression.Right.Type.Equals(NubType.Bool)) if (binaryExpression.Left.Type.Equals(NubPrimitiveType.Bool))
{ {
_builder.AppendLine($" %{outputLabel} =w cnew {left}, {right}"); _builder.AppendLine($" %{outputLabel} =w cnew {left}, {right}");
return $"%{outputLabel}"; return $"%{outputLabel}";
@@ -404,19 +446,19 @@ public class Generator
} }
case BinaryExpressionOperator.GreaterThan: case BinaryExpressionOperator.GreaterThan:
{ {
if (binaryExpression.Left.Type.Equals(NubType.Int64) && binaryExpression.Right.Type.Equals(NubType.Int64)) if (binaryExpression.Left.Type.Equals(NubPrimitiveType.I64))
{ {
_builder.AppendLine($" %{outputLabel} =w csgtl {left}, {right}"); _builder.AppendLine($" %{outputLabel} =w csgtl {left}, {right}");
return $"%{outputLabel}"; return $"%{outputLabel}";
} }
if (binaryExpression.Left.Type.Equals(NubType.Int32) && binaryExpression.Right.Type.Equals(NubType.Int32)) if (binaryExpression.Left.Type.Equals(NubPrimitiveType.I32))
{ {
_builder.AppendLine($" %{outputLabel} =w csgtw {left}, {right}"); _builder.AppendLine($" %{outputLabel} =w csgtw {left}, {right}");
return $"%{outputLabel}"; return $"%{outputLabel}";
} }
if (binaryExpression.Left.Type.Equals(NubType.Bool) && binaryExpression.Right.Type.Equals(NubType.Bool)) if (binaryExpression.Left.Type.Equals(NubPrimitiveType.Bool))
{ {
_builder.AppendLine($" %{outputLabel} =w csgtw {left}, {right}"); _builder.AppendLine($" %{outputLabel} =w csgtw {left}, {right}");
return $"%{outputLabel}"; return $"%{outputLabel}";
@@ -425,19 +467,19 @@ public class Generator
} }
case BinaryExpressionOperator.GreaterThanOrEqual: case BinaryExpressionOperator.GreaterThanOrEqual:
{ {
if (binaryExpression.Left.Type.Equals(NubType.Int64) && binaryExpression.Right.Type.Equals(NubType.Int64)) if (binaryExpression.Left.Type.Equals(NubPrimitiveType.I64))
{ {
_builder.AppendLine($" %{outputLabel} =w csgel {left}, {right}"); _builder.AppendLine($" %{outputLabel} =w csgel {left}, {right}");
return $"%{outputLabel}"; return $"%{outputLabel}";
} }
if (binaryExpression.Left.Type.Equals(NubType.Int32) && binaryExpression.Right.Type.Equals(NubType.Int32)) if (binaryExpression.Left.Type.Equals(NubPrimitiveType.I32))
{ {
_builder.AppendLine($" %{outputLabel} =w csgew {left}, {right}"); _builder.AppendLine($" %{outputLabel} =w csgew {left}, {right}");
return $"%{outputLabel}"; return $"%{outputLabel}";
} }
if (binaryExpression.Left.Type.Equals(NubType.Bool) && binaryExpression.Right.Type.Equals(NubType.Bool)) if (binaryExpression.Left.Type.Equals(NubPrimitiveType.Bool))
{ {
_builder.AppendLine($" %{outputLabel} =w csgew {left}, {right}"); _builder.AppendLine($" %{outputLabel} =w csgew {left}, {right}");
return $"%{outputLabel}"; return $"%{outputLabel}";
@@ -446,19 +488,19 @@ public class Generator
} }
case BinaryExpressionOperator.LessThan: case BinaryExpressionOperator.LessThan:
{ {
if (binaryExpression.Left.Type.Equals(NubType.Int64) && binaryExpression.Right.Type.Equals(NubType.Int64)) if (binaryExpression.Left.Type.Equals(NubPrimitiveType.I64))
{ {
_builder.AppendLine($" %{outputLabel} =w csltl {left}, {right}"); _builder.AppendLine($" %{outputLabel} =w csltl {left}, {right}");
return $"%{outputLabel}"; return $"%{outputLabel}";
} }
if (binaryExpression.Left.Type.Equals(NubType.Int32) && binaryExpression.Right.Type.Equals(NubType.Int32)) if (binaryExpression.Left.Type.Equals(NubPrimitiveType.I32))
{ {
_builder.AppendLine($" %{outputLabel} =w csltw {left}, {right}"); _builder.AppendLine($" %{outputLabel} =w csltw {left}, {right}");
return $"%{outputLabel}"; return $"%{outputLabel}";
} }
if (binaryExpression.Left.Type.Equals(NubType.Bool) && binaryExpression.Right.Type.Equals(NubType.Bool)) if (binaryExpression.Left.Type.Equals(NubPrimitiveType.Bool))
{ {
_builder.AppendLine($" %{outputLabel} =w csltw {left}, {right}"); _builder.AppendLine($" %{outputLabel} =w csltw {left}, {right}");
return $"%{outputLabel}"; return $"%{outputLabel}";
@@ -467,19 +509,19 @@ public class Generator
} }
case BinaryExpressionOperator.LessThanOrEqual: case BinaryExpressionOperator.LessThanOrEqual:
{ {
if (binaryExpression.Left.Type.Equals(NubType.Int64) && binaryExpression.Right.Type.Equals(NubType.Int64)) if (binaryExpression.Left.Type.Equals(NubPrimitiveType.I64))
{ {
_builder.AppendLine($" %{outputLabel} =w cslel {left}, {right}"); _builder.AppendLine($" %{outputLabel} =w cslel {left}, {right}");
return $"%{outputLabel}"; return $"%{outputLabel}";
} }
if (binaryExpression.Left.Type.Equals(NubType.Int32) && binaryExpression.Right.Type.Equals(NubType.Int32)) if (binaryExpression.Left.Type.Equals(NubPrimitiveType.I32))
{ {
_builder.AppendLine($" %{outputLabel} =w cslew {left}, {right}"); _builder.AppendLine($" %{outputLabel} =w cslew {left}, {right}");
return $"%{outputLabel}"; return $"%{outputLabel}";
} }
if (binaryExpression.Left.Type.Equals(NubType.Bool) && binaryExpression.Right.Type.Equals(NubType.Bool)) if (binaryExpression.Left.Type.Equals(NubPrimitiveType.Bool))
{ {
_builder.AppendLine($" %{outputLabel} =w cslew {left}, {right}"); _builder.AppendLine($" %{outputLabel} =w cslew {left}, {right}");
return $"%{outputLabel}"; return $"%{outputLabel}";
@@ -488,13 +530,13 @@ public class Generator
} }
case BinaryExpressionOperator.Plus: case BinaryExpressionOperator.Plus:
{ {
if (binaryExpression.Left.Type.Equals(NubType.Int64) && binaryExpression.Right.Type.Equals(NubType.Int64)) if (binaryExpression.Left.Type.Equals(NubPrimitiveType.I64))
{ {
_builder.AppendLine($" %{outputLabel} =l add {left}, {right}"); _builder.AppendLine($" %{outputLabel} =l add {left}, {right}");
return $"%{outputLabel}"; return $"%{outputLabel}";
} }
if (binaryExpression.Left.Type.Equals(NubType.Int32) && binaryExpression.Right.Type.Equals(NubType.Int32)) if (binaryExpression.Left.Type.Equals(NubPrimitiveType.I32))
{ {
_builder.AppendLine($" %{outputLabel} =w add {left}, {right}"); _builder.AppendLine($" %{outputLabel} =w add {left}, {right}");
return $"%{outputLabel}"; return $"%{outputLabel}";
@@ -503,13 +545,13 @@ public class Generator
} }
case BinaryExpressionOperator.Minus: case BinaryExpressionOperator.Minus:
{ {
if (binaryExpression.Left.Type.Equals(NubType.Int64) && binaryExpression.Right.Type.Equals(NubType.Int64)) if (binaryExpression.Left.Type.Equals(NubPrimitiveType.I64))
{ {
_builder.AppendLine($" %{outputLabel} =l sub {left}, {right}"); _builder.AppendLine($" %{outputLabel} =l sub {left}, {right}");
return $"%{outputLabel}"; return $"%{outputLabel}";
} }
if (binaryExpression.Left.Type.Equals(NubType.Int32) && binaryExpression.Right.Type.Equals(NubType.Int32)) if (binaryExpression.Left.Type.Equals(NubPrimitiveType.I32))
{ {
_builder.AppendLine($" %{outputLabel} =w sub {left}, {right}"); _builder.AppendLine($" %{outputLabel} =w sub {left}, {right}");
return $"%{outputLabel}"; return $"%{outputLabel}";
@@ -518,13 +560,13 @@ public class Generator
} }
case BinaryExpressionOperator.Multiply: case BinaryExpressionOperator.Multiply:
{ {
if (binaryExpression.Left.Type.Equals(NubType.Int64) && binaryExpression.Right.Type.Equals(NubType.Int64)) if (binaryExpression.Left.Type.Equals(NubPrimitiveType.I64))
{ {
_builder.AppendLine($" %{outputLabel} =l mul {left}, {right}"); _builder.AppendLine($" %{outputLabel} =l mul {left}, {right}");
return $"%{outputLabel}"; return $"%{outputLabel}";
} }
if (binaryExpression.Left.Type.Equals(NubType.Int32) && binaryExpression.Right.Type.Equals(NubType.Int32)) if (binaryExpression.Left.Type.Equals(NubPrimitiveType.I32))
{ {
_builder.AppendLine($" %{outputLabel} =w mul {left}, {right}"); _builder.AppendLine($" %{outputLabel} =w mul {left}, {right}");
return $"%{outputLabel}"; return $"%{outputLabel}";
@@ -533,13 +575,13 @@ public class Generator
} }
case BinaryExpressionOperator.Divide: case BinaryExpressionOperator.Divide:
{ {
if (binaryExpression.Left.Type.Equals(NubType.Int64) && binaryExpression.Right.Type.Equals(NubType.Int64)) if (binaryExpression.Left.Type.Equals(NubPrimitiveType.I64))
{ {
_builder.AppendLine($" %{outputLabel} =l div {left}, {right}"); _builder.AppendLine($" %{outputLabel} =l div {left}, {right}");
return $"%{outputLabel}"; return $"%{outputLabel}";
} }
if (binaryExpression.Left.Type.Equals(NubType.Int32) && binaryExpression.Right.Type.Equals(NubType.Int32)) if (binaryExpression.Left.Type.Equals(NubPrimitiveType.I32))
{ {
_builder.AppendLine($" %{outputLabel} =w div {left}, {right}"); _builder.AppendLine($" %{outputLabel} =w div {left}, {right}");
return $"%{outputLabel}"; return $"%{outputLabel}";
@@ -562,18 +604,18 @@ public class Generator
private string GenerateLiteral(LiteralNode literal) private string GenerateLiteral(LiteralNode literal)
{ {
if (literal.LiteralType.Equals(NubType.String)) if (literal.LiteralType.Equals(NubPrimitiveType.String))
{ {
_strings.Add(literal.Literal); _strings.Add(literal.Literal);
return $"$str{_strings.Count}"; return $"$str{_strings.Count}";
} }
if (literal.LiteralType.Equals(NubType.Int64) || literal.LiteralType.Equals(NubType.Int32)) if (literal.LiteralType.Equals(NubPrimitiveType.I64) || literal.LiteralType.Equals(NubPrimitiveType.I32))
{ {
return literal.Literal; return literal.Literal;
} }
if (literal.LiteralType.Equals(NubType.Bool)) if (literal.LiteralType.Equals(NubPrimitiveType.Bool))
{ {
return bool.Parse(literal.Literal) ? "1" : "0"; return bool.Parse(literal.Literal) ? "1" : "0";
} }

View File

@@ -86,7 +86,7 @@ public class Lexer
if (buffer is "true" or "false") if (buffer is "true" or "false")
{ {
return new LiteralToken(NubType.Bool, buffer); return new LiteralToken(NubPrimitiveType.Bool, buffer);
} }
return new IdentifierToken(buffer); return new IdentifierToken(buffer);
@@ -103,7 +103,7 @@ public class Lexer
current = Peek(); current = Peek();
} }
return new LiteralToken(NubType.Int64, buffer); return new LiteralToken(NubPrimitiveType.I64, buffer);
} }
// TODO: Revisit this // TODO: Revisit this
@@ -148,7 +148,7 @@ public class Lexer
buffer += current.Value; buffer += current.Value;
} }
return new LiteralToken(NubType.String, buffer); return new LiteralToken(NubPrimitiveType.String, buffer);
} }
if (char.IsWhiteSpace(current.Value)) if (char.IsWhiteSpace(current.Value))

View File

@@ -20,12 +20,7 @@ public class Parser
{ {
if (TryExpectSymbol(Symbol.Import)) if (TryExpectSymbol(Symbol.Import))
{ {
var name = ExpectLiteral(); var name = ExpectIdentifier();
if (!name.Type.Equals(NubType.String))
{
throw new Exception("Import statements must have a string literal value");
}
TryExpectSymbol(Symbol.Semicolon); TryExpectSymbol(Symbol.Semicolon);
imports.Add(name.Value); imports.Add(name.Value);
} }
@@ -466,7 +461,7 @@ public class Parser
private NubType ParseType() private NubType ParseType()
{ {
var name = ExpectIdentifier().Value; var name = ExpectIdentifier().Value;
return new NubType(name); return NubType.Parse(name);
} }
private Token ExpectToken() private Token ExpectToken()

View File

@@ -197,7 +197,7 @@ public class ExpressionTyper
case BinaryExpressionOperator.LessThan: case BinaryExpressionOperator.LessThan:
case BinaryExpressionOperator.LessThanOrEqual: case BinaryExpressionOperator.LessThanOrEqual:
{ {
binaryExpression.Type = new NubType("bool"); binaryExpression.Type = NubPrimitiveType.Bool;
break; break;
} }
case BinaryExpressionOperator.Plus: case BinaryExpressionOperator.Plus:

View File

@@ -1,32 +1,117 @@
namespace Nub.Lang; using System.Diagnostics.CodeAnalysis;
public sealed class NubType namespace Nub.Lang;
public abstract class NubType
{ {
public NubType(string name) protected NubType(string name)
{ {
Name = name; Name = name;
} }
public string Name { get; } public string Name { get; }
public static NubType Int64 => new("i64"); public static NubType Parse(string s)
public static NubType Int32 => new("i32");
public static NubType Bool => new("bool");
public static NubType String => new("string");
public static NubType Any => new("any");
public override bool Equals(object? obj)
{ {
return obj is NubType item && Name.Equals(item.Name); if (NubPrimitiveType.TryParse(s, out var kind))
{
return new NubPrimitiveType(kind.Value);
} }
public override int GetHashCode() return new NubCustomType(s);
{
return HashCode.Combine(Name);
} }
public override string ToString() public override bool Equals(object? obj) => obj is NubType item && Name.Equals(item.Name);
public override int GetHashCode() => HashCode.Combine(Name);
public override string ToString() => Name;
}
public class NubCustomType(string name) : NubType(name);
public class NubPrimitiveType : NubType
{ {
return $"{Name}"; public NubPrimitiveType(PrimitiveTypeKind kind) : base(KindToString(kind))
{
Kind = kind;
}
public PrimitiveTypeKind Kind { get; }
public static NubPrimitiveType I64 => new(PrimitiveTypeKind.I64);
public static NubPrimitiveType I32 => new(PrimitiveTypeKind.I32);
public static NubPrimitiveType I16 => new(PrimitiveTypeKind.I16);
public static NubPrimitiveType I8 => new(PrimitiveTypeKind.I8);
public static NubPrimitiveType U64 => new(PrimitiveTypeKind.U64);
public static NubPrimitiveType U32 => new(PrimitiveTypeKind.U32);
public static NubPrimitiveType U16 => new(PrimitiveTypeKind.U16);
public static NubPrimitiveType U8 => new(PrimitiveTypeKind.U8);
public static NubPrimitiveType F64 => new(PrimitiveTypeKind.F64);
public static NubPrimitiveType F32 => new(PrimitiveTypeKind.F32);
public static NubPrimitiveType Bool => new(PrimitiveTypeKind.Bool);
public static NubPrimitiveType String => new(PrimitiveTypeKind.String);
public static NubPrimitiveType Any => new(PrimitiveTypeKind.Any);
public static bool TryParse(string s, [NotNullWhen(true)] out PrimitiveTypeKind? kind)
{
kind = s switch
{
"i64" => PrimitiveTypeKind.I64,
"i32" => PrimitiveTypeKind.I32,
"i16" => PrimitiveTypeKind.I16,
"i8" => PrimitiveTypeKind.I8,
"u64" => PrimitiveTypeKind.U64,
"u32" => PrimitiveTypeKind.U32,
"u16" => PrimitiveTypeKind.U16,
"u8" => PrimitiveTypeKind.U8,
"f64" => PrimitiveTypeKind.F64,
"f32" => PrimitiveTypeKind.F32,
"bool" => PrimitiveTypeKind.Bool,
"string" => PrimitiveTypeKind.String,
"any" => PrimitiveTypeKind.Any,
_ => null
};
return kind != null;
}
public static string KindToString(PrimitiveTypeKind kind)
{
return kind switch
{
PrimitiveTypeKind.I64 => "i64",
PrimitiveTypeKind.I32 => "i32",
PrimitiveTypeKind.I16 => "i16",
PrimitiveTypeKind.I8 => "i8",
PrimitiveTypeKind.U64 => "u64",
PrimitiveTypeKind.U32 => "u32",
PrimitiveTypeKind.U16 => "u16",
PrimitiveTypeKind.U8 => "u8",
PrimitiveTypeKind.F64 => "f64",
PrimitiveTypeKind.F32 => "f32",
PrimitiveTypeKind.Bool => "bool",
PrimitiveTypeKind.String => "string",
PrimitiveTypeKind.Any => "any",
_ => throw new ArgumentOutOfRangeException(nameof(kind), kind, null)
};
} }
} }
public enum PrimitiveTypeKind
{
I64,
I32,
I16,
I8,
U64,
U32,
U16,
U8,
F64,
F32,
Bool,
String,
Any
}

View File

@@ -1,264 +0,0 @@
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
#include <sys/mman.h>
#define MINIMUM_THRESHOLD (1024 * 1024 * 8)
#define MINIMUM_BLOCK_SIZE 4096
typedef struct alloc_block {
uint64_t mark;
uint64_t size;
struct alloc_block* next;
} alloc_block_t;
typedef struct free_block {
uint64_t size;
struct free_block* next;
} free_block_t;
static alloc_block_t* alloc_list_head = NULL;
static free_block_t* free_list_head = NULL;
static void* stack_start = NULL;
static int64_t free_list_size = 0;
static int64_t mark_count = 0;
/* Bytes allocated since last collect */
static int64_t bytes_allocated = 0;
/* Threshold for next collect */
static int64_t trigger_threshold = MINIMUM_THRESHOLD;
static void* sys_mmap(size_t size);
static void* get_sp(void);
static void gc_collect(void);
static void gc_mark(void* ptr);
static void gc_mark_stack(void);
static void gc_sweep(void);
static int64_t max(int64_t a, int64_t b);
static void insert_into_free(free_block_t* block);
static void merge(free_block_t* block);
void gc_init(void) {
stack_start = get_sp();
}
/* Allocate memory with garbage collection */
void* gc_alloc(int64_t size) {
size += sizeof(alloc_block_t); // Adjust for metadata size
if (bytes_allocated > trigger_threshold) {
gc_collect();
}
bytes_allocated += size;
// Search free list for a suitable block
free_block_t* current = free_list_head;
free_block_t* prev = NULL;
while (current != NULL) {
if (current->size >= size) {
// Found a suitable block
break;
}
prev = current;
current = current->next;
}
if (current == NULL) {
// No suitable block found, allocate a new one
int64_t alloc_size = max(size, MINIMUM_BLOCK_SIZE);
void* memory = sys_mmap(alloc_size);
free_block_t* new_block = (free_block_t*)memory;
new_block->size = alloc_size - sizeof(free_block_t);
new_block->next = NULL;
insert_into_free(new_block);
current = new_block;
// Recalculate prev
if (current == free_list_head) {
prev = NULL;
} else {
prev = free_list_head;
while (prev->next != current) {
prev = prev->next;
}
}
}
// Use the block
alloc_block_t* result;
if (current->size > size) {
// Block is larger than needed, split it
result = (alloc_block_t*)((char*)current + current->size + sizeof(free_block_t) - size);
current->size -= size;
} else {
// Use the entire block
result = (alloc_block_t*)current;
// Remove block from free list
if (prev == NULL) {
free_list_head = current->next;
} else {
prev->next = current->next;
}
free_list_size--;
}
// Initialize metadata
result->mark = 0;
result->size = size - sizeof(alloc_block_t);
result->next = alloc_list_head;
alloc_list_head = result;
// Return pointer to usable memory
return (void*)(result + 1);
}
/* Run garbage collection */
static void gc_collect(void) {
gc_mark_stack();
gc_sweep();
trigger_threshold = max(bytes_allocated * 2, MINIMUM_THRESHOLD);
bytes_allocated = 0;
}
static void gc_mark_stack(void) {
mark_count = 0;
void** current = get_sp();
void** end = (void**)stack_start;
while (current < end) {
gc_mark(*current);
current++;
}
}
/* Mark a single object and recursively mark its contents */
static void gc_mark(void* ptr) {
if (ptr == NULL) {
return;
}
alloc_block_t* block = alloc_list_head;
while (block != NULL) {
void* block_data = (void*)(block + 1);
if (block_data == ptr) {
if (block->mark == 0) {
mark_count++;
block->mark = 1;
void** p = (void**)block_data;
void** end = (void**)((char*)block_data + block->size);
while (p < end) {
gc_mark(*p);
p++;
}
}
return;
}
block = block->next;
}
}
static void gc_sweep(void) {
alloc_block_t* current = alloc_list_head;
alloc_block_t* prev = NULL;
while (current != NULL) {
if (current->mark == 0) {
alloc_block_t* next = current->next;
if (prev == NULL) {
alloc_list_head = next;
} else {
prev->next = next;
}
bytes_allocated -= (current->size + sizeof(alloc_block_t));
free_block_t* free_block = (free_block_t*)current;
free_block->size = current->size + sizeof(alloc_block_t) - sizeof(free_block_t);
free_block->next = NULL;
insert_into_free(free_block);
current = next;
} else {
current->mark = 0;
prev = current;
current = current->next;
}
}
}
/* Insert a block into the free list, maintaining address order */
static void insert_into_free(free_block_t* block) {
if (free_list_head == NULL || block < free_list_head) {
// Insert at head
block->next = free_list_head;
free_list_head = block;
free_list_size++;
merge(block);
return;
}
// Find insertion point
free_block_t* current = free_list_head;
while (current->next != NULL && current->next < block) {
current = current->next;
}
// Insert after current
block->next = current->next;
current->next = block;
free_list_size++;
// Try to merge adjacent blocks
merge(current);
}
static void merge(free_block_t* block) {
while (block->next != NULL) {
char* block_end = (char*)block + block->size + sizeof(free_block_t);
if (block_end == (char*)block->next) {
free_list_size--;
block->size += block->next->size + sizeof(free_block_t);
block->next = block->next->next;
} else {
break;
}
}
}
static void* sys_mmap(size_t size) {
void* result = mmap(NULL, size, PROT_READ | PROT_WRITE, MAP_PRIVATE | MAP_ANONYMOUS, -1, 0);
if (result == MAP_FAILED) {
perror("[sys_mmap] mmap failed");
exit(1);
}
return result;
}
static int64_t max(int64_t a, int64_t b) {
if (a > b) {
return a;
} else {
return b;
}
}
void* get_sp(void) {
volatile unsigned long var = 0;
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wreturn-local-addr"
return (void*)((unsigned long)&var + 4);
#pragma GCC diagnostic pop
}