xref: /llvm-project/mlir/lib/Dialect/Async/Transforms/PassDetail.cpp (revision b537c5b4147b6966fda8d80ed291f6b1f3857b16)
1 //===- PassDetail.cpp - Async Pass class details ----------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "PassDetail.h"
10 #include "mlir/IR/Builders.h"
11 #include "mlir/Transforms/RegionUtils.h"
12 
13 using namespace mlir;
14 
cloneConstantsIntoTheRegion(Region & region)15 void mlir::async::cloneConstantsIntoTheRegion(Region &region) {
16   OpBuilder builder(&region);
17   cloneConstantsIntoTheRegion(region, builder);
18 }
19 
cloneConstantsIntoTheRegion(Region & region,OpBuilder & builder)20 void mlir::async::cloneConstantsIntoTheRegion(Region &region,
21                                               OpBuilder &builder) {
22   // Values implicitly captured by the region.
23   llvm::SetVector<Value> captures;
24   getUsedValuesDefinedAbove(region, region, captures);
25 
26   OpBuilder::InsertionGuard guard(builder);
27   builder.setInsertionPointToStart(&region.front());
28 
29   // Clone ConstantLike operations into the region.
30   for (Value capture : captures) {
31     Operation *op = capture.getDefiningOp();
32     if (!op || !op->hasTrait<OpTrait::ConstantLike>())
33       continue;
34 
35     Operation *cloned = builder.clone(*op);
36 
37     for (auto tuple : llvm::zip(op->getResults(), cloned->getResults())) {
38       Value orig = std::get<0>(tuple);
39       Value replacement = std::get<1>(tuple);
40       replaceAllUsesInRegionWith(orig, replacement, region);
41     }
42   }
43 }
44