154db8cc7SThomas Raoux //===- VectorUnrollDistribute.cpp - patterns to do vector unrolling -------===// 254db8cc7SThomas Raoux // 354db8cc7SThomas Raoux // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 454db8cc7SThomas Raoux // See https://llvm.org/LICENSE.txt for license information. 554db8cc7SThomas Raoux // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 654db8cc7SThomas Raoux // 754db8cc7SThomas Raoux //===----------------------------------------------------------------------===// 854db8cc7SThomas Raoux // 954db8cc7SThomas Raoux // This file implements patterns to do vector unrolling and vector distribution. 1054db8cc7SThomas Raoux // 1154db8cc7SThomas Raoux //===----------------------------------------------------------------------===// 1254db8cc7SThomas Raoux 1354db8cc7SThomas Raoux #include "mlir/Dialect/Affine/IR/AffineOps.h" 1454db8cc7SThomas Raoux #include "mlir/Dialect/Utils/IndexingUtils.h" 1554db8cc7SThomas Raoux #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" 1654db8cc7SThomas Raoux #include "mlir/IR/ImplicitLocOpBuilder.h" 1754db8cc7SThomas Raoux #include "mlir/Interfaces/VectorInterfaces.h" 1854db8cc7SThomas Raoux #include "llvm/ADT/MapVector.h" 1954db8cc7SThomas Raoux #include "llvm/ADT/STLExtras.h" 20e35ff260SNicolas Vasilache #include "llvm/Support/Debug.h" 2154db8cc7SThomas Raoux #include <numeric> 22a1fe1f5fSKazu Hirata #include <optional> 2354db8cc7SThomas Raoux 24e35ff260SNicolas Vasilache #define DEBUG_TYPE "vector-unroll" 25e35ff260SNicolas Vasilache #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") 26e35ff260SNicolas Vasilache #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") 2754db8cc7SThomas Raoux 2854db8cc7SThomas Raoux using namespace mlir; 2954db8cc7SThomas Raoux using namespace mlir::vector; 3054db8cc7SThomas Raoux 3154db8cc7SThomas Raoux /// Compute the indices of the slice `index` for a tranfer op. 3254db8cc7SThomas Raoux static SmallVector<Value> sliceTransferIndices(ArrayRef<int64_t> elementOffsets, 3354db8cc7SThomas Raoux ArrayRef<Value> indices, 3454db8cc7SThomas Raoux AffineMap permutationMap, 3554db8cc7SThomas Raoux Location loc, 3654db8cc7SThomas Raoux OpBuilder &builder) { 3754db8cc7SThomas Raoux MLIRContext *ctx = builder.getContext(); 3854db8cc7SThomas Raoux auto isBroadcast = [](AffineExpr expr) { 391609f1c2Slong.chen if (auto constExpr = dyn_cast<AffineConstantExpr>(expr)) 4054db8cc7SThomas Raoux return constExpr.getValue() == 0; 4154db8cc7SThomas Raoux return false; 4254db8cc7SThomas Raoux }; 4354db8cc7SThomas Raoux // Compute 'sliceIndices' by adding 'sliceOffsets[i]' to 'indices[i]'. 44*5262865aSKazu Hirata SmallVector<Value> slicedIndices(indices); 4554db8cc7SThomas Raoux for (const auto &dim : llvm::enumerate(permutationMap.getResults())) { 4654db8cc7SThomas Raoux if (isBroadcast(dim.value())) 4754db8cc7SThomas Raoux continue; 481609f1c2Slong.chen unsigned pos = cast<AffineDimExpr>(dim.value()).getPosition(); 4954db8cc7SThomas Raoux auto expr = getAffineDimExpr(0, builder.getContext()) + 5054db8cc7SThomas Raoux getAffineConstantExpr(elementOffsets[dim.index()], ctx); 5154db8cc7SThomas Raoux auto map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr); 524c48f016SMatthias Springer slicedIndices[pos] = 534c48f016SMatthias Springer builder.create<affine::AffineApplyOp>(loc, map, indices[pos]); 5454db8cc7SThomas Raoux } 5554db8cc7SThomas Raoux return slicedIndices; 5654db8cc7SThomas Raoux } 5754db8cc7SThomas Raoux 5854db8cc7SThomas Raoux // Clones `op` into a new operations that takes `operands` and returns 5954db8cc7SThomas Raoux // `resultTypes`. 6054db8cc7SThomas Raoux static Operation *cloneOpWithOperandsAndTypes(OpBuilder &builder, Location loc, 6154db8cc7SThomas Raoux Operation *op, 6254db8cc7SThomas Raoux ArrayRef<Value> operands, 6354db8cc7SThomas Raoux ArrayRef<Type> resultTypes) { 6454db8cc7SThomas Raoux return builder.create(loc, op->getName().getIdentifier(), operands, 6554db8cc7SThomas Raoux resultTypes, op->getAttrs()); 6654db8cc7SThomas Raoux } 6754db8cc7SThomas Raoux 6870c73d1bSKazu Hirata /// Return the target shape for unrolling for the given `op`. Return 6970c73d1bSKazu Hirata /// std::nullopt if the op shouldn't be or cannot be unrolled. 700a81ace0SKazu Hirata static std::optional<SmallVector<int64_t>> 7154db8cc7SThomas Raoux getTargetShape(const vector::UnrollVectorOptions &options, Operation *op) { 72e35ff260SNicolas Vasilache LDBG(""); 73e35ff260SNicolas Vasilache LDBG("Get unroll shape for op " << op->getName().getStringRef()); 74e35ff260SNicolas Vasilache if (options.filterConstraint && failed(options.filterConstraint(op))) { 75e35ff260SNicolas Vasilache LDBG("--no filter constraint -> BAIL"); 761a36588eSKazu Hirata return std::nullopt; 77e35ff260SNicolas Vasilache } 7854db8cc7SThomas Raoux assert(options.nativeShape && 7954db8cc7SThomas Raoux "vector unrolling expects the native shape or native" 8054db8cc7SThomas Raoux "shape call back function to be set"); 8154db8cc7SThomas Raoux auto unrollableVectorOp = dyn_cast<VectorUnrollOpInterface>(op); 82e35ff260SNicolas Vasilache if (!unrollableVectorOp) { 83e35ff260SNicolas Vasilache LDBG("--not an unrollable op -> BAIL"); 841a36588eSKazu Hirata return std::nullopt; 85e35ff260SNicolas Vasilache } 8654db8cc7SThomas Raoux auto maybeUnrollShape = unrollableVectorOp.getShapeForUnroll(); 87e35ff260SNicolas Vasilache if (!maybeUnrollShape) { 88e35ff260SNicolas Vasilache LDBG("--could not get shape of op " << *op << " -> BAIL"); 891a36588eSKazu Hirata return std::nullopt; 90e35ff260SNicolas Vasilache } 91e35ff260SNicolas Vasilache LLVM_DEBUG( 92e35ff260SNicolas Vasilache llvm::interleaveComma(*maybeUnrollShape, DBGS() << "--vector op shape: "); 93e35ff260SNicolas Vasilache llvm::dbgs() << "\n";); 94e35ff260SNicolas Vasilache 950a81ace0SKazu Hirata std::optional<SmallVector<int64_t>> targetShape = options.nativeShape(op); 96e35ff260SNicolas Vasilache if (!targetShape) { 97e35ff260SNicolas Vasilache LDBG("--no unrolling target shape defined " << *op << "-> SKIP"); 981a36588eSKazu Hirata return std::nullopt; 99e35ff260SNicolas Vasilache } 100e35ff260SNicolas Vasilache LLVM_DEBUG(llvm::interleaveComma(*targetShape, DBGS() << "--target shape: "); 101e35ff260SNicolas Vasilache llvm::dbgs() << "\n";); 102e35ff260SNicolas Vasilache 1037a69a9d7SNicolas Vasilache auto maybeShapeRatio = computeShapeRatio(*maybeUnrollShape, *targetShape); 104e35ff260SNicolas Vasilache if (!maybeShapeRatio) { 105e35ff260SNicolas Vasilache LDBG("--could not compute integral shape ratio -> BAIL"); 1061a36588eSKazu Hirata return std::nullopt; 107e35ff260SNicolas Vasilache } 108e35ff260SNicolas Vasilache if (llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; })) { 109e35ff260SNicolas Vasilache LDBG("--no unrolling needed -> SKIP"); 110e35ff260SNicolas Vasilache return std::nullopt; 111e35ff260SNicolas Vasilache } 112e35ff260SNicolas Vasilache LDBG("--found an integral shape ratio to unroll to -> SUCCESS"); 11354db8cc7SThomas Raoux return targetShape; 11454db8cc7SThomas Raoux } 11554db8cc7SThomas Raoux 11654db8cc7SThomas Raoux static SmallVector<int64_t> 11754db8cc7SThomas Raoux getUnrollOrder(unsigned numLoops, Operation *op, 11854db8cc7SThomas Raoux const vector::UnrollVectorOptions &options) { 11954db8cc7SThomas Raoux SmallVector<int64_t> loopOrder = 12054db8cc7SThomas Raoux llvm::to_vector(llvm::seq<int64_t>(0, static_cast<int64_t>(numLoops))); 12154db8cc7SThomas Raoux if (options.traversalOrderCallback != nullptr) { 1220a81ace0SKazu Hirata std::optional<SmallVector<int64_t>> order = 1230a81ace0SKazu Hirata options.traversalOrderCallback(op); 12454db8cc7SThomas Raoux if (order) { 12554db8cc7SThomas Raoux loopOrder = std::move(*order); 12654db8cc7SThomas Raoux } 12754db8cc7SThomas Raoux } 12854db8cc7SThomas Raoux return loopOrder; 12954db8cc7SThomas Raoux } 13054db8cc7SThomas Raoux 13154db8cc7SThomas Raoux namespace { 13254db8cc7SThomas Raoux 13354db8cc7SThomas Raoux struct UnrollTransferReadPattern 13454db8cc7SThomas Raoux : public OpRewritePattern<vector::TransferReadOp> { 13554db8cc7SThomas Raoux UnrollTransferReadPattern(MLIRContext *context, 13654db8cc7SThomas Raoux const vector::UnrollVectorOptions &options, 13754db8cc7SThomas Raoux PatternBenefit benefit = 1) 13854db8cc7SThomas Raoux : OpRewritePattern<vector::TransferReadOp>(context, benefit), 13954db8cc7SThomas Raoux options(options) {} 14054db8cc7SThomas Raoux 14154db8cc7SThomas Raoux LogicalResult matchAndRewrite(vector::TransferReadOp readOp, 14254db8cc7SThomas Raoux PatternRewriter &rewriter) const override { 14354db8cc7SThomas Raoux // TODO: support 0-d corner case. 14454db8cc7SThomas Raoux if (readOp.getTransferRank() == 0) 14554db8cc7SThomas Raoux return failure(); 14654db8cc7SThomas Raoux if (readOp.getMask()) 14754db8cc7SThomas Raoux return failure(); 14854db8cc7SThomas Raoux auto targetShape = getTargetShape(options, readOp); 14954db8cc7SThomas Raoux if (!targetShape) 15054db8cc7SThomas Raoux return failure(); 15154db8cc7SThomas Raoux auto sourceVectorType = readOp.getVectorType(); 1527a69a9d7SNicolas Vasilache SmallVector<int64_t> strides(targetShape->size(), 1); 15354db8cc7SThomas Raoux Location loc = readOp.getLoc(); 15454db8cc7SThomas Raoux ArrayRef<int64_t> originalSize = readOp.getVectorType().getShape(); 15554db8cc7SThomas Raoux 15654db8cc7SThomas Raoux // Prepare the result vector; 15754db8cc7SThomas Raoux Value result = rewriter.create<arith::ConstantOp>( 15854db8cc7SThomas Raoux loc, sourceVectorType, rewriter.getZeroAttr(sourceVectorType)); 15954db8cc7SThomas Raoux auto targetType = 16054db8cc7SThomas Raoux VectorType::get(*targetShape, sourceVectorType.getElementType()); 1617a69a9d7SNicolas Vasilache SmallVector<Value> originalIndices(readOp.getIndices().begin(), 16254db8cc7SThomas Raoux readOp.getIndices().end()); 16354db8cc7SThomas Raoux SmallVector<int64_t> loopOrder = 16454db8cc7SThomas Raoux getUnrollOrder(originalSize.size(), readOp, options); 165831041beSChristopher Bate for (SmallVector<int64_t> elementOffsets : 166831041beSChristopher Bate StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) { 1677a69a9d7SNicolas Vasilache SmallVector<Value> indices = 16854db8cc7SThomas Raoux sliceTransferIndices(elementOffsets, originalIndices, 16954db8cc7SThomas Raoux readOp.getPermutationMap(), loc, rewriter); 17054db8cc7SThomas Raoux auto slicedRead = rewriter.create<vector::TransferReadOp>( 17154db8cc7SThomas Raoux loc, targetType, readOp.getSource(), indices, 17254db8cc7SThomas Raoux readOp.getPermutationMapAttr(), readOp.getPadding(), readOp.getMask(), 17354db8cc7SThomas Raoux readOp.getInBoundsAttr()); 17454db8cc7SThomas Raoux 17554db8cc7SThomas Raoux result = rewriter.create<vector::InsertStridedSliceOp>( 17654db8cc7SThomas Raoux loc, slicedRead, result, elementOffsets, strides); 17754db8cc7SThomas Raoux } 17854db8cc7SThomas Raoux rewriter.replaceOp(readOp, result); 17954db8cc7SThomas Raoux return success(); 18054db8cc7SThomas Raoux } 18154db8cc7SThomas Raoux 18254db8cc7SThomas Raoux private: 18354db8cc7SThomas Raoux vector::UnrollVectorOptions options; 18454db8cc7SThomas Raoux }; 18554db8cc7SThomas Raoux 18654db8cc7SThomas Raoux struct UnrollTransferWritePattern 18754db8cc7SThomas Raoux : public OpRewritePattern<vector::TransferWriteOp> { 18854db8cc7SThomas Raoux UnrollTransferWritePattern(MLIRContext *context, 18954db8cc7SThomas Raoux const vector::UnrollVectorOptions &options, 19054db8cc7SThomas Raoux PatternBenefit benefit = 1) 19154db8cc7SThomas Raoux : OpRewritePattern<vector::TransferWriteOp>(context, benefit), 19254db8cc7SThomas Raoux options(options) {} 19354db8cc7SThomas Raoux 19454db8cc7SThomas Raoux LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp, 19554db8cc7SThomas Raoux PatternRewriter &rewriter) const override { 19654db8cc7SThomas Raoux // TODO: support 0-d corner case. 19754db8cc7SThomas Raoux if (writeOp.getTransferRank() == 0) 19854db8cc7SThomas Raoux return failure(); 19954db8cc7SThomas Raoux 20054db8cc7SThomas Raoux if (writeOp.getMask()) 20154db8cc7SThomas Raoux return failure(); 20254db8cc7SThomas Raoux auto targetShape = getTargetShape(options, writeOp); 20354db8cc7SThomas Raoux if (!targetShape) 20454db8cc7SThomas Raoux return failure(); 20554db8cc7SThomas Raoux auto sourceVectorType = writeOp.getVectorType(); 2067a69a9d7SNicolas Vasilache SmallVector<int64_t> strides(targetShape->size(), 1); 20754db8cc7SThomas Raoux Location loc = writeOp.getLoc(); 20854db8cc7SThomas Raoux ArrayRef<int64_t> originalSize = sourceVectorType.getShape(); 2097a69a9d7SNicolas Vasilache SmallVector<Value> originalIndices(writeOp.getIndices().begin(), 21054db8cc7SThomas Raoux writeOp.getIndices().end()); 21154db8cc7SThomas Raoux SmallVector<int64_t> loopOrder = 21254db8cc7SThomas Raoux getUnrollOrder(originalSize.size(), writeOp, options); 21354db8cc7SThomas Raoux Value resultTensor; 214831041beSChristopher Bate for (SmallVector<int64_t> elementOffsets : 215831041beSChristopher Bate StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) { 21654db8cc7SThomas Raoux Value slicedVector = rewriter.create<vector::ExtractStridedSliceOp>( 21754db8cc7SThomas Raoux loc, writeOp.getVector(), elementOffsets, *targetShape, strides); 2187a69a9d7SNicolas Vasilache SmallVector<Value> indices = 21954db8cc7SThomas Raoux sliceTransferIndices(elementOffsets, originalIndices, 22054db8cc7SThomas Raoux writeOp.getPermutationMap(), loc, rewriter); 22154db8cc7SThomas Raoux Operation *slicedWrite = rewriter.create<vector::TransferWriteOp>( 22254db8cc7SThomas Raoux loc, slicedVector, resultTensor ? resultTensor : writeOp.getSource(), 22354db8cc7SThomas Raoux indices, writeOp.getPermutationMapAttr(), writeOp.getInBoundsAttr()); 22454db8cc7SThomas Raoux // For the tensor case update the destination for the next transfer write. 22554db8cc7SThomas Raoux if (!slicedWrite->getResults().empty()) 22654db8cc7SThomas Raoux resultTensor = slicedWrite->getResult(0); 22754db8cc7SThomas Raoux } 22854db8cc7SThomas Raoux if (resultTensor) 22954db8cc7SThomas Raoux rewriter.replaceOp(writeOp, resultTensor); 23054db8cc7SThomas Raoux else 23154db8cc7SThomas Raoux rewriter.eraseOp(writeOp); 23254db8cc7SThomas Raoux return success(); 23354db8cc7SThomas Raoux } 23454db8cc7SThomas Raoux 23554db8cc7SThomas Raoux private: 23654db8cc7SThomas Raoux vector::UnrollVectorOptions options; 23754db8cc7SThomas Raoux }; 23854db8cc7SThomas Raoux 23954db8cc7SThomas Raoux struct OffsetMapInfo { 24054db8cc7SThomas Raoux static SmallVector<int64_t> getEmptyKey() { return {int64_t(-1)}; } 24154db8cc7SThomas Raoux 24254db8cc7SThomas Raoux static SmallVector<int64_t> getTombstoneKey() { return {int64_t(-2)}; } 24354db8cc7SThomas Raoux 24454db8cc7SThomas Raoux static unsigned getHashValue(const SmallVector<int64_t> &v) { 24554db8cc7SThomas Raoux return static_cast<unsigned>(llvm::hash_combine_range(v.begin(), v.end())); 24654db8cc7SThomas Raoux } 24754db8cc7SThomas Raoux 24854db8cc7SThomas Raoux static bool isEqual(const SmallVector<int64_t> &lhs, 24954db8cc7SThomas Raoux const SmallVector<int64_t> &rhs) { 25054db8cc7SThomas Raoux return lhs == rhs; 25154db8cc7SThomas Raoux } 25254db8cc7SThomas Raoux }; 25354db8cc7SThomas Raoux 25454db8cc7SThomas Raoux struct UnrollContractionPattern 25554db8cc7SThomas Raoux : public OpRewritePattern<vector::ContractionOp> { 25654db8cc7SThomas Raoux UnrollContractionPattern(MLIRContext *context, 25754db8cc7SThomas Raoux const vector::UnrollVectorOptions &options, 25854db8cc7SThomas Raoux PatternBenefit benefit = 1) 25954db8cc7SThomas Raoux : OpRewritePattern<vector::ContractionOp>(context, benefit), 26054db8cc7SThomas Raoux options(options) {} 26154db8cc7SThomas Raoux 26254db8cc7SThomas Raoux LogicalResult matchAndRewrite(vector::ContractionOp contractOp, 26354db8cc7SThomas Raoux PatternRewriter &rewriter) const override { 26454db8cc7SThomas Raoux auto targetShape = getTargetShape(options, contractOp); 26554db8cc7SThomas Raoux if (!targetShape) 26654db8cc7SThomas Raoux return failure(); 2675550c821STres Popp auto dstVecType = cast<VectorType>(contractOp.getResultType()); 2687a69a9d7SNicolas Vasilache SmallVector<int64_t> originalSize = *contractOp.getShapeForUnroll(); 26954db8cc7SThomas Raoux 27054db8cc7SThomas Raoux Location loc = contractOp.getLoc(); 27154db8cc7SThomas Raoux unsigned accIndex = vector::ContractionOp::getAccOperandIndex(); 27254db8cc7SThomas Raoux AffineMap dstAffineMap = contractOp.getIndexingMapsArray()[accIndex]; 27354db8cc7SThomas Raoux llvm::MapVector< 27454db8cc7SThomas Raoux SmallVector<int64_t>, Value, 27554db8cc7SThomas Raoux llvm::DenseMap<SmallVector<int64_t>, unsigned, OffsetMapInfo>> 27654db8cc7SThomas Raoux accCache; 27754db8cc7SThomas Raoux 27854db8cc7SThomas Raoux SmallVector<int64_t> loopOrder = getUnrollOrder( 27954db8cc7SThomas Raoux contractOp.getIteratorTypes().size(), contractOp, options); 280831041beSChristopher Bate 281831041beSChristopher Bate for (SmallVector<int64_t> offsets : 282831041beSChristopher Bate StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) { 2837a69a9d7SNicolas Vasilache SmallVector<Value> slicesOperands(contractOp.getNumOperands()); 28454db8cc7SThomas Raoux 2857a69a9d7SNicolas Vasilache // Helper to compute the new shape of each operand and extract the slice. 28654db8cc7SThomas Raoux auto extractOperand = [&](unsigned index, Value operand, 28754db8cc7SThomas Raoux AffineMap permutationMap, 28854db8cc7SThomas Raoux ArrayRef<int64_t> operandOffets) { 28954db8cc7SThomas Raoux SmallVector<int64_t> operandShape = applyPermutationMap( 29054db8cc7SThomas Raoux permutationMap, ArrayRef<int64_t>(*targetShape)); 2917a69a9d7SNicolas Vasilache SmallVector<int64_t> operandStrides(operandOffets.size(), 1); 29254db8cc7SThomas Raoux slicesOperands[index] = rewriter.create<vector::ExtractStridedSliceOp>( 29354db8cc7SThomas Raoux loc, operand, operandOffets, operandShape, operandStrides); 29454db8cc7SThomas Raoux }; 29554db8cc7SThomas Raoux 29654db8cc7SThomas Raoux // Extract the new lhs operand. 29754db8cc7SThomas Raoux AffineMap lhsPermutationMap = contractOp.getIndexingMapsArray()[0]; 29854db8cc7SThomas Raoux SmallVector<int64_t> lhsOffets = 29954db8cc7SThomas Raoux applyPermutationMap(lhsPermutationMap, ArrayRef<int64_t>(offsets)); 30054db8cc7SThomas Raoux extractOperand(0, contractOp.getLhs(), lhsPermutationMap, lhsOffets); 30154db8cc7SThomas Raoux 30254db8cc7SThomas Raoux // Extract the new rhs operand. 30354db8cc7SThomas Raoux AffineMap rhsPermutationMap = contractOp.getIndexingMapsArray()[1]; 30454db8cc7SThomas Raoux SmallVector<int64_t> rhsOffets = 30554db8cc7SThomas Raoux applyPermutationMap(rhsPermutationMap, ArrayRef<int64_t>(offsets)); 30654db8cc7SThomas Raoux extractOperand(1, contractOp.getRhs(), rhsPermutationMap, rhsOffets); 30754db8cc7SThomas Raoux 30854db8cc7SThomas Raoux AffineMap accPermutationMap = contractOp.getIndexingMapsArray()[2]; 30954db8cc7SThomas Raoux SmallVector<int64_t> accOffets = 31054db8cc7SThomas Raoux applyPermutationMap(accPermutationMap, ArrayRef<int64_t>(offsets)); 31154db8cc7SThomas Raoux // If a version of the accumulator has already been computed, use it 31254db8cc7SThomas Raoux // otherwise extract the first version from the original operand. 31389dc313aSMehdi Amini auto *accIt = accCache.find(accOffets); 31454db8cc7SThomas Raoux if (accIt != accCache.end()) 31554db8cc7SThomas Raoux slicesOperands[2] = accIt->second; 31654db8cc7SThomas Raoux else 31754db8cc7SThomas Raoux extractOperand(2, contractOp.getAcc(), accPermutationMap, accOffets); 31854db8cc7SThomas Raoux 31954db8cc7SThomas Raoux SmallVector<int64_t> dstShape = 32054db8cc7SThomas Raoux applyPermutationMap(dstAffineMap, ArrayRef<int64_t>(*targetShape)); 32154db8cc7SThomas Raoux auto targetType = VectorType::get(dstShape, dstVecType.getElementType()); 32254db8cc7SThomas Raoux Operation *newOp = cloneOpWithOperandsAndTypes( 32354db8cc7SThomas Raoux rewriter, loc, contractOp, slicesOperands, targetType); 32454db8cc7SThomas Raoux 32554db8cc7SThomas Raoux SmallVector<int64_t> dstOffets = 32654db8cc7SThomas Raoux applyPermutationMap(dstAffineMap, ArrayRef<int64_t>(offsets)); 32754db8cc7SThomas Raoux // Save the accumulated value untill all the loops are unrolled since 32854db8cc7SThomas Raoux // reduction loop keep updating the accumulator. 32954db8cc7SThomas Raoux accCache[dstOffets] = newOp->getResult(0); 33054db8cc7SThomas Raoux } 33154db8cc7SThomas Raoux // Assemble back the accumulator into a single vector. 33254db8cc7SThomas Raoux Value result = rewriter.create<arith::ConstantOp>( 33354db8cc7SThomas Raoux loc, dstVecType, rewriter.getZeroAttr(dstVecType)); 33454db8cc7SThomas Raoux for (const auto &it : accCache) { 33554db8cc7SThomas Raoux SmallVector<int64_t> dstStrides(it.first.size(), 1); 33654db8cc7SThomas Raoux result = rewriter.create<vector::InsertStridedSliceOp>( 33754db8cc7SThomas Raoux loc, it.second, result, it.first, dstStrides); 33854db8cc7SThomas Raoux } 33954db8cc7SThomas Raoux rewriter.replaceOp(contractOp, result); 34054db8cc7SThomas Raoux return success(); 34154db8cc7SThomas Raoux } 34254db8cc7SThomas Raoux 34354db8cc7SThomas Raoux private: 34454db8cc7SThomas Raoux vector::UnrollVectorOptions options; 34554db8cc7SThomas Raoux }; 34654db8cc7SThomas Raoux 34754db8cc7SThomas Raoux struct UnrollMultiReductionPattern 34854db8cc7SThomas Raoux : public OpRewritePattern<vector::MultiDimReductionOp> { 34954db8cc7SThomas Raoux UnrollMultiReductionPattern(MLIRContext *context, 35054db8cc7SThomas Raoux const vector::UnrollVectorOptions &options, 35154db8cc7SThomas Raoux PatternBenefit benefit = 1) 35254db8cc7SThomas Raoux : OpRewritePattern<vector::MultiDimReductionOp>(context, benefit), 35354db8cc7SThomas Raoux options(options) {} 35454db8cc7SThomas Raoux 35554db8cc7SThomas Raoux LogicalResult matchAndRewrite(vector::MultiDimReductionOp reductionOp, 35654db8cc7SThomas Raoux PatternRewriter &rewriter) const override { 3570a81ace0SKazu Hirata std::optional<SmallVector<int64_t>> targetShape = 35854db8cc7SThomas Raoux getTargetShape(options, reductionOp); 35954db8cc7SThomas Raoux if (!targetShape) 36054db8cc7SThomas Raoux return failure(); 3617a69a9d7SNicolas Vasilache SmallVector<int64_t> originalSize = *reductionOp.getShapeForUnroll(); 36254db8cc7SThomas Raoux llvm::MapVector< 36354db8cc7SThomas Raoux SmallVector<int64_t>, Value, 36454db8cc7SThomas Raoux llvm::DenseMap<SmallVector<int64_t>, unsigned, OffsetMapInfo>> 36554db8cc7SThomas Raoux accCache; 36654db8cc7SThomas Raoux Location loc = reductionOp.getLoc(); 3677a69a9d7SNicolas Vasilache 3687a69a9d7SNicolas Vasilache // Stride of the ratios, this gives us the offsets of sliceCount in a basis 3697a69a9d7SNicolas Vasilache // of multiples of the targetShape. 370831041beSChristopher Bate for (SmallVector<int64_t> offsets : 371831041beSChristopher Bate StaticTileOffsetRange(originalSize, *targetShape)) { 37254db8cc7SThomas Raoux SmallVector<Value> operands; 3737a69a9d7SNicolas Vasilache SmallVector<int64_t> operandStrides(offsets.size(), 1); 37454db8cc7SThomas Raoux Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>( 37554db8cc7SThomas Raoux loc, reductionOp.getSource(), offsets, *targetShape, operandStrides); 37654db8cc7SThomas Raoux operands.push_back(slicedOperand); 37754db8cc7SThomas Raoux SmallVector<int64_t> dstShape; 37854db8cc7SThomas Raoux SmallVector<int64_t> destOffset; 37954db8cc7SThomas Raoux for (size_t i : llvm::seq(size_t(0), targetShape->size())) { 38054db8cc7SThomas Raoux if (!reductionOp.isReducedDim(i)) { 38154db8cc7SThomas Raoux destOffset.push_back(offsets[i]); 38254db8cc7SThomas Raoux dstShape.push_back((*targetShape)[i]); 38354db8cc7SThomas Raoux } 38454db8cc7SThomas Raoux } 38554db8cc7SThomas Raoux Value acc; 3867a69a9d7SNicolas Vasilache SmallVector<int64_t> accStrides(destOffset.size(), 1); 38754db8cc7SThomas Raoux // If a version of the accumulator has already been computed, use it 38854db8cc7SThomas Raoux // otherwise extract the first version from the original operand. 38989dc313aSMehdi Amini auto *accIt = accCache.find(destOffset); 39054db8cc7SThomas Raoux if (accIt != accCache.end()) 39154db8cc7SThomas Raoux acc = accIt->second; 39254db8cc7SThomas Raoux else 39354db8cc7SThomas Raoux acc = rewriter.create<vector::ExtractStridedSliceOp>( 39454db8cc7SThomas Raoux loc, reductionOp.getAcc(), destOffset, dstShape, accStrides); 39554db8cc7SThomas Raoux operands.push_back(acc); 39654db8cc7SThomas Raoux auto targetType = VectorType::get( 39754db8cc7SThomas Raoux dstShape, reductionOp.getSourceVectorType().getElementType()); 39854db8cc7SThomas Raoux Operation *newOp = cloneOpWithOperandsAndTypes(rewriter, loc, reductionOp, 39954db8cc7SThomas Raoux operands, targetType); 40054db8cc7SThomas Raoux Value result = newOp->getResult(0); 40154db8cc7SThomas Raoux accCache[destOffset] = result; 40254db8cc7SThomas Raoux } 40354db8cc7SThomas Raoux // Assemble back the accumulator into a single vector. 40454db8cc7SThomas Raoux Value result = rewriter.create<arith::ConstantOp>( 40554db8cc7SThomas Raoux loc, reductionOp.getDestType(), 40654db8cc7SThomas Raoux rewriter.getZeroAttr(reductionOp.getDestType())); 40754db8cc7SThomas Raoux for (const auto &it : accCache) { 40854db8cc7SThomas Raoux SmallVector<int64_t> dstStrides(it.first.size(), 1); 40954db8cc7SThomas Raoux result = rewriter.create<vector::InsertStridedSliceOp>( 41054db8cc7SThomas Raoux loc, it.second, result, it.first, dstStrides); 41154db8cc7SThomas Raoux } 41254db8cc7SThomas Raoux rewriter.replaceOp(reductionOp, result); 41354db8cc7SThomas Raoux return success(); 41454db8cc7SThomas Raoux } 41554db8cc7SThomas Raoux 41654db8cc7SThomas Raoux private: 41754db8cc7SThomas Raoux vector::UnrollVectorOptions options; 41854db8cc7SThomas Raoux }; 41954db8cc7SThomas Raoux 42054db8cc7SThomas Raoux struct UnrollElementwisePattern : public RewritePattern { 42154db8cc7SThomas Raoux UnrollElementwisePattern(MLIRContext *context, 42254db8cc7SThomas Raoux const vector::UnrollVectorOptions &options, 42354db8cc7SThomas Raoux PatternBenefit benefit = 1) 42454db8cc7SThomas Raoux : RewritePattern(MatchAnyOpTypeTag(), benefit, context), 42554db8cc7SThomas Raoux options(options) {} 42654db8cc7SThomas Raoux 42754db8cc7SThomas Raoux LogicalResult matchAndRewrite(Operation *op, 42854db8cc7SThomas Raoux PatternRewriter &rewriter) const override { 42954db8cc7SThomas Raoux if (!OpTrait::hasElementwiseMappableTraits(op) || op->getNumResults() != 1) 43054db8cc7SThomas Raoux return failure(); 43154db8cc7SThomas Raoux auto targetShape = getTargetShape(options, op); 43254db8cc7SThomas Raoux if (!targetShape) 43354db8cc7SThomas Raoux return failure(); 4345550c821STres Popp auto dstVecType = cast<VectorType>(op->getResult(0).getType()); 4357a69a9d7SNicolas Vasilache SmallVector<int64_t> originalSize = 43654db8cc7SThomas Raoux *cast<VectorUnrollOpInterface>(op).getShapeForUnroll(); 43754db8cc7SThomas Raoux Location loc = op->getLoc(); 43854db8cc7SThomas Raoux // Prepare the result vector. 43954db8cc7SThomas Raoux Value result = rewriter.create<arith::ConstantOp>( 44054db8cc7SThomas Raoux loc, dstVecType, rewriter.getZeroAttr(dstVecType)); 4417a69a9d7SNicolas Vasilache SmallVector<int64_t> strides(targetShape->size(), 1); 44254db8cc7SThomas Raoux VectorType newVecType = 44354db8cc7SThomas Raoux VectorType::get(*targetShape, dstVecType.getElementType()); 4447a69a9d7SNicolas Vasilache 445831041beSChristopher Bate // Create the unrolled computation. 446831041beSChristopher Bate for (SmallVector<int64_t> offsets : 447831041beSChristopher Bate StaticTileOffsetRange(originalSize, *targetShape)) { 4487a69a9d7SNicolas Vasilache SmallVector<Value> extractOperands; 44954db8cc7SThomas Raoux for (OpOperand &operand : op->getOpOperands()) { 4505550c821STres Popp auto vecType = dyn_cast<VectorType>(operand.get().getType()); 45154db8cc7SThomas Raoux if (!vecType) { 45254db8cc7SThomas Raoux extractOperands.push_back(operand.get()); 45354db8cc7SThomas Raoux continue; 45454db8cc7SThomas Raoux } 45554db8cc7SThomas Raoux extractOperands.push_back( 45654db8cc7SThomas Raoux rewriter.create<vector::ExtractStridedSliceOp>( 45754db8cc7SThomas Raoux loc, operand.get(), offsets, *targetShape, strides)); 45854db8cc7SThomas Raoux } 45954db8cc7SThomas Raoux Operation *newOp = cloneOpWithOperandsAndTypes( 46054db8cc7SThomas Raoux rewriter, loc, op, extractOperands, newVecType); 46154db8cc7SThomas Raoux result = rewriter.create<vector::InsertStridedSliceOp>( 46254db8cc7SThomas Raoux loc, newOp->getResult(0), result, offsets, strides); 46354db8cc7SThomas Raoux } 46454db8cc7SThomas Raoux rewriter.replaceOp(op, result); 46554db8cc7SThomas Raoux return success(); 46654db8cc7SThomas Raoux } 46754db8cc7SThomas Raoux 46854db8cc7SThomas Raoux private: 46954db8cc7SThomas Raoux vector::UnrollVectorOptions options; 47054db8cc7SThomas Raoux }; 47154db8cc7SThomas Raoux 47254db8cc7SThomas Raoux struct UnrollReductionPattern : public OpRewritePattern<vector::ReductionOp> { 47354db8cc7SThomas Raoux UnrollReductionPattern(MLIRContext *context, 47454db8cc7SThomas Raoux const vector::UnrollVectorOptions &options, 47554db8cc7SThomas Raoux PatternBenefit benefit = 1) 47654db8cc7SThomas Raoux : OpRewritePattern<vector::ReductionOp>(context, benefit), 47754db8cc7SThomas Raoux options(options) {} 47854db8cc7SThomas Raoux 47954db8cc7SThomas Raoux LogicalResult matchAndRewrite(vector::ReductionOp reductionOp, 48054db8cc7SThomas Raoux PatternRewriter &rewriter) const override { 4810a81ace0SKazu Hirata std::optional<SmallVector<int64_t>> targetShape = 48254db8cc7SThomas Raoux getTargetShape(options, reductionOp); 48354db8cc7SThomas Raoux if (!targetShape) 48454db8cc7SThomas Raoux return failure(); 48554db8cc7SThomas Raoux SmallVector<int64_t> originalSize = *reductionOp.getShapeForUnroll(); 48654db8cc7SThomas Raoux 48754db8cc7SThomas Raoux // Create unrolled vector reduction. 48854db8cc7SThomas Raoux Location loc = reductionOp.getLoc(); 48954db8cc7SThomas Raoux Value accumulator = nullptr; 490831041beSChristopher Bate for (SmallVector<int64_t> offsets : 491831041beSChristopher Bate StaticTileOffsetRange(originalSize, *targetShape)) { 49254db8cc7SThomas Raoux SmallVector<int64_t> strides(offsets.size(), 1); 49354db8cc7SThomas Raoux Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>( 49454db8cc7SThomas Raoux loc, reductionOp.getVector(), offsets, *targetShape, strides); 49554db8cc7SThomas Raoux Operation *newOp = cloneOpWithOperandsAndTypes( 49654db8cc7SThomas Raoux rewriter, loc, reductionOp, slicedOperand, reductionOp.getType()); 49754db8cc7SThomas Raoux Value result = newOp->getResult(0); 49854db8cc7SThomas Raoux 49954db8cc7SThomas Raoux if (!accumulator) { 50054db8cc7SThomas Raoux // This is the first reduction. 50154db8cc7SThomas Raoux accumulator = result; 50254db8cc7SThomas Raoux } else { 50354db8cc7SThomas Raoux // On subsequent reduction, combine with the accumulator. 50454db8cc7SThomas Raoux accumulator = makeArithReduction(rewriter, loc, reductionOp.getKind(), 50554db8cc7SThomas Raoux accumulator, result); 50654db8cc7SThomas Raoux } 50754db8cc7SThomas Raoux } 50854db8cc7SThomas Raoux 50954db8cc7SThomas Raoux rewriter.replaceOp(reductionOp, accumulator); 51054db8cc7SThomas Raoux return success(); 51154db8cc7SThomas Raoux } 51254db8cc7SThomas Raoux 51354db8cc7SThomas Raoux private: 51454db8cc7SThomas Raoux const vector::UnrollVectorOptions options; 51554db8cc7SThomas Raoux }; 51654db8cc7SThomas Raoux 517490c77e4Sjacquesguan struct UnrollTransposePattern : public OpRewritePattern<vector::TransposeOp> { 518490c77e4Sjacquesguan UnrollTransposePattern(MLIRContext *context, 51954db8cc7SThomas Raoux const vector::UnrollVectorOptions &options, 52054db8cc7SThomas Raoux PatternBenefit benefit = 1) 52154db8cc7SThomas Raoux : OpRewritePattern<vector::TransposeOp>(context, benefit), 52254db8cc7SThomas Raoux options(options) {} 52354db8cc7SThomas Raoux 524490c77e4Sjacquesguan LogicalResult matchAndRewrite(vector::TransposeOp transposeOp, 52554db8cc7SThomas Raoux PatternRewriter &rewriter) const override { 526a1aad28dSLei Zhang if (transposeOp.getResultVectorType().getRank() == 0) 52754db8cc7SThomas Raoux return failure(); 528490c77e4Sjacquesguan auto targetShape = getTargetShape(options, transposeOp); 52954db8cc7SThomas Raoux if (!targetShape) 53054db8cc7SThomas Raoux return failure(); 531a1aad28dSLei Zhang auto originalVectorType = transposeOp.getResultVectorType(); 5327a69a9d7SNicolas Vasilache SmallVector<int64_t> strides(targetShape->size(), 1); 533490c77e4Sjacquesguan Location loc = transposeOp.getLoc(); 53454db8cc7SThomas Raoux ArrayRef<int64_t> originalSize = originalVectorType.getShape(); 535831041beSChristopher Bate 53654db8cc7SThomas Raoux // Prepare the result vector; 53754db8cc7SThomas Raoux Value result = rewriter.create<arith::ConstantOp>( 53854db8cc7SThomas Raoux loc, originalVectorType, rewriter.getZeroAttr(originalVectorType)); 53932c3decbSMatthias Springer ArrayRef<int64_t> permutation = transposeOp.getPermutation(); 5407a69a9d7SNicolas Vasilache 541831041beSChristopher Bate // Unroll the computation. 542831041beSChristopher Bate for (SmallVector<int64_t> elementOffsets : 543831041beSChristopher Bate StaticTileOffsetRange(originalSize, *targetShape)) { 5447a69a9d7SNicolas Vasilache SmallVector<int64_t> permutedOffsets(elementOffsets.size()); 5457a69a9d7SNicolas Vasilache SmallVector<int64_t> permutedShape(elementOffsets.size()); 54654db8cc7SThomas Raoux // Compute the source offsets and shape. 5478c258fdaSJakub Kuderski for (auto indices : llvm::enumerate(permutation)) { 54854db8cc7SThomas Raoux permutedOffsets[indices.value()] = elementOffsets[indices.index()]; 54954db8cc7SThomas Raoux permutedShape[indices.value()] = (*targetShape)[indices.index()]; 55054db8cc7SThomas Raoux } 55154db8cc7SThomas Raoux Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>( 552490c77e4Sjacquesguan loc, transposeOp.getVector(), permutedOffsets, permutedShape, 553490c77e4Sjacquesguan strides); 554490c77e4Sjacquesguan Value transposedSlice = 55554db8cc7SThomas Raoux rewriter.create<vector::TransposeOp>(loc, slicedOperand, permutation); 55654db8cc7SThomas Raoux result = rewriter.create<vector::InsertStridedSliceOp>( 557490c77e4Sjacquesguan loc, transposedSlice, result, elementOffsets, strides); 55854db8cc7SThomas Raoux } 559490c77e4Sjacquesguan rewriter.replaceOp(transposeOp, result); 56054db8cc7SThomas Raoux return success(); 56154db8cc7SThomas Raoux } 56254db8cc7SThomas Raoux 56354db8cc7SThomas Raoux private: 56454db8cc7SThomas Raoux vector::UnrollVectorOptions options; 56554db8cc7SThomas Raoux }; 56654db8cc7SThomas Raoux 567435f7d4cSQuinn Dawkins struct UnrollGatherPattern : public OpRewritePattern<vector::GatherOp> { 568435f7d4cSQuinn Dawkins UnrollGatherPattern(MLIRContext *context, 569435f7d4cSQuinn Dawkins const vector::UnrollVectorOptions &options, 570435f7d4cSQuinn Dawkins PatternBenefit benefit = 1) 571435f7d4cSQuinn Dawkins : OpRewritePattern<vector::GatherOp>(context, benefit), options(options) { 572435f7d4cSQuinn Dawkins } 573435f7d4cSQuinn Dawkins 574435f7d4cSQuinn Dawkins LogicalResult matchAndRewrite(vector::GatherOp gatherOp, 575435f7d4cSQuinn Dawkins PatternRewriter &rewriter) const override { 576435f7d4cSQuinn Dawkins VectorType sourceVectorType = gatherOp.getVectorType(); 577435f7d4cSQuinn Dawkins if (sourceVectorType.getRank() == 0) 578435f7d4cSQuinn Dawkins return failure(); 579435f7d4cSQuinn Dawkins auto targetShape = getTargetShape(options, gatherOp); 580435f7d4cSQuinn Dawkins if (!targetShape) 581435f7d4cSQuinn Dawkins return failure(); 582435f7d4cSQuinn Dawkins SmallVector<int64_t> strides(targetShape->size(), 1); 583435f7d4cSQuinn Dawkins Location loc = gatherOp.getLoc(); 584435f7d4cSQuinn Dawkins ArrayRef<int64_t> originalSize = gatherOp.getVectorType().getShape(); 585435f7d4cSQuinn Dawkins 586435f7d4cSQuinn Dawkins // Prepare the result vector; 587435f7d4cSQuinn Dawkins Value result = rewriter.create<arith::ConstantOp>( 588435f7d4cSQuinn Dawkins loc, sourceVectorType, rewriter.getZeroAttr(sourceVectorType)); 589435f7d4cSQuinn Dawkins auto targetType = 590435f7d4cSQuinn Dawkins VectorType::get(*targetShape, sourceVectorType.getElementType()); 591435f7d4cSQuinn Dawkins 592435f7d4cSQuinn Dawkins SmallVector<int64_t> loopOrder = 593435f7d4cSQuinn Dawkins getUnrollOrder(originalSize.size(), gatherOp, options); 594831041beSChristopher Bate for (SmallVector<int64_t> elementOffsets : 595831041beSChristopher Bate StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) { 596435f7d4cSQuinn Dawkins // To get the unrolled gather, extract the same slice based on the 597435f7d4cSQuinn Dawkins // decomposed shape from each of the index, mask, and pass-through 598435f7d4cSQuinn Dawkins // vectors. 599435f7d4cSQuinn Dawkins Value indexSubVec = rewriter.create<vector::ExtractStridedSliceOp>( 600435f7d4cSQuinn Dawkins loc, gatherOp.getIndexVec(), elementOffsets, *targetShape, strides); 601435f7d4cSQuinn Dawkins Value maskSubVec = rewriter.create<vector::ExtractStridedSliceOp>( 602435f7d4cSQuinn Dawkins loc, gatherOp.getMask(), elementOffsets, *targetShape, strides); 603435f7d4cSQuinn Dawkins Value passThruSubVec = rewriter.create<vector::ExtractStridedSliceOp>( 604435f7d4cSQuinn Dawkins loc, gatherOp.getPassThru(), elementOffsets, *targetShape, strides); 605435f7d4cSQuinn Dawkins auto slicedGather = rewriter.create<vector::GatherOp>( 606435f7d4cSQuinn Dawkins loc, targetType, gatherOp.getBase(), gatherOp.getIndices(), 607435f7d4cSQuinn Dawkins indexSubVec, maskSubVec, passThruSubVec); 608435f7d4cSQuinn Dawkins 609435f7d4cSQuinn Dawkins result = rewriter.create<vector::InsertStridedSliceOp>( 610435f7d4cSQuinn Dawkins loc, slicedGather, result, elementOffsets, strides); 611435f7d4cSQuinn Dawkins } 612435f7d4cSQuinn Dawkins rewriter.replaceOp(gatherOp, result); 613435f7d4cSQuinn Dawkins return success(); 614435f7d4cSQuinn Dawkins } 615435f7d4cSQuinn Dawkins 616435f7d4cSQuinn Dawkins private: 617435f7d4cSQuinn Dawkins vector::UnrollVectorOptions options; 618435f7d4cSQuinn Dawkins }; 619435f7d4cSQuinn Dawkins 62054db8cc7SThomas Raoux } // namespace 62154db8cc7SThomas Raoux 62254db8cc7SThomas Raoux void mlir::vector::populateVectorUnrollPatterns( 62354db8cc7SThomas Raoux RewritePatternSet &patterns, const UnrollVectorOptions &options, 62454db8cc7SThomas Raoux PatternBenefit benefit) { 62554db8cc7SThomas Raoux patterns.add<UnrollTransferReadPattern, UnrollTransferWritePattern, 62654db8cc7SThomas Raoux UnrollContractionPattern, UnrollElementwisePattern, 62754db8cc7SThomas Raoux UnrollReductionPattern, UnrollMultiReductionPattern, 628435f7d4cSQuinn Dawkins UnrollTransposePattern, UnrollGatherPattern>( 629435f7d4cSQuinn Dawkins patterns.getContext(), options, benefit); 63054db8cc7SThomas Raoux } 631