1// RUN: mlir-opt %s \ 2// RUN: -canonicalize="test-convergence" \ 3// RUN: --split-input-file -allow-unregistered-dialect | \ 4// RUN: FileCheck %s 5 6// Basic folding of to_tensor(to_memref(t)) -> t 7// CHECK-LABEL: func @tensor_load_of_buffer_cast( 8func.func @tensor_load_of_buffer_cast(%arg0: tensor<?xf32>) -> tensor<?xf32> { 9 %0 = bufferization.to_memref %arg0 : tensor<?xf32> to memref<?xf32> 10 %1 = bufferization.to_tensor %0 : memref<?xf32> to tensor<?xf32> 11 return %1 : tensor<?xf32> 12} 13// CHECK-SAME: %[[TENSOR:.*]]: tensor<?xf32>) -> tensor<?xf32> { 14// CHECK: return %[[TENSOR]] 15 16// ----- 17 18// Basic folding of to_memref(to_tensor(m)) -> m 19// CHECK-LABEL: func @buffer_cast_of_tensor_load( 20func.func @buffer_cast_of_tensor_load(%arg0: memref<?xf32>) -> memref<?xf32> { 21 %0 = bufferization.to_tensor %arg0 : memref<?xf32> to tensor<?xf32> 22 %1 = bufferization.to_memref %0 : tensor<?xf32> to memref<?xf32> 23 return %1 : memref<?xf32> 24} 25// CHECK-SAME: %[[MEMREF:.*]]: memref<?xf32>) -> memref<?xf32> { 26// CHECK: return %[[MEMREF]] 27 28// ----- 29 30// If the memrefs are not the same type, don't fold them. 31// If the memrefs are not cast-compatible (e.g. different address space), don't 32// canonicalize them either. 33// CHECK-LABEL: func @no_fold_buffer_cast_of_tensor_load( 34// CHECK-SAME: %[[MEMREF_ADDRSPACE2:.*]]: memref<?xf32, 2>) 35// CHECK-SAME: -> memref<?xf32, 7> { 36// CHECK: %[[TENSOR:.*]] = bufferization.to_tensor 37// CHECK-SAME: %[[MEMREF_ADDRSPACE2]] : memref<?xf32, 2> to tensor<?xf32, 7 : i64> 38// CHECK: %[[MEMREF_ADDRSPACE7:.*]] = bufferization.to_memref 39// CHECK-SAME: %[[TENSOR]] : tensor<?xf32, 7 : i64> to memref<?xf32, 7> 40// CHECK: return %[[MEMREF_ADDRSPACE7]] 41func.func @no_fold_buffer_cast_of_tensor_load(%arg0: memref<?xf32, 2>) 42 -> memref<?xf32, 7> { 43 %0 = bufferization.to_tensor %arg0 : memref<?xf32, 2> to tensor<?xf32, 7> 44 %1 = bufferization.to_memref %0 : tensor<?xf32, 7> to memref<?xf32, 7> 45 return %1 : memref<?xf32, 7> 46} 47 48// ----- 49 50// If the memrefs are definitely cast-compatible, canonicalize to 51// cast. 52// CHECK-LABEL: func @canonicalize_buffer_cast_of_tensor_load( 53// CHECK-SAME: %[[M:.*]]: memref<?xf32, strided<[1], offset: 3>>) 54// CHECK-SAME: -> memref<?xf32, strided<[1], offset: ?>> { 55// CHECK-NOT: bufferization.to_tensor 56// CHECK-NOT: bufferization.to_memref 57// CHECK: %[[R:.*]] = memref.cast %[[M]] 58// CHECK-SAME: memref<?xf32, strided<[1], offset: 3>> to memref<?xf32, strided<[1], offset: ?>> 59// CHECK: return %[[R]] 60func.func @canonicalize_buffer_cast_of_tensor_load( 61 %arg0: memref<?xf32, strided<[1], offset: 3>>) 62 -> memref<?xf32, strided<[1], offset: ?>> 63{ 64 %0 = bufferization.to_tensor %arg0 : memref<?xf32, strided<[1], offset: 3>> to tensor<?xf32> 65 %1 = bufferization.to_memref %0 : tensor<?xf32> to memref<?xf32, strided<[1], offset: ?>> 66 return %1 : memref<?xf32, strided<[1], offset: ?>> 67} 68 69// ----- 70 71// If the memrefs are potentially cast-compatible, canonicalize to 72// copy. 73// CHECK-LABEL: func @canonicalize_buffer_cast_of_tensor_load_to_copy( 74func.func @canonicalize_buffer_cast_of_tensor_load_to_copy( 75 %arg0: memref<?xf32, strided<[1], offset: ?>>) 76 -> memref<?xf32, strided<[1], offset: 3>> { 77 %0 = bufferization.to_tensor %arg0 : memref<?xf32, strided<[1], offset: ?>> to tensor<?xf32> 78 %1 = bufferization.to_memref %0 : tensor<?xf32> to memref<?xf32, strided<[1], offset: 3>> 79 return %1 : memref<?xf32, strided<[1], offset: 3>> 80} 81// CHECK-SAME: %[[M:.*]]: memref<?xf32, strided<[1], offset: ?>>) 82// CHECK-SAME: -> memref<?xf32, strided<[1], offset: 3>> { 83// CHECK-NOT: bufferization.to_tensor 84// CHECK-NOT: bufferization.to_memref 85// CHECK: %[[C0:.*]] = arith.constant 0 : index 86// CHECK: %[[DIM:.*]] = memref.dim %[[M]], %[[C0]] : memref<?xf32, strided<[1], offset: ?>> 87// CHECK: %[[ALLOC:.*]] = memref.alloc(%[[DIM]]) : memref<?xf32, strided<[1], offset: 3>> 88// CHECK: memref.copy %[[M]], %[[ALLOC]] 89// CHECK-SAME: memref<?xf32, strided<[1], offset: ?>> to memref<?xf32, strided<[1], offset: 3>> 90// CHECK: return %[[ALLOC]] 91 92// ----- 93 94 95// Basic folding of tensor.dim(to_tensor(m)) -> memref.dim(m). 96// CHECK-LABEL: func @dim_of_tensor_load( 97// CHECK-SAME: %[[MEMREF:[0-9a-z]*]]: memref<?xf32> 98// CHECK: %[[C0:.*]] = arith.constant 0 99// CHECK: %[[D:.*]] = memref.dim %[[MEMREF]], %[[C0]] 100// CHECK: return %[[D]] : index 101func.func @dim_of_tensor_load(%arg0: memref<?xf32>) -> index { 102 %c0 = arith.constant 0 : index 103 %0 = bufferization.to_tensor %arg0 : memref<?xf32> to tensor<?xf32> 104 %1 = tensor.dim %0, %c0 : tensor<?xf32> 105 return %1 : index 106} 107 108// ----- 109 110// CHECK-LABEL: @clone_before_dealloc 111func.func @clone_before_dealloc(%arg0: memref<?xf32>) -> memref<?xf32> { 112 %0 = bufferization.clone %arg0 : memref<?xf32> to memref<?xf32> 113 memref.dealloc %arg0 : memref<?xf32> 114 return %0 : memref<?xf32> 115} 116// CHECK-SAME: %[[ARG:.*]]: memref<?xf32> 117// CHECK-NEXT: return %[[ARG]] 118 119// ----- 120 121// CHECK-LABEL: @clone_before_dealloc 122func.func @clone_before_dealloc(%arg0: memref<?xf32>) -> memref<?xf32> { 123 %0 = bufferization.clone %arg0 : memref<?xf32> to memref<?xf32> 124 "use"(%0) : (memref<?xf32>) -> () 125 memref.dealloc %0 : memref<?xf32> 126 return %arg0 : memref<?xf32> 127} 128// CHECK-SAME: %[[ARG:.*]]: memref<?xf32> 129// CHECK-NEXT: "use"(%arg0) 130// CHECK-NEXT: return %[[ARG]] 131 132// ----- 133 134// CHECK-LABEL: @clone_after_cast 135func.func @clone_after_cast(%arg0: memref<?xf32>) -> memref<32xf32> { 136 %0 = memref.cast %arg0 : memref<?xf32> to memref<32xf32> 137 %1 = bufferization.clone %0 : memref<32xf32> to memref<32xf32> 138 return %1 : memref<32xf32> 139} 140// CHECK-SAME: %[[ARG:.*]]: memref<?xf32> 141// CHECK-NEXT: bufferization.clone %[[ARG]] : memref<?xf32> to memref<32xf32> 142// CHECK-NOT: memref.cast 143 144// ----- 145 146// CHECK-LABEL: @clone_and_cast 147func.func @clone_and_cast(%arg0: memref<?xf32>) -> memref<32xf32> { 148 %0 = bufferization.clone %arg0 : memref<?xf32> to memref<32xf32> 149 memref.dealloc %arg0 : memref<?xf32> 150 return %0 : memref<32xf32> 151} 152// CHECK-SAME: %[[ARG:.*]]: memref<?xf32> 153// CHECK-NEXT: %[[RES:.*]] = memref.cast %[[ARG]] 154// CHECK-SAME: memref<?xf32> to memref<32xf32> 155// CHECK-NEXT: return %[[RES]] 156 157// ----- 158 159// CHECK-LABEL: @clone_incompatible 160func.func @clone_incompatible(%arg0: memref<32xf32, strided<[2]>>) -> memref<32xf32> { 161 %0 = bufferization.clone %arg0 : memref<32xf32, strided<[2]>> to memref<32xf32> 162 memref.dealloc %arg0 : memref<32xf32, strided<[2]>> 163 return %0 : memref<32xf32> 164} 165// CHECK-SAME: %[[ARG:.*]]: memref<32xf32, strided<[2]>> 166// CHECK-NEXT: bufferization.clone %[[ARG]] : memref<32xf32, strided<[2]>> to memref<32xf32> 167// CHECK-NOT: memref.cast 168 169// ----- 170 171// CHECK-LABEL: @alias_is_freed 172func.func @alias_is_freed(%arg0 : memref<?xf32>) { 173 %0 = memref.cast %arg0 : memref<?xf32> to memref<32xf32> 174 %1 = bufferization.clone %0 : memref<32xf32> to memref<32xf32> 175 memref.dealloc %arg0 : memref<?xf32> 176 "use"(%1) : (memref<32xf32>) -> () 177 memref.dealloc %1 : memref<32xf32> 178 return 179} 180// CHECK: bufferization.clone 181// CHECK: memref.dealloc 182// CHECK: memref.dealloc 183 184// ----- 185 186// Verify SimplifyClones skips clones with multiple deallocations. 187// CHECK-LABEL: @clone_multiple_dealloc_of_source 188func.func @clone_multiple_dealloc_of_source(%arg0: memref<?xf32>) -> memref<?xf32> { 189 %0 = bufferization.clone %arg0 : memref<?xf32> to memref<?xf32> 190 "if_else"() ({ 191 memref.dealloc %arg0 : memref<?xf32> 192 }, { 193 memref.dealloc %arg0 : memref<?xf32> 194 }) : () -> () 195 return %0 : memref<?xf32> 196} 197// CHECK-SAME: %[[ARG:.*]]: memref<?xf32> 198// CHECK-NEXT: %[[RES:.*]] = bufferization.clone %[[ARG]] 199// CHECK: memref.dealloc %[[ARG]] 200// CHECK: memref.dealloc %[[ARG]] 201// CHECK: return %[[RES]] 202 203// ----- 204 205// CHECK-LABEL: @clone_multiple_dealloc_of_clone 206// CHECK-SAME: %[[ARG:.*]]: memref<?xf32> 207func.func @clone_multiple_dealloc_of_clone(%arg0: memref<?xf32>) -> memref<?xf32> { 208 // CHECK-NEXT: %[[CLONE:.*]] = bufferization.clone %[[ARG]] 209 // CHECK: memref.dealloc %[[CLONE]] 210 // CHECK: memref.dealloc %[[CLONE]] 211 // CHECK: return %[[ARG]] 212 %0 = bufferization.clone %arg0 : memref<?xf32> to memref<?xf32> 213 "use"(%0) : (memref<?xf32>) -> () 214 "if_else"() ({ 215 memref.dealloc %0 : memref<?xf32> 216 }, { 217 memref.dealloc %0 : memref<?xf32> 218 }) : () -> () 219 return %arg0 : memref<?xf32> 220} 221 222// ----- 223 224// Verify SimplifyClones skips clones followed by realloc. 225// CHECK-LABEL: @clone_and_realloc 226func.func @clone_and_realloc(%arg0: memref<?xf32>) { 227 %0 = bufferization.clone %arg0 : memref<?xf32> to memref<32xf32> 228 "use"(%0) : (memref<32xf32>) -> () 229 %1 = memref.realloc %0 : memref<32xf32> to memref<64xf32> 230 memref.dealloc %1 : memref<64xf32> 231 return 232} 233// CHECK-SAME: %[[ARG:.*]]: memref<?xf32> 234// CHECK-NOT: %cast = memref.cast %[[ARG]] 235 236// ----- 237 238// Verify SimplifyClones skips clones with preceding deallocation. 239// CHECK-LABEL: @clone_and_preceding_dealloc 240func.func @clone_and_preceding_dealloc(%arg0: memref<?xf32>) -> memref<32xf32> { 241 memref.dealloc %arg0 : memref<?xf32> 242 %0 = bufferization.clone %arg0 : memref<?xf32> to memref<32xf32> 243 return %0 : memref<32xf32> 244} 245// CHECK-SAME: %[[ARG:.*]]: memref<?xf32> 246// CHECK-NOT: %cast = memref.cast %[[ARG]] 247 248// ----- 249 250// CHECK-LABEL: func @tensor_cast_to_memref 251// CHECK-SAME: %[[ARG0:.+]]: tensor<4x6x16x32xi8> 252func.func @tensor_cast_to_memref(%arg0 : tensor<4x6x16x32xi8>) -> 253 memref<?x?x16x32xi8> { 254 %0 = tensor.cast %arg0 : tensor<4x6x16x32xi8> to tensor<?x?x16x32xi8> 255 %1 = bufferization.to_memref %0 : tensor<?x?x16x32xi8> to memref<?x?x16x32xi8> 256 return %1 : memref<?x?x16x32xi8> 257} 258// CHECK: %[[M:.+]] = bufferization.to_memref %[[ARG0]] : tensor<4x6x16x32xi8> 259// CHECK: %[[M1:.+]] = memref.cast %[[M]] 260// CHECK-SAME: memref<4x6x16x32xi8> to memref<?x?x16x32xi8> 261// CHECK: return %[[M1]] : memref<?x?x16x32xi8> 262 263// ----- 264 265// Folding of memref.load(to_memref(%v, %idxs)) -> tensor.extract(%v, %idx) 266// CHECK-LABEL: func @load_from_buffer_cast( 267func.func @load_from_buffer_cast(%arg0: index, %arg1: index, 268 %arg2: tensor<?x?xf32>) -> f32 { 269 %0 = bufferization.to_memref %arg2 : tensor<?x?xf32> to memref<?x?xf32> 270 %1 = memref.load %0[%arg0, %arg1] : memref<?x?xf32> 271 return %1 : f32 272} 273// CHECK-SAME: %[[IDX0:[0-9a-z]+]]: index, %[[IDX1:[0-9a-z]+]]: index 274// CHECK-SAME: %[[TENSOR:[0-9a-z]+]]: tensor<?x?xf32> 275// CHECK: %[[RES:.*]] = tensor.extract %[[TENSOR]][%[[IDX0]], %[[IDX1]]] 276// CHECK-NOT: memref.load 277// CHECK: return %[[RES]] : f32 278 279 280// ----- 281 282func.func @alloc_tensor_canonicalize() -> (tensor<4x5x?xf32>) { 283 %c6 = arith.constant 6 : index 284 %0 = bufferization.alloc_tensor(%c6) : tensor<4x5x?xf32> 285 return %0 : tensor<4x5x?xf32> 286} 287// CHECK: func @alloc_tensor_canonicalize 288// CHECK: %[[T0:.+]] = bufferization.alloc_tensor() : tensor<4x5x6xf32> 289// CHECK: %[[T1:.+]] = tensor.cast %[[T0]] : tensor<4x5x6xf32> to tensor<4x5x?xf32> 290// CHECK: return %[[T1]] 291 292// ----- 293 294func.func @dealloc_canonicalize_clone_removal(%arg0: memref<?xindex>) -> memref<*xf32> { 295 %c1 = arith.constant 1 : index 296 %0 = memref.alloc(%c1) : memref<?xf32> 297 %1 = memref.reshape %0(%arg0) : (memref<?xf32>, memref<?xindex>) -> memref<*xf32> 298 %2 = bufferization.clone %1 : memref<*xf32> to memref<*xf32> 299 memref.dealloc %0 : memref<?xf32> 300 return %2 : memref<*xf32> 301} 302// CHECK-LABEL: @dealloc_canonicalize_clone_removal 303// CHECK-NOT: bufferization.clone 304// CHECK-NOT: memref.dealloc 305// CHECK: return {{.*}} 306 307// ----- 308 309func.func @dealloc_canonicalize_duplicates(%arg0: memref<2xi32>, %arg1: i1, %arg2: i1, %arg3: memref<2xi32>, %arg4: memref<2xi32>, %arg5: memref<2xi32>) -> (i1, i1, i1) { 310 %0:3 = bufferization.dealloc (%arg4, %arg0, %arg0 : memref<2xi32>, memref<2xi32>, memref<2xi32>) if (%arg1, %arg1, %arg1) retain (%arg3, %arg5, %arg3 : memref<2xi32>, memref<2xi32>, memref<2xi32>) 311 bufferization.dealloc (%arg0, %arg0 : memref<2xi32>, memref<2xi32>) if (%arg1, %arg2) 312 return %0#0, %0#1, %0#2 : i1, i1, i1 313} 314 315// CHECK-LABEL: func @dealloc_canonicalize_duplicates 316// CHECK-SAME: ([[ARG0:%.+]]: memref<2xi32>, [[ARG1:%.+]]: i1, [[ARG2:%.+]]: i1, [[ARG3:%.+]]: memref<2xi32>, [[ARG4:%.+]]: memref<2xi32>, [[ARG5:%.+]]: memref<2xi32>) 317// CHECK-NEXT: [[V0:%.+]]:2 = bufferization.dealloc ([[ARG4]], [[ARG0]] : memref<2xi32>, memref<2xi32>) if ([[ARG1]], [[ARG1]]) retain ([[ARG3]], [[ARG5]] : memref<2xi32>, memref<2xi32>) 318// CHECK-NEXT: [[NEW_COND:%.+]] = arith.ori [[ARG1]], [[ARG2]] : i1 319// CHECK-NEXT: bufferization.dealloc ([[ARG0]] : memref<2xi32>) if ([[NEW_COND]]) 320// CHECK-NEXT: return [[V0]]#0, [[V0]]#1, [[V0]]#0 : 321 322// ----- 323 324func.func @dealloc_erase_empty(%arg0: memref<2xi32>, %arg1: i1, %arg2: memref<2xi32>) -> i1 { 325 bufferization.dealloc 326 %0 = bufferization.dealloc retain (%arg0 : memref<2xi32>) 327 return %0 : i1 328} 329 330// CHECK-LABEL: func @dealloc_erase_empty 331// CHECK-SAME: ([[ARG0:%.+]]: memref<2xi32>, [[ARG1:%.+]]: i1, [[ARG2:%.+]]: memref<2xi32>) 332// CHECK-NEXT: [[FALSE:%.+]] = arith.constant false 333// CHECK-NEXT: return [[FALSE]] : 334 335// ----- 336 337func.func @dealloc_always_false_condition(%arg0: memref<2xi32>, %arg1: memref<2xi32>, %arg2: i1) { 338 %false = arith.constant false 339 bufferization.dealloc (%arg0, %arg1 : memref<2xi32>, memref<2xi32>) if (%false, %arg2) 340 return 341} 342 343// CHECK-LABEL: func @dealloc_always_false_condition 344// CHECK-SAME: ([[ARG0:%.+]]: memref<2xi32>, [[ARG1:%.+]]: memref<2xi32>, [[ARG2:%.+]]: i1) 345// CHECK-NEXT: bufferization.dealloc ([[ARG1]] : {{.*}}) if ([[ARG2]]) 346// CHECK-NEXT: return 347 348// ----- 349 350func.func @dealloc_base_memref_extract_of_alloc(%arg0: memref<2xi32>, %arg1: i1, %arg2: i1, %arg3: memref<2xi32>) -> memref<2xi32> { 351 %alloc = memref.alloc() : memref<2xi32> 352 %base0, %size0, %stride0, %offset0 = memref.extract_strided_metadata %alloc : memref<2xi32> -> memref<i32>, index, index, index 353 %base1, %size1, %stride1, %offset1 = memref.extract_strided_metadata %arg3 : memref<2xi32> -> memref<i32>, index, index, index 354 bufferization.dealloc (%base0, %arg0, %base1 : memref<i32>, memref<2xi32>, memref<i32>) if (%arg1, %arg2, %arg2) 355 return %alloc : memref<2xi32> 356} 357 358// CHECK-LABEL: func @dealloc_base_memref_extract_of_alloc 359// CHECK-SAME: ([[ARG0:%.+]]: memref<2xi32>, [[ARG1:%.+]]: i1, [[ARG2:%.+]]: i1, [[ARG3:%.+]]: memref<2xi32>) 360// CHECK-NEXT: [[ALLOC:%.+]] = memref.alloc() : memref<2xi32> 361// CHECK-NEXT: [[BASE:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[ARG3]] : 362// CHECK-NEXT: bufferization.dealloc ([[ALLOC]], [[ARG0]], [[BASE]] : memref<2xi32>, memref<2xi32>, memref<i32>) if ([[ARG1]], [[ARG2]], [[ARG2]]) 363// CHECK-NEXT: return 364 365// ----- 366 367func.func @dealloc_base_memref_extract_of_alloc(%arg0: memref<2xi32>) { 368 %true = arith.constant true 369 %alloc = memref.alloc() : memref<2xi32> 370 bufferization.dealloc (%alloc, %arg0 : memref<2xi32>, memref<2xi32>) if (%true, %true) 371 return 372} 373 374// CHECK-LABEL: func @dealloc_base_memref_extract_of_alloc 375// CHECK-SAME:([[ARG0:%.+]]: memref<2xi32>) 376// CHECK-NOT: memref.alloc( 377// CHECK: bufferization.dealloc ([[ARG0]] : memref<2xi32>) if (%true 378 379// ----- 380 381// CHECK-LABEL: func @negative_input 382func.func @negative_input() -> tensor<?x?x?xf16> { 383 %idx27 = index.constant 27 384 %idx-3 = index.constant -3 // negative integer? 385 %c10 = arith.constant 10 : index 386// CHECK: bufferization.alloc_tensor 387// CHECK-SAME: tensor<10x?x27xf16> 388 %11 = bufferization.alloc_tensor(%c10, %idx-3, %idx27) : tensor<?x?x?xf16> 389 return %11 : tensor<?x?x?xf16> 390} 391