From 6c56404f1c4f28d74d1d77762f46c46c47695a23 Mon Sep 17 00:00:00 2001 From: nub31 Date: Sat, 20 Sep 2025 18:17:40 +0200 Subject: [PATCH] ... --- .../NubLang/Generation/QBE/QBEGenerator.cs | 1 + compiler/NubLang/Modules/ModuleRepository.cs | 5 + compiler/NubLang/Parsing/Parser.cs | 72 +++++- .../Parsing/Syntax/DefinitionSyntax.cs | 4 +- .../Parsing/Syntax/ExpressionSyntax.cs | 6 +- compiler/NubLang/Parsing/Syntax/TypeSyntax.cs | 4 +- .../TypeChecking/Node/ExpressionNode.cs | 4 +- .../NubLang/TypeChecking/Node/TypeNode.cs | 36 ++- compiler/NubLang/TypeChecking/TypeChecker.cs | 225 +++++++++++++----- example/src/main.nub | 20 +- 10 files changed, 292 insertions(+), 85 deletions(-) diff --git a/compiler/NubLang/Generation/QBE/QBEGenerator.cs b/compiler/NubLang/Generation/QBE/QBEGenerator.cs index 588fad7..bed24a1 100644 --- a/compiler/NubLang/Generation/QBE/QBEGenerator.cs +++ b/compiler/NubLang/Generation/QBE/QBEGenerator.cs @@ -693,6 +693,7 @@ public class QBEGenerator StructFuncCallNode expr => EmitStructFuncCall(expr), StructInitializerNode expr => EmitStructInitializer(expr), UnaryExpressionNode expr => EmitUnaryExpression(expr), + SizeCompilerMacroNode expr => $"{SizeOf(expr.TargetType)}", _ => throw new ArgumentOutOfRangeException(nameof(rValue)) }; } diff --git a/compiler/NubLang/Modules/ModuleRepository.cs b/compiler/NubLang/Modules/ModuleRepository.cs index c4a209c..30ea56a 100644 --- a/compiler/NubLang/Modules/ModuleRepository.cs +++ b/compiler/NubLang/Modules/ModuleRepository.cs @@ -42,6 +42,11 @@ public class ModuleRepository module.RegisterStruct(structDef.Exported, structDef.Name, fields, functions); break; } + case StructTemplateSyntax structDef: + { + // todo(nub31): Include templates in modules + break; + } default: { throw new ArgumentOutOfRangeException(nameof(definition)); diff --git a/compiler/NubLang/Parsing/Parser.cs b/compiler/NubLang/Parsing/Parser.cs index dee4fb7..d0dae35 100644 --- a/compiler/NubLang/Parsing/Parser.cs +++ b/compiler/NubLang/Parsing/Parser.cs @@ -163,10 +163,20 @@ public sealed class Parser return new FuncSyntax(GetTokens(startIndex), name.Value, exported, externSymbol, signature, body); } - private StructSyntax ParseStruct(int startIndex, bool exported) + private DefinitionSyntax ParseStruct(int startIndex, bool exported) { var name = ExpectIdentifier(); + var templateArguments = new List(); + if (TryExpectSymbol(Symbol.LessThan)) + { + while (!TryExpectSymbol(Symbol.GreaterThan)) + { + templateArguments.Add(ExpectIdentifier().Value); + TryExpectSymbol(Symbol.Comma); + } + } + ExpectSymbol(Symbol.OpenBrace); List fields = []; @@ -207,6 +217,11 @@ public sealed class Parser } } + if (templateArguments.Count > 0) + { + return new StructTemplateSyntax(GetTokens(startIndex), templateArguments, name.Value, exported, fields, funcs); + } + return new StructSyntax(GetTokens(startIndex), name.Value, exported, fields, funcs); } @@ -462,6 +477,7 @@ public sealed class Parser Symbol.OpenBracket => ParseArrayInitializer(startIndex), Symbol.OpenBrace => new StructInitializerSyntax(GetTokens(startIndex), Optional.Empty(), ParseStructInitializerBody()), Symbol.Struct => ParseStructInitializer(startIndex), + Symbol.At => ParseCompilerMacro(startIndex), _ => throw new ParseException(Diagnostic .Error($"Unexpected symbol '{symbolToken.Symbol}' in expression") .WithHelp("Expected '(', '-', '!', '[' or '{'") @@ -478,6 +494,34 @@ public sealed class Parser return ParsePostfixOperators(expr); } + private ExpressionSyntax ParseCompilerMacro(int startIndex) + { + var name = ExpectIdentifier(); + ExpectSymbol(Symbol.OpenParen); + + switch (name.Value) + { + case "size": + { + var type = ParseType(); + ExpectSymbol(Symbol.CloseParen); + return new SizeCompilerMacroSyntax(GetTokens(startIndex), type); + } + case "interpret": + { + var type = ParseType(); + ExpectSymbol(Symbol.Comma); + var expression = ParseExpression(); + ExpectSymbol(Symbol.CloseParen); + return new InterpretCompilerMacroSyntax(GetTokens(startIndex), type, expression); + } + default: + { + throw new ParseException(Diagnostic.Error("Unknown compiler macro").At(name).Build()); + } + } + } + private ExpressionSyntax ParseIdentifier(int startIndex, IdentifierToken identifier) { if (TryExpectSymbol(Symbol.DoubleColon)) @@ -612,7 +656,7 @@ public sealed class Parser ExpectSymbol(Symbol.OpenBrace); return ParseBlock(startIndex); } - + private BlockSyntax ParseBlock(int startIndex) { List statements = []; @@ -699,13 +743,31 @@ public sealed class Parser return new BoolTypeSyntax(GetTokens(startIndex)); default: { + var module = _moduleName; + if (TryExpectSymbol(Symbol.DoubleColon)) { - var customTypeName = ExpectIdentifier().Value; - return new CustomTypeSyntax(GetTokens(startIndex), name.Value, customTypeName); + var customTypeName = ExpectIdentifier(); + module = name.Value; + name = customTypeName; } - return new CustomTypeSyntax(GetTokens(startIndex), _moduleName, name.Value); + var templateParameters = new List(); + if (TryExpectSymbol(Symbol.LessThan)) + { + while (!TryExpectSymbol(Symbol.GreaterThan)) + { + templateParameters.Add(ParseType()); + TryExpectSymbol(Symbol.Comma); + } + } + + if (templateParameters.Count > 0) + { + return new TemplateTypeSyntax(GetTokens(startIndex), templateParameters, module, name.Value); + } + + return new CustomTypeSyntax(GetTokens(startIndex), module, name.Value); } } } diff --git a/compiler/NubLang/Parsing/Syntax/DefinitionSyntax.cs b/compiler/NubLang/Parsing/Syntax/DefinitionSyntax.cs index 1321433..c9c0651 100644 --- a/compiler/NubLang/Parsing/Syntax/DefinitionSyntax.cs +++ b/compiler/NubLang/Parsing/Syntax/DefinitionSyntax.cs @@ -14,4 +14,6 @@ public record StructFieldSyntax(IEnumerable Tokens, string Name, TypeSynt public record StructFuncSyntax(IEnumerable Tokens, string Name, string? Hook, FuncSignatureSyntax Signature, BlockSyntax Body) : SyntaxNode(Tokens); -public record StructSyntax(IEnumerable Tokens, string Name, bool Exported, List Fields, List Functions) : DefinitionSyntax(Tokens, Name, Exported); \ No newline at end of file +public record StructSyntax(IEnumerable Tokens, string Name, bool Exported, List Fields, List Functions) : DefinitionSyntax(Tokens, Name, Exported); + +public record StructTemplateSyntax(IEnumerable Tokens, List TemplateArguments, string Name, bool Exported, List Fields, List Functions) : DefinitionSyntax(Tokens, Name, Exported); \ No newline at end of file diff --git a/compiler/NubLang/Parsing/Syntax/ExpressionSyntax.cs b/compiler/NubLang/Parsing/Syntax/ExpressionSyntax.cs index ede9750..5d263c1 100644 --- a/compiler/NubLang/Parsing/Syntax/ExpressionSyntax.cs +++ b/compiler/NubLang/Parsing/Syntax/ExpressionSyntax.cs @@ -56,4 +56,8 @@ public record StructFieldAccessSyntax(IEnumerable Tokens, ExpressionSynta public record StructInitializerSyntax(IEnumerable Tokens, Optional StructType, Dictionary Initializers) : ExpressionSyntax(Tokens); -public record DereferenceSyntax(IEnumerable Tokens, ExpressionSyntax Target) : ExpressionSyntax(Tokens); \ No newline at end of file +public record DereferenceSyntax(IEnumerable Tokens, ExpressionSyntax Target) : ExpressionSyntax(Tokens); + +public record SizeCompilerMacroSyntax(IEnumerable Tokens, TypeSyntax Type) : ExpressionSyntax(Tokens); + +public record InterpretCompilerMacroSyntax(IEnumerable Tokens, TypeSyntax Type, ExpressionSyntax Target) : ExpressionSyntax(Tokens); \ No newline at end of file diff --git a/compiler/NubLang/Parsing/Syntax/TypeSyntax.cs b/compiler/NubLang/Parsing/Syntax/TypeSyntax.cs index 4f18cd4..809a1fa 100644 --- a/compiler/NubLang/Parsing/Syntax/TypeSyntax.cs +++ b/compiler/NubLang/Parsing/Syntax/TypeSyntax.cs @@ -22,4 +22,6 @@ public record CStringTypeSyntax(IEnumerable Tokens) : TypeSyntax(Tokens); public record ArrayTypeSyntax(IEnumerable Tokens, TypeSyntax BaseType) : TypeSyntax(Tokens); -public record CustomTypeSyntax(IEnumerable Tokens, string Module, string Name) : TypeSyntax(Tokens); \ No newline at end of file +public record CustomTypeSyntax(IEnumerable Tokens, string Module, string Name) : TypeSyntax(Tokens); + +public record TemplateTypeSyntax(IEnumerable Tokens, List TemplateParameters, string Module, string Name) : TypeSyntax(Tokens); \ No newline at end of file diff --git a/compiler/NubLang/TypeChecking/Node/ExpressionNode.cs b/compiler/NubLang/TypeChecking/Node/ExpressionNode.cs index df4824f..27db6d2 100644 --- a/compiler/NubLang/TypeChecking/Node/ExpressionNode.cs +++ b/compiler/NubLang/TypeChecking/Node/ExpressionNode.cs @@ -76,4 +76,6 @@ public record DereferenceNode(TypeNode Type, ExpressionNode Expression) : LValue public record ConvertIntNode(TypeNode Type, ExpressionNode Value, IntTypeNode ValueType, IntTypeNode TargetType) : RValueExpressionNode(Type); -public record ConvertFloatNode(TypeNode Type, ExpressionNode Value, FloatTypeNode ValueType, FloatTypeNode TargetType) : RValueExpressionNode(Type); \ No newline at end of file +public record ConvertFloatNode(TypeNode Type, ExpressionNode Value, FloatTypeNode ValueType, FloatTypeNode TargetType) : RValueExpressionNode(Type); + +public record SizeCompilerMacroNode(TypeNode Type, TypeNode TargetType) : RValueExpressionNode(Type); \ No newline at end of file diff --git a/compiler/NubLang/TypeChecking/Node/TypeNode.cs b/compiler/NubLang/TypeChecking/Node/TypeNode.cs index af43155..dd042f9 100644 --- a/compiler/NubLang/TypeChecking/Node/TypeNode.cs +++ b/compiler/NubLang/TypeChecking/Node/TypeNode.cs @@ -1,4 +1,7 @@ -namespace NubLang.TypeChecking.Node; +using System.Security.Cryptography; +using System.Text; + +namespace NubLang.TypeChecking.Node; public abstract class TypeNode : IEquatable { @@ -156,4 +159,35 @@ public class StringTypeNode : TypeNode public override string ToString() => "string"; public override bool Equals(TypeNode? other) => other is StringTypeNode; public override int GetHashCode() => HashCode.Combine(typeof(StringTypeNode)); +} + +public static class NameMangler +{ + public static string Mangle(params IEnumerable types) + { + var readable = string.Join("_", types.Select(EncodeType)); + return ComputeShortHash(readable); + } + + private static string EncodeType(TypeNode node) => node switch + { + VoidTypeNode => "V", + BoolTypeNode => "B", + IntTypeNode i => (i.Signed ? "I" : "U") + i.Width, + FloatTypeNode f => "F" + f.Width, + CStringTypeNode => "CS", + StringTypeNode => "S", + PointerTypeNode p => "P" + EncodeType(p.BaseType), + ArrayTypeNode a => "A" + EncodeType(a.ElementType), + FuncTypeNode fn => "FN(" + string.Join(",", fn.Parameters.Select(EncodeType)) + ")" + EncodeType(fn.ReturnType), + StructTypeNode st => "ST(" + st.Module + "." + st.Name + ")", + _ => throw new NotSupportedException($"Cannot encode type: {node}") + }; + + private static string ComputeShortHash(string input) + { + var bytes = Encoding.UTF8.GetBytes(input); + var hash = SHA256.HashData(bytes); + return Convert.ToHexString(hash[..8]).ToLower(); + } } \ No newline at end of file diff --git a/compiler/NubLang/TypeChecking/TypeChecker.cs b/compiler/NubLang/TypeChecking/TypeChecker.cs index 60c653d..d865872 100644 --- a/compiler/NubLang/TypeChecking/TypeChecker.cs +++ b/compiler/NubLang/TypeChecking/TypeChecker.cs @@ -1,5 +1,6 @@ using System.Diagnostics; using System.Globalization; +using System.Security.Cryptography; using NubLang.Diagnostics; using NubLang.Modules; using NubLang.Parsing.Syntax; @@ -15,6 +16,8 @@ public sealed class TypeChecker private readonly Stack _scopes = []; private readonly Stack _funcReturnTypes = []; + private readonly Dictionary<(string Module, string Name), TypeNode> _typeCache = new(); + private readonly HashSet<(string Module, string Name)> _resolvingTypes = []; private Scope Scope => _scopes.Peek(); @@ -38,12 +41,26 @@ public sealed class TypeChecker Diagnostics.Clear(); Definitions.Clear(); ReferencedStructTypes.Clear(); + _typeCache.Clear(); + _resolvingTypes.Clear(); foreach (var definition in _syntaxTree.Definitions) { try { - Definitions.Add(CheckDefinition(definition)); + switch (definition) + { + case FuncSyntax funcSyntax: + Definitions.Add(CheckFuncDefinition(funcSyntax)); + break; + case StructSyntax structSyntax: + Definitions.Add(CheckStructDefinition(structSyntax)); + break; + case StructTemplateSyntax: + break; + default: + throw new ArgumentOutOfRangeException(nameof(definition)); + } } catch (TypeCheckerException e) { @@ -52,16 +69,6 @@ public sealed class TypeChecker } } - private DefinitionNode CheckDefinition(DefinitionSyntax node) - { - return node switch - { - FuncSyntax definition => CheckFuncDefinition(definition), - StructSyntax definition => CheckStructDefinition(definition), - _ => throw new ArgumentOutOfRangeException(nameof(node)) - }; - } - private StructNode CheckStructDefinition(StructSyntax node) { var fieldTypes = node.Fields @@ -77,45 +84,45 @@ public sealed class TypeChecker } var type = new StructTypeNode(_syntaxTree.Metadata.ModuleName, node.Name, fieldTypes, functionTypes); - - var fields = new List(); - foreach (var field in node.Fields) - { - var value = Optional.Empty(); - if (field.Value.HasValue) - { - value = CheckExpression(field.Value.Value, ResolveType(field.Type)); - } - - fields.Add(new StructFieldNode(field.Name, ResolveType(field.Type), value)); - } - - var functions = new List(); - foreach (var function in node.Functions) - { - var scope = new Scope(); - scope.Declare(new Identifier("this", type, IdentifierKind.FunctionParameter)); - - foreach (var parameter in function.Signature.Parameters) - { - scope.Declare(new Identifier(parameter.Name, ResolveType(parameter.Type), IdentifierKind.FunctionParameter)); - } - - _funcReturnTypes.Push(ResolveType(function.Signature.ReturnType)); - var body = CheckBlock(function.Body, scope); - _funcReturnTypes.Pop(); - functions.Add(new StructFuncNode(function.Name, CheckFuncSignature(function.Signature), body)); - } + var fields = node.Fields.Select(CheckStructField).ToList(); + var functions = node.Functions.Select(x => CheckStructFunc(type, x)).ToList(); return new StructNode(type, _syntaxTree.Metadata.ModuleName, node.Name, fields, functions); } + private StructFuncNode CheckStructFunc(StructTypeNode type, StructFuncSyntax function, Scope? scope = null) + { + scope ??= new Scope(); + scope.DeclareVariable(new Variable("this", type, VariableKind.FunctionParameter)); + + foreach (var parameter in function.Signature.Parameters) + { + scope.DeclareVariable(new Variable(parameter.Name, ResolveType(parameter.Type), VariableKind.FunctionParameter)); + } + + _funcReturnTypes.Push(ResolveType(function.Signature.ReturnType)); + var body = CheckBlock(function.Body, scope); + _funcReturnTypes.Pop(); + return new StructFuncNode(function.Name, CheckFuncSignature(function.Signature), body); + } + + private StructFieldNode CheckStructField(StructFieldSyntax field) + { + var value = Optional.Empty(); + if (field.Value.HasValue) + { + value = CheckExpression(field.Value.Value, ResolveType(field.Type)); + } + + return new StructFieldNode(field.Name, ResolveType(field.Type), value); + } + private FuncNode CheckFuncDefinition(FuncSyntax node) { var scope = new Scope(); foreach (var parameter in node.Signature.Parameters) { - scope.Declare(new Identifier(parameter.Name, ResolveType(parameter.Type), IdentifierKind.FunctionParameter)); + scope.DeclareVariable(new Variable(parameter.Name, ResolveType(parameter.Type), VariableKind.FunctionParameter)); } var signature = CheckFuncSignature(node.Signature); @@ -236,7 +243,7 @@ public sealed class TypeChecker throw new TypeCheckerException(Diagnostic.Error($"Cannot infer type of variable {statement.Name}").At(statement).Build()); } - Scope.Declare(new Identifier(statement.Name, type, IdentifierKind.Variable)); + Scope.DeclareVariable(new Variable(statement.Name, type, VariableKind.Variable)); return new VariableDeclarationNode(statement.Name, Optional.OfNullable(assignmentNode), type); } @@ -262,10 +269,10 @@ public sealed class TypeChecker case ArrayTypeNode arrayType: { var scope = Scope.SubScope(); - scope.Declare(new Identifier(statement.ElementIdent, arrayType.ElementType, IdentifierKind.FunctionParameter)); + scope.DeclareVariable(new Variable(statement.ElementIdent, arrayType.ElementType, VariableKind.FunctionParameter)); if (statement.IndexIdent != null) { - scope.Declare(new Identifier(statement.ElementIdent, new IntTypeNode(true, 64), IdentifierKind.FunctionParameter)); + scope.DeclareVariable(new Variable(statement.ElementIdent, new IntTypeNode(true, 64), VariableKind.FunctionParameter)); } var body = CheckBlock(statement.Body, scope); @@ -273,7 +280,7 @@ public sealed class TypeChecker } default: { - throw new TypeCheckerException(Diagnostic.Error($"Type {target.Type} is not an iterable target").Build()); + throw new TypeCheckerException(Diagnostic.Error($"Type {target.Type} is not an iterable target").At(statement.Target).Build()); } } } @@ -306,6 +313,8 @@ public sealed class TypeChecker LiteralSyntax expression => CheckLiteral(expression, expectedType), StructFieldAccessSyntax expression => CheckStructFieldAccess(expression), StructInitializerSyntax expression => CheckStructInitializer(expression, expectedType), + InterpretCompilerMacroSyntax expression => CheckExpression(expression.Target) with { Type = ResolveType(expression.Type) }, + SizeCompilerMacroSyntax expression => new SizeCompilerMacroNode(new IntTypeNode(false, 64), ResolveType(expression.Type)), _ => throw new ArgumentOutOfRangeException(nameof(node)) }; @@ -338,7 +347,7 @@ public sealed class TypeChecker var target = CheckExpression(expression.Target); if (target is not LValueExpressionNode lvalue) { - throw new TypeCheckerException(Diagnostic.Error("Cannot take address of an rvalue").Build()); + throw new TypeCheckerException(Diagnostic.Error("Cannot take address of an rvalue").At(expression).Build()); } var type = new PointerTypeNode(target.Type); @@ -469,7 +478,7 @@ public sealed class TypeChecker var operand = CheckExpression(expression.Operand); if (operand.Type is not IntTypeNode { Signed: false } or FloatTypeNode) { - throw new TypeCheckerException(Diagnostic.Error("Negation operator must be used with signed integer or float types").Build()); + throw new TypeCheckerException(Diagnostic.Error("Negation operator must be used with signed integer or float types").At(expression).Build()); } return new UnaryExpressionNode(operand.Type, UnaryOperator.Negate, operand); @@ -479,7 +488,7 @@ public sealed class TypeChecker var operand = CheckExpression(expression.Operand); if (operand.Type is not BoolTypeNode) { - throw new TypeCheckerException(Diagnostic.Error("Invert operator must be used with booleans").Build()); + throw new TypeCheckerException(Diagnostic.Error("Invert operator must be used with booleans").At(expression).Build()); } return new UnaryExpressionNode(operand.Type, UnaryOperator.Invert, operand); @@ -580,22 +589,25 @@ public sealed class TypeChecker return new StructFuncCallNode(function.Type.ReturnType, expression.Name, structType, target, parameters); } - throw new TypeCheckerException(Diagnostic.Error($"No function {expression.Name} exists on type {target.Type}").Build()); + throw new TypeCheckerException(Diagnostic + .Error($"No function {expression.Name} exists on type {target.Type}") + .At(expression) + .Build()); } private ExpressionNode CheckLocalIdentifier(LocalIdentifierSyntax expression) { // First, look in the current scope for a matching identifier - var scopeIdent = Scope.Lookup(expression.Name); + var scopeIdent = Scope.LookupVariable(expression.Name); if (scopeIdent != null) { switch (scopeIdent.Kind) { - case IdentifierKind.Variable: + case VariableKind.Variable: { return new VariableIdentifierNode(scopeIdent.Type, expression.Name); } - case IdentifierKind.FunctionParameter: + case VariableKind.FunctionParameter: { return new FuncParameterIdentifierNode(scopeIdent.Type, expression.Name); } @@ -624,7 +636,11 @@ public sealed class TypeChecker { if (!_visibleModules.TryGetValue(expression.Module, out var module)) { - throw new TypeCheckerException(Diagnostic.Error($"Module {expression.Module} not found").WithHelp($"import \"{expression.Module}\"").At(expression).Build()); + throw new TypeCheckerException(Diagnostic + .Error($"Module {expression.Module} not found") + .WithHelp($"import \"{expression.Module}\"") + .At(expression) + .Build()); } var includePrivate = expression.Module == _syntaxTree.Metadata.ModuleName; @@ -638,7 +654,10 @@ public sealed class TypeChecker return new FuncIdentifierNode(type, expression.Module, expression.Name, function.ExternSymbol); } - throw new TypeCheckerException(Diagnostic.Error($"No exported symbol {expression.Name} not found in module {expression.Module}").At(expression).Build()); + throw new TypeCheckerException(Diagnostic + .Error($"No exported symbol {expression.Name} not found in module {expression.Module}") + .At(expression) + .Build()); } private ExpressionNode CheckLiteral(LiteralSyntax expression, TypeNode? expectedType) @@ -825,16 +844,74 @@ public sealed class TypeChecker ArrayTypeSyntax arr => new ArrayTypeNode(ResolveType(arr.BaseType)), PointerTypeSyntax ptr => new PointerTypeNode(ResolveType(ptr.BaseType)), StringTypeSyntax => new StringTypeNode(), + TemplateTypeSyntax template => ResolveTemplateType(template), VoidTypeSyntax => new VoidTypeNode(), _ => throw new NotSupportedException($"Unknown type syntax: {type}") }; } - private readonly Dictionary<(string Module, string Name), TypeNode> _typeCache = new(); - private readonly HashSet<(string Module, string Name)> _resolvingTypes = []; + private StructTypeNode ResolveTemplateType(TemplateTypeSyntax template) + { + // todo(nub31): Add module support for template types + var definition = _syntaxTree.Definitions + .OfType() + .FirstOrDefault(x => x.Name == template.Name); + + if (definition == null) + { + throw new TypeCheckerException(Diagnostic.Error($"Template {template.Name} does not exist").At(template).Build()); + } + + if (definition.TemplateArguments.Count != template.TemplateParameters.Count) + { + throw new TypeCheckerException(Diagnostic + .Error($"Template {template.Name} has {definition.TemplateArguments.Count} arguments, but usage only has {template.TemplateParameters.Count} parameters") + .At(template) + .Build()); + } + + var scope = new Scope(); + + for (var i = 0; i < definition.TemplateArguments.Count; i++) + { + scope.DeclareGenericType(definition.TemplateArguments[i], ResolveType(template.TemplateParameters[i])); + } + + _scopes.Push(scope); + + var fields = definition.Fields + .Select(x => new StructTypeField(x.Name, ResolveType(x.Type), x.Value.HasValue)) + .ToList(); + + var functions = definition.Functions + .Select(x => new StructTypeFunc(x.Name, x.Hook, new FuncTypeNode(x.Signature.Parameters.Select(y => ResolveType(y.Type)).ToList(), ResolveType(x.Signature.ReturnType)))) + .ToList(); + + var name = $"{template.Name}.{NameMangler.Mangle(template.TemplateParameters.Select(ResolveType))}"; + + var type = new StructTypeNode(template.Module, name, fields, functions); + + var checkedFields = definition.Fields.Select(CheckStructField).ToList(); + var checkedFunctions = definition.Functions.Select(x => CheckStructFunc(type, x, scope)).ToList(); + + Definitions.Add(new StructNode(type, template.Module, name, checkedFields, checkedFunctions)); + + _scopes.Pop(); + + return type; + } private TypeNode ResolveCustomType(CustomTypeSyntax customType) { + if (_syntaxTree.Metadata.ModuleName == customType.Module && _scopes.TryPeek(out var scope)) + { + var generic = scope.LookupGenericType(customType.Name); + if (generic != null) + { + return generic; + } + } + var key = (customType.Module, customType.Name); if (_typeCache.TryGetValue(key, out var cachedType)) @@ -853,7 +930,11 @@ public sealed class TypeChecker { if (!_visibleModules.TryGetValue(customType.Module, out var module)) { - throw new TypeCheckerException(Diagnostic.Error($"Module {customType.Module} not found").WithHelp($"import \"{customType.Module}\"").At(customType).Build()); + throw new TypeCheckerException(Diagnostic + .Error($"Module {customType.Module} not found") + .WithHelp($"import \"{customType.Module}\"") + .At(customType) + .Build()); } var includePrivate = customType.Module == _syntaxTree.Metadata.ModuleName; @@ -878,7 +959,10 @@ public sealed class TypeChecker return result; } - throw new TypeCheckerException(Diagnostic.Error($"Type {customType.Name} not found in module {customType.Module}").At(customType).Build()); + throw new TypeCheckerException(Diagnostic + .Error($"Type {customType.Name} not found in module {customType.Module}") + .At(customType) + .Build()); } finally { @@ -887,19 +971,20 @@ public sealed class TypeChecker } } -public enum IdentifierKind +public enum VariableKind { Variable, FunctionParameter } -public record Identifier(string Name, TypeNode Type, IdentifierKind Kind); +public record Variable(string Name, TypeNode Type, VariableKind Kind); public class Scope(Scope? parent = null) { - private readonly List _variables = []; + private readonly List _variables = []; + private readonly Dictionary _typeArguments = []; - public Identifier? Lookup(string name) + public Variable? LookupVariable(string name) { var variable = _variables.FirstOrDefault(x => x.Name == name); if (variable != null) @@ -907,12 +992,22 @@ public class Scope(Scope? parent = null) return variable; } - return parent?.Lookup(name); + return parent?.LookupVariable(name); } - public void Declare(Identifier identifier) + public void DeclareVariable(Variable variable) { - _variables.Add(identifier); + _variables.Add(variable); + } + + public void DeclareGenericType(string typeArgument, TypeNode type) + { + _typeArguments[typeArgument] = type; + } + + public TypeNode? LookupGenericType(string typeArgument) + { + return _typeArguments.GetValueOrDefault(typeArgument); } public Scope SubScope() diff --git a/example/src/main.nub b/example/src/main.nub index fa30e46..024d0a9 100644 --- a/example/src/main.nub +++ b/example/src/main.nub @@ -1,8 +1,8 @@ module "main" extern "puts" func puts(text: cstring) -extern "malloc" func malloc(size: u64): ^u64 -extern "free" func free(address: ^u64) +extern "malloc" func malloc(size: u64): ^void +extern "free" func free(address: ^void) struct Human { @@ -11,29 +11,29 @@ struct Human extern "main" func main(args: []cstring): i64 { - let x: ref = {} + let x: ref = {} test(x) return 0 } -func test(x: ref) +func test(x: ref) { } -struct ref +struct ref { - value: ^u64 + value: ^T count: ^u64 @oncreate func on_create() { puts("on_create") - this.value = malloc(8) - this.count = malloc(8) + this.value = @interpret(^T, malloc(@size(T))) + this.count = @interpret(^u64, malloc(@size(u64))) this.count^ = 1 } @@ -52,8 +52,8 @@ struct ref if this.count^ <= 0 { puts("free") - free(this.value) - free(this.count) + free(@interpret(^void, this.value)) + free(@interpret(^void, this.count)) } } } \ No newline at end of file