1*99c15eb4SIngo Müller //===- Utils.cpp - Utils related to the transform dialect -------*- C++ -*-===//
2*99c15eb4SIngo Müller //
3*99c15eb4SIngo Müller // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*99c15eb4SIngo Müller // See https://llvm.org/LICENSE.txt for license information.
5*99c15eb4SIngo Müller // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*99c15eb4SIngo Müller //
7*99c15eb4SIngo Müller //===----------------------------------------------------------------------===//
8*99c15eb4SIngo Müller
9*99c15eb4SIngo Müller #include "mlir/Dialect/Transform/IR/Utils.h"
10*99c15eb4SIngo Müller #include "mlir/Dialect/Transform/IR/TransformDialect.h"
11*99c15eb4SIngo Müller #include "mlir/IR/Verifier.h"
12*99c15eb4SIngo Müller #include "mlir/Interfaces/FunctionInterfaces.h"
13*99c15eb4SIngo Müller #include "llvm/Support/Debug.h"
14*99c15eb4SIngo Müller
15*99c15eb4SIngo Müller using namespace mlir;
16*99c15eb4SIngo Müller
17*99c15eb4SIngo Müller #define DEBUG_TYPE "transform-dialect-utils"
18*99c15eb4SIngo Müller #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
19*99c15eb4SIngo Müller
20*99c15eb4SIngo Müller /// Return whether `func1` can be merged into `func2`. For that to work
21*99c15eb4SIngo Müller /// `func1` has to be a declaration (aka has to be external) and `func2`
22*99c15eb4SIngo Müller /// either has to be a declaration as well, or it has to be public (otherwise,
23*99c15eb4SIngo Müller /// it wouldn't be visible by `func1`).
canMergeInto(FunctionOpInterface func1,FunctionOpInterface func2)24*99c15eb4SIngo Müller static bool canMergeInto(FunctionOpInterface func1, FunctionOpInterface func2) {
25*99c15eb4SIngo Müller return func1.isExternal() && (func2.isPublic() || func2.isExternal());
26*99c15eb4SIngo Müller }
27*99c15eb4SIngo Müller
28*99c15eb4SIngo Müller /// Merge `func1` into `func2`. The two ops must be inside the same parent op
29*99c15eb4SIngo Müller /// and mergable according to `canMergeInto`. The function erases `func1` such
30*99c15eb4SIngo Müller /// that only `func2` exists when the function returns.
mergeInto(FunctionOpInterface func1,FunctionOpInterface func2)31*99c15eb4SIngo Müller static InFlightDiagnostic mergeInto(FunctionOpInterface func1,
32*99c15eb4SIngo Müller FunctionOpInterface func2) {
33*99c15eb4SIngo Müller assert(canMergeInto(func1, func2));
34*99c15eb4SIngo Müller assert(func1->getParentOp() == func2->getParentOp() &&
35*99c15eb4SIngo Müller "expected func1 and func2 to be in the same parent op");
36*99c15eb4SIngo Müller
37*99c15eb4SIngo Müller // Check that function signatures match.
38*99c15eb4SIngo Müller if (func1.getFunctionType() != func2.getFunctionType()) {
39*99c15eb4SIngo Müller return func1.emitError()
40*99c15eb4SIngo Müller << "external definition has a mismatching signature ("
41*99c15eb4SIngo Müller << func2.getFunctionType() << ")";
42*99c15eb4SIngo Müller }
43*99c15eb4SIngo Müller
44*99c15eb4SIngo Müller // Check and merge argument attributes.
45*99c15eb4SIngo Müller MLIRContext *context = func1->getContext();
46*99c15eb4SIngo Müller auto *td = context->getLoadedDialect<transform::TransformDialect>();
47*99c15eb4SIngo Müller StringAttr consumedName = td->getConsumedAttrName();
48*99c15eb4SIngo Müller StringAttr readOnlyName = td->getReadOnlyAttrName();
49*99c15eb4SIngo Müller for (unsigned i = 0, e = func1.getNumArguments(); i < e; ++i) {
50*99c15eb4SIngo Müller bool isExternalConsumed = func2.getArgAttr(i, consumedName) != nullptr;
51*99c15eb4SIngo Müller bool isExternalReadonly = func2.getArgAttr(i, readOnlyName) != nullptr;
52*99c15eb4SIngo Müller bool isConsumed = func1.getArgAttr(i, consumedName) != nullptr;
53*99c15eb4SIngo Müller bool isReadonly = func1.getArgAttr(i, readOnlyName) != nullptr;
54*99c15eb4SIngo Müller if (!isExternalConsumed && !isExternalReadonly) {
55*99c15eb4SIngo Müller if (isConsumed)
56*99c15eb4SIngo Müller func2.setArgAttr(i, consumedName, UnitAttr::get(context));
57*99c15eb4SIngo Müller else if (isReadonly)
58*99c15eb4SIngo Müller func2.setArgAttr(i, readOnlyName, UnitAttr::get(context));
59*99c15eb4SIngo Müller continue;
60*99c15eb4SIngo Müller }
61*99c15eb4SIngo Müller
62*99c15eb4SIngo Müller if ((isExternalConsumed && !isConsumed) ||
63*99c15eb4SIngo Müller (isExternalReadonly && !isReadonly)) {
64*99c15eb4SIngo Müller return func1.emitError()
65*99c15eb4SIngo Müller << "external definition has mismatching consumption "
66*99c15eb4SIngo Müller "annotations for argument #"
67*99c15eb4SIngo Müller << i;
68*99c15eb4SIngo Müller }
69*99c15eb4SIngo Müller }
70*99c15eb4SIngo Müller
71*99c15eb4SIngo Müller // `func1` is the external one, so we can remove it.
72*99c15eb4SIngo Müller assert(func1.isExternal());
73*99c15eb4SIngo Müller func1->erase();
74*99c15eb4SIngo Müller
75*99c15eb4SIngo Müller return InFlightDiagnostic();
76*99c15eb4SIngo Müller }
77*99c15eb4SIngo Müller
78*99c15eb4SIngo Müller InFlightDiagnostic
mergeSymbolsInto(Operation * target,OwningOpRef<Operation * > other)79*99c15eb4SIngo Müller transform::detail::mergeSymbolsInto(Operation *target,
80*99c15eb4SIngo Müller OwningOpRef<Operation *> other) {
81*99c15eb4SIngo Müller assert(target->hasTrait<OpTrait::SymbolTable>() &&
82*99c15eb4SIngo Müller "requires target to implement the 'SymbolTable' trait");
83*99c15eb4SIngo Müller assert(other->hasTrait<OpTrait::SymbolTable>() &&
84*99c15eb4SIngo Müller "requires target to implement the 'SymbolTable' trait");
85*99c15eb4SIngo Müller
86*99c15eb4SIngo Müller SymbolTable targetSymbolTable(target);
87*99c15eb4SIngo Müller SymbolTable otherSymbolTable(*other);
88*99c15eb4SIngo Müller
89*99c15eb4SIngo Müller // Step 1:
90*99c15eb4SIngo Müller //
91*99c15eb4SIngo Müller // Rename private symbols in both ops in order to resolve conflicts that can
92*99c15eb4SIngo Müller // be resolved that way.
93*99c15eb4SIngo Müller LLVM_DEBUG(DBGS() << "renaming private symbols to resolve conflicts:\n");
94*99c15eb4SIngo Müller // TODO: Do we *actually* need to test in both directions?
95*99c15eb4SIngo Müller for (auto &&[symbolTable, otherSymbolTable] : llvm::zip(
96*99c15eb4SIngo Müller SmallVector<SymbolTable *, 2>{&targetSymbolTable, &otherSymbolTable},
97*99c15eb4SIngo Müller SmallVector<SymbolTable *, 2>{&otherSymbolTable,
98*99c15eb4SIngo Müller &targetSymbolTable})) {
99*99c15eb4SIngo Müller Operation *symbolTableOp = symbolTable->getOp();
100*99c15eb4SIngo Müller for (Operation &op : symbolTableOp->getRegion(0).front()) {
101*99c15eb4SIngo Müller auto symbolOp = dyn_cast<SymbolOpInterface>(op);
102*99c15eb4SIngo Müller if (!symbolOp)
103*99c15eb4SIngo Müller continue;
104*99c15eb4SIngo Müller StringAttr name = symbolOp.getNameAttr();
105*99c15eb4SIngo Müller LLVM_DEBUG(DBGS() << " found @" << name.getValue() << "\n");
106*99c15eb4SIngo Müller
107*99c15eb4SIngo Müller // Check if there is a colliding op in the other module.
108*99c15eb4SIngo Müller auto collidingOp =
109*99c15eb4SIngo Müller cast_or_null<SymbolOpInterface>(otherSymbolTable->lookup(name));
110*99c15eb4SIngo Müller if (!collidingOp)
111*99c15eb4SIngo Müller continue;
112*99c15eb4SIngo Müller
113*99c15eb4SIngo Müller LLVM_DEBUG(DBGS() << " collision found for @" << name.getValue());
114*99c15eb4SIngo Müller
115*99c15eb4SIngo Müller // Collisions are fine if both opt are functions and can be merged.
116*99c15eb4SIngo Müller if (auto funcOp = dyn_cast<FunctionOpInterface>(op),
117*99c15eb4SIngo Müller collidingFuncOp =
118*99c15eb4SIngo Müller dyn_cast<FunctionOpInterface>(collidingOp.getOperation());
119*99c15eb4SIngo Müller funcOp && collidingFuncOp) {
120*99c15eb4SIngo Müller if (canMergeInto(funcOp, collidingFuncOp) ||
121*99c15eb4SIngo Müller canMergeInto(collidingFuncOp, funcOp)) {
122*99c15eb4SIngo Müller LLVM_DEBUG(llvm::dbgs() << " but both ops are functions and "
123*99c15eb4SIngo Müller "will be merged\n");
124*99c15eb4SIngo Müller continue;
125*99c15eb4SIngo Müller }
126*99c15eb4SIngo Müller
127*99c15eb4SIngo Müller // If they can't be merged, proceed like any other collision.
128*99c15eb4SIngo Müller LLVM_DEBUG(llvm::dbgs() << " and both ops are function definitions");
129*99c15eb4SIngo Müller }
130*99c15eb4SIngo Müller
131*99c15eb4SIngo Müller // Collision can be resolved by renaming if one of the ops is private.
132*99c15eb4SIngo Müller auto renameToUnique =
133*99c15eb4SIngo Müller [&](SymbolOpInterface op, SymbolOpInterface otherOp,
134*99c15eb4SIngo Müller SymbolTable &symbolTable,
135*99c15eb4SIngo Müller SymbolTable &otherSymbolTable) -> InFlightDiagnostic {
136*99c15eb4SIngo Müller LLVM_DEBUG(llvm::dbgs() << ", renaming\n");
137*99c15eb4SIngo Müller FailureOr<StringAttr> maybeNewName =
138*99c15eb4SIngo Müller symbolTable.renameToUnique(op, {&otherSymbolTable});
139*99c15eb4SIngo Müller if (failed(maybeNewName)) {
140*99c15eb4SIngo Müller InFlightDiagnostic diag = op->emitError("failed to rename symbol");
141*99c15eb4SIngo Müller diag.attachNote(otherOp->getLoc())
142*99c15eb4SIngo Müller << "attempted renaming due to collision with this op";
143*99c15eb4SIngo Müller return diag;
144*99c15eb4SIngo Müller }
145*99c15eb4SIngo Müller LLVM_DEBUG(DBGS() << " renamed to @" << maybeNewName->getValue()
146*99c15eb4SIngo Müller << "\n");
147*99c15eb4SIngo Müller return InFlightDiagnostic();
148*99c15eb4SIngo Müller };
149*99c15eb4SIngo Müller
150*99c15eb4SIngo Müller if (symbolOp.isPrivate()) {
151*99c15eb4SIngo Müller InFlightDiagnostic diag = renameToUnique(
152*99c15eb4SIngo Müller symbolOp, collidingOp, *symbolTable, *otherSymbolTable);
153*99c15eb4SIngo Müller if (failed(diag))
154*99c15eb4SIngo Müller return diag;
155*99c15eb4SIngo Müller continue;
156*99c15eb4SIngo Müller }
157*99c15eb4SIngo Müller if (collidingOp.isPrivate()) {
158*99c15eb4SIngo Müller InFlightDiagnostic diag = renameToUnique(
159*99c15eb4SIngo Müller collidingOp, symbolOp, *otherSymbolTable, *symbolTable);
160*99c15eb4SIngo Müller if (failed(diag))
161*99c15eb4SIngo Müller return diag;
162*99c15eb4SIngo Müller continue;
163*99c15eb4SIngo Müller }
164*99c15eb4SIngo Müller LLVM_DEBUG(llvm::dbgs() << ", emitting error\n");
165*99c15eb4SIngo Müller InFlightDiagnostic diag = symbolOp.emitError()
166*99c15eb4SIngo Müller << "doubly defined symbol @" << name.getValue();
167*99c15eb4SIngo Müller diag.attachNote(collidingOp->getLoc()) << "previously defined here";
168*99c15eb4SIngo Müller return diag;
169*99c15eb4SIngo Müller }
170*99c15eb4SIngo Müller }
171*99c15eb4SIngo Müller
172*99c15eb4SIngo Müller // TODO: This duplicates pass infrastructure. We should split this pass into
173*99c15eb4SIngo Müller // several and let the pass infrastructure do the verification.
174*99c15eb4SIngo Müller for (auto *op : SmallVector<Operation *>{target, *other}) {
175*99c15eb4SIngo Müller if (failed(mlir::verify(op)))
176*99c15eb4SIngo Müller return op->emitError() << "failed to verify input op after renaming";
177*99c15eb4SIngo Müller }
178*99c15eb4SIngo Müller
179*99c15eb4SIngo Müller // Step 2:
180*99c15eb4SIngo Müller //
181*99c15eb4SIngo Müller // Move all ops from `other` into target and merge public symbols.
182*99c15eb4SIngo Müller LLVM_DEBUG(DBGS() << "moving all symbols into target\n");
183*99c15eb4SIngo Müller {
184*99c15eb4SIngo Müller SmallVector<SymbolOpInterface> opsToMove;
185*99c15eb4SIngo Müller for (Operation &op : other->getRegion(0).front()) {
186*99c15eb4SIngo Müller if (auto symbol = dyn_cast<SymbolOpInterface>(op))
187*99c15eb4SIngo Müller opsToMove.push_back(symbol);
188*99c15eb4SIngo Müller }
189*99c15eb4SIngo Müller
190*99c15eb4SIngo Müller for (SymbolOpInterface op : opsToMove) {
191*99c15eb4SIngo Müller // Remember potentially colliding op in the target module.
192*99c15eb4SIngo Müller auto collidingOp = cast_or_null<SymbolOpInterface>(
193*99c15eb4SIngo Müller targetSymbolTable.lookup(op.getNameAttr()));
194*99c15eb4SIngo Müller
195*99c15eb4SIngo Müller // Move op even if we get a collision.
196*99c15eb4SIngo Müller LLVM_DEBUG(DBGS() << " moving @" << op.getName());
197*99c15eb4SIngo Müller op->moveBefore(&target->getRegion(0).front(),
198*99c15eb4SIngo Müller target->getRegion(0).front().end());
199*99c15eb4SIngo Müller
200*99c15eb4SIngo Müller // If there is no collision, we are done.
201*99c15eb4SIngo Müller if (!collidingOp) {
202*99c15eb4SIngo Müller LLVM_DEBUG(llvm::dbgs() << " without collision\n");
203*99c15eb4SIngo Müller continue;
204*99c15eb4SIngo Müller }
205*99c15eb4SIngo Müller
206*99c15eb4SIngo Müller // The two colliding ops must both be functions because we have already
207*99c15eb4SIngo Müller // emitted errors otherwise earlier.
208*99c15eb4SIngo Müller auto funcOp = cast<FunctionOpInterface>(op.getOperation());
209*99c15eb4SIngo Müller auto collidingFuncOp =
210*99c15eb4SIngo Müller cast<FunctionOpInterface>(collidingOp.getOperation());
211*99c15eb4SIngo Müller
212*99c15eb4SIngo Müller // Both ops are in the target module now and can be treated
213*99c15eb4SIngo Müller // symmetrically, so w.l.o.g. we can reduce to merging `funcOp` into
214*99c15eb4SIngo Müller // `collidingFuncOp`.
215*99c15eb4SIngo Müller if (!canMergeInto(funcOp, collidingFuncOp)) {
216*99c15eb4SIngo Müller std::swap(funcOp, collidingFuncOp);
217*99c15eb4SIngo Müller }
218*99c15eb4SIngo Müller assert(canMergeInto(funcOp, collidingFuncOp));
219*99c15eb4SIngo Müller
220*99c15eb4SIngo Müller LLVM_DEBUG(llvm::dbgs() << " with collision, trying to keep op at "
221*99c15eb4SIngo Müller << collidingFuncOp.getLoc() << ":\n"
222*99c15eb4SIngo Müller << collidingFuncOp << "\n");
223*99c15eb4SIngo Müller
224*99c15eb4SIngo Müller // Update symbol table. This works with or without the previous `swap`.
225*99c15eb4SIngo Müller targetSymbolTable.remove(funcOp);
226*99c15eb4SIngo Müller targetSymbolTable.insert(collidingFuncOp);
227*99c15eb4SIngo Müller assert(targetSymbolTable.lookup(funcOp.getName()) == collidingFuncOp);
228*99c15eb4SIngo Müller
229*99c15eb4SIngo Müller // Do the actual merging.
230*99c15eb4SIngo Müller {
231*99c15eb4SIngo Müller InFlightDiagnostic diag = mergeInto(funcOp, collidingFuncOp);
232*99c15eb4SIngo Müller if (failed(diag))
233*99c15eb4SIngo Müller return diag;
234*99c15eb4SIngo Müller }
235*99c15eb4SIngo Müller }
236*99c15eb4SIngo Müller }
237*99c15eb4SIngo Müller
238*99c15eb4SIngo Müller if (failed(mlir::verify(target)))
239*99c15eb4SIngo Müller return target->emitError()
240*99c15eb4SIngo Müller << "failed to verify target op after merging symbols";
241*99c15eb4SIngo Müller
242*99c15eb4SIngo Müller LLVM_DEBUG(DBGS() << "done merging ops\n");
243*99c15eb4SIngo Müller return InFlightDiagnostic();
244*99c15eb4SIngo Müller }
245