'iree_gpu' Dialectlink
A dialect for common functionality used by GPU focused IREE code generation.
This dialect provides operations and attributes to aid in code generation for GPU targets. The functionality in this dialect can be hardware specific, but is intended to be independent of the lowering target. Late lowerings to SPIR-V/LLVM are handled separately.
- 'iree_gpu' Dialect
- Operations
- Attributes
- ComputeBitwidthsAttr
- DataTiledMMAAttr
- DerivedThreadConfigAttr
- DotProductOpsAttr
- GPUEncodingLayoutAttr
- GPUPadLayoutAttr
- GPUPipelineOptionsAttr
- LaneIdAttr
- LoweringConfigAttr
- MMAAttr
- MMAIntrinsicAttr
- MMAOpsArrayAttr
- MMAScheduleAttr
- ReorderWorkgroupsStrategyAttr
- StorageBitwidthsAttr
- SubgroupOpsAttr
- TargetAttr
- TargetChipAttr
- TargetWgpAttr
- UKernelConfigAttr
- UseGlobalLoadDMAAttr
- VirtualMMAAttr
- VirtualMMAIntrinsicAttr
- Enums
Operationslink
iree_gpu.barrier_region
(GPU::BarrierRegionOp)link
Synchronizes workers on a region of shared code.
Syntax:
operation ::= `iree_gpu.barrier_region` (`ins` `(` $inputs^ `:` type($inputs) `)` )?
$region attr-dict `:` type($results)
This op is designed to represent synchronization of workers on the operands
and results of the given region. This operation naturally arises when combining
the regions of producer-consumer scf.forall
operations that share a
mapping type.
For example, consider the following pair of parallel loops.
%0 = scf.forall (%idy, %idx) in (2, 32) shared_outs(%init = %empty) -> (tensor<4x128xf32>) {
%in = ...
%2 = affine.apply #affine_map<(d0) -> (d0 * 2)> (%idy)
%3 = affine.apply #affine_map<(d0) -> (d0 * 4)> (%idx)
scf.forall.in_parallel {
tensor.parallel_insert_slice %in into %init[%2, %3] [2, 4] [1, 1]
: tensor<2x4xf32> into tensor<4x128xf32>
}
} {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
%1 = scf.forall (%idy, %idx) in (8, 8) -> (tensor<128x128xf32>) {
%4 = affine.apply #affine_map<(d0) -> (d0 * 16)> (%idx)
%extracted_slice = tensor.extract_slice %0[0, %4] [4, 16] [1, 1]
: tensor<4x128xf32> to tensor<4x16xf32>
...
} {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
Because these loops share the same worker type and total count, the bodies of these two loops can be merged with a barrier an insert_slice and a shuffle where the boundary of the loops currently is.
%0 = scf.forall (%idy, %idx) in (8, 8) -> (tensor<4x128xf32>) {
%alloc = bufferization.alloc_tensor {memory_space = #gpu.address_space<workgroup>}
: tensor<4x128xf32>
%barrier = iree_gpu.barrier_region %alloc {
^bb0(%shared: tensor<4x128xf32>):
%ids = affine.delinearize_index %idy * 8 + %idx to (2, 32) : index
%in = ...
%2 = affine.apply #affine_map<(d0) -> (d0 * 2)> (%ids#0)
%3 = affine.apply #affine_map<(d0) -> (d0 * 4)> (%ids#1)
%inserted_slice = tensor.insert_slice %in into %shared[%2, %3] [2, 4] [1, 1]
: tensor<2x4xf32> to tensor<4x128xf32>
iree_gpu.yield %slice : tensor<4x16xf32>
} : tensor<4x128xf32> -> tensor<4x16xf32>
%4 = affine.apply #affine_map<(d0) -> (d0 * 16)> (%idx)
%slice = tensor.extract_slice %barrier[0, %4] [4, 16] [1, 1] : tensor<4x128xf32> to tensor<4x16xf32>
...
} {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
A barrier_region can be lowered to two barriers, one on the input operands and a second one on the results.
Movtivation and Intended Use Cases:
The primary way this op is generated is when fusing parallel loops with tensor results. This operation helps to make lowerings more progressive and flexible. - Lowering directly to an alloc + reads and writes breaks the dependency chain making transformations like barrier placement and pipelining potentially more difficult. - Allows the option of non-vector based lowering paths.
Traits: AlwaysSpeculatableImplTrait
, SingleBlockImplicitTerminator<mlir::iree_compiler::IREE::GPU::YieldOp>
, SingleBlock
Interfaces: ConditionallySpeculatable
, NoMemoryEffect (MemoryEffectOpInterface)
Effects: MemoryEffects::Effect{}
Operands:link
Operand | Description |
---|---|
inputs |
variadic of any type |
Results:link
Result | Description |
---|---|
results |
variadic of any type |
iree_gpu.buffer_resource_cast
(GPU::BufferResourceCastOp)link
Represents a cast to addr_space<7> (buffer resource) before bufferization.
Syntax:
operation ::= `iree_gpu.buffer_resource_cast` $input oilist (`cacheSwizzleStride` `(` $cache_swizzle_stride `)` )
attr-dict `:` type($result)
Nominal cast of a tensor to AMDGPU buffer resource memory space before
bufferization. This op takes the parameters with which to perform the cast
if |input| bufferizes to storage_buffer
memory space. If |input| resolves
to any other memory space this op is silently dropped and has no effect.
If |cache_swizzle_stride| is present, there is verification before bufferization that all producers of |input| are view-like and single source and user (i.e. trivially no alias). In all other cases this op is best effort and has no verification or failure modes.
// TODO: Add other parameters for casting as needed.
Traits: AlwaysSpeculatableImplTrait
Interfaces: ConditionallySpeculatable
, NoMemoryEffect (MemoryEffectOpInterface)
Effects: MemoryEffects::Effect{}
Operands:link
Operand | Description |
---|---|
input |
ranked tensor of any type values |
cache_swizzle_stride |
index |
Results:link
Result | Description |
---|---|
result |
ranked tensor of any type values |
iree_gpu.global_load_dma
(GPU::GlobalLoadDMAOp)link
Does a global load DMA operation
Syntax:
operation ::= `iree_gpu.global_load_dma` $source`[` $sourceIndices `]` `->` $target `[` $targetIndices `]` attr-dict
`:` type($source) `->` type($target)
This operation represents a subgroup-level global load DMA operation. It is used to represent a direct gathering operation from global memory to workgroup. To be specific, the thread gathers data from the global memoryspace at the designated indices, and stores it to the thread's lane-offset of the workgroup memref at the designated indices.
Specifically, if the thread's subgroup lane id is lane_id
, the thread will load the data
from $source[sourceIndices]
and store it to $target[targetIndices] + lane_id
.
Collectively, all threads in the subgroup orchestrate the load DMA operation.
Note: each gather has a load width is 32bit.
Traits: SameVariadicOperandSize
Operands:link
Operand | Description |
---|---|
source |
memref of any type values |
sourceIndices |
variadic of index |
target |
memref of any type values |
targetIndices |
variadic of index |
iree_gpu.value_barrier
(GPU::ValueBarrierOp)link
Synchronizes workers on a value semantic tensor or vector.
Syntax:
operation ::= `iree_gpu.value_barrier` $inputs attr-dict `:` type($inputs)
This operation acts as a barrier on a value semantic SSA values (tensor or vector). It takes multiple operands and produces a value equivalent to each input. This does not have copy and/or data movement semantics and simply represents a barrier on all writes in the tensor case, and a barrier until all threads acquire the input vector in the vector case.
The inputs must be either all tensors, or all vectors.
This operation is a no-op when not present in a parallel context. This operation is pure as it only requires synchronization for the value it produces.
Traits: AlwaysSpeculatableImplTrait
Interfaces: ConditionallySpeculatable
, NoMemoryEffect (MemoryEffectOpInterface)
Effects: MemoryEffects::Effect{}
Operands:link
Operand | Description |
---|---|
inputs |
variadic of ranked tensor or vector of any type values |
Results:link
Result | Description |
---|---|
results |
variadic of ranked tensor or vector of any type values |
iree_gpu.yield
(GPU::YieldOp)link
Yield values from a iree_gpu region.
Syntax:
operation ::= `iree_gpu.yield` attr-dict ($values^ `:` type($values))?
This operation is used to yield values from a within a region.
Traits: AlwaysSpeculatableImplTrait
, HasParent<::mlir::iree_compiler::IREE::GPU::BarrierRegionOp>
, ReturnLike
, Terminator
Interfaces: ConditionallySpeculatable
, NoMemoryEffect (MemoryEffectOpInterface)
, RegionBranchTerminatorOpInterface
Effects: MemoryEffects::Effect{}
Operands:link
Operand | Description |
---|---|
values |
variadic of any type |
Attributeslink
ComputeBitwidthsAttrlink
Supported bitwidths for compute
Syntax:
#iree_gpu.compute_bitwidths<
::mlir::iree_compiler::IREE::GPU::ComputeBitwidths # value
>
Parameters:link
Parameter | C++ type | Description |
---|---|---|
value | ::mlir::iree_compiler::IREE::GPU::ComputeBitwidths |
an enum of type ComputeBitwidths |
DataTiledMMAAttrlink
Syntax:
#iree_gpu.data_tiled_mma_layout<
::mlir::iree_compiler::IREE::GPU::MMAIntrinsic, # intrinsic
int64_t, # intrinsics_m
int64_t, # subgroups_m
int64_t, # intrinsics_n
int64_t, # subgroups_n
int64_t # intrinsics_k
>
This mma variant represents MMA ops with data-tiling details. The |intrinsic| field specifies which particular MMA intrinsic is targeted by the data-tiling.
The other fields default to one, and that default results in a single intrinsic equivalent to MMAAttr, while values greater than one result in wider "kernels" consisting of multiple intrinsics, with the data layout already swizzled into a tile layout that allows each intrinsic to access data at an offset that's as simple as possible a mapping from the thread ID.
Parameters:link
Parameter | C++ type | Description |
---|---|---|
intrinsic | ::mlir::iree_compiler::IREE::GPU::MMAIntrinsic |
an enum of type MMAIntrinsic |
intrinsics_m | int64_t |
Intrinsic count along the M dimension. |
subgroups_m | int64_t |
Subgroup count along the M dimension. |
intrinsics_n | int64_t |
Intrinsic count along the N dimension. |
subgroups_n | int64_t |
Subgroup count along the N dimension. |
intrinsics_k | int64_t |
Intrinsic count along the K dimension, with interleaved layout. |
DerivedThreadConfigAttrlink
Drive lowering of an operation by deriving thread distribution when needed.
Syntax: #iree_gpu.derived_thread_config
Lowering config for a single thread tiling level that is inferred after previous (often reduction) levels of tile + fuse. This is intended for fused operations where it is much easier to compute the tile sizes to use after previous levels of tile + fuse, rather than trying to pre-propagate tiling configs.
DotProductOpsAttrlink
Supported dot product ops
Syntax:
#iree_gpu.dotproduct_ops<
::mlir::iree_compiler::IREE::GPU::DotProductOps # value
>
Parameters:link
Parameter | C++ type | Description |
---|---|---|
value | ::mlir::iree_compiler::IREE::GPU::DotProductOps |
an enum of type DotProductOps |
GPUEncodingLayoutAttrlink
The encoding layout attribute for GPU backend.
Syntax:
#iree_gpu.gpu_encoding_layout<
DictionaryAttr # configuration
>
This attribute can implement any layout interface methods for encoding serialization and or materialization, e.g., Encoding::LayoutMaterializerAttr, Codegen::PackedLayoutMaterializerAttr, etc. They should be implemented through external model mechanism because we do not want to relocate domain-specific logic to the dialect implementation, and we can have better code structure. See the implementation in compiler/Codegen/ExternalInterfaces/*.
Parameters:link
Parameter | C++ type | Description |
---|---|---|
configuration | DictionaryAttr |
Executable target configuration. It is expected to be used in a pass scope, but not the final IR output. |
GPUPadLayoutAttrlink
The padded encoding layout attribute for GPU targets.
Syntax:
#iree_gpu.gpu_pad_layout<
std::optional<uint32_t>, # cache_line_bytes
std::optional<uint32_t> # cache_sets
>
Describes padding preferences for a given GPU target. This attribute can implement any encoding interface for data-tiling, e.g., Encoding::LayoutResolverAttr, etc. They should be implemented through external model mechanism because we do not want to relocate domain-specific logic to the dialect implementation, and we can have better code structure. See the implementation in compiler/Codegen/ExternalInterfaces/*.
Parameters:link
Parameter | C++ type | Description |
---|---|---|
cache_line_bytes | std::optional<uint32_t> |
|
cache_sets | std::optional<uint32_t> |
GPUPipelineOptionsAttrlink
Options attribute for linalg + tensors -> vector + memref GPU pipelines.
Syntax:
#iree_gpu.pipeline_options<
BoolAttr, # prefetch_shared_memory
BoolAttr, # no_reduce_shared_memory_bank_conflicts
BoolAttr, # use_igemm_convolution
ReorderWorkgroupsStrategyAttr # reorder_workgroups_strategy
>
This attributes describes lowering pipeline specific configuration options:
* prefetch_shared_memory: Boolean option indicating whether or not to run
the loop prefetching pass in the lowering pipeline.
* no_reduce_shared_memory_bank_conflicts: Boolean option indicating whether
or not to skip the bank conflict reduction pass in the lowering pipeline.
* reorder_workgroups_strategy: Enum attribute indicating which strategy to
choose for the workgroup reordering pass. Options are None
, Swizzle
,
and Transpose
.
Parameters:link
Parameter | C++ type | Description |
---|---|---|
prefetch_shared_memory | BoolAttr |
|
no_reduce_shared_memory_bank_conflicts | BoolAttr |
|
use_igemm_convolution | BoolAttr |
|
reorder_workgroups_strategy | ReorderWorkgroupsStrategyAttr |
LaneIdAttrlink
Syntax:
#iree_gpu.lane_id<
int64_t # dim
>
An attribute for mapping scf.forall ops to subgroup lanes.
Parameters:link
Parameter | C++ type | Description |
---|---|---|
dim | int64_t |
LoweringConfigAttrlink
Drive lowering of an operation for gpu compilation.
Syntax:
#iree_gpu.lowering_config<
DictionaryAttr # attributes
>
GPU specific implementation of a lowering config. This carries just a dictionary attribute to store any relevant fields. This is the simplest form of a lowering config, offering flexibility at the cost of structure.
Parameters:link
Parameter | C++ type | Description |
---|---|---|
attributes | DictionaryAttr |
The configured fields, including tiling levels |
MMAAttrlink
Syntax:
#iree_gpu.mma_layout<
::mlir::iree_compiler::IREE::GPU::MMAIntrinsic, # intrinsic
bool # col_major
>
Attribute describing a particular shape of matrix-multiply and accumulate instruction. Abstractly, all attributes of this type represent the following unit of arithmetic for matrices A, B, and C.
C += A x B
The |intrinsic| field specifies which particular MMA intrinsic this refers to, with each intrinsic implicating a specific MNK shape and operand types. See IREEGPUEnums.td for the definition of the intrinsics.
If set to true, |col_major| indicates that the result should be produced column major. This is equivalent to instead computing:
C^T += B^T x A^T
Parameters:link
Parameter | C++ type | Description |
---|---|---|
intrinsic | ::mlir::iree_compiler::IREE::GPU::MMAIntrinsic |
an enum of type MMAIntrinsic |
col_major | bool |
MMAIntrinsicAttrlink
Descriptor for different MMA intrinsics
Syntax:
#iree_gpu.mma_intrinsic<
::mlir::iree_compiler::IREE::GPU::MMAIntrinsic # value
>
Parameters:link
Parameter | C++ type | Description |
---|---|---|
value | ::mlir::iree_compiler::IREE::GPU::MMAIntrinsic |
an enum of type MMAIntrinsic |
MMAOpsArrayAttrlink
Syntax:
#iree_gpu.mma_ops<
::llvm::ArrayRef<MMAAttr> # value
>
Parameters:link
Parameter | C++ type | Description |
---|---|---|
value | ::llvm::ArrayRef<MMAAttr> |
MMAScheduleAttrlink
Syntax:
#iree_gpu.mma_schedule<
::mlir::iree_compiler::IREE::GPU::MmaInterfaceAttr, # intrinsic
int64_t, # subgroup_m_count
int64_t # subgroup_n_count
>
A schedule of MMA intrinsic instruction and various levels of tile sizes to solve a specific contraction problem.
Parameters:link
Parameter | C++ type | Description |
---|---|---|
intrinsic | ::mlir::iree_compiler::IREE::GPU::MmaInterfaceAttr |
|
subgroup_m_count | int64_t |
|
subgroup_n_count | int64_t |
ReorderWorkgroupsStrategyAttrlink
Strategy for workgroup reordering
Syntax:
#iree_gpu.reorder_workgroups_strategy<
::mlir::iree_compiler::IREE::GPU::ReorderWorkgroupsStrategy # value
>
Parameters:link
Parameter | C++ type | Description |
---|---|---|
value | ::mlir::iree_compiler::IREE::GPU::ReorderWorkgroupsStrategy |
an enum of type ReorderWorkgroupsStrategy |
StorageBitwidthsAttrlink
Supported bitwidths for storage
Syntax:
#iree_gpu.storage_bitwidths<
::mlir::iree_compiler::IREE::GPU::StorageBitwidths # value
>
Parameters:link
Parameter | C++ type | Description |
---|---|---|
value | ::mlir::iree_compiler::IREE::GPU::StorageBitwidths |
an enum of type StorageBitwidths |
SubgroupOpsAttrlink
Supported subgroup ops
Syntax:
#iree_gpu.subgroup_ops<
::mlir::iree_compiler::IREE::GPU::SubgroupOps # value
>
Parameters:link
Parameter | C++ type | Description |
---|---|---|
value | ::mlir::iree_compiler::IREE::GPU::SubgroupOps |
an enum of type SubgroupOps |
TargetAttrlink
Full GPU target attribute.
Syntax:
#iree_gpu.target<
::llvm::StringRef, # arch
::llvm::StringRef, # features
TargetWgpAttr, # wgp
TargetChipAttr # chip
>
This attributes describes a full GPU target. It contains a few fields: * The canonical target architecture for compilation, e.g., sm_80 for cuda, gfx942 for hip * A TargetWgpAttr describing the GPU features and limits in a single GPU workgroup processor (WGP), that is, AMD compute unit or NVIDIA streaming multiprocessor * An optional TargetChipAttr describing GPU features for the final chip or product, e.g., wgp count
Parameters:link
Parameter | C++ type | Description |
---|---|---|
arch | ::llvm::StringRef |
target architecture |
features | ::llvm::StringRef |
target features |
wgp | TargetWgpAttr |
|
chip | TargetChipAttr |
TargetChipAttrlink
Chip level target description.
Syntax:
#iree_gpu.target_chip<
uint32_t, # wgp_count
StringAttr, # sku
DictionaryAttr # extra
>
This attribute contains hardware features/limits at a single GPU chip level. Here a GPU chip means the hardware functionality scope where the whole software compute grid is scheduled onto. A chip typically contains many AMD compute units or NVIDIA streaming multiprocessors; it's the final SKU.
Parameters:link
Parameter | C++ type | Description |
---|---|---|
wgp_count | uint32_t |
|
sku | StringAttr |
|
extra | DictionaryAttr |
TargetWgpAttrlink
Workgroup processor level target description.
Syntax:
#iree_gpu.target_wgp<
ComputeBitwidthsAttr, # compute
StorageBitwidthsAttr, # storage
SubgroupOpsAttr, # subgroup
DotProductOpsAttr, # dot
MMAOpsArrayAttr, # mma
DenseI32ArrayAttr, # subgroup_size_choices
DenseI32ArrayAttr, # max_workgroup_sizes
int32_t, # max_thread_count_per_workgroup
int32_t, # max_workgroup_memory_bytes
DenseI32ArrayAttr, # max_workgroup_counts
std::optional<int32_t>, # max_load_instruction_bits
std::optional<int32_t>, # simds_per_wgp
std::optional<int32_t>, # vgpr_space_bits
DictionaryAttr # extra
>
This attribute contains hardware features/limits at a single GPU workgroup processor (WGP) level. Here a GPU workgroup processor means the basic hardware functionality unit where a software workgroup is scheduled onto; that is, a compute unit for AMD GPUs or a streaming multiprocessor for NVIDIA GPUs.
Parameters:link
Parameter | C++ type | Description |
---|---|---|
compute | ComputeBitwidthsAttr |
|
storage | StorageBitwidthsAttr |
|
subgroup | SubgroupOpsAttr |
|
dot | DotProductOpsAttr |
|
mma | MMAOpsArrayAttr |
|
subgroup_size_choices | DenseI32ArrayAttr |
|
max_workgroup_sizes | DenseI32ArrayAttr |
|
max_thread_count_per_workgroup | int32_t |
|
max_workgroup_memory_bytes | int32_t |
|
max_workgroup_counts | DenseI32ArrayAttr |
|
max_load_instruction_bits | std::optional<int32_t> |
|
simds_per_wgp | std::optional<int32_t> |
|
vgpr_space_bits | std::optional<int32_t> |
|
extra | DictionaryAttr |
UKernelConfigAttrlink
An attribute specifying a ukernel that an op can lower to.
Syntax:
#iree_gpu.ukernel_config<
StringAttr, # name
DictionaryAttr, # def_attrs
int64_t # shared_memory_bytes
>
An attribute that can be applied to any operation to specify that it has been matched with a ukernel that is a legal lowering for it.
Parameters:link
Parameter | C++ type | Description |
---|---|---|
name | StringAttr |
|
def_attrs | DictionaryAttr |
|
shared_memory_bytes | int64_t |
Size in bytes of shared memory workspace |
UseGlobalLoadDMAAttrlink
Drive lowering of an operation by using global load DMA.
Syntax: #iree_gpu.use_global_load_dma
Lowering config for when using global load DMA is needed. This is intended for tagging operations that are known to be able to use global load DMA, which might also have its own cofiguration.
VirtualMMAAttrlink
Syntax:
#iree_gpu.virtual_mma_layout<
::mlir::iree_compiler::IREE::GPU::VirtualMMAIntrinsic # intrinsic
>
This mma variant represents "virtual" MMA ops that has modification to its native layouts by intrinsicsK and/or interleave reads. The |intrinsic| field represents different kinds of "Virtual" MMA Ops we found helpful.
These interleaving and/or unrolling changes in the layout is especially useful to coalesce reads from shared memory to register or align layouts in a chained-matmul operation.
Parameters:link
Parameter | C++ type | Description |
---|---|---|
intrinsic | ::mlir::iree_compiler::IREE::GPU::VirtualMMAIntrinsic |
an enum of type VirtualMMAIntrinsic |
VirtualMMAIntrinsicAttrlink
Descriptor for different Virtual MMA intrinsics
Syntax:
#iree_gpu.virtual_mma_intrinsic<
::mlir::iree_compiler::IREE::GPU::VirtualMMAIntrinsic # value
>
Parameters:link
Parameter | C++ type | Description |
---|---|---|
value | ::mlir::iree_compiler::IREE::GPU::VirtualMMAIntrinsic |
an enum of type VirtualMMAIntrinsic |
Enumslink
ComputeBitwidthslink
Supported bitwidths for compute
Cases:link
Symbol | Value | String |
---|---|---|
FP64 | 1 |
fp64 |
FP32 | 2 |
fp32 |
FP16 | 4 |
fp16 |
Int64 | 8 |
int64 |
Int32 | 16 |
int32 |
Int16 | 32 |
int16 |
Int8 | 64 |
int8 |
DotProductOpslink
Supported dot product ops
Cases:link
Symbol | Value | String |
---|---|---|
None | 0 |
none |
DP4xI8ToI32 | 1 |
dp4xi8toi32 |
MMAFragmentlink
Descriptor for a particular fragment of an MMA operation
Cases:link
Symbol | Value | String |
---|---|---|
Lhs | 0 |
Lhs |
Rhs | 1 |
Rhs |
Acc | 2 |
Acc |
MMAIntrinsiclink
Descriptor for different MMA intrinsics
Cases:link
Symbol | Value | String |
---|---|---|
MFMA_F32_16x16x4_F32 | 4112 |
MFMA_F32_16x16x4_F32 |
MFMA_F32_16x16x16_F16 | 4128 |
MFMA_F32_16x16x16_F16 |
MFMA_F32_32x32x8_F16 | 4129 |
MFMA_F32_32x32x8_F16 |
MFMA_I32_16x16x16_I8 | 4288 |
MFMA_I32_16x16x16_I8 |
MFMA_I32_32x32x8_I8 | 4289 |
MFMA_I32_32x32x8_I8 |
MFMA_F32_16x16x8_BF16 | 4384 |
MFMA_F32_16x16x8_BF16 |
MFMA_F32_32x32x4_BF16 | 4385 |
MFMA_F32_32x32x4_BF16 |
MFMA_F64_16x16x4_F64 | 4352 |
MFMA_F64_16x16x4_F64 |
MFMA_F32_16x16x16_BF16 | 4640 |
MFMA_F32_16x16x16_BF16 |
MFMA_F32_32x32x8_BF16 | 4641 |
MFMA_F32_32x32x8_BF16 |
MFMA_F32_16x16x32_F8E5M2FNUZ | 4656 |
MFMA_F32_16x16x32_F8E5M2FNUZ |
MFMA_F32_16x16x32_F8E5M2FNUZ_F8E4M3FNUZ | 4657 |
MFMA_F32_16x16x32_F8E5M2FNUZ_F8E4M3FNUZ |
MFMA_F32_16x16x32_F8E4M3FNUZ | 4658 |
MFMA_F32_16x16x32_F8E4M3FNUZ |
MFMA_F32_16x16x32_F8E4M3FNUZ_F8E5M2FNUZ | 4659 |
MFMA_F32_16x16x32_F8E4M3FNUZ_F8E5M2FNUZ |
MFMA_F32_32x32x16_F8E5M2FNUZ | 4660 |
MFMA_F32_32x32x16_F8E5M2FNUZ |
MFMA_F32_32x32x16_F8E5M2FNUZ_F8E4M3FNUZ | 4661 |
MFMA_F32_32x32x16_F8E5M2FNUZ_F8E4M3FNUZ |
MFMA_F32_32x32x16_F8E4M3FNUZ | 4662 |
MFMA_F32_32x32x16_F8E4M3FNUZ |
MFMA_F32_32x32x16_F8E4M3FNUZ_F8E5M2FNUZ | 4663 |
MFMA_F32_32x32x16_F8E4M3FNUZ_F8E5M2FNUZ |
MFMA_I32_16x16x32_I8 | 4800 |
MFMA_I32_16x16x32_I8 |
MFMA_I32_32x32x16_I8 | 4801 |
MFMA_I32_32x32x16_I8 |
MFMA_F32_16x16x32_F16 | 4896 |
MFMA_F32_16x16x32_F16 |
MFMA_F32_32x32x16_F16 | 4897 |
MFMA_F32_32x32x16_F16 |
MFMA_F32_16x16x32_BF16 | 4898 |
MFMA_F32_16x16x32_BF16 |
MFMA_F32_32x32x16_BF16 | 4899 |
MFMA_F32_32x32x16_BF16 |
MFMA_F32_16x16x32_F8E5M2 | 4912 |
MFMA_F32_16x16x32_F8E5M2 |
MFMA_F32_16x16x32_F8E5M2_F8E4M3FN | 4913 |
MFMA_F32_16x16x32_F8E5M2_F8E4M3FN |
MFMA_F32_16x16x32_F8E4M3FN | 4914 |
MFMA_F32_16x16x32_F8E4M3FN |
MFMA_F32_16x16x32_F8E4M3FN_F8E5M2 | 4915 |
MFMA_F32_16x16x32_F8E4M3FN_F8E5M2 |
MFMA_F32_32x32x16_F8E5M2 | 4916 |
MFMA_F32_32x32x16_F8E5M2 |
MFMA_F32_32x32x16_F8E5M2_F8E4M3FN | 4917 |
MFMA_F32_32x32x16_F8E5M2_F8E4M3FN |
MFMA_F32_32x32x16_F8E4M3FN | 4918 |
MFMA_F32_32x32x16_F8E4M3FN |
MFMA_F32_32x32x16_F8E4M3FN_F8E5M2 | 4919 |
MFMA_F32_32x32x16_F8E4M3FN_F8E5M2 |
MFMA_F32_16x16x128_F8E5M2 | 4920 |
MFMA_F32_16x16x128_F8E5M2 |
MFMA_F32_16x16x128_F8E5M2_F8E4M3FN | 4921 |
MFMA_F32_16x16x128_F8E5M2_F8E4M3FN |
MFMA_F32_16x16x128_F8E4M3FN | 4922 |
MFMA_F32_16x16x128_F8E4M3FN |
MFMA_F32_16x16x128_F8E4M3FN_F8E5M2 | 4923 |
MFMA_F32_16x16x128_F8E4M3FN_F8E5M2 |
MFMA_F32_32x32x64_F8E5M2 | 4924 |
MFMA_F32_32x32x64_F8E5M2 |
MFMA_F32_32x32x64_F8E5M2_F8E4M3FN | 4925 |
MFMA_F32_32x32x64_F8E5M2_F8E4M3FN |
MFMA_F32_32x32x64_F8E4M3FN | 4926 |
MFMA_F32_32x32x64_F8E4M3FN |
MFMA_F32_32x32x64_F8E4M3FN_F8E5M2 | 4927 |
MFMA_F32_32x32x64_F8E4M3FN_F8E5M2 |
MFMA_I32_16x16x64_I8 | 5056 |
MFMA_I32_16x16x64_I8 |
MFMA_I32_32x32x32_I8 | 5057 |
MFMA_I32_32x32x32_I8 |
WMMAR3_F32_16x16x16_F16 | 6176 |
WMMAR3_F32_16x16x16_F16 |
WMMAR3_F16_16x16x16_F16 | 6177 |
WMMAR3_F16_16x16x16_F16 |
WMMAR3_F32_16x16x16_BF16 | 6178 |
WMMAR3_F32_16x16x16_BF16 |
WMMAR3_BF16_16x16x16_BF16 | 6179 |
WMMAR3_BF16_16x16x16_BF16 |
WMMAR3_I32_16x16x16_I8 | 6336 |
WMMAR3_I32_16x16x16_I8 |
WMMAR4_F32_16x16x16_F16 | 6432 |
WMMAR4_F32_16x16x16_F16 |
WMMAR4_F16_16x16x16_F16 | 6433 |
WMMAR4_F16_16x16x16_F16 |
WMMAR4_F32_16x16x16_BF16 | 6434 |
WMMAR4_F32_16x16x16_BF16 |
WMMAR4_BF16_16x16x16_BF16 | 6435 |
WMMAR4_BF16_16x16x16_BF16 |
WMMAR4_F32_16x16x16_F8E5M2 | 6448 |
WMMAR4_F32_16x16x16_F8E5M2 |
WMMAR4_F32_16x16x16_F8E5M2_F8E4M3FN | 6449 |
WMMAR4_F32_16x16x16_F8E5M2_F8E4M3FN |
WMMAR4_F32_16x16x16_F8E4M3FN | 6450 |
WMMAR4_F32_16x16x16_F8E4M3FN |
WMMAR4_F32_16x16x16_F8E4M3FN_F8E5M2 | 6451 |
WMMAR4_F32_16x16x16_F8E4M3FN_F8E5M2 |
WMMAR4_I32_16x16x16_I8 | 6592 |
WMMAR4_I32_16x16x16_I8 |
NV_WMMA_F32_16x16x16_F16 | 8224 |
NV_WMMA_F32_16x16x16_F16 |
NV_WMMA_F16_16x16x16_F16 | 8225 |
NV_WMMA_F16_16x16x16_F16 |
ReorderWorkgroupsStrategylink
Strategy for workgroup reordering
Cases:link
Symbol | Value | String |
---|---|---|
None | 0 |
None |
Transpose | 1 |
Transpose |
StorageBitwidthslink
Supported bitwidths for storage
Cases:link
Symbol | Value | String |
---|---|---|
B64 | 1 |
b64 |
B32 | 2 |
b32 |
B16 | 4 |
b16 |
B8 | 8 |
b8 |
SubgroupOpslink
Supported subgroup ops
Cases:link
Symbol | Value | String |
---|---|---|
None | 0 |
none |
Shuffle | 1 |
shuffle |
Arithmetic | 2 |
arithmetic |
TilingLevellink
Descriptor for tiling levels for GPU lowering configs
Cases:link
Symbol | Value | String |
---|---|---|
Workgroup | 0 |
Workgroup |
Reduction | 1 |
Reduction |
PartialReduction | 2 |
PartialReduction |
Thread | 3 |
Thread |
Subgroup | 4 |
Subgroup |
Lane | 5 |
Lane |
VirtualMMAIntrinsiclink
Descriptor for different Virtual MMA intrinsics
Cases:link
Symbol | Value | String |
---|---|---|
VMFMA_F32_16x16x32_F16 | 0 |
VMFMA_F32_16x16x32_F16 |
VMFMA_F32_32x32x16_F16 | 1 |
VMFMA_F32_32x32x16_F16 |
VMFMA_F32_16x16x32_F8E4M3FNUZ | 2 |
VMFMA_F32_16x16x32_F8E4M3FNUZ |
VMFMA_F32_32x32x16_F8E4M3FNUZ | 3 |
VMFMA_F32_32x32x16_F8E4M3FNUZ |
IteratorTypelink
Iterator type
Cases:link
Symbol | Value | String |
---|---|---|
parallel | 0 |
parallel |
reduction | 1 |
reduction |