Fix type checking for anonymous functions

This commit is contained in:
nub31
2025-06-08 17:26:50 +02:00
parent de071808e6
commit dce5a2b566
7 changed files with 140 additions and 31 deletions

View File

@@ -1,16 +1,18 @@
#!/bin/bash #!/bin/bash
set -e set -e
./clean.sh
mkdir -p out mkdir -p out
dotnet build src/lang/Nub.Lang.CLI dotnet build src/lang/Nub.Lang.CLI
nub example > out/out.qbe nub example > out/out.ssa
nasm -g -felf64 src/runtime/runtime.asm -o out/runtime.o nasm -g -felf64 src/runtime/runtime.asm -o out/runtime.o
nasm -g -felf64 src/runtime/core/syscall.asm -o out/syscall.o nasm -g -felf64 src/runtime/core/syscall.asm -o out/syscall.o
qbe out/out.qbe > out/out.s qbe out/out.ssa > out/out.s
gcc -c -g out/out.s -o out/out.o gcc -c -g out/out.s -o out/out.o
gcc -nostartfiles -o out/program out/runtime.o out/syscall.o out/out.o gcc -nostartfiles -o out/program out/runtime.o out/syscall.o out/out.o -no-pie

View File

@@ -2,23 +2,42 @@ namespace main
struct Human { struct Human {
age: u64 age: u64
print_age: func() = () => { print_age: func() = func() {
c::puts("pwp")
} }
} }
export func main(args: []^string): i64 { export func main(args: []^string): i64 {
let x = 2
let uwu = func() {
c::puts("uwu")
}
uwu()
func() {
c::puts("owo")
}()
let me = alloc Human { let me = alloc Human {
age = 23 age = 23
} }
me.print_age() me.print_age()
print_age() if true {
// do something
}
c::puts("test")
let i = 1
while i <= 10 {
c::puts("test")
i = i + 1
}
return 0 return 0
} }
func print_age() {
c::puts("TEST")
}

1
run.sh
View File

@@ -1,5 +1,4 @@
#!/bin/bash #!/bin/bash
set -e set -e
./clean.sh
./build.sh ./build.sh
bash -c './out/program; echo "Process exited with status code $?"' bash -c './out/program; echo "Process exited with status code $?"'

View File

