diff --git a/compiler/Generator.cs b/compiler/Generator.cs index 4e8a539..3fec29c 100644 --- a/compiler/Generator.cs +++ b/compiler/Generator.cs @@ -301,13 +301,13 @@ public class Generator { foreach (var @case in statement.Cases) { - var tag = enumInfo.Variants.ToList().FindIndex(x => x.Name == @case.Type.Ident); + var tag = enumInfo.Variants.ToList().FindIndex(x => x.Name == @case.Variant.Ident); writer.WriteLine($"case {tag}:"); writer.WriteLine("{"); using (writer.Indent()) { - writer.WriteLine($"auto {@case.VariableName.Ident} = {target}.{@case.Type.Ident};"); + writer.WriteLine($"auto {@case.VariableName.Ident} = {target}.{@case.Variant.Ident};"); EmitStatement(@case.Body); } writer.WriteLine("}"); diff --git a/compiler/TypeChecker.cs b/compiler/TypeChecker.cs index 571ec85..36b6a4f 100644 --- a/compiler/TypeChecker.cs +++ b/compiler/TypeChecker.cs @@ -191,9 +191,22 @@ public class TypeChecker if (target.Type is not NubTypeEnum enumType) throw BasicError("A match statement can only be used on enum types", target); + if (!moduleGraph.TryResolveType(enumType.Module, enumType.Name, enumType.Module == currentModule, out var info)) + throw BasicError($"Type '{enumType}' not found", target); + + if (info is not Module.TypeInfoEnum enumInfo) + throw BasicError($"Type '{enumType}' is not an enum", target); + + var uncoveredCases = enumInfo.Variants.Select(x => x.Name).ToList(); + var cases = new List(); foreach (var @case in statement.Cases) { + if (!enumInfo.Variants.Any(x => x.Name == @case.Variant.Ident)) + throw BasicError($"Enum type'{enumType}' does not have a variant named '{@case.Variant.Ident}'", @case.Variant); + + uncoveredCases.Remove(@case.Variant.Ident); + using (scope.EnterScope()) { scope.DeclareIdentifier(@case.VariableName.Ident, NubTypeEnumVariant.Get(NubTypeEnum.Get(enumType.Module, enumType.Name), @case.Variant.Ident)); @@ -202,6 +215,9 @@ public class TypeChecker } } + if (uncoveredCases.Any()) + throw BasicError($"Match statement does not cover the following cases: {string.Join(", ", uncoveredCases)}", statement); + return new TypedNodeStatementMatch(statement.Tokens, target, cases); } @@ -791,7 +807,7 @@ public class TypedNodeStatementMatch(List tokens, TypedNodeExpression tar public class Case(List tokens, TokenIdent type, TokenIdent variableName, TypedNodeStatement body) : Node(tokens) { - public TokenIdent Type { get; } = type; + public TokenIdent Variant { get; } = type; public TokenIdent VariableName { get; } = variableName; public TypedNodeStatement Body { get; } = body; } diff --git a/examples/math/math.nub b/examples/math/math.nub index 062e4d6..0d2b1e1 100644 --- a/examples/math/math.nub +++ b/examples/math/math.nub @@ -1,53 +1,34 @@ module math -export struct vec2 { +export struct Human { + name: string + age: i32 +} + +export struct Pos { x: i32 y: i32 } -export struct vec3 { - x: i32 - y: i32 - z: i32 -} - -export struct color { - r: i32 - g: i32 - b: i32 - a: i32 -} - -export struct example { - b: color -} - -export enum message { - quit: {} - move: color +export enum Message { + Quit: {} + Move: Pos } export func add(a: i32 b: i32): i32 { - let message: message = enum message::move { - r = 23 - g = 46 - b = 56 + let msg: Message = enum Message::Move { + x = 10 + y = 10 } - match message { - quit q {} - move m { - m.r = 23 - m.g = 23 - m.b = 23 + match msg { + Quit q { + // quit + } + Move m { + // move } - } - - let color: color = { - r = 23 - g = 23 - b = 23 } return add_internal(a b)