summaryrefslogtreecommitdiff
path: root/pkgs/development/python-modules/jax-cuda12-pjrt/default.nix
diff options
context:
space:
mode:
Diffstat (limited to 'pkgs/development/python-modules/jax-cuda12-pjrt/default.nix')
-rw-r--r--pkgs/development/python-modules/jax-cuda12-pjrt/default.nix43
1 files changed, 22 insertions, 21 deletions
diff --git a/pkgs/development/python-modules/jax-cuda12-pjrt/default.nix b/pkgs/development/python-modules/jax-cuda12-pjrt/default.nix
index 14c59787f06a..160158d0ed39 100644
--- a/pkgs/development/python-modules/jax-cuda12-pjrt/default.nix
+++ b/pkgs/development/python-modules/jax-cuda12-pjrt/default.nix
@@ -2,7 +2,7 @@
lib,
stdenv,
buildPythonPackage,
- fetchurl,
+ fetchPypi,
addDriverRunpath,
autoPatchelfHook,
pypaInstallHook,
@@ -31,30 +31,31 @@ let
]
);
- # Find new releases at https://storage.googleapis.com/jax-releases
- # When upgrading, you can get these hashes from jaxlib/prefetch.sh. See
- # https://github.com/google/jax/issues/12879 as to why this specific URL is the correct index.
-
- # upstream does not distribute jax-cuda12-pjrt binaries for aarch64-linux
- srcs = {
- "x86_64-linux" = fetchurl {
- url = "https://storage.googleapis.com/jax-releases/cuda12_plugin/jax_cuda12_pjrt-${version}-py3-none-manylinux2014_x86_64.whl";
- hash = "sha256-xTeDBlaLoMgbIwp3ndMZTJ3RAzmrY2CugJKBCNN+f3U=";
- };
- # "aarch64-linux" = fetchurl {
- # url = "https://storage.googleapis.com/jax-releases/cuda12_plugin/jax_cuda12_pjrt-${version}-py3-none-manylinux2014_aarch64.whl";
- # hash = "sha256-AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=";
- # };
- };
in
-buildPythonPackage {
+buildPythonPackage rec {
pname = "jax-cuda12-pjrt";
inherit version;
pyproject = false;
- src =
- srcs.${stdenv.hostPlatform.system}
- or (throw "jax-cuda12-pjrt: No src for ${stdenv.hostPlatform.system}");
+ src = fetchPypi {
+ pname = "jax_cuda12_pjrt";
+ inherit version;
+ format = "wheel";
+ python = "py3";
+ dist = "py3";
+ platform =
+ {
+ x86_64-linux = "manylinux2014_x86_64";
+ aarch64-linux = "manylinux2014_aarch64";
+ }
+ .${stdenv.hostPlatform.system};
+ hash =
+ {
+ x86_64-linux = "sha256-aDcb2cE1JEuJZjA5viCCVWmKdb7JhU1BnqPD+VfKRkY= ";
+ aarch64-linux = "sha256-m/67BqOWFMtomfdzDqhWHxEVasgcuz7GiEpir7OxX/M=";
+ }
+ .${stdenv.hostPlatform.system};
+ };
nativeBuildInputs = [
autoPatchelfHook
@@ -97,7 +98,7 @@ buildPythonPackage {
sourceProvenance = [ lib.sourceTypes.binaryNativeCode ];
license = lib.licenses.asl20;
maintainers = with lib.maintainers; [ natsukium ];
- platforms = lib.attrNames srcs;
+ platforms = lib.platforms.linux;
# see CUDA compatibility matrix
# https://jax.readthedocs.io/en/latest/installation.html#pip-installation-nvidia-gpu-cuda-installed-locally-harder
broken =