diff --git a/compiler/Generator.cs b/compiler/Generator.cs index b6c0515..82a482d 100644 --- a/compiler/Generator.cs +++ b/compiler/Generator.cs @@ -200,6 +200,7 @@ public sealed class Generator(List functions, ModuleGra TypedNodeExpressionMemberAccess expression => EmitExpressionMemberAccess(expression), TypedNodeExpressionLocalIdent expression => expression.Value.Ident, TypedNodeExpressionModuleIdent expression => expression.Value.Ident, + TypedNodeExpressionFuncCall expression => EmitExpressionFuncCall(expression), _ => throw new ArgumentOutOfRangeException(nameof(node), node, null) }; } @@ -263,6 +264,13 @@ public sealed class Generator(List functions, ModuleGra return $"{target}.{expression.Name.Ident}"; } + private string EmitExpressionFuncCall(TypedNodeExpressionFuncCall expression) + { + var name = EmitExpression(expression.Target); + var parameterValues = expression.Parameters.Select(EmitExpression).ToList(); + return $"{name}({string.Join(", ", parameterValues)})"; + } + private string CType(NubType node, string? varName = null) { return node switch diff --git a/compiler/Parser.cs b/compiler/Parser.cs index 45bdb07..1d9fe6d 100644 --- a/compiler/Parser.cs +++ b/compiler/Parser.cs @@ -135,23 +135,13 @@ public sealed class Parser(string fileName, List tokens) var target = ParseExpression(); - if (TryExpectSymbol(Symbol.OpenParen)) - { - var parameters = new List(); - - while (!TryExpectSymbol(Symbol.CloseParen)) - parameters.Add(ParseExpression()); - - return new NodeStatementFuncCall(TokensFrom(startIndex), target, parameters); - } - if (TryExpectSymbol(Symbol.Equal)) { var value = ParseExpression(); return new NodeStatementAssignment(TokensFrom(startIndex), target, value); } - throw new CompileException(Diagnostic.Error("Cannot use expression in statement context unless called as a function or used in assignment").At(fileName, target).Build()); + return new NodeStatementExpression(TokensFrom(startIndex), target); } private NodeExpression ParseExpression(int minPrecedence = -1) @@ -273,10 +263,26 @@ public sealed class Parser(string fileName, List tokens) throw new CompileException(Diagnostic.Error("Expected start of expression").At(fileName, Peek()).Build()); } - if (TryExpectSymbol(Symbol.Period)) + while (true) { - var name = ExpectIdent(); - expr = new NodeExpressionMemberAccess(TokensFrom(startIndex), expr, name); + if (TryExpectSymbol(Symbol.Period)) + { + var name = ExpectIdent(); + expr = new NodeExpressionMemberAccess(TokensFrom(startIndex), expr, name); + } + else if (TryExpectSymbol(Symbol.OpenParen)) + { + var parameters = new List(); + + while (!TryExpectSymbol(Symbol.CloseParen)) + parameters.Add(ParseExpression()); + + expr = new NodeExpressionFuncCall(TokensFrom(startIndex), expr, parameters); + } + else + { + break; + } } return expr; @@ -574,10 +580,9 @@ public sealed class NodeStatementBlock(List tokens, List s public List Statements { get; } = statements; } -public sealed class NodeStatementFuncCall(List tokens, NodeExpression target, List parameters) : NodeStatement(tokens) +public sealed class NodeStatementExpression(List tokens, NodeExpression expression) : NodeStatement(tokens) { - public NodeExpression Target { get; } = target; - public List Parameters { get; } = parameters; + public NodeExpression Expression { get; } = expression; } public sealed class NodeStatementReturn(List tokens, NodeExpression value) : NodeStatement(tokens) @@ -647,6 +652,12 @@ public sealed class NodeExpressionMemberAccess(List tokens, NodeExpressio public TokenIdent Name { get; } = name; } +public sealed class NodeExpressionFuncCall(List tokens, NodeExpression target, List parameters) : NodeExpression(tokens) +{ + public NodeExpression Target { get; } = target; + public List Parameters { get; } = parameters; +} + public sealed class NodeExpressionLocalIdent(List tokens, TokenIdent value) : NodeExpression(tokens) { public TokenIdent Value { get; } = value; diff --git a/compiler/TypeChecker.cs b/compiler/TypeChecker.cs index 0825e93..27f835e 100644 --- a/compiler/TypeChecker.cs +++ b/compiler/TypeChecker.cs @@ -66,7 +66,7 @@ public sealed class TypeChecker(string fileName, NodeDefinitionFunc function, Mo { NodeStatementAssignment statement => CheckStatementAssignment(statement), NodeStatementBlock statement => CheckStatementBlock(statement), - NodeStatementFuncCall statement => CheckStatementFuncCall(statement), + NodeStatementExpression statement => CheckStatementExpression(statement), NodeStatementIf statement => CheckStatementIf(statement), NodeStatementReturn statement => CheckStatementReturn(statement), NodeStatementVariableDeclaration statement => CheckStatementVariableDeclaration(statement), @@ -85,9 +85,12 @@ public sealed class TypeChecker(string fileName, NodeDefinitionFunc function, Mo return new TypedNodeStatementBlock(statement.Tokens, statement.Statements.Select(CheckStatement).ToList()); } - private TypedNodeStatementFuncCall CheckStatementFuncCall(NodeStatementFuncCall statement) + private TypedNodeStatementFuncCall CheckStatementExpression(NodeStatementExpression statement) { - return new TypedNodeStatementFuncCall(statement.Tokens, CheckExpression(statement.Target), statement.Parameters.Select(CheckExpression).ToList()); + if (statement.Expression is not NodeExpressionFuncCall funcCall) + throw new CompileException(Diagnostic.Error("Expected statement or function call").At(fileName, statement).Build()); + + return new TypedNodeStatementFuncCall(statement.Tokens, CheckExpression(funcCall.Target), funcCall.Parameters.Select(CheckExpression).ToList()); } private TypedNodeStatementIf CheckStatementIf(NodeStatementIf statement) @@ -129,6 +132,7 @@ public sealed class TypeChecker(string fileName, NodeDefinitionFunc function, Mo NodeExpressionModuleIdent expression => CheckExpressionModuleIdent(expression), NodeExpressionIntLiteral expression => CheckExpressionIntLiteral(expression), NodeExpressionMemberAccess expression => CheckExpressionMemberAccess(expression), + NodeExpressionFuncCall expression => CheckExpressionFuncCall(expression), NodeExpressionStringLiteral expression => CheckExpressionStringLiteral(expression), NodeExpressionStructLiteral expression => CheckExpressionStructLiteral(expression), _ => throw new ArgumentOutOfRangeException(nameof(node)) @@ -311,6 +315,17 @@ public sealed class TypeChecker(string fileName, NodeDefinitionFunc function, Mo return new TypedNodeExpressionMemberAccess(expression.Tokens, field.Type, target, expression.Name); } + private TypedNodeExpressionFuncCall CheckExpressionFuncCall(NodeExpressionFuncCall expression) + { + var target = CheckExpression(expression.Target); + if (target.Type is not NubTypeFunc funcType) + throw new CompileException(Diagnostic.Error($"Cannot invoke function call on type '{target.Type}'").At(fileName, target).Build()); + + var parameters = expression.Parameters.Select(CheckExpression).ToList(); + + return new TypedNodeExpressionFuncCall(expression.Tokens, funcType.ReturnType, target, parameters); + } + private TypedNodeExpressionStringLiteral CheckExpressionStringLiteral(NodeExpressionStringLiteral expression) { return new TypedNodeExpressionStringLiteral(expression.Tokens, NubTypeString.Instance, expression.Value); @@ -491,6 +506,12 @@ public sealed class TypedNodeExpressionMemberAccess(List tokens, NubType public TokenIdent Name { get; } = name; } +public sealed class TypedNodeExpressionFuncCall(List tokens, NubType type, TypedNodeExpression target, List parameters) : TypedNodeExpression(tokens, type) +{ + public TypedNodeExpression Target { get; } = target; + public List Parameters { get; } = parameters; +} + public sealed class TypedNodeExpressionLocalIdent(List tokens, NubType type, TokenIdent value) : TypedNodeExpression(tokens, type) { public TokenIdent Value { get; } = value; diff --git a/compiler/test.nub b/compiler/test.nub index ab245b3..0f6cd4c 100644 --- a/compiler/test.nub +++ b/compiler/test.nub @@ -21,6 +21,7 @@ func main(): i32 { let me: test::person = struct test::person { age = 21 name = "Oliver" } + x = test::do_something(me.name) test::do_something(me.name) return x } \ No newline at end of file diff --git a/compiler/test2.nub b/compiler/test2.nub index 467364e..62d9828 100644 --- a/compiler/test2.nub +++ b/compiler/test2.nub @@ -5,5 +5,6 @@ struct person { name: string } -func do_something(name: string): void { +func do_something(name: string): i32 { + return 3 } \ No newline at end of file