1 //===- VectorTransformOps.cpp - Implementation of Vector transform ops ----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h" 10 11 #include "mlir/Conversion/LLVMCommon/TypeConverter.h" 12 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" 13 #include "mlir/Conversion/VectorToSCF/VectorToSCF.h" 14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 15 #include "mlir/Dialect/Transform/IR/TransformDialect.h" 16 #include "mlir/Dialect/Transform/IR/TransformOps.h" 17 #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" 18 #include "mlir/Dialect/Vector/IR/VectorOps.h" 19 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" 20 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" 21 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" 22 #include "mlir/Dialect/X86Vector/Transforms.h" 23 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 24 25 using namespace mlir; 26 using namespace mlir::vector; 27 using namespace mlir::transform; 28 29 //===----------------------------------------------------------------------===// 30 // Apply...ConversionPatternsOp 31 //===----------------------------------------------------------------------===// 32 33 void transform::ApplyVectorToLLVMConversionPatternsOp::populatePatterns( 34 TypeConverter &typeConverter, RewritePatternSet &patterns) { 35 populateVectorToLLVMConversionPatterns( 36 static_cast<LLVMTypeConverter &>(typeConverter), patterns, 37 getReassociateFpReductions(), getForce_32bitVectorIndices()); 38 } 39 40 LogicalResult 41 transform::ApplyVectorToLLVMConversionPatternsOp::verifyTypeConverter( 42 transform::TypeConverterBuilderOpInterface builder) { 43 if (builder.getTypeConverterType() != "LLVMTypeConverter") 44 return emitOpError("expected LLVMTypeConverter"); 45 return success(); 46 } 47 48 //===----------------------------------------------------------------------===// 49 // Apply...PatternsOp 50 //===----------------------------------------------------------------------===// 51 52 void transform::ApplyCastAwayVectorLeadingOneDimPatternsOp::populatePatterns( 53 RewritePatternSet &patterns) { 54 vector::populateCastAwayVectorLeadingOneDimPatterns(patterns); 55 } 56 57 void transform::ApplyFoldArithExtensionPatternsOp::populatePatterns( 58 RewritePatternSet &patterns) { 59 vector::populateFoldArithExtensionPatterns(patterns); 60 } 61 62 void transform::ApplyVectorReductionToContractPatternsOp::populatePatterns( 63 RewritePatternSet &patterns) { 64 vector::populateVectorReductionToContractPatterns(patterns); 65 } 66 67 void transform::ApplyLowerCreateMaskPatternsOp::populatePatterns( 68 RewritePatternSet &patterns) { 69 vector::populateVectorMaskOpLoweringPatterns(patterns); 70 } 71 72 void transform::ApplyRankReducingSubviewPatternsOp::populatePatterns( 73 RewritePatternSet &patterns) { 74 vector::populateVectorTransferDropUnitDimsPatterns(patterns); 75 } 76 77 void transform::ApplyTransferPermutationPatternsOp::populatePatterns( 78 RewritePatternSet &patterns) { 79 vector::populateVectorTransferPermutationMapLoweringPatterns(patterns); 80 } 81 82 void transform::ApplyLowerBroadcastPatternsOp::populatePatterns( 83 RewritePatternSet &patterns) { 84 populateVectorBroadcastLoweringPatterns(patterns); 85 } 86 87 void transform::ApplyLowerContractionPatternsOp::populatePatterns( 88 RewritePatternSet &patterns) { 89 vector::VectorTransformsOptions vectorTransformOptions; 90 vectorTransformOptions.setVectorTransformsOptions(getLoweringStrategy()); 91 populateVectorContractLoweringPatterns(patterns, vectorTransformOptions, 92 /*benefit=*/1, 93 /*disableOuterProductLowering=*/true); 94 } 95 96 void transform::ApplyLowerMasksPatternsOp::populatePatterns( 97 RewritePatternSet &patterns) { 98 populateVectorMaskOpLoweringPatterns(patterns); 99 } 100 101 void transform::ApplyLowerMaskedTransfersPatternsOp::populatePatterns( 102 RewritePatternSet &patterns) { 103 populateVectorMaskLoweringPatternsForSideEffectingOps(patterns); 104 } 105 106 void transform::ApplyMaterializeMasksPatternsOp::populatePatterns( 107 RewritePatternSet &patterns) { 108 populateVectorMaskMaterializationPatterns(patterns, 109 /*force32BitVectorIndices=*/false); 110 } 111 112 void transform::ApplyLowerMultiReductionPatternsOp::populatePatterns( 113 RewritePatternSet &patterns) { 114 vector::VectorTransformsOptions vectorTransformOptions; 115 vectorTransformOptions.setVectorMultiReductionLowering(getLoweringStrategy()); 116 vector::populateVectorMultiReductionLoweringPatterns( 117 patterns, vectorTransformOptions.vectorMultiReductionLowering); 118 } 119 120 void transform::ApplyLowerOuterProductPatternsOp::populatePatterns( 121 RewritePatternSet &patterns) { 122 populateVectorOuterProductLoweringPatterns(patterns); 123 } 124 125 void transform::ApplyLowerGatherPatternsOp::populatePatterns( 126 RewritePatternSet &patterns) { 127 vector::populateVectorGatherLoweringPatterns(patterns); 128 } 129 130 void transform::ApplyLowerScanPatternsOp::populatePatterns( 131 RewritePatternSet &patterns) { 132 vector::populateVectorScanLoweringPatterns(patterns); 133 } 134 135 void transform::ApplyLowerShapeCastPatternsOp::populatePatterns( 136 RewritePatternSet &patterns) { 137 vector::populateVectorShapeCastLoweringPatterns(patterns); 138 } 139 140 void transform::ApplyLowerTransferPatternsOp::populatePatterns( 141 RewritePatternSet &patterns) { 142 vector::populateVectorTransferLoweringPatterns(patterns, 143 getMaxTransferRank()); 144 } 145 146 void transform::ApplyLowerTransposePatternsOp::populatePatterns( 147 RewritePatternSet &patterns) { 148 vector::populateVectorTransposeLoweringPatterns( 149 patterns, vector::VectorTransformsOptions().setVectorTransposeLowering( 150 getLoweringStrategy())); 151 if (getAvx2LoweringStrategy()) { 152 auto avx2LoweringOptions = 153 x86vector::avx2::LoweringOptions().setTransposeOptions( 154 x86vector::avx2::TransposeLoweringOptions() 155 .lower4x8xf32(true) 156 .lower8x8xf32(true)); 157 x86vector::avx2::populateSpecializedTransposeLoweringPatterns( 158 patterns, avx2LoweringOptions, /*benefit=*/10); 159 } 160 } 161 162 void transform::ApplyLowerInterleavePatternsOp::populatePatterns( 163 RewritePatternSet &patterns) { 164 vector::populateVectorInterleaveLoweringPatterns(patterns); 165 } 166 167 void transform::ApplyInterleaveToShufflePatternsOp::populatePatterns( 168 RewritePatternSet &patterns) { 169 vector::populateVectorInterleaveToShufflePatterns(patterns); 170 } 171 172 void transform::ApplyRewriteNarrowTypePatternsOp::populatePatterns( 173 RewritePatternSet &patterns) { 174 populateVectorNarrowTypeRewritePatterns(patterns); 175 populateVectorTransposeNarrowTypeRewritePatterns(patterns); 176 } 177 178 void transform::ApplySplitTransferFullPartialPatternsOp::populatePatterns( 179 RewritePatternSet &patterns) { 180 vector::VectorTransformsOptions vectorTransformOptions; 181 vectorTransformOptions.setVectorTransferSplit(getSplitTransferStrategy()); 182 populateVectorTransferFullPartialPatterns(patterns, vectorTransformOptions); 183 } 184 185 void transform::ApplyTransferToScfPatternsOp::populatePatterns( 186 RewritePatternSet &patterns) { 187 VectorTransferToSCFOptions vectorTransferToSCFOptions = 188 VectorTransferToSCFOptions() 189 .enableFullUnroll(getFullUnroll()) 190 .setTargetRank(getMaxTransferRank()); 191 populateVectorToSCFConversionPatterns(patterns, vectorTransferToSCFOptions); 192 } 193 194 //===----------------------------------------------------------------------===// 195 // Transform op registration 196 //===----------------------------------------------------------------------===// 197 198 namespace { 199 /// Registers new ops and declares PDL as dependent dialect since the additional 200 /// ops are using PDL types for operands and results. 201 class VectorTransformDialectExtension 202 : public transform::TransformDialectExtension< 203 VectorTransformDialectExtension> { 204 public: 205 VectorTransformDialectExtension() { 206 declareGeneratedDialect<vector::VectorDialect>(); 207 declareGeneratedDialect<LLVM::LLVMDialect>(); 208 registerTransformOps< 209 #define GET_OP_LIST 210 #include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.cpp.inc" 211 >(); 212 } 213 }; 214 } // namespace 215 216 #define GET_OP_CLASSES 217 #include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.cpp.inc" 218 219 void mlir::vector::registerTransformDialectExtension( 220 DialectRegistry ®istry) { 221 registry.addExtensions<VectorTransformDialectExtension>(); 222 } 223