ericsorides commited on
Commit
19ef1fb
·
verified ·
1 Parent(s): 3b514c7

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +138 -0
README.md ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ tags:
3
+ - text-generation-inference
4
+ - gemma
5
+ ---
6
+
7
+
8
+ # Gemma 2B instruct with Key-Value-Cache enabled in ONNX fp16 format
9
+ - Model creator: [Google](https://huggingface.co/google)
10
+ - Original model: [Gemma 2B instruct](https://huggingface.co/google/gemma-2b-it)
11
+
12
+ <!-- description start -->
13
+ ## Description
14
+
15
+ This repo contains the ONNX files of the ONNX conversion of Gemma 2B instruct done by Esperanto Technologies.
16
+ The model is in the fp16 format and has the KVC enabled.
17
+
18
+ <!-- description end -->
19
+
20
+ ## How to download ONNX model and weight files
21
+
22
+ The easiest way to obtain the model is to clone this whole repo.
23
+ Alternatively you can download the files is using the `huggingface-hub` Python library.
24
+
25
+ ```shell
26
+ pip3 install huggingface-hub>=0.17.1
27
+ ```
28
+
29
+ Then you can download any individual model file to the current directory, at high speed, with a command like this:
30
+
31
+ ```shell
32
+ huggingface-cli download Esperanto/gemma-2b-it-kvc-fp16-onnx --local-dir gemma-2b-it-kvc-fp16-onnx --local-dir-use-symlinks False
33
+ ```
34
+
35
+ For more documentation on downloading with `huggingface-cli`, please see: [HF -> Hub Python Library -> Download files -> Download from the CLI](https://huggingface.co/docs/huggingface_hub/guides/download#download-from-the-cli).
36
+
37
+ ## How to run from Python code using ONNXRuntime
38
+
39
+ This model can easily be ran in a CPU using [ONNXRuntime](https://onnxruntime.ai/).
40
+
41
+ #### First install the packages
42
+
43
+ ```bash
44
+ pip3 install onnx==1.16.1
45
+ pip3 install onnxruntime==1.17.1
46
+ ```
47
+
48
+ #### Example code: generate text with this model
49
+
50
+ We define the loop with greedy decoding:
51
+ ```python
52
+ import numpy as np
53
+ import onnxruntime
54
+ import onnx
55
+ from transformers import AutoTokenizer
56
+
57
+ def generate_text(model_path, prompt, tokenizer, max_gen_tokens, total_sequence, window, context):
58
+ model = onnx.load(model_path)
59
+
60
+ #we create the inputs for the first iteration
61
+ input_tensor = tokenizer(prompt, return_tensors="pt")
62
+ prompt_size = len(input_tensor['input_ids'][0])
63
+ actual_input = input_tensor['input_ids']
64
+ if prompt_size < window:
65
+ actual_input = np.concatenate((tokenizer.bos_token_id*np.ones([1, window - prompt_size], dtype = 'int64'),
66
+ actual_input), axis=1)
67
+ if prompt_size + max_gen_tokens > total_sequence:
68
+ print("ERROR: Longer total sequence is needed!")
69
+ return
70
+ first_attention = np.concatenate((np.zeros([1, total_sequence - window], dtype = 'int64'),
71
+ np.ones((1, window), dtype = 'int64')), axis=1)
72
+ max_gen_tokens += prompt_size #we need to generate on top of parsing the prompt
73
+ inputs_names =[node.name for node in model.graph.input]
74
+ output_names =[node.name for node in model.graph.output]
75
+ n_heads = 1 #gqa-heads of the kvc
76
+ inputs_dict = {}
77
+ inputs_dict['input_ids'] = actual_input[:, :window].reshape(1, window).numpy()
78
+ inputs_dict['attention_mask'] = first_attention
79
+ for name in inputs_names:
80
+ if name == 'input_ids' or name == 'attention_mask': continue
81
+ inputs_dict[name] = np.zeros([1, n_heads, context-window, 256], dtype="float16")
82
+ index = 0
83
+ new_token = np.array([10])
84
+ next_index = window
85
+ old_j = 0
86
+ total_input = actual_input.numpy()
87
+
88
+ rt_session = onnxruntime.InferenceSession(model_path)
89
+ ## We run the inferences
90
+ while next_index < max_gen_tokens:
91
+ if new_token.any() == tokenizer.eos_token_id:
92
+ break
93
+ #inference
94
+ output = rt_session.run(output_names, inputs_dict)
95
+ outs_dictionary = {name: content for (name, content) in zip (output_names, output)}
96
+ #we prepare the inputs for the next inference
97
+ for name in inputs_names:
98
+ if name == 'input_ids':
99
+ old_j = next_index
100
+ if next_index < prompt_size:
101
+ if prompt_size - next_index >= window: next_index += window
102
+ else: next_index = prompt_size
103
+ j = next_index - window
104
+ else:
105
+ next_index +=1
106
+ j = next_index - window
107
+ new_token = outs_dictionary['logits'].argmax(-1).reshape(1, window)
108
+ total_input = np.concatenate((total_input, new_token[: , -1:]), axis = 1)
109
+ inputs_dict['input_ids']= total_input[:, j:next_index].reshape(1, window)
110
+ elif name == 'attention_mask':
111
+ inputs_dict['attention_mask'] = np.concatenate((np.zeros((1, total_sequence-next_index), dtype = 'int64'), np.ones((1, next_index), dtype = 'int64')), axis=1)
112
+ else:
113
+ old_name = name.replace("past_key_values", "present")
114
+ inputs_dict[name] = outs_dictionary[old_name][:, :, next_index-old_j:context-window+(next_index - old_j), :]
115
+
116
+ answer = tokenizer.decode(total_input[0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
117
+ return answer
118
+ ```
119
+ We now run the inferences:
120
+
121
+ ```python
122
+ tokenizer = AutoTokenizer.from_pretrained("Esperanto/gemma-2b-it-kvc-fp16-onnx")
123
+ model_path = "gemma-2b-it-kvc-fp16-onnx/model.onnx"
124
+
125
+ max_gen_tokens = 20 #number of tokens we want tog eneral
126
+ total_sequence = 128 #total sequence_length
127
+ context = 1024 #the context to extend the kvc
128
+ window = 16 #number of tokens we want to parse at the time
129
+ messages = [
130
+ {"role": "system", "content": "You are a pirate chatbot who always responds in pirate speak!"},
131
+ {"role": "user", "content": "Who are you?"},
132
+ ]
133
+
134
+ prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
135
+
136
+ generated = generate_text(model_path, prompt, tokenizer, max_gen_tokens, total_sequence, window, context)
137
+ print(generated)
138
+ ```