diff --git a/xla/mlir/runtime/transforms/BUILD b/xla/mlir/runtime/transforms/BUILD index 680d49333ed3a..6969f7bdbf853 100644 --- a/xla/mlir/runtime/transforms/BUILD +++ b/xla/mlir/runtime/transforms/BUILD @@ -119,7 +119,6 @@ cc_library( ":compilation_pipeline_options", ":compiler", ":passes", - "//xla/mlir/backends/cpu/transforms:passes", "//xla/mlir/memref/transforms:passes", "//xla/mlir/runtime/ir:rt", "//xla/mlir_hlo:transforms_passes", diff --git a/xla/mlir/runtime/transforms/compilation_pipeline_cpu.cc b/xla/mlir/runtime/transforms/compilation_pipeline_cpu.cc index c5e845c22516d..7515ebd9c51f5 100644 --- a/xla/mlir/runtime/transforms/compilation_pipeline_cpu.cc +++ b/xla/mlir/runtime/transforms/compilation_pipeline_cpu.cc @@ -55,7 +55,6 @@ limitations under the License. #include "mlir/Target/LLVMIR/Dialect/X86Vector/X86VectorToLLVMIRTranslation.h" // from @llvm-project #endif #include "mlir/Transforms/Passes.h" // from @llvm-project -#include "xla/mlir/backends/cpu/transforms/passes.h" #include "xla/mlir/memref/transforms/passes.h" #include "xla/mlir/runtime/ir/rt_dialect.h" #include "xla/mlir/runtime/transforms/compilation_pipeline_options.h" @@ -106,7 +105,6 @@ static void CreateXlaCpuCompilationPipeline(mlir::OpPassManager& pm, // Convert entry function to the XLA entrypoint. pm.addPass(CreateExportRuntimeFunctionsPass()); - pm.addPass(cpu::createConvertXlaCpuToCpuRuntimePass()); pm.addPass(CreateConvertCustomCallsPass()); pm.addPass(CreateConvertAssertsPass()); diff --git a/xla/service/cpu/BUILD b/xla/service/cpu/BUILD index 2ce4fc0133431..81148df9e2714 100644 --- a/xla/service/cpu/BUILD +++ b/xla/service/cpu/BUILD @@ -223,7 +223,6 @@ cc_library( ":cpu_options", ":dot_op_emitter", ":executable_proto_cc", - ":hlo_xla_runtime_pipeline", ":ir_emission_utils", ":ir_emitter", ":onednn_matmul_rewriter", @@ -469,56 +468,6 @@ cc_library( deps = [":xla_framework_proto_cc"], ) -cc_library( - name = "hlo_xla_runtime_pipeline", - srcs = ["hlo_xla_runtime_pipeline.cc"], - hdrs = ["hlo_xla_runtime_pipeline.h"], - local_defines = select({ - ":experimental_mlir_gpu_enabled": ["EXPERIMENTAL_MLIR_GPU=1"], - "//conditions:default": [], - }), - deps = [ - "//xla:status", - "//xla/mlir/backends/cpu/transforms:passes", - "//xla/mlir/runtime/transforms:compiler", - "//xla/mlir_hlo:all_passes", - "//xla/mlir_hlo:mhlo_passes", - "//xla/mlir_hlo:transforms_passes", - "//xla/runtime:compiler", - "@llvm-project//mlir:ArithTransforms", - "@llvm-project//mlir:BufferizationToMemRef", - "@llvm-project//mlir:BufferizationTransforms", - "@llvm-project//mlir:ComplexToStandard", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:FuncTransforms", - "@llvm-project//mlir:LinalgTransforms", - "@llvm-project//mlir:MemRefTransforms", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:ReconcileUnrealizedCasts", - "@llvm-project//mlir:SCFToControlFlow", - "@llvm-project//mlir:SCFTransforms", - "@llvm-project//mlir:ShapeToStandard", - "@llvm-project//mlir:ShapeTransforms", - "@llvm-project//mlir:SparseTensorTransforms", - "@llvm-project//mlir:Support", - "@llvm-project//mlir:TensorToLinalg", - "@llvm-project//mlir:TensorTransforms", - "@llvm-project//mlir:Transforms", - "@llvm-project//mlir:VectorToLLVM", - "@llvm-project//mlir:VectorToSCF", - "@llvm-project//mlir:VectorTransforms", - "@tsl//tsl/platform:errors", - "@tsl//tsl/platform:logging", - ] + select({ - ":experimental_mlir_gpu_enabled": [ - "@llvm-project//mlir:GPUDialect", - "@llvm-project//mlir:GPUToNVVMTransforms", - ], - "//conditions:default": [], - }), - alwayslink = 1, # has pipeline registration -) - cc_library( name = "simple_orc_jit", srcs = [ diff --git a/xla/service/cpu/cpu_compiler.cc b/xla/service/cpu/cpu_compiler.cc index ed21241ea4262..2ef92ea90a2b3 100644 --- a/xla/service/cpu/cpu_compiler.cc +++ b/xla/service/cpu/cpu_compiler.cc @@ -66,29 +66,10 @@ limitations under the License. #ifdef TF_LLVM_X86_AVAILABLE #include "llvm/TargetParser/X86TargetParser.h" #endif -#include "mlir/Conversion/AffineToStandard/AffineToStandard.h" // from @llvm-project -#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" // from @llvm-project -#include "mlir/Dialect/Affine/IR/AffineOps.h" // from @llvm-project -#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project -#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" // from @llvm-project -#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project -#include "mlir/Dialect/Linalg/IR/Linalg.h" // from @llvm-project -#include "mlir/Dialect/MemRef/Transforms/AllocationOpInterfaceImpl.h" // from @llvm-project -#include "mlir/Dialect/MemRef/Transforms/Passes.h" // from @llvm-project -#include "mlir/Dialect/SCF/IR/SCF.h" // from @llvm-project -#include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project #include "mlir/Dialect/Vector/IR/VectorOps.h" // from @llvm-project -#include "mlir/IR/Attributes.h" // from @llvm-project -#include "mlir/IR/Builders.h" // from @llvm-project -#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project -#include "mlir/IR/BuiltinTypes.h" // from @llvm-project -#include "mlir/IR/OperationSupport.h" // from @llvm-project -#include "mlir/IR/OwningOpRef.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project -#include "mlir/Pass/PassRegistry.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project -#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" // from @llvm-project #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" // from @llvm-project #include "mlir/Target/LLVMIR/Export.h" // from @llvm-project @@ -106,18 +87,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_schedule.h" #include "xla/layout_util.h" #include "xla/map_util.h" -#include "xla/mlir/framework/ir/xla_framework.h" -#include "xla/mlir/runtime/ir/rt_dialect.h" -#include "xla/mlir/runtime/transforms/calling_convention.h" -#include "xla/mlir/runtime/transforms/compilation_pipeline_cpu.h" -#include "xla/mlir/runtime/transforms/compiler.h" -#include "xla/mlir/runtime/transforms/jit_compiler.h" -#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" -#include "xla/mlir_hlo/mhlo/transforms/passes.h" #include "xla/mlir_hlo/transforms/passes.h" -#include "xla/runtime/custom_call_registry.h" -#include "xla/runtime/executable.h" -#include "xla/runtime/jit_executable.h" #include "xla/service/algebraic_simplifier.h" #include "xla/service/all_reduce_promotion.h" #include "xla/service/all_to_all_decomposer.h" @@ -145,15 +115,8 @@ limitations under the License. #include "xla/service/cpu/cpu_layout_assignment.h" #include "xla/service/cpu/cpu_options.h" #include "xla/service/cpu/dot_op_emitter.h" -#include "xla/service/cpu/hlo_xla_runtime_pipeline.h" #include "xla/service/cpu/ir_emitter.h" #include "xla/service/cpu/parallel_task_assignment.h" -#include "xla/service/cpu/runtime/collectives.h" -#include "xla/service/cpu/runtime/convolution_call.h" -#include "xla/service/cpu/runtime/custom_call.h" -#include "xla/service/cpu/runtime/fft_call.h" -#include "xla/service/cpu/runtime/rng_call.h" -#include "xla/service/cpu/runtime/xfeed.h" #include "xla/service/cpu/simple_orc_jit.h" #include "xla/service/cpu/target_machine_features.h" #include "xla/service/cpu_gpu_shape_verifier.h" @@ -244,58 +207,6 @@ limitations under the License. #include "xla/service/cpu/onednn_ops_rewriter.h" #endif -namespace { - -// We need to explicitly load all the dialects we will involved in emitting the -// IR. This is only needed because of how MLIR is bolted into XLA and does not -// make use of the MLIR infrastructure (like using a proper pass pipeline). -// Hopefully this will all go away at some point in favor of a better -// integration. -void LoadMLIRDialects(mlir::MLIRContext& context) { - context.loadDialect(); - mlir::registerBuiltinDialectTranslation(context); - mlir::registerLLVMDialectTranslation(context); - - mlir::DialectRegistry registry; - mlir::memref::registerAllocationOpInterfaceExternalModels(registry); - context.appendDialectRegistry(registry); -} - -xla::cpu::HloXlaRuntimePipelineOptions GetHloXlaRuntimePipelineOptions( - llvm::Triple target_triple, llvm::StringRef cpu_name) { - xla::cpu::HloXlaRuntimePipelineOptions options; - options.enable_tiling_and_fusion = false; - if (xla::GetDebugOptionsFromFlags().xla_cpu_enable_custom_matmul_tiling()) { - options.matmul_tile_sizes = { - xla::GetDebugOptionsFromFlags().xla_cpu_matmul_tiling_m_dim(), - xla::GetDebugOptionsFromFlags().xla_cpu_matmul_tiling_n_dim(), - xla::GetDebugOptionsFromFlags().xla_cpu_matmul_tiling_k_dim()}; - } -#ifdef TF_LLVM_X86_AVAILABLE - options.enable_avx2 = [&] { - // Derive whether this is an x86 CPU with AVX2 enabled. - if (!target_triple.isX86()) return false; - llvm::SmallVector cpu_features; - llvm::X86::getFeaturesForCPU(cpu_name, cpu_features); - return llvm::is_contained(cpu_features, "avx2"); - }(); -#else - options.enable_avx2 = false; -#endif - options.cpu_name = cpu_name; - if (xla::GetDebugOptionsFromFlags().xla_cpu_enable_mlir_fusion_outlining()) { - options.enable_fusion_outlining = true; - } - return options; -} - -} // namespace - namespace xla { namespace { @@ -359,102 +270,6 @@ se::Platform::Id CpuAotCompilationOptions::PlatformId() const { return se::host::kHostPlatformId; } -namespace { - -namespace runtime = ::xla::runtime; - -class FlattenTuplesAndBufferizeTypeConverter : public mlir::TypeConverter { - public: - FlattenTuplesAndBufferizeTypeConverter() { - addConversion( - [](mlir::Type type, mlir::SmallVectorImpl& converted) - -> mlir::LogicalResult { - mlir::bufferization::BufferizeTypeConverter bufferize; - auto tuple_type = mlir::dyn_cast(type); - if (!tuple_type) { - converted.push_back(bufferize.convertType(type)); - return mlir::success(); - } - // TODO(b/249078472): update this expansion to support nested tuples. - converted.append(llvm::to_vector(llvm::map_range( - tuple_type.getTypes(), - [&](mlir::Type t) { return bufferize.convertType(t); }))); - return mlir::success(); - }); - } -}; - -runtime::JitExecutable::Options GetXlaRuntimeJitExecutableOptions( - const HloModule& module, mlir::DialectRegistry* custom_registry) { - runtime::CpuPipelineOptions copts; - runtime::JitExecutable::Options opts; - copts.xla_cpu_sparse_cuda_threads = - GetDebugOptionsFromFlags().xla_cpu_sparse_cuda_threads(); - std::optional maybeOverriddenPipeline = - options::ExperimentalOverriddenPipeline(module.config()); - opts.specialization = runtime::JitExecutable::Specialization::kDisabled; - opts.compiler.register_dialects = - [custom_registry](xla::runtime::DialectRegistry& dialects) { - dialects->insert(); - runtime::RegisterDefaultXlaCpuRuntimeDialects(dialects); - RegisterHloXlaRuntimePipelineDialects(*dialects); - if (custom_registry) { - custom_registry->appendTo(*dialects); - } - }; - opts.compiler.symbols_binding = runtime::ToSymbolsBinding( - [](runtime::DirectCustomCallRegistry& registry) { - PopulateXlaCpuCollectivesCall(registry); - PopulateXlaCpuConvolutionCall(registry); - PopulateXlaCpuCustomCall(registry); - PopulateXlaCpuFftCall(registry); - PopulateXlaCpuRngCall(registry); - PopulateXlaXfeedCall(registry); - }); - opts.compiler.create_compilation_pipeline = - [copts, maybeOverriddenPipeline = std::move(maybeOverriddenPipeline)]( - xla::runtime::PassManager& passes) { - if (maybeOverriddenPipeline.has_value()) { - std::string error_message; - llvm::raw_string_ostream error_stream(error_message); - mlir::LogicalResult result = mlir::parsePassPipeline( - maybeOverriddenPipeline.value(), *passes, error_stream); - if (mlir::failed(result)) { - LOG(ERROR) - << "Failed to parse experimental CPU compilation pipeline: " - << error_stream.str(); - return absl::InternalError( - "Failed to parse experimental CPU compilation pipeline."); - } - LOG(INFO) << "Experimental CPU compilation pipeline: " - << maybeOverriddenPipeline.value(); - return absl::OkStatus(); - } - - HloXlaRuntimePipelineOptions options = GetHloXlaRuntimePipelineOptions( - llvm::Triple(llvm::sys::getProcessTriple()), - llvm::sys::getHostCPUName()); - options.xla_cpu_sparse_cuda_threads = - GetDebugOptionsFromFlags().xla_cpu_sparse_cuda_threads(); - - Status status = CreateHloXlaRuntimePipeline(passes, options); - if (!status.ok()) { - LOG(ERROR) << "HLO-XLA Runtime pipeline failed with: " - << status.message(); - return status; - } - runtime::CreateDefaultXlaCpuRuntimeCompilationPipeline(passes, copts); - return absl::OkStatus(); - }; - opts.compiler.calling_convention = runtime::ResultsToOutsCallingConvention( - FlattenTuplesAndBufferizeTypeConverter()); - opts.compiler.embed_ir_in_executable = - module.config().debug_options().xla_embed_ir_in_executable(); - return opts; -} - -} // namespace - CpuAotCompilationResult::CpuAotCompilationResult( ObjectFileData object_file_data, std::vector buffer_infos, int64_t result_buffer_index, std::unique_ptr module, @@ -1138,129 +953,6 @@ void InitializeLLVMCommandLineOptions(const HloModuleConfig& config) { config.debug_options().xla_backend_extra_options()); } -Status LowerMLIRModule(HloModule* module, mlir::ModuleOp mlir_module, - mlir::MLIRContext& mlir_context, - const llvm::TargetMachine& target) { - LoadMLIRDialects(mlir_context); - mlir::PassManager pm(&mlir_context); - if (VLOG_IS_ON(5)) { - mlir_context.disableMultithreading(); - // Do not print large constants. - mlir::OpPrintingFlags printing_flags; - printing_flags.elideLargeElementsAttrs(32); - pm.enableIRPrinting( - [](mlir::Pass* pass, mlir::Operation* op) { return true; }, - [](mlir::Pass* pass, mlir::Operation* op) { return true; }, - /*printModuleScope=*/true, /*printAfterOnlyOnChange=*/true, - /*printAfterOnlyOnFailure=*/false, llvm::errs(), printing_flags); - } - - xla::runtime::PassManager xla_pm(&pm); - HloXlaRuntimePipelineOptions options = GetHloXlaRuntimePipelineOptions( - target.getTargetTriple(), target.getTargetCPU()); - options.sparse_bufferization = false; - TF_RETURN_IF_ERROR(CreateHloXlaRuntimePipeline(xla_pm, options)); - - runtime::CpuPipelineOptions cpu_pipeline_opts; - CreateDefaultXlaCpuRuntimeCompilationPipeline(xla_pm, cpu_pipeline_opts); - - if (pm.run(mlir_module).failed()) { - mlir_module->dump(); - return tsl::errors::Internal("Failed to compile through MLIR pipeline"); - } - return absl::OkStatus(); -} - -absl::StatusOr> createMLIRModule( - HloModule* module, mlir::MLIRContext& mlir_context, - BufferAssignment* assignment, - XlaFrameworkMapping* export_mapping = nullptr) { - LoadMLIRDialects(mlir_context); - mlir::OpBuilder builder(&mlir_context); - auto mlir_module = builder.create(builder.getUnknownLoc()); - TF_RETURN_IF_ERROR(ConvertHloToMlirHlo(mlir_module, module)); - - // Flatten tuples before we set up the input mapping. The flattening pass - // doesn't preserve attributes so we'd lose some in the process. - mlir::PassManager pm(mlir_module.getOperation()->getName(), - mlir::PassManager::Nesting::Implicit); - pm.addPass(mlir::mhlo::createExpandHloTuplesPass("main")); - if (failed(pm.run(mlir_module.getOperation()))) { - return tsl::errors::Internal("Failed to flatten tuples"); - } - - // Add buffer mappings. The first attribute is the index of the slice, the - // second is a boolean attribute on whether the allocation is writeable. - llvm::SmallVector> - operand_mapping; - for (auto i : module->entry_computation()->parameter_instructions()) { - ShapeUtil::ForEachSubshape( - i->shape(), [&](const Shape& subshape, const ShapeIndex& index) { - if (subshape.IsTuple()) { - return; - } - auto slice = assignment->GetUniqueSlice(i, index); - operand_mapping.emplace_back( - builder.getI32IntegerAttr(static_cast(slice->index())), - builder.getBoolAttr(!slice->allocation()->is_readonly())); - }); - } - - auto root_instr = module->entry_computation()->root_instruction(); - auto output_allocation = assignment->GetUniqueTopLevelOutputSlice(); - - // Gather mappings to each element in the tuple if necessary - llvm::SmallVector result_inner_mapping; - if (output_allocation->allocation()->is_tuple()) { - ShapeUtil::ForEachSubshape( - root_instr->shape(), - [&](const Shape& subshape, const ShapeIndex& index) { - if (subshape.IsTuple()) { - return; - } - int64_t result_index = - assignment->GetUniqueSlice(root_instr, index)->index(); - result_inner_mapping.push_back( - builder.getI64IntegerAttr(result_index)); - if (export_mapping != nullptr) { - export_mapping->flattened_outputs.push_back(result_index); - } - }); - } - - int output_index = static_cast(output_allocation->index()); - auto result_mapping = builder.getI32IntegerAttr(output_index); - mlir_module->walk([&](mlir::func::FuncOp f) { - if (f.getSymName() == "main") { - for (const auto& p : llvm::enumerate(operand_mapping)) { - f.setArgAttr(p.index(), "xla_framework.input_mapping", p.value().first); - if (export_mapping != nullptr) { - auto index_attr = mlir::dyn_cast(p.value().first); - if (index_attr) { - export_mapping->inputs.push_back(index_attr.getInt()); - } - } - // Mark argument as (non-)writeable for bufferization. This ensures that - // entry parameters are not overwritten. - f.setArgAttr(p.index(), "bufferization.writable", p.value().second); - } - f->setAttr("xla_framework.result_mapping", result_mapping); - if (export_mapping != nullptr) { - export_mapping->result = output_index; - } - } - - if (output_allocation->allocation()->is_tuple()) { - f->setAttr("xla_framework.result_inner_mapping", - mlir::ArrayAttr::get(f.getContext(), result_inner_mapping)); - if (export_mapping != nullptr) { - export_mapping->output_is_tuple = true; - } - } - }); - return {mlir_module}; -} - struct ComputationToEmit { HloComputation* computation; @@ -1335,7 +1027,6 @@ CpuCompiler::CompileLegacyCpuExecutable(std::unique_ptr module) { // Compile must be thread-safe so create a new LLVM context for the module. mlir::MLIRContext mlir_context; - LoadMLIRDialects(mlir_context); auto llvm_context = std::make_unique(); auto llvm_module = std::make_unique("__compute_module", *llvm_context); @@ -1484,96 +1175,6 @@ CpuCompiler::CompileLegacyCpuExecutable(std::unique_ptr module) { return cpu_executable; } -namespace { - -absl::StatusOr> -GetXlaRuntimeCpuExecutable(const HloModule& hlo_module, - mlir::ModuleOp mlir_module, - absl::string_view entry_point, - const XlaFrameworkMapping& xla_framework_mapping, - mlir::DialectRegistry* registry) { - runtime::JitExecutable::Options opts = - GetXlaRuntimeJitExecutableOptions(hlo_module, registry); - std::string serialized_mlir = llvm_ir::DumpToString(mlir_module); - - absl::StatusOr jit_executable = - runtime::JitExecutable::Instantiate(serialized_mlir, entry_point, opts); - if (!jit_executable.ok()) { - return Internal("Failed to compile XLA Runtime program: %s", - jit_executable.status().message()); - } - - return std::make_unique( - std::make_unique(std::move(*jit_executable)), - xla_framework_mapping); -} -} // namespace - -absl::StatusOr> -CpuCompiler::CompileXlaRuntimeCpuExecutable( - std::unique_ptr hlo_module, mlir::DialectRegistry* registry) { - // Select an order for emitting the HLO instructions for each - // computation. Using this sequence enables tighter buffer liveness analysis - // and reduced memory usage (as compared to using DependencyHloOrdering). - TF_ASSIGN_OR_RETURN( - HloSchedule schedule, - ScheduleModule( - hlo_module.get(), BufferSizeBytesFunction(), - ComputationSchedulerToModuleScheduler(DFSMemoryScheduler))); - - // Run buffer allocation on the HLO graph. - TF_ASSIGN_OR_RETURN( - std::unique_ptr assignment, - BufferAssigner::Run(hlo_module.get(), - std::make_unique(schedule), - BufferSizeBytesFunction(), memory_alignment, - /*allocate_buffers_for_constants=*/true)); - VLOG(1) << "Buffer Assignment Stats for " << hlo_module->name() << "\n" - << assignment->GetStats().ToString(); - DumpHloModuleIfEnabled(*hlo_module, *assignment, "cpu_after_optimizations"); - - // TODO(ecg): these are just being collected here to be passed to - // CpuExecutable's constructor; we should actually do something with them. - absl::flat_hash_map - instruction_to_profile_idx; - absl::flat_hash_map - computation_to_profile_idx; - std::unique_ptr hlo_profile_index_map; - std::unique_ptr hlo_profile_printer_data; - if (hlo_module->config().hlo_profiling_enabled()) { - TF_RETURN_IF_ERROR(CreateHloProfilingArtifacts( - *hlo_module, &instruction_to_profile_idx, &computation_to_profile_idx, - &hlo_profile_index_map, &hlo_profile_printer_data)); - } - - mlir::MLIRContext mlir_context; - if (registry) { - mlir_context.appendDialectRegistry(*registry); - } - XlaFrameworkMapping xla_framework_mapping; - TF_ASSIGN_OR_RETURN( - auto mlir_module, - createMLIRModule(hlo_module.get(), mlir_context, assignment.get(), - &xla_framework_mapping)); - - TF_ASSIGN_OR_RETURN( - auto xla_runtime_executable, - GetXlaRuntimeCpuExecutable(*hlo_module, *mlir_module, "main", - xla_framework_mapping, registry)); - - if (DumpingEnabledForHloModule(*hlo_module)) { - TF_ASSIGN_OR_RETURN(std::string_view obj_file, - xla_runtime_executable->GetObjFile()); - DumpToFileInDir(*hlo_module, /*file_prefix=*/"", /*file_suffix=*/"o", - obj_file); - } - - return CpuExecutable::Create( - std::move(hlo_module), std::move(hlo_profile_printer_data), - std::move(hlo_profile_index_map), std::move(assignment), - std::move(xla_runtime_executable)); -} - absl::StatusOr> CpuCompiler::RunBackend( std::unique_ptr module, [[maybe_unused]] se::StreamExecutor* stream_exec, @@ -1688,9 +1289,7 @@ CpuCompiler::CompileAheadOfTime(std::unique_ptr module_group, // Compile must be thread-safe so create a new LLVM context for the module. mlir::MLIRContext mlir_context; - LoadMLIRDialects(mlir_context); llvm::LLVMContext llvm_context; - std::unique_ptr llvm_module; std::vector> results; for (size_t i = 0; i < modules.size(); ++i) { @@ -1741,79 +1340,52 @@ CpuCompiler::CompileAheadOfTime(std::unique_ptr module_group, CreateBufferInfosFromBufferAssignment(*module, *assignment); HloComputation* computation = module->entry_computation(); - if (options.use_mlir_hlo_lowering()) { - TF_ASSIGN_OR_RETURN( - auto mlir_module, - createMLIRModule(module, mlir_context, assignment.get())); - TF_RETURN_IF_ERROR( - xla::runtime::ExportMainWithOrdinal0(*mlir_module, mlir_context)); - TF_RETURN_IF_ERROR(LowerMLIRModule(module, *mlir_module, mlir_context, - *target_machine)); - - llvm::cast(mlir_module->lookupSymbol("main")) - .setName(options.entry_point_name()); - - llvm_module = mlir::translateModuleToLLVMIR(*mlir_module, llvm_context); - if (!llvm_module) { - return Internal("Failed to translate module to LLVM IR"); - } - // Set missing information - llvm_module->setDataLayout(target_machine->createDataLayout()); - llvm_module->setTargetTriple(triple.getTriple()); - if (pic_level != llvm::PICLevel::NotPIC) { - llvm_module->setPICLevel(pic_level); - } - if (pie_level != llvm::PIELevel::Default) { - llvm_module->setPIELevel(pie_level); - } - } else { - // Set required information before emitting IR - llvm_module = - std::make_unique("__compute_module", llvm_context); - llvm_module->setDataLayout(target_machine->createDataLayout()); - llvm_module->setTargetTriple(triple.getTriple()); - if (pic_level != llvm::PICLevel::NotPIC) { - llvm_module->setPICLevel(pic_level); - } - if (pie_level != llvm::PIELevel::Default) { - llvm_module->setPIELevel(pie_level); - } - IrEmitter ir_emitter( - &mlir_context, *module, *assignment, llvm_module.get(), - std::move(instruction_to_profile_idx), - std::move(computation_to_profile_idx), - ModuleComputationsTransitivelyContainCustomCall(*module), - &target_machine_features, - // TODO(b/66051036): Run full msan for AOT. - /*emit_code_for_msan=*/false); - - TF_RETURN_IF_ERROR(ir_emitter.EmitConstantGlobals()); - - for (ComputationToEmit subcomputation : - SubcomputationEmissionOrder(computation)) { - if (subcomputation.computation->IsFusionComputation()) { - continue; - } - TF_RETURN_IF_ERROR( - ir_emitter - .EmitComputation(subcomputation.computation, - subcomputation.computation->name(), - /*is_top_level_computation=*/false, - schedule.sequence(subcomputation.computation) - .instructions(), - subcomputation.allow_reassociation) - .status()); + // Set required information before emitting IR + auto llvm_module = + std::make_unique("__compute_module", llvm_context); + llvm_module->setDataLayout(target_machine->createDataLayout()); + llvm_module->setTargetTriple(triple.getTriple()); + if (pic_level != llvm::PICLevel::NotPIC) { + llvm_module->setPICLevel(pic_level); + } + if (pie_level != llvm::PIELevel::Default) { + llvm_module->setPIELevel(pie_level); + } + IrEmitter ir_emitter( + &mlir_context, *module, *assignment, llvm_module.get(), + std::move(instruction_to_profile_idx), + std::move(computation_to_profile_idx), + ModuleComputationsTransitivelyContainCustomCall(*module), + &target_machine_features, + // TODO(b/66051036): Run full msan for AOT. + /*emit_code_for_msan=*/false); + + TF_RETURN_IF_ERROR(ir_emitter.EmitConstantGlobals()); + + for (ComputationToEmit subcomputation : + SubcomputationEmissionOrder(computation)) { + if (subcomputation.computation->IsFusionComputation()) { + continue; } - const std::string& entry_point_name = options.entry_point_name(); - TF_ASSIGN_OR_RETURN(llvm::Function * entry_function, - ir_emitter.EmitComputation( - computation, entry_point_name, - /*is_top_level_computation=*/true, - schedule.sequence(computation).instructions(), - /*allow_reassociation=*/false)); - - CHECK(entry_function->getName() == entry_point_name); + TF_RETURN_IF_ERROR( + ir_emitter + .EmitComputation(subcomputation.computation, + subcomputation.computation->name(), + /*is_top_level_computation=*/false, + schedule.sequence(subcomputation.computation) + .instructions(), + subcomputation.allow_reassociation) + .status()); } + const std::string& entry_point_name = options.entry_point_name(); + TF_ASSIGN_OR_RETURN(llvm::Function * entry_function, + ir_emitter.EmitComputation( + computation, entry_point_name, + /*is_top_level_computation=*/true, + schedule.sequence(computation).instructions(), + /*allow_reassociation=*/false)); + + CHECK(entry_function->getName() == entry_point_name); ModuleHook pre_optimization_ir_hook; ModuleHook post_optimization_ir_hook; diff --git a/xla/service/cpu/hlo_xla_runtime_pipeline.cc b/xla/service/cpu/hlo_xla_runtime_pipeline.cc deleted file mode 100644 index c13f9c20bb7e5..0000000000000 --- a/xla/service/cpu/hlo_xla_runtime_pipeline.cc +++ /dev/null @@ -1,263 +0,0 @@ -/* Copyright 2022 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/cpu/hlo_xla_runtime_pipeline.h" - -#include - -#include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h" -#include "mlir/Conversion/BufferizationToMemRef/BufferizationToMemRef.h" // from @llvm-project -#include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h" // from @llvm-project -#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" // from @llvm-project -#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" // from @llvm-project -#include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h" // from @llvm-project -#include "mlir/Conversion/TensorToLinalg/TensorToLinalgPass.h" // from @llvm-project -#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" // from @llvm-project -#include "mlir/Conversion/VectorToSCF/VectorToSCF.h" // from @llvm-project -#include "mlir/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.h" // from @llvm-project -#include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h" // from @llvm-project -#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" // from @llvm-project -#include "mlir/Dialect/Bufferization/Transforms/Passes.h" // from @llvm-project -#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Dialect/Func/Transforms/Passes.h" // from @llvm-project -#include "mlir/Dialect/Linalg/Passes.h" // from @llvm-project -#include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h" // from @llvm-project -#include "mlir/Dialect/MemRef/Transforms/AllocationOpInterfaceImpl.h" // from @llvm-project -#include "mlir/Dialect/MemRef/Transforms/Passes.h" // from @llvm-project -#include "mlir/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.h" // from @llvm-project -#include "mlir/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.h" // from @llvm-project -#include "mlir/Dialect/Shape/Transforms/Passes.h" // from @llvm-project -#include "mlir/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.h" // from @llvm-project -#include "mlir/Dialect/SparseTensor/Transforms/Passes.h" // from @llvm-project -#include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h" // from @llvm-project -#include "mlir/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.h" // from @llvm-project -#include "mlir/Pass/PassManager.h" // from @llvm-project -#include "mlir/Support/LLVM.h" // from @llvm-project -#include "mlir/Support/LogicalResult.h" // from @llvm-project -#include "mlir/Transforms/Passes.h" // from @llvm-project -#include "xla/mlir/backends/cpu/transforms/passes.h" -#include "xla/mlir/runtime/transforms/compiler.h" -#include "xla/mlir_hlo/mhlo/interfaces/bufferizable_op_interface_impl.h" -#include "xla/mlir_hlo/mhlo/transforms/passes.h" -#include "xla/mlir_hlo/transforms/passes.h" -#include "xla/status.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/logging.h" - -#ifdef EXPERIMENTAL_MLIR_GPU -#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" // from @llvm-project -#include "mlir/Dialect/GPU/IR/GPUDialect.h" // from @llvm-project -#endif // EXPERIMENTAL_MLIR_GPU - -namespace xla { -namespace cpu { -namespace { - -using mlir::func::FuncOp; - -mlir::bufferization::OneShotBufferizationOptions GetBufferizationOptions( - bool new_deallocator) { - using mlir::bufferization::BufferizationOptions; - using mlir::bufferization::LayoutMapOption; - using mlir::bufferization::OneShotBufferizationOptions; - - OneShotBufferizationOptions options; - options.bufferizeFunctionBoundaries = true; - options.allowReturnAllocsFromLoops = true; - options.setFunctionBoundaryTypeConversion(LayoutMapOption::IdentityLayoutMap); - options.unknownTypeConverterFn = [](mlir::Value value, - mlir::Attribute memorySpace, - const BufferizationOptions& options) { - return mlir::bufferization::getMemRefTypeWithStaticIdentityLayout( - mlir::cast(value.getType()), memorySpace); - }; - return options; -} - -} // namespace - -// -------------------------------------------------------------------------- // -// Assemble a HLO XLA Runtime pipeline to lower from HLO to Linalg on buffers. -// -------------------------------------------------------------------------- // - -static Status CreateHloXlaPipeline( - mlir::OpPassManager& pm, const HloXlaRuntimePipelineOptions& options) { - // Resolve all shape constraints (e.g. broadcast constraints that can be - // proved statically and changed to const witness) early to allow more - // efficient broadcast operations moving. - // Move up broadcasting operations to allow for more fusion opportunities. - pm.addPass(mlir::createInlinerPass()); - pm.addPass(mlir::mhlo::createExpandHloTuplesPass("main")); - // TODO(b/233771980): Remove once custom_call doesn't use tuples. - pm.addNestedPass(mlir::mhlo::createFlattenTuplePass()); - pm.addPass(createXlaAbiLegalizationPass()); - pm.addNestedPass( - mlir::mhlo::createLegalizeGeneralDotPass()); - pm.addNestedPass( - mlir::mhlo::createBroadcastPropagationPass()); - pm.addPass(mlir::createCSEPass()); - pm.addPass(mlir::createCanonicalizerPass()); - - // Transform HLO operations to Linalg. - pm.addNestedPass( - mlir::mhlo::createLegalizeControlFlowPass()); - pm.addNestedPass(mlir::mhlo::createLegalizeDotGeneralToDotPass()); - pm.addPass(::mlir::mhlo::createLegalizeToArithmeticPass()); - pm.addNestedPass( - xla::cpu::createLegalizeLibraryOpsPass()); - pm.addNestedPass( - mlir::mhlo::createMhloExpandOpsSimplifierPass()); - pm.addNestedPass( - mlir::mhlo::createHloCanonicalizeScatterPass()); - pm.addNestedPass(mlir::mhlo::createHloCanonicalizeDotPass()); - pm.addNestedPass(mlir::mhlo::createGroupReductionDimensionsPass()); - pm.addNestedPass( - mlir::mhlo::createLegalizeHloToLinalgPass()); - - // Lower index cast on tensors to tensor.generate. - pm.addNestedPass(mlir::createLowerIndexCastPass()); - - pm.addPass(mlir::mhlo::createConvertToSignlessPass()); - - // Lower shape dialect to standard to enable linalg canonicalizations (e.g. - // use linalg inputs instead of outputs for memref.dim operations). - pm.addNestedPass(mlir::mhlo::createShapeSimplification()); - pm.addNestedPass(mlir::createShapeToShapeLowering()); - pm.addPass(mlir::createConvertShapeToStandardPass()); - pm.addNestedPass( - mlir::createConvertShapeConstraintsPass()); - - // Fuse Linalg on tensors operations. - pm.addPass(mlir::createCSEPass()); - pm.addPass(mlir::memref::createResolveShapedTypeResultDimsPass()); - pm.addPass(mlir::createCanonicalizerPass()); - pm.addNestedPass( - mlir::createLinalgElementwiseOpFusionPass()); - pm.addPass(mlir::createReconcileUnrealizedCastsPass()); - pm.addPass(mlir::createConvertTensorToLinalgPass()); - - // Detensorize SCF iter args. - pm.addNestedPass(mlir::createDetensorizeScfOpsPass()); - // mhlo ops on unit tensors generate trivial linalg.generics, which - // one-shot-bufferize generates unnecessary allocs for. The detensorize pass - // replaces these linalg.generics with scalar ops. - auto detensorize = mlir::createLinalgDetensorizePass(); - if (detensorize - ->initializeOptions( - "aggressive-mode=true", - [](const mlir::Twine&) { return mlir::failure(); }) - .failed()) { - return tsl::errors::Internal("Failed to set up detensorize pass."); - } - pm.addNestedPass(std::move(detensorize)); - pm.addPass(mlir::bufferization::createEmptyTensorEliminationPass()); - pm.addNestedPass( - mlir::bufferization::createEmptyTensorToAllocTensorPass()); - - // Always run canonicalizer (which does dead code removal) before - // bufferizing anything. - pm.addPass(mlir::createCanonicalizerPass()); - pm.addPass(mlir::hlo::createOneShotBufferizePass()); - pm.addNestedPass(createRewriteReallocToAllocPass()); - pm.addNestedPass(mlir::createVectorizeCopyPass()); - pm.addNestedPass(mlir::createNaiveCopyRemovalPass()); - - // This should be unified. It exists, because the async runtime tests expect - // parallel loops. - if (options.sparse_bufferization) { - pm.addNestedPass( - mlir::createConvertLinalgToLoopsPass()); - } else { - pm.addNestedPass( - mlir::createConvertLinalgToParallelLoopsPass()); - } - pm.addPass(mlir::createCSEPass()); - pm.addPass(mlir::createCanonicalizerPass()); - mlir::bufferization::BufferResultsToOutParamsOpts out_params_opts; - out_params_opts.filterFn = [](mlir::func::FuncOp* func) { - // Only transform the entry point. - return func->getSymName() == "main"; - }; - pm.addPass( - mlir::bufferization::createBufferResultsToOutParamsPass(out_params_opts)); - - pm.addNestedPass( - mlir::bufferization::createPromoteBuffersToStackPass(nullptr)); - pm.addNestedPass( - mlir::bufferization::createBufferDeallocationPass()); - pm.addPass(mlir::createBufferizationToMemRefPass()); - if (options.remove_copies_to_outparams) { - pm.addNestedPass( - xla::cpu::createRemoveCopiesToOutParamsPass()); - } - - // Specialize linalg.matmul to linalg.dot, linalg.matvec or linalg.vecmat, - // and immediately canonicalize to clean up not taken branches. - // pm.addNestedPass(CreateLinalgMatmulSpecializationPass()); - pm.addPass(mlir::createCanonicalizerPass()); - - // TODO(tpopp): Move hits to mlir::hlo::createGenericHostToLLVMPass? - pm.addNestedPass( - mlir::createConvertComplexToStandardPass()); - - pm.addPass(mlir::createCSEPass()); - pm.addPass(mlir::createCanonicalizerPass()); - pm.addNestedPass(mlir::createConvertVectorToSCFPass()); - pm.addNestedPass(xla::cpu::createLegalizeI1VectorTransferOpsPass()); - pm.addNestedPass( - xla::cpu::createConvertXlaCpuMemRefElementCastToLLVMPass()); - return OkStatus(); -} - -Status CreateHloXlaRuntimePipeline( - xla::runtime::PassManager& passes, - const HloXlaRuntimePipelineOptions& options) { - return CreateHloXlaPipeline(*passes, options); -} - -Status CreateDefaultHloXlaRuntimePipeline(xla::runtime::PassManager& passes) { - HloXlaRuntimePipelineOptions options; - return CreateHloXlaPipeline(*passes, options); -} - -void RegisterHloXlaRuntimePipelineDialects(mlir::DialectRegistry& dialects) { - mlir::arith::registerBufferizableOpInterfaceExternalModels(dialects); - mlir::bufferization::func_ext::registerBufferizableOpInterfaceExternalModels( - dialects); - mlir::memref::registerAllocationOpInterfaceExternalModels(dialects); - mlir::linalg::registerBufferizableOpInterfaceExternalModels(dialects); - mlir::linalg::registerTilingInterfaceExternalModels(dialects); - mlir::mhlo::registerBufferizableOpInterfaceExternalModels(dialects); - mlir::scf::registerBufferizableOpInterfaceExternalModels(dialects); - mlir::shape::registerBufferizableOpInterfaceExternalModels(dialects); - mlir::sparse_tensor::registerBufferizableOpInterfaceExternalModels(dialects); - mlir::tensor::registerBufferizableOpInterfaceExternalModels(dialects); - mlir::vector::registerBufferizableOpInterfaceExternalModels(dialects); -} - -static mlir::PassPipelineRegistration<> hlo_xla_runtime_pipeline( - "hlo-xla-runtime-pipeline", - "Convert HLO dialect to XLA Runtime compatible dialects", - [](mlir::OpPassManager& pm) { - HloXlaRuntimePipelineOptions options; - Status status = CreateHloXlaPipeline(pm, options); - if (!status.ok()) { - LOG(FATAL) << "HLO-XLA Runtime pipeline failed with: " - << status.message(); - } - }); - -} // namespace cpu -} // namespace xla diff --git a/xla/service/cpu/hlo_xla_runtime_pipeline.h b/xla/service/cpu/hlo_xla_runtime_pipeline.h deleted file mode 100644 index 5b5a970d1352f..0000000000000 --- a/xla/service/cpu/hlo_xla_runtime_pipeline.h +++ /dev/null @@ -1,55 +0,0 @@ -/* Copyright 2022 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_CPU_HLO_XLA_RUNTIME_PIPELINE_H_ -#define XLA_SERVICE_CPU_HLO_XLA_RUNTIME_PIPELINE_H_ - -#include -#include - -#include "xla/runtime/compiler.h" -#include "xla/status.h" - -namespace mlir { -class DialectRegistry; -} // namespace mlir - -namespace xla { -namespace cpu { - -struct HloXlaRuntimePipelineOptions { - bool enable_tiling_and_fusion = false; - bool enable_fusion_outlining = true; - bool remove_copies_to_outparams = true; - bool sparse_bufferization = true; - bool enable_avx2 = true; - // Accelerate sparse computations with CUDA threading. - // This is an experimental feature, so off by default. - int32_t xla_cpu_sparse_cuda_threads = 0; - // Optional CPU name, similar to llc's -mcpu flag. - std::string cpu_name = ""; - std::vector matmul_tile_sizes = {}; -}; - -// Creates a pipeline that lowers modules from HLO to Linalg on buffers. -Status CreateHloXlaRuntimePipeline(xla::runtime::PassManager& passes, - const HloXlaRuntimePipelineOptions& options); -Status CreateDefaultHloXlaRuntimePipeline(xla::runtime::PassManager& passes); - -void RegisterHloXlaRuntimePipelineDialects(mlir::DialectRegistry& dialects); -} // namespace cpu -} // namespace xla - -#endif // XLA_SERVICE_CPU_HLO_XLA_RUNTIME_PIPELINE_H_ diff --git a/xla/translate/BUILD b/xla/translate/BUILD index 0eb802744596f..293d0580747f2 100644 --- a/xla/translate/BUILD +++ b/xla/translate/BUILD @@ -52,7 +52,6 @@ xla_cc_binary( "//xla/mlir/framework/transforms:passes", "//xla/mlir_hlo:hlo_dialect_registration", "//xla/service:cpu_plugin", - "//xla/service/cpu:hlo_xla_runtime_pipeline", # buildcleaner: keep "@llvm-project//llvm:Support", "@llvm-project//mlir:AllPassesAndDialects", "@llvm-project//mlir:MlirOptLib",