Lines Matching defs:packOp

4363 static bool hasSameInnerOuterAttribute(PackOp packOp, UnPackOp unPackOp) {
4364 if (packOp.getInnerDimsPos() != unPackOp.getInnerDimsPos())
4366 if (packOp.getOuterDimsPerm() == unPackOp.getOuterDimsPerm())
4371 return isIdentityPermutation(packOp.getOuterDimsPerm()) &&
4377 static bool haveSameTiles(PackOp packOp, UnPackOp unPackOp) {
4378 auto packTiles = packOp.getMixedTiles();
4403 /// `packOp` and populates each with the inferred static shape.
4404 static bool inferStaticShape(PackOp packOp, SmallVectorImpl<int64_t> &srcShape,
4407 srcShape.assign(packOp.getSourceType().getShape().begin(),
4408 packOp.getSourceType().getShape().end());
4409 destShape.assign(packOp.getDestType().getShape().begin(),
4410 packOp.getDestType().getShape().end());
4412 innerDims.insert(packOp.getInnerDimsPos().begin(),
4413 packOp.getInnerDimsPos().end());
4415 if (!packOp.getOuterDimsPerm().empty())
4416 inverseOuterDimsPerm = invertPermutationVector(packOp.getOuterDimsPerm());
4417 int srcRank = packOp.getSourceRank();
4439 LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
4441 if (auto unPackOp = packOp.getSource().getDefiningOp<UnPackOp>()) {
4442 if (unPackOp.getSourceType() != packOp.getDestType())
4444 if (packOp.getPaddingValue() ||
4445 !hasSameInnerOuterAttribute(packOp, unPackOp) ||
4446 !haveSameTiles(packOp, unPackOp))
4448 rewriter.replaceOp(packOp, unPackOp.getSource());
4453 if (packOp.getPaddingValue() && paddingIsNotNeeded(packOp)) {
4454 rewriter.startOpModification(packOp);
4455 packOp.getPaddingValueMutable().clear();
4456 rewriter.finalizeOpModification(packOp);
4462 if (inferStaticShape(packOp, srcShape, destShape)) {
4463 Location loc = packOp.getLoc();
4464 Value source = packOp.getSource();
4465 if (srcShape != packOp.getSourceType().getShape()) {
4466 auto newSrcType = packOp.getSourceType().clone(srcShape);
4468 rewriter.create<tensor::CastOp>(loc, newSrcType, packOp.getSource());
4470 Value dest = packOp.getDest();
4471 RankedTensorType originalResultType = packOp.getDestType();
4474 auto newDestType = packOp.getDestType().clone(destShape);
4476 rewriter.create<tensor::CastOp>(loc, newDestType, packOp.getDest());
4478 rewriter.modifyOpInPlace(packOp, [&] {
4479 packOp.getSourceMutable().assign(source);
4480 packOp.getDestMutable().assign(dest);
4481 packOp.getResult().setType(cast<RankedTensorType>(dest.getType()));
4485 rewriter.setInsertionPointAfter(packOp);
4487 rewriter.create<tensor::CastOp>(loc, originalResultType, packOp);
4488 rewriter.replaceAllUsesExcept(packOp, castOp, castOp);
4497 static bool isLikePadUnPad(PackOrUnpackOp packOp,
4505 ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
4701 if (PackOp packOp = unPackOp.getSource().getDefiningOp<tensor::PackOp>()) {
4702 if (packOp.getSourceType() != unPackOp.getDestType())
4704 if (packOp.getPaddingValue() ||
4705 !hasSameInnerOuterAttribute(packOp, unPackOp) ||
4706 !haveSameTiles(packOp, unPackOp))
4708 rewriter.replaceOp(unPackOp, packOp.getSource());