xref: /llvm-project/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp (revision d4fd20258f63d30be638b04f10eaa469707759f0)
171c10803SAlexander Belyaev //===- ShapeToStandard.cpp - conversion from Shape to Standard dialect ----===//
23713314bSFrederik Gossen //
33713314bSFrederik Gossen // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
43713314bSFrederik Gossen // See https://llvm.org/LICENSE.txt for license information.
53713314bSFrederik Gossen // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
63713314bSFrederik Gossen //
73713314bSFrederik Gossen //===----------------------------------------------------------------------===//
83713314bSFrederik Gossen 
93713314bSFrederik Gossen #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h"
103713314bSFrederik Gossen 
11abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h"
1236550692SRiver Riddle #include "mlir/Dialect/Func/IR/FuncOps.h"
138b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/IR/SCF.h"
143713314bSFrederik Gossen #include "mlir/Dialect/Shape/IR/Shape.h"
15444822d7SSean Silva #include "mlir/Dialect/Tensor/IR/Tensor.h"
164d67b278SJeff Niu #include "mlir/IR/IRMapping.h"
17f30f347dSTres Popp #include "mlir/IR/ImplicitLocOpBuilder.h"
1867d0d7acSMichele Scuttari #include "mlir/Pass/Pass.h"
193713314bSFrederik Gossen #include "mlir/Transforms/DialectConversion.h"
20f30f347dSTres Popp #include "llvm/ADT/STLExtras.h"
213713314bSFrederik Gossen 
2267d0d7acSMichele Scuttari namespace mlir {
2367d0d7acSMichele Scuttari #define GEN_PASS_DEF_CONVERTSHAPETOSTANDARD
2467d0d7acSMichele Scuttari #include "mlir/Conversion/Passes.h.inc"
2567d0d7acSMichele Scuttari } // namespace mlir
2667d0d7acSMichele Scuttari 
2724edbdf9SFrederik Gossen using namespace mlir;
2880be54c0SAlexander Belyaev using namespace mlir::shape;
29a70f2eb3SFrederik Gossen using namespace mlir::scf;
3024edbdf9SFrederik Gossen 
313713314bSFrederik Gossen /// Conversion patterns.
324baf18dbSFrederik Gossen namespace {
339df6afbbSFrederik Gossen class AnyOpConversion : public OpConversionPattern<AnyOp> {
349df6afbbSFrederik Gossen public:
359df6afbbSFrederik Gossen   using OpConversionPattern<AnyOp>::OpConversionPattern;
369df6afbbSFrederik Gossen 
379df6afbbSFrederik Gossen   LogicalResult
38b54c724bSRiver Riddle   matchAndRewrite(AnyOp op, OpAdaptor adaptor,
394baf18dbSFrederik Gossen                   ConversionPatternRewriter &rewriter) const override;
404baf18dbSFrederik Gossen };
414baf18dbSFrederik Gossen } // namespace
424baf18dbSFrederik Gossen 
434baf18dbSFrederik Gossen LogicalResult
matchAndRewrite(AnyOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const44b54c724bSRiver Riddle AnyOpConversion::matchAndRewrite(AnyOp op, OpAdaptor adaptor,
454baf18dbSFrederik Gossen                                  ConversionPatternRewriter &rewriter) const {
469df6afbbSFrederik Gossen   // Replace `any` with its first operand.
479df6afbbSFrederik Gossen   // Any operand would be a valid substitution.
48cfb72fd3SJacques Pienaar   rewriter.replaceOp(op, {adaptor.getInputs().front()});
499df6afbbSFrederik Gossen   return success();
509df6afbbSFrederik Gossen }
519df6afbbSFrederik Gossen 
524baf18dbSFrederik Gossen namespace {
5380be54c0SAlexander Belyaev template <typename SrcOpTy, typename DstOpTy>
5480be54c0SAlexander Belyaev class BinaryOpConversion : public OpConversionPattern<SrcOpTy> {
5580be54c0SAlexander Belyaev public:
5680be54c0SAlexander Belyaev   using OpConversionPattern<SrcOpTy>::OpConversionPattern;
5780be54c0SAlexander Belyaev 
5880be54c0SAlexander Belyaev   LogicalResult
matchAndRewrite(SrcOpTy op,typename SrcOpTy::Adaptor adaptor,ConversionPatternRewriter & rewriter) const59b54c724bSRiver Riddle   matchAndRewrite(SrcOpTy op, typename SrcOpTy::Adaptor adaptor,
6080be54c0SAlexander Belyaev                   ConversionPatternRewriter &rewriter) const override {
616673c6cdSFrederik Gossen     // For now, only error-free types are supported by this lowering.
625550c821STres Popp     if (isa<SizeType>(op.getType()))
636673c6cdSFrederik Gossen       return failure();
646673c6cdSFrederik Gossen 
65cfb72fd3SJacques Pienaar     rewriter.replaceOpWithNewOp<DstOpTy>(op, adaptor.getLhs(),
66cfb72fd3SJacques Pienaar                                          adaptor.getRhs());
6780be54c0SAlexander Belyaev     return success();
6880be54c0SAlexander Belyaev   }
6980be54c0SAlexander Belyaev };
704baf18dbSFrederik Gossen } // namespace
7180be54c0SAlexander Belyaev 
724baf18dbSFrederik Gossen namespace {
73a70f2eb3SFrederik Gossen struct BroadcastOpConverter : public OpConversionPattern<BroadcastOp> {
74a70f2eb3SFrederik Gossen   using OpConversionPattern<BroadcastOp>::OpConversionPattern;
755d9f33aaSStephan Herhut 
765d9f33aaSStephan Herhut   LogicalResult
77b54c724bSRiver Riddle   matchAndRewrite(BroadcastOp op, OpAdaptor adaptor,
784baf18dbSFrederik Gossen                   ConversionPatternRewriter &rewriter) const override;
794baf18dbSFrederik Gossen };
80f30f347dSTres Popp 
81f30f347dSTres Popp // Get the resulting extent in a given dimension. This is computed with any
82f30f347dSTres Popp // number of extent tensors and shifted offsets into them.
getBroadcastedDim(ImplicitLocOpBuilder lb,ValueRange extentTensors,ValueRange rankDiffs,Value outputDimension)83f30f347dSTres Popp Value getBroadcastedDim(ImplicitLocOpBuilder lb, ValueRange extentTensors,
84f30f347dSTres Popp                         ValueRange rankDiffs, Value outputDimension) {
85a54f4eaeSMogball   Value one = lb.create<arith::ConstantIndexOp>(1);
86f30f347dSTres Popp   Value broadcastedDim = one;
87f30f347dSTres Popp   for (auto tup : llvm::zip(extentTensors, rankDiffs)) {
88f30f347dSTres Popp     Value shape = std::get<0>(tup);
89f30f347dSTres Popp     Value rankDiff = std::get<1>(tup);
90a54f4eaeSMogball     Value outOfBounds = lb.create<arith::CmpIOp>(arith::CmpIPredicate::ult,
91a54f4eaeSMogball                                                  outputDimension, rankDiff);
92f30f347dSTres Popp     Type indexTy = lb.getIndexType();
93f30f347dSTres Popp     broadcastedDim =
94f30f347dSTres Popp         lb.create<IfOp>(
951125c5c0SFrederik Gossen               outOfBounds,
96f30f347dSTres Popp               [&](OpBuilder &b, Location loc) {
97f30f347dSTres Popp                 b.create<scf::YieldOp>(loc, broadcastedDim);
98f30f347dSTres Popp               },
99f30f347dSTres Popp               [&](OpBuilder &b, Location loc) {
100f30f347dSTres Popp                 // The broadcasting logic is:
101f30f347dSTres Popp                 // - if one extent (here we arbitrarily choose the
102f30f347dSTres Popp                 // extent from the greater-rank operand) is equal to 1,
103f30f347dSTres Popp                 // then take the extent from the other operand
104f30f347dSTres Popp                 // - otherwise, take the extent as-is.
105f30f347dSTres Popp                 // Note that this logic remains correct in the presence
106f30f347dSTres Popp                 // of dimensions of zero extent.
107a54f4eaeSMogball                 Value lesserRankOperandDimension = b.create<arith::SubIOp>(
108a54f4eaeSMogball                     loc, indexTy, outputDimension, rankDiff);
109f30f347dSTres Popp                 Value lesserRankOperandExtent = b.create<tensor::ExtractOp>(
110f30f347dSTres Popp                     loc, shape, ValueRange{lesserRankOperandDimension});
111f30f347dSTres Popp 
112a54f4eaeSMogball                 Value dimIsOne =
113a54f4eaeSMogball                     b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
114f30f347dSTres Popp                                             lesserRankOperandExtent, one);
115dec8af70SRiver Riddle                 Value dim = b.create<arith::SelectOp>(
116dec8af70SRiver Riddle                     loc, dimIsOne, broadcastedDim, lesserRankOperandExtent);
117f30f347dSTres Popp                 b.create<scf::YieldOp>(loc, dim);
118f30f347dSTres Popp               })
119f30f347dSTres Popp             .getResult(0);
120f30f347dSTres Popp   }
121f30f347dSTres Popp   return broadcastedDim;
122f30f347dSTres Popp }
1234baf18dbSFrederik Gossen } // namespace
1244baf18dbSFrederik Gossen 
matchAndRewrite(BroadcastOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const125a70f2eb3SFrederik Gossen LogicalResult BroadcastOpConverter::matchAndRewrite(
126b54c724bSRiver Riddle     BroadcastOp op, OpAdaptor adaptor,
1274baf18dbSFrederik Gossen     ConversionPatternRewriter &rewriter) const {
128a70f2eb3SFrederik Gossen   // For now, this lowering is only defined on `tensor<?xindex>` operands, not
129a70f2eb3SFrederik Gossen   // on shapes.
1305550c821STres Popp   if (isa<ShapeType>(op.getType()))
1316673c6cdSFrederik Gossen     return failure();
132ac3e5c4dSFrederik Gossen 
1336673c6cdSFrederik Gossen   auto loc = op.getLoc();
134f30f347dSTres Popp   ImplicitLocOpBuilder lb(loc, rewriter);
135ac3e5c4dSFrederik Gossen 
136a54f4eaeSMogball   Value zero = lb.create<arith::ConstantIndexOp>(0);
137f30f347dSTres Popp   Type indexTy = lb.getIndexType();
138a70f2eb3SFrederik Gossen 
139f30f347dSTres Popp   // Save all the ranks for bounds checking. Because this is a tensor
140f30f347dSTres Popp   // representing the shape extents, the rank is the extent of the only
141f30f347dSTres Popp   // dimension in the tensor.
142f30f347dSTres Popp   SmallVector<Value> ranks, rankDiffs;
143cfb72fd3SJacques Pienaar   llvm::append_range(ranks, llvm::map_range(adaptor.getShapes(), [&](Value v) {
144c0a6318dSMatthias Springer                        return lb.create<tensor::DimOp>(v, zero);
145f30f347dSTres Popp                      }));
146f30f347dSTres Popp 
147f30f347dSTres Popp   // Find the maximum rank
148f30f347dSTres Popp   Value maxRank = ranks.front();
149f30f347dSTres Popp   for (Value v : llvm::drop_begin(ranks, 1)) {
150*d4fd2025Smlevesquedion     maxRank = lb.create<arith::MaxUIOp>(v, maxRank);
151f30f347dSTres Popp   }
152f30f347dSTres Popp 
153f30f347dSTres Popp   // Calculate the difference of ranks and the maximum rank for later offsets.
154f30f347dSTres Popp   llvm::append_range(rankDiffs, llvm::map_range(ranks, [&](Value v) {
155a54f4eaeSMogball                        return lb.create<arith::SubIOp>(indexTy, maxRank, v);
156f30f347dSTres Popp                      }));
157f30f347dSTres Popp 
158eb56fa97SFrederik Gossen   Value replacement = lb.create<tensor::GenerateOp>(
159f30f347dSTres Popp       getExtentTensorType(lb.getContext()), ValueRange{maxRank},
16057211fd2SSean Silva       [&](OpBuilder &b, Location loc, ValueRange args) {
161cfb72fd3SJacques Pienaar         Value broadcastedDim =
162cfb72fd3SJacques Pienaar             getBroadcastedDim(ImplicitLocOpBuilder(loc, b), adaptor.getShapes(),
163cfb72fd3SJacques Pienaar                               rankDiffs, args[0]);
164f30f347dSTres Popp 
165f30f347dSTres Popp         b.create<tensor::YieldOp>(loc, broadcastedDim);
166eb56fa97SFrederik Gossen       });
167eb56fa97SFrederik Gossen   if (replacement.getType() != op.getType())
168eb56fa97SFrederik Gossen     replacement = lb.create<tensor::CastOp>(op.getType(), replacement);
169eb56fa97SFrederik Gossen   rewriter.replaceOp(op, replacement);
170ac3e5c4dSFrederik Gossen   return success();
171ac3e5c4dSFrederik Gossen }
172ac3e5c4dSFrederik Gossen 
1734baf18dbSFrederik Gossen namespace {
174dfcc0989SFrederik Gossen class ConstShapeOpConverter : public OpConversionPattern<ConstShapeOp> {
175dfcc0989SFrederik Gossen public:
176dfcc0989SFrederik Gossen   using OpConversionPattern<ConstShapeOp>::OpConversionPattern;
177dfcc0989SFrederik Gossen 
178dfcc0989SFrederik Gossen   LogicalResult
179b54c724bSRiver Riddle   matchAndRewrite(ConstShapeOp op, OpAdaptor adaptor,
180dfcc0989SFrederik Gossen                   ConversionPatternRewriter &rewriter) const override;
181dfcc0989SFrederik Gossen };
182dfcc0989SFrederik Gossen } // namespace
183dfcc0989SFrederik Gossen 
matchAndRewrite(ConstShapeOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const184dfcc0989SFrederik Gossen LogicalResult ConstShapeOpConverter::matchAndRewrite(
185b54c724bSRiver Riddle     ConstShapeOp op, OpAdaptor adaptor,
186dfcc0989SFrederik Gossen     ConversionPatternRewriter &rewriter) const {
187dfcc0989SFrederik Gossen 
188dfcc0989SFrederik Gossen   // For now, this lowering supports only extent tensors, not `shape.shape`
189dfcc0989SFrederik Gossen   // types.
1905550c821STres Popp   if (isa<ShapeType>(op.getType()))
191dfcc0989SFrederik Gossen     return failure();
192dfcc0989SFrederik Gossen 
193dfcc0989SFrederik Gossen   auto loc = op.getLoc();
194dfcc0989SFrederik Gossen   SmallVector<Value, 4> extentOperands;
195cfb72fd3SJacques Pienaar   for (auto extent : op.getShape()) {
196dfcc0989SFrederik Gossen     extentOperands.push_back(
197a54f4eaeSMogball         rewriter.create<arith::ConstantIndexOp>(loc, extent.getLimitedValue()));
198dfcc0989SFrederik Gossen   }
199f77e9f87SAlexander Belyaev   Type resultTy =
200f77e9f87SAlexander Belyaev       RankedTensorType::get({op.getShape().size()}, rewriter.getIndexType());
20184a6da67SSean Silva   Value tensor =
202f77e9f87SAlexander Belyaev       rewriter.create<tensor::FromElementsOp>(loc, resultTy, extentOperands);
203129d6e55SSean Silva   rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultTy, tensor);
204dfcc0989SFrederik Gossen   return success();
205dfcc0989SFrederik Gossen }
206dfcc0989SFrederik Gossen 
207dfcc0989SFrederik Gossen namespace {
208a70f2eb3SFrederik Gossen class ConstSizeOpConversion : public OpConversionPattern<ConstSizeOp> {
2095d9f33aaSStephan Herhut public:
210a70f2eb3SFrederik Gossen   using OpConversionPattern<ConstSizeOp>::OpConversionPattern;
2115d9f33aaSStephan Herhut 
2125d9f33aaSStephan Herhut   LogicalResult
213b54c724bSRiver Riddle   matchAndRewrite(ConstSizeOp op, OpAdaptor adaptor,
214a70f2eb3SFrederik Gossen                   ConversionPatternRewriter &rewriter) const override;
215973800dcSDavid Truby };
216973800dcSDavid Truby } // namespace
21715acdd75SFrederik Gossen 
matchAndRewrite(ConstSizeOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const218a70f2eb3SFrederik Gossen LogicalResult ConstSizeOpConversion::matchAndRewrite(
219b54c724bSRiver Riddle     ConstSizeOp op, OpAdaptor adaptor,
220a70f2eb3SFrederik Gossen     ConversionPatternRewriter &rewriter) const {
221a54f4eaeSMogball   rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(
222cfb72fd3SJacques Pienaar       op, op.getValue().getSExtValue());
223a70f2eb3SFrederik Gossen   return success();
224a70f2eb3SFrederik Gossen }
225a70f2eb3SFrederik Gossen 
2265d9f33aaSStephan Herhut namespace {
227511484f2STres Popp struct IsBroadcastableOpConverter
228511484f2STres Popp     : public OpConversionPattern<IsBroadcastableOp> {
229511484f2STres Popp   using OpConversionPattern<IsBroadcastableOp>::OpConversionPattern;
230511484f2STres Popp 
231511484f2STres Popp   LogicalResult
232b54c724bSRiver Riddle   matchAndRewrite(IsBroadcastableOp op, OpAdaptor adaptor,
233511484f2STres Popp                   ConversionPatternRewriter &rewriter) const override;
234511484f2STres Popp };
235511484f2STres Popp } // namespace
236511484f2STres Popp 
matchAndRewrite(IsBroadcastableOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const237511484f2STres Popp LogicalResult IsBroadcastableOpConverter::matchAndRewrite(
238b54c724bSRiver Riddle     IsBroadcastableOp op, OpAdaptor adaptor,
239511484f2STres Popp     ConversionPatternRewriter &rewriter) const {
240511484f2STres Popp   // For now, this lowering is only defined on `tensor<?xindex>` operands, not
241511484f2STres Popp   // on shapes.
242cfb72fd3SJacques Pienaar   if (!llvm::all_of(op.getShapes(),
2435550c821STres Popp                     [](Value v) { return !isa<ShapeType>(v.getType()); }))
244511484f2STres Popp     return failure();
245511484f2STres Popp 
246511484f2STres Popp   auto loc = op.getLoc();
2473842d4b6STres Popp   ImplicitLocOpBuilder lb(loc, rewriter);
248a54f4eaeSMogball   Value zero = lb.create<arith::ConstantIndexOp>(0);
249a54f4eaeSMogball   Value one = lb.create<arith::ConstantIndexOp>(1);
2503842d4b6STres Popp   Type indexTy = lb.getIndexType();
251511484f2STres Popp 
2523842d4b6STres Popp   // Save all the ranks for bounds checking. Because this is a tensor
2533842d4b6STres Popp   // representing the shape extents, the rank is the extent of the only
2543842d4b6STres Popp   // dimension in the tensor.
2553842d4b6STres Popp   SmallVector<Value> ranks, rankDiffs;
256cfb72fd3SJacques Pienaar   llvm::append_range(ranks, llvm::map_range(adaptor.getShapes(), [&](Value v) {
257c0a6318dSMatthias Springer                        return lb.create<tensor::DimOp>(v, zero);
2583842d4b6STres Popp                      }));
2593842d4b6STres Popp 
2603842d4b6STres Popp   // Find the maximum rank
2613842d4b6STres Popp   Value maxRank = ranks.front();
2623842d4b6STres Popp   for (Value v : llvm::drop_begin(ranks, 1)) {
263*d4fd2025Smlevesquedion     maxRank = lb.create<arith::MaxUIOp>(v, maxRank);
2643842d4b6STres Popp   }
2653842d4b6STres Popp 
2663842d4b6STres Popp   // Calculate the difference of ranks and the maximum rank for later offsets.
2673842d4b6STres Popp   llvm::append_range(rankDiffs, llvm::map_range(ranks, [&](Value v) {
268a54f4eaeSMogball                        return lb.create<arith::SubIOp>(indexTy, maxRank, v);
2693842d4b6STres Popp                      }));
2703842d4b6STres Popp 
271511484f2STres Popp   Type i1Ty = rewriter.getI1Type();
2723842d4b6STres Popp   Value trueVal =
273a54f4eaeSMogball       rewriter.create<arith::ConstantOp>(loc, i1Ty, rewriter.getBoolAttr(true));
274511484f2STres Popp 
2753842d4b6STres Popp   auto reduceResult = lb.create<ForOp>(
2763842d4b6STres Popp       loc, zero, maxRank, one, ValueRange{trueVal},
277511484f2STres Popp       [&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) {
2783842d4b6STres Popp         // Find a non-1 dim, if it exists. Note that the first part of this
2793842d4b6STres Popp         // could reuse the Broadcast lowering entirely, but we redo the work
2803842d4b6STres Popp         // here to make optimizations easier between the two loops.
2813842d4b6STres Popp         Value broadcastedDim = getBroadcastedDim(
282cfb72fd3SJacques Pienaar             ImplicitLocOpBuilder(loc, b), adaptor.getShapes(), rankDiffs, iv);
2833842d4b6STres Popp 
2843842d4b6STres Popp         Value broadcastable = iterArgs[0];
285cfb72fd3SJacques Pienaar         for (auto tup : llvm::zip(adaptor.getShapes(), rankDiffs)) {
2863842d4b6STres Popp           Value shape, rankDiff;
2873842d4b6STres Popp           std::tie(shape, rankDiff) = tup;
288a54f4eaeSMogball           Value outOfBounds = b.create<arith::CmpIOp>(
289a54f4eaeSMogball               loc, arith::CmpIPredicate::ult, iv, rankDiff);
2903842d4b6STres Popp           broadcastable =
2913842d4b6STres Popp               b.create<IfOp>(
2921125c5c0SFrederik Gossen                    loc, outOfBounds,
2933842d4b6STres Popp                    [&](OpBuilder &b, Location loc) {
2943842d4b6STres Popp                      // Non existent dimensions are always broadcastable
2953842d4b6STres Popp                      b.create<scf::YieldOp>(loc, broadcastable);
2963842d4b6STres Popp                    },
2973842d4b6STres Popp                    [&](OpBuilder &b, Location loc) {
2983842d4b6STres Popp                      // Every value needs to be either 1, or the same non-1
2993842d4b6STres Popp                      // value to be broadcastable in this dim.
3003842d4b6STres Popp                      Value operandDimension =
301a54f4eaeSMogball                          b.create<arith::SubIOp>(loc, indexTy, iv, rankDiff);
3023842d4b6STres Popp                      Value dimensionExtent = b.create<tensor::ExtractOp>(
3033842d4b6STres Popp                          loc, shape, ValueRange{operandDimension});
3043842d4b6STres Popp 
305a54f4eaeSMogball                      Value equalOne = b.create<arith::CmpIOp>(
306a54f4eaeSMogball                          loc, arith::CmpIPredicate::eq, dimensionExtent, one);
307a54f4eaeSMogball                      Value equalBroadcasted = b.create<arith::CmpIOp>(
308a54f4eaeSMogball                          loc, arith::CmpIPredicate::eq, dimensionExtent,
309a54f4eaeSMogball                          broadcastedDim);
310a54f4eaeSMogball                      Value result = b.create<arith::AndIOp>(
3113842d4b6STres Popp                          loc, broadcastable,
312a54f4eaeSMogball                          b.create<arith::OrIOp>(loc, equalOne,
313a54f4eaeSMogball                                                 equalBroadcasted));
3143842d4b6STres Popp                      b.create<scf::YieldOp>(loc, result);
3153842d4b6STres Popp                    })
3163842d4b6STres Popp                   .getResult(0);
3173842d4b6STres Popp         }
3183842d4b6STres Popp 
3193842d4b6STres Popp         b.create<scf::YieldOp>(loc, broadcastable);
320511484f2STres Popp       });
321511484f2STres Popp 
322c0342a2dSJacques Pienaar   rewriter.replaceOp(op, reduceResult.getResults().front());
323511484f2STres Popp   return success();
324511484f2STres Popp }
325511484f2STres Popp 
326511484f2STres Popp namespace {
3272f025e0eSJacques Pienaar class DimOpConverter : public OpConversionPattern<DimOp> {
3282f025e0eSJacques Pienaar   using OpConversionPattern<DimOp>::OpConversionPattern;
3292f025e0eSJacques Pienaar 
3302f025e0eSJacques Pienaar   LogicalResult
3312f025e0eSJacques Pienaar   matchAndRewrite(DimOp op, OpAdaptor adaptor,
3322f025e0eSJacques Pienaar                   ConversionPatternRewriter &rewriter) const override;
3332f025e0eSJacques Pienaar };
3342f025e0eSJacques Pienaar } // namespace
3352f025e0eSJacques Pienaar 
3362f025e0eSJacques Pienaar LogicalResult
matchAndRewrite(DimOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const3372f025e0eSJacques Pienaar DimOpConverter::matchAndRewrite(DimOp op, OpAdaptor adaptor,
3382f025e0eSJacques Pienaar                                 ConversionPatternRewriter &rewriter) const {
3392f025e0eSJacques Pienaar   // Lower to dim(X, i) to get_extent(shape_of(X), i) and rely on further
3402f025e0eSJacques Pienaar   // lowerings. This can be further optimized if needed to avoid intermediate
3412f025e0eSJacques Pienaar   // steps.
3422f025e0eSJacques Pienaar   auto shapeOf = rewriter.create<shape::ShapeOfOp>(op.getLoc(), op.getValue());
3432f025e0eSJacques Pienaar   rewriter.replaceOpWithNewOp<shape::GetExtentOp>(op, op.getType(), shapeOf,
344b23e72a2SJacques Pienaar                                                   op.getIndex());
3452f025e0eSJacques Pienaar   return success();
3462f025e0eSJacques Pienaar }
3472f025e0eSJacques Pienaar 
3482f025e0eSJacques Pienaar namespace {
3498577a090SFrederik Gossen class GetExtentOpConverter : public OpConversionPattern<GetExtentOp> {
3508577a090SFrederik Gossen   using OpConversionPattern<GetExtentOp>::OpConversionPattern;
3518577a090SFrederik Gossen 
3528577a090SFrederik Gossen   LogicalResult
353b54c724bSRiver Riddle   matchAndRewrite(GetExtentOp op, OpAdaptor adaptor,
3544baf18dbSFrederik Gossen                   ConversionPatternRewriter &rewriter) const override;
3554baf18dbSFrederik Gossen };
3564baf18dbSFrederik Gossen } // namespace
3574baf18dbSFrederik Gossen 
matchAndRewrite(GetExtentOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const3584baf18dbSFrederik Gossen LogicalResult GetExtentOpConverter::matchAndRewrite(
359b54c724bSRiver Riddle     GetExtentOp op, OpAdaptor adaptor,
3604baf18dbSFrederik Gossen     ConversionPatternRewriter &rewriter) const {
3616673c6cdSFrederik Gossen   // For now, only error-free types are supported by this lowering.
3625550c821STres Popp   if (isa<SizeType>(op.getType()))
3636673c6cdSFrederik Gossen     return failure();
3646673c6cdSFrederik Gossen 
3656673c6cdSFrederik Gossen   // Derive shape extent directly from shape origin if possible. This
3666673c6cdSFrederik Gossen   // circumvents the necessity to materialize the shape in memory.
367cfb72fd3SJacques Pienaar   if (auto shapeOfOp = op.getShape().getDefiningOp<ShapeOfOp>()) {
3685550c821STres Popp     if (isa<ShapedType>(shapeOfOp.getArg().getType())) {
369cfb72fd3SJacques Pienaar       rewriter.replaceOpWithNewOp<tensor::DimOp>(op, shapeOfOp.getArg(),
370cfb72fd3SJacques Pienaar                                                  adaptor.getDim());
3718577a090SFrederik Gossen       return success();
3728577a090SFrederik Gossen     }
3736673c6cdSFrederik Gossen   }
3748577a090SFrederik Gossen 
375cfb72fd3SJacques Pienaar   rewriter.replaceOpWithNewOp<tensor::ExtractOp>(op, rewriter.getIndexType(),
376cfb72fd3SJacques Pienaar                                                  adaptor.getShape(),
377cfb72fd3SJacques Pienaar                                                  ValueRange{adaptor.getDim()});
3788577a090SFrederik Gossen   return success();
3798577a090SFrederik Gossen }
3808577a090SFrederik Gossen 
3814baf18dbSFrederik Gossen namespace {
38224debf5aSFrederik Gossen class RankOpConverter : public OpConversionPattern<shape::RankOp> {
38324debf5aSFrederik Gossen public:
38424debf5aSFrederik Gossen   using OpConversionPattern<shape::RankOp>::OpConversionPattern;
38524debf5aSFrederik Gossen 
38624debf5aSFrederik Gossen   LogicalResult
387b54c724bSRiver Riddle   matchAndRewrite(shape::RankOp op, OpAdaptor adaptor,
3884baf18dbSFrederik Gossen                   ConversionPatternRewriter &rewriter) const override;
3894baf18dbSFrederik Gossen };
3904baf18dbSFrederik Gossen } // namespace
3914baf18dbSFrederik Gossen 
3924baf18dbSFrederik Gossen LogicalResult
matchAndRewrite(shape::RankOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const393b54c724bSRiver Riddle RankOpConverter::matchAndRewrite(shape::RankOp op, OpAdaptor adaptor,
3944baf18dbSFrederik Gossen                                  ConversionPatternRewriter &rewriter) const {
395a97940d4SFrederik Gossen   // For now, this lowering supports only error-free types.
3965550c821STres Popp   if (isa<SizeType>(op.getType()))
397a97940d4SFrederik Gossen     return failure();
398a97940d4SFrederik Gossen 
399cfb72fd3SJacques Pienaar   rewriter.replaceOpWithNewOp<tensor::DimOp>(op, adaptor.getShape(), 0);
40024debf5aSFrederik Gossen   return success();
40124debf5aSFrederik Gossen }
40224debf5aSFrederik Gossen 
4034baf18dbSFrederik Gossen namespace {
404a70f2eb3SFrederik Gossen /// Converts `shape.reduce` to `scf.for`.
405a70f2eb3SFrederik Gossen struct ReduceOpConverter : public OpConversionPattern<shape::ReduceOp> {
406a70f2eb3SFrederik Gossen public:
407a70f2eb3SFrederik Gossen   using OpConversionPattern::OpConversionPattern;
408a70f2eb3SFrederik Gossen 
409a70f2eb3SFrederik Gossen   LogicalResult
410b54c724bSRiver Riddle   matchAndRewrite(shape::ReduceOp op, OpAdaptor adaptor,
411a70f2eb3SFrederik Gossen                   ConversionPatternRewriter &rewriter) const final;
412a70f2eb3SFrederik Gossen };
413a70f2eb3SFrederik Gossen } // namespace
414a70f2eb3SFrederik Gossen 
415a70f2eb3SFrederik Gossen LogicalResult
matchAndRewrite(shape::ReduceOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const416b54c724bSRiver Riddle ReduceOpConverter::matchAndRewrite(shape::ReduceOp op, OpAdaptor adaptor,
417a70f2eb3SFrederik Gossen                                    ConversionPatternRewriter &rewriter) const {
418a70f2eb3SFrederik Gossen   // For now, this lowering is only defined on `tensor<?xindex>` operands.
4195550c821STres Popp   if (isa<ShapeType>(op.getShape().getType()))
420a70f2eb3SFrederik Gossen     return failure();
421a70f2eb3SFrederik Gossen 
422a70f2eb3SFrederik Gossen   auto loc = op.getLoc();
423a70f2eb3SFrederik Gossen 
424a54f4eaeSMogball   Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
425a54f4eaeSMogball   Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
426a70f2eb3SFrederik Gossen   Type indexTy = rewriter.getIndexType();
427e2310704SJulian Gross   Value rank =
428cfb72fd3SJacques Pienaar       rewriter.create<tensor::DimOp>(loc, indexTy, adaptor.getShape(), zero);
429a70f2eb3SFrederik Gossen 
430a70f2eb3SFrederik Gossen   auto loop = rewriter.create<scf::ForOp>(
431cfb72fd3SJacques Pienaar       loc, zero, rank, one, op.getInitVals(),
432a70f2eb3SFrederik Gossen       [&](OpBuilder &b, Location loc, Value iv, ValueRange args) {
433cfb72fd3SJacques Pienaar         Value extent = b.create<tensor::ExtractOp>(loc, adaptor.getShape(), iv);
434a70f2eb3SFrederik Gossen 
435a70f2eb3SFrederik Gossen         SmallVector<Value, 2> mappedValues{iv, extent};
436a70f2eb3SFrederik Gossen         mappedValues.append(args.begin(), args.end());
437a70f2eb3SFrederik Gossen 
4384d67b278SJeff Niu         IRMapping mapping;
439a70f2eb3SFrederik Gossen         Block *reduceBody = op.getBody();
440a70f2eb3SFrederik Gossen         mapping.map(reduceBody->getArguments(), mappedValues);
441a70f2eb3SFrederik Gossen         for (auto &nested : reduceBody->without_terminator())
442a70f2eb3SFrederik Gossen           b.clone(nested, mapping);
443a70f2eb3SFrederik Gossen 
444a70f2eb3SFrederik Gossen         SmallVector<Value, 2> mappedResults;
445a70f2eb3SFrederik Gossen         for (auto result : reduceBody->getTerminator()->getOperands())
446a70f2eb3SFrederik Gossen           mappedResults.push_back(mapping.lookup(result));
447a70f2eb3SFrederik Gossen         b.create<scf::YieldOp>(loc, mappedResults);
448a70f2eb3SFrederik Gossen       });
449a70f2eb3SFrederik Gossen 
450a70f2eb3SFrederik Gossen   rewriter.replaceOp(op, loop.getResults());
451a70f2eb3SFrederik Gossen   return success();
452a70f2eb3SFrederik Gossen }
453a70f2eb3SFrederik Gossen 
454a70f2eb3SFrederik Gossen namespace {
455a70f2eb3SFrederik Gossen /// Converts `shape.shape_eq` to an `scf.for` loop. For now, the lowering is
456a70f2eb3SFrederik Gossen /// only defined on `tensor<?xindex>` operands. The test for equality first
457a70f2eb3SFrederik Gossen /// compares their size and, if equal, checks every extent for equality.
458a70f2eb3SFrederik Gossen ///
459a70f2eb3SFrederik Gossen /// Example:
460a70f2eb3SFrederik Gossen ///
461a70f2eb3SFrederik Gossen /// %result = shape.shape_eq %a, %b : tensor<?xindex>, tensor<?xindex>
462a70f2eb3SFrederik Gossen ///
463a70f2eb3SFrederik Gossen /// becomes
464a70f2eb3SFrederik Gossen ///
465cb3aa49eSMogball /// %c0 = arith.constant 0 : index
466a70f2eb3SFrederik Gossen /// %0 = dim %arg0, %c0 : tensor<?xindex>
467a70f2eb3SFrederik Gossen /// %1 = dim %arg1, %c0 : tensor<?xindex>
468a54f4eaeSMogball /// %2 = arith.cmpi "eq", %0, %1 : index
469a70f2eb3SFrederik Gossen /// %result = scf.if %2 -> (i1) {
470a54f4eaeSMogball ///   %c1 = arith.constant 1 : index
471a54f4eaeSMogball ///   %true = arith.constant true
472a70f2eb3SFrederik Gossen ///   %4 = scf.for %arg2 = %c0 to %0 step %c1 iter_args(%arg3 = %true) -> (i1) {
473444822d7SSean Silva ///     %5 = tensor.extract %arg0[%arg2] : tensor<?xindex>
474444822d7SSean Silva ///     %6 = tensor.extract %arg1[%arg2] : tensor<?xindex>
475a54f4eaeSMogball ///     %7 = arith.cmpi "eq", %5, %6 : index
476a54f4eaeSMogball ///     %8 = arith.andi %arg3, %7 : i1
477a70f2eb3SFrederik Gossen ///     scf.yield %8 : i1
478a70f2eb3SFrederik Gossen ///   }
479a70f2eb3SFrederik Gossen ///   scf.yield %4 : i1
480a70f2eb3SFrederik Gossen /// } else {
481a54f4eaeSMogball ///   %false = arith.constant false
482a70f2eb3SFrederik Gossen ///   scf.yield %false : i1
483a70f2eb3SFrederik Gossen /// }
484a70f2eb3SFrederik Gossen ///
485a70f2eb3SFrederik Gossen struct ShapeEqOpConverter : public OpConversionPattern<ShapeEqOp> {
486a70f2eb3SFrederik Gossen   using OpConversionPattern<ShapeEqOp>::OpConversionPattern;
487a70f2eb3SFrederik Gossen 
488a70f2eb3SFrederik Gossen   LogicalResult
489b54c724bSRiver Riddle   matchAndRewrite(ShapeEqOp op, OpAdaptor adaptor,
490a70f2eb3SFrederik Gossen                   ConversionPatternRewriter &rewriter) const override;
491a70f2eb3SFrederik Gossen };
492a70f2eb3SFrederik Gossen } // namespace
493a70f2eb3SFrederik Gossen 
494a70f2eb3SFrederik Gossen LogicalResult
matchAndRewrite(ShapeEqOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const495b54c724bSRiver Riddle ShapeEqOpConverter::matchAndRewrite(ShapeEqOp op, OpAdaptor adaptor,
496a70f2eb3SFrederik Gossen                                     ConversionPatternRewriter &rewriter) const {
497cfb72fd3SJacques Pienaar   if (!llvm::all_of(op.getShapes(),
4985550c821STres Popp                     [](Value v) { return !isa<ShapeType>(v.getType()); }))
499a70f2eb3SFrederik Gossen     return failure();
50024acadefSBenjamin Kramer 
50124acadefSBenjamin Kramer   Type i1Ty = rewriter.getI1Type();
502cfb72fd3SJacques Pienaar   if (op.getShapes().size() <= 1) {
503a54f4eaeSMogball     rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, i1Ty,
50424acadefSBenjamin Kramer                                                    rewriter.getBoolAttr(true));
50524acadefSBenjamin Kramer     return success();
506a70f2eb3SFrederik Gossen   }
507a70f2eb3SFrederik Gossen 
508a70f2eb3SFrederik Gossen   auto loc = op.getLoc();
509a70f2eb3SFrederik Gossen   Type indexTy = rewriter.getIndexType();
510a54f4eaeSMogball   Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
511cfb72fd3SJacques Pienaar   Value firstShape = adaptor.getShapes().front();
512e2310704SJulian Gross   Value firstRank =
513c0a6318dSMatthias Springer       rewriter.create<tensor::DimOp>(loc, indexTy, firstShape, zero);
51424acadefSBenjamin Kramer   Value result = nullptr;
51524acadefSBenjamin Kramer   // Generate a linear sequence of compares, all with firstShape as lhs.
516cfb72fd3SJacques Pienaar   for (Value shape : adaptor.getShapes().drop_front(1)) {
517c0a6318dSMatthias Springer     Value rank = rewriter.create<tensor::DimOp>(loc, indexTy, shape, zero);
518a54f4eaeSMogball     Value eqRank = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
519a54f4eaeSMogball                                                   firstRank, rank);
52024acadefSBenjamin Kramer     auto same = rewriter.create<IfOp>(
5211125c5c0SFrederik Gossen         loc, eqRank,
522a70f2eb3SFrederik Gossen         [&](OpBuilder &b, Location loc) {
523a54f4eaeSMogball           Value one = b.create<arith::ConstantIndexOp>(loc, 1);
524a54f4eaeSMogball           Value init =
525a54f4eaeSMogball               b.create<arith::ConstantOp>(loc, i1Ty, b.getBoolAttr(true));
526a70f2eb3SFrederik Gossen           auto loop = b.create<scf::ForOp>(
52724acadefSBenjamin Kramer               loc, zero, firstRank, one, ValueRange{init},
528a70f2eb3SFrederik Gossen               [&](OpBuilder &b, Location nestedLoc, Value iv, ValueRange args) {
529a70f2eb3SFrederik Gossen                 Value conj = args[0];
530a70f2eb3SFrederik Gossen                 Value lhsExtent =
53124acadefSBenjamin Kramer                     b.create<tensor::ExtractOp>(loc, firstShape, iv);
53224acadefSBenjamin Kramer                 Value rhsExtent = b.create<tensor::ExtractOp>(loc, shape, iv);
533a54f4eaeSMogball                 Value eqExtent = b.create<arith::CmpIOp>(
534a54f4eaeSMogball                     loc, arith::CmpIPredicate::eq, lhsExtent, rhsExtent);
535a54f4eaeSMogball                 Value conjNext = b.create<arith::AndIOp>(loc, conj, eqExtent);
536a70f2eb3SFrederik Gossen                 b.create<scf::YieldOp>(loc, ValueRange({conjNext}));
537a70f2eb3SFrederik Gossen               });
538a70f2eb3SFrederik Gossen           b.create<scf::YieldOp>(loc, loop.getResults());
539a70f2eb3SFrederik Gossen         },
540a70f2eb3SFrederik Gossen         [&](OpBuilder &b, Location loc) {
541a54f4eaeSMogball           Value result =
542a54f4eaeSMogball               b.create<arith::ConstantOp>(loc, i1Ty, b.getBoolAttr(false));
543a70f2eb3SFrederik Gossen           b.create<scf::YieldOp>(loc, result);
544a70f2eb3SFrederik Gossen         });
54524acadefSBenjamin Kramer     result = !result ? same.getResult(0)
546a54f4eaeSMogball                      : rewriter.create<arith::AndIOp>(loc, result,
547a54f4eaeSMogball                                                       same.getResult(0));
54824acadefSBenjamin Kramer   }
54924acadefSBenjamin Kramer   rewriter.replaceOp(op, result);
550a70f2eb3SFrederik Gossen   return success();
551a70f2eb3SFrederik Gossen }
552a70f2eb3SFrederik Gossen 
553a70f2eb3SFrederik Gossen namespace {
554a70f2eb3SFrederik Gossen class ShapeOfOpConversion : public OpConversionPattern<ShapeOfOp> {
555a70f2eb3SFrederik Gossen public:
556a70f2eb3SFrederik Gossen   using OpConversionPattern<ShapeOfOp>::OpConversionPattern;
557a70f2eb3SFrederik Gossen 
558a70f2eb3SFrederik Gossen   LogicalResult
559b54c724bSRiver Riddle   matchAndRewrite(ShapeOfOp op, OpAdaptor adaptor,
560a70f2eb3SFrederik Gossen                   ConversionPatternRewriter &rewriter) const override;
561a70f2eb3SFrederik Gossen };
562a70f2eb3SFrederik Gossen } // namespace
563a70f2eb3SFrederik Gossen 
matchAndRewrite(ShapeOfOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const564a70f2eb3SFrederik Gossen LogicalResult ShapeOfOpConversion::matchAndRewrite(
565b54c724bSRiver Riddle     ShapeOfOp op, OpAdaptor adaptor,
566a70f2eb3SFrederik Gossen     ConversionPatternRewriter &rewriter) const {
567a70f2eb3SFrederik Gossen 
568a70f2eb3SFrederik Gossen   // For now, only error-free types are supported by this lowering.
5695550c821STres Popp   if (isa<ShapeType>(op.getType()))
570a70f2eb3SFrederik Gossen     return failure();
571a70f2eb3SFrederik Gossen 
572be7352c0SSean Silva   // For ranked tensor arguments, lower to `tensor.from_elements`.
5735106a8b8SFrederik Gossen   auto loc = op.getLoc();
574cfb72fd3SJacques Pienaar   Value tensor = adaptor.getArg();
575a70f2eb3SFrederik Gossen   Type tensorTy = tensor.getType();
5765550c821STres Popp   if (isa<RankedTensorType>(tensorTy)) {
577a70f2eb3SFrederik Gossen 
578a70f2eb3SFrederik Gossen     // Build values for individual extents.
579a70f2eb3SFrederik Gossen     SmallVector<Value, 8> extentValues;
5805550c821STres Popp     RankedTensorType rankedTensorTy = cast<RankedTensorType>(tensorTy);
581a70f2eb3SFrederik Gossen     int64_t rank = rankedTensorTy.getRank();
582a70f2eb3SFrederik Gossen     for (int64_t i = 0; i < rank; i++) {
583a70f2eb3SFrederik Gossen       if (rankedTensorTy.isDynamicDim(i)) {
584c0a6318dSMatthias Springer         Value extent = rewriter.create<tensor::DimOp>(loc, tensor, i);
585a70f2eb3SFrederik Gossen         extentValues.push_back(extent);
586a70f2eb3SFrederik Gossen       } else {
587a54f4eaeSMogball         Value extent = rewriter.create<arith::ConstantIndexOp>(
588a54f4eaeSMogball             loc, rankedTensorTy.getDimSize(i));
589a70f2eb3SFrederik Gossen         extentValues.push_back(extent);
590a70f2eb3SFrederik Gossen       }
591a70f2eb3SFrederik Gossen     }
592a70f2eb3SFrederik Gossen 
593a70f2eb3SFrederik Gossen     // Materialize extent tensor.
594be7352c0SSean Silva     Value staticExtentTensor = rewriter.create<tensor::FromElementsOp>(
595f77e9f87SAlexander Belyaev         loc, RankedTensorType::get({rank}, rewriter.getIndexType()),
596f77e9f87SAlexander Belyaev         extentValues);
597129d6e55SSean Silva     rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(),
598129d6e55SSean Silva                                                 staticExtentTensor);
599a70f2eb3SFrederik Gossen     return success();
600a70f2eb3SFrederik Gossen   }
601a70f2eb3SFrederik Gossen 
602be7352c0SSean Silva   // Lower to `tensor.generate` otherwise.
6035106a8b8SFrederik Gossen   auto *ctx = rewriter.getContext();
60415f8f3e2SAlexander Belyaev   Value rank = rewriter.create<tensor::RankOp>(loc, tensor);
605be7352c0SSean Silva   rewriter.replaceOpWithNewOp<tensor::GenerateOp>(
6065106a8b8SFrederik Gossen       op, getExtentTensorType(ctx), ValueRange{rank},
6075106a8b8SFrederik Gossen       [&](OpBuilder &b, Location loc, ValueRange args) {
6085106a8b8SFrederik Gossen         Value dim = args.front();
609c0a6318dSMatthias Springer         Value extent = b.create<tensor::DimOp>(loc, tensor, dim);
610be7352c0SSean Silva         b.create<tensor::YieldOp>(loc, extent);
611a70f2eb3SFrederik Gossen       });
612a70f2eb3SFrederik Gossen 
613a70f2eb3SFrederik Gossen   return success();
614a70f2eb3SFrederik Gossen }
615a70f2eb3SFrederik Gossen 
616a70f2eb3SFrederik Gossen namespace {
61742c195f0SBenjamin Kramer class SplitAtOpConversion : public OpConversionPattern<SplitAtOp> {
61842c195f0SBenjamin Kramer public:
61942c195f0SBenjamin Kramer   using OpConversionPattern<SplitAtOp>::OpConversionPattern;
62042c195f0SBenjamin Kramer 
62142c195f0SBenjamin Kramer   LogicalResult
622b54c724bSRiver Riddle   matchAndRewrite(SplitAtOp op, OpAdaptor adaptor,
62342c195f0SBenjamin Kramer                   ConversionPatternRewriter &rewriter) const override;
62442c195f0SBenjamin Kramer };
62542c195f0SBenjamin Kramer } // namespace
62642c195f0SBenjamin Kramer 
matchAndRewrite(SplitAtOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const62742c195f0SBenjamin Kramer LogicalResult SplitAtOpConversion::matchAndRewrite(
628b54c724bSRiver Riddle     SplitAtOp op, OpAdaptor adaptor,
62942c195f0SBenjamin Kramer     ConversionPatternRewriter &rewriter) const {
63042c195f0SBenjamin Kramer   // Error conditions are not implemented, only lower if all operands and
63142c195f0SBenjamin Kramer   // results are extent tensors.
632cfb72fd3SJacques Pienaar   if (llvm::any_of(ValueRange{op.getOperand(), op.getHead(), op.getTail()},
6335550c821STres Popp                    [](Value v) { return isa<ShapeType>(v.getType()); }))
63442c195f0SBenjamin Kramer     return failure();
63542c195f0SBenjamin Kramer 
63642c195f0SBenjamin Kramer   ImplicitLocOpBuilder b(op.getLoc(), rewriter);
637a54f4eaeSMogball   Value zero = b.create<arith::ConstantIndexOp>(0);
638cfb72fd3SJacques Pienaar   Value rank = b.create<tensor::DimOp>(adaptor.getOperand(), zero);
63942c195f0SBenjamin Kramer 
64042c195f0SBenjamin Kramer   // index < 0 ? index + rank : index
641cfb72fd3SJacques Pienaar   Value originalIndex = adaptor.getIndex();
642a54f4eaeSMogball   Value add = b.create<arith::AddIOp>(originalIndex, rank);
64342c195f0SBenjamin Kramer   Value indexIsNegative =
644a54f4eaeSMogball       b.create<arith::CmpIOp>(arith::CmpIPredicate::slt, originalIndex, zero);
645dec8af70SRiver Riddle   Value index = b.create<arith::SelectOp>(indexIsNegative, add, originalIndex);
64642c195f0SBenjamin Kramer 
647a54f4eaeSMogball   Value one = b.create<arith::ConstantIndexOp>(1);
648060208b4SMatthias Springer   Value head =
649cfb72fd3SJacques Pienaar       b.create<tensor::ExtractSliceOp>(adaptor.getOperand(), zero, index, one);
650a54f4eaeSMogball   Value tailSize = b.create<arith::SubIOp>(rank, index);
651cfb72fd3SJacques Pienaar   Value tail = b.create<tensor::ExtractSliceOp>(adaptor.getOperand(), index,
652cfb72fd3SJacques Pienaar                                                 tailSize, one);
65342c195f0SBenjamin Kramer   rewriter.replaceOp(op, {head, tail});
65442c195f0SBenjamin Kramer   return success();
65542c195f0SBenjamin Kramer }
65642c195f0SBenjamin Kramer 
65742c195f0SBenjamin Kramer namespace {
658a70f2eb3SFrederik Gossen class ToExtentTensorOpConversion
659a70f2eb3SFrederik Gossen     : public OpConversionPattern<ToExtentTensorOp> {
660a70f2eb3SFrederik Gossen public:
661a70f2eb3SFrederik Gossen   using OpConversionPattern<ToExtentTensorOp>::OpConversionPattern;
662a70f2eb3SFrederik Gossen 
663a70f2eb3SFrederik Gossen   LogicalResult
matchAndRewrite(ToExtentTensorOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const664b54c724bSRiver Riddle   matchAndRewrite(ToExtentTensorOp op, OpAdaptor adaptor,
665a70f2eb3SFrederik Gossen                   ConversionPatternRewriter &rewriter) const override {
6665550c821STres Popp     if (!isa<RankedTensorType>(adaptor.getInput().getType()))
667a70f2eb3SFrederik Gossen       return rewriter.notifyMatchFailure(op, "input needs to be a tensor");
668a70f2eb3SFrederik Gossen 
669129d6e55SSean Silva     rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(),
670cfb72fd3SJacques Pienaar                                                 adaptor.getInput());
671a70f2eb3SFrederik Gossen     return success();
672a70f2eb3SFrederik Gossen   }
673a70f2eb3SFrederik Gossen };
674a70f2eb3SFrederik Gossen } // namespace
675a70f2eb3SFrederik Gossen 
676a70f2eb3SFrederik Gossen namespace {
677d05d4219STres Popp /// Import the Shape Ops to Std Patterns.
678d05d4219STres Popp #include "ShapeToStandard.cpp.inc"
679d05d4219STres Popp } // namespace
680d05d4219STres Popp 
681d05d4219STres Popp namespace {
6823713314bSFrederik Gossen /// Conversion pass.
6833713314bSFrederik Gossen class ConvertShapeToStandardPass
68467d0d7acSMichele Scuttari     : public impl::ConvertShapeToStandardBase<ConvertShapeToStandardPass> {
685eaf49130SFrederik Gossen 
6864baf18dbSFrederik Gossen   void runOnOperation() override;
6874baf18dbSFrederik Gossen };
6884baf18dbSFrederik Gossen } // namespace
6894baf18dbSFrederik Gossen 
runOnOperation()6904baf18dbSFrederik Gossen void ConvertShapeToStandardPass::runOnOperation() {
6913713314bSFrederik Gossen   // Setup target legality.
692b6b9d3eaSFrederik Gossen   MLIRContext &ctx = getContext();
6933713314bSFrederik Gossen   ConversionTarget target(ctx);
694abc362a1SJakub Kuderski   target.addLegalDialect<arith::ArithDialect, SCFDialect,
6951f971e23SRiver Riddle                          tensor::TensorDialect>();
69658ceae95SRiver Riddle   target.addLegalOp<CstrRequireOp, func::FuncOp, ModuleOp>();
6973713314bSFrederik Gossen 
6983713314bSFrederik Gossen   // Setup conversion patterns.
699dc4e913bSChris Lattner   RewritePatternSet patterns(&ctx);
7003a506b31SChris Lattner   populateShapeToStandardConversionPatterns(patterns);
7013713314bSFrederik Gossen 
7023713314bSFrederik Gossen   // Apply conversion.
7033713314bSFrederik Gossen   auto module = getOperation();
7043fffffa8SRiver Riddle   if (failed(applyPartialConversion(module, target, std::move(patterns))))
7053713314bSFrederik Gossen     signalPassFailure();
7063713314bSFrederik Gossen }
7073713314bSFrederik Gossen 
populateShapeToStandardConversionPatterns(RewritePatternSet & patterns)70824edbdf9SFrederik Gossen void mlir::populateShapeToStandardConversionPatterns(
709dc4e913bSChris Lattner     RewritePatternSet &patterns) {
7103713314bSFrederik Gossen   // clang-format off
7111d909c9aSChris Lattner   populateWithGenerated(patterns);
712dc4e913bSChris Lattner   patterns.add<
7139df6afbbSFrederik Gossen       AnyOpConversion,
714a54f4eaeSMogball       BinaryOpConversion<AddOp, arith::AddIOp>,
715a54f4eaeSMogball       BinaryOpConversion<MulOp, arith::MulIOp>,
716a70f2eb3SFrederik Gossen       BroadcastOpConverter,
717a70f2eb3SFrederik Gossen       ConstShapeOpConverter,
7185d9f33aaSStephan Herhut       ConstSizeOpConversion,
7192f025e0eSJacques Pienaar       DimOpConverter,
720511484f2STres Popp       IsBroadcastableOpConverter,
7218577a090SFrederik Gossen       GetExtentOpConverter,
72224debf5aSFrederik Gossen       RankOpConverter,
723a70f2eb3SFrederik Gossen       ReduceOpConverter,
724a70f2eb3SFrederik Gossen       ShapeEqOpConverter,
7255d9f33aaSStephan Herhut       ShapeOfOpConversion,
72642c195f0SBenjamin Kramer       SplitAtOpConversion,
7273a506b31SChris Lattner       ToExtentTensorOpConversion>(patterns.getContext());
7283713314bSFrederik Gossen   // clang-format on
7293713314bSFrederik Gossen }
7303713314bSFrederik Gossen 
73124edbdf9SFrederik Gossen std::unique_ptr<OperationPass<ModuleOp>>
createConvertShapeToStandardPass()73224edbdf9SFrederik Gossen mlir::createConvertShapeToStandardPass() {
7333713314bSFrederik Gossen   return std::make_unique<ConvertShapeToStandardPass>();
7343713314bSFrederik Gossen }
735