Skip to content

'iree_codegen' Dialectlink

A dialect for common functionality used by IREE code generation.

This dialect is primarily meant to hold attributes that carry the state of the compilation when lowered to scalar code for an architecture. Typically, a backend starts by analyzing the entry point functions within the hal.executable.variant and deciding which compilation pipeline to chose. During this, even the values for parameters such as tile sizes, etc. are also decided. The rest of the compilation flow does not make any heuristic decisions, rather just looks at the values of the decision specified in attributes that belong to this dialect. This allows an external search to easily override the heuristics that are hard-coded within a backend.

Operationslink

iree_codegen.extract_strided_metadata (Codegen::ExtractStridedMetadataOp)link

Extracts a buffer base with offset and strides.

Syntax:

operation ::= `iree_codegen.extract_strided_metadata` $source `:` type($source) `->` type(results) attr-dict

This op is implemented similarly to the upstream MemRef::ExtractStridedMetadataOp with the following differences.

  1. It does not fold away static offset/stride information. Hence unlike the upstream Op the link between the memref and consumers of the metadata is not broken when later passes change this information. A common example in IREE of this is buffer binding optimizations.

  2. Helper functions getConstifiedMixed{Offset|Strides|Sizes} are not implemented as the expectation is you should lower to the upstream op before using those functions if you need them.

Copy of MemRef::ExtractStridedMetadataOp description for reference below. Extracts a base buffer, offset and strides. This op allows additional layers of transformations and foldings to be added as lowering progresses from higher-level dialect to lower-level dialects such as the LLVM dialect.

The op requires a strided memref source operand. If the source operand is not a strided memref, then verification fails.

This operation is also useful for completeness to the existing memref.dim op. While accessing strides, offsets and the base pointer independently is not available, this is useful for composing with its natural complement op: memref.reinterpret_cast.

Intended Use Cases:

The main use case is to expose the logic for manipulate memref metadata at a higher level than the LLVM dialect. This makes lowering more progressive and brings the following benefits: - not all users of MLIR want to lower to LLVM and the information to e.g. lower to library calls---like libxsmm---or to SPIR-V was not available. - foldings and canonicalizations can happen at a higher level in MLIR: before this op existed, lowering to LLVM would create large amounts of LLVMIR. Even when LLVM does a good job at folding the low-level IR from a performance perspective, it is unnecessarily opaque and inefficient to send unkempt IR to LLVM.

Traits: AlwaysSpeculatableImplTrait, InferTypeOpAdaptor, SameVariadicResultSize

Interfaces: ConditionallySpeculatable, InferTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), OpAsmOpInterface, ViewLikeOpInterface

Effects: MemoryEffects::Effect{}

Operands:link
Operand Description
source strided memref of any type values
Results:link
Result Description
base_buffer strided memref of any type values of rank 0
offset index
sizes variadic of index
strides variadic of index

iree_codegen.inner_tiled (Codegen::InnerTiledOp)link

Models an inner-tiled operation that may perform contractions.

Syntax:

operation ::= `iree_codegen.inner_tiled` `ins` `(` $inputs `)` `outs` `(` $outputs `)` attr-dict
              `:` type($inputs) `into` type($outputs)

Represents an operation that updates "inner tiles" (vector slices) of its accumulator outputs using inner tiles of its input operands. The way the inner tiles are used is determined by the semantics of the kind attribute. These inner tiles are the trailing dimensions of the tensor/vector operands that are not present in the indexing maps.

In the case of a matrix-multiply-accumulate (MMA) inner tiled operation, the semantics logically match vector.contraction. However, instead of a combiner type, it has a intrinsic description that specifies how the inner tiles are combined.

Similar to vector.contract, an iterator type attribute list must be specified, where each element of the list represents an iterator over one of the outer dimensions. Iteration of inner dimensions is defined solely by the intrinsic and may be opaque.

An indexing map attribute list must be specified with an entry for input and output arguments. An indexing map attribute specifies a mapping from each outer loop iterator in the iterator type list, to each dimension of each operand.

The combiner type is defined by the intrinsic.

Example:

#contraction_accesses = [
 affine_map<(i, j, k) -> (i, k)>,
 affine_map<(i, j, k) -> (k, j)>,
 affine_map<(i, j, k) -> (i, j)>
]
#contraction_trait = {
  indexing_maps = #contraction_accesses,
  iterator_types = ["parallel", "parallel", "reduction"],
  kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>
}
%3 = iree_codegen.inner_tiled ins(%0, %1) outs(%2) #contraction_trait
  : vector<2x3x4xf16>, vector<3x5x4xf16> into vector<2x5x4xf32>

