Concepts
Shape Tracking
Compile-time shape tracking
When working with a GraphTensor in LuminAIR, you may notice that its type includes a specific generic parameter. This parameter encodes the tensor’s shape directly at the type level, enabling compile-time shape tracking.
What Does the Shape Parameter Represent?
Consider the following example:
Here, R1<3>
is the shape parameter of the GraphTensor
.
Let’s break it down:
R1
: Represents the rank of the tensor, which is the number of dimensions it has. In this case,R1
denotes a rank-1 tensor (a one-dimensional tensor).<3>
: Specifies the size of the tensor along its single dimension. Here,<3>
means the tensor contains 3 elements
Together, R1<3>
defines a one-dimensional tensor with 3 elements—a 3-dimensional vector.
How Does Shape Tracking Work?
The shape parameter (R1<N>
) is implemented as a type alias for (Const<N>,)
,
where Const<N>
is a compile-time constant representing the size of the tensor.
This design allows Rust to enforce tensor shapes at compile time, ensuring that:
- Shape Mismatches Are Caught Early: Tensor operations are only allowed between tensors with compatible shapes.
- No Runtime Shape Calculations: Since shapes are encoded in types, there’s no need to compute or validate shapes during execution.