1 //===- VectorTransformOps.cpp - Implementation of Vector transform ops ----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h" 10 #include "mlir/Conversion/VectorToSCF/VectorToSCF.h" 11 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 12 #include "mlir/Dialect/Transform/IR/TransformDialect.h" 13 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" 14 #include "mlir/Dialect/Vector/IR/VectorOps.h" 15 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" 16 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" 17 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" 18 #include "mlir/Dialect/X86Vector/Transforms.h" 19 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 20 21 using namespace mlir; 22 using namespace mlir::vector; 23 using namespace mlir::transform; 24 25 //===----------------------------------------------------------------------===// 26 // ApplyRankReducingSubviewPatternsOp 27 //===----------------------------------------------------------------------===// 28 29 void transform::ApplyRankReducingSubviewPatternsOp::populatePatterns( 30 RewritePatternSet &patterns) { 31 vector::populateVectorTransferDropUnitDimsPatterns(patterns); 32 } 33 34 //===----------------------------------------------------------------------===// 35 // ApplyTransferPermutationPatternsOp 36 //===----------------------------------------------------------------------===// 37 38 void transform::ApplyTransferPermutationPatternsOp::populatePatterns( 39 RewritePatternSet &patterns) { 40 vector::populateVectorTransferPermutationMapLoweringPatterns(patterns); 41 } 42 43 //===----------------------------------------------------------------------===// 44 // LowerBroadcastOp 45 //===----------------------------------------------------------------------===// 46 47 void transform::LowerBroadcastOp::populatePatterns( 48 RewritePatternSet &patterns) { 49 populateVectorBroadcastLoweringPatterns(patterns); 50 } 51 52 //===----------------------------------------------------------------------===// 53 // LowerContractionOp 54 //===----------------------------------------------------------------------===// 55 56 void transform::LowerContractionOp::populatePatterns( 57 RewritePatternSet &patterns) { 58 vector::VectorTransformsOptions vectorTransformOptions; 59 vectorTransformOptions.setVectorTransformsOptions(getLoweringStrategy()); 60 populateVectorContractLoweringPatterns(patterns, vectorTransformOptions, 61 /*benefit=*/1, 62 /*disableOuterProductLowering=*/true); 63 } 64 65 //===----------------------------------------------------------------------===// 66 // LowerMasksOp 67 //===----------------------------------------------------------------------===// 68 69 void transform::LowerMasksOp::populatePatterns(RewritePatternSet &patterns) { 70 populateVectorMaskOpLoweringPatterns(patterns); 71 } 72 73 //===----------------------------------------------------------------------===// 74 // LowerMaskedTransfersOp 75 //===----------------------------------------------------------------------===// 76 77 void transform::LowerMaskedTransfersOp::populatePatterns( 78 RewritePatternSet &patterns) { 79 populateVectorMaskLoweringPatternsForSideEffectingOps(patterns); 80 } 81 82 //===----------------------------------------------------------------------===// 83 // LowerMultiReductionOp 84 //===----------------------------------------------------------------------===// 85 86 void transform::LowerMultiReductionOp::populatePatterns( 87 RewritePatternSet &patterns) { 88 vector::VectorTransformsOptions vectorTransformOptions; 89 vectorTransformOptions.setVectorMultiReductionLowering(getLoweringStrategy()); 90 vector::populateVectorMultiReductionLoweringPatterns( 91 patterns, vectorTransformOptions.vectorMultiReductionLowering); 92 } 93 94 //===----------------------------------------------------------------------===// 95 // LowerOuterProductOp 96 //===----------------------------------------------------------------------===// 97 98 void transform::LowerOuterProductOp::populatePatterns( 99 RewritePatternSet &patterns) { 100 populateVectorOuterProductLoweringPatterns(patterns); 101 } 102 103 //===----------------------------------------------------------------------===// 104 // LowerShapeCastOp 105 //===----------------------------------------------------------------------===// 106 107 void transform::LowerShapeCastOp::populatePatterns( 108 RewritePatternSet &patterns) { 109 vector::populateVectorShapeCastLoweringPatterns(patterns); 110 } 111 112 //===----------------------------------------------------------------------===// 113 // LowerTransferOp 114 //===----------------------------------------------------------------------===// 115 116 void transform::LowerTransferOp::populatePatterns(RewritePatternSet &patterns) { 117 vector::populateVectorTransferLoweringPatterns(patterns, 118 getMaxTransferRank()); 119 } 120 121 //===----------------------------------------------------------------------===// 122 // LowerTransposeOp 123 //===----------------------------------------------------------------------===// 124 125 void transform::LowerTransposeOp::populatePatterns( 126 RewritePatternSet &patterns) { 127 vector::populateVectorTransposeLoweringPatterns( 128 patterns, vector::VectorTransformsOptions().setVectorTransposeLowering( 129 getLoweringStrategy())); 130 if (getAvx2LoweringStrategy()) { 131 auto avx2LoweringOptions = 132 x86vector::avx2::LoweringOptions().setTransposeOptions( 133 x86vector::avx2::TransposeLoweringOptions() 134 .lower4x8xf32(true) 135 .lower8x8xf32(true)); 136 x86vector::avx2::populateSpecializedTransposeLoweringPatterns( 137 patterns, avx2LoweringOptions, /*benefit=*/10); 138 } 139 } 140 141 //===----------------------------------------------------------------------===// 142 // SplitTransferFullPartialOp 143 //===----------------------------------------------------------------------===// 144 145 void transform::SplitTransferFullPartialOp::populatePatterns( 146 RewritePatternSet &patterns) { 147 vector::VectorTransformsOptions vectorTransformOptions; 148 vectorTransformOptions.setVectorTransferSplit(getSplitTransferStrategy()); 149 populateVectorTransferFullPartialPatterns(patterns, vectorTransformOptions); 150 } 151 152 //===----------------------------------------------------------------------===// 153 // TransferToScfOp 154 //===----------------------------------------------------------------------===// 155 156 void transform::TransferToScfOp::populatePatterns(RewritePatternSet &patterns) { 157 VectorTransferToSCFOptions vectorTransferToSCFOptions = 158 VectorTransferToSCFOptions() 159 .enableFullUnroll(getFullUnroll()) 160 .setTargetRank(getMaxTransferRank()); 161 populateVectorToSCFConversionPatterns(patterns, vectorTransferToSCFOptions); 162 } 163 164 //===----------------------------------------------------------------------===// 165 // Transform op registration 166 //===----------------------------------------------------------------------===// 167 168 namespace { 169 /// Registers new ops and declares PDL as dependent dialect since the additional 170 /// ops are using PDL types for operands and results. 171 class VectorTransformDialectExtension 172 : public transform::TransformDialectExtension< 173 VectorTransformDialectExtension> { 174 public: 175 VectorTransformDialectExtension() { 176 declareGeneratedDialect<vector::VectorDialect>(); 177 declareGeneratedDialect<LLVM::LLVMDialect>(); 178 registerTransformOps< 179 #define GET_OP_LIST 180 #include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.cpp.inc" 181 >(); 182 } 183 }; 184 } // namespace 185 186 #define GET_OP_CLASSES 187 #include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.cpp.inc" 188 189 void mlir::vector::registerTransformDialectExtension( 190 DialectRegistry ®istry) { 191 registry.addExtensions<VectorTransformDialectExtension>(); 192 } 193