Lines Matching defs:axis
165 for (auto axis : axes) {
166 if (axis >= rank || axis < 0) {
168 << "0-based mesh axis index " << axis
496 for (MeshAxis axis : axesArray) {
497 if (axis < 0)
498 return emitError() << "mesh axis is expected to be non-negative";
499 if (!visitedAxes.insert(axis).second)
500 return emitError() << "mesh axis duplicated";
733 for (auto [i, axis] : llvm::enumerate(split_axes_)) {
735 MeshAxesAttr::get(mesh_.getContext(), axis.asArrayRef());
930 return emitError(loc) << "Dimension size mismatch for result axis "
947 << "Gather axis " << gatherAxis << " is out of bounds [0, "
955 for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
956 auto operandDimSize = DimensionSize(operandType.getDimSize(axis));
957 auto resultDimSize = DimensionSize(resultType.getDimSize(axis));
959 axis == gatherAxis ? deviceGroupSize * operandDimSize : operandDimSize;
961 result.getLoc(), expectedResultDimSize, resultDimSize, axis))) {
973 for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
974 if ((axis != splitAxis && axis != concatAxis) || splitAxis == concatAxis) {
976 result.getLoc(), operandType.getDimSize(axis),
977 resultType.getDimSize(axis), axis))) {
1018 for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
1019 if (axis != tensorAxis) {
1021 result.getLoc(), operandType.getDimSize(axis),
1022 resultType.getDimSize(axis), axis))) {
1037 << int64_t(deviceGroupSize) << " for tensor axis " << tensorAxis
1394 return emitError() << "Invalid shift axis " << shiftAxis