// Takes tensors as well, however the inner dimensions must always be
// static.
%7 = iree_codegen.inner_tiled ins(%4, %5) outs(%6) #contraction_trait
  : tensor<?x?x4xf16>, tensor<?x?x4xf16> into tensor<?x?x4xf32>

The example above can be logically lowered directly to loops like this (ignoring type conversions from tensor to vector needed for the mfma).

%outer_m = tensor.dim %6, %c0 : index
%outer_n = tensor.dim %6, %c1 : index
%outer_k = tensor.dim %4, %c1 : index
%7 = scf.for %i = %c0 to %outer_m iter_args(%arg0 = %6) {
  %8 = scf.for %j = %c0 to %outer_n iter_args(%arg1 = %arg0) {
    %9 = scf.for %k = %c0 to %outer_k iter_args(%arg2 = %arg1) {
      %lhs = tensor.extract_slice %4 [%i, %k, 0] [1, 1, 4] [1, 1, 1] : tensor<4xf16>
      %rhs = tensor.extract_slice %5 [%k, %j, 0] [1, 1, 4] [1, 1, 1] : tensor<4xf16>
      %acc = tensor.extract_slice %arg2 [%i, %j, 0] [1, 1, 4] [1, 1, 1] : tensor<4xf32>
      %res = amdgpu.mfma %lhs, %rhs, %acc : tensor<4xf32>
      %ret = tensor.insert_slice %acc into %arg2 [%i, %j, 0] [1, 1, 4] [1, 1, 1] : tensor<?x?x4xf32>
      scf.yield %ret : tensor<?x?x4xf32>
    }
    scf.yield %9 : tensor<?x?x4xf32>
  }
  scf.yield %8 : tensor<?x?x4xf32>
}

Or alternatively unrolled to a single intrinsic when operation on vectors.

#contraction_accesses = [
 affine_map<() -> ()>,
 affine_map<() -> ()>,
 affine_map<() -> ()>
]
#contraction_trait = {
  indexing_maps = #contraction_accesses,
  iterator_types = [],
  kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>
}
%3 = iree_codegen.inner_tiled ins(%0, %1) outs(%2) #contraction_trait
  : vector<4xf16>, vector<4xf16> into vector<4xf32>

This operation can represent an intrinsic both in undistributed (workgroup/subgroup/warp) and distributed (thread) level. The descriptor attribute specifies the inner tile sizes for both the undistributed form (where the operands represent the data to be processed by an entire group of parallel workers) and the distributed form (where the operands represent the data processed by a single workitem/thread).

In some cases, variations on the inner tiled operations can be expressed with the permutations attribute. This attribute represents the permutation from that intrinsic's "canonical" layout (in the case of matrix multiplication, this is row-major storage of the inner tile) to the format of the inner tile in the arguments, with a permutation specified for each argument.

Since the canonical dimensionality of the inner dimensions are somewhat intrinsic specific, verification of this op requires only that element counts of the inner dimensions match the intrinsic.

For example, an MMT product of inner dimensions with warp semantics can be represented with the following. Permutations are only allowed for ops with undistributed semantics and must be resolved before distribution.

#contraction_accesses = [
 affine_map<(i, j, k) -> (i, k)>,
 affine_map<(i, j, k) -> (k, j)>,
 affine_map<(i, j, k) -> (i, j)>
]
#contraction_trait = {
  indexing_maps = #contraction_accesses,
  iterator_types = ["parallel", "parallel", "reduction"],
  kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>,
  permutations = [[0, 1], [1, 0], [0, 1]]
}
%7 = iree_codegen.inner_tiled ins(%4, %5) outs(%6) #contraction_trait
  : tensor<?x?x16x16xf16>, tensor<?x?x16x16xf16> into tensor<?x?x16x16xf32>

Motivation, Design Choices, and Pitfallslink

This operation grew out of a general representation for matrix multiplication intrinsics on GPUs, where the inner tiles would be the tiles of the A, B, and C matrices that were computed by an entire GPU workgroup or subgroup. It is now used for generalizations of such multiplications. Currently, the only usage is for scaled matrix-multiply-accumulate, where block scales must be passed in as additional inputs, but it's possible more uses will be possible in the future.

