|
import inspect |
|
import random |
|
import time |
|
|
|
import gradio as gr |
|
import pytest |
|
from pydub import AudioSegment |
|
|
|
|
|
def pytest_configure(config): |
|
config.addinivalue_line( |
|
"markers", "flaky: mark test as flaky. Failure will not cause te" |
|
) |
|
|
|
|
|
@pytest.fixture |
|
def calculator_demo(): |
|
def calculator(num1, operation, num2): |
|
if operation == "add": |
|
return num1 + num2 |
|
elif operation == "subtract": |
|
return num1 - num2 |
|
elif operation == "multiply": |
|
return num1 * num2 |
|
elif operation == "divide": |
|
if num2 == 0: |
|
raise gr.Error("Cannot divide by zero!") |
|
return num1 / num2 |
|
|
|
demo = gr.Interface( |
|
calculator, |
|
["number", gr.Radio(["add", "subtract", "multiply", "divide"]), "number"], |
|
"number", |
|
examples=[ |
|
[5, "add", 3], |
|
[4, "divide", 2], |
|
[-4, "multiply", 2.5], |
|
[0, "subtract", 1.2], |
|
], |
|
) |
|
return demo |
|
|
|
|
|
@pytest.fixture |
|
def calculator_demo_with_defaults(): |
|
def calculator(num1, operation=None, num2=100): |
|
if operation is None or operation == "add": |
|
return num1 + num2 |
|
elif operation == "subtract": |
|
return num1 - num2 |
|
elif operation == "multiply": |
|
return num1 * num2 |
|
elif operation == "divide": |
|
if num2 == 0: |
|
raise gr.Error("Cannot divide by zero!") |
|
return num1 / num2 |
|
|
|
demo = gr.Interface( |
|
calculator, |
|
[ |
|
gr.Number(value=10), |
|
gr.Radio(["add", "subtract", "multiply", "divide"]), |
|
gr.Number(), |
|
], |
|
"number", |
|
examples=[ |
|
[5, "add", 3], |
|
[4, "divide", 2], |
|
[-4, "multiply", 2.5], |
|
[0, "subtract", 1.2], |
|
], |
|
) |
|
return demo |
|
|
|
|
|
@pytest.fixture |
|
def state_demo(): |
|
state = gr.State(delete_callback=lambda x: print("STATE DELETED")) |
|
demo = gr.Interface( |
|
lambda x, y: (x, y), |
|
["textbox", state], |
|
["textbox", state], |
|
) |
|
return demo |
|
|
|
|
|
@pytest.fixture |
|
def increment_demo(): |
|
with gr.Blocks() as demo: |
|
btn1 = gr.Button("Increment") |
|
btn2 = gr.Button("Increment") |
|
btn3 = gr.Button("Increment") |
|
numb = gr.Number() |
|
|
|
state = gr.State(0) |
|
|
|
btn1.click( |
|
lambda x: (x + 1, x + 1), |
|
state, |
|
[state, numb], |
|
api_name="increment_with_queue", |
|
) |
|
btn2.click( |
|
lambda x: (x + 1, x + 1), |
|
state, |
|
[state, numb], |
|
queue=False, |
|
api_name="increment_without_queue", |
|
) |
|
btn3.click( |
|
lambda x: (x + 1, x + 1), |
|
state, |
|
[state, numb], |
|
api_name=False, |
|
) |
|
|
|
return demo |
|
|
|
|
|
@pytest.fixture |
|
def progress_demo(): |
|
def my_function(x, progress=gr.Progress()): |
|
progress(0, desc="Starting...") |
|
for _ in progress.tqdm(range(20)): |
|
time.sleep(0.1) |
|
return x |
|
|
|
return gr.Interface(my_function, gr.Textbox(), gr.Textbox()) |
|
|
|
|
|
@pytest.fixture |
|
def yield_demo(): |
|
def spell(x): |
|
for i in range(len(x)): |
|
time.sleep(0.5) |
|
yield x[:i] |
|
|
|
return gr.Interface(spell, "textbox", "textbox") |
|
|
|
|
|
@pytest.fixture |
|
def cancel_from_client_demo(): |
|
def iteration(): |
|
for i in range(20): |
|
print(f"i: {i}") |
|
yield i |
|
time.sleep(0.5) |
|
|
|
def long_process(): |
|
time.sleep(10) |
|
print("DONE!") |
|
return 10 |
|
|
|
with gr.Blocks() as demo: |
|
num = gr.Number() |
|
|
|
btn = gr.Button(value="Iterate") |
|
btn.click(iteration, None, num, api_name="iterate") |
|
btn2 = gr.Button(value="Long Process") |
|
btn2.click(long_process, None, num, api_name="long") |
|
|
|
return demo |
|
|
|
|
|
@pytest.fixture |
|
def sentiment_classification_demo(): |
|
def classifier(text): |
|
time.sleep(1) |
|
return {label: random.random() for label in ["POSITIVE", "NEGATIVE", "NEUTRAL"]} |
|
|
|
def sleep_for_test(): |
|
time.sleep(10) |
|
return 2 |
|
|
|
with gr.Blocks(theme="gstaff/xkcd") as demo: |
|
with gr.Row(): |
|
with gr.Column(): |
|
input_text = gr.Textbox(label="Input Text") |
|
with gr.Row(): |
|
classify = gr.Button("Classify Sentiment") |
|
with gr.Column(): |
|
label = gr.Label(label="Predicted Sentiment") |
|
number = gr.Number() |
|
btn = gr.Button("Sleep then print") |
|
classify.click(classifier, input_text, label, api_name="classify") |
|
btn.click(sleep_for_test, None, number, api_name="sleep") |
|
|
|
return demo |
|
|
|
|
|
@pytest.fixture |
|
def count_generator_demo(): |
|
def count(n): |
|
for i in range(int(n)): |
|
time.sleep(0.5) |
|
yield i |
|
|
|
def show(n): |
|
return str(list(range(int(n)))) |
|
|
|
with gr.Blocks() as demo: |
|
with gr.Column(): |
|
num = gr.Number(value=10) |
|
with gr.Row(): |
|
count_btn = gr.Button("Count") |
|
list_btn = gr.Button("List") |
|
with gr.Column(): |
|
out = gr.Textbox() |
|
|
|
count_btn.click(count, num, out) |
|
list_btn.click(show, num, out) |
|
|
|
return demo |
|
|
|
|
|
@pytest.fixture |
|
def count_generator_no_api(): |
|
def count(n): |
|
for i in range(int(n)): |
|
time.sleep(0.5) |
|
yield i |
|
|
|
def show(n): |
|
return str(list(range(int(n)))) |
|
|
|
with gr.Blocks() as demo: |
|
with gr.Column(): |
|
num = gr.Number(value=10) |
|
with gr.Row(): |
|
count_btn = gr.Button("Count") |
|
list_btn = gr.Button("List") |
|
with gr.Column(): |
|
out = gr.Textbox() |
|
|
|
count_btn.click(count, num, out, api_name=False) |
|
list_btn.click(show, num, out, api_name=False) |
|
|
|
return demo |
|
|
|
|
|
@pytest.fixture |
|
def count_generator_demo_exception(): |
|
def count(n): |
|
for i in range(int(n)): |
|
time.sleep(0.01) |
|
if i == 5: |
|
raise ValueError("Oh no!") |
|
yield i |
|
|
|
def show(n): |
|
return str(list(range(int(n)))) |
|
|
|
with gr.Blocks() as demo: |
|
with gr.Column(): |
|
num = gr.Number(value=10) |
|
with gr.Row(): |
|
count_btn = gr.Button("Count") |
|
with gr.Column(): |
|
out = gr.Textbox() |
|
|
|
count_btn.click(count, num, out, api_name="count") |
|
return demo |
|
|
|
|
|
@pytest.fixture |
|
def file_io_demo(): |
|
demo = gr.Interface( |
|
lambda _: print("foox"), |
|
[gr.File(file_count="multiple"), "file"], |
|
[gr.File(file_count="multiple"), "file"], |
|
) |
|
|
|
return demo |
|
|
|
|
|
@pytest.fixture |
|
def stateful_chatbot(): |
|
with gr.Blocks() as demo: |
|
chatbot = gr.Chatbot() |
|
msg = gr.Textbox() |
|
clear = gr.Button("Clear") |
|
st = gr.State([1, 2, 3]) |
|
|
|
def respond(message, st, chat_history): |
|
assert st[0] == 1 and st[1] == 2 and st[2] == 3 |
|
bot_message = "I love you" |
|
chat_history.append((message, bot_message)) |
|
return "", chat_history |
|
|
|
msg.submit(respond, [msg, st, chatbot], [msg, chatbot], api_name="submit") |
|
clear.click(lambda: None, None, chatbot, queue=False) |
|
return demo |
|
|
|
|
|
@pytest.fixture |
|
def hello_world_with_group(): |
|
with gr.Blocks() as demo: |
|
name = gr.Textbox(label="name") |
|
output = gr.Textbox(label="greeting") |
|
greet = gr.Button("Greet") |
|
show_group = gr.Button("Show group") |
|
with gr.Group(visible=False) as group: |
|
gr.Textbox("Hello!") |
|
|
|
def greeting(name): |
|
return f"Hello {name}", gr.Group(visible=True) |
|
|
|
greet.click( |
|
greeting, inputs=[name], outputs=[output, group], api_name="greeting" |
|
) |
|
show_group.click( |
|
lambda: gr.Group(visible=False), None, group, api_name="show_group" |
|
) |
|
return demo |
|
|
|
|
|
@pytest.fixture |
|
def hello_world_with_state_and_accordion(): |
|
with gr.Blocks() as demo: |
|
with gr.Row(): |
|
name = gr.Textbox(label="name") |
|
output = gr.Textbox(label="greeting") |
|
num = gr.Number(label="count") |
|
with gr.Row(): |
|
n_counts = gr.State(value=0) |
|
greet = gr.Button("Greet") |
|
open_acc = gr.Button("Open acc") |
|
close_acc = gr.Button("Close acc") |
|
with gr.Accordion(label="Extra stuff", open=False) as accordion: |
|
gr.Textbox("Hello!") |
|
|
|
def greeting(name, state): |
|
state += 1 |
|
return state, f"Hello {name}", state, gr.Accordion(open=False) |
|
|
|
greet.click( |
|
greeting, |
|
inputs=[name, n_counts], |
|
outputs=[n_counts, output, num, accordion], |
|
api_name="greeting", |
|
) |
|
open_acc.click( |
|
lambda state: (state + 1, state + 1, gr.Accordion(open=True)), |
|
[n_counts], |
|
[n_counts, num, accordion], |
|
api_name="open", |
|
) |
|
close_acc.click( |
|
lambda state: (state + 1, state + 1, gr.Accordion(open=False)), |
|
[n_counts], |
|
[n_counts, num, accordion], |
|
api_name="close", |
|
) |
|
return demo |
|
|
|
|
|
@pytest.fixture |
|
def stream_audio(): |
|
import pathlib |
|
import tempfile |
|
|
|
def _stream_audio(audio_file): |
|
audio = AudioSegment.from_mp3(audio_file) |
|
i = 0 |
|
chunk_size = 3000 |
|
|
|
while chunk_size * i < len(audio): |
|
chunk = audio[chunk_size * i : chunk_size * (i + 1)] |
|
i += 1 |
|
if chunk: |
|
file = str(pathlib.Path(tempfile.gettempdir()) / f"{i}.wav") |
|
chunk.export(file, format="wav") |
|
yield file |
|
|
|
return gr.Interface( |
|
fn=_stream_audio, |
|
inputs=gr.Audio(type="filepath", label="Audio file to stream"), |
|
outputs=gr.Audio(autoplay=True, streaming=True), |
|
) |
|
|
|
|
|
@pytest.fixture |
|
def video_component(): |
|
return gr.Interface(fn=lambda x: x, inputs=gr.Video(), outputs=gr.Video()) |
|
|
|
|
|
@pytest.fixture |
|
def all_components(): |
|
classes_to_check = gr.components.Component.__subclasses__() |
|
subclasses = [] |
|
|
|
while classes_to_check: |
|
subclass = classes_to_check.pop() |
|
children = subclass.__subclasses__() |
|
|
|
if children: |
|
classes_to_check.extend(children) |
|
if ( |
|
"value" in inspect.signature(subclass).parameters |
|
and subclass != gr.components.Component |
|
and not getattr(subclass, "is_template", False) |
|
): |
|
subclasses.append(subclass) |
|
|
|
return subclasses |
|
|
|
|
|
@pytest.fixture(autouse=True) |
|
def gradio_temp_dir(monkeypatch, tmp_path): |
|
"""tmp_path is unique to each test function. |
|
It will be cleared automatically according to pytest docs: https://docs.pytest.org/en/6.2.x/reference.html#tmp-path |
|
""" |
|
monkeypatch.setenv("GRADIO_TEMP_DIR", str(tmp_path)) |
|
return tmp_path |
|
|
|
|
|
@pytest.fixture |
|
def long_response_with_info(): |
|
def long_response(_): |
|
gr.Info("Beginning long response") |
|
time.sleep(17) |
|
gr.Info("Done!") |
|
return "\ta\nb" * 90000 |
|
|
|
return gr.Interface( |
|
long_response, |
|
None, |
|
gr.Textbox(label="Output"), |
|
) |
|
|
|
|
|
@pytest.fixture |
|
def many_endpoint_demo(): |
|
with gr.Blocks() as demo: |
|
|
|
def noop(x): |
|
return x |
|
|
|
n_elements = 1000 |
|
for _ in range(n_elements): |
|
msg2 = gr.Textbox() |
|
msg2.submit(noop, msg2, msg2) |
|
butn2 = gr.Button() |
|
butn2.click(noop, msg2, msg2) |
|
|
|
return demo |
|
|
|
|
|
@pytest.fixture |
|
def max_file_size_demo(): |
|
with gr.Blocks() as demo: |
|
file_1b = gr.File() |
|
upload_status = gr.Textbox() |
|
|
|
file_1b.upload( |
|
lambda x: "Upload successful", file_1b, upload_status, api_name="upload_1b" |
|
) |
|
|
|
return demo |
|
|
|
|
|
@pytest.fixture |
|
def chatbot_message_format(): |
|
with gr.Blocks() as demo: |
|
chatbot = gr.Chatbot(type="messages") |
|
msg = gr.Textbox() |
|
|
|
def respond(message, chat_history: list): |
|
bot_message = random.choice( |
|
["How are you?", "I love you", "I'm very hungry"] |
|
) |
|
chat_history.extend( |
|
[ |
|
{"role": "user", "content": message}, |
|
{"role": "assistant", "content": bot_message}, |
|
] |
|
) |
|
return "", chat_history |
|
|
|
msg.submit(respond, [msg, chatbot], [msg, chatbot], api_name="chat") |
|
|
|
return demo |
|
|