diff --git a/compiler/Generator.cs b/compiler/Generator.cs index 9e95965..372f258 100644 --- a/compiler/Generator.cs +++ b/compiler/Generator.cs @@ -489,7 +489,7 @@ public class Generator TypedNodeExpressionIntLiteral expression => expression.Value.Value.ToString(), TypedNodeExpressionStringLiteral expression => EmitExpressionStringLiteral(expression), TypedNodeExpressionStructLiteral expression => EmitExpressionStructLiteral(expression), - TypedNodeExpressionNewNamedType expression => EmitNodeExpressionNewNamedType(expression), + TypedNodeExpressionEnumLiteral expression => EmitExpressionEnumLiteral(expression), TypedNodeExpressionStructMemberAccess expression => EmitExpressionMemberAccess(expression), TypedNodeExpressionStringLength expression => EmitExpressionStringLength(expression), TypedNodeExpressionStringPointer expression => EmitExpressionStringPointer(expression), @@ -578,35 +578,34 @@ public class Generator return name; } - private string EmitNodeExpressionNewNamedType(TypedNodeExpressionNewNamedType expression) + private string EmitExpressionEnumLiteral(TypedNodeExpressionEnumLiteral expression) { - switch (expression.Type) + var name = TmpName(); + scopes.Peek().DeconstructableNames.Add((name, expression.Type)); + + var enumVariantType = (NubTypeEnumVariant)expression.Type; + + if (!moduleGraph.TryResolveType(enumVariantType.EnumType.Module, enumVariantType.EnumType.Name, true, out var info)) + throw new UnreachableException(); + + var enumInfo = (Module.TypeInfoEnum)info; + var tag = enumInfo.Variants.ToList().FindIndex(x => x.Name == enumVariantType.Variant); + + string? value = null; + if (expression.Value != null) { - case NubTypeEnumVariant enumVariantType: - { - var name = TmpName(); - scopes.Peek().DeconstructableNames.Add((name, expression.Type)); - - if (!moduleGraph.TryResolveType(enumVariantType.EnumType.Module, enumVariantType.EnumType.Name, true, out var info)) - throw new UnreachableException(); - - var enumInfo = (Module.TypeInfoEnum)info; - var tag = enumInfo.Variants.ToList().FindIndex(x => x.Name == enumVariantType.Variant); - - var value = EmitExpression(expression.Value); - EmitCopyConstructor(value, expression.Value.Type); - - writer.WriteLine($"{CType(expression.Type, name)} = ({CType(expression.Type)}){{ .tag = {tag}, .{enumVariantType.Variant} = {value} }};"); - - return name; - } - case NubTypeStruct structType: - { - return EmitExpression(expression.Value); - } - default: - throw new UnreachableException(); + value = EmitExpression(expression.Value); + EmitCopyConstructor(value, expression.Value.Type); } + + writer.Write($"{CType(expression.Type, name)} = ({CType(expression.Type)}){{ .tag = {tag}"); + + if (value != null) + writer.WriteLine($", .{enumVariantType.Variant} = {value} }};"); + else + writer.WriteLine(" }};"); + + return name; } private string EmitExpressionMemberAccess(TypedNodeExpressionStructMemberAccess expression) diff --git a/compiler/Parser.cs b/compiler/Parser.cs index 2830127..1dc3643 100644 --- a/compiler/Parser.cs +++ b/compiler/Parser.cs @@ -375,7 +375,7 @@ public class Parser initializers.Add(new NodeExpressionStructLiteral.Initializer(TokensFrom(initializerStartIndex), fieldName, fieldValue)); } - expr = new NodeExpressionStructLiteral(TokensFrom(startIndex), initializers); + expr = new NodeExpressionStructLiteral(TokensFrom(startIndex), null, initializers); } else if (TryExpectSymbol(Symbol.Bang)) { @@ -408,8 +408,32 @@ public class Parser else if (TryExpectKeyword(Keyword.New)) { var type = ParseType(); - var value = ParseExpression(); - return new NodeExpressionNewNamedType(TokensFrom(startIndex), type, value); + + if (TryExpectSymbol(Symbol.OpenParen)) + { + var value = ParseExpression(); + ExpectSymbol(Symbol.CloseParen); + + expr = new NodeExpressionEnumLiteral(TokensFrom(startIndex), type, value); + } + else if (TryExpectSymbol(Symbol.OpenCurly)) + { + var initializers = new List(); + while (!TryExpectSymbol(Symbol.CloseCurly)) + { + var initializerStartIndex = startIndex; + var fieldName = ExpectIdent(); + ExpectSymbol(Symbol.Equal); + var fieldValue = ParseExpression(); + initializers.Add(new NodeExpressionStructLiteral.Initializer(TokensFrom(initializerStartIndex), fieldName, fieldValue)); + } + + expr = new NodeExpressionStructLiteral(TokensFrom(startIndex), null, initializers); + } + else + { + expr = new NodeExpressionEnumLiteral(TokensFrom(startIndex), type, null); + } } else { @@ -869,14 +893,9 @@ public class NodeExpressionBoolLiteral(List tokens, TokenBoolLiteral valu public TokenBoolLiteral Value { get; } = value; } -public class NodeExpressionNewNamedType(List tokens, NodeType type, NodeExpression value) : NodeExpression(tokens) -{ - public NodeType Type { get; } = type; - public NodeExpression Value { get; } = value; -} - -public class NodeExpressionStructLiteral(List tokens, List initializers) : NodeExpression(tokens) +public class NodeExpressionStructLiteral(List tokens, NodeType? type, List initializers) : NodeExpression(tokens) { + public NodeType? Type { get; } = type; public List Initializers { get; } = initializers; public class Initializer(List tokens, TokenIdent name, NodeExpression value) : Node(tokens) @@ -886,6 +905,12 @@ public class NodeExpressionStructLiteral(List tokens, List tokens, NodeType type, NodeExpression? value) : NodeExpression(tokens) +{ + public NodeType Type { get; } = type; + public NodeExpression? Value { get; } = value; +} + public class NodeExpressionMemberAccess(List tokens, NodeExpression target, TokenIdent name) : NodeExpression(tokens) { public NodeExpression Target { get; } = target; diff --git a/compiler/TypeChecker.cs b/compiler/TypeChecker.cs index 3fbf68b..384f709 100644 --- a/compiler/TypeChecker.cs +++ b/compiler/TypeChecker.cs @@ -249,7 +249,7 @@ public class TypeChecker NodeExpressionFuncCall expression => CheckExpressionFuncCall(expression, expectedType), NodeExpressionStringLiteral expression => CheckExpressionStringLiteral(expression, expectedType), NodeExpressionStructLiteral expression => CheckExpressionStructLiteral(expression, expectedType), - NodeExpressionNewNamedType expression => CheckExpressionNewNamedType(expression, expectedType), + NodeExpressionEnumLiteral expression => CheckExpressionEnumLiteral(expression, expectedType), _ => throw new ArgumentOutOfRangeException(nameof(node)) }; } @@ -526,7 +526,35 @@ public class TypeChecker private TypedNodeExpressionStructLiteral CheckExpressionStructLiteral(NodeExpressionStructLiteral expression, NubType? expectedType) { - if (expectedType is NubTypeStruct structType) + if (expression.Type != null) + { + var type = ResolveType(expression.Type); + if (type is not NubTypeStruct structType) + throw BasicError("Type of struct literal is not a struct", expression); + + if (!moduleGraph.TryResolveType(structType.Module, structType.Name, structType.Module == currentModule, out var info)) + throw BasicError($"Type '{structType}' struct literal not found", expression); + + if (info is not Module.TypeInfoStruct structInfo) + throw BasicError($"Type '{structType}' is not a struct", expression.Type); + + var initializers = new List(); + foreach (var initializer in expression.Initializers) + { + var field = structInfo.Fields.FirstOrDefault(x => x.Name == initializer.Name.Ident); + if (field == null) + throw BasicError($"Field '{initializer.Name.Ident}' does not exist on struct '{structType.Module}::{structType.Name}'", initializer.Name); + + var value = CheckExpression(initializer.Value, field.Type); + if (!value.Type.IsAssignableTo(field.Type)) + throw BasicError($"Type of assignment ({value.Type}) does not match expected type of field '{field.Name}' ({field.Type})", initializer.Name); + + initializers.Add(new TypedNodeExpressionStructLiteral.Initializer(initializer.Tokens, initializer.Name, value)); + } + + return new TypedNodeExpressionStructLiteral(expression.Tokens, structType, initializers); + } + else if (expectedType is NubTypeStruct structType) { if (!moduleGraph.TryResolveType(structType.Module, structType.Name, structType.Module == currentModule, out var info)) throw BasicError($"Type '{structType}' struct literal not found", expression); @@ -583,36 +611,31 @@ public class TypeChecker } } - private TypedNodeExpressionNewNamedType CheckExpressionNewNamedType(NodeExpressionNewNamedType expression, NubType? expectedType) + private TypedNodeExpressionEnumLiteral CheckExpressionEnumLiteral(NodeExpressionEnumLiteral expression, NubType? expectedType) { var type = ResolveType(expression.Type); - switch (type) - { - case NubTypeStruct structType: - { - var value = CheckExpression(expression.Value, structType); - return new TypedNodeExpressionNewNamedType(expression.Tokens, structType, value); - } - case NubTypeEnumVariant enumVariantType: - { - if (!moduleGraph.TryResolveType(enumVariantType.EnumType.Module, enumVariantType.EnumType.Name, enumVariantType.EnumType.Module == currentModule, out var info)) - throw BasicError($"Type '{enumVariantType.EnumType}' not found", expression.Type); + if (type is not NubTypeEnumVariant variantType) + throw BasicError("Expected enum variant type", expression.Type); - if (info is not Module.TypeInfoEnum enumInfo) - throw BasicError($"Type '{enumVariantType.EnumType}' is not an enum", expression.Type); + if (!moduleGraph.TryResolveType(variantType.EnumType.Module, variantType.EnumType.Name, variantType.EnumType.Module == currentModule, out var info)) + throw BasicError($"Type '{variantType.EnumType}' not found", expression.Type); - var variant = enumInfo.Variants.FirstOrDefault(x => x.Name == enumVariantType.Variant); - if (variant == null) - throw BasicError($"Enum type '{enumVariantType.EnumType}' does not have a variant named '{enumVariantType.Variant}'", expression.Type); + if (info is not Module.TypeInfoEnum enumInfo) + throw BasicError($"Type '{variantType.EnumType}' is not an enum", expression.Type); - var value = CheckExpression(expression.Value, variant.Type); - return new TypedNodeExpressionNewNamedType(expression.Tokens, enumVariantType, value); - } - default: - { - throw BasicError($"'{type}' is not a valid type for the new operator", expression); - } - } + var variant = enumInfo.Variants.FirstOrDefault(x => x.Name == variantType.Variant); + if (variant == null) + throw BasicError($"Enum '{variantType.EnumType}' does not have a variant named '{variantType.Variant}'", expression.Type); + + if (expression.Value == null && variant.Type is not null) + throw BasicError($"Enum variant '{variantType.EnumType}' expects a value of type '{variant.Type}'", expression.Type); + + if (expression.Value != null && variant.Type is null) + throw BasicError($"Enum variant '{variantType.EnumType}' does not expect any data", expression.Value); + + var value = expression.Value == null ? null : CheckExpression(expression.Value, variant.Type); + + return new TypedNodeExpressionEnumLiteral(expression.Tokens, type, value); } private NubType ResolveType(NodeType node) @@ -892,9 +915,9 @@ public class TypedNodeExpressionStructLiteral(List tokens, NubType type, } } -public class TypedNodeExpressionNewNamedType(List tokens, NubType type, TypedNodeExpression value) : TypedNodeExpression(tokens, type) +public class TypedNodeExpressionEnumLiteral(List tokens, NubType type, TypedNodeExpression? value) : TypedNodeExpression(tokens, type) { - public TypedNodeExpression Value { get; } = value; + public TypedNodeExpression? Value { get; } = value; } public class TypedNodeExpressionStructMemberAccess(List tokens, NubType type, TypedNodeExpression target, TokenIdent name) : TypedNodeExpression(tokens, type) diff --git a/examples/program/main.nub b/examples/program/main.nub index d9b8b1e..f78f57c 100644 --- a/examples/program/main.nub +++ b/examples/program/main.nub @@ -14,16 +14,22 @@ func main(): i32 { core::println("Hello, world!") core::println("Hello" + "World") - let message: Message = new Message::Say "test" + let message = getMessage() match message { - Quit { + Quit + { core::println("quit") } - Say message { + Say message + { core::println(message) } } return 0 +} + +func getMessage(): Message { + return new Message::Say("test") } \ No newline at end of file