...
This commit is contained in:
@@ -6,7 +6,6 @@ namespace Nub.Lang.Backend.Custom;
|
||||
public class Generator
|
||||
{
|
||||
private const string Entrypoint = "main";
|
||||
private const bool ZeroBasedIndexing = false;
|
||||
|
||||
private readonly List<DefinitionNode> _definitions;
|
||||
private readonly SymbolTable _symbolTable;
|
||||
@@ -129,17 +128,20 @@ public class Generator
|
||||
}
|
||||
case LiteralNode literal:
|
||||
{
|
||||
if (literal.Type is not PrimitiveType primitiveType)
|
||||
if (literal.LiteralType.Equals(NubType.Int64)
|
||||
|| literal.LiteralType.Equals(NubType.Int32)
|
||||
|| literal.LiteralType.Equals(NubType.Int16)
|
||||
|| literal.LiteralType.Equals(NubType.Int8))
|
||||
{
|
||||
throw new NotSupportedException("Global variable literals must be of a primitive type");
|
||||
return literal.Literal;
|
||||
}
|
||||
|
||||
return primitiveType.Kind switch
|
||||
if (literal.LiteralType.Equals(NubType.Bool))
|
||||
{
|
||||
PrimitiveTypeKind.Bool => bool.Parse(literal.Literal) ? "1" : "0",
|
||||
PrimitiveTypeKind.Int64 or PrimitiveTypeKind.Int32 => $"{literal.Literal}",
|
||||
_ => throw new ArgumentOutOfRangeException()
|
||||
};
|
||||
return bool.Parse(literal.Literal) ? "1" : "0";
|
||||
}
|
||||
|
||||
throw new InvalidOperationException("BAD");
|
||||
}
|
||||
default:
|
||||
{
|
||||
@@ -194,9 +196,6 @@ public class Generator
|
||||
{
|
||||
switch (statement)
|
||||
{
|
||||
case ArrayIndexAssignmentNode arrayIndexAssignment:
|
||||
GenerateArrayIndexAssignment(arrayIndexAssignment, func);
|
||||
break;
|
||||
case BreakNode:
|
||||
GenerateBreak();
|
||||
break;
|
||||
@@ -239,15 +238,6 @@ public class Generator
|
||||
_builder.AppendLine($" jmp {_loops.Peek().StartLabel}");
|
||||
}
|
||||
|
||||
private void GenerateArrayIndexAssignment(ArrayIndexAssignmentNode arrayIndexAssignment, LocalFunc func)
|
||||
{
|
||||
GenerateExpression(arrayIndexAssignment.Value, func);
|
||||
_builder.AppendLine(" push rax");
|
||||
GenerateArrayIndexPointerAccess(arrayIndexAssignment.Identifier, arrayIndexAssignment.Index, func);
|
||||
_builder.AppendLine(" pop rdx");
|
||||
_builder.AppendLine(" mov [rax], rdx");
|
||||
}
|
||||
|
||||
private void GenerateIf(IfNode ifStatement, LocalFunc func)
|
||||
{
|
||||
var endLabel = _labelFactory.Create();
|
||||
@@ -319,12 +309,6 @@ public class Generator
|
||||
{
|
||||
switch (expression)
|
||||
{
|
||||
case ArrayIndexAccessNode arrayIndexAccess:
|
||||
GenerateArrayIndexAccess(arrayIndexAccess, func);
|
||||
break;
|
||||
case ArrayInitializerNode arrayInitializer:
|
||||
GenerateArrayInitializer(arrayInitializer);
|
||||
break;
|
||||
case BinaryExpressionNode binaryExpression:
|
||||
GenerateBinaryExpression(binaryExpression, func);
|
||||
break;
|
||||
@@ -355,31 +339,21 @@ public class Generator
|
||||
{
|
||||
var variable = func.ResolveLocalVariable(structMemberAccessor.Members[0]);
|
||||
|
||||
if (variable.Type is not StructType structType)
|
||||
{
|
||||
throw new Exception($"Cannot access struct member on {variable} since it is not a struct type");
|
||||
}
|
||||
|
||||
_builder.AppendLine($" mov rax, [rbp - {variable.Offset}]");
|
||||
|
||||
Type prevMemberType = structType;
|
||||
var prevMemberType = variable.Type;
|
||||
for (var i = 1; i < structMemberAccessor.Members.Count; i++)
|
||||
{
|
||||
if (prevMemberType is not StructType prevMemberStructType)
|
||||
{
|
||||
throw new Exception($"Cannot access {structMemberAccessor.Members[i]} on type {prevMemberType} because it is not a struct type");
|
||||
}
|
||||
|
||||
var structDefinition = _definitions.OfType<StructDefinitionNode>().FirstOrDefault(sd => sd.Name == prevMemberStructType.Name);
|
||||
var structDefinition = _definitions.OfType<StructDefinitionNode>().FirstOrDefault(sd => sd.Name == prevMemberType.Name);
|
||||
if (structDefinition == null)
|
||||
{
|
||||
throw new Exception($"Struct {prevMemberStructType} is not defined");
|
||||
throw new Exception($"Struct {prevMemberType} is not defined");
|
||||
}
|
||||
|
||||
var member = structDefinition.Members.FirstOrDefault(m => m.Name == structMemberAccessor.Members[i]);
|
||||
if (member == null)
|
||||
{
|
||||
throw new Exception($"Struct {prevMemberStructType} has no member with name {structMemberAccessor.Members[i]}");
|
||||
throw new Exception($"Struct {prevMemberType} has no member with name {structMemberAccessor.Members[i]}");
|
||||
}
|
||||
|
||||
var offset = structDefinition.Members.IndexOf(member);
|
||||
@@ -389,19 +363,6 @@ public class Generator
|
||||
}
|
||||
}
|
||||
|
||||
private void GenerateArrayIndexAccess(ArrayIndexAccessNode arrayIndexAccess, LocalFunc func)
|
||||
{
|
||||
GenerateArrayIndexPointerAccess(arrayIndexAccess.Identifier, arrayIndexAccess.Index, func);
|
||||
_builder.AppendLine(" mov rax, [rax]");
|
||||
}
|
||||
|
||||
private void GenerateArrayInitializer(ArrayInitializerNode arrayInitializer)
|
||||
{
|
||||
_builder.AppendLine($" mov rdi, {8 + arrayInitializer.Length * 8}");
|
||||
_builder.AppendLine(" call gc_alloc");
|
||||
_builder.AppendLine($" mov qword [rax], {arrayInitializer.Length}");
|
||||
}
|
||||
|
||||
private void GenerateBinaryExpression(BinaryExpressionNode binaryExpression, LocalFunc func)
|
||||
{
|
||||
GenerateExpression(binaryExpression.Left, func);
|
||||
@@ -459,108 +420,121 @@ public class Generator
|
||||
}
|
||||
}
|
||||
|
||||
private void GenerateComparison(Type type)
|
||||
private void GenerateComparison(NubType type)
|
||||
{
|
||||
switch (type)
|
||||
if (type.Equals(NubType.String))
|
||||
{
|
||||
case AnyType:
|
||||
throw new InvalidOperationException($"Cannot compare type {type}");
|
||||
case ArrayType:
|
||||
// compare pointers
|
||||
_builder.AppendLine(" cmp rax, rcx");
|
||||
break;
|
||||
case PrimitiveType:
|
||||
_builder.AppendLine(" cmp rax, rcx");
|
||||
break;
|
||||
case StringType:
|
||||
_builder.AppendLine(" mov rdi, rax");
|
||||
_builder.AppendLine(" mov rsi, rcx");
|
||||
_builder.AppendLine(" call str_cmp");
|
||||
break;
|
||||
default:
|
||||
throw new ArgumentOutOfRangeException(nameof(type));
|
||||
_builder.AppendLine(" mov rdi, rax");
|
||||
_builder.AppendLine(" mov rsi, rcx");
|
||||
_builder.AppendLine(" call str_cmp");
|
||||
}
|
||||
else if (type.Equals(NubType.Bool) || type.Equals(NubType.Int64) || type.Equals(NubType.Int32) || type.Equals(NubType.Int16) || type.Equals(NubType.Int8))
|
||||
{
|
||||
_builder.AppendLine(" cmp rax, rcx");
|
||||
}
|
||||
else
|
||||
{
|
||||
throw new ArgumentOutOfRangeException(nameof(type));
|
||||
}
|
||||
}
|
||||
|
||||
private void GenerateBinaryAddition(Type type)
|
||||
private void GenerateBinaryAddition(NubType type)
|
||||
{
|
||||
if (type is not PrimitiveType primitiveType)
|
||||
if (type.Equals(NubType.Int64))
|
||||
{
|
||||
throw new InvalidOperationException("Addition can only be done on primitive types");
|
||||
_builder.AppendLine(" add rax, rcx");
|
||||
}
|
||||
|
||||
switch (primitiveType.Kind)
|
||||
else if (type.Equals(NubType.Int32))
|
||||
{
|
||||
case PrimitiveTypeKind.Int64:
|
||||
_builder.AppendLine(" add rax, rcx");
|
||||
break;
|
||||
case PrimitiveTypeKind.Int32:
|
||||
_builder.AppendLine(" add eax, ecx");
|
||||
break;
|
||||
default:
|
||||
throw new InvalidOperationException($"Invalid type {primitiveType.Kind}");
|
||||
_builder.AppendLine(" add eax, ecx");
|
||||
}
|
||||
else if (type.Equals(NubType.Int16))
|
||||
{
|
||||
_builder.AppendLine(" add ax, cx");
|
||||
}
|
||||
else if (type.Equals(NubType.Int8))
|
||||
{
|
||||
_builder.AppendLine(" add al, cl");
|
||||
}
|
||||
else
|
||||
{
|
||||
throw new InvalidOperationException($"Invalid type for addition {type}");
|
||||
}
|
||||
}
|
||||
|
||||
private void GenerateBinarySubtraction(Type type)
|
||||
private void GenerateBinarySubtraction(NubType type)
|
||||
{
|
||||
if (type is not PrimitiveType primitiveType)
|
||||
if (type.Equals(NubType.Int64))
|
||||
{
|
||||
throw new InvalidOperationException("Subtraction can only be done on primitive types");
|
||||
_builder.AppendLine(" sub rax, rcx");
|
||||
}
|
||||
|
||||
switch (primitiveType.Kind)
|
||||
else if (type.Equals(NubType.Int32))
|
||||
{
|
||||
case PrimitiveTypeKind.Int64:
|
||||
_builder.AppendLine(" sub rax, rcx");
|
||||
break;
|
||||
case PrimitiveTypeKind.Int32:
|
||||
_builder.AppendLine(" sub eax, ecx");
|
||||
break;
|
||||
default:
|
||||
throw new InvalidOperationException($"Invalid type {primitiveType.Kind}");
|
||||
_builder.AppendLine(" sub eax, ecx");
|
||||
}
|
||||
else if (type.Equals(NubType.Int16))
|
||||
{
|
||||
_builder.AppendLine(" sub ax, cx");
|
||||
}
|
||||
else if (type.Equals(NubType.Int8))
|
||||
{
|
||||
_builder.AppendLine(" sub al, cl");
|
||||
}
|
||||
else
|
||||
{
|
||||
throw new InvalidOperationException($"Invalid type for subtraction {type}");
|
||||
}
|
||||
}
|
||||
|
||||
private void GenerateBinaryMultiplication(Type type)
|
||||
private void GenerateBinaryMultiplication(NubType type)
|
||||
{
|
||||
if (type is not PrimitiveType primitiveType)
|
||||
if (type.Equals(NubType.Int64))
|
||||
{
|
||||
throw new InvalidOperationException("Multiplication can only be done on primitive types");
|
||||
_builder.AppendLine(" imul rcx");
|
||||
}
|
||||
|
||||
switch (primitiveType.Kind)
|
||||
else if (type.Equals(NubType.Int32))
|
||||
{
|
||||
case PrimitiveTypeKind.Int64:
|
||||
_builder.AppendLine(" imul rcx");
|
||||
break;
|
||||
case PrimitiveTypeKind.Int32:
|
||||
_builder.AppendLine(" imul ecx");
|
||||
break;
|
||||
default:
|
||||
throw new InvalidOperationException($"Invalid type {primitiveType.Kind}");
|
||||
_builder.AppendLine(" imul ecx");
|
||||
}
|
||||
else if (type.Equals(NubType.Int16))
|
||||
{
|
||||
_builder.AppendLine(" imul cx");
|
||||
}
|
||||
else if (type.Equals(NubType.Int8))
|
||||
{
|
||||
_builder.AppendLine(" imul cl");
|
||||
}
|
||||
else
|
||||
{
|
||||
throw new InvalidOperationException($"Invalid type for multiplication {type}");
|
||||
}
|
||||
}
|
||||
|
||||
private void GenerateBinaryDivision(Type type)
|
||||
private void GenerateBinaryDivision(NubType type)
|
||||
{
|
||||
if (type is not PrimitiveType primitiveType)
|
||||
if (type.Equals(NubType.Int64))
|
||||
{
|
||||
throw new InvalidOperationException("Division can only be done on primitive types");
|
||||
_builder.AppendLine(" cqo");
|
||||
_builder.AppendLine(" idiv rcx");
|
||||
}
|
||||
|
||||
switch (primitiveType.Kind)
|
||||
else if (type.Equals(NubType.Int32))
|
||||
{
|
||||
case PrimitiveTypeKind.Int64:
|
||||
_builder.AppendLine(" cqo");
|
||||
_builder.AppendLine(" idiv rcx");
|
||||
break;
|
||||
case PrimitiveTypeKind.Int32:
|
||||
_builder.AppendLine(" cdq");
|
||||
_builder.AppendLine(" idiv ecx");
|
||||
break;
|
||||
default:
|
||||
throw new InvalidOperationException($"Invalid type {primitiveType.Kind}");
|
||||
_builder.AppendLine(" cdq");
|
||||
_builder.AppendLine(" idiv ecx");
|
||||
}
|
||||
else if (type.Equals(NubType.Int16))
|
||||
{
|
||||
_builder.AppendLine(" cwd");
|
||||
_builder.AppendLine(" idiv cx");
|
||||
}
|
||||
else if (type.Equals(NubType.Int8))
|
||||
{
|
||||
_builder.AppendLine(" cbw");
|
||||
_builder.AppendLine(" idiv cl");
|
||||
}
|
||||
else
|
||||
{
|
||||
throw new InvalidOperationException($"Invalid type for division {type}");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -586,34 +560,22 @@ public class Generator
|
||||
|
||||
private void GenerateLiteral(LiteralNode literal)
|
||||
{
|
||||
switch (literal.Type)
|
||||
if (literal.Type.Equals(NubType.String))
|
||||
{
|
||||
case StringType:
|
||||
{
|
||||
var label = _symbolTable.DefineString(literal.Literal);
|
||||
_builder.AppendLine($" mov rax, {label}");
|
||||
break;
|
||||
}
|
||||
case PrimitiveType primitive:
|
||||
{
|
||||
switch (primitive.Kind)
|
||||
{
|
||||
case PrimitiveTypeKind.Bool:
|
||||
_builder.AppendLine($" mov rax, {(bool.Parse(literal.Literal) ? "1" : "0")}");
|
||||
break;
|
||||
case PrimitiveTypeKind.Int64:
|
||||
_builder.AppendLine($" mov rax, {literal.Literal}");
|
||||
break;
|
||||
case PrimitiveTypeKind.Int32:
|
||||
_builder.AppendLine($" mov rax, {literal.Literal}");
|
||||
break;
|
||||
default:
|
||||
throw new Exception("Cannot convert literal to string");
|
||||
}
|
||||
break;
|
||||
}
|
||||
default:
|
||||
throw new ArgumentOutOfRangeException();
|
||||
var label = _symbolTable.DefineString(literal.Literal);
|
||||
_builder.AppendLine($" mov rax, {label}");
|
||||
}
|
||||
else if (literal.Type.Equals(NubType.Int64) || literal.Type.Equals(NubType.Int32) || literal.Type.Equals(NubType.Int16) || literal.Type.Equals(NubType.Int8))
|
||||
{
|
||||
_builder.AppendLine($" mov rax, {literal.Literal}");
|
||||
}
|
||||
else if (literal.Type.Equals(NubType.Bool))
|
||||
{
|
||||
_builder.AppendLine($" mov rax, {(bool.Parse(literal.Literal) ? "1" : "0")}");
|
||||
}
|
||||
else
|
||||
{
|
||||
throw new NotImplementedException($"Literal type {literal.Type} not implemented");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -706,31 +668,4 @@ public class Generator
|
||||
|
||||
_builder.AppendLine(" syscall");
|
||||
}
|
||||
|
||||
private void GenerateArrayIndexPointerAccess(IdentifierNode identifier, ExpressionNode index, LocalFunc func)
|
||||
{
|
||||
GenerateExpression(index, func);
|
||||
_builder.AppendLine(" push rax");
|
||||
GenerateIdentifier(identifier, func);
|
||||
_builder.AppendLine(" pop rdx");
|
||||
|
||||
// rcx now holds the length of the array which we can use to check bounds
|
||||
_builder.AppendLine(" mov rcx, [rax]");
|
||||
_builder.AppendLine(" cmp rdx, rcx");
|
||||
if (ZeroBasedIndexing)
|
||||
{
|
||||
_builder.AppendLine(" jge eb6e_oob_error");
|
||||
_builder.AppendLine(" cmp rdx, 0");
|
||||
}
|
||||
else
|
||||
{
|
||||
_builder.AppendLine(" jg eb6e_oob_error");
|
||||
_builder.AppendLine(" cmp rdx, 1");
|
||||
}
|
||||
_builder.AppendLine(" jl eb6e_oob_error");
|
||||
|
||||
_builder.AppendLine(" inc rdx");
|
||||
_builder.AppendLine(" shl rdx, 3");
|
||||
_builder.AppendLine(" add rax, rdx");
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user