xref: /llvm-project/mlir/test/Dialect/Vector/vector-deinterleave-lowering-transforms.mlir (revision b87a80d4ebca9e1c065f0d2762e500078c4badca)
1// RUN: mlir-opt %s --transform-interpreter | FileCheck %s
2
3// CHECK-LABEL: @vector_deinterleave_2d
4// CHECK-SAME: %[[SRC:.*]]: vector<2x8xi32>) -> (vector<2x4xi32>, vector<2x4xi32>)
5func.func @vector_deinterleave_2d(%a: vector<2x8xi32>) -> (vector<2x4xi32>, vector<2x4xi32>) {
6  // CHECK: %[[CST:.*]] = arith.constant dense<0>
7  // CHECK: %[[SRC_0:.*]] = vector.extract %[[SRC]][0]
8  // CHECK: %[[UNZIP_0:.*]], %[[UNZIP_1:.*]] = vector.deinterleave %[[SRC_0]]
9  // CHECK: %[[RES_0:.*]] = vector.insert %[[UNZIP_0]], %[[CST]] [0]
10  // CHECK: %[[RES_1:.*]] = vector.insert %[[UNZIP_1]], %[[CST]] [0]
11  // CHECK: %[[SRC_1:.*]] = vector.extract %[[SRC]][1]
12  // CHECK: %[[UNZIP_2:.*]], %[[UNZIP_3:.*]] = vector.deinterleave %[[SRC_1]]
13  // CHECK: %[[RES_2:.*]] = vector.insert %[[UNZIP_2]], %[[RES_0]] [1]
14  // CHECK: %[[RES_3:.*]] = vector.insert %[[UNZIP_3]], %[[RES_1]] [1]
15  // CHECK-NEXT: return %[[RES_2]], %[[RES_3]] : vector<2x4xi32>, vector<2x4xi32>
16  %0, %1 = vector.deinterleave %a : vector<2x8xi32> -> vector<2x4xi32>
17  return %0, %1 : vector<2x4xi32>, vector<2x4xi32>
18}
19
20// CHECK-LABEL: @vector_deinterleave_2d_scalable
21// CHECK-SAME: %[[SRC:.*]]: vector<2x[8]xi32>) -> (vector<2x[4]xi32>, vector<2x[4]xi32>)
22func.func @vector_deinterleave_2d_scalable(%a: vector<2x[8]xi32>) -> (vector<2x[4]xi32>, vector<2x[4]xi32>) {
23  // CHECK: %[[CST:.*]] = arith.constant dense<0>
24  // CHECK: %[[SRC_0:.*]] = vector.extract %[[SRC]][0]
25  // CHECK: %[[UNZIP_0:.*]], %[[UNZIP_1:.*]] = vector.deinterleave %[[SRC_0]]
26  // CHECK: %[[RES_0:.*]] = vector.insert %[[UNZIP_0]], %[[CST]] [0]
27  // CHECK: %[[RES_1:.*]] = vector.insert %[[UNZIP_1]], %[[CST]] [0]
28  // CHECK: %[[SRC_1:.*]] = vector.extract %[[SRC]][1]
29  // CHECK: %[[UNZIP_2:.*]], %[[UNZIP_3:.*]] = vector.deinterleave %[[SRC_1]]
30  // CHECK: %[[RES_2:.*]] = vector.insert %[[UNZIP_2]], %[[RES_0]] [1]
31  // CHECK: %[[RES_3:.*]] = vector.insert %[[UNZIP_3]], %[[RES_1]] [1]
32  // CHECK-NEXT: return %[[RES_2]], %[[RES_3]] : vector<2x[4]xi32>, vector<2x[4]xi32>
33  %0, %1 = vector.deinterleave %a : vector<2x[8]xi32> -> vector<2x[4]xi32>
34  return %0, %1 : vector<2x[4]xi32>, vector<2x[4]xi32>
35}
36
37// CHECK-LABEL: @vector_deinterleave_4d
38// CHECK-SAME: %[[SRC:.*]]: vector<1x2x3x8xi64>) -> (vector<1x2x3x4xi64>, vector<1x2x3x4xi64>)
39func.func @vector_deinterleave_4d(%a: vector<1x2x3x8xi64>) -> (vector<1x2x3x4xi64>, vector<1x2x3x4xi64>) {
40    // CHECK: %[[SRC_0:.*]] = vector.extract %[[SRC]][0, 0, 0] : vector<8xi64> from vector<1x2x3x8xi64>
41    // CHECK: %[[UNZIP_0:.*]], %[[UNZIP_1:.*]] = vector.deinterleave %[[SRC_0]] : vector<8xi64> -> vector<4xi64>
42    // CHECK: %[[RES_0:.*]] = vector.insert %[[UNZIP_0]], %{{.*}} [0, 0, 0] : vector<4xi64> into vector<1x2x3x4xi64>
43    // CHECK: %[[RES_1:.*]] = vector.insert %[[UNZIP_1]], %{{.*}} [0, 0, 0] : vector<4xi64> into vector<1x2x3x4xi64>
44    // CHECK-COUNT-5: vector.deinterleave %{{.*}} : vector<8xi64> -> vector<4xi64>
45    %0, %1 = vector.deinterleave %a : vector<1x2x3x8xi64> -> vector<1x2x3x4xi64>
46    return %0, %1 : vector<1x2x3x4xi64>, vector<1x2x3x4xi64>
47}
48
49// CHECK-LABEL: @vector_deinterleave_nd_with_scalable_dim
50func.func @vector_deinterleave_nd_with_scalable_dim(
51  %a: vector<1x3x[2]x2x3x8xf16>) -> (vector<1x3x[2]x2x3x4xf16>, vector<1x3x[2]x2x3x4xf16>) {
52  // The scalable dim blocks unrolling so only the first two dims are unrolled.
53  // CHECK-COUNT-3: vector.deinterleave %{{.*}} : vector<[2]x2x3x8xf16>
54  %0, %1 = vector.deinterleave %a: vector<1x3x[2]x2x3x8xf16> -> vector<1x3x[2]x2x3x4xf16>
55  return %0, %1 : vector<1x3x[2]x2x3x4xf16>, vector<1x3x[2]x2x3x4xf16>
56}
57
58module attributes {transform.with_named_sequence} {
59  transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
60    %f = transform.structured.match ops{["func.func"]} in %module_op
61      : (!transform.any_op) -> !transform.any_op
62
63    transform.apply_patterns to %f {
64      transform.apply_patterns.vector.lower_interleave
65    } : !transform.any_op
66    transform.yield
67  }
68}
69