171c10803SAlexander Belyaev //===- ShapeToStandard.cpp - conversion from Shape to Standard dialect ----===//
23713314bSFrederik Gossen //
33713314bSFrederik Gossen // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
43713314bSFrederik Gossen // See https://llvm.org/LICENSE.txt for license information.
53713314bSFrederik Gossen // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
63713314bSFrederik Gossen //
73713314bSFrederik Gossen //===----------------------------------------------------------------------===//
83713314bSFrederik Gossen
93713314bSFrederik Gossen #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h"
103713314bSFrederik Gossen
11abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h"
1236550692SRiver Riddle #include "mlir/Dialect/Func/IR/FuncOps.h"
138b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/IR/SCF.h"
143713314bSFrederik Gossen #include "mlir/Dialect/Shape/IR/Shape.h"
15444822d7SSean Silva #include "mlir/Dialect/Tensor/IR/Tensor.h"
164d67b278SJeff Niu #include "mlir/IR/IRMapping.h"
17f30f347dSTres Popp #include "mlir/IR/ImplicitLocOpBuilder.h"
1867d0d7acSMichele Scuttari #include "mlir/Pass/Pass.h"
193713314bSFrederik Gossen #include "mlir/Transforms/DialectConversion.h"
20f30f347dSTres Popp #include "llvm/ADT/STLExtras.h"
213713314bSFrederik Gossen
2267d0d7acSMichele Scuttari namespace mlir {
2367d0d7acSMichele Scuttari #define GEN_PASS_DEF_CONVERTSHAPETOSTANDARD
2467d0d7acSMichele Scuttari #include "mlir/Conversion/Passes.h.inc"
2567d0d7acSMichele Scuttari } // namespace mlir
2667d0d7acSMichele Scuttari
2724edbdf9SFrederik Gossen using namespace mlir;
2880be54c0SAlexander Belyaev using namespace mlir::shape;
29a70f2eb3SFrederik Gossen using namespace mlir::scf;
3024edbdf9SFrederik Gossen
313713314bSFrederik Gossen /// Conversion patterns.
324baf18dbSFrederik Gossen namespace {
339df6afbbSFrederik Gossen class AnyOpConversion : public OpConversionPattern<AnyOp> {
349df6afbbSFrederik Gossen public:
359df6afbbSFrederik Gossen using OpConversionPattern<AnyOp>::OpConversionPattern;
369df6afbbSFrederik Gossen
379df6afbbSFrederik Gossen LogicalResult
38b54c724bSRiver Riddle matchAndRewrite(AnyOp op, OpAdaptor adaptor,
394baf18dbSFrederik Gossen ConversionPatternRewriter &rewriter) const override;
404baf18dbSFrederik Gossen };
414baf18dbSFrederik Gossen } // namespace
424baf18dbSFrederik Gossen
434baf18dbSFrederik Gossen LogicalResult
matchAndRewrite(AnyOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const44b54c724bSRiver Riddle AnyOpConversion::matchAndRewrite(AnyOp op, OpAdaptor adaptor,
454baf18dbSFrederik Gossen ConversionPatternRewriter &rewriter) const {
469df6afbbSFrederik Gossen // Replace `any` with its first operand.
479df6afbbSFrederik Gossen // Any operand would be a valid substitution.
48cfb72fd3SJacques Pienaar rewriter.replaceOp(op, {adaptor.getInputs().front()});
499df6afbbSFrederik Gossen return success();
509df6afbbSFrederik Gossen }
519df6afbbSFrederik Gossen
524baf18dbSFrederik Gossen namespace {
5380be54c0SAlexander Belyaev template <typename SrcOpTy, typename DstOpTy>
5480be54c0SAlexander Belyaev class BinaryOpConversion : public OpConversionPattern<SrcOpTy> {
5580be54c0SAlexander Belyaev public:
5680be54c0SAlexander Belyaev using OpConversionPattern<SrcOpTy>::OpConversionPattern;
5780be54c0SAlexander Belyaev
5880be54c0SAlexander Belyaev LogicalResult
matchAndRewrite(SrcOpTy op,typename SrcOpTy::Adaptor adaptor,ConversionPatternRewriter & rewriter) const59b54c724bSRiver Riddle matchAndRewrite(SrcOpTy op, typename SrcOpTy::Adaptor adaptor,
6080be54c0SAlexander Belyaev ConversionPatternRewriter &rewriter) const override {
616673c6cdSFrederik Gossen // For now, only error-free types are supported by this lowering.
625550c821STres Popp if (isa<SizeType>(op.getType()))
636673c6cdSFrederik Gossen return failure();
646673c6cdSFrederik Gossen
65cfb72fd3SJacques Pienaar rewriter.replaceOpWithNewOp<DstOpTy>(op, adaptor.getLhs(),
66cfb72fd3SJacques Pienaar adaptor.getRhs());
6780be54c0SAlexander Belyaev return success();
6880be54c0SAlexander Belyaev }
6980be54c0SAlexander Belyaev };
704baf18dbSFrederik Gossen } // namespace
7180be54c0SAlexander Belyaev
724baf18dbSFrederik Gossen namespace {
73a70f2eb3SFrederik Gossen struct BroadcastOpConverter : public OpConversionPattern<BroadcastOp> {
74a70f2eb3SFrederik Gossen using OpConversionPattern<BroadcastOp>::OpConversionPattern;
755d9f33aaSStephan Herhut
765d9f33aaSStephan Herhut LogicalResult
77b54c724bSRiver Riddle matchAndRewrite(BroadcastOp op, OpAdaptor adaptor,
784baf18dbSFrederik Gossen ConversionPatternRewriter &rewriter) const override;
794baf18dbSFrederik Gossen };
80f30f347dSTres Popp
81f30f347dSTres Popp // Get the resulting extent in a given dimension. This is computed with any
82f30f347dSTres Popp // number of extent tensors and shifted offsets into them.
getBroadcastedDim(ImplicitLocOpBuilder lb,ValueRange extentTensors,ValueRange rankDiffs,Value outputDimension)83f30f347dSTres Popp Value getBroadcastedDim(ImplicitLocOpBuilder lb, ValueRange extentTensors,
84f30f347dSTres Popp ValueRange rankDiffs, Value outputDimension) {
85a54f4eaeSMogball Value one = lb.create<arith::ConstantIndexOp>(1);
86f30f347dSTres Popp Value broadcastedDim = one;
87f30f347dSTres Popp for (auto tup : llvm::zip(extentTensors, rankDiffs)) {
88f30f347dSTres Popp Value shape = std::get<0>(tup);
89f30f347dSTres Popp Value rankDiff = std::get<1>(tup);
90a54f4eaeSMogball Value outOfBounds = lb.create<arith::CmpIOp>(arith::CmpIPredicate::ult,
91a54f4eaeSMogball outputDimension, rankDiff);
92f30f347dSTres Popp Type indexTy = lb.getIndexType();
93f30f347dSTres Popp broadcastedDim =
94f30f347dSTres Popp lb.create<IfOp>(
951125c5c0SFrederik Gossen outOfBounds,
96f30f347dSTres Popp [&](OpBuilder &b, Location loc) {
97f30f347dSTres Popp b.create<scf::YieldOp>(loc, broadcastedDim);
98f30f347dSTres Popp },
99f30f347dSTres Popp [&](OpBuilder &b, Location loc) {
100f30f347dSTres Popp // The broadcasting logic is:
101f30f347dSTres Popp // - if one extent (here we arbitrarily choose the
102f30f347dSTres Popp // extent from the greater-rank operand) is equal to 1,
103f30f347dSTres Popp // then take the extent from the other operand
104f30f347dSTres Popp // - otherwise, take the extent as-is.
105f30f347dSTres Popp // Note that this logic remains correct in the presence
106f30f347dSTres Popp // of dimensions of zero extent.
107a54f4eaeSMogball Value lesserRankOperandDimension = b.create<arith::SubIOp>(
108a54f4eaeSMogball loc, indexTy, outputDimension, rankDiff);
109f30f347dSTres Popp Value lesserRankOperandExtent = b.create<tensor::ExtractOp>(
110f30f347dSTres Popp loc, shape, ValueRange{lesserRankOperandDimension});
111f30f347dSTres Popp
112a54f4eaeSMogball Value dimIsOne =
113a54f4eaeSMogball b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
114f30f347dSTres Popp lesserRankOperandExtent, one);
115dec8af70SRiver Riddle Value dim = b.create<arith::SelectOp>(
116dec8af70SRiver Riddle loc, dimIsOne, broadcastedDim, lesserRankOperandExtent);
117f30f347dSTres Popp b.create<scf::YieldOp>(loc, dim);
118f30f347dSTres Popp })
119f30f347dSTres Popp .getResult(0);
120f30f347dSTres Popp }
121f30f347dSTres Popp return broadcastedDim;
122f30f347dSTres Popp }
1234baf18dbSFrederik Gossen } // namespace
1244baf18dbSFrederik Gossen
matchAndRewrite(BroadcastOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const125a70f2eb3SFrederik Gossen LogicalResult BroadcastOpConverter::matchAndRewrite(
126b54c724bSRiver Riddle BroadcastOp op, OpAdaptor adaptor,
1274baf18dbSFrederik Gossen ConversionPatternRewriter &rewriter) const {
128a70f2eb3SFrederik Gossen // For now, this lowering is only defined on `tensor<?xindex>` operands, not
129a70f2eb3SFrederik Gossen // on shapes.
1305550c821STres Popp if (isa<ShapeType>(op.getType()))
1316673c6cdSFrederik Gossen return failure();
132ac3e5c4dSFrederik Gossen
1336673c6cdSFrederik Gossen auto loc = op.getLoc();
134f30f347dSTres Popp ImplicitLocOpBuilder lb(loc, rewriter);
135ac3e5c4dSFrederik Gossen
136a54f4eaeSMogball Value zero = lb.create<arith::ConstantIndexOp>(0);
137f30f347dSTres Popp Type indexTy = lb.getIndexType();
138a70f2eb3SFrederik Gossen
139f30f347dSTres Popp // Save all the ranks for bounds checking. Because this is a tensor
140f30f347dSTres Popp // representing the shape extents, the rank is the extent of the only
141f30f347dSTres Popp // dimension in the tensor.
142f30f347dSTres Popp SmallVector<Value> ranks, rankDiffs;
143cfb72fd3SJacques Pienaar llvm::append_range(ranks, llvm::map_range(adaptor.getShapes(), [&](Value v) {
144c0a6318dSMatthias Springer return lb.create<tensor::DimOp>(v, zero);
145f30f347dSTres Popp }));
146f30f347dSTres Popp
147f30f347dSTres Popp // Find the maximum rank
148f30f347dSTres Popp Value maxRank = ranks.front();
149f30f347dSTres Popp for (Value v : llvm::drop_begin(ranks, 1)) {
150*d4fd2025Smlevesquedion maxRank = lb.create<arith::MaxUIOp>(v, maxRank);
151f30f347dSTres Popp }
152f30f347dSTres Popp
153f30f347dSTres Popp // Calculate the difference of ranks and the maximum rank for later offsets.
154f30f347dSTres Popp llvm::append_range(rankDiffs, llvm::map_range(ranks, [&](Value v) {
155a54f4eaeSMogball return lb.create<arith::SubIOp>(indexTy, maxRank, v);
156f30f347dSTres Popp }));
157f30f347dSTres Popp
158eb56fa97SFrederik Gossen Value replacement = lb.create<tensor::GenerateOp>(
159f30f347dSTres Popp getExtentTensorType(lb.getContext()), ValueRange{maxRank},
16057211fd2SSean Silva [&](OpBuilder &b, Location loc, ValueRange args) {
161cfb72fd3SJacques Pienaar Value broadcastedDim =
162cfb72fd3SJacques Pienaar getBroadcastedDim(ImplicitLocOpBuilder(loc, b), adaptor.getShapes(),
163cfb72fd3SJacques Pienaar rankDiffs, args[0]);
164f30f347dSTres Popp
165f30f347dSTres Popp b.create<tensor::YieldOp>(loc, broadcastedDim);
166eb56fa97SFrederik Gossen });
167eb56fa97SFrederik Gossen if (replacement.getType() != op.getType())
168eb56fa97SFrederik Gossen replacement = lb.create<tensor::CastOp>(op.getType(), replacement);
169eb56fa97SFrederik Gossen rewriter.replaceOp(op, replacement);
170ac3e5c4dSFrederik Gossen return success();
171ac3e5c4dSFrederik Gossen }
172ac3e5c4dSFrederik Gossen
1734baf18dbSFrederik Gossen namespace {
174dfcc0989SFrederik Gossen class ConstShapeOpConverter : public OpConversionPattern<ConstShapeOp> {
175dfcc0989SFrederik Gossen public:
176dfcc0989SFrederik Gossen using OpConversionPattern<ConstShapeOp>::OpConversionPattern;
177dfcc0989SFrederik Gossen
178dfcc0989SFrederik Gossen LogicalResult
179b54c724bSRiver Riddle matchAndRewrite(ConstShapeOp op, OpAdaptor adaptor,
180dfcc0989SFrederik Gossen ConversionPatternRewriter &rewriter) const override;
181dfcc0989SFrederik Gossen };
182dfcc0989SFrederik Gossen } // namespace
183dfcc0989SFrederik Gossen
matchAndRewrite(ConstShapeOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const184dfcc0989SFrederik Gossen LogicalResult ConstShapeOpConverter::matchAndRewrite(
185b54c724bSRiver Riddle ConstShapeOp op, OpAdaptor adaptor,
186dfcc0989SFrederik Gossen ConversionPatternRewriter &rewriter) const {
187dfcc0989SFrederik Gossen
188dfcc0989SFrederik Gossen // For now, this lowering supports only extent tensors, not `shape.shape`
189dfcc0989SFrederik Gossen // types.
1905550c821STres Popp if (isa<ShapeType>(op.getType()))
191dfcc0989SFrederik Gossen return failure();
192dfcc0989SFrederik Gossen
193dfcc0989SFrederik Gossen auto loc = op.getLoc();
194dfcc0989SFrederik Gossen SmallVector<Value, 4> extentOperands;
195cfb72fd3SJacques Pienaar for (auto extent : op.getShape()) {
196dfcc0989SFrederik Gossen extentOperands.push_back(
197a54f4eaeSMogball rewriter.create<arith::ConstantIndexOp>(loc, extent.getLimitedValue()));
198dfcc0989SFrederik Gossen }
199f77e9f87SAlexander Belyaev Type resultTy =
200f77e9f87SAlexander Belyaev RankedTensorType::get({op.getShape().size()}, rewriter.getIndexType());
20184a6da67SSean Silva Value tensor =
202f77e9f87SAlexander Belyaev rewriter.create<tensor::FromElementsOp>(loc, resultTy, extentOperands);
203129d6e55SSean Silva rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultTy, tensor);
204dfcc0989SFrederik Gossen return success();
205dfcc0989SFrederik Gossen }
206dfcc0989SFrederik Gossen
207dfcc0989SFrederik Gossen namespace {
208a70f2eb3SFrederik Gossen class ConstSizeOpConversion : public OpConversionPattern<ConstSizeOp> {
2095d9f33aaSStephan Herhut public:
210a70f2eb3SFrederik Gossen using OpConversionPattern<ConstSizeOp>::OpConversionPattern;
2115d9f33aaSStephan Herhut
2125d9f33aaSStephan Herhut LogicalResult
213b54c724bSRiver Riddle matchAndRewrite(ConstSizeOp op, OpAdaptor adaptor,
214a70f2eb3SFrederik Gossen ConversionPatternRewriter &rewriter) const override;
215973800dcSDavid Truby };
216973800dcSDavid Truby } // namespace
21715acdd75SFrederik Gossen
matchAndRewrite(ConstSizeOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const218a70f2eb3SFrederik Gossen LogicalResult ConstSizeOpConversion::matchAndRewrite(
219b54c724bSRiver Riddle ConstSizeOp op, OpAdaptor adaptor,
220a70f2eb3SFrederik Gossen ConversionPatternRewriter &rewriter) const {
221a54f4eaeSMogball rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(
222cfb72fd3SJacques Pienaar op, op.getValue().getSExtValue());
223a70f2eb3SFrederik Gossen return success();
224a70f2eb3SFrederik Gossen }
225a70f2eb3SFrederik Gossen
2265d9f33aaSStephan Herhut namespace {
227511484f2STres Popp struct IsBroadcastableOpConverter
228511484f2STres Popp : public OpConversionPattern<IsBroadcastableOp> {
229511484f2STres Popp using OpConversionPattern<IsBroadcastableOp>::OpConversionPattern;
230511484f2STres Popp
231511484f2STres Popp LogicalResult
232b54c724bSRiver Riddle matchAndRewrite(IsBroadcastableOp op, OpAdaptor adaptor,
233511484f2STres Popp ConversionPatternRewriter &rewriter) const override;
234511484f2STres Popp };
235511484f2STres Popp } // namespace
236511484f2STres Popp
matchAndRewrite(IsBroadcastableOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const237511484f2STres Popp LogicalResult IsBroadcastableOpConverter::matchAndRewrite(
238b54c724bSRiver Riddle IsBroadcastableOp op, OpAdaptor adaptor,
239511484f2STres Popp ConversionPatternRewriter &rewriter) const {
240511484f2STres Popp // For now, this lowering is only defined on `tensor<?xindex>` operands, not
241511484f2STres Popp // on shapes.
242cfb72fd3SJacques Pienaar if (!llvm::all_of(op.getShapes(),
2435550c821STres Popp [](Value v) { return !isa<ShapeType>(v.getType()); }))
244511484f2STres Popp return failure();
245511484f2STres Popp
246511484f2STres Popp auto loc = op.getLoc();
2473842d4b6STres Popp ImplicitLocOpBuilder lb(loc, rewriter);
248a54f4eaeSMogball Value zero = lb.create<arith::ConstantIndexOp>(0);
249a54f4eaeSMogball Value one = lb.create<arith::ConstantIndexOp>(1);
2503842d4b6STres Popp Type indexTy = lb.getIndexType();
251511484f2STres Popp
2523842d4b6STres Popp // Save all the ranks for bounds checking. Because this is a tensor
2533842d4b6STres Popp // representing the shape extents, the rank is the extent of the only
2543842d4b6STres Popp // dimension in the tensor.
2553842d4b6STres Popp SmallVector<Value> ranks, rankDiffs;
256cfb72fd3SJacques Pienaar llvm::append_range(ranks, llvm::map_range(adaptor.getShapes(), [&](Value v) {
257c0a6318dSMatthias Springer return lb.create<tensor::DimOp>(v, zero);
2583842d4b6STres Popp }));
2593842d4b6STres Popp
2603842d4b6STres Popp // Find the maximum rank
2613842d4b6STres Popp Value maxRank = ranks.front();
2623842d4b6STres Popp for (Value v : llvm::drop_begin(ranks, 1)) {
263*d4fd2025Smlevesquedion maxRank = lb.create<arith::MaxUIOp>(v, maxRank);
2643842d4b6STres Popp }
2653842d4b6STres Popp
2663842d4b6STres Popp // Calculate the difference of ranks and the maximum rank for later offsets.
2673842d4b6STres Popp llvm::append_range(rankDiffs, llvm::map_range(ranks, [&](Value v) {
268a54f4eaeSMogball return lb.create<arith::SubIOp>(indexTy, maxRank, v);
2693842d4b6STres Popp }));
2703842d4b6STres Popp
271511484f2STres Popp Type i1Ty = rewriter.getI1Type();
2723842d4b6STres Popp Value trueVal =
273a54f4eaeSMogball rewriter.create<arith::ConstantOp>(loc, i1Ty, rewriter.getBoolAttr(true));
274511484f2STres Popp
2753842d4b6STres Popp auto reduceResult = lb.create<ForOp>(
2763842d4b6STres Popp loc, zero, maxRank, one, ValueRange{trueVal},
277511484f2STres Popp [&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) {
2783842d4b6STres Popp // Find a non-1 dim, if it exists. Note that the first part of this
2793842d4b6STres Popp // could reuse the Broadcast lowering entirely, but we redo the work
2803842d4b6STres Popp // here to make optimizations easier between the two loops.
2813842d4b6STres Popp Value broadcastedDim = getBroadcastedDim(
282cfb72fd3SJacques Pienaar ImplicitLocOpBuilder(loc, b), adaptor.getShapes(), rankDiffs, iv);
2833842d4b6STres Popp
2843842d4b6STres Popp Value broadcastable = iterArgs[0];
285cfb72fd3SJacques Pienaar for (auto tup : llvm::zip(adaptor.getShapes(), rankDiffs)) {
2863842d4b6STres Popp Value shape, rankDiff;
2873842d4b6STres Popp std::tie(shape, rankDiff) = tup;
288a54f4eaeSMogball Value outOfBounds = b.create<arith::CmpIOp>(
289a54f4eaeSMogball loc, arith::CmpIPredicate::ult, iv, rankDiff);
2903842d4b6STres Popp broadcastable =
2913842d4b6STres Popp b.create<IfOp>(
2921125c5c0SFrederik Gossen loc, outOfBounds,
2933842d4b6STres Popp [&](OpBuilder &b, Location loc) {
2943842d4b6STres Popp // Non existent dimensions are always broadcastable
2953842d4b6STres Popp b.create<scf::YieldOp>(loc, broadcastable);
2963842d4b6STres Popp },
2973842d4b6STres Popp [&](OpBuilder &b, Location loc) {
2983842d4b6STres Popp // Every value needs to be either 1, or the same non-1
2993842d4b6STres Popp // value to be broadcastable in this dim.
3003842d4b6STres Popp Value operandDimension =
301a54f4eaeSMogball b.create<arith::SubIOp>(loc, indexTy, iv, rankDiff);
3023842d4b6STres Popp Value dimensionExtent = b.create<tensor::ExtractOp>(
3033842d4b6STres Popp loc, shape, ValueRange{operandDimension});
3043842d4b6STres Popp
305a54f4eaeSMogball Value equalOne = b.create<arith::CmpIOp>(
306a54f4eaeSMogball loc, arith::CmpIPredicate::eq, dimensionExtent, one);
307a54f4eaeSMogball Value equalBroadcasted = b.create<arith::CmpIOp>(
308a54f4eaeSMogball loc, arith::CmpIPredicate::eq, dimensionExtent,
309a54f4eaeSMogball broadcastedDim);
310a54f4eaeSMogball Value result = b.create<arith::AndIOp>(
3113842d4b6STres Popp loc, broadcastable,
312a54f4eaeSMogball b.create<arith::OrIOp>(loc, equalOne,
313a54f4eaeSMogball equalBroadcasted));
3143842d4b6STres Popp b.create<scf::YieldOp>(loc, result);
3153842d4b6STres Popp })
3163842d4b6STres Popp .getResult(0);
3173842d4b6STres Popp }
3183842d4b6STres Popp
3193842d4b6STres Popp b.create<scf::YieldOp>(loc, broadcastable);
320511484f2STres Popp });
321511484f2STres Popp
322c0342a2dSJacques Pienaar rewriter.replaceOp(op, reduceResult.getResults().front());
323511484f2STres Popp return success();
324511484f2STres Popp }
325511484f2STres Popp
326511484f2STres Popp namespace {
3272f025e0eSJacques Pienaar class DimOpConverter : public OpConversionPattern<DimOp> {
3282f025e0eSJacques Pienaar using OpConversionPattern<DimOp>::OpConversionPattern;
3292f025e0eSJacques Pienaar
3302f025e0eSJacques Pienaar LogicalResult
3312f025e0eSJacques Pienaar matchAndRewrite(DimOp op, OpAdaptor adaptor,
3322f025e0eSJacques Pienaar ConversionPatternRewriter &rewriter) const override;
3332f025e0eSJacques Pienaar };
3342f025e0eSJacques Pienaar } // namespace
3352f025e0eSJacques Pienaar
3362f025e0eSJacques Pienaar LogicalResult
matchAndRewrite(DimOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const3372f025e0eSJacques Pienaar DimOpConverter::matchAndRewrite(DimOp op, OpAdaptor adaptor,
3382f025e0eSJacques Pienaar ConversionPatternRewriter &rewriter) const {
3392f025e0eSJacques Pienaar // Lower to dim(X, i) to get_extent(shape_of(X), i) and rely on further
3402f025e0eSJacques Pienaar // lowerings. This can be further optimized if needed to avoid intermediate
3412f025e0eSJacques Pienaar // steps.
3422f025e0eSJacques Pienaar auto shapeOf = rewriter.create<shape::ShapeOfOp>(op.getLoc(), op.getValue());
3432f025e0eSJacques Pienaar rewriter.replaceOpWithNewOp<shape::GetExtentOp>(op, op.getType(), shapeOf,
344b23e72a2SJacques Pienaar op.getIndex());
3452f025e0eSJacques Pienaar return success();
3462f025e0eSJacques Pienaar }
3472f025e0eSJacques Pienaar
3482f025e0eSJacques Pienaar namespace {
3498577a090SFrederik Gossen class GetExtentOpConverter : public OpConversionPattern<GetExtentOp> {
3508577a090SFrederik Gossen using OpConversionPattern<GetExtentOp>::OpConversionPattern;
3518577a090SFrederik Gossen
3528577a090SFrederik Gossen LogicalResult
353b54c724bSRiver Riddle matchAndRewrite(GetExtentOp op, OpAdaptor adaptor,
3544baf18dbSFrederik Gossen ConversionPatternRewriter &rewriter) const override;
3554baf18dbSFrederik Gossen };
3564baf18dbSFrederik Gossen } // namespace
3574baf18dbSFrederik Gossen
matchAndRewrite(GetExtentOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const3584baf18dbSFrederik Gossen LogicalResult GetExtentOpConverter::matchAndRewrite(
359b54c724bSRiver Riddle GetExtentOp op, OpAdaptor adaptor,
3604baf18dbSFrederik Gossen ConversionPatternRewriter &rewriter) const {
3616673c6cdSFrederik Gossen // For now, only error-free types are supported by this lowering.
3625550c821STres Popp if (isa<SizeType>(op.getType()))
3636673c6cdSFrederik Gossen return failure();
3646673c6cdSFrederik Gossen
3656673c6cdSFrederik Gossen // Derive shape extent directly from shape origin if possible. This
3666673c6cdSFrederik Gossen // circumvents the necessity to materialize the shape in memory.
367cfb72fd3SJacques Pienaar if (auto shapeOfOp = op.getShape().getDefiningOp<ShapeOfOp>()) {
3685550c821STres Popp if (isa<ShapedType>(shapeOfOp.getArg().getType())) {
369cfb72fd3SJacques Pienaar rewriter.replaceOpWithNewOp<tensor::DimOp>(op, shapeOfOp.getArg(),
370cfb72fd3SJacques Pienaar adaptor.getDim());
3718577a090SFrederik Gossen return success();
3728577a090SFrederik Gossen }
3736673c6cdSFrederik Gossen }
3748577a090SFrederik Gossen
375cfb72fd3SJacques Pienaar rewriter.replaceOpWithNewOp<tensor::ExtractOp>(op, rewriter.getIndexType(),
376cfb72fd3SJacques Pienaar adaptor.getShape(),
377cfb72fd3SJacques Pienaar ValueRange{adaptor.getDim()});
3788577a090SFrederik Gossen return success();
3798577a090SFrederik Gossen }
3808577a090SFrederik Gossen
3814baf18dbSFrederik Gossen namespace {
38224debf5aSFrederik Gossen class RankOpConverter : public OpConversionPattern<shape::RankOp> {
38324debf5aSFrederik Gossen public:
38424debf5aSFrederik Gossen using OpConversionPattern<shape::RankOp>::OpConversionPattern;
38524debf5aSFrederik Gossen
38624debf5aSFrederik Gossen LogicalResult
387b54c724bSRiver Riddle matchAndRewrite(shape::RankOp op, OpAdaptor adaptor,
3884baf18dbSFrederik Gossen ConversionPatternRewriter &rewriter) const override;
3894baf18dbSFrederik Gossen };
3904baf18dbSFrederik Gossen } // namespace
3914baf18dbSFrederik Gossen
3924baf18dbSFrederik Gossen LogicalResult
matchAndRewrite(shape::RankOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const393b54c724bSRiver Riddle RankOpConverter::matchAndRewrite(shape::RankOp op, OpAdaptor adaptor,
3944baf18dbSFrederik Gossen ConversionPatternRewriter &rewriter) const {
395a97940d4SFrederik Gossen // For now, this lowering supports only error-free types.
3965550c821STres Popp if (isa<SizeType>(op.getType()))
397a97940d4SFrederik Gossen return failure();
398a97940d4SFrederik Gossen
399cfb72fd3SJacques Pienaar rewriter.replaceOpWithNewOp<tensor::DimOp>(op, adaptor.getShape(), 0);
40024debf5aSFrederik Gossen return success();
40124debf5aSFrederik Gossen }
40224debf5aSFrederik Gossen
4034baf18dbSFrederik Gossen namespace {
404a70f2eb3SFrederik Gossen /// Converts `shape.reduce` to `scf.for`.
405a70f2eb3SFrederik Gossen struct ReduceOpConverter : public OpConversionPattern<shape::ReduceOp> {
406a70f2eb3SFrederik Gossen public:
407a70f2eb3SFrederik Gossen using OpConversionPattern::OpConversionPattern;
408a70f2eb3SFrederik Gossen
409a70f2eb3SFrederik Gossen LogicalResult
410b54c724bSRiver Riddle matchAndRewrite(shape::ReduceOp op, OpAdaptor adaptor,
411a70f2eb3SFrederik Gossen ConversionPatternRewriter &rewriter) const final;
412a70f2eb3SFrederik Gossen };
413a70f2eb3SFrederik Gossen } // namespace
414a70f2eb3SFrederik Gossen
415a70f2eb3SFrederik Gossen LogicalResult
matchAndRewrite(shape::ReduceOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const416b54c724bSRiver Riddle ReduceOpConverter::matchAndRewrite(shape::ReduceOp op, OpAdaptor adaptor,
417a70f2eb3SFrederik Gossen ConversionPatternRewriter &rewriter) const {
418a70f2eb3SFrederik Gossen // For now, this lowering is only defined on `tensor<?xindex>` operands.
4195550c821STres Popp if (isa<ShapeType>(op.getShape().getType()))
420a70f2eb3SFrederik Gossen return failure();
421a70f2eb3SFrederik Gossen
422a70f2eb3SFrederik Gossen auto loc = op.getLoc();
423a70f2eb3SFrederik Gossen
424a54f4eaeSMogball Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
425a54f4eaeSMogball Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
426a70f2eb3SFrederik Gossen Type indexTy = rewriter.getIndexType();
427e2310704SJulian Gross Value rank =
428cfb72fd3SJacques Pienaar rewriter.create<tensor::DimOp>(loc, indexTy, adaptor.getShape(), zero);
429a70f2eb3SFrederik Gossen
430a70f2eb3SFrederik Gossen auto loop = rewriter.create<scf::ForOp>(
431cfb72fd3SJacques Pienaar loc, zero, rank, one, op.getInitVals(),
432a70f2eb3SFrederik Gossen [&](OpBuilder &b, Location loc, Value iv, ValueRange args) {
433cfb72fd3SJacques Pienaar Value extent = b.create<tensor::ExtractOp>(loc, adaptor.getShape(), iv);
434a70f2eb3SFrederik Gossen
435a70f2eb3SFrederik Gossen SmallVector<Value, 2> mappedValues{iv, extent};
436a70f2eb3SFrederik Gossen mappedValues.append(args.begin(), args.end());
437a70f2eb3SFrederik Gossen
4384d67b278SJeff Niu IRMapping mapping;
439a70f2eb3SFrederik Gossen Block *reduceBody = op.getBody();
440a70f2eb3SFrederik Gossen mapping.map(reduceBody->getArguments(), mappedValues);
441a70f2eb3SFrederik Gossen for (auto &nested : reduceBody->without_terminator())
442a70f2eb3SFrederik Gossen b.clone(nested, mapping);
443a70f2eb3SFrederik Gossen
444a70f2eb3SFrederik Gossen SmallVector<Value, 2> mappedResults;
445a70f2eb3SFrederik Gossen for (auto result : reduceBody->getTerminator()->getOperands())
446a70f2eb3SFrederik Gossen mappedResults.push_back(mapping.lookup(result));
447a70f2eb3SFrederik Gossen b.create<scf::YieldOp>(loc, mappedResults);
448a70f2eb3SFrederik Gossen });
449a70f2eb3SFrederik Gossen
450a70f2eb3SFrederik Gossen rewriter.replaceOp(op, loop.getResults());
451a70f2eb3SFrederik Gossen return success();
452a70f2eb3SFrederik Gossen }
453a70f2eb3SFrederik Gossen
454a70f2eb3SFrederik Gossen namespace {
455a70f2eb3SFrederik Gossen /// Converts `shape.shape_eq` to an `scf.for` loop. For now, the lowering is
456a70f2eb3SFrederik Gossen /// only defined on `tensor<?xindex>` operands. The test for equality first
457a70f2eb3SFrederik Gossen /// compares their size and, if equal, checks every extent for equality.
458a70f2eb3SFrederik Gossen ///
459a70f2eb3SFrederik Gossen /// Example:
460a70f2eb3SFrederik Gossen ///
461a70f2eb3SFrederik Gossen /// %result = shape.shape_eq %a, %b : tensor<?xindex>, tensor<?xindex>
462a70f2eb3SFrederik Gossen ///
463a70f2eb3SFrederik Gossen /// becomes
464a70f2eb3SFrederik Gossen ///
465cb3aa49eSMogball /// %c0 = arith.constant 0 : index
466a70f2eb3SFrederik Gossen /// %0 = dim %arg0, %c0 : tensor<?xindex>
467a70f2eb3SFrederik Gossen /// %1 = dim %arg1, %c0 : tensor<?xindex>
468a54f4eaeSMogball /// %2 = arith.cmpi "eq", %0, %1 : index
469a70f2eb3SFrederik Gossen /// %result = scf.if %2 -> (i1) {
470a54f4eaeSMogball /// %c1 = arith.constant 1 : index
471a54f4eaeSMogball /// %true = arith.constant true
472a70f2eb3SFrederik Gossen /// %4 = scf.for %arg2 = %c0 to %0 step %c1 iter_args(%arg3 = %true) -> (i1) {
473444822d7SSean Silva /// %5 = tensor.extract %arg0[%arg2] : tensor<?xindex>
474444822d7SSean Silva /// %6 = tensor.extract %arg1[%arg2] : tensor<?xindex>
475a54f4eaeSMogball /// %7 = arith.cmpi "eq", %5, %6 : index
476a54f4eaeSMogball /// %8 = arith.andi %arg3, %7 : i1
477a70f2eb3SFrederik Gossen /// scf.yield %8 : i1
478a70f2eb3SFrederik Gossen /// }
479a70f2eb3SFrederik Gossen /// scf.yield %4 : i1
480a70f2eb3SFrederik Gossen /// } else {
481a54f4eaeSMogball /// %false = arith.constant false
482a70f2eb3SFrederik Gossen /// scf.yield %false : i1
483a70f2eb3SFrederik Gossen /// }
484a70f2eb3SFrederik Gossen ///
485a70f2eb3SFrederik Gossen struct ShapeEqOpConverter : public OpConversionPattern<ShapeEqOp> {
486a70f2eb3SFrederik Gossen using OpConversionPattern<ShapeEqOp>::OpConversionPattern;
487a70f2eb3SFrederik Gossen
488a70f2eb3SFrederik Gossen LogicalResult
489b54c724bSRiver Riddle matchAndRewrite(ShapeEqOp op, OpAdaptor adaptor,
490a70f2eb3SFrederik Gossen ConversionPatternRewriter &rewriter) const override;
491a70f2eb3SFrederik Gossen };
492a70f2eb3SFrederik Gossen } // namespace
493a70f2eb3SFrederik Gossen
494a70f2eb3SFrederik Gossen LogicalResult
matchAndRewrite(ShapeEqOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const495b54c724bSRiver Riddle ShapeEqOpConverter::matchAndRewrite(ShapeEqOp op, OpAdaptor adaptor,
496a70f2eb3SFrederik Gossen ConversionPatternRewriter &rewriter) const {
497cfb72fd3SJacques Pienaar if (!llvm::all_of(op.getShapes(),
4985550c821STres Popp [](Value v) { return !isa<ShapeType>(v.getType()); }))
499a70f2eb3SFrederik Gossen return failure();
50024acadefSBenjamin Kramer
50124acadefSBenjamin Kramer Type i1Ty = rewriter.getI1Type();
502cfb72fd3SJacques Pienaar if (op.getShapes().size() <= 1) {
503a54f4eaeSMogball rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, i1Ty,
50424acadefSBenjamin Kramer rewriter.getBoolAttr(true));
50524acadefSBenjamin Kramer return success();
506a70f2eb3SFrederik Gossen }
507a70f2eb3SFrederik Gossen
508a70f2eb3SFrederik Gossen auto loc = op.getLoc();
509a70f2eb3SFrederik Gossen Type indexTy = rewriter.getIndexType();
510a54f4eaeSMogball Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
511cfb72fd3SJacques Pienaar Value firstShape = adaptor.getShapes().front();
512e2310704SJulian Gross Value firstRank =
513c0a6318dSMatthias Springer rewriter.create<tensor::DimOp>(loc, indexTy, firstShape, zero);
51424acadefSBenjamin Kramer Value result = nullptr;
51524acadefSBenjamin Kramer // Generate a linear sequence of compares, all with firstShape as lhs.
516cfb72fd3SJacques Pienaar for (Value shape : adaptor.getShapes().drop_front(1)) {
517c0a6318dSMatthias Springer Value rank = rewriter.create<tensor::DimOp>(loc, indexTy, shape, zero);
518a54f4eaeSMogball Value eqRank = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
519a54f4eaeSMogball firstRank, rank);
52024acadefSBenjamin Kramer auto same = rewriter.create<IfOp>(
5211125c5c0SFrederik Gossen loc, eqRank,
522a70f2eb3SFrederik Gossen [&](OpBuilder &b, Location loc) {
523a54f4eaeSMogball Value one = b.create<arith::ConstantIndexOp>(loc, 1);
524a54f4eaeSMogball Value init =
525a54f4eaeSMogball b.create<arith::ConstantOp>(loc, i1Ty, b.getBoolAttr(true));
526a70f2eb3SFrederik Gossen auto loop = b.create<scf::ForOp>(
52724acadefSBenjamin Kramer loc, zero, firstRank, one, ValueRange{init},
528a70f2eb3SFrederik Gossen [&](OpBuilder &b, Location nestedLoc, Value iv, ValueRange args) {
529a70f2eb3SFrederik Gossen Value conj = args[0];
530a70f2eb3SFrederik Gossen Value lhsExtent =
53124acadefSBenjamin Kramer b.create<tensor::ExtractOp>(loc, firstShape, iv);
53224acadefSBenjamin Kramer Value rhsExtent = b.create<tensor::ExtractOp>(loc, shape, iv);
533a54f4eaeSMogball Value eqExtent = b.create<arith::CmpIOp>(
534a54f4eaeSMogball loc, arith::CmpIPredicate::eq, lhsExtent, rhsExtent);
535a54f4eaeSMogball Value conjNext = b.create<arith::AndIOp>(loc, conj, eqExtent);
536a70f2eb3SFrederik Gossen b.create<scf::YieldOp>(loc, ValueRange({conjNext}));
537a70f2eb3SFrederik Gossen });
538a70f2eb3SFrederik Gossen b.create<scf::YieldOp>(loc, loop.getResults());
539a70f2eb3SFrederik Gossen },
540a70f2eb3SFrederik Gossen [&](OpBuilder &b, Location loc) {
541a54f4eaeSMogball Value result =
542a54f4eaeSMogball b.create<arith::ConstantOp>(loc, i1Ty, b.getBoolAttr(false));
543a70f2eb3SFrederik Gossen b.create<scf::YieldOp>(loc, result);
544a70f2eb3SFrederik Gossen });
54524acadefSBenjamin Kramer result = !result ? same.getResult(0)
546a54f4eaeSMogball : rewriter.create<arith::AndIOp>(loc, result,
547a54f4eaeSMogball same.getResult(0));
54824acadefSBenjamin Kramer }
54924acadefSBenjamin Kramer rewriter.replaceOp(op, result);
550a70f2eb3SFrederik Gossen return success();
551a70f2eb3SFrederik Gossen }
552a70f2eb3SFrederik Gossen
553a70f2eb3SFrederik Gossen namespace {
554a70f2eb3SFrederik Gossen class ShapeOfOpConversion : public OpConversionPattern<ShapeOfOp> {
555a70f2eb3SFrederik Gossen public:
556a70f2eb3SFrederik Gossen using OpConversionPattern<ShapeOfOp>::OpConversionPattern;
557a70f2eb3SFrederik Gossen
558a70f2eb3SFrederik Gossen LogicalResult
559b54c724bSRiver Riddle matchAndRewrite(ShapeOfOp op, OpAdaptor adaptor,
560a70f2eb3SFrederik Gossen ConversionPatternRewriter &rewriter) const override;
561a70f2eb3SFrederik Gossen };
562a70f2eb3SFrederik Gossen } // namespace
563a70f2eb3SFrederik Gossen
matchAndRewrite(ShapeOfOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const564a70f2eb3SFrederik Gossen LogicalResult ShapeOfOpConversion::matchAndRewrite(
565b54c724bSRiver Riddle ShapeOfOp op, OpAdaptor adaptor,
566a70f2eb3SFrederik Gossen ConversionPatternRewriter &rewriter) const {
567a70f2eb3SFrederik Gossen
568a70f2eb3SFrederik Gossen // For now, only error-free types are supported by this lowering.
5695550c821STres Popp if (isa<ShapeType>(op.getType()))
570a70f2eb3SFrederik Gossen return failure();
571a70f2eb3SFrederik Gossen
572be7352c0SSean Silva // For ranked tensor arguments, lower to `tensor.from_elements`.
5735106a8b8SFrederik Gossen auto loc = op.getLoc();
574cfb72fd3SJacques Pienaar Value tensor = adaptor.getArg();
575a70f2eb3SFrederik Gossen Type tensorTy = tensor.getType();
5765550c821STres Popp if (isa<RankedTensorType>(tensorTy)) {
577a70f2eb3SFrederik Gossen
578a70f2eb3SFrederik Gossen // Build values for individual extents.
579a70f2eb3SFrederik Gossen SmallVector<Value, 8> extentValues;
5805550c821STres Popp RankedTensorType rankedTensorTy = cast<RankedTensorType>(tensorTy);
581a70f2eb3SFrederik Gossen int64_t rank = rankedTensorTy.getRank();
582a70f2eb3SFrederik Gossen for (int64_t i = 0; i < rank; i++) {
583a70f2eb3SFrederik Gossen if (rankedTensorTy.isDynamicDim(i)) {
584c0a6318dSMatthias Springer Value extent = rewriter.create<tensor::DimOp>(loc, tensor, i);
585a70f2eb3SFrederik Gossen extentValues.push_back(extent);
586a70f2eb3SFrederik Gossen } else {
587a54f4eaeSMogball Value extent = rewriter.create<arith::ConstantIndexOp>(
588a54f4eaeSMogball loc, rankedTensorTy.getDimSize(i));
589a70f2eb3SFrederik Gossen extentValues.push_back(extent);
590a70f2eb3SFrederik Gossen }
591a70f2eb3SFrederik Gossen }
592a70f2eb3SFrederik Gossen
593a70f2eb3SFrederik Gossen // Materialize extent tensor.
594be7352c0SSean Silva Value staticExtentTensor = rewriter.create<tensor::FromElementsOp>(
595f77e9f87SAlexander Belyaev loc, RankedTensorType::get({rank}, rewriter.getIndexType()),
596f77e9f87SAlexander Belyaev extentValues);
597129d6e55SSean Silva rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(),
598129d6e55SSean Silva staticExtentTensor);
599a70f2eb3SFrederik Gossen return success();
600a70f2eb3SFrederik Gossen }
601a70f2eb3SFrederik Gossen
602be7352c0SSean Silva // Lower to `tensor.generate` otherwise.
6035106a8b8SFrederik Gossen auto *ctx = rewriter.getContext();
60415f8f3e2SAlexander Belyaev Value rank = rewriter.create<tensor::RankOp>(loc, tensor);
605be7352c0SSean Silva rewriter.replaceOpWithNewOp<tensor::GenerateOp>(
6065106a8b8SFrederik Gossen op, getExtentTensorType(ctx), ValueRange{rank},
6075106a8b8SFrederik Gossen [&](OpBuilder &b, Location loc, ValueRange args) {
6085106a8b8SFrederik Gossen Value dim = args.front();
609c0a6318dSMatthias Springer Value extent = b.create<tensor::DimOp>(loc, tensor, dim);
610be7352c0SSean Silva b.create<tensor::YieldOp>(loc, extent);
611a70f2eb3SFrederik Gossen });
612a70f2eb3SFrederik Gossen
613a70f2eb3SFrederik Gossen return success();
614a70f2eb3SFrederik Gossen }
615a70f2eb3SFrederik Gossen
616a70f2eb3SFrederik Gossen namespace {
61742c195f0SBenjamin Kramer class SplitAtOpConversion : public OpConversionPattern<SplitAtOp> {
61842c195f0SBenjamin Kramer public:
61942c195f0SBenjamin Kramer using OpConversionPattern<SplitAtOp>::OpConversionPattern;
62042c195f0SBenjamin Kramer
62142c195f0SBenjamin Kramer LogicalResult
622b54c724bSRiver Riddle matchAndRewrite(SplitAtOp op, OpAdaptor adaptor,
62342c195f0SBenjamin Kramer ConversionPatternRewriter &rewriter) const override;
62442c195f0SBenjamin Kramer };
62542c195f0SBenjamin Kramer } // namespace
62642c195f0SBenjamin Kramer
matchAndRewrite(SplitAtOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const62742c195f0SBenjamin Kramer LogicalResult SplitAtOpConversion::matchAndRewrite(
628b54c724bSRiver Riddle SplitAtOp op, OpAdaptor adaptor,
62942c195f0SBenjamin Kramer ConversionPatternRewriter &rewriter) const {
63042c195f0SBenjamin Kramer // Error conditions are not implemented, only lower if all operands and
63142c195f0SBenjamin Kramer // results are extent tensors.
632cfb72fd3SJacques Pienaar if (llvm::any_of(ValueRange{op.getOperand(), op.getHead(), op.getTail()},
6335550c821STres Popp [](Value v) { return isa<ShapeType>(v.getType()); }))
63442c195f0SBenjamin Kramer return failure();
63542c195f0SBenjamin Kramer
63642c195f0SBenjamin Kramer ImplicitLocOpBuilder b(op.getLoc(), rewriter);
637a54f4eaeSMogball Value zero = b.create<arith::ConstantIndexOp>(0);
638cfb72fd3SJacques Pienaar Value rank = b.create<tensor::DimOp>(adaptor.getOperand(), zero);
63942c195f0SBenjamin Kramer
64042c195f0SBenjamin Kramer // index < 0 ? index + rank : index
641cfb72fd3SJacques Pienaar Value originalIndex = adaptor.getIndex();
642a54f4eaeSMogball Value add = b.create<arith::AddIOp>(originalIndex, rank);
64342c195f0SBenjamin Kramer Value indexIsNegative =
644a54f4eaeSMogball b.create<arith::CmpIOp>(arith::CmpIPredicate::slt, originalIndex, zero);
645dec8af70SRiver Riddle Value index = b.create<arith::SelectOp>(indexIsNegative, add, originalIndex);
64642c195f0SBenjamin Kramer
647a54f4eaeSMogball Value one = b.create<arith::ConstantIndexOp>(1);
648060208b4SMatthias Springer Value head =
649cfb72fd3SJacques Pienaar b.create<tensor::ExtractSliceOp>(adaptor.getOperand(), zero, index, one);
650a54f4eaeSMogball Value tailSize = b.create<arith::SubIOp>(rank, index);
651cfb72fd3SJacques Pienaar Value tail = b.create<tensor::ExtractSliceOp>(adaptor.getOperand(), index,
652cfb72fd3SJacques Pienaar tailSize, one);
65342c195f0SBenjamin Kramer rewriter.replaceOp(op, {head, tail});
65442c195f0SBenjamin Kramer return success();
65542c195f0SBenjamin Kramer }
65642c195f0SBenjamin Kramer
65742c195f0SBenjamin Kramer namespace {
658a70f2eb3SFrederik Gossen class ToExtentTensorOpConversion
659a70f2eb3SFrederik Gossen : public OpConversionPattern<ToExtentTensorOp> {
660a70f2eb3SFrederik Gossen public:
661a70f2eb3SFrederik Gossen using OpConversionPattern<ToExtentTensorOp>::OpConversionPattern;
662a70f2eb3SFrederik Gossen
663a70f2eb3SFrederik Gossen LogicalResult
matchAndRewrite(ToExtentTensorOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const664b54c724bSRiver Riddle matchAndRewrite(ToExtentTensorOp op, OpAdaptor adaptor,
665a70f2eb3SFrederik Gossen ConversionPatternRewriter &rewriter) const override {
6665550c821STres Popp if (!isa<RankedTensorType>(adaptor.getInput().getType()))
667a70f2eb3SFrederik Gossen return rewriter.notifyMatchFailure(op, "input needs to be a tensor");
668a70f2eb3SFrederik Gossen
669129d6e55SSean Silva rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(),
670cfb72fd3SJacques Pienaar adaptor.getInput());
671a70f2eb3SFrederik Gossen return success();
672a70f2eb3SFrederik Gossen }
673a70f2eb3SFrederik Gossen };
674a70f2eb3SFrederik Gossen } // namespace
675a70f2eb3SFrederik Gossen
676a70f2eb3SFrederik Gossen namespace {
677d05d4219STres Popp /// Import the Shape Ops to Std Patterns.
678d05d4219STres Popp #include "ShapeToStandard.cpp.inc"
679d05d4219STres Popp } // namespace
680d05d4219STres Popp
681d05d4219STres Popp namespace {
6823713314bSFrederik Gossen /// Conversion pass.
6833713314bSFrederik Gossen class ConvertShapeToStandardPass
68467d0d7acSMichele Scuttari : public impl::ConvertShapeToStandardBase<ConvertShapeToStandardPass> {
685eaf49130SFrederik Gossen
6864baf18dbSFrederik Gossen void runOnOperation() override;
6874baf18dbSFrederik Gossen };
6884baf18dbSFrederik Gossen } // namespace
6894baf18dbSFrederik Gossen
runOnOperation()6904baf18dbSFrederik Gossen void ConvertShapeToStandardPass::runOnOperation() {
6913713314bSFrederik Gossen // Setup target legality.
692b6b9d3eaSFrederik Gossen MLIRContext &ctx = getContext();
6933713314bSFrederik Gossen ConversionTarget target(ctx);
694abc362a1SJakub Kuderski target.addLegalDialect<arith::ArithDialect, SCFDialect,
6951f971e23SRiver Riddle tensor::TensorDialect>();
69658ceae95SRiver Riddle target.addLegalOp<CstrRequireOp, func::FuncOp, ModuleOp>();
6973713314bSFrederik Gossen
6983713314bSFrederik Gossen // Setup conversion patterns.
699dc4e913bSChris Lattner RewritePatternSet patterns(&ctx);
7003a506b31SChris Lattner populateShapeToStandardConversionPatterns(patterns);
7013713314bSFrederik Gossen
7023713314bSFrederik Gossen // Apply conversion.
7033713314bSFrederik Gossen auto module = getOperation();
7043fffffa8SRiver Riddle if (failed(applyPartialConversion(module, target, std::move(patterns))))
7053713314bSFrederik Gossen signalPassFailure();
7063713314bSFrederik Gossen }
7073713314bSFrederik Gossen
populateShapeToStandardConversionPatterns(RewritePatternSet & patterns)70824edbdf9SFrederik Gossen void mlir::populateShapeToStandardConversionPatterns(
709dc4e913bSChris Lattner RewritePatternSet &patterns) {
7103713314bSFrederik Gossen // clang-format off
7111d909c9aSChris Lattner populateWithGenerated(patterns);
712dc4e913bSChris Lattner patterns.add<
7139df6afbbSFrederik Gossen AnyOpConversion,
714a54f4eaeSMogball BinaryOpConversion<AddOp, arith::AddIOp>,
715a54f4eaeSMogball BinaryOpConversion<MulOp, arith::MulIOp>,
716a70f2eb3SFrederik Gossen BroadcastOpConverter,
717a70f2eb3SFrederik Gossen ConstShapeOpConverter,
7185d9f33aaSStephan Herhut ConstSizeOpConversion,
7192f025e0eSJacques Pienaar DimOpConverter,
720511484f2STres Popp IsBroadcastableOpConverter,
7218577a090SFrederik Gossen GetExtentOpConverter,
72224debf5aSFrederik Gossen RankOpConverter,
723a70f2eb3SFrederik Gossen ReduceOpConverter,
724a70f2eb3SFrederik Gossen ShapeEqOpConverter,
7255d9f33aaSStephan Herhut ShapeOfOpConversion,
72642c195f0SBenjamin Kramer SplitAtOpConversion,
7273a506b31SChris Lattner ToExtentTensorOpConversion>(patterns.getContext());
7283713314bSFrederik Gossen // clang-format on
7293713314bSFrederik Gossen }
7303713314bSFrederik Gossen
73124edbdf9SFrederik Gossen std::unique_ptr<OperationPass<ModuleOp>>
createConvertShapeToStandardPass()73224edbdf9SFrederik Gossen mlir::createConvertShapeToStandardPass() {
7333713314bSFrederik Gossen return std::make_unique<ConvertShapeToStandardPass>();
7343713314bSFrederik Gossen }
735