275 lines
10 KiB
Python
275 lines
10 KiB
Python
# encoding: utf-8
|
||
'''
|
||
@author: 温进
|
||
@file: codebase_handler.py
|
||
@time: 2023/11/21 下午2:25
|
||
@desc:
|
||
'''
|
||
import os
|
||
import time
|
||
import json
|
||
from typing import List
|
||
from loguru import logger
|
||
|
||
from coagent.base_configs.env_config import (
|
||
NEBULA_HOST, NEBULA_PORT, NEBULA_USER, NEBULA_PASSWORD, NEBULA_STORAGED_PORT,
|
||
CHROMA_PERSISTENT_PATH, CB_ROOT_PATH
|
||
)
|
||
|
||
|
||
from coagent.db_handler.graph_db_handler.nebula_handler import NebulaHandler
|
||
from coagent.db_handler.vector_db_handler.chroma_handler import ChromaHandler
|
||
from coagent.codechat.code_crawler.zip_crawler import *
|
||
from coagent.codechat.code_analyzer.code_analyzer import CodeAnalyzer
|
||
from coagent.codechat.codebase_handler.code_importer import CodeImporter
|
||
from coagent.codechat.code_search.code_search import CodeSearch
|
||
from coagent.llm_models.llm_config import EmbedConfig, LLMConfig
|
||
|
||
|
||
class CodeBaseHandler:
|
||
def __init__(
|
||
self,
|
||
codebase_name: str,
|
||
code_path: str = '',
|
||
language: str = 'java',
|
||
crawl_type: str = 'ZIP',
|
||
embed_config: EmbedConfig = EmbedConfig(),
|
||
llm_config: LLMConfig = LLMConfig(),
|
||
use_nh: bool = True,
|
||
local_graph_path: str = CB_ROOT_PATH
|
||
):
|
||
self.codebase_name = codebase_name
|
||
self.code_path = code_path
|
||
self.language = language
|
||
self.crawl_type = crawl_type
|
||
self.embed_config = embed_config
|
||
self.llm_config = llm_config
|
||
self.local_graph_file_path = local_graph_path + os.sep + f'{self.codebase_name}_graph.json'
|
||
|
||
if use_nh:
|
||
try:
|
||
self.nh = NebulaHandler(host=NEBULA_HOST, port=NEBULA_PORT, username=NEBULA_USER,
|
||
password=NEBULA_PASSWORD, space_name=codebase_name)
|
||
self.nh.add_host(NEBULA_HOST, NEBULA_STORAGED_PORT)
|
||
time.sleep(1)
|
||
except:
|
||
self.nh = None
|
||
try:
|
||
with open(self.local_graph_file_path, 'r') as f:
|
||
self.graph = json.load(f)
|
||
except:
|
||
pass
|
||
elif local_graph_path:
|
||
self.nh = None
|
||
try:
|
||
with open(self.local_graph_file_path, 'r') as f:
|
||
self.graph = json.load(f)
|
||
except:
|
||
pass
|
||
|
||
self.ch = ChromaHandler(path=CHROMA_PERSISTENT_PATH, collection_name=codebase_name)
|
||
|
||
def import_code(self, zip_file='', do_interpret=True):
|
||
'''
|
||
analyze code and save it to codekg and codedb
|
||
@return:
|
||
'''
|
||
# init graph to init tag and edge
|
||
code_importer = CodeImporter(embed_config=self.embed_config, codebase_name=self.codebase_name,
|
||
nh=self.nh, ch=self.ch, local_graph_file_path=self.local_graph_file_path)
|
||
if self.nh:
|
||
code_importer.init_graph()
|
||
time.sleep(5)
|
||
|
||
# crawl code
|
||
st0 = time.time()
|
||
logger.info('start crawl')
|
||
code_dict = self.crawl_code(zip_file)
|
||
logger.debug('crawl done, rt={}'.format(time.time() - st0))
|
||
|
||
# analyze code
|
||
logger.info('start analyze')
|
||
st1 = time.time()
|
||
code_analyzer = CodeAnalyzer(language=self.language, llm_config=self.llm_config)
|
||
static_analysis_res, interpretation = code_analyzer.analyze(code_dict, do_interpret=do_interpret)
|
||
logger.debug('analyze done, rt={}'.format(time.time() - st1))
|
||
|
||
# add info to nebula and chroma
|
||
st2 = time.time()
|
||
code_importer.import_code(static_analysis_res, interpretation, do_interpret=do_interpret)
|
||
logger.debug('update codebase done, rt={}'.format(time.time() - st2))
|
||
|
||
# get KG info
|
||
if self.nh:
|
||
stat = self.nh.get_stat()
|
||
vertices_num, edges_num = stat['vertices'], stat['edges']
|
||
else:
|
||
vertices_num = 0
|
||
edges_num = 0
|
||
|
||
# get chroma info
|
||
file_num = self.ch.count()['result']
|
||
|
||
return vertices_num, edges_num, file_num
|
||
|
||
def delete_codebase(self, codebase_name: str):
|
||
'''
|
||
delete codebase
|
||
@param codebase_name: name of codebase
|
||
@return:
|
||
'''
|
||
if self.nh:
|
||
self.nh.drop_space(space_name=codebase_name)
|
||
elif self.local_graph_file_path and os.path.isfile(self.local_graph_file_path):
|
||
os.remove(self.local_graph_file_path)
|
||
|
||
self.ch.delete_collection(collection_name=codebase_name)
|
||
|
||
def crawl_code(self, zip_file=''):
|
||
'''
|
||
@return:
|
||
'''
|
||
if self.language == 'java':
|
||
suffix = 'java'
|
||
|
||
logger.info(f'crawl_type={self.crawl_type}')
|
||
|
||
code_dict = {}
|
||
if self.crawl_type.lower() == 'zip':
|
||
code_dict = ZipCrawler.crawl(zip_file, output_path=self.code_path, suffix=suffix)
|
||
elif self.crawl_type.lower() == 'dir':
|
||
code_dict = DirCrawler.crawl(self.code_path, suffix)
|
||
|
||
return code_dict
|
||
|
||
def search_code(self, query: str, search_type: str, limit: int = 3):
|
||
'''
|
||
search code from codebase
|
||
@param limit:
|
||
@param engine:
|
||
@param query: query from user
|
||
@param search_type: ['cypher', 'graph', 'vector']
|
||
@return:
|
||
'''
|
||
if self.nh:
|
||
assert search_type in ['cypher', 'tag', 'description']
|
||
else:
|
||
if search_type == 'tag':
|
||
search_type = 'tag_by_local_graph'
|
||
assert search_type in ['tag_by_local_graph', 'description']
|
||
|
||
code_search = CodeSearch(llm_config=self.llm_config, nh=self.nh, ch=self.ch, limit=limit,
|
||
local_graph_file_path=self.local_graph_file_path)
|
||
|
||
if search_type == 'cypher':
|
||
search_res = code_search.search_by_cypher(query=query)
|
||
elif search_type == 'tag':
|
||
search_res = code_search.search_by_tag(query=query)
|
||
elif search_type == 'description':
|
||
search_res = code_search.search_by_desciption(
|
||
query=query, engine=self.embed_config.embed_engine, model_path=self.embed_config.embed_model_path,
|
||
embedding_device=self.embed_config.model_device, embed_config=self.embed_config)
|
||
elif search_type == 'tag_by_local_graph':
|
||
search_res = code_search.search_by_tag_by_graph(query=query)
|
||
|
||
|
||
context, related_vertice = self.format_search_res(search_res, search_type)
|
||
return context, related_vertice
|
||
|
||
def format_search_res(self, search_res: str, search_type: str):
|
||
'''
|
||
format search_res
|
||
@param search_res:
|
||
@param search_type:
|
||
@return:
|
||
'''
|
||
CYPHER_QA_PROMPT = '''
|
||
执行的 Cypher 是: {cypher}
|
||
Cypher 的结果是: {result}
|
||
'''
|
||
|
||
if search_type == 'cypher':
|
||
context = CYPHER_QA_PROMPT.format(cypher=search_res['cypher'], result=search_res['cypher_res'])
|
||
related_vertice = []
|
||
elif search_type == 'tag':
|
||
context = ''
|
||
related_vertice = []
|
||
for code in search_res:
|
||
context = context + code['code_text'] + '\n'
|
||
related_vertice.append(code['vertex'])
|
||
elif search_type == 'tag_by_local_graph':
|
||
context = ''
|
||
related_vertice = []
|
||
for code in search_res:
|
||
context = context + code['code_text'] + '\n'
|
||
related_vertice.append(code['vertex'])
|
||
elif search_type == 'description':
|
||
context = ''
|
||
related_vertice = []
|
||
for code in search_res:
|
||
context = context + code['code_text'] + '\n'
|
||
related_vertice.append(code['vertex'])
|
||
|
||
return context, related_vertice
|
||
|
||
def search_vertices(self, vertex_type="class") -> List[str]:
|
||
'''
|
||
通过 method/class 来搜索所有的节点
|
||
'''
|
||
vertices = []
|
||
if self.nh:
|
||
vertices = self.nh.get_all_vertices()
|
||
vertices = [str(v.as_node().get_id()) for v in vertices["v"] if vertex_type in v.as_node().tags()]
|
||
# for v in vertices["v"]:
|
||
# logger.debug(f"{v.as_node().get_id()}, {v.as_node().tags()}")
|
||
else:
|
||
if vertex_type == "class":
|
||
vertices = [str(class_name) for code, structure in self.graph.items() for class_name in structure['class_name_list']]
|
||
elif vertex_type == "method":
|
||
vertices = [
|
||
str(methods_name)
|
||
for code, structure in self.graph.items()
|
||
for methods_names in structure['func_name_dict'].values()
|
||
for methods_name in methods_names
|
||
]
|
||
# logger.debug(vertices)
|
||
return vertices
|
||
|
||
|
||
if __name__ == '__main__':
|
||
from configs.model_config import KB_ROOT_PATH, JUPYTER_WORK_PATH
|
||
from configs.server_config import SANDBOX_SERVER
|
||
|
||
LLM_MODEL = "gpt-3.5-turbo"
|
||
llm_config = LLMConfig(
|
||
model_name=LLM_MODEL, model_device="cpu", api_key=os.environ["OPENAI_API_KEY"],
|
||
api_base_url=os.environ["API_BASE_URL"], temperature=0.3
|
||
)
|
||
src_dir = '/Users/bingxu/Desktop/工作/大模型/chatbot/Codefuse-chatbot-antcode'
|
||
embed_config = EmbedConfig(
|
||
embed_engine="model", embed_model="text2vec-base-chinese",
|
||
embed_model_path=os.path.join(src_dir, "embedding_models/text2vec-base-chinese")
|
||
)
|
||
|
||
codebase_name = 'client_local'
|
||
code_path = '/Users/bingxu/Desktop/工作/大模型/chatbot/test_code_repo/client'
|
||
use_nh = False
|
||
local_graph_path = '/Users/bingxu/Desktop/工作/大模型/chatbot/Codefuse-chatbot-antcode/code_base'
|
||
CHROMA_PERSISTENT_PATH = '/Users/bingxu/Desktop/工作/大模型/chatbot/Codefuse-chatbot-antcode/data/chroma_data'
|
||
|
||
cbh = CodeBaseHandler(codebase_name, code_path, crawl_type='dir', use_nh=use_nh, local_graph_path=local_graph_path,
|
||
llm_config=llm_config, embed_config=embed_config)
|
||
|
||
# test import code
|
||
# cbh.import_code(do_interpret=True)
|
||
|
||
# query = '使用不同的HTTP请求类型(GET、POST、DELETE等)来执行不同的操作'
|
||
# query = '代码中一共有多少个类'
|
||
# query = 'remove 这个函数是用来做什么的'
|
||
query = '有没有函数是从字符串中删除指定字符串的功能'
|
||
|
||
search_type = 'description'
|
||
limit = 2
|
||
res = cbh.search_code(query, search_type, limit)
|
||
logger.debug(res)
|