1 //===- VectorTransformOps.cpp - Implementation of Vector transform ops ----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h" 10 #include "mlir/Conversion/VectorToSCF/VectorToSCF.h" 11 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 12 #include "mlir/Dialect/Transform/IR/TransformDialect.h" 13 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" 14 #include "mlir/Dialect/Vector/IR/VectorOps.h" 15 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" 16 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" 17 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" 18 #include "mlir/Dialect/X86Vector/Transforms.h" 19 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 20 21 using namespace mlir; 22 using namespace mlir::vector; 23 using namespace mlir::transform; 24 25 //===----------------------------------------------------------------------===// 26 // ApplyRankReducingSubviewPatternsOp 27 //===----------------------------------------------------------------------===// 28 29 void transform::ApplyRankReducingSubviewPatternsOp::populatePatterns( 30 RewritePatternSet &patterns) { 31 vector::populateVectorTransferDropUnitDimsPatterns(patterns); 32 } 33 34 //===----------------------------------------------------------------------===// 35 // ApplyTransferPermutationPatternsOp 36 //===----------------------------------------------------------------------===// 37 38 void transform::ApplyTransferPermutationPatternsOp::populatePatterns( 39 RewritePatternSet &patterns) { 40 vector::populateVectorTransferPermutationMapLoweringPatterns(patterns); 41 } 42 43 //===----------------------------------------------------------------------===// 44 // LowerBroadcastOp 45 //===----------------------------------------------------------------------===// 46 47 void transform::LowerBroadcastOp::populatePatterns( 48 RewritePatternSet &patterns) { 49 populateVectorBroadcastLoweringPatterns(patterns); 50 } 51 52 //===----------------------------------------------------------------------===// 53 // LowerContractionOp 54 //===----------------------------------------------------------------------===// 55 56 void transform::LowerContractionOp::populatePatterns( 57 RewritePatternSet &patterns) { 58 vector::VectorTransformsOptions vectorTransformOptions; 59 vectorTransformOptions.setVectorTransformsOptions(getLoweringStrategy()); 60 populateVectorContractLoweringPatterns(patterns, vectorTransformOptions, 61 /*benefit=*/1, 62 /*disableOuterProductLowering=*/true); 63 } 64 65 //===----------------------------------------------------------------------===// 66 // LowerMasksOp 67 //===----------------------------------------------------------------------===// 68 69 void transform::LowerMasksOp::populatePatterns(RewritePatternSet &patterns) { 70 populateVectorMaskOpLoweringPatterns(patterns); 71 } 72 73 //===----------------------------------------------------------------------===// 74 // LowerMaskedTransfersOp 75 //===----------------------------------------------------------------------===// 76 77 void transform::LowerMaskedTransfersOp::populatePatterns( 78 RewritePatternSet &patterns) { 79 populateVectorMaskLoweringPatternsForSideEffectingOps(patterns); 80 } 81 82 //===----------------------------------------------------------------------===// 83 // MaterializeMasksOp 84 //===----------------------------------------------------------------------===// 85 86 void transform::MaterializeMasksOp::populatePatterns(RewritePatternSet &patterns) { 87 populateVectorMaskMaterializationPatterns(patterns, 88 /*force32BitVectorIndices=*/false); 89 } 90 91 //===----------------------------------------------------------------------===// 92 // LowerMultiReductionOp 93 //===----------------------------------------------------------------------===// 94 95 void transform::LowerMultiReductionOp::populatePatterns( 96 RewritePatternSet &patterns) { 97 vector::VectorTransformsOptions vectorTransformOptions; 98 vectorTransformOptions.setVectorMultiReductionLowering(getLoweringStrategy()); 99 vector::populateVectorMultiReductionLoweringPatterns( 100 patterns, vectorTransformOptions.vectorMultiReductionLowering); 101 } 102 103 //===----------------------------------------------------------------------===// 104 // LowerOuterProductOp 105 //===----------------------------------------------------------------------===// 106 107 void transform::LowerOuterProductOp::populatePatterns( 108 RewritePatternSet &patterns) { 109 populateVectorOuterProductLoweringPatterns(patterns); 110 } 111 112 //===----------------------------------------------------------------------===// 113 // LowerShapeCastOp 114 //===----------------------------------------------------------------------===// 115 116 void transform::LowerShapeCastOp::populatePatterns( 117 RewritePatternSet &patterns) { 118 vector::populateVectorShapeCastLoweringPatterns(patterns); 119 } 120 121 //===----------------------------------------------------------------------===// 122 // LowerTransferOp 123 //===----------------------------------------------------------------------===// 124 125 void transform::LowerTransferOp::populatePatterns(RewritePatternSet &patterns) { 126 vector::populateVectorTransferLoweringPatterns(patterns, 127 getMaxTransferRank()); 128 } 129 130 //===----------------------------------------------------------------------===// 131 // LowerTransposeOp 132 //===----------------------------------------------------------------------===// 133 134 void transform::LowerTransposeOp::populatePatterns( 135 RewritePatternSet &patterns) { 136 vector::populateVectorTransposeLoweringPatterns( 137 patterns, vector::VectorTransformsOptions().setVectorTransposeLowering( 138 getLoweringStrategy())); 139 if (getAvx2LoweringStrategy()) { 140 auto avx2LoweringOptions = 141 x86vector::avx2::LoweringOptions().setTransposeOptions( 142 x86vector::avx2::TransposeLoweringOptions() 143 .lower4x8xf32(true) 144 .lower8x8xf32(true)); 145 x86vector::avx2::populateSpecializedTransposeLoweringPatterns( 146 patterns, avx2LoweringOptions, /*benefit=*/10); 147 } 148 } 149 150 //===----------------------------------------------------------------------===// 151 // SplitTransferFullPartialOp 152 //===----------------------------------------------------------------------===// 153 154 void transform::SplitTransferFullPartialOp::populatePatterns( 155 RewritePatternSet &patterns) { 156 vector::VectorTransformsOptions vectorTransformOptions; 157 vectorTransformOptions.setVectorTransferSplit(getSplitTransferStrategy()); 158 populateVectorTransferFullPartialPatterns(patterns, vectorTransformOptions); 159 } 160 161 //===----------------------------------------------------------------------===// 162 // TransferToScfOp 163 //===----------------------------------------------------------------------===// 164 165 void transform::TransferToScfOp::populatePatterns(RewritePatternSet &patterns) { 166 VectorTransferToSCFOptions vectorTransferToSCFOptions = 167 VectorTransferToSCFOptions() 168 .enableFullUnroll(getFullUnroll()) 169 .setTargetRank(getMaxTransferRank()); 170 populateVectorToSCFConversionPatterns(patterns, vectorTransferToSCFOptions); 171 } 172 173 //===----------------------------------------------------------------------===// 174 // Transform op registration 175 //===----------------------------------------------------------------------===// 176 177 namespace { 178 /// Registers new ops and declares PDL as dependent dialect since the additional 179 /// ops are using PDL types for operands and results. 180 class VectorTransformDialectExtension 181 : public transform::TransformDialectExtension< 182 VectorTransformDialectExtension> { 183 public: 184 VectorTransformDialectExtension() { 185 declareGeneratedDialect<vector::VectorDialect>(); 186 declareGeneratedDialect<LLVM::LLVMDialect>(); 187 registerTransformOps< 188 #define GET_OP_LIST 189 #include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.cpp.inc" 190 >(); 191 } 192 }; 193 } // namespace 194 195 #define GET_OP_CLASSES 196 #include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.cpp.inc" 197 198 void mlir::vector::registerTransformDialectExtension( 199 DialectRegistry ®istry) { 200 registry.addExtensions<VectorTransformDialectExtension>(); 201 } 202