xref: /llvm-project/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp (revision d4fd20258f63d30be638b04f10eaa469707759f0)
1 //===- ShapeToStandard.cpp - conversion from Shape to Standard dialect ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h"
10 
11 #include "mlir/Dialect/Arith/IR/Arith.h"
12 #include "mlir/Dialect/Func/IR/FuncOps.h"
13 #include "mlir/Dialect/SCF/IR/SCF.h"
14 #include "mlir/Dialect/Shape/IR/Shape.h"
15 #include "mlir/Dialect/Tensor/IR/Tensor.h"
16 #include "mlir/IR/IRMapping.h"
17 #include "mlir/IR/ImplicitLocOpBuilder.h"
18 #include "mlir/Pass/Pass.h"
19 #include "mlir/Transforms/DialectConversion.h"
20 #include "llvm/ADT/STLExtras.h"
21 
22 namespace mlir {
23 #define GEN_PASS_DEF_CONVERTSHAPETOSTANDARD
24 #include "mlir/Conversion/Passes.h.inc"
25 } // namespace mlir
26 
27 using namespace mlir;
28 using namespace mlir::shape;
29 using namespace mlir::scf;
30 
31 /// Conversion patterns.
32 namespace {
33 class AnyOpConversion : public OpConversionPattern<AnyOp> {
34 public:
35   using OpConversionPattern<AnyOp>::OpConversionPattern;
36 
37   LogicalResult
38   matchAndRewrite(AnyOp op, OpAdaptor adaptor,
39                   ConversionPatternRewriter &rewriter) const override;
40 };
41 } // namespace
42 
43 LogicalResult
matchAndRewrite(AnyOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const44 AnyOpConversion::matchAndRewrite(AnyOp op, OpAdaptor adaptor,
45                                  ConversionPatternRewriter &rewriter) const {
46   // Replace `any` with its first operand.
47   // Any operand would be a valid substitution.
48   rewriter.replaceOp(op, {adaptor.getInputs().front()});
49   return success();
50 }
51 
52 namespace {
53 template <typename SrcOpTy, typename DstOpTy>
54 class BinaryOpConversion : public OpConversionPattern<SrcOpTy> {
55 public:
56   using OpConversionPattern<SrcOpTy>::OpConversionPattern;
57 
58   LogicalResult
matchAndRewrite(SrcOpTy op,typename SrcOpTy::Adaptor adaptor,ConversionPatternRewriter & rewriter) const59   matchAndRewrite(SrcOpTy op, typename SrcOpTy::Adaptor adaptor,
60                   ConversionPatternRewriter &rewriter) const override {
61     // For now, only error-free types are supported by this lowering.
62     if (isa<SizeType>(op.getType()))
63       return failure();
64 
65     rewriter.replaceOpWithNewOp<DstOpTy>(op, adaptor.getLhs(),
66                                          adaptor.getRhs());
67     return success();
68   }
69 };
70 } // namespace
71 
72 namespace {
73 struct BroadcastOpConverter : public OpConversionPattern<BroadcastOp> {
74   using OpConversionPattern<BroadcastOp>::OpConversionPattern;
75 
76   LogicalResult
77   matchAndRewrite(BroadcastOp op, OpAdaptor adaptor,
78                   ConversionPatternRewriter &rewriter) const override;
79 };
80 
81 // Get the resulting extent in a given dimension. This is computed with any
82 // number of extent tensors and shifted offsets into them.
getBroadcastedDim(ImplicitLocOpBuilder lb,ValueRange extentTensors,ValueRange rankDiffs,Value outputDimension)83 Value getBroadcastedDim(ImplicitLocOpBuilder lb, ValueRange extentTensors,
84                         ValueRange rankDiffs, Value outputDimension) {
85   Value one = lb.create<arith::ConstantIndexOp>(1);
86   Value broadcastedDim = one;
87   for (auto tup : llvm::zip(extentTensors, rankDiffs)) {
88     Value shape = std::get<0>(tup);
89     Value rankDiff = std::get<1>(tup);
90     Value outOfBounds = lb.create<arith::CmpIOp>(arith::CmpIPredicate::ult,
91                                                  outputDimension, rankDiff);
92     Type indexTy = lb.getIndexType();
93     broadcastedDim =
94         lb.create<IfOp>(
95               outOfBounds,
96               [&](OpBuilder &b, Location loc) {
97                 b.create<scf::YieldOp>(loc, broadcastedDim);
98               },
99               [&](OpBuilder &b, Location loc) {
100                 // The broadcasting logic is:
101                 // - if one extent (here we arbitrarily choose the
102                 // extent from the greater-rank operand) is equal to 1,
103                 // then take the extent from the other operand
104                 // - otherwise, take the extent as-is.
105                 // Note that this logic remains correct in the presence
106                 // of dimensions of zero extent.
107                 Value lesserRankOperandDimension = b.create<arith::SubIOp>(
108                     loc, indexTy, outputDimension, rankDiff);
109                 Value lesserRankOperandExtent = b.create<tensor::ExtractOp>(
110                     loc, shape, ValueRange{lesserRankOperandDimension});
111 
112                 Value dimIsOne =
113                     b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
114                                             lesserRankOperandExtent, one);
115                 Value dim = b.create<arith::SelectOp>(
116                     loc, dimIsOne, broadcastedDim, lesserRankOperandExtent);
117                 b.create<scf::YieldOp>(loc, dim);
118               })
119             .getResult(0);
120   }
121   return broadcastedDim;
122 }
123 } // namespace
124 
matchAndRewrite(BroadcastOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const125 LogicalResult BroadcastOpConverter::matchAndRewrite(
126     BroadcastOp op, OpAdaptor adaptor,
127     ConversionPatternRewriter &rewriter) const {
128   // For now, this lowering is only defined on `tensor<?xindex>` operands, not
129   // on shapes.
130   if (isa<ShapeType>(op.getType()))
131     return failure();
132 
133   auto loc = op.getLoc();
134   ImplicitLocOpBuilder lb(loc, rewriter);
135 
136   Value zero = lb.create<arith::ConstantIndexOp>(0);
137   Type indexTy = lb.getIndexType();
138 
139   // Save all the ranks for bounds checking. Because this is a tensor
140   // representing the shape extents, the rank is the extent of the only
141   // dimension in the tensor.
142   SmallVector<Value> ranks, rankDiffs;
143   llvm::append_range(ranks, llvm::map_range(adaptor.getShapes(), [&](Value v) {
144                        return lb.create<tensor::DimOp>(v, zero);
145                      }));
146 
147   // Find the maximum rank
148   Value maxRank = ranks.front();
149   for (Value v : llvm::drop_begin(ranks, 1)) {
150     maxRank = lb.create<arith::MaxUIOp>(v, maxRank);
151   }
152 
153   // Calculate the difference of ranks and the maximum rank for later offsets.
154   llvm::append_range(rankDiffs, llvm::map_range(ranks, [&](Value v) {
155                        return lb.create<arith::SubIOp>(indexTy, maxRank, v);
156                      }));
157 
158   Value replacement = lb.create<tensor::GenerateOp>(
159       getExtentTensorType(lb.getContext()), ValueRange{maxRank},
160       [&](OpBuilder &b, Location loc, ValueRange args) {
161         Value broadcastedDim =
162             getBroadcastedDim(ImplicitLocOpBuilder(loc, b), adaptor.getShapes(),
163                               rankDiffs, args[0]);
164 
165         b.create<tensor::YieldOp>(loc, broadcastedDim);
166       });
167   if (replacement.getType() != op.getType())
168     replacement = lb.create<tensor::CastOp>(op.getType(), replacement);
169   rewriter.replaceOp(op, replacement);
170   return success();
171 }
172 
173 namespace {
174 class ConstShapeOpConverter : public OpConversionPattern<ConstShapeOp> {
175 public:
176   using OpConversionPattern<ConstShapeOp>::OpConversionPattern;
177 
178   LogicalResult
179   matchAndRewrite(ConstShapeOp op, OpAdaptor adaptor,
180                   ConversionPatternRewriter &rewriter) const override;
181 };
182 } // namespace
183 
matchAndRewrite(ConstShapeOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const184 LogicalResult ConstShapeOpConverter::matchAndRewrite(
185     ConstShapeOp op, OpAdaptor adaptor,
186     ConversionPatternRewriter &rewriter) const {
187 
188   // For now, this lowering supports only extent tensors, not `shape.shape`
189   // types.
190   if (isa<ShapeType>(op.getType()))
191     return failure();
192 
193   auto loc = op.getLoc();
194   SmallVector<Value, 4> extentOperands;
195   for (auto extent : op.getShape()) {
196     extentOperands.push_back(
197         rewriter.create<arith::ConstantIndexOp>(loc, extent.getLimitedValue()));
198   }
199   Type resultTy =
200       RankedTensorType::get({op.getShape().size()}, rewriter.getIndexType());
201   Value tensor =
202       rewriter.create<tensor::FromElementsOp>(loc, resultTy, extentOperands);
203   rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultTy, tensor);
204   return success();
205 }
206 
207 namespace {
208 class ConstSizeOpConversion : public OpConversionPattern<ConstSizeOp> {
209 public:
210   using OpConversionPattern<ConstSizeOp>::OpConversionPattern;
211 
212   LogicalResult
213   matchAndRewrite(ConstSizeOp op, OpAdaptor adaptor,
214                   ConversionPatternRewriter &rewriter) const override;
215 };
216 } // namespace
217 
matchAndRewrite(ConstSizeOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const218 LogicalResult ConstSizeOpConversion::matchAndRewrite(
219     ConstSizeOp op, OpAdaptor adaptor,
220     ConversionPatternRewriter &rewriter) const {
221   rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(
222       op, op.getValue().getSExtValue());
223   return success();
224 }
225 
226 namespace {
227 struct IsBroadcastableOpConverter
228     : public OpConversionPattern<IsBroadcastableOp> {
229   using OpConversionPattern<IsBroadcastableOp>::OpConversionPattern;
230 
231   LogicalResult
232   matchAndRewrite(IsBroadcastableOp op, OpAdaptor adaptor,
233                   ConversionPatternRewriter &rewriter) const override;
234 };
235 } // namespace
236 
matchAndRewrite(IsBroadcastableOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const237 LogicalResult IsBroadcastableOpConverter::matchAndRewrite(
238     IsBroadcastableOp op, OpAdaptor adaptor,
239     ConversionPatternRewriter &rewriter) const {
240   // For now, this lowering is only defined on `tensor<?xindex>` operands, not
241   // on shapes.
242   if (!llvm::all_of(op.getShapes(),
243                     [](Value v) { return !isa<ShapeType>(v.getType()); }))
244     return failure();
245 
246   auto loc = op.getLoc();
247   ImplicitLocOpBuilder lb(loc, rewriter);
248   Value zero = lb.create<arith::ConstantIndexOp>(0);
249   Value one = lb.create<arith::ConstantIndexOp>(1);
250   Type indexTy = lb.getIndexType();
251 
252   // Save all the ranks for bounds checking. Because this is a tensor
253   // representing the shape extents, the rank is the extent of the only
254   // dimension in the tensor.
255   SmallVector<Value> ranks, rankDiffs;
256   llvm::append_range(ranks, llvm::map_range(adaptor.getShapes(), [&](Value v) {
257                        return lb.create<tensor::DimOp>(v, zero);
258                      }));
259 
260   // Find the maximum rank
261   Value maxRank = ranks.front();
262   for (Value v : llvm::drop_begin(ranks, 1)) {
263     maxRank = lb.create<arith::MaxUIOp>(v, maxRank);
264   }
265 
266   // Calculate the difference of ranks and the maximum rank for later offsets.
267   llvm::append_range(rankDiffs, llvm::map_range(ranks, [&](Value v) {
268                        return lb.create<arith::SubIOp>(indexTy, maxRank, v);
269                      }));
270 
271   Type i1Ty = rewriter.getI1Type();
272   Value trueVal =
273       rewriter.create<arith::ConstantOp>(loc, i1Ty, rewriter.getBoolAttr(true));
274 
275   auto reduceResult = lb.create<ForOp>(
276       loc, zero, maxRank, one, ValueRange{trueVal},
277       [&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) {
278         // Find a non-1 dim, if it exists. Note that the first part of this
279         // could reuse the Broadcast lowering entirely, but we redo the work
280         // here to make optimizations easier between the two loops.
281         Value broadcastedDim = getBroadcastedDim(
282             ImplicitLocOpBuilder(loc, b), adaptor.getShapes(), rankDiffs, iv);
283 
284         Value broadcastable = iterArgs[0];
285         for (auto tup : llvm::zip(adaptor.getShapes(), rankDiffs)) {
286           Value shape, rankDiff;
287           std::tie(shape, rankDiff) = tup;
288           Value outOfBounds = b.create<arith::CmpIOp>(
289               loc, arith::CmpIPredicate::ult, iv, rankDiff);
290           broadcastable =
291               b.create<IfOp>(
292                    loc, outOfBounds,
293                    [&](OpBuilder &b, Location loc) {
294                      // Non existent dimensions are always broadcastable
295                      b.create<scf::YieldOp>(loc, broadcastable);
296                    },
297                    [&](OpBuilder &b, Location loc) {
298                      // Every value needs to be either 1, or the same non-1
299                      // value to be broadcastable in this dim.
300                      Value operandDimension =
301                          b.create<arith::SubIOp>(loc, indexTy, iv, rankDiff);
302                      Value dimensionExtent = b.create<tensor::ExtractOp>(
303                          loc, shape, ValueRange{operandDimension});
304 
305                      Value equalOne = b.create<arith::CmpIOp>(
306                          loc, arith::CmpIPredicate::eq, dimensionExtent, one);
307                      Value equalBroadcasted = b.create<arith::CmpIOp>(
308                          loc, arith::CmpIPredicate::eq, dimensionExtent,
309                          broadcastedDim);
310                      Value result = b.create<arith::AndIOp>(
311                          loc, broadcastable,
312                          b.create<arith::OrIOp>(loc, equalOne,
313                                                 equalBroadcasted));
314                      b.create<scf::YieldOp>(loc, result);
315                    })
316                   .getResult(0);
317         }
318 
319         b.create<scf::YieldOp>(loc, broadcastable);
320       });
321 
322   rewriter.replaceOp(op, reduceResult.getResults().front());
323   return success();
324 }
325 
326 namespace {
327 class DimOpConverter : public OpConversionPattern<DimOp> {
328   using OpConversionPattern<DimOp>::OpConversionPattern;
329 
330   LogicalResult
331   matchAndRewrite(DimOp op, OpAdaptor adaptor,
332                   ConversionPatternRewriter &rewriter) const override;
333 };
334 } // namespace
335 
336 LogicalResult
matchAndRewrite(DimOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const337 DimOpConverter::matchAndRewrite(DimOp op, OpAdaptor adaptor,
338                                 ConversionPatternRewriter &rewriter) const {
339   // Lower to dim(X, i) to get_extent(shape_of(X), i) and rely on further
340   // lowerings. This can be further optimized if needed to avoid intermediate
341   // steps.
342   auto shapeOf = rewriter.create<shape::ShapeOfOp>(op.getLoc(), op.getValue());
343   rewriter.replaceOpWithNewOp<shape::GetExtentOp>(op, op.getType(), shapeOf,
344                                                   op.getIndex());
345   return success();
346 }
347 
348 namespace {
349 class GetExtentOpConverter : public OpConversionPattern<GetExtentOp> {
350   using OpConversionPattern<GetExtentOp>::OpConversionPattern;
351 
352   LogicalResult
353   matchAndRewrite(GetExtentOp op, OpAdaptor adaptor,
354                   ConversionPatternRewriter &rewriter) const override;
355 };
356 } // namespace
357 
matchAndRewrite(GetExtentOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const358 LogicalResult GetExtentOpConverter::matchAndRewrite(
359     GetExtentOp op, OpAdaptor adaptor,
360     ConversionPatternRewriter &rewriter) const {
361   // For now, only error-free types are supported by this lowering.
362   if (isa<SizeType>(op.getType()))
363     return failure();
364 
365   // Derive shape extent directly from shape origin if possible. This
366   // circumvents the necessity to materialize the shape in memory.
367   if (auto shapeOfOp = op.getShape().getDefiningOp<ShapeOfOp>()) {
368     if (isa<ShapedType>(shapeOfOp.getArg().getType())) {
369       rewriter.replaceOpWithNewOp<tensor::DimOp>(op, shapeOfOp.getArg(),
370                                                  adaptor.getDim());
371       return success();
372     }
373   }
374 
375   rewriter.replaceOpWithNewOp<tensor::ExtractOp>(op, rewriter.getIndexType(),
376                                                  adaptor.getShape(),
377                                                  ValueRange{adaptor.getDim()});
378   return success();
379 }
380 
381 namespace {
382 class RankOpConverter : public OpConversionPattern<shape::RankOp> {
383 public:
384   using OpConversionPattern<shape::RankOp>::OpConversionPattern;
385 
386   LogicalResult
387   matchAndRewrite(shape::RankOp op, OpAdaptor adaptor,
388                   ConversionPatternRewriter &rewriter) const override;
389 };
390 } // namespace
391 
392 LogicalResult
matchAndRewrite(shape::RankOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const393 RankOpConverter::matchAndRewrite(shape::RankOp op, OpAdaptor adaptor,
394                                  ConversionPatternRewriter &rewriter) const {
395   // For now, this lowering supports only error-free types.
396   if (isa<SizeType>(op.getType()))
397     return failure();
398 
399   rewriter.replaceOpWithNewOp<tensor::DimOp>(op, adaptor.getShape(), 0);
400   return success();
401 }
402 
403 namespace {
404 /// Converts `shape.reduce` to `scf.for`.
405 struct ReduceOpConverter : public OpConversionPattern<shape::ReduceOp> {
406 public:
407   using OpConversionPattern::OpConversionPattern;
408 
409   LogicalResult
410   matchAndRewrite(shape::ReduceOp op, OpAdaptor adaptor,
411                   ConversionPatternRewriter &rewriter) const final;
412 };
413 } // namespace
414 
415 LogicalResult
matchAndRewrite(shape::ReduceOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const416 ReduceOpConverter::matchAndRewrite(shape::ReduceOp op, OpAdaptor adaptor,
417                                    ConversionPatternRewriter &rewriter) const {
418   // For now, this lowering is only defined on `tensor<?xindex>` operands.
419   if (isa<ShapeType>(op.getShape().getType()))
420     return failure();
421 
422   auto loc = op.getLoc();
423 
424   Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
425   Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
426   Type indexTy = rewriter.getIndexType();
427   Value rank =
428       rewriter.create<tensor::DimOp>(loc, indexTy, adaptor.getShape(), zero);
429 
430   auto loop = rewriter.create<scf::ForOp>(
431       loc, zero, rank, one, op.getInitVals(),
432       [&](OpBuilder &b, Location loc, Value iv, ValueRange args) {
433         Value extent = b.create<tensor::ExtractOp>(loc, adaptor.getShape(), iv);
434 
435         SmallVector<Value, 2> mappedValues{iv, extent};
436         mappedValues.append(args.begin(), args.end());
437 
438         IRMapping mapping;
439         Block *reduceBody = op.getBody();
440         mapping.map(reduceBody->getArguments(), mappedValues);
441         for (auto &nested : reduceBody->without_terminator())
442           b.clone(nested, mapping);
443 
444         SmallVector<Value, 2> mappedResults;
445         for (auto result : reduceBody->getTerminator()->getOperands())
446           mappedResults.push_back(mapping.lookup(result));
447         b.create<scf::YieldOp>(loc, mappedResults);
448       });
449 
450   rewriter.replaceOp(op, loop.getResults());
451   return success();
452 }
453 
454 namespace {
455 /// Converts `shape.shape_eq` to an `scf.for` loop. For now, the lowering is
456 /// only defined on `tensor<?xindex>` operands. The test for equality first
457 /// compares their size and, if equal, checks every extent for equality.
458 ///
459 /// Example:
460 ///
461 /// %result = shape.shape_eq %a, %b : tensor<?xindex>, tensor<?xindex>
462 ///
463 /// becomes
464 ///
465 /// %c0 = arith.constant 0 : index
466 /// %0 = dim %arg0, %c0 : tensor<?xindex>
467 /// %1 = dim %arg1, %c0 : tensor<?xindex>
468 /// %2 = arith.cmpi "eq", %0, %1 : index
469 /// %result = scf.if %2 -> (i1) {
470 ///   %c1 = arith.constant 1 : index
471 ///   %true = arith.constant true
472 ///   %4 = scf.for %arg2 = %c0 to %0 step %c1 iter_args(%arg3 = %true) -> (i1) {
473 ///     %5 = tensor.extract %arg0[%arg2] : tensor<?xindex>
474 ///     %6 = tensor.extract %arg1[%arg2] : tensor<?xindex>
475 ///     %7 = arith.cmpi "eq", %5, %6 : index
476 ///     %8 = arith.andi %arg3, %7 : i1
477 ///     scf.yield %8 : i1
478 ///   }
479 ///   scf.yield %4 : i1
480 /// } else {
481 ///   %false = arith.constant false
482 ///   scf.yield %false : i1
483 /// }
484 ///
485 struct ShapeEqOpConverter : public OpConversionPattern<ShapeEqOp> {
486   using OpConversionPattern<ShapeEqOp>::OpConversionPattern;
487 
488   LogicalResult
489   matchAndRewrite(ShapeEqOp op, OpAdaptor adaptor,
490                   ConversionPatternRewriter &rewriter) const override;
491 };
492 } // namespace
493 
494 LogicalResult
matchAndRewrite(ShapeEqOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const495 ShapeEqOpConverter::matchAndRewrite(ShapeEqOp op, OpAdaptor adaptor,
496                                     ConversionPatternRewriter &rewriter) const {
497   if (!llvm::all_of(op.getShapes(),
498                     [](Value v) { return !isa<ShapeType>(v.getType()); }))
499     return failure();
500 
501   Type i1Ty = rewriter.getI1Type();
502   if (op.getShapes().size() <= 1) {
503     rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, i1Ty,
504                                                    rewriter.getBoolAttr(true));
505     return success();
506   }
507 
508   auto loc = op.getLoc();
509   Type indexTy = rewriter.getIndexType();
510   Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
511   Value firstShape = adaptor.getShapes().front();
512   Value firstRank =
513       rewriter.create<tensor::DimOp>(loc, indexTy, firstShape, zero);
514   Value result = nullptr;
515   // Generate a linear sequence of compares, all with firstShape as lhs.
516   for (Value shape : adaptor.getShapes().drop_front(1)) {
517     Value rank = rewriter.create<tensor::DimOp>(loc, indexTy, shape, zero);
518     Value eqRank = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
519                                                   firstRank, rank);
520     auto same = rewriter.create<IfOp>(
521         loc, eqRank,
522         [&](OpBuilder &b, Location loc) {
523           Value one = b.create<arith::ConstantIndexOp>(loc, 1);
524           Value init =
525               b.create<arith::ConstantOp>(loc, i1Ty, b.getBoolAttr(true));
526           auto loop = b.create<scf::ForOp>(
527               loc, zero, firstRank, one, ValueRange{init},
528               [&](OpBuilder &b, Location nestedLoc, Value iv, ValueRange args) {
529                 Value conj = args[0];
530                 Value lhsExtent =
531                     b.create<tensor::ExtractOp>(loc, firstShape, iv);
532                 Value rhsExtent = b.create<tensor::ExtractOp>(loc, shape, iv);
533                 Value eqExtent = b.create<arith::CmpIOp>(
534                     loc, arith::CmpIPredicate::eq, lhsExtent, rhsExtent);
535                 Value conjNext = b.create<arith::AndIOp>(loc, conj, eqExtent);
536                 b.create<scf::YieldOp>(loc, ValueRange({conjNext}));
537               });
538           b.create<scf::YieldOp>(loc, loop.getResults());
539         },
540         [&](OpBuilder &b, Location loc) {
541           Value result =
542               b.create<arith::ConstantOp>(loc, i1Ty, b.getBoolAttr(false));
543           b.create<scf::YieldOp>(loc, result);
544         });
545     result = !result ? same.getResult(0)
546                      : rewriter.create<arith::AndIOp>(loc, result,
547                                                       same.getResult(0));
548   }
549   rewriter.replaceOp(op, result);
550   return success();
551 }
552 
553 namespace {
554 class ShapeOfOpConversion : public OpConversionPattern<ShapeOfOp> {
555 public:
556   using OpConversionPattern<ShapeOfOp>::OpConversionPattern;
557 
558   LogicalResult
559   matchAndRewrite(ShapeOfOp op, OpAdaptor adaptor,
560                   ConversionPatternRewriter &rewriter) const override;
561 };
562 } // namespace
563 
matchAndRewrite(ShapeOfOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const564 LogicalResult ShapeOfOpConversion::matchAndRewrite(
565     ShapeOfOp op, OpAdaptor adaptor,
566     ConversionPatternRewriter &rewriter) const {
567 
568   // For now, only error-free types are supported by this lowering.
569   if (isa<ShapeType>(op.getType()))
570     return failure();
571 
572   // For ranked tensor arguments, lower to `tensor.from_elements`.
573   auto loc = op.getLoc();
574   Value tensor = adaptor.getArg();
575   Type tensorTy = tensor.getType();
576   if (isa<RankedTensorType>(tensorTy)) {
577 
578     // Build values for individual extents.
579     SmallVector<Value, 8> extentValues;
580     RankedTensorType rankedTensorTy = cast<RankedTensorType>(tensorTy);
581     int64_t rank = rankedTensorTy.getRank();
582     for (int64_t i = 0; i < rank; i++) {
583       if (rankedTensorTy.isDynamicDim(i)) {
584         Value extent = rewriter.create<tensor::DimOp>(loc, tensor, i);
585         extentValues.push_back(extent);
586       } else {
587         Value extent = rewriter.create<arith::ConstantIndexOp>(
588             loc, rankedTensorTy.getDimSize(i));
589         extentValues.push_back(extent);
590       }
591     }
592 
593     // Materialize extent tensor.
594     Value staticExtentTensor = rewriter.create<tensor::FromElementsOp>(
595         loc, RankedTensorType::get({rank}, rewriter.getIndexType()),
596         extentValues);
597     rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(),
598                                                 staticExtentTensor);
599     return success();
600   }
601 
602   // Lower to `tensor.generate` otherwise.
603   auto *ctx = rewriter.getContext();
604   Value rank = rewriter.create<tensor::RankOp>(loc, tensor);
605   rewriter.replaceOpWithNewOp<tensor::GenerateOp>(
606       op, getExtentTensorType(ctx), ValueRange{rank},
607       [&](OpBuilder &b, Location loc, ValueRange args) {
608         Value dim = args.front();
609         Value extent = b.create<tensor::DimOp>(loc, tensor, dim);
610         b.create<tensor::YieldOp>(loc, extent);
611       });
612 
613   return success();
614 }
615 
616 namespace {
617 class SplitAtOpConversion : public OpConversionPattern<SplitAtOp> {
618 public:
619   using OpConversionPattern<SplitAtOp>::OpConversionPattern;
620 
621   LogicalResult
622   matchAndRewrite(SplitAtOp op, OpAdaptor adaptor,
623                   ConversionPatternRewriter &rewriter) const override;
624 };
625 } // namespace
626 
matchAndRewrite(SplitAtOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const627 LogicalResult SplitAtOpConversion::matchAndRewrite(
628     SplitAtOp op, OpAdaptor adaptor,
629     ConversionPatternRewriter &rewriter) const {
630   // Error conditions are not implemented, only lower if all operands and
631   // results are extent tensors.
632   if (llvm::any_of(ValueRange{op.getOperand(), op.getHead(), op.getTail()},
633                    [](Value v) { return isa<ShapeType>(v.getType()); }))
634     return failure();
635 
636   ImplicitLocOpBuilder b(op.getLoc(), rewriter);
637   Value zero = b.create<arith::ConstantIndexOp>(0);
638   Value rank = b.create<tensor::DimOp>(adaptor.getOperand(), zero);
639 
640   // index < 0 ? index + rank : index
641   Value originalIndex = adaptor.getIndex();
642   Value add = b.create<arith::AddIOp>(originalIndex, rank);
643   Value indexIsNegative =
644       b.create<arith::CmpIOp>(arith::CmpIPredicate::slt, originalIndex, zero);
645   Value index = b.create<arith::SelectOp>(indexIsNegative, add, originalIndex);
646 
647   Value one = b.create<arith::ConstantIndexOp>(1);
648   Value head =
649       b.create<tensor::ExtractSliceOp>(adaptor.getOperand(), zero, index, one);
650   Value tailSize = b.create<arith::SubIOp>(rank, index);
651   Value tail = b.create<tensor::ExtractSliceOp>(adaptor.getOperand(), index,
652                                                 tailSize, one);
653   rewriter.replaceOp(op, {head, tail});
654   return success();
655 }
656 
657 namespace {
658 class ToExtentTensorOpConversion
659     : public OpConversionPattern<ToExtentTensorOp> {
660 public:
661   using OpConversionPattern<ToExtentTensorOp>::OpConversionPattern;
662 
663   LogicalResult
matchAndRewrite(ToExtentTensorOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const664   matchAndRewrite(ToExtentTensorOp op, OpAdaptor adaptor,
665                   ConversionPatternRewriter &rewriter) const override {
666     if (!isa<RankedTensorType>(adaptor.getInput().getType()))
667       return rewriter.notifyMatchFailure(op, "input needs to be a tensor");
668 
669     rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(),
670                                                 adaptor.getInput());
671     return success();
672   }
673 };
674 } // namespace
675 
676 namespace {
677 /// Import the Shape Ops to Std Patterns.
678 #include "ShapeToStandard.cpp.inc"
679 } // namespace
680 
681 namespace {
682 /// Conversion pass.
683 class ConvertShapeToStandardPass
684     : public impl::ConvertShapeToStandardBase<ConvertShapeToStandardPass> {
685 
686   void runOnOperation() override;
687 };
688 } // namespace
689 
runOnOperation()690 void ConvertShapeToStandardPass::runOnOperation() {
691   // Setup target legality.
692   MLIRContext &ctx = getContext();
693   ConversionTarget target(ctx);
694   target.addLegalDialect<arith::ArithDialect, SCFDialect,
695                          tensor::TensorDialect>();
696   target.addLegalOp<CstrRequireOp, func::FuncOp, ModuleOp>();
697 
698   // Setup conversion patterns.
699   RewritePatternSet patterns(&ctx);
700   populateShapeToStandardConversionPatterns(patterns);
701 
702   // Apply conversion.
703   auto module = getOperation();
704   if (failed(applyPartialConversion(module, target, std::move(patterns))))
705     signalPassFailure();
706 }
707 
populateShapeToStandardConversionPatterns(RewritePatternSet & patterns)708 void mlir::populateShapeToStandardConversionPatterns(
709     RewritePatternSet &patterns) {
710   // clang-format off
711   populateWithGenerated(patterns);
712   patterns.add<
713       AnyOpConversion,
714       BinaryOpConversion<AddOp, arith::AddIOp>,
715       BinaryOpConversion<MulOp, arith::MulIOp>,
716       BroadcastOpConverter,
717       ConstShapeOpConverter,
718       ConstSizeOpConversion,
719       DimOpConverter,
720       IsBroadcastableOpConverter,
721       GetExtentOpConverter,
722       RankOpConverter,
723       ReduceOpConverter,
724       ShapeEqOpConverter,
725       ShapeOfOpConversion,
726       SplitAtOpConversion,
727       ToExtentTensorOpConversion>(patterns.getContext());
728   // clang-format on
729 }
730 
731 std::unique_ptr<OperationPass<ModuleOp>>
createConvertShapeToStandardPass()732 mlir::createConvertShapeToStandardPass() {
733   return std::make_unique<ConvertShapeToStandardPass>();
734 }
735