Lines Matching defs:xferOp

49 static bool isTensorOp(VectorTransferOpInterface xferOp) {
50 if (isa<RankedTensorType>(xferOp.getShapedType())) {
51 if (isa<vector::TransferWriteOp>(xferOp)) {
53 assert(xferOp->getNumResults() > 0);
68 LogicalResult checkLowerTensors(VectorTransferOpInterface xferOp,
70 if (isTensorOp(xferOp) && !options.lowerTensors) {
72 xferOp, "lowering tensor transfers is disabled");
84 static std::optional<int64_t> unpackedDim(OpTy xferOp) {
86 assert(xferOp.getTransferRank() > 0 && "unexpected 0-d transfer");
87 auto map = xferOp.getPermutationMap();
91 assert(xferOp.isBroadcastDim(0) &&
100 static AffineMap unpackedPermutationMap(OpBuilder &b, OpTy xferOp) {
102 assert(xferOp.getTransferRank() > 0 && "unexpected 0-d transfer");
103 auto map = xferOp.getPermutationMap();
115 static void getXferIndices(OpBuilder &b, OpTy xferOp, Value iv,
117 typename OpTy::Adaptor adaptor(xferOp);
119 auto dim = unpackedDim(xferOp);
123 Location loc = xferOp.getLoc();
127 bindDims(xferOp.getContext(), d0, d1);
144 /// Generates a boolean Value that is true if the iv-th bit in xferOp's mask
146 /// * xferOp does not have a mask.
147 /// * xferOp's mask is not 1D. (In case of (N>1)-D, a subvector of the mask is
149 /// * The to-be-unpacked dim of xferOp is a broadcast.
151 static Value generateMaskCheck(OpBuilder &b, OpTy xferOp, Value iv) {
152 if (!xferOp.getMask())
154 if (xferOp.getMaskType().getRank() != 1)
156 if (xferOp.isBroadcastDim(0))
159 Location loc = xferOp.getLoc();
160 return b.create<vector::ExtractElementOp>(loc, xferOp.getMask(), iv);
189 OpBuilder &b, OpTy xferOp, Value iv, std::optional<int64_t> dim,
198 Location loc = xferOp.getLoc();
199 ImplicitLocOpBuilder lb(xferOp.getLoc(), b);
200 if (!xferOp.isDimInBounds(0) && !isBroadcast) {
202 vector::createOrFoldDimOp(b, loc, xferOp.getSource(), *dim);
204 bindDims(xferOp.getContext(), d0, d1);
205 Value base = xferOp.getIndices()[*dim];
213 if (auto maskCond = generateMaskCheck(b, xferOp, iv)) {
248 OpBuilder &b, OpTy xferOp, Value iv, std::optional<int64_t> dim,
252 b, xferOp, iv, dim, /*resultTypes=*/TypeRange(),
300 static BufferAllocs allocBuffers(OpBuilder &b, OpTy xferOp) {
301 Location loc = xferOp.getLoc();
303 Operation *scope = getAutomaticAllocationScope(xferOp);
309 auto bufferType = MemRefType::get({}, xferOp.getVectorType());
312 if (xferOp.getMask()) {
313 auto maskType = MemRefType::get({}, xferOp.getMask().getType());
315 b.setInsertionPoint(xferOp);
316 b.create<memref::StoreOp>(loc, xferOp.getMask(), maskBuffer);
344 static Value getMaskBuffer(OpTy xferOp) {
345 assert(xferOp.getMask() && "Expected that transfer op has mask");
346 auto loadOp = xferOp.getMask().template getDefiningOp<memref::LoadOp>();
360 static memref::StoreOp getStoreOp(TransferReadOp xferOp) {
361 assert(xferOp->hasOneUse() && "Expected exactly one use of TransferReadOp");
362 auto storeOp = dyn_cast<memref::StoreOp>((*xferOp->use_begin()).getOwner());
374 static Value getBuffer(TransferReadOp xferOp) {
375 return getStoreOp(xferOp).getMemRef();
379 static void getBufferIndices(TransferReadOp xferOp,
381 auto storeOp = getStoreOp(xferOp);
415 TransferReadOp xferOp, Value buffer, Value iv,
418 getBufferIndices(xferOp, storeIndices);
422 getXferIndices(b, xferOp, iv, xferIndices);
424 Location loc = xferOp.getLoc();
427 auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
429 loc, vecType, xferOp.getSource(), xferIndices,
430 AffineMapAttr::get(unpackedPermutationMap(b, xferOp)),
431 xferOp.getPadding(), Value(), inBoundsAttr);
441 static Value handleOutOfBoundsDim(OpBuilder &b, TransferReadOp xferOp,
445 getBufferIndices(xferOp, storeIndices);
448 Location loc = xferOp.getLoc();
451 auto vec = b.create<vector::SplatOp>(loc, vecType, xferOp.getPadding());
458 static void cleanup(PatternRewriter &rewriter, TransferReadOp xferOp,
460 rewriter.eraseOp(getStoreOp(xferOp));
461 rewriter.eraseOp(xferOp);
465 static Value initialLoopState(TransferReadOp xferOp) { return Value(); }
478 static Value getBuffer(TransferWriteOp xferOp) {
479 auto loadOp = xferOp.getVector().getDefiningOp<memref::LoadOp>();
485 static void getBufferIndices(TransferWriteOp xferOp,
487 auto loadOp = xferOp.getVector().getDefiningOp<memref::LoadOp>();
503 TransferWriteOp xferOp, Value buffer,
506 getBufferIndices(xferOp, loadIndices);
510 getXferIndices(b, xferOp, iv, xferIndices);
512 Location loc = xferOp.getLoc();
514 auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
515 auto source = loopState.empty() ? xferOp.getSource() : loopState[0];
516 Type type = isTensorOp(xferOp) ? xferOp.getShapedType() : Type();
519 AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), Value(),
528 static Value handleOutOfBoundsDim(OpBuilder &b, TransferWriteOp xferOp,
531 return isTensorOp(xferOp) ? loopState[0] : Value();
535 static void cleanup(PatternRewriter &rewriter, TransferWriteOp xferOp,
537 if (isTensorOp(xferOp)) {
539 rewriter.replaceOp(xferOp, forOp->getResult(0));
541 rewriter.eraseOp(xferOp);
546 static Value initialLoopState(TransferWriteOp xferOp) {
547 return isTensorOp(xferOp) ? xferOp.getSource() : Value();
552 LogicalResult checkPrepareXferOp(OpTy xferOp,
554 if (xferOp->hasAttr(kPassLabel))
556 if (xferOp.getVectorType().getRank() <= options.targetRank)
560 if (xferOp.getVectorType().getScalableDims().front())
562 if (isTensorOp(xferOp) && !options.lowerTensors)
565 if (xferOp.getVectorType().getElementType() !=
566 xferOp.getShapedType().getElementType())
598 LogicalResult matchAndRewrite(TransferReadOp xferOp,
600 if (checkPrepareXferOp(xferOp, options).failed())
603 auto buffers = allocBuffers(rewriter, xferOp);
604 auto *newXfer = rewriter.clone(*xferOp.getOperation());
606 if (xferOp.getMask()) {
611 Location loc = xferOp.getLoc();
614 rewriter.replaceOpWithNewOp<memref::LoadOp>(xferOp, buffers.dataBuffer);
647 LogicalResult matchAndRewrite(TransferWriteOp xferOp,
649 if (checkPrepareXferOp(xferOp, options).failed())
652 Location loc = xferOp.getLoc();
653 auto buffers = allocBuffers(rewriter, xferOp);
654 rewriter.create<memref::StoreOp>(loc, xferOp.getVector(),
657 rewriter.modifyOpInPlace(xferOp, [&]() {
658 xferOp.getVectorMutable().assign(loadedVec);
659 xferOp->setAttr(kPassLabel, rewriter.getUnitAttr());
662 if (xferOp.getMask()) {
663 rewriter.modifyOpInPlace(xferOp, [&]() {
664 xferOp.getMaskMutable().assign(buffers.maskBuffer);
878 static void getMaskBufferLoadIndices(OpTy xferOp, Value castedMaskBuffer,
881 assert(xferOp.getMask() && "Expected transfer op to have mask");
887 Value maskBuffer = getMaskBuffer(xferOp);
899 if (!xferOp.isBroadcastDim(0))
903 LogicalResult matchAndRewrite(OpTy xferOp,
905 if (!xferOp->hasAttr(kPassLabel))
909 ImplicitLocOpBuilder locB(xferOp.getLoc(), rewriter);
910 Value dataBuffer = Strategy<OpTy>::getBuffer(xferOp);
919 // If the xferOp has a mask: Find and cast mask buffer.
921 if (xferOp.getMask()) {
922 Value maskBuffer = getMaskBuffer(xferOp);
923 if (xferOp.isBroadcastDim(0) || xferOp.getMaskType().getRank() == 1) {
947 auto loopState = Strategy<OpTy>::initialLoopState(xferOp);
956 b, xferOp, iv, unpackedDim(xferOp),
962 b, this->options, xferOp, castedDataBuffer, iv, loopState);
968 if (xferOp.getMask() && (xferOp.isBroadcastDim(0) ||
969 xferOp.getMaskType().getRank() > 1)) {
974 getMaskBufferLoadIndices(xferOp, castedMaskBuffer,
988 b, xferOp, castedDataBuffer, iv, loopState);
994 Strategy<OpTy>::cleanup(rewriter, xferOp, result);
1189 static void maybeAssignMask(OpBuilder &b, OpTy xferOp, OpTy newXferOp,
1191 if (!xferOp.getMask())
1194 if (xferOp.isBroadcastDim(0)) {
1197 newXferOp.getMaskMutable().assign(xferOp.getMask());
1201 if (xferOp.getMaskType().getRank() > 1) {
1207 Location loc = xferOp.getLoc();
1208 auto newMask = b.create<vector::ExtractOp>(loc, xferOp.getMask(), indices);
1258 TransferReadOp xferOp) const {
1259 if (auto insertOp = getInsertOp(xferOp))
1261 Location loc = xferOp.getLoc();
1262 return rewriter.create<vector::SplatOp>(loc, xferOp.getVectorType(),
1263 xferOp.getPadding());
1268 vector::InsertOp getInsertOp(TransferReadOp xferOp) const {
1269 if (xferOp->hasOneUse()) {
1270 Operation *xferOpUser = *xferOp->getUsers().begin();
1280 void getInsertionIndices(TransferReadOp xferOp,
1282 if (auto insertOp = getInsertOp(xferOp)) {
1290 LogicalResult matchAndRewrite(TransferReadOp xferOp,
1292 if (xferOp.getVectorType().getRank() <= options.targetRank)
1294 xferOp, "vector rank is less or equal to target rank");
1295 if (failed(checkLowerTensors(xferOp, rewriter)))
1298 if (xferOp.getVectorType().getElementType() !=
1299 xferOp.getShapedType().getElementType())
1301 xferOp, "not yet supported: element type mismatch");
1302 auto xferVecType = xferOp.getVectorType();
1306 xferOp, "scalable dimensions cannot be unrolled");
1309 auto insertOp = getInsertOp(xferOp);
1310 auto vec = buildResultVector(rewriter, xferOp);
1318 Location loc = xferOp.getLoc();
1323 rewriter, xferOp, iv, unpackedDim(xferOp), TypeRange(vecType),
1328 getXferIndices(b, xferOp, iv, xferIndices);
1332 getInsertionIndices(xferOp, insertionIndices);
1335 auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
1337 loc, newXferVecType, xferOp.getSource(), xferIndices,
1338 AffineMapAttr::get(unpackedPermutationMap(b, xferOp)),
1339 xferOp.getPadding(), Value(), inBoundsAttr);
1340 maybeAssignMask(b, xferOp, newXferOp, i);
1354 rewriter.eraseOp(xferOp);
1356 rewriter.replaceOp(xferOp, vec);
1400 Value getDataVector(TransferWriteOp xferOp) const {
1401 if (auto extractOp = getExtractOp(xferOp))
1403 return xferOp.getVector();
1407 vector::ExtractOp getExtractOp(TransferWriteOp xferOp) const {
1408 if (auto *op = xferOp.getVector().getDefiningOp())
1415 void getExtractionIndices(TransferWriteOp xferOp,
1417 if (auto extractOp = getExtractOp(xferOp)) {
1425 LogicalResult matchAndRewrite(TransferWriteOp xferOp,
1427 VectorType inputVectorTy = xferOp.getVectorType();
1432 if (failed(checkLowerTensors(xferOp, rewriter)))
1436 xferOp.getShapedType().getElementType())
1439 auto vec = getDataVector(xferOp);
1446 Value source = xferOp.getSource(); // memref or tensor to be written to.
1447 auto sourceType = isTensorOp(xferOp) ? xferOp.getShapedType() : Type();
1450 Location loc = xferOp.getLoc();
1455 rewriter, xferOp, iv, unpackedDim(xferOp),
1456 isTensorOp(xferOp) ? TypeRange(sourceType) : TypeRange(),
1461 getXferIndices(b, xferOp, iv, xferIndices);
1465 getExtractionIndices(xferOp, extractionIndices);
1470 auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
1483 AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), Value(),
1486 maybeAssignMask(b, xferOp, newXferOp, i);
1488 return isTensorOp(xferOp) ? newXferOp->getResult(0) : Value();
1492 return isTensorOp(xferOp) ? source : Value();
1495 if (isTensorOp(xferOp))
1499 if (isTensorOp(xferOp))
1500 rewriter.replaceOp(xferOp, source);
1502 rewriter.eraseOp(xferOp);
1518 get1dMemrefIndices(OpBuilder &b, OpTy xferOp, Value iv,
1520 auto indices = xferOp.getIndices();
1521 auto map = xferOp.getPermutationMap();
1522 assert(xferOp.getTransferRank() > 0 && "unexpected 0-d transfer");
1528 Location loc = xferOp.getLoc();
1531 bindDims(xferOp.getContext(), d0, d1);
1538 assert(xferOp.isBroadcastDim(0) &&
1552 TransferReadOp xferOp, Value iv,
1555 auto dim = get1dMemrefIndices(b, xferOp, iv, indices);
1561 b, xferOp, iv, dim, TypeRange(xferOp.getVectorType()),
1565 b.create<memref::LoadOp>(loc, xferOp.getSource(), indices);
1573 static Value initialLoopState(OpBuilder &b, TransferReadOp xferOp) {
1575 Location loc = xferOp.getLoc();
1576 return b.create<vector::SplatOp>(loc, xferOp.getVectorType(),
1577 xferOp.getPadding());
1585 TransferWriteOp xferOp, Value iv,
1588 auto dim = get1dMemrefIndices(b, xferOp, iv, indices);
1592 b, xferOp, iv, dim,
1595 b.create<vector::ExtractElementOp>(loc, xferOp.getVector(), iv);
1596 b.create<memref::StoreOp>(loc, val, xferOp.getSource(), indices);
1601 static Value initialLoopState(OpBuilder &b, TransferWriteOp xferOp) {
1641 LogicalResult matchAndRewrite(OpTy xferOp,
1644 if (xferOp.getTransferRank() == 0)
1646 auto map = xferOp.getPermutationMap();
1647 auto memRefType = dyn_cast<MemRefType>(xferOp.getShapedType());
1651 if (xferOp.getVectorType().getRank() != 1)
1657 Location loc = xferOp.getLoc();
1658 auto vecType = xferOp.getVectorType();
1668 auto loopState = Strategy1d<OpTy>::initialLoopState(rewriter, xferOp);
1672 xferOp, lb, ub, step, loopState ? ValueRange(loopState) : ValueRange(),
1674 Strategy1d<OpTy>::generateForLoopBody(b, loc, xferOp, iv, loopState);