The idea behind this operation is to decouple the layout setting/tiling required to target certain intrinsics from the lowering to them. Because typically tiling of this sort happens on tensor operands, however the target intrinsics operate on vectors, we use this operation to bridge the gap. The choice for a shared operation is intended to ease the lowering process and allow for different transformations at different stages of the pipeline without needing to essentially clone this op.

The choice to let the inner dimensions required to compute the intrinsic be implicit based on the indexing maps was made to make this operation easier to generate and to skip the need for type conversion ops. However this comes at the expense of ease of verification for the operation. It is also implicitly linked to a lane-level parent scf.forall operation.

Traits: AlwaysSpeculatableImplTrait, AttrSizedOperandSegments, InferTypeOpAdaptor

Interfaces: ConditionallySpeculatable, DestinationStyleOpInterface, InferTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TilingInterface, VectorUnrollOpInterface

Effects: MemoryEffects::Effect{}

Attributes:link
AttributeMLIR TypeDescription
indexing_maps::mlir::ArrayAttrindxing affine maps
iterator_types::mlir::ArrayAttrIterator type should be an enum.
kindIREE::Codegen::InnerTileDescAttrInterfacebuffer-like constant attribute values
permutations::mlir::ArrayAttrpermutations
Operands:link
Operand Description
inputs variadic of ranked tensor or vector of any type values
outputs variadic of ranked tensor or vector of any type values
Results:link
Result Description
results variadic of ranked tensor or vector of any type values

iree_codegen.load_from_buffer (Codegen::LoadFromBufferOp)link

Loads a tensor from a memref.

Syntax:

operation ::= `iree_codegen.load_from_buffer` $buffer attr-dict `:` type($buffer) `->` type($tensor)

Loads a tensor from a memref with a compatible shape and the same element type.

Interfaces: MemoryEffectOpInterface, ReifyRankedShapedTypeOpInterface

Operands:link
Operand Description
buffer strided memref of any type values
Results:link
Result Description
tensor ranked tensor of any type values

iree_codegen.null_pointer (Codegen::NullPointerOp)link

Returns a null_pointer value.

Syntax:

operation ::= `iree_codegen.null_pointer` attr-dict

This is meant to be used only as arguments to microkernels.

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, InferTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Results:link
Result Description
result Pseudo null-pointer type. Lowers to a null pointer.

iree_codegen.query_tile_sizes (Codegen::QueryTileSizesOp)link

Yields tile sizes for the specified tensor type.

Syntax:

operation ::= `iree_codegen.query_tile_sizes` attr-dict $tensor_type `->` type($results)

For targets where tile sizes can't be resolved at compile time, this operation allows querying the sizes at runtime. Today this only applies to VMVX.

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Attributes:link
AttributeMLIR TypeDescription
tensor_type::mlir::TypeAttrTensor type attribute
Results:link
Result Description
results variadic of index

iree_codegen.store_to_buffer (Codegen::StoreToBufferOp)link

Stores a tensor into a memref.

Syntax:

operation ::= `iree_codegen.store_to_buffer` $tensor `,` $buffer
              attr-dict `:` type($tensor) `into` type($buffer)

Stores a tensor into a memref with a compatible shape and the same element type.

Interfaces: MemoryEffectOpInterface

Operands:link
Operand Description
tensor ranked tensor of any type values
buffer strided memref of any type values

iree_codegen.swizzle_hint (Codegen::SwizzleHintOp)link

Hint to swizzle accesses according to an access pattern.

Syntax:

operation ::= `iree_codegen.swizzle_hint` $operand `[` $swizzle attr-dict `]` `:` type($result)

Optimization hint to swizzle all accesses to the memref this takes a view of. This only affects reads/writes immediately consuming this operation and is best effort. If the desired swizzling is not apparently possible, this op will no-op. As a result, it should not be relied on for correctness.

Any subviews on this operation will cause swizzling to fail. The expectation is for all view like operations to fold into the accessing ops (loads/stores) before this op takes effect.

Note that this only rewrites direct users. If there are any aliased loads or stores of the data from/to the |src| memref of a hintOp, those accesses will not be swizzled. This allows reusing an allocation with different swizzled access patterns as long as there is no data dependency between memory with different layouts. For example:

%0 = alloc()
%1 = iree_codegen.swizzle_hint %0, #layout_0
%2 = iree_codegen.swizzle_hint %0, #layout_1
{
   vector.store %1
   vector.load %1
     ^
     |
    unrelated
     |
     v
   vector.store %2
   vector.load %2
}

