xref: /llvm-project/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp (revision 35ef3994bf738318b59ce640910fb1ccd3bb7dcb)
1 //===- VectorLinearize.cpp - vector linearization transforms --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns and pass for linearizing ND vectors into 1D.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Arith/IR/Arith.h"
14 #include "mlir/Dialect/Vector/IR/VectorOps.h"
15 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
16 #include "mlir/IR/PatternMatch.h"
17 #include "mlir/IR/TypeUtilities.h"
18 #include "mlir/Transforms/DialectConversion.h"
19 
20 using namespace mlir;
21 
22 namespace {
23 struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> {
24   using OpConversionPattern::OpConversionPattern;
25 
26   LogicalResult
27   matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor,
28                   ConversionPatternRewriter &rewriter) const override {
29     Location loc = constOp.getLoc();
30     auto resType =
31         getTypeConverter()->convertType<VectorType>(constOp.getType());
32     if (!resType)
33       return rewriter.notifyMatchFailure(loc, "can't convert return type");
34 
35     auto dstElementsAttr = dyn_cast<DenseElementsAttr>(constOp.getValue());
36     if (!dstElementsAttr)
37       return rewriter.notifyMatchFailure(loc, "unsupported attr type");
38 
39     dstElementsAttr = dstElementsAttr.reshape(resType);
40     rewriter.replaceOpWithNewOp<arith::ConstantOp>(constOp, resType,
41                                                    dstElementsAttr);
42     return success();
43   }
44 };
45 
46 struct LinearizeVectorizable final
47     : OpTraitConversionPattern<OpTrait::Vectorizable> {
48   using OpTraitConversionPattern::OpTraitConversionPattern;
49 
50   LogicalResult
51   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
52                   ConversionPatternRewriter &rewriter) const override {
53     FailureOr<Operation *> newOp =
54         convertOpResultTypes(op, operands, *getTypeConverter(), rewriter);
55     if (failed(newOp))
56       return failure();
57 
58     rewriter.replaceOp(op, (*newOp)->getResults());
59     return success();
60   }
61 };
62 } // namespace
63 
64 void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
65     TypeConverter &typeConverter, RewritePatternSet &patterns,
66     ConversionTarget &target) {
67   typeConverter.addConversion([](VectorType type) -> std::optional<Type> {
68     // Ignore scalable vectors for now.
69     if (type.getRank() <= 1 || type.isScalable())
70       return type;
71 
72     return VectorType::get(type.getNumElements(), type.getElementType());
73   });
74 
75   auto materializeCast = [](OpBuilder &builder, Type type, ValueRange inputs,
76                             Location loc) -> Value {
77     if (inputs.size() != 1 || !isa<VectorType>(inputs.front().getType()) ||
78         !isa<VectorType>(type))
79       return nullptr;
80 
81     return builder.create<vector::ShapeCastOp>(loc, type, inputs.front());
82   };
83   typeConverter.addArgumentMaterialization(materializeCast);
84   typeConverter.addSourceMaterialization(materializeCast);
85   typeConverter.addTargetMaterialization(materializeCast);
86 
87   target.markUnknownOpDynamicallyLegal(
88       [&](Operation *op) -> std::optional<bool> {
89         if (isa<arith::ConstantOp>(op) || op->hasTrait<OpTrait::Vectorizable>())
90           return typeConverter.isLegal(op);
91 
92         return std::nullopt;
93       });
94 
95   patterns.add<LinearizeConstant, LinearizeVectorizable>(typeConverter,
96                                                          patterns.getContext());
97 }
98