From 942b38992beb4bdcc15711d4b853b2751a0e4e42 Mon Sep 17 00:00:00 2001 From: nub31 Date: Sat, 17 May 2025 15:53:28 +0200 Subject: [PATCH] ... --- example/program.nub | 4 +- src/compiler/Nub.Lang/Backend/Generator.cs | 101 ++++++------- .../Nub.Lang/Frontend/Lexing/Lexer.cs | 2 + .../Nub.Lang/Frontend/Lexing/SymbolToken.cs | 4 +- .../Nub.Lang/Frontend/Parsing/Parser.cs | 126 ++++++++-------- .../Nub.Lang/Frontend/Typing/TypeChecker.cs | 134 ++++++++++-------- src/compiler/Nub.Lang/NubType.cs | 13 +- src/compiler/Nub.Lang/Program.cs | 4 +- 8 files changed, 194 insertions(+), 194 deletions(-) diff --git a/example/program.nub b/example/program.nub index aa00e3a..b8a5f59 100644 --- a/example/program.nub +++ b/example/program.nub @@ -3,6 +3,6 @@ import c global func main(argc: i64, argv: i64) { printf("args: %d, starts at %p\n", argc, argv) - x: i8 = (i8)320000 - printf("%d\n", x) + + printf("%s\n", list.text) } \ No newline at end of file diff --git a/src/compiler/Nub.Lang/Backend/Generator.cs b/src/compiler/Nub.Lang/Backend/Generator.cs index d17a2cf..13abfbb 100644 --- a/src/compiler/Nub.Lang/Backend/Generator.cs +++ b/src/compiler/Nub.Lang/Backend/Generator.cs @@ -8,11 +8,11 @@ public class Generator { private readonly List _definitions; private readonly StringBuilder _builder = new(); - private readonly Dictionary _prefixIndexes = new(); private readonly Dictionary _variables = new(); private readonly List _strings = []; private readonly Stack _breakLabels = new(); private readonly Stack _continueLabels = new(); + private int _variableIndex; private bool _codeIsReachable = true; public Generator(List definitions) @@ -72,7 +72,7 @@ public class Generator throw new ArgumentOutOfRangeException(); } } - case NubCustomType nubCustomType: + case NubStructType: { return "l"; } @@ -114,7 +114,7 @@ public class Generator throw new ArgumentOutOfRangeException(); } } - case NubCustomType nubCustomType: + case NubStructType nubCustomType: { return ":" + nubCustomType.Name; } @@ -159,7 +159,7 @@ public class Generator throw new ArgumentOutOfRangeException(); } } - case NubCustomType nubCustomType: + case NubStructType nubCustomType: { return ":" + nubCustomType.Name; } @@ -170,7 +170,7 @@ public class Generator } } - private static int QbeTypeSize(NubType type) + private int QbeTypeSize(NubType type) { switch (type) { @@ -201,9 +201,14 @@ public class Generator throw new ArgumentOutOfRangeException(); } } - case NubCustomType nubCustomType: + case NubStructType nubCustomType: { - return 8; + var definition = _definitions.OfType().FirstOrDefault(s => s.Name == nubCustomType.Name); + if (definition == null) + { + throw new Exception($"Cannot determine size of non-existent type {nubCustomType}"); + } + return definition.Fields.Sum(f => QbeTypeSize(f.Type)); } default: { @@ -246,19 +251,19 @@ public class Generator switch (FQT(parameter.Type)) { case "sb": - parameterName = GenName("c"); + parameterName = GenName(); _builder.AppendLine($" %{parameterName} =w extsb %{parameter.Name}"); break; case "ub": - parameterName = GenName("c"); + parameterName = GenName(); _builder.AppendLine($" %{parameterName} =w extub %{parameter.Name}"); break; case "sh": - parameterName = GenName("c"); + parameterName = GenName(); _builder.AppendLine($" %{parameterName} =w extsh %{parameter.Name}"); break; case "uh": - parameterName = GenName("c"); + parameterName = GenName(); _builder.AppendLine($" %{parameterName} =w extuh %{parameter.Name}"); break; } @@ -403,9 +408,9 @@ public class Generator private void GenerateIf(IfNode ifStatement) { - var trueLabel = GenName("true"); - var falseLabel = GenName("false"); - var endLabel = GenName("endif"); + var trueLabel = GenName(); + var falseLabel = GenName(); + var endLabel = GenName(); var result = GenerateExpression(ifStatement.Condition); _builder.AppendLine($" jnz {result}, @{trueLabel}, @{falseLabel}"); @@ -450,9 +455,9 @@ public class Generator private void GenerateWhile(WhileNode whileStatement) { - var conditionLabel = GenName("condition"); - var iterationLabel = GenName("iteration"); - var endLabel = GenName("endloop"); + var conditionLabel = GenName(); + var iterationLabel = GenName(); + var endLabel = GenName(); _breakLabels.Push(endLabel); _continueLabels.Push(conditionLabel); @@ -734,7 +739,7 @@ public class Generator throw new NotSupportedException("Casting is only supported for primitive types"); } - var outputLabel = GenName("c"); + var outputLabel = GenName(); switch (primitiveInputType.Kind) { @@ -779,7 +784,7 @@ public class Generator case PrimitiveTypeKind.U8: return input; case PrimitiveTypeKind.F64: - var extLabel = GenName("ext"); + var extLabel = GenName(); _builder.AppendLine($" %{extLabel} =l extsw {input}"); _builder.AppendLine($" %{outputLabel} =d sltof %{extLabel}"); return $"%{outputLabel}"; @@ -790,7 +795,7 @@ public class Generator _builder.AppendLine($" %{outputLabel} =l call $nub_i32_to_string(w {input})"); return $"%{outputLabel}"; case PrimitiveTypeKind.Any: - var extAnyLabel = GenName("ext"); + var extAnyLabel = GenName(); _builder.AppendLine($" %{extAnyLabel} =l extsw {input}"); return $"%{extAnyLabel}"; case PrimitiveTypeKind.Bool: @@ -815,28 +820,28 @@ public class Generator return input; case PrimitiveTypeKind.F64: { - var extLabel = GenName("ext"); + var extLabel = GenName(); _builder.AppendLine($" %{extLabel} =l extsh {input}"); _builder.AppendLine($" %{outputLabel} =d sltof %{extLabel}"); return $"%{outputLabel}"; } case PrimitiveTypeKind.F32: { - var extLabel = GenName("ext"); + var extLabel = GenName(); _builder.AppendLine($" %{extLabel} =w extsh {input}"); _builder.AppendLine($" %{outputLabel} =s swtof %{extLabel}"); return $"%{outputLabel}"; } case PrimitiveTypeKind.String: { - var extLabel = GenName("ext"); + var extLabel = GenName(); _builder.AppendLine($" %{extLabel} =w extsh {input}"); _builder.AppendLine($" %{outputLabel} =l call $nub_i32_to_string(w %{extLabel})"); return $"%{outputLabel}"; } case PrimitiveTypeKind.Any: { - var extLabel = GenName("ext"); + var extLabel = GenName(); _builder.AppendLine($" %{extLabel} =l extsh {input}"); return $"%{extLabel}"; } @@ -862,28 +867,28 @@ public class Generator return input; case PrimitiveTypeKind.F64: { - var extLabel = GenName("ext"); + var extLabel = GenName(); _builder.AppendLine($" %{extLabel} =l extsb {input}"); _builder.AppendLine($" %{outputLabel} =d sltof %{extLabel}"); return $"%{outputLabel}"; } case PrimitiveTypeKind.F32: { - var extLabel = GenName("ext"); + var extLabel = GenName(); _builder.AppendLine($" %{extLabel} =w extsb {input}"); _builder.AppendLine($" %{outputLabel} =s swtof %{extLabel}"); return $"%{outputLabel}"; } case PrimitiveTypeKind.String: { - var extLabel = GenName("ext"); + var extLabel = GenName(); _builder.AppendLine($" %{extLabel} =w extsb {input}"); _builder.AppendLine($" %{outputLabel} =l call $nub_i32_to_string(w %{extLabel})"); return $"%{outputLabel}"; } case PrimitiveTypeKind.Any: { - var extLabel = GenName("ext"); + var extLabel = GenName(); _builder.AppendLine($" %{extLabel} =l extsb {input}"); return $"%{extLabel}"; } @@ -933,7 +938,7 @@ public class Generator case PrimitiveTypeKind.U8: return input; case PrimitiveTypeKind.F64: - var extLabel = GenName("ext"); + var extLabel = GenName(); _builder.AppendLine($" %{extLabel} =l extuw {input}"); _builder.AppendLine($" %{outputLabel} =d ultof %{extLabel}"); return $"%{outputLabel}"; @@ -944,7 +949,7 @@ public class Generator _builder.AppendLine($" %{outputLabel} =l call $nub_u32_to_string(w {input})"); return $"%{outputLabel}"; case PrimitiveTypeKind.Any: - var extAnyLabel = GenName("ext"); + var extAnyLabel = GenName(); _builder.AppendLine($" %{extAnyLabel} =l extuw {input}"); return $"%{extAnyLabel}"; case PrimitiveTypeKind.Bool: @@ -969,28 +974,28 @@ public class Generator return input; case PrimitiveTypeKind.F64: { - var extLabel = GenName("ext"); + var extLabel = GenName(); _builder.AppendLine($" %{extLabel} =l extuh {input}"); _builder.AppendLine($" %{outputLabel} =d ultof %{extLabel}"); return $"%{outputLabel}"; } case PrimitiveTypeKind.F32: { - var extLabel = GenName("ext"); + var extLabel = GenName(); _builder.AppendLine($" %{extLabel} =w extuh {input}"); _builder.AppendLine($" %{outputLabel} =s uwtof %{extLabel}"); return $"%{outputLabel}"; } case PrimitiveTypeKind.String: { - var extLabel = GenName("ext"); + var extLabel = GenName(); _builder.AppendLine($" %{extLabel} =w extuh {input}"); _builder.AppendLine($" %{outputLabel} =l call $nub_u32_to_string(w %{extLabel})"); return $"%{outputLabel}"; } case PrimitiveTypeKind.Any: { - var extLabel = GenName("ext"); + var extLabel = GenName(); _builder.AppendLine($" %{extLabel} =l extuh {input}"); return $"%{extLabel}"; } @@ -1016,28 +1021,28 @@ public class Generator return input; case PrimitiveTypeKind.F64: { - var extLabel = GenName("ext"); + var extLabel = GenName(); _builder.AppendLine($" %{extLabel} =l extub {input}"); _builder.AppendLine($" %{outputLabel} =d ultof %{extLabel}"); return $"%{outputLabel}"; } case PrimitiveTypeKind.F32: { - var extLabel = GenName("ext"); + var extLabel = GenName(); _builder.AppendLine($" %{extLabel} =w extub {input}"); _builder.AppendLine($" %{outputLabel} =s uwtof %{extLabel}"); return $"%{outputLabel}"; } case PrimitiveTypeKind.String: { - var extLabel = GenName("ext"); + var extLabel = GenName(); _builder.AppendLine($" %{extLabel} =w extub {input}"); _builder.AppendLine($" %{outputLabel} =l call $nub_u32_to_string(w %{extLabel})"); return $"%{outputLabel}"; } case PrimitiveTypeKind.Any: { - var extLabel = GenName("ext"); + var extLabel = GenName(); _builder.AppendLine($" %{extLabel} =l extub {input}"); return $"%{extLabel}"; } @@ -1196,15 +1201,13 @@ public class Generator private string GenerateStructInitializer(StructInitializerNode structInitializer) { - var structDefinition = _definitions.OfType() - .FirstOrDefault(s => s.Name == structInitializer.StructType.Name); - + var structDefinition = _definitions.OfType().FirstOrDefault(s => s.Name == structInitializer.StructType.Name); if (structDefinition == null) { throw new Exception($"Struct {structInitializer.StructType.Name} is not defined"); } - var structVar = GenName("struct"); + var structVar = GenName(); var size = structDefinition.Fields.Sum(x => QbeTypeSize(x.Type)); _builder.AppendLine($" %{structVar} =l alloc8 {size}"); @@ -1216,14 +1219,14 @@ public class Generator if (structInitializer.Initializers.TryGetValue(field.Name, out var fieldValue)) { var var = GenerateExpression(fieldValue); - var offsetLabel = GenName("offset"); + var offsetLabel = GenName(); _builder.AppendLine($" %{offsetLabel} =l add %{structVar}, {i * QbeTypeSize(field.Type)}"); _builder.AppendLine($" store{SQT(field.Type)} {var}, %{offsetLabel}"); } else if (field.Value.HasValue) { var var = GenerateExpression(field.Value.Value); - var offsetLabel = GenName("offset"); + var offsetLabel = GenName(); _builder.AppendLine($" %{offsetLabel} =l add %{structVar}, {i * QbeTypeSize(field.Type)}"); _builder.AppendLine($" store{SQT(field.Type)} {var}, %{offsetLabel}"); } @@ -1265,10 +1268,10 @@ public class Generator throw new Exception($"Field {structFieldAccessor.Field} is not defined in struct {structType.Name}"); } - var offsetLabel = GenName("offset"); + var offsetLabel = GenName(); _builder.AppendLine($" %{offsetLabel} =l add {@struct}, {fieldIndex * QbeTypeSize(structFieldAccessor.Type)}"); - var outputLabel = GenName("field"); + var outputLabel = GenName(); _builder.AppendLine($" %{outputLabel} ={SQT(structFieldAccessor.Type)} load{SQT(structFieldAccessor.Type)} %{offsetLabel}"); return $"%{outputLabel}"; @@ -1282,11 +1285,9 @@ public class Generator return $"%{outputLabel}"; } - private string GenName(string prefix = "v") + private string GenName() { - var index = _prefixIndexes.GetValueOrDefault(prefix, 0); - _prefixIndexes[prefix] = index + 1; - return $"{prefix}{index}"; + return $"v{++_variableIndex}"; } private class Variable diff --git a/src/compiler/Nub.Lang/Frontend/Lexing/Lexer.cs b/src/compiler/Nub.Lang/Frontend/Lexing/Lexer.cs index e36c490..5255371 100644 --- a/src/compiler/Nub.Lang/Frontend/Lexing/Lexer.cs +++ b/src/compiler/Nub.Lang/Frontend/Lexing/Lexer.cs @@ -50,6 +50,8 @@ public class Lexer ['*'] = Symbol.Star, ['/'] = Symbol.ForwardSlash, ['!'] = Symbol.Bang, + ['^'] = Symbol.Caret, + ['&'] = Symbol.Ampersand, }; private string _src = string.Empty; diff --git a/src/compiler/Nub.Lang/Frontend/Lexing/SymbolToken.cs b/src/compiler/Nub.Lang/Frontend/Lexing/SymbolToken.cs index 068b52f..882f751 100644 --- a/src/compiler/Nub.Lang/Frontend/Lexing/SymbolToken.cs +++ b/src/compiler/Nub.Lang/Frontend/Lexing/SymbolToken.cs @@ -39,5 +39,7 @@ public enum Symbol Star, ForwardSlash, New, - Struct + Struct, + Caret, + Ampersand } \ No newline at end of file diff --git a/src/compiler/Nub.Lang/Frontend/Parsing/Parser.cs b/src/compiler/Nub.Lang/Frontend/Parsing/Parser.cs index ce52b19..1ff5e83 100644 --- a/src/compiler/Nub.Lang/Frontend/Parsing/Parser.cs +++ b/src/compiler/Nub.Lang/Frontend/Parsing/Parser.cs @@ -35,12 +35,12 @@ public class Parser private DefinitionNode ParseDefinition() { List modifiers = []; - + while (TryExpectModifier(out var modifier)) { modifiers.Add(modifier); } - + var keyword = ExpectSymbol(); return keyword.Symbol switch { @@ -76,20 +76,22 @@ public class Parser { throw new Exception($"Modifiers: {string.Join(", ", modifiers)} is not valid for an extern function"); } + return new ExternFuncDefinitionNode(name.Value, parameters, returnType); } - + var body = ParseBlock(); var global = modifiers.Remove(Modifier.Global); - + if (modifiers.Count != 0) { throw new Exception($"Modifiers: {string.Join(", ", modifiers)} is not valid for a local function"); } + return new LocalFuncDefinitionNode(name.Value, parameters, body, returnType, global); } - private StructDefinitionNode ParseStruct(List modifiers) + private StructDefinitionNode ParseStruct(List _) { var name = ExpectIdentifier().Value; @@ -157,14 +159,14 @@ public class Parser case Symbol.Assign: { var value = ParseExpression(); - return new VariableAssignmentNode(identifier.Value, Optional.Empty(), value); + return new VariableAssignmentNode(identifier.Value, Optional.Empty(), value); } case Symbol.Colon: { var type = ParseType(); ExpectSymbol(Symbol.Assign); var value = ParseExpression(); - return new VariableAssignmentNode(identifier.Value,type, value); + return new VariableAssignmentNode(identifier.Value, type, value); } default: { @@ -304,70 +306,69 @@ public class Parser } private ExpressionNode ParsePrimaryExpression() -{ - var token = ExpectToken(); - switch (token) { - case LiteralToken literal: + var token = ExpectToken(); + switch (token) { - return new LiteralNode(literal.Value, literal.Type); - } - case IdentifierToken identifier: - { - return ParseExpressionIdentifier(identifier); - } - case SymbolToken symbolToken: - { - switch (symbolToken.Symbol) + case LiteralToken literal: { - case Symbol.OpenParen: + return new LiteralNode(literal.Value, literal.Type); + } + case IdentifierToken identifier: + { + return ParseExpressionIdentifier(identifier); + } + case SymbolToken symbolToken: + { + switch (symbolToken.Symbol) { - // This is ugly - var nextToken = Peek(); - if (nextToken is { Value: IdentifierToken }) + case Symbol.OpenParen: { - var startIndex = _index; - var identifierToken = ExpectIdentifier(); - var type = NubType.Parse(identifierToken.Value); - - if (TryExpectSymbol(Symbol.CloseParen)) + // This is ugly + var nextToken = Peek(); + if (nextToken is { Value: IdentifierToken or SymbolToken { Symbol: Symbol.Caret } }) { - var expressionToCast = ParsePrimaryExpression(); - return new CastNode(type, expressionToCast); + var startIndex = _index; + var type = ParseType(); + + if (TryExpectSymbol(Symbol.CloseParen)) + { + var expressionToCast = ParsePrimaryExpression(); + return new CastNode(type, expressionToCast); + } + + _index = startIndex; } - _index = startIndex; + var expression = ParseExpression(); + ExpectSymbol(Symbol.CloseParen); + return expression; } - - var expression = ParseExpression(); - ExpectSymbol(Symbol.CloseParen); - return expression; - } - case Symbol.New: - { - var type = ParseType(); - Dictionary initializers = []; - ExpectSymbol(Symbol.OpenBrace); - while (!TryExpectSymbol(Symbol.CloseBrace)) + case Symbol.New: { - var name = ExpectIdentifier().Value; - ExpectSymbol(Symbol.Assign); - var value = ParseExpression(); - initializers.Add(name, value); - } + var type = ParseType(); + Dictionary initializers = []; + ExpectSymbol(Symbol.OpenBrace); + while (!TryExpectSymbol(Symbol.CloseBrace)) + { + var name = ExpectIdentifier().Value; + ExpectSymbol(Symbol.Assign); + var value = ParseExpression(); + initializers.Add(name, value); + } - return new StructInitializerNode(type, initializers); - } - default: - { - throw new Exception($"Unknown symbol: {symbolToken.Symbol}"); + return new StructInitializerNode(type, initializers); + } + default: + { + throw new Exception($"Unknown symbol: {symbolToken.Symbol}"); + } } } + default: + throw new Exception($"Unexpected token type {token.GetType().Name}"); } - default: - throw new Exception($"Unexpected token type {token.GetType().Name}"); } -} private ExpressionNode ParseExpressionIdentifier(IdentifierToken identifier) { @@ -482,7 +483,7 @@ public class Parser Next(); return true; } - + modifier = default; return false; } @@ -498,17 +499,6 @@ public class Parser return identifier; } - private LiteralToken ExpectLiteral() - { - var token = ExpectToken(); - if (token is not LiteralToken literal) - { - throw new Exception($"Expected {nameof(LiteralToken)} but got {token.GetType().Name}"); - } - - return literal; - } - private Optional Peek() { while (_index < _tokens.Count && _tokens.ElementAt(_index) is SymbolToken { Symbol: Symbol.Whitespace }) diff --git a/src/compiler/Nub.Lang/Frontend/Typing/TypeChecker.cs b/src/compiler/Nub.Lang/Frontend/Typing/TypeChecker.cs index 47791e1..b5a59ba 100644 --- a/src/compiler/Nub.Lang/Frontend/Typing/TypeChecker.cs +++ b/src/compiler/Nub.Lang/Frontend/Typing/TypeChecker.cs @@ -10,44 +10,29 @@ public class TypeCheckingException : Exception public class TypeChecker { private readonly Dictionary _variables = new(); - private readonly Dictionary Parameters, Optional ReturnType)> _functions = new(); - private readonly Dictionary> _structs = new(); + private readonly List _definitions; private NubType? _currentFunctionReturnType; private bool _hasReturnStatement; - public void TypeCheck(List definitions) + public TypeChecker(List definitions) { - CollectDefinitions(definitions); + _definitions = definitions; + } - foreach (var definition in definitions) + public void TypeCheck() + { + foreach (var structDef in _definitions.OfType()) { - if (definition is LocalFuncDefinitionNode funcDef) - { - TypeCheckFunction(funcDef); - } + TypeCheckStructDef(structDef); + } + + foreach (var funcDef in _definitions.OfType()) + { + TypeCheckFuncDef(funcDef); } } - private void CollectDefinitions(List definitions) - { - foreach (var definition in definitions) - { - switch (definition) - { - case StructDefinitionNode structDef: - RegisterStruct(structDef); - break; - case LocalFuncDefinitionNode funcDef: - RegisterFunction(funcDef); - break; - case ExternFuncDefinitionNode externFuncDef: - RegisterExternFunction(externFuncDef); - break; - } - } - } - - private void RegisterStruct(StructDefinitionNode structDef) + private void TypeCheckStructDef(StructDefinitionNode structDef) { var fields = new Dictionary(); foreach (var field in structDef.Fields) @@ -56,22 +41,20 @@ public class TypeChecker { throw new TypeCheckingException($"Duplicate field '{field.Name}' in struct '{structDef.Name}'"); } + + if (field.Value.HasValue) + { + if (!TypeCheckExpression(field.Value.Value).Equals(field.Type)) + { + throw new TypeCheckingException("Default field initializer does not match the defined type"); + } + } + fields[field.Name] = field.Type; } - _structs[structDef.Name] = fields; } - private void RegisterFunction(LocalFuncDefinitionNode funcDef) - { - _functions[funcDef.Name] = (funcDef.Parameters, funcDef.ReturnType); - } - - private void RegisterExternFunction(ExternFuncDefinitionNode funcDef) - { - _functions[funcDef.Name] = (funcDef.Parameters, funcDef.ReturnType); - } - - private void TypeCheckFunction(LocalFuncDefinitionNode funcDef) + private void TypeCheckFuncDef(LocalFuncDefinitionNode funcDef) { _variables.Clear(); _currentFunctionReturnType = funcDef.ReturnType.HasValue ? funcDef.ReturnType.Value : null; @@ -146,13 +129,28 @@ public class TypeChecker private NubType TypeCheckFuncCall(FuncCall funcCall) { - if (!_functions.TryGetValue(funcCall.Name, out var funcSignature)) + var localFuncDef = _definitions.OfType().FirstOrDefault(f => f.Name == funcCall.Name); + var externFuncDef = _definitions.OfType().FirstOrDefault(f => f.Name == funcCall.Name); + + List parameters; + Optional returnType; + if (localFuncDef != null) + { + parameters = localFuncDef.Parameters; + returnType = localFuncDef.ReturnType; + } + else if (externFuncDef != null) + { + parameters = externFuncDef.Parameters; + returnType = externFuncDef.ReturnType; + } + + else { throw new TypeCheckingException($"Function '{funcCall.Name}' is not defined"); } - var paramTypes = funcSignature.Parameters; - if (paramTypes.Take(paramTypes.Count - 1).Any(x => x.Variadic)) + if (parameters.Take(parameters.Count - 1).Any(x => x.Variadic)) { throw new TypeCheckingException($"Function '{funcCall.Name}' has multiple variadic parameters"); } @@ -162,13 +160,13 @@ public class TypeChecker var argType = TypeCheckExpression(funcCall.Parameters[i]); NubType paramType; - if (i < paramTypes.Count) + if (i < parameters.Count) { - paramType = paramTypes[i].Type; + paramType = parameters[i].Type; } - else if (paramTypes.LastOrDefault()?.Variadic ?? false) + else if (parameters.LastOrDefault()?.Variadic ?? false) { - return paramTypes[^1].Type; + return parameters[^1].Type; } else { @@ -181,7 +179,7 @@ public class TypeChecker } } - return funcSignature.ReturnType.HasValue ? funcSignature.ReturnType.Value : NubPrimitiveType.Any; + return returnType.HasValue ? returnType.Value : NubPrimitiveType.Any; } private void TypeCheckIf(IfNode ifNode) @@ -322,36 +320,47 @@ public class TypeChecker private NubType TypeCheckStructInitializer(StructInitializerNode structInit) { + var initialized = new HashSet(); + var structType = structInit.StructType; - if (structType is not NubCustomType customType) + if (structType is not NubStructType customType) { throw new TypeCheckingException($"Type '{structType}' is not a struct type"); } - if (!_structs.TryGetValue(customType.Name, out var fields)) + var definition = _definitions.OfType().FirstOrDefault(s => s.Name == structInit.StructType.Name); + if (definition == null) { throw new TypeCheckingException($"Struct type '{customType.Name}' is not defined"); } foreach (var initializer in structInit.Initializers) { - if (!fields.TryGetValue(initializer.Key, out var fieldType)) + var definitionField = definition.Fields.FirstOrDefault(f => f.Name == initializer.Key); + if (definitionField == null) { throw new TypeCheckingException($"Field '{initializer.Key}' does not exist in struct '{customType.Name}'"); } var initializerType = TypeCheckExpression(initializer.Value); - if (!AreTypesCompatible(initializerType, fieldType)) + if (!AreTypesCompatible(initializerType, definitionField.Type)) { - throw new TypeCheckingException($"Cannot initialize field '{initializer.Key}' of type '{fieldType}' with expression of type '{initializerType}'"); + throw new TypeCheckingException($"Cannot initialize field '{initializer.Key}' of type '{definitionField.Type}' with expression of type '{initializerType}'"); } + + initialized.Add(initializer.Key); + } + + foreach (var field in definition.Fields.Where(f => f.Value.HasValue)) + { + initialized.Add(field.Name); } - foreach (var field in fields) + foreach (var field in definition.Fields) { - if (!structInit.Initializers.ContainsKey(field.Key)) + if (!initialized.Contains(field.Name)) { - throw new TypeCheckingException($"Field '{field.Key}' of struct '{customType.Name}' is not initialized"); + throw new TypeCheckingException($"Struct field '{field.Name}' is not initialized on type '{customType.Name}'"); } } @@ -361,23 +370,24 @@ public class TypeChecker private NubType TypeCheckStructFieldAccess(StructFieldAccessorNode fieldAccess) { var structType = TypeCheckExpression(fieldAccess.Struct); - - if (structType is not NubCustomType customType) + if (structType is not NubStructType customType) { throw new TypeCheckingException($"Cannot access field '{fieldAccess.Field}' on non-struct type '{structType}'"); } - if (!_structs.TryGetValue(customType.Name, out var fields)) + var definition = _definitions.OfType().FirstOrDefault(s => s.Name == customType.Name); + if (definition == null) { throw new TypeCheckingException($"Struct type '{customType.Name}' is not defined"); } - if (!fields.TryGetValue(fieldAccess.Field, out var fieldType)) + var field = definition.Fields.FirstOrDefault(f => f.Name == fieldAccess.Field); + if (field == null) { throw new TypeCheckingException($"Field '{fieldAccess.Field}' does not exist in struct '{customType.Name}'"); } - return fieldType; + return field.Type; } private static bool AreTypesCompatible(NubType sourceType, NubType targetType) diff --git a/src/compiler/Nub.Lang/NubType.cs b/src/compiler/Nub.Lang/NubType.cs index c69f582..75aa151 100644 --- a/src/compiler/Nub.Lang/NubType.cs +++ b/src/compiler/Nub.Lang/NubType.cs @@ -18,7 +18,7 @@ public abstract class NubType return new NubPrimitiveType(kind.Value); } - return new NubCustomType(s); + return new NubStructType(s); } public override bool Equals(object? obj) => obj is NubType item && Name.Equals(item.Name); @@ -26,16 +26,11 @@ public abstract class NubType public override string ToString() => Name; } -public class NubCustomType(string name) : NubType(name); +public class NubStructType(string name) : NubType(name); -public class NubPrimitiveType : NubType +public class NubPrimitiveType(PrimitiveTypeKind kind) : NubType(KindToString(kind)) { - public NubPrimitiveType(PrimitiveTypeKind kind) : base(KindToString(kind)) - { - Kind = kind; - } - - public PrimitiveTypeKind Kind { get; } + public PrimitiveTypeKind Kind { get; } = kind; public static NubPrimitiveType I64 => new(PrimitiveTypeKind.I64); public static NubPrimitiveType I32 => new(PrimitiveTypeKind.I32); diff --git a/src/compiler/Nub.Lang/Program.cs b/src/compiler/Nub.Lang/Program.cs index a7b94cd..a7335b8 100644 --- a/src/compiler/Nub.Lang/Program.cs +++ b/src/compiler/Nub.Lang/Program.cs @@ -44,8 +44,8 @@ internal static class Program var modules = RunFrontend(input); var definitions = modules.SelectMany(f => f.Definitions).ToList(); - var typeChecker = new TypeChecker(); - typeChecker.TypeCheck(definitions); + var typeChecker = new TypeChecker(definitions); + typeChecker.TypeCheck(); var generator = new Generator(definitions); var result = generator.Generate();