Spaces:
Paused
Paused
import os | |
import re | |
import torch | |
from threading import Thread | |
from typing import Iterator | |
from mongoengine import connect, Document, StringField, SequenceField | |
import gradio as gr | |
import spaces | |
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer | |
from peft import PeftModel | |
import openai | |
from openai import OpenAI | |
import logging | |
openai.api_key = os.environ.get("OPENAI_KEY") | |
# Set up logging configuration | |
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s') | |
# Example usage of logging in your function | |
def generate_image(text): | |
try: | |
logging.debug("Generating image with prompt: %s", text) | |
response = openai.images.generate( | |
model="dall-e-3", | |
prompt="Create a 4 panel pixar style illustration that accurately depicts the character and the setting of a story:" + text, | |
n=1, | |
size="1024x1024" | |
) | |
image_url = response.data[0].url | |
logging.info("Image generated successfully: %s", image_url) | |
return image_url | |
except Exception as error: | |
logging.error("Failed to generate image: %s", str(error)) | |
raise gr.Error("An error occurred while generating the image. Please check your API key and try again.") | |
rope_scaling = { | |
'type': 'linear', # Adjust the type to the appropriate scaling type for your model. | |
'factor': 8.0 # Use the intended scaling factor. | |
} | |
# Constants | |
MAX_MAX_NEW_TOKENS = 2048 | |
DEFAULT_MAX_NEW_TOKENS = 1024 | |
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096")) | |
LICENSE = """ | |
--- | |
As a derivative work of [Llama-2-7b-hf](https://huggingface.co/meta-llama/Llama-2-7b-hf) by Meta, | |
this demo is governed by the original [license](https://huggingface.co/spaces/huggingface-projects/llama-2-7b-chat/blob/main/LICENSE.txt) and [acceptable use policy](https://huggingface.co/spaces/huggingface-projects/llama-2-7b-chat/blob/main/USE_POLICY.md). | |
""" | |
# GPU Check and add CPU warning | |
if not torch.cuda.is_available(): | |
DESCRIPTION += "\n<p>Running on CPU ๐ฅถ This demo does not work on CPU.</p>" | |
if torch.cuda.is_available(): | |
# Model and Tokenizer Configuration | |
model_id = "meta-llama/Llama-3.1-8B" | |
bnb_config = BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_use_double_quant=False, | |
bnb_4bit_quant_type="nf4", | |
bnb_4bit_compute_dtype=torch.bfloat16 | |
) | |
base_model = AutoModelForCausalLM.from_pretrained( | |
model_id, | |
device_map="auto", | |
quantization_config=bnb_config, | |
rope_scaling=rope_scaling # Add this only if your model specifically requires it. | |
) | |
model = PeftModel.from_pretrained(base_model, "ranamhamoud/storytellai-2.0") | |
tokenizer = AutoTokenizer.from_pretrained(model_id) | |
tokenizer.pad_token = tokenizer.eos_token | |
def make_prompt(entry): | |
return f"### Human: When asked to explain use a story.Don't repeat the assesments, limit to 500 words.However keep context in mind if edits to the content is required. {entry} ### Assistant:" | |
def process_text(text): | |
text = re.sub(r'\[answer:\]\s*', 'Answer: ', text) | |
text = re.sub(r'\[.*?\](?<!Answer: )', '', text) | |
return text | |
custom_css = """ | |
body, input, button, textarea, label { | |
font-family: Arial, sans-serif; | |
font-size: 24px; | |
} | |
.gr-chat-interface .gr-chat-message-container { | |
font-size: 14px; | |
} | |
.gr-button { | |
font-size: 14px; | |
padding: 12px 24px; | |
} | |
.gr-input { | |
font-size: 14px; | |
} | |
""" | |
def process_text(text): | |
text = re.sub(r'\[assessment;[^\]]*\]', '', text, flags=re.DOTALL) | |
text = re.sub(r'\[.*?\]', '', text, flags=re.DOTALL) | |
return text | |
def generate( | |
message: str, | |
chat_history: list[tuple[str, str]], | |
max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS, | |
temperature: float = 0.8, | |
top_p: float = 0.7, | |
top_k: int = 30, | |
repetition_penalty: float = 1.0, | |
) -> Iterator[str]: | |
conversation = [] | |
for user, assistant in chat_history: | |
conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}]) | |
conversation.append({"role": "user", "content": make_prompt(message)}) | |
enc = tokenizer(make_prompt(message), return_tensors="pt", padding=True, truncation=True) | |
input_ids = enc.input_ids.to(model.device) | |
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH: | |
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:] | |
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.") | |
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=False) | |
generate_kwargs = dict( | |
{"input_ids": input_ids}, | |
streamer=streamer, | |
max_new_tokens=max_new_tokens, | |
do_sample=True, | |
top_p=top_p, | |
top_k=top_k, | |
temperature=temperature, | |
num_beams=1, | |
repetition_penalty=repetition_penalty, | |
) | |
t = Thread(target=model.generate, kwargs=generate_kwargs) | |
t.start() | |
outputs = [] | |
for text in streamer: | |
processed_text = process_text(text) | |
outputs.append(processed_text) | |
output = "".join(outputs) | |
yield output | |
final_story = "".join(outputs) | |
final_story_trimmed = remove_last_sentence(final_story) | |
image_url = generate_image(final_story_trimmed) | |
return f"{final_story}\n\n![Generated Image]({image_url})" | |
def remove_last_sentence(text): | |
sentences = re.split(r'(?<=\.)\s', text) | |
return ' '.join(sentences[:-1]) if sentences else text | |
chat_interface = gr.ChatInterface( | |
fn=generate, | |
fill_height=True, | |
stop_btn=None, | |
examples=[ | |
["Tell me about HTTP."], | |
["Can you explain briefly to me what is the Python programming language?"], | |
["Could you please provide an explanation about Data Science?"], | |
["Could you explain what a URL is?"] | |
], | |
theme='shivi/calm_seafoam',autofocus=True, | |
) | |
js_func = """ | |
function refresh() { | |
const url = new URL(window.location); | |
if (url.searchParams.get('__theme') !== 'light') { | |
url.searchParams.set('__theme', 'light'); | |
window.location.href = url.href; | |
} | |
} | |
""" | |
# Gradio Web Interface | |
with gr.Blocks(css=custom_css,fill_height=True,theme="shivi/calm_seafoam") as demo: | |
chat_interface.render() | |
# gr.Markdown(LICENSE) | |
# Main Execution | |
if __name__ == "__main__": | |
demo.queue(max_size=20) | |
demo.launch(share=True) |