1 //===- VectorLinearize.cpp - vector linearization transforms --------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements patterns and pass for linearizing ND vectors into 1D. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Arith/IR/Arith.h" 14 #include "mlir/Dialect/Vector/IR/VectorOps.h" 15 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" 16 #include "mlir/IR/PatternMatch.h" 17 #include "mlir/IR/TypeUtilities.h" 18 #include "mlir/Transforms/DialectConversion.h" 19 20 using namespace mlir; 21 22 namespace { 23 struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> { 24 using OpConversionPattern::OpConversionPattern; 25 26 LogicalResult 27 matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor, 28 ConversionPatternRewriter &rewriter) const override { 29 Location loc = constOp.getLoc(); 30 auto resType = 31 getTypeConverter()->convertType<VectorType>(constOp.getType()); 32 if (!resType) 33 return rewriter.notifyMatchFailure(loc, "can't convert return type"); 34 35 auto dstElementsAttr = dyn_cast<DenseElementsAttr>(constOp.getValue()); 36 if (!dstElementsAttr) 37 return rewriter.notifyMatchFailure(loc, "unsupported attr type"); 38 39 dstElementsAttr = dstElementsAttr.reshape(resType); 40 rewriter.replaceOpWithNewOp<arith::ConstantOp>(constOp, resType, 41 dstElementsAttr); 42 return success(); 43 } 44 }; 45 46 struct LinearizeVectorizable final 47 : OpTraitConversionPattern<OpTrait::Vectorizable> { 48 using OpTraitConversionPattern::OpTraitConversionPattern; 49 50 LogicalResult 51 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 52 ConversionPatternRewriter &rewriter) const override { 53 FailureOr<Operation *> newOp = 54 convertOpResultTypes(op, operands, *getTypeConverter(), rewriter); 55 if (failed(newOp)) 56 return failure(); 57 58 rewriter.replaceOp(op, (*newOp)->getResults()); 59 return success(); 60 } 61 }; 62 } // namespace 63 64 void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality( 65 TypeConverter &typeConverter, RewritePatternSet &patterns, 66 ConversionTarget &target) { 67 typeConverter.addConversion([](VectorType type) -> std::optional<Type> { 68 // Ignore scalable vectors for now. 69 if (type.getRank() <= 1 || type.isScalable()) 70 return type; 71 72 return VectorType::get(type.getNumElements(), type.getElementType()); 73 }); 74 75 auto materializeCast = [](OpBuilder &builder, Type type, ValueRange inputs, 76 Location loc) -> Value { 77 if (inputs.size() != 1 || !isa<VectorType>(inputs.front().getType()) || 78 !isa<VectorType>(type)) 79 return nullptr; 80 81 return builder.create<vector::ShapeCastOp>(loc, type, inputs.front()); 82 }; 83 typeConverter.addArgumentMaterialization(materializeCast); 84 typeConverter.addSourceMaterialization(materializeCast); 85 typeConverter.addTargetMaterialization(materializeCast); 86 87 target.markUnknownOpDynamicallyLegal( 88 [&](Operation *op) -> std::optional<bool> { 89 if (isa<arith::ConstantOp>(op) || op->hasTrait<OpTrait::Vectorizable>()) 90 return typeConverter.isLegal(op); 91 92 return std::nullopt; 93 }); 94 95 patterns.add<LinearizeConstant, LinearizeVectorizable>(typeConverter, 96 patterns.getContext()); 97 } 98