Guchyos commited on
Commit
5ff79ad
·
verified ·
1 Parent(s): 8a5278b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -21
app.py CHANGED
@@ -8,36 +8,39 @@ tokenizer = None
8
  def load_model():
9
  global model, tokenizer
10
  if model is None:
11
- model_name = "line-corporation/japanese-large-lm-1.7b" # LINEの軽量モデル
12
  tokenizer = AutoTokenizer.from_pretrained(model_name)
13
  model = AutoModelForCausalLM.from_pretrained(
14
  model_name,
15
- device_map="cpu",
16
- low_cpu_mem_usage=True
17
  )
18
  return model, tokenizer
19
 
20
- import gradio as gr
21
- from transformers import AutoModelForCausalLM, AutoTokenizer
22
- import torch
23
- import sentencepiece as spm
24
-
25
  def predict(message, history):
26
  try:
27
- model_name = "rinna/japanese-gpt-neox-small"
28
- tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
29
- model = AutoModelForCausalLM.from_pretrained(
30
- model_name,
31
- device_map="cpu",
32
- trust_remote_code=True
33
- )
34
 
35
- inputs = tokenizer(message, return_tensors="pt")
36
- outputs = model.generate(**inputs, max_length=64)
37
- return tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
 
 
 
38
 
39
  except Exception as e:
40
- return f"エラー: {str(e)}"
 
 
 
 
 
 
41
 
42
- demo = gr.ChatInterface(fn=predict)
43
- demo.launch()
 
8
  def load_model():
9
  global model, tokenizer
10
  if model is None:
11
+ model_name = "Guchyos/gemma-2b-elyza-task"
12
  tokenizer = AutoTokenizer.from_pretrained(model_name)
13
  model = AutoModelForCausalLM.from_pretrained(
14
  model_name,
15
+ torch_dtype=torch.float32, # float32を使用
16
+ device_map="cpu"
17
  )
18
  return model, tokenizer
19
 
 
 
 
 
 
20
  def predict(message, history):
21
  try:
22
+ model, tokenizer = load_model()
23
+ prompt = f"質問: {message}\n\n回答:"
24
+ inputs = tokenizer(prompt, return_tensors="pt")
 
 
 
 
25
 
26
+ with torch.no_grad():
27
+ outputs = model.generate(
28
+ **inputs,
29
+ max_new_tokens=128,
30
+ do_sample=False
31
+ )
32
+
33
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
34
+ return response.replace(prompt, "").strip()
35
 
36
  except Exception as e:
37
+ return f"エラーが発生しました: {str(e)}"
38
+
39
+ demo = gr.ChatInterface(
40
+ fn=predict,
41
+ title="💬 Gemma 2 for ELYZA-tasks",
42
+ description="ELYZA-tasks-100-TV用に最適化された日本語LLMです"
43
+ )
44
 
45
+ if __name__ == "__main__":
46
+ demo.launch(share=True)