codefuse-chatbot/coagent/codechat/code_search/code_search.py

262 lines
9.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# encoding: utf-8
'''
@author: 温进
@file: code_search.py
@time: 2023/11/21 下午2:35
@desc:
'''
import json
import time
from loguru import logger
from collections import defaultdict
from coagent.db_handler.graph_db_handler.nebula_handler import NebulaHandler
from coagent.db_handler.vector_db_handler.chroma_handler import ChromaHandler
from coagent.codechat.code_search.cypher_generator import CypherGenerator
from coagent.codechat.code_search.tagger import Tagger
from coagent.embeddings.get_embedding import get_embedding
from coagent.llm_models.llm_config import LLMConfig, EmbedConfig
# from configs.model_config import EMBEDDING_DEVICE, EMBEDDING_MODEL
# search_by_tag
VERTEX_SCORE = 10
HISTORY_VERTEX_SCORE = 5
VERTEX_MERGE_RATIO = 0.5
# search_by_description
MAX_DISTANCE = 1000
class CodeSearch:
def __init__(self, llm_config: LLMConfig, nh: NebulaHandler, ch: ChromaHandler, limit: int = 3,
local_graph_file_path: str = ''):
'''
init
@param nh: NebulaHandler
@param ch: ChromaHandler
@param limit: limit of result
'''
self.llm_config = llm_config
self.nh = nh
if not self.nh:
with open(local_graph_file_path, 'r') as f:
self.graph = json.load(f)
self.ch = ch
self.limit = limit
def search_by_tag(self, query: str):
'''
search_code_res by tag
@param query: str
@return:
'''
tagger = Tagger()
tag_list = tagger.generate_tag_query(query)
logger.info(f'query tag={tag_list}')
# get all vertices
vertex_list = self.nh.get_vertices().get('v', [])
vertex_vid_list = [i.as_node().get_id().as_string() for i in vertex_list]
# update score
vertex_score_dict = defaultdict(lambda: 0)
for vid in vertex_vid_list:
for tag in tag_list:
if tag in vid:
vertex_score_dict[vid] += VERTEX_SCORE
# merge depend adj score
vertex_score_dict_final = {}
for vertex in vertex_score_dict:
cypher = f'''MATCH (v1)-[e]-(v2) where id(v1) == "{vertex}" RETURN v2'''
cypher_res = self.nh.execute_cypher(cypher, self.nh.space_name)
cypher_res_dict = self.nh.result_to_dict(cypher_res)
adj_vertex_list = [i.as_node().get_id().as_string() for i in cypher_res_dict.get('v2', [])]
score = vertex_score_dict.get(vertex, 0)
for adj_vertex in adj_vertex_list:
score += vertex_score_dict.get(adj_vertex, 0) * VERTEX_MERGE_RATIO
if score > 0:
vertex_score_dict_final[vertex] = score
# get most prominent package tag
package_score_dict = defaultdict(lambda: 0)
for vertex, score in vertex_score_dict_final.items():
if '#' in vertex:
# get class name first
cypher = f'''MATCH (v1:class)-[e:contain]->(v2) WHERE id(v2) == '{vertex}' RETURN id(v1) as id;'''
cypher_res = self.nh.execute_cypher(cypher=cypher, format_res=True)
class_vertices = cypher_res.get('id', [])
if not class_vertices:
continue
vertex = class_vertices[0].as_string()
# get package name
cypher = f'''MATCH (v1:package)-[e:contain]->(v2) WHERE id(v2) == '{vertex}' RETURN id(v1) as id;'''
cypher_res = self.nh.execute_cypher(cypher=cypher, format_res=True)
pac_vertices = cypher_res.get('id', [])
if not pac_vertices:
continue
package = pac_vertices[0].as_string()
package_score_dict[package] += score
# get respective code
res = []
package_score_tuple = list(package_score_dict.items())
package_score_tuple.sort(key=lambda x: x[1], reverse=True)
ids = [i[0] for i in package_score_tuple]
logger.info(f'ids={ids}')
chroma_res = self.ch.get(ids=ids, include=['metadatas'])
for vertex, score in package_score_tuple:
index = chroma_res['result']['ids'].index(vertex)
code_text = chroma_res['result']['metadatas'][index]['code_text']
res.append({
"vertex": vertex,
"code_text": code_text}
)
if len(res) >= self.limit:
break
# logger.info(f'retrival code={res}')
return res
def search_by_tag_by_graph(self, query: str):
'''
search code by tag with graph
@param query:
@return:
'''
tagger = Tagger()
tag_list = tagger.generate_tag_query(query)
logger.info(f'query tag={tag_list}')
# loop to get package node
package_score_dict = {}
for code, structure in self.graph.items():
score = 0
for class_name in structure['class_name_list']:
for tag in tag_list:
if tag.lower() in class_name.lower():
score += 1
for func_name_list in structure['func_name_dict'].values():
for func_name in func_name_list:
for tag in tag_list:
if tag.lower() in func_name.lower():
score += 1
package_score_dict[structure['pac_name']] = score
# get respective code
res = []
package_score_tuple = list(package_score_dict.items())
package_score_tuple.sort(key=lambda x: x[1], reverse=True)
ids = [i[0] for i in package_score_tuple]
logger.info(f'ids={ids}')
chroma_res = self.ch.get(ids=ids, include=['metadatas'])
# logger.info(chroma_res)
for vertex, score in package_score_tuple:
index = chroma_res['result']['ids'].index(vertex)
code_text = chroma_res['result']['metadatas'][index]['code_text']
res.append({
"vertex": vertex,
"code_text": code_text}
)
if len(res) >= self.limit:
break
# logger.info(f'retrival code={res}')
return res
def search_by_desciption(self, query: str, engine: str, model_path: str = "text2vec-base-chinese", embedding_device: str = "cpu", embed_config: EmbedConfig=None):
'''
search by perform sim search
@param query:
@return:
'''
query = query.replace(',', '')
query_emb = get_embedding(engine=engine, text_list=[query], model_path=model_path, embedding_device= embedding_device, embed_config=embed_config)
query_emb = query_emb[query]
query_embeddings = [query_emb]
query_result = self.ch.query(query_embeddings=query_embeddings, n_results=self.limit,
include=['metadatas', 'distances'])
res = []
for idx, distance in enumerate(query_result['result']['distances'][0]):
if distance < MAX_DISTANCE:
vertex = query_result['result']['ids'][0][idx]
code_text = query_result['result']['metadatas'][0][idx]['code_text']
res.append({
"vertex": vertex,
"code_text": code_text
})
return res
def search_by_cypher(self, query: str):
'''
search by generating cypher
@param query:
@param engine:
@return:
'''
cg = CypherGenerator(self.llm_config)
cypher = cg.get_cypher(query)
if not cypher:
return None
cypher_res = self.nh.execute_cypher(cypher, self.nh.space_name)
logger.info(f'cypher execution result={cypher_res}')
if not cypher_res.is_succeeded():
return {
'cypher': '',
'cypher_res': ''
}
res = {
'cypher': cypher,
'cypher_res': cypher_res
}
return res
if __name__ == '__main__':
# from configs.server_config import NEBULA_HOST, NEBULA_PORT, NEBULA_USER, NEBULA_PASSWORD, NEBULA_STORAGED_PORT
# from configs.server_config import CHROMA_PERSISTENT_PATH
from coagent.base_configs.env_config import (
NEBULA_HOST, NEBULA_PORT, NEBULA_USER, NEBULA_PASSWORD, NEBULA_STORAGED_PORT,
CHROMA_PERSISTENT_PATH
)
codebase_name = 'testing'
nh = NebulaHandler(host=NEBULA_HOST, port=NEBULA_PORT, username=NEBULA_USER,
password=NEBULA_PASSWORD, space_name=codebase_name)
nh.add_host(NEBULA_HOST, NEBULA_STORAGED_PORT)
time.sleep(0.5)
ch = ChromaHandler(path=CHROMA_PERSISTENT_PATH, collection_name=codebase_name)
cs = CodeSearch(nh, ch)
# res = cs.search_by_tag(tag_list=['createFineTuneCompletion', 'OpenAiApi'])
# logger.debug(res)
# res = cs.search_by_cypher('代码中一共有多少个类', 'openai')
# logger.debug(res)
res = cs.search_by_desciption('使用不同的HTTP请求类型GET、POST、DELETE等来执行不同的操作', 'openai')
logger.debug(res)