diff --git a/compiler/NubLang.CLI/Program.cs b/compiler/NubLang.CLI/Program.cs index adcebad..9c5a563 100644 --- a/compiler/NubLang.CLI/Program.cs +++ b/compiler/NubLang.CLI/Program.cs @@ -9,7 +9,7 @@ var syntaxTrees = new List(); var tokenizer = new Tokenizer(); var parser = new Parser(); -var generator = new LlvmGenerator(); +var generator = new LlvmSharpGenerator(); foreach (var file in args) { @@ -83,7 +83,7 @@ for (var i = 0; i < args.Length; i++) } var path = Path.Combine(".build", Path.ChangeExtension(file, "ll")); - File.WriteAllText(path, generator.Emit(compilationUnit, moduleRepository)); + generator.Emit(compilationUnit, moduleRepository, file, path); } return 0; \ No newline at end of file diff --git a/compiler/NubLang/Ast/Node.cs b/compiler/NubLang/Ast/Node.cs index 2e48c0a..9cffe4c 100644 --- a/compiler/NubLang/Ast/Node.cs +++ b/compiler/NubLang/Ast/Node.cs @@ -596,6 +596,8 @@ public class CastNode(List tokens, NubType type, ExpressionNode value, Ca ConstArrayToArray, ConstArrayToSlice, + + StringToCString } public ExpressionNode Value { get; } = value; @@ -607,7 +609,7 @@ public class CastNode(List tokens, NubType type, ExpressionNode value, Ca } } -public class StructInitializerNode(List tokens, NubType type, Dictionary initializers) : RValue(tokens, type) +public class StructInitializerNode(List tokens, NubType type, Dictionary initializers) : LValue(tokens, type) { public Dictionary Initializers { get; } = initializers; @@ -620,7 +622,7 @@ public class StructInitializerNode(List tokens, NubType type, Dictionary< } } -public class ConstArrayInitializerNode(List tokens, NubType type, List values) : RValue(tokens, type) +public class ConstArrayInitializerNode(List tokens, NubType type, List values) : LValue(tokens, type) { public List Values { get; } = values; diff --git a/compiler/NubLang/Ast/TypeChecker.cs b/compiler/NubLang/Ast/TypeChecker.cs index 8b96de6..dfe380a 100644 --- a/compiler/NubLang/Ast/TypeChecker.cs +++ b/compiler/NubLang/Ast/TypeChecker.cs @@ -389,6 +389,11 @@ public sealed class TypeChecker conversion = CastNode.Conversion.ConstArrayToSlice; return true; } + case NubStringType when to is NubPointerType { BaseType: NubIntType { Signed: true, Width: 8 } }: + { + conversion = CastNode.Conversion.StringToCString; + return true; + } } if (!strict) diff --git a/compiler/NubLang/Generation/LlvmGenerator.cs b/compiler/NubLang/Generation/LlvmGenerator.cs deleted file mode 100644 index 42d78bf..0000000 --- a/compiler/NubLang/Generation/LlvmGenerator.cs +++ /dev/null @@ -1,1085 +0,0 @@ -using System.Diagnostics; -using System.Text; -using NubLang.Ast; -using NubLang.Modules; -using NubLang.Types; - -namespace NubLang.Generation; - -public class LlvmGenerator -{ - private string _module = string.Empty; - private int _tmpIndex; - private int _labelIndex; - private List<(string Name, int Size, string Text)> _stringLiterals = []; - private Stack<(string breakLabel, string continueLabel)> _loopStack = []; - - public string Emit(List topLevelNodes, ModuleRepository repository) - { - _stringLiterals = []; - _loopStack = []; - - var writer = new IndentedTextWriter(); - - _module = topLevelNodes.OfType().First().NameToken.Value; - - writer.WriteLine($$""" - ; Module {{_module}} - - target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128" - target triple = "x86_64-pc-linux-gnu" - - %nub.slice = type { i64, ptr } - %nub.string = type { i64, ptr } - - declare void @llvm.memcpy.p0.p0.i64(ptr, ptr, i64, i1) - - """); - - var declaredExternFunctions = new HashSet(); - - writer.WriteLine("; == Function declarations =="); - foreach (var module in repository.GetAll()) - { - writer.WriteLine($"; ==== {module.Name} ===="); - foreach (var prototype in module.FunctionPrototypes) - { - // note(nub31): If we are in the current module and the function has a body, we skip it to prevent duplicate definition - if (module.Name == _module && topLevelNodes.OfType().First(x => x.NameToken.Value == prototype.NameToken.Value).Body != null) - { - continue; - } - - if (prototype.ExternSymbolToken != null && !declaredExternFunctions.Add(prototype.ExternSymbolToken.Value)) - { - continue; - } - - writer.WriteLine($"declare {CreateFunctionPrototype(prototype, module.Name)}"); - } - } - - writer.WriteLine(); - - writer.WriteLine("; == Struct declarations =="); - foreach (var module in repository.GetAll()) - { - writer.WriteLine($"; ==== {module.Name} ===="); - foreach (var structType in module.StructTypes) - { - var fieldTypes = structType.Fields.Select(x => MapType(x.Type)); - writer.WriteLine($"%{StructName(structType.Module, structType.Name)} = type {{ {string.Join(", ", fieldTypes)} }}"); - } - } - - writer.WriteLine(); - - foreach (var structNode in topLevelNodes.OfType()) - { - _tmpIndex = 0; - _labelIndex = 0; - - writer.WriteLine("; == Struct constructors =="); - writer.WriteLine($"define void @{StructName(structNode)}.new(ptr %self) {{"); - using (writer.Indent()) - { - foreach (var field in structNode.Fields) - { - if (field.Value != null) - { - var index = structNode.StructType.GetFieldIndex(field.NameToken.Value); - var fieldTmp = NewTmp($"struct.field.{field.NameToken.Value}"); - writer.WriteLine($"{fieldTmp} = getelementptr %{StructName(structNode)}, ptr %self, i32 0, i32 {index}"); - - EmitExpressionInto(writer, field.Value, fieldTmp); - } - } - - writer.WriteLine("ret void"); - } - - writer.WriteLine("}"); - writer.WriteLine(); - } - - writer.WriteLine("; == Function definitions =="); - foreach (var funcNode in topLevelNodes.OfType()) - { - if (funcNode.Body == null) continue; - - _tmpIndex = 0; - _labelIndex = 0; - - writer.WriteLine($"define {CreateFunctionPrototype(funcNode.Prototype, _module)} {{"); - - using (writer.Indent()) - { - EmitBlock(writer, funcNode.Body); - - // note(nub31): Implicit return for void functions - if (funcNode.Prototype.ReturnType is NubVoidType) - { - writer.WriteLine("ret void"); - } - } - - writer.WriteLine("}"); - writer.WriteLine(); - } - - writer.WriteLine("; == String literals =="); - foreach (var stringLiteral in _stringLiterals) - { - writer.WriteLine($"{stringLiteral.Name} = private unnamed_addr constant [{stringLiteral.Size} x i8] c\"{stringLiteral.Text}\\00\", align 1"); - } - - return writer.ToString(); - } - - private string CreateFunctionPrototype(FuncPrototypeNode prototypeNode, string module) - { - var parameterStrings = new List(); - - foreach (var parameter in prototypeNode.Parameters) - { - var llvmType = MapType(parameter.Type); - var name = parameter.NameToken.Value; - - if (parameter.Type is NubStructType) - { - var alignment = parameter.Type.GetAlignment(); - parameterStrings.Add($"{llvmType}* byval({llvmType}) align {alignment} %{name}"); - } - else - { - parameterStrings.Add($"{llvmType} %{name}"); - } - } - - var funcName = FuncName(module, prototypeNode.NameToken.Value, prototypeNode.ExternSymbolToken?.Value); - var returnType = MapType(prototypeNode.ReturnType); - - if (prototypeNode.ReturnType is NubStructType) - { - var alignment = prototypeNode.ReturnType.GetAlignment(); - var parameters = ""; - if (parameterStrings.Count != 0) - { - parameters = ", " + string.Join(", ", parameterStrings); - } - - return $"ccc void @{funcName}({returnType}* sret({returnType}) align {alignment}{parameters})"; - } - else - { - return $"ccc {returnType} @{funcName}({string.Join(", ", parameterStrings)})"; - } - } - - private void EmitStatement(IndentedTextWriter writer, StatementNode statementNode) - { - switch (statementNode) - { - case AssignmentNode assignmentNode: - EmitAssignment(writer, assignmentNode); - break; - case BlockNode blockNode: - EmitBlock(writer, blockNode); - break; - case BreakNode: - EmitBreak(writer); - break; - case ContinueNode: - EmitContinue(writer); - break; - case DeferNode deferNode: - EmitDefer(writer, deferNode); - break; - case ForConstArrayNode forConstArrayNode: - EmitForConstArray(writer, forConstArrayNode); - break; - case ForSliceNode forSliceNode: - EmitForSlice(writer, forSliceNode); - break; - case IfNode ifNode: - EmitIf(writer, ifNode); - break; - case ReturnNode returnNode: - EmitReturn(writer, returnNode); - break; - case StatementFuncCallNode statementFuncCallNode: - EmitStatementFuncCall(writer, statementFuncCallNode); - break; - case VariableDeclarationNode variableDeclarationNode: - EmitVariableDeclaration(writer, variableDeclarationNode); - break; - case WhileNode whileNode: - EmitWhile(writer, whileNode); - break; - default: - throw new ArgumentOutOfRangeException(nameof(statementNode)); - } - } - - private void EmitAssignment(IndentedTextWriter writer, AssignmentNode assignmentNode) - { - var target = EmitExpression(writer, assignmentNode.Target); - var value = Unwrap(writer, EmitExpression(writer, assignmentNode.Value)); - writer.WriteLine($"store {MapType(assignmentNode.Value.Type)} {value}, ptr {target.Ident}"); - } - - private void EmitBlock(IndentedTextWriter writer, BlockNode blockNode) - { - foreach (var statementNode in blockNode.Statements) - { - EmitStatement(writer, statementNode); - } - } - - private void EmitBreak(IndentedTextWriter writer) - { - var (breakLabel, _) = _loopStack.Peek(); - writer.WriteLine($"br label %{breakLabel}"); - } - - private void EmitContinue(IndentedTextWriter writer) - { - var (_, continueLabel) = _loopStack.Peek(); - writer.WriteLine($"br label %{continueLabel}"); - } - - private void EmitDefer(IndentedTextWriter writer, DeferNode deferNode) - { - throw new NotImplementedException(); - } - - private void EmitForConstArray(IndentedTextWriter writer, ForConstArrayNode forConstArrayNode) - { - throw new NotImplementedException(); - } - - private void EmitForSlice(IndentedTextWriter writer, ForSliceNode forSliceNode) - { - throw new NotImplementedException(); - } - - private void EmitIf(IndentedTextWriter writer, IfNode ifNode) - { - var endLabel = NewLabel("if.end"); - EmitIf(writer, ifNode, endLabel); - writer.WriteLine($"{endLabel}:"); - } - - private void EmitIf(IndentedTextWriter writer, IfNode ifNode, string endLabel) - { - var condition = Unwrap(writer, EmitExpression(writer, ifNode.Condition)); - var thenLabel = NewLabel("if.then"); - var elseLabel = ifNode.Else.HasValue ? NewLabel("if.else") : endLabel; - - writer.WriteLine($"br i1 {condition}, label %{thenLabel}, label %{elseLabel}"); - - writer.WriteLine($"{thenLabel}:"); - using (writer.Indent()) - { - EmitBlock(writer, ifNode.Body); - writer.WriteLine($"br label %{endLabel}"); - } - - if (!ifNode.Else.HasValue) return; - - writer.WriteLine($"{elseLabel}:"); - using (writer.Indent()) - { - ifNode.Else.Value.Match( - nestedElseIf => EmitIf(writer, nestedElseIf, endLabel), - finalElse => - { - EmitBlock(writer, finalElse); - writer.WriteLine($"br label %{endLabel}"); - } - ); - } - } - - private void EmitReturn(IndentedTextWriter writer, ReturnNode returnNode) - { - if (returnNode.Value != null) - { - var returnValue = Unwrap(writer, EmitExpression(writer, returnNode.Value)); - writer.WriteLine($"ret {MapType(returnNode.Value.Type)} {returnValue}"); - } - else - { - writer.WriteLine("ret void"); - } - } - - private void EmitStatementFuncCall(IndentedTextWriter writer, StatementFuncCallNode statementFuncCallNode) - { - EmitFuncCall(writer, statementFuncCallNode.FuncCall); - } - - private void EmitVariableDeclaration(IndentedTextWriter writer, VariableDeclarationNode variableDeclarationNode) - { - writer.WriteLine($"%{variableDeclarationNode.NameToken.Value} = alloca {MapType(variableDeclarationNode.Type)}"); - if (variableDeclarationNode.Assignment != null) - { - EmitExpressionInto(writer, variableDeclarationNode.Assignment, $"%{variableDeclarationNode.NameToken.Value}"); - } - } - - private void EmitWhile(IndentedTextWriter writer, WhileNode whileNode) - { - var conditionLabel = NewLabel("while.condition"); - var bodyLabel = NewLabel("while.body"); - var endLabel = NewLabel("while.end"); - - _loopStack.Push((endLabel, conditionLabel)); - - writer.WriteLine($"br label %{conditionLabel}"); - - writer.WriteLine($"{conditionLabel}:"); - using (writer.Indent()) - { - var condition = Unwrap(writer, EmitExpression(writer, whileNode.Condition)); - writer.WriteLine($"br i1 {condition}, label %{bodyLabel}, label %{endLabel}"); - } - - writer.WriteLine($"{bodyLabel}:"); - using (writer.Indent()) - { - EmitBlock(writer, whileNode.Body); - writer.WriteLine($"br label %{conditionLabel}"); - } - - _loopStack.Pop(); - - writer.WriteLine($"{endLabel}:"); - } - - private Tmp EmitExpression(IndentedTextWriter writer, ExpressionNode expressionNode) - { - return expressionNode switch - { - AddressOfNode addressOfNode => EmitAddressOf(writer, addressOfNode), - DereferenceNode dereferenceNode => EmitDereference(writer, dereferenceNode), - - UnaryExpressionNode unaryExpressionNode => EmitUnaryExpression(writer, unaryExpressionNode), - BinaryExpressionNode binaryExpressionNode => EmitBinaryExpression(writer, binaryExpressionNode), - - ConstArrayInitializerNode constArrayInitializerNode => EmitConstArrayInitializer(writer, constArrayInitializerNode), - StructInitializerNode structInitializerNode => EmitStructInitializer(writer, structInitializerNode), - - ConstArrayIndexAccessNode constArrayIndexAccessNode => EmitConstArrayIndexAccess(writer, constArrayIndexAccessNode), - ArrayIndexAccessNode arrayIndexAccessNode => EmitArrayIndexAccess(writer, arrayIndexAccessNode), - SliceIndexAccessNode sliceIndexAccessNode => EmitSliceIndexAccess(writer, sliceIndexAccessNode), - - StructFieldAccessNode structFieldAccessNode => EmitStructFieldAccess(writer, structFieldAccessNode), - - CStringLiteralNode cStringLiteralNode => EmitCStringLiteral(writer, cStringLiteralNode), - StringLiteralNode stringLiteralNode => EmitStringLiteral(writer, stringLiteralNode), - BoolLiteralNode boolLiteralNode => EmitBoolLiteral(boolLiteralNode), - Float32LiteralNode float32LiteralNode => EmitFloat32Literal(float32LiteralNode), - Float64LiteralNode float64LiteralNode => EmitFloat64Literal(float64LiteralNode), - U8LiteralNode u8LiteralNode => EmitU8Literal(u8LiteralNode), - U16LiteralNode u16LiteralNode => EmitU16Literal(u16LiteralNode), - U32LiteralNode u32LiteralNode => EmitU32Literal(u32LiteralNode), - U64LiteralNode u64LiteralNode => EmitU64Literal(u64LiteralNode), - I8LiteralNode i8LiteralNode => EmitI8Literal(i8LiteralNode), - I16LiteralNode i16LiteralNode => EmitI16Literal(i16LiteralNode), - I32LiteralNode i32LiteralNode => EmitI32Literal(i32LiteralNode), - I64LiteralNode i64LiteralNode => EmitI64Literal(i64LiteralNode), - - LocalFuncIdentifierNode localFuncIdentifierNode => EmitLocalFuncIdentifier(writer, localFuncIdentifierNode), - ModuleFuncIdentifierNode moduleFuncIdentifierNode => EmitModuleFuncIdentifier(writer, moduleFuncIdentifierNode), - VariableIdentifierNode variableIdentifierNode => EmitVariableIdentifier(writer, variableIdentifierNode), - - FuncCallNode funcCallNode => EmitFuncCall(writer, funcCallNode), - SizeNode sizeNode => EmitSize(sizeNode), - CastNode castNode => EmitCast(writer, castNode), - - _ => throw new ArgumentOutOfRangeException(nameof(expressionNode)) - }; - } - - private void EmitExpressionInto(IndentedTextWriter writer, ExpressionNode expressionNode, string destination) - { - switch (expressionNode) - { - case StructInitializerNode structInitializerNode: - { - EmitStructInitializer(writer, structInitializerNode, destination); - return; - } - case ConstArrayInitializerNode constArrayInitializerNode: - { - EmitConstArrayInitializer(writer, constArrayInitializerNode, destination); - return; - } - } - - var value = Unwrap(writer, EmitExpression(writer, expressionNode)); - - if (expressionNode.Type.IsAggregate()) - { - // note(nub31): Fall back to slow method creating a copy - writer.WriteLine("; Slow rvalue copy instead of direct memory write"); - writer.WriteLine($"call void @llvm.memcpy.p0.p0.i64(ptr {destination}, ptr {value}, i64 {expressionNode.Type.GetSize()}, i1 false)"); - } - else - { - writer.WriteLine($"store {MapType(expressionNode.Type)} {value}, ptr {destination}"); - } - } - - private Tmp EmitAddressOf(IndentedTextWriter writer, AddressOfNode addressOfNode) - { - var target = EmitExpression(writer, addressOfNode.Target); - return new Tmp(target.Ident, addressOfNode.Type, false); - } - - private Tmp EmitDereference(IndentedTextWriter writer, DereferenceNode dereferenceNode) - { - var target = EmitExpression(writer, dereferenceNode.Target); - return new Tmp(target.Ident, dereferenceNode.Type, true); - } - - private Tmp EmitUnaryExpression(IndentedTextWriter writer, UnaryExpressionNode unaryExpressionNode) - { - var operand = Unwrap(writer, EmitExpression(writer, unaryExpressionNode.Operand)); - var result = NewTmp("unary"); - - switch (unaryExpressionNode.Operator) - { - case UnaryOperator.Negate: - { - switch (unaryExpressionNode.Operand.Type) - { - case NubIntType intType: - writer.WriteLine($"{result} = sub {MapType(intType)} 0, {operand}"); - break; - case NubFloatType floatType: - writer.WriteLine($"{result} = fneg {MapType(floatType)} {operand}"); - break; - default: - throw new UnreachableException(); - } - - break; - } - case UnaryOperator.Invert: - { - writer.WriteLine($"{result} = xor i1 {operand}, true"); - break; - } - default: - { - throw new ArgumentOutOfRangeException(); - } - } - - return new Tmp(result, unaryExpressionNode.Type, false); - } - - private Tmp EmitBinaryExpression(IndentedTextWriter writer, BinaryExpressionNode binaryExpressionNode) - { - var left = Unwrap(writer, EmitExpression(writer, binaryExpressionNode.Left)); - var right = Unwrap(writer, EmitExpression(writer, binaryExpressionNode.Right)); - var result = NewTmp("binop"); - - var leftType = binaryExpressionNode.Left.Type; - var op = binaryExpressionNode.Operator; - - switch (op) - { - case BinaryOperator.Equal: - case BinaryOperator.NotEqual: - case BinaryOperator.GreaterThan: - case BinaryOperator.GreaterThanOrEqual: - case BinaryOperator.LessThan: - case BinaryOperator.LessThanOrEqual: - { - var cmpOp = leftType switch - { - NubIntType intType => GenerateIntComparison(op, intType.Signed), - NubFloatType => GenerateFloatComparison(op), - NubBoolType => GenerateBoolComparison(op), - NubPointerType => GeneratePointerComparison(op), - _ => throw new InvalidOperationException($"Unexpected type for comparison: {leftType}") - }; - - writer.WriteLine($"{result} = {cmpOp} {MapType(leftType)} {left}, {right}"); - break; - } - case BinaryOperator.LogicalAnd: - { - writer.WriteLine($"{result} = and i1 {left}, {right}"); - break; - } - case BinaryOperator.LogicalOr: - { - writer.WriteLine($"{result} = or i1 {left}, {right}"); - break; - } - case BinaryOperator.Plus: - { - var instruction = leftType switch - { - NubIntType => "add", - NubFloatType => "fadd", - _ => throw new InvalidOperationException($"Unexpected type for plus: {leftType}") - }; - writer.WriteLine($"{result} = {instruction} {MapType(leftType)} {left}, {right}"); - break; - } - case BinaryOperator.Minus: - { - var instruction = leftType switch - { - NubIntType => "sub", - NubFloatType => "fsub", - _ => throw new InvalidOperationException($"Unexpected type for minus: {leftType}") - }; - writer.WriteLine($"{result} = {instruction} {MapType(leftType)} {left}, {right}"); - break; - } - case BinaryOperator.Multiply: - { - var instruction = leftType switch - { - NubIntType => "mul", - NubFloatType => "fmul", - _ => throw new InvalidOperationException($"Unexpected type for multiply: {leftType}") - }; - writer.WriteLine($"{result} = {instruction} {MapType(leftType)} {left}, {right}"); - break; - } - case BinaryOperator.Divide: - { - var instruction = leftType switch - { - NubIntType intType => intType.Signed ? "sdiv" : "udiv", - NubFloatType => "fdiv", - _ => throw new InvalidOperationException($"Unexpected type for divide: {leftType}") - }; - writer.WriteLine($"{result} = {instruction} {MapType(leftType)} {left}, {right}"); - break; - } - case BinaryOperator.Modulo: - { - var instruction = leftType switch - { - NubIntType intType => intType.Signed ? "srem" : "urem", - NubFloatType => "frem", - _ => throw new InvalidOperationException($"Unexpected type for modulo: {leftType}") - }; - writer.WriteLine($"{result} = {instruction} {MapType(leftType)} {left}, {right}"); - break; - } - case BinaryOperator.LeftShift: - { - writer.WriteLine($"{result} = shl {MapType(leftType)} {left}, {right}"); - break; - } - case BinaryOperator.RightShift: - { - var intType = (NubIntType)leftType; - var instruction = intType.Signed ? "ashr" : "lshr"; - writer.WriteLine($"{result} = {instruction} {MapType(leftType)} {left}, {right}"); - break; - } - case BinaryOperator.BitwiseAnd: - { - writer.WriteLine($"{result} = and {MapType(leftType)} {left}, {right}"); - break; - } - case BinaryOperator.BitwiseXor: - { - writer.WriteLine($"{result} = xor {MapType(leftType)} {left}, {right}"); - break; - } - case BinaryOperator.BitwiseOr: - { - writer.WriteLine($"{result} = or {MapType(leftType)} {left}, {right}"); - break; - } - default: - throw new ArgumentOutOfRangeException(nameof(op), op, null); - } - - return new Tmp(result, binaryExpressionNode.Type, false); - } - - private string GenerateIntComparison(BinaryOperator op, bool signed) - { - return op switch - { - BinaryOperator.Equal => "icmp eq", - BinaryOperator.NotEqual => "icmp ne", - BinaryOperator.GreaterThan => signed ? "icmp sgt" : "icmp ugt", - BinaryOperator.GreaterThanOrEqual => signed ? "icmp sge" : "icmp uge", - BinaryOperator.LessThan => signed ? "icmp slt" : "icmp ult", - BinaryOperator.LessThanOrEqual => signed ? "icmp sle" : "icmp ule", - _ => throw new ArgumentOutOfRangeException(nameof(op), op, null) - }; - } - - private string GenerateFloatComparison(BinaryOperator op) - { - return op switch - { - BinaryOperator.Equal => "fcmp oeq", - BinaryOperator.NotEqual => "fcmp one", - BinaryOperator.GreaterThan => "fcmp ogt", - BinaryOperator.GreaterThanOrEqual => "fcmp oge", - BinaryOperator.LessThan => "fcmp olt", - BinaryOperator.LessThanOrEqual => "fcmp ole", - _ => throw new ArgumentOutOfRangeException(nameof(op), op, null) - }; - } - - private static string GenerateBoolComparison(BinaryOperator op) - { - return op switch - { - BinaryOperator.Equal => "icmp eq", - BinaryOperator.NotEqual => "icmp ne", - _ => throw new ArgumentOutOfRangeException(nameof(op), op, null) - }; - } - - private static string GeneratePointerComparison(BinaryOperator op) - { - return op switch - { - BinaryOperator.Equal => "icmp eq", - BinaryOperator.NotEqual => "icmp ne", - BinaryOperator.GreaterThan => "icmp ugt", - BinaryOperator.GreaterThanOrEqual => "icmp uge", - BinaryOperator.LessThan => "icmp ult", - BinaryOperator.LessThanOrEqual => "icmp ule", - _ => throw new ArgumentOutOfRangeException(nameof(op), op, null) - }; - } - - #region Initializers - - private Tmp EmitConstArrayInitializer(IndentedTextWriter writer, ConstArrayInitializerNode constArrayInitializerNode, string? destination = null) - { - var arrayType = (NubConstArrayType)constArrayInitializerNode.Type; - - if (destination == null) - { - destination = NewTmp("array"); - writer.WriteLine($"{destination} = alloca {MapType(arrayType)}"); - } - - for (var i = 0; i < constArrayInitializerNode.Values.Count; i++) - { - var value = constArrayInitializerNode.Values[i]; - var indexTmp = NewTmp("array.element"); - writer.WriteLine($"{indexTmp} = getelementptr {MapType(arrayType)}, ptr {destination}, i32 0, i32 {i}"); - EmitExpressionInto(writer, value, indexTmp); - } - - return new Tmp(destination, constArrayInitializerNode.Type, false); - } - - private Tmp EmitStructInitializer(IndentedTextWriter writer, StructInitializerNode structInitializerNode, string? destination = null) - { - if (destination == null) - { - destination = NewTmp("struct"); - writer.WriteLine($"{destination} = alloca {MapType(structInitializerNode.Type)}"); - } - - var structType = (NubStructType)structInitializerNode.Type; - - writer.WriteLine($"call void @{StructName(structType.Module, structType.Name)}.new(ptr {destination})"); - - foreach (var (name, value) in structInitializerNode.Initializers) - { - var index = structType.GetFieldIndex(name.Value); - var fieldTmp = NewTmp($"struct.field.{name}"); - writer.WriteLine($"{fieldTmp} = getelementptr %{StructName(structType.Module, structType.Name)}, ptr {destination}, i32 0, i32 {index}"); - - EmitExpressionInto(writer, value, fieldTmp); - } - - return new Tmp(destination, structInitializerNode.Type, false); - } - - #endregion - - #region Array indexing - - private Tmp EmitConstArrayIndexAccess(IndentedTextWriter writer, ConstArrayIndexAccessNode constArrayIndexAccessNode) - { - var arrayPtr = Unwrap(writer, EmitExpression(writer, constArrayIndexAccessNode.Target)); - var index = Unwrap(writer, EmitExpression(writer, constArrayIndexAccessNode.Index)); - - var elementType = ((NubConstArrayType)constArrayIndexAccessNode.Target.Type).ElementType; - var ptrTmp = NewTmp("array.element"); - writer.WriteLine($"{ptrTmp} = getelementptr {MapType(elementType)}, ptr {arrayPtr}, {MapType(constArrayIndexAccessNode.Index.Type)} {index}"); - - return new Tmp(ptrTmp, constArrayIndexAccessNode.Type, true); - } - - private Tmp EmitArrayIndexAccess(IndentedTextWriter writer, ArrayIndexAccessNode arrayIndexAccessNode) - { - var arrayPtr = Unwrap(writer, EmitExpression(writer, arrayIndexAccessNode.Target)); - var index = Unwrap(writer, EmitExpression(writer, arrayIndexAccessNode.Index)); - - var elementType = ((NubArrayType)arrayIndexAccessNode.Target.Type).ElementType; - var ptrTmp = NewTmp("array.element"); - writer.WriteLine($"{ptrTmp} = getelementptr {MapType(elementType)}, ptr {arrayPtr}, {MapType(arrayIndexAccessNode.Index.Type)} {index}"); - - return new Tmp(ptrTmp, arrayIndexAccessNode.Type, true); - } - - private Tmp EmitSliceIndexAccess(IndentedTextWriter writer, SliceIndexAccessNode sliceIndexAccessNode) - { - throw new NotImplementedException(); - } - - #endregion - - private Tmp EmitStructFieldAccess(IndentedTextWriter writer, StructFieldAccessNode structFieldAccessNode) - { - var target = Unwrap(writer, EmitExpression(writer, structFieldAccessNode.Target)); - - var structType = (NubStructType)structFieldAccessNode.Target.Type; - var index = structType.GetFieldIndex(structFieldAccessNode.FieldToken.Value); - - var ptrTmp = NewTmp($"struct.field.{structFieldAccessNode.FieldToken.Value}"); - writer.WriteLine($"{ptrTmp} = getelementptr %{StructName(structType.Module, structType.Name)}, ptr {target}, i32 0, i32 {index}"); - - return new Tmp(ptrTmp, structFieldAccessNode.Type, true); - } - - #region Literals - - private Tmp EmitCStringLiteral(IndentedTextWriter writer, CStringLiteralNode cStringLiteralNode) - { - var escaped = new StringBuilder(); - foreach (var c in cStringLiteralNode.Value) - { - switch (c) - { - case '\0': - escaped.Append("\\00"); - break; - case '\n': - escaped.Append("\\0A"); - break; - case '\r': - escaped.Append("\\0D"); - break; - case '\t': - escaped.Append("\\09"); - break; - case '\\': - escaped.Append("\\\\"); - break; - case '"': - escaped.Append("\\22"); - break; - default: - { - if (c < 32 || c > 126) - escaped.Append($"\\{(int)c:X2}"); - else - escaped.Append(c); - - break; - } - } - } - - var stringWithNull = cStringLiteralNode.Value + "\0"; - var length = stringWithNull.Length; - - var globalName = $"@.str.{_stringLiterals.Count}"; - - _stringLiterals.Add((globalName, length, escaped.ToString())); - - var gepTmp = NewTmp("str.ptr"); - writer.WriteLine($"{gepTmp} = getelementptr [{length} x i8], ptr {globalName}, i32 0, i32 0"); - - return new Tmp(gepTmp, cStringLiteralNode.Type, false); - } - - private static Tmp EmitStringLiteral(IndentedTextWriter writer, StringLiteralNode stringLiteralNode) - { - throw new NotImplementedException(); - } - - private static Tmp EmitBoolLiteral(BoolLiteralNode boolLiteralNode) - { - return new Tmp(boolLiteralNode.Value ? "1" : "0", boolLiteralNode.Type, false); - } - - private static Tmp EmitFloat32Literal(Float32LiteralNode float32LiteralNode) - { - var literal = ((double)float32LiteralNode.Value).ToString("R", System.Globalization.CultureInfo.InvariantCulture); - if (!literal.Contains('.')) - { - literal += ".0"; - } - - return new Tmp(literal, float32LiteralNode.Type, false); - } - - private static Tmp EmitFloat64Literal(Float64LiteralNode float64LiteralNode) - { - var literal = float64LiteralNode.Value.ToString("R", System.Globalization.CultureInfo.InvariantCulture); - if (!literal.Contains('.')) - { - literal += ".0"; - } - - return new Tmp(literal, float64LiteralNode.Type, false); - } - - private static Tmp EmitU8Literal(U8LiteralNode u8LiteralNode) - { - return new Tmp(u8LiteralNode.Value.ToString(), u8LiteralNode.Type, false); - } - - private static Tmp EmitU16Literal(U16LiteralNode u16LiteralNode) - { - return new Tmp(u16LiteralNode.Value.ToString(), u16LiteralNode.Type, false); - } - - private static Tmp EmitU32Literal(U32LiteralNode u32LiteralNode) - { - return new Tmp(u32LiteralNode.Value.ToString(), u32LiteralNode.Type, false); - } - - private static Tmp EmitU64Literal(U64LiteralNode u64LiteralNode) - { - return new Tmp(u64LiteralNode.Value.ToString(), u64LiteralNode.Type, false); - } - - private static Tmp EmitI8Literal(I8LiteralNode i8LiteralNode) - { - return new Tmp(i8LiteralNode.Value.ToString(), i8LiteralNode.Type, false); - } - - private static Tmp EmitI16Literal(I16LiteralNode i16LiteralNode) - { - return new Tmp(i16LiteralNode.Value.ToString(), i16LiteralNode.Type, false); - } - - private static Tmp EmitI32Literal(I32LiteralNode i32LiteralNode) - { - return new Tmp(i32LiteralNode.Value.ToString(), i32LiteralNode.Type, false); - } - - private static Tmp EmitI64Literal(I64LiteralNode i64LiteralNode) - { - return new Tmp(i64LiteralNode.Value.ToString(), i64LiteralNode.Type, false); - } - - #endregion - - #region Identifiers - - private Tmp EmitLocalFuncIdentifier(IndentedTextWriter writer, LocalFuncIdentifierNode localFuncIdentifierNode) - { - var name = FuncName(_module, localFuncIdentifierNode.NameToken.Value, localFuncIdentifierNode.ExternSymbolToken?.Value); - return new Tmp($"@{name}", localFuncIdentifierNode.Type, false); - } - - private Tmp EmitModuleFuncIdentifier(IndentedTextWriter writer, ModuleFuncIdentifierNode moduleFuncIdentifierNode) - { - var name = FuncName(moduleFuncIdentifierNode.ModuleToken.Value, moduleFuncIdentifierNode.NameToken.Value, moduleFuncIdentifierNode.ExternSymbolToken?.Value); - return new Tmp($"@{name}", moduleFuncIdentifierNode.Type, false); - } - - private Tmp EmitVariableIdentifier(IndentedTextWriter writer, VariableIdentifierNode variableIdentifierNode) - { - return new Tmp($"%{variableIdentifierNode.NameToken.Value}", variableIdentifierNode.Type, true); - } - - #endregion - - private Tmp EmitFuncCall(IndentedTextWriter writer, FuncCallNode funcCallNode) - { - var result = NewTmp(); - - var parameterStrings = new List(); - - foreach (var parameter in funcCallNode.Parameters) - { - var value = Unwrap(writer, EmitExpression(writer, parameter)); - parameterStrings.Add($"{MapType(parameter.Type)} {value}"); - } - - var functionPtr = Unwrap(writer, EmitExpression(writer, funcCallNode.Expression)); - - if (funcCallNode.Type is NubVoidType) - { - writer.WriteLine($"call ccc {MapType(funcCallNode.Type)} {functionPtr}({string.Join(", ", parameterStrings)})"); - } - else - { - writer.WriteLine($"{result} = call ccc {MapType(funcCallNode.Type)} {functionPtr}({string.Join(", ", parameterStrings)})"); - } - - return new Tmp(result, funcCallNode.Type, false); - } - - private static Tmp EmitSize(SizeNode sizeNode) - { - return new Tmp(sizeNode.TargetType.GetSize().ToString(), sizeNode.Type, false); - } - - private Tmp EmitCast(IndentedTextWriter writer, CastNode castNode) - { - var source = Unwrap(writer, EmitExpression(writer, castNode.Value)); - var result = NewTmp("cast"); - - switch (castNode.ConversionType) - { - case CastNode.Conversion.IntToInt: - { - var sourceInt = (NubIntType)castNode.Value.Type; - var targetInt = (NubIntType)castNode.Type; - - var op = sourceInt.Width < targetInt.Width - ? sourceInt.Signed - ? "sext" - : "zext" - : sourceInt.Width > targetInt.Width - ? "trunc" - : "bitcast"; - - writer.WriteLine($"{result} = {op} {MapType(sourceInt)} {source} to {MapType(targetInt)}"); - break; - } - case CastNode.Conversion.FloatToFloat: - { - var sourceFloat = (NubFloatType)castNode.Value.Type; - var targetFloat = (NubFloatType)castNode.Type; - - var op = sourceFloat.Width < targetFloat.Width ? "fpext" : "fptrunc"; - writer.WriteLine($"{result} = {op} {MapType(sourceFloat)} {source} to {MapType(targetFloat)}"); - break; - } - case CastNode.Conversion.IntToFloat: - { - var sourceInt = (NubIntType)castNode.Value.Type; - var targetFloat = (NubFloatType)castNode.Type; - - var op = sourceInt.Signed ? "sitofp" : "uitofp"; - writer.WriteLine($"{result} = {op} {MapType(sourceInt)} {source} to {MapType(targetFloat)}"); - break; - } - case CastNode.Conversion.FloatToInt: - { - var sourceFloat = (NubFloatType)castNode.Value.Type; - var targetInt = (NubIntType)castNode.Type; - - var op = targetInt.Signed ? "fptosi" : "fptoui"; - writer.WriteLine($"{result} = {op} {MapType(sourceFloat)} {source} to {MapType(targetInt)}"); - break; - } - case CastNode.Conversion.PointerToPointer: - case CastNode.Conversion.PointerToUInt64: - case CastNode.Conversion.UInt64ToPointer: - { - writer.WriteLine($"{result} = inttoptr {MapType(castNode.Value.Type)} {source} to {MapType(castNode.Type)}"); - break; - } - case CastNode.Conversion.ConstArrayToArray: - { - var sourceConstArrayType = (NubConstArrayType)castNode.Value.Type; - var targetArrayType = (NubArrayType)castNode.Type; - - writer.WriteLine($"{result} = getelementptr {MapType(sourceConstArrayType)}, {MapType(targetArrayType)} {source}, i32 0, i32 0"); - break; - } - case CastNode.Conversion.ConstArrayToSlice: - { - throw new NotImplementedException(); - } - default: - { - throw new UnreachableException(); - } - } - - return new Tmp(result, castNode.Type, false); - } - - private string StructName(StructNode structNode) - { - return StructName(_module, structNode.NameToken.Value); - } - - private string StructName(string module, string name) - { - return $"struct.{module}.{name}"; - } - - private string FuncName(string module, string name, string? externSymbol) - { - if (externSymbol != null) - { - return externSymbol; - } - - return $"{module}.{name}"; - } - - private string MapType(NubType type) - { - return type switch - { - NubArrayType arrayType => $"{MapType(arrayType.ElementType)}*", - NubBoolType => "i1", - NubConstArrayType constArrayType => $"[{constArrayType.Size} x {MapType(constArrayType.ElementType)}]", - NubFloatType floatType => floatType.Width == 32 ? "float" : "double", - NubFuncType funcType => MapFuncType(funcType), - NubIntType intType => $"i{intType.Width}", - NubPointerType pointerType => $"{MapType(pointerType.BaseType)}*", - NubSliceType sliceType => throw new NotImplementedException(), - NubStringType stringType => throw new NotImplementedException(), - NubStructType structType => $"%{StructName(structType.Module, structType.Name)}", - NubVoidType => "void", - _ => throw new ArgumentOutOfRangeException(nameof(type)) - }; - } - - private string MapFuncType(NubFuncType funcType) - { - var paramTypes = string.Join(", ", funcType.Parameters.Select(MapType)); - var returnType = MapType(funcType.ReturnType); - return $"{returnType} ({paramTypes})*"; - } - - private record Tmp(string Ident, NubType Type, bool LValue); - - private string Unwrap(IndentedTextWriter writer, Tmp tmp) - { - if (tmp.LValue && !tmp.Type.IsAggregate()) - { - var newTmp = NewTmp("deref"); - writer.WriteLine($"{newTmp} = load {MapType(tmp.Type)}, ptr {tmp.Ident}"); - return newTmp; - } - - return tmp.Ident; - } - - private string NewTmp(string name = "t") - { - return $"%{name}.{++_tmpIndex}"; - } - - private string NewLabel(string name = "l") - { - return $"{name}.{++_labelIndex}"; - } -} \ No newline at end of file diff --git a/compiler/NubLang/Generation/LlvmSharpGenerator.cs b/compiler/NubLang/Generation/LlvmSharpGenerator.cs new file mode 100644 index 0000000..fb2a071 --- /dev/null +++ b/compiler/NubLang/Generation/LlvmSharpGenerator.cs @@ -0,0 +1,781 @@ +using System.Text; +using LLVMSharp.Interop; +using NubLang.Ast; +using NubLang.Modules; +using NubLang.Types; + +namespace NubLang.Generation; + +public class LlvmSharpGenerator +{ + private string _module = string.Empty; + private LLVMContextRef _context; + private LLVMModuleRef _llvmModule; + private LLVMBuilderRef _builder; + private readonly Dictionary _structTypes = new(); + private readonly Dictionary _functions = new(); + private readonly Dictionary _locals = new(); + private readonly Stack<(LLVMBasicBlockRef breakBlock, LLVMBasicBlockRef continueBlock)> _loopStack = new(); + + public void Emit(List topLevelNodes, ModuleRepository repository, string sourceFileName, string outputPath) + { + _module = topLevelNodes.OfType().First().NameToken.Value; + + _context = LLVMContextRef.Global; + _llvmModule = _context.CreateModuleWithName(sourceFileName); + _llvmModule.Target = "x86_64-pc-linux-gnu"; + _llvmModule.DataLayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128"; + + _builder = _context.CreateBuilder(); + + _structTypes.Clear(); + _functions.Clear(); + _locals.Clear(); + _loopStack.Clear(); + + var stringType = _context.CreateNamedStruct("nub.string"); + stringType.StructSetBody([LLVMTypeRef.Int64, LLVMTypeRef.CreatePointer(LLVMTypeRef.Int8, 0)], false); + _structTypes["nub.string"] = stringType; + + foreach (var module in repository.GetAll()) + { + foreach (var structType in module.StructTypes) + { + var structName = StructName(structType.Module, structType.Name); + var llvmStructType = _context.CreateNamedStruct(structName); + _structTypes[structName] = llvmStructType; + } + } + + foreach (var module in repository.GetAll()) + { + foreach (var structType in module.StructTypes) + { + var structName = StructName(structType.Module, structType.Name); + var llvmStructType = _structTypes[structName]; + var fieldTypes = structType.Fields.Select(f => MapType(f.Type)).ToArray(); + llvmStructType.StructSetBody(fieldTypes, false); + } + } + + foreach (var module in repository.GetAll()) + { + foreach (var prototype in module.FunctionPrototypes) + { + CreateFunctionDeclaration(prototype, module.Name); + } + } + + foreach (var structNode in topLevelNodes.OfType()) + { + EmitStructConstructor(structNode); + } + + foreach (var funcNode in topLevelNodes.OfType()) + { + if (funcNode.Body != null) + { + EmitFunction(funcNode); + } + } + + if (!_llvmModule.TryVerify(LLVMVerifierFailureAction.LLVMPrintMessageAction, out var error)) + { + // throw new Exception($"LLVM module verification failed: {error}"); + } + + _llvmModule.PrintToFile(outputPath); + + _builder.Dispose(); + } + + private void CreateFunctionDeclaration(FuncPrototypeNode prototype, string moduleName) + { + var funcName = FuncName(moduleName, prototype.NameToken.Value, prototype.ExternSymbolToken?.Value); + + var paramTypes = prototype.Parameters.Select(p => MapType(p.Type)).ToArray(); + var returnType = MapType(prototype.ReturnType); + + var funcType = LLVMTypeRef.CreateFunction(returnType, paramTypes); + var func = _llvmModule.AddFunction(funcName, funcType); + + func.FunctionCallConv = (uint)LLVMCallConv.LLVMCCallConv; + + for (var i = 0; i < prototype.Parameters.Count; i++) + { + func.GetParam((uint)i).Name = prototype.Parameters[i].NameToken.Value; + } + + _functions[funcName] = func; + } + + private void EmitStructConstructor(StructNode structNode) + { + var structType = _structTypes[StructName(_module, structNode.NameToken.Value)]; + var ptrType = LLVMTypeRef.CreatePointer(structType, 0); + + var funcType = LLVMTypeRef.CreateFunction(LLVMTypeRef.Void, [ptrType]); + var funcName = StructConstructorName(_module, structNode.NameToken.Value); + var func = _llvmModule.AddFunction(funcName, funcType); + func.FunctionCallConv = (uint)LLVMCallConv.LLVMCCallConv; + + var entryBlock = func.AppendBasicBlock("entry"); + _builder.PositionAtEnd(entryBlock); + + var selfParam = func.GetParam(0); + selfParam.Name = "self"; + + _locals.Clear(); + + foreach (var field in structNode.Fields) + { + if (field.Value != null) + { + var index = structNode.StructType.GetFieldIndex(field.NameToken.Value); + var fieldPtr = _builder.BuildStructGEP2(structType, selfParam, (uint)index); + EmitExpressionInto(field.Value, fieldPtr); + } + } + + _builder.BuildRetVoid(); + _functions[funcName] = func; + } + + private void EmitFunction(FuncNode funcNode) + { + var funcName = FuncName(_module, funcNode.Prototype.NameToken.Value, funcNode.Prototype.ExternSymbolToken?.Value); + var func = _functions[funcName]; + + var entryBlock = func.AppendBasicBlock("entry"); + _builder.PositionAtEnd(entryBlock); + + _locals.Clear(); + + for (uint i = 0; i < funcNode.Prototype.Parameters.Count; i++) + { + var param = func.GetParam(i); + var paramNode = funcNode.Prototype.Parameters[(int)i]; + var alloca = _builder.BuildAlloca(MapType(paramNode.Type), paramNode.NameToken.Value); + _builder.BuildStore(param, alloca); + _locals[paramNode.NameToken.Value] = alloca; + } + + EmitBlock(funcNode.Body!); + + if (funcNode.Prototype.ReturnType is NubVoidType) + { + if (_builder.InsertBlock.Terminator.Handle == IntPtr.Zero) + { + _builder.BuildRetVoid(); + } + } + } + + private void EmitBlock(BlockNode blockNode) + { + foreach (var statement in blockNode.Statements) + { + EmitStatement(statement); + } + } + + private void EmitStatement(StatementNode statement) + { + switch (statement) + { + case AssignmentNode assignment: + EmitAssignment(assignment); + break; + case BlockNode block: + EmitBlock(block); + break; + case BreakNode: + EmitBreak(); + break; + case ContinueNode: + EmitContinue(); + break; + case IfNode ifNode: + EmitIf(ifNode); + break; + case ReturnNode returnNode: + EmitReturn(returnNode); + break; + case StatementFuncCallNode funcCall: + EmitExpression(funcCall.FuncCall); + break; + case VariableDeclarationNode varDecl: + EmitVariableDeclaration(varDecl); + break; + case WhileNode whileNode: + EmitWhile(whileNode); + break; + default: + throw new NotImplementedException($"Statement type {statement.GetType()} not implemented"); + } + } + + private void EmitAssignment(AssignmentNode assignment) + { + var targetPtr = EmitExpression(assignment.Target, asLValue: true); + var value = EmitExpression(assignment.Value); + _builder.BuildStore(value, targetPtr); + } + + private void EmitBreak() + { + var (breakBlock, _) = _loopStack.Peek(); + _builder.BuildBr(breakBlock); + } + + private void EmitContinue() + { + var (_, continueBlock) = _loopStack.Peek(); + _builder.BuildBr(continueBlock); + } + + private void EmitIf(IfNode ifNode) + { + var condition = EmitExpression(ifNode.Condition); + + var func = _builder.InsertBlock.Parent; + var thenBlock = func.AppendBasicBlock("if.then"); + var elseBlock = ifNode.Else.HasValue ? func.AppendBasicBlock("if.else") : default; + var endBlock = func.AppendBasicBlock("if.end"); + + _builder.BuildCondBr(condition, thenBlock, ifNode.Else.HasValue ? elseBlock : endBlock); + + _builder.PositionAtEnd(thenBlock); + EmitBlock(ifNode.Body); + if (_builder.InsertBlock.Terminator.Handle == IntPtr.Zero) + { + _builder.BuildBr(endBlock); + } + + if (ifNode.Else.HasValue) + { + _builder.PositionAtEnd(elseBlock); + ifNode.Else.Value.Match(EmitIf, EmitBlock); + if (_builder.InsertBlock.Terminator.Handle == IntPtr.Zero) + { + _builder.BuildBr(endBlock); + } + } + + _builder.PositionAtEnd(endBlock); + } + + private void EmitReturn(ReturnNode returnNode) + { + if (returnNode.Value != null) + { + var value = EmitExpression(returnNode.Value); + _builder.BuildRet(value); + } + else + { + _builder.BuildRetVoid(); + } + } + + private void EmitVariableDeclaration(VariableDeclarationNode varDecl) + { + var alloca = _builder.BuildAlloca(MapType(varDecl.Type), varDecl.NameToken.Value); + _locals[varDecl.NameToken.Value] = alloca; + + if (varDecl.Assignment != null) + { + EmitExpressionInto(varDecl.Assignment, alloca); + } + } + + private void EmitWhile(WhileNode whileNode) + { + var func = _builder.InsertBlock.Parent; + var condBlock = func.AppendBasicBlock("while.cond"); + var bodyBlock = func.AppendBasicBlock("while.body"); + var endBlock = func.AppendBasicBlock("while.end"); + + _loopStack.Push((endBlock, condBlock)); + + _builder.BuildBr(condBlock); + + _builder.PositionAtEnd(condBlock); + var condition = EmitExpression(whileNode.Condition); + _builder.BuildCondBr(condition, bodyBlock, endBlock); + + _builder.PositionAtEnd(bodyBlock); + EmitBlock(whileNode.Body); + if (_builder.InsertBlock.Terminator.Handle == IntPtr.Zero) + { + _builder.BuildBr(condBlock); + } + + _loopStack.Pop(); + + _builder.PositionAtEnd(endBlock); + } + + private LLVMValueRef EmitExpression(ExpressionNode expr, bool asLValue = false) + { + var result = expr switch + { + StringLiteralNode stringLiteralNode => EmitStringLiteral(stringLiteralNode), + CStringLiteralNode cStringLiteralNode => EmitCStringLiteral(cStringLiteralNode), + BoolLiteralNode b => LLVMValueRef.CreateConstInt(LLVMTypeRef.Int1, b.Value ? 1UL : 0UL), + I8LiteralNode i => LLVMValueRef.CreateConstInt(LLVMTypeRef.Int8, (ulong)i.Value, true), + I16LiteralNode i => LLVMValueRef.CreateConstInt(LLVMTypeRef.Int16, (ulong)i.Value, true), + I32LiteralNode i => LLVMValueRef.CreateConstInt(LLVMTypeRef.Int32, (ulong)i.Value, true), + I64LiteralNode i => LLVMValueRef.CreateConstInt(LLVMTypeRef.Int64, (ulong)i.Value, true), + U8LiteralNode u => LLVMValueRef.CreateConstInt(LLVMTypeRef.Int8, u.Value), + U16LiteralNode u => LLVMValueRef.CreateConstInt(LLVMTypeRef.Int16, u.Value), + U32LiteralNode u => LLVMValueRef.CreateConstInt(LLVMTypeRef.Int32, u.Value), + U64LiteralNode u => LLVMValueRef.CreateConstInt(LLVMTypeRef.Int64, u.Value), + Float32LiteralNode f => LLVMValueRef.CreateConstReal(LLVMTypeRef.Float, f.Value), + Float64LiteralNode f => LLVMValueRef.CreateConstReal(LLVMTypeRef.Double, f.Value), + + VariableIdentifierNode v => EmitVariableIdentifier(v), + LocalFuncIdentifierNode localFuncIdentifierNode => EmitLocalFuncIdentifier(localFuncIdentifierNode), + ModuleFuncIdentifierNode moduleFuncIdentifierNode => EmitModuleFuncIdentifier(moduleFuncIdentifierNode), + + BinaryExpressionNode bin => EmitBinaryExpression(bin), + UnaryExpressionNode unary => EmitUnaryExpression(unary), + + StructFieldAccessNode field => EmitStructFieldAccess(field), + ConstArrayIndexAccessNode arr => EmitConstArrayIndexAccess(arr), + SliceIndexAccessNode sliceIndexAccessNode => EmitSliceIndexAccess(sliceIndexAccessNode), + ArrayIndexAccessNode arrayIndexAccessNode => EmitArrayIndexAccess(arrayIndexAccessNode), + + ConstArrayInitializerNode constArrayInitializerNode => EmitConstArrayInitializer(constArrayInitializerNode), + StructInitializerNode structInitializerNode => EmitStructInitializer(structInitializerNode), + + AddressOfNode addr => EmitAddressOf(addr), + DereferenceNode deref => EmitDereference(deref), + + FuncCallNode funcCall => EmitFuncCall(funcCall), + CastNode cast => EmitCast(cast), + SizeNode size => LLVMValueRef.CreateConstInt(LLVMTypeRef.Int64, size.TargetType.GetSize()), + + _ => throw new ArgumentOutOfRangeException(nameof(expr), expr, null) + }; + + if (expr is LValue) + { + if (asLValue) + { + return result; + } + + return _builder.BuildLoad2(MapType(expr.Type), result); + } + + if (asLValue) + { + throw new InvalidOperationException($"Expression of type {expr.GetType().Name} is not an lvalue and cannot be used where an address is required"); + } + + return result; + } + + private void EmitExpressionInto(ExpressionNode expr, LLVMValueRef destPtr) + { + switch (expr) + { + case StructInitializerNode structInit: + EmitStructInitializer(structInit, destPtr); + return; + case ConstArrayInitializerNode arrayInit: + EmitConstArrayInitializer(arrayInit, destPtr); + return; + default: + { + var result = EmitExpression(expr); + _builder.BuildStore(result, destPtr); + break; + } + } + } + + private LLVMValueRef EmitStringLiteral(StringLiteralNode stringLiteralNode) + { + var strValue = stringLiteralNode.Value; + var length = (ulong)Encoding.UTF8.GetByteCount(strValue); + var globalStr = _builder.BuildGlobalStringPtr(strValue); + var llvmStringType = MapType(stringLiteralNode.Type); + + var strAlloca = _builder.BuildAlloca(llvmStringType); + + var lengthPtr = _builder.BuildStructGEP2(llvmStringType, strAlloca, 0); + var lengthConst = LLVMValueRef.CreateConstInt(LLVMTypeRef.Int64, length); + _builder.BuildStore(lengthConst, lengthPtr); + + var dataPtr = _builder.BuildStructGEP2(llvmStringType, strAlloca, 1); + _builder.BuildStore(globalStr, dataPtr); + + return _builder.BuildLoad2(llvmStringType, strAlloca); + } + + private LLVMValueRef EmitCStringLiteral(CStringLiteralNode cStringLiteralNode) + { + return _builder.BuildGlobalStringPtr(cStringLiteralNode.Value); + } + + private LLVMValueRef EmitVariableIdentifier(VariableIdentifierNode v) + { + return _locals[v.NameToken.Value]; + } + + private LLVMValueRef EmitLocalFuncIdentifier(LocalFuncIdentifierNode localFuncIdentifierNode) + { + return _functions[FuncName(_module, localFuncIdentifierNode.NameToken.Value, localFuncIdentifierNode.ExternSymbolToken?.Value)]; + } + + private LLVMValueRef EmitModuleFuncIdentifier(ModuleFuncIdentifierNode moduleFuncIdentifierNode) + { + return _functions[FuncName(moduleFuncIdentifierNode.ModuleToken.Value, moduleFuncIdentifierNode.NameToken.Value, moduleFuncIdentifierNode.ExternSymbolToken?.Value)]; + } + + private LLVMValueRef EmitBinaryExpression(BinaryExpressionNode bin) + { + var left = EmitExpression(bin.Left); + var right = EmitExpression(bin.Right); + + var leftType = bin.Left.Type; + + var result = bin.Operator switch + { + BinaryOperator.Plus when leftType is NubIntType => _builder.BuildAdd(left, right), + BinaryOperator.Plus when leftType is NubFloatType => _builder.BuildFAdd(left, right), + + BinaryOperator.Minus when leftType is NubIntType => _builder.BuildSub(left, right), + BinaryOperator.Minus when leftType is NubFloatType => _builder.BuildFSub(left, right), + + BinaryOperator.Multiply when leftType is NubIntType => _builder.BuildMul(left, right), + BinaryOperator.Multiply when leftType is NubFloatType => _builder.BuildFMul(left, right), + + BinaryOperator.Divide when leftType is NubIntType intType => intType.Signed ? _builder.BuildSDiv(left, right) : _builder.BuildUDiv(left, right), + BinaryOperator.Divide when leftType is NubFloatType => _builder.BuildFDiv(left, right), + + BinaryOperator.Modulo when leftType is NubIntType intType => intType.Signed ? _builder.BuildSRem(left, right) : _builder.BuildURem(left, right), + BinaryOperator.Modulo when leftType is NubFloatType => _builder.BuildFRem(left, right), + + BinaryOperator.LogicalAnd => _builder.BuildAnd(left, right), + BinaryOperator.LogicalOr => _builder.BuildOr(left, right), + + BinaryOperator.Equal when leftType is NubIntType or NubBoolType or NubPointerType => _builder.BuildICmp(LLVMIntPredicate.LLVMIntEQ, left, right), + BinaryOperator.Equal when leftType is NubFloatType => _builder.BuildFCmp(LLVMRealPredicate.LLVMRealOEQ, left, right), + + BinaryOperator.NotEqual when leftType is NubIntType or NubBoolType or NubPointerType => _builder.BuildICmp(LLVMIntPredicate.LLVMIntNE, left, right), + BinaryOperator.NotEqual when leftType is NubFloatType => _builder.BuildFCmp(LLVMRealPredicate.LLVMRealONE, left, right), + + BinaryOperator.GreaterThan when leftType is NubIntType intType => _builder.BuildICmp(intType.Signed ? LLVMIntPredicate.LLVMIntSGT : LLVMIntPredicate.LLVMIntUGT, left, right), + BinaryOperator.GreaterThan when leftType is NubFloatType => _builder.BuildFCmp(LLVMRealPredicate.LLVMRealOGT, left, right), + + BinaryOperator.GreaterThanOrEqual when leftType is NubIntType intType => _builder.BuildICmp(intType.Signed ? LLVMIntPredicate.LLVMIntSGE : LLVMIntPredicate.LLVMIntUGE, left, right), + BinaryOperator.GreaterThanOrEqual when leftType is NubFloatType => _builder.BuildFCmp(LLVMRealPredicate.LLVMRealOGE, left, right), + + BinaryOperator.LessThan when leftType is NubIntType intType => _builder.BuildICmp(intType.Signed ? LLVMIntPredicate.LLVMIntSLT : LLVMIntPredicate.LLVMIntULT, left, right), + BinaryOperator.LessThan when leftType is NubFloatType => _builder.BuildFCmp(LLVMRealPredicate.LLVMRealOLT, left, right), + + BinaryOperator.LessThanOrEqual when leftType is NubIntType intType => _builder.BuildICmp(intType.Signed ? LLVMIntPredicate.LLVMIntSLE : LLVMIntPredicate.LLVMIntULE, left, right), + BinaryOperator.LessThanOrEqual when leftType is NubFloatType => _builder.BuildFCmp(LLVMRealPredicate.LLVMRealOLE, left, right), + + BinaryOperator.LeftShift => _builder.BuildShl(left, right), + BinaryOperator.RightShift when leftType is NubIntType intType => intType.Signed ? _builder.BuildAShr(left, right) : _builder.BuildLShr(left, right), + + BinaryOperator.BitwiseAnd => _builder.BuildAnd(left, right), + BinaryOperator.BitwiseXor => _builder.BuildXor(left, right), + BinaryOperator.BitwiseOr => _builder.BuildOr(left, right), + + _ => throw new ArgumentOutOfRangeException(nameof(bin.Operator)) + }; + + return result; + } + + private LLVMValueRef EmitUnaryExpression(UnaryExpressionNode unary) + { + var operand = EmitExpression(unary.Operand); + + var result = unary.Operator switch + { + UnaryOperator.Negate when unary.Operand.Type is NubIntType => _builder.BuildNeg(operand), + UnaryOperator.Negate when unary.Operand.Type is NubFloatType => _builder.BuildFNeg(operand), + UnaryOperator.Invert => _builder.BuildXor(operand, LLVMValueRef.CreateConstInt(LLVMTypeRef.Int1, 1)), + _ => throw new ArgumentOutOfRangeException(nameof(unary.Operator)) + }; + + return result; + } + + private LLVMValueRef EmitFuncCall(FuncCallNode funcCall) + { + var funcPtr = EmitExpression(funcCall.Expression); + var args = funcCall.Parameters.Select(x => EmitExpression(x)).ToArray(); + return _builder.BuildCall2(MapType(funcCall.Expression.Type), funcPtr, args, funcCall.Type is NubVoidType ? "" : "call"); + } + + private LLVMValueRef EmitStructFieldAccess(StructFieldAccessNode field) + { + var target = EmitExpression(field.Target, asLValue: true); + var structType = (NubStructType)field.Target.Type; + var index = structType.GetFieldIndex(field.FieldToken.Value); + + var llvmStructType = _structTypes[StructName(structType.Module, structType.Name)]; + return _builder.BuildStructGEP2(llvmStructType, target, (uint)index); + } + + private LLVMValueRef EmitConstArrayIndexAccess(ConstArrayIndexAccessNode constArrayIndexAccessNode) + { + var arrayPtr = EmitExpression(constArrayIndexAccessNode.Target, asLValue: true); + var index = EmitExpression(constArrayIndexAccessNode.Index); + var indices = new[] { LLVMValueRef.CreateConstInt(LLVMTypeRef.Int32, 0), index }; + return _builder.BuildInBoundsGEP2(MapType(constArrayIndexAccessNode.Target.Type), arrayPtr, indices); + } + + private LLVMValueRef EmitSliceIndexAccess(SliceIndexAccessNode sliceIndexAccessNode) + { + var slicePtr = EmitExpression(sliceIndexAccessNode.Target, asLValue: true); + var index = EmitExpression(sliceIndexAccessNode.Index); + + var sliceType = (NubSliceType)sliceIndexAccessNode.Target.Type; + var llvmSliceType = MapType(sliceType); + var elementType = MapType(sliceType.ElementType); + + var dataPtrPtr = _builder.BuildStructGEP2(llvmSliceType, slicePtr, 1); + var dataPtr = _builder.BuildLoad2(LLVMTypeRef.CreatePointer(elementType, 0), dataPtrPtr); + return _builder.BuildInBoundsGEP2(elementType, dataPtr, [index]); + } + + private LLVMValueRef EmitArrayIndexAccess(ArrayIndexAccessNode arrayIndexAccessNode) + { + var arrayPtr = EmitExpression(arrayIndexAccessNode.Target); + var index = EmitExpression(arrayIndexAccessNode.Index); + return _builder.BuildGEP2(MapType(arrayIndexAccessNode.Target.Type), arrayPtr, [index]); + } + + private LLVMValueRef EmitConstArrayInitializer(ConstArrayInitializerNode constArrayInitializerNode, LLVMValueRef? destination = null) + { + var arrayType = (NubConstArrayType)constArrayInitializerNode.Type; + var llvmType = MapType(arrayType); + + destination ??= _builder.BuildAlloca(llvmType); + + for (var i = 0; i < constArrayInitializerNode.Values.Count; i++) + { + var indices = new[] + { + LLVMValueRef.CreateConstInt(LLVMTypeRef.Int32, 0), + LLVMValueRef.CreateConstInt(LLVMTypeRef.Int32, (ulong)i) + }; + + var elementPtr = _builder.BuildInBoundsGEP2(llvmType, destination.Value, indices); + + EmitExpressionInto(constArrayInitializerNode.Values[i], elementPtr); + } + + return destination.Value; + } + + private LLVMValueRef EmitStructInitializer(StructInitializerNode structInitializerNode, LLVMValueRef? destination = null) + { + var type = (NubStructType)structInitializerNode.Type; + var llvmType = MapType(type); + + destination ??= _builder.BuildAlloca(llvmType); + + var constructorType = LLVMTypeRef.CreateFunction(LLVMTypeRef.Void, [LLVMTypeRef.CreatePointer(_structTypes[StructName(type.Module, type.Name)], 0)]); + var constructor = _functions[StructConstructorName(type.Module, type.Name)]; + _builder.BuildCall2(constructorType, constructor, [destination.Value]); + + foreach (var (name, value) in structInitializerNode.Initializers) + { + var fieldIndex = type.GetFieldIndex(name.Value); + var fieldPtr = _builder.BuildStructGEP2(llvmType, destination.Value, (uint)fieldIndex); + EmitExpressionInto(value, fieldPtr); + } + + return destination.Value; + } + + private LLVMValueRef EmitAddressOf(AddressOfNode addr) + { + return EmitExpression(addr.Target, asLValue: true); + } + + private LLVMValueRef EmitDereference(DereferenceNode deref) + { + return EmitExpression(deref.Target, asLValue: false); + } + + private LLVMValueRef EmitCast(CastNode castNode) + { + return castNode.ConversionType switch + { + CastNode.Conversion.IntToInt => EmitIntToIntCast(castNode), + CastNode.Conversion.FloatToFloat => EmitFloatToFloatCast(castNode), + CastNode.Conversion.IntToFloat => EmitIntToFloatCast(castNode), + CastNode.Conversion.FloatToInt => EmitFloatToIntCast(castNode), + CastNode.Conversion.PointerToPointer or CastNode.Conversion.PointerToUInt64 or CastNode.Conversion.UInt64ToPointer => _builder.BuildIntToPtr(EmitExpression(castNode.Value), MapType(castNode.Type)), + CastNode.Conversion.ConstArrayToSlice => EmitConstArrayToSliceCast(castNode), + CastNode.Conversion.ConstArrayToArray => EmitConstArrayToArrayCast(castNode), + CastNode.Conversion.StringToCString => EmitStringToCStringCast(castNode), + _ => throw new ArgumentOutOfRangeException(nameof(castNode.ConversionType)) + }; + } + + private LLVMValueRef EmitIntToIntCast(CastNode castNode) + { + var sourceInt = (NubIntType)castNode.Value.Type; + var targetInt = (NubIntType)castNode.Type; + var source = EmitExpression(castNode.Value); + + if (sourceInt.Width < targetInt.Width) + { + return sourceInt.Signed + ? _builder.BuildSExt(source, MapType(targetInt)) + : _builder.BuildZExt(source, MapType(targetInt)); + } + + if (sourceInt.Width > targetInt.Width) + { + return _builder.BuildTrunc(source, MapType(targetInt)); + } + + return _builder.BuildBitCast(source, MapType(targetInt)); + } + + private LLVMValueRef EmitFloatToFloatCast(CastNode castNode) + { + var sourceFloat = (NubFloatType)castNode.Value.Type; + var targetFloat = (NubFloatType)castNode.Type; + var source = EmitExpression(castNode.Value); + + return sourceFloat.Width < targetFloat.Width + ? _builder.BuildFPExt(source, MapType(castNode.Type)) + : _builder.BuildFPTrunc(source, MapType(castNode.Type)); + } + + private LLVMValueRef EmitIntToFloatCast(CastNode castNode) + { + var sourceInt = (NubIntType)castNode.Value.Type; + var source = EmitExpression(castNode.Value); + + return sourceInt.Signed + ? _builder.BuildSIToFP(source, MapType(castNode.Type)) + : _builder.BuildUIToFP(source, MapType(castNode.Type)); + } + + private LLVMValueRef EmitFloatToIntCast(CastNode castNode) + { + var targetInt = (NubIntType)castNode.Type; + var source = EmitExpression(castNode.Value); + + return targetInt.Signed + ? _builder.BuildFPToSI(source, MapType(targetInt)) + : _builder.BuildFPToUI(source, MapType(targetInt)); + } + + private LLVMValueRef EmitConstArrayToSliceCast(CastNode castNode) + { + var sourceArrayType = (NubConstArrayType)castNode.Value.Type; + var targetSliceType = (NubSliceType)castNode.Type; + var source = EmitExpression(castNode.Value, asLValue: true); + + var llvmArrayType = MapType(sourceArrayType); + var llvmSliceType = MapType(targetSliceType); + + var indices = new[] + { + LLVMValueRef.CreateConstInt(LLVMTypeRef.Int32, 0), + LLVMValueRef.CreateConstInt(LLVMTypeRef.Int32, 0) + }; + + var firstElementPtr = _builder.BuildInBoundsGEP2(llvmArrayType, source, indices); + + var slicePtr = _builder.BuildAlloca(llvmSliceType); + + var lengthPtr = _builder.BuildStructGEP2(llvmSliceType, slicePtr, 0); + var length = LLVMValueRef.CreateConstInt(LLVMTypeRef.Int64, sourceArrayType.Size); + _builder.BuildStore(length, lengthPtr); + + var dataPtrPtr = _builder.BuildStructGEP2(llvmSliceType, slicePtr, 1); + _builder.BuildStore(firstElementPtr, dataPtrPtr); + + return _builder.BuildLoad2(llvmSliceType, slicePtr); + } + + private LLVMValueRef EmitConstArrayToArrayCast(CastNode castNode) + { + var sourceArrayType = (NubConstArrayType)castNode.Value.Type; + var source = EmitExpression(castNode.Value, asLValue: true); + + var indices = new[] + { + LLVMValueRef.CreateConstInt(LLVMTypeRef.Int32, 0), + LLVMValueRef.CreateConstInt(LLVMTypeRef.Int32, 0) + }; + + return _builder.BuildInBoundsGEP2(MapType(sourceArrayType), source, indices); + } + + private LLVMValueRef EmitStringToCStringCast(CastNode castNode) + { + var source = EmitExpression(castNode.Value, asLValue: true); + var dataPtrPtr = _builder.BuildStructGEP2(MapType(castNode.Value.Type), source, 1); + return _builder.BuildLoad2(LLVMTypeRef.CreatePointer(LLVMTypeRef.Int8, 0), dataPtrPtr); + } + + private LLVMTypeRef MapType(NubType type) + { + return type switch + { + NubBoolType => LLVMTypeRef.Int1, + NubIntType intType => LLVMTypeRef.CreateInt((uint)intType.Width), + NubFloatType floatType => floatType.Width == 32 ? LLVMTypeRef.Float : LLVMTypeRef.Double, + NubFuncType funcType => LLVMTypeRef.CreateFunction(MapType(funcType.ReturnType), funcType.Parameters.Select(MapType).ToArray()), + NubPointerType ptrType => LLVMTypeRef.CreatePointer(MapType(ptrType.BaseType), 0), + NubSliceType nubSliceType => MapSliceType(nubSliceType), + NubStringType => _structTypes["nub.string"], + NubArrayType arrType => LLVMTypeRef.CreatePointer(MapType(arrType.ElementType), 0), + NubConstArrayType constArr => LLVMTypeRef.CreateArray(MapType(constArr.ElementType), (uint)constArr.Size), + NubStructType structType => _structTypes[StructName(structType.Module, structType.Name)], + NubVoidType => LLVMTypeRef.Void, + _ => throw new ArgumentOutOfRangeException(nameof(type), type, null) + }; + } + + private LLVMTypeRef MapSliceType(NubSliceType nubSliceType) + { + var mangledName = NameMangler.Mangle(nubSliceType.ElementType); + var name = $"nub.slice.{mangledName}"; + if (!_structTypes.TryGetValue(name, out var type)) + { + type = _context.CreateNamedStruct(name); + type.StructSetBody([LLVMTypeRef.Int64, LLVMTypeRef.CreatePointer(MapType(nubSliceType.ElementType), 0)], false); + _structTypes[name] = type; + } + + return type; + } + + private static string StructName(string module, string name) + { + return $"struct.{module}.{name}"; + } + + private static string StructConstructorName(string module, string name) + { + return $"{StructName(module, name)}.new"; + } + + private static string FuncName(string module, string name, string? externSymbol) + { + if (externSymbol != null) + { + return externSymbol; + } + + return $"{module}.{name}"; + } +} \ No newline at end of file diff --git a/compiler/NubLang/NubLang.csproj b/compiler/NubLang/NubLang.csproj index b682a68..7bda37f 100644 --- a/compiler/NubLang/NubLang.csproj +++ b/compiler/NubLang/NubLang.csproj @@ -7,4 +7,8 @@ true + + + + diff --git a/examples/playgroud/build.sh b/examples/playgroud/build.sh index 1df6fc6..1ca616f 100755 --- a/examples/playgroud/build.sh +++ b/examples/playgroud/build.sh @@ -2,5 +2,5 @@ set -euo pipefail -nubc main.nub test.nub -clang .build/main.ll .build/test.ll -o .build/out \ No newline at end of file +nubc main.nub +clang .build/main.ll -o .build/out \ No newline at end of file diff --git a/examples/playgroud/main.nub b/examples/playgroud/main.nub index da133f4..eb22e6b 100644 --- a/examples/playgroud/main.nub +++ b/examples/playgroud/main.nub @@ -2,20 +2,13 @@ module main extern "puts" func puts(text: ^i8) -struct Test -{ - field: u32 +struct Test { + test: ^i8 = "test1" } extern "main" func main(argc: i64, argv: [?]^i8) { - let x: ^i8 = "test" - // test - ^x^ = "uwu" - puts(x) -} + let x = "test" -func test(test: Test): Test -{ - return test + puts(x) } \ No newline at end of file diff --git a/examples/playgroud/test.nub b/examples/playgroud/test.nub deleted file mode 100644 index b8f9d6a..0000000 --- a/examples/playgroud/test.nub +++ /dev/null @@ -1,8 +0,0 @@ -module test - -extern "puts" func puts(text: ^i8) - -func test() -{ - puts("uwu") -} \ No newline at end of file