CPU_BlazeChat / app.py
Sanshruth's picture
initial_commit
aff5a3a verified
import os
import torch
import time
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
from tld.diffusion import DiffusionTransformer
from tld.configs import LTDConfig, DenoiserConfig, DenoiserLoad
import numpy as np
from PIL import Image
# Image Generation Model Setup
denoiser_cfg = DenoiserConfig(
image_size=32,
noise_embed_dims=256,
patch_size=2,
embed_dim=768,
dropout=0,
n_layers=12,
text_emb_size=768
)
denoiser_load = DenoiserLoad(**{
'dtype': torch.float32,
'file_url': 'https://huggingface.co/apapiu/small_ldt/resolve/main/state_dict_378000.pth',
'local_filename': 'state_dict_378000.pth'
})
cfg = LTDConfig(denoiser_cfg=denoiser_cfg, denoiser_load=denoiser_load)
diffusion_transformer = DiffusionTransformer(cfg)
# Set PyTorch to use all available CPU cores
num_cores = os.cpu_count()
torch.set_num_threads(num_cores)
print(f"Using {num_cores} CPU cores.")
# Text Model Setup
model_name = 'mllmTeam/PhoneLM-1.5B-Instruct'
model = AutoModelForCausalLM.from_pretrained(model_name, device_map='cpu', trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(model_name)
def generate_text_response(question):
start_time = time.time()
prompt = [{"role": "user", "content": question}]
input_text = tokenizer.apply_chat_template(prompt, tokenize=False, add_generation_prompt=True)
inp = tokenizer(input_text, return_tensors="pt")
inp = {k: v.to('cpu') for k, v in inp.items()}
out = model.generate(**inp, max_length=256, do_sample=True, temperature=0.7, top_p=0.7)
text = tokenizer.decode(out[0], skip_special_tokens=True)
text = text.split("\n")[-1]
end_time = time.time()
elapsed_time = end_time - start_time
return text
def generate_image(prompt, class_guidance=6, num_imgs=1, seed=11):
start_time = time.time()
try:
# Generate the image
out = diffusion_transformer.generate_image_from_text(
prompt=prompt,
class_guidance=class_guidance,
num_imgs=num_imgs,
seed=seed
)
# Convert to PIL Image if it's not already
if isinstance(out, torch.Tensor):
out = out.squeeze().permute(1, 2, 0).numpy()
# Ensure the image is in the right format for Gradio
if isinstance(out, np.ndarray):
# Normalize pixel values to 0-255 range
out = ((out - out.min()) * (1/(out.max() - out.min()) * 255)).astype('uint8')
out = Image.fromarray(out)
end_time = time.time()
print(f"Image generation time: {end_time - start_time:.2f} seconds")
return out
except Exception as e:
print(f"Image generation error: {e}")
return None
def chat_with_ai(message, history):
max_history_length = 1 # Adjust as needed
history = history[-max_history_length:]
if message.startswith('@imagine'):
# Extract prompt after '@imagine'
image_prompt = message.split('@imagine', 1)[1].strip()
image = generate_image(image_prompt)
if image:
return "", history, image
else:
return "", history + [[message, "Failed to generate image."]], None
else:
response = generate_text_response(message)
return response, history + [[message, response]], None
# Create Gradio interface
with gr.Blocks(title="BlazeChat Image Generator") as demo:
#################
gr.Markdown("# ⚡Fast CPU-Powered Chat & Image Generation")
gr.Markdown("Generate text and images using advanced AI models on CPU. Use `@imagine [prompt]` to create images or chat naturally.")
gr.Markdown("https://github.com/SanshruthR/CPU_BlazeChat")
####################
chatbot = gr.Chatbot()
msg = gr.Textbox(label="Enter your message")
####submit button
submit_button = gr.Button("Submit")
##########
clear = gr.Button("Clear")
img_output = gr.Image(label="Generated Image")
msg.submit(chat_with_ai, [msg, chatbot], [msg, chatbot, img_output])
####################binding with submit
submit_button.click(chat_with_ai, [msg, chatbot], [msg, chatbot, img_output])
###################
clear.click(lambda: None, None, chatbot, queue=False)
# Launch the demo
demo.launch(debug=True,ssr_mode=False)