From c9064f0a4357e35abf6c53211e6eecd8bef6f5af Mon Sep 17 00:00:00 2001 From: nub31 Date: Fri, 19 Sep 2025 23:08:12 +0200 Subject: [PATCH] defer and struct hooks --- .../NubLang/Generation/QBE/QBEGenerator.cs | 162 ++++++++++++------ compiler/NubLang/Modules/Module.cs | 2 +- compiler/NubLang/Modules/ModuleRepository.cs | 2 +- compiler/NubLang/Parsing/Parser.cs | 35 +++- .../Parsing/Syntax/DefinitionSyntax.cs | 2 +- .../NubLang/Parsing/Syntax/StatementSyntax.cs | 4 + compiler/NubLang/Parsing/Syntax/SyntaxNode.cs | 4 +- compiler/NubLang/Tokenization/Token.cs | 2 + compiler/NubLang/Tokenization/Tokenizer.cs | 2 + .../TypeChecking/Node/ExpressionNode.cs | 2 +- compiler/NubLang/TypeChecking/Node/Node.cs | 4 +- .../TypeChecking/Node/StatementNode.cs | 12 +- .../NubLang/TypeChecking/Node/TypeNode.cs | 3 +- compiler/NubLang/TypeChecking/TypeChecker.cs | 42 ++++- example/src/main.nub | 59 +++++-- 15 files changed, 240 insertions(+), 97 deletions(-) diff --git a/compiler/NubLang/Generation/QBE/QBEGenerator.cs b/compiler/NubLang/Generation/QBE/QBEGenerator.cs index 0601506..be5e15b 100644 --- a/compiler/NubLang/Generation/QBE/QBEGenerator.cs +++ b/compiler/NubLang/Generation/QBE/QBEGenerator.cs @@ -16,11 +16,14 @@ public class QBEGenerator private readonly List _stringLiterals = []; private readonly Stack _breakLabels = []; private readonly Stack _continueLabels = []; + private readonly Stack _scopes = new(); private int _tmpIndex; private int _labelIndex; private int _cStringLiteralIndex; private int _stringLiteralIndex; + private Scope Scope => _scopes.Peek(); + public QBEGenerator(List definitions, HashSet structTypes) { _definitions = definitions; @@ -34,6 +37,7 @@ public class QBEGenerator _stringLiterals.Clear(); _breakLabels.Clear(); _continueLabels.Clear(); + _scopes.Clear(); _tmpIndex = 0; _labelIndex = 0; _cStringLiteralIndex = 0; @@ -307,6 +311,16 @@ public class QBEGenerator { var value = EmitExpression(source); _writer.Indented($"blit {value}, {destination}, {SizeOf(source.Type)}"); + + if (source.Type is StructTypeNode structType) + { + var copyFunc = structType.Functions.FirstOrDefault(x => x.Hook == "oncopy"); + if (copyFunc != null) + { + _writer.Indented($"call {StructFuncName(structType.Module, structType.Name, copyFunc.Name)}(l {destination})"); + } + } + return; } @@ -349,48 +363,6 @@ public class QBEGenerator } } - private string EmitCopy(ExpressionNode source) - { - if (source is RValueExpressionNode || source.Type.IsScalar) - { - return EmitExpression(source); - } - - var value = EmitExpression(source); - var destination = TmpName(); - - if (source.Type.IsValueType) - { - var size = SizeOf(source.Type); - _writer.Indented($"{destination} =l alloc8 {size}"); - _writer.Indented($"blit {value}, {destination}, {size}"); - return destination; - } - - switch (source.Type) - { - case ArrayTypeNode arrayType: - var arraySize = EmitArraySizeInBytes(arrayType, value); - _writer.Indented($"{destination} =l alloc8 {arraySize}"); - EmitMemcpy(value, destination, arraySize); - break; - case CStringTypeNode: - var cstrSize = EmitCStringSizeInBytes(value); - _writer.Indented($"{destination} =l alloc8 {cstrSize}"); - EmitMemcpy(value, destination, cstrSize); - break; - case StringTypeNode: - var strSize = EmitStringSizeInBytes(value); - _writer.Indented($"{destination} =l alloc8 {strSize}"); - EmitMemcpy(value, destination, strSize); - break; - default: - throw new InvalidOperationException($"Cannot copy type {source.Type}"); - } - - return destination; - } - private void EmitStructType(StructTypeNode structType) { // todo(nub31): qbe expects structs to be declared in order. We must Check the dependencies of the struct to see if a type need to be declared before this one @@ -450,7 +422,13 @@ public class QBEGenerator _writer.WriteLine(") {"); _writer.WriteLine("@start"); - EmitBlock(function.Body); + var scope = new Scope(); + foreach (var parameter in function.Signature.Parameters) + { + scope.Variables.Push(new Variable(parameter.Name, parameter.Type)); + } + + EmitBlock(function.Body, scope); // Implicit return for void functions if no explicit return has been set if (function.Signature.ReturnType is VoidTypeNode && function.Body.Statements.LastOrDefault() is not ReturnNode) @@ -488,18 +466,31 @@ public class QBEGenerator _writer.WriteLine(") {"); _writer.WriteLine("@start"); - EmitBlock(funcDef.Body); + var scope = new Scope(); + foreach (var parameter in funcDef.Signature.Parameters) + { + scope.Variables.Push(new Variable(parameter.Name, parameter.Type)); + } + + EmitBlock(funcDef.Body, scope); _writer.WriteLine("}"); _writer.NewLine(); } - private void EmitBlock(BlockNode block) + private void EmitBlock(BlockNode block, Scope? scope = null) { + scope ??= Scope.SubScope(); + _scopes.Push(scope); + foreach (var statement in block.Statements) { EmitStatement(statement); } + + EmitScopeCleanup(); + + _scopes.Pop(); } private void EmitStatement(StatementNode statement) @@ -510,9 +501,11 @@ public class QBEGenerator EmitCopyInto(assignment.Value, EmitAddressOf(assignment.Target)); break; case BreakNode: + EmitScopeCleanup(); _writer.Indented($"jmp {_breakLabels.Peek()}"); break; case ContinueNode: + EmitScopeCleanup(); _writer.Indented($"jmp {_continueLabels.Peek()}"); break; case IfNode ifStatement: @@ -527,6 +520,9 @@ public class QBEGenerator case VariableDeclarationNode variableDeclaration: EmitVariableDeclaration(variableDeclaration); break; + case DeferNode defer: + Scope.DeferredStatements.Push(defer.Statement); + break; case WhileNode whileStatement: EmitWhile(whileStatement); break; @@ -538,6 +534,26 @@ public class QBEGenerator } } + private void EmitScopeCleanup() + { + while (Scope.DeferredStatements.TryPop(out var deferredStatement)) + { + EmitStatement(deferredStatement); + } + + while (Scope.Variables.TryPop(out var variable)) + { + if (variable.Type is StructTypeNode structType) + { + var destroyFunc = structType.Functions.FirstOrDefault(x => x.Hook == "ondestroy"); + if (destroyFunc != null) + { + _writer.Indented($"call {StructFuncName(structType.Module, structType.Name, destroyFunc.Name)}(l %{variable.Name})"); + } + } + } + } + private void EmitIf(IfNode ifStatement) { var trueLabel = LabelName(); @@ -552,7 +568,7 @@ public class QBEGenerator _writer.WriteLine(falseLabel); if (ifStatement.Else.HasValue) { - ifStatement.Else.Value.Match(EmitIf, EmitBlock); + ifStatement.Else.Value.Match(EmitIf, b => EmitBlock(b)); } _writer.WriteLine(endLabel); @@ -563,10 +579,12 @@ public class QBEGenerator if (@return.Value.HasValue) { var result = EmitExpression(@return.Value.Value); + EmitScopeCleanup(); _writer.Indented($"ret {result}"); } else { + EmitScopeCleanup(); _writer.Indented("ret"); } } @@ -580,6 +598,8 @@ public class QBEGenerator { EmitCopyInto(variableDeclaration.Assignment.Value, name); } + + Scope.Variables.Push(new Variable(variableDeclaration.Name, variableDeclaration.Type)); } private void EmitWhile(WhileNode whileStatement) @@ -677,7 +697,6 @@ public class QBEGenerator BinaryExpressionNode expr => EmitBinaryExpression(expr), ConvertFloatNode expr => EmitConvertFloat(expr), ConvertIntNode expr => EmitConvertInt(expr), - DereferenceNode expr => EmitLoad(expr.Type, EmitExpression(expr.Expression)), FuncCallNode expr => EmitFuncCall(expr), FuncIdentifierNode expr => FuncName(expr.Module, expr.Name, expr.ExternSymbol), FuncParameterIdentifierNode expr => $"%{expr.Name}", @@ -740,6 +759,7 @@ public class QBEGenerator return lval switch { ArrayIndexAccessNode arrayIndexAccess => EmitAddressOfArrayIndexAccess(arrayIndexAccess), + DereferenceNode dereference => EmitExpression(dereference.Expression), StructFieldAccessNode structFieldAccess => EmitAddressOfStructFieldAccess(structFieldAccess), VariableIdentifierNode variableIdent => $"%{variableIdent.Name}", _ => throw new ArgumentOutOfRangeException(nameof(lval)) @@ -948,6 +968,12 @@ public class QBEGenerator _writer.Indented($"{destination} =l alloc8 {size}"); _writer.Indented($"call {StructCtorName(structInitializer.StructType.Module, structInitializer.StructType.Name)}(l {destination})"); + var createFunc = structInitializer.StructType.Functions.FirstOrDefault(x => x.Hook == "oncreate"); + if (createFunc != null) + { + _writer.Indented($"call {StructFuncName(structInitializer.StructType.Module, structInitializer.StructType.Name, createFunc.Name)}(l {destination})"); + } + foreach (var (field, value) in structInitializer.Initializers) { var offset = TmpName(); @@ -1015,8 +1041,18 @@ public class QBEGenerator foreach (var parameter in structFuncCall.Parameters) { - var copy = EmitCopy(parameter); - parameterStrings.Add($"{FuncQBETypeName(parameter.Type)} {copy}"); + var value = EmitExpression(parameter); + + if (parameter.Type is StructTypeNode structType) + { + var copyFunc = structType.Functions.FirstOrDefault(x => x.Hook == "oncopy"); + if (copyFunc != null) + { + _writer.Indented($"call {StructFuncName(structType.Module, structType.Name, copyFunc.Name)}(l {value})"); + } + } + + parameterStrings.Add($"{FuncQBETypeName(parameter.Type)} {value}"); } if (structFuncCall.Type is VoidTypeNode) @@ -1093,8 +1129,18 @@ public class QBEGenerator foreach (var parameter in funcCall.Parameters) { - var copy = EmitCopy(parameter); - parameterStrings.Add($"{FuncQBETypeName(parameter.Type)} {copy}"); + var value = EmitExpression(parameter); + + if (parameter.Type is StructTypeNode structType) + { + var copyFunc = structType.Functions.FirstOrDefault(x => x.Hook == "oncopy"); + if (copyFunc != null) + { + _writer.Indented($"call {StructFuncName(structType.Module, structType.Name, copyFunc.Name)}(l {value})"); + } + } + + parameterStrings.Add($"{FuncQBETypeName(parameter.Type)} {value}"); } if (funcCall.Type is VoidTypeNode) @@ -1242,6 +1288,20 @@ public class QBEGenerator } } +// todo(nub31): Parent is not used when getting variables and deferred statements +public class Scope(Scope? parent = null) +{ + public readonly Stack DeferredStatements = []; + public readonly Stack Variables = []; + + public Scope SubScope() + { + return new Scope(this); + } +} + +public record Variable(string Name, TypeNode Type); + public class StringLiteral(string value, string name) { public string Value { get; } = value; diff --git a/compiler/NubLang/Modules/Module.cs b/compiler/NubLang/Modules/Module.cs index bee7c02..c1cbb96 100644 --- a/compiler/NubLang/Modules/Module.cs +++ b/compiler/NubLang/Modules/Module.cs @@ -32,7 +32,7 @@ public record ModuleStructField(string Name, TypeSyntax Type, bool HasDefaultVal public record ModuleStructFunctionParameter(string Name, TypeSyntax Type); -public record ModuleStructFunction(string Name, List Parameters, TypeSyntax ReturnType); +public record ModuleStructFunction(string Name, string? Hook, List Parameters, TypeSyntax ReturnType); public record ModuleStruct(bool Exported, string Name, List Fields, List Functions); diff --git a/compiler/NubLang/Modules/ModuleRepository.cs b/compiler/NubLang/Modules/ModuleRepository.cs index 4a17926..c4a209c 100644 --- a/compiler/NubLang/Modules/ModuleRepository.cs +++ b/compiler/NubLang/Modules/ModuleRepository.cs @@ -36,7 +36,7 @@ public class ModuleRepository foreach (var function in structDef.Functions) { var parameters = function.Signature.Parameters.Select(x => new ModuleStructFunctionParameter(x.Name, x.Type)).ToList(); - functions.AddRange(new ModuleStructFunction(function.Name, parameters, function.Signature.ReturnType)); + functions.AddRange(new ModuleStructFunction(function.Name, function.Hook, parameters, function.Signature.ReturnType)); } module.RegisterStruct(structDef.Exported, structDef.Name, fields, functions); diff --git a/compiler/NubLang/Parsing/Parser.cs b/compiler/NubLang/Parsing/Parser.cs index 913ed87..0d9640b 100644 --- a/compiler/NubLang/Parsing/Parser.cs +++ b/compiler/NubLang/Parsing/Parser.cs @@ -175,13 +175,19 @@ public sealed class Parser { var memberStartIndex = _tokenIndex; + string? hook = null; + if (TryExpectSymbol(Symbol.At)) + { + hook = ExpectIdentifier().Value; + } + if (TryExpectSymbol(Symbol.Func)) { var funcName = ExpectIdentifier().Value; var funcSignature = ParseFuncSignature(); var funcBody = ParseBlock(); - funcs.Add(new StructFuncSyntax(GetTokens(memberStartIndex), funcName, funcSignature, funcBody)); + funcs.Add(new StructFuncSyntax(GetTokens(memberStartIndex), funcName, hook, funcSignature, funcBody)); } else { @@ -209,6 +215,8 @@ public sealed class Parser { switch (symbol.Symbol) { + case Symbol.OpenBrace: + return ParseBlock(); case Symbol.Return: return ParseReturn(); case Symbol.If: @@ -219,6 +227,8 @@ public sealed class Parser return ParseFor(); case Symbol.Let: return ParseVariableDeclaration(); + case Symbol.Defer: + return ParseDefer(); case Symbol.Break: return ParseBreak(); case Symbol.Continue: @@ -264,14 +274,22 @@ public sealed class Parser return new VariableDeclarationSyntax(GetTokens(startIndex), name, explicitType, assignment); } - private StatementSyntax ParseBreak() + private DeferSyntax ParseDefer() + { + var startIndex = _tokenIndex; + ExpectSymbol(Symbol.Defer); + var statement = ParseStatement(); + return new DeferSyntax(GetTokens(startIndex), statement); + } + + private BreakSyntax ParseBreak() { var startIndex = _tokenIndex; ExpectSymbol(Symbol.Break); return new BreakSyntax(GetTokens(startIndex)); } - private StatementSyntax ParseContinue() + private ContinueSyntax ParseContinue() { var startIndex = _tokenIndex; ExpectSymbol(Symbol.Continue); @@ -303,9 +321,14 @@ public sealed class Parser var elseStatement = Optional>.Empty(); if (TryExpectSymbol(Symbol.Else)) { - elseStatement = TryExpectSymbol(Symbol.If) - ? (Variant)ParseIf() - : (Variant)ParseBlock(); + if (CurrentToken is SymbolToken { Symbol: Symbol.If }) + { + elseStatement = (Variant)ParseIf(); + } + else + { + elseStatement = (Variant)ParseBlock(); + } } return new IfSyntax(GetTokens(startIndex), condition, body, elseStatement); diff --git a/compiler/NubLang/Parsing/Syntax/DefinitionSyntax.cs b/compiler/NubLang/Parsing/Syntax/DefinitionSyntax.cs index 0a929f3..1321433 100644 --- a/compiler/NubLang/Parsing/Syntax/DefinitionSyntax.cs +++ b/compiler/NubLang/Parsing/Syntax/DefinitionSyntax.cs @@ -12,6 +12,6 @@ public record FuncSyntax(IEnumerable Tokens, string Name, bool Exported, public record StructFieldSyntax(IEnumerable Tokens, string Name, TypeSyntax Type, Optional Value) : SyntaxNode(Tokens); -public record StructFuncSyntax(IEnumerable Tokens, string Name, FuncSignatureSyntax Signature, BlockSyntax Body) : SyntaxNode(Tokens); +public record StructFuncSyntax(IEnumerable Tokens, string Name, string? Hook, FuncSignatureSyntax Signature, BlockSyntax Body) : SyntaxNode(Tokens); public record StructSyntax(IEnumerable Tokens, string Name, bool Exported, List Fields, List Functions) : DefinitionSyntax(Tokens, Name, Exported); \ No newline at end of file diff --git a/compiler/NubLang/Parsing/Syntax/StatementSyntax.cs b/compiler/NubLang/Parsing/Syntax/StatementSyntax.cs index a3e421d..a663ab9 100644 --- a/compiler/NubLang/Parsing/Syntax/StatementSyntax.cs +++ b/compiler/NubLang/Parsing/Syntax/StatementSyntax.cs @@ -4,6 +4,8 @@ namespace NubLang.Parsing.Syntax; public abstract record StatementSyntax(IEnumerable Tokens) : SyntaxNode(Tokens); +public record BlockSyntax(IEnumerable Tokens, List Statements) : StatementSyntax(Tokens); + public record StatementExpressionSyntax(IEnumerable Tokens, ExpressionSyntax Expression) : StatementSyntax(Tokens); public record ReturnSyntax(IEnumerable Tokens, Optional Value) : StatementSyntax(Tokens); @@ -18,6 +20,8 @@ public record ContinueSyntax(IEnumerable Tokens) : StatementSyntax(Tokens public record BreakSyntax(IEnumerable Tokens) : StatementSyntax(Tokens); +public record DeferSyntax(IEnumerable Tokens, StatementSyntax Statement) : StatementSyntax(Tokens); + public record WhileSyntax(IEnumerable Tokens, ExpressionSyntax Condition, BlockSyntax Body) : StatementSyntax(Tokens); public record ForSyntax(IEnumerable Tokens, string ElementIdent, string? IndexIdent, ExpressionSyntax Target, BlockSyntax Body) : StatementSyntax(Tokens); \ No newline at end of file diff --git a/compiler/NubLang/Parsing/Syntax/SyntaxNode.cs b/compiler/NubLang/Parsing/Syntax/SyntaxNode.cs index 35d1c1d..65d3210 100644 --- a/compiler/NubLang/Parsing/Syntax/SyntaxNode.cs +++ b/compiler/NubLang/Parsing/Syntax/SyntaxNode.cs @@ -6,6 +6,4 @@ public abstract record SyntaxNode(IEnumerable Tokens); public record SyntaxTreeMetadata(string ModuleName, List Imports); -public record SyntaxTree(List Definitions, SyntaxTreeMetadata Metadata); - -public record BlockSyntax(IEnumerable Tokens, List Statements) : SyntaxNode(Tokens); \ No newline at end of file +public record SyntaxTree(List Definitions, SyntaxTreeMetadata Metadata); \ No newline at end of file diff --git a/compiler/NubLang/Tokenization/Token.cs b/compiler/NubLang/Tokenization/Token.cs index 764c528..9ead836 100644 --- a/compiler/NubLang/Tokenization/Token.cs +++ b/compiler/NubLang/Tokenization/Token.cs @@ -80,4 +80,6 @@ public enum Symbol Module, Import, Export, + Defer, + At, } \ No newline at end of file diff --git a/compiler/NubLang/Tokenization/Tokenizer.cs b/compiler/NubLang/Tokenization/Tokenizer.cs index 7478a5d..934a757 100644 --- a/compiler/NubLang/Tokenization/Tokenizer.cs +++ b/compiler/NubLang/Tokenization/Tokenizer.cs @@ -23,6 +23,7 @@ public sealed class Tokenizer ["module"] = Symbol.Module, ["export"] = Symbol.Export, ["import"] = Symbol.Import, + ["defer"] = Symbol.Defer, }; private static readonly Dictionary Symbols = new() @@ -58,6 +59,7 @@ public sealed class Tokenizer [[';']] = Symbol.Semi, [['%']] = Symbol.Percent, [['|']] = Symbol.Pipe, + [['@']] = Symbol.At, }; private static readonly (char[] Pattern, Symbol Symbol)[] OrderedSymbols = Symbols diff --git a/compiler/NubLang/TypeChecking/Node/ExpressionNode.cs b/compiler/NubLang/TypeChecking/Node/ExpressionNode.cs index 1aef4ed..df4824f 100644 --- a/compiler/NubLang/TypeChecking/Node/ExpressionNode.cs +++ b/compiler/NubLang/TypeChecking/Node/ExpressionNode.cs @@ -72,7 +72,7 @@ public record StructFieldAccessNode(TypeNode Type, StructTypeNode StructType, Ex public record StructInitializerNode(StructTypeNode StructType, Dictionary Initializers) : RValueExpressionNode(StructType); -public record DereferenceNode(TypeNode Type, ExpressionNode Expression) : RValueExpressionNode(Type); +public record DereferenceNode(TypeNode Type, ExpressionNode Expression) : LValueExpressionNode(Type); public record ConvertIntNode(TypeNode Type, ExpressionNode Value, IntTypeNode ValueType, IntTypeNode TargetType) : RValueExpressionNode(Type); diff --git a/compiler/NubLang/TypeChecking/Node/Node.cs b/compiler/NubLang/TypeChecking/Node/Node.cs index 2ee1062..d64afd5 100644 --- a/compiler/NubLang/TypeChecking/Node/Node.cs +++ b/compiler/NubLang/TypeChecking/Node/Node.cs @@ -1,5 +1,3 @@ namespace NubLang.TypeChecking.Node; -public abstract record Node; - -public record BlockNode(List Statements) : Node; \ No newline at end of file +public abstract record Node; \ No newline at end of file diff --git a/compiler/NubLang/TypeChecking/Node/StatementNode.cs b/compiler/NubLang/TypeChecking/Node/StatementNode.cs index 0e41a20..451e92c 100644 --- a/compiler/NubLang/TypeChecking/Node/StatementNode.cs +++ b/compiler/NubLang/TypeChecking/Node/StatementNode.cs @@ -2,9 +2,13 @@ public abstract record StatementNode : Node; +public abstract record TerminalStatementNode : StatementNode; + +public record BlockNode(List Statements) : StatementNode; + public record StatementExpressionNode(ExpressionNode Expression) : StatementNode; -public record ReturnNode(Optional Value) : StatementNode; +public record ReturnNode(Optional Value) : TerminalStatementNode; public record AssignmentNode(LValueExpressionNode Target, ExpressionNode Value) : StatementNode; @@ -12,9 +16,11 @@ public record IfNode(ExpressionNode Condition, BlockNode Body, Optional Assignment, TypeNode Type) : StatementNode; -public record ContinueNode : StatementNode; +public record ContinueNode : TerminalStatementNode; -public record BreakNode : StatementNode; +public record BreakNode : TerminalStatementNode; + +public record DeferNode(StatementNode Statement) : StatementNode; public record WhileNode(ExpressionNode Condition, BlockNode Body) : StatementNode; diff --git a/compiler/NubLang/TypeChecking/Node/TypeNode.cs b/compiler/NubLang/TypeChecking/Node/TypeNode.cs index 47bfdf3..032ce7b 100644 --- a/compiler/NubLang/TypeChecking/Node/TypeNode.cs +++ b/compiler/NubLang/TypeChecking/Node/TypeNode.cs @@ -119,9 +119,10 @@ public class StructTypeField(string name, TypeNode type, bool hasDefaultValue) public bool HasDefaultValue { get; } = hasDefaultValue; } -public class StructTypeFunc(string name, FuncTypeNode type) +public class StructTypeFunc(string name, string? hook, FuncTypeNode type) { public string Name { get; } = name; + public string? Hook { get; set; } = hook; public FuncTypeNode Type { get; } = type; } diff --git a/compiler/NubLang/TypeChecking/TypeChecker.cs b/compiler/NubLang/TypeChecking/TypeChecker.cs index 5c04497..f4baab4 100644 --- a/compiler/NubLang/TypeChecking/TypeChecker.cs +++ b/compiler/NubLang/TypeChecking/TypeChecker.cs @@ -73,7 +73,7 @@ public sealed class TypeChecker { var parameters = function.Signature.Parameters.Select(x => ResolveType(x.Type)).ToList(); var funcType = new FuncTypeNode(parameters, ResolveType(function.Signature.ReturnType)); - functionTypes.Add(new StructTypeFunc(function.Name, funcType)); + functionTypes.Add(new StructTypeFunc(function.Name, function.Hook, funcType)); } var type = new StructTypeNode(_syntaxTree.Metadata.ModuleName, node.Name, fieldTypes, functionTypes); @@ -127,10 +127,19 @@ public sealed class TypeChecker body = CheckBlock(node.Body, scope); - // Insert implicit return for void functions - if (signature.ReturnType is VoidTypeNode && body.Statements.LastOrDefault() is not ReturnNode) + if (!AlwaysReturns(body)) { - body.Statements.Add(new ReturnNode(Optional.Empty())); + if (signature.ReturnType is VoidTypeNode) + { + body.Statements.Add(new ReturnNode(Optional.Empty())); + } + else + { + Diagnostics.Add(Diagnostic + .Error("Not all code paths return a value") + .At(node.Body.Tokens.LastOrDefault()) + .Build()); + } } _funcReturnTypes.Pop(); @@ -150,6 +159,7 @@ public sealed class TypeChecker ReturnSyntax statement => CheckReturn(statement), StatementExpressionSyntax statement => CheckStatementExpression(statement), VariableDeclarationSyntax statement => CheckVariableDeclaration(statement), + DeferSyntax statement => CheckDefer(statement), WhileSyntax statement => CheckWhile(statement), ForSyntax statement => CheckFor(statement), _ => throw new ArgumentOutOfRangeException(nameof(node)) @@ -224,6 +234,11 @@ public sealed class TypeChecker return new VariableDeclarationNode(statement.Name, Optional.OfNullable(assignmentNode), type); } + private DeferNode CheckDefer(DeferSyntax statement) + { + return new DeferNode(CheckStatement(statement.Statement)); + } + private WhileNode CheckWhile(WhileSyntax statement) { var condition = CheckExpression(statement.Condition, new BoolTypeNode()); @@ -734,8 +749,8 @@ public sealed class TypeChecker if (missingFields.Length != 0) { - throw new TypeCheckerException(Diagnostic - .Error($"Fields {string.Join(", ", missingFields)} are not initialized") + Diagnostics.Add(Diagnostic + .Warning($"Fields {string.Join(", ", missingFields)} are not initialized") .At(expression) .Build()); } @@ -759,7 +774,7 @@ public sealed class TypeChecker if (reachable) { statements.Add(checkedStatement); - if (checkedStatement is ReturnNode or BreakNode or ContinueNode) + if (checkedStatement is TerminalStatementNode) { reachable = false; } @@ -779,6 +794,17 @@ public sealed class TypeChecker return new BlockNode(statements); } + private bool AlwaysReturns(StatementNode statement) + { + return statement switch + { + ReturnNode => true, + BlockNode block => block.Statements.Count != 0 && AlwaysReturns(block.Statements.Last()), + IfNode ifNode => AlwaysReturns(ifNode.Body) && ifNode.Else.TryGetValue(out var elseStatement) ? elseStatement.Match(AlwaysReturns, AlwaysReturns) : true, + _ => false + }; + } + private TypeNode ResolveType(TypeSyntax type) { return type switch @@ -838,7 +864,7 @@ public sealed class TypeChecker { var parameters = function.Parameters.Select(x => ResolveType(x.Type)).ToList(); var type = new FuncTypeNode(parameters, ResolveType(function.ReturnType)); - result.Functions.Add(new StructTypeFunc(function.Name, type)); + result.Functions.Add(new StructTypeFunc(function.Name, function.Hook, type)); } ReferencedStructTypes.Add(result); diff --git a/example/src/main.nub b/example/src/main.nub index 89fbe8a..fa30e46 100644 --- a/example/src/main.nub +++ b/example/src/main.nub @@ -1,6 +1,8 @@ module "main" extern "puts" func puts(text: cstring) +extern "malloc" func malloc(size: u64): ^u64 +extern "free" func free(address: ^u64) struct Human { @@ -9,28 +11,49 @@ struct Human extern "main" func main(args: []cstring): i64 { - let x = [2]cstring + let x: ref = {} - x[0] = "test1" - x[1] = "test2" - - for u in x - { - puts(u) - } - - let me: Human = { - name = "test" - } - - puts(me.name) - - test(32) + test(x) return 0 } - -func test(x: u8) +func test(x: ref) { + +} + +struct ref +{ + value: ^u64 + count: ^u64 + + @oncreate + func on_create() + { + puts("on_create") + this.value = malloc(8) + this.count = malloc(8) + this.count^ = 1 + } + + @oncopy + func on_copy() + { + puts("on_copy") + this.count^ = this.count^ + 1 + } + + @ondestroy + func on_destroy() + { + puts("on_destroy") + this.count^ = this.count^ - 1 + if this.count^ <= 0 + { + puts("free") + free(this.value) + free(this.count) + } + } } \ No newline at end of file