diff --git a/compiler/NubLang/Parsing/Syntax/ExpressionSyntax.cs b/compiler/NubLang/Parsing/Syntax/ExpressionSyntax.cs index cdf35f3..ede9750 100644 --- a/compiler/NubLang/Parsing/Syntax/ExpressionSyntax.cs +++ b/compiler/NubLang/Parsing/Syntax/ExpressionSyntax.cs @@ -38,7 +38,7 @@ public record UnaryExpressionSyntax(IEnumerable Tokens, UnaryOperatorSynt public record FuncCallSyntax(IEnumerable Tokens, ExpressionSyntax Expression, List Parameters) : ExpressionSyntax(Tokens); -public record DotFuncCallSyntax(IEnumerable Tokens, string Name, ExpressionSyntax ThisParameter, List Parameters) : ExpressionSyntax(Tokens); +public record DotFuncCallSyntax(IEnumerable Tokens, string Name, ExpressionSyntax Target, List Parameters) : ExpressionSyntax(Tokens); public record LocalIdentifierSyntax(IEnumerable Tokens, string Name) : ExpressionSyntax(Tokens); @@ -48,7 +48,7 @@ public record ArrayInitializerSyntax(IEnumerable Tokens, ExpressionSyntax public record ArrayIndexAccessSyntax(IEnumerable Tokens, ExpressionSyntax Target, ExpressionSyntax Index) : ExpressionSyntax(Tokens); -public record AddressOfSyntax(IEnumerable Tokens, ExpressionSyntax Expression) : ExpressionSyntax(Tokens); +public record AddressOfSyntax(IEnumerable Tokens, ExpressionSyntax Target) : ExpressionSyntax(Tokens); public record LiteralSyntax(IEnumerable Tokens, string Value, LiteralKind Kind) : ExpressionSyntax(Tokens); @@ -56,4 +56,4 @@ public record StructFieldAccessSyntax(IEnumerable Tokens, ExpressionSynta public record StructInitializerSyntax(IEnumerable Tokens, Optional StructType, Dictionary Initializers) : ExpressionSyntax(Tokens); -public record DereferenceSyntax(IEnumerable Tokens, ExpressionSyntax Expression) : ExpressionSyntax(Tokens); \ No newline at end of file +public record DereferenceSyntax(IEnumerable Tokens, ExpressionSyntax Target) : ExpressionSyntax(Tokens); \ No newline at end of file diff --git a/compiler/NubLang/TypeChecking/TypeChecker.cs b/compiler/NubLang/TypeChecking/TypeChecker.cs index 2fa08dd..31fb0bc 100644 --- a/compiler/NubLang/TypeChecking/TypeChecker.cs +++ b/compiler/NubLang/TypeChecking/TypeChecker.cs @@ -66,6 +66,20 @@ public sealed class TypeChecker private StructNode CheckStructDefinition(StructSyntax node) { + var fieldTypes = node.Fields + .Select(x => new StructTypeField(x.Name, ResolveType(x.Type), x.Value.HasValue)) + .ToList(); + + var functionTypes = new List(); + foreach (var function in node.Functions) + { + var parameters = function.Signature.Parameters.Select(x => ResolveType(x.Type)).ToList(); + var funcType = new FuncTypeNode(parameters, ResolveType(function.Signature.ReturnType)); + functionTypes.Add(new StructTypeFunc(function.Name, funcType)); + } + + var type = new StructTypeNode(_syntaxTree.Metadata.ModuleName, node.Name, fieldTypes, functionTypes); + var fields = new List(); foreach (var field in node.Fields) { @@ -82,7 +96,8 @@ public sealed class TypeChecker foreach (var function in node.Functions) { var scope = new Scope(); - // todo(nub31): Add this parameter + scope.Declare(new Identifier("this", type, IdentifierKind.FunctionParameter)); + foreach (var parameter in function.Signature.Parameters) { scope.Declare(new Identifier(parameter.Name, ResolveType(parameter.Type), IdentifierKind.FunctionParameter)); @@ -94,20 +109,6 @@ public sealed class TypeChecker functions.Add(new StructFuncNode(function.Name, CheckFuncSignature(function.Signature), body)); } - var fieldTypes = node.Fields - .Select(x => new StructTypeField(x.Name, ResolveType(x.Type), x.Value.HasValue)) - .ToList(); - - var functionTypes = new List(); - foreach (var function in node.Functions) - { - var parameters = function.Signature.Parameters.Select(x => ResolveType(x.Type)).ToList(); - var funcType = new FuncTypeNode(parameters, ResolveType(function.Signature.ReturnType)); - functionTypes.Add(new StructTypeFunc(function.Name, funcType)); - } - - var type = new StructTypeNode(_syntaxTree.Metadata.ModuleName, node.Name, fieldTypes, functionTypes); - return new StructNode(type, _syntaxTree.Metadata.ModuleName, node.Name, fields, functions); } @@ -146,23 +147,38 @@ public sealed class TypeChecker }; } - private StatementNode CheckAssignment(AssignmentSyntax statement) + private AssignmentNode CheckAssignment(AssignmentSyntax statement) { - throw new NotImplementedException(); + var target = CheckExpression(statement.Target); + if (target is not LValueExpressionNode lValue) + { + throw new TypeCheckerException(Diagnostic.Error("Cannot assign to an rvalue").At(statement).Build()); + } + + var value = CheckExpression(statement.Value, lValue.Type); + return new AssignmentNode(lValue, value); } private IfNode CheckIf(IfSyntax statement) { - throw new NotImplementedException(); + var condition = CheckExpression(statement.Condition, new BoolTypeNode()); + var body = CheckBlock(statement.Body); + var elseStatement = Optional.Empty>(); + if (statement.Else.TryGetValue(out var elseSyntax)) + { + elseStatement = elseSyntax.Match>(elif => CheckIf(elif), el => CheckBlock(el)); + } + + return new IfNode(condition, body, elseStatement); } private ReturnNode CheckReturn(ReturnSyntax statement) { var value = Optional.Empty(); - if (statement.Value.HasValue) + if (statement.Value.TryGetValue(out var valueExpression)) { - value = CheckExpression(statement.Value.Value, _funcReturnTypes.Peek()); + value = CheckExpression(valueExpression, _funcReturnTypes.Peek()); } return new ReturnNode(value); @@ -201,7 +217,9 @@ public sealed class TypeChecker private WhileNode CheckWhile(WhileSyntax statement) { - throw new NotImplementedException(); + var condition = CheckExpression(statement.Condition, new BoolTypeNode()); + var body = CheckBlock(statement.Body); + return new WhileNode(condition, body); } private FuncSignatureNode CheckFuncSignature(FuncSignatureSyntax statement) @@ -223,6 +241,7 @@ public sealed class TypeChecker ArrayIndexAccessSyntax expression => CheckArrayIndexAccess(expression), ArrayInitializerSyntax expression => CheckArrayInitializer(expression), BinaryExpressionSyntax expression => CheckBinaryExpression(expression), + UnaryExpressionSyntax expression => CheckUnaryExpression(expression), DereferenceSyntax expression => CheckDereference(expression), DotFuncCallSyntax expression => CheckDotFuncCall(expression), FuncCallSyntax expression => CheckFuncCall(expression), @@ -231,7 +250,6 @@ public sealed class TypeChecker LiteralSyntax expression => CheckLiteral(expression, expectedType), StructFieldAccessSyntax expression => CheckStructFieldAccess(expression), StructInitializerSyntax expression => CheckStructInitializer(expression, expectedType), - UnaryExpressionSyntax expression => CheckUnaryExpression(expression), _ => throw new ArgumentOutOfRangeException(nameof(node)) }; @@ -261,27 +279,171 @@ public sealed class TypeChecker private AddressOfNode CheckAddressOf(AddressOfSyntax expression) { - throw new NotImplementedException(); + var target = CheckExpression(expression.Target); + if (target is not LValueExpressionNode lvalue) + { + throw new TypeCheckerException(Diagnostic.Error("Cannot take address of an rvalue").Build()); + } + + var type = new PointerTypeNode(target.Type); + return new AddressOfNode(type, lvalue); } private ArrayIndexAccessNode CheckArrayIndexAccess(ArrayIndexAccessSyntax expression) { - throw new NotImplementedException(); + var index = CheckExpression(expression.Index, new IntTypeNode(false, 64)); + var target = CheckExpression(expression.Target); + if (target.Type is not ArrayTypeNode arrayType) + { + throw new TypeCheckerException(Diagnostic.Error($"Cannot use array indexer on type {target.Type}").At(expression).Build()); + } + + return new ArrayIndexAccessNode(arrayType.ElementType, target, index); } private ArrayInitializerNode CheckArrayInitializer(ArrayInitializerSyntax expression) { - throw new NotImplementedException(); + var elementType = ResolveType(expression.ElementType); + var type = new ArrayTypeNode(elementType); + var capacity = CheckExpression(expression.Capacity); + return new ArrayInitializerNode(type, capacity, elementType); } private BinaryExpressionNode CheckBinaryExpression(BinaryExpressionSyntax expression) { - throw new NotImplementedException(); + var op = expression.Operator switch + { + BinaryOperatorSyntax.Equal => BinaryOperator.Equal, + BinaryOperatorSyntax.NotEqual => BinaryOperator.NotEqual, + BinaryOperatorSyntax.GreaterThan => BinaryOperator.GreaterThan, + BinaryOperatorSyntax.GreaterThanOrEqual => BinaryOperator.GreaterThanOrEqual, + BinaryOperatorSyntax.LessThan => BinaryOperator.LessThan, + BinaryOperatorSyntax.LessThanOrEqual => BinaryOperator.LessThanOrEqual, + BinaryOperatorSyntax.LogicalAnd => BinaryOperator.LogicalAnd, + BinaryOperatorSyntax.LogicalOr => BinaryOperator.LogicalOr, + BinaryOperatorSyntax.Plus => BinaryOperator.Plus, + BinaryOperatorSyntax.Minus => BinaryOperator.Minus, + BinaryOperatorSyntax.Multiply => BinaryOperator.Multiply, + BinaryOperatorSyntax.Divide => BinaryOperator.Divide, + BinaryOperatorSyntax.Modulo => BinaryOperator.Modulo, + BinaryOperatorSyntax.LeftShift => BinaryOperator.LeftShift, + BinaryOperatorSyntax.RightShift => BinaryOperator.RightShift, + BinaryOperatorSyntax.BitwiseAnd => BinaryOperator.BitwiseAnd, + BinaryOperatorSyntax.BitwiseXor => BinaryOperator.BitwiseXor, + BinaryOperatorSyntax.BitwiseOr => BinaryOperator.BitwiseOr, + _ => throw new ArgumentOutOfRangeException() + }; + + switch (expression.Operator) + { + case BinaryOperatorSyntax.Equal: + case BinaryOperatorSyntax.NotEqual: + case BinaryOperatorSyntax.GreaterThan: + case BinaryOperatorSyntax.GreaterThanOrEqual: + case BinaryOperatorSyntax.LessThan: + case BinaryOperatorSyntax.LessThanOrEqual: + case BinaryOperatorSyntax.LogicalAnd: + case BinaryOperatorSyntax.LogicalOr: + { + var left = CheckExpression(expression.Left); + if (left.Type is not IntTypeNode or FloatTypeNode) + { + throw new TypeCheckerException(Diagnostic.Error("Logical operators must must be used with int or float types").At(expression.Left).Build()); + } + + var right = CheckExpression(expression.Right, left.Type); + + return new BinaryExpressionNode(new BoolTypeNode(), left, op, right); + } + case BinaryOperatorSyntax.Plus: + { + var left = CheckExpression(expression.Left); + if (left.Type is IntTypeNode or FloatTypeNode or StringTypeNode or StringTypeNode) + { + var right = CheckExpression(expression.Right, left.Type); + return new BinaryExpressionNode(left.Type, left, op, right); + } + + throw new TypeCheckerException(Diagnostic.Error("The plus operator must be used with int, float or string types").At(expression.Left).Build()); + } + case BinaryOperatorSyntax.Minus: + case BinaryOperatorSyntax.Multiply: + case BinaryOperatorSyntax.Divide: + case BinaryOperatorSyntax.Modulo: + { + var left = CheckExpression(expression.Left); + if (left.Type is not IntTypeNode or FloatTypeNode) + { + throw new TypeCheckerException(Diagnostic.Error("Math operators must be used with int or float types").At(expression.Left).Build()); + } + + var right = CheckExpression(expression.Right, left.Type); + + return new BinaryExpressionNode(left.Type, left, op, right); + } + case BinaryOperatorSyntax.LeftShift: + case BinaryOperatorSyntax.RightShift: + case BinaryOperatorSyntax.BitwiseAnd: + case BinaryOperatorSyntax.BitwiseXor: + case BinaryOperatorSyntax.BitwiseOr: + { + var left = CheckExpression(expression.Left); + if (left.Type is not IntTypeNode) + { + throw new TypeCheckerException(Diagnostic.Error("Bitwise operators must be used with int types").At(expression.Left).Build()); + } + + var right = CheckExpression(expression.Right, left.Type); + + return new BinaryExpressionNode(left.Type, left, op, right); + } + default: + { + throw new ArgumentOutOfRangeException(); + } + } + } + + private UnaryExpressionNode CheckUnaryExpression(UnaryExpressionSyntax expression) + { + switch (expression.Operator) + { + case UnaryOperatorSyntax.Negate: + { + var operand = CheckExpression(expression.Operand); + if (operand.Type is not IntTypeNode { Signed: false } or FloatTypeNode) + { + throw new TypeCheckerException(Diagnostic.Error("Negation operator must be used with signed integer or float types").Build()); + } + + return new UnaryExpressionNode(operand.Type, UnaryOperator.Negate, operand); + } + case UnaryOperatorSyntax.Invert: + { + var operand = CheckExpression(expression.Operand); + if (operand.Type is not BoolTypeNode) + { + throw new TypeCheckerException(Diagnostic.Error("Invert operator must be used with booleans").Build()); + } + + return new UnaryExpressionNode(operand.Type, UnaryOperator.Invert, operand); + } + default: + { + throw new ArgumentOutOfRangeException(); + } + } } private DereferenceNode CheckDereference(DereferenceSyntax expression) { - throw new NotImplementedException(); + var target = CheckExpression(expression.Target); + if (target.Type is not PointerTypeNode pointerType) + { + throw new TypeCheckerException(Diagnostic.Error($"Cannot dereference non-pointer type {target.Type}").At(expression).Build()); + } + + return new DereferenceNode(pointerType.BaseType, target); } private FuncCallNode CheckFuncCall(FuncCallSyntax expression) @@ -315,9 +477,37 @@ public sealed class TypeChecker return new FuncCallNode(funcType.ReturnType, accessor, parameters); } - private ExpressionNode CheckDotFuncCall(DotFuncCallSyntax expression) + private StructFuncCallNode CheckDotFuncCall(DotFuncCallSyntax expression) { - throw new NotImplementedException(); + // todo(nub31): When adding interfaces, also support other types than structs + var target = CheckExpression(expression.Target); + if (target.Type is StructTypeNode structType) + { + var function = structType.Functions.FirstOrDefault(x => x.Name == expression.Name); + if (function == null) + { + throw new TypeCheckerException(Diagnostic.Error($"Function {expression.Name} not found on struct {structType}").At(expression).Build()); + } + + var parameters = new List(); + for (var i = 0; i < expression.Parameters.Count; i++) + { + var parameter = expression.Parameters[i]; + var expectedType = function.Type.Parameters[i]; + + var parameterExpression = CheckExpression(parameter, expectedType); + if (parameterExpression.Type != expectedType) + { + throw new TypeCheckerException(Diagnostic.Error($"Parameter {i + 1} does not match the type {expectedType} for function {function}").At(parameter).Build()); + } + + parameters.Add(parameterExpression); + } + + return new StructFuncCallNode(function.Type.ReturnType, expression.Name, structType, target, parameters); + } + + throw new TypeCheckerException(Diagnostic.Error($"No function {expression.Name} exists on type {target.Type}").Build()); } private ExpressionNode CheckLocalIdentifier(LocalIdentifierSyntax expression) @@ -479,11 +669,6 @@ public sealed class TypeChecker return new StructInitializerNode(structType, initializers); } - private UnaryExpressionNode CheckUnaryExpression(UnaryExpressionSyntax expression) - { - throw new NotImplementedException(); - } - private BlockNode CheckBlock(BlockSyntax node, Scope? scope = null) { var statements = new List(); @@ -555,7 +740,12 @@ public sealed class TypeChecker var fields = structType.Fields.Select(x => new StructTypeField(x.Name, ResolveType(x.Type), x.HasDefaultValue)).ToList(); result.Fields.AddRange(fields); - // todo(nub31): Function implementations + foreach (var function in structType.Functions) + { + var parameters = function.Parameters.Select(x => ResolveType(x.Type)).ToList(); + var type = new FuncTypeNode(parameters, ResolveType(function.ReturnType)); + result.Functions.Add(new StructTypeFunc(function.Name, type)); + } _referencedStructTypes.Add(result); return result; diff --git a/example/src/main.nub b/example/src/main.nub index 648ee60..cef75c5 100644 --- a/example/src/main.nub +++ b/example/src/main.nub @@ -9,10 +9,12 @@ export struct Human extern "main" func main(args: []cstring): i64 { - let x: Human = { - name = "test" + let x = [1]Human + + x[0] = { + name = "oliver" } - puts(x.name) + puts(x[0].name) return 0 }