134ff8573SNicolas Vasilache //===- TestVectorTransforms.cpp - Test Vector transforms and lowerings ----===// 23fef2d26SRiver Riddle // 33fef2d26SRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 43fef2d26SRiver Riddle // See https://llvm.org/LICENSE.txt for license information. 53fef2d26SRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 63fef2d26SRiver Riddle // 73fef2d26SRiver Riddle //===----------------------------------------------------------------------===// 83fef2d26SRiver Riddle 9a1fe1f5fSKazu Hirata #include <optional> 1073bec2b2SNicolas Vasilache #include <type_traits> 113fef2d26SRiver Riddle 123fef2d26SRiver Riddle #include "mlir/Analysis/SliceAnalysis.h" 133fef2d26SRiver Riddle #include "mlir/Dialect/Affine/IR/AffineOps.h" 14fb7ef637SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h" 1523aa5a74SRiver Riddle #include "mlir/Dialect/Func/IR/FuncOps.h" 16d7ef488bSMogball #include "mlir/Dialect/GPU/IR/GPUDialect.h" 17b2729fdaSNicolas Vasilache #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 18b7f2c108Sgysit #include "mlir/Dialect/Linalg/IR/Linalg.h" 1934ff8573SNicolas Vasilache #include "mlir/Dialect/Linalg/Passes.h" 2034ff8573SNicolas Vasilache #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 213fef2d26SRiver Riddle #include "mlir/Dialect/MemRef/IR/MemRef.h" 229a795f0cSManish Gupta #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" 238b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/IR/SCF.h" 24f80a976aSJakub Kuderski #include "mlir/Dialect/Tensor/IR/Tensor.h" 259f122152SChristopher Bate #include "mlir/Dialect/Vector/IR/VectorOps.h" 262bc4c3e9SNicolas Vasilache #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" 27d02f10d9SThomas Raoux #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h" 2839c80656SLei Zhang #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" 2999ef9eebSMatthias Springer #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" 303fef2d26SRiver Riddle #include "mlir/Pass/Pass.h" 3134ff8573SNicolas Vasilache #include "mlir/Pass/PassManager.h" 329f122152SChristopher Bate #include "mlir/Support/LLVM.h" 333fef2d26SRiver Riddle #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 343fef2d26SRiver Riddle 353fef2d26SRiver Riddle using namespace mlir; 3634ff8573SNicolas Vasilache using namespace mlir::linalg; 373fef2d26SRiver Riddle using namespace mlir::vector; 38d054b80bSNicolas Vasilache 393fef2d26SRiver Riddle namespace { 403fef2d26SRiver Riddle 4134ff8573SNicolas Vasilache struct TestVectorToVectorLowering 4258ceae95SRiver Riddle : public PassWrapper<TestVectorToVectorLowering, 4358ceae95SRiver Riddle OperationPass<func::FuncOp>> { 445e50dd04SRiver Riddle MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorToVectorLowering) 455e50dd04SRiver Riddle 4634ff8573SNicolas Vasilache TestVectorToVectorLowering() = default; 473bab9d4eSMehdi Amini TestVectorToVectorLowering(const TestVectorToVectorLowering &pass) 483bab9d4eSMehdi Amini : PassWrapper(pass) {} 49b5e22e6dSMehdi Amini StringRef getArgument() const final { 5034ff8573SNicolas Vasilache return "test-vector-to-vector-lowering"; 51b5e22e6dSMehdi Amini } 52b5e22e6dSMehdi Amini StringRef getDescription() const final { 5334ff8573SNicolas Vasilache return "Test lowering patterns between ops in the vector dialect"; 54b5e22e6dSMehdi Amini } 553fef2d26SRiver Riddle 563fef2d26SRiver Riddle void getDependentDialects(DialectRegistry ®istry) const override { 574c48f016SMatthias Springer registry.insert<affine::AffineDialect>(); 58576b184dSAndrzej Warzynski registry.insert<vector::VectorDialect>(); 593fef2d26SRiver Riddle } 603fef2d26SRiver Riddle 613fef2d26SRiver Riddle Option<bool> unroll{*this, "unroll", llvm::cl::desc("Include unrolling"), 623fef2d26SRiver Riddle llvm::cl::init(false)}; 633fef2d26SRiver Riddle 6441574554SRiver Riddle void runOnOperation() override { 653fef2d26SRiver Riddle auto *ctx = &getContext(); 663fef2d26SRiver Riddle RewritePatternSet patterns(ctx); 673fef2d26SRiver Riddle if (unroll) { 6829102538Sthomasraoux populateVectorUnrollPatterns( 6929102538Sthomasraoux patterns, 703fef2d26SRiver Riddle UnrollVectorOptions().setNativeShapeFn(getShape).setFilterConstraint( 713fef2d26SRiver Riddle filter)); 723fef2d26SRiver Riddle } 733fef2d26SRiver Riddle populateVectorToVectorCanonicalizationPatterns(patterns); 743fef2d26SRiver Riddle populateBubbleVectorBitCastOpPatterns(patterns); 753fef2d26SRiver Riddle populateCastAwayVectorLeadingOneDimPatterns(patterns); 76*09dfc571SJacques Pienaar (void)applyPatternsGreedily(getOperation(), std::move(patterns)); 773fef2d26SRiver Riddle } 783fef2d26SRiver Riddle 793fef2d26SRiver Riddle private: 803fef2d26SRiver Riddle // Return the target shape based on op type. 810a81ace0SKazu Hirata static std::optional<SmallVector<int64_t>> getShape(Operation *op) { 82dec8af70SRiver Riddle if (isa<arith::AddFOp, arith::SelectOp, arith::CmpFOp>(op)) 837a69a9d7SNicolas Vasilache return SmallVector<int64_t>(2, 2); 843fef2d26SRiver Riddle if (isa<vector::ContractionOp>(op)) 857a69a9d7SNicolas Vasilache return SmallVector<int64_t>(3, 2); 8629102538Sthomasraoux // For transfer ops, just propagate the shape coming from 8729102538Sthomasraoux // InsertStridedSlices/ExtractStridedSlices. 8829102538Sthomasraoux if (auto readOp = dyn_cast<vector::TransferReadOp>(op)) { 8929102538Sthomasraoux VectorType dstVec; 9029102538Sthomasraoux for (Operation *users : readOp->getUsers()) { 9129102538Sthomasraoux auto extract = dyn_cast<ExtractStridedSliceOp>(users); 9229102538Sthomasraoux if (!extract) 931a36588eSKazu Hirata return std::nullopt; 945550c821STres Popp auto vecType = cast<VectorType>(extract.getResult().getType()); 9529102538Sthomasraoux if (dstVec && dstVec != vecType) 961a36588eSKazu Hirata return std::nullopt; 9729102538Sthomasraoux dstVec = vecType; 9829102538Sthomasraoux } 995262865aSKazu Hirata return SmallVector<int64_t>(dstVec.getShape()); 10029102538Sthomasraoux } 10129102538Sthomasraoux if (auto writeOp = dyn_cast<vector::TransferWriteOp>(op)) { 1027c38fd60SJacques Pienaar auto insert = writeOp.getVector().getDefiningOp<InsertStridedSliceOp>(); 10329102538Sthomasraoux if (!insert) 1041a36588eSKazu Hirata return std::nullopt; 10529102538Sthomasraoux ArrayRef<int64_t> shape = insert.getSourceVectorType().getShape(); 1065262865aSKazu Hirata return SmallVector<int64_t>(shape); 10729102538Sthomasraoux } 1081a36588eSKazu Hirata return std::nullopt; 1093fef2d26SRiver Riddle } 1103fef2d26SRiver Riddle 1113fef2d26SRiver Riddle static LogicalResult filter(Operation *op) { 112dec8af70SRiver Riddle return success(isa<arith::AddFOp, arith::SelectOp, arith::CmpFOp, 113dec8af70SRiver Riddle ContractionOp, TransferReadOp, TransferWriteOp>(op)); 1143fef2d26SRiver Riddle } 1153fef2d26SRiver Riddle }; 1163fef2d26SRiver Riddle 117fb7ef637SJakub Kuderski struct TestVectorContractionPrepareForMMTLowering 118fb7ef637SJakub Kuderski : public PassWrapper<TestVectorContractionPrepareForMMTLowering, 119fb7ef637SJakub Kuderski OperationPass<func::FuncOp>> { 120fb7ef637SJakub Kuderski MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 121fb7ef637SJakub Kuderski TestVectorContractionPrepareForMMTLowering) 122fb7ef637SJakub Kuderski 123fb7ef637SJakub Kuderski StringRef getArgument() const final { 124fb7ef637SJakub Kuderski return "test-vector-contraction-prepare-for-mmt-lowering"; 125fb7ef637SJakub Kuderski } 126fb7ef637SJakub Kuderski StringRef getDescription() const final { 127fb7ef637SJakub Kuderski return "Test vector.contraction matmul canonicalization for MMT lowering."; 128fb7ef637SJakub Kuderski } 129fb7ef637SJakub Kuderski TestVectorContractionPrepareForMMTLowering() = default; 130fb7ef637SJakub Kuderski 131fb7ef637SJakub Kuderski void getDependentDialects(DialectRegistry ®istry) const override { 1324c48f016SMatthias Springer registry.insert<affine::AffineDialect, arith::ArithDialect, 1334c48f016SMatthias Springer vector::VectorDialect>(); 134fb7ef637SJakub Kuderski } 135fb7ef637SJakub Kuderski 136fb7ef637SJakub Kuderski void runOnOperation() override { 137fb7ef637SJakub Kuderski MLIRContext *ctx = &getContext(); 138fb7ef637SJakub Kuderski RewritePatternSet patterns(ctx); 139fb7ef637SJakub Kuderski vector::populateVectorContractCanonicalizeMatmulToMMT(patterns); 140*09dfc571SJacques Pienaar (void)applyPatternsGreedily(getOperation(), std::move(patterns)); 141fb7ef637SJakub Kuderski } 142fb7ef637SJakub Kuderski }; 143fb7ef637SJakub Kuderski 1443fef2d26SRiver Riddle struct TestVectorUnrollingPatterns 14558ceae95SRiver Riddle : public PassWrapper<TestVectorUnrollingPatterns, 14658ceae95SRiver Riddle OperationPass<func::FuncOp>> { 1475e50dd04SRiver Riddle MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorUnrollingPatterns) 1485e50dd04SRiver Riddle 149b5e22e6dSMehdi Amini StringRef getArgument() const final { 150b5e22e6dSMehdi Amini return "test-vector-unrolling-patterns"; 151b5e22e6dSMehdi Amini } 152b5e22e6dSMehdi Amini StringRef getDescription() const final { 15334ff8573SNicolas Vasilache return "Test lowering patterns to unroll contract ops in the vector " 154b5e22e6dSMehdi Amini "dialect"; 155b5e22e6dSMehdi Amini } 1563fef2d26SRiver Riddle TestVectorUnrollingPatterns() = default; 1573bab9d4eSMehdi Amini TestVectorUnrollingPatterns(const TestVectorUnrollingPatterns &pass) 1583bab9d4eSMehdi Amini : PassWrapper(pass) {} 15941574554SRiver Riddle void runOnOperation() override { 1603fef2d26SRiver Riddle MLIRContext *ctx = &getContext(); 1613fef2d26SRiver Riddle RewritePatternSet patterns(ctx); 16229102538Sthomasraoux populateVectorUnrollPatterns( 16329102538Sthomasraoux patterns, UnrollVectorOptions() 1643fef2d26SRiver Riddle .setNativeShape(ArrayRef<int64_t>{2, 2}) 1653fef2d26SRiver Riddle .setFilterConstraint([](Operation *op) { 166f69175b1SThomas Raoux return success(isa<arith::AddFOp, vector::FMAOp, 167f69175b1SThomas Raoux vector::MultiDimReductionOp>(op)); 1683fef2d26SRiver Riddle })); 169de5022c7SMatthias Springer populateVectorUnrollPatterns( 170de5022c7SMatthias Springer patterns, UnrollVectorOptions() 171de5022c7SMatthias Springer .setNativeShape(ArrayRef<int64_t>{2}) 172de5022c7SMatthias Springer .setFilterConstraint([](Operation *op) { 173de5022c7SMatthias Springer return success(isa<vector::ReductionOp>(op)); 174de5022c7SMatthias Springer })); 1755b1b7108SThomas Raoux populateVectorUnrollPatterns( 1765b1b7108SThomas Raoux patterns, UnrollVectorOptions() 1775b1b7108SThomas Raoux .setNativeShape(ArrayRef<int64_t>{1, 3, 4, 2}) 1785b1b7108SThomas Raoux .setFilterConstraint([](Operation *op) { 1795b1b7108SThomas Raoux return success(isa<vector::TransposeOp>(op)); 1805b1b7108SThomas Raoux })); 1813fef2d26SRiver Riddle 1823fef2d26SRiver Riddle if (unrollBasedOnType) { 1833fef2d26SRiver Riddle UnrollVectorOptions::NativeShapeFnType nativeShapeFn = 1840a81ace0SKazu Hirata [](Operation *op) -> std::optional<SmallVector<int64_t>> { 1853fef2d26SRiver Riddle vector::ContractionOp contractOp = cast<vector::ContractionOp>(op); 1867a69a9d7SNicolas Vasilache SmallVector<int64_t> nativeShape(contractOp.getIteratorTypes().size(), 1877a69a9d7SNicolas Vasilache 4); 1889f122152SChristopher Bate Type lhsType = contractOp.getLhsType().getElementType(); 1899f122152SChristopher Bate nativeShape[nativeShape.size() - 1] = lhsType.isF16() ? 4 : 2; 1903fef2d26SRiver Riddle return nativeShape; 1913fef2d26SRiver Riddle }; 1929f122152SChristopher Bate 1939f122152SChristopher Bate UnrollVectorOptions opts; 1949f122152SChristopher Bate opts.setNativeShapeFn(nativeShapeFn) 1959f122152SChristopher Bate .setFilterConstraint( 1969f122152SChristopher Bate [](Operation *op) { return success(isa<ContractionOp>(op)); }); 1979f122152SChristopher Bate 1989f122152SChristopher Bate if (!unrollOrder.empty()) { 1990a81ace0SKazu Hirata opts.setUnrollTraversalOrderFn( 2000a81ace0SKazu Hirata [this](Operation *op) -> std::optional<SmallVector<int64_t>> { 2010a81ace0SKazu Hirata vector::ContractionOp contractOp = 2020a81ace0SKazu Hirata cast<vector::ContractionOp>(op); 2039f122152SChristopher Bate if (contractOp.getIteratorTypes().size() == unrollOrder.size()) 2040a81ace0SKazu Hirata return SmallVector<int64_t>(unrollOrder.begin(), 2050a81ace0SKazu Hirata unrollOrder.end()); 2061a36588eSKazu Hirata return std::nullopt; 2079f122152SChristopher Bate }); 2089f122152SChristopher Bate } 2099f122152SChristopher Bate populateVectorUnrollPatterns(patterns, opts); 2109f122152SChristopher Bate } else { 2110a81ace0SKazu Hirata auto nativeShapeFn = 2120a81ace0SKazu Hirata [](Operation *op) -> std::optional<SmallVector<int64_t>> { 2139f122152SChristopher Bate auto contractOp = dyn_cast<ContractionOp>(op); 2149f122152SChristopher Bate if (!contractOp) 2151a36588eSKazu Hirata return std::nullopt; 2167a69a9d7SNicolas Vasilache return SmallVector<int64_t>(contractOp.getIteratorTypes().size(), 2); 2179f122152SChristopher Bate }; 21853fe155bSChristopher Bate populateVectorUnrollPatterns(patterns, 21953fe155bSChristopher Bate UnrollVectorOptions() 22053fe155bSChristopher Bate .setNativeShapeFn(nativeShapeFn) 22153fe155bSChristopher Bate .setFilterConstraint([](Operation *op) { 22253fe155bSChristopher Bate return success(isa<ContractionOp>(op)); 22353fe155bSChristopher Bate })); 2243fef2d26SRiver Riddle } 2253fef2d26SRiver Riddle populateVectorToVectorCanonicalizationPatterns(patterns); 226*09dfc571SJacques Pienaar (void)applyPatternsGreedily(getOperation(), std::move(patterns)); 2273fef2d26SRiver Riddle } 2283fef2d26SRiver Riddle 2299f122152SChristopher Bate ListOption<int64_t> unrollOrder{*this, "unroll-order", 23062a4e6abSFangrui Song llvm::cl::desc("set the unroll order")}; 2319f122152SChristopher Bate 2323fef2d26SRiver Riddle Option<bool> unrollBasedOnType{ 2333fef2d26SRiver Riddle *this, "unroll-based-on-type", 2343fef2d26SRiver Riddle llvm::cl::desc("Set the unroll factor based on type of the operation"), 2353fef2d26SRiver Riddle llvm::cl::init(false)}; 2363fef2d26SRiver Riddle }; 2373fef2d26SRiver Riddle 2383fef2d26SRiver Riddle struct TestVectorTransferUnrollingPatterns 23941574554SRiver Riddle : public PassWrapper<TestVectorTransferUnrollingPatterns, 24058ceae95SRiver Riddle OperationPass<func::FuncOp>> { 2415e50dd04SRiver Riddle MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 2425e50dd04SRiver Riddle TestVectorTransferUnrollingPatterns) 2435e50dd04SRiver Riddle 2449f122152SChristopher Bate TestVectorTransferUnrollingPatterns() = default; 2459f122152SChristopher Bate TestVectorTransferUnrollingPatterns( 2469f122152SChristopher Bate const TestVectorTransferUnrollingPatterns &pass) 2479f122152SChristopher Bate : PassWrapper(pass) {} 2489f122152SChristopher Bate 2493fef2d26SRiver Riddle void getDependentDialects(DialectRegistry ®istry) const override { 2504c48f016SMatthias Springer registry.insert<affine::AffineDialect>(); 2513fef2d26SRiver Riddle } 252b5e22e6dSMehdi Amini StringRef getArgument() const final { 253b5e22e6dSMehdi Amini return "test-vector-transfer-unrolling-patterns"; 254b5e22e6dSMehdi Amini } 255b5e22e6dSMehdi Amini StringRef getDescription() const final { 25634ff8573SNicolas Vasilache return "Test lowering patterns to unroll transfer ops in the vector " 257b5e22e6dSMehdi Amini "dialect"; 258b5e22e6dSMehdi Amini } 25941574554SRiver Riddle void runOnOperation() override { 2603fef2d26SRiver Riddle MLIRContext *ctx = &getContext(); 2613fef2d26SRiver Riddle RewritePatternSet patterns(ctx); 2629f122152SChristopher Bate UnrollVectorOptions opts; 2639f122152SChristopher Bate opts.setNativeShape(ArrayRef<int64_t>{2, 2}) 2643fef2d26SRiver Riddle .setFilterConstraint([](Operation *op) { 265435f7d4cSQuinn Dawkins return success(isa<vector::TransferReadOp, vector::TransferWriteOp, 266435f7d4cSQuinn Dawkins vector::GatherOp>(op)); 2679f122152SChristopher Bate }); 2689f122152SChristopher Bate if (reverseUnrollOrder.getValue()) { 2699f122152SChristopher Bate opts.setUnrollTraversalOrderFn( 2700a81ace0SKazu Hirata [](Operation *op) -> std::optional<SmallVector<int64_t>> { 2719f122152SChristopher Bate int64_t numLoops = 0; 2729f122152SChristopher Bate if (auto readOp = dyn_cast<vector::TransferReadOp>(op)) 2739f122152SChristopher Bate numLoops = readOp.getVectorType().getRank(); 2749f122152SChristopher Bate else if (auto writeOp = dyn_cast<vector::TransferWriteOp>(op)) 2759f122152SChristopher Bate numLoops = writeOp.getVectorType().getRank(); 276435f7d4cSQuinn Dawkins else if (auto gatherOp = dyn_cast<vector::GatherOp>(op)) 277435f7d4cSQuinn Dawkins numLoops = gatherOp.getVectorType().getRank(); 2789f122152SChristopher Bate else 2791a36588eSKazu Hirata return std::nullopt; 2809f122152SChristopher Bate auto order = llvm::reverse(llvm::seq<int64_t>(0, numLoops)); 2819f122152SChristopher Bate return llvm::to_vector(order); 2829f122152SChristopher Bate }); 2839f122152SChristopher Bate } 2849f122152SChristopher Bate populateVectorUnrollPatterns(patterns, opts); 2853fef2d26SRiver Riddle populateVectorToVectorCanonicalizationPatterns(patterns); 286*09dfc571SJacques Pienaar (void)applyPatternsGreedily(getOperation(), std::move(patterns)); 2873fef2d26SRiver Riddle } 2889f122152SChristopher Bate 2899f122152SChristopher Bate Option<bool> reverseUnrollOrder{ 2909f122152SChristopher Bate *this, "reverse-unroll-order", 2919f122152SChristopher Bate llvm::cl::desc( 2929f122152SChristopher Bate "reverse the order of unrolling of vector transfer operations"), 2939f122152SChristopher Bate llvm::cl::init(false)}; 2943fef2d26SRiver Riddle }; 2953fef2d26SRiver Riddle 2962ec98ffbSMatthias Springer struct TestScalarVectorTransferLoweringPatterns 2972ec98ffbSMatthias Springer : public PassWrapper<TestScalarVectorTransferLoweringPatterns, 2982ec98ffbSMatthias Springer OperationPass<func::FuncOp>> { 2992ec98ffbSMatthias Springer MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 3002ec98ffbSMatthias Springer TestScalarVectorTransferLoweringPatterns) 3012ec98ffbSMatthias Springer 30214726cd6SDiego Caballero TestScalarVectorTransferLoweringPatterns() = default; 30314726cd6SDiego Caballero TestScalarVectorTransferLoweringPatterns( 30414726cd6SDiego Caballero const TestScalarVectorTransferLoweringPatterns &pass) 30514726cd6SDiego Caballero : PassWrapper(pass) {} 30614726cd6SDiego Caballero 3072ec98ffbSMatthias Springer StringRef getArgument() const final { 3082ec98ffbSMatthias Springer return "test-scalar-vector-transfer-lowering"; 3092ec98ffbSMatthias Springer } 3102ec98ffbSMatthias Springer StringRef getDescription() const final { 3112ec98ffbSMatthias Springer return "Test lowering of scalar vector transfers to memref loads/stores."; 3122ec98ffbSMatthias Springer } 3132ec98ffbSMatthias Springer 3142ec98ffbSMatthias Springer void getDependentDialects(DialectRegistry ®istry) const override { 3154c48f016SMatthias Springer registry.insert<affine::AffineDialect, memref::MemRefDialect, 3164c48f016SMatthias Springer tensor::TensorDialect, vector::VectorDialect>(); 3172ec98ffbSMatthias Springer } 3182ec98ffbSMatthias Springer 31914726cd6SDiego Caballero Option<bool> allowMultipleUses{ 32014726cd6SDiego Caballero *this, "allow-multiple-uses", 32114726cd6SDiego Caballero llvm::cl::desc("Fold transfer operations with multiple uses"), 32214726cd6SDiego Caballero llvm::cl::init(false)}; 32314726cd6SDiego Caballero 3242ec98ffbSMatthias Springer void runOnOperation() override { 3252ec98ffbSMatthias Springer MLIRContext *ctx = &getContext(); 3262ec98ffbSMatthias Springer RewritePatternSet patterns(ctx); 32714726cd6SDiego Caballero vector::populateScalarVectorTransferLoweringPatterns( 32814726cd6SDiego Caballero patterns, /*benefit=*/1, allowMultipleUses.getValue()); 329*09dfc571SJacques Pienaar (void)applyPatternsGreedily(getOperation(), std::move(patterns)); 3302ec98ffbSMatthias Springer } 3312ec98ffbSMatthias Springer }; 3322ec98ffbSMatthias Springer 3333fef2d26SRiver Riddle struct TestVectorTransferOpt 33458ceae95SRiver Riddle : public PassWrapper<TestVectorTransferOpt, OperationPass<func::FuncOp>> { 3355e50dd04SRiver Riddle MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorTransferOpt) 3365e50dd04SRiver Riddle 337b5e22e6dSMehdi Amini StringRef getArgument() const final { return "test-vector-transferop-opt"; } 338b5e22e6dSMehdi Amini StringRef getDescription() const final { 339b5e22e6dSMehdi Amini return "Test optimization transformations for transfer ops"; 340b5e22e6dSMehdi Amini } 341553cebdeSNicolas Vasilache void runOnOperation() override { 342553cebdeSNicolas Vasilache IRRewriter rewriter(&getContext()); 343553cebdeSNicolas Vasilache transferOpflowOpt(rewriter, getOperation()); 344553cebdeSNicolas Vasilache } 3453fef2d26SRiver Riddle }; 3463fef2d26SRiver Riddle 347a3dd4e77SAhmed S. Taei struct TestVectorTransferCollapseInnerMostContiguousDims 348a3dd4e77SAhmed S. Taei : public PassWrapper<TestVectorTransferCollapseInnerMostContiguousDims, 34958ceae95SRiver Riddle OperationPass<func::FuncOp>> { 3505e50dd04SRiver Riddle MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 3515e50dd04SRiver Riddle TestVectorTransferCollapseInnerMostContiguousDims) 3525e50dd04SRiver Riddle 353a3dd4e77SAhmed S. Taei TestVectorTransferCollapseInnerMostContiguousDims() = default; 354a3dd4e77SAhmed S. Taei TestVectorTransferCollapseInnerMostContiguousDims( 355322c8914SMehdi Amini const TestVectorTransferCollapseInnerMostContiguousDims &pass) = default; 356a3dd4e77SAhmed S. Taei 357a3dd4e77SAhmed S. Taei void getDependentDialects(DialectRegistry ®istry) const override { 3584c48f016SMatthias Springer registry.insert<memref::MemRefDialect, affine::AffineDialect>(); 359a3dd4e77SAhmed S. Taei } 360a3dd4e77SAhmed S. Taei 361a3dd4e77SAhmed S. Taei StringRef getArgument() const final { 362a3dd4e77SAhmed S. Taei return "test-vector-transfer-collapse-inner-most-dims"; 363a3dd4e77SAhmed S. Taei } 364a3dd4e77SAhmed S. Taei 365a3dd4e77SAhmed S. Taei StringRef getDescription() const final { 366017c75bfSKai Sasaki return "Test lowering patterns that reduces the rank of the vector " 367a3dd4e77SAhmed S. Taei "transfer memory and vector operands."; 368a3dd4e77SAhmed S. Taei } 369a3dd4e77SAhmed S. Taei 37041574554SRiver Riddle void runOnOperation() override { 371a3dd4e77SAhmed S. Taei RewritePatternSet patterns(&getContext()); 372a3dd4e77SAhmed S. Taei populateVectorTransferCollapseInnerMostContiguousDimsPatterns(patterns); 373*09dfc571SJacques Pienaar (void)applyPatternsGreedily(getOperation(), std::move(patterns)); 374a3dd4e77SAhmed S. Taei } 375a3dd4e77SAhmed S. Taei }; 376a3dd4e77SAhmed S. Taei 37742944da5SAndrzej Warzyński struct TestVectorSinkPatterns 37842944da5SAndrzej Warzyński : public PassWrapper<TestVectorSinkPatterns, OperationPass<func::FuncOp>> { 37942944da5SAndrzej Warzyński MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorSinkPatterns) 3804d339ec9SAndrzej Warzynski 38142944da5SAndrzej Warzyński TestVectorSinkPatterns() = default; 38242944da5SAndrzej Warzyński TestVectorSinkPatterns(const TestVectorSinkPatterns &pass) = default; 3834d339ec9SAndrzej Warzynski 3844d339ec9SAndrzej Warzynski void getDependentDialects(DialectRegistry ®istry) const override { 3854d339ec9SAndrzej Warzynski registry.insert<memref::MemRefDialect, affine::AffineDialect>(); 3864d339ec9SAndrzej Warzynski } 3874d339ec9SAndrzej Warzynski 38842944da5SAndrzej Warzyński StringRef getArgument() const final { return "test-vector-sink-patterns"; } 3894d339ec9SAndrzej Warzynski 3904d339ec9SAndrzej Warzynski StringRef getDescription() const final { 391017c75bfSKai Sasaki return "Test lowering patterns that eliminate redundant broadcast " 39242944da5SAndrzej Warzyński "and transpose operations."; 3934d339ec9SAndrzej Warzynski } 3944d339ec9SAndrzej Warzynski 3954d339ec9SAndrzej Warzynski void runOnOperation() override { 3964d339ec9SAndrzej Warzynski RewritePatternSet patterns(&getContext()); 39742944da5SAndrzej Warzyński populateSinkVectorOpsPatterns(patterns); 398*09dfc571SJacques Pienaar (void)applyPatternsGreedily(getOperation(), std::move(patterns)); 3994d339ec9SAndrzej Warzynski } 4004d339ec9SAndrzej Warzynski }; 4014d339ec9SAndrzej Warzynski 4021d8cc45bSthomasraoux struct TestVectorReduceToContractPatternsPatterns 4031d8cc45bSthomasraoux : public PassWrapper<TestVectorReduceToContractPatternsPatterns, 40458ceae95SRiver Riddle OperationPass<func::FuncOp>> { 4055e50dd04SRiver Riddle MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 4065e50dd04SRiver Riddle TestVectorReduceToContractPatternsPatterns) 4075e50dd04SRiver Riddle 4081d8cc45bSthomasraoux StringRef getArgument() const final { 4091d8cc45bSthomasraoux return "test-vector-reduction-to-contract-patterns"; 4101d8cc45bSthomasraoux } 4111d8cc45bSthomasraoux StringRef getDescription() const final { 4121d8cc45bSthomasraoux return "Test patterns to convert multireduce op to contract and combine " 4131d8cc45bSthomasraoux "broadcast/transpose to contract"; 4141d8cc45bSthomasraoux } 41541574554SRiver Riddle void runOnOperation() override { 4161d8cc45bSthomasraoux RewritePatternSet patterns(&getContext()); 417d054b80bSNicolas Vasilache populateVectorReductionToContractPatterns(patterns); 418*09dfc571SJacques Pienaar (void)applyPatternsGreedily(getOperation(), std::move(patterns)); 4191d8cc45bSthomasraoux } 4201d8cc45bSthomasraoux }; 4211d8cc45bSthomasraoux 422d33bad66SJakub Kuderski struct TestVectorChainedReductionFoldingPatterns 423d33bad66SJakub Kuderski : public PassWrapper<TestVectorChainedReductionFoldingPatterns, 424d33bad66SJakub Kuderski OperationPass<func::FuncOp>> { 425d33bad66SJakub Kuderski MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 426d33bad66SJakub Kuderski TestVectorChainedReductionFoldingPatterns) 427d33bad66SJakub Kuderski 428d33bad66SJakub Kuderski StringRef getArgument() const final { 429d33bad66SJakub Kuderski return "test-vector-chained-reduction-folding-patterns"; 430d33bad66SJakub Kuderski } 431d33bad66SJakub Kuderski StringRef getDescription() const final { 432d33bad66SJakub Kuderski return "Test patterns to fold chained vector reductions"; 433d33bad66SJakub Kuderski } 434d33bad66SJakub Kuderski void runOnOperation() override { 435d33bad66SJakub Kuderski RewritePatternSet patterns(&getContext()); 436d33bad66SJakub Kuderski populateChainedVectorReductionFoldingPatterns(patterns); 437*09dfc571SJacques Pienaar (void)applyPatternsGreedily(getOperation(), std::move(patterns)); 438d33bad66SJakub Kuderski } 439d33bad66SJakub Kuderski }; 440d33bad66SJakub Kuderski 44107677113SJakub Kuderski struct TestVectorBreakDownReductionPatterns 44207677113SJakub Kuderski : public PassWrapper<TestVectorBreakDownReductionPatterns, 44307677113SJakub Kuderski OperationPass<func::FuncOp>> { 44407677113SJakub Kuderski MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 44507677113SJakub Kuderski TestVectorBreakDownReductionPatterns) 44607677113SJakub Kuderski 44707677113SJakub Kuderski StringRef getArgument() const final { 44807677113SJakub Kuderski return "test-vector-break-down-reduction-patterns"; 44907677113SJakub Kuderski } 45007677113SJakub Kuderski StringRef getDescription() const final { 45107677113SJakub Kuderski return "Test patterns to break down vector reductions into arith " 45207677113SJakub Kuderski "reductions"; 45307677113SJakub Kuderski } 45407677113SJakub Kuderski void runOnOperation() override { 45507677113SJakub Kuderski RewritePatternSet patterns(&getContext()); 45607677113SJakub Kuderski populateBreakDownVectorReductionPatterns(patterns, 45707677113SJakub Kuderski /*maxNumElementsToExtract=*/2); 458*09dfc571SJacques Pienaar (void)applyPatternsGreedily(getOperation(), std::move(patterns)); 45907677113SJakub Kuderski } 46007677113SJakub Kuderski }; 46107677113SJakub Kuderski 462aba437ceSBenoit Jacob struct TestFlattenVectorTransferPatterns 46341574554SRiver Riddle : public PassWrapper<TestFlattenVectorTransferPatterns, 46458ceae95SRiver Riddle OperationPass<func::FuncOp>> { 4655e50dd04SRiver Riddle MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 4665e50dd04SRiver Riddle TestFlattenVectorTransferPatterns) 4675e50dd04SRiver Riddle 46871441ed1SDiego Caballero TestFlattenVectorTransferPatterns() = default; 46971441ed1SDiego Caballero TestFlattenVectorTransferPatterns( 47071441ed1SDiego Caballero const TestFlattenVectorTransferPatterns &pass) 47171441ed1SDiego Caballero : PassWrapper(pass) {} 47271441ed1SDiego Caballero 473aba437ceSBenoit Jacob StringRef getArgument() const final { 474aba437ceSBenoit Jacob return "test-vector-transfer-flatten-patterns"; 475aba437ceSBenoit Jacob } 47671441ed1SDiego Caballero 477aba437ceSBenoit Jacob StringRef getDescription() const final { 478aba437ceSBenoit Jacob return "Test patterns to rewrite contiguous row-major N-dimensional " 479aba437ceSBenoit Jacob "vector.transfer_{read,write} ops into 1D transfers"; 480aba437ceSBenoit Jacob } 48171441ed1SDiego Caballero 482aba437ceSBenoit Jacob void getDependentDialects(DialectRegistry ®istry) const override { 483aba437ceSBenoit Jacob registry.insert<memref::MemRefDialect>(); 4842eb9e33cSAndrzej Warzyński registry.insert<affine::AffineDialect>(); 485c02d07fdSAndrzej Warzyński registry.insert<vector::VectorDialect>(); 486aba437ceSBenoit Jacob } 48771441ed1SDiego Caballero 48871441ed1SDiego Caballero Option<unsigned> targetVectorBitwidth{ 48971441ed1SDiego Caballero *this, "target-vector-bitwidth", 49071441ed1SDiego Caballero llvm::cl::desc( 491d3aa92edSAndrzej Warzyński "Minimum vector bitwidth to enable the flattening transformation. " 492d3aa92edSAndrzej Warzyński "For scalable vectors this is the base size, i.e. the size " 493d3aa92edSAndrzej Warzyński "corresponding to vscale=1."), 49471441ed1SDiego Caballero llvm::cl::init(std::numeric_limits<unsigned>::max())}; 49571441ed1SDiego Caballero 49641574554SRiver Riddle void runOnOperation() override { 497aba437ceSBenoit Jacob RewritePatternSet patterns(&getContext()); 49871441ed1SDiego Caballero populateFlattenVectorTransferPatterns(patterns, targetVectorBitwidth); 499*09dfc571SJacques Pienaar (void)applyPatternsGreedily(getOperation(), std::move(patterns)); 500aba437ceSBenoit Jacob } 501aba437ceSBenoit Jacob }; 502aba437ceSBenoit Jacob 50380e0bf1aSharsh struct TestVectorScanLowering 50458ceae95SRiver Riddle : public PassWrapper<TestVectorScanLowering, OperationPass<func::FuncOp>> { 5055e50dd04SRiver Riddle MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorScanLowering) 5065e50dd04SRiver Riddle 50780e0bf1aSharsh StringRef getArgument() const final { return "test-vector-scan-lowering"; } 50880e0bf1aSharsh StringRef getDescription() const final { 50980e0bf1aSharsh return "Test lowering patterns that lower the scan op in the vector " 51080e0bf1aSharsh "dialect"; 51180e0bf1aSharsh } 51280e0bf1aSharsh void runOnOperation() override { 51380e0bf1aSharsh RewritePatternSet patterns(&getContext()); 51480e0bf1aSharsh populateVectorScanLoweringPatterns(patterns); 515*09dfc571SJacques Pienaar (void)applyPatternsGreedily(getOperation(), std::move(patterns)); 51680e0bf1aSharsh } 51780e0bf1aSharsh }; 51880e0bf1aSharsh 519d02f10d9SThomas Raoux /// Allocate shared memory for a single warp to test lowering of 520d02f10d9SThomas Raoux /// WarpExecuteOnLane0Op. 521d02f10d9SThomas Raoux static Value allocateGlobalSharedMemory(Location loc, OpBuilder &builder, 522ecaf2c33SPetr Kurapov gpu::WarpExecuteOnLane0Op warpOp, 523d02f10d9SThomas Raoux Type type) { 524d02f10d9SThomas Raoux static constexpr int64_t kSharedMemorySpace = 3; 525d02f10d9SThomas Raoux // Compute type of shared memory buffer. 526d02f10d9SThomas Raoux MemRefType memrefType; 5275550c821STres Popp if (auto vectorType = dyn_cast<VectorType>(type)) { 528d02f10d9SThomas Raoux memrefType = 529d02f10d9SThomas Raoux MemRefType::get(vectorType.getShape(), vectorType.getElementType(), {}, 530d02f10d9SThomas Raoux kSharedMemorySpace); 531d02f10d9SThomas Raoux } else { 532d02f10d9SThomas Raoux memrefType = MemRefType::get({1}, type, {}, kSharedMemorySpace); 533d02f10d9SThomas Raoux } 534d02f10d9SThomas Raoux 535d02f10d9SThomas Raoux // Get symbol table holding all shared memory globals. 536d02f10d9SThomas Raoux ModuleOp moduleOp = warpOp->getParentOfType<ModuleOp>(); 537d02f10d9SThomas Raoux SymbolTable symbolTable(moduleOp); 538d02f10d9SThomas Raoux 539d02f10d9SThomas Raoux // Create a pretty name. 540d02f10d9SThomas Raoux SmallString<64> buf; 541d02f10d9SThomas Raoux llvm::raw_svector_ostream os(buf); 542d02f10d9SThomas Raoux interleave(memrefType.getShape(), os, "x"); 543d02f10d9SThomas Raoux os << "x" << memrefType.getElementType(); 544d02f10d9SThomas Raoux std::string symbolName = (Twine("__shared_") + os.str()).str(); 545d02f10d9SThomas Raoux 546d02f10d9SThomas Raoux auto ip = builder.saveInsertionPoint(); 547d02f10d9SThomas Raoux builder.setInsertionPoint(moduleOp); 548d02f10d9SThomas Raoux auto global = builder.create<memref::GlobalOp>( 549d02f10d9SThomas Raoux loc, 550d02f10d9SThomas Raoux /*sym_name=*/symbolName, 551d02f10d9SThomas Raoux /*sym_visibility=*/builder.getStringAttr("private"), 552d02f10d9SThomas Raoux /*type=*/memrefType, 553d02f10d9SThomas Raoux /*initial_value=*/Attribute(), 554d02f10d9SThomas Raoux /*constant=*/false, 555d02f10d9SThomas Raoux /*alignment=*/IntegerAttr()); 556d02f10d9SThomas Raoux symbolTable.insert(global); 557d02f10d9SThomas Raoux // The symbol table inserts at the end of the module, but globals are a bit 558d02f10d9SThomas Raoux // nicer if they are at the beginning. 559d02f10d9SThomas Raoux global->moveBefore(&moduleOp.front()); 560d02f10d9SThomas Raoux 561d02f10d9SThomas Raoux builder.restoreInsertionPoint(ip); 562d02f10d9SThomas Raoux return builder.create<memref::GetGlobalOp>(loc, memrefType, symbolName); 563d02f10d9SThomas Raoux } 564d02f10d9SThomas Raoux 5656834803cSThomas Raoux static Value warpReduction(Location loc, OpBuilder &builder, Value input, 5666834803cSThomas Raoux CombiningKind kind, uint32_t size) { 567d2061530Sstanley-nod // First reduce on a single thread to get per lane reduction value. 568d2061530Sstanley-nod Value laneVal = builder.create<vector::ReductionOp>(loc, kind, input); 5696834803cSThomas Raoux // Parallel reduction using butterfly shuffles. 5706834803cSThomas Raoux for (uint64_t i = 1; i < size; i <<= 1) { 5716834803cSThomas Raoux Value shuffled = builder 5726834803cSThomas Raoux .create<gpu::ShuffleOp>(loc, laneVal, i, 5736834803cSThomas Raoux /*width=*/size, 5746834803cSThomas Raoux /*mode=*/gpu::ShuffleMode::XOR) 575986b5c56SRiver Riddle .getShuffleResult(); 5766834803cSThomas Raoux laneVal = makeArithReduction(builder, loc, kind, laneVal, shuffled); 5776834803cSThomas Raoux } 5786834803cSThomas Raoux return laneVal; 5796834803cSThomas Raoux } 5806834803cSThomas Raoux 581d02f10d9SThomas Raoux struct TestVectorDistribution 582d02f10d9SThomas Raoux : public PassWrapper<TestVectorDistribution, OperationPass<func::FuncOp>> { 583d02f10d9SThomas Raoux MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorDistribution) 584d02f10d9SThomas Raoux 585d02f10d9SThomas Raoux void getDependentDialects(DialectRegistry ®istry) const override { 586ecaf2c33SPetr Kurapov registry 587ecaf2c33SPetr Kurapov .insert<vector::VectorDialect, scf::SCFDialect, memref::MemRefDialect, 588ecaf2c33SPetr Kurapov gpu::GPUDialect, affine::AffineDialect>(); 589d02f10d9SThomas Raoux } 590d02f10d9SThomas Raoux 591d02f10d9SThomas Raoux StringRef getArgument() const final { return "test-vector-warp-distribute"; } 592d02f10d9SThomas Raoux StringRef getDescription() const final { 593d02f10d9SThomas Raoux return "Test vector warp distribute transformation and lowering patterns"; 594d02f10d9SThomas Raoux } 595d02f10d9SThomas Raoux TestVectorDistribution() = default; 596d02f10d9SThomas Raoux TestVectorDistribution(const TestVectorDistribution &pass) 597d02f10d9SThomas Raoux : PassWrapper(pass) {} 598d02f10d9SThomas Raoux 599d02f10d9SThomas Raoux Option<bool> warpOpToSCF{ 600d02f10d9SThomas Raoux *this, "rewrite-warp-ops-to-scf-if", 601d02f10d9SThomas Raoux llvm::cl::desc("Lower vector.warp_execute_on_lane0 to scf.if op"), 602d02f10d9SThomas Raoux llvm::cl::init(false)}; 603d02f10d9SThomas Raoux 604ed0288f7SThomas Raoux Option<bool> distributeTransferWriteOps{ 605ed0288f7SThomas Raoux *this, "distribute-transfer-write", 606ed0288f7SThomas Raoux llvm::cl::desc("Test distribution of transfer write"), 607ed0288f7SThomas Raoux llvm::cl::init(false)}; 608ed0288f7SThomas Raoux 60980636227SJakub Kuderski Option<unsigned> maxTransferWriteElements{ 61080636227SJakub Kuderski *this, "max-transfer-write-elements", 61180636227SJakub Kuderski llvm::cl::desc("Maximum number of transfer write elements to distribute"), 61280636227SJakub Kuderski llvm::cl::init(1)}; 61380636227SJakub Kuderski 614ed0288f7SThomas Raoux Option<bool> hoistUniform{*this, "hoist-uniform", 615ed0288f7SThomas Raoux llvm::cl::desc("Test hoist uniform"), 616ed0288f7SThomas Raoux llvm::cl::init(false)}; 617ed0288f7SThomas Raoux 61876cf33daSThomas Raoux Option<bool> propagateDistribution{ 61976cf33daSThomas Raoux *this, "propagate-distribution", 620017c75bfSKai Sasaki llvm::cl::desc("Test distribution propagation"), llvm::cl::init(false)}; 62176cf33daSThomas Raoux 622d02f10d9SThomas Raoux void runOnOperation() override { 623d02f10d9SThomas Raoux RewritePatternSet patterns(&getContext()); 624ed0288f7SThomas Raoux 625ed0288f7SThomas Raoux getOperation().walk([&](Operation *op) { 626ecaf2c33SPetr Kurapov if (auto warpOp = dyn_cast<gpu::WarpExecuteOnLane0Op>(op)) { 627ed0288f7SThomas Raoux if (hoistUniform) { 628ed0288f7SThomas Raoux moveScalarUniformCode(warpOp); 629ed0288f7SThomas Raoux } 630ed0288f7SThomas Raoux WalkResult::interrupt(); 631ed0288f7SThomas Raoux } 632ed0288f7SThomas Raoux }); 633ed0288f7SThomas Raoux MLIRContext *ctx = &getContext(); 63491f62f0eSThomas Raoux auto distributionFn = [](Value val) { 635c2b95292SQuinn Dawkins // Create an identity dim map of the same rank as the vector. 6365550c821STres Popp VectorType vecType = dyn_cast<VectorType>(val.getType()); 63791f62f0eSThomas Raoux int64_t vecRank = vecType ? vecType.getRank() : 0; 63891f62f0eSThomas Raoux OpBuilder builder(val.getContext()); 63991f62f0eSThomas Raoux if (vecRank == 0) 64091f62f0eSThomas Raoux return AffineMap::get(val.getContext()); 641c2b95292SQuinn Dawkins return AffineMap::getMultiDimIdentityMap(vecRank, val.getContext()); 642ed0288f7SThomas Raoux }; 6439d51b4e4SMatthias Springer auto shuffleFn = [](Location loc, OpBuilder &builder, Value val, 6449d51b4e4SMatthias Springer Value srcIdx, int64_t warpSz) { 6459d51b4e4SMatthias Springer assert((val.getType().isF32() || val.getType().isInteger(32)) && 6469d51b4e4SMatthias Springer "unsupported shuffle type"); 6479d51b4e4SMatthias Springer Type i32Type = builder.getIntegerType(32); 6489d51b4e4SMatthias Springer Value srcIdxI32 = 6499d51b4e4SMatthias Springer builder.create<arith::IndexCastOp>(loc, i32Type, srcIdx); 6509d51b4e4SMatthias Springer Value warpSzI32 = builder.create<arith::ConstantOp>( 6519d51b4e4SMatthias Springer loc, builder.getIntegerAttr(i32Type, warpSz)); 6529d51b4e4SMatthias Springer Value result = builder 6539d51b4e4SMatthias Springer .create<gpu::ShuffleOp>(loc, val, srcIdxI32, warpSzI32, 6549d51b4e4SMatthias Springer gpu::ShuffleMode::IDX) 6559d51b4e4SMatthias Springer .getResult(0); 6569d51b4e4SMatthias Springer return result; 6579d51b4e4SMatthias Springer }; 658df49a97aSQuinn Dawkins if (distributeTransferWriteOps && propagateDistribution) { 659df49a97aSQuinn Dawkins RewritePatternSet patterns(ctx); 660df49a97aSQuinn Dawkins vector::populatePropagateWarpVectorDistributionPatterns( 661df49a97aSQuinn Dawkins patterns, distributionFn, shuffleFn, /*benefit=*/1, 662df49a97aSQuinn Dawkins /*readBenefit=*/0); 663df49a97aSQuinn Dawkins vector::populateDistributeReduction(patterns, warpReduction, 1); 664df49a97aSQuinn Dawkins populateDistributeTransferWriteOpPatterns(patterns, distributionFn, 2); 665*09dfc571SJacques Pienaar (void)applyPatternsGreedily(getOperation(), std::move(patterns)); 666df49a97aSQuinn Dawkins } else if (distributeTransferWriteOps) { 667ed0288f7SThomas Raoux RewritePatternSet patterns(ctx); 66880636227SJakub Kuderski populateDistributeTransferWriteOpPatterns(patterns, distributionFn, 66980636227SJakub Kuderski maxTransferWriteElements); 670*09dfc571SJacques Pienaar (void)applyPatternsGreedily(getOperation(), std::move(patterns)); 671df49a97aSQuinn Dawkins } else if (propagateDistribution) { 67276cf33daSThomas Raoux RewritePatternSet patterns(ctx); 6739d51b4e4SMatthias Springer vector::populatePropagateWarpVectorDistributionPatterns( 6749d51b4e4SMatthias Springer patterns, distributionFn, shuffleFn); 6756834803cSThomas Raoux vector::populateDistributeReduction(patterns, warpReduction); 676*09dfc571SJacques Pienaar (void)applyPatternsGreedily(getOperation(), std::move(patterns)); 67776cf33daSThomas Raoux } 678d02f10d9SThomas Raoux WarpExecuteOnLane0LoweringOptions options; 679d02f10d9SThomas Raoux options.warpAllocationFn = allocateGlobalSharedMemory; 680d02f10d9SThomas Raoux options.warpSyncronizationFn = [](Location loc, OpBuilder &builder, 681ecaf2c33SPetr Kurapov gpu::WarpExecuteOnLane0Op warpOp) { 682d02f10d9SThomas Raoux builder.create<gpu::BarrierOp>(loc); 683d02f10d9SThomas Raoux }; 684d02f10d9SThomas Raoux // Test on one pattern in isolation. 685d02f10d9SThomas Raoux if (warpOpToSCF) { 686d02f10d9SThomas Raoux populateWarpExecuteOnLane0OpToScfForPattern(patterns, options); 687*09dfc571SJacques Pienaar (void)applyPatternsGreedily(getOperation(), std::move(patterns)); 688d02f10d9SThomas Raoux return; 689d02f10d9SThomas Raoux } 690d02f10d9SThomas Raoux } 691d02f10d9SThomas Raoux }; 692d02f10d9SThomas Raoux 69339c80656SLei Zhang struct TestVectorExtractStridedSliceLowering 69439c80656SLei Zhang : public PassWrapper<TestVectorExtractStridedSliceLowering, 69539c80656SLei Zhang OperationPass<func::FuncOp>> { 69639c80656SLei Zhang MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 69739c80656SLei Zhang TestVectorExtractStridedSliceLowering) 69839c80656SLei Zhang 69939c80656SLei Zhang StringRef getArgument() const final { 70039c80656SLei Zhang return "test-vector-extract-strided-slice-lowering"; 70139c80656SLei Zhang } 70239c80656SLei Zhang StringRef getDescription() const final { 70339c80656SLei Zhang return "Test lowering patterns that converts vector.extract_strided_slice " 70439c80656SLei Zhang "into a chain of vector.extract and vector.insert ops"; 70539c80656SLei Zhang } 70639c80656SLei Zhang void runOnOperation() override { 70739c80656SLei Zhang RewritePatternSet patterns(&getContext()); 70839c80656SLei Zhang populateVectorExtractStridedSliceToExtractInsertChainPatterns(patterns); 709*09dfc571SJacques Pienaar (void)applyPatternsGreedily(getOperation(), std::move(patterns)); 71039c80656SLei Zhang } 71139c80656SLei Zhang }; 71239c80656SLei Zhang 713650f04feSQuinn Dawkins struct TestVectorBreakDownBitCast 714650f04feSQuinn Dawkins : public PassWrapper<TestVectorBreakDownBitCast, 715650f04feSQuinn Dawkins OperationPass<func::FuncOp>> { 716650f04feSQuinn Dawkins MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorBreakDownBitCast) 717650f04feSQuinn Dawkins 718650f04feSQuinn Dawkins StringRef getArgument() const final { 719650f04feSQuinn Dawkins return "test-vector-break-down-bitcast"; 720650f04feSQuinn Dawkins } 721650f04feSQuinn Dawkins StringRef getDescription() const final { 722650f04feSQuinn Dawkins return "Test pattern that breaks down vector.bitcast ops "; 723650f04feSQuinn Dawkins } 724650f04feSQuinn Dawkins void runOnOperation() override { 725650f04feSQuinn Dawkins RewritePatternSet patterns(&getContext()); 726650f04feSQuinn Dawkins populateBreakDownVectorBitCastOpPatterns(patterns, [](BitCastOp op) { 727650f04feSQuinn Dawkins return op.getSourceVectorType().getShape().back() > 4; 728650f04feSQuinn Dawkins }); 729*09dfc571SJacques Pienaar (void)applyPatternsGreedily(getOperation(), std::move(patterns)); 730650f04feSQuinn Dawkins } 731650f04feSQuinn Dawkins }; 732650f04feSQuinn Dawkins 733de13eedaSNicolas Vasilache struct TestCreateVectorBroadcast 734de13eedaSNicolas Vasilache : public PassWrapper<TestCreateVectorBroadcast, 735de13eedaSNicolas Vasilache OperationPass<func::FuncOp>> { 736de13eedaSNicolas Vasilache MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestCreateVectorBroadcast) 737de13eedaSNicolas Vasilache 738de13eedaSNicolas Vasilache StringRef getArgument() const final { return "test-create-vector-broadcast"; } 739de13eedaSNicolas Vasilache StringRef getDescription() const final { 740de13eedaSNicolas Vasilache return "Test optimization transformations for transfer ops"; 741de13eedaSNicolas Vasilache } 742de13eedaSNicolas Vasilache void getDependentDialects(DialectRegistry ®istry) const override { 743de13eedaSNicolas Vasilache registry.insert<vector::VectorDialect>(); 744de13eedaSNicolas Vasilache } 745de13eedaSNicolas Vasilache 746de13eedaSNicolas Vasilache void runOnOperation() override { 747de13eedaSNicolas Vasilache getOperation()->walk([](Operation *op) { 748de13eedaSNicolas Vasilache if (op->getName().getStringRef() != "test_create_broadcast") 749de13eedaSNicolas Vasilache return; 750de13eedaSNicolas Vasilache auto targetShape = 7515550c821STres Popp cast<VectorType>(op->getResult(0).getType()).getShape(); 752de13eedaSNicolas Vasilache auto arrayAttr = 753830b9b07SMehdi Amini cast<DenseI64ArrayAttr>(op->getDiscardableAttr("broadcast_dims")) 754830b9b07SMehdi Amini .asArrayRef(); 755de13eedaSNicolas Vasilache llvm::SetVector<int64_t> broadcastedDims; 756de13eedaSNicolas Vasilache broadcastedDims.insert(arrayAttr.begin(), arrayAttr.end()); 757de13eedaSNicolas Vasilache OpBuilder b(op); 758de13eedaSNicolas Vasilache Value bcast = vector::BroadcastOp::createOrFoldBroadcastOp( 759de13eedaSNicolas Vasilache b, op->getOperand(0), targetShape, broadcastedDims); 760de13eedaSNicolas Vasilache op->getResult(0).replaceAllUsesWith(bcast); 761de13eedaSNicolas Vasilache op->erase(); 762de13eedaSNicolas Vasilache }); 763de13eedaSNicolas Vasilache } 764de13eedaSNicolas Vasilache }; 765de13eedaSNicolas Vasilache 766f80a976aSJakub Kuderski struct TestVectorGatherLowering 767f80a976aSJakub Kuderski : public PassWrapper<TestVectorGatherLowering, 768f80a976aSJakub Kuderski OperationPass<func::FuncOp>> { 769f80a976aSJakub Kuderski MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorGatherLowering) 770f80a976aSJakub Kuderski 771f80a976aSJakub Kuderski StringRef getArgument() const final { return "test-vector-gather-lowering"; } 772f80a976aSJakub Kuderski StringRef getDescription() const final { 773f80a976aSJakub Kuderski return "Test patterns that lower the gather op in the vector conditional " 774f80a976aSJakub Kuderski "loads"; 775f80a976aSJakub Kuderski } 776f80a976aSJakub Kuderski void getDependentDialects(DialectRegistry ®istry) const override { 777f80a976aSJakub Kuderski registry.insert<arith::ArithDialect, func::FuncDialect, 778f80a976aSJakub Kuderski memref::MemRefDialect, scf::SCFDialect, 779f80a976aSJakub Kuderski tensor::TensorDialect, vector::VectorDialect>(); 780f80a976aSJakub Kuderski } 781f80a976aSJakub Kuderski 782f80a976aSJakub Kuderski void runOnOperation() override { 783f80a976aSJakub Kuderski RewritePatternSet patterns(&getContext()); 784f80a976aSJakub Kuderski populateVectorGatherLoweringPatterns(patterns); 785*09dfc571SJacques Pienaar (void)applyPatternsGreedily(getOperation(), std::move(patterns)); 786f80a976aSJakub Kuderski } 787f80a976aSJakub Kuderski }; 788f80a976aSJakub Kuderski 7899a795f0cSManish Gupta struct TestFoldArithExtensionIntoVectorContractPatterns 7909a795f0cSManish Gupta : public PassWrapper<TestFoldArithExtensionIntoVectorContractPatterns, 7919a795f0cSManish Gupta OperationPass<func::FuncOp>> { 7929a795f0cSManish Gupta MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 7939a795f0cSManish Gupta TestFoldArithExtensionIntoVectorContractPatterns) 7949a795f0cSManish Gupta 7959a795f0cSManish Gupta StringRef getArgument() const final { 7969a795f0cSManish Gupta return "test-fold-arith-extf-into-vector-contract-patterns"; 7979a795f0cSManish Gupta } 7989a795f0cSManish Gupta StringRef getDescription() const final { 7999a795f0cSManish Gupta return "Test patterns that fold arithmetic extension ops into vector " 8009a795f0cSManish Gupta "contract ops"; 8019a795f0cSManish Gupta } 8029a795f0cSManish Gupta 8039a795f0cSManish Gupta void getDependentDialects(DialectRegistry ®istry) const override { 8049a795f0cSManish Gupta registry.insert<arith::ArithDialect, func::FuncDialect, nvgpu::NVGPUDialect, 8059a795f0cSManish Gupta memref::MemRefDialect, scf::SCFDialect, 8069a795f0cSManish Gupta tensor::TensorDialect, vector::VectorDialect>(); 8079a795f0cSManish Gupta } 8089a795f0cSManish Gupta 8099a795f0cSManish Gupta void runOnOperation() override { 8109a795f0cSManish Gupta RewritePatternSet patterns(&getContext()); 8119a795f0cSManish Gupta populateFoldArithExtensionPatterns(patterns); 812*09dfc571SJacques Pienaar (void)applyPatternsGreedily(getOperation(), std::move(patterns)); 8139a795f0cSManish Gupta } 8149a795f0cSManish Gupta }; 815f643eec8SHsiangkai Wang 816f643eec8SHsiangkai Wang struct TestVectorEmulateMaskedLoadStore final 817f643eec8SHsiangkai Wang : public PassWrapper<TestVectorEmulateMaskedLoadStore, 818f643eec8SHsiangkai Wang OperationPass<func::FuncOp>> { 819f643eec8SHsiangkai Wang MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorEmulateMaskedLoadStore) 820f643eec8SHsiangkai Wang 821f643eec8SHsiangkai Wang StringRef getArgument() const override { 822f643eec8SHsiangkai Wang return "test-vector-emulate-masked-load-store"; 823f643eec8SHsiangkai Wang } 824f643eec8SHsiangkai Wang StringRef getDescription() const override { 825f643eec8SHsiangkai Wang return "Test patterns that emulate the maskedload/maskedstore op by " 826f643eec8SHsiangkai Wang " memref.load/store and scf.if"; 827f643eec8SHsiangkai Wang } 828f643eec8SHsiangkai Wang void getDependentDialects(DialectRegistry ®istry) const override { 829f643eec8SHsiangkai Wang registry 830f643eec8SHsiangkai Wang .insert<arith::ArithDialect, func::FuncDialect, memref::MemRefDialect, 831f643eec8SHsiangkai Wang scf::SCFDialect, vector::VectorDialect>(); 832f643eec8SHsiangkai Wang } 833f643eec8SHsiangkai Wang 834f643eec8SHsiangkai Wang void runOnOperation() override { 835f643eec8SHsiangkai Wang RewritePatternSet patterns(&getContext()); 836f643eec8SHsiangkai Wang populateVectorMaskedLoadStoreEmulationPatterns(patterns); 837*09dfc571SJacques Pienaar (void)applyPatternsGreedily(getOperation(), std::move(patterns)); 838f643eec8SHsiangkai Wang } 839f643eec8SHsiangkai Wang }; 84035ef3994SIvan Butygin 84135ef3994SIvan Butygin struct TestVectorLinearize final 84235ef3994SIvan Butygin : public PassWrapper<TestVectorLinearize, OperationPass<>> { 84335ef3994SIvan Butygin MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorLinearize) 84435ef3994SIvan Butygin 8456f5c4f2eSBalaji V. Iyer TestVectorLinearize() = default; 8466f5c4f2eSBalaji V. Iyer TestVectorLinearize(const TestVectorLinearize &pass) : PassWrapper(pass) {} 8476f5c4f2eSBalaji V. Iyer 84835ef3994SIvan Butygin StringRef getArgument() const override { return "test-vector-linearize"; } 84935ef3994SIvan Butygin StringRef getDescription() const override { 85035ef3994SIvan Butygin return "Linearizes ND vectors for N >= 2 into 1D vectors"; 85135ef3994SIvan Butygin } 85235ef3994SIvan Butygin void getDependentDialects(DialectRegistry ®istry) const override { 85335ef3994SIvan Butygin registry.insert<vector::VectorDialect>(); 85435ef3994SIvan Butygin } 85535ef3994SIvan Butygin 8566f5c4f2eSBalaji V. Iyer Option<unsigned> targetVectorBitwidth{ 8576f5c4f2eSBalaji V. Iyer *this, "target-vector-bitwidth", 8586f5c4f2eSBalaji V. Iyer llvm::cl::desc( 8596f5c4f2eSBalaji V. Iyer "Minimum vector bitwidth to enable the flattening transformation"), 8606f5c4f2eSBalaji V. Iyer llvm::cl::init(std::numeric_limits<unsigned>::max())}; 86135ef3994SIvan Butygin void runOnOperation() override { 86235ef3994SIvan Butygin auto *context = &getContext(); 86335ef3994SIvan Butygin 86435ef3994SIvan Butygin TypeConverter typeConverter; 86535ef3994SIvan Butygin RewritePatternSet patterns(context); 86635ef3994SIvan Butygin ConversionTarget target(*context); 86735ef3994SIvan Butygin 8686f5c4f2eSBalaji V. Iyer vector::populateVectorLinearizeTypeConversionsAndLegality( 8696f5c4f2eSBalaji V. Iyer typeConverter, patterns, target, targetVectorBitwidth); 870c577f91dSCharitha Saumya vector::populateVectorLinearizeShuffleLikeOpsPatterns( 871c577f91dSCharitha Saumya typeConverter, patterns, target, targetVectorBitwidth); 87235ef3994SIvan Butygin if (failed(applyPartialConversion(getOperation(), target, 87335ef3994SIvan Butygin std::move(patterns)))) 87435ef3994SIvan Butygin return signalPassFailure(); 87535ef3994SIvan Butygin } 87635ef3994SIvan Butygin }; 8779b06e25eSBenjamin Maxwell 8789b06e25eSBenjamin Maxwell struct TestEliminateVectorMasks 8799b06e25eSBenjamin Maxwell : public PassWrapper<TestEliminateVectorMasks, 8809b06e25eSBenjamin Maxwell OperationPass<func::FuncOp>> { 8819b06e25eSBenjamin Maxwell MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestEliminateVectorMasks) 8829b06e25eSBenjamin Maxwell 8839b06e25eSBenjamin Maxwell TestEliminateVectorMasks() = default; 8849b06e25eSBenjamin Maxwell TestEliminateVectorMasks(const TestEliminateVectorMasks &pass) 8859b06e25eSBenjamin Maxwell : PassWrapper(pass) {} 8869b06e25eSBenjamin Maxwell 8879b06e25eSBenjamin Maxwell Option<unsigned> vscaleMin{ 8889b06e25eSBenjamin Maxwell *this, "vscale-min", llvm::cl::desc("Minimum possible value of vscale."), 8899b06e25eSBenjamin Maxwell llvm::cl::init(1)}; 8909b06e25eSBenjamin Maxwell Option<unsigned> vscaleMax{ 8919b06e25eSBenjamin Maxwell *this, "vscale-max", llvm::cl::desc("Maximum possible value of vscale."), 8929b06e25eSBenjamin Maxwell llvm::cl::init(16)}; 8939b06e25eSBenjamin Maxwell 8949b06e25eSBenjamin Maxwell StringRef getArgument() const final { return "test-eliminate-vector-masks"; } 8959b06e25eSBenjamin Maxwell StringRef getDescription() const final { 8969b06e25eSBenjamin Maxwell return "Test eliminating vector masks"; 8979b06e25eSBenjamin Maxwell } 8989b06e25eSBenjamin Maxwell void runOnOperation() override { 8999b06e25eSBenjamin Maxwell IRRewriter rewriter(&getContext()); 9009b06e25eSBenjamin Maxwell eliminateVectorMasks(rewriter, getOperation(), 9019b06e25eSBenjamin Maxwell VscaleRange{vscaleMin, vscaleMax}); 9029b06e25eSBenjamin Maxwell } 9039b06e25eSBenjamin Maxwell }; 904be0a7e9fSMehdi Amini } // namespace 9053fef2d26SRiver Riddle 9063fef2d26SRiver Riddle namespace mlir { 9073fef2d26SRiver Riddle namespace test { 90834ff8573SNicolas Vasilache void registerTestVectorLowerings() { 90934ff8573SNicolas Vasilache PassRegistration<TestVectorToVectorLowering>(); 9103fef2d26SRiver Riddle 911fb7ef637SJakub Kuderski PassRegistration<TestVectorContractionPrepareForMMTLowering>(); 912fb7ef637SJakub Kuderski 913b5e22e6dSMehdi Amini PassRegistration<TestVectorUnrollingPatterns>(); 9143fef2d26SRiver Riddle 915b5e22e6dSMehdi Amini PassRegistration<TestVectorTransferUnrollingPatterns>(); 9163fef2d26SRiver Riddle 9172ec98ffbSMatthias Springer PassRegistration<TestScalarVectorTransferLoweringPatterns>(); 9182ec98ffbSMatthias Springer 919b5e22e6dSMehdi Amini PassRegistration<TestVectorTransferOpt>(); 9203fef2d26SRiver Riddle 921a3dd4e77SAhmed S. Taei PassRegistration<TestVectorTransferCollapseInnerMostContiguousDims>(); 9221d8cc45bSthomasraoux 92342944da5SAndrzej Warzyński PassRegistration<TestVectorSinkPatterns>(); 9244d339ec9SAndrzej Warzynski 9251d8cc45bSthomasraoux PassRegistration<TestVectorReduceToContractPatternsPatterns>(); 9260aea49a7SBenoit Jacob 927d33bad66SJakub Kuderski PassRegistration<TestVectorChainedReductionFoldingPatterns>(); 928d33bad66SJakub Kuderski 92907677113SJakub Kuderski PassRegistration<TestVectorBreakDownReductionPatterns>(); 93007677113SJakub Kuderski 931aba437ceSBenoit Jacob PassRegistration<TestFlattenVectorTransferPatterns>(); 93280e0bf1aSharsh 93380e0bf1aSharsh PassRegistration<TestVectorScanLowering>(); 934d02f10d9SThomas Raoux 935d02f10d9SThomas Raoux PassRegistration<TestVectorDistribution>(); 93639c80656SLei Zhang 93739c80656SLei Zhang PassRegistration<TestVectorExtractStridedSliceLowering>(); 938de13eedaSNicolas Vasilache 939650f04feSQuinn Dawkins PassRegistration<TestVectorBreakDownBitCast>(); 940650f04feSQuinn Dawkins 941de13eedaSNicolas Vasilache PassRegistration<TestCreateVectorBroadcast>(); 942f80a976aSJakub Kuderski 943f80a976aSJakub Kuderski PassRegistration<TestVectorGatherLowering>(); 944e000b62aSLei Zhang 9459a795f0cSManish Gupta PassRegistration<TestFoldArithExtensionIntoVectorContractPatterns>(); 946f643eec8SHsiangkai Wang 947f643eec8SHsiangkai Wang PassRegistration<TestVectorEmulateMaskedLoadStore>(); 94835ef3994SIvan Butygin 94935ef3994SIvan Butygin PassRegistration<TestVectorLinearize>(); 9509b06e25eSBenjamin Maxwell 9519b06e25eSBenjamin Maxwell PassRegistration<TestEliminateVectorMasks>(); 9523fef2d26SRiver Riddle } 9533fef2d26SRiver Riddle } // namespace test 9543fef2d26SRiver Riddle } // namespace mlir 955