xref: /llvm-project/mlir/test/Integration/Dialect/Linalg/CPU/mmt4d.mlir (revision eb206e9ea84eff0a0596fed2de8316d924f946d1)
1// DEFINE: %{compile} =  mlir-opt %s \
2// DEFINE:    -transform-interpreter -test-transform-dialect-erase-schedule \
3// DEFINE:    -one-shot-bufferize="bufferize-function-boundaries" -buffer-deallocation-pipeline -cse -canonicalize -convert-vector-to-scf -test-lower-to-llvm -o %t
4// DEFINE: %{entry_point} = mmt4d
5// DEFINE: %{run} = mlir-runner %t -e %{entry_point} -entry-point-result=void \
6// DEFINE:    -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils
7
8// RUN: %{compile}
9
10// RUN: %{run} | FileCheck %s
11
12func.func @mmt4d() {
13  // Allocate the matrices
14  %A_alloc = tensor.empty() : tensor<2x2x3x1xi32>
15  %B_alloc = tensor.empty() : tensor<2x2x3x1xi32>
16  %C_alloc = tensor.empty() : tensor<2x2x3x3xi32>
17  %C_in = arith.constant dense<[
18    [[[ 1, 2, 3],
19     [ 4, 5, 6],
20     [ 7, 8, 9]],
21    [[ 11, 12, 13],
22     [ 14, 15, 16],
23     [ 17, 18, 19]]],
24    [[[ 21, 22, 23],
25     [ 24, 25, 26],
26     [ 27, 28, 29]],
27    [[ 31, 32, 33],
28     [ 34, 35, 36],
29     [ 37, 38, 39]]]
30  ]> : tensor<2x2x3x3xi32>
31
32  // Initialise the matrices
33  %three = arith.constant 3 : i32
34  %four = arith.constant 4 : i32
35  %A = linalg.fill ins(%three : i32) outs(%A_alloc : tensor<2x2x3x1xi32>) -> tensor<2x2x3x1xi32>
36  %B = linalg.fill ins(%four : i32) outs(%B_alloc : tensor<2x2x3x1xi32>) -> tensor<2x2x3x1xi32>
37
38  // Matmul
39  %C_out = linalg.mmt4d ins(%A, %B: tensor<2x2x3x1xi32>, tensor<2x2x3x1xi32>) outs(%C_in: tensor<2x2x3x3xi32>) -> tensor<2x2x3x3xi32>
40
41  // Print and verify the output
42  // CHECK:  Unranked Memref {{.*}} rank = 4 offset = 0 sizes = [2, 2, 3, 3] strides = [18, 9, 3, 1] data =
43  // C[0, 0]
44  // CHECK-NEXT: [25,  26, 27]
45  // CHECK-NEXT: [28,  29, 30]
46  // CHECK-NEXT: [31,  32, 33]
47  // C[0, 1]
48  // CHECK-NEXT: [35,  36, 37]
49  // CHECK-NEXT: [38,  39, 40]
50  // CHECK-NEXT: [41,  42, 43]
51  // C[1, 0]
52  // CHECK-NEXT: [45,  46, 47]
53  // CHECK-NEXT: [48,  49, 50]
54  // CHECK-NEXT: [51,  52, 53]
55  // C[1, 1]
56  // CHECK-NEXT: [55,  56, 57]
57  // CHECK-NEXT: [58,  59, 60]
58  // CHECK-NEXT: [61,  62, 63]
59
60  %xf = tensor.cast %C_out : tensor<2x2x3x3xi32> to tensor<*xi32>
61  call @printMemrefI32(%xf) : (tensor<*xi32>) -> ()
62
63  return
64}
65
66module @transforms attributes { transform.with_named_sequence } {
67  transform.named_sequence @__transform_main(%module: !transform.any_op {transform.readonly}) {
68   %mmt4d = transform.collect_matching @match_mmt4d in %module : (!transform.any_op) -> (!transform.any_op)
69   %func = transform.get_parent_op %mmt4d {isolated_from_above} : (!transform.any_op) -> !transform.op<"func.func">
70
71   // Step 1: Tile
72   // Tile parallel dims
73   %tiled_linalg_op_p, %loops:4 = transform.structured.tile_using_for %mmt4d tile_sizes [1, 1, 0, 3, 3, 0]
74     : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
75   // Tile reduction dims
76   %tiled_linalg_op_r, %loops2:2 = transform.structured.tile_using_for %tiled_linalg_op_p tile_sizes [0, 0, 1, 0, 0, 1]
77     : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
78
79   // Step 2: Vectorize
80   transform.structured.vectorize %tiled_linalg_op_r : !transform.any_op
81
82   // Step 3: Simplify
83   // vector.multi_reduction --> vector.contract
84   // Generates a 6-dim vector.contract with the dim matching the original MMT4D Op
85   // and with the following split into parallel and reduction dims:
86   //    * parallel, parallel, reduction, parallel, parallel, reduction
87   transform.apply_patterns to %func {
88     transform.apply_patterns.vector.reduction_to_contract
89     // Reduce the rank of xfer ops. This transforms vector.contract to be
90     // more matmul-like and to enable the lowering to outer product Ops.
91     transform.apply_patterns.vector.transfer_permutation_patterns
92   } : !transform.op<"func.func">
93
94   // Hoisting and LICM - not strictly required
95   %func_h = transform.structured.hoist_redundant_vector_transfers %func
96     : (!transform.op<"func.func">) -> !transform.op<"func.func">
97   %all_loops = transform.structured.match interface{LoopLikeInterface} in %func_h
98     : (!transform.op<"func.func">) -> !transform.any_op
99   transform.apply_licm to %all_loops : !transform.any_op
100   transform.loop.hoist_loop_invariant_subsets %all_loops : !transform.any_op
101
102   // Simplify the 6-dim vector.contract into a 3-dim matmul-like
103   // vector.contract with the following split into parallel and reduction
104   // dims:
105   //    * parallel, parallel, reduction
106   transform.apply_patterns to %func_h {
107     transform.apply_patterns.vector.reduction_to_contract
108     transform.apply_patterns.vector.cast_away_vector_leading_one_dim
109     transform.apply_patterns.canonicalization
110   } : !transform.op<"func.func">
111    transform.yield
112  }
113
114  transform.named_sequence @match_mmt4d(
115      %entry: !transform.any_op {transform.readonly}) -> !transform.any_op {
116    transform.match.operation_name %entry ["linalg.mmt4d"] : !transform.any_op
117    transform.yield %entry : !transform.any_op
118  }
119}
120
121func.func private @printMemrefI32(%ptr : tensor<*xi32>)
122