Yosemat commited on
Commit
fd8a4dd
·
verified ·
1 Parent(s): 8eeaa19

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_internlm_xcomposer2.py +57 -22
modeling_internlm_xcomposer2.py CHANGED
@@ -35,6 +35,7 @@ from transformers import (
35
  StoppingCriteriaList,
36
  set_seed,
37
  )
 
38
  from transformers.generation.streamers import BaseStreamer
39
  from transformers.modeling_outputs import CausalLMOutputWithPast
40
  from transformers.utils import (
@@ -52,6 +53,8 @@ from .modeling_internlm2 import (
52
  )
53
 
54
  _CONFIG_FOR_DOC = "InternLMXcomposer2Config"
 
 
55
 
56
  image_extensions = {".jpg", ".jpeg", ".png", ".gif", ".bmp", ".webp"}
57
  video_extensions = {".mp4", ".avi", ".mkv", ".mov", ".wmv"}
@@ -103,7 +106,7 @@ class InternLMXComposer2ForCausalLM(InternLM2PreTrainedModel):
103
  self.model = InternLM2Model(config)
104
  self.vocab_size = config.vocab_size
105
  self.output = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
106
- self.tokenizer = None
107
  self.hd_num = 25
108
  self.font = get_font()
109
 
@@ -245,12 +248,12 @@ class InternLMXComposer2ForCausalLM(InternLM2PreTrainedModel):
245
  self.max_length = max_length
246
  prompt = ""
247
  if meta_instruction:
248
- prompt += (
249
- f"""[UNUSED_TOKEN_146]system\n{meta_instruction}[UNUSED_TOKEN_145]\n"""
250
- )
251
  for record in history:
252
- prompt += f"""[UNUSED_TOKEN_146]user\n{record[0]}[UNUSED_TOKEN_145]\n[UNUSED_TOKEN_146]assistant\n{record[1]}[UNUSED_TOKEN_145]\n"""
253
- prompt += f"""[UNUSED_TOKEN_146]user\n{query}[UNUSED_TOKEN_145]\n[UNUSED_TOKEN_146]assistant\n"""
 
 
254
 
255
  image_nums = len(image)
256
  if image_nums == 1 and prompt.find("<ImageHere>") == -1:
@@ -587,7 +590,7 @@ class InternLMXComposer2ForCausalLM(InternLM2PreTrainedModel):
587
  shift_labels = labels[..., 1:].contiguous()
588
  # Flatten the tokens
589
  loss_fct = CrossEntropyLoss()
590
- shift_logits = shift_logits.view(-1, self.config.vocab_size)
591
  shift_labels = shift_labels.view(-1)
592
  # Enable model parallelism
593
  shift_labels = shift_labels.to(shift_logits.device)
@@ -676,12 +679,14 @@ class InternLMXComposer2ForCausalLM(InternLM2PreTrainedModel):
676
  ):
677
  prompt = ""
678
  if meta_instruction:
679
- prompt += f"""<s>[UNUSED_TOKEN_146]system\n{meta_instruction}[UNUSED_TOKEN_145]\n"""
680
  else:
681
  prompt += "<s>"
682
  for record in history:
683
- prompt += f"""[UNUSED_TOKEN_146]user\n{record[0]}[UNUSED_TOKEN_145]\n[UNUSED_TOKEN_146]assistant\n{record[1]}[UNUSED_TOKEN_145]\n"""
684
- prompt += f"""[UNUSED_TOKEN_146]user\n{query}[UNUSED_TOKEN_145]\n[UNUSED_TOKEN_146]assistant\n"""
 
 
685
  return tokenizer([prompt], return_tensors="pt")
686
 
687
  @torch.no_grad()
@@ -724,7 +729,7 @@ class InternLMXComposer2ForCausalLM(InternLM2PreTrainedModel):
724
  # also add end-of-assistant token in eos token id to avoid unnecessary generation
725
  eos_token_id = [
726
  tokenizer.eos_token_id,
727
- tokenizer.convert_tokens_to_ids(["[UNUSED_TOKEN_145]"])[0],
728
  ]