If there is a data dependency between the accesses of %1 and %2, for example a value stored to %1 is loaded from %2, this is undefined behavior. Aliasing is otherwise perfectly legal.

Traits: AlwaysSpeculatableImplTrait, SameOperandsAndResultType

Interfaces: ConditionallySpeculatable, InferTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Attributes:link
AttributeMLIR TypeDescription
swizzleIREE::Codegen::SwizzleAttrInterfaceswizzling descriptor attributes
Operands:link
Operand Description
operand 1D memref of any type values
Results:link
Result Description
result 1D memref of any type values

Attributeslink

DispatchLoweringPassPipelineAttrlink

Identifier for pass pipeline use to lower dispatch region

Syntax:

#iree_codegen.<
  ::mlir::iree_compiler::IREE::Codegen::DispatchLoweringPassPipeline   # value
>
Parameters:link
Parameter C++ type Description
value ::mlir::iree_compiler::IREE::Codegen::DispatchLoweringPassPipeline an enum of type DispatchLoweringPassPipeline

CompilationInfoAttrlink

Drive lowering of an operation from input dialect.

Syntax:

#iree_codegen.compilation_info<
  LoweringConfigAttrInterface,   # loweringConfig
  TranslationInfoAttr   # translationInfo
>

Specifies the information that allows controlling the compilation of operations like linalg.matmul/linalg.*conv within IREE. This information is used to override the defaults used by the IREE compiler. If set on the input to the compiler, there is no guarantee that the config survives until codegen. Named operations like linalg.matmul/linalg.*conv* are more likely to retain their lowering configurations.

TODO: It is expected that the TranslationInfoAttr and the LoweringConfigAttr are specified. Currently there is no verification that the values of the LoweringConfigAttr fully specifies the behaviour of the compilation path chosen with TranslationInfoAttr. This could be added in the future.

Parameters:link
Parameter C++ type Description
loweringConfig LoweringConfigAttrInterface
translationInfo TranslationInfoAttr

EncodingNopLayoutAttrlink

An attribute with implementation that treats encoding as nop.

Syntax: #iree_codegen.encoding_nop_layout

An attribute that implements the interface methods that discards the encodings. It can be a default attribute when a backend does not implement encoding details.

ExportConfigAttrlink

User defined workgroup size specification.

Syntax:

#iree_codegen.export_config<
  ::llvm::ArrayRef<int64_t>   # workgroup_size
>

Allows setting workgroup size for pre-formed dispatches.

Parameters:link
Parameter C++ type Description
workgroup_size ::llvm::ArrayRef<int64_t> Workgroup Size to use

LoweringConfigAttrlink

Drive lowering of an operation within dispatch region.

Syntax:

#iree_codegen.lowering_config<
  LoweringConfigTilingLevelsAttr,   # tilingLevels
  ::llvm::ArrayRef<int64_t>   # nativeVectorSize
>

Default implementation of a lowering configuration attribute. It includes only tiling and optionally vectorization information. The interpretation of the tiles sizes are backend dependent.

TODO: Currently there is no verification that the configuration specifies everything needed for a pass-pipeline. The values to set for these parameters is dependent on the pass-pipeline implementation. In future, each pass pipeline could verify that the lowering configuration has all the necessary attributes for the pipeline.

Parameters:link
Parameter C++ type Description
tilingLevels LoweringConfigTilingLevelsAttr The lowering config at different levels
nativeVectorSize ::llvm::ArrayRef<int64_t> The native vector size to use for the given operation

LoweringConfigTilingLevelAttrlink

Parameters:link
Parameter C++ type Description
sizes ::llvm::ArrayRef<int64_t> The tile sizes to use for this level of tiling
interchange ::llvm::ArrayRef<int64_t> The tile interchange to use for this level of tiling
scalableFlags ::llvm::ArrayRef<bool> The scalable tile flags for this level of tiling

LoweringConfigTilingLevelsAttrlink

Syntax:

#iree_codegen.lowering_config_levels<
  ::llvm::ArrayRef<LoweringConfigTilingLevelAttr>   # value
>
Parameters:link
Parameter C++ type Description
value ::llvm::ArrayRef<LoweringConfigTilingLevelAttr>

RotateRowsAttrlink

An attribute that describes a swizzling pattern for rotating rows.

Syntax:

#iree_codegen.rotate_rows<
  int64_t,   # row_width
  int64_t   # access_width
