1// DEFINE: %{compile} = mlir-opt %s \ 2// DEFINE: -transform-interpreter -test-transform-dialect-erase-schedule \ 3// DEFINE: -one-shot-bufferize="bufferize-function-boundaries" -buffer-deallocation-pipeline -cse -canonicalize -convert-vector-to-scf -test-lower-to-llvm -o %t 4// DEFINE: %{entry_point} = mmt4d 5// DEFINE: %{run} = mlir-runner %t -e %{entry_point} -entry-point-result=void \ 6// DEFINE: -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils 7 8// RUN: %{compile} 9 10// RUN: %{run} | FileCheck %s 11 12func.func @mmt4d() { 13 // Allocate the matrices 14 %A_alloc = tensor.empty() : tensor<2x2x3x1xi32> 15 %B_alloc = tensor.empty() : tensor<2x2x3x1xi32> 16 %C_alloc = tensor.empty() : tensor<2x2x3x3xi32> 17 %C_in = arith.constant dense<[ 18 [[[ 1, 2, 3], 19 [ 4, 5, 6], 20 [ 7, 8, 9]], 21 [[ 11, 12, 13], 22 [ 14, 15, 16], 23 [ 17, 18, 19]]], 24 [[[ 21, 22, 23], 25 [ 24, 25, 26], 26 [ 27, 28, 29]], 27 [[ 31, 32, 33], 28 [ 34, 35, 36], 29 [ 37, 38, 39]]] 30 ]> : tensor<2x2x3x3xi32> 31 32 // Initialise the matrices 33 %three = arith.constant 3 : i32 34 %four = arith.constant 4 : i32 35 %A = linalg.fill ins(%three : i32) outs(%A_alloc : tensor<2x2x3x1xi32>) -> tensor<2x2x3x1xi32> 36 %B = linalg.fill ins(%four : i32) outs(%B_alloc : tensor<2x2x3x1xi32>) -> tensor<2x2x3x1xi32> 37 38 // Matmul 39 %C_out = linalg.mmt4d ins(%A, %B: tensor<2x2x3x1xi32>, tensor<2x2x3x1xi32>) outs(%C_in: tensor<2x2x3x3xi32>) -> tensor<2x2x3x3xi32> 40 41 // Print and verify the output 42 // CHECK: Unranked Memref {{.*}} rank = 4 offset = 0 sizes = [2, 2, 3, 3] strides = [18, 9, 3, 1] data = 43 // C[0, 0] 44 // CHECK-NEXT: [25, 26, 27] 45 // CHECK-NEXT: [28, 29, 30] 46 // CHECK-NEXT: [31, 32, 33] 47 // C[0, 1] 48 // CHECK-NEXT: [35, 36, 37] 49 // CHECK-NEXT: [38, 39, 40] 50 // CHECK-NEXT: [41, 42, 43] 51 // C[1, 0] 52 // CHECK-NEXT: [45, 46, 47] 53 // CHECK-NEXT: [48, 49, 50] 54 // CHECK-NEXT: [51, 52, 53] 55 // C[1, 1] 56 // CHECK-NEXT: [55, 56, 57] 57 // CHECK-NEXT: [58, 59, 60] 58 // CHECK-NEXT: [61, 62, 63] 59 60 %xf = tensor.cast %C_out : tensor<2x2x3x3xi32> to tensor<*xi32> 61 call @printMemrefI32(%xf) : (tensor<*xi32>) -> () 62 63 return 64} 65 66module @transforms attributes { transform.with_named_sequence } { 67 transform.named_sequence @__transform_main(%module: !transform.any_op {transform.readonly}) { 68 %mmt4d = transform.collect_matching @match_mmt4d in %module : (!transform.any_op) -> (!transform.any_op) 69 %func = transform.get_parent_op %mmt4d {isolated_from_above} : (!transform.any_op) -> !transform.op<"func.func"> 70 71 // Step 1: Tile 72 // Tile parallel dims 73 %tiled_linalg_op_p, %loops:4 = transform.structured.tile_using_for %mmt4d tile_sizes [1, 1, 0, 3, 3, 0] 74 : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) 75 // Tile reduction dims 76 %tiled_linalg_op_r, %loops2:2 = transform.structured.tile_using_for %tiled_linalg_op_p tile_sizes [0, 0, 1, 0, 0, 1] 77 : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) 78 79 // Step 2: Vectorize 80 transform.structured.vectorize %tiled_linalg_op_r : !transform.any_op 81 82 // Step 3: Simplify 83 // vector.multi_reduction --> vector.contract 84 // Generates a 6-dim vector.contract with the dim matching the original MMT4D Op 85 // and with the following split into parallel and reduction dims: 86 // * parallel, parallel, reduction, parallel, parallel, reduction 87 transform.apply_patterns to %func { 88 transform.apply_patterns.vector.reduction_to_contract 89 // Reduce the rank of xfer ops. This transforms vector.contract to be 90 // more matmul-like and to enable the lowering to outer product Ops. 91 transform.apply_patterns.vector.transfer_permutation_patterns 92 } : !transform.op<"func.func"> 93 94 // Hoisting and LICM - not strictly required 95 %func_h = transform.structured.hoist_redundant_vector_transfers %func 96 : (!transform.op<"func.func">) -> !transform.op<"func.func"> 97 %all_loops = transform.structured.match interface{LoopLikeInterface} in %func_h 98 : (!transform.op<"func.func">) -> !transform.any_op 99 transform.apply_licm to %all_loops : !transform.any_op 100 transform.loop.hoist_loop_invariant_subsets %all_loops : !transform.any_op 101 102 // Simplify the 6-dim vector.contract into a 3-dim matmul-like 103 // vector.contract with the following split into parallel and reduction 104 // dims: 105 // * parallel, parallel, reduction 106 transform.apply_patterns to %func_h { 107 transform.apply_patterns.vector.reduction_to_contract 108 transform.apply_patterns.vector.cast_away_vector_leading_one_dim 109 transform.apply_patterns.canonicalization 110 } : !transform.op<"func.func"> 111 transform.yield 112 } 113 114 transform.named_sequence @match_mmt4d( 115 %entry: !transform.any_op {transform.readonly}) -> !transform.any_op { 116 transform.match.operation_name %entry ["linalg.mmt4d"] : !transform.any_op 117 transform.yield %entry : !transform.any_op 118 } 119} 120 121func.func private @printMemrefI32(%ptr : tensor<*xi32>) 122