This commit is contained in:
nub31
2025-09-10 22:33:58 +02:00
parent 32190b3bea
commit df6665a096
21 changed files with 587 additions and 880 deletions

2
example/.gitignore vendored
View File

@@ -1,2 +1,2 @@
build
.build
out

View File

@@ -1,34 +1,15 @@
// c
module main
extern func puts(text: cstring)
struct Name
struct Test
{
first: cstring
last: cstring
}
struct Human
{
name: Name
age: cstring
}
func main(args: []cstring): i64
{
let x: Human = {
name = {
first = "bob"
last = "the builder"
}
age = "23"
}
puts(x.age)
puts("test")
return 0
}
// func test(human: ^Human)
// {
// puts(human^.name.last)
// }

View File

@@ -1,13 +1,11 @@
using NubLang.CLI;
using NubLang.Code;
using NubLang.Diagnostics;
using NubLang.Generation;
using NubLang.Generation.QBE;
using NubLang.Parsing;
using NubLang.Parsing.Syntax;
using NubLang.Tokenization;
using NubLang.TypeChecking;
using NubLang.TypeChecking.Node;
var options = new Options();
@@ -35,9 +33,6 @@ for (var i = 0; i < args.Length; i++)
}
}
var diagnostics = new List<Diagnostic>();
var syntaxTrees = new List<SyntaxTree>();
foreach (var file in options.Files)
{
if (!File.Exists(file.Path))
@@ -47,6 +42,9 @@ foreach (var file in options.Files)
}
}
var diagnostics = new List<Diagnostic>();
var syntaxTrees = new List<SyntaxTree>();
foreach (var file in options.Files)
{
var tokenizer = new Tokenizer(file);
@@ -61,16 +59,17 @@ foreach (var file in options.Files)
syntaxTrees.Add(syntaxTree);
}
var definitionTable = new DefinitionTable(syntaxTrees);
var moduleSignatures = ModuleSignature.CollectFromSyntaxTrees(syntaxTrees);
var modules = Module.CollectFromSyntaxTrees(syntaxTrees);
var typedSyntaxTrees = new List<TypedSyntaxTree>();
var typedModules = new List<TypedModule>();
foreach (var syntaxTree in syntaxTrees)
foreach (var module in modules)
{
var typeChecker = new TypeChecker(syntaxTree, definitionTable);
var typedSyntaxTree = typeChecker.Check();
var typeChecker = new TypeChecker(module, moduleSignatures);
var typedModule = typeChecker.CheckModule();
diagnostics.AddRange(typeChecker.GetDiagnostics());
typedSyntaxTrees.Add(typedSyntaxTree);
typedModules.Add(typedModule);
}
foreach (var diagnostic in diagnostics)
@@ -83,13 +82,11 @@ if (diagnostics.Any(diagnostic => diagnostic.Severity == DiagnosticSeverity.Erro
return 1;
}
var typedDefinitionTable = new TypedDefinitionTable(typedSyntaxTrees);
var objectFiles = new List<string>();
for (var i = 0; i < typedSyntaxTrees.Count; i++)
for (var i = 0; i < typedModules.Count; i++)
{
var syntaxTree = typedSyntaxTrees[i];
var typedModule = typedModules[i];
var outFileName = Path.Combine(".build", "code", Path.ChangeExtension(options.Files[i].Path, null));
var outFileDir = Path.GetDirectoryName(outFileName);
@@ -98,7 +95,7 @@ for (var i = 0; i < typedSyntaxTrees.Count; i++)
Directory.CreateDirectory(outFileDir);
}
var generator = new QBEGenerator(syntaxTree, typedDefinitionTable, options.Files[i].Path);
var generator = new QBEGenerator(typedModule, moduleSignatures);
var ssa = generator.Emit();
var ssaFilePath = Path.ChangeExtension(outFileName, "ssa");

View File

@@ -2,16 +2,16 @@
using System.Globalization;
using System.Text;
using NubLang.Tokenization;
using NubLang.TypeChecking;
using NubLang.TypeChecking.Node;
namespace NubLang.Generation.QBE;
public class QBEGenerator
{
private readonly TypedSyntaxTree _syntaxTree;
private readonly TypedDefinitionTable _definitionTable;
private readonly string _sourceFileName;
private readonly QBEWriter _writer;
private readonly TypedModule _module;
private readonly IReadOnlyList<ModuleSignature> _moduleSignatures;
private readonly List<CStringLiteral> _cStringLiterals = [];
private readonly List<StringLiteral> _stringLiterals = [];
@@ -23,11 +23,10 @@ public class QBEGenerator
private int _stringLiteralIndex;
private bool _codeIsReachable = true;
public QBEGenerator(TypedSyntaxTree syntaxTree, TypedDefinitionTable definitionTable, string sourceFileName)
public QBEGenerator(TypedModule module, IReadOnlyList<ModuleSignature> moduleSignatures)
{
_syntaxTree = syntaxTree;
_definitionTable = definitionTable;
_sourceFileName = sourceFileName;
_module = module;
_moduleSignatures = moduleSignatures;
_writer = new QBEWriter();
}
@@ -43,41 +42,42 @@ public class QBEGenerator
_stringLiteralIndex = 0;
_codeIsReachable = true;
_writer.WriteLine($"dbgfile \"{_sourceFileName}\"");
foreach (var structDef in _definitionTable.GetStructs())
foreach (var moduleSignature in _moduleSignatures)
{
EmitStructTypeDefinition(structDef);
foreach (var structType in moduleSignature.Symbols.Values.OfType<StructTypeNode>())
{
EmitStructType(moduleSignature.Name, structType);
_writer.NewLine();
}
}
foreach (var structDef in _syntaxTree.Definitions.OfType<StructNode>())
foreach (var structDef in _module.Definitions.OfType<StructNode>())
{
EmitStructDefinition(structDef);
_writer.NewLine();
}
foreach (var funcDef in _syntaxTree.Definitions.OfType<LocalFuncNode>())
foreach (var funcDef in _module.Definitions.OfType<LocalFuncNode>())
{
EmitLocalFuncDefinition(funcDef);
_writer.NewLine();
}
foreach (var structDef in _syntaxTree.Definitions.OfType<StructNode>().Where(x => x.InterfaceImplementations.Count > 0))
{
_writer.Write($"data {StructVtableName(structDef.Name)} = {{ ");
foreach (var interfaceImplementation in structDef.InterfaceImplementations)
{
var interfaceDef = _definitionTable.LookupInterface(interfaceImplementation.Name);
foreach (var func in interfaceDef.Functions)
{
_writer.Write($"l {StructFuncName(structDef.Name, func.Name)}, ");
}
}
_writer.WriteLine("}");
}
// foreach (var structDef in _module.Definitions.OfType<StructNode>().Where(x => x.InterfaceImplementations.Count > 0))
// {
// _writer.Write($"data {StructVtableName(_module.Name, structDef.Name)} = {{ ");
//
// foreach (var interfaceImplementation in structDef.InterfaceImplementations)
// {
// var interfaceDef = _definitionTable.LookupInterface(interfaceImplementation.Name);
// foreach (var func in interfaceDef.Functions)
// {
// _writer.Write($"l {StructFuncName(_module.Name, structDef.Name, func.Name)}, ");
// }
// }
//
// _writer.WriteLine("}");
// }
foreach (var cStringLiteral in _cStringLiterals)
{
@@ -366,7 +366,7 @@ public class QBEGenerator
if (complexType is StructTypeNode structType)
{
return StructTypeName(structType.Name);
return StructTypeName(structType.Module, structType.Name);
}
return "l";
@@ -384,7 +384,7 @@ public class QBEGenerator
_writer.Write(FuncQBETypeName(funcDef.Signature.ReturnType) + ' ');
}
_writer.Write(LocalFuncName(funcDef));
_writer.Write(LocalFuncName(_module.Name, funcDef));
_writer.Write("(");
foreach (var parameter in funcDef.Signature.Parameters)
@@ -408,12 +408,19 @@ public class QBEGenerator
private void EmitStructDefinition(StructNode structDef)
{
_writer.WriteLine($"export function {StructCtorName(_module.Name, structDef.Name)}() {{");
_writer.WriteLine("@start");
_writer.Indented($"%struct =l alloc8 {SizeOf(structDef.)}");
_writer.Indented("ret %struct");
_writer.WriteLine("}");
for (var i = 0; i < structDef.Functions.Count; i++)
{
var function = structDef.Functions[i];
_labelIndex = 0;
_tmpIndex = 0;
_writer.NewLine();
_writer.Write("export function ");
if (function.Signature.ReturnType is not VoidTypeNode)
@@ -421,7 +428,7 @@ public class QBEGenerator
_writer.Write(FuncQBETypeName(function.Signature.ReturnType) + ' ');
}
_writer.Write(StructFuncName(structDef.Name, function.Name));
_writer.Write(StructFuncName(_module.Name, structDef.Name, function.Name));
_writer.Write("(l %this, ");
foreach (var parameter in function.Signature.Parameters)
@@ -441,38 +448,24 @@ public class QBEGenerator
}
_writer.WriteLine("}");
if (i != structDef.Functions.Count - 1)
{
_writer.NewLine();
}
}
}
private void EmitStructTypeDefinition(StructNode structDef)
private void EmitStructType(string module, StructTypeNode structType)
{
_writer.WriteLine($"type {StructTypeName(structDef.Name)} = {{ ");
_writer.WriteLine($"type {StructTypeName(module, structType.Name)} = {{ ");
var types = new Dictionary<string, string>();
foreach (var field in structDef.Fields)
foreach (var field in structType.Fields)
{
types.Add(field.Name, StructDefQBEType(field));
}
var longest = types.Values.Max(x => x.Length);
foreach (var (name, type) in types)
{
var padding = longest - type.Length;
_writer.Indented($"{type},{new string(' ', padding)} # {name}");
_writer.Indented($"{StructDefQBEType(field)},");
}
_writer.WriteLine("}");
return;
string StructDefQBEType(StructFieldNode field)
string StructDefQBEType(TypeNode type)
{
if (field.Type.IsSimpleType(out var simpleType, out var complexType))
if (type.IsSimpleType(out var simpleType, out var complexType))
{
return simpleType.StorageSize switch
{
@@ -488,7 +481,7 @@ public class QBEGenerator
if (complexType is StructTypeNode structType)
{
return StructTypeName(structType.Name);
return StructTypeName(structType.Module, structType.Name);
}
return "l";
@@ -649,10 +642,7 @@ public class QBEGenerator
ConvertToInterfaceNode convertToInterface => EmitConvertToInterface(convertToInterface),
ConvertIntNode convertInt => EmitConvertInt(convertInt),
ConvertFloatNode convertFloat => EmitConvertFloat(convertFloat),
ExternFuncIdentNode externFuncIdent => EmitExternFuncIdent(externFuncIdent),
LocalFuncIdentNode localFuncIdent => EmitLocalFuncIdent(localFuncIdent),
VariableIdentNode variableIdent => EmitVariableIdent(variableIdent),
FuncParameterIdentNode funcParameterIdent => EmitFuncParameterIdent(funcParameterIdent),
VariableIdentifierNode identifier => EmitIdentifier(identifier),
LiteralNode literal => EmitLiteral(literal),
UnaryExpressionNode unaryExpression => EmitUnaryExpression(unaryExpression),
StructFieldAccessNode structFieldAccess => EmitStructFieldAccess(structFieldAccess),
@@ -662,15 +652,21 @@ public class QBEGenerator
};
}
private string EmitArrayIndexAccess(ArrayIndexAccessNode arrayIndexAccess)
private string EmitIdentifier(VariableIdentifierNode variableIdentifier)
{
var address = EmitAddressOfArrayIndexAccess(arrayIndexAccess);
if (arrayIndexAccess.Type is StructTypeNode)
{
return address;
throw new NotImplementedException();
}
return EmitLoad(arrayIndexAccess.Type, address);
private string EmitArrayIndexAccess(ArrayIndexAccessNode arrayIndexAccess)
{
// var address = EmitAddressOfArrayIndexAccess(arrayIndexAccess);
// if (arrayIndexAccess.Type is StructTypeNode)
// {
// return address;
// }
//
// return EmitLoad(arrayIndexAccess.Type, address);
throw new NotImplementedException();
}
private string EmitArrayInitializer(ArrayInitializerNode arrayInitializer)
@@ -714,43 +710,43 @@ public class QBEGenerator
{
return addressOf switch
{
ArrayIndexAccessNode arrayIndexAccess => EmitAddressOfArrayIndexAccess(arrayIndexAccess),
StructFieldAccessNode structFieldAccess => EmitAddressOfStructFieldAccess(structFieldAccess),
VariableIdentNode variableIdent => EmitAddressOfVariableIdent(variableIdent),
// ArrayIndexAccessNode arrayIndexAccess => EmitAddressOfArrayIndexAccess(arrayIndexAccess),
// StructFieldAccessNode structFieldAccess => EmitAddressOfStructFieldAccess(structFieldAccess),
// VariableIdentNode variableIdent => EmitAddressOfVariableIdent(variableIdent),
_ => throw new ArgumentOutOfRangeException(nameof(addressOf))
};
}
private string EmitAddressOfArrayIndexAccess(ArrayIndexAccessNode arrayIndexAccess)
{
var array = EmitExpression(arrayIndexAccess.Target);
var index = EmitExpression(arrayIndexAccess.Index);
var elementType = ((ArrayTypeNode)arrayIndexAccess.Target.Type).ElementType;
var offset = TmpName();
_writer.Indented($"{offset} =l mul {index}, {SizeOf(elementType)}");
_writer.Indented($"{offset} =l add {offset}, 8");
_writer.Indented($"{offset} =l add {array}, {offset}");
return offset;
}
private string EmitAddressOfStructFieldAccess(StructFieldAccessNode structFieldAccess)
{
var target = EmitExpression(structFieldAccess.Target);
var structDef = _definitionTable.LookupStruct(structFieldAccess.StructType.Name);
var offset = OffsetOf(structDef, structFieldAccess.Field);
var address = TmpName();
_writer.Indented($"{address} =l add {target}, {offset}");
return address;
}
private string EmitAddressOfVariableIdent(VariableIdentNode variableIdent)
{
return "%" + variableIdent.Name;
}
// private string EmitAddressOfArrayIndexAccess(ArrayIndexAccessNode arrayIndexAccess)
// {
// var array = EmitExpression(arrayIndexAccess.Target);
// var index = EmitExpression(arrayIndexAccess.Index);
//
// var elementType = ((ArrayTypeNode)arrayIndexAccess.Target.Type).ElementType;
//
// var offset = TmpName();
// _writer.Indented($"{offset} =l mul {index}, {SizeOf(elementType)}");
// _writer.Indented($"{offset} =l add {offset}, 8");
// _writer.Indented($"{offset} =l add {array}, {offset}");
// return offset;
// }
//
// private string EmitAddressOfStructFieldAccess(StructFieldAccessNode structFieldAccess)
// {
// var target = EmitExpression(structFieldAccess.Target);
//
// var structDef = _definitionTable.LookupStruct(structFieldAccess.StructType.Name);
// var offset = OffsetOf(structDef, structFieldAccess.Field);
//
// var address = TmpName();
// _writer.Indented($"{address} =l add {target}, {offset}");
// return address;
// }
//
// private string EmitAddressOfVariableIdent(VariableIdentNode variableIdent)
// {
// return "%" + variableIdent.Name;
// }
private string EmitBinaryExpression(BinaryExpressionNode binaryExpression)
{
@@ -922,32 +918,6 @@ public class QBEGenerator
};
}
private string EmitExternFuncIdent(ExternFuncIdentNode externFuncIdent)
{
return ExternFuncName(_definitionTable.LookupExternFunc(externFuncIdent.Name));
}
private string EmitLocalFuncIdent(LocalFuncIdentNode localFuncIdent)
{
return LocalFuncName(_definitionTable.LookupLocalFunc(localFuncIdent.Name));
}
private string EmitVariableIdent(VariableIdentNode variableIdent)
{
var address = EmitAddressOfVariableIdent(variableIdent);
if (variableIdent.Type is StructTypeNode)
{
return address;
}
return EmitLoad(variableIdent.Type, address);
}
private string EmitFuncParameterIdent(FuncParameterIdentNode funcParameterIdent)
{
return "%" + funcParameterIdent.Name;
}
private string EmitLiteral(LiteralNode literal)
{
switch (literal.Kind)
@@ -1032,13 +1002,11 @@ public class QBEGenerator
private string EmitStructInitializer(StructInitializerNode structInitializer)
{
var structDef = _definitionTable.LookupStruct(structInitializer.StructType.Name);
var destination = TmpName();
var size = SizeOf(structInitializer.StructType);
_writer.Indented($"{destination} =l alloc8 {size}");
foreach (var field in structDef.Fields)
foreach (var field in structInitializer.StructType.Fields)
{
if (!structInitializer.Initializers.TryGetValue(field.Name, out var valueExpression))
{
@@ -1417,9 +1385,9 @@ public class QBEGenerator
return $"$string{++_stringLiteralIndex}";
}
private string LocalFuncName(LocalFuncNode funcDef)
private string LocalFuncName(string module, LocalFuncNode funcDef)
{
return $"${funcDef.Name}";
return $"${module}.{funcDef.Name}";
}
private string ExternFuncName(ExternFuncNode funcDef)
@@ -1427,19 +1395,24 @@ public class QBEGenerator
return $"${funcDef.CallName}";
}
private string StructTypeName(string name)
private string StructTypeName(string module, string name)
{
return $":{name}";
return $":{module}.{name}";
}
private string StructFuncName(string structName, string funcName)
private string StructFuncName(string module, string structName, string funcName)
{
return $"${structName}_{funcName}";
return $"${module}.{structName}_func.{funcName}";
}
private string StructVtableName(string structName)
private string StructCtorName(string module, string structName)
{
return $"${structName}_vtable";
return $"${module}.{structName}_ctor";
}
private string StructVtableName(string module, string structName)
{
return $"${module}.{structName}_vtable";
}
#endregion

View File

@@ -1,51 +0,0 @@
using NubLang.TypeChecking.Node;
namespace NubLang.Generation;
public sealed class TypedDefinitionTable
{
private readonly List<DefinitionNode> _definitions;
public TypedDefinitionTable(IEnumerable<TypedSyntaxTree> syntaxTrees)
{
_definitions = syntaxTrees.SelectMany(x => x.Definitions).ToList();
}
public LocalFuncNode LookupLocalFunc(string name)
{
return _definitions
.OfType<LocalFuncNode>()
.First(x => x.Name == name);
}
public ExternFuncNode LookupExternFunc(string name)
{
return _definitions
.OfType<ExternFuncNode>()
.First(x => x.Name == name);
}
public StructNode LookupStruct(string name)
{
return _definitions
.OfType<StructNode>()
.First(x => x.Name == name);
}
public InterfaceNode LookupInterface(string name)
{
return _definitions
.OfType<InterfaceNode>()
.First(x => x.Name == name);
}
public IEnumerable<StructNode> GetStructs()
{
return _definitions.OfType<StructNode>();
}
public IEnumerable<InterfaceNode> GetInterfaces()
{
return _definitions.OfType<InterfaceNode>();
}
}

View File

@@ -10,6 +10,7 @@ public sealed class Parser
private readonly List<Diagnostic> _diagnostics = [];
private IReadOnlyList<Token> _tokens = [];
private int _tokenIndex;
private string _moduleName = string.Empty;
private Token? CurrentToken => _tokenIndex < _tokens.Count ? _tokens[_tokenIndex] : null;
private bool HasToken => CurrentToken != null;
@@ -24,7 +25,47 @@ public sealed class Parser
_diagnostics.Clear();
_tokens = tokens;
_tokenIndex = 0;
_moduleName = string.Empty;
var metadata = ParseMetadata();
var definitions = ParseDefinitions();
return new SyntaxTree(definitions, metadata);
}
private SyntaxTreeMetadata ParseMetadata()
{
var imports = new List<string>();
try
{
ExpectSymbol(Symbol.Module);
_moduleName = ExpectIdentifier().Value;
while (TryExpectSymbol(Symbol.Import))
{
imports.Add(ExpectIdentifier().Value);
}
}
catch (ParseException e)
{
_diagnostics.Add(e.Diagnostic);
while (HasToken)
{
if (CurrentToken is SymbolToken { Symbol: Symbol.Module or Symbol.Import })
{
break;
}
Next();
}
}
return new SyntaxTreeMetadata(_moduleName, imports);
}
private List<DefinitionSyntax> ParseDefinitions()
{
var definitions = new List<DefinitionSyntax>();
while (HasToken)
@@ -48,9 +89,9 @@ public sealed class Parser
definitions.Add(definition);
}
catch (ParseException ex)
catch (ParseException e)
{
_diagnostics.Add(ex.Diagnostic);
_diagnostics.Add(e.Diagnostic);
while (HasToken)
{
if (CurrentToken is SymbolToken { Symbol: Symbol.Extern or Symbol.Func or Symbol.Struct or Symbol.Interface })
@@ -63,7 +104,7 @@ public sealed class Parser
}
}
return new SyntaxTree(GetTokens(_tokenIndex), definitions);
return definitions;
}
private FuncSignatureSyntax ParseFuncSignature()
@@ -129,13 +170,13 @@ public sealed class Parser
return new ExternFuncSyntax(GetTokens(startIndex), name.Value, callName, signature);
}
private LocalFuncSyntax ParseFunc(int startIndex)
private FuncSyntax ParseFunc(int startIndex)
{
var name = ExpectIdentifier();
var signature = ParseFuncSignature();
var body = ParseBlock();
return new LocalFuncSyntax(GetTokens(startIndex), name.Value, signature, body);
return new FuncSyntax(GetTokens(startIndex), name.Value, signature, body);
}
private DefinitionSyntax ParseStruct(int startIndex)
@@ -448,7 +489,7 @@ public sealed class Parser
var expr = token switch
{
LiteralToken literal => new LiteralSyntax(GetTokens(startIndex), literal.Value, literal.Kind),
IdentifierToken identifier => new IdentifierSyntax(GetTokens(startIndex), identifier.Value),
IdentifierToken identifier => new IdentifierSyntax(GetTokens(startIndex), Optional<string>.Empty(), identifier.Value),
SymbolToken symbolToken => symbolToken.Symbol switch
{
Symbol.OpenParen => ParseParenthesizedExpression(),
@@ -664,7 +705,7 @@ public sealed class Parser
"string" => new StringTypeSyntax(GetTokens(startIndex)),
"cstring" => new CStringTypeSyntax(GetTokens(startIndex)),
"bool" => new BoolTypeSyntax(GetTokens(startIndex)),
_ => new CustomTypeSyntax(GetTokens(startIndex), name.Value)
_ => new CustomTypeSyntax(GetTokens(startIndex), _moduleName, name.Value)
};
}

View File

@@ -2,22 +2,23 @@ using NubLang.Tokenization;
namespace NubLang.Parsing.Syntax;
public abstract record DefinitionSyntax(IEnumerable<Token> Tokens) : SyntaxNode(Tokens);
// todo(nub31): Check export modifier instead of harcoding true
public abstract record DefinitionSyntax(IEnumerable<Token> Tokens, string Name, bool Exported = true) : SyntaxNode(Tokens);
public record FuncParameterSyntax(IEnumerable<Token> Tokens, string Name, TypeSyntax Type) : SyntaxNode(Tokens);
public record FuncSignatureSyntax(IEnumerable<Token> Tokens, IReadOnlyList<FuncParameterSyntax> Parameters, TypeSyntax ReturnType) : SyntaxNode(Tokens);
public record LocalFuncSyntax(IEnumerable<Token> Tokens, string Name, FuncSignatureSyntax Signature, BlockSyntax Body) : DefinitionSyntax(Tokens);
public record FuncSyntax(IEnumerable<Token> Tokens, string Name, FuncSignatureSyntax Signature, BlockSyntax Body) : DefinitionSyntax(Tokens, Name);
public record ExternFuncSyntax(IEnumerable<Token> Tokens, string Name, string CallName, FuncSignatureSyntax Signature) : DefinitionSyntax(Tokens);
public record ExternFuncSyntax(IEnumerable<Token> Tokens, string Name, string CallName, FuncSignatureSyntax Signature) : DefinitionSyntax(Tokens, Name);
public record StructFieldSyntax(IEnumerable<Token> Tokens, int Index, string Name, TypeSyntax Type, Optional<ExpressionSyntax> Value) : SyntaxNode(Tokens);
public record StructFuncSyntax(IEnumerable<Token> Tokens, string Name, FuncSignatureSyntax Signature, BlockSyntax Body) : SyntaxNode(Tokens);
public record StructSyntax(IEnumerable<Token> Tokens, string Name, IReadOnlyList<StructFieldSyntax> Fields, IReadOnlyList<StructFuncSyntax> Functions, IReadOnlyList<TypeSyntax> InterfaceImplementations) : DefinitionSyntax(Tokens);
public record StructSyntax(IEnumerable<Token> Tokens, string Name, IReadOnlyList<StructFieldSyntax> Fields, IReadOnlyList<StructFuncSyntax> Functions, IReadOnlyList<TypeSyntax> InterfaceImplementations) : DefinitionSyntax(Tokens, Name);
public record InterfaceFuncSyntax(IEnumerable<Token> Tokens, string Name, FuncSignatureSyntax Signature) : SyntaxNode(Tokens);
public record InterfaceSyntax(IEnumerable<Token> Tokens, string Name, IReadOnlyList<InterfaceFuncSyntax> Functions) : DefinitionSyntax(Tokens);
public record InterfaceSyntax(IEnumerable<Token> Tokens, string Name, IReadOnlyList<InterfaceFuncSyntax> Functions) : DefinitionSyntax(Tokens, Name);

View File

@@ -40,7 +40,7 @@ public record FuncCallSyntax(IEnumerable<Token> Tokens, ExpressionSyntax Express
public record DotFuncCallSyntax(IEnumerable<Token> Tokens, string Name, ExpressionSyntax ThisParameter, IReadOnlyList<ExpressionSyntax> Parameters) : ExpressionSyntax(Tokens);
public record IdentifierSyntax(IEnumerable<Token> Tokens, string Name) : ExpressionSyntax(Tokens);
public record IdentifierSyntax(IEnumerable<Token> Tokens, Optional<string> Module, string Name) : ExpressionSyntax(Tokens);
public record ArrayInitializerSyntax(IEnumerable<Token> Tokens, ExpressionSyntax Capacity, TypeSyntax ElementType) : ExpressionSyntax(Tokens);

View File

@@ -4,6 +4,8 @@ namespace NubLang.Parsing.Syntax;
public abstract record SyntaxNode(IEnumerable<Token> Tokens);
public record SyntaxTree(IEnumerable<Token> Tokens, IReadOnlyList<DefinitionSyntax> Definitions) : SyntaxNode(Tokens);
public record SyntaxTreeMetadata(string? ModuleName, IReadOnlyList<string> Imports);
public record SyntaxTree(IReadOnlyList<DefinitionSyntax> Definitions, SyntaxTreeMetadata Metadata);
public record BlockSyntax(IEnumerable<Token> Tokens, IReadOnlyList<StatementSyntax> Statements) : SyntaxNode(Tokens);

View File

@@ -22,4 +22,4 @@ public record CStringTypeSyntax(IEnumerable<Token> Tokens) : TypeSyntax(Tokens);
public record ArrayTypeSyntax(IEnumerable<Token> Tokens, TypeSyntax BaseType) : TypeSyntax(Tokens);
public record CustomTypeSyntax(IEnumerable<Token> Tokens, string Name) : TypeSyntax(Tokens);
public record CustomTypeSyntax(IEnumerable<Token> Tokens, string Module, string Name) : TypeSyntax(Tokens);

View File

@@ -76,4 +76,6 @@ public enum Symbol
Pipe,
And,
Or,
Module,
Import,
}

View File

@@ -20,6 +20,7 @@ public sealed class Tokenizer
["interface"] = Symbol.Interface,
["for"] = Symbol.For,
["extern"] = Symbol.Extern,
["module"] = Symbol.Module,
};
private static readonly Dictionary<char[], Symbol> Symbols = new()

