2023-12-07 20:17:21 +08:00
|
|
|
|
# encoding: utf-8
|
|
|
|
|
'''
|
|
|
|
|
@author: 温进
|
|
|
|
|
@file: code_search.py
|
|
|
|
|
@time: 2023/11/21 下午2:35
|
|
|
|
|
@desc:
|
|
|
|
|
'''
|
2024-03-12 15:31:06 +08:00
|
|
|
|
import json
|
2023-12-07 20:17:21 +08:00
|
|
|
|
import time
|
|
|
|
|
from loguru import logger
|
|
|
|
|
from collections import defaultdict
|
|
|
|
|
|
2024-01-26 14:03:25 +08:00
|
|
|
|
from coagent.db_handler.graph_db_handler.nebula_handler import NebulaHandler
|
|
|
|
|
from coagent.db_handler.vector_db_handler.chroma_handler import ChromaHandler
|
2023-12-07 20:17:21 +08:00
|
|
|
|
|
2024-01-26 14:03:25 +08:00
|
|
|
|
from coagent.codechat.code_search.cypher_generator import CypherGenerator
|
|
|
|
|
from coagent.codechat.code_search.tagger import Tagger
|
|
|
|
|
from coagent.embeddings.get_embedding import get_embedding
|
2024-03-12 15:31:06 +08:00
|
|
|
|
from coagent.llm_models.llm_config import LLMConfig, EmbedConfig
|
2023-12-07 20:17:21 +08:00
|
|
|
|
|
2024-01-26 14:03:25 +08:00
|
|
|
|
|
|
|
|
|
# from configs.model_config import EMBEDDING_DEVICE, EMBEDDING_MODEL
|
2023-12-07 20:17:21 +08:00
|
|
|
|
# search_by_tag
|
|
|
|
|
VERTEX_SCORE = 10
|
|
|
|
|
HISTORY_VERTEX_SCORE = 5
|
|
|
|
|
VERTEX_MERGE_RATIO = 0.5
|
|
|
|
|
|
|
|
|
|
# search_by_description
|
2023-12-26 11:41:53 +08:00
|
|
|
|
MAX_DISTANCE = 1000
|
2023-12-07 20:17:21 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CodeSearch:
|
2024-03-12 15:31:06 +08:00
|
|
|
|
def __init__(self, llm_config: LLMConfig, nh: NebulaHandler, ch: ChromaHandler, limit: int = 3,
|
|
|
|
|
local_graph_file_path: str = ''):
|
2023-12-07 20:17:21 +08:00
|
|
|
|
'''
|
|
|
|
|
init
|
|
|
|
|
@param nh: NebulaHandler
|
|
|
|
|
@param ch: ChromaHandler
|
|
|
|
|
@param limit: limit of result
|
|
|
|
|
'''
|
2024-01-26 14:03:25 +08:00
|
|
|
|
self.llm_config = llm_config
|
2024-03-12 15:31:06 +08:00
|
|
|
|
|
2023-12-07 20:17:21 +08:00
|
|
|
|
self.nh = nh
|
2024-03-12 15:31:06 +08:00
|
|
|
|
|
|
|
|
|
if not self.nh:
|
|
|
|
|
with open(local_graph_file_path, 'r') as f:
|
|
|
|
|
self.graph = json.load(f)
|
|
|
|
|
|
2023-12-07 20:17:21 +08:00
|
|
|
|
self.ch = ch
|
|
|
|
|
self.limit = limit
|
|
|
|
|
|
|
|
|
|
def search_by_tag(self, query: str):
|
|
|
|
|
'''
|
|
|
|
|
search_code_res by tag
|
|
|
|
|
@param query: str
|
|
|
|
|
@return:
|
|
|
|
|
'''
|
|
|
|
|
tagger = Tagger()
|
|
|
|
|
tag_list = tagger.generate_tag_query(query)
|
|
|
|
|
logger.info(f'query tag={tag_list}')
|
|
|
|
|
|
2024-03-12 15:31:06 +08:00
|
|
|
|
# get all vertices
|
2023-12-07 20:17:21 +08:00
|
|
|
|
vertex_list = self.nh.get_vertices().get('v', [])
|
|
|
|
|
vertex_vid_list = [i.as_node().get_id().as_string() for i in vertex_list]
|
|
|
|
|
|
|
|
|
|
# update score
|
|
|
|
|
vertex_score_dict = defaultdict(lambda: 0)
|
|
|
|
|
for vid in vertex_vid_list:
|
|
|
|
|
for tag in tag_list:
|
|
|
|
|
if tag in vid:
|
|
|
|
|
vertex_score_dict[vid] += VERTEX_SCORE
|
|
|
|
|
|
|
|
|
|
# merge depend adj score
|
|
|
|
|
vertex_score_dict_final = {}
|
|
|
|
|
for vertex in vertex_score_dict:
|
|
|
|
|
cypher = f'''MATCH (v1)-[e]-(v2) where id(v1) == "{vertex}" RETURN v2'''
|
|
|
|
|
cypher_res = self.nh.execute_cypher(cypher, self.nh.space_name)
|
|
|
|
|
cypher_res_dict = self.nh.result_to_dict(cypher_res)
|
|
|
|
|
|
|
|
|
|
adj_vertex_list = [i.as_node().get_id().as_string() for i in cypher_res_dict.get('v2', [])]
|
|
|
|
|
|
|
|
|
|
score = vertex_score_dict.get(vertex, 0)
|
|
|
|
|
for adj_vertex in adj_vertex_list:
|
|
|
|
|
score += vertex_score_dict.get(adj_vertex, 0) * VERTEX_MERGE_RATIO
|
|
|
|
|
|
|
|
|
|
if score > 0:
|
|
|
|
|
vertex_score_dict_final[vertex] = score
|
|
|
|
|
|
|
|
|
|
# get most prominent package tag
|
|
|
|
|
package_score_dict = defaultdict(lambda: 0)
|
2024-01-26 14:03:25 +08:00
|
|
|
|
|
2024-03-12 15:31:06 +08:00
|
|
|
|
for vertex, score in vertex_score_dict_final.items():
|
2024-01-26 14:03:25 +08:00
|
|
|
|
if '#' in vertex:
|
|
|
|
|
# get class name first
|
|
|
|
|
cypher = f'''MATCH (v1:class)-[e:contain]->(v2) WHERE id(v2) == '{vertex}' RETURN id(v1) as id;'''
|
|
|
|
|
cypher_res = self.nh.execute_cypher(cypher=cypher, format_res=True)
|
|
|
|
|
class_vertices = cypher_res.get('id', [])
|
|
|
|
|
if not class_vertices:
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
vertex = class_vertices[0].as_string()
|
|
|
|
|
|
|
|
|
|
# get package name
|
|
|
|
|
cypher = f'''MATCH (v1:package)-[e:contain]->(v2) WHERE id(v2) == '{vertex}' RETURN id(v1) as id;'''
|
|
|
|
|
cypher_res = self.nh.execute_cypher(cypher=cypher, format_res=True)
|
|
|
|
|
pac_vertices = cypher_res.get('id', [])
|
|
|
|
|
if not pac_vertices:
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
package = pac_vertices[0].as_string()
|
2023-12-07 20:17:21 +08:00
|
|
|
|
package_score_dict[package] += score
|
|
|
|
|
|
|
|
|
|
# get respective code
|
|
|
|
|
res = []
|
|
|
|
|
package_score_tuple = list(package_score_dict.items())
|
|
|
|
|
package_score_tuple.sort(key=lambda x: x[1], reverse=True)
|
|
|
|
|
|
|
|
|
|
ids = [i[0] for i in package_score_tuple]
|
2024-01-26 14:03:25 +08:00
|
|
|
|
logger.info(f'ids={ids}')
|
2023-12-07 20:17:21 +08:00
|
|
|
|
chroma_res = self.ch.get(ids=ids, include=['metadatas'])
|
2024-01-26 14:03:25 +08:00
|
|
|
|
|
2024-03-12 15:31:06 +08:00
|
|
|
|
for vertex, score in package_score_tuple:
|
|
|
|
|
index = chroma_res['result']['ids'].index(vertex)
|
|
|
|
|
code_text = chroma_res['result']['metadatas'][index]['code_text']
|
|
|
|
|
res.append({
|
|
|
|
|
"vertex": vertex,
|
|
|
|
|
"code_text": code_text}
|
|
|
|
|
)
|
|
|
|
|
if len(res) >= self.limit:
|
|
|
|
|
break
|
|
|
|
|
# logger.info(f'retrival code={res}')
|
|
|
|
|
return res
|
|
|
|
|
|
|
|
|
|
def search_by_tag_by_graph(self, query: str):
|
|
|
|
|
'''
|
|
|
|
|
search code by tag with graph
|
|
|
|
|
@param query:
|
|
|
|
|
@return:
|
|
|
|
|
'''
|
|
|
|
|
tagger = Tagger()
|
|
|
|
|
tag_list = tagger.generate_tag_query(query)
|
|
|
|
|
logger.info(f'query tag={tag_list}')
|
|
|
|
|
|
|
|
|
|
# loop to get package node
|
|
|
|
|
package_score_dict = {}
|
|
|
|
|
for code, structure in self.graph.items():
|
|
|
|
|
score = 0
|
|
|
|
|
for class_name in structure['class_name_list']:
|
|
|
|
|
for tag in tag_list:
|
|
|
|
|
if tag.lower() in class_name.lower():
|
|
|
|
|
score += 1
|
|
|
|
|
|
|
|
|
|
for func_name_list in structure['func_name_dict'].values():
|
|
|
|
|
for func_name in func_name_list:
|
|
|
|
|
for tag in tag_list:
|
|
|
|
|
if tag.lower() in func_name.lower():
|
|
|
|
|
score += 1
|
|
|
|
|
package_score_dict[structure['pac_name']] = score
|
|
|
|
|
|
|
|
|
|
# get respective code
|
|
|
|
|
res = []
|
|
|
|
|
package_score_tuple = list(package_score_dict.items())
|
|
|
|
|
package_score_tuple.sort(key=lambda x: x[1], reverse=True)
|
|
|
|
|
|
|
|
|
|
ids = [i[0] for i in package_score_tuple]
|
|
|
|
|
logger.info(f'ids={ids}')
|
|
|
|
|
chroma_res = self.ch.get(ids=ids, include=['metadatas'])
|
|
|
|
|
|
2024-01-26 14:03:25 +08:00
|
|
|
|
# logger.info(chroma_res)
|
2023-12-07 20:17:21 +08:00
|
|
|
|
for vertex, score in package_score_tuple:
|
|
|
|
|
index = chroma_res['result']['ids'].index(vertex)
|
|
|
|
|
code_text = chroma_res['result']['metadatas'][index]['code_text']
|
|
|
|
|
res.append({
|
|
|
|
|
"vertex": vertex,
|
|
|
|
|
"code_text": code_text}
|
|
|
|
|
)
|
|
|
|
|
if len(res) >= self.limit:
|
|
|
|
|
break
|
2024-03-12 15:31:06 +08:00
|
|
|
|
# logger.info(f'retrival code={res}')
|
2023-12-07 20:17:21 +08:00
|
|
|
|
return res
|
|
|
|
|
|
2024-03-12 15:31:06 +08:00
|
|
|
|
def search_by_desciption(self, query: str, engine: str, model_path: str = "text2vec-base-chinese", embedding_device: str = "cpu", embed_config: EmbedConfig=None):
|
2023-12-07 20:17:21 +08:00
|
|
|
|
'''
|
|
|
|
|
search by perform sim search
|
|
|
|
|
@param query:
|
|
|
|
|
@return:
|
|
|
|
|
'''
|
|
|
|
|
query = query.replace(',', ',')
|
2024-03-12 15:31:06 +08:00
|
|
|
|
query_emb = get_embedding(engine=engine, text_list=[query], model_path=model_path, embedding_device= embedding_device, embed_config=embed_config)
|
2023-12-07 20:17:21 +08:00
|
|
|
|
query_emb = query_emb[query]
|
|
|
|
|
|
|
|
|
|
query_embeddings = [query_emb]
|
|
|
|
|
query_result = self.ch.query(query_embeddings=query_embeddings, n_results=self.limit,
|
|
|
|
|
include=['metadatas', 'distances'])
|
|
|
|
|
|
|
|
|
|
res = []
|
|
|
|
|
for idx, distance in enumerate(query_result['result']['distances'][0]):
|
|
|
|
|
if distance < MAX_DISTANCE:
|
|
|
|
|
vertex = query_result['result']['ids'][0][idx]
|
|
|
|
|
code_text = query_result['result']['metadatas'][0][idx]['code_text']
|
|
|
|
|
res.append({
|
|
|
|
|
"vertex": vertex,
|
|
|
|
|
"code_text": code_text
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
return res
|
|
|
|
|
|
|
|
|
|
def search_by_cypher(self, query: str):
|
|
|
|
|
'''
|
|
|
|
|
search by generating cypher
|
|
|
|
|
@param query:
|
|
|
|
|
@param engine:
|
|
|
|
|
@return:
|
|
|
|
|
'''
|
2024-01-26 14:03:25 +08:00
|
|
|
|
cg = CypherGenerator(self.llm_config)
|
2023-12-07 20:17:21 +08:00
|
|
|
|
cypher = cg.get_cypher(query)
|
|
|
|
|
|
|
|
|
|
if not cypher:
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
cypher_res = self.nh.execute_cypher(cypher, self.nh.space_name)
|
|
|
|
|
logger.info(f'cypher execution result={cypher_res}')
|
|
|
|
|
if not cypher_res.is_succeeded():
|
|
|
|
|
return {
|
|
|
|
|
'cypher': '',
|
|
|
|
|
'cypher_res': ''
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
res = {
|
|
|
|
|
'cypher': cypher,
|
|
|
|
|
'cypher_res': cypher_res
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return res
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
2024-01-26 14:03:25 +08:00
|
|
|
|
# from configs.server_config import NEBULA_HOST, NEBULA_PORT, NEBULA_USER, NEBULA_PASSWORD, NEBULA_STORAGED_PORT
|
|
|
|
|
# from configs.server_config import CHROMA_PERSISTENT_PATH
|
|
|
|
|
from coagent.base_configs.env_config import (
|
|
|
|
|
NEBULA_HOST, NEBULA_PORT, NEBULA_USER, NEBULA_PASSWORD, NEBULA_STORAGED_PORT,
|
|
|
|
|
CHROMA_PERSISTENT_PATH
|
|
|
|
|
)
|
2023-12-07 20:17:21 +08:00
|
|
|
|
codebase_name = 'testing'
|
|
|
|
|
|
|
|
|
|
nh = NebulaHandler(host=NEBULA_HOST, port=NEBULA_PORT, username=NEBULA_USER,
|
|
|
|
|
password=NEBULA_PASSWORD, space_name=codebase_name)
|
|
|
|
|
nh.add_host(NEBULA_HOST, NEBULA_STORAGED_PORT)
|
|
|
|
|
time.sleep(0.5)
|
|
|
|
|
|
|
|
|
|
ch = ChromaHandler(path=CHROMA_PERSISTENT_PATH, collection_name=codebase_name)
|
|
|
|
|
|
|
|
|
|
cs = CodeSearch(nh, ch)
|
|
|
|
|
# res = cs.search_by_tag(tag_list=['createFineTuneCompletion', 'OpenAiApi'])
|
|
|
|
|
# logger.debug(res)
|
|
|
|
|
|
|
|
|
|
# res = cs.search_by_cypher('代码中一共有多少个类', 'openai')
|
|
|
|
|
# logger.debug(res)
|
|
|
|
|
|
|
|
|
|
res = cs.search_by_desciption('使用不同的HTTP请求类型(GET、POST、DELETE等)来执行不同的操作', 'openai')
|
|
|
|
|
logger.debug(res)
|