summaryrefslogtreecommitdiff
path: root/pkgs/development/python-modules/dm-haiku/default.nix
diff options
context:
space:
mode:
Diffstat (limited to 'pkgs/development/python-modules/dm-haiku/default.nix')
-rw-r--r--pkgs/development/python-modules/dm-haiku/default.nix11
1 files changed, 11 insertions, 0 deletions
diff --git a/pkgs/development/python-modules/dm-haiku/default.nix b/pkgs/development/python-modules/dm-haiku/default.nix
index 5ea929f8570e..0a818b0b8f36 100644
--- a/pkgs/development/python-modules/dm-haiku/default.nix
+++ b/pkgs/development/python-modules/dm-haiku/default.nix
@@ -58,6 +58,17 @@ let
})
];
+ # AttributeError: jax.core.Var was removed in JAX v0.6.0. Use jax.extend.core.Var instead, and
+ # see https://docs.jax.dev/en/latest/jax.extend.html for details.
+ # Alrady on master: https://github.com/google-deepmind/dm-haiku/commit/cfe8480d253a93100bf5e2d24c40435a95399c96
+ # TODO: remove at the next release
+ postPatch = ''
+ substituteInPlace haiku/_src/jaxpr_info.py \
+ --replace-fail "jax.core.JaxprEqn" "jax.extend.core.JaxprEqn" \
+ --replace-fail "jax.core.Var" "jax.extend.core.Var" \
+ --replace-fail "jax.core.Jaxpr" "jax.extend.core.Jaxpr"
+ '';
+
build-system = [ setuptools ];
dependencies = [