xref: /llvm-project/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp (revision f18c3e4e7335df282c468b6dff3d29be1822a96d)
1 //===- TensorTransformOps.cpp - Implementation of tensor transform ops ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.h"
10 
11 #include "mlir/Dialect/Affine/IR/AffineOps.h"
12 #include "mlir/Dialect/SCF/IR/SCF.h"
13 #include "mlir/Dialect/Tensor/IR/Tensor.h"
14 #include "mlir/Dialect/Tensor/Transforms/Transforms.h"
15 #include "mlir/Dialect/Tensor/Utils/Utils.h"
16 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
17 #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
18 #include "mlir/IR/Builders.h"
19 #include "mlir/Transforms/DialectConversion.h"
20 
21 using namespace mlir;
22 using namespace tensor;
23 
24 //===----------------------------------------------------------------------===//
25 // FindPayloadReplacementOpInterface implementations
26 //===----------------------------------------------------------------------===//
27 
28 namespace {
29 struct ExtractSliceOpReplacementInterface
30     : public transform::FindPayloadReplacementOpInterface::ExternalModel<
31           ExtractSliceOpReplacementInterface, tensor::ExtractSliceOp> {
32   SmallVector<Value> getNextOperands(Operation *op) const {
33     auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
34     if (!isCastLikeExtractSliceOp(extractSliceOp))
35       return {};
36     return {extractSliceOp.getSource()};
37   }
38 };
39 
40 struct InsertSliceOpReplacementInterface
41     : public transform::FindPayloadReplacementOpInterface::ExternalModel<
42           InsertSliceOpReplacementInterface, tensor::InsertSliceOp> {
43   SmallVector<Value> getNextOperands(Operation *op) const {
44     auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
45     if (!isCastLikeInsertSliceOp(insertSliceOp))
46       return {};
47     return {insertSliceOp.getSource()};
48   }
49 };
50 
51 struct ReshapeOpReplacementInterface
52     : public transform::FindPayloadReplacementOpInterface::ExternalModel<
53           ReshapeOpReplacementInterface, tensor::ReshapeOp> {
54   SmallVector<Value> getNextOperands(Operation *op) const {
55     auto reshapeOp = cast<tensor::ReshapeOp>(op);
56     return {reshapeOp.getSource()};
57   }
58 };
59 
60 template <typename ConcreteOp>
61 struct ReassociativeReshapeOpReplacementInterface
62     : public transform::FindPayloadReplacementOpInterface::ExternalModel<
63           ReassociativeReshapeOpReplacementInterface<ConcreteOp>, ConcreteOp> {
64   SmallVector<Value> getNextOperands(Operation *op) const {
65     auto reshapeOp = cast<ConcreteOp>(op);
66     return {reshapeOp.getSrc()};
67   }
68 };
69 } // namespace
70 
71 void tensor::registerFindPayloadReplacementOpInterfaceExternalModels(
72     DialectRegistry &registry) {
73   registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) {
74     CollapseShapeOp::attachInterface<
75         ReassociativeReshapeOpReplacementInterface<CollapseShapeOp>>(*ctx);
76     ExpandShapeOp::attachInterface<
77         ReassociativeReshapeOpReplacementInterface<ExpandShapeOp>>(*ctx);
78     ExtractSliceOp::attachInterface<ExtractSliceOpReplacementInterface>(*ctx);
79     InsertSliceOp::attachInterface<InsertSliceOpReplacementInterface>(*ctx);
80     ReshapeOp::attachInterface<ReshapeOpReplacementInterface>(*ctx);
81   });
82 }
83 
84 //===----------------------------------------------------------------------===//
85 // Apply...PatternsOp
86 //===----------------------------------------------------------------------===//
87 
88 void transform::ApplyDecomposeTensorConcatPatternsOp::populatePatterns(
89     RewritePatternSet &patterns) {
90   tensor::populateDecomposeTensorConcatPatterns(patterns);
91 }
92 
93 void transform::ApplyDropRedundantInsertSliceRankExpansionPatternsOp::
94     populatePatterns(RewritePatternSet &patterns) {
95   tensor::populateDropRedundantInsertSliceRankExpansionPatterns(patterns);
96 }
97 
98 void transform::ApplyFoldTensorEmptyPatternsOp::populatePatterns(
99     RewritePatternSet &patterns) {
100   tensor::populateFoldTensorEmptyPatterns(patterns, getFoldSingleUseOnly());
101 }
102 
103 void transform::ApplyFoldIntoPackAndUnpackPatternsOp::populatePatterns(
104     RewritePatternSet &patterns) {
105   tensor::populateFoldIntoPackAndUnpackPatterns(patterns);
106 }
107 
108 void transform::ApplyFoldTensorSubsetOpsPatternsOp::populatePatterns(
109     RewritePatternSet &patterns) {
110   tensor::populateFoldTensorSubsetOpPatterns(patterns);
111 }
112 
113 void transform::ApplyFoldTensorSubsetOpsIntoVectorTransfersPatternsOp::
114     populatePatterns(RewritePatternSet &patterns) {
115   tensor::populateFoldTensorSubsetIntoVectorTransferPatterns(patterns);
116 }
117 
118 void transform::ApplyMergeConsecutiveInsertExtractSlicePatternsOp::
119     populatePatterns(RewritePatternSet &patterns) {
120   tensor::populateMergeConsecutiveInsertExtractSlicePatterns(patterns);
121 }
122 
123 void transform::ApplyReassociativeReshapeFoldingPatternsOp::populatePatterns(
124     RewritePatternSet &patterns) {
125   tensor::populateReassociativeReshapeFoldingPatterns(patterns);
126 }
127 
128 void transform::ApplyRewriteTensorOpsAsConstantPatternsOp::populatePatterns(
129     RewritePatternSet &patterns) {
130   ControlFoldFn defaultControlFn = [](OpOperand *fusedOperand) {
131     Operation *producer = fusedOperand->get().getDefiningOp();
132     return producer && producer->hasOneUse();
133   };
134 
135   ControlFoldFn aggressiveControlFn = [](OpOperand *fusedOperand) {
136     return true;
137   };
138 
139   // Add folding with reshape by expansion patterns.
140   if (getAggressive())
141     tensor::populateRewriteAsConstantPatterns(patterns, aggressiveControlFn);
142   else
143     tensor::populateRewriteAsConstantPatterns(patterns, defaultControlFn);
144 }
145 
146 //===----------------------------------------------------------------------===//
147 // TypeConversionCastTensorShapeOp
148 //===----------------------------------------------------------------------===//
149 
150 void transform::TypeConversionCastShapeDynamicDimsOp::
151     populateTypeMaterializations(TypeConverter &converter) {
152   bool ignoreDynamicInfo = getIgnoreDynamicInfo();
153   converter.addSourceMaterialization([ignoreDynamicInfo](
154                                          OpBuilder &builder, Type resultType,
155                                          ValueRange inputs,
156                                          Location loc) -> Value {
157     if (inputs.size() != 1) {
158       return Value();
159     }
160     Value input = inputs[0];
161     if (!ignoreDynamicInfo &&
162         !tensor::preservesStaticInformation(resultType, input.getType())) {
163       return Value();
164     }
165     if (!tensor::CastOp::areCastCompatible(input.getType(), resultType)) {
166       return Value();
167     }
168     return builder.create<tensor::CastOp>(loc, resultType, input).getResult();
169   });
170   converter.addTargetMaterialization([](OpBuilder &builder, Type resultType,
171                                         ValueRange inputs,
172                                         Location loc) -> Value {
173     if (inputs.size() != 1) {
174       return Value();
175     }
176     Value input = inputs[0];
177     if (!tensor::CastOp::areCastCompatible(input.getType(), resultType)) {
178       return Value();
179     }
180     return builder.create<tensor::CastOp>(loc, resultType, input).getResult();
181   });
182 }
183 
184 //===----------------------------------------------------------------------===//
185 // MakeLoopIndependentOp
186 //===----------------------------------------------------------------------===//
187 
188 DiagnosedSilenceableFailure transform::MakeLoopIndependentOp::applyToOne(
189     transform::TransformRewriter &rewriter, Operation *target,
190     transform::ApplyToEachResultList &results,
191     transform::TransformState &state) {
192   // Gather IVs.
193   SmallVector<Value> ivs;
194   Operation *nextOp = target;
195   for (uint64_t i = 0, e = getNumLoops(); i < e; ++i) {
196     nextOp = nextOp->getParentOfType<scf::ForOp>();
197     if (!nextOp) {
198       DiagnosedSilenceableFailure diag = emitSilenceableError()
199                                          << "could not find " << i
200                                          << "-th enclosing loop";
201       diag.attachNote(target->getLoc()) << "target op";
202       return diag;
203     }
204     ivs.push_back(cast<scf::ForOp>(nextOp).getInductionVar());
205   }
206 
207   // Rewrite IR.
208   FailureOr<Value> replacement = failure();
209   if (auto padOp = dyn_cast<tensor::PadOp>(target)) {
210     replacement = tensor::buildIndependentOp(rewriter, padOp, ivs);
211   } else if (auto emptyOp = dyn_cast<tensor::EmptyOp>(target)) {
212     replacement = tensor::buildIndependentOp(rewriter, emptyOp, ivs);
213   } else {
214     DiagnosedSilenceableFailure diag = emitSilenceableError()
215                                        << "unsupported target op";
216     diag.attachNote(target->getLoc()) << "target op";
217     return diag;
218   }
219   if (failed(replacement)) {
220     DiagnosedSilenceableFailure diag =
221         emitSilenceableError() << "could not make target op loop-independent";
222     diag.attachNote(target->getLoc()) << "target op";
223     return diag;
224   }
225   rewriter.replaceOp(target, *replacement);
226   results.push_back(replacement->getDefiningOp());
227   return DiagnosedSilenceableFailure::success();
228 }
229 
230 //===----------------------------------------------------------------------===//
231 // Transform op registration
232 //===----------------------------------------------------------------------===//
233 
234 namespace {
235 class TensorTransformDialectExtension
236     : public transform::TransformDialectExtension<
237           TensorTransformDialectExtension> {
238 public:
239   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TensorTransformDialectExtension)
240 
241   using Base::Base;
242 
243   void init() {
244     declareGeneratedDialect<affine::AffineDialect>();
245     declareGeneratedDialect<tensor::TensorDialect>();
246 
247     registerTransformOps<
248 #define GET_OP_LIST
249 #include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.cpp.inc"
250         >();
251   }
252 };
253 } // namespace
254 
255 #define GET_OP_CLASSES
256 #include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.cpp.inc"
257 
258 void mlir::tensor::registerTransformDialectExtension(
259     DialectRegistry &registry) {
260   registry.addExtensions<TensorTransformDialectExtension>();
261 }
262