Commit ·
e417b0c
1
Parent(s): 0eb0b1e
use python script to clone repo and trigger train
Browse files- run_sm.py +16 -0
- run_speech_recognition_seq2seq_streaming.py +1 -2
- sm.py +8 -5
run_sm.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""" Python script that triggers sagemaker flow"""
|
| 2 |
+
|
| 3 |
+
import sys
|
| 4 |
+
import subprocess
|
| 5 |
+
|
| 6 |
+
def main():
|
| 7 |
+
# Let's skip arg names
|
| 8 |
+
repo = sys.argv[2]
|
| 9 |
+
repo_name = repo.split('/')[-1]
|
| 10 |
+
script_name = sys.argv[4]
|
| 11 |
+
cmd = f'git clone {repo} && cd {repo_name} && sh {script_name}'
|
| 12 |
+
# subprocess.call(cmd, shell=True)
|
| 13 |
+
raise ValueError(cmd)
|
| 14 |
+
|
| 15 |
+
if __name__ == '__main__':
|
| 16 |
+
main()
|
run_speech_recognition_seq2seq_streaming.py
CHANGED
|
@@ -19,8 +19,7 @@ with 🤗 Datasets' streaming mode.
|
|
| 19 |
"""
|
| 20 |
# You can also adapt this script for your own sequence to sequence speech
|
| 21 |
# recognition task. Pointers for this are left as comments.
|
| 22 |
-
|
| 23 |
-
raise RuntimeError(f"{os.getcwd()}")
|
| 24 |
import json
|
| 25 |
import logging
|
| 26 |
import os
|
|
|
|
| 19 |
"""
|
| 20 |
# You can also adapt this script for your own sequence to sequence speech
|
| 21 |
# recognition task. Pointers for this are left as comments.
|
| 22 |
+
|
|
|
|
| 23 |
import json
|
| 24 |
import logging
|
| 25 |
import os
|
sm.py
CHANGED
|
@@ -13,7 +13,7 @@ TEST = True
|
|
| 13 |
|
| 14 |
|
| 15 |
test_sm_instances = {
|
| 16 |
-
"ml.g4dn.
|
| 17 |
{
|
| 18 |
"num_instances": 1,
|
| 19 |
"num_gpus": 1
|
|
@@ -30,7 +30,7 @@ full_sm_instances = {
|
|
| 30 |
|
| 31 |
sm_instances = test_sm_instances if TEST else full_sm_instances
|
| 32 |
|
| 33 |
-
ENTRY_POINT = "
|
| 34 |
RUN_SCRIPT = "test_run.sh" if TEST else "run.sh"
|
| 35 |
IMAGE_URI = "116817510867.dkr.ecr.eu-west-1.amazonaws.com/huggingface-pytorch-training:whisper-finetuning-0223e276db78adf4ea4dc5f874793cb2"
|
| 36 |
if IMAGE_URI is None:
|
|
@@ -66,7 +66,6 @@ def parse_run_script():
|
|
| 66 |
.replace("\n", "") \
|
| 67 |
.replace('"', "")
|
| 68 |
line = line.split("=")
|
| 69 |
-
# remove '\t--'
|
| 70 |
key = str(line[0])
|
| 71 |
try:
|
| 72 |
value = line[1]
|
|
@@ -78,8 +77,8 @@ def parse_run_script():
|
|
| 78 |
|
| 79 |
|
| 80 |
set_creds()
|
| 81 |
-
hyperparameters = parse_run_script()
|
| 82 |
-
pprint(hyperparameters)
|
| 83 |
|
| 84 |
hf_token = os.environ.get("HF_TOKEN")
|
| 85 |
if hf_token is None:
|
|
@@ -93,6 +92,10 @@ env_vars = {
|
|
| 93 |
}
|
| 94 |
pprint(env_vars)
|
| 95 |
repo = f"https://huggingface.co/marinone94/{os.getcwd().split('/')[-1]}"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
for sm_instance_name, sm_instance_values in sm_instances.items():
|
| 97 |
num_instances: int = \
|
| 98 |
int(sm_instance_values["num_instances"])
|
|
|
|
| 13 |
|
| 14 |
|
| 15 |
test_sm_instances = {
|
| 16 |
+
"ml.g4dn.xlarge":
|
| 17 |
{
|
| 18 |
"num_instances": 1,
|
| 19 |
"num_gpus": 1
|
|
|
|
| 30 |
|
| 31 |
sm_instances = test_sm_instances if TEST else full_sm_instances
|
| 32 |
|
| 33 |
+
ENTRY_POINT = "run_sm.py"
|
| 34 |
RUN_SCRIPT = "test_run.sh" if TEST else "run.sh"
|
| 35 |
IMAGE_URI = "116817510867.dkr.ecr.eu-west-1.amazonaws.com/huggingface-pytorch-training:whisper-finetuning-0223e276db78adf4ea4dc5f874793cb2"
|
| 36 |
if IMAGE_URI is None:
|
|
|
|
| 66 |
.replace("\n", "") \
|
| 67 |
.replace('"', "")
|
| 68 |
line = line.split("=")
|
|
|
|
| 69 |
key = str(line[0])
|
| 70 |
try:
|
| 71 |
value = line[1]
|
|
|
|
| 77 |
|
| 78 |
|
| 79 |
set_creds()
|
| 80 |
+
# hyperparameters = parse_run_script()
|
| 81 |
+
# pprint(hyperparameters)
|
| 82 |
|
| 83 |
hf_token = os.environ.get("HF_TOKEN")
|
| 84 |
if hf_token is None:
|
|
|
|
| 92 |
}
|
| 93 |
pprint(env_vars)
|
| 94 |
repo = f"https://huggingface.co/marinone94/{os.getcwd().split('/')[-1]}"
|
| 95 |
+
hyperparameters = {
|
| 96 |
+
"repo": repo,
|
| 97 |
+
"entrypoint": RUN_SCRIPT
|
| 98 |
+
}
|
| 99 |
for sm_instance_name, sm_instance_values in sm_instances.items():
|
| 100 |
num_instances: int = \
|
| 101 |
int(sm_instance_values["num_instances"])
|