1//===- MeshBase.td - Mesh Dialect --------------------------*- tablegen -*-===// 2// 3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4// See https://llvm.org/LICENSE.txt for license information. 5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6// 7//===----------------------------------------------------------------------===// 8 9#ifndef MLIR_DIALECT_MESH_IR_MESHBASE_TD 10#define MLIR_DIALECT_MESH_IR_MESHBASE_TD 11 12include "mlir/IR/OpBase.td" 13include "mlir/IR/AttrTypeBase.td" 14include "mlir/IR/BuiltinTypeInterfaces.td" 15include "mlir/IR/CommonAttrConstraints.td" 16include "mlir/IR/EnumAttr.td" 17 18//===----------------------------------------------------------------------===// 19// Mesh Dialect 20//===----------------------------------------------------------------------===// 21 22def Mesh_Dialect : Dialect { 23 let name = "mesh"; 24 let cppNamespace = "::mlir::mesh"; 25 26 let description = [{ 27 See [Mesh dialect documentation](mlir/docs/Dialects/Mesh.md). 28 }]; 29 30 let dependentDialects = [ 31 "arith::ArithDialect" // For materializeConstant() 32 ]; 33 34 let useDefaultAttributePrinterParser = 1; 35 let useDefaultTypePrinterParser = 1; 36 let hasConstantMaterializer = 1; 37} 38 39def Mesh_MeshAxis : I<16>; 40def Mesh_MeshAxesAttr : DenseArrayAttrBase<"DenseI16ArrayAttr", "int16_t", "i16">; 41def Mesh_ShardShapeAttr : DenseArrayAttrBase<"DenseI64ArrayAttr", "int64_t", "i64">; 42 43//===----------------------------------------------------------------------===// 44// Mesh Enums. 45//===----------------------------------------------------------------------===// 46 47def Mesh_ReductionKind : I32EnumAttr<"ReductionKind", 48 "Reduction of an iterator/mesh dimension.", [ 49 I32EnumAttrCase<"Sum", 1, "sum">, 50 I32EnumAttrCase<"Max", 2, "max">, 51 I32EnumAttrCase<"Min", 3, "min">, 52 I32EnumAttrCase<"Product", 4, "product">, 53 // Arithmetic mean. 54 I32EnumAttrCase<"Average", 5, "average">, 55 I32EnumAttrCase<"BitwiseAnd", 6, "bitwise_and">, 56 I32EnumAttrCase<"BitwiseOr", 7, "bitwise_or">, 57 I32EnumAttrCase<"BitwiseXor", 8, "bitwise_xor">, 58 I32EnumAttrCase<"Generic", 100, "generic"> 59]> { 60 let genSpecializedAttr = 0; 61 let cppNamespace = "::mlir::mesh"; 62} 63 64def Mesh_ReductionKindAttr : EnumAttr<Mesh_Dialect, Mesh_ReductionKind, "partial"> { 65 let assemblyFormat = "$value"; 66} 67 68class Mesh_Type<string name, string typeMnemonic, list<Trait> traits = [], 69 string baseCppClass = "::mlir::Type"> 70 : TypeDef<Mesh_Dialect, name, traits, baseCppClass> { 71 let mnemonic = typeMnemonic; 72} 73 74def Mesh_Sharding : Mesh_Type<"Sharding", "sharding"> { 75 let summary = "sharding definition"; 76 let assemblyFormat = ""; 77} 78 79//===----------------------------------------------------------------------===// 80// Mesh Attribute 81//===----------------------------------------------------------------------===// 82 83def Mesh_MeshAxesArrayAttr : AttrDef<Mesh_Dialect, "MeshAxesArray"> { 84 let mnemonic = "axisarray"; 85 let parameters = (ins ArrayRefParameter<"MeshAxesAttr">:$axes); 86 let assemblyFormat = "`[` $axes `]`"; 87 let extraClassDeclaration = [{ 88 size_t size() const { return getAxes().size(); } 89 auto begin() const { return getAxes().begin(); } 90 auto end() const { return getAxes().end(); } 91 }]; 92} 93 94#endif // MLIR_DIALECT_MESH_IR_MESHBASE_TD 95