Spaces:
Runtime error
Runtime error
import os | |
import string | |
import copy | |
import gradio as gr | |
import PIL.Image | |
import torch | |
from transformers import BitsAndBytesConfig, pipeline | |
import re | |
import time | |
import random | |
DESCRIPTION = "# LLaVA ππͺ - Now with Arnold Mode and Bodybuilding Coaching Expertise!" | |
model_id = "llava-hf/llava-1.5-7b-hf" | |
quantization_config = BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_compute_dtype=torch.float16 | |
) | |
pipe = pipeline("image-to-text", model=model_id, model_kwargs={"quantization_config": quantization_config}) | |
bodybuilding_criteria = { | |
"Muscular Size": "Focus on overall muscle mass and development.", | |
"Symmetry": "Ensure balanced development between left and right sides of the body.", | |
"Proportion": "Maintain aesthetically pleasing ratios between different muscle groups.", | |
"Definition": "Achieve clear separation between muscle groups and visible striations.", | |
"Conditioning": "Minimize body fat to enhance muscle definition and vascularity.", | |
"Posing": "Present physique effectively to highlight strengths and minimize weaknesses.", | |
} | |
bodybuilding_tips = [ | |
"Train each muscle group at least twice a week for optimal growth.", | |
"Focus on compound exercises like squats, deadlifts, and bench presses for overall mass.", | |
"Don't neglect your legs! They're half your physique.", | |
"Proper nutrition is key. Eat clean and maintain a caloric surplus for growth.", | |
"Get enough rest. Muscles grow during recovery, not in the gym.", | |
"Practice your posing regularly. It's not just for shows, it helps mind-muscle connection.", | |
"Stay hydrated. Water is crucial for muscle function and recovery.", | |
] | |
def extract_response_pairs(text): | |
turns = re.split(r'(USER:|ASSISTANT:)', text)[1:] | |
turns = [turn.strip() for turn in turns if turn.strip()] | |
conv_list = [] | |
for i in range(0, len(turns[1::2]), 2): | |
if i + 1 < len(turns[1::2]): | |
conv_list.append([turns[1::2][i].lstrip(":"), turns[1::2][i + 1].lstrip(":")]) | |
return conv_list | |
def add_text(history, text): | |
history = history + [[text, None]] | |
return history, "" # Clear the input field after submission | |
def arnold_speak(text): | |
arnold_phrases = [ | |
"Come with me if you want to lift!", | |
"I'll be back... after my protein shake.", | |
"Hasta la vista, baby weight!", | |
"Get to da choppa... I mean, da squat rack!", | |
"You lack discipline! But don't worry, I'm here to pump you up!" | |
] | |
text = text.replace(".", "!") # More enthusiastic punctuation | |
text = text.replace("gym", "iron paradise") | |
text = text.replace("exercise", "pump iron") | |
text = text.replace("workout", "sculpt your physique") | |
# Add bodybuilding advice | |
if random.random() < 0.7: # 70% chance to add bodybuilding advice | |
advice = random.choice(list(bodybuilding_criteria.items())) | |
text += f" Remember, in bodybuilding, {advice[0]} is crucial! {advice[1]}" | |
# Add a bodybuilding tip | |
if random.random() < 0.5: # 50% chance to add a tip | |
tip = random.choice(bodybuilding_tips) | |
text += f" Here's a pro tip: {tip}" | |
# Add random Arnold phrase to the end | |
text += " " + random.choice(arnold_phrases) | |
return text | |
def infer(image, prompt, temperature, length_penalty, repetition_penalty, max_length, min_length, top_p): | |
try: | |
outputs = pipe(images=image, prompt=prompt, | |
generate_kwargs={"temperature": temperature, | |
"length_penalty": length_penalty, | |
"repetition_penalty": repetition_penalty, | |
"max_length": max_length, | |
"min_length": min_length, | |
"top_p": top_p}) | |
inference_output = outputs[0]["generated_text"] | |
return inference_output | |
except Exception as e: | |
print(f"Error during inference: {str(e)}") | |
return f"An error occurred during inference: {str(e)}" | |
def bot(history, text_input, image, temperature, length_penalty, repetition_penalty, max_length, min_length, top_p, arnold_mode): | |
if text_input == "": | |
yield history + [["Please input text", None]] | |
return | |
if image is None: | |
yield history + [["Please input image or wait for image to be uploaded before clicking submit.", None]] | |
return | |
if arnold_mode: | |
system_prompt = """You are Arnold Schwarzenegger, the famous bodybuilder, actor, and former Mr. Olympia. | |
Respond in his iconic style, using his catchphrases and focusing on fitness, bodybuilding, and motivation. | |
Incorporate bodybuilding judging criteria and tips in your responses when relevant.""" | |
else: | |
system_prompt = "You are a helpful AI assistant. Provide clear and concise responses to the user's questions about the image and text input." | |
# Use only the current input for generating the response | |
prompt = f"{system_prompt}\nUSER: <image>\n{text_input}\nASSISTANT:" | |
response = infer(image, prompt, temperature, length_penalty, repetition_penalty, max_length, min_length, top_p) | |
if arnold_mode: | |
response = arnold_speak(response) | |
history.append([text_input, ""]) | |
for i in range(len(response)): | |
history[-1][1] = response[:i+1] | |
time.sleep(0.05) | |
yield history | |
with gr.Blocks() as demo: | |
gr.Markdown(DESCRIPTION) | |
gr.Markdown("""## LLaVA, one of the greatest multimodal chat models is now available in Transformers with 4-bit quantization! β‘οΈ | |
See the docs here: https://huggingface.co/docs/transformers/main/en/model_doc/llava.""") | |
chatbot = gr.Chatbot() | |
with gr.Row(): | |
image = gr.Image(type="pil") | |
with gr.Column(): | |
text_input = gr.Textbox(label="Chat Input", lines=3) | |
arnold_mode = gr.Checkbox(label="Arnold Schwarzenegger Mode") | |
with gr.Accordion(label="Advanced settings", open=False): | |
temperature = gr.Slider(label="Temperature", minimum=0.5, maximum=1.0, value=1.0, step=0.1) | |
length_penalty = gr.Slider(label="Length Penalty", minimum=-1.0, maximum=2.0, value=1.0, step=0.2) | |
repetition_penalty = gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=5.0, value=1.5, step=0.5) | |
max_length = gr.Slider(label="Max Length", minimum=1, maximum=500, value=200, step=1) | |
min_length = gr.Slider(label="Minimum Length", minimum=1, maximum=100, value=1, step=1) | |
top_p = gr.Slider(label="Top P", minimum=0.5, maximum=1.0, value=0.9, step=0.1) | |
with gr.Row(): | |
clear_button = gr.Button("Clear") | |
submit_button = gr.Button("Submit", variant="primary") | |
submit_button.click( | |
fn=bot, | |
inputs=[chatbot, text_input, image, temperature, length_penalty, repetition_penalty, max_length, min_length, top_p, arnold_mode], | |
outputs=chatbot | |
).then( | |
fn=lambda: "", | |
outputs=text_input | |
) | |
clear_button.click(lambda: ([], None), outputs=[chatbot, image], queue=False) | |
examples = [ | |
["./examples/bodybuilder.jpeg", "What do you think of this physique?"], | |
["./examples/gym.jpeg", "How can I improve my workout routine?"] | |
] | |
gr.Examples(examples=examples, inputs=[image, text_input]) | |
if __name__ == "__main__": | |
demo.queue(max_size=10).launch(debug=True) |