from transformers import BlipProcessor, BlipForConditionalGeneration, AutoTokenizer, AutoModelForSeq2SeqLM
import torch
import requests
from dotenv import load_dotenv
import os


class ImageCaptioning:
    def __init__(self):
        # Initialize Model and Tokenizer
        self.blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
        self.blip_model = BlipForConditionalGeneration.from_pretrained('Salesforce/blip-image-captioning-base')
        self.topic_generator_processor = AutoTokenizer.from_pretrained("google/flan-t5-large")
        self.topic_generator_model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-large")
        self.blip_model.eval()
        self.topic_generator_model.eval()


    def generate_caption(self, image):
        # Generate Caption
        input_text = self.blip_processor(image, return_tensors="pt")
        outputs = self.blip_model.generate(pixel_values=input_text["pixel_values"], max_new_tokens=128, do_sample=True, temperature=0.5, top_k=50, top_p=0.95)
        caption_output = [self.blip_processor.decode(output, skip_special_tokens=True) for output in outputs]

        return outputs

    
    def generate_topics(self, caption, additional_text=None, num_topics=3):
        base_prompt = "Generate short, creative titles or topics based on the detailed information provided:"

        # Construct the prompt based on whether additional context is provided
        if additional_text:
            full_prompt = (f"{base_prompt}\n\n"
                        f"Image description: {caption}\n\n"
                        f"Additional context: {additional_text}\n\n"
                        f"Task: Create {num_topics} inventive titles or topics (2-5 words each) that blend the essence of the image with the additional context. "
                        f"These titles should be imaginative and suitable for use as hashtags, image titles, or starting points for discussions."
                        f"IMPORTANT: Be imaginative and concise in your responses. Avoid repeating the same ideas in different words."
                        f"Also make sure to provide a title/topic that relates to every context provided while following the examples listed below as a way of being creative and intuitive."
                        )
        else:
            full_prompt = (f"{base_prompt}\n\n"
                        f"Image description: {caption}\n\n"
                        f"Task: Create {num_topics} inventive titles or topics (2-5 words each) that encapsulate the essence of the image. "
                        f"These titles should be imaginative and suitable for use as hashtags, image titles, or starting points for discussions."
                        f"IMPORTANT: Be imaginative and concise in your responses. Avoid repeating the same ideas in different words."
                        f"Also make sure to provide a title/topic that relates to every context provided while following the examples listed below as a way of being creative and intuitive."
                        )

        # Provide creative examples to inspire the model
        examples = """
        Creative examples to inspire your titles/topics:
        - "Misty Peaks at Dawn"
        - "Graffiti Lanes of Urbania"
        - "Chef’s Secret Ingredients"
        - "Neon Future Skylines"
        - "Puppy’s First Snow"
        - "Edge of Adventure"
        """

        # Append the examples to the prompt with a clear creative directive
        full_prompt += f"\n{examples}\nNow, inspired by these examples, create {num_topics} short and descriptive titles/topics based on the information provided.\n"

        print(full_prompt)

        # Generate the topics using the T5 model with adjusted parameters
        inputs = self.topic_generator_processor(full_prompt, return_tensors="pt", padding=True, truncation=True, max_length=512)
        outputs = self.topic_generator_model.generate(
            **inputs,
            num_return_sequences=num_topics,
            do_sample=True,
            temperature=0.7,
            max_length=32,  # Reduced for shorter outputs
            top_k=50,
            top_p=0.95,
            num_beams=5,
            no_repeat_ngram_size=2
        )

        topics = [self.topic_generator_processor.decode(output, skip_special_tokens=True).strip() for output in outputs]
        return [topic for topic in topics if topic and len(topic.split()) > 1]

    def combo_model(self, image, additional_text=None):
        caption = self.generate_caption(image)
        caption = self.blip_processor.decode(caption[0], skip_special_tokens=True)
        topics = self.generate_topics(caption, additional_text)
        
        return {
            "caption": caption,
            "topics": topics
        }