using LLVMSharp.Interop; using NubLang.Ast; namespace NubLang.Generation; public class Generator { private readonly CompilationUnit _compilationUnit; private bool _done; private readonly Dictionary _namedTypes = new(); private readonly Dictionary _namedValues = []; private readonly LLVMModuleRef _module; private LLVMBuilderRef _builder; private LLVMValueRef _currentFunction; public Generator(CompilationUnit compilationUnit, string fileName) { _compilationUnit = compilationUnit; _module = LLVMModuleRef.CreateWithName(fileName); } private LLVMTypeRef MapType(NubType nubType) { return nubType switch { NubArrayType nubArrayType => LLVMTypeRef.CreatePointer(MapType(nubArrayType.ElementType), 0), NubBoolType => LLVMTypeRef.Int1, NubConstArrayType nubConstArrayType => LLVMTypeRef.CreatePointer(MapType(nubConstArrayType.ElementType), 0), NubCStringType => LLVMTypeRef.CreatePointer(LLVMTypeRef.Int8, 0), NubFloatType nubFloatType => nubFloatType.Width switch { 32 => LLVMTypeRef.Float, 64 => LLVMTypeRef.Double, _ => throw new ArgumentOutOfRangeException(nameof(nubFloatType.Width)) }, NubFuncType nubFuncType => LLVMTypeRef.CreateFunction(MapType(nubFuncType.ReturnType), nubFuncType.Parameters.Select(MapType).ToArray()), NubIntType nubIntType => LLVMTypeRef.CreateInt((uint)nubIntType.Width), NubPointerType nubPointerType => LLVMTypeRef.CreatePointer(MapType(nubPointerType.BaseType), 0), NubSliceType nubSliceType => GetOrCreateSliceType(nubSliceType), NubStringType => GetOrCreateStringType(), NubStructType nubStructType => GetOrCreateStructType(nubStructType), NubVoidType => LLVMTypeRef.Void, _ => throw new ArgumentOutOfRangeException(nameof(nubType)) }; } private LLVMTypeRef GetOrCreateStructType(NubStructType structType) { var key = $"struct.{structType.Module}.{structType.Name}"; if (!_namedTypes.TryGetValue(key, out var llvmType)) { llvmType = _module.Context.CreateNamedStruct(structType.Name); _namedTypes[key] = llvmType; var fields = structType.Fields.Select(f => MapType(f.Type)).ToArray(); llvmType.StructSetBody(fields, false); } return llvmType; } private LLVMTypeRef GetOrCreateSliceType(NubSliceType sliceType) { var key = $"slice.{sliceType.ElementType}"; if (!_namedTypes.TryGetValue(key, out var llvmType)) { llvmType = LLVMTypeRef.CreateStruct([LLVMTypeRef.CreatePointer(MapType(sliceType.ElementType), 0), LLVMTypeRef.Int64], false); _namedTypes[key] = llvmType; } return llvmType; } private LLVMTypeRef GetOrCreateStringType() { const string key = "string"; if (!_namedTypes.TryGetValue(key, out var llvmType)) { llvmType = _module.Context.CreateNamedStruct("string"); llvmType.StructSetBody([LLVMTypeRef.CreatePointer(LLVMTypeRef.Int8, 0), LLVMTypeRef.Int64], false); _namedTypes[key] = llvmType; } return llvmType; } public LLVMModuleRef Generate() { if (_done) { return _module; } _done = true; _builder = _module.Context.CreateBuilder(); foreach (var funcNode in _compilationUnit.Functions) { EmitFunction(funcNode); } if (!_module.TryVerify(LLVMVerifierFailureAction.LLVMPrintMessageAction, out var error)) { Console.WriteLine($"Invalid LLVM module: {error}"); // throw new InvalidOperationException($"Invalid LLVM module: {error}"); } return _module; } private void EmitFunction(FuncNode funcNode) { var parameters = funcNode.Prototype.Parameters.Select(x => MapType(x.Type)).ToArray(); var functionType = LLVMTypeRef.CreateFunction(MapType(funcNode.Prototype.ReturnType), parameters); _currentFunction = _module.AddFunction(funcNode.Name, functionType); if (funcNode.Body == null) { return; } _namedValues.Clear(); var entryBlock = _currentFunction.AppendBasicBlock("entry"); _builder.PositionAtEnd(entryBlock); for (var i = 0; i < funcNode.Prototype.Parameters.Count; i++) { var param = funcNode.Prototype.Parameters[i]; var llvmParam = _currentFunction.GetParam((uint)i); llvmParam.Name = param.Name; var alloca = _builder.BuildAlloca(MapType(param.Type)); _builder.BuildStore(llvmParam, alloca); _namedValues[param.Name] = alloca; } EmitBlock(funcNode.Body); } private void EmitBlock(BlockNode blockNode) { foreach (var statement in blockNode.Statements) { EmitStatement(statement); } } private void EmitStatement(StatementNode statement) { switch (statement) { case AssignmentNode assignmentNode: EmitAssignment(assignmentNode); break; case BlockNode blockNode: EmitBlock(blockNode); break; case BreakNode breakNode: EmitBreak(breakNode); break; case ContinueNode continueNode: EmitContinue(continueNode); break; case DeferNode deferNode: EmitDefer(deferNode); break; case IfNode ifNode: EmitIf(ifNode); break; case ReturnNode returnNode: EmitReturn(returnNode); break; case StatementFuncCallNode statementFuncCallNode: EmitStatementFuncCall(statementFuncCallNode); break; case VariableDeclarationNode variableDeclarationNode: EmitVariableDeclaration(variableDeclarationNode); break; case WhileNode whileNode: EmitWhile(whileNode); break; default: throw new ArgumentOutOfRangeException(nameof(statement)); } } private void EmitAssignment(AssignmentNode assignmentNode) { var target = EmitLValue(assignmentNode.Target); var value = EmitExpression(assignmentNode.Value); _builder.BuildStore(value, target); } private void EmitBreak(BreakNode breakNode) { throw new NotImplementedException(); } private void EmitContinue(ContinueNode continueNode) { throw new NotImplementedException(); } private void EmitDefer(DeferNode deferNode) { throw new NotImplementedException(); } private void EmitIf(IfNode ifNode) { throw new NotImplementedException(); } private void EmitReturn(ReturnNode returnNode) { if (returnNode.Value == null) { _builder.BuildRetVoid(); } else { var returnValue = EmitExpression(returnNode.Value); _builder.BuildRet(returnValue); } } private void EmitStatementFuncCall(StatementFuncCallNode statementFuncCallNode) { EmitFuncCall(statementFuncCallNode.FuncCall); } private void EmitVariableDeclaration(VariableDeclarationNode variableDeclarationNode) { if (variableDeclarationNode.Assignment is LValueExpressionNode lValueExpressionNode) { var value = EmitLValue(lValueExpressionNode); _namedValues[variableDeclarationNode.Name] = value; return; } var allocaType = MapType(variableDeclarationNode.Type); var alloca = _builder.BuildAlloca(allocaType); if (variableDeclarationNode.Assignment != null) { var initValue = EmitExpression(variableDeclarationNode.Assignment); _builder.BuildStore(initValue, alloca); } _namedValues[variableDeclarationNode.Name] = alloca; } private void EmitWhile(WhileNode whileNode) { throw new NotImplementedException(); } private LLVMValueRef EmitExpression(ExpressionNode expressionNode) { switch (expressionNode) { case LValueExpressionNode lvalue: { var value = EmitLValue(lvalue); return _builder.BuildLoad2(MapType(lvalue.Type), value); } case RValueExpressionNode rvalue: { return EmitRValue(rvalue); } default: { throw new ArgumentOutOfRangeException(nameof(expressionNode)); } } } private LLVMValueRef EmitLValue(LValueExpressionNode lValueNode) { return lValueNode switch { ArrayIndexAccessNode arrayIndexAccessNode => EmitArrayIndexAccess(arrayIndexAccessNode), ConstArrayIndexAccessNode constArrayIndexAccessNode => EmitConstArrayIndexAccess(constArrayIndexAccessNode), DereferenceNode dereferenceNode => EmitDereference(dereferenceNode), LValueIdentifierNode lValueIdentifierNode => EmitLValueIdentifier(lValueIdentifierNode), SliceIndexAccessNode sliceIndexAccessNode => EmitSliceIndexAccess(sliceIndexAccessNode), StructFieldAccessNode structFieldAccessNode => EmitStructFieldAccess(structFieldAccessNode), StructInitializerNode structInitializerNode => EmitStructInitializer(structInitializerNode), _ => throw new ArgumentOutOfRangeException(nameof(lValueNode)) }; } private LLVMValueRef EmitRValue(RValueExpressionNode rValueNode) { return rValueNode switch { AddressOfNode addressOfNode => EmitAddressOf(addressOfNode), ArrayInitializerNode arrayInitializerNode => EmitArrayInitializer(arrayInitializerNode), BinaryExpressionNode binaryExpressionNode => EmitBinaryExpression(binaryExpressionNode), BoolLiteralNode boolLiteralNode => EmitBoolLiteral(boolLiteralNode), ConvertFloatNode convertFloatNode => EmitConvertFloat(convertFloatNode), ConvertIntNode convertIntNode => EmitConvertInt(convertIntNode), CStringLiteralNode cStringLiteralNode => EmitCStringLiteral(cStringLiteralNode), Float32LiteralNode float32LiteralNode => EmitFloat32Literal(float32LiteralNode), Float64LiteralNode float64LiteralNode => EmitFloat64Literal(float64LiteralNode), FloatToIntBuiltinNode floatToIntBuiltinNode => EmitFloatToIntBuiltin(floatToIntBuiltinNode), FuncCallNode funcCallNode => EmitFuncCall(funcCallNode), FuncIdentifierNode funcIdentifierNode => EmitFuncIdentifier(funcIdentifierNode), IntLiteralNode intLiteralNode => EmitIntLiteral(intLiteralNode), RValueIdentifierNode rValueIdentifierNode => EmitRValueIdentifier(rValueIdentifierNode), SizeBuiltinNode sizeBuiltinNode => EmitSizeBuiltin(sizeBuiltinNode), StringLiteralNode stringLiteralNode => EmitStringLiteral(stringLiteralNode), UIntLiteralNode uIntLiteralNode => EmitUIntLiteral(uIntLiteralNode), UnaryExpressionNode unaryExpressionNode => EmitUnaryExpression(unaryExpressionNode), _ => throw new ArgumentOutOfRangeException(nameof(rValueNode)) }; } private LLVMValueRef EmitArrayIndexAccess(ArrayIndexAccessNode arrayIndexAccessNode) { var arrayType = (NubArrayType)arrayIndexAccessNode.Target.Type; var elementType = MapType(arrayType.ElementType); var target = EmitExpression(arrayIndexAccessNode.Target); var index = EmitExpression(arrayIndexAccessNode.Index); return _builder.BuildGEP2(elementType, target, [index]); } private LLVMValueRef EmitArrayInitializer(ArrayInitializerNode arrayInitializerNode) { var capacity = EmitExpression(arrayInitializerNode.Capacity); return _builder.BuildArrayAlloca(MapType(arrayInitializerNode.ElementType), capacity); } private LLVMValueRef EmitBinaryExpression(BinaryExpressionNode binaryExpressionNode) { var left = EmitExpression(binaryExpressionNode.Left); var right = EmitExpression(binaryExpressionNode.Right); var leftType = binaryExpressionNode.Left.Type; if (leftType is NubIntType) { return binaryExpressionNode.Operator switch { BinaryOperator.Plus => _builder.BuildAdd(left, right), BinaryOperator.Minus => _builder.BuildSub(left, right), BinaryOperator.Multiply => _builder.BuildMul(left, right), BinaryOperator.Divide => _builder.BuildSDiv(left, right), BinaryOperator.Modulo => _builder.BuildSRem(left, right), BinaryOperator.Equal => _builder.BuildICmp(LLVMIntPredicate.LLVMIntEQ, left, right), BinaryOperator.NotEqual => _builder.BuildICmp(LLVMIntPredicate.LLVMIntNE, left, right), BinaryOperator.LessThan => _builder.BuildICmp(LLVMIntPredicate.LLVMIntSLT, left, right), BinaryOperator.LessThanOrEqual => _builder.BuildICmp(LLVMIntPredicate.LLVMIntSLE, left, right), BinaryOperator.GreaterThan => _builder.BuildICmp(LLVMIntPredicate.LLVMIntSGT, left, right), BinaryOperator.GreaterThanOrEqual => _builder.BuildICmp(LLVMIntPredicate.LLVMIntSGE, left, right), BinaryOperator.BitwiseAnd => _builder.BuildAnd(left, right), BinaryOperator.BitwiseOr => _builder.BuildOr(left, right), BinaryOperator.BitwiseXor => _builder.BuildXor(left, right), BinaryOperator.LeftShift => _builder.BuildShl(left, right), BinaryOperator.RightShift => _builder.BuildAShr(left, right), _ => throw new NotSupportedException($"Binary operator {binaryExpressionNode.Operator} not supported for int") }; } if (leftType is NubFloatType) { return binaryExpressionNode.Operator switch { BinaryOperator.Plus => _builder.BuildFAdd(left, right), BinaryOperator.Minus => _builder.BuildFSub(left, right), BinaryOperator.Multiply => _builder.BuildFMul(left, right), BinaryOperator.Divide => _builder.BuildFDiv(left, right), BinaryOperator.Modulo => _builder.BuildFRem(left, right), BinaryOperator.Equal => _builder.BuildFCmp(LLVMRealPredicate.LLVMRealOEQ, left, right), BinaryOperator.NotEqual => _builder.BuildFCmp(LLVMRealPredicate.LLVMRealONE, left, right), BinaryOperator.LessThan => _builder.BuildFCmp(LLVMRealPredicate.LLVMRealOLT, left, right), BinaryOperator.LessThanOrEqual => _builder.BuildFCmp(LLVMRealPredicate.LLVMRealOLE, left, right), BinaryOperator.GreaterThan => _builder.BuildFCmp(LLVMRealPredicate.LLVMRealOGT, left, right), BinaryOperator.GreaterThanOrEqual => _builder.BuildFCmp(LLVMRealPredicate.LLVMRealOGE, left, right), _ => throw new NotSupportedException($"Binary operator {binaryExpressionNode.Operator} not supported for float") }; } if (leftType is NubBoolType) { return binaryExpressionNode.Operator switch { BinaryOperator.LogicalAnd => _builder.BuildAnd(left, right), BinaryOperator.LogicalOr => _builder.BuildOr(left, right), BinaryOperator.Equal => _builder.BuildICmp(LLVMIntPredicate.LLVMIntEQ, left, right), BinaryOperator.NotEqual => _builder.BuildICmp(LLVMIntPredicate.LLVMIntNE, left, right), _ => throw new NotSupportedException($"Binary operator {binaryExpressionNode.Operator} not supported for bool") }; } throw new NotSupportedException($"Binary operations for type {leftType} not supported"); } private LLVMValueRef EmitBoolLiteral(BoolLiteralNode boolLiteralNode) { return LLVMValueRef.CreateConstInt(LLVMTypeRef.Int1, boolLiteralNode.Value ? 1ul : 0ul); } private LLVMValueRef EmitConstArrayIndexAccess(ConstArrayIndexAccessNode constArrayIndexAccessNode) { var arrayType = (NubConstArrayType)constArrayIndexAccessNode.Target.Type; var target = EmitExpression(constArrayIndexAccessNode.Target); var index = EmitExpression(constArrayIndexAccessNode.Index); return _builder.BuildGEP2(MapType(arrayType), target, [LLVMValueRef.CreateConstInt(LLVMTypeRef.Int64, 0), index]); } private LLVMValueRef EmitConvertFloat(ConvertFloatNode convertFloatNode) { throw new NotImplementedException(); } private LLVMValueRef EmitConvertInt(ConvertIntNode convertIntNode) { throw new NotImplementedException(); } private LLVMValueRef EmitCStringLiteral(CStringLiteralNode cStringLiteralNode) { return _builder.BuildGlobalStringPtr(cStringLiteralNode.Value, "str"); } private LLVMValueRef EmitDereference(DereferenceNode dereferenceNode) { throw new NotImplementedException(); } private LLVMValueRef EmitFloat32Literal(Float32LiteralNode float32LiteralNode) { return LLVMValueRef.CreateConstReal(LLVMTypeRef.Float, float32LiteralNode.Value); } private LLVMValueRef EmitFloat64Literal(Float64LiteralNode float64LiteralNode) { return LLVMValueRef.CreateConstReal(LLVMTypeRef.Double, float64LiteralNode.Value); } private LLVMValueRef EmitFloatToIntBuiltin(FloatToIntBuiltinNode floatToIntBuiltinNode) { throw new NotImplementedException(); } private LLVMValueRef EmitFuncCall(FuncCallNode funcCallNode) { var function = EmitExpression(funcCallNode.Expression); var args = funcCallNode.Parameters.Select(EmitExpression).ToArray(); return _builder.BuildCall2(MapType(funcCallNode.Expression.Type), function, args); } private LLVMValueRef EmitFuncIdentifier(FuncIdentifierNode funcIdentifierNode) { return _module.GetNamedFunction(funcIdentifierNode.Name); } private LLVMValueRef EmitIntLiteral(IntLiteralNode intLiteralNode) { return LLVMValueRef.CreateConstInt(MapType(intLiteralNode.Type), (ulong)intLiteralNode.Value, true); } private LLVMValueRef EmitAddressOf(AddressOfNode addressOfNode) { throw new NotImplementedException(); } private LLVMValueRef EmitLValueIdentifier(LValueIdentifierNode lValueIdentifierNode) { return _namedValues[lValueIdentifierNode.Name]; } private LLVMValueRef EmitRValueIdentifier(RValueIdentifierNode rValueIdentifierNode) { return _builder.BuildLoad2(MapType(rValueIdentifierNode.Type), _namedValues[rValueIdentifierNode.Name], rValueIdentifierNode.Name); } private LLVMValueRef EmitSizeBuiltin(SizeBuiltinNode sizeBuiltinNode) { throw new NotImplementedException(); } private LLVMValueRef EmitSliceIndexAccess(SliceIndexAccessNode sliceIndexAccessNode) { throw new NotImplementedException(); } private LLVMValueRef EmitStringLiteral(StringLiteralNode stringLiteralNode) { throw new NotImplementedException(); } private LLVMValueRef EmitStructFieldAccess(StructFieldAccessNode structFieldAccessNode) { var type = (NubStructType)structFieldAccessNode.Target.Type; var target = EmitLValue(structFieldAccessNode.Target); var fieldIndex = type.Fields.FindIndex(x => x.Name == structFieldAccessNode.Field); return _builder.BuildStructGEP2(MapType(structFieldAccessNode.Target.Type), target, (uint)fieldIndex); } private LLVMValueRef EmitStructInitializer(StructInitializerNode structInitializerNode) { var type = MapType(structInitializerNode.StructType); var ptr = _builder.BuildAlloca(type); foreach (var initializer in structInitializerNode.Initializers) { var value = EmitExpression(initializer.Value); var fieldIndex = structInitializerNode.StructType.Fields.FindIndex(x => x.Name == initializer.Key); var fieldPtr = _builder.BuildStructGEP2(type, ptr, (uint)fieldIndex); _builder.BuildStore(value, fieldPtr); } return ptr; } private LLVMValueRef EmitUIntLiteral(UIntLiteralNode uIntLiteralNode) { return LLVMValueRef.CreateConstInt(MapType(uIntLiteralNode.Type), uIntLiteralNode.Value); } private LLVMValueRef EmitUnaryExpression(UnaryExpressionNode unaryExpressionNode) { var operand = EmitExpression(unaryExpressionNode.Operand); return unaryExpressionNode.Operator switch { UnaryOperator.Negate when unaryExpressionNode.Operand.Type is NubIntType => _builder.BuildNeg(operand), UnaryOperator.Negate when unaryExpressionNode.Operand.Type is NubFloatType => _builder.BuildFNeg(operand), UnaryOperator.Invert => _builder.BuildNot(operand), _ => throw new NotImplementedException($"Unary operator {unaryExpressionNode.Operator} not implemented") }; } }