1 //===- TestTransformStateExtension.h - Test Utility -------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines an TransformState extension for the purpose of testing the 10 // relevant APIs. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #ifndef MLIR_TEST_LIB_DIALECT_TRANSFORM_TESTTRANSFORMSTATEEXTENSION_H 15 #define MLIR_TEST_LIB_DIALECT_TRANSFORM_TESTTRANSFORMSTATEEXTENSION_H 16 17 #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" 18 19 using namespace mlir; 20 21 namespace mlir { 22 namespace test { 23 class TestTransformStateExtension 24 : public transform::TransformState::Extension { 25 public: 26 TestTransformStateExtension(transform::TransformState &state, 27 StringAttr message) 28 : Extension(state), message(message) {} 29 30 StringRef getMessage() const { return message.getValue(); } 31 32 LogicalResult updateMapping(Operation *previous, Operation *updated); 33 34 private: 35 StringAttr message; 36 }; 37 38 class TransformStateInitializerExtension 39 : public transform::TransformState::Extension { 40 public: 41 TransformStateInitializerExtension(transform::TransformState &state, 42 int numOp, 43 SmallVector<std::string> ®isteredOps) 44 : Extension(state), numOp(numOp), registeredOps(registeredOps) {} 45 46 int getNumOp() { return numOp; } 47 void setNumOp(int num) { numOp = num; } 48 SmallVector<std::string> getRegisteredOps() { return registeredOps; } 49 void pushRegisteredOps(const std::string &newOp) { 50 registeredOps.push_back(newOp); 51 } 52 std::string printMessage() const { 53 std::string message = "Registered transformOps are: "; 54 for (const auto &op : registeredOps) { 55 message += op + " | "; 56 } 57 return message; 58 } 59 60 private: 61 int numOp; 62 SmallVector<std::string> registeredOps; 63 }; 64 65 } // namespace test 66 } // namespace mlir 67 68 #endif // MLIR_TEST_LIB_DIALECT_TRANSFORM_TESTTRANSFORMSTATEEXTENSION_H 69