xref: /llvm-project/mlir/lib/Dialect/Mesh/Transforms/TransformsDetail.h (revision dc3258c617420e83caff63c93d548e0923b10791)
1 //===- TransformsDetail.h - -------------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_DIALECT_MESH_TRANSFORMS_TRANSFORMSDETAIL_H
10 #define MLIR_DIALECT_MESH_TRANSFORMS_TRANSFORMSDETAIL_H
11 
12 #include "mlir/IR/PatternMatch.h"
13 #include "mlir/IR/SymbolTable.h"
14 
15 namespace mlir {
16 namespace mesh {
17 
18 template <typename Op>
19 struct OpRewritePatternWithSymbolTableCollection : OpRewritePattern<Op> {
20   template <typename... OpRewritePatternArgs>
OpRewritePatternWithSymbolTableCollectionOpRewritePatternWithSymbolTableCollection21   OpRewritePatternWithSymbolTableCollection(
22       SymbolTableCollection &symbolTableCollection,
23       OpRewritePatternArgs &&...opRewritePatternArgs)
24       : OpRewritePattern<Op>(
25             std::forward<OpRewritePatternArgs...>(opRewritePatternArgs)...),
26         symbolTableCollection(symbolTableCollection) {}
27 
28 protected:
29   SymbolTableCollection &symbolTableCollection;
30 };
31 
32 } // namespace mesh
33 } // namespace mlir
34 
35 #endif // MLIR_DIALECT_MESH_TRANSFORMS_TRANSFORMSDETAIL_H
36