codefuse-chatbot/coagent/codechat/codebase_handler/codebase_handler.py

276 lines
10 KiB
Python
Raw Normal View History

# encoding: utf-8
'''
@author: 温进
@file: codebase_handler.py
@time: 2023/11/21 下午2:25
@desc:
'''
import os
import time
import json
from typing import List
from loguru import logger
from coagent.base_configs.env_config import (
NEBULA_HOST, NEBULA_PORT, NEBULA_USER, NEBULA_PASSWORD, NEBULA_STORAGED_PORT,
CHROMA_PERSISTENT_PATH, CB_ROOT_PATH
)
from coagent.db_handler.graph_db_handler.nebula_handler import NebulaHandler
from coagent.db_handler.vector_db_handler.chroma_handler import ChromaHandler
from coagent.codechat.code_crawler.zip_crawler import *
from coagent.codechat.code_analyzer.code_analyzer import CodeAnalyzer
from coagent.codechat.codebase_handler.code_importer import CodeImporter
from coagent.codechat.code_search.code_search import CodeSearch
from coagent.llm_models.llm_config import EmbedConfig, LLMConfig
class CodeBaseHandler:
def __init__(
self,
codebase_name: str,
code_path: str = '',
language: str = 'java',
crawl_type: str = 'ZIP',
embed_config: EmbedConfig = EmbedConfig(),
llm_config: LLMConfig = LLMConfig(),
use_nh: bool = True,
local_graph_path: str = CB_ROOT_PATH
):
self.codebase_name = codebase_name
self.code_path = code_path
self.language = language
self.crawl_type = crawl_type
self.embed_config = embed_config
self.llm_config = llm_config
self.local_graph_file_path = local_graph_path + os.sep + f'{self.codebase_name}_graph.json'
if use_nh:
try:
self.nh = NebulaHandler(host=NEBULA_HOST, port=NEBULA_PORT, username=NEBULA_USER,
password=NEBULA_PASSWORD, space_name=codebase_name)
self.nh.add_host(NEBULA_HOST, NEBULA_STORAGED_PORT)
time.sleep(1)
except:
self.nh = None
try:
with open(self.local_graph_file_path, 'r') as f:
self.graph = json.load(f)
except:
pass
elif local_graph_path:
self.nh = None
try:
with open(self.local_graph_file_path, 'r') as f:
self.graph = json.load(f)
except:
pass
self.ch = ChromaHandler(path=CHROMA_PERSISTENT_PATH, collection_name=codebase_name)
def import_code(self, zip_file='', do_interpret=True):
'''
analyze code and save it to codekg and codedb
@return:
'''
# init graph to init tag and edge
code_importer = CodeImporter(embed_config=self.embed_config, codebase_name=self.codebase_name,
nh=self.nh, ch=self.ch, local_graph_file_path=self.local_graph_file_path)
if self.nh:
code_importer.init_graph()
time.sleep(5)
# crawl code
st0 = time.time()
logger.info('start crawl')
code_dict = self.crawl_code(zip_file)
logger.debug('crawl done, rt={}'.format(time.time() - st0))
# analyze code
logger.info('start analyze')
st1 = time.time()
code_analyzer = CodeAnalyzer(language=self.language, llm_config=self.llm_config)
static_analysis_res, interpretation = code_analyzer.analyze(code_dict, do_interpret=do_interpret)
logger.debug('analyze done, rt={}'.format(time.time() - st1))
# add info to nebula and chroma
st2 = time.time()
code_importer.import_code(static_analysis_res, interpretation, do_interpret=do_interpret)
logger.debug('update codebase done, rt={}'.format(time.time() - st2))
# get KG info
if self.nh:
time.sleep(10) # aviod nebula staus didn't complete
stat = self.nh.get_stat()
vertices_num, edges_num = stat['vertices'], stat['edges']
else:
vertices_num = 0
edges_num = 0
# get chroma info
file_num = self.ch.count()['result']
return vertices_num, edges_num, file_num
def delete_codebase(self, codebase_name: str):
'''
delete codebase
@param codebase_name: name of codebase
@return:
'''
if self.nh:
self.nh.drop_space(space_name=codebase_name)
elif self.local_graph_file_path and os.path.isfile(self.local_graph_file_path):
os.remove(self.local_graph_file_path)
self.ch.delete_collection(collection_name=codebase_name)
def crawl_code(self, zip_file=''):
'''
@return:
'''
if self.language == 'java':
suffix = 'java'
logger.info(f'crawl_type={self.crawl_type}')
code_dict = {}
if self.crawl_type.lower() == 'zip':
code_dict = ZipCrawler.crawl(zip_file, output_path=self.code_path, suffix=suffix)
elif self.crawl_type.lower() == 'dir':
code_dict = DirCrawler.crawl(self.code_path, suffix)
return code_dict
def search_code(self, query: str, search_type: str, limit: int = 3):
'''
search code from codebase
@param limit:
@param engine:
@param query: query from user
@param search_type: ['cypher', 'graph', 'vector']
@return:
'''
if self.nh:
assert search_type in ['cypher', 'tag', 'description']
else:
if search_type == 'tag':
search_type = 'tag_by_local_graph'
assert search_type in ['tag_by_local_graph', 'description']
code_search = CodeSearch(llm_config=self.llm_config, nh=self.nh, ch=self.ch, limit=limit,
local_graph_file_path=self.local_graph_file_path)
if search_type == 'cypher':
search_res = code_search.search_by_cypher(query=query)
elif search_type == 'tag':
search_res = code_search.search_by_tag(query=query)
elif search_type == 'description':
search_res = code_search.search_by_desciption(
query=query, engine=self.embed_config.embed_engine, model_path=self.embed_config.embed_model_path,
embedding_device=self.embed_config.model_device, embed_config=self.embed_config)
elif search_type == 'tag_by_local_graph':
search_res = code_search.search_by_tag_by_graph(query=query)
context, related_vertice = self.format_search_res(search_res, search_type)
return context, related_vertice
def format_search_res(self, search_res: str, search_type: str):
'''
format search_res
@param search_res:
@param search_type:
@return:
'''
CYPHER_QA_PROMPT = '''
执行的 Cypher : {cypher}
Cypher 的结果是: {result}
'''
if search_type == 'cypher':
context = CYPHER_QA_PROMPT.format(cypher=search_res['cypher'], result=search_res['cypher_res'])
related_vertice = []
elif search_type == 'tag':
context = ''
related_vertice = []
for code in search_res:
context = context + code['code_text'] + '\n'
related_vertice.append(code['vertex'])
elif search_type == 'tag_by_local_graph':
context = ''
related_vertice = []
for code in search_res:
context = context + code['code_text'] + '\n'
related_vertice.append(code['vertex'])
elif search_type == 'description':
context = ''
related_vertice = []
for code in search_res:
context = context + code['code_text'] + '\n'
related_vertice.append(code['vertex'])
return context, related_vertice
def search_vertices(self, vertex_type="class") -> List[str]:
'''
通过 method/class 来搜索所有的节点
'''
vertices = []
if self.nh:
vertices = self.nh.get_all_vertices()
vertices = [str(v.as_node().get_id()) for v in vertices["v"] if vertex_type in v.as_node().tags()]
# for v in vertices["v"]:
# logger.debug(f"{v.as_node().get_id()}, {v.as_node().tags()}")
else:
if vertex_type == "class":
vertices = [str(class_name) for code, structure in self.graph.items() for class_name in structure['class_name_list']]
elif vertex_type == "method":
vertices = [
str(methods_name)
for code, structure in self.graph.items()
for methods_names in structure['func_name_dict'].values()
for methods_name in methods_names
]
# logger.debug(vertices)
return vertices
if __name__ == '__main__':
from configs.model_config import KB_ROOT_PATH, JUPYTER_WORK_PATH
from configs.server_config import SANDBOX_SERVER
LLM_MODEL = "gpt-3.5-turbo"
llm_config = LLMConfig(
model_name=LLM_MODEL, model_device="cpu", api_key=os.environ["OPENAI_API_KEY"],
api_base_url=os.environ["API_BASE_URL"], temperature=0.3
)
src_dir = '/Users/bingxu/Desktop/工作/大模型/chatbot/Codefuse-chatbot-antcode'
embed_config = EmbedConfig(
embed_engine="model", embed_model="text2vec-base-chinese",
embed_model_path=os.path.join(src_dir, "embedding_models/text2vec-base-chinese")
)
codebase_name = 'client_local'
code_path = '/Users/bingxu/Desktop/工作/大模型/chatbot/test_code_repo/client'
use_nh = False
local_graph_path = '/Users/bingxu/Desktop/工作/大模型/chatbot/Codefuse-chatbot-antcode/code_base'
CHROMA_PERSISTENT_PATH = '/Users/bingxu/Desktop/工作/大模型/chatbot/Codefuse-chatbot-antcode/data/chroma_data'
cbh = CodeBaseHandler(codebase_name, code_path, crawl_type='dir', use_nh=use_nh, local_graph_path=local_graph_path,
llm_config=llm_config, embed_config=embed_config)
# test import code
# cbh.import_code(do_interpret=True)
# query = '使用不同的HTTP请求类型GET、POST、DELETE等来执行不同的操作'
# query = '代码中一共有多少个类'
# query = 'remove 这个函数是用来做什么的'
query = '有没有函数是从字符串中删除指定字符串的功能'
search_type = 'description'
limit = 2
res = cbh.search_code(query, search_type, limit)
logger.debug(res)