[feature](coagent)<增加antflow兼容和增加coagent demo>
This commit is contained in:
		
							parent
							
								
									c14b41ecec
								
							
						
					
					
						commit
						4d9b268a98
					
				| @ -26,9 +26,12 @@ JUPYTER_WORK_PATH = os.environ.get("JUPYTER_WORK_PATH", None) or os.path.join(ex | |||||||
| WEB_CRAWL_PATH = os.environ.get("WEB_CRAWL_PATH", None) or os.path.join(executable_path, "knowledge_base") | WEB_CRAWL_PATH = os.environ.get("WEB_CRAWL_PATH", None) or os.path.join(executable_path, "knowledge_base") | ||||||
| 
 | 
 | ||||||
| # NEBULA_DATA存储路径 | # NEBULA_DATA存储路径 | ||||||
| NELUBA_PATH = os.environ.get("NELUBA_PATH", None) or os.path.join(executable_path, "data/neluba_data") | NEBULA_PATH = os.environ.get("NEBULA_PATH", None) or os.path.join(executable_path, "data/nebula_data") | ||||||
| 
 | 
 | ||||||
| for _path in [LOG_PATH, SOURCE_PATH, KB_ROOT_PATH, NLTK_DATA_PATH, JUPYTER_WORK_PATH, WEB_CRAWL_PATH, NELUBA_PATH]: | # CHROMA 存储路径 | ||||||
|  | CHROMA_PERSISTENT_PATH = os.environ.get("CHROMA_PERSISTENT_PATH", None) or os.path.join(executable_path, "data/chroma_data") | ||||||
|  | 
 | ||||||
|  | for _path in [LOG_PATH, SOURCE_PATH, KB_ROOT_PATH, CB_ROOT_PATH, NLTK_DATA_PATH, JUPYTER_WORK_PATH, WEB_CRAWL_PATH, NEBULA_PATH, CHROMA_PERSISTENT_PATH]: | ||||||
|     if not os.path.exists(_path): |     if not os.path.exists(_path): | ||||||
|         os.makedirs(_path, exist_ok=True) |         os.makedirs(_path, exist_ok=True) | ||||||
| 
 | 
 | ||||||
| @ -58,7 +61,8 @@ NEBULA_GRAPH_SERVER = { | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| # CHROMA CONFIG | # CHROMA CONFIG | ||||||
| CHROMA_PERSISTENT_PATH = '/home/user/chatbot/data/chroma_data' | # CHROMA_PERSISTENT_PATH = '/home/user/chatbot/data/chroma_data' | ||||||
|  | # CHROMA_PERSISTENT_PATH = '/Users/bingxu/Desktop/工作/大模型/chatbot/codefuse-chatbot-antcode/data/chroma_data' | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| # 默认向量库类型。可选:faiss, milvus, pg. | # 默认向量库类型。可选:faiss, milvus, pg. | ||||||
|  | |||||||
| @ -7,7 +7,7 @@ from langchain import LLMChain | |||||||
| from langchain.callbacks import AsyncIteratorCallbackHandler | from langchain.callbacks import AsyncIteratorCallbackHandler | ||||||
| from langchain.prompts.chat import ChatPromptTemplate | from langchain.prompts.chat import ChatPromptTemplate | ||||||
| 
 | 
 | ||||||
| from coagent.llm_models import getChatModel, getChatModelFromConfig | from coagent.llm_models import getChatModelFromConfig | ||||||
| from coagent.chat.utils import History, wrap_done | from coagent.chat.utils import History, wrap_done | ||||||
| from coagent.llm_models.llm_config import LLMConfig, EmbedConfig | from coagent.llm_models.llm_config import LLMConfig, EmbedConfig | ||||||
| # from configs.model_config import (llm_model_dict, LLM_MODEL, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD) | # from configs.model_config import (llm_model_dict, LLM_MODEL, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD) | ||||||
|  | |||||||
| @ -22,7 +22,7 @@ from coagent.connector.configs.prompts import CODE_PROMPT_TEMPLATE | |||||||
| from coagent.chat.utils import History, wrap_done | from coagent.chat.utils import History, wrap_done | ||||||
| from coagent.utils import BaseResponse | from coagent.utils import BaseResponse | ||||||
| from .base_chat import Chat | from .base_chat import Chat | ||||||
| from coagent.llm_models import getChatModel, getChatModelFromConfig | from coagent.llm_models import getChatModelFromConfig | ||||||
| 
 | 
 | ||||||
| from coagent.llm_models.llm_config import LLMConfig, EmbedConfig | from coagent.llm_models.llm_config import LLMConfig, EmbedConfig | ||||||
| 
 | 
 | ||||||
| @ -67,6 +67,7 @@ class CodeChat(Chat): | |||||||
|                                 embed_model_path=embed_config.embed_model_path, |                                 embed_model_path=embed_config.embed_model_path, | ||||||
|                                 embed_engine=embed_config.embed_engine, |                                 embed_engine=embed_config.embed_engine, | ||||||
|                                 model_device=embed_config.model_device, |                                 model_device=embed_config.model_device, | ||||||
|  |                                 embed_config=embed_config | ||||||
|                                 ) |                                 ) | ||||||
| 
 | 
 | ||||||
|         context = codes_res['context'] |         context = codes_res['context'] | ||||||
|  | |||||||
| @ -12,7 +12,7 @@ from langchain.schema import ( | |||||||
| 
 | 
 | ||||||
| # from configs.model_config import CODE_INTERPERT_TEMPLATE | # from configs.model_config import CODE_INTERPERT_TEMPLATE | ||||||
| from coagent.connector.configs.prompts import CODE_INTERPERT_TEMPLATE | from coagent.connector.configs.prompts import CODE_INTERPERT_TEMPLATE | ||||||
| from coagent.llm_models.openai_model import getChatModel, getChatModelFromConfig | from coagent.llm_models.openai_model import getChatModelFromConfig | ||||||
| from coagent.llm_models.llm_config import LLMConfig | from coagent.llm_models.llm_config import LLMConfig | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @ -53,9 +53,15 @@ class CodeIntepreter: | |||||||
|             message = CODE_INTERPERT_TEMPLATE.format(code=code) |             message = CODE_INTERPERT_TEMPLATE.format(code=code) | ||||||
|             messages.append(message) |             messages.append(message) | ||||||
| 
 | 
 | ||||||
|         chat_ress = chat_model.batch(messages) |         try: | ||||||
|  |             chat_ress = [chat_model(messages) for message in messages] | ||||||
|  |         except: | ||||||
|  |             chat_ress = chat_model.batch(messages) | ||||||
|         for chat_res, code in zip(chat_ress, code_list): |         for chat_res, code in zip(chat_ress, code_list): | ||||||
|             res[code] = chat_res.content |             try: | ||||||
|  |                 res[code] = chat_res.content | ||||||
|  |             except: | ||||||
|  |                 res[code] = chat_res | ||||||
|         return res |         return res | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -27,7 +27,7 @@ class DirCrawler: | |||||||
|         logger.info(java_file_list) |         logger.info(java_file_list) | ||||||
| 
 | 
 | ||||||
|         for java_file in java_file_list: |         for java_file in java_file_list: | ||||||
|             with open(java_file) as f: |             with open(java_file, encoding="utf-8") as f: | ||||||
|                 java_code = ''.join(f.readlines()) |                 java_code = ''.join(f.readlines()) | ||||||
|                 java_code_dict[java_file] = java_code |                 java_code_dict[java_file] = java_code | ||||||
|         return java_code_dict |         return java_code_dict | ||||||
|  | |||||||
| @ -5,6 +5,7 @@ | |||||||
| @time: 2023/11/21 下午2:35 | @time: 2023/11/21 下午2:35 | ||||||
| @desc: | @desc: | ||||||
| ''' | ''' | ||||||
|  | import json | ||||||
| import time | import time | ||||||
| from loguru import logger | from loguru import logger | ||||||
| from collections import defaultdict | from collections import defaultdict | ||||||
| @ -15,7 +16,7 @@ from coagent.db_handler.vector_db_handler.chroma_handler import ChromaHandler | |||||||
| from coagent.codechat.code_search.cypher_generator import CypherGenerator | from coagent.codechat.code_search.cypher_generator import CypherGenerator | ||||||
| from coagent.codechat.code_search.tagger import Tagger | from coagent.codechat.code_search.tagger import Tagger | ||||||
| from coagent.embeddings.get_embedding import get_embedding | from coagent.embeddings.get_embedding import get_embedding | ||||||
| from coagent.llm_models.llm_config import LLMConfig | from coagent.llm_models.llm_config import LLMConfig, EmbedConfig | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| # from configs.model_config import EMBEDDING_DEVICE, EMBEDDING_MODEL | # from configs.model_config import EMBEDDING_DEVICE, EMBEDDING_MODEL | ||||||
| @ -29,7 +30,8 @@ MAX_DISTANCE = 1000 | |||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class CodeSearch: | class CodeSearch: | ||||||
|     def __init__(self, llm_config: LLMConfig, nh: NebulaHandler, ch: ChromaHandler, limit: int = 3): |     def __init__(self, llm_config: LLMConfig, nh: NebulaHandler, ch: ChromaHandler, limit: int = 3, | ||||||
|  |                  local_graph_file_path: str = ''): | ||||||
|         ''' |         ''' | ||||||
|         init |         init | ||||||
|         @param nh: NebulaHandler |         @param nh: NebulaHandler | ||||||
| @ -37,7 +39,13 @@ class CodeSearch: | |||||||
|         @param limit: limit of result |         @param limit: limit of result | ||||||
|         ''' |         ''' | ||||||
|         self.llm_config = llm_config |         self.llm_config = llm_config | ||||||
|  | 
 | ||||||
|         self.nh = nh |         self.nh = nh | ||||||
|  | 
 | ||||||
|  |         if not self.nh: | ||||||
|  |             with open(local_graph_file_path, 'r') as f: | ||||||
|  |                 self.graph = json.load(f) | ||||||
|  | 
 | ||||||
|         self.ch = ch |         self.ch = ch | ||||||
|         self.limit = limit |         self.limit = limit | ||||||
| 
 | 
 | ||||||
| @ -51,7 +59,7 @@ class CodeSearch: | |||||||
|         tag_list = tagger.generate_tag_query(query) |         tag_list = tagger.generate_tag_query(query) | ||||||
|         logger.info(f'query tag={tag_list}') |         logger.info(f'query tag={tag_list}') | ||||||
| 
 | 
 | ||||||
|         # get all verticex |         # get all vertices | ||||||
|         vertex_list = self.nh.get_vertices().get('v', []) |         vertex_list = self.nh.get_vertices().get('v', []) | ||||||
|         vertex_vid_list = [i.as_node().get_id().as_string() for i in vertex_list] |         vertex_vid_list = [i.as_node().get_id().as_string() for i in vertex_list] | ||||||
| 
 | 
 | ||||||
| @ -81,7 +89,7 @@ class CodeSearch: | |||||||
|         # get most prominent package tag |         # get most prominent package tag | ||||||
|         package_score_dict = defaultdict(lambda: 0) |         package_score_dict = defaultdict(lambda: 0) | ||||||
| 
 | 
 | ||||||
|         for vertex, score in vertex_score_dict.items(): |         for vertex, score in vertex_score_dict_final.items(): | ||||||
|             if '#' in vertex: |             if '#' in vertex: | ||||||
|                 # get class name first |                 # get class name first | ||||||
|                 cypher = f'''MATCH (v1:class)-[e:contain]->(v2) WHERE id(v2) == '{vertex}' RETURN id(v1) as id;''' |                 cypher = f'''MATCH (v1:class)-[e:contain]->(v2) WHERE id(v2) == '{vertex}' RETURN id(v1) as id;''' | ||||||
| @ -111,6 +119,53 @@ class CodeSearch: | |||||||
|         logger.info(f'ids={ids}') |         logger.info(f'ids={ids}') | ||||||
|         chroma_res = self.ch.get(ids=ids, include=['metadatas']) |         chroma_res = self.ch.get(ids=ids, include=['metadatas']) | ||||||
| 
 | 
 | ||||||
|  |         for vertex, score in package_score_tuple: | ||||||
|  |             index = chroma_res['result']['ids'].index(vertex) | ||||||
|  |             code_text = chroma_res['result']['metadatas'][index]['code_text'] | ||||||
|  |             res.append({ | ||||||
|  |                 "vertex": vertex, | ||||||
|  |                 "code_text": code_text} | ||||||
|  |             ) | ||||||
|  |             if len(res) >= self.limit: | ||||||
|  |                 break | ||||||
|  |         # logger.info(f'retrival code={res}') | ||||||
|  |         return res | ||||||
|  | 
 | ||||||
|  |     def search_by_tag_by_graph(self, query: str): | ||||||
|  |         ''' | ||||||
|  |         search code by tag with graph | ||||||
|  |         @param query: | ||||||
|  |         @return: | ||||||
|  |         ''' | ||||||
|  |         tagger = Tagger() | ||||||
|  |         tag_list = tagger.generate_tag_query(query) | ||||||
|  |         logger.info(f'query tag={tag_list}') | ||||||
|  | 
 | ||||||
|  |         # loop to get package node | ||||||
|  |         package_score_dict = {} | ||||||
|  |         for code, structure in self.graph.items(): | ||||||
|  |             score = 0 | ||||||
|  |             for class_name in structure['class_name_list']: | ||||||
|  |                 for tag in tag_list: | ||||||
|  |                     if tag.lower() in class_name.lower(): | ||||||
|  |                         score += 1 | ||||||
|  | 
 | ||||||
|  |             for func_name_list in structure['func_name_dict'].values(): | ||||||
|  |                 for func_name in func_name_list: | ||||||
|  |                     for tag in tag_list: | ||||||
|  |                         if tag.lower() in func_name.lower(): | ||||||
|  |                             score += 1 | ||||||
|  |             package_score_dict[structure['pac_name']] = score | ||||||
|  | 
 | ||||||
|  |         # get respective code | ||||||
|  |         res = [] | ||||||
|  |         package_score_tuple = list(package_score_dict.items()) | ||||||
|  |         package_score_tuple.sort(key=lambda x: x[1], reverse=True) | ||||||
|  | 
 | ||||||
|  |         ids = [i[0] for i in package_score_tuple] | ||||||
|  |         logger.info(f'ids={ids}') | ||||||
|  |         chroma_res = self.ch.get(ids=ids, include=['metadatas']) | ||||||
|  | 
 | ||||||
|         # logger.info(chroma_res) |         # logger.info(chroma_res) | ||||||
|         for vertex, score in package_score_tuple: |         for vertex, score in package_score_tuple: | ||||||
|             index = chroma_res['result']['ids'].index(vertex) |             index = chroma_res['result']['ids'].index(vertex) | ||||||
| @ -121,23 +176,22 @@ class CodeSearch: | |||||||
|             ) |             ) | ||||||
|             if len(res) >= self.limit: |             if len(res) >= self.limit: | ||||||
|                 break |                 break | ||||||
|         logger.info(f'retrival code={res}') |         # logger.info(f'retrival code={res}') | ||||||
|         return res |         return res | ||||||
| 
 | 
 | ||||||
|     def search_by_desciption(self, query: str, engine: str, model_path: str = "text2vec-base-chinese", embedding_device: str = "cpu"): |     def search_by_desciption(self, query: str, engine: str, model_path: str = "text2vec-base-chinese", embedding_device: str = "cpu", embed_config: EmbedConfig=None): | ||||||
|         ''' |         ''' | ||||||
|         search by perform sim search |         search by perform sim search | ||||||
|         @param query: |         @param query: | ||||||
|         @return: |         @return: | ||||||
|         ''' |         ''' | ||||||
|         query = query.replace(',', ',') |         query = query.replace(',', ',') | ||||||
|         query_emb = get_embedding(engine=engine, text_list=[query], model_path=model_path, embedding_device= embedding_device,) |         query_emb = get_embedding(engine=engine, text_list=[query], model_path=model_path, embedding_device= embedding_device, embed_config=embed_config) | ||||||
|         query_emb = query_emb[query] |         query_emb = query_emb[query] | ||||||
| 
 | 
 | ||||||
|         query_embeddings = [query_emb] |         query_embeddings = [query_emb] | ||||||
|         query_result = self.ch.query(query_embeddings=query_embeddings, n_results=self.limit, |         query_result = self.ch.query(query_embeddings=query_embeddings, n_results=self.limit, | ||||||
|                                      include=['metadatas', 'distances']) |                                      include=['metadatas', 'distances']) | ||||||
|         logger.debug(query_result) |  | ||||||
| 
 | 
 | ||||||
|         res = [] |         res = [] | ||||||
|         for idx, distance in enumerate(query_result['result']['distances'][0]): |         for idx, distance in enumerate(query_result['result']['distances'][0]): | ||||||
|  | |||||||
| @ -8,7 +8,7 @@ | |||||||
| from langchain import PromptTemplate | from langchain import PromptTemplate | ||||||
| from loguru import logger | from loguru import logger | ||||||
| 
 | 
 | ||||||
| from coagent.llm_models.openai_model import getChatModel, getChatModelFromConfig | from coagent.llm_models.openai_model import getChatModelFromConfig | ||||||
| from coagent.llm_models.llm_config import LLMConfig | from coagent.llm_models.llm_config import LLMConfig | ||||||
| from coagent.utils.postprocess import replace_lt_gt | from coagent.utils.postprocess import replace_lt_gt | ||||||
| from langchain.schema import ( | from langchain.schema import ( | ||||||
|  | |||||||
| @ -6,11 +6,10 @@ | |||||||
| @desc: | @desc: | ||||||
| ''' | ''' | ||||||
| import time | import time | ||||||
|  | import json | ||||||
|  | import os | ||||||
| from loguru import logger | from loguru import logger | ||||||
| 
 | 
 | ||||||
| # from configs.server_config import NEBULA_HOST, NEBULA_PORT, NEBULA_USER, NEBULA_PASSWORD, NEBULA_STORAGED_PORT |  | ||||||
| # from configs.server_config import CHROMA_PERSISTENT_PATH |  | ||||||
| # from configs.model_config import EMBEDDING_DEVICE, EMBEDDING_MODEL |  | ||||||
| from coagent.db_handler.graph_db_handler.nebula_handler import NebulaHandler | from coagent.db_handler.graph_db_handler.nebula_handler import NebulaHandler | ||||||
| from coagent.db_handler.vector_db_handler.chroma_handler import ChromaHandler | from coagent.db_handler.vector_db_handler.chroma_handler import ChromaHandler | ||||||
| from coagent.embeddings.get_embedding import get_embedding | from coagent.embeddings.get_embedding import get_embedding | ||||||
| @ -18,12 +17,14 @@ from coagent.llm_models.llm_config import EmbedConfig | |||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class CodeImporter: | class CodeImporter: | ||||||
|     def __init__(self, codebase_name: str, embed_config: EmbedConfig, nh: NebulaHandler, ch: ChromaHandler): |     def __init__(self, codebase_name: str, embed_config: EmbedConfig, nh: NebulaHandler, ch: ChromaHandler, | ||||||
|  |                  local_graph_file_path: str): | ||||||
|         self.codebase_name = codebase_name |         self.codebase_name = codebase_name | ||||||
|         # self.engine = engine |         # self.engine = engine | ||||||
|         self.embed_config: EmbedConfig= embed_config |         self.embed_config: EmbedConfig = embed_config | ||||||
|         self.nh = nh |         self.nh = nh | ||||||
|         self.ch = ch |         self.ch = ch | ||||||
|  |         self.local_graph_file_path = local_graph_file_path | ||||||
| 
 | 
 | ||||||
|     def import_code(self, static_analysis_res: dict, interpretation: dict, do_interpret: bool = True): |     def import_code(self, static_analysis_res: dict, interpretation: dict, do_interpret: bool = True): | ||||||
|         ''' |         ''' | ||||||
| @ -31,9 +32,14 @@ class CodeImporter: | |||||||
|         @return: |         @return: | ||||||
|         ''' |         ''' | ||||||
|         static_analysis_res = self.filter_out_vertex(static_analysis_res, interpretation) |         static_analysis_res = self.filter_out_vertex(static_analysis_res, interpretation) | ||||||
|         logger.info(f'static_analysis_res={static_analysis_res}') |  | ||||||
| 
 | 
 | ||||||
|         self.analysis_res_to_graph(static_analysis_res) |         if self.nh: | ||||||
|  |             self.analysis_res_to_graph(static_analysis_res) | ||||||
|  |         else: | ||||||
|  |             # persist to local dir | ||||||
|  |             with open(self.local_graph_file_path, 'w') as f: | ||||||
|  |                 json.dump(static_analysis_res, f) | ||||||
|  | 
 | ||||||
|         self.interpretation_to_db(static_analysis_res, interpretation, do_interpret) |         self.interpretation_to_db(static_analysis_res, interpretation, do_interpret) | ||||||
| 
 | 
 | ||||||
|     def filter_out_vertex(self, static_analysis_res, interpretation): |     def filter_out_vertex(self, static_analysis_res, interpretation): | ||||||
| @ -114,12 +120,12 @@ class CodeImporter: | |||||||
|         # create vertex |         # create vertex | ||||||
|         for tag_name, value_dict in vertex_value_dict.items(): |         for tag_name, value_dict in vertex_value_dict.items(): | ||||||
|             res = self.nh.insert_vertex(tag_name, value_dict) |             res = self.nh.insert_vertex(tag_name, value_dict) | ||||||
|             logger.debug(res.error_msg()) |             # logger.debug(res.error_msg()) | ||||||
| 
 | 
 | ||||||
|         # create edge |         # create edge | ||||||
|         for tag_name, value_dict in edge_value_dict.items(): |         for tag_name, value_dict in edge_value_dict.items(): | ||||||
|             res = self.nh.insert_edge(tag_name, value_dict) |             res = self.nh.insert_edge(tag_name, value_dict) | ||||||
|             logger.debug(res.error_msg()) |             # logger.debug(res.error_msg()) | ||||||
| 
 | 
 | ||||||
|         return |         return | ||||||
| 
 | 
 | ||||||
| @ -132,7 +138,7 @@ class CodeImporter: | |||||||
|         if do_interpret: |         if do_interpret: | ||||||
|             logger.info('start get embedding for interpretion') |             logger.info('start get embedding for interpretion') | ||||||
|             interp_list = list(interpretation.values()) |             interp_list = list(interpretation.values()) | ||||||
|             emb = get_embedding(engine=self.embed_config.embed_engine, text_list=interp_list, model_path=self.embed_config.embed_model_path, embedding_device= self.embed_config.model_device) |             emb = get_embedding(engine=self.embed_config.embed_engine, text_list=interp_list, model_path=self.embed_config.embed_model_path, embedding_device= self.embed_config.model_device, embed_config=self.embed_config) | ||||||
|             logger.info('get embedding done') |             logger.info('get embedding done') | ||||||
|         else: |         else: | ||||||
|             emb = {i: [0] for i in list(interpretation.values())} |             emb = {i: [0] for i in list(interpretation.values())} | ||||||
| @ -161,7 +167,7 @@ class CodeImporter: | |||||||
| 
 | 
 | ||||||
|         # add documents to chroma |         # add documents to chroma | ||||||
|         res = self.ch.add_data(ids=ids, embeddings=embeddings, documents=documents, metadatas=metadatas) |         res = self.ch.add_data(ids=ids, embeddings=embeddings, documents=documents, metadatas=metadatas) | ||||||
|         logger.debug(res) |         # logger.debug(res) | ||||||
| 
 | 
 | ||||||
|     def init_graph(self): |     def init_graph(self): | ||||||
|         ''' |         ''' | ||||||
| @ -169,7 +175,7 @@ class CodeImporter: | |||||||
|         @return: |         @return: | ||||||
|         ''' |         ''' | ||||||
|         res = self.nh.create_space(space_name=self.codebase_name, vid_type='FIXED_STRING(1024)') |         res = self.nh.create_space(space_name=self.codebase_name, vid_type='FIXED_STRING(1024)') | ||||||
|         logger.debug(res.error_msg()) |         # logger.debug(res.error_msg()) | ||||||
|         time.sleep(5) |         time.sleep(5) | ||||||
| 
 | 
 | ||||||
|         self.nh.set_space_name(self.codebase_name) |         self.nh.set_space_name(self.codebase_name) | ||||||
| @ -179,29 +185,29 @@ class CodeImporter: | |||||||
|         tag_name = 'package' |         tag_name = 'package' | ||||||
|         prop_dict = {} |         prop_dict = {} | ||||||
|         res = self.nh.create_tag(tag_name, prop_dict) |         res = self.nh.create_tag(tag_name, prop_dict) | ||||||
|         logger.debug(res.error_msg()) |         # logger.debug(res.error_msg()) | ||||||
| 
 | 
 | ||||||
|         tag_name = 'class' |         tag_name = 'class' | ||||||
|         prop_dict = {} |         prop_dict = {} | ||||||
|         res = self.nh.create_tag(tag_name, prop_dict) |         res = self.nh.create_tag(tag_name, prop_dict) | ||||||
|         logger.debug(res.error_msg()) |         # logger.debug(res.error_msg()) | ||||||
| 
 | 
 | ||||||
|         tag_name = 'method' |         tag_name = 'method' | ||||||
|         prop_dict = {} |         prop_dict = {} | ||||||
|         res = self.nh.create_tag(tag_name, prop_dict) |         res = self.nh.create_tag(tag_name, prop_dict) | ||||||
|         logger.debug(res.error_msg()) |         # logger.debug(res.error_msg()) | ||||||
| 
 | 
 | ||||||
|         # create edge type |         # create edge type | ||||||
|         edge_type_name = 'contain' |         edge_type_name = 'contain' | ||||||
|         prop_dict = {} |         prop_dict = {} | ||||||
|         res = self.nh.create_edge_type(edge_type_name, prop_dict) |         res = self.nh.create_edge_type(edge_type_name, prop_dict) | ||||||
|         logger.debug(res.error_msg()) |         # logger.debug(res.error_msg()) | ||||||
| 
 | 
 | ||||||
|         # create edge type |         # create edge type | ||||||
|         edge_type_name = 'depend' |         edge_type_name = 'depend' | ||||||
|         prop_dict = {} |         prop_dict = {} | ||||||
|         res = self.nh.create_edge_type(edge_type_name, prop_dict) |         res = self.nh.create_edge_type(edge_type_name, prop_dict) | ||||||
|         logger.debug(res.error_msg()) |         # logger.debug(res.error_msg()) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||||
|  | |||||||
| @ -5,16 +5,15 @@ | |||||||
| @time: 2023/11/21 下午2:25 | @time: 2023/11/21 下午2:25 | ||||||
| @desc: | @desc: | ||||||
| ''' | ''' | ||||||
|  | import os | ||||||
| import time | import time | ||||||
|  | import json | ||||||
|  | from typing import List | ||||||
| from loguru import logger | from loguru import logger | ||||||
| 
 | 
 | ||||||
| # from configs.server_config import NEBULA_HOST, NEBULA_PORT, NEBULA_USER, NEBULA_PASSWORD, NEBULA_STORAGED_PORT |  | ||||||
| # from configs.server_config import CHROMA_PERSISTENT_PATH |  | ||||||
| # from configs.model_config import EMBEDDING_ENGINE |  | ||||||
| 
 |  | ||||||
| from coagent.base_configs.env_config import ( | from coagent.base_configs.env_config import ( | ||||||
|     NEBULA_HOST, NEBULA_PORT, NEBULA_USER, NEBULA_PASSWORD, NEBULA_STORAGED_PORT, |     NEBULA_HOST, NEBULA_PORT, NEBULA_USER, NEBULA_PASSWORD, NEBULA_STORAGED_PORT, | ||||||
|     CHROMA_PERSISTENT_PATH |     CHROMA_PERSISTENT_PATH, CB_ROOT_PATH | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @ -35,7 +34,9 @@ class CodeBaseHandler: | |||||||
|             language: str = 'java',  |             language: str = 'java',  | ||||||
|             crawl_type: str = 'ZIP', |             crawl_type: str = 'ZIP', | ||||||
|             embed_config: EmbedConfig = EmbedConfig(), |             embed_config: EmbedConfig = EmbedConfig(), | ||||||
|             llm_config: LLMConfig = LLMConfig() |             llm_config: LLMConfig = LLMConfig(), | ||||||
|  |             use_nh: bool = True, | ||||||
|  |             local_graph_path: str = CB_ROOT_PATH | ||||||
|         ): |         ): | ||||||
|         self.codebase_name = codebase_name |         self.codebase_name = codebase_name | ||||||
|         self.code_path = code_path |         self.code_path = code_path | ||||||
| @ -43,11 +44,28 @@ class CodeBaseHandler: | |||||||
|         self.crawl_type = crawl_type |         self.crawl_type = crawl_type | ||||||
|         self.embed_config = embed_config |         self.embed_config = embed_config | ||||||
|         self.llm_config = llm_config |         self.llm_config = llm_config | ||||||
|  |         self.local_graph_file_path = local_graph_path + os.sep + f'{self.codebase_name}_graph.json' | ||||||
| 
 | 
 | ||||||
|         self.nh = NebulaHandler(host=NEBULA_HOST, port=NEBULA_PORT, username=NEBULA_USER, |         if use_nh: | ||||||
|                                 password=NEBULA_PASSWORD, space_name=codebase_name) |             try: | ||||||
|         self.nh.add_host(NEBULA_HOST, NEBULA_STORAGED_PORT) |                 self.nh = NebulaHandler(host=NEBULA_HOST, port=NEBULA_PORT, username=NEBULA_USER, | ||||||
|         time.sleep(1) |                                         password=NEBULA_PASSWORD, space_name=codebase_name) | ||||||
|  |                 self.nh.add_host(NEBULA_HOST, NEBULA_STORAGED_PORT) | ||||||
|  |                 time.sleep(1) | ||||||
|  |             except: | ||||||
|  |                 self.nh = None | ||||||
|  |                 try: | ||||||
|  |                     with open(self.local_graph_file_path, 'r') as f: | ||||||
|  |                         self.graph = json.load(f) | ||||||
|  |                 except: | ||||||
|  |                     pass | ||||||
|  |         elif local_graph_path: | ||||||
|  |             self.nh = None | ||||||
|  |             try: | ||||||
|  |                 with open(self.local_graph_file_path, 'r') as f: | ||||||
|  |                     self.graph = json.load(f) | ||||||
|  |             except: | ||||||
|  |                 pass | ||||||
| 
 | 
 | ||||||
|         self.ch = ChromaHandler(path=CHROMA_PERSISTENT_PATH, collection_name=codebase_name) |         self.ch = ChromaHandler(path=CHROMA_PERSISTENT_PATH, collection_name=codebase_name) | ||||||
| 
 | 
 | ||||||
| @ -58,9 +76,10 @@ class CodeBaseHandler: | |||||||
|         ''' |         ''' | ||||||
|         # init graph to init tag and edge |         # init graph to init tag and edge | ||||||
|         code_importer = CodeImporter(embed_config=self.embed_config, codebase_name=self.codebase_name, |         code_importer = CodeImporter(embed_config=self.embed_config, codebase_name=self.codebase_name, | ||||||
|                                      nh=self.nh, ch=self.ch) |                                      nh=self.nh, ch=self.ch, local_graph_file_path=self.local_graph_file_path) | ||||||
|         code_importer.init_graph() |         if self.nh: | ||||||
|         time.sleep(5) |             code_importer.init_graph() | ||||||
|  |             time.sleep(5) | ||||||
| 
 | 
 | ||||||
|         # crawl code |         # crawl code | ||||||
|         st0 = time.time() |         st0 = time.time() | ||||||
| @ -71,7 +90,7 @@ class CodeBaseHandler: | |||||||
|         # analyze code |         # analyze code | ||||||
|         logger.info('start analyze') |         logger.info('start analyze') | ||||||
|         st1 = time.time() |         st1 = time.time() | ||||||
|         code_analyzer = CodeAnalyzer(language=self.language, llm_config = self.llm_config) |         code_analyzer = CodeAnalyzer(language=self.language, llm_config=self.llm_config) | ||||||
|         static_analysis_res, interpretation = code_analyzer.analyze(code_dict, do_interpret=do_interpret) |         static_analysis_res, interpretation = code_analyzer.analyze(code_dict, do_interpret=do_interpret) | ||||||
|         logger.debug('analyze done, rt={}'.format(time.time() - st1)) |         logger.debug('analyze done, rt={}'.format(time.time() - st1)) | ||||||
| 
 | 
 | ||||||
| @ -81,8 +100,12 @@ class CodeBaseHandler: | |||||||
|         logger.debug('update codebase done, rt={}'.format(time.time() - st2)) |         logger.debug('update codebase done, rt={}'.format(time.time() - st2)) | ||||||
| 
 | 
 | ||||||
|         # get KG info |         # get KG info | ||||||
|         stat = self.nh.get_stat() |         if self.nh: | ||||||
|         vertices_num, edges_num = stat['vertices'], stat['edges'] |             stat = self.nh.get_stat() | ||||||
|  |             vertices_num, edges_num = stat['vertices'], stat['edges'] | ||||||
|  |         else: | ||||||
|  |             vertices_num = 0 | ||||||
|  |             edges_num = 0 | ||||||
| 
 | 
 | ||||||
|         # get chroma info |         # get chroma info | ||||||
|         file_num = self.ch.count()['result'] |         file_num = self.ch.count()['result'] | ||||||
| @ -95,7 +118,11 @@ class CodeBaseHandler: | |||||||
|         @param codebase_name: name of codebase |         @param codebase_name: name of codebase | ||||||
|         @return: |         @return: | ||||||
|         ''' |         ''' | ||||||
|         self.nh.drop_space(space_name=codebase_name) |         if self.nh: | ||||||
|  |             self.nh.drop_space(space_name=codebase_name) | ||||||
|  |         elif self.local_graph_file_path and os.path.isfile(self.local_graph_file_path): | ||||||
|  |             os.remove(self.local_graph_file_path) | ||||||
|  | 
 | ||||||
|         self.ch.delete_collection(collection_name=codebase_name) |         self.ch.delete_collection(collection_name=codebase_name) | ||||||
| 
 | 
 | ||||||
|     def crawl_code(self, zip_file=''): |     def crawl_code(self, zip_file=''): | ||||||
| @ -124,9 +151,15 @@ class CodeBaseHandler: | |||||||
|         @param search_type: ['cypher', 'graph', 'vector'] |         @param search_type: ['cypher', 'graph', 'vector'] | ||||||
|         @return:  |         @return:  | ||||||
|         ''' |         ''' | ||||||
|         assert search_type in ['cypher', 'tag', 'description'] |         if self.nh: | ||||||
|  |             assert search_type in ['cypher', 'tag', 'description'] | ||||||
|  |         else: | ||||||
|  |             if search_type == 'tag': | ||||||
|  |                 search_type = 'tag_by_local_graph' | ||||||
|  |             assert search_type in ['tag_by_local_graph', 'description'] | ||||||
| 
 | 
 | ||||||
|         code_search = CodeSearch(llm_config=self.llm_config, nh=self.nh, ch=self.ch, limit=limit) |         code_search = CodeSearch(llm_config=self.llm_config, nh=self.nh, ch=self.ch, limit=limit, | ||||||
|  |                                  local_graph_file_path=self.local_graph_file_path) | ||||||
| 
 | 
 | ||||||
|         if search_type == 'cypher': |         if search_type == 'cypher': | ||||||
|             search_res = code_search.search_by_cypher(query=query) |             search_res = code_search.search_by_cypher(query=query) | ||||||
| @ -134,7 +167,11 @@ class CodeBaseHandler: | |||||||
|             search_res = code_search.search_by_tag(query=query) |             search_res = code_search.search_by_tag(query=query) | ||||||
|         elif search_type == 'description': |         elif search_type == 'description': | ||||||
|             search_res = code_search.search_by_desciption( |             search_res = code_search.search_by_desciption( | ||||||
|                 query=query, engine=self.embed_config.embed_engine, model_path=self.embed_config.embed_model_path, embedding_device=self.embed_config.model_device) |                 query=query, engine=self.embed_config.embed_engine, model_path=self.embed_config.embed_model_path,  | ||||||
|  |                 embedding_device=self.embed_config.model_device, embed_config=self.embed_config) | ||||||
|  |         elif search_type == 'tag_by_local_graph': | ||||||
|  |             search_res = code_search.search_by_tag_by_graph(query=query) | ||||||
|  | 
 | ||||||
| 
 | 
 | ||||||
|         context, related_vertice = self.format_search_res(search_res, search_type) |         context, related_vertice = self.format_search_res(search_res, search_type) | ||||||
|         return context, related_vertice |         return context, related_vertice | ||||||
| @ -160,6 +197,12 @@ class CodeBaseHandler: | |||||||
|             for code in search_res: |             for code in search_res: | ||||||
|                 context = context + code['code_text'] + '\n' |                 context = context + code['code_text'] + '\n' | ||||||
|                 related_vertice.append(code['vertex']) |                 related_vertice.append(code['vertex']) | ||||||
|  |         elif search_type == 'tag_by_local_graph': | ||||||
|  |             context = '' | ||||||
|  |             related_vertice = [] | ||||||
|  |             for code in search_res: | ||||||
|  |                 context = context + code['code_text'] + '\n' | ||||||
|  |                 related_vertice.append(code['vertex']) | ||||||
|         elif search_type == 'description': |         elif search_type == 'description': | ||||||
|             context = '' |             context = '' | ||||||
|             related_vertice = [] |             related_vertice = [] | ||||||
| @ -169,17 +212,63 @@ class CodeBaseHandler: | |||||||
| 
 | 
 | ||||||
|         return context, related_vertice |         return context, related_vertice | ||||||
| 
 | 
 | ||||||
|  |     def search_vertices(self, vertex_type="class") -> List[str]: | ||||||
|  |         ''' | ||||||
|  |         通过 method/class 来搜索所有的节点 | ||||||
|  |         ''' | ||||||
|  |         vertices = [] | ||||||
|  |         if self.nh: | ||||||
|  |             vertices = self.nh.get_all_vertices() | ||||||
|  |             vertices = [str(v.as_node().get_id()) for v in vertices["v"] if vertex_type in v.as_node().tags()] | ||||||
|  |             # for v in vertices["v"]: | ||||||
|  |             #     logger.debug(f"{v.as_node().get_id()}, {v.as_node().tags()}") | ||||||
|  |         else: | ||||||
|  |             if vertex_type == "class": | ||||||
|  |                 vertices = [str(class_name) for code, structure in self.graph.items() for class_name in structure['class_name_list']] | ||||||
|  |             elif vertex_type == "method": | ||||||
|  |                 vertices = [ | ||||||
|  |                     str(methods_name) | ||||||
|  |                     for code, structure in self.graph.items()  | ||||||
|  |                     for methods_names in structure['func_name_dict'].values()  | ||||||
|  |                     for methods_name in methods_names | ||||||
|  |                 ] | ||||||
|  |         # logger.debug(vertices) | ||||||
|  |         return vertices | ||||||
|  | 
 | ||||||
| 
 | 
 | ||||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||||
|     codebase_name = 'testing' |     from configs.model_config import KB_ROOT_PATH, JUPYTER_WORK_PATH | ||||||
|  |     from configs.server_config import SANDBOX_SERVER | ||||||
|  | 
 | ||||||
|  |     LLM_MODEL = "gpt-3.5-turbo" | ||||||
|  |     llm_config = LLMConfig( | ||||||
|  |         model_name=LLM_MODEL, model_device="cpu", api_key=os.environ["OPENAI_API_KEY"], | ||||||
|  |         api_base_url=os.environ["API_BASE_URL"], temperature=0.3 | ||||||
|  |     ) | ||||||
|  |     src_dir = '/Users/bingxu/Desktop/工作/大模型/chatbot/Codefuse-chatbot-antcode' | ||||||
|  |     embed_config = EmbedConfig( | ||||||
|  |         embed_engine="model", embed_model="text2vec-base-chinese", | ||||||
|  |         embed_model_path=os.path.join(src_dir, "embedding_models/text2vec-base-chinese") | ||||||
|  |     ) | ||||||
|  | 
 | ||||||
|  |     codebase_name = 'client_local' | ||||||
|     code_path = '/Users/bingxu/Desktop/工作/大模型/chatbot/test_code_repo/client' |     code_path = '/Users/bingxu/Desktop/工作/大模型/chatbot/test_code_repo/client' | ||||||
|     cbh = CodeBaseHandler(codebase_name, code_path, crawl_type='dir') |     use_nh = False | ||||||
|  |     local_graph_path = '/Users/bingxu/Desktop/工作/大模型/chatbot/Codefuse-chatbot-antcode/code_base' | ||||||
|  |     CHROMA_PERSISTENT_PATH = '/Users/bingxu/Desktop/工作/大模型/chatbot/Codefuse-chatbot-antcode/data/chroma_data' | ||||||
|  | 
 | ||||||
|  |     cbh = CodeBaseHandler(codebase_name, code_path, crawl_type='dir', use_nh=use_nh, local_graph_path=local_graph_path, | ||||||
|  |                           llm_config=llm_config, embed_config=embed_config) | ||||||
|  | 
 | ||||||
|  |     # test import code | ||||||
|  |     # cbh.import_code(do_interpret=True) | ||||||
| 
 | 
 | ||||||
|     # query = '使用不同的HTTP请求类型(GET、POST、DELETE等)来执行不同的操作' |     # query = '使用不同的HTTP请求类型(GET、POST、DELETE等)来执行不同的操作' | ||||||
|     # query = '代码中一共有多少个类' |     # query = '代码中一共有多少个类' | ||||||
|  |     # query = 'remove 这个函数是用来做什么的' | ||||||
|  |     query = '有没有函数是从字符串中删除指定字符串的功能' | ||||||
| 
 | 
 | ||||||
|     query = 'intercept 函数作用是什么' |     search_type = 'description' | ||||||
|     search_type = 'graph' |  | ||||||
|     limit = 2 |     limit = 2 | ||||||
|     res = cbh.search_code(query, search_type, limit) |     res = cbh.search_code(query, search_type, limit) | ||||||
|     logger.debug(res) |     logger.debug(res) | ||||||
|  | |||||||
							
								
								
									
										6
									
								
								coagent/connector/actions/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										6
									
								
								coagent/connector/actions/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,6 @@ | |||||||
|  | from .base_action import BaseAction | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | __all__ = [ | ||||||
|  |     "BaseAction" | ||||||
|  | ] | ||||||
							
								
								
									
										16
									
								
								coagent/connector/actions/base_action.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										16
									
								
								coagent/connector/actions/base_action.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,16 @@ | |||||||
|  | 
 | ||||||
|  | from langchain.schema import BaseRetriever, Document | ||||||
|  | 
 | ||||||
|  | class BaseAction: | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  |     def __init__(self, ): | ||||||
|  |         pass | ||||||
|  | 
 | ||||||
|  |     def step(self, ): | ||||||
|  |         pass | ||||||
|  | 
 | ||||||
|  |     def astep(self, ): | ||||||
|  |         pass | ||||||
|  | 
 | ||||||
|  |      | ||||||
| @ -4,25 +4,25 @@ import re, os | |||||||
| import copy | import copy | ||||||
| from loguru import logger | from loguru import logger | ||||||
| 
 | 
 | ||||||
|  | from langchain.schema import BaseRetriever | ||||||
|  | 
 | ||||||
| from coagent.connector.schema import ( | from coagent.connector.schema import ( | ||||||
|     Memory, Task, Role, Message, PromptField, LogVerboseEnum |     Memory, Task, Role, Message, PromptField, LogVerboseEnum | ||||||
| ) | ) | ||||||
| from coagent.connector.memory_manager import BaseMemoryManager | from coagent.connector.memory_manager import BaseMemoryManager | ||||||
| from coagent.connector.configs.prompts import BEGIN_PROMPT_INPUT |  | ||||||
| from coagent.connector.message_process import MessageUtils | from coagent.connector.message_process import MessageUtils | ||||||
| from coagent.llm_models import getChatModel, getExtraModel, LLMConfig, getChatModelFromConfig, EmbedConfig | from coagent.llm_models import getExtraModel, LLMConfig, getChatModelFromConfig, EmbedConfig | ||||||
| from coagent.connector.prompt_manager import PromptManager | from coagent.connector.prompt_manager.prompt_manager import PromptManager | ||||||
| from coagent.connector.memory_manager import LocalMemoryManager | from coagent.connector.memory_manager import LocalMemoryManager | ||||||
| from coagent.connector.utils import parse_section | from coagent.base_configs.env_config import JUPYTER_WORK_PATH, KB_ROOT_PATH | ||||||
| # from configs.model_config import JUPYTER_WORK_PATH | 
 | ||||||
| # from configs.server_config import SANDBOX_SERVER |  | ||||||
| 
 | 
 | ||||||
| class BaseAgent: | class BaseAgent: | ||||||
| 
 | 
 | ||||||
|     def __init__( |     def __init__( | ||||||
|             self,  |             self,  | ||||||
|             role: Role, |             role: Role, | ||||||
|             prompt_config: [PromptField], |             prompt_config: List[PromptField], | ||||||
|             prompt_manager_type: str = "PromptManager", |             prompt_manager_type: str = "PromptManager", | ||||||
|             task: Task = None, |             task: Task = None, | ||||||
|             memory: Memory = None, |             memory: Memory = None, | ||||||
| @ -33,8 +33,11 @@ class BaseAgent: | |||||||
|             llm_config: LLMConfig = None, |             llm_config: LLMConfig = None, | ||||||
|             embed_config: EmbedConfig = None, |             embed_config: EmbedConfig = None, | ||||||
|             sandbox_server: dict = {}, |             sandbox_server: dict = {}, | ||||||
|             jupyter_work_path: str = "", |             jupyter_work_path: str = JUPYTER_WORK_PATH, | ||||||
|             kb_root_path: str = "", |             kb_root_path: str = KB_ROOT_PATH, | ||||||
|  |             doc_retrieval: Union[BaseRetriever] = None, | ||||||
|  |             code_retrieval = None, | ||||||
|  |             search_retrieval = None, | ||||||
|             log_verbose: str = "0" |             log_verbose: str = "0" | ||||||
|             ): |             ): | ||||||
|          |          | ||||||
| @ -43,7 +46,7 @@ class BaseAgent: | |||||||
|         self.sandbox_server = sandbox_server |         self.sandbox_server = sandbox_server | ||||||
|         self.jupyter_work_path = jupyter_work_path |         self.jupyter_work_path = jupyter_work_path | ||||||
|         self.kb_root_path = kb_root_path |         self.kb_root_path = kb_root_path | ||||||
|         self.message_utils = MessageUtils(role, sandbox_server, jupyter_work_path, embed_config, llm_config, kb_root_path, log_verbose) |         self.message_utils = MessageUtils(role, sandbox_server, jupyter_work_path, embed_config, llm_config, kb_root_path, doc_retrieval, code_retrieval, search_retrieval, log_verbose) | ||||||
|         self.memory = self.init_history(memory) |         self.memory = self.init_history(memory) | ||||||
|         self.llm_config: LLMConfig = llm_config |         self.llm_config: LLMConfig = llm_config | ||||||
|         self.embed_config: EmbedConfig = embed_config |         self.embed_config: EmbedConfig = embed_config | ||||||
| @ -82,12 +85,8 @@ class BaseAgent: | |||||||
|                 llm_config=self.embed_config |                 llm_config=self.embed_config | ||||||
|             ) |             ) | ||||||
|             memory_manager.append(query) |             memory_manager.append(query) | ||||||
|             memory_pool = memory_manager.current_memory |         memory_pool = memory_manager.get_memory_pool(query.user_name) | ||||||
|         else: |  | ||||||
|             memory_pool = memory_manager.current_memory |  | ||||||
| 
 | 
 | ||||||
| 
 |  | ||||||
|         logger.debug(f"memory_pool: {memory_pool}") |  | ||||||
|         prompt = self.prompt_manager.generate_full_prompt( |         prompt = self.prompt_manager.generate_full_prompt( | ||||||
|             previous_agent_message=query_c, agent_long_term_memory=self.memory, ui_history=history, chain_summary_messages=background, memory_pool=memory_pool) |             previous_agent_message=query_c, agent_long_term_memory=self.memory, ui_history=history, chain_summary_messages=background, memory_pool=memory_pool) | ||||||
|         content = self.llm.predict(prompt) |         content = self.llm.predict(prompt) | ||||||
| @ -99,6 +98,7 @@ class BaseAgent: | |||||||
|             logger.info(f"{self.role.role_name} content: {content}") |             logger.info(f"{self.role.role_name} content: {content}") | ||||||
| 
 | 
 | ||||||
|         output_message = Message( |         output_message = Message( | ||||||
|  |             user_name=query.user_name, | ||||||
|             role_name=self.role.role_name, |             role_name=self.role.role_name, | ||||||
|             role_type="assistant", #self.role.role_type, |             role_type="assistant", #self.role.role_type, | ||||||
|             role_content=content, |             role_content=content, | ||||||
| @ -151,10 +151,7 @@ class BaseAgent: | |||||||
|         self.memory = self.init_history() |         self.memory = self.init_history() | ||||||
|      |      | ||||||
|     def create_llm_engine(self, llm_config: LLMConfig = None, temperature=0.2, stop=None): |     def create_llm_engine(self, llm_config: LLMConfig = None, temperature=0.2, stop=None): | ||||||
|         if llm_config is None: |         return getChatModelFromConfig(llm_config=llm_config) | ||||||
|             return getChatModel(temperature=temperature, stop=stop) |  | ||||||
|         else: |  | ||||||
|             return getChatModelFromConfig(llm_config=llm_config) |  | ||||||
|      |      | ||||||
|     def registry_actions(self, actions): |     def registry_actions(self, actions): | ||||||
|         '''registry llm's actions''' |         '''registry llm's actions''' | ||||||
| @ -212,171 +209,3 @@ class BaseAgent: | |||||||
|      |      | ||||||
|     def get_memory_str(self, content_key="role_content"): |     def get_memory_str(self, content_key="role_content"): | ||||||
|         return "\n".join([": ".join(i) for i in self.memory.to_tuple_messages(content_key="step_content")]) |         return "\n".join([": ".join(i) for i in self.memory.to_tuple_messages(content_key="step_content")]) | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     def create_prompt( |  | ||||||
|             self, query: Message, memory: Memory =None, history: Memory = None, background: Memory = None, memory_pool: Memory=None, prompt_mamnger=None) -> str: |  | ||||||
|         ''' |  | ||||||
|         prompt engineer, contains role\task\tools\docs\memory |  | ||||||
|         ''' |  | ||||||
|         #  |  | ||||||
|         doc_infos = self.create_doc_prompt(query) |  | ||||||
|         code_infos = self.create_codedoc_prompt(query) |  | ||||||
|         #  |  | ||||||
|         formatted_tools, tool_names, _ = self.create_tools_prompt(query) |  | ||||||
|         task_prompt = self.create_task_prompt(query) |  | ||||||
|         background_prompt = self.create_background_prompt(background, control_key="step_content") |  | ||||||
|         history_prompt = self.create_history_prompt(history) |  | ||||||
|         selfmemory_prompt = self.create_selfmemory_prompt(memory, control_key="step_content") |  | ||||||
|          |  | ||||||
|         # extra_system_prompt = self.role.role_prompt |  | ||||||
| 
 |  | ||||||
|          |  | ||||||
|         prompt = self.role.role_prompt.format(**{"formatted_tools": formatted_tools, "tool_names": tool_names}) |  | ||||||
|         # |  | ||||||
|         memory_pool_select_by_agent_key = self.select_memory_by_agent_key(memory_pool) |  | ||||||
|         memory_pool_select_by_agent_key_context = '\n\n'.join([f"*{k}*\n{v}" for parsed_output in memory_pool_select_by_agent_key.get_parserd_output_list() for k, v in parsed_output.items() if k not in ['Action Status']]) |  | ||||||
|          |  | ||||||
|         # input_query = query.input_query |  | ||||||
| 
 |  | ||||||
|         # # logger.debug(f"{self.role.role_name}  extra_system_prompt: {self.role.role_prompt}") |  | ||||||
|         # # logger.debug(f"{self.role.role_name}  input_query: {input_query}") |  | ||||||
|         # # logger.debug(f"{self.role.role_name}  doc_infos: {doc_infos}") |  | ||||||
|         # # logger.debug(f"{self.role.role_name}  tool_names: {tool_names}") |  | ||||||
|         # if "**Context:**" in self.role.role_prompt: |  | ||||||
|         #     # logger.debug(f"parsed_output_list: {query.parsed_output_list}") |  | ||||||
|         #     # input_query =  "'''" + "\n".join([f"###{k}###\n{v}" for i in query.parsed_output_list for k,v in i.items() if "Action Status" !=k]) + "'''" |  | ||||||
|         #     context =  "\n".join([f"*{k}*\n{v}" for i in query.parsed_output_list for k,v in i.items() if "Action Status" !=k]) |  | ||||||
|         #     # context = history_prompt or '""' |  | ||||||
|         #     # logger.debug(f"parsed_output_list: {t}") |  | ||||||
|         #     prompt += "\n" + QUERY_CONTEXT_PROMPT_INPUT.format(**{"context": context, "query": query.origin_query}) |  | ||||||
|         # else: |  | ||||||
|         #     prompt += "\n" + PLAN_PROMPT_INPUT.format(**{"query": input_query}) |  | ||||||
| 
 |  | ||||||
|         task = query.task or self.task |  | ||||||
|         if task_prompt is not None: |  | ||||||
|             prompt += "\n" + task.task_prompt |  | ||||||
| 
 |  | ||||||
|         DocInfos = "" |  | ||||||
|         if doc_infos is not None and doc_infos!="" and doc_infos!="不存在知识库辅助信息": |  | ||||||
|             DocInfos += f"\nDocument Information: {doc_infos}" |  | ||||||
| 
 |  | ||||||
|         if code_infos is not None and code_infos!="" and code_infos!="不存在代码库辅助信息": |  | ||||||
|             DocInfos += f"\nCodeBase Infomation: {code_infos}" |  | ||||||
|              |  | ||||||
|         # if selfmemory_prompt: |  | ||||||
|         #     prompt += "\n" + selfmemory_prompt |  | ||||||
| 
 |  | ||||||
|         # if background_prompt: |  | ||||||
|         #     prompt += "\n" + background_prompt |  | ||||||
| 
 |  | ||||||
|         # if history_prompt: |  | ||||||
|         #     prompt += "\n" + history_prompt |  | ||||||
| 
 |  | ||||||
|         input_query = query.input_query |  | ||||||
| 
 |  | ||||||
|         # logger.debug(f"{self.role.role_name}  extra_system_prompt: {self.role.role_prompt}") |  | ||||||
|         # logger.debug(f"{self.role.role_name}  input_query: {input_query}") |  | ||||||
|         # logger.debug(f"{self.role.role_name}  doc_infos: {doc_infos}") |  | ||||||
|         # logger.debug(f"{self.role.role_name}  tool_names: {tool_names}") |  | ||||||
| 
 |  | ||||||
|         # extra_system_prompt = self.role.role_prompt |  | ||||||
|         input_keys = parse_section(self.role.role_prompt, 'Input Format') |  | ||||||
|         prompt = self.role.role_prompt.format(**{"formatted_tools": formatted_tools, "tool_names": tool_names}) |  | ||||||
|         prompt += "\n" + BEGIN_PROMPT_INPUT |  | ||||||
|         for input_key in input_keys: |  | ||||||
|             if input_key == "Origin Query":  |  | ||||||
|                 prompt += "\n**Origin Query:**\n" + query.origin_query |  | ||||||
|             elif input_key == "Context": |  | ||||||
|                 context =  "\n".join([f"*{k}*\n{v}" for i in query.parsed_output_list for k,v in i.items() if "Action Status" !=k]) |  | ||||||
|                 if history: |  | ||||||
|                     context = history_prompt + "\n" + context |  | ||||||
|                 if not context: |  | ||||||
|                     context = "there is no context" |  | ||||||
| 
 |  | ||||||
|                 if self.focus_agents and memory_pool_select_by_agent_key_context: |  | ||||||
|                     context = memory_pool_select_by_agent_key_context |  | ||||||
|                 prompt += "\n**Context:**\n" + context + "\n" + input_query |  | ||||||
|             elif input_key == "DocInfos": |  | ||||||
|                 if DocInfos: |  | ||||||
|                     prompt += "\n**DocInfos:**\n" + DocInfos |  | ||||||
|                 else: |  | ||||||
|                     prompt += "\n**DocInfos:**\n" + "Empty" |  | ||||||
|             elif input_key == "Question": |  | ||||||
|                 prompt += "\n**Question:**\n" + input_query |  | ||||||
| 
 |  | ||||||
|         # if "**Context:**" in self.role.role_prompt: |  | ||||||
|         #     # logger.debug(f"parsed_output_list: {query.parsed_output_list}") |  | ||||||
|         #     # input_query =  "'''" + "\n".join([f"###{k}###\n{v}" for i in query.parsed_output_list for k,v in i.items() if "Action Status" !=k]) + "'''" |  | ||||||
|         #     context =  "\n".join([f"*{k}*\n{v}" for i in query.parsed_output_list for k,v in i.items() if "Action Status" !=k]) |  | ||||||
|         #     if history: |  | ||||||
|         #         context = history_prompt + "\n" + context |  | ||||||
| 
 |  | ||||||
|         #     if not context: |  | ||||||
|         #         context = "there is no context" |  | ||||||
| 
 |  | ||||||
|         #     # logger.debug(f"parsed_output_list: {t}") |  | ||||||
|         #     if "DocInfos" in prompt: |  | ||||||
|         #         prompt += "\n" + QUERY_CONTEXT_DOC_PROMPT_INPUT.format(**{"context": context, "query": query.origin_query, "DocInfos": DocInfos}) |  | ||||||
|         #     else: |  | ||||||
|         #         prompt += "\n" + QUERY_CONTEXT_PROMPT_INPUT.format(**{"context": context, "query": query.origin_query, "DocInfos": DocInfos}) |  | ||||||
|         # else: |  | ||||||
|         #     prompt += "\n" + BASE_PROMPT_INPUT.format(**{"query": input_query}) |  | ||||||
| 
 |  | ||||||
|         # prompt = extra_system_prompt.format(**{"query": input_query, "doc_infos": doc_infos, "formatted_tools": formatted_tools, "tool_names": tool_names}) |  | ||||||
|         while "{{" in prompt or "}}" in prompt: |  | ||||||
|             prompt = prompt.replace("{{", "{") |  | ||||||
|             prompt = prompt.replace("}}", "}") |  | ||||||
| 
 |  | ||||||
|         # logger.debug(f"{self.role.role_name}  prompt: {prompt}") |  | ||||||
|         return prompt |  | ||||||
|      |  | ||||||
|     def create_doc_prompt(self, message: Message) -> str:  |  | ||||||
|         '''''' |  | ||||||
|         db_docs = message.db_docs |  | ||||||
|         search_docs = message.search_docs |  | ||||||
|         doc_infos = "\n".join([doc.get_snippet() for doc in db_docs] + [doc.get_snippet() for doc in search_docs]) |  | ||||||
|         return doc_infos or "不存在知识库辅助信息" |  | ||||||
|      |  | ||||||
|     def create_codedoc_prompt(self, message: Message) -> str:  |  | ||||||
|         '''''' |  | ||||||
|         code_docs = message.code_docs |  | ||||||
|         doc_infos = "\n".join([doc.get_code() for doc in code_docs]) |  | ||||||
|         return doc_infos or "不存在代码库辅助信息" |  | ||||||
|      |  | ||||||
|     def create_tools_prompt(self, message: Message) -> str: |  | ||||||
|         tools = message.tools |  | ||||||
|         tool_strings = [] |  | ||||||
|         tools_descs = [] |  | ||||||
|         for tool in tools: |  | ||||||
|             args_schema = re.sub("}", "}}}}", re.sub("{", "{{{{", str(tool.args))) |  | ||||||
|             tool_strings.append(f"{tool.name}: {tool.description}, args: {args_schema}") |  | ||||||
|             tools_descs.append(f"{tool.name}: {tool.description}") |  | ||||||
|         formatted_tools = "\n".join(tool_strings) |  | ||||||
|         tools_desc_str = "\n".join(tools_descs) |  | ||||||
|         tool_names = ", ".join([tool.name for tool in tools]) |  | ||||||
|         return formatted_tools, tool_names, tools_desc_str |  | ||||||
|      |  | ||||||
|     def create_task_prompt(self, message: Message) -> str: |  | ||||||
|         task = message.task or self.task |  | ||||||
|         return "\n任务目标: " + task.task_prompt if task is not None else None |  | ||||||
|      |  | ||||||
|     def create_background_prompt(self, background: Memory, control_key="role_content") -> str: |  | ||||||
|         background_message = None if background is None else background.to_str_messages(content_key=control_key) |  | ||||||
|         # logger.debug(f"background_message: {background_message}") |  | ||||||
|         if background_message: |  | ||||||
|             background_message = re.sub("}", "}}", re.sub("{", "{{", background_message)) |  | ||||||
|         return "\n背景信息: " + background_message if background_message else None |  | ||||||
|      |  | ||||||
|     def create_history_prompt(self, history: Memory, control_key="role_content") -> str: |  | ||||||
|         history_message = None if history is None else history.to_str_messages(content_key=control_key) |  | ||||||
|         if history_message: |  | ||||||
|             history_message = re.sub("}", "}}", re.sub("{", "{{", history_message)) |  | ||||||
|         return "\n补充对话信息: " + history_message if history_message else None |  | ||||||
|      |  | ||||||
|     def create_selfmemory_prompt(self, selfmemory: Memory, control_key="role_content") -> str: |  | ||||||
|         selfmemory_message = None if selfmemory is None else selfmemory.to_str_messages(content_key=control_key) |  | ||||||
|         if selfmemory_message: |  | ||||||
|             selfmemory_message = re.sub("}", "}}", re.sub("{", "{{", selfmemory_message)) |  | ||||||
|         return "\n补充自身对话信息: " + selfmemory_message if selfmemory_message else None |  | ||||||
|      |  | ||||||
| @ -2,14 +2,15 @@ from typing import List, Union | |||||||
| import copy | import copy | ||||||
| from loguru import logger | from loguru import logger | ||||||
| 
 | 
 | ||||||
|  | from langchain.schema import BaseRetriever | ||||||
|  | 
 | ||||||
| from coagent.connector.schema import ( | from coagent.connector.schema import ( | ||||||
|     Memory, Task, Env, Role, Message, ActionStatus, PromptField, LogVerboseEnum |     Memory, Task, Env, Role, Message, ActionStatus, PromptField, LogVerboseEnum | ||||||
| ) | ) | ||||||
| from coagent.connector.memory_manager import BaseMemoryManager | from coagent.connector.memory_manager import BaseMemoryManager | ||||||
| from coagent.connector.configs.prompts import BEGIN_PROMPT_INPUT |  | ||||||
| from coagent.llm_models import LLMConfig, EmbedConfig | from coagent.llm_models import LLMConfig, EmbedConfig | ||||||
| from coagent.connector.memory_manager import LocalMemoryManager | from coagent.connector.memory_manager import LocalMemoryManager | ||||||
| 
 | from coagent.base_configs.env_config import JUPYTER_WORK_PATH, KB_ROOT_PATH | ||||||
| from .base_agent import BaseAgent | from .base_agent import BaseAgent | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @ -17,7 +18,7 @@ class ExecutorAgent(BaseAgent): | |||||||
|     def __init__( |     def __init__( | ||||||
|             self,  |             self,  | ||||||
|             role: Role, |             role: Role, | ||||||
|             prompt_config: [PromptField], |             prompt_config: List[PromptField], | ||||||
|             prompt_manager_type: str= "PromptManager", |             prompt_manager_type: str= "PromptManager", | ||||||
|             task: Task = None, |             task: Task = None, | ||||||
|             memory: Memory = None, |             memory: Memory = None, | ||||||
| @ -28,14 +29,17 @@ class ExecutorAgent(BaseAgent): | |||||||
|             llm_config: LLMConfig = None, |             llm_config: LLMConfig = None, | ||||||
|             embed_config: EmbedConfig = None, |             embed_config: EmbedConfig = None, | ||||||
|             sandbox_server: dict = {}, |             sandbox_server: dict = {}, | ||||||
|             jupyter_work_path: str = "", |             jupyter_work_path: str = JUPYTER_WORK_PATH, | ||||||
|             kb_root_path: str = "", |             kb_root_path: str = KB_ROOT_PATH, | ||||||
|  |             doc_retrieval: Union[BaseRetriever] = None, | ||||||
|  |             code_retrieval = None, | ||||||
|  |             search_retrieval = None, | ||||||
|             log_verbose: str = "0" |             log_verbose: str = "0" | ||||||
|             ): |             ): | ||||||
|          |          | ||||||
|         super().__init__(role, prompt_config, prompt_manager_type, task, memory, chat_turn, |         super().__init__(role, prompt_config, prompt_manager_type, task, memory, chat_turn, | ||||||
|                          focus_agents, focus_message_keys, llm_config, embed_config, sandbox_server, |                          focus_agents, focus_message_keys, llm_config, embed_config, sandbox_server, | ||||||
|                          jupyter_work_path, kb_root_path, log_verbose |                          jupyter_work_path, kb_root_path, doc_retrieval, code_retrieval, search_retrieval, log_verbose | ||||||
|                          ) |                          ) | ||||||
|         self.do_all_task = True # run all tasks |         self.do_all_task = True # run all tasks | ||||||
| 
 | 
 | ||||||
| @ -45,6 +49,7 @@ class ExecutorAgent(BaseAgent): | |||||||
|         task_executor_memory = Memory(messages=[]) |         task_executor_memory = Memory(messages=[]) | ||||||
|         # insert query |         # insert query | ||||||
|         output_message = Message( |         output_message = Message( | ||||||
|  |                 user_name=query.user_name, | ||||||
|                 role_name=self.role.role_name, |                 role_name=self.role.role_name, | ||||||
|                 role_type="assistant", #self.role.role_type, |                 role_type="assistant", #self.role.role_type, | ||||||
|                 role_content=query.input_query, |                 role_content=query.input_query, | ||||||
| @ -115,7 +120,7 @@ class ExecutorAgent(BaseAgent): | |||||||
|             history: Memory, background: Memory, memory_manager: BaseMemoryManager,  |             history: Memory, background: Memory, memory_manager: BaseMemoryManager,  | ||||||
|             task_memory: Memory) -> Union[Message, Memory]: |             task_memory: Memory) -> Union[Message, Memory]: | ||||||
|         '''execute the llm predict by created prompt''' |         '''execute the llm predict by created prompt''' | ||||||
|         memory_pool = memory_manager.current_memory |         memory_pool = memory_manager.get_memory_pool(query.user_name) | ||||||
|         prompt = self.prompt_manager.generate_full_prompt( |         prompt = self.prompt_manager.generate_full_prompt( | ||||||
|             previous_agent_message=query, agent_long_term_memory=self_memory, ui_history=history, chain_summary_messages=background, memory_pool=memory_pool, |             previous_agent_message=query, agent_long_term_memory=self_memory, ui_history=history, chain_summary_messages=background, memory_pool=memory_pool, | ||||||
|             task_memory=task_memory) |             task_memory=task_memory) | ||||||
|  | |||||||
| @ -3,23 +3,23 @@ import traceback | |||||||
| import copy | import copy | ||||||
| from loguru import logger | from loguru import logger | ||||||
| 
 | 
 | ||||||
|  | from langchain.schema import BaseRetriever | ||||||
|  | 
 | ||||||
| from coagent.connector.schema import ( | from coagent.connector.schema import ( | ||||||
|     Memory, Task, Env, Role, Message, ActionStatus, PromptField, LogVerboseEnum |     Memory, Task, Env, Role, Message, ActionStatus, PromptField, LogVerboseEnum | ||||||
| ) | ) | ||||||
| from coagent.connector.memory_manager import BaseMemoryManager | from coagent.connector.memory_manager import BaseMemoryManager | ||||||
| from coagent.connector.configs.agent_config import REACT_PROMPT_INPUT |  | ||||||
| from coagent.llm_models import LLMConfig, EmbedConfig | from coagent.llm_models import LLMConfig, EmbedConfig | ||||||
| from .base_agent import BaseAgent | from .base_agent import BaseAgent | ||||||
| from coagent.connector.memory_manager import LocalMemoryManager | from coagent.connector.memory_manager import LocalMemoryManager | ||||||
| 
 | from coagent.base_configs.env_config import JUPYTER_WORK_PATH, KB_ROOT_PATH | ||||||
| from coagent.connector.prompt_manager import PromptManager |  | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class ReactAgent(BaseAgent): | class ReactAgent(BaseAgent): | ||||||
|     def __init__( |     def __init__( | ||||||
|             self,  |             self,  | ||||||
|             role: Role, |             role: Role, | ||||||
|             prompt_config: [PromptField], |             prompt_config: List[PromptField], | ||||||
|             prompt_manager_type: str = "PromptManager", |             prompt_manager_type: str = "PromptManager", | ||||||
|             task: Task = None, |             task: Task = None, | ||||||
|             memory: Memory = None, |             memory: Memory = None, | ||||||
| @ -30,14 +30,17 @@ class ReactAgent(BaseAgent): | |||||||
|             llm_config: LLMConfig = None, |             llm_config: LLMConfig = None, | ||||||
|             embed_config: EmbedConfig = None, |             embed_config: EmbedConfig = None, | ||||||
|             sandbox_server: dict = {}, |             sandbox_server: dict = {}, | ||||||
|             jupyter_work_path: str = "", |             jupyter_work_path: str = JUPYTER_WORK_PATH, | ||||||
|             kb_root_path: str = "", |             kb_root_path: str = KB_ROOT_PATH, | ||||||
|  |             doc_retrieval: Union[BaseRetriever] = None, | ||||||
|  |             code_retrieval = None, | ||||||
|  |             search_retrieval = None, | ||||||
|             log_verbose: str = "0" |             log_verbose: str = "0" | ||||||
|             ): |             ): | ||||||
|          |          | ||||||
|         super().__init__(role, prompt_config, prompt_manager_type, task, memory, chat_turn,  |         super().__init__(role, prompt_config, prompt_manager_type, task, memory, chat_turn,  | ||||||
|                          focus_agents, focus_message_keys, llm_config, embed_config, sandbox_server, |                          focus_agents, focus_message_keys, llm_config, embed_config, sandbox_server, | ||||||
|                          jupyter_work_path, kb_root_path, log_verbose |                          jupyter_work_path, kb_root_path, doc_retrieval, code_retrieval, search_retrieval, log_verbose | ||||||
|                          ) |                          ) | ||||||
| 
 | 
 | ||||||
|     def step(self, query: Message, history: Memory = None, background: Memory = None, memory_manager: BaseMemoryManager = None) -> Message: |     def step(self, query: Message, history: Memory = None, background: Memory = None, memory_manager: BaseMemoryManager = None) -> Message: | ||||||
| @ -52,6 +55,7 @@ class ReactAgent(BaseAgent): | |||||||
|         react_memory = Memory(messages=[]) |         react_memory = Memory(messages=[]) | ||||||
|         # insert query |         # insert query | ||||||
|         output_message = Message( |         output_message = Message( | ||||||
|  |                 user_name=query.user_name, | ||||||
|                 role_name=self.role.role_name, |                 role_name=self.role.role_name, | ||||||
|                 role_type="assistant", #self.role.role_type, |                 role_type="assistant", #self.role.role_type, | ||||||
|                 role_content=query.input_query, |                 role_content=query.input_query, | ||||||
| @ -84,9 +88,7 @@ class ReactAgent(BaseAgent): | |||||||
|                     llm_config=self.embed_config |                     llm_config=self.embed_config | ||||||
|                 ) |                 ) | ||||||
|                 memory_manager.append(query) |                 memory_manager.append(query) | ||||||
|                 memory_pool = memory_manager.current_memory |             memory_pool = memory_manager.get_memory_pool(query_c.user_name) | ||||||
|             else: |  | ||||||
|                 memory_pool = memory_manager.current_memory |  | ||||||
| 
 | 
 | ||||||
|             prompt = self.prompt_manager.generate_full_prompt( |             prompt = self.prompt_manager.generate_full_prompt( | ||||||
|                 previous_agent_message=query_c, agent_long_term_memory=self.memory, ui_history=history, chain_summary_messages=background, react_memory=react_memory,  |                 previous_agent_message=query_c, agent_long_term_memory=self.memory, ui_history=history, chain_summary_messages=background, react_memory=react_memory,  | ||||||
| @ -142,82 +144,4 @@ class ReactAgent(BaseAgent): | |||||||
|         title = f"<<<<{self.role.role_name}'s prompt>>>>" |         title = f"<<<<{self.role.role_name}'s prompt>>>>" | ||||||
|         print("#"*len(title) + f"\n{title}\n"+ "#"*len(title)+ f"\n\n{prompt}\n\n") |         print("#"*len(title) + f"\n{title}\n"+ "#"*len(title)+ f"\n\n{prompt}\n\n") | ||||||
| 
 | 
 | ||||||
|     # def create_prompt( |  | ||||||
|     #         self, query: Message, memory: Memory =None, history: Memory = None, background: Memory = None, react_memory: Memory = None, memory_manager: BaseMemoryManager= None,  |  | ||||||
|     #         prompt_mamnger=None) -> str: |  | ||||||
|     #     prompt_mamnger = PromptManager() |  | ||||||
|     #     prompt_mamnger.register_standard_fields() |  | ||||||
|          |  | ||||||
|     #     # input_keys = parse_section(self.role.role_prompt, 'Agent Profile') |  | ||||||
|          |  | ||||||
|     #     data_dict = { |  | ||||||
|     #         "agent_profile": extract_section(self.role.role_prompt, 'Agent Profile'), |  | ||||||
|     #         "tool_information": query.tools, |  | ||||||
|     #         "session_records": memory_manager, |  | ||||||
|     #         "reference_documents": query, |  | ||||||
|     #         "output_format": extract_section(self.role.role_prompt, 'Response Output Format'), |  | ||||||
|     #         "response": "\n".join(["\n".join([f"**{k}:**\n{v}" for k,v in _dict.items()]) for _dict in react_memory.get_parserd_output()]), |  | ||||||
|     #     } |  | ||||||
|     #     # logger.debug(memory_pool) |  | ||||||
|          |  | ||||||
|     #     return prompt_mamnger.generate_full_prompt(data_dict) |  | ||||||
|      |  | ||||||
|     def create_prompt( |  | ||||||
|             self, query: Message, memory: Memory =None, history: Memory = None, background: Memory = None, react_memory: Memory = None, memory_pool: Memory= None,  |  | ||||||
|             prompt_mamnger=None) -> str: |  | ||||||
|         ''' |  | ||||||
|         role\task\tools\docs\memory |  | ||||||
|         ''' |  | ||||||
|         #  |  | ||||||
|         doc_infos = self.create_doc_prompt(query) |  | ||||||
|         code_infos = self.create_codedoc_prompt(query) |  | ||||||
|         #  |  | ||||||
|         formatted_tools, tool_names, _ = self.create_tools_prompt(query) |  | ||||||
|         task_prompt = self.create_task_prompt(query) |  | ||||||
|         background_prompt = self.create_background_prompt(background) |  | ||||||
|         history_prompt = self.create_history_prompt(history) |  | ||||||
|         selfmemory_prompt = self.create_selfmemory_prompt(memory, control_key="step_content") |  | ||||||
|         #  |  | ||||||
|         # extra_system_prompt = self.role.role_prompt |  | ||||||
|         prompt = self.role.role_prompt.format(**{"formatted_tools": formatted_tools, "tool_names": tool_names}) |  | ||||||
| 
 |  | ||||||
|         # react 流程是自身迭代过程,另外二次触发的是需要作为历史对话信息 |  | ||||||
|         # input_query = react_memory.to_tuple_messages(content_key="step_content") |  | ||||||
|         # # input_query = query.input_query + "\n" + "\n".join([f"{v}" for k, v in input_query if v]) |  | ||||||
|         # input_query = "\n".join([f"{v}" for k, v in input_query if v]) |  | ||||||
|         input_query = "\n".join(["\n".join([f"**{k}:**\n{v}" for k,v in _dict.items()]) for _dict in react_memory.get_parserd_output()]) |  | ||||||
|         # logger.debug(f"input_query: {input_query}") |  | ||||||
|          |  | ||||||
|         prompt += "\n" + REACT_PROMPT_INPUT.format(**{"query": input_query}) |  | ||||||
| 
 |  | ||||||
|         task = query.task or self.task |  | ||||||
|         # if task_prompt is not None: |  | ||||||
|         #     prompt += "\n" + task.task_prompt |  | ||||||
| 
 |  | ||||||
|         # if doc_infos is not None and doc_infos!="" and doc_infos!="不存在知识库辅助信息": |  | ||||||
|         #     prompt += f"\n知识库信息: {doc_infos}" |  | ||||||
| 
 |  | ||||||
|         # if code_infos is not None and code_infos!="" and code_infos!="不存在代码库辅助信息": |  | ||||||
|         #     prompt += f"\n代码库信息: {code_infos}" |  | ||||||
| 
 |  | ||||||
|         # if background_prompt: |  | ||||||
|         #     prompt += "\n" + background_prompt |  | ||||||
| 
 |  | ||||||
|         # if history_prompt: |  | ||||||
|         #     prompt += "\n" + history_prompt |  | ||||||
| 
 |  | ||||||
|         # if selfmemory_prompt: |  | ||||||
|         #     prompt += "\n" + selfmemory_prompt |  | ||||||
| 
 |  | ||||||
|         # logger.debug(f"{self.role.role_name}  extra_system_prompt: {self.role.role_prompt}") |  | ||||||
|         # logger.debug(f"{self.role.role_name}  input_query: {input_query}") |  | ||||||
|         # logger.debug(f"{self.role.role_name}  doc_infos: {doc_infos}") |  | ||||||
|         # logger.debug(f"{self.role.role_name}  tool_names: {tool_names}") |  | ||||||
|         # prompt += "\n" + REACT_PROMPT_INPUT.format(**{"query": input_query}) |  | ||||||
| 
 |  | ||||||
|         # prompt = extra_system_prompt.format(**{"query": input_query, "doc_infos": doc_infos, "formatted_tools": formatted_tools, "tool_names": tool_names}) |  | ||||||
|         while "{{" in prompt or "}}" in prompt: |  | ||||||
|             prompt = prompt.replace("{{", "{") |  | ||||||
|             prompt = prompt.replace("}}", "}") |  | ||||||
|         return prompt |  | ||||||
|      |      | ||||||
|  | |||||||
| @ -3,13 +3,15 @@ import copy | |||||||
| import random | import random | ||||||
| from loguru import logger | from loguru import logger | ||||||
| 
 | 
 | ||||||
|  | from langchain.schema import BaseRetriever | ||||||
|  | 
 | ||||||
| from coagent.connector.schema import ( | from coagent.connector.schema import ( | ||||||
|     Memory, Task, Role, Message, PromptField, LogVerboseEnum |     Memory, Task, Role, Message, PromptField, LogVerboseEnum | ||||||
| ) | ) | ||||||
| from coagent.connector.memory_manager import BaseMemoryManager | from coagent.connector.memory_manager import BaseMemoryManager | ||||||
| from coagent.connector.configs.prompts import BEGIN_PROMPT_INPUT |  | ||||||
| from coagent.connector.memory_manager import LocalMemoryManager | from coagent.connector.memory_manager import LocalMemoryManager | ||||||
| from coagent.llm_models import LLMConfig, EmbedConfig | from coagent.llm_models import LLMConfig, EmbedConfig | ||||||
|  | from coagent.base_configs.env_config import JUPYTER_WORK_PATH, KB_ROOT_PATH | ||||||
| from .base_agent import BaseAgent | from .base_agent import BaseAgent | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @ -30,14 +32,17 @@ class SelectorAgent(BaseAgent): | |||||||
|             llm_config: LLMConfig = None, |             llm_config: LLMConfig = None, | ||||||
|             embed_config: EmbedConfig = None, |             embed_config: EmbedConfig = None, | ||||||
|             sandbox_server: dict = {}, |             sandbox_server: dict = {}, | ||||||
|             jupyter_work_path: str = "", |             jupyter_work_path: str = JUPYTER_WORK_PATH, | ||||||
|             kb_root_path: str = "", |             kb_root_path: str = KB_ROOT_PATH, | ||||||
|  |             doc_retrieval: Union[BaseRetriever] = None, | ||||||
|  |             code_retrieval = None, | ||||||
|  |             search_retrieval = None, | ||||||
|             log_verbose: str = "0" |             log_verbose: str = "0" | ||||||
|             ): |             ): | ||||||
|          |          | ||||||
|         super().__init__(role, prompt_config, prompt_manager_type, task, memory, chat_turn,  |         super().__init__(role, prompt_config, prompt_manager_type, task, memory, chat_turn,  | ||||||
|                          focus_agents, focus_message_keys, llm_config, embed_config, sandbox_server, |                          focus_agents, focus_message_keys, llm_config, embed_config, sandbox_server, | ||||||
|                          jupyter_work_path, kb_root_path, log_verbose |                          jupyter_work_path, kb_root_path, doc_retrieval, code_retrieval, search_retrieval, log_verbose | ||||||
|                          ) |                          ) | ||||||
|         self.group_agents = group_agents |         self.group_agents = group_agents | ||||||
| 
 | 
 | ||||||
| @ -56,9 +61,8 @@ class SelectorAgent(BaseAgent): | |||||||
|                 llm_config=self.embed_config |                 llm_config=self.embed_config | ||||||
|             ) |             ) | ||||||
|             memory_manager.append(query) |             memory_manager.append(query) | ||||||
|             memory_pool = memory_manager.current_memory |         memory_pool = memory_manager.get_memory_pool(query_c.user_name) | ||||||
|         else: | 
 | ||||||
|             memory_pool = memory_manager.current_memory |  | ||||||
|         prompt = self.prompt_manager.generate_full_prompt( |         prompt = self.prompt_manager.generate_full_prompt( | ||||||
|                 previous_agent_message=query_c, agent_long_term_memory=self.memory, ui_history=history, chain_summary_messages=background, react_memory=None,  |                 previous_agent_message=query_c, agent_long_term_memory=self.memory, ui_history=history, chain_summary_messages=background, react_memory=None,  | ||||||
|                 memory_pool=memory_pool, agents=self.group_agents) |                 memory_pool=memory_pool, agents=self.group_agents) | ||||||
| @ -90,6 +94,9 @@ class SelectorAgent(BaseAgent): | |||||||
|             for agent in self.group_agents: |             for agent in self.group_agents: | ||||||
|                 if agent.role.role_name == select_message.parsed_output.get("Role", ""): |                 if agent.role.role_name == select_message.parsed_output.get("Role", ""): | ||||||
|                     break |                     break | ||||||
|  |              | ||||||
|  |             # 把除了role以外的信息传给下一个agent | ||||||
|  |             query_c.parsed_output.update({k:v for k,v in select_message.parsed_output.items() if k!="Role"}) | ||||||
|             for output_message in agent.astep(query_c, history, background=background, memory_manager=memory_manager): |             for output_message in agent.astep(query_c, history, background=background, memory_manager=memory_manager): | ||||||
|                 yield output_message or select_message |                 yield output_message or select_message | ||||||
|             # update self_memory |             # update self_memory | ||||||
| @ -103,6 +110,7 @@ class SelectorAgent(BaseAgent): | |||||||
|             memory_manager.append(output_message) |             memory_manager.append(output_message) | ||||||
| 
 | 
 | ||||||
|             select_message.parsed_output = output_message.parsed_output |             select_message.parsed_output = output_message.parsed_output | ||||||
|  |             select_message.spec_parsed_output.update(output_message.spec_parsed_output) | ||||||
|             select_message.parsed_output_list.extend(output_message.parsed_output_list) |             select_message.parsed_output_list.extend(output_message.parsed_output_list) | ||||||
|         yield select_message |         yield select_message | ||||||
| 
 | 
 | ||||||
| @ -115,76 +123,3 @@ class SelectorAgent(BaseAgent): | |||||||
| 
 | 
 | ||||||
|         for agent in self.group_agents: |         for agent in self.group_agents: | ||||||
|             agent.pre_print(query=query, history=history, background=background, memory_manager=memory_manager) |             agent.pre_print(query=query, history=history, background=background, memory_manager=memory_manager) | ||||||
| 
 |  | ||||||
|     # def create_prompt( |  | ||||||
|     #         self, query: Message, memory: Memory =None, history: Memory = None, background: Memory = None, memory_manager: BaseMemoryManager=None, prompt_mamnger=None) -> str: |  | ||||||
|     #     ''' |  | ||||||
|     #     role\task\tools\docs\memory |  | ||||||
|     #     ''' |  | ||||||
|     #     #  |  | ||||||
|     #     doc_infos = self.create_doc_prompt(query) |  | ||||||
|     #     code_infos = self.create_codedoc_prompt(query) |  | ||||||
|     #     #  |  | ||||||
|     #     formatted_tools, tool_names, tools_descs = self.create_tools_prompt(query) |  | ||||||
|     #     agent_names, agents = self.create_agent_names() |  | ||||||
|     #     task_prompt = self.create_task_prompt(query) |  | ||||||
|     #     background_prompt = self.create_background_prompt(background) |  | ||||||
|     #     history_prompt = self.create_history_prompt(history) |  | ||||||
|     #     selfmemory_prompt = self.create_selfmemory_prompt(memory, control_key="step_content") |  | ||||||
|          |  | ||||||
| 
 |  | ||||||
|     #     DocInfos = "" |  | ||||||
|     #     if doc_infos is not None and doc_infos!="" and doc_infos!="不存在知识库辅助信息": |  | ||||||
|     #         DocInfos += f"\nDocument Information: {doc_infos}" |  | ||||||
| 
 |  | ||||||
|     #     if code_infos is not None and code_infos!="" and code_infos!="不存在代码库辅助信息": |  | ||||||
|     #         DocInfos += f"\nCodeBase Infomation: {code_infos}" |  | ||||||
| 
 |  | ||||||
|     #     input_query = query.input_query |  | ||||||
|     #     logger.debug(f"{self.role.role_name}  input_query: {input_query}") |  | ||||||
|     #     prompt = self.role.role_prompt.format(**{"agent_names": agent_names, "agents": agents, "formatted_tools": tools_descs, "tool_names": tool_names}) |  | ||||||
|     #     # |  | ||||||
|     #     memory_pool_select_by_agent_key = self.select_memory_by_agent_key(memory_manager.current_memory) |  | ||||||
|     #     memory_pool_select_by_agent_key_context = '\n\n'.join([f"*{k}*\n{v}" for parsed_output in memory_pool_select_by_agent_key.get_parserd_output_list() for k, v in parsed_output.items() if k not in ['Action Status']]) |  | ||||||
| 
 |  | ||||||
|     #     input_keys = parse_section(self.role.role_prompt, 'Input Format') |  | ||||||
|     #     #  |  | ||||||
|     #     prompt += "\n" + BEGIN_PROMPT_INPUT |  | ||||||
|     #     for input_key in input_keys: |  | ||||||
|     #         if input_key == "Origin Query":  |  | ||||||
|     #             prompt += "\n**Origin Query:**\n" + query.origin_query |  | ||||||
|     #         elif input_key == "Context": |  | ||||||
|     #             context =  "\n".join([f"*{k}*\n{v}" for i in query.parsed_output_list for k,v in i.items() if "Action Status" !=k]) |  | ||||||
|     #             if history: |  | ||||||
|     #                 context = history_prompt + "\n" + context |  | ||||||
|     #             if not context: |  | ||||||
|     #                 context = "there is no context" |  | ||||||
| 
 |  | ||||||
|     #             if self.focus_agents and memory_pool_select_by_agent_key_context: |  | ||||||
|     #                 context = memory_pool_select_by_agent_key_context |  | ||||||
|     #             prompt += "\n**Context:**\n" + context + "\n" + input_query |  | ||||||
|     #         elif input_key == "DocInfos": |  | ||||||
|     #             prompt += "\n**DocInfos:**\n" + DocInfos |  | ||||||
|     #         elif input_key == "Question": |  | ||||||
|     #             prompt += "\n**Question:**\n" + input_query |  | ||||||
| 
 |  | ||||||
|     #     while "{{" in prompt or "}}" in prompt: |  | ||||||
|     #         prompt = prompt.replace("{{", "{") |  | ||||||
|     #         prompt = prompt.replace("}}", "}") |  | ||||||
| 
 |  | ||||||
|     #     # logger.debug(f"{self.role.role_name}  prompt: {prompt}") |  | ||||||
|     #     return prompt |  | ||||||
|      |  | ||||||
|     # def create_agent_names(self): |  | ||||||
|     #     random.shuffle(self.group_agents) |  | ||||||
|     #     agent_names = ", ".join([f'{agent.role.role_name}' for agent in self.group_agents]) |  | ||||||
|     #     agent_descs = [] |  | ||||||
|     #     for agent in self.group_agents: |  | ||||||
|     #         role_desc = agent.role.role_prompt.split("####")[1] |  | ||||||
|     #         while "\n\n" in role_desc: |  | ||||||
|     #             role_desc = role_desc.replace("\n\n", "\n") |  | ||||||
|     #         role_desc = role_desc.replace("\n", ",") |  | ||||||
| 
 |  | ||||||
|     #         agent_descs.append(f'"role name: {agent.role.role_name}\nrole description: {role_desc}"') |  | ||||||
| 
 |  | ||||||
|     #     return agent_names, "\n".join(agent_descs) |  | ||||||
							
								
								
									
										7
									
								
								coagent/connector/antflow/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										7
									
								
								coagent/connector/antflow/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,7 @@ | |||||||
|  | from .flow import AgentFlow, PhaseFlow, ChainFlow | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | __all__ = [ | ||||||
|  |     "AgentFlow", "PhaseFlow", "ChainFlow" | ||||||
|  | ] | ||||||
							
								
								
									
										255
									
								
								coagent/connector/antflow/flow.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										255
									
								
								coagent/connector/antflow/flow.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,255 @@ | |||||||
|  | import importlib | ||||||
|  | from typing import List, Union, Dict, Any | ||||||
|  | from loguru import logger | ||||||
|  | import os | ||||||
|  | from langchain.embeddings.base import Embeddings | ||||||
|  | from langchain.agents import Tool | ||||||
|  | from langchain.llms.base import BaseLLM, LLM | ||||||
|  | 
 | ||||||
|  | from coagent.retrieval.base_retrieval import IMRertrieval | ||||||
|  | from coagent.llm_models.llm_config import EmbedConfig, LLMConfig | ||||||
|  | from coagent.connector.phase import BasePhase | ||||||
|  | from coagent.connector.agents import BaseAgent | ||||||
|  | from coagent.connector.chains import BaseChain | ||||||
|  | from coagent.connector.schema import Message, Role, PromptField, ChainConfig | ||||||
|  | from coagent.tools import toLangchainTools, TOOL_DICT, TOOL_SETS | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class AgentFlow: | ||||||
|  |     def __init__( | ||||||
|  |         self,  | ||||||
|  |         role_name: str, | ||||||
|  |         agent_type: str, | ||||||
|  |         role_type: str = "assistant", | ||||||
|  |         agent_index: int = 0, | ||||||
|  |         role_prompt: str = "", | ||||||
|  |         prompt_config: List[Dict[str, Any]] = [], | ||||||
|  |         prompt_manager_type: str = "PromptManager", | ||||||
|  |         chat_turn: int = 3, | ||||||
|  |         focus_agents: List[str] = [], | ||||||
|  |         focus_messages: List[str] = [], | ||||||
|  |         embeddings: Embeddings = None, | ||||||
|  |         llm: BaseLLM = None, | ||||||
|  |         doc_retrieval: IMRertrieval = None, | ||||||
|  |         code_retrieval: IMRertrieval = None, | ||||||
|  |         search_retrieval: IMRertrieval = None, | ||||||
|  |         **kwargs | ||||||
|  |     ): | ||||||
|  |         self.role_type = role_type | ||||||
|  |         self.role_name = role_name | ||||||
|  |         self.agent_type = agent_type | ||||||
|  |         self.role_prompt = role_prompt | ||||||
|  |         self.agent_index = agent_index | ||||||
|  | 
 | ||||||
|  |         self.prompt_config = prompt_config | ||||||
|  |         self.prompt_manager_type = prompt_manager_type | ||||||
|  | 
 | ||||||
|  |         self.chat_turn = chat_turn | ||||||
|  |         self.focus_agents = focus_agents | ||||||
|  |         self.focus_messages = focus_messages | ||||||
|  | 
 | ||||||
|  |         self.embeddings = embeddings | ||||||
|  |         self.llm = llm | ||||||
|  |         self.doc_retrieval = doc_retrieval | ||||||
|  |         self.code_retrieval = code_retrieval | ||||||
|  |         self.search_retrieval = search_retrieval | ||||||
|  |         # self.build_config() | ||||||
|  |         # self.build_agent() | ||||||
|  | 
 | ||||||
|  |     def build_config(self, embeddings: Embeddings = None, llm: BaseLLM = None): | ||||||
|  |         self.llm_config = LLMConfig(model_name="test", llm=self.llm or llm) | ||||||
|  |         self.embed_config = EmbedConfig(embed_model="test", langchain_embeddings=self.embeddings or embeddings) | ||||||
|  | 
 | ||||||
|  |     def build_agent(self,  | ||||||
|  |                     embeddings: Embeddings = None, llm: BaseLLM = None, | ||||||
|  |                     doc_retrieval: IMRertrieval = None,  | ||||||
|  |                     code_retrieval: IMRertrieval = None,  | ||||||
|  |                     search_retrieval: IMRertrieval = None, | ||||||
|  |         ): | ||||||
|  |         # 可注册个性化的agent,仅通过start_action和end_action来注册 | ||||||
|  |         # class ExtraAgent(BaseAgent): | ||||||
|  |         #     def start_action_step(self, message: Message) -> Message: | ||||||
|  |         #         pass | ||||||
|  | 
 | ||||||
|  |         #     def end_action_step(self, message: Message) -> Message: | ||||||
|  |         #         pass | ||||||
|  |         # agent_module = importlib.import_module("coagent.connector.agents") | ||||||
|  |         # setattr(agent_module, 'extraAgent', ExtraAgent) | ||||||
|  | 
 | ||||||
|  |         # 可注册个性化的prompt组装方式, | ||||||
|  |         # class CodeRetrievalPM(PromptManager): | ||||||
|  |         #     def handle_code_packages(self, **kwargs) -> str: | ||||||
|  |         #         if 'previous_agent_message' not in kwargs: | ||||||
|  |         #             return "" | ||||||
|  |         #         previous_agent_message: Message = kwargs['previous_agent_message'] | ||||||
|  |         #         # 由于两个agent共用了同一个manager,所以临时性处理 | ||||||
|  |         #         vertices = previous_agent_message.customed_kargs.get("RelatedVerticesRetrivalRes", {}).get("vertices", []) | ||||||
|  |         #         return ", ".join([str(v) for v in vertices]) | ||||||
|  | 
 | ||||||
|  |         # prompt_manager_module = importlib.import_module("coagent.connector.prompt_manager") | ||||||
|  |         # setattr(prompt_manager_module, 'CodeRetrievalPM', CodeRetrievalPM) | ||||||
|  |          | ||||||
|  |         # agent实例化 | ||||||
|  |         agent_module = importlib.import_module("coagent.connector.agents") | ||||||
|  |         baseAgent: BaseAgent = getattr(agent_module, self.agent_type) | ||||||
|  |         role = Role( | ||||||
|  |             role_type=self.agent_type, role_name=self.role_name,  | ||||||
|  |             agent_type=self.agent_type, role_prompt=self.role_prompt, | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |         self.build_config(embeddings, llm) | ||||||
|  |         self.agent = baseAgent( | ||||||
|  |                     role=role,  | ||||||
|  |                     prompt_config = [PromptField(**config) for config in self.prompt_config], | ||||||
|  |                     prompt_manager_type=self.prompt_manager_type, | ||||||
|  |                     chat_turn=self.chat_turn, | ||||||
|  |                     focus_agents=self.focus_agents, | ||||||
|  |                     focus_message_keys=self.focus_messages, | ||||||
|  |                     llm_config=self.llm_config, | ||||||
|  |                     embed_config=self.embed_config, | ||||||
|  |                     doc_retrieval=doc_retrieval or self.doc_retrieval, | ||||||
|  |                     code_retrieval=code_retrieval or self.code_retrieval, | ||||||
|  |                     search_retrieval=search_retrieval or self.search_retrieval, | ||||||
|  |                 ) | ||||||
|  |          | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class ChainFlow: | ||||||
|  |     def __init__( | ||||||
|  |         self,  | ||||||
|  |         chain_name: str, | ||||||
|  |         chain_index: int = 0, | ||||||
|  |         agent_flows: List[AgentFlow] = [], | ||||||
|  |         chat_turn: int = 5, | ||||||
|  |         do_checker: bool = False, | ||||||
|  |         embeddings: Embeddings = None, | ||||||
|  |         llm: BaseLLM = None, | ||||||
|  |         doc_retrieval: IMRertrieval = None, | ||||||
|  |         code_retrieval: IMRertrieval = None, | ||||||
|  |         search_retrieval: IMRertrieval = None, | ||||||
|  |         # chain_type: str = "BaseChain", | ||||||
|  |         **kwargs | ||||||
|  |     ): | ||||||
|  |         self.agent_flows = sorted(agent_flows, key=lambda x:x.agent_index) | ||||||
|  |         self.chat_turn = chat_turn | ||||||
|  |         self.do_checker = do_checker | ||||||
|  |         self.chain_name = chain_name | ||||||
|  |         self.chain_index = chain_index | ||||||
|  |         self.chain_type = "BaseChain" | ||||||
|  | 
 | ||||||
|  |         self.embeddings = embeddings | ||||||
|  |         self.llm = llm | ||||||
|  | 
 | ||||||
|  |         self.doc_retrieval = doc_retrieval | ||||||
|  |         self.code_retrieval = code_retrieval | ||||||
|  |         self.search_retrieval = search_retrieval | ||||||
|  |         # self.build_config() | ||||||
|  |         # self.build_chain() | ||||||
|  | 
 | ||||||
|  |     def build_config(self, embeddings: Embeddings = None, llm: BaseLLM = None): | ||||||
|  |         self.llm_config = LLMConfig(model_name="test", llm=self.llm or llm) | ||||||
|  |         self.embed_config = EmbedConfig(embed_model="test", langchain_embeddings=self.embeddings or embeddings) | ||||||
|  | 
 | ||||||
|  |     def build_chain(self,  | ||||||
|  |                     embeddings: Embeddings = None, llm: BaseLLM = None, | ||||||
|  |                     doc_retrieval: IMRertrieval = None,  | ||||||
|  |                     code_retrieval: IMRertrieval = None,  | ||||||
|  |                     search_retrieval: IMRertrieval = None, | ||||||
|  |         ): | ||||||
|  |         # chain 实例化 | ||||||
|  |         chain_module = importlib.import_module("coagent.connector.chains") | ||||||
|  |         baseChain: BaseChain = getattr(chain_module, self.chain_type) | ||||||
|  | 
 | ||||||
|  |         agent_names = [agent_flow.role_name for agent_flow in self.agent_flows] | ||||||
|  |         chain_config = ChainConfig(chain_name=self.chain_name, agents=agent_names, do_checker=self.do_checker, chat_turn=self.chat_turn) | ||||||
|  | 
 | ||||||
|  |         # agent 实例化 | ||||||
|  |         self.build_config(embeddings, llm) | ||||||
|  |         for agent_flow in self.agent_flows: | ||||||
|  |             agent_flow.build_agent(embeddings, llm) | ||||||
|  | 
 | ||||||
|  |         self.chain = baseChain( | ||||||
|  |                 chain_config, | ||||||
|  |                 [agent_flow.agent for agent_flow in self.agent_flows],  | ||||||
|  |                 embed_config=self.embed_config, | ||||||
|  |                 llm_config=self.llm_config, | ||||||
|  |                 doc_retrieval=doc_retrieval or self.doc_retrieval, | ||||||
|  |                 code_retrieval=code_retrieval or self.code_retrieval, | ||||||
|  |                 search_retrieval=search_retrieval or self.search_retrieval, | ||||||
|  |                 ) | ||||||
|  | 
 | ||||||
|  | class PhaseFlow: | ||||||
|  |     def __init__( | ||||||
|  |         self,  | ||||||
|  |         phase_name: str, | ||||||
|  |         chain_flows: List[ChainFlow], | ||||||
|  |         embeddings: Embeddings = None, | ||||||
|  |         llm: BaseLLM = None, | ||||||
|  |         tools: List[Tool] = [], | ||||||
|  |         doc_retrieval: IMRertrieval = None, | ||||||
|  |         code_retrieval: IMRertrieval = None, | ||||||
|  |         search_retrieval: IMRertrieval = None, | ||||||
|  |         **kwargs | ||||||
|  |     ): | ||||||
|  |         self.phase_name = phase_name | ||||||
|  |         self.chain_flows = sorted(chain_flows, key=lambda x:x.chain_index) | ||||||
|  |         self.phase_type = "BasePhase" | ||||||
|  |         self.tools = tools | ||||||
|  | 
 | ||||||
|  |         self.embeddings = embeddings | ||||||
|  |         self.llm = llm | ||||||
|  | 
 | ||||||
|  |         self.doc_retrieval = doc_retrieval | ||||||
|  |         self.code_retrieval = code_retrieval | ||||||
|  |         self.search_retrieval = search_retrieval | ||||||
|  |         # self.build_config() | ||||||
|  |         self.build_phase() | ||||||
|  |      | ||||||
|  |     def __call__(self, params: dict) -> str: | ||||||
|  | 
 | ||||||
|  |         # tools = toLangchainTools([TOOL_DICT[i] for i in TOOL_SETS if i in TOOL_DICT]) | ||||||
|  |         # query_content = "帮我确认下127.0.0.1这个服务器的在10点是否存在异常,请帮我判断一下" | ||||||
|  |         try: | ||||||
|  |             logger.info(f"params: {params}") | ||||||
|  |             query_content = params.get("query") or params.get("input") | ||||||
|  |             search_type = params.get("search_type") | ||||||
|  |             query = Message( | ||||||
|  |                 role_name="human", role_type="user", tools=self.tools, | ||||||
|  |                 role_content=query_content, input_query=query_content, origin_query=query_content, | ||||||
|  |                 cb_search_type=search_type, | ||||||
|  |                 ) | ||||||
|  |             # phase.pre_print(query) | ||||||
|  |             output_message, output_memory = self.phase.step(query) | ||||||
|  |             output_content = "\n\n".join((output_memory.to_str_messages(return_all=True, content_key="parsed_output_list").split("\n\n")[1:])) or output_message.role_content | ||||||
|  |             return output_content | ||||||
|  |         except Exception as e: | ||||||
|  |             logger.exception(e) | ||||||
|  |             return f"Error {e}" | ||||||
|  | 
 | ||||||
|  |     def build_config(self, embeddings: Embeddings = None, llm: BaseLLM = None): | ||||||
|  |         self.llm_config = LLMConfig(model_name="test", llm=self.llm or llm) | ||||||
|  |         self.embed_config = EmbedConfig(embed_model="test", langchain_embeddings=self.embeddings or embeddings) | ||||||
|  | 
 | ||||||
|  |     def build_phase(self, embeddings: Embeddings = None, llm: BaseLLM = None): | ||||||
|  |         # phase 实例化 | ||||||
|  |         phase_module = importlib.import_module("coagent.connector.phase") | ||||||
|  |         basePhase: BasePhase = getattr(phase_module, self.phase_type) | ||||||
|  | 
 | ||||||
|  |         # chain 实例化 | ||||||
|  |         self.build_config(self.embeddings or embeddings, self.llm or llm) | ||||||
|  |         os.environ["log_verbose"] = "2" | ||||||
|  |         for chain_flow in self.chain_flows: | ||||||
|  |             chain_flow.build_chain( | ||||||
|  |                 self.embeddings or embeddings, self.llm or llm, | ||||||
|  |                 self.doc_retrieval, self.code_retrieval, self.search_retrieval | ||||||
|  |                 ) | ||||||
|  | 
 | ||||||
|  |         self.phase: BasePhase = basePhase( | ||||||
|  |                 phase_name=self.phase_name, | ||||||
|  |                 chains=[chain_flow.chain for chain_flow in self.chain_flows], | ||||||
|  |                 embed_config=self.embed_config, | ||||||
|  |                 llm_config=self.llm_config, | ||||||
|  |                 doc_retrieval=self.doc_retrieval, | ||||||
|  |                 code_retrieval=self.code_retrieval, | ||||||
|  |                 search_retrieval=self.search_retrieval | ||||||
|  |                 ) | ||||||
| @ -1,9 +1,10 @@ | |||||||
| from typing import List | from typing import List, Tuple, Union | ||||||
| from loguru import logger | from loguru import logger | ||||||
| import copy, os | import copy, os | ||||||
| 
 | 
 | ||||||
| from coagent.connector.agents import BaseAgent | from langchain.schema import BaseRetriever | ||||||
| 
 | 
 | ||||||
|  | from coagent.connector.agents import BaseAgent | ||||||
| from coagent.connector.schema import ( | from coagent.connector.schema import ( | ||||||
|     Memory, Role, Message, ActionStatus, ChainConfig, |     Memory, Role, Message, ActionStatus, ChainConfig, | ||||||
|     load_role_configs |     load_role_configs | ||||||
| @ -11,31 +12,32 @@ from coagent.connector.schema import ( | |||||||
| from coagent.connector.memory_manager import BaseMemoryManager | from coagent.connector.memory_manager import BaseMemoryManager | ||||||
| from coagent.connector.message_process import MessageUtils | from coagent.connector.message_process import MessageUtils | ||||||
| from coagent.llm_models.llm_config import LLMConfig, EmbedConfig | from coagent.llm_models.llm_config import LLMConfig, EmbedConfig | ||||||
|  | from coagent.base_configs.env_config import JUPYTER_WORK_PATH, KB_ROOT_PATH | ||||||
| from coagent.connector.configs.agent_config import AGETN_CONFIGS | from coagent.connector.configs.agent_config import AGETN_CONFIGS | ||||||
| role_configs = load_role_configs(AGETN_CONFIGS) | role_configs = load_role_configs(AGETN_CONFIGS) | ||||||
| 
 | 
 | ||||||
| # from configs.model_config import JUPYTER_WORK_PATH |  | ||||||
| # from configs.server_config import SANDBOX_SERVER |  | ||||||
| 
 |  | ||||||
| 
 | 
 | ||||||
| class BaseChain: | class BaseChain: | ||||||
|     def __init__( |     def __init__( | ||||||
|             self,  |             self,  | ||||||
|             # chainConfig: ChainConfig, |             chainConfig: ChainConfig, | ||||||
|             agents: List[BaseAgent], |             agents: List[BaseAgent], | ||||||
|             chat_turn: int = 1, |             # chat_turn: int = 1, | ||||||
|             do_checker: bool = False, |             # do_checker: bool = False, | ||||||
|             sandbox_server: dict = {}, |             sandbox_server: dict = {}, | ||||||
|             jupyter_work_path: str = "", |             jupyter_work_path: str = JUPYTER_WORK_PATH, | ||||||
|             kb_root_path: str = "", |             kb_root_path: str = KB_ROOT_PATH, | ||||||
|             llm_config: LLMConfig = LLMConfig(), |             llm_config: LLMConfig = LLMConfig(), | ||||||
|             embed_config: EmbedConfig = None, |             embed_config: EmbedConfig = None, | ||||||
|  |             doc_retrieval: Union[BaseRetriever] = None, | ||||||
|  |             code_retrieval = None, | ||||||
|  |             search_retrieval = None, | ||||||
|             log_verbose: str = "0" |             log_verbose: str = "0" | ||||||
|             ) -> None: |             ) -> None: | ||||||
|         # self.chainConfig = chainConfig |         self.chainConfig = chainConfig | ||||||
|         self.agents: List[BaseAgent] = agents |         self.agents: List[BaseAgent] = agents | ||||||
|         self.chat_turn = chat_turn |         self.chat_turn = chainConfig.chat_turn | ||||||
|         self.do_checker = do_checker |         self.do_checker = chainConfig.do_checker | ||||||
|         self.sandbox_server = sandbox_server |         self.sandbox_server = sandbox_server | ||||||
|         self.jupyter_work_path = jupyter_work_path |         self.jupyter_work_path = jupyter_work_path | ||||||
|         self.llm_config = llm_config |         self.llm_config = llm_config | ||||||
| @ -45,9 +47,11 @@ class BaseChain: | |||||||
|             task = None, memory = None, |             task = None, memory = None, | ||||||
|             llm_config=llm_config, embed_config=embed_config, |             llm_config=llm_config, embed_config=embed_config, | ||||||
|             sandbox_server=sandbox_server, jupyter_work_path=jupyter_work_path, |             sandbox_server=sandbox_server, jupyter_work_path=jupyter_work_path, | ||||||
|             kb_root_path=kb_root_path |             kb_root_path=kb_root_path, | ||||||
|  |             doc_retrieval=doc_retrieval, code_retrieval=code_retrieval, | ||||||
|  |             search_retrieval=search_retrieval | ||||||
|             ) |             ) | ||||||
|         self.messageUtils = MessageUtils(None, sandbox_server, self.jupyter_work_path, embed_config, llm_config, kb_root_path, log_verbose) |         self.messageUtils = MessageUtils(None, sandbox_server, self.jupyter_work_path, embed_config, llm_config, kb_root_path, doc_retrieval, code_retrieval, search_retrieval, log_verbose) | ||||||
|         # all memory created by agent until instance deleted |         # all memory created by agent until instance deleted | ||||||
|         self.global_memory = Memory(messages=[]) |         self.global_memory = Memory(messages=[]) | ||||||
| 
 | 
 | ||||||
| @ -62,13 +66,16 @@ class BaseChain: | |||||||
|         for agent in self.agents: |         for agent in self.agents: | ||||||
|             agent.pre_print(query, history, background=background, memory_manager=memory_manager) |             agent.pre_print(query, history, background=background, memory_manager=memory_manager) | ||||||
|      |      | ||||||
|     def astep(self, query: Message, history: Memory = None, background: Memory = None, memory_manager: BaseMemoryManager = None) -> Message: |     def astep(self, query: Message, history: Memory = None, background: Memory = None, memory_manager: BaseMemoryManager = None) -> Tuple[Message, Memory]: | ||||||
|         '''execute chain''' |         '''execute chain''' | ||||||
|         local_memory = Memory(messages=[]) |         local_memory = Memory(messages=[]) | ||||||
|         input_message = copy.deepcopy(query) |         input_message = copy.deepcopy(query) | ||||||
|         step_nums = copy.deepcopy(self.chat_turn) |         step_nums = copy.deepcopy(self.chat_turn) | ||||||
|         check_message = None |         check_message = None | ||||||
| 
 | 
 | ||||||
|  |         # if input_message not in memory_manager: | ||||||
|  |         #     memory_manager.append(input_message) | ||||||
|  | 
 | ||||||
|         self.global_memory.append(input_message) |         self.global_memory.append(input_message) | ||||||
|         # local_memory.append(input_message) |         # local_memory.append(input_message) | ||||||
|         while step_nums > 0: |         while step_nums > 0: | ||||||
| @ -78,7 +85,7 @@ class BaseChain: | |||||||
|                     yield output_message, local_memory + output_message |                     yield output_message, local_memory + output_message | ||||||
|                 output_message = self.messageUtils.inherit_extrainfo(input_message, output_message) |                 output_message = self.messageUtils.inherit_extrainfo(input_message, output_message) | ||||||
|                 # according the output to choose one action for code_content or tool_content |                 # according the output to choose one action for code_content or tool_content | ||||||
|                 output_message = self.messageUtils.parser(output_message) |                 # output_message = self.messageUtils.parser(output_message) | ||||||
|                 yield output_message, local_memory + output_message |                 yield output_message, local_memory + output_message | ||||||
|                 # output_message = self.step_router(output_message) |                 # output_message = self.step_router(output_message) | ||||||
|                 input_message = output_message |                 input_message = output_message | ||||||
|  | |||||||
| @ -1,9 +1,10 @@ | |||||||
| from .agent_config import AGETN_CONFIGS | from .agent_config import AGETN_CONFIGS | ||||||
| from .chain_config import CHAIN_CONFIGS | from .chain_config import CHAIN_CONFIGS | ||||||
| from .phase_config import PHASE_CONFIGS | from .phase_config import PHASE_CONFIGS | ||||||
| from .prompt_config import BASE_PROMPT_CONFIGS, EXECUTOR_PROMPT_CONFIGS, SELECTOR_PROMPT_CONFIGS, BASE_NOTOOLPROMPT_CONFIGS | from .prompt_config import * | ||||||
| 
 | 
 | ||||||
| __all__ = [ | __all__ = [ | ||||||
|     "AGETN_CONFIGS", "CHAIN_CONFIGS", "PHASE_CONFIGS",  |     "AGETN_CONFIGS", "CHAIN_CONFIGS", "PHASE_CONFIGS",  | ||||||
|     "BASE_PROMPT_CONFIGS", "EXECUTOR_PROMPT_CONFIGS", "SELECTOR_PROMPT_CONFIGS", "BASE_NOTOOLPROMPT_CONFIGS" |     "BASE_PROMPT_CONFIGS", "EXECUTOR_PROMPT_CONFIGS", "SELECTOR_PROMPT_CONFIGS", "BASE_NOTOOLPROMPT_CONFIGS", | ||||||
|  |     "CODE2DOC_GROUP_PROMPT_CONFIGS", "CODE2DOC_PROMPT_CONFIGS", "CODE2TESTS_PROMPT_CONFIGS" | ||||||
|     ] |     ] | ||||||
| @ -1,19 +1,21 @@ | |||||||
| from enum import Enum | from enum import Enum | ||||||
| from .prompts import ( | from .prompts import * | ||||||
|     REACT_PROMPT_INPUT, CHECK_PROMPT_INPUT, EXECUTOR_PROMPT_INPUT, CONTEXT_PROMPT_INPUT, QUERY_CONTEXT_PROMPT_INPUT,PLAN_PROMPT_INPUT, | # from .prompts import ( | ||||||
|     RECOGNIZE_INTENTION_PROMPT, | #     REACT_PROMPT_INPUT, CHECK_PROMPT_INPUT, EXECUTOR_PROMPT_INPUT, CONTEXT_PROMPT_INPUT, QUERY_CONTEXT_PROMPT_INPUT,PLAN_PROMPT_INPUT, | ||||||
|     CHECKER_TEMPLATE_PROMPT, | #     RECOGNIZE_INTENTION_PROMPT, | ||||||
|     CONV_SUMMARY_PROMPT, | #     CHECKER_TEMPLATE_PROMPT, | ||||||
|     QA_PROMPT, CODE_QA_PROMPT, QA_TEMPLATE_PROMPT, | #     CONV_SUMMARY_PROMPT, | ||||||
|     EXECUTOR_TEMPLATE_PROMPT, | #     QA_PROMPT, CODE_QA_PROMPT, QA_TEMPLATE_PROMPT, | ||||||
|     REFINE_TEMPLATE_PROMPT, | #     EXECUTOR_TEMPLATE_PROMPT, | ||||||
|     SELECTOR_AGENT_TEMPLATE_PROMPT, | #     REFINE_TEMPLATE_PROMPT, | ||||||
|     PLANNER_TEMPLATE_PROMPT, GENERAL_PLANNER_PROMPT, DATA_PLANNER_PROMPT, TOOL_PLANNER_PROMPT, | #     SELECTOR_AGENT_TEMPLATE_PROMPT, | ||||||
|     PRD_WRITER_METAGPT_PROMPT, DESIGN_WRITER_METAGPT_PROMPT, TASK_WRITER_METAGPT_PROMPT, CODE_WRITER_METAGPT_PROMPT, | #     PLANNER_TEMPLATE_PROMPT, GENERAL_PLANNER_PROMPT, DATA_PLANNER_PROMPT, TOOL_PLANNER_PROMPT, | ||||||
|     REACT_TEMPLATE_PROMPT, | #     PRD_WRITER_METAGPT_PROMPT, DESIGN_WRITER_METAGPT_PROMPT, TASK_WRITER_METAGPT_PROMPT, CODE_WRITER_METAGPT_PROMPT, | ||||||
|     REACT_TOOL_PROMPT, REACT_CODE_PROMPT, REACT_TOOL_AND_CODE_PLANNER_PROMPT, REACT_TOOL_AND_CODE_PROMPT | #     REACT_TEMPLATE_PROMPT, | ||||||
| ) | #     REACT_TOOL_PROMPT, REACT_CODE_PROMPT, REACT_TOOL_AND_CODE_PLANNER_PROMPT, REACT_TOOL_AND_CODE_PROMPT | ||||||
| from .prompt_config import BASE_PROMPT_CONFIGS, EXECUTOR_PROMPT_CONFIGS, SELECTOR_PROMPT_CONFIGS, BASE_NOTOOLPROMPT_CONFIGS | # ) | ||||||
|  | from .prompt_config import * | ||||||
|  | # BASE_PROMPT_CONFIGS, EXECUTOR_PROMPT_CONFIGS, SELECTOR_PROMPT_CONFIGS, BASE_NOTOOLPROMPT_CONFIGS | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @ -261,4 +263,68 @@ AGETN_CONFIGS = { | |||||||
|         "focus_agents": ["metaGPT_DESIGN", "metaGPT_TASK"], |         "focus_agents": ["metaGPT_DESIGN", "metaGPT_TASK"], | ||||||
|         "focus_message_keys": [], |         "focus_message_keys": [], | ||||||
|     }, |     }, | ||||||
|  |     "class2Docer": { | ||||||
|  |         "role": { | ||||||
|  |             "role_prompt": Class2Doc_PROMPT, | ||||||
|  |             "role_type": "assistant", | ||||||
|  |             "role_name": "class2Docer", | ||||||
|  |             "role_desc": "", | ||||||
|  |             "agent_type": "CodeGenDocer" | ||||||
|  |         }, | ||||||
|  |         "prompt_config": CODE2DOC_PROMPT_CONFIGS, | ||||||
|  |         "prompt_manager_type": "Code2DocPM", | ||||||
|  |         "chat_turn": 1, | ||||||
|  |         "focus_agents": [], | ||||||
|  |         "focus_message_keys": [], | ||||||
|  |     }, | ||||||
|  |     "func2Docer": { | ||||||
|  |         "role": { | ||||||
|  |             "role_prompt": Func2Doc_PROMPT, | ||||||
|  |             "role_type": "assistant", | ||||||
|  |             "role_name": "func2Docer", | ||||||
|  |             "role_desc": "", | ||||||
|  |             "agent_type": "CodeGenDocer" | ||||||
|  |         }, | ||||||
|  |         "prompt_config": CODE2DOC_PROMPT_CONFIGS, | ||||||
|  |         "prompt_manager_type": "Code2DocPM", | ||||||
|  |         "chat_turn": 1, | ||||||
|  |         "focus_agents": [], | ||||||
|  |         "focus_message_keys": [], | ||||||
|  |     }, | ||||||
|  |     "code2DocsGrouper": { | ||||||
|  |         "role": { | ||||||
|  |             "role_prompt": Code2DocGroup_PROMPT, | ||||||
|  |             "role_type": "assistant", | ||||||
|  |             "role_name": "code2DocsGrouper", | ||||||
|  |             "role_desc": "", | ||||||
|  |             "agent_type": "SelectorAgent" | ||||||
|  |         }, | ||||||
|  |         "prompt_config": CODE2DOC_GROUP_PROMPT_CONFIGS, | ||||||
|  |         "group_agents": ["class2Docer", "func2Docer"], | ||||||
|  |         "chat_turn": 1, | ||||||
|  |     }, | ||||||
|  |     "Code2TestJudger": { | ||||||
|  |         "role": { | ||||||
|  |             "role_prompt": judgeCode2Tests_PROMPT, | ||||||
|  |             "role_type": "assistant", | ||||||
|  |             "role_name": "Code2TestJudger", | ||||||
|  |             "role_desc": "", | ||||||
|  |             "agent_type": "CodeRetrieval" | ||||||
|  |         }, | ||||||
|  |         "prompt_config": CODE2TESTS_PROMPT_CONFIGS, | ||||||
|  |         "prompt_manager_type": "CodeRetrievalPM", | ||||||
|  |         "chat_turn": 1, | ||||||
|  |     }, | ||||||
|  |     "code2Tests": { | ||||||
|  |         "role": { | ||||||
|  |             "role_prompt": code2Tests_PROMPT, | ||||||
|  |             "role_type": "assistant", | ||||||
|  |             "role_name": "code2Tests", | ||||||
|  |             "role_desc": "", | ||||||
|  |             "agent_type": "CodeRetrieval" | ||||||
|  |         }, | ||||||
|  |         "prompt_config": CODE2TESTS_PROMPT_CONFIGS, | ||||||
|  |         "prompt_manager_type": "CodeRetrievalPM", | ||||||
|  |         "chat_turn": 1, | ||||||
|  |     }, | ||||||
| } | } | ||||||
| @ -123,5 +123,21 @@ CHAIN_CONFIGS = { | |||||||
|         "chat_turn": 1, |         "chat_turn": 1, | ||||||
|         "do_checker": False, |         "do_checker": False, | ||||||
|         "chain_prompt": "" |         "chain_prompt": "" | ||||||
|  |     }, | ||||||
|  |     "code2DocsGroupChain": { | ||||||
|  |         "chain_name": "code2DocsGroupChain", | ||||||
|  |         "chain_type": "BaseChain", | ||||||
|  |         "agents": ["code2DocsGrouper"], | ||||||
|  |         "chat_turn": 1, | ||||||
|  |         "do_checker": False, | ||||||
|  |         "chain_prompt": "" | ||||||
|  |     }, | ||||||
|  |     "code2TestsChain": { | ||||||
|  |         "chain_name": "code2TestsChain", | ||||||
|  |         "chain_type": "BaseChain", | ||||||
|  |         "agents": ["Code2TestJudger", "code2Tests"], | ||||||
|  |         "chat_turn": 1, | ||||||
|  |         "do_checker": False, | ||||||
|  |         "chain_prompt": "" | ||||||
|     } |     } | ||||||
| } | } | ||||||
|  | |||||||
| @ -14,44 +14,24 @@ PHASE_CONFIGS = { | |||||||
|         "phase_name": "docChatPhase", |         "phase_name": "docChatPhase", | ||||||
|         "phase_type": "BasePhase", |         "phase_type": "BasePhase", | ||||||
|         "chains": ["docChatChain"], |         "chains": ["docChatChain"], | ||||||
|         "do_summary": False, |  | ||||||
|         "do_search": False, |  | ||||||
|         "do_doc_retrieval": True, |         "do_doc_retrieval": True, | ||||||
|         "do_code_retrieval": False, |  | ||||||
|         "do_tool_retrieval": False, |  | ||||||
|         "do_using_tool": False |  | ||||||
|     }, |     }, | ||||||
|     "searchChatPhase": { |     "searchChatPhase": { | ||||||
|         "phase_name": "searchChatPhase", |         "phase_name": "searchChatPhase", | ||||||
|         "phase_type": "BasePhase", |         "phase_type": "BasePhase", | ||||||
|         "chains": ["searchChatChain"], |         "chains": ["searchChatChain"], | ||||||
|         "do_summary": False, |  | ||||||
|         "do_search": True, |         "do_search": True, | ||||||
|         "do_doc_retrieval": False, |  | ||||||
|         "do_code_retrieval": False, |  | ||||||
|         "do_tool_retrieval": False, |  | ||||||
|         "do_using_tool": False |  | ||||||
|     }, |     }, | ||||||
|     "codeChatPhase": { |     "codeChatPhase": { | ||||||
|         "phase_name": "codeChatPhase", |         "phase_name": "codeChatPhase", | ||||||
|         "phase_type": "BasePhase", |         "phase_type": "BasePhase", | ||||||
|         "chains": ["codeChatChain"], |         "chains": ["codeChatChain"], | ||||||
|         "do_summary": False, |  | ||||||
|         "do_search": False, |  | ||||||
|         "do_doc_retrieval": False, |  | ||||||
|         "do_code_retrieval": True, |         "do_code_retrieval": True, | ||||||
|         "do_tool_retrieval": False, |  | ||||||
|         "do_using_tool": False |  | ||||||
|     }, |     }, | ||||||
|     "toolReactPhase": { |     "toolReactPhase": { | ||||||
|         "phase_name": "toolReactPhase", |         "phase_name": "toolReactPhase", | ||||||
|         "phase_type": "BasePhase", |         "phase_type": "BasePhase", | ||||||
|         "chains": ["toolReactChain"], |         "chains": ["toolReactChain"], | ||||||
|         "do_summary": False, |  | ||||||
|         "do_search": False, |  | ||||||
|         "do_doc_retrieval": False, |  | ||||||
|         "do_code_retrieval": False, |  | ||||||
|         "do_tool_retrieval": False, |  | ||||||
|         "do_using_tool": True |         "do_using_tool": True | ||||||
|     }, |     }, | ||||||
|     "codeReactPhase": { |     "codeReactPhase": { | ||||||
| @ -59,55 +39,36 @@ PHASE_CONFIGS = { | |||||||
|         "phase_type": "BasePhase", |         "phase_type": "BasePhase", | ||||||
|         # "chains": ["codePlannerChain", "codeReactChain"], |         # "chains": ["codePlannerChain", "codeReactChain"], | ||||||
|         "chains": ["planChain", "codeReactChain"], |         "chains": ["planChain", "codeReactChain"], | ||||||
|         "do_summary": False, |  | ||||||
|         "do_search": False, |  | ||||||
|         "do_doc_retrieval": False, |  | ||||||
|         "do_code_retrieval": False, |  | ||||||
|         "do_tool_retrieval": False, |  | ||||||
|         "do_using_tool": False |  | ||||||
|     }, |     }, | ||||||
|     "codeToolReactPhase": { |     "codeToolReactPhase": { | ||||||
|         "phase_name": "codeToolReactPhase", |         "phase_name": "codeToolReactPhase", | ||||||
|         "phase_type": "BasePhase", |         "phase_type": "BasePhase", | ||||||
|         "chains": ["codeToolPlanChain", "codeToolReactChain"], |         "chains": ["codeToolPlanChain", "codeToolReactChain"], | ||||||
|         "do_summary": False, |  | ||||||
|         "do_search": False, |  | ||||||
|         "do_doc_retrieval": False, |  | ||||||
|         "do_code_retrieval": False, |  | ||||||
|         "do_tool_retrieval": False, |  | ||||||
|         "do_using_tool": True |         "do_using_tool": True | ||||||
|     }, |     }, | ||||||
|     "baseTaskPhase": { |     "baseTaskPhase": { | ||||||
|         "phase_name": "baseTaskPhase", |         "phase_name": "baseTaskPhase", | ||||||
|         "phase_type": "BasePhase", |         "phase_type": "BasePhase", | ||||||
|         "chains": ["planChain", "executorChain"], |         "chains": ["planChain", "executorChain"], | ||||||
|         "do_summary": False, |  | ||||||
|         "do_search": False, |  | ||||||
|         "do_doc_retrieval": False, |  | ||||||
|         "do_code_retrieval": False, |  | ||||||
|         "do_tool_retrieval": False, |  | ||||||
|         "do_using_tool": False |  | ||||||
|     }, |     }, | ||||||
|     "metagpt_code_devlop": { |     "metagpt_code_devlop": { | ||||||
|         "phase_name": "metagpt_code_devlop", |         "phase_name": "metagpt_code_devlop", | ||||||
|         "phase_type": "BasePhase", |         "phase_type": "BasePhase", | ||||||
|         "chains": ["metagptChain",], |         "chains": ["metagptChain",], | ||||||
|         "do_summary": False, |  | ||||||
|         "do_search": False, |  | ||||||
|         "do_doc_retrieval": False, |  | ||||||
|         "do_code_retrieval": False, |  | ||||||
|         "do_tool_retrieval": False, |  | ||||||
|         "do_using_tool": False |  | ||||||
|     }, |     }, | ||||||
|     "baseGroupPhase": { |     "baseGroupPhase": { | ||||||
|         "phase_name": "baseGroupPhase", |         "phase_name": "baseGroupPhase", | ||||||
|         "phase_type": "BasePhase", |         "phase_type": "BasePhase", | ||||||
|         "chains": ["baseGroupChain"], |         "chains": ["baseGroupChain"], | ||||||
|         "do_summary": False, |  | ||||||
|         "do_search": False, |  | ||||||
|         "do_doc_retrieval": False, |  | ||||||
|         "do_code_retrieval": False, |  | ||||||
|         "do_tool_retrieval": False, |  | ||||||
|         "do_using_tool": False |  | ||||||
|     }, |     }, | ||||||
|  |     "code2DocsGroup": { | ||||||
|  |         "phase_name": "code2DocsGroup", | ||||||
|  |         "phase_type": "BasePhase", | ||||||
|  |         "chains": ["code2DocsGroupChain"], | ||||||
|  |     }, | ||||||
|  |     "code2Tests": { | ||||||
|  |         "phase_name": "code2Tests", | ||||||
|  |         "phase_type": "BasePhase", | ||||||
|  |         "chains": ["code2TestsChain"], | ||||||
|  |     } | ||||||
| } | } | ||||||
|  | |||||||
| @ -41,3 +41,40 @@ SELECTOR_PROMPT_CONFIGS = [ | |||||||
|     {"field_name": 'output_format', "function_name": 'handle_output_format', 'title': 'Response Output Format', "is_context": False}, |     {"field_name": 'output_format', "function_name": 'handle_output_format', 'title': 'Response Output Format', "is_context": False}, | ||||||
|     {"field_name": 'begin!!!', "function_name": 'handle_response', "is_context": False, "omit_if_empty": False} |     {"field_name": 'begin!!!', "function_name": 'handle_response', "is_context": False, "omit_if_empty": False} | ||||||
|     ] |     ] | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | CODE2DOC_GROUP_PROMPT_CONFIGS = [ | ||||||
|  |     {"field_name": 'agent_profile', "function_name": 'handle_agent_profile', "is_context": False}, | ||||||
|  |     {"field_name": 'agent_infomation', "function_name": 'handle_agent_data', "is_context": False, "omit_if_empty": False}, | ||||||
|  |     # {"field_name": 'tool_information',"function_name": 'handle_tool_data', "is_context": False}, | ||||||
|  |     {"field_name": 'context_placeholder', "function_name": '', "is_context": True}, | ||||||
|  |     # {"field_name": 'reference_documents', "function_name": 'handle_doc_info'}, | ||||||
|  |     {"field_name": 'session_records', "function_name": 'handle_session_records'}, | ||||||
|  |     {"field_name": 'Specific Objective', "function_name": 'handle_specific_objective'}, | ||||||
|  |     {"field_name": 'Code Snippet', "function_name": 'handle_code_snippet'}, | ||||||
|  |     {"field_name": 'output_format', "function_name": 'handle_output_format', 'title': 'Response Output Format', "is_context": False}, | ||||||
|  |     {"field_name": 'begin!!!', "function_name": 'handle_response', "is_context": False, "omit_if_empty": False} | ||||||
|  | ] | ||||||
|  | 
 | ||||||
|  | CODE2DOC_PROMPT_CONFIGS = [ | ||||||
|  |     {"field_name": 'agent_profile', "function_name": 'handle_agent_profile', "is_context": False}, | ||||||
|  |     # {"field_name": 'tool_information',"function_name": 'handle_tool_data', "is_context": False}, | ||||||
|  |     {"field_name": 'context_placeholder', "function_name": '', "is_context": True}, | ||||||
|  |     # {"field_name": 'reference_documents', "function_name": 'handle_doc_info'}, | ||||||
|  |     {"field_name": 'session_records', "function_name": 'handle_session_records'}, | ||||||
|  |     {"field_name": 'Specific Objective', "function_name": 'handle_specific_objective'}, | ||||||
|  |     {"field_name": 'Code Snippet', "function_name": 'handle_code_snippet'}, | ||||||
|  |     {"field_name": 'output_format', "function_name": 'handle_output_format', 'title': 'Response Output Format', "is_context": False}, | ||||||
|  |     {"field_name": 'begin!!!', "function_name": 'handle_response', "is_context": False, "omit_if_empty": False} | ||||||
|  | ] | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | CODE2TESTS_PROMPT_CONFIGS = [ | ||||||
|  |     {"field_name": 'agent_profile', "function_name": 'handle_agent_profile', "is_context": False}, | ||||||
|  |     {"field_name": 'context_placeholder', "function_name": '', "is_context": True}, | ||||||
|  |     {"field_name": 'session_records', "function_name": 'handle_session_records'}, | ||||||
|  |     {"field_name": 'code_snippet', "function_name": 'handle_code_snippet'}, | ||||||
|  |     {"field_name": 'retrieval_codes', "function_name": 'handle_retrieval_codes', "description": ""}, | ||||||
|  |     {"field_name": 'output_format', "function_name": 'handle_output_format', 'title': 'Response Output Format', "is_context": False}, | ||||||
|  |     {"field_name": 'begin!!!', "function_name": 'handle_response', "is_context": False, "omit_if_empty": False} | ||||||
|  | ] | ||||||
| @ -14,7 +14,8 @@ from .qa_template_prompt import QA_PROMPT, CODE_QA_PROMPT, QA_TEMPLATE_PROMPT, C | |||||||
| 
 | 
 | ||||||
| from .executor_template_prompt import EXECUTOR_TEMPLATE_PROMPT | from .executor_template_prompt import EXECUTOR_TEMPLATE_PROMPT | ||||||
| from .refine_template_prompt import REFINE_TEMPLATE_PROMPT | from .refine_template_prompt import REFINE_TEMPLATE_PROMPT | ||||||
| 
 | from .code2doc_template_prompt import Code2DocGroup_PROMPT, Class2Doc_PROMPT, Func2Doc_PROMPT | ||||||
|  | from .code2test_template_prompt import code2Tests_PROMPT, judgeCode2Tests_PROMPT | ||||||
| from .agent_selector_template_prompt import SELECTOR_AGENT_TEMPLATE_PROMPT | from .agent_selector_template_prompt import SELECTOR_AGENT_TEMPLATE_PROMPT | ||||||
| 
 | 
 | ||||||
| from .react_template_prompt import REACT_TEMPLATE_PROMPT | from .react_template_prompt import REACT_TEMPLATE_PROMPT | ||||||
| @ -37,5 +38,7 @@ __all__ = [ | |||||||
|     "SELECTOR_AGENT_TEMPLATE_PROMPT", |     "SELECTOR_AGENT_TEMPLATE_PROMPT", | ||||||
|     "PLANNER_TEMPLATE_PROMPT", "GENERAL_PLANNER_PROMPT", "DATA_PLANNER_PROMPT", "TOOL_PLANNER_PROMPT", |     "PLANNER_TEMPLATE_PROMPT", "GENERAL_PLANNER_PROMPT", "DATA_PLANNER_PROMPT", "TOOL_PLANNER_PROMPT", | ||||||
|     "REACT_TEMPLATE_PROMPT",  |     "REACT_TEMPLATE_PROMPT",  | ||||||
|     "REACT_CODE_PROMPT", "REACT_TOOL_PROMPT", "REACT_TOOL_AND_CODE_PROMPT", "REACT_TOOL_AND_CODE_PLANNER_PROMPT" |     "REACT_CODE_PROMPT", "REACT_TOOL_PROMPT", "REACT_TOOL_AND_CODE_PROMPT", "REACT_TOOL_AND_CODE_PLANNER_PROMPT", | ||||||
|  |     "Code2DocGroup_PROMPT", "Class2Doc_PROMPT", "Func2Doc_PROMPT", | ||||||
|  |     "code2Tests_PROMPT", "judgeCode2Tests_PROMPT" | ||||||
| ] | ] | ||||||
| @ -0,0 +1,95 @@ | |||||||
|  | Code2DocGroup_PROMPT = """#### Agent Profile | ||||||
|  | 
 | ||||||
|  | Your goal is to response according the Context Data's information with the role that will best facilitate a solution, taking into account all relevant context (Context) provided. | ||||||
|  | 
 | ||||||
|  | When you need to select the appropriate role for handling a user's query, carefully read the provided role names, role descriptions and tool list. | ||||||
|  | 
 | ||||||
|  | ATTENTION: response carefully referenced "Response Output Format" in format. | ||||||
|  | 
 | ||||||
|  | #### Input Format | ||||||
|  | 
 | ||||||
|  | #### Response Output Format | ||||||
|  | 
 | ||||||
|  | **Code Path:** Extract the paths for the class/method/function that need to be addressed from the context | ||||||
|  | 
 | ||||||
|  | **Role:** Select the role from agent names | ||||||
|  | """ | ||||||
|  | 
 | ||||||
|  | Class2Doc_PROMPT = """#### Agent Profile | ||||||
|  | As an advanced code documentation generator, you are proficient in translating class definitions into comprehensive documentation with a focus on instantiation parameters.  | ||||||
|  | Your specific task is to parse the given code snippet of a class, extract information regarding its instantiation parameters. | ||||||
|  | 
 | ||||||
|  | ATTENTION: response carefully in "Response Output Format". | ||||||
|  | 
 | ||||||
|  | #### Input Format | ||||||
|  | 
 | ||||||
|  | **Code Snippet:** Provide the full class definition, including the constructor and any parameters it may require for instantiation. | ||||||
|  | 
 | ||||||
|  | #### Response Output Format | ||||||
|  | **Class Base:** Specify the base class or interface from which the current class extends, if any. | ||||||
|  | 
 | ||||||
|  | **Class Description:** Offer a brief description of the class's purpose and functionality. | ||||||
|  | 
 | ||||||
|  | **Init Parameters:** List each parameter from construct. For each parameter, provide: | ||||||
|  |     - `param`: The parameter name | ||||||
|  |     - `param_description`: A concise explanation of the parameter's purpose. | ||||||
|  |     - `param_type`: The data type of the parameter, if explicitly defined. | ||||||
|  | 
 | ||||||
|  |     ```json | ||||||
|  |     [ | ||||||
|  |         { | ||||||
|  |             "param": "parameter_name", | ||||||
|  |             "param_description": "A brief description of what this parameter is used for.", | ||||||
|  |             "param_type": "The data type of the parameter" | ||||||
|  |         }, | ||||||
|  |         ... | ||||||
|  |     ] | ||||||
|  |     ``` | ||||||
|  | 
 | ||||||
|  |          | ||||||
|  |     If no parameter for construct, return  | ||||||
|  |     ```json | ||||||
|  |     [] | ||||||
|  |     ``` | ||||||
|  | """ | ||||||
|  | 
 | ||||||
|  | Func2Doc_PROMPT = """#### Agent Profile | ||||||
|  | You are a high-level code documentation assistant, skilled at extracting information from function/method code into detailed and well-structured documentation. | ||||||
|  | 
 | ||||||
|  | ATTENTION: response carefully in "Response Output Format". | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | #### Input Format | ||||||
|  | **Code Path:** Provide the code path of the function or method you wish to document.  | ||||||
|  | This name will be used to identify and extract the relevant details from the code snippet provided. | ||||||
|  |      | ||||||
|  | **Code Snippet:** A segment of code that contains the function or method to be documented. | ||||||
|  | 
 | ||||||
|  | #### Response Output Format | ||||||
|  | 
 | ||||||
|  | **Class Description:** Offer a brief description of the method(function)'s purpose and functionality. | ||||||
|  | 
 | ||||||
|  | **Parameters:** Extract parameter for the specific function/method Code from Code Snippet. For parameter, provide: | ||||||
|  |     - `param`: The parameter name | ||||||
|  |     - `param_description`: A concise explanation of the parameter's purpose. | ||||||
|  |     - `param_type`: The data type of the parameter, if explicitly defined. | ||||||
|  |     ```json | ||||||
|  |     [ | ||||||
|  |         { | ||||||
|  |             "param": "parameter_name", | ||||||
|  |             "param_description": "A brief description of what this parameter is used for.", | ||||||
|  |             "param_type": "The data type of the parameter" | ||||||
|  |         }, | ||||||
|  |         ... | ||||||
|  |     ] | ||||||
|  |     ``` | ||||||
|  | 
 | ||||||
|  |     If no parameter for function/method, return  | ||||||
|  |     ```json | ||||||
|  |     [] | ||||||
|  |     ``` | ||||||
|  | 
 | ||||||
|  | **Return Value Description:** Describe what the function/method returns upon completion. | ||||||
|  | 
 | ||||||
|  | **Return Type:** Indicate the type of data the function/method returns (e.g., string, integer, object, void). | ||||||
|  | """ | ||||||
| @ -0,0 +1,65 @@ | |||||||
|  | judgeCode2Tests_PROMPT = """#### Agent Profile | ||||||
|  | When determining the necessity of writing test cases for a given code snippet,  | ||||||
|  | it's essential to evaluate its interactions with dependent classes and methods (retrieved code snippets),  | ||||||
|  | in addition to considering these critical factors: | ||||||
|  | 1. Functionality: If it implements a concrete function or logic, test cases are typically necessary to verify its correctness. | ||||||
|  | 2. Complexity: If the code is complex, especially if it contains multiple conditional statements, loops, exceptions handling, etc.,  | ||||||
|  | it's more likely to harbor bugs, and thus test cases should be written.  | ||||||
|  | If the code involves complex algorithms or logic, then writing test cases can help ensure the accuracy of the logic and prevent errors during future refactoring. | ||||||
|  | 3. Criticality: If it's part of the critical path or affects core functionalities, then it needs to be tested.  | ||||||
|  | Comprehensive test cases should be written for core business logic or key components of the system to ensure the correctness and stability of the functionality. | ||||||
|  | 4. Dependencies: If the code has external dependencies, integration testing may be necessary, or mocking these dependencies during unit testing might be required. | ||||||
|  | 5. User Input: If the code handles user input, especially from unregulated external sources, creating test cases to check input validation and handling is important. | ||||||
|  | 6. Frequent Changes: For code that requires regular updates or modifications, having the appropriate test cases ensures that changes do not break existing functionalities. | ||||||
|  | 
 | ||||||
|  | #### Input Format | ||||||
|  | 
 | ||||||
|  | **Code Snippet:** the initial Code or objective that the user wanted to achieve | ||||||
|  | 
 | ||||||
|  | **Retrieval Code Snippets:** These are the associated code segments that the main Code Snippet depends on.  | ||||||
|  | Examine these snippets to understand how they interact with the main snippet and to determine how they might affect the overall functionality. | ||||||
|  | 
 | ||||||
|  | #### Response Output Format | ||||||
|  | **Action Status:** Set to 'finished' or 'continued'.  | ||||||
|  | If set to 'finished', the code snippet does not warrant the generation of a test case. | ||||||
|  | If set to 'continued', the code snippet necessitates the creation of a test case. | ||||||
|  | 
 | ||||||
|  | **REASON:** Justify the selection of 'finished' or 'continued', contemplating the decision through a step-by-step rationale. | ||||||
|  | """ | ||||||
|  | 
 | ||||||
|  | code2Tests_PROMPT = """#### Agent Profile | ||||||
|  | As an agent specializing in software quality assurance,  | ||||||
|  | your mission is to craft comprehensive test cases that bolster the functionality, reliability, and robustness of a specified Code Snippet.  | ||||||
|  | This task is to be carried out with a keen understanding of the snippet's interactions with its dependent classes and methods—collectively referred to as Retrieval Code Snippets.  | ||||||
|  | Analyze the details given below to grasp the code's intended purpose, its inherent complexity, and the context within which it operates.  | ||||||
|  | Your constructed test cases must thoroughly examine the various factors influencing the code's quality and performance. | ||||||
|  | 
 | ||||||
|  | ATTENTION: response carefully referenced "Response Output Format" in format. | ||||||
|  | 
 | ||||||
|  | Each test case should include: | ||||||
|  | 1. clear description of the test purpose. | ||||||
|  | 2. The input values or conditions for the test. | ||||||
|  | 3. The expected outcome or assertion for the test. | ||||||
|  | 4. Appropriate tags (e.g., 'functional', 'integration', 'regression') that classify the type of test case. | ||||||
|  | 5. these test code should have package and import | ||||||
|  | 
 | ||||||
|  | #### Input Format | ||||||
|  | 
 | ||||||
|  | **Code Snippet:** the initial Code or objective that the user wanted to achieve | ||||||
|  | 
 | ||||||
|  | **Retrieval Code Snippets:** These are the interrelated pieces of code sourced from the codebase, which support or influence the primary Code Snippet. | ||||||
|  | 
 | ||||||
|  | #### Response Output Format | ||||||
|  | **SaveFileName:** construct a local file name based on Question and Context, such as | ||||||
|  | 
 | ||||||
|  | ```java | ||||||
|  | package/class.java | ||||||
|  | ``` | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | **Test Code:** generate the test code for the current Code Snippet. | ||||||
|  | ```java | ||||||
|  | ... | ||||||
|  | ``` | ||||||
|  | 
 | ||||||
|  | """ | ||||||
| @ -1,5 +1,5 @@ | |||||||
| from abc import abstractmethod, ABC | from abc import abstractmethod, ABC | ||||||
| from typing import List | from typing import List, Dict | ||||||
| import os, sys, copy, json | import os, sys, copy, json | ||||||
| from jieba.analyse import extract_tags | from jieba.analyse import extract_tags | ||||||
| from collections import Counter | from collections import Counter | ||||||
| @ -10,12 +10,13 @@ from langchain.docstore.document import Document | |||||||
| 
 | 
 | ||||||
| from .schema import Memory, Message | from .schema import Memory, Message | ||||||
| from coagent.service.service_factory import KBServiceFactory | from coagent.service.service_factory import KBServiceFactory | ||||||
| from coagent.llm_models import getChatModel, getChatModelFromConfig | from coagent.llm_models import getChatModelFromConfig | ||||||
| from coagent.llm_models.llm_config import EmbedConfig, LLMConfig | from coagent.llm_models.llm_config import EmbedConfig, LLMConfig | ||||||
| from coagent.embeddings.utils import load_embeddings_from_path | from coagent.embeddings.utils import load_embeddings_from_path | ||||||
| from coagent.utils.common_utils import save_to_json_file, read_json_file, addMinutesToTime | from coagent.utils.common_utils import save_to_json_file, read_json_file, addMinutesToTime | ||||||
| from coagent.connector.configs.prompts import CONV_SUMMARY_PROMPT_SPEC | from coagent.connector.configs.prompts import CONV_SUMMARY_PROMPT_SPEC | ||||||
| from coagent.orm import table_init | from coagent.orm import table_init | ||||||
|  | from coagent.base_configs.env_config import KB_ROOT_PATH | ||||||
| # from configs.model_config import KB_ROOT_PATH, EMBEDDING_MODEL, EMBEDDING_DEVICE, SCORE_THRESHOLD | # from configs.model_config import KB_ROOT_PATH, EMBEDDING_MODEL, EMBEDDING_DEVICE, SCORE_THRESHOLD | ||||||
| # from configs.model_config import embedding_model_dict | # from configs.model_config import embedding_model_dict | ||||||
| 
 | 
 | ||||||
| @ -70,16 +71,22 @@ class BaseMemoryManager(ABC): | |||||||
|         self.unique_name = unique_name |         self.unique_name = unique_name | ||||||
|         self.memory_type = memory_type |         self.memory_type = memory_type | ||||||
|         self.do_init = do_init |         self.do_init = do_init | ||||||
|         self.current_memory = Memory(messages=[]) |         # self.current_memory = Memory(messages=[]) | ||||||
|         self.recall_memory = Memory(messages=[]) |         # self.recall_memory = Memory(messages=[]) | ||||||
|         self.summary_memory = Memory(messages=[]) |         # self.summary_memory = Memory(messages=[]) | ||||||
|  |         self.current_memory_dict: Dict[str, Memory] = {} | ||||||
|  |         self.recall_memory_dict: Dict[str, Memory] = {} | ||||||
|  |         self.summary_memory_dict: Dict[str, Memory] = {} | ||||||
|         self.save_message_keys = [ |         self.save_message_keys = [ | ||||||
|             'chat_index', 'role_name', 'role_type', 'role_prompt', 'input_query', 'origin_query', |             'chat_index', 'role_name', 'role_type', 'role_prompt', 'input_query', 'origin_query', | ||||||
|             'datetime', 'role_content', 'step_content', 'parsed_output', 'spec_parsed_output', 'parsed_output_list',  |             'datetime', 'role_content', 'step_content', 'parsed_output', 'spec_parsed_output', 'parsed_output_list',  | ||||||
|             'task', 'db_docs', 'code_docs', 'search_docs', 'phase_name', 'chain_name', 'customed_kargs'] |             'task', 'db_docs', 'code_docs', 'search_docs', 'phase_name', 'chain_name', 'customed_kargs'] | ||||||
|         self.init_vb() |         self.init_vb() | ||||||
| 
 | 
 | ||||||
|     def init_vb(self): |     def re_init(self, do_init: bool=False): | ||||||
|  |         self.init_vb() | ||||||
|  | 
 | ||||||
|  |     def init_vb(self, do_init: bool=None): | ||||||
|         """ |         """ | ||||||
|         Initializes the vb. |         Initializes the vb. | ||||||
|         """ |         """ | ||||||
| @ -135,13 +142,15 @@ class BaseMemoryManager(ABC): | |||||||
|         """ |         """ | ||||||
|         pass |         pass | ||||||
| 
 | 
 | ||||||
|     def save_to_vs(self, embed_model="", embed_device=""): |     def save_to_vs(self, ): | ||||||
|         """ |         """ | ||||||
|         Saves the memory to the vector space. |         Saves the memory to the vector space. | ||||||
|  |         """ | ||||||
|  |         pass | ||||||
| 
 | 
 | ||||||
|         Args: |     def get_memory_pool(self, user_name: str, ): | ||||||
|         - embed_model: A string representing the embedding model. Default is EMBEDDING_MODEL. |         """ | ||||||
|         - embed_device: A string representing the embedding device. Default is EMBEDDING_DEVICE. |         return memory_pool | ||||||
|         """ |         """ | ||||||
|         pass |         pass | ||||||
| 
 | 
 | ||||||
| @ -230,7 +239,7 @@ class LocalMemoryManager(BaseMemoryManager): | |||||||
|             unique_name: str = "default", |             unique_name: str = "default", | ||||||
|             memory_type: str = "recall", |             memory_type: str = "recall", | ||||||
|             do_init: bool = False, |             do_init: bool = False, | ||||||
|             kb_root_path: str = "", |             kb_root_path: str = KB_ROOT_PATH, | ||||||
|         ): |         ): | ||||||
|         self.user_name = user_name |         self.user_name = user_name | ||||||
|         self.unique_name = unique_name |         self.unique_name = unique_name | ||||||
| @ -239,16 +248,22 @@ class LocalMemoryManager(BaseMemoryManager): | |||||||
|         self.kb_root_path = kb_root_path |         self.kb_root_path = kb_root_path | ||||||
|         self.embed_config: EmbedConfig = embed_config |         self.embed_config: EmbedConfig = embed_config | ||||||
|         self.llm_config: LLMConfig = llm_config |         self.llm_config: LLMConfig = llm_config | ||||||
|         self.current_memory = Memory(messages=[]) |         # self.current_memory = Memory(messages=[]) | ||||||
|         self.recall_memory = Memory(messages=[]) |         # self.recall_memory = Memory(messages=[]) | ||||||
|         self.summary_memory = Memory(messages=[]) |         # self.summary_memory = Memory(messages=[]) | ||||||
|  |         self.current_memory_dict: Dict[str, Memory] = {} | ||||||
|  |         self.recall_memory_dict: Dict[str, Memory] = {} | ||||||
|  |         self.summary_memory_dict: Dict[str, Memory] = {} | ||||||
|         self.save_message_keys = [ |         self.save_message_keys = [ | ||||||
|             'chat_index', 'role_name', 'role_type', 'role_prompt', 'input_query', 'origin_query', |             'chat_index', 'role_name', 'role_type', 'role_prompt', 'input_query', 'origin_query', | ||||||
|             'datetime', 'role_content', 'step_content', 'parsed_output', 'spec_parsed_output', 'parsed_output_list',  |             'datetime', 'role_content', 'step_content', 'parsed_output', 'spec_parsed_output', 'parsed_output_list',  | ||||||
|             'task', 'db_docs', 'code_docs', 'search_docs', 'phase_name', 'chain_name', 'customed_kargs'] |             'task', 'db_docs', 'code_docs', 'search_docs', 'phase_name', 'chain_name', 'customed_kargs'] | ||||||
|         self.init_vb() |         self.init_vb() | ||||||
| 
 | 
 | ||||||
|     def init_vb(self): |     def re_init(self, do_init: bool=False): | ||||||
|  |         self.init_vb(do_init) | ||||||
|  | 
 | ||||||
|  |     def init_vb(self, do_init: bool=None): | ||||||
|         vb_name = f"{self.user_name}/{self.unique_name}/{self.memory_type}" |         vb_name = f"{self.user_name}/{self.unique_name}/{self.memory_type}" | ||||||
|         # default to recreate a new vb |         # default to recreate a new vb | ||||||
|         table_init() |         table_init() | ||||||
| @ -256,31 +271,37 @@ class LocalMemoryManager(BaseMemoryManager): | |||||||
|         if vb: |         if vb: | ||||||
|             status = vb.clear_vs() |             status = vb.clear_vs() | ||||||
| 
 | 
 | ||||||
|         if not self.do_init: |         check_do_init = do_init if do_init else self.do_init | ||||||
|  |         if not check_do_init: | ||||||
|             self.load(self.kb_root_path) |             self.load(self.kb_root_path) | ||||||
|             self.save_to_vs() |             self.save_to_vs() | ||||||
| 
 | 
 | ||||||
|     def append(self, message: Message): |     def append(self, message: Message): | ||||||
|         self.recall_memory.append(message) |         self.check_user_name(message.user_name) | ||||||
|  | 
 | ||||||
|  |         uuid_name = "_".join([self.user_name, self.unique_name, self.memory_type]) | ||||||
|  |         self.recall_memory_dict[uuid_name].append(message) | ||||||
|         #  |         #  | ||||||
|         if message.role_type == "summary": |         if message.role_type == "summary": | ||||||
|             self.summary_memory.append(message) |             self.summary_memory_dict[uuid_name].append(message) | ||||||
|         else: |         else: | ||||||
|             self.current_memory.append(message) |             self.current_memory_dict[uuid_name].append(message) | ||||||
| 
 | 
 | ||||||
|         self.save(self.kb_root_path) |         self.save(self.kb_root_path) | ||||||
|         self.save_new_to_vs([message]) |         self.save_new_to_vs([message]) | ||||||
| 
 | 
 | ||||||
|     def extend(self, memory: Memory): |     # def extend(self, memory: Memory): | ||||||
|         self.recall_memory.extend(memory) |     #     self.recall_memory.extend(memory) | ||||||
|         self.current_memory.extend(self.recall_memory.filter_by_role_type(["summary"])) |     #     self.current_memory.extend(self.recall_memory.filter_by_role_type(["summary"])) | ||||||
|         self.summary_memory.extend(self.recall_memory.select_by_role_type(["summary"])) |     #     self.summary_memory.extend(self.recall_memory.select_by_role_type(["summary"])) | ||||||
|         self.save(self.kb_root_path) |     #     self.save(self.kb_root_path) | ||||||
|         self.save_new_to_vs(memory.messages) |     #     self.save_new_to_vs(memory.messages) | ||||||
| 
 | 
 | ||||||
|     def save(self, save_dir: str = "./"): |     def save(self, save_dir: str = "./"): | ||||||
|         file_path = os.path.join(save_dir, f"{self.user_name}/{self.unique_name}/{self.memory_type}/converation.jsonl") |         file_path = os.path.join(save_dir, f"{self.user_name}/{self.unique_name}/{self.memory_type}/converation.jsonl") | ||||||
|         memory_messages = self.recall_memory.dict() |         uuid_name = "_".join([self.user_name, self.unique_name, self.memory_type]) | ||||||
|  | 
 | ||||||
|  |         memory_messages = self.recall_memory_dict[uuid_name].dict() | ||||||
|         memory_messages = {k: [ |         memory_messages = {k: [ | ||||||
|                 {kkk: vvv for kkk, vvv in vv.items() if kkk in self.save_message_keys} |                 {kkk: vvv for kkk, vvv in vv.items() if kkk in self.save_message_keys} | ||||||
|                 for vv in v ]  |                 for vv in v ]  | ||||||
| @ -291,18 +312,28 @@ class LocalMemoryManager(BaseMemoryManager): | |||||||
| 
 | 
 | ||||||
|     def load(self, load_dir: str = "./") -> Memory: |     def load(self, load_dir: str = "./") -> Memory: | ||||||
|         file_path = os.path.join(load_dir, f"{self.user_name}/{self.unique_name}/{self.memory_type}/converation.jsonl") |         file_path = os.path.join(load_dir, f"{self.user_name}/{self.unique_name}/{self.memory_type}/converation.jsonl") | ||||||
|  |         uuid_name = "_".join([self.user_name, self.unique_name, self.memory_type]) | ||||||
| 
 | 
 | ||||||
|         if os.path.exists(file_path): |         if os.path.exists(file_path): | ||||||
|             self.recall_memory = Memory(**read_json_file(file_path)) |             # self.recall_memory = Memory(**read_json_file(file_path)) | ||||||
|             self.current_memory = Memory(messages=self.recall_memory.filter_by_role_type(["summary"])) |             # self.current_memory = Memory(messages=self.recall_memory.filter_by_role_type(["summary"])) | ||||||
|             self.summary_memory = Memory(messages=self.recall_memory.select_by_role_type(["summary"])) |             # self.summary_memory = Memory(messages=self.recall_memory.select_by_role_type(["summary"])) | ||||||
|  | 
 | ||||||
|  |             recall_memory = Memory(**read_json_file(file_path)) | ||||||
|  |             self.recall_memory_dict[uuid_name] = recall_memory | ||||||
|  |             self.current_memory_dict[uuid_name] = Memory(messages=recall_memory.filter_by_role_type(["summary"])) | ||||||
|  |             self.summary_memory_dict[uuid_name] = Memory(messages=recall_memory.select_by_role_type(["summary"])) | ||||||
|  |         else: | ||||||
|  |             self.recall_memory_dict[uuid_name] = Memory(messages=[]) | ||||||
|  |             self.current_memory_dict[uuid_name] = Memory(messages=[]) | ||||||
|  |             self.summary_memory_dict[uuid_name] = Memory(messages=[]) | ||||||
| 
 | 
 | ||||||
|     def save_new_to_vs(self, messages: List[Message]): |     def save_new_to_vs(self, messages: List[Message]): | ||||||
|         if self.embed_config: |         if self.embed_config: | ||||||
|             vb_name = f"{self.user_name}/{self.unique_name}/{self.memory_type}" |             vb_name = f"{self.user_name}/{self.unique_name}/{self.memory_type}" | ||||||
|             # default to faiss, todo: add new vstype |             # default to faiss, todo: add new vstype | ||||||
|             vb = KBServiceFactory.get_service(vb_name, "faiss", self.embed_config, self.kb_root_path) |             vb = KBServiceFactory.get_service(vb_name, "faiss", self.embed_config, self.kb_root_path) | ||||||
|             embeddings = load_embeddings_from_path(self.embed_config.embed_model_path, self.embed_config.model_device,) |             embeddings = load_embeddings_from_path(self.embed_config.embed_model_path, self.embed_config.model_device, self.embed_config.langchain_embeddings) | ||||||
|             messages = [ |             messages = [ | ||||||
|                     {k: v for k, v in m.dict().items() if k in self.save_message_keys} |                     {k: v for k, v in m.dict().items() if k in self.save_message_keys} | ||||||
|                     for m in messages]  |                     for m in messages]  | ||||||
| @ -311,23 +342,26 @@ class LocalMemoryManager(BaseMemoryManager): | |||||||
|             vb.do_add_doc(docs, embeddings) |             vb.do_add_doc(docs, embeddings) | ||||||
| 
 | 
 | ||||||
|     def save_to_vs(self): |     def save_to_vs(self): | ||||||
|         vb_name = f"{self.user_name}/{self.unique_name}/{self.memory_type}" |         '''only after load''' | ||||||
|         # default to recreate a new vb |         if self.embed_config: | ||||||
|         vb = KBServiceFactory.get_service_by_name(vb_name, self.embed_config, self.kb_root_path) |             vb_name = f"{self.user_name}/{self.unique_name}/{self.memory_type}" | ||||||
|         if vb: |             uuid_name = "_".join([self.user_name, self.unique_name, self.memory_type]) | ||||||
|             status = vb.clear_vs() |             # default to recreate a new vb | ||||||
|         # create_kb(vb_name, "faiss", embed_model) |             vb = KBServiceFactory.get_service_by_name(vb_name, self.embed_config, self.kb_root_path) | ||||||
|  |             if vb: | ||||||
|  |                 status = vb.clear_vs() | ||||||
|  |             # create_kb(vb_name, "faiss", embed_model) | ||||||
| 
 | 
 | ||||||
|         # default to faiss, todo: add new vstype |             # default to faiss, todo: add new vstype | ||||||
|         vb = KBServiceFactory.get_service(vb_name, "faiss", self.embed_config, self.kb_root_path) |             vb = KBServiceFactory.get_service(vb_name, "faiss", self.embed_config, self.kb_root_path) | ||||||
|         embeddings = load_embeddings_from_path(self.embed_config.embed_model_path, self.embed_config.model_device,) |             embeddings = load_embeddings_from_path(self.embed_config.embed_model_path, self.embed_config.model_device, self.embed_config.langchain_embeddings) | ||||||
|         messages = self.recall_memory.dict() |             messages = self.recall_memory_dict[uuid_name].dict() | ||||||
|         messages = [ |             messages = [ | ||||||
|                 {kkk: vvv for kkk, vvv in vv.items() if kkk in self.save_message_keys} |                     {kkk: vvv for kkk, vvv in vv.items() if kkk in self.save_message_keys} | ||||||
|                 for k, v in messages.items() for vv in v]  |                     for k, v in messages.items() for vv in v]  | ||||||
|         docs = [{"page_content": m["step_content"] or m["role_content"] or m["input_query"] or m["origin_query"], "metadata": m} for m in messages] |             docs = [{"page_content": m["step_content"] or m["role_content"] or m["input_query"] or m["origin_query"], "metadata": m} for m in messages] | ||||||
|         docs = [Document(**doc) for doc in docs] |             docs = [Document(**doc) for doc in docs] | ||||||
|         vb.do_add_doc(docs, embeddings) |             vb.do_add_doc(docs, embeddings) | ||||||
| 
 | 
 | ||||||
|     # def load_from_vs(self, embed_model=EMBEDDING_MODEL) -> Memory: |     # def load_from_vs(self, embed_model=EMBEDDING_MODEL) -> Memory: | ||||||
|     #     vb_name = f"{self.user_name}/{self.unique_name}/{self.memory_type}" |     #     vb_name = f"{self.user_name}/{self.unique_name}/{self.memory_type}" | ||||||
| @ -338,7 +372,12 @@ class LocalMemoryManager(BaseMemoryManager): | |||||||
|     #     docs =  vb.get_all_documents() |     #     docs =  vb.get_all_documents() | ||||||
|     #     print(docs) |     #     print(docs) | ||||||
| 
 | 
 | ||||||
|     def router_retrieval(self, text: str=None, datetime: str = None, n=5, top_k=5, retrieval_type: str = "embedding", **kwargs) -> List[Message]: |     def get_memory_pool(self, user_name: str, ): | ||||||
|  |         self.check_user_name(user_name) | ||||||
|  |         uuid_name = "_".join([self.user_name, self.unique_name, self.memory_type]) | ||||||
|  |         return self.recall_memory_dict[uuid_name] | ||||||
|  | 
 | ||||||
|  |     def router_retrieval(self, user_name: str = "default", text: str=None, datetime: str = None, n=5, top_k=5, retrieval_type: str = "embedding", **kwargs) -> List[Message]: | ||||||
|         retrieval_func_dict = { |         retrieval_func_dict = { | ||||||
|             "embedding": self.embedding_retrieval, "text": self.text_retrieval, "datetime": self.datetime_retrieval |             "embedding": self.embedding_retrieval, "text": self.text_retrieval, "datetime": self.datetime_retrieval | ||||||
|             } |             } | ||||||
| @ -356,20 +395,22 @@ class LocalMemoryManager(BaseMemoryManager): | |||||||
|         #  |         #  | ||||||
|         return retrieval_func(**params) |         return retrieval_func(**params) | ||||||
|          |          | ||||||
|     def embedding_retrieval(self, text: str, top_k=1, score_threshold=1.0, **kwargs) -> List[Message]: |     def embedding_retrieval(self, text: str, top_k=1, score_threshold=1.0, user_name: str = "default", **kwargs) -> List[Message]: | ||||||
|         if text is None: return [] |         if text is None: return [] | ||||||
|         vb_name = f"{self.user_name}/{self.unique_name}/{self.memory_type}" |         vb_name = f"{user_name}/{self.unique_name}/{self.memory_type}" | ||||||
|         vb = KBServiceFactory.get_service(vb_name, "faiss", self.embed_config, self.kb_root_path) |         vb = KBServiceFactory.get_service(vb_name, "faiss", self.embed_config, self.kb_root_path) | ||||||
|         docs = vb.search_docs(text, top_k=top_k, score_threshold=score_threshold) |         docs = vb.search_docs(text, top_k=top_k, score_threshold=score_threshold) | ||||||
|         return [Message(**doc.metadata) for doc, score in docs] |         return [Message(**doc.metadata) for doc, score in docs] | ||||||
|      |      | ||||||
|     def text_retrieval(self, text: str, **kwargs)  -> List[Message]: |     def text_retrieval(self, text: str, user_name: str = "default", **kwargs)  -> List[Message]: | ||||||
|         if text is None: return [] |         if text is None: return [] | ||||||
|         return self._text_retrieval_from_cache(self.recall_memory.messages, text, score_threshold=0.3, topK=5, **kwargs) |         uuid_name = "_".join([user_name, self.unique_name, self.memory_type]) | ||||||
|  |         return self._text_retrieval_from_cache(self.recall_memory_dict[uuid_name].messages, text, score_threshold=0.3, topK=5, **kwargs) | ||||||
| 
 | 
 | ||||||
|     def datetime_retrieval(self, datetime: str, text: str = None, n: int = 5, **kwargs) -> List[Message]: |     def datetime_retrieval(self,  datetime: str, text: str = None, n: int = 5, user_name: str = "default", **kwargs) -> List[Message]: | ||||||
|         if datetime is None: return [] |         if datetime is None: return [] | ||||||
|         return self._datetime_retrieval_from_cache(self.recall_memory.messages, datetime, text, n, **kwargs) |         uuid_name = "_".join([user_name, self.unique_name, self.memory_type]) | ||||||
|  |         return self._datetime_retrieval_from_cache(self.recall_memory_dict[uuid_name].messages, datetime, text, n, **kwargs) | ||||||
|      |      | ||||||
|     def _text_retrieval_from_cache(self, messages: List[Message], text: str = None, score_threshold=0.3, topK=5, tag_topK=5, **kwargs) -> List[Message]: |     def _text_retrieval_from_cache(self, messages: List[Message], text: str = None, score_threshold=0.3, topK=5, tag_topK=5, **kwargs) -> List[Message]: | ||||||
|         keywords = extract_tags(text, topK=tag_topK) |         keywords = extract_tags(text, topK=tag_topK) | ||||||
| @ -428,3 +469,17 @@ class LocalMemoryManager(BaseMemoryManager): | |||||||
|         summary_message.parsed_output_list.append({"summary": content}) |         summary_message.parsed_output_list.append({"summary": content}) | ||||||
|         newest_messages.insert(0, summary_message) |         newest_messages.insert(0, summary_message) | ||||||
|         return newest_messages |         return newest_messages | ||||||
|  |      | ||||||
|  |     def check_user_name(self, user_name: str): | ||||||
|  |         # logger.debug(f"self.user_name is {self.user_name}") | ||||||
|  |         if user_name != self.user_name: | ||||||
|  |             self.user_name = user_name | ||||||
|  |             self.init_vb() | ||||||
|  | 
 | ||||||
|  |         uuid_name = "_".join([self.user_name, self.unique_name, self.memory_type]) | ||||||
|  |         if uuid_name not in self.recall_memory_dict: | ||||||
|  |             self.recall_memory_dict[uuid_name] = Memory(messages=[]) | ||||||
|  |             self.current_memory_dict[uuid_name] = Memory(messages=[]) | ||||||
|  |             self.summary_memory_dict[uuid_name] = Memory(messages=[]) | ||||||
|  | 
 | ||||||
|  |         # logger.debug(f"self.user_name is {self.user_name}") | ||||||
| @ -1,16 +1,19 @@ | |||||||
| import re, traceback, uuid, copy, json, os | import re, traceback, uuid, copy, json, os | ||||||
|  | from typing import Union | ||||||
| from loguru import logger | from loguru import logger | ||||||
| 
 | 
 | ||||||
|  | from langchain.schema import BaseRetriever | ||||||
| 
 | 
 | ||||||
| # from configs.server_config import SANDBOX_SERVER |  | ||||||
| # from configs.model_config import JUPYTER_WORK_PATH |  | ||||||
| from coagent.connector.schema import ( | from coagent.connector.schema import ( | ||||||
|     Memory, Role, Message, ActionStatus, CodeDoc, Doc, LogVerboseEnum |     Memory, Role, Message, ActionStatus, CodeDoc, Doc, LogVerboseEnum | ||||||
| ) | ) | ||||||
|  | from coagent.retrieval.base_retrieval import IMRertrieval | ||||||
| from coagent.connector.memory_manager import BaseMemoryManager | from coagent.connector.memory_manager import BaseMemoryManager | ||||||
| from coagent.tools import DDGSTool, DocRetrieval, CodeRetrieval | from coagent.tools import DDGSTool, DocRetrieval, CodeRetrieval | ||||||
| from coagent.sandbox import PyCodeBox, CodeBoxResponse | from coagent.sandbox import PyCodeBox, CodeBoxResponse | ||||||
| from coagent.llm_models.llm_config import LLMConfig, EmbedConfig | from coagent.llm_models.llm_config import LLMConfig, EmbedConfig | ||||||
|  | from coagent.base_configs.env_config import JUPYTER_WORK_PATH | ||||||
|  | 
 | ||||||
| from .utils import parse_dict_to_dict, parse_text_to_dict | from .utils import parse_dict_to_dict, parse_text_to_dict | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @ -19,10 +22,13 @@ class MessageUtils: | |||||||
|             self,  |             self,  | ||||||
|             role: Role = None, |             role: Role = None, | ||||||
|             sandbox_server: dict = {}, |             sandbox_server: dict = {}, | ||||||
|             jupyter_work_path: str = "./", |             jupyter_work_path: str = JUPYTER_WORK_PATH, | ||||||
|             embed_config: EmbedConfig = None, |             embed_config: EmbedConfig = None, | ||||||
|             llm_config: LLMConfig = None, |             llm_config: LLMConfig = None, | ||||||
|             kb_root_path: str = "", |             kb_root_path: str = "", | ||||||
|  |             doc_retrieval: Union[BaseRetriever, IMRertrieval] = None, | ||||||
|  |             code_retrieval: IMRertrieval = None, | ||||||
|  |             search_retrieval: IMRertrieval = None, | ||||||
|             log_verbose: str = "0" |             log_verbose: str = "0" | ||||||
|         ) -> None: |         ) -> None: | ||||||
|         self.role = role |         self.role = role | ||||||
| @ -31,6 +37,9 @@ class MessageUtils: | |||||||
|         self.embed_config = embed_config |         self.embed_config = embed_config | ||||||
|         self.llm_config = llm_config |         self.llm_config = llm_config | ||||||
|         self.kb_root_path = kb_root_path |         self.kb_root_path = kb_root_path | ||||||
|  |         self.doc_retrieval = doc_retrieval | ||||||
|  |         self.code_retrieval = code_retrieval | ||||||
|  |         self.search_retrieval = search_retrieval | ||||||
|         self.codebox = PyCodeBox( |         self.codebox = PyCodeBox( | ||||||
|                     remote_url=self.sandbox_server.get("url", "http://127.0.0.1:5050"), |                     remote_url=self.sandbox_server.get("url", "http://127.0.0.1:5050"), | ||||||
|                     remote_ip=self.sandbox_server.get("host", "http://127.0.0.1"), |                     remote_ip=self.sandbox_server.get("host", "http://127.0.0.1"), | ||||||
| @ -44,6 +53,7 @@ class MessageUtils: | |||||||
|         self.log_verbose = os.environ.get("log_verbose", "0") or log_verbose |         self.log_verbose = os.environ.get("log_verbose", "0") or log_verbose | ||||||
|      |      | ||||||
|     def inherit_extrainfo(self, input_message: Message, output_message: Message): |     def inherit_extrainfo(self, input_message: Message, output_message: Message): | ||||||
|  |         output_message.user_name = input_message.user_name | ||||||
|         output_message.db_docs = input_message.db_docs |         output_message.db_docs = input_message.db_docs | ||||||
|         output_message.search_docs = input_message.search_docs |         output_message.search_docs = input_message.search_docs | ||||||
|         output_message.code_docs = input_message.code_docs |         output_message.code_docs = input_message.code_docs | ||||||
| @ -116,18 +126,45 @@ class MessageUtils: | |||||||
|         knowledge_basename = message.doc_engine_name |         knowledge_basename = message.doc_engine_name | ||||||
|         top_k = message.top_k |         top_k = message.top_k | ||||||
|         score_threshold = message.score_threshold |         score_threshold = message.score_threshold | ||||||
|         if knowledge_basename: |         if self.doc_retrieval: | ||||||
|             docs = DocRetrieval.run(query, knowledge_basename, top_k, score_threshold, self.embed_config, self.kb_root_path) |             if isinstance(self.doc_retrieval, BaseRetriever): | ||||||
|  |                 docs = self.doc_retrieval.get_relevant_documents(query) | ||||||
|  |             else: | ||||||
|  |                 # docs = self.doc_retrieval.run(query, search_top=message.top_k, score_threshold=message.score_threshold,) | ||||||
|  |                 docs = self.doc_retrieval.run(query) | ||||||
|  |             docs = [ | ||||||
|  |                 {"index": idx, "snippet": doc.page_content, "title": doc.metadata.get("title_prefix", ""), "link": doc.metadata.get("url", "")} | ||||||
|  |                 for idx, doc in enumerate(docs) | ||||||
|  |             ] | ||||||
|             message.db_docs = [Doc(**doc) for doc in docs] |             message.db_docs = [Doc(**doc) for doc in docs] | ||||||
|  |         else: | ||||||
|  |             if knowledge_basename: | ||||||
|  |                 docs = DocRetrieval.run(query, knowledge_basename, top_k, score_threshold, self.embed_config, self.kb_root_path) | ||||||
|  |                 message.db_docs = [Doc(**doc) for doc in docs] | ||||||
|         return message |         return message | ||||||
|      |      | ||||||
|     def get_code_retrieval(self, message: Message) -> Message: |     def get_code_retrieval(self, message: Message) -> Message: | ||||||
|         query = message.input_query |         query = message.role_content | ||||||
|         code_engine_name = message.code_engine_name |         code_engine_name = message.code_engine_name | ||||||
|         history_node_list = message.history_node_list |         history_node_list = message.history_node_list | ||||||
|         code_docs = CodeRetrieval.run(code_engine_name, query, code_limit=message.top_k, history_node_list=history_node_list, search_type=message.cb_search_type, | 
 | ||||||
|                                       llm_config=self.llm_config, embed_config=self.embed_config,) |         use_nh = message.use_nh | ||||||
|  |         local_graph_path = message.local_graph_path | ||||||
|  | 
 | ||||||
|  |         if self.code_retrieval: | ||||||
|  |             code_docs = self.code_retrieval.run( | ||||||
|  |                 query, history_node_list=history_node_list, search_type=message.cb_search_type, | ||||||
|  |                 code_limit=1 | ||||||
|  |             ) | ||||||
|  |         else: | ||||||
|  |             code_docs = CodeRetrieval.run(code_engine_name, query, code_limit=message.top_k, history_node_list=history_node_list, search_type=message.cb_search_type, | ||||||
|  |                                       llm_config=self.llm_config, embed_config=self.embed_config, | ||||||
|  |                                       use_nh=use_nh, local_graph_path=local_graph_path) | ||||||
|  |          | ||||||
|         message.code_docs = [CodeDoc(**doc) for doc in code_docs]  |         message.code_docs = [CodeDoc(**doc) for doc in code_docs]  | ||||||
|  | 
 | ||||||
|  |         # related_nodes = [doc.get_related_node() for idx, doc in enumerate(message.code_docs) if idx==0], | ||||||
|  |         # history_node_list.extend([node[0] for node in related_nodes])  | ||||||
|         return message |         return message | ||||||
|      |      | ||||||
|     def get_tool_retrieval(self, message: Message) -> Message: |     def get_tool_retrieval(self, message: Message) -> Message: | ||||||
| @ -160,6 +197,7 @@ class MessageUtils: | |||||||
|                     if code_answer.code_exe_type == "error" else f"The return information after executing the above code is {code_answer.code_exe_response}.\n" |                     if code_answer.code_exe_type == "error" else f"The return information after executing the above code is {code_answer.code_exe_response}.\n" | ||||||
|          |          | ||||||
|         observation_message = Message( |         observation_message = Message( | ||||||
|  |                 user_name=message.user_name, | ||||||
|                 role_name="observation", |                 role_name="observation", | ||||||
|                 role_type="function", #self.role.role_type, |                 role_type="function", #self.role.role_type, | ||||||
|                 role_content="", |                 role_content="", | ||||||
| @ -190,6 +228,7 @@ class MessageUtils: | |||||||
|     def tool_step(self, message: Message) -> Message: |     def tool_step(self, message: Message) -> Message: | ||||||
|         '''execute tool''' |         '''execute tool''' | ||||||
|         observation_message = Message( |         observation_message = Message( | ||||||
|  |                 user_name=message.user_name, | ||||||
|                 role_name="observation", |                 role_name="observation", | ||||||
|                 role_type="function", #self.role.role_type, |                 role_type="function", #self.role.role_type, | ||||||
|                 role_content="\n**Observation:** there is no tool can execute\n", |                 role_content="\n**Observation:** there is no tool can execute\n", | ||||||
| @ -226,7 +265,7 @@ class MessageUtils: | |||||||
|         return message, observation_message |         return message, observation_message | ||||||
| 
 | 
 | ||||||
|     def parser(self, message: Message) -> Message: |     def parser(self, message: Message) -> Message: | ||||||
|         '''''' |         '''parse llm output into dict''' | ||||||
|         content = message.role_content |         content = message.role_content | ||||||
|         # parse start |         # parse start | ||||||
|         parsed_dict = parse_text_to_dict(content) |         parsed_dict = parse_text_to_dict(content) | ||||||
|  | |||||||
| @ -5,6 +5,8 @@ import importlib | |||||||
| import copy | import copy | ||||||
| from loguru import logger | from loguru import logger | ||||||
| 
 | 
 | ||||||
|  | from langchain.schema import BaseRetriever | ||||||
|  | 
 | ||||||
| from coagent.connector.agents import BaseAgent | from coagent.connector.agents import BaseAgent | ||||||
| from coagent.connector.chains import BaseChain | from coagent.connector.chains import BaseChain | ||||||
| from coagent.connector.schema import ( | from coagent.connector.schema import ( | ||||||
| @ -18,9 +20,6 @@ from coagent.connector.message_process import MessageUtils | |||||||
| from coagent.llm_models.llm_config import EmbedConfig, LLMConfig | from coagent.llm_models.llm_config import EmbedConfig, LLMConfig | ||||||
| from coagent.base_configs.env_config import JUPYTER_WORK_PATH, KB_ROOT_PATH | from coagent.base_configs.env_config import JUPYTER_WORK_PATH, KB_ROOT_PATH | ||||||
| 
 | 
 | ||||||
| # from configs.model_config import JUPYTER_WORK_PATH, KB_ROOT_PATH |  | ||||||
| # from configs.server_config import SANDBOX_SERVER |  | ||||||
| 
 |  | ||||||
| 
 | 
 | ||||||
| role_configs = load_role_configs(AGETN_CONFIGS) | role_configs = load_role_configs(AGETN_CONFIGS) | ||||||
| chain_configs = load_chain_configs(CHAIN_CONFIGS) | chain_configs = load_chain_configs(CHAIN_CONFIGS) | ||||||
| @ -39,20 +38,24 @@ class BasePhase: | |||||||
|             kb_root_path: str = KB_ROOT_PATH, |             kb_root_path: str = KB_ROOT_PATH, | ||||||
|             jupyter_work_path: str = JUPYTER_WORK_PATH, |             jupyter_work_path: str = JUPYTER_WORK_PATH, | ||||||
|             sandbox_server: dict = {}, |             sandbox_server: dict = {}, | ||||||
|             embed_config: EmbedConfig = EmbedConfig(), |             embed_config: EmbedConfig = None, | ||||||
|             llm_config: LLMConfig = LLMConfig(), |             llm_config: LLMConfig = None, | ||||||
|             task: Task = None, |             task: Task = None, | ||||||
|             base_phase_config: Union[dict, str] = PHASE_CONFIGS, |             base_phase_config: Union[dict, str] = PHASE_CONFIGS, | ||||||
|             base_chain_config: Union[dict, str] = CHAIN_CONFIGS, |             base_chain_config: Union[dict, str] = CHAIN_CONFIGS, | ||||||
|             base_role_config: Union[dict, str] = AGETN_CONFIGS, |             base_role_config: Union[dict, str] = AGETN_CONFIGS, | ||||||
|  |             chains: List[BaseChain] = [], | ||||||
|  |             doc_retrieval: Union[BaseRetriever] = None, | ||||||
|  |             code_retrieval = None, | ||||||
|  |             search_retrieval = None, | ||||||
|             log_verbose: str = "0" |             log_verbose: str = "0" | ||||||
|             ) -> None: |             ) -> None: | ||||||
|         #  |         #  | ||||||
|         self.phase_name = phase_name |         self.phase_name = phase_name | ||||||
|         self.do_summary = False |         self.do_summary = False | ||||||
|         self.do_search = False |         self.do_search = search_retrieval is not None | ||||||
|         self.do_code_retrieval = False |         self.do_code_retrieval = code_retrieval is not None | ||||||
|         self.do_doc_retrieval = False |         self.do_doc_retrieval = doc_retrieval is not None | ||||||
|         self.do_tool_retrieval = False |         self.do_tool_retrieval = False | ||||||
|         # memory_pool dont have specific order |         # memory_pool dont have specific order | ||||||
|         # self.memory_pool = Memory(messages=[]) |         # self.memory_pool = Memory(messages=[]) | ||||||
| @ -62,12 +65,15 @@ class BasePhase: | |||||||
|         self.jupyter_work_path = jupyter_work_path |         self.jupyter_work_path = jupyter_work_path | ||||||
|         self.kb_root_path = kb_root_path |         self.kb_root_path = kb_root_path | ||||||
|         self.log_verbose = max(os.environ.get("log_verbose", "0"), log_verbose) |         self.log_verbose = max(os.environ.get("log_verbose", "0"), log_verbose) | ||||||
|          |         # TODO透传 | ||||||
|         self.message_utils = MessageUtils(None, sandbox_server, jupyter_work_path, embed_config, llm_config, kb_root_path, log_verbose) |         self.doc_retrieval = doc_retrieval | ||||||
|  |         self.code_retrieval = code_retrieval | ||||||
|  |         self.search_retrieval = search_retrieval | ||||||
|  |         self.message_utils = MessageUtils(None, sandbox_server, jupyter_work_path, embed_config, llm_config, kb_root_path, doc_retrieval, code_retrieval, search_retrieval, log_verbose) | ||||||
|         self.global_memory = Memory(messages=[]) |         self.global_memory = Memory(messages=[]) | ||||||
|         self.phase_memory: List[Memory] = [] |         self.phase_memory: List[Memory] = [] | ||||||
|         # according phase name to init the phase contains |         # according phase name to init the phase contains | ||||||
|         self.chains: List[BaseChain] = self.init_chains( |         self.chains: List[BaseChain] = chains if chains else self.init_chains( | ||||||
|             phase_name, |             phase_name, | ||||||
|             phase_config, |             phase_config, | ||||||
|             task=task,  |             task=task,  | ||||||
| @ -90,7 +96,9 @@ class BasePhase: | |||||||
|             kb_root_path=kb_root_path |             kb_root_path=kb_root_path | ||||||
|             ) |             ) | ||||||
| 
 | 
 | ||||||
|     def astep(self, query: Message, history: Memory = None) -> Tuple[Message, Memory]: |     def astep(self, query: Message, history: Memory = None, reinit_memory=False) -> Tuple[Message, Memory]: | ||||||
|  |         if reinit_memory: | ||||||
|  |             self.memory_manager.re_init(reinit_memory) | ||||||
|         self.memory_manager.append(query) |         self.memory_manager.append(query) | ||||||
|         summary_message = None |         summary_message = None | ||||||
|         chain_message = Memory(messages=[]) |         chain_message = Memory(messages=[]) | ||||||
| @ -139,8 +147,8 @@ class BasePhase: | |||||||
|         message.role_name = self.phase_name |         message.role_name = self.phase_name | ||||||
|         yield message, local_phase_memory |         yield message, local_phase_memory | ||||||
| 
 | 
 | ||||||
|     def step(self, query: Message, history: Memory = None) -> Tuple[Message, Memory]: |     def step(self, query: Message, history: Memory = None, reinit_memory=False) -> Tuple[Message, Memory]: | ||||||
|         for message, local_phase_memory in self.astep(query, history=history): |         for message, local_phase_memory in self.astep(query, history=history, reinit_memory=reinit_memory): | ||||||
|             pass |             pass | ||||||
|         return message, local_phase_memory |         return message, local_phase_memory | ||||||
| 
 | 
 | ||||||
| @ -194,6 +202,9 @@ class BasePhase: | |||||||
|                     sandbox_server=self.sandbox_server, |                     sandbox_server=self.sandbox_server, | ||||||
|                     jupyter_work_path=self.jupyter_work_path, |                     jupyter_work_path=self.jupyter_work_path, | ||||||
|                     kb_root_path=self.kb_root_path, |                     kb_root_path=self.kb_root_path, | ||||||
|  |                     doc_retrieval=self.doc_retrieval, | ||||||
|  |                     code_retrieval=self.code_retrieval, | ||||||
|  |                     search_retrieval=self.search_retrieval, | ||||||
|                     log_verbose=self.log_verbose |                     log_verbose=self.log_verbose | ||||||
|                 )  |                 )  | ||||||
|                 if agent_config.role.agent_type == "SelectorAgent": |                 if agent_config.role.agent_type == "SelectorAgent": | ||||||
| @ -205,7 +216,7 @@ class BasePhase: | |||||||
|                         group_base_agent = baseAgent( |                         group_base_agent = baseAgent( | ||||||
|                             role=group_agent_config.role,  |                             role=group_agent_config.role,  | ||||||
|                             prompt_config = group_agent_config.prompt_config, |                             prompt_config = group_agent_config.prompt_config, | ||||||
|                             prompt_manager_type=agent_config.prompt_manager_type, |                             prompt_manager_type=group_agent_config.prompt_manager_type, | ||||||
|                             task = task, |                             task = task, | ||||||
|                             memory = memory, |                             memory = memory, | ||||||
|                             chat_turn=group_agent_config.chat_turn, |                             chat_turn=group_agent_config.chat_turn, | ||||||
| @ -216,6 +227,9 @@ class BasePhase: | |||||||
|                             sandbox_server=self.sandbox_server, |                             sandbox_server=self.sandbox_server, | ||||||
|                             jupyter_work_path=self.jupyter_work_path, |                             jupyter_work_path=self.jupyter_work_path, | ||||||
|                             kb_root_path=self.kb_root_path, |                             kb_root_path=self.kb_root_path, | ||||||
|  |                             doc_retrieval=self.doc_retrieval, | ||||||
|  |                             code_retrieval=self.code_retrieval, | ||||||
|  |                             search_retrieval=self.search_retrieval, | ||||||
|                             log_verbose=self.log_verbose |                             log_verbose=self.log_verbose | ||||||
|                         )  |                         )  | ||||||
|                         base_agent.group_agents.append(group_base_agent) |                         base_agent.group_agents.append(group_base_agent) | ||||||
| @ -223,13 +237,16 @@ class BasePhase: | |||||||
|                 agents.append(base_agent) |                 agents.append(base_agent) | ||||||
|              |              | ||||||
|             chain_instance = BaseChain( |             chain_instance = BaseChain( | ||||||
|                 agents, chain_config.chat_turn,  |                 chain_config, | ||||||
|                 do_checker=chain_configs[chain_name].do_checker,  |                 agents,  | ||||||
|                 jupyter_work_path=self.jupyter_work_path, |                 jupyter_work_path=self.jupyter_work_path, | ||||||
|                 sandbox_server=self.sandbox_server, |                 sandbox_server=self.sandbox_server, | ||||||
|                 embed_config=self.embed_config, |                 embed_config=self.embed_config, | ||||||
|                 llm_config=self.llm_config, |                 llm_config=self.llm_config, | ||||||
|                 kb_root_path=self.kb_root_path, |                 kb_root_path=self.kb_root_path, | ||||||
|  |                 doc_retrieval=self.doc_retrieval, | ||||||
|  |                 code_retrieval=self.code_retrieval, | ||||||
|  |                 search_retrieval=self.search_retrieval, | ||||||
|                 log_verbose=self.log_verbose |                 log_verbose=self.log_verbose | ||||||
|                 ) |                 ) | ||||||
|             chains.append(chain_instance) |             chains.append(chain_instance) | ||||||
|  | |||||||
							
								
								
									
										2
									
								
								coagent/connector/prompt_manager/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										2
									
								
								coagent/connector/prompt_manager/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,2 @@ | |||||||
|  | from .prompt_manager import PromptManager | ||||||
|  | from .extend_manager import * | ||||||
							
								
								
									
										45
									
								
								coagent/connector/prompt_manager/extend_manager.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										45
									
								
								coagent/connector/prompt_manager/extend_manager.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,45 @@ | |||||||
|  | 
 | ||||||
|  | from coagent.connector.schema import Message | ||||||
|  | from .prompt_manager import PromptManager | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class Code2DocPM(PromptManager): | ||||||
|  |     def handle_code_snippet(self, **kwargs) -> str: | ||||||
|  |         if 'previous_agent_message' not in kwargs: | ||||||
|  |             return "" | ||||||
|  |         previous_agent_message: Message = kwargs['previous_agent_message'] | ||||||
|  |         code_snippet = previous_agent_message.customed_kargs.get("Code Snippet", "") | ||||||
|  |         current_vertex = previous_agent_message.customed_kargs.get("Current_Vertex", "") | ||||||
|  |         instruction = "A segment of code that contains the function or method to be documented.\n" | ||||||
|  |         return instruction + "\n" + f"name: {current_vertex}\n{code_snippet}" | ||||||
|  | 
 | ||||||
|  |     def handle_specific_objective(self, **kwargs) -> str: | ||||||
|  |         if 'previous_agent_message' not in kwargs: | ||||||
|  |             return "" | ||||||
|  |         previous_agent_message: Message = kwargs['previous_agent_message'] | ||||||
|  |         specific_objective = previous_agent_message.parsed_output.get("Code Path") | ||||||
|  | 
 | ||||||
|  |         instruction = "Provide the code path of the function or method you wish to document.\n" | ||||||
|  |         s = instruction + f"\n{specific_objective}" | ||||||
|  |         return s | ||||||
|  |      | ||||||
|  | 
 | ||||||
|  | class CodeRetrievalPM(PromptManager): | ||||||
|  |     def handle_code_snippet(self, **kwargs) -> str: | ||||||
|  |         if 'previous_agent_message' not in kwargs: | ||||||
|  |             return "" | ||||||
|  |         previous_agent_message: Message = kwargs['previous_agent_message'] | ||||||
|  |         code_snippet = previous_agent_message.customed_kargs.get("Code Snippet", "") | ||||||
|  |         current_vertex = previous_agent_message.customed_kargs.get("Current_Vertex", "") | ||||||
|  |         instruction = "the initial Code or objective that the user wanted to achieve" | ||||||
|  |         return instruction + "\n" + f"name: {current_vertex}\n{code_snippet}" | ||||||
|  | 
 | ||||||
|  |     def handle_retrieval_codes(self, **kwargs) -> str: | ||||||
|  |         if 'previous_agent_message' not in kwargs: | ||||||
|  |             return "" | ||||||
|  |         previous_agent_message: Message = kwargs['previous_agent_message'] | ||||||
|  |         Retrieval_Codes = previous_agent_message.customed_kargs["Retrieval_Codes"] | ||||||
|  |         Relative_vertex = previous_agent_message.customed_kargs["Relative_vertex"] | ||||||
|  |         instruction = "the initial Code or objective that the user wanted to achieve" | ||||||
|  |         s = instruction + "\n" + "\n".join([f"name: {vertext}\n{code}" for vertext, code in zip(Relative_vertex, Retrieval_Codes)]) | ||||||
|  |         return s | ||||||
							
								
								
									
										353
									
								
								coagent/connector/prompt_manager/prompt_manager.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										353
									
								
								coagent/connector/prompt_manager/prompt_manager.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,353 @@ | |||||||
|  | import random | ||||||
|  | from textwrap import dedent | ||||||
|  | import copy | ||||||
|  | from loguru import logger | ||||||
|  | 
 | ||||||
|  | from langchain.agents.tools import Tool | ||||||
|  | 
 | ||||||
|  | from coagent.connector.schema import Memory, Message | ||||||
|  | from coagent.connector.utils import extract_section, parse_section | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class PromptManager: | ||||||
|  |     def __init__(self, role_prompt="", prompt_config=None, monitored_agents=[], monitored_fields=[]): | ||||||
|  |         self.role_prompt = role_prompt | ||||||
|  |         self.monitored_agents = monitored_agents | ||||||
|  |         self.monitored_fields = monitored_fields | ||||||
|  |         self.field_handlers = {} | ||||||
|  |         self.context_handlers = {} | ||||||
|  |         self.field_order = []  # 用于普通字段的顺序 | ||||||
|  |         self.context_order = []  # 单独维护上下文字段的顺序 | ||||||
|  |         self.field_descriptions = {} | ||||||
|  |         self.omit_if_empty_flags = {} | ||||||
|  |         self.context_title = "### Context Data\n\n"   | ||||||
|  |          | ||||||
|  |         self.prompt_config = prompt_config | ||||||
|  |         if self.prompt_config: | ||||||
|  |             self.register_fields_from_config() | ||||||
|  |      | ||||||
|  |     def register_field(self, field_name, function=None, title=None, description=None, is_context=True, omit_if_empty=True): | ||||||
|  |         """ | ||||||
|  |         注册一个新的字段及其处理函数。 | ||||||
|  |         Args: | ||||||
|  |             field_name (str): 字段名称。 | ||||||
|  |             function (callable): 处理字段数据的函数。 | ||||||
|  |             title (str, optional): 字段的自定义标题(可选)。 | ||||||
|  |             description (str, optional): 字段的描述(可选,可以是几句话)。 | ||||||
|  |             is_context (bool, optional): 指示该字段是否为上下文字段。 | ||||||
|  |             omit_if_empty (bool, optional): 如果数据为空,是否省略该字段。 | ||||||
|  |         """ | ||||||
|  |         if not function: | ||||||
|  |             function = self.handle_custom_data | ||||||
|  |          | ||||||
|  |         # Register the handler function based on context flag | ||||||
|  |         if is_context: | ||||||
|  |             self.context_handlers[field_name] = function | ||||||
|  |         else: | ||||||
|  |             self.field_handlers[field_name] = function | ||||||
|  |          | ||||||
|  |         # Store the custom title if provided and adjust the title prefix based on context | ||||||
|  |         title_prefix = "####" if is_context else "###" | ||||||
|  |         if title is not None: | ||||||
|  |             self.field_descriptions[field_name] = f"{title_prefix} {title}\n\n" | ||||||
|  |         elif description is not None: | ||||||
|  |             # If title is not provided but description is, use description as title | ||||||
|  |             self.field_descriptions[field_name] = f"{title_prefix} {field_name.replace('_', ' ').title()}\n\n{description}\n\n" | ||||||
|  |         else: | ||||||
|  |             # If neither title nor description is provided, use the field name as title | ||||||
|  |             self.field_descriptions[field_name] = f"{title_prefix} {field_name.replace('_', ' ').title()}\n\n" | ||||||
|  |              | ||||||
|  |         # Store the omit_if_empty flag for this field | ||||||
|  |         self.omit_if_empty_flags[field_name] = omit_if_empty | ||||||
|  |          | ||||||
|  |         if is_context and field_name != 'context_placeholder': | ||||||
|  |             self.context_handlers[field_name] = function | ||||||
|  |             self.context_order.append(field_name) | ||||||
|  |         else: | ||||||
|  |             self.field_handlers[field_name] = function | ||||||
|  |             self.field_order.append(field_name) | ||||||
|  |      | ||||||
|  |     def generate_full_prompt(self, **kwargs): | ||||||
|  |         full_prompt = [] | ||||||
|  |         context_prompts = []  # 用于收集上下文内容 | ||||||
|  |         is_pre_print = kwargs.get("is_pre_print", False) # 用于强制打印所有prompt 字段信息,不管有没有空 | ||||||
|  | 
 | ||||||
|  |         # 先处理上下文字段 | ||||||
|  |         for field_name in self.context_order: | ||||||
|  |             handler = self.context_handlers[field_name] | ||||||
|  |             processed_prompt = handler(**kwargs) | ||||||
|  |             # Check if the field should be omitted when empty | ||||||
|  |             if self.omit_if_empty_flags.get(field_name, False) and not processed_prompt and not is_pre_print: | ||||||
|  |                 continue  # Skip this field | ||||||
|  |             title_or_description = self.field_descriptions.get(field_name, f"#### {field_name.replace('_', ' ').title()}\n\n") | ||||||
|  |             context_prompts.append(title_or_description + processed_prompt + '\n\n') | ||||||
|  | 
 | ||||||
|  |         # 处理普通字段,同时查找 context_placeholder 的位置 | ||||||
|  |         for field_name in self.field_order: | ||||||
|  |             if field_name == 'context_placeholder': | ||||||
|  |                 # 在 context_placeholder 的位置插入上下文数据 | ||||||
|  |                 full_prompt.append(self.context_title)  # 添加上下文部分的大标题 | ||||||
|  |                 full_prompt.extend(context_prompts)  # 添加收集的上下文内容 | ||||||
|  |             else: | ||||||
|  |                 handler = self.field_handlers[field_name] | ||||||
|  |                 processed_prompt = handler(**kwargs) | ||||||
|  |                 # Check if the field should be omitted when empty | ||||||
|  |                 if self.omit_if_empty_flags.get(field_name, False) and not processed_prompt and not is_pre_print: | ||||||
|  |                     continue  # Skip this field | ||||||
|  |                 title_or_description = self.field_descriptions.get(field_name, f"### {field_name.replace('_', ' ').title()}\n\n") | ||||||
|  |                 full_prompt.append(title_or_description + processed_prompt + '\n\n') | ||||||
|  | 
 | ||||||
|  |         # 返回完整的提示,移除尾部的空行 | ||||||
|  |         return ''.join(full_prompt).rstrip('\n') | ||||||
|  |          | ||||||
|  |     def pre_print(self, **kwargs): | ||||||
|  |         kwargs.update({"is_pre_print": True}) | ||||||
|  |         prompt = self.generate_full_prompt(**kwargs) | ||||||
|  | 
 | ||||||
|  |         input_keys = parse_section(self.role_prompt, 'Response Output Format') | ||||||
|  |         llm_predict = "\n".join([f"**{k}:**" for k in input_keys]) | ||||||
|  |         return prompt + "\n\n" + "#"*19 + "\n<<<<LLM PREDICT>>>>\n" + "#"*19  + f"\n\n{llm_predict}\n" | ||||||
|  | 
 | ||||||
|  |     def handle_custom_data(self, **kwargs): | ||||||
|  |         return "" | ||||||
|  |      | ||||||
|  |     def handle_tool_data(self, **kwargs): | ||||||
|  |         if 'previous_agent_message' not in kwargs: | ||||||
|  |             return "" | ||||||
|  | 
 | ||||||
|  |         previous_agent_message = kwargs.get('previous_agent_message') | ||||||
|  |         tools: list[Tool] = previous_agent_message.tools | ||||||
|  |          | ||||||
|  |         if not tools: | ||||||
|  |             return "" | ||||||
|  |          | ||||||
|  |         tool_strings = [] | ||||||
|  |         for tool in tools: | ||||||
|  |             args_str = f'args: {str(tool.args)}' if tool.args_schema else "" | ||||||
|  |             tool_strings.append(f"{tool.name}: {tool.description}, {args_str}") | ||||||
|  |         formatted_tools = "\n".join(tool_strings) | ||||||
|  |          | ||||||
|  |         tool_names = ", ".join([tool.name for tool in tools]) | ||||||
|  |          | ||||||
|  |         tool_prompt = dedent(f""" | ||||||
|  | Below is a list of tools that are available for your use: | ||||||
|  | {formatted_tools} | ||||||
|  | 
 | ||||||
|  | valid "tool_name" value is: | ||||||
|  | {tool_names} | ||||||
|  | """) | ||||||
|  |          | ||||||
|  |         return tool_prompt | ||||||
|  | 
 | ||||||
|  |     def handle_agent_data(self, **kwargs): | ||||||
|  |         if 'agents' not in kwargs: | ||||||
|  |             return "" | ||||||
|  |          | ||||||
|  |         agents = kwargs.get('agents') | ||||||
|  |         random.shuffle(agents) | ||||||
|  |         agent_names = ", ".join([f'{agent.role.role_name}' for agent in agents]) | ||||||
|  |         agent_descs = [] | ||||||
|  |         for agent in agents: | ||||||
|  |             role_desc = agent.role.role_prompt.split("####")[1] | ||||||
|  |             while "\n\n" in role_desc: | ||||||
|  |                 role_desc = role_desc.replace("\n\n", "\n") | ||||||
|  |             role_desc = role_desc.replace("\n", ",") | ||||||
|  | 
 | ||||||
|  |             agent_descs.append(f'"role name: {agent.role.role_name}\nrole description: {role_desc}"') | ||||||
|  | 
 | ||||||
|  |         agents =  "\n".join(agent_descs) | ||||||
|  |         agent_prompt = f''' | ||||||
|  |         Please ensure your selection is one of the listed roles. Available roles for selection: | ||||||
|  |         {agents} | ||||||
|  |         Please ensure select the Role from agent names, such as {agent_names}''' | ||||||
|  | 
 | ||||||
|  |         return dedent(agent_prompt) | ||||||
|  |      | ||||||
|  |     def handle_doc_info(self, **kwargs) -> str: | ||||||
|  |         if 'previous_agent_message' not in kwargs: | ||||||
|  |             return "" | ||||||
|  |         previous_agent_message: Message = kwargs.get('previous_agent_message') | ||||||
|  |         db_docs = previous_agent_message.db_docs | ||||||
|  |         search_docs = previous_agent_message.search_docs | ||||||
|  |         code_cocs = previous_agent_message.code_docs | ||||||
|  |         doc_infos = "\n".join([doc.get_snippet() for doc in db_docs] + [doc.get_snippet() for doc in search_docs] +  | ||||||
|  |                               [doc.get_code() for doc in code_cocs]) | ||||||
|  |         return doc_infos | ||||||
|  |      | ||||||
|  |     def handle_session_records(self, **kwargs) -> str: | ||||||
|  | 
 | ||||||
|  |         memory_pool: Memory = kwargs.get('memory_pool', Memory(messages=[])) | ||||||
|  |         memory_pool = self.select_memory_by_agent_name(memory_pool) | ||||||
|  |         memory_pool = self.select_memory_by_parsed_key(memory_pool) | ||||||
|  | 
 | ||||||
|  |         return memory_pool.to_str_messages(content_key="parsed_output_list", with_tag=True) | ||||||
|  |      | ||||||
|  |     def handle_current_plan(self, **kwargs) -> str: | ||||||
|  |         if 'previous_agent_message' not in kwargs: | ||||||
|  |             return "" | ||||||
|  |         previous_agent_message = kwargs['previous_agent_message'] | ||||||
|  |         return previous_agent_message.parsed_output.get("CURRENT_STEP", "") | ||||||
|  |      | ||||||
|  |     def handle_agent_profile(self, **kwargs) -> str: | ||||||
|  |         return extract_section(self.role_prompt, 'Agent Profile') | ||||||
|  |      | ||||||
|  |     def handle_output_format(self, **kwargs) -> str: | ||||||
|  |         return extract_section(self.role_prompt, 'Response Output Format') | ||||||
|  |      | ||||||
|  |     def handle_response(self, **kwargs) -> str: | ||||||
|  |         if 'react_memory' not in kwargs: | ||||||
|  |             return "" | ||||||
|  | 
 | ||||||
|  |         react_memory = kwargs.get('react_memory', Memory(messages=[])) | ||||||
|  |         if react_memory is None: | ||||||
|  |             return "" | ||||||
|  |          | ||||||
|  |         return "\n".join(["\n".join([f"**{k}:**\n{v}" for k,v in _dict.items()]) for _dict in react_memory.get_parserd_output()]) | ||||||
|  | 
 | ||||||
|  |     def handle_task_records(self, **kwargs) -> str: | ||||||
|  |         if 'task_memory' not in kwargs: | ||||||
|  |             return "" | ||||||
|  | 
 | ||||||
|  |         task_memory: Memory = kwargs.get('task_memory', Memory(messages=[])) | ||||||
|  |         if task_memory is None: | ||||||
|  |             return "" | ||||||
|  |          | ||||||
|  |         return "\n".join(["\n".join([f"**{k}:**\n{v}" for k,v in _dict.items() if k not in ["CURRENT_STEP"]]) for _dict in task_memory.get_parserd_output()]) | ||||||
|  |      | ||||||
|  |     def handle_previous_message(self, message: Message) -> str:  | ||||||
|  |         pass | ||||||
|  |      | ||||||
|  |     def handle_message_by_role_name(self, message: Message) -> str:  | ||||||
|  |         pass | ||||||
|  |      | ||||||
|  |     def handle_message_by_role_type(self, message: Message) -> str: | ||||||
|  |         pass | ||||||
|  |      | ||||||
|  |     def handle_current_agent_react_message(self, message: Message) -> str: | ||||||
|  |         pass | ||||||
|  |      | ||||||
|  |     def extract_codedoc_info_for_prompt(self, message: Message) -> str:  | ||||||
|  |         code_docs = message.code_docs | ||||||
|  |         doc_infos = "\n".join([doc.get_code() for doc in code_docs]) | ||||||
|  |         return doc_infos | ||||||
|  |      | ||||||
|  |     def select_memory_by_parsed_key(self, memory: Memory) -> Memory: | ||||||
|  |         return Memory( | ||||||
|  |             messages=[self.select_message_by_parsed_key(message) for message in memory.messages  | ||||||
|  |                       if self.select_message_by_parsed_key(message) is not None] | ||||||
|  |                       ) | ||||||
|  | 
 | ||||||
|  |     def select_memory_by_agent_name(self, memory: Memory) -> Memory: | ||||||
|  |         return Memory( | ||||||
|  |             messages=[self.select_message_by_agent_name(message) for message in memory.messages  | ||||||
|  |                       if self.select_message_by_agent_name(message) is not None] | ||||||
|  |                       ) | ||||||
|  | 
 | ||||||
|  |     def select_message_by_agent_name(self, message: Message) -> Message: | ||||||
|  |         # assume we focus all agents | ||||||
|  |         if self.monitored_agents == []: | ||||||
|  |             return message | ||||||
|  |         return None if message is None or message.role_name not in self.monitored_agents else self.select_message_by_parsed_key(message) | ||||||
|  |      | ||||||
|  |     def select_message_by_parsed_key(self, message: Message) -> Message: | ||||||
|  |         # assume we focus all key contents | ||||||
|  |         if message is None: | ||||||
|  |             return message | ||||||
|  |          | ||||||
|  |         if self.monitored_fields == []: | ||||||
|  |             return message | ||||||
|  |          | ||||||
|  |         message_c = copy.deepcopy(message) | ||||||
|  |         message_c.parsed_output = {k: v for k,v in message_c.parsed_output.items() if k in self.monitored_fields} | ||||||
|  |         message_c.parsed_output_list = [{k: v for k,v in parsed_output.items() if k in self.monitored_fields} for parsed_output in message_c.parsed_output_list] | ||||||
|  |         return message_c | ||||||
|  |      | ||||||
|  |     def get_memory(self, content_key="role_content"): | ||||||
|  |         return self.memory.to_tuple_messages(content_key="step_content") | ||||||
|  |      | ||||||
|  |     def get_memory_str(self, content_key="role_content"): | ||||||
|  |         return "\n".join([": ".join(i) for i in self.memory.to_tuple_messages(content_key="step_content")]) | ||||||
|  |      | ||||||
|  |     def register_fields_from_config(self): | ||||||
|  |          | ||||||
|  |         for prompt_field in self.prompt_config: | ||||||
|  |              | ||||||
|  |             function_name = prompt_field.function_name | ||||||
|  |             # 检查function_name是否是self的一个方法 | ||||||
|  |             if function_name and hasattr(self, function_name): | ||||||
|  |                 function = getattr(self, function_name) | ||||||
|  |             else: | ||||||
|  |                 function = self.handle_custom_data | ||||||
|  |                  | ||||||
|  |             self.register_field(prompt_field.field_name,  | ||||||
|  |                                 function=function,  | ||||||
|  |                                 title=prompt_field.title, | ||||||
|  |                                 description=prompt_field.description, | ||||||
|  |                                 is_context=prompt_field.is_context, | ||||||
|  |                                 omit_if_empty=prompt_field.omit_if_empty) | ||||||
|  |      | ||||||
|  |     def register_standard_fields(self): | ||||||
|  |         self.register_field('agent_profile', function=self.handle_agent_profile, is_context=False) | ||||||
|  |         self.register_field('tool_information', function=self.handle_tool_data, is_context=False) | ||||||
|  |         self.register_field('context_placeholder', is_context=True)  # 用于标记上下文数据部分的位置 | ||||||
|  |         self.register_field('reference_documents', function=self.handle_doc_info, is_context=True) | ||||||
|  |         self.register_field('session_records', function=self.handle_session_records, is_context=True) | ||||||
|  |         self.register_field('output_format', function=self.handle_output_format, title='Response Output Format', is_context=False) | ||||||
|  |         self.register_field('response', function=self.handle_response, is_context=False, omit_if_empty=False) | ||||||
|  |          | ||||||
|  |     def register_executor_fields(self): | ||||||
|  |         self.register_field('agent_profile', function=self.handle_agent_profile, is_context=False) | ||||||
|  |         self.register_field('tool_information', function=self.handle_tool_data, is_context=False) | ||||||
|  |         self.register_field('context_placeholder', is_context=True)  # 用于标记上下文数据部分的位置 | ||||||
|  |         self.register_field('reference_documents', function=self.handle_doc_info, is_context=True) | ||||||
|  |         self.register_field('session_records', function=self.handle_session_records, is_context=True) | ||||||
|  |         self.register_field('current_plan', function=self.handle_current_plan, is_context=True) | ||||||
|  |         self.register_field('output_format', function=self.handle_output_format, title='Response Output Format', is_context=False) | ||||||
|  |         self.register_field('response', function=self.handle_response, is_context=False, omit_if_empty=False) | ||||||
|  |          | ||||||
|  |     def register_fields_from_dict(self, fields_dict): | ||||||
|  |         # 使用字典注册字段的函数 | ||||||
|  |         for field_name, field_config in fields_dict.items(): | ||||||
|  |             function_name = field_config.get('function', None) | ||||||
|  |             title = field_config.get('title', None) | ||||||
|  |             description = field_config.get('description', None) | ||||||
|  |             is_context = field_config.get('is_context', True) | ||||||
|  |             omit_if_empty = field_config.get('omit_if_empty', True) | ||||||
|  |              | ||||||
|  |             # 检查function_name是否是self的一个方法 | ||||||
|  |             if function_name and hasattr(self, function_name): | ||||||
|  |                 function = getattr(self, function_name) | ||||||
|  |             else: | ||||||
|  |                 function = self.handle_custom_data | ||||||
|  |              | ||||||
|  |             # 调用已存在的register_field方法注册字段 | ||||||
|  |             self.register_field(field_name, function=function, title=title, description=description, is_context=is_context, omit_if_empty=omit_if_empty) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def main(): | ||||||
|  |     manager = PromptManager() | ||||||
|  |     manager.register_standard_fields() | ||||||
|  | 
 | ||||||
|  |     manager.register_field('agents_work_progress', title=f"Agents' Work Progress", is_context=True) | ||||||
|  | 
 | ||||||
|  |     # 创建数据字典 | ||||||
|  |     data_dict = { | ||||||
|  |         "agent_profile": "这是代理配置文件...", | ||||||
|  |         # "tool_list": "这是工具列表...", | ||||||
|  |         "reference_documents": "这是参考文档...", | ||||||
|  |         "session_records": "这是会话记录...", | ||||||
|  |         "agents_work_progress": "这是代理工作进展...", | ||||||
|  |         "output_format": "这是预期的输出格式...", | ||||||
|  |         # "response": "这是生成或继续回应的指令...", | ||||||
|  |         "response": "", | ||||||
|  |         "test": 'xxxxx' | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |     # 组合完整的提示 | ||||||
|  |     full_prompt = manager.generate_full_prompt(data_dict) | ||||||
|  |     print(full_prompt) | ||||||
|  | 
 | ||||||
|  | if __name__ == "__main__": | ||||||
|  |     main() | ||||||
| @ -215,15 +215,15 @@ class Env(BaseModel): | |||||||
| class Role(BaseModel): | class Role(BaseModel): | ||||||
|     role_type: str |     role_type: str | ||||||
|     role_name: str |     role_name: str | ||||||
|     role_desc: str |     role_desc: str = "" | ||||||
|     agent_type: str = "" |     agent_type: str = "BaseAgent" | ||||||
|     role_prompt: str = "" |     role_prompt: str = "" | ||||||
|     template_prompt: str = "" |     template_prompt: str = "" | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class ChainConfig(BaseModel): | class ChainConfig(BaseModel): | ||||||
|     chain_name: str |     chain_name: str | ||||||
|     chain_type: str |     chain_type: str = "BaseChain" | ||||||
|     agents: List[str] |     agents: List[str] | ||||||
|     do_checker: bool = False |     do_checker: bool = False | ||||||
|     chat_turn: int = 1 |     chat_turn: int = 1 | ||||||
|  | |||||||
| @ -132,6 +132,9 @@ class Memory(BaseModel): | |||||||
|         # return [parsed_output for message in self.messages for parsed_output in message.parsed_output_list[1:]] |         # return [parsed_output for message in self.messages for parsed_output in message.parsed_output_list[1:]] | ||||||
|         return [parsed_output for message in self.messages for parsed_output in message.parsed_output_list] |         return [parsed_output for message in self.messages for parsed_output in message.parsed_output_list] | ||||||
| 
 | 
 | ||||||
|  |     def get_spec_parserd_output(self, ): | ||||||
|  |         return [message.spec_parsed_output for message in self.messages] | ||||||
|  |      | ||||||
|     def get_rolenames(self, ): |     def get_rolenames(self, ): | ||||||
|         '''''' |         '''''' | ||||||
|         return [message.role_name for message in self.messages] |         return [message.role_name for message in self.messages] | ||||||
|  | |||||||
| @ -7,6 +7,7 @@ from .general_schema import * | |||||||
| 
 | 
 | ||||||
| class Message(BaseModel): | class Message(BaseModel): | ||||||
|     chat_index: str = None |     chat_index: str = None | ||||||
|  |     user_name: str = "default" | ||||||
|     role_name: str |     role_name: str | ||||||
|     role_type: str |     role_type: str | ||||||
|     role_prompt: str = None |     role_prompt: str = None | ||||||
| @ -53,6 +54,8 @@ class Message(BaseModel): | |||||||
|     cb_search_type: str = None |     cb_search_type: str = None | ||||||
|     search_engine_name: str = None  |     search_engine_name: str = None  | ||||||
|     top_k: int = 3 |     top_k: int = 3 | ||||||
|  |     use_nh: bool = True | ||||||
|  |     local_graph_path: str = '' | ||||||
|     score_threshold: float = 1.0 |     score_threshold: float = 1.0 | ||||||
|     do_doc_retrieval: bool = False |     do_doc_retrieval: bool = False | ||||||
|     do_code_retrieval: bool = False |     do_code_retrieval: bool = False | ||||||
|  | |||||||
| @ -72,20 +72,25 @@ def parse_text_to_dict(text): | |||||||
| def parse_dict_to_dict(parsed_dict) -> dict: | def parse_dict_to_dict(parsed_dict) -> dict: | ||||||
|     code_pattern = r'```python\n(.*?)```' |     code_pattern = r'```python\n(.*?)```' | ||||||
|     tool_pattern = r'```json\n(.*?)```' |     tool_pattern = r'```json\n(.*?)```' | ||||||
|  |     java_pattern = r'```java\n(.*?)```' | ||||||
|      |      | ||||||
|     pattern_dict = {"code": code_pattern, "json": tool_pattern} |     pattern_dict = {"code": code_pattern, "json": tool_pattern, "java": java_pattern} | ||||||
|     spec_parsed_dict = copy.deepcopy(parsed_dict) |     spec_parsed_dict = copy.deepcopy(parsed_dict) | ||||||
|     for key, pattern in pattern_dict.items(): |     for key, pattern in pattern_dict.items(): | ||||||
|         for k, text in parsed_dict.items(): |         for k, text in parsed_dict.items(): | ||||||
|             # Search for the code block |             # Search for the code block | ||||||
|             if not isinstance(text, str): continue |             if not isinstance(text, str):  | ||||||
|  |                 spec_parsed_dict[k] = text | ||||||
|  |                 continue | ||||||
|             _match = re.search(pattern, text, re.DOTALL) |             _match = re.search(pattern, text, re.DOTALL) | ||||||
|             if _match: |             if _match: | ||||||
|                 # Add the code block to the dictionary |                 # Add the code block to the dictionary | ||||||
|                 try: |                 try: | ||||||
|                     spec_parsed_dict[key] = json.loads(_match.group(1).strip()) |                     spec_parsed_dict[key] = json.loads(_match.group(1).strip()) | ||||||
|  |                     spec_parsed_dict[k] = json.loads(_match.group(1).strip()) | ||||||
|                 except: |                 except: | ||||||
|                     spec_parsed_dict[key] = _match.group(1).strip() |                     spec_parsed_dict[key] = _match.group(1).strip() | ||||||
|  |                     spec_parsed_dict[k] = _match.group(1).strip() | ||||||
|                 break |                 break | ||||||
|     return spec_parsed_dict |     return spec_parsed_dict | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -43,7 +43,7 @@ class NebulaHandler: | |||||||
|                 elif self.space_name: |                 elif self.space_name: | ||||||
|                     cypher = f'USE {self.space_name};{cypher}' |                     cypher = f'USE {self.space_name};{cypher}' | ||||||
| 
 | 
 | ||||||
|             logger.debug(cypher) |             # logger.debug(cypher) | ||||||
|             resp = session.execute(cypher) |             resp = session.execute(cypher) | ||||||
| 
 | 
 | ||||||
|             if format_res: |             if format_res: | ||||||
| @ -247,6 +247,24 @@ class NebulaHandler: | |||||||
|         res = self.execute_cypher(cypher, self.space_name) |         res = self.execute_cypher(cypher, self.space_name) | ||||||
|         return self.result_to_dict(res) |         return self.result_to_dict(res) | ||||||
| 
 | 
 | ||||||
|  |     def get_all_vertices(self,): | ||||||
|  |         ''' | ||||||
|  |         get all vertices | ||||||
|  |         @return: | ||||||
|  |         ''' | ||||||
|  |         cypher = "MATCH (v) RETURN v;" | ||||||
|  |         res = self.execute_cypher(cypher, self.space_name) | ||||||
|  |         return self.result_to_dict(res) | ||||||
|  |      | ||||||
|  |     def get_relative_vertices(self, vertice): | ||||||
|  |         ''' | ||||||
|  |         get all vertices | ||||||
|  |         @return: | ||||||
|  |         ''' | ||||||
|  |         cypher = f'''MATCH (v1)--(v2) WHERE id(v1) == '{vertice}' RETURN id(v2) as id;''' | ||||||
|  |         res = self.execute_cypher(cypher, self.space_name) | ||||||
|  |         return self.result_to_dict(res) | ||||||
|  | 
 | ||||||
|     def result_to_dict(self, result) -> dict: |     def result_to_dict(self, result) -> dict: | ||||||
|         """ |         """ | ||||||
|         build list for each column, and transform to dataframe |         build list for each column, and transform to dataframe | ||||||
|  | |||||||
| @ -6,6 +6,7 @@ import os | |||||||
| import pickle | import pickle | ||||||
| import uuid | import uuid | ||||||
| import warnings | import warnings | ||||||
|  | from enum import Enum | ||||||
| from pathlib import Path | from pathlib import Path | ||||||
| from typing import ( | from typing import ( | ||||||
|     Any, |     Any, | ||||||
| @ -22,10 +23,22 @@ import numpy as np | |||||||
| 
 | 
 | ||||||
| from langchain.docstore.base import AddableMixin, Docstore | from langchain.docstore.base import AddableMixin, Docstore | ||||||
| from langchain.docstore.document import Document | from langchain.docstore.document import Document | ||||||
| from langchain.docstore.in_memory import InMemoryDocstore | # from langchain.docstore.in_memory import InMemoryDocstore | ||||||
|  | from .in_memory import InMemoryDocstore | ||||||
| from langchain.embeddings.base import Embeddings | from langchain.embeddings.base import Embeddings | ||||||
| from langchain.vectorstores.base import VectorStore | from langchain.vectorstores.base import VectorStore | ||||||
| from langchain.vectorstores.utils import DistanceStrategy, maximal_marginal_relevance | from langchain.vectorstores.utils import maximal_marginal_relevance | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class DistanceStrategy(str, Enum): | ||||||
|  |     """Enumerator of the Distance strategies for calculating distances | ||||||
|  |     between vectors.""" | ||||||
|  | 
 | ||||||
|  |     EUCLIDEAN_DISTANCE = "EUCLIDEAN_DISTANCE" | ||||||
|  |     MAX_INNER_PRODUCT = "MAX_INNER_PRODUCT" | ||||||
|  |     DOT_PRODUCT = "DOT_PRODUCT" | ||||||
|  |     JACCARD = "JACCARD" | ||||||
|  |     COSINE = "COSINE" | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def dependable_faiss_import(no_avx2: Optional[bool] = None) -> Any: | def dependable_faiss_import(no_avx2: Optional[bool] = None) -> Any: | ||||||
| @ -219,6 +232,9 @@ class FAISS(VectorStore): | |||||||
|         if self._normalize_L2: |         if self._normalize_L2: | ||||||
|             faiss.normalize_L2(vector) |             faiss.normalize_L2(vector) | ||||||
|         scores, indices = self.index.search(vector, k if filter is None else fetch_k) |         scores, indices = self.index.search(vector, k if filter is None else fetch_k) | ||||||
|  |         # 经过normalize的结果会超出1 | ||||||
|  |         if self._normalize_L2: | ||||||
|  |             scores = np.array([row / np.linalg.norm(row) if np.max(row) > 1 else row for row in scores]) | ||||||
|         docs = [] |         docs = [] | ||||||
|         for j, i in enumerate(indices[0]): |         for j, i in enumerate(indices[0]): | ||||||
|             if i == -1: |             if i == -1: | ||||||
| @ -565,7 +581,7 @@ class FAISS(VectorStore): | |||||||
|         vecstore = cls( |         vecstore = cls( | ||||||
|             embedding.embed_query, |             embedding.embed_query, | ||||||
|             index, |             index, | ||||||
|             InMemoryDocstore(), |             InMemoryDocstore({}), | ||||||
|             {}, |             {}, | ||||||
|             normalize_L2=normalize_L2, |             normalize_L2=normalize_L2, | ||||||
|             distance_strategy=distance_strategy, |             distance_strategy=distance_strategy, | ||||||
|  | |||||||
| @ -10,13 +10,14 @@ from loguru import logger | |||||||
| # from configs.model_config import EMBEDDING_MODEL | # from configs.model_config import EMBEDDING_MODEL | ||||||
| from coagent.embeddings.openai_embedding import OpenAIEmbedding | from coagent.embeddings.openai_embedding import OpenAIEmbedding | ||||||
| from coagent.embeddings.huggingface_embedding import HFEmbedding | from coagent.embeddings.huggingface_embedding import HFEmbedding | ||||||
| 
 | from coagent.llm_models.llm_config import EmbedConfig | ||||||
| 
 | 
 | ||||||
| def get_embedding( | def get_embedding( | ||||||
|         engine: str,  |         engine: str,  | ||||||
|         text_list: list,  |         text_list: list,  | ||||||
|         model_path: str = "text2vec-base-chinese", |         model_path: str = "text2vec-base-chinese", | ||||||
|         embedding_device: str = "cpu", |         embedding_device: str = "cpu", | ||||||
|  |         embed_config: EmbedConfig = None, | ||||||
|         ): |         ): | ||||||
|     ''' |     ''' | ||||||
|     get embedding |     get embedding | ||||||
| @ -25,8 +26,12 @@ def get_embedding( | |||||||
|     @return: |     @return: | ||||||
|     ''' |     ''' | ||||||
|     emb_res = {} |     emb_res = {} | ||||||
| 
 |     if embed_config and embed_config.langchain_embeddings: | ||||||
|     if engine == 'openai': |         emb_res = embed_config.langchain_embeddings.embed_documents(text_list) | ||||||
|  |         emb_res = { | ||||||
|  |             text_list[idx]: emb_res[idx] for idx in range(len(text_list)) | ||||||
|  |         } | ||||||
|  |     elif engine == 'openai': | ||||||
|         oae = OpenAIEmbedding() |         oae = OpenAIEmbedding() | ||||||
|         emb_res = oae.get_emb(text_list) |         emb_res = oae.get_emb(text_list) | ||||||
|     elif engine == 'model': |     elif engine == 'model': | ||||||
|  | |||||||
							
								
								
									
										49
									
								
								coagent/embeddings/in_memory.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										49
									
								
								coagent/embeddings/in_memory.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,49 @@ | |||||||
|  | """Simple in memory docstore in the form of a dict.""" | ||||||
|  | from typing import Dict, List, Optional, Union | ||||||
|  | 
 | ||||||
|  | from langchain.docstore.base import AddableMixin, Docstore | ||||||
|  | from langchain.docstore.document import Document | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class InMemoryDocstore(Docstore, AddableMixin): | ||||||
|  |     """Simple in memory docstore in the form of a dict.""" | ||||||
|  | 
 | ||||||
|  |     def __init__(self, _dict: Optional[Dict[str, Document]] = None): | ||||||
|  |         """Initialize with dict.""" | ||||||
|  |         self._dict = _dict if _dict is not None else {} | ||||||
|  | 
 | ||||||
|  |     def add(self, texts: Dict[str, Document]) -> None: | ||||||
|  |         """Add texts to in memory dictionary. | ||||||
|  | 
 | ||||||
|  |         Args: | ||||||
|  |             texts: dictionary of id -> document. | ||||||
|  | 
 | ||||||
|  |         Returns: | ||||||
|  |             None | ||||||
|  |         """ | ||||||
|  |         overlapping = set(texts).intersection(self._dict) | ||||||
|  |         if overlapping: | ||||||
|  |             raise ValueError(f"Tried to add ids that already exist: {overlapping}") | ||||||
|  |         self._dict = {**self._dict, **texts} | ||||||
|  | 
 | ||||||
|  |     def delete(self, ids: List) -> None: | ||||||
|  |         """Deleting IDs from in memory dictionary.""" | ||||||
|  |         overlapping = set(ids).intersection(self._dict) | ||||||
|  |         if not overlapping: | ||||||
|  |             raise ValueError(f"Tried to delete ids that does not  exist: {ids}") | ||||||
|  |         for _id in ids: | ||||||
|  |             self._dict.pop(_id) | ||||||
|  | 
 | ||||||
|  |     def search(self, search: str) -> Union[str, Document]: | ||||||
|  |         """Search via direct lookup. | ||||||
|  | 
 | ||||||
|  |         Args: | ||||||
|  |             search: id of a document to search for. | ||||||
|  | 
 | ||||||
|  |         Returns: | ||||||
|  |             Document if found, else error message. | ||||||
|  |         """ | ||||||
|  |         if search not in self._dict: | ||||||
|  |             return f"ID {search} not found." | ||||||
|  |         else: | ||||||
|  |             return self._dict[search] | ||||||
| @ -1,6 +1,8 @@ | |||||||
| import os | import os | ||||||
| from functools import lru_cache | from functools import lru_cache | ||||||
| from langchain.embeddings.huggingface import HuggingFaceEmbeddings | from langchain.embeddings.huggingface import HuggingFaceEmbeddings | ||||||
|  | from langchain.embeddings.base import Embeddings | ||||||
|  | 
 | ||||||
| # from configs.model_config import embedding_model_dict | # from configs.model_config import embedding_model_dict | ||||||
| from loguru import logger | from loguru import logger | ||||||
| 
 | 
 | ||||||
| @ -12,8 +14,11 @@ def load_embeddings(model: str, device: str, embedding_model_dict: dict): | |||||||
|     return embeddings |     return embeddings | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @lru_cache(1) | # @lru_cache(1) | ||||||
| def load_embeddings_from_path(model_path: str, device: str): | def load_embeddings_from_path(model_path: str, device: str, langchain_embeddings: Embeddings = None): | ||||||
|  |     if langchain_embeddings: | ||||||
|  |         return langchain_embeddings | ||||||
|  |      | ||||||
|     embeddings = HuggingFaceEmbeddings(model_name=model_path, |     embeddings = HuggingFaceEmbeddings(model_name=model_path, | ||||||
|                                        model_kwargs={'device': device}) |                                        model_kwargs={'device': device}) | ||||||
|     return embeddings |     return embeddings | ||||||
|  | |||||||
| @ -1,8 +1,8 @@ | |||||||
| from .openai_model import getChatModel, getExtraModel, getChatModelFromConfig | from .openai_model import getExtraModel, getChatModelFromConfig | ||||||
| from .llm_config import LLMConfig, EmbedConfig | from .llm_config import LLMConfig, EmbedConfig | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| __all__ = [ | __all__ = [ | ||||||
|     "getChatModel", "getExtraModel", "getChatModelFromConfig", |     "getExtraModel", "getChatModelFromConfig", | ||||||
|     "LLMConfig", "EmbedConfig" |     "LLMConfig", "EmbedConfig" | ||||||
| ] | ] | ||||||
| @ -1,6 +1,9 @@ | |||||||
| from dataclasses import dataclass | from dataclasses import dataclass | ||||||
| from typing import List, Union | from typing import List, Union | ||||||
| 
 | 
 | ||||||
|  | from langchain.embeddings.base import Embeddings | ||||||
|  | from langchain.llms.base import LLM, BaseLLM | ||||||
|  | 
 | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @dataclass | @dataclass | ||||||
| @ -12,7 +15,8 @@ class LLMConfig: | |||||||
|             stop: Union[List[str], str] = None, |             stop: Union[List[str], str] = None, | ||||||
|             api_key: str = "", |             api_key: str = "", | ||||||
|             api_base_url: str = "", |             api_base_url: str = "", | ||||||
|             model_device: str = "cpu", |             model_device: str = "cpu", # unuse,will delete it | ||||||
|  |             llm: LLM = None, | ||||||
|             **kwargs |             **kwargs | ||||||
|         ): |         ): | ||||||
| 
 | 
 | ||||||
| @ -21,7 +25,7 @@ class LLMConfig: | |||||||
|         self.stop: Union[List[str], str] = stop |         self.stop: Union[List[str], str] = stop | ||||||
|         self.api_key: str = api_key |         self.api_key: str = api_key | ||||||
|         self.api_base_url: str = api_base_url |         self.api_base_url: str = api_base_url | ||||||
|         self.model_device: str = model_device |         self.llm: LLM = llm | ||||||
|         #  |         #  | ||||||
|         self.check_config() |         self.check_config() | ||||||
| 
 | 
 | ||||||
| @ -42,6 +46,7 @@ class EmbedConfig: | |||||||
|             embed_model_path: str = "", |             embed_model_path: str = "", | ||||||
|             embed_engine: str = "", |             embed_engine: str = "", | ||||||
|             model_device: str = "cpu", |             model_device: str = "cpu", | ||||||
|  |             langchain_embeddings: Embeddings = None, | ||||||
|             **kwargs |             **kwargs | ||||||
|         ): |         ): | ||||||
|         self.embed_model: str = embed_model |         self.embed_model: str = embed_model | ||||||
| @ -51,6 +56,8 @@ class EmbedConfig: | |||||||
|         self.api_key: str = api_key |         self.api_key: str = api_key | ||||||
|         self.api_base_url: str = api_base_url |         self.api_base_url: str = api_base_url | ||||||
|         #  |         #  | ||||||
|  |         self.langchain_embeddings = langchain_embeddings | ||||||
|  |         #  | ||||||
|         self.check_config() |         self.check_config() | ||||||
| 
 | 
 | ||||||
|     def check_config(self, ): |     def check_config(self, ): | ||||||
|  | |||||||
| @ -1,38 +1,54 @@ | |||||||
| import os | import os | ||||||
|  | from typing import Union, Optional, List | ||||||
|  | from loguru import logger | ||||||
| 
 | 
 | ||||||
| from langchain.callbacks import AsyncIteratorCallbackHandler | from langchain.callbacks import AsyncIteratorCallbackHandler | ||||||
| from langchain.chat_models import ChatOpenAI | from langchain.chat_models import ChatOpenAI | ||||||
|  | from langchain.llms.base import LLM | ||||||
| 
 | 
 | ||||||
| from .llm_config import LLMConfig | from .llm_config import LLMConfig | ||||||
| # from configs.model_config import (llm_model_dict, LLM_MODEL) | # from configs.model_config import (llm_model_dict, LLM_MODEL) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def getChatModel(callBack: AsyncIteratorCallbackHandler = None, temperature=0.3, stop=None): | class CustomLLMModel: | ||||||
|     if callBack is None: |      | ||||||
|  |     def __init__(self, llm: LLM): | ||||||
|  |         self.llm: LLM = llm | ||||||
|  | 
 | ||||||
|  |     def __call__(self, prompt: str, | ||||||
|  |                   stop: Optional[List[str]] = None): | ||||||
|  |         return self.llm(prompt, stop) | ||||||
|  | 
 | ||||||
|  |     def _call(self, prompt: str, | ||||||
|  |                   stop: Optional[List[str]] = None): | ||||||
|  |         return self.llm(prompt, stop) | ||||||
|  |      | ||||||
|  |     def predict(self, prompt: str, | ||||||
|  |                   stop: Optional[List[str]] = None): | ||||||
|  |         return self.llm(prompt, stop) | ||||||
|  | 
 | ||||||
|  |     def batch(self, prompts: str, | ||||||
|  |                   stop: Optional[List[str]] = None): | ||||||
|  |         return [self.llm(prompt, stop) for prompt in prompts] | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def getChatModelFromConfig(llm_config: LLMConfig, callBack: AsyncIteratorCallbackHandler = None, ) -> Union[ChatOpenAI, LLM]: | ||||||
|  |     # logger.debug(f"llm type is {type(llm_config.llm)}") | ||||||
|  |     if llm_config is None: | ||||||
|         model = ChatOpenAI( |         model = ChatOpenAI( | ||||||
|             streaming=True, |             streaming=True, | ||||||
|             verbose=True, |             verbose=True, | ||||||
|             openai_api_key=llm_model_dict[LLM_MODEL]["api_key"], |             openai_api_key=os.environ.get("api_key"), | ||||||
|             openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"], |             openai_api_base=os.environ.get("api_base_url"), | ||||||
|             model_name=LLM_MODEL, |             model_name=os.environ.get("LLM_MODEL", "gpt-3.5-turbo"), | ||||||
|             temperature=temperature, |             temperature=os.environ.get("temperature", 0.5), | ||||||
|             stop=stop |             stop=os.environ.get("stop", ""), | ||||||
|         ) |         ) | ||||||
|     else: |         return model | ||||||
|         model = ChatOpenAI( |  | ||||||
|             streaming=True, |  | ||||||
|             verbose=True, |  | ||||||
|             callBack=[callBack], |  | ||||||
|             openai_api_key=llm_model_dict[LLM_MODEL]["api_key"], |  | ||||||
|             openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"], |  | ||||||
|             model_name=LLM_MODEL, |  | ||||||
|             temperature=temperature, |  | ||||||
|             stop=stop |  | ||||||
|         ) |  | ||||||
|     return model |  | ||||||
| 
 | 
 | ||||||
|  |     if llm_config and llm_config.llm and isinstance(llm_config.llm, LLM): | ||||||
|  |         return CustomLLMModel(llm=llm_config.llm) | ||||||
|      |      | ||||||
| def getChatModelFromConfig(llm_config: LLMConfig, callBack: AsyncIteratorCallbackHandler = None, ): |  | ||||||
|     if callBack is None: |     if callBack is None: | ||||||
|         model = ChatOpenAI( |         model = ChatOpenAI( | ||||||
|                 streaming=True, |                 streaming=True, | ||||||
|  | |||||||
							
								
								
									
										5
									
								
								coagent/retrieval/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								coagent/retrieval/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,5 @@ | |||||||
|  | # from .base_retrieval import * | ||||||
|  | 
 | ||||||
|  | # __all__ = [ | ||||||
|  | #     "IMRertrieval", "BaseDocRetrieval", "BaseCodeRetrieval", "BaseSearchRetrieval" | ||||||
|  | # ] | ||||||
							
								
								
									
										75
									
								
								coagent/retrieval/base_retrieval.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										75
									
								
								coagent/retrieval/base_retrieval.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,75 @@ | |||||||
|  | 
 | ||||||
|  | from coagent.llm_models.llm_config import EmbedConfig, LLMConfig | ||||||
|  | from coagent.base_configs.env_config import KB_ROOT_PATH | ||||||
|  | from coagent.tools import DocRetrieval, CodeRetrieval | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class IMRertrieval: | ||||||
|  | 
 | ||||||
|  |     def __init__(self,): | ||||||
|  |         ''' | ||||||
|  |         init your personal attributes | ||||||
|  |         ''' | ||||||
|  |         pass | ||||||
|  | 
 | ||||||
|  |     def run(self, ): | ||||||
|  |         ''' | ||||||
|  |         execute interface, and can use init' attributes | ||||||
|  |         ''' | ||||||
|  |         pass | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class BaseDocRetrieval(IMRertrieval): | ||||||
|  | 
 | ||||||
|  |     def __init__(self, knowledge_base_name: str, search_top=5, score_threshold=1.0, embed_config: EmbedConfig=EmbedConfig(), kb_root_path: str=KB_ROOT_PATH): | ||||||
|  |         self.knowledge_base_name = knowledge_base_name | ||||||
|  |         self.search_top = search_top | ||||||
|  |         self.score_threshold = score_threshold | ||||||
|  |         self.embed_config = embed_config | ||||||
|  |         self.kb_root_path = kb_root_path | ||||||
|  | 
 | ||||||
|  |     def run(self, query: str, search_top=None, score_threshold=None, ): | ||||||
|  |         docs = DocRetrieval.run( | ||||||
|  |             query=query, knowledge_base_name=self.knowledge_base_name, | ||||||
|  |             search_top=search_top or self.search_top, | ||||||
|  |             score_threshold=score_threshold or self.score_threshold, | ||||||
|  |             embed_config=self.embed_config, | ||||||
|  |             kb_root_path=self.kb_root_path | ||||||
|  |         ) | ||||||
|  |         return docs | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class BaseCodeRetrieval(IMRertrieval): | ||||||
|  | 
 | ||||||
|  |     def __init__(self, code_base_name, embed_config: EmbedConfig, llm_config: LLMConfig, search_type = 'tag', code_limit = 1, local_graph_path: str=""): | ||||||
|  |         self.code_base_name = code_base_name | ||||||
|  |         self.embed_config = embed_config | ||||||
|  |         self.llm_config = llm_config | ||||||
|  |         self.search_type = search_type | ||||||
|  |         self.code_limit = code_limit | ||||||
|  |         self.use_nh: bool = False | ||||||
|  |         self.local_graph_path: str = local_graph_path | ||||||
|  | 
 | ||||||
|  |     def run(self, query,  history_node_list=[], search_type = None, code_limit=None): | ||||||
|  |         code_docs = CodeRetrieval.run( | ||||||
|  |             code_base_name=self.code_base_name, | ||||||
|  |             query=query,  | ||||||
|  |             history_node_list=history_node_list, | ||||||
|  |             code_limit=code_limit or self.code_limit, | ||||||
|  |             search_type=search_type or self.search_type, | ||||||
|  |             llm_config=self.llm_config, | ||||||
|  |             embed_config=self.embed_config, | ||||||
|  |             use_nh=self.use_nh, | ||||||
|  |             local_graph_path=self.local_graph_path | ||||||
|  |             ) | ||||||
|  |         return code_docs | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class BaseSearchRetrieval(IMRertrieval): | ||||||
|  | 
 | ||||||
|  |     def __init__(self, ): | ||||||
|  |         pass | ||||||
|  | 
 | ||||||
|  |     def run(self, ): | ||||||
|  |         pass | ||||||
							
								
								
									
										6
									
								
								coagent/retrieval/document_loaders/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										6
									
								
								coagent/retrieval/document_loaders/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,6 @@ | |||||||
|  | from .json_loader import JSONLoader | ||||||
|  | from .jsonl_loader import JSONLLoader | ||||||
|  | 
 | ||||||
|  | __all__ = [ | ||||||
|  |     "JSONLoader", "JSONLLoader" | ||||||
|  | ] | ||||||
							
								
								
									
										61
									
								
								coagent/retrieval/document_loaders/json_loader.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										61
									
								
								coagent/retrieval/document_loaders/json_loader.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,61 @@ | |||||||
|  | import json | ||||||
|  | from pathlib import Path | ||||||
|  | from typing import AnyStr, Callable, Dict, List, Optional, Union | ||||||
|  | 
 | ||||||
|  | from langchain.docstore.document import Document | ||||||
|  | from langchain.document_loaders.base import BaseLoader | ||||||
|  | from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter | ||||||
|  | 
 | ||||||
|  | from coagent.utils.common_utils import read_json_file | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class JSONLoader(BaseLoader): | ||||||
|  | 
 | ||||||
|  |     def __init__( | ||||||
|  |             self, | ||||||
|  |             file_path: Union[str, Path], | ||||||
|  |             schema_key: str = "all_text", | ||||||
|  |             content_key: Optional[str] = None, | ||||||
|  |             metadata_func: Optional[Callable[[Dict, Dict], Dict]] = None, | ||||||
|  |             text_content: bool = True, | ||||||
|  |     ): | ||||||
|  |         self.file_path = Path(file_path).resolve() | ||||||
|  |         self.schema_key = schema_key | ||||||
|  |         self._content_key = content_key | ||||||
|  |         self._metadata_func = metadata_func | ||||||
|  |         self._text_content = text_content | ||||||
|  | 
 | ||||||
|  |     def load(self, ) -> List[Document]: | ||||||
|  |         """Load and return documents from the JSON file.""" | ||||||
|  |         docs: List[Document] = [] | ||||||
|  |         datas = read_json_file(self.file_path) | ||||||
|  |         self._parse(datas, docs) | ||||||
|  |         return docs | ||||||
|  |      | ||||||
|  |     def _parse(self, datas: List, docs: List[Document]) -> None: | ||||||
|  |         for idx, sample in enumerate(datas): | ||||||
|  |             metadata = dict( | ||||||
|  |                 source=str(self.file_path), | ||||||
|  |                 seq_num=idx, | ||||||
|  |             ) | ||||||
|  |             text = sample.get(self.schema_key, "") | ||||||
|  |             docs.append(Document(page_content=text, metadata=metadata)) | ||||||
|  | 
 | ||||||
|  |     def load_and_split( | ||||||
|  |         self, text_splitter: Optional[TextSplitter] = None | ||||||
|  |     ) -> List[Document]: | ||||||
|  |         """Load Documents and split into chunks. Chunks are returned as Documents. | ||||||
|  | 
 | ||||||
|  |         Args: | ||||||
|  |             text_splitter: TextSplitter instance to use for splitting documents. | ||||||
|  |               Defaults to RecursiveCharacterTextSplitter. | ||||||
|  | 
 | ||||||
|  |         Returns: | ||||||
|  |             List of Documents. | ||||||
|  |         """ | ||||||
|  |         if text_splitter is None: | ||||||
|  |             _text_splitter: TextSplitter = RecursiveCharacterTextSplitter() | ||||||
|  |         else: | ||||||
|  |             _text_splitter = text_splitter | ||||||
|  |         docs = self.load() | ||||||
|  |         return _text_splitter.split_documents(docs) | ||||||
							
								
								
									
										62
									
								
								coagent/retrieval/document_loaders/jsonl_loader.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										62
									
								
								coagent/retrieval/document_loaders/jsonl_loader.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,62 @@ | |||||||
|  | import json | ||||||
|  | from pathlib import Path | ||||||
|  | from typing import AnyStr, Callable, Dict, List, Optional, Union | ||||||
|  | 
 | ||||||
|  | from langchain.docstore.document import Document | ||||||
|  | from langchain.document_loaders.base import BaseLoader | ||||||
|  | from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter | ||||||
|  | 
 | ||||||
|  | from coagent.utils.common_utils import read_jsonl_file | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class JSONLLoader(BaseLoader): | ||||||
|  | 
 | ||||||
|  |     def __init__( | ||||||
|  |             self, | ||||||
|  |             file_path: Union[str, Path], | ||||||
|  |             schema_key: str = "all_text", | ||||||
|  |             content_key: Optional[str] = None, | ||||||
|  |             metadata_func: Optional[Callable[[Dict, Dict], Dict]] = None, | ||||||
|  |             text_content: bool = True, | ||||||
|  |     ): | ||||||
|  |         self.file_path = Path(file_path).resolve() | ||||||
|  |         self.schema_key = schema_key | ||||||
|  |         self._content_key = content_key | ||||||
|  |         self._metadata_func = metadata_func | ||||||
|  |         self._text_content = text_content | ||||||
|  | 
 | ||||||
|  |     def load(self, ) -> List[Document]: | ||||||
|  |         """Load and return documents from the JSON file.""" | ||||||
|  |         docs: List[Document] = [] | ||||||
|  |         datas = read_jsonl_file(self.file_path) | ||||||
|  |         self._parse(datas, docs) | ||||||
|  |         return docs | ||||||
|  |      | ||||||
|  |     def _parse(self, datas: List, docs: List[Document]) -> None: | ||||||
|  |         for idx, sample in enumerate(datas): | ||||||
|  |             metadata = dict( | ||||||
|  |                 source=str(self.file_path), | ||||||
|  |                 seq_num=idx, | ||||||
|  |             ) | ||||||
|  |             text = sample.get(self.schema_key, "") | ||||||
|  |             docs.append(Document(page_content=text, metadata=metadata)) | ||||||
|  | 
 | ||||||
|  |     def load_and_split( | ||||||
|  |         self, text_splitter: Optional[TextSplitter] = None | ||||||
|  |     ) -> List[Document]: | ||||||
|  |         """Load Documents and split into chunks. Chunks are returned as Documents. | ||||||
|  | 
 | ||||||
|  |         Args: | ||||||
|  |             text_splitter: TextSplitter instance to use for splitting documents. | ||||||
|  |               Defaults to RecursiveCharacterTextSplitter. | ||||||
|  | 
 | ||||||
|  |         Returns: | ||||||
|  |             List of Documents. | ||||||
|  |         """ | ||||||
|  |         if text_splitter is None: | ||||||
|  |             _text_splitter: TextSplitter = RecursiveCharacterTextSplitter() | ||||||
|  |         else: | ||||||
|  |             _text_splitter = text_splitter | ||||||
|  | 
 | ||||||
|  |         docs = self.load() | ||||||
|  |         return _text_splitter.split_documents(docs) | ||||||
							
								
								
									
										3
									
								
								coagent/retrieval/text_splitter/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										3
									
								
								coagent/retrieval/text_splitter/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,3 @@ | |||||||
|  | from .langchain_splitter import LCTextSplitter | ||||||
|  | 
 | ||||||
|  | __all__ = ["LCTextSplitter"] | ||||||
							
								
								
									
										77
									
								
								coagent/retrieval/text_splitter/langchain_splitter.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										77
									
								
								coagent/retrieval/text_splitter/langchain_splitter.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,77 @@ | |||||||
|  | import os | ||||||
|  | import importlib | ||||||
|  | from loguru import logger | ||||||
|  | 
 | ||||||
|  | from langchain.document_loaders.base import BaseLoader | ||||||
|  | from langchain.text_splitter import ( | ||||||
|  |     SpacyTextSplitter, RecursiveCharacterTextSplitter | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | # from configs.model_config import ( | ||||||
|  | #     CHUNK_SIZE, | ||||||
|  | #     OVERLAP_SIZE, | ||||||
|  | #     ZH_TITLE_ENHANCE | ||||||
|  | # ) | ||||||
|  | from coagent.utils.path_utils import * | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class LCTextSplitter: | ||||||
|  |     '''langchain textsplitter 执行file2text''' | ||||||
|  |     def __init__( | ||||||
|  |             self, filepath: str, text_splitter_name: str = None, | ||||||
|  |             chunk_size: int = 500,  | ||||||
|  |             overlap_size: int = 50 | ||||||
|  |     ): | ||||||
|  |         self.filepath = filepath | ||||||
|  |         self.ext = os.path.splitext(filepath)[-1].lower() | ||||||
|  |         self.text_splitter_name = text_splitter_name | ||||||
|  |         self.chunk_size = chunk_size | ||||||
|  |         self.overlap_size = overlap_size | ||||||
|  |         if self.ext not in SUPPORTED_EXTS: | ||||||
|  |             raise ValueError(f"暂未支持的文件格式 {self.ext}") | ||||||
|  |         self.document_loader_name = get_LoaderClass(self.ext) | ||||||
|  | 
 | ||||||
|  |     def file2text(self, ): | ||||||
|  |         loader = self._load_document() | ||||||
|  |         text_splitter = self._load_text_splitter() | ||||||
|  |         if self.document_loader_name in ["JSONLoader", "JSONLLoader"]: | ||||||
|  |             # docs = loader.load() | ||||||
|  |             docs = loader.load_and_split(text_splitter) | ||||||
|  |             # logger.debug(f"please check your file can be loaded, docs.lens {len(docs)}") | ||||||
|  |         else: | ||||||
|  |             docs = loader.load_and_split(text_splitter) | ||||||
|  | 
 | ||||||
|  |         return docs | ||||||
|  | 
 | ||||||
|  |     def _load_document(self, ) -> BaseLoader: | ||||||
|  |         DocumentLoader = EXT2LOADER_DICT[self.ext] | ||||||
|  |         if self.document_loader_name == "UnstructuredFileLoader": | ||||||
|  |             loader = DocumentLoader(self.filepath, autodetect_encoding=True) | ||||||
|  |         else: | ||||||
|  |             loader = DocumentLoader(self.filepath) | ||||||
|  |         return loader | ||||||
|  |      | ||||||
|  |     def _load_text_splitter(self, ): | ||||||
|  |         try: | ||||||
|  |             if self.text_splitter_name is None: | ||||||
|  |                 text_splitter = SpacyTextSplitter( | ||||||
|  |                     pipeline="zh_core_web_sm", | ||||||
|  |                     chunk_size=self.chunk_size, | ||||||
|  |                     chunk_overlap=self.overlap_size, | ||||||
|  |                 ) | ||||||
|  |                 self.text_splitter_name = "SpacyTextSplitter" | ||||||
|  |             # elif self.document_loader_name in ["JSONLoader", "JSONLLoader"]: | ||||||
|  |             #     text_splitter = None | ||||||
|  |             else: | ||||||
|  |                 text_splitter_module = importlib.import_module('langchain.text_splitter') | ||||||
|  |                 TextSplitter = getattr(text_splitter_module, self.text_splitter_name) | ||||||
|  |                 text_splitter = TextSplitter( | ||||||
|  |                     chunk_size=self.chunk_size, | ||||||
|  |                     chunk_overlap=self.overlap_size) | ||||||
|  |         except Exception as e: | ||||||
|  |             text_splitter = RecursiveCharacterTextSplitter( | ||||||
|  |                 chunk_size=self.chunk_size, | ||||||
|  |                 chunk_overlap=self.overlap_size, | ||||||
|  |             ) | ||||||
|  |         return text_splitter | ||||||
							
								
								
									
										0
									
								
								coagent/retrieval/text_splitter/utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								coagent/retrieval/text_splitter/utils.py
									
									
									
									
									
										Normal file
									
								
							| @ -32,8 +32,8 @@ class PyCodeBox(BaseBox): | |||||||
|         self.do_check_net = do_check_net |         self.do_check_net = do_check_net | ||||||
|         self.use_stop = use_stop |         self.use_stop = use_stop | ||||||
|         self.jupyter_work_path = jupyter_work_path |         self.jupyter_work_path = jupyter_work_path | ||||||
|         asyncio.run(self.astart()) |         # asyncio.run(self.astart()) | ||||||
|         # self.start() |         self.start() | ||||||
| 
 | 
 | ||||||
|         # logger.info(f"""remote_url: {self.remote_url}, |         # logger.info(f"""remote_url: {self.remote_url}, | ||||||
|         #             remote_ip: {self.remote_ip}, |         #             remote_ip: {self.remote_ip}, | ||||||
| @ -199,13 +199,13 @@ class PyCodeBox(BaseBox): | |||||||
|          |          | ||||||
|     async def _aget_kernelid(self, ) -> None: |     async def _aget_kernelid(self, ) -> None: | ||||||
|         headers = {"Authorization": f'Token {self.token}', 'token': self.token} |         headers = {"Authorization": f'Token {self.token}', 'token': self.token} | ||||||
|         response = requests.get(f"{self.kernel_url}?token={self.token}", headers=headers) |         # response = requests.get(f"{self.kernel_url}?token={self.token}", headers=headers) | ||||||
|         async with aiohttp.ClientSession() as session: |         async with aiohttp.ClientSession() as session: | ||||||
|             async with session.get(f"{self.kernel_url}?token={self.token}", headers=headers) as resp: |             async with session.get(f"{self.kernel_url}?token={self.token}", headers=headers, timeout=270) as resp: | ||||||
|                 if len(await resp.json()) > 0: |                 if len(await resp.json()) > 0: | ||||||
|                     self.kernel_id = (await resp.json())[0]["id"] |                     self.kernel_id = (await resp.json())[0]["id"] | ||||||
|                 else: |                 else: | ||||||
|                     async with session.post(f"{self.kernel_url}?token={self.token}", headers=headers) as response: |                     async with session.post(f"{self.kernel_url}?token={self.token}", headers=headers, timeout=270) as response: | ||||||
|                         self.kernel_id = (await response.json())["id"] |                         self.kernel_id = (await response.json())["id"] | ||||||
| 
 | 
 | ||||||
|         # if len(response.json()) > 0: |         # if len(response.json()) > 0: | ||||||
| @ -220,41 +220,45 @@ class PyCodeBox(BaseBox): | |||||||
|             return False |             return False | ||||||
| 
 | 
 | ||||||
|         try: |         try: | ||||||
|              response = requests.get(f"{self.kernel_url}?token={self.token}", timeout=270) |              response = requests.get(f"{self.kernel_url}?token={self.token}", timeout=10) | ||||||
|              return response.status_code == 200 |              return response.status_code == 200 | ||||||
|         except requests.exceptions.ConnectionError: |         except requests.exceptions.ConnectionError: | ||||||
|             return False |             return False | ||||||
|  |         except requests.exceptions.ReadTimeout: | ||||||
|  |             return False | ||||||
| 
 | 
 | ||||||
|     async def _acheck_connect(self, ) -> bool: |     async def _acheck_connect(self, ) -> bool: | ||||||
|         if self.kernel_url == "": |         if self.kernel_url == "": | ||||||
|             return False |             return False | ||||||
|         try: |         try: | ||||||
|             async with aiohttp.ClientSession() as session: |             async with aiohttp.ClientSession() as session: | ||||||
|                 async with session.get(f"{self.kernel_url}?token={self.token}", timeout=270) as resp: |                 async with session.get(f"{self.kernel_url}?token={self.token}", timeout=10) as resp: | ||||||
|                     return resp.status == 200 |                     return resp.status == 200 | ||||||
|         except aiohttp.ClientConnectorError: |         except aiohttp.ClientConnectorError: | ||||||
|             pass |             return False | ||||||
|         except aiohttp.ServerDisconnectedError: |         except aiohttp.ServerDisconnectedError: | ||||||
|             pass |             return False | ||||||
| 
 | 
 | ||||||
|     def  _check_port(self, ) -> bool: |     def  _check_port(self, ) -> bool: | ||||||
|         try: |         try: | ||||||
|              response = requests.get(f"{self.remote_ip}:{self.remote_port}", timeout=270) |              response = requests.get(f"{self.remote_ip}:{self.remote_port}", timeout=10) | ||||||
|              logger.warning(f"Port is conflict, please check your codebox's port {self.remote_port}") |              logger.warning(f"Port is conflict, please check your codebox's port {self.remote_port}") | ||||||
|              return response.status_code == 200 |              return response.status_code == 200 | ||||||
|         except requests.exceptions.ConnectionError: |         except requests.exceptions.ConnectionError: | ||||||
|             return False |             return False | ||||||
|  |         except requests.exceptions.ReadTimeout: | ||||||
|  |             return False | ||||||
|          |          | ||||||
|     async def _acheck_port(self, ) -> bool: |     async def _acheck_port(self, ) -> bool: | ||||||
|         try: |         try: | ||||||
|             async with aiohttp.ClientSession() as session: |             async with aiohttp.ClientSession() as session: | ||||||
|                 async with session.get(f"{self.remote_ip}:{self.remote_port}", timeout=270) as resp: |                 async with session.get(f"{self.remote_ip}:{self.remote_port}", timeout=10) as resp: | ||||||
|                     # logger.warning(f"Port is conflict, please check your codebox's port {self.remote_port}") |                     # logger.warning(f"Port is conflict, please check your codebox's port {self.remote_port}") | ||||||
|                     return resp.status == 200 |                     return resp.status == 200 | ||||||
|         except aiohttp.ClientConnectorError: |         except aiohttp.ClientConnectorError: | ||||||
|             pass |             return False | ||||||
|         except aiohttp.ServerDisconnectedError: |         except aiohttp.ServerDisconnectedError: | ||||||
|             pass |             return False | ||||||
| 
 | 
 | ||||||
|     def _check_connect_success(self, retry_nums: int = 2) -> bool: |     def _check_connect_success(self, retry_nums: int = 2) -> bool: | ||||||
|         if not self.do_check_net: return True |         if not self.do_check_net: return True | ||||||
| @ -263,7 +267,7 @@ class PyCodeBox(BaseBox): | |||||||
|             try: |             try: | ||||||
|                 connect_status = self._check_connect() |                 connect_status = self._check_connect() | ||||||
|                 if connect_status: |                 if connect_status: | ||||||
|                     logger.info(f"{self.remote_url} connection success") |                     # logger.info(f"{self.remote_url} connection success") | ||||||
|                     return True |                     return True | ||||||
|             except requests.exceptions.ConnectionError: |             except requests.exceptions.ConnectionError: | ||||||
|                 logger.info(f"{self.remote_url} connection fail") |                 logger.info(f"{self.remote_url} connection fail") | ||||||
| @ -301,10 +305,12 @@ class PyCodeBox(BaseBox): | |||||||
|         else: |         else: | ||||||
|             # TODO 自动检测本地接口 |             # TODO 自动检测本地接口 | ||||||
|             port_status = self._check_port() |             port_status = self._check_port() | ||||||
|  |             self.kernel_url = self.remote_url + "/api/kernels" | ||||||
|             connect_status = self._check_connect() |             connect_status = self._check_connect() | ||||||
|             logger.info(f"port_status: {port_status}, connect_status: {connect_status}") |             if os.environ.get("log_verbose", "0") >= "2": | ||||||
|  |                 logger.info(f"port_status: {port_status}, connect_status: {connect_status}") | ||||||
|             if port_status and not connect_status: |             if port_status and not connect_status: | ||||||
|                 raise BaseException(f"Port is conflict, please check your codebox's port {self.remote_port}") |                 logger.error("Port is conflict, please check your codebox's port {self.remote_port}") | ||||||
|              |              | ||||||
|             if not connect_status: |             if not connect_status: | ||||||
|                 self.jupyter = subprocess.Popen( |                 self.jupyter = subprocess.Popen( | ||||||
| @ -321,14 +327,32 @@ class PyCodeBox(BaseBox): | |||||||
|                     stdout=subprocess.PIPE, |                     stdout=subprocess.PIPE, | ||||||
|                 ) |                 ) | ||||||
| 
 | 
 | ||||||
|  |             record = [] | ||||||
|  |             while True and self.jupyter and len(record)<100: | ||||||
|  |                 line = self.jupyter.stderr.readline() | ||||||
|  |                 try: | ||||||
|  |                     content = line.decode("utf-8") | ||||||
|  |                 except: | ||||||
|  |                     content = line.decode("gbk") | ||||||
|  |                 # logger.debug(content) | ||||||
|  |                 record.append(content) | ||||||
|  |                 if "control-c" in content.lower(): | ||||||
|  |                     break | ||||||
|  | 
 | ||||||
|             self.kernel_url = self.remote_url + "/api/kernels" |             self.kernel_url = self.remote_url + "/api/kernels" | ||||||
|             self.do_check_net = True |             self.do_check_net = True | ||||||
|             self._check_connect_success() |             self._check_connect_success() | ||||||
|             self._get_kernelid() |             self._get_kernelid() | ||||||
|             # logger.debug(self.kernel_url.replace("http", "ws") + f"/{self.kernel_id}/channels?token={self.token}") |  | ||||||
|             self.wc_url = self.kernel_url.replace("http", "ws") + f"/{self.kernel_id}/channels?token={self.token}" |             self.wc_url = self.kernel_url.replace("http", "ws") + f"/{self.kernel_id}/channels?token={self.token}" | ||||||
|             headers = {"Authorization": f'Token {self.token}', 'token': self.token} |             headers = {"Authorization": f'Token {self.token}', 'token': self.token} | ||||||
|             self.ws = create_connection(self.wc_url, headers=headers) |             retry_nums = 3 | ||||||
|  |             while retry_nums>=0: | ||||||
|  |                 try: | ||||||
|  |                     self.ws = create_connection(self.wc_url, headers=headers, timeout=5) | ||||||
|  |                     break | ||||||
|  |                 except Exception as e: | ||||||
|  |                     logger.error(f"create ws connection timeout {e}") | ||||||
|  |                     retry_nums -= 1 | ||||||
| 
 | 
 | ||||||
|     async def astart(self, ): |     async def astart(self, ): | ||||||
|         '''判断是从外部service执行还是内部启动notebook执行''' |         '''判断是从外部service执行还是内部启动notebook执行''' | ||||||
| @ -369,10 +393,16 @@ class PyCodeBox(BaseBox): | |||||||
|                     cwd=self.jupyter_work_path |                     cwd=self.jupyter_work_path | ||||||
|                 ) |                 ) | ||||||
| 
 | 
 | ||||||
|             while True and self.jupyter: |             record = [] | ||||||
|  |             while True and self.jupyter and len(record)<100: | ||||||
|                 line = self.jupyter.stderr.readline() |                 line = self.jupyter.stderr.readline() | ||||||
|                 # logger.debug(line.decode("gbk")) |                 try: | ||||||
|                 if "Control-C" in line.decode("gbk"): |                     content = line.decode("utf-8") | ||||||
|  |                 except: | ||||||
|  |                     content = line.decode("gbk") | ||||||
|  |                 # logger.debug(content) | ||||||
|  |                 record.append(content) | ||||||
|  |                 if "control-c" in content.lower(): | ||||||
|                     break |                     break | ||||||
|             self.kernel_url = self.remote_url + "/api/kernels" |             self.kernel_url = self.remote_url + "/api/kernels" | ||||||
|             self.do_check_net = True |             self.do_check_net = True | ||||||
| @ -380,7 +410,15 @@ class PyCodeBox(BaseBox): | |||||||
|             await self._aget_kernelid() |             await self._aget_kernelid() | ||||||
|             self.wc_url = self.kernel_url.replace("http", "ws") + f"/{self.kernel_id}/channels?token={self.token}" |             self.wc_url = self.kernel_url.replace("http", "ws") + f"/{self.kernel_id}/channels?token={self.token}" | ||||||
|             headers = {"Authorization": f'Token {self.token}', 'token': self.token} |             headers = {"Authorization": f'Token {self.token}', 'token': self.token} | ||||||
|             self.ws = create_connection(self.wc_url, headers=headers) | 
 | ||||||
|  |             retry_nums = 3 | ||||||
|  |             while retry_nums>=0: | ||||||
|  |                 try: | ||||||
|  |                     self.ws = create_connection(self.wc_url, headers=headers, timeout=5) | ||||||
|  |                     break | ||||||
|  |                 except Exception as e: | ||||||
|  |                     logger.error(f"create ws connection timeout {e}") | ||||||
|  |                     retry_nums -= 1 | ||||||
| 
 | 
 | ||||||
|     def status(self,) -> CodeBoxStatus: |     def status(self,) -> CodeBoxStatus: | ||||||
|         if not self.kernel_id: |         if not self.kernel_id: | ||||||
|  | |||||||
| @ -17,7 +17,7 @@ from coagent.orm.commands import * | |||||||
| from coagent.utils.path_utils import * | from coagent.utils.path_utils import * | ||||||
| from coagent.orm.utils import DocumentFile | from coagent.orm.utils import DocumentFile | ||||||
| from coagent.embeddings.utils import load_embeddings, load_embeddings_from_path | from coagent.embeddings.utils import load_embeddings, load_embeddings_from_path | ||||||
| from coagent.text_splitter import LCTextSplitter | from coagent.retrieval.text_splitter import LCTextSplitter | ||||||
| from coagent.llm_models.llm_config import EmbedConfig | from coagent.llm_models.llm_config import EmbedConfig | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @ -46,7 +46,7 @@ class KBService(ABC): | |||||||
| 
 | 
 | ||||||
|     def _load_embeddings(self) -> Embeddings: |     def _load_embeddings(self) -> Embeddings: | ||||||
|         # return load_embeddings(self.embed_model, embed_device, embedding_model_dict) |         # return load_embeddings(self.embed_model, embed_device, embedding_model_dict) | ||||||
|         return load_embeddings_from_path(self.embed_config.embed_model_path, self.embed_config.model_device) |         return load_embeddings_from_path(self.embed_config.embed_model_path, self.embed_config.model_device, self.embed_config.langchain_embeddings) | ||||||
| 
 | 
 | ||||||
|     def create_kb(self): |     def create_kb(self): | ||||||
|         """ |         """ | ||||||
|  | |||||||
| @ -20,9 +20,6 @@ from coagent.utils.path_utils import * | |||||||
| from coagent.orm.commands import * | from coagent.orm.commands import * | ||||||
| from coagent.db_handler.graph_db_handler.nebula_handler import NebulaHandler | from coagent.db_handler.graph_db_handler.nebula_handler import NebulaHandler | ||||||
| from coagent.db_handler.vector_db_handler.chroma_handler import ChromaHandler | from coagent.db_handler.vector_db_handler.chroma_handler import ChromaHandler | ||||||
| # from configs.server_config import NEBULA_HOST, NEBULA_PORT, NEBULA_USER, NEBULA_PASSWORD, NEBULA_STORAGED_PORT |  | ||||||
| # from configs.server_config import CHROMA_PERSISTENT_PATH |  | ||||||
| 
 |  | ||||||
| from coagent.base_configs.env_config import ( | from coagent.base_configs.env_config import ( | ||||||
|     CB_ROOT_PATH,  |     CB_ROOT_PATH,  | ||||||
|     NEBULA_HOST, NEBULA_PORT, NEBULA_USER, NEBULA_PASSWORD, NEBULA_STORAGED_PORT, |     NEBULA_HOST, NEBULA_PORT, NEBULA_USER, NEBULA_PASSWORD, NEBULA_STORAGED_PORT, | ||||||
| @ -58,10 +55,11 @@ async def create_cb(zip_file, | |||||||
|                     model_name: bool = Body(..., examples=["samples"]), |                     model_name: bool = Body(..., examples=["samples"]), | ||||||
|                     temperature: bool = Body(..., examples=["samples"]), |                     temperature: bool = Body(..., examples=["samples"]), | ||||||
|                     model_device: bool = Body(..., examples=["samples"]), |                     model_device: bool = Body(..., examples=["samples"]), | ||||||
|  |                     embed_config: EmbedConfig = None, | ||||||
|                     ) -> BaseResponse: |                     ) -> BaseResponse: | ||||||
|     logger.info('cb_name={}, zip_path={}, do_interpret={}'.format(cb_name, code_path, do_interpret)) |     logger.info('cb_name={}, zip_path={}, do_interpret={}'.format(cb_name, code_path, do_interpret)) | ||||||
| 
 | 
 | ||||||
|     embed_config: EmbedConfig = EmbedConfig(**locals()) |     embed_config: EmbedConfig = EmbedConfig(**locals()) if embed_config is None else embed_config | ||||||
|     llm_config: LLMConfig = LLMConfig(**locals()) |     llm_config: LLMConfig = LLMConfig(**locals()) | ||||||
| 
 | 
 | ||||||
|     # Create selected knowledge base |     # Create selected knowledge base | ||||||
| @ -101,9 +99,10 @@ async def delete_cb( | |||||||
|         model_name: bool = Body(..., examples=["samples"]), |         model_name: bool = Body(..., examples=["samples"]), | ||||||
|         temperature: bool = Body(..., examples=["samples"]), |         temperature: bool = Body(..., examples=["samples"]), | ||||||
|         model_device: bool = Body(..., examples=["samples"]), |         model_device: bool = Body(..., examples=["samples"]), | ||||||
|  |         embed_config: EmbedConfig = None, | ||||||
|         ) -> BaseResponse: |         ) -> BaseResponse: | ||||||
|     logger.info('cb_name={}'.format(cb_name)) |     logger.info('cb_name={}'.format(cb_name)) | ||||||
|     embed_config: EmbedConfig = EmbedConfig(**locals()) |     embed_config: EmbedConfig = EmbedConfig(**locals()) if embed_config is None else embed_config | ||||||
|     llm_config: LLMConfig = LLMConfig(**locals()) |     llm_config: LLMConfig = LLMConfig(**locals()) | ||||||
|     # Create selected knowledge base |     # Create selected knowledge base | ||||||
|     if not validate_kb_name(cb_name): |     if not validate_kb_name(cb_name): | ||||||
| @ -143,18 +142,24 @@ def search_code(cb_name: str = Body(..., examples=["sofaboot"]), | |||||||
|                 model_name: bool = Body(..., examples=["samples"]), |                 model_name: bool = Body(..., examples=["samples"]), | ||||||
|                 temperature: bool = Body(..., examples=["samples"]), |                 temperature: bool = Body(..., examples=["samples"]), | ||||||
|                 model_device: bool = Body(..., examples=["samples"]), |                 model_device: bool = Body(..., examples=["samples"]), | ||||||
|  |                 use_nh: bool = True, | ||||||
|  |                 local_graph_path: str = '', | ||||||
|  |                 embed_config: EmbedConfig = None, | ||||||
|                 ) -> dict: |                 ) -> dict: | ||||||
|      |      | ||||||
|     logger.info('cb_name={}'.format(cb_name)) |     if os.environ.get("log_verbose", "0") >= "2": | ||||||
|     logger.info('query={}'.format(query)) |         logger.info(f'local_graph_path={local_graph_path}') | ||||||
|     logger.info('code_limit={}'.format(code_limit)) |         logger.info('cb_name={}'.format(cb_name)) | ||||||
|     logger.info('search_type={}'.format(search_type)) |         logger.info('query={}'.format(query)) | ||||||
|     logger.info('history_node_list={}'.format(history_node_list)) |         logger.info('code_limit={}'.format(code_limit)) | ||||||
|     embed_config: EmbedConfig = EmbedConfig(**locals()) |         logger.info('search_type={}'.format(search_type)) | ||||||
|  |         logger.info('history_node_list={}'.format(history_node_list)) | ||||||
|  |     embed_config: EmbedConfig = EmbedConfig(**locals()) if embed_config is None else embed_config | ||||||
|     llm_config: LLMConfig = LLMConfig(**locals()) |     llm_config: LLMConfig = LLMConfig(**locals()) | ||||||
|     try: |     try: | ||||||
|         # load codebase |         # load codebase | ||||||
|         cbh = CodeBaseHandler(codebase_name=cb_name, embed_config=embed_config, llm_config=llm_config) |         cbh = CodeBaseHandler(codebase_name=cb_name, embed_config=embed_config, llm_config=llm_config, | ||||||
|  |                               use_nh=use_nh, local_graph_path=local_graph_path) | ||||||
| 
 | 
 | ||||||
|         # search code |         # search code | ||||||
|         context, related_vertices = cbh.search_code(query, search_type=search_type, limit=code_limit) |         context, related_vertices = cbh.search_code(query, search_type=search_type, limit=code_limit) | ||||||
| @ -180,10 +185,12 @@ def search_related_vertices(cb_name: str = Body(..., examples=["sofaboot"]), | |||||||
|         nh = NebulaHandler(host=NEBULA_HOST, port=NEBULA_PORT, username=NEBULA_USER, |         nh = NebulaHandler(host=NEBULA_HOST, port=NEBULA_PORT, username=NEBULA_USER, | ||||||
|                            password=NEBULA_PASSWORD, space_name=cb_name) |                            password=NEBULA_PASSWORD, space_name=cb_name) | ||||||
|          |          | ||||||
|         cypher = f'''MATCH (v1)--(v2) WHERE id(v1) == '{vertex}' RETURN id(v2) as id;''' |         if vertex.endswith(".java"): | ||||||
| 
 |             cypher = f'''MATCH (v1)--(v2:package) WHERE id(v1) == '{vertex}' RETURN id(v2) as id;''' | ||||||
|  |         else: | ||||||
|  |             cypher = f'''MATCH (v1)--(v2) WHERE id(v1) == '{vertex}' RETURN id(v2) as id;''' | ||||||
|  |         # cypher = f'''MATCH (v1)--(v2) WHERE id(v1) == '{vertex}' RETURN v2;''' | ||||||
|         cypher_res = nh.execute_cypher(cypher=cypher, format_res=True) |         cypher_res = nh.execute_cypher(cypher=cypher, format_res=True) | ||||||
| 
 |  | ||||||
|         related_vertices = cypher_res.get('id', []) |         related_vertices = cypher_res.get('id', []) | ||||||
|         related_vertices = [i.as_string() for i in related_vertices] |         related_vertices = [i.as_string() for i in related_vertices] | ||||||
| 
 | 
 | ||||||
| @ -200,8 +207,8 @@ def search_related_vertices(cb_name: str = Body(..., examples=["sofaboot"]), | |||||||
| def search_code_by_vertex(cb_name: str = Body(..., examples=["sofaboot"]), | def search_code_by_vertex(cb_name: str = Body(..., examples=["sofaboot"]), | ||||||
|                             vertex: str = Body(..., examples=['***'])) -> dict: |                             vertex: str = Body(..., examples=['***'])) -> dict: | ||||||
| 
 | 
 | ||||||
|     logger.info('cb_name={}'.format(cb_name)) |     # logger.info('cb_name={}'.format(cb_name)) | ||||||
|     logger.info('vertex={}'.format(vertex)) |     # logger.info('vertex={}'.format(vertex)) | ||||||
| 
 | 
 | ||||||
|     try: |     try: | ||||||
|         nh = NebulaHandler(host=NEBULA_HOST, port=NEBULA_PORT, username=NEBULA_USER, |         nh = NebulaHandler(host=NEBULA_HOST, port=NEBULA_PORT, username=NEBULA_USER, | ||||||
| @ -233,7 +240,7 @@ def search_code_by_vertex(cb_name: str = Body(..., examples=["sofaboot"]), | |||||||
|         return res |         return res | ||||||
|     except Exception as e: |     except Exception as e: | ||||||
|         logger.exception(e) |         logger.exception(e) | ||||||
|         return {} |         return {'code': ""} | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def cb_exists_api(cb_name: str = Body(..., examples=["sofaboot"])) -> bool: | def cb_exists_api(cb_name: str = Body(..., examples=["sofaboot"])) -> bool: | ||||||
|  | |||||||
| @ -8,17 +8,6 @@ from loguru import logger | |||||||
| from langchain.embeddings.base import Embeddings | from langchain.embeddings.base import Embeddings | ||||||
| from langchain.docstore.document import Document | from langchain.docstore.document import Document | ||||||
| from langchain.embeddings.huggingface import HuggingFaceEmbeddings | from langchain.embeddings.huggingface import HuggingFaceEmbeddings | ||||||
| from langchain.vectorstores.utils import DistanceStrategy |  | ||||||
| 
 |  | ||||||
| # from configs.model_config import ( |  | ||||||
| #     KB_ROOT_PATH, |  | ||||||
| #     CACHED_VS_NUM, |  | ||||||
| #     EMBEDDING_MODEL, |  | ||||||
| #     EMBEDDING_DEVICE, |  | ||||||
| #     SCORE_THRESHOLD, |  | ||||||
| #     FAISS_NORMALIZE_L2 |  | ||||||
| # ) |  | ||||||
| # from configs.model_config import embedding_model_dict |  | ||||||
| 
 | 
 | ||||||
| from coagent.base_configs.env_config import ( | from coagent.base_configs.env_config import ( | ||||||
|     KB_ROOT_PATH, |     KB_ROOT_PATH, | ||||||
| @ -52,15 +41,15 @@ def load_vector_store( | |||||||
|         tick: int = 0,  # tick will be changed by upload_doc etc. and make cache refreshed. |         tick: int = 0,  # tick will be changed by upload_doc etc. and make cache refreshed. | ||||||
|         kb_root_path: str = KB_ROOT_PATH, |         kb_root_path: str = KB_ROOT_PATH, | ||||||
| ): | ): | ||||||
|     print(f"loading vector store in '{knowledge_base_name}'.") |     # print(f"loading vector store in '{knowledge_base_name}'.") | ||||||
|     vs_path = get_vs_path(knowledge_base_name, kb_root_path) |     vs_path = get_vs_path(knowledge_base_name, kb_root_path) | ||||||
|     if embeddings is None: |     if embeddings is None: | ||||||
|         embeddings = load_embeddings_from_path(embed_config.embed_model_path, embed_config.model_device) |         embeddings = load_embeddings_from_path(embed_config.embed_model_path, embed_config.model_device, embed_config.langchain_embeddings) | ||||||
| 
 | 
 | ||||||
|     if not os.path.exists(vs_path): |     if not os.path.exists(vs_path): | ||||||
|         os.makedirs(vs_path) |         os.makedirs(vs_path) | ||||||
|      |      | ||||||
|     distance_strategy = DistanceStrategy.EUCLIDEAN_DISTANCE |     distance_strategy = "EUCLIDEAN_DISTANCE" | ||||||
|     if "index.faiss" in os.listdir(vs_path): |     if "index.faiss" in os.listdir(vs_path): | ||||||
|         search_index = FAISS.load_local(vs_path, embeddings, normalize_L2=FAISS_NORMALIZE_L2, distance_strategy=distance_strategy) |         search_index = FAISS.load_local(vs_path, embeddings, normalize_L2=FAISS_NORMALIZE_L2, distance_strategy=distance_strategy) | ||||||
|     else: |     else: | ||||||
|  | |||||||
| @ -9,9 +9,7 @@ from pydantic import BaseModel, Field | |||||||
| from loguru import logger | from loguru import logger | ||||||
| 
 | 
 | ||||||
| from coagent.llm_models import LLMConfig, EmbedConfig | from coagent.llm_models import LLMConfig, EmbedConfig | ||||||
| 
 |  | ||||||
| from .base_tool import BaseToolModel | from .base_tool import BaseToolModel | ||||||
| 
 |  | ||||||
| from coagent.service.cb_api import search_code | from coagent.service.cb_api import search_code | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @ -29,7 +27,17 @@ class CodeRetrieval(BaseToolModel): | |||||||
|         code: str  = Field(..., description="检索代码") |         code: str  = Field(..., description="检索代码") | ||||||
| 
 | 
 | ||||||
|     @classmethod |     @classmethod | ||||||
|     def run(cls, code_base_name, query, code_limit=1, history_node_list=[], search_type="tag", llm_config: LLMConfig=None, embed_config: EmbedConfig=None): |     def run(cls, | ||||||
|  |             code_base_name, | ||||||
|  |             query, | ||||||
|  |             code_limit=1, | ||||||
|  |             history_node_list=[], | ||||||
|  |             search_type="tag", | ||||||
|  |             llm_config: LLMConfig=None, | ||||||
|  |             embed_config: EmbedConfig=None, | ||||||
|  |             use_nh: str=True, | ||||||
|  |             local_graph_path: str='' | ||||||
|  |             ): | ||||||
|         """excute your tool!""" |         """excute your tool!""" | ||||||
|          |          | ||||||
|         search_type = { |         search_type = { | ||||||
| @ -45,7 +53,8 @@ class CodeRetrieval(BaseToolModel): | |||||||
|         codes = search_code(code_base_name, query, code_limit, search_type=search_type, history_node_list=history_node_list, |         codes = search_code(code_base_name, query, code_limit, search_type=search_type, history_node_list=history_node_list, | ||||||
|                             embed_engine=embed_config.embed_engine, embed_model=embed_config.embed_model, embed_model_path=embed_config.embed_model_path, |                             embed_engine=embed_config.embed_engine, embed_model=embed_config.embed_model, embed_model_path=embed_config.embed_model_path, | ||||||
|                             model_device=embed_config.model_device, model_name=llm_config.model_name, temperature=llm_config.temperature, |                             model_device=embed_config.model_device, model_name=llm_config.model_name, temperature=llm_config.temperature, | ||||||
|                             api_base_url=llm_config.api_base_url, api_key=llm_config.api_key |                             api_base_url=llm_config.api_base_url, api_key=llm_config.api_key, use_nh=use_nh, | ||||||
|  |                             local_graph_path=local_graph_path, embed_config=embed_config | ||||||
|                             ) |                             ) | ||||||
|         return_codes = [] |         return_codes = [] | ||||||
|         context = codes['context'] |         context = codes['context'] | ||||||
|  | |||||||
| @ -5,6 +5,7 @@ | |||||||
| @time: 2023/12/14 上午10:24 | @time: 2023/12/14 上午10:24 | ||||||
| @desc: | @desc: | ||||||
| ''' | ''' | ||||||
|  | import os | ||||||
| from pydantic import BaseModel, Field | from pydantic import BaseModel, Field | ||||||
| from loguru import logger | from loguru import logger | ||||||
| 
 | 
 | ||||||
| @ -40,10 +41,9 @@ class CodeRetrievalSingle(BaseToolModel): | |||||||
|         vertex: str = Field(..., description="代码对应 id") |         vertex: str = Field(..., description="代码对应 id") | ||||||
| 
 | 
 | ||||||
|     @classmethod |     @classmethod | ||||||
|     def run(cls, code_base_name, query, embed_config: EmbedConfig, llm_config: LLMConfig, **kargs): |     def run(cls, code_base_name, query, embed_config: EmbedConfig, llm_config: LLMConfig, search_type="description", **kargs): | ||||||
|         """excute your tool!""" |         """excute your tool!""" | ||||||
| 
 | 
 | ||||||
|         search_type = 'description' |  | ||||||
|         code_limit = 1 |         code_limit = 1 | ||||||
| 
 | 
 | ||||||
|         # default |         # default | ||||||
| @ -51,10 +51,11 @@ class CodeRetrievalSingle(BaseToolModel): | |||||||
|                             history_node_list=[], |                             history_node_list=[], | ||||||
|                             embed_engine=embed_config.embed_engine, embed_model=embed_config.embed_model, embed_model_path=embed_config.embed_model_path, |                             embed_engine=embed_config.embed_engine, embed_model=embed_config.embed_model, embed_model_path=embed_config.embed_model_path, | ||||||
|                             model_device=embed_config.model_device, model_name=llm_config.model_name, temperature=llm_config.temperature, |                             model_device=embed_config.model_device, model_name=llm_config.model_name, temperature=llm_config.temperature, | ||||||
|                             api_base_url=llm_config.api_base_url, api_key=llm_config.api_key |                             api_base_url=llm_config.api_base_url, api_key=llm_config.api_key, embed_config=embed_config, use_nh=kargs.get("use_nh", True), | ||||||
|  |                             local_graph_path=kargs.get("local_graph_path", "") | ||||||
|                             ) |                             ) | ||||||
| 
 |         if os.environ.get("log_verbose", "0") >= "3": | ||||||
|         logger.debug(search_result) |             logger.debug(search_result) | ||||||
|         code = search_result['context'] |         code = search_result['context'] | ||||||
|         vertex = search_result['related_vertices'][0] |         vertex = search_result['related_vertices'][0] | ||||||
|         # logger.debug(f"code: {code}, vertex: {vertex}") |         # logger.debug(f"code: {code}, vertex: {vertex}") | ||||||
| @ -83,7 +84,7 @@ class RelatedVerticesRetrival(BaseToolModel): | |||||||
|     def run(cls, code_base_name: str, vertex: str, **kargs): |     def run(cls, code_base_name: str, vertex: str, **kargs): | ||||||
|         """execute your tool!""" |         """execute your tool!""" | ||||||
|         related_vertices = search_related_vertices(cb_name=code_base_name, vertex=vertex) |         related_vertices = search_related_vertices(cb_name=code_base_name, vertex=vertex) | ||||||
|         logger.debug(f"related_vertices: {related_vertices}") |         # logger.debug(f"related_vertices: {related_vertices}") | ||||||
| 
 | 
 | ||||||
|         return related_vertices |         return related_vertices | ||||||
| 
 | 
 | ||||||
| @ -110,6 +111,6 @@ class Vertex2Code(BaseToolModel): | |||||||
|         else: |         else: | ||||||
|             vertex = vertex.strip(' "') |             vertex = vertex.strip(' "') | ||||||
| 
 | 
 | ||||||
|         logger.info(f'vertex={vertex}') |         # logger.info(f'vertex={vertex}') | ||||||
|         res = search_code_by_vertex(cb_name=code_base_name, vertex=vertex) |         res = search_code_by_vertex(cb_name=code_base_name, vertex=vertex) | ||||||
|         return res |         return res | ||||||
| @ -2,11 +2,7 @@ from pydantic import BaseModel, Field | |||||||
| from loguru import logger | from loguru import logger | ||||||
| 
 | 
 | ||||||
| from coagent.llm_models.llm_config import EmbedConfig | from coagent.llm_models.llm_config import EmbedConfig | ||||||
| 
 |  | ||||||
| from .base_tool import BaseToolModel | from .base_tool import BaseToolModel | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| from coagent.service.kb_api import search_docs | from coagent.service.kb_api import search_docs | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -9,8 +9,10 @@ import numpy as np | |||||||
| from loguru import logger | from loguru import logger | ||||||
| 
 | 
 | ||||||
| from .base_tool import BaseToolModel | from .base_tool import BaseToolModel | ||||||
| 
 | try: | ||||||
| from duckduckgo_search import DDGS |     from duckduckgo_search import DDGS | ||||||
|  | except: | ||||||
|  |     logger.warning("can't find duckduckgo_search, if you need it, please `pip install duckduckgo_search`") | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class DDGSTool(BaseToolModel): | class DDGSTool(BaseToolModel): | ||||||
|  | |||||||
							
								
								
									
										89
									
								
								coagent/utils/code2doc_util.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										89
									
								
								coagent/utils/code2doc_util.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,89 @@ | |||||||
|  | import json | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def class_info_decode(data): | ||||||
|  |     '''解析class的相关信息''' | ||||||
|  |     params_dict = {} | ||||||
|  | 
 | ||||||
|  |     for i in data: | ||||||
|  |         _params_dict = {} | ||||||
|  |         for ii in i: | ||||||
|  |             for k, v in ii.items(): | ||||||
|  |                 if k=="origin_query": continue | ||||||
|  |                  | ||||||
|  |                 if k == "Code Path": | ||||||
|  |                     _params_dict["code_path"] = v.split("#")[0] | ||||||
|  |                     _params_dict["function_name"] = ".".join(v.split("#")[1:]) | ||||||
|  |                      | ||||||
|  |                 if k == "Class Description": | ||||||
|  |                     _params_dict["ClassDescription"] = v | ||||||
|  | 
 | ||||||
|  |                 if k == "Class Base": | ||||||
|  |                     _params_dict["ClassBase"] = v | ||||||
|  |                      | ||||||
|  |                 if k=="Init Parameters": | ||||||
|  |                     _params_dict["Parameters"] = v | ||||||
|  |                      | ||||||
|  |                      | ||||||
|  |         code_path = _params_dict["code_path"] | ||||||
|  |         params_dict.setdefault(code_path, []).append(_params_dict) | ||||||
|  |          | ||||||
|  |     return params_dict | ||||||
|  | 
 | ||||||
|  | def method_info_decode(data): | ||||||
|  |     params_dict = {} | ||||||
|  | 
 | ||||||
|  |     for i in data: | ||||||
|  |         _params_dict = {} | ||||||
|  |         for ii in i: | ||||||
|  |             for k, v in ii.items(): | ||||||
|  |                 if k=="origin_query": continue | ||||||
|  |                  | ||||||
|  |                 if k == "Code Path": | ||||||
|  |                     _params_dict["code_path"] = v.split("#")[0] | ||||||
|  |                     _params_dict["function_name"] = ".".join(v.split("#")[1:]) | ||||||
|  |                      | ||||||
|  |                 if k == "Return Value Description": | ||||||
|  |                     _params_dict["Returns"] = v | ||||||
|  |                  | ||||||
|  |                 if k == "Return Type": | ||||||
|  |                     _params_dict["ReturnType"] = v | ||||||
|  |                      | ||||||
|  |                 if k=="Parameters": | ||||||
|  |                     _params_dict["Parameters"] = v | ||||||
|  |                      | ||||||
|  |                      | ||||||
|  |         code_path = _params_dict["code_path"] | ||||||
|  |         params_dict.setdefault(code_path, []).append(_params_dict) | ||||||
|  |          | ||||||
|  |     return params_dict | ||||||
|  | 
 | ||||||
|  | def encode2md(data, md_format): | ||||||
|  |     md_dict = {} | ||||||
|  |     for code_path, params_list in data.items(): | ||||||
|  |         for params in params_list: | ||||||
|  |             params["Parameters_text"] = "\n".join([f"{param['param']}({param['param_type']})-{param['param_description']}"   | ||||||
|  |                                     for param in params["Parameters"]]) | ||||||
|  |     #         params.delete("Parameters") | ||||||
|  |             text=md_format.format(**params) | ||||||
|  |             md_dict.setdefault(code_path, []).append(text) | ||||||
|  |     return md_dict | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | method_text_md = '''> {function_name} | ||||||
|  | 
 | ||||||
|  | | Column Name | Content | | ||||||
|  | |-----------------|-----------------| | ||||||
|  | | Parameters   | {Parameters_text} | | ||||||
|  | | Returns   | {Returns} | | ||||||
|  | | Return type   | {ReturnType} | | ||||||
|  | ''' | ||||||
|  | 
 | ||||||
|  | class_text_md = '''> {code_path} | ||||||
|  | 
 | ||||||
|  | Bases: {ClassBase} | ||||||
|  | 
 | ||||||
|  | {ClassDescription} | ||||||
|  | 
 | ||||||
|  | {Parameters_text} | ||||||
|  | ''' | ||||||
| @ -7,7 +7,7 @@ from pathlib import Path | |||||||
| from io import BytesIO | from io import BytesIO | ||||||
| from fastapi import Body, File, Form, Body, Query, UploadFile | from fastapi import Body, File, Form, Body, Query, UploadFile | ||||||
| from tempfile import SpooledTemporaryFile | from tempfile import SpooledTemporaryFile | ||||||
| 
 | import json | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| DATE_FORMAT = "%Y-%m-%d %H:%M:%S" | DATE_FORMAT = "%Y-%m-%d %H:%M:%S" | ||||||
| @ -110,3 +110,5 @@ def get_uploadfile(file: Union[str, Path, bytes], filename=None) -> UploadFile: | |||||||
|     temp_file.write(file.read()) |     temp_file.write(file.read()) | ||||||
|     temp_file.seek(0) |     temp_file.seek(0) | ||||||
|     return UploadFile(file=temp_file, filename=filename) |     return UploadFile(file=temp_file, filename=filename) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | |||||||
| @ -1,7 +1,7 @@ | |||||||
| import os | import os | ||||||
| from langchain.document_loaders import CSVLoader, PyPDFLoader, UnstructuredFileLoader, TextLoader, PythonLoader | from langchain.document_loaders import CSVLoader, PyPDFLoader, UnstructuredFileLoader, TextLoader, PythonLoader | ||||||
| 
 | 
 | ||||||
| from coagent.document_loaders import JSONLLoader, JSONLoader | from coagent.retrieval.document_loaders import JSONLLoader, JSONLoader | ||||||
| # from configs.model_config import ( | # from configs.model_config import ( | ||||||
| #     embedding_model_dict, | #     embedding_model_dict, | ||||||
| #     KB_ROOT_PATH, | #     KB_ROOT_PATH, | ||||||
|  | |||||||
| @ -21,17 +21,20 @@ JUPYTER_WORK_PATH = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath | |||||||
| # WEB_CRAWL存储路径 | # WEB_CRAWL存储路径 | ||||||
| WEB_CRAWL_PATH = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "knowledge_base") | WEB_CRAWL_PATH = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "knowledge_base") | ||||||
| # NEBULA_DATA存储路径 | # NEBULA_DATA存储路径 | ||||||
| NELUBA_PATH = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "data/neluba_data") | NEBULA_PATH = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "data/nebula_data") | ||||||
| 
 | 
 | ||||||
| for _path in [LOG_PATH, SOURCE_PATH, KB_ROOT_PATH, NLTK_DATA_PATH, JUPYTER_WORK_PATH, WEB_CRAWL_PATH, NELUBA_PATH]: | # CHROMA 存储路径 | ||||||
|  | CHROMA_PERSISTENT_PATH = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "data/chroma_data") | ||||||
|  | 
 | ||||||
|  | for _path in [LOG_PATH, SOURCE_PATH, KB_ROOT_PATH, CB_ROOT_PATH, NLTK_DATA_PATH, JUPYTER_WORK_PATH, WEB_CRAWL_PATH, NEBULA_PATH, CHROMA_PERSISTENT_PATH]: | ||||||
|     if not os.path.exists(_path): |     if not os.path.exists(_path): | ||||||
|         os.makedirs(_path, exist_ok=True) |         os.makedirs(_path, exist_ok=True) | ||||||
|          |          | ||||||
| #  |  | ||||||
| path_envt_dict = { | path_envt_dict = { | ||||||
|     "LOG_PATH": LOG_PATH, "SOURCE_PATH": SOURCE_PATH, "KB_ROOT_PATH": KB_ROOT_PATH, |     "LOG_PATH": LOG_PATH, "SOURCE_PATH": SOURCE_PATH, "KB_ROOT_PATH": KB_ROOT_PATH, | ||||||
|     "NLTK_DATA_PATH":NLTK_DATA_PATH, "JUPYTER_WORK_PATH": JUPYTER_WORK_PATH, |     "NLTK_DATA_PATH":NLTK_DATA_PATH, "JUPYTER_WORK_PATH": JUPYTER_WORK_PATH, | ||||||
|     "WEB_CRAWL_PATH": WEB_CRAWL_PATH, "NELUBA_PATH": NELUBA_PATH |     "WEB_CRAWL_PATH": WEB_CRAWL_PATH, "NEBULA_PATH": NEBULA_PATH, | ||||||
|  |     "CHROMA_PERSISTENT_PATH": CHROMA_PERSISTENT_PATH | ||||||
|     }         |     }         | ||||||
| for path_name, _path in path_envt_dict.items(): | for path_name, _path in path_envt_dict.items(): | ||||||
|     os.environ[path_name] = _path |     os.environ[path_name] = _path | ||||||
|  | |||||||
| @ -33,7 +33,7 @@ except: | |||||||
|     pass |     pass | ||||||
| 
 | 
 | ||||||
| # add your openai key | # add your openai key | ||||||
| OPENAI_API_BASE = "http://openai.com/v1/chat/completions" | OPENAI_API_BASE = "https://api.openai.com/v1" | ||||||
| os.environ["API_BASE_URL"] = OPENAI_API_BASE | os.environ["API_BASE_URL"] = OPENAI_API_BASE | ||||||
| os.environ["OPENAI_API_KEY"] = "sk-xx" | os.environ["OPENAI_API_KEY"] = "sk-xx" | ||||||
| openai.api_key = "sk-xx" | openai.api_key = "sk-xx" | ||||||
|  | |||||||
| @ -58,9 +58,6 @@ NEBULA_GRAPH_SERVER = { | |||||||
|     "docker_port": NEBULA_PORT |     "docker_port": NEBULA_PORT | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| # chroma conf |  | ||||||
| CHROMA_PERSISTENT_PATH = '/home/user/chatbot/data/chroma_data' |  | ||||||
| 
 |  | ||||||
| # sandbox api server | # sandbox api server | ||||||
| SANDBOX_CONTRAINER_NAME = "devopsgpt_sandbox" | SANDBOX_CONTRAINER_NAME = "devopsgpt_sandbox" | ||||||
| SANDBOX_IMAGE_NAME = "devopsgpt:py39" | SANDBOX_IMAGE_NAME = "devopsgpt:py39" | ||||||
|  | |||||||
| @ -15,11 +15,11 @@ from coagent.connector.schema import Message | |||||||
| # | # | ||||||
| tools = toLangchainTools([TOOL_DICT[i] for i in TOOL_SETS if i in TOOL_DICT]) | tools = toLangchainTools([TOOL_DICT[i] for i in TOOL_SETS if i in TOOL_DICT]) | ||||||
| # log-level,print prompt和llm predict | # log-level,print prompt和llm predict | ||||||
| os.environ["log_verbose"] = "0" | os.environ["log_verbose"] = "2" | ||||||
| 
 | 
 | ||||||
| phase_name = "baseGroupPhase" | phase_name = "baseGroupPhase" | ||||||
| llm_config = LLMConfig( | llm_config = LLMConfig( | ||||||
|     model_name=LLM_MODEL, model_device="cpu",api_key=os.environ["OPENAI_API_KEY"],  |     model_name=LLM_MODEL, api_key=os.environ["OPENAI_API_KEY"],  | ||||||
|     api_base_url=os.environ["API_BASE_URL"], temperature=0.3 |     api_base_url=os.environ["API_BASE_URL"], temperature=0.3 | ||||||
|     ) |     ) | ||||||
| embed_config = EmbedConfig( | embed_config = EmbedConfig( | ||||||
|  | |||||||
| @ -17,7 +17,7 @@ os.environ["log_verbose"] = "2" | |||||||
| 
 | 
 | ||||||
| phase_name = "baseTaskPhase" | phase_name = "baseTaskPhase" | ||||||
| llm_config = LLMConfig( | llm_config = LLMConfig( | ||||||
|     model_name="gpt-3.5-turbo", model_device="cpu",api_key=os.environ["OPENAI_API_KEY"],  |     model_name="gpt-3.5-turbo", api_key=os.environ["OPENAI_API_KEY"],  | ||||||
|     api_base_url=os.environ["API_BASE_URL"], temperature=0.3 |     api_base_url=os.environ["API_BASE_URL"], temperature=0.3 | ||||||
|     ) |     ) | ||||||
| embed_config = EmbedConfig( | embed_config = EmbedConfig( | ||||||
|  | |||||||
							
								
								
									
										135
									
								
								examples/agent_examples/codeChatPhaseLocal_example.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										135
									
								
								examples/agent_examples/codeChatPhaseLocal_example.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,135 @@ | |||||||
|  | # encoding: utf-8 | ||||||
|  | ''' | ||||||
|  | @author: 温进 | ||||||
|  | @file: codeChatPhaseLocal_example.py | ||||||
|  | @time: 2024/1/31 下午4:32 | ||||||
|  | @desc: | ||||||
|  | ''' | ||||||
|  | import os, sys, requests | ||||||
|  | from concurrent.futures import ThreadPoolExecutor | ||||||
|  | from tqdm import tqdm | ||||||
|  | 
 | ||||||
|  | import requests | ||||||
|  | from typing import List | ||||||
|  | 
 | ||||||
|  | src_dir = os.path.join( | ||||||
|  |     os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | ||||||
|  | ) | ||||||
|  | sys.path.append(src_dir) | ||||||
|  | 
 | ||||||
|  | from configs.model_config import KB_ROOT_PATH, JUPYTER_WORK_PATH, CB_ROOT_PATH | ||||||
|  | from configs.server_config import SANDBOX_SERVER | ||||||
|  | from coagent.tools import toLangchainTools, TOOL_DICT, TOOL_SETS | ||||||
|  | from coagent.llm_models.llm_config import EmbedConfig, LLMConfig | ||||||
|  | from coagent.connector.phase import BasePhase | ||||||
|  | from coagent.connector.schema import Message, Memory | ||||||
|  | from coagent.codechat.codebase_handler.codebase_handler import CodeBaseHandler | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | # log-level,print prompt和llm predict | ||||||
|  | os.environ["log_verbose"] = "1" | ||||||
|  | 
 | ||||||
|  | llm_config = LLMConfig( | ||||||
|  |     model_name="gpt-3.5-turbo", model_device="cpu",api_key=os.environ["OPENAI_API_KEY"], | ||||||
|  |     api_base_url=os.environ["API_BASE_URL"], temperature=0.3 | ||||||
|  |     ) | ||||||
|  | embed_config = EmbedConfig( | ||||||
|  |     embed_engine="model", embed_model="text2vec-base-chinese", | ||||||
|  |     embed_model_path=os.path.join(src_dir, "embedding_models/text2vec-base-chinese") | ||||||
|  |     ) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | # delete codebase | ||||||
|  | codebase_name = 'client_local' | ||||||
|  | code_path = '/Users/bingxu/Desktop/工作/大模型/chatbot/test_code_repo/client' | ||||||
|  | code_path = "D://chromeDownloads/devopschat-bot/client_v2/client" | ||||||
|  | use_nh = True | ||||||
|  | # cbh = CodeBaseHandler(codebase_name, code_path, crawl_type='dir', use_nh=use_nh, local_graph_path=CB_ROOT_PATH, | ||||||
|  | #                       llm_config=llm_config, embed_config=embed_config) | ||||||
|  | cbh = CodeBaseHandler(codebase_name, code_path, crawl_type='dir', use_nh=use_nh, local_graph_path=CB_ROOT_PATH, | ||||||
|  |                       llm_config=llm_config, embed_config=embed_config) | ||||||
|  | cbh.delete_codebase(codebase_name=codebase_name) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | # initialize codebase | ||||||
|  | codebase_name = 'client_local' | ||||||
|  | code_path = '/Users/bingxu/Desktop/工作/大模型/chatbot/test_code_repo/client' | ||||||
|  | code_path = "D://chromeDownloads/devopschat-bot/client_v2/client" | ||||||
|  | code_path = "/home/user/client" | ||||||
|  | use_nh = True | ||||||
|  | do_interpret = True | ||||||
|  | cbh = CodeBaseHandler(codebase_name, code_path, crawl_type='dir', use_nh=use_nh, local_graph_path=CB_ROOT_PATH, | ||||||
|  |                       llm_config=llm_config, embed_config=embed_config) | ||||||
|  | cbh.import_code(do_interpret=do_interpret) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | # chat with codebase | ||||||
|  | phase_name = "codeChatPhase" | ||||||
|  | phase = BasePhase( | ||||||
|  |     phase_name, sandbox_server=SANDBOX_SERVER, jupyter_work_path=JUPYTER_WORK_PATH, | ||||||
|  |     embed_config=embed_config, llm_config=llm_config, kb_root_path=KB_ROOT_PATH, | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | # remove 这个函数是做什么的  => 基于标签 | ||||||
|  | # 有没有函数已经实现了从字符串删除指定字符串的功能,使用的话可以怎么使用,写个java代码  => 基于描述 | ||||||
|  | # 有根据我以下的需求用 java 开发一个方法:输入为字符串,将输入中的 .java 字符串给删除掉,然后返回新的字符串 => 基于描述 | ||||||
|  | 
 | ||||||
|  | ## 需要启动容器中的nebula,采用use_nh=True来构建代码库,是可以通过cypher来查询 | ||||||
|  | # round-1 | ||||||
|  | # query_content = "代码一共有多少类" | ||||||
|  | # query = Message( | ||||||
|  | #     role_name="human", role_type="user", | ||||||
|  | #     role_content=query_content, input_query=query_content, origin_query=query_content, | ||||||
|  | #     code_engine_name="client_1", score_threshold=1.0, top_k=3, cb_search_type="cypher" | ||||||
|  | #     ) | ||||||
|  | # | ||||||
|  | # output_message1, _ = phase.step(query) | ||||||
|  | # print(output_message1) | ||||||
|  | 
 | ||||||
|  | # round-2 | ||||||
|  | # query_content = "代码库里有哪些函数,返回5个就行" | ||||||
|  | # query = Message( | ||||||
|  | #     role_name="human", role_type="user", | ||||||
|  | #     role_content=query_content, input_query=query_content, origin_query=query_content, | ||||||
|  | #     code_engine_name="client_1", score_threshold=1.0, top_k=3, cb_search_type="cypher" | ||||||
|  | #     ) | ||||||
|  | # output_message2, _ = phase.step(query) | ||||||
|  | # print(output_message2) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | # round-3 | ||||||
|  | query_content = "remove 这个函数是做什么的" | ||||||
|  | query = Message( | ||||||
|  |     role_name="user", role_type="human", | ||||||
|  |     role_content=query_content, input_query=query_content, origin_query=query_content, | ||||||
|  |     code_engine_name=codebase_name, score_threshold=1.0, top_k=3, cb_search_type="tag", | ||||||
|  |     use_nh=False, local_graph_path=CB_ROOT_PATH | ||||||
|  |     ) | ||||||
|  | output_message3, output_memory3 = phase.step(query) | ||||||
|  | print(output_memory3.to_str_messages(return_all=True, content_key="parsed_output_list")) | ||||||
|  | 
 | ||||||
|  | # | ||||||
|  | # # round-4 | ||||||
|  | query_content = "有没有函数已经实现了从字符串删除指定字符串的功能,使用的话可以怎么使用,写个java代码" | ||||||
|  | query = Message( | ||||||
|  |     role_name="human", role_type="user", | ||||||
|  |     role_content=query_content, input_query=query_content, origin_query=query_content, | ||||||
|  |     code_engine_name=codebase_name, score_threshold=1.0, top_k=3, cb_search_type="description", | ||||||
|  |     use_nh=False, local_graph_path=CB_ROOT_PATH | ||||||
|  |     ) | ||||||
|  | output_message4, output_memory4 = phase.step(query) | ||||||
|  | print(output_memory4.to_str_messages(return_all=True, content_key="parsed_output_list")) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | # # round-5 | ||||||
|  | query_content = "有根据我以下的需求用 java 开发一个方法:输入为字符串,将输入中的 .java 字符串给删除掉,然后返回新的字符串" | ||||||
|  | query = Message( | ||||||
|  |     role_name="human", role_type="user", | ||||||
|  |     role_content=query_content, input_query=query_content, origin_query=query_content, | ||||||
|  |     code_engine_name=codebase_name, score_threshold=1.0, top_k=3, cb_search_type="description", | ||||||
|  |     use_nh=False, local_graph_path=CB_ROOT_PATH | ||||||
|  |     ) | ||||||
|  | output_message5, output_memory5 = phase.step(query) | ||||||
|  | print(output_memory5.to_str_messages(return_all=True, content_key="parsed_output_list")) | ||||||
| @ -17,13 +17,14 @@ os.environ["log_verbose"] = "2" | |||||||
| 
 | 
 | ||||||
| phase_name = "codeChatPhase" | phase_name = "codeChatPhase" | ||||||
| llm_config = LLMConfig( | llm_config = LLMConfig( | ||||||
|     model_name="gpt-3.5-turbo", model_device="cpu",api_key=os.environ["OPENAI_API_KEY"],  |     model_name="gpt-3.5-turbo", api_key=os.environ["OPENAI_API_KEY"],  | ||||||
|     api_base_url=os.environ["API_BASE_URL"], temperature=0.3 |     api_base_url=os.environ["API_BASE_URL"], temperature=0.3 | ||||||
|     ) |     ) | ||||||
| embed_config = EmbedConfig( | embed_config = EmbedConfig( | ||||||
|     embed_engine="model", embed_model="text2vec-base-chinese", |     embed_engine="model", embed_model="text2vec-base-chinese", | ||||||
|     embed_model_path=os.path.join(src_dir, "embedding_models/text2vec-base-chinese") |     embed_model_path=os.path.join(src_dir, "embedding_models/text2vec-base-chinese") | ||||||
|     ) |     ) | ||||||
|  | 
 | ||||||
| phase = BasePhase( | phase = BasePhase( | ||||||
|     phase_name, sandbox_server=SANDBOX_SERVER, jupyter_work_path=JUPYTER_WORK_PATH, |     phase_name, sandbox_server=SANDBOX_SERVER, jupyter_work_path=JUPYTER_WORK_PATH, | ||||||
|     embed_config=embed_config, llm_config=llm_config, kb_root_path=KB_ROOT_PATH, |     embed_config=embed_config, llm_config=llm_config, kb_root_path=KB_ROOT_PATH, | ||||||
| @ -35,25 +36,28 @@ phase = BasePhase( | |||||||
| # 有根据我以下的需求用 java 开发一个方法:输入为字符串,将输入中的 .java 字符串给删除掉,然后返回新的字符串 => 基于描述 | # 有根据我以下的需求用 java 开发一个方法:输入为字符串,将输入中的 .java 字符串给删除掉,然后返回新的字符串 => 基于描述 | ||||||
| 
 | 
 | ||||||
| # round-1 | # round-1 | ||||||
| query_content = "代码一共有多少类" | # query_content = "代码一共有多少类" | ||||||
| query = Message( | # query = Message( | ||||||
|     role_name="human", role_type="user",  | #     role_name="human", role_type="user", | ||||||
|     role_content=query_content, input_query=query_content, origin_query=query_content, | #     role_content=query_content, input_query=query_content, origin_query=query_content, | ||||||
|     code_engine_name="client", score_threshold=1.0, top_k=3, cb_search_type="cypher" | #     code_engine_name="client_1", score_threshold=1.0, top_k=3, cb_search_type="cypher" | ||||||
|     ) | #     ) | ||||||
| 
 | # | ||||||
| output_message1, _ = phase.step(query) | # output_message1, _ = phase.step(query) | ||||||
|  | # print(output_message1) | ||||||
| 
 | 
 | ||||||
| # round-2 | # round-2 | ||||||
| query_content = "代码库里有哪些函数,返回5个就行" | # query_content = "代码库里有哪些函数,返回5个就行" | ||||||
| query = Message( | # query = Message( | ||||||
|     role_name="human", role_type="user",  | #     role_name="human", role_type="user", | ||||||
|     role_content=query_content, input_query=query_content, origin_query=query_content, | #     role_content=query_content, input_query=query_content, origin_query=query_content, | ||||||
|     code_engine_name="client", score_threshold=1.0, top_k=3, cb_search_type="cypher" | #     code_engine_name="client_1", score_threshold=1.0, top_k=3, cb_search_type="cypher" | ||||||
|     ) | #     ) | ||||||
| output_message2, _ = phase.step(query) | # output_message2, _ = phase.step(query) | ||||||
|  | # print(output_message2) | ||||||
| 
 | 
 | ||||||
| # round-3 | # | ||||||
|  | # # round-3 | ||||||
| query_content = "remove 这个函数是做什么的" | query_content = "remove 这个函数是做什么的" | ||||||
| query = Message( | query = Message( | ||||||
|     role_name="user", role_type="human", |     role_name="user", role_type="human", | ||||||
| @ -61,24 +65,27 @@ query = Message( | |||||||
|     code_engine_name="client", score_threshold=1.0, top_k=3, cb_search_type="tag" |     code_engine_name="client", score_threshold=1.0, top_k=3, cb_search_type="tag" | ||||||
|     ) |     ) | ||||||
| output_message3, _ = phase.step(query) | output_message3, _ = phase.step(query) | ||||||
|  | print(output_message3) | ||||||
| 
 | 
 | ||||||
| # round-4 | # | ||||||
| query_content = "有没有函数已经实现了从字符串删除指定字符串的功能,使用的话可以怎么使用,写个java代码" | # # round-4 | ||||||
| query = Message( | # query_content = "有没有函数已经实现了从字符串删除指定字符串的功能,使用的话可以怎么使用,写个java代码" | ||||||
|     role_name="human", role_type="user",  | # query = Message( | ||||||
|     role_content=query_content, input_query=query_content, origin_query=query_content, | #     role_name="human", role_type="user", | ||||||
|     code_engine_name="client", score_threshold=1.0, top_k=3, cb_search_type="description" | #     role_content=query_content, input_query=query_content, origin_query=query_content, | ||||||
|     ) | #     code_engine_name="client_1", score_threshold=1.0, top_k=3, cb_search_type="description" | ||||||
| output_message4, _ = phase.step(query) | #     ) | ||||||
| 
 | # output_message4, _ = phase.step(query) | ||||||
| 
 | # print(output_message4) | ||||||
| # round-5 | # | ||||||
| query_content = "有根据我以下的需求用 java 开发一个方法:输入为字符串,将输入中的 .java 字符串给删除掉,然后返回新的字符串" | # # round-5 | ||||||
| query = Message( | # query_content = "有根据我以下的需求用 java 开发一个方法:输入为字符串,将输入中的 .java 字符串给删除掉,然后返回新的字符串" | ||||||
|     role_name="human", role_type="user",  | # query = Message( | ||||||
|     role_content=query_content, input_query=query_content, origin_query=query_content, | #     role_name="human", role_type="user", | ||||||
|     code_engine_name="client", score_threshold=1.0, top_k=3, cb_search_type="description" | #     role_content=query_content, input_query=query_content, origin_query=query_content, | ||||||
|     ) | #     code_engine_name="client_1", score_threshold=1.0, top_k=3, cb_search_type="description" | ||||||
| output_message5, output_memory5 = phase.step(query) | #     ) | ||||||
| 
 | # output_message5, output_memory5 = phase.step(query) | ||||||
| print(output_memory5.to_str_messages(return_all=True, content_key="parsed_output_list")) | # print(output_message5) | ||||||
|  | # | ||||||
|  | # print(output_memory5.to_str_messages(return_all=True, content_key="parsed_output_list")) | ||||||
							
								
								
									
										507
									
								
								examples/agent_examples/codeGenDoc_example.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										507
									
								
								examples/agent_examples/codeGenDoc_example.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,507 @@ | |||||||
|  | import os, sys, json | ||||||
|  | from loguru import logger | ||||||
|  | src_dir = os.path.join( | ||||||
|  |     os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | ||||||
|  | ) | ||||||
|  | sys.path.append(src_dir) | ||||||
|  | 
 | ||||||
|  | from configs.model_config import KB_ROOT_PATH, JUPYTER_WORK_PATH, CB_ROOT_PATH | ||||||
|  | from configs.server_config import SANDBOX_SERVER | ||||||
|  | from coagent.llm_models.llm_config import EmbedConfig, LLMConfig | ||||||
|  | 
 | ||||||
|  | from coagent.connector.phase import BasePhase | ||||||
|  | from coagent.connector.agents import BaseAgent | ||||||
|  | from coagent.connector.schema import Message | ||||||
|  | from coagent.tools import CodeRetrievalSingle | ||||||
|  | from coagent.codechat.codebase_handler.codebase_handler import CodeBaseHandler | ||||||
|  | import importlib | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | # 定义一个新的agent类 | ||||||
|  | class CodeGenDocer(BaseAgent): | ||||||
|  | 
 | ||||||
|  |     def start_action_step(self, message: Message) -> Message: | ||||||
|  |         '''do action before agent predict ''' | ||||||
|  |         # 根据问题获取代码片段和节点信息 | ||||||
|  |         action_json = CodeRetrievalSingle.run(message.code_engine_name, message.origin_query,  | ||||||
|  |                                               llm_config=self.llm_config, embed_config=self.embed_config, local_graph_path=message.local_graph_path, use_nh=message.use_nh,search_type="tag") | ||||||
|  |         current_vertex = action_json['vertex'] | ||||||
|  |         message.customed_kargs["Code Snippet"] = action_json["code"] | ||||||
|  |         message.customed_kargs['Current_Vertex'] = current_vertex | ||||||
|  |         return message | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | # add agent or prompt_manager class | ||||||
|  | agent_module = importlib.import_module("coagent.connector.agents") | ||||||
|  | setattr(agent_module, 'CodeGenDocer', CodeGenDocer) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | # log-level,print prompt和llm predict | ||||||
|  | os.environ["log_verbose"] = "1" | ||||||
|  | 
 | ||||||
|  | phase_name = "code2DocsGroup" | ||||||
|  | llm_config = LLMConfig( | ||||||
|  |     model_name="gpt-4", api_key=os.environ["OPENAI_API_KEY"],  | ||||||
|  |     api_base_url=os.environ["API_BASE_URL"], temperature=0.3 | ||||||
|  | ) | ||||||
|  | embed_config = EmbedConfig( | ||||||
|  |     embed_engine="model", embed_model="text2vec-base-chinese",  | ||||||
|  |     embed_model_path=os.path.join(src_dir, "embedding_models/text2vec-base-chinese") | ||||||
|  |     ) | ||||||
|  | 
 | ||||||
|  | # initialize codebase | ||||||
|  | # delete codebase | ||||||
|  | codebase_name = 'client_local' | ||||||
|  | code_path = "D://chromeDownloads/devopschat-bot/client_v2/client" | ||||||
|  | use_nh = False | ||||||
|  | cbh = CodeBaseHandler(codebase_name, code_path, crawl_type='dir', use_nh=use_nh, local_graph_path=CB_ROOT_PATH, | ||||||
|  |                       llm_config=llm_config, embed_config=embed_config) | ||||||
|  | cbh.delete_codebase(codebase_name=codebase_name) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | # load codebase | ||||||
|  | codebase_name = 'client_local' | ||||||
|  | code_path = "D://chromeDownloads/devopschat-bot/client_v2/client" | ||||||
|  | use_nh = True | ||||||
|  | do_interpret = True | ||||||
|  | cbh = CodeBaseHandler(codebase_name, code_path, crawl_type='dir', use_nh=use_nh, local_graph_path=CB_ROOT_PATH, | ||||||
|  |                       llm_config=llm_config, embed_config=embed_config) | ||||||
|  | cbh.import_code(do_interpret=do_interpret) | ||||||
|  | 
 | ||||||
|  | # 根据前面的load过程进行初始化 | ||||||
|  | cbh = CodeBaseHandler(codebase_name, code_path, crawl_type='dir', use_nh=use_nh, local_graph_path=CB_ROOT_PATH, | ||||||
|  |                       llm_config=llm_config, embed_config=embed_config) | ||||||
|  | phase = BasePhase( | ||||||
|  |     phase_name, sandbox_server=SANDBOX_SERVER, jupyter_work_path=JUPYTER_WORK_PATH, | ||||||
|  |     embed_config=embed_config, llm_config=llm_config, kb_root_path=KB_ROOT_PATH, | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | for vertex_type in ["class", "method"]: | ||||||
|  |     vertexes = cbh.search_vertices(vertex_type=vertex_type) | ||||||
|  |     logger.info(f"vertexes={vertexes}") | ||||||
|  | 
 | ||||||
|  |     # round-1 | ||||||
|  |     docs = [] | ||||||
|  |     for vertex in vertexes: | ||||||
|  |         vertex = vertex.split("-")[0] # -为method的参数 | ||||||
|  |         query_content = f"为{vertex_type}节点 {vertex}生成文档" | ||||||
|  |         query = Message( | ||||||
|  |             role_name="human", role_type="user",  | ||||||
|  |             role_content=query_content, input_query=query_content, origin_query=query_content, | ||||||
|  |             code_engine_name="client_local", score_threshold=1.0, top_k=3, cb_search_type="tag", use_nh=use_nh, | ||||||
|  |             local_graph_path=CB_ROOT_PATH, | ||||||
|  |             ) | ||||||
|  |         output_message, output_memory = phase.step(query, reinit_memory=True) | ||||||
|  |         # print(output_memory.to_str_messages(return_all=True, content_key="parsed_output_list")) | ||||||
|  |         docs.append(output_memory.get_spec_parserd_output()) | ||||||
|  | 
 | ||||||
|  |         os.makedirs(f"{CB_ROOT_PATH}/docs", exist_ok=True) | ||||||
|  |         with open(f"{CB_ROOT_PATH}/docs/raw_{vertex_type}.json", "w") as f: | ||||||
|  |             json.dump(docs, f) | ||||||
|  |      | ||||||
|  | 
 | ||||||
|  | # 下面把生成的文档信息转换成markdown文本 | ||||||
|  | from coagent.utils.code2doc_util import * | ||||||
|  | import json | ||||||
|  | with open(f"{CB_ROOT_PATH}/docs/raw_method.json", "r") as f: | ||||||
|  |     method_raw_data = json.load(f) | ||||||
|  | 
 | ||||||
|  | with open(f"{CB_ROOT_PATH}/docs/raw_class.json", "r") as f: | ||||||
|  |     class_raw_data = json.load(f) | ||||||
|  |      | ||||||
|  | 
 | ||||||
|  | method_data = method_info_decode(method_raw_data) | ||||||
|  | class_data = class_info_decode(class_raw_data) | ||||||
|  | method_mds = encode2md(method_data, method_text_md) | ||||||
|  | class_mds = encode2md(class_data, class_text_md) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | docs_dict = {} | ||||||
|  | for k,v in class_mds.items(): | ||||||
|  |     method_textmds = method_mds.get(k, []) | ||||||
|  |     for vv in v: | ||||||
|  |         # 理论上只有一个 | ||||||
|  |         text_md = vv | ||||||
|  | 
 | ||||||
|  |     for method_textmd in method_textmds: | ||||||
|  |         text_md += "\n<br>" + method_textmd | ||||||
|  | 
 | ||||||
|  |     docs_dict.setdefault(k, []).append(text_md) | ||||||
|  |      | ||||||
|  |     with open(f"{CB_ROOT_PATH}//docs/{k}.md", "w") as f: | ||||||
|  |         f.write(text_md) | ||||||
|  |      | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | #################################### | ||||||
|  | ######## 下面是完整的复现过程 ######## | ||||||
|  | #################################### | ||||||
|  | 
 | ||||||
|  | # import os, sys, requests | ||||||
|  | # from loguru import logger | ||||||
|  | # src_dir = os.path.join( | ||||||
|  | #     os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | ||||||
|  | # ) | ||||||
|  | # sys.path.append(src_dir) | ||||||
|  | 
 | ||||||
|  | # from configs.model_config import KB_ROOT_PATH, JUPYTER_WORK_PATH, CB_ROOT_PATH | ||||||
|  | # from configs.server_config import SANDBOX_SERVER | ||||||
|  | # from coagent.tools import toLangchainTools, TOOL_DICT, TOOL_SETS | ||||||
|  | # from coagent.llm_models.llm_config import EmbedConfig, LLMConfig | ||||||
|  | 
 | ||||||
|  | # from coagent.connector.phase import BasePhase | ||||||
|  | # from coagent.connector.agents import BaseAgent, SelectorAgent | ||||||
|  | # from coagent.connector.chains import BaseChain | ||||||
|  | # from coagent.connector.schema import ( | ||||||
|  | #     Message, Memory, load_role_configs, load_phase_configs, load_chain_configs, ActionStatus | ||||||
|  | #     ) | ||||||
|  | # from coagent.connector.memory_manager import BaseMemoryManager | ||||||
|  | # from coagent.connector.configs import AGETN_CONFIGS, CHAIN_CONFIGS, PHASE_CONFIGS, BASE_PROMPT_CONFIGS | ||||||
|  | # from coagent.connector.prompt_manager.prompt_manager import PromptManager | ||||||
|  | # from coagent.codechat.codebase_handler.codebase_handler import CodeBaseHandler | ||||||
|  | 
 | ||||||
|  | # import importlib | ||||||
|  | # from loguru import logger | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | # from coagent.tools import CodeRetrievalSingle, RelatedVerticesRetrival, Vertex2Code | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | # # update new agent configs | ||||||
|  | # codeGenDocGroup_PROMPT = """#### Agent Profile | ||||||
|  | 
 | ||||||
|  | # Your goal is to response according the Context Data's information with the role that will best facilitate a solution, taking into account all relevant context (Context) provided. | ||||||
|  | 
 | ||||||
|  | # When you need to select the appropriate role for handling a user's query, carefully read the provided role names, role descriptions and tool list. | ||||||
|  | 
 | ||||||
|  | # ATTENTION: response carefully referenced "Response Output Format" in format. | ||||||
|  | 
 | ||||||
|  | # #### Input Format | ||||||
|  | 
 | ||||||
|  | # #### Response Output Format | ||||||
|  | 
 | ||||||
|  | # **Code Path:** Extract the paths for the class/method/function that need to be addressed from the context | ||||||
|  | 
 | ||||||
|  | # **Role:** Select the role from agent names | ||||||
|  | # """ | ||||||
|  | 
 | ||||||
|  | # classGenDoc_PROMPT = """#### Agent Profile | ||||||
|  | # As an advanced code documentation generator, you are proficient in translating class definitions into comprehensive documentation with a focus on instantiation parameters.  | ||||||
|  | # Your specific task is to parse the given code snippet of a class, extract information regarding its instantiation parameters. | ||||||
|  | 
 | ||||||
|  | # ATTENTION: response carefully in "Response Output Format". | ||||||
|  | 
 | ||||||
|  | # #### Input Format | ||||||
|  | 
 | ||||||
|  | # **Code Snippet:** Provide the full class definition, including the constructor and any parameters it may require for instantiation. | ||||||
|  | 
 | ||||||
|  | # #### Response Output Format | ||||||
|  | # **Class Base:** Specify the base class or interface from which the current class extends, if any. | ||||||
|  | 
 | ||||||
|  | # **Class Description:** Offer a brief description of the class's purpose and functionality. | ||||||
|  | 
 | ||||||
|  | # **Init Parameters:** List each parameter from construct. For each parameter, provide: | ||||||
|  | #     - `param`: The parameter name | ||||||
|  | #     - `param_description`: A concise explanation of the parameter's purpose. | ||||||
|  | #     - `param_type`: The data type of the parameter, if explicitly defined. | ||||||
|  | 
 | ||||||
|  | #     ```json | ||||||
|  | #     [ | ||||||
|  | #         { | ||||||
|  | #             "param": "parameter_name", | ||||||
|  | #             "param_description": "A brief description of what this parameter is used for.", | ||||||
|  | #             "param_type": "The data type of the parameter" | ||||||
|  | #         }, | ||||||
|  | #         ... | ||||||
|  | #     ] | ||||||
|  | #     ``` | ||||||
|  | 
 | ||||||
|  |          | ||||||
|  | #     If no parameter for construct, return  | ||||||
|  | #     ```json | ||||||
|  | #     [] | ||||||
|  | #     ``` | ||||||
|  | # """ | ||||||
|  | 
 | ||||||
|  | # funcGenDoc_PROMPT = """#### Agent Profile | ||||||
|  | # You are a high-level code documentation assistant, skilled at extracting information from function/method code into detailed and well-structured documentation. | ||||||
|  | 
 | ||||||
|  | # ATTENTION: response carefully in "Response Output Format". | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | # #### Input Format | ||||||
|  | # **Code Path:** Provide the code path of the function or method you wish to document.  | ||||||
|  | # This name will be used to identify and extract the relevant details from the code snippet provided. | ||||||
|  |      | ||||||
|  | # **Code Snippet:** A segment of code that contains the function or method to be documented. | ||||||
|  | 
 | ||||||
|  | # #### Response Output Format | ||||||
|  | 
 | ||||||
|  | # **Class Description:** Offer a brief description of the method(function)'s purpose and functionality. | ||||||
|  | 
 | ||||||
|  | # **Parameters:** Extract parameter for the specific function/method Code from Code Snippet. For parameter, provide: | ||||||
|  | #     - `param`: The parameter name | ||||||
|  | #     - `param_description`: A concise explanation of the parameter's purpose. | ||||||
|  | #     - `param_type`: The data type of the parameter, if explicitly defined. | ||||||
|  | #     ```json | ||||||
|  | #     [ | ||||||
|  | #         { | ||||||
|  | #             "param": "parameter_name", | ||||||
|  | #             "param_description": "A brief description of what this parameter is used for.", | ||||||
|  | #             "param_type": "The data type of the parameter" | ||||||
|  | #         }, | ||||||
|  | #         ... | ||||||
|  | #     ] | ||||||
|  | #     ``` | ||||||
|  | 
 | ||||||
|  | #     If no parameter for function/method, return  | ||||||
|  | #     ```json | ||||||
|  | #     [] | ||||||
|  | #     ``` | ||||||
|  | 
 | ||||||
|  | # **Return Value Description:** Describe what the function/method returns upon completion. | ||||||
|  | 
 | ||||||
|  | # **Return Type:** Indicate the type of data the function/method returns (e.g., string, integer, object, void). | ||||||
|  | # """ | ||||||
|  | 
 | ||||||
|  | # CODE_GENERATE_GROUP_PROMPT_CONFIGS = [ | ||||||
|  | #     {"field_name": 'agent_profile', "function_name": 'handle_agent_profile', "is_context": False}, | ||||||
|  | #     {"field_name": 'agent_infomation', "function_name": 'handle_agent_data', "is_context": False, "omit_if_empty": False}, | ||||||
|  | #     # {"field_name": 'tool_information',"function_name": 'handle_tool_data', "is_context": False}, | ||||||
|  | #     {"field_name": 'context_placeholder', "function_name": '', "is_context": True}, | ||||||
|  | #     # {"field_name": 'reference_documents', "function_name": 'handle_doc_info'}, | ||||||
|  | #     {"field_name": 'session_records', "function_name": 'handle_session_records'}, | ||||||
|  | #     {"field_name": 'Specific Objective', "function_name": 'handle_specific_objective'}, | ||||||
|  | #     {"field_name": 'Code Snippet', "function_name": 'handle_code_snippet'}, | ||||||
|  | #     {"field_name": 'output_format', "function_name": 'handle_output_format', 'title': 'Response Output Format', "is_context": False}, | ||||||
|  | #     {"field_name": 'begin!!!', "function_name": 'handle_response', "is_context": False, "omit_if_empty": False} | ||||||
|  | # ] | ||||||
|  | 
 | ||||||
|  | # CODE_GENERATE_DOC_PROMPT_CONFIGS = [ | ||||||
|  | #     {"field_name": 'agent_profile', "function_name": 'handle_agent_profile', "is_context": False}, | ||||||
|  | #     # {"field_name": 'tool_information',"function_name": 'handle_tool_data', "is_context": False}, | ||||||
|  | #     {"field_name": 'context_placeholder', "function_name": '', "is_context": True}, | ||||||
|  | #     # {"field_name": 'reference_documents', "function_name": 'handle_doc_info'}, | ||||||
|  | #     {"field_name": 'session_records', "function_name": 'handle_session_records'}, | ||||||
|  | #     {"field_name": 'Specific Objective', "function_name": 'handle_specific_objective'}, | ||||||
|  | #     {"field_name": 'Code Snippet', "function_name": 'handle_code_snippet'}, | ||||||
|  | #     {"field_name": 'output_format', "function_name": 'handle_output_format', 'title': 'Response Output Format', "is_context": False}, | ||||||
|  | #     {"field_name": 'begin!!!', "function_name": 'handle_response', "is_context": False, "omit_if_empty": False} | ||||||
|  | # ] | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | # class CodeGenDocPM(PromptManager): | ||||||
|  | #     def handle_code_snippet(self, **kwargs) -> str: | ||||||
|  | #         if 'previous_agent_message' not in kwargs: | ||||||
|  | #             return "" | ||||||
|  | #         previous_agent_message: Message = kwargs['previous_agent_message'] | ||||||
|  | #         code_snippet = previous_agent_message.customed_kargs.get("Code Snippet", "") | ||||||
|  | #         current_vertex = previous_agent_message.customed_kargs.get("Current_Vertex", "") | ||||||
|  | #         instruction = "A segment of code that contains the function or method to be documented.\n" | ||||||
|  | #         return instruction + "\n" + f"name: {current_vertex}\n{code_snippet}" | ||||||
|  | 
 | ||||||
|  | #     def handle_specific_objective(self, **kwargs) -> str: | ||||||
|  | #         if 'previous_agent_message' not in kwargs: | ||||||
|  | #             return "" | ||||||
|  | #         previous_agent_message: Message = kwargs['previous_agent_message'] | ||||||
|  | #         specific_objective = previous_agent_message.parsed_output.get("Code Path") | ||||||
|  | 
 | ||||||
|  | #         instruction = "Provide the code path of the function or method you wish to document.\n" | ||||||
|  | #         s = instruction + f"\n{specific_objective}" | ||||||
|  | #         return s | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | # from coagent.tools import CodeRetrievalSingle | ||||||
|  | 
 | ||||||
|  | # # 定义一个新的agent类 | ||||||
|  | # class CodeGenDocer(BaseAgent): | ||||||
|  | 
 | ||||||
|  | #     def start_action_step(self, message: Message) -> Message: | ||||||
|  | #         '''do action before agent predict ''' | ||||||
|  | #         # 根据问题获取代码片段和节点信息 | ||||||
|  | #         action_json = CodeRetrievalSingle.run(message.code_engine_name, message.origin_query,  | ||||||
|  | #                                               llm_config=self.llm_config, embed_config=self.embed_config, local_graph_path=message.local_graph_path, use_nh=message.use_nh,search_type="tag") | ||||||
|  | #         current_vertex = action_json['vertex'] | ||||||
|  | #         message.customed_kargs["Code Snippet"] = action_json["code"] | ||||||
|  | #         message.customed_kargs['Current_Vertex'] = current_vertex | ||||||
|  | #         return message | ||||||
|  | 
 | ||||||
|  | # # add agent or prompt_manager class | ||||||
|  | # agent_module = importlib.import_module("coagent.connector.agents") | ||||||
|  | # prompt_manager_module = importlib.import_module("coagent.connector.prompt_manager") | ||||||
|  | 
 | ||||||
|  | # setattr(agent_module, 'CodeGenDocer', CodeGenDocer) | ||||||
|  | # setattr(prompt_manager_module, 'CodeGenDocPM', CodeGenDocPM) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | # AGETN_CONFIGS.update({ | ||||||
|  | #     "classGenDoc": { | ||||||
|  | #         "role": { | ||||||
|  | #             "role_prompt": classGenDoc_PROMPT, | ||||||
|  | #             "role_type": "assistant", | ||||||
|  | #             "role_name": "classGenDoc", | ||||||
|  | #             "role_desc": "", | ||||||
|  | #             "agent_type": "CodeGenDocer" | ||||||
|  | #         }, | ||||||
|  | #         "prompt_config": CODE_GENERATE_DOC_PROMPT_CONFIGS, | ||||||
|  | #         "prompt_manager_type": "CodeGenDocPM", | ||||||
|  | #         "chat_turn": 1, | ||||||
|  | #         "focus_agents": [], | ||||||
|  | #         "focus_message_keys": [], | ||||||
|  | #     }, | ||||||
|  | #     "funcGenDoc": { | ||||||
|  | #         "role": { | ||||||
|  | #             "role_prompt": funcGenDoc_PROMPT, | ||||||
|  | #             "role_type": "assistant", | ||||||
|  | #             "role_name": "funcGenDoc", | ||||||
|  | #             "role_desc": "", | ||||||
|  | #             "agent_type": "CodeGenDocer" | ||||||
|  | #         }, | ||||||
|  | #         "prompt_config": CODE_GENERATE_DOC_PROMPT_CONFIGS, | ||||||
|  | #         "prompt_manager_type": "CodeGenDocPM", | ||||||
|  | #         "chat_turn": 1, | ||||||
|  | #         "focus_agents": [], | ||||||
|  | #         "focus_message_keys": [], | ||||||
|  | #     }, | ||||||
|  | #     "codeGenDocsGrouper": { | ||||||
|  | #         "role": { | ||||||
|  | #             "role_prompt": codeGenDocGroup_PROMPT, | ||||||
|  | #             "role_type": "assistant", | ||||||
|  | #             "role_name": "codeGenDocsGrouper", | ||||||
|  | #             "role_desc": "", | ||||||
|  | #             "agent_type": "SelectorAgent" | ||||||
|  | #         }, | ||||||
|  | #         "prompt_config": CODE_GENERATE_GROUP_PROMPT_CONFIGS, | ||||||
|  | #         "group_agents": ["classGenDoc", "funcGenDoc"], | ||||||
|  | #         "chat_turn": 1, | ||||||
|  | #     }, | ||||||
|  | # }) | ||||||
|  | # # update new chain configs | ||||||
|  | # CHAIN_CONFIGS.update({ | ||||||
|  | #     "codeGenDocsGroupChain": { | ||||||
|  | #         "chain_name": "codeGenDocsGroupChain", | ||||||
|  | #         "chain_type": "BaseChain", | ||||||
|  | #         "agents": ["codeGenDocsGrouper"], | ||||||
|  | #         "chat_turn": 1, | ||||||
|  | #         "do_checker": False, | ||||||
|  | #         "chain_prompt": "" | ||||||
|  | #     } | ||||||
|  | # }) | ||||||
|  | 
 | ||||||
|  | # # update phase configs | ||||||
|  | # PHASE_CONFIGS.update({ | ||||||
|  | #     "codeGenDocsGroup": { | ||||||
|  | #         "phase_name": "codeGenDocsGroup", | ||||||
|  | #         "phase_type": "BasePhase", | ||||||
|  | #         "chains": ["codeGenDocsGroupChain"], | ||||||
|  | #         "do_summary": False, | ||||||
|  | #         "do_search": False, | ||||||
|  | #         "do_doc_retrieval": False, | ||||||
|  | #         "do_code_retrieval": False, | ||||||
|  | #         "do_tool_retrieval": False, | ||||||
|  | #     }, | ||||||
|  | # }) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | # role_configs = load_role_configs(AGETN_CONFIGS) | ||||||
|  | # chain_configs = load_chain_configs(CHAIN_CONFIGS) | ||||||
|  | # phase_configs = load_phase_configs(PHASE_CONFIGS) | ||||||
|  | 
 | ||||||
|  | # # log-level,print prompt和llm predict | ||||||
|  | # os.environ["log_verbose"] = "1" | ||||||
|  | 
 | ||||||
|  | # phase_name = "codeGenDocsGroup" | ||||||
|  | # llm_config = LLMConfig( | ||||||
|  | #     model_name="gpt-4", api_key=os.environ["OPENAI_API_KEY"],  | ||||||
|  | #     api_base_url=os.environ["API_BASE_URL"], temperature=0.3 | ||||||
|  | # ) | ||||||
|  | # embed_config = EmbedConfig( | ||||||
|  | #     embed_engine="model", embed_model="text2vec-base-chinese",  | ||||||
|  | #     embed_model_path=os.path.join(src_dir, "embedding_models/text2vec-base-chinese") | ||||||
|  | #     ) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | # # initialize codebase | ||||||
|  | # # delete codebase | ||||||
|  | # codebase_name = 'client_local' | ||||||
|  | # code_path = "D://chromeDownloads/devopschat-bot/client_v2/client" | ||||||
|  | # use_nh = False | ||||||
|  | # cbh = CodeBaseHandler(codebase_name, code_path, crawl_type='dir', use_nh=use_nh, local_graph_path=CB_ROOT_PATH, | ||||||
|  | #                       llm_config=llm_config, embed_config=embed_config) | ||||||
|  | # cbh.delete_codebase(codebase_name=codebase_name) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | # # load codebase | ||||||
|  | # codebase_name = 'client_local' | ||||||
|  | # code_path = "D://chromeDownloads/devopschat-bot/client_v2/client" | ||||||
|  | # use_nh = False | ||||||
|  | # do_interpret = True | ||||||
|  | # cbh = CodeBaseHandler(codebase_name, code_path, crawl_type='dir', use_nh=use_nh, local_graph_path=CB_ROOT_PATH, | ||||||
|  | #                       llm_config=llm_config, embed_config=embed_config) | ||||||
|  | # cbh.import_code(do_interpret=do_interpret) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | # phase = BasePhase( | ||||||
|  | #     phase_name, sandbox_server=SANDBOX_SERVER, jupyter_work_path=JUPYTER_WORK_PATH, | ||||||
|  | #     embed_config=embed_config, llm_config=llm_config, kb_root_path=KB_ROOT_PATH, | ||||||
|  | # ) | ||||||
|  | 
 | ||||||
|  | # for vertex_type in ["class", "method"]: | ||||||
|  | #     vertexes = cbh.search_vertices(vertex_type=vertex_type) | ||||||
|  | #     logger.info(f"vertexes={vertexes}") | ||||||
|  | 
 | ||||||
|  | #     # round-1 | ||||||
|  | #     docs = [] | ||||||
|  | #     for vertex in vertexes: | ||||||
|  | #         vertex = vertex.split("-")[0] # -为method的参数 | ||||||
|  | #         query_content = f"为{vertex_type}节点 {vertex}生成文档" | ||||||
|  | #         query = Message( | ||||||
|  | #             role_name="human", role_type="user",  | ||||||
|  | #             role_content=query_content, input_query=query_content, origin_query=query_content, | ||||||
|  | #             code_engine_name="client_local", score_threshold=1.0, top_k=3, cb_search_type="tag", use_nh=use_nh, | ||||||
|  | #             local_graph_path=CB_ROOT_PATH, | ||||||
|  | #             ) | ||||||
|  | #         output_message, output_memory = phase.step(query, reinit_memory=True) | ||||||
|  | #         # print(output_memory.to_str_messages(return_all=True, content_key="parsed_output_list")) | ||||||
|  | #         docs.append(output_memory.get_spec_parserd_output()) | ||||||
|  | 
 | ||||||
|  | #         import json | ||||||
|  | #         os.makedirs("/home/user/code_base/docs", exist_ok=True) | ||||||
|  | #         with open(f"/home/user/code_base/docs/raw_{vertex_type}.json", "w") as f: | ||||||
|  | #             json.dump(docs, f) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | # # 下面把生成的文档信息转换成markdown文本 | ||||||
|  | # from coagent.utils.code2doc_util import * | ||||||
|  | 
 | ||||||
|  | # import json | ||||||
|  | # with open(f"/home/user/code_base/docs/raw_method.json", "r") as f: | ||||||
|  | #     method_raw_data = json.load(f) | ||||||
|  | 
 | ||||||
|  | # with open(f"/home/user/code_base/docs/raw_class.json", "r") as f: | ||||||
|  | #     class_raw_data = json.load(f) | ||||||
|  |      | ||||||
|  | 
 | ||||||
|  | # method_data = method_info_decode(method_raw_data) | ||||||
|  | # class_data = class_info_decode(class_raw_data) | ||||||
|  | # method_mds = encode2md(method_data, method_text_md) | ||||||
|  | # class_mds = encode2md(class_data, class_text_md) | ||||||
|  | 
 | ||||||
|  | # docs_dict = {} | ||||||
|  | # for k,v in class_mds.items(): | ||||||
|  | #     method_textmds = method_mds.get(k, []) | ||||||
|  | #     for vv in v: | ||||||
|  | #         # 理论上只有一个 | ||||||
|  | #         text_md = vv | ||||||
|  | 
 | ||||||
|  | #     for method_textmd in method_textmds: | ||||||
|  | #         text_md += "\n<br>" + method_textmd | ||||||
|  | 
 | ||||||
|  | #     docs_dict.setdefault(k, []).append(text_md) | ||||||
|  |      | ||||||
|  | #     with open(f"/home/user/code_base/docs/{k}.md", "w") as f: | ||||||
|  | #         f.write(text_md) | ||||||
							
								
								
									
										444
									
								
								examples/agent_examples/codeGenTestCases_example.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										444
									
								
								examples/agent_examples/codeGenTestCases_example.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,444 @@ | |||||||
|  | import os, sys, json | ||||||
|  | from loguru import logger | ||||||
|  | src_dir = os.path.join( | ||||||
|  |     os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | ||||||
|  | ) | ||||||
|  | sys.path.append(src_dir) | ||||||
|  | 
 | ||||||
|  | from configs.model_config import KB_ROOT_PATH, JUPYTER_WORK_PATH, CB_ROOT_PATH | ||||||
|  | from configs.server_config import SANDBOX_SERVER | ||||||
|  | from coagent.llm_models.llm_config import EmbedConfig, LLMConfig | ||||||
|  | 
 | ||||||
|  | from coagent.connector.phase import BasePhase | ||||||
|  | from coagent.connector.agents import BaseAgent | ||||||
|  | from coagent.connector.schema import Message | ||||||
|  | from coagent.codechat.codebase_handler.codebase_handler import CodeBaseHandler | ||||||
|  | 
 | ||||||
|  | import importlib | ||||||
|  | from loguru import logger | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | from coagent.tools import CodeRetrievalSingle, RelatedVerticesRetrival, Vertex2Code | ||||||
|  | 
 | ||||||
|  | # 定义一个新的agent类 | ||||||
|  | class CodeRetrieval(BaseAgent): | ||||||
|  | 
 | ||||||
|  |     def start_action_step(self, message: Message) -> Message: | ||||||
|  |         '''do action before agent predict ''' | ||||||
|  |         # 根据问题获取代码片段和节点信息 | ||||||
|  |         action_json = CodeRetrievalSingle.run(message.code_engine_name, message.origin_query, llm_config=self.llm_config, embed_config=self.embed_config, search_type="tag",  | ||||||
|  |                                               local_graph_path=message.local_graph_path, use_nh=message.use_nh) | ||||||
|  |         current_vertex = action_json['vertex'] | ||||||
|  |         message.customed_kargs["Code Snippet"] = action_json["code"] | ||||||
|  |         message.customed_kargs['Current_Vertex'] = current_vertex | ||||||
|  | 
 | ||||||
|  |         # 获取邻近节点 | ||||||
|  |         action_json = RelatedVerticesRetrival.run(message.code_engine_name, message.customed_kargs['Current_Vertex']) | ||||||
|  |         # 获取邻近节点所有代码 | ||||||
|  |         relative_vertex = [] | ||||||
|  |         retrieval_Codes = [] | ||||||
|  |         for vertex in action_json["vertices"]: | ||||||
|  |             # 由于代码是文件级别,所以相同文件代码不再获取 | ||||||
|  |             # logger.debug(f"{current_vertex}, {vertex}") | ||||||
|  |             current_vertex_name = current_vertex.replace("#", "").replace(".java", "" ) if current_vertex.endswith(".java") else current_vertex | ||||||
|  |             if current_vertex_name.split("#")[0] == vertex.split("#")[0]: continue | ||||||
|  | 
 | ||||||
|  |             action_json = Vertex2Code.run(message.code_engine_name, vertex) | ||||||
|  |             if action_json["code"]: | ||||||
|  |                 retrieval_Codes.append(action_json["code"]) | ||||||
|  |                 relative_vertex.append(vertex) | ||||||
|  |         #  | ||||||
|  |         message.customed_kargs["Retrieval_Codes"] = retrieval_Codes | ||||||
|  |         message.customed_kargs["Relative_vertex"] = relative_vertex | ||||||
|  |         return message | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | # add agent or prompt_manager class | ||||||
|  | agent_module = importlib.import_module("coagent.connector.agents") | ||||||
|  | setattr(agent_module, 'CodeRetrieval', CodeRetrieval) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | llm_config = LLMConfig( | ||||||
|  |     model_name="gpt-4", api_key=os.environ["OPENAI_API_KEY"],  | ||||||
|  |     api_base_url=os.environ["API_BASE_URL"], temperature=0.3 | ||||||
|  | ) | ||||||
|  | embed_config = EmbedConfig( | ||||||
|  |     embed_engine="model", embed_model="text2vec-base-chinese",  | ||||||
|  |     embed_model_path=os.path.join(src_dir, "embedding_models/text2vec-base-chinese") | ||||||
|  |     ) | ||||||
|  | 
 | ||||||
|  | ## initialize codebase | ||||||
|  | # delete codebase | ||||||
|  | codebase_name = 'client_local' | ||||||
|  | code_path = "D://chromeDownloads/devopschat-bot/client_v2/client" | ||||||
|  | use_nh = False | ||||||
|  | cbh = CodeBaseHandler(codebase_name, code_path, crawl_type='dir', use_nh=use_nh, local_graph_path=CB_ROOT_PATH, | ||||||
|  |                       llm_config=llm_config, embed_config=embed_config) | ||||||
|  | cbh.delete_codebase(codebase_name=codebase_name) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | # load codebase | ||||||
|  | codebase_name = 'client_local' | ||||||
|  | code_path = "D://chromeDownloads/devopschat-bot/client_v2/client" | ||||||
|  | use_nh = True | ||||||
|  | do_interpret = True | ||||||
|  | cbh = CodeBaseHandler(codebase_name, code_path, crawl_type='dir', use_nh=use_nh, local_graph_path=CB_ROOT_PATH, | ||||||
|  |                       llm_config=llm_config, embed_config=embed_config) | ||||||
|  | cbh.import_code(do_interpret=do_interpret) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | cbh = CodeBaseHandler(codebase_name, code_path, crawl_type='dir', use_nh=use_nh, local_graph_path=CB_ROOT_PATH, | ||||||
|  |                       llm_config=llm_config, embed_config=embed_config) | ||||||
|  | vertexes = cbh.search_vertices(vertex_type="class") | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | # log-level,print prompt和llm predict | ||||||
|  | os.environ["log_verbose"] = "0" | ||||||
|  | 
 | ||||||
|  | phase_name = "code2Tests" | ||||||
|  | phase = BasePhase( | ||||||
|  |     phase_name, sandbox_server=SANDBOX_SERVER, jupyter_work_path=JUPYTER_WORK_PATH, | ||||||
|  |     embed_config=embed_config, llm_config=llm_config, kb_root_path=KB_ROOT_PATH, | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | # round-1 | ||||||
|  | logger.debug(vertexes) | ||||||
|  | test_cases = [] | ||||||
|  | for vertex in vertexes: | ||||||
|  |     query_content = f"为{vertex}生成可执行的测例 " | ||||||
|  |     query = Message( | ||||||
|  |         role_name="human", role_type="user",  | ||||||
|  |         role_content=query_content, input_query=query_content, origin_query=query_content, | ||||||
|  |         code_engine_name=codebase_name, score_threshold=1.0, top_k=3, cb_search_type="tag", | ||||||
|  |         use_nh=use_nh, local_graph_path=CB_ROOT_PATH, | ||||||
|  |         ) | ||||||
|  |     output_message, output_memory = phase.step(query, reinit_memory=True) | ||||||
|  |     # print(output_memory.to_str_messages(return_all=True, content_key="parsed_output_list")) | ||||||
|  |     print(output_memory.get_spec_parserd_output()) | ||||||
|  |     values = output_memory.get_spec_parserd_output() | ||||||
|  |     test_code = {k:v for i in values for k,v in i.items() if k in ["SaveFileName", "Test Code"]} | ||||||
|  |     test_cases.append(test_code) | ||||||
|  | 
 | ||||||
|  |     os.makedirs(f"{CB_ROOT_PATH}/tests", exist_ok=True) | ||||||
|  | 
 | ||||||
|  |     with open(f"{CB_ROOT_PATH}/tests/{test_code['SaveFileName']}", "w") as f: | ||||||
|  |         f.write(test_code["Test Code"]) | ||||||
|  |     break | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | # import os, sys, json | ||||||
|  | # from loguru import logger | ||||||
|  | # src_dir = os.path.join( | ||||||
|  | #     os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | ||||||
|  | # ) | ||||||
|  | # sys.path.append(src_dir) | ||||||
|  | 
 | ||||||
|  | # from configs.model_config import KB_ROOT_PATH, JUPYTER_WORK_PATH, CB_ROOT_PATH | ||||||
|  | # from configs.server_config import SANDBOX_SERVER | ||||||
|  | # from coagent.tools import toLangchainTools, TOOL_DICT, TOOL_SETS | ||||||
|  | # from coagent.llm_models.llm_config import EmbedConfig, LLMConfig | ||||||
|  | 
 | ||||||
|  | # from coagent.connector.phase import BasePhase | ||||||
|  | # from coagent.connector.agents import BaseAgent | ||||||
|  | # from coagent.connector.chains import BaseChain | ||||||
|  | # from coagent.connector.schema import ( | ||||||
|  | #     Message, Memory, load_role_configs, load_phase_configs, load_chain_configs, ActionStatus | ||||||
|  | #     ) | ||||||
|  | # from coagent.connector.memory_manager import BaseMemoryManager | ||||||
|  | # from coagent.connector.configs import AGETN_CONFIGS, CHAIN_CONFIGS, PHASE_CONFIGS, BASE_PROMPT_CONFIGS | ||||||
|  | # from coagent.connector.prompt_manager.prompt_manager import PromptManager | ||||||
|  | # from coagent.codechat.codebase_handler.codebase_handler import CodeBaseHandler | ||||||
|  | 
 | ||||||
|  | # import importlib | ||||||
|  | # from loguru import logger | ||||||
|  | 
 | ||||||
|  | # # 下面给定了一份代码片段,以及来自于它的依赖类、依赖方法相关的代码片段,你需要判断是否为这段指定代码片段生成测例。 | ||||||
|  | # # 理论上所有代码都需要写测例,但是受限于人的精力不可能覆盖所有代码 | ||||||
|  | # # 考虑以下因素进行裁剪: | ||||||
|  | # # 功能性: 如果它实现的是一个具体的功能或逻辑,则通常需要编写测试用例以验证其正确性。 | ||||||
|  | # # 复杂性: 如果代码较为,尤其是包含多个条件判断、循环、异常处理等的代码,更可能隐藏bug,因此应该编写测试用例。如果代码涉及复杂的算法或者逻辑,那么编写测试用例可以帮助确保逻辑的正确性,并在未来的重构中防止引入错误。 | ||||||
|  | # # 关键性: 如果它是关键路径的一部分或影响到核心功能,那么它就需要被测试。对于核心业务逻辑或者系统的关键组件,应当编写全面的测试用例来确保功能的正确性和稳定性。 | ||||||
|  | # # 依赖性: 如果代码有外部依赖,可能需要编写集成测试或模拟这些依赖进行单元测试。 | ||||||
|  | # # 用户输入: 如果代码处理用户输入,尤其是来自外部的、非受控的输入,那么创建测试用例来检查输入验证和处理是很重要的。 | ||||||
|  | # # 频繁更改:对于经常需要更新或修改的代码,有相应的测试用例可以确保更改不会破坏现有功能。 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | # # 代码公开或重用:如果代码将被公开或用于其他项目,编写测试用例可以提高代码的可信度和易用性。 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | # # update new agent configs | ||||||
|  | # judgeGenerateTests_PROMPT = """#### Agent Profile | ||||||
|  | # When determining the necessity of writing test cases for a given code snippet,  | ||||||
|  | # it's essential to evaluate its interactions with dependent classes and methods (retrieved code snippets),  | ||||||
|  | # in addition to considering these critical factors: | ||||||
|  | # 1. Functionality: If it implements a concrete function or logic, test cases are typically necessary to verify its correctness. | ||||||
|  | # 2. Complexity: If the code is complex, especially if it contains multiple conditional statements, loops, exceptions handling, etc.,  | ||||||
|  | # it's more likely to harbor bugs, and thus test cases should be written.  | ||||||
|  | # If the code involves complex algorithms or logic, then writing test cases can help ensure the accuracy of the logic and prevent errors during future refactoring. | ||||||
|  | # 3. Criticality: If it's part of the critical path or affects core functionalities, then it needs to be tested.  | ||||||
|  | # Comprehensive test cases should be written for core business logic or key components of the system to ensure the correctness and stability of the functionality. | ||||||
|  | # 4. Dependencies: If the code has external dependencies, integration testing may be necessary, or mocking these dependencies during unit testing might be required. | ||||||
|  | # 5. User Input: If the code handles user input, especially from unregulated external sources, creating test cases to check input validation and handling is important. | ||||||
|  | # 6. Frequent Changes: For code that requires regular updates or modifications, having the appropriate test cases ensures that changes do not break existing functionalities. | ||||||
|  | 
 | ||||||
|  | # #### Input Format | ||||||
|  | 
 | ||||||
|  | # **Code Snippet:** the initial Code or objective that the user wanted to achieve | ||||||
|  | 
 | ||||||
|  | # **Retrieval Code Snippets:** These are the associated code segments that the main Code Snippet depends on.  | ||||||
|  | # Examine these snippets to understand how they interact with the main snippet and to determine how they might affect the overall functionality. | ||||||
|  | 
 | ||||||
|  | # #### Response Output Format | ||||||
|  | # **Action Status:** Set to 'finished' or 'continued'.  | ||||||
|  | # If set to 'finished', the code snippet does not warrant the generation of a test case. | ||||||
|  | # If set to 'continued', the code snippet necessitates the creation of a test case. | ||||||
|  | 
 | ||||||
|  | # **REASON:** Justify the selection of 'finished' or 'continued', contemplating the decision through a step-by-step rationale. | ||||||
|  | # """ | ||||||
|  | 
 | ||||||
|  | # generateTests_PROMPT = """#### Agent Profile | ||||||
|  | # As an agent specializing in software quality assurance,  | ||||||
|  | # your mission is to craft comprehensive test cases that bolster the functionality, reliability, and robustness of a specified Code Snippet.  | ||||||
|  | # This task is to be carried out with a keen understanding of the snippet's interactions with its dependent classes and methods—collectively referred to as Retrieval Code Snippets.  | ||||||
|  | # Analyze the details given below to grasp the code's intended purpose, its inherent complexity, and the context within which it operates.  | ||||||
|  | # Your constructed test cases must thoroughly examine the various factors influencing the code's quality and performance. | ||||||
|  | 
 | ||||||
|  | # ATTENTION: response carefully referenced "Response Output Format" in format. | ||||||
|  | 
 | ||||||
|  | # Each test case should include: | ||||||
|  | # 1. clear description of the test purpose. | ||||||
|  | # 2. The input values or conditions for the test. | ||||||
|  | # 3. The expected outcome or assertion for the test. | ||||||
|  | # 4. Appropriate tags (e.g., 'functional', 'integration', 'regression') that classify the type of test case. | ||||||
|  | # 5. these test code should have package and import | ||||||
|  | 
 | ||||||
|  | # #### Input Format | ||||||
|  | 
 | ||||||
|  | # **Code Snippet:** the initial Code or objective that the user wanted to achieve | ||||||
|  | 
 | ||||||
|  | # **Retrieval Code Snippets:** These are the interrelated pieces of code sourced from the codebase, which support or influence the primary Code Snippet. | ||||||
|  | 
 | ||||||
|  | # #### Response Output Format | ||||||
|  | # **SaveFileName:** construct a local file name based on Question and Context, such as | ||||||
|  | 
 | ||||||
|  | # ```java | ||||||
|  | # package/class.java | ||||||
|  | # ``` | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | # **Test Code:** generate the test code for the current Code Snippet. | ||||||
|  | # ```java | ||||||
|  | # ... | ||||||
|  | # ``` | ||||||
|  | 
 | ||||||
|  | # """ | ||||||
|  | 
 | ||||||
|  | # CODE_GENERATE_TESTS_PROMPT_CONFIGS = [ | ||||||
|  | #     {"field_name": 'agent_profile', "function_name": 'handle_agent_profile', "is_context": False}, | ||||||
|  | #     # {"field_name": 'tool_information',"function_name": 'handle_tool_data', "is_context": False}, | ||||||
|  | #     {"field_name": 'context_placeholder', "function_name": '', "is_context": True}, | ||||||
|  | #     # {"field_name": 'reference_documents', "function_name": 'handle_doc_info'}, | ||||||
|  | #     {"field_name": 'session_records', "function_name": 'handle_session_records'}, | ||||||
|  | #     {"field_name": 'code_snippet', "function_name": 'handle_code_snippet'}, | ||||||
|  | #     {"field_name": 'retrieval_codes', "function_name": 'handle_retrieval_codes', "description": ""}, | ||||||
|  | #     {"field_name": 'output_format', "function_name": 'handle_output_format', 'title': 'Response Output Format', "is_context": False}, | ||||||
|  | #     {"field_name": 'begin!!!', "function_name": 'handle_response', "is_context": False, "omit_if_empty": False} | ||||||
|  | # ] | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | # class CodeRetrievalPM(PromptManager): | ||||||
|  | #     def handle_code_snippet(self, **kwargs) -> str: | ||||||
|  | #         if 'previous_agent_message' not in kwargs: | ||||||
|  | #             return "" | ||||||
|  | #         previous_agent_message: Message = kwargs['previous_agent_message'] | ||||||
|  | #         code_snippet = previous_agent_message.customed_kargs.get("Code Snippet", "") | ||||||
|  | #         current_vertex = previous_agent_message.customed_kargs.get("Current_Vertex", "") | ||||||
|  | #         instruction = "the initial Code or objective that the user wanted to achieve" | ||||||
|  | #         return instruction + "\n" + f"name: {current_vertex}\n{code_snippet}" | ||||||
|  | 
 | ||||||
|  | #     def handle_retrieval_codes(self, **kwargs) -> str: | ||||||
|  | #         if 'previous_agent_message' not in kwargs: | ||||||
|  | #             return "" | ||||||
|  | #         previous_agent_message: Message = kwargs['previous_agent_message'] | ||||||
|  | #         Retrieval_Codes = previous_agent_message.customed_kargs["Retrieval_Codes"] | ||||||
|  | #         Relative_vertex = previous_agent_message.customed_kargs["Relative_vertex"] | ||||||
|  | #         instruction = "the initial Code or objective that the user wanted to achieve" | ||||||
|  | #         s = instruction + "\n" + "\n".join([f"name: {vertext}\n{code}" for vertext, code in zip(Relative_vertex, Retrieval_Codes)]) | ||||||
|  | #         return s | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | # AGETN_CONFIGS.update({ | ||||||
|  | #     "CodeJudger": { | ||||||
|  | #         "role": { | ||||||
|  | #             "role_prompt": judgeGenerateTests_PROMPT, | ||||||
|  | #             "role_type": "assistant", | ||||||
|  | #             "role_name": "CodeJudger", | ||||||
|  | #             "role_desc": "", | ||||||
|  | #             "agent_type": "CodeRetrieval" | ||||||
|  | #             # "agent_type": "BaseAgent" | ||||||
|  | #         }, | ||||||
|  | #         "prompt_config": CODE_GENERATE_TESTS_PROMPT_CONFIGS, | ||||||
|  | #         "prompt_manager_type": "CodeRetrievalPM", | ||||||
|  | #         "chat_turn": 1, | ||||||
|  | #         "focus_agents": [], | ||||||
|  | #         "focus_message_keys": [], | ||||||
|  | #     }, | ||||||
|  | #     "generateTests": { | ||||||
|  | #         "role": { | ||||||
|  | #             "role_prompt": generateTests_PROMPT, | ||||||
|  | #             "role_type": "assistant", | ||||||
|  | #             "role_name": "generateTests", | ||||||
|  | #             "role_desc": "", | ||||||
|  | #             "agent_type": "CodeRetrieval" | ||||||
|  | #             # "agent_type": "BaseAgent" | ||||||
|  | #         }, | ||||||
|  | #         "prompt_config": CODE_GENERATE_TESTS_PROMPT_CONFIGS, | ||||||
|  | #         "prompt_manager_type": "CodeRetrievalPM", | ||||||
|  | #         "chat_turn": 1, | ||||||
|  | #         "focus_agents": [], | ||||||
|  | #         "focus_message_keys": [], | ||||||
|  | #     }, | ||||||
|  | # }) | ||||||
|  | # # update new chain configs | ||||||
|  | # CHAIN_CONFIGS.update({ | ||||||
|  | #     "codeRetrievalChain": { | ||||||
|  | #         "chain_name": "codeRetrievalChain", | ||||||
|  | #         "chain_type": "BaseChain", | ||||||
|  | #         "agents": ["CodeJudger", "generateTests"], | ||||||
|  | #         "chat_turn": 1, | ||||||
|  | #         "do_checker": False, | ||||||
|  | #         "chain_prompt": "" | ||||||
|  | #     } | ||||||
|  | # }) | ||||||
|  | 
 | ||||||
|  | # # update phase configs | ||||||
|  | # PHASE_CONFIGS.update({ | ||||||
|  | #     "codeGenerateTests": { | ||||||
|  | #         "phase_name": "codeGenerateTests", | ||||||
|  | #         "phase_type": "BasePhase", | ||||||
|  | #         "chains": ["codeRetrievalChain"], | ||||||
|  | #         "do_summary": False, | ||||||
|  | #         "do_search": False, | ||||||
|  | #         "do_doc_retrieval": False, | ||||||
|  | #         "do_code_retrieval": False, | ||||||
|  | #         "do_tool_retrieval": False, | ||||||
|  | #     }, | ||||||
|  | # }) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | # role_configs = load_role_configs(AGETN_CONFIGS) | ||||||
|  | # chain_configs = load_chain_configs(CHAIN_CONFIGS) | ||||||
|  | # phase_configs = load_phase_configs(PHASE_CONFIGS) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | # from coagent.tools import CodeRetrievalSingle, RelatedVerticesRetrival, Vertex2Code | ||||||
|  | 
 | ||||||
|  | # # 定义一个新的agent类 | ||||||
|  | # class CodeRetrieval(BaseAgent): | ||||||
|  | 
 | ||||||
|  | #     def start_action_step(self, message: Message) -> Message: | ||||||
|  | #         '''do action before agent predict ''' | ||||||
|  | #         # 根据问题获取代码片段和节点信息 | ||||||
|  | #         action_json = CodeRetrievalSingle.run(message.code_engine_name, message.origin_query, llm_config=self.llm_config, embed_config=self.embed_config, search_type="tag",  | ||||||
|  | #                                               local_graph_path=message.local_graph_path, use_nh=message.use_nh) | ||||||
|  | #         current_vertex = action_json['vertex'] | ||||||
|  | #         message.customed_kargs["Code Snippet"] = action_json["code"] | ||||||
|  | #         message.customed_kargs['Current_Vertex'] = current_vertex | ||||||
|  | 
 | ||||||
|  | #         # 获取邻近节点 | ||||||
|  | #         action_json = RelatedVerticesRetrival.run(message.code_engine_name, message.customed_kargs['Current_Vertex']) | ||||||
|  | #         # 获取邻近节点所有代码 | ||||||
|  | #         relative_vertex = [] | ||||||
|  | #         retrieval_Codes = [] | ||||||
|  | #         for vertex in action_json["vertices"]: | ||||||
|  | #             # 由于代码是文件级别,所以相同文件代码不再获取 | ||||||
|  | #             # logger.debug(f"{current_vertex}, {vertex}") | ||||||
|  | #             current_vertex_name = current_vertex.replace("#", "").replace(".java", "" ) if current_vertex.endswith(".java") else current_vertex | ||||||
|  | #             if current_vertex_name.split("#")[0] == vertex.split("#")[0]: continue | ||||||
|  | 
 | ||||||
|  | #             action_json = Vertex2Code.run(message.code_engine_name, vertex) | ||||||
|  | #             if action_json["code"]: | ||||||
|  | #                 retrieval_Codes.append(action_json["code"]) | ||||||
|  | #                 relative_vertex.append(vertex) | ||||||
|  | #         #  | ||||||
|  | #         message.customed_kargs["Retrieval_Codes"] = retrieval_Codes | ||||||
|  | #         message.customed_kargs["Relative_vertex"] = relative_vertex | ||||||
|  | #         return message | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | # # add agent or prompt_manager class | ||||||
|  | # agent_module = importlib.import_module("coagent.connector.agents") | ||||||
|  | # prompt_manager_module = importlib.import_module("coagent.connector.prompt_manager") | ||||||
|  | 
 | ||||||
|  | # setattr(agent_module, 'CodeRetrieval', CodeRetrieval) | ||||||
|  | # setattr(prompt_manager_module, 'CodeRetrievalPM', CodeRetrievalPM) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | # # log-level,print prompt和llm predict | ||||||
|  | # os.environ["log_verbose"] = "0" | ||||||
|  | 
 | ||||||
|  | # phase_name = "codeGenerateTests" | ||||||
|  | # llm_config = LLMConfig( | ||||||
|  | #     model_name="gpt-4", api_key=os.environ["OPENAI_API_KEY"],  | ||||||
|  | #     api_base_url=os.environ["API_BASE_URL"], temperature=0.3 | ||||||
|  | # ) | ||||||
|  | # embed_config = EmbedConfig( | ||||||
|  | #     embed_engine="model", embed_model="text2vec-base-chinese",  | ||||||
|  | #     embed_model_path=os.path.join(src_dir, "embedding_models/text2vec-base-chinese") | ||||||
|  | #     ) | ||||||
|  | 
 | ||||||
|  | # ## initialize codebase | ||||||
|  | # # delete codebase | ||||||
|  | # codebase_name = 'client_local' | ||||||
|  | # code_path = "D://chromeDownloads/devopschat-bot/client_v2/client" | ||||||
|  | # use_nh = False | ||||||
|  | # cbh = CodeBaseHandler(codebase_name, code_path, crawl_type='dir', use_nh=use_nh, local_graph_path=CB_ROOT_PATH, | ||||||
|  | #                       llm_config=llm_config, embed_config=embed_config) | ||||||
|  | # cbh.delete_codebase(codebase_name=codebase_name) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | # # load codebase | ||||||
|  | # codebase_name = 'client_local' | ||||||
|  | # code_path = "D://chromeDownloads/devopschat-bot/client_v2/client" | ||||||
|  | # use_nh = True | ||||||
|  | # do_interpret = True | ||||||
|  | # cbh = CodeBaseHandler(codebase_name, code_path, crawl_type='dir', use_nh=use_nh, local_graph_path=CB_ROOT_PATH, | ||||||
|  | #                       llm_config=llm_config, embed_config=embed_config) | ||||||
|  | # cbh.import_code(do_interpret=do_interpret) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | # cbh = CodeBaseHandler(codebase_name, code_path, crawl_type='dir', use_nh=use_nh, local_graph_path=CB_ROOT_PATH, | ||||||
|  | #                       llm_config=llm_config, embed_config=embed_config) | ||||||
|  | # vertexes = cbh.search_vertices(vertex_type="class") | ||||||
|  | 
 | ||||||
|  | # phase = BasePhase( | ||||||
|  | #     phase_name, sandbox_server=SANDBOX_SERVER, jupyter_work_path=JUPYTER_WORK_PATH, | ||||||
|  | #     embed_config=embed_config, llm_config=llm_config, kb_root_path=KB_ROOT_PATH, | ||||||
|  | # ) | ||||||
|  | 
 | ||||||
|  | # # round-1 | ||||||
|  | # logger.debug(vertexes) | ||||||
|  | # test_cases = [] | ||||||
|  | # for vertex in vertexes: | ||||||
|  | #     query_content = f"为{vertex}生成可执行的测例 " | ||||||
|  | #     query = Message( | ||||||
|  | #         role_name="human", role_type="user",  | ||||||
|  | #         role_content=query_content, input_query=query_content, origin_query=query_content, | ||||||
|  | #         code_engine_name=codebase_name, score_threshold=1.0, top_k=3, cb_search_type="tag", | ||||||
|  | #         use_nh=use_nh, local_graph_path=CB_ROOT_PATH, | ||||||
|  | #         ) | ||||||
|  | #     output_message, output_memory = phase.step(query, reinit_memory=True) | ||||||
|  | #     # print(output_memory.to_str_messages(return_all=True, content_key="parsed_output_list")) | ||||||
|  | #     print(output_memory.get_spec_parserd_output()) | ||||||
|  | #     values = output_memory.get_spec_parserd_output() | ||||||
|  | #     test_code = {k:v for i in values for k,v in i.items() if k in ["SaveFileName", "Test Code"]} | ||||||
|  | #     test_cases.append(test_code) | ||||||
|  | 
 | ||||||
|  | #     os.makedirs(f"{CB_ROOT_PATH}/tests", exist_ok=True) | ||||||
|  | 
 | ||||||
|  | #     with open(f"{CB_ROOT_PATH}/tests/{test_code['SaveFileName']}", "w") as f: | ||||||
|  | #         f.write(test_code["Test Code"]) | ||||||
| @ -17,7 +17,7 @@ os.environ["log_verbose"] = "2" | |||||||
| 
 | 
 | ||||||
| phase_name = "codeReactPhase" | phase_name = "codeReactPhase" | ||||||
| llm_config = LLMConfig( | llm_config = LLMConfig( | ||||||
|     model_name="gpt-3.5-turbo", model_device="cpu",api_key=os.environ["OPENAI_API_KEY"],  |     model_name="gpt-3.5-turbo",api_key=os.environ["OPENAI_API_KEY"],  | ||||||
|     api_base_url=os.environ["API_BASE_URL"], temperature=0.3 |     api_base_url=os.environ["API_BASE_URL"], temperature=0.3 | ||||||
|     ) |     ) | ||||||
| embed_config = EmbedConfig( | embed_config = EmbedConfig( | ||||||
|  | |||||||
| @ -18,8 +18,7 @@ from coagent.connector.schema import ( | |||||||
|     ) |     ) | ||||||
| from coagent.connector.memory_manager import BaseMemoryManager | from coagent.connector.memory_manager import BaseMemoryManager | ||||||
| from coagent.connector.configs import AGETN_CONFIGS, CHAIN_CONFIGS, PHASE_CONFIGS, BASE_PROMPT_CONFIGS | from coagent.connector.configs import AGETN_CONFIGS, CHAIN_CONFIGS, PHASE_CONFIGS, BASE_PROMPT_CONFIGS | ||||||
| from coagent.connector.utils import parse_section | from coagent.connector.prompt_manager.prompt_manager import PromptManager | ||||||
| from coagent.connector.prompt_manager import PromptManager |  | ||||||
| import importlib | import importlib | ||||||
| 
 | 
 | ||||||
| from loguru import logger | from loguru import logger | ||||||
| @ -230,7 +229,7 @@ os.environ["log_verbose"] = "2" | |||||||
| 
 | 
 | ||||||
| phase_name = "codeRetrievalPhase" | phase_name = "codeRetrievalPhase" | ||||||
| llm_config = LLMConfig( | llm_config = LLMConfig( | ||||||
|     model_name="gpt-3.5-turbo", model_device="cpu",api_key=os.environ["OPENAI_API_KEY"],  |     model_name="gpt-3.5-turbo", api_key=os.environ["OPENAI_API_KEY"],  | ||||||
|     api_base_url=os.environ["API_BASE_URL"], temperature=0.3 |     api_base_url=os.environ["API_BASE_URL"], temperature=0.3 | ||||||
|     ) |     ) | ||||||
| embed_config = EmbedConfig( | embed_config = EmbedConfig( | ||||||
| @ -246,7 +245,7 @@ query_content = "UtilsTest 这个类中测试了哪些函数,测试的函数代 | |||||||
| query = Message( | query = Message( | ||||||
|     role_name="human", role_type="user",  |     role_name="human", role_type="user",  | ||||||
|     role_content=query_content, input_query=query_content, origin_query=query_content, |     role_content=query_content, input_query=query_content, origin_query=query_content, | ||||||
|     code_engine_name="client", score_threshold=1.0, top_k=3, cb_search_type="tag" |     code_engine_name="client_1", score_threshold=1.0, top_k=3, cb_search_type="tag" | ||||||
|     ) |     ) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -24,7 +24,7 @@ os.environ["log_verbose"] = "2" | |||||||
| 
 | 
 | ||||||
| phase_name = "codeToolReactPhase" | phase_name = "codeToolReactPhase" | ||||||
| llm_config = LLMConfig( | llm_config = LLMConfig( | ||||||
|     model_name="gpt-3.5-turbo-0613", model_device="cpu",api_key=os.environ["OPENAI_API_KEY"],  |     model_name="gpt-3.5-turbo-0613", api_key=os.environ["OPENAI_API_KEY"],  | ||||||
|     api_base_url=os.environ["API_BASE_URL"], temperature=0.7 |     api_base_url=os.environ["API_BASE_URL"], temperature=0.7 | ||||||
|     ) |     ) | ||||||
| embed_config = EmbedConfig( | embed_config = EmbedConfig( | ||||||
|  | |||||||
| @ -17,7 +17,7 @@ from coagent.connector.schema import Message, Memory | |||||||
| 
 | 
 | ||||||
| tools = toLangchainTools([TOOL_DICT[i] for i in TOOL_SETS if i in TOOL_DICT]) | tools = toLangchainTools([TOOL_DICT[i] for i in TOOL_SETS if i in TOOL_DICT]) | ||||||
| llm_config = LLMConfig( | llm_config = LLMConfig( | ||||||
|     model_name="gpt-3.5-turbo", model_device="cpu",api_key=os.environ["OPENAI_API_KEY"],  |     model_name="gpt-3.5-turbo",api_key=os.environ["OPENAI_API_KEY"],  | ||||||
|     api_base_url=os.environ["API_BASE_URL"], temperature=0.3 |     api_base_url=os.environ["API_BASE_URL"], temperature=0.3 | ||||||
|     ) |     ) | ||||||
| embed_config = EmbedConfig( | embed_config = EmbedConfig( | ||||||
|  | |||||||
| @ -18,7 +18,7 @@ os.environ["log_verbose"] = "0" | |||||||
| 
 | 
 | ||||||
| phase_name = "metagpt_code_devlop" | phase_name = "metagpt_code_devlop" | ||||||
| llm_config = LLMConfig( | llm_config = LLMConfig( | ||||||
|     model_name="gpt-3.5-turbo", model_device="cpu",api_key=os.environ["OPENAI_API_KEY"],  |     model_name="gpt-3.5-turbo", api_key=os.environ["OPENAI_API_KEY"],  | ||||||
|     api_base_url=os.environ["API_BASE_URL"], temperature=0.3 |     api_base_url=os.environ["API_BASE_URL"], temperature=0.3 | ||||||
|     ) |     ) | ||||||
| embed_config = EmbedConfig( | embed_config = EmbedConfig( | ||||||
|  | |||||||
| @ -20,7 +20,7 @@ os.environ["log_verbose"] = "2" | |||||||
| 
 | 
 | ||||||
| phase_name = "searchChatPhase" | phase_name = "searchChatPhase" | ||||||
| llm_config = LLMConfig( | llm_config = LLMConfig( | ||||||
|     model_name="gpt-3.5-turbo", model_device="cpu",api_key=os.environ["OPENAI_API_KEY"],  |     model_name="gpt-3.5-turbo", api_key=os.environ["OPENAI_API_KEY"],  | ||||||
|     api_base_url=os.environ["API_BASE_URL"], temperature=0.3 |     api_base_url=os.environ["API_BASE_URL"], temperature=0.3 | ||||||
|     ) |     ) | ||||||
| embed_config = EmbedConfig( | embed_config = EmbedConfig( | ||||||
|  | |||||||
| @ -18,7 +18,7 @@ os.environ["log_verbose"] = "2" | |||||||
| 
 | 
 | ||||||
| phase_name = "toolReactPhase" | phase_name = "toolReactPhase" | ||||||
| llm_config = LLMConfig( | llm_config = LLMConfig( | ||||||
|     model_name="gpt-3.5-turbo", model_device="cpu",api_key=os.environ["OPENAI_API_KEY"],  |     model_name="gpt-3.5-turbo",api_key=os.environ["OPENAI_API_KEY"],  | ||||||
|     api_base_url=os.environ["API_BASE_URL"], temperature=0.3 |     api_base_url=os.environ["API_BASE_URL"], temperature=0.3 | ||||||
|     ) |     ) | ||||||
| embed_config = EmbedConfig( | embed_config = EmbedConfig( | ||||||
|  | |||||||
| @ -151,9 +151,9 @@ def create_app(): | |||||||
|              )(delete_cb) |              )(delete_cb) | ||||||
| 
 | 
 | ||||||
|     app.post("/code_base/code_base_chat", |     app.post("/code_base/code_base_chat", | ||||||
|              tags=["Code Base Management"], |             tags=["Code Base Management"], | ||||||
|              summary="删除 code_base" |             summary="code_base 对话" | ||||||
|              )(delete_cb) |             )(search_code) | ||||||
| 
 | 
 | ||||||
|     app.get("/code_base/list_code_bases", |     app.get("/code_base/list_code_bases", | ||||||
|             tags=["Code Base Management"], |             tags=["Code Base Management"], | ||||||
|  | |||||||
| @ -117,7 +117,7 @@ PHASE_CONFIGS.update({ | |||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| llm_config = LLMConfig( | llm_config = LLMConfig( | ||||||
|     model_name="gpt-3.5-turbo", model_device="cpu",api_key=os.environ["OPENAI_API_KEY"],  |     model_name="gpt-3.5-turbo", api_key=os.environ["OPENAI_API_KEY"],  | ||||||
|     api_base_url=os.environ["API_BASE_URL"], temperature=0.3 |     api_base_url=os.environ["API_BASE_URL"], temperature=0.3 | ||||||
|     ) |     ) | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -98,12 +98,6 @@ def start_docker(client, script_shs, ports, image_name, container_name, mounts=N | |||||||
| network_name ='my_network' | network_name ='my_network' | ||||||
| 
 | 
 | ||||||
| def start_sandbox_service(network_name ='my_network'): | def start_sandbox_service(network_name ='my_network'): | ||||||
|     # networks = client.networks.list() |  | ||||||
|     # if any([network_name==i.attrs["Name"] for i in networks]): |  | ||||||
|     #     network = client.networks.get(network_name) |  | ||||||
|     # else: |  | ||||||
|     #     network = client.networks.create('my_network', driver='bridge') |  | ||||||
| 
 |  | ||||||
|     mount = Mount( |     mount = Mount( | ||||||
|             type='bind', |             type='bind', | ||||||
|             source=os.path.join(src_dir, "jupyter_work"), |             source=os.path.join(src_dir, "jupyter_work"), | ||||||
| @ -114,6 +108,12 @@ def start_sandbox_service(network_name ='my_network'): | |||||||
|     # 沙盒的启动与服务的启动是独立的 |     # 沙盒的启动与服务的启动是独立的 | ||||||
|     if SANDBOX_SERVER["do_remote"]: |     if SANDBOX_SERVER["do_remote"]: | ||||||
|         client = docker.from_env() |         client = docker.from_env() | ||||||
|  |         networks = client.networks.list() | ||||||
|  |         if any([network_name==i.attrs["Name"] for i in networks]): | ||||||
|  |             network = client.networks.get(network_name) | ||||||
|  |         else: | ||||||
|  |             network = client.networks.create('my_network', driver='bridge') | ||||||
|  | 
 | ||||||
|         # 启动容器 |         # 启动容器 | ||||||
|         logger.info("start container sandbox service") |         logger.info("start container sandbox service") | ||||||
|         JUPYTER_WORK_PATH = "/home/user/chatbot/jupyter_work" |         JUPYTER_WORK_PATH = "/home/user/chatbot/jupyter_work" | ||||||
| @ -150,7 +150,7 @@ def start_api_service(sandbox_host=DEFAULT_BIND_HOST): | |||||||
|         client = docker.from_env() |         client = docker.from_env() | ||||||
|         logger.info("start container service") |         logger.info("start container service") | ||||||
|         check_process("api.py", do_stop=True) |         check_process("api.py", do_stop=True) | ||||||
|         check_process("sdfile_api.py", do_stop=True) |         check_process("llm_api.py", do_stop=True) | ||||||
|         check_process("sdfile_api.py", do_stop=True) |         check_process("sdfile_api.py", do_stop=True) | ||||||
|         check_process("webui.py", do_stop=True) |         check_process("webui.py", do_stop=True) | ||||||
|         mount = Mount( |         mount = Mount( | ||||||
| @ -159,27 +159,28 @@ def start_api_service(sandbox_host=DEFAULT_BIND_HOST): | |||||||
|             target='/home/user/chatbot/', |             target='/home/user/chatbot/', | ||||||
|             read_only=False  # 如果需要只读访问,将此选项设置为True |             read_only=False  # 如果需要只读访问,将此选项设置为True | ||||||
|         ) |         ) | ||||||
|         mount_database = Mount( |         # mount_database = Mount( | ||||||
|             type='bind', |         #     type='bind', | ||||||
|             source=os.path.join(src_dir, "knowledge_base"), |         #     source=os.path.join(src_dir, "knowledge_base"), | ||||||
|             target='/home/user/knowledge_base/', |         #     target='/home/user/knowledge_base/', | ||||||
|             read_only=False  # 如果需要只读访问,将此选项设置为True |         #     read_only=False  # 如果需要只读访问,将此选项设置为True | ||||||
|         ) |         # ) | ||||||
|         mount_code_database = Mount( |         # mount_code_database = Mount( | ||||||
|             type='bind', |         #     type='bind', | ||||||
|             source=os.path.join(src_dir, "code_base"), |         #     source=os.path.join(src_dir, "code_base"), | ||||||
|             target='/home/user/code_base/', |         #     target='/home/user/code_base/', | ||||||
|             read_only=False  # 如果需要只读访问,将此选项设置为True |         #     read_only=False  # 如果需要只读访问,将此选项设置为True | ||||||
|         ) |         # ) | ||||||
|         ports={ |         ports={ | ||||||
|                 f"{API_SERVER['docker_port']}/tcp": f"{API_SERVER['port']}/tcp",  |                 f"{API_SERVER['docker_port']}/tcp": f"{API_SERVER['port']}/tcp",  | ||||||
|                 f"{WEBUI_SERVER['docker_port']}/tcp": f"{WEBUI_SERVER['port']}/tcp", |                 f"{WEBUI_SERVER['docker_port']}/tcp": f"{WEBUI_SERVER['port']}/tcp", | ||||||
|                 f"{SDFILE_API_SERVER['docker_port']}/tcp": f"{SDFILE_API_SERVER['port']}/tcp", |                 f"{SDFILE_API_SERVER['docker_port']}/tcp": f"{SDFILE_API_SERVER['port']}/tcp", | ||||||
|                 f"{NEBULA_GRAPH_SERVER['docker_port']}/tcp": f"{NEBULA_GRAPH_SERVER['port']}/tcp" |                 f"{NEBULA_GRAPH_SERVER['docker_port']}/tcp": f"{NEBULA_GRAPH_SERVER['port']}/tcp" | ||||||
|                 } |                 } | ||||||
|         mounts = [mount, mount_database, mount_code_database] |         # mounts = [mount, mount_database, mount_code_database] | ||||||
|  |         mounts = [mount] | ||||||
|         script_shs = [ |         script_shs = [ | ||||||
|             "mkdir -p /home/user/logs", |             "mkdir -p /home/user/chatbot/logs", | ||||||
|             ''' |             ''' | ||||||
|             if [ -d "/home/user/chatbot/data/nebula_data/data/meta" ]; then |             if [ -d "/home/user/chatbot/data/nebula_data/data/meta" ]; then | ||||||
|                 cp -r /home/user/chatbot/data/nebula_data/data /usr/local/nebula/ |                 cp -r /home/user/chatbot/data/nebula_data/data /usr/local/nebula/ | ||||||
| @ -197,12 +198,12 @@ def start_api_service(sandbox_host=DEFAULT_BIND_HOST): | |||||||
|             "pip install jieba", |             "pip install jieba", | ||||||
|             "pip install duckduckgo-search", |             "pip install duckduckgo-search", | ||||||
| 
 | 
 | ||||||
|             "nohup python chatbot/examples/sdfile_api.py > /home/user/logs/sdfile_api.log 2>&1 &", |             "nohup python chatbot/examples/sdfile_api.py > /home/user/chatbot/logs/sdfile_api.log 2>&1 &", | ||||||
|             f"export DUCKDUCKGO_PROXY=socks5://host.docker.internal:13659 && export SANDBOX_HOST={sandbox_host} &&\ |             f"export DUCKDUCKGO_PROXY=socks5://host.docker.internal:13659 && export SANDBOX_HOST={sandbox_host} &&\ | ||||||
|                 nohup python chatbot/examples/api.py > /home/user/logs/api.log 2>&1 &", |                 nohup python chatbot/examples/api.py > /home/user/chatbot/logs/api.log 2>&1 &", | ||||||
|             "nohup python chatbot/examples/llm_api.py > /home/user/llm.log  2>&1 &", |             "nohup python chatbot/examples/llm_api.py > /home/user/llm.log  2>&1 &", | ||||||
|             f"export DUCKDUCKGO_PROXY=socks5://host.docker.internal:13659 && export SANDBOX_HOST={sandbox_host} &&\ |             f"export DUCKDUCKGO_PROXY=socks5://host.docker.internal:13659 && export SANDBOX_HOST={sandbox_host} &&\ | ||||||
|                 cd chatbot/examples && nohup streamlit run webui.py > /home/user/logs/start_webui.log 2>&1 &" |                 cd chatbot/examples && nohup streamlit run webui.py > /home/user/chatbot/logs/start_webui.log 2>&1 &" | ||||||
|             ] |             ] | ||||||
|         if check_docker(client, CONTRAINER_NAME, do_stop=True): |         if check_docker(client, CONTRAINER_NAME, do_stop=True): | ||||||
|             container = start_docker(client, script_shs, ports, IMAGE_NAME, CONTRAINER_NAME, mounts, network=network_name) |             container = start_docker(client, script_shs, ports, IMAGE_NAME, CONTRAINER_NAME, mounts, network=network_name) | ||||||
| @ -212,12 +213,9 @@ def start_api_service(sandbox_host=DEFAULT_BIND_HOST): | |||||||
|         # 关闭之前启动的docker 服务 |         # 关闭之前启动的docker 服务 | ||||||
|         # check_docker(client, CONTRAINER_NAME, do_stop=True, ) |         # check_docker(client, CONTRAINER_NAME, do_stop=True, ) | ||||||
| 
 | 
 | ||||||
|         # api_sh = "nohup python ../coagent/service/api.py > ../logs/api.log 2>&1 &" |  | ||||||
|         api_sh = "nohup python api.py > ../logs/api.log 2>&1 &" |         api_sh = "nohup python api.py > ../logs/api.log 2>&1 &" | ||||||
|         # sdfile_sh = "nohup python ../coagent/service/sdfile_api.py > ../logs/sdfile_api.log 2>&1 &" |  | ||||||
|         sdfile_sh = "nohup python sdfile_api.py > ../logs/sdfile_api.log 2>&1 &" |         sdfile_sh = "nohup python sdfile_api.py > ../logs/sdfile_api.log 2>&1 &" | ||||||
|         notebook_sh = f"nohup jupyter-notebook --NotebookApp.token=mytoken --port={SANDBOX_SERVER['port']} --allow-root --ip=0.0.0.0 --notebook-dir={JUPYTER_WORK_PATH} --no-browser --ServerApp.disable_check_xsrf=True > ../logs/sandbox.log 2>&1  &" |         notebook_sh = f"nohup jupyter-notebook --NotebookApp.token=mytoken --port={SANDBOX_SERVER['port']} --allow-root --ip=0.0.0.0 --notebook-dir={JUPYTER_WORK_PATH} --no-browser --ServerApp.disable_check_xsrf=True > ../logs/sandbox.log 2>&1  &" | ||||||
|         # llm_sh = "nohup python ../coagent/service/llm_api.py > ../logs/llm_api.log 2>&1 &" |  | ||||||
|         llm_sh = "nohup python llm_api.py > ../logs/llm_api.log 2>&1 &" |         llm_sh = "nohup python llm_api.py > ../logs/llm_api.log 2>&1 &" | ||||||
|         webui_sh = "streamlit run webui.py" if USE_TTY else "streamlit run webui.py" |         webui_sh = "streamlit run webui.py" if USE_TTY else "streamlit run webui.py" | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -22,7 +22,7 @@ from coagent.service.service_factory import get_cb_details, get_cb_details_by_cb | |||||||
| from coagent.orm import table_init | from coagent.orm import table_init | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| from configs.model_config import EMBEDDING_DEVICE, EMBEDDING_ENGINE, EMBEDDING_MODEL, embedding_model_dict | from configs.model_config import EMBEDDING_DEVICE, EMBEDDING_ENGINE, EMBEDDING_MODEL, embedding_model_dict,llm_model_dict | ||||||
| # SENTENCE_SIZE = 100 | # SENTENCE_SIZE = 100 | ||||||
| 
 | 
 | ||||||
| cell_renderer = JsCode("""function(params) {if(params.value==true){return '✓'}else{return '×'}}""") | cell_renderer = JsCode("""function(params) {if(params.value==true){return '✓'}else{return '×'}}""") | ||||||
| @ -117,6 +117,8 @@ def code_page(api: ApiRequest): | |||||||
|                     embed_model_path=embedding_model_dict[EMBEDDING_MODEL], |                     embed_model_path=embedding_model_dict[EMBEDDING_MODEL], | ||||||
|                     embedding_device=EMBEDDING_DEVICE, |                     embedding_device=EMBEDDING_DEVICE, | ||||||
|                     llm_model=LLM_MODEL, |                     llm_model=LLM_MODEL, | ||||||
|  |                     api_key=llm_model_dict[LLM_MODEL]["api_key"], | ||||||
|  |                     api_base_url=llm_model_dict[LLM_MODEL]["api_base_url"], | ||||||
|                 ) |                 ) | ||||||
|                 st.toast(ret.get("msg", " ")) |                 st.toast(ret.get("msg", " ")) | ||||||
|                 st.session_state["selected_cb_name"] = cb_name |                 st.session_state["selected_cb_name"] = cb_name | ||||||
| @ -153,6 +155,8 @@ def code_page(api: ApiRequest): | |||||||
|                     embed_model_path=embedding_model_dict[EMBEDDING_MODEL], |                     embed_model_path=embedding_model_dict[EMBEDDING_MODEL], | ||||||
|                     embedding_device=EMBEDDING_DEVICE, |                     embedding_device=EMBEDDING_DEVICE, | ||||||
|                     llm_model=LLM_MODEL, |                     llm_model=LLM_MODEL, | ||||||
|  |                     api_key=llm_model_dict[LLM_MODEL]["api_key"], | ||||||
|  |                     api_base_url=llm_model_dict[LLM_MODEL]["api_base_url"], | ||||||
|                     ) |                     ) | ||||||
|             st.toast(ret.get("msg", "删除成功")) |             st.toast(ret.get("msg", "删除成功")) | ||||||
|             time.sleep(0.05) |             time.sleep(0.05) | ||||||
|  | |||||||
| @ -11,7 +11,7 @@ from coagent.chat.search_chat import SEARCH_ENGINES | |||||||
| from coagent.connector import PHASE_LIST, PHASE_CONFIGS | from coagent.connector import PHASE_LIST, PHASE_CONFIGS | ||||||
| from coagent.service.service_factory import get_cb_details_by_cb_name | from coagent.service.service_factory import get_cb_details_by_cb_name | ||||||
| 
 | 
 | ||||||
| from configs.model_config import EMBEDDING_DEVICE, EMBEDDING_MODEL, embedding_model_dict, EMBEDDING_ENGINE, KB_ROOT_PATH | from configs.model_config import EMBEDDING_DEVICE, EMBEDDING_MODEL, embedding_model_dict, EMBEDDING_ENGINE, KB_ROOT_PATH, llm_model_dict | ||||||
| chat_box = ChatBox( | chat_box = ChatBox( | ||||||
|     assistant_avatar="../sources/imgs/devops-chatbot2.png" |     assistant_avatar="../sources/imgs/devops-chatbot2.png" | ||||||
| ) | ) | ||||||
| @ -174,7 +174,7 @@ def dialogue_page(api: ApiRequest): | |||||||
|             is_detailed = st.toggle(webui_configs["dialogue"]["phase_toggle_detailed_name"], False) |             is_detailed = st.toggle(webui_configs["dialogue"]["phase_toggle_detailed_name"], False) | ||||||
|             tool_using_on = st.toggle( |             tool_using_on = st.toggle( | ||||||
|                 webui_configs["dialogue"]["phase_toggle_doToolUsing"],  |                 webui_configs["dialogue"]["phase_toggle_doToolUsing"],  | ||||||
|                 PHASE_CONFIGS[choose_phase]["do_using_tool"]) |                 PHASE_CONFIGS[choose_phase].get("do_using_tool", False)) | ||||||
|             tool_selects = [] |             tool_selects = [] | ||||||
|             if tool_using_on: |             if tool_using_on: | ||||||
|                 with st.expander("工具军火库", True): |                 with st.expander("工具军火库", True): | ||||||
| @ -183,7 +183,7 @@ def dialogue_page(api: ApiRequest): | |||||||
|                         TOOL_SETS, ["WeatherInfo"]) |                         TOOL_SETS, ["WeatherInfo"]) | ||||||
|              |              | ||||||
|             search_on = st.toggle(webui_configs["dialogue"]["phase_toggle_doSearch"],  |             search_on = st.toggle(webui_configs["dialogue"]["phase_toggle_doSearch"],  | ||||||
|                                   PHASE_CONFIGS[choose_phase]["do_search"]) |                                   PHASE_CONFIGS[choose_phase].get("do_search", False)) | ||||||
|             search_engine, top_k = None, 3 |             search_engine, top_k = None, 3 | ||||||
|             if search_on: |             if search_on: | ||||||
|                 with st.expander(webui_configs["dialogue"]["expander_search_name"], True): |                 with st.expander(webui_configs["dialogue"]["expander_search_name"], True): | ||||||
| @ -195,7 +195,8 @@ def dialogue_page(api: ApiRequest): | |||||||
| 
 | 
 | ||||||
|             doc_retrieval_on = st.toggle( |             doc_retrieval_on = st.toggle( | ||||||
|                 webui_configs["dialogue"]["phase_toggle_doDocRetrieval"],  |                 webui_configs["dialogue"]["phase_toggle_doDocRetrieval"],  | ||||||
|                   PHASE_CONFIGS[choose_phase]["do_doc_retrieval"]) |                   PHASE_CONFIGS[choose_phase].get("do_doc_retrieval", False) | ||||||
|  |             ) | ||||||
|             selected_kb, top_k, score_threshold = None, 3, 1.0 |             selected_kb, top_k, score_threshold = None, 3, 1.0 | ||||||
|             if doc_retrieval_on: |             if doc_retrieval_on: | ||||||
|                 with st.expander(webui_configs["dialogue"]["kbase_expander_name"], True): |                 with st.expander(webui_configs["dialogue"]["kbase_expander_name"], True): | ||||||
| @ -215,7 +216,7 @@ def dialogue_page(api: ApiRequest): | |||||||
|                      |                      | ||||||
|             code_retrieval_on = st.toggle( |             code_retrieval_on = st.toggle( | ||||||
|                 webui_configs["dialogue"]["phase_toggle_doCodeRetrieval"],  |                 webui_configs["dialogue"]["phase_toggle_doCodeRetrieval"],  | ||||||
|                   PHASE_CONFIGS[choose_phase]["do_code_retrieval"]) |                   PHASE_CONFIGS[choose_phase].get("do_code_retrieval", False)) | ||||||
|             selected_cb, top_k = None, 1 |             selected_cb, top_k = None, 1 | ||||||
|             cb_search_type = "tag" |             cb_search_type = "tag" | ||||||
|             if code_retrieval_on: |             if code_retrieval_on: | ||||||
| @ -296,7 +297,8 @@ def dialogue_page(api: ApiRequest): | |||||||
|             r = api.chat_chat( |             r = api.chat_chat( | ||||||
|                 prompt, history, no_remote_api=True,  |                 prompt, history, no_remote_api=True,  | ||||||
|                 embed_model=EMBEDDING_MODEL, embed_model_path=embedding_model_dict[EMBEDDING_MODEL], |                 embed_model=EMBEDDING_MODEL, embed_model_path=embedding_model_dict[EMBEDDING_MODEL], | ||||||
|                 model_device=EMBEDDING_DEVICE, embed_engine=EMBEDDING_ENGINE, |                 model_device=EMBEDDING_DEVICE, embed_engine=EMBEDDING_ENGINE,api_key=llm_model_dict[LLM_MODEL]["api_key"], | ||||||
|  |                 api_base_url=llm_model_dict[LLM_MODEL]["api_base_url"], | ||||||
|                 llm_model=LLM_MODEL) |                 llm_model=LLM_MODEL) | ||||||
|             for t in r: |             for t in r: | ||||||
|                 if error_msg := check_error_msg(t):  # check whether error occured |                 if error_msg := check_error_msg(t):  # check whether error occured | ||||||
| @ -362,6 +364,8 @@ def dialogue_page(api: ApiRequest): | |||||||
|                 "embed_engine": EMBEDDING_ENGINE, |                 "embed_engine": EMBEDDING_ENGINE, | ||||||
|                 "kb_root_path": KB_ROOT_PATH, |                 "kb_root_path": KB_ROOT_PATH, | ||||||
|                 "model_name": LLM_MODEL, |                 "model_name": LLM_MODEL, | ||||||
|  |                 "api_key": llm_model_dict[LLM_MODEL]["api_key"], | ||||||
|  |                 "api_base_url": llm_model_dict[LLM_MODEL]["api_base_url"], | ||||||
|             } |             } | ||||||
|             text = "" |             text = "" | ||||||
|             d = {"docs": []} |             d = {"docs": []} | ||||||
| @ -405,7 +409,10 @@ def dialogue_page(api: ApiRequest): | |||||||
|                 api.knowledge_base_chat( |                 api.knowledge_base_chat( | ||||||
|                     prompt, selected_kb, kb_top_k, score_threshold, history, |                     prompt, selected_kb, kb_top_k, score_threshold, history, | ||||||
|                     embed_model=EMBEDDING_MODEL, embed_model_path=embedding_model_dict[EMBEDDING_MODEL], |                     embed_model=EMBEDDING_MODEL, embed_model_path=embedding_model_dict[EMBEDDING_MODEL], | ||||||
|                     model_device=EMBEDDING_DEVICE, embed_engine=EMBEDDING_ENGINE, llm_model=LLM_MODEL) |                     model_device=EMBEDDING_DEVICE, embed_engine=EMBEDDING_ENGINE, llm_model=LLM_MODEL, | ||||||
|  |                     api_key=llm_model_dict[LLM_MODEL]["api_key"], | ||||||
|  |                     api_base_url=llm_model_dict[LLM_MODEL]["api_base_url"], | ||||||
|  |                     ) | ||||||
|                     ): |                     ): | ||||||
|                 if error_msg := check_error_msg(d): # check whether error occured |                 if error_msg := check_error_msg(d): # check whether error occured | ||||||
|                     st.error(error_msg) |                     st.error(error_msg) | ||||||
| @ -415,11 +422,7 @@ def dialogue_page(api: ApiRequest): | |||||||
|                 # chat_box.update_msg("知识库匹配结果: \n\n".join(d["docs"]), element_index=1, streaming=False, state="complete") |                 # chat_box.update_msg("知识库匹配结果: \n\n".join(d["docs"]), element_index=1, streaming=False, state="complete") | ||||||
|             chat_box.update_msg(text, element_index=0, streaming=False)  # 更新最终的字符串,去除光标 |             chat_box.update_msg(text, element_index=0, streaming=False)  # 更新最终的字符串,去除光标 | ||||||
|             chat_box.update_msg("{webui_configs['chat']['chatbox_doc_result']}: \n\n".join(d["docs"]), element_index=1, streaming=False, state="complete") |             chat_box.update_msg("{webui_configs['chat']['chatbox_doc_result']}: \n\n".join(d["docs"]), element_index=1, streaming=False, state="complete") | ||||||
|             # # 判断是否存在代码, 并提高编辑功能,执行功能 | 
 | ||||||
|             # code_text = api.codebox.decode_code_from_text(text) |  | ||||||
|             # GLOBAL_EXE_CODE_TEXT = code_text |  | ||||||
|             # if code_text and code_exec_on: |  | ||||||
|             #     codebox_res = api.codebox_chat("```"+code_text+"```", do_code_exe=True) |  | ||||||
|         elif dialogue_mode == webui_configs["dialogue"]["mode"][2]: |         elif dialogue_mode == webui_configs["dialogue"]["mode"][2]: | ||||||
|             logger.info('prompt={}'.format(prompt)) |             logger.info('prompt={}'.format(prompt)) | ||||||
|             logger.info('history={}'.format(history)) |             logger.info('history={}'.format(history)) | ||||||
| @ -438,7 +441,9 @@ def dialogue_page(api: ApiRequest): | |||||||
|                                                              cb_search_type=cb_search_type, |                                                              cb_search_type=cb_search_type, | ||||||
|                                                              no_remote_api=True, embed_model=EMBEDDING_MODEL,  |                                                              no_remote_api=True, embed_model=EMBEDDING_MODEL,  | ||||||
|                                                              embed_model_path=embedding_model_dict[EMBEDDING_MODEL], |                                                              embed_model_path=embedding_model_dict[EMBEDDING_MODEL], | ||||||
|                                                              embed_engine=EMBEDDING_ENGINE, llm_model=LLM_MODEL |                                                              embed_engine=EMBEDDING_ENGINE, llm_model=LLM_MODEL, | ||||||
|  |                                                              api_key=llm_model_dict[LLM_MODEL]["api_key"], | ||||||
|  |                                                              api_base_url=llm_model_dict[LLM_MODEL]["api_base_url"], | ||||||
|                                                              )): |                                                              )): | ||||||
|                 if error_msg := check_error_msg(d): |                 if error_msg := check_error_msg(d): | ||||||
|                     st.error(error_msg) |                     st.error(error_msg) | ||||||
| @ -448,6 +453,7 @@ def dialogue_page(api: ApiRequest): | |||||||
|                     chat_box.update_msg(text, element_index=0) |                     chat_box.update_msg(text, element_index=0) | ||||||
| 
 | 
 | ||||||
|             # postprocess |             # postprocess | ||||||
|  |             logger.debug(f"d={d}") | ||||||
|             text = replace_lt_gt(text) |             text = replace_lt_gt(text) | ||||||
|             chat_box.update_msg(text, element_index=0, streaming=False)  # 更新最终的字符串,去除光标 |             chat_box.update_msg(text, element_index=0, streaming=False)  # 更新最终的字符串,去除光标 | ||||||
|             logger.debug('text={}'.format(text)) |             logger.debug('text={}'.format(text)) | ||||||
| @ -467,7 +473,9 @@ def dialogue_page(api: ApiRequest): | |||||||
|                 api.search_engine_chat( |                 api.search_engine_chat( | ||||||
|                     prompt, search_engine, se_top_k, history, embed_model=EMBEDDING_MODEL,  |                     prompt, search_engine, se_top_k, history, embed_model=EMBEDDING_MODEL,  | ||||||
|                     embed_model_path=embedding_model_dict[EMBEDDING_MODEL], |                     embed_model_path=embedding_model_dict[EMBEDDING_MODEL], | ||||||
|                     model_device=EMBEDDING_DEVICE, embed_engine=EMBEDDING_ENGINE, llm_model=LLM_MODEL) |                     model_device=EMBEDDING_DEVICE, embed_engine=EMBEDDING_ENGINE, llm_model=LLM_MODEL, | ||||||
|  |                     pi_key=llm_model_dict[LLM_MODEL]["api_key"], | ||||||
|  |                     api_base_url=llm_model_dict[LLM_MODEL]["api_base_url"],) | ||||||
|                     ): |                     ): | ||||||
|                 if error_msg := check_error_msg(d): # check whether error occured |                 if error_msg := check_error_msg(d): # check whether error occured | ||||||
|                     st.error(error_msg) |                     st.error(error_msg) | ||||||
| @ -477,56 +485,11 @@ def dialogue_page(api: ApiRequest): | |||||||
|                 # chat_box.update_msg("搜索匹配结果: \n\n".join(d["docs"]), element_index=1, streaming=False) |                 # chat_box.update_msg("搜索匹配结果: \n\n".join(d["docs"]), element_index=1, streaming=False) | ||||||
|             chat_box.update_msg(text, element_index=0, streaming=False)  # 更新最终的字符串,去除光标 |             chat_box.update_msg(text, element_index=0, streaming=False)  # 更新最终的字符串,去除光标 | ||||||
|             chat_box.update_msg(f"{webui_configs['chat']['chatbox_search_result']}: \n\n".join(d["docs"]), element_index=1, streaming=False, state="complete") |             chat_box.update_msg(f"{webui_configs['chat']['chatbox_search_result']}: \n\n".join(d["docs"]), element_index=1, streaming=False, state="complete") | ||||||
|             # # 判断是否存在代码, 并提高编辑功能,执行功能 |  | ||||||
|             # code_text = api.codebox.decode_code_from_text(text) |  | ||||||
|             # GLOBAL_EXE_CODE_TEXT = code_text |  | ||||||
|             # if code_text and code_exec_on: |  | ||||||
|             #     codebox_res = api.codebox_chat("```"+code_text+"```", do_code_exe=True) |  | ||||||
| 
 | 
 | ||||||
|         # 将上传文件清空 |         # 将上传文件清空 | ||||||
|         st.session_state["interpreter_file_key"] += 1 |         st.session_state["interpreter_file_key"] += 1 | ||||||
|         st.experimental_rerun() |         st.experimental_rerun() | ||||||
| 
 | 
 | ||||||
|     # if code_interpreter_on: |  | ||||||
|     #     with st.expander(webui_configs['sandbox']['expander_code_name'], False): |  | ||||||
|     #         code_part = st.text_area( |  | ||||||
|     #             webui_configs['sandbox']['textArea_code_name'], code_text, key="code_text") |  | ||||||
|     #         cols = st.columns(2) |  | ||||||
|     #         if cols[0].button( |  | ||||||
|     #             webui_configs['sandbox']['button_modify_code_name'], |  | ||||||
|     #             use_container_width=True, |  | ||||||
|     #         ): |  | ||||||
|     #             code_text = code_part |  | ||||||
|     #             GLOBAL_EXE_CODE_TEXT = code_text |  | ||||||
|     #             st.toast(webui_configs['sandbox']['text_modify_code']) |  | ||||||
| 
 |  | ||||||
|     #         if cols[1].button( |  | ||||||
|     #             webui_configs['sandbox']['button_exec_code_name'], |  | ||||||
|     #             use_container_width=True |  | ||||||
|     #         ): |  | ||||||
|     #             if code_text: |  | ||||||
|     #                 codebox_res = api.codebox_chat("```"+code_text+"```", do_code_exe=True) |  | ||||||
|     #                 st.toast(webui_configs['sandbox']['text_execing_code'],) |  | ||||||
|     #             else: |  | ||||||
|     #                 st.toast(webui_configs['sandbox']['text_error_exec_code'],) |  | ||||||
| 
 |  | ||||||
|     # #TODO 这段信息会被记录到history里 |  | ||||||
|     # if codebox_res is not None and codebox_res.code_exe_status != 200: |  | ||||||
|     #     st.toast(f"{codebox_res.code_exe_response}") |  | ||||||
| 
 |  | ||||||
|     # if codebox_res is not None and codebox_res.code_exe_status == 200: |  | ||||||
|     #     st.toast(f"codebox_chat {codebox_res}") |  | ||||||
|     #     chat_box.ai_say(Markdown(code_text, in_expander=True, title="code interpreter", unsafe_allow_html=True), ) |  | ||||||
|     #     if codebox_res.code_exe_type == "image/png": |  | ||||||
|     #         base_text = f"```\n{code_text}\n```\n\n" |  | ||||||
|     #         img_html = "<img src='data:image/png;base64,{}' class='img-fluid'>".format( |  | ||||||
|     #             codebox_res.code_exe_response |  | ||||||
|     #         ) |  | ||||||
|     #         chat_box.update_msg(img_html, streaming=False, state="complete") |  | ||||||
|     #     else: |  | ||||||
|     #         chat_box.update_msg('```\n'+code_text+'\n```'+"\n\n"+'```\n'+codebox_res.code_exe_response+'\n```',  |  | ||||||
|     #                             streaming=False, state="complete") |  | ||||||
| 
 |  | ||||||
|     now = datetime.now() |     now = datetime.now() | ||||||
|     with st.sidebar: |     with st.sidebar: | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -14,7 +14,8 @@ from coagent.orm import table_init | |||||||
| 
 | 
 | ||||||
| from configs.model_config import ( | from configs.model_config import ( | ||||||
|     KB_ROOT_PATH, kbs_config, DEFAULT_VS_TYPE, WEB_CRAWL_PATH, |     KB_ROOT_PATH, kbs_config, DEFAULT_VS_TYPE, WEB_CRAWL_PATH, | ||||||
|     EMBEDDING_DEVICE, EMBEDDING_ENGINE, EMBEDDING_MODEL, embedding_model_dict |     EMBEDDING_DEVICE, EMBEDDING_ENGINE, EMBEDDING_MODEL, embedding_model_dict, | ||||||
|  |     llm_model_dict | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| # SENTENCE_SIZE = 100 | # SENTENCE_SIZE = 100 | ||||||
| @ -136,6 +137,8 @@ def knowledge_page( | |||||||
|                     embed_engine=EMBEDDING_ENGINE, |                     embed_engine=EMBEDDING_ENGINE, | ||||||
|                     embedding_device= EMBEDDING_DEVICE, |                     embedding_device= EMBEDDING_DEVICE, | ||||||
|                     embed_model_path=embedding_model_dict[embed_model], |                     embed_model_path=embedding_model_dict[embed_model], | ||||||
|  |                     api_key=llm_model_dict[LLM_MODEL]["api_key"], | ||||||
|  |                     api_base_url=llm_model_dict[LLM_MODEL]["api_base_url"], | ||||||
|                 ) |                 ) | ||||||
|                 st.toast(ret.get("msg", " ")) |                 st.toast(ret.get("msg", " ")) | ||||||
|                 st.session_state["selected_kb_name"] = kb_name |                 st.session_state["selected_kb_name"] = kb_name | ||||||
| @ -160,7 +163,10 @@ def knowledge_page( | |||||||
|             data = [{"file": f, "knowledge_base_name": kb, "not_refresh_vs_cache": True, "embed_model": EMBEDDING_MODEL, |             data = [{"file": f, "knowledge_base_name": kb, "not_refresh_vs_cache": True, "embed_model": EMBEDDING_MODEL, | ||||||
|             "embed_model_path": embedding_model_dict[EMBEDDING_MODEL], |             "embed_model_path": embedding_model_dict[EMBEDDING_MODEL], | ||||||
|             "model_device": EMBEDDING_DEVICE, |             "model_device": EMBEDDING_DEVICE, | ||||||
|             "embed_engine": EMBEDDING_ENGINE} |             "embed_engine": EMBEDDING_ENGINE, | ||||||
|  |             "api_key": llm_model_dict[LLM_MODEL]["api_key"], | ||||||
|  |             "api_base_url": llm_model_dict[LLM_MODEL]["api_base_url"], | ||||||
|  |             } | ||||||
|                     for f in files] |                     for f in files] | ||||||
|             data[-1]["not_refresh_vs_cache"]=False |             data[-1]["not_refresh_vs_cache"]=False | ||||||
|             for k in data: |             for k in data: | ||||||
| @ -210,7 +216,9 @@ def knowledge_page( | |||||||
|                          "embed_model": EMBEDDING_MODEL, |                          "embed_model": EMBEDDING_MODEL, | ||||||
|                         "embed_model_path": embedding_model_dict[EMBEDDING_MODEL], |                         "embed_model_path": embedding_model_dict[EMBEDDING_MODEL], | ||||||
|                         "model_device": EMBEDDING_DEVICE, |                         "model_device": EMBEDDING_DEVICE, | ||||||
|                         "embed_engine": EMBEDDING_ENGINE}] |                         "embed_engine": EMBEDDING_ENGINE, | ||||||
|  |                         "api_key": llm_model_dict[LLM_MODEL]["api_key"], | ||||||
|  |                         "api_base_url": llm_model_dict[LLM_MODEL]["api_base_url"],}] | ||||||
|                 for k in data: |                 for k in data: | ||||||
|                     ret = api.upload_kb_doc(**k) |                     ret = api.upload_kb_doc(**k) | ||||||
|                     logger.info(ret) |                     logger.info(ret) | ||||||
| @ -297,7 +305,9 @@ def knowledge_page( | |||||||
|                     api.update_kb_doc(kb, row["file_name"],  |                     api.update_kb_doc(kb, row["file_name"],  | ||||||
|                                       embed_engine=EMBEDDING_ENGINE,embed_model=EMBEDDING_MODEL, |                                       embed_engine=EMBEDDING_ENGINE,embed_model=EMBEDDING_MODEL, | ||||||
|                                       embed_model_path=embedding_model_dict[EMBEDDING_MODEL], |                                       embed_model_path=embedding_model_dict[EMBEDDING_MODEL], | ||||||
|                                       model_device=EMBEDDING_DEVICE |                                       model_device=EMBEDDING_DEVICE, | ||||||
|  |                                       api_key=llm_model_dict[LLM_MODEL]["api_key"], | ||||||
|  |                                       api_base_url=llm_model_dict[LLM_MODEL]["api_base_url"], | ||||||
|                                       ) |                                       ) | ||||||
|                 st.experimental_rerun() |                 st.experimental_rerun() | ||||||
| 
 | 
 | ||||||
| @ -311,7 +321,9 @@ def knowledge_page( | |||||||
|                     api.delete_kb_doc(kb, row["file_name"], |                     api.delete_kb_doc(kb, row["file_name"], | ||||||
|                                       embed_engine=EMBEDDING_ENGINE,embed_model=EMBEDDING_MODEL, |                                       embed_engine=EMBEDDING_ENGINE,embed_model=EMBEDDING_MODEL, | ||||||
|                                       embed_model_path=embedding_model_dict[EMBEDDING_MODEL], |                                       embed_model_path=embedding_model_dict[EMBEDDING_MODEL], | ||||||
|                                       model_device=EMBEDDING_DEVICE) |                                       model_device=EMBEDDING_DEVICE, | ||||||
|  |                                       api_key=llm_model_dict[LLM_MODEL]["api_key"], | ||||||
|  |                                       api_base_url=llm_model_dict[LLM_MODEL]["api_base_url"],) | ||||||
|                 st.experimental_rerun() |                 st.experimental_rerun() | ||||||
| 
 | 
 | ||||||
|             if cols[3].button( |             if cols[3].button( | ||||||
| @ -323,7 +335,9 @@ def knowledge_page( | |||||||
|                     ret = api.delete_kb_doc(kb, row["file_name"], True, |                     ret = api.delete_kb_doc(kb, row["file_name"], True, | ||||||
|                                       embed_engine=EMBEDDING_ENGINE,embed_model=EMBEDDING_MODEL, |                                       embed_engine=EMBEDDING_ENGINE,embed_model=EMBEDDING_MODEL, | ||||||
|                                       embed_model_path=embedding_model_dict[EMBEDDING_MODEL], |                                       embed_model_path=embedding_model_dict[EMBEDDING_MODEL], | ||||||
|                                       model_device=EMBEDDING_DEVICE) |                                       model_device=EMBEDDING_DEVICE, | ||||||
|  |                                       api_key=llm_model_dict[LLM_MODEL]["api_key"], | ||||||
|  |                                       api_base_url=llm_model_dict[LLM_MODEL]["api_base_url"],) | ||||||
|                     st.toast(ret.get("msg", " ")) |                     st.toast(ret.get("msg", " ")) | ||||||
|                 st.experimental_rerun() |                 st.experimental_rerun() | ||||||
| 
 | 
 | ||||||
| @ -344,6 +358,8 @@ def knowledge_page( | |||||||
|                 for d in api.recreate_vector_store( |                 for d in api.recreate_vector_store( | ||||||
|                     kb, vs_type=default_vs_type, embed_model=embedding_model, embedding_device=EMBEDDING_DEVICE, |                     kb, vs_type=default_vs_type, embed_model=embedding_model, embedding_device=EMBEDDING_DEVICE, | ||||||
|                       embed_model_path=embedding_model_dict["embedding_model"], embed_engine=EMBEDDING_ENGINE, |                       embed_model_path=embedding_model_dict["embedding_model"], embed_engine=EMBEDDING_ENGINE, | ||||||
|  |                       api_key=llm_model_dict[LLM_MODEL]["api_key"], | ||||||
|  |                       api_base_url=llm_model_dict[LLM_MODEL]["api_base_url"], | ||||||
|                     ): |                     ): | ||||||
|                     if msg := check_error_msg(d): |                     if msg := check_error_msg(d): | ||||||
|                         st.toast(msg) |                         st.toast(msg) | ||||||
|  | |||||||
| @ -299,7 +299,9 @@ class ApiRequest: | |||||||
|         stream: bool = True, |         stream: bool = True, | ||||||
|         no_remote_api: bool = None, |         no_remote_api: bool = None, | ||||||
|         embed_model: str="", embed_model_path: str="", model_device: str="", embed_engine: str="", |         embed_model: str="", embed_model_path: str="", model_device: str="", embed_engine: str="", | ||||||
|         llm_model: str ="", temperature: float= 0.2 |         llm_model: str ="", temperature: float= 0.2, | ||||||
|  |         api_key: str=os.environ["OPENAI_API_KEY"], | ||||||
|  |         api_base_url: str = os.environ["API_BASE_URL"], | ||||||
|     ): |     ): | ||||||
|         ''' |         ''' | ||||||
|         对应api.py/chat/chat接口 |         对应api.py/chat/chat接口 | ||||||
| @ -311,8 +313,8 @@ class ApiRequest: | |||||||
|             "query": query, |             "query": query, | ||||||
|             "history": history, |             "history": history, | ||||||
|             "stream": stream, |             "stream": stream, | ||||||
|             "api_key": os.environ["OPENAI_API_KEY"], |             "api_key": api_key, | ||||||
|             "api_base_url": os.environ["API_BASE_URL"], |             "api_base_url": api_base_url, | ||||||
|             "embed_model": embed_model, |             "embed_model": embed_model, | ||||||
|             "embed_model_path": embed_model_path, |             "embed_model_path": embed_model_path, | ||||||
|             "embed_engine": embed_engine, |             "embed_engine": embed_engine, | ||||||
| @ -339,7 +341,9 @@ class ApiRequest: | |||||||
|         stream: bool = True, |         stream: bool = True, | ||||||
|         no_remote_api: bool = None, |         no_remote_api: bool = None, | ||||||
|         embed_model: str="", embed_model_path: str="", model_device: str="", embed_engine: str="", |         embed_model: str="", embed_model_path: str="", model_device: str="", embed_engine: str="", | ||||||
|         llm_model: str ="", temperature: float= 0.2 |         llm_model: str ="", temperature: float= 0.2, | ||||||
|  |         api_key: str=os.environ["OPENAI_API_KEY"], | ||||||
|  |         api_base_url: str = os.environ["API_BASE_URL"], | ||||||
|     ): |     ): | ||||||
|         ''' |         ''' | ||||||
|         对应api.py/chat/knowledge_base_chat接口 |         对应api.py/chat/knowledge_base_chat接口 | ||||||
| @ -355,8 +359,8 @@ class ApiRequest: | |||||||
|             "history": history, |             "history": history, | ||||||
|             "stream": stream, |             "stream": stream, | ||||||
|             "local_doc_url": no_remote_api, |             "local_doc_url": no_remote_api, | ||||||
|             "api_key": os.environ["OPENAI_API_KEY"], |             "api_key": api_key, | ||||||
|             "api_base_url": os.environ["API_BASE_URL"], |             "api_base_url": api_base_url, | ||||||
|             "embed_model": embed_model, |             "embed_model": embed_model, | ||||||
|             "embed_model_path": embed_model_path, |             "embed_model_path": embed_model_path, | ||||||
|             "embed_engine": embed_engine, |             "embed_engine": embed_engine, | ||||||
| @ -386,7 +390,10 @@ class ApiRequest: | |||||||
|         stream: bool = True, |         stream: bool = True, | ||||||
|         no_remote_api: bool = None, |         no_remote_api: bool = None, | ||||||
|         embed_model: str="", embed_model_path: str="", model_device: str="", embed_engine: str="", |         embed_model: str="", embed_model_path: str="", model_device: str="", embed_engine: str="", | ||||||
|         llm_model: str ="", temperature: float= 0.2 |         llm_model: str ="", temperature: float= 0.2, | ||||||
|  |         api_key: str=os.environ["OPENAI_API_KEY"], | ||||||
|  |         api_base_url: str = os.environ["API_BASE_URL"], | ||||||
|  | 
 | ||||||
|     ): |     ): | ||||||
|         ''' |         ''' | ||||||
|         对应api.py/chat/search_engine_chat接口 |         对应api.py/chat/search_engine_chat接口 | ||||||
| @ -400,8 +407,8 @@ class ApiRequest: | |||||||
|             "top_k": top_k, |             "top_k": top_k, | ||||||
|             "history": history, |             "history": history, | ||||||
|             "stream": stream, |             "stream": stream, | ||||||
|             "api_key": os.environ["OPENAI_API_KEY"], |             "api_key": api_key, | ||||||
|             "api_base_url": os.environ["API_BASE_URL"], |             "api_base_url": api_base_url, | ||||||
|             "embed_model": embed_model, |             "embed_model": embed_model, | ||||||
|             "embed_model_path": embed_model_path, |             "embed_model_path": embed_model_path, | ||||||
|             "embed_engine": embed_engine, |             "embed_engine": embed_engine, | ||||||
| @ -432,7 +439,9 @@ class ApiRequest: | |||||||
|         stream: bool = True, |         stream: bool = True, | ||||||
|         no_remote_api: bool = None, |         no_remote_api: bool = None, | ||||||
|         embed_model: str="", embed_model_path: str="", model_device: str="", embed_engine: str="", |         embed_model: str="", embed_model_path: str="", model_device: str="", embed_engine: str="", | ||||||
|         llm_model: str ="", temperature: float= 0.2 |         llm_model: str ="", temperature: float= 0.2, | ||||||
|  |         api_key: str=os.environ["OPENAI_API_KEY"], | ||||||
|  |         api_base_url: str = os.environ["API_BASE_URL"], | ||||||
|     ): |     ): | ||||||
|         ''' |         ''' | ||||||
|         对应api.py/chat/knowledge_base_chat接口 |         对应api.py/chat/knowledge_base_chat接口 | ||||||
| @ -458,8 +467,8 @@ class ApiRequest: | |||||||
|             "cb_search_type": cb_search_type, |             "cb_search_type": cb_search_type, | ||||||
|             "stream": stream, |             "stream": stream, | ||||||
|             "local_doc_url": no_remote_api, |             "local_doc_url": no_remote_api, | ||||||
|             "api_key": os.environ["OPENAI_API_KEY"], |             "api_key": api_key, | ||||||
|             "api_base_url": os.environ["API_BASE_URL"], |             "api_base_url": api_base_url, | ||||||
|             "embed_model": embed_model, |             "embed_model": embed_model, | ||||||
|             "embed_model_path": embed_model_path, |             "embed_model_path": embed_model_path, | ||||||
|             "embed_engine": embed_engine, |             "embed_engine": embed_engine, | ||||||
| @ -510,6 +519,8 @@ class ApiRequest: | |||||||
|         embed_model: str="", embed_model_path: str="",  |         embed_model: str="", embed_model_path: str="",  | ||||||
|         model_device: str="", embed_engine: str="", |         model_device: str="", embed_engine: str="", | ||||||
|         temperature: float=0.2, model_name:str ="", |         temperature: float=0.2, model_name:str ="", | ||||||
|  |         api_key: str=os.environ["OPENAI_API_KEY"], | ||||||
|  |         api_base_url: str = os.environ["API_BASE_URL"], | ||||||
|     ): |     ): | ||||||
|         ''' |         ''' | ||||||
|         对应api.py/chat/chat接口 |         对应api.py/chat/chat接口 | ||||||
| @ -541,8 +552,8 @@ class ApiRequest: | |||||||
|             "isDetailed": isDetailed, |             "isDetailed": isDetailed, | ||||||
|             "upload_file": upload_file, |             "upload_file": upload_file, | ||||||
|             "kb_root_path": kb_root_path, |             "kb_root_path": kb_root_path, | ||||||
|             "api_key": os.environ["OPENAI_API_KEY"], |             "api_key": api_key, | ||||||
|             "api_base_url": os.environ["API_BASE_URL"], |             "api_base_url": api_base_url, | ||||||
|             "embed_model": embed_model, |             "embed_model": embed_model, | ||||||
|             "embed_model_path": embed_model_path, |             "embed_model_path": embed_model_path, | ||||||
|             "embed_engine": embed_engine, |             "embed_engine": embed_engine, | ||||||
| @ -588,6 +599,8 @@ class ApiRequest: | |||||||
|         embed_model: str="", embed_model_path: str="",  |         embed_model: str="", embed_model_path: str="",  | ||||||
|         model_device: str="", embed_engine: str="", |         model_device: str="", embed_engine: str="", | ||||||
|         temperature: float=0.2, model_name: str="", |         temperature: float=0.2, model_name: str="", | ||||||
|  |         api_key: str=os.environ["OPENAI_API_KEY"], | ||||||
|  |         api_base_url: str = os.environ["API_BASE_URL"], | ||||||
|     ): |     ): | ||||||
|         ''' |         ''' | ||||||
|         对应api.py/chat/chat接口 |         对应api.py/chat/chat接口 | ||||||
| @ -620,8 +633,8 @@ class ApiRequest: | |||||||
|             "isDetailed": isDetailed, |             "isDetailed": isDetailed, | ||||||
|             "upload_file": upload_file, |             "upload_file": upload_file, | ||||||
|             "kb_root_path": kb_root_path, |             "kb_root_path": kb_root_path, | ||||||
|             "api_key": os.environ["OPENAI_API_KEY"], |             "api_key": api_key, | ||||||
|             "api_base_url": os.environ["API_BASE_URL"], |             "api_base_url": api_base_url, | ||||||
|             "embed_model": embed_model, |             "embed_model": embed_model, | ||||||
|             "embed_model_path": embed_model_path, |             "embed_model_path": embed_model_path, | ||||||
|             "embed_engine": embed_engine, |             "embed_engine": embed_engine, | ||||||
| @ -694,7 +707,9 @@ class ApiRequest: | |||||||
|         no_remote_api: bool = None, |         no_remote_api: bool = None, | ||||||
|         kb_root_path: str =KB_ROOT_PATH, |         kb_root_path: str =KB_ROOT_PATH, | ||||||
|         embed_model: str="", embed_model_path: str="",  |         embed_model: str="", embed_model_path: str="",  | ||||||
|         embedding_device: str="", embed_engine: str="" |         embedding_device: str="", embed_engine: str="", | ||||||
|  |         api_key: str=os.environ["OPENAI_API_KEY"], | ||||||
|  |         api_base_url: str = os.environ["API_BASE_URL"], | ||||||
|     ): |     ): | ||||||
|         ''' |         ''' | ||||||
|         对应api.py/knowledge_base/create_knowledge_base接口 |         对应api.py/knowledge_base/create_knowledge_base接口 | ||||||
| @ -706,8 +721,8 @@ class ApiRequest: | |||||||
|             "knowledge_base_name": knowledge_base_name, |             "knowledge_base_name": knowledge_base_name, | ||||||
|             "vector_store_type": vector_store_type, |             "vector_store_type": vector_store_type, | ||||||
|             "kb_root_path": kb_root_path, |             "kb_root_path": kb_root_path, | ||||||
|             "api_key": os.environ["OPENAI_API_KEY"], |             "api_key": api_key, | ||||||
|             "api_base_url": os.environ["API_BASE_URL"], |             "api_base_url": api_base_url, | ||||||
|             "embed_model": embed_model, |             "embed_model": embed_model, | ||||||
|             "embed_model_path": embed_model_path, |             "embed_model_path": embed_model_path, | ||||||
|             "model_device": embedding_device, |             "model_device": embedding_device, | ||||||
| @ -781,7 +796,9 @@ class ApiRequest: | |||||||
|         no_remote_api: bool = None, |         no_remote_api: bool = None, | ||||||
|         kb_root_path: str = KB_ROOT_PATH, |         kb_root_path: str = KB_ROOT_PATH, | ||||||
|         embed_model: str="", embed_model_path: str="",  |         embed_model: str="", embed_model_path: str="",  | ||||||
|         model_device: str="", embed_engine: str="" |         model_device: str="", embed_engine: str="", | ||||||
|  |         api_key: str=os.environ["OPENAI_API_KEY"], | ||||||
|  |         api_base_url: str = os.environ["API_BASE_URL"], | ||||||
|     ): |     ): | ||||||
|         ''' |         ''' | ||||||
|         对应api.py/knowledge_base/upload_docs接口 |         对应api.py/knowledge_base/upload_docs接口 | ||||||
| @ -810,8 +827,8 @@ class ApiRequest: | |||||||
|                 override, |                 override, | ||||||
|                 not_refresh_vs_cache, |                 not_refresh_vs_cache, | ||||||
|                 kb_root_path=kb_root_path, |                 kb_root_path=kb_root_path, | ||||||
|                 api_key=os.environ["OPENAI_API_KEY"], |                 api_key=api_key, | ||||||
|                 api_base_url=os.environ["API_BASE_URL"], |                 api_base_url=api_base_url, | ||||||
|                 embed_model=embed_model, |                 embed_model=embed_model, | ||||||
|                 embed_model_path=embed_model_path, |                 embed_model_path=embed_model_path, | ||||||
|                 model_device=model_device, |                 model_device=model_device, | ||||||
| @ -839,7 +856,9 @@ class ApiRequest: | |||||||
|         no_remote_api: bool = None, |         no_remote_api: bool = None, | ||||||
|         kb_root_path: str = KB_ROOT_PATH, |         kb_root_path: str = KB_ROOT_PATH, | ||||||
|         embed_model: str="", embed_model_path: str="",  |         embed_model: str="", embed_model_path: str="",  | ||||||
|         model_device: str="", embed_engine: str="" |         model_device: str="", embed_engine: str="", | ||||||
|  |         api_key: str=os.environ["OPENAI_API_KEY"], | ||||||
|  |         api_base_url: str = os.environ["API_BASE_URL"], | ||||||
|     ): |     ): | ||||||
|         ''' |         ''' | ||||||
|         对应api.py/knowledge_base/delete_doc接口 |         对应api.py/knowledge_base/delete_doc接口 | ||||||
| @ -853,8 +872,8 @@ class ApiRequest: | |||||||
|             "delete_content": delete_content, |             "delete_content": delete_content, | ||||||
|             "not_refresh_vs_cache": not_refresh_vs_cache, |             "not_refresh_vs_cache": not_refresh_vs_cache, | ||||||
|             "kb_root_path": kb_root_path, |             "kb_root_path": kb_root_path, | ||||||
|             "api_key": os.environ["OPENAI_API_KEY"], |             "api_key": api_key, | ||||||
|             "api_base_url": os.environ["API_BASE_URL"], |             "api_base_url": api_base_url, | ||||||
|             "embed_model": embed_model, |             "embed_model": embed_model, | ||||||
|             "embed_model_path": embed_model_path, |             "embed_model_path": embed_model_path, | ||||||
|             "model_device": model_device, |             "model_device": model_device, | ||||||
| @ -878,7 +897,9 @@ class ApiRequest: | |||||||
|         not_refresh_vs_cache: bool = False, |         not_refresh_vs_cache: bool = False, | ||||||
|         no_remote_api: bool = None, |         no_remote_api: bool = None, | ||||||
|         embed_model: str="", embed_model_path: str="",  |         embed_model: str="", embed_model_path: str="",  | ||||||
|         model_device: str="", embed_engine: str="" |         model_device: str="", embed_engine: str="", | ||||||
|  |         api_key: str=os.environ["OPENAI_API_KEY"], | ||||||
|  |         api_base_url: str = os.environ["API_BASE_URL"], | ||||||
|     ): |     ): | ||||||
|         ''' |         ''' | ||||||
|         对应api.py/knowledge_base/update_doc接口 |         对应api.py/knowledge_base/update_doc接口 | ||||||
| @ -889,8 +910,8 @@ class ApiRequest: | |||||||
|         if no_remote_api: |         if no_remote_api: | ||||||
|             response = run_async(update_doc( |             response = run_async(update_doc( | ||||||
|                 knowledge_base_name, file_name, not_refresh_vs_cache, kb_root_path=KB_ROOT_PATH, |                 knowledge_base_name, file_name, not_refresh_vs_cache, kb_root_path=KB_ROOT_PATH, | ||||||
|                                 api_key=os.environ["OPENAI_API_KEY"], |                                 api_key=api_key, | ||||||
|                 api_base_url=os.environ["API_BASE_URL"], |                 api_base_url=api_base_url, | ||||||
|                 embed_model=embed_model, |                 embed_model=embed_model, | ||||||
|                 embed_model_path=embed_model_path, |                 embed_model_path=embed_model_path, | ||||||
|                 model_device=model_device, |                 model_device=model_device, | ||||||
| @ -915,7 +936,9 @@ class ApiRequest: | |||||||
|         no_remote_api: bool = None, |         no_remote_api: bool = None, | ||||||
|         kb_root_path: str =KB_ROOT_PATH, |         kb_root_path: str =KB_ROOT_PATH, | ||||||
|         embed_model: str="", embed_model_path: str="",  |         embed_model: str="", embed_model_path: str="",  | ||||||
|         embedding_device: str="", embed_engine: str="" |         embedding_device: str="", embed_engine: str="", | ||||||
|  |         api_key: str=os.environ["OPENAI_API_KEY"], | ||||||
|  |         api_base_url: str = os.environ["API_BASE_URL"], | ||||||
|     ): |     ): | ||||||
|         ''' |         ''' | ||||||
|         对应api.py/knowledge_base/recreate_vector_store接口 |         对应api.py/knowledge_base/recreate_vector_store接口 | ||||||
| @ -928,8 +951,8 @@ class ApiRequest: | |||||||
|             "allow_empty_kb": allow_empty_kb, |             "allow_empty_kb": allow_empty_kb, | ||||||
|             "vs_type": vs_type, |             "vs_type": vs_type, | ||||||
|             "kb_root_path": kb_root_path, |             "kb_root_path": kb_root_path, | ||||||
|             "api_key": os.environ["OPENAI_API_KEY"], |             "api_key": api_key, | ||||||
|             "api_base_url": os.environ["API_BASE_URL"], |             "api_base_url": api_base_url, | ||||||
|             "embed_model": embed_model, |             "embed_model": embed_model, | ||||||
|             "embed_model_path": embed_model_path, |             "embed_model_path": embed_model_path, | ||||||
|             "model_device": embedding_device, |             "model_device": embedding_device, | ||||||
| @ -1041,7 +1064,9 @@ class ApiRequest: | |||||||
|     # code base 相关操作 |     # code base 相关操作 | ||||||
|     def create_code_base(self, cb_name, zip_file, do_interpret: bool, no_remote_api: bool = None, |     def create_code_base(self, cb_name, zip_file, do_interpret: bool, no_remote_api: bool = None, | ||||||
|                          embed_model: str="", embed_model_path: str="", embedding_device: str="", embed_engine: str="", |                          embed_model: str="", embed_model_path: str="", embedding_device: str="", embed_engine: str="", | ||||||
|                          llm_model: str ="", temperature: float= 0.2 |                          llm_model: str ="", temperature: float= 0.2, | ||||||
|  |                          api_key: str=os.environ["OPENAI_API_KEY"], | ||||||
|  |                          api_base_url: str = os.environ["API_BASE_URL"], | ||||||
|                          ): |                          ): | ||||||
|         ''' |         ''' | ||||||
|         创建 code_base |         创建 code_base | ||||||
| @ -1067,8 +1092,8 @@ class ApiRequest: | |||||||
|             "cb_name": cb_name, |             "cb_name": cb_name, | ||||||
|             "code_path": raw_code_path, |             "code_path": raw_code_path, | ||||||
|             "do_interpret": do_interpret, |             "do_interpret": do_interpret, | ||||||
|             "api_key": os.environ["OPENAI_API_KEY"], |             "api_key": api_key, | ||||||
|             "api_base_url": os.environ["API_BASE_URL"], |             "api_base_url": api_base_url, | ||||||
|             "embed_model": embed_model, |             "embed_model": embed_model, | ||||||
|             "embed_model_path": embed_model_path, |             "embed_model_path": embed_model_path, | ||||||
|             "embed_engine": embed_engine, |             "embed_engine": embed_engine, | ||||||
| @ -1091,7 +1116,9 @@ class ApiRequest: | |||||||
| 
 | 
 | ||||||
|     def delete_code_base(self, cb_name: str, no_remote_api: bool = None, |     def delete_code_base(self, cb_name: str, no_remote_api: bool = None, | ||||||
|                          embed_model: str="", embed_model_path: str="", embedding_device: str="", embed_engine: str="", |                          embed_model: str="", embed_model_path: str="", embedding_device: str="", embed_engine: str="", | ||||||
|                          llm_model: str ="", temperature: float= 0.2 |                          llm_model: str ="", temperature: float= 0.2, | ||||||
|  |                          api_key: str=os.environ["OPENAI_API_KEY"], | ||||||
|  |                          api_base_url: str = os.environ["API_BASE_URL"], | ||||||
|                          ): |                          ): | ||||||
|         ''' |         ''' | ||||||
|         删除 code_base |         删除 code_base | ||||||
| @ -1102,8 +1129,8 @@ class ApiRequest: | |||||||
|             no_remote_api = self.no_remote_api |             no_remote_api = self.no_remote_api | ||||||
|         data = { |         data = { | ||||||
|             "cb_name": cb_name, |             "cb_name": cb_name, | ||||||
|             "api_key": os.environ["OPENAI_API_KEY"], |             "api_key": api_key, | ||||||
|             "api_base_url": os.environ["API_BASE_URL"], |             "api_base_url": api_base_url, | ||||||
|             "embed_model": embed_model, |             "embed_model": embed_model, | ||||||
|             "embed_model_path": embed_model_path, |             "embed_model_path": embed_model_path, | ||||||
|             "embed_engine": embed_engine, |             "embed_engine": embed_engine, | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user