xref: /llvm-project/mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h (revision fb582b6ace781ff6991775d6dcd4df98aa16698f)
1 //===- Transforms.h - Mesh Transforms ---------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_DIALECT_MESH_TRANSFORMS_TRANSFORMS_H
10 #define MLIR_DIALECT_MESH_TRANSFORMS_TRANSFORMS_H
11 
12 #include "mlir/Dialect/Mesh/IR/MeshOps.h"
13 #include "mlir/IR/BuiltinTypes.h"
14 #include "mlir/IR/Value.h"
15 #include "mlir/Support/LLVM.h"
16 #include "llvm/ADT/ArrayRef.h"
17 
18 namespace mlir {
19 class RewritePatternSet;
20 class SymbolTableCollection;
21 class DialectRegistry;
22 class ImplicitLocOpBuilder;
23 namespace mesh {
24 
25 void populateProcessMultiIndexOpLoweringPatterns(
26     RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection);
27 void registerProcessMultiIndexOpLoweringDialects(DialectRegistry &registry);
28 
29 void populateAllSliceOpLoweringPatterns(
30     RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection);
31 void registerAllSliceOpLoweringDialects(DialectRegistry &registry);
32 
33 void populateAllOpLoweringPatterns(
34     RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection);
35 void registerAllOpLoweringDialects(DialectRegistry &registry);
36 
37 TypedValue<IndexType>
38 createCollectiveProcessGroupSize(MeshOp mesh, ArrayRef<MeshAxis> axes,
39                                  ImplicitLocOpBuilder &builder);
40 
41 // Get process linear index along the given mesh axes.
42 TypedValue<IndexType> createProcessLinearIndex(StringRef mesh,
43                                                ArrayRef<MeshAxis> meshAxes,
44                                                ImplicitLocOpBuilder &builder);
45 
46 } // namespace mesh
47 } // namespace mlir
48 
49 #endif // MLIR_DIALECT_MESH_TRANSFORMS_TRANSFORMS_H
50