View File

@@ -1,56 +0,0 @@
using NubLang.Parsing.Syntax;
namespace NubLang.TypeChecking;
public class DefinitionTable
{
private readonly List<DefinitionSyntax> _definitions;
public DefinitionTable(IEnumerable<SyntaxTree> syntaxTrees)
{
_definitions = syntaxTrees.SelectMany(x => x.Definitions).ToList();
}
public IEnumerable<LocalFuncSyntax> LookupLocalFunc(string name)
{
return _definitions
.OfType<LocalFuncSyntax>()
.Where(x => x.Name == name);
}
public IEnumerable<ExternFuncSyntax> LookupExternFunc(string name)
{
return _definitions
.OfType<ExternFuncSyntax>()
.Where(x => x.Name == name);
}
public IEnumerable<StructSyntax> LookupStruct(string name)
{
return _definitions
.OfType<StructSyntax>()
.Where(x => x.Name == name);
}
public IEnumerable<StructFieldSyntax> LookupStructField(StructSyntax @struct, string field)
{
return @struct.Fields.Where(x => x.Name == field);
}
public IEnumerable<StructFuncSyntax> LookupStructFunc(StructSyntax @struct, string func)
{
return @struct.Functions.Where(x => x.Name == func);
}
public IEnumerable<InterfaceSyntax> LookupInterface(string name)
{
return _definitions
.OfType<InterfaceSyntax>()
.Where(x => x.Name == name);
}
public IEnumerable<InterfaceFuncSyntax> LookupInterfaceFunc(InterfaceSyntax @interface, string name)
{
return @interface.Functions.Where(x => x.Name == name);
}
}