729
  outputs = self.generate(
730
  **inputs,
@@ -745,7 +750,7 @@ class InternLMXComposer2ForCausalLM(InternLM2PreTrainedModel):
745
  else:
746
  outputs = outputs[0].cpu().tolist()
747
  response = tokenizer.decode(outputs, skip_special_tokens=True)
748
- response = response.split("[UNUSED_TOKEN_145]")[0]
749
  history = history + [(query, response)]
750
  return response, history
751
 
@@ -807,8 +812,8 @@ class InternLMXComposer2ForCausalLM(InternLM2PreTrainedModel):
807
  response = generate[0].tolist()
808
  response = self.tokenizer.decode(response, skip_special_tokens=True) # type: ignore
809
  # remove eoa
810
- response = response.replace("[UNUSED_TOKEN_145]", "")
811
- response = response.replace("[UNUSED_TOKEN_146]", "")
812
 
813
  return response
814
 
@@ -847,8 +852,8 @@ class InternLMXComposer2ForCausalLM(InternLM2PreTrainedModel):
847
  response = generate[0].tolist()
848
  response = self.tokenizer.decode(response, skip_special_tokens=True) # type: ignore
849
  # remove eoa
850
- response = response.replace("[UNUSED_TOKEN_145]", "")
851
- out = response.replace("[UNUSED_TOKEN_146]", "")
852
  image_type = "random"
853
  pattern = r"""https://source\.unsplash\.com/random/(\d+)x(\d+)/\?([^'"]+)"""
854
  if image_type == "placeholder":
@@ -900,8 +905,8 @@ class InternLMXComposer2ForCausalLM(InternLM2PreTrainedModel):
900
  response = generate[0].tolist()
901
  response = self.tokenizer.decode(response, skip_special_tokens=True) # type: ignore
902
  # remove eoa
903
- response = response.replace("[UNUSED_TOKEN_145]", "")
904
- html = response.replace("[UNUSED_TOKEN_146]", "")
905
 
906
  if seed != -1:
907
  set_random_seed(seed, set_cudnn=True)
@@ -923,8 +928,8 @@ class InternLMXComposer2ForCausalLM(InternLM2PreTrainedModel):
923
  response = generate[0].tolist()
924
  response = self.tokenizer.decode(response, skip_special_tokens=True) # type: ignore
925
  # remove eoa
926
- response = response.replace("[UNUSED_TOKEN_145]", "")
927
- js = response.replace("[UNUSED_TOKEN_146]", "")
928
 
929
  if re.search(r"</script>", html):
930
  js = re.findall(r"<script>([\s\S]*?)<\/script>", js)
@@ -983,8 +988,8 @@ class InternLMXComposer2ForCausalLM(InternLM2PreTrainedModel):
983
  response = generate[0].tolist()
984
  response = self.tokenizer.decode(response, skip_special_tokens=True) # type: ignore
985
  # remove eoa
986
- response = response.replace("[UNUSED_TOKEN_145]", "")
987
- out = response.replace("[UNUSED_TOKEN_146]", "")
988
  image_type = "random"
989
  pattern = r"""https://source\.unsplash\.com/random/(\d+)x(\d+)/\?([^'"]+)"""
990
  if image_type == "placeholder":
@@ -995,3 +1000,33 @@ class InternLMXComposer2ForCausalLM(InternLM2PreTrainedModel):
995
  with open(task.replace(" ", "_") + ".html", "w") as f:
996
  f.write(out)
997
  return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  StoppingCriteriaList,
36
  set_seed,
37
  )
38
+ from transformers import PreTrainedTokenizer
39
  from transformers.generation.streamers import BaseStreamer
40
  from transformers.modeling_outputs import CausalLMOutputWithPast
41
  from transformers.utils import (
 
53
  )
54
 
55
  _CONFIG_FOR_DOC = "InternLMXcomposer2Config"
56
+ FROM_TOKEN_1 = "[UNUSED_TOKEN_146]"
57
+ FROM_TOKEN_2 = "[UNUSED_TOKEN_145]"
58
 
59
  image_extensions = {".jpg", ".jpeg", ".png", ".gif", ".bmp", ".webp"}
60
  video_extensions = {".mp4", ".avi", ".mkv", ".mov", ".wmv"}
 
106
  self.model = InternLM2Model(config)
107
  self.vocab_size = config.vocab_size
108
  self.output = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
109
+ self.tokenizer: PreTrainedTokenizer = None # type: ignore
110
  self.hd_num = 25
111
  self.font = get_font()
112
 
 
248
  self.max_length = max_length
249
  prompt = ""
250
  if meta_instruction:
251
+ prompt += f"""{FROM_TOKEN_1}system\n{meta_instruction}{FROM_TOKEN_2}\n"""
 
 
252
  for record in history:
253
+ prompt += f"""{FROM_TOKEN_1}user\n{record[0]}{FROM_TOKEN_2}\n{FROM_TOKEN_1}assistant\n{record[1]}{FROM_TOKEN_2}\n"""
254
+ prompt += (
255
+ f"""{FROM_TOKEN_1}user\n{query}{FROM_TOKEN_2}\n{FROM_TOKEN_1}assistant\n"""
256
+ )
257
 
258
  image_nums = len(image)
259
  if image_nums == 1 and prompt.find("<ImageHere>") == -1:
 
590
  shift_labels = labels[..., 1:].contiguous()
591
  # Flatten the tokens
592
  loss_fct = CrossEntropyLoss()
593
+ shift_logits = shift_logits.view(-1, self.vocab_size)
594
  shift_labels = shift_labels.view(-1)
595
  # Enable model parallelism
596
  shift_labels = shift_labels.to(shift_logits.device)
 
679
  ):
680
  prompt = ""
681
  if meta_instruction:
