xref: /llvm-project/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td (revision baabcb28983edf8f20e39b89e2b1745412073b44)
1//===- MeshBase.td - Mesh Dialect --------------------------*- tablegen -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef MLIR_DIALECT_MESH_IR_MESHBASE_TD
10#define MLIR_DIALECT_MESH_IR_MESHBASE_TD
11
12include "mlir/IR/OpBase.td"
13include "mlir/IR/AttrTypeBase.td"
14include "mlir/IR/BuiltinTypeInterfaces.td"
15include "mlir/IR/CommonAttrConstraints.td"
16include "mlir/IR/EnumAttr.td"
17
18//===----------------------------------------------------------------------===//
19// Mesh Dialect
20//===----------------------------------------------------------------------===//
21
22def Mesh_Dialect : Dialect {
23  let name = "mesh";
24  let cppNamespace = "::mlir::mesh";
25
26  let description = [{
27    See [Mesh dialect documentation](mlir/docs/Dialects/Mesh.md).
28  }];
29
30  let dependentDialects = [
31    "arith::ArithDialect" // For materializeConstant()
32  ];
33
34  let useDefaultAttributePrinterParser = 1;
35  let useDefaultTypePrinterParser = 1;
36  let hasConstantMaterializer = 1;
37}
38
39def Mesh_MeshAxis : I<16>;
40def Mesh_MeshAxesAttr : DenseArrayAttrBase<"DenseI16ArrayAttr", "int16_t", "i16">;
41def Mesh_ShardShapeAttr : DenseArrayAttrBase<"DenseI64ArrayAttr", "int64_t", "i64">;
42
43//===----------------------------------------------------------------------===//
44// Mesh Enums.
45//===----------------------------------------------------------------------===//
46
47def Mesh_ReductionKind : I32EnumAttr<"ReductionKind",
48  "Reduction of an iterator/mesh dimension.", [
49  I32EnumAttrCase<"Sum", 1, "sum">,
50  I32EnumAttrCase<"Max", 2, "max">,
51  I32EnumAttrCase<"Min", 3, "min">,
52  I32EnumAttrCase<"Product", 4, "product">,
53  // Arithmetic mean.
54  I32EnumAttrCase<"Average", 5, "average">,
55  I32EnumAttrCase<"BitwiseAnd", 6, "bitwise_and">,
56  I32EnumAttrCase<"BitwiseOr", 7, "bitwise_or">,
57  I32EnumAttrCase<"BitwiseXor", 8, "bitwise_xor">,
58  I32EnumAttrCase<"Generic", 100, "generic">
59]> {
60  let genSpecializedAttr = 0;
61  let cppNamespace = "::mlir::mesh";
62}
63
64def Mesh_ReductionKindAttr : EnumAttr<Mesh_Dialect, Mesh_ReductionKind, "partial"> {
65  let assemblyFormat = "$value";
66}
67
68class Mesh_Type<string name, string typeMnemonic, list<Trait> traits = [],
69                   string baseCppClass = "::mlir::Type">
70    : TypeDef<Mesh_Dialect, name, traits, baseCppClass> {
71  let mnemonic = typeMnemonic;
72}
73
74def Mesh_Sharding : Mesh_Type<"Sharding", "sharding"> {
75  let summary = "sharding definition";
76  let assemblyFormat = "";
77}
78
79//===----------------------------------------------------------------------===//
80// Mesh Attribute
81//===----------------------------------------------------------------------===//
82
83def Mesh_MeshAxesArrayAttr : AttrDef<Mesh_Dialect, "MeshAxesArray"> {
84  let mnemonic = "axisarray";
85  let parameters = (ins ArrayRefParameter<"MeshAxesAttr">:$axes);
86  let assemblyFormat = "`[` $axes `]`";
87  let extraClassDeclaration = [{
88    size_t size() const { return getAxes().size(); }
89    auto begin() const { return getAxes().begin(); }
90    auto end() const { return getAxes().end(); }
91  }];
92}
93
94#endif // MLIR_DIALECT_MESH_IR_MESHBASE_TD
95