Skip to content

'iree_vector_ext' Dialectlink

IREE Vector Extensions.

A dialect designed for experimenting with vector operations beyond what is currently available in the Vector Dialect.

Operationslink

iree_vector_ext.arg_compare (VectorExt::ArgCompareOp)link

Vectorized arg-reduction using a user-defined comparator.

Syntax:

operation ::= `iree_vector_ext.arg_compare` attr-dict
              `dimension` `(` $dimension `)`
              `ins` `(` $input_value (`,` $input_index^)? `:` type($input_value) (`,` type($input_index)^)? `)`
              `inits` `(` $init_value `,` $init_index `:` type($init_value) `,` type($init_index) `)`
              (`index_base` `(` $index_base^ `:` type($index_base) `)`)?
              $region `->` type($result_value) `,` type($result_index)

The iree_vector_ext.arg_compare op is the vectorized form of iree_linalg_ext.arg_compare. It performs a reduction over a given dimension of a vector, returning both the selected value and its corresponding index.

Comparator region semantics:

The comparator region defines the predication logic that determines the selection rule for the reduction (e.g., "greater than" for argmax or "less than" for argmin). The region takes two scalar values of the input element type as arguments and returns a single i1 result via iree_vector_ext.yield.

The region is invoked during the reduction to determine which element to select: when the comparison yields true, the first argument is selected; otherwise, the second argument is selected.

The region must contain only pure operations (operations with the Pure trait). This ensures the comparator can be safely executed in any order during the reduction.

Example (implicit-index mode - argmax over dim 1):

%input_vec = vector<4x128xf32>
%out_val_vec = vector<4xf32>
%out_idx_vec = vector<4xi32>
%result:2 = iree_vector_ext.arg_compare
  dimension(1)
  ins(%input_vec : vector<4x128xf32>)
  inits(%out_val_vec, %out_idx_vec : vector<4xf32>, vector<4xi32>) {
^bb0(%a: f32, %b: f32):
  %cmp = arith.cmpf ogt, %a, %b : f32
  iree_vector_ext.yield %cmp : i1
} -> vector<4xf32>, vector<4xi32>

Example (explicit-index mode):

%partial_vals = vector<4x32xf32>
%partial_idxs = vector<4x32xi32>
%out_val = vector<4xf32>
%out_idx = vector<4xi32>
%result:2 = iree_vector_ext.arg_compare
  dimension(1)
  ins(%partial_vals, %partial_idxs : vector<4x32xf32>, vector<4x32xi32>)
  inits(%out_val, %out_idx : vector<4xf32>, vector<4xi32>) {
^bb0(%a: f32, %b: f32):
  %cmp = arith.cmpf ogt, %a, %b : f32
  iree_vector_ext.yield %cmp : i1
} -> vector<4xf32>, vector<4xi32>

The index_base is an optional offset value that, when specified, is added to the computed indices in the result. This is useful when reducing over a sliced subregion where the indices need to be adjusted to reflect their position in the original vector. The index_base can only be used in implicit-index mode (single input).

The inits operands provide the initial accumulator values for the reduction (initial max/min value and index).

Traits: AlwaysSpeculatableImplTrait, AttrSizedOperandSegments, SingleBlockImplicitTerminator<IREE::VectorExt::YieldOp>, SingleBlock

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Attributes:link
AttributeMLIR TypeDescription
dimension::mlir::IntegerAttr64-bit signless integer attribute
Operands:link
Operand Description
input_value vector of any type values
input_index vector of any type values
init_value vector of any type values
init_index vector of any type values
index_base index
Results:link
Result Description
result_value vector of any type values
result_index vector of any type values

iree_vector_ext.to_layout (VectorExt::ToLayoutOp)link

Layout conversion operator.

Syntax:

operation ::= `iree_vector_ext.to_layout` $input `to` `layout` `(` $layout `)` attr-dict `:` type($input)

The layout conversion operator takes a shaped value and a layout and transforms the value to have that layout.

