Skip to content

'iree_tensor_ext' Dialectlink

IREE Tensor Extensions.

A dialect designed for experimenting with tensor operations beyond what is currently available in the Tensor Dialect.

Operationslink

iree_tensor_ext.dispatch.tensor.load (TensorExt::DispatchTensorLoadOp)link

Loads a tensor from a dispatch input placeholder

Syntax:

operation ::= `iree_tensor_ext.dispatch.tensor.load` $source
              `,` `offsets` `=` custom<DynamicIndexList>(
              $offsets, $static_offsets)
              `,` `sizes` `=` custom<DynamicIndexList>(
              $sizes, $static_sizes)
              `,` `strides` `=` custom<DynamicIndexList>(
              $strides, $static_strides)
              attr-dict `:` type($source) (`{` $source_dims^ `}`)?  `->` type($result)

Loads an input tensor or subtensor from an input placeholder. As each workgroup executes concurrently all workgroups will receive identical loaded results of regions that may overlap.

Traits: AlwaysSpeculatableImplTrait, AttrSizedOperandSegments

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface), OffsetSizeAndStrideOpInterface, ReifyRankedShapedTypeOpInterface, TiedOpInterface, Util_ShapeAwareOp

Effects: MemoryEffects::Effect{}

Attributes:link
AttributeMLIR TypeDescription
static_offsets::mlir::DenseI64ArrayAttri64 dense array attribute
static_sizes::mlir::DenseI64ArrayAttri64 dense array attribute
static_strides::mlir::DenseI64ArrayAttri64 dense array attribute
Operands:link
Operand Description
source dispatch.tensor
source_dims variadic of index
offsets variadic of index
sizes variadic of index
strides variadic of index
Results:link
Result Description
result ranked tensor of any type values

iree_tensor_ext.dispatch.tensor.store (TensorExt::DispatchTensorStoreOp)link

Stores a tensor into a dispatch output placeholder

Syntax:

operation ::= `iree_tensor_ext.dispatch.tensor.store` $value `,` $target
              `,` `offsets` `=` custom<DynamicIndexList>(
              $offsets, $static_offsets)
              `,` `sizes` `=` custom<DynamicIndexList>(
              $sizes, $static_sizes)
              `,` `strides` `=` custom<DynamicIndexList>(
              $strides, $static_strides)
              attr-dict `:` type($value) `->` type($target) (`{` $target_dims^ `}`)?

Stores a tensor or subtensor into an output tensor placeholder. As each workgroup executes concurrently behavior is undefined if more than one workgroup stores into overlapping regions of the full output tensor.

Traits: AttrSizedOperandSegments

Interfaces: OffsetSizeAndStrideOpInterface, Util_ShapeAwareOp

Attributes:link
AttributeMLIR TypeDescription
static_offsets::mlir::DenseI64ArrayAttri64 dense array attribute
static_sizes::mlir::DenseI64ArrayAttri64 dense array attribute
static_strides::mlir::DenseI64ArrayAttri64 dense array attribute
Operands:link
Operand Description
value ranked tensor of any type values
target dispatch.tensor
target_dims variadic of index
offsets variadic of index
sizes variadic of index
strides variadic of index

iree_tensor_ext.dispatch.workgroup_count_from_dag_root (TensorExt::DispatchWorkgroupCountFromDagRootOp)link

Workgroup count computed based on iteration range of the root of the DAG for ops within the dispatch.

Syntax:

operation ::= `iree_tensor_ext.dispatch.workgroup_count_from_dag_root` attr-dict $operands

When using tile + distribution of the root of the DAG (Directed Acyclic Graph) of ops within the dispatch to split the work amongst workgroups. The workload captured is the size of the iteration space of the root of the DAG. This op represents the computation that given the workload returns the number of workgroups to use. The backends are responsible for lowering this op into actual computation (typically based on the tile sizes used to tile and distribute the root of the DAG).

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, InferTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Operands:link
Operand Description
operands variadic of index
Results:link
Result Description
x index
y index
z index

iree_tensor_ext.dispatch.workgroup_count_from_slice (TensorExt::DispatchWorkgroupCountFromSliceOp)link

Place holder to signify default workgroup count calculation.

Syntax:

operation ::= `iree_tensor_ext.dispatch.workgroup_count_from_slice` attr-dict $operands

The default computation of the number of workgroups (or workgroup count) assumes that the dispatch + captured values is enough to compute the workgroup count. It does so by using a program slice of the values within the dispatch that represent the number of workgroups when available within the dispatch. Currently the arguments of index types captured by the flow.dispatch.workgroups is treated as the workload for the operation. It is a requirement that the slice of the program that computes the number of workgroups will need to have its leaves be these captured values.

TODO: This could be generalized in future to allow the slices to encompass arbitrary computation. The computation of the workgroup count can then be done on the device itself, if this is data dependent. In such cases the workload could be more than just values of index types.

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, InferTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Operands:link
Operand Description
operands variadic of index
Results:link
Result Description
x index
y index
z index

iree_tensor_ext.dispatch.workload.ordinal (TensorExt::DispatchWorkloadOrdinalOp)link

Annotates the values captured as workload within the body of flow.dispatch.workgroups op.

Syntax:

operation ::= `iree_tensor_ext.dispatch.workload.ordinal` attr-dict $operand `,` $ordinal `:` type($operand)

The arguments that represent the captured/returned values of the flow.dispatch.workgroups, i.e. the signature of the body of the op is not preserved during IREEs compilation. Since the workloads are derived from the operands captured by the operation, this op denotes the values captured as workloads. This can be used in the backends to map back to the workload values while materializing the workgroup count computation.

TODO: Find a better way to represent this information, either by somehow propagating the signature of the created dispatch workgroup op through the compilation stack until the codegen backends, or as a separate list/attribute that can be plumbed through without using explicit ops.

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, InferIntDivisibilityOpInterface, InferIntRangeInterface, InferTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Attributes:link
AttributeMLIR TypeDescription
ordinal::mlir::IntegerAttrindex attribute
Operands:link
Operand Description
operand index
Results:link
Result Description
result index

Type constraintslink

dispatch.tensorlink

A placeholder for a dispatch region input/output operand. This can be used to query the metadata about the tensor (such as its shape) as well as both load and store from the backing tensor representation.

dispatch.tensorlink

A placeholder for a dispatch region input operand. This can be used to query the metadata about the tensor (such as its shape) as well as load from the backing tensor representation.

dispatch.tensorlink

A placeholder for a dispatch region output operand. This can be used to query the metadata about the tensor (such as its shape) as well as store to the backing tensor representation.