1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.h"
10
11 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
12 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
13 #include "mlir/Dialect/Shape/IR/Shape.h"
14 #include "mlir/IR/Dialect.h"
15 #include "mlir/IR/Operation.h"
16 #include "mlir/IR/PatternMatch.h"
17
18 using namespace mlir;
19 using namespace mlir::bufferization;
20 using namespace mlir::shape;
21
22 namespace mlir {
23 namespace shape {
24 namespace {
25
26 /// Bufferization of shape.assuming.
27 struct AssumingOpInterface
28 : public BufferizableOpInterface::ExternalModel<AssumingOpInterface,
29 shape::AssumingOp> {
30 AliasingOpOperandList
getAliasingOpOperandsmlir::shape::__anon221796b20111::AssumingOpInterface31 getAliasingOpOperands(Operation *op, Value value,
32 const AnalysisState &state) const {
33 // AssumingOps do not have tensor OpOperands. The yielded value can be any
34 // SSA value that is in scope. To allow for use-def chain traversal through
35 // AssumingOps in the analysis, the corresponding yield value is considered
36 // to be aliasing with the result.
37 auto assumingOp = cast<shape::AssumingOp>(op);
38 size_t resultNum = std::distance(op->getOpResults().begin(),
39 llvm::find(op->getOpResults(), value));
40 // TODO: Support multiple blocks.
41 assert(assumingOp.getDoRegion().getBlocks().size() == 1 &&
42 "expected exactly 1 block");
43 auto yieldOp = dyn_cast<shape::AssumingYieldOp>(
44 assumingOp.getDoRegion().front().getTerminator());
45 assert(yieldOp && "expected shape.assuming_yield terminator");
46 return {{&yieldOp->getOpOperand(resultNum), BufferRelation::Equivalent}};
47 }
48
bufferizemlir::shape::__anon221796b20111::AssumingOpInterface49 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
50 const BufferizationOptions &options) const {
51 auto assumingOp = cast<shape::AssumingOp>(op);
52 assert(assumingOp.getDoRegion().getBlocks().size() == 1 &&
53 "only 1 block supported");
54 auto yieldOp = cast<shape::AssumingYieldOp>(
55 assumingOp.getDoRegion().front().getTerminator());
56
57 // Create new op and move over region.
58 TypeRange newResultTypes(yieldOp.getOperands());
59 auto newOp = rewriter.create<shape::AssumingOp>(
60 op->getLoc(), newResultTypes, assumingOp.getWitness());
61 newOp.getDoRegion().takeBody(assumingOp.getRegion());
62
63 // Update all uses of the old op.
64 rewriter.setInsertionPointAfter(newOp);
65 SmallVector<Value> newResults;
66 for (const auto &it : llvm::enumerate(assumingOp->getResultTypes())) {
67 if (isa<TensorType>(it.value())) {
68 newResults.push_back(rewriter.create<bufferization::ToTensorOp>(
69 assumingOp.getLoc(), newOp->getResult(it.index())));
70 } else {
71 newResults.push_back(newOp->getResult(it.index()));
72 }
73 }
74
75 // Replace old op.
76 rewriter.replaceOp(assumingOp, newResults);
77
78 return success();
79 }
80 };
81
82 /// Bufferization of shape.assuming_yield. Bufferized as part of their enclosing
83 /// ops, so this is for analysis only.
84 struct AssumingYieldOpInterface
85 : public BufferizableOpInterface::ExternalModel<AssumingYieldOpInterface,
86 shape::AssumingYieldOp> {
bufferizesToMemoryReadmlir::shape::__anon221796b20111::AssumingYieldOpInterface87 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
88 const AnalysisState &state) const {
89 return true;
90 }
91
bufferizesToMemoryWritemlir::shape::__anon221796b20111::AssumingYieldOpInterface92 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
93 const AnalysisState &state) const {
94 return false;
95 }
96
getAliasingValuesmlir::shape::__anon221796b20111::AssumingYieldOpInterface97 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
98 const AnalysisState &state) const {
99 assert(isa<shape::AssumingOp>(op->getParentOp()) &&
100 "expected that parent is an AssumingOp");
101 OpResult opResult =
102 op->getParentOp()->getResult(opOperand.getOperandNumber());
103 return {{opResult, BufferRelation::Equivalent}};
104 }
105
mustBufferizeInPlacemlir::shape::__anon221796b20111::AssumingYieldOpInterface106 bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand,
107 const AnalysisState &state) const {
108 // Yield operands always bufferize inplace. Otherwise, an alloc + copy
109 // may be generated inside the block. We should not return/yield allocations
110 // when possible.
111 return true;
112 }
113
bufferizemlir::shape::__anon221796b20111::AssumingYieldOpInterface114 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
115 const BufferizationOptions &options) const {
116 auto yieldOp = cast<shape::AssumingYieldOp>(op);
117 SmallVector<Value> newResults;
118 for (Value value : yieldOp.getOperands()) {
119 if (isa<TensorType>(value.getType())) {
120 FailureOr<Value> buffer = getBuffer(rewriter, value, options);
121 if (failed(buffer))
122 return failure();
123 newResults.push_back(*buffer);
124 } else {
125 newResults.push_back(value);
126 }
127 }
128 replaceOpWithNewBufferizedOp<shape::AssumingYieldOp>(rewriter, op,
129 newResults);
130 return success();
131 }
132 };
133
134 } // namespace
135 } // namespace shape
136 } // namespace mlir
137
registerBufferizableOpInterfaceExternalModels(DialectRegistry & registry)138 void mlir::shape::registerBufferizableOpInterfaceExternalModels(
139 DialectRegistry ®istry) {
140 registry.addExtension(+[](MLIRContext *ctx, shape::ShapeDialect *dialect) {
141 shape::AssumingOp::attachInterface<AssumingOpInterface>(*ctx);
142 shape::AssumingYieldOp::attachInterface<AssumingYieldOpInterface>(*ctx);
143 });
144 }
145