180 lines
5.8 KiB
Python
180 lines
5.8 KiB
Python
# encoding: utf-8
|
||
'''
|
||
@author: 温进
|
||
@file: code_search.py
|
||
@time: 2023/11/21 下午2:35
|
||
@desc:
|
||
'''
|
||
import time
|
||
from loguru import logger
|
||
from collections import defaultdict
|
||
|
||
from dev_opsgpt.db_handler.graph_db_handler.nebula_handler import NebulaHandler
|
||
from dev_opsgpt.db_handler.vector_db_handler.chroma_handler import ChromaHandler
|
||
|
||
from dev_opsgpt.codechat.code_search.cypher_generator import CypherGenerator
|
||
from dev_opsgpt.codechat.code_search.tagger import Tagger
|
||
from dev_opsgpt.embeddings.get_embedding import get_embedding
|
||
|
||
# search_by_tag
|
||
VERTEX_SCORE = 10
|
||
HISTORY_VERTEX_SCORE = 5
|
||
VERTEX_MERGE_RATIO = 0.5
|
||
|
||
# search_by_description
|
||
MAX_DISTANCE = 1000
|
||
|
||
|
||
class CodeSearch:
|
||
def __init__(self, nh: NebulaHandler, ch: ChromaHandler, limit: int = 3):
|
||
'''
|
||
init
|
||
@param nh: NebulaHandler
|
||
@param ch: ChromaHandler
|
||
@param limit: limit of result
|
||
'''
|
||
self.nh = nh
|
||
self.ch = ch
|
||
self.limit = limit
|
||
|
||
def search_by_tag(self, query: str):
|
||
'''
|
||
search_code_res by tag
|
||
@param query: str
|
||
@return:
|
||
'''
|
||
tagger = Tagger()
|
||
tag_list = tagger.generate_tag_query(query)
|
||
logger.info(f'query tag={tag_list}')
|
||
|
||
# get all verticex
|
||
vertex_list = self.nh.get_vertices().get('v', [])
|
||
vertex_vid_list = [i.as_node().get_id().as_string() for i in vertex_list]
|
||
logger.debug(vertex_vid_list)
|
||
|
||
# update score
|
||
vertex_score_dict = defaultdict(lambda: 0)
|
||
for vid in vertex_vid_list:
|
||
for tag in tag_list:
|
||
if tag in vid:
|
||
vertex_score_dict[vid] += VERTEX_SCORE
|
||
|
||
# merge depend adj score
|
||
vertex_score_dict_final = {}
|
||
for vertex in vertex_score_dict:
|
||
cypher = f'''MATCH (v1)-[e]-(v2) where id(v1) == "{vertex}" RETURN v2'''
|
||
cypher_res = self.nh.execute_cypher(cypher, self.nh.space_name)
|
||
cypher_res_dict = self.nh.result_to_dict(cypher_res)
|
||
|
||
adj_vertex_list = [i.as_node().get_id().as_string() for i in cypher_res_dict.get('v2', [])]
|
||
|
||
score = vertex_score_dict.get(vertex, 0)
|
||
for adj_vertex in adj_vertex_list:
|
||
score += vertex_score_dict.get(adj_vertex, 0) * VERTEX_MERGE_RATIO
|
||
|
||
if score > 0:
|
||
vertex_score_dict_final[vertex] = score
|
||
|
||
# get most prominent package tag
|
||
package_score_dict = defaultdict(lambda: 0)
|
||
for vertex, score in vertex_score_dict.items():
|
||
package = '#'.join(vertex.split('#')[0:2])
|
||
package_score_dict[package] += score
|
||
|
||
# get respective code
|
||
res = []
|
||
package_score_tuple = list(package_score_dict.items())
|
||
package_score_tuple.sort(key=lambda x: x[1], reverse=True)
|
||
|
||
ids = [i[0] for i in package_score_tuple]
|
||
chroma_res = self.ch.get(ids=ids, include=['metadatas'])
|
||
for vertex, score in package_score_tuple:
|
||
index = chroma_res['result']['ids'].index(vertex)
|
||
code_text = chroma_res['result']['metadatas'][index]['code_text']
|
||
res.append({
|
||
"vertex": vertex,
|
||
"code_text": code_text}
|
||
)
|
||
if len(res) >= self.limit:
|
||
break
|
||
return res
|
||
|
||
def search_by_desciption(self, query: str, engine: str):
|
||
'''
|
||
search by perform sim search
|
||
@param query:
|
||
@return:
|
||
'''
|
||
query = query.replace(',', ',')
|
||
query_emb = get_embedding(engine=engine, text_list=[query])
|
||
query_emb = query_emb[query]
|
||
|
||
query_embeddings = [query_emb]
|
||
query_result = self.ch.query(query_embeddings=query_embeddings, n_results=self.limit,
|
||
include=['metadatas', 'distances'])
|
||
logger.debug(query_result)
|
||
|
||
res = []
|
||
for idx, distance in enumerate(query_result['result']['distances'][0]):
|
||
if distance < MAX_DISTANCE:
|
||
vertex = query_result['result']['ids'][0][idx]
|
||
code_text = query_result['result']['metadatas'][0][idx]['code_text']
|
||
res.append({
|
||
"vertex": vertex,
|
||
"code_text": code_text
|
||
})
|
||
|
||
return res
|
||
|
||
def search_by_cypher(self, query: str):
|
||
'''
|
||
search by generating cypher
|
||
@param query:
|
||
@param engine:
|
||
@return:
|
||
'''
|
||
cg = CypherGenerator()
|
||
cypher = cg.get_cypher(query)
|
||
|
||
if not cypher:
|
||
return None
|
||
|
||
cypher_res = self.nh.execute_cypher(cypher, self.nh.space_name)
|
||
logger.info(f'cypher execution result={cypher_res}')
|
||
if not cypher_res.is_succeeded():
|
||
return {
|
||
'cypher': '',
|
||
'cypher_res': ''
|
||
}
|
||
|
||
res = {
|
||
'cypher': cypher,
|
||
'cypher_res': cypher_res
|
||
}
|
||
|
||
return res
|
||
|
||
|
||
if __name__ == '__main__':
|
||
from configs.server_config import NEBULA_HOST, NEBULA_PORT, NEBULA_USER, NEBULA_PASSWORD, NEBULA_STORAGED_PORT
|
||
from configs.server_config import CHROMA_PERSISTENT_PATH
|
||
|
||
codebase_name = 'testing'
|
||
|
||
nh = NebulaHandler(host=NEBULA_HOST, port=NEBULA_PORT, username=NEBULA_USER,
|
||
password=NEBULA_PASSWORD, space_name=codebase_name)
|
||
nh.add_host(NEBULA_HOST, NEBULA_STORAGED_PORT)
|
||
time.sleep(0.5)
|
||
|
||
ch = ChromaHandler(path=CHROMA_PERSISTENT_PATH, collection_name=codebase_name)
|
||
|
||
cs = CodeSearch(nh, ch)
|
||
# res = cs.search_by_tag(tag_list=['createFineTuneCompletion', 'OpenAiApi'])
|
||
# logger.debug(res)
|
||
|
||
# res = cs.search_by_cypher('代码中一共有多少个类', 'openai')
|
||
# logger.debug(res)
|
||
|
||
res = cs.search_by_desciption('使用不同的HTTP请求类型(GET、POST、DELETE等)来执行不同的操作', 'openai')
|
||
logger.debug(res)
|