← voidwest    writeup

ember internals

this is a line-by-line walkthrough of the data structures, algorithms, and memory layouts that make ember work. it assumes you've read the architectural overview and want the raw details.

tensor

CpuTensor is the only data structure every other component touches. three fields:

pub struct CpuTensor {
    pub shape: Vec<usize>,      // e.g. [12, 64, 768]
    pub strides: Vec<usize>,    // contiguous row-major strides
    pub data: Vec<f32>,         // flat f32 buffer
}

strides are computed but never used for indexing. the compute_strides function builds a standard row-major stride array (e.g. [768, 1] for shape [64, 768]), but every access goes through direct offset math like r * cols + c. the strides exist purely for get(&[usize]), the multi-dimensional indexer used only in testing.

every operation allocates. no views, no mutation in place. add() returns a new CpuTensor. softmax() returns a new CpuTensor. matmul() returns a new CpuTensor. the allocator sees a stream of identical-sized allocations during decode (token count is constant), so jemalloc or the system allocator reuses the same slab. this is not accidental, it's the reason the hot path can get away with per-op allocations without profiling as a malloc storm.

#[must_use] on every pure op. if you write x.softmax(); without binding the result, the allocation drops silently and the model runs on stale data. the attribute makes that a compile error. it caught the bug of forgetting to assign layer norm output back to x in the transformer block loop, a bug that produces no panic, no NaN, just progressively degraded text.

matmul: matrixmultiply::sgemm

the matmul delegates to bluss's matrixmultiply crate, a pure-rust sgemm with no blas linking. the call site:

unsafe {
    matrixmultiply::sgemm(
        m, k, n,                // dimensions
        1.0,                    // alpha
        a.as_ptr(), k as isize, 1,   // A: m×k, col stride = k, row stride = 1
        b.as_ptr(), n as isize, 1,   // B: k×n, col stride = n, row stride = 1
        0.0,                    // beta
        c.as_mut_ptr(), n as isize, 1, // C: m×n, row-major
    );
}

matrixmultiply is scalar, so this is the throughput bottleneck. the Backend trait is the insertion point for simd: implement SimdBackend with the same trait, swap the backend type, and every matmul in the model uses avx2 or neon without touching model code.

softmax

applied along the last dimension of any-rank tensor. the batch prefix is flattened: for shape [12, 64, 768], there are 12 × 64 = 768 rows of 768 elements each. the algorithm for each row:

let max = slice.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));

// branch: all-masked row
if max == f32::NEG_INFINITY {
    let uniform = 1.0 / (last_dim as f32);
    for i in 0..last_dim { out[offset + i] = uniform; }
    continue;
}

// stable softmax
let mut sum = 0.0;
for i in 0..last_dim {
    let e = (slice[i] - max).exp();
    out[offset + i] = e;
    sum += e;
}
let inv_sum = sum.recip();
for i in 0..last_dim { out[offset + i] *= inv_sum; }

two passes per row, three for the edge case. the inner loop is scalar exp(), which is 10-15 cycles on modern x86. for a 768-element row, that's ~10k cycles per softmax, or ~120k cycles per transformer block (12 heads × 64 dim). across 12 layers and a growing sequence, softmax accounts for roughly 8-12% of decode time.

gelu

let inv_sqrt_2 = 0.707_106_77_f32;
let data: Vec<f32> = self.data.iter()
    .map(|&x| {
        let z = x * inv_sqrt_2;
        0.5 * x * (1.0 + libm::erff(z))
    })
    .collect();

libm::erff is a pure-rust port of the musl error function. no libm linking, no target-specific intrinsics. the constant 1/√2 is precomputed as 0.70710677, saving a division per element. gelu fires on 3072 elements per token per layer in the mlp hidden dimension, 36,864 calls per forward pass.

layer norm

for b in 0..batch {
    let slice = &self.data[offset..offset + features];

    let mean = slice.iter().sum::<f32>() / features as f32;
    let var = slice.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / features as f32;
    let std = (var + eps).sqrt();

    for i in 0..features {
        let normalized = (slice[i] - mean) / std;
        out[offset + i] = normalized * weight.data[i] + bias.data[i];
    }
}

