TLDR: Article describes a method to define “registers”, “heads” and non-linear “compute” functions which are compiled and distilled into a standard Transformer model. This is used to develop a simple and deterministic RPN interpreter.
Code accompanying this project can be found here.
AI disclaimer: code ≈ 50%, non-agentic. Article writing = 0%.
A few weeks ago I read a blog post from Percepta.ai entitled, Can LLMs be computers?1. The article frames this as an alternative to tool calling. Instead of writing out a program and using an external tool to evaluate it, just build an exact, deterministic interpreter right into the Transformer network. This framing led to some criticism regarding the approach, particularly “what is the point?” and “where are the benchmarks?” These are fair questions, but I think it unfortunately detracted from what I thought was a rather fascinating topic, which is: how do you build a deterministic interpreter into a Transformer?
Since Percepta didn’t really give away the secret sauce on that, I couldn’t resist making my own attempt. In the process I learned a lot about Transformers and thought it might be interesting to write a bit about it and post some code. Note that Percepta claims to have built a full WASM interpreter, and their article goes into some interesting details on taking advantage of 2D attention heads for fast sampling; I don’t address either of these here, though they are interesting in their own right.
Instead I just focused on the basic implementation of a simple RPN calculator. I like RPN because it’s just complex enough to formally require a stack, but is pretty much devoid of complicated syntax parsing issues. This meant I could really focus on just the basics: how to actually model a stack using a write-only tape mechanism, how to determine what operator to execute next and what the operands are, and how to determine when we’re done.
But first, let’s talk about how Transformers work, how to think of one as a “machine” and compilation target, and what determinism even means in this context.
By the way this article obviously got quite long as I wanted to describe in some detail how the final program works. Feel free to skip down to Conclusions if it gets boring.
Note: I do not want to claim that this view of a Transformer-as-register-machine is “actually how LLMs work.” Please think of it the other way around: this is an affordance that Transformers have. It is one possible state they can achieve, as proven constructively here, but it does not mean that this is the structure that they learn when trained on trillions of tokens.
At best, it may exist as a subspace that comes from enough training on math and algorithmic problems, and I do think the emergence of such structures could partially explain double descent and the grokking phonemenon2, and, who knows, maybe even by association relate to emergent abilities3, but this is nothing but conjecture.
The Transformer as a register machine
You probably know the (decoder-only) Transformer as a stack of layers alternating between self-attention and feedforward networks (FFNs), with layer normalization and dropout sprinkled somewhere in there. The FFNs are usually MLPs, i.e., single-hidden-layer neural networks with a non-linearity, typically GELU or SwiGLU (e.g. used in Llama) these days. I’m not going to dive in to the particularities of different activation functions, but overall we’re going to stick to this model, and we’re even going to remove layer normalization. Dropout doesn’t apply since we are not training.
Our target is to evaluate expressions like this one:
3 4 + 3 3 + *
and output 42.
The implementation is mostly in numpy because we don’t even need to calculate any gradients here:
class BasicTransformer:
def forward(self, X_input: np.ndarray) -> np.ndarray:
X = self.embed_input(X_input)
for layer in range(self.L):
# Calculate attention for each head
H_out =[self.mha(X, self.Q[layer][head], self.K[layer][head], self.V[layer][head])
for head in range(self.H)]
# Stack results and mix it together with O
X = X + np.hstack(H_out) @ self.O[layer]
# Apply FFNs
X = X + np.asarray([self.MLP[layer](x) for x in X])
return X
So, our goal is to figure out how to target the above, i.e., how to design a series of \(W_Q, W_K, W_V\), and \(W_O\) matrices, and an accompanying set of MLP blocks, to do what we want.
Immediately, a first lesson that I hadn’t appreciated before this project: the \(W_O\) matrix. It made sense to me that after performing per-head value lookups you might need to “remix” things but I hadn’t really appreciated why. In fact I sort of forgot that \(W_O\) existed. But, the need for it has everything to do with the residual structure.
To understand this, we have to rewind a bit and think about what is in that residual, and therefore what is in \(X\), and in turn, what structure is present in the embeddings?
Previously I considered embeddings and the residual pretty much like a “cloud”, clearly some sort of latent embedding space whose structure is not worth worrying about too much, as long as it gets the job done. It maps tokens to a semantic space, that’s all I need to know. But if we want to build a program interpreter, we need values, and operators that mutate those values. We need structure.
Another misconception: I had this tendency to think about neural networks as immutable, functional programs. After all, the “tape” is indeed immutable, we can only produce new tokens, not delete old ones. And each layer just processes the input and generates its own optimal latent space. Now I understand an alternate view: we can instead think of \(X\) as having a persistent structure with state that is modified at each layer. In this view, each layer is a “modify” operation on local storage \(X\).
The residual structure of Transformers even emphasizes that. If the residuals are deltas that modify the embeddings, ie., \( X_{l+1} = X_{l} + \Delta_{l} \), then we can assume \(\Delta_{l}\) comprises a state update on some subspace of \(X_{l}\). Each of these subspaces is effectively a register, and each \(\Delta_{l}\) writes the result of a calculation.
Figure 1: View of a transformer layer reading and writing a set of “registers” like a CPU. Registers are just subsets of our embedding space.
With this view, every attention and MLP block has an opportunity to “execute” a CPU instruction. The caveat is that, as long as we’re sticking to standard Transformers, we have registers but we don’t have program flow, ie., no loops: if we get to the last layer, we are forced to output something. So we have a choice: have enough layers for the calculations we need to do, or output a breadcrumb that helps us pick up where we left off the next time around.
Seeing \(X\) as collection of independent registers also allows a certain decoupling. An attention head in Layer 2 can prepare some input for an MLP in Layer 8, as long as the register location (subspace of \(X\)) is not overwritten to in-between. This realization led me to think about the depth-wise structure of the transformer more as a DAG of functions that operate on registers, storing values for later consumption, very much like a computer. Thinking of every layer as having its own blobby, blurry representation space is incorrect from this point of view: instead, layers can read and write specific memory locations. This might explain why we can get away with looping layers, allowing for some truly crazy model surgery4.
Back to \(W_O\): while the MLP receives the full \(X\) and emits a full \(\Delta\), this is not so for attention heads. In multihead attention, we divide up the dimension space. For example, of \(D_{model}=1024\), and we have 4 heads, then \(D_{head}=256\). Things naturally have to get re-arranged. This is what the \(W_Q\), \(W_K\) and \(W_V\) matrices are for: pick out parts of \(X\) that we are interested in. Once we’ve used \(W_V\) to return, say, registers 2, 5, 1, and 3 (averaged over matching locations in the sequence) in our 4 attention heads, they cannot be just concatenated to produce \(\Delta\) because the 4 heads are independent and they do not line up with the structure of \(X\). That is where \(W_O\) comes in: to take the results of each head and move the results to target register locations.
Or at least this is true as long as we stick to block-identity matrices. Later, I discovered more interesting uses for \(W_Q, W_K, W_V,\) and \(W_O\). More on that in the next sections.
Attention as an information router
In the above, I had mainly thought of attention as essentially a routing layer: pick out interesting tokens in the prefix according to certain registers using \(W_Q\) and \(W_K\), and then retrieve another register via \(W_V\) as the value. Deliver averaged \(V\) to the MLP where the real work gets done.
Initially I figured I’d keep things simple for this experiment and just stick to one-hot encodings for everything, so I was just populating \(W_Q, W_K, W_V\) with block-identity matrices. But it was a lot of tracking of dimension offsets in \(X\), and I wanted to be able to identify registers by name, so I created a class for this purpose:
class Register:
"""A memory mapping over the residual stream defining data type and size constraints."""
def __init__(self, name: str, size: int = 0):
self.name = name
self.size = size
self.offset = 0
self.type: Optional[RegisterType] = None
Registers just give a name to a subset of our embedding dimensions, from offset to
offset+size. By ensuring that registers don’t share dimensions, they automatically
represent orthogonal subspaces in the embedding space. We also give each register a
type which is the representation to use. Assume it is “one hot” for now, I’ll get into
representations in the next section.
A head, which “looks up” registers in \(X\), can now refer to them by name:
class Head:
"""Defines Query/Key matching constraints for an Attention Head."""
def __init__(self, match_on: Optional[Dict[str, int|str|Modifier]] = None,
provide: Optional[List[str]] = None,
out_regs: Optional[List[str]] = None):
...
The job of the compiler then is to assign the registers actual locations (i.e. fill in
Register.offset) and then populate \(W_Q, W_K, W_V\) accordingly. I decided to name the
parameters after what I want to do rather than these cryptic letters. But they directly
map. When scanning the sequence, we want to match two registers, maybe based on value or
position or some other property, so we provide a mapping, a
dict[str,int|str|Modifier]. (int is for when we want to match a specific value, str
is when we want to match the contents of a registers; I’ll dig into Modifier a bit
later.) This information is used to generate \(W_Q\) and \(W_K\). Then we want to
provide some other register as a value, which generates \(W_V\).
Since we’re just selecting, we can see \(W_Q, W_K, W_V\) as having 1s in locations where we want to copy a register, and 0 everywhere else, this is illustrated in figure 2.
Figure 2: How \(W_V\) extracts registers, and \(W_O\) copies the result to the destination. Note that \(W_Q\) and \(W_K\) work the same way as \(W_V\) here to extract registers to produce \(Q\) and \(K\), and \(\alpha=\textrm{softmax}(\frac{QK^T}{\sqrt{d}})\) performs the match by producing scores for the weighted sum over \(V_i\).
Finally, when aggregating across heads, we’ll want to shift that provided register, now
summed across tokens, into a new position, identified by out_reg, which lets us
construct \(W_O\).
Now Layer is very simple, it just collects heads and associates them to a processing
function:
class Layer:
def __init__(self, name: str):
self.name = name
self.heads: List[Head] = []
# Attention heads to find relevant information
def add_head(self, head: Head):
self.heads.append(head)
# Perform the work
def mlp(self, input: np.ndarray) -> np.ndarray:
pass
The idea here was that, as seen in forward in the initial code snippet above, we just
execute all the heads, construct the \(\Delta\) by distributing relevant outputs to the
corresponding register locations via \(W_O\), and pass our updated state into the MLP,
which in turn can update any register with its own \(\Delta\).
Compute blocks
After performing attention and updating the state, we now want to execute some non-linear
computation in the MLP. The approach I took was to write out the operation I wanted to
perform in Python code. I made a decorator @compute that declares to the compiler what
registers are read and written by that compute block, and actual reading and writing goes
through a State class that allows access by name.
@graph.compute(inputs=['exec_pos', 'pos'], outputs=['is_prompt'])
def check_prompt(state):
# code that reads `exec_pos` and `pos` and writes the
# delta for `is_prompt` through the `state` interface
...
The compiler then assigns attention heads and compute blocks to layers, according to the
dependency graph between variables. My goal was to generate an actual Transformer though,
so after compiling the network into Layers, I added a mechanism to call each layer’s
compute function inside a PyTorch Dataset class. I used this to distill the logic into
an actual MLP.
At inference time, I can choose to run the Python function or replace it with a PyTorch call on the trained MLP.
This might sound like I am cheating here by going back to training. But this is quite different from training the full network end to end! We are training each MLP in isolation, layer by layer, with full knowledge of the input and output domain. Since we cover the full domain, we are also trying to overfit in order to produce an exact, deterministic value for any given input. We are trying to create an exact copy of the Python logic. For this reason, I continue training until the output values are not only accurate, but also produce output that is very close exactly the target output value, usually 0, 1 or -1. Otherwise we risk building up error in deeper networks. I used a Hubert L1 loss and cosine annealing on the learning rate.
I will note that ensuring my compute functions were actually learnable was much less straightforward than I expected. Multiple times I had to simplify them or split them up. In the future I’d like to explore a constructive approach, building the MLP weights directly instead of distilling them this way. I suspect it is possible for some operations, but I didn’t go there yet. Maybe in a future post.
Liveness analysis
Since each attention and MLP is independent and can only communicate through \(X\), we need registers for every value that is transmitted from one layer to another. Without care these add up quickly and we need a huge dimensionality for \(X\), which leads to massive MLPs. However many of these intermediate values are “temporary variables” and their subspace can be reused later.
Since we define a dependency graph over our named references, we can figure out when a register is no longer needed, so we know we can safely re-use its location. This starts to decouple the concept of logical “variables” from the specific dimensions used, which is really a good analogy to how a C compiler will assign variables to hardware registers that are later reused for a different purpose.
Attention as a linear processor
While I initially I got something working with the above, I was still using Python functions for my “MLP” to calculate the per-token transformation. When I started trying to distill these functions into actual MLPs, I quickly hit some walls which took me much longer to solve than I’d like to admit. I was doing too much in my functions, and the MLP coudn’t model it. One of the solutions was using different representations than one-hot encoding.
Take a need like the following,
@graph.compute(inputs=['exec_pos', 'pos'], outputs=['is_prompt'])
def check_prompt(state):
# We are in the prompt if exec_pos hasn't been hit yet
# or if current position is before the exec_pos.
# Note that exec_pos==0 for the first token since attention is empty!
state.write_bool('is_prompt', (
state.read_idx('exec_pos') <= 0
or state.read_idx('pos') <= state.read_idx('exec_pos')
))
Essentially one-hot codes are not good for doing something like an inequality comparison, because the number “categories” are completely orthogonal. The MLP basically has to just memorize which numbers are greater than others, forcing it to build an \(O(N^2)\)-space lookup table for \(N\) values. It would be better if we used a representation that affords direct comparison. One is the so-called “thermometer” representation, which is just a “left-filled” version of the one-hot encoding. For example:
| number | one hot (allow_zero=True) |
thermometer |
|---|---|---|
| null | 0 0 0 0 0 |
0 0 0 0 0 |
1 |
0 1 0 0 0 |
1 1 0 0 0 |
3 |
0 0 0 1 0 |
1 1 1 1 0 |
The value \(v\) of some thermometer value \(t\) can be recovered by \(v=(\Sigma{t})-1\). The advantage is that we can perform an inequality comparison very easily by just comparing the values digit by digit and summing the results. This is much easier for an MLP to learn. (The -1 comes from allowing “no number” to be represented by the value -1, corresponding to all one-hot digits being 0. Not sure this was the best possible choice but it works.)
I had a bit of an aha moment here, when I understood that this conversion from one-hot to
thermometer could actually be done by the \(W_V\) matrix in the attention block. Since the
one-hots are basically just lookups, instead of using block-identity matrices, we can
convert the one-hots into any representation we want, performing a local embedding
specific to the needs of a given MLP block. So for my example of a one-hot register with
value 3, above, we can program \(W_V\) as a thermometer lookup table like this:
reg Wᵥ V
[ 0 0 0 1 0 ] [ 1 0 0 0 0 ] = [ 1 1 1 1 0 ]
[ 1 1 0 0 0 ]
[ 1 1 1 0 0 ]
[ 1 1 1 1 0 ]
[ 1 1 1 1 1 ]
Effectively \(W_V\) becomes an extra linear layer for the MLP, performing a change of representation to simplify the downstream job.
This could also be used for conversion to scalars, for encoding into quadrature (sin/cos) representation, or anything really. I started seeing this as the “type” of the variable, and the remapping of the one-hots as a type conversion.
So, to address the full generality and also have some basic compatibility checking in
place, I introduced “types” to the registers. I realized that type conversion could also
be useful in building queries. I added a general class called Modifier that can wrap
the register name at the moment it is specified to Head. It modifies the associated
linear mapping, also annotating, when appropriate, the result with a type conversion.
# Locate the EXEC token and fetch a Thermometer vector to handle inequalities
graph.fetch(Head(match_on={'val': TokenID.EXEC.value},
provide=['pos', Thermometer('pos')],
out_regs=['exec_pos', 'exec_thermo']))
Modifiers are also useful for things like tie breaking (e.g. finding the “right most” token, if multiple match the query), performing classification, or converting to scalars. They basically just change what is written to \(Q, K\) or \(V\) from a block-identity to something else.
Currently I have several RegisterType classes defined: one-hot, thermometer, constant
value, boolean, scalar; and several Modifier classes also. In addition to
Thermometer, I used Shift for modifying position during a query, TieBreak for
biasing a position query to the left or right, and Embed for embedding knowledge about
specific tokens.
Deterministic Sampling
This is all well and good but we haven’t addressed the question of determinism. The Transformer still generates a multinomial distribution that we sample from. To make it deterministic, we can’t just focus on the “most likely token”. We have to think about how to make it the only token. Since this post is already getting long, and this topic got so interesting, I think I will dedicate a future blog post to it, instead of elaborating here. But to summarize, these are the topics I plan to cover:
- Scaling \(Q\) to produce a large logit gap — this is found in the implementation as
ATTN_MATCH_BOOST. A large amplitude of \(Q\) causes softmax to act as a hardmaxfunction. - Determinism and orthogonality — there is an interesting relationship between how orthogonal we need to keep our “registers” and the achievable entropy of the predicted distribution.
- Reducing orthogonality constraints by considering top-P threshold: the previous problem can be reduced by determining how much orthogonality we need to preserve to stay under some probability threshold, and if the sampler threshold corresponds, we can maintain our determinism guarantees.
Compiling the RPN Calculator (The Code)
I’ve described most of the high level concepts needed to understand what is in transformer.py and compiler.py. Here I’ll go through the RPN program that I came up with. The source can be found in prog2_rpn_crumbs.py.
Without further ado, the output looks like:
$ uv run python src/test.py prog2_rpn_crumbs
[*] Compiling logic architecture for prog2_rpn_crumbs...
[*] Optimization: Aliasing 'exec_pos' into existing lane at offset 301
[*] Optimization: Aliasing 'exec_thermo' into existing lane at offset 351
[*] Optimization: Aliasing 'pos_scalar' into existing lane at offset 457
[*] Optimization: Aliasing 'is_prompt' into existing lane at offset 457
[*] Optimization: Aliasing 'log_pos' into existing lane at offset 301
[*] Optimization: Aliasing 'log_pos_gated' into existing lane at offset 0
[*] Optimization: Aliasing 'V_num' into existing lane at offset 251
[*] Optimization: Aliasing 'V_op' into existing lane at offset 351
[*] Optimization: Aliasing 'left_val' into existing lane at offset 401
[*] Optimization: Aliasing 'next_emit_plus' into existing lane at offset 151
[*] Optimization: Aliasing 'active_ops_map' into existing lane at offset 201
[*] Optimization: Aliasing 'found_first_op' into existing lane at offset 251
[*] Optimization: Aliasing 'next_emit_mult' into existing lane at offset 351
[*] Optimization: Aliasing 'active_nums_map' into existing lane at offset 458
[*] Optimization: Aliasing 'found_closest_num' into existing lane at offset 201
[*] Optimization: Aliasing 'next_emit1' into existing lane at offset 51
[*] Optimization: Aliasing 'active_nums_raw' into existing lane at offset 0
[*] Optimization: Aliasing 'active_ops_raw' into existing lane at offset 0
[*] GraphCompiler topologically sorted DAG into L=8 Layers and H=4 Heads.
-----------------------------------------------------------------
Prompt: 3 4 + EXEC
Output: c2 c1 c0 7
Prompt: 3 4 + 3 3 + * EXEC
Output: c2 c1 c0 7 c5 c4 c3 6 c6 c5 c2 42
Prompt: 10 2 3 * + 2 + EXEC
Output: c3 c2 c1 6 c4 c3 c0 16 c6 c5 c4 18
The first half is the compiler figuring out how to allocate registers to the embedding dimensions.
The second half is actually processing the prompt. The program outputs interim tokens to track progress as it evaluates the operators, which allows us to evaluate one operator at a time instead of trying to have enough “depth” of layers to process an entire expression in one shot. Essentially it spends some rounds figuring out what positions to consider as operands and operators that go together, and it lays them out in triplets. These special “c” tokens are pointers to the numerical tokens to evaluate. If an incomplete triplet is detected, it emits the next token in the triplet; if a full triplet is detected, the result of the operation (addition or multiplication) is emitted.
I don’t think this use of outputting tokens for intermediate values and tracking what we have consumed is automatically related to “chain of thought” but .. maybe not unrelated? Certainly using the tape as temporary storage space is a trade-off compared to allowing for actual depth-wise “loops” inside the network, so it’s interesting that both depth recursion and “thought tokens” have been important topics of LLM research in recent years.
Walk-through
The program is not that big, but big enough that I need a good strategy for walking
through it. I’ll try to go through the @compute functions and then refer back to the
attention heads that they require. There are 14. I will say though that probably there is
some more efficiency we could squeeze out of this by moving some of these compute
functions into attention heads.
I already showed check_prompt above, it outputs is_prompt, a simple boolean telling us
whether we are inside the equation to be executed or after it, i.e., is_prompt is only 1
before the | divider here:
3 3 + 2 * EXEC | c2 c1 c0 6 c4 c3 12
is_prompt=1 | is_prompt=0
So the “prompt” includes the EXEC token that indicates where we want to start
execution. By the way pos is a one-hot register giving us the current position in the
sequence. It simply increments by one at each token.
Next, we have calculation of the logical position. The logical position is “where are we in the partially-evaluated equation.”
First we have some lookups that provide us with these classifications, which I packed into a single head:
graph.fetch(Head(match_on={'pos': 'pos'}, provide=[
Embed('val', 1, is_num_fn),
Embed('val', 1, is_op_fn),
Embed('val', 1, is_ptr_fn),
Embed('val', 1, is_exec_fn),
Embed('val', MAX_TAPE_LEN, ptr_val_fn)
], out_regs=['v_is_num', 'v_is_op', 'v_is_ptr', 'v_is_exec', 'v_ptr_val']))
We also have some heads that find the 3 tokens to the left of the current position:
# Extract properties of T-1 token
graph.fetch(Head(match_on={'pos': Shift('pos', -1)}, provide=[
Embed('val', 1, is_ptr_fn),
Embed('val', 1, is_lt_end_fn),
Embed('val', MAX_TAPE_LEN, ptr_val_fn)
], out_regs=['m1_is_ptr', 'm1_lt_end', 'm1_ptr_val']))
# Extract properties of T-2 token
graph.fetch(Head(match_on={'pos': Shift('pos', -2)}, provide=[
Embed('val', 1, is_ptr_fn),
Embed('val', MAX_TAPE_LEN, ptr_val_fn)
], out_regs=['m2_is_ptr', 'm2_ptr_val']))
# Extract properties of T-3 token
graph.fetch(Head(match_on={'pos': Shift('pos', -3)}, provide=[
Embed('val', MAX_TAPE_LEN, ptr_val_fn)
], out_regs=['m3_ptr_val']))
Here, Shift is a modifier that precisely matches on some shift of the current position.
The is_* functions provide these lookups for various token types. You can see how I’ve
laid out the token IDs here: the first 42 are values, then we have,
class TokenID(IntEnum):
"""Shared vocabulary integer IDs for structural tokens."""
PLUS = 43
MULTIPLY = 44
EXEC = 45
END = 46
START = 47
From 50 (PTR_OFFSET) to 100 we consider them pointer tokens, (also called “consume” or
“c” tokens, but I like to call them “breadcrumbs”.) So here are the functions passed to
the Embed modifier above:
def is_num_fn(x):
return 1.0 if x < TokenID.PLUS.value else 0.0
def is_op_fn(x):
return 1.0 if x in (TokenID.PLUS.value, TokenID.MULTIPLY.value) else 0.0
def is_ptr_fn(x):
return 1.0 if x >= PTR_OFFSET else 0.0
def is_exec_fn(x):
return 1.0 if x == TokenID.EXEC.value else 0.0
def is_lt_end_fn(x):
# Identifies "Value" tokens (Numbers, Operators, EXEC) as opposed to Pointers
return 1.0 if x < TokenID.END.value else 0.0
def ptr_val_fn(x):
# Projects a pointer token (e.g. 53) back into an un-offset index (e.g. 3)
res = [0.0] * MAX_TAPE_LEN
if x >= PTR_OFFSET and (x - PTR_OFFSET) < MAX_TAPE_LEN:
res[x - PTR_OFFSET] = 1.0
return res
Finally we can actually calculate it according to the following rules:
- If we are in the prompt, then the logical position is the current position.
- If we are after the EXEC token, then:
- If we have a number, logical position is what operator calculated it.
- If we have a pointer, logical position is where we point.
@graph.compute(inputs=['is_prompt', 'pos', 'v_is_num', 'v_is_ptr', 'v_ptr_val', 'm3_ptr_val'],
outputs=['log_pos'])
def calc_log_pos(state):
if state.is_true('is_prompt'):
pos = state.read_idx('pos')
if pos != -1:
state.write_idx('log_pos', pos)
elif state.is_true('v_is_num'):
# A newly generated math result inherits the position of the Operator
m3 = state.read_idx('m3_ptr_val')
if m3 != -1:
state.write_idx('log_pos', m3)
else:
state.write_idx('log_pos', 0)
elif state.is_true('v_is_ptr'):
# Pointers inherit the logical position of what they point to
vp = state.read_idx('v_ptr_val')
if vp != -1:
state.write_idx('log_pos', vp)
We also need a version of log_pos that only indicates non-pointer values. (So in c2 c1 c0 6, only 6 would get a position. For the “c” tokens, log_pos_gated is -1.)
@graph.compute(inputs=['log_pos', 'v_is_ptr'], outputs=['log_pos_gated'])
def gate_log_pos(state):
# Excludes pointers so future queries only match concrete values
lp = state.read_idx('log_pos')
if lp != -1:
if not state.is_true('v_is_ptr'):
state.write_idx('log_pos_gated', lp)
Now a trick: when a pointer to a number or operator is encountered, we “annihilate” the original value so that it doesn’t get returned. This allows “c” tokens to “consume” values that have already been processed. Remember, we are always averaging across the sequence, so when we have matches on both the original and pointer tokens, the 1.0 and -1.0 will cancel out, marking the value as “consumed”!
@graph.compute(inputs=['log_pos', 'v_is_num', 'v_is_op', 'v_is_ptr', 'm1_is_ptr', 'm1_lt_end'], outputs=['V_num', 'V_op'])
def track_annihilations(state):
lp = state.read_idx('log_pos')
if lp != -1:
# 1. Mark creation of nodes
if state.is_true('v_is_num'):
state.write_idx('V_num', lp, 1.0)
elif state.is_true('v_is_op'):
state.write_idx('V_op', lp, 1.0)
# 2. Mark annihilation of nodes when a pointer replaces them
if state.is_true('v_is_ptr'):
if state.is_true('m1_is_ptr'):
# The 2nd and 3rd pointers consume Numbers (Right and Left operands)
state.write_idx('V_num', lp, -1.0)
elif state.is_true('m1_lt_end'):
# The 1st pointer consumes the Operator
state.write_idx('V_op', lp, -1.0)
So we can now just average V_num and V_op across the whole sequence and get only the
“active” tokens:
graph.fetch(Head(match_on={}, provide=['V_num'], out_regs=['avg_nums']))
graph.fetch(Head(match_on={}, provide=['V_op'], out_regs=['avg_ops']))
Notice that match_on={} here. Since we are just matching zeros, this causes softmax to
generate a full average (non-weighted) across the whole sequence.
We convert these to “maps” of active tokens, ie. layouts similar to the one-hot position
encoding, but 1 where the token is active and 0 where it is not. We call this a multi-hot
representation. Note that since avg_ops and avg_nums are averages across all tokens up
to the current token, we have to scale by the current position to compare to the
threshold.
For operators:
@graph.compute(inputs=['pos_thermo', 'avg_ops'], outputs=['active_ops_map'])
def prepare_ops_map(state, token_num, num_tokens):
p = state.read_thermo('pos_thermo')
if p != -1:
a_ops = state.read('avg_ops') * (p + 1)
ops_map = (a_ops > ONE_HOT_THRESHOLD).astype(float)
state.write_vec('active_ops_map', ops_map)
Same for numbers, but just a note that I added an extra, seemingly useless dependency on
active_ops_map here. This is not strictly necessary but I found that prepare_ops_map
and prepare_active_nums_map could not be distilled into the same MLP, so this link
forces the compiler to put these functions in separate layers.
@graph.compute(inputs=['pos_thermo', 'avg_nums', 'active_ops_map'], outputs=['active_nums_map'])
def prepare_active_nums_map(state):
p = state.read_thermo('pos_thermo')
if p != -1:
a_nums = state.read('avg_nums') * (p + 1)
nums_map = (a_nums > ONE_HOT_THRESHOLD).astype(float)
state.write_vec('active_nums_map', nums_map)
At this stage we have identified all non-canceled tokens but we need to figure out which one to consume next. If it’s a number or operator, we cancel it with a pointer; if we have a triplet of 2 numbers and an operator, we emit the result of the operation. So we have to decide what to do next. Remember that all these functions are always executed, so they have to store their results and later the specific result we need for the next token is selected from them.
If we are starting a new execution, we first need to identify the operator, so store 1 in
search_first_op:
@graph.compute(inputs=['v_is_exec', 'v_is_num', 'is_prompt'], outputs=['search_first_op'])
def decide_first_op(state):
# We start an execution chain when hitting EXEC or finishing a math step
if state.is_true('v_is_exec'):
state.write_bool('search_first_op', True)
elif state.is_true('v_is_num'):
if not state.is_true('is_prompt'):
state.write_bool('search_first_op', True)
If we are looking for operands, dereference the pointers:
@graph.compute(inputs=['v_is_ptr', 'm1_is_ptr', 'm2_is_ptr', 'm1_lt_end', 'v_ptr_val'], outputs=['search_num_tgt_thermo', 'search_num_tgt'])
def decide_search_num(state):
# If we have 1 or 2 pointers, we must search left for the next operand
if state.is_true('v_is_ptr'):
if not (state.is_true('m1_is_ptr') and state.is_true('m2_is_ptr')):
v_ptr = state.read_idx('v_ptr_val')
if v_ptr != -1:
state.write_idx('search_num_tgt', v_ptr)
state.write_thermo('search_num_tgt_thermo', v_ptr)
Recognize the situation where we are ready to emit a calculation! In this case we emit the left and right operands and the operator into dedicated registers.
@graph.compute(inputs=['v_is_ptr', 'm1_is_ptr', 'm2_is_ptr', 'v_ptr_val', 'm1_ptr_val', 'm2_ptr_val'], outputs=['q_left', 'q_right', 'q_op'])
def decide_execute(state):
# If we have 3 pointers in a row, the instruction is fully built!
if state.is_true('v_is_ptr'):
if state.is_true('m1_is_ptr'):
if state.is_true('m2_is_ptr'):
v_ptr = state.read_idx('v_ptr_val')
m1_ptr = state.read_idx('m1_ptr_val')
m2_ptr = state.read_idx('m2_ptr_val')
if v_ptr != -1:
state.write_idx('q_left', v_ptr)
if m1_ptr != -1:
state.write_idx('q_right', m1_ptr)
if m2_ptr != -1:
state.write_idx('q_op', m2_ptr)
Finally we apply a “tie breaker” modifier which weights the logical position by the token position, allowing us to select the left-most non-consumed token of the desired type.
tie_op = {
'log_pos_gated': 'active_ops_map',
TieBreak('log_pos'): TieBreak('log_pos', weight=-10.0)
}
graph.fetch(Head(match_on=tie_op, provide=['log_pos_gated'], out_regs=['found_first_op']))
tie_num = {
# Ensure the token is an active number
'log_pos_gated': 'active_nums_map',
# Ensure log_pos < search_num_tgt
'log_pos': 'search_num_tgt_thermo',
# Tie-Break by favoring the highest (right-most) logical position
TieBreak('log_pos'): TieBreak('log_pos', weight=10.0)
}
graph.fetch(Head(match_on=tie_num, provide=['log_pos_gated'], out_regs=['found_closest_num']))
With this we are ready to emit pointer tokens or end the sequence. We store this result in
a temporary register next_emit1 for later routing.
@graph.compute(inputs=['search_first_op', 'search_num_tgt', 'found_first_op', 'found_closest_num'], outputs=['next_emit1'])
def route_found_targets(state):
if state.is_true('search_first_op'):
f_op = state.read_idx('found_first_op')
if f_op != -1:
state.write_idx('next_emit1', PTR_OFFSET + f_op)
else:
state.write_idx('next_emit1', TokenID.END.value)
elif state.read_idx('search_num_tgt') != -1:
f_num = state.read_idx('found_closest_num')
if f_num != -1:
state.write_idx('next_emit1', PTR_OFFSET + f_num)
If it’s the thing to do, we also calculate additions and multiplications of the active operator and operands, if they are all found. If the result is not in range, we just leave it unwritten.
graph.fetch(Head(match_on={'log_pos_gated': 'q_left', 'v_is_num': 'c1'}, provide=[('val')], out_regs=['left_val']))
graph.fetch(Head(match_on={'log_pos_gated': 'q_right', 'v_is_num': 'c1'}, provide=[('val')], out_regs=['right_val']))
graph.fetch(Head(match_on={'log_pos_gated': 'q_op', 'v_is_op': 'c1'}, provide=[ProjectOffset('val', TokenID.PLUS.value, 2)], out_regs=['op_val']))
@graph.compute(inputs=['left_val', 'right_val'], outputs=['next_emit_plus'])
def execute_plus(state, token_num, num_tokens):
l_val = state.read_idx('left_val')
r_val = state.read_idx('right_val')
res = l_val + r_val
if res < TokenID.PLUS.value:
state.write_idx('next_emit_plus', res)
@graph.compute(inputs=['left_val', 'right_val', 'next_emit_plus'], outputs=['next_emit_mult'])
def execute_mult(state, token_num, num_tokens):
l_val = state.read_idx('left_val')
r_val = state.read_idx('right_val')
res = l_val * r_val
if res < TokenID.PLUS.value:
state.write_idx('next_emit_mult', res)
Finally we decide what token to emit! Either a previously determined pointer, or an operator calculation result:
@graph.compute(inputs=['next_emit1', 'next_emit_plus', 'next_emit_mult', 'op_val'],
outputs=['next_emit'])
def route_emit(state):
nxt = state.read_idx('next_emit1')
if nxt != -1:
state.write_idx('next_emit', nxt)
return
op = state.read_idx('op_val') + TokenID.PLUS.value
if op == TokenID.PLUS.value:
state.write_idx('next_emit', state.read_idx('next_emit_plus'))
elif op == TokenID.MULTIPLY.value:
state.write_idx('next_emit', state.read_idx('next_emit_mult'))
Well, that does it actually. Our program now outputs either a pointer token, a calculation result, or ends the sequence, as appropriate, dereferencing any pointers as we go. It’s quite complicated to follow the above logic of course, and believe me it wasn’t obvious. This leads me to believe that a higher-level language might be appropriate, more on that later, but the purpose here is at least to show that it’s possible to explicitly construct a Transformer network that can interpret a program.
Distillation
As mentioned, we can now “train” this network, ie. distill these compute functions into actual MLP networks and generate the \(W_Q, W_K, W_V\) attention weights.
python src/train.py prog2_rpn_crumbs crumbs.pt
I say “train” in quotes because it’s not the same as end-to-end training. We are precisely, exactly copying the logic of the Python functions, and doing this layer-by-layer. Each layer takes around half an hour on my measly 4GB 3050 laptop GPU, so the whole 8-layer network trains in about 5 hours.
By specifying this .pt file when running the program, the trained MLP weights are loaded
and PyTorch is used for executing them instead of running the functions I give above:
python src/test.py prog2_rpn_crumbs --weights crumbs.pt
Now in addition to the above output, you will see:
[*] Injecting continuous PyTorch Neural Networks crumbs.pt into MLP layer definitions...
[*] Using 1 hidden layers of size 3528.
[*] Replaced logic for Layer 'Layer_0' with Neural Network from crumbs.pt
[*] Using 1 hidden layers of size 3528.
[*] Replaced logic for Layer 'Layer_1' with Neural Network from crumbs.pt
[*] Using 1 hidden layers of size 3528.
[*] Replaced logic for Layer 'Layer_2' with Neural Network from crumbs.pt
[*] Using 1 hidden layers of size 3528.
[*] Replaced logic for Layer 'Layer_3' with Neural Network from crumbs.pt
[*] Using 1 hidden layers of size 3528.
[*] Replaced logic for Layer 'Layer_4' with Neural Network from crumbs.pt
[*] Using 1 hidden layers of size 3528.
[*] Replaced logic for Layer 'Layer_5' with Neural Network from crumbs.pt
[*] Using 1 hidden layers of size 3528.
[*] Replaced logic for Layer 'Layer_6' with Neural Network from crumbs.pt
[*] Using 1 hidden layers of size 3528.
[*] Replaced logic for Layer 'Layer_7' with Neural Network from crumbs.pt
Conclusions
What I find most confusing about programming this way is thinking about these operations as happening “all the time” to “all tokens” at every step. In that way it’s a lot more like writing a hardware description language like VHDL than it is like imperative, stateful programming. One has to constantly think, “okay what happens if the current token is X and to the left I have a Y and a Z?” while simultaneously consider how any changes you make will also affect what Y and Z are. This is a bit circulate and kind of hard to keep in your head at the same time.
I find it also quite funny how inefficient it is. For a simple RPN interpreter, this must
win some kind of award. The crumbs.pt file clocks in at 1.1G. Of course it could
probably be made a lot smaller by not including unused registers, but consider that all
attention heads and MLPs are executed on every token really makes me think about the
difference in efficiency between an immutable, append-only tape model like an
autoregressive Transformer and a stateful CPU+RAM computer program. Perhaps there is a
lesson in there somewhere.
There are plenty of limitations here, of course. We’re only handling single-token operators and operands, and our embedding strictly limits the length because we are using one dimension per position. I believe several of the computes and heads I defined could be optimized away with some more work, if the token embeddings were improved (the E matrix in the code).
I think there’s room for that but I actually think some of it could be done automatically, more on that below.
What’s next
There’s lots of possibilities from here, so I’ll just make some notes for the future.
Some things to finish
Before getting into questions about theory that this brings up, I’ll just mention a couple of obvious missing pieces that would be nice to address.
Firstly there is Layer Normalization. In a Transformer this is typically either placed on the input to each block (i.e., applied over \(X\) on input to attention and MLP), or applied over the output after summing with \(\Delta\). I avoided this here because it would complicated my encodings. I am using one-hot encodings which are very sparse, and if everything is amplified to make the variance equal to 1.0 then either the thresholds have to be adjusted or things need to be scaled back down to compensate. We can provide this compensation to layernorm via affine parameters \(\gamma\) and \(\beta\), but at best this would let us set it to an “average”, and I’m not sure what to do if it needs to be dynamically set. So I need to experiment and figure out what to do about it. In any case, the only reason to add it would be to be compatible with a fully standard Transformer… which brings me to..
ggufs when?
If the layernorm problem could be figured out, then I could probably actually generate weights compatible with some standard model like Llama and publish a model to HuggingFace. Although the model would be a pretty useless, super limited RPN calculator, it would be a cool and satisfying way to round off this project.
Tokenization is also a blocker, because we are not using BPE here like most models, so it
would also require developing some kind of custom simplified tokenizer just for demo
purposes. I absolutely did not want to deal with multi-digit logic here (e.g. if “40” were
broken up into two tokens “4” and “0”.) So there are definitely some things to figure out
before I could upload a transformers-compatible model.
Would be nice though.
Higher level language design, and MLP construction
So far I’ve managed to develop this basic level of ‘compiler’ that takes a few attention operations and non-linear computations, builds a dependency graph of them, and figures out how to arrange the embedding dimensions and lay out the operations into Transformer layers in a non-conflicting way.
But I definitely still find this a bit too low-level. In particular I find it hard to
reason about what a single MLP can handle in those @compute functions, and experimenting
by distilling for an hour between changes was really inefficient. So one thing I’d like to
try is to move towards a constructive approach to the MLPs. Since we know exactly what
operations they should be implementing, is it possible to just design the MLP weights just
like I am doing for the attention weights? At least for a ReLU activation I am thinking
something should be possible along these lines.
But furthermore, I realized that some things can be moved around between attention and MLPs. For instance the addition operator could theoretically be taken care of directly in attention, whereas the multiply operator could not. So I am wondering if the compiler can actually figure this kind of thing out.
If we understand better how to build certain logical operations into MLPs or attention, and know which operations can be implemented in which case (e.g. require an activation function or require aggregation over the sequence, or are just a linear operator like addition), then we can dynamically make decisions about where to put things and what representations they should use.
That would lead towards some kind of “language” that we can use to express programs for Transformers. For example, just riffing with Gemini on this we came up with something like this possible embedded DSL:
# ---------------------------------------------------------
# 1. Declare Global Token Variables (The "Registers")
# ---------------------------------------------------------
is_op = Var('is_op', node_type='Boolean')
is_num = Var('is_num', node_type='Boolean')
val = Var('val', node_type='OneHot')
pos = Var('pos', node_type='Scalar')
is_active = Var('is_active', node_type='Boolean') # Handles annihilation/consumption
PLUS_VAL = 50 # Example Token IDs
MULT_VAL = 51
PTR_OFFSET = 100
# ---------------------------------------------------------
# 2. Sequence Operations: Find Targets
# ---------------------------------------------------------
# Find the first active operator (anywhere)
first_op_ptr = SequenceFind(
condition=(is_op & is_active),
direction='any',
tiebreak='first'
)
# Find the closest active numbers to the left of the found operator
# Note: In our DSL, `first_op_ptr.pos` implicitly fetches the pos of the found op.
closest_num_1_ptr = SequenceFind(
condition=(is_num & is_active & (pos < first_op_ptr.pos)),
direction='left',
tiebreak='closest'
)
closest_num_2_ptr = SequenceFind(
condition=(is_num & is_active & (pos < closest_num_1_ptr.pos)),
direction='left',
tiebreak='closest'
)
# ---------------------------------------------------------
# 3. Read (Fetch) Values into Local Context
# ---------------------------------------------------------
op_val = Read(first_op_ptr, val)
l_val = Read(closest_num_2_ptr, val)
r_val = Read(closest_num_1_ptr, val)
# ---------------------------------------------------------
# 4. Compute Math Operations
# ---------------------------------------------------------
sum_res = l_val + r_val
mult_res = l_val * r_val
# ---------------------------------------------------------
# 5. Routing / Output Selection
# ---------------------------------------------------------
# We want to emit a trace of what we found, or the math result if we are done tracing.
# Let's assume some boolean flags dictating the state: `search_first_op`
search_first_op = Var('search_first_op', node_type='Boolean')
search_num_tgt = Var('search_num_tgt', node_type='Boolean')
output = Select(
cases=[
# If we are looking for the op, emit a pointer to it
(search_first_op, Read(first_op_ptr, pos) + PTR_OFFSET),
# If we are looking for the number, emit a pointer to it
(search_num_tgt, Read(closest_num_1_ptr, pos) + PTR_OFFSET),
# Otherwise, emit the result of the math!
(op_val == PLUS_VAL, sum_res),
(op_val == MULT_VAL, mult_res),
],
default=0 # End of calculation
)
# Terminate AST
program = Emit(output)
I have to admit this looks a lot more compact and easier to deal with. Supposedly this is equivalent to the program I previously described, but I haven’t worked through it completely to verify. No mention of “heads” or “compute”, just operations that we need, expressed more or less semantically, and it would be up to the compiler to figure out how to build the Transformer that implements this, assigning operations to attention or MLP layers as necessary. I like that it’s declarative instead of imperative, feels more appropriate to the task.
Research questions
All fun and games, but what is the point? As I mentioned this is a super inefficient RPN calculator. And it was a lot of work to get this far. Making it more complicated and more capable would have to be justified.
This means research, and unfortunately I probably don’t have the time or resources to dig into these topics myself. But I wish I could, so I’ll jot them down for posterity.
Three research questions come to mind:
- Can programs be superimposed? — while the above method for constructing programs works, the one-hot representations I’m using are, in general, super sparse compared to the dimensionality. Which, as I said, leads to a massively inefficient runtime for something that can only do “one thing”. It feels like a lot more could be packed in there. One way to do that would be to “rotate” everything. If you perform a PCA over the weights, I bet you could compress things down quite drastically. Uncompressing it would involve rotating activations, then rotating back and projecting. Then, I wonder if you could sort of overlay multiple programs by choosing different orthogonal rotations for them, and select between them by choosing which rotation and projection you do to extract it.
- Programs as initialization — in any case, what’s the point of this? It’s not like a Transformer is a “good” runtime environment for something like an interpreter, so the feedback on this idea saying that it just isn’t useful compared to executing external tools has a point. But, what we are doing here is imbuing a certain knowledge of arithmetic into the Transformer itself. This is something you normally have to train over many thousands of examples to teach it. Well, a Transformer that only does arithmetic is not so useful, but what if we defined many more programs? Numerical solvers, Prolog, graph algorithms.. and then used the resulting weights as initialization for end-to-end pretraining? Would the resulting network learn faster, or generalize better? Hard to say. I can see all sorts of roadblocks here, such as the scaling of \(Q\) that might lead to bad gradients, but overall I think the idea has some merit to be explored.
- Programs in an MoE configuration: if we can backpropate through these program models, but leave them frozen, could we select between them? For instance could we just drop this RPN interpreter into a branch of an existing mixture-of-experts model and teach the router that it gets the answer of these particular expressions right with high certainty? Is there any benefit to doing this? Maybe it would have some advantage over depending on tool or skill descriptions to know when to make use of a “specialist”.
- Combining the above 3 ideas could be a powerful initialization method for MoE models.
Overall I think these are some interesting aspects to explore, but likely I just don’t have the resources to get into them. Sadness. But maybe a small LM could test some ideas here without requiring massive GPU.
Anyways, this post got long, (looks back, very long indeed!) but I am considering it documentation and a project report as well as a blog post. I found it pretty hard to completely describe this topic in a succinct way, or know when to stop developing and just write something, so I decided to do so once I had a fully working RPN interpreter, but I am sure there is a lot that I could still improve. It seems that I even promised to write a follow-up post talking about the determinism aspects, which I’ll work on in due time. I’m also curious to explore the possible development of that DSL, which feels a lot more approachable to me right now than the research questions. If you got this far, I hope you found it compelling! I’ll be watching forums like HN and reddit for any discussions on this and connected topics.
-
Percepta.ai., Can LLMs be computers?, 2026. ↩︎
-
Davis et al., Unifying Grokking and Double Descent, 2023. ↩︎
-
Huang et al., Unified View of Grokking, Double Descent and Emergent Abilities: A Perspective from Circuits Competition, 2024. ↩︎
-
DN Ng. LLM Neuroanatomy: How I Topped the LLM Leaderboard Without Changing a Single Weight, 2026. ↩︎
