ardalan.mehrani commited on
Commit
0a70842
·
1 Parent(s): b4837db

add embedding model

Browse files
configuration_internvl_chat.py CHANGED
@@ -49,7 +49,7 @@ class InternVLChatConfig(PretrainedConfig):
49
  self.vision_config = InternVisionConfig(**vision_config)
50
  if llm_config.get('architectures')[0] == 'LlamaForCausalLM':
51
  self.llm_config = LlamaConfig(**llm_config)
52
- elif llm_config.get('architectures')[0] == 'InternLM2ForCausalLM':
53
  self.llm_config = InternLM2Config(**llm_config)
54
  else:
55
  raise ValueError('Unsupported architecture: {}'.format(llm_config.get('architectures')[0]))
 
49
  self.vision_config = InternVisionConfig(**vision_config)
50
  if llm_config.get('architectures')[0] == 'LlamaForCausalLM':
51
  self.llm_config = LlamaConfig(**llm_config)
52
+ elif llm_config.get('architectures')[0] in ['InternLM2ForCausalLM', 'InternLM2ForSequenceClassification']:
53
  self.llm_config = InternLM2Config(**llm_config)
54
  else:
55
  raise ValueError('Unsupported architecture: {}'.format(llm_config.get('architectures')[0]))
modeling_internvl_chat.py CHANGED
@@ -20,7 +20,7 @@ from transformers.utils import ModelOutput, logging
20
  from .configuration_internvl_chat import InternVLChatConfig
21
  from .conversation import get_conv_template
22
  from .modeling_intern_vit import InternVisionModel, has_flash_attn
23
- from .modeling_internlm2 import InternLM2ForCausalLM
24
 
25
  logger = logging.get_logger(__name__)
26
 
@@ -69,6 +69,8 @@ class InternVLChatModel(PreTrainedModel):
69
  self.language_model = LlamaForCausalLM(config.llm_config)
70
  elif config.llm_config.architectures[0] == 'InternLM2ForCausalLM':
71
  self.language_model = InternLM2ForCausalLM(config.llm_config)
 
 
72
  else:
73
  raise NotImplementedError(f'{config.llm_config.architectures[0]} is not implemented.')
74
 
@@ -289,10 +291,10 @@ class InternVLChatModel(PreTrainedModel):
289
  return response
290
 
291
  def build_query(self, question, history, num_patches_list=None, IMG_START_TOKEN='<img>',
292
- IMG_END_TOKEN='</img>', IMG_CONTEXT_TOKEN='<IMG_CONTEXT>'):
293
 
294
  template = get_conv_template(self.template)
295
- template.system_message = self.system_message
296
 
297
  for (old_question, old_answer) in history:
298
  template.append_message(template.roles[0], old_question)
@@ -308,6 +310,48 @@ class InternVLChatModel(PreTrainedModel):
308
  return query
309
 
310
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
311
  @torch.no_grad()
312
  def generate(
313
  self,
 
20
  from .configuration_internvl_chat import InternVLChatConfig
21
  from .conversation import get_conv_template
22
  from .modeling_intern_vit import InternVisionModel, has_flash_attn
23
+ from .modeling_internlm2 import InternLM2ForCausalLM, InternLM2ForSequenceClassification
24
 
25
  logger = logging.get_logger(__name__)
26
 
 
69
  self.language_model = LlamaForCausalLM(config.llm_config)
70
  elif config.llm_config.architectures[0] == 'InternLM2ForCausalLM':
71
  self.language_model = InternLM2ForCausalLM(config.llm_config)
72
+ elif config.llm_config.architectures[0] == 'InternLM2ForSequenceClassification':
73
+ self.language_model = InternLM2ForSequenceClassification(config.llm_config)
74
  else:
75
  raise NotImplementedError(f'{config.llm_config.architectures[0]} is not implemented.')
76
 
 
291
  return response
292
 
293
  def build_query(self, question, history, num_patches_list=None, IMG_START_TOKEN='<img>',
294
+ IMG_END_TOKEN='</img>', IMG_CONTEXT_TOKEN='<IMG_CONTEXT>', system_message=None):
295
 
296
  template = get_conv_template(self.template)
297
+ template.system_message = system_message or self.system_message
298
 
299
  for (old_question, old_answer) in history:
300
  template.append_message(template.roles[0], old_question)
 
310
  return query
311
 
312
 
313
+ def batch_embedding(self, tokenizer, pixel_values, questions, num_patches_list=None,
314
+ IMG_START_TOKEN='<img>', IMG_END_TOKEN='</img>',
315
+ IMG_CONTEXT_TOKEN='<IMG_CONTEXT>'):
316
+
317
+ img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
318
+ self.img_context_token_id = img_context_token_id
319
+ assert self.img_context_token_id is not None
320
+
321
+ queries = []
322
+ for q, num_patches in zip(questions, num_patches_list):
323
+ query = self.build_query(q, [], num_patches, IMG_START_TOKEN, IMG_END_TOKEN, IMG_CONTEXT_TOKEN, system_message='')
324
+ query = query[30:-23]
325
+ queries.append(query)
326
+
327
+ tokenizer.padding_side = 'left'
328
+ model_inputs = tokenizer(queries, return_tensors='pt', padding=True)
329
+ input_ids = model_inputs['input_ids'].to(self.device)
330
+ attention_mask = model_inputs['attention_mask'].to(self.device)
331
+ template = get_conv_template(self.template)
332
+ eos_token_id = tokenizer.convert_tokens_to_ids(template.sep.strip())
333
+
334
+ vit_embeds = self.extract_feature(pixel_values)
335
+ input_embeds = self.language_model.get_input_embeddings()(input_ids)
336
+ B, N, C = input_embeds.shape
337
+ input_embeds = input_embeds.reshape(B * N, C)
338
+
339
+ input_ids = input_ids.reshape(B * N)
340
+ selected = (input_ids == self.img_context_token_id)
341
+ assert selected.sum() != 0
342
+ input_embeds[selected] = vit_embeds.reshape(-1, C).to(input_embeds.device)
343
+
344
+ input_embeds = input_embeds.reshape(B, N, C)
345
+
346
+ output = self.language_model(
347
+ inputs_embeds=input_embeds,
348
+ attention_mask=attention_mask,
349
+ output_attentions=True,
350
+ output_hidden_states=True,
351
+ return_dict=True
352
+ )
353
+ return output
354
+
355
  @torch.no_grad()
356
  def generate(
357
  self,