nix-ai/flake.nix
2024-11-15 11:59:52 +01:00

174 lines
5.4 KiB
Nix

{
description = "train AI-models conveniently";
inputs = {
nixpkgs.url = "github:NixOS/nixpkgs/nixos-unstable";
nixpkgs-overwrites.url = "github:DerDennisOP/nixpkgs/comp";
flake-utils.url = "github:numtide/flake-utils";
extra-container = {
url = "github:erikarvstedt/extra-container";
inputs = {
nixpkgs.follows = "nixpkgs";
flake-utils.follows = "flake-utils";
};
};
};
outputs = { self, nixpkgs, nixpkgs-overwrites, flake-utils, extra-container, ... }:
let
inherit (nixpkgs) lib;
system = "x86_64-linux";
stateVersion = "24.05";
getOverlays = map (v: self.overlays.${v}) (builtins.attrNames self.overlays);
pkgs = import nixpkgs {
inherit system;
overlays = getOverlays;
config = {
allowUnfree = true;
cudaSupport = true;
enableCudnn = true;
enableFfmpeg = true;
enableTesseract = true;
enableGtk2 = true;
enableQt = true;
};
};
pkgsNoCuda = import nixpkgs {
inherit system;
overlays = getOverlays;
config = {
allowUnfree = true;
enableFfmpeg = true;
enableTesseract = true;
enableGtk2 = true;
enableQt = true;
};
};
flake-overlay = module: (attrs: (lib.evalModules {
modules = [{
imports = [ module ];
config = {
_module.args = { inherit lib flake-utils pkgs pkgsNoCuda; };
} // attrs;
}];
}).config.flake);
nixosConfigurationsFromContainerConfigs = containerName: containerConfig: nixpkgs.lib.nixosSystem {
inherit system;
specialArgs = { inherit (self) nixosModules; };
modules = [
containerConfig.config
{
boot.isContainer = true;
system = { inherit stateVersion; };
}
];
};
buildContainerFromContainerConfigs = containerName: containerConfig: lib.attrsets.nameValuePair
("buildContainer_" + containerName)
(extra-container.lib.buildContainers {
inherit system nixpkgs;
config.containers."${containerName}" = containerConfig // {
specialArgs = { inherit (self) nixosModules; };
};
});
containerConfigsWithSpecialArgs = containerName: containerConfig:
containerConfig // { specialArgs = { inherit (self) nixosModules; }; };
containerConfigs = lib.attrsets.mergeAttrsList (map (x: x.containers) [ (import ./containers/buildMachine.nix { inherit lib; })
(import ./containers/hydra.nix { inherit lib; })]);
buildContainer_ = lib.attrsets.mapAttrs' buildContainerFromContainerConfigs containerConfigs;
in
{
lib.mkFlake = flake-overlay ./mk-flake.nix;
# stable-baselines3: https://github.com/NixOS/nixpkgs/pull/355954
overlays = {
wvls = final: prev: { inherit (self.packages.${final.system}) wvls; };
} // (builtins.listToAttrs (map (py: {
name = "${py}Overlays";
value = final: prev: let
pyPackages = "${py}Packages";
in {
"${py}PackagesOverlays" = (prev."${py}PackagesOverlays" or [ ]) ++ [
(python-final: python-prev: (builtins.listToAttrs ((map (v: {
name = v;
value = nixpkgs-overwrites.legacyPackages.${final.system}.${pyPackages}.${v};
}) [ "stable-baselines3" ])
++ (map (v: {
name = v.from;
value = nixpkgs.legacyPackages.${final.system}.${pyPackages}.${v.to};
}) [ { from = "tensorflow"; to = "tensorflow-bin"; } ])
++ [{
name = "fairseq";
value = nixpkgs.legacyPackages.${final.system}.${pyPackages}.fairseq.overrideAttrs (oldAttrs: {
doCheck = false;
preCheck = "";
disabledTestPaths = [
"examples/simultaneous_translation/tests/test_text_models.py"
"tests/"
];
});
}])))
];
${py} = let
self = prev.${py}.override {
inherit self;
packageOverrides = prev.lib.composeManyExtensions final."${py}PackagesOverlays";
}; in
self;
${pyPackages} = final.${py}.pkgs;
};
}) [ "python3" "python311" ])) // builtins.listToAttrs (map (v: {
name = v;
value = final: prev: { ${v} = nixpkgs-overwrites.legacyPackages.${final.system}.${v}; };
}) [ ]);
nixosConfigurations = builtins.mapAttrs nixosConfigurationsFromContainerConfigs containerConfigs;
packages."${system}" = buildContainer_ // rec {
buildContainers = extra-container.lib.buildContainers {
inherit system nixpkgs;
config.containers = builtins.mapAttrs containerConfigsWithSpecialArgs containerConfigs;
};
default = buildContainers;
wvls = pkgsNoCuda.callPackage ./modules/wvls.nix { };
};
nixosModules = {
mkFlake = ./mk-flake.nix;
hydra = ./modules/hydra.nix;
buildMachine = ./modules/buildMachine.nix;
};
devShells.x86_64-linux.default = with pkgsNoCuda; mkShell {
buildInputs = [
stdenv.cc.cc.lib
pam
] ++ (with python3Packages; [
gitpython
keyring
requests
]);
packages = [
self.packages.x86_64-linux.wvls
];
EXTRA_CCFLAGS = "-I/usr/include";
};
hydraJobs = {
build."${system}" = builtins.mapAttrs (n: v: lib.hydraJob v) self.packages.x86_64-linux;
};
};
}