1 //===- VectorTransformOps.cpp - Implementation of Vector transform ops ----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h" 10 #include "mlir/Conversion/VectorToSCF/VectorToSCF.h" 11 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 12 #include "mlir/Dialect/Transform/IR/TransformDialect.h" 13 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" 14 #include "mlir/Dialect/Transform/IR/TransformOps.h" 15 #include "mlir/Dialect/Vector/IR/VectorOps.h" 16 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" 17 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" 18 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" 19 #include "mlir/Dialect/X86Vector/Transforms.h" 20 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 21 22 using namespace mlir; 23 using namespace mlir::vector; 24 using namespace mlir::transform; 25 26 //===----------------------------------------------------------------------===// 27 // Apply...PatternsOp 28 //===----------------------------------------------------------------------===// 29 30 void transform::ApplyCastAwayVectorLeadingOneDimPatternsOp::populatePatterns( 31 RewritePatternSet &patterns) { 32 vector::populateCastAwayVectorLeadingOneDimPatterns(patterns); 33 } 34 35 void transform::ApplyFoldArithExtensionPatternsOp::populatePatterns( 36 RewritePatternSet &patterns) { 37 vector::populateFoldArithExtensionPatterns(patterns); 38 } 39 40 void transform::ApplyVectorReductionToContractPatternsOp::populatePatterns( 41 RewritePatternSet &patterns) { 42 vector::populateVectorReductionToContractPatterns(patterns); 43 } 44 45 void transform::ApplyRankReducingSubviewPatternsOp::populatePatterns( 46 RewritePatternSet &patterns) { 47 vector::populateVectorTransferDropUnitDimsPatterns(patterns); 48 } 49 50 void transform::ApplyTransferPermutationPatternsOp::populatePatterns( 51 RewritePatternSet &patterns) { 52 vector::populateVectorTransferPermutationMapLoweringPatterns(patterns); 53 } 54 55 void transform::ApplyLowerBroadcastPatternsOp::populatePatterns( 56 RewritePatternSet &patterns) { 57 populateVectorBroadcastLoweringPatterns(patterns); 58 } 59 60 void transform::ApplyLowerContractionPatternsOp::populatePatterns( 61 RewritePatternSet &patterns) { 62 vector::VectorTransformsOptions vectorTransformOptions; 63 vectorTransformOptions.setVectorTransformsOptions(getLoweringStrategy()); 64 populateVectorContractLoweringPatterns(patterns, vectorTransformOptions, 65 /*benefit=*/1, 66 /*disableOuterProductLowering=*/true); 67 } 68 69 void transform::ApplyLowerMasksPatternsOp::populatePatterns( 70 RewritePatternSet &patterns) { 71 populateVectorMaskOpLoweringPatterns(patterns); 72 } 73 74 void transform::ApplyLowerMaskedTransfersPatternsOp::populatePatterns( 75 RewritePatternSet &patterns) { 76 populateVectorMaskLoweringPatternsForSideEffectingOps(patterns); 77 } 78 79 void transform::ApplyMaterializeMasksPatternsOp::populatePatterns( 80 RewritePatternSet &patterns) { 81 populateVectorMaskMaterializationPatterns(patterns, 82 /*force32BitVectorIndices=*/false); 83 } 84 85 void transform::ApplyLowerMultiReductionPatternsOp::populatePatterns( 86 RewritePatternSet &patterns) { 87 vector::VectorTransformsOptions vectorTransformOptions; 88 vectorTransformOptions.setVectorMultiReductionLowering(getLoweringStrategy()); 89 vector::populateVectorMultiReductionLoweringPatterns( 90 patterns, vectorTransformOptions.vectorMultiReductionLowering); 91 } 92 93 void transform::ApplyLowerOuterProductPatternsOp::populatePatterns( 94 RewritePatternSet &patterns) { 95 populateVectorOuterProductLoweringPatterns(patterns); 96 } 97 98 void transform::ApplyLowerGatherPatternsOp::populatePatterns( 99 RewritePatternSet &patterns) { 100 vector::populateVectorGatherLoweringPatterns(patterns); 101 } 102 103 void transform::ApplyLowerScanPatternsOp::populatePatterns( 104 RewritePatternSet &patterns) { 105 vector::populateVectorScanLoweringPatterns(patterns); 106 } 107 108 void transform::ApplyLowerShapeCastPatternsOp::populatePatterns( 109 RewritePatternSet &patterns) { 110 vector::populateVectorShapeCastLoweringPatterns(patterns); 111 } 112 113 void transform::ApplyLowerTransferPatternsOp::populatePatterns( 114 RewritePatternSet &patterns) { 115 vector::populateVectorTransferLoweringPatterns(patterns, 116 getMaxTransferRank()); 117 } 118 119 void transform::ApplyLowerTransposePatternsOp::populatePatterns( 120 RewritePatternSet &patterns) { 121 vector::populateVectorTransposeLoweringPatterns( 122 patterns, vector::VectorTransformsOptions().setVectorTransposeLowering( 123 getLoweringStrategy())); 124 if (getAvx2LoweringStrategy()) { 125 auto avx2LoweringOptions = 126 x86vector::avx2::LoweringOptions().setTransposeOptions( 127 x86vector::avx2::TransposeLoweringOptions() 128 .lower4x8xf32(true) 129 .lower8x8xf32(true)); 130 x86vector::avx2::populateSpecializedTransposeLoweringPatterns( 131 patterns, avx2LoweringOptions, /*benefit=*/10); 132 } 133 } 134 135 void transform::ApplySplitTransferFullPartialPatternsOp::populatePatterns( 136 RewritePatternSet &patterns) { 137 vector::VectorTransformsOptions vectorTransformOptions; 138 vectorTransformOptions.setVectorTransferSplit(getSplitTransferStrategy()); 139 populateVectorTransferFullPartialPatterns(patterns, vectorTransformOptions); 140 } 141 142 void transform::ApplyTransferToScfPatternsOp::populatePatterns( 143 RewritePatternSet &patterns) { 144 VectorTransferToSCFOptions vectorTransferToSCFOptions = 145 VectorTransferToSCFOptions() 146 .enableFullUnroll(getFullUnroll()) 147 .setTargetRank(getMaxTransferRank()); 148 populateVectorToSCFConversionPatterns(patterns, vectorTransferToSCFOptions); 149 } 150 151 //===----------------------------------------------------------------------===// 152 // Transform op registration 153 //===----------------------------------------------------------------------===// 154 155 namespace { 156 /// Registers new ops and declares PDL as dependent dialect since the additional 157 /// ops are using PDL types for operands and results. 158 class VectorTransformDialectExtension 159 : public transform::TransformDialectExtension< 160 VectorTransformDialectExtension> { 161 public: 162 VectorTransformDialectExtension() { 163 declareGeneratedDialect<vector::VectorDialect>(); 164 declareGeneratedDialect<LLVM::LLVMDialect>(); 165 registerTransformOps< 166 #define GET_OP_LIST 167 #include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.cpp.inc" 168 >(); 169 } 170 }; 171 } // namespace 172 173 #define GET_OP_CLASSES 174 #include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.cpp.inc" 175 176 void mlir::vector::registerTransformDialectExtension( 177 DialectRegistry ®istry) { 178 registry.addExtensions<VectorTransformDialectExtension>(); 179 } 180