186 lines
6.6 KiB
Python
186 lines
6.6 KiB
Python
# encoding: utf-8
|
||
'''
|
||
@author: 温进
|
||
@file: codebase_handler.py
|
||
@time: 2023/11/21 下午2:25
|
||
@desc:
|
||
'''
|
||
import time
|
||
from loguru import logger
|
||
|
||
# from configs.server_config import NEBULA_HOST, NEBULA_PORT, NEBULA_USER, NEBULA_PASSWORD, NEBULA_STORAGED_PORT
|
||
# from configs.server_config import CHROMA_PERSISTENT_PATH
|
||
# from configs.model_config import EMBEDDING_ENGINE
|
||
|
||
from coagent.base_configs.env_config import (
|
||
NEBULA_HOST, NEBULA_PORT, NEBULA_USER, NEBULA_PASSWORD, NEBULA_STORAGED_PORT,
|
||
CHROMA_PERSISTENT_PATH
|
||
)
|
||
|
||
|
||
from coagent.db_handler.graph_db_handler.nebula_handler import NebulaHandler
|
||
from coagent.db_handler.vector_db_handler.chroma_handler import ChromaHandler
|
||
from coagent.codechat.code_crawler.zip_crawler import *
|
||
from coagent.codechat.code_analyzer.code_analyzer import CodeAnalyzer
|
||
from coagent.codechat.codebase_handler.code_importer import CodeImporter
|
||
from coagent.codechat.code_search.code_search import CodeSearch
|
||
from coagent.llm_models.llm_config import EmbedConfig, LLMConfig
|
||
|
||
|
||
class CodeBaseHandler:
|
||
def __init__(
|
||
self,
|
||
codebase_name: str,
|
||
code_path: str = '',
|
||
language: str = 'java',
|
||
crawl_type: str = 'ZIP',
|
||
embed_config: EmbedConfig = EmbedConfig(),
|
||
llm_config: LLMConfig = LLMConfig()
|
||
):
|
||
self.codebase_name = codebase_name
|
||
self.code_path = code_path
|
||
self.language = language
|
||
self.crawl_type = crawl_type
|
||
self.embed_config = embed_config
|
||
self.llm_config = llm_config
|
||
|
||
self.nh = NebulaHandler(host=NEBULA_HOST, port=NEBULA_PORT, username=NEBULA_USER,
|
||
password=NEBULA_PASSWORD, space_name=codebase_name)
|
||
self.nh.add_host(NEBULA_HOST, NEBULA_STORAGED_PORT)
|
||
time.sleep(1)
|
||
|
||
self.ch = ChromaHandler(path=CHROMA_PERSISTENT_PATH, collection_name=codebase_name)
|
||
|
||
def import_code(self, zip_file='', do_interpret=True):
|
||
'''
|
||
analyze code and save it to codekg and codedb
|
||
@return:
|
||
'''
|
||
# init graph to init tag and edge
|
||
code_importer = CodeImporter(embed_config=self.embed_config, codebase_name=self.codebase_name,
|
||
nh=self.nh, ch=self.ch)
|
||
code_importer.init_graph()
|
||
time.sleep(5)
|
||
|
||
# crawl code
|
||
st0 = time.time()
|
||
logger.info('start crawl')
|
||
code_dict = self.crawl_code(zip_file)
|
||
logger.debug('crawl done, rt={}'.format(time.time() - st0))
|
||
|
||
# analyze code
|
||
logger.info('start analyze')
|
||
st1 = time.time()
|
||
code_analyzer = CodeAnalyzer(language=self.language, llm_config = self.llm_config)
|
||
static_analysis_res, interpretation = code_analyzer.analyze(code_dict, do_interpret=do_interpret)
|
||
logger.debug('analyze done, rt={}'.format(time.time() - st1))
|
||
|
||
# add info to nebula and chroma
|
||
st2 = time.time()
|
||
code_importer.import_code(static_analysis_res, interpretation, do_interpret=do_interpret)
|
||
logger.debug('update codebase done, rt={}'.format(time.time() - st2))
|
||
|
||
# get KG info
|
||
stat = self.nh.get_stat()
|
||
vertices_num, edges_num = stat['vertices'], stat['edges']
|
||
|
||
# get chroma info
|
||
file_num = self.ch.count()['result']
|
||
|
||
return vertices_num, edges_num, file_num
|
||
|
||
def delete_codebase(self, codebase_name: str):
|
||
'''
|
||
delete codebase
|
||
@param codebase_name: name of codebase
|
||
@return:
|
||
'''
|
||
self.nh.drop_space(space_name=codebase_name)
|
||
self.ch.delete_collection(collection_name=codebase_name)
|
||
|
||
def crawl_code(self, zip_file=''):
|
||
'''
|
||
@return:
|
||
'''
|
||
if self.language == 'java':
|
||
suffix = 'java'
|
||
|
||
logger.info(f'crawl_type={self.crawl_type}')
|
||
|
||
code_dict = {}
|
||
if self.crawl_type.lower() == 'zip':
|
||
code_dict = ZipCrawler.crawl(zip_file, output_path=self.code_path, suffix=suffix)
|
||
elif self.crawl_type.lower() == 'dir':
|
||
code_dict = DirCrawler.crawl(self.code_path, suffix)
|
||
|
||
return code_dict
|
||
|
||
def search_code(self, query: str, search_type: str, limit: int = 3):
|
||
'''
|
||
search code from codebase
|
||
@param limit:
|
||
@param engine:
|
||
@param query: query from user
|
||
@param search_type: ['cypher', 'graph', 'vector']
|
||
@return:
|
||
'''
|
||
assert search_type in ['cypher', 'tag', 'description']
|
||
|
||
code_search = CodeSearch(llm_config=self.llm_config, nh=self.nh, ch=self.ch, limit=limit)
|
||
|
||
if search_type == 'cypher':
|
||
search_res = code_search.search_by_cypher(query=query)
|
||
elif search_type == 'tag':
|
||
search_res = code_search.search_by_tag(query=query)
|
||
elif search_type == 'description':
|
||
search_res = code_search.search_by_desciption(
|
||
query=query, engine=self.embed_config.embed_engine, model_path=self.embed_config.embed_model_path, embedding_device=self.embed_config.model_device)
|
||
|
||
context, related_vertice = self.format_search_res(search_res, search_type)
|
||
return context, related_vertice
|
||
|
||
def format_search_res(self, search_res: str, search_type: str):
|
||
'''
|
||
format search_res
|
||
@param search_res:
|
||
@param search_type:
|
||
@return:
|
||
'''
|
||
CYPHER_QA_PROMPT = '''
|
||
执行的 Cypher 是: {cypher}
|
||
Cypher 的结果是: {result}
|
||
'''
|
||
|
||
if search_type == 'cypher':
|
||
context = CYPHER_QA_PROMPT.format(cypher=search_res['cypher'], result=search_res['cypher_res'])
|
||
related_vertice = []
|
||
elif search_type == 'tag':
|
||
context = ''
|
||
related_vertice = []
|
||
for code in search_res:
|
||
context = context + code['code_text'] + '\n'
|
||
related_vertice.append(code['vertex'])
|
||
elif search_type == 'description':
|
||
context = ''
|
||
related_vertice = []
|
||
for code in search_res:
|
||
context = context + code['code_text'] + '\n'
|
||
related_vertice.append(code['vertex'])
|
||
|
||
return context, related_vertice
|
||
|
||
|
||
if __name__ == '__main__':
|
||
codebase_name = 'testing'
|
||
code_path = '/Users/bingxu/Desktop/工作/大模型/chatbot/test_code_repo/client'
|
||
cbh = CodeBaseHandler(codebase_name, code_path, crawl_type='dir')
|
||
|
||
# query = '使用不同的HTTP请求类型(GET、POST、DELETE等)来执行不同的操作'
|
||
# query = '代码中一共有多少个类'
|
||
|
||
query = 'intercept 函数作用是什么'
|
||
search_type = 'graph'
|
||
limit = 2
|
||
res = cbh.search_code(query, search_type, limit)
|
||
logger.debug(res)
|