View File

@@ -0,0 +1,174 @@
using NubLang.Parsing.Syntax;
using NubLang.TypeChecking.Node;
namespace NubLang.TypeChecking;
public class Module
{
public static IReadOnlyList<Module> CollectFromSyntaxTrees(IReadOnlyList<SyntaxTree> syntaxTrees)
{
var modules = new Dictionary<string, Module>();
foreach (var syntaxTree in syntaxTrees)
{
var name = syntaxTree.Metadata.ModuleName;
if (name == null)
{
continue;
}
if (!modules.TryGetValue(name, out var module))
{
module = new Module(name, syntaxTree.Metadata.Imports);
modules[name] = module;
}
foreach (var definition in syntaxTree.Definitions)
{
module.AddDefinition(definition);
}
}
return modules.Values.ToList();
}
private readonly List<DefinitionSyntax> _definitions = [];
public Module(string name, IReadOnlyList<string> imports)
{
Name = name;
Imports = imports;
}
public string Name { get; }
public IReadOnlyList<string> Imports { get; }
public IReadOnlyList<DefinitionSyntax> Definitions => _definitions;
private void AddDefinition(DefinitionSyntax syntax)
{
_definitions.Add(syntax);
}
}
public class TypedModule
{
public TypedModule(string name, IReadOnlyList<DefinitionNode> definitions)
{
Name = name;
Definitions = definitions;
}
public string Name { get; }
public IReadOnlyList<DefinitionNode> Definitions { get; }
}
public class ModuleSignature
{
public static IReadOnlyDictionary<string, ModuleSignature> CollectFromSyntaxTrees(IReadOnlyList<SyntaxTree> syntaxTrees)
{
var modules = new Dictionary<string, ModuleSignature>();
foreach (var syntaxTree in syntaxTrees)
{
var moduleName = syntaxTree.Metadata.ModuleName;
if (moduleName == null)
{
continue;
}
if (!modules.TryGetValue(moduleName, out var module))
{
module = new ModuleSignature();
modules[moduleName] = module;
}
foreach (var def in syntaxTree.Definitions)
{
if (def.Exported)
{
switch (def)
{
case ExternFuncSyntax externFuncDef:
{
var parameters = externFuncDef.Signature.Parameters.Select(p => TypeResolver.ResolveType(p.Type, modules)).ToList();
var returnType = TypeResolver.ResolveType(externFuncDef.Signature.ReturnType, modules);
var type = new FuncTypeNode(parameters, returnType);
module._functions.Add(externFuncDef.Name, type);
break;
}
case FuncSyntax funcDef:
{
var parameters = funcDef.Signature.Parameters.Select(p => TypeResolver.ResolveType(p.Type, modules)).ToList();
var returnType = TypeResolver.ResolveType(funcDef.Signature.ReturnType, modules);
var type = new FuncTypeNode(parameters, returnType);
module._functions.Add(funcDef.Name, type);
break;
}
case InterfaceSyntax interfaceDef:
{
var functions = new Dictionary<string, FuncTypeNode>();
foreach (var function in interfaceDef.Functions)
{
var parameters = function.Signature.Parameters.Select(p => TypeResolver.ResolveType(p.Type, modules)).ToList();
var returnType = TypeResolver.ResolveType(function.Signature.ReturnType, modules);
functions.Add(function.Name, new FuncTypeNode(parameters, returnType));
}
var type = new InterfaceTypeNode(moduleName, interfaceDef.Name, functions);
module._interfaces.Add(type);
break;
}
case StructSyntax structDef:
{
var fields = structDef.Fields.Select(x => new StructTypeField(x.Name, TypeResolver.ResolveType(x.Type, modules), x.Value.HasValue)).ToList();
var functions = new Dictionary<string, FuncTypeNode>();
foreach (var function in structDef.Functions)
{
var parameters = function.Signature.Parameters.Select(p => TypeResolver.ResolveType(p.Type, modules)).ToList();
var returnType = TypeResolver.ResolveType(function.Signature.ReturnType, modules);
functions.Add(function.Name, new FuncTypeNode(parameters, returnType));
}
var interfaceImplementations = new List<InterfaceTypeNode>();
foreach (var interfaceImplementation in structDef.InterfaceImplementations)
{
if (interfaceImplementation is not CustomTypeSyntax customType)
{
throw new Exception("Interface implementation is not a custom type");
}
var resolvedType = TypeResolver.ResolveCustomType(customType.Module, customType.Name, modules);
if (resolvedType is not InterfaceTypeNode interfaceType)
{
throw new Exception("Interface implementation is not a interface");
}
interfaceImplementations.Add(interfaceType);
}
var type = new StructTypeNode(moduleName, structDef.Name, fields, functions, interfaceImplementations);
module._structs.Add(type);
break;
}
default:
{
throw new ArgumentOutOfRangeException(nameof(def));
}
}
}
}
}
return modules;
}
private readonly List<StructTypeNode> _structs = [];
private readonly List<InterfaceTypeNode> _interfaces = [];
private readonly Dictionary<string, FuncTypeNode> _functions = [];
public IReadOnlyList<StructTypeNode> StructTypes => _structs;
public IReadOnlyList<InterfaceTypeNode> InterfaceTypes => _interfaces;
public IReadOnlyDictionary<string, FuncTypeNode> Functions => _functions;
}

