13b2004e1SMatthias Springer //===- TensorCopyInsertion.cpp - Resolve Bufferization Conflicts w/ Copies ===//
23b2004e1SMatthias Springer //
33b2004e1SMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
43b2004e1SMatthias Springer // See https://llvm.org/LICENSE.txt for license information.
53b2004e1SMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
63b2004e1SMatthias Springer //
73b2004e1SMatthias Springer //===----------------------------------------------------------------------===//
83b2004e1SMatthias Springer
967d0d7acSMichele Scuttari #include "mlir/Dialect/Bufferization/Transforms/Passes.h"
103b2004e1SMatthias Springer
113b2004e1SMatthias Springer #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
123b2004e1SMatthias Springer #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
133b2004e1SMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
143b2004e1SMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
153b2004e1SMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h"
1628b2f792SMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
1767d0d7acSMichele Scuttari #include "mlir/Dialect/Func/IR/FuncOps.h"
1867d0d7acSMichele Scuttari
1967d0d7acSMichele Scuttari namespace mlir {
2067d0d7acSMichele Scuttari namespace bufferization {
2167d0d7acSMichele Scuttari #define GEN_PASS_DEF_TENSORCOPYINSERTION
2267d0d7acSMichele Scuttari #include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
2367d0d7acSMichele Scuttari } // namespace bufferization
2467d0d7acSMichele Scuttari } // namespace mlir
253b2004e1SMatthias Springer
263b2004e1SMatthias Springer using namespace mlir;
273b2004e1SMatthias Springer using namespace mlir::bufferization;
283b2004e1SMatthias Springer
insertTensorCopies(Operation * op,const OneShotBufferizationOptions & options,BufferizationStatistics * statistics)293b2004e1SMatthias Springer LogicalResult mlir::bufferization::insertTensorCopies(
30*ae05bd99SMatthias Springer Operation *op, const OneShotBufferizationOptions &options,
31*ae05bd99SMatthias Springer BufferizationStatistics *statistics) {
323b2004e1SMatthias Springer OneShotAnalysisState state(op, options);
333b2004e1SMatthias Springer // Run normal One-Shot Bufferize analysis or One-Shot Module Bufferize
343b2004e1SMatthias Springer // analysis depending on whether function boundary bufferization is enabled or
353b2004e1SMatthias Springer // not.
363b2004e1SMatthias Springer if (options.bufferizeFunctionBoundaries) {
37*ae05bd99SMatthias Springer if (failed(analyzeModuleOp(cast<ModuleOp>(op), state, statistics)))
383b2004e1SMatthias Springer return failure();
393b2004e1SMatthias Springer } else {
40*ae05bd99SMatthias Springer if (failed(analyzeOp(op, state, statistics)))
413b2004e1SMatthias Springer return failure();
423b2004e1SMatthias Springer }
433b2004e1SMatthias Springer
443b2004e1SMatthias Springer if (options.testAnalysisOnly)
453b2004e1SMatthias Springer return success();
463b2004e1SMatthias Springer
473b2004e1SMatthias Springer return insertTensorCopies(op, state);
483b2004e1SMatthias Springer }
493b2004e1SMatthias Springer
503b2004e1SMatthias Springer LogicalResult
insertTensorCopies(Operation * op,const AnalysisState & state)513b2004e1SMatthias Springer mlir::bufferization::insertTensorCopies(Operation *op,
523b2004e1SMatthias Springer const AnalysisState &state) {
5387c770bbSMatthias Springer IRRewriter rewriter(op->getContext());
543474d10eSMatthias Springer
553b2004e1SMatthias Springer WalkResult result = op->walk([&](Operation *op) {
563b2004e1SMatthias Springer auto bufferizableOp = state.getOptions().dynCastBufferizableOp(op);
573b2004e1SMatthias Springer if (!bufferizableOp)
583b2004e1SMatthias Springer return WalkResult::skip();
593b2004e1SMatthias Springer
6087c770bbSMatthias Springer // Find inplacability conflicts and resolve them. (Typically with explicit
6187c770bbSMatthias Springer // tensor copies in the form of AllocTensorOps.)
6287c770bbSMatthias Springer rewriter.setInsertionPoint(op);
6387c770bbSMatthias Springer if (failed(bufferizableOp.resolveConflicts(rewriter, state)))
643b2004e1SMatthias Springer return WalkResult::interrupt();
653b2004e1SMatthias Springer
663b2004e1SMatthias Springer return WalkResult::advance();
673b2004e1SMatthias Springer });
683b2004e1SMatthias Springer
693b2004e1SMatthias Springer return failure(result.wasInterrupted());
703b2004e1SMatthias Springer }
71