This commit is contained in:
nub31
2025-11-03 13:46:25 +01:00
parent 085f7a1a6a
commit f231a45285
4 changed files with 133 additions and 63 deletions

View File

@@ -581,9 +581,25 @@ public class SizeNode(List<Token> tokens, NubType targetType) : RValue(tokens, n
} }
} }
public class CastNode(List<Token> tokens, NubType type, ExpressionNode value) : RValue(tokens, type) public class CastNode(List<Token> tokens, NubType type, ExpressionNode value, CastNode.Conversion conversionType) : RValue(tokens, type)
{ {
public enum Conversion
{
IntToInt,
FloatToFloat,
IntToFloat,
FloatToInt,
PointerToPointer,
PointerToUInt64,
UInt64ToPointer,
ConstArrayToArray,
ConstArrayToSlice,
}
public ExpressionNode Value { get; } = value; public ExpressionNode Value { get; } = value;
public Conversion ConversionType { get; } = conversionType;
public override IEnumerable<Node> Children() public override IEnumerable<Node> Children()
{ {

View File

@@ -322,9 +322,9 @@ public sealed class TypeChecker
return result; return result;
} }
if (IsCastAllowed(result.Type, expectedType)) if (IsCastAllowed(result.Type, expectedType, out var conversion))
{ {
return new CastNode(result.Tokens, expectedType, result); return new CastNode(result.Tokens, expectedType, result, conversion);
} }
} }
@@ -354,26 +354,39 @@ public sealed class TypeChecker
return value; return value;
} }
if (!IsCastAllowed(value.Type, expectedType, false)) if (!IsCastAllowed(value.Type, expectedType, out var conversion, false))
{ {
throw new CompileException(Diagnostic throw new CompileException(Diagnostic
.Error($"Cannot cast from {value.Type} to {expectedType}") .Error($"Cannot cast from {value.Type} to {expectedType}")
.Build()); .Build());
} }
return new CastNode(expression.Tokens, expectedType, value); return new CastNode(expression.Tokens, expectedType, value, conversion);
} }
private static bool IsCastAllowed(NubType from, NubType to, bool strict = true) private static bool IsCastAllowed(NubType from, NubType to, out CastNode.Conversion conversion, bool strict = true)
{ {
// note(nub31): Implicit casts // note(nub31): Implicit casts
switch (from) switch (from)
{ {
case NubIntType fromInt when to is NubIntType toInt && fromInt.Width < toInt.Width: case NubIntType fromInt when to is NubIntType toInt && fromInt.Width < toInt.Width:
{
conversion = CastNode.Conversion.IntToInt;
return true;
}
case NubPointerType when to is NubPointerType { BaseType: NubVoidType }: case NubPointerType when to is NubPointerType { BaseType: NubVoidType }:
{
conversion = CastNode.Conversion.PointerToPointer;
return true;
}
case NubConstArrayType constArrayType1 when to is NubArrayType arrayType && constArrayType1.ElementType == arrayType.ElementType: case NubConstArrayType constArrayType1 when to is NubArrayType arrayType && constArrayType1.ElementType == arrayType.ElementType:
{
conversion = CastNode.Conversion.ConstArrayToArray;
return true;
}
case NubConstArrayType constArrayType3 when to is NubSliceType sliceType2 && constArrayType3.ElementType == sliceType2.ElementType: case NubConstArrayType constArrayType3 when to is NubSliceType sliceType2 && constArrayType3.ElementType == sliceType2.ElementType:
{ {
conversion = CastNode.Conversion.ConstArrayToSlice;
return true; return true;
} }
} }
@@ -384,19 +397,44 @@ public sealed class TypeChecker
switch (from) switch (from)
{ {
case NubIntType when to is NubIntType: case NubIntType when to is NubIntType:
case NubIntType when to is NubFloatType:
case NubFloatType when to is NubIntType:
case NubFloatType when to is NubFloatType:
case NubPointerType when to is NubPointerType:
case NubPointerType when to is NubIntType:
case NubIntType when to is NubPointerType:
{ {
conversion = CastNode.Conversion.IntToInt;
return true;
}
case NubIntType when to is NubFloatType:
{
conversion = CastNode.Conversion.IntToFloat;
return true;
}
case NubFloatType when to is NubIntType:
{
conversion = CastNode.Conversion.FloatToInt;
return true;
}
case NubFloatType when to is NubFloatType:
{
conversion = CastNode.Conversion.FloatToFloat;
return true;
}
case NubPointerType when to is NubPointerType:
{
conversion = CastNode.Conversion.PointerToPointer;
return true;
}
case NubPointerType when to is NubIntType { Signed: false, Width: 64 }:
{
conversion = CastNode.Conversion.PointerToUInt64;
return true;
}
case NubIntType { Signed: false, Width: 64 } when to is NubPointerType:
{
conversion = CastNode.Conversion.UInt64ToPointer;
return true; return true;
} }
} }
} }
conversion = default;
return false; return false;
} }

