summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorUri Baghin <uri@canva.com>2023-03-16 10:36:37 +1100
committerUri Baghin <uri@canva.com>2023-04-08 10:22:20 +1000
commitc734173bee8c35acc0cf6a4d5d17fe21b1e878be (patch)
tree0e0f500951b376a429637c0dacc2cf48a798e93d
parentbind: replace hard-coded `allow-query` zone setting with a real zone paramete... (diff)
downloadnixpkgs-origin/uri/jax.tar.gz
python3Packages.jaxlib-build: share fetch derivation between different build derivationsorigin/uri/jax
-rw-r--r--pkgs/development/python-modules/jaxlib/default.nix133
1 files changed, 74 insertions, 59 deletions
diff --git a/pkgs/development/python-modules/jaxlib/default.nix b/pkgs/development/python-modules/jaxlib/default.nix
index 95f276ec2451..a2a0493657bc 100644
--- a/pkgs/development/python-modules/jaxlib/default.nix
+++ b/pkgs/development/python-modules/jaxlib/default.nix
@@ -86,7 +86,50 @@ let
];
};
- bazel-build = buildBazelPackage {
+ # Copy-paste from TF derivation.
+ # Most of these are not really used in jaxlib compilation but it's simpler to keep it
+ # 'as is' so that it's more compatible with TF derivation.
+ tf_system_libs = [
+ "absl_py"
+ "astor_archive"
+ "astunparse_archive"
+ "boringssl"
+ # Not packaged in nixpkgs
+ # "com_github_googleapis_googleapis"
+ # "com_github_googlecloudplatform_google_cloud_cpp"
+ "com_github_grpc_grpc"
+ "com_google_protobuf"
+ # Fails with the error: external/org_tensorflow/tensorflow/core/profiler/utils/tf_op_utils.cc:46:49: error: no matching function for call to 're2::RE2::FullMatch(absl::lts_2020_02_25::string_view&, re2::RE2&)'
+ # "com_googlesource_code_re2"
+ "curl"
+ "cython"
+ "dill_archive"
+ "double_conversion"
+ "flatbuffers"
+ "functools32_archive"
+ "gast_archive"
+ "gif"
+ "hwloc"
+ "icu"
+ "jsoncpp_git"
+ "libjpeg_turbo"
+ "lmdb"
+ "nasm"
+ "opt_einsum_archive"
+ "org_sqlite"
+ "pasta"
+ "png"
+ "pybind11"
+ "six_archive"
+ "snappy"
+ "tblib_archive"
+ "termcolor_archive"
+ "typing_extensions_archive"
+ "wrapt"
+ "zlib"
+ ];
+
+ bazel-build = buildBazelPackage rec {
name = "bazel-build-${pname}-${version}";
bazel = bazel_5;
@@ -169,61 +212,10 @@ let
CFG
'';
- # Copy-paste from TF derivation.
- # Most of these are not really used in jaxlib compilation but it's simpler to keep it
- # 'as is' so that it's more compatible with TF derivation.
- TF_SYSTEM_LIBS = lib.concatStringsSep "," ([
- "absl_py"
- "astor_archive"
- "astunparse_archive"
- "boringssl"
- # Not packaged in nixpkgs
- # "com_github_googleapis_googleapis"
- # "com_github_googlecloudplatform_google_cloud_cpp"
- "com_github_grpc_grpc"
- "com_google_protobuf"
- # Fails with the error: external/org_tensorflow/tensorflow/core/profiler/utils/tf_op_utils.cc:46:49: error: no matching function for call to 're2::RE2::FullMatch(absl::lts_2020_02_25::string_view&, re2::RE2&)'
- # "com_googlesource_code_re2"
- "curl"
- "cython"
- "dill_archive"
- "double_conversion"
- "flatbuffers"
- "functools32_archive"
- "gast_archive"
- "gif"
- "hwloc"
- "icu"
- "jsoncpp_git"
- "libjpeg_turbo"
- "lmdb"
- "nasm"
- "opt_einsum_archive"
- "org_sqlite"
- "pasta"
- "png"
- "pybind11"
- "six_archive"
- "snappy"
- "tblib_archive"
- "termcolor_archive"
- "typing_extensions_archive"
- "wrapt"
- "zlib"
- ] ++ lib.optionals (!stdenv.isDarwin) [
- "nsync" # fails to build on darwin
- ]);
-
# Make sure Bazel knows about our configuration flags during fetching so that the
# relevant dependencies can be downloaded.
bazelFlags = [
"-c opt"
- ] ++ lib.optionals (stdenv.targetPlatform.isx86_64 && stdenv.targetPlatform.isUnix) [
- "--config=avx_posix"
- ] ++ lib.optionals cudaSupport [
- "--config=cuda"
- ] ++ lib.optionals mklSupport [
- "--config=mkl_open_source_only"
] ++ lib.optionals stdenv.cc.isClang [
# bazel depends on the compiler frontend automatically selecting these flags based on file
# extension but our clang doesn't.
@@ -231,21 +223,44 @@ let
"--cxxopt=-x" "--cxxopt=c++" "--host_cxxopt=-x" "--host_cxxopt=c++"
];
+ # We intentionally overfetch so we can share the fetch derivation across all the different configurations
fetchAttrs = {
+ TF_SYSTEM_LIBS = lib.concatStringsSep "," tf_system_libs;
+ # we have to force @mkl_dnn_v1 since it's not needed on darwin
+ bazelTargets = bazelTargets ++ [ "@mkl_dnn_v1//:mkl_dnn" ];
+ bazelFlags = bazelFlags ++ [
+ "--config=avx_posix"
+ ] ++ lib.optionals cudaSupport [
+ # ideally we'd add this unconditionally too, but it doesn't work on darwin
+ # we make this conditional on `cudaSupport` instead of the system, so that the hash for both
+ # the cuda and the non-cuda deps can be computed on linux, since a lot of contributors don't
+ # have access to darwin machines
+ "--config=cuda"
+ ] ++ [
+ "--config=mkl_open_source_only"
+ ];
+
sha256 =
if cudaSupport then
- "sha256-n8wo+hD9ZYO1SsJKgyJzUmjRlsz45WT6tt5ZLleGvGY="
- else {
- x86_64-linux = "sha256-A0A18kxgGNGHNQ67ZPUzh3Yq2LEcRV7CqR9EfP80NQk=";
- aarch64-linux = "sha256-mU2jzuDu89jVmaG/M5bA3jSd7n7lDi+h8sdhs1z8p1A=";
- x86_64-darwin = "sha256-9nNTpetvjyipD/l8vKlregl1j/OnZKAcOCoZQeRBvts=";
- aarch64-darwin = "sha256-FqYwI1YC5eqSv+DYj09DC5IaBfFDUCO97y+TFhGiWAA=";
- }.${stdenv.system} or (throw "unsupported system ${stdenv.system}");
+ "sha256-4yu4y4SwSQoeaOz9yojhvCRGSC6jp61ycVDIKyIK/l8="
+ else
+ "sha256-CyRfPfJc600M7VzR3/SQX/EAyeaXRJwDQWot5h2XnFU=";
};
buildAttrs = {
outputs = [ "out" ];
+ TF_SYSTEM_LIBS = lib.concatStringsSep "," (tf_system_libs ++ lib.optionals (!stdenv.isDarwin) [
+ "nsync" # fails to build on darwin
+ ]);
+
+ bazelFlags = bazelFlags ++ lib.optionals (stdenv.targetPlatform.isx86_64 && stdenv.targetPlatform.isUnix) [
+ "--config=avx_posix"
+ ] ++ lib.optionals cudaSupport [
+ "--config=cuda"
+ ] ++ lib.optionals mklSupport [
+ "--config=mkl_open_source_only"
+ ];
# Note: we cannot do most of this patching at `patch` phase as the deps are not available yet.
# 1) Fix pybind11 include paths.
# 2) Link protobuf from nixpkgs (through TF_SYSTEM_LIBS when using gcc) to prevent crashes on