164 lines
6.3 KiB
Python
164 lines
6.3 KiB
Python
# encoding: utf-8
|
||
'''
|
||
@author: 温进
|
||
@file: code.py.py
|
||
@time: 2023/10/23 下午5:31
|
||
@desc:
|
||
'''
|
||
|
||
import streamlit as st
|
||
import os
|
||
import time
|
||
import traceback
|
||
from typing import Literal, Dict, Tuple
|
||
from st_aggrid import AgGrid, JsCode
|
||
from st_aggrid.grid_options_builder import GridOptionsBuilder
|
||
import pandas as pd
|
||
|
||
# from configs.model_config import embedding_model_dict, kbs_config, EMBEDDING_MODEL, DEFAULT_VS_TYPE, WEB_CRAWL_PATH
|
||
from .utils import *
|
||
from coagent.utils.path_utils import *
|
||
from coagent.service.service_factory import get_cb_details, get_cb_details_by_cb_name
|
||
from coagent.orm import table_init
|
||
|
||
|
||
from configs.model_config import EMBEDDING_DEVICE, EMBEDDING_ENGINE, EMBEDDING_MODEL, embedding_model_dict,llm_model_dict
|
||
# SENTENCE_SIZE = 100
|
||
|
||
cell_renderer = JsCode("""function(params) {if(params.value==true){return '✓'}else{return '×'}}""")
|
||
|
||
|
||
def file_exists(cb: str, selected_rows: List) -> Tuple[str, str]:
|
||
'''
|
||
check whether the dir exist in local file
|
||
return the dir's name and path if it exists.
|
||
'''
|
||
if selected_rows:
|
||
file_name = selected_rows[0]["code_name"]
|
||
file_path = get_file_path(cb, file_name, KB_ROOT_PATH)
|
||
if os.path.isfile(file_path):
|
||
return file_name, file_path
|
||
return "", ""
|
||
|
||
|
||
def code_page(api: ApiRequest):
|
||
# 判断表是否存在并进行初始化
|
||
table_init()
|
||
|
||
try:
|
||
logger.info(get_cb_details())
|
||
cb_list = {x["code_name"]: x for x in get_cb_details()}
|
||
except Exception as e:
|
||
logger.exception(e)
|
||
st.error("获取知识库信息错误,请检查是否已按照 `README.md` 中 `4 知识库初始化与迁移` 步骤完成初始化或迁移,或是否为数据库连接错误。")
|
||
st.stop()
|
||
cb_names = list(cb_list.keys())
|
||
|
||
if "selected_cb_name" in st.session_state and st.session_state["selected_cb_name"] in cb_names:
|
||
selected_cb_index = cb_names.index(st.session_state["selected_cb_name"])
|
||
else:
|
||
selected_cb_index = 0
|
||
|
||
def format_selected_cb(cb_name: str) -> str:
|
||
if cb := cb_list.get(cb_name):
|
||
return f"{cb_name} ({cb['code_path']})"
|
||
else:
|
||
return cb_name
|
||
|
||
selected_cb = st.selectbox(
|
||
"请选择或新建代码知识库:",
|
||
cb_names + ["新建代码知识库"],
|
||
format_func=format_selected_cb,
|
||
index=selected_cb_index
|
||
)
|
||
|
||
if selected_cb == "新建代码知识库":
|
||
with st.form("新建代码知识库"):
|
||
|
||
cb_name = st.text_input(
|
||
"新建代码知识库名称",
|
||
placeholder="新代码知识库名称,不支持中文命名",
|
||
key="cb_name",
|
||
)
|
||
|
||
file = st.file_uploader("上传代码库 zip 文件",
|
||
['.zip'],
|
||
accept_multiple_files=False,
|
||
)
|
||
|
||
do_interpret = st.checkbox('**代码解读**', value=False, help='代码解读会针对每个代码文件通过 LLM 获取解释并且向量化存储。当代码文件较多时,\
|
||
导入速度会变慢,且如果使用收费 API 的话可能会造成较大花费。如果要使用基于描述的代码问答模式,此项必须勾选', key='do_interpret')
|
||
|
||
logger.info(f'do_interpret={do_interpret}')
|
||
submit_create_kb = st.form_submit_button(
|
||
"新建",
|
||
use_container_width=True,
|
||
)
|
||
|
||
if submit_create_kb:
|
||
# unzip file
|
||
logger.info('files={}'.format(file))
|
||
|
||
if not cb_name or not cb_name.strip():
|
||
st.error(f"知识库名称不能为空!")
|
||
elif cb_name in cb_list:
|
||
st.error(f"名为 {cb_name} 的知识库已经存在!")
|
||
elif file.type not in ['application/zip', 'application/x-zip-compressed']:
|
||
logger.error(f"{file.type}")
|
||
st.error('请先上传 zip 文件,再新建代码知识库')
|
||
else:
|
||
ret = api.create_code_base(
|
||
cb_name,
|
||
file,
|
||
do_interpret,
|
||
no_remote_api=True,
|
||
embed_engine=EMBEDDING_ENGINE,
|
||
embed_model=EMBEDDING_MODEL,
|
||
embed_model_path=embedding_model_dict[EMBEDDING_MODEL],
|
||
embedding_device=EMBEDDING_DEVICE,
|
||
llm_model=LLM_MODEL,
|
||
api_key=llm_model_dict[LLM_MODEL]["api_key"],
|
||
api_base_url=llm_model_dict[LLM_MODEL]["api_base_url"],
|
||
)
|
||
st.toast(ret.get("msg", " "))
|
||
st.session_state["selected_cb_name"] = cb_name
|
||
st.experimental_rerun()
|
||
elif selected_cb:
|
||
cb = selected_cb
|
||
|
||
# 知识库详情
|
||
cb_details = get_cb_details_by_cb_name(cb)
|
||
if not len(cb_details):
|
||
st.info(f"代码知识库 `{cb}` 中暂无信息")
|
||
else:
|
||
logger.info(cb_details)
|
||
st.write(f"代码知识库 `{cb}` 加载成功,中含有以下信息:")
|
||
|
||
st.write('代码知识库 `{}` 代码文件数=`{}`'.format(cb_details['code_name'],
|
||
cb_details.get('code_file_num', 'unknown')))
|
||
|
||
st.write('代码知识库 `{}` 知识图谱节点数=`{}`'.format(cb_details['code_name'], cb_details['code_graph_node_num']))
|
||
st.write('代码知识库 `{}` 是否进行代码解读=`{}`'.format(cb_details['code_name'], cb_details['do_interpret']))
|
||
|
||
st.divider()
|
||
|
||
cols = st.columns(3)
|
||
|
||
if cols[2].button(
|
||
"删除知识库",
|
||
use_container_width=True,
|
||
):
|
||
ret = api.delete_code_base(cb,
|
||
no_remote_api=True,
|
||
embed_engine=EMBEDDING_ENGINE,
|
||
embed_model=EMBEDDING_MODEL,
|
||
embed_model_path=embedding_model_dict[EMBEDDING_MODEL],
|
||
embedding_device=EMBEDDING_DEVICE,
|
||
llm_model=LLM_MODEL,
|
||
api_key=llm_model_dict[LLM_MODEL]["api_key"],
|
||
api_base_url=llm_model_dict[LLM_MODEL]["api_base_url"],
|
||
)
|
||
st.toast(ret.get("msg", "删除成功"))
|
||
time.sleep(0.05)
|
||
st.experimental_rerun()
|