codefuse-chatbot/dev_opsgpt/codechat/codebase_handler/codebase_handler.py

170 lines
6.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# encoding: utf-8
'''
@author: 温进
@file: codebase_handler.py
@time: 2023/11/21 下午2:25
@desc:
'''
import time
from loguru import logger
from configs.server_config import NEBULA_HOST, NEBULA_PORT, NEBULA_USER, NEBULA_PASSWORD, NEBULA_STORAGED_PORT
from configs.server_config import CHROMA_PERSISTENT_PATH
from configs.model_config import EMBEDDING_ENGINE
from dev_opsgpt.db_handler.graph_db_handler.nebula_handler import NebulaHandler
from dev_opsgpt.db_handler.vector_db_handler.chroma_handler import ChromaHandler
from dev_opsgpt.codechat.code_crawler.zip_crawler import *
from dev_opsgpt.codechat.code_analyzer.code_analyzer import CodeAnalyzer
from dev_opsgpt.codechat.codebase_handler.code_importer import CodeImporter
from dev_opsgpt.codechat.code_search.code_search import CodeSearch
class CodeBaseHandler:
def __init__(self, codebase_name: str, code_path: str = '',
language: str = 'java', crawl_type: str = 'ZIP'):
self.codebase_name = codebase_name
self.code_path = code_path
self.language = language
self.crawl_type = crawl_type
self.nh = NebulaHandler(host=NEBULA_HOST, port=NEBULA_PORT, username=NEBULA_USER,
password=NEBULA_PASSWORD, space_name=codebase_name)
self.nh.add_host(NEBULA_HOST, NEBULA_STORAGED_PORT)
time.sleep(1)
self.ch = ChromaHandler(path=CHROMA_PERSISTENT_PATH, collection_name=codebase_name)
def import_code(self, zip_file='', do_interpret=True):
'''
analyze code and save it to codekg and codedb
@return:
'''
# init graph to init tag and edge
code_importer = CodeImporter(engine=EMBEDDING_ENGINE, codebase_name=self.codebase_name,
nh=self.nh, ch=self.ch)
code_importer.init_graph()
time.sleep(5)
# crawl code
st0 = time.time()
logger.info('start crawl')
code_dict = self.crawl_code(zip_file)
logger.debug('crawl done, rt={}'.format(time.time() - st0))
# analyze code
logger.info('start analyze')
st1 = time.time()
code_analyzer = CodeAnalyzer(language=self.language)
static_analysis_res, interpretation = code_analyzer.analyze(code_dict, do_interpret=do_interpret)
logger.debug('analyze done, rt={}'.format(time.time() - st1))
# add info to nebula and chroma
st2 = time.time()
code_importer.import_code(static_analysis_res, interpretation, do_interpret=do_interpret)
logger.debug('update codebase done, rt={}'.format(time.time() - st2))
# get KG info
stat = self.nh.get_stat()
vertices_num, edges_num = stat['vertices'], stat['edges']
# get chroma info
file_num = self.ch.count()['result']
return vertices_num, edges_num, file_num
def delete_codebase(self, codebase_name: str):
'''
delete codebase
@param codebase_name: name of codebase
@return:
'''
self.nh.drop_space(space_name=codebase_name)
self.ch.delete_collection(collection_name=codebase_name)
def crawl_code(self, zip_file=''):
'''
@return:
'''
if self.language == 'java':
suffix = 'java'
logger.info(f'crawl_type={self.crawl_type}')
code_dict = {}
if self.crawl_type.lower() == 'zip':
code_dict = ZipCrawler.crawl(zip_file, output_path=self.code_path, suffix=suffix)
elif self.crawl_type.lower() == 'dir':
code_dict = DirCrawler.crawl(self.code_path, suffix)
return code_dict
def search_code(self, query: str, search_type: str, limit: int = 3):
'''
search code from codebase
@param limit:
@param engine:
@param query: query from user
@param search_type: ['cypher', 'graph', 'vector']
@return:
'''
assert search_type in ['cypher', 'tag', 'description']
code_search = CodeSearch(nh=self.nh, ch=self.ch, limit=limit)
if search_type == 'cypher':
search_res = code_search.search_by_cypher(query=query)
elif search_type == 'tag':
search_res = code_search.search_by_tag(query=query)
elif search_type == 'description':
search_res = code_search.search_by_desciption(query=query, engine=EMBEDDING_ENGINE)
context, related_vertice = self.format_search_res(search_res, search_type)
return context, related_vertice
def format_search_res(self, search_res: str, search_type: str):
'''
format search_res
@param search_res:
@param search_type:
@return:
'''
CYPHER_QA_PROMPT = '''
执行的 Cypher 是: {cypher}
Cypher 的结果是: {result}
'''
if search_type == 'cypher':
context = CYPHER_QA_PROMPT.format(cypher=search_res['cypher'], result=search_res['cypher_res'])
related_vertice = []
elif search_type == 'tag':
context = ''
related_vertice = []
for code in search_res:
context = context + code['code_text'] + '\n'
related_vertice.append(code['vertex'])
elif search_type == 'description':
context = ''
related_vertice = []
for code in search_res:
context = context + code['code_text'] + '\n'
related_vertice.append(code['vertex'])
return context, related_vertice
if __name__ == '__main__':
codebase_name = 'testing'
code_path = '/Users/bingxu/Desktop/工作/大模型/chatbot/test_code_repo/client'
cbh = CodeBaseHandler(codebase_name, code_path, crawl_type='dir')
# query = '使用不同的HTTP请求类型GET、POST、DELETE等来执行不同的操作'
# query = '代码中一共有多少个类'
query = 'intercept 函数作用是什么'
search_type = 'graph'
limit = 2
res = cbh.search_code(query, search_type, limit)
logger.debug(res)