From aa5bf0b5680e9c7e8fa5b4e3d49fa49150669e66 Mon Sep 17 00:00:00 2001 From: nub31 Date: Thu, 26 Feb 2026 20:00:31 +0100 Subject: [PATCH] ... --- compiler/Parser.cs | 30 ++++- compiler/TypeChecker.cs | 239 ++++++++++++++++++++++++++++------------ examples/math/math.nub | 6 + 3 files changed, 196 insertions(+), 79 deletions(-) diff --git a/compiler/Parser.cs b/compiler/Parser.cs index a592e11..586081e 100644 --- a/compiler/Parser.cs +++ b/compiler/Parser.cs @@ -193,10 +193,14 @@ public class Parser if (TryExpectKeyword(Keyword.Let)) { var name = ExpectIdent(); - ExpectSymbol(Symbol.Colon); - var type = ParseType(); + + NodeType? type = null; + if (TryExpectSymbol(Symbol.Colon)) + type = ParseType(); + ExpectSymbol(Symbol.Equal); var value = ParseExpression(); + return new NodeStatementVariableDeclaration(TokensFrom(startIndex), name, type, value); } @@ -315,6 +319,20 @@ public class Parser var target = ParseExpression(); expr = new NodeExpressionUnary(TokensFrom(startIndex), target, NodeExpressionUnary.Op.Negate); } + else if (TryExpectSymbol(Symbol.OpenCurly)) + { + var initializers = new List(); + while (!TryExpectSymbol(Symbol.CloseCurly)) + { + var initializerStartIndex = startIndex; + var fieldName = ExpectIdent(); + ExpectSymbol(Symbol.Equal); + var fieldValue = ParseExpression(); + initializers.Add(new NodeExpressionStructLiteral.Initializer(TokensFrom(initializerStartIndex), fieldName, fieldValue)); + } + + expr = new NodeExpressionStructLiteral(TokensFrom(startIndex), null, initializers); + } else if (TryExpectSymbol(Symbol.Bang)) { var target = ParseExpression(); @@ -760,10 +778,10 @@ public class NodeStatementReturn(List tokens, NodeExpression value) : Nod public NodeExpression Value { get; } = value; } -public class NodeStatementVariableDeclaration(List tokens, TokenIdent name, NodeType type, NodeExpression value) : NodeStatement(tokens) +public class NodeStatementVariableDeclaration(List tokens, TokenIdent name, NodeType? type, NodeExpression value) : NodeStatement(tokens) { public TokenIdent Name { get; } = name; - public NodeType Type { get; } = type; + public NodeType? Type { get; } = type; public NodeExpression Value { get; } = value; } @@ -816,9 +834,9 @@ public class NodeExpressionBoolLiteral(List tokens, TokenBoolLiteral valu public TokenBoolLiteral Value { get; } = value; } -public class NodeExpressionStructLiteral(List tokens, NodeType type, List initializers) : NodeExpression(tokens) +public class NodeExpressionStructLiteral(List tokens, NodeType? type, List initializers) : NodeExpression(tokens) { - public NodeType Type { get; } = type; + public NodeType? Type { get; } = type; public List Initializers { get; } = initializers; public class Initializer(List tokens, TokenIdent name, NodeExpression value) : Node(tokens) diff --git a/compiler/TypeChecker.cs b/compiler/TypeChecker.cs index 6191c9a..97c1532 100644 --- a/compiler/TypeChecker.cs +++ b/compiler/TypeChecker.cs @@ -1,5 +1,6 @@ using System.Data.Common; using System.Diagnostics.CodeAnalysis; +using System.Formats.Tar; namespace Compiler; @@ -21,6 +22,7 @@ public class TypeChecker private readonly string fileName; private readonly string currentModule; private readonly NodeDefinitionFunc function; + private NubType functionReturnType = null!; private readonly ModuleGraph moduleGraph; private readonly Scope scope = new(); @@ -31,7 +33,16 @@ public class TypeChecker var parameters = new List(); var invalidParameter = false; TypedNodeStatement? body = null; - NubType? returnType = null; + + try + { + functionReturnType = ResolveType(function.ReturnType); + } + catch (CompileException e) + { + diagnostics.Add(e.Diagnostic); + return null; + } using (scope.EnterScope()) { @@ -63,19 +74,10 @@ public class TypeChecker diagnostics.Add(e.Diagnostic); } - try - { - returnType = ResolveType(function.ReturnType); - } - catch (CompileException e) - { - diagnostics.Add(e.Diagnostic); - } - - if (body == null || returnType is null || invalidParameter) + if (body == null || invalidParameter) return null; - return new TypedNodeDefinitionFunc(function.Tokens, currentModule, function.Name, parameters, body, returnType); + return new TypedNodeDefinitionFunc(function.Tokens, currentModule, function.Name, parameters, body, functionReturnType); } } @@ -97,7 +99,10 @@ public class TypeChecker private TypedNodeStatementAssignment CheckStatementAssignment(NodeStatementAssignment statement) { - return new TypedNodeStatementAssignment(statement.Tokens, CheckExpression(statement.Target), CheckExpression(statement.Value)); + var target = CheckExpression(statement.Target, null); + var value = CheckExpression(statement.Value, target.Type); + + return new TypedNodeStatementAssignment(statement.Tokens, target, value); } private TypedNodeStatementBlock CheckStatementBlock(NodeStatementBlock statement) @@ -114,27 +119,56 @@ public class TypeChecker if (statement.Expression is not NodeExpressionFuncCall funcCall) throw BasicError("Expected statement or function call", statement); - return new TypedNodeStatementFuncCall(statement.Tokens, CheckExpression(funcCall.Target), funcCall.Parameters.Select(CheckExpression).ToList()); + var target = CheckExpression(funcCall.Target, null); + if (target.Type is not NubTypeFunc funcType) + throw BasicError("Expected a function type", target); + + if (funcType.Parameters.Count != funcCall.Parameters.Count) + throw BasicError($"Expected {funcType.Parameters.Count} parameters but got {funcCall.Parameters.Count}", funcCall); + + var parameters = new List(); + for (int i = 0; i < funcCall.Parameters.Count; i++) + { + parameters.Add(CheckExpression(funcCall.Parameters[i], funcType.Parameters[i])); + } + + return new TypedNodeStatementFuncCall(statement.Tokens, target, parameters); } private TypedNodeStatementIf CheckStatementIf(NodeStatementIf statement) { - return new TypedNodeStatementIf(statement.Tokens, CheckExpression(statement.Condition), CheckStatement(statement.ThenBlock), statement.ElseBlock == null ? null : CheckStatement(statement.ElseBlock)); + var condition = CheckExpression(statement.Condition, NubTypeBool.Instance); + if (!condition.Type.IsAssignableTo(NubTypeBool.Instance)) + throw BasicError("Condition part of if statement must be a boolean", condition); + + var thenBlock = CheckStatement(statement.ThenBlock); + var elseBlock = statement.ElseBlock == null ? null : CheckStatement(statement.ElseBlock); + + return new TypedNodeStatementIf(statement.Tokens, condition, thenBlock, elseBlock); } private TypedNodeStatementReturn CheckStatementReturn(NodeStatementReturn statement) { - return new TypedNodeStatementReturn(statement.Tokens, CheckExpression(statement.Value)); + var value = CheckExpression(statement.Value, functionReturnType); + if (!value.Type.IsAssignableTo(functionReturnType)) + throw BasicError($"Type of returned value ({value.Type}) is not assignable to the return type of the function ({functionReturnType})", value); + + return new TypedNodeStatementReturn(statement.Tokens, value); } private TypedNodeStatementVariableDeclaration CheckStatementVariableDeclaration(NodeStatementVariableDeclaration statement) { - var type = ResolveType(statement.Type); - var value = CheckExpression(statement.Value); + NubType? type = null; + if (statement.Type != null) + type = ResolveType(statement.Type); - if (!value.Type.IsAssignableTo(type)) + var value = CheckExpression(statement.Value, type); + + if (type is not null && !value.Type.IsAssignableTo(type)) throw BasicError("Type of variable does match type of assigned value", value); + type ??= value.Type; + scope.DeclareIdentifier(statement.Name.Ident, type); return new TypedNodeStatementVariableDeclaration(statement.Tokens, statement.Name, type, value); @@ -142,15 +176,22 @@ public class TypeChecker private TypedNodeStatementWhile CheckStatementWhile(NodeStatementWhile statement) { - return new TypedNodeStatementWhile(statement.Tokens, CheckExpression(statement.Condition), CheckStatement(statement.Body)); + var condition = CheckExpression(statement.Condition, NubTypeBool.Instance); + if (!condition.Type.IsAssignableTo(NubTypeBool.Instance)) + throw BasicError("Condition part of if statement must be a boolean", condition); + + var body = CheckStatement(statement.Body); + + return new TypedNodeStatementWhile(statement.Tokens, condition, body); } private TypedNodeStatementMatch CheckStatementMatch(NodeStatementMatch statement) { - var cases = new List(); - var target = CheckExpression(statement.Target); - var enumType = (NubTypeEnum)target.Type; + var target = CheckExpression(statement.Target, null); + if (target.Type is not NubTypeEnum enumType) + throw BasicError("A match statement can only be used on enum types", target); + var cases = new List(); foreach (var @case in statement.Cases) { using (scope.EnterScope()) @@ -164,28 +205,29 @@ public class TypeChecker return new TypedNodeStatementMatch(statement.Tokens, target, cases); } - private TypedNodeExpression CheckExpression(NodeExpression node) + private TypedNodeExpression CheckExpression(NodeExpression node, NubType? expectedType) { return node switch { - NodeExpressionBinary expression => CheckExpressionBinary(expression), - NodeExpressionUnary expression => CheckExpressionUnary(expression), - NodeExpressionBoolLiteral expression => CheckExpressionBoolLiteral(expression), - NodeExpressionIdent expression => CheckExpressionIdent(expression), - NodeExpressionIntLiteral expression => CheckExpressionIntLiteral(expression), - NodeExpressionMemberAccess expression => CheckExpressionMemberAccess(expression), - NodeExpressionFuncCall expression => CheckExpressionFuncCall(expression), - NodeExpressionStringLiteral expression => CheckExpressionStringLiteral(expression), - NodeExpressionStructLiteral expression => CheckExpressionStructLiteral(expression), - NodeExpressionEnumLiteral expression => CheckExpressionEnumLiteral(expression), + NodeExpressionBinary expression => CheckExpressionBinary(expression, expectedType), + NodeExpressionUnary expression => CheckExpressionUnary(expression, expectedType), + NodeExpressionBoolLiteral expression => CheckExpressionBoolLiteral(expression, expectedType), + NodeExpressionIdent expression => CheckExpressionIdent(expression, expectedType), + NodeExpressionIntLiteral expression => CheckExpressionIntLiteral(expression, expectedType), + NodeExpressionMemberAccess expression => CheckExpressionMemberAccess(expression, expectedType), + NodeExpressionFuncCall expression => CheckExpressionFuncCall(expression, expectedType), + NodeExpressionStringLiteral expression => CheckExpressionStringLiteral(expression, expectedType), + NodeExpressionStructLiteral expression => CheckExpressionStructLiteral(expression, expectedType), + NodeExpressionEnumLiteral expression => CheckExpressionEnumLiteral(expression, expectedType), _ => throw new ArgumentOutOfRangeException(nameof(node)) }; } - private TypedNodeExpressionBinary CheckExpressionBinary(NodeExpressionBinary expression) + private TypedNodeExpressionBinary CheckExpressionBinary(NodeExpressionBinary expression, NubType? expectedType) { - var left = CheckExpression(expression.Left); - var right = CheckExpression(expression.Right); + // todo(nub31): Add proper inference here + var left = CheckExpression(expression.Left, null); + var right = CheckExpression(expression.Right, null); NubType type; switch (expression.Operation) @@ -275,9 +317,10 @@ public class TypeChecker }; } - private TypedNodeExpressionUnary CheckExpressionUnary(NodeExpressionUnary expression) + private TypedNodeExpressionUnary CheckExpressionUnary(NodeExpressionUnary expression, NubType? expectedType) { - var target = CheckExpression(expression.Target); + // todo(nub31): Add proper inference here + var target = CheckExpression(expression.Target, null); NubType type; switch (expression.Operation) @@ -315,12 +358,12 @@ public class TypeChecker }; } - private TypedNodeExpressionBoolLiteral CheckExpressionBoolLiteral(NodeExpressionBoolLiteral expression) + private TypedNodeExpressionBoolLiteral CheckExpressionBoolLiteral(NodeExpressionBoolLiteral expression, NubType? expectedType) { return new TypedNodeExpressionBoolLiteral(expression.Tokens, NubTypeBool.Instance, expression.Value); } - private TypedNodeExpression CheckExpressionIdent(NodeExpressionIdent expression) + private TypedNodeExpression CheckExpressionIdent(NodeExpressionIdent expression, NubType? expectedType) { if (expression.Sections.Count == 1) { @@ -345,14 +388,14 @@ public class TypeChecker throw BasicError($"Unknown identifier '{string.Join("::", expression.Sections.Select(x => x.Ident))}'", expression); } - private TypedNodeExpressionIntLiteral CheckExpressionIntLiteral(NodeExpressionIntLiteral expression) + private TypedNodeExpressionIntLiteral CheckExpressionIntLiteral(NodeExpressionIntLiteral expression, NubType? expectedType) { return new TypedNodeExpressionIntLiteral(expression.Tokens, NubTypeSInt.Get(32), expression.Value); } - private TypedNodeExpressionMemberAccess CheckExpressionMemberAccess(NodeExpressionMemberAccess expression) + private TypedNodeExpressionMemberAccess CheckExpressionMemberAccess(NodeExpressionMemberAccess expression, NubType? expectedType) { - var target = CheckExpression(expression.Target); + var target = CheckExpression(expression.Target, null); switch (target.Type) { @@ -395,55 +438,105 @@ public class TypeChecker } } - private TypedNodeExpressionFuncCall CheckExpressionFuncCall(NodeExpressionFuncCall expression) + private TypedNodeExpressionFuncCall CheckExpressionFuncCall(NodeExpressionFuncCall expression, NubType? expectedType) { - var target = CheckExpression(expression.Target); + var target = CheckExpression(expression.Target, null); if (target.Type is not NubTypeFunc funcType) - throw BasicError($"Cannot invoke function call on type '{target.Type}'", target); + throw BasicError("Expected a function type", target); - var parameters = expression.Parameters.Select(CheckExpression).ToList(); + if (funcType.Parameters.Count != expression.Parameters.Count) + throw BasicError($"Expected {funcType.Parameters.Count} parameters but got {expression.Parameters.Count}", expression); + + var parameters = new List(); + for (int i = 0; i < expression.Parameters.Count; i++) + { + parameters.Add(CheckExpression(expression.Parameters[i], funcType.Parameters[i])); + } return new TypedNodeExpressionFuncCall(expression.Tokens, funcType.ReturnType, target, parameters); } - private TypedNodeExpressionStringLiteral CheckExpressionStringLiteral(NodeExpressionStringLiteral expression) + private TypedNodeExpressionStringLiteral CheckExpressionStringLiteral(NodeExpressionStringLiteral expression, NubType? expectedType) { return new TypedNodeExpressionStringLiteral(expression.Tokens, NubTypeString.Instance, expression.Value); } - private TypedNodeExpressionStructLiteral CheckExpressionStructLiteral(NodeExpressionStructLiteral expression) + private TypedNodeExpressionStructLiteral CheckExpressionStructLiteral(NodeExpressionStructLiteral expression, NubType? expectedType) { - var type = ResolveType(expression.Type); - if (type is not NubTypeStruct structType) - throw BasicError("Type of struct literal is not a struct", expression.Type); - - if (!moduleGraph.TryResolveType(structType.Module, structType.Name, structType.Module == currentModule, out var info)) - throw BasicError($"Type '{structType}' struct literal not found", expression.Type); - - if (info is not Module.TypeInfoStruct structInfo) - throw BasicError($"Type '{structType}' is not a struct", expression.Type); - - var initializers = new List(); - foreach (var initializer in expression.Initializers) + if (expression.Type != null) { - var field = structInfo.Fields.FirstOrDefault(x => x.Name == initializer.Name.Ident); - if (field == null) - throw BasicError($"Field '{initializer.Name.Ident}' does not exist on struct '{structType.Module}::{structType.Name}'", initializer.Name); + var type = ResolveType(expression.Type); + if (type is not NubTypeStruct structType) + throw BasicError("Type of struct literal is not a struct", expression); + + if (!moduleGraph.TryResolveType(structType.Module, structType.Name, structType.Module == currentModule, out var info)) + throw BasicError($"Type '{structType}' struct literal not found", expression); - var value = CheckExpression(initializer.Value); - if (!value.Type.IsAssignableTo(field.Type)) - throw BasicError($"Type of assignment ({value.Type}) does not match expected type of field '{field.Name}' ({field.Type})", initializer.Name); + if (info is not Module.TypeInfoStruct structInfo) + throw BasicError($"Type '{structType}' is not a struct", expression.Type); - initializers.Add(new TypedNodeExpressionStructLiteral.Initializer(initializer.Tokens, initializer.Name, value)); + var initializers = new List(); + foreach (var initializer in expression.Initializers) + { + var field = structInfo.Fields.FirstOrDefault(x => x.Name == initializer.Name.Ident); + if (field == null) + throw BasicError($"Field '{initializer.Name.Ident}' does not exist on struct '{structType.Module}::{structType.Name}'", initializer.Name); + + var value = CheckExpression(initializer.Value, field.Type); + if (!value.Type.IsAssignableTo(field.Type)) + throw BasicError($"Type of assignment ({value.Type}) does not match expected type of field '{field.Name}' ({field.Type})", initializer.Name); + + initializers.Add(new TypedNodeExpressionStructLiteral.Initializer(initializer.Tokens, initializer.Name, value)); + } + + return new TypedNodeExpressionStructLiteral(expression.Tokens, structType, initializers); } + else if (expectedType is NubTypeStruct structType) + { + if (!moduleGraph.TryResolveType(structType.Module, structType.Name, structType.Module == currentModule, out var info)) + throw BasicError($"Type '{structType}' struct literal not found", expression); - return new TypedNodeExpressionStructLiteral(expression.Tokens, NubTypeStruct.Get(structType.Module, structType.Name), initializers); + if (info is not Module.TypeInfoStruct structInfo) + throw BasicError($"Type '{structType}' is not a struct", expression); + + var initializers = new List(); + foreach (var initializer in expression.Initializers) + { + var field = structInfo.Fields.FirstOrDefault(x => x.Name == initializer.Name.Ident); + if (field == null) + throw BasicError($"Field '{initializer.Name.Ident}' does not exist on struct '{structType.Module}::{structType.Name}'", initializer.Name); + + var value = CheckExpression(initializer.Value, field.Type); + if (!value.Type.IsAssignableTo(field.Type)) + throw BasicError($"Type of assignment ({value.Type}) does not match expected type of field '{field.Name}' ({field.Type})", initializer.Name); + + initializers.Add(new TypedNodeExpressionStructLiteral.Initializer(initializer.Tokens, initializer.Name, value)); + } + + return new TypedNodeExpressionStructLiteral(expression.Tokens, structType, initializers); + } + // todo(nub31): Infer anonymous struct types if expectedType is anonymous struct + else + { + var initializers = new List(); + foreach (var initializer in expression.Initializers) + { + var value = CheckExpression(initializer.Value, null); + initializers.Add(new TypedNodeExpressionStructLiteral.Initializer(initializer.Tokens, initializer.Name, value)); + } + + var type = NubTypeAnonymousStruct.Get(initializers.Select(x => new NubTypeAnonymousStruct.Field(x.Name.Ident, x.Value.Type)).ToList()); + + return new TypedNodeExpressionStructLiteral(expression.Tokens, type, initializers); + } } - private TypedNodeExpressionEnumLiteral CheckExpressionEnumLiteral(NodeExpressionEnumLiteral expression) + private TypedNodeExpressionEnumLiteral CheckExpressionEnumLiteral(NodeExpressionEnumLiteral expression, NubType? expectedType) { - var value = CheckExpression(expression.Value); - return new TypedNodeExpressionEnumLiteral(expression.Tokens, NubTypeEnumVariant.Get(NubTypeEnum.Get(expression.Module.Ident, expression.EnumName.Ident), expression.VariantName.Ident), value); + // todo(nub31): Infer type of enum variant + var type = NubTypeEnumVariant.Get(NubTypeEnum.Get(expression.Module.Ident, expression.EnumName.Ident), expression.VariantName.Ident); + var value = CheckExpression(expression.Value, null); + return new TypedNodeExpressionEnumLiteral(expression.Tokens, type, value); } private NubType ResolveType(NodeType node) diff --git a/examples/math/math.nub b/examples/math/math.nub index 3652799..1ae7fb2 100644 --- a/examples/math/math.nub +++ b/examples/math/math.nub @@ -40,6 +40,12 @@ export func add(a: i32 b: i32): i32 } } + let color: color = { + r = 23 + g = 23 + b = 23 + } + return add_internal(a b) }