YAML Metadata
Warning:
empty or missing yaml metadata in repo card
(https://huggingface.co/docs/hub/model-cards#model-card-metadata)
LLaVA-LoRA Adapter
This is a LoRA adapter for the LLaVA model, fine-tuned for spatial description tasks.
Base Model
This adapter is trained on top of llava-hf/llava-1.5-7b-hf.
Training
The model was fine-tuned using LoRA with the following configuration:
- Rank: 8
- Alpha: 32
- Target modules: q_proj, v_proj, k_proj
- Dataset: PersReFex validation set
Usage
from peft import PeftModel
from transformers import AutoProcessor, LlavaForConditionalGeneration
import torch
# Load base model
base_model = LlavaForConditionalGeneration.from_pretrained(
"llava-hf/llava-1.5-7b-hf",
torch_dtype=torch.bfloat16
).to('cuda')
processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
# Load LoRA adapter
model = PeftModel.from_pretrained(
base_model,
"ZinengTang/llava-lora-spatial"
)
from PIL import Image
init_prompt_instruct = "Describe the location of the blue sphere relative to the environment features."
conversation = [
{
"role": "user",
"content": [
{"type": "text", "text": init_prompt_instruct},
{"type": "image"}, # This will be replaced with the actual image
],
},
]
speaker_image = Image.open('your_image_path')
prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
# print(prompt)
# Process the input image and prompt
inputs = processor(
images=speaker_image,
text=prompt,
return_tensors="pt",
max_length=256,
).to('cuda')
with torch.no_grad():
generated = model.generate(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
pixel_values=inputs["pixel_values"],
max_length=512,
num_beams=1,
do_sample=True,
temperature=0.7
)
generated_message = processor.batch_decode(
generated,
skip_special_tokens=True
)
print(generated_message)
generated_message = generated_message[0].split('ASSISTANT: ')[-1][:100]