MoodCamera / app.py
Anustup's picture
Update app.py
4c2d9c7 verified
raw
history blame
18.6 kB
import streamlit as st
import os
import re
from claude import embed_base64_for_claude, create_claude_image_request_for_image_captioning, \
create_claude_request_for_text_completion, extract_data_from_text_xml
from prompts import prompts
from constants import JSON_SCHEMA_FOR_GPT, UPDATED_MODEL_ONLY_SCHEMA, JSON_SCHEMA_FOR_LOC_ONLY
from gpt import runAssistant, checkRunStatus, retrieveThread, createAssistant, saveFileOpenAI, startAssistantThread, \
create_chat_completion_request_open_ai_for_summary, addMessageToThread, create_image_completion_request_gpt
from summarizer import create_brand_html, create_langchain_openai_query, create_screenshot_from_scrap_fly
from theme import flux_generated_image, flux_generated_image_seed
import time
from PIL import Image
import io
from streamlit_gsheets import GSheetsConnection
# conn = st.connection("gsheets", type=GSheetsConnection)
def process_run(st, thread_id, assistant_id):
run_id = runAssistant(thread_id, assistant_id)
status = 'running'
while status != 'completed':
with st.spinner('. . .'):
time.sleep(20)
status = checkRunStatus(thread_id, run_id)
thread_messages = retrieveThread(thread_id)
for message in thread_messages:
if not message['role'] == 'user':
return message["content"]
else:
pass
def page5():
st.title('Initialize your preferences!')
system_prompt_passed = st.text_area("System Prompt", value=prompts["PROMPT_FOR_MOOD_AND_IDEA"],
key="System Prompt")
caption_system_prompt = st.text_area("Captioning System Prompt", value=prompts["CAPTION_SYSTEM_PROMPT"],
key="Caption Generation System Prompt")
caption_prompt = st.text_area("Caption Prompt", value=prompts["CAPTION_PROMPT"],
key="Caption Generation Prompt")
brand_summary_prompt = st.text_area("Prompt for Brand Summary", value=prompts["BRAND_SUMMARY_PROMPT"],
key="Brand summary prompt")
st.text("Running on Claude")
col1, col2 = st.columns([1, 2])
with col1:
if st.button("Save the Prompt"):
st.session_state["system_prompt"] = system_prompt_passed
print(st.session_state["system_prompt"])
st.session_state["caption_system_prompt"] = caption_system_prompt
st.session_state["caption_prompt"] = caption_prompt
st.session_state["brand_prompt"] = brand_summary_prompt
st.success("Saved your prompts")
with col2:
if st.button("Start Testing!"):
st.session_state['page'] = "Page 1"
def page1():
st.title("Upload Product")
st.markdown("<h2 style='color:#FF5733; font-weight:bold;'>Add a Product</h2>", unsafe_allow_html=True)
st.markdown("<p style='color:#444;'>Upload your product images, more images you upload better the AI learns</p>",
unsafe_allow_html=True)
uploaded_files = st.file_uploader("Upload Images", accept_multiple_files=True, key="uploaded_files_key")
product_description = st.text_area("Describe the product", value=st.session_state.get("product_description", ""))
col1, col2 = st.columns([1, 2])
with col1:
if st.button("Save"):
st.session_state['uploaded_files'] = uploaded_files
st.session_state['product_description'] = product_description
st.success("Product information saved!")
with col2:
if st.button("Add product and move to next page"):
if not uploaded_files:
st.warning("Please upload at least one image.")
elif not product_description:
st.warning("Please provide a description for the product.")
else:
st.session_state['uploaded_files'] = uploaded_files
st.session_state['product_description'] = product_description
st.session_state['page'] = "Page 2"
def page2():
st.title("Tell us about your shoot preference")
st.markdown("<h3 style='color:#444;'>What are you shooting today?</h3>", unsafe_allow_html=True)
shoot_type = st.radio("Select your shoot type:", ["Editorial", "Catalogue"], index=0)
st.session_state['shoot_type'] = shoot_type
brand_link = st.text_input("Add your brand link:", value=st.session_state.get("brand_link", ""))
st.session_state['brand_link'] = brand_link
if st.button("Get Brand Summary"):
if brand_link:
st.text("Using Scrapfly")
brand_summary_html = create_screenshot_from_scrap_fly(brand_link)
if brand_summary_html["success"]:
st.image(brand_summary_html["location"])
brand_image_embed = embed_base64_for_claude(brand_summary_html["location"])
brand_summary_response = create_claude_image_request_for_image_captioning(
"Fashion expert of understanding brand details",
st.session_state["brand_prompt"], brand_image_embed)
st.session_state['brand_summary'] = brand_summary_response
else:
st.text(f"Scrapfly failed due to: {brand_summary_html}")
st.text("Using Langchain")
brand_summary_html = create_brand_html(brand_link)
brand_summary = create_langchain_openai_query(brand_summary_html)
st.session_state['brand_summary'] = brand_summary
st.success("Brand summary fetched!")
else:
st.warning("Please add a brand link.")
brand_summary_value = st.session_state.get('brand_summary', "")
editable_summary = st.text_area("Brand Summary:", value=brand_summary_value, height=100)
st.session_state['brand_summary'] = editable_summary
product_info = st.text_area("Tell us something about your product:", value=st.session_state.get("product_info", ""))
st.session_state['product_info'] = product_info
reference_images = st.file_uploader("Upload Reference Images", accept_multiple_files=True,
key="reference_images_key")
st.session_state['reference_images'] = reference_images
if st.button("Give Me Ideas"):
st.session_state['page'] = "Page 3"
def page3():
import random
st.title("Scene Suggestions")
st.write("Based on your uploaded product and references!")
feedback = st.chat_input("Provide feedback:")
if not st.session_state.get("assistant_initialized", False):
file_locations_for_product = []
for uploaded_file in st.session_state['uploaded_files']:
bytes_data = uploaded_file.getvalue()
image = Image.open(io.BytesIO(bytes_data))
image.verify()
location = f"temp_image_{random.randint(1, 100000000)}.png"
with open(location, "wb") as f:
f.write(bytes_data)
file_locations_for_product.append(location)
image.close()
file_base64_embeds_product = [embed_base64_for_claude(location) for location in file_locations_for_product]
caption_list_from_claude_product = []
for file_embeds_base64 in file_base64_embeds_product:
caption_from_claude = create_claude_image_request_for_image_captioning(
st.session_state["caption_system_prompt"], st.session_state["caption_prompt"], file_embeds_base64)
caption_list_from_claude_product.append(caption_from_claude)
string_caption_list_product = str(caption_list_from_claude_product)
file_locations_for_others = []
for uploaded_file in st.session_state['reference_images']:
bytes_data = uploaded_file.getvalue()
image = Image.open(io.BytesIO(bytes_data))
image.verify()
location = f"temp2_image_{random.randint(1, 1000000)}.png"
with open(location, "wb") as f:
f.write(bytes_data)
file_locations_for_others.append(location)
image.close()
file_base64_embeds = [embed_base64_for_claude(location) for location in file_locations_for_others]
st.session_state.assistant_initialized = True
caption_list_from_claude = []
for file_embeds_base64 in file_base64_embeds:
caption_from_claude = create_claude_image_request_for_image_captioning(
st.session_state["caption_system_prompt"], st.session_state["caption_prompt"], file_embeds_base64)
caption_list_from_claude.append(caption_from_claude)
string_caption_list = str(caption_list_from_claude)
st.session_state["caption_product"] = string_caption_list_product
st.session_state["additional_caption"] = string_caption_list
additional_info_param_for_prompt = f"Brand have provided reference images whose details are:" \
f"```{string_caption_list}```. Apart from this brand needs" \
f"{st.session_state['shoot_type']}"
product_info = str(string_caption_list_product) + st.session_state['product_info']
updated_prompt_for_claude = st.session_state["system_prompt"].format(
BRAND_DETAILS=st.session_state['brand_summary'],
PRODUCT_DETAILS=product_info,
ADDITIONAL_INFO=additional_info_param_for_prompt
)
st.session_state["updated_prompt"] = updated_prompt_for_claude
message_schema_for_claude = [
{
"role": "user",
"content": [
{
"type": "text",
"text": updated_prompt_for_claude
}
]
}
]
response_from_claude = create_claude_request_for_text_completion(message_schema_for_claude)
campaign_pattern = r"<campaign_idea>(.*?)</campaign_idea>"
campaigns = re.findall(campaign_pattern, response_from_claude, re.DOTALL)
concat_prompt_list = []
for idx, campaign in enumerate(campaigns, start=1):
get_model_prompt = extract_data_from_text_xml(campaign, "model_prompt")
get_background_prompt = extract_data_from_text_xml(campaign, "background_prompt")
concat_prompt_flux = get_model_prompt + get_background_prompt
concat_prompt_list.append(concat_prompt_flux)
flux_generated_theme_image = []
for concat_prompt in concat_prompt_list:
theme_image = flux_generated_image(concat_prompt)
flux_generated_theme_image.append(theme_image["file_name"])
print(flux_generated_theme_image)
st.session_state["descriptions"] = concat_prompt_list
st.session_state["claude_context"] = response_from_claude
st.session_state["images"] = flux_generated_theme_image
if feedback:
updated_context = st.session_state["claude_context"]
if 'images' in st.session_state and 'descriptions' in st.session_state:
for image_path in st.session_state['images']:
os.remove(image_path)
del st.session_state['images']
del st.session_state['descriptions']
del st.session_state["claude_context"]
message_schema_for_claude = [
{
"role": "user",
"content": [
{
"type": "text",
"text": st.session_state["updated_prompt"]
}
]
},
{
"role": "assistant",
"content": [
{
"type": "text",
"text": updated_context}
]
},
{
"role": "user",
"content": [
{
"type": "text",
"text": feedback
}
]
},
]
response_from_claude = create_claude_request_for_text_completion(message_schema_for_claude)
campaign_pattern = r"<campaign_idea>(.*?)</campaign_idea>"
campaigns = re.findall(campaign_pattern, response_from_claude, re.DOTALL)
concat_prompt_list = []
for idx, campaign in enumerate(campaigns, start=1):
get_model_prompt = extract_data_from_text_xml(campaign, "model_prompt")
get_background_prompt = extract_data_from_text_xml(campaign, "background_prompt")
concat_prompt_flux = get_model_prompt + get_background_prompt
concat_prompt_list.append(concat_prompt_flux)
flux_generated_theme_image = []
for concat_prompt in concat_prompt_list:
theme_image = flux_generated_image(concat_prompt)
flux_generated_theme_image.append(theme_image["file_name"])
st.session_state["descriptions"] = concat_prompt_list
st.session_state["claude_context"] = response_from_claude
st.session_state["images"] = flux_generated_theme_image
selected_image_index = None
cols = st.columns(4)
for i in range(len(st.session_state["images"])):
with cols[i]:
st.image(st.session_state.images[i], caption=st.session_state.descriptions[i], use_column_width=True)
if st.radio(f"Select {i + 1}", [f"Select Image {i + 1}"], key=f"radio_{i}"):
selected_image_index = i
if selected_image_index is not None and st.button("Refine"):
st.session_state.selected_image_index = selected_image_index
st.session_state.selected_image = st.session_state.images[selected_image_index]
st.session_state.selected_text = st.session_state.descriptions[selected_image_index]
st.session_state['page'] = "Page 4"
if st.button("Go Back!"):
st.session_state.page = "Page 2"
def page4():
import json
selected_theme_text_by_user = st.session_state.descriptions[st.session_state.selected_image_index]
print(selected_theme_text_by_user)
with (st.sidebar):
st.title(st.session_state["product_info"])
st.write("Product Image")
st.image(st.session_state['uploaded_files'])
st.text("Scene Suggestion:")
st.image(st.session_state.selected_image)
dimensions = st.text_input("Enter Dimensions e.g 3:4, 1:2", key="Dimensions")
seed = st.selectbox(
"Seed Preference",
("Fixed", "Random"),
)
if seed == "Fixed":
seed_number = st.number_input("Enter an integer:", min_value=1, max_value=100000, value=10, step=1)
else:
seed_number = 0
st.text("Thanks will take care")
model__bg_preference = st.text_area("Edit Model & BG Idea", value=selected_theme_text_by_user,
key="Model & BG Idea")
start_chat = st.button("Start Chat")
if "mood_chat_messages" not in st.session_state:
st.session_state["mood_chat_messages"] = []
if seed and dimensions and model__bg_preference:
if start_chat:
if seed == "Fixed":
generated_flux_image = flux_generated_image_seed(model__bg_preference, seed_number, dimensions)
else:
generated_flux_image = flux_generated_image(model__bg_preference)
st.session_state["mood_chat_messages"].append({
"role": "AI",
"message": model__bg_preference,
"image": generated_flux_image["file_name"]
})
# for message in st.session_state["mood_chat_messages"]:
# if message["role"] == "AI":
# st.write(f"Caimera AI: {message['message']}")
# st.image(message['image'])
#else:
# st.write(f"**You**: {message['message']}")
user_input = st.chat_input("Type your message here...")
if user_input:
st.session_state["mood_chat_messages"].append({"role": "User", "message": user_input})
updated_flux_prompt = prompts["PROMPT_TO_UPDATE_IDEA_OR_MOOD"].format(
EXISTING_MODEL_BG_PROMPT=model__bg_preference,
USER_INSTRUCTIONS=user_input
)
message_schema_for_claude = [
{
"role": "user",
"content": [
{
"type": "text",
"text": updated_flux_prompt
}
]
},
{
"role": "assistant",
"content": [
{
"type": "text",
"text": str(st.session_state["mood_chat_messages"])}
]
},
{
"role": "user",
"content": [
{
"type": "text",
"text": user_input + "Reference of previous conversation is also added."
}
]
},
]
response_from_claude = create_claude_request_for_text_completion(message_schema_for_claude)
cleaned_prompt = extract_data_from_text_xml(response_from_claude, "updated_prompt")
if seed == "Fixed":
generated_flux_image_n = flux_generated_image_seed(cleaned_prompt, seed_number,
dimensions)
else:
generated_flux_image_n = flux_generated_image(cleaned_prompt)
st.session_state["mood_chat_messages"].append({
"role": "AI",
"message": cleaned_prompt,
"actual_response": response_from_claude,
"image": generated_flux_image_n["file_name"]
})
for message in st.session_state["mood_chat_messages"]:
if message["role"] == "AI":
st.write(f"**AI**: {message['message']}")
st.image(message['image'])
else:
st.write(f"**You**: {message['message']}")
print(seed_number)
if 'page' not in st.session_state:
st.session_state.page = "Page 5"
if st.session_state.page == "Page 5":
page5()
if st.session_state.page == "Page 1":
page1()
elif st.session_state.page == "Page 2":
page2()
elif st.session_state.page == "Page 3":
page3()
elif st.session_state.page == "Page 4":
page4()