spark12x's picture
Update app.py
1aee3db verified
import gradio as gr
from llm_inference import LLMInferenceNode
import random
from PIL import Image
import io
title = """<h1 align="center">SD 3.5 Prompt Generator</h1>
<p><center>
<a href="https://x.com/gokayfem" target="_blank">[X gokaygokay]</a>
<a href="https://github.com/gokayfem" target="_blank">[Github gokayfem]</a>
<p align="center">Generate random prompts using powerful LLMs from Hugging Face and SambaNova.</p>
</center></p>
"""
def create_interface():
llm_node = LLMInferenceNode()
with gr.Blocks(theme='bethecloud/storj_theme') as demo:
gr.HTML(title)
with gr.Row():
with gr.Column(scale=2):
custom = gr.Textbox(label="Custom Input Prompt (optional)", lines=3)
prompt_types = ["Random", "Long", "Short", "Medium", "OnlyObjects", "NoFigure", "Landscape", "Fantasy"]
prompt_type = gr.Dropdown(
choices=prompt_types,
label="Prompt Type",
value="Random",
interactive=True
)
# Add a State component to store the selected prompt type
prompt_type_state = gr.State("Random")
# Update the function to use State and handle Random option
def update_prompt_type(value, state):
if value == "Random":
new_value = random.choice([t for t in prompt_types if t != "Random"])
print(f"Random prompt type selected: {new_value}")
return value, new_value
print(f"Updated prompt type: {value}")
return value, value
# Connect the update_prompt_type function to the prompt_type dropdown
prompt_type.change(update_prompt_type, inputs=[prompt_type, prompt_type_state], outputs=[prompt_type, prompt_type_state])
with gr.Column(scale=2):
with gr.Accordion("LLM Prompt Generation", open=False):
long_talk = gr.Checkbox(label="Long Talk", value=True)
compress = gr.Checkbox(label="Compress", value=True)
compression_level = gr.Dropdown(
choices=["soft", "medium", "hard"],
label="Compression Level",
value="hard"
)
custom_base_prompt = gr.Textbox(label="Custom Base Prompt", lines=5)
# LLM Provider Selection
llm_provider = gr.Dropdown(
choices=["Hugging Face", "SambaNova"],
label="LLM Provider",
value="Hugging Face"
)
api_key = gr.Textbox(label="API Key", type="password", visible=False)
model = gr.Dropdown(label="Model", choices=["Qwen/Qwen2.5-72B-Instruct","meta-llama/Meta-Llama-3.1-70B-Instruct","mistralai/Mixtral-8x7B-Instruct-v0.1","mistralai/Mistral-7B-Instruct-v0.3"], value="Qwen/Qwen2.5-72B-Instruct")
with gr.Row():
# **Single Button for Generating Prompt and Text**
generate_button = gr.Button("Generate Prompt")
with gr.Row():
text_output = gr.Textbox(label="LLM Generated Text", lines=10, show_copy_button=True)
# Updated Models based on provider
def update_model_choices(provider):
provider_models = {
"Hugging Face": [
"Qwen/Qwen2.5-72B-Instruct",
"meta-llama/Meta-Llama-3.1-70B-Instruct",
"mistralai/Mixtral-8x7B-Instruct-v0.1",
"mistralai/Mistral-7B-Instruct-v0.3"
],
"SambaNova": [
"Meta-Llama-3.1-70B-Instruct",
"Meta-Llama-3.1-405B-Instruct",
"Meta-Llama-3.1-8B-Instruct"
],
}
models = provider_models.get(provider, [])
return gr.Dropdown(choices=models, value=models[0] if models else "")
def update_api_key_visibility(provider):
return gr.update(visible=False) # No API key required for selected providers
llm_provider.change(
update_model_choices,
inputs=[llm_provider],
outputs=[model]
)
llm_provider.change(
update_api_key_visibility,
inputs=[llm_provider],
outputs=[api_key]
)
# **Unified Function to Generate Prompt and Text**
def generate_random_prompt_with_llm(custom_input, prompt_type, long_talk, compress, compression_level, custom_base_prompt, provider, api_key, model_selected, prompt_type_state):
try:
# Step 1: Generate Prompt
dynamic_seed = random.randint(0, 1000000)
# Update prompt_type if it's "Random"
if prompt_type == "Random":
prompt_type = random.choice([t for t in prompt_types if t != "Random"])
print(f"Random prompt type selected: {prompt_type}")
if custom_input and custom_input.strip():
prompt = llm_node.generate_prompt(dynamic_seed, prompt_type, custom_input)
print(f"Using Custom Input Prompt.")
else:
prompt = llm_node.generate_prompt(dynamic_seed, prompt_type, f"Create a random prompt based on the '{prompt_type}' type.")
print(f"No Custom Input Prompt provided. Generated prompt based on prompt_type: {prompt_type}")
print(f"Generated Prompt: {prompt}")
# Step 2: Generate Text with LLM
poster = False # Set a default value or modify as needed
result = llm_node.generate(
input_text=prompt,
long_talk=long_talk,
compress=compress,
compression_level=compression_level,
poster=poster,
prompt_type=prompt_type, # Use the updated prompt_type here
custom_base_prompt=custom_base_prompt,
provider=provider,
api_key=api_key,
model=model_selected
)
print(f"Generated Text: {result}")
return result
except Exception as e:
print(f"An error occurred: {e}")
return f"Error occurred while processing the request: {str(e)}"
# **Connect the Unified Function to the Single Button**
generate_button.click(
generate_random_prompt_with_llm,
inputs=[custom, prompt_type, long_talk, compress, compression_level, custom_base_prompt, llm_provider, api_key, model, prompt_type_state],
outputs=[text_output],
api_name="generate_random_prompt_with_llm"
)
return demo
if __name__ == "__main__":
demo = create_interface()
demo.launch(share=True)