18160bce9SMartin Erhart //===- BufferDeallocationOpInterfaceImpl.cpp ------------------------------===//
28160bce9SMartin Erhart //
38160bce9SMartin Erhart // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
48160bce9SMartin Erhart // See https://llvm.org/LICENSE.txt for license information.
58160bce9SMartin Erhart // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
68160bce9SMartin Erhart //
78160bce9SMartin Erhart //===----------------------------------------------------------------------===//
88160bce9SMartin Erhart
98160bce9SMartin Erhart #include "mlir/Dialect/ControlFlow/Transforms/BufferDeallocationOpInterfaceImpl.h"
108160bce9SMartin Erhart #include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h"
118160bce9SMartin Erhart #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
128160bce9SMartin Erhart #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
138160bce9SMartin Erhart #include "mlir/Dialect/MemRef/IR/MemRef.h"
148160bce9SMartin Erhart #include "mlir/IR/Dialect.h"
158160bce9SMartin Erhart #include "mlir/IR/Operation.h"
168160bce9SMartin Erhart
178160bce9SMartin Erhart using namespace mlir;
188160bce9SMartin Erhart using namespace mlir::bufferization;
198160bce9SMartin Erhart
isMemref(Value v)20*a5757c5bSChristian Sigg static bool isMemref(Value v) { return isa<BaseMemRefType>(v.getType()); }
218160bce9SMartin Erhart
228160bce9SMartin Erhart namespace {
238160bce9SMartin Erhart /// While CondBranchOp also implement the BranchOpInterface, we add a
248160bce9SMartin Erhart /// special-case implementation here because the BranchOpInterface does not
258160bce9SMartin Erhart /// offer all of the functionallity we need to insert dealloc oeprations in an
268160bce9SMartin Erhart /// efficient way. More precisely, there is no way to extract the branch
278160bce9SMartin Erhart /// condition without casting to CondBranchOp specifically. It is still
288160bce9SMartin Erhart /// possible to implement deallocation for cases where we don't know to which
298160bce9SMartin Erhart /// successor the terminator branches before the actual branch happens by
308160bce9SMartin Erhart /// inserting auxiliary blocks and putting the dealloc op there, however, this
318160bce9SMartin Erhart /// can lead to less efficient code.
328160bce9SMartin Erhart /// This function inserts two dealloc operations (one for each successor) and
338160bce9SMartin Erhart /// adjusts the dealloc conditions according to the branch condition, then the
348160bce9SMartin Erhart /// ownerships of the retained MemRefs are updated by combining the result
358160bce9SMartin Erhart /// values of the two dealloc operations.
368160bce9SMartin Erhart ///
378160bce9SMartin Erhart /// Example:
388160bce9SMartin Erhart /// ```
398160bce9SMartin Erhart /// ^bb1:
408160bce9SMartin Erhart /// <more ops...>
418160bce9SMartin Erhart /// cf.cond_br cond, ^bb2(<forward-to-bb2>), ^bb3(<forward-to-bb2>)
428160bce9SMartin Erhart /// ```
438160bce9SMartin Erhart /// becomes
448160bce9SMartin Erhart /// ```
458160bce9SMartin Erhart /// // let (m, c) = getMemrefsAndConditionsToDeallocate(bb1)
468160bce9SMartin Erhart /// // let r0 = getMemrefsToRetain(bb1, bb2, <forward-to-bb2>)
478160bce9SMartin Erhart /// // let r1 = getMemrefsToRetain(bb1, bb3, <forward-to-bb3>)
488160bce9SMartin Erhart /// ^bb1:
498160bce9SMartin Erhart /// <more ops...>
508160bce9SMartin Erhart /// let thenCond = map(c, (c) -> arith.andi cond, c)
518160bce9SMartin Erhart /// let elseCond = map(c, (c) -> arith.andi (arith.xori cond, true), c)
528160bce9SMartin Erhart /// o0 = bufferization.dealloc m if thenCond retain r0
538160bce9SMartin Erhart /// o1 = bufferization.dealloc m if elseCond retain r1
548160bce9SMartin Erhart /// // replace ownership(r0) with o0 element-wise
558160bce9SMartin Erhart /// // replace ownership(r1) with o1 element-wise
568160bce9SMartin Erhart /// // let ownership0 := (r) -> o in o0 corresponding to r
578160bce9SMartin Erhart /// // let ownership1 := (r) -> o in o1 corresponding to r
588160bce9SMartin Erhart /// // let cmn := intersection(r0, r1)
598160bce9SMartin Erhart /// foreach (a, b) in zip(map(cmn, ownership0), map(cmn, ownership1)):
608160bce9SMartin Erhart /// forall r in r0: replace ownership0(r) with arith.select cond, a, b)
618160bce9SMartin Erhart /// forall r in r1: replace ownership1(r) with arith.select cond, a, b)
628160bce9SMartin Erhart /// cf.cond_br cond, ^bb2(<forward-to-bb2>, o0), ^bb3(<forward-to-bb3>, o1)
638160bce9SMartin Erhart /// ```
648160bce9SMartin Erhart struct CondBranchOpInterface
658160bce9SMartin Erhart : public BufferDeallocationOpInterface::ExternalModel<CondBranchOpInterface,
668160bce9SMartin Erhart cf::CondBranchOp> {
process__anone39b70920111::CondBranchOpInterface678160bce9SMartin Erhart FailureOr<Operation *> process(Operation *op, DeallocationState &state,
688160bce9SMartin Erhart const DeallocationOptions &options) const {
698160bce9SMartin Erhart OpBuilder builder(op);
708160bce9SMartin Erhart auto condBr = cast<cf::CondBranchOp>(op);
718160bce9SMartin Erhart
728160bce9SMartin Erhart // The list of memrefs to deallocate in this block is independent of which
738160bce9SMartin Erhart // branch is taken.
748160bce9SMartin Erhart SmallVector<Value> memrefs, conditions;
758160bce9SMartin Erhart if (failed(state.getMemrefsAndConditionsToDeallocate(
768160bce9SMartin Erhart builder, condBr.getLoc(), condBr->getBlock(), memrefs, conditions)))
778160bce9SMartin Erhart return failure();
788160bce9SMartin Erhart
798160bce9SMartin Erhart // Helper lambda to factor out common logic for inserting the dealloc
808160bce9SMartin Erhart // operations for each successor.
818160bce9SMartin Erhart auto insertDeallocForBranch =
828160bce9SMartin Erhart [&](Block *target, MutableOperandRange destOperands,
838160bce9SMartin Erhart const std::function<Value(Value)> &conditionModifier,
848160bce9SMartin Erhart DenseMap<Value, Value> &mapping) -> DeallocOp {
858160bce9SMartin Erhart SmallVector<Value> toRetain;
868160bce9SMartin Erhart state.getMemrefsToRetain(condBr->getBlock(), target,
87a9304edfSThomas Preud'homme destOperands.getAsOperandRange(), toRetain);
888160bce9SMartin Erhart SmallVector<Value> adaptedConditions(
898160bce9SMartin Erhart llvm::map_range(conditions, conditionModifier));
908160bce9SMartin Erhart auto deallocOp = builder.create<bufferization::DeallocOp>(
918160bce9SMartin Erhart condBr.getLoc(), memrefs, adaptedConditions, toRetain);
928160bce9SMartin Erhart state.resetOwnerships(deallocOp.getRetained(), condBr->getBlock());
938160bce9SMartin Erhart for (auto [retained, ownership] : llvm::zip(
948160bce9SMartin Erhart deallocOp.getRetained(), deallocOp.getUpdatedConditions())) {
958160bce9SMartin Erhart state.updateOwnership(retained, ownership, condBr->getBlock());
968160bce9SMartin Erhart mapping[retained] = ownership;
978160bce9SMartin Erhart }
988160bce9SMartin Erhart SmallVector<Value> replacements, ownerships;
996923a315SMatthias Springer for (OpOperand &operand : destOperands) {
1006923a315SMatthias Springer replacements.push_back(operand.get());
1016923a315SMatthias Springer if (isMemref(operand.get())) {
1026923a315SMatthias Springer assert(mapping.contains(operand.get()) &&
1038160bce9SMartin Erhart "Should be contained at this point");
1046923a315SMatthias Springer ownerships.push_back(mapping[operand.get()]);
1058160bce9SMartin Erhart }
1068160bce9SMartin Erhart }
1078160bce9SMartin Erhart replacements.append(ownerships);
1088160bce9SMartin Erhart destOperands.assign(replacements);
1098160bce9SMartin Erhart return deallocOp;
1108160bce9SMartin Erhart };
1118160bce9SMartin Erhart
1128160bce9SMartin Erhart // Call the helper lambda and make sure the dealloc conditions are properly
1138160bce9SMartin Erhart // modified to reflect the branch condition as well.
1148160bce9SMartin Erhart DenseMap<Value, Value> thenMapping, elseMapping;
1158160bce9SMartin Erhart DeallocOp thenTakenDeallocOp = insertDeallocForBranch(
1168160bce9SMartin Erhart condBr.getTrueDest(), condBr.getTrueDestOperandsMutable(),
1178160bce9SMartin Erhart [&](Value cond) {
1188160bce9SMartin Erhart return builder.create<arith::AndIOp>(condBr.getLoc(), cond,
1198160bce9SMartin Erhart condBr.getCondition());
1208160bce9SMartin Erhart },
1218160bce9SMartin Erhart thenMapping);
1228160bce9SMartin Erhart DeallocOp elseTakenDeallocOp = insertDeallocForBranch(
1238160bce9SMartin Erhart condBr.getFalseDest(), condBr.getFalseDestOperandsMutable(),
1248160bce9SMartin Erhart [&](Value cond) {
1258160bce9SMartin Erhart Value trueVal = builder.create<arith::ConstantOp>(
1268160bce9SMartin Erhart condBr.getLoc(), builder.getBoolAttr(true));
1278160bce9SMartin Erhart Value negation = builder.create<arith::XOrIOp>(
1288160bce9SMartin Erhart condBr.getLoc(), trueVal, condBr.getCondition());
1298160bce9SMartin Erhart return builder.create<arith::AndIOp>(condBr.getLoc(), cond, negation);
1308160bce9SMartin Erhart },
1318160bce9SMartin Erhart elseMapping);
1328160bce9SMartin Erhart
1338160bce9SMartin Erhart // We specifically need to update the ownerships of values that are retained
1348160bce9SMartin Erhart // in both dealloc operations again to get a combined 'Unique' ownership
1358160bce9SMartin Erhart // instead of an 'Unknown' ownership.
1368160bce9SMartin Erhart SmallPtrSet<Value, 16> thenValues(thenTakenDeallocOp.getRetained().begin(),
1378160bce9SMartin Erhart thenTakenDeallocOp.getRetained().end());
1388160bce9SMartin Erhart SetVector<Value> commonValues;
1398160bce9SMartin Erhart for (Value val : elseTakenDeallocOp.getRetained()) {
1408160bce9SMartin Erhart if (thenValues.contains(val))
1418160bce9SMartin Erhart commonValues.insert(val);
1428160bce9SMartin Erhart }
1438160bce9SMartin Erhart
1448160bce9SMartin Erhart for (Value retained : commonValues) {
1458160bce9SMartin Erhart state.resetOwnerships(retained, condBr->getBlock());
1468160bce9SMartin Erhart Value combinedOwnership = builder.create<arith::SelectOp>(
1478160bce9SMartin Erhart condBr.getLoc(), condBr.getCondition(), thenMapping[retained],
1488160bce9SMartin Erhart elseMapping[retained]);
1498160bce9SMartin Erhart state.updateOwnership(retained, combinedOwnership, condBr->getBlock());
1508160bce9SMartin Erhart }
1518160bce9SMartin Erhart
1528160bce9SMartin Erhart return condBr.getOperation();
1538160bce9SMartin Erhart }
1548160bce9SMartin Erhart };
1558160bce9SMartin Erhart
1568160bce9SMartin Erhart } // namespace
1578160bce9SMartin Erhart
registerBufferDeallocationOpInterfaceExternalModels(DialectRegistry & registry)1588160bce9SMartin Erhart void mlir::cf::registerBufferDeallocationOpInterfaceExternalModels(
1598160bce9SMartin Erhart DialectRegistry ®istry) {
1608160bce9SMartin Erhart registry.addExtension(+[](MLIRContext *ctx, ControlFlowDialect *dialect) {
1618160bce9SMartin Erhart CondBranchOp::attachInterface<CondBranchOpInterface>(*ctx);
1628160bce9SMartin Erhart });
1638160bce9SMartin Erhart }
164