xref: /llvm-project/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td (revision b3ce6dc7232c566c21b84ac5d5795341a355ff79)
1//===- SCFOps.td - Structured Control Flow operations ------*- tablegen -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Defines MLIR structured control flow operations.
10//
11//===----------------------------------------------------------------------===//
12
13#ifndef MLIR_DIALECT_SCF_SCFOPS
14#define MLIR_DIALECT_SCF_SCFOPS
15
16include "mlir/Interfaces/ControlFlowInterfaces.td"
17include "mlir/Interfaces/LoopLikeInterface.td"
18include "mlir/IR/RegionKindInterface.td"
19include "mlir/Dialect/SCF/IR/DeviceMappingInterface.td"
20include "mlir/Interfaces/DestinationStyleOpInterface.td"
21include "mlir/Interfaces/InferTypeOpInterface.td"
22include "mlir/Interfaces/ParallelCombiningOpInterface.td"
23include "mlir/Interfaces/SideEffectInterfaces.td"
24include "mlir/Interfaces/ViewLikeInterface.td"
25
26def SCF_Dialect : Dialect {
27  let name = "scf";
28  let cppNamespace = "::mlir::scf";
29
30  let description = [{
31    The `scf` (structured control flow) dialect contains operations that
32    represent control flow constructs such as `if` and `for`. Being
33    _structured_ means that the control flow has a structure unlike, for
34    example, `goto`s or `assert`s. Unstructured control flow operations are
35    located in the `cf` (control flow) dialect.
36
37    Originally, this dialect was developed as a common lowering stage for the
38    `affine` and `linalg` dialects. Both convert to SCF loops instead of
39    targeting branch-based CFGs directly. Typically, `scf` is lowered to `cf`
40    and then lowered to some final target like LLVM or SPIR-V.
41  }];
42
43  let dependentDialects = ["arith::ArithDialect"];
44}
45
46// Base class for SCF dialect ops.
47class SCF_Op<string mnemonic, list<Trait> traits = []> :
48    Op<SCF_Dialect, mnemonic, traits>;
49
50//===----------------------------------------------------------------------===//
51// ConditionOp
52//===----------------------------------------------------------------------===//
53
54def ConditionOp : SCF_Op<"condition", [
55  HasParent<"WhileOp">,
56  DeclareOpInterfaceMethods<RegionBranchTerminatorOpInterface,
57    ["getSuccessorRegions"]>,
58  Pure,
59  Terminator
60]> {
61  let summary = "loop continuation condition";
62  let description = [{
63    This operation accepts the continuation (i.e., inverse of exit) condition
64    of the `scf.while` construct. If its first argument is true, the "after"
65    region of `scf.while` is executed, with the remaining arguments forwarded
66    to the entry block of the region. Otherwise, the loop terminates.
67  }];
68
69  let arguments = (ins I1:$condition, Variadic<AnyType>:$args);
70
71  let assemblyFormat =
72      [{ `(` $condition `)` attr-dict ($args^ `:` type($args))? }];
73}
74
75//===----------------------------------------------------------------------===//
76// ExecuteRegionOp
77//===----------------------------------------------------------------------===//
78
79def ExecuteRegionOp : SCF_Op<"execute_region", [
80    DeclareOpInterfaceMethods<RegionBranchOpInterface>]> {
81  let summary = "operation that executes its region exactly once";
82  let description = [{
83    The `scf.execute_region` operation is used to allow multiple blocks within SCF
84    and other operations which can hold only one block.  The `scf.execute_region`
85    operation executes the region held exactly once and cannot have any operands.
86    As such, its region has no arguments. All SSA values that dominate the op can
87    be accessed inside the op. The op's region can have multiple blocks and the
88    blocks can have multiple distinct terminators. Values returned from this op's
89    region define the op's results.
90
91    Example:
92
93    ```mlir
94    scf.for %i = 0 to 128 step %c1 {
95      %y = scf.execute_region -> i32 {
96        %x = load %A[%i] : memref<128xi32>
97        scf.yield %x : i32
98      }
99    }
100
101    affine.for %i = 0 to 100 {
102      "foo"() : () -> ()
103      %v = scf.execute_region -> i64 {
104        cf.cond_br %cond, ^bb1, ^bb2
105
106      ^bb1:
107        %c1 = arith.constant 1 : i64
108        cf.br ^bb3(%c1 : i64)
109
110      ^bb2:
111        %c2 = arith.constant 2 : i64
112        cf.br ^bb3(%c2 : i64)
113
114      ^bb3(%x : i64):
115        scf.yield %x : i64
116      }
117      "bar"(%v) : (i64) -> ()
118    }
119    ```
120  }];
121
122  let results = (outs Variadic<AnyType>);
123
124  let regions = (region AnyRegion:$region);
125
126  let hasCanonicalizer = 1;
127  let hasCustomAssemblyFormat = 1;
128
129  let hasVerifier = 1;
130}
131
132//===----------------------------------------------------------------------===//
133// ForOp
134//===----------------------------------------------------------------------===//
135
136def ForOp : SCF_Op<"for",
137      [AutomaticAllocationScope, DeclareOpInterfaceMethods<LoopLikeOpInterface,
138       ["getInitsMutable", "getLoopResults", "getRegionIterArgs",
139        "getLoopInductionVars", "getLoopLowerBounds", "getLoopSteps",
140        "getLoopUpperBounds", "getYieldedValuesMutable",
141        "promoteIfSingleIteration", "replaceWithAdditionalYields",
142        "yieldTiledValuesAndReplace"]>,
143       AllTypesMatch<["lowerBound", "upperBound", "step"]>,
144       ConditionallySpeculatable,
145       DeclareOpInterfaceMethods<RegionBranchOpInterface,
146        ["getEntrySuccessorOperands"]>,
147       SingleBlockImplicitTerminator<"scf::YieldOp">,
148       RecursiveMemoryEffects]> {
149  let summary = "for operation";
150  let description = [{
151    The `scf.for` operation represents a loop taking 3 SSA value as operands
152    that represent the lower bound, upper bound and step respectively. The
153    operation defines an SSA value for its induction variable. It has one
154    region capturing the loop body. The induction variable is represented as an
155    argument of this region. This SSA value is a signless integer or index.
156    The step is a value of same type but required to be positive, the lower and
157    upper bounds can be also negative or zero. The lower and upper bounds specify
158    a half-open range: the iteration is executed iff the signed comparison of induction
159    variable value is less than the upper bound and bigger or equal to the lower bound.
160
161    The body region must contain exactly one block that terminates with
162    `scf.yield`. Calling ForOp::build will create such a region and insert
163    the terminator implicitly if none is defined, so will the parsing even in
164    cases when it is absent from the custom format. For example:
165
166    ```mlir
167    // Index case.
168    scf.for %iv = %lb to %ub step %step {
169      ... // body
170    }
171    ...
172    // Integer case.
173    scf.for %iv_32 = %lb_32 to %ub_32 step %step_32 : i32 {
174      ... // body
175    }
176    ```
177
178    `scf.for` can also operate on loop-carried variables and returns the final
179    values after loop termination. The initial values of the variables are
180    passed as additional SSA operands to the `scf.for` following the 3 loop
181    control SSA values mentioned above (lower bound, upper bound and step). The
182    operation region has an argument for the induction variable, followed by
183    one argument for each loop-carried variable, representing the value of the
184    variable at the current iteration.
185
186    The region must terminate with a `scf.yield` that passes the current
187    values of all loop-carried variables to the next iteration, or to the
188    `scf.for` result, if at the last iteration. The static type of a
189    loop-carried variable may not change with iterations; its runtime type is
190    allowed to change. Note, that when the loop-carried variables are present,
191    calling ForOp::build will not insert the terminator implicitly. The caller
192    must insert `scf.yield` in that case.
193
194    `scf.for` results hold the final values after the last iteration.
195    For example, to sum-reduce a memref:
196
197    ```mlir
198    func.func @reduce(%buffer: memref<1024xf32>, %lb: index,
199                      %ub: index, %step: index) -> (f32) {
200      // Initial sum set to 0.
201      %sum_0 = arith.constant 0.0 : f32
202      // iter_args binds initial values to the loop's region arguments.
203      %sum = scf.for %iv = %lb to %ub step %step
204          iter_args(%sum_iter = %sum_0) -> (f32) {
205        %t = load %buffer[%iv] : memref<1024xf32>
206        %sum_next = arith.addf %sum_iter, %t : f32
207        // Yield current iteration sum to next iteration %sum_iter or to %sum
208        // if final iteration.
209        scf.yield %sum_next : f32
210      }
211      return %sum : f32
212    }
213    ```
214
215    If the `scf.for` defines any values, a yield must be explicitly present.
216    The number and types of the `scf.for` results must match the initial
217    values in the `iter_args` binding and the yield operands.
218
219    Another example with a nested `scf.if` (see `scf.if` for details) to
220    perform conditional reduction:
221
222    ```mlir
223    func.func @conditional_reduce(%buffer: memref<1024xf32>, %lb: index,
224                                  %ub: index, %step: index) -> (f32) {
225      %sum_0 = arith.constant 0.0 : f32
226      %c0 = arith.constant 0.0 : f32
227      %sum = scf.for %iv = %lb to %ub step %step
228          iter_args(%sum_iter = %sum_0) -> (f32) {
229        %t = load %buffer[%iv] : memref<1024xf32>
230        %cond = arith.cmpf "ugt", %t, %c0 : f32
231        %sum_next = scf.if %cond -> (f32) {
232          %new_sum = arith.addf %sum_iter, %t : f32
233          scf.yield %new_sum : f32
234        } else {
235          scf.yield %sum_iter : f32
236        }
237        scf.yield %sum_next : f32
238      }
239      return %sum : f32
240    }
241    ```
242  }];
243  let arguments = (ins AnySignlessIntegerOrIndex:$lowerBound,
244                       AnySignlessIntegerOrIndex:$upperBound,
245                       AnySignlessIntegerOrIndex:$step,
246                       Variadic<AnyType>:$initArgs);
247  let results = (outs Variadic<AnyType>:$results);
248  let regions = (region SizedRegion<1>:$region);
249
250  let skipDefaultBuilders = 1;
251  let builders = [
252    OpBuilder<(ins "Value":$lowerBound, "Value":$upperBound, "Value":$step,
253      CArg<"ValueRange", "std::nullopt">:$initArgs,
254      CArg<"function_ref<void(OpBuilder &, Location, Value, ValueRange)>",
255           "nullptr">)>
256  ];
257
258  let extraClassDeclaration = [{
259    using BodyBuilderFn =
260        function_ref<void(OpBuilder &, Location, Value, ValueRange)>;
261
262    Value getInductionVar() { return getBody()->getArgument(0); }
263
264    /// Return the `index`-th region iteration argument.
265    BlockArgument getRegionIterArg(unsigned index) {
266      assert(index < getNumRegionIterArgs() &&
267        "expected an index less than the number of region iter args");
268      return getBody()->getArguments().drop_front(getNumInductionVars())[index];
269    }
270
271    void setLowerBound(Value bound) { getOperation()->setOperand(0, bound); }
272    void setUpperBound(Value bound) { getOperation()->setOperand(1, bound); }
273    void setStep(Value step) { getOperation()->setOperand(2, step); }
274
275    /// Number of induction variables, always 1 for scf::ForOp.
276    unsigned getNumInductionVars() { return 1; }
277    /// Number of region arguments for loop-carried values
278    unsigned getNumRegionIterArgs() {
279      return getBody()->getNumArguments() - getNumInductionVars();
280    }
281    /// Number of operands controlling the loop: lb, ub, step
282    unsigned getNumControlOperands() { return 3; }
283
284    /// Returns the step as an `APInt` if it is constant.
285    std::optional<APInt> getConstantStep();
286
287    /// Interface method for ConditionallySpeculatable.
288    Speculation::Speculatability getSpeculatability();
289  }];
290
291  let hasCanonicalizer = 1;
292  let hasCustomAssemblyFormat = 1;
293  let hasVerifier = 1;
294  let hasRegionVerifier = 1;
295}
296
297//===----------------------------------------------------------------------===//
298// ForallOp
299//===----------------------------------------------------------------------===//
300
301def ForallOp : SCF_Op<"forall", [
302       AttrSizedOperandSegments,
303       AutomaticAllocationScope,
304       DeclareOpInterfaceMethods<LoopLikeOpInterface,
305          ["getInitsMutable", "getRegionIterArgs", "getLoopInductionVars",
306           "getLoopLowerBounds", "getLoopUpperBounds", "getLoopSteps",
307           "promoteIfSingleIteration", "yieldTiledValuesAndReplace"]>,
308       RecursiveMemoryEffects,
309       SingleBlockImplicitTerminator<"scf::InParallelOp">,
310       DeclareOpInterfaceMethods<RegionBranchOpInterface>,
311       DestinationStyleOpInterface,
312       HasParallelRegion
313     ]> {
314  let summary = "evaluate a block multiple times in parallel";
315  let description = [{
316    `scf.forall` is a target-independent multi-dimensional parallel
317    region application operation. It has exactly one block that represents the
318    parallel body and it takes index operands that specify lower bounds, upper
319    bounds and steps.
320
321    The op also takes a variadic number of tensor operands (`shared_outs`).
322    The future buffers corresponding to these tensors are shared among all
323    threads. Shared tensors should be accessed via their corresponding block
324    arguments. If multiple threads write to a shared buffer in a racy
325    fashion, these writes will execute in some unspecified order. Tensors that
326    are not shared can be used inside the body (i.e., the op is not isolated
327    from above); however, if a use of such a tensor bufferizes to a memory
328    write, the tensor is privatized, i.e., a thread-local copy of the tensor is
329    used. This ensures that memory side effects of a thread are not visible to
330    other threads (or in the parent body), apart from explicitly shared tensors.
331
332    The name "thread" conveys the fact that the parallel execution is mapped
333    (i.e. distributed) to a set of virtual threads of execution, one function
334    application per thread. Further lowerings are responsible for specifying
335    how this is materialized on concrete hardware resources.
336
337    An optional `mapping` is an attribute array that specifies processing units
338    with their dimension, how it remaps 1-1 to a set of concrete processing
339    element resources (e.g. a CUDA grid dimension or a level of concrete nested
340    async parallelism). It is expressed via any attribute that implements the
341    device mapping interface. It is the reponsibility of the lowering mechanism
342    to interpret the `mapping` attributes in the context of the concrete target
343    the op is lowered to, or to ignore it when the specification is ill-formed
344    or unsupported for a particular target.
345
346    The only allowed terminator is `scf.forall.in_parallel`.
347    `scf.forall` returns one value per `shared_out` operand. The
348    actions of the `scf.forall.in_parallel` terminators specify how to combine the
349    partial results of all parallel invocations into a full value, in some
350    unspecified order. The "destination" of each such op must be a `shared_out`
351    block argument of the `scf.forall` op.
352
353    The actions involved in constructing the return values are further described
354    by `tensor.parallel_insert_slice`.
355
356    `scf.forall` acts as an implicit synchronization point.
357
358    When the parallel function body has side effects, their order is unspecified
359    across threads.
360
361    `scf.forall` can be printed in two different ways depending on
362    whether the loop is normalized or not. The loop is 'normalized' when all
363    lower bounds are equal to zero and steps are equal to one. In that case,
364    `lowerBound` and `step` operands will be omitted during printing.
365
366    Normalized loop example:
367
368    ```mlir
369    //
370    // Sequential context.
371    //
372    %matmul_and_pointwise:2 = scf.forall (%thread_id_1, %thread_id_2) in
373        (%num_threads_1, %numthread_id_2) shared_outs(%o1 = %C, %o2 = %pointwise)
374      -> (tensor<?x?xT>, tensor<?xT>) {
375      //
376      // Parallel context, each thread with id = (%thread_id_1, %thread_id_2)
377      // runs its version of the code.
378      //
379      %sA = tensor.extract_slice %A[f((%thread_id_1, %thread_id_2))]:
380        tensor<?x?xT> to tensor<?x?xT>
381      %sB = tensor.extract_slice %B[g((%thread_id_1, %thread_id_2))]:
382        tensor<?x?xT> to tensor<?x?xT>
383      %sC = tensor.extract_slice %o1[h((%thread_id_1, %thread_id_2))]:
384        tensor<?x?xT> to tensor<?x?xT>
385      %sD = linalg.matmul
386        ins(%sA, %sB : tensor<?x?xT>, tensor<?x?xT>)
387        outs(%sC : tensor<?x?xT>)
388
389      %spointwise = subtensor %o2[i((%thread_id_1, %thread_id_2))]:
390        tensor<?xT> to tensor<?xT>
391      %sE = linalg.add ins(%spointwise : tensor<?xT>) outs(%sD : tensor<?xT>)
392
393      scf.forall.in_parallel {
394        tensor.parallel_insert_slice %sD into %o1[h((%thread_id_1, %thread_id_2))]:
395          tensor<?x?xT> into tensor<?x?xT>
396
397        tensor.parallel_insert_slice %spointwise into %o2[i((%thread_id_1, %thread_id_2))]:
398          tensor<?xT> into tensor<?xT>
399      }
400    }
401    // Implicit synchronization point.
402    // Sequential context.
403    //
404    ```
405
406    Loop with loop bounds example:
407
408    ```mlir
409    //
410    // Sequential context.
411    //
412    %pointwise = scf.forall (%i, %j) = (0, 0) to (%dim1, %dim2)
413      step (%tileSize1, %tileSize2) shared_outs(%o1 = %out)
414      -> (tensor<?x?xT>, tensor<?xT>) {
415      //
416      // Parallel context.
417      //
418      %sA = tensor.extract_slice %A[%i, %j][%tileSize1, %tileSize2][1, 1]
419        : tensor<?x?xT> to tensor<?x?xT>
420      %sB = tensor.extract_slice %B[%i, %j][%tileSize1, %tileSize2][1, 1]
421        : tensor<?x?xT> to tensor<?x?xT>
422      %sC = tensor.extract_slice %o[%i, %j][%tileSize1, %tileSize2][1, 1]
423        : tensor<?x?xT> to tensor<?x?xT>
424
425      %add = linalg.map {"arith.addf"}
426        ins(%sA, %sB : tensor<?x?xT>, tensor<?x?xT>)
427        outs(%sC : tensor<?x?xT>)
428
429      scf.forall.in_parallel {
430        tensor.parallel_insert_slice %add into
431          %o[%i, %j][%tileSize1, %tileSize2][1, 1]
432          : tensor<?x?xT> into tensor<?x?xT>
433      }
434    }
435    // Implicit synchronization point.
436    // Sequential context.
437    //
438    ```
439
440    Example with mapping attribute:
441
442    ```mlir
443    //
444    // Sequential context. Here `mapping` is expressed as GPU thread mapping
445    // attributes
446    //
447    %matmul_and_pointwise:2 = scf.forall (%thread_id_1, %thread_id_2) in
448        (%num_threads_1, %numthread_id_2) shared_outs(...)
449      -> (tensor<?x?xT>, tensor<?xT>) {
450      //
451      // Parallel context, each thread with id = **(%thread_id_2, %thread_id_1)**
452      // runs its version of the code.
453      //
454       scf.forall.in_parallel {
455         ...
456      }
457    } { mapping = [#gpu.thread<y>, #gpu.thread<x>] }
458    // Implicit synchronization point.
459    // Sequential context.
460    //
461    ```
462
463    Example with privatized tensors:
464
465    ```mlir
466    %t0 = ...
467    %t1 = ...
468    %r = scf.forall ... shared_outs(%o = t0) -> tensor<?xf32> {
469      // %t0 and %t1 are privatized. %t0 is definitely copied for each thread
470      // because the scf.forall op's %t0 use bufferizes to a memory
471      // write. In the absence of other conflicts, %t1 is copied only if there
472      // are uses of %t1 in the body that bufferize to a memory read and to a
473      // memory write.
474      "some_use"(%t0)
475      "some_use"(%t1)
476    }
477    ```
478  }];
479  let arguments = (ins
480    Variadic<Index>:$dynamicLowerBound,
481    Variadic<Index>:$dynamicUpperBound,
482    Variadic<Index>:$dynamicStep,
483    DenseI64ArrayAttr:$staticLowerBound,
484    DenseI64ArrayAttr:$staticUpperBound,
485    DenseI64ArrayAttr:$staticStep,
486    Variadic<AnyRankedTensor>:$outputs,
487    OptionalAttr<DeviceMappingArrayAttr>:$mapping);
488
489  let results = (outs Variadic<AnyType>:$results);
490  let regions = (region SizedRegion<1>:$region);
491
492  let hasCanonicalizer = 1;
493  let hasCustomAssemblyFormat = 1;
494  let hasVerifier = 1;
495
496  // The default builder does not add the proper body BBargs, roll our own.
497  let skipDefaultBuilders = 1;
498  let builders = [
499    // Builder that takes loop bounds.
500    OpBuilder<(ins "ArrayRef<OpFoldResult>":$lbs,
501       "ArrayRef<OpFoldResult>":$ubs, "ArrayRef<OpFoldResult>":$steps,
502       "ValueRange":$outputs, "std::optional<ArrayAttr>":$mapping,
503       CArg<"function_ref<void(OpBuilder &, Location, ValueRange)>",
504            "nullptr"> :$bodyBuilderFn)>,
505
506    // Builder for normalized loop that takes only upper bounds.
507    OpBuilder<(ins "ArrayRef<OpFoldResult>":$ubs,
508       "ValueRange":$outputs, "std::optional<ArrayAttr>":$mapping,
509       CArg<"function_ref<void(OpBuilder &, Location, ValueRange)>",
510            "nullptr"> :$bodyBuilderFn)>,
511  ];
512
513  let extraClassDeclaration = [{
514    /// Get induction variables.
515    SmallVector<Value> getInductionVars() {
516      std::optional<SmallVector<Value>> maybeInductionVars = getLoopInductionVars();
517      assert(maybeInductionVars.has_value() && "expected values");
518      return *maybeInductionVars;
519    }
520    /// Get lower bounds as OpFoldResult.
521    SmallVector<OpFoldResult> getMixedLowerBound() {
522      std::optional<SmallVector<OpFoldResult>> maybeLowerBounds = getLoopLowerBounds();
523      assert(maybeLowerBounds.has_value() && "expected values");
524      return *maybeLowerBounds;
525    }
526
527    /// Get upper bounds as OpFoldResult.
528    SmallVector<OpFoldResult> getMixedUpperBound() {
529      std::optional<SmallVector<OpFoldResult>> maybeUpperBounds = getLoopUpperBounds();
530      assert(maybeUpperBounds.has_value() && "expected values");
531      return *maybeUpperBounds;
532    }
533
534    /// Get steps as OpFoldResult.
535    SmallVector<OpFoldResult> getMixedStep() {
536      std::optional<SmallVector<OpFoldResult>> maybeSteps = getLoopSteps();
537      assert(maybeSteps.has_value() && "expected values");
538      return *maybeSteps;
539    }
540
541    /// Get lower bounds as values.
542    SmallVector<Value> getLowerBound(OpBuilder &b) {
543      return getValueOrCreateConstantIndexOp(b, getLoc(), getMixedLowerBound());
544    }
545
546    /// Get upper bounds as values.
547    SmallVector<Value> getUpperBound(OpBuilder &b) {
548      return getValueOrCreateConstantIndexOp(b, getLoc(), getMixedUpperBound());
549    }
550
551    /// Get steps as values.
552    SmallVector<Value> getStep(OpBuilder &b) {
553      return getValueOrCreateConstantIndexOp(b, getLoc(), getMixedStep());
554    }
555
556    int64_t getRank() { return getStaticLowerBound().size(); }
557
558    /// Number of operands controlling the loop: lbs, ubs, steps
559    unsigned getNumControlOperands() { return 3 * getRank(); }
560
561    /// Number of dynamic operands controlling the loop: lbs, ubs, steps
562    unsigned getNumDynamicControlOperands() {
563      return getODSOperandIndexAndLength(3).first;
564    }
565
566    OpResult getTiedOpResult(OpOperand *opOperand) {
567      assert(opOperand->getOperandNumber() >= getNumDynamicControlOperands() &&
568             "invalid operand");
569      return getOperation()->getOpResult(
570          opOperand->getOperandNumber() - getNumDynamicControlOperands());
571    }
572
573    /// Return the num_threads operand that is tied to the given thread id
574    /// block argument.
575    OpOperand *getTiedOpOperand(BlockArgument bbArg) {
576      assert(bbArg.getArgNumber() >= getRank() && "invalid bbArg");
577
578      return &getOperation()->getOpOperand(getNumDynamicControlOperands() +
579                                           bbArg.getArgNumber() - getRank());
580    }
581
582    /// Return the shared_outs operand that is tied to the given OpResult.
583    OpOperand *getTiedOpOperand(OpResult opResult) {
584      assert(opResult.getDefiningOp() == getOperation() && "invalid OpResult");
585      return &getOperation()->getOpOperand(getNumDynamicControlOperands() +
586                                           opResult.getResultNumber());
587    }
588
589    BlockArgument getTiedBlockArgument(OpOperand *opOperand) {
590      assert(opOperand->getOperandNumber() >= getNumDynamicControlOperands() &&
591             "invalid operand");
592
593      return getBody()->getArgument(opOperand->getOperandNumber() -
594                                    getNumDynamicControlOperands() + getRank());
595    }
596
597    ::mlir::Value getInductionVar(int64_t idx) {
598      return getInductionVars()[idx];
599    }
600
601    ::mlir::Block::BlockArgListType getRegionOutArgs() {
602      return getBody()->getArguments().drop_front(getRank());
603    }
604
605    /// Checks if the lbs are zeros and steps are ones.
606    bool isNormalized();
607
608    // The ensureTerminator method generated by SingleBlockImplicitTerminator is
609    // unaware of the fact that our terminator also needs a region to be
610    // well-formed. We override it here to ensure that we do the right thing.
611    static void ensureTerminator(Region & region, OpBuilder & builder,
612                                 Location loc);
613
614    InParallelOp getTerminator();
615
616    // Declare the shared_outs as inits/outs to DestinationStyleOpInterface.
617    MutableOperandRange getDpsInitsMutable() { return getOutputsMutable(); }
618
619    /// Returns operations within scf.forall.in_parallel whose destination
620    /// operand is the block argument `bbArg`.
621    SmallVector<Operation*> getCombiningOps(BlockArgument bbArg);
622  }];
623}
624
625//===----------------------------------------------------------------------===//
626// InParallelOp
627//===----------------------------------------------------------------------===//
628
629def InParallelOp : SCF_Op<"forall.in_parallel", [
630       Pure,
631       Terminator,
632       DeclareOpInterfaceMethods<ParallelCombiningOpInterface>,
633       HasParent<"ForallOp">,
634      ] # GraphRegionNoTerminator.traits> {
635  let summary = "terminates a `forall` block";
636  let description = [{
637    The `scf.forall.in_parallel` is a designated terminator for
638    the `scf.forall` operation.
639
640    It has a single region with a single block that contains a flat list of ops.
641    Each such op participates in the aggregate formation of a single result of
642    the enclosing `scf.forall`.
643    The result number corresponds to the position of the op in the terminator.
644  }];
645
646  let regions = (region SizedRegion<1>:$region);
647
648  let hasCustomAssemblyFormat = 1;
649  let hasVerifier = 1;
650
651  // The default builder does not add a region with an empty body, add our own.
652  let skipDefaultBuilders = 1;
653  let builders = [
654    OpBuilder<(ins)>,
655  ];
656
657  // TODO: Add a `InParallelOpInterface` interface for ops that can
658  // appear inside in_parallel.
659  let extraClassDeclaration = [{
660    ::llvm::SmallVector<::mlir::BlockArgument> getDests();
661    ::llvm::iterator_range<::mlir::Block::iterator> getYieldingOps();
662    ::mlir::OpResult getParentResult(int64_t idx);
663  }];
664}
665
666//===----------------------------------------------------------------------===//
667// IfOp
668//===----------------------------------------------------------------------===//
669
670def IfOp : SCF_Op<"if", [DeclareOpInterfaceMethods<RegionBranchOpInterface, [
671    "getNumRegionInvocations", "getRegionInvocationBounds",
672    "getEntrySuccessorRegions"]>,
673    InferTypeOpAdaptor, SingleBlockImplicitTerminator<"scf::YieldOp">,
674    RecursiveMemoryEffects, RecursivelySpeculatable, NoRegionArguments]> {
675  let summary = "if-then-else operation";
676  let description = [{
677    The `scf.if` operation represents an if-then-else construct for
678    conditionally executing two regions of code. The operand to an if operation
679    is a boolean value. For example:
680
681    ```mlir
682    scf.if %b  {
683      ...
684    } else {
685      ...
686    }
687    ```
688
689    `scf.if` may also produce results. Which values are returned depends on
690    which execution path is taken.
691
692    Example:
693
694    ```mlir
695    %x, %y = scf.if %b -> (f32, f32) {
696      %x_true = ...
697      %y_true = ...
698      scf.yield %x_true, %y_true : f32, f32
699    } else {
700      %x_false = ...
701      %y_false = ...
702      scf.yield %x_false, %y_false : f32, f32
703    }
704    ```
705
706    The "then" region has exactly 1 block. The "else" region may have 0 or 1
707    block. In case the `scf.if` produces results, the "else" region must also
708    have exactly 1 block.
709
710    The blocks are always terminated with `scf.yield`. If `scf.if` defines no
711    values, the `scf.yield` can be left out, and will be inserted implicitly.
712    Otherwise, it must be explicit.
713
714    Example:
715
716    ```mlir
717    scf.if %b  {
718      ...
719    }
720    ```
721
722    The types of the yielded values must match the result types of the
723    `scf.if`.
724  }];
725  let arguments = (ins I1:$condition);
726  let results = (outs Variadic<AnyType>:$results);
727  let regions = (region SizedRegion<1>:$thenRegion,
728                        MaxSizedRegion<1>:$elseRegion);
729
730  let skipDefaultBuilders = 1;
731  let builders = [
732    OpBuilder<(ins "TypeRange":$resultTypes, "Value":$cond)>,
733    OpBuilder<(ins "TypeRange":$resultTypes, "Value":$cond,
734      "bool":$addThenBlock, "bool":$addElseBlock)>,
735    OpBuilder<(ins "Value":$cond, "bool":$withElseRegion)>,
736    OpBuilder<(ins "TypeRange":$resultTypes, "Value":$cond,
737      "bool":$withElseRegion)>,
738    OpBuilder<(ins "Value":$cond,
739      CArg<"function_ref<void(OpBuilder &, Location)>",
740           "buildTerminatedBody">:$thenBuilder,
741      CArg<"function_ref<void(OpBuilder &, Location)>",
742           "nullptr">:$elseBuilder)>,
743  ];
744
745  let extraClassDeclaration = [{
746    OpBuilder getThenBodyBuilder(OpBuilder::Listener *listener = nullptr) {
747      Block* body = getBody(0);
748      return getResults().empty() ? OpBuilder::atBlockTerminator(body, listener)
749                                  : OpBuilder::atBlockEnd(body, listener);
750    }
751    OpBuilder getElseBodyBuilder(OpBuilder::Listener *listener = nullptr) {
752      Block* body = getBody(1);
753      return getResults().empty() ? OpBuilder::atBlockTerminator(body, listener)
754                                  : OpBuilder::atBlockEnd(body, listener);
755    }
756    Block* thenBlock();
757    YieldOp thenYield();
758    Block* elseBlock();
759    YieldOp elseYield();
760  }];
761  let hasFolder = 1;
762  let hasCanonicalizer = 1;
763  let hasCustomAssemblyFormat = 1;
764  let hasVerifier = 1;
765}
766
767//===----------------------------------------------------------------------===//
768// ParallelOp
769//===----------------------------------------------------------------------===//
770
771def ParallelOp : SCF_Op<"parallel",
772    [AutomaticAllocationScope,
773     AttrSizedOperandSegments,
774     DeclareOpInterfaceMethods<LoopLikeOpInterface, ["getLoopInductionVars",
775          "getLoopLowerBounds", "getLoopUpperBounds", "getLoopSteps"]>,
776     RecursiveMemoryEffects,
777     DeclareOpInterfaceMethods<RegionBranchOpInterface>,
778     SingleBlockImplicitTerminator<"scf::ReduceOp">,
779     HasParallelRegion]> {
780  let summary = "parallel for operation";
781  let description = [{
782    The `scf.parallel` operation represents a loop nest taking 4 groups of SSA
783    values as operands that represent the lower bounds, upper bounds, steps and
784    initial values, respectively. The operation defines a variadic number of
785    SSA values for its induction variables. It has one region capturing the
786    loop body. The induction variables are represented as an argument of this
787    region. These SSA values always have type index, which is the size of the
788    machine word. The steps are values of type index, required to be positive.
789    The lower and upper bounds specify a half-open range: the range includes
790    the lower bound but does not include the upper bound. The initial values
791    have the same types as results of `scf.parallel`. If there are no results,
792    the keyword `init` can be omitted.
793
794    Semantically we require that the iteration space can be iterated in any
795    order, and the loop body can be executed in parallel. If there are data
796    races, the behavior is undefined.
797
798    The parallel loop operation supports reduction of values produced by
799    individual iterations into a single result. This is modeled using the
800    `scf.reduce` terminator operation (see `scf.reduce` for details). The i-th
801    result of an `scf.parallel` operation is associated with the i-th initial
802    value operand, the i-th operand of the `scf.reduce` operation (the value to
803    be reduced) and the i-th region of the `scf.reduce` operation (the reduction
804    function). Consequently, we require that the number of results of an
805    `scf.parallel` op matches the number of initial values and the the number of
806    reductions in the `scf.reduce` terminator.
807
808    The body region must contain exactly one block that terminates with a
809    `scf.reduce` operation. If an `scf.parallel` op has no reductions, the
810    terminator has no operands and no regions. The `scf.parallel` parser will
811    automatically insert the terminator for ops that have no reductions if it is
812    absent.
813
814    Example:
815
816    ```mlir
817    %init = arith.constant 0.0 : f32
818    %r:2 = scf.parallel (%iv) = (%lb) to (%ub) step (%step) init (%init, %init)
819        -> f32, f32 {
820      %elem_to_reduce1 = load %buffer1[%iv] : memref<100xf32>
821      %elem_to_reduce2 = load %buffer2[%iv] : memref<100xf32>
822      scf.reduce(%elem_to_reduce1, %elem_to_reduce2 : f32, f32) {
823        ^bb0(%lhs : f32, %rhs: f32):
824          %res = arith.addf %lhs, %rhs : f32
825          scf.reduce.return %res : f32
826      }, {
827        ^bb0(%lhs : f32, %rhs: f32):
828          %res = arith.mulf %lhs, %rhs : f32
829          scf.reduce.return %res : f32
830      }
831    }
832    ```
833  }];
834
835  let arguments = (ins Variadic<Index>:$lowerBound,
836                       Variadic<Index>:$upperBound,
837                       Variadic<Index>:$step,
838                       Variadic<AnyType>:$initVals);
839  let results = (outs Variadic<AnyType>:$results);
840  let regions = (region SizedRegion<1>:$region);
841
842  let skipDefaultBuilders = 1;
843  let builders = [
844    OpBuilder<(ins "ValueRange":$lowerBounds, "ValueRange":$upperBounds,
845      "ValueRange":$steps, "ValueRange":$initVals,
846      CArg<"function_ref<void (OpBuilder &, Location, ValueRange, ValueRange)>",
847           "nullptr">:$bodyBuilderFn)>,
848    OpBuilder<(ins "ValueRange":$lowerBounds, "ValueRange":$upperBounds,
849      "ValueRange":$steps,
850      CArg<"function_ref<void (OpBuilder &, Location, ValueRange)>",
851           "nullptr">:$bodyBuilderFn)>,
852  ];
853
854  let extraClassDeclaration = [{
855    /// Get induction variables.
856    SmallVector<Value> getInductionVars() {
857      std::optional<SmallVector<Value>> maybeInductionVars = getLoopInductionVars();;
858      assert(maybeInductionVars.has_value() && "expected values");
859      return *maybeInductionVars;
860    }
861    unsigned getNumLoops() { return getStep().size(); }
862    unsigned getNumReductions() { return getInitVals().size(); }
863  }];
864
865  let hasCanonicalizer = 1;
866  let hasCustomAssemblyFormat = 1;
867  let hasVerifier = 1;
868}
869
870//===----------------------------------------------------------------------===//
871// ReduceOp
872//===----------------------------------------------------------------------===//
873
874def ReduceOp : SCF_Op<"reduce", [
875    Terminator, HasParent<"ParallelOp">, RecursiveMemoryEffects,
876    DeclareOpInterfaceMethods<RegionBranchTerminatorOpInterface>]> {
877  let summary = "reduce operation for scf.parallel";
878  let description = [{
879    The `scf.reduce` operation is the terminator for `scf.parallel` operations. It can model
880    an arbitrary number of reductions. It has one region per reduction. Each
881    region has one block with two arguments which have the same type as the
882    corresponding operand of `scf.reduce`. The operands of the op are the values
883    that should be reduce; one value per reduction.
884
885    The i-th reduction (i.e., the i-th region and the i-th operand) corresponds
886    the i-th initial value and the i-th result of the enclosing `scf.parallel`
887    op.
888
889    The `scf.reduce` operation contains regions whose entry blocks expect two
890    arguments of the same type as the corresponding operand. As the iteration
891    order of the enclosing parallel loop and hence reduction order is
892    unspecified, the results of the reductions may be non-deterministic unless
893    the reductions are associative and commutative.
894
895    The result of a reduction region (`scf.reduce.return` operand) must have the
896    same type as the corresponding `scf.reduce` operand and the corresponding
897    `scf.parallel` initial value.
898
899    Example:
900
901    ```mlir
902    %operand = arith.constant 1.0 : f32
903    scf.reduce(%operand : f32) {
904      ^bb0(%lhs : f32, %rhs: f32):
905        %res = arith.addf %lhs, %rhs : f32
906        scf.reduce.return %res : f32
907    }
908    ```
909  }];
910
911  let skipDefaultBuilders = 1;
912  let builders = [
913    OpBuilder<(ins "ValueRange":$operands)>,
914    OpBuilder<(ins)>
915  ];
916
917  let arguments = (ins Variadic<AnyType>:$operands);
918  let assemblyFormat = [{
919    (`(` $operands^ `:` type($operands) `)`)? $reductions attr-dict
920  }];
921  let regions = (region VariadicRegion<SizedRegion<1>>:$reductions);
922  let hasRegionVerifier = 1;
923}
924
925//===----------------------------------------------------------------------===//
926// ReduceReturnOp
927//===----------------------------------------------------------------------===//
928
929def ReduceReturnOp :
930    SCF_Op<"reduce.return", [HasParent<"ReduceOp">, Pure, Terminator]> {
931  let summary = "terminator for reduce operation";
932  let description = [{
933    The `scf.reduce.return` operation is a special terminator operation for the block inside
934    `scf.reduce` regions. It terminates the region. It should have the same
935    operand type as the corresponding operand of the enclosing `scf.reduce` op.
936
937    Example:
938
939    ```mlir
940    scf.reduce.return %res : f32
941    ```
942  }];
943
944  let arguments = (ins AnyType:$result);
945  let assemblyFormat = "$result attr-dict `:` type($result)";
946  let hasVerifier = 1;
947}
948
949//===----------------------------------------------------------------------===//
950// WhileOp
951//===----------------------------------------------------------------------===//
952
953def WhileOp : SCF_Op<"while",
954    [DeclareOpInterfaceMethods<RegionBranchOpInterface,
955        ["getEntrySuccessorOperands"]>,
956     DeclareOpInterfaceMethods<LoopLikeOpInterface,
957        ["getRegionIterArgs", "getYieldedValuesMutable"]>,
958     RecursiveMemoryEffects, SingleBlock]> {
959  let summary = "a generic 'while' loop";
960  let description = [{
961    This operation represents a generic "while"/"do-while" loop that keeps
962    iterating as long as a condition is satisfied. There is no restriction on
963    the complexity of the condition. It consists of two regions (with single
964    block each): "before" region and "after" region. The names of regions
965    indicates whether they execute before or after the condition check.
966    Therefore, if the main loop payload is located in the "before" region, the
967    operation is a "do-while" loop. Otherwise, it is a "while" loop.
968
969    The "before" region terminates with a special operation, `scf.condition`,
970    that accepts as its first operand an `i1` value indicating whether to
971    proceed to the "after" region (value is `true`) or not. The two regions
972    communicate by means of region arguments. Initially, the "before" region
973    accepts as arguments the operands of the `scf.while` operation and uses them
974    to evaluate the condition. It forwards the trailing, non-condition operands
975    of the `scf.condition` terminator either to the "after" region if the
976    control flow is transferred there or to results of the `scf.while` operation
977    otherwise. The "after" region takes as arguments the values produced by the
978    "before" region and uses `scf.yield` to supply new arguments for the
979    "before" region, into which it transfers the control flow unconditionally.
980
981    A simple "while" loop can be represented as follows.
982
983    ```mlir
984    %res = scf.while (%arg1 = %init1) : (f32) -> f32 {
985      // "Before" region.
986      // In a "while" loop, this region computes the condition.
987      %condition = call @evaluate_condition(%arg1) : (f32) -> i1
988
989      // Forward the argument (as result or "after" region argument).
990      scf.condition(%condition) %arg1 : f32
991
992    } do {
993    ^bb0(%arg2: f32):
994      // "After" region.
995      // In a "while" loop, this region is the loop body.
996      %next = call @payload(%arg2) : (f32) -> f32
997
998      // Forward the new value to the "before" region.
999      // The operand types must match the types of the `scf.while` operands.
1000      scf.yield %next : f32
1001    }
1002    ```
1003
1004    A simple "do-while" loop can be represented by reducing the "after" block
1005    to a simple forwarder.
1006
1007    ```mlir
1008    %res = scf.while (%arg1 = %init1) : (f32) -> f32 {
1009      // "Before" region.
1010      // In a "do-while" loop, this region contains the loop body.
1011      %next = call @payload(%arg1) : (f32) -> f32
1012
1013      // And also evaluates the condition.
1014      %condition = call @evaluate_condition(%arg1) : (f32) -> i1
1015
1016      // Loop through the "after" region.
1017      scf.condition(%condition) %next : f32
1018
1019    } do {
1020    ^bb0(%arg2: f32):
1021      // "After" region.
1022      // Forwards the values back to "before" region unmodified.
1023      scf.yield %arg2 : f32
1024    }
1025    ```
1026
1027    Note that the types of region arguments need not to match with each other.
1028    The op expects the operand types to match with argument types of the
1029    "before" region; the result types to match with the trailing operand types
1030    of the terminator of the "before" region, and with the argument types of the
1031    "after" region. The following scheme can be used to share the results of
1032    some operations executed in the "before" region with the "after" region,
1033    avoiding the need to recompute them.
1034
1035    ```mlir
1036    %res = scf.while (%arg1 = %init1) : (f32) -> i64 {
1037      // One can perform some computations, e.g., necessary to evaluate the
1038      // condition, in the "before" region and forward their results to the
1039      // "after" region.
1040      %shared = call @shared_compute(%arg1) : (f32) -> i64
1041
1042      // Evaluate the condition.
1043      %condition = call @evaluate_condition(%arg1, %shared) : (f32, i64) -> i1
1044
1045      // Forward the result of the shared computation to the "after" region.
1046      // The types must match the arguments of the "after" region as well as
1047      // those of the `scf.while` results.
1048      scf.condition(%condition) %shared : i64
1049
1050    } do {
1051    ^bb0(%arg2: i64) {
1052      // Use the partial result to compute the rest of the payload in the
1053      // "after" region.
1054      %res = call @payload(%arg2) : (i64) -> f32
1055
1056      // Forward the new value to the "before" region.
1057      // The operand types must match the types of the `scf.while` operands.
1058      scf.yield %res : f32
1059    }
1060    ```
1061
1062    The custom syntax for this operation is as follows.
1063
1064    ```
1065    op ::= `scf.while` assignments `:` function-type region `do` region
1066           `attributes` attribute-dict
1067    initializer ::= /* empty */ | `(` assignment-list `)`
1068    assignment-list ::= assignment | assignment `,` assignment-list
1069    assignment ::= ssa-value `=` ssa-value
1070    ```
1071  }];
1072
1073  let arguments = (ins Variadic<AnyType>:$inits);
1074  let results = (outs Variadic<AnyType>:$results);
1075  let regions = (region SizedRegion<1>:$before, SizedRegion<1>:$after);
1076
1077  let builders = [
1078    OpBuilder<(ins "TypeRange":$resultTypes, "ValueRange":$inits,
1079      "function_ref<void(OpBuilder &, Location, ValueRange)>":$beforeBuilder,
1080      "function_ref<void(OpBuilder &, Location, ValueRange)>":$afterBuilder)>
1081  ];
1082
1083  let extraClassDeclaration = [{
1084    using BodyBuilderFn =
1085        function_ref<void(OpBuilder &, Location, ValueRange)>;
1086
1087    ConditionOp getConditionOp();
1088    YieldOp getYieldOp();
1089
1090    Block::BlockArgListType getBeforeArguments();
1091    Block::BlockArgListType getAfterArguments();
1092    Block *getBeforeBody() { return &getBefore().front(); }
1093    Block *getAfterBody() { return &getAfter().front(); }
1094  }];
1095
1096  let hasCanonicalizer = 1;
1097  let hasCustomAssemblyFormat = 1;
1098  let hasVerifier = 1;
1099}
1100
1101//===----------------------------------------------------------------------===//
1102// IndexSwitchOp
1103//===----------------------------------------------------------------------===//
1104
1105def IndexSwitchOp : SCF_Op<"index_switch", [RecursiveMemoryEffects,
1106    SingleBlockImplicitTerminator<"scf::YieldOp">,
1107    DeclareOpInterfaceMethods<RegionBranchOpInterface,
1108                              ["getRegionInvocationBounds",
1109                               "getEntrySuccessorRegions"]>]> {
1110  let summary = "switch-case operation on an index argument";
1111  let description = [{
1112    The `scf.index_switch` is a control-flow operation that branches to one of
1113    the given regions based on the values of the argument and the cases. The
1114    argument is always of type `index`.
1115
1116    The operation always has a "default" region and any number of case regions
1117    denoted by integer constants. Control-flow transfers to the case region
1118    whose constant value equals the value of the argument. If the argument does
1119    not equal any of the case values, control-flow transfer to the "default"
1120    region.
1121
1122    Example:
1123
1124    ```mlir
1125    %0 = scf.index_switch %arg0 : index -> i32
1126    case 2 {
1127      %1 = arith.constant 10 : i32
1128      scf.yield %1 : i32
1129    }
1130    case 5 {
1131      %2 = arith.constant 20 : i32
1132      scf.yield %2 : i32
1133    }
1134    default {
1135      %3 = arith.constant 30 : i32
1136      scf.yield %3 : i32
1137    }
1138    ```
1139  }];
1140
1141  let arguments = (ins Index:$arg, DenseI64ArrayAttr:$cases);
1142  let results = (outs Variadic<AnyType>:$results);
1143  let regions = (region SizedRegion<1>:$defaultRegion,
1144                        VariadicRegion<SizedRegion<1>>:$caseRegions);
1145
1146  let assemblyFormat = [{
1147    $arg attr-dict (`->` type($results)^)?
1148    custom<SwitchCases>($cases, $caseRegions) `\n`
1149    `` `default` $defaultRegion
1150  }];
1151
1152  let extraClassDeclaration = [{
1153    /// Get the number of cases.
1154    unsigned getNumCases();
1155
1156    /// Get the default region body.
1157    Block &getDefaultBlock();
1158
1159    /// Get the body of a case region.
1160    Block &getCaseBlock(unsigned idx);
1161  }];
1162
1163  let hasCanonicalizer = 1;
1164  let hasVerifier = 1;
1165}
1166
1167//===----------------------------------------------------------------------===//
1168// YieldOp
1169//===----------------------------------------------------------------------===//
1170
1171def YieldOp : SCF_Op<"yield", [Pure, ReturnLike, Terminator,
1172    ParentOneOf<["ExecuteRegionOp", "ForOp", "IfOp", "IndexSwitchOp",
1173                 "WhileOp"]>]> {
1174  let summary = "loop yield and termination operation";
1175  let description = [{
1176    The `scf.yield` operation yields an SSA value from the SCF dialect op region and
1177    terminates the regions. The semantics of how the values are yielded is
1178    defined by the parent operation.
1179    If `scf.yield` has any operands, the operands must match the parent
1180    operation's results.
1181    If the parent operation defines no values, then the `scf.yield` may be
1182    left out in the custom syntax and the builders will insert one implicitly.
1183    Otherwise, it has to be present in the syntax to indicate which values are
1184    yielded.
1185  }];
1186
1187  let arguments = (ins Variadic<AnyType>:$results);
1188  let builders = [OpBuilder<(ins), [{ /* nothing to do */ }]>];
1189
1190  let assemblyFormat =
1191      [{  attr-dict ($results^ `:` type($results))? }];
1192}
1193
1194#endif // MLIR_DIALECT_SCF_SCFOPS
1195