diff --git a/example/src/main.nub b/example/src/main.nub index 173fc0f..eb05ccd 100644 --- a/example/src/main.nub +++ b/example/src/main.nub @@ -3,6 +3,7 @@ extern func puts(text: cstring) interface Test { func print() + func test() } struct Human : Test @@ -13,6 +14,11 @@ struct Human : Test { puts(this.name) } + + func test() + { + puts("test") + } } func main(args: []cstring): i64 @@ -22,6 +28,7 @@ func main(args: []cstring): i64 } human.print() + human.test() return 0 } diff --git a/src/compiler/NubLang/Generation/QBE/QBEGenerator.cs b/src/compiler/NubLang/Generation/QBE/QBEGenerator.cs index f353be2..cdd9024 100644 --- a/src/compiler/NubLang/Generation/QBE/QBEGenerator.cs +++ b/src/compiler/NubLang/Generation/QBE/QBEGenerator.cs @@ -634,7 +634,6 @@ public class QBEGenerator DereferenceNode dereference => EmitDereference(dereference), BinaryExpressionNode binaryExpression => EmitBinaryExpression(binaryExpression), FuncCallNode funcCallExpression => EmitFuncCall(funcCallExpression), - InterfaceFuncAccessNode interfaceFuncAccess => EmitInterfaceFuncAccess(interfaceFuncAccess), InterfaceFuncCallNode interfaceFuncCall => EmitInterfaceFuncCall(interfaceFuncCall), InterfaceInitializerNode interfaceInitializer => EmitInterfaceInitializer(interfaceInitializer), ExternFuncIdentNode externFuncIdent => EmitExternFuncIdent(externFuncIdent), @@ -643,7 +642,6 @@ public class QBEGenerator LiteralNode literal => EmitLiteral(literal), UnaryExpressionNode unaryExpression => EmitUnaryExpression(unaryExpression), StructFieldAccessNode structFieldAccess => EmitStructFieldAccess(structFieldAccess), - StructFuncAccessNode structFuncAccess => EmitStructFuncAccess(structFuncAccess), StructFuncCallNode structFuncCall => EmitStructFuncCall(structFuncCall), ArrayIndexAccessNode arrayIndex => EmitArrayIndexAccess(arrayIndex), _ => throw new ArgumentOutOfRangeException(nameof(expression)) @@ -1030,17 +1028,11 @@ public class QBEGenerator return new Val(output, structFieldAccess.Type, ValKind.Pointer); } - private Val EmitStructFuncAccess(StructFuncAccessNode structFuncAccess) - { - var structDef = _definitionTable.LookupStruct(structFuncAccess.StructType.Name); - var func = StructFuncName(structDef.Name, structFuncAccess.Func); - - return new Val(func, structFuncAccess.Type, ValKind.Direct); - } - private Val EmitStructFuncCall(StructFuncCallNode structFuncCall) { - var expression = EmitExpression(structFuncCall.Expression); + var structDef = _definitionTable.LookupStruct(structFuncCall.StructType.Name); + var func = StructFuncName(structDef.Name, structFuncCall.Name); + var thisParameter = EmitUnwrap(EmitExpression(structFuncCall.StructExpression)); List parameterStrings = [$"l {thisParameter}"]; @@ -1051,26 +1043,25 @@ public class QBEGenerator parameterStrings.Add($"{FuncQBETypeName(parameter.Type)} {copy}"); } - var funcPointer = EmitUnwrap(expression); if (structFuncCall.Type is VoidTypeNode) { - _writer.Indented($"call {funcPointer}({string.Join(", ", parameterStrings)})"); + _writer.Indented($"call {func}({string.Join(", ", parameterStrings)})"); return new Val(string.Empty, structFuncCall.Type, ValKind.Direct); } else { var outputName = TmpName(); - _writer.Indented($"{outputName} {QBEAssign(structFuncCall.Type)} call {funcPointer}({string.Join(", ", parameterStrings)})"); + _writer.Indented($"{outputName} {QBEAssign(structFuncCall.Type)} call {func}({string.Join(", ", parameterStrings)})"); return new Val(outputName, structFuncCall.Type, ValKind.Direct); } } - private Val EmitInterfaceFuncAccess(InterfaceFuncAccessNode interfaceFuncAccess) + private Val EmitInterfaceFuncCall(InterfaceFuncCallNode interfaceFuncCall) { - var target = EmitUnwrap(EmitExpression(interfaceFuncAccess.Target)); + var target = EmitUnwrap(EmitExpression(interfaceFuncCall.InterfaceExpression)); - var interfaceDef = _definitionTable.LookupInterface(interfaceFuncAccess.InterfaceType.Name); - var functionIndex = interfaceDef.Functions.ToList().FindIndex(x => x.Name == interfaceFuncAccess.FuncName); + var interfaceDef = _definitionTable.LookupInterface(interfaceFuncCall.InterfaceType.Name); + var functionIndex = interfaceDef.Functions.ToList().FindIndex(x => x.Name == interfaceFuncCall.Name); var offset = functionIndex * 8; var vtable = TmpName(); @@ -1082,18 +1073,10 @@ public class QBEGenerator var func = TmpName(); _writer.Indented($"{func} =l loadl {funcOffset}"); - return new Val(func, interfaceFuncAccess.Type, ValKind.Direct); - } + _writer.Indented($"{target} =l add {target}, 8"); + _writer.Indented($"{target} =l loadl {target}"); - private Val EmitInterfaceFuncCall(InterfaceFuncCallNode interfaceFuncCall) - { - var expression = EmitExpression(interfaceFuncCall.Expression); - - var thisParameter = EmitUnwrap(EmitExpression(interfaceFuncCall.InterfaceExpression)); - _writer.Indented($"{thisParameter} =l add {thisParameter}, 8"); - _writer.Indented($"{thisParameter} =l loadl {thisParameter}"); - - List parameterStrings = [$"l {thisParameter}"]; + List parameterStrings = [$"l {target}"]; foreach (var parameter in interfaceFuncCall.Parameters) { @@ -1101,16 +1084,15 @@ public class QBEGenerator parameterStrings.Add($"{FuncQBETypeName(parameter.Type)} {copy}"); } - var funcPointer = EmitUnwrap(expression); if (interfaceFuncCall.Type is VoidTypeNode) { - _writer.Indented($"call {funcPointer}({string.Join(", ", parameterStrings)})"); + _writer.Indented($"call {func}({string.Join(", ", parameterStrings)})"); return new Val(string.Empty, interfaceFuncCall.Type, ValKind.Direct); } else { var outputName = TmpName(); - _writer.Indented($"{outputName} {QBEAssign(interfaceFuncCall.Type)} call {funcPointer}({string.Join(", ", parameterStrings)})"); + _writer.Indented($"{outputName} {QBEAssign(interfaceFuncCall.Type)} call {func}({string.Join(", ", parameterStrings)})"); return new Val(outputName, interfaceFuncCall.Type, ValKind.Direct); } } diff --git a/src/compiler/NubLang/Parsing/Parser.cs b/src/compiler/NubLang/Parsing/Parser.cs index feebccf..ea09d34 100644 --- a/src/compiler/NubLang/Parsing/Parser.cs +++ b/src/compiler/NubLang/Parsing/Parser.cs @@ -487,8 +487,29 @@ public sealed class Parser if (TryExpectSymbol(Symbol.Period)) { - var structMember = ExpectIdentifier().Value; - expr = new MemberAccessSyntax(GetTokens(startIndex), expr, structMember); + var member = ExpectIdentifier().Value; + if (TryExpectSymbol(Symbol.OpenParen)) + { + var parameters = new List(); + while (!TryExpectSymbol(Symbol.CloseParen)) + { + var parameter = ParseExpression(); + parameters.Add(parameter); + if (!TryExpectSymbol(Symbol.Comma) && CurrentToken is not SymbolToken { Symbol: Symbol.CloseParen }) + { + _diagnostics.Add(Diagnostic + .Warning("Missing comma between function arguments") + .WithHelp("Add a ',' to separate arguments") + .At(CurrentToken) + .Build()); + } + } + + expr = new DotFuncCallSyntax(GetTokens(startIndex), member, expr, parameters); + continue; + } + + expr = new StructFieldAccessSyntax(GetTokens(startIndex), expr, member); continue; } diff --git a/src/compiler/NubLang/Parsing/Syntax/ExpressionSyntax.cs b/src/compiler/NubLang/Parsing/Syntax/ExpressionSyntax.cs index 2bec916..476197a 100644 --- a/src/compiler/NubLang/Parsing/Syntax/ExpressionSyntax.cs +++ b/src/compiler/NubLang/Parsing/Syntax/ExpressionSyntax.cs @@ -30,6 +30,8 @@ public record UnaryExpressionSyntax(IEnumerable Tokens, UnaryOperatorSynt public record FuncCallSyntax(IEnumerable Tokens, ExpressionSyntax Expression, IReadOnlyList Parameters) : ExpressionSyntax(Tokens); +public record DotFuncCallSyntax(IEnumerable Tokens, string Name, ExpressionSyntax ThisParameter, IReadOnlyList Parameters) : ExpressionSyntax(Tokens); + public record IdentifierSyntax(IEnumerable Tokens, string Name) : ExpressionSyntax(Tokens); public record ArrayInitializerSyntax(IEnumerable Tokens, ExpressionSyntax Capacity, TypeSyntax ElementType) : ExpressionSyntax(Tokens); @@ -40,7 +42,7 @@ public record AddressOfSyntax(IEnumerable Tokens, ExpressionSyntax Expres public record LiteralSyntax(IEnumerable Tokens, string Value, LiteralKind Kind) : ExpressionSyntax(Tokens); -public record MemberAccessSyntax(IEnumerable Tokens, ExpressionSyntax Target, string Member) : ExpressionSyntax(Tokens); +public record StructFieldAccessSyntax(IEnumerable Tokens, ExpressionSyntax Target, string Member) : ExpressionSyntax(Tokens); public record StructInitializerSyntax(IEnumerable Tokens, TypeSyntax StructType, Dictionary Initializers) : ExpressionSyntax(Tokens); diff --git a/src/compiler/NubLang/TypeChecking/Node/ExpressionNode.cs b/src/compiler/NubLang/TypeChecking/Node/ExpressionNode.cs index 9fc17fe..f72bceb 100644 --- a/src/compiler/NubLang/TypeChecking/Node/ExpressionNode.cs +++ b/src/compiler/NubLang/TypeChecking/Node/ExpressionNode.cs @@ -30,9 +30,9 @@ public record UnaryExpressionNode(TypeNode Type, UnaryOperator Operator, Express public record FuncCallNode(TypeNode Type, ExpressionNode Expression, IReadOnlyList Parameters) : ExpressionNode(Type); -public record StructFuncCallNode(TypeNode Type, ExpressionNode Expression, ExpressionNode StructExpression, IReadOnlyList Parameters) : ExpressionNode(Type); +public record StructFuncCallNode(TypeNode Type, string Name, StructTypeNode StructType, ExpressionNode StructExpression, IReadOnlyList Parameters) : ExpressionNode(Type); -public record InterfaceFuncCallNode(TypeNode Type, ExpressionNode Expression, ExpressionNode InterfaceExpression, IReadOnlyList Parameters) : ExpressionNode(Type); +public record InterfaceFuncCallNode(TypeNode Type, string Name, InterfaceTypeNode InterfaceType, ExpressionNode InterfaceExpression, IReadOnlyList Parameters) : ExpressionNode(Type); public record VariableIdentNode(TypeNode Type, string Name) : ExpressionNode(Type); @@ -50,12 +50,8 @@ public record LiteralNode(TypeNode Type, string Value, LiteralKind Kind) : Expre public record StructFieldAccessNode(TypeNode Type, StructTypeNode StructType, ExpressionNode Target, string Field) : ExpressionNode(Type); -public record StructFuncAccessNode(TypeNode Type, StructTypeNode StructType, ExpressionNode Target, string Func) : ExpressionNode(Type); - -public record InterfaceFuncAccessNode(TypeNode Type, InterfaceTypeNode InterfaceType, ExpressionNode Target, string FuncName) : ExpressionNode(Type); - public record StructInitializerNode(StructTypeNode StructType, Dictionary Initializers) : ExpressionNode(StructType); public record DereferenceNode(TypeNode Type, ExpressionNode Expression) : ExpressionNode(Type); -public record InterfaceInitializerNode(TypeNode Type, InterfaceTypeNode InterfaceType, StructTypeNode StructType, ExpressionNode Implementation) : ExpressionNode(Type); +public record InterfaceInitializerNode(TypeNode Type, InterfaceTypeNode InterfaceType, StructTypeNode StructType, ExpressionNode Implementation) : ExpressionNode(Type); \ No newline at end of file diff --git a/src/compiler/NubLang/TypeChecking/TypeChecker.cs b/src/compiler/NubLang/TypeChecking/TypeChecker.cs index 4051104..e033977 100644 --- a/src/compiler/NubLang/TypeChecking/TypeChecker.cs +++ b/src/compiler/NubLang/TypeChecking/TypeChecker.cs @@ -271,10 +271,11 @@ public sealed class TypeChecker ArrayInitializerSyntax expression => CheckArrayInitializer(expression), BinaryExpressionSyntax expression => CheckBinaryExpression(expression), DereferenceSyntax expression => CheckDereference(expression), + DotFuncCallSyntax expression => CheckDotFuncCall(expression), FuncCallSyntax expression => CheckFuncCall(expression), IdentifierSyntax expression => CheckIdentifier(expression), LiteralSyntax expression => CheckLiteral(expression, expectedType), - MemberAccessSyntax expression => CheckMemberAccess(expression), + StructFieldAccessSyntax expression => CheckStructFieldAccess(expression), StructInitializerSyntax expression => CheckStructInitializer(expression), UnaryExpressionSyntax expression => CheckUnaryExpression(expression), _ => throw new ArgumentOutOfRangeException(nameof(node)) @@ -336,7 +337,7 @@ public sealed class TypeChecker return new DereferenceNode(dereferencedType, boundExpression); } - private ExpressionNode CheckFuncCall(FuncCallSyntax expression) + private FuncCallNode CheckFuncCall(FuncCallSyntax expression) { var boundExpression = CheckExpression(expression.Expression); @@ -359,19 +360,78 @@ public sealed class TypeChecker parameters.Add(CheckExpression(parameter, expectedType)); } - if (boundExpression is StructFuncAccessNode structFuncAccess) - { - return new StructFuncCallNode(funcType.ReturnType, structFuncAccess, structFuncAccess.Target, parameters); - } - - if (boundExpression is InterfaceFuncAccessNode interfaceFuncAccess) - { - return new InterfaceFuncCallNode(funcType.ReturnType, interfaceFuncAccess, interfaceFuncAccess.Target, parameters); - } - return new FuncCallNode(funcType.ReturnType, boundExpression, parameters); } + private ExpressionNode CheckDotFuncCall(DotFuncCallSyntax expression) + { + var thisParameter = CheckExpression(expression.ThisParameter); + + if (thisParameter.Type is InterfaceTypeNode interfaceType) + { + var interfaceDefinitions = _definitionTable.LookupInterface(interfaceType.Name).ToArray(); + if (interfaceDefinitions.Length == 0) + { + throw new TypeCheckerException(Diagnostic.Error($"Interface {interfaceType.Name} is not defined").Build()); + } + + if (interfaceDefinitions.Length > 1) + { + throw new TypeCheckerException(Diagnostic.Error($"Interface {interfaceType.Name} has multiple definitions").Build()); + } + + var function = interfaceDefinitions[0].Functions.FirstOrDefault(x => x.Name == expression.Name); + if (function == null) + { + throw new TypeCheckerException(Diagnostic.Error($"Interface {interfaceType.Name} does not have a function with the name {expression.Name}").Build()); + } + + var parameters = new List(); + for (var i = 0; i < expression.Parameters.Count; i++) + { + var parameter = expression.Parameters[i]; + var expectedType = CheckType(function.Signature.Parameters[i].Type); + parameters.Add(CheckExpression(parameter, expectedType)); + } + + var returnType = CheckType(function.Signature.ReturnType); + return new InterfaceFuncCallNode(returnType, expression.Name, interfaceType, thisParameter, parameters); + } + + if (thisParameter.Type is StructTypeNode structType) + { + var structDefinitions = _definitionTable.LookupStruct(structType.Name).ToArray(); + if (structDefinitions.Length == 0) + { + throw new TypeCheckerException(Diagnostic.Error($"Struct {structType.Name} is not defined").Build()); + } + + if (structDefinitions.Length > 1) + { + throw new TypeCheckerException(Diagnostic.Error($"Struct {structType.Name} has multiple definitions").Build()); + } + + var function = structDefinitions[0].Functions.FirstOrDefault(x => x.Name == expression.Name); + if (function == null) + { + throw new TypeCheckerException(Diagnostic.Error($"Struct {structType.Name} does not have a function with the name {expression.Name}").Build()); + } + + var parameters = new List(); + for (var i = 0; i < expression.Parameters.Count; i++) + { + var parameter = expression.Parameters[i]; + var expectedType = CheckType(function.Signature.Parameters[i].Type); + parameters.Add(CheckExpression(parameter, expectedType)); + } + + var returnType = CheckType(function.Signature.ReturnType); + return new StructFuncCallNode(returnType, expression.Name, structType, thisParameter, parameters); + } + + throw new TypeCheckerException(Diagnostic.Error($"Cannot call dot function on type {thisParameter.Type}").Build()); + } + private ExpressionNode CheckIdentifier(IdentifierSyntax expression) { var variable = Scope.Lookup(expression.Name); @@ -429,79 +489,34 @@ public sealed class TypeChecker return new LiteralNode(type, expression.Value, expression.Kind); } - private ExpressionNode CheckMemberAccess(MemberAccessSyntax expression) + private StructFieldAccessNode CheckStructFieldAccess(StructFieldAccessSyntax expression) { var boundExpression = CheckExpression(expression.Target); - if (boundExpression.Type is InterfaceTypeNode interfaceType) + if (boundExpression.Type is not StructTypeNode structType) { - var interfaces = _definitionTable.LookupInterface(interfaceType.Name).ToArray(); - if (interfaces.Length > 0) - { - if (interfaces.Length > 1) - { - throw new TypeCheckerException(Diagnostic.Error($"Interface {interfaceType} has multiple definitions").Build()); - } - - var @interface = interfaces[0]; - - var interfaceFuncs = _definitionTable.LookupInterfaceFunc(@interface, expression.Member).ToArray(); - if (interfaceFuncs.Length > 0) - { - if (interfaceFuncs.Length > 1) - { - throw new TypeCheckerException(Diagnostic.Error($"Interface {interfaceType} has multiple functions with the name {expression.Member}").Build()); - } - - var interfaceFunc = interfaceFuncs[0]; - - var returnType = CheckType(interfaceFunc.Signature.ReturnType); - var parameterTypes = interfaceFunc.Signature.Parameters.Select(p => CheckType(p.Type)).ToList(); - var type = new FuncTypeNode(parameterTypes, returnType); - return new InterfaceFuncAccessNode(type, interfaceType, boundExpression, expression.Member); - } - } + throw new Exception($"Cannot access struct field on non-struct type {boundExpression.Type}"); } - if (boundExpression.Type is StructTypeNode structType) + var structs = _definitionTable.LookupStruct(structType.Name).ToArray(); + if (structs.Length > 0) { - var structs = _definitionTable.LookupStruct(structType.Name).ToArray(); - if (structs.Length > 0) + if (structs.Length > 1) { - if (structs.Length > 1) + throw new TypeCheckerException(Diagnostic.Error($"Struct {structType} has multiple definitions").Build()); + } + + var fields = _definitionTable.LookupStructField(structs[0], expression.Member).ToArray(); + if (fields.Length > 0) + { + if (fields.Length > 1) { - throw new TypeCheckerException(Diagnostic.Error($"Struct {structType} has multiple definitions").Build()); + throw new TypeCheckerException(Diagnostic.Error($"Struct {structType} has multiple fields with the name {expression.Member}").Build()); } - var @struct = structs[0]; + var field = fields[0]; - var fields = _definitionTable.LookupStructField(@struct, expression.Member).ToArray(); - if (fields.Length > 0) - { - if (fields.Length > 1) - { - throw new TypeCheckerException(Diagnostic.Error($"Struct {structType} has multiple fields with the name {expression.Member}").Build()); - } - - var field = fields[0]; - - return new StructFieldAccessNode(CheckType(field.Type), structType, boundExpression, expression.Member); - } - - var funcs = _definitionTable.LookupStructFunc(@struct, expression.Member).ToArray(); - if (funcs.Length > 0) - { - if (funcs.Length > 1) - { - throw new TypeCheckerException(Diagnostic.Error($"Struct {structType} has multiple functions with the name {expression.Member}").Build()); - } - - var func = funcs[0]; - - var parameters = func.Signature.Parameters.Select(x => CheckType(x.Type)).ToList(); - var returnType = CheckType(func.Signature.ReturnType); - return new StructFuncAccessNode(new FuncTypeNode(parameters, returnType), structType, boundExpression, expression.Member); - } + return new StructFieldAccessNode(CheckType(field.Type), structType, boundExpression, expression.Member); } }