xref: /llvm-project/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp (revision c95fcd343d405e659190b746052a9fcac573f8ac)
13b2004e1SMatthias Springer //===- TensorCopyInsertion.cpp - Resolve Bufferization Conflicts w/ Copies ===//
23b2004e1SMatthias Springer //
33b2004e1SMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
43b2004e1SMatthias Springer // See https://llvm.org/LICENSE.txt for license information.
53b2004e1SMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
63b2004e1SMatthias Springer //
73b2004e1SMatthias Springer //===----------------------------------------------------------------------===//
83b2004e1SMatthias Springer 
967d0d7acSMichele Scuttari #include "mlir/Dialect/Bufferization/Transforms/Passes.h"
103b2004e1SMatthias Springer 
113b2004e1SMatthias Springer #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
123b2004e1SMatthias Springer #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
133b2004e1SMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
143b2004e1SMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
153b2004e1SMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h"
1628b2f792SMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
1767d0d7acSMichele Scuttari #include "mlir/Dialect/Func/IR/FuncOps.h"
1867d0d7acSMichele Scuttari 
1967d0d7acSMichele Scuttari namespace mlir {
2067d0d7acSMichele Scuttari namespace bufferization {
2167d0d7acSMichele Scuttari #define GEN_PASS_DEF_TENSORCOPYINSERTION
2267d0d7acSMichele Scuttari #include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
2367d0d7acSMichele Scuttari } // namespace bufferization
2467d0d7acSMichele Scuttari } // namespace mlir
253b2004e1SMatthias Springer 
263b2004e1SMatthias Springer using namespace mlir;
273b2004e1SMatthias Springer using namespace mlir::bufferization;
283b2004e1SMatthias Springer 
insertTensorCopies(Operation * op,const OneShotBufferizationOptions & options,BufferizationStatistics * statistics)293b2004e1SMatthias Springer LogicalResult mlir::bufferization::insertTensorCopies(
30*ae05bd99SMatthias Springer     Operation *op, const OneShotBufferizationOptions &options,
31*ae05bd99SMatthias Springer     BufferizationStatistics *statistics) {
323b2004e1SMatthias Springer   OneShotAnalysisState state(op, options);
333b2004e1SMatthias Springer   // Run normal One-Shot Bufferize analysis or One-Shot Module Bufferize
343b2004e1SMatthias Springer   // analysis depending on whether function boundary bufferization is enabled or
353b2004e1SMatthias Springer   // not.
363b2004e1SMatthias Springer   if (options.bufferizeFunctionBoundaries) {
37*ae05bd99SMatthias Springer     if (failed(analyzeModuleOp(cast<ModuleOp>(op), state, statistics)))
383b2004e1SMatthias Springer       return failure();
393b2004e1SMatthias Springer   } else {
40*ae05bd99SMatthias Springer     if (failed(analyzeOp(op, state, statistics)))
413b2004e1SMatthias Springer       return failure();
423b2004e1SMatthias Springer   }
433b2004e1SMatthias Springer 
443b2004e1SMatthias Springer   if (options.testAnalysisOnly)
453b2004e1SMatthias Springer     return success();
463b2004e1SMatthias Springer 
473b2004e1SMatthias Springer   return insertTensorCopies(op, state);
483b2004e1SMatthias Springer }
493b2004e1SMatthias Springer 
503b2004e1SMatthias Springer LogicalResult
insertTensorCopies(Operation * op,const AnalysisState & state)513b2004e1SMatthias Springer mlir::bufferization::insertTensorCopies(Operation *op,
523b2004e1SMatthias Springer                                         const AnalysisState &state) {
5387c770bbSMatthias Springer   IRRewriter rewriter(op->getContext());
543474d10eSMatthias Springer 
553b2004e1SMatthias Springer   WalkResult result = op->walk([&](Operation *op) {
563b2004e1SMatthias Springer     auto bufferizableOp = state.getOptions().dynCastBufferizableOp(op);
573b2004e1SMatthias Springer     if (!bufferizableOp)
583b2004e1SMatthias Springer       return WalkResult::skip();
593b2004e1SMatthias Springer 
6087c770bbSMatthias Springer     // Find inplacability conflicts and resolve them. (Typically with explicit
6187c770bbSMatthias Springer     // tensor copies in the form of AllocTensorOps.)
6287c770bbSMatthias Springer     rewriter.setInsertionPoint(op);
6387c770bbSMatthias Springer     if (failed(bufferizableOp.resolveConflicts(rewriter, state)))
643b2004e1SMatthias Springer       return WalkResult::interrupt();
653b2004e1SMatthias Springer 
663b2004e1SMatthias Springer     return WalkResult::advance();
673b2004e1SMatthias Springer   });
683b2004e1SMatthias Springer 
693b2004e1SMatthias Springer   return failure(result.wasInterrupted());
703b2004e1SMatthias Springer }
71