From 68f0da80658f593717fca1af9c8213ac31bf468c Mon Sep 17 00:00:00 2001 From: nub31 Date: Sat, 20 Sep 2025 19:13:17 +0200 Subject: [PATCH] ... --- compiler/NubLang.CLI/Program.cs | 2 +- .../NubLang/Generation/QBE/QBEGenerator.cs | 334 +++++++++-------- compiler/NubLang/Modules/Module.cs | 4 +- .../TypeChecking/Node/DefinitionNode.cs | 12 +- .../TypeChecking/Node/ExpressionNode.cs | 52 +-- compiler/NubLang/TypeChecking/Node/NubType.cs | 194 ++++++++++ .../TypeChecking/Node/StatementNode.cs | 4 +- .../NubLang/TypeChecking/Node/TypeNode.cs | 193 ---------- compiler/NubLang/TypeChecking/TypeChecker.cs | 343 ++++++++---------- example/src/main.nub | 72 ++-- 10 files changed, 592 insertions(+), 618 deletions(-) create mode 100644 compiler/NubLang/TypeChecking/Node/NubType.cs delete mode 100644 compiler/NubLang/TypeChecking/Node/TypeNode.cs diff --git a/compiler/NubLang.CLI/Program.cs b/compiler/NubLang.CLI/Program.cs index 1bed091..a1bc56d 100644 --- a/compiler/NubLang.CLI/Program.cs +++ b/compiler/NubLang.CLI/Program.cs @@ -65,7 +65,7 @@ var moduleRepository = new ModuleRepository(syntaxTrees); var definitions = new List(); -var referencedStructTypes = new HashSet(); +var referencedStructTypes = new HashSet(); foreach (var syntaxTree in syntaxTrees) { diff --git a/compiler/NubLang/Generation/QBE/QBEGenerator.cs b/compiler/NubLang/Generation/QBE/QBEGenerator.cs index bed24a1..ad8fe5c 100644 --- a/compiler/NubLang/Generation/QBE/QBEGenerator.cs +++ b/compiler/NubLang/Generation/QBE/QBEGenerator.cs @@ -10,7 +10,7 @@ public class QBEGenerator private readonly QBEWriter _writer; private readonly List _definitions; - private readonly HashSet _structTypes; + private readonly HashSet _structTypes; private readonly List _cStringLiterals = []; private readonly List _stringLiterals = []; @@ -24,7 +24,7 @@ public class QBEGenerator private Scope Scope => _scopes.Peek(); - public QBEGenerator(List definitions, HashSet structTypes) + public QBEGenerator(List definitions, HashSet structTypes) { _definitions = definitions; _structTypes = structTypes; @@ -142,72 +142,72 @@ public class QBEGenerator return _writer.ToString(); } - private static string QBEAssign(TypeNode type) + private static string QBEAssign(NubType type) { return type switch { - IntTypeNode { Width: <= 32 } => "=w", - IntTypeNode { Width: 64 } => "=l", - FloatTypeNode { Width: 32 } => "=s", - FloatTypeNode { Width: 64 } => "=d", - BoolTypeNode => "=w", - PointerTypeNode => "=l", - FuncTypeNode => "=l", - CStringTypeNode => "=l", - StringTypeNode => "=l", - ArrayTypeNode => "=l", - StructTypeNode => throw new InvalidOperationException("Structs are not loaded/stored directly"), - VoidTypeNode => throw new InvalidOperationException("Void has no assignment"), + NubIntType { Width: <= 32 } => "=w", + NubIntType { Width: 64 } => "=l", + NubFloatType { Width: 32 } => "=s", + NubFloatType { Width: 64 } => "=d", + NubBoolType => "=w", + NubPointerType => "=l", + NubFuncType => "=l", + NubCStringType => "=l", + NubStringType => "=l", + NubArrayType => "=l", + NubStructType => throw new InvalidOperationException("Structs are not loaded/stored directly"), + NubVoidType => throw new InvalidOperationException("Void has no assignment"), _ => throw new ArgumentOutOfRangeException(nameof(type), $"Unknown type {type.GetType()}") }; } - private void EmitStore(TypeNode type, string value, string destination) + private void EmitStore(NubType type, string value, string destination) { var store = type switch { - BoolTypeNode => "storeb", - IntTypeNode { Width: 8 } => "storeb", - IntTypeNode { Width: 16 } => "storeh", - IntTypeNode { Width: 32 } => "storew", - IntTypeNode { Width: 64 } => "storel", - FloatTypeNode { Width: 32 } => "stores", - FloatTypeNode { Width: 64 } => "stored", - PointerTypeNode => "storel", - FuncTypeNode => "storel", - CStringTypeNode => "storel", - StringTypeNode => "storel", - ArrayTypeNode => "storel", - StructTypeNode => throw new InvalidOperationException("Struct stores must use blit/memcpy"), - VoidTypeNode => throw new InvalidOperationException("Cannot store void"), + NubBoolType => "storeb", + NubIntType { Width: 8 } => "storeb", + NubIntType { Width: 16 } => "storeh", + NubIntType { Width: 32 } => "storew", + NubIntType { Width: 64 } => "storel", + NubFloatType { Width: 32 } => "stores", + NubFloatType { Width: 64 } => "stored", + NubPointerType => "storel", + NubFuncType => "storel", + NubCStringType => "storel", + NubStringType => "storel", + NubArrayType => "storel", + NubStructType => throw new InvalidOperationException("Struct stores must use blit/memcpy"), + NubVoidType => throw new InvalidOperationException("Cannot store void"), _ => throw new ArgumentOutOfRangeException(nameof(type), $"Unknown type {type.GetType()}") }; _writer.Indented($"{store} {value}, {destination}"); } - private string EmitLoad(TypeNode type, string from) + private string EmitLoad(NubType type, string from) { string load = type switch { - BoolTypeNode => "loadub", - IntTypeNode { Signed: true, Width: 8 } => "loadsb", - IntTypeNode { Signed: true, Width: 16 } => "loadsh", - IntTypeNode { Signed: true, Width: 32 } => "loadw", - IntTypeNode { Signed: true, Width: 64 } => "loadl", - IntTypeNode { Signed: false, Width: 8 } => "loadsb", - IntTypeNode { Signed: false, Width: 16 } => "loadsh", - IntTypeNode { Signed: false, Width: 32 } => "loadw", - IntTypeNode { Signed: false, Width: 64 } => "loadl", - FloatTypeNode { Width: 32 } => "loads", - FloatTypeNode { Width: 64 } => "loadd", - PointerTypeNode => "loadl", - FuncTypeNode => "loadl", - CStringTypeNode => "loadl", - StringTypeNode => "loadl", - ArrayTypeNode => "loadl", - StructTypeNode => throw new InvalidOperationException("Struct loads must use blit/memcpy"), - VoidTypeNode => throw new InvalidOperationException("Cannot load void"), + NubBoolType => "loadub", + NubIntType { Signed: true, Width: 8 } => "loadsb", + NubIntType { Signed: true, Width: 16 } => "loadsh", + NubIntType { Signed: true, Width: 32 } => "loadw", + NubIntType { Signed: true, Width: 64 } => "loadl", + NubIntType { Signed: false, Width: 8 } => "loadsb", + NubIntType { Signed: false, Width: 16 } => "loadsh", + NubIntType { Signed: false, Width: 32 } => "loadw", + NubIntType { Signed: false, Width: 64 } => "loadl", + NubFloatType { Width: 32 } => "loads", + NubFloatType { Width: 64 } => "loadd", + NubPointerType => "loadl", + NubFuncType => "loadl", + NubCStringType => "loadl", + NubStringType => "loadl", + NubArrayType => "loadl", + NubStructType => throw new InvalidOperationException("Struct loads must use blit/memcpy"), + NubVoidType => throw new InvalidOperationException("Cannot load void"), _ => throw new ArgumentOutOfRangeException(nameof(type), $"Unknown type {type.GetType()}") }; @@ -217,48 +217,48 @@ public class QBEGenerator } - private static string StructDefQBEType(TypeNode type) + private static string StructDefQBEType(NubType type) { return type switch { - BoolTypeNode => "b", - IntTypeNode { Width: 8 } => "b", - IntTypeNode { Width: 16 } => "h", - IntTypeNode { Width: 32 } => "w", - IntTypeNode { Width: 64 } => "l", - FloatTypeNode { Width: 32 } => "s", - FloatTypeNode { Width: 64 } => "d", - PointerTypeNode => "l", - FuncTypeNode => "l", - CStringTypeNode => "l", - StringTypeNode => "l", - ArrayTypeNode => "l", - StructTypeNode st => StructTypeName(st.Module, st.Name), - VoidTypeNode => throw new InvalidOperationException("Void has no QBE type"), + NubBoolType => "b", + NubIntType { Width: 8 } => "b", + NubIntType { Width: 16 } => "h", + NubIntType { Width: 32 } => "w", + NubIntType { Width: 64 } => "l", + NubFloatType { Width: 32 } => "s", + NubFloatType { Width: 64 } => "d", + NubPointerType => "l", + NubFuncType => "l", + NubCStringType => "l", + NubStringType => "l", + NubArrayType => "l", + NubStructType st => StructTypeName(st.Module, st.Name), + NubVoidType => throw new InvalidOperationException("Void has no QBE type"), _ => throw new ArgumentOutOfRangeException(nameof(type), $"Unknown type: {type.GetType()}") }; } - private string FuncQBETypeName(TypeNode type) + private string FuncQBETypeName(NubType type) { return type switch { - BoolTypeNode => "ub", - IntTypeNode { Signed: true, Width: 8 } => "sb", - IntTypeNode { Signed: true, Width: 16 } => "sh", - IntTypeNode { Signed: false, Width: 8 } => "ub", - IntTypeNode { Signed: false, Width: 16 } => "uh", - IntTypeNode { Width: 32 } => "w", - IntTypeNode { Width: 64 } => "l", - FloatTypeNode { Width: 32 } => "s", - FloatTypeNode { Width: 64 } => "d", - PointerTypeNode => "l", - FuncTypeNode => "l", - CStringTypeNode => "l", - StringTypeNode => "l", - ArrayTypeNode => "l", - StructTypeNode st => StructTypeName(st.Module, st.Name), - VoidTypeNode => throw new InvalidOperationException("Void has no QBE type"), + NubBoolType => "ub", + NubIntType { Signed: true, Width: 8 } => "sb", + NubIntType { Signed: true, Width: 16 } => "sh", + NubIntType { Signed: false, Width: 8 } => "ub", + NubIntType { Signed: false, Width: 16 } => "uh", + NubIntType { Width: 32 } => "w", + NubIntType { Width: 64 } => "l", + NubFloatType { Width: 32 } => "s", + NubFloatType { Width: 64 } => "d", + NubPointerType => "l", + NubFuncType => "l", + NubCStringType => "l", + NubStringType => "l", + NubArrayType => "l", + NubStructType st => StructTypeName(st.Module, st.Name), + NubVoidType => throw new InvalidOperationException("Void has no QBE type"), _ => throw new ArgumentOutOfRangeException(nameof(type), $"Unknown type: {type.GetType()}") }; } @@ -309,7 +309,7 @@ public class QBEGenerator var value = EmitExpression(source); _writer.Indented($"blit {value}, {destinationAddress}, {SizeOf(source.Type)}"); - if (source.Type is StructTypeNode structType) + if (source.Type is NubStructType structType) { var copyFunc = structType.Functions.FirstOrDefault(x => x.Hook == "oncopy"); if (copyFunc != null) @@ -323,7 +323,7 @@ public class QBEGenerator switch (source.Type) { - case CStringTypeNode: + case NubCStringType: { var value = EmitExpression(source); var size = EmitCStringSizeInBytes(value); @@ -333,7 +333,7 @@ public class QBEGenerator EmitStore(source.Type, buffer, destinationAddress); return; } - case StringTypeNode: + case NubStringType: { var value = EmitExpression(source); var size = EmitStringSizeInBytes(value); @@ -343,14 +343,14 @@ public class QBEGenerator EmitStore(source.Type, buffer, destinationAddress); return; } - case VoidTypeNode: + case NubVoidType: throw new InvalidOperationException("Cannot copy void"); default: throw new ArgumentOutOfRangeException(nameof(source.Type), $"Unknown type {source.Type}"); } } - private void EmitStructType(StructTypeNode structType) + private void EmitStructType(NubStructType structType) { // todo(nub31): qbe expects structs to be declared in order. We must Check the dependencies of the struct to see if a type need to be declared before this one // qbe allows multiple declarations of the same struct, but we should keep track and only emit an new one if necessary @@ -374,7 +374,7 @@ public class QBEGenerator { if (field.Value.TryGetValue(out var value)) { - var offset = OffsetOf(structDef.StructType, field.Name); + var offset = OffsetOf(TypeOfStruct(structDef), field.Name); var destination = TmpName(); _writer.Indented($"{destination} =l add %struct, {offset}"); EmitCopyInto(value, destination); @@ -393,7 +393,7 @@ public class QBEGenerator _writer.Write("export function "); - if (function.Signature.ReturnType is not VoidTypeNode) + if (function.Signature.ReturnType is not NubVoidType) { _writer.Write(FuncQBETypeName(function.Signature.ReturnType) + ' '); } @@ -418,7 +418,7 @@ public class QBEGenerator EmitBlock(function.Body, scope); // Implicit return for void functions if no explicit return has been set - if (function.Signature.ReturnType is VoidTypeNode && function.Body.Statements.LastOrDefault() is not ReturnNode) + if (function.Signature.ReturnType is NubVoidType && function.Body.Statements.LastOrDefault() is not ReturnNode) { _writer.Indented("ret"); } @@ -437,7 +437,7 @@ public class QBEGenerator _writer.Write(funcDef.ExternSymbol != null ? "export function " : "function "); - if (funcDef.Signature.ReturnType is not VoidTypeNode) + if (funcDef.Signature.ReturnType is not NubVoidType) { _writer.Write(FuncQBETypeName(funcDef.Signature.ReturnType) + ' '); } @@ -533,7 +533,7 @@ public class QBEGenerator while (Scope.Variables.TryPop(out var variable)) { - if (variable.Type is StructTypeNode structType) + if (variable.Type is NubStructType structType) { var destroyFunc = structType.Functions.FirstOrDefault(x => x.Hook == "ondestroy"); if (destroyFunc != null) @@ -643,7 +643,7 @@ public class QBEGenerator var elementAddress = TmpName(); _writer.Indented($"{elementAddress} =l add {arrayStart}, {arrayOffset}"); - if (forArray.ArrayType.ElementType is StructTypeNode) + if (forArray.ArrayType.ElementType is NubStructType) { _writer.Indented($"%{forArray.ElementIdent} =l copy {elementAddress}"); } @@ -689,7 +689,7 @@ public class QBEGenerator ConvertIntNode expr => EmitConvertInt(expr), FuncCallNode expr => EmitFuncCall(expr), FuncIdentifierNode expr => FuncName(expr.Module, expr.Name, expr.ExternSymbol), - FuncParameterIdentifierNode expr => $"%{expr.Name}", + RValueIdentifierNode expr => $"%{expr.Name}", StructFuncCallNode expr => EmitStructFuncCall(expr), StructInitializerNode expr => EmitStructInitializer(expr), UnaryExpressionNode expr => EmitUnaryExpression(expr), @@ -789,7 +789,7 @@ public class QBEGenerator ArrayIndexAccessNode arrayIndexAccess => EmitAddressOfArrayIndexAccess(arrayIndexAccess), DereferenceNode dereference => EmitExpression(dereference.Expression), StructFieldAccessNode structFieldAccess => EmitAddressOfStructFieldAccess(structFieldAccess), - VariableIdentifierNode variableIdent => $"%{variableIdent.Name}", + LValueIdentifierNode variableIdent => $"%{variableIdent.Name}", _ => throw new ArgumentOutOfRangeException(nameof(lval)) }; } @@ -799,7 +799,7 @@ public class QBEGenerator var array = EmitExpression(arrayIndexAccess.Target); var index = EmitExpression(arrayIndexAccess.Index); - var elementType = ((ArrayTypeNode)arrayIndexAccess.Target.Type).ElementType; + var elementType = ((NubArrayType)arrayIndexAccess.Target.Type).ElementType; var offset = TmpName(); _writer.Indented($"{offset} =l mul {index}, {SizeOf(elementType)}"); @@ -810,8 +810,9 @@ public class QBEGenerator private string EmitAddressOfStructFieldAccess(StructFieldAccessNode structFieldAccess) { + var structType = (NubStructType)structFieldAccess.Target.Type; var target = EmitExpression(structFieldAccess.Target); - var offset = OffsetOf(structFieldAccess.StructType, structFieldAccess.Field); + var offset = OffsetOf(structType, structFieldAccess.Field); var address = TmpName(); _writer.Indented($"{address} =l add {target}, {offset}"); @@ -831,15 +832,15 @@ public class QBEGenerator return outputName; } - private static string EmitBinaryInstructionForOperator(BinaryOperator op, TypeNode type) + private static string EmitBinaryInstructionForOperator(BinaryOperator op, NubType type) { // todo(nub31): Add support for string concatenation. Currently this expects ints or floats and will treat strings as ints return op switch { BinaryOperator.RightShift => type switch { - IntTypeNode { Signed: true } => "sar", - IntTypeNode { Signed: false } => "shr", + NubIntType { Signed: true } => "sar", + NubIntType { Signed: false } => "shr", _ => throw new NotSupportedException($"Right shift not supported for type '{type}'") }, BinaryOperator.BitwiseAnd => "and", @@ -848,15 +849,15 @@ public class QBEGenerator BinaryOperator.LeftShift => "shl", BinaryOperator.Divide => type switch { - IntTypeNode { Signed: true } => "div", - IntTypeNode { Signed: false } => "udiv", - FloatTypeNode => "div", + NubIntType { Signed: true } => "div", + NubIntType { Signed: false } => "udiv", + NubFloatType => "div", _ => throw new NotSupportedException($"Division not supported for type '{type}'") }, BinaryOperator.Modulo => type switch { - IntTypeNode { Signed: true } => "rem", - IntTypeNode { Signed: false } => "urem", + NubIntType { Signed: true } => "rem", + NubIntType { Signed: false } => "urem", _ => throw new NotSupportedException($"Modulo not supported for type '{type}'") }, BinaryOperator.Plus => "add", @@ -864,13 +865,13 @@ public class QBEGenerator BinaryOperator.Multiply => "mul", BinaryOperator.Equal => type switch { - IntTypeNode intType => intType.Width switch + NubIntType intType => intType.Width switch { <= 32 => "ceqw", 64 => "ceql", _ => throw new ArgumentOutOfRangeException() }, - FloatTypeNode floatType => floatType.Width switch + NubFloatType floatType => floatType.Width switch { 32 => "ceqs", 64 => "ceqd", @@ -880,13 +881,13 @@ public class QBEGenerator }, BinaryOperator.NotEqual => type switch { - IntTypeNode intType => intType.Width switch + NubIntType intType => intType.Width switch { <= 32 => "cnew", 64 => "cnel", _ => throw new ArgumentOutOfRangeException() }, - FloatTypeNode floatType => floatType.Width switch + NubFloatType floatType => floatType.Width switch { 32 => "cnes", 64 => "cned", @@ -896,19 +897,19 @@ public class QBEGenerator }, BinaryOperator.LessThan => type switch { - IntTypeNode { Signed: true } intType => intType.Width switch + NubIntType { Signed: true } intType => intType.Width switch { <= 32 => "csltw", 64 => "csltl", _ => throw new ArgumentOutOfRangeException() }, - IntTypeNode { Signed: false } intType => intType.Width switch + NubIntType { Signed: false } intType => intType.Width switch { <= 32 => "cultw", 64 => "cultl", _ => throw new ArgumentOutOfRangeException() }, - FloatTypeNode floatType => floatType.Width switch + NubFloatType floatType => floatType.Width switch { 32 => "clts", 64 => "cltd", @@ -918,19 +919,19 @@ public class QBEGenerator }, BinaryOperator.LessThanOrEqual => type switch { - IntTypeNode { Signed: true } intType => intType.Width switch + NubIntType { Signed: true } intType => intType.Width switch { <= 32 => "cslew", 64 => "cslel", _ => throw new ArgumentOutOfRangeException() }, - IntTypeNode { Signed: false } intType => intType.Width switch + NubIntType { Signed: false } intType => intType.Width switch { <= 32 => "culew", 64 => "culel", _ => throw new ArgumentOutOfRangeException() }, - FloatTypeNode floatType => floatType.Width switch + NubFloatType floatType => floatType.Width switch { 32 => "cles", 64 => "cled", @@ -940,19 +941,19 @@ public class QBEGenerator }, BinaryOperator.GreaterThan => type switch { - IntTypeNode { Signed: true } intType => intType.Width switch + NubIntType { Signed: true } intType => intType.Width switch { <= 32 => "csgtw", 64 => "csgtl", _ => throw new ArgumentOutOfRangeException() }, - IntTypeNode { Signed: false } intType => intType.Width switch + NubIntType { Signed: false } intType => intType.Width switch { <= 32 => "cugtw", 64 => "cugtl", _ => throw new ArgumentOutOfRangeException() }, - FloatTypeNode floatType => floatType.Width switch + NubFloatType floatType => floatType.Width switch { 32 => "cgts", 64 => "cgtd", @@ -962,19 +963,19 @@ public class QBEGenerator }, BinaryOperator.GreaterThanOrEqual => type switch { - IntTypeNode { Signed: true } intType => intType.Width switch + NubIntType { Signed: true } intType => intType.Width switch { <= 32 => "csgew", 64 => "csgel", _ => throw new ArgumentOutOfRangeException() }, - IntTypeNode { Signed: false } intType => intType.Width switch + NubIntType { Signed: false } intType => intType.Width switch { <= 32 => "cugew", 64 => "cugel", _ => throw new ArgumentOutOfRangeException() }, - FloatTypeNode floatType => floatType.Width switch + NubFloatType floatType => floatType.Width switch { 32 => "cges", 64 => "cged", @@ -1027,16 +1028,16 @@ public class QBEGenerator { switch (unaryExpression.Operand.Type) { - case IntTypeNode { Signed: true, Width: 64 }: + case NubIntType { Signed: true, Width: 64 }: _writer.Indented($"{outputName} =l neg {operand}"); return outputName; - case IntTypeNode { Signed: true, Width: 8 or 16 or 32 }: + case NubIntType { Signed: true, Width: 8 or 16 or 32 }: _writer.Indented($"{outputName} =w neg {operand}"); return outputName; - case FloatTypeNode { Width: 64 }: + case NubFloatType { Width: 64 }: _writer.Indented($"{outputName} =d neg {operand}"); return outputName; - case FloatTypeNode { Width: 32 }: + case NubFloatType { Width: 32 }: _writer.Indented($"{outputName} =s neg {operand}"); return outputName; } @@ -1047,7 +1048,7 @@ public class QBEGenerator { switch (unaryExpression.Operand.Type) { - case BoolTypeNode: + case NubBoolType: _writer.Indented($"{outputName} =w xor {operand}, 1"); return outputName; } @@ -1065,7 +1066,7 @@ public class QBEGenerator private string EmitStructFuncCall(StructFuncCallNode structFuncCall) { - var func = StructFuncName(structFuncCall.StructType.Module, structFuncCall.StructType.Name, structFuncCall.Name); + var func = StructFuncName(structFuncCall.Module, structFuncCall.StructName, structFuncCall.FuncName); var thisParameter = EmitExpression(structFuncCall.StructExpression); @@ -1075,7 +1076,7 @@ public class QBEGenerator { var value = EmitExpression(parameter); - if (parameter.Type is StructTypeNode structType) + if (parameter.Type is NubStructType structType) { var copyFunc = structType.Functions.FirstOrDefault(x => x.Hook == "oncopy"); if (copyFunc != null) @@ -1087,7 +1088,7 @@ public class QBEGenerator parameterStrings.Add($"{FuncQBETypeName(parameter.Type)} {value}"); } - if (structFuncCall.Type is VoidTypeNode) + if (structFuncCall.Type is NubVoidType) { _writer.Indented($"call {func}({string.Join(", ", parameterStrings)})"); return string.Empty; @@ -1163,7 +1164,7 @@ public class QBEGenerator { var value = EmitExpression(parameter); - if (parameter.Type is StructTypeNode structType) + if (parameter.Type is NubStructType structType) { var copyFunc = structType.Functions.FirstOrDefault(x => x.Hook == "oncopy"); if (copyFunc != null) @@ -1175,7 +1176,7 @@ public class QBEGenerator parameterStrings.Add($"{FuncQBETypeName(parameter.Type)} {value}"); } - if (funcCall.Type is VoidTypeNode) + if (funcCall.Type is NubVoidType) { _writer.Indented($"call {funcPointer}({string.Join(", ", parameterStrings)})"); return string.Empty; @@ -1188,25 +1189,25 @@ public class QBEGenerator } } - private static int SizeOf(TypeNode type) + private static int SizeOf(NubType type) { return type switch { - IntTypeNode intType => intType.Width / 8, - FloatTypeNode fType => fType.Width / 8, - BoolTypeNode => 1, - PointerTypeNode => PTR_SIZE, - FuncTypeNode => PTR_SIZE, - StructTypeNode structType => CalculateStructSize(structType), - VoidTypeNode => throw new InvalidOperationException("Void type has no size"), - CStringTypeNode => PTR_SIZE, - StringTypeNode => PTR_SIZE, - ArrayTypeNode => PTR_SIZE, + NubIntType intType => intType.Width / 8, + NubFloatType fType => fType.Width / 8, + NubBoolType => 1, + NubPointerType => PTR_SIZE, + NubFuncType => PTR_SIZE, + NubStructType structType => CalculateStructSize(structType), + NubVoidType => throw new InvalidOperationException("Void type has no size"), + NubCStringType => PTR_SIZE, + NubStringType => PTR_SIZE, + NubArrayType => PTR_SIZE, _ => throw new ArgumentOutOfRangeException(nameof(type), $"Unknown type: {type.GetType()}") }; } - private static int CalculateStructSize(StructTypeNode structType) + private static int CalculateStructSize(NubStructType structType) { var offset = 0; @@ -1221,26 +1222,26 @@ public class QBEGenerator return AlignTo(offset, structAlignment); } - private static int AlignmentOf(TypeNode type) + private static int AlignmentOf(NubType type) { return type switch { - IntTypeNode intType => intType.Width / 8, - FloatTypeNode fType => fType.Width / 8, - BoolTypeNode => 1, - PointerTypeNode => PTR_SIZE, - FuncTypeNode => PTR_SIZE, - StructTypeNode st => CalculateStructAlignment(st), - CStringTypeNode => PTR_SIZE, - StringTypeNode => PTR_SIZE, - ArrayTypeNode => PTR_SIZE, - VoidTypeNode => throw new InvalidOperationException("Void has no alignment"), + NubIntType intType => intType.Width / 8, + NubFloatType fType => fType.Width / 8, + NubBoolType => 1, + NubPointerType => PTR_SIZE, + NubFuncType => PTR_SIZE, + NubStructType st => CalculateStructAlignment(st), + NubCStringType => PTR_SIZE, + NubStringType => PTR_SIZE, + NubArrayType => PTR_SIZE, + NubVoidType => throw new InvalidOperationException("Void has no alignment"), _ => throw new ArgumentOutOfRangeException(nameof(type), $"Unknown type: {type.GetType()}") }; } - private static int CalculateStructAlignment(StructTypeNode structType) + private static int CalculateStructAlignment(NubStructType structType) { var maxAlignment = 1; @@ -1258,7 +1259,7 @@ public class QBEGenerator return (offset + alignment - 1) & ~(alignment - 1); } - private static int OffsetOf(StructTypeNode type, string member) + private static int OffsetOf(NubStructType type, string member) { var offset = 0; @@ -1278,6 +1279,27 @@ public class QBEGenerator throw new UnreachableException($"Member '{member}' not found in struct"); } + private NubStructType TypeOfStruct(StructNode definition) + { + var fieldTypes = definition.Fields + .Select(x => new NubStructFieldType(x.Name, x.Type, x.Value.HasValue)) + .ToList(); + + var functionTypes = new List(); + foreach (var function in definition.Functions) + { + var parameters = function.Signature.Parameters.Select(x => x.Type).ToList(); + functionTypes.Add(new NubStructFuncType(function.Name, function.Hook, parameters, function.Signature.ReturnType)); + } + + return new NubStructType(definition.Module, definition.Name, fieldTypes, functionTypes); + } + + private NubFuncType TypeOfFunc(FuncSignatureNode signature) + { + var parameters = signature.Parameters.Select(x => x.Type).ToList(); + return new NubFuncType(parameters, signature.ReturnType); + } private string TmpName() { @@ -1332,7 +1354,7 @@ public class Scope(Scope? parent = null) } } -public record Variable(string Name, TypeNode Type); +public record Variable(string Name, NubType Type); public class StringLiteral(string value, string name) { diff --git a/compiler/NubLang/Modules/Module.cs b/compiler/NubLang/Modules/Module.cs index c1cbb96..ff28711 100644 --- a/compiler/NubLang/Modules/Module.cs +++ b/compiler/NubLang/Modules/Module.cs @@ -38,4 +38,6 @@ public record ModuleStruct(bool Exported, string Name, List F public record ModuleFunctionParameter(string Name, TypeSyntax Type); -public record ModuleFunction(bool Exported, string Name, string? ExternSymbol, List Parameters, TypeSyntax ReturnType); \ No newline at end of file +public record ModuleFunction(bool Exported, string Name, string? ExternSymbol, List Parameters, TypeSyntax ReturnType); + +public record ModuleTemplateStruct(bool Exported, string? Name); \ No newline at end of file diff --git a/compiler/NubLang/TypeChecking/Node/DefinitionNode.cs b/compiler/NubLang/TypeChecking/Node/DefinitionNode.cs index cfecb44..e24e02e 100644 --- a/compiler/NubLang/TypeChecking/Node/DefinitionNode.cs +++ b/compiler/NubLang/TypeChecking/Node/DefinitionNode.cs @@ -2,14 +2,16 @@ public abstract record DefinitionNode(string Module, string Name) : Node; -public record FuncParameterNode(string Name, TypeNode Type) : Node; +public record FuncParameterNode(string Name, NubType Type) : Node; -public record FuncSignatureNode(List Parameters, TypeNode ReturnType) : Node; +public record FuncSignatureNode(List Parameters, NubType ReturnType) : Node; public record FuncNode(string Module, string Name, string? ExternSymbol, FuncSignatureNode Signature, BlockNode? Body) : DefinitionNode(Module, Name); -public record StructFieldNode(string Name, TypeNode Type, Optional Value) : Node; +public record StructFieldNode(string Name, NubType Type, Optional Value) : Node; -public record StructFuncNode(string Name, FuncSignatureNode Signature, BlockNode Body) : Node; +public record StructFuncNode(string Name, string? Hook, FuncSignatureNode Signature, BlockNode Body) : Node; -public record StructNode(StructTypeNode StructType, string Module, string Name, List Fields, List Functions) : DefinitionNode(Module, Name); \ No newline at end of file +public record StructNode(string Module, string Name, List Fields, List Functions) : DefinitionNode(Module, Name); + +public record StructTemplateNode(string Module, string Name, List TemplateArguments, List Fields, List Functions) : DefinitionNode(Module, Name); \ No newline at end of file diff --git a/compiler/NubLang/TypeChecking/Node/ExpressionNode.cs b/compiler/NubLang/TypeChecking/Node/ExpressionNode.cs index 27db6d2..358f502 100644 --- a/compiler/NubLang/TypeChecking/Node/ExpressionNode.cs +++ b/compiler/NubLang/TypeChecking/Node/ExpressionNode.cs @@ -28,54 +28,54 @@ public enum BinaryOperator BitwiseOr } -public abstract record ExpressionNode(TypeNode Type) : Node; +public abstract record ExpressionNode(NubType Type) : Node; -public abstract record LValueExpressionNode(TypeNode Type) : ExpressionNode(Type); +public abstract record LValueExpressionNode(NubType Type) : ExpressionNode(Type); -public abstract record RValueExpressionNode(TypeNode Type) : ExpressionNode(Type); +public abstract record RValueExpressionNode(NubType Type) : ExpressionNode(Type); -public record StringLiteralNode(TypeNode Type, string Value) : RValueExpressionNode(Type); +public record StringLiteralNode(NubType Type, string Value) : RValueExpressionNode(Type); -public record CStringLiteralNode(TypeNode Type, string Value) : RValueExpressionNode(Type); +public record CStringLiteralNode(NubType Type, string Value) : RValueExpressionNode(Type); -public record IntLiteralNode(TypeNode Type, long Value) : RValueExpressionNode(Type); +public record IntLiteralNode(NubType Type, long Value) : RValueExpressionNode(Type); -public record UIntLiteralNode(TypeNode Type, ulong Value) : RValueExpressionNode(Type); +public record UIntLiteralNode(NubType Type, ulong Value) : RValueExpressionNode(Type); -public record Float32LiteralNode(TypeNode Type, float Value) : RValueExpressionNode(Type); +public record Float32LiteralNode(NubType Type, float Value) : RValueExpressionNode(Type); -public record Float64LiteralNode(TypeNode Type, double Value) : RValueExpressionNode(Type); +public record Float64LiteralNode(NubType Type, double Value) : RValueExpressionNode(Type); -public record BoolLiteralNode(TypeNode Type, bool Value) : RValueExpressionNode(Type); +public record BoolLiteralNode(NubType Type, bool Value) : RValueExpressionNode(Type); -public record BinaryExpressionNode(TypeNode Type, ExpressionNode Left, BinaryOperator Operator, ExpressionNode Right) : RValueExpressionNode(Type); +public record BinaryExpressionNode(NubType Type, ExpressionNode Left, BinaryOperator Operator, ExpressionNode Right) : RValueExpressionNode(Type); -public record UnaryExpressionNode(TypeNode Type, UnaryOperator Operator, ExpressionNode Operand) : RValueExpressionNode(Type); +public record UnaryExpressionNode(NubType Type, UnaryOperator Operator, ExpressionNode Operand) : RValueExpressionNode(Type); -public record FuncCallNode(TypeNode Type, ExpressionNode Expression, List Parameters) : RValueExpressionNode(Type); +public record FuncCallNode(NubType Type, ExpressionNode Expression, List Parameters) : RValueExpressionNode(Type); -public record StructFuncCallNode(TypeNode Type, string Name, StructTypeNode StructType, ExpressionNode StructExpression, List Parameters) : RValueExpressionNode(Type); +public record StructFuncCallNode(NubType Type, string Module, string StructName, string FuncName, ExpressionNode StructExpression, List Parameters) : RValueExpressionNode(Type); -public record VariableIdentifierNode(TypeNode Type, string Name) : LValueExpressionNode(Type); +public record LValueIdentifierNode(NubType Type, string Name) : LValueExpressionNode(Type); -public record FuncParameterIdentifierNode(TypeNode Type, string Name) : RValueExpressionNode(Type); +public record RValueIdentifierNode(NubType Type, string Name) : RValueExpressionNode(Type); -public record FuncIdentifierNode(TypeNode Type, string Module, string Name, string? ExternSymbol) : RValueExpressionNode(Type); +public record FuncIdentifierNode(NubType Type, string Module, string Name, string? ExternSymbol) : RValueExpressionNode(Type); -public record ArrayInitializerNode(TypeNode Type, ExpressionNode Capacity, TypeNode ElementType) : RValueExpressionNode(Type); +public record ArrayInitializerNode(NubType Type, ExpressionNode Capacity, NubType ElementType) : RValueExpressionNode(Type); -public record ArrayIndexAccessNode(TypeNode Type, ExpressionNode Target, ExpressionNode Index) : LValueExpressionNode(Type); +public record ArrayIndexAccessNode(NubType Type, ExpressionNode Target, ExpressionNode Index) : LValueExpressionNode(Type); -public record AddressOfNode(TypeNode Type, LValueExpressionNode LValue) : RValueExpressionNode(Type); +public record AddressOfNode(NubType Type, LValueExpressionNode LValue) : RValueExpressionNode(Type); -public record StructFieldAccessNode(TypeNode Type, StructTypeNode StructType, ExpressionNode Target, string Field) : LValueExpressionNode(Type); +public record StructFieldAccessNode(NubType Type, ExpressionNode Target, string Field) : LValueExpressionNode(Type); -public record StructInitializerNode(StructTypeNode StructType, Dictionary Initializers) : RValueExpressionNode(StructType); +public record StructInitializerNode(NubStructType StructType, Dictionary Initializers) : RValueExpressionNode(StructType); -public record DereferenceNode(TypeNode Type, ExpressionNode Expression) : LValueExpressionNode(Type); +public record DereferenceNode(NubType Type, ExpressionNode Expression) : LValueExpressionNode(Type); -public record ConvertIntNode(TypeNode Type, ExpressionNode Value, IntTypeNode ValueType, IntTypeNode TargetType) : RValueExpressionNode(Type); +public record ConvertIntNode(NubType Type, ExpressionNode Value, NubIntType ValueType, NubIntType TargetType) : RValueExpressionNode(Type); -public record ConvertFloatNode(TypeNode Type, ExpressionNode Value, FloatTypeNode ValueType, FloatTypeNode TargetType) : RValueExpressionNode(Type); +public record ConvertFloatNode(NubType Type, ExpressionNode Value, NubFloatType ValueType, NubFloatType TargetType) : RValueExpressionNode(Type); -public record SizeCompilerMacroNode(TypeNode Type, TypeNode TargetType) : RValueExpressionNode(Type); \ No newline at end of file +public record SizeCompilerMacroNode(NubType Type, NubType TargetType) : RValueExpressionNode(Type); \ No newline at end of file diff --git a/compiler/NubLang/TypeChecking/Node/NubType.cs b/compiler/NubLang/TypeChecking/Node/NubType.cs new file mode 100644 index 0000000..6394a83 --- /dev/null +++ b/compiler/NubLang/TypeChecking/Node/NubType.cs @@ -0,0 +1,194 @@ +using System.Security.Cryptography; +using System.Text; + +namespace NubLang.TypeChecking.Node; + +public abstract class NubType : IEquatable +{ + public abstract bool IsValueType { get; } + public abstract bool IsScalar { get; } + + public override bool Equals(object? obj) => obj is NubType other && Equals(other); + public abstract bool Equals(NubType? other); + + public abstract override int GetHashCode(); + public abstract override string ToString(); + + public static bool operator ==(NubType? left, NubType? right) => Equals(left, right); + public static bool operator !=(NubType? left, NubType? right) => !Equals(left, right); +} + +public class NubVoidType : NubType +{ + public override bool IsValueType => false; + public override bool IsScalar => false; + + public override string ToString() => "void"; + public override bool Equals(NubType? other) => other is NubVoidType; + public override int GetHashCode() => HashCode.Combine(typeof(NubVoidType)); +} + +public sealed class NubIntType(bool signed, int width) : NubType +{ + public override bool IsValueType => true; + public override bool IsScalar => true; + + public bool Signed { get; } = signed; + public int Width { get; } = width; + + public override string ToString() => $"{(Signed ? "i" : "u")}{Width}"; + public override bool Equals(NubType? other) => other is NubIntType @int && @int.Width == Width && @int.Signed == Signed; + public override int GetHashCode() => HashCode.Combine(typeof(NubIntType), Signed, Width); +} + +public sealed class NubFloatType(int width) : NubType +{ + public override bool IsValueType => true; + public override bool IsScalar => true; + + public int Width { get; } = width; + + public override string ToString() => $"f{Width}"; + public override bool Equals(NubType? other) => other is NubFloatType @float && @float.Width == Width; + public override int GetHashCode() => HashCode.Combine(typeof(NubFloatType), Width); +} + +public class NubBoolType : NubType +{ + public override bool IsValueType => true; + public override bool IsScalar => true; + + public override string ToString() => "bool"; + public override bool Equals(NubType? other) => other is NubBoolType; + public override int GetHashCode() => HashCode.Combine(typeof(NubBoolType)); +} + +public sealed class NubPointerType(NubType baseType) : NubType +{ + public override bool IsValueType => true; + public override bool IsScalar => true; + + public NubType BaseType { get; } = baseType; + + public override string ToString() => "^" + BaseType; + public override bool Equals(NubType? other) => other is NubPointerType pointer && BaseType.Equals(pointer.BaseType); + public override int GetHashCode() => HashCode.Combine(typeof(NubPointerType), BaseType); +} + +public class NubFuncType(List parameters, NubType returnType) : NubType +{ + public override bool IsValueType => true; + public override bool IsScalar => true; + + public List Parameters { get; } = parameters; + public NubType ReturnType { get; } = returnType; + + public override string ToString() => $"func({string.Join(", ", Parameters)}): {ReturnType}"; + public override bool Equals(NubType? other) => other is NubFuncType func && ReturnType.Equals(func.ReturnType) && Parameters.SequenceEqual(func.Parameters); + + public override int GetHashCode() + { + var hash = new HashCode(); + hash.Add(typeof(NubFuncType)); + hash.Add(ReturnType); + foreach (var param in Parameters) + { + hash.Add(param); + } + + return hash.ToHashCode(); + } +} + +public class NubStructType(string module, string name, List fields, List functions) : NubType +{ + public override bool IsValueType => true; + public override bool IsScalar => false; + + public string Module { get; } = module; + public string Name { get; } = name; + public List Fields { get; set; } = fields; + public List Functions { get; set; } = functions; + + public override string ToString() => $"{Module}.{Name}"; + public override bool Equals(NubType? other) => other is NubStructType structType && Name == structType.Name && Module == structType.Module; + public override int GetHashCode() => HashCode.Combine(typeof(NubStructType), Name); +} + +public class NubStructFieldType(string name, NubType type, bool hasDefaultValue) +{ + public string Name { get; } = name; + public NubType Type { get; } = type; + public bool HasDefaultValue { get; } = hasDefaultValue; +} + +public class NubStructFuncType(string name, string? hook, List parameters, NubType returnType) +{ + public string Name { get; } = name; + public string? Hook { get; set; } = hook; + public List Parameters { get; } = parameters; + public NubType ReturnType { get; } = returnType; +} + +public class NubArrayType(NubType elementType) : NubType +{ + public override bool IsValueType => false; + public override bool IsScalar => true; + + public NubType ElementType { get; } = elementType; + + public override string ToString() => "[]" + ElementType; + public override bool Equals(NubType? other) => other is NubArrayType array && ElementType.Equals(array.ElementType); + public override int GetHashCode() => HashCode.Combine(typeof(NubArrayType), ElementType); +} + +public class NubCStringType : NubType +{ + public override bool IsValueType => false; + public override bool IsScalar => false; + + public override string ToString() => "cstring"; + public override bool Equals(NubType? other) => other is NubCStringType; + public override int GetHashCode() => HashCode.Combine(typeof(NubCStringType)); +} + +public class NubStringType : NubType +{ + public override bool IsValueType => false; + public override bool IsScalar => false; + + public override string ToString() => "string"; + public override bool Equals(NubType? other) => other is NubStringType; + public override int GetHashCode() => HashCode.Combine(typeof(NubStringType)); +} + +public static class NameMangler +{ + public static string Mangle(params IEnumerable types) + { + var readable = string.Join("_", types.Select(EncodeType)); + return ComputeShortHash(readable); + } + + private static string EncodeType(NubType node) => node switch + { + NubVoidType => "V", + NubBoolType => "B", + NubIntType i => (i.Signed ? "I" : "U") + i.Width, + NubFloatType f => "F" + f.Width, + NubCStringType => "CS", + NubStringType => "S", + NubPointerType p => "P" + EncodeType(p.BaseType), + NubArrayType a => "A" + EncodeType(a.ElementType), + NubFuncType fn => "FN(" + string.Join(",", fn.Parameters.Select(EncodeType)) + ")" + EncodeType(fn.ReturnType), + NubStructType st => "ST(" + st.Module + "." + st.Name + ")", + _ => throw new NotSupportedException($"Cannot encode type: {node}") + }; + + private static string ComputeShortHash(string input) + { + var bytes = Encoding.UTF8.GetBytes(input); + var hash = SHA256.HashData(bytes); + return Convert.ToHexString(hash[..8]).ToLower(); + } +} \ No newline at end of file diff --git a/compiler/NubLang/TypeChecking/Node/StatementNode.cs b/compiler/NubLang/TypeChecking/Node/StatementNode.cs index 3ec3764..92d2b8a 100644 --- a/compiler/NubLang/TypeChecking/Node/StatementNode.cs +++ b/compiler/NubLang/TypeChecking/Node/StatementNode.cs @@ -16,7 +16,7 @@ public record AssignmentNode(LValueExpressionNode Target, ExpressionNode Value) public record IfNode(ExpressionNode Condition, BlockNode Body, Optional> Else) : StatementNode; -public record VariableDeclarationNode(string Name, Optional Assignment, TypeNode Type) : StatementNode; +public record VariableDeclarationNode(string Name, Optional Assignment, NubType Type) : StatementNode; public record ContinueNode : TerminalStatementNode; @@ -26,4 +26,4 @@ public record DeferNode(StatementNode Statement) : StatementNode; public record WhileNode(ExpressionNode Condition, BlockNode Body) : StatementNode; -public record ForArrayNode(ArrayTypeNode ArrayType, string ElementIdent, string? IndexIdent, ExpressionNode Target, BlockNode Body) : StatementNode; \ No newline at end of file +public record ForArrayNode(NubArrayType ArrayType, string ElementIdent, string? IndexIdent, ExpressionNode Target, BlockNode Body) : StatementNode; \ No newline at end of file diff --git a/compiler/NubLang/TypeChecking/Node/TypeNode.cs b/compiler/NubLang/TypeChecking/Node/TypeNode.cs deleted file mode 100644 index dd042f9..0000000 --- a/compiler/NubLang/TypeChecking/Node/TypeNode.cs +++ /dev/null @@ -1,193 +0,0 @@ -using System.Security.Cryptography; -using System.Text; - -namespace NubLang.TypeChecking.Node; - -public abstract class TypeNode : IEquatable -{ - public abstract bool IsValueType { get; } - public abstract bool IsScalar { get; } - - public override bool Equals(object? obj) => obj is TypeNode other && Equals(other); - public abstract bool Equals(TypeNode? other); - - public abstract override int GetHashCode(); - public abstract override string ToString(); - - public static bool operator ==(TypeNode? left, TypeNode? right) => Equals(left, right); - public static bool operator !=(TypeNode? left, TypeNode? right) => !Equals(left, right); -} - -public class VoidTypeNode : TypeNode -{ - public override bool IsValueType => false; - public override bool IsScalar => false; - - public override string ToString() => "void"; - public override bool Equals(TypeNode? other) => other is VoidTypeNode; - public override int GetHashCode() => HashCode.Combine(typeof(VoidTypeNode)); -} - -public sealed class IntTypeNode(bool signed, int width) : TypeNode -{ - public override bool IsValueType => true; - public override bool IsScalar => true; - - public bool Signed { get; } = signed; - public int Width { get; } = width; - - public override string ToString() => $"{(Signed ? "i" : "u")}{Width}"; - public override bool Equals(TypeNode? other) => other is IntTypeNode @int && @int.Width == Width && @int.Signed == Signed; - public override int GetHashCode() => HashCode.Combine(typeof(IntTypeNode), Signed, Width); -} - -public sealed class FloatTypeNode(int width) : TypeNode -{ - public override bool IsValueType => true; - public override bool IsScalar => true; - - public int Width { get; } = width; - - public override string ToString() => $"f{Width}"; - public override bool Equals(TypeNode? other) => other is FloatTypeNode @float && @float.Width == Width; - public override int GetHashCode() => HashCode.Combine(typeof(FloatTypeNode), Width); -} - -public class BoolTypeNode : TypeNode -{ - public override bool IsValueType => true; - public override bool IsScalar => true; - - public override string ToString() => "bool"; - public override bool Equals(TypeNode? other) => other is BoolTypeNode; - public override int GetHashCode() => HashCode.Combine(typeof(BoolTypeNode)); -} - -public sealed class PointerTypeNode(TypeNode baseType) : TypeNode -{ - public override bool IsValueType => true; - public override bool IsScalar => true; - - public TypeNode BaseType { get; } = baseType; - - public override string ToString() => "^" + BaseType; - public override bool Equals(TypeNode? other) => other is PointerTypeNode pointer && BaseType.Equals(pointer.BaseType); - public override int GetHashCode() => HashCode.Combine(typeof(PointerTypeNode), BaseType); -} - -public class FuncTypeNode(List parameters, TypeNode returnType) : TypeNode -{ - public override bool IsValueType => true; - public override bool IsScalar => true; - - public List Parameters { get; } = parameters; - public TypeNode ReturnType { get; } = returnType; - - public override string ToString() => $"func({string.Join(", ", Parameters)}): {ReturnType}"; - public override bool Equals(TypeNode? other) => other is FuncTypeNode func && ReturnType.Equals(func.ReturnType) && Parameters.SequenceEqual(func.Parameters); - - public override int GetHashCode() - { - var hash = new HashCode(); - hash.Add(typeof(FuncTypeNode)); - hash.Add(ReturnType); - foreach (var param in Parameters) - { - hash.Add(param); - } - - return hash.ToHashCode(); - } -} - -public class StructTypeNode(string module, string name, List fields, List functions) : TypeNode -{ - public override bool IsValueType => true; - public override bool IsScalar => false; - - public string Module { get; } = module; - public string Name { get; } = name; - public List Fields { get; set; } = fields; - public List Functions { get; set; } = functions; - - public override string ToString() => Name; - public override bool Equals(TypeNode? other) => other is StructTypeNode structType && Name == structType.Name && Module == structType.Module; - public override int GetHashCode() => HashCode.Combine(typeof(StructTypeNode), Name); -} - -public class StructTypeField(string name, TypeNode type, bool hasDefaultValue) -{ - public string Name { get; } = name; - public TypeNode Type { get; } = type; - public bool HasDefaultValue { get; } = hasDefaultValue; -} - -public class StructTypeFunc(string name, string? hook, FuncTypeNode type) -{ - public string Name { get; } = name; - public string? Hook { get; set; } = hook; - public FuncTypeNode Type { get; } = type; -} - -public class ArrayTypeNode(TypeNode elementType) : TypeNode -{ - public override bool IsValueType => false; - public override bool IsScalar => true; - - public TypeNode ElementType { get; } = elementType; - - public override string ToString() => "[]" + ElementType; - public override bool Equals(TypeNode? other) => other is ArrayTypeNode array && ElementType.Equals(array.ElementType); - public override int GetHashCode() => HashCode.Combine(typeof(ArrayTypeNode), ElementType); -} - -public class CStringTypeNode : TypeNode -{ - public override bool IsValueType => false; - public override bool IsScalar => false; - - public override string ToString() => "cstring"; - public override bool Equals(TypeNode? other) => other is CStringTypeNode; - public override int GetHashCode() => HashCode.Combine(typeof(CStringTypeNode)); -} - -public class StringTypeNode : TypeNode -{ - public override bool IsValueType => false; - public override bool IsScalar => false; - - public override string ToString() => "string"; - public override bool Equals(TypeNode? other) => other is StringTypeNode; - public override int GetHashCode() => HashCode.Combine(typeof(StringTypeNode)); -} - -public static class NameMangler -{ - public static string Mangle(params IEnumerable types) - { - var readable = string.Join("_", types.Select(EncodeType)); - return ComputeShortHash(readable); - } - - private static string EncodeType(TypeNode node) => node switch - { - VoidTypeNode => "V", - BoolTypeNode => "B", - IntTypeNode i => (i.Signed ? "I" : "U") + i.Width, - FloatTypeNode f => "F" + f.Width, - CStringTypeNode => "CS", - StringTypeNode => "S", - PointerTypeNode p => "P" + EncodeType(p.BaseType), - ArrayTypeNode a => "A" + EncodeType(a.ElementType), - FuncTypeNode fn => "FN(" + string.Join(",", fn.Parameters.Select(EncodeType)) + ")" + EncodeType(fn.ReturnType), - StructTypeNode st => "ST(" + st.Module + "." + st.Name + ")", - _ => throw new NotSupportedException($"Cannot encode type: {node}") - }; - - private static string ComputeShortHash(string input) - { - var bytes = Encoding.UTF8.GetBytes(input); - var hash = SHA256.HashData(bytes); - return Convert.ToHexString(hash[..8]).ToLower(); - } -} \ No newline at end of file diff --git a/compiler/NubLang/TypeChecking/TypeChecker.cs b/compiler/NubLang/TypeChecking/TypeChecker.cs index d865872..e4cb05d 100644 --- a/compiler/NubLang/TypeChecking/TypeChecker.cs +++ b/compiler/NubLang/TypeChecking/TypeChecker.cs @@ -1,6 +1,5 @@ using System.Diagnostics; using System.Globalization; -using System.Security.Cryptography; using NubLang.Diagnostics; using NubLang.Modules; using NubLang.Parsing.Syntax; @@ -15,95 +14,121 @@ public sealed class TypeChecker private readonly Dictionary _visibleModules; private readonly Stack _scopes = []; - private readonly Stack _funcReturnTypes = []; - private readonly Dictionary<(string Module, string Name), TypeNode> _typeCache = new(); + private Scope _globalScope = new(); + private readonly Stack _funcReturnTypes = []; + private readonly Dictionary<(string Module, string Name), NubType> _typeCache = new(); private readonly HashSet<(string Module, string Name)> _resolvingTypes = []; - private Scope Scope => _scopes.Peek(); + private Scope CurrentScope => _scopes.Peek(); + private string CurrentModule => _syntaxTree.Metadata.ModuleName; public TypeChecker(SyntaxTree syntaxTree, ModuleRepository moduleRepository) { _syntaxTree = syntaxTree; _visibleModules = moduleRepository .Modules() - .Where(x => syntaxTree.Metadata.Imports.Contains(x.Key) || _syntaxTree.Metadata.ModuleName == x.Key) + .Where(x => syntaxTree.Metadata.Imports.Contains(x.Key) || CurrentModule == x.Key) .ToDictionary(); } public List Definitions { get; } = []; public List Diagnostics { get; } = []; - public List ReferencedStructTypes { get; } = []; + public List ReferencedStructTypes { get; } = []; public void Check() { _scopes.Clear(); + _globalScope = new Scope(); _funcReturnTypes.Clear(); - Diagnostics.Clear(); - Definitions.Clear(); - ReferencedStructTypes.Clear(); _typeCache.Clear(); _resolvingTypes.Clear(); + Diagnostics.Clear(); + Definitions.Clear(); + ReferencedStructTypes.Clear(); + foreach (var definition in _syntaxTree.Definitions) { + BeginScope(true); + try { - switch (definition) + Definitions.Add(definition switch { - case FuncSyntax funcSyntax: - Definitions.Add(CheckFuncDefinition(funcSyntax)); - break; - case StructSyntax structSyntax: - Definitions.Add(CheckStructDefinition(structSyntax)); - break; - case StructTemplateSyntax: - break; - default: - throw new ArgumentOutOfRangeException(nameof(definition)); - } + FuncSyntax funcSyntax => CheckFuncDefinition(funcSyntax), + StructSyntax structSyntax => CheckStructDefinition(structSyntax), + StructTemplateSyntax structTemplate => CheckStructTemplateDefinition(structTemplate), + _ => throw new ArgumentOutOfRangeException() + }); } catch (TypeCheckerException e) { Diagnostics.Add(e.Diagnostic); } + + EndScope(); } } + private Scope BeginScope(bool root) + { + var scope = root + ? _globalScope.SubScope() + : _scopes.Peek().SubScope(); + + _scopes.Push(scope); + return scope; + } + + private Scope EndScope() + { + return _scopes.Pop(); + } + private StructNode CheckStructDefinition(StructSyntax node) { var fieldTypes = node.Fields - .Select(x => new StructTypeField(x.Name, ResolveType(x.Type), x.Value.HasValue)) + .Select(x => new NubStructFieldType(x.Name, ResolveType(x.Type), x.Value.HasValue)) .ToList(); - var functionTypes = new List(); - foreach (var function in node.Functions) - { - var parameters = function.Signature.Parameters.Select(x => ResolveType(x.Type)).ToList(); - var funcType = new FuncTypeNode(parameters, ResolveType(function.Signature.ReturnType)); - functionTypes.Add(new StructTypeFunc(function.Name, function.Hook, funcType)); - } + var fieldFunctions = node.Functions + .Select(x => + { + var parameters = x.Signature.Parameters.Select(y => ResolveType(y.Type)).ToList(); + var returnType = ResolveType(x.Signature.ReturnType); + return new NubStructFuncType(x.Name, x.Hook, parameters, returnType); + }) + .ToList(); + + var structType = new NubStructType(CurrentModule, node.Name, fieldTypes, fieldFunctions); + + CurrentScope.DeclareVariable(new Variable("this", structType, VariableKind.RValue)); - var type = new StructTypeNode(_syntaxTree.Metadata.ModuleName, node.Name, fieldTypes, functionTypes); var fields = node.Fields.Select(CheckStructField).ToList(); - var functions = node.Functions.Select(x => CheckStructFunc(type, x)).ToList(); + var functions = node.Functions.Select(CheckStructFunc).ToList(); - return new StructNode(type, _syntaxTree.Metadata.ModuleName, node.Name, fields, functions); + return new StructNode(CurrentModule, node.Name, fields, functions); } - private StructFuncNode CheckStructFunc(StructTypeNode type, StructFuncSyntax function, Scope? scope = null) + private StructTemplateNode CheckStructTemplateDefinition(StructTemplateSyntax node) { - scope ??= new Scope(); - scope.DeclareVariable(new Variable("this", type, VariableKind.FunctionParameter)); + var fields = node.Fields.Select(CheckStructField).ToList(); + var functions = node.Functions.Select(CheckStructFunc).ToList(); + return new StructTemplateNode(CurrentModule, node.Name, node.TemplateArguments, fields, functions); + } + + private StructFuncNode CheckStructFunc(StructFuncSyntax function) + { foreach (var parameter in function.Signature.Parameters) { - scope.DeclareVariable(new Variable(parameter.Name, ResolveType(parameter.Type), VariableKind.FunctionParameter)); + CurrentScope.DeclareVariable(new Variable(parameter.Name, ResolveType(parameter.Type), VariableKind.RValue)); } _funcReturnTypes.Push(ResolveType(function.Signature.ReturnType)); - var body = CheckBlock(function.Body, scope); + var body = CheckBlock(function.Body, CurrentScope); _funcReturnTypes.Pop(); - return new StructFuncNode(function.Name, CheckFuncSignature(function.Signature), body); + return new StructFuncNode(function.Name, function.Hook, CheckFuncSignature(function.Signature), body); } private StructFieldNode CheckStructField(StructFieldSyntax field) @@ -122,7 +147,7 @@ public sealed class TypeChecker var scope = new Scope(); foreach (var parameter in node.Signature.Parameters) { - scope.DeclareVariable(new Variable(parameter.Name, ResolveType(parameter.Type), VariableKind.FunctionParameter)); + scope.DeclareVariable(new Variable(parameter.Name, ResolveType(parameter.Type), VariableKind.RValue)); } var signature = CheckFuncSignature(node.Signature); @@ -136,7 +161,7 @@ public sealed class TypeChecker if (!AlwaysReturns(body)) { - if (signature.ReturnType is VoidTypeNode) + if (signature.ReturnType is NubVoidType) { body.Statements.Add(new ReturnNode(Optional.Empty())); } @@ -152,7 +177,7 @@ public sealed class TypeChecker _funcReturnTypes.Pop(); } - return new FuncNode(_syntaxTree.Metadata.ModuleName, node.Name, node.ExternSymbol, signature, body); + return new FuncNode(CurrentModule, node.Name, node.ExternSymbol, signature, body); } private StatementNode CheckStatement(StatementSyntax node) @@ -187,7 +212,7 @@ public sealed class TypeChecker private IfNode CheckIf(IfSyntax statement) { - var condition = CheckExpression(statement.Condition, new BoolTypeNode()); + var condition = CheckExpression(statement.Condition, new NubBoolType()); var body = CheckBlock(statement.Body); var elseStatement = Optional.Empty>(); if (statement.Else.TryGetValue(out var elseSyntax)) @@ -224,7 +249,7 @@ public sealed class TypeChecker private VariableDeclarationNode CheckVariableDeclaration(VariableDeclarationSyntax statement) { - TypeNode? type = null; + NubType? type = null; ExpressionNode? assignmentNode = null; if (statement.ExplicitType.TryGetValue(out var explicitType)) @@ -243,7 +268,7 @@ public sealed class TypeChecker throw new TypeCheckerException(Diagnostic.Error($"Cannot infer type of variable {statement.Name}").At(statement).Build()); } - Scope.DeclareVariable(new Variable(statement.Name, type, VariableKind.Variable)); + CurrentScope.DeclareVariable(new Variable(statement.Name, type, VariableKind.LValue)); return new VariableDeclarationNode(statement.Name, Optional.OfNullable(assignmentNode), type); } @@ -255,7 +280,7 @@ public sealed class TypeChecker private WhileNode CheckWhile(WhileSyntax statement) { - var condition = CheckExpression(statement.Condition, new BoolTypeNode()); + var condition = CheckExpression(statement.Condition, new NubBoolType()); var body = CheckBlock(statement.Body); return new WhileNode(condition, body); } @@ -266,13 +291,13 @@ public sealed class TypeChecker switch (target.Type) { - case ArrayTypeNode arrayType: + case NubArrayType arrayType: { - var scope = Scope.SubScope(); - scope.DeclareVariable(new Variable(statement.ElementIdent, arrayType.ElementType, VariableKind.FunctionParameter)); + var scope = CurrentScope.SubScope(); + scope.DeclareVariable(new Variable(statement.ElementIdent, arrayType.ElementType, VariableKind.RValue)); if (statement.IndexIdent != null) { - scope.DeclareVariable(new Variable(statement.ElementIdent, new IntTypeNode(true, 64), VariableKind.FunctionParameter)); + scope.DeclareVariable(new Variable(statement.ElementIdent, new NubIntType(true, 64), VariableKind.RValue)); } var body = CheckBlock(statement.Body, scope); @@ -296,7 +321,7 @@ public sealed class TypeChecker return new FuncSignatureNode(parameters, ResolveType(statement.ReturnType)); } - private ExpressionNode CheckExpression(ExpressionSyntax node, TypeNode? expectedType = null) + private ExpressionNode CheckExpression(ExpressionSyntax node, NubType? expectedType = null) { var result = node switch { @@ -314,7 +339,7 @@ public sealed class TypeChecker StructFieldAccessSyntax expression => CheckStructFieldAccess(expression), StructInitializerSyntax expression => CheckStructInitializer(expression, expectedType), InterpretCompilerMacroSyntax expression => CheckExpression(expression.Target) with { Type = ResolveType(expression.Type) }, - SizeCompilerMacroSyntax expression => new SizeCompilerMacroNode(new IntTypeNode(false, 64), ResolveType(expression.Type)), + SizeCompilerMacroSyntax expression => new SizeCompilerMacroNode(new NubIntType(false, 64), ResolveType(expression.Type)), _ => throw new ArgumentOutOfRangeException(nameof(node)) }; @@ -323,7 +348,7 @@ public sealed class TypeChecker return result; } - if (result.Type is IntTypeNode sourceIntType && expectedType is IntTypeNode targetIntType) + if (result.Type is NubIntType sourceIntType && expectedType is NubIntType targetIntType) { if (sourceIntType.Signed == targetIntType.Signed && sourceIntType.Width < targetIntType.Width) { @@ -331,7 +356,7 @@ public sealed class TypeChecker } } - if (result.Type is FloatTypeNode sourceFloatType && expectedType is FloatTypeNode targetFloatType) + if (result.Type is NubFloatType sourceFloatType && expectedType is NubFloatType targetFloatType) { if (sourceFloatType.Width < targetFloatType.Width) { @@ -350,15 +375,15 @@ public sealed class TypeChecker throw new TypeCheckerException(Diagnostic.Error("Cannot take address of an rvalue").At(expression).Build()); } - var type = new PointerTypeNode(target.Type); + var type = new NubPointerType(target.Type); return new AddressOfNode(type, lvalue); } private ArrayIndexAccessNode CheckArrayIndexAccess(ArrayIndexAccessSyntax expression) { - var index = CheckExpression(expression.Index, new IntTypeNode(false, 64)); + var index = CheckExpression(expression.Index, new NubIntType(false, 64)); var target = CheckExpression(expression.Target); - if (target.Type is not ArrayTypeNode arrayType) + if (target.Type is not NubArrayType arrayType) { throw new TypeCheckerException(Diagnostic.Error($"Cannot use array indexer on type {target.Type}").At(expression).Build()); } @@ -369,7 +394,7 @@ public sealed class TypeChecker private ArrayInitializerNode CheckArrayInitializer(ArrayInitializerSyntax expression) { var elementType = ResolveType(expression.ElementType); - var type = new ArrayTypeNode(elementType); + var type = new NubArrayType(elementType); var capacity = CheckExpression(expression.Capacity); return new ArrayInitializerNode(type, capacity, elementType); } @@ -411,19 +436,19 @@ public sealed class TypeChecker case BinaryOperatorSyntax.LogicalOr: { var left = CheckExpression(expression.Left); - if (left.Type is not IntTypeNode or FloatTypeNode) + if (left.Type is not NubIntType or NubFloatType) { throw new TypeCheckerException(Diagnostic.Error("Logical operators must must be used with int or float types").At(expression.Left).Build()); } var right = CheckExpression(expression.Right, left.Type); - return new BinaryExpressionNode(new BoolTypeNode(), left, op, right); + return new BinaryExpressionNode(new NubBoolType(), left, op, right); } case BinaryOperatorSyntax.Plus: { var left = CheckExpression(expression.Left); - if (left.Type is IntTypeNode or FloatTypeNode or StringTypeNode or CStringTypeNode) + if (left.Type is NubIntType or NubFloatType or NubStringType or NubCStringType) { var right = CheckExpression(expression.Right, left.Type); return new BinaryExpressionNode(left.Type, left, op, right); @@ -437,7 +462,7 @@ public sealed class TypeChecker case BinaryOperatorSyntax.Modulo: { var left = CheckExpression(expression.Left); - if (left.Type is not IntTypeNode or FloatTypeNode) + if (left.Type is not NubIntType or NubFloatType) { throw new TypeCheckerException(Diagnostic.Error("Math operators must be used with int or float types").At(expression.Left).Build()); } @@ -453,7 +478,7 @@ public sealed class TypeChecker case BinaryOperatorSyntax.BitwiseOr: { var left = CheckExpression(expression.Left); - if (left.Type is not IntTypeNode) + if (left.Type is not NubIntType) { throw new TypeCheckerException(Diagnostic.Error("Bitwise operators must be used with int types").At(expression.Left).Build()); } @@ -476,7 +501,7 @@ public sealed class TypeChecker case UnaryOperatorSyntax.Negate: { var operand = CheckExpression(expression.Operand); - if (operand.Type is not IntTypeNode { Signed: false } or FloatTypeNode) + if (operand.Type is not NubIntType { Signed: false } or NubFloatType) { throw new TypeCheckerException(Diagnostic.Error("Negation operator must be used with signed integer or float types").At(expression).Build()); } @@ -486,7 +511,7 @@ public sealed class TypeChecker case UnaryOperatorSyntax.Invert: { var operand = CheckExpression(expression.Operand); - if (operand.Type is not BoolTypeNode) + if (operand.Type is not NubBoolType) { throw new TypeCheckerException(Diagnostic.Error("Invert operator must be used with booleans").At(expression).Build()); } @@ -503,7 +528,7 @@ public sealed class TypeChecker private DereferenceNode CheckDereference(DereferenceSyntax expression) { var target = CheckExpression(expression.Target); - if (target.Type is not PointerTypeNode pointerType) + if (target.Type is not NubPointerType pointerType) { throw new TypeCheckerException(Diagnostic.Error($"Cannot dereference non-pointer type {target.Type}").At(expression).Build()); } @@ -514,7 +539,7 @@ public sealed class TypeChecker private FuncCallNode CheckFuncCall(FuncCallSyntax expression) { var accessor = CheckExpression(expression.Expression); - if (accessor.Type is not FuncTypeNode funcType) + if (accessor.Type is not NubFuncType funcType) { throw new TypeCheckerException(Diagnostic.Error($"Cannot call non-function type {accessor.Type}").At(expression.Expression).Build()); } @@ -552,18 +577,21 @@ public sealed class TypeChecker { // todo(nub31): When adding interfaces, also support other types than structs var target = CheckExpression(expression.Target); - if (target.Type is StructTypeNode structType) + if (target.Type is NubStructType structType) { var function = structType.Functions.FirstOrDefault(x => x.Name == expression.Name); if (function == null) { - throw new TypeCheckerException(Diagnostic.Error($"Function {expression.Name} not found on struct {structType}").At(expression).Build()); + throw new TypeCheckerException(Diagnostic + .Error($"Function {expression.Name} not found on struct {structType}") + .At(expression) + .Build()); } - if (expression.Parameters.Count != function.Type.Parameters.Count) + if (expression.Parameters.Count != function.Parameters.Count) { throw new TypeCheckerException(Diagnostic - .Error($"Function {function.Type} expects {function.Type.Parameters.Count} parameters but got {expression.Parameters.Count}") + .Error($"Function {function.Name} expects {function.Parameters.Count} parameters but got {expression.Parameters.Count}") .At(expression.Parameters.LastOrDefault(expression)) .Build()); } @@ -572,7 +600,7 @@ public sealed class TypeChecker for (var i = 0; i < expression.Parameters.Count; i++) { var parameter = expression.Parameters[i]; - var expectedType = function.Type.Parameters[i]; + var expectedType = function.Parameters[i]; var parameterExpression = CheckExpression(parameter, expectedType); if (parameterExpression.Type != expectedType) @@ -586,7 +614,7 @@ public sealed class TypeChecker parameters.Add(parameterExpression); } - return new StructFuncCallNode(function.Type.ReturnType, expression.Name, structType, target, parameters); + return new StructFuncCallNode(function.ReturnType, structType.Module, structType.Name, expression.Name, target, parameters); } throw new TypeCheckerException(Diagnostic @@ -598,35 +626,26 @@ public sealed class TypeChecker private ExpressionNode CheckLocalIdentifier(LocalIdentifierSyntax expression) { // First, look in the current scope for a matching identifier - var scopeIdent = Scope.LookupVariable(expression.Name); + var scopeIdent = CurrentScope.LookupVariable(expression.Name); if (scopeIdent != null) { - switch (scopeIdent.Kind) + return scopeIdent.Kind switch { - case VariableKind.Variable: - { - return new VariableIdentifierNode(scopeIdent.Type, expression.Name); - } - case VariableKind.FunctionParameter: - { - return new FuncParameterIdentifierNode(scopeIdent.Type, expression.Name); - } - default: - { - throw new ArgumentOutOfRangeException(); - } - } + VariableKind.LValue => new LValueIdentifierNode(scopeIdent.Type, expression.Name), + VariableKind.RValue => new RValueIdentifierNode(scopeIdent.Type, expression.Name), + _ => throw new ArgumentOutOfRangeException() + }; } // Second, look in the current module for a function matching the identifier - var module = _visibleModules[_syntaxTree.Metadata.ModuleName]; + var module = _visibleModules[CurrentModule]; var function = module.Functions(true).FirstOrDefault(x => x.Name == expression.Name); if (function != null) { var parameters = function.Parameters.Select(x => ResolveType(x.Type)).ToList(); - var type = new FuncTypeNode(parameters, ResolveType(function.ReturnType)); - return new FuncIdentifierNode(type, _syntaxTree.Metadata.ModuleName, expression.Name, function.ExternSymbol); + var type = new NubFuncType(parameters, ResolveType(function.ReturnType)); + return new FuncIdentifierNode(type, CurrentModule, expression.Name, function.ExternSymbol); } throw new TypeCheckerException(Diagnostic.Error($"Symbol {expression.Name} not found").At(expression).Build()); @@ -643,14 +662,14 @@ public sealed class TypeChecker .Build()); } - var includePrivate = expression.Module == _syntaxTree.Metadata.ModuleName; + var includePrivate = expression.Module == CurrentModule; // First, look for the exported function in the specified module var function = module.Functions(includePrivate).FirstOrDefault(x => x.Name == expression.Name); if (function != null) { var parameters = function.Parameters.Select(x => ResolveType(x.Type)).ToList(); - var type = new FuncTypeNode(parameters, ResolveType(function.ReturnType)); + var type = new NubFuncType(parameters, ResolveType(function.ReturnType)); return new FuncIdentifierNode(type, expression.Module, expression.Name, function.ExternSymbol); } @@ -660,20 +679,20 @@ public sealed class TypeChecker .Build()); } - private ExpressionNode CheckLiteral(LiteralSyntax expression, TypeNode? expectedType) + private ExpressionNode CheckLiteral(LiteralSyntax expression, NubType? expectedType) { switch (expression.Kind) { case LiteralKind.Integer: { - var type = expectedType as IntTypeNode ?? new IntTypeNode(true, 64); + var type = expectedType as NubIntType ?? new NubIntType(true, 64); return type.Signed ? new IntLiteralNode(type, long.Parse(expression.Value)) : new UIntLiteralNode(type, ulong.Parse(expression.Value)); } case LiteralKind.Float: { - var type = expectedType as FloatTypeNode ?? new FloatTypeNode(64); + var type = expectedType as NubFloatType ?? new NubFloatType(64); return type.Width == 32 ? new Float32LiteralNode(type, float.Parse(expression.Value, CultureInfo.InvariantCulture)) : new Float64LiteralNode(type, double.Parse(expression.Value, CultureInfo.InvariantCulture)); @@ -682,14 +701,14 @@ public sealed class TypeChecker { return expectedType switch { - CStringTypeNode => new CStringLiteralNode(expectedType, expression.Value), - StringTypeNode => new StringLiteralNode(expectedType, expression.Value), - _ => new CStringLiteralNode(new CStringTypeNode(), expression.Value) + NubCStringType => new CStringLiteralNode(expectedType, expression.Value), + NubStringType => new StringLiteralNode(expectedType, expression.Value), + _ => new CStringLiteralNode(new NubCStringType(), expression.Value) }; } case LiteralKind.Bool: { - return new BoolLiteralNode(new BoolTypeNode(), bool.Parse(expression.Value)); + return new BoolLiteralNode(new NubBoolType(), bool.Parse(expression.Value)); } default: { @@ -702,7 +721,7 @@ public sealed class TypeChecker { var target = CheckExpression(expression.Target); - if (target.Type is not StructTypeNode structType) + if (target.Type is not NubStructType structType) { throw new TypeCheckerException(Diagnostic .Error($"Cannot access struct member on non-struct type {target.Type}") @@ -719,24 +738,24 @@ public sealed class TypeChecker .Build()); } - return new StructFieldAccessNode(field.Type, structType, target, expression.Member); + return new StructFieldAccessNode(field.Type, target, expression.Member); } - private StructInitializerNode CheckStructInitializer(StructInitializerSyntax expression, TypeNode? expectedType) + private StructInitializerNode CheckStructInitializer(StructInitializerSyntax expression, NubType? expectedType) { - StructTypeNode? structType = null; + NubStructType? structType = null; if (expression.StructType.TryGetValue(out var customType)) { var checkedType = ResolveType(customType); - if (checkedType is not StructTypeNode checkedStructType) + if (checkedType is not NubStructType checkedStructType) { throw new UnreachableException("Parser fucked up"); } structType = checkedStructType; } - else if (expectedType is StructTypeNode expectedStructType) + else if (expectedType is NubStructType expectedStructType) { structType = expectedStructType; } @@ -788,7 +807,7 @@ public sealed class TypeChecker { var statements = new List(); - _scopes.Push(scope ?? Scope.SubScope()); + _scopes.Push(scope ?? CurrentScope.SubScope()); var reachable = true; var warnedUnreachable = false; @@ -831,87 +850,27 @@ public sealed class TypeChecker }; } - private TypeNode ResolveType(TypeSyntax type) + private NubType ResolveType(TypeSyntax type) { return type switch { - BoolTypeSyntax => new BoolTypeNode(), - CStringTypeSyntax => new CStringTypeNode(), - IntTypeSyntax i => new IntTypeNode(i.Signed, i.Width), + BoolTypeSyntax => new NubBoolType(), + CStringTypeSyntax => new NubCStringType(), + IntTypeSyntax i => new NubIntType(i.Signed, i.Width), CustomTypeSyntax c => ResolveCustomType(c), - FloatTypeSyntax f => new FloatTypeNode(f.Width), - FuncTypeSyntax func => new FuncTypeNode(func.Parameters.Select(ResolveType).ToList(), ResolveType(func.ReturnType)), - ArrayTypeSyntax arr => new ArrayTypeNode(ResolveType(arr.BaseType)), - PointerTypeSyntax ptr => new PointerTypeNode(ResolveType(ptr.BaseType)), - StringTypeSyntax => new StringTypeNode(), - TemplateTypeSyntax template => ResolveTemplateType(template), - VoidTypeSyntax => new VoidTypeNode(), + FloatTypeSyntax f => new NubFloatType(f.Width), + FuncTypeSyntax func => new NubFuncType(func.Parameters.Select(ResolveType).ToList(), ResolveType(func.ReturnType)), + ArrayTypeSyntax arr => new NubArrayType(ResolveType(arr.BaseType)), + PointerTypeSyntax ptr => new NubPointerType(ResolveType(ptr.BaseType)), + StringTypeSyntax => new NubStringType(), + TemplateTypeSyntax template => throw new NotImplementedException(), + VoidTypeSyntax => new NubVoidType(), _ => throw new NotSupportedException($"Unknown type syntax: {type}") }; } - private StructTypeNode ResolveTemplateType(TemplateTypeSyntax template) + private NubType ResolveCustomType(CustomTypeSyntax customType) { - // todo(nub31): Add module support for template types - var definition = _syntaxTree.Definitions - .OfType() - .FirstOrDefault(x => x.Name == template.Name); - - if (definition == null) - { - throw new TypeCheckerException(Diagnostic.Error($"Template {template.Name} does not exist").At(template).Build()); - } - - if (definition.TemplateArguments.Count != template.TemplateParameters.Count) - { - throw new TypeCheckerException(Diagnostic - .Error($"Template {template.Name} has {definition.TemplateArguments.Count} arguments, but usage only has {template.TemplateParameters.Count} parameters") - .At(template) - .Build()); - } - - var scope = new Scope(); - - for (var i = 0; i < definition.TemplateArguments.Count; i++) - { - scope.DeclareGenericType(definition.TemplateArguments[i], ResolveType(template.TemplateParameters[i])); - } - - _scopes.Push(scope); - - var fields = definition.Fields - .Select(x => new StructTypeField(x.Name, ResolveType(x.Type), x.Value.HasValue)) - .ToList(); - - var functions = definition.Functions - .Select(x => new StructTypeFunc(x.Name, x.Hook, new FuncTypeNode(x.Signature.Parameters.Select(y => ResolveType(y.Type)).ToList(), ResolveType(x.Signature.ReturnType)))) - .ToList(); - - var name = $"{template.Name}.{NameMangler.Mangle(template.TemplateParameters.Select(ResolveType))}"; - - var type = new StructTypeNode(template.Module, name, fields, functions); - - var checkedFields = definition.Fields.Select(CheckStructField).ToList(); - var checkedFunctions = definition.Functions.Select(x => CheckStructFunc(type, x, scope)).ToList(); - - Definitions.Add(new StructNode(type, template.Module, name, checkedFields, checkedFunctions)); - - _scopes.Pop(); - - return type; - } - - private TypeNode ResolveCustomType(CustomTypeSyntax customType) - { - if (_syntaxTree.Metadata.ModuleName == customType.Module && _scopes.TryPeek(out var scope)) - { - var generic = scope.LookupGenericType(customType.Name); - if (generic != null) - { - return generic; - } - } - var key = (customType.Module, customType.Name); if (_typeCache.TryGetValue(key, out var cachedType)) @@ -921,7 +880,7 @@ public sealed class TypeChecker if (!_resolvingTypes.Add(key)) { - var placeholder = new StructTypeNode(customType.Module, customType.Name, [], []); + var placeholder = new NubStructType(customType.Module, customType.Name, [], []); _typeCache[key] = placeholder; return placeholder; } @@ -937,22 +896,21 @@ public sealed class TypeChecker .Build()); } - var includePrivate = customType.Module == _syntaxTree.Metadata.ModuleName; + var includePrivate = customType.Module == CurrentModule; var structType = module.StructTypes(includePrivate).FirstOrDefault(x => x.Name == customType.Name); if (structType != null) { - var result = new StructTypeNode(customType.Module, structType.Name, [], []); + var result = new NubStructType(customType.Module, structType.Name, [], []); _typeCache[key] = result; - var fields = structType.Fields.Select(x => new StructTypeField(x.Name, ResolveType(x.Type), x.HasDefaultValue)).ToList(); + var fields = structType.Fields.Select(x => new NubStructFieldType(x.Name, ResolveType(x.Type), x.HasDefaultValue)).ToList(); result.Fields.AddRange(fields); foreach (var function in structType.Functions) { var parameters = function.Parameters.Select(x => ResolveType(x.Type)).ToList(); - var type = new FuncTypeNode(parameters, ResolveType(function.ReturnType)); - result.Functions.Add(new StructTypeFunc(function.Name, function.Hook, type)); + result.Functions.Add(new NubStructFuncType(function.Name, function.Hook, parameters, ResolveType(function.ReturnType))); } ReferencedStructTypes.Add(result); @@ -973,16 +931,15 @@ public sealed class TypeChecker public enum VariableKind { - Variable, - FunctionParameter + LValue, + RValue } -public record Variable(string Name, TypeNode Type, VariableKind Kind); +public record Variable(string Name, NubType Type, VariableKind Kind); public class Scope(Scope? parent = null) { private readonly List _variables = []; - private readonly Dictionary _typeArguments = []; public Variable? LookupVariable(string name) { @@ -1000,16 +957,6 @@ public class Scope(Scope? parent = null) _variables.Add(variable); } - public void DeclareGenericType(string typeArgument, TypeNode type) - { - _typeArguments[typeArgument] = type; - } - - public TypeNode? LookupGenericType(string typeArgument) - { - return _typeArguments.GetValueOrDefault(typeArgument); - } - public Scope SubScope() { return new Scope(this); diff --git a/example/src/main.nub b/example/src/main.nub index 024d0a9..12ded5a 100644 --- a/example/src/main.nub +++ b/example/src/main.nub @@ -11,49 +11,49 @@ struct Human extern "main" func main(args: []cstring): i64 { - let x: ref = {} + // let x: ref = {} - test(x) + // test(x) return 0 } -func test(x: ref) -{ +// func test(x: ref) +// { -} +// } -struct ref -{ - value: ^T - count: ^u64 +// struct ref +// { +// value: ^T +// count: ^u64 - @oncreate - func on_create() - { - puts("on_create") - this.value = @interpret(^T, malloc(@size(T))) - this.count = @interpret(^u64, malloc(@size(u64))) - this.count^ = 1 - } +// @oncreate +// func on_create() +// { +// puts("on_create") +// this.value = @interpret(^T, malloc(@size(T))) +// this.count = @interpret(^u64, malloc(@size(u64))) +// this.count^ = 1 +// } - @oncopy - func on_copy() - { - puts("on_copy") - this.count^ = this.count^ + 1 - } +// @oncopy +// func on_copy() +// { +// puts("on_copy") +// this.count^ = this.count^ + 1 +// } - @ondestroy - func on_destroy() - { - puts("on_destroy") - this.count^ = this.count^ - 1 - if this.count^ <= 0 - { - puts("free") - free(@interpret(^void, this.value)) - free(@interpret(^void, this.count)) - } - } -} \ No newline at end of file +// @ondestroy +// func on_destroy() +// { +// puts("on_destroy") +// this.count^ = this.count^ - 1 +// if this.count^ <= 0 +// { +// puts("free") +// free(@interpret(^void, this.value)) +// free(@interpret(^void, this.count)) +// } +// } +// } \ No newline at end of file