1include "mlir/IR/PatternBase.td" 2include "mlir/Dialect/Shape/IR/ShapeOps.td" 3include "mlir/Dialect/Tensor/IR/TensorOps.td" 4 5def AllInputShapesEq : Constraint<CPred< [{ 6 llvm::all_equal($0) 7}]>>; 8 9def HasSingleElement : Constraint<CPred< [{ 10 $0.size() == 1 11}]>>; 12 13def HasStaticShape : Constraint<CPred< [{ 14 ::llvm::dyn_cast<ShapedType>($0.getType()).hasStaticShape() 15}]>>; 16 17// Helper that takes the first element of a range. 18def TakeFront : NativeCodeCall<"$0.front()">; 19 20// Canonicalization patterns. 21 22def AssumingAllOneOp : Pat<(Shape_AssumingAllOp $args), 23 (replaceWithValue $args), 24 [(HasSingleElement $args)]>; 25 26def CstrBroadcastableEqOps : Pat<(Shape_CstrBroadcastableOp:$op $shapes), 27 (Shape_ConstWitnessOp ConstBoolAttrTrue), 28 [(AllInputShapesEq $shapes)]>; 29 30def CstrEqEqOps : Pat<(Shape_CstrEqOp:$op $shapes), 31 (Shape_ConstWitnessOp ConstBoolAttrTrue), 32 [(AllInputShapesEq $shapes)]>; 33 34def IndexToSizeToIndexCanonicalization : Pat< 35 (Shape_SizeToIndexOp (Shape_IndexToSizeOp $arg)), 36 (replaceWithValue $arg)>; 37 38def SizeToIndexToSizeCanonicalization : Pat< 39 (Shape_IndexToSizeOp (Shape_SizeToIndexOp $arg)), 40 (replaceWithValue $arg)>; 41 42// Fold tensor.cast(const_shape) to const_shape. This changes the type of 43// const_shape to the destination type of the cast. 44def TensorCastConstShape : Pat < 45 (Tensor_CastOp:$res (Shape_ConstShapeOp $arg)), (Shape_ConstShapeOp $arg), 46 [(HasStaticShape $res)]>; 47 48// tensor.extract from shape_of -> tensor.dim. We can take the first index 49// because shape_of always returns a 1D tensor. 50def ExtractFromShapeOfExtentTensor : Pat< 51 (Tensor_ExtractOp (Shape_ShapeOfOp $arg), $indices), 52 (Tensor_DimOp $arg, (TakeFront $indices))>; 53