1// DEFINE: %{compile} = mlir-opt %s \ 2// DEFINE: -transform-interpreter -test-transform-dialect-erase-schedule |\ 3// DEFINE: mlir-opt \ 4// DEFINE: -test-lower-to-llvm -o %t 5// DEFINE: %{entry_point} = main 6// DEFINE: %{run} = mlir-runner %t -e %{entry_point} -entry-point-result=void \ 7// DEFINE: -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils 8 9// RUN: rm -f %t && %{compile} && %{run} | FileCheck %s 10 11/// End-to-end test for tensor.unpack where one of the inner tile sizes is 12/// dynamic. 13 14func.func @main() { 15 // Allocate and initialise the inputs 16 %A_alloc = tensor.empty() : tensor<7x3xi32> 17 18 %A = arith.constant dense<[ 19 [[[1], 20 [2], 21 [3], 22 [4], 23 [5], 24 [6], 25 [7], 26 [123]], 27 [[8], 28 [9], 29 [10], 30 [11], 31 [12], 32 [13], 33 [14], 34 [123]], 35 [[15], 36 [16], 37 [17], 38 [18], 39 [19], 40 [20], 41 [21], 42 [123]]] 43 ]> : tensor<1x3x8x1xi32> 44 45 %A_cast = tensor.cast %A : tensor<1x3x8x1xi32> to tensor<?x3x?x1xi32> 46 func.call @unpack(%A_cast) : (tensor<?x3x?x1xi32>) -> () 47 48 return 49} 50 51func.func private @unpack(%A: tensor<?x3x?x1xi32>) { 52 %c1 = arith.constant 1 : index 53 %pad_val = arith.constant 123 : i32 54 55 // Dynamic tile size 56 %tile_size = arith.constant 8 : index 57 %A_unpack_empty = tensor.empty() : tensor<7x3xi32> 58 59 %A_unpack = tensor.unpack %A 60 inner_dims_pos = [0, 1] 61 inner_tiles = [%tile_size, 1] 62 into %A_unpack_empty : tensor<?x3x?x1xi32> -> tensor<7x3xi32> 63 %A_cast = tensor.cast %A_unpack : tensor<7x3xi32> to tensor<*xi32> 64 65 // Print the results 66 // CHECK: Unranked Memref base@ = 0x{{.*}} rank = 2 offset = 0 sizes = [7, 3] strides = [3, 1] data = 67 // CHECK-NEXT: [1, 8, 15], 68 // CHECK-NEXT: [2, 9, 16], 69 // CHECK-NEXT: [3, 10, 17], 70 // CHECK-NEXT: [4, 11, 18], 71 // CHECK-NEXT: [5, 12, 19], 72 // CHECK-NEXT: [6, 13, 20], 73 // CHECK-NEXT: [7, 14, 21] 74 call @printMemrefI32(%A_cast) : (tensor<*xi32>) -> () 75 76 return 77} 78 79module @transforms attributes { transform.with_named_sequence } { 80 transform.named_sequence @__transform_main(%module: !transform.any_op {transform.consume}) { 81 %pack = transform.structured.match ops{["tensor.unpack"]} in %module : (!transform.any_op) -> !transform.any_op 82 83 // 1. Tile so that we can decompose tensor.pack 84 // Ops (see step 2) 85 %c8 = transform.param.constant 8 : i64 -> !transform.param<i64> 86 %tiled_pack_op_p, %loops:2 = transform.structured.tile_using_for %pack tile_sizes [%c8, 1] 87 : (!transform.any_op, !transform.param<i64>) -> (!transform.any_op, !transform.any_op, !transform.any_op) 88 89 // 2. Decompose the tiled unpack Op into tensor.extract_slice + tensor.insert_slice: 90 %func_op = transform.get_parent_op %tiled_pack_op_p {isolated_from_above} : (!transform.any_op) -> !transform.op<"func.func"> 91 transform.apply_patterns to %func_op { 92 transform.apply_patterns.linalg.decompose_pack_unpack 93 transform.apply_patterns.linalg.decompose_pad 94 } : !transform.op<"func.func"> 95 96 // 3. Bufferize before lowering to LLVM 97 %bufferize = transform.bufferization.one_shot_bufferize %module 98 {bufferize_function_boundaries=true} : (!transform.any_op) -> !transform.any_op 99 100 // 4. Canonicalize 101 %func_op_bufferized = transform.structured.match ops{["func.func"]} in %bufferize : (!transform.any_op) -> !transform.op<"func.func"> 102 transform.apply_patterns to %func_op_bufferized { 103 transform.apply_patterns.canonicalization 104 } : !transform.op<"func.func"> 105 106 transform.yield 107 } 108} 109 110func.func private @printMemrefI32(%ptr : tensor<*xi32>) 111