1 //===- TensorTransformOps.cpp - Implementation of tensor transform ops ----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.h" 10 11 #include "mlir/Dialect/Affine/IR/AffineOps.h" 12 #include "mlir/Dialect/SCF/IR/SCF.h" 13 #include "mlir/Dialect/Tensor/IR/Tensor.h" 14 #include "mlir/Dialect/Tensor/Transforms/Transforms.h" 15 #include "mlir/Dialect/Tensor/Utils/Utils.h" 16 #include "mlir/Dialect/Transform/IR/TransformDialect.h" 17 #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" 18 #include "mlir/IR/Builders.h" 19 #include "mlir/Transforms/DialectConversion.h" 20 21 using namespace mlir; 22 using namespace tensor; 23 24 //===----------------------------------------------------------------------===// 25 // FindPayloadReplacementOpInterface implementations 26 //===----------------------------------------------------------------------===// 27 28 namespace { 29 struct ExtractSliceOpReplacementInterface 30 : public transform::FindPayloadReplacementOpInterface::ExternalModel< 31 ExtractSliceOpReplacementInterface, tensor::ExtractSliceOp> { 32 SmallVector<Value> getNextOperands(Operation *op) const { 33 auto extractSliceOp = cast<tensor::ExtractSliceOp>(op); 34 if (!isCastLikeExtractSliceOp(extractSliceOp)) 35 return {}; 36 return {extractSliceOp.getSource()}; 37 } 38 }; 39 40 struct InsertSliceOpReplacementInterface 41 : public transform::FindPayloadReplacementOpInterface::ExternalModel< 42 InsertSliceOpReplacementInterface, tensor::InsertSliceOp> { 43 SmallVector<Value> getNextOperands(Operation *op) const { 44 auto insertSliceOp = cast<tensor::InsertSliceOp>(op); 45 if (!isCastLikeInsertSliceOp(insertSliceOp)) 46 return {}; 47 return {insertSliceOp.getSource()}; 48 } 49 }; 50 51 struct ReshapeOpReplacementInterface 52 : public transform::FindPayloadReplacementOpInterface::ExternalModel< 53 ReshapeOpReplacementInterface, tensor::ReshapeOp> { 54 SmallVector<Value> getNextOperands(Operation *op) const { 55 auto reshapeOp = cast<tensor::ReshapeOp>(op); 56 return {reshapeOp.getSource()}; 57 } 58 }; 59 60 template <typename ConcreteOp> 61 struct ReassociativeReshapeOpReplacementInterface 62 : public transform::FindPayloadReplacementOpInterface::ExternalModel< 63 ReassociativeReshapeOpReplacementInterface<ConcreteOp>, ConcreteOp> { 64 SmallVector<Value> getNextOperands(Operation *op) const { 65 auto reshapeOp = cast<ConcreteOp>(op); 66 return {reshapeOp.getSrc()}; 67 } 68 }; 69 } // namespace 70 71 void tensor::registerFindPayloadReplacementOpInterfaceExternalModels( 72 DialectRegistry ®istry) { 73 registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) { 74 CollapseShapeOp::attachInterface< 75 ReassociativeReshapeOpReplacementInterface<CollapseShapeOp>>(*ctx); 76 ExpandShapeOp::attachInterface< 77 ReassociativeReshapeOpReplacementInterface<ExpandShapeOp>>(*ctx); 78 ExtractSliceOp::attachInterface<ExtractSliceOpReplacementInterface>(*ctx); 79 InsertSliceOp::attachInterface<InsertSliceOpReplacementInterface>(*ctx); 80 ReshapeOp::attachInterface<ReshapeOpReplacementInterface>(*ctx); 81 }); 82 } 83 84 //===----------------------------------------------------------------------===// 85 // Apply...PatternsOp 86 //===----------------------------------------------------------------------===// 87 88 void transform::ApplyDecomposeTensorConcatPatternsOp::populatePatterns( 89 RewritePatternSet &patterns) { 90 tensor::populateDecomposeTensorConcatPatterns(patterns); 91 } 92 93 void transform::ApplyDropRedundantInsertSliceRankExpansionPatternsOp:: 94 populatePatterns(RewritePatternSet &patterns) { 95 tensor::populateDropRedundantInsertSliceRankExpansionPatterns(patterns); 96 } 97 98 void transform::ApplyFoldTensorEmptyPatternsOp::populatePatterns( 99 RewritePatternSet &patterns) { 100 tensor::populateFoldTensorEmptyPatterns(patterns, getFoldSingleUseOnly()); 101 } 102 103 void transform::ApplyFoldIntoPackAndUnpackPatternsOp::populatePatterns( 104 RewritePatternSet &patterns) { 105 tensor::populateFoldIntoPackAndUnpackPatterns(patterns); 106 } 107 108 void transform::ApplyFoldTensorSubsetOpsPatternsOp::populatePatterns( 109 RewritePatternSet &patterns) { 110 tensor::populateFoldTensorSubsetOpPatterns(patterns); 111 } 112 113 void transform::ApplyFoldTensorSubsetOpsIntoVectorTransfersPatternsOp:: 114 populatePatterns(RewritePatternSet &patterns) { 115 tensor::populateFoldTensorSubsetIntoVectorTransferPatterns(patterns); 116 } 117 118 void transform::ApplyMergeConsecutiveInsertExtractSlicePatternsOp:: 119 populatePatterns(RewritePatternSet &patterns) { 120 tensor::populateMergeConsecutiveInsertExtractSlicePatterns(patterns); 121 } 122 123 void transform::ApplyReassociativeReshapeFoldingPatternsOp::populatePatterns( 124 RewritePatternSet &patterns) { 125 tensor::populateReassociativeReshapeFoldingPatterns(patterns); 126 } 127 128 void transform::ApplyRewriteTensorOpsAsConstantPatternsOp::populatePatterns( 129 RewritePatternSet &patterns) { 130 ControlFoldFn defaultControlFn = [](OpOperand *fusedOperand) { 131 Operation *producer = fusedOperand->get().getDefiningOp(); 132 return producer && producer->hasOneUse(); 133 }; 134 135 ControlFoldFn aggressiveControlFn = [](OpOperand *fusedOperand) { 136 return true; 137 }; 138 139 // Add folding with reshape by expansion patterns. 140 if (getAggressive()) 141 tensor::populateRewriteAsConstantPatterns(patterns, aggressiveControlFn); 142 else 143 tensor::populateRewriteAsConstantPatterns(patterns, defaultControlFn); 144 } 145 146 //===----------------------------------------------------------------------===// 147 // TypeConversionCastTensorShapeOp 148 //===----------------------------------------------------------------------===// 149 150 void transform::TypeConversionCastShapeDynamicDimsOp:: 151 populateTypeMaterializations(TypeConverter &converter) { 152 bool ignoreDynamicInfo = getIgnoreDynamicInfo(); 153 converter.addSourceMaterialization([ignoreDynamicInfo]( 154 OpBuilder &builder, Type resultType, 155 ValueRange inputs, 156 Location loc) -> Value { 157 if (inputs.size() != 1) { 158 return Value(); 159 } 160 Value input = inputs[0]; 161 if (!ignoreDynamicInfo && 162 !tensor::preservesStaticInformation(resultType, input.getType())) { 163 return Value(); 164 } 165 if (!tensor::CastOp::areCastCompatible(input.getType(), resultType)) { 166 return Value(); 167 } 168 return builder.create<tensor::CastOp>(loc, resultType, input).getResult(); 169 }); 170 converter.addTargetMaterialization([](OpBuilder &builder, Type resultType, 171 ValueRange inputs, 172 Location loc) -> Value { 173 if (inputs.size() != 1) { 174 return Value(); 175 } 176 Value input = inputs[0]; 177 if (!tensor::CastOp::areCastCompatible(input.getType(), resultType)) { 178 return Value(); 179 } 180 return builder.create<tensor::CastOp>(loc, resultType, input).getResult(); 181 }); 182 } 183 184 //===----------------------------------------------------------------------===// 185 // MakeLoopIndependentOp 186 //===----------------------------------------------------------------------===// 187 188 DiagnosedSilenceableFailure transform::MakeLoopIndependentOp::applyToOne( 189 transform::TransformRewriter &rewriter, Operation *target, 190 transform::ApplyToEachResultList &results, 191 transform::TransformState &state) { 192 // Gather IVs. 193 SmallVector<Value> ivs; 194 Operation *nextOp = target; 195 for (uint64_t i = 0, e = getNumLoops(); i < e; ++i) { 196 nextOp = nextOp->getParentOfType<scf::ForOp>(); 197 if (!nextOp) { 198 DiagnosedSilenceableFailure diag = emitSilenceableError() 199 << "could not find " << i 200 << "-th enclosing loop"; 201 diag.attachNote(target->getLoc()) << "target op"; 202 return diag; 203 } 204 ivs.push_back(cast<scf::ForOp>(nextOp).getInductionVar()); 205 } 206 207 // Rewrite IR. 208 FailureOr<Value> replacement = failure(); 209 if (auto padOp = dyn_cast<tensor::PadOp>(target)) { 210 replacement = tensor::buildIndependentOp(rewriter, padOp, ivs); 211 } else if (auto emptyOp = dyn_cast<tensor::EmptyOp>(target)) { 212 replacement = tensor::buildIndependentOp(rewriter, emptyOp, ivs); 213 } else { 214 DiagnosedSilenceableFailure diag = emitSilenceableError() 215 << "unsupported target op"; 216 diag.attachNote(target->getLoc()) << "target op"; 217 return diag; 218 } 219 if (failed(replacement)) { 220 DiagnosedSilenceableFailure diag = 221 emitSilenceableError() << "could not make target op loop-independent"; 222 diag.attachNote(target->getLoc()) << "target op"; 223 return diag; 224 } 225 rewriter.replaceOp(target, *replacement); 226 results.push_back(replacement->getDefiningOp()); 227 return DiagnosedSilenceableFailure::success(); 228 } 229 230 //===----------------------------------------------------------------------===// 231 // Transform op registration 232 //===----------------------------------------------------------------------===// 233 234 namespace { 235 class TensorTransformDialectExtension 236 : public transform::TransformDialectExtension< 237 TensorTransformDialectExtension> { 238 public: 239 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TensorTransformDialectExtension) 240 241 using Base::Base; 242 243 void init() { 244 declareGeneratedDialect<affine::AffineDialect>(); 245 declareGeneratedDialect<tensor::TensorDialect>(); 246 247 registerTransformOps< 248 #define GET_OP_LIST 249 #include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.cpp.inc" 250 >(); 251 } 252 }; 253 } // namespace 254 255 #define GET_OP_CLASSES 256 #include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.cpp.inc" 257 258 void mlir::tensor::registerTransformDialectExtension( 259 DialectRegistry ®istry) { 260 registry.addExtensions<TensorTransformDialectExtension>(); 261 } 262