Spaces:
Sleeping
Sleeping
import streamlit as st | |
import os | |
import re | |
from claude import embed_base64_for_claude, create_claude_image_request_for_image_captioning, \ | |
create_claude_request_for_text_completion, extract_data_from_text_xml | |
from prompts import prompts | |
from constants import JSON_SCHEMA_FOR_GPT, UPDATED_MODEL_ONLY_SCHEMA, JSON_SCHEMA_FOR_LOC_ONLY | |
from gpt import runAssistant, checkRunStatus, retrieveThread, createAssistant, saveFileOpenAI, startAssistantThread, \ | |
create_chat_completion_request_open_ai_for_summary, addMessageToThread, create_image_completion_request_gpt | |
from summarizer import create_brand_html, create_langchain_openai_query, create_screenshot_from_scrap_fly | |
from theme import flux_generated_image, flux_generated_image_seed | |
import time | |
from PIL import Image | |
import io | |
from streamlit_gsheets import GSheetsConnection | |
# conn = st.connection("gsheets", type=GSheetsConnection) | |
def process_run(st, thread_id, assistant_id): | |
run_id = runAssistant(thread_id, assistant_id) | |
status = 'running' | |
while status != 'completed': | |
with st.spinner('. . .'): | |
time.sleep(20) | |
status = checkRunStatus(thread_id, run_id) | |
thread_messages = retrieveThread(thread_id) | |
for message in thread_messages: | |
if not message['role'] == 'user': | |
return message["content"] | |
else: | |
pass | |
def page5(): | |
st.title('Initialize your preferences!') | |
system_prompt_passed = st.text_area("System Prompt", value=prompts["PROMPT_FOR_MOOD_AND_IDEA"], | |
key="System Prompt") | |
caption_system_prompt = st.text_area("Captioning System Prompt", value=prompts["CAPTION_SYSTEM_PROMPT"], | |
key="Caption Generation System Prompt") | |
caption_prompt = st.text_area("Caption Prompt", value=prompts["CAPTION_PROMPT"], | |
key="Caption Generation Prompt") | |
brand_summary_prompt = st.text_area("Prompt for Brand Summary", value=prompts["BRAND_SUMMARY_PROMPT"], | |
key="Brand summary prompt") | |
st.text("Running on Claude") | |
col1, col2 = st.columns([1, 2]) | |
with col1: | |
if st.button("Save the Prompt"): | |
st.session_state["system_prompt"] = system_prompt_passed | |
print(st.session_state["system_prompt"]) | |
st.session_state["caption_system_prompt"] = caption_system_prompt | |
st.session_state["caption_prompt"] = caption_prompt | |
st.session_state["brand_prompt"] = brand_summary_prompt | |
st.success("Saved your prompts") | |
with col2: | |
if st.button("Start Testing!"): | |
st.session_state['page'] = "Page 1" | |
def page1(): | |
st.title("Upload Product") | |
st.markdown("<h2 style='color:#FF5733; font-weight:bold;'>Add a Product</h2>", unsafe_allow_html=True) | |
st.markdown("<p style='color:#444;'>Upload your product images, more images you upload better the AI learns</p>", | |
unsafe_allow_html=True) | |
uploaded_files = st.file_uploader("Upload Images", accept_multiple_files=True, key="uploaded_files_key") | |
product_description = st.text_area("Describe the product", value=st.session_state.get("product_description", "")) | |
col1, col2 = st.columns([1, 2]) | |
with col1: | |
if st.button("Save"): | |
st.session_state['uploaded_files'] = uploaded_files | |
st.session_state['product_description'] = product_description | |
st.success("Product information saved!") | |
with col2: | |
if st.button("Add product and move to next page"): | |
if not uploaded_files: | |
st.warning("Please upload at least one image.") | |
elif not product_description: | |
st.warning("Please provide a description for the product.") | |
else: | |
st.session_state['uploaded_files'] = uploaded_files | |
st.session_state['product_description'] = product_description | |
st.session_state['page'] = "Page 2" | |
def page2(): | |
st.title("Tell us about your shoot preference") | |
st.markdown("<h3 style='color:#444;'>What are you shooting today?</h3>", unsafe_allow_html=True) | |
shoot_type = st.radio("Select your shoot type:", ["Editorial", "Catalogue"], index=0) | |
st.session_state['shoot_type'] = shoot_type | |
brand_link = st.text_input("Add your brand link:", value=st.session_state.get("brand_link", "")) | |
st.session_state['brand_link'] = brand_link | |
if st.button("Get Brand Summary"): | |
if brand_link: | |
st.text("Using Scrapfly") | |
brand_summary_html = create_screenshot_from_scrap_fly(brand_link) | |
if brand_summary_html["success"]: | |
st.image(brand_summary_html["location"]) | |
brand_image_embed = embed_base64_for_claude(brand_summary_html["location"]) | |
brand_summary_response = create_claude_image_request_for_image_captioning( | |
"Fashion expert of understanding brand details", | |
st.session_state["brand_prompt"], brand_image_embed) | |
st.session_state['brand_summary'] = brand_summary_response | |
else: | |
st.text(f"Scrapfly failed due to: {brand_summary_html}") | |
st.text("Using Langchain") | |
brand_summary_html = create_brand_html(brand_link) | |
brand_summary = create_langchain_openai_query(brand_summary_html) | |
st.session_state['brand_summary'] = brand_summary | |
st.success("Brand summary fetched!") | |
else: | |
st.warning("Please add a brand link.") | |
brand_summary_value = st.session_state.get('brand_summary', "") | |
editable_summary = st.text_area("Brand Summary:", value=brand_summary_value, height=100) | |
st.session_state['brand_summary'] = editable_summary | |
product_info = st.text_area("Tell us something about your product:", value=st.session_state.get("product_info", "")) | |
st.session_state['product_info'] = product_info | |
reference_images = st.file_uploader("Upload Reference Images", accept_multiple_files=True, | |
key="reference_images_key") | |
st.session_state['reference_images'] = reference_images | |
if st.button("Give Me Ideas"): | |
st.session_state['page'] = "Page 3" | |
def page3(): | |
import random | |
st.title("Scene Suggestions") | |
st.write("Based on your uploaded product and references!") | |
feedback = st.chat_input("Provide feedback:") | |
if not st.session_state.get("assistant_initialized", False): | |
file_locations_for_product = [] | |
for uploaded_file in st.session_state['uploaded_files']: | |
bytes_data = uploaded_file.getvalue() | |
image = Image.open(io.BytesIO(bytes_data)) | |
image.verify() | |
location = f"temp_image_{random.randint(1, 100000000)}.png" | |
with open(location, "wb") as f: | |
f.write(bytes_data) | |
file_locations_for_product.append(location) | |
image.close() | |
file_base64_embeds_product = [embed_base64_for_claude(location) for location in file_locations_for_product] | |
caption_list_from_claude_product = [] | |
for file_embeds_base64 in file_base64_embeds_product: | |
caption_from_claude = create_claude_image_request_for_image_captioning( | |
st.session_state["caption_system_prompt"], st.session_state["caption_prompt"], file_embeds_base64) | |
caption_list_from_claude_product.append(caption_from_claude) | |
string_caption_list_product = str(caption_list_from_claude_product) | |
file_locations_for_others = [] | |
for uploaded_file in st.session_state['reference_images']: | |
bytes_data = uploaded_file.getvalue() | |
image = Image.open(io.BytesIO(bytes_data)) | |
image.verify() | |
location = f"temp2_image_{random.randint(1, 1000000)}.png" | |
with open(location, "wb") as f: | |
f.write(bytes_data) | |
file_locations_for_others.append(location) | |
image.close() | |
file_base64_embeds = [embed_base64_for_claude(location) for location in file_locations_for_others] | |
st.session_state.assistant_initialized = True | |
caption_list_from_claude = [] | |
for file_embeds_base64 in file_base64_embeds: | |
caption_from_claude = create_claude_image_request_for_image_captioning( | |
st.session_state["caption_system_prompt"], st.session_state["caption_prompt"], file_embeds_base64) | |
caption_list_from_claude.append(caption_from_claude) | |
string_caption_list = str(caption_list_from_claude) | |
st.session_state["caption_product"] = string_caption_list_product | |
st.session_state["additional_caption"] = string_caption_list | |
additional_info_param_for_prompt = f"Brand have provided reference images whose details are:" \ | |
f"```{string_caption_list}```. Apart from this brand needs" \ | |
f"{st.session_state['shoot_type']}" | |
product_info = str(string_caption_list_product) + st.session_state['product_info'] | |
updated_prompt_for_claude = st.session_state["system_prompt"].format( | |
BRAND_DETAILS=st.session_state['brand_summary'], | |
PRODUCT_DETAILS=product_info, | |
ADDITIONAL_INFO=additional_info_param_for_prompt | |
) | |
st.session_state["updated_prompt"] = updated_prompt_for_claude | |
message_schema_for_claude = [ | |
{ | |
"role": "user", | |
"content": [ | |
{ | |
"type": "text", | |
"text": updated_prompt_for_claude | |
} | |
] | |
} | |
] | |
response_from_claude = create_claude_request_for_text_completion(message_schema_for_claude) | |
campaign_pattern = r"<campaign_idea>(.*?)</campaign_idea>" | |
campaigns = re.findall(campaign_pattern, response_from_claude, re.DOTALL) | |
concat_prompt_list = [] | |
for idx, campaign in enumerate(campaigns, start=1): | |
get_model_prompt = extract_data_from_text_xml(campaign, "model_prompt") | |
get_background_prompt = extract_data_from_text_xml(campaign, "background_prompt") | |
concat_prompt_flux = get_model_prompt + get_background_prompt | |
concat_prompt_list.append(concat_prompt_flux) | |
flux_generated_theme_image = [] | |
for concat_prompt in concat_prompt_list: | |
theme_image = flux_generated_image(concat_prompt) | |
flux_generated_theme_image.append(theme_image["file_name"]) | |
print(flux_generated_theme_image) | |
st.session_state["descriptions"] = concat_prompt_list | |
st.session_state["claude_context"] = response_from_claude | |
st.session_state["images"] = flux_generated_theme_image | |
if feedback: | |
updated_context = st.session_state["claude_context"] | |
if 'images' in st.session_state and 'descriptions' in st.session_state: | |
for image_path in st.session_state['images']: | |
os.remove(image_path) | |
del st.session_state['images'] | |
del st.session_state['descriptions'] | |
del st.session_state["claude_context"] | |
message_schema_for_claude = [ | |
{ | |
"role": "user", | |
"content": [ | |
{ | |
"type": "text", | |
"text": st.session_state["updated_prompt"] | |
} | |
] | |
}, | |
{ | |
"role": "assistant", | |
"content": [ | |
{ | |
"type": "text", | |
"text": updated_context} | |
] | |
}, | |
{ | |
"role": "user", | |
"content": [ | |
{ | |
"type": "text", | |
"text": feedback | |
} | |
] | |
}, | |
] | |
response_from_claude = create_claude_request_for_text_completion(message_schema_for_claude) | |
campaign_pattern = r"<campaign_idea>(.*?)</campaign_idea>" | |
campaigns = re.findall(campaign_pattern, response_from_claude, re.DOTALL) | |
concat_prompt_list = [] | |
for idx, campaign in enumerate(campaigns, start=1): | |
get_model_prompt = extract_data_from_text_xml(campaign, "model_prompt") | |
get_background_prompt = extract_data_from_text_xml(campaign, "background_prompt") | |
concat_prompt_flux = get_model_prompt + get_background_prompt | |
concat_prompt_list.append(concat_prompt_flux) | |
flux_generated_theme_image = [] | |
for concat_prompt in concat_prompt_list: | |
theme_image = flux_generated_image(concat_prompt) | |
flux_generated_theme_image.append(theme_image["file_name"]) | |
st.session_state["descriptions"] = concat_prompt_list | |
st.session_state["claude_context"] = response_from_claude | |
st.session_state["images"] = flux_generated_theme_image | |
selected_image_index = None | |
cols = st.columns(4) | |
for i in range(len(st.session_state["images"])): | |
with cols[i]: | |
st.image(st.session_state.images[i], caption=st.session_state.descriptions[i], use_column_width=True) | |
if st.radio(f"Select {i + 1}", [f"Select Image {i + 1}"], key=f"radio_{i}"): | |
selected_image_index = i | |
if selected_image_index is not None and st.button("Refine"): | |
st.session_state.selected_image_index = selected_image_index | |
st.session_state.selected_image = st.session_state.images[selected_image_index] | |
st.session_state.selected_text = st.session_state.descriptions[selected_image_index] | |
st.session_state['page'] = "Page 4" | |
if st.button("Go Back!"): | |
st.session_state.page = "Page 2" | |
def page4(): | |
import json | |
selected_theme_text_by_user = st.session_state.descriptions[st.session_state.selected_image_index] | |
print(selected_theme_text_by_user) | |
with (st.sidebar): | |
st.title(st.session_state["product_info"]) | |
st.write("Product Image") | |
st.image(st.session_state['uploaded_files']) | |
st.text("Scene Suggestion:") | |
st.image(st.session_state.selected_image) | |
dimensions = st.text_input("Enter Dimensions e.g 3:4, 1:2", key="Dimensions") | |
seed = st.selectbox( | |
"Seed Preference", | |
("Fixed", "Random"), | |
) | |
if seed == "Fixed": | |
seed_number = st.number_input("Enter an integer:", min_value=1, max_value=100000, value=10, step=1) | |
else: | |
seed_number = 0 | |
st.text("Thanks will take care") | |
model__bg_preference = st.text_area("Edit Model & BG Idea", value=selected_theme_text_by_user, | |
key="Model & BG Idea") | |
start_chat = st.button("Start Chat") | |
if "mood_chat_messages" not in st.session_state: | |
st.session_state["mood_chat_messages"] = [] | |
if seed and dimensions and model__bg_preference: | |
if start_chat: | |
if seed == "Fixed": | |
generated_flux_image = flux_generated_image_seed(model__bg_preference, seed_number, dimensions) | |
else: | |
generated_flux_image = flux_generated_image(model__bg_preference) | |
st.session_state["mood_chat_messages"].append({ | |
"role": "AI", | |
"message": model__bg_preference, | |
"image": generated_flux_image["file_name"] | |
}) | |
# for message in st.session_state["mood_chat_messages"]: | |
# if message["role"] == "AI": | |
# st.write(f"Caimera AI: {message['message']}") | |
# st.image(message['image']) | |
#else: | |
# st.write(f"**You**: {message['message']}") | |
user_input = st.chat_input("Type your message here...") | |
if user_input: | |
st.session_state["mood_chat_messages"].append({"role": "User", "message": user_input}) | |
updated_flux_prompt = prompts["PROMPT_TO_UPDATE_IDEA_OR_MOOD"].format( | |
EXISTING_MODEL_BG_PROMPT=model__bg_preference, | |
USER_INSTRUCTIONS=user_input | |
) | |
message_schema_for_claude = [ | |
{ | |
"role": "user", | |
"content": [ | |
{ | |
"type": "text", | |
"text": updated_flux_prompt | |
} | |
] | |
}, | |
{ | |
"role": "assistant", | |
"content": [ | |
{ | |
"type": "text", | |
"text": str(st.session_state["mood_chat_messages"])} | |
] | |
}, | |
{ | |
"role": "user", | |
"content": [ | |
{ | |
"type": "text", | |
"text": user_input + "Reference of previous conversation is also added." | |
} | |
] | |
}, | |
] | |
response_from_claude = create_claude_request_for_text_completion(message_schema_for_claude) | |
cleaned_prompt = extract_data_from_text_xml(response_from_claude, "updated_prompt") | |
if seed == "Fixed": | |
generated_flux_image_n = flux_generated_image_seed(cleaned_prompt, seed_number, | |
dimensions) | |
else: | |
generated_flux_image_n = flux_generated_image(cleaned_prompt) | |
st.session_state["mood_chat_messages"].append({ | |
"role": "AI", | |
"message": cleaned_prompt, | |
"actual_response": response_from_claude, | |
"image": generated_flux_image_n["file_name"] | |
}) | |
for message in st.session_state["mood_chat_messages"]: | |
if message["role"] == "AI": | |
st.write(f"**AI**: {message['message']}") | |
st.image(message['image']) | |
else: | |
st.write(f"**You**: {message['message']}") | |
print(seed_number) | |
if 'page' not in st.session_state: | |
st.session_state.page = "Page 5" | |
if st.session_state.page == "Page 5": | |
page5() | |
if st.session_state.page == "Page 1": | |
page1() | |
elif st.session_state.page == "Page 2": | |
page2() | |
elif st.session_state.page == "Page 3": | |
page3() | |
elif st.session_state.page == "Page 4": | |
page4() | |