View File

@@ -1,23 +1,21 @@
using NubLang.Tokenization;
namespace NubLang.TypeChecking.Node;
namespace NubLang.TypeChecking.Node;
public abstract record DefinitionNode : Node;
public abstract record DefinitionNode(IEnumerable<Token> Tokens) : Node(Tokens);
public record FuncParameterNode(string Name, TypeNode Type) : Node;
public record FuncParameterNode(string Name, TypeNode Type, IEnumerable<Token> Tokens) : Node(Tokens);
public record FuncSignatureNode(IReadOnlyList<FuncParameterNode> Parameters, TypeNode ReturnType) : Node;
public record FuncSignatureNode(IReadOnlyList<FuncParameterNode> Parameters, TypeNode ReturnType, IEnumerable<Token> Tokens) : Node(Tokens);
public record LocalFuncNode(string Name, FuncSignatureNode Signature, BlockNode Body) : DefinitionNode;
public record LocalFuncNode(string Name, FuncSignatureNode Signature, BlockNode Body, IEnumerable<Token> Tokens) : DefinitionNode(Tokens);
public record ExternFuncNode(string Name, string CallName, FuncSignatureNode Signature) : DefinitionNode;
public record ExternFuncNode(string Name, string CallName, FuncSignatureNode Signature, IEnumerable<Token> Tokens) : DefinitionNode(Tokens);
public record StructFieldNode(int Index, string Name, TypeNode Type, Optional<ExpressionNode> Value) : Node;
public record StructFieldNode(int Index, string Name, TypeNode Type, Optional<ExpressionNode> Value, IEnumerable<Token> Tokens) : Node(Tokens);
public record StructFuncNode(string Name, FuncSignatureNode Signature, BlockNode Body) : Node;
public record StructFuncNode(string Name, FuncSignatureNode Signature, BlockNode Body, IEnumerable<Token> Tokens) : Node(Tokens);
public record StructNode(string Name, IReadOnlyList<StructFieldNode> Fields, IReadOnlyList<StructFuncNode> Functions, IReadOnlyList<InterfaceTypeNode> InterfaceImplementations) : DefinitionNode;
public record StructNode(string Name, StructTypeNode Type, IReadOnlyList<StructFieldNode> Fields, IReadOnlyList<StructFuncNode> Functions, IReadOnlyList<InterfaceTypeNode> InterfaceImplementations, IEnumerable<Token> Tokens) : DefinitionNode(Tokens);
public record InterfaceFuncNode(string Name, FuncSignatureNode Signature) : Node;
public record InterfaceFuncNode(string Name, FuncSignatureNode Signature, IEnumerable<Token> Tokens) : Node(Tokens);
public record InterfaceNode(string Name, IReadOnlyList<InterfaceFuncNode> Functions, IEnumerable<Token> Tokens) : DefinitionNode(Tokens);
public record InterfaceNode(string Name, IReadOnlyList<InterfaceFuncNode> Functions) : DefinitionNode;

View File

@@ -30,45 +30,43 @@ public enum BinaryOperator
BitwiseOr
}
public abstract record ExpressionNode(TypeNode Type, IEnumerable<Token> Tokens) : Node(Tokens);
public abstract record ExpressionNode(TypeNode Type) : Node;
public abstract record LValueExpressionNode(TypeNode Type, IEnumerable<Token> Tokens) : RValueExpressionNode(Type, Tokens);
public abstract record RValueExpressionNode(TypeNode Type, IEnumerable<Token> Tokens) : ExpressionNode(Type, Tokens);
public abstract record LValueExpressionNode(TypeNode Type) : RValueExpressionNode(Type);
public abstract record RValueExpressionNode(TypeNode Type) : ExpressionNode(Type);
public record BinaryExpressionNode(TypeNode Type, ExpressionNode Left, BinaryOperator Operator, ExpressionNode Right, IEnumerable<Token> Tokens) : RValueExpressionNode(Type, Tokens);
public record BinaryExpressionNode(TypeNode Type, ExpressionNode Left, BinaryOperator Operator, ExpressionNode Right) : RValueExpressionNode(Type);
public record UnaryExpressionNode(TypeNode Type, UnaryOperator Operator, ExpressionNode Operand, IEnumerable<Token> Tokens) : RValueExpressionNode(Type, Tokens);
public record UnaryExpressionNode(TypeNode Type, UnaryOperator Operator, ExpressionNode Operand) : RValueExpressionNode(Type);
public record FuncCallNode(TypeNode Type, ExpressionNode Expression, IReadOnlyList<ExpressionNode> Parameters, IEnumerable<Token> Tokens) : RValueExpressionNode(Type, Tokens);
public record FuncCallNode(TypeNode Type, ExpressionNode Expression, IReadOnlyList<ExpressionNode> Parameters) : RValueExpressionNode(Type);
public record StructFuncCallNode(TypeNode Type, string Name, StructTypeNode StructType, ExpressionNode StructExpression, IReadOnlyList<ExpressionNode> Parameters, IEnumerable<Token> Tokens) : RValueExpressionNode(Type, Tokens);
public record StructFuncCallNode(TypeNode Type, string Name, StructTypeNode StructType, ExpressionNode StructExpression, IReadOnlyList<ExpressionNode> Parameters) : RValueExpressionNode(Type);
public record InterfaceFuncCallNode(TypeNode Type, string Name, InterfaceTypeNode InterfaceType, ExpressionNode InterfaceExpression, IReadOnlyList<ExpressionNode> Parameters, IEnumerable<Token> Tokens) : RValueExpressionNode(Type, Tokens);
public record InterfaceFuncCallNode(TypeNode Type, string Name, InterfaceTypeNode InterfaceType, ExpressionNode InterfaceExpression, IReadOnlyList<ExpressionNode> Parameters) : RValueExpressionNode(Type);
public record VariableIdentNode(TypeNode Type, string Name, IEnumerable<Token> Tokens) : LValueExpressionNode(Type, Tokens);
public record VariableIdentifierNode(TypeNode Type, string Name) : LValueExpressionNode(Type);
public record FuncParameterIdentNode(TypeNode Type, string Name, IEnumerable<Token> Tokens) : RValueExpressionNode(Type, Tokens);
public record FuncParameterIdentifierNode(TypeNode Type, string Name) : RValueExpressionNode(Type);
public record LocalFuncIdentNode(TypeNode Type, string Name, IEnumerable<Token> Tokens) : RValueExpressionNode(Type, Tokens);
public record FuncIdentifierNode(TypeNode Type, string Module, string Name) : RValueExpressionNode(Type);
public record ExternFuncIdentNode(TypeNode Type, string Name, IEnumerable<Token> Tokens) : RValueExpressionNode(Type, Tokens);
public record ArrayInitializerNode(TypeNode Type, ExpressionNode Capacity, TypeNode ElementType) : RValueExpressionNode(Type);
public record ArrayInitializerNode(TypeNode Type, ExpressionNode Capacity, TypeNode ElementType, IEnumerable<Token> Tokens) : RValueExpressionNode(Type, Tokens);
public record ArrayIndexAccessNode(TypeNode Type, ExpressionNode Target, ExpressionNode Index) : LValueExpressionNode(Type);
public record ArrayIndexAccessNode(TypeNode Type, ExpressionNode Target, ExpressionNode Index, IEnumerable<Token> Tokens) : LValueExpressionNode(Type, Tokens);
public record AddressOfNode(TypeNode Type, LValueExpressionNode LValue) : RValueExpressionNode(Type);
public record AddressOfNode(TypeNode Type, LValueExpressionNode LValue, IEnumerable<Token> Tokens) : RValueExpressionNode(Type, Tokens);
public record LiteralNode(TypeNode Type, string Value, LiteralKind Kind) : RValueExpressionNode(Type);
public record LiteralNode(TypeNode Type, string Value, LiteralKind Kind, IEnumerable<Token> Tokens) : RValueExpressionNode(Type, Tokens);
public record StructFieldAccessNode(TypeNode Type, StructTypeNode StructType, ExpressionNode Target, string Field) : LValueExpressionNode(Type);
public record StructFieldAccessNode(TypeNode Type, StructTypeNode StructType, ExpressionNode Target, string Field, IEnumerable<Token> Tokens) : LValueExpressionNode(Type, Tokens);
public record StructInitializerNode(StructTypeNode StructType, Dictionary<string, ExpressionNode> Initializers) : RValueExpressionNode(StructType);
public record StructInitializerNode(StructTypeNode StructType, Dictionary<string, ExpressionNode> Initializers, IEnumerable<Token> Tokens) : RValueExpressionNode(StructType, Tokens);
public record DereferenceNode(TypeNode Type, ExpressionNode Expression) : RValueExpressionNode(Type);
public record DereferenceNode(TypeNode Type, ExpressionNode Expression, IEnumerable<Token> Tokens) : RValueExpressionNode(Type, Tokens);
public record ConvertToInterfaceNode(TypeNode Type, InterfaceTypeNode InterfaceType, StructTypeNode StructType, ExpressionNode Implementation) : RValueExpressionNode(Type);
public record ConvertToInterfaceNode(TypeNode Type, InterfaceTypeNode InterfaceType, StructTypeNode StructType, ExpressionNode Implementation, IEnumerable<Token> Tokens) : RValueExpressionNode(Type, Tokens);
public record ConvertIntNode(TypeNode Type, ExpressionNode Value, IntTypeNode ValueType, IntTypeNode TargetType) : RValueExpressionNode(Type);
public record ConvertIntNode(TypeNode Type, ExpressionNode Value, IntTypeNode ValueType, IntTypeNode TargetType, IEnumerable<Token> Tokens) : RValueExpressionNode(Type, Tokens);
public record ConvertFloatNode(TypeNode Type, ExpressionNode Value, FloatTypeNode ValueType, FloatTypeNode TargetType, IEnumerable<Token> Tokens) : RValueExpressionNode(Type, Tokens);
public record ConvertFloatNode(TypeNode Type, ExpressionNode Value, FloatTypeNode ValueType, FloatTypeNode TargetType) : RValueExpressionNode(Type);