View File

@@ -1,3 +1,4 @@
using System.Diagnostics;
using System.Text; using System.Text;
using NubLang.Ast; using NubLang.Ast;
using NubLang.Modules; using NubLang.Modules;
@@ -156,9 +157,7 @@ public class LlvmGenerator
EmitWhile(writer, whileNode); EmitWhile(writer, whileNode);
break; break;
default: default:
{ throw new ArgumentOutOfRangeException(nameof(statementNode));
throw new NotImplementedException();
}
} }
} }
@@ -590,73 +589,81 @@ public class LlvmGenerator
private Tmp EmitCast(IndentedTextWriter writer, CastNode castNode) private Tmp EmitCast(IndentedTextWriter writer, CastNode castNode)
{ {
var source = Unwrap(writer, EmitExpression(writer, castNode.Value)); var source = Unwrap(writer, EmitExpression(writer, castNode.Value));
var sourceType = castNode.Value.Type;
var targetType = castNode.Type;
var result = NewTmp("cast"); var result = NewTmp("cast");
switch (sourceType, targetType) switch (castNode.ConversionType)
{ {
case (NubIntType sourceInt, NubIntType targetInt): case CastNode.Conversion.IntToInt:
{ {
if (sourceInt.Width < targetInt.Width) var sourceInt = (NubIntType)castNode.Value.Type;
{ var targetInt = (NubIntType)castNode.Type;
var op = sourceInt.Signed ? "sext" : "zext";
writer.WriteLine($"{result} = {op} {MapType(sourceType)} {source} to {MapType(targetType)}");
}
else if (sourceInt.Width > targetInt.Width)
{
writer.WriteLine($"{result} = trunc {MapType(sourceType)} {source} to {MapType(targetType)}");
}
else
{
writer.WriteLine($"{result} = bitcast {MapType(sourceType)} {source} to {MapType(targetType)}");
}
break; var op = sourceInt.Width < targetInt.Width
} ? sourceInt.Signed
case (NubFloatType sourceFloat, NubFloatType targetFloat): ? "sext"
{ : "zext"
if (sourceFloat.Width < targetFloat.Width) : sourceInt.Width > targetInt.Width
{ ? "trunc"
writer.WriteLine($"{result} = fpext {MapType(sourceType)} {source} to {MapType(targetType)}"); : "bitcast";
}
else
{
writer.WriteLine($"{result} = fptrunc {MapType(sourceType)} {source} to {MapType(targetType)}");
}
writer.WriteLine($"{result} = {op} {MapType(sourceInt)} {source} to {MapType(targetInt)}");
break; break;
} }
case (NubIntType intType, NubFloatType): case CastNode.Conversion.FloatToFloat:
{ {
var intToFloatOp = intType.Signed ? "sitofp" : "uitofp"; var sourceFloat = (NubFloatType)castNode.Value.Type;
writer.WriteLine($"{result} = {intToFloatOp} {MapType(sourceType)} {source} to {MapType(targetType)}"); var targetFloat = (NubFloatType)castNode.Type;
var op = sourceFloat.Width < targetFloat.Width ? "fpext" : "fptrunc";
writer.WriteLine($"{result} = {op} {MapType(sourceFloat)} {source} to {MapType(targetFloat)}");
break; break;
} }
case (NubFloatType, NubIntType targetInt): case CastNode.Conversion.IntToFloat:
{ {
var floatToIntOp = targetInt.Signed ? "fptosi" : "fptoui"; var sourceInt = (NubIntType)castNode.Value.Type;
writer.WriteLine($"{result} = {floatToIntOp} {MapType(sourceType)} {source} to {MapType(targetType)}"); var targetFloat = (NubFloatType)castNode.Type;
var op = sourceInt.Signed ? "sitofp" : "uitofp";
writer.WriteLine($"{result} = {op} {MapType(sourceInt)} {source} to {MapType(targetFloat)}");
break; break;
} }
case (NubPointerType, NubPointerType): case CastNode.Conversion.FloatToInt:
case (NubPointerType, NubIntType):
case (NubIntType, NubPointerType):
{ {
writer.WriteLine($"{result} = inttoptr {MapType(sourceType)} {source} to {MapType(targetType)}"); var sourceFloat = (NubFloatType)castNode.Value.Type;
var targetInt = (NubIntType)castNode.Type;
var op = targetInt.Signed ? "fptosi" : "fptoui";
writer.WriteLine($"{result} = {op} {MapType(sourceFloat)} {source} to {MapType(targetInt)}");
break; break;
} }
case CastNode.Conversion.PointerToPointer:
case CastNode.Conversion.PointerToUInt64:
case CastNode.Conversion.UInt64ToPointer:
{
writer.WriteLine($"{result} = inttoptr {MapType(castNode.Value.Type)} {source} to {MapType(castNode.Type)}");
break;
}
case CastNode.Conversion.ConstArrayToArray:
{
var sourceConstArrayType = (NubConstArrayType)castNode.Value.Type;
var targetArrayType = (NubArrayType)castNode.Type;
writer.WriteLine($"{result} = getelementptr {MapType(sourceConstArrayType)}, {MapType(targetArrayType)} {source}, i32 0, i32 0");
break;
}
case CastNode.Conversion.ConstArrayToSlice:
{
throw new NotImplementedException();
}
default: default:
{ {
throw new NotImplementedException($"Cast from {sourceType} to {targetType} not implemented"); throw new UnreachableException();
} }
} }
return new Tmp(result, castNode.Type, false); return new Tmp(result, castNode.Type, false);
} }
private Tmp EmitConstArrayIndexAccess(IndentedTextWriter writer, ConstArrayIndexAccessNode constArrayIndexAccessNode) private Tmp EmitConstArrayIndexAccess(IndentedTextWriter writer, ConstArrayIndexAccessNode constArrayIndexAccessNode)
{ {
var arrayPtr = Unwrap(writer, EmitExpression(writer, constArrayIndexAccessNode.Target)); var arrayPtr = Unwrap(writer, EmitExpression(writer, constArrayIndexAccessNode.Target));
@@ -898,6 +905,7 @@ public class LlvmGenerator
switch (unaryExpressionNode.Operator) switch (unaryExpressionNode.Operator)
{ {
case UnaryOperator.Negate: case UnaryOperator.Negate:
{
switch (unaryExpressionNode.Operand.Type) switch (unaryExpressionNode.Operand.Type)
{ {
case NubIntType intType: case NubIntType intType:
@@ -907,16 +915,21 @@ public class LlvmGenerator
writer.WriteLine($"{result} = fneg {MapType(floatType)} {operand}"); writer.WriteLine($"{result} = fneg {MapType(floatType)} {operand}");
break; break;
default: default:
throw new ArgumentOutOfRangeException(); throw new UnreachableException();
} }
break; break;
}
case UnaryOperator.Invert: case UnaryOperator.Invert:
{
writer.WriteLine($"{result} = xor i1 {operand}, true"); writer.WriteLine($"{result} = xor i1 {operand}, true");
break; break;
}
default: default:
{
throw new ArgumentOutOfRangeException(); throw new ArgumentOutOfRangeException();
} }
}
return new Tmp(result, unaryExpressionNode.Type, false); return new Tmp(result, unaryExpressionNode.Type, false);
} }

View File

@@ -9,9 +9,12 @@ struct Test
extern "main" func main(argc: i64, argv: [?]^i8) extern "main" func main(argc: i64, argv: [?]^i8)
{ {
let x: Test = { let x = [1, 2, 3]
field = 23
test(x)
} }
test::test() func test(arr: [?]i64)
{
} }