This commit is contained in:
nub31
2025-08-12 20:55:35 +02:00
parent 1ef1df545f
commit a591f5b553
7 changed files with 248 additions and 174 deletions

View File

@@ -48,7 +48,7 @@ public partial class QBEGenerator
var elementType = ((ArrayTypeNode)arrayIndexAccess.Target.Type).ElementType; var elementType = ((ArrayTypeNode)arrayIndexAccess.Target.Type).ElementType;
var pointer = TmpName(); var pointer = TmpName();
_writer.Indented($"{pointer} =l mul {index}, {elementType.Size(_definitionTable)}"); _writer.Indented($"{pointer} =l mul {index}, {SizeOf(elementType)}");
_writer.Indented($"{pointer} =l add {pointer}, 8"); _writer.Indented($"{pointer} =l add {pointer}, 8");
_writer.Indented($"{pointer} =l add {array}, {pointer}"); _writer.Indented($"{pointer} =l add {array}, {pointer}");
return new Val(pointer, arrayIndexAccess.Type, ValKind.Pointer); return new Val(pointer, arrayIndexAccess.Type, ValKind.Pointer);
@@ -81,7 +81,7 @@ public partial class QBEGenerator
private Val EmitArrayInitializer(ArrayInitializerNode arrayInitializer) private Val EmitArrayInitializer(ArrayInitializerNode arrayInitializer)
{ {
var capacity = EmitUnwrap(EmitExpression(arrayInitializer.Capacity)); var capacity = EmitUnwrap(EmitExpression(arrayInitializer.Capacity));
var elementSize = arrayInitializer.ElementType.Size(_definitionTable); var elementSize = SizeOf(arrayInitializer.ElementType);
var capacityInBytes = TmpName(); var capacityInBytes = TmpName();
_writer.Indented($"{capacityInBytes} =l mul {capacity}, {elementSize}"); _writer.Indented($"{capacityInBytes} =l mul {capacity}, {elementSize}");
@@ -291,7 +291,7 @@ public partial class QBEGenerator
} }
case LiteralKind.String: case LiteralKind.String:
{ {
if (literal.Type is NubStringTypeNode) if (literal.Type is StringTypeNode)
{ {
var stringLiteral = new StringLiteral(literal.Value, StringName()); var stringLiteral = new StringLiteral(literal.Value, StringName());
_stringLiterals.Add(stringLiteral); _stringLiterals.Add(stringLiteral);
@@ -328,7 +328,7 @@ public partial class QBEGenerator
if (destination == null) if (destination == null)
{ {
destination = TmpName(); destination = TmpName();
var size = structInitializer.StructType.Size(_definitionTable); var size = SizeOf(structInitializer.StructType);
if (structDef.InterfaceImplementations.Any()) if (structDef.InterfaceImplementations.Any())
{ {
@@ -413,7 +413,7 @@ public partial class QBEGenerator
_writer.Indented($"{output} =l add {target}, {offset}"); _writer.Indented($"{output} =l add {target}, {offset}");
// If the accessed member is an inline struct, it will not be a pointer // If the accessed member is an inline struct, it will not be a pointer
if (structFieldAccess.Type is CustomTypeNode customType && customType.Kind(_definitionTable) == CustomTypeKind.Struct) if (structFieldAccess.Type is StructTypeNode)
{ {
return new Val(output, structFieldAccess.Type, ValKind.Direct); return new Val(output, structFieldAccess.Type, ValKind.Direct);
} }

View File

@@ -189,7 +189,7 @@ public partial class QBEGenerator
{ {
var size = TmpName(); var size = TmpName();
_writer.Indented($"{size} =l loadl {array}"); _writer.Indented($"{size} =l loadl {array}");
_writer.Indented($"{size} =l mul {size}, {type.ElementType.Size(_definitionTable)}"); _writer.Indented($"{size} =l mul {size}, {SizeOf(type.ElementType)}");
_writer.Indented($"{size} =l add {size}, 8"); _writer.Indented($"{size} =l add {size}, 8");
return size; return size;
} }
@@ -250,9 +250,9 @@ public partial class QBEGenerator
} }
else else
{ {
if (complexType is CustomTypeNode customType) if (complexType is StructTypeNode structType)
{ {
EmitMemcpy(value, destinationPointer, customType.Size(_definitionTable).ToString()); EmitMemcpy(value, destinationPointer, SizeOf(structType).ToString());
} }
else else
{ {
@@ -260,7 +260,8 @@ public partial class QBEGenerator
{ {
ArrayTypeNode arrayType => EmitArraySizeInBytes(arrayType, value), ArrayTypeNode arrayType => EmitArraySizeInBytes(arrayType, value),
CStringTypeNode => EmitCStringSizeInBytes(value), CStringTypeNode => EmitCStringSizeInBytes(value),
NubStringTypeNode => EmitStringSizeInBytes(value), StringTypeNode => EmitStringSizeInBytes(value),
InterfaceTypeNode => 16.ToString(),
_ => throw new ArgumentOutOfRangeException(nameof(source.Type)) _ => throw new ArgumentOutOfRangeException(nameof(source.Type))
}; };
@@ -309,8 +310,9 @@ public partial class QBEGenerator
{ {
ArrayTypeNode arrayType => EmitArraySizeInBytes(arrayType, value), ArrayTypeNode arrayType => EmitArraySizeInBytes(arrayType, value),
CStringTypeNode => EmitCStringSizeInBytes(value), CStringTypeNode => EmitCStringSizeInBytes(value),
NubStringTypeNode => EmitStringSizeInBytes(value), StringTypeNode => EmitStringSizeInBytes(value),
CustomTypeNode customType => customType.Size(_definitionTable).ToString(), InterfaceTypeNode => 16.ToString(),
StructTypeNode structType => SizeOf(structType).ToString(),
_ => throw new ArgumentOutOfRangeException(nameof(source.Type)) _ => throw new ArgumentOutOfRangeException(nameof(source.Type))
}; };
@@ -339,9 +341,14 @@ public partial class QBEGenerator
}; };
} }
if (complexType is CustomTypeNode customType) if (complexType is StructTypeNode structType)
{ {
return CustomTypeName(customType.Name); return StructTypeName(structType.Name);
}
if (complexType is InterfaceTypeNode interfaceType)
{
return InterfaceTypeName(interfaceType.Name);
} }
return "l"; return "l";
@@ -387,7 +394,7 @@ public partial class QBEGenerator
private void EmitStructTypeDefinition(StructNode structDef) private void EmitStructTypeDefinition(StructNode structDef)
{ {
_writer.WriteLine($"type {CustomTypeName(structDef.Name)} = {{ "); _writer.WriteLine($"type {StructTypeName(structDef.Name)} = {{ ");
var types = new Dictionary<string, string>(); var types = new Dictionary<string, string>();
@@ -422,9 +429,14 @@ public partial class QBEGenerator
}; };
} }
if (complexType is CustomTypeNode customType) if (complexType is StructTypeNode structType)
{ {
return CustomTypeName(customType.Name); return StructTypeName(structType.Name);
}
if (complexType is InterfaceTypeNode interfaceType)
{
return InterfaceTypeName(interfaceType.Name);
} }
return "l"; return "l";
@@ -458,6 +470,93 @@ public partial class QBEGenerator
}; };
} }
private static int SizeOf(TypeNode type)
{
return type switch
{
SimpleTypeNode simple => simple.StorageSize switch
{
StorageSize.Void => 0,
StorageSize.I8 or StorageSize.U8 => 1,
StorageSize.I16 or StorageSize.U16 => 2,
StorageSize.I32 or StorageSize.U32 or StorageSize.F32 => 4,
StorageSize.I64 or StorageSize.U64 or StorageSize.F64 => 8,
_ => throw new ArgumentOutOfRangeException(nameof(type), $"Unknown storage size: {simple.StorageSize}")
},
CStringTypeNode => 8,
StringTypeNode => 8,
ArrayTypeNode => 8,
StructTypeNode structType => CalculateStructSize(structType),
InterfaceTypeNode => 16,
_ => throw new ArgumentOutOfRangeException(nameof(type), $"Unknown type: {type.GetType()}")
};
}
private static int CalculateStructSize(StructTypeNode structType)
{
var offset = 0;
if (structType.InterfaceImplementations.Any())
{
offset = 8;
}
foreach (var field in structType.Fields)
{
var fieldAlignment = AlignmentOf(field);
offset = AlignTo(offset, fieldAlignment);
offset += SizeOf(field);
}
var structAlignment = CalculateStructAlignment(structType);
return AlignTo(offset, structAlignment);
}
private static int AlignmentOf(TypeNode type)
{
return type switch
{
SimpleTypeNode simple => simple.StorageSize switch
{
StorageSize.Void => 1,
StorageSize.I8 or StorageSize.U8 => 1,
StorageSize.I16 or StorageSize.U16 => 2,
StorageSize.I32 or StorageSize.U32 or StorageSize.F32 => 4,
StorageSize.I64 or StorageSize.U64 or StorageSize.F64 => 8,
_ => throw new ArgumentOutOfRangeException(nameof(type), $"Unknown storage size: {simple.StorageSize}")
},
CStringTypeNode => 8,
StringTypeNode => 8,
ArrayTypeNode => 8,
StructTypeNode structType => CalculateStructAlignment(structType),
InterfaceTypeNode => 8,
_ => throw new ArgumentOutOfRangeException(nameof(type), $"Unknown type: {type.GetType()}")
};
}
private static int CalculateStructAlignment(StructTypeNode structType)
{
var maxAlignment = 1;
if (structType.InterfaceImplementations.Any())
{
maxAlignment = Math.Max(maxAlignment, 8);
}
foreach (var field in structType.Fields)
{
var fieldAlignment = AlignmentOf(field);
maxAlignment = Math.Max(maxAlignment, fieldAlignment);
}
return maxAlignment;
}
private static int AlignTo(int offset, int alignment)
{
return (offset + alignment - 1) & ~(alignment - 1);
}
private int OffsetOf(StructNode structDef, string member) private int OffsetOf(StructNode structDef, string member)
{ {
var offset = 0; var offset = 0;
@@ -474,10 +573,10 @@ public partial class QBEGenerator
return offset; return offset;
} }
var fieldAlignment = field.Type.Alignment(_definitionTable); var fieldAlignment = AlignmentOf(field.Type);
offset = TypeNode.AlignTo(offset, fieldAlignment); offset = AlignTo(offset, fieldAlignment);
offset += field.Type.Size(_definitionTable); offset += SizeOf(field.Type);
} }
throw new UnreachableException($"Member '{member}' not found in struct"); throw new UnreachableException($"Member '{member}' not found in struct");
@@ -515,7 +614,12 @@ public partial class QBEGenerator
return $"${funcDef.CallName}"; return $"${funcDef.CallName}";
} }
private string CustomTypeName(string name) private string StructTypeName(string name)
{
return $":{name}";
}
private string InterfaceTypeName(string name)
{ {
return $":{name}"; return $":{name}";
} }

