xref: /llvm-project/mlir/test/lib/Dialect/Transform/TestTransformStateExtension.h (revision 6634d44e5e6079e19efe54c2de35e2e63108b085)
1 //===- TestTransformStateExtension.h - Test Utility -------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines an TransformState extension for the purpose of testing the
10 // relevant APIs.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #ifndef MLIR_TEST_LIB_DIALECT_TRANSFORM_TESTTRANSFORMSTATEEXTENSION_H
15 #define MLIR_TEST_LIB_DIALECT_TRANSFORM_TESTTRANSFORMSTATEEXTENSION_H
16 
17 #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
18 
19 using namespace mlir;
20 
21 namespace mlir {
22 namespace test {
23 class TestTransformStateExtension
24     : public transform::TransformState::Extension {
25 public:
26   TestTransformStateExtension(transform::TransformState &state,
27                               StringAttr message)
28       : Extension(state), message(message) {}
29 
30   StringRef getMessage() const { return message.getValue(); }
31 
32   LogicalResult updateMapping(Operation *previous, Operation *updated);
33 
34 private:
35   StringAttr message;
36 };
37 
38 class TransformStateInitializerExtension
39     : public transform::TransformState::Extension {
40 public:
41   TransformStateInitializerExtension(transform::TransformState &state,
42                                      int numOp,
43                                      SmallVector<std::string> &registeredOps)
44       : Extension(state), numOp(numOp), registeredOps(registeredOps) {}
45 
46   int getNumOp() { return numOp; }
47   void setNumOp(int num) { numOp = num; }
48   SmallVector<std::string> getRegisteredOps() { return registeredOps; }
49   void pushRegisteredOps(const std::string &newOp) {
50     registeredOps.push_back(newOp);
51   }
52   std::string printMessage() const {
53     std::string message = "Registered transformOps are: ";
54     for (const auto &op : registeredOps) {
55       message += op + " | ";
56     }
57     return message;
58   }
59 
60 private:
61   int numOp;
62   SmallVector<std::string> registeredOps;
63 };
64 
65 } // namespace test
66 } // namespace mlir
67 
68 #endif // MLIR_TEST_LIB_DIALECT_TRANSFORM_TESTTRANSFORMSTATEEXTENSION_H
69