defer and struct hooks

This commit is contained in:
nub31
2025-09-19 23:08:12 +02:00
parent 9167660dfc
commit c9064f0a43
15 changed files with 240 additions and 97 deletions

View File

@@ -16,11 +16,14 @@ public class QBEGenerator
private readonly List<StringLiteral> _stringLiterals = []; private readonly List<StringLiteral> _stringLiterals = [];
private readonly Stack<string> _breakLabels = []; private readonly Stack<string> _breakLabels = [];
private readonly Stack<string> _continueLabels = []; private readonly Stack<string> _continueLabels = [];
private readonly Stack<Scope> _scopes = new();
private int _tmpIndex; private int _tmpIndex;
private int _labelIndex; private int _labelIndex;
private int _cStringLiteralIndex; private int _cStringLiteralIndex;
private int _stringLiteralIndex; private int _stringLiteralIndex;
private Scope Scope => _scopes.Peek();
public QBEGenerator(List<DefinitionNode> definitions, HashSet<StructTypeNode> structTypes) public QBEGenerator(List<DefinitionNode> definitions, HashSet<StructTypeNode> structTypes)
{ {
_definitions = definitions; _definitions = definitions;
@@ -34,6 +37,7 @@ public class QBEGenerator
_stringLiterals.Clear(); _stringLiterals.Clear();
_breakLabels.Clear(); _breakLabels.Clear();
_continueLabels.Clear(); _continueLabels.Clear();
_scopes.Clear();
_tmpIndex = 0; _tmpIndex = 0;
_labelIndex = 0; _labelIndex = 0;
_cStringLiteralIndex = 0; _cStringLiteralIndex = 0;
@@ -307,6 +311,16 @@ public class QBEGenerator
{ {
var value = EmitExpression(source); var value = EmitExpression(source);
_writer.Indented($"blit {value}, {destination}, {SizeOf(source.Type)}"); _writer.Indented($"blit {value}, {destination}, {SizeOf(source.Type)}");
if (source.Type is StructTypeNode structType)
{
var copyFunc = structType.Functions.FirstOrDefault(x => x.Hook == "oncopy");
if (copyFunc != null)
{
_writer.Indented($"call {StructFuncName(structType.Module, structType.Name, copyFunc.Name)}(l {destination})");
}
}
return; return;
} }
@@ -349,48 +363,6 @@ public class QBEGenerator
} }
} }
private string EmitCopy(ExpressionNode source)
{
if (source is RValueExpressionNode || source.Type.IsScalar)
{
return EmitExpression(source);
}
var value = EmitExpression(source);
var destination = TmpName();
if (source.Type.IsValueType)
{
var size = SizeOf(source.Type);
_writer.Indented($"{destination} =l alloc8 {size}");
_writer.Indented($"blit {value}, {destination}, {size}");
return destination;
}
switch (source.Type)
{
case ArrayTypeNode arrayType:
var arraySize = EmitArraySizeInBytes(arrayType, value);
_writer.Indented($"{destination} =l alloc8 {arraySize}");
EmitMemcpy(value, destination, arraySize);
break;
case CStringTypeNode:
var cstrSize = EmitCStringSizeInBytes(value);
_writer.Indented($"{destination} =l alloc8 {cstrSize}");
EmitMemcpy(value, destination, cstrSize);
break;
case StringTypeNode:
var strSize = EmitStringSizeInBytes(value);
_writer.Indented($"{destination} =l alloc8 {strSize}");
EmitMemcpy(value, destination, strSize);
break;
default:
throw new InvalidOperationException($"Cannot copy type {source.Type}");
}
return destination;
}
private void EmitStructType(StructTypeNode structType) private void EmitStructType(StructTypeNode structType)
{ {
// todo(nub31): qbe expects structs to be declared in order. We must Check the dependencies of the struct to see if a type need to be declared before this one // todo(nub31): qbe expects structs to be declared in order. We must Check the dependencies of the struct to see if a type need to be declared before this one
@@ -450,7 +422,13 @@ public class QBEGenerator
_writer.WriteLine(") {"); _writer.WriteLine(") {");
_writer.WriteLine("@start"); _writer.WriteLine("@start");
EmitBlock(function.Body); var scope = new Scope();
foreach (var parameter in function.Signature.Parameters)
{
scope.Variables.Push(new Variable(parameter.Name, parameter.Type));
}
EmitBlock(function.Body, scope);
// Implicit return for void functions if no explicit return has been set // Implicit return for void functions if no explicit return has been set
if (function.Signature.ReturnType is VoidTypeNode && function.Body.Statements.LastOrDefault() is not ReturnNode) if (function.Signature.ReturnType is VoidTypeNode && function.Body.Statements.LastOrDefault() is not ReturnNode)
@@ -488,18 +466,31 @@ public class QBEGenerator
_writer.WriteLine(") {"); _writer.WriteLine(") {");
_writer.WriteLine("@start"); _writer.WriteLine("@start");
EmitBlock(funcDef.Body); var scope = new Scope();
foreach (var parameter in funcDef.Signature.Parameters)
{
scope.Variables.Push(new Variable(parameter.Name, parameter.Type));
}
EmitBlock(funcDef.Body, scope);
_writer.WriteLine("}"); _writer.WriteLine("}");
_writer.NewLine(); _writer.NewLine();
} }
private void EmitBlock(BlockNode block) private void EmitBlock(BlockNode block, Scope? scope = null)
{ {
scope ??= Scope.SubScope();
_scopes.Push(scope);
foreach (var statement in block.Statements) foreach (var statement in block.Statements)
{ {
EmitStatement(statement); EmitStatement(statement);
} }
EmitScopeCleanup();
_scopes.Pop();
} }
private void EmitStatement(StatementNode statement) private void EmitStatement(StatementNode statement)
@@ -510,9 +501,11 @@ public class QBEGenerator
EmitCopyInto(assignment.Value, EmitAddressOf(assignment.Target)); EmitCopyInto(assignment.Value, EmitAddressOf(assignment.Target));
break; break;
case BreakNode: case BreakNode:
EmitScopeCleanup();
_writer.Indented($"jmp {_breakLabels.Peek()}"); _writer.Indented($"jmp {_breakLabels.Peek()}");
break; break;
case ContinueNode: case ContinueNode:
EmitScopeCleanup();
_writer.Indented($"jmp {_continueLabels.Peek()}"); _writer.Indented($"jmp {_continueLabels.Peek()}");
break; break;
case IfNode ifStatement: case IfNode ifStatement:
@@ -527,6 +520,9 @@ public class QBEGenerator
case VariableDeclarationNode variableDeclaration: case VariableDeclarationNode variableDeclaration:
EmitVariableDeclaration(variableDeclaration); EmitVariableDeclaration(variableDeclaration);
break; break;
case DeferNode defer:
Scope.DeferredStatements.Push(defer.Statement);
break;
case WhileNode whileStatement: case WhileNode whileStatement:
EmitWhile(whileStatement); EmitWhile(whileStatement);
break; break;
@@ -538,6 +534,26 @@ public class QBEGenerator
} }
} }
private void EmitScopeCleanup()
{
while (Scope.DeferredStatements.TryPop(out var deferredStatement))
{
EmitStatement(deferredStatement);
}
while (Scope.Variables.TryPop(out var variable))
{
if (variable.Type is StructTypeNode structType)
{
var destroyFunc = structType.Functions.FirstOrDefault(x => x.Hook == "ondestroy");
if (destroyFunc != null)
{
_writer.Indented($"call {StructFuncName(structType.Module, structType.Name, destroyFunc.Name)}(l %{variable.Name})");
}
}
}
}
private void EmitIf(IfNode ifStatement) private void EmitIf(IfNode ifStatement)
{ {
var trueLabel = LabelName(); var trueLabel = LabelName();
@@ -552,7 +568,7 @@ public class QBEGenerator
_writer.WriteLine(falseLabel); _writer.WriteLine(falseLabel);
if (ifStatement.Else.HasValue) if (ifStatement.Else.HasValue)
{ {
ifStatement.Else.Value.Match(EmitIf, EmitBlock); ifStatement.Else.Value.Match(EmitIf, b => EmitBlock(b));
} }
_writer.WriteLine(endLabel); _writer.WriteLine(endLabel);
@@ -563,10 +579,12 @@ public class QBEGenerator
if (@return.Value.HasValue) if (@return.Value.HasValue)
{ {
var result = EmitExpression(@return.Value.Value); var result = EmitExpression(@return.Value.Value);
EmitScopeCleanup();
_writer.Indented($"ret {result}"); _writer.Indented($"ret {result}");
} }
else else
{ {
EmitScopeCleanup();
_writer.Indented("ret"); _writer.Indented("ret");
} }
} }
@@ -580,6 +598,8 @@ public class QBEGenerator
{ {
EmitCopyInto(variableDeclaration.Assignment.Value, name); EmitCopyInto(variableDeclaration.Assignment.Value, name);
} }
Scope.Variables.Push(new Variable(variableDeclaration.Name, variableDeclaration.Type));
} }
private void EmitWhile(WhileNode whileStatement) private void EmitWhile(WhileNode whileStatement)
@@ -677,7 +697,6 @@ public class QBEGenerator
BinaryExpressionNode expr => EmitBinaryExpression(expr), BinaryExpressionNode expr => EmitBinaryExpression(expr),
ConvertFloatNode expr => EmitConvertFloat(expr), ConvertFloatNode expr => EmitConvertFloat(expr),
ConvertIntNode expr => EmitConvertInt(expr), ConvertIntNode expr => EmitConvertInt(expr),
DereferenceNode expr => EmitLoad(expr.Type, EmitExpression(expr.Expression)),
FuncCallNode expr => EmitFuncCall(expr), FuncCallNode expr => EmitFuncCall(expr),
FuncIdentifierNode expr => FuncName(expr.Module, expr.Name, expr.ExternSymbol), FuncIdentifierNode expr => FuncName(expr.Module, expr.Name, expr.ExternSymbol),
FuncParameterIdentifierNode expr => $"%{expr.Name}", FuncParameterIdentifierNode expr => $"%{expr.Name}",
@@ -740,6 +759,7 @@ public class QBEGenerator
return lval switch return lval switch
{ {
ArrayIndexAccessNode arrayIndexAccess => EmitAddressOfArrayIndexAccess(arrayIndexAccess), ArrayIndexAccessNode arrayIndexAccess => EmitAddressOfArrayIndexAccess(arrayIndexAccess),
DereferenceNode dereference => EmitExpression(dereference.Expression),
StructFieldAccessNode structFieldAccess => EmitAddressOfStructFieldAccess(structFieldAccess), StructFieldAccessNode structFieldAccess => EmitAddressOfStructFieldAccess(structFieldAccess),
VariableIdentifierNode variableIdent => $"%{variableIdent.Name}", VariableIdentifierNode variableIdent => $"%{variableIdent.Name}",
_ => throw new ArgumentOutOfRangeException(nameof(lval)) _ => throw new ArgumentOutOfRangeException(nameof(lval))
@@ -948,6 +968,12 @@ public class QBEGenerator
_writer.Indented($"{destination} =l alloc8 {size}"); _writer.Indented($"{destination} =l alloc8 {size}");
_writer.Indented($"call {StructCtorName(structInitializer.StructType.Module, structInitializer.StructType.Name)}(l {destination})"); _writer.Indented($"call {StructCtorName(structInitializer.StructType.Module, structInitializer.StructType.Name)}(l {destination})");
var createFunc = structInitializer.StructType.Functions.FirstOrDefault(x => x.Hook == "oncreate");
if (createFunc != null)
{
_writer.Indented($"call {StructFuncName(structInitializer.StructType.Module, structInitializer.StructType.Name, createFunc.Name)}(l {destination})");
}
foreach (var (field, value) in structInitializer.Initializers) foreach (var (field, value) in structInitializer.Initializers)
{ {
var offset = TmpName(); var offset = TmpName();
@@ -1015,8 +1041,18 @@ public class QBEGenerator
foreach (var parameter in structFuncCall.Parameters) foreach (var parameter in structFuncCall.Parameters)
{ {
var copy = EmitCopy(parameter); var value = EmitExpression(parameter);
parameterStrings.Add($"{FuncQBETypeName(parameter.Type)} {copy}");
if (parameter.Type is StructTypeNode structType)
{
var copyFunc = structType.Functions.FirstOrDefault(x => x.Hook == "oncopy");
if (copyFunc != null)
{
_writer.Indented($"call {StructFuncName(structType.Module, structType.Name, copyFunc.Name)}(l {value})");
}
}
parameterStrings.Add($"{FuncQBETypeName(parameter.Type)} {value}");
} }
if (structFuncCall.Type is VoidTypeNode) if (structFuncCall.Type is VoidTypeNode)
@@ -1093,8 +1129,18 @@ public class QBEGenerator
foreach (var parameter in funcCall.Parameters) foreach (var parameter in funcCall.Parameters)
{ {
var copy = EmitCopy(parameter); var value = EmitExpression(parameter);
parameterStrings.Add($"{FuncQBETypeName(parameter.Type)} {copy}");
if (parameter.Type is StructTypeNode structType)
{
var copyFunc = structType.Functions.FirstOrDefault(x => x.Hook == "oncopy");
if (copyFunc != null)
{
_writer.Indented($"call {StructFuncName(structType.Module, structType.Name, copyFunc.Name)}(l {value})");
}
}
parameterStrings.Add($"{FuncQBETypeName(parameter.Type)} {value}");
} }
if (funcCall.Type is VoidTypeNode) if (funcCall.Type is VoidTypeNode)
@@ -1242,6 +1288,20 @@ public class QBEGenerator
} }
} }
// todo(nub31): Parent is not used when getting variables and deferred statements
public class Scope(Scope? parent = null)
{
public readonly Stack<StatementNode> DeferredStatements = [];
public readonly Stack<Variable> Variables = [];
public Scope SubScope()
{
return new Scope(this);
}
}
public record Variable(string Name, TypeNode Type);
public class StringLiteral(string value, string name) public class StringLiteral(string value, string name)
{ {
public string Value { get; } = value; public string Value { get; } = value;

View File

@@ -32,7 +32,7 @@ public record ModuleStructField(string Name, TypeSyntax Type, bool HasDefaultVal
public record ModuleStructFunctionParameter(string Name, TypeSyntax Type); public record ModuleStructFunctionParameter(string Name, TypeSyntax Type);
public record ModuleStructFunction(string Name, List<ModuleStructFunctionParameter> Parameters, TypeSyntax ReturnType); public record ModuleStructFunction(string Name, string? Hook, List<ModuleStructFunctionParameter> Parameters, TypeSyntax ReturnType);
public record ModuleStruct(bool Exported, string Name, List<ModuleStructField> Fields, List<ModuleStructFunction> Functions); public record ModuleStruct(bool Exported, string Name, List<ModuleStructField> Fields, List<ModuleStructFunction> Functions);

View File

@@ -36,7 +36,7 @@ public class ModuleRepository
foreach (var function in structDef.Functions) foreach (var function in structDef.Functions)
{ {
var parameters = function.Signature.Parameters.Select(x => new ModuleStructFunctionParameter(x.Name, x.Type)).ToList(); var parameters = function.Signature.Parameters.Select(x => new ModuleStructFunctionParameter(x.Name, x.Type)).ToList();
functions.AddRange(new ModuleStructFunction(function.Name, parameters, function.Signature.ReturnType)); functions.AddRange(new ModuleStructFunction(function.Name, function.Hook, parameters, function.Signature.ReturnType));
} }
module.RegisterStruct(structDef.Exported, structDef.Name, fields, functions); module.RegisterStruct(structDef.Exported, structDef.Name, fields, functions);

View File

@@ -175,13 +175,19 @@ public sealed class Parser
{ {
var memberStartIndex = _tokenIndex; var memberStartIndex = _tokenIndex;
string? hook = null;
if (TryExpectSymbol(Symbol.At))
{
hook = ExpectIdentifier().Value;
}
if (TryExpectSymbol(Symbol.Func)) if (TryExpectSymbol(Symbol.Func))
{ {
var funcName = ExpectIdentifier().Value; var funcName = ExpectIdentifier().Value;
var funcSignature = ParseFuncSignature(); var funcSignature = ParseFuncSignature();
var funcBody = ParseBlock(); var funcBody = ParseBlock();
funcs.Add(new StructFuncSyntax(GetTokens(memberStartIndex), funcName, funcSignature, funcBody)); funcs.Add(new StructFuncSyntax(GetTokens(memberStartIndex), funcName, hook, funcSignature, funcBody));
} }
else else
{ {
@@ -209,6 +215,8 @@ public sealed class Parser
{ {
switch (symbol.Symbol) switch (symbol.Symbol)
{ {
case Symbol.OpenBrace:
return ParseBlock();
case Symbol.Return: case Symbol.Return:
return ParseReturn(); return ParseReturn();
case Symbol.If: case Symbol.If:
@@ -219,6 +227,8 @@ public sealed class Parser
return ParseFor(); return ParseFor();
case Symbol.Let: case Symbol.Let:
return ParseVariableDeclaration(); return ParseVariableDeclaration();
case Symbol.Defer:
return ParseDefer();
case Symbol.Break: case Symbol.Break:
return ParseBreak(); return ParseBreak();
case Symbol.Continue: case Symbol.Continue:
@@ -264,14 +274,22 @@ public sealed class Parser
return new VariableDeclarationSyntax(GetTokens(startIndex), name, explicitType, assignment); return new VariableDeclarationSyntax(GetTokens(startIndex), name, explicitType, assignment);
} }
private StatementSyntax ParseBreak() private DeferSyntax ParseDefer()
{
var startIndex = _tokenIndex;
ExpectSymbol(Symbol.Defer);
var statement = ParseStatement();
return new DeferSyntax(GetTokens(startIndex), statement);
}
private BreakSyntax ParseBreak()
{ {
var startIndex = _tokenIndex; var startIndex = _tokenIndex;
ExpectSymbol(Symbol.Break); ExpectSymbol(Symbol.Break);
return new BreakSyntax(GetTokens(startIndex)); return new BreakSyntax(GetTokens(startIndex));
} }
private StatementSyntax ParseContinue() private ContinueSyntax ParseContinue()
{ {
var startIndex = _tokenIndex; var startIndex = _tokenIndex;
ExpectSymbol(Symbol.Continue); ExpectSymbol(Symbol.Continue);
@@ -303,9 +321,14 @@ public sealed class Parser
var elseStatement = Optional<Variant<IfSyntax, BlockSyntax>>.Empty(); var elseStatement = Optional<Variant<IfSyntax, BlockSyntax>>.Empty();
if (TryExpectSymbol(Symbol.Else)) if (TryExpectSymbol(Symbol.Else))
{ {
elseStatement = TryExpectSymbol(Symbol.If) if (CurrentToken is SymbolToken { Symbol: Symbol.If })
? (Variant<IfSyntax, BlockSyntax>)ParseIf() {
: (Variant<IfSyntax, BlockSyntax>)ParseBlock(); elseStatement = (Variant<IfSyntax, BlockSyntax>)ParseIf();
}
else
{
elseStatement = (Variant<IfSyntax, BlockSyntax>)ParseBlock();
}
} }
return new IfSyntax(GetTokens(startIndex), condition, body, elseStatement); return new IfSyntax(GetTokens(startIndex), condition, body, elseStatement);

View File

@@ -12,6 +12,6 @@ public record FuncSyntax(IEnumerable<Token> Tokens, string Name, bool Exported,
public record StructFieldSyntax(IEnumerable<Token> Tokens, string Name, TypeSyntax Type, Optional<ExpressionSyntax> Value) : SyntaxNode(Tokens); public record StructFieldSyntax(IEnumerable<Token> Tokens, string Name, TypeSyntax Type, Optional<ExpressionSyntax> Value) : SyntaxNode(Tokens);
public record StructFuncSyntax(IEnumerable<Token> Tokens, string Name, FuncSignatureSyntax Signature, BlockSyntax Body) : SyntaxNode(Tokens); public record StructFuncSyntax(IEnumerable<Token> Tokens, string Name, string? Hook, FuncSignatureSyntax Signature, BlockSyntax Body) : SyntaxNode(Tokens);
public record StructSyntax(IEnumerable<Token> Tokens, string Name, bool Exported, List<StructFieldSyntax> Fields, List<StructFuncSyntax> Functions) : DefinitionSyntax(Tokens, Name, Exported); public record StructSyntax(IEnumerable<Token> Tokens, string Name, bool Exported, List<StructFieldSyntax> Fields, List<StructFuncSyntax> Functions) : DefinitionSyntax(Tokens, Name, Exported);

View File

@@ -4,6 +4,8 @@ namespace NubLang.Parsing.Syntax;
public abstract record StatementSyntax(IEnumerable<Token> Tokens) : SyntaxNode(Tokens); public abstract record StatementSyntax(IEnumerable<Token> Tokens) : SyntaxNode(Tokens);
public record BlockSyntax(IEnumerable<Token> Tokens, List<StatementSyntax> Statements) : StatementSyntax(Tokens);
public record StatementExpressionSyntax(IEnumerable<Token> Tokens, ExpressionSyntax Expression) : StatementSyntax(Tokens); public record StatementExpressionSyntax(IEnumerable<Token> Tokens, ExpressionSyntax Expression) : StatementSyntax(Tokens);
public record ReturnSyntax(IEnumerable<Token> Tokens, Optional<ExpressionSyntax> Value) : StatementSyntax(Tokens); public record ReturnSyntax(IEnumerable<Token> Tokens, Optional<ExpressionSyntax> Value) : StatementSyntax(Tokens);
@@ -18,6 +20,8 @@ public record ContinueSyntax(IEnumerable<Token> Tokens) : StatementSyntax(Tokens
public record BreakSyntax(IEnumerable<Token> Tokens) : StatementSyntax(Tokens); public record BreakSyntax(IEnumerable<Token> Tokens) : StatementSyntax(Tokens);
public record DeferSyntax(IEnumerable<Token> Tokens, StatementSyntax Statement) : StatementSyntax(Tokens);
public record WhileSyntax(IEnumerable<Token> Tokens, ExpressionSyntax Condition, BlockSyntax Body) : StatementSyntax(Tokens); public record WhileSyntax(IEnumerable<Token> Tokens, ExpressionSyntax Condition, BlockSyntax Body) : StatementSyntax(Tokens);
public record ForSyntax(IEnumerable<Token> Tokens, string ElementIdent, string? IndexIdent, ExpressionSyntax Target, BlockSyntax Body) : StatementSyntax(Tokens); public record ForSyntax(IEnumerable<Token> Tokens, string ElementIdent, string? IndexIdent, ExpressionSyntax Target, BlockSyntax Body) : StatementSyntax(Tokens);

View File

@@ -7,5 +7,3 @@ public abstract record SyntaxNode(IEnumerable<Token> Tokens);
public record SyntaxTreeMetadata(string ModuleName, List<string> Imports); public record SyntaxTreeMetadata(string ModuleName, List<string> Imports);
public record SyntaxTree(List<DefinitionSyntax> Definitions, SyntaxTreeMetadata Metadata); public record SyntaxTree(List<DefinitionSyntax> Definitions, SyntaxTreeMetadata Metadata);
public record BlockSyntax(IEnumerable<Token> Tokens, List<StatementSyntax> Statements) : SyntaxNode(Tokens);

View File

@@ -80,4 +80,6 @@ public enum Symbol
Module, Module,
Import, Import,
Export, Export,
Defer,
At,
} }

View File

@@ -23,6 +23,7 @@ public sealed class Tokenizer
["module"] = Symbol.Module, ["module"] = Symbol.Module,
["export"] = Symbol.Export, ["export"] = Symbol.Export,
["import"] = Symbol.Import, ["import"] = Symbol.Import,
["defer"] = Symbol.Defer,
}; };
private static readonly Dictionary<char[], Symbol> Symbols = new() private static readonly Dictionary<char[], Symbol> Symbols = new()
@@ -58,6 +59,7 @@ public sealed class Tokenizer
[[';']] = Symbol.Semi, [[';']] = Symbol.Semi,
[['%']] = Symbol.Percent, [['%']] = Symbol.Percent,
[['|']] = Symbol.Pipe, [['|']] = Symbol.Pipe,
[['@']] = Symbol.At,
}; };
private static readonly (char[] Pattern, Symbol Symbol)[] OrderedSymbols = Symbols private static readonly (char[] Pattern, Symbol Symbol)[] OrderedSymbols = Symbols

View File

@@ -72,7 +72,7 @@ public record StructFieldAccessNode(TypeNode Type, StructTypeNode StructType, Ex
public record StructInitializerNode(StructTypeNode StructType, Dictionary<string, ExpressionNode> Initializers) : RValueExpressionNode(StructType); public record StructInitializerNode(StructTypeNode StructType, Dictionary<string, ExpressionNode> Initializers) : RValueExpressionNode(StructType);
public record DereferenceNode(TypeNode Type, ExpressionNode Expression) : RValueExpressionNode(Type); public record DereferenceNode(TypeNode Type, ExpressionNode Expression) : LValueExpressionNode(Type);
public record ConvertIntNode(TypeNode Type, ExpressionNode Value, IntTypeNode ValueType, IntTypeNode TargetType) : RValueExpressionNode(Type); public record ConvertIntNode(TypeNode Type, ExpressionNode Value, IntTypeNode ValueType, IntTypeNode TargetType) : RValueExpressionNode(Type);

View File

@@ -1,5 +1,3 @@
namespace NubLang.TypeChecking.Node; namespace NubLang.TypeChecking.Node;
public abstract record Node; public abstract record Node;
public record BlockNode(List<StatementNode> Statements) : Node;

View File

@@ -2,9 +2,13 @@
public abstract record StatementNode : Node; public abstract record StatementNode : Node;
public abstract record TerminalStatementNode : StatementNode;
public record BlockNode(List<StatementNode> Statements) : StatementNode;
public record StatementExpressionNode(ExpressionNode Expression) : StatementNode; public record StatementExpressionNode(ExpressionNode Expression) : StatementNode;
public record ReturnNode(Optional<ExpressionNode> Value) : StatementNode; public record ReturnNode(Optional<ExpressionNode> Value) : TerminalStatementNode;
public record AssignmentNode(LValueExpressionNode Target, ExpressionNode Value) : StatementNode; public record AssignmentNode(LValueExpressionNode Target, ExpressionNode Value) : StatementNode;
@@ -12,9 +16,11 @@ public record IfNode(ExpressionNode Condition, BlockNode Body, Optional<Variant<
public record VariableDeclarationNode(string Name, Optional<ExpressionNode> Assignment, TypeNode Type) : StatementNode; public record VariableDeclarationNode(string Name, Optional<ExpressionNode> Assignment, TypeNode Type) : StatementNode;
public record ContinueNode : StatementNode; public record ContinueNode : TerminalStatementNode;
public record BreakNode : StatementNode; public record BreakNode : TerminalStatementNode;
public record DeferNode(StatementNode Statement) : StatementNode;
public record WhileNode(ExpressionNode Condition, BlockNode Body) : StatementNode; public record WhileNode(ExpressionNode Condition, BlockNode Body) : StatementNode;

View File

@@ -119,9 +119,10 @@ public class StructTypeField(string name, TypeNode type, bool hasDefaultValue)
public bool HasDefaultValue { get; } = hasDefaultValue; public bool HasDefaultValue { get; } = hasDefaultValue;
} }
public class StructTypeFunc(string name, FuncTypeNode type) public class StructTypeFunc(string name, string? hook, FuncTypeNode type)
{ {
public string Name { get; } = name; public string Name { get; } = name;
public string? Hook { get; set; } = hook;
public FuncTypeNode Type { get; } = type; public FuncTypeNode Type { get; } = type;
} }

View File

@@ -73,7 +73,7 @@ public sealed class TypeChecker
{ {
var parameters = function.Signature.Parameters.Select(x => ResolveType(x.Type)).ToList(); var parameters = function.Signature.Parameters.Select(x => ResolveType(x.Type)).ToList();
var funcType = new FuncTypeNode(parameters, ResolveType(function.Signature.ReturnType)); var funcType = new FuncTypeNode(parameters, ResolveType(function.Signature.ReturnType));
functionTypes.Add(new StructTypeFunc(function.Name, funcType)); functionTypes.Add(new StructTypeFunc(function.Name, function.Hook, funcType));
} }
var type = new StructTypeNode(_syntaxTree.Metadata.ModuleName, node.Name, fieldTypes, functionTypes); var type = new StructTypeNode(_syntaxTree.Metadata.ModuleName, node.Name, fieldTypes, functionTypes);
@@ -127,11 +127,20 @@ public sealed class TypeChecker
body = CheckBlock(node.Body, scope); body = CheckBlock(node.Body, scope);
// Insert implicit return for void functions if (!AlwaysReturns(body))
if (signature.ReturnType is VoidTypeNode && body.Statements.LastOrDefault() is not ReturnNode) {
if (signature.ReturnType is VoidTypeNode)
{ {
body.Statements.Add(new ReturnNode(Optional<ExpressionNode>.Empty())); body.Statements.Add(new ReturnNode(Optional<ExpressionNode>.Empty()));
} }
else
{
Diagnostics.Add(Diagnostic
.Error("Not all code paths return a value")
.At(node.Body.Tokens.LastOrDefault())
.Build());
}
}
_funcReturnTypes.Pop(); _funcReturnTypes.Pop();
} }
@@ -150,6 +159,7 @@ public sealed class TypeChecker
ReturnSyntax statement => CheckReturn(statement), ReturnSyntax statement => CheckReturn(statement),
StatementExpressionSyntax statement => CheckStatementExpression(statement), StatementExpressionSyntax statement => CheckStatementExpression(statement),
VariableDeclarationSyntax statement => CheckVariableDeclaration(statement), VariableDeclarationSyntax statement => CheckVariableDeclaration(statement),
DeferSyntax statement => CheckDefer(statement),
WhileSyntax statement => CheckWhile(statement), WhileSyntax statement => CheckWhile(statement),
ForSyntax statement => CheckFor(statement), ForSyntax statement => CheckFor(statement),
_ => throw new ArgumentOutOfRangeException(nameof(node)) _ => throw new ArgumentOutOfRangeException(nameof(node))
@@ -224,6 +234,11 @@ public sealed class TypeChecker
return new VariableDeclarationNode(statement.Name, Optional.OfNullable(assignmentNode), type); return new VariableDeclarationNode(statement.Name, Optional.OfNullable(assignmentNode), type);
} }
private DeferNode CheckDefer(DeferSyntax statement)
{
return new DeferNode(CheckStatement(statement.Statement));
}
private WhileNode CheckWhile(WhileSyntax statement) private WhileNode CheckWhile(WhileSyntax statement)
{ {
var condition = CheckExpression(statement.Condition, new BoolTypeNode()); var condition = CheckExpression(statement.Condition, new BoolTypeNode());
@@ -734,8 +749,8 @@ public sealed class TypeChecker
if (missingFields.Length != 0) if (missingFields.Length != 0)
{ {
throw new TypeCheckerException(Diagnostic Diagnostics.Add(Diagnostic
.Error($"Fields {string.Join(", ", missingFields)} are not initialized") .Warning($"Fields {string.Join(", ", missingFields)} are not initialized")
.At(expression) .At(expression)
.Build()); .Build());
} }
@@ -759,7 +774,7 @@ public sealed class TypeChecker
if (reachable) if (reachable)
{ {
statements.Add(checkedStatement); statements.Add(checkedStatement);
if (checkedStatement is ReturnNode or BreakNode or ContinueNode) if (checkedStatement is TerminalStatementNode)
{ {
reachable = false; reachable = false;
} }
@@ -779,6 +794,17 @@ public sealed class TypeChecker
return new BlockNode(statements); return new BlockNode(statements);
} }
private bool AlwaysReturns(StatementNode statement)
{
return statement switch
{
ReturnNode => true,
BlockNode block => block.Statements.Count != 0 && AlwaysReturns(block.Statements.Last()),
IfNode ifNode => AlwaysReturns(ifNode.Body) && ifNode.Else.TryGetValue(out var elseStatement) ? elseStatement.Match(AlwaysReturns, AlwaysReturns) : true,
_ => false
};
}
private TypeNode ResolveType(TypeSyntax type) private TypeNode ResolveType(TypeSyntax type)
{ {
return type switch return type switch
@@ -838,7 +864,7 @@ public sealed class TypeChecker
{ {
var parameters = function.Parameters.Select(x => ResolveType(x.Type)).ToList(); var parameters = function.Parameters.Select(x => ResolveType(x.Type)).ToList();
var type = new FuncTypeNode(parameters, ResolveType(function.ReturnType)); var type = new FuncTypeNode(parameters, ResolveType(function.ReturnType));
result.Functions.Add(new StructTypeFunc(function.Name, type)); result.Functions.Add(new StructTypeFunc(function.Name, function.Hook, type));
} }
ReferencedStructTypes.Add(result); ReferencedStructTypes.Add(result);

View File

@@ -1,6 +1,8 @@
module "main" module "main"
extern "puts" func puts(text: cstring) extern "puts" func puts(text: cstring)
extern "malloc" func malloc(size: u64): ^u64
extern "free" func free(address: ^u64)
struct Human struct Human
{ {
@@ -9,28 +11,49 @@ struct Human
extern "main" func main(args: []cstring): i64 extern "main" func main(args: []cstring): i64
{ {
let x = [2]cstring let x: ref = {}
x[0] = "test1" test(x)
x[1] = "test2"
for u in x
{
puts(u)
}
let me: Human = {
name = "test"
}
puts(me.name)
test(32)
return 0 return 0
} }
func test(x: ref)
func test(x: u8)
{ {
}
struct ref
{
value: ^u64
count: ^u64
@oncreate
func on_create()
{
puts("on_create")
this.value = malloc(8)
this.count = malloc(8)
this.count^ = 1
}
@oncopy
func on_copy()
{
puts("on_copy")
this.count^ = this.count^ + 1
}
@ondestroy
func on_destroy()
{
puts("on_destroy")
this.count^ = this.count^ - 1
if this.count^ <= 0
{
puts("free")
free(this.value)
free(this.count)
}
}
} }