xref: /llvm-project/mlir/lib/Dialect/Vector/Transforms/LowerVectorBitCast.cpp (revision db791b278a414fb6df1acc1799adcf11d8fb9169)
10ea1271eSHan-Chung Wang //===- LowerVectorBitCast.cpp - Lower 'vector.bitcast' operation ----------===//
20ea1271eSHan-Chung Wang //
30ea1271eSHan-Chung Wang // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
40ea1271eSHan-Chung Wang // See https://llvm.org/LICENSE.txt for license information.
50ea1271eSHan-Chung Wang // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
60ea1271eSHan-Chung Wang //
70ea1271eSHan-Chung Wang //===----------------------------------------------------------------------===//
80ea1271eSHan-Chung Wang //
90ea1271eSHan-Chung Wang // This file implements target-independent rewrites and utilities to lower the
100ea1271eSHan-Chung Wang // 'vector.bitcast' operation.
110ea1271eSHan-Chung Wang //
120ea1271eSHan-Chung Wang //===----------------------------------------------------------------------===//
130ea1271eSHan-Chung Wang 
140ea1271eSHan-Chung Wang #include "mlir/Dialect/Vector/IR/VectorOps.h"
150ea1271eSHan-Chung Wang #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
160ea1271eSHan-Chung Wang #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
170ea1271eSHan-Chung Wang #include "mlir/IR/BuiltinTypes.h"
180ea1271eSHan-Chung Wang #include "mlir/IR/PatternMatch.h"
190ea1271eSHan-Chung Wang 
200ea1271eSHan-Chung Wang #define DEBUG_TYPE "vector-bitcast-lowering"
210ea1271eSHan-Chung Wang 
220ea1271eSHan-Chung Wang using namespace mlir;
230ea1271eSHan-Chung Wang using namespace mlir::vector;
240ea1271eSHan-Chung Wang 
250ea1271eSHan-Chung Wang namespace {
260ea1271eSHan-Chung Wang 
270ea1271eSHan-Chung Wang /// A one-shot unrolling of vector.bitcast to the `targetRank`.
280ea1271eSHan-Chung Wang ///
290ea1271eSHan-Chung Wang /// Example:
300ea1271eSHan-Chung Wang ///
310ea1271eSHan-Chung Wang ///   vector.bitcast %a, %b : vector<1x2x3x4xi64> to vector<1x2x3x8xi32>
320ea1271eSHan-Chung Wang ///
330ea1271eSHan-Chung Wang /// Would be unrolled to:
340ea1271eSHan-Chung Wang ///
350ea1271eSHan-Chung Wang /// %result = arith.constant dense<0> : vector<1x2x3x8xi32>
360ea1271eSHan-Chung Wang /// %0 = vector.extract %a[0, 0, 0]                 ─┐
370ea1271eSHan-Chung Wang ///        : vector<4xi64> from vector<1x2x3x4xi64>  |
380ea1271eSHan-Chung Wang /// %1 = vector.bitcast %0                           | - Repeated 6x for
390ea1271eSHan-Chung Wang ///        : vector<4xi64> to vector<8xi32>          |   all leading positions
400ea1271eSHan-Chung Wang /// %2 = vector.insert %1, %result [0, 0, 0]         |
410ea1271eSHan-Chung Wang ///        : vector<8xi64> into vector<1x2x3x8xi32> ─┘
420ea1271eSHan-Chung Wang ///
430ea1271eSHan-Chung Wang /// Note: If any leading dimension before the `targetRank` is scalable the
440ea1271eSHan-Chung Wang /// unrolling will stop before the scalable dimension.
450ea1271eSHan-Chung Wang class UnrollBitCastOp final : public OpRewritePattern<vector::BitCastOp> {
460ea1271eSHan-Chung Wang public:
UnrollBitCastOp(int64_t targetRank,MLIRContext * context,PatternBenefit benefit=1)470ea1271eSHan-Chung Wang   UnrollBitCastOp(int64_t targetRank, MLIRContext *context,
480ea1271eSHan-Chung Wang                   PatternBenefit benefit = 1)
490ea1271eSHan-Chung Wang       : OpRewritePattern(context, benefit), targetRank(targetRank) {};
500ea1271eSHan-Chung Wang 
matchAndRewrite(vector::BitCastOp op,PatternRewriter & rewriter) const510ea1271eSHan-Chung Wang   LogicalResult matchAndRewrite(vector::BitCastOp op,
520ea1271eSHan-Chung Wang                                 PatternRewriter &rewriter) const override {
530ea1271eSHan-Chung Wang     VectorType resultType = op.getResultVectorType();
540ea1271eSHan-Chung Wang     auto unrollIterator = vector::createUnrollIterator(resultType, targetRank);
550ea1271eSHan-Chung Wang     if (!unrollIterator)
560ea1271eSHan-Chung Wang       return failure();
570ea1271eSHan-Chung Wang 
58*dc5d5410SBenjamin Maxwell     auto unrollRank = unrollIterator->getRank();
59*dc5d5410SBenjamin Maxwell     ArrayRef<int64_t> shape = resultType.getShape().drop_front(unrollRank);
60*dc5d5410SBenjamin Maxwell     ArrayRef<bool> scalableDims =
61*dc5d5410SBenjamin Maxwell         resultType.getScalableDims().drop_front(unrollRank);
62*dc5d5410SBenjamin Maxwell     auto bitcastResType =
63*dc5d5410SBenjamin Maxwell         VectorType::get(shape, resultType.getElementType(), scalableDims);
640ea1271eSHan-Chung Wang 
650ea1271eSHan-Chung Wang     Location loc = op.getLoc();
660ea1271eSHan-Chung Wang     Value result = rewriter.create<arith::ConstantOp>(
670ea1271eSHan-Chung Wang         loc, resultType, rewriter.getZeroAttr(resultType));
680ea1271eSHan-Chung Wang     for (auto position : *unrollIterator) {
690ea1271eSHan-Chung Wang       Value extract =
700ea1271eSHan-Chung Wang           rewriter.create<vector::ExtractOp>(loc, op.getSource(), position);
710ea1271eSHan-Chung Wang       Value bitcast =
720ea1271eSHan-Chung Wang           rewriter.create<vector::BitCastOp>(loc, bitcastResType, extract);
730ea1271eSHan-Chung Wang       result =
740ea1271eSHan-Chung Wang           rewriter.create<vector::InsertOp>(loc, bitcast, result, position);
750ea1271eSHan-Chung Wang     }
760ea1271eSHan-Chung Wang 
770ea1271eSHan-Chung Wang     rewriter.replaceOp(op, result);
780ea1271eSHan-Chung Wang     return success();
790ea1271eSHan-Chung Wang   }
800ea1271eSHan-Chung Wang 
810ea1271eSHan-Chung Wang private:
820ea1271eSHan-Chung Wang   int64_t targetRank = 1;
830ea1271eSHan-Chung Wang };
840ea1271eSHan-Chung Wang 
850ea1271eSHan-Chung Wang } // namespace
860ea1271eSHan-Chung Wang 
populateVectorBitCastLoweringPatterns(RewritePatternSet & patterns,int64_t targetRank,PatternBenefit benefit)870ea1271eSHan-Chung Wang void mlir::vector::populateVectorBitCastLoweringPatterns(
880ea1271eSHan-Chung Wang     RewritePatternSet &patterns, int64_t targetRank, PatternBenefit benefit) {
890ea1271eSHan-Chung Wang   patterns.add<UnrollBitCastOp>(targetRank, patterns.getContext(), benefit);
900ea1271eSHan-Chung Wang }
91