| import asyncio |
| import json |
| import os |
| import shutil |
| import subprocess |
| import tempfile |
| import time |
| from pathlib import Path |
| from unittest.mock import patch |
|
|
| import gradio_client as grc |
| import pytest |
| from gradio_client import media_data |
| from gradio_client import utils as client_utils |
| from pydub import AudioSegment |
| from starlette.testclient import TestClient |
| from tqdm import tqdm |
|
|
| import gradio as gr |
| from gradio import utils |
|
|
|
|
| @patch("gradio.utils.get_cache_folder", return_value=Path(tempfile.mkdtemp())) |
| class TestExamples: |
| def test_handle_single_input(self, patched_cache_folder): |
| examples = gr.Examples(["hello", "hi"], gr.Textbox()) |
| assert examples.processed_examples == [["hello"], ["hi"]] |
|
|
| examples = gr.Examples([["hello"]], gr.Textbox()) |
| assert examples.processed_examples == [["hello"]] |
|
|
| examples = gr.Examples(["test/test_files/bus.png"], gr.Image()) |
| assert ( |
| client_utils.encode_file_to_base64( |
| examples.processed_examples[0][0]["path"] |
| ) |
| == media_data.BASE64_IMAGE |
| ) |
|
|
| def test_handle_multiple_inputs(self, patched_cache_folder): |
| examples = gr.Examples( |
| [["hello", "test/test_files/bus.png"]], [gr.Textbox(), gr.Image()] |
| ) |
| assert examples.processed_examples[0][0] == "hello" |
| assert ( |
| client_utils.encode_file_to_base64( |
| examples.processed_examples[0][1]["path"] |
| ) |
| == media_data.BASE64_IMAGE |
| ) |
|
|
| def test_handle_directory(self, patched_cache_folder): |
| examples = gr.Examples("test/test_files/images", gr.Image()) |
| assert len(examples.processed_examples) == 2 |
| for row in examples.processed_examples: |
| for output in row: |
| assert ( |
| client_utils.encode_file_to_base64(output["path"]) |
| == media_data.BASE64_IMAGE |
| ) |
|
|
| def test_handle_directory_with_log_file(self, patched_cache_folder): |
| examples = gr.Examples( |
| "test/test_files/images_log", [gr.Image(label="im"), gr.Text()] |
| ) |
| ex = client_utils.traverse( |
| examples.processed_examples, |
| lambda s: client_utils.encode_file_to_base64(s["path"]), |
| lambda x: isinstance(x, dict) and Path(x["path"]).exists(), |
| ) |
| assert ex == [ |
| [media_data.BASE64_IMAGE, "hello"], |
| [media_data.BASE64_IMAGE, "hi"], |
| ] |
| for sample in examples.dataset.samples: |
| assert os.path.isabs(sample[0]["path"]) |
|
|
| def test_examples_per_page(self, patched_cache_folder): |
| examples = gr.Examples(["hello", "hi"], gr.Textbox(), examples_per_page=2) |
| assert examples.dataset.get_config()["samples_per_page"] == 2 |
|
|
| def test_no_preprocessing(self, patched_cache_folder): |
| with gr.Blocks(): |
| image = gr.Image() |
| textbox = gr.Textbox() |
|
|
| examples = gr.Examples( |
| examples=["test/test_files/bus.png"], |
| inputs=image, |
| outputs=textbox, |
| fn=lambda x: x["path"], |
| cache_examples=True, |
| preprocess=False, |
| ) |
|
|
| prediction = examples.load_from_cache(0) |
| assert ( |
| client_utils.encode_file_to_base64(prediction[0]) == media_data.BASE64_IMAGE |
| ) |
|
|
| def test_no_postprocessing(self, patched_cache_folder): |
| def im(x): |
| return [ |
| { |
| "image": { |
| "path": "test/test_files/bus.png", |
| }, |
| "caption": "hi", |
| } |
| ] |
|
|
| with gr.Blocks(): |
| text = gr.Textbox() |
| gall = gr.Gallery() |
|
|
| examples = gr.Examples( |
| examples=["hi"], |
| inputs=text, |
| outputs=gall, |
| fn=im, |
| cache_examples=True, |
| postprocess=False, |
| ) |
|
|
| prediction = examples.load_from_cache(0) |
| file = prediction[0].root[0].image.path |
| assert client_utils.encode_url_or_file_to_base64( |
| file |
| ) == client_utils.encode_url_or_file_to_base64("test/test_files/bus.png") |
|
|
|
|
| def test_setting_cache_dir_env_variable(monkeypatch): |
| temp_dir = tempfile.mkdtemp() |
| monkeypatch.setenv("GRADIO_EXAMPLES_CACHE", temp_dir) |
| with gr.Blocks(): |
| image = gr.Image() |
| image2 = gr.Image() |
|
|
| examples = gr.Examples( |
| examples=["test/test_files/bus.png"], |
| inputs=image, |
| outputs=image2, |
| fn=lambda x: x, |
| cache_examples=True, |
| ) |
| prediction = examples.load_from_cache(0) |
| path_to_cached_file = Path(prediction[0].path) |
| assert utils.is_in_or_equal(path_to_cached_file, temp_dir) |
| monkeypatch.delenv("GRADIO_EXAMPLES_CACHE", raising=False) |
|
|
|
|
| @patch("gradio.utils.get_cache_folder", return_value=Path(tempfile.mkdtemp())) |
| class TestExamplesDataset: |
| def test_no_headers(self, patched_cache_folder): |
| examples = gr.Examples("test/test_files/images_log", [gr.Image(), gr.Text()]) |
| assert examples.dataset.headers == [] |
|
|
| def test_all_headers(self, patched_cache_folder): |
| examples = gr.Examples( |
| "test/test_files/images_log", |
| [gr.Image(label="im"), gr.Text(label="your text")], |
| ) |
| assert examples.dataset.headers == ["im", "your text"] |
|
|
| def test_some_headers(self, patched_cache_folder): |
| examples = gr.Examples( |
| "test/test_files/images_log", [gr.Image(label="im"), gr.Text()] |
| ) |
| assert examples.dataset.headers == ["im", ""] |
|
|
|
|
| def test_example_caching_relaunch(connect): |
| def combine(a, b): |
| return a + " " + b |
|
|
| with gr.Blocks() as demo: |
| txt = gr.Textbox(label="Input") |
| txt_2 = gr.Textbox(label="Input 2") |
| txt_3 = gr.Textbox(value="", label="Output") |
| btn = gr.Button(value="Submit") |
| btn.click(combine, inputs=[txt, txt_2], outputs=[txt_3]) |
| gr.Examples( |
| [["hi", "Adam"], ["hello", "Eve"]], |
| [txt, txt_2], |
| txt_3, |
| combine, |
| cache_examples=True, |
| api_name="examples", |
| ) |
|
|
| with connect(demo) as client: |
| assert client.predict(1, api_name="/examples") == ( |
| "hello", |
| "Eve", |
| "hello Eve", |
| ) |
|
|
| |
| time.sleep(1) |
|
|
| with connect(demo) as client: |
| assert client.predict(1, api_name="/examples") == ( |
| "hello", |
| "Eve", |
| "hello Eve", |
| ) |
|
|
|
|
| @patch("gradio.utils.get_cache_folder", return_value=Path(tempfile.mkdtemp())) |
| class TestProcessExamples: |
| def test_caching(self, patched_cache_folder): |
| io = gr.Interface( |
| lambda x: f"Hello {x}", |
| "text", |
| "text", |
| examples=[["World"], ["Dunya"], ["Monde"]], |
| cache_examples=True, |
| ) |
| prediction = io.examples_handler.load_from_cache(1) |
| assert prediction[0] == "Hello Dunya" |
|
|
| def test_example_caching_relaunch(self, patched_cache_folder, connect): |
| def combine(a, b): |
| return a + " " + b |
|
|
| with gr.Blocks() as demo: |
| txt = gr.Textbox(label="Input") |
| txt_2 = gr.Textbox(label="Input 2") |
| txt_3 = gr.Textbox(value="", label="Output") |
| btn = gr.Button(value="Submit") |
| btn.click(combine, inputs=[txt, txt_2], outputs=[txt_3]) |
| gr.Examples( |
| [["hi", "Adam"], ["hello", "Eve"]], |
| [txt, txt_2], |
| txt_3, |
| combine, |
| cache_examples=True, |
| api_name="examples", |
| ) |
|
|
| with connect(demo) as client: |
| assert client.predict(1, api_name="/examples") == ( |
| "hello", |
| "Eve", |
| "hello Eve", |
| ) |
|
|
| with connect(demo) as client: |
| assert client.predict(1, api_name="/examples") == ( |
| "hello", |
| "Eve", |
| "hello Eve", |
| ) |
|
|
| def test_caching_image(self, patched_cache_folder): |
| io = gr.Interface( |
| lambda x: x, |
| "image", |
| "image", |
| examples=[["test/test_files/bus.png"]], |
| cache_examples=True, |
| ) |
| prediction = io.examples_handler.load_from_cache(0) |
| assert client_utils.encode_url_or_file_to_base64(prediction[0].path).startswith( |
| "data:image/png;base64,iVBORw0KGgoAAA" |
| ) |
|
|
| def test_caching_audio(self, patched_cache_folder): |
| io = gr.Interface( |
| lambda x: x, |
| "audio", |
| "audio", |
| examples=[["test/test_files/audio_sample.wav"]], |
| cache_examples=True, |
| ) |
| prediction = io.examples_handler.load_from_cache(0) |
| file = prediction[0].path |
| assert client_utils.encode_url_or_file_to_base64(file).startswith( |
| "data:audio/wav;base64,UklGRgA/" |
| ) |
|
|
| def test_caching_with_update(self, patched_cache_folder): |
| io = gr.Interface( |
| lambda x: gr.update(visible=False), |
| "text", |
| "image", |
| examples=[["World"], ["Dunya"], ["Monde"]], |
| cache_examples=True, |
| ) |
| prediction = io.examples_handler.load_from_cache(1) |
| assert prediction[0] == { |
| "visible": False, |
| "__type__": "update", |
| } |
|
|
| def test_caching_with_mix_update(self, patched_cache_folder): |
| io = gr.Interface( |
| lambda x: [gr.update(lines=4, value="hello"), "test/test_files/bus.png"], |
| "text", |
| ["text", "image"], |
| examples=[["World"], ["Dunya"], ["Monde"]], |
| cache_examples=True, |
| ) |
| prediction = io.examples_handler.load_from_cache(1) |
| assert prediction[0] == { |
| "lines": 4, |
| "value": "hello", |
| "__type__": "update", |
| } |
|
|
| def test_caching_with_dict(self, patched_cache_folder): |
| text = gr.Textbox() |
| out = gr.Label() |
|
|
| io = gr.Interface( |
| lambda _: {text: gr.update(lines=4, interactive=False), out: "lion"}, |
| "textbox", |
| [text, out], |
| examples=["abc"], |
| cache_examples=True, |
| ) |
| prediction = io.examples_handler.load_from_cache(0) |
| assert prediction == [ |
| {"lines": 4, "__type__": "update", "interactive": False}, |
| gr.Label.data_model(**{"label": "lion", "confidences": None}), |
| ] |
|
|
| def test_caching_with_generators(self, patched_cache_folder): |
| def test_generator(x): |
| for y in range(len(x)): |
| yield "Your output: " + x[: y + 1] |
|
|
| io = gr.Interface( |
| test_generator, |
| "textbox", |
| "textbox", |
| examples=["abcdef"], |
| cache_examples=True, |
| ) |
| prediction = io.examples_handler.load_from_cache(0) |
| assert prediction[0] == "Your output: abcdef" |
|
|
| def test_caching_with_generators_and_streamed_output(self, patched_cache_folder): |
| file_dir = Path(Path(__file__).parent, "test_files") |
| audio = str(file_dir / "audio_sample.wav") |
|
|
| def test_generator(x): |
| for y in range(int(x)): |
| yield audio, y * 5 |
|
|
| io = gr.Interface( |
| test_generator, |
| "number", |
| [gr.Audio(streaming=True), "number"], |
| examples=[3], |
| cache_examples=True, |
| ) |
| prediction = io.examples_handler.load_from_cache(0) |
| len_input_audio = len(AudioSegment.from_wav(audio)) |
| len_output_audio = len(AudioSegment.from_wav(prediction[0].path)) |
| length_ratio = len_output_audio / len_input_audio |
| assert round(length_ratio, 1) == 3.0 |
| assert float(prediction[1]) == 10.0 |
|
|
| def test_caching_with_async_generators(self, patched_cache_folder): |
| async def test_generator(x): |
| for y in range(len(x)): |
| yield "Your output: " + x[: y + 1] |
|
|
| io = gr.Interface( |
| test_generator, |
| "textbox", |
| "textbox", |
| examples=["abcdef"], |
| cache_examples=True, |
| ) |
| prediction = io.examples_handler.load_from_cache(0) |
| assert prediction[0] == "Your output: abcdef" |
|
|
| def test_raise_helpful_error_message_if_providing_partial_examples( |
| self, patched_cache_folder, tmp_path |
| ): |
| def foo(a, b): |
| return a + b |
|
|
| with pytest.warns( |
| UserWarning, |
| match="^Examples are being cached but not all input components have", |
| ): |
| with pytest.raises(Exception): |
| gr.Interface( |
| foo, |
| inputs=["text", "text"], |
| outputs=["text"], |
| examples=[["foo"], ["bar"]], |
| cache_examples=True, |
| ) |
|
|
| with pytest.warns( |
| UserWarning, |
| match="^Examples are being cached but not all input components have", |
| ): |
| with pytest.raises(Exception): |
| gr.Interface( |
| foo, |
| inputs=["text", "text"], |
| outputs=["text"], |
| examples=[["foo", "bar"], ["bar", None]], |
| cache_examples=True, |
| ) |
|
|
| def foo_no_exception(a, b=2): |
| return a * b |
|
|
| gr.Interface( |
| foo_no_exception, |
| inputs=["text", "number"], |
| outputs=["text"], |
| examples=[["foo"], ["bar"]], |
| cache_examples=True, |
| ) |
|
|
| def many_missing(a, b, c): |
| return a * b |
|
|
| with pytest.warns( |
| UserWarning, |
| match="^Examples are being cached but not all input components have", |
| ): |
| with pytest.raises(Exception): |
| gr.Interface( |
| many_missing, |
| inputs=["text", "number", "number"], |
| outputs=["text"], |
| examples=[["foo", None, None], ["bar", 2, 3]], |
| cache_examples=True, |
| ) |
|
|
| def test_caching_with_batch(self, patched_cache_folder): |
| def trim_words(words, lens): |
| trimmed_words = [word[:length] for word, length in zip(words, lens)] |
| return [trimmed_words] |
|
|
| io = gr.Interface( |
| trim_words, |
| ["textbox", gr.Number(precision=0)], |
| ["textbox"], |
| batch=True, |
| max_batch_size=16, |
| examples=[["hello", 3], ["hi", 4]], |
| cache_examples=True, |
| ) |
| prediction = io.examples_handler.load_from_cache(0) |
| assert prediction == ["hel"] |
|
|
| def test_caching_with_batch_multiple_outputs(self, patched_cache_folder): |
| def trim_words(words, lens): |
| trimmed_words = [word[:length] for word, length in zip(words, lens)] |
| return trimmed_words, lens |
|
|
| io = gr.Interface( |
| trim_words, |
| ["textbox", gr.Number(precision=0)], |
| ["textbox", gr.Number(precision=0)], |
| batch=True, |
| max_batch_size=16, |
| examples=[["hello", 3], ["hi", 4]], |
| cache_examples=True, |
| ) |
| prediction = io.examples_handler.load_from_cache(0) |
| assert prediction == ["hel", "3"] |
|
|
| def test_caching_with_non_io_component(self, patched_cache_folder): |
| def predict(name): |
| return name, gr.update(visible=True) |
|
|
| with gr.Blocks(): |
| t1 = gr.Textbox() |
| with gr.Column(visible=False) as c: |
| t2 = gr.Textbox() |
|
|
| examples = gr.Examples( |
| [["John"], ["Mary"]], |
| fn=predict, |
| inputs=[t1], |
| outputs=[t2, c], |
| cache_examples=True, |
| ) |
|
|
| prediction = examples.load_from_cache(0) |
| assert prediction == ["John", {"visible": True, "__type__": "update"}] |
|
|
| def test_end_to_end(self, patched_cache_folder): |
| def concatenate(str1, str2): |
| return str1 + str2 |
|
|
| with gr.Blocks() as demo: |
| t1 = gr.Textbox() |
| t2 = gr.Textbox() |
| t1.submit(concatenate, [t1, t2], t2) |
|
|
| gr.Examples( |
| [["Hello,", None], ["Michael", None]], |
| inputs=[t1, t2], |
| api_name="load_example", |
| ) |
|
|
| app, _, _ = demo.launch(prevent_thread_lock=True) |
| client = TestClient(app) |
|
|
| response = client.post("/api/load_example/", json={"data": [0]}) |
| assert response.json()["data"] == [ |
| { |
| "lines": 1, |
| "max_lines": 20, |
| "show_label": True, |
| "container": True, |
| "min_width": 160, |
| "autofocus": False, |
| "autoscroll": True, |
| "elem_classes": [], |
| "rtl": False, |
| "show_copy_button": False, |
| "__type__": "update", |
| "visible": True, |
| "value": "Hello,", |
| "type": "text", |
| } |
| ] |
|
|
| response = client.post("/api/load_example/", json={"data": [1]}) |
| assert response.json()["data"] == [ |
| { |
| "lines": 1, |
| "max_lines": 20, |
| "show_label": True, |
| "container": True, |
| "min_width": 160, |
| "autofocus": False, |
| "autoscroll": True, |
| "elem_classes": [], |
| "rtl": False, |
| "show_copy_button": False, |
| "__type__": "update", |
| "visible": True, |
| "value": "Michael", |
| "type": "text", |
| } |
| ] |
|
|
| def test_end_to_end_cache_examples(self, patched_cache_folder): |
| def concatenate(str1, str2): |
| return f"{str1} {str2}" |
|
|
| with gr.Blocks() as demo: |
| t1 = gr.Textbox() |
| t2 = gr.Textbox() |
| t1.submit(concatenate, [t1, t2], t2) |
|
|
| gr.Examples( |
| examples=[["Hello,", "World"], ["Michael", "Jordan"]], |
| inputs=[t1, t2], |
| outputs=[t2], |
| fn=concatenate, |
| cache_examples=True, |
| api_name="load_example", |
| ) |
|
|
| app, _, _ = demo.launch(prevent_thread_lock=True) |
| client = TestClient(app) |
|
|
| response = client.post("/api/load_example/", json={"data": [0]}) |
| assert response.json()["data"] == ["Hello,", "World", "Hello, World"] |
|
|
| response = client.post("/api/load_example/", json={"data": [1]}) |
| assert response.json()["data"] == ["Michael", "Jordan", "Michael Jordan"] |
|
|
|
|
| def test_multiple_file_flagging(tmp_path): |
| with patch("gradio.utils.get_cache_folder", return_value=tmp_path): |
| io = gr.Interface( |
| fn=lambda *x: list(x), |
| inputs=[ |
| gr.Image(type="filepath", label="frame 1"), |
| gr.Image(type="filepath", label="frame 2"), |
| ], |
| outputs=[gr.Files()], |
| examples=[["test/test_files/cheetah1.jpg", "test/test_files/bus.png"]], |
| cache_examples=True, |
| ) |
| prediction = io.examples_handler.load_from_cache(0) |
|
|
| assert len(prediction[0].root) == 2 |
| assert all(isinstance(d, gr.FileData) for d in prediction[0].root) |
|
|
|
|
| def test_examples_keep_all_suffixes(tmp_path): |
| with patch("gradio.utils.get_cache_folder", return_value=Path(tempfile.mkdtemp())): |
| file_1 = tmp_path / "foo.bar.txt" |
| file_1.write_text("file 1") |
| file_2 = tmp_path / "file_2" |
| file_2.mkdir(parents=True) |
| file_2 = file_2 / "foo.bar.txt" |
| file_2.write_text("file 2") |
| io = gr.Interface( |
| fn=lambda x: x.name, |
| inputs=gr.File(), |
| outputs=[gr.File()], |
| examples=[[str(file_1)], [str(file_2)]], |
| cache_examples=True, |
| ) |
| prediction = io.examples_handler.load_from_cache(0) |
| assert Path(prediction[0].path).read_text() == "file 1" |
| assert prediction[0].orig_name == "foo.bar.txt" |
| assert prediction[0].path.endswith("foo.bar.txt") |
| prediction = io.examples_handler.load_from_cache(1) |
| assert Path(prediction[0].path).read_text() == "file 2" |
| assert prediction[0].orig_name == "foo.bar.txt" |
| assert prediction[0].path.endswith("foo.bar.txt") |
|
|
|
|
| def test_make_waveform_with_spaces_in_filename(): |
| with tempfile.TemporaryDirectory() as tmpdirname: |
| audio = os.path.join(tmpdirname, "test audio.wav") |
| shutil.copy("test/test_files/audio_sample.wav", audio) |
| waveform = gr.make_waveform(audio) |
| assert waveform.endswith(".mp4") |
|
|
| try: |
| command = [ |
| "ffprobe", |
| "-v", |
| "error", |
| "-select_streams", |
| "v:0", |
| "-show_entries", |
| "stream=width,height", |
| "-of", |
| "json", |
| waveform, |
| ] |
|
|
| result = subprocess.run(command, capture_output=True, text=True, check=True) |
| output = result.stdout |
| data = json.loads(output) |
|
|
| width = data["streams"][0]["width"] |
| height = data["streams"][0]["height"] |
| assert width == 1000 |
| assert height == 400 |
|
|
| except subprocess.CalledProcessError as e: |
| print("Error retrieving resolution of output waveform video:", e) |
|
|
|
|
| def test_make_waveform_raises_if_ffmpeg_fails(tmp_path, monkeypatch): |
| """ |
| Test that make_waveform raises an exception if ffmpeg fails, |
| instead of returning a path to a non-existent or empty file. |
| """ |
| audio = tmp_path / "test audio.wav" |
| shutil.copy("test/test_files/audio_sample.wav", audio) |
|
|
| def _failing_ffmpeg(*args, **kwargs): |
| raise subprocess.CalledProcessError(1, "ffmpeg") |
|
|
| monkeypatch.setattr(subprocess, "call", _failing_ffmpeg) |
| with pytest.raises(Exception): |
| gr.make_waveform(str(audio)) |
|
|
|
|
| class TestProgressBar: |
| @pytest.mark.asyncio |
| async def test_progress_bar(self): |
| with gr.Blocks() as demo: |
| name = gr.Textbox() |
| greeting = gr.Textbox() |
| button = gr.Button(value="Greet") |
|
|
| def greet(s, prog=gr.Progress()): |
| prog(0, desc="start") |
| time.sleep(0.15) |
| for _ in prog.tqdm(range(4), unit="iter"): |
| time.sleep(0.15) |
| time.sleep(0.15) |
| for _ in tqdm(["a", "b", "c"], desc="alphabet"): |
| time.sleep(0.15) |
| return f"Hello, {s}!" |
|
|
| button.click(greet, name, greeting) |
| demo.queue(max_size=1).launch(prevent_thread_lock=True) |
|
|
| client = grc.Client(demo.local_url) |
| job = client.submit("Gradio") |
|
|
| status_updates = [] |
| while not job.done(): |
| status = job.status() |
| update = ( |
| status.progress_data[0].index if status.progress_data else None, |
| status.progress_data[0].desc if status.progress_data else None, |
| ) |
| if update != (None, None) and ( |
| len(status_updates) == 0 or status_updates[-1] != update |
| ): |
| status_updates.append(update) |
| time.sleep(0.05) |
|
|
| assert status_updates == [ |
| (None, "start"), |
| (0, None), |
| (1, None), |
| (2, None), |
| (3, None), |
| (4, None), |
| ] |
|
|
| @pytest.mark.asyncio |
| async def test_progress_bar_track_tqdm(self): |
| with gr.Blocks() as demo: |
| name = gr.Textbox() |
| greeting = gr.Textbox() |
| button = gr.Button(value="Greet") |
|
|
| def greet(s, prog=gr.Progress(track_tqdm=True)): |
| prog(0, desc="start") |
| time.sleep(0.15) |
| for _ in prog.tqdm(range(4), unit="iter"): |
| time.sleep(0.15) |
| time.sleep(0.15) |
| for _ in tqdm(["a", "b", "c"], desc="alphabet"): |
| time.sleep(0.15) |
| return f"Hello, {s}!" |
|
|
| button.click(greet, name, greeting) |
| demo.queue(max_size=1).launch(prevent_thread_lock=True) |
|
|
| client = grc.Client(demo.local_url) |
| job = client.submit("Gradio") |
|
|
| status_updates = [] |
| while not job.done(): |
| status = job.status() |
| update = ( |
| status.progress_data[0].index if status.progress_data else None, |
| status.progress_data[0].desc if status.progress_data else None, |
| ) |
| if update != (None, None) and ( |
| len(status_updates) == 0 or status_updates[-1] != update |
| ): |
| status_updates.append(update) |
| time.sleep(0.05) |
|
|
| assert status_updates == [ |
| (None, "start"), |
| (0, None), |
| (1, None), |
| (2, None), |
| (3, None), |
| (4, None), |
| (0, "alphabet"), |
| (1, "alphabet"), |
| (2, "alphabet"), |
| ] |
|
|
| @pytest.mark.asyncio |
| async def test_progress_bar_track_tqdm_without_iterable(self): |
| def greet(s, _=gr.Progress(track_tqdm=True)): |
| with tqdm(total=len(s)) as progress_bar: |
| for _c in s: |
| progress_bar.update() |
| time.sleep(0.15) |
| return f"Hello, {s}!" |
|
|
| demo = gr.Interface(greet, "text", "text") |
| demo.queue().launch(prevent_thread_lock=True) |
|
|
| client = grc.Client(demo.local_url) |
| job = client.submit("Gradio") |
|
|
| status_updates = [] |
| while not job.done(): |
| status = job.status() |
| update = ( |
| status.progress_data[0].index if status.progress_data else None, |
| status.progress_data[0].unit if status.progress_data else None, |
| ) |
| if update != (None, None) and ( |
| len(status_updates) == 0 or status_updates[-1] != update |
| ): |
| status_updates.append(update) |
| time.sleep(0.05) |
|
|
| assert status_updates == [ |
| (1, "steps"), |
| (2, "steps"), |
| (3, "steps"), |
| (4, "steps"), |
| (5, "steps"), |
| (6, "steps"), |
| ] |
|
|
| @pytest.mark.asyncio |
| async def test_info_and_warning_alerts(self): |
| def greet(s): |
| for _c in s: |
| gr.Info(f"Letter {_c}") |
| time.sleep(0.15) |
| if len(s) < 5: |
| gr.Warning("Too short!") |
| time.sleep(0.15) |
| return f"Hello, {s}!" |
|
|
| demo = gr.Interface(greet, "text", "text") |
| demo.queue().launch(prevent_thread_lock=True) |
|
|
| client = grc.Client(demo.local_url) |
| job = client.submit("Jon") |
|
|
| status_updates = [] |
| while not job.done(): |
| status = job.status() |
| update = status.log |
| if update is not None and ( |
| len(status_updates) == 0 or status_updates[-1] != update |
| ): |
| status_updates.append(update) |
| time.sleep(0.05) |
|
|
| assert status_updates == [ |
| ("Letter J", "info"), |
| ("Letter o", "info"), |
| ("Letter n", "info"), |
| ("Too short!", "warning"), |
| ] |
|
|
|
|
| @pytest.mark.asyncio |
| @pytest.mark.parametrize("async_handler", [True, False]) |
| async def test_info_isolation(async_handler: bool): |
| async def greet_async(name): |
| await asyncio.sleep(2) |
| gr.Info(f"Hello {name}") |
| await asyncio.sleep(1) |
| return name |
|
|
| def greet_sync(name): |
| time.sleep(2) |
| gr.Info(f"Hello {name}") |
| time.sleep(1) |
| return name |
|
|
| demo = gr.Interface( |
| greet_async if async_handler else greet_sync, |
| "text", |
| "text", |
| concurrency_limit=2, |
| ) |
| demo.launch(prevent_thread_lock=True) |
|
|
| async def session_interaction(name, delay=0): |
| client = grc.Client(demo.local_url) |
| job = client.submit(name) |
|
|
| status_updates = [] |
| while not job.done(): |
| status = job.status() |
| update = status.log |
| if update is not None and ( |
| len(status_updates) == 0 or status_updates[-1] != update |
| ): |
| status_updates.append(update) |
| time.sleep(0.05) |
| return status_updates[-1][0] if status_updates else None |
|
|
| alice_logs, bob_logs = await asyncio.gather( |
| session_interaction("Alice"), |
| session_interaction("Bob", delay=1), |
| ) |
|
|
| assert alice_logs == "Hello Alice" |
| assert bob_logs == "Hello Bob" |
|
|