JAXifying PyTorch
I was recently in a situation where I wanted some PyTorch code to run a lot faster. Specifically, I had a function (let’s call it f
) that I wanted to optimize. I was using torch’s autograd to determine the gradient of this function, and optimizing over some parameters with the BFGS optimizer (using SciPy’s minimize
). This was all working fine – but it was super slow. The function was relatively simple, so autograd introduced a lot of overhead. What do I do?
Well, nowadays the correct thing to do is use the new torch.compile
. In fact, it was the correct thing to do when I initially had this problem – it was around in the pre-release, just new and buggy. I should have done my darnedest to modify my code until torch compilation worked. I’m sure with some simplifications I could have got it working.
Unfortunately, this is not what I did. And now I’m paying for it.
It is well known that JAX is super fast – can I just rewrite the function to use JAX primitives? Use jax.grad
and jax.compile
, and bang! It should work wonderfully. Alas, this was not an option – I was using f
as part of a PyTorch model I was training, so I would need to either rewrite the entire model and training code to use JAX (a massive task considering my codebase size) or keep two versions of f
around – one using PyTorch, one using JAX. This seemed like a recipe for disaster, considering that I was tweaking f
pretty frequently.
So here’s the idea – why not JAXify a PyTorch function? Something like this:
import torch
import jax
import jax.numpy as jnp
from jorch import to_jax
def f(x):
""" Silly example torch function """
a = torch.sin(x)
b = torch.exp(x*2)
return (a + b).sum()
# Use with standard torch tensors
x_torch = torch.zeros((10, 10))
print("torch f:", f(x_torch))
# AND with JAX arrays. Whaaaa....
f_jax = to_jax(f)
x_jax = jnp.zeros((10, 10))
print("jax f:", f_jax(x_jax))
torch f: tensor(100.)
jax f: 100.0
Turns out this is totally possible! It’s a fun little exercise to figure out how to do this. But it’s also a horrible idea to use this in anything important, as I hope this post conveys.
So how does jorch
work?
It turns out that PyTorch has some incredible extensibility features. Namely, if you call a torch function (say, torch.sin
) with a new class it will check to see if that class defines a __torch_function__
method. If so, it will delegate responsibility for computing the function to the __torch_function__
method. It’s a fancy way of implementing multiple dispatch in a language that doesn’t natively support it.
Here’s a minimal version of what we want to do:
from functools import wraps
from torch.utils._pytree import tree_map
def jorch_unwrap(x):
if isinstance(x, JorchTensor):
return x.arr
return x
def jorch_wrap(x):
if isinstance(x, jnp.ndarray):
return JorchTensor(x)
return x
class JorchTensor():
def __init__(self, arr):
self.arr = arr
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs={}):
args = tree_map(jorch_unwrap, args)
kwargs = tree_map(jorch_unwrap, kwargs)
new_func = getattr(jnp, func.__name__)
out = new_func(*args, **kwargs)
return tree_map(jorch_wrap, out)
def to_jax(f):
@wraps(f)
def wrapper(*args, **kwargs):
args = tree_map(jorch_wrap, args)
kwargs = tree_map(jorch_wrap, kwargs)
out = f(*args, **kwargs)
return tree_map(jorch_unwrap, out)
return wrapper
The basic idea is to create a wrapper class JorchTensor
that pretends to be a PyTorch tensor but is secretly holding a JAX array (and is using only JAX primitives).
We first define the functions jorch_unwrap
and jorch_wrap
to convert between JorchTensors
and regular JAX arrays. When the __torch_function__
method is called, we use tree_map
to recursively call jorch_unwrap
on all the arguments to convert them to JAX arrays. Then we determine the JAX version of the PyTorch function in question with getattr(jnp, func.__name__)
. This converts e.g. torch.sin
to jnp.sin
. We use that function on the arguments, wrap the result in a JorchTensor
, and viola!
And now the to_jax
function is pretty simple – just wrap the arguments in JorchTensors
, run the PyTorch function, and return the unwrapped result.
x = jnp.ones((1,))
print(to_jax(torch.sin)(x))
[0.84147096]
Now there are still a lot of things to add to the JorchTensor
class – especially operators. All the arithmetic operators are pretty simple. The only spicy one is __getattribute__
, but the idea is pretty much the same as the __torch_function__
method – unwrap the arguments, find the relevant JAX method, call the method, and wrap the result. I’m not going to get into the details because I want to point out a big, glaring flaw in my strategy. Do you see it?
That’s right! We can’t just naively assume that the jnp
and torch
APIs are the same. Sure, this will mostly work for basic functions like torch.sin
and torch.exp
, but what about torch.linalg.norm
? JAX has an equivalent function jax.norm
, but you’ll need to manually override that in the JAX conversion process.
It gets even worse – even for basic functions like sum
, JAX and PyTorch have slightly different arguments. Whereas JAX would use jnp.sum(x, axis=0, keepdims=True)
, PyTorch prefers torch.sum(x, dim=0, keepdim=True)
. In the final version of jorch
code, I wrote these truly cursed conditions:
if "dim" in kwargs:
kwargs["axis"] = kwargs["dim"]
del kwargs["dim"]
if "keepdim" in kwargs:
kwargs["keepdims"] = kwargs["keepdim"]
del kwargs["keepdim"]
In the end, I did manage to create a system that successfully JAXified my PyTorch function in question – and it is quite fast after compilation. However, the final code for doing all this ended up becoming an unmaintainable mess. I’m currently in the process of getting torch.compile
to work instead.
There’s an important lesson here: it’s cool to write wack code to do wack things. But don’t depend on it for anything important!
The full jorch
code I used in my project is here if you’re interested in playing around with it.