Use it in Comfy.

#4
by razvanab - opened

I made a node with the help of AI for ComfyUI to use this to enhance your prompt.

import torch
import random
import hashlib
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM

class PromptEnhancer:
    def __init__(self):
        # Set up device
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        
        # Model checkpoint
        self.model_checkpoint = "gokaygokay/Flux-Prompt-Enhance"
        
        # Tokenizer and Model
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_checkpoint)
        self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model_checkpoint).to(self.device)
        
        # Initialize the node title and generated prompt
        self.node_title = "Prompt Enhancer"
        self.generated_prompt = ""

    @classmethod
    def INPUT_TYPES(cls):
        return {
            "required": {
                "prompt": ("STRING",),
                "seed": ("INT", {"default": 42, "min": 0, "max": 4294967295}),  # Default seed, larger range
                "repetition_penalty": ("FLOAT", {"default": 1.2, "min": 0.0, "max": 10.0}),  # Default repetition penalty
                "max_target_length": ("INT", {"default": 256, "min": 1, "max": 1024}),  # Default max target length
                "temperature": ("FLOAT", {"default": 0.7, "min": 0.0, "max": 1.0}),  # Default temperature
                "top_k": ("INT", {"default": 50, "min": 1, "max": 1000}),  # Default top-k
                "top_p": ("FLOAT", {"default": 0.9, "min": 0.0, "max": 1.0}),  # Default top-p
            },
            "optional": {
                "prompts_list": ("LIST",),  # List of prompts
            }
        }

    RETURN_TYPES = ("STRING",)  # Return only one string: the enhanced prompt
    FUNCTION = "enhance_prompt"
    CATEGORY = "TextEnhancement"

    def generate_large_seed(self, seed, prompt):
        # Combine the seed and prompt to create a unique string
        unique_string = f"{seed}_{prompt}"
        
        # Use a hash function to generate a large seed
        hash_object = hashlib.sha256(unique_string.encode())
        large_seed = int(hash_object.hexdigest(), 16) % (2**32)
        
        return large_seed

    def enhance_prompt(self, prompt, seed=42, repetition_penalty=1.2, max_target_length=256, temperature=0.7, top_k=50, top_p=0.9, prompts_list=None):
        # Generate a large seed value
        large_seed = self.generate_large_seed(seed, prompt)
        
        # Set random seed for reproducibility
        torch.manual_seed(large_seed)
        random.seed(large_seed)

        # Determine the prompts to process
        prompts = [prompt] if prompts_list is None else prompts_list

        enhanced_prompts = []
        for p in prompts:
            # Enhance prompt
            prefix = "enhance prompt: "
            input_text = prefix + p
            input_ids = self.tokenizer(input_text, return_tensors="pt").input_ids.to(self.device)
            
            # Generate a random seed for this generation
            random_seed = torch.randint(0, 2**32 - 1, (1,)).item()
            torch.manual_seed(random_seed)
            random.seed(random_seed)
            
            outputs = self.model.generate(
                input_ids,
                max_length=max_target_length,
                num_return_sequences=1,
                do_sample=True,
                temperature=temperature,
                repetition_penalty=repetition_penalty,
                top_k=top_k,
                top_p=top_p
            )
            
            final_answer = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
            confidence_score = 1.0  # Default to 1.0 if no score is provided
            
            # Print the generated prompt and confidence score
            print(f"Generated Prompt: {final_answer} (Confidence: {confidence_score:.2f})")
            enhanced_prompts.append((f"Enhanced Prompt: {final_answer}", confidence_score))

        # Update the node title and generated prompt
        if prompts_list is None:
            self.node_title = f"Prompt Enhancer (Confidence: {confidence_score:.2f})"
            self.generated_prompt = f"Enhanced Prompt: {final_answer}"
            return (f"Enhanced Prompt: {final_answer}",)
        else:
            self.node_title = "Prompt Enhancer (Multiple Prompts)"
            self.generated_prompt = "Multiple Prompts"
            return enhanced_prompts

    @property
    def NODE_TITLE(self):
        return self.node_title

    @property
    def GENERATED_PROMPT(self):
        return self.generated_prompt

# A dictionary that contains all nodes you want to export with their names
NODE_CLASS_MAPPINGS = {
    "PromptEnhancer": PromptEnhancer
}

# A dictionary that contains the friendly/humanly readable titles for the nodes
NODE_DISPLAY_NAME_MAPPINGS = {
    "PromptEnhancer": "Prompt Enhancer"
}

Sign up or log in to comment