1 //===- VectorTransformOps.cpp - Implementation of Vector transform ops ----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h" 10 #include "mlir/Conversion/VectorToSCF/VectorToSCF.h" 11 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 12 #include "mlir/Dialect/Transform/IR/TransformDialect.h" 13 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" 14 #include "mlir/Dialect/Transform/IR/TransformOps.h" 15 #include "mlir/Dialect/Vector/IR/VectorOps.h" 16 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" 17 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" 18 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" 19 #include "mlir/Dialect/X86Vector/Transforms.h" 20 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 21 22 using namespace mlir; 23 using namespace mlir::vector; 24 using namespace mlir::transform; 25 26 //===----------------------------------------------------------------------===// 27 // Apply...PatternsOp 28 //===----------------------------------------------------------------------===// 29 30 void transform::ApplyCastAwayVectorLeadingOneDimPatternsOp::populatePatterns( 31 RewritePatternSet &patterns) { 32 vector::populateCastAwayVectorLeadingOneDimPatterns(patterns); 33 } 34 35 void transform::ApplyFoldArithExtensionPatternsOp::populatePatterns( 36 RewritePatternSet &patterns) { 37 vector::populateFoldArithExtensionPatterns(patterns); 38 } 39 40 void transform::ApplyRankReducingSubviewPatternsOp::populatePatterns( 41 RewritePatternSet &patterns) { 42 vector::populateVectorTransferDropUnitDimsPatterns(patterns); 43 } 44 45 void transform::ApplyTransferPermutationPatternsOp::populatePatterns( 46 RewritePatternSet &patterns) { 47 vector::populateVectorTransferPermutationMapLoweringPatterns(patterns); 48 } 49 50 void transform::ApplyLowerBroadcastPatternsOp::populatePatterns( 51 RewritePatternSet &patterns) { 52 populateVectorBroadcastLoweringPatterns(patterns); 53 } 54 55 void transform::ApplyLowerContractionPatternsOp::populatePatterns( 56 RewritePatternSet &patterns) { 57 vector::VectorTransformsOptions vectorTransformOptions; 58 vectorTransformOptions.setVectorTransformsOptions(getLoweringStrategy()); 59 populateVectorContractLoweringPatterns(patterns, vectorTransformOptions, 60 /*benefit=*/1, 61 /*disableOuterProductLowering=*/true); 62 } 63 64 void transform::ApplyLowerMasksPatternsOp::populatePatterns( 65 RewritePatternSet &patterns) { 66 populateVectorMaskOpLoweringPatterns(patterns); 67 } 68 69 void transform::ApplyLowerMaskedTransfersPatternsOp::populatePatterns( 70 RewritePatternSet &patterns) { 71 populateVectorMaskLoweringPatternsForSideEffectingOps(patterns); 72 } 73 74 void transform::ApplyMaterializeMasksPatternsOp::populatePatterns( 75 RewritePatternSet &patterns) { 76 populateVectorMaskMaterializationPatterns(patterns, 77 /*force32BitVectorIndices=*/false); 78 } 79 80 void transform::ApplyLowerMultiReductionPatternsOp::populatePatterns( 81 RewritePatternSet &patterns) { 82 vector::VectorTransformsOptions vectorTransformOptions; 83 vectorTransformOptions.setVectorMultiReductionLowering(getLoweringStrategy()); 84 vector::populateVectorMultiReductionLoweringPatterns( 85 patterns, vectorTransformOptions.vectorMultiReductionLowering); 86 } 87 88 void transform::ApplyLowerOuterProductPatternsOp::populatePatterns( 89 RewritePatternSet &patterns) { 90 populateVectorOuterProductLoweringPatterns(patterns); 91 } 92 93 void transform::ApplyLowerGatherPatternsOp::populatePatterns( 94 RewritePatternSet &patterns) { 95 vector::populateVectorGatherLoweringPatterns(patterns); 96 } 97 98 void transform::ApplyLowerScanPatternsOp::populatePatterns( 99 RewritePatternSet &patterns) { 100 vector::populateVectorScanLoweringPatterns(patterns); 101 } 102 103 void transform::ApplyLowerShapeCastPatternsOp::populatePatterns( 104 RewritePatternSet &patterns) { 105 vector::populateVectorShapeCastLoweringPatterns(patterns); 106 } 107 108 void transform::ApplyLowerTransferPatternsOp::populatePatterns( 109 RewritePatternSet &patterns) { 110 vector::populateVectorTransferLoweringPatterns(patterns, 111 getMaxTransferRank()); 112 } 113 114 void transform::ApplyLowerTransposePatternsOp::populatePatterns( 115 RewritePatternSet &patterns) { 116 vector::populateVectorTransposeLoweringPatterns( 117 patterns, vector::VectorTransformsOptions().setVectorTransposeLowering( 118 getLoweringStrategy())); 119 if (getAvx2LoweringStrategy()) { 120 auto avx2LoweringOptions = 121 x86vector::avx2::LoweringOptions().setTransposeOptions( 122 x86vector::avx2::TransposeLoweringOptions() 123 .lower4x8xf32(true) 124 .lower8x8xf32(true)); 125 x86vector::avx2::populateSpecializedTransposeLoweringPatterns( 126 patterns, avx2LoweringOptions, /*benefit=*/10); 127 } 128 } 129 130 void transform::ApplySplitTransferFullPartialPatternsOp::populatePatterns( 131 RewritePatternSet &patterns) { 132 vector::VectorTransformsOptions vectorTransformOptions; 133 vectorTransformOptions.setVectorTransferSplit(getSplitTransferStrategy()); 134 populateVectorTransferFullPartialPatterns(patterns, vectorTransformOptions); 135 } 136 137 void transform::ApplyTransferToScfPatternsOp::populatePatterns( 138 RewritePatternSet &patterns) { 139 VectorTransferToSCFOptions vectorTransferToSCFOptions = 140 VectorTransferToSCFOptions() 141 .enableFullUnroll(getFullUnroll()) 142 .setTargetRank(getMaxTransferRank()); 143 populateVectorToSCFConversionPatterns(patterns, vectorTransferToSCFOptions); 144 } 145 146 //===----------------------------------------------------------------------===// 147 // Transform op registration 148 //===----------------------------------------------------------------------===// 149 150 namespace { 151 /// Registers new ops and declares PDL as dependent dialect since the additional 152 /// ops are using PDL types for operands and results. 153 class VectorTransformDialectExtension 154 : public transform::TransformDialectExtension< 155 VectorTransformDialectExtension> { 156 public: 157 VectorTransformDialectExtension() { 158 declareGeneratedDialect<vector::VectorDialect>(); 159 declareGeneratedDialect<LLVM::LLVMDialect>(); 160 registerTransformOps< 161 #define GET_OP_LIST 162 #include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.cpp.inc" 163 >(); 164 } 165 }; 166 } // namespace 167 168 #define GET_OP_CLASSES 169 #include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.cpp.inc" 170 171 void mlir::vector::registerTransformDialectExtension( 172 DialectRegistry ®istry) { 173 registry.addExtensions<VectorTransformDialectExtension>(); 174 } 175