Clean up cast syntax and rules

This commit is contained in:
nub31
2025-10-23 21:39:24 +02:00
parent a7c45784b9
commit 9f91e42d63
6 changed files with 142 additions and 145 deletions

View File

@@ -439,31 +439,7 @@ public record DereferenceNode(List<Token> Tokens, NubType Type, ExpressionNode T
} }
} }
public record ConvertIntNode(List<Token> Tokens, ExpressionNode Value, int StartWidth, int TargetWidth, bool StartSignedness, bool TargetSignedness) : RValueExpressionNode(Tokens, new NubIntType(TargetSignedness, TargetWidth)) public record SizeNode(List<Token> Tokens, NubType Type, NubType TargetType) : RValueExpressionNode(Tokens, Type)
{
public override IEnumerable<Node> Children()
{
yield return Value;
}
}
public record ConvertFloatNode(List<Token> Tokens, ExpressionNode Value, int StartWidth, int TargetWidth) : RValueExpressionNode(Tokens, new NubFloatType(TargetWidth))
{
public override IEnumerable<Node> Children()
{
yield return Value;
}
}
public record ConvertCStringToStringNode(List<Token> Tokens, ExpressionNode Value) : RValueExpressionNode(Tokens, new NubStringType())
{
public override IEnumerable<Node> Children()
{
yield return Value;
}
}
public record SizeBuiltinNode(List<Token> Tokens, NubType Type, NubType TargetType) : RValueExpressionNode(Tokens, Type)
{ {
public override IEnumerable<Node> Children() public override IEnumerable<Node> Children()
{ {
@@ -471,7 +447,7 @@ public record SizeBuiltinNode(List<Token> Tokens, NubType Type, NubType TargetTy
} }
} }
public record FloatToIntBuiltinNode(List<Token> Tokens, NubType Type, ExpressionNode Value, NubFloatType ValueType, NubIntType TargetType) : RValueExpressionNode(Tokens, Type) public record CastNode(List<Token> Tokens, NubType Type, ExpressionNode Value) : RValueExpressionNode(Tokens, Type)
{ {
public override IEnumerable<Node> Children() public override IEnumerable<Node> Children()
{ {
@@ -479,14 +455,6 @@ public record FloatToIntBuiltinNode(List<Token> Tokens, NubType Type, Expression
} }
} }
public record ConstArrayToSliceNode(List<Token> Tokens, NubType Type, ExpressionNode Array) : RValueExpressionNode(Tokens, Type)
{
public override IEnumerable<Node> Children()
{
yield return Array;
}
}
public record EnumReferenceIntermediateNode(List<Token> Tokens, string Module, string Name) : IntermediateExpression(Tokens) public record EnumReferenceIntermediateNode(List<Token> Tokens, string Module, string Name) : IntermediateExpression(Tokens)
{ {
public override IEnumerable<Node> Children() public override IEnumerable<Node> Children()

View File

@@ -309,69 +309,97 @@ public sealed class TypeChecker
FloatLiteralSyntax expression => CheckFloatLiteral(expression, expectedType), FloatLiteralSyntax expression => CheckFloatLiteral(expression, expectedType),
MemberAccessSyntax expression => CheckMemberAccess(expression, expectedType), MemberAccessSyntax expression => CheckMemberAccess(expression, expectedType),
StructInitializerSyntax expression => CheckStructInitializer(expression, expectedType), StructInitializerSyntax expression => CheckStructInitializer(expression, expectedType),
InterpretBuiltinSyntax expression => CheckExpression(expression.Target, expectedType) with { Type = ResolveType(expression.Type) }, InterpretSyntax expression => CheckExpression(expression.Target, expectedType) with { Type = ResolveType(expression.Type) },
SizeBuiltinSyntax expression => new SizeBuiltinNode(node.Tokens, new NubIntType(false, 64), ResolveType(expression.Type)), SizeSyntax expression => new SizeNode(node.Tokens, new NubIntType(false, 64), ResolveType(expression.Type)),
FloatToIntBuiltinSyntax expression => CheckFloatToInt(expression, expectedType), CastSyntax expression => CheckCast(expression, expectedType),
_ => throw new ArgumentOutOfRangeException(nameof(node)) _ => throw new ArgumentOutOfRangeException(nameof(node))
}; };
switch (expectedType) if (expectedType != null)
{ {
// note(nub31): Implicit conversion of const array to unsized array if (result.Type == expectedType)
case NubArrayType when result.Type is NubConstArrayType constArrayType:
{
return result with { Type = new NubArrayType(constArrayType.ElementType) };
}
// note(nub31): Implicit conversion of const array to slice
case NubSliceType when result.Type is NubConstArrayType constArrayType:
{
return new ConstArrayToSliceNode(result.Tokens, new NubSliceType(constArrayType.ElementType), result);
}
// note(nub31): Implicit conversion of int to larger int
case NubIntType expectedIntType when result.Type is NubIntType intType && expectedIntType.Width > intType.Width:
{
return new ConvertIntNode(result.Tokens, result, intType.Width, expectedIntType.Width, intType.Signed, expectedIntType.Signed);
}
// note(nub31): Implicit conversion of f32 to f64
case NubFloatType expectedFloatType when result.Type is NubFloatType floatType && expectedFloatType.Width > floatType.Width:
{
return new ConvertFloatNode(result.Tokens, result, floatType.Width, expectedFloatType.Width);
}
// note(nub31): Implicit conversion of cstring to string
case NubStringType when result.Type is NubCStringType:
{
return new ConvertCStringToStringNode(result.Tokens, result);
}
// note(nub31): No implicit conversion was possible or the result value was already the correct type
default:
{ {
return result; return result;
} }
if (IsCastAllowed(result.Type, expectedType))
{
return new CastNode(result.Tokens, expectedType, result);
}
} }
return result;
} }
// todo(nub31): Infer int type instead of explicit type syntax private ExpressionNode CheckCast(CastSyntax expression, NubType? expectedType)
private FloatToIntBuiltinNode CheckFloatToInt(FloatToIntBuiltinSyntax expression, NubType? _)
{ {
var value = CheckExpression(expression.Value); if (expectedType == null)
if (value.Type is not NubFloatType sourceFloatType)
{ {
throw new TypeCheckerException(Diagnostic throw new TypeCheckerException(Diagnostic
.Error("Source type of float to int conversion must be an float") .Error("Unable to infer target type of cast")
.At(expression.Value) .At(expression)
.WithHelp("Specify target type where value is used")
.Build()); .Build());
} }
var targetType = ResolveType(expression.Type); var value = CheckExpression(expression.Value, expectedType);
if (targetType is not NubIntType targetIntType)
if (value.Type == expectedType)
{
Diagnostics.Add(Diagnostic
.Warning("Target type of cast is same as the value. Cast is unnecessary")
.At(expression)
.Build());
return value;
}
if (!IsCastAllowed(value.Type, expectedType, false))
{ {
throw new TypeCheckerException(Diagnostic throw new TypeCheckerException(Diagnostic
.Error("Target type of float to int conversion must be an integer") .Error($"Cannot cast from {value.Type} to {expectedType}")
.At(expression.Type)
.Build()); .Build());
} }
return new FloatToIntBuiltinNode(expression.Tokens, targetIntType, value, sourceFloatType, targetIntType); return new CastNode(expression.Tokens, expectedType, value);
}
private static bool IsCastAllowed(NubType from, NubType to, bool strict = true)
{
// note(nub31): Implicit casts
switch (from)
{
case NubIntType fromInt when to is NubIntType toInt && fromInt.Width < toInt.Width:
case NubPointerType when to is NubPointerType { BaseType: NubVoidType }:
case NubConstArrayType constArrayType1 when to is NubArrayType arrayType && constArrayType1.ElementType == arrayType.ElementType:
case NubConstArrayType constArrayType3 when to is NubSliceType sliceType2 && constArrayType3.ElementType == sliceType2.ElementType:
case NubCStringType when to is NubStringType:
{
return true;
}
}
if (!strict)
{
// note(nub31): Explicit casts
switch (from)
{
case NubIntType when to is NubIntType:
case NubIntType when to is NubFloatType:
case NubFloatType when to is NubIntType:
case NubFloatType when to is NubFloatType:
case NubPointerType when to is NubPointerType:
case NubPointerType when to is NubIntType:
case NubIntType when to is NubPointerType:
case NubCStringType when to is NubPointerType { BaseType: NubIntType { Width: 8 } }:
{
return true;
}
}
}
return false;
} }
private AddressOfNode CheckAddressOf(AddressOfSyntax expression, NubType? expectedType) private AddressOfNode CheckAddressOf(AddressOfSyntax expression, NubType? expectedType)
@@ -499,6 +527,13 @@ public sealed class TypeChecker
} }
var right = CheckExpression(expression.Right, left.Type); var right = CheckExpression(expression.Right, left.Type);
if (right.Type != left.Type)
{
throw new TypeCheckerException(Diagnostic
.Error($"Expected type {left.Type} from left side of binary expression, but got {right.Type}")
.At(expression.Right)
.Build());
}
return new BinaryExpressionNode(expression.Tokens, new NubBoolType(), left, op, right); return new BinaryExpressionNode(expression.Tokens, new NubBoolType(), left, op, right);
} }
@@ -517,6 +552,13 @@ public sealed class TypeChecker
} }
var right = CheckExpression(expression.Right, left.Type); var right = CheckExpression(expression.Right, left.Type);
if (right.Type != left.Type)
{
throw new TypeCheckerException(Diagnostic
.Error($"Expected type {left.Type} from left side of binary expression, but got {right.Type}")
.At(expression.Right)
.Build());
}
return new BinaryExpressionNode(expression.Tokens, new NubBoolType(), left, op, right); return new BinaryExpressionNode(expression.Tokens, new NubBoolType(), left, op, right);
} }
@@ -533,21 +575,36 @@ public sealed class TypeChecker
} }
var right = CheckExpression(expression.Right, left.Type); var right = CheckExpression(expression.Right, left.Type);
if (right.Type != left.Type)
{
throw new TypeCheckerException(Diagnostic
.Error($"Expected type {left.Type} from left side of binary expression, but got {right.Type}")
.At(expression.Right)
.Build());
}
return new BinaryExpressionNode(expression.Tokens, new NubBoolType(), left, op, right); return new BinaryExpressionNode(expression.Tokens, new NubBoolType(), left, op, right);
} }
case BinaryOperatorSyntax.Plus: case BinaryOperatorSyntax.Plus:
{ {
var left = CheckExpression(expression.Left, expectedType); var left = CheckExpression(expression.Left, expectedType);
if (left.Type is not NubIntType and not NubFloatType and not NubStringType and not NubCStringType) if (left.Type is not NubIntType and not NubFloatType)
{ {
throw new TypeCheckerException(Diagnostic throw new TypeCheckerException(Diagnostic
.Error("The plus operator must be used with int, float, cstring or string types") .Error("The plus operator must only be used with int and float types")
.At(expression.Left) .At(expression.Left)
.Build()); .Build());
} }
var right = CheckExpression(expression.Right, left.Type); var right = CheckExpression(expression.Right, left.Type);
if (right.Type != left.Type)
{
throw new TypeCheckerException(Diagnostic
.Error($"Expected type {left.Type} from left side of binary expression, but got {right.Type}")
.At(expression.Right)
.Build());
}
return new BinaryExpressionNode(expression.Tokens, left.Type, left, op, right); return new BinaryExpressionNode(expression.Tokens, left.Type, left, op, right);
} }
case BinaryOperatorSyntax.Minus: case BinaryOperatorSyntax.Minus:
@@ -565,6 +622,13 @@ public sealed class TypeChecker
} }
var right = CheckExpression(expression.Right, left.Type); var right = CheckExpression(expression.Right, left.Type);
if (right.Type != left.Type)
{
throw new TypeCheckerException(Diagnostic
.Error($"Expected type {left.Type} from left side of binary expression, but got {right.Type}")
.At(expression.Right)
.Build());
}
return new BinaryExpressionNode(expression.Tokens, left.Type, left, op, right); return new BinaryExpressionNode(expression.Tokens, left.Type, left, op, right);
} }
@@ -584,6 +648,13 @@ public sealed class TypeChecker
} }
var right = CheckExpression(expression.Right, left.Type); var right = CheckExpression(expression.Right, left.Type);
if (right.Type != left.Type)
{
throw new TypeCheckerException(Diagnostic
.Error($"Expected type {left.Type} from left side of binary expression, but got {right.Type}")
.At(expression.Right)
.Build());
}
return new BinaryExpressionNode(expression.Tokens, left.Type, left, op, right); return new BinaryExpressionNode(expression.Tokens, left.Type, left, op, right);
} }

