Spaces:
Sleeping
Sleeping
import requests | |
from langchain.chat_models import ChatOpenAI #model server | |
from langchain_groq import ChatGroq | |
from langchain.chains import LLMChain | |
from langchain.prompts import ( | |
PromptTemplate, | |
SystemMessagePromptTemplate, | |
HumanMessagePromptTemplate, | |
ChatPromptTemplate, | |
) | |
from config import app_config | |
import mongo_utils as mongo | |
GROQ_API_KEY = "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" | |
def __image2text(image): | |
"""Generates a short description of the image""" | |
headers = {"Authorization": app_config.HF_TOKEN} | |
try: | |
response = requests.post(app_config.I2T_API_URL, headers=headers, data=image) | |
response = response.json()[0]["generated_text"] | |
except Exception as e: | |
print(e) | |
return response | |
def __text2story(image_desc, genre, style, word_count, creativity): | |
""" "Generates a short story based on image description text prompt""" | |
## chat LLM model | |
# story_model = ChatOpenAI( | |
# model="gpt-3.5-turbo", | |
# openai_api_key=app_config.OPENAI_KEY, | |
# temperature=creativity, | |
# ) | |
story_model = ChatGroq(model="llama3-8b-8192", | |
temperature=0.0, | |
api_key=GROQ_API_KEY) | |
## chat message prompts | |
sys_prompt = PromptTemplate( | |
template="""You are an expert story writer, write a maximum of {word_count} | |
words long story in {genre} genre in {style} writing style, based on the user | |
provided story-context. | |
""", | |
input_variables=["word_count", "genre", "style"], | |
) | |
system_msg_prompt = SystemMessagePromptTemplate(prompt=sys_prompt) | |
human_prompt = PromptTemplate( | |
template="story-context: {context}", input_variables=["context"] | |
) | |
human_msg_prompt = HumanMessagePromptTemplate(prompt=human_prompt) | |
chat_prompt = ChatPromptTemplate.from_messages( | |
[system_msg_prompt, human_msg_prompt] | |
) | |
## LLM chain | |
story_chain = LLMChain(llm=story_model, prompt=chat_prompt) | |
response = story_chain.run( | |
genre=genre, style=style, word_count=word_count, context=image_desc | |
) | |
return response | |
def generate_story(image_file, genre, style, word_count, creativity): | |
"""Generates a story given an image""" | |
# read image as bytes arrayS | |
with open(image_file, "rb") as f: | |
input_image = f.read() | |
# generate caption for image | |
image_desc = __image2text(image=input_image) | |
print("++++++++++++++++++++++++++++++++++++++") | |
print(image_desc) | |
print("++++++++++++++++++++++++++++++++++++++") | |
# generate story from caption | |
story = __text2story( | |
image_desc=image_desc, | |
genre=genre, | |
style=style, | |
word_count=word_count, | |
creativity=creativity, | |
) | |
# increment the openai access counter and compute count stats | |
mongo.increment_curr_access_count() | |
max_count = app_config.openai_max_access_count | |
curr_count = app_config.openai_curr_access_count | |
available_count = max_count - curr_count | |
return story, max_count, curr_count, available_count | |