xref: /llvm-project/mlir/test/lib/Dialect/ArmSME/TestLowerToArmSME.cpp (revision d319fc41d0e35bfea8368ad91dc15ab319cddcb7)
1b39f5660SCullen Rhodes //===- TestLowerToArmSME.cpp - Test lowering to ArmSME as a sink pass -----===//
2b39f5660SCullen Rhodes //
3b39f5660SCullen Rhodes // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4b39f5660SCullen Rhodes // See https://llvm.org/LICENSE.txt for license information.
5b39f5660SCullen Rhodes // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6b39f5660SCullen Rhodes //
7b39f5660SCullen Rhodes //===----------------------------------------------------------------------===//
8b39f5660SCullen Rhodes //
9b39f5660SCullen Rhodes // This file implements a pass for testing the lowering to ArmSME as a
10b39f5660SCullen Rhodes // generally usable sink pass.
11b39f5660SCullen Rhodes //
12b39f5660SCullen Rhodes //===----------------------------------------------------------------------===//
13b39f5660SCullen Rhodes 
14b39f5660SCullen Rhodes #include "mlir/Conversion/ArithToArmSME/ArithToArmSME.h"
15b39f5660SCullen Rhodes #include "mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h"
16b39f5660SCullen Rhodes #include "mlir/Conversion/ArmSMEToSCF/ArmSMEToSCF.h"
17041baf2fSBenjamin Maxwell #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
18b39f5660SCullen Rhodes #include "mlir/Conversion/VectorToArmSME/VectorToArmSME.h"
19b39f5660SCullen Rhodes #include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
20b39f5660SCullen Rhodes #include "mlir/Dialect/ArmSME/Transforms/Passes.h"
21b39f5660SCullen Rhodes #include "mlir/Dialect/ArmSVE/Transforms/Passes.h"
22041baf2fSBenjamin Maxwell #include "mlir/Dialect/Func/IR/FuncOps.h"
23b39f5660SCullen Rhodes #include "mlir/IR/DialectRegistry.h"
24b39f5660SCullen Rhodes #include "mlir/Pass/Pass.h"
25b39f5660SCullen Rhodes #include "mlir/Pass/PassManager.h"
26b39f5660SCullen Rhodes #include "mlir/Pass/PassOptions.h"
27b39f5660SCullen Rhodes #include "mlir/Transforms/Passes.h"
28b39f5660SCullen Rhodes 
29b39f5660SCullen Rhodes using namespace mlir;
30b39f5660SCullen Rhodes 
31b39f5660SCullen Rhodes namespace {
32b39f5660SCullen Rhodes struct TestLowerToArmSMEOptions
33b39f5660SCullen Rhodes     : public PassPipelineOptions<TestLowerToArmSMEOptions> {
34b39f5660SCullen Rhodes   PassOptions::Option<bool> fuseOuterProducts{
35b39f5660SCullen Rhodes       *this, "fuse-outer-products",
36b39f5660SCullen Rhodes       llvm::cl::desc("Fuse outer product operations via "
37b39f5660SCullen Rhodes                      "'-arm-sme-outer-product-fusion' pass"),
38b39f5660SCullen Rhodes       llvm::cl::init(true)};
39041baf2fSBenjamin Maxwell   PassOptions::Option<bool> dumpTileLiveRanges{
40041baf2fSBenjamin Maxwell       *this, "dump-tile-live-ranges",
41041baf2fSBenjamin Maxwell       llvm::cl::desc("Dump the live ranges of SME tiles (for debugging)"),
42041baf2fSBenjamin Maxwell       llvm::cl::init(false)};
43b39f5660SCullen Rhodes };
44b39f5660SCullen Rhodes 
buildTestLowerToArmSME(OpPassManager & pm,const TestLowerToArmSMEOptions & options)45b39f5660SCullen Rhodes void buildTestLowerToArmSME(OpPassManager &pm,
46b39f5660SCullen Rhodes                             const TestLowerToArmSMEOptions &options) {
47b39f5660SCullen Rhodes   // Legalize vector operations so they can be converted to ArmSME.
48b39f5660SCullen Rhodes   pm.addPass(arm_sme::createVectorLegalizationPass());
49b39f5660SCullen Rhodes 
50b39f5660SCullen Rhodes   // Sprinkle some cleanups.
51b39f5660SCullen Rhodes   pm.addPass(createCanonicalizerPass());
52b39f5660SCullen Rhodes   pm.addPass(createCSEPass());
53b39f5660SCullen Rhodes 
54b39f5660SCullen Rhodes   // Passes that convert operations on vectors to ArmSME operations.
55b39f5660SCullen Rhodes 
56b39f5660SCullen Rhodes   // Convert Arith to ArmSME.
57b39f5660SCullen Rhodes   pm.addPass(createArithToArmSMEConversionPass());
58b39f5660SCullen Rhodes   // Convert Vector to ArmSME.
59b39f5660SCullen Rhodes   pm.addPass(createConvertVectorToArmSMEPass());
60b39f5660SCullen Rhodes 
61b39f5660SCullen Rhodes   // Fuse outer products.
62b39f5660SCullen Rhodes   if (options.fuseOuterProducts)
63b39f5660SCullen Rhodes     pm.addPass(arm_sme::createOuterProductFusionPass());
64b39f5660SCullen Rhodes 
65b39f5660SCullen Rhodes   // Convert operations on high-level vectors to loops.
66b39f5660SCullen Rhodes 
67b39f5660SCullen Rhodes   // Convert ArmSME to SCF.
68b39f5660SCullen Rhodes   pm.addPass(createConvertArmSMEToSCFPass());
69b39f5660SCullen Rhodes 
70b39f5660SCullen Rhodes   // Convert Vector to SCF (with full unroll enabled).
71b39f5660SCullen Rhodes   pm.addPass(createConvertVectorToSCFPass(
72b39f5660SCullen Rhodes       VectorTransferToSCFOptions().enableFullUnroll()));
73b39f5660SCullen Rhodes 
74b39f5660SCullen Rhodes   // Enable streaming-mode and ZA.
75b39f5660SCullen Rhodes   pm.addPass(arm_sme::createEnableArmStreamingPass(
76b39f5660SCullen Rhodes       arm_sme::ArmStreamingMode::StreamingLocally, arm_sme::ArmZaMode::NewZA,
77*d319fc41SBenjamin Maxwell       /*ifRequiredByOps=*/true));
78b39f5660SCullen Rhodes 
79041baf2fSBenjamin Maxwell   // Convert SCF to CF (required for ArmSME tile allocation).
80041baf2fSBenjamin Maxwell   pm.addPass(createConvertSCFToCFPass());
81041baf2fSBenjamin Maxwell 
82b39f5660SCullen Rhodes   // Convert ArmSME to LLVM.
83041baf2fSBenjamin Maxwell   pm.addNestedPass<func::FuncOp>(
84041baf2fSBenjamin Maxwell       createConvertArmSMEToLLVMPass(options.dumpTileLiveRanges));
85b39f5660SCullen Rhodes 
86b39f5660SCullen Rhodes   // Sprinkle some cleanups.
87b39f5660SCullen Rhodes   pm.addPass(createCanonicalizerPass());
88b39f5660SCullen Rhodes   pm.addPass(createCSEPass());
89b39f5660SCullen Rhodes }
90b39f5660SCullen Rhodes } // namespace
91b39f5660SCullen Rhodes 
92b39f5660SCullen Rhodes namespace mlir {
93b39f5660SCullen Rhodes namespace test {
registerTestLowerToArmSME()94b39f5660SCullen Rhodes void registerTestLowerToArmSME() {
95b39f5660SCullen Rhodes   PassPipelineRegistration<TestLowerToArmSMEOptions>(
96b39f5660SCullen Rhodes       "test-lower-to-arm-sme",
97b39f5660SCullen Rhodes       "An example pipeline to lower operations on vectors (arith, vector) to "
98b39f5660SCullen Rhodes       "LLVM via ArmSME.",
99b39f5660SCullen Rhodes       buildTestLowerToArmSME);
100b39f5660SCullen Rhodes }
101b39f5660SCullen Rhodes } // namespace test
102b39f5660SCullen Rhodes } // namespace mlir
103