Compare commits

...

10 Commits

Author SHA1 Message Date
nub31
c3d64c4ea9 Update lsp to have set root path 2025-11-05 15:53:08 +01:00
nub31
d3822bc9b4 ... 2025-11-05 15:20:45 +01:00
nub31
36622755a9 ... 2025-11-03 19:54:41 +01:00
nub31
47fef6bc9f ... 2025-11-03 17:10:15 +01:00
nub31
7d49bf43b7 ... 2025-11-03 16:01:20 +01:00
nub31
7ce451768d reformat generator 2025-11-03 14:02:47 +01:00
nub31
f231a45285 ... 2025-11-03 13:46:25 +01:00
nub31
085f7a1a6a ... 2025-11-03 12:52:17 +01:00
nub31
40d500fddd ... 2025-10-31 15:18:18 +01:00
nub31
7c7624b1bc ... 2025-10-31 14:42:58 +01:00
31 changed files with 2133 additions and 2110 deletions

View File

@@ -1,31 +1,57 @@
using NubLang.Ast; using NubLang.Ast;
using NubLang.Diagnostics; using NubLang.Diagnostics;
using NubLang.Generation; using NubLang.Generation;
using NubLang.Modules;
using NubLang.Syntax; using NubLang.Syntax;
var diagnostics = new List<Diagnostic>(); var diagnostics = new List<Diagnostic>();
var syntaxTrees = new List<SyntaxTree>(); var syntaxTrees = new List<SyntaxTree>();
var tokenizer = new Tokenizer();
var parser = new Parser();
var generator = new LlvmSharpGenerator();
foreach (var file in args) foreach (var file in args)
{ {
var tokenizer = new Tokenizer(file, File.ReadAllText(file)); var tokens = tokenizer.Tokenize(file, File.ReadAllText(file));
tokenizer.Tokenize();
diagnostics.AddRange(tokenizer.Diagnostics); diagnostics.AddRange(tokenizer.Diagnostics);
var parser = new Parser(); var syntaxTree = parser.Parse(tokens);
var syntaxTree = parser.Parse(tokenizer.Tokens);
diagnostics.AddRange(parser.Diagnostics); diagnostics.AddRange(parser.Diagnostics);
syntaxTrees.Add(syntaxTree); syntaxTrees.Add(syntaxTree);
} }
var modules = Module.Collect(syntaxTrees); foreach (var diagnostic in diagnostics)
{
Console.Error.WriteLine(diagnostic.FormatANSI());
}
if (diagnostics.Any(diagnostic => diagnostic.Severity == DiagnosticSeverity.Error))
{
return 1;
}
diagnostics.Clear();
ModuleRepository moduleRepository;
try
{
moduleRepository = ModuleRepository.Create(syntaxTrees);
}
catch (CompileException e)
{
Console.Error.WriteLine(e.Diagnostic.FormatANSI());
return 1;
}
var compilationUnits = new List<List<TopLevelNode>>(); var compilationUnits = new List<List<TopLevelNode>>();
for (var i = 0; i < args.Length; i++) for (var i = 0; i < args.Length; i++)
{ {
var typeChecker = new TypeChecker(syntaxTrees[i], modules); var typeChecker = new TypeChecker();
var compilationUnit = typeChecker.Check(); var compilationUnit = typeChecker.Check(syntaxTrees[i], moduleRepository);
compilationUnits.Add(compilationUnit); compilationUnits.Add(compilationUnit);
diagnostics.AddRange(typeChecker.Diagnostics); diagnostics.AddRange(typeChecker.Diagnostics);
@@ -41,6 +67,8 @@ if (diagnostics.Any(diagnostic => diagnostic.Severity == DiagnosticSeverity.Erro
return 1; return 1;
} }
diagnostics.Clear();
Directory.CreateDirectory(".build"); Directory.CreateDirectory(".build");
for (var i = 0; i < args.Length; i++) for (var i = 0; i < args.Length; i++)
@@ -48,7 +76,6 @@ for (var i = 0; i < args.Length; i++)
var file = args[i]; var file = args[i];
var compilationUnit = compilationUnits[i]; var compilationUnit = compilationUnits[i];
var generator = new LlvmGenerator();
var directory = Path.GetDirectoryName(file); var directory = Path.GetDirectoryName(file);
if (!string.IsNullOrWhiteSpace(directory)) if (!string.IsNullOrWhiteSpace(directory))
{ {
@@ -56,7 +83,7 @@ for (var i = 0; i < args.Length; i++)
} }
var path = Path.Combine(".build", Path.ChangeExtension(file, "ll")); var path = Path.Combine(".build", Path.ChangeExtension(file, "ll"));
File.WriteAllText(path, generator.Emit(compilationUnit)); generator.Emit(compilationUnit, moduleRepository, file, path);
} }
return 0; return 0;

View File

@@ -1,4 +1,5 @@
using NubLang.Ast; using NubLang.Ast;
using NubLang.Syntax;
using OmniSharp.Extensions.LanguageServer.Protocol.Models; using OmniSharp.Extensions.LanguageServer.Protocol.Models;
using Range = OmniSharp.Extensions.LanguageServer.Protocol.Models.Range; using Range = OmniSharp.Extensions.LanguageServer.Protocol.Models.Range;
@@ -15,11 +16,49 @@ public static class AstExtensions
return new Location return new Location
{ {
Uri = node.Tokens.First().Span.FilePath, Uri = node.Tokens.First().Span.SourcePath,
Range = new Range(node.Tokens.First().Span.Start.Line - 1, node.Tokens.First().Span.Start.Column - 1, node.Tokens.Last().Span.End.Line - 1, node.Tokens.Last().Span.End.Column - 1) Range = new Range(node.Tokens.First().Span.StartLine - 1, node.Tokens.First().Span.StartColumn - 1, node.Tokens.Last().Span.EndLine - 1, node.Tokens.Last().Span.EndColumn - 1)
}; };
} }
public static Location ToLocation(this Token token)
{
return new Location
{
Uri = token.Span.SourcePath,
Range = new Range(token.Span.StartLine - 1, token.Span.StartColumn - 1, token.Span.EndLine - 1, token.Span.EndColumn - 1)
};
}
public static bool ContainsPosition(this Token token, int line, int character)
{
var startLine = token.Span.StartLine - 1;
var startChar = token.Span.StartColumn - 1;
var endLine = token.Span.EndLine - 1;
var endChar = token.Span.EndColumn - 1;
if (line < startLine || line > endLine) return false;
if (line > startLine && line < endLine) return true;
if (startLine == endLine)
{
return character >= startChar && character <= endChar;
}
if (line == startLine)
{
return character >= startChar;
}
if (line == endLine)
{
return character <= endChar;
}
return false;
}
public static bool ContainsPosition(this Node node, int line, int character) public static bool ContainsPosition(this Node node, int line, int character)
{ {
if (node.Tokens.Count == 0) if (node.Tokens.Count == 0)
@@ -27,13 +66,12 @@ public static class AstExtensions
return false; return false;
} }
var start = node.Tokens.First().Span.Start; var span = node.Tokens.First().Span;
var end = node.Tokens.Last().Span.End;
var startLine = start.Line - 1; var startLine = span.StartLine - 1;
var startChar = start.Column - 1; var startChar = span.StartColumn - 1;
var endLine = end.Line - 1; var endLine = span.EndLine - 1;
var endChar = end.Column - 1; var endChar = span.EndColumn - 1;
if (line < startLine || line > endLine) return false; if (line < startLine || line > endLine) return false;
@@ -69,8 +107,8 @@ public static class AstExtensions
return compilationUnit return compilationUnit
.SelectMany(x => x.DescendantsAndSelf()) .SelectMany(x => x.DescendantsAndSelf())
.Where(n => n.ContainsPosition(line, character)) .Where(n => n.ContainsPosition(line, character))
.OrderBy(n => n.Tokens.First().Span.Start.Line) .OrderBy(n => n.Tokens.First().Span.StartLine)
.ThenBy(n => n.Tokens.First().Span.Start.Column) .ThenBy(n => n.Tokens.First().Span.StartColumn)
.LastOrDefault(); .LastOrDefault();
} }
} }

View File

@@ -1,4 +1,5 @@
using NubLang.Ast; using NubLang.Ast;
using NubLang.Modules;
using OmniSharp.Extensions.LanguageServer.Protocol.Client.Capabilities; using OmniSharp.Extensions.LanguageServer.Protocol.Client.Capabilities;
using OmniSharp.Extensions.LanguageServer.Protocol.Document; using OmniSharp.Extensions.LanguageServer.Protocol.Document;
using OmniSharp.Extensions.LanguageServer.Protocol.Models; using OmniSharp.Extensions.LanguageServer.Protocol.Models;
@@ -29,13 +30,6 @@ internal class CompletionHandler(WorkspaceManager workspaceManager) : Completion
Label = "module", Label = "module",
InsertTextFormat = InsertTextFormat.Snippet, InsertTextFormat = InsertTextFormat.Snippet,
InsertText = "module \"$0\"", InsertText = "module \"$0\"",
},
new()
{
Kind = CompletionItemKind.Keyword,
Label = "import",
InsertTextFormat = InsertTextFormat.Snippet,
InsertText = "import \"$0\"",
} }
]; ];
@@ -112,65 +106,76 @@ internal class CompletionHandler(WorkspaceManager workspaceManager) : Completion
private CompletionList HandleSync(CompletionParams request, CancellationToken cancellationToken) private CompletionList HandleSync(CompletionParams request, CancellationToken cancellationToken)
{ {
var completions = new List<CompletionItem>(); var completions = new List<CompletionItem>();
var position = request.Position;
var uri = request.TextDocument.Uri; var compilationUnit = workspaceManager.GetTopLevelNodes(request.TextDocument.Uri.GetFileSystemPath());
var compilationUnit = workspaceManager.GetCompilationUnit(uri);
if (compilationUnit != null) var repository = workspaceManager.GetModuleRepository();
var function = compilationUnit
.OfType<FuncNode>()
.FirstOrDefault(x => x.Body != null && x.Body.ContainsPosition(request.Position.Line, request.Position.Character));
if (function != null)
{ {
var function = compilationUnit.OfType<FuncNode>().FirstOrDefault(x => x.Body != null && x.Body.ContainsPosition(position.Line, position.Character)); completions.AddRange(_statementSnippets);
if (function != null)
foreach (var module in repository.GetAll())
{ {
completions.AddRange(_statementSnippets); foreach (var prototype in module.FunctionPrototypes)
// foreach (var (module, prototypes) in compilationUnit.ImportedFunctions)
// {
// foreach (var prototype in prototypes)
// {
// var parameterStrings = new List<string>();
// foreach (var (index, parameter) in prototype.Parameters.Index())
// {
// parameterStrings.AddRange($"${{{index + 1}:{parameter.NameToken.Value}}}");
// }
//
// completions.Add(new CompletionItem
// {
// Kind = CompletionItemKind.Function,
// Label = $"{module.Value}::{prototype.NameToken.Value}",
// InsertTextFormat = InsertTextFormat.Snippet,
// InsertText = $"{module.Value}::{prototype.NameToken.Value}({string.Join(", ", parameterStrings)})",
// });
// }
// }
foreach (var parameter in function.Prototype.Parameters)
{ {
var parameterStrings = new List<string>();
foreach (var (index, parameter) in prototype.Parameters.Index())
{
parameterStrings.AddRange($"${{{index + 1}:{parameter.NameToken.Value}}}");
}
var isCurrentModule = false;
var moduleDecl = compilationUnit.OfType<ModuleNode>().FirstOrDefault();
if (moduleDecl != null)
{
if (moduleDecl.NameToken.Value == module.Name)
{
isCurrentModule = true;
}
}
completions.Add(new CompletionItem completions.Add(new CompletionItem
{ {
Kind = CompletionItemKind.Variable, Kind = CompletionItemKind.Function,
Label = parameter.NameToken.Value, Label = isCurrentModule ? prototype.NameToken.Value : $"{module.Name}::{prototype.NameToken.Value}",
InsertText = parameter.NameToken.Value, InsertTextFormat = InsertTextFormat.Snippet,
}); InsertText = $"{(isCurrentModule ? "" : $"{module.Name}::")}{prototype.NameToken.Value}({string.Join(", ", parameterStrings)})",
}
var variables = function.Body!
.Descendants()
.OfType<VariableDeclarationNode>();
foreach (var variable in variables)
{
completions.Add(new CompletionItem
{
Kind = CompletionItemKind.Variable,
Label = variable.NameToken.Value,
InsertText = variable.NameToken.Value,
}); });
} }
} }
else
foreach (var parameter in function.Prototype.Parameters)
{ {
completions.AddRange(_definitionSnippets); completions.Add(new CompletionItem
{
Kind = CompletionItemKind.Variable,
Label = parameter.NameToken.Value,
InsertText = parameter.NameToken.Value,
});
} }
var variables = function.Body!
.Descendants()
.OfType<VariableDeclarationNode>();
foreach (var variable in variables)
{
completions.Add(new CompletionItem
{
Kind = CompletionItemKind.Variable,
Label = variable.NameToken.Value,
InsertText = variable.NameToken.Value,
});
}
}
else
{
completions.AddRange(_definitionSnippets);
} }
return new CompletionList(completions, false); return new CompletionList(completions, false);

View File

@@ -20,53 +20,59 @@ internal class DefinitionHandler(WorkspaceManager workspaceManager) : Definition
private LocationOrLocationLinks? HandleSync(DefinitionParams request, CancellationToken cancellationToken) private LocationOrLocationLinks? HandleSync(DefinitionParams request, CancellationToken cancellationToken)
{ {
var uri = request.TextDocument.Uri; var uri = request.TextDocument.Uri;
var compilationUnit = workspaceManager.GetCompilationUnit(uri); var topLevelNodes = workspaceManager.GetTopLevelNodes(uri.GetFileSystemPath());
if (compilationUnit == null)
{
return null;
}
var line = request.Position.Line; var line = request.Position.Line;
var character = request.Position.Character; var character = request.Position.Character;
var node = compilationUnit.DeepestNodeAtPosition(line, character); var node = topLevelNodes.DeepestNodeAtPosition(line, character);
switch (node) switch (node)
{ {
case VariableIdentifierNode variableIdentifierNode: case VariableIdentifierNode variableIdentifierNode:
{ {
var function = compilationUnit.FunctionAtPosition(line, character); var funcNode = topLevelNodes.FunctionAtPosition(line, character);
var parameter = function?.Prototype.Parameters.FirstOrDefault(x => x.NameToken.Value == variableIdentifierNode.NameToken.Value); var parameter = funcNode?.Prototype.Parameters.FirstOrDefault(x => x.NameToken.Value == variableIdentifierNode.NameToken.Value);
if (parameter != null) if (parameter != null)
{ {
return new LocationOrLocationLinks(parameter.ToLocation()); return new LocationOrLocationLinks(parameter.NameToken.ToLocation());
} }
var variable = function?.Body? var variable = funcNode?.Body?
.Descendants() .Descendants()
.OfType<VariableDeclarationNode>() .OfType<VariableDeclarationNode>()
.FirstOrDefault(x => x.NameToken.Value == variableIdentifierNode.NameToken.Value); .FirstOrDefault(x => x.NameToken.Value == variableIdentifierNode.NameToken.Value);
if (variable != null) if (variable != null)
{ {
return new LocationOrLocationLinks(variable.ToLocation()); return new LocationOrLocationLinks(variable.NameToken.ToLocation());
} }
return null; return null;
} }
case FuncIdentifierNode funcIdentifierNode: case LocalFuncIdentifierNode localFuncIdentifierNode:
{ {
// var prototype = compilationUnit var funcNode = topLevelNodes.OfType<FuncNode>().FirstOrDefault(x => x.NameToken.Value == localFuncIdentifierNode.NameToken.Value);
// .ImportedFunctions if (funcNode != null)
// .Where(x => x.Key.Value == funcIdentifierNode.ModuleToken.Value) {
// .SelectMany(x => x.Value) return new LocationOrLocationLinks(funcNode.NameToken.ToLocation());
// .FirstOrDefault(x => x.NameToken.Value == funcIdentifierNode.NameToken.Value); }
//
// if (prototype != null) return null;
// { }
// return new LocationOrLocationLinks(prototype.ToLocation()); case ModuleFuncIdentifierNode localFuncIdentifierNode:
// } {
var repository = workspaceManager.GetModuleRepository();
if (!repository.TryGet(localFuncIdentifierNode.ModuleToken, out var module))
{
return null;
}
if (module.TryResolveFunc(localFuncIdentifierNode.NameToken, out var func, out _))
{
return new LocationOrLocationLinks(func.NameToken.ToLocation());
}
return null; return null;
} }

View File

@@ -37,7 +37,7 @@ public class DiagnosticsPublisher
}, },
Message = $"{nubDiagnostic.Message}\n{(nubDiagnostic.Help == null ? "" : $"help: {nubDiagnostic.Help}")}", Message = $"{nubDiagnostic.Message}\n{(nubDiagnostic.Help == null ? "" : $"help: {nubDiagnostic.Help}")}",
Range = nubDiagnostic.Span.HasValue Range = nubDiagnostic.Span.HasValue
? new Range(nubDiagnostic.Span.Value.Start.Line - 1, nubDiagnostic.Span.Value.Start.Column - 1, nubDiagnostic.Span.Value.End.Line - 1, nubDiagnostic.Span.Value.End.Column - 1) ? new Range(nubDiagnostic.Span.Value.StartLine - 1, nubDiagnostic.Span.Value.StartColumn - 1, nubDiagnostic.Span.Value.EndLine - 1, nubDiagnostic.Span.Value.EndColumn - 1)
: new Range(), : new Range(),
}; };
} }

View File

