Alan Rabello commited on
Commit
d400378
·
1 Parent(s): 0d413ec

Add img extensions

Browse files
Files changed (1) hide show
  1. caption.py +22 -20
caption.py CHANGED
@@ -11,7 +11,7 @@ CHECKPOINT_PATH = Path("./checkpoint")
11
  LLMA_CHECKPOINT = "John6666/Llama-3.1-8B-Lexi-Uncensored-V2-nf4"
12
  WORDS=200
13
  PROMPT = "In one paragraph, write a very descriptive caption for this image, describe all objects, characters and their actions, describe in detail what is happening and their emotions. Include information about lighting, the style of this image and information about camera angle within {word_count} words. Don't create any title for the image."
14
-
15
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
16
 
17
  class ImageAdapter(nn.Module):
@@ -130,9 +130,6 @@ def proc_img(input_image):
130
  ], dim=1).to(device)
131
  attention_mask = torch.ones_like(input_ids)
132
 
133
- # Debugging
134
- #print(f"Input to model: {repr(tokenizer.decode(input_ids[0]))}")
135
-
136
  #generate_ids = text_model.generate(input_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask, max_new_tokens=300, do_sample=True, top_k=10, temperature=0.5, suppress_tokens=None)
137
  generate_ids = text_model.generate(input_ids, inputs_embeds=input_embeds, attention_mask=attention_mask, max_new_tokens=300, do_sample=True, suppress_tokens=None) # Uses the default which is temp=0.6, top_p=0.9
138
 
@@ -150,10 +147,6 @@ def describe_image(image_path):
150
  print(f"File not found: {image_path}")
151
  return
152
 
153
- if not image_path.lower().endswith(".png"):
154
- print("File must be PNG.")
155
- return
156
-
157
  image = Image.open(image_path).convert("RGB")
158
 
159
  description = proc_img(image)
@@ -173,27 +166,36 @@ if __name__ == "__main__":
173
  parser = argparse.ArgumentParser(description="Caption all PNG image files in a folder")
174
  parser.add_argument("folder_path", type=str, help="Folder containing images.")
175
  parser.add_argument("--prompt", type=str, help="Prompt to ask a caption.", default=None, required=False)
 
176
  args = parser.parse_args()
177
 
178
- # Process all PNG images in the folder
 
 
 
 
 
 
179
  folder_path = Path(args.folder_path)
180
  if not folder_path.is_dir():
181
  print(f"Error: {folder_path} is not a valid directory.")
182
  exit(1)
183
 
184
- png_files = list(folder_path.glob("*.png"))
185
- if not png_files:
186
- print(f"No PNG files found in the directory: {folder_path}")
187
- exit(1)
188
-
189
  # Prompt
190
- if args.prompt is None:
191
- prompt_str = PROMPT.format(word_count=WORDS)
192
  else:
193
- prompt_str = args.prompt
 
 
 
 
 
 
 
194
 
195
- total = len(png_files)
196
- print(f"Found {total} PNG files. Processing...")
197
 
198
  device = "cuda" if torch.cuda.is_available() else "cpu"
199
 
@@ -228,7 +230,7 @@ if __name__ == "__main__":
228
  image_adapter.to(device)
229
 
230
  curr = 1
231
- for image_path in png_files:
232
  print(f"Processing image {curr} of {total}: {image_path}")
233
  curr += 1
234
  describe_image(str(image_path))
 
11
  LLMA_CHECKPOINT = "John6666/Llama-3.1-8B-Lexi-Uncensored-V2-nf4"
12
  WORDS=200
13
  PROMPT = "In one paragraph, write a very descriptive caption for this image, describe all objects, characters and their actions, describe in detail what is happening and their emotions. Include information about lighting, the style of this image and information about camera angle within {word_count} words. Don't create any title for the image."
14
+ IMAGE_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.bmp', '.webp')
15
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
16
 
17
  class ImageAdapter(nn.Module):
 
130
  ], dim=1).to(device)
131
  attention_mask = torch.ones_like(input_ids)
132
 
 
 
 
133
  #generate_ids = text_model.generate(input_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask, max_new_tokens=300, do_sample=True, top_k=10, temperature=0.5, suppress_tokens=None)
134
  generate_ids = text_model.generate(input_ids, inputs_embeds=input_embeds, attention_mask=attention_mask, max_new_tokens=300, do_sample=True, suppress_tokens=None) # Uses the default which is temp=0.6, top_p=0.9
135
 
 
147
  print(f"File not found: {image_path}")
148
  return
149
 
 
 
 
 
150
  image = Image.open(image_path).convert("RGB")
151
 
152
  description = proc_img(image)
 
166
  parser = argparse.ArgumentParser(description="Caption all PNG image files in a folder")
167
  parser.add_argument("folder_path", type=str, help="Folder containing images.")
168
  parser.add_argument("--prompt", type=str, help="Prompt to ask a caption.", default=None, required=False)
169
+ parser.add_argument("--output_dir", type=str, help="Output dir.", default=None, required=False)
170
  args = parser.parse_args()
171
 
172
+ # Prompt
173
+ if args.prompt is None:
174
+ prompt_str = PROMPT.format(word_count=WORDS)
175
+ else:
176
+ prompt_str = args.prompt
177
+
178
+ # Process all images in the folder
179
  folder_path = Path(args.folder_path)
180
  if not folder_path.is_dir():
181
  print(f"Error: {folder_path} is not a valid directory.")
182
  exit(1)
183
 
 
 
 
 
 
184
  # Prompt
185
+ if args.output_dir is None:
186
+ output_dir = folder_path
187
  else:
188
+ output_dir = args.output_dir
189
+
190
+ img_files = [f for f in folder_path.iterdir() if f.suffix.lower() in IMAGE_EXTENSIONS]
191
+ img_files = [f for f in img_files if not Path(output_dir,f"{f.stem}.txt").exists()]
192
+
193
+ if not img_files:
194
+ print(f"No image files without caption found in the directory: {folder_path}")
195
+ exit(1)
196
 
197
+ total = len(img_files)
198
+ print(f"Found {total} IMAGE files without caption. Processing...")
199
 
200
  device = "cuda" if torch.cuda.is_available() else "cpu"
201
 
 
230
  image_adapter.to(device)
231
 
232
  curr = 1
233
+ for image_path in img_files:
234
  print(f"Processing image {curr} of {total}: {image_path}")
235
  curr += 1
236
  describe_image(str(image_path))