From 3323d760e8336fd38b9d68762530899246bfe0b6 Mon Sep 17 00:00:00 2001 From: nub31 Date: Thu, 26 Feb 2026 20:44:29 +0100 Subject: [PATCH] infer enum types --- compiler/Parser.cs | 15 ++++----------- compiler/TypeChecker.cs | 18 +++++++++++++++--- examples/math/math.nub | 6 +++++- 3 files changed, 24 insertions(+), 15 deletions(-) diff --git a/compiler/Parser.cs b/compiler/Parser.cs index 586081e..90defac 100644 --- a/compiler/Parser.cs +++ b/compiler/Parser.cs @@ -381,15 +381,10 @@ public class Parser } else if (TryExpectKeyword(Keyword.Enum)) { - var module = ExpectIdent(); - ExpectSymbol(Symbol.ColonColon); - var enumName = ExpectIdent(); - ExpectSymbol(Symbol.ColonColon); - var variantName = ExpectIdent(); - + var type = ParseType(); var value = ParseExpression(); - expr = new NodeExpressionEnumLiteral(TokensFrom(startIndex), module, enumName, variantName, value); + expr = new NodeExpressionEnumLiteral(TokensFrom(startIndex), type, value); } else { @@ -846,11 +841,9 @@ public class NodeExpressionStructLiteral(List tokens, NodeType? type, Lis } } -public class NodeExpressionEnumLiteral(List tokens, TokenIdent module, TokenIdent enumName, TokenIdent variantName, NodeExpression value) : NodeExpression(tokens) +public class NodeExpressionEnumLiteral(List tokens, NodeType type, NodeExpression value) : NodeExpression(tokens) { - public TokenIdent Module { get; } = module; - public TokenIdent EnumName { get; } = enumName; - public TokenIdent VariantName { get; } = variantName; + public NodeType Type { get; } = type; public NodeExpression Value { get; } = value; } diff --git a/compiler/TypeChecker.cs b/compiler/TypeChecker.cs index 97c1532..571ec85 100644 --- a/compiler/TypeChecker.cs +++ b/compiler/TypeChecker.cs @@ -533,9 +533,21 @@ public class TypeChecker private TypedNodeExpressionEnumLiteral CheckExpressionEnumLiteral(NodeExpressionEnumLiteral expression, NubType? expectedType) { - // todo(nub31): Infer type of enum variant - var type = NubTypeEnumVariant.Get(NubTypeEnum.Get(expression.Module.Ident, expression.EnumName.Ident), expression.VariantName.Ident); - var value = CheckExpression(expression.Value, null); + var type = ResolveType(expression.Type); + if (type is not NubTypeEnumVariant variantType) + throw BasicError("Expected enum variant type", expression.Type); + + if (!moduleGraph.TryResolveType(variantType.EnumType.Module, variantType.EnumType.Name, variantType.EnumType.Module == currentModule, out var info)) + throw BasicError($"Type '{variantType.EnumType}' not found", expression.Type); + + if (info is not Module.TypeInfoEnum enumInfo) + throw BasicError($"Type '{variantType.EnumType}' is not an enum", expression.Type); + + var variant = enumInfo.Variants.FirstOrDefault(x => x.Name == variantType.Variant); + if (variant == null) + throw BasicError($"Enum '{variantType.EnumType}' does not have a variant named '{variantType.Variant}'", expression.Type); + + var value = CheckExpression(expression.Value, variant.Type); return new TypedNodeExpressionEnumLiteral(expression.Tokens, type, value); } diff --git a/examples/math/math.nub b/examples/math/math.nub index 1ae7fb2..062e4d6 100644 --- a/examples/math/math.nub +++ b/examples/math/math.nub @@ -29,7 +29,11 @@ export enum message { export func add(a: i32 b: i32): i32 { - let message: message = enum math::message::move struct color { r = 23 g = 46 b = 56 } + let message: message = enum message::move { + r = 23 + g = 46 + b = 56 + } match message { quit q {}