Lines Matching defs:linalgOp

524 ///  The reshape can be folded into the `linalgOp` if its loop dimensionality
526 /// The indexing_map of the fused tensor in the `linalgOp` and the
563 static bool isFusableWithReshapeByDimExpansion(LinalgOp linalgOp,
571 linalgOp.getIteratorTypesArray();
572 AffineMap operandMap = linalgOp.getMatchingIndexingMap(fusableOpOperand);
573 return linalgOp.hasPureTensorSemantics() &&
574 llvm::all_of(linalgOp.getIndexingMaps().getValue(),
596 LogicalResult compute(LinalgOp linalgOp, OpOperand *fusableOpOperand,
624 LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
632 AffineMap fusedIndexMap = linalgOp.getMatchingIndexingMap(fusableOpOperand);
634 SmallVector<int64_t, 4> originalLoopRange = linalgOp.getStaticLoopRanges();
676 static LogicalResult isLinalgOpExpandable(LinalgOp linalgOp,
679 if (!linalgOp.hasIndexSemantics())
688 linalgOp, "cannot expand due to index semantics and dynamic dims");
796 validateDynamicDimExpansion(LinalgOp linalgOp,
809 linalgOp, "cannot infer expanded shape with multiple dynamic "
822 fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
825 assert(isFusableWithReshapeByDimExpansion(linalgOp, fusableOpOperand) &&
828 Location loc = linalgOp.getLoc();
842 linalgOp, fusableOpOperand,
850 if (failed(validateDynamicDimExpansion(linalgOp, expansionInfo, rewriter)))
853 if (failed(isLinalgOpExpandable(linalgOp, expansionInfo, rewriter)))
857 llvm::map_range(linalgOp.getIndexingMapsArray(), [&](AffineMap m) {
863 rewriter.setInsertionPoint(linalgOp);
866 expandedOpOperands.reserve(linalgOp.getNumDpsInputs());
867 for (OpOperand *opOperand : linalgOp.getDpsInputOperands()) {
875 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
884 return rewriter.notifyMatchFailure(linalgOp, msg);
899 for (OpOperand &opOperand : linalgOp.getDpsInitsMutable()) {
900 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
909 return rewriter.notifyMatchFailure(linalgOp, msg);
925 for (auto [i, type] : llvm::enumerate(linalgOp.getIteratorTypesArray()))
931 rewriter.create<GenericOp>(linalgOp.getLoc(), resultTypes,
935 Region &originalRegion = linalgOp->getRegion(0);
944 for (OpResult opResult : linalgOp->getOpResults()) {
949 linalgOp.getMatchingIndexingMap(
950 linalgOp.getDpsInitOperand(resultNumber)),
953 linalgOp.getLoc(), opResult.getType(),
977 LogicalResult matchAndRewrite(LinalgOp linalgOp,
979 for (OpOperand *opOperand : linalgOp.getDpsInputOperands()) {
987 if (!isFusableWithReshapeByDimExpansion(linalgOp, opOperand) ||
992 fuseWithReshapeByExpansion(linalgOp, reshapeOp, opOperand, rewriter);
995 rewriter.replaceOp(linalgOp, *replacementValues);