This commit is contained in:
nub31
2025-06-12 22:49:58 +02:00
parent 7effe36988
commit e919d81737
15 changed files with 413 additions and 458 deletions

View File

@@ -6,9 +6,12 @@ dotnet build src/lang/Nub.Lang.CLI
mkdir -p bin-int bin mkdir -p bin-int bin
rm -rf bin-int/* bin/* rm -rf bin-int/* bin/*
nub example > bin-int/out.ssa nub example
qbe bin-int/out.ssa > bin-int/out.s
as -o bin-int/out.o bin-int/out.s find bin-int -name '*.ssa' | while read -r file; do
qbe "$file" > "bin-int/$(basename "${file}" .ssa).s"
as "bin-int/$(basename "${file}" .ssa).s" -o "bin-int/$(basename "${file}" .ssa).o"
done
find src/runtime -name '*.s' | while read -r file; do find src/runtime -name '*.s' | while read -r file; do
as "$file" -o "bin-int/$(basename "${file}" .s).o" as "$file" -o "bin-int/$(basename "${file}" .s).o"

View File

@@ -1,4 +1,5 @@
using Nub.Lang; using Nub.Lang;
using Nub.Lang.Frontend;
using Nub.Lang.Frontend.Generation; using Nub.Lang.Frontend.Generation;
using Nub.Lang.Frontend.Lexing; using Nub.Lang.Frontend.Lexing;
using Nub.Lang.Frontend.Parsing; using Nub.Lang.Frontend.Parsing;
@@ -20,38 +21,43 @@ if (!Directory.Exists(srcDir))
} }
var error = false; var error = false;
var lexer = new Lexer();
var parser = new Parser();
var typeChecker = new TypeChecker();
List<SourceFile> files = []; List<CompilationUnit> compilationUnits = [];
foreach (var file in Directory.EnumerateFiles(srcDir, "*.nub", SearchOption.AllDirectories)) foreach (var file in Directory.EnumerateFiles(srcDir, "*.nub", SearchOption.AllDirectories))
{ {
var content = File.ReadAllText(file); var content = File.ReadAllText(file);
var tokenizeResult = lexer.Tokenize(new SourceText(file, content)); var tokenizeResult = Lexer.Tokenize(new SourceText(file, content));
tokenizeResult.PrintAllDiagnostics(); tokenizeResult.PrintAllDiagnostics();
error = error || tokenizeResult.HasErrors; error = error || tokenizeResult.HasErrors;
var parseResult = parser.ParseModule(tokenizeResult.Value); var parseResult = Parser.ParseFile(tokenizeResult.Value);
parseResult.PrintAllDiagnostics(); parseResult.PrintAllDiagnostics();
error = error || parseResult.HasErrors; error = error || parseResult.HasErrors;
if (parseResult.Value != null) if (parseResult.Value != null)
{ {
files.Add(parseResult.Value); compilationUnits.Add(parseResult.Value);
} }
} }
var typeCheckResult = typeChecker.TypeCheck(files); if (error)
typeCheckResult.PrintAllDiagnostics(); {
error = error || typeCheckResult.HasErrors; return 1;
}
if (error) return 1; var definitionTable = new DefinitionTable(compilationUnits);
var generator = new QBEGenerator(); foreach (var compilationUnit in compilationUnits)
var result = generator.Generate(files); {
var typeCheckResult = TypeChecker.Check(compilationUnit, definitionTable);
typeCheckResult.PrintAllDiagnostics();
error = error || typeCheckResult.HasErrors;
Console.Out.Write(result); var result = QBEGenerator.Generate(compilationUnit, definitionTable);
return 0; File.WriteAllText($"bin-int/{Guid.NewGuid():N}.ssa", result);
}
return error ? 1 : 0;

View File

@@ -46,8 +46,6 @@ public static class ConsoleColors
public const string BrightCyan = "\e[96m"; public const string BrightCyan = "\e[96m";
public const string BrightWhite = "\e[97m"; public const string BrightWhite = "\e[97m";
private static readonly Lexer Lexer = new();
private static bool IsColorSupported() private static bool IsColorSupported()
{ {
var term = Environment.GetEnvironmentVariable("TERM"); var term = Environment.GetEnvironmentVariable("TERM");

View File

@@ -0,0 +1,43 @@
using Nub.Lang.Frontend.Parsing;
using Nub.Lang.Frontend.Parsing.Definitions;
namespace Nub.Lang.Frontend;
public class DefinitionTable
{
private readonly IEnumerable<CompilationUnit> _compilationUnits;
public DefinitionTable(IEnumerable<CompilationUnit> compilationUnits)
{
_compilationUnits = compilationUnits;
}
public Optional<IFuncSignature> LookupFunc(string @namespace, string name)
{
var definition = _compilationUnits
.Where(c => c.Namespace == @namespace)
.SelectMany(c => c.Definitions)
.OfType<IFuncSignature>()
.SingleOrDefault(f => f.Name == name);
return Optional.OfNullable(definition);
}
public Optional<StructDefinitionNode> LookupStruct(string @namespace, string name)
{
var definition = _compilationUnits
.Where(c => c.Namespace == @namespace)
.SelectMany(c => c.Definitions)
.OfType<StructDefinitionNode>()
.SingleOrDefault(f => f.Name == name);
return Optional.OfNullable(definition);
}
public IEnumerable<StructDefinitionNode> GetStructs()
{
return _compilationUnits
.SelectMany(c => c.Definitions)
.OfType<StructDefinitionNode>();
}
}

View File

@@ -10,52 +10,57 @@ using Nub.Lang.Frontend.Typing;
namespace Nub.Lang.Frontend.Generation; namespace Nub.Lang.Frontend.Generation;
public class QBEGenerator public static class QBEGenerator
{ {
private const string OutOfBoundsMessage = "Index is out of bounds\\n"; private const string OutOfBoundsMessage = "Index is out of bounds\\n";
private List<SourceFile> _sourceFiles = []; private static CompilationUnit _compilationUnit = null!;
private StringBuilder _builder = new(); private static DefinitionTable _definitionTable = null!;
private List<string> _strings = [];
private Stack<string> _breakLabels = [];
private Stack<string> _continueLabels = [];
private int _variableIndex;
private int _labelIndex;
private bool _codeIsReachable = true;
private Dictionary<AnonymousFuncNode, string> _anonymousFunctions = [];
private int _anonymousFuncIndex;
private SymbolTable _symbolTable = new([]);
public string Generate(List<SourceFile> sourceFiles) private static StringBuilder _builder = new();
private static List<string> _strings = [];
private static Stack<string> _breakLabels = [];
private static Stack<string> _continueLabels = [];
private static Queue<(AnonymousFuncNode Func, string Name)> _anonymousFunctions = [];
private static Stack<(string Name, string Pointer)> _variables = [];
private static Stack<int> _variableScopes = [];
private static int _variableIndex;
private static int _labelIndex;
private static int _anonymousFuncIndex;
private static bool _codeIsReachable = true;
public static string Generate(CompilationUnit compilationUnit, DefinitionTable definitionTable)
{ {
_sourceFiles = sourceFiles; _compilationUnit = compilationUnit;
_definitionTable = definitionTable;
_builder = new StringBuilder(); _builder = new StringBuilder();
_strings = []; _strings = [];
_breakLabels = []; _breakLabels = [];
_continueLabels = []; _continueLabels = [];
_anonymousFunctions = []; _anonymousFunctions = [];
_variables = [];
_variableScopes = [];
_variableIndex = 0; _variableIndex = 0;
_labelIndex = 0; _labelIndex = 0;
_anonymousFuncIndex = 0; _anonymousFuncIndex = 0;
_codeIsReachable = true; _codeIsReachable = true;
_symbolTable = new SymbolTable(_sourceFiles.SelectMany(f => f.Definitions.OfType<IFuncSignature>()));
foreach (var structDef in _sourceFiles.SelectMany(f => f.Definitions).OfType<StructDefinitionNode>()) foreach (var structDef in _definitionTable.GetStructs())
{ {
GenerateStructDefinition(structDef); GenerateStructDefinition(structDef);
_builder.AppendLine(); _builder.AppendLine();
} }
foreach (var funcDef in _sourceFiles.SelectMany(f => f.Definitions).OfType<LocalFuncDefinitionNode>()) foreach (var funcDef in _compilationUnit.Definitions.OfType<LocalFuncDefinitionNode>())
{ {
var symbol = _symbolTable.LookupFunc(funcDef.Namespace, funcDef.Name); GenerateFuncDefinition(FuncName(funcDef), funcDef.Parameters, funcDef.ReturnType, funcDef.Body, funcDef.Exported);
GenerateFuncDefinition(symbol.GeneratedName, funcDef.Parameters, funcDef.ReturnType, funcDef.Body, funcDef.Exported);
_builder.AppendLine(); _builder.AppendLine();
} }
foreach (var (func, name) in _anonymousFunctions) while (_anonymousFunctions.TryDequeue(out var anon))
{ {
GenerateFuncDefinition(name, func.Parameters, func.ReturnType, func.Body, false); GenerateFuncDefinition(anon.Name, anon.Func.Parameters, anon.Func.ReturnType, anon.Func.Body, false);
_builder.AppendLine(); _builder.AppendLine();
} }
@@ -70,6 +75,18 @@ public class QBEGenerator
return _builder.ToString(); return _builder.ToString();
} }
private static string FuncName(IFuncSignature func)
{
return func switch
{
ExternFuncDefinitionNode externFuncDefinition => $"${externFuncDefinition.CallName}",
LocalFuncDefinitionNode localFuncDefinition => localFuncDefinition.Exported
? $"${localFuncDefinition.Name}"
: $"${localFuncDefinition.Namespace}_{localFuncDefinition.Name}",
_ => throw new ArgumentOutOfRangeException(nameof(func))
};
}
private static string QBEStore(NubType type) private static string QBEStore(NubType type)
{ {
return $"store{type switch return $"store{type switch
@@ -154,7 +171,7 @@ public class QBEGenerator
}}"; }}";
} }
private int AlignmentOf(NubType type) private static int AlignmentOf(NubType type)
{ {
switch (type) switch (type)
{ {
@@ -183,7 +200,7 @@ public class QBEGenerator
} }
case NubStructType nubStructType: case NubStructType nubStructType:
{ {
var definition = LookupStructDefinition(nubStructType.Namespace, nubStructType.Name); var definition = _definitionTable.LookupStruct(nubStructType.Namespace, nubStructType.Name).GetValue();
return definition.Fields.Max(f => AlignmentOf(f.Type)); return definition.Fields.Max(f => AlignmentOf(f.Type));
} }
case NubPointerType: case NubPointerType:
@@ -203,12 +220,12 @@ public class QBEGenerator
} }
} }
private int AlignTo(int offset, int alignment) private static int AlignTo(int offset, int alignment)
{ {
return (offset + alignment - 1) & ~(alignment - 1); return (offset + alignment - 1) & ~(alignment - 1);
} }
private int SizeOf(NubType type) private static int SizeOf(NubType type)
{ {
switch (type) switch (type)
{ {
@@ -239,7 +256,7 @@ public class QBEGenerator
} }
case NubStructType nubStructType: case NubStructType nubStructType:
{ {
var definition = LookupStructDefinition(nubStructType.Namespace, nubStructType.Name); var definition = _definitionTable.LookupStruct(nubStructType.Namespace, nubStructType.Name).GetValue();
var size = 0; var size = 0;
var maxAlignment = 1; var maxAlignment = 1;
@@ -275,7 +292,7 @@ public class QBEGenerator
} }
} }
private bool IsLargeType(NubType type) private static bool IsLargeType(NubType type)
{ {
return type switch return type switch
{ {
@@ -289,10 +306,11 @@ public class QBEGenerator
}; };
} }
private void GenerateFuncDefinition(string name, List<FuncParameter> parameters, NubType returnType, BlockNode body, bool exported) private static void GenerateFuncDefinition(string name, List<FuncParameter> parameters, NubType returnType, BlockNode body, bool exported)
{ {
_symbolTable.Reset(); _variables.Clear();
_variableScopes.Clear();
if (exported) if (exported)
{ {
_builder.Append("export "); _builder.Append("export ");
@@ -358,6 +376,8 @@ public class QBEGenerator
_builder.AppendLine($"({string.Join(", ", parameterStrings)}) {{"); _builder.AppendLine($"({string.Join(", ", parameterStrings)}) {{");
_builder.AppendLine("@start"); _builder.AppendLine("@start");
List<(string Name, string Pointer)> parameterVars = [];
foreach (var parameter in parameters) foreach (var parameter in parameters)
{ {
var parameterName = "%" + parameter.Name; var parameterName = "%" + parameter.Name;
@@ -385,10 +405,10 @@ public class QBEGenerator
} }
} }
_symbolTable.DeclareVariable(parameter.Name, parameterName); parameterVars.Add((parameter.Name, parameterName));
} }
GenerateBlock(body); GenerateBlock(body, parameterVars);
if (body.Statements.LastOrDefault() is not ReturnNode) if (body.Statements.LastOrDefault() is not ReturnNode)
{ {
@@ -401,7 +421,7 @@ public class QBEGenerator
_builder.AppendLine("}"); _builder.AppendLine("}");
} }
private void GenerateStructDefinition(StructDefinitionNode structDefinition) private static void GenerateStructDefinition(StructDefinitionNode structDefinition)
{ {
_builder.Append($"type :{structDefinition.Namespace}_{structDefinition.Name} = {{ "); _builder.Append($"type :{structDefinition.Namespace}_{structDefinition.Name} = {{ ");
foreach (var structDefinitionField in structDefinition.Fields) foreach (var structDefinitionField in structDefinition.Fields)
@@ -435,7 +455,7 @@ public class QBEGenerator
_builder.AppendLine("}"); _builder.AppendLine("}");
} }
private void GenerateStatement(StatementNode statement) private static void GenerateStatement(StatementNode statement)
{ {
switch (statement) switch (statement)
{ {
@@ -477,7 +497,7 @@ public class QBEGenerator
} }
} }
private void GenerateArrayIndexAssignment(ArrayIndexAssignmentNode arrayIndexAssignment) private static void GenerateArrayIndexAssignment(ArrayIndexAssignmentNode arrayIndexAssignment)
{ {
var array = GenerateExpression(arrayIndexAssignment.ArrayIndexAccess.Array); var array = GenerateExpression(arrayIndexAssignment.ArrayIndexAccess.Array);
var index = GenerateExpression(arrayIndexAssignment.ArrayIndexAccess.Index); var index = GenerateExpression(arrayIndexAssignment.ArrayIndexAccess.Index);
@@ -521,38 +541,51 @@ public class QBEGenerator
} }
} }
private void GenerateBlock(BlockNode block) private static void GenerateBlock(BlockNode block, List<(string Name, string Pointer)>? variables = null)
{ {
_symbolTable.StartScope(); _variableScopes.Push(_variables.Count);
if (variables != null)
{
foreach (var variable in variables)
{
_variables.Push(variable);
}
}
foreach (var statement in block.Statements.Where(_ => _codeIsReachable)) foreach (var statement in block.Statements.Where(_ => _codeIsReachable))
{ {
GenerateStatement(statement); GenerateStatement(statement);
} }
_symbolTable.EndScope();
var count = _variableScopes.Pop();
while (_variableScopes.Count > count)
{
_variableScopes.Pop();
}
_codeIsReachable = true; _codeIsReachable = true;
} }
private void GenerateBreak() private static void GenerateBreak()
{ {
_builder.AppendLine($" jmp {_breakLabels.Peek()}"); _builder.AppendLine($" jmp {_breakLabels.Peek()}");
_codeIsReachable = false; _codeIsReachable = false;
} }
private void GenerateContinue() private static void GenerateContinue()
{ {
_builder.AppendLine($" jmp {_continueLabels.Peek()}"); _builder.AppendLine($" jmp {_continueLabels.Peek()}");
_codeIsReachable = false; _codeIsReachable = false;
} }
private void GenerateDereferenceAssignment(DereferenceAssignmentNode dereferenceAssignment) private static void GenerateDereferenceAssignment(DereferenceAssignmentNode dereferenceAssignment)
{ {
var pointer = GenerateExpression(dereferenceAssignment.Dereference.Expression); var pointer = GenerateExpression(dereferenceAssignment.Dereference.Expression);
var value = GenerateExpression(dereferenceAssignment.Value); var value = GenerateExpression(dereferenceAssignment.Value);
GenerateCopy(dereferenceAssignment.Value.Type, value, pointer); GenerateCopy(dereferenceAssignment.Value.Type, value, pointer);
} }
private void GenerateIf(IfNode ifStatement) private static void GenerateIf(IfNode ifStatement)
{ {
var trueLabel = GenLabelName(); var trueLabel = GenLabelName();
var falseLabel = GenLabelName(); var falseLabel = GenLabelName();
@@ -568,20 +601,20 @@ public class QBEGenerator
{ {
ifStatement.Else.Value.Match ifStatement.Else.Value.Match
( (
GenerateIf, elseIfNode => GenerateIf(elseIfNode),
GenerateBlock elseNode => GenerateBlock(elseNode)
); );
} }
_builder.AppendLine(endLabel); _builder.AppendLine(endLabel);
} }
private void GenerateMemberAssignment(MemberAssignmentNode memberAssignment) private static void GenerateMemberAssignment(MemberAssignmentNode memberAssignment)
{ {
var structType = memberAssignment.MemberAccess.Expression.Type as NubStructType; var structType = memberAssignment.MemberAccess.Expression.Type as NubStructType;
Debug.Assert(structType != null); Debug.Assert(structType != null);
var structDefinition = LookupStructDefinition(structType.Namespace, structType.Name); var structDefinition = _definitionTable.LookupStruct(structType.Namespace, structType.Name).GetValue();
var offset = LookupMemberOffset(structDefinition, memberAssignment.MemberAccess.Member); var offset = LookupMemberOffset(structDefinition, memberAssignment.MemberAccess.Member);
var item = GenerateExpression(memberAssignment.MemberAccess.Expression); var item = GenerateExpression(memberAssignment.MemberAccess.Expression);
@@ -594,7 +627,7 @@ public class QBEGenerator
GenerateCopy(memberAssignment.Value.Type, value, pointer); GenerateCopy(memberAssignment.Value.Type, value, pointer);
} }
private void GenerateReturn(ReturnNode @return) private static void GenerateReturn(ReturnNode @return)
{ {
if (@return.Value.HasValue) if (@return.Value.HasValue)
{ {
@@ -607,7 +640,7 @@ public class QBEGenerator
} }
} }
private void GenerateVariableDeclaration(VariableDeclarationNode variableDeclaration) private static void GenerateVariableDeclaration(VariableDeclarationNode variableDeclaration)
{ {
var type = variableDeclaration.ExplicitType.Value ?? variableDeclaration.Value.Value!.Type; var type = variableDeclaration.ExplicitType.Value ?? variableDeclaration.Value.Value!.Type;
var pointer = GenVarName(); var pointer = GenVarName();
@@ -621,20 +654,20 @@ public class QBEGenerator
else else
{ {
var pointerName = GenVarName(); var pointerName = GenVarName();
_symbolTable.DeclareVariable(variableDeclaration.Name, pointerName); _variables.Push((variableDeclaration.Name, pointerName));
} }
_symbolTable.DeclareVariable(variableDeclaration.Name, pointer); _variables.Push((variableDeclaration.Name, pointer));
} }
private void GenerateVariableAssignment(VariableAssignmentNode variableAssignment) private static void GenerateVariableAssignment(VariableAssignmentNode variableAssignment)
{ {
var value = GenerateExpression(variableAssignment.Value); var value = GenerateExpression(variableAssignment.Value);
var variable = _symbolTable.LookupVariable(variableAssignment.Identifier.Name); var variable = _variables.Single(x => x.Name == variableAssignment.Identifier.Name);
GenerateCopy(variableAssignment.Value.Type, value, variable.Pointer); GenerateCopy(variableAssignment.Value.Type, value, variable.Pointer);
} }
private void GenerateWhile(WhileNode whileStatement) private static void GenerateWhile(WhileNode whileStatement)
{ {
var conditionLabel = GenLabelName(); var conditionLabel = GenLabelName();
var iterationLabel = GenLabelName(); var iterationLabel = GenLabelName();
@@ -655,7 +688,7 @@ public class QBEGenerator
_breakLabels.Pop(); _breakLabels.Pop();
} }
private string GenerateExpression(ExpressionNode expression) private static string GenerateExpression(ExpressionNode expression)
{ {
return expression switch return expression switch
{ {
@@ -676,17 +709,17 @@ public class QBEGenerator
}; };
} }
private string GenerateAnonymousFunc(AnonymousFuncNode anonymousFunc) private static string GenerateAnonymousFunc(AnonymousFuncNode anonymousFunc)
{ {
var name = $"$anon_func{++_anonymousFuncIndex}"; var name = $"$anon_func{++_anonymousFuncIndex}";
_anonymousFunctions[anonymousFunc] = name; _anonymousFunctions.Enqueue((anonymousFunc, name));
var pointer = GenVarName(); var pointer = GenVarName();
_builder.AppendLine($" {pointer} =l alloc8 8"); _builder.AppendLine($" {pointer} =l alloc8 8");
_builder.AppendLine($" storel {name}, {pointer}"); _builder.AppendLine($" storel {name}, {pointer}");
return pointer; return pointer;
} }
private string GenerateArrayIndexPointer(ArrayIndexAccessNode arrayIndexAccess) private static string GenerateArrayIndexPointer(ArrayIndexAccessNode arrayIndexAccess)
{ {
var array = GenerateExpression(arrayIndexAccess.Array); var array = GenerateExpression(arrayIndexAccess.Array);
var index = GenerateExpression(arrayIndexAccess.Index); var index = GenerateExpression(arrayIndexAccess.Index);
@@ -721,7 +754,7 @@ public class QBEGenerator
} }
} }
private string GenerateArrayAccessIndex(ArrayIndexAccessNode arrayIndexAccess) private static string GenerateArrayAccessIndex(ArrayIndexAccessNode arrayIndexAccess)
{ {
var pointerName = GenerateArrayIndexPointer(arrayIndexAccess); var pointerName = GenerateArrayIndexPointer(arrayIndexAccess);
@@ -737,7 +770,7 @@ public class QBEGenerator
} }
} }
private void GenerateArrayBoundsCheck(string array, string index) private static void GenerateArrayBoundsCheck(string array, string index)
{ {
var countName = GenVarName(); var countName = GenVarName();
_builder.AppendLine($" {countName} =l loadl {array}"); _builder.AppendLine($" {countName} =l loadl {array}");
@@ -761,7 +794,7 @@ public class QBEGenerator
_builder.AppendLine(notOobLabel); _builder.AppendLine(notOobLabel);
} }
private string GenerateArrayInitializer(ArrayInitializerNode arrayInitializer) private static string GenerateArrayInitializer(ArrayInitializerNode arrayInitializer)
{ {
var capacity = GenerateExpression(arrayInitializer.Capacity); var capacity = GenerateExpression(arrayInitializer.Capacity);
var elementSize = SizeOf(arrayInitializer.ElementType); var elementSize = SizeOf(arrayInitializer.ElementType);
@@ -780,7 +813,7 @@ public class QBEGenerator
return outputName; return outputName;
} }
private string GenerateDereference(DereferenceNode dereference) private static string GenerateDereference(DereferenceNode dereference)
{ {
var result = GenerateExpression(dereference.Expression); var result = GenerateExpression(dereference.Expression);
var outputName = GenVarName(); var outputName = GenVarName();
@@ -788,7 +821,7 @@ public class QBEGenerator
return outputName; return outputName;
} }
private string GenerateAddressOf(AddressOfNode addressOf) private static string GenerateAddressOf(AddressOfNode addressOf)
{ {
switch (addressOf.Expression) switch (addressOf.Expression)
{ {
@@ -797,11 +830,11 @@ public class QBEGenerator
case DereferenceNode dereference: case DereferenceNode dereference:
return GenerateExpression(dereference.Expression); return GenerateExpression(dereference.Expression);
case IdentifierNode identifier: case IdentifierNode identifier:
if (identifier.Namespace != null) if (identifier.Namespace.HasValue)
{ {
throw new NotSupportedException("There is nothing to address in another namespace"); throw new NotSupportedException("There is nothing to address in another namespace");
} }
return _symbolTable.LookupVariable(identifier.Name).Pointer; return _variables.Single(x => x.Name == identifier.Name).Pointer;
case MemberAccessNode memberAccess: case MemberAccessNode memberAccess:
return GenerateMemberAccessPointer(memberAccess); return GenerateMemberAccessPointer(memberAccess);
default: default:
@@ -809,7 +842,7 @@ public class QBEGenerator
} }
} }
private string GenerateBinaryExpression(BinaryExpressionNode binaryExpression) private static string GenerateBinaryExpression(BinaryExpressionNode binaryExpression)
{ {
var left = GenerateExpression(binaryExpression.Left); var left = GenerateExpression(binaryExpression.Left);
var right = GenerateExpression(binaryExpression.Right); var right = GenerateExpression(binaryExpression.Right);
@@ -1035,25 +1068,26 @@ public class QBEGenerator
throw new NotSupportedException($"Binary operator {binaryExpression.Operator} for types {binaryExpression.Left.Type} and {binaryExpression.Right.Type} not supported"); throw new NotSupportedException($"Binary operator {binaryExpression.Operator} for types {binaryExpression.Left.Type} and {binaryExpression.Right.Type} not supported");
} }
private string GenerateIdentifier(IdentifierNode identifier) private static string GenerateIdentifier(IdentifierNode identifier)
{ {
var symbol = _symbolTable.Lookup(identifier.Namespace, identifier.Name); if (_definitionTable.LookupFunc(identifier.Namespace.Or(_compilationUnit.Namespace), identifier.Name).TryGetValue(out var func))
switch (symbol)
{ {
case SymbolTable.Func func: var pointer = GenVarName();
var pointer = GenVarName(); _builder.AppendLine($" {pointer} =l alloc8 8");
_builder.AppendLine($" {pointer} =l alloc8 8"); _builder.AppendLine($" storel {FuncName(func)}, {pointer}");
_builder.AppendLine($" storel {func.GeneratedName}, {pointer}"); return pointer;
return pointer;
case SymbolTable.Variable variable:
return GenerateDereference(identifier.Type, variable.Pointer);
default:
throw new ArgumentOutOfRangeException(nameof(symbol));
} }
if (!identifier.Namespace.HasValue)
{
var variable = _variables.Single(v => v.Name == identifier.Name);
return GenerateDereference(identifier.Type, variable.Pointer);
}
throw new UnreachableException();
} }
private string GenerateLiteral(LiteralNode literal) private static string GenerateLiteral(LiteralNode literal)
{ {
if (literal.Type.IsInteger) if (literal.Type.IsInteger)
{ {
@@ -1095,9 +1129,9 @@ public class QBEGenerator
} }
private string GenerateStructInitializer(StructInitializerNode structInitializer) private static string GenerateStructInitializer(StructInitializerNode structInitializer)
{ {
var structDefinition = LookupStructDefinition(structInitializer.StructType.Namespace, structInitializer.StructType.Name); var structDefinition = _definitionTable.LookupStruct(structInitializer.StructType.Namespace, structInitializer.StructType.Name).GetValue();
var structVar = GenVarName(); var structVar = GenVarName();
var size = SizeOf(structInitializer.StructType); var size = SizeOf(structInitializer.StructType);
@@ -1130,7 +1164,7 @@ public class QBEGenerator
return structVar; return structVar;
} }
private string GenerateUnaryExpression(UnaryExpressionNode unaryExpression) private static string GenerateUnaryExpression(UnaryExpressionNode unaryExpression)
{ {
var operand = GenerateExpression(unaryExpression.Operand); var operand = GenerateExpression(unaryExpression.Operand);
var outputName = GenVarName(); var outputName = GenVarName();
@@ -1177,7 +1211,7 @@ public class QBEGenerator
throw new NotSupportedException($"Unary operator {unaryExpression.Operator} for type {unaryExpression.Operand.Type} not supported"); throw new NotSupportedException($"Unary operator {unaryExpression.Operator} for type {unaryExpression.Operand.Type} not supported");
} }
private string GenerateMemberAccessPointer(MemberAccessNode memberAccess) private static string GenerateMemberAccessPointer(MemberAccessNode memberAccess)
{ {
var item = GenerateExpression(memberAccess.Expression); var item = GenerateExpression(memberAccess.Expression);
switch (memberAccess.Expression.Type) switch (memberAccess.Expression.Type)
@@ -1193,7 +1227,7 @@ public class QBEGenerator
} }
case NubStructType structType: case NubStructType structType:
{ {
var structDefinition = LookupStructDefinition(structType.Namespace, structType.Name); var structDefinition = _definitionTable.LookupStruct(structType.Namespace, structType.Name).GetValue();
var offset = LookupMemberOffset(structDefinition, memberAccess.Member); var offset = LookupMemberOffset(structDefinition, memberAccess.Member);
var offsetName = GenVarName(); var offsetName = GenVarName();
@@ -1207,7 +1241,7 @@ public class QBEGenerator
} }
} }
private string GenerateMemberAccess(MemberAccessNode memberAccess) private static string GenerateMemberAccess(MemberAccessNode memberAccess)
{ {
var pointer = GenerateMemberAccessPointer(memberAccess); var pointer = GenerateMemberAccessPointer(memberAccess);
@@ -1223,7 +1257,7 @@ public class QBEGenerator
} }
} }
private string GenerateFixedArrayInitializer(FixedArrayInitializerNode fixedArrayInitializer) private static string GenerateFixedArrayInitializer(FixedArrayInitializerNode fixedArrayInitializer)
{ {
var totalSize = SizeOf(fixedArrayInitializer.Type); var totalSize = SizeOf(fixedArrayInitializer.Type);
var outputName = GenVarName(); var outputName = GenVarName();
@@ -1240,7 +1274,7 @@ public class QBEGenerator
return outputName; return outputName;
} }
private string GenerateFuncCall(FuncCallNode funcCall) private static string GenerateFuncCall(FuncCallNode funcCall)
{ {
var funcType = (NubFuncType)funcCall.Expression.Type; var funcType = (NubFuncType)funcCall.Expression.Type;
@@ -1279,9 +1313,9 @@ public class QBEGenerator
} }
string funcTarget; string funcTarget;
if (funcCall.Expression is IdentifierNode identifier && _symbolTable.Lookup(identifier.Namespace, identifier.Name) is SymbolTable.Func func) if (funcCall.Expression is IdentifierNode identifier && _definitionTable.LookupFunc(identifier.Namespace.Or(_compilationUnit.Namespace), identifier.Name).TryGetValue(out var func))
{ {
funcTarget = func.GeneratedName; funcTarget = FuncName(func);
} }
else else
{ {
@@ -1304,7 +1338,7 @@ public class QBEGenerator
} }
} }
private void GenerateCopy(NubType type, string value, string destinationPointer) private static void GenerateCopy(NubType type, string value, string destinationPointer)
{ {
if (IsLargeType(type)) if (IsLargeType(type))
{ {
@@ -1316,7 +1350,7 @@ public class QBEGenerator
} }
} }
private string GenerateDereference(NubType type, string pointer) private static string GenerateDereference(NubType type, string pointer)
{ {
if (IsLargeType(type)) if (IsLargeType(type))
{ {
@@ -1330,26 +1364,17 @@ public class QBEGenerator
} }
} }
private string GenVarName() private static string GenVarName()
{ {
return $"%v{++_variableIndex}"; return $"%v{++_variableIndex}";
} }
private string GenLabelName() private static string GenLabelName()
{ {
return $"@l{++_labelIndex}"; return $"@l{++_labelIndex}";
} }
private StructDefinitionNode LookupStructDefinition(string @namespace, string name) private static int LookupMemberOffset(StructDefinitionNode structDefinition, string member)
{
return _sourceFiles
.Where(f => f.Namespace == @namespace)
.SelectMany(f => f.Definitions)
.OfType<StructDefinitionNode>()
.Single(s => s.Name == name);
}
private int LookupMemberOffset(StructDefinitionNode structDefinition, string member)
{ {
var offset = 0; var offset = 0;

View File

@@ -1,109 +0,0 @@
using Nub.Lang.Frontend.Parsing.Definitions;
namespace Nub.Lang.Frontend.Generation;
public class SymbolTable
{
private readonly List<Func> _functions = [];
private readonly Stack<Variable> _variables = [];
private readonly Stack<int> _scopes = [];
public SymbolTable(IEnumerable<IFuncSignature> functions)
{
foreach (var func in functions)
{
string name;
switch (func)
{
case ExternFuncDefinitionNode externFuncDefinitionNode:
{
name = "$" + externFuncDefinitionNode.CallName;
break;
}
case LocalFuncDefinitionNode localFuncDefinitionNode:
{
if (localFuncDefinitionNode.Exported)
{
name = "$" + localFuncDefinitionNode.Name;
}
else
{
name = "$" + localFuncDefinitionNode.Namespace + "_" + localFuncDefinitionNode.Name;
}
break;
}
default:
{
throw new ArgumentOutOfRangeException(nameof(func));
}
}
_functions.Add(new Func(func.Namespace, func.Name, name));
}
}
public void Reset()
{
_variables.Clear();
}
public void StartScope()
{
_scopes.Push(_variables.Count);
}
public void EndScope()
{
var count = _scopes.Pop();
while (count > _variables.Count)
{
_variables.Pop();
}
}
public Symbol Lookup(string? @namespace, string name)
{
if (@namespace == null)
{
return LookupVariable(name);
}
return LookupFunc(@namespace, name);
}
public Func LookupFunc(string @namespace, string name)
{
return _functions.Single(x => x.Name == name && x.Namespace == @namespace);
}
public Variable LookupVariable(string name)
{
return _variables.Single(x => x.Name == name);
}
public void DeclareVariable(string name, string pointer)
{
_variables.Push(new Variable(name, pointer));
}
public abstract class Symbol(string name)
{
public string Name { get; } = name;
}
public class Variable(string name, string pointer) : Symbol(name)
{
public string Pointer { get; set; } = pointer;
}
public class Func(string @namespace, string name, string generatedName) : Symbol(name)
{
public string Namespace { get; } = @namespace;
public string GeneratedName { get; } = generatedName;
}
public class Struct(string @namespace, string name) : Symbol(name)
{
public string Namespace { get; } = @namespace;
public string GeneratedName => $"{Namespace}_{Name}";
}
}

View File

@@ -2,7 +2,7 @@
namespace Nub.Lang.Frontend.Lexing; namespace Nub.Lang.Frontend.Lexing;
public class Lexer public static class Lexer
{ {
private static readonly Dictionary<string, Symbol> Keywords = new() private static readonly Dictionary<string, Symbol> Keywords = new()
{ {
@@ -59,10 +59,10 @@ public class Lexer
['&'] = Symbol.Ampersand, ['&'] = Symbol.Ampersand,
}; };
private SourceText _sourceText; private static SourceText _sourceText;
private int _index; private static int _index;
public DiagnosticsResult<List<Token>> Tokenize(SourceText sourceText) public static DiagnosticsResult<List<Token>> Tokenize(SourceText sourceText)
{ {
_sourceText = sourceText; _sourceText = sourceText;
_index = 0; _index = 0;
@@ -76,7 +76,7 @@ public class Lexer
return new DiagnosticsResult<List<Token>>([], tokens); return new DiagnosticsResult<List<Token>>([], tokens);
} }
private void ConsumeWhitespace() private static void ConsumeWhitespace()
{ {
while (Peek().TryGetValue(out var character) && char.IsWhiteSpace(character)) while (Peek().TryGetValue(out var character) && char.IsWhiteSpace(character))
{ {
@@ -84,7 +84,7 @@ public class Lexer
} }
} }
private Optional<Token> ParseToken() private static Optional<Token> ParseToken()
{ {
ConsumeWhitespace(); ConsumeWhitespace();
var startIndex = _index; var startIndex = _index;
@@ -236,7 +236,7 @@ public class Lexer
throw new Exception($"Unknown character {current}"); throw new Exception($"Unknown character {current}");
} }
private SourceLocation CreateLocation(int index) private static SourceLocation CreateLocation(int index)
{ {
var line = 1; var line = 1;
var column = 1; var column = 1;
@@ -256,12 +256,12 @@ public class Lexer
return new SourceLocation(line, column); return new SourceLocation(line, column);
} }
private SourceSpan CreateSpan(int startIndex) private static SourceSpan CreateSpan(int startIndex)
{ {
return new SourceSpan(_sourceText, CreateLocation(startIndex), CreateLocation(_index)); return new SourceSpan(_sourceText, CreateLocation(startIndex), CreateLocation(_index));
} }
private Optional<char> Peek(int offset = 0) private static Optional<char> Peek(int offset = 0)
{ {
if (_index + offset < _sourceText.Content.Length) if (_index + offset < _sourceText.Content.Length)
{ {
@@ -271,7 +271,7 @@ public class Lexer
return Optional<char>.Empty(); return Optional<char>.Empty();
} }
private void Next() private static void Next()
{ {
_index++; _index++;
} }

View File

@@ -2,7 +2,7 @@
namespace Nub.Lang.Frontend.Parsing; namespace Nub.Lang.Frontend.Parsing;
public class SourceFile(string @namespace, List<DefinitionNode> definitions) public class CompilationUnit(string @namespace, List<DefinitionNode> definitions)
{ {
public string Namespace { get; } = @namespace; public string Namespace { get; } = @namespace;
public List<DefinitionNode> Definitions { get; } = definitions; public List<DefinitionNode> Definitions { get; } = definitions;

View File

@@ -2,7 +2,8 @@
namespace Nub.Lang.Frontend.Parsing.Definitions; namespace Nub.Lang.Frontend.Parsing.Definitions;
public abstract class DefinitionNode(IReadOnlyList<Token> tokens, Optional<string> documentation) : Node(tokens) public abstract class DefinitionNode(IReadOnlyList<Token> tokens, Optional<string> documentation, string @namespace) : Node(tokens)
{ {
public Optional<string> Documentation { get; set; } = documentation; public Optional<string> Documentation { get; } = documentation;
public string Namespace { get; set; } = @namespace;
} }

View File

@@ -15,17 +15,15 @@ public class FuncParameter(string name, NubType type)
public interface IFuncSignature public interface IFuncSignature
{ {
public string Name { get; } public string Name { get; }
public string Namespace { get; }
public List<FuncParameter> Parameters { get; } public List<FuncParameter> Parameters { get; }
public NubType ReturnType { get; } public NubType ReturnType { get; }
public string ToString() => $"{Name}({string.Join(", ", Parameters.Select(p => p.ToString()))}){": " + ReturnType}"; public string ToString() => $"{Name}({string.Join(", ", Parameters.Select(p => p.ToString()))}){": " + ReturnType}";
} }
public class LocalFuncDefinitionNode(IReadOnlyList<Token> tokens, Optional<string> documentation, string name, string @namespace, List<FuncParameter> parameters, BlockNode body, NubType returnType, bool exported) : DefinitionNode(tokens, documentation), IFuncSignature public class LocalFuncDefinitionNode(IReadOnlyList<Token> tokens, Optional<string> documentation, string @namespace, string name, List<FuncParameter> parameters, BlockNode body, NubType returnType, bool exported) : DefinitionNode(tokens, documentation, @namespace), IFuncSignature
{ {
public string Name { get; } = name; public string Name { get; } = name;
public string Namespace { get; } = @namespace;
public List<FuncParameter> Parameters { get; } = parameters; public List<FuncParameter> Parameters { get; } = parameters;
public BlockNode Body { get; } = body; public BlockNode Body { get; } = body;
public NubType ReturnType { get; } = returnType; public NubType ReturnType { get; } = returnType;
@@ -34,10 +32,9 @@ public class LocalFuncDefinitionNode(IReadOnlyList<Token> tokens, Optional<strin
public override string ToString() => $"{Name}({string.Join(", ", Parameters.Select(p => p.ToString()))}){": " + ReturnType}"; public override string ToString() => $"{Name}({string.Join(", ", Parameters.Select(p => p.ToString()))}){": " + ReturnType}";
} }
public class ExternFuncDefinitionNode(IReadOnlyList<Token> tokens, Optional<string> documentation, string name, string @namespace, string callName, List<FuncParameter> parameters, NubType returnType) : DefinitionNode(tokens, documentation), IFuncSignature public class ExternFuncDefinitionNode(IReadOnlyList<Token> tokens, Optional<string> documentation, string @namespace, string name, string callName, List<FuncParameter> parameters, NubType returnType) : DefinitionNode(tokens, documentation, @namespace), IFuncSignature
{ {
public string Name { get; } = name; public string Name { get; } = name;
public string Namespace { get; } = @namespace;
public string CallName { get; } = callName; public string CallName { get; } = callName;
public List<FuncParameter> Parameters { get; } = parameters; public List<FuncParameter> Parameters { get; } = parameters;
public NubType ReturnType { get; } = returnType; public NubType ReturnType { get; } = returnType;

View File

@@ -11,9 +11,8 @@ public class StructField(string name, NubType type, Optional<ExpressionNode> val
public Optional<ExpressionNode> Value { get; } = value; public Optional<ExpressionNode> Value { get; } = value;
} }
public class StructDefinitionNode(IReadOnlyList<Token> tokens, Optional<string> documentation, string name, string @namespace, List<StructField> fields) : DefinitionNode(tokens, documentation) public class StructDefinitionNode(IReadOnlyList<Token> tokens, Optional<string> documentation, string @namespace, string name, List<StructField> fields) : DefinitionNode(tokens, documentation, @namespace)
{ {
public string Name { get; } = name; public string Name { get; } = name;
public string Namespace { get; } = @namespace;
public List<StructField> Fields { get; } = fields; public List<StructField> Fields { get; } = fields;
} }

View File

@@ -2,10 +2,10 @@
namespace Nub.Lang.Frontend.Parsing.Expressions; namespace Nub.Lang.Frontend.Parsing.Expressions;
public class IdentifierNode(IReadOnlyList<Token> tokens, string? @namespace, string name) : LValueNode(tokens) public class IdentifierNode(IReadOnlyList<Token> tokens, Optional<string> @namespace, string name) : LValueNode(tokens)
{ {
public string? Namespace { get; } = @namespace; public Optional<string> Namespace { get; } = @namespace;
public string Name { get; } = name; public string Name { get; } = name;
public override string ToString() => Name; public override string ToString() => Namespace.HasValue ? $"{Namespace.Value}::{Name}" : Name;
} }

View File

@@ -9,33 +9,43 @@ using Nub.Lang.Frontend.Typing;
namespace Nub.Lang.Frontend.Parsing; namespace Nub.Lang.Frontend.Parsing;
public class Parser public static class Parser
{ {
private List<Diagnostic> _diagnostics = []; private static string _namespace = null!;
private List<Token> _tokens = []; private static List<Diagnostic> _diagnostics = [];
private int _index; private static List<Token> _tokens = [];
private string _namespace = string.Empty; private static int _index;
public DiagnosticsResult<SourceFile?> ParseModule(List<Token> tokens) public static DiagnosticsResult<CompilationUnit?> ParseFile(List<Token> tokens)
{ {
_diagnostics = [];
_tokens = tokens; _tokens = tokens;
_namespace = null!;
_diagnostics = [];
_index = 0; _index = 0;
_namespace = string.Empty;
try try
{ {
ExpectSymbol(Symbol.Namespace); ExpectSymbol(Symbol.Namespace);
_namespace = ExpectIdentifier().Value; var @namespace = ExpectIdentifier();
_namespace = @namespace.Value;
}
catch (ParseException ex)
{
_diagnostics.Add(ex.Diagnostic);
return new DiagnosticsResult<CompilationUnit?>(_diagnostics, null);
}
try
{
List<DefinitionNode> definitions = []; List<DefinitionNode> definitions = [];
while (Peek().HasValue) while (Peek().HasValue)
{ {
definitions.Add(ParseDefinition()); var definition = ParseDefinition();
definitions.Add(definition);
} }
return new DiagnosticsResult<SourceFile?>(_diagnostics, new SourceFile(_namespace, definitions)); return new DiagnosticsResult<CompilationUnit?>(_diagnostics, new CompilationUnit(_namespace, definitions));
} }
catch (ParseException ex) catch (ParseException ex)
{ {
@@ -43,10 +53,10 @@ public class Parser
RecoverToNextDefinition(); RecoverToNextDefinition();
} }
return new DiagnosticsResult<SourceFile?>(_diagnostics, null); return new DiagnosticsResult<CompilationUnit?>(_diagnostics, null);
} }
private DefinitionNode ParseDefinition() private static DefinitionNode ParseDefinition()
{ {
var startIndex = _index; var startIndex = _index;
List<ModifierToken> modifiers = []; List<ModifierToken> modifiers = [];
@@ -78,7 +88,7 @@ public class Parser
}; };
} }
private DefinitionNode ParseFuncDefinition(int startIndex, List<ModifierToken> modifiers, Optional<string> documentation) private static DefinitionNode ParseFuncDefinition(int startIndex, List<ModifierToken> modifiers, Optional<string> documentation)
{ {
var name = ExpectIdentifier(); var name = ExpectIdentifier();
List<FuncParameter> parameters = []; List<FuncParameter> parameters = [];
@@ -120,7 +130,7 @@ public class Parser
callName = ExpectIdentifier().Value; callName = ExpectIdentifier().Value;
} }
return new ExternFuncDefinitionNode(GetTokensForNode(startIndex), documentation, name.Value, _namespace, callName, parameters, returnType); return new ExternFuncDefinitionNode(GetTokensForNode(startIndex), documentation, _namespace, name.Value, callName, parameters, returnType);
} }
var body = ParseBlock(); var body = ParseBlock();
@@ -135,10 +145,10 @@ public class Parser
.Build()); .Build());
} }
return new LocalFuncDefinitionNode(GetTokensForNode(startIndex), documentation, name.Value, _namespace, parameters, body, returnType, exported); return new LocalFuncDefinitionNode(GetTokensForNode(startIndex), documentation, _namespace, name.Value, parameters, body, returnType, exported);
} }
private StructDefinitionNode ParseStruct(int startIndex, List<ModifierToken> _, Optional<string> documentation) private static StructDefinitionNode ParseStruct(int startIndex, List<ModifierToken> _, Optional<string> documentation)
{ {
var name = ExpectIdentifier().Value; var name = ExpectIdentifier().Value;
@@ -162,10 +172,10 @@ public class Parser
variables.Add(new StructField(variableName, variableType, variableValue)); variables.Add(new StructField(variableName, variableType, variableValue));
} }
return new StructDefinitionNode(GetTokensForNode(startIndex), documentation, name, _namespace, variables); return new StructDefinitionNode(GetTokensForNode(startIndex), documentation, _namespace, name, variables);
} }
private FuncParameter ParseFuncParameter() private static FuncParameter ParseFuncParameter()
{ {
var name = ExpectIdentifier(); var name = ExpectIdentifier();
ExpectSymbol(Symbol.Colon); ExpectSymbol(Symbol.Colon);
@@ -174,7 +184,7 @@ public class Parser
return new FuncParameter(name.Value, type); return new FuncParameter(name.Value, type);
} }
private StatementNode ParseStatement() private static StatementNode ParseStatement()
{ {
var startIndex = _index; var startIndex = _index;
if (!Peek().TryGetValue(out var token)) if (!Peek().TryGetValue(out var token))
@@ -207,7 +217,7 @@ public class Parser
return ParseStatementExpression(startIndex); return ParseStatementExpression(startIndex);
} }
private StatementNode ParseStatementExpression(int startIndex) private static StatementNode ParseStatementExpression(int startIndex)
{ {
var expr = ParseExpression(); var expr = ParseExpression();
@@ -256,7 +266,7 @@ public class Parser
return new StatementExpressionNode(GetTokensForNode(startIndex), expr); return new StatementExpressionNode(GetTokensForNode(startIndex), expr);
} }
private VariableDeclarationNode ParseVariableDeclaration(int startIndex) private static VariableDeclarationNode ParseVariableDeclaration(int startIndex)
{ {
ExpectSymbol(Symbol.Let); ExpectSymbol(Symbol.Let);
var name = ExpectIdentifier().Value; var name = ExpectIdentifier().Value;
@@ -275,20 +285,20 @@ public class Parser
return new VariableDeclarationNode(GetTokensForNode(startIndex), name, type, value); return new VariableDeclarationNode(GetTokensForNode(startIndex), name, type, value);
} }
private StatementNode ParseBreak(int startIndex) private static StatementNode ParseBreak(int startIndex)
{ {
ExpectSymbol(Symbol.Break); ExpectSymbol(Symbol.Break);
Next(); Next();
return new BreakNode(GetTokensForNode(startIndex)); return new BreakNode(GetTokensForNode(startIndex));
} }
private StatementNode ParseContinue(int startIndex) private static StatementNode ParseContinue(int startIndex)
{ {
ExpectSymbol(Symbol.Continue); ExpectSymbol(Symbol.Continue);
return new ContinueNode(GetTokensForNode(startIndex)); return new ContinueNode(GetTokensForNode(startIndex));
} }
private ReturnNode ParseReturn(int startIndex) private static ReturnNode ParseReturn(int startIndex)
{ {
ExpectSymbol(Symbol.Return); ExpectSymbol(Symbol.Return);
var value = Optional<ExpressionNode>.Empty(); var value = Optional<ExpressionNode>.Empty();
@@ -300,7 +310,7 @@ public class Parser
return new ReturnNode(GetTokensForNode(startIndex), value); return new ReturnNode(GetTokensForNode(startIndex), value);
} }
private IfNode ParseIf(int startIndex) private static IfNode ParseIf(int startIndex)
{ {
ExpectSymbol(Symbol.If); ExpectSymbol(Symbol.If);
var condition = ParseExpression(); var condition = ParseExpression();
@@ -318,7 +328,7 @@ public class Parser
return new IfNode(GetTokensForNode(startIndex), condition, body, elseStatement); return new IfNode(GetTokensForNode(startIndex), condition, body, elseStatement);
} }
private WhileNode ParseWhile(int startIndex) private static WhileNode ParseWhile(int startIndex)
{ {
ExpectSymbol(Symbol.While); ExpectSymbol(Symbol.While);
var condition = ParseExpression(); var condition = ParseExpression();
@@ -326,7 +336,7 @@ public class Parser
return new WhileNode(GetTokensForNode(startIndex), condition, body); return new WhileNode(GetTokensForNode(startIndex), condition, body);
} }
private ExpressionNode ParseExpression(int precedence = 0) private static ExpressionNode ParseExpression(int precedence = 0)
{ {
var startIndex = _index; var startIndex = _index;
var left = ParsePrimaryExpression(); var left = ParsePrimaryExpression();
@@ -407,7 +417,7 @@ public class Parser
} }
} }
private ExpressionNode ParsePrimaryExpression() private static ExpressionNode ParsePrimaryExpression()
{ {
var startIndex = _index; var startIndex = _index;
ExpressionNode expr; ExpressionNode expr;
@@ -422,7 +432,7 @@ public class Parser
} }
case IdentifierToken identifier: case IdentifierToken identifier:
{ {
string? @namespace = null; var @namespace = Optional<string>.Empty();
var name = identifier.Value; var name = identifier.Value;
if (TryExpectSymbol(Symbol.DoubleColon)) if (TryExpectSymbol(Symbol.DoubleColon))
{ {
@@ -577,7 +587,7 @@ public class Parser
return ParsePostfixOperators(startIndex, expr); return ParsePostfixOperators(startIndex, expr);
} }
private ExpressionNode ParsePostfixOperators(int startIndex, ExpressionNode expr) private static ExpressionNode ParsePostfixOperators(int startIndex, ExpressionNode expr)
{ {
while (true) while (true)
{ {
@@ -628,7 +638,7 @@ public class Parser
return expr; return expr;
} }
private BlockNode ParseBlock() private static BlockNode ParseBlock()
{ {
var startIndex = _index; var startIndex = _index;
ExpectSymbol(Symbol.OpenBrace); ExpectSymbol(Symbol.OpenBrace);
@@ -649,21 +659,21 @@ public class Parser
return new BlockNode(GetTokensForNode(startIndex), statements); return new BlockNode(GetTokensForNode(startIndex), statements);
} }
private NubType ParseType() private static NubType ParseType()
{ {
if (TryExpectIdentifier(out var name)) if (TryExpectIdentifier(out var name))
{ {
if (name == "any") if (name.Value == "any")
{ {
return new NubAnyType(); return new NubAnyType();
} }
if (name == "void") if (name.Value == "void")
{ {
return new NubVoidType(); return new NubVoidType();
} }
if (NubPrimitiveType.TryParse(name, out var primitiveTypeKind)) if (NubPrimitiveType.TryParse(name.Value, out var primitiveTypeKind))
{ {
return new NubPrimitiveType(primitiveTypeKind.Value); return new NubPrimitiveType(primitiveTypeKind.Value);
} }
@@ -674,7 +684,16 @@ public class Parser
@namespace = ExpectIdentifier().Value; @namespace = ExpectIdentifier().Value;
} }
return new NubStructType(@namespace, name); if (@namespace == null)
{
throw new ParseException(Diagnostic
.Error($"Struct '{name.Value}' does not belong to a namespace")
.WithHelp("Make sure you have specified a namespace at the top of the file")
.At(name)
.Build());
}
return new NubStructType(@namespace , name.Value);
} }
if (TryExpectSymbol(Symbol.Caret)) if (TryExpectSymbol(Symbol.Caret))
@@ -749,7 +768,7 @@ public class Parser
.Build()); .Build());
} }
private Token ExpectToken() private static Token ExpectToken()
{ {
if (!Peek().TryGetValue(out var token)) if (!Peek().TryGetValue(out var token))
{ {
@@ -764,7 +783,7 @@ public class Parser
return token; return token;
} }
private SymbolToken ExpectSymbol() private static SymbolToken ExpectSymbol()
{ {
var token = ExpectToken(); var token = ExpectToken();
if (token is not SymbolToken symbol) if (token is not SymbolToken symbol)
@@ -779,7 +798,7 @@ public class Parser
return symbol; return symbol;
} }
private void ExpectSymbol(Symbol expectedSymbol) private static void ExpectSymbol(Symbol expectedSymbol)
{ {
var token = ExpectSymbol(); var token = ExpectSymbol();
if (token.Symbol != expectedSymbol) if (token.Symbol != expectedSymbol)
@@ -792,7 +811,7 @@ public class Parser
} }
} }
private bool TryExpectSymbol(Symbol symbol) private static bool TryExpectSymbol(Symbol symbol)
{ {
if (Peek() is { Value: SymbolToken symbolToken } && symbolToken.Symbol == symbol) if (Peek() is { Value: SymbolToken symbolToken } && symbolToken.Symbol == symbol)
{ {
@@ -803,7 +822,7 @@ public class Parser
return false; return false;
} }
private bool TryExpectModifier([NotNullWhen(true)] out ModifierToken? modifier) private static bool TryExpectModifier([NotNullWhen(true)] out ModifierToken? modifier)
{ {
if (Peek() is { Value: ModifierToken modifierToken }) if (Peek() is { Value: ModifierToken modifierToken })
{ {
@@ -816,11 +835,11 @@ public class Parser
return false; return false;
} }
private bool TryExpectIdentifier([NotNullWhen(true)] out string? identifier) private static bool TryExpectIdentifier([NotNullWhen(true)] out IdentifierToken? identifier)
{ {
if (Peek() is { Value: IdentifierToken identifierToken }) if (Peek() is { Value: IdentifierToken identifierToken })
{ {
identifier = identifierToken.Value; identifier = identifierToken;
Next(); Next();
return true; return true;
} }
@@ -829,7 +848,7 @@ public class Parser
return false; return false;
} }
private IdentifierToken ExpectIdentifier() private static IdentifierToken ExpectIdentifier()
{ {
var token = ExpectToken(); var token = ExpectToken();
if (token is not IdentifierToken identifier) if (token is not IdentifierToken identifier)
@@ -844,12 +863,12 @@ public class Parser
return identifier; return identifier;
} }
private void RecoverToNextDefinition() private static void RecoverToNextDefinition()
{ {
while (Peek().HasValue) while (Peek().HasValue)
{ {
var token = Peek().Value; var token = Peek().Value;
if (token is SymbolToken { Symbol: Symbol.Func or Symbol.Struct }) if (token is SymbolToken { Symbol: Symbol.Func or Symbol.Struct } or ModifierToken)
{ {
break; break;
} }
@@ -858,7 +877,7 @@ public class Parser
} }
} }
private void RecoverToNextStatement() private static void RecoverToNextStatement()
{ {
while (Peek().TryGetValue(out var token)) while (Peek().TryGetValue(out var token))
{ {
@@ -874,7 +893,7 @@ public class Parser
} }
} }
private Optional<Token> Peek(int offset = 0) private static Optional<Token> Peek(int offset = 0)
{ {
var peekIndex = _index + offset; var peekIndex = _index + offset;
while (peekIndex < _tokens.Count && _tokens[peekIndex] is DocumentationToken) while (peekIndex < _tokens.Count && _tokens[peekIndex] is DocumentationToken)
@@ -890,7 +909,7 @@ public class Parser
return Optional<Token>.Empty(); return Optional<Token>.Empty();
} }
private void Next() private static void Next()
{ {
while (_index < _tokens.Count && _tokens[_index] is DocumentationToken) while (_index < _tokens.Count && _tokens[_index] is DocumentationToken)
{ {
@@ -900,7 +919,7 @@ public class Parser
_index++; _index++;
} }
private IReadOnlyList<Token> GetTokensForNode(int startIndex) private static IReadOnlyList<Token> GetTokensForNode(int startIndex)
{ {
return _tokens[startIndex..Math.Min(_index, _tokens.Count - 1)]; return _tokens[startIndex..Math.Min(_index, _tokens.Count - 1)];
} }

View File

@@ -8,63 +8,46 @@ using Nub.Lang.Frontend.Parsing.Statements;
namespace Nub.Lang.Frontend.Typing; namespace Nub.Lang.Frontend.Typing;
public class TypeChecker public static class TypeChecker
{ {
private Dictionary<string, NubType> _variables = new(); private static CompilationUnit _compilationUnit = null!;
private List<SourceFile> _sourceFiles = []; private static DefinitionTable _definitionTable = null!;
private List<Diagnostic> _diagnostics = [];
private NubType? _currentFunctionReturnType;
private Queue<AnonymousFuncNode> _anonymousFunctions = [];
public DiagnosticsResult TypeCheck(List<SourceFile> sourceFiles) private static Dictionary<string, NubType> _variables = new();
private static List<Diagnostic> _diagnostics = [];
private static NubType? _currentFunctionReturnType;
private static Queue<AnonymousFuncNode> _anonymousFunctions = [];
public static DiagnosticsResult Check(CompilationUnit compilationUnit, DefinitionTable definitionTable)
{ {
_compilationUnit = compilationUnit;
_definitionTable = definitionTable;
_variables = new Dictionary<string, NubType>(); _variables = new Dictionary<string, NubType>();
_diagnostics = []; _diagnostics = [];
_currentFunctionReturnType = null; _currentFunctionReturnType = null;
_sourceFiles = sourceFiles;
_anonymousFunctions = []; _anonymousFunctions = [];
var externFuncDefinitions = _sourceFiles foreach (var structDef in compilationUnit.Definitions.OfType<StructDefinitionNode>())
.SelectMany(f => f.Definitions)
.OfType<ExternFuncDefinitionNode>()
.ToArray();
foreach (var funcName in externFuncDefinitions.Where(x => externFuncDefinitions.Count(y => x.Name == y.Name) > 1))
{ {
ReportError($"Extern function '{funcName}' has been declared more than once", funcName); CheckStructDef(structDef);
} }
var exportedLocalFuncDefinitions = _sourceFiles foreach (var funcDef in compilationUnit.Definitions.OfType<LocalFuncDefinitionNode>())
.SelectMany(f => f.Definitions)
.OfType<LocalFuncDefinitionNode>()
.Where(f => f.Exported)
.ToArray();
foreach (var funcName in exportedLocalFuncDefinitions.Where(x => exportedLocalFuncDefinitions.Count(y => x.Name == y.Name) > 1))
{ {
ReportError($"Exported function '{funcName}' has been declared more than once", funcName); CheckFuncDef(funcDef.Parameters, funcDef.Body, funcDef.ReturnType);
}
foreach (var structDef in _sourceFiles.SelectMany(f => f.Definitions).OfType<StructDefinitionNode>())
{
TypeCheckStructDef(structDef);
}
foreach (var funcDef in _sourceFiles.SelectMany(f => f.Definitions).OfType<LocalFuncDefinitionNode>())
{
TypeCheckFuncDef(funcDef.Parameters, funcDef.Body, funcDef.ReturnType);
} }
while (_anonymousFunctions.TryDequeue(out var func)) while (_anonymousFunctions.TryDequeue(out var func))
{ {
TypeCheckFuncDef(func.Parameters, func.Body, func.ReturnType); CheckFuncDef(func.Parameters, func.Body, func.ReturnType);
} }
return new DiagnosticsResult(_diagnostics); return new DiagnosticsResult(_diagnostics);
} }
private void TypeCheckStructDef(StructDefinitionNode structDef) private static void CheckStructDef(StructDefinitionNode structDef)
{ {
var fields = new Dictionary<string, NubType>(); var fields = new Dictionary<string, NubType>();
foreach (var field in structDef.Fields) foreach (var field in structDef.Fields)
@@ -77,7 +60,7 @@ public class TypeChecker
if (field.Value.HasValue) if (field.Value.HasValue)
{ {
var fieldType = TypeCheckExpression(field.Value.Value, field.Type); var fieldType = CheckExpression(field.Value.Value, field.Type);
if (fieldType != null && !fieldType.Equals(field.Type)) if (fieldType != null && !fieldType.Equals(field.Type))
{ {
ReportError("Default field initializer does not match the defined type", field.Value.Value); ReportError("Default field initializer does not match the defined type", field.Value.Value);
@@ -88,7 +71,7 @@ public class TypeChecker
} }
} }
private void TypeCheckFuncDef(List<FuncParameter> parameters, BlockNode body, NubType returnType) private static void CheckFuncDef(List<FuncParameter> parameters, BlockNode body, NubType returnType)
{ {
_variables.Clear(); _variables.Clear();
_currentFunctionReturnType = returnType; _currentFunctionReturnType = returnType;
@@ -98,50 +81,50 @@ public class TypeChecker
_variables[param.Name] = param.Type; _variables[param.Name] = param.Type;
} }
TypeCheckBlock(body); CheckBlock(body);
} }
private void TypeCheckBlock(BlockNode block) private static void CheckBlock(BlockNode block)
{ {
foreach (var statement in block.Statements) foreach (var statement in block.Statements)
{ {
TypeCheckStatement(statement); CheckStatement(statement);
} }
} }
private void TypeCheckStatement(StatementNode statement) private static void CheckStatement(StatementNode statement)
{ {
switch (statement) switch (statement)
{ {
case ArrayIndexAssignmentNode arrayIndexAssignment: case ArrayIndexAssignmentNode arrayIndexAssignment:
TypeCheckArrayIndexAssignment(arrayIndexAssignment); CheckArrayIndexAssignment(arrayIndexAssignment);
break; break;
case VariableAssignmentNode variableAssignment: case VariableAssignmentNode variableAssignment:
TypeCheckVariableAssignment(variableAssignment); CheckVariableAssignment(variableAssignment);
break; break;
case VariableDeclarationNode variableDeclaration: case VariableDeclarationNode variableDeclaration:
TypeCheckVariableVariableDeclaration(variableDeclaration); CheckVariableVariableDeclaration(variableDeclaration);
break; break;
case IfNode ifNode: case IfNode ifNode:
TypeCheckIf(ifNode); CheckIf(ifNode);
break; break;
case MemberAssignmentNode memberAssignment: case MemberAssignmentNode memberAssignment:
TypeCheckMemberAssignment(memberAssignment); CheckMemberAssignment(memberAssignment);
break; break;
case WhileNode whileNode: case WhileNode whileNode:
TypeCheckWhile(whileNode); CheckWhile(whileNode);
break; break;
case ReturnNode returnNode: case ReturnNode returnNode:
TypeCheckReturn(returnNode); CheckReturn(returnNode);
break; break;
case StatementExpressionNode statementExpression: case StatementExpressionNode statementExpression:
TypeCheckExpression(statementExpression.Expression); CheckExpression(statementExpression.Expression);
break; break;
case BreakNode: case BreakNode:
case ContinueNode: case ContinueNode:
break; break;
case DereferenceAssignmentNode dereferenceAssignment: case DereferenceAssignmentNode dereferenceAssignment:
TypeCheckDereferenceAssignment(dereferenceAssignment); CheckDereferenceAssignment(dereferenceAssignment);
break; break;
default: default:
ReportError($"Unsupported statement type: {statement.GetType().Name}", statement); ReportError($"Unsupported statement type: {statement.GetType().Name}", statement);
@@ -149,11 +132,11 @@ public class TypeChecker
} }
} }
private void TypeCheckMemberAssignment(MemberAssignmentNode memberAssignment) private static void CheckMemberAssignment(MemberAssignmentNode memberAssignment)
{ {
var memberType = TypeCheckExpression(memberAssignment.MemberAccess); var memberType = CheckExpression(memberAssignment.MemberAccess);
if (memberType == null) return; if (memberType == null) return;
var valueType = TypeCheckExpression(memberAssignment.Value, memberType); var valueType = CheckExpression(memberAssignment.Value, memberType);
if (valueType == null) return; if (valueType == null) return;
if (!NubType.IsCompatibleWith(memberType, valueType)) if (!NubType.IsCompatibleWith(memberType, valueType))
@@ -162,11 +145,11 @@ public class TypeChecker
} }
} }
private void TypeCheckArrayIndexAssignment(ArrayIndexAssignmentNode arrayIndexAssignment) private static void CheckArrayIndexAssignment(ArrayIndexAssignmentNode arrayIndexAssignment)
{ {
var itemType = TypeCheckExpression(arrayIndexAssignment.ArrayIndexAccess); var itemType = CheckExpression(arrayIndexAssignment.ArrayIndexAccess);
if (itemType == null) return; if (itemType == null) return;
var valueType = TypeCheckExpression(arrayIndexAssignment.Value, itemType); var valueType = CheckExpression(arrayIndexAssignment.Value, itemType);
if (valueType == null) return; if (valueType == null) return;
if (!NubType.IsCompatibleWith(itemType, valueType)) if (!NubType.IsCompatibleWith(itemType, valueType))
@@ -175,7 +158,7 @@ public class TypeChecker
} }
} }
private void TypeCheckVariableAssignment(VariableAssignmentNode variableAssignment) private static void CheckVariableAssignment(VariableAssignmentNode variableAssignment)
{ {
if (!_variables.TryGetValue(variableAssignment.Identifier.Name, out var variable)) if (!_variables.TryGetValue(variableAssignment.Identifier.Name, out var variable))
{ {
@@ -183,7 +166,7 @@ public class TypeChecker
return; return;
} }
var valueType = TypeCheckExpression(variableAssignment.Value, variable); var valueType = CheckExpression(variableAssignment.Value, variable);
if (valueType == null) return; if (valueType == null) return;
if (!NubType.IsCompatibleWith(variableAssignment.Value.Type, variable)) if (!NubType.IsCompatibleWith(variableAssignment.Value.Type, variable))
@@ -192,7 +175,7 @@ public class TypeChecker
} }
} }
private void TypeCheckVariableVariableDeclaration(VariableDeclarationNode variableDeclaration) private static void CheckVariableVariableDeclaration(VariableDeclarationNode variableDeclaration)
{ {
NubType? type = null; NubType? type = null;
@@ -203,7 +186,7 @@ public class TypeChecker
if (variableDeclaration.Value.HasValue) if (variableDeclaration.Value.HasValue)
{ {
var valueType = TypeCheckExpression(variableDeclaration.Value.Value, variableDeclaration.ExplicitType.Value); var valueType = CheckExpression(variableDeclaration.Value.Value, variableDeclaration.ExplicitType.Value);
if (valueType == null) return; if (valueType == null) return;
type = valueType; type = valueType;
} }
@@ -232,9 +215,9 @@ public class TypeChecker
_variables[variableDeclaration.Name] = type; _variables[variableDeclaration.Name] = type;
} }
private NubType? TypeCheckDereference(DereferenceNode dereference) private static NubType? CheckDereference(DereferenceNode dereference)
{ {
var exprType = TypeCheckExpression(dereference.Expression); var exprType = CheckExpression(dereference.Expression);
if (exprType == null) return null; if (exprType == null) return null;
if (exprType is not NubPointerType nubPointerType) if (exprType is not NubPointerType nubPointerType)
@@ -246,14 +229,14 @@ public class TypeChecker
return nubPointerType.BaseType; return nubPointerType.BaseType;
} }
private NubType TypeCheckFixedInitializerArray(FixedArrayInitializerNode fixedArrayInitializer) private static NubType CheckFixedInitializerArray(FixedArrayInitializerNode fixedArrayInitializer)
{ {
return new NubFixedArrayType(fixedArrayInitializer.ElementType, fixedArrayInitializer.Capacity); return new NubFixedArrayType(fixedArrayInitializer.ElementType, fixedArrayInitializer.Capacity);
} }
private NubType? TypeCheckFuncCall(FuncCallNode funcCall) private static NubType? CheckFuncCall(FuncCallNode funcCall)
{ {
var identType = TypeCheckExpression(funcCall.Expression); var identType = CheckExpression(funcCall.Expression);
if (identType is not NubFuncType funcType) if (identType is not NubFuncType funcType)
{ {
ReportError("Cannot call function on non-function type", funcCall); ReportError("Cannot call function on non-function type", funcCall);
@@ -268,7 +251,7 @@ public class TypeChecker
for (var i = 0; i < funcCall.Parameters.Count; i++) for (var i = 0; i < funcCall.Parameters.Count; i++)
{ {
var parameter = funcCall.Parameters[i]; var parameter = funcCall.Parameters[i];
var parameterType = TypeCheckExpression(parameter); var parameterType = CheckExpression(parameter);
if (parameterType == null) return null; if (parameterType == null) return null;
if (!NubType.IsCompatibleWith(parameterType, funcType.Parameters[i])) if (!NubType.IsCompatibleWith(parameterType, funcType.Parameters[i]))
@@ -281,39 +264,39 @@ public class TypeChecker
return funcType.ReturnType; return funcType.ReturnType;
} }
private void TypeCheckIf(IfNode ifNode) private static void CheckIf(IfNode ifNode)
{ {
var conditionType = TypeCheckExpression(ifNode.Condition, NubPrimitiveType.Bool); var conditionType = CheckExpression(ifNode.Condition, NubPrimitiveType.Bool);
if (conditionType != null && !conditionType.Equals(NubPrimitiveType.Bool)) if (conditionType != null && !conditionType.Equals(NubPrimitiveType.Bool))
{ {
ReportError($"If condition must be a boolean expression, got '{conditionType}'", ifNode.Condition); ReportError($"If condition must be a boolean expression, got '{conditionType}'", ifNode.Condition);
} }
TypeCheckBlock(ifNode.Body); CheckBlock(ifNode.Body);
if (ifNode.Else.HasValue) if (ifNode.Else.HasValue)
{ {
var elseValue = ifNode.Else.Value; var elseValue = ifNode.Else.Value;
elseValue.Match(TypeCheckIf, TypeCheckBlock); elseValue.Match(CheckIf, CheckBlock);
} }
} }
private void TypeCheckWhile(WhileNode whileNode) private static void CheckWhile(WhileNode whileNode)
{ {
var conditionType = TypeCheckExpression(whileNode.Condition, NubPrimitiveType.Bool); var conditionType = CheckExpression(whileNode.Condition, NubPrimitiveType.Bool);
if (conditionType != null && !conditionType.Equals(NubPrimitiveType.Bool)) if (conditionType != null && !conditionType.Equals(NubPrimitiveType.Bool))
{ {
ReportError($"While condition must be a boolean expression, got '{conditionType}'", whileNode.Condition); ReportError($"While condition must be a boolean expression, got '{conditionType}'", whileNode.Condition);
} }
TypeCheckBlock(whileNode.Body); CheckBlock(whileNode.Body);
} }
private void TypeCheckReturn(ReturnNode returnNode) private static void CheckReturn(ReturnNode returnNode)
{ {
if (returnNode.Value.HasValue) if (returnNode.Value.HasValue)
{ {
var returnType = TypeCheckExpression(returnNode.Value.Value, _currentFunctionReturnType); var returnType = CheckExpression(returnNode.Value.Value, _currentFunctionReturnType);
if (returnType == null) return; if (returnType == null) return;
if (_currentFunctionReturnType == null) if (_currentFunctionReturnType == null)
@@ -333,11 +316,11 @@ public class TypeChecker
} }
} }
private void TypeCheckDereferenceAssignment(DereferenceAssignmentNode dereferenceAssignment) private static void CheckDereferenceAssignment(DereferenceAssignmentNode dereferenceAssignment)
{ {
var dereferenceType = TypeCheckExpression(dereferenceAssignment.Dereference); var dereferenceType = CheckExpression(dereferenceAssignment.Dereference);
if (dereferenceType == null) return; if (dereferenceType == null) return;
var valueType = TypeCheckExpression(dereferenceAssignment.Value, dereferenceType); var valueType = CheckExpression(dereferenceAssignment.Value, dereferenceType);
if (valueType == null) return; if (valueType == null) return;
if (!NubType.IsCompatibleWith(dereferenceType, valueType)) if (!NubType.IsCompatibleWith(dereferenceType, valueType))
@@ -346,23 +329,23 @@ public class TypeChecker
} }
} }
private NubType? TypeCheckExpression(ExpressionNode expression, NubType? expectedType = null) private static NubType? CheckExpression(ExpressionNode expression, NubType? expectedType = null)
{ {
var resultType = expression switch var resultType = expression switch
{ {
AddressOfNode addressOf => TypeCheckAddressOf(addressOf), AddressOfNode addressOf => CheckAddressOf(addressOf),
AnonymousFuncNode anonymousFunc => TypeCheckAnonymousFunc(anonymousFunc), AnonymousFuncNode anonymousFunc => CheckAnonymousFunc(anonymousFunc),
ArrayIndexAccessNode arrayIndex => TypeCheckArrayIndex(arrayIndex), ArrayIndexAccessNode arrayIndex => CheckArrayIndex(arrayIndex),
ArrayInitializerNode arrayInitializer => TypeCheckArrayInitializer(arrayInitializer), ArrayInitializerNode arrayInitializer => CheckArrayInitializer(arrayInitializer),
LiteralNode literal => TypeCheckLiteral(literal, expectedType), LiteralNode literal => CheckLiteral(literal, expectedType),
IdentifierNode identifier => TypeCheckIdentifier(identifier), IdentifierNode identifier => CheckIdentifier(identifier),
BinaryExpressionNode binaryExpr => TypeCheckBinaryExpression(binaryExpr), BinaryExpressionNode binaryExpr => CheckBinaryExpression(binaryExpr),
DereferenceNode dereference => TypeCheckDereference(dereference), DereferenceNode dereference => CheckDereference(dereference),
FixedArrayInitializerNode fixedArray => TypeCheckFixedInitializerArray(fixedArray), FixedArrayInitializerNode fixedArray => CheckFixedInitializerArray(fixedArray),
FuncCallNode funcCallExpr => TypeCheckFuncCall(funcCallExpr), FuncCallNode funcCallExpr => CheckFuncCall(funcCallExpr),
StructInitializerNode structInit => TypeCheckStructInitializer(structInit), StructInitializerNode structInit => CheckStructInitializer(structInit),
UnaryExpressionNode unaryExpression => TypeCheckUnaryExpression(unaryExpression), UnaryExpressionNode unaryExpression => CheckUnaryExpression(unaryExpression),
MemberAccessNode memberAccess => TypeCheckMemberAccess(memberAccess), MemberAccessNode memberAccess => CheckMemberAccess(memberAccess),
_ => throw new UnreachableException() _ => throw new UnreachableException()
}; };
@@ -374,13 +357,13 @@ public class TypeChecker
return resultType; return resultType;
} }
private NubType TypeCheckAnonymousFunc(AnonymousFuncNode anonymousFunc) private static NubType CheckAnonymousFunc(AnonymousFuncNode anonymousFunc)
{ {
_anonymousFunctions.Enqueue(anonymousFunc); _anonymousFunctions.Enqueue(anonymousFunc);
return new NubFuncType(anonymousFunc.ReturnType, anonymousFunc.Parameters.Select(p => p.Type).ToList()); return new NubFuncType(anonymousFunc.ReturnType, anonymousFunc.Parameters.Select(p => p.Type).ToList());
} }
private NubType? TypeCheckLiteral(LiteralNode literal, NubType? expectedType = null) private static NubType? CheckLiteral(LiteralNode literal, NubType? expectedType = null)
{ {
if (expectedType != null) if (expectedType != null)
{ {
@@ -411,11 +394,11 @@ public class TypeChecker
}; };
} }
private NubType? TypeCheckArrayIndex(ArrayIndexAccessNode arrayIndexAccess) private static NubType? CheckArrayIndex(ArrayIndexAccessNode arrayIndexAccess)
{ {
var expressionType = TypeCheckExpression(arrayIndexAccess.Array); var expressionType = CheckExpression(arrayIndexAccess.Array);
if (expressionType == null) return null; if (expressionType == null) return null;
var indexType = TypeCheckExpression(arrayIndexAccess.Index, NubPrimitiveType.U64); var indexType = CheckExpression(arrayIndexAccess.Index, NubPrimitiveType.U64);
if (indexType is { IsInteger: false }) if (indexType is { IsInteger: false })
{ {
ReportError("Array index type must be a number", arrayIndexAccess.Index); ReportError("Array index type must be a number", arrayIndexAccess.Index);
@@ -435,9 +418,9 @@ public class TypeChecker
return null; return null;
} }
private NubType TypeCheckArrayInitializer(ArrayInitializerNode arrayInitializer) private static NubType CheckArrayInitializer(ArrayInitializerNode arrayInitializer)
{ {
var capacityType = TypeCheckExpression(arrayInitializer.Capacity, NubPrimitiveType.U64); var capacityType = CheckExpression(arrayInitializer.Capacity, NubPrimitiveType.U64);
if (capacityType is { IsInteger: false }) if (capacityType is { IsInteger: false })
{ {
ReportError("Array capacity type must be an integer", arrayInitializer.Capacity); ReportError("Array capacity type must be an integer", arrayInitializer.Capacity);
@@ -446,33 +429,26 @@ public class TypeChecker
return new NubArrayType(arrayInitializer.ElementType); return new NubArrayType(arrayInitializer.ElementType);
} }
private NubType? TypeCheckIdentifier(IdentifierNode identifier) private static NubType? CheckIdentifier(IdentifierNode identifier)
{ {
if (identifier.Namespace == null) var definition = _definitionTable.LookupFunc(identifier.Namespace.Or(_compilationUnit.Namespace), identifier.Name);
if (definition.HasValue)
{ {
var result = _variables.GetValueOrDefault(identifier.Name); return new NubFuncType(definition.Value.ReturnType, definition.Value.Parameters.Select(p => p.Type).ToList());
if (result == null) }
{
ReportError($"Variable '{identifier.Name}' is not defined", identifier);
return null;
}
return result; if (!identifier.Namespace.HasValue)
}
var func = LookupFuncSignature(identifier.Namespace, identifier.Name);
if (func == null)
{ {
ReportError($"Identifier '{identifier.Name}' is not defined", identifier); return _variables[identifier.Name];
return null;
} }
return new NubFuncType(func.ReturnType, func.Parameters.Select(p => p.Type).ToList()); ReportError($"Identifier '{identifier}' not found", identifier);
return null;
} }
private NubType? TypeCheckAddressOf(AddressOfNode addressOf) private static NubType? CheckAddressOf(AddressOfNode addressOf)
{ {
var exprType = TypeCheckExpression(addressOf.Expression); var exprType = CheckExpression(addressOf.Expression);
if (exprType == null) return null; if (exprType == null) return null;
if (addressOf.Expression is not (IdentifierNode or MemberAccessNode)) if (addressOf.Expression is not (IdentifierNode or MemberAccessNode))
@@ -484,10 +460,10 @@ public class TypeChecker
return new NubPointerType(exprType); return new NubPointerType(exprType);
} }
private NubType? TypeCheckBinaryExpression(BinaryExpressionNode binaryExpr) private static NubType? CheckBinaryExpression(BinaryExpressionNode binaryExpr)
{ {
var leftType = TypeCheckExpression(binaryExpr.Left); var leftType = CheckExpression(binaryExpr.Left);
var rightType = TypeCheckExpression(binaryExpr.Right); var rightType = CheckExpression(binaryExpr.Right);
if (leftType == null || rightType == null) return null; if (leftType == null || rightType == null) return null;
@@ -530,12 +506,12 @@ public class TypeChecker
} }
} }
private NubType? TypeCheckStructInitializer(StructInitializerNode structInit) private static NubType? CheckStructInitializer(StructInitializerNode structInit)
{ {
var initialized = new HashSet<string>(); var initialized = new HashSet<string>();
var definition = LookupStructDefinition(structInit.StructType.Namespace, structInit.StructType.Name); var defOpt = _definitionTable.LookupStruct(structInit.StructType.Namespace, structInit.StructType.Name);
if (definition == null) if (!defOpt.TryGetValue(out var definition))
{ {
ReportError($"Struct type '{structInit.StructType.Name}' is not defined", structInit); ReportError($"Struct type '{structInit.StructType.Name}' is not defined", structInit);
return null; return null;
@@ -550,7 +526,7 @@ public class TypeChecker
continue; continue;
} }
var initializerType = TypeCheckExpression(initializer.Value, definitionField.Type); var initializerType = CheckExpression(initializer.Value, definitionField.Type);
if (initializerType != null && !NubType.IsCompatibleWith(initializerType, definitionField.Type)) if (initializerType != null && !NubType.IsCompatibleWith(initializerType, definitionField.Type))
{ {
ReportError($"Cannot initialize field '{initializer.Key}' of type '{definitionField.Type}' with expression of type '{initializerType}'", initializer.Value); ReportError($"Cannot initialize field '{initializer.Key}' of type '{definitionField.Type}' with expression of type '{initializerType}'", initializer.Value);
@@ -575,9 +551,9 @@ public class TypeChecker
return structInit.StructType; return structInit.StructType;
} }
private NubType? TypeCheckUnaryExpression(UnaryExpressionNode unaryExpression) private static NubType? CheckUnaryExpression(UnaryExpressionNode unaryExpression)
{ {
var operandType = TypeCheckExpression(unaryExpression.Operand); var operandType = CheckExpression(unaryExpression.Operand);
if (operandType == null) return null; if (operandType == null) return null;
switch (unaryExpression.Operator) switch (unaryExpression.Operator)
@@ -615,9 +591,9 @@ public class TypeChecker
} }
} }
private NubType? TypeCheckMemberAccess(MemberAccessNode memberAccess) private static NubType? CheckMemberAccess(MemberAccessNode memberAccess)
{ {
var expressionType = TypeCheckExpression(memberAccess.Expression); var expressionType = CheckExpression(memberAccess.Expression);
if (expressionType == null) return null; if (expressionType == null) return null;
switch (expressionType) switch (expressionType)
@@ -633,8 +609,8 @@ public class TypeChecker
} }
case NubStructType structType: case NubStructType structType:
{ {
var definition = LookupStructDefinition(structType.Namespace, structType.Name); var defOpt = _definitionTable.LookupStruct(structType.Namespace, structType.Name);
if (definition == null) if (!defOpt.TryGetValue(out var definition))
{ {
ReportError($"Struct type '{structType.Name}' is not defined", memberAccess); ReportError($"Struct type '{structType.Name}' is not defined", memberAccess);
return null; return null;
@@ -655,13 +631,13 @@ public class TypeChecker
return null; return null;
} }
private void ReportError(string message, Node node) private static void ReportError(string message, Node node)
{ {
var diagnostic = Diagnostic.Error(message).At(node).Build(); var diagnostic = Diagnostic.Error(message).At(node).Build();
_diagnostics.Add(diagnostic); _diagnostics.Add(diagnostic);
} }
private void ReportWarning(string message, Node node) private static void ReportWarning(string message, Node node)
{ {
var diagnostic = Diagnostic.Warning(message).At(node).Build(); var diagnostic = Diagnostic.Warning(message).At(node).Build();
_diagnostics.Add(diagnostic); _diagnostics.Add(diagnostic);
@@ -691,22 +667,4 @@ public class TypeChecker
return false; return false;
} }
} }
private IFuncSignature? LookupFuncSignature(string @namespace, string name)
{
return _sourceFiles
.Where(f => f.Namespace == @namespace)
.SelectMany(f => f.Definitions)
.OfType<IFuncSignature>()
.FirstOrDefault(f => f.Name == name);
}
private StructDefinitionNode? LookupStructDefinition(string @namespace, string name)
{
return _sourceFiles
.Where(f => f.Namespace == @namespace)
.SelectMany(f => f.Definitions)
.OfType<StructDefinitionNode>()
.SingleOrDefault(d => d.Name == name);
}
} }

View File

@@ -58,5 +58,20 @@ public readonly struct Optional<TValue>
return false; return false;
} }
public TValue GetValue()
{
return Value ?? throw new InvalidOperationException("Value is not set");
}
public static implicit operator Optional<TValue>(TValue value) => new(value); public static implicit operator Optional<TValue>(TValue value) => new(value);
public TValue Or(TValue other)
{
if (HasValue)
{
return Value;
}
return other;
}
} }