xref: /llvm-project/mlir/test/Dialect/Vector/shape-cast-folder.mlir (revision 1f5e8263b920f591c517a5dc562cccad39dd6ec7)
1// RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s
2
3///----------------------------------------------------------------------------------------
4/// [Pattern: ShapeCastOpFolder]
5///----------------------------------------------------------------------------------------
6
7// CHECK-LABEL: func @fixed_width
8//  CHECK-SAME: %[[A0:.*0]]: vector<2x4xf32>
9//   CHECK-NOT: vector.shape_cast
10//       CHECK: return %[[A0]] : vector<2x4xf32>
11func.func @fixed_width(%arg0 : vector<2x4xf32>) -> vector<2x4xf32> {
12  %0 = vector.shape_cast %arg0 : vector<2x4xf32> to vector<8xf32>
13  %1 = vector.shape_cast %0 : vector<8xf32> to vector<2x4xf32>
14  return %1 : vector<2x4xf32>
15}
16
17// CHECK-LABEL: func @scalable
18//  CHECK-SAME: %[[A0:.*0]]: vector<2x[4]xf32>
19//   CHECK-NOT: vector.shape_cast
20//       CHECK: return %[[A0]] : vector<2x[4]xf32>
21func.func @scalable(%arg0 : vector<2x[4]xf32>) -> vector<2x[4]xf32> {
22  %0 = vector.shape_cast %arg0 : vector<2x[4]xf32> to vector<[8]xf32>
23  %1 = vector.shape_cast %0 : vector<[8]xf32> to vector<2x[4]xf32>
24  return %1 : vector<2x[4]xf32>
25}
26
27// ============================================================================
28//  TD sequence
29// ============================================================================
30module attributes {transform.with_named_sequence} {
31  transform.named_sequence @__transform_main(%root : !transform.any_op {transform.readonly}) {
32    %func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func">
33    transform.apply_patterns to %func_op {
34      transform.apply_patterns.vector.drop_unit_dims_with_shape_cast
35    } : !transform.op<"func.func">
36    transform.yield
37  }
38}
39