diff --git a/example/.gitignore b/example/.gitignore index 54f46fc..145bd1e 100644 --- a/example/.gitignore +++ b/example/.gitignore @@ -1,2 +1,2 @@ -build +.build out \ No newline at end of file diff --git a/example/src/main.nub b/example/src/main.nub index 1195e4f..0372339 100644 --- a/example/src/main.nub +++ b/example/src/main.nub @@ -1,34 +1,15 @@ -// c +module main + extern func puts(text: cstring) -struct Name +struct Test { - first: cstring - last: cstring -} - -struct Human -{ - name: Name - age: cstring + } func main(args: []cstring): i64 { - let x: Human = { - name = { - first = "bob" - last = "the builder" - } - age = "23" - } - - puts(x.age) + puts("test") return 0 } - -// func test(human: ^Human) -// { -// puts(human^.name.last) -// } \ No newline at end of file diff --git a/src/compiler/NubLang.CLI/Program.cs b/src/compiler/NubLang.CLI/Program.cs index 13a1fdc..98406b2 100644 --- a/src/compiler/NubLang.CLI/Program.cs +++ b/src/compiler/NubLang.CLI/Program.cs @@ -1,13 +1,11 @@ using NubLang.CLI; using NubLang.Code; using NubLang.Diagnostics; -using NubLang.Generation; using NubLang.Generation.QBE; using NubLang.Parsing; using NubLang.Parsing.Syntax; using NubLang.Tokenization; using NubLang.TypeChecking; -using NubLang.TypeChecking.Node; var options = new Options(); @@ -35,9 +33,6 @@ for (var i = 0; i < args.Length; i++) } } -var diagnostics = new List(); -var syntaxTrees = new List(); - foreach (var file in options.Files) { if (!File.Exists(file.Path)) @@ -47,6 +42,9 @@ foreach (var file in options.Files) } } +var diagnostics = new List(); + +var syntaxTrees = new List(); foreach (var file in options.Files) { var tokenizer = new Tokenizer(file); @@ -61,16 +59,17 @@ foreach (var file in options.Files) syntaxTrees.Add(syntaxTree); } -var definitionTable = new DefinitionTable(syntaxTrees); +var moduleSignatures = ModuleSignature.CollectFromSyntaxTrees(syntaxTrees); +var modules = Module.CollectFromSyntaxTrees(syntaxTrees); -var typedSyntaxTrees = new List(); +var typedModules = new List(); -foreach (var syntaxTree in syntaxTrees) +foreach (var module in modules) { - var typeChecker = new TypeChecker(syntaxTree, definitionTable); - var typedSyntaxTree = typeChecker.Check(); + var typeChecker = new TypeChecker(module, moduleSignatures); + var typedModule = typeChecker.CheckModule(); diagnostics.AddRange(typeChecker.GetDiagnostics()); - typedSyntaxTrees.Add(typedSyntaxTree); + typedModules.Add(typedModule); } foreach (var diagnostic in diagnostics) @@ -83,13 +82,11 @@ if (diagnostics.Any(diagnostic => diagnostic.Severity == DiagnosticSeverity.Erro return 1; } -var typedDefinitionTable = new TypedDefinitionTable(typedSyntaxTrees); - var objectFiles = new List(); -for (var i = 0; i < typedSyntaxTrees.Count; i++) +for (var i = 0; i < typedModules.Count; i++) { - var syntaxTree = typedSyntaxTrees[i]; + var typedModule = typedModules[i]; var outFileName = Path.Combine(".build", "code", Path.ChangeExtension(options.Files[i].Path, null)); var outFileDir = Path.GetDirectoryName(outFileName); @@ -98,7 +95,7 @@ for (var i = 0; i < typedSyntaxTrees.Count; i++) Directory.CreateDirectory(outFileDir); } - var generator = new QBEGenerator(syntaxTree, typedDefinitionTable, options.Files[i].Path); + var generator = new QBEGenerator(typedModule, moduleSignatures); var ssa = generator.Emit(); var ssaFilePath = Path.ChangeExtension(outFileName, "ssa"); diff --git a/src/compiler/NubLang/Generation/QBE/QBEGenerator.cs b/src/compiler/NubLang/Generation/QBE/QBEGenerator.cs index 3f1f580..a93754d 100644 --- a/src/compiler/NubLang/Generation/QBE/QBEGenerator.cs +++ b/src/compiler/NubLang/Generation/QBE/QBEGenerator.cs @@ -2,16 +2,16 @@ using System.Globalization; using System.Text; using NubLang.Tokenization; +using NubLang.TypeChecking; using NubLang.TypeChecking.Node; namespace NubLang.Generation.QBE; public class QBEGenerator { - private readonly TypedSyntaxTree _syntaxTree; - private readonly TypedDefinitionTable _definitionTable; - private readonly string _sourceFileName; private readonly QBEWriter _writer; + private readonly TypedModule _module; + private readonly IReadOnlyList _moduleSignatures; private readonly List _cStringLiterals = []; private readonly List _stringLiterals = []; @@ -23,11 +23,10 @@ public class QBEGenerator private int _stringLiteralIndex; private bool _codeIsReachable = true; - public QBEGenerator(TypedSyntaxTree syntaxTree, TypedDefinitionTable definitionTable, string sourceFileName) + public QBEGenerator(TypedModule module, IReadOnlyList moduleSignatures) { - _syntaxTree = syntaxTree; - _definitionTable = definitionTable; - _sourceFileName = sourceFileName; + _module = module; + _moduleSignatures = moduleSignatures; _writer = new QBEWriter(); } @@ -43,41 +42,42 @@ public class QBEGenerator _stringLiteralIndex = 0; _codeIsReachable = true; - _writer.WriteLine($"dbgfile \"{_sourceFileName}\""); - - foreach (var structDef in _definitionTable.GetStructs()) + foreach (var moduleSignature in _moduleSignatures) { - EmitStructTypeDefinition(structDef); - _writer.NewLine(); + foreach (var structType in moduleSignature.Symbols.Values.OfType()) + { + EmitStructType(moduleSignature.Name, structType); + _writer.NewLine(); + } } - foreach (var structDef in _syntaxTree.Definitions.OfType()) + foreach (var structDef in _module.Definitions.OfType()) { EmitStructDefinition(structDef); _writer.NewLine(); } - foreach (var funcDef in _syntaxTree.Definitions.OfType()) + foreach (var funcDef in _module.Definitions.OfType()) { EmitLocalFuncDefinition(funcDef); _writer.NewLine(); } - foreach (var structDef in _syntaxTree.Definitions.OfType().Where(x => x.InterfaceImplementations.Count > 0)) - { - _writer.Write($"data {StructVtableName(structDef.Name)} = {{ "); - - foreach (var interfaceImplementation in structDef.InterfaceImplementations) - { - var interfaceDef = _definitionTable.LookupInterface(interfaceImplementation.Name); - foreach (var func in interfaceDef.Functions) - { - _writer.Write($"l {StructFuncName(structDef.Name, func.Name)}, "); - } - } - - _writer.WriteLine("}"); - } + // foreach (var structDef in _module.Definitions.OfType().Where(x => x.InterfaceImplementations.Count > 0)) + // { + // _writer.Write($"data {StructVtableName(_module.Name, structDef.Name)} = {{ "); + // + // foreach (var interfaceImplementation in structDef.InterfaceImplementations) + // { + // var interfaceDef = _definitionTable.LookupInterface(interfaceImplementation.Name); + // foreach (var func in interfaceDef.Functions) + // { + // _writer.Write($"l {StructFuncName(_module.Name, structDef.Name, func.Name)}, "); + // } + // } + // + // _writer.WriteLine("}"); + // } foreach (var cStringLiteral in _cStringLiterals) { @@ -366,7 +366,7 @@ public class QBEGenerator if (complexType is StructTypeNode structType) { - return StructTypeName(structType.Name); + return StructTypeName(structType.Module, structType.Name); } return "l"; @@ -384,7 +384,7 @@ public class QBEGenerator _writer.Write(FuncQBETypeName(funcDef.Signature.ReturnType) + ' '); } - _writer.Write(LocalFuncName(funcDef)); + _writer.Write(LocalFuncName(_module.Name, funcDef)); _writer.Write("("); foreach (var parameter in funcDef.Signature.Parameters) @@ -408,12 +408,19 @@ public class QBEGenerator private void EmitStructDefinition(StructNode structDef) { + _writer.WriteLine($"export function {StructCtorName(_module.Name, structDef.Name)}() {{"); + _writer.WriteLine("@start"); + _writer.Indented($"%struct =l alloc8 {SizeOf(structDef.)}"); + _writer.Indented("ret %struct"); + _writer.WriteLine("}"); + for (var i = 0; i < structDef.Functions.Count; i++) { var function = structDef.Functions[i]; _labelIndex = 0; _tmpIndex = 0; + _writer.NewLine(); _writer.Write("export function "); if (function.Signature.ReturnType is not VoidTypeNode) @@ -421,7 +428,7 @@ public class QBEGenerator _writer.Write(FuncQBETypeName(function.Signature.ReturnType) + ' '); } - _writer.Write(StructFuncName(structDef.Name, function.Name)); + _writer.Write(StructFuncName(_module.Name, structDef.Name, function.Name)); _writer.Write("(l %this, "); foreach (var parameter in function.Signature.Parameters) @@ -441,38 +448,24 @@ public class QBEGenerator } _writer.WriteLine("}"); - - if (i != structDef.Functions.Count - 1) - { - _writer.NewLine(); - } } } - private void EmitStructTypeDefinition(StructNode structDef) + private void EmitStructType(string module, StructTypeNode structType) { - _writer.WriteLine($"type {StructTypeName(structDef.Name)} = {{ "); + _writer.WriteLine($"type {StructTypeName(module, structType.Name)} = {{ "); - var types = new Dictionary(); - - foreach (var field in structDef.Fields) + foreach (var field in structType.Fields) { - types.Add(field.Name, StructDefQBEType(field)); - } - - var longest = types.Values.Max(x => x.Length); - foreach (var (name, type) in types) - { - var padding = longest - type.Length; - _writer.Indented($"{type},{new string(' ', padding)} # {name}"); + _writer.Indented($"{StructDefQBEType(field)},"); } _writer.WriteLine("}"); return; - string StructDefQBEType(StructFieldNode field) + string StructDefQBEType(TypeNode type) { - if (field.Type.IsSimpleType(out var simpleType, out var complexType)) + if (type.IsSimpleType(out var simpleType, out var complexType)) { return simpleType.StorageSize switch { @@ -488,7 +481,7 @@ public class QBEGenerator if (complexType is StructTypeNode structType) { - return StructTypeName(structType.Name); + return StructTypeName(structType.Module, structType.Name); } return "l"; @@ -649,10 +642,7 @@ public class QBEGenerator ConvertToInterfaceNode convertToInterface => EmitConvertToInterface(convertToInterface), ConvertIntNode convertInt => EmitConvertInt(convertInt), ConvertFloatNode convertFloat => EmitConvertFloat(convertFloat), - ExternFuncIdentNode externFuncIdent => EmitExternFuncIdent(externFuncIdent), - LocalFuncIdentNode localFuncIdent => EmitLocalFuncIdent(localFuncIdent), - VariableIdentNode variableIdent => EmitVariableIdent(variableIdent), - FuncParameterIdentNode funcParameterIdent => EmitFuncParameterIdent(funcParameterIdent), + VariableIdentifierNode identifier => EmitIdentifier(identifier), LiteralNode literal => EmitLiteral(literal), UnaryExpressionNode unaryExpression => EmitUnaryExpression(unaryExpression), StructFieldAccessNode structFieldAccess => EmitStructFieldAccess(structFieldAccess), @@ -662,15 +652,21 @@ public class QBEGenerator }; } + private string EmitIdentifier(VariableIdentifierNode variableIdentifier) + { + throw new NotImplementedException(); + } + private string EmitArrayIndexAccess(ArrayIndexAccessNode arrayIndexAccess) { - var address = EmitAddressOfArrayIndexAccess(arrayIndexAccess); - if (arrayIndexAccess.Type is StructTypeNode) - { - return address; - } - - return EmitLoad(arrayIndexAccess.Type, address); + // var address = EmitAddressOfArrayIndexAccess(arrayIndexAccess); + // if (arrayIndexAccess.Type is StructTypeNode) + // { + // return address; + // } + // + // return EmitLoad(arrayIndexAccess.Type, address); + throw new NotImplementedException(); } private string EmitArrayInitializer(ArrayInitializerNode arrayInitializer) @@ -714,43 +710,43 @@ public class QBEGenerator { return addressOf switch { - ArrayIndexAccessNode arrayIndexAccess => EmitAddressOfArrayIndexAccess(arrayIndexAccess), - StructFieldAccessNode structFieldAccess => EmitAddressOfStructFieldAccess(structFieldAccess), - VariableIdentNode variableIdent => EmitAddressOfVariableIdent(variableIdent), + // ArrayIndexAccessNode arrayIndexAccess => EmitAddressOfArrayIndexAccess(arrayIndexAccess), + // StructFieldAccessNode structFieldAccess => EmitAddressOfStructFieldAccess(structFieldAccess), + // VariableIdentNode variableIdent => EmitAddressOfVariableIdent(variableIdent), _ => throw new ArgumentOutOfRangeException(nameof(addressOf)) }; } - private string EmitAddressOfArrayIndexAccess(ArrayIndexAccessNode arrayIndexAccess) - { - var array = EmitExpression(arrayIndexAccess.Target); - var index = EmitExpression(arrayIndexAccess.Index); - - var elementType = ((ArrayTypeNode)arrayIndexAccess.Target.Type).ElementType; - - var offset = TmpName(); - _writer.Indented($"{offset} =l mul {index}, {SizeOf(elementType)}"); - _writer.Indented($"{offset} =l add {offset}, 8"); - _writer.Indented($"{offset} =l add {array}, {offset}"); - return offset; - } - - private string EmitAddressOfStructFieldAccess(StructFieldAccessNode structFieldAccess) - { - var target = EmitExpression(structFieldAccess.Target); - - var structDef = _definitionTable.LookupStruct(structFieldAccess.StructType.Name); - var offset = OffsetOf(structDef, structFieldAccess.Field); - - var address = TmpName(); - _writer.Indented($"{address} =l add {target}, {offset}"); - return address; - } - - private string EmitAddressOfVariableIdent(VariableIdentNode variableIdent) - { - return "%" + variableIdent.Name; - } + // private string EmitAddressOfArrayIndexAccess(ArrayIndexAccessNode arrayIndexAccess) + // { + // var array = EmitExpression(arrayIndexAccess.Target); + // var index = EmitExpression(arrayIndexAccess.Index); + // + // var elementType = ((ArrayTypeNode)arrayIndexAccess.Target.Type).ElementType; + // + // var offset = TmpName(); + // _writer.Indented($"{offset} =l mul {index}, {SizeOf(elementType)}"); + // _writer.Indented($"{offset} =l add {offset}, 8"); + // _writer.Indented($"{offset} =l add {array}, {offset}"); + // return offset; + // } + // + // private string EmitAddressOfStructFieldAccess(StructFieldAccessNode structFieldAccess) + // { + // var target = EmitExpression(structFieldAccess.Target); + // + // var structDef = _definitionTable.LookupStruct(structFieldAccess.StructType.Name); + // var offset = OffsetOf(structDef, structFieldAccess.Field); + // + // var address = TmpName(); + // _writer.Indented($"{address} =l add {target}, {offset}"); + // return address; + // } + // + // private string EmitAddressOfVariableIdent(VariableIdentNode variableIdent) + // { + // return "%" + variableIdent.Name; + // } private string EmitBinaryExpression(BinaryExpressionNode binaryExpression) { @@ -922,32 +918,6 @@ public class QBEGenerator }; } - private string EmitExternFuncIdent(ExternFuncIdentNode externFuncIdent) - { - return ExternFuncName(_definitionTable.LookupExternFunc(externFuncIdent.Name)); - } - - private string EmitLocalFuncIdent(LocalFuncIdentNode localFuncIdent) - { - return LocalFuncName(_definitionTable.LookupLocalFunc(localFuncIdent.Name)); - } - - private string EmitVariableIdent(VariableIdentNode variableIdent) - { - var address = EmitAddressOfVariableIdent(variableIdent); - if (variableIdent.Type is StructTypeNode) - { - return address; - } - - return EmitLoad(variableIdent.Type, address); - } - - private string EmitFuncParameterIdent(FuncParameterIdentNode funcParameterIdent) - { - return "%" + funcParameterIdent.Name; - } - private string EmitLiteral(LiteralNode literal) { switch (literal.Kind) @@ -1032,13 +1002,11 @@ public class QBEGenerator private string EmitStructInitializer(StructInitializerNode structInitializer) { - var structDef = _definitionTable.LookupStruct(structInitializer.StructType.Name); - var destination = TmpName(); var size = SizeOf(structInitializer.StructType); _writer.Indented($"{destination} =l alloc8 {size}"); - foreach (var field in structDef.Fields) + foreach (var field in structInitializer.StructType.Fields) { if (!structInitializer.Initializers.TryGetValue(field.Name, out var valueExpression)) { @@ -1417,9 +1385,9 @@ public class QBEGenerator return $"$string{++_stringLiteralIndex}"; } - private string LocalFuncName(LocalFuncNode funcDef) + private string LocalFuncName(string module, LocalFuncNode funcDef) { - return $"${funcDef.Name}"; + return $"${module}.{funcDef.Name}"; } private string ExternFuncName(ExternFuncNode funcDef) @@ -1427,19 +1395,24 @@ public class QBEGenerator return $"${funcDef.CallName}"; } - private string StructTypeName(string name) + private string StructTypeName(string module, string name) { - return $":{name}"; + return $":{module}.{name}"; } - private string StructFuncName(string structName, string funcName) + private string StructFuncName(string module, string structName, string funcName) { - return $"${structName}_{funcName}"; + return $"${module}.{structName}_func.{funcName}"; } - private string StructVtableName(string structName) + private string StructCtorName(string module, string structName) { - return $"${structName}_vtable"; + return $"${module}.{structName}_ctor"; + } + + private string StructVtableName(string module, string structName) + { + return $"${module}.{structName}_vtable"; } #endregion diff --git a/src/compiler/NubLang/Generation/TypedDefinitionTable.cs b/src/compiler/NubLang/Generation/TypedDefinitionTable.cs deleted file mode 100644 index e54fb1c..0000000 --- a/src/compiler/NubLang/Generation/TypedDefinitionTable.cs +++ /dev/null @@ -1,51 +0,0 @@ -using NubLang.TypeChecking.Node; - -namespace NubLang.Generation; - -public sealed class TypedDefinitionTable -{ - private readonly List _definitions; - - public TypedDefinitionTable(IEnumerable syntaxTrees) - { - _definitions = syntaxTrees.SelectMany(x => x.Definitions).ToList(); - } - - public LocalFuncNode LookupLocalFunc(string name) - { - return _definitions - .OfType() - .First(x => x.Name == name); - } - - public ExternFuncNode LookupExternFunc(string name) - { - return _definitions - .OfType() - .First(x => x.Name == name); - } - - public StructNode LookupStruct(string name) - { - return _definitions - .OfType() - .First(x => x.Name == name); - } - - public InterfaceNode LookupInterface(string name) - { - return _definitions - .OfType() - .First(x => x.Name == name); - } - - public IEnumerable GetStructs() - { - return _definitions.OfType(); - } - - public IEnumerable GetInterfaces() - { - return _definitions.OfType(); - } -} \ No newline at end of file diff --git a/src/compiler/NubLang/Parsing/Parser.cs b/src/compiler/NubLang/Parsing/Parser.cs index 079ed15..c6d84ff 100644 --- a/src/compiler/NubLang/Parsing/Parser.cs +++ b/src/compiler/NubLang/Parsing/Parser.cs @@ -10,6 +10,7 @@ public sealed class Parser private readonly List _diagnostics = []; private IReadOnlyList _tokens = []; private int _tokenIndex; + private string _moduleName = string.Empty; private Token? CurrentToken => _tokenIndex < _tokens.Count ? _tokens[_tokenIndex] : null; private bool HasToken => CurrentToken != null; @@ -24,7 +25,47 @@ public sealed class Parser _diagnostics.Clear(); _tokens = tokens; _tokenIndex = 0; + _moduleName = string.Empty; + var metadata = ParseMetadata(); + var definitions = ParseDefinitions(); + + return new SyntaxTree(definitions, metadata); + } + + private SyntaxTreeMetadata ParseMetadata() + { + var imports = new List(); + + try + { + ExpectSymbol(Symbol.Module); + _moduleName = ExpectIdentifier().Value; + + while (TryExpectSymbol(Symbol.Import)) + { + imports.Add(ExpectIdentifier().Value); + } + } + catch (ParseException e) + { + _diagnostics.Add(e.Diagnostic); + while (HasToken) + { + if (CurrentToken is SymbolToken { Symbol: Symbol.Module or Symbol.Import }) + { + break; + } + + Next(); + } + } + + return new SyntaxTreeMetadata(_moduleName, imports); + } + + private List ParseDefinitions() + { var definitions = new List(); while (HasToken) @@ -48,9 +89,9 @@ public sealed class Parser definitions.Add(definition); } - catch (ParseException ex) + catch (ParseException e) { - _diagnostics.Add(ex.Diagnostic); + _diagnostics.Add(e.Diagnostic); while (HasToken) { if (CurrentToken is SymbolToken { Symbol: Symbol.Extern or Symbol.Func or Symbol.Struct or Symbol.Interface }) @@ -63,7 +104,7 @@ public sealed class Parser } } - return new SyntaxTree(GetTokens(_tokenIndex), definitions); + return definitions; } private FuncSignatureSyntax ParseFuncSignature() @@ -129,13 +170,13 @@ public sealed class Parser return new ExternFuncSyntax(GetTokens(startIndex), name.Value, callName, signature); } - private LocalFuncSyntax ParseFunc(int startIndex) + private FuncSyntax ParseFunc(int startIndex) { var name = ExpectIdentifier(); var signature = ParseFuncSignature(); var body = ParseBlock(); - return new LocalFuncSyntax(GetTokens(startIndex), name.Value, signature, body); + return new FuncSyntax(GetTokens(startIndex), name.Value, signature, body); } private DefinitionSyntax ParseStruct(int startIndex) @@ -448,7 +489,7 @@ public sealed class Parser var expr = token switch { LiteralToken literal => new LiteralSyntax(GetTokens(startIndex), literal.Value, literal.Kind), - IdentifierToken identifier => new IdentifierSyntax(GetTokens(startIndex), identifier.Value), + IdentifierToken identifier => new IdentifierSyntax(GetTokens(startIndex), Optional.Empty(), identifier.Value), SymbolToken symbolToken => symbolToken.Symbol switch { Symbol.OpenParen => ParseParenthesizedExpression(), @@ -664,7 +705,7 @@ public sealed class Parser "string" => new StringTypeSyntax(GetTokens(startIndex)), "cstring" => new CStringTypeSyntax(GetTokens(startIndex)), "bool" => new BoolTypeSyntax(GetTokens(startIndex)), - _ => new CustomTypeSyntax(GetTokens(startIndex), name.Value) + _ => new CustomTypeSyntax(GetTokens(startIndex), _moduleName, name.Value) }; } diff --git a/src/compiler/NubLang/Parsing/Syntax/DefinitionSyntax.cs b/src/compiler/NubLang/Parsing/Syntax/DefinitionSyntax.cs index 3674739..524a8ef 100644 --- a/src/compiler/NubLang/Parsing/Syntax/DefinitionSyntax.cs +++ b/src/compiler/NubLang/Parsing/Syntax/DefinitionSyntax.cs @@ -2,22 +2,23 @@ using NubLang.Tokenization; namespace NubLang.Parsing.Syntax; -public abstract record DefinitionSyntax(IEnumerable Tokens) : SyntaxNode(Tokens); +// todo(nub31): Check export modifier instead of harcoding true +public abstract record DefinitionSyntax(IEnumerable Tokens, string Name, bool Exported = true) : SyntaxNode(Tokens); public record FuncParameterSyntax(IEnumerable Tokens, string Name, TypeSyntax Type) : SyntaxNode(Tokens); public record FuncSignatureSyntax(IEnumerable Tokens, IReadOnlyList Parameters, TypeSyntax ReturnType) : SyntaxNode(Tokens); -public record LocalFuncSyntax(IEnumerable Tokens, string Name, FuncSignatureSyntax Signature, BlockSyntax Body) : DefinitionSyntax(Tokens); +public record FuncSyntax(IEnumerable Tokens, string Name, FuncSignatureSyntax Signature, BlockSyntax Body) : DefinitionSyntax(Tokens, Name); -public record ExternFuncSyntax(IEnumerable Tokens, string Name, string CallName, FuncSignatureSyntax Signature) : DefinitionSyntax(Tokens); +public record ExternFuncSyntax(IEnumerable Tokens, string Name, string CallName, FuncSignatureSyntax Signature) : DefinitionSyntax(Tokens, Name); public record StructFieldSyntax(IEnumerable Tokens, int Index, string Name, TypeSyntax Type, Optional Value) : SyntaxNode(Tokens); public record StructFuncSyntax(IEnumerable Tokens, string Name, FuncSignatureSyntax Signature, BlockSyntax Body) : SyntaxNode(Tokens); -public record StructSyntax(IEnumerable Tokens, string Name, IReadOnlyList Fields, IReadOnlyList Functions, IReadOnlyList InterfaceImplementations) : DefinitionSyntax(Tokens); +public record StructSyntax(IEnumerable Tokens, string Name, IReadOnlyList Fields, IReadOnlyList Functions, IReadOnlyList InterfaceImplementations) : DefinitionSyntax(Tokens, Name); public record InterfaceFuncSyntax(IEnumerable Tokens, string Name, FuncSignatureSyntax Signature) : SyntaxNode(Tokens); -public record InterfaceSyntax(IEnumerable Tokens, string Name, IReadOnlyList Functions) : DefinitionSyntax(Tokens); \ No newline at end of file +public record InterfaceSyntax(IEnumerable Tokens, string Name, IReadOnlyList Functions) : DefinitionSyntax(Tokens, Name); \ No newline at end of file diff --git a/src/compiler/NubLang/Parsing/Syntax/ExpressionSyntax.cs b/src/compiler/NubLang/Parsing/Syntax/ExpressionSyntax.cs index cb42096..f5f4a30 100644 --- a/src/compiler/NubLang/Parsing/Syntax/ExpressionSyntax.cs +++ b/src/compiler/NubLang/Parsing/Syntax/ExpressionSyntax.cs @@ -40,7 +40,7 @@ public record FuncCallSyntax(IEnumerable Tokens, ExpressionSyntax Express public record DotFuncCallSyntax(IEnumerable Tokens, string Name, ExpressionSyntax ThisParameter, IReadOnlyList Parameters) : ExpressionSyntax(Tokens); -public record IdentifierSyntax(IEnumerable Tokens, string Name) : ExpressionSyntax(Tokens); +public record IdentifierSyntax(IEnumerable Tokens, Optional Module, string Name) : ExpressionSyntax(Tokens); public record ArrayInitializerSyntax(IEnumerable Tokens, ExpressionSyntax Capacity, TypeSyntax ElementType) : ExpressionSyntax(Tokens); diff --git a/src/compiler/NubLang/Parsing/Syntax/SyntaxNode.cs b/src/compiler/NubLang/Parsing/Syntax/SyntaxNode.cs index 4fb6807..7435d18 100644 --- a/src/compiler/NubLang/Parsing/Syntax/SyntaxNode.cs +++ b/src/compiler/NubLang/Parsing/Syntax/SyntaxNode.cs @@ -4,6 +4,8 @@ namespace NubLang.Parsing.Syntax; public abstract record SyntaxNode(IEnumerable Tokens); -public record SyntaxTree(IEnumerable Tokens, IReadOnlyList Definitions) : SyntaxNode(Tokens); +public record SyntaxTreeMetadata(string? ModuleName, IReadOnlyList Imports); + +public record SyntaxTree(IReadOnlyList Definitions, SyntaxTreeMetadata Metadata); public record BlockSyntax(IEnumerable Tokens, IReadOnlyList Statements) : SyntaxNode(Tokens); \ No newline at end of file diff --git a/src/compiler/NubLang/Parsing/Syntax/TypeSyntax.cs b/src/compiler/NubLang/Parsing/Syntax/TypeSyntax.cs index a29704d..29ff911 100644 --- a/src/compiler/NubLang/Parsing/Syntax/TypeSyntax.cs +++ b/src/compiler/NubLang/Parsing/Syntax/TypeSyntax.cs @@ -22,4 +22,4 @@ public record CStringTypeSyntax(IEnumerable Tokens) : TypeSyntax(Tokens); public record ArrayTypeSyntax(IEnumerable Tokens, TypeSyntax BaseType) : TypeSyntax(Tokens); -public record CustomTypeSyntax(IEnumerable Tokens, string Name) : TypeSyntax(Tokens); \ No newline at end of file +public record CustomTypeSyntax(IEnumerable Tokens, string Module, string Name) : TypeSyntax(Tokens); \ No newline at end of file diff --git a/src/compiler/NubLang/Tokenization/Token.cs b/src/compiler/NubLang/Tokenization/Token.cs index 297c043..f300e04 100644 --- a/src/compiler/NubLang/Tokenization/Token.cs +++ b/src/compiler/NubLang/Tokenization/Token.cs @@ -76,4 +76,6 @@ public enum Symbol Pipe, And, Or, + Module, + Import, } \ No newline at end of file diff --git a/src/compiler/NubLang/Tokenization/Tokenizer.cs b/src/compiler/NubLang/Tokenization/Tokenizer.cs index f18c109..dda7393 100644 --- a/src/compiler/NubLang/Tokenization/Tokenizer.cs +++ b/src/compiler/NubLang/Tokenization/Tokenizer.cs @@ -20,6 +20,7 @@ public sealed class Tokenizer ["interface"] = Symbol.Interface, ["for"] = Symbol.For, ["extern"] = Symbol.Extern, + ["module"] = Symbol.Module, }; private static readonly Dictionary Symbols = new() diff --git a/src/compiler/NubLang/TypeChecking/DefinitionTable.cs b/src/compiler/NubLang/TypeChecking/DefinitionTable.cs deleted file mode 100644 index 182a0e6..0000000 --- a/src/compiler/NubLang/TypeChecking/DefinitionTable.cs +++ /dev/null @@ -1,56 +0,0 @@ -using NubLang.Parsing.Syntax; - -namespace NubLang.TypeChecking; - -public class DefinitionTable -{ - private readonly List _definitions; - - public DefinitionTable(IEnumerable syntaxTrees) - { - _definitions = syntaxTrees.SelectMany(x => x.Definitions).ToList(); - } - - public IEnumerable LookupLocalFunc(string name) - { - return _definitions - .OfType() - .Where(x => x.Name == name); - } - - public IEnumerable LookupExternFunc(string name) - { - return _definitions - .OfType() - .Where(x => x.Name == name); - } - - public IEnumerable LookupStruct(string name) - { - return _definitions - .OfType() - .Where(x => x.Name == name); - } - - public IEnumerable LookupStructField(StructSyntax @struct, string field) - { - return @struct.Fields.Where(x => x.Name == field); - } - - public IEnumerable LookupStructFunc(StructSyntax @struct, string func) - { - return @struct.Functions.Where(x => x.Name == func); - } - - public IEnumerable LookupInterface(string name) - { - return _definitions - .OfType() - .Where(x => x.Name == name); - } - - public IEnumerable LookupInterfaceFunc(InterfaceSyntax @interface, string name) - { - return @interface.Functions.Where(x => x.Name == name); - } -} \ No newline at end of file diff --git a/src/compiler/NubLang/TypeChecking/Module.cs b/src/compiler/NubLang/TypeChecking/Module.cs new file mode 100644 index 0000000..598604e --- /dev/null +++ b/src/compiler/NubLang/TypeChecking/Module.cs @@ -0,0 +1,174 @@ +using NubLang.Parsing.Syntax; +using NubLang.TypeChecking.Node; + +namespace NubLang.TypeChecking; + +public class Module +{ + public static IReadOnlyList CollectFromSyntaxTrees(IReadOnlyList syntaxTrees) + { + var modules = new Dictionary(); + + foreach (var syntaxTree in syntaxTrees) + { + var name = syntaxTree.Metadata.ModuleName; + if (name == null) + { + continue; + } + + if (!modules.TryGetValue(name, out var module)) + { + module = new Module(name, syntaxTree.Metadata.Imports); + modules[name] = module; + } + + foreach (var definition in syntaxTree.Definitions) + { + module.AddDefinition(definition); + } + } + + return modules.Values.ToList(); + } + + private readonly List _definitions = []; + + public Module(string name, IReadOnlyList imports) + { + Name = name; + Imports = imports; + } + + public string Name { get; } + public IReadOnlyList Imports { get; } + + public IReadOnlyList Definitions => _definitions; + + private void AddDefinition(DefinitionSyntax syntax) + { + _definitions.Add(syntax); + } +} + +public class TypedModule +{ + public TypedModule(string name, IReadOnlyList definitions) + { + Name = name; + Definitions = definitions; + } + + public string Name { get; } + public IReadOnlyList Definitions { get; } +} + +public class ModuleSignature +{ + public static IReadOnlyDictionary CollectFromSyntaxTrees(IReadOnlyList syntaxTrees) + { + var modules = new Dictionary(); + + foreach (var syntaxTree in syntaxTrees) + { + var moduleName = syntaxTree.Metadata.ModuleName; + if (moduleName == null) + { + continue; + } + + if (!modules.TryGetValue(moduleName, out var module)) + { + module = new ModuleSignature(); + modules[moduleName] = module; + } + + foreach (var def in syntaxTree.Definitions) + { + if (def.Exported) + { + switch (def) + { + case ExternFuncSyntax externFuncDef: + { + var parameters = externFuncDef.Signature.Parameters.Select(p => TypeResolver.ResolveType(p.Type, modules)).ToList(); + var returnType = TypeResolver.ResolveType(externFuncDef.Signature.ReturnType, modules); + var type = new FuncTypeNode(parameters, returnType); + module._functions.Add(externFuncDef.Name, type); + break; + } + case FuncSyntax funcDef: + { + var parameters = funcDef.Signature.Parameters.Select(p => TypeResolver.ResolveType(p.Type, modules)).ToList(); + var returnType = TypeResolver.ResolveType(funcDef.Signature.ReturnType, modules); + var type = new FuncTypeNode(parameters, returnType); + module._functions.Add(funcDef.Name, type); + break; + } + case InterfaceSyntax interfaceDef: + { + var functions = new Dictionary(); + foreach (var function in interfaceDef.Functions) + { + var parameters = function.Signature.Parameters.Select(p => TypeResolver.ResolveType(p.Type, modules)).ToList(); + var returnType = TypeResolver.ResolveType(function.Signature.ReturnType, modules); + functions.Add(function.Name, new FuncTypeNode(parameters, returnType)); + } + + var type = new InterfaceTypeNode(moduleName, interfaceDef.Name, functions); + module._interfaces.Add(type); + break; + } + case StructSyntax structDef: + { + var fields = structDef.Fields.Select(x => new StructTypeField(x.Name, TypeResolver.ResolveType(x.Type, modules), x.Value.HasValue)).ToList(); + + var functions = new Dictionary(); + foreach (var function in structDef.Functions) + { + var parameters = function.Signature.Parameters.Select(p => TypeResolver.ResolveType(p.Type, modules)).ToList(); + var returnType = TypeResolver.ResolveType(function.Signature.ReturnType, modules); + functions.Add(function.Name, new FuncTypeNode(parameters, returnType)); + } + + var interfaceImplementations = new List(); + foreach (var interfaceImplementation in structDef.InterfaceImplementations) + { + if (interfaceImplementation is not CustomTypeSyntax customType) + { + throw new Exception("Interface implementation is not a custom type"); + } + + var resolvedType = TypeResolver.ResolveCustomType(customType.Module, customType.Name, modules); + if (resolvedType is not InterfaceTypeNode interfaceType) + { + throw new Exception("Interface implementation is not a interface"); + } + + interfaceImplementations.Add(interfaceType); + } + + var type = new StructTypeNode(moduleName, structDef.Name, fields, functions, interfaceImplementations); + module._structs.Add(type); + break; + } + default: + { + throw new ArgumentOutOfRangeException(nameof(def)); + } + } + } + } + } + + return modules; + } + + private readonly List _structs = []; + private readonly List _interfaces = []; + private readonly Dictionary _functions = []; + + public IReadOnlyList StructTypes => _structs; + public IReadOnlyList InterfaceTypes => _interfaces; + public IReadOnlyDictionary Functions => _functions; +} \ No newline at end of file diff --git a/src/compiler/NubLang/TypeChecking/Node/DefinitionNode.cs b/src/compiler/NubLang/TypeChecking/Node/DefinitionNode.cs index 1afa33d..3a42cd3 100644 --- a/src/compiler/NubLang/TypeChecking/Node/DefinitionNode.cs +++ b/src/compiler/NubLang/TypeChecking/Node/DefinitionNode.cs @@ -1,23 +1,21 @@ -using NubLang.Tokenization; +namespace NubLang.TypeChecking.Node; -namespace NubLang.TypeChecking.Node; +public abstract record DefinitionNode : Node; -public abstract record DefinitionNode(IEnumerable Tokens) : Node(Tokens); +public record FuncParameterNode(string Name, TypeNode Type) : Node; -public record FuncParameterNode(string Name, TypeNode Type, IEnumerable Tokens) : Node(Tokens); +public record FuncSignatureNode(IReadOnlyList Parameters, TypeNode ReturnType) : Node; -public record FuncSignatureNode(IReadOnlyList Parameters, TypeNode ReturnType, IEnumerable Tokens) : Node(Tokens); +public record LocalFuncNode(string Name, FuncSignatureNode Signature, BlockNode Body) : DefinitionNode; -public record LocalFuncNode(string Name, FuncSignatureNode Signature, BlockNode Body, IEnumerable Tokens) : DefinitionNode(Tokens); +public record ExternFuncNode(string Name, string CallName, FuncSignatureNode Signature) : DefinitionNode; -public record ExternFuncNode(string Name, string CallName, FuncSignatureNode Signature, IEnumerable Tokens) : DefinitionNode(Tokens); +public record StructFieldNode(int Index, string Name, TypeNode Type, Optional Value) : Node; -public record StructFieldNode(int Index, string Name, TypeNode Type, Optional Value, IEnumerable Tokens) : Node(Tokens); +public record StructFuncNode(string Name, FuncSignatureNode Signature, BlockNode Body) : Node; -public record StructFuncNode(string Name, FuncSignatureNode Signature, BlockNode Body, IEnumerable Tokens) : Node(Tokens); +public record StructNode(string Name, IReadOnlyList Fields, IReadOnlyList Functions, IReadOnlyList InterfaceImplementations) : DefinitionNode; -public record StructNode(string Name, StructTypeNode Type, IReadOnlyList Fields, IReadOnlyList Functions, IReadOnlyList InterfaceImplementations, IEnumerable Tokens) : DefinitionNode(Tokens); +public record InterfaceFuncNode(string Name, FuncSignatureNode Signature) : Node; -public record InterfaceFuncNode(string Name, FuncSignatureNode Signature, IEnumerable Tokens) : Node(Tokens); - -public record InterfaceNode(string Name, IReadOnlyList Functions, IEnumerable Tokens) : DefinitionNode(Tokens); \ No newline at end of file +public record InterfaceNode(string Name, IReadOnlyList Functions) : DefinitionNode; \ No newline at end of file diff --git a/src/compiler/NubLang/TypeChecking/Node/ExpressionNode.cs b/src/compiler/NubLang/TypeChecking/Node/ExpressionNode.cs index 4f11b26..df480cd 100644 --- a/src/compiler/NubLang/TypeChecking/Node/ExpressionNode.cs +++ b/src/compiler/NubLang/TypeChecking/Node/ExpressionNode.cs @@ -30,45 +30,43 @@ public enum BinaryOperator BitwiseOr } -public abstract record ExpressionNode(TypeNode Type, IEnumerable Tokens) : Node(Tokens); +public abstract record ExpressionNode(TypeNode Type) : Node; -public abstract record LValueExpressionNode(TypeNode Type, IEnumerable Tokens) : RValueExpressionNode(Type, Tokens); -public abstract record RValueExpressionNode(TypeNode Type, IEnumerable Tokens) : ExpressionNode(Type, Tokens); +public abstract record LValueExpressionNode(TypeNode Type) : RValueExpressionNode(Type); +public abstract record RValueExpressionNode(TypeNode Type) : ExpressionNode(Type); -public record BinaryExpressionNode(TypeNode Type, ExpressionNode Left, BinaryOperator Operator, ExpressionNode Right, IEnumerable Tokens) : RValueExpressionNode(Type, Tokens); +public record BinaryExpressionNode(TypeNode Type, ExpressionNode Left, BinaryOperator Operator, ExpressionNode Right) : RValueExpressionNode(Type); -public record UnaryExpressionNode(TypeNode Type, UnaryOperator Operator, ExpressionNode Operand, IEnumerable Tokens) : RValueExpressionNode(Type, Tokens); +public record UnaryExpressionNode(TypeNode Type, UnaryOperator Operator, ExpressionNode Operand) : RValueExpressionNode(Type); -public record FuncCallNode(TypeNode Type, ExpressionNode Expression, IReadOnlyList Parameters, IEnumerable Tokens) : RValueExpressionNode(Type, Tokens); +public record FuncCallNode(TypeNode Type, ExpressionNode Expression, IReadOnlyList Parameters) : RValueExpressionNode(Type); -public record StructFuncCallNode(TypeNode Type, string Name, StructTypeNode StructType, ExpressionNode StructExpression, IReadOnlyList Parameters, IEnumerable Tokens) : RValueExpressionNode(Type, Tokens); +public record StructFuncCallNode(TypeNode Type, string Name, StructTypeNode StructType, ExpressionNode StructExpression, IReadOnlyList Parameters) : RValueExpressionNode(Type); -public record InterfaceFuncCallNode(TypeNode Type, string Name, InterfaceTypeNode InterfaceType, ExpressionNode InterfaceExpression, IReadOnlyList Parameters, IEnumerable Tokens) : RValueExpressionNode(Type, Tokens); +public record InterfaceFuncCallNode(TypeNode Type, string Name, InterfaceTypeNode InterfaceType, ExpressionNode InterfaceExpression, IReadOnlyList Parameters) : RValueExpressionNode(Type); -public record VariableIdentNode(TypeNode Type, string Name, IEnumerable Tokens) : LValueExpressionNode(Type, Tokens); +public record VariableIdentifierNode(TypeNode Type, string Name) : LValueExpressionNode(Type); -public record FuncParameterIdentNode(TypeNode Type, string Name, IEnumerable Tokens) : RValueExpressionNode(Type, Tokens); +public record FuncParameterIdentifierNode(TypeNode Type, string Name) : RValueExpressionNode(Type); -public record LocalFuncIdentNode(TypeNode Type, string Name, IEnumerable Tokens) : RValueExpressionNode(Type, Tokens); +public record FuncIdentifierNode(TypeNode Type, string Module, string Name) : RValueExpressionNode(Type); -public record ExternFuncIdentNode(TypeNode Type, string Name, IEnumerable Tokens) : RValueExpressionNode(Type, Tokens); +public record ArrayInitializerNode(TypeNode Type, ExpressionNode Capacity, TypeNode ElementType) : RValueExpressionNode(Type); -public record ArrayInitializerNode(TypeNode Type, ExpressionNode Capacity, TypeNode ElementType, IEnumerable Tokens) : RValueExpressionNode(Type, Tokens); +public record ArrayIndexAccessNode(TypeNode Type, ExpressionNode Target, ExpressionNode Index) : LValueExpressionNode(Type); -public record ArrayIndexAccessNode(TypeNode Type, ExpressionNode Target, ExpressionNode Index, IEnumerable Tokens) : LValueExpressionNode(Type, Tokens); +public record AddressOfNode(TypeNode Type, LValueExpressionNode LValue) : RValueExpressionNode(Type); -public record AddressOfNode(TypeNode Type, LValueExpressionNode LValue, IEnumerable Tokens) : RValueExpressionNode(Type, Tokens); +public record LiteralNode(TypeNode Type, string Value, LiteralKind Kind) : RValueExpressionNode(Type); -public record LiteralNode(TypeNode Type, string Value, LiteralKind Kind, IEnumerable Tokens) : RValueExpressionNode(Type, Tokens); +public record StructFieldAccessNode(TypeNode Type, StructTypeNode StructType, ExpressionNode Target, string Field) : LValueExpressionNode(Type); -public record StructFieldAccessNode(TypeNode Type, StructTypeNode StructType, ExpressionNode Target, string Field, IEnumerable Tokens) : LValueExpressionNode(Type, Tokens); +public record StructInitializerNode(StructTypeNode StructType, Dictionary Initializers) : RValueExpressionNode(StructType); -public record StructInitializerNode(StructTypeNode StructType, Dictionary Initializers, IEnumerable Tokens) : RValueExpressionNode(StructType, Tokens); +public record DereferenceNode(TypeNode Type, ExpressionNode Expression) : RValueExpressionNode(Type); -public record DereferenceNode(TypeNode Type, ExpressionNode Expression, IEnumerable Tokens) : RValueExpressionNode(Type, Tokens); +public record ConvertToInterfaceNode(TypeNode Type, InterfaceTypeNode InterfaceType, StructTypeNode StructType, ExpressionNode Implementation) : RValueExpressionNode(Type); -public record ConvertToInterfaceNode(TypeNode Type, InterfaceTypeNode InterfaceType, StructTypeNode StructType, ExpressionNode Implementation, IEnumerable Tokens) : RValueExpressionNode(Type, Tokens); +public record ConvertIntNode(TypeNode Type, ExpressionNode Value, IntTypeNode ValueType, IntTypeNode TargetType) : RValueExpressionNode(Type); -public record ConvertIntNode(TypeNode Type, ExpressionNode Value, IntTypeNode ValueType, IntTypeNode TargetType, IEnumerable Tokens) : RValueExpressionNode(Type, Tokens); - -public record ConvertFloatNode(TypeNode Type, ExpressionNode Value, FloatTypeNode ValueType, FloatTypeNode TargetType, IEnumerable Tokens) : RValueExpressionNode(Type, Tokens); +public record ConvertFloatNode(TypeNode Type, ExpressionNode Value, FloatTypeNode ValueType, FloatTypeNode TargetType) : RValueExpressionNode(Type); diff --git a/src/compiler/NubLang/TypeChecking/Node/Node.cs b/src/compiler/NubLang/TypeChecking/Node/Node.cs index b2ce1af..4b75960 100644 --- a/src/compiler/NubLang/TypeChecking/Node/Node.cs +++ b/src/compiler/NubLang/TypeChecking/Node/Node.cs @@ -1,9 +1,5 @@ -using NubLang.Tokenization; +namespace NubLang.TypeChecking.Node; -namespace NubLang.TypeChecking.Node; +public abstract record Node; -public abstract record Node(IEnumerable Tokens); - -public record TypedSyntaxTree(IReadOnlyList Definitions); - -public record BlockNode(IReadOnlyList Statements, IEnumerable Tokens) : Node(Tokens); \ No newline at end of file +public record BlockNode(IReadOnlyList Statements) : Node; \ No newline at end of file diff --git a/src/compiler/NubLang/TypeChecking/Node/StatementNode.cs b/src/compiler/NubLang/TypeChecking/Node/StatementNode.cs index 69de9da..bd58462 100644 --- a/src/compiler/NubLang/TypeChecking/Node/StatementNode.cs +++ b/src/compiler/NubLang/TypeChecking/Node/StatementNode.cs @@ -1,21 +1,19 @@ -using NubLang.Tokenization; +namespace NubLang.TypeChecking.Node; -namespace NubLang.TypeChecking.Node; +public record StatementNode : Node; -public record StatementNode(IEnumerable Tokens) : Node(Tokens); +public record StatementExpressionNode(ExpressionNode Expression) : StatementNode; -public record StatementExpressionNode(ExpressionNode Expression, IEnumerable Tokens) : StatementNode(Tokens); +public record ReturnNode(Optional Value) : StatementNode; -public record ReturnNode(Optional Value, IEnumerable Tokens) : StatementNode(Tokens); +public record AssignmentNode(LValueExpressionNode Target, ExpressionNode Value) : StatementNode; -public record AssignmentNode(LValueExpressionNode Target, ExpressionNode Value, IEnumerable Tokens) : StatementNode(Tokens); +public record IfNode(ExpressionNode Condition, BlockNode Body, Optional> Else) : StatementNode; -public record IfNode(ExpressionNode Condition, BlockNode Body, Optional> Else, IEnumerable Tokens) : StatementNode(Tokens); +public record VariableDeclarationNode(string Name, Optional Assignment, TypeNode Type) : StatementNode; -public record VariableDeclarationNode(string Name, Optional Assignment, TypeNode Type, IEnumerable Tokens) : StatementNode(Tokens); +public record ContinueNode : StatementNode; -public record ContinueNode(IEnumerable Tokens) : StatementNode(Tokens); +public record BreakNode : StatementNode; -public record BreakNode(IEnumerable Tokens) : StatementNode(Tokens); - -public record WhileNode(ExpressionNode Condition, BlockNode Body, IEnumerable Tokens) : StatementNode(Tokens); \ No newline at end of file +public record WhileNode(ExpressionNode Condition, BlockNode Body) : StatementNode; \ No newline at end of file diff --git a/src/compiler/NubLang/TypeChecking/Node/TypeNode.cs b/src/compiler/NubLang/TypeChecking/Node/TypeNode.cs index df0ab49..2f5bbdd 100644 --- a/src/compiler/NubLang/TypeChecking/Node/TypeNode.cs +++ b/src/compiler/NubLang/TypeChecking/Node/TypeNode.cs @@ -174,25 +174,34 @@ public class StringTypeNode : ComplexTypeNode public override int GetHashCode() => HashCode.Combine(typeof(StringTypeNode)); } -public class StructTypeNode(string name, IReadOnlyList fields, IReadOnlyList functions, IReadOnlyList interfaceImplementations) : ComplexTypeNode +public class StructTypeField(string name, TypeNode type, bool hasDefaultValue) { public string Name { get; } = name; - public IReadOnlyList Fields { get; set; } = fields; - public IReadOnlyList Functions { get; set; } = functions; + public TypeNode Type { get; } = type; + public bool HasDefaultValue { get; } = hasDefaultValue; +} + +public class StructTypeNode(string module, string name, IReadOnlyList fields, IReadOnlyDictionary functions, IReadOnlyList interfaceImplementations) : ComplexTypeNode +{ + public string Module { get; } = module; + public string Name { get; } = name; + public IReadOnlyList Fields { get; set; } = fields; + public IReadOnlyDictionary Functions { get; set; } = functions; public IReadOnlyList InterfaceImplementations { get; set; } = interfaceImplementations; public override string ToString() => Name; - public override bool Equals(TypeNode? other) => other is StructTypeNode structType && Name == structType.Name; + public override bool Equals(TypeNode? other) => other is StructTypeNode structType && Name == structType.Name && Module == structType.Module; public override int GetHashCode() => HashCode.Combine(typeof(StructTypeNode), Name); } -public class InterfaceTypeNode(string name, IReadOnlyList functions) : ComplexTypeNode +public class InterfaceTypeNode(string module, string name, IReadOnlyDictionary functions) : ComplexTypeNode { + public string Module { get; } = module; public string Name { get; } = name; - public IReadOnlyList Functions { get; set; } = functions; + public IReadOnlyDictionary Functions { get; set; } = functions; public override string ToString() => Name; - public override bool Equals(TypeNode? other) => other is InterfaceTypeNode interfaceType && Name == interfaceType.Name; + public override bool Equals(TypeNode? other) => other is InterfaceTypeNode interfaceType && Name == interfaceType.Name && Module == interfaceType.Module; public override int GetHashCode() => HashCode.Combine(typeof(InterfaceTypeNode), Name); } diff --git a/src/compiler/NubLang/TypeChecking/TypeChecker.cs b/src/compiler/NubLang/TypeChecking/TypeChecker.cs index 388087d..d636b43 100644 --- a/src/compiler/NubLang/TypeChecking/TypeChecker.cs +++ b/src/compiler/NubLang/TypeChecking/TypeChecker.cs @@ -7,33 +7,31 @@ namespace NubLang.TypeChecking; public sealed class TypeChecker { - private readonly SyntaxTree _syntaxTree; - private readonly DefinitionTable _definitionTable; + private readonly Module _currentModule; + private readonly IReadOnlyDictionary _moduleSignatures; private readonly Stack _scopes = []; private readonly Stack _funcReturnTypes = []; private readonly List _diagnostics = []; - private readonly Dictionary _typeCache = new(); private Scope Scope => _scopes.Peek(); - public TypeChecker(SyntaxTree syntaxTree, DefinitionTable definitionTable) + public TypeChecker(Module currentModule, IReadOnlyDictionary moduleSignatures) { - _syntaxTree = syntaxTree; - _definitionTable = definitionTable; + _currentModule = currentModule; + _moduleSignatures = moduleSignatures.Where(x => currentModule.Imports.Contains(x.Key) || _currentModule.Name == x.Key).ToDictionary(); } public IReadOnlyList GetDiagnostics() => _diagnostics; - public TypedSyntaxTree Check() + public TypedModule CheckModule() { _diagnostics.Clear(); - _funcReturnTypes.Clear(); _scopes.Clear(); var definitions = new List(); - foreach (var definition in _syntaxTree.Definitions) + foreach (var definition in _currentModule.Definitions) { try { @@ -45,7 +43,7 @@ public sealed class TypeChecker } } - return new TypedSyntaxTree(definitions); + return new TypedModule(_currentModule.Name, definitions); } private DefinitionNode CheckDefinition(DefinitionSyntax node) @@ -54,7 +52,7 @@ public sealed class TypeChecker { ExternFuncSyntax definition => CheckExternFuncDefinition(definition), InterfaceSyntax definition => CheckInterfaceDefinition(definition), - LocalFuncSyntax definition => CheckLocalFuncDefinition(definition), + FuncSyntax definition => CheckLocalFuncDefinition(definition), StructSyntax definition => CheckStructDefinition(definition), _ => throw new ArgumentOutOfRangeException(nameof(node)) }; @@ -62,102 +60,72 @@ public sealed class TypeChecker private InterfaceNode CheckInterfaceDefinition(InterfaceSyntax node) { - var functions = new List(); - - foreach (var function in node.Functions) - { - functions.Add(new InterfaceFuncNode(function.Name, CheckFuncSignature(function.Signature), function.Tokens)); - } - - return new InterfaceNode(node.Name, functions, node.Tokens); + throw new NotImplementedException(); } private StructNode CheckStructDefinition(StructSyntax node) { - var structFields = new List(); - + var fields = new List(); foreach (var field in node.Fields) { var value = Optional.Empty(); - if (field.Value.HasValue) { - value = CheckExpression(field.Value.Value, CheckType(field.Type)); + value = CheckExpression(field.Value.Value); } - structFields.Add(new StructFieldNode(field.Index, field.Name, CheckType(field.Type), value, field.Tokens)); + fields.Add(new StructFieldNode(field.Index, field.Name, ResolveType(field.Type), value)); } - var funcs = new List(); - - foreach (var func in node.Functions) + var functions = new List(); + foreach (var function in node.Functions) { var scope = new Scope(); - - scope.Declare(new Identifier("this", GetStructType(node), IdentifierKind.FunctionParameter)); - foreach (var parameter in func.Signature.Parameters) + // todo(nub31): Add this parameter + foreach (var parameter in function.Signature.Parameters) { - scope.Declare(new Identifier(parameter.Name, CheckType(parameter.Type), IdentifierKind.FunctionParameter)); + scope.Declare(new Identifier(parameter.Name, ResolveType(parameter.Type), IdentifierKind.FunctionParameter)); } - _funcReturnTypes.Push(CheckType(func.Signature.ReturnType)); - var body = CheckBlock(func.Body, scope); + _funcReturnTypes.Push(ResolveType(function.Signature.ReturnType)); + var body = CheckBlock(function.Body, scope); _funcReturnTypes.Pop(); - - funcs.Add(new StructFuncNode(func.Name, CheckFuncSignature(func.Signature), body, func.Tokens)); + functions.Add(new StructFuncNode(function.Name, CheckFuncSignature(function.Signature), body)); } var interfaceImplementations = new List(); - foreach (var interfaceImplementation in node.InterfaceImplementations) { - var type = CheckType(interfaceImplementation); + var type = ResolveType(interfaceImplementation); if (type is not InterfaceTypeNode interfaceType) { - _diagnostics.Add(Diagnostic.Error("Interface implementation is not a custom type").Build()); - continue; - } - - var interfaceDefs = _definitionTable.LookupInterface(interfaceType.Name).ToArray(); - - if (interfaceDefs.Length == 0) - { - _diagnostics.Add(Diagnostic.Error($"Interface {interfaceType.Name} is not defined").Build()); - continue; - } - - if (interfaceDefs.Length > 1) - { - _diagnostics.Add(Diagnostic.Error($"Interface {interfaceType.Name} has multiple definitions").Build()); + _diagnostics.Add(Diagnostic.Error($"Struct {node.Name} cannot implement non-struct type {interfaceImplementation}").At(interfaceImplementation).Build()); continue; } interfaceImplementations.Add(interfaceType); } - return new StructNode(node.Name, GetStructType(node), structFields, funcs, interfaceImplementations, node.Tokens); + return new StructNode(node.Name, fields, functions, interfaceImplementations); } private ExternFuncNode CheckExternFuncDefinition(ExternFuncSyntax node) { - return new ExternFuncNode(node.Name, node.CallName, CheckFuncSignature(node.Signature), node.Tokens); + return new ExternFuncNode(node.Name, node.CallName, CheckFuncSignature(node.Signature)); } - private LocalFuncNode CheckLocalFuncDefinition(LocalFuncSyntax node) + private LocalFuncNode CheckLocalFuncDefinition(FuncSyntax node) { - var signature = CheckFuncSignature(node.Signature); - var scope = new Scope(); - foreach (var parameter in signature.Parameters) + foreach (var parameter in node.Signature.Parameters) { - scope.Declare(new Identifier(parameter.Name, parameter.Type, IdentifierKind.FunctionParameter)); + scope.Declare(new Identifier(parameter.Name, ResolveType(parameter.Type), IdentifierKind.FunctionParameter)); } - _funcReturnTypes.Push(signature.ReturnType); + _funcReturnTypes.Push(ResolveType(node.Signature.ReturnType)); var body = CheckBlock(node.Body, scope); _funcReturnTypes.Pop(); - - return new LocalFuncNode(node.Name, signature, body, node.Tokens); + return new LocalFuncNode(node.Name, CheckFuncSignature(node.Signature), body); } private StatementNode CheckStatement(StatementSyntax node) @@ -165,8 +133,8 @@ public sealed class TypeChecker return node switch { AssignmentSyntax statement => CheckAssignment(statement), - BreakSyntax => new BreakNode(node.Tokens), - ContinueSyntax => new ContinueNode(node.Tokens), + BreakSyntax => new BreakNode(), + ContinueSyntax => new ContinueNode(), IfSyntax statement => CheckIf(statement), ReturnSyntax statement => CheckReturn(statement), StatementExpressionSyntax statement => CheckStatementExpression(statement), @@ -178,30 +146,12 @@ public sealed class TypeChecker private StatementNode CheckAssignment(AssignmentSyntax statement) { - var target = CheckExpression(statement.Target); - if (target is not LValueExpressionNode targetLValue) - { - throw new TypeCheckerException(Diagnostic.Error("Cannot assign to rvalue").Build()); - } - - var value = CheckExpression(statement.Value, target.Type); - return new AssignmentNode(targetLValue, value, statement.Tokens); + throw new NotImplementedException(); } private IfNode CheckIf(IfSyntax statement) { - var elseStatement = Optional.Empty>(); - - if (statement.Else.HasValue) - { - elseStatement = statement.Else.Value.Match> - ( - elseIf => CheckIf(elseIf), - @else => CheckBlock(@else) - ); - } - - return new IfNode(CheckExpression(statement.Condition, new BoolTypeNode()), CheckBlock(statement.Body), elseStatement, statement.Tokens); + throw new NotImplementedException(); } private ReturnNode CheckReturn(ReturnSyntax statement) @@ -213,58 +163,54 @@ public sealed class TypeChecker value = CheckExpression(statement.Value.Value, _funcReturnTypes.Peek()); } - return new ReturnNode(value, statement.Tokens); + return new ReturnNode(value); } private StatementExpressionNode CheckStatementExpression(StatementExpressionSyntax statement) { - return new StatementExpressionNode(CheckExpression(statement.Expression), statement.Tokens); + return new StatementExpressionNode(CheckExpression(statement.Expression)); } private VariableDeclarationNode CheckVariableDeclaration(VariableDeclarationSyntax statement) { TypeNode? type = null; + ExpressionNode? assignmentNode = null; - if (statement.ExplicitType.HasValue) + if (statement.ExplicitType.TryGetValue(out var explicitType)) { - type = CheckType(statement.ExplicitType.Value); + type = ResolveType(explicitType); } - var assignment = Optional.Empty(); - if (statement.Assignment.HasValue) + if (statement.Assignment.TryGetValue(out var assignment)) { - var boundValue = CheckExpression(statement.Assignment.Value, type); - assignment = boundValue; - - if (type != null) - { - if (boundValue.Type != type) - { - throw new TypeCheckerException(Diagnostic.Error($"{boundValue.Type} is not assignable to {type}").Build()); - } - } - else - { - if (type == null) - { - type = boundValue.Type; - } - } + assignmentNode = CheckExpression(assignment, type); + type ??= assignmentNode.Type; } if (type == null) { - throw new TypeCheckerException(Diagnostic.Error($"Unknown type of variable {statement.Name}").Build()); + throw new TypeCheckerException(Diagnostic.Error($"Cannot infer type of variable {statement.Name}").At(statement).Build()); } Scope.Declare(new Identifier(statement.Name, type, IdentifierKind.Variable)); - return new VariableDeclarationNode(statement.Name, assignment, type, statement.Tokens); + return new VariableDeclarationNode(statement.Name, Optional.OfNullable(assignmentNode), type); } private WhileNode CheckWhile(WhileSyntax statement) { - return new WhileNode(CheckExpression(statement.Condition, new BoolTypeNode()), CheckBlock(statement.Body), statement.Tokens); + throw new NotImplementedException(); + } + + private FuncSignatureNode CheckFuncSignature(FuncSignatureSyntax statement) + { + var parameters = new List(); + foreach (var parameter in statement.Parameters) + { + parameters.Add(new FuncParameterNode(parameter.Name, ResolveType(parameter.Type))); + } + + return new FuncSignatureNode(parameters, ResolveType(statement.ReturnType)); } private ExpressionNode CheckExpression(ExpressionSyntax node, TypeNode? expectedType = null) @@ -293,14 +239,14 @@ public sealed class TypeChecker if (result.Type is StructTypeNode structType && expectedType is InterfaceTypeNode interfaceType) { - return new ConvertToInterfaceNode(interfaceType, interfaceType, structType, result, node.Tokens); + return new ConvertToInterfaceNode(interfaceType, interfaceType, structType, result); } if (result.Type is IntTypeNode sourceIntType && expectedType is IntTypeNode targetIntType) { if (sourceIntType.Signed == targetIntType.Signed && sourceIntType.Width < targetIntType.Width) { - return new ConvertIntNode(targetIntType, result, sourceIntType, targetIntType, node.Tokens); + return new ConvertIntNode(targetIntType, result, sourceIntType, targetIntType); } } @@ -308,245 +254,115 @@ public sealed class TypeChecker { if (sourceFloatType.Width < targetFloatType.Width) { - return new ConvertFloatNode(targetFloatType, result, sourceFloatType, targetFloatType, node.Tokens); + return new ConvertFloatNode(targetFloatType, result, sourceFloatType, targetFloatType); } } - throw new TypeCheckerException(Diagnostic.Error($"Cannot convert {result.Type} to {expectedType}").Build()); + throw new TypeCheckerException(Diagnostic.Error($"Cannot convert {result.Type} to {expectedType}").At(node).Build()); } private AddressOfNode CheckAddressOf(AddressOfSyntax expression) { - var inner = CheckExpression(expression.Expression); - - if (inner is not LValueExpressionNode lValueInner) - { - throw new TypeCheckerException(Diagnostic.Error("Cannot take address of rvalue").Build()); - } - - return new AddressOfNode(new PointerTypeNode(inner.Type), lValueInner, expression.Tokens); + throw new NotImplementedException(); } private ArrayIndexAccessNode CheckArrayIndexAccess(ArrayIndexAccessSyntax expression) { - var boundArray = CheckExpression(expression.Target); - var elementType = ((ArrayTypeNode)boundArray.Type).ElementType; - return new ArrayIndexAccessNode(elementType, boundArray, CheckExpression(expression.Index, new IntTypeNode(false, 64)), expression.Tokens); + throw new NotImplementedException(); } private ArrayInitializerNode CheckArrayInitializer(ArrayInitializerSyntax expression) { - var capacity = CheckExpression(expression.Capacity, new IntTypeNode(false, 64)); - var type = new ArrayTypeNode(CheckType(expression.ElementType)); - return new ArrayInitializerNode(type, capacity, CheckType(expression.ElementType), expression.Tokens); + throw new NotImplementedException(); } private BinaryExpressionNode CheckBinaryExpression(BinaryExpressionSyntax expression) { - var boundLeft = CheckExpression(expression.Left); - var boundRight = CheckExpression(expression.Right, boundLeft.Type); - - var op = expression.Operator switch - { - BinaryOperatorSyntax.Equal => BinaryOperator.Equal, - BinaryOperatorSyntax.NotEqual => BinaryOperator.NotEqual, - BinaryOperatorSyntax.GreaterThan => BinaryOperator.GreaterThan, - BinaryOperatorSyntax.GreaterThanOrEqual => BinaryOperator.GreaterThanOrEqual, - BinaryOperatorSyntax.LessThan => BinaryOperator.LessThan, - BinaryOperatorSyntax.LessThanOrEqual => BinaryOperator.LessThanOrEqual, - BinaryOperatorSyntax.Plus => BinaryOperator.Plus, - BinaryOperatorSyntax.Minus => BinaryOperator.Minus, - BinaryOperatorSyntax.Multiply => BinaryOperator.Multiply, - BinaryOperatorSyntax.Divide => BinaryOperator.Divide, - BinaryOperatorSyntax.Modulo => BinaryOperator.Modulo, - BinaryOperatorSyntax.LeftShift => BinaryOperator.LeftShift, - BinaryOperatorSyntax.RightShift => BinaryOperator.RightShift, - BinaryOperatorSyntax.BitwiseAnd => BinaryOperator.BitwiseAnd, - BinaryOperatorSyntax.BitwiseXor => BinaryOperator.BitwiseXor, - BinaryOperatorSyntax.BitwiseOr => BinaryOperator.BitwiseOr, - BinaryOperatorSyntax.LogicalAnd => BinaryOperator.LogicalAnd, - BinaryOperatorSyntax.LogicalOr => BinaryOperator.LogicalOr, - _ => throw new ArgumentOutOfRangeException(nameof(expression.Operator), expression.Operator, null) - }; - - var resultingType = op switch - { - BinaryOperator.Equal => new BoolTypeNode(), - BinaryOperator.NotEqual => new BoolTypeNode(), - BinaryOperator.GreaterThan => new BoolTypeNode(), - BinaryOperator.GreaterThanOrEqual => new BoolTypeNode(), - BinaryOperator.LessThan => new BoolTypeNode(), - BinaryOperator.LessThanOrEqual => new BoolTypeNode(), - BinaryOperator.LogicalAnd => new BoolTypeNode(), - BinaryOperator.LogicalOr => new BoolTypeNode(), - BinaryOperator.Plus => boundLeft.Type, - BinaryOperator.Minus => boundLeft.Type, - BinaryOperator.Multiply => boundLeft.Type, - BinaryOperator.Divide => boundLeft.Type, - BinaryOperator.Modulo => boundLeft.Type, - BinaryOperator.LeftShift => boundLeft.Type, - BinaryOperator.RightShift => boundLeft.Type, - BinaryOperator.BitwiseAnd => boundLeft.Type, - BinaryOperator.BitwiseXor => boundLeft.Type, - BinaryOperator.BitwiseOr => boundLeft.Type, - _ => throw new ArgumentOutOfRangeException() - }; - - return new BinaryExpressionNode(resultingType, boundLeft, op, boundRight, expression.Tokens); + throw new NotImplementedException(); } private DereferenceNode CheckDereference(DereferenceSyntax expression) { - var boundExpression = CheckExpression(expression.Expression); - var dereferencedType = ((PointerTypeNode)boundExpression.Type).BaseType; - return new DereferenceNode(dereferencedType, boundExpression, expression.Tokens); + throw new NotImplementedException(); } private FuncCallNode CheckFuncCall(FuncCallSyntax expression) { - var boundExpression = CheckExpression(expression.Expression); - - if (boundExpression.Type is not FuncTypeNode funcType) + var accessor = CheckExpression(expression.Expression); + if (accessor.Type is not FuncTypeNode funcType) { - throw new TypeCheckerException(Diagnostic.Error($"Cannot call non-function type {boundExpression.Type}").Build()); + throw new TypeCheckerException(Diagnostic.Error($"Cannot call non-function type {accessor.Type}").At(expression.Expression).Build()); + } + + if (expression.Parameters.Count != funcType.Parameters.Count) + { + throw new TypeCheckerException(Diagnostic.Error($"Function {funcType} expects {funcType.Parameters} but got {expression.Parameters.Count} parameters").At(expression.Expression).Build()); } var parameters = new List(); - - foreach (var (i, parameter) in expression.Parameters.Index()) + for (var i = 0; i < expression.Parameters.Count; i++) { - if (i >= funcType.Parameters.Count) - { - _diagnostics.Add(Diagnostic.Error($"Expected {funcType.Parameters.Count} parameters").Build()); - } - + var parameter = expression.Parameters[i]; var expectedType = funcType.Parameters[i]; - parameters.Add(CheckExpression(parameter, expectedType)); + var parameterExpression = CheckExpression(parameter, expectedType); + if (parameterExpression.Type != expectedType) + { + throw new Exception($"Parameter {i + 1} does not match the type {expectedType} for function {funcType}"); + } + + parameters.Add(parameterExpression); } - return new FuncCallNode(funcType.ReturnType, boundExpression, parameters, expression.Tokens); + return new FuncCallNode(funcType.ReturnType, accessor, parameters); } private ExpressionNode CheckDotFuncCall(DotFuncCallSyntax expression) { - var thisParameter = CheckExpression(expression.ThisParameter); - - if (thisParameter.Type is InterfaceTypeNode interfaceType) - { - var interfaceDefinitions = _definitionTable.LookupInterface(interfaceType.Name).ToArray(); - if (interfaceDefinitions.Length == 0) - { - throw new TypeCheckerException(Diagnostic.Error($"Interface {interfaceType.Name} is not defined").Build()); - } - - if (interfaceDefinitions.Length > 1) - { - throw new TypeCheckerException(Diagnostic.Error($"Interface {interfaceType.Name} has multiple definitions").Build()); - } - - var function = interfaceDefinitions[0].Functions.FirstOrDefault(x => x.Name == expression.Name); - if (function == null) - { - throw new TypeCheckerException(Diagnostic.Error($"Interface {interfaceType.Name} does not have a function with the name {expression.Name}").Build()); - } - - var parameters = new List(); - for (var i = 0; i < expression.Parameters.Count; i++) - { - var parameter = expression.Parameters[i]; - var expectedType = CheckType(function.Signature.Parameters[i].Type); - parameters.Add(CheckExpression(parameter, expectedType)); - } - - var returnType = CheckType(function.Signature.ReturnType); - return new InterfaceFuncCallNode(returnType, expression.Name, interfaceType, thisParameter, parameters, expression.Tokens); - } - - if (thisParameter.Type is StructTypeNode structType) - { - var structDefinitions = _definitionTable.LookupStruct(structType.Name).ToArray(); - if (structDefinitions.Length == 0) - { - throw new TypeCheckerException(Diagnostic.Error($"Struct {structType.Name} is not defined").Build()); - } - - if (structDefinitions.Length > 1) - { - throw new TypeCheckerException(Diagnostic.Error($"Struct {structType.Name} has multiple definitions").Build()); - } - - var function = structDefinitions[0].Functions.FirstOrDefault(x => x.Name == expression.Name); - if (function == null) - { - throw new TypeCheckerException(Diagnostic.Error($"Struct {structType.Name} does not have a function with the name {expression.Name}").Build()); - } - - var parameters = new List(); - for (var i = 0; i < expression.Parameters.Count; i++) - { - var parameter = expression.Parameters[i]; - var expectedType = CheckType(function.Signature.Parameters[i].Type); - parameters.Add(CheckExpression(parameter, expectedType)); - } - - var returnType = CheckType(function.Signature.ReturnType); - return new StructFuncCallNode(returnType, expression.Name, structType, thisParameter, parameters, expression.Tokens); - } - - throw new TypeCheckerException(Diagnostic.Error($"Cannot call dot function on type {thisParameter.Type}").Build()); + throw new NotImplementedException(); } private ExpressionNode CheckIdentifier(IdentifierSyntax expression) { - var identifier = Scope.Lookup(expression.Name); - if (identifier != null) + // If the identifier does not have a module specified, first check if a local variable or function parameter with that identifier exists + if (!expression.Module.TryGetValue(out var moduleName)) { - return identifier.Kind switch + var scopeIdent = Scope.Lookup(expression.Name); + if (scopeIdent != null) { - IdentifierKind.Variable => new VariableIdentNode(identifier.Type, identifier.Name, expression.Tokens), - IdentifierKind.FunctionParameter => new FuncParameterIdentNode(identifier.Type, identifier.Name, expression.Tokens), - _ => throw new ArgumentOutOfRangeException() - }; - } - - var localFuncs = _definitionTable.LookupLocalFunc(expression.Name).ToArray(); - if (localFuncs.Length > 0) - { - if (localFuncs.Length > 1) - { - throw new TypeCheckerException(Diagnostic.Error($"Extern func {expression.Name} has multiple definitions").Build()); + switch (scopeIdent.Kind) + { + case IdentifierKind.Variable: + { + return new VariableIdentifierNode(scopeIdent.Type, expression.Name); + } + case IdentifierKind.FunctionParameter: + { + return new FuncParameterIdentifierNode(scopeIdent.Type, expression.Name); + } + default: + { + throw new ArgumentOutOfRangeException(); + } + } } - - var localFunc = localFuncs[0]; - - var returnType = CheckType(localFunc.Signature.ReturnType); - var parameterTypes = localFunc.Signature.Parameters.Select(p => CheckType(p.Type)).ToList(); - var type = new FuncTypeNode(parameterTypes, returnType); - return new LocalFuncIdentNode(type, expression.Name, expression.Tokens); } - var externFuncs = _definitionTable.LookupExternFunc(expression.Name).ToArray(); - if (externFuncs.Length > 0) + moduleName ??= _currentModule.Name; + if (_moduleSignatures.TryGetValue(moduleName, out var module)) { - if (externFuncs.Length > 1) + if (module.Functions.TryGetValue(expression.Name, out var function)) { - throw new TypeCheckerException(Diagnostic.Error($"Extern func {expression.Name} has multiple definitions").Build()); + return new FuncIdentifierNode(function, moduleName, expression.Name); } - - var externFunc = externFuncs[0]; - - var returnType = CheckType(externFunc.Signature.ReturnType); - var parameterTypes = externFunc.Signature.Parameters.Select(p => CheckType(p.Type)).ToList(); - var type = new FuncTypeNode(parameterTypes, returnType); - return new ExternFuncIdentNode(type, expression.Name, expression.Tokens); } - throw new TypeCheckerException(Diagnostic.Error($"No identifier with the name {expression.Name} exists").Build()); + throw new TypeCheckerException(Diagnostic.Error($"Identifier {expression.Name} not found").At(expression).Build()); } private LiteralNode CheckLiteral(LiteralSyntax expression, TypeNode? expectedType) { + // todo(nub31): Check if the types can actually be represented as another one. For example, an int should be passed when a string is expected var type = expectedType ?? expression.Kind switch { LiteralKind.Integer => new IntTypeNode(true, 64), @@ -556,150 +372,22 @@ public sealed class TypeChecker _ => throw new ArgumentOutOfRangeException() }; - return new LiteralNode(type, expression.Value, expression.Kind, expression.Tokens); + return new LiteralNode(type, expression.Value, expression.Kind); } private StructFieldAccessNode CheckStructFieldAccess(StructFieldAccessSyntax expression) { - var boundExpression = CheckExpression(expression.Target); - - if (boundExpression.Type is not StructTypeNode structType) - { - throw new Exception($"Cannot access struct field on non-struct type {boundExpression.Type}"); - } - - var structs = _definitionTable.LookupStruct(structType.Name).ToArray(); - if (structs.Length > 0) - { - if (structs.Length > 1) - { - throw new TypeCheckerException(Diagnostic.Error($"Struct {structType} has multiple definitions").Build()); - } - - var fields = _definitionTable.LookupStructField(structs[0], expression.Member).ToArray(); - if (fields.Length > 0) - { - if (fields.Length > 1) - { - throw new TypeCheckerException(Diagnostic.Error($"Struct {structType} has multiple fields with the name {expression.Member}").Build()); - } - - var field = fields[0]; - - return new StructFieldAccessNode(CheckType(field.Type), structType, boundExpression, expression.Member, expression.Tokens); - } - } - - throw new TypeCheckerException(Diagnostic.Error($"{boundExpression.Type} does not have a member with the name {expression.Member}").Build()); + throw new NotImplementedException(); } private StructInitializerNode CheckStructInitializer(StructInitializerSyntax expression, TypeNode? expectedType) { - var type = expectedType; - if (expression.StructType.HasValue) - { - type = CheckType(expression.StructType.Value); - } - - if (type == null) - { - throw new TypeCheckerException(Diagnostic.Error("Cannot determine type of struct").Build()); - } - - if (type is not StructTypeNode structType) - { - throw new TypeCheckerException(Diagnostic.Error($"Cannot initialize non-struct type {expression.StructType}").Build()); - } - - var structs = _definitionTable.LookupStruct(structType.Name).ToArray(); - - if (structs.Length == 0) - { - throw new TypeCheckerException(Diagnostic.Error($"Struct {structType} is not defined").Build()); - } - - if (structs.Length > 1) - { - throw new TypeCheckerException(Diagnostic.Error($"Struct {structType} has multiple definitions").Build()); - } - - var @struct = structs[0]; - - var initializers = new Dictionary(); - - foreach (var (field, initializer) in expression.Initializers) - { - var fields = _definitionTable.LookupStructField(@struct, field).ToArray(); - - if (fields.Length == 0) - { - throw new TypeCheckerException(Diagnostic.Error($"Struct {structType} does not have a field with the name {field}").Build()); - } - - if (fields.Length > 1) - { - throw new TypeCheckerException(Diagnostic.Error($"Struct {structType} has multiple fields with the name {field}").Build()); - } - - initializers[field] = CheckExpression(initializer, CheckType(fields[0].Type)); - } - - return new StructInitializerNode(structType, initializers, expression.Tokens); + throw new NotImplementedException(); } private UnaryExpressionNode CheckUnaryExpression(UnaryExpressionSyntax expression) { - var boundOperand = CheckExpression(expression.Operand); - - TypeNode? type = null; - - switch (expression.Operator) - { - case UnaryOperatorSyntax.Negate: - { - boundOperand = CheckExpression(expression.Operand, new IntTypeNode(true, 64)); - - if (boundOperand.Type is IntTypeNode or FloatTypeNode) - { - type = boundOperand.Type; - } - - break; - } - case UnaryOperatorSyntax.Invert: - { - boundOperand = CheckExpression(expression.Operand, new BoolTypeNode()); - - type = new BoolTypeNode(); - break; - } - } - - if (type == null) - { - throw new TypeCheckerException(Diagnostic.Error($"Cannot perform unary operation {expression.Operand} on type {boundOperand.Type}").Build()); - } - - var op = expression.Operator switch - { - UnaryOperatorSyntax.Negate => UnaryOperator.Negate, - UnaryOperatorSyntax.Invert => UnaryOperator.Invert, - _ => throw new ArgumentOutOfRangeException(nameof(expression.Operator), expression.Operator, null) - }; - - return new UnaryExpressionNode(type, op, boundOperand, expression.Tokens); - } - - private FuncSignatureNode CheckFuncSignature(FuncSignatureSyntax node) - { - var parameters = new List(); - - foreach (var parameter in node.Parameters) - { - parameters.Add(new FuncParameterNode(parameter.Name, CheckType(parameter.Type), parameter.Tokens)); - } - - return new FuncSignatureNode(parameters, CheckType(node.ReturnType), node.Tokens); + throw new NotImplementedException(); } private BlockNode CheckBlock(BlockSyntax node, Scope? scope = null) @@ -715,105 +403,12 @@ public sealed class TypeChecker _scopes.Pop(); - return new BlockNode(statements, node.Tokens); + return new BlockNode(statements); } - private TypeNode CheckType(TypeSyntax node) + private TypeNode ResolveType(TypeSyntax fieldType) { - return node switch - { - ArrayTypeSyntax type => new ArrayTypeNode(CheckType(type.BaseType)), - BoolTypeSyntax => new BoolTypeNode(), - CStringTypeSyntax => new CStringTypeNode(), - CustomTypeSyntax type => CheckCustomType(type), - FloatTypeSyntax @float => new FloatTypeNode(@float.Width), - FuncTypeSyntax type => new FuncTypeNode(type.Parameters.Select(CheckType).ToList(), CheckType(type.ReturnType)), - IntTypeSyntax @int => new IntTypeNode(@int.Signed, @int.Width), - PointerTypeSyntax type => new PointerTypeNode(CheckType(type.BaseType)), - StringTypeSyntax => new StringTypeNode(), - VoidTypeSyntax => new VoidTypeNode(), - _ => throw new ArgumentOutOfRangeException(nameof(node)) - }; - } - - private TypeNode CheckCustomType(CustomTypeSyntax type) - { - var structs = _definitionTable.LookupStruct(type.Name).ToArray(); - if (structs.Length > 0) - { - if (structs.Length > 1) - { - throw new TypeCheckerException(Diagnostic.Error($"Struct {type.Name} has multiple definitions").Build()); - } - - return GetStructType(structs[0]); - } - - var interfaces = _definitionTable.LookupInterface(type.Name).ToArray(); - if (interfaces.Length > 0) - { - if (interfaces.Length > 1) - { - throw new TypeCheckerException(Diagnostic.Error($"Interface {type.Name} has multiple definitions").Build()); - } - - return GetInterfaceType(interfaces[0]); - } - - throw new TypeCheckerException(Diagnostic.Error($"Type {type.Name} is not defined").Build()); - } - - private StructTypeNode GetStructType(StructSyntax structDef) - { - if (_typeCache.TryGetValue(structDef.Name, out var cachedType)) - { - return (StructTypeNode)cachedType; - } - - var result = new StructTypeNode(structDef.Name, [], [], []); - _typeCache.Add(structDef.Name, result); - - var fields = structDef.Fields.Select(x => CheckType(x.Type)).ToList(); - - var funcs = structDef.Functions - .Select(x => new FuncTypeNode(x.Signature.Parameters.Select(p => CheckType(p.Type)).ToList(), CheckType(x.Signature.ReturnType))) - .ToList(); - - var interfaceImplementations = new List(); - - foreach (var structInterfaceImplementation in structDef.InterfaceImplementations) - { - var checkedInterfaceType = CheckType(structInterfaceImplementation); - if (checkedInterfaceType is not InterfaceTypeNode interfaceType) - { - throw new TypeCheckerException(Diagnostic.Error($"{structDef.Name} cannot implement non-interface type {checkedInterfaceType}").Build()); - } - - interfaceImplementations.Add(interfaceType); - } - - result.Fields = fields; - result.Functions = funcs; - result.InterfaceImplementations = interfaceImplementations; - return result; - } - - private InterfaceTypeNode GetInterfaceType(InterfaceSyntax interfaceDef) - { - if (_typeCache.TryGetValue(interfaceDef.Name, out var cachedType)) - { - return (InterfaceTypeNode)cachedType; - } - - var result = new InterfaceTypeNode(interfaceDef.Name, []); - _typeCache.Add(interfaceDef.Name, result); - - var functions = interfaceDef.Functions - .Select(x => new FuncTypeNode(x.Signature.Parameters.Select(y => CheckType(y.Type)).ToList(), CheckType(x.Signature.ReturnType))) - .ToList(); - - result.Functions = functions; - return result; + return TypeResolver.ResolveType(fieldType, _moduleSignatures); } } diff --git a/src/compiler/NubLang/TypeResolver.cs b/src/compiler/NubLang/TypeResolver.cs new file mode 100644 index 0000000..7e49649 --- /dev/null +++ b/src/compiler/NubLang/TypeResolver.cs @@ -0,0 +1,48 @@ +using NubLang.Parsing.Syntax; +using NubLang.TypeChecking; +using NubLang.TypeChecking.Node; + +namespace NubLang; + +public static class TypeResolver +{ + public static TypeNode ResolveType(TypeSyntax type, IReadOnlyDictionary modules) + { + return type switch + { + BoolTypeSyntax => new BoolTypeNode(), + CStringTypeSyntax => new CStringTypeNode(), + IntTypeSyntax i => new IntTypeNode(i.Signed, i.Width), + CustomTypeSyntax c => ResolveCustomType(c.Module, c.Name, modules), + FloatTypeSyntax f => new FloatTypeNode(f.Width), + FuncTypeSyntax func => new FuncTypeNode(func.Parameters.Select(x => ResolveType(x, modules)).ToList(), ResolveType(func.ReturnType, modules)), + ArrayTypeSyntax arr => new ArrayTypeNode(ResolveType(arr.BaseType, modules)), + PointerTypeSyntax ptr => new PointerTypeNode(ResolveType(ptr.BaseType, modules)), + StringTypeSyntax => new StringTypeNode(), + VoidTypeSyntax => new VoidTypeNode(), + _ => throw new NotSupportedException($"Unknown type syntax: {type}") + }; + } + + public static TypeNode ResolveCustomType(string moduleName, string typeName, IReadOnlyDictionary modules) + { + if (!modules.TryGetValue(moduleName, out var module)) + { + throw new Exception("Module not found: " + moduleName); + } + + var structType = module.StructTypes.FirstOrDefault(x => x.Name == typeName); + if (structType != null) + { + return structType; + } + + var interfaceType = module.InterfaceTypes.FirstOrDefault(x => x.Name == typeName); + if (interfaceType != null) + { + return interfaceType; + } + + throw new Exception("Type not found: " + typeName); + } +} \ No newline at end of file