From 9f91e42d63b7cf39e5230a956055c11241158f77 Mon Sep 17 00:00:00 2001 From: nub31 Date: Thu, 23 Oct 2025 21:39:24 +0200 Subject: [PATCH] Clean up cast syntax and rules --- compiler/NubLang/Ast/Node.cs | 36 +---- compiler/NubLang/Ast/TypeChecker.cs | 159 ++++++++++++++++------- compiler/NubLang/Generation/Generator.cs | 72 +++------- compiler/NubLang/Syntax/Parser.cs | 10 +- compiler/NubLang/Syntax/Syntax.cs | 6 +- examples/raylib/main.nub | 4 +- 6 files changed, 142 insertions(+), 145 deletions(-) diff --git a/compiler/NubLang/Ast/Node.cs b/compiler/NubLang/Ast/Node.cs index 73c3641..020e8c7 100644 --- a/compiler/NubLang/Ast/Node.cs +++ b/compiler/NubLang/Ast/Node.cs @@ -439,31 +439,7 @@ public record DereferenceNode(List Tokens, NubType Type, ExpressionNode T } } -public record ConvertIntNode(List Tokens, ExpressionNode Value, int StartWidth, int TargetWidth, bool StartSignedness, bool TargetSignedness) : RValueExpressionNode(Tokens, new NubIntType(TargetSignedness, TargetWidth)) -{ - public override IEnumerable Children() - { - yield return Value; - } -} - -public record ConvertFloatNode(List Tokens, ExpressionNode Value, int StartWidth, int TargetWidth) : RValueExpressionNode(Tokens, new NubFloatType(TargetWidth)) -{ - public override IEnumerable Children() - { - yield return Value; - } -} - -public record ConvertCStringToStringNode(List Tokens, ExpressionNode Value) : RValueExpressionNode(Tokens, new NubStringType()) -{ - public override IEnumerable Children() - { - yield return Value; - } -} - -public record SizeBuiltinNode(List Tokens, NubType Type, NubType TargetType) : RValueExpressionNode(Tokens, Type) +public record SizeNode(List Tokens, NubType Type, NubType TargetType) : RValueExpressionNode(Tokens, Type) { public override IEnumerable Children() { @@ -471,7 +447,7 @@ public record SizeBuiltinNode(List Tokens, NubType Type, NubType TargetTy } } -public record FloatToIntBuiltinNode(List Tokens, NubType Type, ExpressionNode Value, NubFloatType ValueType, NubIntType TargetType) : RValueExpressionNode(Tokens, Type) +public record CastNode(List Tokens, NubType Type, ExpressionNode Value) : RValueExpressionNode(Tokens, Type) { public override IEnumerable Children() { @@ -479,14 +455,6 @@ public record FloatToIntBuiltinNode(List Tokens, NubType Type, Expression } } -public record ConstArrayToSliceNode(List Tokens, NubType Type, ExpressionNode Array) : RValueExpressionNode(Tokens, Type) -{ - public override IEnumerable Children() - { - yield return Array; - } -} - public record EnumReferenceIntermediateNode(List Tokens, string Module, string Name) : IntermediateExpression(Tokens) { public override IEnumerable Children() diff --git a/compiler/NubLang/Ast/TypeChecker.cs b/compiler/NubLang/Ast/TypeChecker.cs index f5b3f20..4801eb9 100644 --- a/compiler/NubLang/Ast/TypeChecker.cs +++ b/compiler/NubLang/Ast/TypeChecker.cs @@ -309,69 +309,97 @@ public sealed class TypeChecker FloatLiteralSyntax expression => CheckFloatLiteral(expression, expectedType), MemberAccessSyntax expression => CheckMemberAccess(expression, expectedType), StructInitializerSyntax expression => CheckStructInitializer(expression, expectedType), - InterpretBuiltinSyntax expression => CheckExpression(expression.Target, expectedType) with { Type = ResolveType(expression.Type) }, - SizeBuiltinSyntax expression => new SizeBuiltinNode(node.Tokens, new NubIntType(false, 64), ResolveType(expression.Type)), - FloatToIntBuiltinSyntax expression => CheckFloatToInt(expression, expectedType), + InterpretSyntax expression => CheckExpression(expression.Target, expectedType) with { Type = ResolveType(expression.Type) }, + SizeSyntax expression => new SizeNode(node.Tokens, new NubIntType(false, 64), ResolveType(expression.Type)), + CastSyntax expression => CheckCast(expression, expectedType), _ => throw new ArgumentOutOfRangeException(nameof(node)) }; - switch (expectedType) + if (expectedType != null) { - // note(nub31): Implicit conversion of const array to unsized array - case NubArrayType when result.Type is NubConstArrayType constArrayType: - { - return result with { Type = new NubArrayType(constArrayType.ElementType) }; - } - // note(nub31): Implicit conversion of const array to slice - case NubSliceType when result.Type is NubConstArrayType constArrayType: - { - return new ConstArrayToSliceNode(result.Tokens, new NubSliceType(constArrayType.ElementType), result); - } - // note(nub31): Implicit conversion of int to larger int - case NubIntType expectedIntType when result.Type is NubIntType intType && expectedIntType.Width > intType.Width: - { - return new ConvertIntNode(result.Tokens, result, intType.Width, expectedIntType.Width, intType.Signed, expectedIntType.Signed); - } - // note(nub31): Implicit conversion of f32 to f64 - case NubFloatType expectedFloatType when result.Type is NubFloatType floatType && expectedFloatType.Width > floatType.Width: - { - return new ConvertFloatNode(result.Tokens, result, floatType.Width, expectedFloatType.Width); - } - // note(nub31): Implicit conversion of cstring to string - case NubStringType when result.Type is NubCStringType: - { - return new ConvertCStringToStringNode(result.Tokens, result); - } - // note(nub31): No implicit conversion was possible or the result value was already the correct type - default: + if (result.Type == expectedType) { return result; } + + if (IsCastAllowed(result.Type, expectedType)) + { + return new CastNode(result.Tokens, expectedType, result); + } } + + return result; } - // todo(nub31): Infer int type instead of explicit type syntax - private FloatToIntBuiltinNode CheckFloatToInt(FloatToIntBuiltinSyntax expression, NubType? _) + private ExpressionNode CheckCast(CastSyntax expression, NubType? expectedType) { - var value = CheckExpression(expression.Value); - if (value.Type is not NubFloatType sourceFloatType) + if (expectedType == null) { throw new TypeCheckerException(Diagnostic - .Error("Source type of float to int conversion must be an float") - .At(expression.Value) + .Error("Unable to infer target type of cast") + .At(expression) + .WithHelp("Specify target type where value is used") .Build()); } - var targetType = ResolveType(expression.Type); - if (targetType is not NubIntType targetIntType) + var value = CheckExpression(expression.Value, expectedType); + + if (value.Type == expectedType) + { + Diagnostics.Add(Diagnostic + .Warning("Target type of cast is same as the value. Cast is unnecessary") + .At(expression) + .Build()); + + return value; + } + + if (!IsCastAllowed(value.Type, expectedType, false)) { throw new TypeCheckerException(Diagnostic - .Error("Target type of float to int conversion must be an integer") - .At(expression.Type) + .Error($"Cannot cast from {value.Type} to {expectedType}") .Build()); } - return new FloatToIntBuiltinNode(expression.Tokens, targetIntType, value, sourceFloatType, targetIntType); + return new CastNode(expression.Tokens, expectedType, value); + } + + private static bool IsCastAllowed(NubType from, NubType to, bool strict = true) + { + // note(nub31): Implicit casts + switch (from) + { + case NubIntType fromInt when to is NubIntType toInt && fromInt.Width < toInt.Width: + case NubPointerType when to is NubPointerType { BaseType: NubVoidType }: + case NubConstArrayType constArrayType1 when to is NubArrayType arrayType && constArrayType1.ElementType == arrayType.ElementType: + case NubConstArrayType constArrayType3 when to is NubSliceType sliceType2 && constArrayType3.ElementType == sliceType2.ElementType: + case NubCStringType when to is NubStringType: + { + return true; + } + } + + if (!strict) + { + // note(nub31): Explicit casts + switch (from) + { + case NubIntType when to is NubIntType: + case NubIntType when to is NubFloatType: + case NubFloatType when to is NubIntType: + case NubFloatType when to is NubFloatType: + case NubPointerType when to is NubPointerType: + case NubPointerType when to is NubIntType: + case NubIntType when to is NubPointerType: + case NubCStringType when to is NubPointerType { BaseType: NubIntType { Width: 8 } }: + { + return true; + } + } + } + + + return false; } private AddressOfNode CheckAddressOf(AddressOfSyntax expression, NubType? expectedType) @@ -499,6 +527,13 @@ public sealed class TypeChecker } var right = CheckExpression(expression.Right, left.Type); + if (right.Type != left.Type) + { + throw new TypeCheckerException(Diagnostic + .Error($"Expected type {left.Type} from left side of binary expression, but got {right.Type}") + .At(expression.Right) + .Build()); + } return new BinaryExpressionNode(expression.Tokens, new NubBoolType(), left, op, right); } @@ -517,6 +552,13 @@ public sealed class TypeChecker } var right = CheckExpression(expression.Right, left.Type); + if (right.Type != left.Type) + { + throw new TypeCheckerException(Diagnostic + .Error($"Expected type {left.Type} from left side of binary expression, but got {right.Type}") + .At(expression.Right) + .Build()); + } return new BinaryExpressionNode(expression.Tokens, new NubBoolType(), left, op, right); } @@ -533,21 +575,36 @@ public sealed class TypeChecker } var right = CheckExpression(expression.Right, left.Type); + if (right.Type != left.Type) + { + throw new TypeCheckerException(Diagnostic + .Error($"Expected type {left.Type} from left side of binary expression, but got {right.Type}") + .At(expression.Right) + .Build()); + } return new BinaryExpressionNode(expression.Tokens, new NubBoolType(), left, op, right); } case BinaryOperatorSyntax.Plus: { var left = CheckExpression(expression.Left, expectedType); - if (left.Type is not NubIntType and not NubFloatType and not NubStringType and not NubCStringType) + if (left.Type is not NubIntType and not NubFloatType) { throw new TypeCheckerException(Diagnostic - .Error("The plus operator must be used with int, float, cstring or string types") + .Error("The plus operator must only be used with int and float types") .At(expression.Left) .Build()); } var right = CheckExpression(expression.Right, left.Type); + if (right.Type != left.Type) + { + throw new TypeCheckerException(Diagnostic + .Error($"Expected type {left.Type} from left side of binary expression, but got {right.Type}") + .At(expression.Right) + .Build()); + } + return new BinaryExpressionNode(expression.Tokens, left.Type, left, op, right); } case BinaryOperatorSyntax.Minus: @@ -565,6 +622,13 @@ public sealed class TypeChecker } var right = CheckExpression(expression.Right, left.Type); + if (right.Type != left.Type) + { + throw new TypeCheckerException(Diagnostic + .Error($"Expected type {left.Type} from left side of binary expression, but got {right.Type}") + .At(expression.Right) + .Build()); + } return new BinaryExpressionNode(expression.Tokens, left.Type, left, op, right); } @@ -584,6 +648,13 @@ public sealed class TypeChecker } var right = CheckExpression(expression.Right, left.Type); + if (right.Type != left.Type) + { + throw new TypeCheckerException(Diagnostic + .Error($"Expected type {left.Type} from left side of binary expression, but got {right.Type}") + .At(expression.Right) + .Build()); + } return new BinaryExpressionNode(expression.Tokens, left.Type, left, op, right); } diff --git a/compiler/NubLang/Generation/Generator.cs b/compiler/NubLang/Generation/Generator.cs index b57de89..4b1a5c8 100644 --- a/compiler/NubLang/Generation/Generator.cs +++ b/compiler/NubLang/Generation/Generator.cs @@ -339,19 +339,15 @@ public class Generator BoolLiteralNode boolLiteralNode => boolLiteralNode.Value ? "true" : "false", ConstArrayIndexAccessNode constArrayIndexAccessNode => EmitConstArrayIndexAccess(constArrayIndexAccessNode), ConstArrayInitializerNode constArrayInitializerNode => EmitConstArrayInitializer(constArrayInitializerNode), - ConstArrayToSliceNode constArrayToSliceNode => EmitConstArrayToSlice(constArrayToSliceNode), - ConvertCStringToStringNode convertCStringToStringNode => EmitConvertCStringToString(convertCStringToStringNode), - ConvertFloatNode convertFloatNode => EmitConvertFloat(convertFloatNode), - ConvertIntNode convertIntNode => EmitConvertInt(convertIntNode), CStringLiteralNode cStringLiteralNode => $"\"{cStringLiteralNode.Value}\"", DereferenceNode dereferenceNode => EmitDereference(dereferenceNode), Float32LiteralNode float32LiteralNode => EmitFloat32Literal(float32LiteralNode), Float64LiteralNode float64LiteralNode => EmitFloat64Literal(float64LiteralNode), - FloatToIntBuiltinNode floatToIntBuiltinNode => EmitFloatToIntBuiltin(floatToIntBuiltinNode), + CastNode castNode => EmitCast(castNode), FuncCallNode funcCallNode => EmitFuncCall(funcCallNode), FuncIdentifierNode funcIdentifierNode => FuncName(funcIdentifierNode.Module, funcIdentifierNode.Name, funcIdentifierNode.ExternSymbol), AddressOfNode addressOfNode => EmitAddressOf(addressOfNode), - SizeBuiltinNode sizeBuiltinNode => $"sizeof({CType.Create(sizeBuiltinNode.TargetType)})", + SizeNode sizeBuiltinNode => $"sizeof({CType.Create(sizeBuiltinNode.TargetType)})", SliceIndexAccessNode sliceIndexAccessNode => EmitSliceArrayIndexAccess(sliceIndexAccessNode), StringLiteralNode stringLiteralNode => EmitStringLiteral(stringLiteralNode), StructFieldAccessNode structFieldAccessNode => EmitStructFieldAccess(structFieldAccessNode), @@ -442,46 +438,6 @@ public class Generator return $"({CType.Create(arrayType.ElementType)}[{arrayType.Size}]){{{string.Join(", ", values)}}}"; } - private string EmitConstArrayToSlice(ConstArrayToSliceNode constArrayToSliceNode) - { - var arrayType = (NubConstArrayType)constArrayToSliceNode.Array.Type; - var array = EmitExpression(constArrayToSliceNode.Array); - return $"(slice){{.length = {arrayType.Size}, .data = (void*){array}}}"; - } - - private string EmitConvertCStringToString(ConvertCStringToStringNode convertCStringToStringNode) - { - var value = EmitExpression(convertCStringToStringNode.Value); - return $"(string){{.length = strlen({value}), .data = {value}}}"; - } - - private string EmitConvertFloat(ConvertFloatNode convertFloatNode) - { - var value = EmitExpression(convertFloatNode.Value); - var targetCast = convertFloatNode.TargetWidth switch - { - 32 => "f32", - 64 => "f64", - _ => throw new ArgumentOutOfRangeException() - }; - - return $"({targetCast}){value}"; - } - - private string EmitConvertInt(ConvertIntNode convertIntNode) - { - var value = EmitExpression(convertIntNode.Value); - var targetType = convertIntNode.TargetWidth switch - { - 8 => convertIntNode.TargetSignedness ? "int8_t" : "uint8_t", - 16 => convertIntNode.TargetSignedness ? "int16_t" : "uint16_t", - 32 => convertIntNode.TargetSignedness ? "int32_t" : "uint32_t", - 64 => convertIntNode.TargetSignedness ? "int64_t" : "uint64_t", - _ => throw new ArgumentOutOfRangeException() - }; - return $"({targetType}){value}"; - } - private string EmitDereference(DereferenceNode dereferenceNode) { var pointer = EmitExpression(dereferenceNode.Target); @@ -510,18 +466,22 @@ public class Generator return str; } - private string EmitFloatToIntBuiltin(FloatToIntBuiltinNode floatToIntBuiltinNode) + private string EmitCast(CastNode castNode) { - var value = EmitExpression(floatToIntBuiltinNode.Value); - var targetType = floatToIntBuiltinNode.TargetType.Width switch + var value = EmitExpression(castNode.Value); + + if (castNode is { Type: NubSliceType, Value.Type: NubConstArrayType arrayType }) { - 8 => floatToIntBuiltinNode.TargetType.Signed ? "int8_t" : "uint8_t", - 16 => floatToIntBuiltinNode.TargetType.Signed ? "int16_t" : "uint16_t", - 32 => floatToIntBuiltinNode.TargetType.Signed ? "int32_t" : "uint32_t", - 64 => floatToIntBuiltinNode.TargetType.Signed ? "int64_t" : "uint64_t", - _ => throw new ArgumentOutOfRangeException() - }; - return $"({targetType}){value}"; + return $"(slice){{.length = {arrayType.Size}, .data = (void*){value}}}"; + } + + // todo(nub31): Stop depending on libc + if (castNode is { Type: NubCStringType, Value.Type: NubStringType }) + { + return $"(string){{.length = strlen({value}), .data = {value}}}"; + } + + return $"({CType.Create(castNode.Type)}){value}"; } private string EmitFuncCall(FuncCallNode funcCallNode) diff --git a/compiler/NubLang/Syntax/Parser.cs b/compiler/NubLang/Syntax/Parser.cs index fba5bf9..b122682 100644 --- a/compiler/NubLang/Syntax/Parser.cs +++ b/compiler/NubLang/Syntax/Parser.cs @@ -508,7 +508,7 @@ public sealed class Parser { var type = ParseType(); ExpectSymbol(Symbol.CloseParen); - return new SizeBuiltinSyntax(GetTokens(startIndex), type); + return new SizeSyntax(GetTokens(startIndex), type); } case "interpret": { @@ -516,15 +516,13 @@ public sealed class Parser ExpectSymbol(Symbol.Comma); var expression = ParseExpression(); ExpectSymbol(Symbol.CloseParen); - return new InterpretBuiltinSyntax(GetTokens(startIndex), type, expression); + return new InterpretSyntax(GetTokens(startIndex), type, expression); } - case "floatToInt": + case "cast": { - var type = ParseType(); - ExpectSymbol(Symbol.Comma); var expression = ParseExpression(); ExpectSymbol(Symbol.CloseParen); - return new FloatToIntBuiltinSyntax(GetTokens(startIndex), type, expression); + return new CastSyntax(GetTokens(startIndex), expression); } default: { diff --git a/compiler/NubLang/Syntax/Syntax.cs b/compiler/NubLang/Syntax/Syntax.cs index 1f4d48b..5e1fbae 100644 --- a/compiler/NubLang/Syntax/Syntax.cs +++ b/compiler/NubLang/Syntax/Syntax.cs @@ -112,11 +112,11 @@ public record StructInitializerSyntax(List Tokens, TypeSyntax? StructType public record DereferenceSyntax(List Tokens, ExpressionSyntax Target) : ExpressionSyntax(Tokens); -public record SizeBuiltinSyntax(List Tokens, TypeSyntax Type) : ExpressionSyntax(Tokens); +public record SizeSyntax(List Tokens, TypeSyntax Type) : ExpressionSyntax(Tokens); -public record InterpretBuiltinSyntax(List Tokens, TypeSyntax Type, ExpressionSyntax Target) : ExpressionSyntax(Tokens); +public record InterpretSyntax(List Tokens, TypeSyntax Type, ExpressionSyntax Target) : ExpressionSyntax(Tokens); -public record FloatToIntBuiltinSyntax(List Tokens, TypeSyntax Type, ExpressionSyntax Value) : ExpressionSyntax(Tokens); +public record CastSyntax(List Tokens, ExpressionSyntax Value) : ExpressionSyntax(Tokens); #endregion diff --git a/examples/raylib/main.nub b/examples/raylib/main.nub index 2c3dfdb..f0b05bb 100644 --- a/examples/raylib/main.nub +++ b/examples/raylib/main.nub @@ -37,8 +37,8 @@ extern "main" func main(argc: i64, argv: [?]cstring): i64 direction.y = -1 } - x = x + @floatToInt(i32, direction.x * speed * raylib::GetFrameTime()) - y = y + @floatToInt(i32, direction.y * speed * raylib::GetFrameTime()) + x = x + @cast(direction.x * speed * raylib::GetFrameTime()) + y = y + @cast(direction.y * speed * raylib::GetFrameTime()) raylib::BeginDrawing() {