usage example did not work because of typo and torch not imported
Browse files
README.md
CHANGED
@@ -1,45 +1,45 @@
|
|
1 |
-
---
|
2 |
-
tags:
|
3 |
-
- deep-reinforcement-learning
|
4 |
-
- reinforcement-learning
|
5 |
-
|
6 |
-
---
|
7 |
-
|
8 |
-
Find here pretrained model weights for the [Decision Transformer] (https://github.com/kzl/decision-transformer).
|
9 |
-
Weights are available for 4 Atari games: Breakout, Pong, Qbert and Seaquest. Found in the checkpoints directory.
|
10 |
-
We share models trained for one seed (123), whereas the paper contained weights for 3 random seeds.
|
11 |
-
|
12 |
-
|
13 |
-
### Usage
|
14 |
-
|
15 |
-
```
|
16 |
-
git clone https://huggingface.co/edbeeching/decision_transformer_atari
|
17 |
-
conda env create -f conda_env.yml
|
18 |
-
```
|
19 |
-
|
20 |
-
Then, you can use the model like this:
|
21 |
-
|
22 |
-
```python
|
23 |
-
|
24 |
-
from
|
25 |
-
|
26 |
-
vocab_size = 4
|
27 |
-
block_size = 90
|
28 |
-
model_type = "reward_conditioned"
|
29 |
-
timesteps = 2654
|
30 |
-
|
31 |
-
mconf = GPTConfig(
|
32 |
-
vocab_size,
|
33 |
-
block_size,
|
34 |
-
n_layer=6,
|
35 |
-
n_head=8,
|
36 |
-
n_embd=128,
|
37 |
-
model_type=model_type,
|
38 |
-
max_timestep=timesteps,
|
39 |
-
)
|
40 |
-
model = GPT(mconf)
|
41 |
-
|
42 |
-
checkpoint_path = "checkpoints/Breakout_123.pth" # or Pong, Qbert, Seaquest
|
43 |
-
checkpoint = torch.load(checkpoint_path)
|
44 |
-
model.load_state_dict(checkpoint)
|
45 |
-
```
|
|
|
1 |
+
---
|
2 |
+
tags:
|
3 |
+
- deep-reinforcement-learning
|
4 |
+
- reinforcement-learning
|
5 |
+
|
6 |
+
---
|
7 |
+
|
8 |
+
Find here pretrained model weights for the [Decision Transformer] (https://github.com/kzl/decision-transformer).
|
9 |
+
Weights are available for 4 Atari games: Breakout, Pong, Qbert and Seaquest. Found in the checkpoints directory.
|
10 |
+
We share models trained for one seed (123), whereas the paper contained weights for 3 random seeds.
|
11 |
+
|
12 |
+
|
13 |
+
### Usage
|
14 |
+
|
15 |
+
```
|
16 |
+
git clone https://huggingface.co/edbeeching/decision_transformer_atari
|
17 |
+
conda env create -f conda_env.yml
|
18 |
+
```
|
19 |
+
|
20 |
+
Then, you can use the model like this:
|
21 |
+
|
22 |
+
```python
|
23 |
+
import torch
|
24 |
+
from decision_transformer_atari import GPTConfig, GPT
|
25 |
+
|
26 |
+
vocab_size = 4
|
27 |
+
block_size = 90
|
28 |
+
model_type = "reward_conditioned"
|
29 |
+
timesteps = 2654
|
30 |
+
|
31 |
+
mconf = GPTConfig(
|
32 |
+
vocab_size,
|
33 |
+
block_size,
|
34 |
+
n_layer=6,
|
35 |
+
n_head=8,
|
36 |
+
n_embd=128,
|
37 |
+
model_type=model_type,
|
38 |
+
max_timestep=timesteps,
|
39 |
+
)
|
40 |
+
model = GPT(mconf)
|
41 |
+
|
42 |
+
checkpoint_path = "checkpoints/Breakout_123.pth" # or Pong, Qbert, Seaquest
|
43 |
+
checkpoint = torch.load(checkpoint_path)
|
44 |
+
model.load_state_dict(checkpoint)
|
45 |
+
```
|