@@ -24,6 +24,8 @@ public class QBEGenerator
private int _labelIndex; private int _labelIndex;
private bool _codeIsReachable = true; private bool _codeIsReachable = true;
private Dictionary<IFuncSignature, string> _funcNames = []; private Dictionary<IFuncSignature, string> _funcNames = [];
private Dictionary<AnonymousFuncNode, string> _anonymousFunctions = [];
private int _anonymousFuncIndex;
public string Generate(List<SourceFile> sourceFiles) public string Generate(List<SourceFile> sourceFiles)
{ {
@@ -34,8 +36,10 @@ public class QBEGenerator
_funcNames = []; _funcNames = [];
_breakLabels = []; _breakLabels = [];
_continueLabels = []; _continueLabels = [];
_anonymousFunctions = [];
_variableIndex = 0; _variableIndex = 0;
_labelIndex = 0; _labelIndex = 0;
_anonymousFuncIndex = 0;
_codeIsReachable = true; _codeIsReachable = true;
foreach (var structDef in _sourceFiles.SelectMany(f => f.Definitions).OfType<StructDefinitionNode>()) foreach (var structDef in _sourceFiles.SelectMany(f => f.Definitions).OfType<StructDefinitionNode>())
@@ -69,7 +73,13 @@ public class QBEGenerator
foreach (var funcDef in _sourceFiles.SelectMany(f => f.Definitions).OfType<LocalFuncDefinitionNode>()) foreach (var funcDef in _sourceFiles.SelectMany(f => f.Definitions).OfType<LocalFuncDefinitionNode>())
{ {
GenerateFuncDefinition(funcDef); GenerateFuncDefinition(_funcNames[funcDef], funcDef.Parameters, funcDef.ReturnType, funcDef.Body, funcDef.Exported);
_builder.AppendLine();
}
foreach (var (func, name) in _anonymousFunctions)
{
GenerateFuncDefinition(name, func.Parameters, func.ReturnType, func.Body, false);
_builder.AppendLine(); _builder.AppendLine();
} }
@@ -304,19 +314,19 @@ public class QBEGenerator
}; };
} }
private void GenerateFuncDefinition(LocalFuncDefinitionNode node) private void GenerateFuncDefinition(string name, List<FuncParameter> parameters, NubType returnType, BlockNode body, bool exported)
{ {
_variables.Clear(); _variables.Clear();
if (node.Exported) if (exported)
{ {
_builder.Append("export "); _builder.Append("export ");
} }
_builder.Append("function "); _builder.Append("function ");
if (node.ReturnType is not NubVoidType) if (returnType is not NubVoidType)
{ {
_builder.Append(node.ReturnType switch _builder.Append(returnType switch
{ {
NubArrayType => "l", NubArrayType => "l",
NubPointerType => "l", NubPointerType => "l",
@@ -338,14 +348,14 @@ public class QBEGenerator
NubStructType structType => $":{structType.Namespace}_{structType.Name}", NubStructType structType => $":{structType.Namespace}_{structType.Name}",
NubFixedArrayType => "l", NubFixedArrayType => "l",
NubFuncType => "l", NubFuncType => "l",
_ => throw new NotSupportedException($"'{node.ReturnType}' type cannot be used as a function return type") _ => throw new NotSupportedException($"'{returnType}' type cannot be used as a function return type")
}); });
_builder.Append(' '); _builder.Append(' ');
} }
_builder.Append(_funcNames[node]); _builder.Append(name);
var parameterStrings = node.Parameters.Select(parameter => $"{parameter.Type switch var parameterStrings = parameters.Select(parameter => $"{parameter.Type switch
{ {
NubArrayType => "l", NubArrayType => "l",
NubPointerType => "l", NubPointerType => "l",
@@ -373,7 +383,7 @@ public class QBEGenerator
_builder.AppendLine($"({string.Join(", ", parameterStrings)}) {{"); _builder.AppendLine($"({string.Join(", ", parameterStrings)}) {{");
_builder.AppendLine("@start"); _builder.AppendLine("@start");
foreach (var parameter in node.Parameters) foreach (var parameter in parameters)
{ {
var parameterName = "%" + parameter.Name; var parameterName = "%" + parameter.Name;
@@ -403,11 +413,11 @@ public class QBEGenerator
_variables[parameter.Name] = parameterName; _variables[parameter.Name] = parameterName;
} }
GenerateBlock(node.Body); GenerateBlock(body);
if (node.Body.Statements.LastOrDefault() is not ReturnNode) if (body.Statements.LastOrDefault() is not ReturnNode)
{ {
if (node.ReturnType is NubVoidType) if (returnType is NubVoidType)
{ {
_builder.AppendLine(" ret"); _builder.AppendLine(" ret");
} }
@@ -673,6 +683,7 @@ public class QBEGenerator
return expression switch return expression switch
{ {
AddressOfNode addressOf => GenerateAddressOf(addressOf), AddressOfNode addressOf => GenerateAddressOf(addressOf),
AnonymousFuncNode anonymousFunc => GenerateAnonymousFunc(anonymousFunc),
ArrayIndexAccessNode arrayIndex => GenerateArrayAccessIndex(arrayIndex), ArrayIndexAccessNode arrayIndex => GenerateArrayAccessIndex(arrayIndex),
ArrayInitializerNode arrayInitializer => GenerateArrayInitializer(arrayInitializer), ArrayInitializerNode arrayInitializer => GenerateArrayInitializer(arrayInitializer),
BinaryExpressionNode binaryExpression => GenerateBinaryExpression(binaryExpression), BinaryExpressionNode binaryExpression => GenerateBinaryExpression(binaryExpression),
@@ -688,6 +699,16 @@ public class QBEGenerator
}; };
} }
private string GenerateAnonymousFunc(AnonymousFuncNode anonymousFunc)
{
var name = $"$anon_func{++_anonymousFuncIndex}";
_anonymousFunctions[anonymousFunc] = name;
var pointer = GenVarName();
_builder.AppendLine($" {pointer} =l alloc8 8");
_builder.AppendLine($" storel {name}, {pointer}");
return pointer;
}
private string GenerateArrayIndexPointer(ArrayIndexAccessNode arrayIndexAccess) private string GenerateArrayIndexPointer(ArrayIndexAccessNode arrayIndexAccess)
{ {
var array = GenerateExpression(arrayIndexAccess.Array); var array = GenerateExpression(arrayIndexAccess.Array);
@@ -1036,12 +1057,25 @@ public class QBEGenerator
private string GenerateIdentifier(IdentifierNode identifier) private string GenerateIdentifier(IdentifierNode identifier)
{ {
if (_variables.TryGetValue(identifier.Name, out var value)) if (_variables.TryGetValue(identifier.Name, out var value))
{
if (IsLargeType(identifier.Type))
{ {
return value; return value;
} }
else else
{ {
return _funcNames[LookupFuncSignature(identifier.Namespace, identifier.Name)]; var result = GenVarName();
_builder.AppendLine($" {result} {QBEAssign(identifier.Type)} {QBELoad(identifier.Type)} {value}");
return result;
}
}
else
{
var funcName = _funcNames[LookupFuncSignature(identifier.Namespace, identifier.Name)];
var pointer = GenVarName();
_builder.AppendLine($" {pointer} =l alloc8 8");
_builder.AppendLine($" storel {funcName}, {pointer}");
return pointer;
} }
} }
@@ -1270,7 +1304,10 @@ public class QBEGenerator
parameterStrings.Add($"{qbeParameterType} {result}"); parameterStrings.Add($"{qbeParameterType} {result}");
} }
var funcPointer = GenerateExpression(funcCall.Expression); // var funcPointer = GenerateExpression(funcCall.Expression);
var funcPointerPointer = GenerateExpression(funcCall.Expression);
var funcPointer = GenVarName();
_builder.AppendLine($" {funcPointer} =l loadl {funcPointerPointer}");
if (funcType.ReturnType is not NubVoidType) if (funcType.ReturnType is not NubVoidType)
{ {

View File

@@ -0,0 +1,13 @@
using Nub.Lang.Frontend.Lexing;
using Nub.Lang.Frontend.Parsing.Definitions;
using Nub.Lang.Frontend.Parsing.Statements;
using Nub.Lang.Frontend.Typing;
namespace Nub.Lang.Frontend.Parsing.Expressions;
public class AnonymousFuncNode(IReadOnlyList<Token> tokens, List<FuncParameter> parameters, BlockNode body, NubType returnType) : ExpressionNode(tokens)
{
public List<FuncParameter> Parameters { get; } = parameters;
public BlockNode Body { get; } = body;
public NubType ReturnType { get; } = returnType;
}

View File

@@ -437,6 +437,31 @@ public class Parser
{ {
switch (symbolToken.Symbol) switch (symbolToken.Symbol)
{ {
case Symbol.Func:
{
List<FuncParameter> parameters = [];
ExpectSymbol(Symbol.OpenParen);
while (!TryExpectSymbol(Symbol.CloseParen))
{
var parameter = ParseFuncParameter();
parameters.Add(parameter);
if (!TryExpectSymbol(Symbol.Comma) && Peek().TryGetValue(out var nextToken) && nextToken is not SymbolToken { Symbol: Symbol.CloseParen })
{
_diagnostics.Add(Diagnostic
.Warning("Missing comma between function arguments")
.WithHelp("Add a ',' to separate arguments")
.At(nextToken)
.Build());
}
}
var returnType = TryExpectSymbol(Symbol.Colon) ? ParseType() : new NubVoidType();
var body = ParseBlock();
expr = new AnonymousFuncNode(GetTokensForNode(startIndex), parameters, body, returnType);
break;
}
case Symbol.OpenParen: case Symbol.OpenParen:
{ {
var expression = ParseExpression(); var expression = ParseExpression();

View File

@@ -14,6 +14,7 @@ public class TypeChecker
private List<SourceFile> _sourceFiles = []; private List<SourceFile> _sourceFiles = [];
private List<Diagnostic> _diagnostics = []; private List<Diagnostic> _diagnostics = [];
private NubType? _currentFunctionReturnType; private NubType? _currentFunctionReturnType;
private List<AnonymousFuncNode> _anonymousFunctions = [];
public DiagnosticsResult TypeCheck(List<SourceFile> sourceFiles) public DiagnosticsResult TypeCheck(List<SourceFile> sourceFiles)
{ {
@@ -21,6 +22,7 @@ public class TypeChecker
_diagnostics = []; _diagnostics = [];
_currentFunctionReturnType = null; _currentFunctionReturnType = null;
_sourceFiles = sourceFiles; _sourceFiles = sourceFiles;
_anonymousFunctions = [];
var externFuncDefinitions = _sourceFiles var externFuncDefinitions = _sourceFiles
.SelectMany(f => f.Definitions) .SelectMany(f => f.Definitions)
@@ -50,7 +52,12 @@ public class TypeChecker
foreach (var funcDef in _sourceFiles.SelectMany(f => f.Definitions).OfType<LocalFuncDefinitionNode>()) foreach (var funcDef in _sourceFiles.SelectMany(f => f.Definitions).OfType<LocalFuncDefinitionNode>())
{ {
TypeCheckFuncDef(funcDef); TypeCheckFuncDef(funcDef.Parameters, funcDef.Body, funcDef.ReturnType);
}
foreach (var anonymousFuncNode in _anonymousFunctions)
{
TypeCheckFuncDef(anonymousFuncNode.Parameters, anonymousFuncNode.Body, anonymousFuncNode.ReturnType);
} }
return new DiagnosticsResult(_diagnostics); return new DiagnosticsResult(_diagnostics);
@@ -80,17 +87,17 @@ public class TypeChecker
} }
} }
private void TypeCheckFuncDef(LocalFuncDefinitionNode funcDef) private void TypeCheckFuncDef(List<FuncParameter> parameters, BlockNode body, NubType returnType)
{ {
_variables.Clear(); _variables.Clear();
_currentFunctionReturnType = funcDef.ReturnType; _currentFunctionReturnType = returnType;
foreach (var param in funcDef.Parameters) foreach (var param in parameters)
{ {
_variables[param.Name] = param.Type; _variables[param.Name] = param.Type;
} }
TypeCheckBlock(funcDef.Body); TypeCheckBlock(body);
} }
private void TypeCheckBlock(BlockNode block) private void TypeCheckBlock(BlockNode block)
@@ -343,6 +350,7 @@ public class TypeChecker
var resultType = expression switch var resultType = expression switch
{ {
AddressOfNode addressOf => TypeCheckAddressOf(addressOf), AddressOfNode addressOf => TypeCheckAddressOf(addressOf),
AnonymousFuncNode anonymousFunc => TypeCheckAnonymousFunc(anonymousFunc),
ArrayIndexAccessNode arrayIndex => TypeCheckArrayIndex(arrayIndex), ArrayIndexAccessNode arrayIndex => TypeCheckArrayIndex(arrayIndex),
ArrayInitializerNode arrayInitializer => TypeCheckArrayInitializer(arrayInitializer), ArrayInitializerNode arrayInitializer => TypeCheckArrayInitializer(arrayInitializer),
LiteralNode literal => TypeCheckLiteral(literal, expectedType), LiteralNode literal => TypeCheckLiteral(literal, expectedType),
@@ -365,6 +373,12 @@ public class TypeChecker
return resultType; return resultType;
} }
private NubType TypeCheckAnonymousFunc(AnonymousFuncNode anonymousFunc)
{
_anonymousFunctions.Add(anonymousFunc);
return new NubFuncType(anonymousFunc.ReturnType, anonymousFunc.Parameters.Select(p => p.Type).ToList());
}
private NubType? TypeCheckLiteral(LiteralNode literal, NubType? expectedType = null) private NubType? TypeCheckLiteral(LiteralNode literal, NubType? expectedType = null)
{ {
if (expectedType != null) if (expectedType != null)