Searched defs:meshShape (Results 1 – 3 of 3) sorted by relevance
/llvm-project/mlir/lib/Dialect/Mesh/Transforms/ |
H A D | Transforms.cpp | 53 ValueRange meshShape = builder.create<MeshShapeOp>(mesh).getResults(); in matchAndRewrite() local 203 Operation::result_range meshShape = in createCollectiveProcessGroupSize() local
|
/llvm-project/mlir/include/mlir/Dialect/Mesh/IR/ |
H A D | MeshOps.h | 91 collectiveProcessGroupSize(MeshAxesRange && meshAxes,MeshShapeRange && meshShape) collectiveProcessGroupSize() argument
|
/llvm-project/mlir/lib/Dialect/Mesh/IR/ |
H A D | MeshOps.cpp | 152 shardShape(const InShape & inShape,const MeshShape & meshShape,const SplitAxes & splitAxes,OutShape & outShape) shardShape() argument 511 verifyInGroupDevice(Location loc,StringRef deviceName,ArrayRef<int64_t> device,Operation::operand_range deviceDynamic,ArrayRef<MeshAxis> meshAxes,ArrayRef<int64_t> meshShape) verifyInGroupDevice() argument 578 verifyGatherOperandAndResultShape(Value operand,Value result,int64_t gatherAxis,ArrayRef<MeshAxis> meshAxes,ArrayRef<int64_t> meshShape) verifyGatherOperandAndResultShape() argument 605 verifyAllToAllOperandAndResultShape(Value operand,Value result,int64_t splitAxis,int64_t concatAxis,ArrayRef<MeshAxis> meshAxes,ArrayRef<int64_t> meshShape) verifyAllToAllOperandAndResultShape() argument 650 verifyScatterOrSliceOperandAndResultShape(Value operand,Value result,int64_t tensorAxis,ArrayRef<MeshAxis> meshAxes,ArrayRef<int64_t> meshShape) verifyScatterOrSliceOperandAndResultShape() argument [all...] |