three passes per row: one for mean, one for variance, one for the normalization. epsilon is 1e-5, the standard value from ba et al. (2016). the weight and bias are learned per-feature vectors loaded from the gguf file, identical to the pytorch nn.LayerNorm(768) parameters.

kv cache

memory layout

layer 0 ──┬─ head 0 ──┬─ pos 0: [d0 d1 d2 ... d63] │ ├─ pos 1: [d0 d1 d2 ... d63] │ └─ ... ├─ head 1 ──┬─ pos 0: [d0 d1 d2 ... d63] │ └─ ... └─ ... layer 1 ──┐ ...

the flat offset for k[layer][head][pos][dim] is:

offset = layer * n_heads * max_seq_len * head_dim
       + head  * max_seq_len * head_dim
       + pos   * head_dim
       + dim

total allocation: 12 layers × 12 heads × 2048 positions × 64 dim × 4 bytes × 2 (k and v) = 75,497,472 bytes ≈ 72 MiB. allocated once in KVCache::new() via vec![0.0; len]. never grown, never reallocated.

append

the caller supplies k and v as flat [n_heads * head_dim] slices, the concatenated per-head projections for a single token at a single layer. append scatters them into the cache:

for h in 0..self.n_heads {
    let dst = layer_offset + h * max_seq_len * head_dim + pos * head_dim;
    let src = h * head_dim;

    self.k[dst..dst + head_dim].copy_from_slice(&k_new[src..src + head_dim]);
    self.v[dst..dst + head_dim].copy_from_slice(&v_new[src..src + head_dim]);
}

the cursor bug (bug #5). the original implementation ignored the pos parameter and wrote every entry at self.cursor. but cursor is deliberately not advanced until Gpt2::forward_with_cache finishes all 12 layers. during prefill, the attention loop called append for each prompt token, every call landed at cursor = 0, overwriting the previous. only the final token's k/v survived. the fix passes an explicit pos to append, computed as cache.cursor() + token_index in the attention forward pass.

qk_scratch

a separate Vec<f32> pre-allocated to max_seq_len (2048). reused across every head and every token in a decode step:

pub fn qk_scratch_mut(&mut self) -> &mut Vec<f32> {
    &mut self.qk_scratch
}

the caller does:

let scratch = cache.qk_scratch_mut();
scratch.clear();
scratch.resize(total_seq_len, f32::NEG_INFINITY);

// fill scratch[pos] = dot(q, k[pos]) for pos in 0..total_seq_len
// softmax in-place on scratch
// weight sum of v[pos] by scratch[pos]

because scratch was allocated to max_seq_len, the resize never re-allocates as long as total_seq_len ≤ max_seq_len. this eliminates 144 small heap allocations per generated token (12 layers × 12 heads).

gguf format & quantization

file structure

offset field 0x00 magic: 0x47 0x55 0x46 0x47 ("GGUF") 0x04 version: u32 (3) 0x08 n_tensors: u64 0x10 n_metadata: u64 0x18 metadata kv pairs (n_metadata entries) ... tensor info table (n_tensors entries) ... tensor data (raw bytes at offsets from info table)

the loader walks the metadata to find model hyperparameters (gpt2.block_count, gpt2.context_length, gpt2.embedding_length, etc.), then reads the tensor info table to get each weight's name, shape, dtype, and byte offset. weights are matched by hardcoded gpt-2 tensor names ("tok_embeddings.weight", "blk.{i}.attn.output.weight", etc.).

q8_0 block quantization

each block encodes 32 f32 values into 34 bytes:

bytes 0-1: d (fp16 scale factor) bytes 2-33: q[0..31] (int8 quantized values) reconstruction: dst[j] = q[j] as f32 * d.to_f32()

the dequantization loop:

for i in 0..n_blocks {
    let d_bits = u16::from_le_bytes(src[block_start..block_start + 2]);
    let d = f16::from_bits(d_bits).to_f32();

    for j in 0..Q8_0_BLOCK_SIZE {
        let q = src[block_start + 2 + j] as i8;  // signed int8
        dst[out_start + j] = q as f32 * d;
    }
}

the int8 values are signed (i8), range [-128, 127]. most gguf implementations use signed int8 for q8_0. the fp16 scale means the representable range per block is roughly [-128 × d, 127 × d] where d varies per block. for gpt-2 weights, typical scales are in the range 0.001-0.05, giving roughly ±0.1 to ±6.4 per weight element.

the column-major trap

gguf stores q8_0 and f16 tensors in column-major order. the q8_0 block size is 32, and the innermost dimension must be a multiple of 32. for a weight matrix shaped [768, 50257], the column dimension (50257) is the block axis, it's padded to a multiple of 32 on disk. the raw bytes are laid out: all 768 rows of column 0, then all 768 rows of column 1, etc.

the tensor info header reports the logical shape [768, 50257] in row-major convention, but the dequantized buffer is column-major. the loader originally called reshape(&[768, 50257]), which assumes row-major, producing a transposed weight matrix. the fix is one line in the loader: reverse dims before calling reshape. after dequantizing or converting f16, the loader does:

// dims = [rows, cols] from the tensor info header
// but data is column-major, so reverse to match:
dims.reverse();
let tensor = CpuTensor::from_data(dims, flat_f32_data);

then in Gpt2::from_loader, linear layer weights are transposed back to restore [in_features, out_features] for matmul. embeddings skip the transpose since the dim reversal already makes index_select pick contiguous rows.

sampler

the sampling pipeline is five stages applied in order:

1. temperature

if temperature > 0.0 {
    for l in &mut logits { *l /= temperature; }
}

temperature = 0 means greedy argmax, the pipeline skips temperature scaling and goes straight to argmax in categorical_sample. any positive value divides logits: t < 1.0 sharpens (peaks get higher relative weight), t > 1.0 flattens (more uniform).

2. top-k

let mut indexed: Vec<(usize, f32)> = logits.iter().cloned().enumerate().collect();
indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(Equal));

