xref: /llvm-project/mlir/lib/Dialect/Transform/IR/Utils.cpp (revision 99c15eb49ba0b607314b3bd221f0760049130d97)
1*99c15eb4SIngo Müller //===- Utils.cpp - Utils related to the transform dialect -------*- C++ -*-===//
2*99c15eb4SIngo Müller //
3*99c15eb4SIngo Müller // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*99c15eb4SIngo Müller // See https://llvm.org/LICENSE.txt for license information.
5*99c15eb4SIngo Müller // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*99c15eb4SIngo Müller //
7*99c15eb4SIngo Müller //===----------------------------------------------------------------------===//
8*99c15eb4SIngo Müller 
9*99c15eb4SIngo Müller #include "mlir/Dialect/Transform/IR/Utils.h"
10*99c15eb4SIngo Müller #include "mlir/Dialect/Transform/IR/TransformDialect.h"
11*99c15eb4SIngo Müller #include "mlir/IR/Verifier.h"
12*99c15eb4SIngo Müller #include "mlir/Interfaces/FunctionInterfaces.h"
13*99c15eb4SIngo Müller #include "llvm/Support/Debug.h"
14*99c15eb4SIngo Müller 
15*99c15eb4SIngo Müller using namespace mlir;
16*99c15eb4SIngo Müller 
17*99c15eb4SIngo Müller #define DEBUG_TYPE "transform-dialect-utils"
18*99c15eb4SIngo Müller #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
19*99c15eb4SIngo Müller 
20*99c15eb4SIngo Müller /// Return whether `func1` can be merged into `func2`. For that to work
21*99c15eb4SIngo Müller /// `func1` has to be a declaration (aka has to be external) and `func2`
22*99c15eb4SIngo Müller /// either has to be a declaration as well, or it has to be public (otherwise,
23*99c15eb4SIngo Müller /// it wouldn't be visible by `func1`).
canMergeInto(FunctionOpInterface func1,FunctionOpInterface func2)24*99c15eb4SIngo Müller static bool canMergeInto(FunctionOpInterface func1, FunctionOpInterface func2) {
25*99c15eb4SIngo Müller   return func1.isExternal() && (func2.isPublic() || func2.isExternal());
26*99c15eb4SIngo Müller }
27*99c15eb4SIngo Müller 
28*99c15eb4SIngo Müller /// Merge `func1` into `func2`. The two ops must be inside the same parent op
29*99c15eb4SIngo Müller /// and mergable according to `canMergeInto`. The function erases `func1` such
30*99c15eb4SIngo Müller /// that only `func2` exists when the function returns.
mergeInto(FunctionOpInterface func1,FunctionOpInterface func2)31*99c15eb4SIngo Müller static InFlightDiagnostic mergeInto(FunctionOpInterface func1,
32*99c15eb4SIngo Müller                                     FunctionOpInterface func2) {
33*99c15eb4SIngo Müller   assert(canMergeInto(func1, func2));
34*99c15eb4SIngo Müller   assert(func1->getParentOp() == func2->getParentOp() &&
35*99c15eb4SIngo Müller          "expected func1 and func2 to be in the same parent op");
36*99c15eb4SIngo Müller 
37*99c15eb4SIngo Müller   // Check that function signatures match.
38*99c15eb4SIngo Müller   if (func1.getFunctionType() != func2.getFunctionType()) {
39*99c15eb4SIngo Müller     return func1.emitError()
40*99c15eb4SIngo Müller            << "external definition has a mismatching signature ("
41*99c15eb4SIngo Müller            << func2.getFunctionType() << ")";
42*99c15eb4SIngo Müller   }
43*99c15eb4SIngo Müller 
44*99c15eb4SIngo Müller   // Check and merge argument attributes.
45*99c15eb4SIngo Müller   MLIRContext *context = func1->getContext();
46*99c15eb4SIngo Müller   auto *td = context->getLoadedDialect<transform::TransformDialect>();
47*99c15eb4SIngo Müller   StringAttr consumedName = td->getConsumedAttrName();
48*99c15eb4SIngo Müller   StringAttr readOnlyName = td->getReadOnlyAttrName();
49*99c15eb4SIngo Müller   for (unsigned i = 0, e = func1.getNumArguments(); i < e; ++i) {
50*99c15eb4SIngo Müller     bool isExternalConsumed = func2.getArgAttr(i, consumedName) != nullptr;
51*99c15eb4SIngo Müller     bool isExternalReadonly = func2.getArgAttr(i, readOnlyName) != nullptr;
52*99c15eb4SIngo Müller     bool isConsumed = func1.getArgAttr(i, consumedName) != nullptr;
53*99c15eb4SIngo Müller     bool isReadonly = func1.getArgAttr(i, readOnlyName) != nullptr;
54*99c15eb4SIngo Müller     if (!isExternalConsumed && !isExternalReadonly) {
55*99c15eb4SIngo Müller       if (isConsumed)
56*99c15eb4SIngo Müller         func2.setArgAttr(i, consumedName, UnitAttr::get(context));
57*99c15eb4SIngo Müller       else if (isReadonly)
58*99c15eb4SIngo Müller         func2.setArgAttr(i, readOnlyName, UnitAttr::get(context));
59*99c15eb4SIngo Müller       continue;
60*99c15eb4SIngo Müller     }
61*99c15eb4SIngo Müller 
62*99c15eb4SIngo Müller     if ((isExternalConsumed && !isConsumed) ||
63*99c15eb4SIngo Müller         (isExternalReadonly && !isReadonly)) {
64*99c15eb4SIngo Müller       return func1.emitError()
65*99c15eb4SIngo Müller              << "external definition has mismatching consumption "
66*99c15eb4SIngo Müller                 "annotations for argument #"
67*99c15eb4SIngo Müller              << i;
68*99c15eb4SIngo Müller     }
69*99c15eb4SIngo Müller   }
70*99c15eb4SIngo Müller 
71*99c15eb4SIngo Müller   // `func1` is the external one, so we can remove it.
72*99c15eb4SIngo Müller   assert(func1.isExternal());
73*99c15eb4SIngo Müller   func1->erase();
74*99c15eb4SIngo Müller 
75*99c15eb4SIngo Müller   return InFlightDiagnostic();
76*99c15eb4SIngo Müller }
77*99c15eb4SIngo Müller 
78*99c15eb4SIngo Müller InFlightDiagnostic
mergeSymbolsInto(Operation * target,OwningOpRef<Operation * > other)79*99c15eb4SIngo Müller transform::detail::mergeSymbolsInto(Operation *target,
80*99c15eb4SIngo Müller                                     OwningOpRef<Operation *> other) {
81*99c15eb4SIngo Müller   assert(target->hasTrait<OpTrait::SymbolTable>() &&
82*99c15eb4SIngo Müller          "requires target to implement the 'SymbolTable' trait");
83*99c15eb4SIngo Müller   assert(other->hasTrait<OpTrait::SymbolTable>() &&
84*99c15eb4SIngo Müller          "requires target to implement the 'SymbolTable' trait");
85*99c15eb4SIngo Müller 
86*99c15eb4SIngo Müller   SymbolTable targetSymbolTable(target);
87*99c15eb4SIngo Müller   SymbolTable otherSymbolTable(*other);
88*99c15eb4SIngo Müller 
89*99c15eb4SIngo Müller   // Step 1:
90*99c15eb4SIngo Müller   //
91*99c15eb4SIngo Müller   // Rename private symbols in both ops in order to resolve conflicts that can
92*99c15eb4SIngo Müller   // be resolved that way.
93*99c15eb4SIngo Müller   LLVM_DEBUG(DBGS() << "renaming private symbols to resolve conflicts:\n");
94*99c15eb4SIngo Müller   // TODO: Do we *actually* need to test in both directions?
95*99c15eb4SIngo Müller   for (auto &&[symbolTable, otherSymbolTable] : llvm::zip(
96*99c15eb4SIngo Müller            SmallVector<SymbolTable *, 2>{&targetSymbolTable, &otherSymbolTable},
97*99c15eb4SIngo Müller            SmallVector<SymbolTable *, 2>{&otherSymbolTable,
98*99c15eb4SIngo Müller                                          &targetSymbolTable})) {
99*99c15eb4SIngo Müller     Operation *symbolTableOp = symbolTable->getOp();
100*99c15eb4SIngo Müller     for (Operation &op : symbolTableOp->getRegion(0).front()) {
101*99c15eb4SIngo Müller       auto symbolOp = dyn_cast<SymbolOpInterface>(op);
102*99c15eb4SIngo Müller       if (!symbolOp)
103*99c15eb4SIngo Müller         continue;
104*99c15eb4SIngo Müller       StringAttr name = symbolOp.getNameAttr();
105*99c15eb4SIngo Müller       LLVM_DEBUG(DBGS() << "  found @" << name.getValue() << "\n");
106*99c15eb4SIngo Müller 
107*99c15eb4SIngo Müller       // Check if there is a colliding op in the other module.
108*99c15eb4SIngo Müller       auto collidingOp =
109*99c15eb4SIngo Müller           cast_or_null<SymbolOpInterface>(otherSymbolTable->lookup(name));
110*99c15eb4SIngo Müller       if (!collidingOp)
111*99c15eb4SIngo Müller         continue;
112*99c15eb4SIngo Müller 
113*99c15eb4SIngo Müller       LLVM_DEBUG(DBGS() << "    collision found for @" << name.getValue());
114*99c15eb4SIngo Müller 
115*99c15eb4SIngo Müller       // Collisions are fine if both opt are functions and can be merged.
116*99c15eb4SIngo Müller       if (auto funcOp = dyn_cast<FunctionOpInterface>(op),
117*99c15eb4SIngo Müller           collidingFuncOp =
118*99c15eb4SIngo Müller               dyn_cast<FunctionOpInterface>(collidingOp.getOperation());
119*99c15eb4SIngo Müller           funcOp && collidingFuncOp) {
120*99c15eb4SIngo Müller         if (canMergeInto(funcOp, collidingFuncOp) ||
121*99c15eb4SIngo Müller             canMergeInto(collidingFuncOp, funcOp)) {
122*99c15eb4SIngo Müller           LLVM_DEBUG(llvm::dbgs() << " but both ops are functions and "
123*99c15eb4SIngo Müller                                      "will be merged\n");
124*99c15eb4SIngo Müller           continue;
125*99c15eb4SIngo Müller         }
126*99c15eb4SIngo Müller 
127*99c15eb4SIngo Müller         // If they can't be merged, proceed like any other collision.
128*99c15eb4SIngo Müller         LLVM_DEBUG(llvm::dbgs() << " and both ops are function definitions");
129*99c15eb4SIngo Müller       }
130*99c15eb4SIngo Müller 
131*99c15eb4SIngo Müller       // Collision can be resolved by renaming if one of the ops is private.
132*99c15eb4SIngo Müller       auto renameToUnique =
133*99c15eb4SIngo Müller           [&](SymbolOpInterface op, SymbolOpInterface otherOp,
134*99c15eb4SIngo Müller               SymbolTable &symbolTable,
135*99c15eb4SIngo Müller               SymbolTable &otherSymbolTable) -> InFlightDiagnostic {
136*99c15eb4SIngo Müller         LLVM_DEBUG(llvm::dbgs() << ", renaming\n");
137*99c15eb4SIngo Müller         FailureOr<StringAttr> maybeNewName =
138*99c15eb4SIngo Müller             symbolTable.renameToUnique(op, {&otherSymbolTable});
139*99c15eb4SIngo Müller         if (failed(maybeNewName)) {
140*99c15eb4SIngo Müller           InFlightDiagnostic diag = op->emitError("failed to rename symbol");
141*99c15eb4SIngo Müller           diag.attachNote(otherOp->getLoc())
142*99c15eb4SIngo Müller               << "attempted renaming due to collision with this op";
143*99c15eb4SIngo Müller           return diag;
144*99c15eb4SIngo Müller         }
145*99c15eb4SIngo Müller         LLVM_DEBUG(DBGS() << "      renamed to @" << maybeNewName->getValue()
146*99c15eb4SIngo Müller                           << "\n");
147*99c15eb4SIngo Müller         return InFlightDiagnostic();
148*99c15eb4SIngo Müller       };
149*99c15eb4SIngo Müller 
150*99c15eb4SIngo Müller       if (symbolOp.isPrivate()) {
151*99c15eb4SIngo Müller         InFlightDiagnostic diag = renameToUnique(
152*99c15eb4SIngo Müller             symbolOp, collidingOp, *symbolTable, *otherSymbolTable);
153*99c15eb4SIngo Müller         if (failed(diag))
154*99c15eb4SIngo Müller           return diag;
155*99c15eb4SIngo Müller         continue;
156*99c15eb4SIngo Müller       }
157*99c15eb4SIngo Müller       if (collidingOp.isPrivate()) {
158*99c15eb4SIngo Müller         InFlightDiagnostic diag = renameToUnique(
159*99c15eb4SIngo Müller             collidingOp, symbolOp, *otherSymbolTable, *symbolTable);
160*99c15eb4SIngo Müller         if (failed(diag))
161*99c15eb4SIngo Müller           return diag;
162*99c15eb4SIngo Müller         continue;
163*99c15eb4SIngo Müller       }
164*99c15eb4SIngo Müller       LLVM_DEBUG(llvm::dbgs() << ", emitting error\n");
165*99c15eb4SIngo Müller       InFlightDiagnostic diag = symbolOp.emitError()
166*99c15eb4SIngo Müller                                 << "doubly defined symbol @" << name.getValue();
167*99c15eb4SIngo Müller       diag.attachNote(collidingOp->getLoc()) << "previously defined here";
168*99c15eb4SIngo Müller       return diag;
169*99c15eb4SIngo Müller     }
170*99c15eb4SIngo Müller   }
171*99c15eb4SIngo Müller 
172*99c15eb4SIngo Müller   // TODO: This duplicates pass infrastructure. We should split this pass into
173*99c15eb4SIngo Müller   //       several and let the pass infrastructure do the verification.
174*99c15eb4SIngo Müller   for (auto *op : SmallVector<Operation *>{target, *other}) {
175*99c15eb4SIngo Müller     if (failed(mlir::verify(op)))
176*99c15eb4SIngo Müller       return op->emitError() << "failed to verify input op after renaming";
177*99c15eb4SIngo Müller   }
178*99c15eb4SIngo Müller 
179*99c15eb4SIngo Müller   // Step 2:
180*99c15eb4SIngo Müller   //
181*99c15eb4SIngo Müller   // Move all ops from `other` into target and merge public symbols.
182*99c15eb4SIngo Müller   LLVM_DEBUG(DBGS() << "moving all symbols into target\n");
183*99c15eb4SIngo Müller   {
184*99c15eb4SIngo Müller     SmallVector<SymbolOpInterface> opsToMove;
185*99c15eb4SIngo Müller     for (Operation &op : other->getRegion(0).front()) {
186*99c15eb4SIngo Müller       if (auto symbol = dyn_cast<SymbolOpInterface>(op))
187*99c15eb4SIngo Müller         opsToMove.push_back(symbol);
188*99c15eb4SIngo Müller     }
189*99c15eb4SIngo Müller 
190*99c15eb4SIngo Müller     for (SymbolOpInterface op : opsToMove) {
191*99c15eb4SIngo Müller       // Remember potentially colliding op in the target module.
192*99c15eb4SIngo Müller       auto collidingOp = cast_or_null<SymbolOpInterface>(
193*99c15eb4SIngo Müller           targetSymbolTable.lookup(op.getNameAttr()));
194*99c15eb4SIngo Müller 
195*99c15eb4SIngo Müller       // Move op even if we get a collision.
196*99c15eb4SIngo Müller       LLVM_DEBUG(DBGS() << "  moving @" << op.getName());
197*99c15eb4SIngo Müller       op->moveBefore(&target->getRegion(0).front(),
198*99c15eb4SIngo Müller                      target->getRegion(0).front().end());
199*99c15eb4SIngo Müller 
200*99c15eb4SIngo Müller       // If there is no collision, we are done.
201*99c15eb4SIngo Müller       if (!collidingOp) {
202*99c15eb4SIngo Müller         LLVM_DEBUG(llvm::dbgs() << " without collision\n");
203*99c15eb4SIngo Müller         continue;
204*99c15eb4SIngo Müller       }
205*99c15eb4SIngo Müller 
206*99c15eb4SIngo Müller       // The two colliding ops must both be functions because we have already
207*99c15eb4SIngo Müller       // emitted errors otherwise earlier.
208*99c15eb4SIngo Müller       auto funcOp = cast<FunctionOpInterface>(op.getOperation());
209*99c15eb4SIngo Müller       auto collidingFuncOp =
210*99c15eb4SIngo Müller           cast<FunctionOpInterface>(collidingOp.getOperation());
211*99c15eb4SIngo Müller 
212*99c15eb4SIngo Müller       // Both ops are in the target module now and can be treated
213*99c15eb4SIngo Müller       // symmetrically, so w.l.o.g. we can reduce to merging `funcOp` into
214*99c15eb4SIngo Müller       // `collidingFuncOp`.
215*99c15eb4SIngo Müller       if (!canMergeInto(funcOp, collidingFuncOp)) {
216*99c15eb4SIngo Müller         std::swap(funcOp, collidingFuncOp);
217*99c15eb4SIngo Müller       }
218*99c15eb4SIngo Müller       assert(canMergeInto(funcOp, collidingFuncOp));
219*99c15eb4SIngo Müller 
220*99c15eb4SIngo Müller       LLVM_DEBUG(llvm::dbgs() << " with collision, trying to keep op at "
221*99c15eb4SIngo Müller                               << collidingFuncOp.getLoc() << ":\n"
222*99c15eb4SIngo Müller                               << collidingFuncOp << "\n");
223*99c15eb4SIngo Müller 
224*99c15eb4SIngo Müller       // Update symbol table. This works with or without the previous `swap`.
225*99c15eb4SIngo Müller       targetSymbolTable.remove(funcOp);
226*99c15eb4SIngo Müller       targetSymbolTable.insert(collidingFuncOp);
227*99c15eb4SIngo Müller       assert(targetSymbolTable.lookup(funcOp.getName()) == collidingFuncOp);
228*99c15eb4SIngo Müller 
229*99c15eb4SIngo Müller       // Do the actual merging.
230*99c15eb4SIngo Müller       {
231*99c15eb4SIngo Müller         InFlightDiagnostic diag = mergeInto(funcOp, collidingFuncOp);
232*99c15eb4SIngo Müller         if (failed(diag))
233*99c15eb4SIngo Müller           return diag;
234*99c15eb4SIngo Müller       }
235*99c15eb4SIngo Müller     }
236*99c15eb4SIngo Müller   }
237*99c15eb4SIngo Müller 
238*99c15eb4SIngo Müller   if (failed(mlir::verify(target)))
239*99c15eb4SIngo Müller     return target->emitError()
240*99c15eb4SIngo Müller            << "failed to verify target op after merging symbols";
241*99c15eb4SIngo Müller 
242*99c15eb4SIngo Müller   LLVM_DEBUG(DBGS() << "done merging ops\n");
243*99c15eb4SIngo Müller   return InFlightDiagnostic();
244*99c15eb4SIngo Müller }
245