1 //===- VectorTransformOps.cpp - Implementation of Vector transform ops ----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h" 10 11 #include "mlir/Conversion/LLVMCommon/TypeConverter.h" 12 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" 13 #include "mlir/Conversion/VectorToSCF/VectorToSCF.h" 14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 15 #include "mlir/Dialect/Transform/IR/TransformDialect.h" 16 #include "mlir/Dialect/Transform/IR/TransformOps.h" 17 #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" 18 #include "mlir/Dialect/Vector/IR/VectorOps.h" 19 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" 20 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" 21 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" 22 #include "mlir/Dialect/X86Vector/Transforms.h" 23 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 24 25 using namespace mlir; 26 using namespace mlir::vector; 27 using namespace mlir::transform; 28 29 //===----------------------------------------------------------------------===// 30 // Apply...ConversionPatternsOp 31 //===----------------------------------------------------------------------===// 32 33 void transform::ApplyVectorToLLVMConversionPatternsOp::populatePatterns( 34 TypeConverter &typeConverter, RewritePatternSet &patterns) { 35 populateVectorToLLVMConversionPatterns( 36 static_cast<LLVMTypeConverter &>(typeConverter), patterns, 37 getReassociateFpReductions(), getForce_32bitVectorIndices()); 38 } 39 40 LogicalResult 41 transform::ApplyVectorToLLVMConversionPatternsOp::verifyTypeConverter( 42 transform::TypeConverterBuilderOpInterface builder) { 43 if (builder.getTypeConverterType() != "LLVMTypeConverter") 44 return emitOpError("expected LLVMTypeConverter"); 45 return success(); 46 } 47 48 //===----------------------------------------------------------------------===// 49 // Apply...PatternsOp 50 //===----------------------------------------------------------------------===// 51 52 void transform::ApplyCastAwayVectorLeadingOneDimPatternsOp::populatePatterns( 53 RewritePatternSet &patterns) { 54 vector::populateCastAwayVectorLeadingOneDimPatterns(patterns); 55 } 56 57 void transform::ApplyFoldArithExtensionPatternsOp::populatePatterns( 58 RewritePatternSet &patterns) { 59 vector::populateFoldArithExtensionPatterns(patterns); 60 } 61 62 void transform::ApplyFoldElementwiseToVectorPatternsOp::populatePatterns( 63 RewritePatternSet &patterns) { 64 vector::populateElementwiseToVectorOpsPatterns(patterns); 65 } 66 67 void transform::ApplyVectorReductionToContractPatternsOp::populatePatterns( 68 RewritePatternSet &patterns) { 69 vector::populateVectorReductionToContractPatterns(patterns); 70 vector::populateSinkVectorOpsPatterns(patterns); 71 } 72 73 void transform::ApplyLowerCreateMaskPatternsOp::populatePatterns( 74 RewritePatternSet &patterns) { 75 vector::populateVectorMaskOpLoweringPatterns(patterns); 76 } 77 78 void transform::ApplyRankReducingSubviewPatternsOp::populatePatterns( 79 RewritePatternSet &patterns) { 80 vector::populateVectorTransferDropUnitDimsPatterns(patterns); 81 } 82 83 void transform::ApplyTransferPermutationPatternsOp::populatePatterns( 84 RewritePatternSet &patterns) { 85 vector::populateVectorTransferPermutationMapLoweringPatterns(patterns); 86 } 87 88 void transform::ApplyDropUnitDimWithShapeCastPatternsOp::populatePatterns( 89 RewritePatternSet &patterns) { 90 vector::populateDropUnitDimWithShapeCastPatterns(patterns); 91 } 92 93 void transform::ApplyLowerBitCastPatternsOp::populatePatterns( 94 RewritePatternSet &patterns) { 95 vector::populateVectorBitCastLoweringPatterns(patterns); 96 } 97 98 void transform::ApplyLowerBroadcastPatternsOp::populatePatterns( 99 RewritePatternSet &patterns) { 100 populateVectorBroadcastLoweringPatterns(patterns); 101 } 102 103 void transform::ApplyLowerContractionPatternsOp::populatePatterns( 104 RewritePatternSet &patterns) { 105 vector::VectorTransformsOptions vectorTransformOptions; 106 vectorTransformOptions.setVectorTransformsOptions(getLoweringStrategy()); 107 populateVectorContractLoweringPatterns(patterns, vectorTransformOptions, 108 /*benefit=*/1, 109 /*disableOuterProductLowering=*/true); 110 } 111 112 void transform::ApplyLowerMasksPatternsOp::populatePatterns( 113 RewritePatternSet &patterns) { 114 populateVectorMaskOpLoweringPatterns(patterns); 115 } 116 117 void transform::ApplyLowerMaskedTransfersPatternsOp::populatePatterns( 118 RewritePatternSet &patterns) { 119 populateVectorMaskLoweringPatternsForSideEffectingOps(patterns); 120 } 121 122 void transform::ApplyMaterializeMasksPatternsOp::populatePatterns( 123 RewritePatternSet &patterns) { 124 populateVectorMaskMaterializationPatterns(patterns, 125 /*force32BitVectorIndices=*/false); 126 } 127 128 void transform::ApplyLowerMultiReductionPatternsOp::populatePatterns( 129 RewritePatternSet &patterns) { 130 vector::VectorTransformsOptions vectorTransformOptions; 131 vectorTransformOptions.setVectorMultiReductionLowering(getLoweringStrategy()); 132 vector::populateVectorMultiReductionLoweringPatterns( 133 patterns, vectorTransformOptions.vectorMultiReductionLowering); 134 } 135 136 void transform::ApplyLowerOuterProductPatternsOp::populatePatterns( 137 RewritePatternSet &patterns) { 138 populateVectorOuterProductLoweringPatterns(patterns); 139 } 140 141 void transform::ApplyLowerGatherPatternsOp::populatePatterns( 142 RewritePatternSet &patterns) { 143 vector::populateVectorGatherLoweringPatterns(patterns); 144 } 145 146 void transform::ApplyLowerScanPatternsOp::populatePatterns( 147 RewritePatternSet &patterns) { 148 vector::populateVectorScanLoweringPatterns(patterns); 149 } 150 151 void transform::ApplyLowerShapeCastPatternsOp::populatePatterns( 152 RewritePatternSet &patterns) { 153 vector::populateVectorShapeCastLoweringPatterns(patterns); 154 } 155 156 void transform::ApplyLowerTransferPatternsOp::populatePatterns( 157 RewritePatternSet &patterns) { 158 vector::populateVectorTransferLoweringPatterns(patterns, 159 getMaxTransferRank()); 160 } 161 162 void transform::ApplyLowerTransposePatternsOp::populatePatterns( 163 RewritePatternSet &patterns) { 164 vector::populateVectorTransposeLoweringPatterns( 165 patterns, vector::VectorTransformsOptions().setVectorTransposeLowering( 166 getLoweringStrategy())); 167 if (getAvx2LoweringStrategy()) { 168 auto avx2LoweringOptions = 169 x86vector::avx2::LoweringOptions().setTransposeOptions( 170 x86vector::avx2::TransposeLoweringOptions() 171 .lower4x8xf32(true) 172 .lower8x8xf32(true)); 173 x86vector::avx2::populateSpecializedTransposeLoweringPatterns( 174 patterns, avx2LoweringOptions, /*benefit=*/10); 175 } 176 } 177 178 void transform::ApplyLowerInterleavePatternsOp::populatePatterns( 179 RewritePatternSet &patterns) { 180 vector::populateVectorInterleaveLoweringPatterns(patterns); 181 } 182 183 void transform::ApplyInterleaveToShufflePatternsOp::populatePatterns( 184 RewritePatternSet &patterns) { 185 vector::populateVectorInterleaveToShufflePatterns(patterns); 186 } 187 188 void transform::ApplyRewriteNarrowTypePatternsOp::populatePatterns( 189 RewritePatternSet &patterns) { 190 populateVectorNarrowTypeRewritePatterns(patterns); 191 populateVectorTransposeNarrowTypeRewritePatterns(patterns); 192 } 193 194 void transform::ApplySplitTransferFullPartialPatternsOp::populatePatterns( 195 RewritePatternSet &patterns) { 196 vector::VectorTransformsOptions vectorTransformOptions; 197 vectorTransformOptions.setVectorTransferSplit(getSplitTransferStrategy()); 198 populateVectorTransferFullPartialPatterns(patterns, vectorTransformOptions); 199 } 200 201 void transform::ApplyTransferToScfPatternsOp::populatePatterns( 202 RewritePatternSet &patterns) { 203 VectorTransferToSCFOptions vectorTransferToSCFOptions = 204 VectorTransferToSCFOptions() 205 .enableFullUnroll(getFullUnroll()) 206 .setTargetRank(getMaxTransferRank()); 207 populateVectorToSCFConversionPatterns(patterns, vectorTransferToSCFOptions); 208 } 209 210 //===----------------------------------------------------------------------===// 211 // Transform op registration 212 //===----------------------------------------------------------------------===// 213 214 namespace { 215 /// Registers new ops and declares PDL as dependent dialect since the additional 216 /// ops are using PDL types for operands and results. 217 class VectorTransformDialectExtension 218 : public transform::TransformDialectExtension< 219 VectorTransformDialectExtension> { 220 public: 221 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(VectorTransformDialectExtension) 222 223 VectorTransformDialectExtension() { 224 declareGeneratedDialect<vector::VectorDialect>(); 225 declareGeneratedDialect<LLVM::LLVMDialect>(); 226 registerTransformOps< 227 #define GET_OP_LIST 228 #include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.cpp.inc" 229 >(); 230 } 231 }; 232 } // namespace 233 234 #define GET_OP_CLASSES 235 #include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.cpp.inc" 236 237 void mlir::vector::registerTransformDialectExtension( 238 DialectRegistry ®istry) { 239 registry.addExtensions<VectorTransformDialectExtension>(); 240 } 241