Update README.md
Browse files
README.md
CHANGED
@@ -5,7 +5,11 @@ datasets:
|
|
5 |
pipeline_tag: text-classification
|
6 |
---
|
7 |
|
8 |
-
**Paper:** Coming soon
|
|
|
|
|
|
|
|
|
9 |
|
10 |
# Architecture
|
11 |
<div align=center>
|
@@ -15,7 +19,7 @@ URM is one of the RMs in the figure.
|
|
15 |
|
16 |
# Brief
|
17 |
|
18 |
-
[URM-
|
19 |
This RM consists of a base model and an uncertainty-aware and attribute-specific value head. The base model of this RM is from [Skywork-Reward-Llama-3.1-8B](https://huggingface.co/Skywork/Skywork-Reward-Llama-3.1-8B).
|
20 |
|
21 |
URM involves two-stage training: 1. **attributes regression** and 2. **gating layer learning**.
|
@@ -39,7 +43,7 @@ During this process, the value head and base model are kept frozen.
|
|
39 |
import torch
|
40 |
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
41 |
|
42 |
-
model_name = "LxzGordon/URM-
|
43 |
model = AutoModelForSequenceClassification.from_pretrained(
|
44 |
model_name,
|
45 |
device_map='auto',
|
|
|
5 |
pipeline_tag: text-classification
|
6 |
---
|
7 |
|
8 |
+
- **Paper:** Coming soon
|
9 |
+
|
10 |
+
- **Model:** [URM-LLaMa-3.1-8B](https://huggingface.co/LxzGordon/URM-LLaMa-3.1-8B)
|
11 |
+
|
12 |
+
- Fine-tuned from [Skywork-Reward-Llama-3.1-8B](https://huggingface.co/Skywork/Skywork-Reward-Llama-3.1-8B)
|
13 |
|
14 |
# Architecture
|
15 |
<div align=center>
|
|
|
19 |
|
20 |
# Brief
|
21 |
|
22 |
+
[URM-LLaMa-3.1-8B](https://huggingface.co/LxzGordon/URM-LLaMa-3.1-8B) is an uncertain-aware reward model.
|
23 |
This RM consists of a base model and an uncertainty-aware and attribute-specific value head. The base model of this RM is from [Skywork-Reward-Llama-3.1-8B](https://huggingface.co/Skywork/Skywork-Reward-Llama-3.1-8B).
|
24 |
|
25 |
URM involves two-stage training: 1. **attributes regression** and 2. **gating layer learning**.
|
|
|
43 |
import torch
|
44 |
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
45 |
|
46 |
+
model_name = "LxzGordon/URM-LLaMa-3.1-8B"
|
47 |
model = AutoModelForSequenceClassification.from_pretrained(
|
48 |
model_name,
|
49 |
device_map='auto',
|