2023-12-07 20:17:21 +08:00
|
|
|
|
# encoding: utf-8
|
|
|
|
|
'''
|
|
|
|
|
@author: 温进
|
|
|
|
|
@file: codebase_handler.py
|
|
|
|
|
@time: 2023/11/21 下午2:25
|
|
|
|
|
@desc:
|
|
|
|
|
'''
|
|
|
|
|
import time
|
|
|
|
|
from loguru import logger
|
|
|
|
|
|
2024-01-26 14:03:25 +08:00
|
|
|
|
# from configs.server_config import NEBULA_HOST, NEBULA_PORT, NEBULA_USER, NEBULA_PASSWORD, NEBULA_STORAGED_PORT
|
|
|
|
|
# from configs.server_config import CHROMA_PERSISTENT_PATH
|
|
|
|
|
# from configs.model_config import EMBEDDING_ENGINE
|
2023-12-07 20:17:21 +08:00
|
|
|
|
|
2024-01-26 14:03:25 +08:00
|
|
|
|
from coagent.base_configs.env_config import (
|
|
|
|
|
NEBULA_HOST, NEBULA_PORT, NEBULA_USER, NEBULA_PASSWORD, NEBULA_STORAGED_PORT,
|
|
|
|
|
CHROMA_PERSISTENT_PATH
|
|
|
|
|
)
|
2023-12-07 20:17:21 +08:00
|
|
|
|
|
2024-01-26 14:03:25 +08:00
|
|
|
|
|
|
|
|
|
from coagent.db_handler.graph_db_handler.nebula_handler import NebulaHandler
|
|
|
|
|
from coagent.db_handler.vector_db_handler.chroma_handler import ChromaHandler
|
|
|
|
|
from coagent.codechat.code_crawler.zip_crawler import *
|
|
|
|
|
from coagent.codechat.code_analyzer.code_analyzer import CodeAnalyzer
|
|
|
|
|
from coagent.codechat.codebase_handler.code_importer import CodeImporter
|
|
|
|
|
from coagent.codechat.code_search.code_search import CodeSearch
|
|
|
|
|
from coagent.llm_models.llm_config import EmbedConfig, LLMConfig
|
2023-12-07 20:17:21 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CodeBaseHandler:
|
2024-01-26 14:03:25 +08:00
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
codebase_name: str,
|
|
|
|
|
code_path: str = '',
|
|
|
|
|
language: str = 'java',
|
|
|
|
|
crawl_type: str = 'ZIP',
|
|
|
|
|
embed_config: EmbedConfig = EmbedConfig(),
|
|
|
|
|
llm_config: LLMConfig = LLMConfig()
|
|
|
|
|
):
|
2023-12-07 20:17:21 +08:00
|
|
|
|
self.codebase_name = codebase_name
|
|
|
|
|
self.code_path = code_path
|
|
|
|
|
self.language = language
|
|
|
|
|
self.crawl_type = crawl_type
|
2024-01-26 14:03:25 +08:00
|
|
|
|
self.embed_config = embed_config
|
|
|
|
|
self.llm_config = llm_config
|
2023-12-07 20:17:21 +08:00
|
|
|
|
|
|
|
|
|
self.nh = NebulaHandler(host=NEBULA_HOST, port=NEBULA_PORT, username=NEBULA_USER,
|
|
|
|
|
password=NEBULA_PASSWORD, space_name=codebase_name)
|
|
|
|
|
self.nh.add_host(NEBULA_HOST, NEBULA_STORAGED_PORT)
|
|
|
|
|
time.sleep(1)
|
|
|
|
|
|
|
|
|
|
self.ch = ChromaHandler(path=CHROMA_PERSISTENT_PATH, collection_name=codebase_name)
|
|
|
|
|
|
|
|
|
|
def import_code(self, zip_file='', do_interpret=True):
|
|
|
|
|
'''
|
|
|
|
|
analyze code and save it to codekg and codedb
|
|
|
|
|
@return:
|
|
|
|
|
'''
|
|
|
|
|
# init graph to init tag and edge
|
2024-01-26 14:03:25 +08:00
|
|
|
|
code_importer = CodeImporter(embed_config=self.embed_config, codebase_name=self.codebase_name,
|
2023-12-07 20:17:21 +08:00
|
|
|
|
nh=self.nh, ch=self.ch)
|
|
|
|
|
code_importer.init_graph()
|
|
|
|
|
time.sleep(5)
|
|
|
|
|
|
|
|
|
|
# crawl code
|
|
|
|
|
st0 = time.time()
|
|
|
|
|
logger.info('start crawl')
|
|
|
|
|
code_dict = self.crawl_code(zip_file)
|
|
|
|
|
logger.debug('crawl done, rt={}'.format(time.time() - st0))
|
|
|
|
|
|
|
|
|
|
# analyze code
|
|
|
|
|
logger.info('start analyze')
|
|
|
|
|
st1 = time.time()
|
2024-01-26 14:03:25 +08:00
|
|
|
|
code_analyzer = CodeAnalyzer(language=self.language, llm_config = self.llm_config)
|
2023-12-07 20:17:21 +08:00
|
|
|
|
static_analysis_res, interpretation = code_analyzer.analyze(code_dict, do_interpret=do_interpret)
|
|
|
|
|
logger.debug('analyze done, rt={}'.format(time.time() - st1))
|
|
|
|
|
|
|
|
|
|
# add info to nebula and chroma
|
|
|
|
|
st2 = time.time()
|
|
|
|
|
code_importer.import_code(static_analysis_res, interpretation, do_interpret=do_interpret)
|
|
|
|
|
logger.debug('update codebase done, rt={}'.format(time.time() - st2))
|
|
|
|
|
|
|
|
|
|
# get KG info
|
|
|
|
|
stat = self.nh.get_stat()
|
|
|
|
|
vertices_num, edges_num = stat['vertices'], stat['edges']
|
|
|
|
|
|
|
|
|
|
# get chroma info
|
|
|
|
|
file_num = self.ch.count()['result']
|
|
|
|
|
|
|
|
|
|
return vertices_num, edges_num, file_num
|
|
|
|
|
|
|
|
|
|
def delete_codebase(self, codebase_name: str):
|
|
|
|
|
'''
|
|
|
|
|
delete codebase
|
|
|
|
|
@param codebase_name: name of codebase
|
|
|
|
|
@return:
|
|
|
|
|
'''
|
|
|
|
|
self.nh.drop_space(space_name=codebase_name)
|
|
|
|
|
self.ch.delete_collection(collection_name=codebase_name)
|
|
|
|
|
|
|
|
|
|
def crawl_code(self, zip_file=''):
|
|
|
|
|
'''
|
|
|
|
|
@return:
|
|
|
|
|
'''
|
|
|
|
|
if self.language == 'java':
|
|
|
|
|
suffix = 'java'
|
|
|
|
|
|
|
|
|
|
logger.info(f'crawl_type={self.crawl_type}')
|
|
|
|
|
|
|
|
|
|
code_dict = {}
|
|
|
|
|
if self.crawl_type.lower() == 'zip':
|
|
|
|
|
code_dict = ZipCrawler.crawl(zip_file, output_path=self.code_path, suffix=suffix)
|
|
|
|
|
elif self.crawl_type.lower() == 'dir':
|
|
|
|
|
code_dict = DirCrawler.crawl(self.code_path, suffix)
|
|
|
|
|
|
|
|
|
|
return code_dict
|
|
|
|
|
|
|
|
|
|
def search_code(self, query: str, search_type: str, limit: int = 3):
|
|
|
|
|
'''
|
|
|
|
|
search code from codebase
|
|
|
|
|
@param limit:
|
|
|
|
|
@param engine:
|
|
|
|
|
@param query: query from user
|
|
|
|
|
@param search_type: ['cypher', 'graph', 'vector']
|
|
|
|
|
@return:
|
|
|
|
|
'''
|
|
|
|
|
assert search_type in ['cypher', 'tag', 'description']
|
|
|
|
|
|
2024-01-26 14:03:25 +08:00
|
|
|
|
code_search = CodeSearch(llm_config=self.llm_config, nh=self.nh, ch=self.ch, limit=limit)
|
2023-12-07 20:17:21 +08:00
|
|
|
|
|
|
|
|
|
if search_type == 'cypher':
|
|
|
|
|
search_res = code_search.search_by_cypher(query=query)
|
|
|
|
|
elif search_type == 'tag':
|
|
|
|
|
search_res = code_search.search_by_tag(query=query)
|
|
|
|
|
elif search_type == 'description':
|
2024-01-26 14:03:25 +08:00
|
|
|
|
search_res = code_search.search_by_desciption(
|
|
|
|
|
query=query, engine=self.embed_config.embed_engine, model_path=self.embed_config.embed_model_path, embedding_device=self.embed_config.model_device)
|
2023-12-07 20:17:21 +08:00
|
|
|
|
|
|
|
|
|
context, related_vertice = self.format_search_res(search_res, search_type)
|
|
|
|
|
return context, related_vertice
|
|
|
|
|
|
|
|
|
|
def format_search_res(self, search_res: str, search_type: str):
|
|
|
|
|
'''
|
|
|
|
|
format search_res
|
|
|
|
|
@param search_res:
|
|
|
|
|
@param search_type:
|
|
|
|
|
@return:
|
|
|
|
|
'''
|
|
|
|
|
CYPHER_QA_PROMPT = '''
|
|
|
|
|
执行的 Cypher 是: {cypher}
|
|
|
|
|
Cypher 的结果是: {result}
|
|
|
|
|
'''
|
|
|
|
|
|
|
|
|
|
if search_type == 'cypher':
|
|
|
|
|
context = CYPHER_QA_PROMPT.format(cypher=search_res['cypher'], result=search_res['cypher_res'])
|
|
|
|
|
related_vertice = []
|
|
|
|
|
elif search_type == 'tag':
|
|
|
|
|
context = ''
|
|
|
|
|
related_vertice = []
|
|
|
|
|
for code in search_res:
|
|
|
|
|
context = context + code['code_text'] + '\n'
|
|
|
|
|
related_vertice.append(code['vertex'])
|
|
|
|
|
elif search_type == 'description':
|
|
|
|
|
context = ''
|
|
|
|
|
related_vertice = []
|
|
|
|
|
for code in search_res:
|
|
|
|
|
context = context + code['code_text'] + '\n'
|
|
|
|
|
related_vertice.append(code['vertex'])
|
|
|
|
|
|
|
|
|
|
return context, related_vertice
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
codebase_name = 'testing'
|
|
|
|
|
code_path = '/Users/bingxu/Desktop/工作/大模型/chatbot/test_code_repo/client'
|
|
|
|
|
cbh = CodeBaseHandler(codebase_name, code_path, crawl_type='dir')
|
|
|
|
|
|
|
|
|
|
# query = '使用不同的HTTP请求类型(GET、POST、DELETE等)来执行不同的操作'
|
|
|
|
|
# query = '代码中一共有多少个类'
|
|
|
|
|
|
|
|
|
|
query = 'intercept 函数作用是什么'
|
|
|
|
|
search_type = 'graph'
|
|
|
|
|
limit = 2
|
|
|
|
|
res = cbh.search_code(query, search_type, limit)
|
|
|
|
|
logger.debug(res)
|