12bc4c3e9SNicolas Vasilache //===- LowerVectorTranspose.cpp - Lower 'vector.transpose' operation ------===//
22bc4c3e9SNicolas Vasilache //
32bc4c3e9SNicolas Vasilache // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
42bc4c3e9SNicolas Vasilache // See https://llvm.org/LICENSE.txt for license information.
52bc4c3e9SNicolas Vasilache // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
62bc4c3e9SNicolas Vasilache //
72bc4c3e9SNicolas Vasilache //===----------------------------------------------------------------------===//
82bc4c3e9SNicolas Vasilache //
92bc4c3e9SNicolas Vasilache // This file implements target-independent rewrites and utilities to lower the
102bc4c3e9SNicolas Vasilache // 'vector.transpose' operation.
112bc4c3e9SNicolas Vasilache //
122bc4c3e9SNicolas Vasilache //===----------------------------------------------------------------------===//
132bc4c3e9SNicolas Vasilache
142bc4c3e9SNicolas Vasilache #include "mlir/Dialect/Affine/IR/AffineOps.h"
152bc4c3e9SNicolas Vasilache #include "mlir/Dialect/Arith/IR/Arith.h"
162bc4c3e9SNicolas Vasilache #include "mlir/Dialect/Arith/Utils/Utils.h"
172bc4c3e9SNicolas Vasilache #include "mlir/Dialect/Linalg/IR/Linalg.h"
182bc4c3e9SNicolas Vasilache #include "mlir/Dialect/MemRef/IR/MemRef.h"
192bc4c3e9SNicolas Vasilache #include "mlir/Dialect/SCF/IR/SCF.h"
202bc4c3e9SNicolas Vasilache #include "mlir/Dialect/Tensor/IR/Tensor.h"
212bc4c3e9SNicolas Vasilache #include "mlir/Dialect/Utils/IndexingUtils.h"
222bc4c3e9SNicolas Vasilache #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
232bc4c3e9SNicolas Vasilache #include "mlir/Dialect/Vector/IR/VectorOps.h"
242bc4c3e9SNicolas Vasilache #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
252bc4c3e9SNicolas Vasilache #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
262bc4c3e9SNicolas Vasilache #include "mlir/IR/BuiltinAttributeInterfaces.h"
272bc4c3e9SNicolas Vasilache #include "mlir/IR/BuiltinTypes.h"
282bc4c3e9SNicolas Vasilache #include "mlir/IR/ImplicitLocOpBuilder.h"
292bc4c3e9SNicolas Vasilache #include "mlir/IR/Location.h"
302bc4c3e9SNicolas Vasilache #include "mlir/IR/Matchers.h"
312bc4c3e9SNicolas Vasilache #include "mlir/IR/PatternMatch.h"
322bc4c3e9SNicolas Vasilache #include "mlir/IR/TypeUtilities.h"
332bc4c3e9SNicolas Vasilache #include "mlir/Interfaces/VectorInterfaces.h"
342bc4c3e9SNicolas Vasilache
35eb7f9feeSDiego Caballero #define DEBUG_TYPE "lower-vector-transpose"
362bc4c3e9SNicolas Vasilache
372bc4c3e9SNicolas Vasilache using namespace mlir;
382bc4c3e9SNicolas Vasilache using namespace mlir::vector;
392bc4c3e9SNicolas Vasilache
402bc4c3e9SNicolas Vasilache /// Given a 'transpose' pattern, prune the rightmost dimensions that are not
412bc4c3e9SNicolas Vasilache /// transposed.
pruneNonTransposedDims(ArrayRef<int64_t> transpose,SmallVectorImpl<int64_t> & result)422bc4c3e9SNicolas Vasilache static void pruneNonTransposedDims(ArrayRef<int64_t> transpose,
432bc4c3e9SNicolas Vasilache SmallVectorImpl<int64_t> &result) {
442bc4c3e9SNicolas Vasilache size_t numTransposedDims = transpose.size();
452bc4c3e9SNicolas Vasilache for (size_t transpDim : llvm::reverse(transpose)) {
462bc4c3e9SNicolas Vasilache if (transpDim != numTransposedDims - 1)
472bc4c3e9SNicolas Vasilache break;
482bc4c3e9SNicolas Vasilache numTransposedDims--;
492bc4c3e9SNicolas Vasilache }
502bc4c3e9SNicolas Vasilache
512bc4c3e9SNicolas Vasilache result.append(transpose.begin(), transpose.begin() + numTransposedDims);
522bc4c3e9SNicolas Vasilache }
532bc4c3e9SNicolas Vasilache
548d163e50SHanhan Wang /// Returns true if the lowering option is a vector shuffle based approach.
isShuffleLike(VectorTransposeLowering lowering)558d163e50SHanhan Wang static bool isShuffleLike(VectorTransposeLowering lowering) {
568d163e50SHanhan Wang return lowering == VectorTransposeLowering::Shuffle1D ||
578d163e50SHanhan Wang lowering == VectorTransposeLowering::Shuffle16x16;
588d163e50SHanhan Wang }
598d163e50SHanhan Wang
608d163e50SHanhan Wang /// Returns a shuffle mask that builds on `vals`. `vals` is the offset base of
618d163e50SHanhan Wang /// shuffle ops, i.e., the unpack pattern. The method iterates with `vals` to
628d163e50SHanhan Wang /// create the mask for `numBits` bits vector. The `numBits` have to be a
638d163e50SHanhan Wang /// multiple of 128. For example, if `vals` is {0, 1, 16, 17} and `numBits` is
648d163e50SHanhan Wang /// 512, there should be 16 elements in the final result. It constructs the
658d163e50SHanhan Wang /// below mask to get the unpack elements.
668d163e50SHanhan Wang /// [0, 1, 16, 17,
678d163e50SHanhan Wang /// 0+4, 1+4, 16+4, 17+4,
688d163e50SHanhan Wang /// 0+8, 1+8, 16+8, 17+8,
698d163e50SHanhan Wang /// 0+12, 1+12, 16+12, 17+12]
708d163e50SHanhan Wang static SmallVector<int64_t>
getUnpackShufflePermFor128Lane(ArrayRef<int64_t> vals,int numBits)718d163e50SHanhan Wang getUnpackShufflePermFor128Lane(ArrayRef<int64_t> vals, int numBits) {
728d163e50SHanhan Wang assert(numBits % 128 == 0 && "expected numBits is a multiple of 128");
738d163e50SHanhan Wang int numElem = numBits / 32;
748d163e50SHanhan Wang SmallVector<int64_t> res;
758d163e50SHanhan Wang for (int i = 0; i < numElem; i += 4)
768d163e50SHanhan Wang for (int64_t v : vals)
778d163e50SHanhan Wang res.push_back(v + i);
788d163e50SHanhan Wang return res;
798d163e50SHanhan Wang }
808d163e50SHanhan Wang
818d163e50SHanhan Wang /// Lower to vector.shuffle on v1 and v2 with UnpackLoPd shuffle mask. For
828d163e50SHanhan Wang /// example, if it is targeting 512 bit vector, returns
838d163e50SHanhan Wang /// vector.shuffle on v1, v2, [0, 1, 16, 17,
848d163e50SHanhan Wang /// 0+4, 1+4, 16+4, 17+4,
858d163e50SHanhan Wang /// 0+8, 1+8, 16+8, 17+8,
868d163e50SHanhan Wang /// 0+12, 1+12, 16+12, 17+12].
createUnpackLoPd(ImplicitLocOpBuilder & b,Value v1,Value v2,int numBits)878d163e50SHanhan Wang static Value createUnpackLoPd(ImplicitLocOpBuilder &b, Value v1, Value v2,
888d163e50SHanhan Wang int numBits) {
898d163e50SHanhan Wang int numElem = numBits / 32;
908d163e50SHanhan Wang return b.create<vector::ShuffleOp>(
918d163e50SHanhan Wang v1, v2,
928d163e50SHanhan Wang getUnpackShufflePermFor128Lane({0, 1, numElem, numElem + 1}, numBits));
938d163e50SHanhan Wang }
948d163e50SHanhan Wang
958d163e50SHanhan Wang /// Lower to vector.shuffle on v1 and v2 with UnpackHiPd shuffle mask. For
968d163e50SHanhan Wang /// example, if it is targeting 512 bit vector, returns
978d163e50SHanhan Wang /// vector.shuffle, v1, v2, [2, 3, 18, 19,
988d163e50SHanhan Wang /// 2+4, 3+4, 18+4, 19+4,
998d163e50SHanhan Wang /// 2+8, 3+8, 18+8, 19+8,
1008d163e50SHanhan Wang /// 2+12, 3+12, 18+12, 19+12].
createUnpackHiPd(ImplicitLocOpBuilder & b,Value v1,Value v2,int numBits)1018d163e50SHanhan Wang static Value createUnpackHiPd(ImplicitLocOpBuilder &b, Value v1, Value v2,
1028d163e50SHanhan Wang int numBits) {
1038d163e50SHanhan Wang int numElem = numBits / 32;
1048d163e50SHanhan Wang return b.create<vector::ShuffleOp>(
1058d163e50SHanhan Wang v1, v2,
1068d163e50SHanhan Wang getUnpackShufflePermFor128Lane({2, 3, numElem + 2, numElem + 3},
1078d163e50SHanhan Wang numBits));
1088d163e50SHanhan Wang }
1098d163e50SHanhan Wang
1108d163e50SHanhan Wang /// Lower to vector.shuffle on v1 and v2 with UnpackLoPs shuffle mask. For
1118d163e50SHanhan Wang /// example, if it is targeting 512 bit vector, returns
1128d163e50SHanhan Wang /// vector.shuffle, v1, v2, [0, 16, 1, 17,
1138d163e50SHanhan Wang /// 0+4, 16+4, 1+4, 17+4,
1148d163e50SHanhan Wang /// 0+8, 16+8, 1+8, 17+8,
1158d163e50SHanhan Wang /// 0+12, 16+12, 1+12, 17+12].
createUnpackLoPs(ImplicitLocOpBuilder & b,Value v1,Value v2,int numBits)1168d163e50SHanhan Wang static Value createUnpackLoPs(ImplicitLocOpBuilder &b, Value v1, Value v2,
1178d163e50SHanhan Wang int numBits) {
1188d163e50SHanhan Wang int numElem = numBits / 32;
1198d163e50SHanhan Wang auto shuffle = b.create<vector::ShuffleOp>(
1208d163e50SHanhan Wang v1, v2,
1218d163e50SHanhan Wang getUnpackShufflePermFor128Lane({0, numElem, 1, numElem + 1}, numBits));
1228d163e50SHanhan Wang return shuffle;
1238d163e50SHanhan Wang }
1248d163e50SHanhan Wang
1258d163e50SHanhan Wang /// Lower to vector.shuffle on v1 and v2 with UnpackHiPs shuffle mask. For
1268d163e50SHanhan Wang /// example, if it is targeting 512 bit vector, returns
1278d163e50SHanhan Wang /// vector.shuffle, v1, v2, [2, 18, 3, 19,
1288d163e50SHanhan Wang /// 2+4, 18+4, 3+4, 19+4,
1298d163e50SHanhan Wang /// 2+8, 18+8, 3+8, 19+8,
1308d163e50SHanhan Wang /// 2+12, 18+12, 3+12, 19+12].
createUnpackHiPs(ImplicitLocOpBuilder & b,Value v1,Value v2,int numBits)1318d163e50SHanhan Wang static Value createUnpackHiPs(ImplicitLocOpBuilder &b, Value v1, Value v2,
1328d163e50SHanhan Wang int numBits) {
1338d163e50SHanhan Wang int numElem = numBits / 32;
1348d163e50SHanhan Wang return b.create<vector::ShuffleOp>(
1358d163e50SHanhan Wang v1, v2,
1368d163e50SHanhan Wang getUnpackShufflePermFor128Lane({2, numElem + 2, 3, numElem + 3},
1378d163e50SHanhan Wang numBits));
1388d163e50SHanhan Wang }
1398d163e50SHanhan Wang
1408d163e50SHanhan Wang /// Returns a vector.shuffle that shuffles 128-bit lanes (composed of 4 32-bit
1418d163e50SHanhan Wang /// elements) selected by `mask` from `v1` and `v2`. I.e.,
1428d163e50SHanhan Wang ///
1438d163e50SHanhan Wang /// DEFINE SELECT4(src, control) {
1448d163e50SHanhan Wang /// CASE(control[1:0]) OF
1458d163e50SHanhan Wang /// 0: tmp[127:0] := src[127:0]
1468d163e50SHanhan Wang /// 1: tmp[127:0] := src[255:128]
1478d163e50SHanhan Wang /// 2: tmp[127:0] := src[383:256]
1488d163e50SHanhan Wang /// 3: tmp[127:0] := src[511:384]
1498d163e50SHanhan Wang /// ESAC
1508d163e50SHanhan Wang /// RETURN tmp[127:0]
1518d163e50SHanhan Wang /// }
1528d163e50SHanhan Wang /// dst[127:0] := SELECT4(v1[511:0], mask[1:0])
1538d163e50SHanhan Wang /// dst[255:128] := SELECT4(v1[511:0], mask[3:2])
1548d163e50SHanhan Wang /// dst[383:256] := SELECT4(v2[511:0], mask[5:4])
1558d163e50SHanhan Wang /// dst[511:384] := SELECT4(v2[511:0], mask[7:6])
create4x128BitSuffle(ImplicitLocOpBuilder & b,Value v1,Value v2,uint8_t mask)1568d163e50SHanhan Wang static Value create4x128BitSuffle(ImplicitLocOpBuilder &b, Value v1, Value v2,
1578d163e50SHanhan Wang uint8_t mask) {
1585550c821STres Popp assert(cast<VectorType>(v1.getType()).getShape()[0] == 16 &&
1598d163e50SHanhan Wang "expected a vector with length=16");
1608d163e50SHanhan Wang SmallVector<int64_t> shuffleMask;
1618d163e50SHanhan Wang auto appendToMask = [&](int64_t base, uint8_t control) {
1628d163e50SHanhan Wang switch (control) {
1638d163e50SHanhan Wang case 0:
1648d163e50SHanhan Wang llvm::append_range(shuffleMask, ArrayRef<int64_t>{base + 0, base + 1,
1658d163e50SHanhan Wang base + 2, base + 3});
1668d163e50SHanhan Wang break;
1678d163e50SHanhan Wang case 1:
1688d163e50SHanhan Wang llvm::append_range(shuffleMask, ArrayRef<int64_t>{base + 4, base + 5,
1698d163e50SHanhan Wang base + 6, base + 7});
1708d163e50SHanhan Wang break;
1718d163e50SHanhan Wang case 2:
1728d163e50SHanhan Wang llvm::append_range(shuffleMask, ArrayRef<int64_t>{base + 8, base + 9,
1738d163e50SHanhan Wang base + 10, base + 11});
1748d163e50SHanhan Wang break;
1758d163e50SHanhan Wang case 3:
1768d163e50SHanhan Wang llvm::append_range(shuffleMask, ArrayRef<int64_t>{base + 12, base + 13,
1778d163e50SHanhan Wang base + 14, base + 15});
1788d163e50SHanhan Wang break;
1798d163e50SHanhan Wang default:
1808d163e50SHanhan Wang llvm_unreachable("control > 3 : overflow");
1818d163e50SHanhan Wang }
1828d163e50SHanhan Wang };
1838d163e50SHanhan Wang uint8_t b01 = mask & 0x3;
1848d163e50SHanhan Wang uint8_t b23 = (mask >> 2) & 0x3;
1858d163e50SHanhan Wang uint8_t b45 = (mask >> 4) & 0x3;
1868d163e50SHanhan Wang uint8_t b67 = (mask >> 6) & 0x3;
1878d163e50SHanhan Wang appendToMask(0, b01);
1888d163e50SHanhan Wang appendToMask(0, b23);
1898d163e50SHanhan Wang appendToMask(16, b45);
1908d163e50SHanhan Wang appendToMask(16, b67);
1918d163e50SHanhan Wang return b.create<vector::ShuffleOp>(v1, v2, shuffleMask);
1928d163e50SHanhan Wang }
1938d163e50SHanhan Wang
1948d163e50SHanhan Wang /// Lowers the value to a vector.shuffle op. The `source` is expected to be a
1958d163e50SHanhan Wang /// 1-D vector and have `m`x`n` elements.
transposeToShuffle1D(OpBuilder & b,Value source,int m,int n)1968d163e50SHanhan Wang static Value transposeToShuffle1D(OpBuilder &b, Value source, int m, int n) {
1978d163e50SHanhan Wang SmallVector<int64_t> mask;
1988d163e50SHanhan Wang mask.reserve(m * n);
1998d163e50SHanhan Wang for (int64_t j = 0; j < n; ++j)
2008d163e50SHanhan Wang for (int64_t i = 0; i < m; ++i)
2018d163e50SHanhan Wang mask.push_back(i * n + j);
2028d163e50SHanhan Wang return b.create<vector::ShuffleOp>(source.getLoc(), source, source, mask);
2038d163e50SHanhan Wang }
2048d163e50SHanhan Wang
2058d163e50SHanhan Wang /// Lowers the value to a sequence of vector.shuffle ops. The `source` is
2068d163e50SHanhan Wang /// expected to be a 16x16 vector.
transposeToShuffle16x16(OpBuilder & builder,Value source,int m,int n)2078d163e50SHanhan Wang static Value transposeToShuffle16x16(OpBuilder &builder, Value source, int m,
2088d163e50SHanhan Wang int n) {
2098d163e50SHanhan Wang ImplicitLocOpBuilder b(source.getLoc(), builder);
2108d163e50SHanhan Wang SmallVector<Value> vs;
2118d163e50SHanhan Wang for (int64_t i = 0; i < m; ++i)
2128d163e50SHanhan Wang vs.push_back(b.create<vector::ExtractOp>(source, i));
2138d163e50SHanhan Wang
2148d163e50SHanhan Wang // Interleave 32-bit lanes using
2158d163e50SHanhan Wang // 8x _mm512_unpacklo_epi32
2168d163e50SHanhan Wang // 8x _mm512_unpackhi_epi32
2178d163e50SHanhan Wang Value t0 = createUnpackLoPs(b, vs[0x0], vs[0x1], 512);
2188d163e50SHanhan Wang Value t1 = createUnpackHiPs(b, vs[0x0], vs[0x1], 512);
2198d163e50SHanhan Wang Value t2 = createUnpackLoPs(b, vs[0x2], vs[0x3], 512);
2208d163e50SHanhan Wang Value t3 = createUnpackHiPs(b, vs[0x2], vs[0x3], 512);
2218d163e50SHanhan Wang Value t4 = createUnpackLoPs(b, vs[0x4], vs[0x5], 512);
2228d163e50SHanhan Wang Value t5 = createUnpackHiPs(b, vs[0x4], vs[0x5], 512);
2238d163e50SHanhan Wang Value t6 = createUnpackLoPs(b, vs[0x6], vs[0x7], 512);
2248d163e50SHanhan Wang Value t7 = createUnpackHiPs(b, vs[0x6], vs[0x7], 512);
2258d163e50SHanhan Wang Value t8 = createUnpackLoPs(b, vs[0x8], vs[0x9], 512);
2268d163e50SHanhan Wang Value t9 = createUnpackHiPs(b, vs[0x8], vs[0x9], 512);
2278d163e50SHanhan Wang Value ta = createUnpackLoPs(b, vs[0xa], vs[0xb], 512);
2288d163e50SHanhan Wang Value tb = createUnpackHiPs(b, vs[0xa], vs[0xb], 512);
2298d163e50SHanhan Wang Value tc = createUnpackLoPs(b, vs[0xc], vs[0xd], 512);
2308d163e50SHanhan Wang Value td = createUnpackHiPs(b, vs[0xc], vs[0xd], 512);
2318d163e50SHanhan Wang Value te = createUnpackLoPs(b, vs[0xe], vs[0xf], 512);
2328d163e50SHanhan Wang Value tf = createUnpackHiPs(b, vs[0xe], vs[0xf], 512);
2338d163e50SHanhan Wang
2348d163e50SHanhan Wang // Interleave 64-bit lanes using
2358d163e50SHanhan Wang // 8x _mm512_unpacklo_epi64
2368d163e50SHanhan Wang // 8x _mm512_unpackhi_epi64
2378d163e50SHanhan Wang Value r0 = createUnpackLoPd(b, t0, t2, 512);
2388d163e50SHanhan Wang Value r1 = createUnpackHiPd(b, t0, t2, 512);
2398d163e50SHanhan Wang Value r2 = createUnpackLoPd(b, t1, t3, 512);
2408d163e50SHanhan Wang Value r3 = createUnpackHiPd(b, t1, t3, 512);
2418d163e50SHanhan Wang Value r4 = createUnpackLoPd(b, t4, t6, 512);
2428d163e50SHanhan Wang Value r5 = createUnpackHiPd(b, t4, t6, 512);
2438d163e50SHanhan Wang Value r6 = createUnpackLoPd(b, t5, t7, 512);
2448d163e50SHanhan Wang Value r7 = createUnpackHiPd(b, t5, t7, 512);
2458d163e50SHanhan Wang Value r8 = createUnpackLoPd(b, t8, ta, 512);
2468d163e50SHanhan Wang Value r9 = createUnpackHiPd(b, t8, ta, 512);
2478d163e50SHanhan Wang Value ra = createUnpackLoPd(b, t9, tb, 512);
2488d163e50SHanhan Wang Value rb = createUnpackHiPd(b, t9, tb, 512);
2498d163e50SHanhan Wang Value rc = createUnpackLoPd(b, tc, te, 512);
2508d163e50SHanhan Wang Value rd = createUnpackHiPd(b, tc, te, 512);
2518d163e50SHanhan Wang Value re = createUnpackLoPd(b, td, tf, 512);
2528d163e50SHanhan Wang Value rf = createUnpackHiPd(b, td, tf, 512);
2538d163e50SHanhan Wang
2548d163e50SHanhan Wang // Permute 128-bit lanes using
2558d163e50SHanhan Wang // 16x _mm512_shuffle_i32x4
2568d163e50SHanhan Wang t0 = create4x128BitSuffle(b, r0, r4, 0x88);
2578d163e50SHanhan Wang t1 = create4x128BitSuffle(b, r1, r5, 0x88);
2588d163e50SHanhan Wang t2 = create4x128BitSuffle(b, r2, r6, 0x88);
2598d163e50SHanhan Wang t3 = create4x128BitSuffle(b, r3, r7, 0x88);
2608d163e50SHanhan Wang t4 = create4x128BitSuffle(b, r0, r4, 0xdd);
2618d163e50SHanhan Wang t5 = create4x128BitSuffle(b, r1, r5, 0xdd);
2628d163e50SHanhan Wang t6 = create4x128BitSuffle(b, r2, r6, 0xdd);
2638d163e50SHanhan Wang t7 = create4x128BitSuffle(b, r3, r7, 0xdd);
2648d163e50SHanhan Wang t8 = create4x128BitSuffle(b, r8, rc, 0x88);
2658d163e50SHanhan Wang t9 = create4x128BitSuffle(b, r9, rd, 0x88);
2668d163e50SHanhan Wang ta = create4x128BitSuffle(b, ra, re, 0x88);
2678d163e50SHanhan Wang tb = create4x128BitSuffle(b, rb, rf, 0x88);
2688d163e50SHanhan Wang tc = create4x128BitSuffle(b, r8, rc, 0xdd);
2698d163e50SHanhan Wang td = create4x128BitSuffle(b, r9, rd, 0xdd);
2708d163e50SHanhan Wang te = create4x128BitSuffle(b, ra, re, 0xdd);
2718d163e50SHanhan Wang tf = create4x128BitSuffle(b, rb, rf, 0xdd);
2728d163e50SHanhan Wang
2738d163e50SHanhan Wang // Permute 256-bit lanes using again
2748d163e50SHanhan Wang // 16x _mm512_shuffle_i32x4
2758d163e50SHanhan Wang vs[0x0] = create4x128BitSuffle(b, t0, t8, 0x88);
2768d163e50SHanhan Wang vs[0x1] = create4x128BitSuffle(b, t1, t9, 0x88);
2778d163e50SHanhan Wang vs[0x2] = create4x128BitSuffle(b, t2, ta, 0x88);
2788d163e50SHanhan Wang vs[0x3] = create4x128BitSuffle(b, t3, tb, 0x88);
2798d163e50SHanhan Wang vs[0x4] = create4x128BitSuffle(b, t4, tc, 0x88);
2808d163e50SHanhan Wang vs[0x5] = create4x128BitSuffle(b, t5, td, 0x88);
2818d163e50SHanhan Wang vs[0x6] = create4x128BitSuffle(b, t6, te, 0x88);
2828d163e50SHanhan Wang vs[0x7] = create4x128BitSuffle(b, t7, tf, 0x88);
2838d163e50SHanhan Wang vs[0x8] = create4x128BitSuffle(b, t0, t8, 0xdd);
2848d163e50SHanhan Wang vs[0x9] = create4x128BitSuffle(b, t1, t9, 0xdd);
2858d163e50SHanhan Wang vs[0xa] = create4x128BitSuffle(b, t2, ta, 0xdd);
2868d163e50SHanhan Wang vs[0xb] = create4x128BitSuffle(b, t3, tb, 0xdd);
2878d163e50SHanhan Wang vs[0xc] = create4x128BitSuffle(b, t4, tc, 0xdd);
2888d163e50SHanhan Wang vs[0xd] = create4x128BitSuffle(b, t5, td, 0xdd);
2898d163e50SHanhan Wang vs[0xe] = create4x128BitSuffle(b, t6, te, 0xdd);
2908d163e50SHanhan Wang vs[0xf] = create4x128BitSuffle(b, t7, tf, 0xdd);
2918d163e50SHanhan Wang
2928d163e50SHanhan Wang auto reshInputType = VectorType::get(
2935550c821STres Popp {m, n}, cast<VectorType>(source.getType()).getElementType());
2948d163e50SHanhan Wang Value res =
2958d163e50SHanhan Wang b.create<arith::ConstantOp>(reshInputType, b.getZeroAttr(reshInputType));
2968d163e50SHanhan Wang for (int64_t i = 0; i < m; ++i)
2978d163e50SHanhan Wang res = b.create<vector::InsertOp>(vs[i], res, i);
2988d163e50SHanhan Wang return res;
2998d163e50SHanhan Wang }
3008d163e50SHanhan Wang
3012bc4c3e9SNicolas Vasilache namespace {
3022bc4c3e9SNicolas Vasilache /// Progressive lowering of TransposeOp.
3032bc4c3e9SNicolas Vasilache /// One:
3042bc4c3e9SNicolas Vasilache /// %x = vector.transpose %y, [1, 0]
3052bc4c3e9SNicolas Vasilache /// is replaced by:
3062bc4c3e9SNicolas Vasilache /// %z = arith.constant dense<0.000000e+00>
3072bc4c3e9SNicolas Vasilache /// %0 = vector.extract %y[0, 0]
3082bc4c3e9SNicolas Vasilache /// %1 = vector.insert %0, %z [0, 0]
3092bc4c3e9SNicolas Vasilache /// ..
3102bc4c3e9SNicolas Vasilache /// %x = vector.insert .., .. [.., ..]
3112bc4c3e9SNicolas Vasilache class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
3122bc4c3e9SNicolas Vasilache public:
3132bc4c3e9SNicolas Vasilache using OpRewritePattern::OpRewritePattern;
3142bc4c3e9SNicolas Vasilache
TransposeOpLowering(vector::VectorTransformsOptions vectorTransformOptions,MLIRContext * context,PatternBenefit benefit=1)3152bc4c3e9SNicolas Vasilache TransposeOpLowering(vector::VectorTransformsOptions vectorTransformOptions,
3162bc4c3e9SNicolas Vasilache MLIRContext *context, PatternBenefit benefit = 1)
3172bc4c3e9SNicolas Vasilache : OpRewritePattern<vector::TransposeOp>(context, benefit),
3182bc4c3e9SNicolas Vasilache vectorTransformOptions(vectorTransformOptions) {}
3192bc4c3e9SNicolas Vasilache
matchAndRewrite(vector::TransposeOp op,PatternRewriter & rewriter) const3202bc4c3e9SNicolas Vasilache LogicalResult matchAndRewrite(vector::TransposeOp op,
3212bc4c3e9SNicolas Vasilache PatternRewriter &rewriter) const override {
3222bc4c3e9SNicolas Vasilache auto loc = op.getLoc();
3232bc4c3e9SNicolas Vasilache
3242bc4c3e9SNicolas Vasilache Value input = op.getVector();
3252bc4c3e9SNicolas Vasilache VectorType inputType = op.getSourceVectorType();
3262bc4c3e9SNicolas Vasilache VectorType resType = op.getResultVectorType();
3272bc4c3e9SNicolas Vasilache
328*cbd72cb0SAndrzej Warzyński if (inputType.isScalable())
329*cbd72cb0SAndrzej Warzyński return rewriter.notifyMatchFailure(
330*cbd72cb0SAndrzej Warzyński op, "This lowering does not support scalable vectors");
331*cbd72cb0SAndrzej Warzyński
3322bc4c3e9SNicolas Vasilache // Set up convenience transposition table.
33332c3decbSMatthias Springer ArrayRef<int64_t> transp = op.getPermutation();
3342bc4c3e9SNicolas Vasilache
3358d163e50SHanhan Wang if (isShuffleLike(vectorTransformOptions.vectorTransposeLowering) &&
33625cc5a71SHanhan Wang succeeded(isTranspose2DSlice(op)))
3372bc4c3e9SNicolas Vasilache return rewriter.notifyMatchFailure(
3382bc4c3e9SNicolas Vasilache op, "Options specifies lowering to shuffle");
3392bc4c3e9SNicolas Vasilache
3402bc4c3e9SNicolas Vasilache // Handle a true 2-D matrix transpose differently when requested.
3412bc4c3e9SNicolas Vasilache if (vectorTransformOptions.vectorTransposeLowering ==
3422bc4c3e9SNicolas Vasilache vector::VectorTransposeLowering::Flat &&
3432bc4c3e9SNicolas Vasilache resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0) {
3442bc4c3e9SNicolas Vasilache Type flattenedType =
3452bc4c3e9SNicolas Vasilache VectorType::get(resType.getNumElements(), resType.getElementType());
3462bc4c3e9SNicolas Vasilache auto matrix =
3472bc4c3e9SNicolas Vasilache rewriter.create<vector::ShapeCastOp>(loc, flattenedType, input);
3482bc4c3e9SNicolas Vasilache auto rows = rewriter.getI32IntegerAttr(resType.getShape()[0]);
3492bc4c3e9SNicolas Vasilache auto columns = rewriter.getI32IntegerAttr(resType.getShape()[1]);
3502bc4c3e9SNicolas Vasilache Value trans = rewriter.create<vector::FlatTransposeOp>(
3512bc4c3e9SNicolas Vasilache loc, flattenedType, matrix, rows, columns);
3522bc4c3e9SNicolas Vasilache rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, trans);
3532bc4c3e9SNicolas Vasilache return success();
3542bc4c3e9SNicolas Vasilache }
3552bc4c3e9SNicolas Vasilache
3562bc4c3e9SNicolas Vasilache // Generate unrolled extract/insert ops. We do not unroll the rightmost
3572bc4c3e9SNicolas Vasilache // (i.e., highest-order) dimensions that are not transposed and leave them
3582bc4c3e9SNicolas Vasilache // in vector form to improve performance. Therefore, we prune those
3592bc4c3e9SNicolas Vasilache // dimensions from the shape/transpose data structures used to generate the
3602bc4c3e9SNicolas Vasilache // extract/insert ops.
3612bc4c3e9SNicolas Vasilache SmallVector<int64_t> prunedTransp;
3622bc4c3e9SNicolas Vasilache pruneNonTransposedDims(transp, prunedTransp);
3632bc4c3e9SNicolas Vasilache size_t numPrunedDims = transp.size() - prunedTransp.size();
3642bc4c3e9SNicolas Vasilache auto prunedInShape = inputType.getShape().drop_back(numPrunedDims);
3652bc4c3e9SNicolas Vasilache auto prunedInStrides = computeStrides(prunedInShape);
3662bc4c3e9SNicolas Vasilache
3672bc4c3e9SNicolas Vasilache // Generates the extract/insert operations for every scalar/vector element
3682bc4c3e9SNicolas Vasilache // of the leftmost transposed dimensions. We traverse every transpose
3692bc4c3e9SNicolas Vasilache // element using a linearized index that we delinearize to generate the
3702bc4c3e9SNicolas Vasilache // appropriate indices for the extract/insert operations.
3712bc4c3e9SNicolas Vasilache Value result = rewriter.create<arith::ConstantOp>(
3722bc4c3e9SNicolas Vasilache loc, resType, rewriter.getZeroAttr(resType));
3732bc4c3e9SNicolas Vasilache int64_t numTransposedElements = ShapedType::getNumElements(prunedInShape);
3742bc4c3e9SNicolas Vasilache
3752bc4c3e9SNicolas Vasilache for (int64_t linearIdx = 0; linearIdx < numTransposedElements;
3762bc4c3e9SNicolas Vasilache ++linearIdx) {
3772bc4c3e9SNicolas Vasilache auto extractIdxs = delinearize(linearIdx, prunedInStrides);
3782bc4c3e9SNicolas Vasilache SmallVector<int64_t> insertIdxs(extractIdxs);
3792bc4c3e9SNicolas Vasilache applyPermutationToVector(insertIdxs, prunedTransp);
3802bc4c3e9SNicolas Vasilache Value extractOp =
3812bc4c3e9SNicolas Vasilache rewriter.create<vector::ExtractOp>(loc, input, extractIdxs);
3822bc4c3e9SNicolas Vasilache result =
3832bc4c3e9SNicolas Vasilache rewriter.create<vector::InsertOp>(loc, extractOp, result, insertIdxs);
3842bc4c3e9SNicolas Vasilache }
3852bc4c3e9SNicolas Vasilache
3862bc4c3e9SNicolas Vasilache rewriter.replaceOp(op, result);
3872bc4c3e9SNicolas Vasilache return success();
3882bc4c3e9SNicolas Vasilache }
3892bc4c3e9SNicolas Vasilache
3902bc4c3e9SNicolas Vasilache private:
3912bc4c3e9SNicolas Vasilache /// Options to control the vector patterns.
3922bc4c3e9SNicolas Vasilache vector::VectorTransformsOptions vectorTransformOptions;
3932bc4c3e9SNicolas Vasilache };
3942bc4c3e9SNicolas Vasilache
395*cbd72cb0SAndrzej Warzyński /// Rewrites vector.transpose as vector.shape_cast. This pattern is only applied
396*cbd72cb0SAndrzej Warzyński /// to 2D vectors with at least one unit dim. For example:
397*cbd72cb0SAndrzej Warzyński ///
398*cbd72cb0SAndrzej Warzyński /// Replace:
399*cbd72cb0SAndrzej Warzyński /// vector.transpose %0, [1, 0] : vector<4x1xi32>> to
400*cbd72cb0SAndrzej Warzyński /// vector<1x4xi32>
401*cbd72cb0SAndrzej Warzyński /// with:
402*cbd72cb0SAndrzej Warzyński /// vector.shape_cast %0 : vector<4x1xi32> to vector<1x4xi32>
403*cbd72cb0SAndrzej Warzyński ///
404*cbd72cb0SAndrzej Warzyński /// Source with leading unit dim (inverse) is also replaced. Unit dim must
405*cbd72cb0SAndrzej Warzyński /// be fixed. Non-unit dim can be scalable.
406*cbd72cb0SAndrzej Warzyński ///
407*cbd72cb0SAndrzej Warzyński /// TODO: This pattern was introduced specifically to help lower scalable
408*cbd72cb0SAndrzej Warzyński /// vectors. In hindsight, a more specialised canonicalization (for shape_cast's
409*cbd72cb0SAndrzej Warzyński /// to cancel out) would be preferable:
410*cbd72cb0SAndrzej Warzyński ///
411*cbd72cb0SAndrzej Warzyński /// BEFORE:
412*cbd72cb0SAndrzej Warzyński /// %0 = some_op
413*cbd72cb0SAndrzej Warzyński /// %1 = vector.shape_cast %0 : vector<[4]xf32> to vector<[4]x1xf32>
414*cbd72cb0SAndrzej Warzyński /// %2 = vector.transpose %1 [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
415*cbd72cb0SAndrzej Warzyński /// AFTER:
416*cbd72cb0SAndrzej Warzyński /// %0 = some_op
417*cbd72cb0SAndrzej Warzyński /// %1 = vector.shape_cast %0 : vector<[4]xf32> to vector<1x[4]xf32>
418*cbd72cb0SAndrzej Warzyński ///
419*cbd72cb0SAndrzej Warzyński /// Given the context above, we may want to consider (re-)moving this pattern
420*cbd72cb0SAndrzej Warzyński /// at some later time. I am leaving it for now in case there are other users
421*cbd72cb0SAndrzej Warzyński /// that I am not aware of.
422*cbd72cb0SAndrzej Warzyński class Transpose2DWithUnitDimToShapeCast
423*cbd72cb0SAndrzej Warzyński : public OpRewritePattern<vector::TransposeOp> {
424*cbd72cb0SAndrzej Warzyński public:
425*cbd72cb0SAndrzej Warzyński using OpRewritePattern::OpRewritePattern;
426*cbd72cb0SAndrzej Warzyński
Transpose2DWithUnitDimToShapeCast(MLIRContext * context,PatternBenefit benefit=1)427*cbd72cb0SAndrzej Warzyński Transpose2DWithUnitDimToShapeCast(MLIRContext *context,
428*cbd72cb0SAndrzej Warzyński PatternBenefit benefit = 1)
429*cbd72cb0SAndrzej Warzyński : OpRewritePattern<vector::TransposeOp>(context, benefit) {}
430*cbd72cb0SAndrzej Warzyński
matchAndRewrite(vector::TransposeOp op,PatternRewriter & rewriter) const431*cbd72cb0SAndrzej Warzyński LogicalResult matchAndRewrite(vector::TransposeOp op,
432*cbd72cb0SAndrzej Warzyński PatternRewriter &rewriter) const override {
433*cbd72cb0SAndrzej Warzyński Value input = op.getVector();
434*cbd72cb0SAndrzej Warzyński VectorType resType = op.getResultVectorType();
435*cbd72cb0SAndrzej Warzyński
436*cbd72cb0SAndrzej Warzyński // Set up convenience transposition table.
437*cbd72cb0SAndrzej Warzyński ArrayRef<int64_t> transp = op.getPermutation();
438*cbd72cb0SAndrzej Warzyński
439*cbd72cb0SAndrzej Warzyński if (resType.getRank() == 2 &&
440*cbd72cb0SAndrzej Warzyński ((resType.getShape().front() == 1 &&
441*cbd72cb0SAndrzej Warzyński !resType.getScalableDims().front()) ||
442*cbd72cb0SAndrzej Warzyński (resType.getShape().back() == 1 &&
443*cbd72cb0SAndrzej Warzyński !resType.getScalableDims().back())) &&
444*cbd72cb0SAndrzej Warzyński transp == ArrayRef<int64_t>({1, 0})) {
445*cbd72cb0SAndrzej Warzyński rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, input);
446*cbd72cb0SAndrzej Warzyński return success();
447*cbd72cb0SAndrzej Warzyński }
448*cbd72cb0SAndrzej Warzyński
449*cbd72cb0SAndrzej Warzyński return failure();
450*cbd72cb0SAndrzej Warzyński }
451*cbd72cb0SAndrzej Warzyński };
452*cbd72cb0SAndrzej Warzyński
4538d163e50SHanhan Wang /// Rewrite a 2-D vector.transpose as a sequence of shuffle ops.
4548d163e50SHanhan Wang /// If the strategy is Shuffle1D, it will be lowered to:
4552bc4c3e9SNicolas Vasilache /// vector.shape_cast 2D -> 1D
4562bc4c3e9SNicolas Vasilache /// vector.shuffle
4572bc4c3e9SNicolas Vasilache /// vector.shape_cast 1D -> 2D
4588d163e50SHanhan Wang /// If the strategy is Shuffle16x16, it will be lowered to a sequence of shuffle
4598d163e50SHanhan Wang /// ops on 16xf32 vectors.
4602bc4c3e9SNicolas Vasilache class TransposeOp2DToShuffleLowering
4612bc4c3e9SNicolas Vasilache : public OpRewritePattern<vector::TransposeOp> {
4622bc4c3e9SNicolas Vasilache public:
4632bc4c3e9SNicolas Vasilache using OpRewritePattern::OpRewritePattern;
4642bc4c3e9SNicolas Vasilache
TransposeOp2DToShuffleLowering(vector::VectorTransformsOptions vectorTransformOptions,MLIRContext * context,PatternBenefit benefit=1)4652bc4c3e9SNicolas Vasilache TransposeOp2DToShuffleLowering(
4662bc4c3e9SNicolas Vasilache vector::VectorTransformsOptions vectorTransformOptions,
4672bc4c3e9SNicolas Vasilache MLIRContext *context, PatternBenefit benefit = 1)
4682bc4c3e9SNicolas Vasilache : OpRewritePattern<vector::TransposeOp>(context, benefit),
4692bc4c3e9SNicolas Vasilache vectorTransformOptions(vectorTransformOptions) {}
4702bc4c3e9SNicolas Vasilache
matchAndRewrite(vector::TransposeOp op,PatternRewriter & rewriter) const4712bc4c3e9SNicolas Vasilache LogicalResult matchAndRewrite(vector::TransposeOp op,
4722bc4c3e9SNicolas Vasilache PatternRewriter &rewriter) const override {
47325cc5a71SHanhan Wang if (!isShuffleLike(vectorTransformOptions.vectorTransposeLowering))
47425cc5a71SHanhan Wang return rewriter.notifyMatchFailure(
47525cc5a71SHanhan Wang op, "not using vector shuffle based lowering");
47625cc5a71SHanhan Wang
47788610b79SBenjamin Maxwell if (op.getSourceVectorType().isScalable())
47888610b79SBenjamin Maxwell return rewriter.notifyMatchFailure(
47988610b79SBenjamin Maxwell op, "vector shuffle lowering not supported for scalable vectors");
48088610b79SBenjamin Maxwell
48125cc5a71SHanhan Wang auto srcGtOneDims = isTranspose2DSlice(op);
48225cc5a71SHanhan Wang if (failed(srcGtOneDims))
48325cc5a71SHanhan Wang return rewriter.notifyMatchFailure(
48425cc5a71SHanhan Wang op, "expected transposition on a 2D slice");
4852bc4c3e9SNicolas Vasilache
4862bc4c3e9SNicolas Vasilache VectorType srcType = op.getSourceVectorType();
48725cc5a71SHanhan Wang int64_t m = srcType.getDimSize(std::get<0>(srcGtOneDims.value()));
48825cc5a71SHanhan Wang int64_t n = srcType.getDimSize(std::get<1>(srcGtOneDims.value()));
4892bc4c3e9SNicolas Vasilache
49025cc5a71SHanhan Wang // Reshape the n-D input vector with only two dimensions greater than one
49125cc5a71SHanhan Wang // to a 2-D vector.
49225cc5a71SHanhan Wang Location loc = op.getLoc();
49325cc5a71SHanhan Wang auto flattenedType = VectorType::get({n * m}, srcType.getElementType());
49425cc5a71SHanhan Wang auto reshInputType = VectorType::get({m, n}, srcType.getElementType());
49525cc5a71SHanhan Wang auto reshInput = rewriter.create<vector::ShapeCastOp>(loc, flattenedType,
49625cc5a71SHanhan Wang op.getVector());
4972bc4c3e9SNicolas Vasilache
4988d163e50SHanhan Wang Value res;
49925cc5a71SHanhan Wang if (vectorTransformOptions.vectorTransposeLowering ==
50025cc5a71SHanhan Wang VectorTransposeLowering::Shuffle16x16 &&
50125cc5a71SHanhan Wang m == 16 && n == 16) {
50225cc5a71SHanhan Wang reshInput =
50325cc5a71SHanhan Wang rewriter.create<vector::ShapeCastOp>(loc, reshInputType, reshInput);
50425cc5a71SHanhan Wang res = transposeToShuffle16x16(rewriter, reshInput, m, n);
50525cc5a71SHanhan Wang } else {
50625cc5a71SHanhan Wang // Fallback to shuffle on 1D approach.
50725cc5a71SHanhan Wang res = transposeToShuffle1D(rewriter, reshInput, m, n);
5088d163e50SHanhan Wang }
5092bc4c3e9SNicolas Vasilache
5102bc4c3e9SNicolas Vasilache rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
5118d163e50SHanhan Wang op, op.getResultVectorType(), res);
5122bc4c3e9SNicolas Vasilache
5132bc4c3e9SNicolas Vasilache return success();
5142bc4c3e9SNicolas Vasilache }
5152bc4c3e9SNicolas Vasilache
5162bc4c3e9SNicolas Vasilache private:
5172bc4c3e9SNicolas Vasilache /// Options to control the vector patterns.
5182bc4c3e9SNicolas Vasilache vector::VectorTransformsOptions vectorTransformOptions;
5192bc4c3e9SNicolas Vasilache };
5202bc4c3e9SNicolas Vasilache } // namespace
5212bc4c3e9SNicolas Vasilache
populateVectorTransposeLoweringPatterns(RewritePatternSet & patterns,VectorTransformsOptions options,PatternBenefit benefit)5222bc4c3e9SNicolas Vasilache void mlir::vector::populateVectorTransposeLoweringPatterns(
5232bc4c3e9SNicolas Vasilache RewritePatternSet &patterns, VectorTransformsOptions options,
5242bc4c3e9SNicolas Vasilache PatternBenefit benefit) {
525*cbd72cb0SAndrzej Warzyński patterns.add<Transpose2DWithUnitDimToShapeCast>(patterns.getContext(),
526*cbd72cb0SAndrzej Warzyński benefit);
5272bc4c3e9SNicolas Vasilache patterns.add<TransposeOpLowering, TransposeOp2DToShuffleLowering>(
5282bc4c3e9SNicolas Vasilache options, patterns.getContext(), benefit);
5292bc4c3e9SNicolas Vasilache }
530