1// RUN: mlir-opt %s -split-input-file --sparse-buffer-rewrite --canonicalize --cse | FileCheck %s 2 3// CHECK-LABEL: func @sparse_push_back( 4// CHECK-SAME: %[[A:.*]]: index, 5// CHECK-SAME: %[[B:.*]]: memref<?xf64>, 6// CHECK-SAME: %[[C:.*]]: f64) -> (memref<?xf64>, index) { 7// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index 8// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index 9// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index 10// CHECK: %[[P1:.*]] = memref.dim %[[B]], %[[C0]] 11// CHECK: %[[S2:.*]] = arith.addi %[[A]], %[[C1]] : index 12// CHECK: %[[T:.*]] = arith.cmpi ugt, %[[S2]], %[[P1]] 13// CHECK: %[[M:.*]] = scf.if %[[T]] -> (memref<?xf64>) { 14// CHECK: %[[P2:.*]] = arith.muli %[[P1]], %[[C2]] 15// CHECK: %[[M2:.*]] = memref.realloc %[[B]](%[[P2]]) 16// CHECK: scf.yield %[[M2]] : memref<?xf64> 17// CHECK: } else { 18// CHECK: scf.yield %[[B]] : memref<?xf64> 19// CHECK: } 20// CHECK: memref.store %[[C]], %[[M]]{{\[}}%[[A]]] 21// CHECK: return %[[M]], %[[S2]] 22func.func @sparse_push_back(%arg0: index, %arg1: memref<?xf64>, %arg2: f64) -> (memref<?xf64>, index) { 23 %0:2 = sparse_tensor.push_back %arg0, %arg1, %arg2 : index, memref<?xf64>, f64 24 return %0#0, %0#1 : memref<?xf64>, index 25} 26 27// ----- 28 29// CHECK-LABEL: func @sparse_push_back_n( 30// CHECK-SAME: %[[S1:.*]]: index, 31// CHECK-SAME: %[[B:.*]]: memref<?xf64>, 32// CHECK-SAME: %[[C:.*]]: f64, 33// CHECK-SAME: %[[D:.*]]: index) -> (memref<?xf64>, index) { 34// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index 35// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index 36// CHECK: %[[P1:.*]] = memref.dim %[[B]], %[[C0]] 37// CHECK: %[[S2:.*]] = arith.addi %[[S1]], %[[D]] : index 38// CHECK: %[[T:.*]] = arith.cmpi ugt, %[[S2]], %[[P1]] 39// CHECK: %[[M:.*]] = scf.if %[[T]] -> (memref<?xf64>) { 40// CHECK: %[[P2:.*]] = scf.while (%[[I:.*]] = %[[P1]]) : (index) -> index { 41// CHECK: %[[P3:.*]] = arith.muli %[[I]], %[[C2]] : index 42// CHECK: %[[T2:.*]] = arith.cmpi ugt, %[[S2]], %[[P3]] : index 43// CHECK: scf.condition(%[[T2]]) %[[P3]] : index 44// CHECK: } do { 45// CHECK: ^bb0(%[[I2:.*]]: index): 46// CHECK: scf.yield %[[I2]] : index 47// CHECK: } 48// CHECK: %[[M2:.*]] = memref.realloc %[[B]](%[[P2]]) 49// CHECK: scf.yield %[[M2]] : memref<?xf64> 50// CHECK: } else { 51// CHECK: scf.yield %[[B]] : memref<?xf64> 52// CHECK: } 53// CHECK: %[[S:.*]] = memref.subview %[[M]]{{\[}}%[[S1]]] {{\[}}%[[D]]] [1] 54// CHECK: linalg.fill ins(%[[C]] : f64) outs(%[[S]] 55// CHECK: return %[[M]], %[[S2]] : memref<?xf64>, index 56func.func @sparse_push_back_n(%arg0: index, %arg1: memref<?xf64>, %arg2: f64, %arg3: index) -> (memref<?xf64>, index) { 57 %0:2 = sparse_tensor.push_back %arg0, %arg1, %arg2, %arg3 : index, memref<?xf64>, f64, index 58 return %0#0, %0#1 : memref<?xf64>, index 59} 60 61// ----- 62 63// CHECK-LABEL: func @sparse_push_back_inbound( 64// CHECK-SAME: %[[S1:.*]]: index, 65// CHECK-SAME: %[[B:.*]]: memref<?xf64>, 66// CHECK-SAME: %[[C:.*]]: f64) -> (memref<?xf64>, index) { 67// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index 68// CHECK: %[[S2:.*]] = arith.addi %[[S1]], %[[C1]] 69// CHECK: memref.store %[[C]], %[[B]]{{\[}}%[[S1]]] 70// CHECK: return %[[B]], %[[S2]] : memref<?xf64>, index 71func.func @sparse_push_back_inbound(%arg0: index, %arg1: memref<?xf64>, %arg2: f64) -> (memref<?xf64>, index) { 72 %0:2 = sparse_tensor.push_back inbounds %arg0, %arg1, %arg2 : index, memref<?xf64>, f64 73 return %0#0, %0#1 : memref<?xf64>, index 74} 75 76// ----- 77 78#ID_MAP=affine_map<(d0, d1) -> (d0, d1)> 79 80// Only check the generated supporting functions. We have integration test to 81// verify correctness of the generated code. 82// 83// CHECK-DAG: func.func private @_sparse_partition_0_1_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) -> index { 84// CHECK-DAG: func.func private @_sparse_qsort_0_1_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) { 85// CHECK-LABEL: func.func @sparse_sort_coo_quick 86func.func @sparse_sort_coo_quick(%arg0: index, %arg1: memref<100xindex>, %arg2: memref<?xf32>, %arg3: memref<10xi32>) -> (memref<100xindex>, memref<?xf32>, memref<10xi32>) { 87 sparse_tensor.sort quick_sort %arg0, %arg1 jointly %arg2, %arg3 {perm_map = #ID_MAP, ny = 1: index} : memref<100xindex> jointly memref<?xf32>, memref<10xi32> 88 return %arg1, %arg2, %arg3 : memref<100xindex>, memref<?xf32>, memref<10xi32> 89} 90 91// ----- 92 93#ID_MAP=affine_map<(d0, d1) -> (d0, d1)> 94 95// Only check the generated supporting functions. We have integration test to 96// verify correctness of the generated code. 97// 98// CHECK-DAG: func.func private @_sparse_binary_search_0_1_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) -> index { 99// CHECK-DAG: func.func private @_sparse_sort_stable_0_1_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) { 100// CHECK-DAG: func.func private @_sparse_shift_down_0_1_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>, %arg5: index) { 101// CHECK-DAG: func.func private @_sparse_heap_sort_0_1_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) { 102// CHECK-DAG: func.func private @_sparse_partition_0_1_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) -> index { 103// CHECK-DAG: func.func private @_sparse_hybrid_qsort_0_1_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>, %arg5: i64) { 104// CHECK-LABEL: func.func @sparse_sort_coo_hybrid 105func.func @sparse_sort_coo_hybrid(%arg0: index, %arg1: memref<100xindex>, %arg2: memref<?xf32>, %arg3: memref<10xi32>) -> (memref<100xindex>, memref<?xf32>, memref<10xi32>) { 106 sparse_tensor.sort hybrid_quick_sort %arg0, %arg1 jointly %arg2, %arg3 {perm_map = #ID_MAP, ny = 1: index} : memref<100xindex> jointly memref<?xf32>, memref<10xi32> 107 return %arg1, %arg2, %arg3 : memref<100xindex>, memref<?xf32>, memref<10xi32> 108} 109 110// ----- 111 112#ID_MAP=affine_map<(d0, d1) -> (d0, d1)> 113 114// Only check the generated supporting functions. We have integration test to 115// verify correctness of the generated code. 116// 117// CHECK-DAG: func.func private @_sparse_binary_search_0_1_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) -> index { 118// CHECK-DAG: func.func private @_sparse_sort_stable_0_1_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) { 119// CHECK-LABEL: func.func @sparse_sort_coo_stable 120func.func @sparse_sort_coo_stable(%arg0: index, %arg1: memref<100xindex>, %arg2: memref<?xf32>, %arg3: memref<10xi32>) -> (memref<100xindex>, memref<?xf32>, memref<10xi32>) { 121 sparse_tensor.sort insertion_sort_stable %arg0, %arg1 jointly %arg2, %arg3 {perm_map = #ID_MAP, ny = 1: index} : memref<100xindex> jointly memref<?xf32>, memref<10xi32> 122 return %arg1, %arg2, %arg3 : memref<100xindex>, memref<?xf32>, memref<10xi32> 123} 124 125// ----- 126 127#ID_MAP=affine_map<(d0, d1) -> (d0, d1)> 128 129// Only check the generated supporting functions. We have integration test to 130// verify correctness of the generated code. 131// 132// CHECK-DAG: func.func private @_sparse_shift_down_0_1_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>, %arg5: index) { 133// CHECK-DAG: func.func private @_sparse_heap_sort_0_1_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) { 134// CHECK-LABEL: func.func @sparse_sort_coo_heap 135func.func @sparse_sort_coo_heap(%arg0: index, %arg1: memref<100xindex>, %arg2: memref<?xf32>, %arg3: memref<10xi32>) -> (memref<100xindex>, memref<?xf32>, memref<10xi32>) { 136 sparse_tensor.sort heap_sort %arg0, %arg1 jointly %arg2, %arg3 {perm_map = #ID_MAP, ny = 1: index} : memref<100xindex> jointly memref<?xf32>, memref<10xi32> 137 return %arg1, %arg2, %arg3 : memref<100xindex>, memref<?xf32>, memref<10xi32> 138} 139