When working with a GraphTensor in LuminAIR, you may notice that its type includes a specific generic parameter. This parameter encodes the tensor’s shape directly at the type level, enabling compile-time shape tracking.

What Does the Shape Parameter Represent?

Consider the following example:

let a: GraphTensor<R1<3>> = cx.tensor();

Here, R1<3> is the shape parameter of the GraphTensor. Let’s break it down:

  • R1: Represents the rank of the tensor, which is the number of dimensions it has. In this case, R1 denotes a rank-1 tensor (a one-dimensional tensor).
  • <3>: Specifies the size of the tensor along its single dimension. Here, <3> means the tensor contains 3 elements

Together, R1<3> defines a one-dimensional tensor with 3 elements—a 3-dimensional vector.

How Does Shape Tracking Work?

The shape parameter (R1<N>) is implemented as a type alias for (Const<N>,), where Const<N> is a compile-time constant representing the size of the tensor.

This design allows Rust to enforce tensor shapes at compile time, ensuring that:

  1. Shape Mismatches Are Caught Early: Tensor operations are only allowed between tensors with compatible shapes.
  2. No Runtime Shape Calculations: Since shapes are encoded in types, there’s no need to compute or validate shapes during execution.