This commit is contained in:
nub31
2025-09-21 21:56:59 +02:00
parent 822cdf4bda
commit d0ad361776
7 changed files with 284 additions and 164 deletions

View File

@@ -11,6 +11,7 @@ public sealed class Parser
private List<Token> _tokens = [];
private int _tokenIndex;
private string _moduleName = string.Empty;
private HashSet<string> _templateArguments = [];
private Token? CurrentToken => _tokenIndex < _tokens.Count ? _tokens[_tokenIndex] : null;
private bool HasToken => CurrentToken != null;
@@ -26,6 +27,7 @@ public sealed class Parser
_tokens = tokens;
_tokenIndex = 0;
_moduleName = string.Empty;
_templateArguments.Clear();
var metadata = ParseMetadata();
var definitions = ParseDefinitions();
@@ -39,13 +41,13 @@ public sealed class Parser
try
{
ExpectSymbol(Symbol.Module);
_moduleName = ExpectLiteral(LiteralKind.String).Value;
while (TryExpectSymbol(Symbol.Import))
{
imports.Add(ExpectLiteral(LiteralKind.String).Value);
}
ExpectSymbol(Symbol.Module);
_moduleName = ExpectLiteral(LiteralKind.String).Value;
}
catch (ParseException e)
{
@@ -167,12 +169,11 @@ public sealed class Parser
{
var name = ExpectIdentifier();
var templateArguments = new List<string>();
if (TryExpectSymbol(Symbol.LessThan))
{
while (!TryExpectSymbol(Symbol.GreaterThan))
{
templateArguments.Add(ExpectIdentifier().Value);
_templateArguments.Add(ExpectIdentifier().Value);
TryExpectSymbol(Symbol.Comma);
}
}
@@ -217,8 +218,10 @@ public sealed class Parser
}
}
if (templateArguments.Count > 0)
if (_templateArguments.Count > 0)
{
var templateArguments = _templateArguments.ToList();
_templateArguments.Clear();
return new StructTemplateSyntax(GetTokens(startIndex), templateArguments, name.Value, exported, fields, funcs);
}
@@ -670,6 +673,11 @@ public sealed class Parser
var startIndex = _tokenIndex;
if (TryExpectIdentifier(out var name))
{
if (_templateArguments.Contains(name.Value))
{
return new SubstitutionTypeSyntax(GetTokens(startIndex), name.Value);
}
if (name.Value[0] == 'u' && int.TryParse(name.Value[1..], out var size))
{
if (size is not 8 and not 16 and not 32 and not 64)
@@ -745,7 +753,7 @@ public sealed class Parser
if (templateParameters.Count > 0)
{
return new TemplateTypeSyntax(GetTokens(startIndex), templateParameters, module, name.Value);
return new StructTemplateTypeSyntax(GetTokens(startIndex), templateParameters, module, name.Value);
}
return new CustomTypeSyntax(GetTokens(startIndex), module, name.Value);

View File

@@ -24,4 +24,6 @@ public record ArrayTypeSyntax(List<Token> Tokens, TypeSyntax BaseType) : TypeSyn
public record CustomTypeSyntax(List<Token> Tokens, string Module, string Name) : TypeSyntax(Tokens);
public record TemplateTypeSyntax(List<Token> Tokens, List<TypeSyntax> TemplateParameters, string Module, string Name) : TypeSyntax(Tokens);
public record StructTemplateTypeSyntax(List<Token> Tokens, List<TypeSyntax> TemplateParameters, string Module, string Name) : TypeSyntax(Tokens);
public record SubstitutionTypeSyntax(List<Token> Tokens, string Name) : TypeSyntax(Tokens);

View File

@@ -13,5 +13,3 @@ public record StructFieldNode(string Name, NubType Type, ExpressionNode? Value)
public record StructFuncNode(string Name, string? Hook, FuncSignatureNode Signature, BlockNode Body) : Node;
public record StructNode(string Module, string Name, List<StructFieldNode> Fields, List<StructFuncNode> Functions) : DefinitionNode(Module, Name);
public record StructTemplateNode(string Module, string Name, List<string> TemplateArguments, List<StructFieldNode> Fields, List<StructFuncNode> Functions) : DefinitionNode(Module, Name);

View File

@@ -14,20 +14,19 @@ public sealed class TypeChecker
private readonly Dictionary<string, Module> _visibleModules;
private readonly Stack<Scope> _scopes = [];
private Scope _globalScope = new();
private readonly Stack<NubType> _funcReturnTypes = [];
private readonly Dictionary<(string Module, string Name), NubType> _typeCache = new();
private readonly HashSet<(string Module, string Name)> _resolvingTypes = [];
private readonly HashSet<string> _checkedTemplateStructs = [];
private Scope CurrentScope => _scopes.Peek();
private string CurrentModule => _syntaxTree.Metadata.ModuleName;
public TypeChecker(SyntaxTree syntaxTree, ModuleRepository moduleRepository)
{
_syntaxTree = syntaxTree;
_visibleModules = moduleRepository
.Modules()
.Where(x => syntaxTree.Metadata.Imports.Contains(x.Key) || CurrentModule == x.Key)
.Where(x => syntaxTree.Metadata.Imports.Contains(x.Key) || _syntaxTree.Metadata.ModuleName == x.Key)
.ToDictionary();
}
@@ -38,10 +37,10 @@ public sealed class TypeChecker
public void Check()
{
_scopes.Clear();
_globalScope = new Scope();
_funcReturnTypes.Clear();
_typeCache.Clear();
_resolvingTypes.Clear();
_checkedTemplateStructs.Clear();
Diagnostics.Clear();
Definitions.Clear();
@@ -49,42 +48,64 @@ public sealed class TypeChecker
foreach (var definition in _syntaxTree.Definitions)
{
BeginScope(true);
try
{
Definitions.Add(definition switch
switch (definition)
{
FuncSyntax funcSyntax => CheckFuncDefinition(funcSyntax),
StructSyntax structSyntax => CheckStructDefinition(structSyntax),
StructTemplateSyntax structTemplate => CheckStructTemplateDefinition(structTemplate),
_ => throw new ArgumentOutOfRangeException()
});
case FuncSyntax funcSyntax:
Definitions.Add(CheckFuncDefinition(funcSyntax));
break;
case StructSyntax structSyntax:
Definitions.Add(CheckStructDefinition(structSyntax));
break;
case StructTemplateSyntax:
break;
default:
throw new ArgumentOutOfRangeException();
}
}
catch (TypeCheckerException e)
{
Diagnostics.Add(e.Diagnostic);
}
EndScope();
}
}
private void BeginScope(bool root)
private ScopeDisposer BeginScope()
{
var scope = root
? _globalScope.SubScope()
: _scopes.Peek().SubScope();
_scopes.Push(scope);
if (_scopes.TryPeek(out var scope))
{
_scopes.Push(scope.SubScope());
}
else
{
_scopes.Push(new Scope(_syntaxTree.Metadata.ModuleName));
}
private void EndScope()
return new ScopeDisposer(this);
}
private ScopeDisposer BeginRootScope(string module)
{
_scopes.Pop();
_scopes.Push(new Scope(module));
return new ScopeDisposer(this);
}
private sealed class ScopeDisposer(TypeChecker owner) : IDisposable
{
private bool _disposed;
public void Dispose()
{
if (_disposed) return;
owner._scopes.Pop();
_disposed = true;
}
}
private StructNode CheckStructDefinition(StructSyntax node)
{
using (BeginRootScope(_syntaxTree.Metadata.ModuleName))
{
var fieldTypes = node.Fields
.Select(x => new NubStructFieldType(x.Name, ResolveType(x.Type), x.Value != null))
@@ -99,22 +120,15 @@ public sealed class TypeChecker
})
.ToList();
var structType = new NubStructType(CurrentModule, node.Name, fieldTypes, fieldFunctions);
var structType = new NubStructType(CurrentScope.Module, node.Name, fieldTypes, fieldFunctions);
CurrentScope.DeclareVariable(new Variable("this", structType, VariableKind.RValue));
var fields = node.Fields.Select(CheckStructField).ToList();
var functions = node.Functions.Select(CheckStructFunc).ToList();
return new StructNode(CurrentModule, node.Name, fields, functions);
return new StructNode(CurrentScope.Module, node.Name, fields, functions);
}
private StructTemplateNode CheckStructTemplateDefinition(StructTemplateSyntax node)
{
var fields = node.Fields.Select(CheckStructField).ToList();
var functions = node.Functions.Select(CheckStructFunc).ToList();
return new StructTemplateNode(CurrentModule, node.Name, node.TemplateArguments, fields, functions);
}
private StructFuncNode CheckStructFunc(StructFuncSyntax function)
@@ -142,6 +156,8 @@ public sealed class TypeChecker
}
private FuncNode CheckFuncDefinition(FuncSyntax node)
{
using (BeginRootScope(_syntaxTree.Metadata.ModuleName))
{
foreach (var parameter in node.Signature.Parameters)
{
@@ -175,7 +191,8 @@ public sealed class TypeChecker
_funcReturnTypes.Pop();
}
return new FuncNode(CurrentModule, node.Name, node.ExternSymbol, signature, body);
return new FuncNode(CurrentScope.Module, node.Name, node.ExternSymbol, signature, body);
}
}
private AssignmentNode CheckAssignment(AssignmentSyntax statement)
@@ -588,14 +605,14 @@ public sealed class TypeChecker
}
// Second, look in the current module for a function matching the identifier
var module = _visibleModules[CurrentModule];
var module = _visibleModules[CurrentScope.Module];
var function = module.Functions(true).FirstOrDefault(x => x.Name == expression.Name);
if (function != null)
{
var parameters = function.Signature.Parameters.Select(x => ResolveType(x.Type)).ToList();
var type = new NubFuncType(parameters, ResolveType(function.Signature.ReturnType));
return new FuncIdentifierNode(type, CurrentModule, expression.Name, function.ExternSymbol);
return new FuncIdentifierNode(type, CurrentScope.Module, expression.Name, function.ExternSymbol);
}
throw new TypeCheckerException(Diagnostic.Error($"Symbol {expression.Name} not found").At(expression).Build());
@@ -612,7 +629,7 @@ public sealed class TypeChecker
.Build());
}
var includePrivate = expression.Module == CurrentModule;
var includePrivate = expression.Module == CurrentScope.Module;
// First, look for the exported function in the specified module
var function = module.Functions(includePrivate).FirstOrDefault(x => x.Name == expression.Name);
@@ -760,8 +777,8 @@ public sealed class TypeChecker
var reachable = true;
var warnedUnreachable = false;
BeginScope(false);
using (BeginScope())
{
foreach (var statement in node.Statements)
{
var checkedStatement = CheckStatement(statement);
@@ -785,10 +802,9 @@ public sealed class TypeChecker
}
}
EndScope();
return new BlockNode(statements);
}
}
private StatementNode CheckStatement(StatementSyntax statement)
{
@@ -843,12 +859,27 @@ public sealed class TypeChecker
PointerTypeSyntax ptr => new NubPointerType(ResolveType(ptr.BaseType)),
StringTypeSyntax => new NubStringType(),
CustomTypeSyntax c => ResolveCustomType(c),
TemplateTypeSyntax t => ResolveTemplateType(t),
StructTemplateTypeSyntax t => ResolveStructTemplateType(t),
SubstitutionTypeSyntax s => ResolveTypeSubstitution(s),
VoidTypeSyntax => new NubVoidType(),
_ => throw new NotSupportedException($"Unknown type syntax: {type}")
};
}
private NubType ResolveTypeSubstitution(SubstitutionTypeSyntax substitution)
{
var type = CurrentScope.LookupTypeSubstitution(substitution.Name);
if (type == null)
{
throw new TypeCheckerException(Diagnostic
.Error($"Template argument {substitution.Name} does not exist")
.At(substitution)
.Build());
}
return type;
}
private NubType ResolveCustomType(CustomTypeSyntax customType)
{
var key = (customType.Module, customType.Name);
@@ -876,21 +907,21 @@ public sealed class TypeChecker
.Build());
}
var includePrivate = customType.Module == CurrentModule;
var includePrivate = customType.Module == CurrentScope.Module;
var strctDef = module.Structs(includePrivate).FirstOrDefault(x => x.Name == customType.Name);
if (strctDef != null)
var structDef = module.Structs(includePrivate).FirstOrDefault(x => x.Name == customType.Name);
if (structDef != null)
{
var result = new NubStructType(customType.Module, strctDef.Name, [], []);
var result = new NubStructType(customType.Module, structDef.Name, [], []);
_typeCache[key] = result;
var fields = strctDef.Fields
var fields = structDef.Fields
.Select(x => new NubStructFieldType(x.Name, ResolveType(x.Type), x.Value != null))
.ToList();
result.Fields.AddRange(fields);
var functions = strctDef.Functions
var functions = structDef.Functions
.Select(x =>
{
var parameters = x.Signature.Parameters
@@ -918,9 +949,71 @@ public sealed class TypeChecker
}
}
private NubType ResolveTemplateType(TemplateTypeSyntax template)
private NubStructType ResolveStructTemplateType(StructTemplateTypeSyntax structTemplate)
{
throw new NotImplementedException();
if (!_visibleModules.TryGetValue(structTemplate.Module, out var module))
{
throw new TypeCheckerException(Diagnostic
.Error($"Module {structTemplate.Module} not found")
.WithHelp($"import \"{structTemplate.Module}\"")
.At(structTemplate)
.Build());
}
var includePrivate = structTemplate.Module == CurrentScope.Module;
var templateDef = module
.StructTemplates(includePrivate)
.FirstOrDefault(x => x.Name == structTemplate.Name);
if (templateDef == null)
{
throw new TypeCheckerException(Diagnostic
.Error($"Template type {structTemplate.Name} not found in module {structTemplate.Module}")
.At(structTemplate)
.Build());
}
var templateParameterTypes = structTemplate.TemplateParameters.Select(ResolveType).ToList();
var mangledName = $"{structTemplate.Name}.{NameMangler.Mangle(templateParameterTypes)}";
using (BeginRootScope(structTemplate.Module))
{
for (var i = 0; i < templateParameterTypes.Count; i++)
{
var parameterName = templateDef.TemplateArguments[i];
var parameterType = templateParameterTypes[i];
CurrentScope.DeclareTypeSubstitution(parameterName, parameterType);
}
var fieldTypes = templateDef.Fields
.Select(x => new NubStructFieldType(x.Name, ResolveType(x.Type), x.Value != null))
.ToList();
var fieldFunctions = templateDef.Functions
.Select(x =>
{
var parameters = x.Signature.Parameters.Select(y => ResolveType(y.Type)).ToList();
var returnType = ResolveType(x.Signature.ReturnType);
return new NubStructFuncType(x.Name, x.Hook, parameters, returnType);
})
.ToList();
var structType = new NubStructType(structTemplate.Module, mangledName, fieldTypes, fieldFunctions);
if (!_checkedTemplateStructs.Contains($"{structTemplate.Module}.{mangledName}"))
{
CurrentScope.DeclareVariable(new Variable("this", structType, VariableKind.RValue));
var fields = templateDef.Fields.Select(CheckStructField).ToList();
var functions = templateDef.Functions.Select(CheckStructFunc).ToList();
Definitions.Add(new StructNode(structTemplate.Module, mangledName, fields, functions));
_checkedTemplateStructs.Add($"{structTemplate.Module}.{mangledName}");
}
ReferencedStructTypes.Add(structType);
return structType;
}
}
}
@@ -932,9 +1025,16 @@ public enum VariableKind
public record Variable(string Name, NubType Type, VariableKind Kind);
public class Scope(Scope? parent = null)
public class Scope(string module, Scope? parent = null)
{
private readonly List<Variable> _variables = [];
private readonly Dictionary<string, NubType> _typeSubstitutions = [];
public string Module { get; } = module;
public void DeclareVariable(Variable variable)
{
_variables.Add(variable);
}
public Variable? LookupVariable(string name)
{
@@ -947,14 +1047,24 @@ public class Scope(Scope? parent = null)
return parent?.LookupVariable(name);
}
public void DeclareVariable(Variable variable)
public void DeclareTypeSubstitution(string name, NubType type)
{
_variables.Add(variable);
_typeSubstitutions[name] = type;
}
public NubType? LookupTypeSubstitution(string name)
{
if (_typeSubstitutions.TryGetValue(name, out var type))
{
return type;
}
return parent?.LookupTypeSubstitution(name);
}
public Scope SubScope()
{
return new Scope(this);
return new Scope(Module, this);
}
}