View File

@@ -1,9 +1,5 @@
using NubLang.Tokenization;
namespace NubLang.TypeChecking.Node;
namespace NubLang.TypeChecking.Node;
public abstract record Node;
public abstract record Node(IEnumerable<Token> Tokens);
public record TypedSyntaxTree(IReadOnlyList<DefinitionNode> Definitions);
public record BlockNode(IReadOnlyList<StatementNode> Statements, IEnumerable<Token> Tokens) : Node(Tokens);
public record BlockNode(IReadOnlyList<StatementNode> Statements) : Node;

View File

@@ -1,21 +1,19 @@
using NubLang.Tokenization;
namespace NubLang.TypeChecking.Node;
namespace NubLang.TypeChecking.Node;
public record StatementNode : Node;
public record StatementNode(IEnumerable<Token> Tokens) : Node(Tokens);
public record StatementExpressionNode(ExpressionNode Expression) : StatementNode;
public record StatementExpressionNode(ExpressionNode Expression, IEnumerable<Token> Tokens) : StatementNode(Tokens);
public record ReturnNode(Optional<ExpressionNode> Value) : StatementNode;
public record ReturnNode(Optional<ExpressionNode> Value, IEnumerable<Token> Tokens) : StatementNode(Tokens);
public record AssignmentNode(LValueExpressionNode Target, ExpressionNode Value) : StatementNode;
public record AssignmentNode(LValueExpressionNode Target, ExpressionNode Value, IEnumerable<Token> Tokens) : StatementNode(Tokens);
public record IfNode(ExpressionNode Condition, BlockNode Body, Optional<Variant<IfNode, BlockNode>> Else) : StatementNode;
public record IfNode(ExpressionNode Condition, BlockNode Body, Optional<Variant<IfNode, BlockNode>> Else, IEnumerable<Token> Tokens) : StatementNode(Tokens);
public record VariableDeclarationNode(string Name, Optional<ExpressionNode> Assignment, TypeNode Type) : StatementNode;
public record VariableDeclarationNode(string Name, Optional<ExpressionNode> Assignment, TypeNode Type, IEnumerable<Token> Tokens) : StatementNode(Tokens);
public record ContinueNode : StatementNode;
public record ContinueNode(IEnumerable<Token> Tokens) : StatementNode(Tokens);
public record BreakNode : StatementNode;
public record BreakNode(IEnumerable<Token> Tokens) : StatementNode(Tokens);
public record WhileNode(ExpressionNode Condition, BlockNode Body, IEnumerable<Token> Tokens) : StatementNode(Tokens);
public record WhileNode(ExpressionNode Condition, BlockNode Body) : StatementNode;

View File

@@ -174,25 +174,34 @@ public class StringTypeNode : ComplexTypeNode
public override int GetHashCode() => HashCode.Combine(typeof(StringTypeNode));
}
public class StructTypeNode(string name, IReadOnlyList<TypeNode> fields, IReadOnlyList<FuncTypeNode> functions, IReadOnlyList<InterfaceTypeNode> interfaceImplementations) : ComplexTypeNode
public class StructTypeField(string name, TypeNode type, bool hasDefaultValue)
{
public string Name { get; } = name;
public IReadOnlyList<TypeNode> Fields { get; set; } = fields;
public IReadOnlyList<FuncTypeNode> Functions { get; set; } = functions;
public TypeNode Type { get; } = type;
public bool HasDefaultValue { get; } = hasDefaultValue;
}
public class StructTypeNode(string module, string name, IReadOnlyList<StructTypeField> fields, IReadOnlyDictionary<string, FuncTypeNode> functions, IReadOnlyList<InterfaceTypeNode> interfaceImplementations) : ComplexTypeNode
{
public string Module { get; } = module;
public string Name { get; } = name;
public IReadOnlyList<StructTypeField> Fields { get; set; } = fields;
public IReadOnlyDictionary<string, FuncTypeNode> Functions { get; set; } = functions;
public IReadOnlyList<InterfaceTypeNode> InterfaceImplementations { get; set; } = interfaceImplementations;
public override string ToString() => Name;
public override bool Equals(TypeNode? other) => other is StructTypeNode structType && Name == structType.Name;
public override bool Equals(TypeNode? other) => other is StructTypeNode structType && Name == structType.Name && Module == structType.Module;
public override int GetHashCode() => HashCode.Combine(typeof(StructTypeNode), Name);
}
public class InterfaceTypeNode(string name, IReadOnlyList<FuncTypeNode> functions) : ComplexTypeNode
public class InterfaceTypeNode(string module, string name, IReadOnlyDictionary<string, FuncTypeNode> functions) : ComplexTypeNode
{
public string Module { get; } = module;
public string Name { get; } = name;
public IReadOnlyList<FuncTypeNode> Functions { get; set; } = functions;
public IReadOnlyDictionary<string, FuncTypeNode> Functions { get; set; } = functions;
public override string ToString() => Name;
public override bool Equals(TypeNode? other) => other is InterfaceTypeNode interfaceType && Name == interfaceType.Name;
public override bool Equals(TypeNode? other) => other is InterfaceTypeNode interfaceType && Name == interfaceType.Name && Module == interfaceType.Module;
public override int GetHashCode() => HashCode.Combine(typeof(InterfaceTypeNode), Name);
}

View File

