1 //===- VectorTransformOps.cpp - Implementation of Vector transform ops ----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h" 10 #include "mlir/Conversion/VectorToSCF/VectorToSCF.h" 11 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 12 #include "mlir/Dialect/Transform/IR/TransformDialect.h" 13 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" 14 #include "mlir/Dialect/Transform/IR/TransformOps.h" 15 #include "mlir/Dialect/Vector/IR/VectorOps.h" 16 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" 17 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" 18 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" 19 #include "mlir/Dialect/X86Vector/Transforms.h" 20 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 21 22 using namespace mlir; 23 using namespace mlir::vector; 24 using namespace mlir::transform; 25 26 //===----------------------------------------------------------------------===// 27 // ApplyRankReducingSubviewPatternsOp 28 //===----------------------------------------------------------------------===// 29 30 void transform::ApplyRankReducingSubviewPatternsOp::populatePatterns( 31 RewritePatternSet &patterns) { 32 vector::populateVectorTransferDropUnitDimsPatterns(patterns); 33 } 34 35 //===----------------------------------------------------------------------===// 36 // ApplyTransferPermutationPatternsOp 37 //===----------------------------------------------------------------------===// 38 39 void transform::ApplyTransferPermutationPatternsOp::populatePatterns( 40 RewritePatternSet &patterns) { 41 vector::populateVectorTransferPermutationMapLoweringPatterns(patterns); 42 } 43 44 //===----------------------------------------------------------------------===// 45 // LowerBroadcastOp 46 //===----------------------------------------------------------------------===// 47 48 void transform::LowerBroadcastOp::populatePatterns( 49 RewritePatternSet &patterns) { 50 populateVectorBroadcastLoweringPatterns(patterns); 51 } 52 53 //===----------------------------------------------------------------------===// 54 // LowerContractionOp 55 //===----------------------------------------------------------------------===// 56 57 void transform::LowerContractionOp::populatePatterns( 58 RewritePatternSet &patterns) { 59 vector::VectorTransformsOptions vectorTransformOptions; 60 vectorTransformOptions.setVectorTransformsOptions(getLoweringStrategy()); 61 populateVectorContractLoweringPatterns(patterns, vectorTransformOptions, 62 /*benefit=*/1, 63 /*disableOuterProductLowering=*/true); 64 } 65 66 //===----------------------------------------------------------------------===// 67 // LowerMasksOp 68 //===----------------------------------------------------------------------===// 69 70 void transform::LowerMasksOp::populatePatterns(RewritePatternSet &patterns) { 71 populateVectorMaskOpLoweringPatterns(patterns); 72 } 73 74 //===----------------------------------------------------------------------===// 75 // LowerMaskedTransfersOp 76 //===----------------------------------------------------------------------===// 77 78 void transform::LowerMaskedTransfersOp::populatePatterns( 79 RewritePatternSet &patterns) { 80 populateVectorMaskLoweringPatternsForSideEffectingOps(patterns); 81 } 82 83 //===----------------------------------------------------------------------===// 84 // MaterializeMasksOp 85 //===----------------------------------------------------------------------===// 86 87 void transform::MaterializeMasksOp::populatePatterns(RewritePatternSet &patterns) { 88 populateVectorMaskMaterializationPatterns(patterns, 89 /*force32BitVectorIndices=*/false); 90 } 91 92 //===----------------------------------------------------------------------===// 93 // LowerMultiReductionOp 94 //===----------------------------------------------------------------------===// 95 96 void transform::LowerMultiReductionOp::populatePatterns( 97 RewritePatternSet &patterns) { 98 vector::VectorTransformsOptions vectorTransformOptions; 99 vectorTransformOptions.setVectorMultiReductionLowering(getLoweringStrategy()); 100 vector::populateVectorMultiReductionLoweringPatterns( 101 patterns, vectorTransformOptions.vectorMultiReductionLowering); 102 } 103 104 //===----------------------------------------------------------------------===// 105 // LowerOuterProductOp 106 //===----------------------------------------------------------------------===// 107 108 void transform::LowerOuterProductOp::populatePatterns( 109 RewritePatternSet &patterns) { 110 populateVectorOuterProductLoweringPatterns(patterns); 111 } 112 113 //===----------------------------------------------------------------------===// 114 // LowerShapeCastOp 115 //===----------------------------------------------------------------------===// 116 117 void transform::LowerShapeCastOp::populatePatterns( 118 RewritePatternSet &patterns) { 119 vector::populateVectorShapeCastLoweringPatterns(patterns); 120 } 121 122 //===----------------------------------------------------------------------===// 123 // LowerTransferOp 124 //===----------------------------------------------------------------------===// 125 126 void transform::LowerTransferOp::populatePatterns(RewritePatternSet &patterns) { 127 vector::populateVectorTransferLoweringPatterns(patterns, 128 getMaxTransferRank()); 129 } 130 131 //===----------------------------------------------------------------------===// 132 // LowerTransposeOp 133 //===----------------------------------------------------------------------===// 134 135 void transform::LowerTransposeOp::populatePatterns( 136 RewritePatternSet &patterns) { 137 vector::populateVectorTransposeLoweringPatterns( 138 patterns, vector::VectorTransformsOptions().setVectorTransposeLowering( 139 getLoweringStrategy())); 140 if (getAvx2LoweringStrategy()) { 141 auto avx2LoweringOptions = 142 x86vector::avx2::LoweringOptions().setTransposeOptions( 143 x86vector::avx2::TransposeLoweringOptions() 144 .lower4x8xf32(true) 145 .lower8x8xf32(true)); 146 x86vector::avx2::populateSpecializedTransposeLoweringPatterns( 147 patterns, avx2LoweringOptions, /*benefit=*/10); 148 } 149 } 150 151 //===----------------------------------------------------------------------===// 152 // SplitTransferFullPartialOp 153 //===----------------------------------------------------------------------===// 154 155 void transform::SplitTransferFullPartialOp::populatePatterns( 156 RewritePatternSet &patterns) { 157 vector::VectorTransformsOptions vectorTransformOptions; 158 vectorTransformOptions.setVectorTransferSplit(getSplitTransferStrategy()); 159 populateVectorTransferFullPartialPatterns(patterns, vectorTransformOptions); 160 } 161 162 //===----------------------------------------------------------------------===// 163 // TransferToScfOp 164 //===----------------------------------------------------------------------===// 165 166 void transform::TransferToScfOp::populatePatterns(RewritePatternSet &patterns) { 167 VectorTransferToSCFOptions vectorTransferToSCFOptions = 168 VectorTransferToSCFOptions() 169 .enableFullUnroll(getFullUnroll()) 170 .setTargetRank(getMaxTransferRank()); 171 populateVectorToSCFConversionPatterns(patterns, vectorTransferToSCFOptions); 172 } 173 174 //===----------------------------------------------------------------------===// 175 // Transform op registration 176 //===----------------------------------------------------------------------===// 177 178 namespace { 179 /// Registers new ops and declares PDL as dependent dialect since the additional 180 /// ops are using PDL types for operands and results. 181 class VectorTransformDialectExtension 182 : public transform::TransformDialectExtension< 183 VectorTransformDialectExtension> { 184 public: 185 VectorTransformDialectExtension() { 186 declareGeneratedDialect<vector::VectorDialect>(); 187 declareGeneratedDialect<LLVM::LLVMDialect>(); 188 registerTransformOps< 189 #define GET_OP_LIST 190 #include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.cpp.inc" 191 >(); 192 193 addDialectDataInitializer<transform::PatternRegistry>( 194 [&](transform::PatternRegistry ®istry) { 195 registry.registerPatterns("vector.outer_product_lowering", 196 populateVectorOuterProductLoweringPatterns); 197 registry.registerPatterns("vector.broadcast_lowering", 198 populateVectorBroadcastLoweringPatterns); 199 registry.registerPatterns("vector.mask_op_lowering", 200 populateVectorMaskOpLoweringPatterns); 201 registry.registerPatterns("vector.shape_cast_lowering", 202 populateVectorShapeCastLoweringPatterns); 203 registry.registerPatterns( 204 "vector.transfer_lowering", 205 [&](RewritePatternSet &set, PatternBenefit benefit) { 206 return populateVectorTransferLoweringPatterns( 207 set, /*maxTransferRank=*/std::nullopt, benefit); 208 }); 209 registry.registerPatterns( 210 "vector.transfer_permutation_map_lowering", 211 populateVectorTransferPermutationMapLoweringPatterns); 212 registry.registerPatterns("vector.scan_lowering", 213 populateVectorScanLoweringPatterns); 214 registry.registerPatterns("vector.vector_gather_lowering", 215 populateVectorGatherLoweringPatterns); 216 registry.registerPatterns( 217 "vector.mask_lowering_for_side_effecting_ops", 218 populateVectorMaskLoweringPatternsForSideEffectingOps); 219 }); 220 } 221 }; 222 } // namespace 223 224 #define GET_OP_CLASSES 225 #include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.cpp.inc" 226 227 void mlir::vector::registerTransformDialectExtension( 228 DialectRegistry ®istry) { 229 registry.addExtensions<VectorTransformDialectExtension>(); 230 } 231