Lines Matching defs:multiReductionOp
51 LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
56 cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
62 rootOp = multiReductionOp;
65 auto src = multiReductionOp.getSource();
66 auto loc = multiReductionOp.getLoc();
67 auto srcRank = multiReductionOp.getSourceVectorType().getRank();
70 ArrayRef<int64_t> reductionDims = multiReductionOp.getReductionDims();
121 multiReductionOp.getLoc(), transposeOp.getResult(),
122 multiReductionOp.getAcc(), reductionMask, multiReductionOp.getKind());
148 LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
153 cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
159 rootOp = multiReductionOp;
162 auto srcRank = multiReductionOp.getSourceVectorType().getRank();
163 auto srcShape = multiReductionOp.getSourceVectorType().getShape();
165 multiReductionOp.getSourceVectorType().getScalableDims();
166 auto loc = multiReductionOp.getLoc();
178 SmallVector<bool> reductionMask = multiReductionOp.getReductionMask();
263 vectorShape, multiReductionOp.getSourceVectorType().getElementType(),
266 loc, castedType, multiReductionOp.getSource());
268 Value acc = multiReductionOp.getAcc();
272 multiReductionOp.getSourceVectorType().getElementType(),
279 loc, cast, acc, mask, multiReductionOp.getKind());
292 parallelShapes, multiReductionOp.getSourceVectorType().getElementType(),
309 LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
312 cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
317 auto srcRank = multiReductionOp.getSourceVectorType().getRank();
322 if (multiReductionOp.isReducedDim(1) || !multiReductionOp.isReducedDim(0))
325 auto loc = multiReductionOp.getLoc();
327 multiReductionOp.getSourceVectorType().getShape();
329 Type elementType = getElementTypeOrSelf(multiReductionOp.getDestType());
333 Value result = multiReductionOp.getAcc();
336 loc, multiReductionOp.getSource(), i);
337 result = makeArithReduction(rewriter, loc, multiReductionOp.getKind(),
341 rewriter.replaceOp(multiReductionOp, result);
352 LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
354 auto srcRank = multiReductionOp.getSourceVectorType().getRank();
358 if (multiReductionOp.isReducedDim(0) || !multiReductionOp.isReducedDim(1))
364 cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
370 rootOp = multiReductionOp;
373 auto loc = multiReductionOp.getLoc();
375 loc, multiReductionOp.getDestType(),
376 rewriter.getZeroAttr(multiReductionOp.getDestType()));
377 int outerDim = multiReductionOp.getSourceVectorType().getShape()[0];
381 loc, multiReductionOp.getSource(), ArrayRef<int64_t>{i});
383 loc, multiReductionOp.getAcc(), ArrayRef<int64_t>{i});
385 loc, multiReductionOp.getKind(), v, acc);
412 LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
414 auto srcRank = multiReductionOp.getSourceVectorType().getRank();
422 cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
430 rootOp = multiReductionOp;
433 auto loc = multiReductionOp.getLoc();
434 auto srcVectorType = multiReductionOp.getSourceVectorType();
442 assert(!llvm::isa<VectorType>(multiReductionOp.getDestType()) &&
451 loc, castedType, multiReductionOp.getSource());
453 loc, accType, multiReductionOp.getAcc());
465 loc, cast, castAcc, reductionMask, multiReductionOp.getKind());