rdyro commited on
Commit
98ae8db
·
verified ·
1 Parent(s): b609628

Upload README.md

Browse files
Files changed (1) hide show
  1. README.md +9 -7
README.md CHANGED
@@ -24,19 +24,20 @@ model = FlaxMistralForCausalLM.from_pretrained("rdyro/Mistral-7B-Instruct-v0.1",
24
 
25
  tokenizer = AutoTokenizer.from_pretrained("rdyro/Mistral-7B-Instruct-v0.1")
26
 
27
- torch_model_id = "mistralai/Mistral-7B-Instruct-v0.1"
28
- torch_model = AutoModelForCausalLM.from_pretrained(
29
- torch_model_id, device_map="cpu", torch_dtype=torch.float32
30
- )
31
- torch_tokenizer = AutoTokenizer.from_pretrained(torch_model_id)
32
  out_jax = model(input_jax)
33
  ```
34
 
35
  We can compare the outputs to the original PyTorch version.
36
 
37
  ```python
38
- messages = [{"role": "user", "content": "what's your name?"}]
39
- input_jax = tokenizer.apply_chat_template(messages, return_tensors="jax")
 
 
 
 
40
  input_pt = torch_tokenizer.apply_chat_template(messages, return_tensors="pt")
41
 
42
  with torch.no_grad():
@@ -46,6 +47,7 @@ err = jnp.linalg.norm(jnp.array(out_pt.logits) - out_jax.logits) / jnp.linalg.no
46
  jnp.array(out_pt.logits)
47
  )
48
  print(f"Error is numerical precision level: {err:.4e}")
 
49
  ```
50
 
51
  <p align="center">
 
24
 
25
  tokenizer = AutoTokenizer.from_pretrained("rdyro/Mistral-7B-Instruct-v0.1")
26
 
27
+ messages = [{"role": "user", "content": "what's your name?"}]
28
+ input_jax = tokenizer.apply_chat_template(messages, return_tensors="jax")
 
 
 
29
  out_jax = model(input_jax)
30
  ```
31
 
32
  We can compare the outputs to the original PyTorch version.
33
 
34
  ```python
35
+ torch_model_id = "mistralai/Mistral-7B-Instruct-v0.1"
36
+ torch_model = AutoModelForCausalLM.from_pretrained(
37
+ torch_model_id, device_map="cpu", torch_dtype=torch.float32
38
+ )
39
+ torch_tokenizer = AutoTokenizer.from_pretrained(torch_model_id)
40
+
41
  input_pt = torch_tokenizer.apply_chat_template(messages, return_tensors="pt")
42
 
43
  with torch.no_grad():
 
47
  jnp.array(out_pt.logits)
48
  )
49
  print(f"Error is numerical precision level: {err:.4e}")
50
+ # prints: Error is numerical precision level: 1.0205e-06
51
  ```
52
 
53
  <p align="center">