This commit is contained in:
nub31
2025-09-21 21:56:59 +02:00
parent 822cdf4bda
commit d0ad361776
7 changed files with 284 additions and 164 deletions

View File

@@ -11,6 +11,7 @@ public sealed class Parser
private List<Token> _tokens = []; private List<Token> _tokens = [];
private int _tokenIndex; private int _tokenIndex;
private string _moduleName = string.Empty; private string _moduleName = string.Empty;
private HashSet<string> _templateArguments = [];
private Token? CurrentToken => _tokenIndex < _tokens.Count ? _tokens[_tokenIndex] : null; private Token? CurrentToken => _tokenIndex < _tokens.Count ? _tokens[_tokenIndex] : null;
private bool HasToken => CurrentToken != null; private bool HasToken => CurrentToken != null;
@@ -26,6 +27,7 @@ public sealed class Parser
_tokens = tokens; _tokens = tokens;
_tokenIndex = 0; _tokenIndex = 0;
_moduleName = string.Empty; _moduleName = string.Empty;
_templateArguments.Clear();
var metadata = ParseMetadata(); var metadata = ParseMetadata();
var definitions = ParseDefinitions(); var definitions = ParseDefinitions();
@@ -39,13 +41,13 @@ public sealed class Parser
try try
{ {
ExpectSymbol(Symbol.Module);
_moduleName = ExpectLiteral(LiteralKind.String).Value;
while (TryExpectSymbol(Symbol.Import)) while (TryExpectSymbol(Symbol.Import))
{ {
imports.Add(ExpectLiteral(LiteralKind.String).Value); imports.Add(ExpectLiteral(LiteralKind.String).Value);
} }
ExpectSymbol(Symbol.Module);
_moduleName = ExpectLiteral(LiteralKind.String).Value;
} }
catch (ParseException e) catch (ParseException e)
{ {
@@ -167,12 +169,11 @@ public sealed class Parser
{ {
var name = ExpectIdentifier(); var name = ExpectIdentifier();
var templateArguments = new List<string>();
if (TryExpectSymbol(Symbol.LessThan)) if (TryExpectSymbol(Symbol.LessThan))
{ {
while (!TryExpectSymbol(Symbol.GreaterThan)) while (!TryExpectSymbol(Symbol.GreaterThan))
{ {
templateArguments.Add(ExpectIdentifier().Value); _templateArguments.Add(ExpectIdentifier().Value);
TryExpectSymbol(Symbol.Comma); TryExpectSymbol(Symbol.Comma);
} }
} }
@@ -217,8 +218,10 @@ public sealed class Parser
} }
} }
if (templateArguments.Count > 0) if (_templateArguments.Count > 0)
{ {
var templateArguments = _templateArguments.ToList();
_templateArguments.Clear();
return new StructTemplateSyntax(GetTokens(startIndex), templateArguments, name.Value, exported, fields, funcs); return new StructTemplateSyntax(GetTokens(startIndex), templateArguments, name.Value, exported, fields, funcs);
} }
@@ -670,6 +673,11 @@ public sealed class Parser
var startIndex = _tokenIndex; var startIndex = _tokenIndex;
if (TryExpectIdentifier(out var name)) if (TryExpectIdentifier(out var name))
{ {
if (_templateArguments.Contains(name.Value))
{
return new SubstitutionTypeSyntax(GetTokens(startIndex), name.Value);
}
if (name.Value[0] == 'u' && int.TryParse(name.Value[1..], out var size)) if (name.Value[0] == 'u' && int.TryParse(name.Value[1..], out var size))
{ {
if (size is not 8 and not 16 and not 32 and not 64) if (size is not 8 and not 16 and not 32 and not 64)
@@ -745,7 +753,7 @@ public sealed class Parser
if (templateParameters.Count > 0) if (templateParameters.Count > 0)
{ {
return new TemplateTypeSyntax(GetTokens(startIndex), templateParameters, module, name.Value); return new StructTemplateTypeSyntax(GetTokens(startIndex), templateParameters, module, name.Value);
} }
return new CustomTypeSyntax(GetTokens(startIndex), module, name.Value); return new CustomTypeSyntax(GetTokens(startIndex), module, name.Value);

View File

@@ -24,4 +24,6 @@ public record ArrayTypeSyntax(List<Token> Tokens, TypeSyntax BaseType) : TypeSyn
public record CustomTypeSyntax(List<Token> Tokens, string Module, string Name) : TypeSyntax(Tokens); public record CustomTypeSyntax(List<Token> Tokens, string Module, string Name) : TypeSyntax(Tokens);
public record TemplateTypeSyntax(List<Token> Tokens, List<TypeSyntax> TemplateParameters, string Module, string Name) : TypeSyntax(Tokens); public record StructTemplateTypeSyntax(List<Token> Tokens, List<TypeSyntax> TemplateParameters, string Module, string Name) : TypeSyntax(Tokens);
public record SubstitutionTypeSyntax(List<Token> Tokens, string Name) : TypeSyntax(Tokens);

View File

@@ -13,5 +13,3 @@ public record StructFieldNode(string Name, NubType Type, ExpressionNode? Value)
public record StructFuncNode(string Name, string? Hook, FuncSignatureNode Signature, BlockNode Body) : Node; public record StructFuncNode(string Name, string? Hook, FuncSignatureNode Signature, BlockNode Body) : Node;
public record StructNode(string Module, string Name, List<StructFieldNode> Fields, List<StructFuncNode> Functions) : DefinitionNode(Module, Name); public record StructNode(string Module, string Name, List<StructFieldNode> Fields, List<StructFuncNode> Functions) : DefinitionNode(Module, Name);
public record StructTemplateNode(string Module, string Name, List<string> TemplateArguments, List<StructFieldNode> Fields, List<StructFuncNode> Functions) : DefinitionNode(Module, Name);

View File

@@ -14,20 +14,19 @@ public sealed class TypeChecker
private readonly Dictionary<string, Module> _visibleModules; private readonly Dictionary<string, Module> _visibleModules;
private readonly Stack<Scope> _scopes = []; private readonly Stack<Scope> _scopes = [];
private Scope _globalScope = new();
private readonly Stack<NubType> _funcReturnTypes = []; private readonly Stack<NubType> _funcReturnTypes = [];
private readonly Dictionary<(string Module, string Name), NubType> _typeCache = new(); private readonly Dictionary<(string Module, string Name), NubType> _typeCache = new();
private readonly HashSet<(string Module, string Name)> _resolvingTypes = []; private readonly HashSet<(string Module, string Name)> _resolvingTypes = [];
private readonly HashSet<string> _checkedTemplateStructs = [];
private Scope CurrentScope => _scopes.Peek(); private Scope CurrentScope => _scopes.Peek();
private string CurrentModule => _syntaxTree.Metadata.ModuleName;
public TypeChecker(SyntaxTree syntaxTree, ModuleRepository moduleRepository) public TypeChecker(SyntaxTree syntaxTree, ModuleRepository moduleRepository)
{ {
_syntaxTree = syntaxTree; _syntaxTree = syntaxTree;
_visibleModules = moduleRepository _visibleModules = moduleRepository
.Modules() .Modules()
.Where(x => syntaxTree.Metadata.Imports.Contains(x.Key) || CurrentModule == x.Key) .Where(x => syntaxTree.Metadata.Imports.Contains(x.Key) || _syntaxTree.Metadata.ModuleName == x.Key)
.ToDictionary(); .ToDictionary();
} }
@@ -38,10 +37,10 @@ public sealed class TypeChecker
public void Check() public void Check()
{ {
_scopes.Clear(); _scopes.Clear();
_globalScope = new Scope();
_funcReturnTypes.Clear(); _funcReturnTypes.Clear();
_typeCache.Clear(); _typeCache.Clear();
_resolvingTypes.Clear(); _resolvingTypes.Clear();
_checkedTemplateStructs.Clear();
Diagnostics.Clear(); Diagnostics.Clear();
Definitions.Clear(); Definitions.Clear();
@@ -49,42 +48,64 @@ public sealed class TypeChecker
foreach (var definition in _syntaxTree.Definitions) foreach (var definition in _syntaxTree.Definitions)
{ {
BeginScope(true);
try try
{ {
Definitions.Add(definition switch switch (definition)
{ {
FuncSyntax funcSyntax => CheckFuncDefinition(funcSyntax), case FuncSyntax funcSyntax:
StructSyntax structSyntax => CheckStructDefinition(structSyntax), Definitions.Add(CheckFuncDefinition(funcSyntax));
StructTemplateSyntax structTemplate => CheckStructTemplateDefinition(structTemplate), break;
_ => throw new ArgumentOutOfRangeException() case StructSyntax structSyntax:
}); Definitions.Add(CheckStructDefinition(structSyntax));
break;
case StructTemplateSyntax:
break;
default:
throw new ArgumentOutOfRangeException();
}
} }
catch (TypeCheckerException e) catch (TypeCheckerException e)
{ {
Diagnostics.Add(e.Diagnostic); Diagnostics.Add(e.Diagnostic);
} }
EndScope();
} }
} }
private void BeginScope(bool root) private ScopeDisposer BeginScope()
{ {
var scope = root if (_scopes.TryPeek(out var scope))
? _globalScope.SubScope() {
: _scopes.Peek().SubScope(); _scopes.Push(scope.SubScope());
}
_scopes.Push(scope); else
{
_scopes.Push(new Scope(_syntaxTree.Metadata.ModuleName));
} }
private void EndScope() return new ScopeDisposer(this);
}
private ScopeDisposer BeginRootScope(string module)
{ {
_scopes.Pop(); _scopes.Push(new Scope(module));
return new ScopeDisposer(this);
}
private sealed class ScopeDisposer(TypeChecker owner) : IDisposable
{
private bool _disposed;
public void Dispose()
{
if (_disposed) return;
owner._scopes.Pop();
_disposed = true;
}
} }
private StructNode CheckStructDefinition(StructSyntax node) private StructNode CheckStructDefinition(StructSyntax node)
{
using (BeginRootScope(_syntaxTree.Metadata.ModuleName))
{ {
var fieldTypes = node.Fields var fieldTypes = node.Fields
.Select(x => new NubStructFieldType(x.Name, ResolveType(x.Type), x.Value != null)) .Select(x => new NubStructFieldType(x.Name, ResolveType(x.Type), x.Value != null))
@@ -99,22 +120,15 @@ public sealed class TypeChecker
}) })
.ToList(); .ToList();
var structType = new NubStructType(CurrentModule, node.Name, fieldTypes, fieldFunctions); var structType = new NubStructType(CurrentScope.Module, node.Name, fieldTypes, fieldFunctions);
CurrentScope.DeclareVariable(new Variable("this", structType, VariableKind.RValue)); CurrentScope.DeclareVariable(new Variable("this", structType, VariableKind.RValue));
var fields = node.Fields.Select(CheckStructField).ToList(); var fields = node.Fields.Select(CheckStructField).ToList();
var functions = node.Functions.Select(CheckStructFunc).ToList(); var functions = node.Functions.Select(CheckStructFunc).ToList();
return new StructNode(CurrentModule, node.Name, fields, functions); return new StructNode(CurrentScope.Module, node.Name, fields, functions);
} }
private StructTemplateNode CheckStructTemplateDefinition(StructTemplateSyntax node)
{
var fields = node.Fields.Select(CheckStructField).ToList();
var functions = node.Functions.Select(CheckStructFunc).ToList();
return new StructTemplateNode(CurrentModule, node.Name, node.TemplateArguments, fields, functions);
} }
private StructFuncNode CheckStructFunc(StructFuncSyntax function) private StructFuncNode CheckStructFunc(StructFuncSyntax function)
@@ -142,6 +156,8 @@ public sealed class TypeChecker
} }
private FuncNode CheckFuncDefinition(FuncSyntax node) private FuncNode CheckFuncDefinition(FuncSyntax node)
{
using (BeginRootScope(_syntaxTree.Metadata.ModuleName))
{ {
foreach (var parameter in node.Signature.Parameters) foreach (var parameter in node.Signature.Parameters)
{ {
@@ -175,7 +191,8 @@ public sealed class TypeChecker
_funcReturnTypes.Pop(); _funcReturnTypes.Pop();
} }
return new FuncNode(CurrentModule, node.Name, node.ExternSymbol, signature, body); return new FuncNode(CurrentScope.Module, node.Name, node.ExternSymbol, signature, body);
}
} }
private AssignmentNode CheckAssignment(AssignmentSyntax statement) private AssignmentNode CheckAssignment(AssignmentSyntax statement)
@@ -588,14 +605,14 @@ public sealed class TypeChecker
} }
// Second, look in the current module for a function matching the identifier // Second, look in the current module for a function matching the identifier
var module = _visibleModules[CurrentModule]; var module = _visibleModules[CurrentScope.Module];
var function = module.Functions(true).FirstOrDefault(x => x.Name == expression.Name); var function = module.Functions(true).FirstOrDefault(x => x.Name == expression.Name);
if (function != null) if (function != null)
{ {
var parameters = function.Signature.Parameters.Select(x => ResolveType(x.Type)).ToList(); var parameters = function.Signature.Parameters.Select(x => ResolveType(x.Type)).ToList();
var type = new NubFuncType(parameters, ResolveType(function.Signature.ReturnType)); var type = new NubFuncType(parameters, ResolveType(function.Signature.ReturnType));
return new FuncIdentifierNode(type, CurrentModule, expression.Name, function.ExternSymbol); return new FuncIdentifierNode(type, CurrentScope.Module, expression.Name, function.ExternSymbol);
} }
throw new TypeCheckerException(Diagnostic.Error($"Symbol {expression.Name} not found").At(expression).Build()); throw new TypeCheckerException(Diagnostic.Error($"Symbol {expression.Name} not found").At(expression).Build());
@@ -612,7 +629,7 @@ public sealed class TypeChecker
.Build()); .Build());
} }
var includePrivate = expression.Module == CurrentModule; var includePrivate = expression.Module == CurrentScope.Module;
// First, look for the exported function in the specified module // First, look for the exported function in the specified module
var function = module.Functions(includePrivate).FirstOrDefault(x => x.Name == expression.Name); var function = module.Functions(includePrivate).FirstOrDefault(x => x.Name == expression.Name);
@@ -760,8 +777,8 @@ public sealed class TypeChecker
var reachable = true; var reachable = true;
var warnedUnreachable = false; var warnedUnreachable = false;
BeginScope(false); using (BeginScope())
{
foreach (var statement in node.Statements) foreach (var statement in node.Statements)
{ {
var checkedStatement = CheckStatement(statement); var checkedStatement = CheckStatement(statement);
@@ -785,10 +802,9 @@ public sealed class TypeChecker
} }
} }
EndScope();
return new BlockNode(statements); return new BlockNode(statements);
} }
}
private StatementNode CheckStatement(StatementSyntax statement) private StatementNode CheckStatement(StatementSyntax statement)
{ {
@@ -843,12 +859,27 @@ public sealed class TypeChecker
PointerTypeSyntax ptr => new NubPointerType(ResolveType(ptr.BaseType)), PointerTypeSyntax ptr => new NubPointerType(ResolveType(ptr.BaseType)),
StringTypeSyntax => new NubStringType(), StringTypeSyntax => new NubStringType(),
CustomTypeSyntax c => ResolveCustomType(c), CustomTypeSyntax c => ResolveCustomType(c),
TemplateTypeSyntax t => ResolveTemplateType(t), StructTemplateTypeSyntax t => ResolveStructTemplateType(t),
SubstitutionTypeSyntax s => ResolveTypeSubstitution(s),
VoidTypeSyntax => new NubVoidType(), VoidTypeSyntax => new NubVoidType(),
_ => throw new NotSupportedException($"Unknown type syntax: {type}") _ => throw new NotSupportedException($"Unknown type syntax: {type}")
}; };
} }
private NubType ResolveTypeSubstitution(SubstitutionTypeSyntax substitution)
{
var type = CurrentScope.LookupTypeSubstitution(substitution.Name);
if (type == null)
{
throw new TypeCheckerException(Diagnostic
.Error($"Template argument {substitution.Name} does not exist")
.At(substitution)
.Build());
}
return type;
}
private NubType ResolveCustomType(CustomTypeSyntax customType) private NubType ResolveCustomType(CustomTypeSyntax customType)
{ {
var key = (customType.Module, customType.Name); var key = (customType.Module, customType.Name);
@@ -876,21 +907,21 @@ public sealed class TypeChecker
.Build()); .Build());
} }
var includePrivate = customType.Module == CurrentModule; var includePrivate = customType.Module == CurrentScope.Module;
var strctDef = module.Structs(includePrivate).FirstOrDefault(x => x.Name == customType.Name); var structDef = module.Structs(includePrivate).FirstOrDefault(x => x.Name == customType.Name);
if (strctDef != null) if (structDef != null)
{ {
var result = new NubStructType(customType.Module, strctDef.Name, [], []); var result = new NubStructType(customType.Module, structDef.Name, [], []);
_typeCache[key] = result; _typeCache[key] = result;
var fields = strctDef.Fields var fields = structDef.Fields
.Select(x => new NubStructFieldType(x.Name, ResolveType(x.Type), x.Value != null)) .Select(x => new NubStructFieldType(x.Name, ResolveType(x.Type), x.Value != null))
.ToList(); .ToList();
result.Fields.AddRange(fields); result.Fields.AddRange(fields);
var functions = strctDef.Functions var functions = structDef.Functions
.Select(x => .Select(x =>
{ {
var parameters = x.Signature.Parameters var parameters = x.Signature.Parameters
@@ -918,9 +949,71 @@ public sealed class TypeChecker
} }
} }
private NubType ResolveTemplateType(TemplateTypeSyntax template) private NubStructType ResolveStructTemplateType(StructTemplateTypeSyntax structTemplate)
{ {
throw new NotImplementedException(); if (!_visibleModules.TryGetValue(structTemplate.Module, out var module))
{
throw new TypeCheckerException(Diagnostic
.Error($"Module {structTemplate.Module} not found")
.WithHelp($"import \"{structTemplate.Module}\"")
.At(structTemplate)
.Build());
}
var includePrivate = structTemplate.Module == CurrentScope.Module;
var templateDef = module
.StructTemplates(includePrivate)
.FirstOrDefault(x => x.Name == structTemplate.Name);
if (templateDef == null)
{
throw new TypeCheckerException(Diagnostic
.Error($"Template type {structTemplate.Name} not found in module {structTemplate.Module}")
.At(structTemplate)
.Build());
}
var templateParameterTypes = structTemplate.TemplateParameters.Select(ResolveType).ToList();
var mangledName = $"{structTemplate.Name}.{NameMangler.Mangle(templateParameterTypes)}";
using (BeginRootScope(structTemplate.Module))
{
for (var i = 0; i < templateParameterTypes.Count; i++)
{
var parameterName = templateDef.TemplateArguments[i];
var parameterType = templateParameterTypes[i];
CurrentScope.DeclareTypeSubstitution(parameterName, parameterType);
}
var fieldTypes = templateDef.Fields
.Select(x => new NubStructFieldType(x.Name, ResolveType(x.Type), x.Value != null))
.ToList();
var fieldFunctions = templateDef.Functions
.Select(x =>
{
var parameters = x.Signature.Parameters.Select(y => ResolveType(y.Type)).ToList();
var returnType = ResolveType(x.Signature.ReturnType);
return new NubStructFuncType(x.Name, x.Hook, parameters, returnType);
})
.ToList();
var structType = new NubStructType(structTemplate.Module, mangledName, fieldTypes, fieldFunctions);
if (!_checkedTemplateStructs.Contains($"{structTemplate.Module}.{mangledName}"))
{
CurrentScope.DeclareVariable(new Variable("this", structType, VariableKind.RValue));
var fields = templateDef.Fields.Select(CheckStructField).ToList();
var functions = templateDef.Functions.Select(CheckStructFunc).ToList();
Definitions.Add(new StructNode(structTemplate.Module, mangledName, fields, functions));
_checkedTemplateStructs.Add($"{structTemplate.Module}.{mangledName}");
}
ReferencedStructTypes.Add(structType);
return structType;
}
} }
} }
@@ -932,9 +1025,16 @@ public enum VariableKind
public record Variable(string Name, NubType Type, VariableKind Kind); public record Variable(string Name, NubType Type, VariableKind Kind);
public class Scope(Scope? parent = null) public class Scope(string module, Scope? parent = null)
{ {
private readonly List<Variable> _variables = []; private readonly List<Variable> _variables = [];
private readonly Dictionary<string, NubType> _typeSubstitutions = [];
public string Module { get; } = module;
public void DeclareVariable(Variable variable)
{
_variables.Add(variable);
}
public Variable? LookupVariable(string name) public Variable? LookupVariable(string name)
{ {
@@ -947,14 +1047,24 @@ public class Scope(Scope? parent = null)
return parent?.LookupVariable(name); return parent?.LookupVariable(name);
} }
public void DeclareVariable(Variable variable) public void DeclareTypeSubstitution(string name, NubType type)
{ {
_variables.Add(variable); _typeSubstitutions[name] = type;
}
public NubType? LookupTypeSubstitution(string name)
{
if (_typeSubstitutions.TryGetValue(name, out var type))
{
return type;
}
return parent?.LookupTypeSubstitution(name);
} }
public Scope SubScope() public Scope SubScope()
{ {
return new Scope(this); return new Scope(Module, this);
} }
} }

