Spaces:
Runtime error
Runtime error
Merge pull request #15 from borisdayma/feat-fix-lr
Browse files- requirements.txt +3 -0
- seq2seq/run_seq2seq_flax.py +6 -3
- seq2seq/sweep.yaml +3 -3
requirements.txt
CHANGED
@@ -7,3 +7,6 @@ jax[tpu]>=0.2.16
|
|
7 |
-e git+https://github.com/huggingface/datasets.git@master#egg=datasets
|
8 |
flax
|
9 |
jupyter
|
|
|
|
|
|
|
|
7 |
-e git+https://github.com/huggingface/datasets.git@master#egg=datasets
|
8 |
flax
|
9 |
jupyter
|
10 |
+
# for logging
|
11 |
+
tensorboard
|
12 |
+
tetnsorflow
|
seq2seq/run_seq2seq_flax.py
CHANGED
@@ -19,8 +19,11 @@ Script adapted from run_summarization_flax.py
|
|
19 |
"""
|
20 |
# You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
|
21 |
|
22 |
-
import logging as pylogging # To avoid collision with transformers.utils.logging
|
23 |
import os
|
|
|
|
|
|
|
|
|
24 |
import sys
|
25 |
import time
|
26 |
from dataclasses import dataclass, field
|
@@ -673,12 +676,12 @@ def main():
|
|
673 |
grads = jax.tree_map(lambda x: x / training_args.gradient_accumulation_steps, grad_accum)
|
674 |
grads = jax.lax.pmean(grads, "batch")
|
675 |
new_state = state.apply_gradients(
|
676 |
-
grads=grads, grad_accum=jax.tree_map(jnp.zeros_like, grads), optimizer_step=state.optimizer_step
|
677 |
)
|
678 |
return new_state
|
679 |
|
680 |
new_state = jax.lax.cond(
|
681 |
-
state.step % training_args.gradient_accumulation_steps == 0,
|
682 |
lambda _: update_fn(),
|
683 |
lambda _: state.replace(grad_accum=grad_accum, step=state.step + 1),
|
684 |
None,
|
|
|
19 |
"""
|
20 |
# You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
|
21 |
|
|
|
22 |
import os
|
23 |
+
# set a common huggingface cache folder (used with datasets and transformers)
|
24 |
+
os.environ['HF_HOME'] = '/data/huggingface/' # required before importing transformers & datasets
|
25 |
+
|
26 |
+
import logging as pylogging # To avoid collision with transformers.utils.logging
|
27 |
import sys
|
28 |
import time
|
29 |
from dataclasses import dataclass, field
|
|
|
676 |
grads = jax.tree_map(lambda x: x / training_args.gradient_accumulation_steps, grad_accum)
|
677 |
grads = jax.lax.pmean(grads, "batch")
|
678 |
new_state = state.apply_gradients(
|
679 |
+
grads=grads, grad_accum=jax.tree_map(jnp.zeros_like, grads), optimizer_step=state.optimizer_step + 1
|
680 |
)
|
681 |
return new_state
|
682 |
|
683 |
new_state = jax.lax.cond(
|
684 |
+
(state.step + 1) % training_args.gradient_accumulation_steps == 0,
|
685 |
lambda _: update_fn(),
|
686 |
lambda _: state.replace(grad_accum=grad_accum, step=state.step + 1),
|
687 |
None,
|
seq2seq/sweep.yaml
CHANGED
@@ -8,9 +8,9 @@ metric:
|
|
8 |
parameters:
|
9 |
learning_rate:
|
10 |
distribution: log_uniform
|
11 |
-
# from exp(min) to exp(max), ie 1e-
|
12 |
-
min: -
|
13 |
-
max: -
|
14 |
gradient_accumulation_steps:
|
15 |
value: 8
|
16 |
warmup_steps:
|
|
|
8 |
parameters:
|
9 |
learning_rate:
|
10 |
distribution: log_uniform
|
11 |
+
# from exp(min) to exp(max), ie 1e-4 to 5e-3 on log scale
|
12 |
+
min: -9.2
|
13 |
+
max: -5.3
|
14 |
gradient_accumulation_steps:
|
15 |
value: 8
|
16 |
warmup_steps:
|