diff --git a/compiler/Generator.cs b/compiler/Generator.cs index 92c35dd..d6b0668 100644 --- a/compiler/Generator.cs +++ b/compiler/Generator.cs @@ -399,13 +399,13 @@ public class Generator private string EmitExpressionEnumLiteral(TypedNodeExpressionEnumLiteral expression) { - var enumType = (NubTypeEnum)expression.Type; + var enumVariantType = (NubTypeEnumVariant)expression.Type; - if (!moduleGraph.TryResolveType(enumType.Module, enumType.Name, true, out var info)) + if (!moduleGraph.TryResolveType(enumVariantType.EnumType.Module, enumVariantType.EnumType.Name, true, out var info)) throw new UnreachableException(); var enumInfo = (Module.TypeInfoEnum)info; - var tag = enumInfo.Variants.ToList().FindIndex(x => x.Name == expression.EnumVariant); + var tag = enumInfo.Variants.ToList().FindIndex(x => x.Name == enumVariantType.Variant); var initializerValues = new Dictionary(); @@ -417,7 +417,7 @@ public class Generator var initializerStrings = initializerValues.Select(x => $".{x.Key} = {x.Value}"); - return $"({CType(expression.Type)}){{ .tag = {tag}, .{expression.EnumVariant} = {{ {string.Join(", ", initializerStrings)} }} }}"; + return $"({CType(expression.Type)}){{ .tag = {tag}, .{enumVariantType.Variant} = {{ {string.Join(", ", initializerStrings)} }} }}"; } private string EmitExpressionMemberAccess(TypedNodeExpressionMemberAccess expression) @@ -449,6 +449,7 @@ public class Generator NubTypeBool => "bool" + (varName != null ? $" {varName}" : ""), NubTypeStruct type => $"struct {NameMangler.Mangle(type.Module, type.Name, type)}" + (varName != null ? $" {varName}" : ""), NubTypeEnum type => $"struct {NameMangler.Mangle(type.Module, type.Name, type)}" + (varName != null ? $" {varName}" : ""), + NubTypeEnumVariant type => CType(type.EnumType, varName), NubTypeSInt type => $"int{type.Width}_t" + (varName != null ? $" {varName}" : ""), NubTypeUInt type => $"uint{type.Width}_t" + (varName != null ? $" {varName}" : ""), NubTypePointer type => CType(type.To) + (varName != null ? $" *{varName}" : "*"), diff --git a/compiler/NubType.cs b/compiler/NubType.cs index 962868b..46a7040 100644 --- a/compiler/NubType.cs +++ b/compiler/NubType.cs @@ -8,6 +8,21 @@ namespace Compiler; public abstract class NubType { public abstract override string ToString(); + + [Obsolete("Use IsAssignableTo instead of ==", error: true)] + public static bool operator ==(NubType? a, NubType? b) => throw new InvalidOperationException("Use IsAssignableTo"); + + [Obsolete("Use IsAssignableTo instead of ==", error: true)] + public static bool operator !=(NubType? a, NubType? b) => throw new InvalidOperationException("Use IsAssignableTo"); + + public bool IsAssignableTo(NubType target) + { + return (this, target) switch + { + (NubTypeEnumVariant variant, NubTypeEnum targetEnum) => ReferenceEquals(variant.EnumType, targetEnum), + _ => ReferenceEquals(this, target), + }; + } } public class NubTypeVoid : NubType @@ -135,6 +150,30 @@ public class NubTypeEnum : NubType public override string ToString() => $"enum {Module}::{Name}"; } +public class NubTypeEnumVariant : NubType +{ + private static readonly Dictionary<(NubTypeEnum EnumType, string Variant), NubTypeEnumVariant> Cache = new(); + + public static NubTypeEnumVariant Get(NubTypeEnum enumType, string variant) + { + if (!Cache.TryGetValue((enumType, variant), out var variantType)) + Cache[(enumType, variant)] = variantType = new NubTypeEnumVariant(enumType, variant); + + return variantType; + } + + private NubTypeEnumVariant(NubTypeEnum enumType, string variant) + { + EnumType = enumType; + Variant = variant; + } + + public NubTypeEnum EnumType { get; } + public string Variant { get; } + + public override string ToString() => $"{EnumType}.{Variant}"; +} + public class NubTypePointer : NubType { private static readonly Dictionary Cache = new(); @@ -245,11 +284,21 @@ public class TypeEncoder sb.Append(')'); break; - case NubTypeEnum st: - sb.Append("E("); - sb.Append(st.Module); + case NubTypeEnum e: + sb.Append("EN("); + sb.Append(e.Module); sb.Append(':'); - sb.Append(st.Name); + sb.Append(e.Name); + sb.Append(')'); + break; + + case NubTypeEnumVariant ev: + sb.Append("EV("); + sb.Append(ev.EnumType.Module); + sb.Append(':'); + sb.Append(ev.EnumType.Name); + sb.Append(':'); + sb.Append(ev.Variant); sb.Append(')'); break; @@ -364,23 +413,49 @@ public class TypeDecoder return NubTypeStruct.Get(module, name); } - private NubTypeEnum DecodeEnum() + private NubType DecodeEnum() { var sb = new StringBuilder(); - Expect('('); - while (!TryExpect(':')) - sb.Append(Consume()); + if (TryExpect('V')) + { + Expect('('); + while (!TryExpect(':')) + sb.Append(Consume()); - var module = sb.ToString(); - sb.Clear(); + var module = sb.ToString(); + sb.Clear(); - while (!TryExpect(')')) - sb.Append(Consume()); + while (!TryExpect(':')) + sb.Append(Consume()); - var name = sb.ToString(); + var name = sb.ToString(); - return NubTypeEnum.Get(module, name); + while (!TryExpect(')')) + sb.Append(Consume()); + + var variant = sb.ToString(); + + return NubTypeEnumVariant.Get(NubTypeEnum.Get(module, name), variant); + } + else if (TryExpect('N')) + { + Expect('('); + while (!TryExpect(':')) + sb.Append(Consume()); + + var module = sb.ToString(); + sb.Clear(); + + while (!TryExpect(')')) + sb.Append(Consume()); + + var name = sb.ToString(); + + return NubTypeEnum.Get(module, name); + } + + throw new Exception($"Expected 'V' or 'N'"); } private bool TryPeek(out char c) diff --git a/compiler/Parser.cs b/compiler/Parser.cs index e46303f..253c2d5 100644 --- a/compiler/Parser.cs +++ b/compiler/Parser.cs @@ -802,7 +802,7 @@ public class NodeStatementMatch(List tokens, NodeExpression target, List< public class Case(List tokens, TokenIdent type, TokenIdent variableName, NodeStatement body) : Node(tokens) { - public TokenIdent Type { get; } = type; + public TokenIdent Variant { get; } = type; public TokenIdent VariableName { get; } = variableName; public NodeStatement Body { get; } = body; } diff --git a/compiler/TypeChecker.cs b/compiler/TypeChecker.cs index 75e94ca..f48dc42 100644 --- a/compiler/TypeChecker.cs +++ b/compiler/TypeChecker.cs @@ -47,7 +47,7 @@ public class TypeChecker continue; } - scope.DeclareIdentifier(parameter.Name.Ident, parameterType, null); + scope.DeclareIdentifier(parameter.Name.Ident, parameterType); parameters.Add(new TypedNodeDefinitionFunc.Param(parameter.Tokens, parameter.Name, parameterType)); } @@ -69,7 +69,7 @@ public class TypeChecker diagnostics.Add(e.Diagnostic); } - if (body == null || returnType == null || invalidParameter) + if (body == null || returnType is null || invalidParameter) return null; return new TypedNodeDefinitionFunc(function.Tokens, moduleName, function.Name, parameters, body, returnType); @@ -129,10 +129,10 @@ public class TypeChecker var type = ResolveType(statement.Type); var value = CheckExpression(statement.Value); - if (type != value.Type) + if (!value.Type.IsAssignableTo(type)) throw new CompileException(Diagnostic.Error("Type of variable does match type of assigned value").At(fileName, value).Build()); - scope.DeclareIdentifier(statement.Name.Ident, type, value.EnumVariant); + scope.DeclareIdentifier(statement.Name.Ident, type); return new TypedNodeStatementVariableDeclaration(statement.Tokens, statement.Name, type, value); } @@ -152,9 +152,9 @@ public class TypeChecker { using (scope.EnterScope()) { - scope.DeclareIdentifier(@case.VariableName.Ident, NubTypeEnum.Get(enumType.Module, enumType.Name), @case.Type.Ident); + scope.DeclareIdentifier(@case.VariableName.Ident, NubTypeEnumVariant.Get(NubTypeEnum.Get(enumType.Module, enumType.Name), @case.Variant.Ident)); var body = CheckStatement(@case.Body); - cases.Add(new TypedNodeStatementMatch.Case(@case.Tokens, @case.Type, @case.VariableName, body)); + cases.Add(new TypedNodeStatementMatch.Case(@case.Tokens, @case.Variant, @case.VariableName, body)); } } @@ -321,13 +321,10 @@ public class TypeChecker private TypedNodeExpressionLocalIdent CheckExpressionIdent(NodeExpressionLocalIdent expression) { var type = scope.GetIdentifierType(expression.Value.Ident); - if (!type.HasValue) + if (type is null) throw new CompileException(Diagnostic.Error($"Identifier '{expression.Value.Ident}' is not declared").At(fileName, expression.Value).Build()); - return new TypedNodeExpressionLocalIdent(expression.Tokens, type.Value.Type, expression.Value) - { - EnumVariant = type.Value.EnumVariant - }; + return new TypedNodeExpressionLocalIdent(expression.Tokens, type, expression.Value); } private TypedNodeExpressionModuleIdent CheckExpressionModuleIdent(NodeExpressionModuleIdent expression) @@ -371,23 +368,20 @@ public class TypeChecker return new TypedNodeExpressionMemberAccess(expression.Tokens, field.Type, target, expression.Name); } - case NubTypeEnum enumType: + case NubTypeEnumVariant enumVariantType: { - if (target.EnumVariant == null) - throw new CompileException(Diagnostic.Error($"Cannot access member '{expression.Name.Ident}' on enum '{enumType}' without knowing the variant").At(fileName, target).Build()); + if (!moduleGraph.TryResolveModule(enumVariantType.EnumType.Module, out var module)) + throw new CompileException(Diagnostic.Error($"Module '{enumVariantType.EnumType.Module}' not found").At(fileName, expression.Target).Build()); - if (!moduleGraph.TryResolveModule(enumType.Module, out var module)) - throw new CompileException(Diagnostic.Error($"Module '{enumType.Module}' not found").At(fileName, expression.Target).Build()); - - if (!module.TryResolveType(enumType.Name, moduleName == enumType.Module, out var typeDef)) - throw new CompileException(Diagnostic.Error($"Type '{enumType.Name}' not found in module '{enumType.Module}'").At(fileName, expression.Target).Build()); + if (!module.TryResolveType(enumVariantType.EnumType.Name, moduleName == enumVariantType.EnumType.Module, out var typeDef)) + throw new CompileException(Diagnostic.Error($"Type '{enumVariantType.EnumType.Name}' not found in module '{enumVariantType.EnumType.Module}'").At(fileName, expression.Target).Build()); if (typeDef is not Module.TypeInfoEnum enumDef) - throw new CompileException(Diagnostic.Error($"Type '{enumType.Module}::{enumType.Name}' is not an enum").At(fileName, expression.Target).Build()); + throw new CompileException(Diagnostic.Error($"Type '{enumVariantType.EnumType.Module}::{enumVariantType.EnumType.Name}' is not an enum").At(fileName, expression.Target).Build()); - var variant = enumDef.Variants.FirstOrDefault(x => x.Name == target.EnumVariant); + var variant = enumDef.Variants.FirstOrDefault(x => x.Name == enumVariantType.Variant); if (variant == null) - throw new CompileException(Diagnostic.Error($"Type '{target.Type}' does not have a variant '{target.EnumVariant}'").At(fileName, expression.Target).Build()); + throw new CompileException(Diagnostic.Error($"Type '{target.Type}' does not have a variant named '{enumVariantType.Variant}'").At(fileName, expression.Target).Build()); var field = variant.Fields.FirstOrDefault(x => x.Name == expression.Name.Ident); if (field == null) @@ -437,7 +431,7 @@ public class TypeChecker throw new CompileException(Diagnostic.Error($"Field '{initializer.Name.Ident}' does not exist on struct '{expression.Module.Ident}::{expression.Name.Ident}'").At(fileName, initializer.Name).Build()); var value = CheckExpression(initializer.Value); - if (value.Type != field.Type) + if (!value.Type.IsAssignableTo(field.Type)) throw new CompileException(Diagnostic.Error($"Type of assignment ({value.Type}) does not match expected type of field '{field.Name}' ({field.Type})").At(fileName, initializer.Name).Build()); initializers.Add(new TypedNodeExpressionStructLiteral.Initializer(initializer.Tokens, initializer.Name, value)); @@ -471,16 +465,13 @@ public class TypeChecker throw new CompileException(Diagnostic.Error($"Field '{initializer.Name.Ident}' does not exist on enum variant '{expression.Module.Ident}::{expression.EnumName.Ident}.{expression.VariantName.Ident}'").At(fileName, initializer.Name).Build()); var value = CheckExpression(initializer.Value); - if (value.Type != field.Type) + if (!value.Type.IsAssignableTo(field.Type)) throw new CompileException(Diagnostic.Error($"Type of assignment ({value.Type}) does not match expected type of field '{field.Name}' ({field.Type})").At(fileName, initializer.Name).Build()); initializers.Add(new TypedNodeExpressionEnumLiteral.Initializer(initializer.Tokens, initializer.Name, value)); } - return new TypedNodeExpressionEnumLiteral(expression.Tokens, NubTypeEnum.Get(expression.Module.Ident, expression.EnumName.Ident), initializers) - { - EnumVariant = expression.VariantName.Ident - }; + return new TypedNodeExpressionEnumLiteral(expression.Tokens, NubTypeEnumVariant.Get(NubTypeEnum.Get(expression.Module.Ident, expression.EnumName.Ident), expression.VariantName.Ident), initializers); } private NubType ResolveType(NodeType node) @@ -519,7 +510,7 @@ public class TypeChecker private sealed class Scope { - private readonly Stack> scopes = new(); + private readonly Stack> scopes = new(); public IDisposable EnterScope() { @@ -527,12 +518,12 @@ public class TypeChecker return new ScopeGuard(this); } - public void DeclareIdentifier(string name, NubType type, string? enumVariant) + public void DeclareIdentifier(string name, NubType type) { - scopes.Peek().Add(name, (type, enumVariant)); + scopes.Peek().Add(name, type); } - public (NubType Type, string? EnumVariant)? GetIdentifierType(string name) + public NubType? GetIdentifierType(string name) { foreach (var scope in scopes) { @@ -649,7 +640,6 @@ public class TypedNodeStatementMatch(List tokens, TypedNodeExpression tar public abstract class TypedNodeExpression(List tokens, NubType type) : TypedNode(tokens) { public NubType Type { get; } = type; - public string? EnumVariant { get; init; } } public class TypedNodeExpressionIntLiteral(List tokens, NubType type, TokenIntLiteral value) : TypedNodeExpression(tokens, type)