hadrakey's picture
Training in progress, step 1000
3f86748 verified
# Example inspired from https://huggingface.co/microsoft/Phi-3-vision-128k-instruct
# Import necessary libraries
from PIL import Image
import requests
from transformers import AutoModelForCausalLM
from transformers import AutoProcessor
from transformers import BitsAndBytesConfig
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
import torch
import pandas as pd
from torchmetrics.text import CharErrorRate
from peft import PeftModel, PeftConfig
# Define model ID
model_id = "microsoft/Phi-3-vision-128k-instruct"
peft_model_id = "hadrakey/alphapen_phi3"
peft_model_id_new = "hadrakey/alphapen_new_large"
# Load processor
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
# phi3 finetuned
# config = PeftConfig.from_pretrained(peft_model_id)
# processor_fine = AutoProcessor.from_pretrained(config.base_model_name_or_path, trust_remote_code=True)
# Finetuned model
# config_new = PeftConfig.from_pretrained(peft_model_id_new)
model_finetune = VisionEncoderDecoderModel.from_pretrained("hadrakey/alphapen_large")
# model_new_finetune = AutoModelForCausalLM.from_pretrained(config_new.base_model_name_or_path, device_map="auto", trust_remote_code=True, torch_dtype="auto")
# model_finetune_phi3 = AutoModelForCausalLM.from_pretrained("hadrakey/alphapen_phi3", trust_remote_code=True)
#Baseline
model_base = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")
processor_ocr = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
# processor_ocr_new = AutoProcessor.from_pretrained(config_new.base_model_name_or_path, device_map="auto", trust_remote_code=True, torch_dtype="auto")
# Define BitsAndBytes configuration for 4-bit quantization
nf4_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=torch.bfloat16,
)
# Load model with 4-bit quantization and map to CUDA
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="cuda",
trust_remote_code=True,
torch_dtype="auto",
quantization_config=nf4_config,
)
# base_model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path, device_map="auto", trust_remote_code=True, torch_dtype="auto")
# model_finetune_phi3 = PeftModel.from_pretrained(base_model, peft_model_id)
# Define initial chat message with image placeholder
messages = [{"role": "user", "content": """<|image_1|>\nThis image contains handwritten French characters forming a complete or partial word. The image is blurred, which makes recognition challenging. Please analyze the image to the best of your ability and provide your best guess of the French word or partial word shown, even if you're not certain. Follow these guidelines:
1. Examine the overall shape and any discernible character features.
2. Consider common French letter combinations and word patterns.
3. If you can only identify some characters, provide those as a partial word.
4. Make an educated guess based on what you can see, even if it's just a few letters.
5. If you can see any characters at all, avoid responding with "indiscernible."
Your response should be only the predicted French word or partial word, using lowercase letters unless capital letters are clearly visible. If you can see any characters or shapes at all, provide the OCR from the image.
"""}]
# messages = [{"role": "user", "content": """<|image_1|>\nWhat is shown is this images ? You should only output only your guess otherwise output the OCR.
# """}]
# Download image from URL
url = "https://images.unsplash.com/photo-1528834342297-fdefb9a5a92b?ixlib=rb-4.0.3&q=85&fm=jpg&crop=entropy&cs=srgb&dl=roonz-nl-vjDbHCjHlEY-unsplash.jpg&w=640"
# image = Image.open(requests.get(url, stream=True).raw)
df_path = "/mnt/data1/Datasets/AlphaPen/" + "testing_data.csv"
data = pd.read_csv(df_path)
data.dropna(inplace=True)
data.reset_index(inplace=True)
sample = data.iloc[:5000,:]
root_dir = "/mnt/data1/Datasets/OCR/Alphapen/clean_data/"
# Prepare prompt with image token
prompt = processor.tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
cer_metric = CharErrorRate()
phi_output=[]
phi_finetune_output=[]
inf_baseline = []
inf_finetune = []
inf_finetune_new = []
cer_phi = []
cer_phi_finetune = []
cer_trocr_fine_new = []
cer_trocr_fine = []
cer_trocr_base = []
for idx in range(len(sample)):
# idx=30 # choose the image
image = Image.open(root_dir + "final_cropped_rotated_" + data.filename[idx]).convert("RGB")
# Process prompt and image for model input
inputs = processor(prompt, [image], return_tensors="pt").to("cuda:0")
# Generate text response using model
generate_ids = model.generate(
**inputs,
eos_token_id=processor.tokenizer.eos_token_id,
max_new_tokens=500,
do_sample=False,
)
# Remove input tokens from generated response
generate_ids = generate_ids[:, inputs["input_ids"].shape[1] :]
# Decode generated IDs to text
response = processor.batch_decode(
generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
)[0]
phi_output.append(response)
cer_phi.append(cer_metric(response.lower(), data.text[idx].lower()).detach().numpy())
# Generate text response using model finetuned
# generate_ids_fine = model_finetune_phi3.generate(
# **inputs,
# eos_token_id=processor.tokenizer.eos_token_id,
# max_new_tokens=500,
# do_sample=False,
# )
# # Remove input tokens from generated response
# inputs = processor_fine(prompt, [image], return_tensors="pt").to("cuda:0")
# generate_ids_fine = generate_ids_fine[:, inputs["input_ids"].shape[1] :]
# Decode generated IDs to text
# response = processor.batch_decode(
# generate_ids_fine, skip_special_tokens=True, clean_up_tokenization_spaces=False
# )[0]
# phi_finetune_output.append(response)
# cer_phi_finetune.append(cer_metric(response, data.text[idx]).detach().numpy())
# Trocr
pixel_values = processor_ocr(image, return_tensors="pt").pixel_values
generated_ids_base = model_base.generate(pixel_values)
generated_ids_fine = model_finetune.generate(pixel_values)
# generated_ids_fine_new = model_finetune_new.generate(pixel_values)
generated_text_base = processor_ocr.batch_decode(generated_ids_base, skip_special_tokens=True)[0]
generated_text_fine= processor_ocr.batch_decode(generated_ids_fine, skip_special_tokens=True)[0]
# generated_text_fine_new= processor_ocr_new.batch_decode(generated_ids_fine_new, skip_special_tokens=True)[0]
inf_baseline.append(generated_text_base)
inf_finetune.append(generated_text_fine)
# inf_finetune_new.append(generated_text_fine_new)
# cer_trocr_fine_new.append(cer_metric(generated_text_fine_new, data.text[idx]).detach().numpy())
cer_trocr_fine.append(cer_metric(generated_text_fine.lower(), data.text[idx].lower()).detach().numpy())
cer_trocr_base.append(cer_metric(generated_text_base.lower(), data.text[idx].lower()).detach().numpy())
# Print the generated response
sample["phi3"]=phi_output
# sample["phi3_fine"]=phi_finetune_output
sample["Baseline"]=inf_baseline
sample["Finetune"]=inf_finetune
# sample["Finetune_new"]=inf_finetune_new
sample["cer_phi"]=cer_phi
# sample["cer_phi_fine"]=cer_phi_finetune
sample["cer_trocr_base"]=cer_trocr_base
sample["cer_trocr_fine"]=cer_trocr_fine
# sample["cer_trocr_fine_new"]=cer_trocr_fine_new
sample.to_csv("/mnt/data1/Datasets/AlphaPen/" + "sample_data.csv")