View File

@@ -26,11 +26,11 @@ public class DefinitionTable
.Where(x => x.Name == name); .Where(x => x.Name == name);
} }
public IEnumerable<StructSyntax> LookupStruct(CustomTypeNode type) public IEnumerable<StructSyntax> LookupStruct(string name)
{ {
return _definitions return _definitions
.OfType<StructSyntax>() .OfType<StructSyntax>()
.Where(x => x.Name == type.Name); .Where(x => x.Name == name);
} }
public IEnumerable<StructFieldSyntax> LookupStructField(StructSyntax @struct, string field) public IEnumerable<StructFieldSyntax> LookupStructField(StructSyntax @struct, string field)
@@ -43,11 +43,11 @@ public class DefinitionTable
return @struct.Functions.Where(x => x.Name == func); return @struct.Functions.Where(x => x.Name == func);
} }
public IEnumerable<InterfaceSyntax> LookupInterface(CustomTypeNode type) public IEnumerable<InterfaceSyntax> LookupInterface(string name)
{ {
return _definitions return _definitions
.OfType<InterfaceSyntax>() .OfType<InterfaceSyntax>()
.Where(x => x.Name == type.Name); .Where(x => x.Name == name);
} }
public IEnumerable<InterfaceFuncSyntax> LookupInterfaceFunc(InterfaceSyntax @interface, string name) public IEnumerable<InterfaceFuncSyntax> LookupInterfaceFunc(InterfaceSyntax @interface, string name)

