Lines Matching defs:tensorAxis
155 for (size_t tensorAxis = 0; tensorAxis < targetSharding.getSplitAxes().size();
156 ++tensorAxis) {
157 if (sourceSharding.getSplitAxes().size() > tensorAxis) {
158 if (sourceSharding.getSplitAxes()[tensorAxis].size() + 1 !=
159 targetSharding.getSplitAxes()[tensorAxis].size()) {
163 sourceSharding.getSplitAxes()[tensorAxis].asArrayRef(),
165 targetSharding.getSplitAxes()[tensorAxis]
168 targetSharding.getSplitAxes()[tensorAxis].asArrayRef().end() -
173 if (targetSharding.getSplitAxes()[tensorAxis].size() != 1) {
178 tensorAxis,
179 targetSharding.getSplitAxes()[tensorAxis].asArrayRef().back());
191 auto [tensorAxis, meshAxis] = detectRes.value();
193 tensorAxis, meshAxis);
205 for (size_t tensorAxis = 0; tensorAxis < sourceSharding.getSplitAxes().size();
206 ++tensorAxis) {
207 if (targetSharding.getSplitAxes().size() > tensorAxis) {
208 if (sourceSharding.getSplitAxes()[tensorAxis].size() !=
209 targetSharding.getSplitAxes()[tensorAxis].size() + 1)
213 sourceSharding.getSplitAxes()[tensorAxis]
216 sourceSharding.getSplitAxes()[tensorAxis].asArrayRef().end() -
218 targetSharding.getSplitAxes()[tensorAxis].asArrayRef()))
221 if (sourceSharding.getSplitAxes()[tensorAxis].size() != 1)
225 tensorAxis,
226 sourceSharding.getSplitAxes()[tensorAxis].asArrayRef().back());
290 auto [tensorAxis, meshAxis] = detectRes.value();
293 tensorAxis, meshAxis);