File size: 3,721 Bytes
c320f27
 
 
a556a7a
aafb493
 
 
 
a556a7a
 
 
 
c320f27
 
a556a7a
 
166c1ae
 
 
a556a7a
 
 
 
 
166c1ae
a556a7a
 
 
 
 
 
166c1ae
 
 
a556a7a
 
 
 
 
 
 
 
 
166c1ae
a556a7a
 
afea08b
a556a7a
 
 
 
 
 
 
 
 
 
 
c320f27
a556a7a
 
 
 
 
c320f27
 
 
 
 
a556a7a
 
 
 
 
c320f27
a556a7a
c320f27
 
 
 
a556a7a
c320f27
 
 
a556a7a
c320f27
 
a556a7a
c320f27
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import gradio as gr
from huggingface_hub import InferenceClient

# Initialize Inference Clients for all models
paligemma224_client = InferenceClient("MoritzLaurer/mDeBERTa-v3-base-mnli-xnli")
paligemma448_client = InferenceClient("microsoft/LLM2CLIP-Llama-3-8B-Instruct-CC-Finetuned")
paligemma896_client = InferenceClient("ProsusAI/finbert")
paligemma28b_client = InferenceClient("Skywork/Skywork-o1-Open-PRM-Qwen-2.5-1.5B")
llama_client = InferenceClient("llama/3.3-1b")
deepseek_client = InferenceClient("deepseek-ai/deepseek-vl2")
omniparser_client = InferenceClient("microsoft/OmniParser")
pixtral_client = InferenceClient("mistralai/Pixtral-12B-2409")


def enhance_prompt(prompt: str) -> str:
    # Use the Paligemma models for prompt enhancement
    prompt_224 = paligemma224_client(inputs={"inputs": prompt})["generated_text"]
    prompt_448 = paligemma448_client(inputs={"inputs": prompt})["generated_text"]
    prompt_896 = paligemma896_client(inputs={"inputs": prompt})["generated_text"]
    
    # Combine all enhanced prompts into a single one
    enhanced_prompt = f"Enhanced (224): {prompt_224}\nEnhanced (448): {prompt_448}\nEnhanced (896): {prompt_896}"
    
    # Ultra-enhance the prompt using Paligemma 28b
    ultra_enhanced_prompt = paligemma28b_client(inputs={"inputs": enhanced_prompt})["generated_text"]
    
    return ultra_enhanced_prompt


def generate_answer(enhanced_prompt: str) -> str:
    # Generate answers using the three models: llama, deepseek, and omniparser
    llama_answer = llama_client(inputs={"inputs": enhanced_prompt})["generated_text"]
    deepseek_answer = deepseek_client(inputs={"inputs": enhanced_prompt})["generated_text"]
    omniparser_answer = omniparser_client(inputs={"inputs": enhanced_prompt})["generated_text"]
    
    # Combine answers from all models
    combined_answer = f"Llama: {llama_answer}\nDeepseek: {deepseek_answer}\nOmniparser: {omniparser_answer}"
    
    return combined_answer


def enhance_answer(answer: str) -> str:
    # Enhance the final answer using Pixtral model
    enhanced_answer = pixtral_client(inputs={"inputs": answer})["generated_text"]
    return enhanced_answer


def process(message: str) -> str:
    # Step 1: Enhance the prompt using the Paligemma models
    enhanced_prompt = enhance_prompt(message)
    
    # Step 2: Generate an answer using the three models
    answer = generate_answer(enhanced_prompt)
    
    # Step 3: Enhance the generated answer using Pixtral
    final_answer = enhance_answer(answer)
    
    return final_answer


# Gradio interface setup
def respond(message, history: list[tuple[str, str]], system_message, max_tokens, temperature, top_p):
    # Include system message and history in conversation
    messages = [{"role": "system", "content": system_message}]
    for val in history:
        if val[0]:
            messages.append({"role": "user", "content": val[0]})
        if val[1]:
            messages.append({"role": "assistant", "content": val[1]})
    
    # Get the final enhanced response
    final_answer = process(message)
    
    # Yield the response for the Gradio interface
    response = ""
    for token in final_answer:
        response += token
        yield response


# Gradio interface setup
demo = gr.ChatInterface(
    respond,
    additional_inputs=[
        gr.Textbox(value="You are a helpful assistant.", label="System message"),
        gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
        gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
    ],
)

if __name__ == "__main__":
    demo.launch()