![]() |
|
![]() |
|
I remember pytorch has some pytree capability, no? So is it safe to say that the any-pytree-compatible modules here are already compatible w/ pytorch?
|
![]() |
|
Author here! I didn't know PyTorch had its own pytree system. It looks like it's separate from JAX's pytree registry, though, so Penzai's tooling probably won't work with PyTorch out of the box.
|
![]() |
|
I implemented Jax’s pytrees in pure python. You can use it with whatever you want. https://github.com/shawwn/pytreez The readme is a todo, but the tests are complete. They’re the same that Jax itself uses, but zero dependencies. https://github.com/shawwn/pytreez/blob/master/tests/test_pyt... The concept is simple. The hard part is cross pollination. Suppose you wanted to literally use Jax pytrees with PyTorch. Now you’ll have to import Jax, or my library, and register your modules with it. But anything else that ever uses pytrees need to use the same pytree library, because the registry (the thing that keeps track of pytree compatible classes) is in the library you choose. They don’t share registries. A better way of phrasing it is that if you use a jax-style pytree interface, it should work with any other pytree library. But to my knowledge, the only pytree library besides Jax itself is mine here, and only I use it. So when you ask if pytree-compatible modules are compatible with PyTorch, it’s equivalent to asking whether PyTorch projects use jax, and the answer tends to be no. EDIT: perhaps I’m outdated. OP says that PyTorch has pytree functionality now. https://news.ycombinator.com/item?id=40109662 I guess yet again I was ahead of the times by a couple years; happy to see other ecosystems catch up. Hopefully seeing a simple implementation will clarify the tradeoffs. The best approach for a universal pytree library would be to assume that any class with tree_flatten and tree_unflatten methods are pytreeable, and not require those classes to be explicitly registered. That way you don’t have to worry whether you’re using Jax or PyTorch pytrees. But I gave up trying to make library-agnostic ML modules; in practice it’s better just to choose Jax or PyTorch and be done with it, since making PyTorch modules run in Jax automatically (and vice versa) is a fool’s errand (I was the fool, and it was an errand) for many reasons, not the least of which is that Jax builds an explicit computation graph via jax.jit, a feature PyTorch has only recently (and reluctantly) embraced. But of course, that means if you pick the wrong ecosystem, you’ll miss out on the best tools — hello React vs Vue, or Unreal Engine vs Unity, or dozens of other examples. |
![]() |
|
There are a couple more such libraries. One was inside tensorflow (nest) and then extracted into the standalone dm-tree: https://github.com/deepmind/tree Or also: https://github.com/metaopt/optree I think ideally you would try to use mostly standard types (dict, list, tuple, etc) which are supported by all those libraries in mostly the same way, so it's easy to switch. You have to be careful in some of the small differences though. E.g. what basic types are supported (e.g. dataclass, namedtuple, other derived instances from dict, tuple, etc), or how None is handled. |
![]() |
|
Isn't JAX the most widely used framework in the GenAI space? Most companies there use it -- Cohere, Anthropic, CharacterAI, xAI, Midjourney etc.
|
![]() |
|
most of the GenAI players use both PyTorch and JAX, depending on the hardware they are running on. Character, Anthro, Midjourney, etc. are dual shops (they use both). xAI only uses JAX afaik.
|
![]() |
|
Its meteoric rise started well before the chip embargo. I've looked into it, it liberally borrows ideas from other frameworks, both PyTorch and Jax, and adds some of its own. You lose some of the conceptual purity, but it makes up for it in practical usability, assuming it works as it says on the tin, which it may or may not. PyTorch also has support for Ascend as far as I can tell https://github.com/Ascend/pytorch, so that support does not necessarily explain MindSpore's relative success. Why MindSpore is rising so rapidly is not entirely clear to me. Could be something as simple as preferring a domestic alternative that is adequate to the task and has better documentation in Chinese. Could be cost of compute. Could be both. Nowadays, however, I do agree that the various embargoes would help it (as well as Huawei) a great deal. As a side note I wish Huawei could export its silicon to the West. I bet that'd result in dramatically cheaper compute.
|
![]() |
|
I use it all the time, and there's also a few classes at my uni that use Jax. It's really great for experimentation and research, you can do a lot of things in Jax you just can't in, say, PyTorch.
|
There's too much fragmentation within the JAX NN library space, which penzai isn't helping with. I wish everyone using JAX could agree on a single set of libraries for NN, optimization, and data loading.
PyTorch code can't be called, meaning a lot of reimplementation in JAX is needed when extending and iterating on prior works, which is the case for most of research. Custom CUDA kernels are a bit fiddly too, I haven't been able to bring Gaussian Splatting to JAX yet.