View File

@@ -14,7 +14,7 @@ public record StructFieldNode(int Index, string Name, TypeNode Type, Optional<Ex
public record StructFuncNode(string Name, FuncSignatureNode Signature, BlockNode Body) : Node; public record StructFuncNode(string Name, FuncSignatureNode Signature, BlockNode Body) : Node;
public record StructNode(string Name, IReadOnlyList<StructFieldNode> Fields, IReadOnlyList<StructFuncNode> Functions, IReadOnlyList<CustomTypeNode> InterfaceImplementations) : DefinitionNode; public record StructNode(string Name, IReadOnlyList<StructFieldNode> Fields, IReadOnlyList<StructFuncNode> Functions, IReadOnlyList<InterfaceTypeNode> InterfaceImplementations) : DefinitionNode;
public record InterfaceFuncNode(string Name, FuncSignatureNode Signature) : Node; public record InterfaceFuncNode(string Name, FuncSignatureNode Signature) : Node;

View File

@@ -46,12 +46,12 @@ public record AddressOfNode(TypeNode Type, ExpressionNode Expression) : Expressi
public record LiteralNode(TypeNode Type, string Value, LiteralKind Kind) : ExpressionNode(Type); public record LiteralNode(TypeNode Type, string Value, LiteralKind Kind) : ExpressionNode(Type);
public record StructFieldAccessNode(TypeNode Type, CustomTypeNode StructType, ExpressionNode Target, string Field) : ExpressionNode(Type); public record StructFieldAccessNode(TypeNode Type, StructTypeNode StructType, ExpressionNode Target, string Field) : ExpressionNode(Type);
public record StructFuncAccessNode(TypeNode Type, CustomTypeNode StructType, ExpressionNode Target, string Func) : ExpressionNode(Type); public record StructFuncAccessNode(TypeNode Type, StructTypeNode StructType, ExpressionNode Target, string Func) : ExpressionNode(Type);
public record InterfaceFuncAccessNode(TypeNode Type, CustomTypeNode InterfaceType, ExpressionNode Target, string FuncName) : ExpressionNode(Type); public record InterfaceFuncAccessNode(TypeNode Type, InterfaceTypeNode InterfaceType, ExpressionNode Target, string FuncName) : ExpressionNode(Type);
public record StructInitializerNode(CustomTypeNode StructType, Dictionary<string, ExpressionNode> Initializers) : ExpressionNode(StructType); public record StructInitializerNode(StructTypeNode StructType, Dictionary<string, ExpressionNode> Initializers) : ExpressionNode(StructType);
public record DereferenceNode(TypeNode Type, ExpressionNode Expression) : ExpressionNode(Type); public record DereferenceNode(TypeNode Type, ExpressionNode Expression) : ExpressionNode(Type);

View File

@@ -1,11 +1,10 @@
using System.Diagnostics.CodeAnalysis; using System.Diagnostics.CodeAnalysis;
using NubLang.Generation;
namespace NubLang.TypeChecking.Node; namespace NubLang.TypeChecking.Node;
public abstract class TypeNode : IEquatable<TypeNode> public abstract class TypeNode : IEquatable<TypeNode>
{ {
public bool IsSimpleType([NotNullWhen(true)] out SimpleTypeNode? simpleType, [NotNullWhen(false)] out NubComplexTypeNode? complexType) public bool IsSimpleType([NotNullWhen(true)] out SimpleTypeNode? simpleType, [NotNullWhen(false)] out ComplexTypeNode? complexType)
{ {
if (this is SimpleTypeNode st) if (this is SimpleTypeNode st)
{ {
@@ -14,7 +13,7 @@ public abstract class TypeNode : IEquatable<TypeNode>
return true; return true;
} }
if (this is NubComplexTypeNode ct) if (this is ComplexTypeNode ct)
{ {
complexType = ct; complexType = ct;
simpleType = null; simpleType = null;
@@ -24,14 +23,6 @@ public abstract class TypeNode : IEquatable<TypeNode>
throw new ArgumentException($"Type {this} is not a simple type nor a complex type"); throw new ArgumentException($"Type {this} is not a simple type nor a complex type");
} }
public abstract int Size(TypedDefinitionTable definitionTable);
public abstract int Alignment(TypedDefinitionTable definitionTable);
public static int AlignTo(int offset, int alignment)
{
return (offset + alignment - 1) & ~(alignment - 1);
}
public override bool Equals(object? obj) => obj is TypeNode other && Equals(other); public override bool Equals(object? obj) => obj is TypeNode other && Equals(other);
public abstract bool Equals(TypeNode? other); public abstract bool Equals(TypeNode? other);
@@ -60,23 +51,6 @@ public enum StorageSize
public abstract class SimpleTypeNode : TypeNode public abstract class SimpleTypeNode : TypeNode
{ {
public abstract StorageSize StorageSize { get; } public abstract StorageSize StorageSize { get; }
public override int Size(TypedDefinitionTable definitionTable)
{
return StorageSize switch
{
StorageSize.I64 or StorageSize.U64 or StorageSize.F64 => 8,
StorageSize.I32 or StorageSize.U32 or StorageSize.F32 => 4,
StorageSize.I16 or StorageSize.U16 => 2,
StorageSize.I8 or StorageSize.U8 => 1,
_ => throw new ArgumentOutOfRangeException(nameof(StorageSize))
};
}
public override int Alignment(TypedDefinitionTable definitionTable)
{
return Size(definitionTable);
}
} }
#region Simple types #region Simple types
@@ -136,7 +110,7 @@ public class BoolTypeNode : SimpleTypeNode
public override int GetHashCode() => HashCode.Combine(typeof(BoolTypeNode)); public override int GetHashCode() => HashCode.Combine(typeof(BoolTypeNode));
} }
public class FuncTypeNode(List<TypeNode> parameters, TypeNode returnType) : SimpleTypeNode public class FuncTypeNode(IReadOnlyList<TypeNode> parameters, TypeNode returnType) : SimpleTypeNode
{ {
public IReadOnlyList<TypeNode> Parameters { get; } = parameters; public IReadOnlyList<TypeNode> Parameters { get; } = parameters;
public TypeNode ReturnType { get; } = returnType; public TypeNode ReturnType { get; } = returnType;
@@ -182,110 +156,50 @@ public class VoidTypeNode : SimpleTypeNode
#endregion #endregion
public abstract class NubComplexTypeNode : TypeNode; public abstract class ComplexTypeNode : TypeNode;
#region Complex types #region Complex types
public class CStringTypeNode : NubComplexTypeNode public class CStringTypeNode : ComplexTypeNode
{ {
public override int Size(TypedDefinitionTable definitionTable) => 8;
public override int Alignment(TypedDefinitionTable definitionTable) => Size(definitionTable);
public override string ToString() => "cstring"; public override string ToString() => "cstring";
public override bool Equals(TypeNode? other) => other is CStringTypeNode; public override bool Equals(TypeNode? other) => other is CStringTypeNode;
public override int GetHashCode() => HashCode.Combine(typeof(CStringTypeNode)); public override int GetHashCode() => HashCode.Combine(typeof(CStringTypeNode));
} }
public class NubStringTypeNode : NubComplexTypeNode public class StringTypeNode : ComplexTypeNode
{ {
public override int Size(TypedDefinitionTable definitionTable) => 8;
public override int Alignment(TypedDefinitionTable definitionTable) => Size(definitionTable);
public override string ToString() => "string"; public override string ToString() => "string";
public override bool Equals(TypeNode? other) => other is NubStringTypeNode; public override bool Equals(TypeNode? other) => other is StringTypeNode;
public override int GetHashCode() => HashCode.Combine(typeof(NubStringTypeNode)); public override int GetHashCode() => HashCode.Combine(typeof(StringTypeNode));
} }
public class CustomTypeNode(string name) : NubComplexTypeNode public class StructTypeNode(string name, IReadOnlyList<TypeNode> fields, IReadOnlyList<FuncTypeNode> functions, IReadOnlyList<InterfaceTypeNode> interfaceImplementations) : ComplexTypeNode
{ {
public string Name { get; } = name; public string Name { get; } = name;
public IReadOnlyList<TypeNode> Fields { get; } = fields;
public CustomTypeKind Kind(TypedDefinitionTable definitionTable) public IReadOnlyList<FuncTypeNode> Functions { get; } = functions;
{ public IReadOnlyList<InterfaceTypeNode> InterfaceImplementations { get; } = interfaceImplementations;
if (definitionTable.GetStructs().Any(x => x.Name == Name))
{
return CustomTypeKind.Struct;
}
if (definitionTable.GetInterfaces().Any(x => x.Name == Name))
{
return CustomTypeKind.Interface;
}
throw new ArgumentException($"Definition table does not have any type information for {this}");
}
public override int Size(TypedDefinitionTable definitionTable)
{
switch (Kind(definitionTable))
{
case CustomTypeKind.Struct:
{
var structDef = definitionTable.LookupStruct(Name);
var size = 0;
var maxAlignment = 1;
foreach (var field in structDef.Fields)
{
var fieldAlignment = field.Type.Alignment(definitionTable);
maxAlignment = Math.Max(maxAlignment, fieldAlignment);
size = AlignTo(size, fieldAlignment);
size += field.Type.Size(definitionTable);
}
return AlignTo(size, maxAlignment);
}
case CustomTypeKind.Interface:
{
return 16;
}
default:
throw new ArgumentOutOfRangeException();
}
}
public override int Alignment(TypedDefinitionTable definitionTable)
{
switch (Kind(definitionTable))
{
case CustomTypeKind.Struct:
return definitionTable.LookupStruct(Name).Fields.Max(f => f.Type.Alignment(definitionTable));
case CustomTypeKind.Interface:
return 8;
default:
throw new ArgumentOutOfRangeException();
}
}
public override string ToString() => Name; public override string ToString() => Name;
public override bool Equals(TypeNode? other) => other is CustomTypeNode custom && Name == custom.Name; public override bool Equals(TypeNode? other) => other is StructTypeNode custom && Name == custom.Name;
public override int GetHashCode() => HashCode.Combine(typeof(CustomTypeNode), Name); public override int GetHashCode() => HashCode.Combine(typeof(StructTypeNode), Name);
} }
public enum CustomTypeKind public class InterfaceTypeNode(string name, IReadOnlyList<FuncTypeNode> funcs) : ComplexTypeNode
{ {
Struct, public string Name { get; } = name;
Interface public IReadOnlyList<FuncTypeNode> Funcs { get; } = funcs;
public override string ToString() => Name;
public override bool Equals(TypeNode? other) => other is InterfaceTypeNode custom && Name == custom.Name;
public override int GetHashCode() => HashCode.Combine(typeof(InterfaceTypeNode), Name);
} }
public class ArrayTypeNode(TypeNode elementType) : NubComplexTypeNode public class ArrayTypeNode(TypeNode elementType) : ComplexTypeNode
{ {
public TypeNode ElementType { get; } = elementType; public TypeNode ElementType { get; } = elementType;
public override int Size(TypedDefinitionTable definitionTable) => 8;
public override int Alignment(TypedDefinitionTable definitionTable) => Size(definitionTable);
public override string ToString() => "[]" + ElementType; public override string ToString() => "[]" + ElementType;
public override bool Equals(TypeNode? other) => other is ArrayTypeNode array && ElementType.Equals(array.ElementType); public override bool Equals(TypeNode? other) => other is ArrayTypeNode array && ElementType.Equals(array.ElementType);

View File

@@ -38,7 +38,7 @@ public sealed class TypeChecker
{ {
definitions.Add(CheckDefinition(definition)); definitions.Add(CheckDefinition(definition));
} }
catch (CheckException e) catch (TypeCheckerException e)
{ {
_diagnostics.Add(e.Diagnostic); _diagnostics.Add(e.Diagnostic);
} }
@@ -100,32 +100,32 @@ public sealed class TypeChecker
funcs.Add(new StructFuncNode(func.Name, CheckFuncSignature(func.Signature), CheckFuncBody(func.Body, CheckType(func.Signature.ReturnType), parameters))); funcs.Add(new StructFuncNode(func.Name, CheckFuncSignature(func.Signature), CheckFuncBody(func.Body, CheckType(func.Signature.ReturnType), parameters)));
} }
var interfaceImplementations = new List<CustomTypeNode>(); var interfaceImplementations = new List<InterfaceTypeNode>();
foreach (var interfaceImplementation in node.InterfaceImplementations) foreach (var interfaceImplementation in node.InterfaceImplementations)
{ {
var interfaceType = CheckType(interfaceImplementation); var type = CheckType(interfaceImplementation);
if (interfaceType is not CustomTypeNode customType) if (type is not InterfaceTypeNode interfaceType)
{ {
_diagnostics.Add(Diagnostic.Error("Interface implementation is not a custom type").Build()); _diagnostics.Add(Diagnostic.Error("Interface implementation is not a custom type").Build());
continue; continue;
} }
var interfaceDefs = _definitionTable.LookupInterface(customType).ToArray(); var interfaceDefs = _definitionTable.LookupInterface(interfaceType.Name).ToArray();
if (interfaceDefs.Length == 0) if (interfaceDefs.Length == 0)
{ {
_diagnostics.Add(Diagnostic.Error($"Interface {customType.Name} is not defined").Build()); _diagnostics.Add(Diagnostic.Error($"Interface {interfaceType.Name} is not defined").Build());
continue; continue;
} }
if (interfaceDefs.Length > 1) if (interfaceDefs.Length > 1)
{ {
_diagnostics.Add(Diagnostic.Error($"Interface {customType.Name} has multiple definitions").Build()); _diagnostics.Add(Diagnostic.Error($"Interface {interfaceType.Name} has multiple definitions").Build());
continue; continue;
} }
interfaceImplementations.Add(customType); interfaceImplementations.Add(interfaceType);
} }
return new StructNode(node.Name, structFields, funcs, interfaceImplementations); return new StructNode(node.Name, structFields, funcs, interfaceImplementations);
@@ -262,12 +262,12 @@ public sealed class TypeChecker
{ {
if (expectedType == null) if (expectedType == null)
{ {
throw new CheckException(Diagnostic.Error("Cannot infer argument types for arrow function").Build()); throw new TypeCheckerException(Diagnostic.Error("Cannot infer argument types for arrow function").Build());
} }
if (expectedType is not FuncTypeNode funcType) if (expectedType is not FuncTypeNode funcType)
{ {
throw new CheckException(Diagnostic.Error($"Expected {expectedType}, but got arrow function").Build()); throw new TypeCheckerException(Diagnostic.Error($"Expected {expectedType}, but got arrow function").Build());
} }
var parameters = new List<FuncParameterNode>(); var parameters = new List<FuncParameterNode>();
@@ -276,7 +276,7 @@ public sealed class TypeChecker
{ {
if (i >= funcType.Parameters.Count) if (i >= funcType.Parameters.Count)
{ {
throw new CheckException(Diagnostic.Error($"Arrow function expected a maximum of {funcType.Parameters.Count} arguments").Build()); throw new TypeCheckerException(Diagnostic.Error($"Arrow function expected a maximum of {funcType.Parameters.Count} arguments").Build());
} }
var expectedParameterType = funcType.Parameters[i]; var expectedParameterType = funcType.Parameters[i];
@@ -353,7 +353,7 @@ public sealed class TypeChecker
{ {
if (localFuncs.Length > 1) if (localFuncs.Length > 1)
{ {
throw new CheckException(Diagnostic.Error($"Extern func {expression.Name} has multiple definitions").Build()); throw new TypeCheckerException(Diagnostic.Error($"Extern func {expression.Name} has multiple definitions").Build());
} }
var localFunc = localFuncs[0]; var localFunc = localFuncs[0];
@@ -369,7 +369,7 @@ public sealed class TypeChecker
{ {
if (externFuncs.Length > 1) if (externFuncs.Length > 1)
{ {
throw new CheckException(Diagnostic.Error($"Extern func {expression.Name} has multiple definitions").Build()); throw new TypeCheckerException(Diagnostic.Error($"Extern func {expression.Name} has multiple definitions").Build());
} }
var externFunc = externFuncs[0]; var externFunc = externFuncs[0];
@@ -380,7 +380,7 @@ public sealed class TypeChecker
return new ExternFuncIdentNode(type, expression.Name); return new ExternFuncIdentNode(type, expression.Name);
} }
throw new CheckException(Diagnostic.Error($"No identifier with the name {expression.Name} exists").Build()); throw new TypeCheckerException(Diagnostic.Error($"No identifier with the name {expression.Name} exists").Build());
} }
private LiteralNode CheckLiteral(LiteralSyntax expression, TypeNode? expectedType = null) private LiteralNode CheckLiteral(LiteralSyntax expression, TypeNode? expectedType = null)
@@ -389,7 +389,7 @@ public sealed class TypeChecker
{ {
LiteralKind.Integer => new IntTypeNode(true, 64), LiteralKind.Integer => new IntTypeNode(true, 64),
LiteralKind.Float => new FloatTypeNode(64), LiteralKind.Float => new FloatTypeNode(64),
LiteralKind.String => new NubStringTypeNode(), LiteralKind.String => new StringTypeNode(),
LiteralKind.Bool => new BoolTypeNode(), LiteralKind.Bool => new BoolTypeNode(),
_ => throw new ArgumentOutOfRangeException() _ => throw new ArgumentOutOfRangeException()
}; };
@@ -401,14 +401,14 @@ public sealed class TypeChecker
{ {
var boundExpression = CheckExpression(expression.Target); var boundExpression = CheckExpression(expression.Target);
if (boundExpression.Type is CustomTypeNode customType) if (boundExpression.Type is InterfaceTypeNode customType)
{ {
var interfaces = _definitionTable.LookupInterface(customType).ToArray(); var interfaces = _definitionTable.LookupInterface(customType.Name).ToArray();
if (interfaces.Length > 0) if (interfaces.Length > 0)
{ {
if (interfaces.Length > 1) if (interfaces.Length > 1)
{ {
throw new CheckException(Diagnostic.Error($"Interface {customType} has multiple definitions").Build()); throw new TypeCheckerException(Diagnostic.Error($"Interface {customType} has multiple definitions").Build());
} }
var @interface = interfaces[0]; var @interface = interfaces[0];
@@ -418,7 +418,7 @@ public sealed class TypeChecker
{ {
if (interfaceFuncs.Length > 1) if (interfaceFuncs.Length > 1)
{ {
throw new CheckException(Diagnostic.Error($"Interface {customType} has multiple functions with the name {expression.Member}").Build()); throw new TypeCheckerException(Diagnostic.Error($"Interface {customType} has multiple functions with the name {expression.Member}").Build());
} }
var interfaceFunc = interfaceFuncs[0]; var interfaceFunc = interfaceFuncs[0];
@@ -429,13 +429,16 @@ public sealed class TypeChecker
return new InterfaceFuncAccessNode(type, customType, boundExpression, expression.Member); return new InterfaceFuncAccessNode(type, customType, boundExpression, expression.Member);
} }
} }
}
var structs = _definitionTable.LookupStruct(customType).ToArray(); if (boundExpression.Type is StructTypeNode structType)
{
var structs = _definitionTable.LookupStruct(structType.Name).ToArray();
if (structs.Length > 0) if (structs.Length > 0)
{ {
if (structs.Length > 1) if (structs.Length > 1)
{ {
throw new CheckException(Diagnostic.Error($"Struct {customType} has multiple definitions").Build()); throw new TypeCheckerException(Diagnostic.Error($"Struct {structType} has multiple definitions").Build());
} }
var @struct = structs[0]; var @struct = structs[0];
@@ -445,12 +448,12 @@ public sealed class TypeChecker
{ {
if (fields.Length > 1) if (fields.Length > 1)
{ {
throw new CheckException(Diagnostic.Error($"Struct {customType} has multiple fields with the name {expression.Member}").Build()); throw new TypeCheckerException(Diagnostic.Error($"Struct {structType} has multiple fields with the name {expression.Member}").Build());
} }
var field = fields[0]; var field = fields[0];
return new StructFieldAccessNode(CheckType(field.Type), customType, boundExpression, expression.Member); return new StructFieldAccessNode(CheckType(field.Type), structType, boundExpression, expression.Member);
} }
var funcs = _definitionTable.LookupStructFunc(@struct, expression.Member).ToArray(); var funcs = _definitionTable.LookupStructFunc(@struct, expression.Member).ToArray();
@@ -458,40 +461,40 @@ public sealed class TypeChecker
{ {
if (funcs.Length > 1) if (funcs.Length > 1)
{ {
throw new CheckException(Diagnostic.Error($"Struct {customType} has multiple functions with the name {expression.Member}").Build()); throw new TypeCheckerException(Diagnostic.Error($"Struct {structType} has multiple functions with the name {expression.Member}").Build());
} }
var func = funcs[0]; var func = funcs[0];
var parameters = func.Signature.Parameters.Select(x => CheckType(x.Type)).ToList(); var parameters = func.Signature.Parameters.Select(x => CheckType(x.Type)).ToList();
var returnType = CheckType(func.Signature.ReturnType); var returnType = CheckType(func.Signature.ReturnType);
return new StructFuncAccessNode(new FuncTypeNode(parameters, returnType), customType, boundExpression, expression.Member); return new StructFuncAccessNode(new FuncTypeNode(parameters, returnType), structType, boundExpression, expression.Member);
} }
} }
} }
throw new CheckException(Diagnostic.Error($"{boundExpression.Type} does not have a member with the name {expression.Member}").Build()); throw new TypeCheckerException(Diagnostic.Error($"{boundExpression.Type} does not have a member with the name {expression.Member}").Build());
} }
private StructInitializerNode CheckStructInitializer(StructInitializerSyntax expression) private StructInitializerNode CheckStructInitializer(StructInitializerSyntax expression)
{ {
var boundType = CheckType(expression.StructType); var boundType = CheckType(expression.StructType);
if (boundType is not CustomTypeNode structType) if (boundType is not StructTypeNode structType)
{ {
throw new CheckException(Diagnostic.Error($"Cannot initialize non-struct type {expression.StructType}").Build()); throw new TypeCheckerException(Diagnostic.Error($"Cannot initialize non-struct type {expression.StructType}").Build());
} }
var structs = _definitionTable.LookupStruct(structType).ToArray(); var structs = _definitionTable.LookupStruct(structType.Name).ToArray();
if (structs.Length == 0) if (structs.Length == 0)
{ {
throw new CheckException(Diagnostic.Error($"Struct {structType} is not defined").Build()); throw new TypeCheckerException(Diagnostic.Error($"Struct {structType} is not defined").Build());
} }
if (structs.Length > 1) if (structs.Length > 1)
{ {
throw new CheckException(Diagnostic.Error($"Struct {structType} has multiple definitions").Build()); throw new TypeCheckerException(Diagnostic.Error($"Struct {structType} has multiple definitions").Build());
} }
var @struct = structs[0]; var @struct = structs[0];
@@ -504,12 +507,12 @@ public sealed class TypeChecker
if (fields.Length == 0) if (fields.Length == 0)
{ {
throw new CheckException(Diagnostic.Error($"Struct {structType} does not have a field with the name {field}").Build()); throw new TypeCheckerException(Diagnostic.Error($"Struct {structType} does not have a field with the name {field}").Build());
} }
if (fields.Length > 1) if (fields.Length > 1)
{ {
throw new CheckException(Diagnostic.Error($"Struct {structType} has multiple fields with the name {field}").Build()); throw new TypeCheckerException(Diagnostic.Error($"Struct {structType} has multiple fields with the name {field}").Build());
} }
initializers[field] = CheckExpression(initializer, CheckType(fields[0].Type)); initializers[field] = CheckExpression(initializer, CheckType(fields[0].Type));
@@ -632,16 +635,69 @@ public sealed class TypeChecker
ArrayTypeSyntax type => new ArrayTypeNode(CheckType(type.BaseType)), ArrayTypeSyntax type => new ArrayTypeNode(CheckType(type.BaseType)),
BoolTypeSyntax => new BoolTypeNode(), BoolTypeSyntax => new BoolTypeNode(),
CStringTypeSyntax => new CStringTypeNode(), CStringTypeSyntax => new CStringTypeNode(),
CustomTypeSyntax type => new CustomTypeNode(type.Name), CustomTypeSyntax type => CheckCustomType(type),
FloatTypeSyntax @float => new FloatTypeNode(@float.Width), FloatTypeSyntax @float => new FloatTypeNode(@float.Width),
FuncTypeSyntax type => new FuncTypeNode(type.Parameters.Select(CheckType).ToList(), CheckType(type.ReturnType)), FuncTypeSyntax type => new FuncTypeNode(type.Parameters.Select(CheckType).ToList(), CheckType(type.ReturnType)),
IntTypeSyntax @int => new IntTypeNode(@int.Signed, @int.Width), IntTypeSyntax @int => new IntTypeNode(@int.Signed, @int.Width),
PointerTypeSyntax type => new PointerTypeNode(CheckType(type.BaseType)), PointerTypeSyntax type => new PointerTypeNode(CheckType(type.BaseType)),
StringTypeSyntax => new NubStringTypeNode(), StringTypeSyntax => new StringTypeNode(),
VoidTypeSyntax => new VoidTypeNode(), VoidTypeSyntax => new VoidTypeNode(),
_ => throw new ArgumentOutOfRangeException(nameof(node)) _ => throw new ArgumentOutOfRangeException(nameof(node))
}; };
} }
private TypeNode CheckCustomType(CustomTypeSyntax type)
{
var structs = _definitionTable.LookupStruct(type.Name).ToArray();
if (structs.Length > 0)
{
if (structs.Length > 1)
{
throw new TypeCheckerException(Diagnostic.Error($"Struct {type.Name} has multiple definitions").Build());
}
var @struct = structs[0];
var fields = @struct.Fields.Select(x => CheckType(x.Type)).ToList();
var funcs = @struct.Functions
.Select(x => new FuncTypeNode(x.Signature.Parameters.Select(p => CheckType(p.Type)).ToList(), CheckType(x.Signature.ReturnType)))
.ToList();
var interfaceImplementations = new List<InterfaceTypeNode>();
foreach (var structInterfaceImplementation in @struct.InterfaceImplementations)
{
var checkedInterfaceType = CheckType(structInterfaceImplementation);
if (checkedInterfaceType is not InterfaceTypeNode interfaceType)
{
throw new TypeCheckerException(Diagnostic.Error($"{type.Name} cannot implement non-interface type {checkedInterfaceType}").Build());
}
interfaceImplementations.Add(interfaceType);
}
return new StructTypeNode(type.Name, fields, funcs, interfaceImplementations);
}
var interfaces = _definitionTable.LookupInterface(type.Name).ToArray();
if (interfaces.Length > 0)
{
if (interfaces.Length > 1)
{
throw new TypeCheckerException(Diagnostic.Error($"Interface {type.Name} has multiple definitions").Build());
}
var @interface = interfaces[0];
var functions = @interface.Functions
.Select(x => new FuncTypeNode(x.Signature.Parameters.Select(y => CheckType(y.Type)).ToList(), CheckType(x.Signature.ReturnType)))
.ToList();
return new InterfaceTypeNode(type.Name, functions);
}
throw new TypeCheckerException(Diagnostic.Error($"Type {type.Name} is not defined").Build());
}
} }
public record Variable(string Name, TypeNode Type); public record Variable(string Name, TypeNode Type);
@@ -672,11 +728,11 @@ public class Scope(Scope? parent = null)
} }
} }
public class CheckException : Exception public class TypeCheckerException : Exception
{ {
public Diagnostic Diagnostic { get; } public Diagnostic Diagnostic { get; }
public CheckException(Diagnostic diagnostic) : base(diagnostic.Message) public TypeCheckerException(Diagnostic diagnostic) : base(diagnostic.Message)
{ {
Diagnostic = diagnostic; Diagnostic = diagnostic;
} }