xref: /llvm-project/mlir/lib/Dialect/SCF/Transforms/BufferDeallocationOpInterfaceImpl.cpp (revision ba727ac2199c3c1cecfdaaa487cca0ffc29d2e64)
1 //===- BufferDeallocationOpInterfaceImpl.cpp ------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/SCF/Transforms/BufferDeallocationOpInterfaceImpl.h"
10 #include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h"
11 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
12 #include "mlir/Dialect/SCF/IR/SCF.h"
13 
14 using namespace mlir;
15 using namespace mlir::bufferization;
16 
17 namespace {
18 /// The `scf.forall.in_parallel` terminator is special in a few ways:
19 /// * It does not implement the BranchOpInterface or
20 ///   RegionBranchTerminatorOpInterface, but the ParallelCombiningOpInterface
21 ///   which is not supported by BufferDeallocation.
22 /// * It has a graph-like region which only allows one specific tensor op
23 /// * After bufferization the nested region is always empty
24 /// For these reasons we provide custom deallocation logic via this external
25 /// model.
26 ///
27 /// Example:
28 /// ```mlir
29 /// scf.forall (%arg1) in (%arg0) {
30 ///   %alloc = memref.alloc() : memref<2xf32>
31 ///   ...
32 ///   <implicit in_parallel terminator here>
33 /// }
34 /// ```
35 /// gets transformed to
36 /// ```mlir
37 /// scf.forall (%arg1) in (%arg0) {
38 ///   %alloc = memref.alloc() : memref<2xf32>
39 ///   ...
40 ///   bufferization.dealloc (%alloc : memref<2xf32>) if (%true)
41 ///   <implicit in_parallel terminator here>
42 /// }
43 /// ```
44 struct InParallelOpInterface
45     : public BufferDeallocationOpInterface::ExternalModel<InParallelOpInterface,
46                                                           scf::InParallelOp> {
process__anon22d8e8550111::InParallelOpInterface47   FailureOr<Operation *> process(Operation *op, DeallocationState &state,
48                                  const DeallocationOptions &options) const {
49     auto inParallelOp = cast<scf::InParallelOp>(op);
50     if (!inParallelOp.getBody()->empty())
51       return op->emitError("only supported when nested region is empty");
52 
53     SmallVector<Value> updatedOperandOwnership;
54     return deallocation_impl::insertDeallocOpForReturnLike(
55         state, op, {}, updatedOperandOwnership);
56   }
57 };
58 
59 struct ReduceReturnOpInterface
60     : public BufferDeallocationOpInterface::ExternalModel<
61           ReduceReturnOpInterface, scf::ReduceReturnOp> {
process__anon22d8e8550111::ReduceReturnOpInterface62   FailureOr<Operation *> process(Operation *op, DeallocationState &state,
63                                  const DeallocationOptions &options) const {
64     auto reduceReturnOp = cast<scf::ReduceReturnOp>(op);
65     if (isa<BaseMemRefType>(reduceReturnOp.getOperand().getType()))
66       return op->emitError("only supported when operand is not a MemRef");
67 
68     SmallVector<Value> updatedOperandOwnership;
69     return deallocation_impl::insertDeallocOpForReturnLike(
70         state, op, {}, updatedOperandOwnership);
71   }
72 };
73 
74 } // namespace
75 
registerBufferDeallocationOpInterfaceExternalModels(DialectRegistry & registry)76 void mlir::scf::registerBufferDeallocationOpInterfaceExternalModels(
77     DialectRegistry &registry) {
78   registry.addExtension(+[](MLIRContext *ctx, SCFDialect *dialect) {
79     InParallelOp::attachInterface<InParallelOpInterface>(*ctx);
80     ReduceReturnOp::attachInterface<ReduceReturnOpInterface>(*ctx);
81   });
82 }
83