llava-4bit / app.py
whan12's picture
Update app.py
1989b5f verified
import os
import string
import copy
import gradio as gr
import PIL.Image
import torch
from transformers import BitsAndBytesConfig, pipeline
import re
import time
import random
DESCRIPTION = "# LLaVA πŸŒ‹πŸ’ͺ - Now with Arnold Mode and Bodybuilding Coaching Expertise!"
model_id = "llava-hf/llava-1.5-7b-hf"
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16
)
pipe = pipeline("image-to-text", model=model_id, model_kwargs={"quantization_config": quantization_config})
bodybuilding_criteria = {
"Muscular Size": "Focus on overall muscle mass and development.",
"Symmetry": "Ensure balanced development between left and right sides of the body.",
"Proportion": "Maintain aesthetically pleasing ratios between different muscle groups.",
"Definition": "Achieve clear separation between muscle groups and visible striations.",
"Conditioning": "Minimize body fat to enhance muscle definition and vascularity.",
"Posing": "Present physique effectively to highlight strengths and minimize weaknesses.",
}
bodybuilding_tips = [
"Train each muscle group at least twice a week for optimal growth.",
"Focus on compound exercises like squats, deadlifts, and bench presses for overall mass.",
"Don't neglect your legs! They're half your physique.",
"Proper nutrition is key. Eat clean and maintain a caloric surplus for growth.",
"Get enough rest. Muscles grow during recovery, not in the gym.",
"Practice your posing regularly. It's not just for shows, it helps mind-muscle connection.",
"Stay hydrated. Water is crucial for muscle function and recovery.",
]
def extract_response_pairs(text):
turns = re.split(r'(USER:|ASSISTANT:)', text)[1:]
turns = [turn.strip() for turn in turns if turn.strip()]
conv_list = []
for i in range(0, len(turns[1::2]), 2):
if i + 1 < len(turns[1::2]):
conv_list.append([turns[1::2][i].lstrip(":"), turns[1::2][i + 1].lstrip(":")])
return conv_list
def add_text(history, text):
history = history + [[text, None]]
return history, "" # Clear the input field after submission
def arnold_speak(text):
arnold_phrases = [
"Come with me if you want to lift!",
"I'll be back... after my protein shake.",
"Hasta la vista, baby weight!",
"Get to da choppa... I mean, da squat rack!",
"You lack discipline! But don't worry, I'm here to pump you up!"
]
text = text.replace(".", "!") # More enthusiastic punctuation
text = text.replace("gym", "iron paradise")
text = text.replace("exercise", "pump iron")
text = text.replace("workout", "sculpt your physique")
# Add bodybuilding advice
if random.random() < 0.7: # 70% chance to add bodybuilding advice
advice = random.choice(list(bodybuilding_criteria.items()))
text += f" Remember, in bodybuilding, {advice[0]} is crucial! {advice[1]}"
# Add a bodybuilding tip
if random.random() < 0.5: # 50% chance to add a tip
tip = random.choice(bodybuilding_tips)
text += f" Here's a pro tip: {tip}"
# Add random Arnold phrase to the end
text += " " + random.choice(arnold_phrases)
return text
def infer(image, prompt, temperature, length_penalty, repetition_penalty, max_length, min_length, top_p):
try:
outputs = pipe(images=image, prompt=prompt,
generate_kwargs={"temperature": temperature,
"length_penalty": length_penalty,
"repetition_penalty": repetition_penalty,
"max_length": max_length,
"min_length": min_length,
"top_p": top_p})
inference_output = outputs[0]["generated_text"]
return inference_output
except Exception as e:
print(f"Error during inference: {str(e)}")
return f"An error occurred during inference: {str(e)}"
def bot(history, text_input, image, temperature, length_penalty, repetition_penalty, max_length, min_length, top_p, arnold_mode):
if text_input == "":
yield history + [["Please input text", None]]
return
if image is None:
yield history + [["Please input image or wait for image to be uploaded before clicking submit.", None]]
return
if arnold_mode:
system_prompt = """You are Arnold Schwarzenegger, the famous bodybuilder, actor, and former Mr. Olympia.
Respond in his iconic style, using his catchphrases and focusing on fitness, bodybuilding, and motivation.
Incorporate bodybuilding judging criteria and tips in your responses when relevant."""
else:
system_prompt = "You are a helpful AI assistant. Provide clear and concise responses to the user's questions about the image and text input."
# Use only the current input for generating the response
prompt = f"{system_prompt}\nUSER: <image>\n{text_input}\nASSISTANT:"
response = infer(image, prompt, temperature, length_penalty, repetition_penalty, max_length, min_length, top_p)
if arnold_mode:
response = arnold_speak(response)
history.append([text_input, ""])
for i in range(len(response)):
history[-1][1] = response[:i+1]
time.sleep(0.05)
yield history
with gr.Blocks() as demo:
gr.Markdown(DESCRIPTION)
gr.Markdown("""## LLaVA, one of the greatest multimodal chat models is now available in Transformers with 4-bit quantization! ⚑️
See the docs here: https://huggingface.co/docs/transformers/main/en/model_doc/llava.""")
chatbot = gr.Chatbot()
with gr.Row():
image = gr.Image(type="pil")
with gr.Column():
text_input = gr.Textbox(label="Chat Input", lines=3)
arnold_mode = gr.Checkbox(label="Arnold Schwarzenegger Mode")
with gr.Accordion(label="Advanced settings", open=False):
temperature = gr.Slider(label="Temperature", minimum=0.5, maximum=1.0, value=1.0, step=0.1)
length_penalty = gr.Slider(label="Length Penalty", minimum=-1.0, maximum=2.0, value=1.0, step=0.2)
repetition_penalty = gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=5.0, value=1.5, step=0.5)
max_length = gr.Slider(label="Max Length", minimum=1, maximum=500, value=200, step=1)
min_length = gr.Slider(label="Minimum Length", minimum=1, maximum=100, value=1, step=1)
top_p = gr.Slider(label="Top P", minimum=0.5, maximum=1.0, value=0.9, step=0.1)
with gr.Row():
clear_button = gr.Button("Clear")
submit_button = gr.Button("Submit", variant="primary")
submit_button.click(
fn=bot,
inputs=[chatbot, text_input, image, temperature, length_penalty, repetition_penalty, max_length, min_length, top_p, arnold_mode],
outputs=chatbot
).then(
fn=lambda: "",
outputs=text_input
)
clear_button.click(lambda: ([], None), outputs=[chatbot, image], queue=False)
examples = [
["./examples/bodybuilder.jpeg", "What do you think of this physique?"],
["./examples/gym.jpeg", "How can I improve my workout routine?"]
]
gr.Examples(examples=examples, inputs=[image, text_input])
if __name__ == "__main__":
demo.queue(max_size=10).launch(debug=True)