Compare commits

...

2 Commits

Author SHA1 Message Date
nub31
560e6428ff ... 2025-10-26 22:28:48 +01:00
nub31
27bc4da4fd ... 2025-10-26 20:04:57 +01:00
19 changed files with 673 additions and 490 deletions

View File

@@ -21,7 +21,7 @@ foreach (var file in args)
} }
var modules = Module.Collect(syntaxTrees); var modules = Module.Collect(syntaxTrees);
var compilationUnits = new List<CompilationUnit?>(); var compilationUnits = new List<List<TopLevelNode>>();
for (var i = 0; i < args.Length; i++) for (var i = 0; i < args.Length; i++)
{ {
@@ -46,16 +46,48 @@ var cPaths = new List<string>();
Directory.CreateDirectory(".build"); Directory.CreateDirectory(".build");
var typedModules = modules.Select(x => (x.Key, TypedModule.FromModule(x.Key, x.Value, modules))).ToDictionary();
var moduleHeaders = new List<string>();
var commonHeaderOut = Path.Combine(".build", "runtime.h");
File.WriteAllText(commonHeaderOut, """
#include <stddef.h>
void *rc_alloc(size_t size, void (*destructor)(void *self));
void rc_retain(void *obj);
void rc_release(void *obj);
typedef struct
{
unsigned long long length;
char *data;
} nub_string;
typedef struct
{
unsigned long long length;
void *data;
} nub_slice;
""");
moduleHeaders.Add(commonHeaderOut);
foreach (var typedModule in typedModules)
{
var header = HeaderGenerator.Generate(typedModule.Key, typedModule.Value);
var headerOut = Path.Combine(".build", "modules", typedModule.Key + ".h");
Directory.CreateDirectory(Path.Combine(".build", "modules"));
File.WriteAllText(headerOut, header);
moduleHeaders.Add(headerOut);
}
for (var i = 0; i < args.Length; i++) for (var i = 0; i < args.Length; i++)
{ {
var file = args[i]; var file = args[i];
var compilationUnit = compilationUnits[i]; var compilationUnit = compilationUnits[i];
if (compilationUnit == null)
{
continue;
}
var generator = new Generator(compilationUnit); var generator = new Generator(compilationUnit);
var directory = Path.GetDirectoryName(file); var directory = Path.GetDirectoryName(file);
if (!string.IsNullOrWhiteSpace(directory)) if (!string.IsNullOrWhiteSpace(directory))
@@ -74,6 +106,7 @@ foreach (var cPath in cPaths)
{ {
var objectPath = Path.ChangeExtension(cPath, "o"); var objectPath = Path.ChangeExtension(cPath, "o");
using var compileProcess = Process.Start("clang", [ using var compileProcess = Process.Start("clang", [
..moduleHeaders.SelectMany(x => new[] { "-include", x }),
"-ffreestanding", "-std=c23", "-ffreestanding", "-std=c23",
"-g", "-c", "-g", "-c",
"-o", objectPath, "-o", objectPath,

View File

@@ -1,5 +1,4 @@
using NubLang.Ast; using NubLang.Ast;
using NubLang.Syntax;
using OmniSharp.Extensions.LanguageServer.Protocol.Models; using OmniSharp.Extensions.LanguageServer.Protocol.Models;
using Range = OmniSharp.Extensions.LanguageServer.Protocol.Models.Range; using Range = OmniSharp.Extensions.LanguageServer.Protocol.Models.Range;
@@ -58,16 +57,16 @@ public static class AstExtensions
return false; return false;
} }
public static FuncNode? FunctionAtPosition(this CompilationUnit compilationUnit, int line, int character) public static FuncNode? FunctionAtPosition(this List<TopLevelNode> compilationUnit, int line, int character)
{ {
return compilationUnit return compilationUnit
.Functions .OfType<FuncNode>()
.FirstOrDefault(x => x.ContainsPosition(line, character)); .FirstOrDefault(x => x.ContainsPosition(line, character));
} }
public static Node? DeepestNodeAtPosition(this CompilationUnit compilationUnit, int line, int character) public static Node? DeepestNodeAtPosition(this List<TopLevelNode> compilationUnit, int line, int character)
{ {
return compilationUnit.Functions return compilationUnit
.SelectMany(x => x.DescendantsAndSelf()) .SelectMany(x => x.DescendantsAndSelf())
.Where(n => n.ContainsPosition(line, character)) .Where(n => n.ContainsPosition(line, character))
.OrderBy(n => n.Tokens.First().Span.Start.Line) .OrderBy(n => n.Tokens.First().Span.Start.Line)

View File

@@ -118,30 +118,30 @@ internal class CompletionHandler(WorkspaceManager workspaceManager) : Completion
var compilationUnit = workspaceManager.GetCompilationUnit(uri); var compilationUnit = workspaceManager.GetCompilationUnit(uri);
if (compilationUnit != null) if (compilationUnit != null)
{ {
var function = compilationUnit.Functions.FirstOrDefault(x => x.Body != null && x.Body.ContainsPosition(position.Line, position.Character)); var function = compilationUnit.OfType<FuncNode>().FirstOrDefault(x => x.Body != null && x.Body.ContainsPosition(position.Line, position.Character));
if (function != null) if (function != null)
{ {
completions.AddRange(_statementSnippets); completions.AddRange(_statementSnippets);
foreach (var (module, prototypes) in compilationUnit.ImportedFunctions) // foreach (var (module, prototypes) in compilationUnit.ImportedFunctions)
{ // {
foreach (var prototype in prototypes) // foreach (var prototype in prototypes)
{ // {
var parameterStrings = new List<string>(); // var parameterStrings = new List<string>();
foreach (var (index, parameter) in prototype.Parameters.Index()) // foreach (var (index, parameter) in prototype.Parameters.Index())
{ // {
parameterStrings.AddRange($"${{{index + 1}:{parameter.NameToken.Value}}}"); // parameterStrings.AddRange($"${{{index + 1}:{parameter.NameToken.Value}}}");
} // }
//
completions.Add(new CompletionItem // completions.Add(new CompletionItem
{ // {
Kind = CompletionItemKind.Function, // Kind = CompletionItemKind.Function,
Label = $"{module.Value}::{prototype.NameToken.Value}", // Label = $"{module.Value}::{prototype.NameToken.Value}",
InsertTextFormat = InsertTextFormat.Snippet, // InsertTextFormat = InsertTextFormat.Snippet,
InsertText = $"{module.Value}::{prototype.NameToken.Value}({string.Join(", ", parameterStrings)})", // InsertText = $"{module.Value}::{prototype.NameToken.Value}({string.Join(", ", parameterStrings)})",
}); // });
} // }
} // }
foreach (var parameter in function.Prototype.Parameters) foreach (var parameter in function.Prototype.Parameters)
{ {

View File

@@ -57,15 +57,16 @@ internal class DefinitionHandler(WorkspaceManager workspaceManager) : Definition
} }
case FuncIdentifierNode funcIdentifierNode: case FuncIdentifierNode funcIdentifierNode:
{ {
var prototype = compilationUnit.ImportedFunctions // var prototype = compilationUnit
.Where(x => x.Key.Value == funcIdentifierNode.ModuleToken.Value) // .ImportedFunctions
.SelectMany(x => x.Value) // .Where(x => x.Key.Value == funcIdentifierNode.ModuleToken.Value)
.FirstOrDefault(x => x.NameToken.Value == funcIdentifierNode.NameToken.Value); // .SelectMany(x => x.Value)
// .FirstOrDefault(x => x.NameToken.Value == funcIdentifierNode.NameToken.Value);
if (prototype != null) //
{ // if (prototype != null)
return new LocationOrLocationLinks(prototype.ToLocation()); // {
} // return new LocationOrLocationLinks(prototype.ToLocation());
// }
return null; return null;
} }

View File

@@ -39,108 +39,110 @@ internal class HoverHandler(WorkspaceManager workspaceManager) : HoverHandlerBas
return null; return null;
} }
var message = CreateMessage(hoveredNode, compilationUnit); // var message = CreateMessage(hoveredNode, compilationUnit);
if (message == null) // if (message == null)
{ // {
// return null;
// }
//
// return new Hover
// {
// Contents = new MarkedStringsOrMarkupContent(new MarkupContent
// {
// Value = message,
// Kind = MarkupKind.Markdown,
// })
// };
return null; return null;
} }
return new Hover // private static string? CreateMessage(Node hoveredNode, CompilationUnit compilationUnit)
{ // {
Contents = new MarkedStringsOrMarkupContent(new MarkupContent // return hoveredNode switch
{ // {
Value = message, // FuncNode funcNode => CreateFuncPrototypeMessage(funcNode.Prototype),
Kind = MarkupKind.Markdown, // FuncPrototypeNode funcPrototypeNode => CreateFuncPrototypeMessage(funcPrototypeNode),
}) // FuncIdentifierNode funcIdentifierNode => CreateFuncIdentifierMessage(funcIdentifierNode, compilationUnit),
}; // FuncParameterNode funcParameterNode => CreateTypeNameMessage("Function parameter", funcParameterNode.NameToken.Value, funcParameterNode.Type),
} // VariableIdentifierNode variableIdentifierNode => CreateTypeNameMessage("Variable", variableIdentifierNode.NameToken.Value, variableIdentifierNode.Type),
// VariableDeclarationNode variableDeclarationNode => CreateTypeNameMessage("Variable declaration", variableDeclarationNode.NameToken.Value, variableDeclarationNode.Type),
private static string? CreateMessage(Node hoveredNode, CompilationUnit compilationUnit) // StructFieldAccessNode structFieldAccessNode => CreateTypeNameMessage("Struct field", $"{structFieldAccessNode.Target.Type}.{structFieldAccessNode.FieldToken.Value}", structFieldAccessNode.Type),
{ // CStringLiteralNode cStringLiteralNode => CreateLiteralMessage(cStringLiteralNode.Type, '"' + cStringLiteralNode.Value + '"'),
return hoveredNode switch // StringLiteralNode stringLiteralNode => CreateLiteralMessage(stringLiteralNode.Type, '"' + stringLiteralNode.Value + '"'),
{ // BoolLiteralNode boolLiteralNode => CreateLiteralMessage(boolLiteralNode.Type, boolLiteralNode.Value.ToString()),
FuncNode funcNode => CreateFuncPrototypeMessage(funcNode.Prototype), // Float32LiteralNode float32LiteralNode => CreateLiteralMessage(float32LiteralNode.Type, float32LiteralNode.Value.ToString(CultureInfo.InvariantCulture)),
FuncPrototypeNode funcPrototypeNode => CreateFuncPrototypeMessage(funcPrototypeNode), // Float64LiteralNode float64LiteralNode => CreateLiteralMessage(float64LiteralNode.Type, float64LiteralNode.Value.ToString(CultureInfo.InvariantCulture)),
FuncIdentifierNode funcIdentifierNode => CreateFuncIdentifierMessage(funcIdentifierNode, compilationUnit), // I8LiteralNode i8LiteralNode => CreateLiteralMessage(i8LiteralNode.Type, i8LiteralNode.Value.ToString()),
FuncParameterNode funcParameterNode => CreateTypeNameMessage("Function parameter", funcParameterNode.NameToken.Value, funcParameterNode.Type), // I16LiteralNode i16LiteralNode => CreateLiteralMessage(i16LiteralNode.Type, i16LiteralNode.Value.ToString()),
VariableIdentifierNode variableIdentifierNode => CreateTypeNameMessage("Variable", variableIdentifierNode.NameToken.Value, variableIdentifierNode.Type), // I32LiteralNode i32LiteralNode => CreateLiteralMessage(i32LiteralNode.Type, i32LiteralNode.Value.ToString()),
VariableDeclarationNode variableDeclarationNode => CreateTypeNameMessage("Variable declaration", variableDeclarationNode.NameToken.Value, variableDeclarationNode.Type), // I64LiteralNode i64LiteralNode => CreateLiteralMessage(i64LiteralNode.Type, i64LiteralNode.Value.ToString()),
StructFieldAccessNode structFieldAccessNode => CreateTypeNameMessage("Struct field", $"{structFieldAccessNode.Target.Type}.{structFieldAccessNode.FieldToken.Value}", structFieldAccessNode.Type), // U8LiteralNode u8LiteralNode => CreateLiteralMessage(u8LiteralNode.Type, u8LiteralNode.Value.ToString()),
CStringLiteralNode cStringLiteralNode => CreateLiteralMessage(cStringLiteralNode.Type, '"' + cStringLiteralNode.Value + '"'), // U16LiteralNode u16LiteralNode => CreateLiteralMessage(u16LiteralNode.Type, u16LiteralNode.Value.ToString()),
StringLiteralNode stringLiteralNode => CreateLiteralMessage(stringLiteralNode.Type, '"' + stringLiteralNode.Value + '"'), // U32LiteralNode u32LiteralNode => CreateLiteralMessage(u32LiteralNode.Type, u32LiteralNode.Value.ToString()),
BoolLiteralNode boolLiteralNode => CreateLiteralMessage(boolLiteralNode.Type, boolLiteralNode.Value.ToString()), // U64LiteralNode u64LiteralNode => CreateLiteralMessage(u64LiteralNode.Type, u64LiteralNode.Value.ToString()),
Float32LiteralNode float32LiteralNode => CreateLiteralMessage(float32LiteralNode.Type, float32LiteralNode.Value.ToString(CultureInfo.InvariantCulture)), // // Expressions can have a generic fallback showing the resulting type
Float64LiteralNode float64LiteralNode => CreateLiteralMessage(float64LiteralNode.Type, float64LiteralNode.Value.ToString(CultureInfo.InvariantCulture)), // ExpressionNode expressionNode => $"""
I8LiteralNode i8LiteralNode => CreateLiteralMessage(i8LiteralNode.Type, i8LiteralNode.Value.ToString()), // **Expression** `{expressionNode.GetType().Name}`
I16LiteralNode i16LiteralNode => CreateLiteralMessage(i16LiteralNode.Type, i16LiteralNode.Value.ToString()), // ```nub
I32LiteralNode i32LiteralNode => CreateLiteralMessage(i32LiteralNode.Type, i32LiteralNode.Value.ToString()), // {expressionNode.Type}
I64LiteralNode i64LiteralNode => CreateLiteralMessage(i64LiteralNode.Type, i64LiteralNode.Value.ToString()), // ```
U8LiteralNode u8LiteralNode => CreateLiteralMessage(u8LiteralNode.Type, u8LiteralNode.Value.ToString()), // """,
U16LiteralNode u16LiteralNode => CreateLiteralMessage(u16LiteralNode.Type, u16LiteralNode.Value.ToString()), // BlockNode => null,
U32LiteralNode u32LiteralNode => CreateLiteralMessage(u32LiteralNode.Type, u32LiteralNode.Value.ToString()), // _ => hoveredNode.GetType().Name
U64LiteralNode u64LiteralNode => CreateLiteralMessage(u64LiteralNode.Type, u64LiteralNode.Value.ToString()), // };
// Expressions can have a generic fallback showing the resulting type // }
ExpressionNode expressionNode => $""" //
**Expression** `{expressionNode.GetType().Name}` // private static string CreateLiteralMessage(NubType type, string value)
```nub // {
{expressionNode.Type} // return $"""
``` // **Literal** `{type}`
""", // ```nub
BlockNode => null, // {value}: {type}
_ => hoveredNode.GetType().Name // ```
}; // """;
} // }
//
private static string CreateLiteralMessage(NubType type, string value) // private static string CreateTypeNameMessage(string description, string name, NubType type)
{ // {
return $""" // return $"""
**Literal** `{type}` // **{description}** `{name}`
```nub // ```nub
{value}: {type} // {name}: {type}
``` // ```
"""; // """;
} // }
//
private static string CreateTypeNameMessage(string description, string name, NubType type) // private static string CreateFuncIdentifierMessage(FuncIdentifierNode funcIdentifierNode, CompilationUnit compilationUnit)
{ // {
return $""" // var func = compilationUnit.ImportedFunctions
**{description}** `{name}` // .Where(x => x.Key.Value == funcIdentifierNode.ModuleToken.Value)
```nub // .SelectMany(x => x.Value)
{name}: {type} // .FirstOrDefault(x => x.NameToken.Value == funcIdentifierNode.NameToken.Value);
``` //
"""; // if (func == null)
} // {
// return $"""
private static string CreateFuncIdentifierMessage(FuncIdentifierNode funcIdentifierNode, CompilationUnit compilationUnit) // **Function** `{funcIdentifierNode.ModuleToken.Value}::{funcIdentifierNode.NameToken.Value}`
{ // ```nub
var func = compilationUnit.ImportedFunctions // // Declaration not found
.Where(x => x.Key.Value == funcIdentifierNode.ModuleToken.Value) // ```
.SelectMany(x => x.Value) // """;
.FirstOrDefault(x => x.NameToken.Value == funcIdentifierNode.NameToken.Value); // }
//
if (func == null) // return CreateFuncPrototypeMessage(func);
{ // }
return $""" //
**Function** `{funcIdentifierNode.ModuleToken.Value}::{funcIdentifierNode.NameToken.Value}` // private static string CreateFuncPrototypeMessage(FuncPrototypeNode funcPrototypeNode)
```nub // {
// Declaration not found // var parameterText = string.Join(", ", funcPrototypeNode.Parameters.Select(x => $"{x.NameToken.Value}: {x.Type}"));
``` // var externText = funcPrototypeNode.ExternSymbolToken != null ? $"extern \"{funcPrototypeNode.ExternSymbolToken.Value}\" " : "";
"""; //
} // return $"""
// **Function** `{funcPrototypeNode.NameToken.Value}`
return CreateFuncPrototypeMessage(func); // ```nub
} // {externText}func {funcPrototypeNode.NameToken.Value}({parameterText}): {funcPrototypeNode.ReturnType}
// ```
private static string CreateFuncPrototypeMessage(FuncPrototypeNode funcPrototypeNode) // """;
{ // }
var parameterText = string.Join(", ", funcPrototypeNode.Parameters.Select(x => $"{x.NameToken.Value}: {x.Type}"));
var externText = funcPrototypeNode.ExternSymbolToken != null ? $"extern \"{funcPrototypeNode.ExternSymbolToken.Value}\" " : "";
return $"""
**Function** `{funcPrototypeNode.NameToken.Value}`
```nub
{externText}func {funcPrototypeNode.NameToken.Value}({parameterText}): {funcPrototypeNode.ReturnType}
```
""";
}
} }

View File

@@ -7,7 +7,7 @@ namespace NubLang.LSP;
public class WorkspaceManager(DiagnosticsPublisher diagnosticsPublisher) public class WorkspaceManager(DiagnosticsPublisher diagnosticsPublisher)
{ {
private readonly Dictionary<string, SyntaxTree> _syntaxTrees = new(); private readonly Dictionary<string, SyntaxTree> _syntaxTrees = new();
private readonly Dictionary<string, CompilationUnit> _compilationUnits = new(); private readonly Dictionary<string, List<TopLevelNode>> _compilationUnits = new();
public void Init(string rootPath) public void Init(string rootPath)
{ {
@@ -35,16 +35,9 @@ public class WorkspaceManager(DiagnosticsPublisher diagnosticsPublisher)
var result = typeChecker.Check(); var result = typeChecker.Check();
diagnosticsPublisher.Publish(fsPath, typeChecker.Diagnostics); diagnosticsPublisher.Publish(fsPath, typeChecker.Diagnostics);
if (result == null)
{
_compilationUnits.Remove(fsPath);
}
else
{
_compilationUnits[fsPath] = result; _compilationUnits[fsPath] = result;
} }
} }
}
public void UpdateFile(DocumentUri path) public void UpdateFile(DocumentUri path)
{ {
@@ -66,15 +59,8 @@ public class WorkspaceManager(DiagnosticsPublisher diagnosticsPublisher)
var result = typeChecker.Check(); var result = typeChecker.Check();
diagnosticsPublisher.Publish(fsPath, typeChecker.Diagnostics); diagnosticsPublisher.Publish(fsPath, typeChecker.Diagnostics);
if (result == null)
{
_compilationUnits.Remove(fsPath);
}
else
{
_compilationUnits[fsPath] = result; _compilationUnits[fsPath] = result;
} }
}
public void RemoveFile(DocumentUri path) public void RemoveFile(DocumentUri path)
{ {
@@ -83,7 +69,7 @@ public class WorkspaceManager(DiagnosticsPublisher diagnosticsPublisher)
_compilationUnits.Remove(fsPath); _compilationUnits.Remove(fsPath);
} }
public CompilationUnit? GetCompilationUnit(DocumentUri path) public List<TopLevelNode>? GetCompilationUnit(DocumentUri path)
{ {
return _compilationUnits.GetValueOrDefault(path.GetFileSystemPath()); return _compilationUnits.GetValueOrDefault(path.GetFileSystemPath());
} }

View File

@@ -2,10 +2,11 @@ using NubLang.Syntax;
namespace NubLang.Ast; namespace NubLang.Ast;
public sealed class CompilationUnit(IdentifierToken module, List<FuncNode> functions, Dictionary<IdentifierToken, List<NubStructType>> importedStructTypes, Dictionary<IdentifierToken, List<FuncPrototypeNode>> importedFunctions) // public sealed class CompilationUnit(IdentifierToken module, List<FuncNode> functions, List<StructNode> structTypes, Dictionary<IdentifierToken, List<NubStructType>> importedStructTypes, Dictionary<IdentifierToken, List<FuncPrototypeNode>> importedFunctions)
{ // {
public IdentifierToken Module { get; } = module; // public IdentifierToken Module { get; } = module;
public List<FuncNode> Functions { get; } = functions; // public List<FuncNode> Functions { get; } = functions;
public Dictionary<IdentifierToken, List<NubStructType>> ImportedStructTypes { get; } = importedStructTypes; // public List<StructNode> Structs { get; } = structTypes;
public Dictionary<IdentifierToken, List<FuncPrototypeNode>> ImportedFunctions { get; } = importedFunctions; // public Dictionary<IdentifierToken, List<NubStructType>> ImportedStructTypes { get; } = importedStructTypes;
} // public Dictionary<IdentifierToken, List<FuncPrototypeNode>> ImportedFunctions { get; } = importedFunctions;
// }

View File

@@ -29,9 +29,31 @@ public abstract class Node(List<Token> tokens)
} }
} }
public abstract class TopLevelNode(List<Token> tokens) : Node(tokens);
public class ImportNode(List<Token> tokens, IdentifierToken nameToken) : TopLevelNode(tokens)
{
public IdentifierToken NameToken { get; } = nameToken;
public override IEnumerable<Node> Children()
{
return [];
}
}
public class ModuleNode(List<Token> tokens, IdentifierToken nameToken) : TopLevelNode(tokens)
{
public IdentifierToken NameToken { get; } = nameToken;
public override IEnumerable<Node> Children()
{
return [];
}
}
#region Definitions #region Definitions
public abstract class DefinitionNode(List<Token> tokens, IdentifierToken nameToken) : Node(tokens) public abstract class DefinitionNode(List<Token> tokens, IdentifierToken nameToken) : TopLevelNode(tokens)
{ {
public IdentifierToken NameToken { get; } = nameToken; public IdentifierToken NameToken { get; } = nameToken;
} }
@@ -75,6 +97,35 @@ public class FuncNode(List<Token> tokens, FuncPrototypeNode prototype, BlockNode
} }
} }
public class StructFieldNode(List<Token> tokens, IdentifierToken nameToken, NubType type, ExpressionNode? value) : Node(tokens)
{
public IdentifierToken NameToken { get; } = nameToken;
public NubType Type { get; } = type;
public ExpressionNode? Value { get; } = value;
public override IEnumerable<Node> Children()
{
if (Value != null)
{
yield return Value;
}
}
}
public class StructNode(List<Token> tokens, IdentifierToken name, NubStructType structType, List<StructFieldNode> fields) : DefinitionNode(tokens, name)
{
public NubStructType StructType { get; } = structType;
public List<StructFieldNode> Fields { get; } = fields;
public override IEnumerable<Node> Children()
{
foreach (var field in Fields)
{
yield return field;
}
}
}
#endregion #endregion
#region Statements #region Statements

View File

@@ -10,8 +10,8 @@ public sealed class TypeChecker
private readonly Dictionary<string, Module> _modules; private readonly Dictionary<string, Module> _modules;
private readonly Stack<Scope> _scopes = []; private readonly Stack<Scope> _scopes = [];
private readonly Dictionary<(string Module, string Name), NubType> _typeCache = new();
private readonly HashSet<(string Module, string Name)> _resolvingTypes = []; private readonly TypeResolver _typeResolver;
private Scope Scope => _scopes.Peek(); private Scope Scope => _scopes.Peek();
@@ -21,19 +21,18 @@ public sealed class TypeChecker
{ {
_syntaxTree = syntaxTree; _syntaxTree = syntaxTree;
_modules = modules; _modules = modules;
_typeResolver = new TypeResolver(_modules);
} }
public CompilationUnit? Check() public List<TopLevelNode> Check()
{ {
_scopes.Clear(); _scopes.Clear();
_typeCache.Clear();
_resolvingTypes.Clear();
var moduleDeclarations = _syntaxTree.TopLevelSyntaxNodes.OfType<ModuleSyntax>().ToList(); var moduleDeclarations = _syntaxTree.TopLevelSyntaxNodes.OfType<ModuleSyntax>().ToList();
if (moduleDeclarations.Count == 0) if (moduleDeclarations.Count == 0)
{ {
Diagnostics.Add(Diagnostic.Error("Missing module declaration").WithHelp("module \"main\"").Build()); Diagnostics.Add(Diagnostic.Error("Missing module declaration").WithHelp("module \"main\"").Build());
return null; return [];
} }
if (moduleDeclarations.Count > 1) if (moduleDeclarations.Count > 1)
@@ -79,72 +78,45 @@ public sealed class TypeChecker
.At(last) .At(last)
.Build()); .Build());
return null; return [];
} }
} }
var functions = new List<FuncNode>(); var topLevelNodes = new List<TopLevelNode>();
using (BeginRootScope(moduleName)) using (BeginRootScope(moduleName))
{ {
foreach (var funcSyntax in _syntaxTree.TopLevelSyntaxNodes.OfType<FuncSyntax>()) foreach (var topLevelSyntaxNode in _syntaxTree.TopLevelSyntaxNodes)
{ {
try switch (topLevelSyntaxNode)
{ {
functions.Add(CheckFuncDefinition(funcSyntax)); case EnumSyntax:
} break;
catch (TypeCheckerException e) case FuncSyntax funcSyntax:
{ topLevelNodes.Add(CheckFuncDefinition(funcSyntax));
Diagnostics.Add(e.Diagnostic); break;
case StructSyntax structSyntax:
topLevelNodes.Add(CheckStructDefinition(structSyntax));
break;
case ImportSyntax importSyntax:
topLevelNodes.Add(new ImportNode(importSyntax.Tokens, importSyntax.NameToken));
break;
case ModuleSyntax moduleSyntax:
topLevelNodes.Add(new ModuleNode(moduleSyntax.Tokens, moduleSyntax.NameToken));
break;
default:
throw new ArgumentOutOfRangeException(nameof(topLevelSyntaxNode));
} }
} }
} }
var importedStructTypes = new Dictionary<IdentifierToken, List<NubStructType>>(); return topLevelNodes;
var importedFunctions = new Dictionary<IdentifierToken, List<FuncPrototypeNode>>();
foreach (var (name, module) in GetImportedModules())
{
var moduleStructs = new List<NubStructType>();
var moduleFunctions = new List<FuncPrototypeNode>();
using (BeginRootScope(name))
{
foreach (var structSyntax in module.Structs(true))
{
try
{
var fields = structSyntax.Fields
.Select(f => new NubStructFieldType(f.NameToken.Value, ResolveType(f.Type), f.Value != null))
.ToList();
moduleStructs.Add(new NubStructType(name.Value, structSyntax.NameToken.Value, fields));
}
catch (TypeCheckerException e)
{
Diagnostics.Add(e.Diagnostic);
}
} }
importedStructTypes[name] = moduleStructs; private (IdentifierToken Name, Module Module) GetCurrentModule()
foreach (var funcSyntax in module.Functions(true))
{ {
try var currentModule = _syntaxTree.TopLevelSyntaxNodes.OfType<ModuleSyntax>().First().NameToken;
{ return (currentModule, _modules[currentModule.Value]);
moduleFunctions.Add(CheckFuncPrototype(funcSyntax.Prototype));
}
catch (TypeCheckerException e)
{
Diagnostics.Add(e.Diagnostic);
}
}
importedFunctions[name] = moduleFunctions;
}
}
return new CompilationUnit(moduleName, functions, importedStructTypes, importedFunctions);
} }
private List<(IdentifierToken Name, Module Module)> GetImportedModules() private List<(IdentifierToken Name, Module Module)> GetImportedModules()
@@ -225,19 +197,48 @@ public sealed class TypeChecker
} }
} }
private StructNode CheckStructDefinition(StructSyntax structSyntax)
{
var fields = new List<StructFieldNode>();
foreach (var field in structSyntax.Fields)
{
var fieldType = _typeResolver.ResolveType(field.Type, Scope.Module.Value);
ExpressionNode? value = null;
if (field.Value != null)
{
value = CheckExpression(field.Value, fieldType);
if (value.Type != fieldType)
{
throw new CompileException(Diagnostic
.Error($"Type {value.Type} is not assignable to {field.Type} for field {field.NameToken.Value}")
.At(field)
.Build());
}
}
fields.Add(new StructFieldNode(field.Tokens, field.NameToken, fieldType, value));
}
var currentModule = GetCurrentModule();
var type = new NubStructType(currentModule.Name.Value, structSyntax.NameToken.Value, fields.Select(x => new NubStructFieldType(x.NameToken.Value, x.Type, x.Value != null)).ToList());
return new StructNode(structSyntax.Tokens, structSyntax.NameToken, type, fields);
}
private AssignmentNode CheckAssignment(AssignmentSyntax statement) private AssignmentNode CheckAssignment(AssignmentSyntax statement)
{ {
var target = CheckExpression(statement.Target); var target = CheckExpression(statement.Target);
if (target is not LValueExpressionNode lValue) if (target is not LValueExpressionNode lValue)
{ {
throw new TypeCheckerException(Diagnostic.Error("Cannot assign to an rvalue").At(statement).Build()); throw new CompileException(Diagnostic.Error("Cannot assign to an rvalue").At(statement).Build());
} }
var value = CheckExpression(statement.Value, lValue.Type); var value = CheckExpression(statement.Value, lValue.Type);
if (value.Type != lValue.Type) if (value.Type != lValue.Type)
{ {
throw new TypeCheckerException(Diagnostic throw new CompileException(Diagnostic
.Error($"Cannot assign {value.Type} to {lValue.Type}") .Error($"Cannot assign {value.Type} to {lValue.Type}")
.At(statement.Value) .At(statement.Value)
.Build()); .Build());
@@ -279,7 +280,7 @@ public sealed class TypeChecker
return expression switch return expression switch
{ {
FuncCallNode funcCall => new StatementFuncCallNode(statement.Tokens, funcCall), FuncCallNode funcCall => new StatementFuncCallNode(statement.Tokens, funcCall),
_ => throw new TypeCheckerException(Diagnostic.Error("Expressions statements can only be function calls").At(statement).Build()) _ => throw new CompileException(Diagnostic.Error("Expressions statements can only be function calls").At(statement).Build())
}; };
} }
@@ -290,7 +291,7 @@ public sealed class TypeChecker
if (statement.ExplicitType != null) if (statement.ExplicitType != null)
{ {
type = ResolveType(statement.ExplicitType); type = _typeResolver.ResolveType(statement.ExplicitType, Scope.Module.Value);
} }
if (statement.Assignment != null) if (statement.Assignment != null)
@@ -303,7 +304,7 @@ public sealed class TypeChecker
} }
else if (assignmentNode.Type != type) else if (assignmentNode.Type != type)
{ {
throw new TypeCheckerException(Diagnostic throw new CompileException(Diagnostic
.Error($"Cannot assign {assignmentNode.Type} to variable of type {type}") .Error($"Cannot assign {assignmentNode.Type} to variable of type {type}")
.At(statement.Assignment) .At(statement.Assignment)
.Build()); .Build());
@@ -312,7 +313,7 @@ public sealed class TypeChecker
if (type == null) if (type == null)
{ {
throw new TypeCheckerException(Diagnostic throw new CompileException(Diagnostic
.Error($"Cannot infer type of variable {statement.NameToken.Value}") .Error($"Cannot infer type of variable {statement.NameToken.Value}")
.At(statement) .At(statement)
.Build()); .Build());
@@ -367,7 +368,7 @@ public sealed class TypeChecker
} }
default: default:
{ {
throw new TypeCheckerException(Diagnostic throw new CompileException(Diagnostic
.Error($"Cannot iterate over type {target.Type} which does not have size information") .Error($"Cannot iterate over type {target.Type} which does not have size information")
.At(forSyntax.Target) .At(forSyntax.Target)
.Build()); .Build());
@@ -380,10 +381,10 @@ public sealed class TypeChecker
var parameters = new List<FuncParameterNode>(); var parameters = new List<FuncParameterNode>();
foreach (var parameter in statement.Parameters) foreach (var parameter in statement.Parameters)
{ {
parameters.Add(new FuncParameterNode(parameter.Tokens, parameter.NameToken, ResolveType(parameter.Type))); parameters.Add(new FuncParameterNode(parameter.Tokens, parameter.NameToken, _typeResolver.ResolveType(parameter.Type, Scope.Module.Value)));
} }
return new FuncPrototypeNode(statement.Tokens, statement.NameToken, statement.ExternSymbolToken, parameters, ResolveType(statement.ReturnType)); return new FuncPrototypeNode(statement.Tokens, statement.NameToken, statement.ExternSymbolToken, parameters, _typeResolver.ResolveType(statement.ReturnType, Scope.Module.Value));
} }
private ExpressionNode CheckExpression(ExpressionSyntax node, NubType? expectedType = null) private ExpressionNode CheckExpression(ExpressionSyntax node, NubType? expectedType = null)
@@ -405,7 +406,7 @@ public sealed class TypeChecker
FloatLiteralSyntax expression => CheckFloatLiteral(expression, expectedType), FloatLiteralSyntax expression => CheckFloatLiteral(expression, expectedType),
MemberAccessSyntax expression => CheckMemberAccess(expression, expectedType), MemberAccessSyntax expression => CheckMemberAccess(expression, expectedType),
StructInitializerSyntax expression => CheckStructInitializer(expression, expectedType), StructInitializerSyntax expression => CheckStructInitializer(expression, expectedType),
SizeSyntax expression => new SizeNode(node.Tokens, ResolveType(expression.Type)), SizeSyntax expression => new SizeNode(node.Tokens, _typeResolver.ResolveType(expression.Type, Scope.Module.Value)),
CastSyntax expression => CheckCast(expression, expectedType), CastSyntax expression => CheckCast(expression, expectedType),
_ => throw new ArgumentOutOfRangeException(nameof(node)) _ => throw new ArgumentOutOfRangeException(nameof(node))
}; };
@@ -430,7 +431,7 @@ public sealed class TypeChecker
{ {
if (expectedType == null) if (expectedType == null)
{ {
throw new TypeCheckerException(Diagnostic throw new CompileException(Diagnostic
.Error("Unable to infer target type of cast") .Error("Unable to infer target type of cast")
.At(expression) .At(expression)
.WithHelp("Specify target type where value is used") .WithHelp("Specify target type where value is used")
@@ -451,7 +452,7 @@ public sealed class TypeChecker
if (!IsCastAllowed(value.Type, expectedType, false)) if (!IsCastAllowed(value.Type, expectedType, false))
{ {
throw new TypeCheckerException(Diagnostic throw new CompileException(Diagnostic
.Error($"Cannot cast from {value.Type} to {expectedType}") .Error($"Cannot cast from {value.Type} to {expectedType}")
.Build()); .Build());
} }
@@ -500,7 +501,7 @@ public sealed class TypeChecker
var target = CheckExpression(expression.Target, (expectedType as NubPointerType)?.BaseType); var target = CheckExpression(expression.Target, (expectedType as NubPointerType)?.BaseType);
if (target is not LValueExpressionNode lvalue) if (target is not LValueExpressionNode lvalue)
{ {
throw new TypeCheckerException(Diagnostic.Error("Cannot take address of an rvalue").At(expression).Build()); throw new CompileException(Diagnostic.Error("Cannot take address of an rvalue").At(expression).Build());
} }
var type = new NubPointerType(target.Type); var type = new NubPointerType(target.Type);
@@ -512,7 +513,7 @@ public sealed class TypeChecker
var index = CheckExpression(expression.Index); var index = CheckExpression(expression.Index);
if (index.Type is not NubIntType) if (index.Type is not NubIntType)
{ {
throw new TypeCheckerException(Diagnostic throw new CompileException(Diagnostic
.Error("Array indexer must be of type int") .Error("Array indexer must be of type int")
.At(expression.Index) .At(expression.Index)
.Build()); .Build());
@@ -525,7 +526,7 @@ public sealed class TypeChecker
NubArrayType arrayType => new ArrayIndexAccessNode(expression.Tokens, arrayType.ElementType, target, index), NubArrayType arrayType => new ArrayIndexAccessNode(expression.Tokens, arrayType.ElementType, target, index),
NubConstArrayType constArrayType => new ConstArrayIndexAccessNode(expression.Tokens, constArrayType.ElementType, target, index), NubConstArrayType constArrayType => new ConstArrayIndexAccessNode(expression.Tokens, constArrayType.ElementType, target, index),
NubSliceType sliceType => new SliceIndexAccessNode(expression.Tokens, sliceType.ElementType, target, index), NubSliceType sliceType => new SliceIndexAccessNode(expression.Tokens, sliceType.ElementType, target, index),
_ => throw new TypeCheckerException(Diagnostic.Error($"Cannot use array indexer on type {target.Type}").At(expression).Build()) _ => throw new CompileException(Diagnostic.Error($"Cannot use array indexer on type {target.Type}").At(expression).Build())
}; };
} }
@@ -550,7 +551,7 @@ public sealed class TypeChecker
if (elementType == null) if (elementType == null)
{ {
throw new TypeCheckerException(Diagnostic throw new CompileException(Diagnostic
.Error("Unable to infer type of array initializer") .Error("Unable to infer type of array initializer")
.At(expression) .At(expression)
.WithHelp("Provide a type for a variable assignment") .WithHelp("Provide a type for a variable assignment")
@@ -563,7 +564,7 @@ public sealed class TypeChecker
var value = CheckExpression(valueExpression, elementType); var value = CheckExpression(valueExpression, elementType);
if (value.Type != elementType) if (value.Type != elementType)
{ {
throw new TypeCheckerException(Diagnostic throw new CompileException(Diagnostic
.Error("Value in array initializer is not the same as the array type") .Error("Value in array initializer is not the same as the array type")
.At(valueExpression) .At(valueExpression)
.Build()); .Build());
@@ -613,7 +614,7 @@ public sealed class TypeChecker
var left = CheckExpression(expression.Left); var left = CheckExpression(expression.Left);
if (left.Type is not NubIntType and not NubFloatType and not NubBoolType) if (left.Type is not NubIntType and not NubFloatType and not NubBoolType)
{ {
throw new TypeCheckerException(Diagnostic throw new CompileException(Diagnostic
.Error("Equal and not equal operators must must be used with int, float or bool types") .Error("Equal and not equal operators must must be used with int, float or bool types")
.At(expression.Left) .At(expression.Left)
.Build()); .Build());
@@ -622,7 +623,7 @@ public sealed class TypeChecker
var right = CheckExpression(expression.Right, left.Type); var right = CheckExpression(expression.Right, left.Type);
if (right.Type != left.Type) if (right.Type != left.Type)
{ {
throw new TypeCheckerException(Diagnostic throw new CompileException(Diagnostic
.Error($"Expected type {left.Type} from left side of binary expression, but got {right.Type}") .Error($"Expected type {left.Type} from left side of binary expression, but got {right.Type}")
.At(expression.Right) .At(expression.Right)
.Build()); .Build());
@@ -638,7 +639,7 @@ public sealed class TypeChecker
var left = CheckExpression(expression.Left); var left = CheckExpression(expression.Left);
if (left.Type is not NubIntType and not NubFloatType) if (left.Type is not NubIntType and not NubFloatType)
{ {
throw new TypeCheckerException(Diagnostic throw new CompileException(Diagnostic
.Error("Greater than and less than operators must must be used with int or float types") .Error("Greater than and less than operators must must be used with int or float types")
.At(expression.Left) .At(expression.Left)
.Build()); .Build());
@@ -647,7 +648,7 @@ public sealed class TypeChecker
var right = CheckExpression(expression.Right, left.Type); var right = CheckExpression(expression.Right, left.Type);
if (right.Type != left.Type) if (right.Type != left.Type)
{ {
throw new TypeCheckerException(Diagnostic throw new CompileException(Diagnostic
.Error($"Expected type {left.Type} from left side of binary expression, but got {right.Type}") .Error($"Expected type {left.Type} from left side of binary expression, but got {right.Type}")
.At(expression.Right) .At(expression.Right)
.Build()); .Build());
@@ -661,7 +662,7 @@ public sealed class TypeChecker
var left = CheckExpression(expression.Left); var left = CheckExpression(expression.Left);
if (left.Type is not NubBoolType) if (left.Type is not NubBoolType)
{ {
throw new TypeCheckerException(Diagnostic throw new CompileException(Diagnostic
.Error("Logical and/or must must be used with bool types") .Error("Logical and/or must must be used with bool types")
.At(expression.Left) .At(expression.Left)
.Build()); .Build());
@@ -670,7 +671,7 @@ public sealed class TypeChecker
var right = CheckExpression(expression.Right, left.Type); var right = CheckExpression(expression.Right, left.Type);
if (right.Type != left.Type) if (right.Type != left.Type)
{ {
throw new TypeCheckerException(Diagnostic throw new CompileException(Diagnostic
.Error($"Expected type {left.Type} from left side of binary expression, but got {right.Type}") .Error($"Expected type {left.Type} from left side of binary expression, but got {right.Type}")
.At(expression.Right) .At(expression.Right)
.Build()); .Build());
@@ -683,7 +684,7 @@ public sealed class TypeChecker
var left = CheckExpression(expression.Left, expectedType); var left = CheckExpression(expression.Left, expectedType);
if (left.Type is not NubIntType and not NubFloatType) if (left.Type is not NubIntType and not NubFloatType)
{ {
throw new TypeCheckerException(Diagnostic throw new CompileException(Diagnostic
.Error("The plus operator must only be used with int and float types") .Error("The plus operator must only be used with int and float types")
.At(expression.Left) .At(expression.Left)
.Build()); .Build());
@@ -692,7 +693,7 @@ public sealed class TypeChecker
var right = CheckExpression(expression.Right, left.Type); var right = CheckExpression(expression.Right, left.Type);
if (right.Type != left.Type) if (right.Type != left.Type)
{ {
throw new TypeCheckerException(Diagnostic throw new CompileException(Diagnostic
.Error($"Expected type {left.Type} from left side of binary expression, but got {right.Type}") .Error($"Expected type {left.Type} from left side of binary expression, but got {right.Type}")
.At(expression.Right) .At(expression.Right)
.Build()); .Build());
@@ -708,7 +709,7 @@ public sealed class TypeChecker
var left = CheckExpression(expression.Left, expectedType); var left = CheckExpression(expression.Left, expectedType);
if (left.Type is not NubIntType and not NubFloatType) if (left.Type is not NubIntType and not NubFloatType)
{ {
throw new TypeCheckerException(Diagnostic throw new CompileException(Diagnostic
.Error("Math operators must be used with int or float types") .Error("Math operators must be used with int or float types")
.At(expression.Left) .At(expression.Left)
.Build()); .Build());
@@ -717,7 +718,7 @@ public sealed class TypeChecker
var right = CheckExpression(expression.Right, left.Type); var right = CheckExpression(expression.Right, left.Type);
if (right.Type != left.Type) if (right.Type != left.Type)
{ {
throw new TypeCheckerException(Diagnostic throw new CompileException(Diagnostic
.Error($"Expected type {left.Type} from left side of binary expression, but got {right.Type}") .Error($"Expected type {left.Type} from left side of binary expression, but got {right.Type}")
.At(expression.Right) .At(expression.Right)
.Build()); .Build());
@@ -734,7 +735,7 @@ public sealed class TypeChecker
var left = CheckExpression(expression.Left, expectedType); var left = CheckExpression(expression.Left, expectedType);
if (left.Type is not NubIntType) if (left.Type is not NubIntType)
{ {
throw new TypeCheckerException(Diagnostic throw new CompileException(Diagnostic
.Error("Bitwise operators must be used with int types") .Error("Bitwise operators must be used with int types")
.At(expression.Left) .At(expression.Left)
.Build()); .Build());
@@ -743,7 +744,7 @@ public sealed class TypeChecker
var right = CheckExpression(expression.Right, left.Type); var right = CheckExpression(expression.Right, left.Type);
if (right.Type != left.Type) if (right.Type != left.Type)
{ {
throw new TypeCheckerException(Diagnostic throw new CompileException(Diagnostic
.Error($"Expected type {left.Type} from left side of binary expression, but got {right.Type}") .Error($"Expected type {left.Type} from left side of binary expression, but got {right.Type}")
.At(expression.Right) .At(expression.Right)
.Build()); .Build());
@@ -767,7 +768,7 @@ public sealed class TypeChecker
var operand = CheckExpression(expression.Operand, expectedType); var operand = CheckExpression(expression.Operand, expectedType);
if (operand.Type is not NubIntType { Signed: true } and not NubFloatType) if (operand.Type is not NubIntType { Signed: true } and not NubFloatType)
{ {
throw new TypeCheckerException(Diagnostic throw new CompileException(Diagnostic
.Error("Negation operator must be used with signed integer or float types") .Error("Negation operator must be used with signed integer or float types")
.At(expression) .At(expression)
.Build()); .Build());
@@ -780,7 +781,7 @@ public sealed class TypeChecker
var operand = CheckExpression(expression.Operand, expectedType); var operand = CheckExpression(expression.Operand, expectedType);
if (operand.Type is not NubBoolType) if (operand.Type is not NubBoolType)
{ {
throw new TypeCheckerException(Diagnostic throw new CompileException(Diagnostic
.Error("Invert operator must be used with booleans") .Error("Invert operator must be used with booleans")
.At(expression) .At(expression)
.Build()); .Build());
@@ -803,7 +804,7 @@ public sealed class TypeChecker
{ {
NubPointerType pointerType => new DereferenceNode(expression.Tokens, pointerType.BaseType, target), NubPointerType pointerType => new DereferenceNode(expression.Tokens, pointerType.BaseType, target),
NubRefType refType => new RefDereferenceNode(expression.Tokens, refType.BaseType, target), NubRefType refType => new RefDereferenceNode(expression.Tokens, refType.BaseType, target),
_ => throw new TypeCheckerException(Diagnostic.Error($"Cannot dereference non-pointer type {target.Type}").At(expression).Build()) _ => throw new CompileException(Diagnostic.Error($"Cannot dereference non-pointer type {target.Type}").At(expression).Build())
}; };
} }
@@ -812,12 +813,12 @@ public sealed class TypeChecker
var accessor = CheckExpression(expression.Expression); var accessor = CheckExpression(expression.Expression);
if (accessor.Type is not NubFuncType funcType) if (accessor.Type is not NubFuncType funcType)
{ {
throw new TypeCheckerException(Diagnostic.Error($"Cannot call non-function type {accessor.Type}").At(expression.Expression).Build()); throw new CompileException(Diagnostic.Error($"Cannot call non-function type {accessor.Type}").At(expression.Expression).Build());
} }
if (expression.Parameters.Count != funcType.Parameters.Count) if (expression.Parameters.Count != funcType.Parameters.Count)
{ {
throw new TypeCheckerException(Diagnostic throw new CompileException(Diagnostic
.Error($"Function {funcType} expects {funcType.Parameters.Count} parameters but got {expression.Parameters.Count}") .Error($"Function {funcType} expects {funcType.Parameters.Count} parameters but got {expression.Parameters.Count}")
.At(expression.Parameters.LastOrDefault(expression)) .At(expression.Parameters.LastOrDefault(expression))
.Build()); .Build());
@@ -832,7 +833,7 @@ public sealed class TypeChecker
var parameterExpression = CheckExpression(parameter, expectedParameterType); var parameterExpression = CheckExpression(parameter, expectedParameterType);
if (parameterExpression.Type != expectedParameterType) if (parameterExpression.Type != expectedParameterType)
{ {
throw new TypeCheckerException(Diagnostic throw new CompileException(Diagnostic
.Error($"Parameter {i + 1} does not match the type {expectedParameterType} for function {funcType}") .Error($"Parameter {i + 1} does not match the type {expectedParameterType} for function {funcType}")
.At(parameter) .At(parameter)
.Build()); .Build());
@@ -858,8 +859,8 @@ public sealed class TypeChecker
var function = module.Functions(true).FirstOrDefault(x => x.NameToken.Value == expression.NameToken.Value); var function = module.Functions(true).FirstOrDefault(x => x.NameToken.Value == expression.NameToken.Value);
if (function != null) if (function != null)
{ {
var parameters = function.Prototype.Parameters.Select(x => ResolveType(x.Type)).ToList(); var parameters = function.Prototype.Parameters.Select(x => _typeResolver.ResolveType(x.Type, Scope.Module.Value)).ToList();
var type = new NubFuncType(parameters, ResolveType(function.Prototype.ReturnType)); var type = new NubFuncType(parameters, _typeResolver.ResolveType(function.Prototype.ReturnType, Scope.Module.Value));
return new FuncIdentifierNode(expression.Tokens, type, Scope.Module, expression.NameToken, function.Prototype.ExternSymbolToken); return new FuncIdentifierNode(expression.Tokens, type, Scope.Module, expression.NameToken, function.Prototype.ExternSymbolToken);
} }
@@ -869,7 +870,7 @@ public sealed class TypeChecker
return new EnumReferenceIntermediateNode(expression.Tokens, Scope.Module, expression.NameToken); return new EnumReferenceIntermediateNode(expression.Tokens, Scope.Module, expression.NameToken);
} }
throw new TypeCheckerException(Diagnostic throw new CompileException(Diagnostic
.Error($"There is no identifier named {expression.NameToken.Value}") .Error($"There is no identifier named {expression.NameToken.Value}")
.At(expression) .At(expression)
.Build()); .Build());
@@ -880,7 +881,7 @@ public sealed class TypeChecker
var module = GetImportedModule(expression.ModuleToken.Value); var module = GetImportedModule(expression.ModuleToken.Value);
if (module == null) if (module == null)
{ {
throw new TypeCheckerException(Diagnostic throw new CompileException(Diagnostic
.Error($"Module {expression.ModuleToken.Value} not found") .Error($"Module {expression.ModuleToken.Value} not found")
.WithHelp($"import \"{expression.ModuleToken.Value}\"") .WithHelp($"import \"{expression.ModuleToken.Value}\"")
.At(expression.ModuleToken) .At(expression.ModuleToken)
@@ -892,8 +893,8 @@ public sealed class TypeChecker
{ {
using (BeginRootScope(expression.ModuleToken)) using (BeginRootScope(expression.ModuleToken))
{ {
var parameters = function.Prototype.Parameters.Select(x => ResolveType(x.Type)).ToList(); var parameters = function.Prototype.Parameters.Select(x => _typeResolver.ResolveType(x.Type, Scope.Module.Value)).ToList();
var type = new NubFuncType(parameters, ResolveType(function.Prototype.ReturnType)); var type = new NubFuncType(parameters, _typeResolver.ResolveType(function.Prototype.ReturnType, Scope.Module.Value));
return new FuncIdentifierNode(expression.Tokens, type, expression.ModuleToken, expression.NameToken, function.Prototype.ExternSymbolToken); return new FuncIdentifierNode(expression.Tokens, type, expression.ModuleToken, expression.NameToken, function.Prototype.ExternSymbolToken);
} }
} }
@@ -904,7 +905,7 @@ public sealed class TypeChecker
return new EnumReferenceIntermediateNode(expression.Tokens, expression.ModuleToken, expression.NameToken); return new EnumReferenceIntermediateNode(expression.Tokens, expression.ModuleToken, expression.NameToken);
} }
throw new TypeCheckerException(Diagnostic throw new CompileException(Diagnostic
.Error($"Module {expression.ModuleToken.Value} does not export a member named {expression.NameToken.Value}") .Error($"Module {expression.ModuleToken.Value} does not export a member named {expression.NameToken.Value}")
.At(expression) .At(expression)
.Build()); .Build());
@@ -982,16 +983,16 @@ public sealed class TypeChecker
var field = enumDef.Fields.FirstOrDefault(x => x.NameToken.Value == expression.MemberToken.Value); var field = enumDef.Fields.FirstOrDefault(x => x.NameToken.Value == expression.MemberToken.Value);
if (field == null) if (field == null)
{ {
throw new TypeCheckerException(Diagnostic throw new CompileException(Diagnostic
.Error($"Enum {Scope.Module.Value}::{enumReferenceIntermediate.NameToken.Value} does not have a field named {expression.MemberToken.Value}") .Error($"Enum {Scope.Module.Value}::{enumReferenceIntermediate.NameToken.Value} does not have a field named {expression.MemberToken.Value}")
.At(enumDef) .At(enumDef)
.Build()); .Build());
} }
var enumType = enumDef.Type != null ? ResolveType(enumDef.Type) : new NubIntType(false, 64); var enumType = enumDef.Type != null ? _typeResolver.ResolveType(enumDef.Type, Scope.Module.Value) : new NubIntType(false, 64);
if (enumType is not NubIntType enumIntType) if (enumType is not NubIntType enumIntType)
{ {
throw new TypeCheckerException(Diagnostic.Error("Enum type must be an int type").At(enumDef.Type).Build()); throw new CompileException(Diagnostic.Error("Enum type must be an int type").At(enumDef.Type).Build());
} }
if (enumIntType.Signed) if (enumIntType.Signed)
@@ -1027,7 +1028,7 @@ public sealed class TypeChecker
var field = structType.Fields.FirstOrDefault(x => x.Name == expression.MemberToken.Value); var field = structType.Fields.FirstOrDefault(x => x.Name == expression.MemberToken.Value);
if (field == null) if (field == null)
{ {
throw new TypeCheckerException(Diagnostic throw new CompileException(Diagnostic
.Error($"Struct {target.Type} does not have a field with the name {expression.MemberToken.Value}") .Error($"Struct {target.Type} does not have a field with the name {expression.MemberToken.Value}")
.At(expression) .At(expression)
.Build()); .Build());
@@ -1037,7 +1038,7 @@ public sealed class TypeChecker
} }
default: default:
{ {
throw new TypeCheckerException(Diagnostic throw new CompileException(Diagnostic
.Error($"Cannot access struct member {expression.MemberToken.Value} on type {target.Type}") .Error($"Cannot access struct member {expression.MemberToken.Value} on type {target.Type}")
.At(expression) .At(expression)
.Build()); .Build());
@@ -1095,7 +1096,7 @@ public sealed class TypeChecker
if (expression.StructType != null) if (expression.StructType != null)
{ {
var checkedType = ResolveType(expression.StructType); var checkedType = _typeResolver.ResolveType(expression.StructType, Scope.Module.Value);
if (checkedType is not NubStructType checkedStructType) if (checkedType is not NubStructType checkedStructType)
{ {
throw new UnreachableException("Parser fucked up"); throw new UnreachableException("Parser fucked up");
@@ -1115,7 +1116,7 @@ public sealed class TypeChecker
if (structType == null) if (structType == null)
{ {
throw new TypeCheckerException(Diagnostic throw new CompileException(Diagnostic
.Error("Cannot get implicit type of struct") .Error("Cannot get implicit type of struct")
.WithHelp("Specify struct type with struct {type_name} syntax") .WithHelp("Specify struct type with struct {type_name} syntax")
.At(expression) .At(expression)
@@ -1174,7 +1175,7 @@ public sealed class TypeChecker
{ {
statements.Add(CheckStatement(statement)); statements.Add(CheckStatement(statement));
} }
catch (TypeCheckerException e) catch (CompileException e)
{ {
Diagnostics.Add(e.Diagnostic); Diagnostics.Add(e.Diagnostic);
} }
@@ -1202,85 +1203,6 @@ public sealed class TypeChecker
_ => throw new ArgumentOutOfRangeException(nameof(statement)) _ => throw new ArgumentOutOfRangeException(nameof(statement))
}; };
} }
private NubType ResolveType(TypeSyntax type)
{
return type switch
{
ArrayTypeSyntax arr => new NubArrayType(ResolveType(arr.BaseType)),
BoolTypeSyntax => new NubBoolType(),
IntTypeSyntax i => new NubIntType(i.Signed, i.Width),
FloatTypeSyntax f => new NubFloatType(f.Width),
FuncTypeSyntax func => new NubFuncType(func.Parameters.Select(ResolveType).ToList(), ResolveType(func.ReturnType)),
SliceTypeSyntax slice => new NubSliceType(ResolveType(slice.BaseType)),
ConstArrayTypeSyntax arr => new NubConstArrayType(ResolveType(arr.BaseType), arr.Size),
PointerTypeSyntax ptr => new NubPointerType(ResolveType(ptr.BaseType)),
RefTypeSyntax r => new NubRefType(ResolveType(r.BaseType)),
StringTypeSyntax => new NubStringType(),
CustomTypeSyntax c => ResolveCustomType(c),
VoidTypeSyntax => new NubVoidType(),
_ => throw new NotSupportedException($"Unknown type syntax: {type}")
};
}
private NubType ResolveCustomType(CustomTypeSyntax customType)
{
var module = GetImportedModule(customType.ModuleToken?.Value ?? Scope.Module.Value);
if (module == null)
{
throw new TypeCheckerException(Diagnostic
.Error($"Module {customType.ModuleToken?.Value ?? Scope.Module.Value} not found")
.WithHelp($"import \"{customType.ModuleToken?.Value ?? Scope.Module.Value}\"")
.At(customType)
.Build());
}
var enumDef = module.Enums(IsCurrentModule(customType.ModuleToken)).FirstOrDefault(x => x.NameToken.Value == customType.NameToken.Value);
if (enumDef != null)
{
return enumDef.Type != null ? ResolveType(enumDef.Type) : new NubIntType(false, 64);
}
var structDef = module.Structs(IsCurrentModule(customType.ModuleToken)).FirstOrDefault(x => x.NameToken.Value == customType.NameToken.Value);
if (structDef != null)
{
var key = (customType.ModuleToken?.Value ?? Scope.Module.Value, customType.NameToken.Value);
if (_typeCache.TryGetValue(key, out var cachedType))
{
return cachedType;
}
if (!_resolvingTypes.Add(key))
{
var placeholder = new NubStructType(customType.ModuleToken?.Value ?? Scope.Module.Value, customType.NameToken.Value, []);
_typeCache[key] = placeholder;
return placeholder;
}
try
{
var result = new NubStructType(customType.ModuleToken?.Value ?? Scope.Module.Value, structDef.NameToken.Value, []);
_typeCache[key] = result;
var fields = structDef.Fields
.Select(x => new NubStructFieldType(x.NameToken.Value, ResolveType(x.Type), x.Value != null))
.ToList();
result.Fields.AddRange(fields);
return result;
}
finally
{
_resolvingTypes.Remove(key);
}
}
throw new TypeCheckerException(Diagnostic
.Error($"Type {customType.NameToken.Value} not found in module {customType.ModuleToken?.Value ?? Scope.Module.Value}")
.At(customType)
.Build());
}
} }
public record Variable(IdentifierToken Name, NubType Type); public record Variable(IdentifierToken Name, NubType Type);
@@ -1322,13 +1244,3 @@ public class Scope(IdentifierToken module, Scope? parent = null)
return new Scope(Module, this); return new Scope(Module, this);
} }
} }
public class TypeCheckerException : Exception
{
public Diagnostic Diagnostic { get; }
public TypeCheckerException(Diagnostic diagnostic) : base(diagnostic.Message)
{
Diagnostic = diagnostic;
}
}

View File

@@ -0,0 +1,97 @@
using NubLang.Diagnostics;
using NubLang.Syntax;
namespace NubLang.Ast;
public class TypeResolver
{
private readonly Dictionary<string, Module> _modules;
private readonly Dictionary<(string Module, string Name), NubType> _typeCache = new();
private readonly HashSet<(string Module, string Name)> _resolvingTypes = [];
public TypeResolver(Dictionary<string, Module> modules)
{
_modules = modules;
}
public NubType ResolveType(TypeSyntax type, string currentModule)
{
return type switch
{
ArrayTypeSyntax arr => new NubArrayType(ResolveType(arr.BaseType, currentModule)),
BoolTypeSyntax => new NubBoolType(),
IntTypeSyntax i => new NubIntType(i.Signed, i.Width),
FloatTypeSyntax f => new NubFloatType(f.Width),
FuncTypeSyntax func => new NubFuncType(func.Parameters.Select(x => ResolveType(x, currentModule)).ToList(), ResolveType(func.ReturnType, currentModule)),
SliceTypeSyntax slice => new NubSliceType(ResolveType(slice.BaseType, currentModule)),
ConstArrayTypeSyntax arr => new NubConstArrayType(ResolveType(arr.BaseType, currentModule), arr.Size),
PointerTypeSyntax ptr => new NubPointerType(ResolveType(ptr.BaseType, currentModule)),
RefTypeSyntax r => new NubRefType(ResolveType(r.BaseType, currentModule)),
StringTypeSyntax => new NubStringType(),
CustomTypeSyntax c => ResolveCustomType(c, currentModule),
VoidTypeSyntax => new NubVoidType(),
_ => throw new NotSupportedException($"Unknown type syntax: {type}")
};
}
private NubType ResolveCustomType(CustomTypeSyntax customType, string currentModule)
{
var module = _modules[customType.ModuleToken?.Value ?? currentModule];
var enumDef = module.Enums(true).FirstOrDefault(x => x.NameToken.Value == customType.NameToken.Value);
if (enumDef != null)
{
return enumDef.Type != null ? ResolveType(enumDef.Type, currentModule) : new NubIntType(false, 64);
}
var structDef = module.Structs(true).FirstOrDefault(x => x.NameToken.Value == customType.NameToken.Value);
if (structDef != null)
{
var key = (customType.ModuleToken?.Value ?? currentModule, customType.NameToken.Value);
if (_typeCache.TryGetValue(key, out var cachedType))
{
return cachedType;
}
if (!_resolvingTypes.Add(key))
{
var placeholder = new NubStructType(customType.ModuleToken?.Value ?? currentModule, customType.NameToken.Value, []);
_typeCache[key] = placeholder;
return placeholder;
}
try
{
var result = new NubStructType(customType.ModuleToken?.Value ?? currentModule, structDef.NameToken.Value, []);
_typeCache[key] = result;
var fields = structDef.Fields
.Select(x => new NubStructFieldType(x.NameToken.Value, ResolveType(x.Type, currentModule), x.Value != null))
.ToList();
result.Fields.AddRange(fields);
return result;
}
finally
{
_resolvingTypes.Remove(key);
}
}
throw new TypeResolverException(Diagnostic
.Error($"Type {customType.NameToken.Value} not found in module {customType.ModuleToken?.Value ?? currentModule}")
.At(customType)
.Build());
}
}
public class TypeResolverException : Exception
{
public Diagnostic Diagnostic { get; }
public TypeResolverException(Diagnostic diagnostic) : base(diagnostic.Message)
{
Diagnostic = diagnostic;
}
}

View File

@@ -0,0 +1,11 @@
namespace NubLang.Diagnostics;
public class CompileException : Exception
{
public Diagnostic Diagnostic { get; }
public CompileException(Diagnostic diagnostic) : base(diagnostic.Message)
{
Diagnostic = diagnostic;
}
}

View File

@@ -14,12 +14,12 @@ public static class CType
NubFloatType f => CreateFloatType(f, variableName), NubFloatType f => CreateFloatType(f, variableName),
NubPointerType p => CreatePointerType(p, variableName), NubPointerType p => CreatePointerType(p, variableName),
NubRefType r => CreateRefType(r, variableName), NubRefType r => CreateRefType(r, variableName),
NubSliceType => "struct nub_slice" + (variableName != null ? $" {variableName}" : ""), NubSliceType => "nub_slice" + (variableName != null ? $" {variableName}" : ""),
NubStringType => "struct nub_string" + (variableName != null ? $" {variableName}" : ""), NubStringType => "nub_string" + (variableName != null ? $" {variableName}" : ""),
NubConstArrayType a => CreateConstArrayType(a, variableName, constArraysAsPointers), NubConstArrayType a => CreateConstArrayType(a, variableName, constArraysAsPointers),
NubArrayType a => CreateArrayType(a, variableName), NubArrayType a => CreateArrayType(a, variableName),
NubFuncType f => CreateFuncType(f, variableName), NubFuncType f => CreateFuncType(f, variableName),
NubStructType s => $"struct {s.Module}_{s.Name}_{NameMangler.Mangle(s)}" + (variableName != null ? $" {variableName}" : ""), NubStructType s => $"{s.Module}_{s.Name}_{NameMangler.Mangle(s)}" + (variableName != null ? $" {variableName}" : ""),
_ => throw new NotSupportedException($"C type generation not supported for: {type}") _ => throw new NotSupportedException($"C type generation not supported for: {type}")
}; };
} }

View File

@@ -7,14 +7,14 @@ namespace NubLang.Generation;
public class Generator public class Generator
{ {
private readonly CompilationUnit _compilationUnit; private readonly List<TopLevelNode> _compilationUnit;
private readonly IndentedTextWriter _writer; private readonly IndentedTextWriter _writer;
private readonly Stack<Scope> _scopes = []; private readonly Stack<Scope> _scopes = [];
private int _tmpIndex; private int _tmpIndex;
private Scope Scope => _scopes.Peek(); private Scope Scope => _scopes.Peek();
public Generator(CompilationUnit compilationUnit) public Generator(List<TopLevelNode> compilationUnit)
{ {
_compilationUnit = compilationUnit; _compilationUnit = compilationUnit;
_writer = new IndentedTextWriter(); _writer = new IndentedTextWriter();
@@ -31,66 +31,51 @@ public class Generator
return externSymbol ?? $"{module}_{name}"; return externSymbol ?? $"{module}_{name}";
} }
private string GetModuleName()
{
return _compilationUnit.OfType<ModuleNode>().First().NameToken.Value;
}
public string Emit() public string Emit()
{ {
_writer.WriteLine(""" foreach (var structType in _compilationUnit.OfType<StructNode>())
#include <stddef.h>
void *rc_alloc(size_t size, void (*destructor)(void *self));
void rc_retain(void *obj);
void rc_release(void *obj);
struct nub_string
{ {
unsigned long long length; _writer.WriteLine($"void {CType.Create(structType.StructType)}_create({CType.Create(structType.StructType)} *self)");
char *data;
};
struct nub_slice
{
unsigned long long length;
void *data;
};
""");
foreach (var (_, structTypes) in _compilationUnit.ImportedStructTypes)
{
foreach (var structType in structTypes)
{
_writer.WriteLine(CType.Create(structType));
_writer.WriteLine("{"); _writer.WriteLine("{");
using (_writer.Indent()) using (_writer.Indent())
{ {
foreach (var field in structType.Fields) foreach (var field in structType.Fields)
{ {
_writer.WriteLine($"{CType.Create(field.Type, field.Name, constArraysAsPointers: false)};"); if (field.Value != null)
}
}
_writer.WriteLine("};");
_writer.WriteLine();
}
}
// note(nub31): Forward declarations
foreach (var (module, prototypes) in _compilationUnit.ImportedFunctions)
{ {
foreach (var prototype in prototypes) var value = EmitExpression(field.Value);
{ _writer.WriteLine($"self->{field.NameToken.Value} = {value}");
EmitLine(prototype.Tokens.FirstOrDefault());
var parameters = prototype.Parameters.Count != 0
? string.Join(", ", prototype.Parameters.Select(x => CType.Create(x.Type, x.NameToken.Value)))
: "void";
var name = FuncName(module.Value, prototype.NameToken.Value, prototype.ExternSymbolToken?.Value);
_writer.WriteLine($"{CType.Create(prototype.ReturnType, name)}({parameters});");
_writer.WriteLine();
} }
} }
}
_writer.WriteLine("}");
_writer.WriteLine();
_writer.WriteLine($"void {CType.Create(structType.StructType)}_destroy({CType.Create(structType.StructType)} *self)");
_writer.WriteLine("{");
using (_writer.Indent())
{
foreach (var field in structType.Fields)
{
if (field.Type is NubRefType)
{
_writer.WriteLine($"rc_release(self->{field.NameToken.Value});");
}
}
}
_writer.WriteLine("}");
_writer.WriteLine();
}
// note(nub31): Normal functions // note(nub31): Normal functions
foreach (var funcNode in _compilationUnit.Functions) foreach (var funcNode in _compilationUnit.OfType<FuncNode>())
{ {
if (funcNode.Body == null) continue; if (funcNode.Body == null) continue;
@@ -99,7 +84,7 @@ public class Generator
? string.Join(", ", funcNode.Prototype.Parameters.Select(x => CType.Create(x.Type, x.NameToken.Value))) ? string.Join(", ", funcNode.Prototype.Parameters.Select(x => CType.Create(x.Type, x.NameToken.Value)))
: "void"; : "void";
var name = FuncName(_compilationUnit.Module.Value, funcNode.NameToken.Value, funcNode.Prototype.ExternSymbolToken?.Value); var name = FuncName(GetModuleName(), funcNode.NameToken.Value, funcNode.Prototype.ExternSymbolToken?.Value);
_writer.WriteLine($"{CType.Create(funcNode.Prototype.ReturnType, name)}({parameters})"); _writer.WriteLine($"{CType.Create(funcNode.Prototype.ReturnType, name)}({parameters})");
_writer.WriteLine("{"); _writer.WriteLine("{");
using (_writer.Indent()) using (_writer.Indent())
@@ -196,6 +181,7 @@ public class Generator
if (assignmentNode.Target.Type is NubRefType) if (assignmentNode.Target.Type is NubRefType)
{ {
_writer.WriteLine($"rc_retain({value});"); _writer.WriteLine($"rc_retain({value});");
Scope.Defer(() => _writer.WriteLine($"rc_release({value});"));
_writer.WriteLine($"rc_release({target});"); _writer.WriteLine($"rc_release({target});");
} }
@@ -313,8 +299,7 @@ public class Generator
private void EmitStatementFuncCall(StatementFuncCallNode statementFuncCallNode) private void EmitStatementFuncCall(StatementFuncCallNode statementFuncCallNode)
{ {
var funcCall = EmitFuncCall(statementFuncCallNode.FuncCall); EmitFuncCall(statementFuncCallNode.FuncCall);
_writer.WriteLine($"{funcCall};");
} }
private void EmitVariableDeclaration(VariableDeclarationNode variableDeclarationNode) private void EmitVariableDeclaration(VariableDeclarationNode variableDeclarationNode)
@@ -322,11 +307,14 @@ public class Generator
if (variableDeclarationNode.Assignment != null) if (variableDeclarationNode.Assignment != null)
{ {
var value = EmitExpression(variableDeclarationNode.Assignment); var value = EmitExpression(variableDeclarationNode.Assignment);
_writer.WriteLine($"{CType.Create(variableDeclarationNode.Type, variableDeclarationNode.NameToken.Value)} = {value};");
if (variableDeclarationNode.Type is NubRefType) if (variableDeclarationNode.Type is NubRefType)
{ {
_writer.WriteLine($"rc_retain({variableDeclarationNode.NameToken.Value});"); _writer.WriteLine($"rc_retain({value});");
Scope.Defer(() => _writer.WriteLine($"rc_release({value});"));
} }
_writer.WriteLine($"{CType.Create(variableDeclarationNode.Type, variableDeclarationNode.NameToken.Value)} = {value};");
} }
else else
{ {
@@ -522,7 +510,16 @@ public class Generator
parameterNames.Add(result); parameterNames.Add(result);
} }
return $"{name}({string.Join(", ", parameterNames)})"; var tmp = NewTmp();
_writer.WriteLine($"{CType.Create(funcCallNode.Type)} {tmp} = {name}({string.Join(", ", parameterNames)});");
if (funcCallNode.Type is NubRefType)
{
Scope.Defer(() => _writer.WriteLine($"rc_release({tmp});"));
}
return tmp;
} }
private string EmitAddressOf(AddressOfNode addressOfNode) private string EmitAddressOf(AddressOfNode addressOfNode)
@@ -543,23 +540,18 @@ public class Generator
var structType = (NubStructType)type.BaseType; var structType = (NubStructType)type.BaseType;
var tmp = NewTmp(); var tmp = NewTmp();
_writer.WriteLine($"{CType.Create(type)} {tmp} = ({CType.Create(type)})rc_alloc(sizeof({CType.Create(structType)}), NULL);"); _writer.WriteLine($"{CType.Create(type)} {tmp} = ({CType.Create(type)})rc_alloc(sizeof({CType.Create(structType)}), (void (*)(void *)){CType.Create(structType)}_destroy);");
Scope.Defer(() => _writer.WriteLine($"rc_release({tmp});"));
_writer.WriteLine($"*{tmp} = ({CType.Create(structType)}){{{0}}};");
_writer.WriteLine($"{CType.Create(structType)}_create({tmp});");
var initValues = new List<string>();
foreach (var initializer in refStructInitializerNode.Initializers) foreach (var initializer in refStructInitializerNode.Initializers)
{ {
var value = EmitExpression(initializer.Value); var value = EmitExpression(initializer.Value);
initValues.Add($".{initializer.Key.Value} = {value}"); _writer.WriteLine($"{tmp}->{initializer.Key} = {value};");
} }
var initString = initValues.Count == 0
? "0"
: string.Join(", ", initValues);
_writer.WriteLine($"*{tmp} = ({CType.Create(structType)}){{{initString}}};");
Scope.Defer(() => _writer.WriteLine($"rc_release({tmp});"));
return tmp; return tmp;
} }
@@ -575,7 +567,7 @@ public class Generator
private string EmitStringLiteral(StringLiteralNode stringLiteralNode) private string EmitStringLiteral(StringLiteralNode stringLiteralNode)
{ {
var length = Encoding.UTF8.GetByteCount(stringLiteralNode.Value); var length = Encoding.UTF8.GetByteCount(stringLiteralNode.Value);
return $"(nub_string){{.length = {length}, .data = \"{stringLiteralNode.Value}\"}}"; return $"({CType.Create(stringLiteralNode.Type)}){{.length = {length}, .data = \"{stringLiteralNode.Value}\"}}";
} }
private string EmitStructFieldAccess(StructFieldAccessNode structFieldAccessNode) private string EmitStructFieldAccess(StructFieldAccessNode structFieldAccessNode)
@@ -586,18 +578,19 @@ public class Generator
private string EmitStructInitializer(StructInitializerNode structInitializerNode) private string EmitStructInitializer(StructInitializerNode structInitializerNode)
{ {
var initValues = new List<string>(); var structType = (NubStructType)structInitializerNode.Type;
var tmp = NewTmp();
_writer.WriteLine($"{CType.Create(structType)} {tmp} = ({CType.Create(structType)}){{0}};");
_writer.WriteLine($"{CType.Create(structType)}_create(&{tmp});");
foreach (var initializer in structInitializerNode.Initializers) foreach (var initializer in structInitializerNode.Initializers)
{ {
var value = EmitExpression(initializer.Value); var value = EmitExpression(initializer.Value);
initValues.Add($".{initializer.Key.Value} = {value}"); _writer.WriteLine($"{tmp}.{initializer.Key} = {value};");
} }
var initString = initValues.Count == 0 return tmp;
? "0"
: string.Join(", ", initValues);
return $"({CType.Create(structInitializerNode.Type)}){{{initString}}}";
} }
private string EmitI8Literal(I8LiteralNode i8LiteralNode) private string EmitI8Literal(I8LiteralNode i8LiteralNode)

View File

@@ -0,0 +1,49 @@
using NubLang.Ast;
using NubLang.Syntax;
namespace NubLang.Generation;
public static class HeaderGenerator
{
private static string FuncName(string module, string name, string? externSymbol)
{
return externSymbol ?? $"{module}_{name}";
}
public static string Generate(string name, TypedModule module)
{
var writer = new IndentedTextWriter();
writer.WriteLine();
foreach (var structType in module.StructTypes)
{
writer.WriteLine("typedef struct");
writer.WriteLine("{");
using (writer.Indent())
{
foreach (var field in structType.Fields)
{
writer.WriteLine($"{CType.Create(field.Type)} {field.Name};");
}
}
writer.WriteLine($"}} {CType.Create(structType)};");
writer.WriteLine($"void {CType.Create(structType)}_create({CType.Create(structType)} *self);");
writer.WriteLine($"void {CType.Create(structType)}_destroy({CType.Create(structType)} *self);");
writer.WriteLine();
}
foreach (var prototype in module.FunctionPrototypes)
{
var parameters = prototype.Parameters.Count != 0
? string.Join(", ", prototype.Parameters.Select(x => CType.Create(x.Type, x.NameToken.Value)))
: "void";
var funcName = FuncName(name, prototype.NameToken.Value, prototype.ExternSymbolToken?.Value);
writer.WriteLine($"{CType.Create(prototype.ReturnType, funcName)}({parameters});");
}
return writer.ToString();
}
}

View File

@@ -45,7 +45,7 @@ public sealed class Parser
Symbol.Func => ParseFunc(startIndex, exported, null), Symbol.Func => ParseFunc(startIndex, exported, null),
Symbol.Struct => ParseStruct(startIndex, exported), Symbol.Struct => ParseStruct(startIndex, exported),
Symbol.Enum => ParseEnum(startIndex, exported), Symbol.Enum => ParseEnum(startIndex, exported),
_ => throw new ParseException(Diagnostic _ => throw new CompileException(Diagnostic
.Error($"Expected 'func', 'struct', 'enum', 'import' or 'module' but found '{keyword.Symbol}'") .Error($"Expected 'func', 'struct', 'enum', 'import' or 'module' but found '{keyword.Symbol}'")
.WithHelp("Valid top level statements are 'func', 'struct', 'enum', 'import' and 'module'") .WithHelp("Valid top level statements are 'func', 'struct', 'enum', 'import' and 'module'")
.At(keyword) .At(keyword)
@@ -54,7 +54,7 @@ public sealed class Parser
topLevelSyntaxNodes.Add(definition); topLevelSyntaxNodes.Add(definition);
} }
catch (ParseException e) catch (CompileException e)
{ {
Diagnostics.Add(e.Diagnostic); Diagnostics.Add(e.Diagnostic);
while (HasToken) while (HasToken)
@@ -180,7 +180,7 @@ public sealed class Parser
{ {
if (!TryExpectIntLiteral(out var intLiteralToken)) if (!TryExpectIntLiteral(out var intLiteralToken))
{ {
throw new ParseException(Diagnostic throw new CompileException(Diagnostic
.Error("Value of enum field must be an integer literal") .Error("Value of enum field must be an integer literal")
.At(CurrentToken) .At(CurrentToken)
.Build()); .Build());
@@ -451,13 +451,13 @@ public sealed class Parser
Symbol.OpenBrace => new StructInitializerSyntax(GetTokens(startIndex), null, ParseStructInitializerBody()), Symbol.OpenBrace => new StructInitializerSyntax(GetTokens(startIndex), null, ParseStructInitializerBody()),
Symbol.Struct => ParseStructInitializer(startIndex), Symbol.Struct => ParseStructInitializer(startIndex),
Symbol.At => ParseBuiltinFunction(startIndex), Symbol.At => ParseBuiltinFunction(startIndex),
_ => throw new ParseException(Diagnostic _ => throw new CompileException(Diagnostic
.Error($"Unexpected symbol '{symbolToken.Symbol}' in expression") .Error($"Unexpected symbol '{symbolToken.Symbol}' in expression")
.WithHelp("Expected '(', '-', '!', '[' or '{'") .WithHelp("Expected '(', '-', '!', '[' or '{'")
.At(symbolToken) .At(symbolToken)
.Build()) .Build())
}, },
_ => throw new ParseException(Diagnostic _ => throw new CompileException(Diagnostic
.Error($"Unexpected token '{token.GetType().Name}' in expression") .Error($"Unexpected token '{token.GetType().Name}' in expression")
.WithHelp("Expected literal, identifier, or parenthesized expression") .WithHelp("Expected literal, identifier, or parenthesized expression")
.At(token) .At(token)
@@ -488,7 +488,7 @@ public sealed class Parser
} }
default: default:
{ {
throw new ParseException(Diagnostic.Error($"Unknown builtin {name.Value}").At(name).Build()); throw new CompileException(Diagnostic.Error($"Unknown builtin {name.Value}").At(name).Build());
} }
} }
} }
@@ -628,7 +628,7 @@ public sealed class Parser
{ {
statements.Add(ParseStatement()); statements.Add(ParseStatement());
} }
catch (ParseException ex) catch (CompileException ex)
{ {
Diagnostics.Add(ex.Diagnostic); Diagnostics.Add(ex.Diagnostic);
if (HasToken) if (HasToken)
@@ -654,7 +654,7 @@ public sealed class Parser
{ {
if (size is not 8 and not 16 and not 32 and not 64) if (size is not 8 and not 16 and not 32 and not 64)
{ {
throw new ParseException(Diagnostic throw new CompileException(Diagnostic
.Error("Arbitrary uint size is not supported") .Error("Arbitrary uint size is not supported")
.WithHelp("Use u8, u16, u32 or u64") .WithHelp("Use u8, u16, u32 or u64")
.At(name) .At(name)
@@ -668,7 +668,7 @@ public sealed class Parser
{ {
if (size is not 8 and not 16 and not 32 and not 64) if (size is not 8 and not 16 and not 32 and not 64)
{ {
throw new ParseException(Diagnostic throw new CompileException(Diagnostic
.Error("Arbitrary int size is not supported") .Error("Arbitrary int size is not supported")
.WithHelp("Use i8, i16, i32 or i64") .WithHelp("Use i8, i16, i32 or i64")
.At(name) .At(name)
@@ -682,7 +682,7 @@ public sealed class Parser
{ {
if (size is not 32 and not 64) if (size is not 32 and not 64)
{ {
throw new ParseException(Diagnostic throw new CompileException(Diagnostic
.Error("Arbitrary float size is not supported") .Error("Arbitrary float size is not supported")
.WithHelp("Use f32 or f64") .WithHelp("Use f32 or f64")
.At(name) .At(name)
@@ -772,7 +772,7 @@ public sealed class Parser
} }
} }
throw new ParseException(Diagnostic throw new CompileException(Diagnostic
.Error("Invalid type syntax") .Error("Invalid type syntax")
.WithHelp("Expected type name, '^' for pointer, or '[]' for array") .WithHelp("Expected type name, '^' for pointer, or '[]' for array")
.At(CurrentToken) .At(CurrentToken)
@@ -783,7 +783,7 @@ public sealed class Parser
{ {
if (!HasToken) if (!HasToken)
{ {
throw new ParseException(Diagnostic throw new CompileException(Diagnostic
.Error("Unexpected end of file") .Error("Unexpected end of file")
.WithHelp("Expected more tokens to complete the syntax") .WithHelp("Expected more tokens to complete the syntax")
.At(_tokens[^1]) .At(_tokens[^1])
@@ -800,7 +800,7 @@ public sealed class Parser
var token = ExpectToken(); var token = ExpectToken();
if (token is not SymbolToken symbol) if (token is not SymbolToken symbol)
{ {
throw new ParseException(Diagnostic throw new CompileException(Diagnostic
.Error($"Expected symbol, but found {token.GetType().Name}") .Error($"Expected symbol, but found {token.GetType().Name}")
.WithHelp("This position requires a symbol like '(', ')', '{', '}', etc.") .WithHelp("This position requires a symbol like '(', ')', '{', '}', etc.")
.At(token) .At(token)
@@ -815,7 +815,7 @@ public sealed class Parser
var token = ExpectSymbol(); var token = ExpectSymbol();
if (token.Symbol != expectedSymbol) if (token.Symbol != expectedSymbol)
{ {
throw new ParseException(Diagnostic throw new CompileException(Diagnostic
.Error($"Expected '{expectedSymbol}', but found '{token.Symbol}'") .Error($"Expected '{expectedSymbol}', but found '{token.Symbol}'")
.WithHelp($"Insert '{expectedSymbol}' here") .WithHelp($"Insert '{expectedSymbol}' here")
.At(token) .At(token)
@@ -865,7 +865,7 @@ public sealed class Parser
var token = ExpectToken(); var token = ExpectToken();
if (token is not IdentifierToken identifier) if (token is not IdentifierToken identifier)
{ {
throw new ParseException(Diagnostic throw new CompileException(Diagnostic
.Error($"Expected identifier, but found {token.GetType().Name}") .Error($"Expected identifier, but found {token.GetType().Name}")
.WithHelp("Provide a valid identifier name here") .WithHelp("Provide a valid identifier name here")
.At(token) .At(token)
@@ -893,7 +893,7 @@ public sealed class Parser
var token = ExpectToken(); var token = ExpectToken();
if (token is not StringLiteralToken identifier) if (token is not StringLiteralToken identifier)
{ {
throw new ParseException(Diagnostic throw new CompileException(Diagnostic
.Error($"Expected string literal, but found {token.GetType().Name}") .Error($"Expected string literal, but found {token.GetType().Name}")
.WithHelp("Provide a valid string literal") .WithHelp("Provide a valid string literal")
.At(token) .At(token)
@@ -915,13 +915,3 @@ public sealed class Parser
} }
public record SyntaxTree(List<TopLevelSyntaxNode> TopLevelSyntaxNodes); public record SyntaxTree(List<TopLevelSyntaxNode> TopLevelSyntaxNodes);
public class ParseException : Exception
{
public Diagnostic Diagnostic { get; }
public ParseException(Diagnostic diagnostic) : base(diagnostic.Message)
{
Diagnostic = diagnostic;
}
}

View File

@@ -58,7 +58,7 @@ public sealed class Tokenizer
Tokens.Add(ParseToken(current, _line, _column)); Tokens.Add(ParseToken(current, _line, _column));
} }
catch (TokenizerException e) catch (CompileException e)
{ {
Diagnostics.Add(e.Diagnostic); Diagnostics.Add(e.Diagnostic);
Next(); Next();
@@ -95,7 +95,7 @@ public sealed class Tokenizer
return ParseIdentifier(lineStart, columnStart); return ParseIdentifier(lineStart, columnStart);
} }
throw new TokenizerException(Diagnostic.Error($"Unknown token '{current}'").Build()); throw new CompileException(Diagnostic.Error($"Unknown token '{current}'").Build());
} }
private Token ParseNumber(int lineStart, int columnStart) private Token ParseNumber(int lineStart, int columnStart)
@@ -116,7 +116,7 @@ public sealed class Tokenizer
if (_index == digitStart) if (_index == digitStart)
{ {
throw new TokenizerException(Diagnostic throw new CompileException(Diagnostic
.Error("Invalid hex literal, no digits found") .Error("Invalid hex literal, no digits found")
.At(_fileName, _line, _column) .At(_fileName, _line, _column)
.Build()); .Build());
@@ -141,7 +141,7 @@ public sealed class Tokenizer
if (_index == digitStart) if (_index == digitStart)
{ {
throw new TokenizerException(Diagnostic throw new CompileException(Diagnostic
.Error("Invalid binary literal, no digits found") .Error("Invalid binary literal, no digits found")
.At(_fileName, _line, _column) .At(_fileName, _line, _column)
.Build()); .Build());
@@ -163,7 +163,7 @@ public sealed class Tokenizer
{ {
if (isFloat) if (isFloat)
{ {
throw new TokenizerException(Diagnostic throw new CompileException(Diagnostic
.Error("More than one period found in float literal") .Error("More than one period found in float literal")
.At(_fileName, _line, _column) .At(_fileName, _line, _column)
.Build()); .Build());
@@ -198,7 +198,7 @@ public sealed class Tokenizer
{ {
if (_index >= _content.Length) if (_index >= _content.Length)
{ {
throw new TokenizerException(Diagnostic throw new CompileException(Diagnostic
.Error("Unclosed string literal") .Error("Unclosed string literal")
.At(_fileName, _line, _column) .At(_fileName, _line, _column)
.Build()); .Build());
@@ -208,7 +208,7 @@ public sealed class Tokenizer
if (next == '\n') if (next == '\n')
{ {
throw new TokenizerException(Diagnostic throw new CompileException(Diagnostic
.Error("Unclosed string literal (newline found)") .Error("Unclosed string literal (newline found)")
.At(_fileName, _line, _column) .At(_fileName, _line, _column)
.Build()); .Build());
@@ -376,13 +376,3 @@ public sealed class Tokenizer
_column += count; _column += count;
} }
} }
public class TokenizerException : Exception
{
public Diagnostic Diagnostic { get; }
public TokenizerException(Diagnostic diagnostic) : base(diagnostic.Message)
{
Diagnostic = diagnostic;
}
}

View File

@@ -0,0 +1,50 @@
using NubLang.Ast;
namespace NubLang.Syntax;
public sealed class TypedModule
{
public static TypedModule FromModule(string name, Module module, Dictionary<string, Module> modules)
{
var typeResolver = new TypeResolver(modules);
var functionPrototypes = new List<FuncPrototypeNode>();
foreach (var funcSyntax in module.Functions(true))
{
var parameters = new List<FuncParameterNode>();
foreach (var parameter in funcSyntax.Prototype.Parameters)
{
parameters.Add(new FuncParameterNode(parameter.Tokens, parameter.NameToken, typeResolver.ResolveType(parameter.Type, name)));
}
var returnType = typeResolver.ResolveType(funcSyntax.Prototype.ReturnType, name);
functionPrototypes.Add(new FuncPrototypeNode(funcSyntax.Tokens, funcSyntax.Prototype.NameToken, funcSyntax.Prototype.ExternSymbolToken, parameters, returnType));
}
var structTypes = new List<NubStructType>();
foreach (var structSyntax in module.Structs(true))
{
var fields = new List<NubStructFieldType>();
foreach (var field in structSyntax.Fields)
{
fields.Add(new NubStructFieldType(field.NameToken.Value, typeResolver.ResolveType(field.Type, name), field.Value != null));
}
structTypes.Add(new NubStructType(name, structSyntax.NameToken.Value, fields));
}
return new TypedModule(functionPrototypes, structTypes);
}
public TypedModule(List<FuncPrototypeNode> functionPrototypes, List<NubStructType> structTypes)
{
FunctionPrototypes = functionPrototypes;
StructTypes = structTypes;
}
public List<FuncPrototypeNode> FunctionPrototypes { get; set; }
public List<NubStructType> StructTypes { get; set; }
}

View File

@@ -2,26 +2,44 @@ module main
extern "puts" func puts(text: ^i8) extern "puts" func puts(text: ^i8)
struct Name
{
first: ^i8
last: ^i8
}
struct Human struct Human
{ {
age: u64 age: u64
name: ^i8 name: &Name
} }
extern "main" func main(argc: i64, argv: [?]^i8): i64 extern "main" func main(argc: i64, argv: [?]^i8): i64
{ {
let x: &Human = { let x: &Human = {
age = 23 age = 23
name = "test" name = {
first = "oliver"
last = "stene"
}
} }
puts(x^.name) let z: Human = {
age = 23
name = {
first = "oliver"
last = "stene"
}
}
test(x) test(x)
let y = x
return 0 return 0
} }
func test(x: &Human) func test(x: &Human): &Human
{ {
return x
} }

View File

@@ -12,7 +12,7 @@ void *rc_alloc(size_t size, void (*destructor)(void *self))
exit(69); exit(69);
} }
header->ref_count = 0; header->ref_count = 1;
header->destructor = destructor; header->destructor = destructor;
return (void *)(header + 1); return (void *)(header + 1);
@@ -29,7 +29,7 @@ void rc_release(void *obj)
{ {
ref_header *header = ((ref_header *)obj) - 1; ref_header *header = ((ref_header *)obj) - 1;
printf("rc_release\n"); printf("rc_release\n");
if (--header->ref_count <= 0) if (--header->ref_count == 0)
{ {
if (header->destructor) if (header->destructor)
{ {