1 //===- VectorTransformOps.cpp - Implementation of Vector transform ops ----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h" 10 #include "mlir/Conversion/VectorToSCF/VectorToSCF.h" 11 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 12 #include "mlir/Dialect/Transform/IR/TransformDialect.h" 13 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" 14 #include "mlir/Dialect/Transform/IR/TransformOps.h" 15 #include "mlir/Dialect/Vector/IR/VectorOps.h" 16 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" 17 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" 18 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" 19 #include "mlir/Dialect/X86Vector/Transforms.h" 20 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 21 22 using namespace mlir; 23 using namespace mlir::vector; 24 using namespace mlir::transform; 25 26 //===----------------------------------------------------------------------===// 27 // Apply...PatternsOp 28 //===----------------------------------------------------------------------===// 29 30 void transform::ApplyRankReducingSubviewPatternsOp::populatePatterns( 31 RewritePatternSet &patterns) { 32 vector::populateVectorTransferDropUnitDimsPatterns(patterns); 33 } 34 35 void transform::ApplyTransferPermutationPatternsOp::populatePatterns( 36 RewritePatternSet &patterns) { 37 vector::populateVectorTransferPermutationMapLoweringPatterns(patterns); 38 } 39 40 void transform::ApplyLowerBroadcastPatternsOp::populatePatterns( 41 RewritePatternSet &patterns) { 42 populateVectorBroadcastLoweringPatterns(patterns); 43 } 44 45 void transform::ApplyLowerContractionPatternsOp::populatePatterns( 46 RewritePatternSet &patterns) { 47 vector::VectorTransformsOptions vectorTransformOptions; 48 vectorTransformOptions.setVectorTransformsOptions(getLoweringStrategy()); 49 populateVectorContractLoweringPatterns(patterns, vectorTransformOptions, 50 /*benefit=*/1, 51 /*disableOuterProductLowering=*/true); 52 } 53 54 void transform::ApplyLowerMasksPatternsOp::populatePatterns( 55 RewritePatternSet &patterns) { 56 populateVectorMaskOpLoweringPatterns(patterns); 57 } 58 59 void transform::ApplyLowerMaskedTransfersPatternsOp::populatePatterns( 60 RewritePatternSet &patterns) { 61 populateVectorMaskLoweringPatternsForSideEffectingOps(patterns); 62 } 63 64 void transform::ApplyMaterializeMasksPatternsOp::populatePatterns( 65 RewritePatternSet &patterns) { 66 populateVectorMaskMaterializationPatterns(patterns, 67 /*force32BitVectorIndices=*/false); 68 } 69 70 void transform::ApplyLowerMultiReductionPatternsOp::populatePatterns( 71 RewritePatternSet &patterns) { 72 vector::VectorTransformsOptions vectorTransformOptions; 73 vectorTransformOptions.setVectorMultiReductionLowering(getLoweringStrategy()); 74 vector::populateVectorMultiReductionLoweringPatterns( 75 patterns, vectorTransformOptions.vectorMultiReductionLowering); 76 } 77 78 void transform::ApplyLowerOuterProductPatternsOp::populatePatterns( 79 RewritePatternSet &patterns) { 80 populateVectorOuterProductLoweringPatterns(patterns); 81 } 82 83 void transform::ApplyLowerGatherPatternsOp::populatePatterns( 84 RewritePatternSet &patterns) { 85 vector::populateVectorGatherLoweringPatterns(patterns); 86 } 87 88 void transform::ApplyLowerScanPatternsOp::populatePatterns( 89 RewritePatternSet &patterns) { 90 vector::populateVectorScanLoweringPatterns(patterns); 91 } 92 93 void transform::ApplyLowerShapeCastPatternsOp::populatePatterns( 94 RewritePatternSet &patterns) { 95 vector::populateVectorShapeCastLoweringPatterns(patterns); 96 } 97 98 void transform::ApplyLowerTransferPatternsOp::populatePatterns( 99 RewritePatternSet &patterns) { 100 vector::populateVectorTransferLoweringPatterns(patterns, 101 getMaxTransferRank()); 102 } 103 104 void transform::ApplyLowerTransposePatternsOp::populatePatterns( 105 RewritePatternSet &patterns) { 106 vector::populateVectorTransposeLoweringPatterns( 107 patterns, vector::VectorTransformsOptions().setVectorTransposeLowering( 108 getLoweringStrategy())); 109 if (getAvx2LoweringStrategy()) { 110 auto avx2LoweringOptions = 111 x86vector::avx2::LoweringOptions().setTransposeOptions( 112 x86vector::avx2::TransposeLoweringOptions() 113 .lower4x8xf32(true) 114 .lower8x8xf32(true)); 115 x86vector::avx2::populateSpecializedTransposeLoweringPatterns( 116 patterns, avx2LoweringOptions, /*benefit=*/10); 117 } 118 } 119 120 void transform::ApplySplitTransferFullPartialPatternsOp::populatePatterns( 121 RewritePatternSet &patterns) { 122 vector::VectorTransformsOptions vectorTransformOptions; 123 vectorTransformOptions.setVectorTransferSplit(getSplitTransferStrategy()); 124 populateVectorTransferFullPartialPatterns(patterns, vectorTransformOptions); 125 } 126 127 void transform::ApplyTransferToScfPatternsOp::populatePatterns( 128 RewritePatternSet &patterns) { 129 VectorTransferToSCFOptions vectorTransferToSCFOptions = 130 VectorTransferToSCFOptions() 131 .enableFullUnroll(getFullUnroll()) 132 .setTargetRank(getMaxTransferRank()); 133 populateVectorToSCFConversionPatterns(patterns, vectorTransferToSCFOptions); 134 } 135 136 //===----------------------------------------------------------------------===// 137 // Transform op registration 138 //===----------------------------------------------------------------------===// 139 140 namespace { 141 /// Registers new ops and declares PDL as dependent dialect since the additional 142 /// ops are using PDL types for operands and results. 143 class VectorTransformDialectExtension 144 : public transform::TransformDialectExtension< 145 VectorTransformDialectExtension> { 146 public: 147 VectorTransformDialectExtension() { 148 declareGeneratedDialect<vector::VectorDialect>(); 149 declareGeneratedDialect<LLVM::LLVMDialect>(); 150 registerTransformOps< 151 #define GET_OP_LIST 152 #include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.cpp.inc" 153 >(); 154 } 155 }; 156 } // namespace 157 158 #define GET_OP_CLASSES 159 #include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.cpp.inc" 160 161 void mlir::vector::registerTransformDialectExtension( 162 DialectRegistry ®istry) { 163 registry.addExtensions<VectorTransformDialectExtension>(); 164 } 165