xref: /llvm-project/mlir/test/lib/Dialect/Transform/TestTransformStateExtension.cpp (revision 7b4a7552719b7720b9c8ccb4bc04a9e6fa1ec0b6)
1*7b4a7552SMatthias Springer //===- TestTransformStateExtension.cpp ------------------------------------===//
2*7b4a7552SMatthias Springer //
3*7b4a7552SMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*7b4a7552SMatthias Springer // See https://llvm.org/LICENSE.txt for license information.
5*7b4a7552SMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*7b4a7552SMatthias Springer //
7*7b4a7552SMatthias Springer //===----------------------------------------------------------------------===//
8*7b4a7552SMatthias Springer 
9*7b4a7552SMatthias Springer #include "TestTransformStateExtension.h"
10*7b4a7552SMatthias Springer 
11*7b4a7552SMatthias Springer using namespace mlir;
12*7b4a7552SMatthias Springer 
13*7b4a7552SMatthias Springer LogicalResult
updateMapping(Operation * previous,Operation * updated)14*7b4a7552SMatthias Springer test::TestTransformStateExtension::updateMapping(Operation *previous,
15*7b4a7552SMatthias Springer                                                  Operation *updated) {
16*7b4a7552SMatthias Springer   // Update value handles. The new ops should have at least as many results as
17*7b4a7552SMatthias Springer   // the replacement op. Fewer results are acceptable, if those results are not
18*7b4a7552SMatthias Springer   // mapped to any handle.
19*7b4a7552SMatthias Springer   for (auto r = updated->getNumResults(); r < previous->getNumResults(); ++r) {
20*7b4a7552SMatthias Springer     SmallVector<Value> handles;
21*7b4a7552SMatthias Springer     (void)getTransformState().getHandlesForPayloadValue(previous->getResult(r),
22*7b4a7552SMatthias Springer                                                         handles);
23*7b4a7552SMatthias Springer     if (!handles.empty())
24*7b4a7552SMatthias Springer       return emitError(previous->getLoc())
25*7b4a7552SMatthias Springer              << "cannot replace an op with another op producing fewer results "
26*7b4a7552SMatthias Springer                 "while tracking handles";
27*7b4a7552SMatthias Springer   }
28*7b4a7552SMatthias Springer 
29*7b4a7552SMatthias Springer   for (auto [oldValue, newValue] :
30*7b4a7552SMatthias Springer        llvm::zip(previous->getResults(), updated->getResults()))
31*7b4a7552SMatthias Springer     if (failed(replacePayloadValue(oldValue, newValue)))
32*7b4a7552SMatthias Springer       return failure();
33*7b4a7552SMatthias Springer 
34*7b4a7552SMatthias Springer   // Update op handle.
35*7b4a7552SMatthias Springer   return replacePayloadOp(previous, updated);
36*7b4a7552SMatthias Springer }
37