Saving weights of epoch 1 at step 92
Browse files
__pycache__/model_file.cpython-38.pyc
CHANGED
Binary files a/__pycache__/model_file.cpython-38.pyc and b/__pycache__/model_file.cpython-38.pyc differ
|
|
flax_model.msgpack
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 1419367919
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:51d86bd352715e1623b69a8451f8c752c314bb6cf7669a5d9bb2f7589261d8c3
|
3 |
size 1419367919
|
model_file.py
CHANGED
@@ -190,7 +190,7 @@ class FlaxGPT2ForMultipleChoiceModule(nn.Module):
|
|
190 |
dtype: jnp.dtype = jnp.float32
|
191 |
def setup(self):
|
192 |
self.transformer = FlaxGPT2Module(config=self.config, dtype=self.dtype)
|
193 |
-
self.dropout = nn.Dropout(rate=0.
|
194 |
self.classifier = nn.Dense(4, dtype=self.dtype)
|
195 |
|
196 |
def __call__(self,input_ids,attention_mask,position_ids,return_dict=True,deterministic=True,*args):
|
|
|
190 |
dtype: jnp.dtype = jnp.float32
|
191 |
def setup(self):
|
192 |
self.transformer = FlaxGPT2Module(config=self.config, dtype=self.dtype)
|
193 |
+
self.dropout = nn.Dropout(rate=0.2)
|
194 |
self.classifier = nn.Dense(4, dtype=self.dtype)
|
195 |
|
196 |
def __call__(self,input_ids,attention_mask,position_ids,return_dict=True,deterministic=True,*args):
|
results_tensorboard/events.out.tfevents.1626339960.t1v-n-8cb15980-w-0.776261.3.v2
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:322d306ebcf9d805c02057d9c2c761e63a887231db3d665df7d6dc88bed92174
|
3 |
+
size 25038
|
train.py
CHANGED
@@ -74,7 +74,7 @@ def main():
|
|
74 |
per_device_batch_size=4
|
75 |
seed=0
|
76 |
num_train_epochs=3
|
77 |
-
learning_rate=
|
78 |
|
79 |
|
80 |
total_batch_size = per_device_batch_size * jax.local_device_count()
|
|
|
74 |
per_device_batch_size=4
|
75 |
seed=0
|
76 |
num_train_epochs=3
|
77 |
+
learning_rate=2e-5
|
78 |
|
79 |
|
80 |
total_batch_size = per_device_batch_size * jax.local_device_count()
|