nix-ai/scripts/wvls.py
2024-07-25 22:30:48 +02:00

230 lines
6.7 KiB
Python

import os
import sys
import git
import json
import signal
import base64
import keyring
import argparse
import requests
from getpass import getpass
from string import whitespace
parser = argparse.ArgumentParser(prog='wvls', description='Wavelens CLI', add_help=False)
parser.add_argument('action', type=str, nargs='*', help='Possible values: train, dl')
parser.add_argument('-h', '--help', action='store_true', help='Show this help message and exit')
args = parser.parse_args()
# parser.add_argument('train', action='store_true', help='Execute a train run. wvls train -h for more info')
# parser.add_argument('dl', action='store_true', help='Download a model. wvls dl -h for more info')
if args.action[0] == 'train':
parser = argparse.ArgumentParser(prog='wvls', description='Wavelens CLI', add_help=True)
parser.add_argument('train', type=str, nargs='?')
parser.add_argument('trainfile', type=str, nargs='?', help='Add file to train')
args = parser.parse_args()
elif args.action[0] == 'dl':
parser = argparse.ArgumentParser(prog='wvls', description='Wavelens CLI', add_help=True)
parser.add_argument('dl', type=str, nargs='?')
parser.add_argument('model', type=str, nargs='?', help='Add model to download')
args = parser.parse_args()
elif args.help:
parser.print_help()
sys.exit(0)
else:
parser.print_help()
sys.exit(1)
def signal_handler(sig, frame):
sys.exit(0)
signal.signal(signal.SIGINT, signal_handler)
def output_ok():
print("\033[92m" + "[OK]" + "\033[0m")
def output_failed():
print("\033[91m" + "[Failed]" + "\033[0m")
exit(1)
def input_yes_no(prompt, default=True):
default_str = " [Y/n]" if default else " [y/N]"
while True:
response = input(prompt + default_str + ": ").lower()
if response in ['y', 'yes']:
return True
elif response in ['n', 'no']:
return False
elif response == '':
return default
else:
print("Invalid input. Please enter 'yes' or 'no'.")
def input_str(prompt):
while True:
response = input(prompt + ": ")
if response == '':
print("Invalid input. Please enter a URL.")
else:
return response
def input_pass(prompt):
while True:
response = getpass(prompt + ": ")
if response == '':
print("Invalid input. Please enter a password.")
else:
return response
def get_git_root(path):
git_repo = git.Repo(path, search_parent_directories=True)
if not git_repo.working_tree_dir:
print("Not a git repository")
exit(1)
git_root = git_repo.git.rev_parse("--show-toplevel")
return git_root
def keyring_has(key):
return keyring.get_password("nix-ai", key) is not None
def hydra_request(endpoint, headers, param=None, put=False):
if put:
return requests.put(f"{config['hydra_url']}{endpoint}?{param}", headers=headers)
else:
return requests.get(f"{config['hydra_url']}{endpoint}", headers=headers)
def flake_check_bracket(flake):
bracket = 0
for c in flake:
if c == "{":
bracket += 1
elif c == "}":
bracket -= 1
if bracket < 0:
return False
return bracket == 0
root_path = get_git_root(os.getcwd())
config = { }
trainfile = ""
configfile = ""
flakefile = os.path.join(root_path, "flake.nix")
quickfile = os.path.join(root_path, "quick.nix")
if os.path.isfile(quickfile):
os.remove(quickfile)
if not os.path.isfile(flakefile):
print("\033[91m" + "[Error]" + "\033[0m", "No flake.nix file found")
exit(1)
if args.trainfile is not None:
trainfile = os.path.join(os.getcwd(), args.trainfile)
if not os.path.isfile(trainfile):
print("\033[91m" + "[Error]" + "\033[0m", f"File {trainfile} not found")
exit(1)
configfile = os.path.join("/".join(trainfile.split("/")[:-1]), "config.json")
if not os.path.isfile(configfile):
print("\033[91m" + "[Error]" + "\033[0m", f"File {configfile} not found")
exit(1)
else:
print("\033[91m" + "[Error]" + "\033[0m", "No trainfile specified")
exit(1)
if os.path.exists(os.path.join(root_path, ".nix-ai.json")):
with open(os.path.join(root_path, ".nix-ai.json")) as f:
config = json.load(f)
else:
print("\033[93m" + "[Warning]" + "\033[0m No .nix-ai.json file found")
config["hydra_url"] = input_str("Enter Hydra URL")
config["hydra_project"] = input_str("Enter Hydra Project Name")
config["hydra_jobset"] = input_str("Enter Hydra Jobset Name")
config["basic_auth"] = False
if input_yes_no("Do you have Hydra secured with Basic Auth?", False):
config["basic_auth"] = True
with open(os.path.join(root_path, ".nix-ai.json"), "w") as f:
json.dump(config, f)
if not keyring_has("hydra_session"):
hydra_session = input_pass("Enter Hydra session id")
keyring.set_password("nix-ai", "hydra_session", hydra_session)
if not keyring_has("basic_auth"):
username = input_str("Enter username")
password = input_pass("Enter password")
basic_auth = base64.b64encode(f"{username}:{password}".encode()).decode()
keyring.set_password("nix-ai", "basic_auth", basic_auth)
print("Checking Hydra connection... " , end="")
headers = { "Accept": "application/json", "Content-Type": "application/json" }
if config["basic_auth"]:
headers["Authorization"] = f"Basic {keyring.get_password('nix-ai', 'basic_auth')}"
headers["Cookie"] = f"hydra_session={keyring.get_password('nix-ai', 'hydra_session')}"
req = hydra_request(f"/jobset/{config['hydra_project']}/{config['hydra_jobset']}", headers)
if req.status_code == 200:
output_ok()
else:
output_failed()
print("Checking Flake file... " , end="")
relative_trainfile = os.path.relpath(trainfile, root_path)
relative_trainfile_split = relative_trainfile.split("/")
relative_dir = "/".join(relative_trainfile_split[:-1])
with open(configfile) as f:
configfile_text = f.read()
flake = f"""{{
directoryPath = ./{ relative_dir };
commands = ''
python { relative_trainfile_split[-1] }
'';
config = ''{ configfile_text }'';
}}"""
with open(quickfile, "w") as f:
f.write(flake)
os.chdir(root_path)
os.system(f"nix flake check && nix flake update")
output_ok()
print("Commiting Git changes... " , end="")
git_repo = git.Repo(root_path)
git_repo.git.add("*")
git_repo.git.commit("-m", "QuickTrain")
git_repo.git.push()
os.chdir(root_path + "/" + relative_dir)
output_ok()
print("Pushing to Hydra... " , end="")
req = hydra_request(f"/api/push", headers, f"jobsets={config['hydra_project']}:{config['hydra_jobset']}", put=True)
if req.status_code == 200:
output_ok()
else:
output_failed()
os.remove(quickfile)