View File

@@ -339,19 +339,15 @@ public class Generator
BoolLiteralNode boolLiteralNode => boolLiteralNode.Value ? "true" : "false", BoolLiteralNode boolLiteralNode => boolLiteralNode.Value ? "true" : "false",
ConstArrayIndexAccessNode constArrayIndexAccessNode => EmitConstArrayIndexAccess(constArrayIndexAccessNode), ConstArrayIndexAccessNode constArrayIndexAccessNode => EmitConstArrayIndexAccess(constArrayIndexAccessNode),
ConstArrayInitializerNode constArrayInitializerNode => EmitConstArrayInitializer(constArrayInitializerNode), ConstArrayInitializerNode constArrayInitializerNode => EmitConstArrayInitializer(constArrayInitializerNode),
ConstArrayToSliceNode constArrayToSliceNode => EmitConstArrayToSlice(constArrayToSliceNode),
ConvertCStringToStringNode convertCStringToStringNode => EmitConvertCStringToString(convertCStringToStringNode),
ConvertFloatNode convertFloatNode => EmitConvertFloat(convertFloatNode),
ConvertIntNode convertIntNode => EmitConvertInt(convertIntNode),
CStringLiteralNode cStringLiteralNode => $"\"{cStringLiteralNode.Value}\"", CStringLiteralNode cStringLiteralNode => $"\"{cStringLiteralNode.Value}\"",
DereferenceNode dereferenceNode => EmitDereference(dereferenceNode), DereferenceNode dereferenceNode => EmitDereference(dereferenceNode),
Float32LiteralNode float32LiteralNode => EmitFloat32Literal(float32LiteralNode), Float32LiteralNode float32LiteralNode => EmitFloat32Literal(float32LiteralNode),
Float64LiteralNode float64LiteralNode => EmitFloat64Literal(float64LiteralNode), Float64LiteralNode float64LiteralNode => EmitFloat64Literal(float64LiteralNode),
FloatToIntBuiltinNode floatToIntBuiltinNode => EmitFloatToIntBuiltin(floatToIntBuiltinNode), CastNode castNode => EmitCast(castNode),
FuncCallNode funcCallNode => EmitFuncCall(funcCallNode), FuncCallNode funcCallNode => EmitFuncCall(funcCallNode),
FuncIdentifierNode funcIdentifierNode => FuncName(funcIdentifierNode.Module, funcIdentifierNode.Name, funcIdentifierNode.ExternSymbol), FuncIdentifierNode funcIdentifierNode => FuncName(funcIdentifierNode.Module, funcIdentifierNode.Name, funcIdentifierNode.ExternSymbol),
AddressOfNode addressOfNode => EmitAddressOf(addressOfNode), AddressOfNode addressOfNode => EmitAddressOf(addressOfNode),
SizeBuiltinNode sizeBuiltinNode => $"sizeof({CType.Create(sizeBuiltinNode.TargetType)})", SizeNode sizeBuiltinNode => $"sizeof({CType.Create(sizeBuiltinNode.TargetType)})",
SliceIndexAccessNode sliceIndexAccessNode => EmitSliceArrayIndexAccess(sliceIndexAccessNode), SliceIndexAccessNode sliceIndexAccessNode => EmitSliceArrayIndexAccess(sliceIndexAccessNode),
StringLiteralNode stringLiteralNode => EmitStringLiteral(stringLiteralNode), StringLiteralNode stringLiteralNode => EmitStringLiteral(stringLiteralNode),
StructFieldAccessNode structFieldAccessNode => EmitStructFieldAccess(structFieldAccessNode), StructFieldAccessNode structFieldAccessNode => EmitStructFieldAccess(structFieldAccessNode),
@@ -442,46 +438,6 @@ public class Generator
return $"({CType.Create(arrayType.ElementType)}[{arrayType.Size}]){{{string.Join(", ", values)}}}"; return $"({CType.Create(arrayType.ElementType)}[{arrayType.Size}]){{{string.Join(", ", values)}}}";
} }
private string EmitConstArrayToSlice(ConstArrayToSliceNode constArrayToSliceNode)
{
var arrayType = (NubConstArrayType)constArrayToSliceNode.Array.Type;
var array = EmitExpression(constArrayToSliceNode.Array);
return $"(slice){{.length = {arrayType.Size}, .data = (void*){array}}}";
}
private string EmitConvertCStringToString(ConvertCStringToStringNode convertCStringToStringNode)
{
var value = EmitExpression(convertCStringToStringNode.Value);
return $"(string){{.length = strlen({value}), .data = {value}}}";
}
private string EmitConvertFloat(ConvertFloatNode convertFloatNode)
{
var value = EmitExpression(convertFloatNode.Value);
var targetCast = convertFloatNode.TargetWidth switch
{
32 => "f32",
64 => "f64",
_ => throw new ArgumentOutOfRangeException()
};
return $"({targetCast}){value}";
}
private string EmitConvertInt(ConvertIntNode convertIntNode)
{
var value = EmitExpression(convertIntNode.Value);
var targetType = convertIntNode.TargetWidth switch
{
8 => convertIntNode.TargetSignedness ? "int8_t" : "uint8_t",
16 => convertIntNode.TargetSignedness ? "int16_t" : "uint16_t",
32 => convertIntNode.TargetSignedness ? "int32_t" : "uint32_t",
64 => convertIntNode.TargetSignedness ? "int64_t" : "uint64_t",
_ => throw new ArgumentOutOfRangeException()
};
return $"({targetType}){value}";
}
private string EmitDereference(DereferenceNode dereferenceNode) private string EmitDereference(DereferenceNode dereferenceNode)
{ {
var pointer = EmitExpression(dereferenceNode.Target); var pointer = EmitExpression(dereferenceNode.Target);
@@ -510,18 +466,22 @@ public class Generator
return str; return str;
} }
private string EmitFloatToIntBuiltin(FloatToIntBuiltinNode floatToIntBuiltinNode) private string EmitCast(CastNode castNode)
{ {
var value = EmitExpression(floatToIntBuiltinNode.Value); var value = EmitExpression(castNode.Value);
var targetType = floatToIntBuiltinNode.TargetType.Width switch
if (castNode is { Type: NubSliceType, Value.Type: NubConstArrayType arrayType })
{ {
8 => floatToIntBuiltinNode.TargetType.Signed ? "int8_t" : "uint8_t", return $"(slice){{.length = {arrayType.Size}, .data = (void*){value}}}";
16 => floatToIntBuiltinNode.TargetType.Signed ? "int16_t" : "uint16_t", }
32 => floatToIntBuiltinNode.TargetType.Signed ? "int32_t" : "uint32_t",
64 => floatToIntBuiltinNode.TargetType.Signed ? "int64_t" : "uint64_t", // todo(nub31): Stop depending on libc
_ => throw new ArgumentOutOfRangeException() if (castNode is { Type: NubCStringType, Value.Type: NubStringType })
}; {
return $"({targetType}){value}"; return $"(string){{.length = strlen({value}), .data = {value}}}";
}
return $"({CType.Create(castNode.Type)}){value}";
} }
private string EmitFuncCall(FuncCallNode funcCallNode) private string EmitFuncCall(FuncCallNode funcCallNode)

