Usage
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
model_name = 'csdc-atl/doc2query'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
def create_queries(history, next_question):
inputs_ids = []
for line in history:
inputs_ids.extend([32127]+tokenizer.encode(line[0], add_special_tokens=False)+[32126]+tokenizer.encode(line[1], add_special_tokens=False))
inputs_ids.extend([32127]+tokenizer.encode(next_question, add_special_tokens=False))
inputs_ids = inputs_ids + [1]
inputs_ids = torch.Tensor([inputs_ids]).long()
with torch.no_grad():
sampling_outputs = model.generate(
input_ids=inputs_ids,
max_length=512,
do_sample=True,
top_p=0.95,
top_k=10
)
print("\nSampling Outputs:")
for i in range(len(sampling_outputs)):
rewrite_question = tokenizer.decode(sampling_outputs[i], skip_special_tokens=True)
print(f'{i + 1}: {rewrite_question}')
history = [['loghub是什么', 'AWS 上的loghub解决方案可帮助组织在单个控制面板上收集、分析和显示 Amazon CloudWatch Logs。该解决方案可整合、管理和分析来自各种来源的日志文件,例如访问、配置更改和计费事件的审计日志。您也可以从多个账户和 AWS 区域收集 Amazon CloudWatch Logs。']]
next_question = '它的优点是什么?'
create_queries(history, next_question)
# 1: loghub解决方案的优点是什么?
- Downloads last month
- 8
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social
visibility and check back later, or deploy to Inference Endpoints (dedicated)
instead.