cutechicken commited on
Commit
2db4e16
·
verified ·
1 Parent(s): 5416ad2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -45
app.py CHANGED
@@ -1,78 +1,71 @@
1
  import os
2
  from dotenv import load_dotenv
3
  import gradio as gr
4
- from huggingface_hub import InferenceClient
5
  import pandas as pd
6
  import json
7
  from datetime import datetime
8
  import torch
9
- from transformers import AutoModelForCausalLM, AutoTokenizer
10
 
11
  # 환경 변수 설정
12
  HF_TOKEN = os.getenv("HF_TOKEN")
13
  MODEL_ID = "CohereForAI/c4ai-command-r7b-12-2024"
14
 
15
- from transformers import pipeline
16
-
17
  class ModelManager:
18
  def __init__(self):
19
- self.pipe = None
20
- self.setup_pipeline()
 
21
 
22
- def setup_pipeline(self):
23
  try:
24
- print("파이프라인 초기화 시작...")
25
- self.pipe = pipeline(
26
- "text-generation",
27
- model=MODEL_ID,
 
 
 
28
  token=HF_TOKEN,
29
- device_map="auto",
30
- torch_dtype=torch.float16
31
  )
32
- print("파이프라인 초기화 완료")
33
  except Exception as e:
34
- print(f"파이프라인 초기화 중 오류 발생: {e}")
35
- raise Exception(f"파이프라인 초기화 실패: {e}")
36
 
37
  def generate_response(self, messages, max_tokens=4000, temperature=0.7, top_p=0.9):
38
  try:
39
- # 메시지 형식 변환
40
- prompt = ""
41
- for msg in messages:
42
- role = msg["role"]
43
- content = msg["content"]
44
- if role == "system":
45
- prompt += f"System: {content}\n"
46
- elif role == "user":
47
- prompt += f"User: {content}\n"
48
- elif role == "assistant":
49
- prompt += f"Assistant: {content}\n"
50
-
51
- # 응답 생성
52
- response = self.pipe(
53
- prompt,
54
  max_new_tokens=max_tokens,
 
55
  temperature=temperature,
56
  top_p=top_p,
57
- do_sample=True,
58
- num_return_sequences=1,
59
- pad_token_id=self.pipe.tokenizer.eos_token_id
60
  )
61
-
62
- # 응답 텍스트 추출 및 스트리밍 시뮬레이션
63
- generated_text = response[0]['generated_text'][len(prompt):].strip()
64
- words = generated_text.split()
65
-
66
- # 단어 단위로 스트리밍
67
- partial_response = ""
68
- for word in words:
69
- partial_response += word + " "
70
  yield type('Response', (), {
71
  'choices': [type('Choice', (), {
72
- 'delta': {'content': word + " "}
73
  })()]
74
  })()
75
-
76
  except Exception as e:
77
  raise Exception(f"응답 생성 실패: {e}")
78
 
 
1
  import os
2
  from dotenv import load_dotenv
3
  import gradio as gr
 
4
  import pandas as pd
5
  import json
6
  from datetime import datetime
7
  import torch
8
+ from transformers import AutoTokenizer, AutoModelForCausalLM
9
 
10
  # 환경 변수 설정
11
  HF_TOKEN = os.getenv("HF_TOKEN")
12
  MODEL_ID = "CohereForAI/c4ai-command-r7b-12-2024"
13
 
 
 
14
  class ModelManager:
15
  def __init__(self):
16
+ self.tokenizer = None
17
+ self.model = None
18
+ self.setup_model()
19
 
20
+ def setup_model(self):
21
  try:
22
+ print("토크나이저 로딩 시작...")
23
+ self.tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=HF_TOKEN)
24
+ print("토크나이저 로딩 완료")
25
+
26
+ print("모델 로딩 시작...")
27
+ self.model = AutoModelForCausalLM.from_pretrained(
28
+ MODEL_ID,
29
  token=HF_TOKEN,
30
+ torch_dtype=torch.float16,
31
+ device_map="auto"
32
  )
33
+ print("모델 로딩 완료")
34
  except Exception as e:
35
+ print(f"모델 로딩 중 오류 발생: {e}")
36
+ raise Exception(f"모델 로딩 실패: {e}")
37
 
38
  def generate_response(self, messages, max_tokens=4000, temperature=0.7, top_p=0.9):
39
  try:
40
+ # 채팅 템플릿 적용
41
+ input_ids = self.tokenizer.apply_chat_template(
42
+ messages,
43
+ tokenize=True,
44
+ add_generation_prompt=True,
45
+ return_tensors="pt"
46
+ ).to(self.model.device)
47
+
48
+ # 토큰 생성
49
+ gen_tokens = self.model.generate(
50
+ input_ids,
 
 
 
 
51
  max_new_tokens=max_tokens,
52
+ do_sample=True,
53
  temperature=temperature,
54
  top_p=top_p,
55
+ pad_token_id=self.tokenizer.eos_token_id,
56
+ streamer=TextIteratorStreamer(self.tokenizer, skip_special_tokens=True)
 
57
  )
58
+
59
+ # 응답 디코딩 및 스트리밍
60
+ response_text = ""
61
+ for new_text in self.tokenizer.decode(gen_tokens[0], skip_special_tokens=True):
62
+ response_text += new_text
 
 
 
 
63
  yield type('Response', (), {
64
  'choices': [type('Choice', (), {
65
+ 'delta': {'content': new_text}
66
  })()]
67
  })()
68
+
69
  except Exception as e:
70
  raise Exception(f"응답 생성 실패: {e}")
71