1// RUN: mlir-opt -pass-pipeline="builtin.module(func.func(convert-elementwise-to-linalg))" -split-input-file %s | FileCheck %s 2 3// In-depth checking of the linalg.generic op for a very trivial case. 4// CHECK: #[[$MAP:.*]] = affine_map<() -> ()> 5// CHECK-LABEL: func @addf_rank0 6// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]*]]: tensor<f32> 7// CHECK-SAME: %[[ARG1:[0-9a-zA-Z]*]]: tensor<f32> 8func.func @addf_rank0(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<f32> { 9 // CHECK: %{{.*}} = linalg.generic 10 // CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP]], #[[$MAP]]] 11 // CHECK-SAME: iterator_types = [] 12 // CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] 13 // CHECK-SAME: outs(%[[ARG0]] 14 // CHECK: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32, %{{.*}}: f32): 15 // CHECK: %[[YIELD:.*]] = arith.addf %[[LHS]], %[[RHS]] : f32 16 // CHECK: linalg.yield %[[YIELD]] : f32 17 // CHECK: } -> tensor<f32> 18 %0 = arith.addf %arg0, %arg1 : tensor<f32> 19 return %0 : tensor<f32> 20} 21 22// ----- 23 24// Check indexing maps and iterator types for the rank > 0 case. 25// CHECK-LABEL: func @addf_rank1 26// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]*]]: tensor<?xf32> 27// CHECK-SAME: %[[ARG1:[0-9a-zA-Z]*]]: tensor<?xf32> 28func.func @addf_rank1(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> { 29 // CHECK: linalg.generic 30 // CHECK-SAME: iterator_types = ["parallel"] 31 // CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] 32 // CHECK-SAME: outs(%[[ARG0]] 33 %0 = arith.addf %arg0, %arg1 : tensor<?xf32> 34 return %0 : tensor<?xf32> 35} 36 37// ----- 38 39// Check a unary op. 40// CHECK-LABEL: func @exp 41// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]*]]: tensor<f32> 42func.func @exp(%arg0: tensor<f32>) -> tensor<f32> { 43 // CHECK: linalg.generic 44 // CHECK-SAME: ins(%[[ARG0]] 45 // CHECK-SAME: outs(%[[ARG0]] 46 // CHECK: ^bb0(%[[SCALAR:.*]]: f32, %{{.*}}: f32): 47 // CHECK: %[[YIELD:.*]] = math.exp %[[SCALAR]] : f32 48 // CHECK: linalg.yield %[[YIELD]] : f32 49 %0 = math.exp %arg0 : tensor<f32> 50 return %0 : tensor<f32> 51} 52 53// ----- 54 55// Check a case with varying operand types. 56// CHECK-LABEL: func @select 57// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]*]]: tensor<i1> 58// CHECK-SAME: %[[ARG1:[0-9a-zA-Z]*]]: tensor<i32> 59// CHECK-SAME: %[[ARG2:[0-9a-zA-Z]*]]: tensor<i32> 60func.func @select(%arg0: tensor<i1>, %arg1: tensor<i32>, %arg2: tensor<i32>) -> tensor<i32> { 61 // CHECK: linalg.generic 62 // CHECK-SAME: ins(%[[ARG0]], %[[ARG1]], %[[ARG2]] 63 // CHECK-SAME: outs(%[[ARG1]] 64 // CHECK: ^bb0(%[[PRED:.*]]: i1, %[[TRUE_VAL:.*]]: i32, %[[FALSE_VAL:.*]]: i32, %{{.*}}: i32): 65 // CHECK: arith.select %[[PRED]], %[[TRUE_VAL]], %[[FALSE_VAL]] : i32 66 %0 = arith.select %arg0, %arg1, %arg2 : tensor<i1>, tensor<i32> 67 return %0 : tensor<i32> 68} 69 70// ----- 71 72// Spot-check an op that requires copying attributes properly to the created scalar op. 73// Also checks proper init_tensor usage. 74// CHECK-LABEL: func @cmpf( 75// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]*]]: tensor<f32> 76// CHECK-SAME: %[[ARG1:[0-9a-zA-Z]*]]: tensor<f32> 77func.func @cmpf(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<i1> { 78 // CHECK: %[[INIT:.*]] = tensor.empty() : tensor<i1> 79 // CHECK: linalg.generic 80 // CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] 81 // CHECK-SAME: outs(%[[INIT]] 82 // CHECK: ^bb0(%{{.*}}: f32, %{{.*}}: f32, %{{.*}}: i1): 83 // CHECK: arith.cmpf olt, %{{.*}}, %{{.*}} : f32 84 %0 = arith.cmpf olt, %arg0, %arg1 : tensor<f32> 85 return %0 : tensor<i1> 86} 87 88// ----- 89 90// Check proper init_tensor usage in a mixed case. 91// CHECK-LABEL: func @cmpf( 92// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]*]]: tensor<4x?x?x8x2x?xf32> 93// CHECK-SAME: %[[ARG1:[0-9a-zA-Z]*]]: tensor<4x?x?x8x2x?xf32> 94func.func @cmpf(%arg0: tensor<4x?x?x8x2x?xf32>, %arg1: tensor<4x?x?x8x2x?xf32>) -> tensor<4x?x?x8x2x?xi1> { 95 // CHECK: %[[C1:.*]] = arith.constant 1 : index 96 // CHECK: %[[D1:.*]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<4x?x?x8x2x?xf32> 97 // CHECK: %[[C2:.*]] = arith.constant 2 : index 98 // CHECK: %[[D2:.*]] = tensor.dim %[[ARG0]], %[[C2]] : tensor<4x?x?x8x2x?xf32> 99 // CHECK: %[[C5:.*]] = arith.constant 5 : index 100 // CHECK: %[[D5:.*]] = tensor.dim %[[ARG0]], %[[C5]] : tensor<4x?x?x8x2x?xf32> 101 // CHECK: %[[INIT:.*]] = tensor.empty(%[[D1]], %[[D2]], %[[D5]]) : tensor<4x?x?x8x2x?xi1> 102 // CHECK: linalg.generic 103 // CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] 104 // CHECK-SAME: outs(%[[INIT]] 105 // CHECK: ^bb0(%{{.*}}: f32, %{{.*}}: f32, %{{.*}}: i1): 106 // CHECK: arith.cmpf olt, %{{.*}}, %{{.*}} : f32 107 %0 = arith.cmpf olt, %arg0, %arg1 : tensor<4x?x?x8x2x?xf32> 108 return %0 : tensor<4x?x?x8x2x?xi1> 109} 110 111