wenhu commited on
Commit
b19f216
·
verified ·
1 Parent(s): 014c434

Update model/model_manager.py

Browse files
Files changed (1) hide show
  1. model/model_manager.py +22 -18
model/model_manager.py CHANGED
@@ -18,6 +18,7 @@ class ModelManager:
18
  self.model_vg_list = VIDEO_GENERATION_MODELS
19
  self.excluding_model_list = MUSEUM_UNSUPPORTED_MODELS
20
  self.desired_model_list = DESIRED_APPEAR_MODEL
 
21
  self.loaded_models = {}
22
 
23
  def load_model_pipe(self, model_name):
@@ -28,35 +29,38 @@ class ModelManager:
28
  pipe = self.loaded_models[model_name]
29
  return pipe
30
 
31
- @spaces.GPU(duration=20)
32
- def NSFW_filter(self, prompt):
33
  model_id = "meta-llama/Meta-Llama-Guard-2-8B"
34
- device = "cuda"
35
  dtype = torch.bfloat16
36
- tokenizer = AutoTokenizer.from_pretrained(model_id, token=os.environ['HF_GUARD'])
37
- model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=dtype, device_map=device, token=os.environ['HF_GUARD'])
 
 
 
38
  chat = [{"role": "user", "content": prompt}]
39
- input_ids = tokenizer.apply_chat_template(chat, return_tensors="pt").to(device)
40
- output = model.generate(input_ids=input_ids, max_new_tokens=100, pad_token_id=0)
 
41
  prompt_len = input_ids.shape[-1]
42
- result = tokenizer.decode(output[0][prompt_len:], skip_special_tokens=True)
43
  return result
44
 
45
  @spaces.GPU(duration=120)
46
  def generate_image_ig(self, prompt, model_name):
47
- #if self.NSFW_filter(prompt) == 'safe':
48
- pipe = self.load_model_pipe(model_name)
49
- result = pipe(prompt=prompt)
50
- # else:
51
- # result = ''
52
  return result
53
 
54
  def generate_image_ig_api(self, prompt, model_name):
55
- # if self.NSFW_filter(prompt) == 'safe':
56
- pipe = self.load_model_pipe(model_name)
57
- result = pipe(prompt=prompt)
58
- # else:
59
- # result = ''
60
  return result
61
 
62
  def generate_image_ig_museum(self, model_name):
 
18
  self.model_vg_list = VIDEO_GENERATION_MODELS
19
  self.excluding_model_list = MUSEUM_UNSUPPORTED_MODELS
20
  self.desired_model_list = DESIRED_APPEAR_MODEL
21
+ self.load_guard()
22
  self.loaded_models = {}
23
 
24
  def load_model_pipe(self, model_name):
 
29
  pipe = self.loaded_models[model_name]
30
  return pipe
31
 
32
+ def load_guard(self)
 
33
  model_id = "meta-llama/Meta-Llama-Guard-2-8B"
34
+ device = "cuda" if torch.cuda.is_available() else "cpu"
35
  dtype = torch.bfloat16
36
+ self.tokenizer = AutoTokenizer.from_pretrained(model_id, token=os.environ['HF_GUARD'])
37
+ self.guard = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=dtype, device_map=device, token=os.environ['HF_GUARD'])
38
+
39
+ @spaces.GPU(duration=30)
40
+ def NSFW_filter(self, prompt):
41
  chat = [{"role": "user", "content": prompt}]
42
+ input_ids = self.tokenizer.apply_chat_template(chat, return_tensors="pt").to(device)
43
+ self.guard.cuda()
44
+ output = self.guard.generate(input_ids=input_ids, max_new_tokens=100, pad_token_id=0)
45
  prompt_len = input_ids.shape[-1]
46
+ result = self.tokenizer.decode(output[0][prompt_len:], skip_special_tokens=True)
47
  return result
48
 
49
  @spaces.GPU(duration=120)
50
  def generate_image_ig(self, prompt, model_name):
51
+ if self.NSFW_filter(prompt) == 'safe':
52
+ pipe = self.load_model_pipe(model_name)
53
+ result = pipe(prompt=prompt)
54
+ else:
55
+ result = ''
56
  return result
57
 
58
  def generate_image_ig_api(self, prompt, model_name):
59
+ if self.NSFW_filter(prompt) == 'safe':
60
+ pipe = self.load_model_pipe(model_name)
61
+ result = pipe(prompt=prompt)
62
+ else:
63
+ result = ''
64
  return result
65
 
66
  def generate_image_ig_museum(self, model_name):