Update README.md
Browse files
README.md
CHANGED
@@ -38,25 +38,33 @@ class StopOnTokenCriteria(StoppingCriteria):
|
|
38 |
return input_ids[0, -1] == self.stop_token_id
|
39 |
|
40 |
|
41 |
-
stop_on_token_criteria = StopOnTokenCriteria(stop_token_id=2)
|
42 |
-
|
43 |
pipe = pipeline(
|
44 |
"text-generation",
|
45 |
-
"AI-Sweden-Models/gpt-sw3-6.7b-v2-translator",
|
46 |
device=device
|
47 |
)
|
|
|
|
|
|
|
|
|
48 |
# This will translate English to Swedish
|
49 |
# To translate from Swedish to English the prompt would be:
|
50 |
# prompt = f"<|endoftext|><s>User: Översätt till Engelska från Svenska\n{text}<s>Bot:"
|
51 |
-
|
52 |
prompt = f"<|endoftext|><s>User: Översätt till Svenska från Engelska\n{text}<s>Bot:"
|
53 |
|
|
|
|
|
|
|
|
|
54 |
response = pipe(
|
55 |
prompt,
|
56 |
-
max_length=
|
|
|
57 |
stopping_criteria=StoppingCriteriaList([stop_on_token_criteria])
|
58 |
)
|
59 |
|
|
|
60 |
print(response[0]["generated_text"].split("<s>Bot: ")[-1])
|
61 |
```
|
62 |
```python
|
|
|
38 |
return input_ids[0, -1] == self.stop_token_id
|
39 |
|
40 |
|
|
|
|
|
41 |
pipe = pipeline(
|
42 |
"text-generation",
|
43 |
+
model="AI-Sweden-Models/gpt-sw3-6.7b-v2-translator",
|
44 |
device=device
|
45 |
)
|
46 |
+
|
47 |
+
stop_on_token_criteria = StopOnTokenCriteria(stop_token_id=pipe.tokenizer.bos_token_id)
|
48 |
+
text = "I like to eat ice cream in the summer."
|
49 |
+
|
50 |
# This will translate English to Swedish
|
51 |
# To translate from Swedish to English the prompt would be:
|
52 |
# prompt = f"<|endoftext|><s>User: Översätt till Engelska från Svenska\n{text}<s>Bot:"
|
53 |
+
|
54 |
prompt = f"<|endoftext|><s>User: Översätt till Svenska från Engelska\n{text}<s>Bot:"
|
55 |
|
56 |
+
input_tokens = pipe.tokenizer(prompt, return_tensors="pt").input_ids.to(device)
|
57 |
+
max_model_length = 2048
|
58 |
+
dynamic_max_length = max_model_length - input_tokens.shape[1]
|
59 |
+
|
60 |
response = pipe(
|
61 |
prompt,
|
62 |
+
max_length=dynamic_max_length,
|
63 |
+
truncation=True,
|
64 |
stopping_criteria=StoppingCriteriaList([stop_on_token_criteria])
|
65 |
)
|
66 |
|
67 |
+
# Extract and print the generated translation
|
68 |
print(response[0]["generated_text"].split("<s>Bot: ")[-1])
|
69 |
```
|
70 |
```python
|