xref: /llvm-project/mlir/test/Dialect/Linalg/tile-softmax.mlir (revision 2c1c67674cb3beb4e091a9f446de5858631cf8ae)
1// RUN: mlir-opt %s -transform-interpreter -canonicalize --split-input-file | FileCheck %s
2
3// Check that we can tile softmax on tensors.
4// The tiling here is 2x3.
5// So the shape used in the inner loop should be 2x3x256, however since 3
6// doesn't divide the second dimension (64), we should see a '?' in the shape.
7// The actual size, used through extract_slice/insert_slice, should come from a
8// `min(64 - current iteration index, 3)`
9
10// CHECK: #[[$MIN_MAP:.*]] = affine_map<(d0) -> (-d0 + 64, 3)>
11// CHECK-LABEL:   func.func @softmax(
12// CHECK-SAME:                       %[[VAL_0:.*]]: tensor<16x64x256xf32>) -> tensor<16x64x256xf32> {
13// CHECK-DAG:       %[[C3:.*]] = arith.constant 3 : index
14// CHECK-DAG:       %[[C2:.*]] = arith.constant 2 : index
15// CHECK-DAG:       %[[C64:.*]] = arith.constant 64 : index
16// CHECK-DAG:       %[[C16:.*]] = arith.constant 16 : index
17// CHECK-DAG:       %[[C0:.*]] = arith.constant 0 : index
18// CHECK:           %[[TENSOR_EMPTY:.*]] = tensor.empty() : tensor<16x64x256xf32>
19// CHECK:           %[[VAL_7:.*]] = scf.for %[[VAL_8:.*]] = %[[C0]] to %[[C16]] step %[[C2]] iter_args(%[[VAL_9:.*]] = %[[TENSOR_EMPTY]]) -> (tensor<16x64x256xf32>) {
20// CHECK:             %[[VAL_10:.*]] = scf.for %[[VAL_11:.*]] = %[[C0]] to %[[C64]] step %[[C3]] iter_args(%[[VAL_12:.*]] = %[[VAL_9]]) -> (tensor<16x64x256xf32>) {
21// CHECK:               %[[VAL_13:.*]] = affine.min #[[$MIN_MAP]](%[[VAL_11]])
22// CHECK:               %[[VAL_14:.*]] = tensor.extract_slice %[[VAL_0]]{{\[}}%[[VAL_8]], %[[VAL_11]], 0] [2, %[[VAL_13]], 256] [1, 1, 1] : tensor<16x64x256xf32> to tensor<2x?x256xf32>
23// CHECK:               %[[VAL_15:.*]] = tensor.extract_slice %[[VAL_12]]{{\[}}%[[VAL_8]], %[[VAL_11]], 0] [2, %[[VAL_13]], 256] [1, 1, 1] : tensor<16x64x256xf32> to tensor<2x?x256xf32>
24// CHECK:               %[[VAL_16:.*]] = linalg.softmax dimension(1) ins(%[[VAL_14]] : tensor<2x?x256xf32>) outs(%[[VAL_15]] : tensor<2x?x256xf32>) -> tensor<2x?x256xf32>
25// CHECK:               %[[VAL_17:.*]] = tensor.insert_slice %[[VAL_16]] into %[[VAL_12]]{{\[}}%[[VAL_8]], %[[VAL_11]], 0] [2, %[[VAL_13]], 256] [1, 1, 1] : tensor<2x?x256xf32> into tensor<16x64x256xf32>
26// CHECK:               scf.yield %[[VAL_17]] : tensor<16x64x256xf32>
27// CHECK:             }
28// CHECK:             scf.yield %[[VAL_18:.*]] : tensor<16x64x256xf32>
29// CHECK:           }
30// CHECK:           return %[[VAL_19:.*]] : tensor<16x64x256xf32>
31// CHECK:         }
32func.func @softmax(%arg0: tensor<16x64x256xf32>) -> tensor<16x64x256xf32> {
33  %0 = tensor.empty() : tensor<16x64x256xf32>
34  %1 = linalg.softmax
35         dimension(1) ins(%arg0 : tensor<16x64x256xf32>) outs(%0 : tensor<16x64x256xf32>) -> tensor<16x64x256xf32>
36  return %1 : tensor<16x64x256xf32>
37}
38
39module attributes {transform.with_named_sequence} {
40  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
41    %0 = transform.structured.match ops{["linalg.softmax"]} in %arg1 : (!transform.any_op) -> !transform.any_op
42    %1, %loop:2 = transform.structured.tile_using_for %0 tile_sizes [2, 3] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
43    transform.yield
44  }
45}
46
47// -----
48
49// Test the softmax tiling interface with the tile_using_forall transform and
50// check that it composes properly with the fuse transform.
51// This should sink the linalg.generic inside the scf.forall and run that
52// generic on 2x4x256 tensors (2==16/8, 4==64/16).
53
54// CHECK: #[[$TIMES2_MAP:.*]] = affine_map<(d0) -> (d0 * 2)>
55// CHECK: #[[$TIMES4_MAP:.*]] = affine_map<(d0) -> (d0 * 4)>
56// CHECK-LABEL:   func.func @softmax_tile_n_fuse(
57// CHECK-SAME:                       %[[VAL_0:.*]]: tensor<16x64x256xf32>) -> tensor<16x64x256xf32> {
58// CHECK:           %[[VAL_1:.*]] = arith.constant 1.000000e+00 : f32
59// CHECK:           %[[VAL_2:.*]] = tensor.empty() : tensor<16x64x256xf32>
60// CHECK:           %[[VAL_3:.*]] = tensor.empty() : tensor<16x64x256xf32>
61// CHECK:           %[[VAL_4:.*]] = scf.forall (%[[VAL_5:.*]], %[[VAL_6:.*]]) in (8, 16) shared_outs(%[[VAL_7:.*]] = %[[VAL_3]]) -> (tensor<16x64x256xf32>) {
62// CHECK:             %[[VAL_8:.*]] = affine.apply #[[$TIMES2_MAP]](%[[VAL_5]])
63// CHECK:             %[[VAL_9:.*]] = affine.apply #[[$TIMES4_MAP]](%[[VAL_6]])
64// CHECK:             %[[VAL_10:.*]] = affine.apply #[[$TIMES2_MAP]](%[[VAL_5]])
65// CHECK:             %[[VAL_11:.*]] = affine.apply #[[$TIMES4_MAP]](%[[VAL_6]])
66// CHECK:             %[[VAL_12:.*]] = affine.apply #[[$TIMES2_MAP]](%[[VAL_5]])
67// CHECK:             %[[VAL_13:.*]] = affine.apply #[[$TIMES4_MAP]](%[[VAL_6]])
68// CHECK:             %[[VAL_14:.*]] = tensor.extract_slice %[[VAL_0]]{{\[}}%[[VAL_10]], %[[VAL_11]], 0] [2, 4, 256] [1, 1, 1] : tensor<16x64x256xf32> to tensor<2x4x256xf32>
69// CHECK:             %[[VAL_15:.*]] = tensor.extract_slice %[[VAL_2]]{{\[}}%[[VAL_12]], %[[VAL_13]], 0] [2, 4, 256] [1, 1, 1] : tensor<16x64x256xf32> to tensor<2x4x256xf32>
70// CHECK:             %[[VAL_16:.*]] = linalg.generic {indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[VAL_14]] : tensor<2x4x256xf32>) outs(%[[VAL_15]] : tensor<2x4x256xf32>) {
71// CHECK:             ^bb0(%[[VAL_17:.*]]: f32, %[[VAL_18:.*]]: f32):
72// CHECK:               %[[VAL_19:.*]] = arith.addf %[[VAL_18]], %[[VAL_1]] : f32
73// CHECK:               linalg.yield %[[VAL_19]] : f32
74// CHECK:             } -> tensor<2x4x256xf32>
75// CHECK:             %[[VAL_20:.*]] = tensor.extract_slice %[[VAL_7]]{{\[}}%[[VAL_8]], %[[VAL_9]], 0] [2, 4, 256] [1, 1, 1] : tensor<16x64x256xf32> to tensor<2x4x256xf32>
76// CHECK:             %[[VAL_21:.*]] = linalg.softmax dimension(1) ins(%[[VAL_22:.*]] : tensor<2x4x256xf32>) outs(%[[VAL_20]] : tensor<2x4x256xf32>) -> tensor<2x4x256xf32>
77// CHECK:             scf.forall.in_parallel {
78// CHECK:               tensor.parallel_insert_slice %[[VAL_21]] into %[[VAL_7]]{{\[}}%[[VAL_8]], %[[VAL_9]], 0] [2, 4, 256] [1, 1, 1] : tensor<2x4x256xf32> into tensor<16x64x256xf32>
79// CHECK:             }
80// CHECK:           }
81// CHECK:           return %[[VAL_23:.*]] : tensor<16x64x256xf32>
82// CHECK:         }
83
84func.func @softmax_tile_n_fuse(%arg0: tensor<16x64x256xf32>) -> tensor<16x64x256xf32> {
85  %empty = tensor.empty() : tensor<16x64x256xf32>
86  %cst = arith.constant 1.000000e+00 : f32
87  %eltwise = linalg.generic
88      {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
89                        affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
90       iterator_types = ["parallel", "parallel", "parallel"]
91      }
92      ins(%arg0 : tensor<16x64x256xf32>)
93      outs(%empty : tensor<16x64x256xf32>) {
94    ^bb0(%arg2: f32, %arg3: f32):
95      %arg3Plus1 = arith.addf %arg3, %cst : f32
96      linalg.yield %arg3Plus1 : f32
97    } -> tensor<16x64x256xf32>
98
99  %0 = tensor.empty() : tensor<16x64x256xf32>
100  %1 = linalg.softmax
101         dimension(1) ins(%eltwise : tensor<16x64x256xf32>) outs(%0 : tensor<16x64x256xf32>) -> tensor<16x64x256xf32>
102  return %1 : tensor<16x64x256xf32>
103}
104
105module attributes {transform.with_named_sequence} {
106  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
107    %0 = transform.structured.match ops{["linalg.softmax"]} in %arg1 : (!transform.any_op) -> !transform.any_op
108
109    // Tile the root.
110    %tiled_op, %forall_op = transform.structured.tile_using_forall %0 num_threads [8, 16]
111         : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
112
113    // Fuse all producers.
114    %1 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
115    transform.structured.fuse_into_containing_op %1 into %forall_op
116      : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
117      transform.yield
118  }
119}
120// -----
121
122// Same as the previous test but on memrefs.
123
124// CHECK: #[[$MIN_MAP:.*]] = affine_map<(d0) -> (-d0 + 64, 3)>
125// CHECK-LABEL:   func.func @softmax_memref(
126// CHECK-SAME:                              %[[VAL_0:.*]]: memref<16x64x256xf32>,
127// CHECK-SAME:                              %[[VAL_1:.*]]: memref<16x64x256xf32>) {
128// CHECK-DAG:       %[[C0:.*]] = arith.constant 0 : index
129// CHECK-DAG:       %[[C16:.*]] = arith.constant 16 : index
130// CHECK-DAG:       %[[C64:.*]] = arith.constant 64 : index
131// CHECK-DAG:       %[[C2:.*]] = arith.constant 2 : index
132// CHECK-DAG:       %[[C3:.*]] = arith.constant 3 : index
133// CHECK:           scf.for %[[VAL_7:.*]] = %[[C0]] to %[[C16]] step %[[C2]] {
134// CHECK:             scf.for %[[VAL_8:.*]] = %[[C0]] to %[[C64]] step %[[C3]] {
135// CHECK:               %[[VAL_9:.*]] = affine.min #[[$MIN_MAP]](%[[VAL_8]])
136// CHECK:               %[[VAL_10:.*]] = memref.subview %[[VAL_0]]{{\[}}%[[VAL_7]], %[[VAL_8]], 0] [2, %[[VAL_9]], 256] [1, 1, 1] : memref<16x64x256xf32> to memref<2x?x256xf32, strided<[16384, 256, 1], offset: ?>>
137// CHECK:               %[[VAL_11:.*]] = memref.subview %[[VAL_1]]{{\[}}%[[VAL_7]], %[[VAL_8]], 0] [2, %[[VAL_9]], 256] [1, 1, 1] : memref<16x64x256xf32> to memref<2x?x256xf32, strided<[16384, 256, 1], offset: ?>>
138// CHECK:               linalg.softmax dimension(1) ins(%[[VAL_10]] : memref<2x?x256xf32, strided<[16384, 256, 1], offset: ?>>) outs(%[[VAL_11]] : memref<2x?x256xf32, strided<[16384, 256, 1], offset: ?>>)
139// CHECK:             }
140// CHECK:           }
141// CHECK:           return
142// CHECK:         }
143func.func @softmax_memref(%arg0: memref<16x64x256xf32>, %arg1: memref<16x64x256xf32>) {
144  linalg.softmax
145    dimension(1) ins(%arg0 : memref<16x64x256xf32>) outs(%arg1 : memref<16x64x256xf32>)
146  return
147}
148
149module attributes {transform.with_named_sequence} {
150  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
151    %0 = transform.structured.match ops{["linalg.softmax"]} in %arg1 : (!transform.any_op) -> !transform.any_op
152    %1, %loop:2 = transform.structured.tile_using_for %0 tile_sizes [2, 3] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
153    transform.yield
154  }
155}
156