If the "shared_memory_conversion" attribute is set, then this layout change has to be materialized through shared memory.

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Attributes:link
AttributeMLIR TypeDescription
layout::mlir::iree_compiler::IREE::VectorExt::VectorLayoutInterfaceVectorLayoutInterface instance
shared_memory_conversion::mlir::UnitAttrunit attribute
mma_kind::mlir::Attributeany attribute
Operands:link
Operand Description
input shaped of any type values
Results:link
Result Description
output shaped of any type values

iree_vector_ext.to_simd (VectorExt::ToSIMDOp)link

SIMT to SIMD conversion operation.

Syntax:

operation ::= `iree_vector_ext.to_simd` $input attr-dict `:` type($input) `->` type($output)

This operation is a temporary operation useful for source/target materializations when doing type conversions between distributed and not distributed vectors.

Traits: AlwaysSpeculatableImplTrait, SameOperandsAndResultElementType

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Operands:link
Operand Description
input vector of any type values
Results:link
Result Description
output vector of any type values

iree_vector_ext.to_simt (VectorExt::ToSIMTOp)link

SIMD to SIMT conversion operation.

Syntax:

operation ::= `iree_vector_ext.to_simt` $input attr-dict `:` type($input) `->` type($output)

This operation is a temporary operation useful for source/target materializations when doing type conversions between distributed and not distributed vectors.

Traits: AlwaysSpeculatableImplTrait, SameOperandsAndResultElementType

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Operands:link
Operand Description
input vector of any type values
Results:link
Result Description
output vector of any type values

iree_vector_ext.transfer_gather (VectorExt::TransferGatherOp)link

Gathers a supervector from a shaped source into an SSA vector value.

Syntax:

operation ::= `iree_vector_ext.transfer_gather` $base `[` $offsets `]` (`[` $index_vecs^ `:` type($index_vecs) `]`)? `,` $padding (`,` $mask^)? attr-dict `:` type($base) `,` type($vector) (`,` type($mask)^)?

The transfer_gather operation reads elements from a shaped source (memref or tensor) into a vector, where each source dimension can be independently contiguous, gathered, or broadcast.

Semantically, for each position in the result vector:

result[d0, d1, ...] = base[offsets[0] + f0(d, s), offsets[1] + f1(d, s), ...]

where each fi is the i-th result of the source indexing map evaluated at the vector position d = (d0, d1, ...) and gathered index values s = (s0, s1, ...).

The indexing_maps attribute describes all indexing. Every map has numDims = result vector rank and numSymbols = number of index vecs:

  • Map 0 (source map): (vector_dims)[symbols] -> (source_dims). A dim expr means the source dimension is contiguous (iterated in lockstep with the vector dimension). A symbol expr means the source dimension is gathered (looked up via the corresponding index vector). A constant 0 means the source dimension is broadcast (always reads at the base offset).
  • Maps 1..N (index vec maps): (vector_dims)[symbols] -> (index_vec_dims). Describes how each index vector is indexed from the vector iteration space. Only dim exprs are allowed.
  • Optional last map (mask map): present only when a mask operand is provided. Only dim exprs are allowed.

Example — embedding lookup: reading rows from a 3D source where the row index is gathered from an index vector, while the column is contiguous:

// result[i, j] = base[0, indices[i], j]
%result = iree_vector_ext.transfer_gather %base[%c0, %c0, %c0]
  [%indices : vector<16xindex>], %pad {
    indexing_maps = [
      affine_map<(d0, d1)[s0] -> (0, s0, d1)>,
      affine_map<(d0, d1)[s0] -> (d0)>
    ]
  } : memref<4096x512x8xf16>, vector<16x8xf16>

Traits: AttrSizedOperandSegments

Interfaces: ConditionallySpeculatable, MemoryEffectOpInterface

Attributes:link
AttributeMLIR TypeDescription
indexing_maps::mlir::ArrayAttrAffineMap array attribute
Operands:link
Operand Description
base shaped of any type values
offsets variadic of index
index_vecs variadic of index or vector of index values
padding any type
mask vector of 1-bit signless integer values
Results:link
Result Description
vector vector of any type values

iree_vector_ext.yield (VectorExt::YieldOp)link

Yield operation for VectorExt operations with regions

Syntax:

