Rotary Position Embedding, or RoPE is a critical operations used in Large Language Models. Responsible of encoding positional information (where the token is in the sentence) into the embedding, so the model can distinguish between “the cat eats the bird” and “bird eats the cat”. Ever since my journey starting to add a Tenstorrent backend into GGML. RoPE has always been falling back to the CPU. As TTNN although have it’s own RoPE implementation, is fundamentally incompatible with the semantics GGML wants. I have been putting off writing my own RoPE support as I expect it to be a pain and hoping someone would Deus Ex Machina and give me a working version in a PR out of no where.
To say, I was wrong. I improved the Tenstorrent GGML backend to the degree where RoPE support is now th…
Rotary Position Embedding, or RoPE is a critical operations used in Large Language Models. Responsible of encoding positional information (where the token is in the sentence) into the embedding, so the model can distinguish between “the cat eats the bird” and “bird eats the cat”. Ever since my journey starting to add a Tenstorrent backend into GGML. RoPE has always been falling back to the CPU. As TTNN although have it’s own RoPE implementation, is fundamentally incompatible with the semantics GGML wants. I have been putting off writing my own RoPE support as I expect it to be a pain and hoping someone would Deus Ex Machina and give me a working version in a PR out of no where.
To say, I was wrong. I improved the Tenstorrent GGML backend to the degree where RoPE support is now the main bottleneck for inference performance. What else can I do.. Time to get my hand dirty.
IMPORTANT: As you can see looking at the right hand side - the scrollbar on this page - the post is VERY long. I do not have the energy nor space to guide you through the entire process. The raw Gemtext file of this post is about 117KB! And I wrote every word.
It is assumed that the reader have a very solid understanding of parallel computing patterns, computer organization, hardware architecture and imagination to interpret the gaps.
I hope this blog post can serve as a reference in the future for “how the stuff works” and if software vendors are interested in Tensotrrent’s processors, how to program them.
Source code (the NeoX variant) available in the following link (licsned under Zero-Clause BSD):
And the acual integration lives in my llama.cpp fork. Kernels linked below
RoPE
I highly recommend my friend Fleetwood’s blog on the mathematics and design philosophy of RoPE.
Done? Good. I am not mathematician and my barin goes blank as soon as I see complicated equations. Let’s decompose RoPE into computation:
- For each RoPE operation. A vector and a position value is provided.
- For each pair of values in the vector
- Define
ias the index of the first element (of the pair) in the vector andDbeing the size of the vector - Calculate the frequency of rotation
= 1 / 10000^(i / D) - Calculate rotation angle
= position * frequency - Apply a 2D rotation of angle
angleto the pair
I choose to implement the NeoX variant of RoPE as that seems easier. It splits the entire vector into 2 halfs. Then rotate pairwise between the 2 vectors.
std::vector<float> rope(const std::vector<float>& vec, int pos)
{
size_t D = vec.size();
assert(D % 2 == 0 && "Dimension must be even");
std::vector<float> result(D);
for (size_t i = 0; i < D/2; ++i) {
float exponent = 2.0f * i / D ;
float freq = 1.0f / std::pow(10000.0f, exponent);
// Calculate the angle of rotation
float angle = pos * freq;
float cos_angle = std::cos(angle);
float sin_angle = std::sin(angle);
// Calculate the rotation matrix
float cos_sin = cos_angle * sin_angle;
float sin_cos = sin_angle * cos_angle;
// Apply the rotation to the pair
float x = vec[i];
float y = vec[i + D/2];
result[i] = x * cos_angle - y * sin_angle;
result[i + D/2] = x * sin_angle + y * cos_angle;
}
return result;
}
Really simple isn’t it? The mechanisms are clear as soon as you’re not looking at vectorized equations.
Additionally, we ought to support multiple vectors in a single operation (denoted N here). If you look into GGML’s API. It sometimes wants to limit how many dimensions RoPE is applied. So new parameters are introduced. D denotes the width of the input data (vec.size() == D * N and D_active denotes the number of dimensions the operation is applied to. The following implementation can be derived:
std::vector<float> rope(const std::vector<float>& vec, const std::vector<int>& pos, int D, int D_active)
{
size_t rotate_dim = D_active == -1 ? vec.size() : D_active;
assert(D_active % 2 == 0 && "Active dimension must be even");
assert(D_active <= D && "Active dimension must be less than or equal to total dimension");
std::vector<float> result(vec.size());
size_t n = pos.size();
for(size_t n = 0; n < N; ++n) {
size_t offset = n * D;
for (size_t i = 0; i < D_active/2; i ++) {
float exponent = 2.f * (float)i / D_active;
float freq = 1.0f / std::pow(10000.0f, exponent);
float angle = pos * freq;
float cos_angle = std::cos(angle);
float sin_angle = std::sin(angle);
float x = vec[offset + i];
float y = vec[offset + i + D_active/2];
result[offset + i] = x * cos_angle - y * sin_angle;
result[offset + i + D_active/2] = x * sin_angle + y * cos_angle;
}
for (size_t i = D_active; i < D; i++) {
result[offset + i] = vec[offset + i];
}
}
return result;
}
We can check our implementation against GGML using with some simple coding:
// GGML
int main(void) {
const int D = 32;
const int N = 1;
std::vector<float> input_vec(D * N);
for(int i = 0; i < D * N; i++) {
input_vec[i] = (float)i;
}
std::vector<int> pos(N);
for(int i = 0; i < N; i++) {
pos[i] = i+1;
}
// Arbitrary amount of allowed memory as the example is small
size_t ctx_size = 5 * 1024 * 1024; // 5 MB
struct ggml_init_params params = {
/*.mem_size =*/ ctx_size,
/*.mem_buffer =*/ NULL,
/*.no_alloc =*/ false,
};
struct ggml_context * ctx = ggml_init(params);
struct ggml_tensor * tensor_a = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, D, N);
struct ggml_tensor * tensor_b = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, N, 1);
memcpy(tensor_a->data, input_vec.data(), ggml_nbytes(tensor_a));
memcpy(tensor_b->data, pos.data(), ggml_nbytes(tensor_b));
struct ggml_cgraph * gf = ggml_new_graph(ctx);
struct ggml_tensor * result = ggml_rope(ctx, tensor_a, tensor_b, D/2, GGML_ROPE_TYPE_NEOX);
ggml_build_forward_expand(gf, result);
ggml_graph_compute_with_ctx(ctx, gf, 1);
// Print result
float * result_data = (float *) result->data;
for(size_t i = 0; i < N; i++) {
printf("vector %zu:\n", i);
for(size_t j = 0; j < D; j++) {
printf("%f ", result_data[i * D + j]);
}
printf("\n");
}
ggml_free(ctx);
}
// Ours
int main() {
constexpr size_t D = 32;
std::vector<float> vec(D);
for(size_t i = 0; i < D; ++i) {
vec[i] = i;
}
auto res = rope(vec, 1, D/2);
for (size_t i = 0; i < res.size(); ++i) {
std::cout << res[i] << (i < res.size() - 1 ? ", " : "\n");
}
}
Which prints
GGML:
-6.731768 -1.848437 0.991674 2.650707 3.879802 4.958866 5.985997 6.995256 4.322418 8.864721 10.149709 11.089353 12.039399
13.015746 14.005993 15.002213 16.000000 17.000000 18.000000 19.000000 20.000000 21.000000 22.000000 23.000000 24.000000
25.000000 26.000000 27.000000 28.000000 29.000000 30.000000 31.000000
Ours:
-6.73177, -1.84844, 0.991674, 2.65071, 3.8798, 4.95887, 5.986, 6.99526, 4.32242, 8.86472, 10.1497, 11.0894, 12.0394, 13.0157,
14.006, 15.0022, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31
Good - we have a working reference and mental model now.
Implementing RoPE
Let’s understand some design constraint before diving into code. RoPE is not a trival element wise operation, in two differnt ways
- The heavy but predictable use of sin() and cos()
- Running RoPE requires knowledge of which element in the vector is being processed
sin() and cos() are expensive to compute - no surprise there. Most backends dodge this by precomputing the values on the host, dumping them into device DRAM as a “sin/cos cache,” and then just indexing into that at runtime. Since the inputs are always the same for a given set of hyperparameters, this works fine. But not on Tenstorrent. Here, everything runs in tiles, so I’d have to write a gather just to pull the right values from DRAM, and then hack up the data movement kernel to handle sub-tile access (forget about using helpers like TensorAccessor or InterleavedAddressGen{,Fast} - they won’t know what to do). This design would be forbiddingly difficult to write. So, trigonometric functions have to be computed on the fly.
With that said, some amount of manual addressing is still required. Especially as it will be required to extract the position index in order to apply it across the entire vector later on.
I recommend reviewing the Metalium Programming Model guide if you are not familare with it already. As it will serve as the base knowledge assumed in this post.
RoPE also needs each pair of elements to be rotated by a different amount, based on their index. So, which element lands in which SFPU lane actually matters (RoPE is not a simple element wise operation), and the internal format of a tile must be taken into account. The official documentations (I wrote most of them :p) should help with understanding:
And my previous experiment programming the SFPU (though disregarding the tile structure at the time) would also be an interesting read for readers into weird programming practices.
For a baseline, the kernel could run on a single (Tensix) core, utilizing the SFPU (vector engine) and supporting RoPE on part of the input tensor (the n_dims mentioned earlier, a hard requirement by GGML). I’ll simplify things by skipping the need to read varying token positions from an integer tensor and instead apply a single value across all vectors. It’s easier this way, and loading an array of integers later should be straightforward enough.
Implementation strategy
Let’s visualize the input vector. There’s an active and passive region, the active reagon is where RoPE will be performed on, and is splitted subsequently into 2 euqal sized subreagons. Which are then iterated together pairwise to perform the actual calculation. The passive reagon is simply not touched.
Image: Diagram of how a vector is splitted conceptually
In the active region, where RoPE will performed. We:
- split the active region into two halves
- Data Movement kernel 0 (DM0) moves a tile from each half into the input Circular Buffer, transferring to the compute kernel
- Compute kernel moves the tiles into Dst registers
- Performs RoPE using the SFPU, against data on the Dst registers
- Store result back onto the Dst registers
- Move result to the output Circular Buffer
- Data Movement kernel 1 (DM1) moves the result back into DRAM
- Repeat until all active tiles are processed
The following diagram illustrates the flow of processing active tiles:
Image: Diagram of the dataflow processing active tiles
NOTE: The diagram only shows one pair of active tiles due to space constraints. In real-world scenarios, you’ll typically encounter multiple pairs of active tiles within a single row.
After all active tiles are processed, the kernel switch to processing passive tiles. In this phase, the compute kernel is effectively disabled, waiting for the next round (if available). Data Movement kernel 0 directly passes the input tile to Data Movement kernel 1 for writing out into the result tensor. Arguably it could be more efficient if it directly write the passive tiles into the result tensor from Data Movement kernel 0. But I consider that too much of a complication for too little benifit.
Image: Diagram of the dataflow processing passive tiles
To set things up. On the host side (I have written some wrappers to make the Metalium API easier to use):
constexpr size_t D = 64; // Length of the vector
constexpr size_t D_active = 64; // Active reagon size
constexpr size_t N = 32; // "Batch" size
// The above variables in number tiles
constexpr uint32_t Dt = D/32;
constexpr uint32_t Nt = N/32;
constexpr uint32_t D_activet = D_active/32;
// Buffers for input and output
auto src = MakeBuffer(device, Dt * Nt, sizeof(float));
auto dst = MakeBuffer(device, Dt * Nt, sizeof(float));
// The 3 citcular buffers for data passing
MakeCircularBufferFP32(program, core, tt::CBIndex::c_0, 4); // Data into the compute kernel (reader/DM0 -> Compute)
MakeCircularBufferFP32(program, core, tt::CBIndex::c_16, 4); // Data out of compute kernel (Compute -> writer/DM1)
MakeCircularBufferFP32(program, core, tt::CBIndex::c_17, 4); // Moves passive tiles (reader/DM0 -> writer/DM1)
The reader implemts the reading part of our stradegy. Read the pairs of ative tiles and send to compute. Then read the passive tiles and send to writer
// Runtime kernel parameters
uint32_t src_addr = get_arg_val<uint32_t>(0);
uint32_t n_tiles_width_active = get_arg_val<uint32_t>(1);
uint32_t n_tiles_width = get_arg_val<uint32_t>(2);
uint32_t n_tiles_height = get_arg_val<uint32_t>(3);
const uint32_t tile_size_bytes = get_tile_size(cb_in0);
const auto src = TensorAccessor(TensorAccessorArgs<0>(), src_addr, tile_size_bytes);
constexpr uint32_t cb_in0 = tt::CBIndex::c_0;
constexpr uint32_t cb_bypass = tt::CBIndex::c_17;
for(uint32_t h = 0; h < n_tiles_height; h++) {
// Reads the active tile pairs and send them to the compute kernel
for(uint32_t w = 0; w < n_tiles_width_active/2; w++) {
uint32_t tile_idx = h * n_tiles_width + w;
uint32_t tile_idx2 = h * n_tiles_width + (w + n_tiles_width_active/2);
cb_reserve_back(cb_in0, 2);
uint32_t cb_src_addr = get_write_ptr(cb_in0);
noc_async_read_tile(tile_idx, src, cb_src_addr);
noc_async_read_tile(tile_idx2, src, cb_src_addr + tile_size_bytes);
noc_async_read_barrier();
cb_push_back(cb_in0, 2);
}
// Reads the passive tiles and forward to writer
for(uint32_t w = n_tiles_width_active; w < n_tiles_width; w++) {
uint32_t tile_idx = h * n_tiles_width + w;
cb_reserve_back(cb_bypass, 1);
noc_async_read_tile(tile_idx, src, get_write_ptr(cb_bypass));
noc_async_read_barrier();
cb_push_back(cb_bypass, 1);
}
}
Compute is where the complexity lives. The kernel itself is as straightforward and as typical as it is - wait for 2 tiles, move them into Dst. Do math. Move them from Dst back into Circular Buffer. Loop.
uint32_t n_tiles_width_active = get_arg_val<uint32_t>(0);
uint32_t n_tiles_width = get_arg_val<uint32_t>(1);
uint32_t n_tiles_height = get_arg_val<uint32_t>(2);
constexpr uint32_t cb_in0 = tt::CBIndex::c_0;
constexpr uint32_t cb_out0 = tt::CBIndex::c_16;
init_sfpu(tt::CBIndex::c_0, tt::CBIndex::c_16);
for(uint32_t i = 0; i < n_tiles_height; i++) {
for(uint32_t j = 0; j < n_tiles_width_active/2; j++) {
cb_wait_front(cb_in0, 2);
tile_regs_acquire();
copy_tile_init(cb_in0);
copy_tile(cb_in0, 0, 0);
copy_tile(cb_in0, 1, 1);
MATH(rope_tile(1000, n_tiles_width_active*32, j*32));
tile_regs_commit();
tile_regs_wait();
cb_reserve_back(cb_out0, 2);
pack_reconfig_data_format(cb_out0);
pack_tile(0, cb_out0, 0);
pack_tile(1, cb_out0, 1);
tile_regs_release();
cb_push_back(cb_out0, 2);
cb_pop_front(cb_in0, 2);
}
}
The tricky part lies in the custom rope_tile function. Remember (and if you don’t, go read the documentation on Tiles - shame on you and your cow for skipping it) that each tile is made up of 4 faces, with each face being a 16x16 row-major matrix. rope_tile configures the SFPU - don’t ask me what exactly it does; I borrowed the setup from elsewhere, and it just works - and then runs the actual computation 4 times, once for each face.
inline void rope_tile(int pos, int D_active, int vec_offset)
{
math::set_dst_write_addr<DstTileLayout::Default, DstTileShape::Tile32x32>(0);
math::set_addr_mod_base();
TTI_STALLWAIT(p_stall::STALL_SFPU, p_stall::MATH);
for (int face = 0; face < 4; face++) {
rope_face(pos, D_active, vec_offset + ((face % 2 == 0) ? 0 : 16)); // real work happens here!
TTI_SETRWC(p_setrwc::CLR_NONE, p_setrwc::CR_D, 8, 0, 0, p_setrwc::SET_D);
TTI_SETRWC(p_setrwc::CLR_NONE, p_setrwc::CR_D, 8, 0, 0, p_setrwc::SET_D);
}
math::clear_dst_reg_addr();
TTI_STALLWAIT(p_stall::STALL_CFG, p_stall::WAIT_SFPU);
math::clear_addr_mod_base();
}
rope_face handles the actual computation. Implementing it is quite the journey. The first challenge is accessing different tiles stored via copy_tile. Tiles are spaced 32 dst_reg apart. The TTI commands in the caller function adjust the pointer so that each call begins at the start of a new face. To access the same chunk on another tile, simply add 32 to the current index.
Image: Register space of the Dst registers viewed from the SFPU
And how are the faces loaded into the SFPU? You might expect that since the SFPU is 32-wide and each face has a width of 16, each load would simply bring in 2 rows into the SFPU. That is not the case. For some reason (well, there’s legit reasons, but will not get into that here) the load is interleaved in Dst.
Image: dst_reg is interlaved loading 4 rows of 8 elements into the 32wide LReg
Therefor to calcualte the row index. We must take the vConstTileId variable (which contains [0, 2, 4... 30]). Mod it by 16. Add on the row offset of the face, then add 1 if we are processing the odd lanes.
The standard calculations to arrive the rotation angle used in RoPE is very inefficient on the SFPU. The SFPU is designed as a small but fast vector unit - it lacks a lot of complex function evaluation capabilities. No arbitrary pow(x, y), no division. Just basic arithmetic. There are some built in functions, but not designed as a part of the public interface.
The original formula freq = 1.0f / std::pow(10000.0f, exponent) has to be rewritten to avoid unsupported calculations:
- Replace power with exp:
freq = 1.0f / exp(exponent * log(10000.0f)) - Convert the fraction into exponent:
freq = exp(-exponent * log(10000.0f)) - Evaulate the log manually:
freq = exp(-exponent * 9.21034037f)
The standard exp_tile API won’t work here - they expect a full tile within Dst as the input. Instead we can find that that the public API internally calls ckernel::sfpu::_sfpu_exp2_21f_. Which does accept a SFPU variable as parameter. Likewise sin_tile is not directly usable for the same reason. However ckernel::sfpu::calculate_sine is very close to what we need. We can yank and modify it to suit our needs.
inline vFloat vector_sin(vFloat x)
{
// handles input not in [-pi, pi]
vFloat v = x * ckernel::sfpu::FRAC_1_PI;
vInt whole_v = float_to_int16(v, 0);
v -= int32_to_float(whole_v, 0);
v = ckernel::sfpu::sfpu_sinpi<false>(v);
v_if(whole_v & 1) { v = -v; }
v_endif;
return v;
}
inline void rope_face(int pos, int D_active, int vec_offset)
{
float inv_d = 1.f/D_active;
for (int i = 0; i < 8; i++) {
// No mod operator on SFPI, use bit hack
vFloat block_lane_id = int32_to_float((vConstTileId & 15) + vec_offset + i % 2);
vFloat exponent = 2.f * block_lane_id * inv_d;
vFloat term_to_exp = -exponent * 9.21034037f;
vFloat freq = ckernel::sfpu::_sfpu_exp2_21f_<true>(term_to_exp);
// Standard RoPE math
vFloat angle = int32_to_float(pos) * freq;
vFloat sin_value = vector_sin(angle);
vFloat cos_value = vector_sin(ckernel::sfpu::PI_2 - angle);
vFloat x = dst_reg[i];
vFloat y = dst_reg[i+32];
dst_reg[i] = x * cos_value - y * sin_value;
dst_reg[i+32] = x * sin_value + y * cos_value;
}
}
Numeric accuracy
The implementation works - for a while. However, numerical errors become significant when moving from pos = 1 to pos = 1000. The values deviate wildly, with some even being the negative of what they should be! What the actual.. it’s just a rotation of the original value, so how could this happen? Since the algorithm works perfectly for lower positions and the logic remains unchanged regardless of the position values. Issue has to be something to do with numerical precision.
Much debugging later, I pinpointed freq as the culprit. The exponent matched the CPU’s calculation exactly, bit for bit. However, multiplying by 10000 and then feeding that into sin() and cos() explained the issue: some values ended up as the negative of the reference, likely due to being half a phase off, causing the trigonometric outputs to invert. I dumped the variable into a CSV file, and the data confirmed my suspicion.
Some quick ROOT code to visualize what is going on (following code is slightly truncated to fit on a blog post. I choose ROOT as I don’t want to deal with Pandas. Plus I love Minuit2’s curve fitting and boundary handling capabilities; no need to worry about numerical edge cases).
// analysis.cpp
// Invoke with: root analysis.cpp
void analysis()
{
auto* c1 = new TCanvas;
c1->Divide(2, 2);
c1->cd(1);
auto df = ROOT::RDF::FromCSV("outputdata.csv");
std::string col1 = "cpu_freq";
std::string col2 = "device_freq";
graph = df.Graph(col1, col2);
graph->SetMarkerStyle(3);
graph->Draw("ap");
double error = 0.0;
double max_error = 0.0;
TH1D* h1 = new TH1D("h1", "Error Distribution;|CPU value - Device value|;Counts", 100, 0, max_error);
df.Foreach([&](double cpu_angle, double device_angle) {
double err = std::abs(cpu_angle - device_angle);
error += err;
if(err != 0)
h1->Fill(err);
max_error = std::max(max_error, err);
}, {col1, col2});
c1->cd(2);
h1->Draw();
auto err_df = df.Define("error", [&](double a, double b) {
return std::abs(a - b);
}, {col1, col2});
graph_err = err_df.Graph(col1, "error");
graph_err->SetMarkerStyle(3);
c1->cd(3);
graph_err->Draw("ap");
func_err = std::make_shared<TF1>("func_err", "pol1", 0, 100);
graph_err->Fit(func_err.get());
func_err->SetLineColor(kRed);
func_err->Draw("same");
auto err_df2 = df.Define("error_normalized", [&](double a, double b) {
return std::abs(a - b) / a;
}, {col1, col2});
graph_err2 = err_df2.Graph(col1, "error_normalized");
graph_err2->SetMarkerStyle(3);
c1->cd(4);
graph_err2->Draw("ap");
func_err2 = std::make_shared<TF1>("func_err2", "pol1", 0, 100);
graph_err2->Fit(func_err2.get());
gPad->SetLogx();
func_err2->SetLineColor(kRed);
func_err2->Draw("same");
}
Running the above script yields the following plots. On the bottom left, titled error vs device_freq shows the absolute error between that the device calculated and what the CPU calculated - this error ought to be as low as possible. Now the problem makes sense. Given max error is ~0.0018, multiplying it by 10000 is 18. There is more then enough space for sine and cosines to be half a phase apart.
Image: Error plot of the builtin exp function in Metalium
Looking into the paper of the exponential algorithm used by Metalium (I had to borrow someone’s IEEE account), the exp_21f is one of the proposed algorithms to approximate the exponential function. It strikes a balance between speed and “good enough” accuracy. Apprantly it’s not good enough for my use case. The exp_61f and exp_24f algorithms seem quite nice in the table. 50~100x accuracy should reduce the maximum error before going into trigonometry from 18 down to 0.18.
Image: Table of accuracy of the proposed algorithms in the paper
I ended up using exp_24f. I might have implemented it wrong, but, though exp_61f looks nice on paper. It does not behave well outside of the [-1, 1] range. While exp_24f behaves very well even from x=-10.
Fun fact. Somehow in the paper the author wrote:
float y,d1,d2,d3;
....
if(zif > 0x00200000)
{//second segment
d1 = 0.37120473e-7f;
d2 = 0x1113a74+zif;
d3 = 0x9f16+zif;
}
else
{//first segment
d1 = 0.31214472e-7f;
d2 = 0x151d842+zif;
d3 = 328.83582f+zif;
}
What? d3 is declared as a floating-point variable but is treated as an integer everywhere else (including the code Metalium uses, see the 21f algorithm in the same paper). The aim is to manipulate the mantissa, so it makes sense for d3 to be an integer. However, through trial and error, it turns out the author actually intended them to be floating-point values—i.e., d3 = 328.83582f. The above integer-like values all fit nicely into floating-point representation. The following is the core of the 24f algorithm in SFPU code (which treats d2 and d3 as integers—it works well enough and allows me to save on SFPU instructions).
v_if(zif > 0x00600000) {
// Fourth segment (highest values of the mantissa)
POLY_D1 = 0.52496276e-7f;
POLY_D2 = 0x81354a;
POLY_D3 = 0x10a440;it shh
}
v_elseif(zif > 0x00400000) {
// Third segment
POLY_D1 = 0.4414393e-7f;
POLY_D2 = 0xcdf4b4;
POLY_D3 = 0x3e4d6;
}
v_elseif(zif > 0x00200000) {
// Second segment
POLY_D1 =0.37120473e-7f;
POLY_D2 = 0x1113a74;
POLY_D3 = 0x9f16;
}
v_else {
// First segment
POLY_D1 = 0.31214472e-7f;
POLY_D2 = 0x151d842;
// Note: The original C code has a float constant here
POLY_D3 = 328;
}
v_endif;
vFloat d1 = vFloat(POLY_D1);
vFloat d2 = int32_to_float(vInt(POLY_D2) + zif, 0);
vFloat d3 = int32_to_float(vInt(POLY_D3) + zif, 0);
Woohoo! exp_24f dramatically reduces the error compared to the builtin implementation. Max error is now somewhere in the 2x10^-5 range. And so much so that floating point blocky-ness can be seen in the plot itself (though this is printed via std::to_string, so it only prints effective digits). See the following diagram for a visual comparison:
Image: Error plot of the exp_24f algorithm in Metalium
Switching to exp_24f solves the accuracy issue singlehandily. Now pos=1000 have 100% pass rate. With maximum error of the final result at 0.03. Improvements can be made by using a better sin() approximation then what Metalium uses. But I consider what is achieved here good enough.
Profiling & Base Optimization
The code works. Now the question is—how slow is my implementation, and how fast can I make it? Fortunately, we have tools to help. Tenstorrent has forked the Tracy profiler, an excellent tool for GPU applications. To use it, the BUILD_TRACY CMake flag must be enabled in the host Metalium build. Since I am developing on a remote machine over SSH, I also need a local copy of Tracy running on my machine.
First, the Metalium build you are using has to have Tracy enabled. This can be achieved in 2 ways, either add the flag manually to the CMake command. Or add the --enable-tracy flag in the build script.
# If you are using CMake
cd /path/to/your/tt-metal/build
cmake . -DBUILD_TRACY=ON
ninja install
# If you are using the build script
cd /path/to/your/tt-metal
./build_metal.sh --enable-tracy
To create a capture. Start the cpature tool tt-metal/build/tools/profiler/bin/capture-release then your application. It should connect once the application starts and automatically saving the trace information.
Video: Video - How to create a capture of a Metalium application
The GUI needs to be built seperately (note that prior to this blog post, the GUI doesn’t build on Linux properly, I have upstreamed a patch to make it work):
Reference build commands:
git clone https://github.com/Tenstorrent/tracy.git
cd tracy
cd profiler/build/unix
make -j8 release
It should result in a Tracy-release binary in the build folder. Run it. And an empty window should show up. The GUI can load and display the capture created earlier. Or connect to a running process and display the information on the fly.
Image: The default window of Tenstorrent’s Tracy fork
IMPORTANT: The program being profiled acts as the server and capture tool a client. For remote development, create a forward tunnel to port 8086 on the remote machine. I use the following SSH command to establish the connection between the Tracy GUI and the program being profiled. Due to how SSH works, usually the Tracy GUI will wait for the server to come online. SSH will reject the connection if the application is not running at the time of connection. You must connect after starting the application.
I highly recommend adding the -C flag to since the datasteam from Tracy is very compressible.
ssh -C -NL 8086:localhost:8086 user@remote-machine
Or use a VPN that puts both machines under the same network. There’s a million guides out there.
Now, in order to profile device kernels. We must set the TT_METAL_DEVICE_PROFILER environment flag; without the flag, kernel profiling will become null operation and no overhead will incur. And add a DeviceZoneScopedN object to the region of interest. The profile will record the time between object construction and destruction.
// in device kernel
inline void rope_tile(int pos, int D_active, int vec_offset)
{
DeviceZoneScopedN("ROPE-TILE"); // insert this to create a profiling region
// rest of the SFPI code
math::set_dst_write_addr<DstTileLayout::Default, DstTileShape::Tile32x32>(0);
math::set_addr_mod_base();
...
}
Ether connect to a running program or load up a captured trace. The window should now display the program running information.
Image: Tracy showing profiled information
Zoom into the top left, where the kernel timing is. We can see the compute (to be precise, the SFPI code) takes 44us to process 4 pairs of tiles (or 11us per tile). That is SLOW, but not particularly unexpected. In rope_face, the line float inv_d = 1.f/D_active does floating point division, which the RISC-V cores on does not support and has to be done using softfp. Which is slow. And is currently doing 4 times per face.
Image: The profiled result of the baseline kernel (with N = 32, D = 2048, D_active=256)
IMPORTANT: The overhead of profiling is nontrivial. You should not take microbenchmarking results and consider them the exact cycle count needed for the operation. That said, the delta in cycle count before and after code changes is still very useful.
Hoisting the floating point division
That division is the most obvious inefficiency and the easiest to address. Since parameters can be passed around like in any C program, and the D parameter is a single scalar (not a tensor), we can move the division to the beginning of the kernel. This allows the result to be shared across all pairs of tiles.
namespace NAMESPACE {
void MAIN {
...
float inv_d = 1.f/(n_tiles_width_active * (32 / 2));
...
for(...) {
...
// during invocation
MATH(rope_tile(1000, n_tiles_width_active * (32 / 2), j*32));
}
...
}
}
Re profiling shows the overall kernel execution time dropped from 5us from 44us to 39us. Pretty good for a simple optimization!
Image: Hoisting the division reduces kernel execution time down to 39us
Reuse exponentiation
For accuracy, we implemented a custom exponential function. Our new implementation is much larger then the one that is used by default. It is bounded to be slow looking at the shear amount of code that forms it. Luckily, columns within the same tile (thus by extension each face) shares the same index in the input vector. We could reuse everything up to the point where we start to where each row diverges (keep in mind we still need to support different token index per row, as the GGML API asks)
for (int h = 0; h < 2; h++) {
// No mod operator on SFPI, use bit hack
vFloat block_lane_id = int32_to_float((vConstTileId & 15) + vec_offset + h);
vFloat exponent = 2.f * block_lane_id * inv_d;
vFloat term_to_exp = -exponent * 9.21034037f;
vFloat freq = vector_exp(term_to_exp);
for (int i = 0; i < 4; i++) {
// Standard RoPE math
vFloat angle = int32_to_float(pos) * freq;
vFloat sin_value = vector_sin(angle);
vFloat cos_value = vector_sin(ckernel::sfpu::PI_2 - angle);
int idx = i*2+h;
vFloat x = dst_reg[idx];
vFloat y = dst_reg[idx+32];
dst_reg[idx] = x * cos_value - y * sin_value;
dst_reg[idx+32] = x * sin_value + y * cos_value;
}
}
Wopping 17us reduced!
Image: Reusing the result of the exponential function shaves 17us from execution time
Numerical tricks
The next optimization we’ll do is around the sine function, this is most likely just a few cycles of faster, but in HPC even one cycle is a lot. vector_sin always multiplies the input value with 1/pi in order to convert the expected range from [-pi, pi] to [-1, 1]. Then applying the antisymmetric property of sine to further reduce the range down to [0, 1] which the polynomial approximation uses. That’s a load and multiplication we could avoid per invocation of the function. Let’s turn vector_sin into vector_sin_phase - if we can amortize that multiplication away early on in the calculation:
inline vFloat vector_sin_phase(vFloat x)
{
// was
// vFloat v = x * ckernel::sfpu::FRAC_1_PI;
vFloat v = x;
vInt whole_v = float_to_int16(v, 0);
v -= int32_to_float(whole_v, 0);
v = ckernel::sfpu::sfpu_sinpi<false>(v);
v_if(whole_v & 1) { v = -v; }
v_endif;
return v;
}
Reviewing the currnet logic of values going into the sine function, with the understanding that there’s a implicit division by pi now:
// current logic
vFloat term_to_exp = -exponent * 9.21034037f;
vFloat freq = vector_exp(term_to_exp);
vFloat angle = int32_to_float(pos) * freq;
vFloat sin_value = vector_sin(angle);
vFloat cos_value = vector_sin(ckernel::sfpu::PI_2 - angle);
// What we need to turn it into
vFloat sin_value = vector_sin_phase(angle * ckernel::sfpu::FRAC_1_PI);
vFloat cos_value = vector_sin_phase((ckernel::sfpu::PI_2 - angle) * ckernel::sfpu::FRAC_1_PI);
We can find that the PI_2 term becomes 0.5. And we seem to multiply angle by 1/pi. Which by constantly moving it upwards and rewriting our equations:
- We need a new
angle_phase = angle/pi - Expanding that
angle_phase = int32_to_float(pos) * freq / pi - Evaluate the final two terms together
freq/pi = vector_exp(term_to_exp) / pi - And move the division into exponentiation
freq/pi = vector_exp(term_to_exp - ln(pi)) - Finally expend the internal term to be exponented
term_to_exp - ln(pi) = -exponent * 9.21034037f - ln(pi) - As
ln(pi)is a constantnew_term_to_exp = -exponent * 9.21034037f - 1.14472988585f - Also
PI_2 / PIis 1/2 or 0.5
The core numerical calculations become
vFloat term_to_exp = -exponent * 9.21034037f - 1.14472988585f;
vFloat freq = vector_exp(term_to_exp);
vFloat angle = int32_to_float(pos) * freq;
vFloat sin_value = vector_sin_phase(angle);
vFloat cos_value = vector_sin_phase(0.5f - angle);
This optimization reduces execution time by another 1.13us
Image: Using sine in the phase form reduces compute execution time by 1.13us
Using constant registers
Finally, Looking at the Low Level Kernels document from the official documentation site, several constants are defined and backed by hardware registers. In which vConst{Float,Int}Prgm[0-2] can be set to hold a single floating point or integer value at runtime (i.e. vConstFloatPrgm0 = 3.1415f or vConstIntPrgm1 = 1000).
Constant registers are implemented as objects which can be referenced wherever a vector can be used. On Wormhole and Blackhole the following variables are defined: * vConst0 * vConst1 * vConst0p8373 * vConstNeg1 * vConstTileId, counts by two through the vector elements: [0, 2, 4..62] * vConstFloatPrgm0, vConstIntPrgm0 * vConstFloatPrgm1, vConstIntPrgm1 * vConstFloatPrgm2, vConstIntPrgm2
That’s use them then. We happen to have used several constants in the kernel. Removing the need to load them during kernel execution should make it slightly faster
inline void rope_tile_init(float inv_d)
{
vConstFloatPrgm0 = 9.21034037f;
vConstFloatPrgm1 = 1.14472988585f;
vConstFloatPrgm2 = inv_d;
}
inline void rope_face(int pos, float inv_d, int vec_offset)
{
...
for (int h = 0; h < 2; h++) {
vFloat block_lane_id = int32_to_float((vConstTileId & 15) + vec_offset + h);
vFloat exponent = 2.f * block_lane_id * vConstFloatPrgm2; // Use them like any other variables
vFloat term_to_exp = -exponent * vConstFloatPrgm0 - vConstFloatPrgm1; // Use them like any other variables
vFloat freq = vector_exp(term_to_exp);
...
}
}
Another 0.5us faster!
Image: Programmable constants avoids loading using instructions. 0.5us faster
Supporting per-row token position
2x faster is quite nice, and there’s more tricks we can pull to maybe drop the execution down another 30%. However we have to stop here. We haven’t fulfilled the GGML API yet. Currently we share the same token potion across the entire tile. However, GGML supports per-row position. The goal then, is eventually in tile_face, instead of loading a single integer as the token position, like we have been
NOTE: I made a blunder of universal scale here. Due to many factors (including me being stupid), I didn’t notice GGML uses per-**batch** position. Which got me down the road of optimizing for per row data movement and loading. I only realized I made this mistake when I tried integrating with GGML.
vFloat angle = int32_to_float(pos) * freq;
To loading 4 integers into differnet rows within the vFloat variable.
vFloat vpos = int32_to_float(load_into_row(pos_ptr+offset));
vFloat angle = vpos * freq;
In order to do so, we must read position index from DRAM, store them in local SRAM in linear order. And pass that pointer to the compute cores. Which presents a problem - so far all operations we discussed works on a tile level, everything is a tile. But position indicies are not. They are raw vectors. What ever should we do..
TTNN supports row major layouts too. See how that’s handled in the official document (you’ll find the document very familare if you are a reader of this blog)
To store row major tensors, TTNN splits the tensor into rows and round robbin the rows across memory controllers. Like so:
Image: A row major tensor of shape (48, 1024)
The standard noc_async_read_page API will happily read the rows into SRAM for us. But we ought to read 32 elements at a time. Just in case someone decided to send in a 50K word prompt and that turns into a 200KB buffer; let alone we need 2 of them to double buffer. And we only have 1.5MB of total SRAM per tensix core. Oh well, hacking time. Instead of raw reads.
We should manually calculate the addresses and read 128 bytes (32 four-byte int) at a time. On the host side, we create a buffer with the same scheme as TTNN would. And a circular buffer holding 32 integers at a time. Note that in this particular case, as there’s only one row of indices, the page size and total size is the same for the index buffer - this will change when we support batching.
auto idxs = MakeBuffer(device,
/*total_size=*/N*sizeof(int32_t),
/*page_size=*/N*sizeof(int32_t),
/*sram=*/false);
MakeCircularBuffer(program, core, tt::CBIndex::c_1,
/*pagfe_size=*/ N*sizeof(int32_t),
/*total_size=*/ 32*sizeof(int32_t),
/*dtype=*/tt::DataFormat::Int32);
In the reader kernel, accessing subreagons of the index buffer goes through the same TensorAccessor object. But we manually set offset and read size when reading from DRAM.
constexpr auto idx_args = TensorAccessorArgs<src_args.next_compile_time_args_offset()>();
// Cannot rely of get_tile_size() since we are not reading a tile at a time. Manual
// calculation of page size (or just pass in as kernel parameter).
uint32_t idx_page_size = n_tiles_height*32*sizeof(int32_t);
const auto idx = TensorAccessor(idx_args, idx_addr, idx_page_size);
constexpr uint32_t cb_in1 = tt::CBIndex::c_1; // where we will be pushing index data
for(uint32_t h = 0; h < n_tiles_height; h++) {
cb_reserve_back(cb_in1, 1);
uint32_t cb_idx_addr = get_write_ptr(cb_in1);
// Recall we only have one row of indices now. So always on page 0
// v
uint64_t read_addr = idx.get_noc_addr(0, /*offset=*/32*sizeof(int)*h);
// Read 32 integers at a time
noc_async_read(read_addr, cb_idx_addr, 32*sizeof(int));
noc_async_read_barrier();
cb_push_back(cb_in1, 1);
...
// Rest of the reader kernel
}
Now, in the compute kernel. We don’t use get_read_ptr to access the SRAM address. Instead use cb_get_tile - this API also handles CB synchronization between the 3 compute cores.
for(uint32_t i = 0; i < n_tiles_height; i++) {
// Read the address of the CB pushed by the reader kernel
cb_wait_front(cb_in1, 1);
int* idxs = nullptr;
cb_get_tile(cb_in1, 0, &idxs);
// IDK why, but documented to need shift of 16 bytes
idxs += 4;
// Rest of the compute loop for RoPE
...
for(uint32_t j = 0; j < n_tiles_width_active/2; j++) {
...
// Pass the indices pointer to the SFPI function
MATH(rope_tile(idxs, inv_d, j*32));
...
}
cb_pop_front(cb_in1, 1);
}
Pass it down the chain of compute calls and applying the correct offset so each face has the right pointer.
inline void rope_tile(int* pos, float inv_d, int vec_offset)
{
...
for (int face = 0; face < 4; face++) {
// Where are we along the vector
int internal_offset = ((face % 2 == 0) ? 0 : 16);
// Which index we should read from
int idx_offset = face > 1 ? 16 : 0;
rope_face(pos + idx_offset, inv_d, vec_offset + rope_tile_precompute_pos(idxs)internal_offset);
TTI_SETRWC(p_setrwc::CLR_NONE, p_setrwc::CR_D, 8, 0, 0, p_setrwc::SET_D);
TTI_SETRWC(p_setrwc::CLR_NONE, p_setrwc::CR_D, 8, 0, 0, p_setrwc::SET_D);
}
...
}
Now each face has a int* pointer pointing to the position values in SRAM. The custom load_into_row function takes in a pointer and load the first 4 values and returns a vInt. Again bit tricks here. As there is not division operation available. We mask away the lower 4 bits and check the remaining bits to determine which row to load.
inline vInt load_into_row(int* ptr)
{
// Recall vConstTileId:
// contains [0, 2, 4, 6, 8, 10, 12, 14, 16, 18 ... 62]
// By anding it with ~15 (bit inverse of 15) we mask away the lower 3 bits. Thus
// remains [0, 0, 0, 0, 0, 0, 0, 0, 16, 16, ... 48]
// Which we can then check the remaining bits to determine which row to load.
vInt row_mask = vConstTileId & (~15);
vInt v = ptr[0]; // Unconditional load (avoid one comparsion)
v_if(row_mask == 16) { // Load from the next row
v = ptr[1];
}
v_elseif(row_mask == 32) { // Load from row 2
v = ptr[2];
}
v_elseif(row_mask == 48) { // Load from row 3
v = ptr[3];
}
v_endif;
return v;
}
inline void rope_face(int* pos, float inv_d, int vec_offset)
{
...
for(...) {
vFloat vpos = int32_to_float(load_into_row(pos+i*4));
vFloat angle_phase = vpos * freq;
...
}
}
It’s a long path to getting independent indices. Unfortunately now we are back to 32us to process 4 pairs of tiles. Yikes.
Image: Back to 32us with added per row handling
Optimization: Abusing Dst as data storage
It needs some optimization. I asked around and even looked at the ISA documentation. I couldn’t come up with a better solution than what I implemented to load integers into rows. Fine, I’ll have to amortize the cost of the loads. Currently, the token positions are loaded for each pair of vectors to perform RoPE, which doesn’t make much sense since the entire face shares the same set of 16 positions. Furthermore, the entire row of tiles shares the same 32 token positions. If only there were a way to persist vector values across multiple iterations (NOTE: SFPI does not support spilling to SRAM, and SFPU can only communicate with SRAM using the packer).
Image: Position indices are shared across face and tiles
There is one place that we could store data into - the Dst registers. However some very important caveats
- Position indices can reach values in the millions, necessitating the use of FP32 in Dst (as neither bfloat16 nor int16 can represent values in this range). This requires setting
fp32_dest_acc_entotruein the compute kernel configuration. - For performance reasons, each time
tile_regs_acquireis called, only half of the actual Dst register is made available. This means we need to generate the position indices twice. - Our kernel currently auto-increments
dst_regusingTTI_SETRWC. This behavior needs to be disabled.
That’s do that. First, the function to generate and store the position values for us. Since currently we are using tile 0 and 1 for data input/output. The stored position indices will be stored in tile 2. And since each tile is 32 LReg wide, tile 2 starts at offset 64.
inline void rope_tile_precompute_pos(int* pos)
{
DeviceZoneScopedN("ROPE-TILE-PRECOMP-POS");
math::set_dst_write_addr<DstTileLayout::Default, DstTileShape::Tile32x32>(0);
math::set_addr_mod_base();
TTI_STALLWAIT(p_stall::STALL_SFPU, p_stall::MATH);
for (int i=0;i<8;i++) {
vFloat vpos = int32_to_float(load_into_row(pos+i*4));
dst_reg[64+i] = vpos;
}
math::clear_dst_reg_addr();
TTI_STALLWAIT(p_stall::STALL_CFG, p_stall::WAIT_SFPU);
math::clear_addr_mod_base();
}
Next, we had to disable auto dst_reg increment that is meant to help operator writers to keep in sync with where each face is - now, invoking rope_face looks like this.
for (int face = 0; face < 4; face++) {
int internal_offset = ((face % 2 == 0) ? 0 : 16);
int idx_offset = face > 1 ? 16 : 0;
// Now we had to tell rope_face which face it is processing now
// vvvv
rope_face(pos + idx_offset, inv_d, vec_offset + internal_offset, face);
// These are deleted
// TTI_SETRWC(p_setrwc::CLR_NONE, p_setrwc::CR_D, 8, 0, 0, p_setrwc::SET_D);
// TTI_SETRWC(p_setrwc::CLR_NONE, p_setrwc::CR_D, 8, 0, 0, p_setrwc::SET_D);
}
And in rope_face two things has to happen - manual offset when accessing dst_reg and instead of loading from integers. We direcly load from Dst.
int face_row = face_idx / 2; // If we are in the 1st row of face, helps determine which position to load
int dst_offset = face_idx*8; // Auto dst_reg increment is diabled. Calculate the offset of the current face
for (int h = 0; h < 2; h++) {
vFloat block_lane_id = int32_to_float((vConstTileId & 15) + (vec_offset + h));
vFloat exponent = block_lane_id * vConstFloatPrgm2;
vFloat term_to_exp = -exponent * vConstFloatPrgm0 - vConstFloatPrgm1;
vFloat freq = vector_exp(term_to_exp);
for (int i = 0; i < 4; i++) {
vFloat vpos = dst_reg[64+face_row*4+i]; // Load the precompued position values
// Same RoPE as before
vFloat angle_phase = vpos * freq;
vFloat sin_value = vector_sin_phase(angle_phase);
vFloat cos_value = vector_sin_phase(0.5f - angle_phase);
size_t idx = i*2+h;
// Auto dst_reg increment is diabled. `dst_offset` is added to manually
// calculate where to read and write
vFloat x = dst_reg[dst_offset+idx];
vFloat y = dst_reg[dst_offset+idx+32];
dst_reg[dst_offset+idx] = x * cos_value - y * sin_value;
dst_reg[dst_offset+idx+32] = x * sin_value + y * cos_value;
}
}
Finally the precomputation has to be invoked for the first 2 iterations of a row in the main kernel - to populate (the double buffered) Dst with the precomputed position values.
for(uint32_t j = 0; j < n_tiles_width_active/2; j++) {
cb_wait_front(cb_in0, 2);
tile_regs_acquire();
if(j<2) {
MATH(rope_tile_precompute_pos(idxs));
}
...
// Rest is the same
}
Pew... That’s a lot of small changes. And we are back down to 22.6us
Image: Down to a more reasonable speed amortizing the cost of vector load
Reusing exponent within a tile
In the same spirit of reusing token positions, the exponent is also consistent within the same column of a face within a tile. We have already used this fact within a face to reduce the number of exponential operations. By extending this optimization across faces, we can further enhance performance.
Image: Exponents are the same on the same column
NOTE: Similar to reusing token positions across tiles, it is possible to reuse the exponents across tiles as well. However, reusing exponents potentially requires significantly more space than reusing token positions. Each tile needs 4 vectors of exponents. Since there are only 2 free tiles (or 64 free vectors) in Dst and we already use 8 vectors to hold positions, this leaves 64 - 8 = 56 vectors available for exponents. This allows for at most 896 active dimensions to apply RoPE to. This is sufficient for LLMs; for example, Gemma 2 2B uses 128 active dimensions. While it is possible to work around this limitation with additional logic. In the name of not loosing generality - I will avoid reusing exponents across tiles.
Prior to invoking the actual RoPE computation. Generate the exponents and store them in Dst. Since we used the first 8 vectors to hold the token position. Exponents is stored from the 8th vector.
// Generate the exponents for the current tile and store them in Dst
for(int i=0;i<4;i++) {
int internal_offset = ((i / 2 == 0) ? 0 : 16);
int pos_in_vector = vec_offset + internal_offset;
vFloat block_lane_id = int32_to_float((vConstTileId & 15) + (pos_in_vector + i % 2));
vFloat exponent = block_lane_id * vConstFloatPrgm2;
vFloat term_to_exp = -exponent * vConstFloatPrgm0 - vConstFloatPrgm1;
vFloat freq = vector_exp(term_to_exp);
dst_reg[64+8+i] = freq;
}
// The same rope_face invocation
for (int face = 0; face < 4; face++) {
int internal_offset = ((face % 2 == 0) ? 0 : 16);
int idx_offset = face > 1 ? 16 : 0;
rope_face(pos + idx_offset, inv_d, vec_offset + internal_offset, face);
}
And just load the exponents from Dst in the compute loop.
for (int h = 0; h < 2; h++) {
vFloat freq = dst_reg[64+8+face_col*2+h];
...
// Rest of RoPE
}
Image: Sharing of exponent
Support for batching
Batching support is straightforward compared to the complexities discussed earlier in this article. The process involves adding an extra parameter and an outer loop. On the host side, a new parameter B is introduced to represent the batch size. The main adjustment required is changing the shape of the index tensor to [B, N]. Since row-major tensors on TTNN store an entire row as a single page, the total size of the tensor is updated, while the page size remains unchanged.
constexpr size_t B = 4; // New parameter
constexpr size_t D = 2048;
constexpr size_t D_active = 256;
constexpr size