I was recently in a situation where I wanted some PyTorch code to run a lot faster. Specifically, I had a function (let’s call it f) that I wanted to optimize. I was using torch’s autograd to determine the gradient of this function, and optimizing over some parameters with the BFGS optimizer (using SciPy’s minimize). This was all working fine – but it was super slow. The function was relatively simple, so autograd introduced a lot of overhead. What do I do?

Well, nowadays the correct thing to do is use the new torch.compile. In fact, it was the correct thing to do when I initially had this problem – it was around in the pre-release, just new and buggy. I should have done my darnedest to modify my code until torch compilation worked. I’m sure with some simplifications I could have got it working.

Unfortunately, this is not what I did. And now I’m paying for it.

It is well known that JAX is super fast – can I just rewrite the function to use JAX primitives? Use jax.grad and jax.compile, and bang! It should work wonderfully. Alas, this was not an option – I was using f as part of a PyTorch model I was training, so I would need to either rewrite the entire model and training code to use JAX (a massive task considering my codebase size) or keep two versions of f around – one using PyTorch, one using JAX. This seemed like a recipe for disaster, considering that I was tweaking f pretty frequently.

So here’s the idea – why not JAXify a PyTorch function? Something like this:

import torch
import jax
import jax.numpy as jnp
from jorch import to_jax

def f(x):
    """ Silly example torch function """
    a = torch.sin(x)
    b = torch.exp(x*2)
    return (a + b).sum()

# Use with standard torch tensors
x_torch = torch.zeros((10, 10))
print("torch f:", f(x_torch))

# AND with JAX arrays. Whaaaa....
f_jax = to_jax(f)
x_jax = jnp.zeros((10, 10))
print("jax f:", f_jax(x_jax))
torch f: tensor(100.)
jax f: 100.0

Turns out this is totally possible! It’s a fun little exercise to figure out how to do this. But it’s also a horrible idea to use this in anything important, as I hope this post conveys.

So how does jorch work?

It turns out that PyTorch has some incredible extensibility features. Namely, if you call a torch function (say, torch.sin) with a new class it will check to see if that class defines a __torch_function__ method. If so, it will delegate responsibility for computing the function to the __torch_function__ method. It’s a fancy way of implementing multiple dispatch in a language that doesn’t natively support it.

Here’s a minimal version of what we want to do:

from functools import wraps
from torch.utils._pytree import tree_map

def jorch_unwrap(x):
    if isinstance(x, JorchTensor):
        return x.arr
    return x

def jorch_wrap(x):
    if isinstance(x, jnp.ndarray):
        return JorchTensor(x)
    return x

class JorchTensor():

    def __init__(self, arr):
        self.arr = arr

    @classmethod
    def __torch_function__(cls, func, types, args=(), kwargs={}):
        args = tree_map(jorch_unwrap, args)
        kwargs = tree_map(jorch_unwrap, kwargs)
        new_func = getattr(jnp, func.__name__)
        out = new_func(*args, **kwargs)
        return tree_map(jorch_wrap, out)
    
def to_jax(f):
    @wraps(f)
    def wrapper(*args, **kwargs):
        args = tree_map(jorch_wrap, args)
        kwargs = tree_map(jorch_wrap, kwargs)
        out = f(*args, **kwargs)
        return tree_map(jorch_unwrap, out)
    return wrapper

The basic idea is to create a wrapper class JorchTensor that pretends to be a PyTorch tensor but is secretly holding a JAX array (and is using only JAX primitives).

We first define the functions jorch_unwrap and jorch_wrap to convert between JorchTensors and regular JAX arrays. When the __torch_function__ method is called, we use tree_map to recursively call jorch_unwrap on all the arguments to convert them to JAX arrays. Then we determine the JAX version of the PyTorch function in question with getattr(jnp, func.__name__). This converts e.g. torch.sin to jnp.sin. We use that function on the arguments, wrap the result in a JorchTensor, and viola!

And now the to_jax function is pretty simple – just wrap the arguments in JorchTensors, run the PyTorch function, and return the unwrapped result.

x = jnp.ones((1,))
print(to_jax(torch.sin)(x))
[0.84147096]

Now there are still a lot of things to add to the JorchTensor class – especially operators. All the arithmetic operators are pretty simple. The only spicy one is __getattribute__, but the idea is pretty much the same as the __torch_function__ method – unwrap the arguments, find the relevant JAX method, call the method, and wrap the result. I’m not going to get into the details because I want to point out a big, glaring flaw in my strategy. Do you see it?

That’s right! We can’t just naively assume that the jnp and torch APIs are the same. Sure, this will mostly work for basic functions like torch.sin and torch.exp, but what about torch.linalg.norm? JAX has an equivalent function jax.norm, but you’ll need to manually override that in the JAX conversion process.

It gets even worse – even for basic functions like sum, JAX and PyTorch have slightly different arguments. Whereas JAX would use jnp.sum(x, axis=0, keepdims=True), PyTorch prefers torch.sum(x, dim=0, keepdim=True). In the final version of jorch code, I wrote these truly cursed conditions:

    if "dim" in kwargs:
        kwargs["axis"] = kwargs["dim"]
        del kwargs["dim"]
    if "keepdim" in kwargs:
        kwargs["keepdims"] = kwargs["keepdim"]
        del kwargs["keepdim"]

In the end, I did manage to create a system that successfully JAXified my PyTorch function in question – and it is quite fast after compilation. However, the final code for doing all this ended up becoming an unmaintainable mess. I’m currently in the process of getting torch.compile to work instead.

There’s an important lesson here: it’s cool to write wack code to do wack things. But don’t depend on it for anything important!

The full jorch code I used in my project is here if you’re interested in playing around with it.