'iree_vector_ext' Dialectlink
IREE Vector Extensions.
A dialect designed for experimenting with vector operations beyond what is currently available in the Vector Dialect.
- 'iree_vector_ext' Dialect
- Operations
- iree_vector_ext.arg_compare (VectorExt::ArgCompareOp)
- iree_vector_ext.to_layout (VectorExt::ToLayoutOp)
- iree_vector_ext.to_simd (VectorExt::ToSIMDOp)
- iree_vector_ext.to_simt (VectorExt::ToSIMTOp)
- iree_vector_ext.transfer_gather (VectorExt::TransferGatherOp)
- iree_vector_ext.yield (VectorExt::YieldOp)
- Attributes
- Operations
Operationslink
iree_vector_ext.arg_compare (VectorExt::ArgCompareOp)link
Vectorized arg-reduction using a user-defined comparator.
Syntax:
operation ::= `iree_vector_ext.arg_compare` attr-dict
`dimension` `(` $dimension `)`
`ins` `(` $input_value (`,` $input_index^)? `:` type($input_value) (`,` type($input_index)^)? `)`
`inits` `(` $init_value `,` $init_index `:` type($init_value) `,` type($init_index) `)`
(`index_base` `(` $index_base^ `:` type($index_base) `)`)?
$region `->` type($result_value) `,` type($result_index)
The iree_vector_ext.arg_compare op is the vectorized form of
iree_linalg_ext.arg_compare. It performs a reduction over a given
dimension of a vector, returning both the selected value and its
corresponding index.
Comparator region semantics:
The comparator region defines the predication logic that determines the
selection rule for the reduction (e.g., "greater than" for argmax or
"less than" for argmin). The region takes two scalar values of the input
element type as arguments and returns a single i1 result via
iree_vector_ext.yield.
The region is invoked during the reduction to determine which element to
select: when the comparison yields true, the first argument is selected;
otherwise, the second argument is selected.
The region must contain only pure operations (operations with the Pure
trait). This ensures the comparator can be safely executed in any order
during the reduction.
Example (implicit-index mode - argmax over dim 1):
%input_vec = vector<4x128xf32>
%out_val_vec = vector<4xf32>
%out_idx_vec = vector<4xi32>
%result:2 = iree_vector_ext.arg_compare
dimension(1)
ins(%input_vec : vector<4x128xf32>)
inits(%out_val_vec, %out_idx_vec : vector<4xf32>, vector<4xi32>) {
^bb0(%a: f32, %b: f32):
%cmp = arith.cmpf ogt, %a, %b : f32
iree_vector_ext.yield %cmp : i1
} -> vector<4xf32>, vector<4xi32>
Example (explicit-index mode):
%partial_vals = vector<4x32xf32>
%partial_idxs = vector<4x32xi32>
%out_val = vector<4xf32>
%out_idx = vector<4xi32>
%result:2 = iree_vector_ext.arg_compare
dimension(1)
ins(%partial_vals, %partial_idxs : vector<4x32xf32>, vector<4x32xi32>)
inits(%out_val, %out_idx : vector<4xf32>, vector<4xi32>) {
^bb0(%a: f32, %b: f32):
%cmp = arith.cmpf ogt, %a, %b : f32
iree_vector_ext.yield %cmp : i1
} -> vector<4xf32>, vector<4xi32>
The index_base is an optional offset value that, when specified, is added
to the computed indices in the result. This is useful when reducing over a
sliced subregion where the indices need to be adjusted to reflect their
position in the original vector. The index_base can only be used in
implicit-index mode (single input).
The inits operands provide the initial accumulator values for the
reduction (initial max/min value and index).
Traits: AlwaysSpeculatableImplTrait, AttrSizedOperandSegments, SingleBlockImplicitTerminator<IREE::VectorExt::YieldOp>, SingleBlock
Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface)
Effects: MemoryEffects::Effect{}
Attributes:link
| Attribute | MLIR Type | Description |
|---|---|---|
dimension | ::mlir::IntegerAttr | 64-bit signless integer attribute |
Operands:link
| Operand | Description |
|---|---|
input_value |
vector of any type values |
input_index |
vector of any type values |
init_value |
vector of any type values |
init_index |
vector of any type values |
index_base |
index |
Results:link
| Result | Description |
|---|---|
result_value |
vector of any type values |
result_index |
vector of any type values |
iree_vector_ext.to_layout (VectorExt::ToLayoutOp)link
Layout conversion operator.
Syntax:
operation ::= `iree_vector_ext.to_layout` $input `to` `layout` `(` $layout `)` attr-dict `:` type($input)
The layout conversion operator takes a shaped value and a layout and transforms the value to have that layout.
If the "shared_memory_conversion" attribute is set, then this layout change has to be materialized through shared memory.
Traits: AlwaysSpeculatableImplTrait
Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface)
Effects: MemoryEffects::Effect{}
Attributes:link
| Attribute | MLIR Type | Description |
|---|---|---|
layout | ::mlir::iree_compiler::IREE::VectorExt::VectorLayoutInterface | VectorLayoutInterface instance |
shared_memory_conversion | ::mlir::UnitAttr | unit attribute |
mma_kind | ::mlir::Attribute | any attribute |
Operands:link
| Operand | Description |
|---|---|
input |
shaped of any type values |
Results:link
| Result | Description |
|---|---|
output |
shaped of any type values |
iree_vector_ext.to_simd (VectorExt::ToSIMDOp)link
SIMT to SIMD conversion operation.
Syntax:
operation ::= `iree_vector_ext.to_simd` $input attr-dict `:` type($input) `->` type($output)
This operation is a temporary operation useful for source/target materializations when doing type conversions between distributed and not distributed vectors.
Traits: AlwaysSpeculatableImplTrait, SameOperandsAndResultElementType
Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface)
Effects: MemoryEffects::Effect{}
Operands:link
| Operand | Description |
|---|---|
input |
vector of any type values |
Results:link
| Result | Description |
|---|---|
output |
vector of any type values |
iree_vector_ext.to_simt (VectorExt::ToSIMTOp)link
SIMD to SIMT conversion operation.
Syntax:
operation ::= `iree_vector_ext.to_simt` $input attr-dict `:` type($input) `->` type($output)
This operation is a temporary operation useful for source/target materializations when doing type conversions between distributed and not distributed vectors.
Traits: AlwaysSpeculatableImplTrait, SameOperandsAndResultElementType
Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface)
Effects: MemoryEffects::Effect{}
Operands:link
| Operand | Description |
|---|---|
input |
vector of any type values |
Results:link
| Result | Description |
|---|---|
output |
vector of any type values |
iree_vector_ext.transfer_gather (VectorExt::TransferGatherOp)link
Gathers a supervector from a shaped source into an SSA vector value.
Syntax:
operation ::= `iree_vector_ext.transfer_gather` $base `[` $offsets `]` (`[` $index_vecs^ `:` type($index_vecs) `]`)? `,` $padding (`,` $mask^)? attr-dict `:` type($base) `,` type($vector) (`,` type($mask)^)?
The transfer_gather operation reads elements from a shaped source (memref
or tensor) into a vector, where each source dimension can be independently
contiguous, gathered, or broadcast.
Semantically, for each position in the result vector:
result[d0, d1, ...] = base[offsets[0] + f0(d, s), offsets[1] + f1(d, s), ...]
where each fi is the i-th result of the source indexing map evaluated at
the vector position d = (d0, d1, ...) and gathered index values
s = (s0, s1, ...).
The indexing_maps attribute describes all indexing. Every map has
numDims = result vector rank and numSymbols = number of index vecs:
- Map 0 (source map):
(vector_dims)[symbols] -> (source_dims). A dim expr means the source dimension is contiguous (iterated in lockstep with the vector dimension). A symbol expr means the source dimension is gathered (looked up via the corresponding index vector). A constant 0 means the source dimension is broadcast (always reads at the base offset). - Maps 1..N (index vec maps):
(vector_dims)[symbols] -> (index_vec_dims). Describes how each index vector is indexed from the vector iteration space. Only dim exprs are allowed. - Optional last map (mask map): present only when a mask operand is provided. Only dim exprs are allowed.
Example — embedding lookup: reading rows from a 3D source where the row index is gathered from an index vector, while the column is contiguous:
// result[i, j] = base[0, indices[i], j]
%result = iree_vector_ext.transfer_gather %base[%c0, %c0, %c0]
[%indices : vector<16xindex>], %pad {
indexing_maps = [
affine_map<(d0, d1)[s0] -> (0, s0, d1)>,
affine_map<(d0, d1)[s0] -> (d0)>
]
} : memref<4096x512x8xf16>, vector<16x8xf16>
Traits: AttrSizedOperandSegments
Interfaces: ConditionallySpeculatable, MemoryEffectOpInterface
Attributes:link
| Attribute | MLIR Type | Description |
|---|---|---|
indexing_maps | ::mlir::ArrayAttr | AffineMap array attribute |
Operands:link
| Operand | Description |
|---|---|
base |
shaped of any type values |
offsets |
variadic of index |
index_vecs |
variadic of index or vector of index values |
padding |
any type |
mask |
vector of 1-bit signless integer values |
Results:link
| Result | Description |
|---|---|
vector |
vector of any type values |
iree_vector_ext.yield (VectorExt::YieldOp)link
Yield operation for VectorExt operations with regions
Syntax:
operation ::= `iree_vector_ext.yield` attr-dict ($values^ `:` type($values))?
Yields values from regions in VectorExt operations.
Traits: AlwaysSpeculatableImplTrait, HasParent<ArgCompareOp>, ReturnLike, Terminator
Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface), RegionBranchTerminatorOpInterface
Effects: MemoryEffects::Effect{}
Operands:link
| Operand | Description |
|---|---|
values |
variadic of any type |
Attributeslink
NestedLayoutAttrlink
A layout representing a mapping from GPU thread hierarchy to a shape.
Syntax:
#iree_vector_ext.nested_layout<
::llvm::ArrayRef<int64_t>, # subgroupTile
::llvm::ArrayRef<int64_t>, # batchTile
::llvm::ArrayRef<int64_t>, # outerTile
::llvm::ArrayRef<int64_t>, # threadTile
::llvm::ArrayRef<int64_t>, # elementTile
::llvm::ArrayRef<int64_t>, # subgroupStrides
::llvm::ArrayRef<int64_t> # threadStrides
>
This layout explicitly defines how the shape of the associated vector is mapped to a compute hierarchy. We consider the following levels of hierarchy, inspired by GPUs:
- Subgroups per workgroup
- Threads per subgroup
- Elements per thread
Note that elements in a thread is also conceptually viewed as
a 3 dimensions. i.e. elements per thread = batch x outer x element
However, the final order of sub-dimensions are not exactly in that
hierarchy. For e.g. a single dimensional vector say vector< n x f16>
is viewed as a
vector<subgroup x batch x outer x thread x element> 5 dimensional
vector. For a two dimensional vector, each above sub-dimension would
be doubled. i.e. vector< n1 x n2 x f16> is viewed as a
vector<subgroup1 x subgroup2 x batch1 x batch2 x ... x element1 x element2>
Now, when the vectorthread are not directly refferring
to the subgroup_id and thread_id in the GPU context. lets define them
as virtual_subgroup_id and virtual_thread_id and they hold the following
definition:
virtual_subgroup_id[i] = (subgroup_id / subgroup_stride[i]) % subgroup_tile_size[i]
virtual_thread_id[i] = (thread_id / thread_stride[i]) % thread_tile_size[i]
the inverse mapping would be:
subgroup_id = sum_i(subgroup_stride[i] * virtual_subgroup_id[i]) % mul_i(subgroup_tile_size[i])
thread_id = sum_i(thread_stride[i] * virtual_thread_id[i]) % mul_i(thread_tile_size[i])
for i = [0 : rank(undistributed_vector)]
NOTE: if stride is zero, it represents non-distribution of that dimension on that hierarchy.
We now describe each level of tiling. Each level of tiling represents a count of tiles over the next level (rather than a list of tile sizes).
Subgroups per Workgrouplink
This level of tiling is also known as "subgroup/warp distribution". It represents how the vector is distributed into subgroups.
For example, consider distributing vector<4x2xf16> to a
subgroup_tile=[4, 2], subgroup_stride=[1, 4] will
arrange the subgroups in the order:
virtual_subgroups_ids:
[0][0] , [0][1] , [1][0], [1][1], [2][0], [2][1], [3][0], [3][1]
subgroups_ids:
0, 4, 1, 5, 2, 6, 3, 7
The subgroups are placed contiguously with their shape and ordering
determined by:
- subgroup_tile: Sizes of this level of tiling
- subgroup_strides: Stride of this level of tiling. 0 if not distributed.
Tiling levels must not overlap.
The total number of subgroups used (computed by multiplying each dim in subgroup_tile) should be a multiple of number of subgroups in the harware. If the total number of subgroups used exceeds the number of subgroups of the hardware, then the subgroup used (say x) is x mod num_subgroups:
num_subgroups = 4
0, 4, 1, 5, 2, 6, 3, 7
| mod 4
V
0, 0, 1, 1, 2, 2, 3, 3
Threads per Subgroup:link
This level of tiling is also known as "thread distribution" within a subgroup. The logic is quite similiar to subgroup distribution using the tile sizes and the 'thread_strides'.
Element distribution on a threadlink
So after the vector is distributed per thread on a subgroup, it is viewed as [batch] x [outer] x [element] where each sub-dimensions group has dimensions equal to original rank of the undistributed vector.
The first level, batches, are a way to represent instruction unrolling. For example, an intrinsic which can only take 4x4 shape at a time, uses batches to unroll a 16x16 shape to the native intrinsice shape.
The second level, outers, is a way to represent thread layout duplication required by a particular intrinsic. For example, some AMDGPU matrix multiplication variants require threads to be distributed like:
E.g.: outer_tile=[2, 1], thread_tile=[2, 5]
the thread Layout of shape 2x5 duplicated 2 times, to get a layout of shape 4x5
outer = 0,0 :
[0 1 2 3 4]
[5 6 7 8 9]
outer = 1,0 :
[0 1 2 3 4]
[5 6 7 8 9]
outer_tile represents the number of outers in a batch.
The final level of tiling, representing the minimum shape of vector that is treated as an atom.
element_tile represents the native size of the vector.
A full examplelink
Vector to be distributed: vector<64x64xf16>
NestedLayout : <
subgroup_tile = [2, 1],
batch_tile = [2, 4],
outer_tile = [1, 1],
thread_tile = [16, 4],
element_tile = [1, 4],
subgroup_strides = [1, 0],
thread_strides = [1, 16]
>
This is conceptually viewed as a: vector<[2x1]x[2x4]x[1x1]x[16x4]x[1x4]>
where the first groups of sub-dimensions
represent the distribution into subgroups.
The subgroup_strides being [1, 0] means
each subgroup is going to get a vector
as follows:
subgroup0 : vector<[2x4]x[1x1]x[16x4]x[1x4]>
from vector<[2x1]x[2x4]x[1x1]x[16x4]x[1x4]>[0,:,:,:,:,:,:,:,:,:]
subgroup1 : vector<[2x4]x[1x1]x[16x4]x[1x4]>
from vector<[2x1]x[2x4]x[1x1]x[16x4]x[1x4]>[1,:,:,:,:,:,:,:,:,:]
subgroup2 : vector<[2x4]x[1x1]x[16x4]x[1x4]>
from vector<[2x1]x[2x4]x[1x1]x[16x4]x[1x4]>[0,:,:,:,:,:,:,:,:,:]
subgroup3 : vector<[2x4]x[1x1]x[16x4]x[1x4]>
from vector<[2x1]x[2x4]x[1x1]x[16x4]x[1x4]>[1,:,:,:,:,:,:,:,:,:]
Then each vector<[2x4]x[1x1]x[16x4]x[1x4]> is distributed threads in a subgroup using thread_strides = [1, 16]
recall: thread_id = sum_i(thread_stride[i] * virtual_thread_id[i]) % mul_i(thread_tile_size[i])
thread0 : vector<[2x4]x[1x1]x[1x4]>
from vector<[2x4]x[1x1]x[16x4]x[1x4]>[:,:,:,:,0,0,:,:]
thread1 : vector<[2x4]x[1x1]x[1x4]>
from vector<[2x4]x[1x1]x[16x4]x[1x4]>[:,:,:,:,1,0,:,:]
...
...
thread16 : vector<[2x4]x[1x1]x[1x4]>
from vector<[2x4]x[1x1]x[16x4]x[1x4]>[:,:,:,:,0,1,:,:]
Finally we are left with a distributed vector
of conceptual view : vector<[2x4]x[1x1]x[1x4]>
where the actual shape is : vector<2x16>.
Parameters:link
| Parameter | C++ type | Description |
|---|---|---|
| subgroupTile | ::llvm::ArrayRef<int64_t> |
subgroup_tile |
| batchTile | ::llvm::ArrayRef<int64_t> |
batch_tile |
| outerTile | ::llvm::ArrayRef<int64_t> |
outer_tile |
| threadTile | ::llvm::ArrayRef<int64_t> |
thread_tile |
| elementTile | ::llvm::ArrayRef<int64_t> |
element_tile |
| subgroupStrides | ::llvm::ArrayRef<int64_t> |
subgroup_strides |
| threadStrides | ::llvm::ArrayRef<int64_t> |
thread_strides |