View File

@@ -3,8 +3,8 @@ NUBC = ../compiler/NubLang.CLI/bin/Debug/net9.0/nubc
out: .build/out.o
gcc -nostartfiles -o out x86_64.s .build/out.o
.build/out.o: $(NUBC) src/main.nub
$(NUBC) src/main.nub
.build/out.o: $(NUBC) src/main.nub src/ref.nub
$(NUBC) src/main.nub src/ref.nub
.PHONY: $(NUBC)
$(NUBC):

View File

@@ -1,8 +1,6 @@
module "main"
import "core"
extern "puts" func puts(text: cstring)
extern "malloc" func malloc(size: u64): ^void
extern "free" func free(address: ^void)
module "main"
struct Human
{
@@ -11,47 +9,11 @@ struct Human
extern "main" func main(args: []cstring): i64
{
let x: ref<Human> = {}
let x: core::ref<Human> = {}
test(x)
return 0
}
func test(x: ref<Human>)
func test(x: core::ref<Human>)
{
}
struct ref<T>
{
value: ^T
count: ^u64
@oncreate
func on_create()
{
puts("on_create")
this.value = @interpret(^T, malloc(@size(T)))
this.count = @interpret(^u64, malloc(@size(u64)))
this.count^ = 1
}
@oncopy
func on_copy()
{
puts("on_copy")
this.count^ = this.count^ + 1
}
@ondestroy
func on_destroy()
{
puts("on_destroy")
this.count^ = this.count^ - 1
if this.count^ <= 0
{
puts("free")
free(@interpret(^void, this.value))
free(@interpret(^void, this.count))
}
}
}

40
example/src/ref.nub Normal file
View File

@@ -0,0 +1,40 @@
module "core"
extern "puts" func puts(text: cstring)
extern "malloc" func malloc(size: u64): ^void
extern "free" func free(address: ^void)
export struct ref<T>
{
value: ^T
count: ^u64
@oncreate
func on_create()
{
puts("on_create")
this.value = @interpret(^T, malloc(@size(T)))
this.count = @interpret(^u64, malloc(@size(u64)))
this.count^ = 1
}
@oncopy
func on_copy()
{
puts("on_copy")
this.count^ = this.count^ + 1
}
@ondestroy
func on_destroy()
{
puts("on_destroy")
this.count^ = this.count^ - 1
if this.count^ <= 0
{
puts("free")
free(@interpret(^void, this.value))
free(@interpret(^void, this.count))
}
}
}