From cf228eda6e9ecab1ca8eef58452c959c915cc126 Mon Sep 17 00:00:00 2001 From: Peder Bergebakken Sundt Date: Thu, 24 Apr 2025 19:38:13 +0200 Subject: [PATCH] overlays.withCudaOrRocm WIP --- overlays.nix | 69 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 69 insertions(+) diff --git a/overlays.nix b/overlays.nix index 825e3b9..1014210 100755 --- a/overlays.nix +++ b/overlays.nix @@ -148,6 +148,75 @@ let } ); + # TODO: this way of overriding does not work + # usage: nix-build ./dev.nix -A .withCuda + # usage: nix-build ./dev.nix -A .withRocm + overlays.withCudaOrRocm = final: prev: { + lib = prev.lib.extend ( + finalLib: prevLib: { + # makeOverridable :: (AttrSet -> a) -> AttrSet -> a + makeOverridable = + f: args: + let + fArgs = lib.functionArgs f; + # Creates a functor with the same arguments as f + mirrorArgs = lib.mirrorFunctionArgs f; + fWithCuda = mirrorArgs ( + origArgs: + f' ( + origArgs + // lib.optionalAttrs (fArgs ? cudaSupport) { + cudaSupport = true; + } + // lib.optionalAttrs (fArgs ? config) { + config = fArgs.config // { + cudaSupport = true; + }; + } + ) + ); + fWithRocm = mirrorArgs ( + origArgs: + f' ( + origArgs + // lib.optionalAttrs (fArgs ? rocmSupport) { + rocmSupport = true; + } + // lib.optionalAttrs (fArgs ? config) { + config = fArgs.config // { + rocmSupport = true; + }; + } + ) + ); + f' = mirrorArgs ( + origArgs: + let + result = f origArgs; + in + if lib.isAttrs result then + if result ? overrideAttrs then + result.overrideAttrs (old: { + passthru = { + withCuda = fWithCuda origArgs; + withRocm = fWithRocm origArgs; + } // old.passthru or { }; + }) + else + { + withCuda = fWithCuda origArgs; + withRocm = fWithRocm origArgs; + } + // result + else + result + ); + in + prevLib.makeOverridable f' args; + } + ); + }; + # very hacky, not guaranteed to work, but may save a lot of .whl rebuilds # usage: nix-build ./dev.nix -A python3Packages..twostage # usage: nix-build ./dev.nix -A python3Packages..twostage.first