diff --git a/compiler/NubLang/Ast/Node.cs b/compiler/NubLang/Ast/Node.cs index 9cffe4c..5b48e56 100644 --- a/compiler/NubLang/Ast/Node.cs +++ b/compiler/NubLang/Ast/Node.cs @@ -632,4 +632,12 @@ public class ConstArrayInitializerNode(List tokens, NubType type, List tokens, NubType type) : ExpressionNode(tokens, type) +{ + public override IEnumerable Children() + { + return []; + } +} + #endregion \ No newline at end of file diff --git a/compiler/NubLang/Ast/TypeChecker.cs b/compiler/NubLang/Ast/TypeChecker.cs index dfe380a..3d6678b 100644 --- a/compiler/NubLang/Ast/TypeChecker.cs +++ b/compiler/NubLang/Ast/TypeChecker.cs @@ -837,6 +837,11 @@ public sealed class TypeChecker return new ModuleFuncIdentifierNode(expression.Tokens, type, expression.ModuleToken, expression.NameToken, function.ExternSymbolToken); } + if (module.TryResolveEnum(expression.NameToken, out var enumType, out var _)) + { + return new EnumTypeIntermediateNode(expression.Tokens, enumType); + } + throw new CompileException(Diagnostic .Error($"Module {expression.ModuleToken.Value} does not export a member named {expression.NameToken.Value}") .At(expression, _syntaxTree.Tokens) @@ -920,6 +925,29 @@ public sealed class TypeChecker return new StructFieldAccessNode(expression.Tokens, field.Type, target, expression.MemberToken); } + case NubEnumType enumType: + { + if (!enumType.Members.TryGetValue(expression.MemberToken.Value, out var value)) + { + throw new CompileException(Diagnostic + .Error($"Enum {target.Type} does not have a member with the name {expression.MemberToken.Value}") + .At(expression, _syntaxTree.Tokens) + .Build()); + } + + return (enumType.UnderlyingType.Width, enumType.UnderlyingType.Signed) switch + { + (8, false) => new U8LiteralNode(expression.Tokens, (byte)value), + (16, false) => new U16LiteralNode(expression.Tokens, (ushort)value), + (32, false) => new U32LiteralNode(expression.Tokens, (uint)value), + (64, false) => new U64LiteralNode(expression.Tokens, value), + (8, true) => new I8LiteralNode(expression.Tokens, (sbyte)value), + (16, true) => new I16LiteralNode(expression.Tokens, (short)value), + (32, true) => new I32LiteralNode(expression.Tokens, (int)value), + (64, true) => new I64LiteralNode(expression.Tokens, (long)value), + _ => throw new ArgumentOutOfRangeException() + }; + } default: { throw new CompileException(Diagnostic @@ -1062,7 +1090,7 @@ public sealed class TypeChecker return structType; } - var enumType = module.EnumTypes.GetValueOrDefault(customType.NameToken.Value); + var enumType = module.EnumTypes.FirstOrDefault(x => x.Name == customType.NameToken.Value); if (enumType != null) { return enumType; diff --git a/compiler/NubLang/Generation/LlvmSharpGenerator.cs b/compiler/NubLang/Generation/LlvmSharpGenerator.cs index fb2a071..19a8bf2 100644 --- a/compiler/NubLang/Generation/LlvmSharpGenerator.cs +++ b/compiler/NubLang/Generation/LlvmSharpGenerator.cs @@ -16,6 +16,9 @@ public class LlvmSharpGenerator private readonly Dictionary _functions = new(); private readonly Dictionary _locals = new(); private readonly Stack<(LLVMBasicBlockRef breakBlock, LLVMBasicBlockRef continueBlock)> _loopStack = new(); + private readonly Stack _scopes = new(); + + private Scope CurrentScope => _scopes.Peek(); public void Emit(List topLevelNodes, ModuleRepository repository, string sourceFileName, string outputPath) { @@ -32,50 +35,97 @@ public class LlvmSharpGenerator _functions.Clear(); _locals.Clear(); _loopStack.Clear(); + _scopes.Clear(); var stringType = _context.CreateNamedStruct("nub.string"); stringType.StructSetBody([LLVMTypeRef.Int64, LLVMTypeRef.CreatePointer(LLVMTypeRef.Int8, 0)], false); _structTypes["nub.string"] = stringType; + // note(nub31): Declare all structs and functions foreach (var module in repository.GetAll()) { foreach (var structType in module.StructTypes) { - var structName = StructName(structType.Module, structType.Name); - var llvmStructType = _context.CreateNamedStruct(structName); - _structTypes[structName] = llvmStructType; - } - } + var llvmStructType = _context.CreateNamedStruct(StructName(structType.Module, structType.Name)); + llvmStructType.StructSetBody(structType.Fields.Select(f => MapType(f.Type)).ToArray(), structType.Packed); + _structTypes[StructName(structType.Module, structType.Name)] = llvmStructType; - foreach (var module in repository.GetAll()) - { - foreach (var structType in module.StructTypes) - { - var structName = StructName(structType.Module, structType.Name); - var llvmStructType = _structTypes[structName]; - var fieldTypes = structType.Fields.Select(f => MapType(f.Type)).ToArray(); - llvmStructType.StructSetBody(fieldTypes, false); - } - } + var constructorType = LLVMTypeRef.CreateFunction(LLVMTypeRef.Void, [LLVMTypeRef.CreatePointer(llvmStructType, 0)]); + var constructor = _llvmModule.AddFunction(StructConstructorName(structType.Module, structType.Name), constructorType); + + _functions[StructConstructorName(structType.Module, structType.Name)] = constructor; + } - foreach (var module in repository.GetAll()) - { foreach (var prototype in module.FunctionPrototypes) { - CreateFunctionDeclaration(prototype, module.Name); + var funcName = FuncName(module.Name, prototype.NameToken.Value, prototype.ExternSymbolToken?.Value); + + var paramTypes = prototype.Parameters.Select(p => MapType(p.Type)).ToArray(); + var funcType = LLVMTypeRef.CreateFunction(MapType(prototype.ReturnType), paramTypes); + var func = _llvmModule.AddFunction(funcName, funcType); + + func.FunctionCallConv = (uint)LLVMCallConv.LLVMCCallConv; + + _functions[funcName] = func; } } + // note(nub31): Define struct constructors foreach (var structNode in topLevelNodes.OfType()) { - EmitStructConstructor(structNode); + var structType = _structTypes[StructName(_module, structNode.NameToken.Value)]; + var constructor = _functions[StructConstructorName(_module, structNode.NameToken.Value)]; + + var entryBlock = constructor.AppendBasicBlock("entry"); + _builder.PositionAtEnd(entryBlock); + + var selfParam = constructor.GetParam(0); + selfParam.Name = "self"; + + _locals.Clear(); + + foreach (var field in structNode.Fields) + { + if (field.Value == null) continue; + + var index = structNode.StructType.GetFieldIndex(field.NameToken.Value); + var fieldPtr = _builder.BuildStructGEP2(structType, selfParam, (uint)index); + EmitExpressionInto(field.Value, fieldPtr); + } + + _builder.BuildRetVoid(); } + // note(nub31): Define function bodies foreach (var funcNode in topLevelNodes.OfType()) { - if (funcNode.Body != null) + if (funcNode.Body == null) continue; + + var funcName = FuncName(_module, funcNode.Prototype.NameToken.Value, funcNode.Prototype.ExternSymbolToken?.Value); + var func = _functions[funcName]; + + var entryBlock = func.AppendBasicBlock("entry"); + _builder.PositionAtEnd(entryBlock); + + _locals.Clear(); + + for (uint i = 0; i < funcNode.Prototype.Parameters.Count; i++) { - EmitFunction(funcNode); + var param = func.GetParam(i); + var paramNode = funcNode.Prototype.Parameters[(int)i]; + var alloca = _builder.BuildAlloca(MapType(paramNode.Type), paramNode.NameToken.Value); + _builder.BuildStore(param, alloca); + _locals[paramNode.NameToken.Value] = alloca; + } + + EmitBlock(funcNode.Body); + + if (funcNode.Prototype.ReturnType is NubVoidType) + { + if (_builder.InsertBlock.Terminator.Handle == IntPtr.Zero) + { + _builder.BuildRetVoid(); + } } } @@ -89,94 +139,16 @@ public class LlvmSharpGenerator _builder.Dispose(); } - private void CreateFunctionDeclaration(FuncPrototypeNode prototype, string moduleName) - { - var funcName = FuncName(moduleName, prototype.NameToken.Value, prototype.ExternSymbolToken?.Value); - - var paramTypes = prototype.Parameters.Select(p => MapType(p.Type)).ToArray(); - var returnType = MapType(prototype.ReturnType); - - var funcType = LLVMTypeRef.CreateFunction(returnType, paramTypes); - var func = _llvmModule.AddFunction(funcName, funcType); - - func.FunctionCallConv = (uint)LLVMCallConv.LLVMCCallConv; - - for (var i = 0; i < prototype.Parameters.Count; i++) - { - func.GetParam((uint)i).Name = prototype.Parameters[i].NameToken.Value; - } - - _functions[funcName] = func; - } - - private void EmitStructConstructor(StructNode structNode) - { - var structType = _structTypes[StructName(_module, structNode.NameToken.Value)]; - var ptrType = LLVMTypeRef.CreatePointer(structType, 0); - - var funcType = LLVMTypeRef.CreateFunction(LLVMTypeRef.Void, [ptrType]); - var funcName = StructConstructorName(_module, structNode.NameToken.Value); - var func = _llvmModule.AddFunction(funcName, funcType); - func.FunctionCallConv = (uint)LLVMCallConv.LLVMCCallConv; - - var entryBlock = func.AppendBasicBlock("entry"); - _builder.PositionAtEnd(entryBlock); - - var selfParam = func.GetParam(0); - selfParam.Name = "self"; - - _locals.Clear(); - - foreach (var field in structNode.Fields) - { - if (field.Value != null) - { - var index = structNode.StructType.GetFieldIndex(field.NameToken.Value); - var fieldPtr = _builder.BuildStructGEP2(structType, selfParam, (uint)index); - EmitExpressionInto(field.Value, fieldPtr); - } - } - - _builder.BuildRetVoid(); - _functions[funcName] = func; - } - - private void EmitFunction(FuncNode funcNode) - { - var funcName = FuncName(_module, funcNode.Prototype.NameToken.Value, funcNode.Prototype.ExternSymbolToken?.Value); - var func = _functions[funcName]; - - var entryBlock = func.AppendBasicBlock("entry"); - _builder.PositionAtEnd(entryBlock); - - _locals.Clear(); - - for (uint i = 0; i < funcNode.Prototype.Parameters.Count; i++) - { - var param = func.GetParam(i); - var paramNode = funcNode.Prototype.Parameters[(int)i]; - var alloca = _builder.BuildAlloca(MapType(paramNode.Type), paramNode.NameToken.Value); - _builder.BuildStore(param, alloca); - _locals[paramNode.NameToken.Value] = alloca; - } - - EmitBlock(funcNode.Body!); - - if (funcNode.Prototype.ReturnType is NubVoidType) - { - if (_builder.InsertBlock.Terminator.Handle == IntPtr.Zero) - { - _builder.BuildRetVoid(); - } - } - } - private void EmitBlock(BlockNode blockNode) { + _scopes.Push(new Scope()); foreach (var statement in blockNode.Statements) { EmitStatement(statement); } + + EmitScopeExit(); + _scopes.Pop(); } private void EmitStatement(StatementNode statement) @@ -195,6 +167,13 @@ public class LlvmSharpGenerator case ContinueNode: EmitContinue(); break; + case DeferNode deferNode: + CurrentScope.Defer(() => EmitStatement(deferNode.Statement)); + break; + case ForConstArrayNode forConstArrayNode: + throw new NotImplementedException(); + case ForSliceNode forSliceNode: + throw new NotImplementedException(); case IfNode ifNode: EmitIf(ifNode); break; @@ -211,7 +190,7 @@ public class LlvmSharpGenerator EmitWhile(whileNode); break; default: - throw new NotImplementedException($"Statement type {statement.GetType()} not implemented"); + throw new ArgumentOutOfRangeException(nameof(statement)); } } @@ -270,10 +249,12 @@ public class LlvmSharpGenerator if (returnNode.Value != null) { var value = EmitExpression(returnNode.Value); + EmitScopeExit(); _builder.BuildRet(value); } else { + EmitScopeExit(); _builder.BuildRetVoid(); } } @@ -512,7 +493,10 @@ public class LlvmSharpGenerator { var funcPtr = EmitExpression(funcCall.Expression); var args = funcCall.Parameters.Select(x => EmitExpression(x)).ToArray(); - return _builder.BuildCall2(MapType(funcCall.Expression.Type), funcPtr, args, funcCall.Type is NubVoidType ? "" : "call"); + + var functionType = (NubFuncType)funcCall.Expression.Type; + var llvmFunctionType = LLVMTypeRef.CreateFunction(MapType(functionType.ReturnType), functionType.Parameters.Select(MapType).ToArray()); + return _builder.BuildCall2(llvmFunctionType, funcPtr, args, funcCall.Type is NubVoidType ? "" : "call"); } private LLVMValueRef EmitStructFieldAccess(StructFieldAccessNode field) @@ -733,13 +717,14 @@ public class LlvmSharpGenerator NubBoolType => LLVMTypeRef.Int1, NubIntType intType => LLVMTypeRef.CreateInt((uint)intType.Width), NubFloatType floatType => floatType.Width == 32 ? LLVMTypeRef.Float : LLVMTypeRef.Double, - NubFuncType funcType => LLVMTypeRef.CreateFunction(MapType(funcType.ReturnType), funcType.Parameters.Select(MapType).ToArray()), + NubFuncType funcType => LLVMTypeRef.CreatePointer(LLVMTypeRef.CreateFunction(MapType(funcType.ReturnType), funcType.Parameters.Select(MapType).ToArray()), 0), NubPointerType ptrType => LLVMTypeRef.CreatePointer(MapType(ptrType.BaseType), 0), NubSliceType nubSliceType => MapSliceType(nubSliceType), NubStringType => _structTypes["nub.string"], NubArrayType arrType => LLVMTypeRef.CreatePointer(MapType(arrType.ElementType), 0), NubConstArrayType constArr => LLVMTypeRef.CreateArray(MapType(constArr.ElementType), (uint)constArr.Size), NubStructType structType => _structTypes[StructName(structType.Module, structType.Name)], + NubEnumType enumType => MapType(enumType.UnderlyingType), NubVoidType => LLVMTypeRef.Void, _ => throw new ArgumentOutOfRangeException(nameof(type), type, null) }; @@ -778,4 +763,28 @@ public class LlvmSharpGenerator return $"{module}.{name}"; } + + private void EmitScopeExit() + { + var deferredActions = CurrentScope.GetDeferredActions(); + while (deferredActions.TryPop(out var action)) + { + action.Invoke(); + } + } + + private class Scope + { + private readonly Stack _deferredActions = []; + + public Stack GetDeferredActions() + { + return _deferredActions; + } + + public void Defer(Action action) + { + _deferredActions.Push(action); + } + } } \ No newline at end of file diff --git a/compiler/NubLang/Modules/ModuleRepository.cs b/compiler/NubLang/Modules/ModuleRepository.cs index f3e2706..70ab75e 100644 --- a/compiler/NubLang/Modules/ModuleRepository.cs +++ b/compiler/NubLang/Modules/ModuleRepository.cs @@ -11,7 +11,7 @@ public sealed class ModuleRepository public static ModuleRepository Create(List syntaxTrees) { var structTypes = new Dictionary<(string module, string name), NubStructType>(); - var enumTypes = new Dictionary<(string module, string name), NubIntType>(); + var enumTypes = new Dictionary<(string module, string name), NubEnumType>(); foreach (var syntaxTree in syntaxTrees) { @@ -44,11 +44,27 @@ public sealed class ModuleRepository underlyingType ??= new NubIntType(false, 64); var key = (module.NameToken.Value, enumSyntax.NameToken.Value); - enumTypes.Add(key, underlyingType); + + var memberValues = new Dictionary(); + + ulong currentValue = 0; + + foreach (var member in enumSyntax.Members) + { + if (member.ValueToken != null) + { + currentValue = member.ValueToken.AsU64; + } + + memberValues[member.NameToken.Value] = currentValue; + currentValue++; + } + + enumTypes.Add(key, new NubEnumType(module.NameToken.Value, enumSyntax.NameToken.Value, underlyingType, memberValues)); } } - // note(nub31): Since all struct types are now registered, we can safely resolve the field types + // note(nub31): Since all struct and enum types are now registered, we can safely resolve the field types foreach (var syntaxTree in syntaxTrees) { var module = syntaxTree.TopLevelSyntaxNodes.OfType().FirstOrDefault(); @@ -90,9 +106,7 @@ public sealed class ModuleRepository { Name = moduleDecl.NameToken.Value, StructTypes = structTypes.Where(x => x.Key.module == moduleDecl.NameToken.Value).Select(x => x.Value).ToList(), - EnumTypes = enumTypes - .Where(x => x.Key.module == moduleDecl.NameToken.Value) - .ToDictionary(x => x.Key.name, x => x.Value), + EnumTypes = enumTypes.Where(x => x.Key.module == moduleDecl.NameToken.Value).Select(x => x.Value).ToList(), FunctionPrototypes = functionPrototypes }; @@ -167,12 +181,6 @@ public sealed class ModuleRepository return module != null; } - public bool TryGet(string name, [NotNullWhen(true)] out Module? module) - { - module = _modules.GetValueOrDefault(name); - return module != null; - } - public List GetAll() { return _modules.Values.ToList(); @@ -183,22 +191,7 @@ public sealed class ModuleRepository public required string Name { get; init; } public required List FunctionPrototypes { get; init; } = []; public required List StructTypes { get; init; } = []; - public required Dictionary EnumTypes { get; init; } = []; - - public bool TryResolveFunc(string name, [NotNullWhen(true)] out FuncPrototypeNode? value, [NotNullWhen(false)] out Diagnostic? diagnostic) - { - value = FunctionPrototypes.FirstOrDefault(x => x.NameToken.Value == name); - - if (value == null) - { - value = null; - diagnostic = Diagnostic.Error($"Func {name} not found in module {Name}").Build(); - return false; - } - - diagnostic = null; - return true; - } + public required List EnumTypes { get; init; } = []; public bool TryResolveFunc(IdentifierToken name, [NotNullWhen(true)] out FuncPrototypeNode? value, [NotNullWhen(false)] out Diagnostic? diagnostic) { @@ -215,31 +208,6 @@ public sealed class ModuleRepository return true; } - public FuncPrototypeNode ResolveFunc(IdentifierToken name) - { - if (!TryResolveFunc(name, out var value, out var diagnostic)) - { - throw new CompileException(diagnostic); - } - - return value; - } - - public bool TryResolveStruct(string name, [NotNullWhen(true)] out NubStructType? value, [NotNullWhen(false)] out Diagnostic? diagnostic) - { - value = StructTypes.FirstOrDefault(x => x.Name == name); - - if (value == null) - { - value = null; - diagnostic = Diagnostic.Error($"Struct {name} not found in module {Name}").Build(); - return false; - } - - diagnostic = null; - return true; - } - public bool TryResolveStruct(IdentifierToken name, [NotNullWhen(true)] out NubStructType? value, [NotNullWhen(false)] out Diagnostic? diagnostic) { value = StructTypes.FirstOrDefault(x => x.Name == name.Value); @@ -255,34 +223,9 @@ public sealed class ModuleRepository return true; } - public NubStructType ResolveStruct(IdentifierToken name) + public bool TryResolveEnum(IdentifierToken name, [NotNullWhen(true)] out NubEnumType? value, [NotNullWhen(false)] out Diagnostic? diagnostic) { - if (!TryResolveStruct(name, out var value, out var diagnostic)) - { - throw new CompileException(diagnostic); - } - - return value; - } - - public bool TryResolveEnum(string name, [NotNullWhen(true)] out NubIntType? value, [NotNullWhen(false)] out Diagnostic? diagnostic) - { - value = EnumTypes.GetValueOrDefault(name); - - if (value == null) - { - value = null; - diagnostic = Diagnostic.Error($"Enum {name} not found in module {Name}").Build(); - return false; - } - - diagnostic = null; - return true; - } - - public bool TryResolveEnum(IdentifierToken name, [NotNullWhen(true)] out NubIntType? value, [NotNullWhen(false)] out Diagnostic? diagnostic) - { - value = EnumTypes.GetValueOrDefault(name.Value); + value = EnumTypes.FirstOrDefault(x => x.Name == name.Value); if (value == null) { @@ -294,15 +237,5 @@ public sealed class ModuleRepository diagnostic = null; return true; } - - public NubIntType ResolveEnum(IdentifierToken name) - { - if (!TryResolveEnum(name, out var value, out var diagnostic)) - { - throw new CompileException(diagnostic); - } - - return value; - } } } \ No newline at end of file diff --git a/compiler/NubLang/Syntax/Parser.cs b/compiler/NubLang/Syntax/Parser.cs index b5c5873..a5c4fc5 100644 --- a/compiler/NubLang/Syntax/Parser.cs +++ b/compiler/NubLang/Syntax/Parser.cs @@ -174,7 +174,7 @@ public sealed class Parser type = ParseType(); } - List fields = []; + List fields = []; ExpectSymbol(Symbol.OpenBrace); @@ -197,7 +197,7 @@ public sealed class Parser value = intLiteralToken; } - fields.Add(new EnumFieldSyntax(GetTokens(memberStartIndex), fieldName, value)); + fields.Add(new EnumMemberSyntax(GetTokens(memberStartIndex), fieldName, value)); } return new EnumSyntax(GetTokens(startIndex), name, exported, type, fields); diff --git a/compiler/NubLang/Syntax/Syntax.cs b/compiler/NubLang/Syntax/Syntax.cs index 0c70250..add9fc6 100644 --- a/compiler/NubLang/Syntax/Syntax.cs +++ b/compiler/NubLang/Syntax/Syntax.cs @@ -20,9 +20,9 @@ public record StructFieldSyntax(List Tokens, IdentifierToken NameToken, T public record StructSyntax(List Tokens, IdentifierToken NameToken, bool Exported, bool Packed, List Fields) : DefinitionSyntax(Tokens, NameToken, Exported); -public record EnumFieldSyntax(List Tokens, IdentifierToken NameToken, IntLiteralToken? ValueToken) : SyntaxNode(Tokens); +public record EnumMemberSyntax(List Tokens, IdentifierToken NameToken, IntLiteralToken? ValueToken) : SyntaxNode(Tokens); -public record EnumSyntax(List Tokens, IdentifierToken NameToken, bool Exported, TypeSyntax? Type, List Fields) : DefinitionSyntax(Tokens, NameToken, Exported); +public record EnumSyntax(List Tokens, IdentifierToken NameToken, bool Exported, TypeSyntax? Type, List Members) : DefinitionSyntax(Tokens, NameToken, Exported); public enum UnaryOperatorSyntax { diff --git a/compiler/NubLang/Types/NubType.cs b/compiler/NubLang/Types/NubType.cs index de99b2e..0a74bdf 100644 --- a/compiler/NubLang/Types/NubType.cs +++ b/compiler/NubLang/Types/NubType.cs @@ -183,6 +183,22 @@ public class NubStructFieldType(string name, NubType type, bool hasDefaultValue) public bool HasDefaultValue { get; } = hasDefaultValue; } +public class NubEnumType(string module, string name, NubIntType underlyingType, Dictionary members) : NubType +{ + public string Module { get; } = module; + public string Name { get; } = name; + public NubIntType UnderlyingType { get; } = underlyingType; + public Dictionary Members { get; } = members; + + public override ulong GetSize() => UnderlyingType.GetSize(); + public override ulong GetAlignment() => UnderlyingType.GetSize(); + public override bool IsAggregate() => false; + + public override bool Equals(NubType? other) => other is NubEnumType enumType && Name == enumType.Name && Module == enumType.Module; + public override int GetHashCode() => HashCode.Combine(typeof(NubEnumType), Module, Name); + public override string ToString() => $"{Module}::{Name}"; +} + public class NubSliceType(NubType elementType) : NubType { public NubType ElementType { get; } = elementType;