summaryrefslogtreecommitdiff
path: root/pkgs/development/python-modules/jax-cuda12-pjrt/default.nix
blob: 160158d0ed397f0f69ee64d3c74677c8483e4a4c (about) (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
{
  lib,
  stdenv,
  buildPythonPackage,
  fetchPypi,
  addDriverRunpath,
  autoPatchelfHook,
  pypaInstallHook,
  wheelUnpackHook,
  cudaPackages,
  python,
  jaxlib,
}:
let
  inherit (jaxlib) version;
  inherit (cudaPackages) cudaVersion;

  cudaLibPath = lib.makeLibraryPath (
    with cudaPackages;
    [
      (lib.getLib libcublas) # libcublas.so
      (lib.getLib cuda_cupti) # libcupti.so
      (lib.getLib cuda_cudart) # libcudart.so
      (lib.getLib cudnn) # libcudnn.so
      (lib.getLib libcufft) # libcufft.so
      (lib.getLib libcusolver) # libcusolver.so
      (lib.getLib libcusparse) # libcusparse.so
      (lib.getLib nccl) # libnccl.so
      (lib.getLib libnvjitlink) # libnvJitLink.so
      (lib.getLib addDriverRunpath.driverLink) # libcuda.so
    ]
  );

in
buildPythonPackage rec {
  pname = "jax-cuda12-pjrt";
  inherit version;
  pyproject = false;

  src = fetchPypi {
    pname = "jax_cuda12_pjrt";
    inherit version;
    format = "wheel";
    python = "py3";
    dist = "py3";
    platform =
      {
        x86_64-linux = "manylinux2014_x86_64";
        aarch64-linux = "manylinux2014_aarch64";
      }
      .${stdenv.hostPlatform.system};
    hash =
      {
        x86_64-linux = "sha256-aDcb2cE1JEuJZjA5viCCVWmKdb7JhU1BnqPD+VfKRkY= ";
        aarch64-linux = "sha256-m/67BqOWFMtomfdzDqhWHxEVasgcuz7GiEpir7OxX/M=";
      }
      .${stdenv.hostPlatform.system};
  };

  nativeBuildInputs = [
    autoPatchelfHook
    pypaInstallHook
    wheelUnpackHook
  ];

  # jax-cuda12-pjrt looks for ptxas, nvlink and nvvm at runtime, eg when running `jax.random.PRNGKey(0)`.
  # Linking into $out is the least bad solution. See
  # * https://github.com/NixOS/nixpkgs/pull/164176#discussion_r828801621
  # * https://github.com/NixOS/nixpkgs/pull/288829#discussion_r1493852211
  # for more info.
  postInstall = ''
    mkdir -p $out/${python.sitePackages}/jax_plugins/nvidia/cuda_nvcc/bin
    ln -s ${lib.getExe' cudaPackages.cuda_nvcc "ptxas"} $out/${python.sitePackages}/jax_plugins/nvidia/cuda_nvcc/bin/ptxas
    ln -s ${lib.getExe' cudaPackages.cuda_nvcc "nvlink"} $out/${python.sitePackages}/jax_plugins/nvidia/cuda_nvcc/bin/nvlink
    ln -s ${cudaPackages.cuda_nvcc}/nvvm $out/${python.sitePackages}/jax_plugins/nvidia/cuda_nvcc/nvvm
  '';

  # jax-cuda12-pjrt contains shared libraries that open other shared libraries via dlopen
  # and these implicit dependencies are not recognized by ldd or
  # autoPatchelfHook. That means we need to sneak them into rpath. This step
  # must be done after autoPatchelfHook and the automatic stripping of
  # artifacts. autoPatchelfHook runs in postFixup and auto-stripping runs in the
  # patchPhase.
  preInstallCheck = ''
    patchelf --add-rpath "${cudaLibPath}" $out/${python.sitePackages}/jax_plugins/xla_cuda12/xla_cuda_plugin.so
  '';

  # FIXME: there are no tests, but we need to run preInstallCheck above
  doCheck = true;

  pythonImportsCheck = [ "jax_plugins" ];

  inherit cudaLibPath;

  meta = {
    description = "JAX XLA PJRT Plugin for NVIDIA GPUs";
    homepage = "https://github.com/jax-ml/jax/tree/main/jax_plugins/cuda";
    sourceProvenance = [ lib.sourceTypes.binaryNativeCode ];
    license = lib.licenses.asl20;
    maintainers = with lib.maintainers; [ natsukium ];
    platforms = lib.platforms.linux;
    # see CUDA compatibility matrix
    # https://jax.readthedocs.io/en/latest/installation.html#pip-installation-nvidia-gpu-cuda-installed-locally-harder
    broken =
      !(lib.versionAtLeast cudaVersion "12.1") || !(lib.versionAtLeast cudaPackages.cudnn.version "9.1");
  };
}