1// RUN: mlir-opt %s -transform-interpreter -canonicalize --split-input-file | FileCheck %s 2 3// Check that we can tile softmax on tensors. 4// The tiling here is 2x3. 5// So the shape used in the inner loop should be 2x3x256, however since 3 6// doesn't divide the second dimension (64), we should see a '?' in the shape. 7// The actual size, used through extract_slice/insert_slice, should come from a 8// `min(64 - current iteration index, 3)` 9 10// CHECK: #[[$MIN_MAP:.*]] = affine_map<(d0) -> (-d0 + 64, 3)> 11// CHECK-LABEL: func.func @softmax( 12// CHECK-SAME: %[[VAL_0:.*]]: tensor<16x64x256xf32>) -> tensor<16x64x256xf32> { 13// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index 14// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index 15// CHECK-DAG: %[[C64:.*]] = arith.constant 64 : index 16// CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index 17// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index 18// CHECK: %[[TENSOR_EMPTY:.*]] = tensor.empty() : tensor<16x64x256xf32> 19// CHECK: %[[VAL_7:.*]] = scf.for %[[VAL_8:.*]] = %[[C0]] to %[[C16]] step %[[C2]] iter_args(%[[VAL_9:.*]] = %[[TENSOR_EMPTY]]) -> (tensor<16x64x256xf32>) { 20// CHECK: %[[VAL_10:.*]] = scf.for %[[VAL_11:.*]] = %[[C0]] to %[[C64]] step %[[C3]] iter_args(%[[VAL_12:.*]] = %[[VAL_9]]) -> (tensor<16x64x256xf32>) { 21// CHECK: %[[VAL_13:.*]] = affine.min #[[$MIN_MAP]](%[[VAL_11]]) 22// CHECK: %[[VAL_14:.*]] = tensor.extract_slice %[[VAL_0]]{{\[}}%[[VAL_8]], %[[VAL_11]], 0] [2, %[[VAL_13]], 256] [1, 1, 1] : tensor<16x64x256xf32> to tensor<2x?x256xf32> 23// CHECK: %[[VAL_15:.*]] = tensor.extract_slice %[[VAL_12]]{{\[}}%[[VAL_8]], %[[VAL_11]], 0] [2, %[[VAL_13]], 256] [1, 1, 1] : tensor<16x64x256xf32> to tensor<2x?x256xf32> 24// CHECK: %[[VAL_16:.*]] = linalg.softmax dimension(1) ins(%[[VAL_14]] : tensor<2x?x256xf32>) outs(%[[VAL_15]] : tensor<2x?x256xf32>) -> tensor<2x?x256xf32> 25// CHECK: %[[VAL_17:.*]] = tensor.insert_slice %[[VAL_16]] into %[[VAL_12]]{{\[}}%[[VAL_8]], %[[VAL_11]], 0] [2, %[[VAL_13]], 256] [1, 1, 1] : tensor<2x?x256xf32> into tensor<16x64x256xf32> 26// CHECK: scf.yield %[[VAL_17]] : tensor<16x64x256xf32> 27// CHECK: } 28// CHECK: scf.yield %[[VAL_18:.*]] : tensor<16x64x256xf32> 29// CHECK: } 30// CHECK: return %[[VAL_19:.*]] : tensor<16x64x256xf32> 31// CHECK: } 32func.func @softmax(%arg0: tensor<16x64x256xf32>) -> tensor<16x64x256xf32> { 33 %0 = tensor.empty() : tensor<16x64x256xf32> 34 %1 = linalg.softmax 35 dimension(1) ins(%arg0 : tensor<16x64x256xf32>) outs(%0 : tensor<16x64x256xf32>) -> tensor<16x64x256xf32> 36 return %1 : tensor<16x64x256xf32> 37} 38 39module attributes {transform.with_named_sequence} { 40 transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { 41 %0 = transform.structured.match ops{["linalg.softmax"]} in %arg1 : (!transform.any_op) -> !transform.any_op 42 %1, %loop:2 = transform.structured.tile_using_for %0 tile_sizes [2, 3] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) 43 transform.yield 44 } 45} 46 47// ----- 48 49// Test the softmax tiling interface with the tile_using_forall transform and 50// check that it composes properly with the fuse transform. 51// This should sink the linalg.generic inside the scf.forall and run that 52// generic on 2x4x256 tensors (2==16/8, 4==64/16). 53 54// CHECK: #[[$TIMES2_MAP:.*]] = affine_map<(d0) -> (d0 * 2)> 55// CHECK: #[[$TIMES4_MAP:.*]] = affine_map<(d0) -> (d0 * 4)> 56// CHECK-LABEL: func.func @softmax_tile_n_fuse( 57// CHECK-SAME: %[[VAL_0:.*]]: tensor<16x64x256xf32>) -> tensor<16x64x256xf32> { 58// CHECK: %[[VAL_1:.*]] = arith.constant 1.000000e+00 : f32 59// CHECK: %[[VAL_2:.*]] = tensor.empty() : tensor<16x64x256xf32> 60// CHECK: %[[VAL_3:.*]] = tensor.empty() : tensor<16x64x256xf32> 61// CHECK: %[[VAL_4:.*]] = scf.forall (%[[VAL_5:.*]], %[[VAL_6:.*]]) in (8, 16) shared_outs(%[[VAL_7:.*]] = %[[VAL_3]]) -> (tensor<16x64x256xf32>) { 62// CHECK: %[[VAL_8:.*]] = affine.apply #[[$TIMES2_MAP]](%[[VAL_5]]) 63// CHECK: %[[VAL_9:.*]] = affine.apply #[[$TIMES4_MAP]](%[[VAL_6]]) 64// CHECK: %[[VAL_10:.*]] = affine.apply #[[$TIMES2_MAP]](%[[VAL_5]]) 65// CHECK: %[[VAL_11:.*]] = affine.apply #[[$TIMES4_MAP]](%[[VAL_6]]) 66// CHECK: %[[VAL_12:.*]] = affine.apply #[[$TIMES2_MAP]](%[[VAL_5]]) 67// CHECK: %[[VAL_13:.*]] = affine.apply #[[$TIMES4_MAP]](%[[VAL_6]]) 68// CHECK: %[[VAL_14:.*]] = tensor.extract_slice %[[VAL_0]]{{\[}}%[[VAL_10]], %[[VAL_11]], 0] [2, 4, 256] [1, 1, 1] : tensor<16x64x256xf32> to tensor<2x4x256xf32> 69// CHECK: %[[VAL_15:.*]] = tensor.extract_slice %[[VAL_2]]{{\[}}%[[VAL_12]], %[[VAL_13]], 0] [2, 4, 256] [1, 1, 1] : tensor<16x64x256xf32> to tensor<2x4x256xf32> 70// CHECK: %[[VAL_16:.*]] = linalg.generic {indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[VAL_14]] : tensor<2x4x256xf32>) outs(%[[VAL_15]] : tensor<2x4x256xf32>) { 71// CHECK: ^bb0(%[[VAL_17:.*]]: f32, %[[VAL_18:.*]]: f32): 72// CHECK: %[[VAL_19:.*]] = arith.addf %[[VAL_18]], %[[VAL_1]] : f32 73// CHECK: linalg.yield %[[VAL_19]] : f32 74// CHECK: } -> tensor<2x4x256xf32> 75// CHECK: %[[VAL_20:.*]] = tensor.extract_slice %[[VAL_7]]{{\[}}%[[VAL_8]], %[[VAL_9]], 0] [2, 4, 256] [1, 1, 1] : tensor<16x64x256xf32> to tensor<2x4x256xf32> 76// CHECK: %[[VAL_21:.*]] = linalg.softmax dimension(1) ins(%[[VAL_22:.*]] : tensor<2x4x256xf32>) outs(%[[VAL_20]] : tensor<2x4x256xf32>) -> tensor<2x4x256xf32> 77// CHECK: scf.forall.in_parallel { 78// CHECK: tensor.parallel_insert_slice %[[VAL_21]] into %[[VAL_7]]{{\[}}%[[VAL_8]], %[[VAL_9]], 0] [2, 4, 256] [1, 1, 1] : tensor<2x4x256xf32> into tensor<16x64x256xf32> 79// CHECK: } 80// CHECK: } 81// CHECK: return %[[VAL_23:.*]] : tensor<16x64x256xf32> 82// CHECK: } 83 84func.func @softmax_tile_n_fuse(%arg0: tensor<16x64x256xf32>) -> tensor<16x64x256xf32> { 85 %empty = tensor.empty() : tensor<16x64x256xf32> 86 %cst = arith.constant 1.000000e+00 : f32 87 %eltwise = linalg.generic 88 {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 89 affine_map<(d0, d1, d2) -> (d0, d1, d2)>], 90 iterator_types = ["parallel", "parallel", "parallel"] 91 } 92 ins(%arg0 : tensor<16x64x256xf32>) 93 outs(%empty : tensor<16x64x256xf32>) { 94 ^bb0(%arg2: f32, %arg3: f32): 95 %arg3Plus1 = arith.addf %arg3, %cst : f32 96 linalg.yield %arg3Plus1 : f32 97 } -> tensor<16x64x256xf32> 98 99 %0 = tensor.empty() : tensor<16x64x256xf32> 100 %1 = linalg.softmax 101 dimension(1) ins(%eltwise : tensor<16x64x256xf32>) outs(%0 : tensor<16x64x256xf32>) -> tensor<16x64x256xf32> 102 return %1 : tensor<16x64x256xf32> 103} 104 105module attributes {transform.with_named_sequence} { 106 transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { 107 %0 = transform.structured.match ops{["linalg.softmax"]} in %arg1 : (!transform.any_op) -> !transform.any_op 108 109 // Tile the root. 110 %tiled_op, %forall_op = transform.structured.tile_using_forall %0 num_threads [8, 16] 111 : (!transform.any_op) -> (!transform.any_op, !transform.any_op) 112 113 // Fuse all producers. 114 %1 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op 115 transform.structured.fuse_into_containing_op %1 into %forall_op 116 : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) 117 transform.yield 118 } 119} 120// ----- 121 122// Same as the previous test but on memrefs. 123 124// CHECK: #[[$MIN_MAP:.*]] = affine_map<(d0) -> (-d0 + 64, 3)> 125// CHECK-LABEL: func.func @softmax_memref( 126// CHECK-SAME: %[[VAL_0:.*]]: memref<16x64x256xf32>, 127// CHECK-SAME: %[[VAL_1:.*]]: memref<16x64x256xf32>) { 128// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index 129// CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index 130// CHECK-DAG: %[[C64:.*]] = arith.constant 64 : index 131// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index 132// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index 133// CHECK: scf.for %[[VAL_7:.*]] = %[[C0]] to %[[C16]] step %[[C2]] { 134// CHECK: scf.for %[[VAL_8:.*]] = %[[C0]] to %[[C64]] step %[[C3]] { 135// CHECK: %[[VAL_9:.*]] = affine.min #[[$MIN_MAP]](%[[VAL_8]]) 136// CHECK: %[[VAL_10:.*]] = memref.subview %[[VAL_0]]{{\[}}%[[VAL_7]], %[[VAL_8]], 0] [2, %[[VAL_9]], 256] [1, 1, 1] : memref<16x64x256xf32> to memref<2x?x256xf32, strided<[16384, 256, 1], offset: ?>> 137// CHECK: %[[VAL_11:.*]] = memref.subview %[[VAL_1]]{{\[}}%[[VAL_7]], %[[VAL_8]], 0] [2, %[[VAL_9]], 256] [1, 1, 1] : memref<16x64x256xf32> to memref<2x?x256xf32, strided<[16384, 256, 1], offset: ?>> 138// CHECK: linalg.softmax dimension(1) ins(%[[VAL_10]] : memref<2x?x256xf32, strided<[16384, 256, 1], offset: ?>>) outs(%[[VAL_11]] : memref<2x?x256xf32, strided<[16384, 256, 1], offset: ?>>) 139// CHECK: } 140// CHECK: } 141// CHECK: return 142// CHECK: } 143func.func @softmax_memref(%arg0: memref<16x64x256xf32>, %arg1: memref<16x64x256xf32>) { 144 linalg.softmax 145 dimension(1) ins(%arg0 : memref<16x64x256xf32>) outs(%arg1 : memref<16x64x256xf32>) 146 return 147} 148 149module attributes {transform.with_named_sequence} { 150 transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { 151 %0 = transform.structured.match ops{["linalg.softmax"]} in %arg1 : (!transform.any_op) -> !transform.any_op 152 %1, %loop:2 = transform.structured.tile_using_for %0 tile_sizes [2, 3] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) 153 transform.yield 154 } 155} 156