# encoding: utf-8 ''' @author: 温进 @file: code_search.py @time: 2023/11/21 下午2:35 @desc: ''' import time from loguru import logger from collections import defaultdict from coagent.db_handler.graph_db_handler.nebula_handler import NebulaHandler from coagent.db_handler.vector_db_handler.chroma_handler import ChromaHandler from coagent.codechat.code_search.cypher_generator import CypherGenerator from coagent.codechat.code_search.tagger import Tagger from coagent.embeddings.get_embedding import get_embedding from coagent.llm_models.llm_config import LLMConfig # from configs.model_config import EMBEDDING_DEVICE, EMBEDDING_MODEL # search_by_tag VERTEX_SCORE = 10 HISTORY_VERTEX_SCORE = 5 VERTEX_MERGE_RATIO = 0.5 # search_by_description MAX_DISTANCE = 1000 class CodeSearch: def __init__(self, llm_config: LLMConfig, nh: NebulaHandler, ch: ChromaHandler, limit: int = 3): ''' init @param nh: NebulaHandler @param ch: ChromaHandler @param limit: limit of result ''' self.llm_config = llm_config self.nh = nh self.ch = ch self.limit = limit def search_by_tag(self, query: str): ''' search_code_res by tag @param query: str @return: ''' tagger = Tagger() tag_list = tagger.generate_tag_query(query) logger.info(f'query tag={tag_list}') # get all verticex vertex_list = self.nh.get_vertices().get('v', []) vertex_vid_list = [i.as_node().get_id().as_string() for i in vertex_list] # update score vertex_score_dict = defaultdict(lambda: 0) for vid in vertex_vid_list: for tag in tag_list: if tag in vid: vertex_score_dict[vid] += VERTEX_SCORE # merge depend adj score vertex_score_dict_final = {} for vertex in vertex_score_dict: cypher = f'''MATCH (v1)-[e]-(v2) where id(v1) == "{vertex}" RETURN v2''' cypher_res = self.nh.execute_cypher(cypher, self.nh.space_name) cypher_res_dict = self.nh.result_to_dict(cypher_res) adj_vertex_list = [i.as_node().get_id().as_string() for i in cypher_res_dict.get('v2', [])] score = vertex_score_dict.get(vertex, 0) for adj_vertex in adj_vertex_list: score += vertex_score_dict.get(adj_vertex, 0) * VERTEX_MERGE_RATIO if score > 0: vertex_score_dict_final[vertex] = score # get most prominent package tag package_score_dict = defaultdict(lambda: 0) for vertex, score in vertex_score_dict.items(): if '#' in vertex: # get class name first cypher = f'''MATCH (v1:class)-[e:contain]->(v2) WHERE id(v2) == '{vertex}' RETURN id(v1) as id;''' cypher_res = self.nh.execute_cypher(cypher=cypher, format_res=True) class_vertices = cypher_res.get('id', []) if not class_vertices: continue vertex = class_vertices[0].as_string() # get package name cypher = f'''MATCH (v1:package)-[e:contain]->(v2) WHERE id(v2) == '{vertex}' RETURN id(v1) as id;''' cypher_res = self.nh.execute_cypher(cypher=cypher, format_res=True) pac_vertices = cypher_res.get('id', []) if not pac_vertices: continue package = pac_vertices[0].as_string() package_score_dict[package] += score # get respective code res = [] package_score_tuple = list(package_score_dict.items()) package_score_tuple.sort(key=lambda x: x[1], reverse=True) ids = [i[0] for i in package_score_tuple] logger.info(f'ids={ids}') chroma_res = self.ch.get(ids=ids, include=['metadatas']) # logger.info(chroma_res) for vertex, score in package_score_tuple: index = chroma_res['result']['ids'].index(vertex) code_text = chroma_res['result']['metadatas'][index]['code_text'] res.append({ "vertex": vertex, "code_text": code_text} ) if len(res) >= self.limit: break logger.info(f'retrival code={res}') return res def search_by_desciption(self, query: str, engine: str, model_path: str = "text2vec-base-chinese", embedding_device: str = "cpu"): ''' search by perform sim search @param query: @return: ''' query = query.replace(',', ',') query_emb = get_embedding(engine=engine, text_list=[query], model_path=model_path, embedding_device= embedding_device,) query_emb = query_emb[query] query_embeddings = [query_emb] query_result = self.ch.query(query_embeddings=query_embeddings, n_results=self.limit, include=['metadatas', 'distances']) logger.debug(query_result) res = [] for idx, distance in enumerate(query_result['result']['distances'][0]): if distance < MAX_DISTANCE: vertex = query_result['result']['ids'][0][idx] code_text = query_result['result']['metadatas'][0][idx]['code_text'] res.append({ "vertex": vertex, "code_text": code_text }) return res def search_by_cypher(self, query: str): ''' search by generating cypher @param query: @param engine: @return: ''' cg = CypherGenerator(self.llm_config) cypher = cg.get_cypher(query) if not cypher: return None cypher_res = self.nh.execute_cypher(cypher, self.nh.space_name) logger.info(f'cypher execution result={cypher_res}') if not cypher_res.is_succeeded(): return { 'cypher': '', 'cypher_res': '' } res = { 'cypher': cypher, 'cypher_res': cypher_res } return res if __name__ == '__main__': # from configs.server_config import NEBULA_HOST, NEBULA_PORT, NEBULA_USER, NEBULA_PASSWORD, NEBULA_STORAGED_PORT # from configs.server_config import CHROMA_PERSISTENT_PATH from coagent.base_configs.env_config import ( NEBULA_HOST, NEBULA_PORT, NEBULA_USER, NEBULA_PASSWORD, NEBULA_STORAGED_PORT, CHROMA_PERSISTENT_PATH ) codebase_name = 'testing' nh = NebulaHandler(host=NEBULA_HOST, port=NEBULA_PORT, username=NEBULA_USER, password=NEBULA_PASSWORD, space_name=codebase_name) nh.add_host(NEBULA_HOST, NEBULA_STORAGED_PORT) time.sleep(0.5) ch = ChromaHandler(path=CHROMA_PERSISTENT_PATH, collection_name=codebase_name) cs = CodeSearch(nh, ch) # res = cs.search_by_tag(tag_list=['createFineTuneCompletion', 'OpenAiApi']) # logger.debug(res) # res = cs.search_by_cypher('代码中一共有多少个类', 'openai') # logger.debug(res) res = cs.search_by_desciption('使用不同的HTTP请求类型(GET、POST、DELETE等)来执行不同的操作', 'openai') logger.debug(res)