1 //===- VectorTransformOps.cpp - Implementation of Vector transform ops ----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h" 10 11 #include "mlir/Conversion/LLVMCommon/TypeConverter.h" 12 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" 13 #include "mlir/Conversion/VectorToSCF/VectorToSCF.h" 14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 15 #include "mlir/Dialect/Transform/IR/TransformDialect.h" 16 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" 17 #include "mlir/Dialect/Transform/IR/TransformOps.h" 18 #include "mlir/Dialect/Vector/IR/VectorOps.h" 19 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" 20 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" 21 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" 22 #include "mlir/Dialect/X86Vector/Transforms.h" 23 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 24 25 using namespace mlir; 26 using namespace mlir::vector; 27 using namespace mlir::transform; 28 29 //===----------------------------------------------------------------------===// 30 // Apply...ConversionPatternsOp 31 //===----------------------------------------------------------------------===// 32 33 void transform::ApplyVectorToLLVMConversionPatternsOp::populatePatterns( 34 TypeConverter &typeConverter, RewritePatternSet &patterns) { 35 populateVectorToLLVMConversionPatterns( 36 static_cast<LLVMTypeConverter &>(typeConverter), patterns, 37 getReassociateFpReductions(), getForce_32bitVectorIndices()); 38 } 39 40 LogicalResult 41 transform::ApplyVectorToLLVMConversionPatternsOp::verifyTypeConverter( 42 transform::TypeConverterBuilderOpInterface builder) { 43 if (builder.getTypeConverterType() != "LLVMTypeConverter") 44 return emitOpError("expected LLVMTypeConverter"); 45 return success(); 46 } 47 48 //===----------------------------------------------------------------------===// 49 // Apply...PatternsOp 50 //===----------------------------------------------------------------------===// 51 52 void transform::ApplyCastAwayVectorLeadingOneDimPatternsOp::populatePatterns( 53 RewritePatternSet &patterns) { 54 vector::populateCastAwayVectorLeadingOneDimPatterns(patterns); 55 } 56 57 void transform::ApplyFoldArithExtensionPatternsOp::populatePatterns( 58 RewritePatternSet &patterns) { 59 vector::populateFoldArithExtensionPatterns(patterns); 60 } 61 62 void transform::ApplyVectorReductionToContractPatternsOp::populatePatterns( 63 RewritePatternSet &patterns) { 64 vector::populateVectorReductionToContractPatterns(patterns); 65 } 66 67 void transform::ApplyRankReducingSubviewPatternsOp::populatePatterns( 68 RewritePatternSet &patterns) { 69 vector::populateVectorTransferDropUnitDimsPatterns(patterns); 70 } 71 72 void transform::ApplyTransferPermutationPatternsOp::populatePatterns( 73 RewritePatternSet &patterns) { 74 vector::populateVectorTransferPermutationMapLoweringPatterns(patterns); 75 } 76 77 void transform::ApplyLowerBroadcastPatternsOp::populatePatterns( 78 RewritePatternSet &patterns) { 79 populateVectorBroadcastLoweringPatterns(patterns); 80 } 81 82 void transform::ApplyLowerContractionPatternsOp::populatePatterns( 83 RewritePatternSet &patterns) { 84 vector::VectorTransformsOptions vectorTransformOptions; 85 vectorTransformOptions.setVectorTransformsOptions(getLoweringStrategy()); 86 populateVectorContractLoweringPatterns(patterns, vectorTransformOptions, 87 /*benefit=*/1, 88 /*disableOuterProductLowering=*/true); 89 } 90 91 void transform::ApplyLowerMasksPatternsOp::populatePatterns( 92 RewritePatternSet &patterns) { 93 populateVectorMaskOpLoweringPatterns(patterns); 94 } 95 96 void transform::ApplyLowerMaskedTransfersPatternsOp::populatePatterns( 97 RewritePatternSet &patterns) { 98 populateVectorMaskLoweringPatternsForSideEffectingOps(patterns); 99 } 100 101 void transform::ApplyMaterializeMasksPatternsOp::populatePatterns( 102 RewritePatternSet &patterns) { 103 populateVectorMaskMaterializationPatterns(patterns, 104 /*force32BitVectorIndices=*/false); 105 } 106 107 void transform::ApplyLowerMultiReductionPatternsOp::populatePatterns( 108 RewritePatternSet &patterns) { 109 vector::VectorTransformsOptions vectorTransformOptions; 110 vectorTransformOptions.setVectorMultiReductionLowering(getLoweringStrategy()); 111 vector::populateVectorMultiReductionLoweringPatterns( 112 patterns, vectorTransformOptions.vectorMultiReductionLowering); 113 } 114 115 void transform::ApplyLowerOuterProductPatternsOp::populatePatterns( 116 RewritePatternSet &patterns) { 117 populateVectorOuterProductLoweringPatterns(patterns); 118 } 119 120 void transform::ApplyLowerGatherPatternsOp::populatePatterns( 121 RewritePatternSet &patterns) { 122 vector::populateVectorGatherLoweringPatterns(patterns); 123 } 124 125 void transform::ApplyLowerScanPatternsOp::populatePatterns( 126 RewritePatternSet &patterns) { 127 vector::populateVectorScanLoweringPatterns(patterns); 128 } 129 130 void transform::ApplyLowerShapeCastPatternsOp::populatePatterns( 131 RewritePatternSet &patterns) { 132 vector::populateVectorShapeCastLoweringPatterns(patterns); 133 } 134 135 void transform::ApplyLowerTransferPatternsOp::populatePatterns( 136 RewritePatternSet &patterns) { 137 vector::populateVectorTransferLoweringPatterns(patterns, 138 getMaxTransferRank()); 139 } 140 141 void transform::ApplyLowerTransposePatternsOp::populatePatterns( 142 RewritePatternSet &patterns) { 143 vector::populateVectorTransposeLoweringPatterns( 144 patterns, vector::VectorTransformsOptions().setVectorTransposeLowering( 145 getLoweringStrategy())); 146 if (getAvx2LoweringStrategy()) { 147 auto avx2LoweringOptions = 148 x86vector::avx2::LoweringOptions().setTransposeOptions( 149 x86vector::avx2::TransposeLoweringOptions() 150 .lower4x8xf32(true) 151 .lower8x8xf32(true)); 152 x86vector::avx2::populateSpecializedTransposeLoweringPatterns( 153 patterns, avx2LoweringOptions, /*benefit=*/10); 154 } 155 } 156 157 void transform::ApplySplitTransferFullPartialPatternsOp::populatePatterns( 158 RewritePatternSet &patterns) { 159 vector::VectorTransformsOptions vectorTransformOptions; 160 vectorTransformOptions.setVectorTransferSplit(getSplitTransferStrategy()); 161 populateVectorTransferFullPartialPatterns(patterns, vectorTransformOptions); 162 } 163 164 void transform::ApplyTransferToScfPatternsOp::populatePatterns( 165 RewritePatternSet &patterns) { 166 VectorTransferToSCFOptions vectorTransferToSCFOptions = 167 VectorTransferToSCFOptions() 168 .enableFullUnroll(getFullUnroll()) 169 .setTargetRank(getMaxTransferRank()); 170 populateVectorToSCFConversionPatterns(patterns, vectorTransferToSCFOptions); 171 } 172 173 //===----------------------------------------------------------------------===// 174 // Transform op registration 175 //===----------------------------------------------------------------------===// 176 177 namespace { 178 /// Registers new ops and declares PDL as dependent dialect since the additional 179 /// ops are using PDL types for operands and results. 180 class VectorTransformDialectExtension 181 : public transform::TransformDialectExtension< 182 VectorTransformDialectExtension> { 183 public: 184 VectorTransformDialectExtension() { 185 declareGeneratedDialect<vector::VectorDialect>(); 186 declareGeneratedDialect<LLVM::LLVMDialect>(); 187 registerTransformOps< 188 #define GET_OP_LIST 189 #include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.cpp.inc" 190 >(); 191 } 192 }; 193 } // namespace 194 195 #define GET_OP_CLASSES 196 #include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.cpp.inc" 197 198 void mlir::vector::registerTransformDialectExtension( 199 DialectRegistry ®istry) { 200 registry.addExtensions<VectorTransformDialectExtension>(); 201 } 202