diff options
Diffstat (limited to 'pkgs/development/python-modules/jax-cuda12-pjrt/default.nix')
| -rw-r--r-- | pkgs/development/python-modules/jax-cuda12-pjrt/default.nix | 43 |
1 files changed, 22 insertions, 21 deletions
diff --git a/pkgs/development/python-modules/jax-cuda12-pjrt/default.nix b/pkgs/development/python-modules/jax-cuda12-pjrt/default.nix index 14c59787f06a..160158d0ed39 100644 --- a/pkgs/development/python-modules/jax-cuda12-pjrt/default.nix +++ b/pkgs/development/python-modules/jax-cuda12-pjrt/default.nix @@ -2,7 +2,7 @@ lib, stdenv, buildPythonPackage, - fetchurl, + fetchPypi, addDriverRunpath, autoPatchelfHook, pypaInstallHook, @@ -31,30 +31,31 @@ let ] ); - # Find new releases at https://storage.googleapis.com/jax-releases - # When upgrading, you can get these hashes from jaxlib/prefetch.sh. See - # https://github.com/google/jax/issues/12879 as to why this specific URL is the correct index. - - # upstream does not distribute jax-cuda12-pjrt binaries for aarch64-linux - srcs = { - "x86_64-linux" = fetchurl { - url = "https://storage.googleapis.com/jax-releases/cuda12_plugin/jax_cuda12_pjrt-${version}-py3-none-manylinux2014_x86_64.whl"; - hash = "sha256-xTeDBlaLoMgbIwp3ndMZTJ3RAzmrY2CugJKBCNN+f3U="; - }; - # "aarch64-linux" = fetchurl { - # url = "https://storage.googleapis.com/jax-releases/cuda12_plugin/jax_cuda12_pjrt-${version}-py3-none-manylinux2014_aarch64.whl"; - # hash = "sha256-AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA="; - # }; - }; in -buildPythonPackage { +buildPythonPackage rec { pname = "jax-cuda12-pjrt"; inherit version; pyproject = false; - src = - srcs.${stdenv.hostPlatform.system} - or (throw "jax-cuda12-pjrt: No src for ${stdenv.hostPlatform.system}"); + src = fetchPypi { + pname = "jax_cuda12_pjrt"; + inherit version; + format = "wheel"; + python = "py3"; + dist = "py3"; + platform = + { + x86_64-linux = "manylinux2014_x86_64"; + aarch64-linux = "manylinux2014_aarch64"; + } + .${stdenv.hostPlatform.system}; + hash = + { + x86_64-linux = "sha256-aDcb2cE1JEuJZjA5viCCVWmKdb7JhU1BnqPD+VfKRkY= "; + aarch64-linux = "sha256-m/67BqOWFMtomfdzDqhWHxEVasgcuz7GiEpir7OxX/M="; + } + .${stdenv.hostPlatform.system}; + }; nativeBuildInputs = [ autoPatchelfHook @@ -97,7 +98,7 @@ buildPythonPackage { sourceProvenance = [ lib.sourceTypes.binaryNativeCode ]; license = lib.licenses.asl20; maintainers = with lib.maintainers; [ natsukium ]; - platforms = lib.attrNames srcs; + platforms = lib.platforms.linux; # see CUDA compatibility matrix # https://jax.readthedocs.io/en/latest/installation.html#pip-installation-nvidia-gpu-cuda-installed-locally-harder broken = |
