1*7b4a7552SMatthias Springer //===- TestTransformStateExtension.cpp ------------------------------------===// 2*7b4a7552SMatthias Springer // 3*7b4a7552SMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4*7b4a7552SMatthias Springer // See https://llvm.org/LICENSE.txt for license information. 5*7b4a7552SMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6*7b4a7552SMatthias Springer // 7*7b4a7552SMatthias Springer //===----------------------------------------------------------------------===// 8*7b4a7552SMatthias Springer 9*7b4a7552SMatthias Springer #include "TestTransformStateExtension.h" 10*7b4a7552SMatthias Springer 11*7b4a7552SMatthias Springer using namespace mlir; 12*7b4a7552SMatthias Springer 13*7b4a7552SMatthias Springer LogicalResult updateMapping(Operation * previous,Operation * updated)14*7b4a7552SMatthias Springertest::TestTransformStateExtension::updateMapping(Operation *previous, 15*7b4a7552SMatthias Springer Operation *updated) { 16*7b4a7552SMatthias Springer // Update value handles. The new ops should have at least as many results as 17*7b4a7552SMatthias Springer // the replacement op. Fewer results are acceptable, if those results are not 18*7b4a7552SMatthias Springer // mapped to any handle. 19*7b4a7552SMatthias Springer for (auto r = updated->getNumResults(); r < previous->getNumResults(); ++r) { 20*7b4a7552SMatthias Springer SmallVector<Value> handles; 21*7b4a7552SMatthias Springer (void)getTransformState().getHandlesForPayloadValue(previous->getResult(r), 22*7b4a7552SMatthias Springer handles); 23*7b4a7552SMatthias Springer if (!handles.empty()) 24*7b4a7552SMatthias Springer return emitError(previous->getLoc()) 25*7b4a7552SMatthias Springer << "cannot replace an op with another op producing fewer results " 26*7b4a7552SMatthias Springer "while tracking handles"; 27*7b4a7552SMatthias Springer } 28*7b4a7552SMatthias Springer 29*7b4a7552SMatthias Springer for (auto [oldValue, newValue] : 30*7b4a7552SMatthias Springer llvm::zip(previous->getResults(), updated->getResults())) 31*7b4a7552SMatthias Springer if (failed(replacePayloadValue(oldValue, newValue))) 32*7b4a7552SMatthias Springer return failure(); 33*7b4a7552SMatthias Springer 34*7b4a7552SMatthias Springer // Update op handle. 35*7b4a7552SMatthias Springer return replacePayloadOp(previous, updated); 36*7b4a7552SMatthias Springer } 37