import streamlit as st import os import time import traceback from typing import Literal, Dict, Tuple from st_aggrid import AgGrid, JsCode from st_aggrid.grid_options_builder import GridOptionsBuilder import pandas as pd from .utils import * from coagent.utils.path_utils import * from coagent.service.service_factory import get_kb_details, get_kb_doc_details from coagent.orm import table_init from configs.model_config import ( KB_ROOT_PATH, kbs_config, DEFAULT_VS_TYPE, WEB_CRAWL_PATH, EMBEDDING_DEVICE, EMBEDDING_ENGINE, EMBEDDING_MODEL, embedding_model_dict, llm_model_dict ) # SENTENCE_SIZE = 100 cell_renderer = JsCode("""function(params) {if(params.value==true){return '✓'}else{return '×'}}""") def config_aggrid( df: pd.DataFrame, columns: Dict[Tuple[str, str], Dict] = {}, selection_mode: Literal["single", "multiple", "disabled"] = "single", use_checkbox: bool = False, ) -> GridOptionsBuilder: gb = GridOptionsBuilder.from_dataframe(df) gb.configure_column("No", width=40) for (col, header), kw in columns.items(): gb.configure_column(col, header, wrapHeaderText=True, **kw) gb.configure_selection( selection_mode=selection_mode, use_checkbox=use_checkbox, # pre_selected_rows=st.session_state.get("selected_rows", [0]), ) return gb def file_exists(kb: str, selected_rows: List) -> Tuple[str, str]: ''' check whether a doc file exists in local knowledge base folder. return the file's name and path if it exists. ''' if selected_rows: file_name = selected_rows[0]["file_name"] file_path = get_file_path(kb, file_name, KB_ROOT_PATH) if os.path.isfile(file_path): return file_name, file_path return "", "" def knowledge_page( api: ApiRequest, embedding_model_dict: dict = embedding_model_dict, kbs_config: dict = kbs_config, embedding_model: str = EMBEDDING_MODEL, default_vs_type: str = DEFAULT_VS_TYPE, web_crawl_path: str = WEB_CRAWL_PATH ): # 判断表是否存在并进行初始化 table_init() try: kb_list = {x["kb_name"]: x for x in get_kb_details(KB_ROOT_PATH)} except Exception as e: st.error("获取知识库信息错误,请检查是否已按照 `README.md` 中 `4 知识库初始化与迁移` 步骤完成初始化或迁移,或是否为数据库连接错误。") st.stop() kb_names = list(kb_list.keys()) if "selected_kb_name" in st.session_state and st.session_state["selected_kb_name"] in kb_names: selected_kb_index = kb_names.index(st.session_state["selected_kb_name"]) else: selected_kb_index = 0 def format_selected_kb(kb_name: str) -> str: if kb := kb_list.get(kb_name): return f"{kb_name} ({kb['vs_type']} @ {kb['embed_model']})" else: return kb_name selected_kb = st.selectbox( "请选择或新建知识库:", kb_names + ["新建知识库"], format_func=format_selected_kb, index=selected_kb_index ) if selected_kb == "新建知识库": with st.form("新建知识库"): kb_name = st.text_input( "新建知识库名称", placeholder="新知识库名称,不支持中文命名", key="kb_name", ) cols = st.columns(2) vs_types = list(kbs_config.keys()) vs_type = cols[0].selectbox( "向量库类型", vs_types, index=vs_types.index(default_vs_type), key="vs_type", ) embed_models = list(embedding_model_dict.keys()) embed_model = cols[1].selectbox( "Embedding 模型", embed_models, index=embed_models.index(embedding_model), key="embed_model", ) submit_create_kb = st.form_submit_button( "新建", # disabled=not bool(kb_name), use_container_width=True, ) if submit_create_kb: if not kb_name or not kb_name.strip(): st.error(f"知识库名称不能为空!") elif kb_name in kb_list: st.error(f"名为 {kb_name} 的知识库已经存在!") else: ret = api.create_knowledge_base( knowledge_base_name=kb_name, vector_store_type=vs_type, embed_model=embed_model, embed_engine=EMBEDDING_ENGINE, embedding_device= EMBEDDING_DEVICE, embed_model_path=embedding_model_dict[embed_model], api_key=llm_model_dict[LLM_MODEL]["api_key"], api_base_url=llm_model_dict[LLM_MODEL]["api_base_url"], ) st.toast(ret.get("msg", " ")) st.session_state["selected_kb_name"] = kb_name st.experimental_rerun() elif selected_kb: kb = selected_kb # 上传文件 # sentence_size = st.slider("文本入库分句长度限制", 1, 1000, SENTENCE_SIZE, disabled=True) files = st.file_uploader("上传知识文件", [i for ls in LOADER2EXT_DICT.values() for i in ls], accept_multiple_files=True, ) if st.button( "添加文件到知识库", # help="请先上传文件,再点击添加", # use_container_width=True, disabled=len(files) == 0, ): data = [{"file": f, "knowledge_base_name": kb, "not_refresh_vs_cache": True, "embed_model": EMBEDDING_MODEL, "embed_model_path": embedding_model_dict[EMBEDDING_MODEL], "model_device": EMBEDDING_DEVICE, "embed_engine": EMBEDDING_ENGINE, "api_key": llm_model_dict[LLM_MODEL]["api_key"], "api_base_url": llm_model_dict[LLM_MODEL]["api_base_url"], } for f in files] data[-1]["not_refresh_vs_cache"]=False for k in data: pass ret = api.upload_kb_doc(**k) if msg := check_success_msg(ret): st.toast(msg, icon="✔") elif msg := check_error_msg(ret): st.toast(msg, icon="✖") st.session_state.files = [] base_url = st.text_input( "待获取内容的URL地址", placeholder="请填写正确可打开的URL地址", key="base_url", ) if st.button( "添加URL内容到知识库", disabled= base_url is None or base_url=="", ): filename = base_url.replace("https://", " ").\ replace("http://", " ").replace("/", " ").\ replace("?", " ").replace("=", " ").replace(".", " ").strip() html_name = "_".join(filename.split(" ",) + ["html.jsonl"]) text_name = "_".join(filename.split(" ",) + ["text.jsonl"]) html_path = os.path.join(web_crawl_path, html_name,) text_path = os.path.join(web_crawl_path, text_name,) # if not os.path.exists(text_dir) or : st.toast(base_url) st.toast(html_path) st.toast(text_path) res = api.web_crawl( base_url=base_url, html_dir=html_path, text_dir=text_path, do_dfs = False, reptile_lib="requests", method="get", time_sleep=2, ) if res["status"] == 200: st.toast(res["response"], icon="✔") data = [{"file": text_path, "filename": text_name, "knowledge_base_name": kb, "not_refresh_vs_cache": False, "embed_model": EMBEDDING_MODEL, "embed_model_path": embedding_model_dict[EMBEDDING_MODEL], "model_device": EMBEDDING_DEVICE, "embed_engine": EMBEDDING_ENGINE, "api_key": llm_model_dict[LLM_MODEL]["api_key"], "api_base_url": llm_model_dict[LLM_MODEL]["api_base_url"],}] for k in data: ret = api.upload_kb_doc(**k) logger.info(ret) if msg := check_success_msg(ret): st.toast(msg, icon="✔") elif msg := check_error_msg(ret): st.toast(msg, icon="✖") st.session_state.files = [] else: st.toast(res["response"], icon="✖") if os.path.exists(html_path): os.remove(html_path) st.divider() # 知识库详情 # st.info("请选择文件,点击按钮进行操作。") doc_details = pd.DataFrame(get_kb_doc_details(kb, KB_ROOT_PATH)) if not len(doc_details): st.info(f"知识库 `{kb}` 中暂无文件") else: st.write(f"知识库 `{kb}` 中已有文件:") st.info("知识库中包含源文件与向量库,请从下表中选择文件后操作") doc_details.drop(columns=["kb_name"], inplace=True) doc_details = doc_details[[ "No", "file_name", "document_loader", "text_splitter", "in_folder", "in_db", ]] # doc_details["in_folder"] = doc_details["in_folder"].replace(True, "✓").replace(False, "×") # doc_details["in_db"] = doc_details["in_db"].replace(True, "✓").replace(False, "×") gb = config_aggrid( doc_details, { ("No", "序号"): {}, ("file_name", "文档名称"): {}, # ("file_ext", "文档类型"): {}, # ("file_version", "文档版本"): {}, ("document_loader", "文档加载器"): {}, ("text_splitter", "分词器"): {}, # ("create_time", "创建时间"): {}, ("in_folder", "源文件"): {"cellRenderer": cell_renderer}, ("in_db", "向量库"): {"cellRenderer": cell_renderer}, }, "multiple", ) doc_grid = AgGrid( doc_details, gb.build(), columns_auto_size_mode="FIT_CONTENTS", theme="alpine", custom_css={ "#gridToolBar": {"display": "none"}, }, allow_unsafe_jscode=True ) selected_rows = doc_grid.get("selected_rows", []) cols = st.columns(4) file_name, file_path = file_exists(kb, selected_rows) if file_path: with open(file_path, "rb") as fp: cols[0].download_button( "下载选中文档", fp, file_name=file_name, use_container_width=True, ) else: cols[0].download_button( "下载选中文档", "", disabled=True, use_container_width=True, ) st.write() # 将文件分词并加载到向量库中 if cols[1].button( "重新添加至向量库" if selected_rows and (pd.DataFrame(selected_rows)["in_db"]).any() else "添加至向量库", disabled=not file_exists(kb, selected_rows)[0], use_container_width=True, ): for row in selected_rows: api.update_kb_doc(kb, row["file_name"], embed_engine=EMBEDDING_ENGINE,embed_model=EMBEDDING_MODEL, embed_model_path=embedding_model_dict[EMBEDDING_MODEL], model_device=EMBEDDING_DEVICE, api_key=llm_model_dict[LLM_MODEL]["api_key"], api_base_url=llm_model_dict[LLM_MODEL]["api_base_url"], ) st.experimental_rerun() # 将文件从向量库中删除,但不删除文件本身。 if cols[2].button( "从向量库删除", disabled=not (selected_rows and selected_rows[0]["in_db"]), use_container_width=True, ): for row in selected_rows: api.delete_kb_doc(kb, row["file_name"], embed_engine=EMBEDDING_ENGINE,embed_model=EMBEDDING_MODEL, embed_model_path=embedding_model_dict[EMBEDDING_MODEL], model_device=EMBEDDING_DEVICE, api_key=llm_model_dict[LLM_MODEL]["api_key"], api_base_url=llm_model_dict[LLM_MODEL]["api_base_url"],) st.experimental_rerun() if cols[3].button( "从知识库中删除", type="primary", use_container_width=True, ): for row in selected_rows: ret = api.delete_kb_doc(kb, row["file_name"], True, embed_engine=EMBEDDING_ENGINE,embed_model=EMBEDDING_MODEL, embed_model_path=embedding_model_dict[EMBEDDING_MODEL], model_device=EMBEDDING_DEVICE, api_key=llm_model_dict[LLM_MODEL]["api_key"], api_base_url=llm_model_dict[LLM_MODEL]["api_base_url"],) st.toast(ret.get("msg", " ")) st.experimental_rerun() st.divider() cols = st.columns(3) # todo: freezed if cols[0].button( "依据源文件重建向量库", # help="无需上传文件,通过其它方式将文档拷贝到对应知识库content目录下,点击本按钮即可重建知识库。", use_container_width=True, type="primary", ): with st.spinner("向量库重构中,请耐心等待,勿刷新或关闭页面。"): empty = st.empty() empty.progress(0.0, "") for d in api.recreate_vector_store( kb, vs_type=default_vs_type, embed_model=embedding_model, embedding_device=EMBEDDING_DEVICE, embed_model_path=embedding_model_dict[EMBEDDING_MODEL], embed_engine=EMBEDDING_ENGINE, api_key=llm_model_dict[LLM_MODEL]["api_key"], api_base_url=llm_model_dict[LLM_MODEL]["api_base_url"], ): if msg := check_error_msg(d): st.toast(msg) else: empty.progress(d["finished"] / d["total"], f"正在处理: {d['doc']}") st.experimental_rerun() if cols[2].button( "删除知识库", use_container_width=True, ): ret = api.delete_knowledge_base(kb,) st.toast(ret.get("msg", " ")) time.sleep(1) st.experimental_rerun()