codefuse-chatbot/dev_opsgpt/codechat/code_search/code_search.py

180 lines
5.8 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# encoding: utf-8
'''
@author: 温进
@file: code_search.py
@time: 2023/11/21 下午2:35
@desc:
'''
import time
from loguru import logger
from collections import defaultdict
from dev_opsgpt.db_handler.graph_db_handler.nebula_handler import NebulaHandler
from dev_opsgpt.db_handler.vector_db_handler.chroma_handler import ChromaHandler
from dev_opsgpt.codechat.code_search.cypher_generator import CypherGenerator
from dev_opsgpt.codechat.code_search.tagger import Tagger
from dev_opsgpt.embeddings.get_embedding import get_embedding
# search_by_tag
VERTEX_SCORE = 10
HISTORY_VERTEX_SCORE = 5
VERTEX_MERGE_RATIO = 0.5
# search_by_description
MAX_DISTANCE = 1000
class CodeSearch:
def __init__(self, nh: NebulaHandler, ch: ChromaHandler, limit: int = 3):
'''
init
@param nh: NebulaHandler
@param ch: ChromaHandler
@param limit: limit of result
'''
self.nh = nh
self.ch = ch
self.limit = limit
def search_by_tag(self, query: str):
'''
search_code_res by tag
@param query: str
@return:
'''
tagger = Tagger()
tag_list = tagger.generate_tag_query(query)
logger.info(f'query tag={tag_list}')
# get all verticex
vertex_list = self.nh.get_vertices().get('v', [])
vertex_vid_list = [i.as_node().get_id().as_string() for i in vertex_list]
logger.debug(vertex_vid_list)
# update score
vertex_score_dict = defaultdict(lambda: 0)
for vid in vertex_vid_list:
for tag in tag_list:
if tag in vid:
vertex_score_dict[vid] += VERTEX_SCORE
# merge depend adj score
vertex_score_dict_final = {}
for vertex in vertex_score_dict:
cypher = f'''MATCH (v1)-[e]-(v2) where id(v1) == "{vertex}" RETURN v2'''
cypher_res = self.nh.execute_cypher(cypher, self.nh.space_name)
cypher_res_dict = self.nh.result_to_dict(cypher_res)
adj_vertex_list = [i.as_node().get_id().as_string() for i in cypher_res_dict.get('v2', [])]
score = vertex_score_dict.get(vertex, 0)
for adj_vertex in adj_vertex_list:
score += vertex_score_dict.get(adj_vertex, 0) * VERTEX_MERGE_RATIO
if score > 0:
vertex_score_dict_final[vertex] = score
# get most prominent package tag
package_score_dict = defaultdict(lambda: 0)
for vertex, score in vertex_score_dict.items():
package = '#'.join(vertex.split('#')[0:2])
package_score_dict[package] += score
# get respective code
res = []
package_score_tuple = list(package_score_dict.items())
package_score_tuple.sort(key=lambda x: x[1], reverse=True)
ids = [i[0] for i in package_score_tuple]
chroma_res = self.ch.get(ids=ids, include=['metadatas'])
for vertex, score in package_score_tuple:
index = chroma_res['result']['ids'].index(vertex)
code_text = chroma_res['result']['metadatas'][index]['code_text']
res.append({
"vertex": vertex,
"code_text": code_text}
)
if len(res) >= self.limit:
break
return res
def search_by_desciption(self, query: str, engine: str):
'''
search by perform sim search
@param query:
@return:
'''
query = query.replace(',', '')
query_emb = get_embedding(engine=engine, text_list=[query])
query_emb = query_emb[query]
query_embeddings = [query_emb]
query_result = self.ch.query(query_embeddings=query_embeddings, n_results=self.limit,
include=['metadatas', 'distances'])
logger.debug(query_result)
res = []
for idx, distance in enumerate(query_result['result']['distances'][0]):
if distance < MAX_DISTANCE:
vertex = query_result['result']['ids'][0][idx]
code_text = query_result['result']['metadatas'][0][idx]['code_text']
res.append({
"vertex": vertex,
"code_text": code_text
})
return res
def search_by_cypher(self, query: str):
'''
search by generating cypher
@param query:
@param engine:
@return:
'''
cg = CypherGenerator()
cypher = cg.get_cypher(query)
if not cypher:
return None
cypher_res = self.nh.execute_cypher(cypher, self.nh.space_name)
logger.info(f'cypher execution result={cypher_res}')
if not cypher_res.is_succeeded():
return {
'cypher': '',
'cypher_res': ''
}
res = {
'cypher': cypher,
'cypher_res': cypher_res
}
return res
if __name__ == '__main__':
from configs.server_config import NEBULA_HOST, NEBULA_PORT, NEBULA_USER, NEBULA_PASSWORD, NEBULA_STORAGED_PORT
from configs.server_config import CHROMA_PERSISTENT_PATH
codebase_name = 'testing'
nh = NebulaHandler(host=NEBULA_HOST, port=NEBULA_PORT, username=NEBULA_USER,
password=NEBULA_PASSWORD, space_name=codebase_name)
nh.add_host(NEBULA_HOST, NEBULA_STORAGED_PORT)
time.sleep(0.5)
ch = ChromaHandler(path=CHROMA_PERSISTENT_PATH, collection_name=codebase_name)
cs = CodeSearch(nh, ch)
# res = cs.search_by_tag(tag_list=['createFineTuneCompletion', 'OpenAiApi'])
# logger.debug(res)
# res = cs.search_by_cypher('代码中一共有多少个类', 'openai')
# logger.debug(res)
res = cs.search_by_desciption('使用不同的HTTP请求类型GET、POST、DELETE等来执行不同的操作', 'openai')
logger.debug(res)