Lines Matching defs:meshAxes
870 auto meshAxes = op.getMeshAxes();
871 if (!meshAxes.empty()) {
889 ArrayRef<MeshAxis> meshAxes,
891 if (device.size() != meshAxes.size()) {
894 << device.size() << ". Expected " << meshAxes.size()
900 !ShapedType::isDynamic(meshShape[meshAxes[i]]) &&
901 meshShape[meshAxes[i]] <= device[i]) {
906 << (meshShape[meshAxes[i]] - 1) << "].";
943 ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) {
954 DimensionSize(collectiveProcessGroupSize(meshAxes, meshShape));
970 ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) {
988 DimensionSize(collectiveProcessGroupSize(meshAxes, meshShape));
1015 ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) {
1029 DimensionSize(collectiveProcessGroupSize(meshAxes, meshShape));
1052 ArrayRef<MeshAxis> meshAxes,
1063 DimensionSize(collectiveProcessGroupSize(meshAxes, mesh));
1109 ArrayRef<MeshAxis> meshAxes, ReductionKind reduction) {
1110 build(odsBuilder, odsState, input.getType(), mesh, meshAxes, input,
1139 Value input, MeshOp mesh, ArrayRef<MeshAxis> meshAxes,
1141 Type resultType = sliceResultType(input.getType(), mesh, meshAxes, sliceAxis);
1142 build(odsBuilder, odsState, resultType, input, mesh.getSymName(), meshAxes,
1148 ArrayRef<MeshAxis> meshAxes, int64_t sliceAxis) {
1149 build(odsBuilder, odsState, resultType, mesh, meshAxes, input,
1391 auto meshAxes = getMeshAxes();
1393 if (llvm::find(meshAxes, shiftAxis) == meshAxes.end()) {