//===- TestLowerToArmSME.cpp - Test lowering to ArmSME as a sink pass -----===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements a pass for testing the lowering to ArmSME as a // generally usable sink pass. // //===----------------------------------------------------------------------===// #include "mlir/Conversion/ArithToArmSME/ArithToArmSME.h" #include "mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h" #include "mlir/Conversion/ArmSMEToSCF/ArmSMEToSCF.h" #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" #include "mlir/Conversion/VectorToArmSME/VectorToArmSME.h" #include "mlir/Conversion/VectorToSCF/VectorToSCF.h" #include "mlir/Dialect/ArmSME/Transforms/Passes.h" #include "mlir/Dialect/ArmSVE/Transforms/Passes.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/DialectRegistry.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "mlir/Pass/PassOptions.h" #include "mlir/Transforms/Passes.h" using namespace mlir; namespace { struct TestLowerToArmSMEOptions : public PassPipelineOptions { PassOptions::Option fuseOuterProducts{ *this, "fuse-outer-products", llvm::cl::desc("Fuse outer product operations via " "'-arm-sme-outer-product-fusion' pass"), llvm::cl::init(true)}; PassOptions::Option dumpTileLiveRanges{ *this, "dump-tile-live-ranges", llvm::cl::desc("Dump the live ranges of SME tiles (for debugging)"), llvm::cl::init(false)}; }; void buildTestLowerToArmSME(OpPassManager &pm, const TestLowerToArmSMEOptions &options) { // Legalize vector operations so they can be converted to ArmSME. pm.addPass(arm_sme::createVectorLegalizationPass()); // Sprinkle some cleanups. pm.addPass(createCanonicalizerPass()); pm.addPass(createCSEPass()); // Passes that convert operations on vectors to ArmSME operations. // Convert Arith to ArmSME. pm.addPass(createArithToArmSMEConversionPass()); // Convert Vector to ArmSME. pm.addPass(createConvertVectorToArmSMEPass()); // Fuse outer products. if (options.fuseOuterProducts) pm.addPass(arm_sme::createOuterProductFusionPass()); // Convert operations on high-level vectors to loops. // Convert ArmSME to SCF. pm.addPass(createConvertArmSMEToSCFPass()); // Convert Vector to SCF (with full unroll enabled). pm.addPass(createConvertVectorToSCFPass( VectorTransferToSCFOptions().enableFullUnroll())); // Enable streaming-mode and ZA. pm.addPass(arm_sme::createEnableArmStreamingPass( arm_sme::ArmStreamingMode::StreamingLocally, arm_sme::ArmZaMode::NewZA, /*ifRequiredByOps=*/true)); // Convert SCF to CF (required for ArmSME tile allocation). pm.addPass(createConvertSCFToCFPass()); // Convert ArmSME to LLVM. pm.addNestedPass( createConvertArmSMEToLLVMPass(options.dumpTileLiveRanges)); // Sprinkle some cleanups. pm.addPass(createCanonicalizerPass()); pm.addPass(createCSEPass()); } } // namespace namespace mlir { namespace test { void registerTestLowerToArmSME() { PassPipelineRegistration( "test-lower-to-arm-sme", "An example pipeline to lower operations on vectors (arith, vector) to " "LLVM via ArmSME.", buildTestLowerToArmSME); } } // namespace test } // namespace mlir