>

This attribute rotates accesses of |access_width| within rows of size |row_width|. For any given access into logical memref of shape memref<...xNx|access_width|x!eltype> where N = row_width / access_width at position (i, j, 0) is rotated to (i, (i + j) % N, 0). For example,

row_width = 16, access_width = 4

0000 1111 2222 3333 /// 0 1 2 3
4444 5555 6666 7777 /// 0 1 2 3
8888 9999 AAAA BBBB /// 0 1 2 3
CCCC DDDD EEEE FFFF /// 0 1 2 3

is swizzled to

0000 1111 2222 3333 /// 0 1 2 3
7777 4444 5555 6666 /// 3 0 1 2
BBBB AAAA 8888 9999 /// 2 3 0 1
FFFF EEEE DDDD CCCC /// 1 2 3 0

The pattern repeats for subsequent rows.

Parameters:link
Parameter C++ type Description
row_width int64_t
access_width int64_t

TranslationInfoAttrlink

Drive dispatch entry point lowering.

Syntax:

#iree_codegen.translation_info<
  IREE::Codegen::DispatchLoweringPassPipelineAttr,   # passPipeline
  SymbolRefAttr,   # codegenSpec
  ::llvm::ArrayRef<int64_t>,   # workgroupSize
  int64_t,   # subgroupSize
  DictionaryAttr   # configuration
>

Specifies the information that is used to drive the translation of an entry point function using Linalg based structured-op lowering. During executable translation this is attached to the hal.executable.export operation.

If this operation is already set on the root operation (as part of iree_codegen.compilation_info) that drives the compilation of a dispatch region (like linalg.matmul/linalg.*conv*), this attribute gets propagated to the entry point function.

The fields are - passPipeline : The pass pipeline to use.

Parameters:link
Parameter C++ type Description
passPipeline IREE::Codegen::DispatchLoweringPassPipelineAttr Name of the pipeline to be invoked on the translation unit.
codegenSpec SymbolRefAttr The symbol pointing to the transform dialect codegen spec to be used
workgroupSize ::llvm::ArrayRef<int64_t> The workgroup size to use
subgroupSize int64_t The subgroup size to use
configuration DictionaryAttr Pipeline specific configuration

WorkgroupMappingAttrlink

Syntax:

#iree_codegen.workgroup_mapping<
  ::mlir::iree_compiler::IREE::Codegen::WorkgroupId,   # id
  int64_t   # delinearizedDim
>

Attribute that eventually will be used to map distributed loop iterations to hal.workgroup.ids.

The x,y and z values for id map to hal.workgroup.id[0], hal.workgroup.id[1] and hal.workgroup.id[2] respectively.

In addition it is possible to specify if the z dimension is to be delinearized on mapping. For example if the list of mapping attributes is [workgroup_mapping<z:1>, workgroup_mapping<z:0>], then the z dimension is delinearized to map to workgroup_mapping<z:1> and workgroup_mapping<z:0>. In other words if the number of logical parallel workers along the z:0 dimension is W, then

workgroup_mapping<z:0> = hal.workgroup.id[1] mod W,
worgrkoup_mapping<z:1> = hal.workgroup.id[1] div W

Note: It is expected that this attribute is always used in a list of mapping attributes (with a single element being a list of size 1). It is illegal for a list to have workgroup_mapping<z:a> without workgroup_mapping<z:b> if a > b. In the same way it is illegal to for the list to - have workgroup_mapping<y> but not workgroup_mapping<x> - have workgroup_mapping<z:*> but not have workgroup_mapping<x> and workgroup_mapping<y>

Parameters:link
Parameter C++ type Description
id ::mlir::iree_compiler::IREE::Codegen::WorkgroupId an enum of type WorkgroupId
delinearizedDim int64_t

Typeslink

NullPointerTypelink

Pseudo null-pointer type. Lowers to a null pointer.

Syntax: !iree_codegen.null_pointer

This is meant to be used only as arguments to microkernels.

Enumslink

BinaryFnlink

Allowed 32-bit signless integer cases: 0, 1, 2, 3, 4, 5, 6, 7, 8, 9

Cases:link

Symbol Value String
add 0 add
sub 1 sub
mul 2 mul
div 3 div
div_unsigned 4 div_unsigned
max_signed 5 max_signed
min_signed 6 min_signed
max_unsigned 7 max_unsigned
min_unsigned 8 min_unsigned
powf 9 powf

