codefuse-chatbot/dev_opsgpt/codechat/code_search/code_search.py

180 lines
5.8 KiB
Python
Raw Permalink Normal View History

# encoding: utf-8
'''
@author: 温进
@file: code_search.py
@time: 2023/11/21 下午2:35
@desc:
'''
import time
from loguru import logger
from collections import defaultdict
from dev_opsgpt.db_handler.graph_db_handler.nebula_handler import NebulaHandler
from dev_opsgpt.db_handler.vector_db_handler.chroma_handler import ChromaHandler
from dev_opsgpt.codechat.code_search.cypher_generator import CypherGenerator
from dev_opsgpt.codechat.code_search.tagger import Tagger
from dev_opsgpt.embeddings.get_embedding import get_embedding
# search_by_tag
VERTEX_SCORE = 10
HISTORY_VERTEX_SCORE = 5
VERTEX_MERGE_RATIO = 0.5
# search_by_description
MAX_DISTANCE = 1000
class CodeSearch:
def __init__(self, nh: NebulaHandler, ch: ChromaHandler, limit: int = 3):
'''
init
@param nh: NebulaHandler
@param ch: ChromaHandler
@param limit: limit of result
'''
self.nh = nh
self.ch = ch
self.limit = limit
def search_by_tag(self, query: str):
'''
search_code_res by tag
@param query: str
@return:
'''
tagger = Tagger()
tag_list = tagger.generate_tag_query(query)
logger.info(f'query tag={tag_list}')
# get all verticex
vertex_list = self.nh.get_vertices().get('v', [])
vertex_vid_list = [i.as_node().get_id().as_string() for i in vertex_list]
logger.debug(vertex_vid_list)
# update score
vertex_score_dict = defaultdict(lambda: 0)
for vid in vertex_vid_list:
for tag in tag_list:
if tag in vid:
vertex_score_dict[vid] += VERTEX_SCORE
# merge depend adj score
vertex_score_dict_final = {}
for vertex in vertex_score_dict:
cypher = f'''MATCH (v1)-[e]-(v2) where id(v1) == "{vertex}" RETURN v2'''
cypher_res = self.nh.execute_cypher(cypher, self.nh.space_name)
cypher_res_dict = self.nh.result_to_dict(cypher_res)
adj_vertex_list = [i.as_node().get_id().as_string() for i in cypher_res_dict.get('v2', [])]
score = vertex_score_dict.get(vertex, 0)
for adj_vertex in adj_vertex_list:
score += vertex_score_dict.get(adj_vertex, 0) * VERTEX_MERGE_RATIO
if score > 0:
vertex_score_dict_final[vertex] = score
# get most prominent package tag
package_score_dict = defaultdict(lambda: 0)
for vertex, score in vertex_score_dict.items():
package = '#'.join(vertex.split('#')[0:2])
package_score_dict[package] += score
# get respective code
res = []
package_score_tuple = list(package_score_dict.items())
package_score_tuple.sort(key=lambda x: x[1], reverse=True)
ids = [i[0] for i in package_score_tuple]
chroma_res = self.ch.get(ids=ids, include=['metadatas'])
for vertex, score in package_score_tuple:
index = chroma_res['result']['ids'].index(vertex)
code_text = chroma_res['result']['metadatas'][index]['code_text']
res.append({
"vertex": vertex,
"code_text": code_text}
)
if len(res) >= self.limit:
break
return res
def search_by_desciption(self, query: str, engine: str):
'''
search by perform sim search
@param query:
@return:
'''
query = query.replace(',', '')
query_emb = get_embedding(engine=engine, text_list=[query])
query_emb = query_emb[query]
query_embeddings = [query_emb]
query_result = self.ch.query(query_embeddings=query_embeddings, n_results=self.limit,
include=['metadatas', 'distances'])
logger.debug(query_result)
res = []
for idx, distance in enumerate(query_result['result']['distances'][0]):
if distance < MAX_DISTANCE:
vertex = query_result['result']['ids'][0][idx]
code_text = query_result['result']['metadatas'][0][idx]['code_text']
res.append({
"vertex": vertex,
"code_text": code_text
})
return res
def search_by_cypher(self, query: str):
'''
search by generating cypher
@param query:
@param engine:
@return:
'''
cg = CypherGenerator()
cypher = cg.get_cypher(query)
if not cypher:
return None
cypher_res = self.nh.execute_cypher(cypher, self.nh.space_name)
logger.info(f'cypher execution result={cypher_res}')
if not cypher_res.is_succeeded():
return {
'cypher': '',
'cypher_res': ''
}
res = {
'cypher': cypher,
'cypher_res': cypher_res
}
return res
if __name__ == '__main__':
from configs.server_config import NEBULA_HOST, NEBULA_PORT, NEBULA_USER, NEBULA_PASSWORD, NEBULA_STORAGED_PORT
from configs.server_config import CHROMA_PERSISTENT_PATH
codebase_name = 'testing'
nh = NebulaHandler(host=NEBULA_HOST, port=NEBULA_PORT, username=NEBULA_USER,
password=NEBULA_PASSWORD, space_name=codebase_name)
nh.add_host(NEBULA_HOST, NEBULA_STORAGED_PORT)
time.sleep(0.5)
ch = ChromaHandler(path=CHROMA_PERSISTENT_PATH, collection_name=codebase_name)
cs = CodeSearch(nh, ch)
# res = cs.search_by_tag(tag_list=['createFineTuneCompletion', 'OpenAiApi'])
# logger.debug(res)
# res = cs.search_by_cypher('代码中一共有多少个类', 'openai')
# logger.debug(res)
res = cs.search_by_desciption('使用不同的HTTP请求类型GET、POST、DELETE等来执行不同的操作', 'openai')
logger.debug(res)