From 560e6428ff50b8ba313a28ba07f4590d9421dfce Mon Sep 17 00:00:00 2001 From: nub31 Date: Sun, 26 Oct 2025 22:28:48 +0100 Subject: [PATCH] ... --- compiler/NubLang.CLI/Program.cs | 45 ++- compiler/NubLang.LSP/AstExtensions.cs | 9 +- compiler/NubLang.LSP/CompletionHandler.cs | 40 +-- compiler/NubLang.LSP/DefinitionHandler.cs | 19 +- compiler/NubLang.LSP/HoverHandler.cs | 204 ++++++------ compiler/NubLang.LSP/WorkspaceManager.cs | 22 +- compiler/NubLang/Ast/CompilationUnit.cs | 15 +- compiler/NubLang/Ast/Node.cs | 53 ++- compiler/NubLang/Ast/TypeChecker.cs | 308 +++++++----------- compiler/NubLang/Ast/TypeResolver.cs | 97 ++++++ .../NubLang/Diagnostics/CompileException.cs | 11 + compiler/NubLang/Generation/CType.cs | 6 +- compiler/NubLang/Generation/Generator.cs | 126 ++++--- .../NubLang/Generation/HeaderGenerator.cs | 49 +++ compiler/NubLang/Syntax/Parser.cs | 44 +-- compiler/NubLang/Syntax/Tokenizer.cs | 24 +- compiler/NubLang/Syntax/TypedModule.cs | 50 +++ examples/playgroud/main.nub | 24 +- 18 files changed, 663 insertions(+), 483 deletions(-) create mode 100644 compiler/NubLang/Ast/TypeResolver.cs create mode 100644 compiler/NubLang/Diagnostics/CompileException.cs create mode 100644 compiler/NubLang/Generation/HeaderGenerator.cs create mode 100644 compiler/NubLang/Syntax/TypedModule.cs diff --git a/compiler/NubLang.CLI/Program.cs b/compiler/NubLang.CLI/Program.cs index ba72d53..314327a 100644 --- a/compiler/NubLang.CLI/Program.cs +++ b/compiler/NubLang.CLI/Program.cs @@ -21,7 +21,7 @@ foreach (var file in args) } var modules = Module.Collect(syntaxTrees); -var compilationUnits = new List(); +var compilationUnits = new List>(); for (var i = 0; i < args.Length; i++) { @@ -46,16 +46,48 @@ var cPaths = new List(); Directory.CreateDirectory(".build"); +var typedModules = modules.Select(x => (x.Key, TypedModule.FromModule(x.Key, x.Value, modules))).ToDictionary(); + +var moduleHeaders = new List(); + +var commonHeaderOut = Path.Combine(".build", "runtime.h"); + +File.WriteAllText(commonHeaderOut, """ + #include + + void *rc_alloc(size_t size, void (*destructor)(void *self)); + void rc_retain(void *obj); + void rc_release(void *obj); + + typedef struct + { + unsigned long long length; + char *data; + } nub_string; + + typedef struct + { + unsigned long long length; + void *data; + } nub_slice; + """); + +moduleHeaders.Add(commonHeaderOut); + +foreach (var typedModule in typedModules) +{ + var header = HeaderGenerator.Generate(typedModule.Key, typedModule.Value); + var headerOut = Path.Combine(".build", "modules", typedModule.Key + ".h"); + Directory.CreateDirectory(Path.Combine(".build", "modules")); + File.WriteAllText(headerOut, header); + moduleHeaders.Add(headerOut); +} + for (var i = 0; i < args.Length; i++) { var file = args[i]; var compilationUnit = compilationUnits[i]; - if (compilationUnit == null) - { - continue; - } - var generator = new Generator(compilationUnit); var directory = Path.GetDirectoryName(file); if (!string.IsNullOrWhiteSpace(directory)) @@ -74,6 +106,7 @@ foreach (var cPath in cPaths) { var objectPath = Path.ChangeExtension(cPath, "o"); using var compileProcess = Process.Start("clang", [ + ..moduleHeaders.SelectMany(x => new[] { "-include", x }), "-ffreestanding", "-std=c23", "-g", "-c", "-o", objectPath, diff --git a/compiler/NubLang.LSP/AstExtensions.cs b/compiler/NubLang.LSP/AstExtensions.cs index e0b2477..e26675b 100644 --- a/compiler/NubLang.LSP/AstExtensions.cs +++ b/compiler/NubLang.LSP/AstExtensions.cs @@ -1,5 +1,4 @@ using NubLang.Ast; -using NubLang.Syntax; using OmniSharp.Extensions.LanguageServer.Protocol.Models; using Range = OmniSharp.Extensions.LanguageServer.Protocol.Models.Range; @@ -58,16 +57,16 @@ public static class AstExtensions return false; } - public static FuncNode? FunctionAtPosition(this CompilationUnit compilationUnit, int line, int character) + public static FuncNode? FunctionAtPosition(this List compilationUnit, int line, int character) { return compilationUnit - .Functions + .OfType() .FirstOrDefault(x => x.ContainsPosition(line, character)); } - public static Node? DeepestNodeAtPosition(this CompilationUnit compilationUnit, int line, int character) + public static Node? DeepestNodeAtPosition(this List compilationUnit, int line, int character) { - return compilationUnit.Functions + return compilationUnit .SelectMany(x => x.DescendantsAndSelf()) .Where(n => n.ContainsPosition(line, character)) .OrderBy(n => n.Tokens.First().Span.Start.Line) diff --git a/compiler/NubLang.LSP/CompletionHandler.cs b/compiler/NubLang.LSP/CompletionHandler.cs index 297e858..4dde192 100644 --- a/compiler/NubLang.LSP/CompletionHandler.cs +++ b/compiler/NubLang.LSP/CompletionHandler.cs @@ -118,30 +118,30 @@ internal class CompletionHandler(WorkspaceManager workspaceManager) : Completion var compilationUnit = workspaceManager.GetCompilationUnit(uri); if (compilationUnit != null) { - var function = compilationUnit.Functions.FirstOrDefault(x => x.Body != null && x.Body.ContainsPosition(position.Line, position.Character)); + var function = compilationUnit.OfType().FirstOrDefault(x => x.Body != null && x.Body.ContainsPosition(position.Line, position.Character)); if (function != null) { completions.AddRange(_statementSnippets); - foreach (var (module, prototypes) in compilationUnit.ImportedFunctions) - { - foreach (var prototype in prototypes) - { - var parameterStrings = new List(); - foreach (var (index, parameter) in prototype.Parameters.Index()) - { - parameterStrings.AddRange($"${{{index + 1}:{parameter.NameToken.Value}}}"); - } - - completions.Add(new CompletionItem - { - Kind = CompletionItemKind.Function, - Label = $"{module.Value}::{prototype.NameToken.Value}", - InsertTextFormat = InsertTextFormat.Snippet, - InsertText = $"{module.Value}::{prototype.NameToken.Value}({string.Join(", ", parameterStrings)})", - }); - } - } + // foreach (var (module, prototypes) in compilationUnit.ImportedFunctions) + // { + // foreach (var prototype in prototypes) + // { + // var parameterStrings = new List(); + // foreach (var (index, parameter) in prototype.Parameters.Index()) + // { + // parameterStrings.AddRange($"${{{index + 1}:{parameter.NameToken.Value}}}"); + // } + // + // completions.Add(new CompletionItem + // { + // Kind = CompletionItemKind.Function, + // Label = $"{module.Value}::{prototype.NameToken.Value}", + // InsertTextFormat = InsertTextFormat.Snippet, + // InsertText = $"{module.Value}::{prototype.NameToken.Value}({string.Join(", ", parameterStrings)})", + // }); + // } + // } foreach (var parameter in function.Prototype.Parameters) { diff --git a/compiler/NubLang.LSP/DefinitionHandler.cs b/compiler/NubLang.LSP/DefinitionHandler.cs index d3ce660..a86779b 100644 --- a/compiler/NubLang.LSP/DefinitionHandler.cs +++ b/compiler/NubLang.LSP/DefinitionHandler.cs @@ -57,15 +57,16 @@ internal class DefinitionHandler(WorkspaceManager workspaceManager) : Definition } case FuncIdentifierNode funcIdentifierNode: { - var prototype = compilationUnit.ImportedFunctions - .Where(x => x.Key.Value == funcIdentifierNode.ModuleToken.Value) - .SelectMany(x => x.Value) - .FirstOrDefault(x => x.NameToken.Value == funcIdentifierNode.NameToken.Value); - - if (prototype != null) - { - return new LocationOrLocationLinks(prototype.ToLocation()); - } + // var prototype = compilationUnit + // .ImportedFunctions + // .Where(x => x.Key.Value == funcIdentifierNode.ModuleToken.Value) + // .SelectMany(x => x.Value) + // .FirstOrDefault(x => x.NameToken.Value == funcIdentifierNode.NameToken.Value); + // + // if (prototype != null) + // { + // return new LocationOrLocationLinks(prototype.ToLocation()); + // } return null; } diff --git a/compiler/NubLang.LSP/HoverHandler.cs b/compiler/NubLang.LSP/HoverHandler.cs index c97b090..d6c248e 100644 --- a/compiler/NubLang.LSP/HoverHandler.cs +++ b/compiler/NubLang.LSP/HoverHandler.cs @@ -39,108 +39,110 @@ internal class HoverHandler(WorkspaceManager workspaceManager) : HoverHandlerBas return null; } - var message = CreateMessage(hoveredNode, compilationUnit); - if (message == null) - { - return null; - } + // var message = CreateMessage(hoveredNode, compilationUnit); + // if (message == null) + // { + // return null; + // } + // + // return new Hover + // { + // Contents = new MarkedStringsOrMarkupContent(new MarkupContent + // { + // Value = message, + // Kind = MarkupKind.Markdown, + // }) + // }; - return new Hover - { - Contents = new MarkedStringsOrMarkupContent(new MarkupContent - { - Value = message, - Kind = MarkupKind.Markdown, - }) - }; + return null; } - private static string? CreateMessage(Node hoveredNode, CompilationUnit compilationUnit) - { - return hoveredNode switch - { - FuncNode funcNode => CreateFuncPrototypeMessage(funcNode.Prototype), - FuncPrototypeNode funcPrototypeNode => CreateFuncPrototypeMessage(funcPrototypeNode), - FuncIdentifierNode funcIdentifierNode => CreateFuncIdentifierMessage(funcIdentifierNode, compilationUnit), - FuncParameterNode funcParameterNode => CreateTypeNameMessage("Function parameter", funcParameterNode.NameToken.Value, funcParameterNode.Type), - VariableIdentifierNode variableIdentifierNode => CreateTypeNameMessage("Variable", variableIdentifierNode.NameToken.Value, variableIdentifierNode.Type), - VariableDeclarationNode variableDeclarationNode => CreateTypeNameMessage("Variable declaration", variableDeclarationNode.NameToken.Value, variableDeclarationNode.Type), - StructFieldAccessNode structFieldAccessNode => CreateTypeNameMessage("Struct field", $"{structFieldAccessNode.Target.Type}.{structFieldAccessNode.FieldToken.Value}", structFieldAccessNode.Type), - CStringLiteralNode cStringLiteralNode => CreateLiteralMessage(cStringLiteralNode.Type, '"' + cStringLiteralNode.Value + '"'), - StringLiteralNode stringLiteralNode => CreateLiteralMessage(stringLiteralNode.Type, '"' + stringLiteralNode.Value + '"'), - BoolLiteralNode boolLiteralNode => CreateLiteralMessage(boolLiteralNode.Type, boolLiteralNode.Value.ToString()), - Float32LiteralNode float32LiteralNode => CreateLiteralMessage(float32LiteralNode.Type, float32LiteralNode.Value.ToString(CultureInfo.InvariantCulture)), - Float64LiteralNode float64LiteralNode => CreateLiteralMessage(float64LiteralNode.Type, float64LiteralNode.Value.ToString(CultureInfo.InvariantCulture)), - I8LiteralNode i8LiteralNode => CreateLiteralMessage(i8LiteralNode.Type, i8LiteralNode.Value.ToString()), - I16LiteralNode i16LiteralNode => CreateLiteralMessage(i16LiteralNode.Type, i16LiteralNode.Value.ToString()), - I32LiteralNode i32LiteralNode => CreateLiteralMessage(i32LiteralNode.Type, i32LiteralNode.Value.ToString()), - I64LiteralNode i64LiteralNode => CreateLiteralMessage(i64LiteralNode.Type, i64LiteralNode.Value.ToString()), - U8LiteralNode u8LiteralNode => CreateLiteralMessage(u8LiteralNode.Type, u8LiteralNode.Value.ToString()), - U16LiteralNode u16LiteralNode => CreateLiteralMessage(u16LiteralNode.Type, u16LiteralNode.Value.ToString()), - U32LiteralNode u32LiteralNode => CreateLiteralMessage(u32LiteralNode.Type, u32LiteralNode.Value.ToString()), - U64LiteralNode u64LiteralNode => CreateLiteralMessage(u64LiteralNode.Type, u64LiteralNode.Value.ToString()), - // Expressions can have a generic fallback showing the resulting type - ExpressionNode expressionNode => $""" - **Expression** `{expressionNode.GetType().Name}` - ```nub - {expressionNode.Type} - ``` - """, - BlockNode => null, - _ => hoveredNode.GetType().Name - }; - } - - private static string CreateLiteralMessage(NubType type, string value) - { - return $""" - **Literal** `{type}` - ```nub - {value}: {type} - ``` - """; - } - - private static string CreateTypeNameMessage(string description, string name, NubType type) - { - return $""" - **{description}** `{name}` - ```nub - {name}: {type} - ``` - """; - } - - private static string CreateFuncIdentifierMessage(FuncIdentifierNode funcIdentifierNode, CompilationUnit compilationUnit) - { - var func = compilationUnit.ImportedFunctions - .Where(x => x.Key.Value == funcIdentifierNode.ModuleToken.Value) - .SelectMany(x => x.Value) - .FirstOrDefault(x => x.NameToken.Value == funcIdentifierNode.NameToken.Value); - - if (func == null) - { - return $""" - **Function** `{funcIdentifierNode.ModuleToken.Value}::{funcIdentifierNode.NameToken.Value}` - ```nub - // Declaration not found - ``` - """; - } - - return CreateFuncPrototypeMessage(func); - } - - private static string CreateFuncPrototypeMessage(FuncPrototypeNode funcPrototypeNode) - { - var parameterText = string.Join(", ", funcPrototypeNode.Parameters.Select(x => $"{x.NameToken.Value}: {x.Type}")); - var externText = funcPrototypeNode.ExternSymbolToken != null ? $"extern \"{funcPrototypeNode.ExternSymbolToken.Value}\" " : ""; - - return $""" - **Function** `{funcPrototypeNode.NameToken.Value}` - ```nub - {externText}func {funcPrototypeNode.NameToken.Value}({parameterText}): {funcPrototypeNode.ReturnType} - ``` - """; - } +// private static string? CreateMessage(Node hoveredNode, CompilationUnit compilationUnit) +// { +// return hoveredNode switch +// { +// FuncNode funcNode => CreateFuncPrototypeMessage(funcNode.Prototype), +// FuncPrototypeNode funcPrototypeNode => CreateFuncPrototypeMessage(funcPrototypeNode), +// FuncIdentifierNode funcIdentifierNode => CreateFuncIdentifierMessage(funcIdentifierNode, compilationUnit), +// FuncParameterNode funcParameterNode => CreateTypeNameMessage("Function parameter", funcParameterNode.NameToken.Value, funcParameterNode.Type), +// VariableIdentifierNode variableIdentifierNode => CreateTypeNameMessage("Variable", variableIdentifierNode.NameToken.Value, variableIdentifierNode.Type), +// VariableDeclarationNode variableDeclarationNode => CreateTypeNameMessage("Variable declaration", variableDeclarationNode.NameToken.Value, variableDeclarationNode.Type), +// StructFieldAccessNode structFieldAccessNode => CreateTypeNameMessage("Struct field", $"{structFieldAccessNode.Target.Type}.{structFieldAccessNode.FieldToken.Value}", structFieldAccessNode.Type), +// CStringLiteralNode cStringLiteralNode => CreateLiteralMessage(cStringLiteralNode.Type, '"' + cStringLiteralNode.Value + '"'), +// StringLiteralNode stringLiteralNode => CreateLiteralMessage(stringLiteralNode.Type, '"' + stringLiteralNode.Value + '"'), +// BoolLiteralNode boolLiteralNode => CreateLiteralMessage(boolLiteralNode.Type, boolLiteralNode.Value.ToString()), +// Float32LiteralNode float32LiteralNode => CreateLiteralMessage(float32LiteralNode.Type, float32LiteralNode.Value.ToString(CultureInfo.InvariantCulture)), +// Float64LiteralNode float64LiteralNode => CreateLiteralMessage(float64LiteralNode.Type, float64LiteralNode.Value.ToString(CultureInfo.InvariantCulture)), +// I8LiteralNode i8LiteralNode => CreateLiteralMessage(i8LiteralNode.Type, i8LiteralNode.Value.ToString()), +// I16LiteralNode i16LiteralNode => CreateLiteralMessage(i16LiteralNode.Type, i16LiteralNode.Value.ToString()), +// I32LiteralNode i32LiteralNode => CreateLiteralMessage(i32LiteralNode.Type, i32LiteralNode.Value.ToString()), +// I64LiteralNode i64LiteralNode => CreateLiteralMessage(i64LiteralNode.Type, i64LiteralNode.Value.ToString()), +// U8LiteralNode u8LiteralNode => CreateLiteralMessage(u8LiteralNode.Type, u8LiteralNode.Value.ToString()), +// U16LiteralNode u16LiteralNode => CreateLiteralMessage(u16LiteralNode.Type, u16LiteralNode.Value.ToString()), +// U32LiteralNode u32LiteralNode => CreateLiteralMessage(u32LiteralNode.Type, u32LiteralNode.Value.ToString()), +// U64LiteralNode u64LiteralNode => CreateLiteralMessage(u64LiteralNode.Type, u64LiteralNode.Value.ToString()), +// // Expressions can have a generic fallback showing the resulting type +// ExpressionNode expressionNode => $""" +// **Expression** `{expressionNode.GetType().Name}` +// ```nub +// {expressionNode.Type} +// ``` +// """, +// BlockNode => null, +// _ => hoveredNode.GetType().Name +// }; +// } +// +// private static string CreateLiteralMessage(NubType type, string value) +// { +// return $""" +// **Literal** `{type}` +// ```nub +// {value}: {type} +// ``` +// """; +// } +// +// private static string CreateTypeNameMessage(string description, string name, NubType type) +// { +// return $""" +// **{description}** `{name}` +// ```nub +// {name}: {type} +// ``` +// """; +// } +// +// private static string CreateFuncIdentifierMessage(FuncIdentifierNode funcIdentifierNode, CompilationUnit compilationUnit) +// { +// var func = compilationUnit.ImportedFunctions +// .Where(x => x.Key.Value == funcIdentifierNode.ModuleToken.Value) +// .SelectMany(x => x.Value) +// .FirstOrDefault(x => x.NameToken.Value == funcIdentifierNode.NameToken.Value); +// +// if (func == null) +// { +// return $""" +// **Function** `{funcIdentifierNode.ModuleToken.Value}::{funcIdentifierNode.NameToken.Value}` +// ```nub +// // Declaration not found +// ``` +// """; +// } +// +// return CreateFuncPrototypeMessage(func); +// } +// +// private static string CreateFuncPrototypeMessage(FuncPrototypeNode funcPrototypeNode) +// { +// var parameterText = string.Join(", ", funcPrototypeNode.Parameters.Select(x => $"{x.NameToken.Value}: {x.Type}")); +// var externText = funcPrototypeNode.ExternSymbolToken != null ? $"extern \"{funcPrototypeNode.ExternSymbolToken.Value}\" " : ""; +// +// return $""" +// **Function** `{funcPrototypeNode.NameToken.Value}` +// ```nub +// {externText}func {funcPrototypeNode.NameToken.Value}({parameterText}): {funcPrototypeNode.ReturnType} +// ``` +// """; +// } } \ No newline at end of file diff --git a/compiler/NubLang.LSP/WorkspaceManager.cs b/compiler/NubLang.LSP/WorkspaceManager.cs index 3124513..ced47f3 100644 --- a/compiler/NubLang.LSP/WorkspaceManager.cs +++ b/compiler/NubLang.LSP/WorkspaceManager.cs @@ -7,7 +7,7 @@ namespace NubLang.LSP; public class WorkspaceManager(DiagnosticsPublisher diagnosticsPublisher) { private readonly Dictionary _syntaxTrees = new(); - private readonly Dictionary _compilationUnits = new(); + private readonly Dictionary> _compilationUnits = new(); public void Init(string rootPath) { @@ -35,14 +35,7 @@ public class WorkspaceManager(DiagnosticsPublisher diagnosticsPublisher) var result = typeChecker.Check(); diagnosticsPublisher.Publish(fsPath, typeChecker.Diagnostics); - if (result == null) - { - _compilationUnits.Remove(fsPath); - } - else - { - _compilationUnits[fsPath] = result; - } + _compilationUnits[fsPath] = result; } } @@ -66,14 +59,7 @@ public class WorkspaceManager(DiagnosticsPublisher diagnosticsPublisher) var result = typeChecker.Check(); diagnosticsPublisher.Publish(fsPath, typeChecker.Diagnostics); - if (result == null) - { - _compilationUnits.Remove(fsPath); - } - else - { - _compilationUnits[fsPath] = result; - } + _compilationUnits[fsPath] = result; } public void RemoveFile(DocumentUri path) @@ -83,7 +69,7 @@ public class WorkspaceManager(DiagnosticsPublisher diagnosticsPublisher) _compilationUnits.Remove(fsPath); } - public CompilationUnit? GetCompilationUnit(DocumentUri path) + public List? GetCompilationUnit(DocumentUri path) { return _compilationUnits.GetValueOrDefault(path.GetFileSystemPath()); } diff --git a/compiler/NubLang/Ast/CompilationUnit.cs b/compiler/NubLang/Ast/CompilationUnit.cs index 0dd548d..4cf2489 100644 --- a/compiler/NubLang/Ast/CompilationUnit.cs +++ b/compiler/NubLang/Ast/CompilationUnit.cs @@ -2,10 +2,11 @@ using NubLang.Syntax; namespace NubLang.Ast; -public sealed class CompilationUnit(IdentifierToken module, List functions, Dictionary> importedStructTypes, Dictionary> importedFunctions) -{ - public IdentifierToken Module { get; } = module; - public List Functions { get; } = functions; - public Dictionary> ImportedStructTypes { get; } = importedStructTypes; - public Dictionary> ImportedFunctions { get; } = importedFunctions; -} \ No newline at end of file +// public sealed class CompilationUnit(IdentifierToken module, List functions, List structTypes, Dictionary> importedStructTypes, Dictionary> importedFunctions) +// { +// public IdentifierToken Module { get; } = module; +// public List Functions { get; } = functions; +// public List Structs { get; } = structTypes; +// public Dictionary> ImportedStructTypes { get; } = importedStructTypes; +// public Dictionary> ImportedFunctions { get; } = importedFunctions; +// } \ No newline at end of file diff --git a/compiler/NubLang/Ast/Node.cs b/compiler/NubLang/Ast/Node.cs index df42b4a..0d18e84 100644 --- a/compiler/NubLang/Ast/Node.cs +++ b/compiler/NubLang/Ast/Node.cs @@ -29,9 +29,31 @@ public abstract class Node(List tokens) } } +public abstract class TopLevelNode(List tokens) : Node(tokens); + +public class ImportNode(List tokens, IdentifierToken nameToken) : TopLevelNode(tokens) +{ + public IdentifierToken NameToken { get; } = nameToken; + + public override IEnumerable Children() + { + return []; + } +} + +public class ModuleNode(List tokens, IdentifierToken nameToken) : TopLevelNode(tokens) +{ + public IdentifierToken NameToken { get; } = nameToken; + + public override IEnumerable Children() + { + return []; + } +} + #region Definitions -public abstract class DefinitionNode(List tokens, IdentifierToken nameToken) : Node(tokens) +public abstract class DefinitionNode(List tokens, IdentifierToken nameToken) : TopLevelNode(tokens) { public IdentifierToken NameToken { get; } = nameToken; } @@ -75,6 +97,35 @@ public class FuncNode(List tokens, FuncPrototypeNode prototype, BlockNode } } +public class StructFieldNode(List tokens, IdentifierToken nameToken, NubType type, ExpressionNode? value) : Node(tokens) +{ + public IdentifierToken NameToken { get; } = nameToken; + public NubType Type { get; } = type; + public ExpressionNode? Value { get; } = value; + + public override IEnumerable Children() + { + if (Value != null) + { + yield return Value; + } + } +} + +public class StructNode(List tokens, IdentifierToken name, NubStructType structType, List fields) : DefinitionNode(tokens, name) +{ + public NubStructType StructType { get; } = structType; + public List Fields { get; } = fields; + + public override IEnumerable Children() + { + foreach (var field in Fields) + { + yield return field; + } + } +} + #endregion #region Statements diff --git a/compiler/NubLang/Ast/TypeChecker.cs b/compiler/NubLang/Ast/TypeChecker.cs index 147dc83..9f0578d 100644 --- a/compiler/NubLang/Ast/TypeChecker.cs +++ b/compiler/NubLang/Ast/TypeChecker.cs @@ -10,8 +10,8 @@ public sealed class TypeChecker private readonly Dictionary _modules; private readonly Stack _scopes = []; - private readonly Dictionary<(string Module, string Name), NubType> _typeCache = new(); - private readonly HashSet<(string Module, string Name)> _resolvingTypes = []; + + private readonly TypeResolver _typeResolver; private Scope Scope => _scopes.Peek(); @@ -21,19 +21,18 @@ public sealed class TypeChecker { _syntaxTree = syntaxTree; _modules = modules; + _typeResolver = new TypeResolver(_modules); } - public CompilationUnit? Check() + public List Check() { _scopes.Clear(); - _typeCache.Clear(); - _resolvingTypes.Clear(); var moduleDeclarations = _syntaxTree.TopLevelSyntaxNodes.OfType().ToList(); if (moduleDeclarations.Count == 0) { Diagnostics.Add(Diagnostic.Error("Missing module declaration").WithHelp("module \"main\"").Build()); - return null; + return []; } if (moduleDeclarations.Count > 1) @@ -79,72 +78,45 @@ public sealed class TypeChecker .At(last) .Build()); - return null; + return []; } } - var functions = new List(); + var topLevelNodes = new List(); using (BeginRootScope(moduleName)) { - foreach (var funcSyntax in _syntaxTree.TopLevelSyntaxNodes.OfType()) + foreach (var topLevelSyntaxNode in _syntaxTree.TopLevelSyntaxNodes) { - try + switch (topLevelSyntaxNode) { - functions.Add(CheckFuncDefinition(funcSyntax)); - } - catch (TypeCheckerException e) - { - Diagnostics.Add(e.Diagnostic); + case EnumSyntax: + break; + case FuncSyntax funcSyntax: + topLevelNodes.Add(CheckFuncDefinition(funcSyntax)); + break; + case StructSyntax structSyntax: + topLevelNodes.Add(CheckStructDefinition(structSyntax)); + break; + case ImportSyntax importSyntax: + topLevelNodes.Add(new ImportNode(importSyntax.Tokens, importSyntax.NameToken)); + break; + case ModuleSyntax moduleSyntax: + topLevelNodes.Add(new ModuleNode(moduleSyntax.Tokens, moduleSyntax.NameToken)); + break; + default: + throw new ArgumentOutOfRangeException(nameof(topLevelSyntaxNode)); } } } - var importedStructTypes = new Dictionary>(); - var importedFunctions = new Dictionary>(); + return topLevelNodes; + } - foreach (var (name, module) in GetImportedModules()) - { - var moduleStructs = new List(); - var moduleFunctions = new List(); - - using (BeginRootScope(name)) - { - foreach (var structSyntax in module.Structs(true)) - { - try - { - var fields = structSyntax.Fields - .Select(f => new NubStructFieldType(f.NameToken.Value, ResolveType(f.Type), f.Value != null)) - .ToList(); - - moduleStructs.Add(new NubStructType(name.Value, structSyntax.NameToken.Value, fields)); - } - catch (TypeCheckerException e) - { - Diagnostics.Add(e.Diagnostic); - } - } - - importedStructTypes[name] = moduleStructs; - - foreach (var funcSyntax in module.Functions(true)) - { - try - { - moduleFunctions.Add(CheckFuncPrototype(funcSyntax.Prototype)); - } - catch (TypeCheckerException e) - { - Diagnostics.Add(e.Diagnostic); - } - } - - importedFunctions[name] = moduleFunctions; - } - } - - return new CompilationUnit(moduleName, functions, importedStructTypes, importedFunctions); + private (IdentifierToken Name, Module Module) GetCurrentModule() + { + var currentModule = _syntaxTree.TopLevelSyntaxNodes.OfType().First().NameToken; + return (currentModule, _modules[currentModule.Value]); } private List<(IdentifierToken Name, Module Module)> GetImportedModules() @@ -225,19 +197,48 @@ public sealed class TypeChecker } } + private StructNode CheckStructDefinition(StructSyntax structSyntax) + { + var fields = new List(); + + foreach (var field in structSyntax.Fields) + { + var fieldType = _typeResolver.ResolveType(field.Type, Scope.Module.Value); + ExpressionNode? value = null; + if (field.Value != null) + { + value = CheckExpression(field.Value, fieldType); + if (value.Type != fieldType) + { + throw new CompileException(Diagnostic + .Error($"Type {value.Type} is not assignable to {field.Type} for field {field.NameToken.Value}") + .At(field) + .Build()); + } + } + + fields.Add(new StructFieldNode(field.Tokens, field.NameToken, fieldType, value)); + } + + var currentModule = GetCurrentModule(); + var type = new NubStructType(currentModule.Name.Value, structSyntax.NameToken.Value, fields.Select(x => new NubStructFieldType(x.NameToken.Value, x.Type, x.Value != null)).ToList()); + + return new StructNode(structSyntax.Tokens, structSyntax.NameToken, type, fields); + } + private AssignmentNode CheckAssignment(AssignmentSyntax statement) { var target = CheckExpression(statement.Target); if (target is not LValueExpressionNode lValue) { - throw new TypeCheckerException(Diagnostic.Error("Cannot assign to an rvalue").At(statement).Build()); + throw new CompileException(Diagnostic.Error("Cannot assign to an rvalue").At(statement).Build()); } var value = CheckExpression(statement.Value, lValue.Type); if (value.Type != lValue.Type) { - throw new TypeCheckerException(Diagnostic + throw new CompileException(Diagnostic .Error($"Cannot assign {value.Type} to {lValue.Type}") .At(statement.Value) .Build()); @@ -279,7 +280,7 @@ public sealed class TypeChecker return expression switch { FuncCallNode funcCall => new StatementFuncCallNode(statement.Tokens, funcCall), - _ => throw new TypeCheckerException(Diagnostic.Error("Expressions statements can only be function calls").At(statement).Build()) + _ => throw new CompileException(Diagnostic.Error("Expressions statements can only be function calls").At(statement).Build()) }; } @@ -290,7 +291,7 @@ public sealed class TypeChecker if (statement.ExplicitType != null) { - type = ResolveType(statement.ExplicitType); + type = _typeResolver.ResolveType(statement.ExplicitType, Scope.Module.Value); } if (statement.Assignment != null) @@ -303,7 +304,7 @@ public sealed class TypeChecker } else if (assignmentNode.Type != type) { - throw new TypeCheckerException(Diagnostic + throw new CompileException(Diagnostic .Error($"Cannot assign {assignmentNode.Type} to variable of type {type}") .At(statement.Assignment) .Build()); @@ -312,7 +313,7 @@ public sealed class TypeChecker if (type == null) { - throw new TypeCheckerException(Diagnostic + throw new CompileException(Diagnostic .Error($"Cannot infer type of variable {statement.NameToken.Value}") .At(statement) .Build()); @@ -367,7 +368,7 @@ public sealed class TypeChecker } default: { - throw new TypeCheckerException(Diagnostic + throw new CompileException(Diagnostic .Error($"Cannot iterate over type {target.Type} which does not have size information") .At(forSyntax.Target) .Build()); @@ -380,10 +381,10 @@ public sealed class TypeChecker var parameters = new List(); foreach (var parameter in statement.Parameters) { - parameters.Add(new FuncParameterNode(parameter.Tokens, parameter.NameToken, ResolveType(parameter.Type))); + parameters.Add(new FuncParameterNode(parameter.Tokens, parameter.NameToken, _typeResolver.ResolveType(parameter.Type, Scope.Module.Value))); } - return new FuncPrototypeNode(statement.Tokens, statement.NameToken, statement.ExternSymbolToken, parameters, ResolveType(statement.ReturnType)); + return new FuncPrototypeNode(statement.Tokens, statement.NameToken, statement.ExternSymbolToken, parameters, _typeResolver.ResolveType(statement.ReturnType, Scope.Module.Value)); } private ExpressionNode CheckExpression(ExpressionSyntax node, NubType? expectedType = null) @@ -405,7 +406,7 @@ public sealed class TypeChecker FloatLiteralSyntax expression => CheckFloatLiteral(expression, expectedType), MemberAccessSyntax expression => CheckMemberAccess(expression, expectedType), StructInitializerSyntax expression => CheckStructInitializer(expression, expectedType), - SizeSyntax expression => new SizeNode(node.Tokens, ResolveType(expression.Type)), + SizeSyntax expression => new SizeNode(node.Tokens, _typeResolver.ResolveType(expression.Type, Scope.Module.Value)), CastSyntax expression => CheckCast(expression, expectedType), _ => throw new ArgumentOutOfRangeException(nameof(node)) }; @@ -430,7 +431,7 @@ public sealed class TypeChecker { if (expectedType == null) { - throw new TypeCheckerException(Diagnostic + throw new CompileException(Diagnostic .Error("Unable to infer target type of cast") .At(expression) .WithHelp("Specify target type where value is used") @@ -451,7 +452,7 @@ public sealed class TypeChecker if (!IsCastAllowed(value.Type, expectedType, false)) { - throw new TypeCheckerException(Diagnostic + throw new CompileException(Diagnostic .Error($"Cannot cast from {value.Type} to {expectedType}") .Build()); } @@ -500,7 +501,7 @@ public sealed class TypeChecker var target = CheckExpression(expression.Target, (expectedType as NubPointerType)?.BaseType); if (target is not LValueExpressionNode lvalue) { - throw new TypeCheckerException(Diagnostic.Error("Cannot take address of an rvalue").At(expression).Build()); + throw new CompileException(Diagnostic.Error("Cannot take address of an rvalue").At(expression).Build()); } var type = new NubPointerType(target.Type); @@ -512,7 +513,7 @@ public sealed class TypeChecker var index = CheckExpression(expression.Index); if (index.Type is not NubIntType) { - throw new TypeCheckerException(Diagnostic + throw new CompileException(Diagnostic .Error("Array indexer must be of type int") .At(expression.Index) .Build()); @@ -525,7 +526,7 @@ public sealed class TypeChecker NubArrayType arrayType => new ArrayIndexAccessNode(expression.Tokens, arrayType.ElementType, target, index), NubConstArrayType constArrayType => new ConstArrayIndexAccessNode(expression.Tokens, constArrayType.ElementType, target, index), NubSliceType sliceType => new SliceIndexAccessNode(expression.Tokens, sliceType.ElementType, target, index), - _ => throw new TypeCheckerException(Diagnostic.Error($"Cannot use array indexer on type {target.Type}").At(expression).Build()) + _ => throw new CompileException(Diagnostic.Error($"Cannot use array indexer on type {target.Type}").At(expression).Build()) }; } @@ -550,7 +551,7 @@ public sealed class TypeChecker if (elementType == null) { - throw new TypeCheckerException(Diagnostic + throw new CompileException(Diagnostic .Error("Unable to infer type of array initializer") .At(expression) .WithHelp("Provide a type for a variable assignment") @@ -563,7 +564,7 @@ public sealed class TypeChecker var value = CheckExpression(valueExpression, elementType); if (value.Type != elementType) { - throw new TypeCheckerException(Diagnostic + throw new CompileException(Diagnostic .Error("Value in array initializer is not the same as the array type") .At(valueExpression) .Build()); @@ -613,7 +614,7 @@ public sealed class TypeChecker var left = CheckExpression(expression.Left); if (left.Type is not NubIntType and not NubFloatType and not NubBoolType) { - throw new TypeCheckerException(Diagnostic + throw new CompileException(Diagnostic .Error("Equal and not equal operators must must be used with int, float or bool types") .At(expression.Left) .Build()); @@ -622,7 +623,7 @@ public sealed class TypeChecker var right = CheckExpression(expression.Right, left.Type); if (right.Type != left.Type) { - throw new TypeCheckerException(Diagnostic + throw new CompileException(Diagnostic .Error($"Expected type {left.Type} from left side of binary expression, but got {right.Type}") .At(expression.Right) .Build()); @@ -638,7 +639,7 @@ public sealed class TypeChecker var left = CheckExpression(expression.Left); if (left.Type is not NubIntType and not NubFloatType) { - throw new TypeCheckerException(Diagnostic + throw new CompileException(Diagnostic .Error("Greater than and less than operators must must be used with int or float types") .At(expression.Left) .Build()); @@ -647,7 +648,7 @@ public sealed class TypeChecker var right = CheckExpression(expression.Right, left.Type); if (right.Type != left.Type) { - throw new TypeCheckerException(Diagnostic + throw new CompileException(Diagnostic .Error($"Expected type {left.Type} from left side of binary expression, but got {right.Type}") .At(expression.Right) .Build()); @@ -661,7 +662,7 @@ public sealed class TypeChecker var left = CheckExpression(expression.Left); if (left.Type is not NubBoolType) { - throw new TypeCheckerException(Diagnostic + throw new CompileException(Diagnostic .Error("Logical and/or must must be used with bool types") .At(expression.Left) .Build()); @@ -670,7 +671,7 @@ public sealed class TypeChecker var right = CheckExpression(expression.Right, left.Type); if (right.Type != left.Type) { - throw new TypeCheckerException(Diagnostic + throw new CompileException(Diagnostic .Error($"Expected type {left.Type} from left side of binary expression, but got {right.Type}") .At(expression.Right) .Build()); @@ -683,7 +684,7 @@ public sealed class TypeChecker var left = CheckExpression(expression.Left, expectedType); if (left.Type is not NubIntType and not NubFloatType) { - throw new TypeCheckerException(Diagnostic + throw new CompileException(Diagnostic .Error("The plus operator must only be used with int and float types") .At(expression.Left) .Build()); @@ -692,7 +693,7 @@ public sealed class TypeChecker var right = CheckExpression(expression.Right, left.Type); if (right.Type != left.Type) { - throw new TypeCheckerException(Diagnostic + throw new CompileException(Diagnostic .Error($"Expected type {left.Type} from left side of binary expression, but got {right.Type}") .At(expression.Right) .Build()); @@ -708,7 +709,7 @@ public sealed class TypeChecker var left = CheckExpression(expression.Left, expectedType); if (left.Type is not NubIntType and not NubFloatType) { - throw new TypeCheckerException(Diagnostic + throw new CompileException(Diagnostic .Error("Math operators must be used with int or float types") .At(expression.Left) .Build()); @@ -717,7 +718,7 @@ public sealed class TypeChecker var right = CheckExpression(expression.Right, left.Type); if (right.Type != left.Type) { - throw new TypeCheckerException(Diagnostic + throw new CompileException(Diagnostic .Error($"Expected type {left.Type} from left side of binary expression, but got {right.Type}") .At(expression.Right) .Build()); @@ -734,7 +735,7 @@ public sealed class TypeChecker var left = CheckExpression(expression.Left, expectedType); if (left.Type is not NubIntType) { - throw new TypeCheckerException(Diagnostic + throw new CompileException(Diagnostic .Error("Bitwise operators must be used with int types") .At(expression.Left) .Build()); @@ -743,7 +744,7 @@ public sealed class TypeChecker var right = CheckExpression(expression.Right, left.Type); if (right.Type != left.Type) { - throw new TypeCheckerException(Diagnostic + throw new CompileException(Diagnostic .Error($"Expected type {left.Type} from left side of binary expression, but got {right.Type}") .At(expression.Right) .Build()); @@ -767,7 +768,7 @@ public sealed class TypeChecker var operand = CheckExpression(expression.Operand, expectedType); if (operand.Type is not NubIntType { Signed: true } and not NubFloatType) { - throw new TypeCheckerException(Diagnostic + throw new CompileException(Diagnostic .Error("Negation operator must be used with signed integer or float types") .At(expression) .Build()); @@ -780,7 +781,7 @@ public sealed class TypeChecker var operand = CheckExpression(expression.Operand, expectedType); if (operand.Type is not NubBoolType) { - throw new TypeCheckerException(Diagnostic + throw new CompileException(Diagnostic .Error("Invert operator must be used with booleans") .At(expression) .Build()); @@ -803,7 +804,7 @@ public sealed class TypeChecker { NubPointerType pointerType => new DereferenceNode(expression.Tokens, pointerType.BaseType, target), NubRefType refType => new RefDereferenceNode(expression.Tokens, refType.BaseType, target), - _ => throw new TypeCheckerException(Diagnostic.Error($"Cannot dereference non-pointer type {target.Type}").At(expression).Build()) + _ => throw new CompileException(Diagnostic.Error($"Cannot dereference non-pointer type {target.Type}").At(expression).Build()) }; } @@ -812,12 +813,12 @@ public sealed class TypeChecker var accessor = CheckExpression(expression.Expression); if (accessor.Type is not NubFuncType funcType) { - throw new TypeCheckerException(Diagnostic.Error($"Cannot call non-function type {accessor.Type}").At(expression.Expression).Build()); + throw new CompileException(Diagnostic.Error($"Cannot call non-function type {accessor.Type}").At(expression.Expression).Build()); } if (expression.Parameters.Count != funcType.Parameters.Count) { - throw new TypeCheckerException(Diagnostic + throw new CompileException(Diagnostic .Error($"Function {funcType} expects {funcType.Parameters.Count} parameters but got {expression.Parameters.Count}") .At(expression.Parameters.LastOrDefault(expression)) .Build()); @@ -832,7 +833,7 @@ public sealed class TypeChecker var parameterExpression = CheckExpression(parameter, expectedParameterType); if (parameterExpression.Type != expectedParameterType) { - throw new TypeCheckerException(Diagnostic + throw new CompileException(Diagnostic .Error($"Parameter {i + 1} does not match the type {expectedParameterType} for function {funcType}") .At(parameter) .Build()); @@ -858,8 +859,8 @@ public sealed class TypeChecker var function = module.Functions(true).FirstOrDefault(x => x.NameToken.Value == expression.NameToken.Value); if (function != null) { - var parameters = function.Prototype.Parameters.Select(x => ResolveType(x.Type)).ToList(); - var type = new NubFuncType(parameters, ResolveType(function.Prototype.ReturnType)); + var parameters = function.Prototype.Parameters.Select(x => _typeResolver.ResolveType(x.Type, Scope.Module.Value)).ToList(); + var type = new NubFuncType(parameters, _typeResolver.ResolveType(function.Prototype.ReturnType, Scope.Module.Value)); return new FuncIdentifierNode(expression.Tokens, type, Scope.Module, expression.NameToken, function.Prototype.ExternSymbolToken); } @@ -869,7 +870,7 @@ public sealed class TypeChecker return new EnumReferenceIntermediateNode(expression.Tokens, Scope.Module, expression.NameToken); } - throw new TypeCheckerException(Diagnostic + throw new CompileException(Diagnostic .Error($"There is no identifier named {expression.NameToken.Value}") .At(expression) .Build()); @@ -880,7 +881,7 @@ public sealed class TypeChecker var module = GetImportedModule(expression.ModuleToken.Value); if (module == null) { - throw new TypeCheckerException(Diagnostic + throw new CompileException(Diagnostic .Error($"Module {expression.ModuleToken.Value} not found") .WithHelp($"import \"{expression.ModuleToken.Value}\"") .At(expression.ModuleToken) @@ -892,8 +893,8 @@ public sealed class TypeChecker { using (BeginRootScope(expression.ModuleToken)) { - var parameters = function.Prototype.Parameters.Select(x => ResolveType(x.Type)).ToList(); - var type = new NubFuncType(parameters, ResolveType(function.Prototype.ReturnType)); + var parameters = function.Prototype.Parameters.Select(x => _typeResolver.ResolveType(x.Type, Scope.Module.Value)).ToList(); + var type = new NubFuncType(parameters, _typeResolver.ResolveType(function.Prototype.ReturnType, Scope.Module.Value)); return new FuncIdentifierNode(expression.Tokens, type, expression.ModuleToken, expression.NameToken, function.Prototype.ExternSymbolToken); } } @@ -904,7 +905,7 @@ public sealed class TypeChecker return new EnumReferenceIntermediateNode(expression.Tokens, expression.ModuleToken, expression.NameToken); } - throw new TypeCheckerException(Diagnostic + throw new CompileException(Diagnostic .Error($"Module {expression.ModuleToken.Value} does not export a member named {expression.NameToken.Value}") .At(expression) .Build()); @@ -982,16 +983,16 @@ public sealed class TypeChecker var field = enumDef.Fields.FirstOrDefault(x => x.NameToken.Value == expression.MemberToken.Value); if (field == null) { - throw new TypeCheckerException(Diagnostic + throw new CompileException(Diagnostic .Error($"Enum {Scope.Module.Value}::{enumReferenceIntermediate.NameToken.Value} does not have a field named {expression.MemberToken.Value}") .At(enumDef) .Build()); } - var enumType = enumDef.Type != null ? ResolveType(enumDef.Type) : new NubIntType(false, 64); + var enumType = enumDef.Type != null ? _typeResolver.ResolveType(enumDef.Type, Scope.Module.Value) : new NubIntType(false, 64); if (enumType is not NubIntType enumIntType) { - throw new TypeCheckerException(Diagnostic.Error("Enum type must be an int type").At(enumDef.Type).Build()); + throw new CompileException(Diagnostic.Error("Enum type must be an int type").At(enumDef.Type).Build()); } if (enumIntType.Signed) @@ -1027,7 +1028,7 @@ public sealed class TypeChecker var field = structType.Fields.FirstOrDefault(x => x.Name == expression.MemberToken.Value); if (field == null) { - throw new TypeCheckerException(Diagnostic + throw new CompileException(Diagnostic .Error($"Struct {target.Type} does not have a field with the name {expression.MemberToken.Value}") .At(expression) .Build()); @@ -1037,7 +1038,7 @@ public sealed class TypeChecker } default: { - throw new TypeCheckerException(Diagnostic + throw new CompileException(Diagnostic .Error($"Cannot access struct member {expression.MemberToken.Value} on type {target.Type}") .At(expression) .Build()); @@ -1095,7 +1096,7 @@ public sealed class TypeChecker if (expression.StructType != null) { - var checkedType = ResolveType(expression.StructType); + var checkedType = _typeResolver.ResolveType(expression.StructType, Scope.Module.Value); if (checkedType is not NubStructType checkedStructType) { throw new UnreachableException("Parser fucked up"); @@ -1115,7 +1116,7 @@ public sealed class TypeChecker if (structType == null) { - throw new TypeCheckerException(Diagnostic + throw new CompileException(Diagnostic .Error("Cannot get implicit type of struct") .WithHelp("Specify struct type with struct {type_name} syntax") .At(expression) @@ -1174,7 +1175,7 @@ public sealed class TypeChecker { statements.Add(CheckStatement(statement)); } - catch (TypeCheckerException e) + catch (CompileException e) { Diagnostics.Add(e.Diagnostic); } @@ -1202,85 +1203,6 @@ public sealed class TypeChecker _ => throw new ArgumentOutOfRangeException(nameof(statement)) }; } - - private NubType ResolveType(TypeSyntax type) - { - return type switch - { - ArrayTypeSyntax arr => new NubArrayType(ResolveType(arr.BaseType)), - BoolTypeSyntax => new NubBoolType(), - IntTypeSyntax i => new NubIntType(i.Signed, i.Width), - FloatTypeSyntax f => new NubFloatType(f.Width), - FuncTypeSyntax func => new NubFuncType(func.Parameters.Select(ResolveType).ToList(), ResolveType(func.ReturnType)), - SliceTypeSyntax slice => new NubSliceType(ResolveType(slice.BaseType)), - ConstArrayTypeSyntax arr => new NubConstArrayType(ResolveType(arr.BaseType), arr.Size), - PointerTypeSyntax ptr => new NubPointerType(ResolveType(ptr.BaseType)), - RefTypeSyntax r => new NubRefType(ResolveType(r.BaseType)), - StringTypeSyntax => new NubStringType(), - CustomTypeSyntax c => ResolveCustomType(c), - VoidTypeSyntax => new NubVoidType(), - _ => throw new NotSupportedException($"Unknown type syntax: {type}") - }; - } - - private NubType ResolveCustomType(CustomTypeSyntax customType) - { - var module = GetImportedModule(customType.ModuleToken?.Value ?? Scope.Module.Value); - if (module == null) - { - throw new TypeCheckerException(Diagnostic - .Error($"Module {customType.ModuleToken?.Value ?? Scope.Module.Value} not found") - .WithHelp($"import \"{customType.ModuleToken?.Value ?? Scope.Module.Value}\"") - .At(customType) - .Build()); - } - - var enumDef = module.Enums(IsCurrentModule(customType.ModuleToken)).FirstOrDefault(x => x.NameToken.Value == customType.NameToken.Value); - if (enumDef != null) - { - return enumDef.Type != null ? ResolveType(enumDef.Type) : new NubIntType(false, 64); - } - - var structDef = module.Structs(IsCurrentModule(customType.ModuleToken)).FirstOrDefault(x => x.NameToken.Value == customType.NameToken.Value); - if (structDef != null) - { - var key = (customType.ModuleToken?.Value ?? Scope.Module.Value, customType.NameToken.Value); - - if (_typeCache.TryGetValue(key, out var cachedType)) - { - return cachedType; - } - - if (!_resolvingTypes.Add(key)) - { - var placeholder = new NubStructType(customType.ModuleToken?.Value ?? Scope.Module.Value, customType.NameToken.Value, []); - _typeCache[key] = placeholder; - return placeholder; - } - - try - { - var result = new NubStructType(customType.ModuleToken?.Value ?? Scope.Module.Value, structDef.NameToken.Value, []); - _typeCache[key] = result; - - var fields = structDef.Fields - .Select(x => new NubStructFieldType(x.NameToken.Value, ResolveType(x.Type), x.Value != null)) - .ToList(); - - result.Fields.AddRange(fields); - return result; - } - finally - { - _resolvingTypes.Remove(key); - } - } - - throw new TypeCheckerException(Diagnostic - .Error($"Type {customType.NameToken.Value} not found in module {customType.ModuleToken?.Value ?? Scope.Module.Value}") - .At(customType) - .Build()); - } } public record Variable(IdentifierToken Name, NubType Type); @@ -1321,14 +1243,4 @@ public class Scope(IdentifierToken module, Scope? parent = null) { return new Scope(Module, this); } -} - -public class TypeCheckerException : Exception -{ - public Diagnostic Diagnostic { get; } - - public TypeCheckerException(Diagnostic diagnostic) : base(diagnostic.Message) - { - Diagnostic = diagnostic; - } } \ No newline at end of file diff --git a/compiler/NubLang/Ast/TypeResolver.cs b/compiler/NubLang/Ast/TypeResolver.cs new file mode 100644 index 0000000..effeaf1 --- /dev/null +++ b/compiler/NubLang/Ast/TypeResolver.cs @@ -0,0 +1,97 @@ +using NubLang.Diagnostics; +using NubLang.Syntax; + +namespace NubLang.Ast; + +public class TypeResolver +{ + private readonly Dictionary _modules; + private readonly Dictionary<(string Module, string Name), NubType> _typeCache = new(); + private readonly HashSet<(string Module, string Name)> _resolvingTypes = []; + + public TypeResolver(Dictionary modules) + { + _modules = modules; + } + + public NubType ResolveType(TypeSyntax type, string currentModule) + { + return type switch + { + ArrayTypeSyntax arr => new NubArrayType(ResolveType(arr.BaseType, currentModule)), + BoolTypeSyntax => new NubBoolType(), + IntTypeSyntax i => new NubIntType(i.Signed, i.Width), + FloatTypeSyntax f => new NubFloatType(f.Width), + FuncTypeSyntax func => new NubFuncType(func.Parameters.Select(x => ResolveType(x, currentModule)).ToList(), ResolveType(func.ReturnType, currentModule)), + SliceTypeSyntax slice => new NubSliceType(ResolveType(slice.BaseType, currentModule)), + ConstArrayTypeSyntax arr => new NubConstArrayType(ResolveType(arr.BaseType, currentModule), arr.Size), + PointerTypeSyntax ptr => new NubPointerType(ResolveType(ptr.BaseType, currentModule)), + RefTypeSyntax r => new NubRefType(ResolveType(r.BaseType, currentModule)), + StringTypeSyntax => new NubStringType(), + CustomTypeSyntax c => ResolveCustomType(c, currentModule), + VoidTypeSyntax => new NubVoidType(), + _ => throw new NotSupportedException($"Unknown type syntax: {type}") + }; + } + + private NubType ResolveCustomType(CustomTypeSyntax customType, string currentModule) + { + var module = _modules[customType.ModuleToken?.Value ?? currentModule]; + + var enumDef = module.Enums(true).FirstOrDefault(x => x.NameToken.Value == customType.NameToken.Value); + if (enumDef != null) + { + return enumDef.Type != null ? ResolveType(enumDef.Type, currentModule) : new NubIntType(false, 64); + } + + var structDef = module.Structs(true).FirstOrDefault(x => x.NameToken.Value == customType.NameToken.Value); + if (structDef != null) + { + var key = (customType.ModuleToken?.Value ?? currentModule, customType.NameToken.Value); + + if (_typeCache.TryGetValue(key, out var cachedType)) + { + return cachedType; + } + + if (!_resolvingTypes.Add(key)) + { + var placeholder = new NubStructType(customType.ModuleToken?.Value ?? currentModule, customType.NameToken.Value, []); + _typeCache[key] = placeholder; + return placeholder; + } + + try + { + var result = new NubStructType(customType.ModuleToken?.Value ?? currentModule, structDef.NameToken.Value, []); + _typeCache[key] = result; + + var fields = structDef.Fields + .Select(x => new NubStructFieldType(x.NameToken.Value, ResolveType(x.Type, currentModule), x.Value != null)) + .ToList(); + + result.Fields.AddRange(fields); + return result; + } + finally + { + _resolvingTypes.Remove(key); + } + } + + throw new TypeResolverException(Diagnostic + .Error($"Type {customType.NameToken.Value} not found in module {customType.ModuleToken?.Value ?? currentModule}") + .At(customType) + .Build()); + } +} + +public class TypeResolverException : Exception +{ + public Diagnostic Diagnostic { get; } + + public TypeResolverException(Diagnostic diagnostic) : base(diagnostic.Message) + { + Diagnostic = diagnostic; + } +} \ No newline at end of file diff --git a/compiler/NubLang/Diagnostics/CompileException.cs b/compiler/NubLang/Diagnostics/CompileException.cs new file mode 100644 index 0000000..4e77451 --- /dev/null +++ b/compiler/NubLang/Diagnostics/CompileException.cs @@ -0,0 +1,11 @@ +namespace NubLang.Diagnostics; + +public class CompileException : Exception +{ + public Diagnostic Diagnostic { get; } + + public CompileException(Diagnostic diagnostic) : base(diagnostic.Message) + { + Diagnostic = diagnostic; + } +} \ No newline at end of file diff --git a/compiler/NubLang/Generation/CType.cs b/compiler/NubLang/Generation/CType.cs index 5da4c67..daad988 100644 --- a/compiler/NubLang/Generation/CType.cs +++ b/compiler/NubLang/Generation/CType.cs @@ -14,12 +14,12 @@ public static class CType NubFloatType f => CreateFloatType(f, variableName), NubPointerType p => CreatePointerType(p, variableName), NubRefType r => CreateRefType(r, variableName), - NubSliceType => "struct nub_slice" + (variableName != null ? $" {variableName}" : ""), - NubStringType => "struct nub_string" + (variableName != null ? $" {variableName}" : ""), + NubSliceType => "nub_slice" + (variableName != null ? $" {variableName}" : ""), + NubStringType => "nub_string" + (variableName != null ? $" {variableName}" : ""), NubConstArrayType a => CreateConstArrayType(a, variableName, constArraysAsPointers), NubArrayType a => CreateArrayType(a, variableName), NubFuncType f => CreateFuncType(f, variableName), - NubStructType s => $"struct {s.Module}_{s.Name}_{NameMangler.Mangle(s)}" + (variableName != null ? $" {variableName}" : ""), + NubStructType s => $"{s.Module}_{s.Name}_{NameMangler.Mangle(s)}" + (variableName != null ? $" {variableName}" : ""), _ => throw new NotSupportedException($"C type generation not supported for: {type}") }; } diff --git a/compiler/NubLang/Generation/Generator.cs b/compiler/NubLang/Generation/Generator.cs index 4c5aa97..acc3f45 100644 --- a/compiler/NubLang/Generation/Generator.cs +++ b/compiler/NubLang/Generation/Generator.cs @@ -7,14 +7,14 @@ namespace NubLang.Generation; public class Generator { - private readonly CompilationUnit _compilationUnit; + private readonly List _compilationUnit; private readonly IndentedTextWriter _writer; private readonly Stack _scopes = []; private int _tmpIndex; private Scope Scope => _scopes.Peek(); - public Generator(CompilationUnit compilationUnit) + public Generator(List compilationUnit) { _compilationUnit = compilationUnit; _writer = new IndentedTextWriter(); @@ -31,66 +31,51 @@ public class Generator return externSymbol ?? $"{module}_{name}"; } + private string GetModuleName() + { + return _compilationUnit.OfType().First().NameToken.Value; + } + public string Emit() { - _writer.WriteLine(""" - #include - - void *rc_alloc(size_t size, void (*destructor)(void *self)); - void rc_retain(void *obj); - void rc_release(void *obj); - - struct nub_string - { - unsigned long long length; - char *data; - }; - - struct nub_slice - { - unsigned long long length; - void *data; - }; - - """); - - foreach (var (_, structTypes) in _compilationUnit.ImportedStructTypes) + foreach (var structType in _compilationUnit.OfType()) { - foreach (var structType in structTypes) + _writer.WriteLine($"void {CType.Create(structType.StructType)}_create({CType.Create(structType.StructType)} *self)"); + _writer.WriteLine("{"); + using (_writer.Indent()) { - _writer.WriteLine(CType.Create(structType)); - _writer.WriteLine("{"); - using (_writer.Indent()) + foreach (var field in structType.Fields) { - foreach (var field in structType.Fields) + if (field.Value != null) { - _writer.WriteLine($"{CType.Create(field.Type, field.Name, constArraysAsPointers: false)};"); + var value = EmitExpression(field.Value); + _writer.WriteLine($"self->{field.NameToken.Value} = {value}"); } } - - _writer.WriteLine("};"); - _writer.WriteLine(); } - } - // note(nub31): Forward declarations - foreach (var (module, prototypes) in _compilationUnit.ImportedFunctions) - { - foreach (var prototype in prototypes) + _writer.WriteLine("}"); + _writer.WriteLine(); + + _writer.WriteLine($"void {CType.Create(structType.StructType)}_destroy({CType.Create(structType.StructType)} *self)"); + _writer.WriteLine("{"); + using (_writer.Indent()) { - EmitLine(prototype.Tokens.FirstOrDefault()); - var parameters = prototype.Parameters.Count != 0 - ? string.Join(", ", prototype.Parameters.Select(x => CType.Create(x.Type, x.NameToken.Value))) - : "void"; - - var name = FuncName(module.Value, prototype.NameToken.Value, prototype.ExternSymbolToken?.Value); - _writer.WriteLine($"{CType.Create(prototype.ReturnType, name)}({parameters});"); - _writer.WriteLine(); + foreach (var field in structType.Fields) + { + if (field.Type is NubRefType) + { + _writer.WriteLine($"rc_release(self->{field.NameToken.Value});"); + } + } } + + _writer.WriteLine("}"); + _writer.WriteLine(); } // note(nub31): Normal functions - foreach (var funcNode in _compilationUnit.Functions) + foreach (var funcNode in _compilationUnit.OfType()) { if (funcNode.Body == null) continue; @@ -99,7 +84,7 @@ public class Generator ? string.Join(", ", funcNode.Prototype.Parameters.Select(x => CType.Create(x.Type, x.NameToken.Value))) : "void"; - var name = FuncName(_compilationUnit.Module.Value, funcNode.NameToken.Value, funcNode.Prototype.ExternSymbolToken?.Value); + var name = FuncName(GetModuleName(), funcNode.NameToken.Value, funcNode.Prototype.ExternSymbolToken?.Value); _writer.WriteLine($"{CType.Create(funcNode.Prototype.ReturnType, name)}({parameters})"); _writer.WriteLine("{"); using (_writer.Indent()) @@ -314,8 +299,7 @@ public class Generator private void EmitStatementFuncCall(StatementFuncCallNode statementFuncCallNode) { - var funcCall = EmitFuncCall(statementFuncCallNode.FuncCall); - _writer.WriteLine($"{funcCall};"); + EmitFuncCall(statementFuncCallNode.FuncCall); } private void EmitVariableDeclaration(VariableDeclarationNode variableDeclarationNode) @@ -526,7 +510,16 @@ public class Generator parameterNames.Add(result); } - return $"{name}({string.Join(", ", parameterNames)})"; + var tmp = NewTmp(); + + _writer.WriteLine($"{CType.Create(funcCallNode.Type)} {tmp} = {name}({string.Join(", ", parameterNames)});"); + + if (funcCallNode.Type is NubRefType) + { + Scope.Defer(() => _writer.WriteLine($"rc_release({tmp});")); + } + + return tmp; } private string EmitAddressOf(AddressOfNode addressOfNode) @@ -547,22 +540,18 @@ public class Generator var structType = (NubStructType)type.BaseType; var tmp = NewTmp(); - _writer.WriteLine($"{CType.Create(type)} {tmp} = ({CType.Create(type)})rc_alloc(sizeof({CType.Create(structType)}), NULL);"); + _writer.WriteLine($"{CType.Create(type)} {tmp} = ({CType.Create(type)})rc_alloc(sizeof({CType.Create(structType)}), (void (*)(void *)){CType.Create(structType)}_destroy);"); Scope.Defer(() => _writer.WriteLine($"rc_release({tmp});")); - var initValues = new List(); + _writer.WriteLine($"*{tmp} = ({CType.Create(structType)}){{{0}}};"); + _writer.WriteLine($"{CType.Create(structType)}_create({tmp});"); + foreach (var initializer in refStructInitializerNode.Initializers) { var value = EmitExpression(initializer.Value); - initValues.Add($".{initializer.Key.Value} = {value}"); + _writer.WriteLine($"{tmp}->{initializer.Key} = {value};"); } - var initString = initValues.Count == 0 - ? "0" - : string.Join(", ", initValues); - - _writer.WriteLine($"*{tmp} = ({CType.Create(structType)}){{{initString}}};"); - return tmp; } @@ -578,7 +567,7 @@ public class Generator private string EmitStringLiteral(StringLiteralNode stringLiteralNode) { var length = Encoding.UTF8.GetByteCount(stringLiteralNode.Value); - return $"(nub_string){{.length = {length}, .data = \"{stringLiteralNode.Value}\"}}"; + return $"({CType.Create(stringLiteralNode.Type)}){{.length = {length}, .data = \"{stringLiteralNode.Value}\"}}"; } private string EmitStructFieldAccess(StructFieldAccessNode structFieldAccessNode) @@ -589,18 +578,19 @@ public class Generator private string EmitStructInitializer(StructInitializerNode structInitializerNode) { - var initValues = new List(); + var structType = (NubStructType)structInitializerNode.Type; + + var tmp = NewTmp(); + _writer.WriteLine($"{CType.Create(structType)} {tmp} = ({CType.Create(structType)}){{0}};"); + _writer.WriteLine($"{CType.Create(structType)}_create(&{tmp});"); + foreach (var initializer in structInitializerNode.Initializers) { var value = EmitExpression(initializer.Value); - initValues.Add($".{initializer.Key.Value} = {value}"); + _writer.WriteLine($"{tmp}.{initializer.Key} = {value};"); } - var initString = initValues.Count == 0 - ? "0" - : string.Join(", ", initValues); - - return $"({CType.Create(structInitializerNode.Type)}){{{initString}}}"; + return tmp; } private string EmitI8Literal(I8LiteralNode i8LiteralNode) diff --git a/compiler/NubLang/Generation/HeaderGenerator.cs b/compiler/NubLang/Generation/HeaderGenerator.cs new file mode 100644 index 0000000..059cb22 --- /dev/null +++ b/compiler/NubLang/Generation/HeaderGenerator.cs @@ -0,0 +1,49 @@ +using NubLang.Ast; +using NubLang.Syntax; + +namespace NubLang.Generation; + +public static class HeaderGenerator +{ + private static string FuncName(string module, string name, string? externSymbol) + { + return externSymbol ?? $"{module}_{name}"; + } + + public static string Generate(string name, TypedModule module) + { + var writer = new IndentedTextWriter(); + + writer.WriteLine(); + + foreach (var structType in module.StructTypes) + { + writer.WriteLine("typedef struct"); + writer.WriteLine("{"); + using (writer.Indent()) + { + foreach (var field in structType.Fields) + { + writer.WriteLine($"{CType.Create(field.Type)} {field.Name};"); + } + } + + writer.WriteLine($"}} {CType.Create(structType)};"); + writer.WriteLine($"void {CType.Create(structType)}_create({CType.Create(structType)} *self);"); + writer.WriteLine($"void {CType.Create(structType)}_destroy({CType.Create(structType)} *self);"); + writer.WriteLine(); + } + + foreach (var prototype in module.FunctionPrototypes) + { + var parameters = prototype.Parameters.Count != 0 + ? string.Join(", ", prototype.Parameters.Select(x => CType.Create(x.Type, x.NameToken.Value))) + : "void"; + + var funcName = FuncName(name, prototype.NameToken.Value, prototype.ExternSymbolToken?.Value); + writer.WriteLine($"{CType.Create(prototype.ReturnType, funcName)}({parameters});"); + } + + return writer.ToString(); + } +} \ No newline at end of file diff --git a/compiler/NubLang/Syntax/Parser.cs b/compiler/NubLang/Syntax/Parser.cs index cf05deb..253ce07 100644 --- a/compiler/NubLang/Syntax/Parser.cs +++ b/compiler/NubLang/Syntax/Parser.cs @@ -45,7 +45,7 @@ public sealed class Parser Symbol.Func => ParseFunc(startIndex, exported, null), Symbol.Struct => ParseStruct(startIndex, exported), Symbol.Enum => ParseEnum(startIndex, exported), - _ => throw new ParseException(Diagnostic + _ => throw new CompileException(Diagnostic .Error($"Expected 'func', 'struct', 'enum', 'import' or 'module' but found '{keyword.Symbol}'") .WithHelp("Valid top level statements are 'func', 'struct', 'enum', 'import' and 'module'") .At(keyword) @@ -54,7 +54,7 @@ public sealed class Parser topLevelSyntaxNodes.Add(definition); } - catch (ParseException e) + catch (CompileException e) { Diagnostics.Add(e.Diagnostic); while (HasToken) @@ -180,7 +180,7 @@ public sealed class Parser { if (!TryExpectIntLiteral(out var intLiteralToken)) { - throw new ParseException(Diagnostic + throw new CompileException(Diagnostic .Error("Value of enum field must be an integer literal") .At(CurrentToken) .Build()); @@ -451,13 +451,13 @@ public sealed class Parser Symbol.OpenBrace => new StructInitializerSyntax(GetTokens(startIndex), null, ParseStructInitializerBody()), Symbol.Struct => ParseStructInitializer(startIndex), Symbol.At => ParseBuiltinFunction(startIndex), - _ => throw new ParseException(Diagnostic + _ => throw new CompileException(Diagnostic .Error($"Unexpected symbol '{symbolToken.Symbol}' in expression") .WithHelp("Expected '(', '-', '!', '[' or '{'") .At(symbolToken) .Build()) }, - _ => throw new ParseException(Diagnostic + _ => throw new CompileException(Diagnostic .Error($"Unexpected token '{token.GetType().Name}' in expression") .WithHelp("Expected literal, identifier, or parenthesized expression") .At(token) @@ -488,7 +488,7 @@ public sealed class Parser } default: { - throw new ParseException(Diagnostic.Error($"Unknown builtin {name.Value}").At(name).Build()); + throw new CompileException(Diagnostic.Error($"Unknown builtin {name.Value}").At(name).Build()); } } } @@ -628,7 +628,7 @@ public sealed class Parser { statements.Add(ParseStatement()); } - catch (ParseException ex) + catch (CompileException ex) { Diagnostics.Add(ex.Diagnostic); if (HasToken) @@ -654,7 +654,7 @@ public sealed class Parser { if (size is not 8 and not 16 and not 32 and not 64) { - throw new ParseException(Diagnostic + throw new CompileException(Diagnostic .Error("Arbitrary uint size is not supported") .WithHelp("Use u8, u16, u32 or u64") .At(name) @@ -668,7 +668,7 @@ public sealed class Parser { if (size is not 8 and not 16 and not 32 and not 64) { - throw new ParseException(Diagnostic + throw new CompileException(Diagnostic .Error("Arbitrary int size is not supported") .WithHelp("Use i8, i16, i32 or i64") .At(name) @@ -682,7 +682,7 @@ public sealed class Parser { if (size is not 32 and not 64) { - throw new ParseException(Diagnostic + throw new CompileException(Diagnostic .Error("Arbitrary float size is not supported") .WithHelp("Use f32 or f64") .At(name) @@ -772,7 +772,7 @@ public sealed class Parser } } - throw new ParseException(Diagnostic + throw new CompileException(Diagnostic .Error("Invalid type syntax") .WithHelp("Expected type name, '^' for pointer, or '[]' for array") .At(CurrentToken) @@ -783,7 +783,7 @@ public sealed class Parser { if (!HasToken) { - throw new ParseException(Diagnostic + throw new CompileException(Diagnostic .Error("Unexpected end of file") .WithHelp("Expected more tokens to complete the syntax") .At(_tokens[^1]) @@ -800,7 +800,7 @@ public sealed class Parser var token = ExpectToken(); if (token is not SymbolToken symbol) { - throw new ParseException(Diagnostic + throw new CompileException(Diagnostic .Error($"Expected symbol, but found {token.GetType().Name}") .WithHelp("This position requires a symbol like '(', ')', '{', '}', etc.") .At(token) @@ -815,7 +815,7 @@ public sealed class Parser var token = ExpectSymbol(); if (token.Symbol != expectedSymbol) { - throw new ParseException(Diagnostic + throw new CompileException(Diagnostic .Error($"Expected '{expectedSymbol}', but found '{token.Symbol}'") .WithHelp($"Insert '{expectedSymbol}' here") .At(token) @@ -865,7 +865,7 @@ public sealed class Parser var token = ExpectToken(); if (token is not IdentifierToken identifier) { - throw new ParseException(Diagnostic + throw new CompileException(Diagnostic .Error($"Expected identifier, but found {token.GetType().Name}") .WithHelp("Provide a valid identifier name here") .At(token) @@ -893,7 +893,7 @@ public sealed class Parser var token = ExpectToken(); if (token is not StringLiteralToken identifier) { - throw new ParseException(Diagnostic + throw new CompileException(Diagnostic .Error($"Expected string literal, but found {token.GetType().Name}") .WithHelp("Provide a valid string literal") .At(token) @@ -914,14 +914,4 @@ public sealed class Parser } } -public record SyntaxTree(List TopLevelSyntaxNodes); - -public class ParseException : Exception -{ - public Diagnostic Diagnostic { get; } - - public ParseException(Diagnostic diagnostic) : base(diagnostic.Message) - { - Diagnostic = diagnostic; - } -} \ No newline at end of file +public record SyntaxTree(List TopLevelSyntaxNodes); \ No newline at end of file diff --git a/compiler/NubLang/Syntax/Tokenizer.cs b/compiler/NubLang/Syntax/Tokenizer.cs index 3102eb9..1734c91 100644 --- a/compiler/NubLang/Syntax/Tokenizer.cs +++ b/compiler/NubLang/Syntax/Tokenizer.cs @@ -58,7 +58,7 @@ public sealed class Tokenizer Tokens.Add(ParseToken(current, _line, _column)); } - catch (TokenizerException e) + catch (CompileException e) { Diagnostics.Add(e.Diagnostic); Next(); @@ -95,7 +95,7 @@ public sealed class Tokenizer return ParseIdentifier(lineStart, columnStart); } - throw new TokenizerException(Diagnostic.Error($"Unknown token '{current}'").Build()); + throw new CompileException(Diagnostic.Error($"Unknown token '{current}'").Build()); } private Token ParseNumber(int lineStart, int columnStart) @@ -116,7 +116,7 @@ public sealed class Tokenizer if (_index == digitStart) { - throw new TokenizerException(Diagnostic + throw new CompileException(Diagnostic .Error("Invalid hex literal, no digits found") .At(_fileName, _line, _column) .Build()); @@ -141,7 +141,7 @@ public sealed class Tokenizer if (_index == digitStart) { - throw new TokenizerException(Diagnostic + throw new CompileException(Diagnostic .Error("Invalid binary literal, no digits found") .At(_fileName, _line, _column) .Build()); @@ -163,7 +163,7 @@ public sealed class Tokenizer { if (isFloat) { - throw new TokenizerException(Diagnostic + throw new CompileException(Diagnostic .Error("More than one period found in float literal") .At(_fileName, _line, _column) .Build()); @@ -198,7 +198,7 @@ public sealed class Tokenizer { if (_index >= _content.Length) { - throw new TokenizerException(Diagnostic + throw new CompileException(Diagnostic .Error("Unclosed string literal") .At(_fileName, _line, _column) .Build()); @@ -208,7 +208,7 @@ public sealed class Tokenizer if (next == '\n') { - throw new TokenizerException(Diagnostic + throw new CompileException(Diagnostic .Error("Unclosed string literal (newline found)") .At(_fileName, _line, _column) .Build()); @@ -375,14 +375,4 @@ public sealed class Tokenizer _index += count; _column += count; } -} - -public class TokenizerException : Exception -{ - public Diagnostic Diagnostic { get; } - - public TokenizerException(Diagnostic diagnostic) : base(diagnostic.Message) - { - Diagnostic = diagnostic; - } } \ No newline at end of file diff --git a/compiler/NubLang/Syntax/TypedModule.cs b/compiler/NubLang/Syntax/TypedModule.cs new file mode 100644 index 0000000..ba336e0 --- /dev/null +++ b/compiler/NubLang/Syntax/TypedModule.cs @@ -0,0 +1,50 @@ +using NubLang.Ast; + +namespace NubLang.Syntax; + +public sealed class TypedModule +{ + public static TypedModule FromModule(string name, Module module, Dictionary modules) + { + var typeResolver = new TypeResolver(modules); + + var functionPrototypes = new List(); + + foreach (var funcSyntax in module.Functions(true)) + { + var parameters = new List(); + foreach (var parameter in funcSyntax.Prototype.Parameters) + { + parameters.Add(new FuncParameterNode(parameter.Tokens, parameter.NameToken, typeResolver.ResolveType(parameter.Type, name))); + } + + var returnType = typeResolver.ResolveType(funcSyntax.Prototype.ReturnType, name); + + functionPrototypes.Add(new FuncPrototypeNode(funcSyntax.Tokens, funcSyntax.Prototype.NameToken, funcSyntax.Prototype.ExternSymbolToken, parameters, returnType)); + } + + var structTypes = new List(); + + foreach (var structSyntax in module.Structs(true)) + { + var fields = new List(); + foreach (var field in structSyntax.Fields) + { + fields.Add(new NubStructFieldType(field.NameToken.Value, typeResolver.ResolveType(field.Type, name), field.Value != null)); + } + + structTypes.Add(new NubStructType(name, structSyntax.NameToken.Value, fields)); + } + + return new TypedModule(functionPrototypes, structTypes); + } + + public TypedModule(List functionPrototypes, List structTypes) + { + FunctionPrototypes = functionPrototypes; + StructTypes = structTypes; + } + + public List FunctionPrototypes { get; set; } + public List StructTypes { get; set; } +} \ No newline at end of file diff --git a/examples/playgroud/main.nub b/examples/playgroud/main.nub index 44658a3..d3c4d54 100644 --- a/examples/playgroud/main.nub +++ b/examples/playgroud/main.nub @@ -2,17 +2,34 @@ module main extern "puts" func puts(text: ^i8) +struct Name +{ + first: ^i8 + last: ^i8 +} + struct Human { age: u64 - name: ^i8 + name: &Name } extern "main" func main(argc: i64, argv: [?]^i8): i64 { let x: &Human = { age = 23 - name = "test" + name = { + first = "oliver" + last = "stene" + } + } + + let z: Human = { + age = 23 + name = { + first = "oliver" + last = "stene" + } } test(x) @@ -22,6 +39,7 @@ extern "main" func main(argc: i64, argv: [?]^i8): i64 return 0 } -func test(x: &Human) +func test(x: &Human): &Human { + return x } \ No newline at end of file