xref: /llvm-project/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td (revision 5c8ce6d5761ed6a9a39ef5a77aa45d8b6095e0f5)
1include "mlir/IR/PatternBase.td"
2include "mlir/Dialect/Shape/IR/ShapeOps.td"
3include "mlir/Dialect/Tensor/IR/TensorOps.td"
4
5def AllInputShapesEq : Constraint<CPred< [{
6  llvm::all_equal($0)
7}]>>;
8
9def HasSingleElement : Constraint<CPred< [{
10  $0.size() == 1
11}]>>;
12
13def HasStaticShape : Constraint<CPred< [{
14  ::llvm::dyn_cast<ShapedType>($0.getType()).hasStaticShape()
15}]>>;
16
17// Helper that takes the first element of a range.
18def TakeFront : NativeCodeCall<"$0.front()">;
19
20// Canonicalization patterns.
21
22def AssumingAllOneOp : Pat<(Shape_AssumingAllOp $args),
23                           (replaceWithValue $args),
24                           [(HasSingleElement $args)]>;
25
26def CstrBroadcastableEqOps : Pat<(Shape_CstrBroadcastableOp:$op $shapes),
27  (Shape_ConstWitnessOp ConstBoolAttrTrue),
28  [(AllInputShapesEq $shapes)]>;
29
30def CstrEqEqOps : Pat<(Shape_CstrEqOp:$op $shapes),
31  (Shape_ConstWitnessOp ConstBoolAttrTrue),
32  [(AllInputShapesEq $shapes)]>;
33
34def IndexToSizeToIndexCanonicalization : Pat<
35  (Shape_SizeToIndexOp (Shape_IndexToSizeOp $arg)),
36  (replaceWithValue $arg)>;
37
38def SizeToIndexToSizeCanonicalization : Pat<
39  (Shape_IndexToSizeOp (Shape_SizeToIndexOp $arg)),
40  (replaceWithValue $arg)>;
41
42// Fold tensor.cast(const_shape) to const_shape. This changes the type of
43// const_shape to the destination type of the cast.
44def TensorCastConstShape : Pat <
45  (Tensor_CastOp:$res (Shape_ConstShapeOp $arg)), (Shape_ConstShapeOp $arg),
46  [(HasStaticShape $res)]>;
47
48// tensor.extract from shape_of -> tensor.dim. We can take the first index
49// because shape_of always returns a 1D tensor.
50def ExtractFromShapeOfExtentTensor : Pat<
51  (Tensor_ExtractOp (Shape_ShapeOfOp $arg), $indices),
52  (Tensor_DimOp $arg, (TakeFront $indices))>;
53