From d0ad3617764660c77ca17c1b71c476969b6f44b4 Mon Sep 17 00:00:00 2001 From: nub31 Date: Sun, 21 Sep 2025 21:56:59 +0200 Subject: [PATCH] ... --- compiler/NubLang/Parsing/Parser.cs | 22 +- compiler/NubLang/Parsing/Syntax/TypeSyntax.cs | 4 +- .../TypeChecking/Node/DefinitionNode.cs | 4 +- compiler/NubLang/TypeChecking/TypeChecker.cs | 328 ++++++++++++------ example/makefile | 4 +- example/src/main.nub | 46 +-- example/src/ref.nub | 40 +++ 7 files changed, 284 insertions(+), 164 deletions(-) create mode 100644 example/src/ref.nub diff --git a/compiler/NubLang/Parsing/Parser.cs b/compiler/NubLang/Parsing/Parser.cs index 3926c27..8d1dace 100644 --- a/compiler/NubLang/Parsing/Parser.cs +++ b/compiler/NubLang/Parsing/Parser.cs @@ -11,6 +11,7 @@ public sealed class Parser private List _tokens = []; private int _tokenIndex; private string _moduleName = string.Empty; + private HashSet _templateArguments = []; private Token? CurrentToken => _tokenIndex < _tokens.Count ? _tokens[_tokenIndex] : null; private bool HasToken => CurrentToken != null; @@ -26,6 +27,7 @@ public sealed class Parser _tokens = tokens; _tokenIndex = 0; _moduleName = string.Empty; + _templateArguments.Clear(); var metadata = ParseMetadata(); var definitions = ParseDefinitions(); @@ -39,13 +41,13 @@ public sealed class Parser try { - ExpectSymbol(Symbol.Module); - _moduleName = ExpectLiteral(LiteralKind.String).Value; - while (TryExpectSymbol(Symbol.Import)) { imports.Add(ExpectLiteral(LiteralKind.String).Value); } + + ExpectSymbol(Symbol.Module); + _moduleName = ExpectLiteral(LiteralKind.String).Value; } catch (ParseException e) { @@ -167,12 +169,11 @@ public sealed class Parser { var name = ExpectIdentifier(); - var templateArguments = new List(); if (TryExpectSymbol(Symbol.LessThan)) { while (!TryExpectSymbol(Symbol.GreaterThan)) { - templateArguments.Add(ExpectIdentifier().Value); + _templateArguments.Add(ExpectIdentifier().Value); TryExpectSymbol(Symbol.Comma); } } @@ -217,8 +218,10 @@ public sealed class Parser } } - if (templateArguments.Count > 0) + if (_templateArguments.Count > 0) { + var templateArguments = _templateArguments.ToList(); + _templateArguments.Clear(); return new StructTemplateSyntax(GetTokens(startIndex), templateArguments, name.Value, exported, fields, funcs); } @@ -670,6 +673,11 @@ public sealed class Parser var startIndex = _tokenIndex; if (TryExpectIdentifier(out var name)) { + if (_templateArguments.Contains(name.Value)) + { + return new SubstitutionTypeSyntax(GetTokens(startIndex), name.Value); + } + if (name.Value[0] == 'u' && int.TryParse(name.Value[1..], out var size)) { if (size is not 8 and not 16 and not 32 and not 64) @@ -745,7 +753,7 @@ public sealed class Parser if (templateParameters.Count > 0) { - return new TemplateTypeSyntax(GetTokens(startIndex), templateParameters, module, name.Value); + return new StructTemplateTypeSyntax(GetTokens(startIndex), templateParameters, module, name.Value); } return new CustomTypeSyntax(GetTokens(startIndex), module, name.Value); diff --git a/compiler/NubLang/Parsing/Syntax/TypeSyntax.cs b/compiler/NubLang/Parsing/Syntax/TypeSyntax.cs index 33ff4e7..9eb0afe 100644 --- a/compiler/NubLang/Parsing/Syntax/TypeSyntax.cs +++ b/compiler/NubLang/Parsing/Syntax/TypeSyntax.cs @@ -24,4 +24,6 @@ public record ArrayTypeSyntax(List Tokens, TypeSyntax BaseType) : TypeSyn public record CustomTypeSyntax(List Tokens, string Module, string Name) : TypeSyntax(Tokens); -public record TemplateTypeSyntax(List Tokens, List TemplateParameters, string Module, string Name) : TypeSyntax(Tokens); \ No newline at end of file +public record StructTemplateTypeSyntax(List Tokens, List TemplateParameters, string Module, string Name) : TypeSyntax(Tokens); + +public record SubstitutionTypeSyntax(List Tokens, string Name) : TypeSyntax(Tokens); \ No newline at end of file diff --git a/compiler/NubLang/TypeChecking/Node/DefinitionNode.cs b/compiler/NubLang/TypeChecking/Node/DefinitionNode.cs index 73801fe..d9c6b01 100644 --- a/compiler/NubLang/TypeChecking/Node/DefinitionNode.cs +++ b/compiler/NubLang/TypeChecking/Node/DefinitionNode.cs @@ -12,6 +12,4 @@ public record StructFieldNode(string Name, NubType Type, ExpressionNode? Value) public record StructFuncNode(string Name, string? Hook, FuncSignatureNode Signature, BlockNode Body) : Node; -public record StructNode(string Module, string Name, List Fields, List Functions) : DefinitionNode(Module, Name); - -public record StructTemplateNode(string Module, string Name, List TemplateArguments, List Fields, List Functions) : DefinitionNode(Module, Name); \ No newline at end of file +public record StructNode(string Module, string Name, List Fields, List Functions) : DefinitionNode(Module, Name); \ No newline at end of file diff --git a/compiler/NubLang/TypeChecking/TypeChecker.cs b/compiler/NubLang/TypeChecking/TypeChecker.cs index f078034..36419d1 100644 --- a/compiler/NubLang/TypeChecking/TypeChecker.cs +++ b/compiler/NubLang/TypeChecking/TypeChecker.cs @@ -14,20 +14,19 @@ public sealed class TypeChecker private readonly Dictionary _visibleModules; private readonly Stack _scopes = []; - private Scope _globalScope = new(); private readonly Stack _funcReturnTypes = []; private readonly Dictionary<(string Module, string Name), NubType> _typeCache = new(); private readonly HashSet<(string Module, string Name)> _resolvingTypes = []; + private readonly HashSet _checkedTemplateStructs = []; private Scope CurrentScope => _scopes.Peek(); - private string CurrentModule => _syntaxTree.Metadata.ModuleName; public TypeChecker(SyntaxTree syntaxTree, ModuleRepository moduleRepository) { _syntaxTree = syntaxTree; _visibleModules = moduleRepository .Modules() - .Where(x => syntaxTree.Metadata.Imports.Contains(x.Key) || CurrentModule == x.Key) + .Where(x => syntaxTree.Metadata.Imports.Contains(x.Key) || _syntaxTree.Metadata.ModuleName == x.Key) .ToDictionary(); } @@ -38,10 +37,10 @@ public sealed class TypeChecker public void Check() { _scopes.Clear(); - _globalScope = new Scope(); _funcReturnTypes.Clear(); _typeCache.Clear(); _resolvingTypes.Clear(); + _checkedTemplateStructs.Clear(); Diagnostics.Clear(); Definitions.Clear(); @@ -49,72 +48,87 @@ public sealed class TypeChecker foreach (var definition in _syntaxTree.Definitions) { - BeginScope(true); - try { - Definitions.Add(definition switch + switch (definition) { - FuncSyntax funcSyntax => CheckFuncDefinition(funcSyntax), - StructSyntax structSyntax => CheckStructDefinition(structSyntax), - StructTemplateSyntax structTemplate => CheckStructTemplateDefinition(structTemplate), - _ => throw new ArgumentOutOfRangeException() - }); + case FuncSyntax funcSyntax: + Definitions.Add(CheckFuncDefinition(funcSyntax)); + break; + case StructSyntax structSyntax: + Definitions.Add(CheckStructDefinition(structSyntax)); + break; + case StructTemplateSyntax: + break; + default: + throw new ArgumentOutOfRangeException(); + } } catch (TypeCheckerException e) { Diagnostics.Add(e.Diagnostic); } - - EndScope(); } } - private void BeginScope(bool root) + private ScopeDisposer BeginScope() { - var scope = root - ? _globalScope.SubScope() - : _scopes.Peek().SubScope(); + if (_scopes.TryPeek(out var scope)) + { + _scopes.Push(scope.SubScope()); + } + else + { + _scopes.Push(new Scope(_syntaxTree.Metadata.ModuleName)); + } - _scopes.Push(scope); + return new ScopeDisposer(this); } - private void EndScope() + private ScopeDisposer BeginRootScope(string module) { - _scopes.Pop(); + _scopes.Push(new Scope(module)); + return new ScopeDisposer(this); + } + + private sealed class ScopeDisposer(TypeChecker owner) : IDisposable + { + private bool _disposed; + + public void Dispose() + { + if (_disposed) return; + owner._scopes.Pop(); + _disposed = true; + } } private StructNode CheckStructDefinition(StructSyntax node) { - var fieldTypes = node.Fields - .Select(x => new NubStructFieldType(x.Name, ResolveType(x.Type), x.Value != null)) - .ToList(); + using (BeginRootScope(_syntaxTree.Metadata.ModuleName)) + { + var fieldTypes = node.Fields + .Select(x => new NubStructFieldType(x.Name, ResolveType(x.Type), x.Value != null)) + .ToList(); - var fieldFunctions = node.Functions - .Select(x => - { - var parameters = x.Signature.Parameters.Select(y => ResolveType(y.Type)).ToList(); - var returnType = ResolveType(x.Signature.ReturnType); - return new NubStructFuncType(x.Name, x.Hook, parameters, returnType); - }) - .ToList(); + var fieldFunctions = node.Functions + .Select(x => + { + var parameters = x.Signature.Parameters.Select(y => ResolveType(y.Type)).ToList(); + var returnType = ResolveType(x.Signature.ReturnType); + return new NubStructFuncType(x.Name, x.Hook, parameters, returnType); + }) + .ToList(); - var structType = new NubStructType(CurrentModule, node.Name, fieldTypes, fieldFunctions); + var structType = new NubStructType(CurrentScope.Module, node.Name, fieldTypes, fieldFunctions); - CurrentScope.DeclareVariable(new Variable("this", structType, VariableKind.RValue)); + CurrentScope.DeclareVariable(new Variable("this", structType, VariableKind.RValue)); - var fields = node.Fields.Select(CheckStructField).ToList(); - var functions = node.Functions.Select(CheckStructFunc).ToList(); + var fields = node.Fields.Select(CheckStructField).ToList(); + var functions = node.Functions.Select(CheckStructFunc).ToList(); - return new StructNode(CurrentModule, node.Name, fields, functions); - } - - private StructTemplateNode CheckStructTemplateDefinition(StructTemplateSyntax node) - { - var fields = node.Fields.Select(CheckStructField).ToList(); - var functions = node.Functions.Select(CheckStructFunc).ToList(); - - return new StructTemplateNode(CurrentModule, node.Name, node.TemplateArguments, fields, functions); + return new StructNode(CurrentScope.Module, node.Name, fields, functions); + } } private StructFuncNode CheckStructFunc(StructFuncSyntax function) @@ -143,39 +157,42 @@ public sealed class TypeChecker private FuncNode CheckFuncDefinition(FuncSyntax node) { - foreach (var parameter in node.Signature.Parameters) + using (BeginRootScope(_syntaxTree.Metadata.ModuleName)) { - CurrentScope.DeclareVariable(new Variable(parameter.Name, ResolveType(parameter.Type), VariableKind.RValue)); - } - - var signature = CheckFuncSignature(node.Signature); - - BlockNode? body = null; - if (node.Body != null) - { - _funcReturnTypes.Push(signature.ReturnType); - - body = CheckBlock(node.Body); - - if (!AlwaysReturns(body)) + foreach (var parameter in node.Signature.Parameters) { - if (signature.ReturnType is NubVoidType) - { - body.Statements.Add(new ReturnNode(null)); - } - else - { - Diagnostics.Add(Diagnostic - .Error("Not all code paths return a value") - .At(node.Body) - .Build()); - } + CurrentScope.DeclareVariable(new Variable(parameter.Name, ResolveType(parameter.Type), VariableKind.RValue)); } - _funcReturnTypes.Pop(); - } + var signature = CheckFuncSignature(node.Signature); - return new FuncNode(CurrentModule, node.Name, node.ExternSymbol, signature, body); + BlockNode? body = null; + if (node.Body != null) + { + _funcReturnTypes.Push(signature.ReturnType); + + body = CheckBlock(node.Body); + + if (!AlwaysReturns(body)) + { + if (signature.ReturnType is NubVoidType) + { + body.Statements.Add(new ReturnNode(null)); + } + else + { + Diagnostics.Add(Diagnostic + .Error("Not all code paths return a value") + .At(node.Body) + .Build()); + } + } + + _funcReturnTypes.Pop(); + } + + return new FuncNode(CurrentScope.Module, node.Name, node.ExternSymbol, signature, body); + } } private AssignmentNode CheckAssignment(AssignmentSyntax statement) @@ -588,14 +605,14 @@ public sealed class TypeChecker } // Second, look in the current module for a function matching the identifier - var module = _visibleModules[CurrentModule]; + var module = _visibleModules[CurrentScope.Module]; var function = module.Functions(true).FirstOrDefault(x => x.Name == expression.Name); if (function != null) { var parameters = function.Signature.Parameters.Select(x => ResolveType(x.Type)).ToList(); var type = new NubFuncType(parameters, ResolveType(function.Signature.ReturnType)); - return new FuncIdentifierNode(type, CurrentModule, expression.Name, function.ExternSymbol); + return new FuncIdentifierNode(type, CurrentScope.Module, expression.Name, function.ExternSymbol); } throw new TypeCheckerException(Diagnostic.Error($"Symbol {expression.Name} not found").At(expression).Build()); @@ -612,7 +629,7 @@ public sealed class TypeChecker .Build()); } - var includePrivate = expression.Module == CurrentModule; + var includePrivate = expression.Module == CurrentScope.Module; // First, look for the exported function in the specified module var function = module.Functions(includePrivate).FirstOrDefault(x => x.Name == expression.Name); @@ -760,34 +777,33 @@ public sealed class TypeChecker var reachable = true; var warnedUnreachable = false; - BeginScope(false); - - foreach (var statement in node.Statements) + using (BeginScope()) { - var checkedStatement = CheckStatement(statement); - - if (reachable) + foreach (var statement in node.Statements) { - if (checkedStatement is TerminalStatementNode) + var checkedStatement = CheckStatement(statement); + + if (reachable) { - reachable = false; + if (checkedStatement is TerminalStatementNode) + { + reachable = false; + } + + statements.Add(checkedStatement); } - - statements.Add(checkedStatement); - } - else - { - if (!warnedUnreachable) + else { - Diagnostics.Add(Diagnostic.Warning("Statement is unreachable").At(statement).Build()); - warnedUnreachable = true; + if (!warnedUnreachable) + { + Diagnostics.Add(Diagnostic.Warning("Statement is unreachable").At(statement).Build()); + warnedUnreachable = true; + } } } + + return new BlockNode(statements); } - - EndScope(); - - return new BlockNode(statements); } private StatementNode CheckStatement(StatementSyntax statement) @@ -843,12 +859,27 @@ public sealed class TypeChecker PointerTypeSyntax ptr => new NubPointerType(ResolveType(ptr.BaseType)), StringTypeSyntax => new NubStringType(), CustomTypeSyntax c => ResolveCustomType(c), - TemplateTypeSyntax t => ResolveTemplateType(t), + StructTemplateTypeSyntax t => ResolveStructTemplateType(t), + SubstitutionTypeSyntax s => ResolveTypeSubstitution(s), VoidTypeSyntax => new NubVoidType(), _ => throw new NotSupportedException($"Unknown type syntax: {type}") }; } + private NubType ResolveTypeSubstitution(SubstitutionTypeSyntax substitution) + { + var type = CurrentScope.LookupTypeSubstitution(substitution.Name); + if (type == null) + { + throw new TypeCheckerException(Diagnostic + .Error($"Template argument {substitution.Name} does not exist") + .At(substitution) + .Build()); + } + + return type; + } + private NubType ResolveCustomType(CustomTypeSyntax customType) { var key = (customType.Module, customType.Name); @@ -876,21 +907,21 @@ public sealed class TypeChecker .Build()); } - var includePrivate = customType.Module == CurrentModule; + var includePrivate = customType.Module == CurrentScope.Module; - var strctDef = module.Structs(includePrivate).FirstOrDefault(x => x.Name == customType.Name); - if (strctDef != null) + var structDef = module.Structs(includePrivate).FirstOrDefault(x => x.Name == customType.Name); + if (structDef != null) { - var result = new NubStructType(customType.Module, strctDef.Name, [], []); + var result = new NubStructType(customType.Module, structDef.Name, [], []); _typeCache[key] = result; - var fields = strctDef.Fields + var fields = structDef.Fields .Select(x => new NubStructFieldType(x.Name, ResolveType(x.Type), x.Value != null)) .ToList(); result.Fields.AddRange(fields); - var functions = strctDef.Functions + var functions = structDef.Functions .Select(x => { var parameters = x.Signature.Parameters @@ -918,9 +949,71 @@ public sealed class TypeChecker } } - private NubType ResolveTemplateType(TemplateTypeSyntax template) + private NubStructType ResolveStructTemplateType(StructTemplateTypeSyntax structTemplate) { - throw new NotImplementedException(); + if (!_visibleModules.TryGetValue(structTemplate.Module, out var module)) + { + throw new TypeCheckerException(Diagnostic + .Error($"Module {structTemplate.Module} not found") + .WithHelp($"import \"{structTemplate.Module}\"") + .At(structTemplate) + .Build()); + } + + var includePrivate = structTemplate.Module == CurrentScope.Module; + + var templateDef = module + .StructTemplates(includePrivate) + .FirstOrDefault(x => x.Name == structTemplate.Name); + + if (templateDef == null) + { + throw new TypeCheckerException(Diagnostic + .Error($"Template type {structTemplate.Name} not found in module {structTemplate.Module}") + .At(structTemplate) + .Build()); + } + + var templateParameterTypes = structTemplate.TemplateParameters.Select(ResolveType).ToList(); + var mangledName = $"{structTemplate.Name}.{NameMangler.Mangle(templateParameterTypes)}"; + + using (BeginRootScope(structTemplate.Module)) + { + for (var i = 0; i < templateParameterTypes.Count; i++) + { + var parameterName = templateDef.TemplateArguments[i]; + var parameterType = templateParameterTypes[i]; + CurrentScope.DeclareTypeSubstitution(parameterName, parameterType); + } + + var fieldTypes = templateDef.Fields + .Select(x => new NubStructFieldType(x.Name, ResolveType(x.Type), x.Value != null)) + .ToList(); + + var fieldFunctions = templateDef.Functions + .Select(x => + { + var parameters = x.Signature.Parameters.Select(y => ResolveType(y.Type)).ToList(); + var returnType = ResolveType(x.Signature.ReturnType); + return new NubStructFuncType(x.Name, x.Hook, parameters, returnType); + }) + .ToList(); + + var structType = new NubStructType(structTemplate.Module, mangledName, fieldTypes, fieldFunctions); + + if (!_checkedTemplateStructs.Contains($"{structTemplate.Module}.{mangledName}")) + { + CurrentScope.DeclareVariable(new Variable("this", structType, VariableKind.RValue)); + var fields = templateDef.Fields.Select(CheckStructField).ToList(); + var functions = templateDef.Functions.Select(CheckStructFunc).ToList(); + Definitions.Add(new StructNode(structTemplate.Module, mangledName, fields, functions)); + _checkedTemplateStructs.Add($"{structTemplate.Module}.{mangledName}"); + } + + ReferencedStructTypes.Add(structType); + + return structType; + } } } @@ -932,9 +1025,16 @@ public enum VariableKind public record Variable(string Name, NubType Type, VariableKind Kind); -public class Scope(Scope? parent = null) +public class Scope(string module, Scope? parent = null) { private readonly List _variables = []; + private readonly Dictionary _typeSubstitutions = []; + public string Module { get; } = module; + + public void DeclareVariable(Variable variable) + { + _variables.Add(variable); + } public Variable? LookupVariable(string name) { @@ -947,14 +1047,24 @@ public class Scope(Scope? parent = null) return parent?.LookupVariable(name); } - public void DeclareVariable(Variable variable) + public void DeclareTypeSubstitution(string name, NubType type) { - _variables.Add(variable); + _typeSubstitutions[name] = type; + } + + public NubType? LookupTypeSubstitution(string name) + { + if (_typeSubstitutions.TryGetValue(name, out var type)) + { + return type; + } + + return parent?.LookupTypeSubstitution(name); } public Scope SubScope() { - return new Scope(this); + return new Scope(Module, this); } } diff --git a/example/makefile b/example/makefile index 302a9aa..0a4e657 100644 --- a/example/makefile +++ b/example/makefile @@ -3,8 +3,8 @@ NUBC = ../compiler/NubLang.CLI/bin/Debug/net9.0/nubc out: .build/out.o gcc -nostartfiles -o out x86_64.s .build/out.o -.build/out.o: $(NUBC) src/main.nub - $(NUBC) src/main.nub +.build/out.o: $(NUBC) src/main.nub src/ref.nub + $(NUBC) src/main.nub src/ref.nub .PHONY: $(NUBC) $(NUBC): diff --git a/example/src/main.nub b/example/src/main.nub index e817e04..c989628 100644 --- a/example/src/main.nub +++ b/example/src/main.nub @@ -1,8 +1,6 @@ -module "main" +import "core" -extern "puts" func puts(text: cstring) -extern "malloc" func malloc(size: u64): ^void -extern "free" func free(address: ^void) +module "main" struct Human { @@ -11,47 +9,11 @@ struct Human extern "main" func main(args: []cstring): i64 { - let x: ref = {} + let x: core::ref = {} test(x) return 0 } -func test(x: ref) +func test(x: core::ref) { - -} - -struct ref -{ - value: ^T - count: ^u64 - - @oncreate - func on_create() - { - puts("on_create") - this.value = @interpret(^T, malloc(@size(T))) - this.count = @interpret(^u64, malloc(@size(u64))) - this.count^ = 1 - } - - @oncopy - func on_copy() - { - puts("on_copy") - this.count^ = this.count^ + 1 - } - - @ondestroy - func on_destroy() - { - puts("on_destroy") - this.count^ = this.count^ - 1 - if this.count^ <= 0 - { - puts("free") - free(@interpret(^void, this.value)) - free(@interpret(^void, this.count)) - } - } } \ No newline at end of file diff --git a/example/src/ref.nub b/example/src/ref.nub new file mode 100644 index 0000000..5bfe824 --- /dev/null +++ b/example/src/ref.nub @@ -0,0 +1,40 @@ +module "core" + +extern "puts" func puts(text: cstring) +extern "malloc" func malloc(size: u64): ^void +extern "free" func free(address: ^void) + +export struct ref +{ + value: ^T + count: ^u64 + + @oncreate + func on_create() + { + puts("on_create") + this.value = @interpret(^T, malloc(@size(T))) + this.count = @interpret(^u64, malloc(@size(u64))) + this.count^ = 1 + } + + @oncopy + func on_copy() + { + puts("on_copy") + this.count^ = this.count^ + 1 + } + + @ondestroy + func on_destroy() + { + puts("on_destroy") + this.count^ = this.count^ - 1 + if this.count^ <= 0 + { + puts("free") + free(@interpret(^void, this.value)) + free(@interpret(^void, this.count)) + } + } +} \ No newline at end of file