Alan Rabello
commited on
Commit
·
d400378
1
Parent(s):
0d413ec
Add img extensions
Browse files- caption.py +22 -20
caption.py
CHANGED
@@ -11,7 +11,7 @@ CHECKPOINT_PATH = Path("./checkpoint")
|
|
11 |
LLMA_CHECKPOINT = "John6666/Llama-3.1-8B-Lexi-Uncensored-V2-nf4"
|
12 |
WORDS=200
|
13 |
PROMPT = "In one paragraph, write a very descriptive caption for this image, describe all objects, characters and their actions, describe in detail what is happening and their emotions. Include information about lighting, the style of this image and information about camera angle within {word_count} words. Don't create any title for the image."
|
14 |
-
|
15 |
HF_TOKEN = os.environ.get("HF_TOKEN", None)
|
16 |
|
17 |
class ImageAdapter(nn.Module):
|
@@ -130,9 +130,6 @@ def proc_img(input_image):
|
|
130 |
], dim=1).to(device)
|
131 |
attention_mask = torch.ones_like(input_ids)
|
132 |
|
133 |
-
# Debugging
|
134 |
-
#print(f"Input to model: {repr(tokenizer.decode(input_ids[0]))}")
|
135 |
-
|
136 |
#generate_ids = text_model.generate(input_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask, max_new_tokens=300, do_sample=True, top_k=10, temperature=0.5, suppress_tokens=None)
|
137 |
generate_ids = text_model.generate(input_ids, inputs_embeds=input_embeds, attention_mask=attention_mask, max_new_tokens=300, do_sample=True, suppress_tokens=None) # Uses the default which is temp=0.6, top_p=0.9
|
138 |
|
@@ -150,10 +147,6 @@ def describe_image(image_path):
|
|
150 |
print(f"File not found: {image_path}")
|
151 |
return
|
152 |
|
153 |
-
if not image_path.lower().endswith(".png"):
|
154 |
-
print("File must be PNG.")
|
155 |
-
return
|
156 |
-
|
157 |
image = Image.open(image_path).convert("RGB")
|
158 |
|
159 |
description = proc_img(image)
|
@@ -173,27 +166,36 @@ if __name__ == "__main__":
|
|
173 |
parser = argparse.ArgumentParser(description="Caption all PNG image files in a folder")
|
174 |
parser.add_argument("folder_path", type=str, help="Folder containing images.")
|
175 |
parser.add_argument("--prompt", type=str, help="Prompt to ask a caption.", default=None, required=False)
|
|
|
176 |
args = parser.parse_args()
|
177 |
|
178 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
179 |
folder_path = Path(args.folder_path)
|
180 |
if not folder_path.is_dir():
|
181 |
print(f"Error: {folder_path} is not a valid directory.")
|
182 |
exit(1)
|
183 |
|
184 |
-
png_files = list(folder_path.glob("*.png"))
|
185 |
-
if not png_files:
|
186 |
-
print(f"No PNG files found in the directory: {folder_path}")
|
187 |
-
exit(1)
|
188 |
-
|
189 |
# Prompt
|
190 |
-
if args.
|
191 |
-
|
192 |
else:
|
193 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
194 |
|
195 |
-
total = len(
|
196 |
-
print(f"Found {total}
|
197 |
|
198 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
199 |
|
@@ -228,7 +230,7 @@ if __name__ == "__main__":
|
|
228 |
image_adapter.to(device)
|
229 |
|
230 |
curr = 1
|
231 |
-
for image_path in
|
232 |
print(f"Processing image {curr} of {total}: {image_path}")
|
233 |
curr += 1
|
234 |
describe_image(str(image_path))
|
|
|
11 |
LLMA_CHECKPOINT = "John6666/Llama-3.1-8B-Lexi-Uncensored-V2-nf4"
|
12 |
WORDS=200
|
13 |
PROMPT = "In one paragraph, write a very descriptive caption for this image, describe all objects, characters and their actions, describe in detail what is happening and their emotions. Include information about lighting, the style of this image and information about camera angle within {word_count} words. Don't create any title for the image."
|
14 |
+
IMAGE_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.bmp', '.webp')
|
15 |
HF_TOKEN = os.environ.get("HF_TOKEN", None)
|
16 |
|
17 |
class ImageAdapter(nn.Module):
|
|
|
130 |
], dim=1).to(device)
|
131 |
attention_mask = torch.ones_like(input_ids)
|
132 |
|
|
|
|
|
|
|
133 |
#generate_ids = text_model.generate(input_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask, max_new_tokens=300, do_sample=True, top_k=10, temperature=0.5, suppress_tokens=None)
|
134 |
generate_ids = text_model.generate(input_ids, inputs_embeds=input_embeds, attention_mask=attention_mask, max_new_tokens=300, do_sample=True, suppress_tokens=None) # Uses the default which is temp=0.6, top_p=0.9
|
135 |
|
|
|
147 |
print(f"File not found: {image_path}")
|
148 |
return
|
149 |
|
|
|
|
|
|
|
|
|
150 |
image = Image.open(image_path).convert("RGB")
|
151 |
|
152 |
description = proc_img(image)
|
|
|
166 |
parser = argparse.ArgumentParser(description="Caption all PNG image files in a folder")
|
167 |
parser.add_argument("folder_path", type=str, help="Folder containing images.")
|
168 |
parser.add_argument("--prompt", type=str, help="Prompt to ask a caption.", default=None, required=False)
|
169 |
+
parser.add_argument("--output_dir", type=str, help="Output dir.", default=None, required=False)
|
170 |
args = parser.parse_args()
|
171 |
|
172 |
+
# Prompt
|
173 |
+
if args.prompt is None:
|
174 |
+
prompt_str = PROMPT.format(word_count=WORDS)
|
175 |
+
else:
|
176 |
+
prompt_str = args.prompt
|
177 |
+
|
178 |
+
# Process all images in the folder
|
179 |
folder_path = Path(args.folder_path)
|
180 |
if not folder_path.is_dir():
|
181 |
print(f"Error: {folder_path} is not a valid directory.")
|
182 |
exit(1)
|
183 |
|
|
|
|
|
|
|
|
|
|
|
184 |
# Prompt
|
185 |
+
if args.output_dir is None:
|
186 |
+
output_dir = folder_path
|
187 |
else:
|
188 |
+
output_dir = args.output_dir
|
189 |
+
|
190 |
+
img_files = [f for f in folder_path.iterdir() if f.suffix.lower() in IMAGE_EXTENSIONS]
|
191 |
+
img_files = [f for f in img_files if not Path(output_dir,f"{f.stem}.txt").exists()]
|
192 |
+
|
193 |
+
if not img_files:
|
194 |
+
print(f"No image files without caption found in the directory: {folder_path}")
|
195 |
+
exit(1)
|
196 |
|
197 |
+
total = len(img_files)
|
198 |
+
print(f"Found {total} IMAGE files without caption. Processing...")
|
199 |
|
200 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
201 |
|
|
|
230 |
image_adapter.to(device)
|
231 |
|
232 |
curr = 1
|
233 |
+
for image_path in img_files:
|
234 |
print(f"Processing image {curr} of {total}: {image_path}")
|
235 |
curr += 1
|
236 |
describe_image(str(image_path))
|