codefuse-chatbot/coagent/service/cb_api.py

256 lines
9.8 KiB
Python

# encoding: utf-8
'''
@author: 温进
@file: cb_api.py
@time: 2023/10/23 下午7:08
@desc:
'''
import urllib, os, json, traceback
from typing import List, Dict
import shutil
from fastapi.responses import StreamingResponse, FileResponse
from fastapi import File, Form, Body, Query, UploadFile
from langchain.docstore.document import Document
from .service_factory import KBServiceFactory
from coagent.utils.server_utils import BaseResponse, ListResponse
from coagent.utils.path_utils import *
from coagent.orm.commands import *
from coagent.db_handler.graph_db_handler.nebula_handler import NebulaHandler
from coagent.db_handler.vector_db_handler.chroma_handler import ChromaHandler
from coagent.base_configs.env_config import (
CB_ROOT_PATH,
NEBULA_HOST, NEBULA_PORT, NEBULA_USER, NEBULA_PASSWORD, NEBULA_STORAGED_PORT,
CHROMA_PERSISTENT_PATH
)
# from configs.model_config import (
# CB_ROOT_PATH
# )
# from coagent.codebase_handler.codebase_handler import CodeBaseHandler
from coagent.llm_models.llm_config import EmbedConfig, LLMConfig
from coagent.codechat.codebase_handler.codebase_handler import CodeBaseHandler
from loguru import logger
async def list_cbs():
# Get List of Knowledge Base
return ListResponse(data=list_cbs_from_db())
async def create_cb(zip_file,
cb_name: str = Body(..., examples=["samples"]),
code_path: str = Body(..., examples=["samples"]),
do_interpret: bool = Body(..., examples=["samples"]),
api_key: bool = Body(..., examples=["samples"]),
api_base_url: bool = Body(..., examples=["samples"]),
embed_model: bool = Body(..., examples=["samples"]),
embed_model_path: bool = Body(..., examples=["samples"]),
embed_engine: bool = Body(..., examples=["samples"]),
model_name: bool = Body(..., examples=["samples"]),
temperature: bool = Body(..., examples=["samples"]),
model_device: bool = Body(..., examples=["samples"]),
embed_config: EmbedConfig = None,
) -> BaseResponse:
logger.info('cb_name={}, zip_path={}, do_interpret={}'.format(cb_name, code_path, do_interpret))
embed_config: EmbedConfig = EmbedConfig(**locals()) if embed_config is None else embed_config
llm_config: LLMConfig = LLMConfig(**locals())
# Create selected knowledge base
if not validate_kb_name(cb_name):
return BaseResponse(code=403, msg="Don't attack me")
if cb_name is None or cb_name.strip() == "":
return BaseResponse(code=404, msg="知识库名称不能为空,请重新填写知识库名称")
cb = cb_exists(cb_name)
if cb:
return BaseResponse(code=404, msg=f"已存在同名代码知识库 {cb_name}")
try:
logger.info('start build code base')
cbh = CodeBaseHandler(cb_name, code_path, embed_config=embed_config, llm_config=llm_config)
vertices_num, edge_num, file_num = cbh.import_code(zip_file=zip_file, do_interpret=do_interpret)
logger.info('build code base done')
# create cb to table
add_cb_to_db(cb_name, cbh.code_path, vertices_num, file_num, do_interpret)
logger.info('add cb to mysql table success')
except Exception as e:
print(e)
logger.exception(e)
return BaseResponse(code=500, msg=f"创建代码知识库出错: {e}")
return BaseResponse(code=200, msg=f"已新增代码知识库 {cb_name}")
async def delete_cb(
cb_name: str = Body(..., examples=["samples"]),
api_key: bool = Body(..., examples=["samples"]),
api_base_url: bool = Body(..., examples=["samples"]),
embed_model: bool = Body(..., examples=["samples"]),
embed_model_path: bool = Body(..., examples=["samples"]),
embed_engine: bool = Body(..., examples=["samples"]),
model_name: bool = Body(..., examples=["samples"]),
temperature: bool = Body(..., examples=["samples"]),
model_device: bool = Body(..., examples=["samples"]),
embed_config: EmbedConfig = None,
) -> BaseResponse:
logger.info('cb_name={}'.format(cb_name))
embed_config: EmbedConfig = EmbedConfig(**locals()) if embed_config is None else embed_config
llm_config: LLMConfig = LLMConfig(**locals())
# Create selected knowledge base
if not validate_kb_name(cb_name):
return BaseResponse(code=403, msg="Don't attack me")
if cb_name is None or cb_name.strip() == "":
return BaseResponse(code=404, msg="知识库名称不能为空,请重新填写知识库名称")
cb = cb_exists(cb_name)
if cb:
try:
delete_cb_from_db(cb_name)
# delete local file
shutil.rmtree(CB_ROOT_PATH + os.sep + cb_name)
# delete from codebase
cbh = CodeBaseHandler(cb_name, embed_config=embed_config, llm_config=llm_config)
cbh.delete_codebase(codebase_name=cb_name)
except Exception as e:
print(e)
return BaseResponse(code=500, msg=f"删除代码知识库出错: {e}")
return BaseResponse(code=200, msg=f"已删除代码知识库 {cb_name}")
def search_code(cb_name: str = Body(..., examples=["sofaboot"]),
query: str = Body(..., examples=['你好']),
code_limit: int = Body(..., examples=['1']),
search_type: str = Body(..., examples=['你好']),
history_node_list: list = Body(...),
api_key: bool = Body(..., examples=["samples"]),
api_base_url: bool = Body(..., examples=["samples"]),
embed_model: bool = Body(..., examples=["samples"]),
embed_model_path: bool = Body(..., examples=["samples"]),
embed_engine: bool = Body(..., examples=["samples"]),
model_name: bool = Body(..., examples=["samples"]),
temperature: bool = Body(..., examples=["samples"]),
model_device: bool = Body(..., examples=["samples"]),
use_nh: bool = True,
local_graph_path: str = '',
embed_config: EmbedConfig = None,
) -> dict:
if os.environ.get("log_verbose", "0") >= "2":
logger.info(f'local_graph_path={local_graph_path}')
logger.info('cb_name={}'.format(cb_name))
logger.info('query={}'.format(query))
logger.info('code_limit={}'.format(code_limit))
logger.info('search_type={}'.format(search_type))
logger.info('history_node_list={}'.format(history_node_list))
embed_config: EmbedConfig = EmbedConfig(**locals()) if embed_config is None else embed_config
llm_config: LLMConfig = LLMConfig(**locals())
try:
# load codebase
cbh = CodeBaseHandler(codebase_name=cb_name, embed_config=embed_config, llm_config=llm_config,
use_nh=use_nh, local_graph_path=local_graph_path)
# search code
context, related_vertices = cbh.search_code(query, search_type=search_type, limit=code_limit)
res = {
'context': context,
'related_vertices': related_vertices
}
return res
except Exception as e:
logger.exception(e)
return {}
def search_related_vertices(cb_name: str = Body(..., examples=["sofaboot"]),
vertex: str = Body(..., examples=['***'])) -> dict:
logger.info('cb_name={}'.format(cb_name))
logger.info('vertex={}'.format(vertex))
try:
# load codebase
nh = NebulaHandler(host=NEBULA_HOST, port=NEBULA_PORT, username=NEBULA_USER,
password=NEBULA_PASSWORD, space_name=cb_name)
if vertex.endswith(".java"):
cypher = f'''MATCH (v1)--(v2:package) WHERE id(v1) == '{vertex}' RETURN id(v2) as id;'''
else:
cypher = f'''MATCH (v1)--(v2) WHERE id(v1) == '{vertex}' RETURN id(v2) as id;'''
# cypher = f'''MATCH (v1)--(v2) WHERE id(v1) == '{vertex}' RETURN v2;'''
cypher_res = nh.execute_cypher(cypher=cypher, format_res=True)
related_vertices = cypher_res.get('id', [])
related_vertices = [i.as_string() for i in related_vertices]
res = {
'vertices': related_vertices
}
return res
except Exception as e:
logger.exception(e)
return {}
def search_code_by_vertex(cb_name: str = Body(..., examples=["sofaboot"]),
vertex: str = Body(..., examples=['***'])) -> dict:
# logger.info('cb_name={}'.format(cb_name))
# logger.info('vertex={}'.format(vertex))
try:
nh = NebulaHandler(host=NEBULA_HOST, port=NEBULA_PORT, username=NEBULA_USER,
password=NEBULA_PASSWORD, space_name=cb_name)
cypher = f'''MATCH (v1:package)-[e:contain]->(v2) WHERE id(v2) == '{vertex}' RETURN id(v1) as id;'''
cypher_res = nh.execute_cypher(cypher=cypher, format_res=True)
related_vertices = cypher_res.get('id', [])
related_vertices = [i.as_string() for i in related_vertices]
if not related_vertices:
return {'code': ''}
ch = ChromaHandler(path=CHROMA_PERSISTENT_PATH, collection_name=cb_name)
# logger.info(related_vertices)
chroma_res = ch.get(ids=related_vertices, include=['metadatas'])
# logger.info(chroma_res)
if chroma_res['result']['ids']:
code_text = chroma_res['result']['metadatas'][0]['code_text']
else:
code_text = ''
res = {
'code': code_text
}
return res
except Exception as e:
logger.exception(e)
return {'code': ""}
def cb_exists_api(cb_name: str = Body(..., examples=["sofaboot"])) -> bool:
try:
res = cb_exists(cb_name)
return res
except Exception as e:
logger.exception(e)
return False