Lines Matching defs:tensorAxis
203 for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(splitAxes)) {
224 outShape[tensorAxis] = same ? sz : ShapedType::kDynamic;
228 for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(splitAxes)) {
229 outShape[tensorAxis] = shardDimension(
230 inShape[tensorAxis],
237 for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(splitAxes)) {
238 if (!ShapedType::isDynamic(outShape[tensorAxis]) &&
242 outShape[tensorAxis] +=
246 outShape[tensorAxis] = ShapedType::kDynamic;
554 for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(getSplitAxes())) {
1014 Value operand, Value result, int64_t tensorAxis,
1019 if (axis != tensorAxis) {
1031 DimensionSize(operandType.getDimSize(tensorAxis));
1037 << int64_t(deviceGroupSize) << " for tensor axis " << tensorAxis
1044 resultType.getDimSize(tensorAxis), tensorAxis))) {