operation ::= `iree_vector_ext.yield` attr-dict ($values^ `:` type($values))?

Yields values from regions in VectorExt operations.

Traits: AlwaysSpeculatableImplTrait, HasParent<ArgCompareOp>, ReturnLike, Terminator

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface), RegionBranchTerminatorOpInterface

Effects: MemoryEffects::Effect{}

Operands:link
Operand Description
values variadic of any type

Attributeslink

NestedLayoutAttrlink

A layout representing a mapping from GPU thread hierarchy to a shape.

Syntax:

#iree_vector_ext.nested_layout<
  ::llvm::ArrayRef<int64_t>,   # subgroupTile
  ::llvm::ArrayRef<int64_t>,   # batchTile
  ::llvm::ArrayRef<int64_t>,   # outerTile
  ::llvm::ArrayRef<int64_t>,   # threadTile
  ::llvm::ArrayRef<int64_t>,   # elementTile
  ::llvm::ArrayRef<int64_t>,   # subgroupStrides
  ::llvm::ArrayRef<int64_t>   # threadStrides
>

This layout explicitly defines how the shape of the associated vector is mapped to a compute hierarchy. We consider the following levels of hierarchy, inspired by GPUs:

  1. Subgroups per workgroup
  2. Threads per subgroup
  3. Elements per thread

Note that elements in a thread is also conceptually viewed as a 3 dimensions. i.e. elements per thread = batch x outer x element However, the final order of sub-dimensions are not exactly in that hierarchy. For e.g. a single dimensional vector say vector< n x f16> is viewed as a vector<subgroup x batch x outer x thread x element> 5 dimensional vector. For a two dimensional vector, each above sub-dimension would be doubled. i.e. vector< n1 x n2 x f16> is viewed as a vector<subgroup1 x subgroup2 x batch1 x batch2 x ... x element1 x element2>

Now, when the vector is indexed, the indices of 'subgroup' and thread are not directly refferring to the subgroup_id and thread_id in the GPU context. lets define them as virtual_subgroup_id and virtual_thread_id and they hold the following definition:

virtual_subgroup_id[i] = (subgroup_id / subgroup_stride[i]) % subgroup_tile_size[i]
virtual_thread_id[i]   = (thread_id   / thread_stride[i]) % thread_tile_size[i]

the inverse mapping would be:

subgroup_id = sum_i(subgroup_stride[i] * virtual_subgroup_id[i]) % mul_i(subgroup_tile_size[i])
thread_id = sum_i(thread_stride[i] * virtual_thread_id[i]) % mul_i(thread_tile_size[i])
    for i = [0 : rank(undistributed_vector)]

NOTE: if stride is zero, it represents non-distribution of that dimension on that hierarchy.

We now describe each level of tiling. Each level of tiling represents a count of tiles over the next level (rather than a list of tile sizes).

Subgroups per Workgrouplink

This level of tiling is also known as "subgroup/warp distribution". It represents how the vector is distributed into subgroups.

For example, consider distributing vector<4x2xf16> to a subgroup_tile=[4, 2], subgroup_stride=[1, 4] will arrange the subgroups in the order:

virtual_subgroups_ids:
[0][0] , [0][1] , [1][0], [1][1], [2][0], [2][1], [3][0], [3][1]
subgroups_ids:
0, 4, 1, 5, 2, 6, 3, 7

The subgroups are placed contiguously with their shape and ordering determined by: - subgroup_tile: Sizes of this level of tiling - subgroup_strides: Stride of this level of tiling. 0 if not distributed. Tiling levels must not overlap.

The total number of subgroups used (computed by multiplying each dim in subgroup_tile) should be a multiple of number of subgroups in the harware. If the total number of subgroups used exceeds the number of subgroups of the hardware, then the subgroup used (say x) is x mod num_subgroups:

num_subgroups = 4

0, 4, 1, 5, 2, 6, 3, 7
| mod 4
V
0, 0, 1, 1, 2, 2, 3, 3

Threads per Subgroup:link

This level of tiling is also known as "thread distribution" within a subgroup. The logic is quite similiar to subgroup distribution using the tile sizes and the 'thread_strides'.