DispatchLoweringPassPipelinelink

Identifier for pass pipeline use to lower dispatch region

Cases:link

Symbol Value String
CPUDefault 0 CPUDefault
CPUDoubleTilingExpert 1 CPUDoubleTilingExpert
CPUConvTileAndDecomposeExpert 2 CPUConvTileAndDecomposeExpert
Mmt4dTilingExpert 3 Mmt4dTilingExpert
CPUBufferOpsTileAndVectorize 4 CPUBufferOpsTileAndVectorize
CPUDataTiling 5 CPUDataTiling
CPULinalgExtTileAndVectorize 6 CPULinalgExtTileAndVectorize
LLVMGPUDefault 100 LLVMGPUDefault
LLVMGPUBaseLowering 101 LLVMGPUBaseLowering
LLVMGPUDistribute 102 LLVMGPUDistribute
LLVMGPUVectorize 103 LLVMGPUVectorize
LLVMGPUMatmulTensorCore 104 LLVMGPUMatmulTensorCore
LLVMGPUTransposeSharedMem 105 LLVMGPUTransposeSharedMem
LLVMGPUWarpReduction 106 LLVMGPUWarpReduction
LLVMGPUMatmulTensorCoreMmaSync 107 LLVMGPUMatmulTensorCoreMmaSync
LLVMGPUVectorDistribute 108 LLVMGPUVectorDistribute
LLVMGPUWinogradVectorize 109 LLVMGPUWinogradVectorize
LLVMGPUTileAndFuse 110 LLVMGPUTileAndFuse
SPIRVBaseLowering 200 SPIRVBaseLowering
SPIRVBaseDistribute 201 SPIRVBaseDistribute
SPIRVBaseVectorize 202 SPIRVBaseVectorize
SPIRVSubgroupReduce 203 SPIRVSubgroupReduce
SPIRVMatmulPromoteVectorize 204 SPIRVMatmulPromoteVectorize
SPIRVCooperativeMatrixVectorize 205 SPIRVCooperativeMatrixVectorize
SPIRVWinogradVectorize 206 SPIRVWinogradVectorize
VMVXDefault 300 VMVXDefault
TransformDialectCodegen 1000 TransformDialectCodegen
Custom 1001 Custom
None 65535 None

ElementwiseArityGrouplink

Allowed 32-bit signless integer cases: 1, 2, 3

Cases:link

Symbol Value String
Unary 1 Unary
Binary 2 Binary
Ternary 3 Ternary

ElementwiseCaseLimitslink

Allowed 32-bit signless integer cases:

Cases:link

Symbol Value String
LastUnary 13 LastUnary
LastBinary 23 LastBinary
LastTernary 24 LastTernary

ElementwiseKindlink

Allowed 32-bit signless integer cases: 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23

Cases:link

Symbol Value String
exp 0 exp
log 1 log
abs 2 abs
ceil 3 ceil
floor 4 floor
negf 5 negf
reciprocal 6 reciprocal
round 7 round
sqrt 8 sqrt
rsqrt 9 rsqrt
square 10 square
tanh 11 tanh
erf 12 erf
add 13 add
sub 14 sub
mul 15 mul
div 16 div
div_unsigned 17 div_unsigned
max_signed 18 max_signed
min_signed 19 min_signed
max_unsigned 20 max_unsigned
min_unsigned 21 min_unsigned
powf 22 powf
select 23 select

IteratorTypelink

Iterator type

Cases:link

Symbol Value String
parallel 0 parallel
reduction 1 reduction

TernaryFnlink

Allowed 32-bit signless integer cases: 0

Cases:link

Symbol Value String
select 0 select

TypeFnlink

Allowed 32-bit signless integer cases: 0, 1

Cases:link

Symbol Value String
cast_signed 0 cast_signed
cast_unsigned 1 cast_unsigned

UnaryFnlink

Allowed 32-bit signless integer cases: 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12

Cases:link

Symbol Value String
exp 0 exp
log 1 log
abs 2 abs
ceil 3 ceil
floor 4 floor
negf 5 negf
reciprocal 6 reciprocal
round 7 round
sqrt 8 sqrt
rsqrt 9 rsqrt
square 10 square
tanh 11 tanh
erf 12 erf

WorkgroupIdlink

Attribute that map to hal.workgrpoup.ids

Cases:link

Symbol Value String
IdX 0 x
IdY 1 y
IdZ 2 z