230 lines
6.7 KiB
Python
230 lines
6.7 KiB
Python
import os
|
|
import sys
|
|
import git
|
|
import json
|
|
import signal
|
|
import base64
|
|
import keyring
|
|
import argparse
|
|
import requests
|
|
from getpass import getpass
|
|
from string import whitespace
|
|
|
|
parser = argparse.ArgumentParser(prog='wvls', description='Wavelens CLI', add_help=False)
|
|
parser.add_argument('action', type=str, nargs='*', help='Possible values: train, dl')
|
|
parser.add_argument('-h', '--help', action='store_true', help='Show this help message and exit')
|
|
args = parser.parse_args()
|
|
# parser.add_argument('train', action='store_true', help='Execute a train run. wvls train -h for more info')
|
|
# parser.add_argument('dl', action='store_true', help='Download a model. wvls dl -h for more info')
|
|
|
|
if args.action[0] == 'train':
|
|
parser = argparse.ArgumentParser(prog='wvls', description='Wavelens CLI', add_help=True)
|
|
parser.add_argument('train', type=str, nargs='?')
|
|
parser.add_argument('trainfile', type=str, nargs='?', help='Add file to train')
|
|
args = parser.parse_args()
|
|
elif args.action[0] == 'dl':
|
|
parser = argparse.ArgumentParser(prog='wvls', description='Wavelens CLI', add_help=True)
|
|
parser.add_argument('dl', type=str, nargs='?')
|
|
parser.add_argument('model', type=str, nargs='?', help='Add model to download')
|
|
args = parser.parse_args()
|
|
elif args.help:
|
|
parser.print_help()
|
|
sys.exit(0)
|
|
else:
|
|
parser.print_help()
|
|
sys.exit(1)
|
|
|
|
|
|
def signal_handler(sig, frame):
|
|
sys.exit(0)
|
|
|
|
signal.signal(signal.SIGINT, signal_handler)
|
|
|
|
def output_ok():
|
|
print("\033[92m" + "[OK]" + "\033[0m")
|
|
|
|
def output_failed():
|
|
print("\033[91m" + "[Failed]" + "\033[0m")
|
|
exit(1)
|
|
|
|
def input_yes_no(prompt, default=True):
|
|
default_str = " [Y/n]" if default else " [y/N]"
|
|
while True:
|
|
response = input(prompt + default_str + ": ").lower()
|
|
if response in ['y', 'yes']:
|
|
return True
|
|
elif response in ['n', 'no']:
|
|
return False
|
|
elif response == '':
|
|
return default
|
|
else:
|
|
print("Invalid input. Please enter 'yes' or 'no'.")
|
|
|
|
def input_str(prompt):
|
|
while True:
|
|
response = input(prompt + ": ")
|
|
if response == '':
|
|
print("Invalid input. Please enter a URL.")
|
|
else:
|
|
return response
|
|
|
|
def input_pass(prompt):
|
|
while True:
|
|
response = getpass(prompt + ": ")
|
|
if response == '':
|
|
print("Invalid input. Please enter a password.")
|
|
else:
|
|
return response
|
|
|
|
def get_git_root(path):
|
|
git_repo = git.Repo(path, search_parent_directories=True)
|
|
|
|
if not git_repo.working_tree_dir:
|
|
print("Not a git repository")
|
|
exit(1)
|
|
|
|
git_root = git_repo.git.rev_parse("--show-toplevel")
|
|
return git_root
|
|
|
|
def keyring_has(key):
|
|
return keyring.get_password("nix-ai", key) is not None
|
|
|
|
def hydra_request(endpoint, headers, param=None, put=False):
|
|
if put:
|
|
return requests.put(f"{config['hydra_url']}{endpoint}?{param}", headers=headers)
|
|
else:
|
|
return requests.get(f"{config['hydra_url']}{endpoint}", headers=headers)
|
|
|
|
def flake_check_bracket(flake):
|
|
bracket = 0
|
|
for c in flake:
|
|
if c == "{":
|
|
bracket += 1
|
|
elif c == "}":
|
|
bracket -= 1
|
|
|
|
if bracket < 0:
|
|
return False
|
|
|
|
return bracket == 0
|
|
|
|
root_path = get_git_root(os.getcwd())
|
|
config = { }
|
|
trainfile = ""
|
|
configfile = ""
|
|
flakefile = os.path.join(root_path, "flake.nix")
|
|
quickfile = os.path.join(root_path, "quick.nix")
|
|
|
|
if os.path.isfile(quickfile):
|
|
os.remove(quickfile)
|
|
|
|
if not os.path.isfile(flakefile):
|
|
print("\033[91m" + "[Error]" + "\033[0m", "No flake.nix file found")
|
|
exit(1)
|
|
|
|
if args.trainfile is not None:
|
|
trainfile = os.path.join(os.getcwd(), args.trainfile)
|
|
if not os.path.isfile(trainfile):
|
|
print("\033[91m" + "[Error]" + "\033[0m", f"File {trainfile} not found")
|
|
exit(1)
|
|
|
|
configfile = os.path.join("/".join(trainfile.split("/")[:-1]), "config.json")
|
|
|
|
if not os.path.isfile(configfile):
|
|
print("\033[91m" + "[Error]" + "\033[0m", f"File {configfile} not found")
|
|
exit(1)
|
|
else:
|
|
print("\033[91m" + "[Error]" + "\033[0m", "No trainfile specified")
|
|
exit(1)
|
|
|
|
if os.path.exists(os.path.join(root_path, ".nix-ai.json")):
|
|
with open(os.path.join(root_path, ".nix-ai.json")) as f:
|
|
config = json.load(f)
|
|
else:
|
|
print("\033[93m" + "[Warning]" + "\033[0m No .nix-ai.json file found")
|
|
|
|
config["hydra_url"] = input_str("Enter Hydra URL")
|
|
config["hydra_project"] = input_str("Enter Hydra Project Name")
|
|
config["hydra_jobset"] = input_str("Enter Hydra Jobset Name")
|
|
|
|
|
|
config["basic_auth"] = False
|
|
if input_yes_no("Do you have Hydra secured with Basic Auth?", False):
|
|
config["basic_auth"] = True
|
|
|
|
with open(os.path.join(root_path, ".nix-ai.json"), "w") as f:
|
|
json.dump(config, f)
|
|
|
|
if not keyring_has("hydra_session"):
|
|
hydra_session = input_pass("Enter Hydra session id")
|
|
keyring.set_password("nix-ai", "hydra_session", hydra_session)
|
|
|
|
if not keyring_has("basic_auth"):
|
|
username = input_str("Enter username")
|
|
password = input_pass("Enter password")
|
|
basic_auth = base64.b64encode(f"{username}:{password}".encode()).decode()
|
|
keyring.set_password("nix-ai", "basic_auth", basic_auth)
|
|
|
|
print("Checking Hydra connection... " , end="")
|
|
|
|
headers = { "Accept": "application/json", "Content-Type": "application/json" }
|
|
if config["basic_auth"]:
|
|
headers["Authorization"] = f"Basic {keyring.get_password('nix-ai', 'basic_auth')}"
|
|
headers["Cookie"] = f"hydra_session={keyring.get_password('nix-ai', 'hydra_session')}"
|
|
|
|
req = hydra_request(f"/jobset/{config['hydra_project']}/{config['hydra_jobset']}", headers)
|
|
|
|
if req.status_code == 200:
|
|
output_ok()
|
|
else:
|
|
output_failed()
|
|
|
|
print("Checking Flake file... " , end="")
|
|
|
|
relative_trainfile = os.path.relpath(trainfile, root_path)
|
|
relative_trainfile_split = relative_trainfile.split("/")
|
|
relative_dir = "/".join(relative_trainfile_split[:-1])
|
|
|
|
with open(configfile) as f:
|
|
configfile_text = f.read()
|
|
|
|
flake = f"""{{
|
|
directoryPath = ./{ relative_dir };
|
|
commands = ''
|
|
python { relative_trainfile_split[-1] }
|
|
'';
|
|
|
|
config = ''{ configfile_text }'';
|
|
}}"""
|
|
|
|
with open(quickfile, "w") as f:
|
|
f.write(flake)
|
|
|
|
os.chdir(root_path)
|
|
os.system(f"nix flake check && nix flake update")
|
|
|
|
output_ok()
|
|
|
|
print("Commiting Git changes... " , end="")
|
|
|
|
git_repo = git.Repo(root_path)
|
|
git_repo.git.add("*")
|
|
git_repo.git.commit("-m", "QuickTrain")
|
|
git_repo.git.push()
|
|
|
|
os.chdir(root_path + "/" + relative_dir)
|
|
|
|
output_ok()
|
|
|
|
print("Pushing to Hydra... " , end="")
|
|
|
|
req = hydra_request(f"/api/push", headers, f"jobsets={config['hydra_project']}:{config['hydra_jobset']}", put=True)
|
|
|
|
if req.status_code == 200:
|
|
output_ok()
|
|
else:
|
|
output_failed()
|
|
|
|
os.remove(quickfile)
|
|
|