using LLVMSharp.Interop; using NubLang.Ast; namespace NubLang.Generation; public class Generator { private readonly CompilationUnit _compilationUnit; public Generator(CompilationUnit compilationUnit) { _compilationUnit = compilationUnit; LLVM.LinkInMCJIT(); LLVM.InitializeX86TargetInfo(); LLVM.InitializeX86Target(); LLVM.InitializeX86TargetMC(); LLVM.InitializeX86AsmPrinter(); _module = LLVMModuleRef.CreateWithName("test"); } private bool _done; private readonly Dictionary _namedTypes = new(); private readonly LLVMModuleRef _module; private LLVMBuilderRef _builder; private LLVMValueRef _currentFunction; private Dictionary _namedValues = []; private LLVMTypeRef MapType(NubType nubType) { return nubType switch { NubArrayType nubArrayType => LLVMTypeRef.CreatePointer(MapType(nubArrayType.ElementType), 0), NubBoolType => LLVMTypeRef.Int1, NubConstArrayType nubConstArrayType => LLVMTypeRef.CreateArray(MapType(nubConstArrayType.ElementType), (uint)nubConstArrayType.Size), NubCStringType => LLVMTypeRef.CreatePointer(LLVMTypeRef.Int8, 0), NubFloatType nubFloatType => nubFloatType.Width switch { 32 => LLVMTypeRef.Float, 64 => LLVMTypeRef.Double, _ => throw new ArgumentOutOfRangeException(nameof(nubFloatType.Width)) }, NubFuncType nubFuncType => LLVMTypeRef.CreateFunction(MapType(nubFuncType.ReturnType), nubFuncType.Parameters.Select(MapType).ToArray()), NubIntType nubIntType => LLVMTypeRef.CreateInt((uint)nubIntType.Width), NubPointerType nubPointerType => LLVMTypeRef.CreatePointer(MapType(nubPointerType.BaseType), 0), NubSliceType nubSliceType => GetOrCreateSliceType(nubSliceType), NubStringType => GetOrCreateStringType(), NubStructType nubStructType => GetOrCreateStructType(nubStructType), NubVoidType => LLVMTypeRef.Void, _ => throw new ArgumentOutOfRangeException(nameof(nubType)) }; } private LLVMTypeRef GetOrCreateStructType(NubStructType structType) { var key = $"struct.{structType.Module}.{structType.Name}"; if (!_namedTypes.TryGetValue(key, out var llvmType)) { llvmType = _module.Context.CreateNamedStruct(structType.Name); _namedTypes[key] = llvmType; var fields = structType.Fields.Select(f => MapType(f.Type)).ToArray(); llvmType.StructSetBody(fields, false); } return llvmType; } private LLVMTypeRef GetOrCreateSliceType(NubSliceType sliceType) { var key = $"slice.{sliceType.ElementType}"; if (!_namedTypes.TryGetValue(key, out var llvmType)) { llvmType = LLVMTypeRef.CreateStruct([LLVMTypeRef.CreatePointer(MapType(sliceType.ElementType), 0), LLVMTypeRef.Int64], false); _namedTypes[key] = llvmType; } return llvmType; } private LLVMTypeRef GetOrCreateStringType() { const string key = "string"; if (!_namedTypes.TryGetValue(key, out var llvmType)) { llvmType = _module.Context.CreateNamedStruct("string"); llvmType.StructSetBody([LLVMTypeRef.CreatePointer(LLVMTypeRef.Int8, 0), LLVMTypeRef.Int64], false); _namedTypes[key] = llvmType; } return llvmType; } public string Generate() { if (_done) { return _module.PrintToString(); } _done = true; _builder = _module.Context.CreateBuilder(); foreach (var funcNode in _compilationUnit.Functions) { EmitFunction(funcNode); } if (_module.TryVerify(LLVMVerifierFailureAction.LLVMPrintMessageAction, out string error)) { return _module.PrintToString(); } else { throw new InvalidOperationException($"Invalid LLVM module: {error}"); } } private void EmitFunction(FuncNode funcNode) { var parameters = funcNode.Prototype.Parameters.Select(x => MapType(x.Type)).ToArray(); var functionType = LLVMTypeRef.CreateFunction(MapType(funcNode.Prototype.ReturnType), parameters); _currentFunction = _module.AddFunction(funcNode.Name, functionType); if (funcNode.Body == null) { return; } _namedValues.Clear(); var entryBlock = _currentFunction.AppendBasicBlock("entry"); _builder.PositionAtEnd(entryBlock); for (var i = 0; i < funcNode.Prototype.Parameters.Count; i++) { var param = funcNode.Prototype.Parameters[i]; var llvmParam = _currentFunction.GetParam((uint)i); llvmParam.Name = param.Name; var alloca = _builder.BuildAlloca(MapType(param.Type), param.Name); _builder.BuildStore(llvmParam, alloca); _namedValues[param.Name] = alloca; } EmitBlock(funcNode.Body); } private void EmitBlock(BlockNode blockNode) { foreach (var statement in blockNode.Statements) { EmitStatement(statement); } } private void EmitStatement(StatementNode statement) { switch (statement) { case AssignmentNode assignmentNode: EmitAssignment(assignmentNode); break; case BlockNode blockNode: EmitBlock(blockNode); break; case BreakNode breakNode: EmitBreak(breakNode); break; case ContinueNode continueNode: EmitContinue(continueNode); break; case DeferNode deferNode: EmitDefer(deferNode); break; case IfNode ifNode: EmitIf(ifNode); break; case ReturnNode returnNode: EmitReturn(returnNode); break; case StatementFuncCallNode statementFuncCallNode: EmitStatementFuncCall(statementFuncCallNode); break; case VariableDeclarationNode variableDeclarationNode: EmitVariableDeclaration(variableDeclarationNode); break; case WhileNode whileNode: EmitWhile(whileNode); break; default: throw new ArgumentOutOfRangeException(nameof(statement)); } } private void EmitAssignment(AssignmentNode assignmentNode) { throw new NotImplementedException(); } private void EmitBreak(BreakNode breakNode) { throw new NotImplementedException(); } private void EmitContinue(ContinueNode continueNode) { throw new NotImplementedException(); } private void EmitDefer(DeferNode deferNode) { throw new NotImplementedException(); } private void EmitIf(IfNode ifNode) { throw new NotImplementedException(); } private void EmitReturn(ReturnNode returnNode) { throw new NotImplementedException(); } private void EmitStatementFuncCall(StatementFuncCallNode statementFuncCallNode) { throw new NotImplementedException(); } private void EmitVariableDeclaration(VariableDeclarationNode variableDeclarationNode) { throw new NotImplementedException(); } private void EmitWhile(WhileNode whileNode) { throw new NotImplementedException(); } private LLVMValueRef EmitExpression(ExpressionNode expressionNode) { return expressionNode switch { ArrayIndexAccessNode arrayIndexAccessNode => EmitArrayIndexAccess(arrayIndexAccessNode), ArrayInitializerNode arrayInitializerNode => EmitArrayInitializer(arrayInitializerNode), BinaryExpressionNode binaryExpressionNode => EmitBinaryExpression(binaryExpressionNode), BoolLiteralNode boolLiteralNode => EmitBoolLiteral(boolLiteralNode), ConstArrayIndexAccessNode constArrayIndexAccessNode => EmitConstArrayIndexAccess(constArrayIndexAccessNode), ConstArrayInitializerNode constArrayInitializerNode => EmitConstArrayInitializer(constArrayInitializerNode), ConvertFloatNode convertFloatNode => EmitConvertFloat(convertFloatNode), ConvertIntNode convertIntNode => EmitConvertInt(convertIntNode), CStringLiteralNode cStringLiteralNode => EmitCStringLiteral(cStringLiteralNode), DereferenceNode dereferenceNode => EmitDereference(dereferenceNode), Float32LiteralNode float32LiteralNode => EmitFloat32Literal(float32LiteralNode), Float64LiteralNode float64LiteralNode => EmitFloat64Literal(float64LiteralNode), FloatToIntBuiltinNode floatToIntBuiltinNode => EmitFloatToIntBuiltin(floatToIntBuiltinNode), FuncCallNode funcCallNode => EmitFuncCall(funcCallNode), FuncIdentifierNode funcIdentifierNode => EmitFuncIdentifier(funcIdentifierNode), IntLiteralNode intLiteralNode => EmitIntLiteral(intLiteralNode), AddressOfNode addressOfNode => EmitAddressOf(addressOfNode), LValueIdentifierNode lValueIdentifierNode => EmitLValueIdentifier(lValueIdentifierNode), RValueIdentifierNode rValueIdentifierNode => EmitRValueIdentifier(rValueIdentifierNode), SizeBuiltinNode sizeBuiltinNode => EmitSizeBuiltin(sizeBuiltinNode), SliceIndexAccessNode sliceIndexAccessNode => EmitSliceIndexAccess(sliceIndexAccessNode), StringLiteralNode stringLiteralNode => EmitStringLiteral(stringLiteralNode), StructFieldAccessNode structFieldAccessNode => EmitStructFieldAccess(structFieldAccessNode), StructInitializerNode structInitializerNode => EmitStructInitializer(structInitializerNode), UIntLiteralNode uIntLiteralNode => EmitUIntLiteral(uIntLiteralNode), UnaryExpressionNode unaryExpressionNode => EmitUnaryExpression(unaryExpressionNode), _ => throw new ArgumentOutOfRangeException(nameof(expressionNode)) }; } private LLVMValueRef EmitArrayIndexAccess(ArrayIndexAccessNode arrayIndexAccessNode) { throw new NotImplementedException(); } private LLVMValueRef EmitArrayInitializer(ArrayInitializerNode arrayInitializerNode) { throw new NotImplementedException(); } private LLVMValueRef EmitBinaryExpression(BinaryExpressionNode binaryExpressionNode) { throw new NotImplementedException(); } private LLVMValueRef EmitBoolLiteral(BoolLiteralNode boolLiteralNode) { throw new NotImplementedException(); } private LLVMValueRef EmitConstArrayIndexAccess(ConstArrayIndexAccessNode constArrayIndexAccessNode) { throw new NotImplementedException(); } private LLVMValueRef EmitConstArrayInitializer(ConstArrayInitializerNode constArrayInitializerNode) { throw new NotImplementedException(); } private LLVMValueRef EmitConvertFloat(ConvertFloatNode convertFloatNode) { throw new NotImplementedException(); } private LLVMValueRef EmitConvertInt(ConvertIntNode convertIntNode) { throw new NotImplementedException(); } private LLVMValueRef EmitCStringLiteral(CStringLiteralNode cStringLiteralNode) { throw new NotImplementedException(); } private LLVMValueRef EmitDereference(DereferenceNode dereferenceNode) { throw new NotImplementedException(); } private LLVMValueRef EmitFloat32Literal(Float32LiteralNode float32LiteralNode) { throw new NotImplementedException(); } private LLVMValueRef EmitFloat64Literal(Float64LiteralNode float64LiteralNode) { throw new NotImplementedException(); } private LLVMValueRef EmitFloatToIntBuiltin(FloatToIntBuiltinNode floatToIntBuiltinNode) { throw new NotImplementedException(); } private LLVMValueRef EmitFuncCall(FuncCallNode funcCallNode) { throw new NotImplementedException(); } private LLVMValueRef EmitFuncIdentifier(FuncIdentifierNode funcIdentifierNode) { throw new NotImplementedException(); } private LLVMValueRef EmitIntLiteral(IntLiteralNode intLiteralNode) { throw new NotImplementedException(); } private LLVMValueRef EmitAddressOf(AddressOfNode addressOfNode) { throw new NotImplementedException(); } private LLVMValueRef EmitLValueIdentifier(LValueIdentifierNode lValueIdentifierNode) { throw new NotImplementedException(); } private LLVMValueRef EmitRValueIdentifier(RValueIdentifierNode rValueIdentifierNode) { throw new NotImplementedException(); } private LLVMValueRef EmitSizeBuiltin(SizeBuiltinNode sizeBuiltinNode) { throw new NotImplementedException(); } private LLVMValueRef EmitSliceIndexAccess(SliceIndexAccessNode sliceIndexAccessNode) { throw new NotImplementedException(); } private LLVMValueRef EmitStringLiteral(StringLiteralNode stringLiteralNode) { throw new NotImplementedException(); } private LLVMValueRef EmitStructFieldAccess(StructFieldAccessNode structFieldAccessNode) { throw new NotImplementedException(); } private LLVMValueRef EmitStructInitializer(StructInitializerNode structInitializerNode) { throw new NotImplementedException(); } private LLVMValueRef EmitUIntLiteral(UIntLiteralNode uIntLiteralNode) { throw new NotImplementedException(); } private LLVMValueRef EmitUnaryExpression(UnaryExpressionNode unaryExpressionNode) { throw new NotImplementedException(); } }