@@ -7,33 +7,31 @@ namespace NubLang.TypeChecking;
public sealed class TypeChecker
{
private readonly SyntaxTree _syntaxTree;
private readonly DefinitionTable _definitionTable;
private readonly Module _currentModule;
private readonly IReadOnlyDictionary<string, ModuleSignature> _moduleSignatures;
private readonly Stack<Scope> _scopes = [];
private readonly Stack<TypeNode> _funcReturnTypes = [];
private readonly List<Diagnostic> _diagnostics = [];
private readonly Dictionary<string, TypeNode> _typeCache = new();
private Scope Scope => _scopes.Peek();
public TypeChecker(SyntaxTree syntaxTree, DefinitionTable definitionTable)
public TypeChecker(Module currentModule, IReadOnlyDictionary<string, ModuleSignature> moduleSignatures)
{
_syntaxTree = syntaxTree;
_definitionTable = definitionTable;
_currentModule = currentModule;
_moduleSignatures = moduleSignatures.Where(x => currentModule.Imports.Contains(x.Key) || _currentModule.Name == x.Key).ToDictionary();
}
public IReadOnlyList<Diagnostic> GetDiagnostics() => _diagnostics;
public TypedSyntaxTree Check()
public TypedModule CheckModule()
{
_diagnostics.Clear();
_funcReturnTypes.Clear();
_scopes.Clear();
var definitions = new List<DefinitionNode>();
foreach (var definition in _syntaxTree.Definitions)
foreach (var definition in _currentModule.Definitions)
{
try
{
@@ -45,7 +43,7 @@ public sealed class TypeChecker
}
}
return new TypedSyntaxTree(definitions);
return new TypedModule(_currentModule.Name, definitions);
}
private DefinitionNode CheckDefinition(DefinitionSyntax node)
@@ -54,7 +52,7 @@ public sealed class TypeChecker
{
ExternFuncSyntax definition => CheckExternFuncDefinition(definition),
InterfaceSyntax definition => CheckInterfaceDefinition(definition),
LocalFuncSyntax definition => CheckLocalFuncDefinition(definition),
FuncSyntax definition => CheckLocalFuncDefinition(definition),
StructSyntax definition => CheckStructDefinition(definition),
_ => throw new ArgumentOutOfRangeException(nameof(node))
};
@@ -62,102 +60,72 @@ public sealed class TypeChecker
private InterfaceNode CheckInterfaceDefinition(InterfaceSyntax node)
{
var functions = new List<InterfaceFuncNode>();
foreach (var function in node.Functions)
{
functions.Add(new InterfaceFuncNode(function.Name, CheckFuncSignature(function.Signature), function.Tokens));
}
return new InterfaceNode(node.Name, functions, node.Tokens);
throw new NotImplementedException();
}
private StructNode CheckStructDefinition(StructSyntax node)
{
var structFields = new List<StructFieldNode>();
var fields = new List<StructFieldNode>();
foreach (var field in node.Fields)
{
var value = Optional.Empty<ExpressionNode>();
if (field.Value.HasValue)
{
value = CheckExpression(field.Value.Value, CheckType(field.Type));
value = CheckExpression(field.Value.Value);
}
structFields.Add(new StructFieldNode(field.Index, field.Name, CheckType(field.Type), value, field.Tokens));
fields.Add(new StructFieldNode(field.Index, field.Name, ResolveType(field.Type), value));
}
var funcs = new List<StructFuncNode>();
foreach (var func in node.Functions)
var functions = new List<StructFuncNode>();
foreach (var function in node.Functions)
{
var scope = new Scope();
scope.Declare(new Identifier("this", GetStructType(node), IdentifierKind.FunctionParameter));
foreach (var parameter in func.Signature.Parameters)
// todo(nub31): Add this parameter
foreach (var parameter in function.Signature.Parameters)
{
scope.Declare(new Identifier(parameter.Name, CheckType(parameter.Type), IdentifierKind.FunctionParameter));
scope.Declare(new Identifier(parameter.Name, ResolveType(parameter.Type), IdentifierKind.FunctionParameter));
}
_funcReturnTypes.Push(CheckType(func.Signature.ReturnType));
var body = CheckBlock(func.Body, scope);
_funcReturnTypes.Push(ResolveType(function.Signature.ReturnType));
var body = CheckBlock(function.Body, scope);
_funcReturnTypes.Pop();
funcs.Add(new StructFuncNode(func.Name, CheckFuncSignature(func.Signature), body, func.Tokens));
functions.Add(new StructFuncNode(function.Name, CheckFuncSignature(function.Signature), body));
}
var interfaceImplementations = new List<InterfaceTypeNode>();
foreach (var interfaceImplementation in node.InterfaceImplementations)
{
var type = CheckType(interfaceImplementation);
var type = ResolveType(interfaceImplementation);
if (type is not InterfaceTypeNode interfaceType)
{
_diagnostics.Add(Diagnostic.Error("Interface implementation is not a custom type").Build());
continue;
}
var interfaceDefs = _definitionTable.LookupInterface(interfaceType.Name).ToArray();
if (interfaceDefs.Length == 0)
{
_diagnostics.Add(Diagnostic.Error($"Interface {interfaceType.Name} is not defined").Build());
continue;
}
if (interfaceDefs.Length > 1)
{
_diagnostics.Add(Diagnostic.Error($"Interface {interfaceType.Name} has multiple definitions").Build());
_diagnostics.Add(Diagnostic.Error($"Struct {node.Name} cannot implement non-struct type {interfaceImplementation}").At(interfaceImplementation).Build());
continue;
}
interfaceImplementations.Add(interfaceType);
}
return new StructNode(node.Name, GetStructType(node), structFields, funcs, interfaceImplementations, node.Tokens);
return new StructNode(node.Name, fields, functions, interfaceImplementations);
}
private ExternFuncNode CheckExternFuncDefinition(ExternFuncSyntax node)
{
return new ExternFuncNode(node.Name, node.CallName, CheckFuncSignature(node.Signature), node.Tokens);
return new ExternFuncNode(node.Name, node.CallName, CheckFuncSignature(node.Signature));
}
private LocalFuncNode CheckLocalFuncDefinition(LocalFuncSyntax node)
private LocalFuncNode CheckLocalFuncDefinition(FuncSyntax node)
{
var signature = CheckFuncSignature(node.Signature);
var scope = new Scope();
foreach (var parameter in signature.Parameters)
foreach (var parameter in node.Signature.Parameters)
{
scope.Declare(new Identifier(parameter.Name, parameter.Type, IdentifierKind.FunctionParameter));
scope.Declare(new Identifier(parameter.Name, ResolveType(parameter.Type), IdentifierKind.FunctionParameter));
}
_funcReturnTypes.Push(signature.ReturnType);
_funcReturnTypes.Push(ResolveType(node.Signature.ReturnType));
var body = CheckBlock(node.Body, scope);
_funcReturnTypes.Pop();
return new LocalFuncNode(node.Name, signature, body, node.Tokens);
return new LocalFuncNode(node.Name, CheckFuncSignature(node.Signature), body);
}
private StatementNode CheckStatement(StatementSyntax node)
@@ -165,8 +133,8 @@ public sealed class TypeChecker
return node switch
{
AssignmentSyntax statement => CheckAssignment(statement),
BreakSyntax => new BreakNode(node.Tokens),
ContinueSyntax => new ContinueNode(node.Tokens),
BreakSyntax => new BreakNode(),
ContinueSyntax => new ContinueNode(),
IfSyntax statement => CheckIf(statement),
ReturnSyntax statement => CheckReturn(statement),
StatementExpressionSyntax statement => CheckStatementExpression(statement),
@@ -178,30 +146,12 @@ public sealed class TypeChecker
private StatementNode CheckAssignment(AssignmentSyntax statement)
{
var target = CheckExpression(statement.Target);
if (target is not LValueExpressionNode targetLValue)
{
throw new TypeCheckerException(Diagnostic.Error("Cannot assign to rvalue").Build());
}
var value = CheckExpression(statement.Value, target.Type);
return new AssignmentNode(targetLValue, value, statement.Tokens);
throw new NotImplementedException();
}
private IfNode CheckIf(IfSyntax statement)
{
var elseStatement = Optional.Empty<Variant<IfNode, BlockNode>>();
if (statement.Else.HasValue)
{
elseStatement = statement.Else.Value.Match<Variant<IfNode, BlockNode>>
(
elseIf => CheckIf(elseIf),
@else => CheckBlock(@else)
);
}
return new IfNode(CheckExpression(statement.Condition, new BoolTypeNode()), CheckBlock(statement.Body), elseStatement, statement.Tokens);
throw new NotImplementedException();
}
private ReturnNode CheckReturn(ReturnSyntax statement)
@@ -213,58 +163,54 @@ public sealed class TypeChecker
value = CheckExpression(statement.Value.Value, _funcReturnTypes.Peek());
}
return new ReturnNode(value, statement.Tokens);
return new ReturnNode(value);
}
private StatementExpressionNode CheckStatementExpression(StatementExpressionSyntax statement)
{
return new StatementExpressionNode(CheckExpression(statement.Expression), statement.Tokens);
return new StatementExpressionNode(CheckExpression(statement.Expression));
}
private VariableDeclarationNode CheckVariableDeclaration(VariableDeclarationSyntax statement)
{
TypeNode? type = null;
ExpressionNode? assignmentNode = null;
if (statement.ExplicitType.HasValue)
if (statement.ExplicitType.TryGetValue(out var explicitType))
{
type = CheckType(statement.ExplicitType.Value);
type = ResolveType(explicitType);
}
var assignment = Optional<ExpressionNode>.Empty();
if (statement.Assignment.HasValue)
if (statement.Assignment.TryGetValue(out var assignment))
{
var boundValue = CheckExpression(statement.Assignment.Value, type);
assignment = boundValue;
if (type != null)
{
if (boundValue.Type != type)
{
throw new TypeCheckerException(Diagnostic.Error($"{boundValue.Type} is not assignable to {type}").Build());
}
}
else
{
if (type == null)
{
type = boundValue.Type;
}
}
assignmentNode = CheckExpression(assignment, type);
type ??= assignmentNode.Type;
}
if (type == null)
{
throw new TypeCheckerException(Diagnostic.Error($"Unknown type of variable {statement.Name}").Build());
throw new TypeCheckerException(Diagnostic.Error($"Cannot infer type of variable {statement.Name}").At(statement).Build());
}
Scope.Declare(new Identifier(statement.Name, type, IdentifierKind.Variable));
return new VariableDeclarationNode(statement.Name, assignment, type, statement.Tokens);
return new VariableDeclarationNode(statement.Name, Optional.OfNullable(assignmentNode), type);
}
private WhileNode CheckWhile(WhileSyntax statement)
{
return new WhileNode(CheckExpression(statement.Condition, new BoolTypeNode()), CheckBlock(statement.Body), statement.Tokens);
throw new NotImplementedException();
}
private FuncSignatureNode CheckFuncSignature(FuncSignatureSyntax statement)
{
var parameters = new List<FuncParameterNode>();
foreach (var parameter in statement.Parameters)
{
parameters.Add(new FuncParameterNode(parameter.Name, ResolveType(parameter.Type)));
}
return new FuncSignatureNode(parameters, ResolveType(statement.ReturnType));
}
private ExpressionNode CheckExpression(ExpressionSyntax node, TypeNode? expectedType = null)
@@ -293,14 +239,14 @@ public sealed class TypeChecker
if (result.Type is StructTypeNode structType && expectedType is InterfaceTypeNode interfaceType)
{
return new ConvertToInterfaceNode(interfaceType, interfaceType, structType, result, node.Tokens);
return new ConvertToInterfaceNode(interfaceType, interfaceType, structType, result);
}
if (result.Type is IntTypeNode sourceIntType && expectedType is IntTypeNode targetIntType)
{
if (sourceIntType.Signed == targetIntType.Signed && sourceIntType.Width < targetIntType.Width)
{
return new ConvertIntNode(targetIntType, result, sourceIntType, targetIntType, node.Tokens);
return new ConvertIntNode(targetIntType, result, sourceIntType, targetIntType);
}
}
@@ -308,245 +254,115 @@ public sealed class TypeChecker
{
if (sourceFloatType.Width < targetFloatType.Width)
{
return new ConvertFloatNode(targetFloatType, result, sourceFloatType, targetFloatType, node.Tokens);
return new ConvertFloatNode(targetFloatType, result, sourceFloatType, targetFloatType);
}
}
throw new TypeCheckerException(Diagnostic.Error($"Cannot convert {result.Type} to {expectedType}").Build());
throw new TypeCheckerException(Diagnostic.Error($"Cannot convert {result.Type} to {expectedType}").At(node).Build());
}
private AddressOfNode CheckAddressOf(AddressOfSyntax expression)
{
var inner = CheckExpression(expression.Expression);
if (inner is not LValueExpressionNode lValueInner)
{
throw new TypeCheckerException(Diagnostic.Error("Cannot take address of rvalue").Build());
}
return new AddressOfNode(new PointerTypeNode(inner.Type), lValueInner, expression.Tokens);
throw new NotImplementedException();
}
private ArrayIndexAccessNode CheckArrayIndexAccess(ArrayIndexAccessSyntax expression)
{
var boundArray = CheckExpression(expression.Target);
var elementType = ((ArrayTypeNode)boundArray.Type).ElementType;
return new ArrayIndexAccessNode(elementType, boundArray, CheckExpression(expression.Index, new IntTypeNode(false, 64)), expression.Tokens);
throw new NotImplementedException();
}
private ArrayInitializerNode CheckArrayInitializer(ArrayInitializerSyntax expression)
{
var capacity = CheckExpression(expression.Capacity, new IntTypeNode(false, 64));
var type = new ArrayTypeNode(CheckType(expression.ElementType));
return new ArrayInitializerNode(type, capacity, CheckType(expression.ElementType), expression.Tokens);
throw new NotImplementedException();
}
private BinaryExpressionNode CheckBinaryExpression(BinaryExpressionSyntax expression)
{
var boundLeft = CheckExpression(expression.Left);
var boundRight = CheckExpression(expression.Right, boundLeft.Type);
var op = expression.Operator switch
{
BinaryOperatorSyntax.Equal => BinaryOperator.Equal,
BinaryOperatorSyntax.NotEqual => BinaryOperator.NotEqual,
BinaryOperatorSyntax.GreaterThan => BinaryOperator.GreaterThan,
BinaryOperatorSyntax.GreaterThanOrEqual => BinaryOperator.GreaterThanOrEqual,
BinaryOperatorSyntax.LessThan => BinaryOperator.LessThan,
BinaryOperatorSyntax.LessThanOrEqual => BinaryOperator.LessThanOrEqual,
BinaryOperatorSyntax.Plus => BinaryOperator.Plus,
BinaryOperatorSyntax.Minus => BinaryOperator.Minus,
BinaryOperatorSyntax.Multiply => BinaryOperator.Multiply,
BinaryOperatorSyntax.Divide => BinaryOperator.Divide,
BinaryOperatorSyntax.Modulo => BinaryOperator.Modulo,
BinaryOperatorSyntax.LeftShift => BinaryOperator.LeftShift,
BinaryOperatorSyntax.RightShift => BinaryOperator.RightShift,
BinaryOperatorSyntax.BitwiseAnd => BinaryOperator.BitwiseAnd,
BinaryOperatorSyntax.BitwiseXor => BinaryOperator.BitwiseXor,
BinaryOperatorSyntax.BitwiseOr => BinaryOperator.BitwiseOr,
BinaryOperatorSyntax.LogicalAnd => BinaryOperator.LogicalAnd,
BinaryOperatorSyntax.LogicalOr => BinaryOperator.LogicalOr,
_ => throw new ArgumentOutOfRangeException(nameof(expression.Operator), expression.Operator, null)
};
var resultingType = op switch
{
BinaryOperator.Equal => new BoolTypeNode(),
BinaryOperator.NotEqual => new BoolTypeNode(),
BinaryOperator.GreaterThan => new BoolTypeNode(),
BinaryOperator.GreaterThanOrEqual => new BoolTypeNode(),
BinaryOperator.LessThan => new BoolTypeNode(),
BinaryOperator.LessThanOrEqual => new BoolTypeNode(),
BinaryOperator.LogicalAnd => new BoolTypeNode(),
BinaryOperator.LogicalOr => new BoolTypeNode(),
BinaryOperator.Plus => boundLeft.Type,
BinaryOperator.Minus => boundLeft.Type,
BinaryOperator.Multiply => boundLeft.Type,
BinaryOperator.Divide => boundLeft.Type,
BinaryOperator.Modulo => boundLeft.Type,
BinaryOperator.LeftShift => boundLeft.Type,
BinaryOperator.RightShift => boundLeft.Type,
BinaryOperator.BitwiseAnd => boundLeft.Type,
BinaryOperator.BitwiseXor => boundLeft.Type,
BinaryOperator.BitwiseOr => boundLeft.Type,
_ => throw new ArgumentOutOfRangeException()
};
return new BinaryExpressionNode(resultingType, boundLeft, op, boundRight, expression.Tokens);
throw new NotImplementedException();
}
private DereferenceNode CheckDereference(DereferenceSyntax expression)
{
var boundExpression = CheckExpression(expression.Expression);
var dereferencedType = ((PointerTypeNode)boundExpression.Type).BaseType;
return new DereferenceNode(dereferencedType, boundExpression, expression.Tokens);
throw new NotImplementedException();
}
private FuncCallNode CheckFuncCall(FuncCallSyntax expression)
{
var boundExpression = CheckExpression(expression.Expression);
if (boundExpression.Type is not FuncTypeNode funcType)
var accessor = CheckExpression(expression.Expression);
if (accessor.Type is not FuncTypeNode funcType)
{
throw new TypeCheckerException(Diagnostic.Error($"Cannot call non-function type {boundExpression.Type}").Build());
throw new TypeCheckerException(Diagnostic.Error($"Cannot call non-function type {accessor.Type}").At(expression.Expression).Build());
}
if (expression.Parameters.Count != funcType.Parameters.Count)
{
throw new TypeCheckerException(Diagnostic.Error($"Function {funcType} expects {funcType.Parameters} but got {expression.Parameters.Count} parameters").At(expression.Expression).Build());
}
var parameters = new List<ExpressionNode>();
foreach (var (i, parameter) in expression.Parameters.Index())
for (var i = 0; i < expression.Parameters.Count; i++)
{
if (i >= funcType.Parameters.Count)
{
_diagnostics.Add(Diagnostic.Error($"Expected {funcType.Parameters.Count} parameters").Build());
}
var parameter = expression.Parameters[i];
var expectedType = funcType.Parameters[i];
parameters.Add(CheckExpression(parameter, expectedType));
var parameterExpression = CheckExpression(parameter, expectedType);
if (parameterExpression.Type != expectedType)
{
throw new Exception($"Parameter {i + 1} does not match the type {expectedType} for function {funcType}");
}
return new FuncCallNode(funcType.ReturnType, boundExpression, parameters, expression.Tokens);
parameters.Add(parameterExpression);
}
return new FuncCallNode(funcType.ReturnType, accessor, parameters);
}
private ExpressionNode CheckDotFuncCall(DotFuncCallSyntax expression)
{
var thisParameter = CheckExpression(expression.ThisParameter);
if (thisParameter.Type is InterfaceTypeNode interfaceType)
{
var interfaceDefinitions = _definitionTable.LookupInterface(interfaceType.Name).ToArray();
if (interfaceDefinitions.Length == 0)
{
throw new TypeCheckerException(Diagnostic.Error($"Interface {interfaceType.Name} is not defined").Build());
}
if (interfaceDefinitions.Length > 1)
{
throw new TypeCheckerException(Diagnostic.Error($"Interface {interfaceType.Name} has multiple definitions").Build());
}
var function = interfaceDefinitions[0].Functions.FirstOrDefault(x => x.Name == expression.Name);
if (function == null)
{
throw new TypeCheckerException(Diagnostic.Error($"Interface {interfaceType.Name} does not have a function with the name {expression.Name}").Build());
}
var parameters = new List<ExpressionNode>();
for (var i = 0; i < expression.Parameters.Count; i++)
{
var parameter = expression.Parameters[i];
var expectedType = CheckType(function.Signature.Parameters[i].Type);
parameters.Add(CheckExpression(parameter, expectedType));
}
var returnType = CheckType(function.Signature.ReturnType);
return new InterfaceFuncCallNode(returnType, expression.Name, interfaceType, thisParameter, parameters, expression.Tokens);
}
if (thisParameter.Type is StructTypeNode structType)
{
var structDefinitions = _definitionTable.LookupStruct(structType.Name).ToArray();
if (structDefinitions.Length == 0)
{
throw new TypeCheckerException(Diagnostic.Error($"Struct {structType.Name} is not defined").Build());
}
if (structDefinitions.Length > 1)
{
throw new TypeCheckerException(Diagnostic.Error($"Struct {structType.Name} has multiple definitions").Build());
}
var function = structDefinitions[0].Functions.FirstOrDefault(x => x.Name == expression.Name);
if (function == null)
{
throw new TypeCheckerException(Diagnostic.Error($"Struct {structType.Name} does not have a function with the name {expression.Name}").Build());
}
var parameters = new List<ExpressionNode>();
for (var i = 0; i < expression.Parameters.Count; i++)
{
var parameter = expression.Parameters[i];
var expectedType = CheckType(function.Signature.Parameters[i].Type);
parameters.Add(CheckExpression(parameter, expectedType));
}
var returnType = CheckType(function.Signature.ReturnType);
return new StructFuncCallNode(returnType, expression.Name, structType, thisParameter, parameters, expression.Tokens);
}
throw new TypeCheckerException(Diagnostic.Error($"Cannot call dot function on type {thisParameter.Type}").Build());
throw new NotImplementedException();
}
private ExpressionNode CheckIdentifier(IdentifierSyntax expression)
{
var identifier = Scope.Lookup(expression.Name);
if (identifier != null)
// If the identifier does not have a module specified, first check if a local variable or function parameter with that identifier exists
if (!expression.Module.TryGetValue(out var moduleName))
{
return identifier.Kind switch
var scopeIdent = Scope.Lookup(expression.Name);
if (scopeIdent != null)
{
IdentifierKind.Variable => new VariableIdentNode(identifier.Type, identifier.Name, expression.Tokens),
IdentifierKind.FunctionParameter => new FuncParameterIdentNode(identifier.Type, identifier.Name, expression.Tokens),
_ => throw new ArgumentOutOfRangeException()
};
switch (scopeIdent.Kind)
{
case IdentifierKind.Variable:
{
return new VariableIdentifierNode(scopeIdent.Type, expression.Name);
}
case IdentifierKind.FunctionParameter:
{
return new FuncParameterIdentifierNode(scopeIdent.Type, expression.Name);
}
default:
{
throw new ArgumentOutOfRangeException();
}
}
}
}
var localFuncs = _definitionTable.LookupLocalFunc(expression.Name).ToArray();
if (localFuncs.Length > 0)
moduleName ??= _currentModule.Name;
if (_moduleSignatures.TryGetValue(moduleName, out var module))
{
if (localFuncs.Length > 1)
if (module.Functions.TryGetValue(expression.Name, out var function))
{
throw new TypeCheckerException(Diagnostic.Error($"Extern func {expression.Name} has multiple definitions").Build());
return new FuncIdentifierNode(function, moduleName, expression.Name);
}
}
var localFunc = localFuncs[0];
var returnType = CheckType(localFunc.Signature.ReturnType);
var parameterTypes = localFunc.Signature.Parameters.Select(p => CheckType(p.Type)).ToList();
var type = new FuncTypeNode(parameterTypes, returnType);
return new LocalFuncIdentNode(type, expression.Name, expression.Tokens);
}
var externFuncs = _definitionTable.LookupExternFunc(expression.Name).ToArray();
if (externFuncs.Length > 0)
{
if (externFuncs.Length > 1)
{
throw new TypeCheckerException(Diagnostic.Error($"Extern func {expression.Name} has multiple definitions").Build());
}
var externFunc = externFuncs[0];
var returnType = CheckType(externFunc.Signature.ReturnType);
var parameterTypes = externFunc.Signature.Parameters.Select(p => CheckType(p.Type)).ToList();
var type = new FuncTypeNode(parameterTypes, returnType);
return new ExternFuncIdentNode(type, expression.Name, expression.Tokens);
}
throw new TypeCheckerException(Diagnostic.Error($"No identifier with the name {expression.Name} exists").Build());
throw new TypeCheckerException(Diagnostic.Error($"Identifier {expression.Name} not found").At(expression).Build());
}
private LiteralNode CheckLiteral(LiteralSyntax expression, TypeNode? expectedType)
{
// todo(nub31): Check if the types can actually be represented as another one. For example, an int should be passed when a string is expected
var type = expectedType ?? expression.Kind switch
{
LiteralKind.Integer => new IntTypeNode(true, 64),
@@ -556,150 +372,22 @@ public sealed class TypeChecker
_ => throw new ArgumentOutOfRangeException()
};
return new LiteralNode(type, expression.Value, expression.Kind, expression.Tokens);
return new LiteralNode(type, expression.Value, expression.Kind);
}
private StructFieldAccessNode CheckStructFieldAccess(StructFieldAccessSyntax expression)
{
var boundExpression = CheckExpression(expression.Target);
if (boundExpression.Type is not StructTypeNode structType)
{
throw new Exception($"Cannot access struct field on non-struct type {boundExpression.Type}");
}
var structs = _definitionTable.LookupStruct(structType.Name).ToArray();
if (structs.Length > 0)
{
if (structs.Length > 1)
{
throw new TypeCheckerException(Diagnostic.Error($"Struct {structType} has multiple definitions").Build());
}
var fields = _definitionTable.LookupStructField(structs[0], expression.Member).ToArray();
if (fields.Length > 0)
{
if (fields.Length > 1)
{
throw new TypeCheckerException(Diagnostic.Error($"Struct {structType} has multiple fields with the name {expression.Member}").Build());
}
var field = fields[0];
return new StructFieldAccessNode(CheckType(field.Type), structType, boundExpression, expression.Member, expression.Tokens);
}
}
throw new TypeCheckerException(Diagnostic.Error($"{boundExpression.Type} does not have a member with the name {expression.Member}").Build());
throw new NotImplementedException();
}
private StructInitializerNode CheckStructInitializer(StructInitializerSyntax expression, TypeNode? expectedType)
{
var type = expectedType;
if (expression.StructType.HasValue)
{
type = CheckType(expression.StructType.Value);
}
if (type == null)
{
throw new TypeCheckerException(Diagnostic.Error("Cannot determine type of struct").Build());
}
if (type is not StructTypeNode structType)
{
throw new TypeCheckerException(Diagnostic.Error($"Cannot initialize non-struct type {expression.StructType}").Build());
}
var structs = _definitionTable.LookupStruct(structType.Name).ToArray();
if (structs.Length == 0)
{
throw new TypeCheckerException(Diagnostic.Error($"Struct {structType} is not defined").Build());
}
if (structs.Length > 1)
{
throw new TypeCheckerException(Diagnostic.Error($"Struct {structType} has multiple definitions").Build());
}
var @struct = structs[0];
var initializers = new Dictionary<string, ExpressionNode>();
foreach (var (field, initializer) in expression.Initializers)
{
var fields = _definitionTable.LookupStructField(@struct, field).ToArray();
if (fields.Length == 0)
{
throw new TypeCheckerException(Diagnostic.Error($"Struct {structType} does not have a field with the name {field}").Build());
}
if (fields.Length > 1)
{
throw new TypeCheckerException(Diagnostic.Error($"Struct {structType} has multiple fields with the name {field}").Build());
}
initializers[field] = CheckExpression(initializer, CheckType(fields[0].Type));
}
return new StructInitializerNode(structType, initializers, expression.Tokens);
throw new NotImplementedException();
}
private UnaryExpressionNode CheckUnaryExpression(UnaryExpressionSyntax expression)
{
var boundOperand = CheckExpression(expression.Operand);
TypeNode? type = null;
switch (expression.Operator)
{
case UnaryOperatorSyntax.Negate:
{
boundOperand = CheckExpression(expression.Operand, new IntTypeNode(true, 64));
if (boundOperand.Type is IntTypeNode or FloatTypeNode)
{
type = boundOperand.Type;
}
break;
}
case UnaryOperatorSyntax.Invert:
{
boundOperand = CheckExpression(expression.Operand, new BoolTypeNode());
type = new BoolTypeNode();
break;
}
}
if (type == null)
{
throw new TypeCheckerException(Diagnostic.Error($"Cannot perform unary operation {expression.Operand} on type {boundOperand.Type}").Build());
}
var op = expression.Operator switch
{
UnaryOperatorSyntax.Negate => UnaryOperator.Negate,
UnaryOperatorSyntax.Invert => UnaryOperator.Invert,
_ => throw new ArgumentOutOfRangeException(nameof(expression.Operator), expression.Operator, null)
};
return new UnaryExpressionNode(type, op, boundOperand, expression.Tokens);
}
private FuncSignatureNode CheckFuncSignature(FuncSignatureSyntax node)
{
var parameters = new List<FuncParameterNode>();
foreach (var parameter in node.Parameters)
{
parameters.Add(new FuncParameterNode(parameter.Name, CheckType(parameter.Type), parameter.Tokens));
}
return new FuncSignatureNode(parameters, CheckType(node.ReturnType), node.Tokens);
throw new NotImplementedException();
}
private BlockNode CheckBlock(BlockSyntax node, Scope? scope = null)
@@ -715,105 +403,12 @@ public sealed class TypeChecker
_scopes.Pop();
return new BlockNode(statements, node.Tokens);
return new BlockNode(statements);
}
private TypeNode CheckType(TypeSyntax node)
private TypeNode ResolveType(TypeSyntax fieldType)
{
return node switch
{
ArrayTypeSyntax type => new ArrayTypeNode(CheckType(type.BaseType)),
BoolTypeSyntax => new BoolTypeNode(),
CStringTypeSyntax => new CStringTypeNode(),
CustomTypeSyntax type => CheckCustomType(type),
FloatTypeSyntax @float => new FloatTypeNode(@float.Width),
FuncTypeSyntax type => new FuncTypeNode(type.Parameters.Select(CheckType).ToList(), CheckType(type.ReturnType)),
IntTypeSyntax @int => new IntTypeNode(@int.Signed, @int.Width),
PointerTypeSyntax type => new PointerTypeNode(CheckType(type.BaseType)),
StringTypeSyntax => new StringTypeNode(),
VoidTypeSyntax => new VoidTypeNode(),
_ => throw new ArgumentOutOfRangeException(nameof(node))
};
}
private TypeNode CheckCustomType(CustomTypeSyntax type)
{
var structs = _definitionTable.LookupStruct(type.Name).ToArray();
if (structs.Length > 0)
{
if (structs.Length > 1)
{
throw new TypeCheckerException(Diagnostic.Error($"Struct {type.Name} has multiple definitions").Build());
}
return GetStructType(structs[0]);
}
var interfaces = _definitionTable.LookupInterface(type.Name).ToArray();
if (interfaces.Length > 0)
{
if (interfaces.Length > 1)
{
throw new TypeCheckerException(Diagnostic.Error($"Interface {type.Name} has multiple definitions").Build());
}
return GetInterfaceType(interfaces[0]);
}
throw new TypeCheckerException(Diagnostic.Error($"Type {type.Name} is not defined").Build());
}
private StructTypeNode GetStructType(StructSyntax structDef)
{
if (_typeCache.TryGetValue(structDef.Name, out var cachedType))
{
return (StructTypeNode)cachedType;
}
var result = new StructTypeNode(structDef.Name, [], [], []);
_typeCache.Add(structDef.Name, result);
var fields = structDef.Fields.Select(x => CheckType(x.Type)).ToList();
var funcs = structDef.Functions
.Select(x => new FuncTypeNode(x.Signature.Parameters.Select(p => CheckType(p.Type)).ToList(), CheckType(x.Signature.ReturnType)))
.ToList();
var interfaceImplementations = new List<InterfaceTypeNode>();
foreach (var structInterfaceImplementation in structDef.InterfaceImplementations)
{
var checkedInterfaceType = CheckType(structInterfaceImplementation);
if (checkedInterfaceType is not InterfaceTypeNode interfaceType)
{
throw new TypeCheckerException(Diagnostic.Error($"{structDef.Name} cannot implement non-interface type {checkedInterfaceType}").Build());
}
interfaceImplementations.Add(interfaceType);
}
result.Fields = fields;
result.Functions = funcs;
result.InterfaceImplementations = interfaceImplementations;
return result;
}
private InterfaceTypeNode GetInterfaceType(InterfaceSyntax interfaceDef)
{
if (_typeCache.TryGetValue(interfaceDef.Name, out var cachedType))
{
return (InterfaceTypeNode)cachedType;
}
var result = new InterfaceTypeNode(interfaceDef.Name, []);
_typeCache.Add(interfaceDef.Name, result);
var functions = interfaceDef.Functions
.Select(x => new FuncTypeNode(x.Signature.Parameters.Select(y => CheckType(y.Type)).ToList(), CheckType(x.Signature.ReturnType)))
.ToList();
result.Functions = functions;
return result;
return TypeResolver.ResolveType(fieldType, _moduleSignatures);
}
}

