xref: /llvm-project/mlir/lib/Dialect/Transform/IR/Utils.cpp (revision 99c15eb49ba0b607314b3bd221f0760049130d97)
1 //===- Utils.cpp - Utils related to the transform dialect -------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Transform/IR/Utils.h"
10 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
11 #include "mlir/IR/Verifier.h"
12 #include "mlir/Interfaces/FunctionInterfaces.h"
13 #include "llvm/Support/Debug.h"
14 
15 using namespace mlir;
16 
17 #define DEBUG_TYPE "transform-dialect-utils"
18 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
19 
20 /// Return whether `func1` can be merged into `func2`. For that to work
21 /// `func1` has to be a declaration (aka has to be external) and `func2`
22 /// either has to be a declaration as well, or it has to be public (otherwise,
23 /// it wouldn't be visible by `func1`).
canMergeInto(FunctionOpInterface func1,FunctionOpInterface func2)24 static bool canMergeInto(FunctionOpInterface func1, FunctionOpInterface func2) {
25   return func1.isExternal() && (func2.isPublic() || func2.isExternal());
26 }
27 
28 /// Merge `func1` into `func2`. The two ops must be inside the same parent op
29 /// and mergable according to `canMergeInto`. The function erases `func1` such
30 /// that only `func2` exists when the function returns.
mergeInto(FunctionOpInterface func1,FunctionOpInterface func2)31 static InFlightDiagnostic mergeInto(FunctionOpInterface func1,
32                                     FunctionOpInterface func2) {
33   assert(canMergeInto(func1, func2));
34   assert(func1->getParentOp() == func2->getParentOp() &&
35          "expected func1 and func2 to be in the same parent op");
36 
37   // Check that function signatures match.
38   if (func1.getFunctionType() != func2.getFunctionType()) {
39     return func1.emitError()
40            << "external definition has a mismatching signature ("
41            << func2.getFunctionType() << ")";
42   }
43 
44   // Check and merge argument attributes.
45   MLIRContext *context = func1->getContext();
46   auto *td = context->getLoadedDialect<transform::TransformDialect>();
47   StringAttr consumedName = td->getConsumedAttrName();
48   StringAttr readOnlyName = td->getReadOnlyAttrName();
49   for (unsigned i = 0, e = func1.getNumArguments(); i < e; ++i) {
50     bool isExternalConsumed = func2.getArgAttr(i, consumedName) != nullptr;
51     bool isExternalReadonly = func2.getArgAttr(i, readOnlyName) != nullptr;
52     bool isConsumed = func1.getArgAttr(i, consumedName) != nullptr;
53     bool isReadonly = func1.getArgAttr(i, readOnlyName) != nullptr;
54     if (!isExternalConsumed && !isExternalReadonly) {
55       if (isConsumed)
56         func2.setArgAttr(i, consumedName, UnitAttr::get(context));
57       else if (isReadonly)
58         func2.setArgAttr(i, readOnlyName, UnitAttr::get(context));
59       continue;
60     }
61 
62     if ((isExternalConsumed && !isConsumed) ||
63         (isExternalReadonly && !isReadonly)) {
64       return func1.emitError()
65              << "external definition has mismatching consumption "
66                 "annotations for argument #"
67              << i;
68     }
69   }
70 
71   // `func1` is the external one, so we can remove it.
72   assert(func1.isExternal());
73   func1->erase();
74 
75   return InFlightDiagnostic();
76 }
77 
78 InFlightDiagnostic
mergeSymbolsInto(Operation * target,OwningOpRef<Operation * > other)79 transform::detail::mergeSymbolsInto(Operation *target,
80                                     OwningOpRef<Operation *> other) {
81   assert(target->hasTrait<OpTrait::SymbolTable>() &&
82          "requires target to implement the 'SymbolTable' trait");
83   assert(other->hasTrait<OpTrait::SymbolTable>() &&
84          "requires target to implement the 'SymbolTable' trait");
85 
86   SymbolTable targetSymbolTable(target);
87   SymbolTable otherSymbolTable(*other);
88 
89   // Step 1:
90   //
91   // Rename private symbols in both ops in order to resolve conflicts that can
92   // be resolved that way.
93   LLVM_DEBUG(DBGS() << "renaming private symbols to resolve conflicts:\n");
94   // TODO: Do we *actually* need to test in both directions?
95   for (auto &&[symbolTable, otherSymbolTable] : llvm::zip(
96            SmallVector<SymbolTable *, 2>{&targetSymbolTable, &otherSymbolTable},
97            SmallVector<SymbolTable *, 2>{&otherSymbolTable,
98                                          &targetSymbolTable})) {
99     Operation *symbolTableOp = symbolTable->getOp();
100     for (Operation &op : symbolTableOp->getRegion(0).front()) {
101       auto symbolOp = dyn_cast<SymbolOpInterface>(op);
102       if (!symbolOp)
103         continue;
104       StringAttr name = symbolOp.getNameAttr();
105       LLVM_DEBUG(DBGS() << "  found @" << name.getValue() << "\n");
106 
107       // Check if there is a colliding op in the other module.
108       auto collidingOp =
109           cast_or_null<SymbolOpInterface>(otherSymbolTable->lookup(name));
110       if (!collidingOp)
111         continue;
112 
113       LLVM_DEBUG(DBGS() << "    collision found for @" << name.getValue());
114 
115       // Collisions are fine if both opt are functions and can be merged.
116       if (auto funcOp = dyn_cast<FunctionOpInterface>(op),
117           collidingFuncOp =
118               dyn_cast<FunctionOpInterface>(collidingOp.getOperation());
119           funcOp && collidingFuncOp) {
120         if (canMergeInto(funcOp, collidingFuncOp) ||
121             canMergeInto(collidingFuncOp, funcOp)) {
122           LLVM_DEBUG(llvm::dbgs() << " but both ops are functions and "
123                                      "will be merged\n");
124           continue;
125         }
126 
127         // If they can't be merged, proceed like any other collision.
128         LLVM_DEBUG(llvm::dbgs() << " and both ops are function definitions");
129       }
130 
131       // Collision can be resolved by renaming if one of the ops is private.
132       auto renameToUnique =
133           [&](SymbolOpInterface op, SymbolOpInterface otherOp,
134               SymbolTable &symbolTable,
135               SymbolTable &otherSymbolTable) -> InFlightDiagnostic {
136         LLVM_DEBUG(llvm::dbgs() << ", renaming\n");
137         FailureOr<StringAttr> maybeNewName =
138             symbolTable.renameToUnique(op, {&otherSymbolTable});
139         if (failed(maybeNewName)) {
140           InFlightDiagnostic diag = op->emitError("failed to rename symbol");
141           diag.attachNote(otherOp->getLoc())
142               << "attempted renaming due to collision with this op";
143           return diag;
144         }
145         LLVM_DEBUG(DBGS() << "      renamed to @" << maybeNewName->getValue()
146                           << "\n");
147         return InFlightDiagnostic();
148       };
149 
150       if (symbolOp.isPrivate()) {
151         InFlightDiagnostic diag = renameToUnique(
152             symbolOp, collidingOp, *symbolTable, *otherSymbolTable);
153         if (failed(diag))
154           return diag;
155         continue;
156       }
157       if (collidingOp.isPrivate()) {
158         InFlightDiagnostic diag = renameToUnique(
159             collidingOp, symbolOp, *otherSymbolTable, *symbolTable);
160         if (failed(diag))
161           return diag;
162         continue;
163       }
164       LLVM_DEBUG(llvm::dbgs() << ", emitting error\n");
165       InFlightDiagnostic diag = symbolOp.emitError()
166                                 << "doubly defined symbol @" << name.getValue();
167       diag.attachNote(collidingOp->getLoc()) << "previously defined here";
168       return diag;
169     }
170   }
171 
172   // TODO: This duplicates pass infrastructure. We should split this pass into
173   //       several and let the pass infrastructure do the verification.
174   for (auto *op : SmallVector<Operation *>{target, *other}) {
175     if (failed(mlir::verify(op)))
176       return op->emitError() << "failed to verify input op after renaming";
177   }
178 
179   // Step 2:
180   //
181   // Move all ops from `other` into target and merge public symbols.
182   LLVM_DEBUG(DBGS() << "moving all symbols into target\n");
183   {
184     SmallVector<SymbolOpInterface> opsToMove;
185     for (Operation &op : other->getRegion(0).front()) {
186       if (auto symbol = dyn_cast<SymbolOpInterface>(op))
187         opsToMove.push_back(symbol);
188     }
189 
190     for (SymbolOpInterface op : opsToMove) {
191       // Remember potentially colliding op in the target module.
192       auto collidingOp = cast_or_null<SymbolOpInterface>(
193           targetSymbolTable.lookup(op.getNameAttr()));
194 
195       // Move op even if we get a collision.
196       LLVM_DEBUG(DBGS() << "  moving @" << op.getName());
197       op->moveBefore(&target->getRegion(0).front(),
198                      target->getRegion(0).front().end());
199 
200       // If there is no collision, we are done.
201       if (!collidingOp) {
202         LLVM_DEBUG(llvm::dbgs() << " without collision\n");
203         continue;
204       }
205 
206       // The two colliding ops must both be functions because we have already
207       // emitted errors otherwise earlier.
208       auto funcOp = cast<FunctionOpInterface>(op.getOperation());
209       auto collidingFuncOp =
210           cast<FunctionOpInterface>(collidingOp.getOperation());
211 
212       // Both ops are in the target module now and can be treated
213       // symmetrically, so w.l.o.g. we can reduce to merging `funcOp` into
214       // `collidingFuncOp`.
215       if (!canMergeInto(funcOp, collidingFuncOp)) {
216         std::swap(funcOp, collidingFuncOp);
217       }
218       assert(canMergeInto(funcOp, collidingFuncOp));
219 
220       LLVM_DEBUG(llvm::dbgs() << " with collision, trying to keep op at "
221                               << collidingFuncOp.getLoc() << ":\n"
222                               << collidingFuncOp << "\n");
223 
224       // Update symbol table. This works with or without the previous `swap`.
225       targetSymbolTable.remove(funcOp);
226       targetSymbolTable.insert(collidingFuncOp);
227       assert(targetSymbolTable.lookup(funcOp.getName()) == collidingFuncOp);
228 
229       // Do the actual merging.
230       {
231         InFlightDiagnostic diag = mergeInto(funcOp, collidingFuncOp);
232         if (failed(diag))
233           return diag;
234       }
235     }
236   }
237 
238   if (failed(mlir::verify(target)))
239     return target->emitError()
240            << "failed to verify target op after merging symbols";
241 
242   LLVM_DEBUG(DBGS() << "done merging ops\n");
243   return InFlightDiagnostic();
244 }
245