Fusebox
Python · PyTorch · Triton · GPU Systems · 2026
Introduction
While GPUs are quick at performing arithmetic, they're glacially slow at moving data around. When a model runs many small operations in sequence, almost all the time is spent sending numbers to and from memory rather than computing. Fusebox automatically finds operations that can be fused into single GPU kernels, rewrites the computation graph, and recompiles the program with the fused implementations.
Unfused vs. Fused Computation Graph
relugelusilurelu → gelu → silu
single kernel
intermediates in registers
I built a four-stage pipeline that takes a model and returns a faster, equivalent one. First it traces the model with torch.fx to capture its computation as a graph of operations. Then it detects fusable patterns in that graph. My program does this through two detectors, one that identifies chains of consecutive elementwise ops (like relu → gelu → silu) and one that identifies the Linear → activation pattern that dominates transformers. Next, it rewrites the graph by collapsing each detected group into a single node and validating that the result is still well formed. Finally, it substitutes a fused kernel for that node, with a PyTorch fallback and a Triton GPU kernel for the specific case. On a T4 GPU, the fused three-op activation chain ran 2.95× faster than the unfused version.
Background
Why are GPUs so slow?
Any operation has to first read its inputs from VRAM (the GPU's main memory), then write its results back when it's done. It's usually that part (as opposed to the time spent performing computations) that takes a majority of the time. So in many cases, the compute cores spend most of their time waiting on memory.
When PyTorch executes a model, every operation runs as its own GPU kernel, and each kernel follows the same pattern: read from VRAM, compute, write to VRAM. So, a sequence like relu → gelu → silu is three separate kernels and three full round-trips, where each intermediate result gets written out to memory and read straight back in by the next op.
Per-Kernel VRAM Traffic: Unfused vs. Fused
relugelusilurelu · in registersgelu · in registerssilu · in registersThe fix is kernel fusion. Kernel fusion replaces those three kernels with a single kernel that performs all three steps while the intermediate values stay in registers. So, the fused implementation only requires one value read at the start and one written at the end.
The problem is that fusion is usually done manually. Someone with GPU expertise writes a custom fused kernel for a specific sequence of operations. Generalized frameworks exist (torch.compile, TorchInductor, etc.) but the systems are black boxes. So, I built my own!
Two more pieces of background info:
torch.fxcaptures a model's structure as a graph without actually running it on real data.- Triton is a language for writing GPU kernels in a Python-like syntax instead of raw CUDA.
So, fusebox generalizes and automates kernel fusion!
Architecture
Fusebox contains 4 steps: tracing, detecting patterns, rewriting the graph, and implementing the fused kernels.
Stage 1 — Tracing tracer.py
Before optimizing a model, you need it in a form that you can actually manipulate. So my first step was to capture the structure of a normal nn.Module as a graph. That's what the trace() function does! It calls torch.fx.symbolic_trace(model) which runs the model once with symbolic (fake) inputs and records every operation as a node in a graph.
torch.fx Traced Computation Graph
relu(x)gelu(relu_out)silu(gelu_out)torch.fx traced graph — each operation captured as a manipulable node.The one caveat is that symbolic tracing fails on models with dynamic control flow: a forward() with if statements that branch on the input. Symbolic tracing pushes fake placeholder tensors through the model to record operations, but an if needs a real boolean to decide which branch to take which the placeholder can't provide. But the try/except block catches tracing failures and falls back to a plain Tracer().trace(), building the GraphModule manually. (But for genuinely dynamic models like BERT, even this fails. That case needs a model-specific tracer, which is why the BERT benchmark uses HuggingFace's own.) The helper get_call_nodes() helps filter out the boilerplate nodes (placeholder, output, get_attr) for just the actual operations.
Now that we've traced all the functions, it's time to determine which computations can be fused for efficiency.
Stage 2 — Pattern Detection patterns.py
So, this stage finds sequences of operations that are safe and efficient to fuse. There are two detectors that do so.
1. Pointwise Chains
First for some background, a pointwise chain is a consecutive set of functions where the output of one function serves as the input of the next. This chain of operations is efficiently fusable. The function detect_pointwise_chains() looks for runs of consecutive pointwise operations (ex. relu → gelu → silu).
But how does it know which operations are fusable?
I hardcoded the set of operations that the program should consider fusable near the top in POINTWISE_OPS (activations, plus basic arithmetic like add/mul). This makes the detection itself quite simple. The program starts at a pointwise node, then keeps hopping to the next node as long as two conditions hold: the next node is also pointwise, and the current node feeds exactly one consumer.
Why the second check?
Let's say, for example: relu's output goes both into gelu and into some connection downstream. Then you can't fuse relu away because something else still needs its result. Without the check, the relu and gelu operations would get fused together and the connection downstream wouldn't have an input value. The check is in the inline consumers = [...]; if len(consumers) != 1: break inside the walk loop (there's also a standalone _single_consumer() helper). Every chain it returns is guaranteed to be a private, linear run of ops without outside dependencies.
2. Linear Activations
The other detector is detect_linear_activation(), and it's the one that matters for real models. I wrote the detector to look for an nn.Linear immediately followed by an activation which is the main pattern in transformers.
Linear → Activation Pattern in Transformer MLPs
LinearActivation (GELU / ReLU)LinearActivation (GELU / ReLU)Linear → Activation pattern repeats throughout transformer MLPs — each pair is a fusion target.This detector follows the same single consumer check that I explained above. Similar to the way pointwise operations were found, the program checks against SUPPORTED_ACTS imported from the kernels module. This makes sure that the activation is one that fusebox actually supports.
Both detectors return FusableGroup objects. These objects are a little dataclass holding the matched nodes, a kind string (this tells the compiler the type of pattern detected), and the tuple of op names (e.g. ("relu", "gelu")) that gets used later for kernel lookup. When writing the detect_all() function, which runs both kinds of pattern detection methods, I deliberately checked linear_activation first, since those are the higher-value targets.
I also wrote a _resolved_op_name() helper which normalizes all three call styles (function call, method call, and module) and strips the trailing underscore off in-place variants like relu_ so detection and the registry always agree on what to call something.
Stage 3 — Graph Rewriting compiler.py
Now, it's time to edit the symbolic graph. For each FusableGroup, the _rewrite_group() function chooses one of the two rewrite strategies depending on the kind (pointwise or linear activation).
1. _rewrite_pointwise()
In this stage, the pointwise chains are collapsed into a single function call.
relu → gelu → silu
fused_relu_gelu_silu(x)
The program does this by looking up a fused kernel that matches the list of operations by checking the registry. If there's no kernel registered, it returns False and doesn't modify the chain. But, assuming a kernel does exist, the program 1) inserts one new call_function node pointing at that kernel, 2) points the chain's input into it, 3) redirects everyone who used the last node's output to use the new node instead (last.replace_all_uses_with(fused_node)), and 4) erases the original nodes in reverse order. Without reverse order, I'd be deleting consumers before the things they consume.
2. _rewrite_linear_activation()
The linear and activation rewrite doesn't actually work the same way the pointwise rewrite did. The pointwise functions all consist of free functions, operations that don't have internal states. But a Linear function has weights. So, instead of inserting a function call, this path 1) builds a FusedLinearAct module which wraps around the original Linear, 2) registers it on the GraphModule with gm.add_submodule(), 3) swaps in a single call_module node, 4) redirects uses, 5) erases the two old nodes, and 6) finally deletes the original Linear submodule.
Graph Rewrite: Linear + Activation → FusedLinearAct
call_moduleLinear
call_functiongelu
call_moduleFusedLinearAct
wraps original weights
FusedLinearAct module carries the original Linear's weights so no parameters are lost.After that long 6 step process, there are two more lines that finish this stage up. First, gm.graph.lint() validates that the rewritten graph is well-formed (i.e., no dangling references, no use-before-definition), and gm.recompile() regenerates the actual Python forward() method from the modified graph.
Stage 4 — Fused Kernels kernels/
Stages 1–3 decided what to fuse and recreated the graph to call the fused implementation. This final step is the implementation itself.
First, kernels/registry.py is the lookup table that maps the operation names to callables, and lookup() is what the compiler calls. Here, I coded a safe fallback. If a certain set of operations isn't found in the registry, then lookup() returns None, and the compiler leaves that chain alone. That way, the worst case situation is just an unfused model (as opposed to a broken one).
1. Python fallback registry.py · kernels/linear_act.py
The first tier option is a set of pure PyTorch implementations that are always available. An example of one is FusedLinearAct. It shares the original Linear's weight and bias tensors and its forward() is the self._act(F.linear(...)). Unfortunately, this doesn't make the memory more efficient or make the program run faster.
2. Triton kernel kernels/pointwise.py · benchmarks/triton_linear_gelu.py
The second tier is actual GPU kernels that do the fusion. pointwise.py has _relu_gelu_kernel, a Triton kernel that loads the input once, computes relu and gelu while the values sit in registers, and writes once. Now this tier is the only one that actually delivers the speedup, because it's the only tier where the intermediate doesn't touch VRAM.
So, the memory is saved and the runtime is shortened only when the fused operation is actually backed by a kernel that is able to execute the entire sequence. This avoids the intermediate writes to memory. The Python tier just validates the correctness and structure of the graph but the second tier reduces memory traffic.
Build Process
First, I laid out the skeleton of what I wanted the code to roughly look like, I set up the environment, and then I worked on the fusion mechanism for a pointwise chain. Once I got the environment behaving properly, I got my first real end-to-end path working. My program was able to detect a pointwise chain, rewrite the graph, and run it (passing the correctness check). Initially, the benchmark on a CPU was 1.05× which was basically negligible but also expected. The whole premise of fusion is eliminating memory round trips, and on a CPU the intermediate result just sits in cache, so there was nothing for the fusion to win back.
Next, I wanted to use fusebox on a real model, so I tried BERT from HuggingFace. Two things happened. First, torch.fx tracing failed. BERT's forward() has input-dependent if statements (dynamic control flow), which symbolic tracing can't represent, like I described earlier. The fix was to swap in HuggingFace's own transformers.utils.fx.symbolic_trace, which is purpose-built to handle BERT's branching.
Second, once it traced, fusebox found only one fusable group in all of BERT, and it was a pair of consecutive add operations. This is because BERT's activations are between LayerNorm and Linear layers, so they never form the long pure-pointwise chains that my first detector was looking for.
This is when I built the second detector, targeting the Linear → activation pattern that dominates transformers. I tested this on a TransformerMLP (512 → 2048 → 512), it correctly detected and fused both Linear → activation pairs. So, my next step was to see if this actually sped things up.
I moved to Google Colab for a T4 GPU. This pass had its own round of environment issues but they were all Colab/Triton-specific.
| Experiment | Result |
|---|---|
Full model, Linear + activation fused |
1.00× |
Single activation, Triton gelu vs PyTorch gelu |
1.00× |
3-op chain, relu → gelu → silu, Triton vs PyTorch |
2.95× |
That last row, in numbers: unfused relu → gelu → silu ran at 0.855 ms; the single fused Triton kernel ran at 0.290 ms — a 2.95× speedup on the T4. Three full VRAM round-trips collapsed into one.
Another note: my first instinct on the Linear → activation case was to write a Triton kernel that did the matmul and the activation in one pass. That kernel ran at 0.40×: more than twice as slow as just calling PyTorch's F.linear. The reason is that PyTorch's Linear calls into NVIDIA's cuBLAS, which is hardware-tuned, uses tensor cores, and years of optimization behind it. My handwritten matmul obviously can't compete. So, I learned to let cuBLAS do the matmul, and fuse only the cheap elementwise ops around it. That's why the benchmark ended up split the way it is, and why the real efficiency gains lie in the activation chain rather than the Linear fusion.
Results
| Experiment | Hardware | Result |
|---|---|---|
Pointwise chain, relu → gelu |
CPU | 1.05× |
Linear + activation fused |
T4 GPU | 1.00× |
Single activation, gelu |
T4 GPU | 1.00× |
3-op chain, relu → gelu → silu |
T4 GPU | 2.95× |
The 3-op activation chain is the result the whole project was built to produce. The unfused version was running relu, then gelu, then silu as three separate PyTorch kernels. That process took 0.855 ms. The single fused Triton kernel, doing all three with the intermediates held in registers, took 0.290 ms. That's a 2.95× speedup on an NVIDIA T4.
Each of those three activations, run separately, costs a full read from VRAM and a full write back. The math in between is almost free. So three ops means six trips across the slow memory bus. The fused kernel does one read at the start and one write at the end, with both intermediate results living in on-chip registers — two of the three round-trips simply deleted. Same output, a third of the memory traffic, roughly a third of the time. That's not a trick or an artifact; it's the entire mechanism of kernel fusion, isolated and measured.
The three that didn't
1. Linear + activation fusion: 1.00×
Even though the graph rewrite worked perfectly, FusedLinearAct is the Python-tier implementation. So, it still calls F.linear and then the activation, in sequence, so the intermediate tensor still gets written to VRAM and read back.
2. Single activation: 1.00×
I wondered if I could beat PyTorch's gelu with a Triton gelu. I couldn't, because PyTorch's gelu is already a Triton kernel internally in recent versions. Basically, there's no round-trip to eliminate in a single op.
3. Pointwise chain on CPU: 1.05×
The same relu → gelu logic that wins on GPU does basically nothing on CPU, because a CPU keeps intermediates in its cache and there's no expensive VRAM bus to avoid. Fusion only pays off where memory matters (on the GPU).
Fusebox does nothing on CPU (no memory bottleneck), nothing for a single op (no round-trip to remove), and nothing if the fusion still writes the intermediate to memory anyway. But it works in a chain of memory-bound operations on a GPU, backed by a kernel that keeps the intermediates on chip.