View File

@@ -0,0 +1,48 @@
using NubLang.Parsing.Syntax;
using NubLang.TypeChecking;
using NubLang.TypeChecking.Node;
namespace NubLang;
public static class TypeResolver
{
public static TypeNode ResolveType(TypeSyntax type, IReadOnlyDictionary<string, ModuleSignature> modules)
{
return type switch
{
BoolTypeSyntax => new BoolTypeNode(),
CStringTypeSyntax => new CStringTypeNode(),
IntTypeSyntax i => new IntTypeNode(i.Signed, i.Width),
CustomTypeSyntax c => ResolveCustomType(c.Module, c.Name, modules),
FloatTypeSyntax f => new FloatTypeNode(f.Width),
FuncTypeSyntax func => new FuncTypeNode(func.Parameters.Select(x => ResolveType(x, modules)).ToList(), ResolveType(func.ReturnType, modules)),
ArrayTypeSyntax arr => new ArrayTypeNode(ResolveType(arr.BaseType, modules)),
PointerTypeSyntax ptr => new PointerTypeNode(ResolveType(ptr.BaseType, modules)),
StringTypeSyntax => new StringTypeNode(),
VoidTypeSyntax => new VoidTypeNode(),
_ => throw new NotSupportedException($"Unknown type syntax: {type}")
};
}
public static TypeNode ResolveCustomType(string moduleName, string typeName, IReadOnlyDictionary<string, ModuleSignature> modules)
{
if (!modules.TryGetValue(moduleName, out var module))
{
throw new Exception("Module not found: " + moduleName);
}
var structType = module.StructTypes.FirstOrDefault(x => x.Name == typeName);
if (structType != null)
{
return structType;
}
var interfaceType = module.InterfaceTypes.FirstOrDefault(x => x.Name == typeName);
if (interfaceType != null)
{
return interfaceType;
}
throw new Exception("Type not found: " + typeName);
}
}