682
+ prompt += f"""<s>{FROM_TOKEN_1}system\n{meta_instruction}{FROM_TOKEN_2}\n"""
683
  else:
684
  prompt += "<s>"
685
  for record in history:
686
+ prompt += f"""{FROM_TOKEN_1}user\n{record[0]}{FROM_TOKEN_2}\n{FROM_TOKEN_1}assistant\n{record[1]}{FROM_TOKEN_2}\n"""
687
+ prompt += (
688
+ f"""{FROM_TOKEN_1}user\n{query}{FROM_TOKEN_2}\n{FROM_TOKEN_1}assistant\n"""
689
+ )
690
  return tokenizer([prompt], return_tensors="pt")
691
 
692
  @torch.no_grad()
 
729
  # also add end-of-assistant token in eos token id to avoid unnecessary generation
730
  eos_token_id = [
731
  tokenizer.eos_token_id,
732
+ tokenizer.convert_tokens_to_ids([FROM_TOKEN_2])[0],
733
  ]
734
  outputs = self.generate(
735
  **inputs,
 
750
  else:
751
  outputs = outputs[0].cpu().tolist()
752
  response = tokenizer.decode(outputs, skip_special_tokens=True)
753
+ response = response.split(FROM_TOKEN_2)[0]
754
  history = history + [(query, response)]
755
  return response, history
756
 
 
812
  response = generate[0].tolist()
813
  response = self.tokenizer.decode(response, skip_special_tokens=True) # type: ignore
814
  # remove eoa
815
+ response = response.replace(FROM_TOKEN_2, "")
816
+ response = response.replace(FROM_TOKEN_1, "")
817
 
818
  return response
819
 
 
852
  response = generate[0].tolist()
853
  response = self.tokenizer.decode(response, skip_special_tokens=True) # type: ignore
854
  # remove eoa
855
+ response = response.replace(FROM_TOKEN_2, "")
856
+ out = response.replace(FROM_TOKEN_1, "")
857
  image_type = "random"
858
  pattern = r"""https://source\.unsplash\.com/random/(\d+)x(\d+)/\?([^'"]+)"""
859
  if image_type == "placeholder":
 
905
  response = generate[0].tolist()
906
  response = self.tokenizer.decode(response, skip_special_tokens=True) # type: ignore
907
  # remove eoa
908
+ response = response.replace(FROM_TOKEN_2, "")
909
+ html = response.replace(FROM_TOKEN_1, "")
910
 
911
  if seed != -1:
912
  set_random_seed(seed, set_cudnn=True)
 
928
  response = generate[0].tolist()
929
  response = self.tokenizer.decode(response, skip_special_tokens=True) # type: ignore
930
  # remove eoa
931
+ response = response.replace(FROM_TOKEN_2, "")
932
+ js = response.replace(FROM_TOKEN_1, "")
933
 
934
  if re.search(r"</script>", html):
935
  js = re.findall(r"<script>([\s\S]*?)<\/script>", js)
 
988
  response = generate[0].tolist()
989
  response = self.tokenizer.decode(response, skip_special_tokens=True) # type: ignore
990
  # remove eoa
991
+ response = response.replace(FROM_TOKEN_2, "")
992
+ out = response.replace(FROM_TOKEN_1, "")
993
  image_type = "random"
994
  pattern = r"""https://source\.unsplash\.com/random/(\d+)x(\d+)/\?([^'"]+)"""
995
  if image_type == "placeholder":
 
1000
  with open(task.replace(" ", "_") + ".html", "w") as f:
1001
  f.write(out)
1002
  return out
1003
+
1004
+ def add_tokens(self, new_tokens: list[str]):
1005
+ self.tokenizer.add_tokens(new_tokens) # type: ignore
1006
+ self.model.resize_token_embeddings(len(self.tokenizer))
1007
+ self.vocab_size = len(self.tokenizer)
1008
+
1009
+ # self.output needs to be resized accordingly but without loosing the weight
1010
+ new_output = nn.Linear(
1011
+ self.model.config.hidden_size,
1012
+ self.vocab_size,
1013
+ bias=False,
1014
+ dtype=self.output.weight.dtype,
1015
+ device=self.output.weight.device,
1016
+ ).to(self.device)
1017
+ new_output.weight.data[: self.output.weight.shape[0]] = self.output.weight.data
1018
+ dummy_input_for_output = torch.zeros(
1019
+ 1,
1020
+ 1,
1021
+ self.model.config.hidden_size,
1022
+ device=new_output.weight.device,
1023
+ dtype=new_output.weight.dtype,
1024
+ ).type_as(new_output.weight)
1025
+ # Check if output has same behavior
1026
+ dummy_old_output: torch.Tensor = self.output(dummy_input_for_output)
1027
+ dummy_new_output = new_output(dummy_input_for_output)
1028
+ assert dummy_old_output.allclose(
1029
+ dummy_new_output[:, :, : self.output.weight.shape[0]]
1030
+ )
1031
+
1032
+ self.output = new_output