//===- VectorUnrollDistribute.cpp - patterns to do vector unrolling -------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements patterns to do vector unrolling and vector distribution. // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" #include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/Interfaces/VectorInterfaces.h" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Support/Debug.h" #include #include #define DEBUG_TYPE "vector-unroll" #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") using namespace mlir; using namespace mlir::vector; /// Compute the indices of the slice `index` for a tranfer op. static SmallVector sliceTransferIndices(ArrayRef elementOffsets, ArrayRef indices, AffineMap permutationMap, Location loc, OpBuilder &builder) { MLIRContext *ctx = builder.getContext(); auto isBroadcast = [](AffineExpr expr) { if (auto constExpr = dyn_cast(expr)) return constExpr.getValue() == 0; return false; }; // Compute 'sliceIndices' by adding 'sliceOffsets[i]' to 'indices[i]'. SmallVector slicedIndices(indices); for (const auto &dim : llvm::enumerate(permutationMap.getResults())) { if (isBroadcast(dim.value())) continue; unsigned pos = cast(dim.value()).getPosition(); auto expr = getAffineDimExpr(0, builder.getContext()) + getAffineConstantExpr(elementOffsets[dim.index()], ctx); auto map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr); slicedIndices[pos] = builder.create(loc, map, indices[pos]); } return slicedIndices; } // Clones `op` into a new operations that takes `operands` and returns // `resultTypes`. static Operation *cloneOpWithOperandsAndTypes(OpBuilder &builder, Location loc, Operation *op, ArrayRef operands, ArrayRef resultTypes) { return builder.create(loc, op->getName().getIdentifier(), operands, resultTypes, op->getAttrs()); } /// Return the target shape for unrolling for the given `op`. Return /// std::nullopt if the op shouldn't be or cannot be unrolled. static std::optional> getTargetShape(const vector::UnrollVectorOptions &options, Operation *op) { LDBG(""); LDBG("Get unroll shape for op " << op->getName().getStringRef()); if (options.filterConstraint && failed(options.filterConstraint(op))) { LDBG("--no filter constraint -> BAIL"); return std::nullopt; } assert(options.nativeShape && "vector unrolling expects the native shape or native" "shape call back function to be set"); auto unrollableVectorOp = dyn_cast(op); if (!unrollableVectorOp) { LDBG("--not an unrollable op -> BAIL"); return std::nullopt; } auto maybeUnrollShape = unrollableVectorOp.getShapeForUnroll(); if (!maybeUnrollShape) { LDBG("--could not get shape of op " << *op << " -> BAIL"); return std::nullopt; } LLVM_DEBUG( llvm::interleaveComma(*maybeUnrollShape, DBGS() << "--vector op shape: "); llvm::dbgs() << "\n";); std::optional> targetShape = options.nativeShape(op); if (!targetShape) { LDBG("--no unrolling target shape defined " << *op << "-> SKIP"); return std::nullopt; } LLVM_DEBUG(llvm::interleaveComma(*targetShape, DBGS() << "--target shape: "); llvm::dbgs() << "\n";); auto maybeShapeRatio = computeShapeRatio(*maybeUnrollShape, *targetShape); if (!maybeShapeRatio) { LDBG("--could not compute integral shape ratio -> BAIL"); return std::nullopt; } if (llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; })) { LDBG("--no unrolling needed -> SKIP"); return std::nullopt; } LDBG("--found an integral shape ratio to unroll to -> SUCCESS"); return targetShape; } static SmallVector getUnrollOrder(unsigned numLoops, Operation *op, const vector::UnrollVectorOptions &options) { SmallVector loopOrder = llvm::to_vector(llvm::seq(0, static_cast(numLoops))); if (options.traversalOrderCallback != nullptr) { std::optional> order = options.traversalOrderCallback(op); if (order) { loopOrder = std::move(*order); } } return loopOrder; } namespace { struct UnrollTransferReadPattern : public OpRewritePattern { UnrollTransferReadPattern(MLIRContext *context, const vector::UnrollVectorOptions &options, PatternBenefit benefit = 1) : OpRewritePattern(context, benefit), options(options) {} LogicalResult matchAndRewrite(vector::TransferReadOp readOp, PatternRewriter &rewriter) const override { // TODO: support 0-d corner case. if (readOp.getTransferRank() == 0) return failure(); if (readOp.getMask()) return failure(); auto targetShape = getTargetShape(options, readOp); if (!targetShape) return failure(); auto sourceVectorType = readOp.getVectorType(); SmallVector strides(targetShape->size(), 1); Location loc = readOp.getLoc(); ArrayRef originalSize = readOp.getVectorType().getShape(); // Prepare the result vector; Value result = rewriter.create( loc, sourceVectorType, rewriter.getZeroAttr(sourceVectorType)); auto targetType = VectorType::get(*targetShape, sourceVectorType.getElementType()); SmallVector originalIndices(readOp.getIndices().begin(), readOp.getIndices().end()); SmallVector loopOrder = getUnrollOrder(originalSize.size(), readOp, options); for (SmallVector elementOffsets : StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) { SmallVector indices = sliceTransferIndices(elementOffsets, originalIndices, readOp.getPermutationMap(), loc, rewriter); auto slicedRead = rewriter.create( loc, targetType, readOp.getSource(), indices, readOp.getPermutationMapAttr(), readOp.getPadding(), readOp.getMask(), readOp.getInBoundsAttr()); result = rewriter.create( loc, slicedRead, result, elementOffsets, strides); } rewriter.replaceOp(readOp, result); return success(); } private: vector::UnrollVectorOptions options; }; struct UnrollTransferWritePattern : public OpRewritePattern { UnrollTransferWritePattern(MLIRContext *context, const vector::UnrollVectorOptions &options, PatternBenefit benefit = 1) : OpRewritePattern(context, benefit), options(options) {} LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp, PatternRewriter &rewriter) const override { // TODO: support 0-d corner case. if (writeOp.getTransferRank() == 0) return failure(); if (writeOp.getMask()) return failure(); auto targetShape = getTargetShape(options, writeOp); if (!targetShape) return failure(); auto sourceVectorType = writeOp.getVectorType(); SmallVector strides(targetShape->size(), 1); Location loc = writeOp.getLoc(); ArrayRef originalSize = sourceVectorType.getShape(); SmallVector originalIndices(writeOp.getIndices().begin(), writeOp.getIndices().end()); SmallVector loopOrder = getUnrollOrder(originalSize.size(), writeOp, options); Value resultTensor; for (SmallVector elementOffsets : StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) { Value slicedVector = rewriter.create( loc, writeOp.getVector(), elementOffsets, *targetShape, strides); SmallVector indices = sliceTransferIndices(elementOffsets, originalIndices, writeOp.getPermutationMap(), loc, rewriter); Operation *slicedWrite = rewriter.create( loc, slicedVector, resultTensor ? resultTensor : writeOp.getSource(), indices, writeOp.getPermutationMapAttr(), writeOp.getInBoundsAttr()); // For the tensor case update the destination for the next transfer write. if (!slicedWrite->getResults().empty()) resultTensor = slicedWrite->getResult(0); } if (resultTensor) rewriter.replaceOp(writeOp, resultTensor); else rewriter.eraseOp(writeOp); return success(); } private: vector::UnrollVectorOptions options; }; struct OffsetMapInfo { static SmallVector getEmptyKey() { return {int64_t(-1)}; } static SmallVector getTombstoneKey() { return {int64_t(-2)}; } static unsigned getHashValue(const SmallVector &v) { return static_cast(llvm::hash_combine_range(v.begin(), v.end())); } static bool isEqual(const SmallVector &lhs, const SmallVector &rhs) { return lhs == rhs; } }; struct UnrollContractionPattern : public OpRewritePattern { UnrollContractionPattern(MLIRContext *context, const vector::UnrollVectorOptions &options, PatternBenefit benefit = 1) : OpRewritePattern(context, benefit), options(options) {} LogicalResult matchAndRewrite(vector::ContractionOp contractOp, PatternRewriter &rewriter) const override { auto targetShape = getTargetShape(options, contractOp); if (!targetShape) return failure(); auto dstVecType = cast(contractOp.getResultType()); SmallVector originalSize = *contractOp.getShapeForUnroll(); Location loc = contractOp.getLoc(); unsigned accIndex = vector::ContractionOp::getAccOperandIndex(); AffineMap dstAffineMap = contractOp.getIndexingMapsArray()[accIndex]; llvm::MapVector< SmallVector, Value, llvm::DenseMap, unsigned, OffsetMapInfo>> accCache; SmallVector loopOrder = getUnrollOrder( contractOp.getIteratorTypes().size(), contractOp, options); for (SmallVector offsets : StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) { SmallVector slicesOperands(contractOp.getNumOperands()); // Helper to compute the new shape of each operand and extract the slice. auto extractOperand = [&](unsigned index, Value operand, AffineMap permutationMap, ArrayRef operandOffets) { SmallVector operandShape = applyPermutationMap( permutationMap, ArrayRef(*targetShape)); SmallVector operandStrides(operandOffets.size(), 1); slicesOperands[index] = rewriter.create( loc, operand, operandOffets, operandShape, operandStrides); }; // Extract the new lhs operand. AffineMap lhsPermutationMap = contractOp.getIndexingMapsArray()[0]; SmallVector lhsOffets = applyPermutationMap(lhsPermutationMap, ArrayRef(offsets)); extractOperand(0, contractOp.getLhs(), lhsPermutationMap, lhsOffets); // Extract the new rhs operand. AffineMap rhsPermutationMap = contractOp.getIndexingMapsArray()[1]; SmallVector rhsOffets = applyPermutationMap(rhsPermutationMap, ArrayRef(offsets)); extractOperand(1, contractOp.getRhs(), rhsPermutationMap, rhsOffets); AffineMap accPermutationMap = contractOp.getIndexingMapsArray()[2]; SmallVector accOffets = applyPermutationMap(accPermutationMap, ArrayRef(offsets)); // If a version of the accumulator has already been computed, use it // otherwise extract the first version from the original operand. auto *accIt = accCache.find(accOffets); if (accIt != accCache.end()) slicesOperands[2] = accIt->second; else extractOperand(2, contractOp.getAcc(), accPermutationMap, accOffets); SmallVector dstShape = applyPermutationMap(dstAffineMap, ArrayRef(*targetShape)); auto targetType = VectorType::get(dstShape, dstVecType.getElementType()); Operation *newOp = cloneOpWithOperandsAndTypes( rewriter, loc, contractOp, slicesOperands, targetType); SmallVector dstOffets = applyPermutationMap(dstAffineMap, ArrayRef(offsets)); // Save the accumulated value untill all the loops are unrolled since // reduction loop keep updating the accumulator. accCache[dstOffets] = newOp->getResult(0); } // Assemble back the accumulator into a single vector. Value result = rewriter.create( loc, dstVecType, rewriter.getZeroAttr(dstVecType)); for (const auto &it : accCache) { SmallVector dstStrides(it.first.size(), 1); result = rewriter.create( loc, it.second, result, it.first, dstStrides); } rewriter.replaceOp(contractOp, result); return success(); } private: vector::UnrollVectorOptions options; }; struct UnrollMultiReductionPattern : public OpRewritePattern { UnrollMultiReductionPattern(MLIRContext *context, const vector::UnrollVectorOptions &options, PatternBenefit benefit = 1) : OpRewritePattern(context, benefit), options(options) {} LogicalResult matchAndRewrite(vector::MultiDimReductionOp reductionOp, PatternRewriter &rewriter) const override { std::optional> targetShape = getTargetShape(options, reductionOp); if (!targetShape) return failure(); SmallVector originalSize = *reductionOp.getShapeForUnroll(); llvm::MapVector< SmallVector, Value, llvm::DenseMap, unsigned, OffsetMapInfo>> accCache; Location loc = reductionOp.getLoc(); // Stride of the ratios, this gives us the offsets of sliceCount in a basis // of multiples of the targetShape. for (SmallVector offsets : StaticTileOffsetRange(originalSize, *targetShape)) { SmallVector operands; SmallVector operandStrides(offsets.size(), 1); Value slicedOperand = rewriter.create( loc, reductionOp.getSource(), offsets, *targetShape, operandStrides); operands.push_back(slicedOperand); SmallVector dstShape; SmallVector destOffset; for (size_t i : llvm::seq(size_t(0), targetShape->size())) { if (!reductionOp.isReducedDim(i)) { destOffset.push_back(offsets[i]); dstShape.push_back((*targetShape)[i]); } } Value acc; SmallVector accStrides(destOffset.size(), 1); // If a version of the accumulator has already been computed, use it // otherwise extract the first version from the original operand. auto *accIt = accCache.find(destOffset); if (accIt != accCache.end()) acc = accIt->second; else acc = rewriter.create( loc, reductionOp.getAcc(), destOffset, dstShape, accStrides); operands.push_back(acc); auto targetType = VectorType::get( dstShape, reductionOp.getSourceVectorType().getElementType()); Operation *newOp = cloneOpWithOperandsAndTypes(rewriter, loc, reductionOp, operands, targetType); Value result = newOp->getResult(0); accCache[destOffset] = result; } // Assemble back the accumulator into a single vector. Value result = rewriter.create( loc, reductionOp.getDestType(), rewriter.getZeroAttr(reductionOp.getDestType())); for (const auto &it : accCache) { SmallVector dstStrides(it.first.size(), 1); result = rewriter.create( loc, it.second, result, it.first, dstStrides); } rewriter.replaceOp(reductionOp, result); return success(); } private: vector::UnrollVectorOptions options; }; struct UnrollElementwisePattern : public RewritePattern { UnrollElementwisePattern(MLIRContext *context, const vector::UnrollVectorOptions &options, PatternBenefit benefit = 1) : RewritePattern(MatchAnyOpTypeTag(), benefit, context), options(options) {} LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { if (!OpTrait::hasElementwiseMappableTraits(op) || op->getNumResults() != 1) return failure(); auto targetShape = getTargetShape(options, op); if (!targetShape) return failure(); auto dstVecType = cast(op->getResult(0).getType()); SmallVector originalSize = *cast(op).getShapeForUnroll(); Location loc = op->getLoc(); // Prepare the result vector. Value result = rewriter.create( loc, dstVecType, rewriter.getZeroAttr(dstVecType)); SmallVector strides(targetShape->size(), 1); VectorType newVecType = VectorType::get(*targetShape, dstVecType.getElementType()); // Create the unrolled computation. for (SmallVector offsets : StaticTileOffsetRange(originalSize, *targetShape)) { SmallVector extractOperands; for (OpOperand &operand : op->getOpOperands()) { auto vecType = dyn_cast(operand.get().getType()); if (!vecType) { extractOperands.push_back(operand.get()); continue; } extractOperands.push_back( rewriter.create( loc, operand.get(), offsets, *targetShape, strides)); } Operation *newOp = cloneOpWithOperandsAndTypes( rewriter, loc, op, extractOperands, newVecType); result = rewriter.create( loc, newOp->getResult(0), result, offsets, strides); } rewriter.replaceOp(op, result); return success(); } private: vector::UnrollVectorOptions options; }; struct UnrollReductionPattern : public OpRewritePattern { UnrollReductionPattern(MLIRContext *context, const vector::UnrollVectorOptions &options, PatternBenefit benefit = 1) : OpRewritePattern(context, benefit), options(options) {} LogicalResult matchAndRewrite(vector::ReductionOp reductionOp, PatternRewriter &rewriter) const override { std::optional> targetShape = getTargetShape(options, reductionOp); if (!targetShape) return failure(); SmallVector originalSize = *reductionOp.getShapeForUnroll(); // Create unrolled vector reduction. Location loc = reductionOp.getLoc(); Value accumulator = nullptr; for (SmallVector offsets : StaticTileOffsetRange(originalSize, *targetShape)) { SmallVector strides(offsets.size(), 1); Value slicedOperand = rewriter.create( loc, reductionOp.getVector(), offsets, *targetShape, strides); Operation *newOp = cloneOpWithOperandsAndTypes( rewriter, loc, reductionOp, slicedOperand, reductionOp.getType()); Value result = newOp->getResult(0); if (!accumulator) { // This is the first reduction. accumulator = result; } else { // On subsequent reduction, combine with the accumulator. accumulator = makeArithReduction(rewriter, loc, reductionOp.getKind(), accumulator, result); } } rewriter.replaceOp(reductionOp, accumulator); return success(); } private: const vector::UnrollVectorOptions options; }; struct UnrollTransposePattern : public OpRewritePattern { UnrollTransposePattern(MLIRContext *context, const vector::UnrollVectorOptions &options, PatternBenefit benefit = 1) : OpRewritePattern(context, benefit), options(options) {} LogicalResult matchAndRewrite(vector::TransposeOp transposeOp, PatternRewriter &rewriter) const override { if (transposeOp.getResultVectorType().getRank() == 0) return failure(); auto targetShape = getTargetShape(options, transposeOp); if (!targetShape) return failure(); auto originalVectorType = transposeOp.getResultVectorType(); SmallVector strides(targetShape->size(), 1); Location loc = transposeOp.getLoc(); ArrayRef originalSize = originalVectorType.getShape(); // Prepare the result vector; Value result = rewriter.create( loc, originalVectorType, rewriter.getZeroAttr(originalVectorType)); ArrayRef permutation = transposeOp.getPermutation(); // Unroll the computation. for (SmallVector elementOffsets : StaticTileOffsetRange(originalSize, *targetShape)) { SmallVector permutedOffsets(elementOffsets.size()); SmallVector permutedShape(elementOffsets.size()); // Compute the source offsets and shape. for (auto indices : llvm::enumerate(permutation)) { permutedOffsets[indices.value()] = elementOffsets[indices.index()]; permutedShape[indices.value()] = (*targetShape)[indices.index()]; } Value slicedOperand = rewriter.create( loc, transposeOp.getVector(), permutedOffsets, permutedShape, strides); Value transposedSlice = rewriter.create(loc, slicedOperand, permutation); result = rewriter.create( loc, transposedSlice, result, elementOffsets, strides); } rewriter.replaceOp(transposeOp, result); return success(); } private: vector::UnrollVectorOptions options; }; struct UnrollGatherPattern : public OpRewritePattern { UnrollGatherPattern(MLIRContext *context, const vector::UnrollVectorOptions &options, PatternBenefit benefit = 1) : OpRewritePattern(context, benefit), options(options) { } LogicalResult matchAndRewrite(vector::GatherOp gatherOp, PatternRewriter &rewriter) const override { VectorType sourceVectorType = gatherOp.getVectorType(); if (sourceVectorType.getRank() == 0) return failure(); auto targetShape = getTargetShape(options, gatherOp); if (!targetShape) return failure(); SmallVector strides(targetShape->size(), 1); Location loc = gatherOp.getLoc(); ArrayRef originalSize = gatherOp.getVectorType().getShape(); // Prepare the result vector; Value result = rewriter.create( loc, sourceVectorType, rewriter.getZeroAttr(sourceVectorType)); auto targetType = VectorType::get(*targetShape, sourceVectorType.getElementType()); SmallVector loopOrder = getUnrollOrder(originalSize.size(), gatherOp, options); for (SmallVector elementOffsets : StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) { // To get the unrolled gather, extract the same slice based on the // decomposed shape from each of the index, mask, and pass-through // vectors. Value indexSubVec = rewriter.create( loc, gatherOp.getIndexVec(), elementOffsets, *targetShape, strides); Value maskSubVec = rewriter.create( loc, gatherOp.getMask(), elementOffsets, *targetShape, strides); Value passThruSubVec = rewriter.create( loc, gatherOp.getPassThru(), elementOffsets, *targetShape, strides); auto slicedGather = rewriter.create( loc, targetType, gatherOp.getBase(), gatherOp.getIndices(), indexSubVec, maskSubVec, passThruSubVec); result = rewriter.create( loc, slicedGather, result, elementOffsets, strides); } rewriter.replaceOp(gatherOp, result); return success(); } private: vector::UnrollVectorOptions options; }; } // namespace void mlir::vector::populateVectorUnrollPatterns( RewritePatternSet &patterns, const UnrollVectorOptions &options, PatternBenefit benefit) { patterns.add( patterns.getContext(), options, benefit); }