FaYo
model
d8d694f
raw
history blame
3.3 kB
import shutil
from pathlib import Path
import streamlit as st
import torch
import yaml
from utils.rag.feature_store import gen_vector_db
from utils.rag.retriever import CacheRetriever
from utils.web_configs import WEB_CONFIGS
# 基础配置
CONTEXT_MAX_LENGTH = 3000 # 上下文最大长度
GENERATE_TEMPLATE = "这是说明书:“{}”\n 客户的问题:“{}” \n 请阅读说明并运用你的性格进行解答。" # RAG prompt 模板
def build_rag_prompt(rag_retriever: CacheRetriever, product_name, prompt):
real_retriever = rag_retriever.get(fs_id="default")
if isinstance(real_retriever, tuple):
print(f" @@@ GOT real_retriever == tuple : {real_retriever}")
return ""
chunk, db_context, references = real_retriever.query(
f"商品名:{product_name}{prompt}", context_max_length=CONTEXT_MAX_LENGTH - 2 * len(GENERATE_TEMPLATE)
)
print(f"db_context = {db_context}")
if db_context is not None and len(db_context) > 1:
prompt_rag = GENERATE_TEMPLATE.format(db_context, prompt)
else:
print("db_context get error")
prompt_rag = prompt
print(f"RAG reference = {references}")
print("=" * 20)
return prompt_rag
def init_rag_retriever(rag_config: str, db_path: str):
torch.cuda.empty_cache()
retriever = CacheRetriever(config_path=rag_config)
# 初始化
retriever.get(fs_id="default", config_path=rag_config, work_dir=db_path)
return retriever
def gen_rag_db(force_gen=False):
"""
生成向量数据库。
参数:
force_gen - 布尔值,当设置为 True 时,即使数据库已存在也会重新生成数据库。
"""
# 检查数据库目录是否存在,如果存在且force_gen为False,则不执行生成操作
if Path(WEB_CONFIGS.RAG_VECTOR_DB_DIR).exists() and not force_gen:
return
if force_gen and Path(WEB_CONFIGS.RAG_VECTOR_DB_DIR).exists():
shutil.rmtree(WEB_CONFIGS.RAG_VECTOR_DB_DIR)
# 仅仅遍历 instructions 字段里面的文件
if Path(WEB_CONFIGS.PRODUCT_INSTRUCTION_DIR_GEN_DB_TMP).exists():
shutil.rmtree(WEB_CONFIGS.PRODUCT_INSTRUCTION_DIR_GEN_DB_TMP)
Path(WEB_CONFIGS.PRODUCT_INSTRUCTION_DIR_GEN_DB_TMP).mkdir(exist_ok=True, parents=True)
# 读取 yaml 文件,获取所有说明书路径,并移动到 tmp 目录
with open(WEB_CONFIGS.PRODUCT_INFO_YAML_PATH, "r", encoding="utf-8") as f:
product_info_dict = yaml.safe_load(f)
for _, info in product_info_dict.items():
shutil.copyfile(
info["instruction"], Path(WEB_CONFIGS.PRODUCT_INSTRUCTION_DIR_GEN_DB_TMP).joinpath(Path(info["instruction"]).name)
)
print("Generating rag database, pls wait ...")
# 调用函数生成向量数据库
gen_vector_db(
WEB_CONFIGS.RAG_CONFIG_PATH,
str(Path(WEB_CONFIGS.PRODUCT_INSTRUCTION_DIR_GEN_DB_TMP).absolute()),
WEB_CONFIGS.RAG_VECTOR_DB_DIR,
)
# 删除过程文件
shutil.rmtree(WEB_CONFIGS.PRODUCT_INSTRUCTION_DIR_GEN_DB_TMP)
@st.cache_resource
def load_rag_model():
# 生成 rag 数据库
gen_rag_db()
# 加载 rag 模型
retriever = init_rag_retriever(rag_config=WEB_CONFIGS.RAG_CONFIG_PATH, db_path=WEB_CONFIGS.RAG_VECTOR_DB_DIR)
return retriever