diff --git a/example/src/main.nub b/example/src/main.nub index eadfd8a..2f0975f 100644 --- a/example/src/main.nub +++ b/example/src/main.nub @@ -1,12 +1,12 @@ -extern func puts(fmt: cstring) +extern func puts(text: cstring) struct Human { name: cstring - func str(): cstring + func print() { - return this.name + puts(this.name) } } @@ -16,7 +16,7 @@ func main(args: []cstring): i64 name = "oliver" } - puts(human.str()) + human.print() return 0 } diff --git a/src/compiler/NubLang/Generation/QBE/QBEGenerator.Expression.cs b/src/compiler/NubLang/Generation/QBE/QBEGenerator.Expression.cs index 09d4c94..9c0252c 100644 --- a/src/compiler/NubLang/Generation/QBE/QBEGenerator.Expression.cs +++ b/src/compiler/NubLang/Generation/QBE/QBEGenerator.Expression.cs @@ -18,6 +18,7 @@ public partial class QBEGenerator BinaryExpressionNode binaryExpression => EmitBinaryExpression(binaryExpression), FuncCallNode funcCallExpression => EmitFuncCall(funcCallExpression), InterfaceFuncAccessNode interfaceFuncAccess => EmitInterfaceFuncAccess(interfaceFuncAccess), + InterfaceFuncCallNode interfaceFuncCall => EmitInterfaceFuncCall(interfaceFuncCall), InterfaceInitializerNode interfaceInitializer => EmitInterfaceInitializer(interfaceInitializer), ExternFuncIdentNode externFuncIdent => EmitExternFuncIdent(externFuncIdent), LocalFuncIdentNode localFuncIdent => EmitLocalFuncIdent(localFuncIdent), @@ -26,6 +27,7 @@ public partial class QBEGenerator UnaryExpressionNode unaryExpression => EmitUnaryExpression(unaryExpression), StructFieldAccessNode structFieldAccess => EmitStructFieldAccess(structFieldAccess), StructFuncAccessNode structFuncAccess => EmitStructFuncAccess(structFuncAccess), + StructFuncCallNode structFuncCall => EmitStructFuncCall(structFuncCall), ArrayIndexAccessNode arrayIndex => EmitArrayIndexAccess(arrayIndex), _ => throw new ArgumentOutOfRangeException(nameof(expression)) }; @@ -419,6 +421,33 @@ public partial class QBEGenerator return new Val(func, structFuncAccess.Type, ValKind.Direct); } + private Val EmitStructFuncCall(StructFuncCallNode structFuncCall) + { + var expression = EmitExpression(structFuncCall.Expression); + var thisParameter = EmitUnwrap(EmitExpression(structFuncCall.ThisParam)); + + List parameterStrings = [$"l {thisParameter}"]; + + foreach (var parameter in structFuncCall.Parameters) + { + var copy = EmitCreateCopyOrInitialize(parameter); + parameterStrings.Add($"{FuncQBETypeName(parameter.Type)} {copy}"); + } + + var funcPointer = EmitUnwrap(expression); + if (structFuncCall.Type is VoidTypeNode) + { + _writer.Indented($"call {funcPointer}({string.Join(", ", parameterStrings)})"); + return new Val(string.Empty, structFuncCall.Type, ValKind.Direct); + } + else + { + var outputName = TmpName(); + _writer.Indented($"{outputName} {QBEAssign(structFuncCall.Type)} call {funcPointer}({string.Join(", ", parameterStrings)})"); + return new Val(outputName, structFuncCall.Type, ValKind.Direct); + } + } + private Val EmitInterfaceFuncAccess(InterfaceFuncAccessNode interfaceFuncAccess) { var target = EmitUnwrap(EmitExpression(interfaceFuncAccess.Target)); @@ -439,6 +468,11 @@ public partial class QBEGenerator return new Val(func, interfaceFuncAccess.Type, ValKind.Direct); } + private Val EmitInterfaceFuncCall(InterfaceFuncCallNode interfaceFuncCall) + { + throw new NotImplementedException(); + } + private Val EmitInterfaceInitializer(InterfaceInitializerNode interfaceInitializer, string? destination = null) { var implementation = EmitUnwrap(EmitExpression(interfaceInitializer.Implementation)); diff --git a/src/compiler/NubLang/Generation/QBE/QBEGenerator.cs b/src/compiler/NubLang/Generation/QBE/QBEGenerator.cs index 6e35a44..983174c 100644 --- a/src/compiler/NubLang/Generation/QBE/QBEGenerator.cs +++ b/src/compiler/NubLang/Generation/QBE/QBEGenerator.cs @@ -53,17 +53,13 @@ public partial class QBEGenerator foreach (var structDef in _syntaxTree.Definitions.OfType()) { - foreach (var func in structDef.Functions) - { - var funcName = StructFuncName(structDef.Name, func.Name); - EmitFuncDefinition(funcName, func.Signature.Parameters, func.Signature.ReturnType, func.Body); - _writer.NewLine(); - } + EmitStructDefinition(structDef); + _writer.NewLine(); } foreach (var funcDef in _syntaxTree.Definitions.OfType()) { - EmitFuncDefinition(LocalFuncName(funcDef), funcDef.Signature.Parameters, funcDef.Signature.ReturnType, funcDef.Body); + EmitLocalFuncDefinition(funcDef); _writer.NewLine(); } @@ -343,37 +339,40 @@ public partial class QBEGenerator return "l"; } - private void EmitFuncDefinition(string name, IReadOnlyList parameters, TypeNode returnType, BlockNode body) + private void EmitLocalFuncDefinition(LocalFuncNode funcDef) { _labelIndex = 0; _tmpIndex = 0; _writer.Write("export function "); - if (returnType is not VoidTypeNode) + if (funcDef.Signature.ReturnType is not VoidTypeNode) { - _writer.Write(FuncQBETypeName(returnType) + ' '); + _writer.Write(FuncQBETypeName(funcDef.Signature.ReturnType) + ' '); } - _writer.Write(name); + _writer.Write(LocalFuncName(funcDef)); - var parameterStrings = parameters.Select(x => FuncQBETypeName(x.Type) + $" %{x.Name}"); + _writer.Write("("); + foreach (var parameter in funcDef.Signature.Parameters) + { + _writer.Write(FuncQBETypeName(parameter.Type) + $" %{parameter.Name}"); + } - _writer.Write($"({string.Join(", ", parameterStrings)})"); - _writer.WriteLine(" {"); + _writer.WriteLine(") {"); _writer.WriteLine("@start"); var scope = new Scope(); - foreach (var parameter in parameters) + foreach (var parameter in funcDef.Signature.Parameters) { scope.Declare(parameter.Name, new Val("%" + parameter.Name, parameter.Type, ValKind.Direct)); } - EmitBlock(body, scope); + EmitBlock(funcDef.Body, scope); // Implicit return for void functions if no explicit return has been set - if (returnType is VoidTypeNode && body.Statements is [.., not ReturnNode]) + if (funcDef.Signature.ReturnType is VoidTypeNode && funcDef.Body.Statements is [.., not ReturnNode]) { _writer.Indented("ret"); } @@ -381,6 +380,57 @@ public partial class QBEGenerator _writer.WriteLine("}"); } + private void EmitStructDefinition(StructNode structDef) + { + for (var i = 0; i < structDef.Functions.Count; i++) + { + var function = structDef.Functions[i]; + _labelIndex = 0; + _tmpIndex = 0; + + _writer.Write("export function "); + + if (function.Signature.ReturnType is not VoidTypeNode) + { + _writer.Write(FuncQBETypeName(function.Signature.ReturnType) + ' '); + } + + _writer.Write(StructFuncName(structDef.Name, function.Name)); + + _writer.Write("(l %this, "); + foreach (var parameter in function.Signature.Parameters) + { + _writer.Write(FuncQBETypeName(parameter.Type) + $" %{parameter.Name}, "); + } + + _writer.WriteLine(") {"); + _writer.WriteLine("@start"); + + var scope = new Scope(); + + scope.Declare("this", new Val("%this", structDef.Type, ValKind.Direct)); + foreach (var parameter in function.Signature.Parameters) + { + scope.Declare(parameter.Name, new Val("%" + parameter.Name, parameter.Type, ValKind.Direct)); + } + + EmitBlock(function.Body, scope); + + // Implicit return for void functions if no explicit return has been set + if (function.Signature.ReturnType is VoidTypeNode && function.Body.Statements is [.., not ReturnNode]) + { + _writer.Indented("ret"); + } + + _writer.WriteLine("}"); + + if (i != structDef.Functions.Count - 1) + { + _writer.NewLine(); + } + } + } + private void EmitStructTypeDefinition(StructNode structDef) { _writer.WriteLine($"type {StructTypeName(structDef.Name)} = {{ "); diff --git a/src/compiler/NubLang/TypeChecking/Node/DefinitionNode.cs b/src/compiler/NubLang/TypeChecking/Node/DefinitionNode.cs index 3a42cd3..ffabf2a 100644 --- a/src/compiler/NubLang/TypeChecking/Node/DefinitionNode.cs +++ b/src/compiler/NubLang/TypeChecking/Node/DefinitionNode.cs @@ -14,7 +14,7 @@ public record StructFieldNode(int Index, string Name, TypeNode Type, Optional Fields, IReadOnlyList Functions, IReadOnlyList InterfaceImplementations) : DefinitionNode; +public record StructNode(string Name, StructTypeNode Type, IReadOnlyList Fields, IReadOnlyList Functions, IReadOnlyList InterfaceImplementations) : DefinitionNode; public record InterfaceFuncNode(string Name, FuncSignatureNode Signature) : Node; diff --git a/src/compiler/NubLang/TypeChecking/Node/ExpressionNode.cs b/src/compiler/NubLang/TypeChecking/Node/ExpressionNode.cs index 43acb5f..6869f7b 100644 --- a/src/compiler/NubLang/TypeChecking/Node/ExpressionNode.cs +++ b/src/compiler/NubLang/TypeChecking/Node/ExpressionNode.cs @@ -30,6 +30,10 @@ public record UnaryExpressionNode(TypeNode Type, UnaryOperator Operator, Express public record FuncCallNode(TypeNode Type, ExpressionNode Expression, IReadOnlyList Parameters) : ExpressionNode(Type); +public record StructFuncCallNode(TypeNode Type, ExpressionNode Expression, ExpressionNode ThisParam, IReadOnlyList Parameters) : ExpressionNode(Type); + +public record InterfaceFuncCallNode(TypeNode Type, ExpressionNode Expression, ExpressionNode ThisParam, IReadOnlyList Parameters) : ExpressionNode(Type); + public record VariableIdentNode(TypeNode Type, string Name) : ExpressionNode(Type); public record LocalFuncIdentNode(TypeNode Type, string Name) : ExpressionNode(Type); diff --git a/src/compiler/NubLang/TypeChecking/TypeChecker.cs b/src/compiler/NubLang/TypeChecking/TypeChecker.cs index 69c2374..4051104 100644 --- a/src/compiler/NubLang/TypeChecking/TypeChecker.cs +++ b/src/compiler/NubLang/TypeChecking/TypeChecker.cs @@ -13,6 +13,7 @@ public sealed class TypeChecker private readonly Stack _scopes = []; private readonly Stack _funcReturnTypes = []; private readonly List _diagnostics = []; + private readonly Dictionary _typeCache = new(); private Scope Scope => _scopes.Peek(); @@ -91,13 +92,19 @@ public sealed class TypeChecker foreach (var func in node.Functions) { - var parameters = new List(); + var scope = new Scope(); + + scope.Declare(new Variable("this", GetStructType(node))); foreach (var parameter in func.Signature.Parameters) { - parameters.Add(new FuncParameterNode(parameter.Name, CheckType(parameter.Type))); + scope.Declare(new Variable(parameter.Name, CheckType(parameter.Type))); } - funcs.Add(new StructFuncNode(func.Name, CheckFuncSignature(func.Signature), CheckFuncBody(func.Body, CheckType(func.Signature.ReturnType), parameters))); + _funcReturnTypes.Push(CheckType(func.Signature.ReturnType)); + var body = CheckBlock(func.Body, scope); + _funcReturnTypes.Pop(); + + funcs.Add(new StructFuncNode(func.Name, CheckFuncSignature(func.Signature), body)); } var interfaceImplementations = new List(); @@ -128,7 +135,7 @@ public sealed class TypeChecker interfaceImplementations.Add(interfaceType); } - return new StructNode(node.Name, structFields, funcs, interfaceImplementations); + return new StructNode(node.Name, GetStructType(node), structFields, funcs, interfaceImplementations); } private ExternFuncNode CheckExternFuncDefinition(ExternFuncSyntax node) @@ -139,7 +146,16 @@ public sealed class TypeChecker private LocalFuncNode CheckLocalFuncDefinition(LocalFuncSyntax node) { var signature = CheckFuncSignature(node.Signature); - var body = CheckFuncBody(node.Body, signature.ReturnType, signature.Parameters); + + var scope = new Scope(); + foreach (var parameter in signature.Parameters) + { + scope.Declare(new Variable(parameter.Name, parameter.Type)); + } + + _funcReturnTypes.Push(signature.ReturnType); + var body = CheckBlock(node.Body, scope); + _funcReturnTypes.Pop(); return new LocalFuncNode(node.Name, signature, body); } @@ -320,11 +336,14 @@ public sealed class TypeChecker return new DereferenceNode(dereferencedType, boundExpression); } - private FuncCallNode CheckFuncCall(FuncCallSyntax expression) + private ExpressionNode CheckFuncCall(FuncCallSyntax expression) { var boundExpression = CheckExpression(expression.Expression); - var funcType = (FuncTypeNode)boundExpression.Type; + if (boundExpression.Type is not FuncTypeNode funcType) + { + throw new TypeCheckerException(Diagnostic.Error($"Cannot call non-function type {boundExpression.Type}").Build()); + } var parameters = new List(); @@ -332,7 +351,7 @@ public sealed class TypeChecker { if (i >= funcType.Parameters.Count) { - throw new NotImplementedException("Diagnostics not implemented"); + _diagnostics.Add(Diagnostic.Error($"Expected {funcType.Parameters.Count} parameters").Build()); } var expectedType = funcType.Parameters[i]; @@ -340,6 +359,16 @@ public sealed class TypeChecker parameters.Add(CheckExpression(parameter, expectedType)); } + if (boundExpression is StructFuncAccessNode structFuncAccess) + { + return new StructFuncCallNode(funcType.ReturnType, structFuncAccess, structFuncAccess.Target, parameters); + } + + if (boundExpression is InterfaceFuncAccessNode interfaceFuncAccess) + { + return new InterfaceFuncCallNode(funcType.ReturnType, interfaceFuncAccess, interfaceFuncAccess.Target, parameters); + } + return new FuncCallNode(funcType.ReturnType, boundExpression, parameters); } @@ -616,21 +645,6 @@ public sealed class TypeChecker return new BlockNode(statements); } - private BlockNode CheckFuncBody(BlockSyntax block, TypeNode returnType, IReadOnlyList parameters) - { - _funcReturnTypes.Push(returnType); - - var scope = new Scope(); - foreach (var parameter in parameters) - { - scope.Declare(new Variable(parameter.Name, parameter.Type)); - } - - var body = CheckBlock(block, scope); - _funcReturnTypes.Pop(); - return body; - } - private TypeNode CheckType(TypeSyntax node) { return node switch @@ -649,15 +663,8 @@ public sealed class TypeChecker }; } - private readonly Dictionary _typeCache = new(); - private TypeNode CheckCustomType(CustomTypeSyntax type) { - if (_typeCache.TryGetValue(type.Name, out var cachedType)) - { - return cachedType; - } - var structs = _definitionTable.LookupStruct(type.Name).ToArray(); if (structs.Length > 0) { @@ -666,34 +673,7 @@ public sealed class TypeChecker throw new TypeCheckerException(Diagnostic.Error($"Struct {type.Name} has multiple definitions").Build()); } - var @struct = structs[0]; - - var result = new StructTypeNode(type.Name, [], [], []); - _typeCache.Add(type.Name, result); - - var fields = @struct.Fields.Select(x => CheckType(x.Type)).ToList(); - - var funcs = @struct.Functions - .Select(x => new FuncTypeNode(x.Signature.Parameters.Select(p => CheckType(p.Type)).ToList(), CheckType(x.Signature.ReturnType))) - .ToList(); - - var interfaceImplementations = new List(); - - foreach (var structInterfaceImplementation in @struct.InterfaceImplementations) - { - var checkedInterfaceType = CheckType(structInterfaceImplementation); - if (checkedInterfaceType is not InterfaceTypeNode interfaceType) - { - throw new TypeCheckerException(Diagnostic.Error($"{type.Name} cannot implement non-interface type {checkedInterfaceType}").Build()); - } - - interfaceImplementations.Add(interfaceType); - } - - result.Fields = fields; - result.Functions = funcs; - result.InterfaceImplementations = interfaceImplementations; - return result; + return GetStructType(structs[0]); } var interfaces = _definitionTable.LookupInterface(type.Name).ToArray(); @@ -704,21 +684,64 @@ public sealed class TypeChecker throw new TypeCheckerException(Diagnostic.Error($"Interface {type.Name} has multiple definitions").Build()); } - var @interface = interfaces[0]; - - var result = new InterfaceTypeNode(type.Name, []); - _typeCache.Add(type.Name, result); - - var functions = @interface.Functions - .Select(x => new FuncTypeNode(x.Signature.Parameters.Select(y => CheckType(y.Type)).ToList(), CheckType(x.Signature.ReturnType))) - .ToList(); - - result.Functions = functions; - return result; + return GetInterfaceType(interfaces[0]); } throw new TypeCheckerException(Diagnostic.Error($"Type {type.Name} is not defined").Build()); } + + private StructTypeNode GetStructType(StructSyntax structDef) + { + if (_typeCache.TryGetValue(structDef.Name, out var cachedType)) + { + return (StructTypeNode)cachedType; + } + + var result = new StructTypeNode(structDef.Name, [], [], []); + _typeCache.Add(structDef.Name, result); + + var fields = structDef.Fields.Select(x => CheckType(x.Type)).ToList(); + + var funcs = structDef.Functions + .Select(x => new FuncTypeNode(x.Signature.Parameters.Select(p => CheckType(p.Type)).ToList(), CheckType(x.Signature.ReturnType))) + .ToList(); + + var interfaceImplementations = new List(); + + foreach (var structInterfaceImplementation in structDef.InterfaceImplementations) + { + var checkedInterfaceType = CheckType(structInterfaceImplementation); + if (checkedInterfaceType is not InterfaceTypeNode interfaceType) + { + throw new TypeCheckerException(Diagnostic.Error($"{structDef.Name} cannot implement non-interface type {checkedInterfaceType}").Build()); + } + + interfaceImplementations.Add(interfaceType); + } + + result.Fields = fields; + result.Functions = funcs; + result.InterfaceImplementations = interfaceImplementations; + return result; + } + + private InterfaceTypeNode GetInterfaceType(InterfaceSyntax interfaceDef) + { + if (_typeCache.TryGetValue(interfaceDef.Name, out var cachedType)) + { + return (InterfaceTypeNode)cachedType; + } + + var result = new InterfaceTypeNode(interfaceDef.Name, []); + _typeCache.Add(interfaceDef.Name, result); + + var functions = interfaceDef.Functions + .Select(x => new FuncTypeNode(x.Signature.Parameters.Select(y => CheckType(y.Type)).ToList(), CheckType(x.Signature.ReturnType))) + .ToList(); + + result.Functions = functions; + return result; + } } public record Variable(string Name, TypeNode Type);