Element distribution on a threadlink

So after the vector is distributed per thread on a subgroup, it is viewed as [batch] x [outer] x [element] where each sub-dimensions group has dimensions equal to original rank of the undistributed vector.

The first level, batches, are a way to represent instruction unrolling. For example, an intrinsic which can only take 4x4 shape at a time, uses batches to unroll a 16x16 shape to the native intrinsice shape.

The second level, outers, is a way to represent thread layout duplication required by a particular intrinsic. For example, some AMDGPU matrix multiplication variants require threads to be distributed like:

E.g.: outer_tile=[2, 1], thread_tile=[2, 5] the thread Layout of shape 2x5 duplicated 2 times, to get a layout of shape 4x5

outer = 0,0 :
[0 1 2 3 4]
[5 6 7 8 9]

outer = 1,0 :
[0 1 2 3 4]
[5 6 7 8 9]

outer_tile represents the number of outers in a batch.

The final level of tiling, representing the minimum shape of vector that is treated as an atom.

element_tile represents the native size of the vector.

A full examplelink

Vector to be distributed: vector<64x64xf16>

NestedLayout : <
    subgroup_tile = [2, 1],
    batch_tile = [2, 4],
    outer_tile = [1, 1],
    thread_tile = [16, 4],
    element_tile = [1, 4],
    subgroup_strides = [1, 0],
    thread_strides = [1, 16]
>

This is conceptually viewed as a: vector<[2x1]x[2x4]x[1x1]x[16x4]x[1x4]> where the first groups of sub-dimensions represent the distribution into subgroups. The subgroup_strides being [1, 0] means each subgroup is going to get a vector as follows:

subgroup0 : vector<[2x4]x[1x1]x[16x4]x[1x4]>
from vector<[2x1]x[2x4]x[1x1]x[16x4]x[1x4]>[0,:,:,:,:,:,:,:,:,:]
subgroup1 : vector<[2x4]x[1x1]x[16x4]x[1x4]>
from vector<[2x1]x[2x4]x[1x1]x[16x4]x[1x4]>[1,:,:,:,:,:,:,:,:,:]
subgroup2 : vector<[2x4]x[1x1]x[16x4]x[1x4]>
from vector<[2x1]x[2x4]x[1x1]x[16x4]x[1x4]>[0,:,:,:,:,:,:,:,:,:]
subgroup3 : vector<[2x4]x[1x1]x[16x4]x[1x4]>
from vector<[2x1]x[2x4]x[1x1]x[16x4]x[1x4]>[1,:,:,:,:,:,:,:,:,:]

Then each vector<[2x4]x[1x1]x[16x4]x[1x4]> is distributed threads in a subgroup using thread_strides = [1, 16]

recall: thread_id = sum_i(thread_stride[i] * virtual_thread_id[i]) % mul_i(thread_tile_size[i])

thread0 : vector<[2x4]x[1x1]x[1x4]>
from vector<[2x4]x[1x1]x[16x4]x[1x4]>[:,:,:,:,0,0,:,:]
thread1 : vector<[2x4]x[1x1]x[1x4]>
from vector<[2x4]x[1x1]x[16x4]x[1x4]>[:,:,:,:,1,0,:,:]
...
...
thread16 : vector<[2x4]x[1x1]x[1x4]>
from vector<[2x4]x[1x1]x[16x4]x[1x4]>[:,:,:,:,0,1,:,:]

Finally we are left with a distributed vector of conceptual view : vector<[2x4]x[1x1]x[1x4]> where the actual shape is : vector<2x16>.

Parameters:link
Parameter C++ type Description
subgroupTile ::llvm::ArrayRef<int64_t> subgroup_tile
batchTile ::llvm::ArrayRef<int64_t> batch_tile
outerTile ::llvm::ArrayRef<int64_t> outer_tile
threadTile ::llvm::ArrayRef<int64_t> thread_tile
elementTile ::llvm::ArrayRef<int64_t> element_tile
subgroupStrides ::llvm::ArrayRef<int64_t> subgroup_strides
threadStrides ::llvm::ArrayRef<int64_t> thread_strides