let threshold = indexed[k - 1].1;
for l in logits.iter_mut() {
    if *l < threshold { *l = f32::NEG_INFINITY; }
}

sorts a copy of (index, logit) pairs descending, reads the k-th largest value, masks everything below it. O(V log V) where V = 50,257 for gpt-2. the sort is the bottleneck in the sampling path, consuming ~60% of sampling time for typical k = 40-50.

3. top-p (nucleus)

fn top_p_filter(logits: &mut [f32], p: f32) {
    let soft = softmax_1d(logits);
    let cutoff = nucleus_cutoff(&soft, p);
    for (i, s) in soft.iter().enumerate() {
        if *s < cutoff { logits[i] = f32::NEG_INFINITY; }
    }
}

nucleus_cutoff sorts the softmax probabilities descending, accumulates from the top, and returns the smallest probability whose cumulative sum reaches p. the caller then computes a second softmax over the filtered logits in sample_token. this is intentional: the first softmax finds the cutoff, the second produces the final distribution. computing softmax twice is cheaper than the alternative of tracking which indices were filtered.

4. softmax

the softmax_1d helper is identical to the tensor softmax but operates on a plain &[f32] slice. same max-subtraction stability, same all--inf branch for uniform fallback.

5. inverse cdf sampling

let r: f32 = rng.gen();           // uniform [0, 1)
let mut cum = 0.0;
for (i, &p) in dist.iter().enumerate() {
    cum += p;
    if r < cum { return i; }
}
// fallback: argmax
dist.iter().enumerate()
    .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
    .map(|(i, _)| i).unwrap_or(0)

the fallback to argmax handles the (rare) case where floating-point rounding causes the cumulative sum to never reach 1.0, leaving r perpetually above cum.

the backend trait

pub trait Backend {
    type Tensor: Clone + Send + Sync;
    type Error: core::error::Error;

