10ea1271eSHan-Chung Wang //===- LowerVectorBitCast.cpp - Lower 'vector.bitcast' operation ----------===//
20ea1271eSHan-Chung Wang //
30ea1271eSHan-Chung Wang // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
40ea1271eSHan-Chung Wang // See https://llvm.org/LICENSE.txt for license information.
50ea1271eSHan-Chung Wang // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
60ea1271eSHan-Chung Wang //
70ea1271eSHan-Chung Wang //===----------------------------------------------------------------------===//
80ea1271eSHan-Chung Wang //
90ea1271eSHan-Chung Wang // This file implements target-independent rewrites and utilities to lower the
100ea1271eSHan-Chung Wang // 'vector.bitcast' operation.
110ea1271eSHan-Chung Wang //
120ea1271eSHan-Chung Wang //===----------------------------------------------------------------------===//
130ea1271eSHan-Chung Wang
140ea1271eSHan-Chung Wang #include "mlir/Dialect/Vector/IR/VectorOps.h"
150ea1271eSHan-Chung Wang #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
160ea1271eSHan-Chung Wang #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
170ea1271eSHan-Chung Wang #include "mlir/IR/BuiltinTypes.h"
180ea1271eSHan-Chung Wang #include "mlir/IR/PatternMatch.h"
190ea1271eSHan-Chung Wang
200ea1271eSHan-Chung Wang #define DEBUG_TYPE "vector-bitcast-lowering"
210ea1271eSHan-Chung Wang
220ea1271eSHan-Chung Wang using namespace mlir;
230ea1271eSHan-Chung Wang using namespace mlir::vector;
240ea1271eSHan-Chung Wang
250ea1271eSHan-Chung Wang namespace {
260ea1271eSHan-Chung Wang
270ea1271eSHan-Chung Wang /// A one-shot unrolling of vector.bitcast to the `targetRank`.
280ea1271eSHan-Chung Wang ///
290ea1271eSHan-Chung Wang /// Example:
300ea1271eSHan-Chung Wang ///
310ea1271eSHan-Chung Wang /// vector.bitcast %a, %b : vector<1x2x3x4xi64> to vector<1x2x3x8xi32>
320ea1271eSHan-Chung Wang ///
330ea1271eSHan-Chung Wang /// Would be unrolled to:
340ea1271eSHan-Chung Wang ///
350ea1271eSHan-Chung Wang /// %result = arith.constant dense<0> : vector<1x2x3x8xi32>
360ea1271eSHan-Chung Wang /// %0 = vector.extract %a[0, 0, 0] ─┐
370ea1271eSHan-Chung Wang /// : vector<4xi64> from vector<1x2x3x4xi64> |
380ea1271eSHan-Chung Wang /// %1 = vector.bitcast %0 | - Repeated 6x for
390ea1271eSHan-Chung Wang /// : vector<4xi64> to vector<8xi32> | all leading positions
400ea1271eSHan-Chung Wang /// %2 = vector.insert %1, %result [0, 0, 0] |
410ea1271eSHan-Chung Wang /// : vector<8xi64> into vector<1x2x3x8xi32> ─┘
420ea1271eSHan-Chung Wang ///
430ea1271eSHan-Chung Wang /// Note: If any leading dimension before the `targetRank` is scalable the
440ea1271eSHan-Chung Wang /// unrolling will stop before the scalable dimension.
450ea1271eSHan-Chung Wang class UnrollBitCastOp final : public OpRewritePattern<vector::BitCastOp> {
460ea1271eSHan-Chung Wang public:
UnrollBitCastOp(int64_t targetRank,MLIRContext * context,PatternBenefit benefit=1)470ea1271eSHan-Chung Wang UnrollBitCastOp(int64_t targetRank, MLIRContext *context,
480ea1271eSHan-Chung Wang PatternBenefit benefit = 1)
490ea1271eSHan-Chung Wang : OpRewritePattern(context, benefit), targetRank(targetRank) {};
500ea1271eSHan-Chung Wang
matchAndRewrite(vector::BitCastOp op,PatternRewriter & rewriter) const510ea1271eSHan-Chung Wang LogicalResult matchAndRewrite(vector::BitCastOp op,
520ea1271eSHan-Chung Wang PatternRewriter &rewriter) const override {
530ea1271eSHan-Chung Wang VectorType resultType = op.getResultVectorType();
540ea1271eSHan-Chung Wang auto unrollIterator = vector::createUnrollIterator(resultType, targetRank);
550ea1271eSHan-Chung Wang if (!unrollIterator)
560ea1271eSHan-Chung Wang return failure();
570ea1271eSHan-Chung Wang
58*dc5d5410SBenjamin Maxwell auto unrollRank = unrollIterator->getRank();
59*dc5d5410SBenjamin Maxwell ArrayRef<int64_t> shape = resultType.getShape().drop_front(unrollRank);
60*dc5d5410SBenjamin Maxwell ArrayRef<bool> scalableDims =
61*dc5d5410SBenjamin Maxwell resultType.getScalableDims().drop_front(unrollRank);
62*dc5d5410SBenjamin Maxwell auto bitcastResType =
63*dc5d5410SBenjamin Maxwell VectorType::get(shape, resultType.getElementType(), scalableDims);
640ea1271eSHan-Chung Wang
650ea1271eSHan-Chung Wang Location loc = op.getLoc();
660ea1271eSHan-Chung Wang Value result = rewriter.create<arith::ConstantOp>(
670ea1271eSHan-Chung Wang loc, resultType, rewriter.getZeroAttr(resultType));
680ea1271eSHan-Chung Wang for (auto position : *unrollIterator) {
690ea1271eSHan-Chung Wang Value extract =
700ea1271eSHan-Chung Wang rewriter.create<vector::ExtractOp>(loc, op.getSource(), position);
710ea1271eSHan-Chung Wang Value bitcast =
720ea1271eSHan-Chung Wang rewriter.create<vector::BitCastOp>(loc, bitcastResType, extract);
730ea1271eSHan-Chung Wang result =
740ea1271eSHan-Chung Wang rewriter.create<vector::InsertOp>(loc, bitcast, result, position);
750ea1271eSHan-Chung Wang }
760ea1271eSHan-Chung Wang
770ea1271eSHan-Chung Wang rewriter.replaceOp(op, result);
780ea1271eSHan-Chung Wang return success();
790ea1271eSHan-Chung Wang }
800ea1271eSHan-Chung Wang
810ea1271eSHan-Chung Wang private:
820ea1271eSHan-Chung Wang int64_t targetRank = 1;
830ea1271eSHan-Chung Wang };
840ea1271eSHan-Chung Wang
850ea1271eSHan-Chung Wang } // namespace
860ea1271eSHan-Chung Wang
populateVectorBitCastLoweringPatterns(RewritePatternSet & patterns,int64_t targetRank,PatternBenefit benefit)870ea1271eSHan-Chung Wang void mlir::vector::populateVectorBitCastLoweringPatterns(
880ea1271eSHan-Chung Wang RewritePatternSet &patterns, int64_t targetRank, PatternBenefit benefit) {
890ea1271eSHan-Chung Wang patterns.add<UnrollBitCastOp>(targetRank, patterns.getContext(), benefit);
900ea1271eSHan-Chung Wang }
91