xref: /llvm-project/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp (revision a02ad6c1773368c9ce67d3a28578bf6284c6c1be)
193e66327SMatthias Springer //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
293e66327SMatthias Springer //
393e66327SMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
493e66327SMatthias Springer // See https://llvm.org/LICENSE.txt for license information.
593e66327SMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
693e66327SMatthias Springer //
793e66327SMatthias Springer //===----------------------------------------------------------------------===//
893e66327SMatthias Springer 
993e66327SMatthias Springer #include "mlir/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.h"
1093e66327SMatthias Springer 
1193e66327SMatthias Springer #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
1293e66327SMatthias Springer #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
1393e66327SMatthias Springer #include "mlir/Dialect/Shape/IR/Shape.h"
1493e66327SMatthias Springer #include "mlir/IR/Dialect.h"
1593e66327SMatthias Springer #include "mlir/IR/Operation.h"
1693e66327SMatthias Springer #include "mlir/IR/PatternMatch.h"
1793e66327SMatthias Springer 
1893e66327SMatthias Springer using namespace mlir;
1993e66327SMatthias Springer using namespace mlir::bufferization;
2093e66327SMatthias Springer using namespace mlir::shape;
2193e66327SMatthias Springer 
2293e66327SMatthias Springer namespace mlir {
2393e66327SMatthias Springer namespace shape {
2493e66327SMatthias Springer namespace {
2593e66327SMatthias Springer 
2693e66327SMatthias Springer /// Bufferization of shape.assuming.
2793e66327SMatthias Springer struct AssumingOpInterface
2893e66327SMatthias Springer     : public BufferizableOpInterface::ExternalModel<AssumingOpInterface,
2993e66327SMatthias Springer                                                     shape::AssumingOp> {
301ac248e4SMatthias Springer   AliasingOpOperandList
getAliasingOpOperandsmlir::shape::__anon221796b20111::AssumingOpInterface31*a02ad6c1SMatthias Springer   getAliasingOpOperands(Operation *op, Value value,
329597b16aSMatthias Springer                         const AnalysisState &state) const {
3393e66327SMatthias Springer     // AssumingOps do not have tensor OpOperands. The yielded value can be any
3493e66327SMatthias Springer     // SSA value that is in scope. To allow for use-def chain traversal through
3593e66327SMatthias Springer     // AssumingOps in the analysis, the corresponding yield value is considered
3693e66327SMatthias Springer     // to be aliasing with the result.
3793e66327SMatthias Springer     auto assumingOp = cast<shape::AssumingOp>(op);
3893e66327SMatthias Springer     size_t resultNum = std::distance(op->getOpResults().begin(),
39*a02ad6c1SMatthias Springer                                      llvm::find(op->getOpResults(), value));
4093e66327SMatthias Springer     // TODO: Support multiple blocks.
4193e66327SMatthias Springer     assert(assumingOp.getDoRegion().getBlocks().size() == 1 &&
4293e66327SMatthias Springer            "expected exactly 1 block");
4393e66327SMatthias Springer     auto yieldOp = dyn_cast<shape::AssumingYieldOp>(
4493e66327SMatthias Springer         assumingOp.getDoRegion().front().getTerminator());
4593e66327SMatthias Springer     assert(yieldOp && "expected shape.assuming_yield terminator");
469fa6b350SMatthias Springer     return {{&yieldOp->getOpOperand(resultNum), BufferRelation::Equivalent}};
4793e66327SMatthias Springer   }
4893e66327SMatthias Springer 
bufferizemlir::shape::__anon221796b20111::AssumingOpInterface4993e66327SMatthias Springer   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
50b55d55ecSMatthias Springer                           const BufferizationOptions &options) const {
5193e66327SMatthias Springer     auto assumingOp = cast<shape::AssumingOp>(op);
5219efb84cSMatthias Springer     assert(assumingOp.getDoRegion().getBlocks().size() == 1 &&
5319efb84cSMatthias Springer            "only 1 block supported");
5419efb84cSMatthias Springer     auto yieldOp = cast<shape::AssumingYieldOp>(
5519efb84cSMatthias Springer         assumingOp.getDoRegion().front().getTerminator());
5693e66327SMatthias Springer 
5793e66327SMatthias Springer     // Create new op and move over region.
58b74192b7SRiver Riddle     TypeRange newResultTypes(yieldOp.getOperands());
5993e66327SMatthias Springer     auto newOp = rewriter.create<shape::AssumingOp>(
6093e66327SMatthias Springer         op->getLoc(), newResultTypes, assumingOp.getWitness());
6193e66327SMatthias Springer     newOp.getDoRegion().takeBody(assumingOp.getRegion());
6293e66327SMatthias Springer 
6393e66327SMatthias Springer     // Update all uses of the old op.
6493e66327SMatthias Springer     rewriter.setInsertionPointAfter(newOp);
6593e66327SMatthias Springer     SmallVector<Value> newResults;
6693e66327SMatthias Springer     for (const auto &it : llvm::enumerate(assumingOp->getResultTypes())) {
675550c821STres Popp       if (isa<TensorType>(it.value())) {
6893e66327SMatthias Springer         newResults.push_back(rewriter.create<bufferization::ToTensorOp>(
6993e66327SMatthias Springer             assumingOp.getLoc(), newOp->getResult(it.index())));
7093e66327SMatthias Springer       } else {
7193e66327SMatthias Springer         newResults.push_back(newOp->getResult(it.index()));
7293e66327SMatthias Springer       }
7393e66327SMatthias Springer     }
7493e66327SMatthias Springer 
7593e66327SMatthias Springer     // Replace old op.
7693e66327SMatthias Springer     rewriter.replaceOp(assumingOp, newResults);
7793e66327SMatthias Springer 
7893e66327SMatthias Springer     return success();
7993e66327SMatthias Springer   }
8093e66327SMatthias Springer };
8193e66327SMatthias Springer 
8293e66327SMatthias Springer /// Bufferization of shape.assuming_yield. Bufferized as part of their enclosing
8393e66327SMatthias Springer /// ops, so this is for analysis only.
8493e66327SMatthias Springer struct AssumingYieldOpInterface
8593e66327SMatthias Springer     : public BufferizableOpInterface::ExternalModel<AssumingYieldOpInterface,
866ab1ed43SMatthias Springer                                                     shape::AssumingYieldOp> {
bufferizesToMemoryReadmlir::shape::__anon221796b20111::AssumingYieldOpInterface8793e66327SMatthias Springer   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
889597b16aSMatthias Springer                               const AnalysisState &state) const {
8993e66327SMatthias Springer     return true;
9093e66327SMatthias Springer   }
9193e66327SMatthias Springer 
bufferizesToMemoryWritemlir::shape::__anon221796b20111::AssumingYieldOpInterface9293e66327SMatthias Springer   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
939597b16aSMatthias Springer                                const AnalysisState &state) const {
9493e66327SMatthias Springer     return false;
9593e66327SMatthias Springer   }
9693e66327SMatthias Springer 
getAliasingValuesmlir::shape::__anon221796b20111::AssumingYieldOpInterface97*a02ad6c1SMatthias Springer   AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
989597b16aSMatthias Springer                                       const AnalysisState &state) const {
9993e66327SMatthias Springer     assert(isa<shape::AssumingOp>(op->getParentOp()) &&
10093e66327SMatthias Springer            "expected that parent is an AssumingOp");
1011ac248e4SMatthias Springer     OpResult opResult =
1021ac248e4SMatthias Springer         op->getParentOp()->getResult(opOperand.getOperandNumber());
1039fa6b350SMatthias Springer     return {{opResult, BufferRelation::Equivalent}};
10493e66327SMatthias Springer   }
10593e66327SMatthias Springer 
mustBufferizeInPlacemlir::shape::__anon221796b20111::AssumingYieldOpInterface10693e66327SMatthias Springer   bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand,
1079597b16aSMatthias Springer                             const AnalysisState &state) const {
10893e66327SMatthias Springer     // Yield operands always bufferize inplace. Otherwise, an alloc + copy
10993e66327SMatthias Springer     // may be generated inside the block. We should not return/yield allocations
11093e66327SMatthias Springer     // when possible.
11193e66327SMatthias Springer     return true;
11293e66327SMatthias Springer   }
11393e66327SMatthias Springer 
bufferizemlir::shape::__anon221796b20111::AssumingYieldOpInterface11493e66327SMatthias Springer   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
115b55d55ecSMatthias Springer                           const BufferizationOptions &options) const {
11619efb84cSMatthias Springer     auto yieldOp = cast<shape::AssumingYieldOp>(op);
11719efb84cSMatthias Springer     SmallVector<Value> newResults;
118b74192b7SRiver Riddle     for (Value value : yieldOp.getOperands()) {
1195550c821STres Popp       if (isa<TensorType>(value.getType())) {
1205d50f51cSMatthias Springer         FailureOr<Value> buffer = getBuffer(rewriter, value, options);
1215d50f51cSMatthias Springer         if (failed(buffer))
1225d50f51cSMatthias Springer           return failure();
1235d50f51cSMatthias Springer         newResults.push_back(*buffer);
1245d50f51cSMatthias Springer       } else {
1255d50f51cSMatthias Springer         newResults.push_back(value);
1265d50f51cSMatthias Springer       }
1275d50f51cSMatthias Springer     }
12819efb84cSMatthias Springer     replaceOpWithNewBufferizedOp<shape::AssumingYieldOp>(rewriter, op,
12919efb84cSMatthias Springer                                                          newResults);
130ba9d886dSMatthias Springer     return success();
13193e66327SMatthias Springer   }
13293e66327SMatthias Springer };
13393e66327SMatthias Springer 
13493e66327SMatthias Springer } // namespace
13593e66327SMatthias Springer } // namespace shape
13693e66327SMatthias Springer } // namespace mlir
13793e66327SMatthias Springer 
registerBufferizableOpInterfaceExternalModels(DialectRegistry & registry)13893e66327SMatthias Springer void mlir::shape::registerBufferizableOpInterfaceExternalModels(
13993e66327SMatthias Springer     DialectRegistry &registry) {
14077eee579SRiver Riddle   registry.addExtension(+[](MLIRContext *ctx, shape::ShapeDialect *dialect) {
14177eee579SRiver Riddle     shape::AssumingOp::attachInterface<AssumingOpInterface>(*ctx);
14277eee579SRiver Riddle     shape::AssumingYieldOp::attachInterface<AssumingYieldOpInterface>(*ctx);
14377eee579SRiver Riddle   });
14493e66327SMatthias Springer }
145