    fn zeroes(&self, shape: &[usize]) -> Result<Self::Tensor, Self::Error>;
    fn matmul(&self, a: &Self::Tensor, b: &Self::Tensor) -> Result<Self::Tensor, Self::Error>;
    fn add(&self, a: &Self::Tensor, b: &Self::Tensor) -> Result<Self::Tensor, Self::Error>;
    fn softmax(&self, x: &Self::Tensor) -> Result<Self::Tensor, Self::Error>;
    fn gelu(&self, x: &Self::Tensor) -> Result<Self::Tensor, Self::Error>;
    fn layer_norm(&self, x: &Self::Tensor, weight: &Self::Tensor,
                  bias: &Self::Tensor, eps: f32) -> Result<Self::Tensor, Self::Error>;
    fn index_select(&self, t: &Self::Tensor, index: usize) -> Result<Self::Tensor, Self::Error>;
    fn assign_row(&self, dst: &mut Self::Tensor, index: usize, src: &Self::Tensor);
    fn slice_cols(&self, x: &Self::Tensor, start: usize, end: usize) -> Self::Tensor;
    fn shape<'a>(&self, x: &'a Self::Tensor) -> &'a [usize];
    fn data<'a>(&self, x: &'a Self::Tensor) -> &'a [f32];
    fn load_from_cpu(&self, data: Vec<f32>, shape: &[usize])
        -> Result<Self::Tensor, Self::Error>;
    fn add_broadcast(&self, x: &Self::Tensor, bias: &Self::Tensor)
        -> Result<Self::Tensor, Self::Error>;
}

why load_from_cpu and not from_cpu. clippy's wrong_self_convention expects from_* methods to be constructors without &self. since backends are invoked through a shared reference (&self), the method takes &self and is named load_from_cpu to avoid the lint. for CpuBackend, this is a thin wrapper around CpuTensor::from_data. for a gpu backend, it would copy data to device memory.

n_layers is stored in KVCache but never read. it exists only to compute the allocation size in new(). removing it would require threading the layer count through every method that touches the cache, or hardcoding it. storing it is the more explicit path, and the field is annotated #[allow(dead_code)].

the Module companion trait. each transformer component implements Module<B: Backend> with a single method forward(&self, backend: &B, x: &B::Tensor) -> Result<B::Tensor, B::Error>. the backend reference flows through every layer, it's the object that knows how to allocate, compute, and move data. swapping backends is a one-line change at the top of Gpt2::new.

performance

metricbefore kv cacheafter kv cache
forward passes (50-tok prompt + 20 gen)71 (O(L²) each)1 prefill + 20 decode (O(L) each)
decode throughput~2 tok/s~12 tok/s
attention hot-path allocations144 per token0
kv cache memoryn/a~72 MiB
model weight memory (q8_0)~130 MiB~130 MiB
total runtime memory~140 MiB~210 MiB

where time goes in decode (per token, 12 tok/s on i7-12700H):

matmul (12 layers × 4 matmuls): ~55 ms (66%) c_attn (768→2304): ~14 ms c_proj (768→768): ~6 ms c_fc (768→3072): ~19 ms c_proj (3072→768): ~16 ms attention (scalar loops): ~12 ms (14%) gelu (12×3072 elements): ~5 ms (6%) layer_norm (12×2×768 elements): ~4 ms (5%) softmax (12 heads × 64 dim): ~3 ms (4%) other (adds, copies, sampling): ~4 ms (5%)

matmul dominates. the c_fc and output projection are the two heaviest matmuls, 768×3072 and 3072×768 respectively, each 2.36 million multiply-adds. simd would drop this to roughly 8-10 ms, pushing throughput toward 40-50 tok/s.

edge cases worth knowing

all-masked softmax rows. the causal mask sets future positions to -inf. on the first decode step (only one token in the sequence), every row after the first is entirely -inf. standard softmax produces NaN ((-inf - -inf).exp() evaluates to NaN per ieee 754). one branch, if max == f32::NEG_INFINITY { return uniform }, prevents NaN from propagating through 12 layers into the output logits.

temperature = 0. the sampler checks if temperature > 0.0 and skips scaling. in categorical_sample, the softmax distribution for sharp logits concentrates all mass on the argmax token, and inverse cdf sampling selects it deterministically. this is how --temperature 0 produces identical output every run (given the same rng seed).

scratch buffer resize. the qk_scratch buffer is allocated to max_seq_len (2048). during decode at position 47, the caller resizes it to 48. because 48 ≤ 2048, the resize is a no-op, capacity was set at construction time. this is a deliberate invariant: the hot path never calls the global allocator.

transposed embeddings. the gguf file stores token embeddings as [vocab, embed] (50257 × 768). the loader transposes them to [embed, vocab] (768 × 50257) during model construction. this means index_select(token_id) picks a contiguous 768-element slice rather than gathering 768 elements strided by vocab_size. cost: one transpose at startup. benefit: every inference step does a memcpy instead of a gather.