xref: /llvm-project/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/fill-2d.mlir (revision fe55c34d19628304e0ca6a0e14a0b786b93d0e02)
1// RUN: mlir-opt %s \
2// RUN:   -transform-interpreter -test-transform-dialect-erase-schedule \
3// RUN:   -lower-vector-mask \
4// RUN:   -one-shot-bufferize="bufferize-function-boundaries" \
5// RUN:   -test-lower-to-arm-sme -test-lower-to-llvm | \
6// RUN: %mcr_aarch64_cmd \
7// RUN:   -e=entry -entry-point-result=void \
8// RUN:   -march=aarch64 -mattr="+sve,+sme" \
9// RUN:   -shared-libs=%native_mlir_runner_utils,%native_mlir_c_runner_utils,%native_arm_sme_abi_shlib | \
10// RUN: FileCheck %s
11
12func.func @entry() {
13  %c0 = arith.constant 0 : index
14  %c4 = arith.constant 4 : index
15  %step = arith.constant 1 : index
16
17  %c123_f32 = arith.constant 123.0 : f32
18
19  // "svl" refers to the Streaming Vector Length and "svl_s" the number of
20  // 32-bit elements in a vector of SVL bits.
21  %svl_s = arm_sme.streaming_vl <word>
22
23  %tile_init = bufferization.alloc_tensor(%svl_s, %svl_s) : tensor<?x?xf32>
24
25  // Initialize tile with "123.0".
26  // TODO: this could be simplified to tensor.splat + tensor.insert_slice once
27  // splat supports dynamically shaped tensors.
28  %tile_0 = scf.for %i = %c0 to %svl_s step %step iter_args(%tile_partial = %tile_init) -> tensor<?x?xf32> {
29    %inner_tile = scf.for %j = %c0 to %svl_s step %step iter_args(%inner_tile_partial = %tile_partial) -> tensor<?x?xf32> {
30      %tile_update = tensor.insert %c123_f32 into %inner_tile_partial[%i, %j] : tensor<?x?xf32>
31      scf.yield %tile_update : tensor<?x?xf32>
32    }
33    scf.yield %inner_tile : tensor<?x?xf32>
34  }
35
36  // Print tile after initialization. The smallest SVL is 128-bits so the tile
37  // will be at least 4x4xf32.
38  //
39  // CHECK:      ( 123, 123, 123, 123
40  // CHECK-NEXT: ( 123, 123, 123, 123
41  // CHECK-NEXT: ( 123, 123, 123, 123
42  // CHECK-NEXT: ( 123, 123, 123, 123
43  scf.for %i = %c0 to %svl_s step %step {
44    vector.print punctuation <open>
45    scf.for %j = %c0 to %svl_s step %step {
46      %element = tensor.extract %tile_0[%i, %j] : tensor<?x?xf32>
47      vector.print %element : f32 punctuation <no_punctuation>
48
49      // Print comma unless last element.
50      %c1_index = arith.constant 1 : index
51      %last_i = arith.subi %svl_s, %c1_index : index
52      %isNotLastIter = arith.cmpi ult, %j, %last_i : index
53      scf.if %isNotLastIter {
54        vector.print punctuation <comma>
55      }
56    }
57    vector.print punctuation <close>
58    vector.print punctuation <newline>
59  }
60
61  // Fill tile with pi.
62  %pi = arith.constant 3.14 : f32
63  %tile_1 = linalg.fill ins(%pi : f32) outs(%tile_0 : tensor<?x?xf32>) -> tensor<?x?xf32>
64
65  // Print tile after filling with pi. The smallest SVL is 128-bits so the tile
66  // will be at least 4x4xf32.
67  //
68  // CHECK:      ( 3.14, 3.14, 3.14, 3.14
69  // CHECK-NEXT: ( 3.14, 3.14, 3.14, 3.14
70  // CHECK-NEXT: ( 3.14, 3.14, 3.14, 3.14
71  // CHECK-NEXT: ( 3.14, 3.14, 3.14, 3.14
72  scf.for %i = %c0 to %svl_s step %step {
73    vector.print punctuation <open>
74    scf.for %j = %c0 to %svl_s step %step {
75      %element = tensor.extract %tile_1[%i, %j] : tensor<?x?xf32>
76      vector.print %element : f32 punctuation <no_punctuation>
77
78      // Print comma unless last element.
79      %c1_index = arith.constant 1 : index
80      %last_i = arith.subi %svl_s, %c1_index : index
81      %isNotLastIter = arith.cmpi ult, %j, %last_i : index
82      scf.if %isNotLastIter {
83        vector.print punctuation <comma>
84      }
85    }
86    vector.print punctuation <close>
87    vector.print punctuation <newline>
88  }
89
90  // CHECK: SME: END OF TEST OUTPUT
91  vector.print str "SME: END OF TEST OUTPUT\n"
92
93  return
94}
95
96module attributes {transform.with_named_sequence} {
97  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
98    %0 = transform.structured.match ops{["linalg.fill"]} in %arg1 : (!transform.any_op) -> !transform.any_op
99    transform.structured.vectorize %0 vector_sizes [[4], [4]] : !transform.any_op
100    transform.yield
101  }
102}
103