View File

@@ -3,8 +3,8 @@ NUBC = ../compiler/NubLang.CLI/bin/Debug/net9.0/nubc
out: .build/out.o out: .build/out.o
gcc -nostartfiles -o out x86_64.s .build/out.o gcc -nostartfiles -o out x86_64.s .build/out.o
.build/out.o: $(NUBC) src/main.nub .build/out.o: $(NUBC) src/main.nub src/ref.nub
$(NUBC) src/main.nub $(NUBC) src/main.nub src/ref.nub
.PHONY: $(NUBC) .PHONY: $(NUBC)
$(NUBC): $(NUBC):

View File

@@ -1,8 +1,6 @@
module "main" import "core"
extern "puts" func puts(text: cstring) module "main"
extern "malloc" func malloc(size: u64): ^void
extern "free" func free(address: ^void)
struct Human struct Human
{ {
@@ -11,47 +9,11 @@ struct Human
extern "main" func main(args: []cstring): i64 extern "main" func main(args: []cstring): i64
{ {
let x: ref<Human> = {} let x: core::ref<Human> = {}
test(x) test(x)
return 0 return 0
} }
func test(x: ref<Human>) func test(x: core::ref<Human>)
{ {
}
struct ref<T>
{
value: ^T
count: ^u64
@oncreate
func on_create()
{
puts("on_create")
this.value = @interpret(^T, malloc(@size(T)))
this.count = @interpret(^u64, malloc(@size(u64)))
this.count^ = 1
}
@oncopy
func on_copy()
{
puts("on_copy")
this.count^ = this.count^ + 1
}
@ondestroy
func on_destroy()
{
puts("on_destroy")
this.count^ = this.count^ - 1
if this.count^ <= 0
{
puts("free")
free(@interpret(^void, this.value))
free(@interpret(^void, this.count))
}
}
} }

40
example/src/ref.nub Normal file
View File

@@ -0,0 +1,40 @@
module "core"
extern "puts" func puts(text: cstring)
extern "malloc" func malloc(size: u64): ^void
extern "free" func free(address: ^void)
export struct ref<T>
{
value: ^T
count: ^u64
@oncreate
func on_create()
{
puts("on_create")
this.value = @interpret(^T, malloc(@size(T)))
this.count = @interpret(^u64, malloc(@size(u64)))
this.count^ = 1
}
@oncopy
func on_copy()
{
puts("on_copy")
this.count^ = this.count^ + 1
}
@ondestroy
func on_destroy()
{
puts("on_destroy")
this.count^ = this.count^ - 1
if this.count^ <= 0
{
puts("free")
free(@interpret(^void, this.value))
free(@interpret(^void, this.count))
}
}
}