arnocandel
commited on
Commit
·
d1aea17
1
Parent(s):
3dadeda
https://github.com/h2oai/h2ogpt/issues/125#issuecomment-1548239108
Browse files- config.json +1 -7
- h2oai_pipeline.py +8 -0
- pytorch_model-00001-of-00005.bin +1 -1
- pytorch_model-00002-of-00005.bin +1 -1
- pytorch_model-00003-of-00005.bin +1 -1
- pytorch_model-00004-of-00005.bin +1 -1
- pytorch_model-00005-of-00005.bin +1 -1
config.json
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
{
|
2 |
-
"_name_or_path": "
|
3 |
"architectures": [
|
4 |
"GPTNeoXForCausalLM"
|
5 |
],
|
@@ -10,12 +10,6 @@
|
|
10 |
"pt": "AutoModelForCausalLM"
|
11 |
}
|
12 |
},
|
13 |
-
"custom_pipelines": {
|
14 |
-
"text-generation": {
|
15 |
-
"impl": "h2oai_pipeline.H2OTextGenerationPipeline",
|
16 |
-
"pt": "AutoModelForCausalLM"
|
17 |
-
}
|
18 |
-
},
|
19 |
"eos_token_id": 0,
|
20 |
"hidden_act": "gelu",
|
21 |
"hidden_size": 5120,
|
|
|
1 |
{
|
2 |
+
"_name_or_path": "EleutherAI/pythia-12b-deduped",
|
3 |
"architectures": [
|
4 |
"GPTNeoXForCausalLM"
|
5 |
],
|
|
|
10 |
"pt": "AutoModelForCausalLM"
|
11 |
}
|
12 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
"eos_token_id": 0,
|
14 |
"hidden_act": "gelu",
|
15 |
"hidden_size": 5120,
|
h2oai_pipeline.py
CHANGED
@@ -1,6 +1,9 @@
|
|
1 |
from transformers import TextGenerationPipeline
|
2 |
from transformers.pipelines.text_generation import ReturnType
|
3 |
|
|
|
|
|
|
|
4 |
human = "<human>:"
|
5 |
bot = "<bot>:"
|
6 |
|
@@ -28,3 +31,8 @@ class H2OTextGenerationPipeline(TextGenerationPipeline):
|
|
28 |
for rec in records:
|
29 |
rec['generated_text'] = rec['generated_text'].split(bot)[1].strip().split(human)[0].strip()
|
30 |
return records
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from transformers import TextGenerationPipeline
|
2 |
from transformers.pipelines.text_generation import ReturnType
|
3 |
|
4 |
+
from stopping import get_stopping
|
5 |
+
|
6 |
+
prompt_type = "human_bot"
|
7 |
human = "<human>:"
|
8 |
bot = "<bot>:"
|
9 |
|
|
|
31 |
for rec in records:
|
32 |
rec['generated_text'] = rec['generated_text'].split(bot)[1].strip().split(human)[0].strip()
|
33 |
return records
|
34 |
+
|
35 |
+
def _forward(self, model_inputs, **generate_kwargs):
|
36 |
+
stopping_criteria = get_stopping(prompt_type, self.tokenizer, self.device, human=human, bot=bot)
|
37 |
+
generate_kwargs['stopping_criteria'] = stopping_criteria
|
38 |
+
return super()._forward(model_inputs, **generate_kwargs)
|
pytorch_model-00001-of-00005.bin
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 4957630318
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:64691fa6fa33a63aa2fad165e6215a17e79dac4a203b9f8c887907a72278660b
|
3 |
size 4957630318
|
pytorch_model-00002-of-00005.bin
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 4853861544
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:75dd532c4cb4c3649e80191dac7f0120ce0d3a0f573f66da11f61290936eeb46
|
3 |
size 4853861544
|
pytorch_model-00003-of-00005.bin
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 4858068625
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7ace0311fd3b140629c0bda15e5d6ebf23987d4905124da10fac7c0ff11e583e
|
3 |
size 4858068625
|
pytorch_model-00004-of-00005.bin
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 5015385889
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ea6b0c3b72599fc88f0328bb9c1f5058cd33eff5a5897c0863cd09d713ffbea1
|
3 |
size 5015385889
|
pytorch_model-00005-of-00005.bin
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 4158379959
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:17f12d59a0255f9b07e531081be6666b3e9507baa82d25ac412536f6badaffdd
|
3 |
size 4158379959
|