diff --git a/src/compiler/NubLang/Generation/QBE/QBEGenerator.Expression.cs b/src/compiler/NubLang/Generation/QBE/QBEGenerator.Expression.cs index c4c5460..8064180 100644 --- a/src/compiler/NubLang/Generation/QBE/QBEGenerator.Expression.cs +++ b/src/compiler/NubLang/Generation/QBE/QBEGenerator.Expression.cs @@ -48,7 +48,7 @@ public partial class QBEGenerator var elementType = ((ArrayTypeNode)arrayIndexAccess.Target.Type).ElementType; var pointer = TmpName(); - _writer.Indented($"{pointer} =l mul {index}, {elementType.Size(_definitionTable)}"); + _writer.Indented($"{pointer} =l mul {index}, {SizeOf(elementType)}"); _writer.Indented($"{pointer} =l add {pointer}, 8"); _writer.Indented($"{pointer} =l add {array}, {pointer}"); return new Val(pointer, arrayIndexAccess.Type, ValKind.Pointer); @@ -81,7 +81,7 @@ public partial class QBEGenerator private Val EmitArrayInitializer(ArrayInitializerNode arrayInitializer) { var capacity = EmitUnwrap(EmitExpression(arrayInitializer.Capacity)); - var elementSize = arrayInitializer.ElementType.Size(_definitionTable); + var elementSize = SizeOf(arrayInitializer.ElementType); var capacityInBytes = TmpName(); _writer.Indented($"{capacityInBytes} =l mul {capacity}, {elementSize}"); @@ -291,7 +291,7 @@ public partial class QBEGenerator } case LiteralKind.String: { - if (literal.Type is NubStringTypeNode) + if (literal.Type is StringTypeNode) { var stringLiteral = new StringLiteral(literal.Value, StringName()); _stringLiterals.Add(stringLiteral); @@ -328,7 +328,7 @@ public partial class QBEGenerator if (destination == null) { destination = TmpName(); - var size = structInitializer.StructType.Size(_definitionTable); + var size = SizeOf(structInitializer.StructType); if (structDef.InterfaceImplementations.Any()) { @@ -413,7 +413,7 @@ public partial class QBEGenerator _writer.Indented($"{output} =l add {target}, {offset}"); // If the accessed member is an inline struct, it will not be a pointer - if (structFieldAccess.Type is CustomTypeNode customType && customType.Kind(_definitionTable) == CustomTypeKind.Struct) + if (structFieldAccess.Type is StructTypeNode) { return new Val(output, structFieldAccess.Type, ValKind.Direct); } diff --git a/src/compiler/NubLang/Generation/QBE/QBEGenerator.cs b/src/compiler/NubLang/Generation/QBE/QBEGenerator.cs index 5fe9ba7..cb2435d 100644 --- a/src/compiler/NubLang/Generation/QBE/QBEGenerator.cs +++ b/src/compiler/NubLang/Generation/QBE/QBEGenerator.cs @@ -189,7 +189,7 @@ public partial class QBEGenerator { var size = TmpName(); _writer.Indented($"{size} =l loadl {array}"); - _writer.Indented($"{size} =l mul {size}, {type.ElementType.Size(_definitionTable)}"); + _writer.Indented($"{size} =l mul {size}, {SizeOf(type.ElementType)}"); _writer.Indented($"{size} =l add {size}, 8"); return size; } @@ -250,9 +250,9 @@ public partial class QBEGenerator } else { - if (complexType is CustomTypeNode customType) + if (complexType is StructTypeNode structType) { - EmitMemcpy(value, destinationPointer, customType.Size(_definitionTable).ToString()); + EmitMemcpy(value, destinationPointer, SizeOf(structType).ToString()); } else { @@ -260,7 +260,8 @@ public partial class QBEGenerator { ArrayTypeNode arrayType => EmitArraySizeInBytes(arrayType, value), CStringTypeNode => EmitCStringSizeInBytes(value), - NubStringTypeNode => EmitStringSizeInBytes(value), + StringTypeNode => EmitStringSizeInBytes(value), + InterfaceTypeNode => 16.ToString(), _ => throw new ArgumentOutOfRangeException(nameof(source.Type)) }; @@ -309,8 +310,9 @@ public partial class QBEGenerator { ArrayTypeNode arrayType => EmitArraySizeInBytes(arrayType, value), CStringTypeNode => EmitCStringSizeInBytes(value), - NubStringTypeNode => EmitStringSizeInBytes(value), - CustomTypeNode customType => customType.Size(_definitionTable).ToString(), + StringTypeNode => EmitStringSizeInBytes(value), + InterfaceTypeNode => 16.ToString(), + StructTypeNode structType => SizeOf(structType).ToString(), _ => throw new ArgumentOutOfRangeException(nameof(source.Type)) }; @@ -339,9 +341,14 @@ public partial class QBEGenerator }; } - if (complexType is CustomTypeNode customType) + if (complexType is StructTypeNode structType) { - return CustomTypeName(customType.Name); + return StructTypeName(structType.Name); + } + + if (complexType is InterfaceTypeNode interfaceType) + { + return InterfaceTypeName(interfaceType.Name); } return "l"; @@ -387,7 +394,7 @@ public partial class QBEGenerator private void EmitStructTypeDefinition(StructNode structDef) { - _writer.WriteLine($"type {CustomTypeName(structDef.Name)} = {{ "); + _writer.WriteLine($"type {StructTypeName(structDef.Name)} = {{ "); var types = new Dictionary(); @@ -422,9 +429,14 @@ public partial class QBEGenerator }; } - if (complexType is CustomTypeNode customType) + if (complexType is StructTypeNode structType) { - return CustomTypeName(customType.Name); + return StructTypeName(structType.Name); + } + + if (complexType is InterfaceTypeNode interfaceType) + { + return InterfaceTypeName(interfaceType.Name); } return "l"; @@ -458,6 +470,93 @@ public partial class QBEGenerator }; } + private static int SizeOf(TypeNode type) + { + return type switch + { + SimpleTypeNode simple => simple.StorageSize switch + { + StorageSize.Void => 0, + StorageSize.I8 or StorageSize.U8 => 1, + StorageSize.I16 or StorageSize.U16 => 2, + StorageSize.I32 or StorageSize.U32 or StorageSize.F32 => 4, + StorageSize.I64 or StorageSize.U64 or StorageSize.F64 => 8, + _ => throw new ArgumentOutOfRangeException(nameof(type), $"Unknown storage size: {simple.StorageSize}") + }, + CStringTypeNode => 8, + StringTypeNode => 8, + ArrayTypeNode => 8, + StructTypeNode structType => CalculateStructSize(structType), + InterfaceTypeNode => 16, + _ => throw new ArgumentOutOfRangeException(nameof(type), $"Unknown type: {type.GetType()}") + }; + } + + private static int CalculateStructSize(StructTypeNode structType) + { + var offset = 0; + + if (structType.InterfaceImplementations.Any()) + { + offset = 8; + } + + foreach (var field in structType.Fields) + { + var fieldAlignment = AlignmentOf(field); + offset = AlignTo(offset, fieldAlignment); + offset += SizeOf(field); + } + + var structAlignment = CalculateStructAlignment(structType); + return AlignTo(offset, structAlignment); + } + + private static int AlignmentOf(TypeNode type) + { + return type switch + { + SimpleTypeNode simple => simple.StorageSize switch + { + StorageSize.Void => 1, + StorageSize.I8 or StorageSize.U8 => 1, + StorageSize.I16 or StorageSize.U16 => 2, + StorageSize.I32 or StorageSize.U32 or StorageSize.F32 => 4, + StorageSize.I64 or StorageSize.U64 or StorageSize.F64 => 8, + _ => throw new ArgumentOutOfRangeException(nameof(type), $"Unknown storage size: {simple.StorageSize}") + }, + CStringTypeNode => 8, + StringTypeNode => 8, + ArrayTypeNode => 8, + StructTypeNode structType => CalculateStructAlignment(structType), + InterfaceTypeNode => 8, + _ => throw new ArgumentOutOfRangeException(nameof(type), $"Unknown type: {type.GetType()}") + }; + } + + private static int CalculateStructAlignment(StructTypeNode structType) + { + var maxAlignment = 1; + + if (structType.InterfaceImplementations.Any()) + { + maxAlignment = Math.Max(maxAlignment, 8); + } + + foreach (var field in structType.Fields) + { + var fieldAlignment = AlignmentOf(field); + maxAlignment = Math.Max(maxAlignment, fieldAlignment); + } + + return maxAlignment; + } + + private static int AlignTo(int offset, int alignment) + { + return (offset + alignment - 1) & ~(alignment - 1); + } + private int OffsetOf(StructNode structDef, string member) { var offset = 0; @@ -474,10 +573,10 @@ public partial class QBEGenerator return offset; } - var fieldAlignment = field.Type.Alignment(_definitionTable); + var fieldAlignment = AlignmentOf(field.Type); - offset = TypeNode.AlignTo(offset, fieldAlignment); - offset += field.Type.Size(_definitionTable); + offset = AlignTo(offset, fieldAlignment); + offset += SizeOf(field.Type); } throw new UnreachableException($"Member '{member}' not found in struct"); @@ -515,7 +614,12 @@ public partial class QBEGenerator return $"${funcDef.CallName}"; } - private string CustomTypeName(string name) + private string StructTypeName(string name) + { + return $":{name}"; + } + + private string InterfaceTypeName(string name) { return $":{name}"; } diff --git a/src/compiler/NubLang/TypeChecking/DefinitionTable.cs b/src/compiler/NubLang/TypeChecking/DefinitionTable.cs index ee094d4..454a57f 100644 --- a/src/compiler/NubLang/TypeChecking/DefinitionTable.cs +++ b/src/compiler/NubLang/TypeChecking/DefinitionTable.cs @@ -26,11 +26,11 @@ public class DefinitionTable .Where(x => x.Name == name); } - public IEnumerable LookupStruct(CustomTypeNode type) + public IEnumerable LookupStruct(string name) { return _definitions .OfType() - .Where(x => x.Name == type.Name); + .Where(x => x.Name == name); } public IEnumerable LookupStructField(StructSyntax @struct, string field) @@ -43,11 +43,11 @@ public class DefinitionTable return @struct.Functions.Where(x => x.Name == func); } - public IEnumerable LookupInterface(CustomTypeNode type) + public IEnumerable LookupInterface(string name) { return _definitions .OfType() - .Where(x => x.Name == type.Name); + .Where(x => x.Name == name); } public IEnumerable LookupInterfaceFunc(InterfaceSyntax @interface, string name) diff --git a/src/compiler/NubLang/TypeChecking/Node/DefinitionNode.cs b/src/compiler/NubLang/TypeChecking/Node/DefinitionNode.cs index f7ad577..3a42cd3 100644 --- a/src/compiler/NubLang/TypeChecking/Node/DefinitionNode.cs +++ b/src/compiler/NubLang/TypeChecking/Node/DefinitionNode.cs @@ -14,7 +14,7 @@ public record StructFieldNode(int Index, string Name, TypeNode Type, Optional Fields, IReadOnlyList Functions, IReadOnlyList InterfaceImplementations) : DefinitionNode; +public record StructNode(string Name, IReadOnlyList Fields, IReadOnlyList Functions, IReadOnlyList InterfaceImplementations) : DefinitionNode; public record InterfaceFuncNode(string Name, FuncSignatureNode Signature) : Node; diff --git a/src/compiler/NubLang/TypeChecking/Node/ExpressionNode.cs b/src/compiler/NubLang/TypeChecking/Node/ExpressionNode.cs index 754ba22..9d835d4 100644 --- a/src/compiler/NubLang/TypeChecking/Node/ExpressionNode.cs +++ b/src/compiler/NubLang/TypeChecking/Node/ExpressionNode.cs @@ -46,12 +46,12 @@ public record AddressOfNode(TypeNode Type, ExpressionNode Expression) : Expressi public record LiteralNode(TypeNode Type, string Value, LiteralKind Kind) : ExpressionNode(Type); -public record StructFieldAccessNode(TypeNode Type, CustomTypeNode StructType, ExpressionNode Target, string Field) : ExpressionNode(Type); +public record StructFieldAccessNode(TypeNode Type, StructTypeNode StructType, ExpressionNode Target, string Field) : ExpressionNode(Type); -public record StructFuncAccessNode(TypeNode Type, CustomTypeNode StructType, ExpressionNode Target, string Func) : ExpressionNode(Type); +public record StructFuncAccessNode(TypeNode Type, StructTypeNode StructType, ExpressionNode Target, string Func) : ExpressionNode(Type); -public record InterfaceFuncAccessNode(TypeNode Type, CustomTypeNode InterfaceType, ExpressionNode Target, string FuncName) : ExpressionNode(Type); +public record InterfaceFuncAccessNode(TypeNode Type, InterfaceTypeNode InterfaceType, ExpressionNode Target, string FuncName) : ExpressionNode(Type); -public record StructInitializerNode(CustomTypeNode StructType, Dictionary Initializers) : ExpressionNode(StructType); +public record StructInitializerNode(StructTypeNode StructType, Dictionary Initializers) : ExpressionNode(StructType); public record DereferenceNode(TypeNode Type, ExpressionNode Expression) : ExpressionNode(Type); \ No newline at end of file diff --git a/src/compiler/NubLang/TypeChecking/Node/TypeNode.cs b/src/compiler/NubLang/TypeChecking/Node/TypeNode.cs index 9567073..511a872 100644 --- a/src/compiler/NubLang/TypeChecking/Node/TypeNode.cs +++ b/src/compiler/NubLang/TypeChecking/Node/TypeNode.cs @@ -1,11 +1,10 @@ using System.Diagnostics.CodeAnalysis; -using NubLang.Generation; namespace NubLang.TypeChecking.Node; public abstract class TypeNode : IEquatable { - public bool IsSimpleType([NotNullWhen(true)] out SimpleTypeNode? simpleType, [NotNullWhen(false)] out NubComplexTypeNode? complexType) + public bool IsSimpleType([NotNullWhen(true)] out SimpleTypeNode? simpleType, [NotNullWhen(false)] out ComplexTypeNode? complexType) { if (this is SimpleTypeNode st) { @@ -14,7 +13,7 @@ public abstract class TypeNode : IEquatable return true; } - if (this is NubComplexTypeNode ct) + if (this is ComplexTypeNode ct) { complexType = ct; simpleType = null; @@ -24,14 +23,6 @@ public abstract class TypeNode : IEquatable throw new ArgumentException($"Type {this} is not a simple type nor a complex type"); } - public abstract int Size(TypedDefinitionTable definitionTable); - public abstract int Alignment(TypedDefinitionTable definitionTable); - - public static int AlignTo(int offset, int alignment) - { - return (offset + alignment - 1) & ~(alignment - 1); - } - public override bool Equals(object? obj) => obj is TypeNode other && Equals(other); public abstract bool Equals(TypeNode? other); @@ -60,23 +51,6 @@ public enum StorageSize public abstract class SimpleTypeNode : TypeNode { public abstract StorageSize StorageSize { get; } - - public override int Size(TypedDefinitionTable definitionTable) - { - return StorageSize switch - { - StorageSize.I64 or StorageSize.U64 or StorageSize.F64 => 8, - StorageSize.I32 or StorageSize.U32 or StorageSize.F32 => 4, - StorageSize.I16 or StorageSize.U16 => 2, - StorageSize.I8 or StorageSize.U8 => 1, - _ => throw new ArgumentOutOfRangeException(nameof(StorageSize)) - }; - } - - public override int Alignment(TypedDefinitionTable definitionTable) - { - return Size(definitionTable); - } } #region Simple types @@ -136,7 +110,7 @@ public class BoolTypeNode : SimpleTypeNode public override int GetHashCode() => HashCode.Combine(typeof(BoolTypeNode)); } -public class FuncTypeNode(List parameters, TypeNode returnType) : SimpleTypeNode +public class FuncTypeNode(IReadOnlyList parameters, TypeNode returnType) : SimpleTypeNode { public IReadOnlyList Parameters { get; } = parameters; public TypeNode ReturnType { get; } = returnType; @@ -182,110 +156,50 @@ public class VoidTypeNode : SimpleTypeNode #endregion -public abstract class NubComplexTypeNode : TypeNode; +public abstract class ComplexTypeNode : TypeNode; #region Complex types -public class CStringTypeNode : NubComplexTypeNode +public class CStringTypeNode : ComplexTypeNode { - public override int Size(TypedDefinitionTable definitionTable) => 8; - public override int Alignment(TypedDefinitionTable definitionTable) => Size(definitionTable); - public override string ToString() => "cstring"; public override bool Equals(TypeNode? other) => other is CStringTypeNode; public override int GetHashCode() => HashCode.Combine(typeof(CStringTypeNode)); } -public class NubStringTypeNode : NubComplexTypeNode +public class StringTypeNode : ComplexTypeNode { - public override int Size(TypedDefinitionTable definitionTable) => 8; - public override int Alignment(TypedDefinitionTable definitionTable) => Size(definitionTable); - public override string ToString() => "string"; - public override bool Equals(TypeNode? other) => other is NubStringTypeNode; - public override int GetHashCode() => HashCode.Combine(typeof(NubStringTypeNode)); + public override bool Equals(TypeNode? other) => other is StringTypeNode; + public override int GetHashCode() => HashCode.Combine(typeof(StringTypeNode)); } -public class CustomTypeNode(string name) : NubComplexTypeNode +public class StructTypeNode(string name, IReadOnlyList fields, IReadOnlyList functions, IReadOnlyList interfaceImplementations) : ComplexTypeNode { public string Name { get; } = name; - - public CustomTypeKind Kind(TypedDefinitionTable definitionTable) - { - if (definitionTable.GetStructs().Any(x => x.Name == Name)) - { - return CustomTypeKind.Struct; - } - - if (definitionTable.GetInterfaces().Any(x => x.Name == Name)) - { - return CustomTypeKind.Interface; - } - - throw new ArgumentException($"Definition table does not have any type information for {this}"); - } - - public override int Size(TypedDefinitionTable definitionTable) - { - switch (Kind(definitionTable)) - { - case CustomTypeKind.Struct: - { - var structDef = definitionTable.LookupStruct(Name); - var size = 0; - var maxAlignment = 1; - - foreach (var field in structDef.Fields) - { - var fieldAlignment = field.Type.Alignment(definitionTable); - maxAlignment = Math.Max(maxAlignment, fieldAlignment); - - size = AlignTo(size, fieldAlignment); - size += field.Type.Size(definitionTable); - } - - return AlignTo(size, maxAlignment); - } - case CustomTypeKind.Interface: - { - return 16; - } - default: - throw new ArgumentOutOfRangeException(); - } - } - - public override int Alignment(TypedDefinitionTable definitionTable) - { - switch (Kind(definitionTable)) - { - case CustomTypeKind.Struct: - return definitionTable.LookupStruct(Name).Fields.Max(f => f.Type.Alignment(definitionTable)); - case CustomTypeKind.Interface: - return 8; - default: - throw new ArgumentOutOfRangeException(); - } - } + public IReadOnlyList Fields { get; } = fields; + public IReadOnlyList Functions { get; } = functions; + public IReadOnlyList InterfaceImplementations { get; } = interfaceImplementations; public override string ToString() => Name; - public override bool Equals(TypeNode? other) => other is CustomTypeNode custom && Name == custom.Name; - public override int GetHashCode() => HashCode.Combine(typeof(CustomTypeNode), Name); + public override bool Equals(TypeNode? other) => other is StructTypeNode custom && Name == custom.Name; + public override int GetHashCode() => HashCode.Combine(typeof(StructTypeNode), Name); } -public enum CustomTypeKind +public class InterfaceTypeNode(string name, IReadOnlyList funcs) : ComplexTypeNode { - Struct, - Interface + public string Name { get; } = name; + public IReadOnlyList Funcs { get; } = funcs; + + public override string ToString() => Name; + public override bool Equals(TypeNode? other) => other is InterfaceTypeNode custom && Name == custom.Name; + public override int GetHashCode() => HashCode.Combine(typeof(InterfaceTypeNode), Name); } -public class ArrayTypeNode(TypeNode elementType) : NubComplexTypeNode +public class ArrayTypeNode(TypeNode elementType) : ComplexTypeNode { public TypeNode ElementType { get; } = elementType; - public override int Size(TypedDefinitionTable definitionTable) => 8; - public override int Alignment(TypedDefinitionTable definitionTable) => Size(definitionTable); - public override string ToString() => "[]" + ElementType; public override bool Equals(TypeNode? other) => other is ArrayTypeNode array && ElementType.Equals(array.ElementType); diff --git a/src/compiler/NubLang/TypeChecking/TypeChecker.cs b/src/compiler/NubLang/TypeChecking/TypeChecker.cs index 9c6f06d..e29b612 100644 --- a/src/compiler/NubLang/TypeChecking/TypeChecker.cs +++ b/src/compiler/NubLang/TypeChecking/TypeChecker.cs @@ -38,7 +38,7 @@ public sealed class TypeChecker { definitions.Add(CheckDefinition(definition)); } - catch (CheckException e) + catch (TypeCheckerException e) { _diagnostics.Add(e.Diagnostic); } @@ -100,32 +100,32 @@ public sealed class TypeChecker funcs.Add(new StructFuncNode(func.Name, CheckFuncSignature(func.Signature), CheckFuncBody(func.Body, CheckType(func.Signature.ReturnType), parameters))); } - var interfaceImplementations = new List(); + var interfaceImplementations = new List(); foreach (var interfaceImplementation in node.InterfaceImplementations) { - var interfaceType = CheckType(interfaceImplementation); - if (interfaceType is not CustomTypeNode customType) + var type = CheckType(interfaceImplementation); + if (type is not InterfaceTypeNode interfaceType) { _diagnostics.Add(Diagnostic.Error("Interface implementation is not a custom type").Build()); continue; } - var interfaceDefs = _definitionTable.LookupInterface(customType).ToArray(); + var interfaceDefs = _definitionTable.LookupInterface(interfaceType.Name).ToArray(); if (interfaceDefs.Length == 0) { - _diagnostics.Add(Diagnostic.Error($"Interface {customType.Name} is not defined").Build()); + _diagnostics.Add(Diagnostic.Error($"Interface {interfaceType.Name} is not defined").Build()); continue; } if (interfaceDefs.Length > 1) { - _diagnostics.Add(Diagnostic.Error($"Interface {customType.Name} has multiple definitions").Build()); + _diagnostics.Add(Diagnostic.Error($"Interface {interfaceType.Name} has multiple definitions").Build()); continue; } - interfaceImplementations.Add(customType); + interfaceImplementations.Add(interfaceType); } return new StructNode(node.Name, structFields, funcs, interfaceImplementations); @@ -262,12 +262,12 @@ public sealed class TypeChecker { if (expectedType == null) { - throw new CheckException(Diagnostic.Error("Cannot infer argument types for arrow function").Build()); + throw new TypeCheckerException(Diagnostic.Error("Cannot infer argument types for arrow function").Build()); } if (expectedType is not FuncTypeNode funcType) { - throw new CheckException(Diagnostic.Error($"Expected {expectedType}, but got arrow function").Build()); + throw new TypeCheckerException(Diagnostic.Error($"Expected {expectedType}, but got arrow function").Build()); } var parameters = new List(); @@ -276,7 +276,7 @@ public sealed class TypeChecker { if (i >= funcType.Parameters.Count) { - throw new CheckException(Diagnostic.Error($"Arrow function expected a maximum of {funcType.Parameters.Count} arguments").Build()); + throw new TypeCheckerException(Diagnostic.Error($"Arrow function expected a maximum of {funcType.Parameters.Count} arguments").Build()); } var expectedParameterType = funcType.Parameters[i]; @@ -353,7 +353,7 @@ public sealed class TypeChecker { if (localFuncs.Length > 1) { - throw new CheckException(Diagnostic.Error($"Extern func {expression.Name} has multiple definitions").Build()); + throw new TypeCheckerException(Diagnostic.Error($"Extern func {expression.Name} has multiple definitions").Build()); } var localFunc = localFuncs[0]; @@ -369,7 +369,7 @@ public sealed class TypeChecker { if (externFuncs.Length > 1) { - throw new CheckException(Diagnostic.Error($"Extern func {expression.Name} has multiple definitions").Build()); + throw new TypeCheckerException(Diagnostic.Error($"Extern func {expression.Name} has multiple definitions").Build()); } var externFunc = externFuncs[0]; @@ -380,7 +380,7 @@ public sealed class TypeChecker return new ExternFuncIdentNode(type, expression.Name); } - throw new CheckException(Diagnostic.Error($"No identifier with the name {expression.Name} exists").Build()); + throw new TypeCheckerException(Diagnostic.Error($"No identifier with the name {expression.Name} exists").Build()); } private LiteralNode CheckLiteral(LiteralSyntax expression, TypeNode? expectedType = null) @@ -389,7 +389,7 @@ public sealed class TypeChecker { LiteralKind.Integer => new IntTypeNode(true, 64), LiteralKind.Float => new FloatTypeNode(64), - LiteralKind.String => new NubStringTypeNode(), + LiteralKind.String => new StringTypeNode(), LiteralKind.Bool => new BoolTypeNode(), _ => throw new ArgumentOutOfRangeException() }; @@ -401,14 +401,14 @@ public sealed class TypeChecker { var boundExpression = CheckExpression(expression.Target); - if (boundExpression.Type is CustomTypeNode customType) + if (boundExpression.Type is InterfaceTypeNode customType) { - var interfaces = _definitionTable.LookupInterface(customType).ToArray(); + var interfaces = _definitionTable.LookupInterface(customType.Name).ToArray(); if (interfaces.Length > 0) { if (interfaces.Length > 1) { - throw new CheckException(Diagnostic.Error($"Interface {customType} has multiple definitions").Build()); + throw new TypeCheckerException(Diagnostic.Error($"Interface {customType} has multiple definitions").Build()); } var @interface = interfaces[0]; @@ -418,7 +418,7 @@ public sealed class TypeChecker { if (interfaceFuncs.Length > 1) { - throw new CheckException(Diagnostic.Error($"Interface {customType} has multiple functions with the name {expression.Member}").Build()); + throw new TypeCheckerException(Diagnostic.Error($"Interface {customType} has multiple functions with the name {expression.Member}").Build()); } var interfaceFunc = interfaceFuncs[0]; @@ -429,13 +429,16 @@ public sealed class TypeChecker return new InterfaceFuncAccessNode(type, customType, boundExpression, expression.Member); } } + } - var structs = _definitionTable.LookupStruct(customType).ToArray(); + if (boundExpression.Type is StructTypeNode structType) + { + var structs = _definitionTable.LookupStruct(structType.Name).ToArray(); if (structs.Length > 0) { if (structs.Length > 1) { - throw new CheckException(Diagnostic.Error($"Struct {customType} has multiple definitions").Build()); + throw new TypeCheckerException(Diagnostic.Error($"Struct {structType} has multiple definitions").Build()); } var @struct = structs[0]; @@ -445,12 +448,12 @@ public sealed class TypeChecker { if (fields.Length > 1) { - throw new CheckException(Diagnostic.Error($"Struct {customType} has multiple fields with the name {expression.Member}").Build()); + throw new TypeCheckerException(Diagnostic.Error($"Struct {structType} has multiple fields with the name {expression.Member}").Build()); } var field = fields[0]; - return new StructFieldAccessNode(CheckType(field.Type), customType, boundExpression, expression.Member); + return new StructFieldAccessNode(CheckType(field.Type), structType, boundExpression, expression.Member); } var funcs = _definitionTable.LookupStructFunc(@struct, expression.Member).ToArray(); @@ -458,40 +461,40 @@ public sealed class TypeChecker { if (funcs.Length > 1) { - throw new CheckException(Diagnostic.Error($"Struct {customType} has multiple functions with the name {expression.Member}").Build()); + throw new TypeCheckerException(Diagnostic.Error($"Struct {structType} has multiple functions with the name {expression.Member}").Build()); } var func = funcs[0]; var parameters = func.Signature.Parameters.Select(x => CheckType(x.Type)).ToList(); var returnType = CheckType(func.Signature.ReturnType); - return new StructFuncAccessNode(new FuncTypeNode(parameters, returnType), customType, boundExpression, expression.Member); + return new StructFuncAccessNode(new FuncTypeNode(parameters, returnType), structType, boundExpression, expression.Member); } } } - throw new CheckException(Diagnostic.Error($"{boundExpression.Type} does not have a member with the name {expression.Member}").Build()); + throw new TypeCheckerException(Diagnostic.Error($"{boundExpression.Type} does not have a member with the name {expression.Member}").Build()); } private StructInitializerNode CheckStructInitializer(StructInitializerSyntax expression) { var boundType = CheckType(expression.StructType); - if (boundType is not CustomTypeNode structType) + if (boundType is not StructTypeNode structType) { - throw new CheckException(Diagnostic.Error($"Cannot initialize non-struct type {expression.StructType}").Build()); + throw new TypeCheckerException(Diagnostic.Error($"Cannot initialize non-struct type {expression.StructType}").Build()); } - var structs = _definitionTable.LookupStruct(structType).ToArray(); + var structs = _definitionTable.LookupStruct(structType.Name).ToArray(); if (structs.Length == 0) { - throw new CheckException(Diagnostic.Error($"Struct {structType} is not defined").Build()); + throw new TypeCheckerException(Diagnostic.Error($"Struct {structType} is not defined").Build()); } if (structs.Length > 1) { - throw new CheckException(Diagnostic.Error($"Struct {structType} has multiple definitions").Build()); + throw new TypeCheckerException(Diagnostic.Error($"Struct {structType} has multiple definitions").Build()); } var @struct = structs[0]; @@ -504,12 +507,12 @@ public sealed class TypeChecker if (fields.Length == 0) { - throw new CheckException(Diagnostic.Error($"Struct {structType} does not have a field with the name {field}").Build()); + throw new TypeCheckerException(Diagnostic.Error($"Struct {structType} does not have a field with the name {field}").Build()); } if (fields.Length > 1) { - throw new CheckException(Diagnostic.Error($"Struct {structType} has multiple fields with the name {field}").Build()); + throw new TypeCheckerException(Diagnostic.Error($"Struct {structType} has multiple fields with the name {field}").Build()); } initializers[field] = CheckExpression(initializer, CheckType(fields[0].Type)); @@ -632,16 +635,69 @@ public sealed class TypeChecker ArrayTypeSyntax type => new ArrayTypeNode(CheckType(type.BaseType)), BoolTypeSyntax => new BoolTypeNode(), CStringTypeSyntax => new CStringTypeNode(), - CustomTypeSyntax type => new CustomTypeNode(type.Name), + CustomTypeSyntax type => CheckCustomType(type), FloatTypeSyntax @float => new FloatTypeNode(@float.Width), FuncTypeSyntax type => new FuncTypeNode(type.Parameters.Select(CheckType).ToList(), CheckType(type.ReturnType)), IntTypeSyntax @int => new IntTypeNode(@int.Signed, @int.Width), PointerTypeSyntax type => new PointerTypeNode(CheckType(type.BaseType)), - StringTypeSyntax => new NubStringTypeNode(), + StringTypeSyntax => new StringTypeNode(), VoidTypeSyntax => new VoidTypeNode(), _ => throw new ArgumentOutOfRangeException(nameof(node)) }; } + + private TypeNode CheckCustomType(CustomTypeSyntax type) + { + var structs = _definitionTable.LookupStruct(type.Name).ToArray(); + if (structs.Length > 0) + { + if (structs.Length > 1) + { + throw new TypeCheckerException(Diagnostic.Error($"Struct {type.Name} has multiple definitions").Build()); + } + + var @struct = structs[0]; + + var fields = @struct.Fields.Select(x => CheckType(x.Type)).ToList(); + + var funcs = @struct.Functions + .Select(x => new FuncTypeNode(x.Signature.Parameters.Select(p => CheckType(p.Type)).ToList(), CheckType(x.Signature.ReturnType))) + .ToList(); + + var interfaceImplementations = new List(); + + foreach (var structInterfaceImplementation in @struct.InterfaceImplementations) + { + var checkedInterfaceType = CheckType(structInterfaceImplementation); + if (checkedInterfaceType is not InterfaceTypeNode interfaceType) + { + throw new TypeCheckerException(Diagnostic.Error($"{type.Name} cannot implement non-interface type {checkedInterfaceType}").Build()); + } + + interfaceImplementations.Add(interfaceType); + } + + return new StructTypeNode(type.Name, fields, funcs, interfaceImplementations); + } + + var interfaces = _definitionTable.LookupInterface(type.Name).ToArray(); + if (interfaces.Length > 0) + { + if (interfaces.Length > 1) + { + throw new TypeCheckerException(Diagnostic.Error($"Interface {type.Name} has multiple definitions").Build()); + } + + var @interface = interfaces[0]; + var functions = @interface.Functions + .Select(x => new FuncTypeNode(x.Signature.Parameters.Select(y => CheckType(y.Type)).ToList(), CheckType(x.Signature.ReturnType))) + .ToList(); + + return new InterfaceTypeNode(type.Name, functions); + } + + throw new TypeCheckerException(Diagnostic.Error($"Type {type.Name} is not defined").Build()); + } } public record Variable(string Name, TypeNode Type); @@ -672,11 +728,11 @@ public class Scope(Scope? parent = null) } } -public class CheckException : Exception +public class TypeCheckerException : Exception { public Diagnostic Diagnostic { get; } - public CheckException(Diagnostic diagnostic) : base(diagnostic.Message) + public TypeCheckerException(Diagnostic diagnostic) : base(diagnostic.Message) { Diagnostic = diagnostic; }