2023-11-07 19:44:47 +08:00
|
|
|
# encoding: utf-8
|
|
|
|
'''
|
|
|
|
@author: 温进
|
|
|
|
@file: cb_api.py
|
|
|
|
@time: 2023/10/23 下午7:08
|
|
|
|
@desc:
|
|
|
|
'''
|
|
|
|
|
|
|
|
import urllib, os, json, traceback
|
|
|
|
from typing import List, Dict
|
|
|
|
import shutil
|
|
|
|
|
|
|
|
from fastapi.responses import StreamingResponse, FileResponse
|
|
|
|
from fastapi import File, Form, Body, Query, UploadFile
|
|
|
|
from langchain.docstore.document import Document
|
|
|
|
|
|
|
|
from .service_factory import KBServiceFactory
|
2024-01-26 14:03:25 +08:00
|
|
|
from coagent.utils.server_utils import BaseResponse, ListResponse
|
|
|
|
from coagent.utils.path_utils import *
|
|
|
|
from coagent.orm.commands import *
|
|
|
|
from coagent.db_handler.graph_db_handler.nebula_handler import NebulaHandler
|
|
|
|
from coagent.db_handler.vector_db_handler.chroma_handler import ChromaHandler
|
|
|
|
from coagent.base_configs.env_config import (
|
|
|
|
CB_ROOT_PATH,
|
|
|
|
NEBULA_HOST, NEBULA_PORT, NEBULA_USER, NEBULA_PASSWORD, NEBULA_STORAGED_PORT,
|
|
|
|
CHROMA_PERSISTENT_PATH
|
2023-11-07 19:44:47 +08:00
|
|
|
)
|
|
|
|
|
2023-12-07 20:17:21 +08:00
|
|
|
|
2024-01-26 14:03:25 +08:00
|
|
|
# from configs.model_config import (
|
|
|
|
# CB_ROOT_PATH
|
|
|
|
# )
|
|
|
|
|
|
|
|
# from coagent.codebase_handler.codebase_handler import CodeBaseHandler
|
|
|
|
from coagent.llm_models.llm_config import EmbedConfig, LLMConfig
|
|
|
|
from coagent.codechat.codebase_handler.codebase_handler import CodeBaseHandler
|
2023-11-07 19:44:47 +08:00
|
|
|
|
|
|
|
from loguru import logger
|
|
|
|
|
|
|
|
|
|
|
|
async def list_cbs():
|
|
|
|
# Get List of Knowledge Base
|
|
|
|
return ListResponse(data=list_cbs_from_db())
|
|
|
|
|
|
|
|
|
2023-12-07 20:17:21 +08:00
|
|
|
async def create_cb(zip_file,
|
|
|
|
cb_name: str = Body(..., examples=["samples"]),
|
|
|
|
code_path: str = Body(..., examples=["samples"]),
|
2024-01-26 14:03:25 +08:00
|
|
|
do_interpret: bool = Body(..., examples=["samples"]),
|
|
|
|
api_key: bool = Body(..., examples=["samples"]),
|
|
|
|
api_base_url: bool = Body(..., examples=["samples"]),
|
|
|
|
embed_model: bool = Body(..., examples=["samples"]),
|
|
|
|
embed_model_path: bool = Body(..., examples=["samples"]),
|
|
|
|
embed_engine: bool = Body(..., examples=["samples"]),
|
|
|
|
model_name: bool = Body(..., examples=["samples"]),
|
|
|
|
temperature: bool = Body(..., examples=["samples"]),
|
|
|
|
model_device: bool = Body(..., examples=["samples"]),
|
2024-03-12 15:31:06 +08:00
|
|
|
embed_config: EmbedConfig = None,
|
2023-11-07 19:44:47 +08:00
|
|
|
) -> BaseResponse:
|
2023-12-07 20:17:21 +08:00
|
|
|
logger.info('cb_name={}, zip_path={}, do_interpret={}'.format(cb_name, code_path, do_interpret))
|
2023-11-07 19:44:47 +08:00
|
|
|
|
2024-03-12 15:31:06 +08:00
|
|
|
embed_config: EmbedConfig = EmbedConfig(**locals()) if embed_config is None else embed_config
|
2024-01-26 14:03:25 +08:00
|
|
|
llm_config: LLMConfig = LLMConfig(**locals())
|
|
|
|
|
2023-11-07 19:44:47 +08:00
|
|
|
# Create selected knowledge base
|
|
|
|
if not validate_kb_name(cb_name):
|
|
|
|
return BaseResponse(code=403, msg="Don't attack me")
|
|
|
|
if cb_name is None or cb_name.strip() == "":
|
|
|
|
return BaseResponse(code=404, msg="知识库名称不能为空,请重新填写知识库名称")
|
|
|
|
|
|
|
|
cb = cb_exists(cb_name)
|
|
|
|
if cb:
|
|
|
|
return BaseResponse(code=404, msg=f"已存在同名代码知识库 {cb_name}")
|
|
|
|
|
|
|
|
try:
|
|
|
|
logger.info('start build code base')
|
2024-01-26 14:03:25 +08:00
|
|
|
cbh = CodeBaseHandler(cb_name, code_path, embed_config=embed_config, llm_config=llm_config)
|
2023-12-07 20:17:21 +08:00
|
|
|
vertices_num, edge_num, file_num = cbh.import_code(zip_file=zip_file, do_interpret=do_interpret)
|
2023-11-07 19:44:47 +08:00
|
|
|
logger.info('build code base done')
|
|
|
|
|
|
|
|
# create cb to table
|
2023-12-07 20:17:21 +08:00
|
|
|
add_cb_to_db(cb_name, cbh.code_path, vertices_num, file_num, do_interpret)
|
2023-11-07 19:44:47 +08:00
|
|
|
logger.info('add cb to mysql table success')
|
|
|
|
except Exception as e:
|
|
|
|
print(e)
|
2023-12-07 20:17:21 +08:00
|
|
|
logger.exception(e)
|
2023-11-07 19:44:47 +08:00
|
|
|
return BaseResponse(code=500, msg=f"创建代码知识库出错: {e}")
|
|
|
|
|
|
|
|
return BaseResponse(code=200, msg=f"已新增代码知识库 {cb_name}")
|
|
|
|
|
|
|
|
|
2024-01-26 14:03:25 +08:00
|
|
|
async def delete_cb(
|
|
|
|
cb_name: str = Body(..., examples=["samples"]),
|
|
|
|
api_key: bool = Body(..., examples=["samples"]),
|
|
|
|
api_base_url: bool = Body(..., examples=["samples"]),
|
|
|
|
embed_model: bool = Body(..., examples=["samples"]),
|
|
|
|
embed_model_path: bool = Body(..., examples=["samples"]),
|
|
|
|
embed_engine: bool = Body(..., examples=["samples"]),
|
|
|
|
model_name: bool = Body(..., examples=["samples"]),
|
|
|
|
temperature: bool = Body(..., examples=["samples"]),
|
|
|
|
model_device: bool = Body(..., examples=["samples"]),
|
2024-03-12 15:31:06 +08:00
|
|
|
embed_config: EmbedConfig = None,
|
2024-01-26 14:03:25 +08:00
|
|
|
) -> BaseResponse:
|
2023-11-07 19:44:47 +08:00
|
|
|
logger.info('cb_name={}'.format(cb_name))
|
2024-03-12 15:31:06 +08:00
|
|
|
embed_config: EmbedConfig = EmbedConfig(**locals()) if embed_config is None else embed_config
|
2024-01-26 14:03:25 +08:00
|
|
|
llm_config: LLMConfig = LLMConfig(**locals())
|
2023-11-07 19:44:47 +08:00
|
|
|
# Create selected knowledge base
|
|
|
|
if not validate_kb_name(cb_name):
|
|
|
|
return BaseResponse(code=403, msg="Don't attack me")
|
|
|
|
if cb_name is None or cb_name.strip() == "":
|
|
|
|
return BaseResponse(code=404, msg="知识库名称不能为空,请重新填写知识库名称")
|
|
|
|
|
|
|
|
cb = cb_exists(cb_name)
|
|
|
|
if cb:
|
|
|
|
try:
|
|
|
|
delete_cb_from_db(cb_name)
|
|
|
|
|
|
|
|
# delete local file
|
|
|
|
shutil.rmtree(CB_ROOT_PATH + os.sep + cb_name)
|
2023-12-07 20:17:21 +08:00
|
|
|
|
|
|
|
# delete from codebase
|
2024-01-26 14:03:25 +08:00
|
|
|
cbh = CodeBaseHandler(cb_name, embed_config=embed_config, llm_config=llm_config)
|
2023-12-07 20:17:21 +08:00
|
|
|
cbh.delete_codebase(codebase_name=cb_name)
|
|
|
|
|
2023-11-07 19:44:47 +08:00
|
|
|
except Exception as e:
|
|
|
|
print(e)
|
|
|
|
return BaseResponse(code=500, msg=f"删除代码知识库出错: {e}")
|
|
|
|
|
|
|
|
return BaseResponse(code=200, msg=f"已删除代码知识库 {cb_name}")
|
|
|
|
|
|
|
|
|
|
|
|
def search_code(cb_name: str = Body(..., examples=["sofaboot"]),
|
|
|
|
query: str = Body(..., examples=['你好']),
|
|
|
|
code_limit: int = Body(..., examples=['1']),
|
2023-12-07 20:17:21 +08:00
|
|
|
search_type: str = Body(..., examples=['你好']),
|
2024-01-26 14:03:25 +08:00
|
|
|
history_node_list: list = Body(...),
|
|
|
|
api_key: bool = Body(..., examples=["samples"]),
|
|
|
|
api_base_url: bool = Body(..., examples=["samples"]),
|
|
|
|
embed_model: bool = Body(..., examples=["samples"]),
|
|
|
|
embed_model_path: bool = Body(..., examples=["samples"]),
|
|
|
|
embed_engine: bool = Body(..., examples=["samples"]),
|
|
|
|
model_name: bool = Body(..., examples=["samples"]),
|
|
|
|
temperature: bool = Body(..., examples=["samples"]),
|
|
|
|
model_device: bool = Body(..., examples=["samples"]),
|
2024-03-12 15:31:06 +08:00
|
|
|
use_nh: bool = True,
|
|
|
|
local_graph_path: str = '',
|
|
|
|
embed_config: EmbedConfig = None,
|
2024-01-26 14:03:25 +08:00
|
|
|
) -> dict:
|
2024-03-12 15:31:06 +08:00
|
|
|
|
|
|
|
if os.environ.get("log_verbose", "0") >= "2":
|
|
|
|
logger.info(f'local_graph_path={local_graph_path}')
|
|
|
|
logger.info('cb_name={}'.format(cb_name))
|
|
|
|
logger.info('query={}'.format(query))
|
|
|
|
logger.info('code_limit={}'.format(code_limit))
|
|
|
|
logger.info('search_type={}'.format(search_type))
|
|
|
|
logger.info('history_node_list={}'.format(history_node_list))
|
|
|
|
embed_config: EmbedConfig = EmbedConfig(**locals()) if embed_config is None else embed_config
|
2024-01-26 14:03:25 +08:00
|
|
|
llm_config: LLMConfig = LLMConfig(**locals())
|
2023-11-07 19:44:47 +08:00
|
|
|
try:
|
|
|
|
# load codebase
|
2024-03-12 15:31:06 +08:00
|
|
|
cbh = CodeBaseHandler(codebase_name=cb_name, embed_config=embed_config, llm_config=llm_config,
|
|
|
|
use_nh=use_nh, local_graph_path=local_graph_path)
|
2023-11-07 19:44:47 +08:00
|
|
|
|
|
|
|
# search code
|
2023-12-07 20:17:21 +08:00
|
|
|
context, related_vertices = cbh.search_code(query, search_type=search_type, limit=code_limit)
|
2023-11-07 19:44:47 +08:00
|
|
|
|
|
|
|
res = {
|
2023-12-07 20:17:21 +08:00
|
|
|
'context': context,
|
|
|
|
'related_vertices': related_vertices
|
2023-11-07 19:44:47 +08:00
|
|
|
}
|
|
|
|
return res
|
|
|
|
except Exception as e:
|
|
|
|
logger.exception(e)
|
|
|
|
return {}
|
|
|
|
|
|
|
|
|
2023-12-26 11:41:53 +08:00
|
|
|
def search_related_vertices(cb_name: str = Body(..., examples=["sofaboot"]),
|
|
|
|
vertex: str = Body(..., examples=['***'])) -> dict:
|
|
|
|
|
|
|
|
logger.info('cb_name={}'.format(cb_name))
|
|
|
|
logger.info('vertex={}'.format(vertex))
|
|
|
|
|
|
|
|
try:
|
|
|
|
# load codebase
|
|
|
|
nh = NebulaHandler(host=NEBULA_HOST, port=NEBULA_PORT, username=NEBULA_USER,
|
|
|
|
password=NEBULA_PASSWORD, space_name=cb_name)
|
2024-03-12 15:31:06 +08:00
|
|
|
|
|
|
|
if vertex.endswith(".java"):
|
|
|
|
cypher = f'''MATCH (v1)--(v2:package) WHERE id(v1) == '{vertex}' RETURN id(v2) as id;'''
|
|
|
|
else:
|
|
|
|
cypher = f'''MATCH (v1)--(v2) WHERE id(v1) == '{vertex}' RETURN id(v2) as id;'''
|
|
|
|
# cypher = f'''MATCH (v1)--(v2) WHERE id(v1) == '{vertex}' RETURN v2;'''
|
2023-12-26 11:41:53 +08:00
|
|
|
cypher_res = nh.execute_cypher(cypher=cypher, format_res=True)
|
|
|
|
related_vertices = cypher_res.get('id', [])
|
2024-01-26 14:03:25 +08:00
|
|
|
related_vertices = [i.as_string() for i in related_vertices]
|
2023-12-26 11:41:53 +08:00
|
|
|
|
|
|
|
res = {
|
|
|
|
'vertices': related_vertices
|
|
|
|
}
|
|
|
|
|
|
|
|
return res
|
|
|
|
except Exception as e:
|
|
|
|
logger.exception(e)
|
|
|
|
return {}
|
|
|
|
|
|
|
|
|
|
|
|
def search_code_by_vertex(cb_name: str = Body(..., examples=["sofaboot"]),
|
|
|
|
vertex: str = Body(..., examples=['***'])) -> dict:
|
|
|
|
|
2024-03-12 15:31:06 +08:00
|
|
|
# logger.info('cb_name={}'.format(cb_name))
|
|
|
|
# logger.info('vertex={}'.format(vertex))
|
2023-12-26 11:41:53 +08:00
|
|
|
|
|
|
|
try:
|
2024-01-26 14:03:25 +08:00
|
|
|
nh = NebulaHandler(host=NEBULA_HOST, port=NEBULA_PORT, username=NEBULA_USER,
|
|
|
|
password=NEBULA_PASSWORD, space_name=cb_name)
|
2023-12-26 11:41:53 +08:00
|
|
|
|
2024-01-26 14:03:25 +08:00
|
|
|
cypher = f'''MATCH (v1:package)-[e:contain]->(v2) WHERE id(v2) == '{vertex}' RETURN id(v1) as id;'''
|
|
|
|
cypher_res = nh.execute_cypher(cypher=cypher, format_res=True)
|
|
|
|
|
|
|
|
related_vertices = cypher_res.get('id', [])
|
|
|
|
related_vertices = [i.as_string() for i in related_vertices]
|
|
|
|
|
|
|
|
if not related_vertices:
|
|
|
|
return {'code': ''}
|
|
|
|
ch = ChromaHandler(path=CHROMA_PERSISTENT_PATH, collection_name=cb_name)
|
2023-12-26 11:41:53 +08:00
|
|
|
|
2024-01-26 14:03:25 +08:00
|
|
|
# logger.info(related_vertices)
|
|
|
|
chroma_res = ch.get(ids=related_vertices, include=['metadatas'])
|
|
|
|
# logger.info(chroma_res)
|
2023-12-26 11:41:53 +08:00
|
|
|
|
2024-01-26 14:03:25 +08:00
|
|
|
if chroma_res['result']['ids']:
|
|
|
|
code_text = chroma_res['result']['metadatas'][0]['code_text']
|
|
|
|
else:
|
|
|
|
code_text = ''
|
2023-12-26 11:41:53 +08:00
|
|
|
|
|
|
|
res = {
|
|
|
|
'code': code_text
|
|
|
|
}
|
|
|
|
|
|
|
|
return res
|
|
|
|
except Exception as e:
|
|
|
|
logger.exception(e)
|
2024-03-12 15:31:06 +08:00
|
|
|
return {'code': ""}
|
2023-12-26 11:41:53 +08:00
|
|
|
|
|
|
|
|
2023-11-07 19:44:47 +08:00
|
|
|
def cb_exists_api(cb_name: str = Body(..., examples=["sofaboot"])) -> bool:
|
|
|
|
try:
|
|
|
|
res = cb_exists(cb_name)
|
|
|
|
return res
|
|
|
|
except Exception as e:
|
|
|
|
logger.exception(e)
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
|