1 //===- LowerVectorInterleave.cpp - Lower 'vector.interleave' operation ----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements target-independent rewrites and utilities to lower the 10 // 'vector.interleave' operation. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Vector/IR/VectorOps.h" 15 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" 16 #include "mlir/Dialect/Vector/Utils/VectorUtils.h" 17 #include "mlir/IR/BuiltinTypes.h" 18 #include "mlir/IR/PatternMatch.h" 19 20 #define DEBUG_TYPE "vector-interleave-lowering" 21 22 using namespace mlir; 23 using namespace mlir::vector; 24 25 namespace { 26 27 /// A one-shot unrolling of vector.interleave to the `targetRank`. 28 /// 29 /// Example: 30 /// 31 /// ```mlir 32 /// vector.interleave %a, %b : vector<1x2x3x4xi64> 33 /// ``` 34 /// Would be unrolled to: 35 /// ```mlir 36 /// %result = arith.constant dense<0> : vector<1x2x3x8xi64> 37 /// %0 = vector.extract %a[0, 0, 0] ─┐ 38 /// : vector<4xi64> from vector<1x2x3x4xi64> | 39 /// %1 = vector.extract %b[0, 0, 0] | 40 /// : vector<4xi64> from vector<1x2x3x4xi64> | - Repeated 6x for 41 /// %2 = vector.interleave %0, %1 : vector<4xi64> | all leading positions 42 /// %3 = vector.insert %2, %result [0, 0, 0] | 43 /// : vector<8xi64> into vector<1x2x3x8xi64> ┘ 44 /// ``` 45 /// 46 /// Note: If any leading dimension before the `targetRank` is scalable the 47 /// unrolling will stop before the scalable dimension. 48 class UnrollInterleaveOp : public OpRewritePattern<vector::InterleaveOp> { 49 public: 50 UnrollInterleaveOp(int64_t targetRank, MLIRContext *context, 51 PatternBenefit benefit = 1) 52 : OpRewritePattern(context, benefit), targetRank(targetRank){}; 53 54 LogicalResult matchAndRewrite(vector::InterleaveOp op, 55 PatternRewriter &rewriter) const override { 56 VectorType resultType = op.getResultVectorType(); 57 auto unrollIterator = vector::createUnrollIterator(resultType, targetRank); 58 if (!unrollIterator) 59 return failure(); 60 61 auto loc = op.getLoc(); 62 Value result = rewriter.create<arith::ConstantOp>( 63 loc, resultType, rewriter.getZeroAttr(resultType)); 64 for (auto position : *unrollIterator) { 65 Value extractLhs = rewriter.create<ExtractOp>(loc, op.getLhs(), position); 66 Value extractRhs = rewriter.create<ExtractOp>(loc, op.getRhs(), position); 67 Value interleave = 68 rewriter.create<InterleaveOp>(loc, extractLhs, extractRhs); 69 result = rewriter.create<InsertOp>(loc, interleave, result, position); 70 } 71 72 rewriter.replaceOp(op, result); 73 return success(); 74 } 75 76 private: 77 int64_t targetRank = 1; 78 }; 79 80 } // namespace 81 82 void mlir::vector::populateVectorInterleaveLoweringPatterns( 83 RewritePatternSet &patterns, int64_t targetRank, PatternBenefit benefit) { 84 patterns.add<UnrollInterleaveOp>(targetRank, patterns.getContext(), benefit); 85 } 86