1 //===- BufferDeallocationOpInterfaceImpl.cpp ------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Dialect/SCF/Transforms/BufferDeallocationOpInterfaceImpl.h"
10 #include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h"
11 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
12 #include "mlir/Dialect/SCF/IR/SCF.h"
13
14 using namespace mlir;
15 using namespace mlir::bufferization;
16
17 namespace {
18 /// The `scf.forall.in_parallel` terminator is special in a few ways:
19 /// * It does not implement the BranchOpInterface or
20 /// RegionBranchTerminatorOpInterface, but the ParallelCombiningOpInterface
21 /// which is not supported by BufferDeallocation.
22 /// * It has a graph-like region which only allows one specific tensor op
23 /// * After bufferization the nested region is always empty
24 /// For these reasons we provide custom deallocation logic via this external
25 /// model.
26 ///
27 /// Example:
28 /// ```mlir
29 /// scf.forall (%arg1) in (%arg0) {
30 /// %alloc = memref.alloc() : memref<2xf32>
31 /// ...
32 /// <implicit in_parallel terminator here>
33 /// }
34 /// ```
35 /// gets transformed to
36 /// ```mlir
37 /// scf.forall (%arg1) in (%arg0) {
38 /// %alloc = memref.alloc() : memref<2xf32>
39 /// ...
40 /// bufferization.dealloc (%alloc : memref<2xf32>) if (%true)
41 /// <implicit in_parallel terminator here>
42 /// }
43 /// ```
44 struct InParallelOpInterface
45 : public BufferDeallocationOpInterface::ExternalModel<InParallelOpInterface,
46 scf::InParallelOp> {
process__anon22d8e8550111::InParallelOpInterface47 FailureOr<Operation *> process(Operation *op, DeallocationState &state,
48 const DeallocationOptions &options) const {
49 auto inParallelOp = cast<scf::InParallelOp>(op);
50 if (!inParallelOp.getBody()->empty())
51 return op->emitError("only supported when nested region is empty");
52
53 SmallVector<Value> updatedOperandOwnership;
54 return deallocation_impl::insertDeallocOpForReturnLike(
55 state, op, {}, updatedOperandOwnership);
56 }
57 };
58
59 struct ReduceReturnOpInterface
60 : public BufferDeallocationOpInterface::ExternalModel<
61 ReduceReturnOpInterface, scf::ReduceReturnOp> {
process__anon22d8e8550111::ReduceReturnOpInterface62 FailureOr<Operation *> process(Operation *op, DeallocationState &state,
63 const DeallocationOptions &options) const {
64 auto reduceReturnOp = cast<scf::ReduceReturnOp>(op);
65 if (isa<BaseMemRefType>(reduceReturnOp.getOperand().getType()))
66 return op->emitError("only supported when operand is not a MemRef");
67
68 SmallVector<Value> updatedOperandOwnership;
69 return deallocation_impl::insertDeallocOpForReturnLike(
70 state, op, {}, updatedOperandOwnership);
71 }
72 };
73
74 } // namespace
75
registerBufferDeallocationOpInterfaceExternalModels(DialectRegistry & registry)76 void mlir::scf::registerBufferDeallocationOpInterfaceExternalModels(
77 DialectRegistry ®istry) {
78 registry.addExtension(+[](MLIRContext *ctx, SCFDialect *dialect) {
79 InParallelOp::attachInterface<InParallelOpInterface>(*ctx);
80 ReduceReturnOp::attachInterface<ReduceReturnOpInterface>(*ctx);
81 });
82 }
83