xref: /llvm-project/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp (revision 09dfc5713d7e2342bea4c8447d1ed76c85eb8225)
134ff8573SNicolas Vasilache //===- TestVectorTransforms.cpp - Test Vector transforms and lowerings ----===//
23fef2d26SRiver Riddle //
33fef2d26SRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
43fef2d26SRiver Riddle // See https://llvm.org/LICENSE.txt for license information.
53fef2d26SRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
63fef2d26SRiver Riddle //
73fef2d26SRiver Riddle //===----------------------------------------------------------------------===//
83fef2d26SRiver Riddle 
9a1fe1f5fSKazu Hirata #include <optional>
1073bec2b2SNicolas Vasilache #include <type_traits>
113fef2d26SRiver Riddle 
123fef2d26SRiver Riddle #include "mlir/Analysis/SliceAnalysis.h"
133fef2d26SRiver Riddle #include "mlir/Dialect/Affine/IR/AffineOps.h"
14fb7ef637SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h"
1523aa5a74SRiver Riddle #include "mlir/Dialect/Func/IR/FuncOps.h"
16d7ef488bSMogball #include "mlir/Dialect/GPU/IR/GPUDialect.h"
17b2729fdaSNicolas Vasilache #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
18b7f2c108Sgysit #include "mlir/Dialect/Linalg/IR/Linalg.h"
1934ff8573SNicolas Vasilache #include "mlir/Dialect/Linalg/Passes.h"
2034ff8573SNicolas Vasilache #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
213fef2d26SRiver Riddle #include "mlir/Dialect/MemRef/IR/MemRef.h"
229a795f0cSManish Gupta #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
238b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/IR/SCF.h"
24f80a976aSJakub Kuderski #include "mlir/Dialect/Tensor/IR/Tensor.h"
259f122152SChristopher Bate #include "mlir/Dialect/Vector/IR/VectorOps.h"
262bc4c3e9SNicolas Vasilache #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
27d02f10d9SThomas Raoux #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
2839c80656SLei Zhang #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
2999ef9eebSMatthias Springer #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
303fef2d26SRiver Riddle #include "mlir/Pass/Pass.h"
3134ff8573SNicolas Vasilache #include "mlir/Pass/PassManager.h"
329f122152SChristopher Bate #include "mlir/Support/LLVM.h"
333fef2d26SRiver Riddle #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
343fef2d26SRiver Riddle 
353fef2d26SRiver Riddle using namespace mlir;
3634ff8573SNicolas Vasilache using namespace mlir::linalg;
373fef2d26SRiver Riddle using namespace mlir::vector;
38d054b80bSNicolas Vasilache 
393fef2d26SRiver Riddle namespace {
403fef2d26SRiver Riddle 
4134ff8573SNicolas Vasilache struct TestVectorToVectorLowering
4258ceae95SRiver Riddle     : public PassWrapper<TestVectorToVectorLowering,
4358ceae95SRiver Riddle                          OperationPass<func::FuncOp>> {
445e50dd04SRiver Riddle   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorToVectorLowering)
455e50dd04SRiver Riddle 
4634ff8573SNicolas Vasilache   TestVectorToVectorLowering() = default;
473bab9d4eSMehdi Amini   TestVectorToVectorLowering(const TestVectorToVectorLowering &pass)
483bab9d4eSMehdi Amini       : PassWrapper(pass) {}
49b5e22e6dSMehdi Amini   StringRef getArgument() const final {
5034ff8573SNicolas Vasilache     return "test-vector-to-vector-lowering";
51b5e22e6dSMehdi Amini   }
52b5e22e6dSMehdi Amini   StringRef getDescription() const final {
5334ff8573SNicolas Vasilache     return "Test lowering patterns between ops in the vector dialect";
54b5e22e6dSMehdi Amini   }
553fef2d26SRiver Riddle 
563fef2d26SRiver Riddle   void getDependentDialects(DialectRegistry &registry) const override {
574c48f016SMatthias Springer     registry.insert<affine::AffineDialect>();
58576b184dSAndrzej Warzynski     registry.insert<vector::VectorDialect>();
593fef2d26SRiver Riddle   }
603fef2d26SRiver Riddle 
613fef2d26SRiver Riddle   Option<bool> unroll{*this, "unroll", llvm::cl::desc("Include unrolling"),
623fef2d26SRiver Riddle                       llvm::cl::init(false)};
633fef2d26SRiver Riddle 
6441574554SRiver Riddle   void runOnOperation() override {
653fef2d26SRiver Riddle     auto *ctx = &getContext();
663fef2d26SRiver Riddle     RewritePatternSet patterns(ctx);
673fef2d26SRiver Riddle     if (unroll) {
6829102538Sthomasraoux       populateVectorUnrollPatterns(
6929102538Sthomasraoux           patterns,
703fef2d26SRiver Riddle           UnrollVectorOptions().setNativeShapeFn(getShape).setFilterConstraint(
713fef2d26SRiver Riddle               filter));
723fef2d26SRiver Riddle     }
733fef2d26SRiver Riddle     populateVectorToVectorCanonicalizationPatterns(patterns);
743fef2d26SRiver Riddle     populateBubbleVectorBitCastOpPatterns(patterns);
753fef2d26SRiver Riddle     populateCastAwayVectorLeadingOneDimPatterns(patterns);
76*09dfc571SJacques Pienaar     (void)applyPatternsGreedily(getOperation(), std::move(patterns));
773fef2d26SRiver Riddle   }
783fef2d26SRiver Riddle 
793fef2d26SRiver Riddle private:
803fef2d26SRiver Riddle   // Return the target shape based on op type.
810a81ace0SKazu Hirata   static std::optional<SmallVector<int64_t>> getShape(Operation *op) {
82dec8af70SRiver Riddle     if (isa<arith::AddFOp, arith::SelectOp, arith::CmpFOp>(op))
837a69a9d7SNicolas Vasilache       return SmallVector<int64_t>(2, 2);
843fef2d26SRiver Riddle     if (isa<vector::ContractionOp>(op))
857a69a9d7SNicolas Vasilache       return SmallVector<int64_t>(3, 2);
8629102538Sthomasraoux     // For transfer ops, just propagate the shape coming from
8729102538Sthomasraoux     // InsertStridedSlices/ExtractStridedSlices.
8829102538Sthomasraoux     if (auto readOp = dyn_cast<vector::TransferReadOp>(op)) {
8929102538Sthomasraoux       VectorType dstVec;
9029102538Sthomasraoux       for (Operation *users : readOp->getUsers()) {
9129102538Sthomasraoux         auto extract = dyn_cast<ExtractStridedSliceOp>(users);
9229102538Sthomasraoux         if (!extract)
931a36588eSKazu Hirata           return std::nullopt;
945550c821STres Popp         auto vecType = cast<VectorType>(extract.getResult().getType());
9529102538Sthomasraoux         if (dstVec && dstVec != vecType)
961a36588eSKazu Hirata           return std::nullopt;
9729102538Sthomasraoux         dstVec = vecType;
9829102538Sthomasraoux       }
995262865aSKazu Hirata       return SmallVector<int64_t>(dstVec.getShape());
10029102538Sthomasraoux     }
10129102538Sthomasraoux     if (auto writeOp = dyn_cast<vector::TransferWriteOp>(op)) {
1027c38fd60SJacques Pienaar       auto insert = writeOp.getVector().getDefiningOp<InsertStridedSliceOp>();
10329102538Sthomasraoux       if (!insert)
1041a36588eSKazu Hirata         return std::nullopt;
10529102538Sthomasraoux       ArrayRef<int64_t> shape = insert.getSourceVectorType().getShape();
1065262865aSKazu Hirata       return SmallVector<int64_t>(shape);
10729102538Sthomasraoux     }
1081a36588eSKazu Hirata     return std::nullopt;
1093fef2d26SRiver Riddle   }
1103fef2d26SRiver Riddle 
1113fef2d26SRiver Riddle   static LogicalResult filter(Operation *op) {
112dec8af70SRiver Riddle     return success(isa<arith::AddFOp, arith::SelectOp, arith::CmpFOp,
113dec8af70SRiver Riddle                        ContractionOp, TransferReadOp, TransferWriteOp>(op));
1143fef2d26SRiver Riddle   }
1153fef2d26SRiver Riddle };
1163fef2d26SRiver Riddle 
117fb7ef637SJakub Kuderski struct TestVectorContractionPrepareForMMTLowering
118fb7ef637SJakub Kuderski     : public PassWrapper<TestVectorContractionPrepareForMMTLowering,
119fb7ef637SJakub Kuderski                          OperationPass<func::FuncOp>> {
120fb7ef637SJakub Kuderski   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
121fb7ef637SJakub Kuderski       TestVectorContractionPrepareForMMTLowering)
122fb7ef637SJakub Kuderski 
123fb7ef637SJakub Kuderski   StringRef getArgument() const final {
124fb7ef637SJakub Kuderski     return "test-vector-contraction-prepare-for-mmt-lowering";
125fb7ef637SJakub Kuderski   }
126fb7ef637SJakub Kuderski   StringRef getDescription() const final {
127fb7ef637SJakub Kuderski     return "Test vector.contraction matmul canonicalization for MMT lowering.";
128fb7ef637SJakub Kuderski   }
129fb7ef637SJakub Kuderski   TestVectorContractionPrepareForMMTLowering() = default;
130fb7ef637SJakub Kuderski 
131fb7ef637SJakub Kuderski   void getDependentDialects(DialectRegistry &registry) const override {
1324c48f016SMatthias Springer     registry.insert<affine::AffineDialect, arith::ArithDialect,
1334c48f016SMatthias Springer                     vector::VectorDialect>();
134fb7ef637SJakub Kuderski   }
135fb7ef637SJakub Kuderski 
136fb7ef637SJakub Kuderski   void runOnOperation() override {
137fb7ef637SJakub Kuderski     MLIRContext *ctx = &getContext();
138fb7ef637SJakub Kuderski     RewritePatternSet patterns(ctx);
139fb7ef637SJakub Kuderski     vector::populateVectorContractCanonicalizeMatmulToMMT(patterns);
140*09dfc571SJacques Pienaar     (void)applyPatternsGreedily(getOperation(), std::move(patterns));
141fb7ef637SJakub Kuderski   }
142fb7ef637SJakub Kuderski };
143fb7ef637SJakub Kuderski 
1443fef2d26SRiver Riddle struct TestVectorUnrollingPatterns
14558ceae95SRiver Riddle     : public PassWrapper<TestVectorUnrollingPatterns,
14658ceae95SRiver Riddle                          OperationPass<func::FuncOp>> {
1475e50dd04SRiver Riddle   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorUnrollingPatterns)
1485e50dd04SRiver Riddle 
149b5e22e6dSMehdi Amini   StringRef getArgument() const final {
150b5e22e6dSMehdi Amini     return "test-vector-unrolling-patterns";
151b5e22e6dSMehdi Amini   }
152b5e22e6dSMehdi Amini   StringRef getDescription() const final {
15334ff8573SNicolas Vasilache     return "Test lowering patterns to unroll contract ops in the vector "
154b5e22e6dSMehdi Amini            "dialect";
155b5e22e6dSMehdi Amini   }
1563fef2d26SRiver Riddle   TestVectorUnrollingPatterns() = default;
1573bab9d4eSMehdi Amini   TestVectorUnrollingPatterns(const TestVectorUnrollingPatterns &pass)
1583bab9d4eSMehdi Amini       : PassWrapper(pass) {}
15941574554SRiver Riddle   void runOnOperation() override {
1603fef2d26SRiver Riddle     MLIRContext *ctx = &getContext();
1613fef2d26SRiver Riddle     RewritePatternSet patterns(ctx);
16229102538Sthomasraoux     populateVectorUnrollPatterns(
16329102538Sthomasraoux         patterns, UnrollVectorOptions()
1643fef2d26SRiver Riddle                       .setNativeShape(ArrayRef<int64_t>{2, 2})
1653fef2d26SRiver Riddle                       .setFilterConstraint([](Operation *op) {
166f69175b1SThomas Raoux                         return success(isa<arith::AddFOp, vector::FMAOp,
167f69175b1SThomas Raoux                                            vector::MultiDimReductionOp>(op));
1683fef2d26SRiver Riddle                       }));
169de5022c7SMatthias Springer     populateVectorUnrollPatterns(
170de5022c7SMatthias Springer         patterns, UnrollVectorOptions()
171de5022c7SMatthias Springer                       .setNativeShape(ArrayRef<int64_t>{2})
172de5022c7SMatthias Springer                       .setFilterConstraint([](Operation *op) {
173de5022c7SMatthias Springer                         return success(isa<vector::ReductionOp>(op));
174de5022c7SMatthias Springer                       }));
1755b1b7108SThomas Raoux     populateVectorUnrollPatterns(
1765b1b7108SThomas Raoux         patterns, UnrollVectorOptions()
1775b1b7108SThomas Raoux                       .setNativeShape(ArrayRef<int64_t>{1, 3, 4, 2})
1785b1b7108SThomas Raoux                       .setFilterConstraint([](Operation *op) {
1795b1b7108SThomas Raoux                         return success(isa<vector::TransposeOp>(op));
1805b1b7108SThomas Raoux                       }));
1813fef2d26SRiver Riddle 
1823fef2d26SRiver Riddle     if (unrollBasedOnType) {
1833fef2d26SRiver Riddle       UnrollVectorOptions::NativeShapeFnType nativeShapeFn =
1840a81ace0SKazu Hirata           [](Operation *op) -> std::optional<SmallVector<int64_t>> {
1853fef2d26SRiver Riddle         vector::ContractionOp contractOp = cast<vector::ContractionOp>(op);
1867a69a9d7SNicolas Vasilache         SmallVector<int64_t> nativeShape(contractOp.getIteratorTypes().size(),
1877a69a9d7SNicolas Vasilache                                          4);
1889f122152SChristopher Bate         Type lhsType = contractOp.getLhsType().getElementType();
1899f122152SChristopher Bate         nativeShape[nativeShape.size() - 1] = lhsType.isF16() ? 4 : 2;
1903fef2d26SRiver Riddle         return nativeShape;
1913fef2d26SRiver Riddle       };
1929f122152SChristopher Bate 
1939f122152SChristopher Bate       UnrollVectorOptions opts;
1949f122152SChristopher Bate       opts.setNativeShapeFn(nativeShapeFn)
1959f122152SChristopher Bate           .setFilterConstraint(
1969f122152SChristopher Bate               [](Operation *op) { return success(isa<ContractionOp>(op)); });
1979f122152SChristopher Bate 
1989f122152SChristopher Bate       if (!unrollOrder.empty()) {
1990a81ace0SKazu Hirata         opts.setUnrollTraversalOrderFn(
2000a81ace0SKazu Hirata             [this](Operation *op) -> std::optional<SmallVector<int64_t>> {
2010a81ace0SKazu Hirata               vector::ContractionOp contractOp =
2020a81ace0SKazu Hirata                   cast<vector::ContractionOp>(op);
2039f122152SChristopher Bate               if (contractOp.getIteratorTypes().size() == unrollOrder.size())
2040a81ace0SKazu Hirata                 return SmallVector<int64_t>(unrollOrder.begin(),
2050a81ace0SKazu Hirata                                             unrollOrder.end());
2061a36588eSKazu Hirata               return std::nullopt;
2079f122152SChristopher Bate             });
2089f122152SChristopher Bate       }
2099f122152SChristopher Bate       populateVectorUnrollPatterns(patterns, opts);
2109f122152SChristopher Bate     } else {
2110a81ace0SKazu Hirata       auto nativeShapeFn =
2120a81ace0SKazu Hirata           [](Operation *op) -> std::optional<SmallVector<int64_t>> {
2139f122152SChristopher Bate         auto contractOp = dyn_cast<ContractionOp>(op);
2149f122152SChristopher Bate         if (!contractOp)
2151a36588eSKazu Hirata           return std::nullopt;
2167a69a9d7SNicolas Vasilache         return SmallVector<int64_t>(contractOp.getIteratorTypes().size(), 2);
2179f122152SChristopher Bate       };
21853fe155bSChristopher Bate       populateVectorUnrollPatterns(patterns,
21953fe155bSChristopher Bate                                    UnrollVectorOptions()
22053fe155bSChristopher Bate                                        .setNativeShapeFn(nativeShapeFn)
22153fe155bSChristopher Bate                                        .setFilterConstraint([](Operation *op) {
22253fe155bSChristopher Bate                                          return success(isa<ContractionOp>(op));
22353fe155bSChristopher Bate                                        }));
2243fef2d26SRiver Riddle     }
2253fef2d26SRiver Riddle     populateVectorToVectorCanonicalizationPatterns(patterns);
226*09dfc571SJacques Pienaar     (void)applyPatternsGreedily(getOperation(), std::move(patterns));
2273fef2d26SRiver Riddle   }
2283fef2d26SRiver Riddle 
2299f122152SChristopher Bate   ListOption<int64_t> unrollOrder{*this, "unroll-order",
23062a4e6abSFangrui Song                                   llvm::cl::desc("set the unroll order")};
2319f122152SChristopher Bate 
2323fef2d26SRiver Riddle   Option<bool> unrollBasedOnType{
2333fef2d26SRiver Riddle       *this, "unroll-based-on-type",
2343fef2d26SRiver Riddle       llvm::cl::desc("Set the unroll factor based on type of the operation"),
2353fef2d26SRiver Riddle       llvm::cl::init(false)};
2363fef2d26SRiver Riddle };
2373fef2d26SRiver Riddle 
2383fef2d26SRiver Riddle struct TestVectorTransferUnrollingPatterns
23941574554SRiver Riddle     : public PassWrapper<TestVectorTransferUnrollingPatterns,
24058ceae95SRiver Riddle                          OperationPass<func::FuncOp>> {
2415e50dd04SRiver Riddle   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
2425e50dd04SRiver Riddle       TestVectorTransferUnrollingPatterns)
2435e50dd04SRiver Riddle 
2449f122152SChristopher Bate   TestVectorTransferUnrollingPatterns() = default;
2459f122152SChristopher Bate   TestVectorTransferUnrollingPatterns(
2469f122152SChristopher Bate       const TestVectorTransferUnrollingPatterns &pass)
2479f122152SChristopher Bate       : PassWrapper(pass) {}
2489f122152SChristopher Bate 
2493fef2d26SRiver Riddle   void getDependentDialects(DialectRegistry &registry) const override {
2504c48f016SMatthias Springer     registry.insert<affine::AffineDialect>();
2513fef2d26SRiver Riddle   }
252b5e22e6dSMehdi Amini   StringRef getArgument() const final {
253b5e22e6dSMehdi Amini     return "test-vector-transfer-unrolling-patterns";
254b5e22e6dSMehdi Amini   }
255b5e22e6dSMehdi Amini   StringRef getDescription() const final {
25634ff8573SNicolas Vasilache     return "Test lowering patterns to unroll transfer ops in the vector "
257b5e22e6dSMehdi Amini            "dialect";
258b5e22e6dSMehdi Amini   }
25941574554SRiver Riddle   void runOnOperation() override {
2603fef2d26SRiver Riddle     MLIRContext *ctx = &getContext();
2613fef2d26SRiver Riddle     RewritePatternSet patterns(ctx);
2629f122152SChristopher Bate     UnrollVectorOptions opts;
2639f122152SChristopher Bate     opts.setNativeShape(ArrayRef<int64_t>{2, 2})
2643fef2d26SRiver Riddle         .setFilterConstraint([](Operation *op) {
265435f7d4cSQuinn Dawkins           return success(isa<vector::TransferReadOp, vector::TransferWriteOp,
266435f7d4cSQuinn Dawkins                              vector::GatherOp>(op));
2679f122152SChristopher Bate         });
2689f122152SChristopher Bate     if (reverseUnrollOrder.getValue()) {
2699f122152SChristopher Bate       opts.setUnrollTraversalOrderFn(
2700a81ace0SKazu Hirata           [](Operation *op) -> std::optional<SmallVector<int64_t>> {
2719f122152SChristopher Bate             int64_t numLoops = 0;
2729f122152SChristopher Bate             if (auto readOp = dyn_cast<vector::TransferReadOp>(op))
2739f122152SChristopher Bate               numLoops = readOp.getVectorType().getRank();
2749f122152SChristopher Bate             else if (auto writeOp = dyn_cast<vector::TransferWriteOp>(op))
2759f122152SChristopher Bate               numLoops = writeOp.getVectorType().getRank();
276435f7d4cSQuinn Dawkins             else if (auto gatherOp = dyn_cast<vector::GatherOp>(op))
277435f7d4cSQuinn Dawkins               numLoops = gatherOp.getVectorType().getRank();
2789f122152SChristopher Bate             else
2791a36588eSKazu Hirata               return std::nullopt;
2809f122152SChristopher Bate             auto order = llvm::reverse(llvm::seq<int64_t>(0, numLoops));
2819f122152SChristopher Bate             return llvm::to_vector(order);
2829f122152SChristopher Bate           });
2839f122152SChristopher Bate     }
2849f122152SChristopher Bate     populateVectorUnrollPatterns(patterns, opts);
2853fef2d26SRiver Riddle     populateVectorToVectorCanonicalizationPatterns(patterns);
286*09dfc571SJacques Pienaar     (void)applyPatternsGreedily(getOperation(), std::move(patterns));
2873fef2d26SRiver Riddle   }
2889f122152SChristopher Bate 
2899f122152SChristopher Bate   Option<bool> reverseUnrollOrder{
2909f122152SChristopher Bate       *this, "reverse-unroll-order",
2919f122152SChristopher Bate       llvm::cl::desc(
2929f122152SChristopher Bate           "reverse the order of unrolling of vector transfer operations"),
2939f122152SChristopher Bate       llvm::cl::init(false)};
2943fef2d26SRiver Riddle };
2953fef2d26SRiver Riddle 
2962ec98ffbSMatthias Springer struct TestScalarVectorTransferLoweringPatterns
2972ec98ffbSMatthias Springer     : public PassWrapper<TestScalarVectorTransferLoweringPatterns,
2982ec98ffbSMatthias Springer                          OperationPass<func::FuncOp>> {
2992ec98ffbSMatthias Springer   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
3002ec98ffbSMatthias Springer       TestScalarVectorTransferLoweringPatterns)
3012ec98ffbSMatthias Springer 
30214726cd6SDiego Caballero   TestScalarVectorTransferLoweringPatterns() = default;
30314726cd6SDiego Caballero   TestScalarVectorTransferLoweringPatterns(
30414726cd6SDiego Caballero       const TestScalarVectorTransferLoweringPatterns &pass)
30514726cd6SDiego Caballero       : PassWrapper(pass) {}
30614726cd6SDiego Caballero 
3072ec98ffbSMatthias Springer   StringRef getArgument() const final {
3082ec98ffbSMatthias Springer     return "test-scalar-vector-transfer-lowering";
3092ec98ffbSMatthias Springer   }
3102ec98ffbSMatthias Springer   StringRef getDescription() const final {
3112ec98ffbSMatthias Springer     return "Test lowering of scalar vector transfers to memref loads/stores.";
3122ec98ffbSMatthias Springer   }
3132ec98ffbSMatthias Springer 
3142ec98ffbSMatthias Springer   void getDependentDialects(DialectRegistry &registry) const override {
3154c48f016SMatthias Springer     registry.insert<affine::AffineDialect, memref::MemRefDialect,
3164c48f016SMatthias Springer                     tensor::TensorDialect, vector::VectorDialect>();
3172ec98ffbSMatthias Springer   }
3182ec98ffbSMatthias Springer 
31914726cd6SDiego Caballero   Option<bool> allowMultipleUses{
32014726cd6SDiego Caballero       *this, "allow-multiple-uses",
32114726cd6SDiego Caballero       llvm::cl::desc("Fold transfer operations with multiple uses"),
32214726cd6SDiego Caballero       llvm::cl::init(false)};
32314726cd6SDiego Caballero 
3242ec98ffbSMatthias Springer   void runOnOperation() override {
3252ec98ffbSMatthias Springer     MLIRContext *ctx = &getContext();
3262ec98ffbSMatthias Springer     RewritePatternSet patterns(ctx);
32714726cd6SDiego Caballero     vector::populateScalarVectorTransferLoweringPatterns(
32814726cd6SDiego Caballero         patterns, /*benefit=*/1, allowMultipleUses.getValue());
329*09dfc571SJacques Pienaar     (void)applyPatternsGreedily(getOperation(), std::move(patterns));
3302ec98ffbSMatthias Springer   }
3312ec98ffbSMatthias Springer };
3322ec98ffbSMatthias Springer 
3333fef2d26SRiver Riddle struct TestVectorTransferOpt
33458ceae95SRiver Riddle     : public PassWrapper<TestVectorTransferOpt, OperationPass<func::FuncOp>> {
3355e50dd04SRiver Riddle   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorTransferOpt)
3365e50dd04SRiver Riddle 
337b5e22e6dSMehdi Amini   StringRef getArgument() const final { return "test-vector-transferop-opt"; }
338b5e22e6dSMehdi Amini   StringRef getDescription() const final {
339b5e22e6dSMehdi Amini     return "Test optimization transformations for transfer ops";
340b5e22e6dSMehdi Amini   }
341553cebdeSNicolas Vasilache   void runOnOperation() override {
342553cebdeSNicolas Vasilache     IRRewriter rewriter(&getContext());
343553cebdeSNicolas Vasilache     transferOpflowOpt(rewriter, getOperation());
344553cebdeSNicolas Vasilache   }
3453fef2d26SRiver Riddle };
3463fef2d26SRiver Riddle 
347a3dd4e77SAhmed S. Taei struct TestVectorTransferCollapseInnerMostContiguousDims
348a3dd4e77SAhmed S. Taei     : public PassWrapper<TestVectorTransferCollapseInnerMostContiguousDims,
34958ceae95SRiver Riddle                          OperationPass<func::FuncOp>> {
3505e50dd04SRiver Riddle   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
3515e50dd04SRiver Riddle       TestVectorTransferCollapseInnerMostContiguousDims)
3525e50dd04SRiver Riddle 
353a3dd4e77SAhmed S. Taei   TestVectorTransferCollapseInnerMostContiguousDims() = default;
354a3dd4e77SAhmed S. Taei   TestVectorTransferCollapseInnerMostContiguousDims(
355322c8914SMehdi Amini       const TestVectorTransferCollapseInnerMostContiguousDims &pass) = default;
356a3dd4e77SAhmed S. Taei 
357a3dd4e77SAhmed S. Taei   void getDependentDialects(DialectRegistry &registry) const override {
3584c48f016SMatthias Springer     registry.insert<memref::MemRefDialect, affine::AffineDialect>();
359a3dd4e77SAhmed S. Taei   }
360a3dd4e77SAhmed S. Taei 
361a3dd4e77SAhmed S. Taei   StringRef getArgument() const final {
362a3dd4e77SAhmed S. Taei     return "test-vector-transfer-collapse-inner-most-dims";
363a3dd4e77SAhmed S. Taei   }
364a3dd4e77SAhmed S. Taei 
365a3dd4e77SAhmed S. Taei   StringRef getDescription() const final {
366017c75bfSKai Sasaki     return "Test lowering patterns that reduces the rank of the vector "
367a3dd4e77SAhmed S. Taei            "transfer memory and vector operands.";
368a3dd4e77SAhmed S. Taei   }
369a3dd4e77SAhmed S. Taei 
37041574554SRiver Riddle   void runOnOperation() override {
371a3dd4e77SAhmed S. Taei     RewritePatternSet patterns(&getContext());
372a3dd4e77SAhmed S. Taei     populateVectorTransferCollapseInnerMostContiguousDimsPatterns(patterns);
373*09dfc571SJacques Pienaar     (void)applyPatternsGreedily(getOperation(), std::move(patterns));
374a3dd4e77SAhmed S. Taei   }
375a3dd4e77SAhmed S. Taei };
376a3dd4e77SAhmed S. Taei 
37742944da5SAndrzej Warzyński struct TestVectorSinkPatterns
37842944da5SAndrzej Warzyński     : public PassWrapper<TestVectorSinkPatterns, OperationPass<func::FuncOp>> {
37942944da5SAndrzej Warzyński   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorSinkPatterns)
3804d339ec9SAndrzej Warzynski 
38142944da5SAndrzej Warzyński   TestVectorSinkPatterns() = default;
38242944da5SAndrzej Warzyński   TestVectorSinkPatterns(const TestVectorSinkPatterns &pass) = default;
3834d339ec9SAndrzej Warzynski 
3844d339ec9SAndrzej Warzynski   void getDependentDialects(DialectRegistry &registry) const override {
3854d339ec9SAndrzej Warzynski     registry.insert<memref::MemRefDialect, affine::AffineDialect>();
3864d339ec9SAndrzej Warzynski   }
3874d339ec9SAndrzej Warzynski 
38842944da5SAndrzej Warzyński   StringRef getArgument() const final { return "test-vector-sink-patterns"; }
3894d339ec9SAndrzej Warzynski 
3904d339ec9SAndrzej Warzynski   StringRef getDescription() const final {
391017c75bfSKai Sasaki     return "Test lowering patterns that eliminate redundant broadcast "
39242944da5SAndrzej Warzyński            "and transpose operations.";
3934d339ec9SAndrzej Warzynski   }
3944d339ec9SAndrzej Warzynski 
3954d339ec9SAndrzej Warzynski   void runOnOperation() override {
3964d339ec9SAndrzej Warzynski     RewritePatternSet patterns(&getContext());
39742944da5SAndrzej Warzyński     populateSinkVectorOpsPatterns(patterns);
398*09dfc571SJacques Pienaar     (void)applyPatternsGreedily(getOperation(), std::move(patterns));
3994d339ec9SAndrzej Warzynski   }
4004d339ec9SAndrzej Warzynski };
4014d339ec9SAndrzej Warzynski 
4021d8cc45bSthomasraoux struct TestVectorReduceToContractPatternsPatterns
4031d8cc45bSthomasraoux     : public PassWrapper<TestVectorReduceToContractPatternsPatterns,
40458ceae95SRiver Riddle                          OperationPass<func::FuncOp>> {
4055e50dd04SRiver Riddle   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
4065e50dd04SRiver Riddle       TestVectorReduceToContractPatternsPatterns)
4075e50dd04SRiver Riddle 
4081d8cc45bSthomasraoux   StringRef getArgument() const final {
4091d8cc45bSthomasraoux     return "test-vector-reduction-to-contract-patterns";
4101d8cc45bSthomasraoux   }
4111d8cc45bSthomasraoux   StringRef getDescription() const final {
4121d8cc45bSthomasraoux     return "Test patterns to convert multireduce op to contract and combine "
4131d8cc45bSthomasraoux            "broadcast/transpose to contract";
4141d8cc45bSthomasraoux   }
41541574554SRiver Riddle   void runOnOperation() override {
4161d8cc45bSthomasraoux     RewritePatternSet patterns(&getContext());
417d054b80bSNicolas Vasilache     populateVectorReductionToContractPatterns(patterns);
418*09dfc571SJacques Pienaar     (void)applyPatternsGreedily(getOperation(), std::move(patterns));
4191d8cc45bSthomasraoux   }
4201d8cc45bSthomasraoux };
4211d8cc45bSthomasraoux 
422d33bad66SJakub Kuderski struct TestVectorChainedReductionFoldingPatterns
423d33bad66SJakub Kuderski     : public PassWrapper<TestVectorChainedReductionFoldingPatterns,
424d33bad66SJakub Kuderski                          OperationPass<func::FuncOp>> {
425d33bad66SJakub Kuderski   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
426d33bad66SJakub Kuderski       TestVectorChainedReductionFoldingPatterns)
427d33bad66SJakub Kuderski 
428d33bad66SJakub Kuderski   StringRef getArgument() const final {
429d33bad66SJakub Kuderski     return "test-vector-chained-reduction-folding-patterns";
430d33bad66SJakub Kuderski   }
431d33bad66SJakub Kuderski   StringRef getDescription() const final {
432d33bad66SJakub Kuderski     return "Test patterns to fold chained vector reductions";
433d33bad66SJakub Kuderski   }
434d33bad66SJakub Kuderski   void runOnOperation() override {
435d33bad66SJakub Kuderski     RewritePatternSet patterns(&getContext());
436d33bad66SJakub Kuderski     populateChainedVectorReductionFoldingPatterns(patterns);
437*09dfc571SJacques Pienaar     (void)applyPatternsGreedily(getOperation(), std::move(patterns));
438d33bad66SJakub Kuderski   }
439d33bad66SJakub Kuderski };
440d33bad66SJakub Kuderski 
44107677113SJakub Kuderski struct TestVectorBreakDownReductionPatterns
44207677113SJakub Kuderski     : public PassWrapper<TestVectorBreakDownReductionPatterns,
44307677113SJakub Kuderski                          OperationPass<func::FuncOp>> {
44407677113SJakub Kuderski   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
44507677113SJakub Kuderski       TestVectorBreakDownReductionPatterns)
44607677113SJakub Kuderski 
44707677113SJakub Kuderski   StringRef getArgument() const final {
44807677113SJakub Kuderski     return "test-vector-break-down-reduction-patterns";
44907677113SJakub Kuderski   }
45007677113SJakub Kuderski   StringRef getDescription() const final {
45107677113SJakub Kuderski     return "Test patterns to break down vector reductions into arith "
45207677113SJakub Kuderski            "reductions";
45307677113SJakub Kuderski   }
45407677113SJakub Kuderski   void runOnOperation() override {
45507677113SJakub Kuderski     RewritePatternSet patterns(&getContext());
45607677113SJakub Kuderski     populateBreakDownVectorReductionPatterns(patterns,
45707677113SJakub Kuderski                                              /*maxNumElementsToExtract=*/2);
458*09dfc571SJacques Pienaar     (void)applyPatternsGreedily(getOperation(), std::move(patterns));
45907677113SJakub Kuderski   }
46007677113SJakub Kuderski };
46107677113SJakub Kuderski 
462aba437ceSBenoit Jacob struct TestFlattenVectorTransferPatterns
46341574554SRiver Riddle     : public PassWrapper<TestFlattenVectorTransferPatterns,
46458ceae95SRiver Riddle                          OperationPass<func::FuncOp>> {
4655e50dd04SRiver Riddle   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
4665e50dd04SRiver Riddle       TestFlattenVectorTransferPatterns)
4675e50dd04SRiver Riddle 
46871441ed1SDiego Caballero   TestFlattenVectorTransferPatterns() = default;
46971441ed1SDiego Caballero   TestFlattenVectorTransferPatterns(
47071441ed1SDiego Caballero       const TestFlattenVectorTransferPatterns &pass)
47171441ed1SDiego Caballero       : PassWrapper(pass) {}
47271441ed1SDiego Caballero 
473aba437ceSBenoit Jacob   StringRef getArgument() const final {
474aba437ceSBenoit Jacob     return "test-vector-transfer-flatten-patterns";
475aba437ceSBenoit Jacob   }
47671441ed1SDiego Caballero 
477aba437ceSBenoit Jacob   StringRef getDescription() const final {
478aba437ceSBenoit Jacob     return "Test patterns to rewrite contiguous row-major N-dimensional "
479aba437ceSBenoit Jacob            "vector.transfer_{read,write} ops into 1D transfers";
480aba437ceSBenoit Jacob   }
48171441ed1SDiego Caballero 
482aba437ceSBenoit Jacob   void getDependentDialects(DialectRegistry &registry) const override {
483aba437ceSBenoit Jacob     registry.insert<memref::MemRefDialect>();
4842eb9e33cSAndrzej Warzyński     registry.insert<affine::AffineDialect>();
485c02d07fdSAndrzej Warzyński     registry.insert<vector::VectorDialect>();
486aba437ceSBenoit Jacob   }
48771441ed1SDiego Caballero 
48871441ed1SDiego Caballero   Option<unsigned> targetVectorBitwidth{
48971441ed1SDiego Caballero       *this, "target-vector-bitwidth",
49071441ed1SDiego Caballero       llvm::cl::desc(
491d3aa92edSAndrzej Warzyński           "Minimum vector bitwidth to enable the flattening transformation. "
492d3aa92edSAndrzej Warzyński           "For scalable vectors this is the base size, i.e. the size "
493d3aa92edSAndrzej Warzyński           "corresponding to vscale=1."),
49471441ed1SDiego Caballero       llvm::cl::init(std::numeric_limits<unsigned>::max())};
49571441ed1SDiego Caballero 
49641574554SRiver Riddle   void runOnOperation() override {
497aba437ceSBenoit Jacob     RewritePatternSet patterns(&getContext());
49871441ed1SDiego Caballero     populateFlattenVectorTransferPatterns(patterns, targetVectorBitwidth);
499*09dfc571SJacques Pienaar     (void)applyPatternsGreedily(getOperation(), std::move(patterns));
500aba437ceSBenoit Jacob   }
501aba437ceSBenoit Jacob };
502aba437ceSBenoit Jacob 
50380e0bf1aSharsh struct TestVectorScanLowering
50458ceae95SRiver Riddle     : public PassWrapper<TestVectorScanLowering, OperationPass<func::FuncOp>> {
5055e50dd04SRiver Riddle   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorScanLowering)
5065e50dd04SRiver Riddle 
50780e0bf1aSharsh   StringRef getArgument() const final { return "test-vector-scan-lowering"; }
50880e0bf1aSharsh   StringRef getDescription() const final {
50980e0bf1aSharsh     return "Test lowering patterns that lower the scan op in the vector "
51080e0bf1aSharsh            "dialect";
51180e0bf1aSharsh   }
51280e0bf1aSharsh   void runOnOperation() override {
51380e0bf1aSharsh     RewritePatternSet patterns(&getContext());
51480e0bf1aSharsh     populateVectorScanLoweringPatterns(patterns);
515*09dfc571SJacques Pienaar     (void)applyPatternsGreedily(getOperation(), std::move(patterns));
51680e0bf1aSharsh   }
51780e0bf1aSharsh };
51880e0bf1aSharsh 
519d02f10d9SThomas Raoux /// Allocate shared memory for a single warp to test lowering of
520d02f10d9SThomas Raoux /// WarpExecuteOnLane0Op.
521d02f10d9SThomas Raoux static Value allocateGlobalSharedMemory(Location loc, OpBuilder &builder,
522ecaf2c33SPetr Kurapov                                         gpu::WarpExecuteOnLane0Op warpOp,
523d02f10d9SThomas Raoux                                         Type type) {
524d02f10d9SThomas Raoux   static constexpr int64_t kSharedMemorySpace = 3;
525d02f10d9SThomas Raoux   // Compute type of shared memory buffer.
526d02f10d9SThomas Raoux   MemRefType memrefType;
5275550c821STres Popp   if (auto vectorType = dyn_cast<VectorType>(type)) {
528d02f10d9SThomas Raoux     memrefType =
529d02f10d9SThomas Raoux         MemRefType::get(vectorType.getShape(), vectorType.getElementType(), {},
530d02f10d9SThomas Raoux                         kSharedMemorySpace);
531d02f10d9SThomas Raoux   } else {
532d02f10d9SThomas Raoux     memrefType = MemRefType::get({1}, type, {}, kSharedMemorySpace);
533d02f10d9SThomas Raoux   }
534d02f10d9SThomas Raoux 
535d02f10d9SThomas Raoux   // Get symbol table holding all shared memory globals.
536d02f10d9SThomas Raoux   ModuleOp moduleOp = warpOp->getParentOfType<ModuleOp>();
537d02f10d9SThomas Raoux   SymbolTable symbolTable(moduleOp);
538d02f10d9SThomas Raoux 
539d02f10d9SThomas Raoux   // Create a pretty name.
540d02f10d9SThomas Raoux   SmallString<64> buf;
541d02f10d9SThomas Raoux   llvm::raw_svector_ostream os(buf);
542d02f10d9SThomas Raoux   interleave(memrefType.getShape(), os, "x");
543d02f10d9SThomas Raoux   os << "x" << memrefType.getElementType();
544d02f10d9SThomas Raoux   std::string symbolName = (Twine("__shared_") + os.str()).str();
545d02f10d9SThomas Raoux 
546d02f10d9SThomas Raoux   auto ip = builder.saveInsertionPoint();
547d02f10d9SThomas Raoux   builder.setInsertionPoint(moduleOp);
548d02f10d9SThomas Raoux   auto global = builder.create<memref::GlobalOp>(
549d02f10d9SThomas Raoux       loc,
550d02f10d9SThomas Raoux       /*sym_name=*/symbolName,
551d02f10d9SThomas Raoux       /*sym_visibility=*/builder.getStringAttr("private"),
552d02f10d9SThomas Raoux       /*type=*/memrefType,
553d02f10d9SThomas Raoux       /*initial_value=*/Attribute(),
554d02f10d9SThomas Raoux       /*constant=*/false,
555d02f10d9SThomas Raoux       /*alignment=*/IntegerAttr());
556d02f10d9SThomas Raoux   symbolTable.insert(global);
557d02f10d9SThomas Raoux   // The symbol table inserts at the end of the module, but globals are a bit
558d02f10d9SThomas Raoux   // nicer if they are at the beginning.
559d02f10d9SThomas Raoux   global->moveBefore(&moduleOp.front());
560d02f10d9SThomas Raoux 
561d02f10d9SThomas Raoux   builder.restoreInsertionPoint(ip);
562d02f10d9SThomas Raoux   return builder.create<memref::GetGlobalOp>(loc, memrefType, symbolName);
563d02f10d9SThomas Raoux }
564d02f10d9SThomas Raoux 
5656834803cSThomas Raoux static Value warpReduction(Location loc, OpBuilder &builder, Value input,
5666834803cSThomas Raoux                            CombiningKind kind, uint32_t size) {
567d2061530Sstanley-nod   // First reduce on a single thread to get per lane reduction value.
568d2061530Sstanley-nod   Value laneVal = builder.create<vector::ReductionOp>(loc, kind, input);
5696834803cSThomas Raoux   // Parallel reduction using butterfly shuffles.
5706834803cSThomas Raoux   for (uint64_t i = 1; i < size; i <<= 1) {
5716834803cSThomas Raoux     Value shuffled = builder
5726834803cSThomas Raoux                          .create<gpu::ShuffleOp>(loc, laneVal, i,
5736834803cSThomas Raoux                                                  /*width=*/size,
5746834803cSThomas Raoux                                                  /*mode=*/gpu::ShuffleMode::XOR)
575986b5c56SRiver Riddle                          .getShuffleResult();
5766834803cSThomas Raoux     laneVal = makeArithReduction(builder, loc, kind, laneVal, shuffled);
5776834803cSThomas Raoux   }
5786834803cSThomas Raoux   return laneVal;
5796834803cSThomas Raoux }
5806834803cSThomas Raoux 
581d02f10d9SThomas Raoux struct TestVectorDistribution
582d02f10d9SThomas Raoux     : public PassWrapper<TestVectorDistribution, OperationPass<func::FuncOp>> {
583d02f10d9SThomas Raoux   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorDistribution)
584d02f10d9SThomas Raoux 
585d02f10d9SThomas Raoux   void getDependentDialects(DialectRegistry &registry) const override {
586ecaf2c33SPetr Kurapov     registry
587ecaf2c33SPetr Kurapov         .insert<vector::VectorDialect, scf::SCFDialect, memref::MemRefDialect,
588ecaf2c33SPetr Kurapov                 gpu::GPUDialect, affine::AffineDialect>();
589d02f10d9SThomas Raoux   }
590d02f10d9SThomas Raoux 
591d02f10d9SThomas Raoux   StringRef getArgument() const final { return "test-vector-warp-distribute"; }
592d02f10d9SThomas Raoux   StringRef getDescription() const final {
593d02f10d9SThomas Raoux     return "Test vector warp distribute transformation and lowering patterns";
594d02f10d9SThomas Raoux   }
595d02f10d9SThomas Raoux   TestVectorDistribution() = default;
596d02f10d9SThomas Raoux   TestVectorDistribution(const TestVectorDistribution &pass)
597d02f10d9SThomas Raoux       : PassWrapper(pass) {}
598d02f10d9SThomas Raoux 
599d02f10d9SThomas Raoux   Option<bool> warpOpToSCF{
600d02f10d9SThomas Raoux       *this, "rewrite-warp-ops-to-scf-if",
601d02f10d9SThomas Raoux       llvm::cl::desc("Lower vector.warp_execute_on_lane0 to scf.if op"),
602d02f10d9SThomas Raoux       llvm::cl::init(false)};
603d02f10d9SThomas Raoux 
604ed0288f7SThomas Raoux   Option<bool> distributeTransferWriteOps{
605ed0288f7SThomas Raoux       *this, "distribute-transfer-write",
606ed0288f7SThomas Raoux       llvm::cl::desc("Test distribution of transfer write"),
607ed0288f7SThomas Raoux       llvm::cl::init(false)};
608ed0288f7SThomas Raoux 
60980636227SJakub Kuderski   Option<unsigned> maxTransferWriteElements{
61080636227SJakub Kuderski       *this, "max-transfer-write-elements",
61180636227SJakub Kuderski       llvm::cl::desc("Maximum number of transfer write elements to distribute"),
61280636227SJakub Kuderski       llvm::cl::init(1)};
61380636227SJakub Kuderski 
614ed0288f7SThomas Raoux   Option<bool> hoistUniform{*this, "hoist-uniform",
615ed0288f7SThomas Raoux                             llvm::cl::desc("Test hoist uniform"),
616ed0288f7SThomas Raoux                             llvm::cl::init(false)};
617ed0288f7SThomas Raoux 
61876cf33daSThomas Raoux   Option<bool> propagateDistribution{
61976cf33daSThomas Raoux       *this, "propagate-distribution",
620017c75bfSKai Sasaki       llvm::cl::desc("Test distribution propagation"), llvm::cl::init(false)};
62176cf33daSThomas Raoux 
622d02f10d9SThomas Raoux   void runOnOperation() override {
623d02f10d9SThomas Raoux     RewritePatternSet patterns(&getContext());
624ed0288f7SThomas Raoux 
625ed0288f7SThomas Raoux     getOperation().walk([&](Operation *op) {
626ecaf2c33SPetr Kurapov       if (auto warpOp = dyn_cast<gpu::WarpExecuteOnLane0Op>(op)) {
627ed0288f7SThomas Raoux         if (hoistUniform) {
628ed0288f7SThomas Raoux           moveScalarUniformCode(warpOp);
629ed0288f7SThomas Raoux         }
630ed0288f7SThomas Raoux         WalkResult::interrupt();
631ed0288f7SThomas Raoux       }
632ed0288f7SThomas Raoux     });
633ed0288f7SThomas Raoux     MLIRContext *ctx = &getContext();
63491f62f0eSThomas Raoux     auto distributionFn = [](Value val) {
635c2b95292SQuinn Dawkins       // Create an identity dim map of the same rank as the vector.
6365550c821STres Popp       VectorType vecType = dyn_cast<VectorType>(val.getType());
63791f62f0eSThomas Raoux       int64_t vecRank = vecType ? vecType.getRank() : 0;
63891f62f0eSThomas Raoux       OpBuilder builder(val.getContext());
63991f62f0eSThomas Raoux       if (vecRank == 0)
64091f62f0eSThomas Raoux         return AffineMap::get(val.getContext());
641c2b95292SQuinn Dawkins       return AffineMap::getMultiDimIdentityMap(vecRank, val.getContext());
642ed0288f7SThomas Raoux     };
6439d51b4e4SMatthias Springer     auto shuffleFn = [](Location loc, OpBuilder &builder, Value val,
6449d51b4e4SMatthias Springer                         Value srcIdx, int64_t warpSz) {
6459d51b4e4SMatthias Springer       assert((val.getType().isF32() || val.getType().isInteger(32)) &&
6469d51b4e4SMatthias Springer              "unsupported shuffle type");
6479d51b4e4SMatthias Springer       Type i32Type = builder.getIntegerType(32);
6489d51b4e4SMatthias Springer       Value srcIdxI32 =
6499d51b4e4SMatthias Springer           builder.create<arith::IndexCastOp>(loc, i32Type, srcIdx);
6509d51b4e4SMatthias Springer       Value warpSzI32 = builder.create<arith::ConstantOp>(
6519d51b4e4SMatthias Springer           loc, builder.getIntegerAttr(i32Type, warpSz));
6529d51b4e4SMatthias Springer       Value result = builder
6539d51b4e4SMatthias Springer                          .create<gpu::ShuffleOp>(loc, val, srcIdxI32, warpSzI32,
6549d51b4e4SMatthias Springer                                                  gpu::ShuffleMode::IDX)
6559d51b4e4SMatthias Springer                          .getResult(0);
6569d51b4e4SMatthias Springer       return result;
6579d51b4e4SMatthias Springer     };
658df49a97aSQuinn Dawkins     if (distributeTransferWriteOps && propagateDistribution) {
659df49a97aSQuinn Dawkins       RewritePatternSet patterns(ctx);
660df49a97aSQuinn Dawkins       vector::populatePropagateWarpVectorDistributionPatterns(
661df49a97aSQuinn Dawkins           patterns, distributionFn, shuffleFn, /*benefit=*/1,
662df49a97aSQuinn Dawkins           /*readBenefit=*/0);
663df49a97aSQuinn Dawkins       vector::populateDistributeReduction(patterns, warpReduction, 1);
664df49a97aSQuinn Dawkins       populateDistributeTransferWriteOpPatterns(patterns, distributionFn, 2);
665*09dfc571SJacques Pienaar       (void)applyPatternsGreedily(getOperation(), std::move(patterns));
666df49a97aSQuinn Dawkins     } else if (distributeTransferWriteOps) {
667ed0288f7SThomas Raoux       RewritePatternSet patterns(ctx);
66880636227SJakub Kuderski       populateDistributeTransferWriteOpPatterns(patterns, distributionFn,
66980636227SJakub Kuderski                                                 maxTransferWriteElements);
670*09dfc571SJacques Pienaar       (void)applyPatternsGreedily(getOperation(), std::move(patterns));
671df49a97aSQuinn Dawkins     } else if (propagateDistribution) {
67276cf33daSThomas Raoux       RewritePatternSet patterns(ctx);
6739d51b4e4SMatthias Springer       vector::populatePropagateWarpVectorDistributionPatterns(
6749d51b4e4SMatthias Springer           patterns, distributionFn, shuffleFn);
6756834803cSThomas Raoux       vector::populateDistributeReduction(patterns, warpReduction);
676*09dfc571SJacques Pienaar       (void)applyPatternsGreedily(getOperation(), std::move(patterns));
67776cf33daSThomas Raoux     }
678d02f10d9SThomas Raoux     WarpExecuteOnLane0LoweringOptions options;
679d02f10d9SThomas Raoux     options.warpAllocationFn = allocateGlobalSharedMemory;
680d02f10d9SThomas Raoux     options.warpSyncronizationFn = [](Location loc, OpBuilder &builder,
681ecaf2c33SPetr Kurapov                                       gpu::WarpExecuteOnLane0Op warpOp) {
682d02f10d9SThomas Raoux       builder.create<gpu::BarrierOp>(loc);
683d02f10d9SThomas Raoux     };
684d02f10d9SThomas Raoux     // Test on one pattern in isolation.
685d02f10d9SThomas Raoux     if (warpOpToSCF) {
686d02f10d9SThomas Raoux       populateWarpExecuteOnLane0OpToScfForPattern(patterns, options);
687*09dfc571SJacques Pienaar       (void)applyPatternsGreedily(getOperation(), std::move(patterns));
688d02f10d9SThomas Raoux       return;
689d02f10d9SThomas Raoux     }
690d02f10d9SThomas Raoux   }
691d02f10d9SThomas Raoux };
692d02f10d9SThomas Raoux 
69339c80656SLei Zhang struct TestVectorExtractStridedSliceLowering
69439c80656SLei Zhang     : public PassWrapper<TestVectorExtractStridedSliceLowering,
69539c80656SLei Zhang                          OperationPass<func::FuncOp>> {
69639c80656SLei Zhang   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
69739c80656SLei Zhang       TestVectorExtractStridedSliceLowering)
69839c80656SLei Zhang 
69939c80656SLei Zhang   StringRef getArgument() const final {
70039c80656SLei Zhang     return "test-vector-extract-strided-slice-lowering";
70139c80656SLei Zhang   }
70239c80656SLei Zhang   StringRef getDescription() const final {
70339c80656SLei Zhang     return "Test lowering patterns that converts vector.extract_strided_slice "
70439c80656SLei Zhang            "into a chain of vector.extract and vector.insert ops";
70539c80656SLei Zhang   }
70639c80656SLei Zhang   void runOnOperation() override {
70739c80656SLei Zhang     RewritePatternSet patterns(&getContext());
70839c80656SLei Zhang     populateVectorExtractStridedSliceToExtractInsertChainPatterns(patterns);
709*09dfc571SJacques Pienaar     (void)applyPatternsGreedily(getOperation(), std::move(patterns));
71039c80656SLei Zhang   }
71139c80656SLei Zhang };
71239c80656SLei Zhang 
713650f04feSQuinn Dawkins struct TestVectorBreakDownBitCast
714650f04feSQuinn Dawkins     : public PassWrapper<TestVectorBreakDownBitCast,
715650f04feSQuinn Dawkins                          OperationPass<func::FuncOp>> {
716650f04feSQuinn Dawkins   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorBreakDownBitCast)
717650f04feSQuinn Dawkins 
718650f04feSQuinn Dawkins   StringRef getArgument() const final {
719650f04feSQuinn Dawkins     return "test-vector-break-down-bitcast";
720650f04feSQuinn Dawkins   }
721650f04feSQuinn Dawkins   StringRef getDescription() const final {
722650f04feSQuinn Dawkins     return "Test pattern that breaks down vector.bitcast ops ";
723650f04feSQuinn Dawkins   }
724650f04feSQuinn Dawkins   void runOnOperation() override {
725650f04feSQuinn Dawkins     RewritePatternSet patterns(&getContext());
726650f04feSQuinn Dawkins     populateBreakDownVectorBitCastOpPatterns(patterns, [](BitCastOp op) {
727650f04feSQuinn Dawkins       return op.getSourceVectorType().getShape().back() > 4;
728650f04feSQuinn Dawkins     });
729*09dfc571SJacques Pienaar     (void)applyPatternsGreedily(getOperation(), std::move(patterns));
730650f04feSQuinn Dawkins   }
731650f04feSQuinn Dawkins };
732650f04feSQuinn Dawkins 
733de13eedaSNicolas Vasilache struct TestCreateVectorBroadcast
734de13eedaSNicolas Vasilache     : public PassWrapper<TestCreateVectorBroadcast,
735de13eedaSNicolas Vasilache                          OperationPass<func::FuncOp>> {
736de13eedaSNicolas Vasilache   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestCreateVectorBroadcast)
737de13eedaSNicolas Vasilache 
738de13eedaSNicolas Vasilache   StringRef getArgument() const final { return "test-create-vector-broadcast"; }
739de13eedaSNicolas Vasilache   StringRef getDescription() const final {
740de13eedaSNicolas Vasilache     return "Test optimization transformations for transfer ops";
741de13eedaSNicolas Vasilache   }
742de13eedaSNicolas Vasilache   void getDependentDialects(DialectRegistry &registry) const override {
743de13eedaSNicolas Vasilache     registry.insert<vector::VectorDialect>();
744de13eedaSNicolas Vasilache   }
745de13eedaSNicolas Vasilache 
746de13eedaSNicolas Vasilache   void runOnOperation() override {
747de13eedaSNicolas Vasilache     getOperation()->walk([](Operation *op) {
748de13eedaSNicolas Vasilache       if (op->getName().getStringRef() != "test_create_broadcast")
749de13eedaSNicolas Vasilache         return;
750de13eedaSNicolas Vasilache       auto targetShape =
7515550c821STres Popp           cast<VectorType>(op->getResult(0).getType()).getShape();
752de13eedaSNicolas Vasilache       auto arrayAttr =
753830b9b07SMehdi Amini           cast<DenseI64ArrayAttr>(op->getDiscardableAttr("broadcast_dims"))
754830b9b07SMehdi Amini               .asArrayRef();
755de13eedaSNicolas Vasilache       llvm::SetVector<int64_t> broadcastedDims;
756de13eedaSNicolas Vasilache       broadcastedDims.insert(arrayAttr.begin(), arrayAttr.end());
757de13eedaSNicolas Vasilache       OpBuilder b(op);
758de13eedaSNicolas Vasilache       Value bcast = vector::BroadcastOp::createOrFoldBroadcastOp(
759de13eedaSNicolas Vasilache           b, op->getOperand(0), targetShape, broadcastedDims);
760de13eedaSNicolas Vasilache       op->getResult(0).replaceAllUsesWith(bcast);
761de13eedaSNicolas Vasilache       op->erase();
762de13eedaSNicolas Vasilache     });
763de13eedaSNicolas Vasilache   }
764de13eedaSNicolas Vasilache };
765de13eedaSNicolas Vasilache 
766f80a976aSJakub Kuderski struct TestVectorGatherLowering
767f80a976aSJakub Kuderski     : public PassWrapper<TestVectorGatherLowering,
768f80a976aSJakub Kuderski                          OperationPass<func::FuncOp>> {
769f80a976aSJakub Kuderski   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorGatherLowering)
770f80a976aSJakub Kuderski 
771f80a976aSJakub Kuderski   StringRef getArgument() const final { return "test-vector-gather-lowering"; }
772f80a976aSJakub Kuderski   StringRef getDescription() const final {
773f80a976aSJakub Kuderski     return "Test patterns that lower the gather op in the vector conditional "
774f80a976aSJakub Kuderski            "loads";
775f80a976aSJakub Kuderski   }
776f80a976aSJakub Kuderski   void getDependentDialects(DialectRegistry &registry) const override {
777f80a976aSJakub Kuderski     registry.insert<arith::ArithDialect, func::FuncDialect,
778f80a976aSJakub Kuderski                     memref::MemRefDialect, scf::SCFDialect,
779f80a976aSJakub Kuderski                     tensor::TensorDialect, vector::VectorDialect>();
780f80a976aSJakub Kuderski   }
781f80a976aSJakub Kuderski 
782f80a976aSJakub Kuderski   void runOnOperation() override {
783f80a976aSJakub Kuderski     RewritePatternSet patterns(&getContext());
784f80a976aSJakub Kuderski     populateVectorGatherLoweringPatterns(patterns);
785*09dfc571SJacques Pienaar     (void)applyPatternsGreedily(getOperation(), std::move(patterns));
786f80a976aSJakub Kuderski   }
787f80a976aSJakub Kuderski };
788f80a976aSJakub Kuderski 
7899a795f0cSManish Gupta struct TestFoldArithExtensionIntoVectorContractPatterns
7909a795f0cSManish Gupta     : public PassWrapper<TestFoldArithExtensionIntoVectorContractPatterns,
7919a795f0cSManish Gupta                          OperationPass<func::FuncOp>> {
7929a795f0cSManish Gupta   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
7939a795f0cSManish Gupta       TestFoldArithExtensionIntoVectorContractPatterns)
7949a795f0cSManish Gupta 
7959a795f0cSManish Gupta   StringRef getArgument() const final {
7969a795f0cSManish Gupta     return "test-fold-arith-extf-into-vector-contract-patterns";
7979a795f0cSManish Gupta   }
7989a795f0cSManish Gupta   StringRef getDescription() const final {
7999a795f0cSManish Gupta     return "Test patterns that fold arithmetic extension ops into vector "
8009a795f0cSManish Gupta            "contract ops";
8019a795f0cSManish Gupta   }
8029a795f0cSManish Gupta 
8039a795f0cSManish Gupta   void getDependentDialects(DialectRegistry &registry) const override {
8049a795f0cSManish Gupta     registry.insert<arith::ArithDialect, func::FuncDialect, nvgpu::NVGPUDialect,
8059a795f0cSManish Gupta                     memref::MemRefDialect, scf::SCFDialect,
8069a795f0cSManish Gupta                     tensor::TensorDialect, vector::VectorDialect>();
8079a795f0cSManish Gupta   }
8089a795f0cSManish Gupta 
8099a795f0cSManish Gupta   void runOnOperation() override {
8109a795f0cSManish Gupta     RewritePatternSet patterns(&getContext());
8119a795f0cSManish Gupta     populateFoldArithExtensionPatterns(patterns);
812*09dfc571SJacques Pienaar     (void)applyPatternsGreedily(getOperation(), std::move(patterns));
8139a795f0cSManish Gupta   }
8149a795f0cSManish Gupta };
815f643eec8SHsiangkai Wang 
816f643eec8SHsiangkai Wang struct TestVectorEmulateMaskedLoadStore final
817f643eec8SHsiangkai Wang     : public PassWrapper<TestVectorEmulateMaskedLoadStore,
818f643eec8SHsiangkai Wang                          OperationPass<func::FuncOp>> {
819f643eec8SHsiangkai Wang   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorEmulateMaskedLoadStore)
820f643eec8SHsiangkai Wang 
821f643eec8SHsiangkai Wang   StringRef getArgument() const override {
822f643eec8SHsiangkai Wang     return "test-vector-emulate-masked-load-store";
823f643eec8SHsiangkai Wang   }
824f643eec8SHsiangkai Wang   StringRef getDescription() const override {
825f643eec8SHsiangkai Wang     return "Test patterns that emulate the maskedload/maskedstore op by "
826f643eec8SHsiangkai Wang            " memref.load/store and scf.if";
827f643eec8SHsiangkai Wang   }
828f643eec8SHsiangkai Wang   void getDependentDialects(DialectRegistry &registry) const override {
829f643eec8SHsiangkai Wang     registry
830f643eec8SHsiangkai Wang         .insert<arith::ArithDialect, func::FuncDialect, memref::MemRefDialect,
831f643eec8SHsiangkai Wang                 scf::SCFDialect, vector::VectorDialect>();
832f643eec8SHsiangkai Wang   }
833f643eec8SHsiangkai Wang 
834f643eec8SHsiangkai Wang   void runOnOperation() override {
835f643eec8SHsiangkai Wang     RewritePatternSet patterns(&getContext());
836f643eec8SHsiangkai Wang     populateVectorMaskedLoadStoreEmulationPatterns(patterns);
837*09dfc571SJacques Pienaar     (void)applyPatternsGreedily(getOperation(), std::move(patterns));
838f643eec8SHsiangkai Wang   }
839f643eec8SHsiangkai Wang };
84035ef3994SIvan Butygin 
84135ef3994SIvan Butygin struct TestVectorLinearize final
84235ef3994SIvan Butygin     : public PassWrapper<TestVectorLinearize, OperationPass<>> {
84335ef3994SIvan Butygin   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorLinearize)
84435ef3994SIvan Butygin 
8456f5c4f2eSBalaji V. Iyer   TestVectorLinearize() = default;
8466f5c4f2eSBalaji V. Iyer   TestVectorLinearize(const TestVectorLinearize &pass) : PassWrapper(pass) {}
8476f5c4f2eSBalaji V. Iyer 
84835ef3994SIvan Butygin   StringRef getArgument() const override { return "test-vector-linearize"; }
84935ef3994SIvan Butygin   StringRef getDescription() const override {
85035ef3994SIvan Butygin     return "Linearizes ND vectors for N >= 2 into 1D vectors";
85135ef3994SIvan Butygin   }
85235ef3994SIvan Butygin   void getDependentDialects(DialectRegistry &registry) const override {
85335ef3994SIvan Butygin     registry.insert<vector::VectorDialect>();
85435ef3994SIvan Butygin   }
85535ef3994SIvan Butygin 
8566f5c4f2eSBalaji V. Iyer   Option<unsigned> targetVectorBitwidth{
8576f5c4f2eSBalaji V. Iyer       *this, "target-vector-bitwidth",
8586f5c4f2eSBalaji V. Iyer       llvm::cl::desc(
8596f5c4f2eSBalaji V. Iyer           "Minimum vector bitwidth to enable the flattening transformation"),
8606f5c4f2eSBalaji V. Iyer       llvm::cl::init(std::numeric_limits<unsigned>::max())};
86135ef3994SIvan Butygin   void runOnOperation() override {
86235ef3994SIvan Butygin     auto *context = &getContext();
86335ef3994SIvan Butygin 
86435ef3994SIvan Butygin     TypeConverter typeConverter;
86535ef3994SIvan Butygin     RewritePatternSet patterns(context);
86635ef3994SIvan Butygin     ConversionTarget target(*context);
86735ef3994SIvan Butygin 
8686f5c4f2eSBalaji V. Iyer     vector::populateVectorLinearizeTypeConversionsAndLegality(
8696f5c4f2eSBalaji V. Iyer         typeConverter, patterns, target, targetVectorBitwidth);
870c577f91dSCharitha Saumya     vector::populateVectorLinearizeShuffleLikeOpsPatterns(
871c577f91dSCharitha Saumya         typeConverter, patterns, target, targetVectorBitwidth);
87235ef3994SIvan Butygin     if (failed(applyPartialConversion(getOperation(), target,
87335ef3994SIvan Butygin                                       std::move(patterns))))
87435ef3994SIvan Butygin       return signalPassFailure();
87535ef3994SIvan Butygin   }
87635ef3994SIvan Butygin };
8779b06e25eSBenjamin Maxwell 
8789b06e25eSBenjamin Maxwell struct TestEliminateVectorMasks
8799b06e25eSBenjamin Maxwell     : public PassWrapper<TestEliminateVectorMasks,
8809b06e25eSBenjamin Maxwell                          OperationPass<func::FuncOp>> {
8819b06e25eSBenjamin Maxwell   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestEliminateVectorMasks)
8829b06e25eSBenjamin Maxwell 
8839b06e25eSBenjamin Maxwell   TestEliminateVectorMasks() = default;
8849b06e25eSBenjamin Maxwell   TestEliminateVectorMasks(const TestEliminateVectorMasks &pass)
8859b06e25eSBenjamin Maxwell       : PassWrapper(pass) {}
8869b06e25eSBenjamin Maxwell 
8879b06e25eSBenjamin Maxwell   Option<unsigned> vscaleMin{
8889b06e25eSBenjamin Maxwell       *this, "vscale-min", llvm::cl::desc("Minimum possible value of vscale."),
8899b06e25eSBenjamin Maxwell       llvm::cl::init(1)};
8909b06e25eSBenjamin Maxwell   Option<unsigned> vscaleMax{
8919b06e25eSBenjamin Maxwell       *this, "vscale-max", llvm::cl::desc("Maximum possible value of vscale."),
8929b06e25eSBenjamin Maxwell       llvm::cl::init(16)};
8939b06e25eSBenjamin Maxwell 
8949b06e25eSBenjamin Maxwell   StringRef getArgument() const final { return "test-eliminate-vector-masks"; }
8959b06e25eSBenjamin Maxwell   StringRef getDescription() const final {
8969b06e25eSBenjamin Maxwell     return "Test eliminating vector masks";
8979b06e25eSBenjamin Maxwell   }
8989b06e25eSBenjamin Maxwell   void runOnOperation() override {
8999b06e25eSBenjamin Maxwell     IRRewriter rewriter(&getContext());
9009b06e25eSBenjamin Maxwell     eliminateVectorMasks(rewriter, getOperation(),
9019b06e25eSBenjamin Maxwell                          VscaleRange{vscaleMin, vscaleMax});
9029b06e25eSBenjamin Maxwell   }
9039b06e25eSBenjamin Maxwell };
904be0a7e9fSMehdi Amini } // namespace
9053fef2d26SRiver Riddle 
9063fef2d26SRiver Riddle namespace mlir {
9073fef2d26SRiver Riddle namespace test {
90834ff8573SNicolas Vasilache void registerTestVectorLowerings() {
90934ff8573SNicolas Vasilache   PassRegistration<TestVectorToVectorLowering>();
9103fef2d26SRiver Riddle 
911fb7ef637SJakub Kuderski   PassRegistration<TestVectorContractionPrepareForMMTLowering>();
912fb7ef637SJakub Kuderski 
913b5e22e6dSMehdi Amini   PassRegistration<TestVectorUnrollingPatterns>();
9143fef2d26SRiver Riddle 
915b5e22e6dSMehdi Amini   PassRegistration<TestVectorTransferUnrollingPatterns>();
9163fef2d26SRiver Riddle 
9172ec98ffbSMatthias Springer   PassRegistration<TestScalarVectorTransferLoweringPatterns>();
9182ec98ffbSMatthias Springer 
919b5e22e6dSMehdi Amini   PassRegistration<TestVectorTransferOpt>();
9203fef2d26SRiver Riddle 
921a3dd4e77SAhmed S. Taei   PassRegistration<TestVectorTransferCollapseInnerMostContiguousDims>();
9221d8cc45bSthomasraoux 
92342944da5SAndrzej Warzyński   PassRegistration<TestVectorSinkPatterns>();
9244d339ec9SAndrzej Warzynski 
9251d8cc45bSthomasraoux   PassRegistration<TestVectorReduceToContractPatternsPatterns>();
9260aea49a7SBenoit Jacob 
927d33bad66SJakub Kuderski   PassRegistration<TestVectorChainedReductionFoldingPatterns>();
928d33bad66SJakub Kuderski 
92907677113SJakub Kuderski   PassRegistration<TestVectorBreakDownReductionPatterns>();
93007677113SJakub Kuderski 
931aba437ceSBenoit Jacob   PassRegistration<TestFlattenVectorTransferPatterns>();
93280e0bf1aSharsh 
93380e0bf1aSharsh   PassRegistration<TestVectorScanLowering>();
934d02f10d9SThomas Raoux 
935d02f10d9SThomas Raoux   PassRegistration<TestVectorDistribution>();
93639c80656SLei Zhang 
93739c80656SLei Zhang   PassRegistration<TestVectorExtractStridedSliceLowering>();
938de13eedaSNicolas Vasilache 
939650f04feSQuinn Dawkins   PassRegistration<TestVectorBreakDownBitCast>();
940650f04feSQuinn Dawkins 
941de13eedaSNicolas Vasilache   PassRegistration<TestCreateVectorBroadcast>();
942f80a976aSJakub Kuderski 
943f80a976aSJakub Kuderski   PassRegistration<TestVectorGatherLowering>();
944e000b62aSLei Zhang 
9459a795f0cSManish Gupta   PassRegistration<TestFoldArithExtensionIntoVectorContractPatterns>();
946f643eec8SHsiangkai Wang 
947f643eec8SHsiangkai Wang   PassRegistration<TestVectorEmulateMaskedLoadStore>();
94835ef3994SIvan Butygin 
94935ef3994SIvan Butygin   PassRegistration<TestVectorLinearize>();
9509b06e25eSBenjamin Maxwell 
9519b06e25eSBenjamin Maxwell   PassRegistration<TestEliminateVectorMasks>();
9523fef2d26SRiver Riddle }
9533fef2d26SRiver Riddle } // namespace test
9543fef2d26SRiver Riddle } // namespace mlir
955