|
---
|
|
license: mit
|
|
inference: False
|
|
---
|
|
|
|
# training logs
|
|
- https://wandb.ai/junyu/huggingface/runs/1jg2jlgt
|
|
|
|
# install
|
|
- https://github.com/JunnYu/FLASHQuad_pytorch
|
|
|
|
# usage
|
|
```python
|
|
import torch
|
|
from flash import FLASHForMaskedLM
|
|
from transformers import BertTokenizerFast
|
|
tokenizer = BertTokenizerFast.from_pretrained("junnyu/flash_small_wwm_cluecorpussmall")
|
|
model = FLASHForMaskedLM.from_pretrained("junnyu/flash_small_wwm_cluecorpussmall")
|
|
model.eval()
|
|
text = "天气预报说今天的天[MASK]很好,那么我[MASK]一起去公园玩吧!"
|
|
inputs = tokenizer(text, return_tensors="pt", padding="max_length", max_length=512, return_token_type_ids=False) #这里必须是512,不然结果可能不对。
|
|
with torch.no_grad():
|
|
pt_outputs = model(**inputs).logits[0]
|
|
|
|
pt_outputs_sentence = "pytorch: "
|
|
for i, id in enumerate(tokenizer.encode(text)):
|
|
if id == tokenizer.mask_token_id:
|
|
val,idx = pt_outputs[i].softmax(-1).topk(k=5)
|
|
tokens = tokenizer.convert_ids_to_tokens(idx)
|
|
new_tokens = []
|
|
for v,t in zip(val.cpu(),tokens):
|
|
new_tokens.append(f"{t}+{round(v.item(),4)}")
|
|
pt_outputs_sentence += "[" + "||".join(new_tokens) + "]"
|
|
else:
|
|
pt_outputs_sentence += "".join(
|
|
tokenizer.convert_ids_to_tokens([id], skip_special_tokens=True))
|
|
print(pt_outputs_sentence)
|
|
# pytorch: 天气预报说今天的天[气+0.994||天+0.0015||空+0.0014||晴+0.0005||阳+0.0003]很好,那么我[们+0.9563||就+0.0381||也+0.0032||俩+0.0004||来+0.0002]一起去公园玩吧!
|
|
``` |