Spaces:
Paused
Paused
File size: 6,512 Bytes
b75125a 297485e 841e4af b75125a 841e4af b75125a e70034d e3f86b5 def4aea e3f86b5 e70034d 70f6552 e3f86b5 def4aea e3f86b5 def4aea 5845573 e3f86b5 632a2cd e3f86b5 def4aea e3f86b5 def4aea 5845573 e3f86b5 65b2f34 9363edf c86c2f3 297485e b75125a e59ec94 b75125a 4522cd0 b75125a 841e4af 4d5d8af b75125a 841e4af 4d5d8af 89044b5 44a3700 89044b5 9363edf 44a3700 89044b5 b75125a c2b7b7c b75125a 632a2cd 4d5d8af 2a727f4 4d5d8af 2de7d8c 4d5d8af b75125a c2116d4 b75125a 4d5d8af c2116d4 4d5d8af c2116d4 4d5d8af b75125a 9905ae2 4d5d8af 469c0f9 4d5d8af 841e4af 4d5d8af b75125a 4d5d8af b75125a 841e4af f6ff388 b75125a 4d5d8af e3f86b5 c2116d4 6ea5aa1 ac36988 e6dd388 c2116d4 b75125a 8426bbf b75125a 2de7d8c b75125a dd8ad6c b75125a 70579dc 4d722d1 70579dc b75125a c892aad 2a727f4 2de7d8c 4d5d8af f317c15 2a727f4 b75125a 4d5d8af |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 |
import os
import re
import torch
from threading import Thread
from typing import Iterator
from mongoengine import connect, Document, StringField, SequenceField
import gradio as gr
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer
from peft import PeftModel
import openai
from openai import OpenAI
import logging
openai.api_key = os.environ.get("OPENAI_KEY")
# Set up logging configuration
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
# Example usage of logging in your function
def generate_image(text):
try:
logging.debug("Generating image with prompt: %s", text)
response = openai.images.generate(
model="dall-e-3",
prompt="Create a 4 panel pixar style illustration that accurately depicts the character and the setting of a story:" + text,
n=1,
size="1024x1024"
)
image_url = response.data[0].url
logging.info("Image generated successfully: %s", image_url)
return image_url
except Exception as error:
logging.error("Failed to generate image: %s", str(error))
raise gr.Error("An error occurred while generating the image. Please check your API key and try again.")
rope_scaling = {
'type': 'linear', # Adjust the type to the appropriate scaling type for your model.
'factor': 8.0 # Use the intended scaling factor.
}
# Constants
MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 1024
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
LICENSE = """
---
As a derivative work of [Llama-2-7b-hf](https://huggingface.co/meta-llama/Llama-2-7b-hf) by Meta,
this demo is governed by the original [license](https://huggingface.co/spaces/huggingface-projects/llama-2-7b-chat/blob/main/LICENSE.txt) and [acceptable use policy](https://huggingface.co/spaces/huggingface-projects/llama-2-7b-chat/blob/main/USE_POLICY.md).
"""
# GPU Check and add CPU warning
if not torch.cuda.is_available():
DESCRIPTION += "\n<p>Running on CPU ๐ฅถ This demo does not work on CPU.</p>"
if torch.cuda.is_available():
# Model and Tokenizer Configuration
model_id = "meta-llama/Llama-3.1-8B"
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=False,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
base_model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto",
quantization_config=bnb_config,
rope_scaling=rope_scaling # Add this only if your model specifically requires it.
)
model = PeftModel.from_pretrained(base_model, "ranamhamoud/storytellai-2.0")
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
def make_prompt(entry):
return f"### Human: When asked to explain use a story.Don't repeat the assesments, limit to 500 words.However keep context in mind if edits to the content is required. {entry} ### Assistant:"
def process_text(text):
text = re.sub(r'\[answer:\]\s*', 'Answer: ', text)
text = re.sub(r'\[.*?\](?<!Answer: )', '', text)
return text
custom_css = """
body, input, button, textarea, label {
font-family: Arial, sans-serif;
font-size: 24px;
}
.gr-chat-interface .gr-chat-message-container {
font-size: 14px;
}
.gr-button {
font-size: 14px;
padding: 12px 24px;
}
.gr-input {
font-size: 14px;
}
"""
def process_text(text):
text = re.sub(r'\[assessment;[^\]]*\]', '', text, flags=re.DOTALL)
text = re.sub(r'\[.*?\]', '', text, flags=re.DOTALL)
return text
@spaces.GPU
def generate(
message: str,
chat_history: list[tuple[str, str]],
max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
temperature: float = 0.8,
top_p: float = 0.7,
top_k: int = 30,
repetition_penalty: float = 1.0,
) -> Iterator[str]:
conversation = []
for user, assistant in chat_history:
conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
conversation.append({"role": "user", "content": make_prompt(message)})
enc = tokenizer(make_prompt(message), return_tensors="pt", padding=True, truncation=True)
input_ids = enc.input_ids.to(model.device)
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=False)
generate_kwargs = dict(
{"input_ids": input_ids},
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
top_p=top_p,
top_k=top_k,
temperature=temperature,
num_beams=1,
repetition_penalty=repetition_penalty,
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for text in streamer:
processed_text = process_text(text)
outputs.append(processed_text)
output = "".join(outputs)
yield output
final_story = "".join(outputs)
final_story_trimmed = remove_last_sentence(final_story)
image_url = generate_image(final_story_trimmed)
return f"{final_story}\n\n![Generated Image]({image_url})"
def remove_last_sentence(text):
sentences = re.split(r'(?<=\.)\s', text)
return ' '.join(sentences[:-1]) if sentences else text
chat_interface = gr.ChatInterface(
fn=generate,
fill_height=True,
stop_btn=None,
examples=[
["Tell me about HTTP."],
["Can you explain briefly to me what is the Python programming language?"],
["Could you please provide an explanation about Data Science?"],
["Could you explain what a URL is?"]
],
theme='shivi/calm_seafoam',autofocus=True,
)
js_func = """
function refresh() {
const url = new URL(window.location);
if (url.searchParams.get('__theme') !== 'light') {
url.searchParams.set('__theme', 'light');
window.location.href = url.href;
}
}
"""
# Gradio Web Interface
with gr.Blocks(css=custom_css,fill_height=True,theme="shivi/calm_seafoam") as demo:
chat_interface.render()
# gr.Markdown(LICENSE)
# Main Execution
if __name__ == "__main__":
demo.queue(max_size=20)
demo.launch(share=True) |