xref: /llvm-project/mlir/test/lib/Dialect/ArmSME/TestLowerToArmSME.cpp (revision d319fc41d0e35bfea8368ad91dc15ab319cddcb7)
1 //===- TestLowerToArmSME.cpp - Test lowering to ArmSME as a sink pass -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass for testing the lowering to ArmSME as a
10 // generally usable sink pass.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Conversion/ArithToArmSME/ArithToArmSME.h"
15 #include "mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h"
16 #include "mlir/Conversion/ArmSMEToSCF/ArmSMEToSCF.h"
17 #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
18 #include "mlir/Conversion/VectorToArmSME/VectorToArmSME.h"
19 #include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
20 #include "mlir/Dialect/ArmSME/Transforms/Passes.h"
21 #include "mlir/Dialect/ArmSVE/Transforms/Passes.h"
22 #include "mlir/Dialect/Func/IR/FuncOps.h"
23 #include "mlir/IR/DialectRegistry.h"
24 #include "mlir/Pass/Pass.h"
25 #include "mlir/Pass/PassManager.h"
26 #include "mlir/Pass/PassOptions.h"
27 #include "mlir/Transforms/Passes.h"
28 
29 using namespace mlir;
30 
31 namespace {
32 struct TestLowerToArmSMEOptions
33     : public PassPipelineOptions<TestLowerToArmSMEOptions> {
34   PassOptions::Option<bool> fuseOuterProducts{
35       *this, "fuse-outer-products",
36       llvm::cl::desc("Fuse outer product operations via "
37                      "'-arm-sme-outer-product-fusion' pass"),
38       llvm::cl::init(true)};
39   PassOptions::Option<bool> dumpTileLiveRanges{
40       *this, "dump-tile-live-ranges",
41       llvm::cl::desc("Dump the live ranges of SME tiles (for debugging)"),
42       llvm::cl::init(false)};
43 };
44 
buildTestLowerToArmSME(OpPassManager & pm,const TestLowerToArmSMEOptions & options)45 void buildTestLowerToArmSME(OpPassManager &pm,
46                             const TestLowerToArmSMEOptions &options) {
47   // Legalize vector operations so they can be converted to ArmSME.
48   pm.addPass(arm_sme::createVectorLegalizationPass());
49 
50   // Sprinkle some cleanups.
51   pm.addPass(createCanonicalizerPass());
52   pm.addPass(createCSEPass());
53 
54   // Passes that convert operations on vectors to ArmSME operations.
55 
56   // Convert Arith to ArmSME.
57   pm.addPass(createArithToArmSMEConversionPass());
58   // Convert Vector to ArmSME.
59   pm.addPass(createConvertVectorToArmSMEPass());
60 
61   // Fuse outer products.
62   if (options.fuseOuterProducts)
63     pm.addPass(arm_sme::createOuterProductFusionPass());
64 
65   // Convert operations on high-level vectors to loops.
66 
67   // Convert ArmSME to SCF.
68   pm.addPass(createConvertArmSMEToSCFPass());
69 
70   // Convert Vector to SCF (with full unroll enabled).
71   pm.addPass(createConvertVectorToSCFPass(
72       VectorTransferToSCFOptions().enableFullUnroll()));
73 
74   // Enable streaming-mode and ZA.
75   pm.addPass(arm_sme::createEnableArmStreamingPass(
76       arm_sme::ArmStreamingMode::StreamingLocally, arm_sme::ArmZaMode::NewZA,
77       /*ifRequiredByOps=*/true));
78 
79   // Convert SCF to CF (required for ArmSME tile allocation).
80   pm.addPass(createConvertSCFToCFPass());
81 
82   // Convert ArmSME to LLVM.
83   pm.addNestedPass<func::FuncOp>(
84       createConvertArmSMEToLLVMPass(options.dumpTileLiveRanges));
85 
86   // Sprinkle some cleanups.
87   pm.addPass(createCanonicalizerPass());
88   pm.addPass(createCSEPass());
89 }
90 } // namespace
91 
92 namespace mlir {
93 namespace test {
registerTestLowerToArmSME()94 void registerTestLowerToArmSME() {
95   PassPipelineRegistration<TestLowerToArmSMEOptions>(
96       "test-lower-to-arm-sme",
97       "An example pipeline to lower operations on vectors (arith, vector) to "
98       "LLVM via ArmSME.",
99       buildTestLowerToArmSME);
100 }
101 } // namespace test
102 } // namespace mlir
103