180 lines
5.8 KiB
Python
180 lines
5.8 KiB
Python
|
# encoding: utf-8
|
|||
|
'''
|
|||
|
@author: 温进
|
|||
|
@file: code_search.py
|
|||
|
@time: 2023/11/21 下午2:35
|
|||
|
@desc:
|
|||
|
'''
|
|||
|
import time
|
|||
|
from loguru import logger
|
|||
|
from collections import defaultdict
|
|||
|
|
|||
|
from dev_opsgpt.db_handler.graph_db_handler.nebula_handler import NebulaHandler
|
|||
|
from dev_opsgpt.db_handler.vector_db_handler.chroma_handler import ChromaHandler
|
|||
|
|
|||
|
from dev_opsgpt.codechat.code_search.cypher_generator import CypherGenerator
|
|||
|
from dev_opsgpt.codechat.code_search.tagger import Tagger
|
|||
|
from dev_opsgpt.embeddings.get_embedding import get_embedding
|
|||
|
|
|||
|
# search_by_tag
|
|||
|
VERTEX_SCORE = 10
|
|||
|
HISTORY_VERTEX_SCORE = 5
|
|||
|
VERTEX_MERGE_RATIO = 0.5
|
|||
|
|
|||
|
# search_by_description
|
|||
|
MAX_DISTANCE = 0.5
|
|||
|
|
|||
|
|
|||
|
class CodeSearch:
|
|||
|
def __init__(self, nh: NebulaHandler, ch: ChromaHandler, limit: int = 3):
|
|||
|
'''
|
|||
|
init
|
|||
|
@param nh: NebulaHandler
|
|||
|
@param ch: ChromaHandler
|
|||
|
@param limit: limit of result
|
|||
|
'''
|
|||
|
self.nh = nh
|
|||
|
self.ch = ch
|
|||
|
self.limit = limit
|
|||
|
|
|||
|
def search_by_tag(self, query: str):
|
|||
|
'''
|
|||
|
search_code_res by tag
|
|||
|
@param query: str
|
|||
|
@return:
|
|||
|
'''
|
|||
|
tagger = Tagger()
|
|||
|
tag_list = tagger.generate_tag_query(query)
|
|||
|
logger.info(f'query tag={tag_list}')
|
|||
|
|
|||
|
# get all verticex
|
|||
|
vertex_list = self.nh.get_vertices().get('v', [])
|
|||
|
vertex_vid_list = [i.as_node().get_id().as_string() for i in vertex_list]
|
|||
|
logger.debug(vertex_vid_list)
|
|||
|
|
|||
|
# update score
|
|||
|
vertex_score_dict = defaultdict(lambda: 0)
|
|||
|
for vid in vertex_vid_list:
|
|||
|
for tag in tag_list:
|
|||
|
if tag in vid:
|
|||
|
vertex_score_dict[vid] += VERTEX_SCORE
|
|||
|
|
|||
|
# merge depend adj score
|
|||
|
vertex_score_dict_final = {}
|
|||
|
for vertex in vertex_score_dict:
|
|||
|
cypher = f'''MATCH (v1)-[e]-(v2) where id(v1) == "{vertex}" RETURN v2'''
|
|||
|
cypher_res = self.nh.execute_cypher(cypher, self.nh.space_name)
|
|||
|
cypher_res_dict = self.nh.result_to_dict(cypher_res)
|
|||
|
|
|||
|
adj_vertex_list = [i.as_node().get_id().as_string() for i in cypher_res_dict.get('v2', [])]
|
|||
|
|
|||
|
score = vertex_score_dict.get(vertex, 0)
|
|||
|
for adj_vertex in adj_vertex_list:
|
|||
|
score += vertex_score_dict.get(adj_vertex, 0) * VERTEX_MERGE_RATIO
|
|||
|
|
|||
|
if score > 0:
|
|||
|
vertex_score_dict_final[vertex] = score
|
|||
|
|
|||
|
# get most prominent package tag
|
|||
|
package_score_dict = defaultdict(lambda: 0)
|
|||
|
for vertex, score in vertex_score_dict.items():
|
|||
|
package = '#'.join(vertex.split('#')[0:2])
|
|||
|
package_score_dict[package] += score
|
|||
|
|
|||
|
# get respective code
|
|||
|
res = []
|
|||
|
package_score_tuple = list(package_score_dict.items())
|
|||
|
package_score_tuple.sort(key=lambda x: x[1], reverse=True)
|
|||
|
|
|||
|
ids = [i[0] for i in package_score_tuple]
|
|||
|
chroma_res = self.ch.get(ids=ids, include=['metadatas'])
|
|||
|
for vertex, score in package_score_tuple:
|
|||
|
index = chroma_res['result']['ids'].index(vertex)
|
|||
|
code_text = chroma_res['result']['metadatas'][index]['code_text']
|
|||
|
res.append({
|
|||
|
"vertex": vertex,
|
|||
|
"code_text": code_text}
|
|||
|
)
|
|||
|
if len(res) >= self.limit:
|
|||
|
break
|
|||
|
return res
|
|||
|
|
|||
|
def search_by_desciption(self, query: str, engine: str):
|
|||
|
'''
|
|||
|
search by perform sim search
|
|||
|
@param query:
|
|||
|
@return:
|
|||
|
'''
|
|||
|
query = query.replace(',', ',')
|
|||
|
query_emb = get_embedding(engine=engine, text_list=[query])
|
|||
|
query_emb = query_emb[query]
|
|||
|
|
|||
|
query_embeddings = [query_emb]
|
|||
|
query_result = self.ch.query(query_embeddings=query_embeddings, n_results=self.limit,
|
|||
|
include=['metadatas', 'distances'])
|
|||
|
logger.debug(query_result)
|
|||
|
|
|||
|
res = []
|
|||
|
for idx, distance in enumerate(query_result['result']['distances'][0]):
|
|||
|
if distance < MAX_DISTANCE:
|
|||
|
vertex = query_result['result']['ids'][0][idx]
|
|||
|
code_text = query_result['result']['metadatas'][0][idx]['code_text']
|
|||
|
res.append({
|
|||
|
"vertex": vertex,
|
|||
|
"code_text": code_text
|
|||
|
})
|
|||
|
|
|||
|
return res
|
|||
|
|
|||
|
def search_by_cypher(self, query: str):
|
|||
|
'''
|
|||
|
search by generating cypher
|
|||
|
@param query:
|
|||
|
@param engine:
|
|||
|
@return:
|
|||
|
'''
|
|||
|
cg = CypherGenerator()
|
|||
|
cypher = cg.get_cypher(query)
|
|||
|
|
|||
|
if not cypher:
|
|||
|
return None
|
|||
|
|
|||
|
cypher_res = self.nh.execute_cypher(cypher, self.nh.space_name)
|
|||
|
logger.info(f'cypher execution result={cypher_res}')
|
|||
|
if not cypher_res.is_succeeded():
|
|||
|
return {
|
|||
|
'cypher': '',
|
|||
|
'cypher_res': ''
|
|||
|
}
|
|||
|
|
|||
|
res = {
|
|||
|
'cypher': cypher,
|
|||
|
'cypher_res': cypher_res
|
|||
|
}
|
|||
|
|
|||
|
return res
|
|||
|
|
|||
|
|
|||
|
if __name__ == '__main__':
|
|||
|
from configs.server_config import NEBULA_HOST, NEBULA_PORT, NEBULA_USER, NEBULA_PASSWORD, NEBULA_STORAGED_PORT
|
|||
|
from configs.server_config import CHROMA_PERSISTENT_PATH
|
|||
|
|
|||
|
codebase_name = 'testing'
|
|||
|
|
|||
|
nh = NebulaHandler(host=NEBULA_HOST, port=NEBULA_PORT, username=NEBULA_USER,
|
|||
|
password=NEBULA_PASSWORD, space_name=codebase_name)
|
|||
|
nh.add_host(NEBULA_HOST, NEBULA_STORAGED_PORT)
|
|||
|
time.sleep(0.5)
|
|||
|
|
|||
|
ch = ChromaHandler(path=CHROMA_PERSISTENT_PATH, collection_name=codebase_name)
|
|||
|
|
|||
|
cs = CodeSearch(nh, ch)
|
|||
|
# res = cs.search_by_tag(tag_list=['createFineTuneCompletion', 'OpenAiApi'])
|
|||
|
# logger.debug(res)
|
|||
|
|
|||
|
# res = cs.search_by_cypher('代码中一共有多少个类', 'openai')
|
|||
|
# logger.debug(res)
|
|||
|
|
|||
|
res = cs.search_by_desciption('使用不同的HTTP请求类型(GET、POST、DELETE等)来执行不同的操作', 'openai')
|
|||
|
logger.debug(res)
|