@@ -1,5 +1,7 @@
using System.Globalization; using System.Globalization;
using NubLang.Ast; using NubLang.Ast;
using NubLang.Modules;
using NubLang.Types;
using OmniSharp.Extensions.LanguageServer.Protocol.Client.Capabilities; using OmniSharp.Extensions.LanguageServer.Protocol.Client.Capabilities;
using OmniSharp.Extensions.LanguageServer.Protocol.Document; using OmniSharp.Extensions.LanguageServer.Protocol.Document;
using OmniSharp.Extensions.LanguageServer.Protocol.Models; using OmniSharp.Extensions.LanguageServer.Protocol.Models;
@@ -23,8 +25,16 @@ internal class HoverHandler(WorkspaceManager workspaceManager) : HoverHandlerBas
private Hover? HandleSync(HoverParams request, CancellationToken cancellationToken) private Hover? HandleSync(HoverParams request, CancellationToken cancellationToken)
{ {
var compilationUnit = workspaceManager.GetCompilationUnit(request.TextDocument.Uri); var topLevelNodes = workspaceManager.GetTopLevelNodes(request.TextDocument.Uri.GetFileSystemPath());
if (compilationUnit == null)
var moduleDecl = topLevelNodes.OfType<ModuleNode>().FirstOrDefault();
if (moduleDecl == null)
{
return null;
}
var moduleRepository = workspaceManager.GetModuleRepository();
if (!moduleRepository.TryGet(moduleDecl.NameToken, out var module))
{ {
return null; return null;
} }
@@ -32,117 +42,162 @@ internal class HoverHandler(WorkspaceManager workspaceManager) : HoverHandlerBas
var line = request.Position.Line; var line = request.Position.Line;
var character = request.Position.Character; var character = request.Position.Character;
var hoveredNode = compilationUnit.DeepestNodeAtPosition(line, character); var hoveredNode = topLevelNodes.DeepestNodeAtPosition(line, character);
if (hoveredNode == null) if (hoveredNode == null)
{ {
return null; return null;
} }
// var message = CreateMessage(hoveredNode, compilationUnit); var message = CreateMessage(hoveredNode, moduleRepository, module, line, character);
// if (message == null) if (message == null)
// { {
// return null; return null;
// } }
//
// return new Hover
// {
// Contents = new MarkedStringsOrMarkupContent(new MarkupContent
// {
// Value = message,
// Kind = MarkupKind.Markdown,
// })
// };
return null; return new Hover
{
Contents = new MarkedStringsOrMarkupContent(new MarkupContent
{
Value = message,
Kind = MarkupKind.Markdown,
})
};
} }
// private static string? CreateMessage(Node hoveredNode, CompilationUnit compilationUnit) private static string? CreateMessage(Node hoveredNode, ModuleRepository repository, ModuleRepository.Module currentModule, int line, int character)
// { {
// return hoveredNode switch return hoveredNode switch
// { {
// FuncNode funcNode => CreateFuncPrototypeMessage(funcNode.Prototype), FuncNode funcNode => CreateFuncPrototypeMessage(funcNode.Prototype),
// FuncPrototypeNode funcPrototypeNode => CreateFuncPrototypeMessage(funcPrototypeNode), FuncPrototypeNode funcPrototypeNode => CreateFuncPrototypeMessage(funcPrototypeNode),
// FuncIdentifierNode funcIdentifierNode => CreateFuncIdentifierMessage(funcIdentifierNode, compilationUnit), LocalFuncIdentifierNode funcIdentifierNode => CreateLocalFuncIdentifierMessage(funcIdentifierNode, currentModule),
// FuncParameterNode funcParameterNode => CreateTypeNameMessage("Function parameter", funcParameterNode.NameToken.Value, funcParameterNode.Type), ModuleFuncIdentifierNode funcIdentifierNode => CreateModuleFuncIdentifierMessage(funcIdentifierNode, repository),
// VariableIdentifierNode variableIdentifierNode => CreateTypeNameMessage("Variable", variableIdentifierNode.NameToken.Value, variableIdentifierNode.Type), FuncParameterNode funcParameterNode => CreateTypeNameMessage("Function parameter", funcParameterNode.NameToken.Value, funcParameterNode.Type),
// VariableDeclarationNode variableDeclarationNode => CreateTypeNameMessage("Variable declaration", variableDeclarationNode.NameToken.Value, variableDeclarationNode.Type), VariableIdentifierNode variableIdentifierNode => CreateTypeNameMessage("Variable", variableIdentifierNode.NameToken.Value, variableIdentifierNode.Type),
// StructFieldAccessNode structFieldAccessNode => CreateTypeNameMessage("Struct field", $"{structFieldAccessNode.Target.Type}.{structFieldAccessNode.FieldToken.Value}", structFieldAccessNode.Type), VariableDeclarationNode variableDeclarationNode => CreateTypeNameMessage("Variable declaration", variableDeclarationNode.NameToken.Value, variableDeclarationNode.Type),
// CStringLiteralNode cStringLiteralNode => CreateLiteralMessage(cStringLiteralNode.Type, '"' + cStringLiteralNode.Value + '"'), StructFieldAccessNode structFieldAccessNode => CreateTypeNameMessage("Struct field", $"{structFieldAccessNode.Target.Type}.{structFieldAccessNode.FieldToken.Value}", structFieldAccessNode.Type),
// StringLiteralNode stringLiteralNode => CreateLiteralMessage(stringLiteralNode.Type, '"' + stringLiteralNode.Value + '"'), CStringLiteralNode cStringLiteralNode => CreateLiteralMessage(cStringLiteralNode.Type, '"' + cStringLiteralNode.Value + '"'),
// BoolLiteralNode boolLiteralNode => CreateLiteralMessage(boolLiteralNode.Type, boolLiteralNode.Value.ToString()), StringLiteralNode stringLiteralNode => CreateLiteralMessage(stringLiteralNode.Type, '"' + stringLiteralNode.Value + '"'),
// Float32LiteralNode float32LiteralNode => CreateLiteralMessage(float32LiteralNode.Type, float32LiteralNode.Value.ToString(CultureInfo.InvariantCulture)), BoolLiteralNode boolLiteralNode => CreateLiteralMessage(boolLiteralNode.Type, boolLiteralNode.Value.ToString()),
// Float64LiteralNode float64LiteralNode => CreateLiteralMessage(float64LiteralNode.Type, float64LiteralNode.Value.ToString(CultureInfo.InvariantCulture)), Float32LiteralNode float32LiteralNode => CreateLiteralMessage(float32LiteralNode.Type, float32LiteralNode.Value.ToString(CultureInfo.InvariantCulture)),
// I8LiteralNode i8LiteralNode => CreateLiteralMessage(i8LiteralNode.Type, i8LiteralNode.Value.ToString()), Float64LiteralNode float64LiteralNode => CreateLiteralMessage(float64LiteralNode.Type, float64LiteralNode.Value.ToString(CultureInfo.InvariantCulture)),
// I16LiteralNode i16LiteralNode => CreateLiteralMessage(i16LiteralNode.Type, i16LiteralNode.Value.ToString()), I8LiteralNode i8LiteralNode => CreateLiteralMessage(i8LiteralNode.Type, i8LiteralNode.Value.ToString()),
// I32LiteralNode i32LiteralNode => CreateLiteralMessage(i32LiteralNode.Type, i32LiteralNode.Value.ToString()), I16LiteralNode i16LiteralNode => CreateLiteralMessage(i16LiteralNode.Type, i16LiteralNode.Value.ToString()),
// I64LiteralNode i64LiteralNode => CreateLiteralMessage(i64LiteralNode.Type, i64LiteralNode.Value.ToString()), I32LiteralNode i32LiteralNode => CreateLiteralMessage(i32LiteralNode.Type, i32LiteralNode.Value.ToString()),
// U8LiteralNode u8LiteralNode => CreateLiteralMessage(u8LiteralNode.Type, u8LiteralNode.Value.ToString()), I64LiteralNode i64LiteralNode => CreateLiteralMessage(i64LiteralNode.Type, i64LiteralNode.Value.ToString()),
// U16LiteralNode u16LiteralNode => CreateLiteralMessage(u16LiteralNode.Type, u16LiteralNode.Value.ToString()), U8LiteralNode u8LiteralNode => CreateLiteralMessage(u8LiteralNode.Type, u8LiteralNode.Value.ToString()),
// U32LiteralNode u32LiteralNode => CreateLiteralMessage(u32LiteralNode.Type, u32LiteralNode.Value.ToString()), U16LiteralNode u16LiteralNode => CreateLiteralMessage(u16LiteralNode.Type, u16LiteralNode.Value.ToString()),
// U64LiteralNode u64LiteralNode => CreateLiteralMessage(u64LiteralNode.Type, u64LiteralNode.Value.ToString()), U32LiteralNode u32LiteralNode => CreateLiteralMessage(u32LiteralNode.Type, u32LiteralNode.Value.ToString()),
// // Expressions can have a generic fallback showing the resulting type U64LiteralNode u64LiteralNode => CreateLiteralMessage(u64LiteralNode.Type, u64LiteralNode.Value.ToString()),
// ExpressionNode expressionNode => $""" StructInitializerNode structInitializerNode => CreateStructInitializerMessage(structInitializerNode, line, character),
// **Expression** `{expressionNode.GetType().Name}` // Expressions can have a generic fallback showing the resulting type
// ```nub ExpressionNode expressionNode => $"""
// {expressionNode.Type} **Expression** `{expressionNode.GetType().Name}`
// ``` ```nub
// """, {expressionNode.Type}
// BlockNode => null, ```
// _ => hoveredNode.GetType().Name """,
// }; BlockNode => null,
// } _ => hoveredNode.GetType().Name
// };
// private static string CreateLiteralMessage(NubType type, string value) }
// {
// return $""" private static string CreateStructInitializerMessage(StructInitializerNode structInitializerNode, int line, int character)
// **Literal** `{type}` {
// ```nub var hoveredInitializerName = structInitializerNode
// {value}: {type} .Initializers
// ``` .Select(x => x.Key)
// """; .FirstOrDefault(x => x.ContainsPosition(line, character));
// }
// var structType = (NubStructType)structInitializerNode.Type;
// private static string CreateTypeNameMessage(string description, string name, NubType type)
// { if (hoveredInitializerName != null)
// return $""" {
// **{description}** `{name}` var field = structType.Fields.FirstOrDefault(x => x.Name == hoveredInitializerName.Value);
// ```nub if (field != null)
// {name}: {type} {
// ``` return $"""
// """; **Field** in `{structType}`
// } ```nub
// {hoveredInitializerName.Value}: {field.Type}
// private static string CreateFuncIdentifierMessage(FuncIdentifierNode funcIdentifierNode, CompilationUnit compilationUnit) ```
// { """;
// var func = compilationUnit.ImportedFunctions }
// .Where(x => x.Key.Value == funcIdentifierNode.ModuleToken.Value) else
// .SelectMany(x => x.Value) {
// .FirstOrDefault(x => x.NameToken.Value == funcIdentifierNode.NameToken.Value); return $"""
// **Field** in `{structType}`
// if (func == null) ```nub
// { // Field not found
// return $""" ```
// **Function** `{funcIdentifierNode.ModuleToken.Value}::{funcIdentifierNode.NameToken.Value}` """;
// ```nub }
// // Declaration not found }
// ```
// """; return $"**Struct initializer** `{structType}`";
// } }
//
// return CreateFuncPrototypeMessage(func); private static string CreateLiteralMessage(NubType type, string value)
// } {
// return $"""
// private static string CreateFuncPrototypeMessage(FuncPrototypeNode funcPrototypeNode) **Literal** `{type}`
// { ```nub
// var parameterText = string.Join(", ", funcPrototypeNode.Parameters.Select(x => $"{x.NameToken.Value}: {x.Type}")); {value}: {type}
// var externText = funcPrototypeNode.ExternSymbolToken != null ? $"extern \"{funcPrototypeNode.ExternSymbolToken.Value}\" " : ""; ```
// """;
// return $""" }
// **Function** `{funcPrototypeNode.NameToken.Value}`
// ```nub private static string CreateTypeNameMessage(string description, string name, NubType type)
// {externText}func {funcPrototypeNode.NameToken.Value}({parameterText}): {funcPrototypeNode.ReturnType} {
// ``` return $"""
// """; **{description}** `{name}`
// } ```nub
{name}: {type}
```
""";
}
private static string CreateLocalFuncIdentifierMessage(LocalFuncIdentifierNode funcIdentifierNode, ModuleRepository.Module currentModule)
{
if (!currentModule.TryResolveFunc(funcIdentifierNode.NameToken, out var func, out _))
{
return $"""
**Function** `{funcIdentifierNode.NameToken.Value}`
```nub
// Declaration not found
```
""";
}
return CreateFuncPrototypeMessage(func);
}
private static string CreateModuleFuncIdentifierMessage(ModuleFuncIdentifierNode funcIdentifierNode, ModuleRepository repository)
{
if (!repository.TryGet(funcIdentifierNode.ModuleToken, out var module) || !module.TryResolveFunc(funcIdentifierNode.NameToken, out var func, out _))
{
return $"""
**Function** `{funcIdentifierNode.ModuleToken.Value}::{funcIdentifierNode.NameToken.Value}`
```nub
// Declaration not found
```
""";
}
return CreateFuncPrototypeMessage(func);
}
private static string CreateFuncPrototypeMessage(FuncPrototypeNode funcPrototypeNode)
{
var parameterText = string.Join(", ", funcPrototypeNode.Parameters.Select(x => $"{x.NameToken.Value}: {x.Type}"));
var externText = funcPrototypeNode.ExternSymbolToken != null ? $"extern \"{funcPrototypeNode.ExternSymbolToken.Value}\" " : "";
return $"""
**Function** `{funcPrototypeNode.NameToken.Value}`
```nub
{externText}func {funcPrototypeNode.NameToken.Value}({parameterText}): {funcPrototypeNode.ReturnType}
```
""";
}
} }

View File

@@ -18,17 +18,7 @@ var server = await LanguageServer.From(options => options
.WithHandler<HoverHandler>() .WithHandler<HoverHandler>()
.WithHandler<CompletionHandler>() .WithHandler<CompletionHandler>()
.WithHandler<DefinitionHandler>() .WithHandler<DefinitionHandler>()
.OnInitialize((server, request, ct) => .WithHandler<SetRootPathCommandHandler>()
{
var workspaceManager = server.GetRequiredService<WorkspaceManager>();
if (request.RootPath != null)
{
workspaceManager.Init(request.RootPath);
}
return Task.CompletedTask;
})
); );
await server.WaitForExit; await server.WaitForExit;

View File

@@ -0,0 +1,31 @@
using MediatR;
using OmniSharp.Extensions.LanguageServer.Protocol.Client.Capabilities;
using OmniSharp.Extensions.LanguageServer.Protocol.Models;
using OmniSharp.Extensions.LanguageServer.Protocol.Workspace;
namespace NubLang.LSP;
public class SetRootPathCommandHandler(WorkspaceManager workspaceManager) : ExecuteCommandHandlerBase
{
protected override ExecuteCommandRegistrationOptions CreateRegistrationOptions(ExecuteCommandCapability capability, ClientCapabilities clientCapabilities)
{
return new ExecuteCommandRegistrationOptions
{
Commands = new Container<string>("nub.setRootPath")
};
}
public override Task<Unit> Handle(ExecuteCommandParams request, CancellationToken cancellationToken)
{
if (request is { Command: "nub.setRootPath", Arguments.Count: > 0 })
{
var newRoot = request.Arguments[0].ToString();
if (!string.IsNullOrEmpty(newRoot))
{
workspaceManager.SetRootPath(newRoot);
}
}
return Unit.Task;
}
}

View File

@@ -15,25 +15,25 @@ internal class TextDocumentSyncHandler(WorkspaceManager workspaceManager) : Text
public override Task<Unit> Handle(DidOpenTextDocumentParams request, CancellationToken cancellationToken) public override Task<Unit> Handle(DidOpenTextDocumentParams request, CancellationToken cancellationToken)
{ {
workspaceManager.UpdateFile(request.TextDocument.Uri.GetFileSystemPath()); workspaceManager.Update();
return Unit.Task; return Unit.Task;
} }
public override Task<Unit> Handle(DidChangeTextDocumentParams request, CancellationToken cancellationToken) public override Task<Unit> Handle(DidChangeTextDocumentParams request, CancellationToken cancellationToken)
{ {
workspaceManager.UpdateFile(request.TextDocument.Uri.GetFileSystemPath()); workspaceManager.Update();
return Unit.Task; return Unit.Task;
} }
public override Task<Unit> Handle(DidSaveTextDocumentParams request, CancellationToken cancellationToken) public override Task<Unit> Handle(DidSaveTextDocumentParams request, CancellationToken cancellationToken)
{ {
workspaceManager.UpdateFile(request.TextDocument.Uri.GetFileSystemPath()); workspaceManager.Update();
return Unit.Task; return Unit.Task;
} }
public override Task<Unit> Handle(DidCloseTextDocumentParams request, CancellationToken cancellationToken) public override Task<Unit> Handle(DidCloseTextDocumentParams request, CancellationToken cancellationToken)
{ {
workspaceManager.UpdateFile(request.TextDocument.Uri.GetFileSystemPath()); workspaceManager.Update();
return Unit.Task; return Unit.Task;
} }

View File

@@ -1,77 +1,68 @@
using NubLang.Ast; using NubLang.Ast;
using NubLang.Diagnostics;
using NubLang.Modules;
using NubLang.Syntax; using NubLang.Syntax;
using OmniSharp.Extensions.LanguageServer.Protocol;
namespace NubLang.LSP; namespace NubLang.LSP;
public class WorkspaceManager(DiagnosticsPublisher diagnosticsPublisher) public class WorkspaceManager(DiagnosticsPublisher diagnosticsPublisher)
{ {
private readonly Dictionary<string, SyntaxTree> _syntaxTrees = new(); private record Unit(SyntaxTree SyntaxTree, DateTimeOffset FileTimestamp, List<Diagnostic> Diagnostics);
private readonly Dictionary<string, List<TopLevelNode>> _compilationUnits = new();
private readonly Dictionary<string, TypedModule> _modules = new();
public void Init(string rootPath) private readonly Tokenizer _tokenizer = new();
private readonly Parser _parser = new();
private readonly TypeChecker _typeChecker = new();
private string? _rootPath;
private readonly Dictionary<string, Unit> _units = [];
private readonly Dictionary<string, List<TopLevelNode>> _possiblyOutdatedTopLevelNodes = [];
private ModuleRepository _repository = new([]);
public void SetRootPath(string rootPath)
{ {
var files = Directory.GetFiles(rootPath, "*.nub", SearchOption.AllDirectories); _rootPath = rootPath;
foreach (var path in files) Update();
}
public void Update()
{
if (_rootPath == null) return;
var files = Directory.GetFiles(_rootPath, "*.nub", SearchOption.AllDirectories);
foreach (var file in files)
{ {
var text = File.ReadAllText(path); var lastUpdated = File.GetLastWriteTimeUtc(file);
var tokenizer = new Tokenizer(path, text); var unit = _units.GetValueOrDefault(file);
if (unit == null || lastUpdated > unit.FileTimestamp)
tokenizer.Tokenize(); {
diagnosticsPublisher.Publish(path, tokenizer.Diagnostics); _units[file] = Update(file, lastUpdated);
}
var parser = new Parser();
var parseResult = parser.Parse(tokenizer.Tokens);
diagnosticsPublisher.Publish(path, parser.Diagnostics);
_syntaxTrees[path] = parseResult;
} }
foreach (var (fsPath, syntaxTree) in _syntaxTrees) _repository = ModuleRepository.Create(_units.Select(x => x.Value.SyntaxTree).ToList());
foreach (var (file, unit) in _units)
{ {
var modules = Module.Collect(_syntaxTrees.Select(x => x.Value).ToList()); var topLevelNodes = _typeChecker.Check(unit.SyntaxTree, _repository);
_possiblyOutdatedTopLevelNodes[file] = topLevelNodes;
var typeChecker = new TypeChecker(syntaxTree, modules); diagnosticsPublisher.Publish(file, [..unit.Diagnostics, .._typeChecker.Diagnostics]);
var result = typeChecker.Check();
diagnosticsPublisher.Publish(fsPath, typeChecker.Diagnostics);
_compilationUnits[fsPath] = result;
} }
} }
public void UpdateFile(DocumentUri path) private Unit Update(string file, DateTimeOffset lastUpdated)
{ {
var fsPath = path.GetFileSystemPath(); var text = File.ReadAllText(file);
var tokens = _tokenizer.Tokenize(file, text);
var syntaxTree = _parser.Parse(tokens);
var text = File.ReadAllText(fsPath); return new Unit(syntaxTree, lastUpdated, [.._tokenizer.Diagnostics, .._parser.Diagnostics]);
var tokenizer = new Tokenizer(fsPath, text);
tokenizer.Tokenize();
diagnosticsPublisher.Publish(path, tokenizer.Diagnostics);
var parser = new Parser();
var syntaxTree = parser.Parse(tokenizer.Tokens);
diagnosticsPublisher.Publish(path, parser.Diagnostics);
_syntaxTrees[fsPath] = syntaxTree;
var modules = Module.Collect(_syntaxTrees.Select(x => x.Value).ToList());
var typeChecker = new TypeChecker(syntaxTree, modules);
var result = typeChecker.Check();
diagnosticsPublisher.Publish(fsPath, typeChecker.Diagnostics);
_compilationUnits[fsPath] = result;
} }
public void RemoveFile(DocumentUri path) public List<TopLevelNode> GetTopLevelNodes(string path)
{ {
var fsPath = path.GetFileSystemPath(); return _possiblyOutdatedTopLevelNodes.GetValueOrDefault(path, []);
_syntaxTrees.Remove(fsPath);
_compilationUnits.Remove(fsPath);
} }
public List<TopLevelNode>? GetCompilationUnit(DocumentUri path) public ModuleRepository GetModuleRepository()
{ {
return _compilationUnits.GetValueOrDefault(path.GetFileSystemPath()); return _repository;
} }
} }

View File

@@ -1,12 +0,0 @@
using NubLang.Syntax;
namespace NubLang.Ast;
// public sealed class CompilationUnit(IdentifierToken module, List<FuncNode> functions, List<StructNode> structTypes, Dictionary<IdentifierToken, List<NubStructType>> importedStructTypes, Dictionary<IdentifierToken, List<FuncPrototypeNode>> importedFunctions)
// {
// public IdentifierToken Module { get; } = module;
// public List<FuncNode> Functions { get; } = functions;
// public List<StructNode> Structs { get; } = structTypes;
// public Dictionary<IdentifierToken, List<NubStructType>> ImportedStructTypes { get; } = importedStructTypes;
// public Dictionary<IdentifierToken, List<FuncPrototypeNode>> ImportedFunctions { get; } = importedFunctions;
// }

View File

@@ -1,4 +1,5 @@
using NubLang.Syntax; using NubLang.Syntax;
using NubLang.Types;
namespace NubLang.Ast; namespace NubLang.Ast;
@@ -31,16 +32,6 @@ public abstract class Node(List<Token> tokens)
public abstract class TopLevelNode(List<Token> tokens) : Node(tokens); public abstract class TopLevelNode(List<Token> tokens) : Node(tokens);
public class ImportNode(List<Token> tokens, IdentifierToken nameToken) : TopLevelNode(tokens)
{
public IdentifierToken NameToken { get; } = nameToken;
public override IEnumerable<Node> Children()
{
return [];
}
}
public class ModuleNode(List<Token> tokens, IdentifierToken nameToken) : TopLevelNode(tokens) public class ModuleNode(List<Token> tokens, IdentifierToken nameToken) : TopLevelNode(tokens)
{ {
public IdentifierToken NameToken { get; } = nameToken; public IdentifierToken NameToken { get; } = nameToken;
@@ -490,7 +481,18 @@ public class VariableIdentifierNode(List<Token> tokens, NubType type, Identifier
} }
} }
public class FuncIdentifierNode(List<Token> tokens, NubType type, IdentifierToken moduleToken, IdentifierToken nameToken, StringLiteralToken? externSymbolToken) : RValue(tokens, type) public class LocalFuncIdentifierNode(List<Token> tokens, NubType type, IdentifierToken nameToken, StringLiteralToken? externSymbolToken) : RValue(tokens, type)
{
public IdentifierToken NameToken { get; } = nameToken;
public StringLiteralToken? ExternSymbolToken { get; } = externSymbolToken;
public override IEnumerable<Node> Children()
{
return [];
}
}
public class ModuleFuncIdentifierNode(List<Token> tokens, NubType type, IdentifierToken moduleToken, IdentifierToken nameToken, StringLiteralToken? externSymbolToken) : RValue(tokens, type)
{ {
public IdentifierToken ModuleToken { get; } = moduleToken; public IdentifierToken ModuleToken { get; } = moduleToken;
public IdentifierToken NameToken { get; } = nameToken; public IdentifierToken NameToken { get; } = nameToken;
@@ -579,9 +581,27 @@ public class SizeNode(List<Token> tokens, NubType targetType) : RValue(tokens, n
} }
} }
public class CastNode(List<Token> tokens, NubType type, ExpressionNode value) : RValue(tokens, type) public class CastNode(List<Token> tokens, NubType type, ExpressionNode value, CastNode.Conversion conversionType) : RValue(tokens, type)
{ {
public enum Conversion
{
IntToInt,
FloatToFloat,
IntToFloat,
FloatToInt,
PointerToPointer,
PointerToUInt64,
UInt64ToPointer,
ConstArrayToArray,
ConstArrayToSlice,
StringToCString
}
public ExpressionNode Value { get; } = value; public ExpressionNode Value { get; } = value;
public Conversion ConversionType { get; } = conversionType;
public override IEnumerable<Node> Children() public override IEnumerable<Node> Children()
{ {
@@ -589,7 +609,7 @@ public class CastNode(List<Token> tokens, NubType type, ExpressionNode value) :
} }
} }
public class StructInitializerNode(List<Token> tokens, NubType type, Dictionary<IdentifierToken, ExpressionNode> initializers) : RValue(tokens, type) public class StructInitializerNode(List<Token> tokens, NubType type, Dictionary<IdentifierToken, ExpressionNode> initializers) : LValue(tokens, type)
{ {
public Dictionary<IdentifierToken, ExpressionNode> Initializers { get; } = initializers; public Dictionary<IdentifierToken, ExpressionNode> Initializers { get; } = initializers;
@@ -602,7 +622,7 @@ public class StructInitializerNode(List<Token> tokens, NubType type, Dictionary<
} }
} }
public class ConstArrayInitializerNode(List<Token> tokens, NubType type, List<ExpressionNode> values) : RValue(tokens, type) public class ConstArrayInitializerNode(List<Token> tokens, NubType type, List<ExpressionNode> values) : LValue(tokens, type)
{ {
public List<ExpressionNode> Values { get; } = values; public List<ExpressionNode> Values { get; } = values;
@@ -612,17 +632,4 @@ public class ConstArrayInitializerNode(List<Token> tokens, NubType type, List<Ex
} }
} }
public abstract class IntermediateExpression(List<Token> tokens) : ExpressionNode(tokens, new NubVoidType());
public class EnumReferenceIntermediateNode(List<Token> tokens, IdentifierToken moduleToken, IdentifierToken nameToken) : IntermediateExpression(tokens)
{
public IdentifierToken ModuleToken { get; } = moduleToken;
public IdentifierToken NameToken { get; } = nameToken;
public override IEnumerable<Node> Children()
{
return [];
}
}
#endregion #endregion

View File

@@ -1,32 +1,28 @@
using System.Diagnostics; using System.Diagnostics;
using NubLang.Diagnostics; using NubLang.Diagnostics;
using NubLang.Modules;
using NubLang.Syntax; using NubLang.Syntax;
using NubLang.Types;
namespace NubLang.Ast; namespace NubLang.Ast;
public sealed class TypeChecker public sealed class TypeChecker
{ {
private readonly SyntaxTree _syntaxTree; private SyntaxTree _syntaxTree = null!;
private readonly Dictionary<string, Module> _modules; private ModuleRepository _repository = null!;
private readonly Stack<Scope> _scopes = []; private Stack<Scope> _scopes = [];
private readonly TypeResolver _typeResolver; private Scope CurrentScope => _scopes.Peek();
private Scope Scope => _scopes.Peek(); public List<Diagnostic> Diagnostics { get; set; } = [];
public List<Diagnostic> Diagnostics { get; } = []; public List<TopLevelNode> Check(SyntaxTree syntaxTree, ModuleRepository repository)
public TypeChecker(SyntaxTree syntaxTree, Dictionary<string, Module> modules)
{ {
_syntaxTree = syntaxTree; _syntaxTree = syntaxTree;
_modules = modules; _repository = repository;
_typeResolver = new TypeResolver(_modules); Diagnostics = [];
} _scopes = [];
public List<TopLevelNode> Check()
{
_scopes.Clear();
var moduleDeclarations = _syntaxTree.TopLevelSyntaxNodes.OfType<ModuleSyntax>().ToList(); var moduleDeclarations = _syntaxTree.TopLevelSyntaxNodes.OfType<ModuleSyntax>().ToList();
if (moduleDeclarations.Count == 0) if (moduleDeclarations.Count == 0)
@@ -40,51 +36,11 @@ public sealed class TypeChecker
Diagnostics.Add(Diagnostic.Error("Multiple module declarations").WithHelp("Remove extra module declarations").Build()); Diagnostics.Add(Diagnostic.Error("Multiple module declarations").WithHelp("Remove extra module declarations").Build());
} }
var moduleName = moduleDeclarations[0].NameToken; var module = _repository.Get(moduleDeclarations[0].NameToken);
var importDeclarations = _syntaxTree.TopLevelSyntaxNodes.OfType<ImportSyntax>().ToList();
foreach (var importDeclaration in importDeclarations)
{
var name = importDeclaration.NameToken.Value;
var last = importDeclarations.Last(x => x.NameToken.Value == name);
if (importDeclaration != last)
{
Diagnostics.Add(Diagnostic
.Warning($"Module \"{last.NameToken.Value}\" is imported twice")
.WithHelp($"Remove duplicate import \"{last.NameToken.Value}\"")
.At(last)
.Build());
}
var exists = _modules.ContainsKey(name);
if (!exists)
{
var suggestions = _modules.Keys
.Select(m => new { Name = m, Distance = Utils.LevenshteinDistance(name, m) })
.OrderBy(x => x.Distance)
.Take(3)
.Where(x => x.Distance <= 3)
.Select(x => $"\"{x.Name}\"")
.ToArray();
var suggestionText = suggestions.Length != 0
? $"Did you mean {string.Join(", ", suggestions)}?"
: "Check that the module name is correct.";
Diagnostics.Add(Diagnostic
.Error($"Module \"{name}\" does not exist")
.WithHelp(suggestionText)
.At(last)
.Build());
return [];
}
}
var topLevelNodes = new List<TopLevelNode>(); var topLevelNodes = new List<TopLevelNode>();
using (BeginRootScope(moduleName)) using (BeginRootScope(module))
{ {
foreach (var topLevelSyntaxNode in _syntaxTree.TopLevelSyntaxNodes) foreach (var topLevelSyntaxNode in _syntaxTree.TopLevelSyntaxNodes)
{ {
@@ -98,9 +54,6 @@ public sealed class TypeChecker
case StructSyntax structSyntax: case StructSyntax structSyntax:
topLevelNodes.Add(CheckStructDefinition(structSyntax)); topLevelNodes.Add(CheckStructDefinition(structSyntax));
break; break;
case ImportSyntax importSyntax:
topLevelNodes.Add(new ImportNode(importSyntax.Tokens, importSyntax.NameToken));
break;
case ModuleSyntax moduleSyntax: case ModuleSyntax moduleSyntax:
topLevelNodes.Add(new ModuleNode(moduleSyntax.Tokens, moduleSyntax.NameToken)); topLevelNodes.Add(new ModuleNode(moduleSyntax.Tokens, moduleSyntax.NameToken));
break; break;
@@ -113,58 +66,15 @@ public sealed class TypeChecker
return topLevelNodes; return topLevelNodes;
} }
private (IdentifierToken Name, Module Module) GetCurrentModule()
{
var currentModule = _syntaxTree.TopLevelSyntaxNodes.OfType<ModuleSyntax>().First().NameToken;
return (currentModule, _modules[currentModule.Value]);
}
private List<(IdentifierToken Name, Module Module)> GetImportedModules()
{
var currentModule = _syntaxTree.TopLevelSyntaxNodes.OfType<ModuleSyntax>().First().NameToken;
return _syntaxTree.TopLevelSyntaxNodes
.OfType<ImportSyntax>()
.Select(x => (Name: x.NameToken, Module: _modules[x.NameToken.Value]))
.Concat([(Name: currentModule, Module: _modules[currentModule.Value])])
.ToList();
}
private bool IsCurrentModule(IdentifierToken? module)
{
if (module == null)
{
return true;
}
return module.Value == Scope.Module.Value;
}
private Module? GetImportedModule(string module)
{
var currentModule = _syntaxTree.TopLevelSyntaxNodes.OfType<ModuleSyntax>().First().NameToken;
if (module == currentModule.Value)
{
return _modules[currentModule.Value];
}
var import = _syntaxTree.TopLevelSyntaxNodes.OfType<ImportSyntax>().FirstOrDefault(x => x.NameToken.Value == module);
if (import != null)
{
return _modules[import.NameToken.Value];
}
return null;
}
private ScopeDisposer BeginScope() private ScopeDisposer BeginScope()
{ {
_scopes.Push(Scope.SubScope()); _scopes.Push(CurrentScope.SubScope());
return new ScopeDisposer(this); return new ScopeDisposer(this);
} }
private ScopeDisposer BeginRootScope(IdentifierToken moduleName) private ScopeDisposer BeginRootScope(ModuleRepository.Module module)
{ {
_scopes.Push(new Scope(moduleName)); _scopes.Push(new Scope(module));
return new ScopeDisposer(this); return new ScopeDisposer(this);
} }
@@ -186,10 +96,10 @@ public sealed class TypeChecker
{ {
var prototype = CheckFuncPrototype(node.Prototype); var prototype = CheckFuncPrototype(node.Prototype);
Scope.SetReturnType(prototype.ReturnType); CurrentScope.SetReturnType(prototype.ReturnType);
foreach (var parameter in prototype.Parameters) foreach (var parameter in prototype.Parameters)
{ {
Scope.DeclareVariable(new Variable(parameter.NameToken, parameter.Type)); CurrentScope.DeclareVariable(new Variable(parameter.NameToken, parameter.Type));
} }
var body = node.Body == null ? null : CheckBlock(node.Body); var body = node.Body == null ? null : CheckBlock(node.Body);
@@ -203,7 +113,7 @@ public sealed class TypeChecker
foreach (var field in structSyntax.Fields) foreach (var field in structSyntax.Fields)
{ {
var fieldType = _typeResolver.ResolveType(field.Type, Scope.Module.Value); var fieldType = ResolveType(field.Type);
ExpressionNode? value = null; ExpressionNode? value = null;
if (field.Value != null) if (field.Value != null)
{ {
@@ -212,7 +122,7 @@ public sealed class TypeChecker
{ {
throw new CompileException(Diagnostic throw new CompileException(Diagnostic
.Error($"Type {value.Type} is not assignable to {field.Type} for field {field.NameToken.Value}") .Error($"Type {value.Type} is not assignable to {field.Type} for field {field.NameToken.Value}")
.At(field) .At(field, _syntaxTree.Tokens)
.Build()); .Build());
} }
} }
@@ -220,8 +130,7 @@ public sealed class TypeChecker
fields.Add(new StructFieldNode(field.Tokens, field.NameToken, fieldType, value)); fields.Add(new StructFieldNode(field.Tokens, field.NameToken, fieldType, value));
} }
var currentModule = GetCurrentModule(); var type = new NubStructType(CurrentScope.Module.Name, structSyntax.NameToken.Value, structSyntax.Packed, fields.Select(x => new NubStructFieldType(x.NameToken.Value, x.Type, x.Value != null)).ToList());
var type = new NubStructType(currentModule.Name.Value, structSyntax.NameToken.Value, structSyntax.Packed, fields.Select(x => new NubStructFieldType(x.NameToken.Value, x.Type, x.Value != null)).ToList());
return new StructNode(structSyntax.Tokens, structSyntax.NameToken, type, structSyntax.Packed, fields); return new StructNode(structSyntax.Tokens, structSyntax.NameToken, type, structSyntax.Packed, fields);
} }
@@ -235,7 +144,7 @@ public sealed class TypeChecker
{ {
throw new CompileException(Diagnostic throw new CompileException(Diagnostic
.Error($"Cannot assign {value.Type} to {target.Type}") .Error($"Cannot assign {value.Type} to {target.Type}")
.At(statement.Value) .At(statement.Value, _syntaxTree.Tokens)
.Build()); .Build());
} }
@@ -261,7 +170,7 @@ public sealed class TypeChecker
if (statement.Value != null) if (statement.Value != null)
{ {
var expectedReturnType = Scope.GetReturnType(); var expectedReturnType = CurrentScope.GetReturnType();
value = CheckExpression(statement.Value, expectedReturnType); value = CheckExpression(statement.Value, expectedReturnType);
} }
@@ -275,7 +184,7 @@ public sealed class TypeChecker
return expression switch return expression switch
{ {
FuncCallNode funcCall => new StatementFuncCallNode(statement.Tokens, funcCall), FuncCallNode funcCall => new StatementFuncCallNode(statement.Tokens, funcCall),
_ => throw new CompileException(Diagnostic.Error("Expressions statements can only be function calls").At(statement).Build()) _ => throw new CompileException(Diagnostic.Error("Expressions statements can only be function calls").At(statement, _syntaxTree.Tokens).Build())
}; };
} }
@@ -286,7 +195,7 @@ public sealed class TypeChecker
if (statement.ExplicitType != null) if (statement.ExplicitType != null)
{ {
type = _typeResolver.ResolveType(statement.ExplicitType, Scope.Module.Value); type = ResolveType(statement.ExplicitType);
} }
if (statement.Assignment != null) if (statement.Assignment != null)
@@ -301,7 +210,7 @@ public sealed class TypeChecker
{ {
throw new CompileException(Diagnostic throw new CompileException(Diagnostic
.Error($"Cannot assign {assignmentNode.Type} to variable of type {type}") .Error($"Cannot assign {assignmentNode.Type} to variable of type {type}")
.At(statement.Assignment) .At(statement.Assignment, _syntaxTree.Tokens)
.Build()); .Build());
} }
} }
@@ -310,11 +219,11 @@ public sealed class TypeChecker
{ {
throw new CompileException(Diagnostic throw new CompileException(Diagnostic
.Error($"Cannot infer type of variable {statement.NameToken.Value}") .Error($"Cannot infer type of variable {statement.NameToken.Value}")
.At(statement) .At(statement, _syntaxTree.Tokens)
.Build()); .Build());
} }
Scope.DeclareVariable(new Variable(statement.NameToken, type)); CurrentScope.DeclareVariable(new Variable(statement.NameToken, type));
return new VariableDeclarationNode(statement.Tokens, statement.NameToken, assignmentNode, type); return new VariableDeclarationNode(statement.Tokens, statement.NameToken, assignmentNode, type);
} }
@@ -337,10 +246,10 @@ public sealed class TypeChecker
{ {
using (BeginScope()) using (BeginScope())
{ {
Scope.DeclareVariable(new Variable(forSyntax.ElementNameToken, sliceType.ElementType)); CurrentScope.DeclareVariable(new Variable(forSyntax.ElementNameToken, sliceType.ElementType));
if (forSyntax.IndexNameToken != null) if (forSyntax.IndexNameToken != null)
{ {
Scope.DeclareVariable(new Variable(forSyntax.IndexNameToken, new NubIntType(false, 64))); CurrentScope.DeclareVariable(new Variable(forSyntax.IndexNameToken, new NubIntType(false, 64)));
} }
var body = CheckBlock(forSyntax.Body); var body = CheckBlock(forSyntax.Body);
@@ -351,10 +260,10 @@ public sealed class TypeChecker
{ {
using (BeginScope()) using (BeginScope())
{ {
Scope.DeclareVariable(new Variable(forSyntax.ElementNameToken, constArrayType.ElementType)); CurrentScope.DeclareVariable(new Variable(forSyntax.ElementNameToken, constArrayType.ElementType));
if (forSyntax.IndexNameToken != null) if (forSyntax.IndexNameToken != null)
{ {
Scope.DeclareVariable(new Variable(forSyntax.IndexNameToken, new NubIntType(false, 64))); CurrentScope.DeclareVariable(new Variable(forSyntax.IndexNameToken, new NubIntType(false, 64)));
} }
var body = CheckBlock(forSyntax.Body); var body = CheckBlock(forSyntax.Body);
@@ -365,7 +274,7 @@ public sealed class TypeChecker
{ {
throw new CompileException(Diagnostic throw new CompileException(Diagnostic
.Error($"Cannot iterate over type {target.Type} which does not have size information") .Error($"Cannot iterate over type {target.Type} which does not have size information")
.At(forSyntax.Target) .At(forSyntax.Target, _syntaxTree.Tokens)
.Build()); .Build());
} }
} }
@@ -376,10 +285,10 @@ public sealed class TypeChecker
var parameters = new List<FuncParameterNode>(); var parameters = new List<FuncParameterNode>();
foreach (var parameter in statement.Parameters) foreach (var parameter in statement.Parameters)
{ {
parameters.Add(new FuncParameterNode(parameter.Tokens, parameter.NameToken, _typeResolver.ResolveType(parameter.Type, Scope.Module.Value))); parameters.Add(new FuncParameterNode(parameter.Tokens, parameter.NameToken, ResolveType(parameter.Type)));
} }
return new FuncPrototypeNode(statement.Tokens, statement.NameToken, statement.ExternSymbolToken, parameters, _typeResolver.ResolveType(statement.ReturnType, Scope.Module.Value)); return new FuncPrototypeNode(statement.Tokens, statement.NameToken, statement.ExternSymbolToken, parameters, ResolveType(statement.ReturnType));
} }
private ExpressionNode CheckExpression(ExpressionSyntax node, NubType? expectedType = null) private ExpressionNode CheckExpression(ExpressionSyntax node, NubType? expectedType = null)
@@ -401,7 +310,7 @@ public sealed class TypeChecker
FloatLiteralSyntax expression => CheckFloatLiteral(expression, expectedType), FloatLiteralSyntax expression => CheckFloatLiteral(expression, expectedType),
MemberAccessSyntax expression => CheckMemberAccess(expression, expectedType), MemberAccessSyntax expression => CheckMemberAccess(expression, expectedType),
StructInitializerSyntax expression => CheckStructInitializer(expression, expectedType), StructInitializerSyntax expression => CheckStructInitializer(expression, expectedType),
SizeSyntax expression => new SizeNode(node.Tokens, _typeResolver.ResolveType(expression.Type, Scope.Module.Value)), SizeSyntax expression => new SizeNode(node.Tokens, ResolveType(expression.Type)),
CastSyntax expression => CheckCast(expression, expectedType), CastSyntax expression => CheckCast(expression, expectedType),
_ => throw new ArgumentOutOfRangeException(nameof(node)) _ => throw new ArgumentOutOfRangeException(nameof(node))
}; };
@@ -413,9 +322,9 @@ public sealed class TypeChecker
return result; return result;
} }
if (IsCastAllowed(result.Type, expectedType)) if (IsCastAllowed(result.Type, expectedType, out var conversion))
{ {
return new CastNode(result.Tokens, expectedType, result); return new CastNode(result.Tokens, expectedType, result, conversion);
} }
} }
@@ -428,7 +337,7 @@ public sealed class TypeChecker
{ {
throw new CompileException(Diagnostic throw new CompileException(Diagnostic
.Error("Unable to infer target type of cast") .Error("Unable to infer target type of cast")
.At(expression) .At(expression, _syntaxTree.Tokens)
.WithHelp("Specify target type where value is used") .WithHelp("Specify target type where value is used")
.Build()); .Build());
} }
@@ -439,32 +348,50 @@ public sealed class TypeChecker
{ {
Diagnostics.Add(Diagnostic Diagnostics.Add(Diagnostic
.Warning("Target type of cast is same as the value. Cast is unnecessary") .Warning("Target type of cast is same as the value. Cast is unnecessary")
.At(expression) .At(expression, _syntaxTree.Tokens)
.Build()); .Build());
return value; return value;
} }
if (!IsCastAllowed(value.Type, expectedType, false)) if (!IsCastAllowed(value.Type, expectedType, out var conversion, false))
{ {
throw new CompileException(Diagnostic throw new CompileException(Diagnostic
.Error($"Cannot cast from {value.Type} to {expectedType}") .Error($"Cannot cast from {value.Type} to {expectedType}")
.Build()); .Build());
} }
return new CastNode(expression.Tokens, expectedType, value); return new CastNode(expression.Tokens, expectedType, value, conversion);
} }
private static bool IsCastAllowed(NubType from, NubType to, bool strict = true) private static bool IsCastAllowed(NubType from, NubType to, out CastNode.Conversion conversion, bool strict = true)
{ {
// note(nub31): Implicit casts // note(nub31): Implicit casts
switch (from) switch (from)
{ {
case NubIntType fromInt when to is NubIntType toInt && fromInt.Width < toInt.Width: case NubIntType fromInt when to is NubIntType toInt && fromInt.Width < toInt.Width:
{
conversion = CastNode.Conversion.IntToInt;
return true;
}
case NubPointerType when to is NubPointerType { BaseType: NubVoidType }: case NubPointerType when to is NubPointerType { BaseType: NubVoidType }:
{
conversion = CastNode.Conversion.PointerToPointer;
return true;
}
case NubConstArrayType constArrayType1 when to is NubArrayType arrayType && constArrayType1.ElementType == arrayType.ElementType: case NubConstArrayType constArrayType1 when to is NubArrayType arrayType && constArrayType1.ElementType == arrayType.ElementType:
{
conversion = CastNode.Conversion.ConstArrayToArray;
return true;
}
case NubConstArrayType constArrayType3 when to is NubSliceType sliceType2 && constArrayType3.ElementType == sliceType2.ElementType: case NubConstArrayType constArrayType3 when to is NubSliceType sliceType2 && constArrayType3.ElementType == sliceType2.ElementType:
{ {
conversion = CastNode.Conversion.ConstArrayToSlice;
return true;
}
case NubStringType when to is NubPointerType { BaseType: NubIntType { Signed: true, Width: 8 } }:
{
conversion = CastNode.Conversion.StringToCString;
return true; return true;
} }
} }
@@ -475,19 +402,44 @@ public sealed class TypeChecker
switch (from) switch (from)
{ {
case NubIntType when to is NubIntType: case NubIntType when to is NubIntType:
case NubIntType when to is NubFloatType:
case NubFloatType when to is NubIntType:
case NubFloatType when to is NubFloatType:
case NubPointerType when to is NubPointerType:
case NubPointerType when to is NubIntType:
case NubIntType when to is NubPointerType:
{ {
conversion = CastNode.Conversion.IntToInt;
return true;
}
case NubIntType when to is NubFloatType:
{
conversion = CastNode.Conversion.IntToFloat;
return true;
}
case NubFloatType when to is NubIntType:
{
conversion = CastNode.Conversion.FloatToInt;
return true;
}
case NubFloatType when to is NubFloatType:
{
conversion = CastNode.Conversion.FloatToFloat;
return true;
}
case NubPointerType when to is NubPointerType:
{
conversion = CastNode.Conversion.PointerToPointer;
return true;
}
case NubPointerType when to is NubIntType { Signed: false, Width: 64 }:
{
conversion = CastNode.Conversion.PointerToUInt64;
return true;
}
case NubIntType { Signed: false, Width: 64 } when to is NubPointerType:
{
conversion = CastNode.Conversion.UInt64ToPointer;
return true; return true;
} }
} }
} }
conversion = default;
return false; return false;
} }
@@ -505,7 +457,7 @@ public sealed class TypeChecker
{ {
throw new CompileException(Diagnostic throw new CompileException(Diagnostic
.Error("Array indexer must be of type int") .Error("Array indexer must be of type int")
.At(expression.Index) .At(expression.Index, _syntaxTree.Tokens)
.Build()); .Build());
} }
@@ -516,7 +468,7 @@ public sealed class TypeChecker
NubArrayType arrayType => new ArrayIndexAccessNode(expression.Tokens, arrayType.ElementType, target, index), NubArrayType arrayType => new ArrayIndexAccessNode(expression.Tokens, arrayType.ElementType, target, index),
NubConstArrayType constArrayType => new ConstArrayIndexAccessNode(expression.Tokens, constArrayType.ElementType, target, index), NubConstArrayType constArrayType => new ConstArrayIndexAccessNode(expression.Tokens, constArrayType.ElementType, target, index),
NubSliceType sliceType => new SliceIndexAccessNode(expression.Tokens, sliceType.ElementType, target, index), NubSliceType sliceType => new SliceIndexAccessNode(expression.Tokens, sliceType.ElementType, target, index),
_ => throw new CompileException(Diagnostic.Error($"Cannot use array indexer on type {target.Type}").At(expression).Build()) _ => throw new CompileException(Diagnostic.Error($"Cannot use array indexer on type {target.Type}").At(expression, _syntaxTree.Tokens).Build())
}; };
} }
@@ -543,7 +495,7 @@ public sealed class TypeChecker
{ {
throw new CompileException(Diagnostic throw new CompileException(Diagnostic
.Error("Unable to infer type of array initializer") .Error("Unable to infer type of array initializer")
.At(expression) .At(expression, _syntaxTree.Tokens)
.WithHelp("Provide a type for a variable assignment") .WithHelp("Provide a type for a variable assignment")
.Build()); .Build());
} }
@@ -556,7 +508,7 @@ public sealed class TypeChecker
{ {
throw new CompileException(Diagnostic throw new CompileException(Diagnostic
.Error("Value in array initializer is not the same as the array type") .Error("Value in array initializer is not the same as the array type")
.At(valueExpression) .At(valueExpression, _syntaxTree.Tokens)
.Build()); .Build());
} }
@@ -601,7 +553,7 @@ public sealed class TypeChecker
{ {
throw new CompileException(Diagnostic throw new CompileException(Diagnostic
.Error("Equal and not equal operators must must be used with int, float or bool types") .Error("Equal and not equal operators must must be used with int, float or bool types")
.At(expression.Left) .At(expression.Left, _syntaxTree.Tokens)
.Build()); .Build());
} }
@@ -610,7 +562,7 @@ public sealed class TypeChecker
{ {
throw new CompileException(Diagnostic throw new CompileException(Diagnostic
.Error($"Expected type {left.Type} from left side of binary expression, but got {right.Type}") .Error($"Expected type {left.Type} from left side of binary expression, but got {right.Type}")
.At(expression.Right) .At(expression.Right, _syntaxTree.Tokens)
.Build()); .Build());
} }
@@ -626,7 +578,7 @@ public sealed class TypeChecker
{ {
throw new CompileException(Diagnostic throw new CompileException(Diagnostic
.Error("Greater than and less than operators must must be used with int or float types") .Error("Greater than and less than operators must must be used with int or float types")
.At(expression.Left) .At(expression.Left, _syntaxTree.Tokens)
.Build()); .Build());
} }
@@ -635,7 +587,7 @@ public sealed class TypeChecker
{ {
throw new CompileException(Diagnostic throw new CompileException(Diagnostic
.Error($"Expected type {left.Type} from left side of binary expression, but got {right.Type}") .Error($"Expected type {left.Type} from left side of binary expression, but got {right.Type}")
.At(expression.Right) .At(expression.Right, _syntaxTree.Tokens)
.Build()); .Build());
} }
@@ -649,7 +601,7 @@ public sealed class TypeChecker
{ {
throw new CompileException(Diagnostic throw new CompileException(Diagnostic
.Error("Logical and/or must must be used with bool types") .Error("Logical and/or must must be used with bool types")
.At(expression.Left) .At(expression.Left, _syntaxTree.Tokens)
.Build()); .Build());
} }
@@ -658,7 +610,7 @@ public sealed class TypeChecker
{ {
throw new CompileException(Diagnostic throw new CompileException(Diagnostic
.Error($"Expected type {left.Type} from left side of binary expression, but got {right.Type}") .Error($"Expected type {left.Type} from left side of binary expression, but got {right.Type}")
.At(expression.Right) .At(expression.Right, _syntaxTree.Tokens)
.Build()); .Build());
} }
@@ -671,7 +623,7 @@ public sealed class TypeChecker
{ {
throw new CompileException(Diagnostic throw new CompileException(Diagnostic
.Error("The plus operator must only be used with int and float types") .Error("The plus operator must only be used with int and float types")
.At(expression.Left) .At(expression.Left, _syntaxTree.Tokens)
.Build()); .Build());
} }
@@ -696,7 +648,7 @@ public sealed class TypeChecker
{ {
throw new CompileException(Diagnostic throw new CompileException(Diagnostic
.Error("Math operators must be used with int or float types") .Error("Math operators must be used with int or float types")
.At(expression.Left) .At(expression.Left, _syntaxTree.Tokens)
.Build()); .Build());
} }
@@ -705,7 +657,7 @@ public sealed class TypeChecker
{ {
throw new CompileException(Diagnostic throw new CompileException(Diagnostic
.Error($"Expected type {left.Type} from left side of binary expression, but got {right.Type}") .Error($"Expected type {left.Type} from left side of binary expression, but got {right.Type}")
.At(expression.Right) .At(expression.Right, _syntaxTree.Tokens)
.Build()); .Build());
} }
@@ -719,7 +671,7 @@ public sealed class TypeChecker
{ {
throw new CompileException(Diagnostic throw new CompileException(Diagnostic
.Error("Bitwise operators must be used with int types") .Error("Bitwise operators must be used with int types")
.At(expression.Left) .At(expression.Left, _syntaxTree.Tokens)
.Build()); .Build());
} }
@@ -728,7 +680,7 @@ public sealed class TypeChecker
{ {
throw new CompileException(Diagnostic throw new CompileException(Diagnostic
.Error("Bitwise operators must be used with int types") .Error("Bitwise operators must be used with int types")
.At(expression.Right) .At(expression.Right, _syntaxTree.Tokens)
.Build()); .Build());
} }
@@ -743,7 +695,7 @@ public sealed class TypeChecker
{ {
throw new CompileException(Diagnostic throw new CompileException(Diagnostic
.Error("Bitwise operators must be used with int types") .Error("Bitwise operators must be used with int types")
.At(expression.Left) .At(expression.Left, _syntaxTree.Tokens)
.Build()); .Build());
} }
@@ -752,7 +704,7 @@ public sealed class TypeChecker
{ {
throw new CompileException(Diagnostic throw new CompileException(Diagnostic
.Error($"Expected type {left.Type} from left side of binary expression, but got {right.Type}") .Error($"Expected type {left.Type} from left side of binary expression, but got {right.Type}")
.At(expression.Right) .At(expression.Right, _syntaxTree.Tokens)
.Build()); .Build());
} }
@@ -776,7 +728,7 @@ public sealed class TypeChecker
{ {
throw new CompileException(Diagnostic throw new CompileException(Diagnostic
.Error("Negation operator must be used with signed integer or float types") .Error("Negation operator must be used with signed integer or float types")
.At(expression) .At(expression, _syntaxTree.Tokens)
.Build()); .Build());
} }
@@ -789,7 +741,7 @@ public sealed class TypeChecker
{ {
throw new CompileException(Diagnostic throw new CompileException(Diagnostic
.Error("Invert operator must be used with booleans") .Error("Invert operator must be used with booleans")
.At(expression) .At(expression, _syntaxTree.Tokens)
.Build()); .Build());
} }
@@ -809,7 +761,7 @@ public sealed class TypeChecker
{ {
throw new CompileException(Diagnostic throw new CompileException(Diagnostic
.Error($"Cannot dereference non-pointer type {target.Type}") .Error($"Cannot dereference non-pointer type {target.Type}")
.At(expression) .At(expression, _syntaxTree.Tokens)
.Build()); .Build());
} }
@@ -821,14 +773,14 @@ public sealed class TypeChecker
var accessor = CheckExpression(expression.Expression); var accessor = CheckExpression(expression.Expression);
if (accessor.Type is not NubFuncType funcType) if (accessor.Type is not NubFuncType funcType)
{ {
throw new CompileException(Diagnostic.Error($"Cannot call non-function type {accessor.Type}").At(expression.Expression).Build()); throw new CompileException(Diagnostic.Error($"Cannot call non-function type {accessor.Type}").At(expression.Expression, _syntaxTree.Tokens).Build());
} }
if (expression.Parameters.Count != funcType.Parameters.Count) if (expression.Parameters.Count != funcType.Parameters.Count)
{ {
throw new CompileException(Diagnostic throw new CompileException(Diagnostic
.Error($"Function {funcType} expects {funcType.Parameters.Count} parameters but got {expression.Parameters.Count}") .Error($"Function {funcType} expects {funcType.Parameters.Count} parameters but got {expression.Parameters.Count}")
.At(expression.Parameters.LastOrDefault(expression)) .At(expression.Parameters.LastOrDefault(expression), _syntaxTree.Tokens)
.Build()); .Build());
} }
@@ -843,7 +795,7 @@ public sealed class TypeChecker
{ {
throw new CompileException(Diagnostic throw new CompileException(Diagnostic
.Error($"Parameter {i + 1} does not match the type {expectedParameterType} for function {funcType}") .Error($"Parameter {i + 1} does not match the type {expectedParameterType} for function {funcType}")
.At(parameter) .At(parameter, _syntaxTree.Tokens)
.Build()); .Build());
} }
@@ -856,67 +808,40 @@ public sealed class TypeChecker
private ExpressionNode CheckLocalIdentifier(LocalIdentifierSyntax expression, NubType? _) private ExpressionNode CheckLocalIdentifier(LocalIdentifierSyntax expression, NubType? _)
{ {
// note(nub31): Local identifiers can be variables or a symbol in a module // note(nub31): Local identifiers can be variables or a symbol in a module
var scopeIdent = Scope.LookupVariable(expression.NameToken); var scopeIdent = CurrentScope.LookupVariable(expression.NameToken);
if (scopeIdent != null) if (scopeIdent != null)
{ {
return new VariableIdentifierNode(expression.Tokens, scopeIdent.Type, expression.NameToken); return new VariableIdentifierNode(expression.Tokens, scopeIdent.Type, expression.NameToken);
} }
var module = GetImportedModule(Scope.Module.Value)!; if (CurrentScope.Module.TryResolveFunc(expression.NameToken, out var function, out var _))
var function = module.Functions(true).FirstOrDefault(x => x.NameToken.Value == expression.NameToken.Value);
if (function != null)
{ {
var parameters = function.Prototype.Parameters.Select(x => _typeResolver.ResolveType(x.Type, Scope.Module.Value)).ToList(); var type = new NubFuncType(function.Parameters.Select(x => x.Type).ToList(), function.ReturnType);
var type = new NubFuncType(parameters, _typeResolver.ResolveType(function.Prototype.ReturnType, Scope.Module.Value)); return new LocalFuncIdentifierNode(expression.Tokens, type, expression.NameToken, function.ExternSymbolToken);
return new FuncIdentifierNode(expression.Tokens, type, Scope.Module, expression.NameToken, function.Prototype.ExternSymbolToken);
}
var enumDef = module.Enums(true).FirstOrDefault(x => x.NameToken.Value == expression.NameToken.Value);
if (enumDef != null)
{
return new EnumReferenceIntermediateNode(expression.Tokens, Scope.Module, expression.NameToken);
} }
throw new CompileException(Diagnostic throw new CompileException(Diagnostic
.Error($"There is no identifier named {expression.NameToken.Value}") .Error($"There is no identifier named {expression.NameToken.Value}")
.At(expression) .At(expression, _syntaxTree.Tokens)
.Build()); .Build());
} }
private ExpressionNode CheckModuleIdentifier(ModuleIdentifierSyntax expression, NubType? _) private ExpressionNode CheckModuleIdentifier(ModuleIdentifierSyntax expression, NubType? _)
{ {
var module = GetImportedModule(expression.ModuleToken.Value); var module = _repository.Get(expression.ModuleToken);
if (module == null) using (BeginRootScope(module))
{ {
if (module.TryResolveFunc(expression.NameToken, out var function, out var _))
{
var type = new NubFuncType(function.Parameters.Select(x => x.Type).ToList(), function.ReturnType);
return new ModuleFuncIdentifierNode(expression.Tokens, type, expression.ModuleToken, expression.NameToken, function.ExternSymbolToken);
}
throw new CompileException(Diagnostic throw new CompileException(Diagnostic
.Error($"Module {expression.ModuleToken.Value} not found") .Error($"Module {expression.ModuleToken.Value} does not export a member named {expression.NameToken.Value}")
.WithHelp($"import \"{expression.ModuleToken.Value}\"") .At(expression, _syntaxTree.Tokens)
.At(expression.ModuleToken)
.Build()); .Build());
} }
var function = module.Functions(false).FirstOrDefault(x => x.NameToken.Value == expression.NameToken.Value);
if (function != null)
{
using (BeginRootScope(expression.ModuleToken))
{
var parameters = function.Prototype.Parameters.Select(x => _typeResolver.ResolveType(x.Type, Scope.Module.Value)).ToList();
var type = new NubFuncType(parameters, _typeResolver.ResolveType(function.Prototype.ReturnType, Scope.Module.Value));
return new FuncIdentifierNode(expression.Tokens, type, expression.ModuleToken, expression.NameToken, function.Prototype.ExternSymbolToken);
}
}
var enumDef = module.Enums(false).FirstOrDefault(x => x.NameToken.Value == expression.NameToken.Value);
if (enumDef != null)
{
return new EnumReferenceIntermediateNode(expression.Tokens, expression.ModuleToken, expression.NameToken);
}
throw new CompileException(Diagnostic
.Error($"Module {expression.ModuleToken.Value} does not export a member named {expression.NameToken.Value}")
.At(expression)
.Build());
} }
private ExpressionNode CheckStringLiteral(StringLiteralSyntax expression, NubType? expectedType) private ExpressionNode CheckStringLiteral(StringLiteralSyntax expression, NubType? expectedType)
@@ -980,55 +905,6 @@ public sealed class TypeChecker
{ {
var target = CheckExpression(expression.Target); var target = CheckExpression(expression.Target);
if (target is EnumReferenceIntermediateNode enumReferenceIntermediate)
{
var enumDef = GetImportedModules()
.First(x => x.Name.Value == enumReferenceIntermediate.ModuleToken.Value)
.Module
.Enums(IsCurrentModule(enumReferenceIntermediate.ModuleToken))
.First(x => x.NameToken.Value == enumReferenceIntermediate.NameToken.Value);
var field = enumDef.Fields.FirstOrDefault(x => x.NameToken.Value == expression.MemberToken.Value);
if (field == null)
{
throw new CompileException(Diagnostic
.Error($"Enum {Scope.Module.Value}::{enumReferenceIntermediate.NameToken.Value} does not have a field named {expression.MemberToken.Value}")
.At(enumDef)
.Build());
}
var enumType = enumDef.Type != null ? _typeResolver.ResolveType(enumDef.Type, Scope.Module.Value) : new NubIntType(false, 64);
if (enumType is not NubIntType enumIntType)
{
throw new CompileException(Diagnostic.Error("Enum type must be an int type").At(enumDef.Type).Build());
}
if (enumIntType.Signed)
{
var fieldValue = CalculateSignedEnumFieldValue(enumDef, field);
return enumIntType.Width switch
{
8 => new I8LiteralNode(expression.Tokens, (sbyte)fieldValue),
16 => new I16LiteralNode(expression.Tokens, (short)fieldValue),
32 => new I32LiteralNode(expression.Tokens, (int)fieldValue),
64 => new I64LiteralNode(expression.Tokens, fieldValue),
_ => throw new ArgumentOutOfRangeException()
};
}
else
{
var fieldValue = CalculateUnsignedEnumFieldValue(enumDef, field);
return enumIntType.Width switch
{
8 => new U8LiteralNode(expression.Tokens, (byte)fieldValue),
16 => new U16LiteralNode(expression.Tokens, (ushort)fieldValue),
32 => new U32LiteralNode(expression.Tokens, (uint)fieldValue),
64 => new U64LiteralNode(expression.Tokens, fieldValue),
_ => throw new ArgumentOutOfRangeException()
};
}
}
switch (target.Type) switch (target.Type)
{ {
case NubStructType structType: case NubStructType structType:
@@ -1038,7 +914,7 @@ public sealed class TypeChecker
{ {
throw new CompileException(Diagnostic throw new CompileException(Diagnostic
.Error($"Struct {target.Type} does not have a field with the name {expression.MemberToken.Value}") .Error($"Struct {target.Type} does not have a field with the name {expression.MemberToken.Value}")
.At(expression) .At(expression, _syntaxTree.Tokens)
.Build()); .Build());
} }
@@ -1048,63 +924,19 @@ public sealed class TypeChecker
{ {
throw new CompileException(Diagnostic throw new CompileException(Diagnostic
.Error($"Cannot access struct member {expression.MemberToken.Value} on type {target.Type}") .Error($"Cannot access struct member {expression.MemberToken.Value} on type {target.Type}")
.At(expression) .At(expression, _syntaxTree.Tokens)
.Build()); .Build());
} }
} }
} }
private static long CalculateSignedEnumFieldValue(EnumSyntax enumDef, EnumFieldSyntax field)
{
long currentValue = 0;
foreach (var f in enumDef.Fields)
{
if (f.ValueToken != null)
{
currentValue = f.ValueToken.AsI64;
}
if (f == field)
{
return currentValue;
}
currentValue++;
}
throw new UnreachableException();
}
private static ulong CalculateUnsignedEnumFieldValue(EnumSyntax enumDef, EnumFieldSyntax field)
{
ulong currentValue = 0;
foreach (var f in enumDef.Fields)
{
if (f.ValueToken != null)
{
currentValue = f.ValueToken.AsU64;
}
if (f == field)
{
return currentValue;
}
currentValue++;
}
throw new UnreachableException();
}
private StructInitializerNode CheckStructInitializer(StructInitializerSyntax expression, NubType? expectedType) private StructInitializerNode CheckStructInitializer(StructInitializerSyntax expression, NubType? expectedType)
{ {
NubStructType? structType = null; NubStructType? structType = null;
if (expression.StructType != null) if (expression.StructType != null)
{ {
var checkedType = _typeResolver.ResolveType(expression.StructType, Scope.Module.Value); var checkedType = ResolveType(expression.StructType);
if (checkedType is not NubStructType checkedStructType) if (checkedType is not NubStructType checkedStructType)
{ {
throw new UnreachableException("Parser fucked up"); throw new UnreachableException("Parser fucked up");
@@ -1123,7 +955,7 @@ public sealed class TypeChecker
throw new CompileException(Diagnostic throw new CompileException(Diagnostic
.Error("Cannot get implicit type of struct") .Error("Cannot get implicit type of struct")
.WithHelp("Specify struct type with struct {type_name} syntax") .WithHelp("Specify struct type with struct {type_name} syntax")
.At(expression) .At(expression, _syntaxTree.Tokens)
.Build()); .Build());
} }
@@ -1136,7 +968,7 @@ public sealed class TypeChecker
{ {
Diagnostics.AddRange(Diagnostic Diagnostics.AddRange(Diagnostic
.Error($"Struct {structType.Name} does not have a field named {initializer.Key}") .Error($"Struct {structType.Name} does not have a field named {initializer.Key}")
.At(initializer.Value) .At(initializer.Value, _syntaxTree.Tokens)
.Build()); .Build());
continue; continue;
@@ -1154,7 +986,7 @@ public sealed class TypeChecker
{ {
Diagnostics.Add(Diagnostic Diagnostics.Add(Diagnostic
.Warning($"Fields {string.Join(", ", missingFields)} are not initialized") .Warning($"Fields {string.Join(", ", missingFields)} are not initialized")
.At(expression) .At(expression, _syntaxTree.Tokens)
.Build()); .Build());
} }
@@ -1200,44 +1032,85 @@ public sealed class TypeChecker
_ => throw new ArgumentOutOfRangeException(nameof(statement)) _ => throw new ArgumentOutOfRangeException(nameof(statement))
}; };
} }
}
public record Variable(IdentifierToken Name, NubType Type); private NubType ResolveType(TypeSyntax type)
public class Scope(IdentifierToken module, Scope? parent = null)
{
private NubType? _returnType;
private readonly List<Variable> _variables = [];
public IdentifierToken Module { get; } = module;
public void DeclareVariable(Variable variable)
{ {
_variables.Add(variable); return type switch
}
public void SetReturnType(NubType returnType)
{
_returnType = returnType;
}
public NubType? GetReturnType()
{
return _returnType ?? parent?.GetReturnType();
}
public Variable? LookupVariable(IdentifierToken name)
{
var variable = _variables.FirstOrDefault(x => x.Name.Value == name.Value);
if (variable != null)
{ {
return variable; ArrayTypeSyntax arr => new NubArrayType(ResolveType(arr.BaseType)),
BoolTypeSyntax => new NubBoolType(),
IntTypeSyntax i => new NubIntType(i.Signed, i.Width),
FloatTypeSyntax f => new NubFloatType(f.Width),
FuncTypeSyntax func => new NubFuncType(func.Parameters.Select(ResolveType).ToList(), ResolveType(func.ReturnType)),
SliceTypeSyntax slice => new NubSliceType(ResolveType(slice.BaseType)),
ConstArrayTypeSyntax arr => new NubConstArrayType(ResolveType(arr.BaseType), arr.Size),
PointerTypeSyntax ptr => new NubPointerType(ResolveType(ptr.BaseType)),
StringTypeSyntax => new NubStringType(),
CustomTypeSyntax c => ResolveCustomType(c),
VoidTypeSyntax => new NubVoidType(),
_ => throw new NotSupportedException($"Unknown type syntax: {type}")
};
}
private NubType ResolveCustomType(CustomTypeSyntax customType)
{
var module = customType.ModuleToken != null ? _repository.Get(customType.ModuleToken) : CurrentScope.Module;
var structType = module.StructTypes.FirstOrDefault(x => x.Name == customType.NameToken.Value);
if (structType != null)
{
return structType;
} }
return parent?.LookupVariable(name); var enumType = module.EnumTypes.GetValueOrDefault(customType.NameToken.Value);
if (enumType != null)
{
return enumType;
}
throw new CompileException(Diagnostic
.Error($"Type {customType.NameToken.Value} not found in module {module.Name}")
.At(customType, _syntaxTree.Tokens)
.Build());
} }
public Scope SubScope() private record Variable(IdentifierToken Name, NubType Type);
private class Scope(ModuleRepository.Module module, Scope? parent = null)
{ {
return new Scope(Module, this); private NubType? _returnType;
private readonly List<Variable> _variables = [];
public ModuleRepository.Module Module { get; } = module;
public void DeclareVariable(Variable variable)
{
_variables.Add(variable);
}
public void SetReturnType(NubType returnType)
{
_returnType = returnType;
}
public NubType? GetReturnType()
{
return _returnType ?? parent?.GetReturnType();
}
public Variable? LookupVariable(IdentifierToken name)
{
var variable = _variables.FirstOrDefault(x => x.Name.Value == name.Value);
if (variable != null)
{
return variable;
}
return parent?.LookupVariable(name);
}
public Scope SubScope()
{
return new Scope(Module, this);
}
} }
} }

View File

@@ -1,96 +0,0 @@
using NubLang.Diagnostics;
using NubLang.Syntax;
namespace NubLang.Ast;
public class TypeResolver
{
private readonly Dictionary<string, Module> _modules;
private readonly Dictionary<(string Module, string Name), NubType> _typeCache = new();
private readonly HashSet<(string Module, string Name)> _resolvingTypes = [];
public TypeResolver(Dictionary<string, Module> modules)
{
_modules = modules;
}
public NubType ResolveType(TypeSyntax type, string currentModule)
{
return type switch
{
ArrayTypeSyntax arr => new NubArrayType(ResolveType(arr.BaseType, currentModule)),
BoolTypeSyntax => new NubBoolType(),
IntTypeSyntax i => new NubIntType(i.Signed, i.Width),
FloatTypeSyntax f => new NubFloatType(f.Width),
FuncTypeSyntax func => new NubFuncType(func.Parameters.Select(x => ResolveType(x, currentModule)).ToList(), ResolveType(func.ReturnType, currentModule)),
SliceTypeSyntax slice => new NubSliceType(ResolveType(slice.BaseType, currentModule)),
ConstArrayTypeSyntax arr => new NubConstArrayType(ResolveType(arr.BaseType, currentModule), arr.Size),
PointerTypeSyntax ptr => new NubPointerType(ResolveType(ptr.BaseType, currentModule)),
StringTypeSyntax => new NubStringType(),
CustomTypeSyntax c => ResolveCustomType(c, currentModule),
VoidTypeSyntax => new NubVoidType(),
_ => throw new NotSupportedException($"Unknown type syntax: {type}")
};
}
private NubType ResolveCustomType(CustomTypeSyntax customType, string currentModule)
{
var module = _modules[customType.ModuleToken?.Value ?? currentModule];
var enumDef = module.Enums(true).FirstOrDefault(x => x.NameToken.Value == customType.NameToken.Value);
if (enumDef != null)
{
return enumDef.Type != null ? ResolveType(enumDef.Type, currentModule) : new NubIntType(false, 64);
}
var structDef = module.Structs(true).FirstOrDefault(x => x.NameToken.Value == customType.NameToken.Value);
if (structDef != null)
{
var key = (customType.ModuleToken?.Value ?? currentModule, customType.NameToken.Value);
if (_typeCache.TryGetValue(key, out var cachedType))
{
return cachedType;
}
if (!_resolvingTypes.Add(key))
{
var placeholder = new NubStructType(customType.ModuleToken?.Value ?? currentModule, customType.NameToken.Value, structDef.Packed, []);
_typeCache[key] = placeholder;
return placeholder;
}
try
{
var result = new NubStructType(customType.ModuleToken?.Value ?? currentModule, structDef.NameToken.Value, structDef.Packed, []);
_typeCache[key] = result;
var fields = structDef.Fields
.Select(x => new NubStructFieldType(x.NameToken.Value, ResolveType(x.Type, currentModule), x.Value != null))
.ToList();
result.Fields.AddRange(fields);
return result;
}
finally
{
_resolvingTypes.Remove(key);
}
}
throw new TypeResolverException(Diagnostic
.Error($"Type {customType.NameToken.Value} not found in module {customType.ModuleToken?.Value ?? currentModule}")
.At(customType)
.Build());
}
}
public class TypeResolverException : Exception
{
public Diagnostic Diagnostic { get; }
public TypeResolverException(Diagnostic diagnostic) : base(diagnostic.Message)
{
Diagnostic = diagnostic;
}
}

View File

@@ -11,6 +11,7 @@ public class Diagnostic
private readonly string _message; private readonly string _message;
private SourceSpan? _span; private SourceSpan? _span;
private string? _help; private string? _help;
private List<Token>? _tokens;
public DiagnosticBuilder(DiagnosticSeverity severity, string message) public DiagnosticBuilder(DiagnosticSeverity severity, string message)
{ {
@@ -18,18 +19,32 @@ public class Diagnostic
_message = message; _message = message;
} }
public DiagnosticBuilder At(SyntaxNode? node) public DiagnosticBuilder At(SyntaxNode? node, List<Token>? tokens = null)
{ {
if (tokens != null)
{
_tokens = tokens;
}
if (node != null) if (node != null)
{ {
_span = SourceSpan.Merge(node.Tokens.Select(x => x.Span)); var first = node.Tokens.FirstOrDefault();
if (first != null)
{
_span = SourceSpan.Merge(node.Tokens.Select(x => x.Span));
}
} }
return this; return this;
} }
public DiagnosticBuilder At(Token? token) public DiagnosticBuilder At(Token? token, List<Token>? tokens = null)
{ {
if (tokens != null)
{
_tokens = tokens;
}
if (token != null) if (token != null)
{ {
At(token.Span); At(token.Span);
@@ -48,11 +63,11 @@ public class Diagnostic
return this; return this;
} }
public DiagnosticBuilder At(string filePath, int line, int column) // public DiagnosticBuilder At(string filePath, int line, int column)
{ // {
_span = new SourceSpan(filePath, new SourceLocation(line, column), new SourceLocation(line, column)); // _span = new SourceSpan(filePath, new SourceLocation(line, column), new SourceLocation(line, column));
return this; // return this;
} // }
public DiagnosticBuilder WithHelp(string help) public DiagnosticBuilder WithHelp(string help)
{ {
@@ -60,20 +75,23 @@ public class Diagnostic
return this; return this;
} }
public Diagnostic Build() => new(_severity, _message, _help, _span); public Diagnostic Build() => new(_severity, _message, _help, _span, _tokens);
} }
public static DiagnosticBuilder Error(string message) => new(DiagnosticSeverity.Error, message); public static DiagnosticBuilder Error(string message) => new(DiagnosticSeverity.Error, message);
public static DiagnosticBuilder Warning(string message) => new(DiagnosticSeverity.Warning, message); public static DiagnosticBuilder Warning(string message) => new(DiagnosticSeverity.Warning, message);
public static DiagnosticBuilder Info(string message) => new(DiagnosticSeverity.Info, message); public static DiagnosticBuilder Info(string message) => new(DiagnosticSeverity.Info, message);
private readonly List<Token>? _tokens;
public DiagnosticSeverity Severity { get; } public DiagnosticSeverity Severity { get; }
public string Message { get; } public string Message { get; }
public string? Help { get; } public string? Help { get; }
public SourceSpan? Span { get; } public SourceSpan? Span { get; }
private Diagnostic(DiagnosticSeverity severity, string message, string? help, SourceSpan? span) private Diagnostic(DiagnosticSeverity severity, string message, string? help, SourceSpan? span, List<Token>? tokens)
{ {
_tokens = tokens;
Severity = severity; Severity = severity;
Message = message; Message = message;
Help = help; Help = help;
@@ -103,15 +121,12 @@ public class Diagnostic
if (Span.HasValue) if (Span.HasValue)
{ {
sb.AppendLine(); sb.AppendLine();
var text = File.ReadAllText(Span.Value.FilePath); var text = Span.Value.Source;
var tokenizer = new Tokenizer(Span.Value.FilePath, text);
tokenizer.Tokenize();
var lines = text.Split('\n'); var lines = text.Split('\n');
var startLine = Span.Value.Start.Line; var startLine = Span.Value.StartLine;
var endLine = Span.Value.End.Line; var endLine = Span.Value.EndLine;
const int CONTEXT_LINES = 3; const int CONTEXT_LINES = 3;
@@ -144,24 +159,31 @@ public class Diagnostic
sb.Append("│ "); sb.Append("│ ");
sb.Append(i.ToString().PadRight(numberPadding)); sb.Append(i.ToString().PadRight(numberPadding));
sb.Append(" │ "); sb.Append(" │ ");
sb.Append(ApplySyntaxHighlighting(line.PadRight(codePadding), i, tokenizer.Tokens)); if (_tokens != null)
// sb.Append(line.PadRight(codePadding)); {
sb.Append(ApplySyntaxHighlighting(line.PadRight(codePadding), i, _tokens));
}
else
{
sb.Append(line.PadRight(codePadding));
}
sb.Append(" │"); sb.Append(" │");
sb.AppendLine(); sb.AppendLine();
if (i >= startLine && i <= endLine) if (i >= startLine && i <= endLine)
{ {
var markerStartColumn = 1; var markerStartColumn = 1;
var markerEndColumn = line.Length; var markerEndColumn = line.Length + 1;
if (i == startLine) if (i == startLine)
{ {
markerStartColumn = Span.Value.Start.Column; markerStartColumn = Span.Value.StartColumn;
} }
if (i == endLine) if (i == endLine)
{ {
markerEndColumn = Span.Value.End.Column; markerEndColumn = Span.Value.EndColumn;
} }
var markerLength = markerEndColumn - markerStartColumn; var markerLength = markerEndColumn - markerStartColumn;
@@ -206,8 +228,8 @@ public class Diagnostic
{ {
var sb = new StringBuilder(); var sb = new StringBuilder();
var lineTokens = tokens var lineTokens = tokens
.Where(t => t.Span.Start.Line == lineNumber) .Where(t => t.Span.StartLine == lineNumber)
.OrderBy(t => t.Span.Start.Column) .OrderBy(t => t.Span.StartColumn)
.ToList(); .ToList();
if (lineTokens.Count == 0) if (lineTokens.Count == 0)
@@ -219,8 +241,10 @@ public class Diagnostic
foreach (var token in lineTokens) foreach (var token in lineTokens)
{ {
var tokenStart = token.Span.Start.Column; if (token is WhitespaceToken) continue;
var tokenEnd = token.Span.End.Column;
var tokenStart = token.Span.StartColumn;
var tokenEnd = token.Span.EndColumn;
if (tokenStart > currentColumn && currentColumn - 1 < line.Length) if (tokenStart > currentColumn && currentColumn - 1 < line.Length)
{ {
@@ -262,6 +286,10 @@ public class Diagnostic
{ {
switch (token) switch (token)
{ {
case CommentToken:
{
return ConsoleColors.Colorize(tokenText, ConsoleColors.Green);
}
case IdentifierToken: case IdentifierToken:
{ {
return ConsoleColors.Colorize(tokenText, ConsoleColors.BrightWhite); return ConsoleColors.Colorize(tokenText, ConsoleColors.BrightWhite);

View File

@@ -1,112 +1,56 @@
namespace NubLang.Diagnostics; namespace NubLang.Diagnostics;
public readonly struct SourceSpan : IEquatable<SourceSpan>, IComparable<SourceSpan> public readonly struct SourceSpan
{ {
private readonly int _startIndex;
private readonly int _endIndex;
public static SourceSpan Merge(params IEnumerable<SourceSpan> spans) public static SourceSpan Merge(params IEnumerable<SourceSpan> spans)
{ {
var spanArray = spans as SourceSpan[] ?? spans.ToArray(); var spanArray = spans as SourceSpan[] ?? spans.ToArray();
if (spanArray.Length == 0) if (spanArray.Length == 0)
{ {
return new SourceSpan(string.Empty, new SourceLocation(0, 0), new SourceLocation(0, 0)); return new SourceSpan(string.Empty, string.Empty, 0, 0, 0, 0, 0, 0);
} }
var minStart = spanArray.Min(s => s.Start); var first = spanArray.MinBy(x => x._startIndex);
var maxEnd = spanArray.Max(s => s.End); var last = spanArray.MaxBy(x => x._endIndex);
return new SourceSpan(spanArray[0].FilePath, minStart, maxEnd); return new SourceSpan(first.SourcePath, first.Source, first._startIndex, last._endIndex, first.StartLine, last.EndLine, first.StartColumn, last.EndColumn);
} }
public SourceSpan(string filePath, SourceLocation start, SourceLocation end) public SourceSpan(string sourcePath, string source, int startIndex, int endIndex, int startLine, int startColumn, int endLine, int endColumn)
{ {
if (start > end) _startIndex = startIndex;
{ _endIndex = endIndex;
throw new ArgumentException("Start location cannot be after end location"); SourcePath = sourcePath;
} Source = source;
StartLine = startLine;
FilePath = filePath; StartColumn = startColumn;
Start = start; EndLine = endLine;
End = end; EndColumn = endColumn;
} }
public string FilePath { get; } public int StartLine { get; }
public SourceLocation Start { get; } public int StartColumn { get; }
public SourceLocation End { get; } public int EndLine { get; }
public int EndColumn { get; }
public string SourcePath { get; }
public string Source { get; }
public override string ToString() public override string ToString()
{ {
if (Start == End) if (StartLine == EndLine && StartColumn == EndColumn)
{ {
return $"{FilePath}:{Start}"; return $"{SourcePath}:{StartColumn}:{StartColumn}";
} }
if (Start.Line == End.Line) if (StartLine == EndLine)
{ {
return Start.Column == End.Column ? $"{FilePath}:{Start}" : $"{FilePath}:{Start.Line}:{Start.Column}-{End.Column}"; return $"{SourcePath}:{StartLine}:{StartColumn}-{EndColumn}";
} }
return $"{FilePath}:{Start}-{End}"; return $"{SourcePath}:{StartLine}:{StartColumn}-{EndLine}:{EndColumn}";
}
public bool Equals(SourceSpan other) => Start == other.Start && End == other.End;
public override bool Equals(object? obj) => obj is SourceSpan other && Equals(other);
public override int GetHashCode() => HashCode.Combine(typeof(SourceSpan), Start, End);
public static bool operator ==(SourceSpan left, SourceSpan right) => Equals(left, right);
public static bool operator !=(SourceSpan left, SourceSpan right) => !Equals(left, right);
public static bool operator <(SourceSpan left, SourceSpan right) => left.CompareTo(right) < 0;
public static bool operator <=(SourceSpan left, SourceSpan right) => left.CompareTo(right) <= 0;
public static bool operator >(SourceSpan left, SourceSpan right) => left.CompareTo(right) > 0;
public static bool operator >=(SourceSpan left, SourceSpan right) => left.CompareTo(right) >= 0;
public int CompareTo(SourceSpan other)
{
var startComparison = Start.CompareTo(other.Start);
return startComparison != 0 ? startComparison : End.CompareTo(other.End);
}
}
public readonly struct SourceLocation : IEquatable<SourceLocation>, IComparable<SourceLocation>
{
public SourceLocation(int line, int column)
{
Line = line;
Column = column;
}
public int Line { get; }
public int Column { get; }
public override string ToString()
{
return $"{Line}:{Column}";
}
public override bool Equals(object? obj)
{
return obj is SourceLocation other && Equals(other);
}
public bool Equals(SourceLocation other)
{
return Line == other.Line && Column == other.Column;
}
public override int GetHashCode()
{
return HashCode.Combine(typeof(SourceLocation), Line, Column);
}
public static bool operator ==(SourceLocation left, SourceLocation right) => Equals(left, right);
public static bool operator !=(SourceLocation left, SourceLocation right) => !Equals(left, right);
public static bool operator <(SourceLocation left, SourceLocation right) => left.Line < right.Line || (left.Line == right.Line && left.Column < right.Column);
public static bool operator >(SourceLocation left, SourceLocation right) => left.Line > right.Line || (left.Line == right.Line && left.Column > right.Column);
public static bool operator <=(SourceLocation left, SourceLocation right) => left.Line <= right.Line || (left.Line == right.Line && left.Column <= right.Column);
public static bool operator >=(SourceLocation left, SourceLocation right) => left.Line >= right.Line || (left.Line == right.Line && left.Column >= right.Column);
public int CompareTo(SourceLocation other)
{
var lineComparison = Line.CompareTo(other.Line);
return lineComparison != 0 ? lineComparison : Column.CompareTo(other.Column);
} }
} }

View File

@@ -1,979 +0,0 @@
using System.Text;
using NubLang.Ast;
namespace NubLang.Generation;
public class LlvmGenerator
{
private string _module = string.Empty;
private int _tmpIndex;
private int _labelIndex;
private List<(string Name, int Size, string Text)> _stringLiterals = [];
private Stack<(string breakLabel, string continueLabel)> _loopStack = [];
public string Emit(List<TopLevelNode> topLevelNodes)
{
_stringLiterals = [];
_loopStack = [];
var writer = new IndentedTextWriter();
_module = topLevelNodes.OfType<ModuleNode>().First().NameToken.Value;
writer.WriteLine($"; Module {_module}");
writer.WriteLine();
writer.WriteLine("declare void @llvm.memcpy.p0.p0.i64(ptr, ptr, i64, i1)");
writer.WriteLine();
foreach (var structNode in topLevelNodes.OfType<StructNode>())
{
var types = structNode.Fields.Select(x => MapType(x.Type));
writer.WriteLine($"%{StructName(structNode)} = type {{ {string.Join(", ", types)} }}");
writer.WriteLine();
_tmpIndex = 0;
_labelIndex = 0;
writer.WriteLine($"define void @{StructName(structNode)}.new(ptr %self) {{");
using (writer.Indent())
{
foreach (var field in structNode.Fields)
{
if (field.Value != null)
{
var index = structNode.StructType.GetFieldIndex(field.NameToken.Value);
var fieldTmp = NewTmp($"struct.field.{field.NameToken.Value}");
writer.WriteLine($"{fieldTmp} = getelementptr %{StructName(structNode)}, ptr %self, i32 0, i32 {index}");
EmitExpressionInto(writer, field.Value, fieldTmp);
}
}
writer.WriteLine("ret void");
}
writer.WriteLine("}");
writer.WriteLine();
}
foreach (var funcNode in topLevelNodes.OfType<FuncNode>())
{
if (funcNode.Body != null) continue;
var parameters = funcNode.Prototype.Parameters.Select(x => $"{MapType(x.Type)} %{x.NameToken.Value}");
writer.WriteLine($"declare {MapType(funcNode.Prototype.ReturnType)} @{FuncName(funcNode.Prototype)}({string.Join(", ", parameters)})");
writer.WriteLine();
}
foreach (var funcNode in topLevelNodes.OfType<FuncNode>())
{
if (funcNode.Body == null) continue;
_tmpIndex = 0;
_labelIndex = 0;
var parameters = funcNode.Prototype.Parameters.Select(x => $"{MapType(x.Type)} %{x.NameToken.Value}");
writer.WriteLine($"define {MapType(funcNode.Prototype.ReturnType)} @{FuncName(funcNode.Prototype)}({string.Join(", ", parameters)}) {{");
using (writer.Indent())
{
EmitBlock(writer, funcNode.Body);
// note(nub31): Implicit return for void functions
if (funcNode.Prototype.ReturnType is NubVoidType)
{
writer.WriteLine("ret void");
}
}
writer.WriteLine("}");
writer.WriteLine();
}
foreach (var stringLiteral in _stringLiterals)
{
writer.WriteLine($"{stringLiteral.Name} = private unnamed_addr constant [{stringLiteral.Size} x i8] c\"{stringLiteral.Text}\\00\", align 1");
}
return writer.ToString();
}
private void EmitStatement(IndentedTextWriter writer, StatementNode statementNode)
{
switch (statementNode)
{
case AssignmentNode assignmentNode:
EmitAssignment(writer, assignmentNode);
break;
case BlockNode blockNode:
EmitBlock(writer, blockNode);
break;
case BreakNode breakNode:
EmitBreak(writer, breakNode);
break;
case ContinueNode continueNode:
EmitContinue(writer, continueNode);
break;
case DeferNode deferNode:
EmitDefer(writer, deferNode);
break;
case ForConstArrayNode forConstArrayNode:
EmitForConstArray(writer, forConstArrayNode);
break;
case ForSliceNode forSliceNode:
EmitForSlice(writer, forSliceNode);
break;
case IfNode ifNode:
EmitIf(writer, ifNode);
break;
case ReturnNode returnNode:
EmitReturn(writer, returnNode);
break;
case StatementFuncCallNode statementFuncCallNode:
EmitStatementFuncCall(writer, statementFuncCallNode);
break;
case VariableDeclarationNode variableDeclarationNode:
EmitVariableDeclaration(writer, variableDeclarationNode);
break;
case WhileNode whileNode:
EmitWhile(writer, whileNode);
break;
default:
{
throw new NotImplementedException();
}
}
}
private void EmitAssignment(IndentedTextWriter writer, AssignmentNode assignmentNode)
{
var target = EmitExpression(writer, assignmentNode.Target);
var value = Unwrap(writer, EmitExpression(writer, assignmentNode.Value));
writer.WriteLine($"store {MapType(assignmentNode.Value.Type)} {value}, ptr {target.Ident}");
}
private void EmitBlock(IndentedTextWriter writer, BlockNode blockNode)
{
foreach (var statementNode in blockNode.Statements)
{
EmitStatement(writer, statementNode);
}
}
private void EmitBreak(IndentedTextWriter writer, BreakNode breakNode)
{
var (breakLabel, _) = _loopStack.Peek();
writer.WriteLine($"br label %{breakLabel}");
}
private void EmitContinue(IndentedTextWriter writer, ContinueNode continueNode)
{
var (_, continueLabel) = _loopStack.Peek();
writer.WriteLine($"br label %{continueLabel}");
}
private void EmitDefer(IndentedTextWriter writer, DeferNode deferNode)
{
throw new NotImplementedException();
}
private void EmitForConstArray(IndentedTextWriter writer, ForConstArrayNode forConstArrayNode)
{
throw new NotImplementedException();
}
private void EmitForSlice(IndentedTextWriter writer, ForSliceNode forSliceNode)
{
throw new NotImplementedException();
}
private void EmitIf(IndentedTextWriter writer, IfNode ifNode)
{
var endLabel = NewLabel("if.end");
EmitIf(writer, ifNode, endLabel);
writer.WriteLine($"{endLabel}:");
}
private void EmitIf(IndentedTextWriter writer, IfNode ifNode, string endLabel)
{
var condition = Unwrap(writer, EmitExpression(writer, ifNode.Condition));
var thenLabel = NewLabel("if.then");
var elseLabel = ifNode.Else.HasValue ? NewLabel("if.else") : endLabel;
writer.WriteLine($"br i1 {condition}, label %{thenLabel}, label %{elseLabel}");
writer.WriteLine($"{thenLabel}:");
using (writer.Indent())
{
EmitBlock(writer, ifNode.Body);
writer.WriteLine($"br label %{endLabel}");
}
if (!ifNode.Else.HasValue) return;
writer.WriteLine($"{elseLabel}:");
using (writer.Indent())
{
ifNode.Else.Value.Match(
nestedElseIf => EmitIf(writer, nestedElseIf, endLabel),
finalElse =>
{
EmitBlock(writer, finalElse);
writer.WriteLine($"br label %{endLabel}");
}
);
}
}
private void EmitReturn(IndentedTextWriter writer, ReturnNode returnNode)
{
if (returnNode.Value != null)
{
var returnValue = Unwrap(writer, EmitExpression(writer, returnNode.Value));
writer.WriteLine($"ret {MapType(returnNode.Value.Type)} {returnValue}");
}
else
{
writer.WriteLine("ret void");
}
}
private void EmitStatementFuncCall(IndentedTextWriter writer, StatementFuncCallNode statementFuncCallNode)
{
EmitFuncCall(writer, statementFuncCallNode.FuncCall);
}
private void EmitVariableDeclaration(IndentedTextWriter writer, VariableDeclarationNode variableDeclarationNode)
{
writer.WriteLine($"%{variableDeclarationNode.NameToken.Value} = alloca {MapType(variableDeclarationNode.Type)}");
if (variableDeclarationNode.Assignment != null)
{
EmitExpressionInto(writer, variableDeclarationNode.Assignment, $"%{variableDeclarationNode.NameToken.Value}");
}
}
private void EmitWhile(IndentedTextWriter writer, WhileNode whileNode)
{
var conditionLabel = NewLabel("while.condition");
var bodyLabel = NewLabel("while.body");
var endLabel = NewLabel("while.end");
_loopStack.Push((endLabel, conditionLabel));
writer.WriteLine($"br label %{conditionLabel}");
writer.WriteLine($"{conditionLabel}:");
using (writer.Indent())
{
var condition = Unwrap(writer, EmitExpression(writer, whileNode.Condition));
writer.WriteLine($"br i1 {condition}, label %{bodyLabel}, label %{endLabel}");
}
writer.WriteLine($"{bodyLabel}:");
using (writer.Indent())
{
EmitBlock(writer, whileNode.Body);
writer.WriteLine($"br label %{conditionLabel}");
}
_loopStack.Pop();
writer.WriteLine($"{endLabel}:");
}
private Tmp EmitExpression(IndentedTextWriter writer, ExpressionNode expressionNode)
{
return expressionNode switch
{
RValue rValue => EmitRValue(writer, rValue),
LValue lValue => EmitLValue(writer, lValue),
_ => throw new ArgumentOutOfRangeException(nameof(expressionNode))
};
}
private void EmitExpressionInto(IndentedTextWriter writer, ExpressionNode expressionNode, string destination)
{
switch (expressionNode)
{
case StructInitializerNode structInitializerNode:
{
EmitStructInitializer(writer, structInitializerNode, destination);
return;
}
case ConstArrayInitializerNode constArrayInitializerNode:
{
EmitConstArrayInitializer(writer, constArrayInitializerNode, destination);
return;
}
}
var value = Unwrap(writer, EmitExpression(writer, expressionNode));
if (expressionNode.Type.IsAggregate())
{
// note(nub31): Fall back to slow method creating a copy
writer.WriteLine("; Slow rvalue copy instead of direct memory write");
writer.WriteLine($"call void @llvm.memcpy.p0.p0.i64(ptr {destination}, ptr {value}, i64 {expressionNode.Type.GetSize()}, i1 false)");
}
else
{
writer.WriteLine($"store {MapType(expressionNode.Type)} {value}, ptr {destination}");
}
}
private Tmp EmitRValue(IndentedTextWriter writer, RValue rValue)
{
return rValue switch
{
AddressOfNode addressOfNode => EmitAddressOf(writer, addressOfNode),
BinaryExpressionNode binaryExpressionNode => EmitBinaryExpression(writer, binaryExpressionNode),
BoolLiteralNode boolLiteralNode => EmitBoolLiteral(writer, boolLiteralNode),
CastNode castNode => EmitCast(writer, castNode),
ConstArrayInitializerNode constArrayInitializerNode => EmitConstArrayInitializer(writer, constArrayInitializerNode),
CStringLiteralNode cStringLiteralNode => EmitCStringLiteral(writer, cStringLiteralNode),
Float32LiteralNode float32LiteralNode => EmitFloat32Literal(writer, float32LiteralNode),
Float64LiteralNode float64LiteralNode => EmitFloat64Literal(writer, float64LiteralNode),
FuncCallNode funcCallNode => EmitFuncCall(writer, funcCallNode),
FuncIdentifierNode funcIdentifierNode => EmitFuncIdentifier(writer, funcIdentifierNode),
I16LiteralNode i16LiteralNode => EmitI16Literal(writer, i16LiteralNode),
I32LiteralNode i32LiteralNode => EmitI32Literal(writer, i32LiteralNode),
I64LiteralNode i64LiteralNode => EmitI64Literal(writer, i64LiteralNode),
I8LiteralNode i8LiteralNode => EmitI8Literal(writer, i8LiteralNode),
SizeNode sizeNode => EmitSize(writer, sizeNode),
StringLiteralNode stringLiteralNode => EmitStringLiteral(writer, stringLiteralNode),
StructInitializerNode structInitializerNode => EmitStructInitializer(writer, structInitializerNode),
U16LiteralNode u16LiteralNode => EmitU16Literal(writer, u16LiteralNode),
U32LiteralNode u32LiteralNode => EmitU32Literal(writer, u32LiteralNode),
U64LiteralNode u64LiteralNode => EmitU64Literal(writer, u64LiteralNode),
U8LiteralNode u8LiteralNode => EmitU8Literal(writer, u8LiteralNode),
UnaryExpressionNode unaryExpressionNode => EmitUnaryExpression(writer, unaryExpressionNode),
_ => throw new ArgumentOutOfRangeException(nameof(rValue), rValue, null)
};
}
private Tmp EmitLValue(IndentedTextWriter writer, LValue lValue)
{
return lValue switch
{
ArrayIndexAccessNode arrayIndexAccessNode => EmitArrayIndexAccess(writer, arrayIndexAccessNode),
ConstArrayIndexAccessNode constArrayIndexAccessNode => EmitConstArrayIndexAccess(writer, constArrayIndexAccessNode),
DereferenceNode dereferenceNode => EmitDereference(writer, dereferenceNode),
SliceIndexAccessNode sliceIndexAccessNode => EmitSliceIndexAccess(writer, sliceIndexAccessNode),
StructFieldAccessNode structFieldAccessNode => EmitStructFieldAccess(writer, structFieldAccessNode),
VariableIdentifierNode variableIdentifierNode => EmitVariableIdentifier(writer, variableIdentifierNode),
_ => throw new ArgumentOutOfRangeException(nameof(lValue), lValue, null)
};
}
private Tmp EmitAddressOf(IndentedTextWriter writer, AddressOfNode addressOfNode)
{
var target = EmitExpression(writer, addressOfNode.Target);
return new Tmp(target.Ident, addressOfNode.Type, false);
}
private Tmp EmitArrayIndexAccess(IndentedTextWriter writer, ArrayIndexAccessNode arrayIndexAccessNode)
{
var arrayPtr = Unwrap(writer, EmitExpression(writer, arrayIndexAccessNode.Target));
var index = Unwrap(writer, EmitExpression(writer, arrayIndexAccessNode.Index));
var elementType = ((NubArrayType)arrayIndexAccessNode.Target.Type).ElementType;
var ptrTmp = NewTmp("array.element");
writer.WriteLine($"{ptrTmp} = getelementptr {MapType(elementType)}, ptr {arrayPtr}, {MapType(arrayIndexAccessNode.Index.Type)} {index}");
return new Tmp(ptrTmp, arrayIndexAccessNode.Type, true);
}
private Tmp EmitBinaryExpression(IndentedTextWriter writer, BinaryExpressionNode binaryExpressionNode)
{
var left = Unwrap(writer, EmitExpression(writer, binaryExpressionNode.Left));
var right = Unwrap(writer, EmitExpression(writer, binaryExpressionNode.Right));
var result = NewTmp("binop");
var leftType = binaryExpressionNode.Left.Type;
var op = binaryExpressionNode.Operator;
switch (op)
{
case BinaryOperator.Equal:
case BinaryOperator.NotEqual:
case BinaryOperator.GreaterThan:
case BinaryOperator.GreaterThanOrEqual:
case BinaryOperator.LessThan:
case BinaryOperator.LessThanOrEqual:
{
var cmpOp = leftType switch
{
NubIntType intType => GenerateIntComparison(op, intType.Signed),
NubFloatType => GenerateFloatComparison(op),
NubBoolType => GenerateBoolComparison(op),
NubPointerType => GeneratePointerComparison(op),
_ => throw new InvalidOperationException($"Unexpected type for comparison: {leftType}")
};
writer.WriteLine($"{result} = {cmpOp} {MapType(leftType)} {left}, {right}");
break;
}
case BinaryOperator.LogicalAnd:
{
writer.WriteLine($"{result} = and i1 {left}, {right}");
break;
}
case BinaryOperator.LogicalOr:
{
writer.WriteLine($"{result} = or i1 {left}, {right}");
break;
}
case BinaryOperator.Plus:
{
var instruction = leftType switch
{
NubIntType => "add",
NubFloatType => "fadd",
_ => throw new InvalidOperationException($"Unexpected type for plus: {leftType}")
};
writer.WriteLine($"{result} = {instruction} {MapType(leftType)} {left}, {right}");
break;
}
case BinaryOperator.Minus:
{
var instruction = leftType switch
{
NubIntType => "sub",
NubFloatType => "fsub",
_ => throw new InvalidOperationException($"Unexpected type for minus: {leftType}")
};
writer.WriteLine($"{result} = {instruction} {MapType(leftType)} {left}, {right}");
break;
}
case BinaryOperator.Multiply:
{
var instruction = leftType switch
{
NubIntType => "mul",
NubFloatType => "fmul",
_ => throw new InvalidOperationException($"Unexpected type for multiply: {leftType}")
};
writer.WriteLine($"{result} = {instruction} {MapType(leftType)} {left}, {right}");
break;
}
case BinaryOperator.Divide:
{
var instruction = leftType switch
{
NubIntType intType => intType.Signed ? "sdiv" : "udiv",
NubFloatType => "fdiv",
_ => throw new InvalidOperationException($"Unexpected type for divide: {leftType}")
};
writer.WriteLine($"{result} = {instruction} {MapType(leftType)} {left}, {right}");
break;
}
case BinaryOperator.Modulo:
{
var instruction = leftType switch
{
NubIntType intType => intType.Signed ? "srem" : "urem",
NubFloatType => "frem",
_ => throw new InvalidOperationException($"Unexpected type for modulo: {leftType}")
};
writer.WriteLine($"{result} = {instruction} {MapType(leftType)} {left}, {right}");
break;
}
case BinaryOperator.LeftShift:
{
writer.WriteLine($"{result} = shl {MapType(leftType)} {left}, {right}");
break;
}
case BinaryOperator.RightShift:
{
var intType = (NubIntType)leftType;
var instruction = intType.Signed ? "ashr" : "lshr";
writer.WriteLine($"{result} = {instruction} {MapType(leftType)} {left}, {right}");
break;
}
case BinaryOperator.BitwiseAnd:
{
writer.WriteLine($"{result} = and {MapType(leftType)} {left}, {right}");
break;
}
case BinaryOperator.BitwiseXor:
{
writer.WriteLine($"{result} = xor {MapType(leftType)} {left}, {right}");
break;
}
case BinaryOperator.BitwiseOr:
{
writer.WriteLine($"{result} = or {MapType(leftType)} {left}, {right}");
break;
}
default:
throw new ArgumentOutOfRangeException(nameof(op), op, null);
}
return new Tmp(result, binaryExpressionNode.Type, false);
}
private string GenerateIntComparison(BinaryOperator op, bool signed)
{
return op switch
{
BinaryOperator.Equal => "icmp eq",
BinaryOperator.NotEqual => "icmp ne",
BinaryOperator.GreaterThan => signed ? "icmp sgt" : "icmp ugt",
BinaryOperator.GreaterThanOrEqual => signed ? "icmp sge" : "icmp uge",
BinaryOperator.LessThan => signed ? "icmp slt" : "icmp ult",
BinaryOperator.LessThanOrEqual => signed ? "icmp sle" : "icmp ule",
_ => throw new ArgumentOutOfRangeException(nameof(op), op, null)
};
}
private string GenerateFloatComparison(BinaryOperator op)
{
return op switch
{
BinaryOperator.Equal => "fcmp oeq",
BinaryOperator.NotEqual => "fcmp one",
BinaryOperator.GreaterThan => "fcmp ogt",
BinaryOperator.GreaterThanOrEqual => "fcmp oge",
BinaryOperator.LessThan => "fcmp olt",
BinaryOperator.LessThanOrEqual => "fcmp ole",
_ => throw new ArgumentOutOfRangeException(nameof(op), op, null)
};
}
private string GenerateBoolComparison(BinaryOperator op)
{
return op switch
{
BinaryOperator.Equal => "icmp eq",
BinaryOperator.NotEqual => "icmp ne",
_ => throw new ArgumentOutOfRangeException(nameof(op), op, null)
};
}
private string GeneratePointerComparison(BinaryOperator op)
{
return op switch
{
BinaryOperator.Equal => "icmp eq",
BinaryOperator.NotEqual => "icmp ne",
BinaryOperator.GreaterThan => "icmp ugt",
BinaryOperator.GreaterThanOrEqual => "icmp uge",
BinaryOperator.LessThan => "icmp ult",
BinaryOperator.LessThanOrEqual => "icmp ule",
_ => throw new ArgumentOutOfRangeException(nameof(op), op, null)
};
}
private Tmp EmitBoolLiteral(IndentedTextWriter writer, BoolLiteralNode boolLiteralNode)
{
return new Tmp(boolLiteralNode.Value ? "1" : "0", boolLiteralNode.Type, false);
}
private Tmp EmitCast(IndentedTextWriter writer, CastNode castNode)
{
var source = Unwrap(writer, EmitExpression(writer, castNode.Value));
var sourceType = castNode.Value.Type;
var targetType = castNode.Type;
var result = NewTmp("cast");
switch (sourceType, targetType)
{
case (NubIntType sourceInt, NubIntType targetInt):
{
if (sourceInt.Width < targetInt.Width)
{
var op = sourceInt.Signed ? "sext" : "zext";
writer.WriteLine($"{result} = {op} {MapType(sourceType)} {source} to {MapType(targetType)}");
}
else if (sourceInt.Width > targetInt.Width)
{
writer.WriteLine($"{result} = trunc {MapType(sourceType)} {source} to {MapType(targetType)}");
}
else
{
writer.WriteLine($"{result} = bitcast {MapType(sourceType)} {source} to {MapType(targetType)}");
}
break;
}
case (NubFloatType sourceFloat, NubFloatType targetFloat):
{
if (sourceFloat.Width < targetFloat.Width)
{
writer.WriteLine($"{result} = fpext {MapType(sourceType)} {source} to {MapType(targetType)}");
}
else
{
writer.WriteLine($"{result} = fptrunc {MapType(sourceType)} {source} to {MapType(targetType)}");
}
break;
}
case (NubIntType intType, NubFloatType):
{
var intToFloatOp = intType.Signed ? "sitofp" : "uitofp";
writer.WriteLine($"{result} = {intToFloatOp} {MapType(sourceType)} {source} to {MapType(targetType)}");
break;
}
case (NubFloatType, NubIntType targetInt):
{
var floatToIntOp = targetInt.Signed ? "fptosi" : "fptoui";
writer.WriteLine($"{result} = {floatToIntOp} {MapType(sourceType)} {source} to {MapType(targetType)}");
break;
}
case (NubPointerType, NubPointerType):
case (NubPointerType, NubIntType):
case (NubIntType, NubPointerType):
{
writer.WriteLine($"{result} = inttoptr {MapType(sourceType)} {source} to {MapType(targetType)}");
break;
}
default:
{
throw new NotImplementedException($"Cast from {sourceType} to {targetType} not implemented");
}
}
return new Tmp(result, castNode.Type, false);
}
private Tmp EmitConstArrayIndexAccess(IndentedTextWriter writer, ConstArrayIndexAccessNode constArrayIndexAccessNode)
{
var arrayPtr = Unwrap(writer, EmitExpression(writer, constArrayIndexAccessNode.Target));
var index = Unwrap(writer, EmitExpression(writer, constArrayIndexAccessNode.Index));
var elementType = ((NubConstArrayType)constArrayIndexAccessNode.Target.Type).ElementType;
var ptrTmp = NewTmp("array.element");
writer.WriteLine($"{ptrTmp} = getelementptr {MapType(elementType)}, ptr {arrayPtr}, {MapType(constArrayIndexAccessNode.Index.Type)} {index}");
return new Tmp(ptrTmp, constArrayIndexAccessNode.Type, true);
}
private Tmp EmitConstArrayInitializer(IndentedTextWriter writer, ConstArrayInitializerNode constArrayInitializerNode, string? destination = null)
{
var arrayType = (NubConstArrayType)constArrayInitializerNode.Type;
if (destination == null)
{
destination = NewTmp("array");
writer.WriteLine($"{destination} = alloca {MapType(arrayType)}");
}
for (var i = 0; i < constArrayInitializerNode.Values.Count; i++)
{
var value = constArrayInitializerNode.Values[i];
var indexTmp = NewTmp("array.element");
writer.WriteLine($"{indexTmp} = getelementptr {MapType(arrayType)}, ptr {destination}, i32 0, i32 {i}");
EmitExpressionInto(writer, value, indexTmp);
}
return new Tmp(destination, constArrayInitializerNode.Type, false);
}
private Tmp EmitCStringLiteral(IndentedTextWriter writer, CStringLiteralNode cStringLiteralNode)
{
var escaped = EscapeStringForLLVM(cStringLiteralNode.Value);
var stringWithNull = cStringLiteralNode.Value + "\0";
var length = stringWithNull.Length;
var globalName = $"@.str.{_stringLiterals.Count}";
_stringLiterals.Add((globalName, length, escaped));
var gepTmp = NewTmp("str.ptr");
writer.WriteLine($"{gepTmp} = getelementptr [{length} x i8], ptr {globalName}, i32 0, i32 0");
return new Tmp(gepTmp, cStringLiteralNode.Type, false);
}
private string EscapeStringForLLVM(string input)
{
var result = new StringBuilder();
foreach (var c in input)
{
if (c == '\0')
result.Append("\\00");
else if (c == '\n')
result.Append("\\0A");
else if (c == '\r')
result.Append("\\0D");
else if (c == '\t')
result.Append("\\09");
else if (c == '\\')
result.Append("\\\\");
else if (c == '"')
result.Append("\\22");
else if (c < 32 || c > 126)
result.Append($"\\{(int)c:X2}");
else
result.Append(c);
}
return result.ToString();
}
private Tmp EmitDereference(IndentedTextWriter writer, DereferenceNode dereferenceNode)
{
throw new NotImplementedException();
}
private Tmp EmitFloat32Literal(IndentedTextWriter writer, Float32LiteralNode float32LiteralNode)
{
var literal = ((double)float32LiteralNode.Value).ToString("R", System.Globalization.CultureInfo.InvariantCulture);
if (!literal.Contains('.'))
{
literal += ".0";
}
return new Tmp(literal, float32LiteralNode.Type, false);
}
private Tmp EmitFloat64Literal(IndentedTextWriter writer, Float64LiteralNode float64LiteralNode)
{
var literal = float64LiteralNode.Value.ToString("R", System.Globalization.CultureInfo.InvariantCulture);
if (!literal.Contains('.'))
{
literal += ".0";
}
return new Tmp(literal, float64LiteralNode.Type, false);
}
private Tmp EmitFuncCall(IndentedTextWriter writer, FuncCallNode funcCallNode)
{
var result = NewTmp();
var parameterStrings = new List<string>();
foreach (var parameter in funcCallNode.Parameters)
{
var value = Unwrap(writer, EmitExpression(writer, parameter));
parameterStrings.Add($"{MapType(parameter.Type)} {value}");
}
var functionPtr = Unwrap(writer, EmitExpression(writer, funcCallNode.Expression));
if (funcCallNode.Type is NubVoidType)
{
writer.WriteLine($"call {MapType(funcCallNode.Type)} {functionPtr}({string.Join(", ", parameterStrings)})");
}
else
{
writer.WriteLine($"{result} = call {MapType(funcCallNode.Type)} {functionPtr}({string.Join(", ", parameterStrings)})");
}
return new Tmp(result, funcCallNode.Type, false);
}
private Tmp EmitFuncIdentifier(IndentedTextWriter writer, FuncIdentifierNode funcIdentifierNode)
{
var name = FuncName(funcIdentifierNode.ModuleToken.Value, funcIdentifierNode.NameToken.Value, funcIdentifierNode.ExternSymbolToken?.Value);
return new Tmp($"@{name}", funcIdentifierNode.Type, false);
}
private Tmp EmitI16Literal(IndentedTextWriter writer, I16LiteralNode i16LiteralNode)
{
return new Tmp(i16LiteralNode.Value.ToString(), i16LiteralNode.Type, false);
}
private Tmp EmitI32Literal(IndentedTextWriter writer, I32LiteralNode i32LiteralNode)
{
return new Tmp(i32LiteralNode.Value.ToString(), i32LiteralNode.Type, false);
}
private Tmp EmitI64Literal(IndentedTextWriter writer, I64LiteralNode i64LiteralNode)
{
return new Tmp(i64LiteralNode.Value.ToString(), i64LiteralNode.Type, false);
}
private Tmp EmitI8Literal(IndentedTextWriter writer, I8LiteralNode i8LiteralNode)
{
return new Tmp(i8LiteralNode.Value.ToString(), i8LiteralNode.Type, false);
}
private Tmp EmitSize(IndentedTextWriter writer, SizeNode sizeNode)
{
return new Tmp(sizeNode.TargetType.GetSize().ToString(), sizeNode.Type, false);
}
private Tmp EmitSliceIndexAccess(IndentedTextWriter writer, SliceIndexAccessNode sliceIndexAccessNode)
{
throw new NotImplementedException();
}
private Tmp EmitStringLiteral(IndentedTextWriter writer, StringLiteralNode stringLiteralNode)
{
throw new NotImplementedException();
}
private Tmp EmitStructFieldAccess(IndentedTextWriter writer, StructFieldAccessNode structFieldAccessNode)
{
var target = Unwrap(writer, EmitExpression(writer, structFieldAccessNode.Target));
var structType = (NubStructType)structFieldAccessNode.Target.Type;
var index = structType.GetFieldIndex(structFieldAccessNode.FieldToken.Value);
var ptrTmp = NewTmp($"struct.field.{structFieldAccessNode.FieldToken.Value}");
writer.WriteLine($"{ptrTmp} = getelementptr %{StructName(structType.Module, structType.Name)}, ptr {target}, i32 0, i32 {index}");
return new Tmp(ptrTmp, structFieldAccessNode.Type, true);
}
private Tmp EmitStructInitializer(IndentedTextWriter writer, StructInitializerNode structInitializerNode, string? destination = null)
{
if (destination == null)
{
destination = NewTmp("struct");
writer.WriteLine($"{destination} = alloca {MapType(structInitializerNode.Type)}");
}
var structType = (NubStructType)structInitializerNode.Type;
writer.WriteLine($"call void @{StructName(structType.Module, structType.Name)}.new(ptr {destination})");
foreach (var (name, value) in structInitializerNode.Initializers)
{
var index = structType.GetFieldIndex(name.Value);
var fieldTmp = NewTmp($"struct.field.{name}");
writer.WriteLine($"{fieldTmp} = getelementptr %{StructName(structType.Module, structType.Name)}, ptr {destination}, i32 0, i32 {index}");
EmitExpressionInto(writer, value, fieldTmp);
}
return new Tmp(destination, structInitializerNode.Type, false);
}
private Tmp EmitU16Literal(IndentedTextWriter writer, U16LiteralNode u16LiteralNode)
{
return new Tmp(u16LiteralNode.Value.ToString(), u16LiteralNode.Type, false);
}
private Tmp EmitU32Literal(IndentedTextWriter writer, U32LiteralNode u32LiteralNode)
{
return new Tmp(u32LiteralNode.Value.ToString(), u32LiteralNode.Type, false);
}
private Tmp EmitU64Literal(IndentedTextWriter writer, U64LiteralNode u64LiteralNode)
{
return new Tmp(u64LiteralNode.Value.ToString(), u64LiteralNode.Type, false);
}
private Tmp EmitU8Literal(IndentedTextWriter writer, U8LiteralNode u8LiteralNode)
{
return new Tmp(u8LiteralNode.Value.ToString(), u8LiteralNode.Type, false);
}
private Tmp EmitUnaryExpression(IndentedTextWriter writer, UnaryExpressionNode unaryExpressionNode)
{
var operand = Unwrap(writer, EmitExpression(writer, unaryExpressionNode.Operand));
var result = NewTmp("unary");
switch (unaryExpressionNode.Operator)
{
case UnaryOperator.Negate:
switch (unaryExpressionNode.Operand.Type)
{
case NubIntType intType:
writer.WriteLine($"{result} = sub {MapType(intType)} 0, {operand}");
break;
case NubFloatType floatType:
writer.WriteLine($"{result} = fneg {MapType(floatType)} {operand}");
break;
default:
throw new ArgumentOutOfRangeException();
}
break;
case UnaryOperator.Invert:
writer.WriteLine($"{result} = xor i1 {operand}, true");
break;
default:
throw new ArgumentOutOfRangeException();
}
return new Tmp(result, unaryExpressionNode.Type, false);
}
private Tmp EmitVariableIdentifier(IndentedTextWriter writer, VariableIdentifierNode variableIdentifierNode)
{
return new Tmp($"%{variableIdentifierNode.NameToken.Value}", variableIdentifierNode.Type, true);
}
private string StructName(StructNode structNode)
{
return StructName(_module, structNode.NameToken.Value);
}
private string StructName(string module, string name)
{
return $"struct.{module}.{name}";
}
private string FuncName(FuncPrototypeNode funcNodePrototype)
{
return FuncName(_module, funcNodePrototype.NameToken.Value, funcNodePrototype.ExternSymbolToken?.Value);
}
private string FuncName(string module, string name, string? externSymbol)
{
if (externSymbol != null)
{
return externSymbol;
}
return $"{module}.{name}";
}
private string MapType(NubType type)
{
return type switch
{
NubArrayType arrayType => $"{MapType(arrayType.ElementType)}*",
NubBoolType => "i1",
NubConstArrayType constArrayType => $"[{constArrayType.Size} x {MapType(constArrayType.ElementType)}]",
NubFloatType floatType => floatType.Width == 32 ? "float" : "double",
NubFuncType funcType => MapFuncType(funcType),
NubIntType intType => $"i{intType.Width}",
NubPointerType pointerType => $"{MapType(pointerType.BaseType)}*",
NubSliceType sliceType => throw new NotImplementedException(),
NubStringType stringType => throw new NotImplementedException(),
NubStructType structType => $"%{StructName(structType.Module, structType.Name)}",
NubVoidType => "void",
_ => throw new ArgumentOutOfRangeException(nameof(type))
};
}
private string MapFuncType(NubFuncType funcType)
{
var paramTypes = string.Join(", ", funcType.Parameters.Select(MapType));
var returnType = MapType(funcType.ReturnType);
return $"{returnType} ({paramTypes})*";
}
private record Tmp(string Ident, NubType Type, bool LValue);
private string Unwrap(IndentedTextWriter writer, Tmp tmp)
{
if (tmp.LValue && !tmp.Type.IsAggregate())
{
var newTmp = NewTmp("deref");
writer.WriteLine($"{newTmp} = load {MapType(tmp.Type)}, ptr {tmp.Ident}");
return newTmp;
}
return tmp.Ident;
}
private string NewTmp(string name = "t")
{
return $"%{name}.{++_tmpIndex}";
}
private string NewLabel(string name = "l")
{
return $"{name}.{++_labelIndex}";
}
}

View File

@@ -0,0 +1,781 @@
using System.Text;
using LLVMSharp.Interop;
using NubLang.Ast;
using NubLang.Modules;
using NubLang.Types;
namespace NubLang.Generation;
public class LlvmSharpGenerator
{
private string _module = string.Empty;
private LLVMContextRef _context;
private LLVMModuleRef _llvmModule;
private LLVMBuilderRef _builder;
private readonly Dictionary<string, LLVMTypeRef> _structTypes = new();
private readonly Dictionary<string, LLVMValueRef> _functions = new();
private readonly Dictionary<string, LLVMValueRef> _locals = new();
private readonly Stack<(LLVMBasicBlockRef breakBlock, LLVMBasicBlockRef continueBlock)> _loopStack = new();
public void Emit(List<TopLevelNode> topLevelNodes, ModuleRepository repository, string sourceFileName, string outputPath)
{
_module = topLevelNodes.OfType<ModuleNode>().First().NameToken.Value;
_context = LLVMContextRef.Global;
_llvmModule = _context.CreateModuleWithName(sourceFileName);
_llvmModule.Target = "x86_64-pc-linux-gnu";
_llvmModule.DataLayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128";
_builder = _context.CreateBuilder();
_structTypes.Clear();
_functions.Clear();
_locals.Clear();
_loopStack.Clear();
var stringType = _context.CreateNamedStruct("nub.string");
stringType.StructSetBody([LLVMTypeRef.Int64, LLVMTypeRef.CreatePointer(LLVMTypeRef.Int8, 0)], false);
_structTypes["nub.string"] = stringType;
foreach (var module in repository.GetAll())
{
foreach (var structType in module.StructTypes)
{
var structName = StructName(structType.Module, structType.Name);
var llvmStructType = _context.CreateNamedStruct(structName);
_structTypes[structName] = llvmStructType;
}
}
foreach (var module in repository.GetAll())
{
foreach (var structType in module.StructTypes)
{
var structName = StructName(structType.Module, structType.Name);
var llvmStructType = _structTypes[structName];
var fieldTypes = structType.Fields.Select(f => MapType(f.Type)).ToArray();
llvmStructType.StructSetBody(fieldTypes, false);
}
}
foreach (var module in repository.GetAll())
{
foreach (var prototype in module.FunctionPrototypes)
{
CreateFunctionDeclaration(prototype, module.Name);
}
}
foreach (var structNode in topLevelNodes.OfType<StructNode>())
{
EmitStructConstructor(structNode);
}
foreach (var funcNode in topLevelNodes.OfType<FuncNode>())
{
if (funcNode.Body != null)
{
EmitFunction(funcNode);
}
}
if (!_llvmModule.TryVerify(LLVMVerifierFailureAction.LLVMPrintMessageAction, out var error))
{
// throw new Exception($"LLVM module verification failed: {error}");
}
_llvmModule.PrintToFile(outputPath);
_builder.Dispose();
}
private void CreateFunctionDeclaration(FuncPrototypeNode prototype, string moduleName)
{
var funcName = FuncName(moduleName, prototype.NameToken.Value, prototype.ExternSymbolToken?.Value);
var paramTypes = prototype.Parameters.Select(p => MapType(p.Type)).ToArray();
var returnType = MapType(prototype.ReturnType);
var funcType = LLVMTypeRef.CreateFunction(returnType, paramTypes);
var func = _llvmModule.AddFunction(funcName, funcType);
func.FunctionCallConv = (uint)LLVMCallConv.LLVMCCallConv;
for (var i = 0; i < prototype.Parameters.Count; i++)
{
func.GetParam((uint)i).Name = prototype.Parameters[i].NameToken.Value;
}
_functions[funcName] = func;
}
private void EmitStructConstructor(StructNode structNode)
{
var structType = _structTypes[StructName(_module, structNode.NameToken.Value)];
var ptrType = LLVMTypeRef.CreatePointer(structType, 0);
var funcType = LLVMTypeRef.CreateFunction(LLVMTypeRef.Void, [ptrType]);
var funcName = StructConstructorName(_module, structNode.NameToken.Value);
var func = _llvmModule.AddFunction(funcName, funcType);
func.FunctionCallConv = (uint)LLVMCallConv.LLVMCCallConv;
var entryBlock = func.AppendBasicBlock("entry");
_builder.PositionAtEnd(entryBlock);
var selfParam = func.GetParam(0);
selfParam.Name = "self";
_locals.Clear();
foreach (var field in structNode.Fields)
{
if (field.Value != null)
{
var index = structNode.StructType.GetFieldIndex(field.NameToken.Value);
var fieldPtr = _builder.BuildStructGEP2(structType, selfParam, (uint)index);
EmitExpressionInto(field.Value, fieldPtr);
}
}
_builder.BuildRetVoid();
_functions[funcName] = func;
}
private void EmitFunction(FuncNode funcNode)
{
var funcName = FuncName(_module, funcNode.Prototype.NameToken.Value, funcNode.Prototype.ExternSymbolToken?.Value);
var func = _functions[funcName];
var entryBlock = func.AppendBasicBlock("entry");
_builder.PositionAtEnd(entryBlock);
_locals.Clear();
for (uint i = 0; i < funcNode.Prototype.Parameters.Count; i++)
{
var param = func.GetParam(i);
var paramNode = funcNode.Prototype.Parameters[(int)i];
var alloca = _builder.BuildAlloca(MapType(paramNode.Type), paramNode.NameToken.Value);
_builder.BuildStore(param, alloca);
_locals[paramNode.NameToken.Value] = alloca;
}
EmitBlock(funcNode.Body!);
if (funcNode.Prototype.ReturnType is NubVoidType)
{
if (_builder.InsertBlock.Terminator.Handle == IntPtr.Zero)
{
_builder.BuildRetVoid();
}
}
}
private void EmitBlock(BlockNode blockNode)
{
foreach (var statement in blockNode.Statements)
{
EmitStatement(statement);
}
}
private void EmitStatement(StatementNode statement)
{
switch (statement)
{
case AssignmentNode assignment:
EmitAssignment(assignment);
break;
case BlockNode block:
EmitBlock(block);
break;
case BreakNode:
EmitBreak();
break;
case ContinueNode:
EmitContinue();
break;
case IfNode ifNode:
EmitIf(ifNode);
break;
case ReturnNode returnNode:
EmitReturn(returnNode);
break;
case StatementFuncCallNode funcCall:
EmitExpression(funcCall.FuncCall);
break;
case VariableDeclarationNode varDecl:
EmitVariableDeclaration(varDecl);
break;
case WhileNode whileNode:
EmitWhile(whileNode);
break;
default:
throw new NotImplementedException($"Statement type {statement.GetType()} not implemented");
}
}
private void EmitAssignment(AssignmentNode assignment)
{
var targetPtr = EmitExpression(assignment.Target, asLValue: true);
var value = EmitExpression(assignment.Value);
_builder.BuildStore(value, targetPtr);
}
private void EmitBreak()
{
var (breakBlock, _) = _loopStack.Peek();
_builder.BuildBr(breakBlock);
}
private void EmitContinue()
{
var (_, continueBlock) = _loopStack.Peek();
_builder.BuildBr(continueBlock);
}
private void EmitIf(IfNode ifNode)
{
var condition = EmitExpression(ifNode.Condition);
var func = _builder.InsertBlock.Parent;
var thenBlock = func.AppendBasicBlock("if.then");
var elseBlock = ifNode.Else.HasValue ? func.AppendBasicBlock("if.else") : default;
var endBlock = func.AppendBasicBlock("if.end");
_builder.BuildCondBr(condition, thenBlock, ifNode.Else.HasValue ? elseBlock : endBlock);
_builder.PositionAtEnd(thenBlock);
EmitBlock(ifNode.Body);
if (_builder.InsertBlock.Terminator.Handle == IntPtr.Zero)
{
_builder.BuildBr(endBlock);
}
if (ifNode.Else.HasValue)
{
_builder.PositionAtEnd(elseBlock);
ifNode.Else.Value.Match(EmitIf, EmitBlock);
if (_builder.InsertBlock.Terminator.Handle == IntPtr.Zero)
{
_builder.BuildBr(endBlock);
}
}
_builder.PositionAtEnd(endBlock);
}
private void EmitReturn(ReturnNode returnNode)
{
if (returnNode.Value != null)
{
var value = EmitExpression(returnNode.Value);
_builder.BuildRet(value);
}
else
{
_builder.BuildRetVoid();
}
}
private void EmitVariableDeclaration(VariableDeclarationNode varDecl)
{
var alloca = _builder.BuildAlloca(MapType(varDecl.Type), varDecl.NameToken.Value);
_locals[varDecl.NameToken.Value] = alloca;
if (varDecl.Assignment != null)
{
EmitExpressionInto(varDecl.Assignment, alloca);
}
}
private void EmitWhile(WhileNode whileNode)
{
var func = _builder.InsertBlock.Parent;
var condBlock = func.AppendBasicBlock("while.cond");
var bodyBlock = func.AppendBasicBlock("while.body");
var endBlock = func.AppendBasicBlock("while.end");
_loopStack.Push((endBlock, condBlock));
_builder.BuildBr(condBlock);
_builder.PositionAtEnd(condBlock);
var condition = EmitExpression(whileNode.Condition);
_builder.BuildCondBr(condition, bodyBlock, endBlock);
_builder.PositionAtEnd(bodyBlock);
EmitBlock(whileNode.Body);
if (_builder.InsertBlock.Terminator.Handle == IntPtr.Zero)
{
_builder.BuildBr(condBlock);
}
_loopStack.Pop();
_builder.PositionAtEnd(endBlock);
}
private LLVMValueRef EmitExpression(ExpressionNode expr, bool asLValue = false)
{
var result = expr switch
{
StringLiteralNode stringLiteralNode => EmitStringLiteral(stringLiteralNode),
CStringLiteralNode cStringLiteralNode => EmitCStringLiteral(cStringLiteralNode),
BoolLiteralNode b => LLVMValueRef.CreateConstInt(LLVMTypeRef.Int1, b.Value ? 1UL : 0UL),
I8LiteralNode i => LLVMValueRef.CreateConstInt(LLVMTypeRef.Int8, (ulong)i.Value, true),
I16LiteralNode i => LLVMValueRef.CreateConstInt(LLVMTypeRef.Int16, (ulong)i.Value, true),
I32LiteralNode i => LLVMValueRef.CreateConstInt(LLVMTypeRef.Int32, (ulong)i.Value, true),
I64LiteralNode i => LLVMValueRef.CreateConstInt(LLVMTypeRef.Int64, (ulong)i.Value, true),
U8LiteralNode u => LLVMValueRef.CreateConstInt(LLVMTypeRef.Int8, u.Value),
U16LiteralNode u => LLVMValueRef.CreateConstInt(LLVMTypeRef.Int16, u.Value),
U32LiteralNode u => LLVMValueRef.CreateConstInt(LLVMTypeRef.Int32, u.Value),
U64LiteralNode u => LLVMValueRef.CreateConstInt(LLVMTypeRef.Int64, u.Value),
Float32LiteralNode f => LLVMValueRef.CreateConstReal(LLVMTypeRef.Float, f.Value),
Float64LiteralNode f => LLVMValueRef.CreateConstReal(LLVMTypeRef.Double, f.Value),
VariableIdentifierNode v => EmitVariableIdentifier(v),
LocalFuncIdentifierNode localFuncIdentifierNode => EmitLocalFuncIdentifier(localFuncIdentifierNode),
ModuleFuncIdentifierNode moduleFuncIdentifierNode => EmitModuleFuncIdentifier(moduleFuncIdentifierNode),
BinaryExpressionNode bin => EmitBinaryExpression(bin),
UnaryExpressionNode unary => EmitUnaryExpression(unary),
StructFieldAccessNode field => EmitStructFieldAccess(field),
ConstArrayIndexAccessNode arr => EmitConstArrayIndexAccess(arr),
SliceIndexAccessNode sliceIndexAccessNode => EmitSliceIndexAccess(sliceIndexAccessNode),
ArrayIndexAccessNode arrayIndexAccessNode => EmitArrayIndexAccess(arrayIndexAccessNode),
ConstArrayInitializerNode constArrayInitializerNode => EmitConstArrayInitializer(constArrayInitializerNode),
StructInitializerNode structInitializerNode => EmitStructInitializer(structInitializerNode),
AddressOfNode addr => EmitAddressOf(addr),
DereferenceNode deref => EmitDereference(deref),
FuncCallNode funcCall => EmitFuncCall(funcCall),
CastNode cast => EmitCast(cast),
SizeNode size => LLVMValueRef.CreateConstInt(LLVMTypeRef.Int64, size.TargetType.GetSize()),
_ => throw new ArgumentOutOfRangeException(nameof(expr), expr, null)
};
if (expr is LValue)
{
if (asLValue)
{
return result;
}
return _builder.BuildLoad2(MapType(expr.Type), result);
}
if (asLValue)
{
throw new InvalidOperationException($"Expression of type {expr.GetType().Name} is not an lvalue and cannot be used where an address is required");
}
return result;
}
private void EmitExpressionInto(ExpressionNode expr, LLVMValueRef destPtr)
{
switch (expr)
{
case StructInitializerNode structInit:
EmitStructInitializer(structInit, destPtr);
return;
case ConstArrayInitializerNode arrayInit:
EmitConstArrayInitializer(arrayInit, destPtr);
return;
default:
{
var result = EmitExpression(expr);
_builder.BuildStore(result, destPtr);
break;
}
}
}
private LLVMValueRef EmitStringLiteral(StringLiteralNode stringLiteralNode)
{
var strValue = stringLiteralNode.Value;
var length = (ulong)Encoding.UTF8.GetByteCount(strValue);
var globalStr = _builder.BuildGlobalStringPtr(strValue);
var llvmStringType = MapType(stringLiteralNode.Type);
var strAlloca = _builder.BuildAlloca(llvmStringType);
var lengthPtr = _builder.BuildStructGEP2(llvmStringType, strAlloca, 0);
var lengthConst = LLVMValueRef.CreateConstInt(LLVMTypeRef.Int64, length);
_builder.BuildStore(lengthConst, lengthPtr);
var dataPtr = _builder.BuildStructGEP2(llvmStringType, strAlloca, 1);
_builder.BuildStore(globalStr, dataPtr);
return _builder.BuildLoad2(llvmStringType, strAlloca);
}
private LLVMValueRef EmitCStringLiteral(CStringLiteralNode cStringLiteralNode)
{
return _builder.BuildGlobalStringPtr(cStringLiteralNode.Value);
}
private LLVMValueRef EmitVariableIdentifier(VariableIdentifierNode v)
{
return _locals[v.NameToken.Value];
}
private LLVMValueRef EmitLocalFuncIdentifier(LocalFuncIdentifierNode localFuncIdentifierNode)
{
return _functions[FuncName(_module, localFuncIdentifierNode.NameToken.Value, localFuncIdentifierNode.ExternSymbolToken?.Value)];
}
private LLVMValueRef EmitModuleFuncIdentifier(ModuleFuncIdentifierNode moduleFuncIdentifierNode)
{
return _functions[FuncName(moduleFuncIdentifierNode.ModuleToken.Value, moduleFuncIdentifierNode.NameToken.Value, moduleFuncIdentifierNode.ExternSymbolToken?.Value)];
}
private LLVMValueRef EmitBinaryExpression(BinaryExpressionNode bin)
{
var left = EmitExpression(bin.Left);
var right = EmitExpression(bin.Right);
var leftType = bin.Left.Type;
var result = bin.Operator switch
{
BinaryOperator.Plus when leftType is NubIntType => _builder.BuildAdd(left, right),
BinaryOperator.Plus when leftType is NubFloatType => _builder.BuildFAdd(left, right),
BinaryOperator.Minus when leftType is NubIntType => _builder.BuildSub(left, right),
BinaryOperator.Minus when leftType is NubFloatType => _builder.BuildFSub(left, right),
BinaryOperator.Multiply when leftType is NubIntType => _builder.BuildMul(left, right),
BinaryOperator.Multiply when leftType is NubFloatType => _builder.BuildFMul(left, right),
BinaryOperator.Divide when leftType is NubIntType intType => intType.Signed ? _builder.BuildSDiv(left, right) : _builder.BuildUDiv(left, right),
BinaryOperator.Divide when leftType is NubFloatType => _builder.BuildFDiv(left, right),
BinaryOperator.Modulo when leftType is NubIntType intType => intType.Signed ? _builder.BuildSRem(left, right) : _builder.BuildURem(left, right),
BinaryOperator.Modulo when leftType is NubFloatType => _builder.BuildFRem(left, right),
BinaryOperator.LogicalAnd => _builder.BuildAnd(left, right),
BinaryOperator.LogicalOr => _builder.BuildOr(left, right),
BinaryOperator.Equal when leftType is NubIntType or NubBoolType or NubPointerType => _builder.BuildICmp(LLVMIntPredicate.LLVMIntEQ, left, right),
BinaryOperator.Equal when leftType is NubFloatType => _builder.BuildFCmp(LLVMRealPredicate.LLVMRealOEQ, left, right),
BinaryOperator.NotEqual when leftType is NubIntType or NubBoolType or NubPointerType => _builder.BuildICmp(LLVMIntPredicate.LLVMIntNE, left, right),
BinaryOperator.NotEqual when leftType is NubFloatType => _builder.BuildFCmp(LLVMRealPredicate.LLVMRealONE, left, right),
BinaryOperator.GreaterThan when leftType is NubIntType intType => _builder.BuildICmp(intType.Signed ? LLVMIntPredicate.LLVMIntSGT : LLVMIntPredicate.LLVMIntUGT, left, right),
BinaryOperator.GreaterThan when leftType is NubFloatType => _builder.BuildFCmp(LLVMRealPredicate.LLVMRealOGT, left, right),
BinaryOperator.GreaterThanOrEqual when leftType is NubIntType intType => _builder.BuildICmp(intType.Signed ? LLVMIntPredicate.LLVMIntSGE : LLVMIntPredicate.LLVMIntUGE, left, right),
BinaryOperator.GreaterThanOrEqual when leftType is NubFloatType => _builder.BuildFCmp(LLVMRealPredicate.LLVMRealOGE, left, right),
BinaryOperator.LessThan when leftType is NubIntType intType => _builder.BuildICmp(intType.Signed ? LLVMIntPredicate.LLVMIntSLT : LLVMIntPredicate.LLVMIntULT, left, right),
BinaryOperator.LessThan when leftType is NubFloatType => _builder.BuildFCmp(LLVMRealPredicate.LLVMRealOLT, left, right),
BinaryOperator.LessThanOrEqual when leftType is NubIntType intType => _builder.BuildICmp(intType.Signed ? LLVMIntPredicate.LLVMIntSLE : LLVMIntPredicate.LLVMIntULE, left, right),
BinaryOperator.LessThanOrEqual when leftType is NubFloatType => _builder.BuildFCmp(LLVMRealPredicate.LLVMRealOLE, left, right),
BinaryOperator.LeftShift => _builder.BuildShl(left, right),
BinaryOperator.RightShift when leftType is NubIntType intType => intType.Signed ? _builder.BuildAShr(left, right) : _builder.BuildLShr(left, right),
BinaryOperator.BitwiseAnd => _builder.BuildAnd(left, right),
BinaryOperator.BitwiseXor => _builder.BuildXor(left, right),
BinaryOperator.BitwiseOr => _builder.BuildOr(left, right),
_ => throw new ArgumentOutOfRangeException(nameof(bin.Operator))
};
return result;
}
private LLVMValueRef EmitUnaryExpression(UnaryExpressionNode unary)
{
var operand = EmitExpression(unary.Operand);
var result = unary.Operator switch
{
UnaryOperator.Negate when unary.Operand.Type is NubIntType => _builder.BuildNeg(operand),
UnaryOperator.Negate when unary.Operand.Type is NubFloatType => _builder.BuildFNeg(operand),
UnaryOperator.Invert => _builder.BuildXor(operand, LLVMValueRef.CreateConstInt(LLVMTypeRef.Int1, 1)),
_ => throw new ArgumentOutOfRangeException(nameof(unary.Operator))
};
return result;
}
private LLVMValueRef EmitFuncCall(FuncCallNode funcCall)
{
var funcPtr = EmitExpression(funcCall.Expression);
var args = funcCall.Parameters.Select(x => EmitExpression(x)).ToArray();
return _builder.BuildCall2(MapType(funcCall.Expression.Type), funcPtr, args, funcCall.Type is NubVoidType ? "" : "call");
}
private LLVMValueRef EmitStructFieldAccess(StructFieldAccessNode field)
{
var target = EmitExpression(field.Target, asLValue: true);
var structType = (NubStructType)field.Target.Type;
var index = structType.GetFieldIndex(field.FieldToken.Value);
var llvmStructType = _structTypes[StructName(structType.Module, structType.Name)];
return _builder.BuildStructGEP2(llvmStructType, target, (uint)index);
}
private LLVMValueRef EmitConstArrayIndexAccess(ConstArrayIndexAccessNode constArrayIndexAccessNode)
{
var arrayPtr = EmitExpression(constArrayIndexAccessNode.Target, asLValue: true);
var index = EmitExpression(constArrayIndexAccessNode.Index);
var indices = new[] { LLVMValueRef.CreateConstInt(LLVMTypeRef.Int32, 0), index };
return _builder.BuildInBoundsGEP2(MapType(constArrayIndexAccessNode.Target.Type), arrayPtr, indices);
}
private LLVMValueRef EmitSliceIndexAccess(SliceIndexAccessNode sliceIndexAccessNode)
{
var slicePtr = EmitExpression(sliceIndexAccessNode.Target, asLValue: true);
var index = EmitExpression(sliceIndexAccessNode.Index);
var sliceType = (NubSliceType)sliceIndexAccessNode.Target.Type;
var llvmSliceType = MapType(sliceType);
var elementType = MapType(sliceType.ElementType);
var dataPtrPtr = _builder.BuildStructGEP2(llvmSliceType, slicePtr, 1);
var dataPtr = _builder.BuildLoad2(LLVMTypeRef.CreatePointer(elementType, 0), dataPtrPtr);
return _builder.BuildInBoundsGEP2(elementType, dataPtr, [index]);
}
private LLVMValueRef EmitArrayIndexAccess(ArrayIndexAccessNode arrayIndexAccessNode)
{
var arrayPtr = EmitExpression(arrayIndexAccessNode.Target);
var index = EmitExpression(arrayIndexAccessNode.Index);
return _builder.BuildGEP2(MapType(arrayIndexAccessNode.Target.Type), arrayPtr, [index]);
}
private LLVMValueRef EmitConstArrayInitializer(ConstArrayInitializerNode constArrayInitializerNode, LLVMValueRef? destination = null)
{
var arrayType = (NubConstArrayType)constArrayInitializerNode.Type;
var llvmType = MapType(arrayType);
destination ??= _builder.BuildAlloca(llvmType);
for (var i = 0; i < constArrayInitializerNode.Values.Count; i++)
{
var indices = new[]
{
LLVMValueRef.CreateConstInt(LLVMTypeRef.Int32, 0),
LLVMValueRef.CreateConstInt(LLVMTypeRef.Int32, (ulong)i)
};
var elementPtr = _builder.BuildInBoundsGEP2(llvmType, destination.Value, indices);
EmitExpressionInto(constArrayInitializerNode.Values[i], elementPtr);
}
return destination.Value;
}
private LLVMValueRef EmitStructInitializer(StructInitializerNode structInitializerNode, LLVMValueRef? destination = null)
{
var type = (NubStructType)structInitializerNode.Type;
var llvmType = MapType(type);
destination ??= _builder.BuildAlloca(llvmType);
var constructorType = LLVMTypeRef.CreateFunction(LLVMTypeRef.Void, [LLVMTypeRef.CreatePointer(_structTypes[StructName(type.Module, type.Name)], 0)]);
var constructor = _functions[StructConstructorName(type.Module, type.Name)];
_builder.BuildCall2(constructorType, constructor, [destination.Value]);
foreach (var (name, value) in structInitializerNode.Initializers)
{
var fieldIndex = type.GetFieldIndex(name.Value);
var fieldPtr = _builder.BuildStructGEP2(llvmType, destination.Value, (uint)fieldIndex);
EmitExpressionInto(value, fieldPtr);
}
return destination.Value;
}
private LLVMValueRef EmitAddressOf(AddressOfNode addr)
{
return EmitExpression(addr.Target, asLValue: true);
}
private LLVMValueRef EmitDereference(DereferenceNode deref)
{
return EmitExpression(deref.Target, asLValue: false);
}
private LLVMValueRef EmitCast(CastNode castNode)
{
return castNode.ConversionType switch
{
CastNode.Conversion.IntToInt => EmitIntToIntCast(castNode),
CastNode.Conversion.FloatToFloat => EmitFloatToFloatCast(castNode),
CastNode.Conversion.IntToFloat => EmitIntToFloatCast(castNode),
CastNode.Conversion.FloatToInt => EmitFloatToIntCast(castNode),
CastNode.Conversion.PointerToPointer or CastNode.Conversion.PointerToUInt64 or CastNode.Conversion.UInt64ToPointer => _builder.BuildIntToPtr(EmitExpression(castNode.Value), MapType(castNode.Type)),
CastNode.Conversion.ConstArrayToSlice => EmitConstArrayToSliceCast(castNode),
CastNode.Conversion.ConstArrayToArray => EmitConstArrayToArrayCast(castNode),
CastNode.Conversion.StringToCString => EmitStringToCStringCast(castNode),
_ => throw new ArgumentOutOfRangeException(nameof(castNode.ConversionType))
};
}
private LLVMValueRef EmitIntToIntCast(CastNode castNode)
{
var sourceInt = (NubIntType)castNode.Value.Type;
var targetInt = (NubIntType)castNode.Type;
var source = EmitExpression(castNode.Value);
if (sourceInt.Width < targetInt.Width)
{
return sourceInt.Signed
? _builder.BuildSExt(source, MapType(targetInt))
: _builder.BuildZExt(source, MapType(targetInt));
}
if (sourceInt.Width > targetInt.Width)
{
return _builder.BuildTrunc(source, MapType(targetInt));
}
return _builder.BuildBitCast(source, MapType(targetInt));
}
private LLVMValueRef EmitFloatToFloatCast(CastNode castNode)
{
var sourceFloat = (NubFloatType)castNode.Value.Type;
var targetFloat = (NubFloatType)castNode.Type;
var source = EmitExpression(castNode.Value);
return sourceFloat.Width < targetFloat.Width
? _builder.BuildFPExt(source, MapType(castNode.Type))
: _builder.BuildFPTrunc(source, MapType(castNode.Type));
}
private LLVMValueRef EmitIntToFloatCast(CastNode castNode)
{
var sourceInt = (NubIntType)castNode.Value.Type;
var source = EmitExpression(castNode.Value);
return sourceInt.Signed
? _builder.BuildSIToFP(source, MapType(castNode.Type))
: _builder.BuildUIToFP(source, MapType(castNode.Type));
}
private LLVMValueRef EmitFloatToIntCast(CastNode castNode)
{
var targetInt = (NubIntType)castNode.Type;
var source = EmitExpression(castNode.Value);
return targetInt.Signed
? _builder.BuildFPToSI(source, MapType(targetInt))
: _builder.BuildFPToUI(source, MapType(targetInt));
}
private LLVMValueRef EmitConstArrayToSliceCast(CastNode castNode)
{
var sourceArrayType = (NubConstArrayType)castNode.Value.Type;
var targetSliceType = (NubSliceType)castNode.Type;
var source = EmitExpression(castNode.Value, asLValue: true);
var llvmArrayType = MapType(sourceArrayType);
var llvmSliceType = MapType(targetSliceType);
var indices = new[]
{
LLVMValueRef.CreateConstInt(LLVMTypeRef.Int32, 0),
LLVMValueRef.CreateConstInt(LLVMTypeRef.Int32, 0)
};
var firstElementPtr = _builder.BuildInBoundsGEP2(llvmArrayType, source, indices);
var slicePtr = _builder.BuildAlloca(llvmSliceType);
var lengthPtr = _builder.BuildStructGEP2(llvmSliceType, slicePtr, 0);
var length = LLVMValueRef.CreateConstInt(LLVMTypeRef.Int64, sourceArrayType.Size);
_builder.BuildStore(length, lengthPtr);
var dataPtrPtr = _builder.BuildStructGEP2(llvmSliceType, slicePtr, 1);
_builder.BuildStore(firstElementPtr, dataPtrPtr);
return _builder.BuildLoad2(llvmSliceType, slicePtr);
}
private LLVMValueRef EmitConstArrayToArrayCast(CastNode castNode)
{
var sourceArrayType = (NubConstArrayType)castNode.Value.Type;
var source = EmitExpression(castNode.Value, asLValue: true);
var indices = new[]
{
LLVMValueRef.CreateConstInt(LLVMTypeRef.Int32, 0),
LLVMValueRef.CreateConstInt(LLVMTypeRef.Int32, 0)
};
return _builder.BuildInBoundsGEP2(MapType(sourceArrayType), source, indices);
}
private LLVMValueRef EmitStringToCStringCast(CastNode castNode)
{
var source = EmitExpression(castNode.Value, asLValue: true);
var dataPtrPtr = _builder.BuildStructGEP2(MapType(castNode.Value.Type), source, 1);
return _builder.BuildLoad2(LLVMTypeRef.CreatePointer(LLVMTypeRef.Int8, 0), dataPtrPtr);
}
private LLVMTypeRef MapType(NubType type)
{
return type switch
{
NubBoolType => LLVMTypeRef.Int1,
NubIntType intType => LLVMTypeRef.CreateInt((uint)intType.Width),
NubFloatType floatType => floatType.Width == 32 ? LLVMTypeRef.Float : LLVMTypeRef.Double,
NubFuncType funcType => LLVMTypeRef.CreateFunction(MapType(funcType.ReturnType), funcType.Parameters.Select(MapType).ToArray()),
NubPointerType ptrType => LLVMTypeRef.CreatePointer(MapType(ptrType.BaseType), 0),
NubSliceType nubSliceType => MapSliceType(nubSliceType),
NubStringType => _structTypes["nub.string"],
NubArrayType arrType => LLVMTypeRef.CreatePointer(MapType(arrType.ElementType), 0),
NubConstArrayType constArr => LLVMTypeRef.CreateArray(MapType(constArr.ElementType), (uint)constArr.Size),
NubStructType structType => _structTypes[StructName(structType.Module, structType.Name)],
NubVoidType => LLVMTypeRef.Void,
_ => throw new ArgumentOutOfRangeException(nameof(type), type, null)
};
}
private LLVMTypeRef MapSliceType(NubSliceType nubSliceType)
{
var mangledName = NameMangler.Mangle(nubSliceType.ElementType);
var name = $"nub.slice.{mangledName}";
if (!_structTypes.TryGetValue(name, out var type))
{
type = _context.CreateNamedStruct(name);
type.StructSetBody([LLVMTypeRef.Int64, LLVMTypeRef.CreatePointer(MapType(nubSliceType.ElementType), 0)], false);
_structTypes[name] = type;
}
return type;
}
private static string StructName(string module, string name)
{
return $"struct.{module}.{name}";
}
private static string StructConstructorName(string module, string name)
{
return $"{StructName(module, name)}.new";
}
private static string FuncName(string module, string name, string? externSymbol)
{
if (externSymbol != null)
{
return externSymbol;
}
return $"{module}.{name}";
}
}

View File

@@ -0,0 +1,308 @@
using System.Diagnostics.CodeAnalysis;
using NubLang.Ast;
using NubLang.Diagnostics;
using NubLang.Syntax;
using NubLang.Types;
namespace NubLang.Modules;
public sealed class ModuleRepository
{
public static ModuleRepository Create(List<SyntaxTree> syntaxTrees)
{
var structTypes = new Dictionary<(string module, string name), NubStructType>();
var enumTypes = new Dictionary<(string module, string name), NubIntType>();
foreach (var syntaxTree in syntaxTrees)
{
var module = syntaxTree.TopLevelSyntaxNodes.OfType<ModuleSyntax>().FirstOrDefault();
if (module == null)
{
throw new CompileException(Diagnostic.Error("Module declaration missing").WithHelp("module \"main\"").Build());
}
foreach (var structSyntax in syntaxTree.TopLevelSyntaxNodes.OfType<StructSyntax>())
{
// note(nub31): Since not all struct types are registered yet, we cannot register field types as they might reference unregistered structs
var key = (module.NameToken.Value, structSyntax.NameToken.Value);
structTypes.Add(key, new NubStructType(module.NameToken.Value, structSyntax.NameToken.Value, structSyntax.Packed, []));
}
foreach (var enumSyntax in syntaxTree.TopLevelSyntaxNodes.OfType<EnumSyntax>())
{
NubIntType? underlyingType = null;
if (enumSyntax.Type != null)
{
if (enumSyntax.Type is not IntTypeSyntax intType)
{
throw new CompileException(Diagnostic.Error("Underlying type of enum must be an integer type").At(enumSyntax.Type).Build());
}
underlyingType = new NubIntType(intType.Signed, intType.Width);
}
underlyingType ??= new NubIntType(false, 64);
var key = (module.NameToken.Value, enumSyntax.NameToken.Value);
enumTypes.Add(key, underlyingType);
}
}
// note(nub31): Since all struct types are now registered, we can safely resolve the field types
foreach (var syntaxTree in syntaxTrees)
{
var module = syntaxTree.TopLevelSyntaxNodes.OfType<ModuleSyntax>().FirstOrDefault();
if (module == null)
{
throw new CompileException(Diagnostic.Error("Module declaration missing").WithHelp("module \"main\"").Build());
}
foreach (var structSyntax in syntaxTree.TopLevelSyntaxNodes.OfType<StructSyntax>())
{
var key = (module.NameToken.Value, structSyntax.NameToken.Value);
structTypes[key].Fields = structSyntax.Fields
.Select(x => new NubStructFieldType(x.NameToken.Value, ResolveType(x.Type, module.NameToken.Value), x.Value != null))
.ToList();
}
}
var modules = new Dictionary<string, Module>();
foreach (var syntaxTree in syntaxTrees)
{
var moduleDecl = syntaxTree.TopLevelSyntaxNodes.OfType<ModuleSyntax>().FirstOrDefault();
if (moduleDecl == null)
{
throw new CompileException(Diagnostic.Error("Module declaration missing").WithHelp("module \"main\"").Build());
}
var functionPrototypes = new List<FuncPrototypeNode>();
foreach (var funcSyntax in syntaxTree.TopLevelSyntaxNodes.OfType<FuncSyntax>())
{
var returnType = ResolveType(funcSyntax.Prototype.ReturnType, moduleDecl.NameToken.Value);
var parameters = funcSyntax.Prototype.Parameters.Select(x => new FuncParameterNode(x.Tokens, x.NameToken, ResolveType(x.Type, moduleDecl.NameToken.Value))).ToList();
functionPrototypes.Add(new FuncPrototypeNode(funcSyntax.Prototype.Tokens, funcSyntax.Prototype.NameToken, funcSyntax.Prototype.ExternSymbolToken, parameters, returnType));
}
var module = new Module
{
Name = moduleDecl.NameToken.Value,
StructTypes = structTypes.Where(x => x.Key.module == moduleDecl.NameToken.Value).Select(x => x.Value).ToList(),
EnumTypes = enumTypes
.Where(x => x.Key.module == moduleDecl.NameToken.Value)
.ToDictionary(x => x.Key.name, x => x.Value),
FunctionPrototypes = functionPrototypes
};
modules.Add(moduleDecl.NameToken.Value, module);
}
return new ModuleRepository(modules);
NubType ResolveType(TypeSyntax type, string currentModule)
{
return type switch
{
ArrayTypeSyntax arr => new NubArrayType(ResolveType(arr.BaseType, currentModule)),
BoolTypeSyntax => new NubBoolType(),
IntTypeSyntax i => new NubIntType(i.Signed, i.Width),
FloatTypeSyntax f => new NubFloatType(f.Width),
FuncTypeSyntax func => new NubFuncType(func.Parameters.Select(x => ResolveType(x, currentModule)).ToList(), ResolveType(func.ReturnType, currentModule)),
SliceTypeSyntax slice => new NubSliceType(ResolveType(slice.BaseType, currentModule)),
ConstArrayTypeSyntax arr => new NubConstArrayType(ResolveType(arr.BaseType, currentModule), arr.Size),
PointerTypeSyntax ptr => new NubPointerType(ResolveType(ptr.BaseType, currentModule)),
StringTypeSyntax => new NubStringType(),
CustomTypeSyntax c => ResolveCustomType(c, currentModule),
VoidTypeSyntax => new NubVoidType(),
_ => throw new NotSupportedException($"Unknown type syntax: {type}")
};
}
NubType ResolveCustomType(CustomTypeSyntax customType, string currentModule)
{
var customTypeKey = (customType.ModuleToken?.Value ?? currentModule, customType.NameToken.Value);
var resolvedStructType = structTypes.GetValueOrDefault(customTypeKey);
if (resolvedStructType != null)
{
return resolvedStructType;
}
var resolvedEnumType = enumTypes.GetValueOrDefault(customTypeKey);
if (resolvedEnumType != null)
{
return resolvedEnumType;
}
throw new CompileException(Diagnostic
.Error($"Type {customType.NameToken.Value} not found in module {customType.ModuleToken?.Value ?? currentModule}")
.At(customType)
.Build());
}
}
public ModuleRepository(Dictionary<string, Module> modules)
{
_modules = modules;
}
private readonly Dictionary<string, Module> _modules;
public Module Get(IdentifierToken ident)
{
var module = _modules.GetValueOrDefault(ident.Value);
if (module == null)
{
throw new CompileException(Diagnostic.Error($"Module {ident.Value} was not found").At(ident).Build());
}
return module;
}
public bool TryGet(IdentifierToken ident, [NotNullWhen(true)] out Module? module)
{
module = _modules.GetValueOrDefault(ident.Value);
return module != null;
}
public bool TryGet(string name, [NotNullWhen(true)] out Module? module)
{
module = _modules.GetValueOrDefault(name);
return module != null;
}
public List<Module> GetAll()
{
return _modules.Values.ToList();
}
public sealed class Module
{
public required string Name { get; init; }
public required List<FuncPrototypeNode> FunctionPrototypes { get; init; } = [];
public required List<NubStructType> StructTypes { get; init; } = [];
public required Dictionary<string, NubIntType> EnumTypes { get; init; } = [];
public bool TryResolveFunc(string name, [NotNullWhen(true)] out FuncPrototypeNode? value, [NotNullWhen(false)] out Diagnostic? diagnostic)
{
value = FunctionPrototypes.FirstOrDefault(x => x.NameToken.Value == name);
if (value == null)
{
value = null;
diagnostic = Diagnostic.Error($"Func {name} not found in module {Name}").Build();
return false;
}
diagnostic = null;
return true;
}
public bool TryResolveFunc(IdentifierToken name, [NotNullWhen(true)] out FuncPrototypeNode? value, [NotNullWhen(false)] out Diagnostic? diagnostic)
{
value = FunctionPrototypes.FirstOrDefault(x => x.NameToken.Value == name.Value);
if (value == null)
{
value = null;
diagnostic = Diagnostic.Error($"Func {name.Value} not found in module {Name}").At(name).Build();
return false;
}
diagnostic = null;
return true;
}
public FuncPrototypeNode ResolveFunc(IdentifierToken name)
{
if (!TryResolveFunc(name, out var value, out var diagnostic))
{
throw new CompileException(diagnostic);
}
return value;
}
public bool TryResolveStruct(string name, [NotNullWhen(true)] out NubStructType? value, [NotNullWhen(false)] out Diagnostic? diagnostic)
{
value = StructTypes.FirstOrDefault(x => x.Name == name);
if (value == null)
{
value = null;
diagnostic = Diagnostic.Error($"Struct {name} not found in module {Name}").Build();
return false;
}
diagnostic = null;
return true;
}
public bool TryResolveStruct(IdentifierToken name, [NotNullWhen(true)] out NubStructType? value, [NotNullWhen(false)] out Diagnostic? diagnostic)
{
value = StructTypes.FirstOrDefault(x => x.Name == name.Value);
if (value == null)
{
value = null;
diagnostic = Diagnostic.Error($"Struct {name.Value} not found in module {Name}").At(name).Build();
return false;
}
diagnostic = null;
return true;
}
public NubStructType ResolveStruct(IdentifierToken name)
{
if (!TryResolveStruct(name, out var value, out var diagnostic))
{
throw new CompileException(diagnostic);
}
return value;
}
public bool TryResolveEnum(string name, [NotNullWhen(true)] out NubIntType? value, [NotNullWhen(false)] out Diagnostic? diagnostic)
{
value = EnumTypes.GetValueOrDefault(name);
if (value == null)
{
value = null;
diagnostic = Diagnostic.Error($"Enum {name} not found in module {Name}").Build();
return false;
}
diagnostic = null;
return true;
}
public bool TryResolveEnum(IdentifierToken name, [NotNullWhen(true)] out NubIntType? value, [NotNullWhen(false)] out Diagnostic? diagnostic)
{
value = EnumTypes.GetValueOrDefault(name.Value);
if (value == null)
{
value = null;
diagnostic = Diagnostic.Error($"Enum {name.Value} not found in module {Name}").At(name).Build();
return false;
}
diagnostic = null;
return true;
}
public NubIntType ResolveEnum(IdentifierToken name)
{
if (!TryResolveEnum(name, out var value, out var diagnostic))
{
throw new CompileException(diagnostic);
}
return value;
}
}
}

View File

@@ -7,4 +7,8 @@
<IsAotCompatible>true</IsAotCompatible> <IsAotCompatible>true</IsAotCompatible>
</PropertyGroup> </PropertyGroup>
<ItemGroup>
<PackageReference Include="LLVMSharp" Version="20.1.2" />
</ItemGroup>
</Project> </Project>

View File

@@ -1,60 +0,0 @@
namespace NubLang.Syntax;
public sealed class Module
{
public static Dictionary<string, Module> Collect(List<SyntaxTree> syntaxTrees)
{
var modules = new Dictionary<string, Module>();
foreach (var syntaxTree in syntaxTrees)
{
var moduleDeclaration = syntaxTree.TopLevelSyntaxNodes.OfType<ModuleSyntax>().FirstOrDefault();
if (moduleDeclaration != null)
{
if (!modules.TryGetValue(moduleDeclaration.NameToken.Value, out var module))
{
module = new Module();
modules.Add(moduleDeclaration.NameToken.Value, module);
}
module._definitions.AddRange(syntaxTree.TopLevelSyntaxNodes);
}
}
return modules;
}
private readonly List<TopLevelSyntaxNode> _definitions = [];
public List<StructSyntax> Structs(bool includePrivate)
{
return _definitions
.OfType<StructSyntax>()
.Where(x => x.Exported || includePrivate)
.ToList();
}
public List<FuncSyntax> Functions(bool includePrivate)
{
return _definitions
.OfType<FuncSyntax>()
.Where(x => x.Exported || includePrivate)
.ToList();
}
public List<EnumSyntax> Enums(bool includePrivate)
{
return _definitions
.OfType<EnumSyntax>()
.Where(x => x.Exported || includePrivate)
.ToList();
}
public List<string> Imports()
{
return _definitions
.OfType<ImportSyntax>()
.Select(x => x.NameToken.Value)
.Distinct()
.ToList();
}
}

View File

@@ -9,16 +9,30 @@ public sealed class Parser
private int _tokenIndex; private int _tokenIndex;
private Token? CurrentToken => _tokenIndex < _tokens.Count ? _tokens[_tokenIndex] : null; private Token? CurrentToken => _tokenIndex < _tokens.Count ? _tokens[_tokenIndex] : null;
private bool HasTrailingWhitespace(Token token)
{
var index = _tokens.IndexOf(token);
return index + 1 < _tokens.Count && _tokens[index + 1] is WhitespaceToken or CommentToken;
}
private bool HasLeadingWhitespace(Token token)
{
var index = _tokens.IndexOf(token);
return index - 1 < _tokens.Count && _tokens[index - 1] is WhitespaceToken or CommentToken;
}
private bool HasToken => CurrentToken != null; private bool HasToken => CurrentToken != null;
public List<Diagnostic> Diagnostics { get; } = []; public List<Diagnostic> Diagnostics { get; set; } = [];
public SyntaxTree Parse(List<Token> tokens) public SyntaxTree Parse(List<Token> tokens)
{ {
Diagnostics.Clear();
_tokens = tokens; _tokens = tokens;
_tokenIndex = 0; _tokenIndex = 0;
Diagnostics = [];
var topLevelSyntaxNodes = new List<TopLevelSyntaxNode>(); var topLevelSyntaxNodes = new List<TopLevelSyntaxNode>();
while (HasToken) while (HasToken)
@@ -42,14 +56,13 @@ public sealed class Parser
TopLevelSyntaxNode definition = keyword.Symbol switch TopLevelSyntaxNode definition = keyword.Symbol switch
{ {
Symbol.Module => ParseModule(startIndex), Symbol.Module => ParseModule(startIndex),
Symbol.Import => ParseImport(startIndex),
Symbol.Func => ParseFunc(startIndex, exported, null), Symbol.Func => ParseFunc(startIndex, exported, null),
Symbol.Struct => ParseStruct(startIndex, exported, packed), Symbol.Struct => ParseStruct(startIndex, exported, packed),
Symbol.Enum => ParseEnum(startIndex, exported), Symbol.Enum => ParseEnum(startIndex, exported),
_ => throw new CompileException(Diagnostic _ => throw new CompileException(Diagnostic
.Error($"Expected 'func', 'struct', 'enum', 'import' or 'module' but found '{keyword.Symbol}'") .Error($"Expected 'func', 'struct', 'enum', 'import' or 'module' but found '{keyword.Symbol}'")
.WithHelp("Valid top level statements are 'func', 'struct', 'enum', 'import' and 'module'") .WithHelp("Valid top level statements are 'func', 'struct', 'enum', 'import' and 'module'")
.At(keyword) .At(keyword, _tokens)
.Build()) .Build())
}; };
@@ -70,13 +83,7 @@ public sealed class Parser
} }
} }
return new SyntaxTree(topLevelSyntaxNodes); return new SyntaxTree(topLevelSyntaxNodes, _tokens);
}
private ImportSyntax ParseImport(int startIndex)
{
var name = ExpectIdentifier();
return new ImportSyntax(GetTokens(startIndex), name);
} }
private ModuleSyntax ParseModule(int startIndex) private ModuleSyntax ParseModule(int startIndex)
@@ -183,7 +190,7 @@ public sealed class Parser
{ {
throw new CompileException(Diagnostic throw new CompileException(Diagnostic
.Error("Value of enum field must be an integer literal") .Error("Value of enum field must be an integer literal")
.At(CurrentToken) .At(CurrentToken, _tokens)
.Build()); .Build());
} }
@@ -200,27 +207,36 @@ public sealed class Parser
{ {
var startIndex = _tokenIndex; var startIndex = _tokenIndex;
if (TryExpectSymbol(out var symbol)) if (CurrentToken is SymbolToken symbolToken)
{ {
switch (symbol) switch (symbolToken.Symbol)
{ {
case Symbol.OpenBrace: case Symbol.OpenBrace:
Next();
return ParseBlock(startIndex); return ParseBlock(startIndex);
case Symbol.Return: case Symbol.Return:
Next();
return ParseReturn(startIndex); return ParseReturn(startIndex);
case Symbol.If: case Symbol.If:
Next();
return ParseIf(startIndex); return ParseIf(startIndex);
case Symbol.While: case Symbol.While:
Next();
return ParseWhile(startIndex); return ParseWhile(startIndex);
case Symbol.For: case Symbol.For:
Next();
return ParseFor(startIndex); return ParseFor(startIndex);
case Symbol.Let: case Symbol.Let:
Next();
return ParseVariableDeclaration(startIndex); return ParseVariableDeclaration(startIndex);
case Symbol.Defer: case Symbol.Defer:
Next();
return ParseDefer(startIndex); return ParseDefer(startIndex);
case Symbol.Break: case Symbol.Break:
Next();
return new BreakSyntax(GetTokens(startIndex)); return new BreakSyntax(GetTokens(startIndex));
case Symbol.Continue: case Symbol.Continue:
Next();
return new ContinueSyntax(GetTokens(startIndex)); return new ContinueSyntax(GetTokens(startIndex));
} }
} }
@@ -325,7 +341,7 @@ public sealed class Parser
var startIndex = _tokenIndex; var startIndex = _tokenIndex;
var left = ParsePrimaryExpression(); var left = ParsePrimaryExpression();
while (CurrentToken is SymbolToken symbolToken && TryGetBinaryOperator(symbolToken.Symbol, out var op) && GetBinaryOperatorPrecedence(op.Value) >= precedence) while (CurrentToken is SymbolToken symbolToken && HasLeadingWhitespace(symbolToken) && HasTrailingWhitespace(symbolToken) && TryGetBinaryOperator(symbolToken.Symbol, out var op) && GetBinaryOperatorPrecedence(op.Value) >= precedence)
{ {
Next(); Next();
var right = ParseExpression(GetBinaryOperatorPrecedence(op.Value) + 1); var right = ParseExpression(GetBinaryOperatorPrecedence(op.Value) + 1);
@@ -423,7 +439,7 @@ public sealed class Parser
case Symbol.Pipe: case Symbol.Pipe:
binaryExpressionOperator = BinaryOperatorSyntax.BitwiseOr; binaryExpressionOperator = BinaryOperatorSyntax.BitwiseOr;
return true; return true;
case Symbol.XOr: case Symbol.Tilde:
binaryExpressionOperator = BinaryOperatorSyntax.BitwiseXor; binaryExpressionOperator = BinaryOperatorSyntax.BitwiseXor;
return true; return true;
default: default:
@@ -445,76 +461,37 @@ public sealed class Parser
IdentifierToken identifier => ParseIdentifier(startIndex, identifier), IdentifierToken identifier => ParseIdentifier(startIndex, identifier),
SymbolToken symbolToken => symbolToken.Symbol switch SymbolToken symbolToken => symbolToken.Symbol switch
{ {
Symbol.Ampersand => new AddressOfSyntax(GetTokens(startIndex), ParsePrimaryExpression()), Symbol.Caret => ParseAddressOf(startIndex),
Symbol.OpenParen => ParseParenthesizedExpression(), Symbol.OpenParen => ParseParenthesizedExpression(),
Symbol.Minus => new UnaryExpressionSyntax(GetTokens(startIndex), UnaryOperatorSyntax.Negate, ParsePrimaryExpression()), Symbol.Minus => ParseUnaryNegate(startIndex),
Symbol.Bang => new UnaryExpressionSyntax(GetTokens(startIndex), UnaryOperatorSyntax.Invert, ParsePrimaryExpression()), Symbol.Bang => ParseUnaryInvert(startIndex),
Symbol.OpenBracket => ParseArrayInitializer(startIndex), Symbol.OpenBracket => ParseArrayInitializer(startIndex),
Symbol.OpenBrace => new StructInitializerSyntax(GetTokens(startIndex), null, ParseStructInitializerBody()), Symbol.OpenBrace => ParseUnnamedStructInitializer(startIndex),
Symbol.Struct => ParseStructInitializer(startIndex), Symbol.Struct => ParseStructInitializer(startIndex),
Symbol.At => ParseBuiltinFunction(startIndex), Symbol.At => ParseBuiltinFunction(startIndex),
_ => throw new CompileException(Diagnostic _ => throw new CompileException(Diagnostic
.Error($"Unexpected symbol '{symbolToken.Symbol}' in expression") .Error($"Unexpected symbol '{symbolToken.Symbol}' in expression")
.WithHelp("Expected '(', '-', '!', '[' or '{'") .WithHelp("Expected '(', '-', '!', '[' or '{'")
.At(symbolToken) .At(symbolToken, _tokens)
.Build()) .Build())
}, },
_ => throw new CompileException(Diagnostic _ => throw new CompileException(Diagnostic
.Error($"Unexpected token '{token.GetType().Name}' in expression") .Error($"Unexpected token '{token.GetType().Name}' in expression")
.WithHelp("Expected literal, identifier, or parenthesized expression") .WithHelp("Expected literal, identifier, or parenthesized expression")
.At(token) .At(token, _tokens)
.Build()) .Build())
}; };
return ParsePostfixOperators(expr); return ParsePostfixOperators(expr);
} }
private ExpressionSyntax ParseBuiltinFunction(int startIndex)
{
var name = ExpectIdentifier();
ExpectSymbol(Symbol.OpenParen);
switch (name.Value)
{
case "size":
{
var type = ParseType();
ExpectSymbol(Symbol.CloseParen);
return new SizeSyntax(GetTokens(startIndex), type);
}
case "cast":
{
var expression = ParseExpression();
ExpectSymbol(Symbol.CloseParen);
return new CastSyntax(GetTokens(startIndex), expression);
}
default:
{
throw new CompileException(Diagnostic.Error($"Unknown builtin {name.Value}").At(name).Build());
}
}
}
private ExpressionSyntax ParseIdentifier(int startIndex, IdentifierToken identifier)
{
if (TryExpectSymbol(Symbol.DoubleColon))
{
var name = ExpectIdentifier();
return new ModuleIdentifierSyntax(GetTokens(startIndex), identifier, name);
}
return new LocalIdentifierSyntax(GetTokens(startIndex), identifier);
}
private ExpressionSyntax ParseParenthesizedExpression()
{
var expression = ParseExpression();
ExpectSymbol(Symbol.CloseParen);
return expression;
}
private ExpressionSyntax ParsePostfixOperators(ExpressionSyntax expr) private ExpressionSyntax ParsePostfixOperators(ExpressionSyntax expr)
{ {
if (CurrentToken == null || HasLeadingWhitespace(CurrentToken))
{
return expr;
}
var startIndex = _tokenIndex; var startIndex = _tokenIndex;
while (HasToken) while (HasToken)
{ {
@@ -563,6 +540,68 @@ public sealed class Parser
return expr; return expr;
} }
private ExpressionSyntax ParseParenthesizedExpression()
{
var expression = ParseExpression();
ExpectSymbol(Symbol.CloseParen);
return expression;
}
private AddressOfSyntax ParseAddressOf(int startIndex)
{
var expression = ParsePrimaryExpression();
return new AddressOfSyntax(GetTokens(startIndex), expression);
}
private UnaryExpressionSyntax ParseUnaryInvert(int startIndex)
{
var expression = ParsePrimaryExpression();
return new UnaryExpressionSyntax(GetTokens(startIndex), UnaryOperatorSyntax.Invert, expression);
}
private UnaryExpressionSyntax ParseUnaryNegate(int startIndex)
{
var expression = ParsePrimaryExpression();
return new UnaryExpressionSyntax(GetTokens(startIndex), UnaryOperatorSyntax.Negate, expression);
}
private ExpressionSyntax ParseBuiltinFunction(int startIndex)
{
var name = ExpectIdentifier();
ExpectSymbol(Symbol.OpenParen);
switch (name.Value)
{
case "size":
{
var type = ParseType();
ExpectSymbol(Symbol.CloseParen);
return new SizeSyntax(GetTokens(startIndex), type);
}
case "cast":
{
var expression = ParseExpression();
ExpectSymbol(Symbol.CloseParen);
return new CastSyntax(GetTokens(startIndex), expression);
}
default:
{
throw new CompileException(Diagnostic.Error($"Unknown builtin {name.Value}").At(name).Build());
}
}
}
private ExpressionSyntax ParseIdentifier(int startIndex, IdentifierToken identifier)
{
if (TryExpectSymbol(Symbol.DoubleColon))
{
var name = ExpectIdentifier();
return new ModuleIdentifierSyntax(GetTokens(startIndex), identifier, name);
}
return new LocalIdentifierSyntax(GetTokens(startIndex), identifier);
}
private ExpressionSyntax ParseArrayInitializer(int startIndex) private ExpressionSyntax ParseArrayInitializer(int startIndex)
{ {
var values = new List<ExpressionSyntax>(); var values = new List<ExpressionSyntax>();
@@ -593,6 +632,12 @@ public sealed class Parser
return new StructInitializerSyntax(GetTokens(startIndex), type, initializers); return new StructInitializerSyntax(GetTokens(startIndex), type, initializers);
} }
private StructInitializerSyntax ParseUnnamedStructInitializer(int startIndex)
{
var body = ParseStructInitializerBody();
return new StructInitializerSyntax(GetTokens(startIndex), null, body);
}
private Dictionary<IdentifierToken, ExpressionSyntax> ParseStructInitializerBody() private Dictionary<IdentifierToken, ExpressionSyntax> ParseStructInitializerBody()
{ {
Dictionary<IdentifierToken, ExpressionSyntax> initializers = []; Dictionary<IdentifierToken, ExpressionSyntax> initializers = [];
@@ -653,7 +698,7 @@ public sealed class Parser
throw new CompileException(Diagnostic throw new CompileException(Diagnostic
.Error("Arbitrary uint size is not supported") .Error("Arbitrary uint size is not supported")
.WithHelp("Use u8, u16, u32 or u64") .WithHelp("Use u8, u16, u32 or u64")
.At(name) .At(name, _tokens)
.Build()); .Build());
} }
@@ -667,7 +712,7 @@ public sealed class Parser
throw new CompileException(Diagnostic throw new CompileException(Diagnostic
.Error("Arbitrary int size is not supported") .Error("Arbitrary int size is not supported")
.WithHelp("Use i8, i16, i32 or i64") .WithHelp("Use i8, i16, i32 or i64")
.At(name) .At(name, _tokens)
.Build()); .Build());
} }
@@ -681,7 +726,7 @@ public sealed class Parser
throw new CompileException(Diagnostic throw new CompileException(Diagnostic
.Error("Arbitrary float size is not supported") .Error("Arbitrary float size is not supported")
.WithHelp("Use f32 or f64") .WithHelp("Use f32 or f64")
.At(name) .At(name, _tokens)
.Build()); .Build());
} }
@@ -765,7 +810,7 @@ public sealed class Parser
throw new CompileException(Diagnostic throw new CompileException(Diagnostic
.Error("Invalid type syntax") .Error("Invalid type syntax")
.WithHelp("Expected type name, '^' for pointer, or '[]' for array") .WithHelp("Expected type name, '^' for pointer, or '[]' for array")
.At(CurrentToken) .At(CurrentToken, _tokens)
.Build()); .Build());
} }
@@ -776,7 +821,7 @@ public sealed class Parser
throw new CompileException(Diagnostic throw new CompileException(Diagnostic
.Error("Unexpected end of file") .Error("Unexpected end of file")
.WithHelp("Expected more tokens to complete the syntax") .WithHelp("Expected more tokens to complete the syntax")
.At(_tokens[^1]) .At(_tokens[^1], _tokens)
.Build()); .Build());
} }
@@ -793,7 +838,7 @@ public sealed class Parser
throw new CompileException(Diagnostic throw new CompileException(Diagnostic
.Error($"Expected symbol, but found {token.GetType().Name}") .Error($"Expected symbol, but found {token.GetType().Name}")
.WithHelp("This position requires a symbol like '(', ')', '{', '}', etc.") .WithHelp("This position requires a symbol like '(', ')', '{', '}', etc.")
.At(token) .At(token, _tokens)
.Build()); .Build());
} }
@@ -808,7 +853,7 @@ public sealed class Parser
throw new CompileException(Diagnostic throw new CompileException(Diagnostic
.Error($"Expected '{expectedSymbol}', but found '{token.Symbol}'") .Error($"Expected '{expectedSymbol}', but found '{token.Symbol}'")
.WithHelp($"Insert '{expectedSymbol}' here") .WithHelp($"Insert '{expectedSymbol}' here")
.At(token) .At(token, _tokens)
.Build()); .Build());
} }
} }
@@ -858,7 +903,7 @@ public sealed class Parser
throw new CompileException(Diagnostic throw new CompileException(Diagnostic
.Error($"Expected identifier, but found {token.GetType().Name}") .Error($"Expected identifier, but found {token.GetType().Name}")
.WithHelp("Provide a valid identifier name here") .WithHelp("Provide a valid identifier name here")
.At(token) .At(token, _tokens)
.Build()); .Build());
} }
@@ -886,7 +931,7 @@ public sealed class Parser
throw new CompileException(Diagnostic throw new CompileException(Diagnostic
.Error($"Expected string literal, but found {token.GetType().Name}") .Error($"Expected string literal, but found {token.GetType().Name}")
.WithHelp("Provide a valid string literal") .WithHelp("Provide a valid string literal")
.At(token) .At(token, _tokens)
.Build()); .Build());
} }
@@ -896,6 +941,10 @@ public sealed class Parser
private void Next() private void Next()
{ {
_tokenIndex++; _tokenIndex++;
while (_tokenIndex < _tokens.Count && _tokens[_tokenIndex] is WhitespaceToken or CommentToken)
{
_tokenIndex++;
}
} }
private List<Token> GetTokens(int tokenStartIndex) private List<Token> GetTokens(int tokenStartIndex)
@@ -904,4 +953,4 @@ public sealed class Parser
} }
} }
public record SyntaxTree(List<TopLevelSyntaxNode> TopLevelSyntaxNodes); public record SyntaxTree(List<TopLevelSyntaxNode> TopLevelSyntaxNodes, List<Token> Tokens);

View File

@@ -8,8 +8,6 @@ public record TopLevelSyntaxNode(List<Token> Tokens) : SyntaxNode(Tokens);
public record ModuleSyntax(List<Token> Tokens, IdentifierToken NameToken) : TopLevelSyntaxNode(Tokens); public record ModuleSyntax(List<Token> Tokens, IdentifierToken NameToken) : TopLevelSyntaxNode(Tokens);
public record ImportSyntax(List<Token> Tokens, IdentifierToken NameToken) : TopLevelSyntaxNode(Tokens);
public abstract record DefinitionSyntax(List<Token> Tokens, IdentifierToken NameToken, bool Exported) : TopLevelSyntaxNode(Tokens); public abstract record DefinitionSyntax(List<Token> Tokens, IdentifierToken NameToken, bool Exported) : TopLevelSyntaxNode(Tokens);
public record FuncParameterSyntax(List<Token> Tokens, IdentifierToken NameToken, TypeSyntax Type) : SyntaxNode(Tokens); public record FuncParameterSyntax(List<Token> Tokens, IdentifierToken NameToken, TypeSyntax Type) : SyntaxNode(Tokens);

View File

@@ -2,29 +2,45 @@
namespace NubLang.Syntax; namespace NubLang.Syntax;
public abstract record Token(SourceSpan Span); public abstract class Token(SourceSpan span)
public record IdentifierToken(SourceSpan Span, string Value) : Token(Span)
{ {
public SourceSpan Span { get; } = span;
}
public class WhitespaceToken(SourceSpan span) : Token(span);
public class CommentToken(SourceSpan span, string comment) : Token(span)
{
public string Comment { get; } = comment;
public override string ToString()
{
return "// " + Comment;
}
}
public class IdentifierToken(SourceSpan span, string value) : Token(span)
{
public string Value { get; } = value;
public override string ToString() public override string ToString()
{ {
return Value; return Value;
} }
} }
public record IntLiteralToken(SourceSpan Span, string Value, int Base) : Token(Span) public class IntLiteralToken(SourceSpan span, string value, int @base) : Token(span)
{ {
public string Value { get; } = value;
public int Base { get; } = @base;
private string GetNumericValue() private string GetNumericValue()
{ {
// Strip base prefixes: 0b, 0o, 0x
return Base switch return Base switch
{ {
2 when Value.StartsWith("0b", StringComparison.OrdinalIgnoreCase) 2 when Value.StartsWith("0b", StringComparison.OrdinalIgnoreCase) => Value[2..],
=> Value.Substring(2), 8 when Value.StartsWith("0o", StringComparison.OrdinalIgnoreCase) => Value[2..],
8 when Value.StartsWith("0o", StringComparison.OrdinalIgnoreCase) 16 when Value.StartsWith("0x", StringComparison.OrdinalIgnoreCase) => Value[2..],
=> Value.Substring(2),
16 when Value.StartsWith("0x", StringComparison.OrdinalIgnoreCase)
=> Value.Substring(2),
_ => Value _ => Value
}; };
} }
@@ -47,24 +63,30 @@ public record IntLiteralToken(SourceSpan Span, string Value, int Base) : Token(S
} }
} }
public record StringLiteralToken(SourceSpan Span, string Value) : Token(Span) public class StringLiteralToken(SourceSpan span, string value) : Token(span)
{ {
public string Value { get; } = value;
public override string ToString() public override string ToString()
{ {
return $"\"{Value}\""; return $"\"{Value}\"";
} }
} }
public record BoolLiteralToken(SourceSpan Span, bool Value) : Token(Span) public class BoolLiteralToken(SourceSpan span, bool value) : Token(span)
{ {
public bool Value { get; } = value;
public override string ToString() public override string ToString()
{ {
return Value ? "true" : "false"; return Value ? "true" : "false";
} }
} }
public record FloatLiteralToken(SourceSpan Span, string Value) : Token(Span) public class FloatLiteralToken(SourceSpan span, string value) : Token(span)
{ {
public string Value { get; } = value;
public float AsF32 => Convert.ToSingle(Value); public float AsF32 => Convert.ToSingle(Value);
public double AsF64 => Convert.ToDouble(Value); public double AsF64 => Convert.ToDouble(Value);
@@ -95,7 +117,6 @@ public enum Symbol
Func, Func,
Struct, Struct,
Enum, Enum,
Import,
Module, Module,
// Modifier // Modifier
@@ -126,6 +147,7 @@ public enum Symbol
Star, Star,
ForwardSlash, ForwardSlash,
Caret, Caret,
Tilde,
Ampersand, Ampersand,
Semi, Semi,
Percent, Percent,
@@ -134,13 +156,14 @@ public enum Symbol
Pipe, Pipe,
And, And,
Or, Or,
XOr,
At, At,
QuestionMark, QuestionMark,
} }
public record SymbolToken(SourceSpan Span, Symbol Symbol) : Token(Span) public class SymbolToken(SourceSpan span, Symbol symbol) : Token(span)
{ {
public Symbol Symbol { get; } = symbol;
public override string ToString() public override string ToString()
{ {
return Symbol switch return Symbol switch
@@ -159,7 +182,6 @@ public record SymbolToken(SourceSpan Span, Symbol Symbol) : Token(Span)
Symbol.Extern => "extern", Symbol.Extern => "extern",
Symbol.Module => "module", Symbol.Module => "module",
Symbol.Export => "export", Symbol.Export => "export",
Symbol.Import => "import",
Symbol.Defer => "defer", Symbol.Defer => "defer",
Symbol.Enum => "enum", Symbol.Enum => "enum",
Symbol.Equal => "==", Symbol.Equal => "==",

View File

@@ -4,59 +4,32 @@ namespace NubLang.Syntax;
public sealed class Tokenizer public sealed class Tokenizer
{ {
private readonly string _fileName; private string _fileName = null!;
private readonly string _content; private string _content = null!;
private int _index; private int _index;
private int _line = 1; private int _line = 1;
private int _column = 1; private int _column = 1;
public Tokenizer(string fileName, string content) public List<Diagnostic> Diagnostics { get; set; } = new(16);
public List<Token> Tokenize(string fileName, string content)
{ {
_fileName = fileName; _fileName = fileName;
_content = content; _content = content;
}
public List<Diagnostic> Diagnostics { get; } = new(16); Diagnostics = [];
public List<Token> Tokens { get; } = new(256);
public void Tokenize()
{
Diagnostics.Clear();
Tokens.Clear();
_index = 0; _index = 0;
_line = 1; _line = 1;
_column = 1; _column = 1;
var tokens = new List<Token>();
while (_index < _content.Length) while (_index < _content.Length)
{ {
try try
{ {
var current = _content[_index]; tokens.Add(ParseToken());
if (char.IsWhiteSpace(current))
{
if (current == '\n')
{
_line += 1;
_column = 0;
}
Next();
continue;
}
if (current == '/' && _index + 1 < _content.Length && _content[_index + 1] == '/')
{
Next(2);
while (_index < _content.Length && _content[_index] != '\n')
{
Next();
}
continue;
}
Tokens.Add(ParseToken(current, _line, _column));
} }
catch (CompileException e) catch (CompileException e)
{ {
@@ -64,38 +37,67 @@ public sealed class Tokenizer
Next(); Next();
} }
} }
return tokens;
} }
private Token ParseToken(char current, int lineStart, int columnStart) private Token ParseToken()
{ {
if (char.IsDigit(current)) var indexStart = _index;
var lineStart = _line;
var columnStart = _column;
if (char.IsWhiteSpace(_content[_index]))
{ {
return ParseNumber(lineStart, columnStart); while (_index < _content.Length && char.IsWhiteSpace(_content[_index]))
{
Next();
}
return new WhitespaceToken(CreateSpan(indexStart, lineStart, columnStart));
} }
if (current == '"') if (_content[_index] == '/' && _index + 1 < _content.Length && _content[_index + 1] == '/')
{ {
return ParseString(lineStart, columnStart); var startIndex = _index;
Next(2);
while (_index < _content.Length && _content[_index] != '\n')
{
Next();
}
return new CommentToken(CreateSpan(indexStart, lineStart, columnStart), _content.AsSpan(startIndex, _index - startIndex).ToString());
}
if (char.IsDigit(_content[_index]))
{
return ParseNumber(indexStart, lineStart, columnStart);
}
if (_content[_index] == '"')
{
return ParseString(indexStart, lineStart, columnStart);
} }
// note(nub31): Look for keywords (longest first in case a keyword fits partially in a larger keyword) // note(nub31): Look for keywords (longest first in case a keyword fits partially in a larger keyword)
for (var i = 8; i >= 1; i--) for (var i = 8; i >= 1; i--)
{ {
if (TryMatchSymbol(i, lineStart, columnStart, out var token)) if (TryMatchSymbol(i, indexStart, lineStart, columnStart, out var token))
{ {
return token; return token;
} }
} }
if (char.IsLetter(current) || current == '_') if (char.IsLetter(_content[_index]) || _content[_index] == '_')
{ {
return ParseIdentifier(lineStart, columnStart); return ParseIdentifier(indexStart, lineStart, columnStart);
} }
throw new CompileException(Diagnostic.Error($"Unknown token '{current}'").Build()); throw new CompileException(Diagnostic.Error($"Unknown token '{_content[_index]}'").Build());
} }
private Token ParseNumber(int lineStart, int columnStart) private Token ParseNumber(int indexStart, int lineStart, int columnStart)
{ {
var start = _index; var start = _index;
var current = _content[_index]; var current = _content[_index];
@@ -115,12 +117,12 @@ public sealed class Tokenizer
{ {
throw new CompileException(Diagnostic throw new CompileException(Diagnostic
.Error("Invalid hex literal, no digits found") .Error("Invalid hex literal, no digits found")
.At(_fileName, _line, _column) .At(CreateSpan(_index, _line, _column))
.Build()); .Build());
} }
return new IntLiteralToken( return new IntLiteralToken(
CreateSpan(lineStart, columnStart), CreateSpan(indexStart, lineStart, columnStart),
_content.Substring(start, _index - start), _content.Substring(start, _index - start),
16); 16);
} }
@@ -140,12 +142,12 @@ public sealed class Tokenizer
{ {
throw new CompileException(Diagnostic throw new CompileException(Diagnostic
.Error("Invalid binary literal, no digits found") .Error("Invalid binary literal, no digits found")
.At(_fileName, _line, _column) .At(CreateSpan(_index, _line, _column))
.Build()); .Build());
} }
return new IntLiteralToken( return new IntLiteralToken(
CreateSpan(lineStart, columnStart), CreateSpan(indexStart, lineStart, columnStart),
_content.Substring(start, _index - start), _content.Substring(start, _index - start),
2); 2);
} }
@@ -162,7 +164,7 @@ public sealed class Tokenizer
{ {
throw new CompileException(Diagnostic throw new CompileException(Diagnostic
.Error("More than one period found in float literal") .Error("More than one period found in float literal")
.At(_fileName, _line, _column) .At(CreateSpan(_index, _line, _column))
.Build()); .Build());
} }
@@ -182,11 +184,11 @@ public sealed class Tokenizer
var buffer = _content.Substring(start, _index - start); var buffer = _content.Substring(start, _index - start);
return isFloat return isFloat
? new FloatLiteralToken(CreateSpan(lineStart, columnStart), buffer) ? new FloatLiteralToken(CreateSpan(indexStart, lineStart, columnStart), buffer)
: new IntLiteralToken(CreateSpan(lineStart, columnStart), buffer, 10); : new IntLiteralToken(CreateSpan(indexStart, lineStart, columnStart), buffer, 10);
} }
private StringLiteralToken ParseString(int lineStart, int columnStart) private StringLiteralToken ParseString(int indexStart, int lineStart, int columnStart)
{ {
Next(); Next();
var start = _index; var start = _index;
@@ -197,7 +199,7 @@ public sealed class Tokenizer
{ {
throw new CompileException(Diagnostic throw new CompileException(Diagnostic
.Error("Unclosed string literal") .Error("Unclosed string literal")
.At(_fileName, _line, _column) .At(CreateSpan(_index, _line, _column))
.Build()); .Build());
} }
@@ -207,7 +209,7 @@ public sealed class Tokenizer
{ {
throw new CompileException(Diagnostic throw new CompileException(Diagnostic
.Error("Unclosed string literal (newline found)") .Error("Unclosed string literal (newline found)")
.At(_fileName, _line, _column) .At(CreateSpan(_index, _line, _column))
.Build()); .Build());
} }
@@ -215,14 +217,14 @@ public sealed class Tokenizer
{ {
var buffer = _content.Substring(start, _index - start); var buffer = _content.Substring(start, _index - start);
Next(); Next();
return new StringLiteralToken(CreateSpan(lineStart, columnStart), buffer); return new StringLiteralToken(CreateSpan(indexStart, lineStart, columnStart), buffer);
} }
Next(); Next();
} }
} }
private bool TryMatchSymbol(int length, int lineStart, int columnStart, out Token token) private bool TryMatchSymbol(int length, int indexStart, int lineStart, int columnStart, out Token token)
{ {
token = null!; token = null!;
@@ -236,14 +238,14 @@ public sealed class Tokenizer
if (span is "true") if (span is "true")
{ {
Next(4); Next(4);
token = new BoolLiteralToken(CreateSpan(lineStart, columnStart), true); token = new BoolLiteralToken(CreateSpan(indexStart, lineStart, columnStart), true);
return true; return true;
} }
if (span is "false") if (span is "false")
{ {
Next(5); Next(5);
token = new BoolLiteralToken(CreateSpan(lineStart, columnStart), false); token = new BoolLiteralToken(CreateSpan(indexStart, lineStart, columnStart), false);
return true; return true;
} }
@@ -262,7 +264,6 @@ public sealed class Tokenizer
"packed" => Symbol.Packed, "packed" => Symbol.Packed,
"module" => Symbol.Module, "module" => Symbol.Module,
"export" => Symbol.Export, "export" => Symbol.Export,
"import" => Symbol.Import,
_ => Symbol.None _ => Symbol.None
}, },
5 => span switch 5 => span switch
@@ -298,7 +299,6 @@ public sealed class Tokenizer
"&&" => Symbol.And, "&&" => Symbol.And,
"||" => Symbol.Or, "||" => Symbol.Or,
"::" => Symbol.DoubleColon, "::" => Symbol.DoubleColon,
"x|" => Symbol.XOr,
_ => Symbol.None _ => Symbol.None
}, },
1 => span[0] switch 1 => span[0] switch
@@ -327,6 +327,7 @@ public sealed class Tokenizer
'|' => Symbol.Pipe, '|' => Symbol.Pipe,
'@' => Symbol.At, '@' => Symbol.At,
'?' => Symbol.QuestionMark, '?' => Symbol.QuestionMark,
'~' => Symbol.Tilde,
_ => Symbol.None _ => Symbol.None
}, },
_ => Symbol.None _ => Symbol.None
@@ -349,14 +350,14 @@ public sealed class Tokenizer
} }
Next(length); Next(length);
token = new SymbolToken(CreateSpan(lineStart, columnStart), symbol); token = new SymbolToken(CreateSpan(indexStart, lineStart, columnStart), symbol);
return true; return true;
} }
return false; return false;
} }
private IdentifierToken ParseIdentifier(int lineStart, int columnStart) private IdentifierToken ParseIdentifier(int indexStart, int lineStart, int columnStart)
{ {
var start = _index; var start = _index;
@@ -373,17 +374,36 @@ public sealed class Tokenizer
} }
} }
return new IdentifierToken(CreateSpan(lineStart, columnStart), _content.Substring(start, _index - start)); return new IdentifierToken(CreateSpan(indexStart, lineStart, columnStart), _content.Substring(start, _index - start));
} }
private SourceSpan CreateSpan(int lineStart, int columnStart) private SourceSpan CreateSpan(int indexStart, int lineStart, int columnStart)
{ {
return new SourceSpan(_fileName, new SourceLocation(lineStart, columnStart), new SourceLocation(_line, _column)); return new SourceSpan(_fileName, _content, indexStart, Math.Min(_index, _content.Length), lineStart, columnStart, _line, _column);
} }
private void Next(int count = 1) private void Next(int count = 1)
{ {
_index += count; for (var i = 0; i < count; i++)
_column += count; {
if (_index < _content.Length)
{
if (_content[_index] == '\n')
{
_line += 1;
_column = 1;
}
else
{
_column++;
}
}
else
{
_column++;
}
_index++;
}
} }
} }

View File

@@ -1,52 +0,0 @@
using NubLang.Ast;
namespace NubLang.Syntax;
public sealed class TypedModule
{
public static TypedModule FromModule(string name, Module module, Dictionary<string, Module> modules)
{
var typeResolver = new TypeResolver(modules);
var functionPrototypes = new List<FuncPrototypeNode>();
foreach (var funcSyntax in module.Functions(true))
{
var parameters = new List<FuncParameterNode>();
foreach (var parameter in funcSyntax.Prototype.Parameters)
{
parameters.Add(new FuncParameterNode(parameter.Tokens, parameter.NameToken, typeResolver.ResolveType(parameter.Type, name)));
}
var returnType = typeResolver.ResolveType(funcSyntax.Prototype.ReturnType, name);
functionPrototypes.Add(new FuncPrototypeNode(funcSyntax.Tokens, funcSyntax.Prototype.NameToken, funcSyntax.Prototype.ExternSymbolToken, parameters, returnType));
}
var structTypes = new List<NubStructType>();
foreach (var structSyntax in module.Structs(true))
{
var fields = new List<NubStructFieldType>();
foreach (var field in structSyntax.Fields)
{
fields.Add(new NubStructFieldType(field.NameToken.Value, typeResolver.ResolveType(field.Type, name), field.Value != null));
}
structTypes.Add(new NubStructType(name, structSyntax.NameToken.Value, structSyntax.Packed, fields));
}
return new TypedModule(functionPrototypes, structTypes, module.Imports());
}
public TypedModule(List<FuncPrototypeNode> functionPrototypes, List<NubStructType> structTypes, List<string> imports)
{
FunctionPrototypes = functionPrototypes;
StructTypes = structTypes;
Imports = imports;
}
public List<FuncPrototypeNode> FunctionPrototypes { get; }
public List<NubStructType> StructTypes { get; }
public List<string> Imports { get; }
}

View File

@@ -1,7 +1,7 @@
using System.Security.Cryptography; using System.Security.Cryptography;
using System.Text; using System.Text;
namespace NubLang.Ast; namespace NubLang.Types;
public abstract class NubType : IEquatable<NubType> public abstract class NubType : IEquatable<NubType>
{ {

View File

@@ -2,6 +2,13 @@ module main
extern "puts" func puts(text: ^i8) extern "puts" func puts(text: ^i8)
struct Test {
test: ^i8 = "test1"
}
extern "main" func main(argc: i64, argv: [?]^i8) extern "main" func main(argc: i64, argv: [?]^i8)
{ {
let x = "test"
puts(x)
} }

View File

@@ -1,5 +1,3 @@
import raylib
module main module main
extern "main" func main(argc: i64, argv: [?]^i8): i64 extern "main" func main(argc: i64, argv: [?]^i8): i64

View File

@@ -31,6 +31,12 @@
"configuration": "./language-configuration.json" "configuration": "./language-configuration.json"
} }
], ],
"commands": [
{
"command": "nub.setRootPath",
"title": "Set root path"
}
],
"grammars": [ "grammars": [
{ {
"language": "nub", "language": "nub",

View File

@@ -32,7 +32,19 @@ export async function activate(context: vscode.ExtensionContext) {
} }
); );
vscode.commands.registerCommand('nub.setRootPath', setRootPath);
client.start(); client.start();
const choice = await vscode.window.showInformationMessage(
'Do you want to set the root directory for the project',
'Yes',
'No'
);
if (choice === 'Yes') {
await setRootPath();
}
} }
export function deactivate(): Thenable<void> | undefined { export function deactivate(): Thenable<void> | undefined {
@@ -41,4 +53,26 @@ export function deactivate(): Thenable<void> | undefined {
} }
return client.stop(); return client.stop();
} }
async function setRootPath() {
if (!client) return;
const folder = await vscode.window.showOpenDialog({
canSelectFolders: true,
canSelectFiles: false,
canSelectMany: false,
openLabel: 'Select root location'
});
if (folder && folder.length > 0) {
const newRoot = folder[0].fsPath;
await client.sendRequest('workspace/executeCommand', {
command: 'nub.setRootPath',
arguments: [newRoot]
});
vscode.window.showInformationMessage(`Root path set to: ${newRoot}`);
}
}