//===- VectorTransferPermutationMapRewritePatterns.cpp - Xfer map rewrite -===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements rewrite patterns for the permutation_map attribute of // vector.transfer operations. // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" #include "mlir/Interfaces/VectorInterfaces.h" using namespace mlir; using namespace mlir::vector; /// Transpose a vector transfer op's `in_bounds` attribute by applying reverse /// permutation based on the given indices. static ArrayAttr inverseTransposeInBoundsAttr(OpBuilder &builder, ArrayAttr attr, const SmallVector &permutation) { SmallVector newInBoundsValues(permutation.size()); size_t index = 0; for (unsigned pos : permutation) newInBoundsValues[pos] = cast(attr.getValue()[index++]).getValue(); return builder.getBoolArrayAttr(newInBoundsValues); } /// Extend the rank of a vector Value by `addedRanks` by adding outer unit /// dimensions. static Value extendVectorRank(OpBuilder &builder, Location loc, Value vec, int64_t addedRank) { auto originalVecType = cast(vec.getType()); SmallVector newShape(addedRank, 1); newShape.append(originalVecType.getShape().begin(), originalVecType.getShape().end()); VectorType newVecType = VectorType::get(newShape, originalVecType.getElementType()); return builder.create(loc, newVecType, vec); } /// Extend the rank of a vector Value by `addedRanks` by adding inner unit /// dimensions. static Value extendMaskRank(OpBuilder &builder, Location loc, Value vec, int64_t addedRank) { Value broadcasted = extendVectorRank(builder, loc, vec, addedRank); SmallVector permutation; for (int64_t i = addedRank, e = broadcasted.getType().cast().getRank(); i < e; ++i) permutation.push_back(i); for (int64_t i = 0; i < addedRank; ++i) permutation.push_back(i); return builder.create(loc, broadcasted, permutation); } //===----------------------------------------------------------------------===// // populateVectorTransferPermutationMapLoweringPatterns //===----------------------------------------------------------------------===// namespace { /// Lower transfer_read op with permutation into a transfer_read with a /// permutation map composed of leading zeros followed by a minor identiy + /// vector.transpose op. /// Ex: /// vector.transfer_read ... /// permutation_map: (d0, d1, d2) -> (0, d1) /// into: /// %v = vector.transfer_read ... /// permutation_map: (d0, d1, d2) -> (d1, 0) /// vector.transpose %v, [1, 0] /// /// vector.transfer_read ... /// permutation_map: (d0, d1, d2, d3) -> (0, 0, 0, d1, d3) /// into: /// %v = vector.transfer_read ... /// permutation_map: (d0, d1, d2, d3) -> (0, 0, d1, 0, d3) /// vector.transpose %v, [0, 1, 3, 2, 4] /// Note that an alternative is to transform it to linalg.transpose + /// vector.transfer_read to do the transpose in memory instead. struct TransferReadPermutationLowering : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(vector::TransferReadOp op, PatternRewriter &rewriter) const override { // TODO: support 0-d corner case. if (op.getTransferRank() == 0) return rewriter.notifyMatchFailure(op, "0-d corner case not supported"); SmallVector permutation; AffineMap map = op.getPermutationMap(); if (map.getNumResults() == 0) return rewriter.notifyMatchFailure(op, "0 result permutation map"); if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation)) { return rewriter.notifyMatchFailure( op, "map is not permutable to minor identity, apply another pattern"); } AffineMap permutationMap = map.getPermutationMap(permutation, op.getContext()); if (permutationMap.isIdentity()) return rewriter.notifyMatchFailure(op, "map is not identity"); permutationMap = map.getPermutationMap(permutation, op.getContext()); // Caluclate the map of the new read by applying the inverse permutation. permutationMap = inversePermutation(permutationMap); AffineMap newMap = permutationMap.compose(map); // Apply the reverse transpose to deduce the type of the transfer_read. ArrayRef originalShape = op.getVectorType().getShape(); SmallVector newVectorShape(originalShape.size()); ArrayRef originalScalableDims = op.getVectorType().getScalableDims(); SmallVector newScalableDims(originalShape.size()); for (const auto &pos : llvm::enumerate(permutation)) { newVectorShape[pos.value()] = originalShape[pos.index()]; newScalableDims[pos.value()] = originalScalableDims[pos.index()]; } // Transpose in_bounds attribute. ArrayAttr newInBoundsAttr = op.getInBounds() ? inverseTransposeInBoundsAttr( rewriter, op.getInBounds().value(), permutation) : ArrayAttr(); // Generate new transfer_read operation. VectorType newReadType = VectorType::get( newVectorShape, op.getVectorType().getElementType(), newScalableDims); Value newRead = rewriter.create( op.getLoc(), newReadType, op.getSource(), op.getIndices(), AffineMapAttr::get(newMap), op.getPadding(), op.getMask(), newInBoundsAttr); // Transpose result of transfer_read. SmallVector transposePerm(permutation.begin(), permutation.end()); rewriter.replaceOpWithNewOp(op, newRead, transposePerm); return success(); } }; /// Lower transfer_write op with permutation into a transfer_write with a /// minor identity permutation map. (transfer_write ops cannot have broadcasts.) /// Ex: /// vector.transfer_write %v ... /// permutation_map: (d0, d1, d2) -> (d2, d0, d1) /// into: /// %tmp = vector.transpose %v, [2, 0, 1] /// vector.transfer_write %tmp ... /// permutation_map: (d0, d1, d2) -> (d0, d1, d2) /// /// vector.transfer_write %v ... /// permutation_map: (d0, d1, d2, d3) -> (d3, d2) /// into: /// %tmp = vector.transpose %v, [1, 0] /// %v = vector.transfer_write %tmp ... /// permutation_map: (d0, d1, d2, d3) -> (d2, d3) struct TransferWritePermutationLowering : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(vector::TransferWriteOp op, PatternRewriter &rewriter) const override { // TODO: support 0-d corner case. if (op.getTransferRank() == 0) return rewriter.notifyMatchFailure(op, "0-d corner case not supported"); SmallVector permutation; AffineMap map = op.getPermutationMap(); if (map.isMinorIdentity()) return rewriter.notifyMatchFailure(op, "map is already minor identity"); if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation)) { return rewriter.notifyMatchFailure( op, "map is not permutable to minor identity, apply another pattern"); } // Remove unused dims from the permutation map. E.g.: // E.g.: (d0, d1, d2, d3, d4, d5) -> (d5, d3, d4) // comp = (d0, d1, d2) -> (d2, d0, d1) auto comp = compressUnusedDims(map); AffineMap permutationMap = inversePermutation(comp); // Get positions of remaining result dims. SmallVector indices; llvm::transform(permutationMap.getResults(), std::back_inserter(indices), [](AffineExpr expr) { return expr.dyn_cast().getPosition(); }); // Transpose in_bounds attribute. ArrayAttr newInBoundsAttr = op.getInBounds() ? inverseTransposeInBoundsAttr( rewriter, op.getInBounds().value(), permutation) : ArrayAttr(); // Generate new transfer_write operation. Value newVec = rewriter.create( op.getLoc(), op.getVector(), indices); auto newMap = AffineMap::getMinorIdentityMap( map.getNumDims(), map.getNumResults(), rewriter.getContext()); rewriter.replaceOpWithNewOp( op, newVec, op.getSource(), op.getIndices(), AffineMapAttr::get(newMap), op.getMask(), newInBoundsAttr); return success(); } }; /// Convert a transfer.write op with a map which isn't the permutation of a /// minor identity into a vector.broadcast + transfer_write with permutation of /// minor identity map by adding unit dim on inner dimension. Ex: /// ``` /// vector.transfer_write %v /// {permutation_map = affine_map<(d0, d1, d2, d3) -> (d1, d2)>} : /// vector<8x16xf32> /// ``` /// into: /// ``` /// %v1 = vector.broadcast %v : vector<8x16xf32> to vector<1x8x16xf32> /// vector.transfer_write %v1 /// {permutation_map = affine_map<(d0, d1, d2, d3) -> (d3, d1, d2)>} : /// vector<1x8x16xf32> /// ``` struct TransferWriteNonPermutationLowering : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(vector::TransferWriteOp op, PatternRewriter &rewriter) const override { // TODO: support 0-d corner case. if (op.getTransferRank() == 0) return rewriter.notifyMatchFailure(op, "0-d corner case not supported"); SmallVector permutation; AffineMap map = op.getPermutationMap(); if (map.isPermutationOfMinorIdentityWithBroadcasting(permutation)) { return rewriter.notifyMatchFailure( op, "map is already permutable to minor identity, apply another pattern"); } // Missing outer dimensions are allowed, find the most outer existing // dimension then deduce the missing inner dimensions. SmallVector foundDim(map.getNumDims(), false); for (AffineExpr exp : map.getResults()) foundDim[exp.cast().getPosition()] = true; SmallVector exprs; bool foundFirstDim = false; SmallVector missingInnerDim; for (size_t i = 0; i < foundDim.size(); i++) { if (foundDim[i]) { foundFirstDim = true; continue; } if (!foundFirstDim) continue; // Once we found one outer dimension existing in the map keep track of all // the missing dimensions after that. missingInnerDim.push_back(i); exprs.push_back(rewriter.getAffineDimExpr(i)); } // Vector: add unit dims at the beginning of the shape. Value newVec = extendVectorRank(rewriter, op.getLoc(), op.getVector(), missingInnerDim.size()); // Mask: add unit dims at the end of the shape. Value newMask; if (op.getMask()) newMask = extendMaskRank(rewriter, op.getLoc(), op.getMask(), missingInnerDim.size()); exprs.append(map.getResults().begin(), map.getResults().end()); AffineMap newMap = AffineMap::get(map.getNumDims(), 0, exprs, op.getContext()); // All the new dimensions added are inbound. SmallVector newInBoundsValues(missingInnerDim.size(), true); for (int64_t i = 0, e = op.getVectorType().getRank(); i < e; ++i) { newInBoundsValues.push_back(op.isDimInBounds(i)); } ArrayAttr newInBoundsAttr = rewriter.getBoolArrayAttr(newInBoundsValues); rewriter.replaceOpWithNewOp( op, newVec, op.getSource(), op.getIndices(), AffineMapAttr::get(newMap), newMask, newInBoundsAttr); return success(); } }; /// Lower transfer_read op with broadcast in the leading dimensions into /// transfer_read of lower rank + vector.broadcast. /// Ex: vector.transfer_read ... /// permutation_map: (d0, d1, d2, d3) -> (0, d1, 0, d3) /// into: /// %v = vector.transfer_read ... /// permutation_map: (d0, d1, d2, d3) -> (d1, 0, d3) /// vector.broadcast %v struct TransferOpReduceRank : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(vector::TransferReadOp op, PatternRewriter &rewriter) const override { // TODO: support 0-d corner case. if (op.getTransferRank() == 0) return rewriter.notifyMatchFailure(op, "0-d corner case not supported"); AffineMap map = op.getPermutationMap(); unsigned numLeadingBroadcast = 0; for (auto expr : map.getResults()) { auto dimExpr = expr.dyn_cast(); if (!dimExpr || dimExpr.getValue() != 0) break; numLeadingBroadcast++; } // If there are no leading zeros in the map there is nothing to do. if (numLeadingBroadcast == 0) return rewriter.notifyMatchFailure(op, "no leading broadcasts in map"); VectorType originalVecType = op.getVectorType(); unsigned reducedShapeRank = originalVecType.getRank() - numLeadingBroadcast; // Calculate new map, vector type and masks without the leading zeros. AffineMap newMap = AffineMap::get( map.getNumDims(), 0, map.getResults().take_back(reducedShapeRank), op.getContext()); // Only remove the leading zeros if the rest of the map is a minor identity // with broadasting. Otherwise we first want to permute the map. if (!newMap.isMinorIdentityWithBroadcasting()) { return rewriter.notifyMatchFailure( op, "map is not a minor identity with broadcasting"); } // TODO: support zero-dimension vectors natively. See: // https://llvm.discourse.group/t/should-we-have-0-d-vectors/3097. // In the meantime, lower these to a scalar load when they pop up. if (reducedShapeRank == 0) { Value newRead; if (isa(op.getShapedType())) { newRead = rewriter.create( op.getLoc(), op.getSource(), op.getIndices()); } else { newRead = rewriter.create( op.getLoc(), originalVecType.getElementType(), op.getSource(), op.getIndices()); } rewriter.replaceOpWithNewOp(op, originalVecType, newRead); return success(); } SmallVector newShape( originalVecType.getShape().take_back(reducedShapeRank)); SmallVector newScalableDims( originalVecType.getScalableDims().take_back(reducedShapeRank)); // Vector rank cannot be zero. Handled by TransferReadToVectorLoadLowering. if (newShape.empty()) return rewriter.notifyMatchFailure(op, "rank-reduced vector is 0-d"); VectorType newReadType = VectorType::get( newShape, originalVecType.getElementType(), newScalableDims); ArrayAttr newInBoundsAttr = op.getInBounds() ? rewriter.getArrayAttr( op.getInBoundsAttr().getValue().take_back(reducedShapeRank)) : ArrayAttr(); Value newRead = rewriter.create( op.getLoc(), newReadType, op.getSource(), op.getIndices(), AffineMapAttr::get(newMap), op.getPadding(), op.getMask(), newInBoundsAttr); rewriter.replaceOpWithNewOp(op, originalVecType, newRead); return success(); } }; } // namespace void mlir::vector::populateVectorTransferPermutationMapLoweringPatterns( RewritePatternSet &patterns, PatternBenefit benefit) { patterns .add( patterns.getContext(), benefit); } //===----------------------------------------------------------------------===// // populateVectorTransferLoweringPatterns //===----------------------------------------------------------------------===// namespace { /// Progressive lowering of transfer_read. This pattern supports lowering of /// `vector.transfer_read` to a combination of `vector.load` and /// `vector.broadcast` if all of the following hold: /// - Stride of most minor memref dimension must be 1. /// - Out-of-bounds masking is not required. /// - If the memref's element type is a vector type then it coincides with the /// result type. /// - The permutation map doesn't perform permutation (broadcasting is allowed). struct TransferReadToVectorLoadLowering : public OpRewritePattern { TransferReadToVectorLoadLowering(MLIRContext *context, std::optional maxRank, PatternBenefit benefit = 1) : OpRewritePattern(context, benefit), maxTransferRank(maxRank) {} LogicalResult matchAndRewrite(vector::TransferReadOp read, PatternRewriter &rewriter) const override { if (maxTransferRank && read.getVectorType().getRank() > *maxTransferRank) { return rewriter.notifyMatchFailure( read, "vector type is greater than max transfer rank"); } SmallVector broadcastedDims; // Permutations are handled by VectorToSCF or // populateVectorTransferPermutationMapLoweringPatterns. // We let the 0-d corner case pass-through as it is supported. if (!read.getPermutationMap().isMinorIdentityWithBroadcasting( &broadcastedDims)) return rewriter.notifyMatchFailure(read, "not minor identity + bcast"); auto memRefType = dyn_cast(read.getShapedType()); if (!memRefType) return rewriter.notifyMatchFailure(read, "not a memref source"); // Non-unit strides are handled by VectorToSCF. if (!isLastMemrefDimUnitStride(memRefType)) return rewriter.notifyMatchFailure(read, "!= 1 stride needs VectorToSCF"); // If there is broadcasting involved then we first load the unbroadcasted // vector, and then broadcast it with `vector.broadcast`. ArrayRef vectorShape = read.getVectorType().getShape(); SmallVector unbroadcastedVectorShape(vectorShape.begin(), vectorShape.end()); for (unsigned i : broadcastedDims) unbroadcastedVectorShape[i] = 1; VectorType unbroadcastedVectorType = read.getVectorType().cloneWith( unbroadcastedVectorShape, read.getVectorType().getElementType()); // `vector.load` supports vector types as memref's elements only when the // resulting vector type is the same as the element type. auto memrefElTy = memRefType.getElementType(); if (isa(memrefElTy) && memrefElTy != unbroadcastedVectorType) return rewriter.notifyMatchFailure(read, "incompatible element type"); // Otherwise, element types of the memref and the vector must match. if (!isa(memrefElTy) && memrefElTy != read.getVectorType().getElementType()) return rewriter.notifyMatchFailure(read, "non-matching element type"); // Out-of-bounds dims are handled by MaterializeTransferMask. if (read.hasOutOfBoundsDim()) return rewriter.notifyMatchFailure(read, "out-of-bounds needs mask"); // Create vector load op. Operation *loadOp; if (read.getMask()) { Value fill = rewriter.create( read.getLoc(), unbroadcastedVectorType, read.getPadding()); loadOp = rewriter.create( read.getLoc(), unbroadcastedVectorType, read.getSource(), read.getIndices(), read.getMask(), fill); } else { loadOp = rewriter.create( read.getLoc(), unbroadcastedVectorType, read.getSource(), read.getIndices()); } // Insert a broadcasting op if required. if (!broadcastedDims.empty()) { rewriter.replaceOpWithNewOp( read, read.getVectorType(), loadOp->getResult(0)); } else { rewriter.replaceOp(read, loadOp->getResult(0)); } return success(); } std::optional maxTransferRank; }; /// Replace a 0-d vector.load with a memref.load + vector.broadcast. // TODO: we shouldn't cross the vector/scalar domains just for this // but atm we lack the infra to avoid it. Possible solutions include: // - go directly to LLVM + bitcast // - introduce a bitcast op and likely a new pointer dialect // - let memref.load/store additionally support the 0-d vector case // There are still deeper data layout issues lingering even in this // trivial case (for architectures for which this matters). struct VectorLoadToMemrefLoadLowering : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(vector::LoadOp loadOp, PatternRewriter &rewriter) const override { auto vecType = loadOp.getVectorType(); if (vecType.getNumElements() != 1) return rewriter.notifyMatchFailure(loadOp, "not a single element vector"); auto memrefLoad = rewriter.create( loadOp.getLoc(), loadOp.getBase(), loadOp.getIndices()); rewriter.replaceOpWithNewOp(loadOp, vecType, memrefLoad); return success(); } }; /// Replace a 0-d vector.store with a vector.extractelement + memref.store. struct VectorStoreToMemrefStoreLowering : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(vector::StoreOp storeOp, PatternRewriter &rewriter) const override { auto vecType = storeOp.getVectorType(); if (vecType.getNumElements() != 1) return rewriter.notifyMatchFailure(storeOp, "not single element vector"); Value extracted; if (vecType.getRank() == 0) { // TODO: Unifiy once ExtractOp supports 0-d vectors. extracted = rewriter.create( storeOp.getLoc(), storeOp.getValueToStore()); } else { SmallVector indices(vecType.getRank(), 0); extracted = rewriter.create( storeOp.getLoc(), storeOp.getValueToStore(), indices); } rewriter.replaceOpWithNewOp( storeOp, extracted, storeOp.getBase(), storeOp.getIndices()); return success(); } }; /// Progressive lowering of transfer_write. This pattern supports lowering of /// `vector.transfer_write` to `vector.store` if all of the following hold: /// - Stride of most minor memref dimension must be 1. /// - Out-of-bounds masking is not required. /// - If the memref's element type is a vector type then it coincides with the /// type of the written value. /// - The permutation map is the minor identity map (neither permutation nor /// broadcasting is allowed). struct TransferWriteToVectorStoreLowering : public OpRewritePattern { TransferWriteToVectorStoreLowering(MLIRContext *context, std::optional maxRank, PatternBenefit benefit = 1) : OpRewritePattern(context, benefit), maxTransferRank(maxRank) {} LogicalResult matchAndRewrite(vector::TransferWriteOp write, PatternRewriter &rewriter) const override { if (maxTransferRank && write.getVectorType().getRank() > *maxTransferRank) { return rewriter.notifyMatchFailure( write, "vector type is greater than max transfer rank"); } // Permutations are handled by VectorToSCF or // populateVectorTransferPermutationMapLoweringPatterns. if ( // pass-through for the 0-d corner case. !write.getPermutationMap().isMinorIdentity()) return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) { diag << "permutation map is not minor identity: " << write; }); auto memRefType = dyn_cast(write.getShapedType()); if (!memRefType) return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) { diag << "not a memref type: " << write; }); // Non-unit strides are handled by VectorToSCF. if (!isLastMemrefDimUnitStride(memRefType)) return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) { diag << "most minor stride is not 1: " << write; }); // `vector.store` supports vector types as memref's elements only when the // type of the vector value being written is the same as the element type. auto memrefElTy = memRefType.getElementType(); if (isa(memrefElTy) && memrefElTy != write.getVectorType()) return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) { diag << "elemental type mismatch: " << write; }); // Otherwise, element types of the memref and the vector must match. if (!isa(memrefElTy) && memrefElTy != write.getVectorType().getElementType()) return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) { diag << "elemental type mismatch: " << write; }); // Out-of-bounds dims are handled by MaterializeTransferMask. if (write.hasOutOfBoundsDim()) return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) { diag << "out of bounds dim: " << write; }); if (write.getMask()) { rewriter.replaceOpWithNewOp( write, write.getSource(), write.getIndices(), write.getMask(), write.getVector()); } else { rewriter.replaceOpWithNewOp( write, write.getVector(), write.getSource(), write.getIndices()); } return success(); } std::optional maxTransferRank; }; } // namespace void mlir::vector::populateVectorTransferLoweringPatterns( RewritePatternSet &patterns, std::optional maxTransferRank, PatternBenefit benefit) { patterns.add(patterns.getContext(), maxTransferRank, benefit); patterns .add( patterns.getContext(), benefit); }