View File

@@ -508,7 +508,7 @@ public sealed class Parser
{ {
var type = ParseType(); var type = ParseType();
ExpectSymbol(Symbol.CloseParen); ExpectSymbol(Symbol.CloseParen);
return new SizeBuiltinSyntax(GetTokens(startIndex), type); return new SizeSyntax(GetTokens(startIndex), type);
} }
case "interpret": case "interpret":
{ {
@@ -516,15 +516,13 @@ public sealed class Parser
ExpectSymbol(Symbol.Comma); ExpectSymbol(Symbol.Comma);
var expression = ParseExpression(); var expression = ParseExpression();
ExpectSymbol(Symbol.CloseParen); ExpectSymbol(Symbol.CloseParen);
return new InterpretBuiltinSyntax(GetTokens(startIndex), type, expression); return new InterpretSyntax(GetTokens(startIndex), type, expression);
} }
case "floatToInt": case "cast":
{ {
var type = ParseType();
ExpectSymbol(Symbol.Comma);
var expression = ParseExpression(); var expression = ParseExpression();
ExpectSymbol(Symbol.CloseParen); ExpectSymbol(Symbol.CloseParen);
return new FloatToIntBuiltinSyntax(GetTokens(startIndex), type, expression); return new CastSyntax(GetTokens(startIndex), expression);
} }
default: default:
{ {

View File

@@ -112,11 +112,11 @@ public record StructInitializerSyntax(List<Token> Tokens, TypeSyntax? StructType
public record DereferenceSyntax(List<Token> Tokens, ExpressionSyntax Target) : ExpressionSyntax(Tokens); public record DereferenceSyntax(List<Token> Tokens, ExpressionSyntax Target) : ExpressionSyntax(Tokens);
public record SizeBuiltinSyntax(List<Token> Tokens, TypeSyntax Type) : ExpressionSyntax(Tokens); public record SizeSyntax(List<Token> Tokens, TypeSyntax Type) : ExpressionSyntax(Tokens);
public record InterpretBuiltinSyntax(List<Token> Tokens, TypeSyntax Type, ExpressionSyntax Target) : ExpressionSyntax(Tokens); public record InterpretSyntax(List<Token> Tokens, TypeSyntax Type, ExpressionSyntax Target) : ExpressionSyntax(Tokens);
public record FloatToIntBuiltinSyntax(List<Token> Tokens, TypeSyntax Type, ExpressionSyntax Value) : ExpressionSyntax(Tokens); public record CastSyntax(List<Token> Tokens, ExpressionSyntax Value) : ExpressionSyntax(Tokens);
#endregion #endregion

View File

@@ -37,8 +37,8 @@ extern "main" func main(argc: i64, argv: [?]cstring): i64
direction.y = -1 direction.y = -1
} }
x = x + @floatToInt(i32, direction.x * speed * raylib::GetFrameTime()) x = x + @cast(direction.x * speed * raylib::GetFrameTime())
y = y + @floatToInt(i32, direction.y * speed * raylib::GetFrameTime()) y = y + @cast(direction.y * speed * raylib::GetFrameTime())
raylib::BeginDrawing() raylib::BeginDrawing()
{ {