This commit is contained in:
nub31
2025-05-04 16:02:48 +02:00
parent 248f95fa6e
commit 5f2d1ff3f9
25 changed files with 264 additions and 539 deletions

View File

@@ -6,7 +6,6 @@ namespace Nub.Lang.Backend.Custom;
public class Generator
{
private const string Entrypoint = "main";
private const bool ZeroBasedIndexing = false;
private readonly List<DefinitionNode> _definitions;
private readonly SymbolTable _symbolTable;
@@ -129,17 +128,20 @@ public class Generator
}
case LiteralNode literal:
{
if (literal.Type is not PrimitiveType primitiveType)
if (literal.LiteralType.Equals(NubType.Int64)
|| literal.LiteralType.Equals(NubType.Int32)
|| literal.LiteralType.Equals(NubType.Int16)
|| literal.LiteralType.Equals(NubType.Int8))
{
throw new NotSupportedException("Global variable literals must be of a primitive type");
return literal.Literal;
}
return primitiveType.Kind switch
if (literal.LiteralType.Equals(NubType.Bool))
{
PrimitiveTypeKind.Bool => bool.Parse(literal.Literal) ? "1" : "0",
PrimitiveTypeKind.Int64 or PrimitiveTypeKind.Int32 => $"{literal.Literal}",
_ => throw new ArgumentOutOfRangeException()
};
return bool.Parse(literal.Literal) ? "1" : "0";
}
throw new InvalidOperationException("BAD");
}
default:
{
@@ -194,9 +196,6 @@ public class Generator
{
switch (statement)
{
case ArrayIndexAssignmentNode arrayIndexAssignment:
GenerateArrayIndexAssignment(arrayIndexAssignment, func);
break;
case BreakNode:
GenerateBreak();
break;
@@ -239,15 +238,6 @@ public class Generator
_builder.AppendLine($" jmp {_loops.Peek().StartLabel}");
}
private void GenerateArrayIndexAssignment(ArrayIndexAssignmentNode arrayIndexAssignment, LocalFunc func)
{
GenerateExpression(arrayIndexAssignment.Value, func);
_builder.AppendLine(" push rax");
GenerateArrayIndexPointerAccess(arrayIndexAssignment.Identifier, arrayIndexAssignment.Index, func);
_builder.AppendLine(" pop rdx");
_builder.AppendLine(" mov [rax], rdx");
}
private void GenerateIf(IfNode ifStatement, LocalFunc func)
{
var endLabel = _labelFactory.Create();
@@ -319,12 +309,6 @@ public class Generator
{
switch (expression)
{
case ArrayIndexAccessNode arrayIndexAccess:
GenerateArrayIndexAccess(arrayIndexAccess, func);
break;
case ArrayInitializerNode arrayInitializer:
GenerateArrayInitializer(arrayInitializer);
break;
case BinaryExpressionNode binaryExpression:
GenerateBinaryExpression(binaryExpression, func);
break;
@@ -355,31 +339,21 @@ public class Generator
{
var variable = func.ResolveLocalVariable(structMemberAccessor.Members[0]);
if (variable.Type is not StructType structType)
{
throw new Exception($"Cannot access struct member on {variable} since it is not a struct type");
}
_builder.AppendLine($" mov rax, [rbp - {variable.Offset}]");
Type prevMemberType = structType;
var prevMemberType = variable.Type;
for (var i = 1; i < structMemberAccessor.Members.Count; i++)
{
if (prevMemberType is not StructType prevMemberStructType)
{
throw new Exception($"Cannot access {structMemberAccessor.Members[i]} on type {prevMemberType} because it is not a struct type");
}
var structDefinition = _definitions.OfType<StructDefinitionNode>().FirstOrDefault(sd => sd.Name == prevMemberStructType.Name);
var structDefinition = _definitions.OfType<StructDefinitionNode>().FirstOrDefault(sd => sd.Name == prevMemberType.Name);
if (structDefinition == null)
{
throw new Exception($"Struct {prevMemberStructType} is not defined");
throw new Exception($"Struct {prevMemberType} is not defined");
}
var member = structDefinition.Members.FirstOrDefault(m => m.Name == structMemberAccessor.Members[i]);
if (member == null)
{
throw new Exception($"Struct {prevMemberStructType} has no member with name {structMemberAccessor.Members[i]}");
throw new Exception($"Struct {prevMemberType} has no member with name {structMemberAccessor.Members[i]}");
}
var offset = structDefinition.Members.IndexOf(member);
@@ -389,19 +363,6 @@ public class Generator
}
}
private void GenerateArrayIndexAccess(ArrayIndexAccessNode arrayIndexAccess, LocalFunc func)
{
GenerateArrayIndexPointerAccess(arrayIndexAccess.Identifier, arrayIndexAccess.Index, func);
_builder.AppendLine(" mov rax, [rax]");
}
private void GenerateArrayInitializer(ArrayInitializerNode arrayInitializer)
{
_builder.AppendLine($" mov rdi, {8 + arrayInitializer.Length * 8}");
_builder.AppendLine(" call gc_alloc");
_builder.AppendLine($" mov qword [rax], {arrayInitializer.Length}");
}
private void GenerateBinaryExpression(BinaryExpressionNode binaryExpression, LocalFunc func)
{
GenerateExpression(binaryExpression.Left, func);
@@ -459,108 +420,121 @@ public class Generator
}
}
private void GenerateComparison(Type type)
private void GenerateComparison(NubType type)
{
switch (type)
if (type.Equals(NubType.String))
{
case AnyType:
throw new InvalidOperationException($"Cannot compare type {type}");
case ArrayType:
// compare pointers
_builder.AppendLine(" cmp rax, rcx");
break;
case PrimitiveType:
_builder.AppendLine(" cmp rax, rcx");
break;
case StringType:
_builder.AppendLine(" mov rdi, rax");
_builder.AppendLine(" mov rsi, rcx");
_builder.AppendLine(" call str_cmp");
break;
default:
throw new ArgumentOutOfRangeException(nameof(type));
_builder.AppendLine(" mov rdi, rax");
_builder.AppendLine(" mov rsi, rcx");
_builder.AppendLine(" call str_cmp");
}
else if (type.Equals(NubType.Bool) || type.Equals(NubType.Int64) || type.Equals(NubType.Int32) || type.Equals(NubType.Int16) || type.Equals(NubType.Int8))
{
_builder.AppendLine(" cmp rax, rcx");
}
else
{
throw new ArgumentOutOfRangeException(nameof(type));
}
}
private void GenerateBinaryAddition(Type type)
private void GenerateBinaryAddition(NubType type)
{
if (type is not PrimitiveType primitiveType)
if (type.Equals(NubType.Int64))
{
throw new InvalidOperationException("Addition can only be done on primitive types");
_builder.AppendLine(" add rax, rcx");
}
switch (primitiveType.Kind)
else if (type.Equals(NubType.Int32))
{
case PrimitiveTypeKind.Int64:
_builder.AppendLine(" add rax, rcx");
break;
case PrimitiveTypeKind.Int32:
_builder.AppendLine(" add eax, ecx");
break;
default:
throw new InvalidOperationException($"Invalid type {primitiveType.Kind}");
_builder.AppendLine(" add eax, ecx");
}
else if (type.Equals(NubType.Int16))
{
_builder.AppendLine(" add ax, cx");
}
else if (type.Equals(NubType.Int8))
{
_builder.AppendLine(" add al, cl");
}
else
{
throw new InvalidOperationException($"Invalid type for addition {type}");
}
}
private void GenerateBinarySubtraction(Type type)
private void GenerateBinarySubtraction(NubType type)
{
if (type is not PrimitiveType primitiveType)
if (type.Equals(NubType.Int64))
{
throw new InvalidOperationException("Subtraction can only be done on primitive types");
_builder.AppendLine(" sub rax, rcx");
}
switch (primitiveType.Kind)
else if (type.Equals(NubType.Int32))
{
case PrimitiveTypeKind.Int64:
_builder.AppendLine(" sub rax, rcx");
break;
case PrimitiveTypeKind.Int32:
_builder.AppendLine(" sub eax, ecx");
break;
default:
throw new InvalidOperationException($"Invalid type {primitiveType.Kind}");
_builder.AppendLine(" sub eax, ecx");
}
else if (type.Equals(NubType.Int16))
{
_builder.AppendLine(" sub ax, cx");
}
else if (type.Equals(NubType.Int8))
{
_builder.AppendLine(" sub al, cl");
}
else
{
throw new InvalidOperationException($"Invalid type for subtraction {type}");
}
}
private void GenerateBinaryMultiplication(Type type)
private void GenerateBinaryMultiplication(NubType type)
{
if (type is not PrimitiveType primitiveType)
if (type.Equals(NubType.Int64))
{
throw new InvalidOperationException("Multiplication can only be done on primitive types");
_builder.AppendLine(" imul rcx");
}
switch (primitiveType.Kind)
else if (type.Equals(NubType.Int32))
{
case PrimitiveTypeKind.Int64:
_builder.AppendLine(" imul rcx");
break;
case PrimitiveTypeKind.Int32:
_builder.AppendLine(" imul ecx");
break;
default:
throw new InvalidOperationException($"Invalid type {primitiveType.Kind}");
_builder.AppendLine(" imul ecx");
}
else if (type.Equals(NubType.Int16))
{
_builder.AppendLine(" imul cx");
}
else if (type.Equals(NubType.Int8))
{
_builder.AppendLine(" imul cl");
}
else
{
throw new InvalidOperationException($"Invalid type for multiplication {type}");
}
}
private void GenerateBinaryDivision(Type type)
private void GenerateBinaryDivision(NubType type)
{
if (type is not PrimitiveType primitiveType)
if (type.Equals(NubType.Int64))
{
throw new InvalidOperationException("Division can only be done on primitive types");
_builder.AppendLine(" cqo");
_builder.AppendLine(" idiv rcx");
}
switch (primitiveType.Kind)
else if (type.Equals(NubType.Int32))
{
case PrimitiveTypeKind.Int64:
_builder.AppendLine(" cqo");
_builder.AppendLine(" idiv rcx");
break;
case PrimitiveTypeKind.Int32:
_builder.AppendLine(" cdq");
_builder.AppendLine(" idiv ecx");
break;
default:
throw new InvalidOperationException($"Invalid type {primitiveType.Kind}");
_builder.AppendLine(" cdq");
_builder.AppendLine(" idiv ecx");
}
else if (type.Equals(NubType.Int16))
{
_builder.AppendLine(" cwd");
_builder.AppendLine(" idiv cx");
}
else if (type.Equals(NubType.Int8))
{
_builder.AppendLine(" cbw");
_builder.AppendLine(" idiv cl");
}
else
{
throw new InvalidOperationException($"Invalid type for division {type}");
}
}
@@ -586,34 +560,22 @@ public class Generator
private void GenerateLiteral(LiteralNode literal)
{
switch (literal.Type)
if (literal.Type.Equals(NubType.String))
{
case StringType:
{
var label = _symbolTable.DefineString(literal.Literal);
_builder.AppendLine($" mov rax, {label}");
break;
}
case PrimitiveType primitive:
{
switch (primitive.Kind)
{
case PrimitiveTypeKind.Bool:
_builder.AppendLine($" mov rax, {(bool.Parse(literal.Literal) ? "1" : "0")}");
break;
case PrimitiveTypeKind.Int64:
_builder.AppendLine($" mov rax, {literal.Literal}");
break;
case PrimitiveTypeKind.Int32:
_builder.AppendLine($" mov rax, {literal.Literal}");
break;
default:
throw new Exception("Cannot convert literal to string");
}
break;
}
default:
throw new ArgumentOutOfRangeException();
var label = _symbolTable.DefineString(literal.Literal);
_builder.AppendLine($" mov rax, {label}");
}
else if (literal.Type.Equals(NubType.Int64) || literal.Type.Equals(NubType.Int32) || literal.Type.Equals(NubType.Int16) || literal.Type.Equals(NubType.Int8))
{
_builder.AppendLine($" mov rax, {literal.Literal}");
}
else if (literal.Type.Equals(NubType.Bool))
{
_builder.AppendLine($" mov rax, {(bool.Parse(literal.Literal) ? "1" : "0")}");
}
else
{
throw new NotImplementedException($"Literal type {literal.Type} not implemented");
}
}
@@ -706,31 +668,4 @@ public class Generator
_builder.AppendLine(" syscall");
}
private void GenerateArrayIndexPointerAccess(IdentifierNode identifier, ExpressionNode index, LocalFunc func)
{
GenerateExpression(index, func);
_builder.AppendLine(" push rax");
GenerateIdentifier(identifier, func);
_builder.AppendLine(" pop rdx");
// rcx now holds the length of the array which we can use to check bounds
_builder.AppendLine(" mov rcx, [rax]");
_builder.AppendLine(" cmp rdx, rcx");
if (ZeroBasedIndexing)
{
_builder.AppendLine(" jge eb6e_oob_error");
_builder.AppendLine(" cmp rdx, 0");
}
else
{
_builder.AppendLine(" jg eb6e_oob_error");
_builder.AppendLine(" cmp rdx, 1");
}
_builder.AppendLine(" jl eb6e_oob_error");
_builder.AppendLine(" inc rdx");
_builder.AppendLine(" shl rdx, 3");
_builder.AppendLine(" add rax, rdx");
}
}