xref: /llvm-project/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp (revision a02ad6c1773368c9ce67d3a28578bf6284c6c1be)
1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.h"
10 
11 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
12 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
13 #include "mlir/Dialect/Shape/IR/Shape.h"
14 #include "mlir/IR/Dialect.h"
15 #include "mlir/IR/Operation.h"
16 #include "mlir/IR/PatternMatch.h"
17 
18 using namespace mlir;
19 using namespace mlir::bufferization;
20 using namespace mlir::shape;
21 
22 namespace mlir {
23 namespace shape {
24 namespace {
25 
26 /// Bufferization of shape.assuming.
27 struct AssumingOpInterface
28     : public BufferizableOpInterface::ExternalModel<AssumingOpInterface,
29                                                     shape::AssumingOp> {
30   AliasingOpOperandList
getAliasingOpOperandsmlir::shape::__anon221796b20111::AssumingOpInterface31   getAliasingOpOperands(Operation *op, Value value,
32                         const AnalysisState &state) const {
33     // AssumingOps do not have tensor OpOperands. The yielded value can be any
34     // SSA value that is in scope. To allow for use-def chain traversal through
35     // AssumingOps in the analysis, the corresponding yield value is considered
36     // to be aliasing with the result.
37     auto assumingOp = cast<shape::AssumingOp>(op);
38     size_t resultNum = std::distance(op->getOpResults().begin(),
39                                      llvm::find(op->getOpResults(), value));
40     // TODO: Support multiple blocks.
41     assert(assumingOp.getDoRegion().getBlocks().size() == 1 &&
42            "expected exactly 1 block");
43     auto yieldOp = dyn_cast<shape::AssumingYieldOp>(
44         assumingOp.getDoRegion().front().getTerminator());
45     assert(yieldOp && "expected shape.assuming_yield terminator");
46     return {{&yieldOp->getOpOperand(resultNum), BufferRelation::Equivalent}};
47   }
48 
bufferizemlir::shape::__anon221796b20111::AssumingOpInterface49   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
50                           const BufferizationOptions &options) const {
51     auto assumingOp = cast<shape::AssumingOp>(op);
52     assert(assumingOp.getDoRegion().getBlocks().size() == 1 &&
53            "only 1 block supported");
54     auto yieldOp = cast<shape::AssumingYieldOp>(
55         assumingOp.getDoRegion().front().getTerminator());
56 
57     // Create new op and move over region.
58     TypeRange newResultTypes(yieldOp.getOperands());
59     auto newOp = rewriter.create<shape::AssumingOp>(
60         op->getLoc(), newResultTypes, assumingOp.getWitness());
61     newOp.getDoRegion().takeBody(assumingOp.getRegion());
62 
63     // Update all uses of the old op.
64     rewriter.setInsertionPointAfter(newOp);
65     SmallVector<Value> newResults;
66     for (const auto &it : llvm::enumerate(assumingOp->getResultTypes())) {
67       if (isa<TensorType>(it.value())) {
68         newResults.push_back(rewriter.create<bufferization::ToTensorOp>(
69             assumingOp.getLoc(), newOp->getResult(it.index())));
70       } else {
71         newResults.push_back(newOp->getResult(it.index()));
72       }
73     }
74 
75     // Replace old op.
76     rewriter.replaceOp(assumingOp, newResults);
77 
78     return success();
79   }
80 };
81 
82 /// Bufferization of shape.assuming_yield. Bufferized as part of their enclosing
83 /// ops, so this is for analysis only.
84 struct AssumingYieldOpInterface
85     : public BufferizableOpInterface::ExternalModel<AssumingYieldOpInterface,
86                                                     shape::AssumingYieldOp> {
bufferizesToMemoryReadmlir::shape::__anon221796b20111::AssumingYieldOpInterface87   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
88                               const AnalysisState &state) const {
89     return true;
90   }
91 
bufferizesToMemoryWritemlir::shape::__anon221796b20111::AssumingYieldOpInterface92   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
93                                const AnalysisState &state) const {
94     return false;
95   }
96 
getAliasingValuesmlir::shape::__anon221796b20111::AssumingYieldOpInterface97   AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
98                                       const AnalysisState &state) const {
99     assert(isa<shape::AssumingOp>(op->getParentOp()) &&
100            "expected that parent is an AssumingOp");
101     OpResult opResult =
102         op->getParentOp()->getResult(opOperand.getOperandNumber());
103     return {{opResult, BufferRelation::Equivalent}};
104   }
105 
mustBufferizeInPlacemlir::shape::__anon221796b20111::AssumingYieldOpInterface106   bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand,
107                             const AnalysisState &state) const {
108     // Yield operands always bufferize inplace. Otherwise, an alloc + copy
109     // may be generated inside the block. We should not return/yield allocations
110     // when possible.
111     return true;
112   }
113 
bufferizemlir::shape::__anon221796b20111::AssumingYieldOpInterface114   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
115                           const BufferizationOptions &options) const {
116     auto yieldOp = cast<shape::AssumingYieldOp>(op);
117     SmallVector<Value> newResults;
118     for (Value value : yieldOp.getOperands()) {
119       if (isa<TensorType>(value.getType())) {
120         FailureOr<Value> buffer = getBuffer(rewriter, value, options);
121         if (failed(buffer))
122           return failure();
123         newResults.push_back(*buffer);
124       } else {
125         newResults.push_back(value);
126       }
127     }
128     replaceOpWithNewBufferizedOp<shape::AssumingYieldOp>(rewriter, op,
129                                                          newResults);
130     return success();
131   }
132 };
133 
134 } // namespace
135 } // namespace shape
136 } // namespace mlir
137 
registerBufferizableOpInterfaceExternalModels(DialectRegistry & registry)138 void mlir::shape::registerBufferizableOpInterfaceExternalModels(
139     DialectRegistry &registry) {
140   registry.addExtension(+[](MLIRContext *ctx, shape::ShapeDialect *dialect) {
141     shape::AssumingOp::attachInterface<AssumingOpInterface>(*ctx);
142     shape::AssumingYieldOp::attachInterface<AssumingYieldOpInterface>(*ctx);
143   });
144 }
145