From 4d9b268a985aab2a2777ff5173254aca4b5932f3 Mon Sep 17 00:00:00 2001 From: shanshi Date: Tue, 12 Mar 2024 15:31:06 +0800 Subject: [PATCH] =?UTF-8?q?[feature](coagent)<=E5=A2=9E=E5=8A=A0antflow?= =?UTF-8?q?=E5=85=BC=E5=AE=B9=E5=92=8C=E5=A2=9E=E5=8A=A0coagent=20demo>?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- coagent/base_configs/env_config.py | 10 +- coagent/chat/base_chat.py | 2 +- coagent/chat/code_chat.py | 3 +- .../codechat/code_analyzer/code_intepreter.py | 12 +- coagent/codechat/code_crawler/dir_crawler.py | 2 +- coagent/codechat/code_search/code_search.py | 70 ++- .../codechat/code_search/cypher_generator.py | 2 +- .../codebase_handler/code_importer.py | 40 +- .../codebase_handler/codebase_handler.py | 137 ++++- coagent/connector/actions/__init__.py | 6 + coagent/connector/actions/base_action.py | 16 + coagent/connector/agents/base_agent.py | 203 +------ coagent/connector/agents/executor_agent.py | 19 +- coagent/connector/agents/react_agent.py | 100 +--- coagent/connector/agents/selector_agent.py | 97 +--- coagent/connector/antflow/__init__.py | 7 + coagent/connector/antflow/flow.py | 255 +++++++++ coagent/connector/chains/base_chain.py | 41 +- coagent/connector/configs/__init__.py | 5 +- coagent/connector/configs/agent_config.py | 96 +++- coagent/connector/configs/chain_config.py | 16 + coagent/connector/configs/phase_config.py | 59 +- coagent/connector/configs/prompt_config.py | 39 +- coagent/connector/configs/prompts/__init__.py | 7 +- .../prompts/code2doc_template_prompt.py | 95 ++++ .../prompts/code2test_template_prompt.py | 65 +++ coagent/connector/memory_manager.py | 163 ++++-- coagent/connector/message_process.py | 59 +- coagent/connector/phase/base_phase.py | 51 +- coagent/connector/prompt_manager/__init__.py | 2 + .../prompt_manager/extend_manager.py | 45 ++ .../prompt_manager/prompt_manager.py | 353 ++++++++++++ coagent/connector/schema/general_schema.py | 6 +- coagent/connector/schema/memory.py | 3 + coagent/connector/schema/message.py | 3 + coagent/connector/utils.py | 9 +- .../graph_db_handler/nebula_handler.py | 20 +- coagent/embeddings/faiss_m.py | 22 +- coagent/embeddings/get_embedding.py | 11 +- coagent/embeddings/in_memory.py | 49 ++ coagent/embeddings/utils.py | 9 +- coagent/llm_models/__init__.py | 4 +- coagent/llm_models/llm_config.py | 11 +- coagent/llm_models/openai_model.py | 58 +- coagent/retrieval/__init__.py | 5 + coagent/retrieval/base_retrieval.py | 75 +++ .../retrieval/document_loaders/__init__.py | 6 + .../retrieval/document_loaders/json_loader.py | 61 +++ .../document_loaders/jsonl_loader.py | 62 +++ coagent/retrieval/text_splitter/__init__.py | 3 + .../text_splitter/langchain_splitter.py | 77 +++ coagent/retrieval/text_splitter/utils.py | 0 coagent/sandbox/pycodebox.py | 82 ++- coagent/service/base_service.py | 4 +- coagent/service/cb_api.py | 47 +- coagent/service/faiss_db_service.py | 17 +- coagent/tools/cb_query_tool.py | 17 +- coagent/tools/codechat_tools.py | 15 +- coagent/tools/docs_retrieval.py | 4 - coagent/tools/duckduckgo_search.py | 6 +- coagent/utils/code2doc_util.py | 89 +++ coagent/utils/common_utils.py | 6 +- coagent/utils/path_utils.py | 2 +- configs/default_config.py | 13 +- configs/model_config.py.example | 2 +- configs/server_config.py.example | 3 - .../agent_examples/baseGroupPhase_example.py | 4 +- .../agent_examples/baseTaskPhase_example.py | 2 +- .../codeChatPhaseLocal_example.py | 135 +++++ .../agent_examples/codeChatPhase_example.py | 85 +-- examples/agent_examples/codeGenDoc_example.py | 507 ++++++++++++++++++ .../codeGenTestCases_example.py | 444 +++++++++++++++ .../agent_examples/codeReactPhase_example.py | 2 +- .../agent_examples/codeRetrieval_example.py | 7 +- .../codeToolReactPhase_example.py | 2 +- .../agent_examples/docChatPhase_example.py | 2 +- .../agent_examples/metagpt_phase_example.py | 2 +- .../agent_examples/searchChatPhase_example.py | 2 +- .../agent_examples/toolReactPhase_example.py | 2 +- examples/api.py | 6 +- .../auto_feedback_from_code_execution.py | 2 +- examples/start.py | 52 +- examples/webui/code.py | 6 +- examples/webui/dialogue.py | 81 +-- examples/webui/document.py | 28 +- examples/webui/utils.py | 101 ++-- 86 files changed, 3449 insertions(+), 901 deletions(-) create mode 100644 coagent/connector/actions/__init__.py create mode 100644 coagent/connector/actions/base_action.py create mode 100644 coagent/connector/antflow/__init__.py create mode 100644 coagent/connector/antflow/flow.py create mode 100644 coagent/connector/configs/prompts/code2doc_template_prompt.py create mode 100644 coagent/connector/configs/prompts/code2test_template_prompt.py create mode 100644 coagent/connector/prompt_manager/__init__.py create mode 100644 coagent/connector/prompt_manager/extend_manager.py create mode 100644 coagent/connector/prompt_manager/prompt_manager.py create mode 100644 coagent/embeddings/in_memory.py create mode 100644 coagent/retrieval/__init__.py create mode 100644 coagent/retrieval/base_retrieval.py create mode 100644 coagent/retrieval/document_loaders/__init__.py create mode 100644 coagent/retrieval/document_loaders/json_loader.py create mode 100644 coagent/retrieval/document_loaders/jsonl_loader.py create mode 100644 coagent/retrieval/text_splitter/__init__.py create mode 100644 coagent/retrieval/text_splitter/langchain_splitter.py create mode 100644 coagent/retrieval/text_splitter/utils.py create mode 100644 coagent/utils/code2doc_util.py create mode 100644 examples/agent_examples/codeChatPhaseLocal_example.py create mode 100644 examples/agent_examples/codeGenDoc_example.py create mode 100644 examples/agent_examples/codeGenTestCases_example.py diff --git a/coagent/base_configs/env_config.py b/coagent/base_configs/env_config.py index 45a20e5..e931131 100644 --- a/coagent/base_configs/env_config.py +++ b/coagent/base_configs/env_config.py @@ -26,9 +26,12 @@ JUPYTER_WORK_PATH = os.environ.get("JUPYTER_WORK_PATH", None) or os.path.join(ex WEB_CRAWL_PATH = os.environ.get("WEB_CRAWL_PATH", None) or os.path.join(executable_path, "knowledge_base") # NEBULA_DATA存储路径 -NELUBA_PATH = os.environ.get("NELUBA_PATH", None) or os.path.join(executable_path, "data/neluba_data") +NEBULA_PATH = os.environ.get("NEBULA_PATH", None) or os.path.join(executable_path, "data/nebula_data") -for _path in [LOG_PATH, SOURCE_PATH, KB_ROOT_PATH, NLTK_DATA_PATH, JUPYTER_WORK_PATH, WEB_CRAWL_PATH, NELUBA_PATH]: +# CHROMA 存储路径 +CHROMA_PERSISTENT_PATH = os.environ.get("CHROMA_PERSISTENT_PATH", None) or os.path.join(executable_path, "data/chroma_data") + +for _path in [LOG_PATH, SOURCE_PATH, KB_ROOT_PATH, CB_ROOT_PATH, NLTK_DATA_PATH, JUPYTER_WORK_PATH, WEB_CRAWL_PATH, NEBULA_PATH, CHROMA_PERSISTENT_PATH]: if not os.path.exists(_path): os.makedirs(_path, exist_ok=True) @@ -58,7 +61,8 @@ NEBULA_GRAPH_SERVER = { } # CHROMA CONFIG -CHROMA_PERSISTENT_PATH = '/home/user/chatbot/data/chroma_data' +# CHROMA_PERSISTENT_PATH = '/home/user/chatbot/data/chroma_data' +# CHROMA_PERSISTENT_PATH = '/Users/bingxu/Desktop/工作/大模型/chatbot/codefuse-chatbot-antcode/data/chroma_data' # 默认向量库类型。可选:faiss, milvus, pg. diff --git a/coagent/chat/base_chat.py b/coagent/chat/base_chat.py index 490f20d..2fe0943 100644 --- a/coagent/chat/base_chat.py +++ b/coagent/chat/base_chat.py @@ -7,7 +7,7 @@ from langchain import LLMChain from langchain.callbacks import AsyncIteratorCallbackHandler from langchain.prompts.chat import ChatPromptTemplate -from coagent.llm_models import getChatModel, getChatModelFromConfig +from coagent.llm_models import getChatModelFromConfig from coagent.chat.utils import History, wrap_done from coagent.llm_models.llm_config import LLMConfig, EmbedConfig # from configs.model_config import (llm_model_dict, LLM_MODEL, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD) diff --git a/coagent/chat/code_chat.py b/coagent/chat/code_chat.py index a0bb7a3..4e19c60 100644 --- a/coagent/chat/code_chat.py +++ b/coagent/chat/code_chat.py @@ -22,7 +22,7 @@ from coagent.connector.configs.prompts import CODE_PROMPT_TEMPLATE from coagent.chat.utils import History, wrap_done from coagent.utils import BaseResponse from .base_chat import Chat -from coagent.llm_models import getChatModel, getChatModelFromConfig +from coagent.llm_models import getChatModelFromConfig from coagent.llm_models.llm_config import LLMConfig, EmbedConfig @@ -67,6 +67,7 @@ class CodeChat(Chat): embed_model_path=embed_config.embed_model_path, embed_engine=embed_config.embed_engine, model_device=embed_config.model_device, + embed_config=embed_config ) context = codes_res['context'] diff --git a/coagent/codechat/code_analyzer/code_intepreter.py b/coagent/codechat/code_analyzer/code_intepreter.py index a96cda4..7e321b6 100644 --- a/coagent/codechat/code_analyzer/code_intepreter.py +++ b/coagent/codechat/code_analyzer/code_intepreter.py @@ -12,7 +12,7 @@ from langchain.schema import ( # from configs.model_config import CODE_INTERPERT_TEMPLATE from coagent.connector.configs.prompts import CODE_INTERPERT_TEMPLATE -from coagent.llm_models.openai_model import getChatModel, getChatModelFromConfig +from coagent.llm_models.openai_model import getChatModelFromConfig from coagent.llm_models.llm_config import LLMConfig @@ -53,9 +53,15 @@ class CodeIntepreter: message = CODE_INTERPERT_TEMPLATE.format(code=code) messages.append(message) - chat_ress = chat_model.batch(messages) + try: + chat_ress = [chat_model(messages) for message in messages] + except: + chat_ress = chat_model.batch(messages) for chat_res, code in zip(chat_ress, code_list): - res[code] = chat_res.content + try: + res[code] = chat_res.content + except: + res[code] = chat_res return res diff --git a/coagent/codechat/code_crawler/dir_crawler.py b/coagent/codechat/code_crawler/dir_crawler.py index 2c27b84..96dea0c 100644 --- a/coagent/codechat/code_crawler/dir_crawler.py +++ b/coagent/codechat/code_crawler/dir_crawler.py @@ -27,7 +27,7 @@ class DirCrawler: logger.info(java_file_list) for java_file in java_file_list: - with open(java_file) as f: + with open(java_file, encoding="utf-8") as f: java_code = ''.join(f.readlines()) java_code_dict[java_file] = java_code return java_code_dict diff --git a/coagent/codechat/code_search/code_search.py b/coagent/codechat/code_search/code_search.py index 7fc720d..7e4c505 100644 --- a/coagent/codechat/code_search/code_search.py +++ b/coagent/codechat/code_search/code_search.py @@ -5,6 +5,7 @@ @time: 2023/11/21 下午2:35 @desc: ''' +import json import time from loguru import logger from collections import defaultdict @@ -15,7 +16,7 @@ from coagent.db_handler.vector_db_handler.chroma_handler import ChromaHandler from coagent.codechat.code_search.cypher_generator import CypherGenerator from coagent.codechat.code_search.tagger import Tagger from coagent.embeddings.get_embedding import get_embedding -from coagent.llm_models.llm_config import LLMConfig +from coagent.llm_models.llm_config import LLMConfig, EmbedConfig # from configs.model_config import EMBEDDING_DEVICE, EMBEDDING_MODEL @@ -29,7 +30,8 @@ MAX_DISTANCE = 1000 class CodeSearch: - def __init__(self, llm_config: LLMConfig, nh: NebulaHandler, ch: ChromaHandler, limit: int = 3): + def __init__(self, llm_config: LLMConfig, nh: NebulaHandler, ch: ChromaHandler, limit: int = 3, + local_graph_file_path: str = ''): ''' init @param nh: NebulaHandler @@ -37,7 +39,13 @@ class CodeSearch: @param limit: limit of result ''' self.llm_config = llm_config + self.nh = nh + + if not self.nh: + with open(local_graph_file_path, 'r') as f: + self.graph = json.load(f) + self.ch = ch self.limit = limit @@ -51,7 +59,7 @@ class CodeSearch: tag_list = tagger.generate_tag_query(query) logger.info(f'query tag={tag_list}') - # get all verticex + # get all vertices vertex_list = self.nh.get_vertices().get('v', []) vertex_vid_list = [i.as_node().get_id().as_string() for i in vertex_list] @@ -81,7 +89,7 @@ class CodeSearch: # get most prominent package tag package_score_dict = defaultdict(lambda: 0) - for vertex, score in vertex_score_dict.items(): + for vertex, score in vertex_score_dict_final.items(): if '#' in vertex: # get class name first cypher = f'''MATCH (v1:class)-[e:contain]->(v2) WHERE id(v2) == '{vertex}' RETURN id(v1) as id;''' @@ -111,6 +119,53 @@ class CodeSearch: logger.info(f'ids={ids}') chroma_res = self.ch.get(ids=ids, include=['metadatas']) + for vertex, score in package_score_tuple: + index = chroma_res['result']['ids'].index(vertex) + code_text = chroma_res['result']['metadatas'][index]['code_text'] + res.append({ + "vertex": vertex, + "code_text": code_text} + ) + if len(res) >= self.limit: + break + # logger.info(f'retrival code={res}') + return res + + def search_by_tag_by_graph(self, query: str): + ''' + search code by tag with graph + @param query: + @return: + ''' + tagger = Tagger() + tag_list = tagger.generate_tag_query(query) + logger.info(f'query tag={tag_list}') + + # loop to get package node + package_score_dict = {} + for code, structure in self.graph.items(): + score = 0 + for class_name in structure['class_name_list']: + for tag in tag_list: + if tag.lower() in class_name.lower(): + score += 1 + + for func_name_list in structure['func_name_dict'].values(): + for func_name in func_name_list: + for tag in tag_list: + if tag.lower() in func_name.lower(): + score += 1 + package_score_dict[structure['pac_name']] = score + + # get respective code + res = [] + package_score_tuple = list(package_score_dict.items()) + package_score_tuple.sort(key=lambda x: x[1], reverse=True) + + ids = [i[0] for i in package_score_tuple] + logger.info(f'ids={ids}') + chroma_res = self.ch.get(ids=ids, include=['metadatas']) + # logger.info(chroma_res) for vertex, score in package_score_tuple: index = chroma_res['result']['ids'].index(vertex) @@ -121,23 +176,22 @@ class CodeSearch: ) if len(res) >= self.limit: break - logger.info(f'retrival code={res}') + # logger.info(f'retrival code={res}') return res - def search_by_desciption(self, query: str, engine: str, model_path: str = "text2vec-base-chinese", embedding_device: str = "cpu"): + def search_by_desciption(self, query: str, engine: str, model_path: str = "text2vec-base-chinese", embedding_device: str = "cpu", embed_config: EmbedConfig=None): ''' search by perform sim search @param query: @return: ''' query = query.replace(',', ',') - query_emb = get_embedding(engine=engine, text_list=[query], model_path=model_path, embedding_device= embedding_device,) + query_emb = get_embedding(engine=engine, text_list=[query], model_path=model_path, embedding_device= embedding_device, embed_config=embed_config) query_emb = query_emb[query] query_embeddings = [query_emb] query_result = self.ch.query(query_embeddings=query_embeddings, n_results=self.limit, include=['metadatas', 'distances']) - logger.debug(query_result) res = [] for idx, distance in enumerate(query_result['result']['distances'][0]): diff --git a/coagent/codechat/code_search/cypher_generator.py b/coagent/codechat/code_search/cypher_generator.py index 6b23f88..814839a 100644 --- a/coagent/codechat/code_search/cypher_generator.py +++ b/coagent/codechat/code_search/cypher_generator.py @@ -8,7 +8,7 @@ from langchain import PromptTemplate from loguru import logger -from coagent.llm_models.openai_model import getChatModel, getChatModelFromConfig +from coagent.llm_models.openai_model import getChatModelFromConfig from coagent.llm_models.llm_config import LLMConfig from coagent.utils.postprocess import replace_lt_gt from langchain.schema import ( diff --git a/coagent/codechat/codebase_handler/code_importer.py b/coagent/codechat/codebase_handler/code_importer.py index 801b19a..c374f39 100644 --- a/coagent/codechat/codebase_handler/code_importer.py +++ b/coagent/codechat/codebase_handler/code_importer.py @@ -6,11 +6,10 @@ @desc: ''' import time +import json +import os from loguru import logger -# from configs.server_config import NEBULA_HOST, NEBULA_PORT, NEBULA_USER, NEBULA_PASSWORD, NEBULA_STORAGED_PORT -# from configs.server_config import CHROMA_PERSISTENT_PATH -# from configs.model_config import EMBEDDING_DEVICE, EMBEDDING_MODEL from coagent.db_handler.graph_db_handler.nebula_handler import NebulaHandler from coagent.db_handler.vector_db_handler.chroma_handler import ChromaHandler from coagent.embeddings.get_embedding import get_embedding @@ -18,12 +17,14 @@ from coagent.llm_models.llm_config import EmbedConfig class CodeImporter: - def __init__(self, codebase_name: str, embed_config: EmbedConfig, nh: NebulaHandler, ch: ChromaHandler): + def __init__(self, codebase_name: str, embed_config: EmbedConfig, nh: NebulaHandler, ch: ChromaHandler, + local_graph_file_path: str): self.codebase_name = codebase_name # self.engine = engine - self.embed_config: EmbedConfig= embed_config + self.embed_config: EmbedConfig = embed_config self.nh = nh self.ch = ch + self.local_graph_file_path = local_graph_file_path def import_code(self, static_analysis_res: dict, interpretation: dict, do_interpret: bool = True): ''' @@ -31,9 +32,14 @@ class CodeImporter: @return: ''' static_analysis_res = self.filter_out_vertex(static_analysis_res, interpretation) - logger.info(f'static_analysis_res={static_analysis_res}') - self.analysis_res_to_graph(static_analysis_res) + if self.nh: + self.analysis_res_to_graph(static_analysis_res) + else: + # persist to local dir + with open(self.local_graph_file_path, 'w') as f: + json.dump(static_analysis_res, f) + self.interpretation_to_db(static_analysis_res, interpretation, do_interpret) def filter_out_vertex(self, static_analysis_res, interpretation): @@ -114,12 +120,12 @@ class CodeImporter: # create vertex for tag_name, value_dict in vertex_value_dict.items(): res = self.nh.insert_vertex(tag_name, value_dict) - logger.debug(res.error_msg()) + # logger.debug(res.error_msg()) # create edge for tag_name, value_dict in edge_value_dict.items(): res = self.nh.insert_edge(tag_name, value_dict) - logger.debug(res.error_msg()) + # logger.debug(res.error_msg()) return @@ -132,7 +138,7 @@ class CodeImporter: if do_interpret: logger.info('start get embedding for interpretion') interp_list = list(interpretation.values()) - emb = get_embedding(engine=self.embed_config.embed_engine, text_list=interp_list, model_path=self.embed_config.embed_model_path, embedding_device= self.embed_config.model_device) + emb = get_embedding(engine=self.embed_config.embed_engine, text_list=interp_list, model_path=self.embed_config.embed_model_path, embedding_device= self.embed_config.model_device, embed_config=self.embed_config) logger.info('get embedding done') else: emb = {i: [0] for i in list(interpretation.values())} @@ -161,7 +167,7 @@ class CodeImporter: # add documents to chroma res = self.ch.add_data(ids=ids, embeddings=embeddings, documents=documents, metadatas=metadatas) - logger.debug(res) + # logger.debug(res) def init_graph(self): ''' @@ -169,7 +175,7 @@ class CodeImporter: @return: ''' res = self.nh.create_space(space_name=self.codebase_name, vid_type='FIXED_STRING(1024)') - logger.debug(res.error_msg()) + # logger.debug(res.error_msg()) time.sleep(5) self.nh.set_space_name(self.codebase_name) @@ -179,29 +185,29 @@ class CodeImporter: tag_name = 'package' prop_dict = {} res = self.nh.create_tag(tag_name, prop_dict) - logger.debug(res.error_msg()) + # logger.debug(res.error_msg()) tag_name = 'class' prop_dict = {} res = self.nh.create_tag(tag_name, prop_dict) - logger.debug(res.error_msg()) + # logger.debug(res.error_msg()) tag_name = 'method' prop_dict = {} res = self.nh.create_tag(tag_name, prop_dict) - logger.debug(res.error_msg()) + # logger.debug(res.error_msg()) # create edge type edge_type_name = 'contain' prop_dict = {} res = self.nh.create_edge_type(edge_type_name, prop_dict) - logger.debug(res.error_msg()) + # logger.debug(res.error_msg()) # create edge type edge_type_name = 'depend' prop_dict = {} res = self.nh.create_edge_type(edge_type_name, prop_dict) - logger.debug(res.error_msg()) + # logger.debug(res.error_msg()) if __name__ == '__main__': diff --git a/coagent/codechat/codebase_handler/codebase_handler.py b/coagent/codechat/codebase_handler/codebase_handler.py index 136e737..2af52c6 100644 --- a/coagent/codechat/codebase_handler/codebase_handler.py +++ b/coagent/codechat/codebase_handler/codebase_handler.py @@ -5,16 +5,15 @@ @time: 2023/11/21 下午2:25 @desc: ''' +import os import time +import json +from typing import List from loguru import logger -# from configs.server_config import NEBULA_HOST, NEBULA_PORT, NEBULA_USER, NEBULA_PASSWORD, NEBULA_STORAGED_PORT -# from configs.server_config import CHROMA_PERSISTENT_PATH -# from configs.model_config import EMBEDDING_ENGINE - from coagent.base_configs.env_config import ( NEBULA_HOST, NEBULA_PORT, NEBULA_USER, NEBULA_PASSWORD, NEBULA_STORAGED_PORT, - CHROMA_PERSISTENT_PATH + CHROMA_PERSISTENT_PATH, CB_ROOT_PATH ) @@ -35,7 +34,9 @@ class CodeBaseHandler: language: str = 'java', crawl_type: str = 'ZIP', embed_config: EmbedConfig = EmbedConfig(), - llm_config: LLMConfig = LLMConfig() + llm_config: LLMConfig = LLMConfig(), + use_nh: bool = True, + local_graph_path: str = CB_ROOT_PATH ): self.codebase_name = codebase_name self.code_path = code_path @@ -43,11 +44,28 @@ class CodeBaseHandler: self.crawl_type = crawl_type self.embed_config = embed_config self.llm_config = llm_config + self.local_graph_file_path = local_graph_path + os.sep + f'{self.codebase_name}_graph.json' - self.nh = NebulaHandler(host=NEBULA_HOST, port=NEBULA_PORT, username=NEBULA_USER, - password=NEBULA_PASSWORD, space_name=codebase_name) - self.nh.add_host(NEBULA_HOST, NEBULA_STORAGED_PORT) - time.sleep(1) + if use_nh: + try: + self.nh = NebulaHandler(host=NEBULA_HOST, port=NEBULA_PORT, username=NEBULA_USER, + password=NEBULA_PASSWORD, space_name=codebase_name) + self.nh.add_host(NEBULA_HOST, NEBULA_STORAGED_PORT) + time.sleep(1) + except: + self.nh = None + try: + with open(self.local_graph_file_path, 'r') as f: + self.graph = json.load(f) + except: + pass + elif local_graph_path: + self.nh = None + try: + with open(self.local_graph_file_path, 'r') as f: + self.graph = json.load(f) + except: + pass self.ch = ChromaHandler(path=CHROMA_PERSISTENT_PATH, collection_name=codebase_name) @@ -58,9 +76,10 @@ class CodeBaseHandler: ''' # init graph to init tag and edge code_importer = CodeImporter(embed_config=self.embed_config, codebase_name=self.codebase_name, - nh=self.nh, ch=self.ch) - code_importer.init_graph() - time.sleep(5) + nh=self.nh, ch=self.ch, local_graph_file_path=self.local_graph_file_path) + if self.nh: + code_importer.init_graph() + time.sleep(5) # crawl code st0 = time.time() @@ -71,7 +90,7 @@ class CodeBaseHandler: # analyze code logger.info('start analyze') st1 = time.time() - code_analyzer = CodeAnalyzer(language=self.language, llm_config = self.llm_config) + code_analyzer = CodeAnalyzer(language=self.language, llm_config=self.llm_config) static_analysis_res, interpretation = code_analyzer.analyze(code_dict, do_interpret=do_interpret) logger.debug('analyze done, rt={}'.format(time.time() - st1)) @@ -81,8 +100,12 @@ class CodeBaseHandler: logger.debug('update codebase done, rt={}'.format(time.time() - st2)) # get KG info - stat = self.nh.get_stat() - vertices_num, edges_num = stat['vertices'], stat['edges'] + if self.nh: + stat = self.nh.get_stat() + vertices_num, edges_num = stat['vertices'], stat['edges'] + else: + vertices_num = 0 + edges_num = 0 # get chroma info file_num = self.ch.count()['result'] @@ -95,7 +118,11 @@ class CodeBaseHandler: @param codebase_name: name of codebase @return: ''' - self.nh.drop_space(space_name=codebase_name) + if self.nh: + self.nh.drop_space(space_name=codebase_name) + elif self.local_graph_file_path and os.path.isfile(self.local_graph_file_path): + os.remove(self.local_graph_file_path) + self.ch.delete_collection(collection_name=codebase_name) def crawl_code(self, zip_file=''): @@ -124,9 +151,15 @@ class CodeBaseHandler: @param search_type: ['cypher', 'graph', 'vector'] @return: ''' - assert search_type in ['cypher', 'tag', 'description'] + if self.nh: + assert search_type in ['cypher', 'tag', 'description'] + else: + if search_type == 'tag': + search_type = 'tag_by_local_graph' + assert search_type in ['tag_by_local_graph', 'description'] - code_search = CodeSearch(llm_config=self.llm_config, nh=self.nh, ch=self.ch, limit=limit) + code_search = CodeSearch(llm_config=self.llm_config, nh=self.nh, ch=self.ch, limit=limit, + local_graph_file_path=self.local_graph_file_path) if search_type == 'cypher': search_res = code_search.search_by_cypher(query=query) @@ -134,7 +167,11 @@ class CodeBaseHandler: search_res = code_search.search_by_tag(query=query) elif search_type == 'description': search_res = code_search.search_by_desciption( - query=query, engine=self.embed_config.embed_engine, model_path=self.embed_config.embed_model_path, embedding_device=self.embed_config.model_device) + query=query, engine=self.embed_config.embed_engine, model_path=self.embed_config.embed_model_path, + embedding_device=self.embed_config.model_device, embed_config=self.embed_config) + elif search_type == 'tag_by_local_graph': + search_res = code_search.search_by_tag_by_graph(query=query) + context, related_vertice = self.format_search_res(search_res, search_type) return context, related_vertice @@ -160,6 +197,12 @@ class CodeBaseHandler: for code in search_res: context = context + code['code_text'] + '\n' related_vertice.append(code['vertex']) + elif search_type == 'tag_by_local_graph': + context = '' + related_vertice = [] + for code in search_res: + context = context + code['code_text'] + '\n' + related_vertice.append(code['vertex']) elif search_type == 'description': context = '' related_vertice = [] @@ -169,17 +212,63 @@ class CodeBaseHandler: return context, related_vertice + def search_vertices(self, vertex_type="class") -> List[str]: + ''' + 通过 method/class 来搜索所有的节点 + ''' + vertices = [] + if self.nh: + vertices = self.nh.get_all_vertices() + vertices = [str(v.as_node().get_id()) for v in vertices["v"] if vertex_type in v.as_node().tags()] + # for v in vertices["v"]: + # logger.debug(f"{v.as_node().get_id()}, {v.as_node().tags()}") + else: + if vertex_type == "class": + vertices = [str(class_name) for code, structure in self.graph.items() for class_name in structure['class_name_list']] + elif vertex_type == "method": + vertices = [ + str(methods_name) + for code, structure in self.graph.items() + for methods_names in structure['func_name_dict'].values() + for methods_name in methods_names + ] + # logger.debug(vertices) + return vertices + if __name__ == '__main__': - codebase_name = 'testing' + from configs.model_config import KB_ROOT_PATH, JUPYTER_WORK_PATH + from configs.server_config import SANDBOX_SERVER + + LLM_MODEL = "gpt-3.5-turbo" + llm_config = LLMConfig( + model_name=LLM_MODEL, model_device="cpu", api_key=os.environ["OPENAI_API_KEY"], + api_base_url=os.environ["API_BASE_URL"], temperature=0.3 + ) + src_dir = '/Users/bingxu/Desktop/工作/大模型/chatbot/Codefuse-chatbot-antcode' + embed_config = EmbedConfig( + embed_engine="model", embed_model="text2vec-base-chinese", + embed_model_path=os.path.join(src_dir, "embedding_models/text2vec-base-chinese") + ) + + codebase_name = 'client_local' code_path = '/Users/bingxu/Desktop/工作/大模型/chatbot/test_code_repo/client' - cbh = CodeBaseHandler(codebase_name, code_path, crawl_type='dir') + use_nh = False + local_graph_path = '/Users/bingxu/Desktop/工作/大模型/chatbot/Codefuse-chatbot-antcode/code_base' + CHROMA_PERSISTENT_PATH = '/Users/bingxu/Desktop/工作/大模型/chatbot/Codefuse-chatbot-antcode/data/chroma_data' + + cbh = CodeBaseHandler(codebase_name, code_path, crawl_type='dir', use_nh=use_nh, local_graph_path=local_graph_path, + llm_config=llm_config, embed_config=embed_config) + + # test import code + # cbh.import_code(do_interpret=True) # query = '使用不同的HTTP请求类型(GET、POST、DELETE等)来执行不同的操作' # query = '代码中一共有多少个类' + # query = 'remove 这个函数是用来做什么的' + query = '有没有函数是从字符串中删除指定字符串的功能' - query = 'intercept 函数作用是什么' - search_type = 'graph' + search_type = 'description' limit = 2 res = cbh.search_code(query, search_type, limit) logger.debug(res) diff --git a/coagent/connector/actions/__init__.py b/coagent/connector/actions/__init__.py new file mode 100644 index 0000000..13bbc3c --- /dev/null +++ b/coagent/connector/actions/__init__.py @@ -0,0 +1,6 @@ +from .base_action import BaseAction + + +__all__ = [ + "BaseAction" +] \ No newline at end of file diff --git a/coagent/connector/actions/base_action.py b/coagent/connector/actions/base_action.py new file mode 100644 index 0000000..d64f97b --- /dev/null +++ b/coagent/connector/actions/base_action.py @@ -0,0 +1,16 @@ + +from langchain.schema import BaseRetriever, Document + +class BaseAction: + + + def __init__(self, ): + pass + + def step(self, ): + pass + + def astep(self, ): + pass + + \ No newline at end of file diff --git a/coagent/connector/agents/base_agent.py b/coagent/connector/agents/base_agent.py index a1ae699..8afeeab 100644 --- a/coagent/connector/agents/base_agent.py +++ b/coagent/connector/agents/base_agent.py @@ -4,25 +4,25 @@ import re, os import copy from loguru import logger +from langchain.schema import BaseRetriever + from coagent.connector.schema import ( Memory, Task, Role, Message, PromptField, LogVerboseEnum ) from coagent.connector.memory_manager import BaseMemoryManager -from coagent.connector.configs.prompts import BEGIN_PROMPT_INPUT from coagent.connector.message_process import MessageUtils -from coagent.llm_models import getChatModel, getExtraModel, LLMConfig, getChatModelFromConfig, EmbedConfig -from coagent.connector.prompt_manager import PromptManager +from coagent.llm_models import getExtraModel, LLMConfig, getChatModelFromConfig, EmbedConfig +from coagent.connector.prompt_manager.prompt_manager import PromptManager from coagent.connector.memory_manager import LocalMemoryManager -from coagent.connector.utils import parse_section -# from configs.model_config import JUPYTER_WORK_PATH -# from configs.server_config import SANDBOX_SERVER +from coagent.base_configs.env_config import JUPYTER_WORK_PATH, KB_ROOT_PATH + class BaseAgent: def __init__( self, role: Role, - prompt_config: [PromptField], + prompt_config: List[PromptField], prompt_manager_type: str = "PromptManager", task: Task = None, memory: Memory = None, @@ -33,8 +33,11 @@ class BaseAgent: llm_config: LLMConfig = None, embed_config: EmbedConfig = None, sandbox_server: dict = {}, - jupyter_work_path: str = "", - kb_root_path: str = "", + jupyter_work_path: str = JUPYTER_WORK_PATH, + kb_root_path: str = KB_ROOT_PATH, + doc_retrieval: Union[BaseRetriever] = None, + code_retrieval = None, + search_retrieval = None, log_verbose: str = "0" ): @@ -43,7 +46,7 @@ class BaseAgent: self.sandbox_server = sandbox_server self.jupyter_work_path = jupyter_work_path self.kb_root_path = kb_root_path - self.message_utils = MessageUtils(role, sandbox_server, jupyter_work_path, embed_config, llm_config, kb_root_path, log_verbose) + self.message_utils = MessageUtils(role, sandbox_server, jupyter_work_path, embed_config, llm_config, kb_root_path, doc_retrieval, code_retrieval, search_retrieval, log_verbose) self.memory = self.init_history(memory) self.llm_config: LLMConfig = llm_config self.embed_config: EmbedConfig = embed_config @@ -82,12 +85,8 @@ class BaseAgent: llm_config=self.embed_config ) memory_manager.append(query) - memory_pool = memory_manager.current_memory - else: - memory_pool = memory_manager.current_memory + memory_pool = memory_manager.get_memory_pool(query.user_name) - - logger.debug(f"memory_pool: {memory_pool}") prompt = self.prompt_manager.generate_full_prompt( previous_agent_message=query_c, agent_long_term_memory=self.memory, ui_history=history, chain_summary_messages=background, memory_pool=memory_pool) content = self.llm.predict(prompt) @@ -99,6 +98,7 @@ class BaseAgent: logger.info(f"{self.role.role_name} content: {content}") output_message = Message( + user_name=query.user_name, role_name=self.role.role_name, role_type="assistant", #self.role.role_type, role_content=content, @@ -151,10 +151,7 @@ class BaseAgent: self.memory = self.init_history() def create_llm_engine(self, llm_config: LLMConfig = None, temperature=0.2, stop=None): - if llm_config is None: - return getChatModel(temperature=temperature, stop=stop) - else: - return getChatModelFromConfig(llm_config=llm_config) + return getChatModelFromConfig(llm_config=llm_config) def registry_actions(self, actions): '''registry llm's actions''' @@ -212,171 +209,3 @@ class BaseAgent: def get_memory_str(self, content_key="role_content"): return "\n".join([": ".join(i) for i in self.memory.to_tuple_messages(content_key="step_content")]) - - - def create_prompt( - self, query: Message, memory: Memory =None, history: Memory = None, background: Memory = None, memory_pool: Memory=None, prompt_mamnger=None) -> str: - ''' - prompt engineer, contains role\task\tools\docs\memory - ''' - # - doc_infos = self.create_doc_prompt(query) - code_infos = self.create_codedoc_prompt(query) - # - formatted_tools, tool_names, _ = self.create_tools_prompt(query) - task_prompt = self.create_task_prompt(query) - background_prompt = self.create_background_prompt(background, control_key="step_content") - history_prompt = self.create_history_prompt(history) - selfmemory_prompt = self.create_selfmemory_prompt(memory, control_key="step_content") - - # extra_system_prompt = self.role.role_prompt - - - prompt = self.role.role_prompt.format(**{"formatted_tools": formatted_tools, "tool_names": tool_names}) - # - memory_pool_select_by_agent_key = self.select_memory_by_agent_key(memory_pool) - memory_pool_select_by_agent_key_context = '\n\n'.join([f"*{k}*\n{v}" for parsed_output in memory_pool_select_by_agent_key.get_parserd_output_list() for k, v in parsed_output.items() if k not in ['Action Status']]) - - # input_query = query.input_query - - # # logger.debug(f"{self.role.role_name} extra_system_prompt: {self.role.role_prompt}") - # # logger.debug(f"{self.role.role_name} input_query: {input_query}") - # # logger.debug(f"{self.role.role_name} doc_infos: {doc_infos}") - # # logger.debug(f"{self.role.role_name} tool_names: {tool_names}") - # if "**Context:**" in self.role.role_prompt: - # # logger.debug(f"parsed_output_list: {query.parsed_output_list}") - # # input_query = "'''" + "\n".join([f"###{k}###\n{v}" for i in query.parsed_output_list for k,v in i.items() if "Action Status" !=k]) + "'''" - # context = "\n".join([f"*{k}*\n{v}" for i in query.parsed_output_list for k,v in i.items() if "Action Status" !=k]) - # # context = history_prompt or '""' - # # logger.debug(f"parsed_output_list: {t}") - # prompt += "\n" + QUERY_CONTEXT_PROMPT_INPUT.format(**{"context": context, "query": query.origin_query}) - # else: - # prompt += "\n" + PLAN_PROMPT_INPUT.format(**{"query": input_query}) - - task = query.task or self.task - if task_prompt is not None: - prompt += "\n" + task.task_prompt - - DocInfos = "" - if doc_infos is not None and doc_infos!="" and doc_infos!="不存在知识库辅助信息": - DocInfos += f"\nDocument Information: {doc_infos}" - - if code_infos is not None and code_infos!="" and code_infos!="不存在代码库辅助信息": - DocInfos += f"\nCodeBase Infomation: {code_infos}" - - # if selfmemory_prompt: - # prompt += "\n" + selfmemory_prompt - - # if background_prompt: - # prompt += "\n" + background_prompt - - # if history_prompt: - # prompt += "\n" + history_prompt - - input_query = query.input_query - - # logger.debug(f"{self.role.role_name} extra_system_prompt: {self.role.role_prompt}") - # logger.debug(f"{self.role.role_name} input_query: {input_query}") - # logger.debug(f"{self.role.role_name} doc_infos: {doc_infos}") - # logger.debug(f"{self.role.role_name} tool_names: {tool_names}") - - # extra_system_prompt = self.role.role_prompt - input_keys = parse_section(self.role.role_prompt, 'Input Format') - prompt = self.role.role_prompt.format(**{"formatted_tools": formatted_tools, "tool_names": tool_names}) - prompt += "\n" + BEGIN_PROMPT_INPUT - for input_key in input_keys: - if input_key == "Origin Query": - prompt += "\n**Origin Query:**\n" + query.origin_query - elif input_key == "Context": - context = "\n".join([f"*{k}*\n{v}" for i in query.parsed_output_list for k,v in i.items() if "Action Status" !=k]) - if history: - context = history_prompt + "\n" + context - if not context: - context = "there is no context" - - if self.focus_agents and memory_pool_select_by_agent_key_context: - context = memory_pool_select_by_agent_key_context - prompt += "\n**Context:**\n" + context + "\n" + input_query - elif input_key == "DocInfos": - if DocInfos: - prompt += "\n**DocInfos:**\n" + DocInfos - else: - prompt += "\n**DocInfos:**\n" + "Empty" - elif input_key == "Question": - prompt += "\n**Question:**\n" + input_query - - # if "**Context:**" in self.role.role_prompt: - # # logger.debug(f"parsed_output_list: {query.parsed_output_list}") - # # input_query = "'''" + "\n".join([f"###{k}###\n{v}" for i in query.parsed_output_list for k,v in i.items() if "Action Status" !=k]) + "'''" - # context = "\n".join([f"*{k}*\n{v}" for i in query.parsed_output_list for k,v in i.items() if "Action Status" !=k]) - # if history: - # context = history_prompt + "\n" + context - - # if not context: - # context = "there is no context" - - # # logger.debug(f"parsed_output_list: {t}") - # if "DocInfos" in prompt: - # prompt += "\n" + QUERY_CONTEXT_DOC_PROMPT_INPUT.format(**{"context": context, "query": query.origin_query, "DocInfos": DocInfos}) - # else: - # prompt += "\n" + QUERY_CONTEXT_PROMPT_INPUT.format(**{"context": context, "query": query.origin_query, "DocInfos": DocInfos}) - # else: - # prompt += "\n" + BASE_PROMPT_INPUT.format(**{"query": input_query}) - - # prompt = extra_system_prompt.format(**{"query": input_query, "doc_infos": doc_infos, "formatted_tools": formatted_tools, "tool_names": tool_names}) - while "{{" in prompt or "}}" in prompt: - prompt = prompt.replace("{{", "{") - prompt = prompt.replace("}}", "}") - - # logger.debug(f"{self.role.role_name} prompt: {prompt}") - return prompt - - def create_doc_prompt(self, message: Message) -> str: - '''''' - db_docs = message.db_docs - search_docs = message.search_docs - doc_infos = "\n".join([doc.get_snippet() for doc in db_docs] + [doc.get_snippet() for doc in search_docs]) - return doc_infos or "不存在知识库辅助信息" - - def create_codedoc_prompt(self, message: Message) -> str: - '''''' - code_docs = message.code_docs - doc_infos = "\n".join([doc.get_code() for doc in code_docs]) - return doc_infos or "不存在代码库辅助信息" - - def create_tools_prompt(self, message: Message) -> str: - tools = message.tools - tool_strings = [] - tools_descs = [] - for tool in tools: - args_schema = re.sub("}", "}}}}", re.sub("{", "{{{{", str(tool.args))) - tool_strings.append(f"{tool.name}: {tool.description}, args: {args_schema}") - tools_descs.append(f"{tool.name}: {tool.description}") - formatted_tools = "\n".join(tool_strings) - tools_desc_str = "\n".join(tools_descs) - tool_names = ", ".join([tool.name for tool in tools]) - return formatted_tools, tool_names, tools_desc_str - - def create_task_prompt(self, message: Message) -> str: - task = message.task or self.task - return "\n任务目标: " + task.task_prompt if task is not None else None - - def create_background_prompt(self, background: Memory, control_key="role_content") -> str: - background_message = None if background is None else background.to_str_messages(content_key=control_key) - # logger.debug(f"background_message: {background_message}") - if background_message: - background_message = re.sub("}", "}}", re.sub("{", "{{", background_message)) - return "\n背景信息: " + background_message if background_message else None - - def create_history_prompt(self, history: Memory, control_key="role_content") -> str: - history_message = None if history is None else history.to_str_messages(content_key=control_key) - if history_message: - history_message = re.sub("}", "}}", re.sub("{", "{{", history_message)) - return "\n补充对话信息: " + history_message if history_message else None - - def create_selfmemory_prompt(self, selfmemory: Memory, control_key="role_content") -> str: - selfmemory_message = None if selfmemory is None else selfmemory.to_str_messages(content_key=control_key) - if selfmemory_message: - selfmemory_message = re.sub("}", "}}", re.sub("{", "{{", selfmemory_message)) - return "\n补充自身对话信息: " + selfmemory_message if selfmemory_message else None - \ No newline at end of file diff --git a/coagent/connector/agents/executor_agent.py b/coagent/connector/agents/executor_agent.py index 4a41aef..7c714fe 100644 --- a/coagent/connector/agents/executor_agent.py +++ b/coagent/connector/agents/executor_agent.py @@ -2,14 +2,15 @@ from typing import List, Union import copy from loguru import logger +from langchain.schema import BaseRetriever + from coagent.connector.schema import ( Memory, Task, Env, Role, Message, ActionStatus, PromptField, LogVerboseEnum ) from coagent.connector.memory_manager import BaseMemoryManager -from coagent.connector.configs.prompts import BEGIN_PROMPT_INPUT from coagent.llm_models import LLMConfig, EmbedConfig from coagent.connector.memory_manager import LocalMemoryManager - +from coagent.base_configs.env_config import JUPYTER_WORK_PATH, KB_ROOT_PATH from .base_agent import BaseAgent @@ -17,7 +18,7 @@ class ExecutorAgent(BaseAgent): def __init__( self, role: Role, - prompt_config: [PromptField], + prompt_config: List[PromptField], prompt_manager_type: str= "PromptManager", task: Task = None, memory: Memory = None, @@ -28,14 +29,17 @@ class ExecutorAgent(BaseAgent): llm_config: LLMConfig = None, embed_config: EmbedConfig = None, sandbox_server: dict = {}, - jupyter_work_path: str = "", - kb_root_path: str = "", + jupyter_work_path: str = JUPYTER_WORK_PATH, + kb_root_path: str = KB_ROOT_PATH, + doc_retrieval: Union[BaseRetriever] = None, + code_retrieval = None, + search_retrieval = None, log_verbose: str = "0" ): super().__init__(role, prompt_config, prompt_manager_type, task, memory, chat_turn, focus_agents, focus_message_keys, llm_config, embed_config, sandbox_server, - jupyter_work_path, kb_root_path, log_verbose + jupyter_work_path, kb_root_path, doc_retrieval, code_retrieval, search_retrieval, log_verbose ) self.do_all_task = True # run all tasks @@ -45,6 +49,7 @@ class ExecutorAgent(BaseAgent): task_executor_memory = Memory(messages=[]) # insert query output_message = Message( + user_name=query.user_name, role_name=self.role.role_name, role_type="assistant", #self.role.role_type, role_content=query.input_query, @@ -115,7 +120,7 @@ class ExecutorAgent(BaseAgent): history: Memory, background: Memory, memory_manager: BaseMemoryManager, task_memory: Memory) -> Union[Message, Memory]: '''execute the llm predict by created prompt''' - memory_pool = memory_manager.current_memory + memory_pool = memory_manager.get_memory_pool(query.user_name) prompt = self.prompt_manager.generate_full_prompt( previous_agent_message=query, agent_long_term_memory=self_memory, ui_history=history, chain_summary_messages=background, memory_pool=memory_pool, task_memory=task_memory) diff --git a/coagent/connector/agents/react_agent.py b/coagent/connector/agents/react_agent.py index 64e23f1..1ade305 100644 --- a/coagent/connector/agents/react_agent.py +++ b/coagent/connector/agents/react_agent.py @@ -3,23 +3,23 @@ import traceback import copy from loguru import logger +from langchain.schema import BaseRetriever + from coagent.connector.schema import ( Memory, Task, Env, Role, Message, ActionStatus, PromptField, LogVerboseEnum ) from coagent.connector.memory_manager import BaseMemoryManager -from coagent.connector.configs.agent_config import REACT_PROMPT_INPUT from coagent.llm_models import LLMConfig, EmbedConfig from .base_agent import BaseAgent from coagent.connector.memory_manager import LocalMemoryManager - -from coagent.connector.prompt_manager import PromptManager +from coagent.base_configs.env_config import JUPYTER_WORK_PATH, KB_ROOT_PATH class ReactAgent(BaseAgent): def __init__( self, role: Role, - prompt_config: [PromptField], + prompt_config: List[PromptField], prompt_manager_type: str = "PromptManager", task: Task = None, memory: Memory = None, @@ -30,14 +30,17 @@ class ReactAgent(BaseAgent): llm_config: LLMConfig = None, embed_config: EmbedConfig = None, sandbox_server: dict = {}, - jupyter_work_path: str = "", - kb_root_path: str = "", + jupyter_work_path: str = JUPYTER_WORK_PATH, + kb_root_path: str = KB_ROOT_PATH, + doc_retrieval: Union[BaseRetriever] = None, + code_retrieval = None, + search_retrieval = None, log_verbose: str = "0" ): super().__init__(role, prompt_config, prompt_manager_type, task, memory, chat_turn, focus_agents, focus_message_keys, llm_config, embed_config, sandbox_server, - jupyter_work_path, kb_root_path, log_verbose + jupyter_work_path, kb_root_path, doc_retrieval, code_retrieval, search_retrieval, log_verbose ) def step(self, query: Message, history: Memory = None, background: Memory = None, memory_manager: BaseMemoryManager = None) -> Message: @@ -52,6 +55,7 @@ class ReactAgent(BaseAgent): react_memory = Memory(messages=[]) # insert query output_message = Message( + user_name=query.user_name, role_name=self.role.role_name, role_type="assistant", #self.role.role_type, role_content=query.input_query, @@ -84,9 +88,7 @@ class ReactAgent(BaseAgent): llm_config=self.embed_config ) memory_manager.append(query) - memory_pool = memory_manager.current_memory - else: - memory_pool = memory_manager.current_memory + memory_pool = memory_manager.get_memory_pool(query_c.user_name) prompt = self.prompt_manager.generate_full_prompt( previous_agent_message=query_c, agent_long_term_memory=self.memory, ui_history=history, chain_summary_messages=background, react_memory=react_memory, @@ -142,82 +144,4 @@ class ReactAgent(BaseAgent): title = f"<<<<{self.role.role_name}'s prompt>>>>" print("#"*len(title) + f"\n{title}\n"+ "#"*len(title)+ f"\n\n{prompt}\n\n") - # def create_prompt( - # self, query: Message, memory: Memory =None, history: Memory = None, background: Memory = None, react_memory: Memory = None, memory_manager: BaseMemoryManager= None, - # prompt_mamnger=None) -> str: - # prompt_mamnger = PromptManager() - # prompt_mamnger.register_standard_fields() - - # # input_keys = parse_section(self.role.role_prompt, 'Agent Profile') - - # data_dict = { - # "agent_profile": extract_section(self.role.role_prompt, 'Agent Profile'), - # "tool_information": query.tools, - # "session_records": memory_manager, - # "reference_documents": query, - # "output_format": extract_section(self.role.role_prompt, 'Response Output Format'), - # "response": "\n".join(["\n".join([f"**{k}:**\n{v}" for k,v in _dict.items()]) for _dict in react_memory.get_parserd_output()]), - # } - # # logger.debug(memory_pool) - - # return prompt_mamnger.generate_full_prompt(data_dict) - - def create_prompt( - self, query: Message, memory: Memory =None, history: Memory = None, background: Memory = None, react_memory: Memory = None, memory_pool: Memory= None, - prompt_mamnger=None) -> str: - ''' - role\task\tools\docs\memory - ''' - # - doc_infos = self.create_doc_prompt(query) - code_infos = self.create_codedoc_prompt(query) - # - formatted_tools, tool_names, _ = self.create_tools_prompt(query) - task_prompt = self.create_task_prompt(query) - background_prompt = self.create_background_prompt(background) - history_prompt = self.create_history_prompt(history) - selfmemory_prompt = self.create_selfmemory_prompt(memory, control_key="step_content") - # - # extra_system_prompt = self.role.role_prompt - prompt = self.role.role_prompt.format(**{"formatted_tools": formatted_tools, "tool_names": tool_names}) - - # react 流程是自身迭代过程,另外二次触发的是需要作为历史对话信息 - # input_query = react_memory.to_tuple_messages(content_key="step_content") - # # input_query = query.input_query + "\n" + "\n".join([f"{v}" for k, v in input_query if v]) - # input_query = "\n".join([f"{v}" for k, v in input_query if v]) - input_query = "\n".join(["\n".join([f"**{k}:**\n{v}" for k,v in _dict.items()]) for _dict in react_memory.get_parserd_output()]) - # logger.debug(f"input_query: {input_query}") - - prompt += "\n" + REACT_PROMPT_INPUT.format(**{"query": input_query}) - - task = query.task or self.task - # if task_prompt is not None: - # prompt += "\n" + task.task_prompt - - # if doc_infos is not None and doc_infos!="" and doc_infos!="不存在知识库辅助信息": - # prompt += f"\n知识库信息: {doc_infos}" - - # if code_infos is not None and code_infos!="" and code_infos!="不存在代码库辅助信息": - # prompt += f"\n代码库信息: {code_infos}" - - # if background_prompt: - # prompt += "\n" + background_prompt - - # if history_prompt: - # prompt += "\n" + history_prompt - - # if selfmemory_prompt: - # prompt += "\n" + selfmemory_prompt - - # logger.debug(f"{self.role.role_name} extra_system_prompt: {self.role.role_prompt}") - # logger.debug(f"{self.role.role_name} input_query: {input_query}") - # logger.debug(f"{self.role.role_name} doc_infos: {doc_infos}") - # logger.debug(f"{self.role.role_name} tool_names: {tool_names}") - # prompt += "\n" + REACT_PROMPT_INPUT.format(**{"query": input_query}) - - # prompt = extra_system_prompt.format(**{"query": input_query, "doc_infos": doc_infos, "formatted_tools": formatted_tools, "tool_names": tool_names}) - while "{{" in prompt or "}}" in prompt: - prompt = prompt.replace("{{", "{") - prompt = prompt.replace("}}", "}") - return prompt diff --git a/coagent/connector/agents/selector_agent.py b/coagent/connector/agents/selector_agent.py index 76a2012..17e6ce5 100644 --- a/coagent/connector/agents/selector_agent.py +++ b/coagent/connector/agents/selector_agent.py @@ -3,13 +3,15 @@ import copy import random from loguru import logger +from langchain.schema import BaseRetriever + from coagent.connector.schema import ( Memory, Task, Role, Message, PromptField, LogVerboseEnum ) from coagent.connector.memory_manager import BaseMemoryManager -from coagent.connector.configs.prompts import BEGIN_PROMPT_INPUT from coagent.connector.memory_manager import LocalMemoryManager from coagent.llm_models import LLMConfig, EmbedConfig +from coagent.base_configs.env_config import JUPYTER_WORK_PATH, KB_ROOT_PATH from .base_agent import BaseAgent @@ -30,14 +32,17 @@ class SelectorAgent(BaseAgent): llm_config: LLMConfig = None, embed_config: EmbedConfig = None, sandbox_server: dict = {}, - jupyter_work_path: str = "", - kb_root_path: str = "", + jupyter_work_path: str = JUPYTER_WORK_PATH, + kb_root_path: str = KB_ROOT_PATH, + doc_retrieval: Union[BaseRetriever] = None, + code_retrieval = None, + search_retrieval = None, log_verbose: str = "0" ): super().__init__(role, prompt_config, prompt_manager_type, task, memory, chat_turn, focus_agents, focus_message_keys, llm_config, embed_config, sandbox_server, - jupyter_work_path, kb_root_path, log_verbose + jupyter_work_path, kb_root_path, doc_retrieval, code_retrieval, search_retrieval, log_verbose ) self.group_agents = group_agents @@ -56,9 +61,8 @@ class SelectorAgent(BaseAgent): llm_config=self.embed_config ) memory_manager.append(query) - memory_pool = memory_manager.current_memory - else: - memory_pool = memory_manager.current_memory + memory_pool = memory_manager.get_memory_pool(query_c.user_name) + prompt = self.prompt_manager.generate_full_prompt( previous_agent_message=query_c, agent_long_term_memory=self.memory, ui_history=history, chain_summary_messages=background, react_memory=None, memory_pool=memory_pool, agents=self.group_agents) @@ -90,6 +94,9 @@ class SelectorAgent(BaseAgent): for agent in self.group_agents: if agent.role.role_name == select_message.parsed_output.get("Role", ""): break + + # 把除了role以外的信息传给下一个agent + query_c.parsed_output.update({k:v for k,v in select_message.parsed_output.items() if k!="Role"}) for output_message in agent.astep(query_c, history, background=background, memory_manager=memory_manager): yield output_message or select_message # update self_memory @@ -103,6 +110,7 @@ class SelectorAgent(BaseAgent): memory_manager.append(output_message) select_message.parsed_output = output_message.parsed_output + select_message.spec_parsed_output.update(output_message.spec_parsed_output) select_message.parsed_output_list.extend(output_message.parsed_output_list) yield select_message @@ -114,77 +122,4 @@ class SelectorAgent(BaseAgent): print("#"*len(title) + f"\n{title}\n"+ "#"*len(title)+ f"\n\n{prompt}\n\n") for agent in self.group_agents: - agent.pre_print(query=query, history=history, background=background, memory_manager=memory_manager) - - # def create_prompt( - # self, query: Message, memory: Memory =None, history: Memory = None, background: Memory = None, memory_manager: BaseMemoryManager=None, prompt_mamnger=None) -> str: - # ''' - # role\task\tools\docs\memory - # ''' - # # - # doc_infos = self.create_doc_prompt(query) - # code_infos = self.create_codedoc_prompt(query) - # # - # formatted_tools, tool_names, tools_descs = self.create_tools_prompt(query) - # agent_names, agents = self.create_agent_names() - # task_prompt = self.create_task_prompt(query) - # background_prompt = self.create_background_prompt(background) - # history_prompt = self.create_history_prompt(history) - # selfmemory_prompt = self.create_selfmemory_prompt(memory, control_key="step_content") - - - # DocInfos = "" - # if doc_infos is not None and doc_infos!="" and doc_infos!="不存在知识库辅助信息": - # DocInfos += f"\nDocument Information: {doc_infos}" - - # if code_infos is not None and code_infos!="" and code_infos!="不存在代码库辅助信息": - # DocInfos += f"\nCodeBase Infomation: {code_infos}" - - # input_query = query.input_query - # logger.debug(f"{self.role.role_name} input_query: {input_query}") - # prompt = self.role.role_prompt.format(**{"agent_names": agent_names, "agents": agents, "formatted_tools": tools_descs, "tool_names": tool_names}) - # # - # memory_pool_select_by_agent_key = self.select_memory_by_agent_key(memory_manager.current_memory) - # memory_pool_select_by_agent_key_context = '\n\n'.join([f"*{k}*\n{v}" for parsed_output in memory_pool_select_by_agent_key.get_parserd_output_list() for k, v in parsed_output.items() if k not in ['Action Status']]) - - # input_keys = parse_section(self.role.role_prompt, 'Input Format') - # # - # prompt += "\n" + BEGIN_PROMPT_INPUT - # for input_key in input_keys: - # if input_key == "Origin Query": - # prompt += "\n**Origin Query:**\n" + query.origin_query - # elif input_key == "Context": - # context = "\n".join([f"*{k}*\n{v}" for i in query.parsed_output_list for k,v in i.items() if "Action Status" !=k]) - # if history: - # context = history_prompt + "\n" + context - # if not context: - # context = "there is no context" - - # if self.focus_agents and memory_pool_select_by_agent_key_context: - # context = memory_pool_select_by_agent_key_context - # prompt += "\n**Context:**\n" + context + "\n" + input_query - # elif input_key == "DocInfos": - # prompt += "\n**DocInfos:**\n" + DocInfos - # elif input_key == "Question": - # prompt += "\n**Question:**\n" + input_query - - # while "{{" in prompt or "}}" in prompt: - # prompt = prompt.replace("{{", "{") - # prompt = prompt.replace("}}", "}") - - # # logger.debug(f"{self.role.role_name} prompt: {prompt}") - # return prompt - - # def create_agent_names(self): - # random.shuffle(self.group_agents) - # agent_names = ", ".join([f'{agent.role.role_name}' for agent in self.group_agents]) - # agent_descs = [] - # for agent in self.group_agents: - # role_desc = agent.role.role_prompt.split("####")[1] - # while "\n\n" in role_desc: - # role_desc = role_desc.replace("\n\n", "\n") - # role_desc = role_desc.replace("\n", ",") - - # agent_descs.append(f'"role name: {agent.role.role_name}\nrole description: {role_desc}"') - - # return agent_names, "\n".join(agent_descs) \ No newline at end of file + agent.pre_print(query=query, history=history, background=background, memory_manager=memory_manager) \ No newline at end of file diff --git a/coagent/connector/antflow/__init__.py b/coagent/connector/antflow/__init__.py new file mode 100644 index 0000000..633d975 --- /dev/null +++ b/coagent/connector/antflow/__init__.py @@ -0,0 +1,7 @@ +from .flow import AgentFlow, PhaseFlow, ChainFlow + + + +__all__ = [ + "AgentFlow", "PhaseFlow", "ChainFlow" +] \ No newline at end of file diff --git a/coagent/connector/antflow/flow.py b/coagent/connector/antflow/flow.py new file mode 100644 index 0000000..0b131f6 --- /dev/null +++ b/coagent/connector/antflow/flow.py @@ -0,0 +1,255 @@ +import importlib +from typing import List, Union, Dict, Any +from loguru import logger +import os +from langchain.embeddings.base import Embeddings +from langchain.agents import Tool +from langchain.llms.base import BaseLLM, LLM + +from coagent.retrieval.base_retrieval import IMRertrieval +from coagent.llm_models.llm_config import EmbedConfig, LLMConfig +from coagent.connector.phase import BasePhase +from coagent.connector.agents import BaseAgent +from coagent.connector.chains import BaseChain +from coagent.connector.schema import Message, Role, PromptField, ChainConfig +from coagent.tools import toLangchainTools, TOOL_DICT, TOOL_SETS + + +class AgentFlow: + def __init__( + self, + role_name: str, + agent_type: str, + role_type: str = "assistant", + agent_index: int = 0, + role_prompt: str = "", + prompt_config: List[Dict[str, Any]] = [], + prompt_manager_type: str = "PromptManager", + chat_turn: int = 3, + focus_agents: List[str] = [], + focus_messages: List[str] = [], + embeddings: Embeddings = None, + llm: BaseLLM = None, + doc_retrieval: IMRertrieval = None, + code_retrieval: IMRertrieval = None, + search_retrieval: IMRertrieval = None, + **kwargs + ): + self.role_type = role_type + self.role_name = role_name + self.agent_type = agent_type + self.role_prompt = role_prompt + self.agent_index = agent_index + + self.prompt_config = prompt_config + self.prompt_manager_type = prompt_manager_type + + self.chat_turn = chat_turn + self.focus_agents = focus_agents + self.focus_messages = focus_messages + + self.embeddings = embeddings + self.llm = llm + self.doc_retrieval = doc_retrieval + self.code_retrieval = code_retrieval + self.search_retrieval = search_retrieval + # self.build_config() + # self.build_agent() + + def build_config(self, embeddings: Embeddings = None, llm: BaseLLM = None): + self.llm_config = LLMConfig(model_name="test", llm=self.llm or llm) + self.embed_config = EmbedConfig(embed_model="test", langchain_embeddings=self.embeddings or embeddings) + + def build_agent(self, + embeddings: Embeddings = None, llm: BaseLLM = None, + doc_retrieval: IMRertrieval = None, + code_retrieval: IMRertrieval = None, + search_retrieval: IMRertrieval = None, + ): + # 可注册个性化的agent,仅通过start_action和end_action来注册 + # class ExtraAgent(BaseAgent): + # def start_action_step(self, message: Message) -> Message: + # pass + + # def end_action_step(self, message: Message) -> Message: + # pass + # agent_module = importlib.import_module("coagent.connector.agents") + # setattr(agent_module, 'extraAgent', ExtraAgent) + + # 可注册个性化的prompt组装方式, + # class CodeRetrievalPM(PromptManager): + # def handle_code_packages(self, **kwargs) -> str: + # if 'previous_agent_message' not in kwargs: + # return "" + # previous_agent_message: Message = kwargs['previous_agent_message'] + # # 由于两个agent共用了同一个manager,所以临时性处理 + # vertices = previous_agent_message.customed_kargs.get("RelatedVerticesRetrivalRes", {}).get("vertices", []) + # return ", ".join([str(v) for v in vertices]) + + # prompt_manager_module = importlib.import_module("coagent.connector.prompt_manager") + # setattr(prompt_manager_module, 'CodeRetrievalPM', CodeRetrievalPM) + + # agent实例化 + agent_module = importlib.import_module("coagent.connector.agents") + baseAgent: BaseAgent = getattr(agent_module, self.agent_type) + role = Role( + role_type=self.agent_type, role_name=self.role_name, + agent_type=self.agent_type, role_prompt=self.role_prompt, + ) + + self.build_config(embeddings, llm) + self.agent = baseAgent( + role=role, + prompt_config = [PromptField(**config) for config in self.prompt_config], + prompt_manager_type=self.prompt_manager_type, + chat_turn=self.chat_turn, + focus_agents=self.focus_agents, + focus_message_keys=self.focus_messages, + llm_config=self.llm_config, + embed_config=self.embed_config, + doc_retrieval=doc_retrieval or self.doc_retrieval, + code_retrieval=code_retrieval or self.code_retrieval, + search_retrieval=search_retrieval or self.search_retrieval, + ) + + + +class ChainFlow: + def __init__( + self, + chain_name: str, + chain_index: int = 0, + agent_flows: List[AgentFlow] = [], + chat_turn: int = 5, + do_checker: bool = False, + embeddings: Embeddings = None, + llm: BaseLLM = None, + doc_retrieval: IMRertrieval = None, + code_retrieval: IMRertrieval = None, + search_retrieval: IMRertrieval = None, + # chain_type: str = "BaseChain", + **kwargs + ): + self.agent_flows = sorted(agent_flows, key=lambda x:x.agent_index) + self.chat_turn = chat_turn + self.do_checker = do_checker + self.chain_name = chain_name + self.chain_index = chain_index + self.chain_type = "BaseChain" + + self.embeddings = embeddings + self.llm = llm + + self.doc_retrieval = doc_retrieval + self.code_retrieval = code_retrieval + self.search_retrieval = search_retrieval + # self.build_config() + # self.build_chain() + + def build_config(self, embeddings: Embeddings = None, llm: BaseLLM = None): + self.llm_config = LLMConfig(model_name="test", llm=self.llm or llm) + self.embed_config = EmbedConfig(embed_model="test", langchain_embeddings=self.embeddings or embeddings) + + def build_chain(self, + embeddings: Embeddings = None, llm: BaseLLM = None, + doc_retrieval: IMRertrieval = None, + code_retrieval: IMRertrieval = None, + search_retrieval: IMRertrieval = None, + ): + # chain 实例化 + chain_module = importlib.import_module("coagent.connector.chains") + baseChain: BaseChain = getattr(chain_module, self.chain_type) + + agent_names = [agent_flow.role_name for agent_flow in self.agent_flows] + chain_config = ChainConfig(chain_name=self.chain_name, agents=agent_names, do_checker=self.do_checker, chat_turn=self.chat_turn) + + # agent 实例化 + self.build_config(embeddings, llm) + for agent_flow in self.agent_flows: + agent_flow.build_agent(embeddings, llm) + + self.chain = baseChain( + chain_config, + [agent_flow.agent for agent_flow in self.agent_flows], + embed_config=self.embed_config, + llm_config=self.llm_config, + doc_retrieval=doc_retrieval or self.doc_retrieval, + code_retrieval=code_retrieval or self.code_retrieval, + search_retrieval=search_retrieval or self.search_retrieval, + ) + +class PhaseFlow: + def __init__( + self, + phase_name: str, + chain_flows: List[ChainFlow], + embeddings: Embeddings = None, + llm: BaseLLM = None, + tools: List[Tool] = [], + doc_retrieval: IMRertrieval = None, + code_retrieval: IMRertrieval = None, + search_retrieval: IMRertrieval = None, + **kwargs + ): + self.phase_name = phase_name + self.chain_flows = sorted(chain_flows, key=lambda x:x.chain_index) + self.phase_type = "BasePhase" + self.tools = tools + + self.embeddings = embeddings + self.llm = llm + + self.doc_retrieval = doc_retrieval + self.code_retrieval = code_retrieval + self.search_retrieval = search_retrieval + # self.build_config() + self.build_phase() + + def __call__(self, params: dict) -> str: + + # tools = toLangchainTools([TOOL_DICT[i] for i in TOOL_SETS if i in TOOL_DICT]) + # query_content = "帮我确认下127.0.0.1这个服务器的在10点是否存在异常,请帮我判断一下" + try: + logger.info(f"params: {params}") + query_content = params.get("query") or params.get("input") + search_type = params.get("search_type") + query = Message( + role_name="human", role_type="user", tools=self.tools, + role_content=query_content, input_query=query_content, origin_query=query_content, + cb_search_type=search_type, + ) + # phase.pre_print(query) + output_message, output_memory = self.phase.step(query) + output_content = "\n\n".join((output_memory.to_str_messages(return_all=True, content_key="parsed_output_list").split("\n\n")[1:])) or output_message.role_content + return output_content + except Exception as e: + logger.exception(e) + return f"Error {e}" + + def build_config(self, embeddings: Embeddings = None, llm: BaseLLM = None): + self.llm_config = LLMConfig(model_name="test", llm=self.llm or llm) + self.embed_config = EmbedConfig(embed_model="test", langchain_embeddings=self.embeddings or embeddings) + + def build_phase(self, embeddings: Embeddings = None, llm: BaseLLM = None): + # phase 实例化 + phase_module = importlib.import_module("coagent.connector.phase") + basePhase: BasePhase = getattr(phase_module, self.phase_type) + + # chain 实例化 + self.build_config(self.embeddings or embeddings, self.llm or llm) + os.environ["log_verbose"] = "2" + for chain_flow in self.chain_flows: + chain_flow.build_chain( + self.embeddings or embeddings, self.llm or llm, + self.doc_retrieval, self.code_retrieval, self.search_retrieval + ) + + self.phase: BasePhase = basePhase( + phase_name=self.phase_name, + chains=[chain_flow.chain for chain_flow in self.chain_flows], + embed_config=self.embed_config, + llm_config=self.llm_config, + doc_retrieval=self.doc_retrieval, + code_retrieval=self.code_retrieval, + search_retrieval=self.search_retrieval + ) diff --git a/coagent/connector/chains/base_chain.py b/coagent/connector/chains/base_chain.py index 7840ffb..7dc986c 100644 --- a/coagent/connector/chains/base_chain.py +++ b/coagent/connector/chains/base_chain.py @@ -1,9 +1,10 @@ -from typing import List +from typing import List, Tuple, Union from loguru import logger import copy, os -from coagent.connector.agents import BaseAgent +from langchain.schema import BaseRetriever +from coagent.connector.agents import BaseAgent from coagent.connector.schema import ( Memory, Role, Message, ActionStatus, ChainConfig, load_role_configs @@ -11,31 +12,32 @@ from coagent.connector.schema import ( from coagent.connector.memory_manager import BaseMemoryManager from coagent.connector.message_process import MessageUtils from coagent.llm_models.llm_config import LLMConfig, EmbedConfig +from coagent.base_configs.env_config import JUPYTER_WORK_PATH, KB_ROOT_PATH from coagent.connector.configs.agent_config import AGETN_CONFIGS role_configs = load_role_configs(AGETN_CONFIGS) -# from configs.model_config import JUPYTER_WORK_PATH -# from configs.server_config import SANDBOX_SERVER - class BaseChain: def __init__( self, - # chainConfig: ChainConfig, + chainConfig: ChainConfig, agents: List[BaseAgent], - chat_turn: int = 1, - do_checker: bool = False, + # chat_turn: int = 1, + # do_checker: bool = False, sandbox_server: dict = {}, - jupyter_work_path: str = "", - kb_root_path: str = "", + jupyter_work_path: str = JUPYTER_WORK_PATH, + kb_root_path: str = KB_ROOT_PATH, llm_config: LLMConfig = LLMConfig(), embed_config: EmbedConfig = None, + doc_retrieval: Union[BaseRetriever] = None, + code_retrieval = None, + search_retrieval = None, log_verbose: str = "0" ) -> None: - # self.chainConfig = chainConfig + self.chainConfig = chainConfig self.agents: List[BaseAgent] = agents - self.chat_turn = chat_turn - self.do_checker = do_checker + self.chat_turn = chainConfig.chat_turn + self.do_checker = chainConfig.do_checker self.sandbox_server = sandbox_server self.jupyter_work_path = jupyter_work_path self.llm_config = llm_config @@ -45,9 +47,11 @@ class BaseChain: task = None, memory = None, llm_config=llm_config, embed_config=embed_config, sandbox_server=sandbox_server, jupyter_work_path=jupyter_work_path, - kb_root_path=kb_root_path + kb_root_path=kb_root_path, + doc_retrieval=doc_retrieval, code_retrieval=code_retrieval, + search_retrieval=search_retrieval ) - self.messageUtils = MessageUtils(None, sandbox_server, self.jupyter_work_path, embed_config, llm_config, kb_root_path, log_verbose) + self.messageUtils = MessageUtils(None, sandbox_server, self.jupyter_work_path, embed_config, llm_config, kb_root_path, doc_retrieval, code_retrieval, search_retrieval, log_verbose) # all memory created by agent until instance deleted self.global_memory = Memory(messages=[]) @@ -62,13 +66,16 @@ class BaseChain: for agent in self.agents: agent.pre_print(query, history, background=background, memory_manager=memory_manager) - def astep(self, query: Message, history: Memory = None, background: Memory = None, memory_manager: BaseMemoryManager = None) -> Message: + def astep(self, query: Message, history: Memory = None, background: Memory = None, memory_manager: BaseMemoryManager = None) -> Tuple[Message, Memory]: '''execute chain''' local_memory = Memory(messages=[]) input_message = copy.deepcopy(query) step_nums = copy.deepcopy(self.chat_turn) check_message = None + # if input_message not in memory_manager: + # memory_manager.append(input_message) + self.global_memory.append(input_message) # local_memory.append(input_message) while step_nums > 0: @@ -78,7 +85,7 @@ class BaseChain: yield output_message, local_memory + output_message output_message = self.messageUtils.inherit_extrainfo(input_message, output_message) # according the output to choose one action for code_content or tool_content - output_message = self.messageUtils.parser(output_message) + # output_message = self.messageUtils.parser(output_message) yield output_message, local_memory + output_message # output_message = self.step_router(output_message) input_message = output_message diff --git a/coagent/connector/configs/__init__.py b/coagent/connector/configs/__init__.py index fd1c6bd..2cb4d69 100644 --- a/coagent/connector/configs/__init__.py +++ b/coagent/connector/configs/__init__.py @@ -1,9 +1,10 @@ from .agent_config import AGETN_CONFIGS from .chain_config import CHAIN_CONFIGS from .phase_config import PHASE_CONFIGS -from .prompt_config import BASE_PROMPT_CONFIGS, EXECUTOR_PROMPT_CONFIGS, SELECTOR_PROMPT_CONFIGS, BASE_NOTOOLPROMPT_CONFIGS +from .prompt_config import * __all__ = [ "AGETN_CONFIGS", "CHAIN_CONFIGS", "PHASE_CONFIGS", - "BASE_PROMPT_CONFIGS", "EXECUTOR_PROMPT_CONFIGS", "SELECTOR_PROMPT_CONFIGS", "BASE_NOTOOLPROMPT_CONFIGS" + "BASE_PROMPT_CONFIGS", "EXECUTOR_PROMPT_CONFIGS", "SELECTOR_PROMPT_CONFIGS", "BASE_NOTOOLPROMPT_CONFIGS", + "CODE2DOC_GROUP_PROMPT_CONFIGS", "CODE2DOC_PROMPT_CONFIGS", "CODE2TESTS_PROMPT_CONFIGS" ] \ No newline at end of file diff --git a/coagent/connector/configs/agent_config.py b/coagent/connector/configs/agent_config.py index 5073795..927e0a9 100644 --- a/coagent/connector/configs/agent_config.py +++ b/coagent/connector/configs/agent_config.py @@ -1,19 +1,21 @@ from enum import Enum -from .prompts import ( - REACT_PROMPT_INPUT, CHECK_PROMPT_INPUT, EXECUTOR_PROMPT_INPUT, CONTEXT_PROMPT_INPUT, QUERY_CONTEXT_PROMPT_INPUT,PLAN_PROMPT_INPUT, - RECOGNIZE_INTENTION_PROMPT, - CHECKER_TEMPLATE_PROMPT, - CONV_SUMMARY_PROMPT, - QA_PROMPT, CODE_QA_PROMPT, QA_TEMPLATE_PROMPT, - EXECUTOR_TEMPLATE_PROMPT, - REFINE_TEMPLATE_PROMPT, - SELECTOR_AGENT_TEMPLATE_PROMPT, - PLANNER_TEMPLATE_PROMPT, GENERAL_PLANNER_PROMPT, DATA_PLANNER_PROMPT, TOOL_PLANNER_PROMPT, - PRD_WRITER_METAGPT_PROMPT, DESIGN_WRITER_METAGPT_PROMPT, TASK_WRITER_METAGPT_PROMPT, CODE_WRITER_METAGPT_PROMPT, - REACT_TEMPLATE_PROMPT, - REACT_TOOL_PROMPT, REACT_CODE_PROMPT, REACT_TOOL_AND_CODE_PLANNER_PROMPT, REACT_TOOL_AND_CODE_PROMPT -) -from .prompt_config import BASE_PROMPT_CONFIGS, EXECUTOR_PROMPT_CONFIGS, SELECTOR_PROMPT_CONFIGS, BASE_NOTOOLPROMPT_CONFIGS +from .prompts import * +# from .prompts import ( +# REACT_PROMPT_INPUT, CHECK_PROMPT_INPUT, EXECUTOR_PROMPT_INPUT, CONTEXT_PROMPT_INPUT, QUERY_CONTEXT_PROMPT_INPUT,PLAN_PROMPT_INPUT, +# RECOGNIZE_INTENTION_PROMPT, +# CHECKER_TEMPLATE_PROMPT, +# CONV_SUMMARY_PROMPT, +# QA_PROMPT, CODE_QA_PROMPT, QA_TEMPLATE_PROMPT, +# EXECUTOR_TEMPLATE_PROMPT, +# REFINE_TEMPLATE_PROMPT, +# SELECTOR_AGENT_TEMPLATE_PROMPT, +# PLANNER_TEMPLATE_PROMPT, GENERAL_PLANNER_PROMPT, DATA_PLANNER_PROMPT, TOOL_PLANNER_PROMPT, +# PRD_WRITER_METAGPT_PROMPT, DESIGN_WRITER_METAGPT_PROMPT, TASK_WRITER_METAGPT_PROMPT, CODE_WRITER_METAGPT_PROMPT, +# REACT_TEMPLATE_PROMPT, +# REACT_TOOL_PROMPT, REACT_CODE_PROMPT, REACT_TOOL_AND_CODE_PLANNER_PROMPT, REACT_TOOL_AND_CODE_PROMPT +# ) +from .prompt_config import * +# BASE_PROMPT_CONFIGS, EXECUTOR_PROMPT_CONFIGS, SELECTOR_PROMPT_CONFIGS, BASE_NOTOOLPROMPT_CONFIGS @@ -261,4 +263,68 @@ AGETN_CONFIGS = { "focus_agents": ["metaGPT_DESIGN", "metaGPT_TASK"], "focus_message_keys": [], }, + "class2Docer": { + "role": { + "role_prompt": Class2Doc_PROMPT, + "role_type": "assistant", + "role_name": "class2Docer", + "role_desc": "", + "agent_type": "CodeGenDocer" + }, + "prompt_config": CODE2DOC_PROMPT_CONFIGS, + "prompt_manager_type": "Code2DocPM", + "chat_turn": 1, + "focus_agents": [], + "focus_message_keys": [], + }, + "func2Docer": { + "role": { + "role_prompt": Func2Doc_PROMPT, + "role_type": "assistant", + "role_name": "func2Docer", + "role_desc": "", + "agent_type": "CodeGenDocer" + }, + "prompt_config": CODE2DOC_PROMPT_CONFIGS, + "prompt_manager_type": "Code2DocPM", + "chat_turn": 1, + "focus_agents": [], + "focus_message_keys": [], + }, + "code2DocsGrouper": { + "role": { + "role_prompt": Code2DocGroup_PROMPT, + "role_type": "assistant", + "role_name": "code2DocsGrouper", + "role_desc": "", + "agent_type": "SelectorAgent" + }, + "prompt_config": CODE2DOC_GROUP_PROMPT_CONFIGS, + "group_agents": ["class2Docer", "func2Docer"], + "chat_turn": 1, + }, + "Code2TestJudger": { + "role": { + "role_prompt": judgeCode2Tests_PROMPT, + "role_type": "assistant", + "role_name": "Code2TestJudger", + "role_desc": "", + "agent_type": "CodeRetrieval" + }, + "prompt_config": CODE2TESTS_PROMPT_CONFIGS, + "prompt_manager_type": "CodeRetrievalPM", + "chat_turn": 1, + }, + "code2Tests": { + "role": { + "role_prompt": code2Tests_PROMPT, + "role_type": "assistant", + "role_name": "code2Tests", + "role_desc": "", + "agent_type": "CodeRetrieval" + }, + "prompt_config": CODE2TESTS_PROMPT_CONFIGS, + "prompt_manager_type": "CodeRetrievalPM", + "chat_turn": 1, + }, } \ No newline at end of file diff --git a/coagent/connector/configs/chain_config.py b/coagent/connector/configs/chain_config.py index ba49bc4..587ef95 100644 --- a/coagent/connector/configs/chain_config.py +++ b/coagent/connector/configs/chain_config.py @@ -123,5 +123,21 @@ CHAIN_CONFIGS = { "chat_turn": 1, "do_checker": False, "chain_prompt": "" + }, + "code2DocsGroupChain": { + "chain_name": "code2DocsGroupChain", + "chain_type": "BaseChain", + "agents": ["code2DocsGrouper"], + "chat_turn": 1, + "do_checker": False, + "chain_prompt": "" + }, + "code2TestsChain": { + "chain_name": "code2TestsChain", + "chain_type": "BaseChain", + "agents": ["Code2TestJudger", "code2Tests"], + "chat_turn": 1, + "do_checker": False, + "chain_prompt": "" } } diff --git a/coagent/connector/configs/phase_config.py b/coagent/connector/configs/phase_config.py index 48be270..0e88611 100644 --- a/coagent/connector/configs/phase_config.py +++ b/coagent/connector/configs/phase_config.py @@ -14,44 +14,24 @@ PHASE_CONFIGS = { "phase_name": "docChatPhase", "phase_type": "BasePhase", "chains": ["docChatChain"], - "do_summary": False, - "do_search": False, "do_doc_retrieval": True, - "do_code_retrieval": False, - "do_tool_retrieval": False, - "do_using_tool": False }, "searchChatPhase": { "phase_name": "searchChatPhase", "phase_type": "BasePhase", "chains": ["searchChatChain"], - "do_summary": False, "do_search": True, - "do_doc_retrieval": False, - "do_code_retrieval": False, - "do_tool_retrieval": False, - "do_using_tool": False }, "codeChatPhase": { "phase_name": "codeChatPhase", "phase_type": "BasePhase", "chains": ["codeChatChain"], - "do_summary": False, - "do_search": False, - "do_doc_retrieval": False, "do_code_retrieval": True, - "do_tool_retrieval": False, - "do_using_tool": False }, "toolReactPhase": { "phase_name": "toolReactPhase", "phase_type": "BasePhase", "chains": ["toolReactChain"], - "do_summary": False, - "do_search": False, - "do_doc_retrieval": False, - "do_code_retrieval": False, - "do_tool_retrieval": False, "do_using_tool": True }, "codeReactPhase": { @@ -59,55 +39,36 @@ PHASE_CONFIGS = { "phase_type": "BasePhase", # "chains": ["codePlannerChain", "codeReactChain"], "chains": ["planChain", "codeReactChain"], - "do_summary": False, - "do_search": False, - "do_doc_retrieval": False, - "do_code_retrieval": False, - "do_tool_retrieval": False, - "do_using_tool": False }, "codeToolReactPhase": { "phase_name": "codeToolReactPhase", "phase_type": "BasePhase", "chains": ["codeToolPlanChain", "codeToolReactChain"], - "do_summary": False, - "do_search": False, - "do_doc_retrieval": False, - "do_code_retrieval": False, - "do_tool_retrieval": False, "do_using_tool": True }, "baseTaskPhase": { "phase_name": "baseTaskPhase", "phase_type": "BasePhase", "chains": ["planChain", "executorChain"], - "do_summary": False, - "do_search": False, - "do_doc_retrieval": False, - "do_code_retrieval": False, - "do_tool_retrieval": False, - "do_using_tool": False }, "metagpt_code_devlop": { "phase_name": "metagpt_code_devlop", "phase_type": "BasePhase", "chains": ["metagptChain",], - "do_summary": False, - "do_search": False, - "do_doc_retrieval": False, - "do_code_retrieval": False, - "do_tool_retrieval": False, - "do_using_tool": False }, "baseGroupPhase": { "phase_name": "baseGroupPhase", "phase_type": "BasePhase", "chains": ["baseGroupChain"], - "do_summary": False, - "do_search": False, - "do_doc_retrieval": False, - "do_code_retrieval": False, - "do_tool_retrieval": False, - "do_using_tool": False }, + "code2DocsGroup": { + "phase_name": "code2DocsGroup", + "phase_type": "BasePhase", + "chains": ["code2DocsGroupChain"], + }, + "code2Tests": { + "phase_name": "code2Tests", + "phase_type": "BasePhase", + "chains": ["code2TestsChain"], + } } diff --git a/coagent/connector/configs/prompt_config.py b/coagent/connector/configs/prompt_config.py index 45c0f3d..4fd049e 100644 --- a/coagent/connector/configs/prompt_config.py +++ b/coagent/connector/configs/prompt_config.py @@ -40,4 +40,41 @@ SELECTOR_PROMPT_CONFIGS = [ {"field_name": 'current_plan', "function_name": 'handle_current_plan'}, {"field_name": 'output_format', "function_name": 'handle_output_format', 'title': 'Response Output Format', "is_context": False}, {"field_name": 'begin!!!', "function_name": 'handle_response', "is_context": False, "omit_if_empty": False} - ] \ No newline at end of file + ] + + +CODE2DOC_GROUP_PROMPT_CONFIGS = [ + {"field_name": 'agent_profile', "function_name": 'handle_agent_profile', "is_context": False}, + {"field_name": 'agent_infomation', "function_name": 'handle_agent_data', "is_context": False, "omit_if_empty": False}, + # {"field_name": 'tool_information',"function_name": 'handle_tool_data', "is_context": False}, + {"field_name": 'context_placeholder', "function_name": '', "is_context": True}, + # {"field_name": 'reference_documents', "function_name": 'handle_doc_info'}, + {"field_name": 'session_records', "function_name": 'handle_session_records'}, + {"field_name": 'Specific Objective', "function_name": 'handle_specific_objective'}, + {"field_name": 'Code Snippet', "function_name": 'handle_code_snippet'}, + {"field_name": 'output_format', "function_name": 'handle_output_format', 'title': 'Response Output Format', "is_context": False}, + {"field_name": 'begin!!!', "function_name": 'handle_response', "is_context": False, "omit_if_empty": False} +] + +CODE2DOC_PROMPT_CONFIGS = [ + {"field_name": 'agent_profile', "function_name": 'handle_agent_profile', "is_context": False}, + # {"field_name": 'tool_information',"function_name": 'handle_tool_data', "is_context": False}, + {"field_name": 'context_placeholder', "function_name": '', "is_context": True}, + # {"field_name": 'reference_documents', "function_name": 'handle_doc_info'}, + {"field_name": 'session_records', "function_name": 'handle_session_records'}, + {"field_name": 'Specific Objective', "function_name": 'handle_specific_objective'}, + {"field_name": 'Code Snippet', "function_name": 'handle_code_snippet'}, + {"field_name": 'output_format', "function_name": 'handle_output_format', 'title': 'Response Output Format', "is_context": False}, + {"field_name": 'begin!!!', "function_name": 'handle_response', "is_context": False, "omit_if_empty": False} +] + + +CODE2TESTS_PROMPT_CONFIGS = [ + {"field_name": 'agent_profile', "function_name": 'handle_agent_profile', "is_context": False}, + {"field_name": 'context_placeholder', "function_name": '', "is_context": True}, + {"field_name": 'session_records', "function_name": 'handle_session_records'}, + {"field_name": 'code_snippet', "function_name": 'handle_code_snippet'}, + {"field_name": 'retrieval_codes', "function_name": 'handle_retrieval_codes', "description": ""}, + {"field_name": 'output_format', "function_name": 'handle_output_format', 'title': 'Response Output Format', "is_context": False}, + {"field_name": 'begin!!!', "function_name": 'handle_response', "is_context": False, "omit_if_empty": False} +] \ No newline at end of file diff --git a/coagent/connector/configs/prompts/__init__.py b/coagent/connector/configs/prompts/__init__.py index e939e07..5e550b5 100644 --- a/coagent/connector/configs/prompts/__init__.py +++ b/coagent/connector/configs/prompts/__init__.py @@ -14,7 +14,8 @@ from .qa_template_prompt import QA_PROMPT, CODE_QA_PROMPT, QA_TEMPLATE_PROMPT, C from .executor_template_prompt import EXECUTOR_TEMPLATE_PROMPT from .refine_template_prompt import REFINE_TEMPLATE_PROMPT - +from .code2doc_template_prompt import Code2DocGroup_PROMPT, Class2Doc_PROMPT, Func2Doc_PROMPT +from .code2test_template_prompt import code2Tests_PROMPT, judgeCode2Tests_PROMPT from .agent_selector_template_prompt import SELECTOR_AGENT_TEMPLATE_PROMPT from .react_template_prompt import REACT_TEMPLATE_PROMPT @@ -37,5 +38,7 @@ __all__ = [ "SELECTOR_AGENT_TEMPLATE_PROMPT", "PLANNER_TEMPLATE_PROMPT", "GENERAL_PLANNER_PROMPT", "DATA_PLANNER_PROMPT", "TOOL_PLANNER_PROMPT", "REACT_TEMPLATE_PROMPT", - "REACT_CODE_PROMPT", "REACT_TOOL_PROMPT", "REACT_TOOL_AND_CODE_PROMPT", "REACT_TOOL_AND_CODE_PLANNER_PROMPT" + "REACT_CODE_PROMPT", "REACT_TOOL_PROMPT", "REACT_TOOL_AND_CODE_PROMPT", "REACT_TOOL_AND_CODE_PLANNER_PROMPT", + "Code2DocGroup_PROMPT", "Class2Doc_PROMPT", "Func2Doc_PROMPT", + "code2Tests_PROMPT", "judgeCode2Tests_PROMPT" ] \ No newline at end of file diff --git a/coagent/connector/configs/prompts/code2doc_template_prompt.py b/coagent/connector/configs/prompts/code2doc_template_prompt.py new file mode 100644 index 0000000..5a2e915 --- /dev/null +++ b/coagent/connector/configs/prompts/code2doc_template_prompt.py @@ -0,0 +1,95 @@ +Code2DocGroup_PROMPT = """#### Agent Profile + +Your goal is to response according the Context Data's information with the role that will best facilitate a solution, taking into account all relevant context (Context) provided. + +When you need to select the appropriate role for handling a user's query, carefully read the provided role names, role descriptions and tool list. + +ATTENTION: response carefully referenced "Response Output Format" in format. + +#### Input Format + +#### Response Output Format + +**Code Path:** Extract the paths for the class/method/function that need to be addressed from the context + +**Role:** Select the role from agent names +""" + +Class2Doc_PROMPT = """#### Agent Profile +As an advanced code documentation generator, you are proficient in translating class definitions into comprehensive documentation with a focus on instantiation parameters. +Your specific task is to parse the given code snippet of a class, extract information regarding its instantiation parameters. + +ATTENTION: response carefully in "Response Output Format". + +#### Input Format + +**Code Snippet:** Provide the full class definition, including the constructor and any parameters it may require for instantiation. + +#### Response Output Format +**Class Base:** Specify the base class or interface from which the current class extends, if any. + +**Class Description:** Offer a brief description of the class's purpose and functionality. + +**Init Parameters:** List each parameter from construct. For each parameter, provide: + - `param`: The parameter name + - `param_description`: A concise explanation of the parameter's purpose. + - `param_type`: The data type of the parameter, if explicitly defined. + + ```json + [ + { + "param": "parameter_name", + "param_description": "A brief description of what this parameter is used for.", + "param_type": "The data type of the parameter" + }, + ... + ] + ``` + + + If no parameter for construct, return + ```json + [] + ``` +""" + +Func2Doc_PROMPT = """#### Agent Profile +You are a high-level code documentation assistant, skilled at extracting information from function/method code into detailed and well-structured documentation. + +ATTENTION: response carefully in "Response Output Format". + + +#### Input Format +**Code Path:** Provide the code path of the function or method you wish to document. +This name will be used to identify and extract the relevant details from the code snippet provided. + +**Code Snippet:** A segment of code that contains the function or method to be documented. + +#### Response Output Format + +**Class Description:** Offer a brief description of the method(function)'s purpose and functionality. + +**Parameters:** Extract parameter for the specific function/method Code from Code Snippet. For parameter, provide: + - `param`: The parameter name + - `param_description`: A concise explanation of the parameter's purpose. + - `param_type`: The data type of the parameter, if explicitly defined. + ```json + [ + { + "param": "parameter_name", + "param_description": "A brief description of what this parameter is used for.", + "param_type": "The data type of the parameter" + }, + ... + ] + ``` + + If no parameter for function/method, return + ```json + [] + ``` + +**Return Value Description:** Describe what the function/method returns upon completion. + +**Return Type:** Indicate the type of data the function/method returns (e.g., string, integer, object, void). +""" \ No newline at end of file diff --git a/coagent/connector/configs/prompts/code2test_template_prompt.py b/coagent/connector/configs/prompts/code2test_template_prompt.py new file mode 100644 index 0000000..e0e1c5f --- /dev/null +++ b/coagent/connector/configs/prompts/code2test_template_prompt.py @@ -0,0 +1,65 @@ +judgeCode2Tests_PROMPT = """#### Agent Profile +When determining the necessity of writing test cases for a given code snippet, +it's essential to evaluate its interactions with dependent classes and methods (retrieved code snippets), +in addition to considering these critical factors: +1. Functionality: If it implements a concrete function or logic, test cases are typically necessary to verify its correctness. +2. Complexity: If the code is complex, especially if it contains multiple conditional statements, loops, exceptions handling, etc., +it's more likely to harbor bugs, and thus test cases should be written. +If the code involves complex algorithms or logic, then writing test cases can help ensure the accuracy of the logic and prevent errors during future refactoring. +3. Criticality: If it's part of the critical path or affects core functionalities, then it needs to be tested. +Comprehensive test cases should be written for core business logic or key components of the system to ensure the correctness and stability of the functionality. +4. Dependencies: If the code has external dependencies, integration testing may be necessary, or mocking these dependencies during unit testing might be required. +5. User Input: If the code handles user input, especially from unregulated external sources, creating test cases to check input validation and handling is important. +6. Frequent Changes: For code that requires regular updates or modifications, having the appropriate test cases ensures that changes do not break existing functionalities. + +#### Input Format + +**Code Snippet:** the initial Code or objective that the user wanted to achieve + +**Retrieval Code Snippets:** These are the associated code segments that the main Code Snippet depends on. +Examine these snippets to understand how they interact with the main snippet and to determine how they might affect the overall functionality. + +#### Response Output Format +**Action Status:** Set to 'finished' or 'continued'. +If set to 'finished', the code snippet does not warrant the generation of a test case. +If set to 'continued', the code snippet necessitates the creation of a test case. + +**REASON:** Justify the selection of 'finished' or 'continued', contemplating the decision through a step-by-step rationale. +""" + +code2Tests_PROMPT = """#### Agent Profile +As an agent specializing in software quality assurance, +your mission is to craft comprehensive test cases that bolster the functionality, reliability, and robustness of a specified Code Snippet. +This task is to be carried out with a keen understanding of the snippet's interactions with its dependent classes and methods—collectively referred to as Retrieval Code Snippets. +Analyze the details given below to grasp the code's intended purpose, its inherent complexity, and the context within which it operates. +Your constructed test cases must thoroughly examine the various factors influencing the code's quality and performance. + +ATTENTION: response carefully referenced "Response Output Format" in format. + +Each test case should include: +1. clear description of the test purpose. +2. The input values or conditions for the test. +3. The expected outcome or assertion for the test. +4. Appropriate tags (e.g., 'functional', 'integration', 'regression') that classify the type of test case. +5. these test code should have package and import + +#### Input Format + +**Code Snippet:** the initial Code or objective that the user wanted to achieve + +**Retrieval Code Snippets:** These are the interrelated pieces of code sourced from the codebase, which support or influence the primary Code Snippet. + +#### Response Output Format +**SaveFileName:** construct a local file name based on Question and Context, such as + +```java +package/class.java +``` + + +**Test Code:** generate the test code for the current Code Snippet. +```java +... +``` + +""" \ No newline at end of file diff --git a/coagent/connector/memory_manager.py b/coagent/connector/memory_manager.py index bbfceaf..a7d8436 100644 --- a/coagent/connector/memory_manager.py +++ b/coagent/connector/memory_manager.py @@ -1,5 +1,5 @@ from abc import abstractmethod, ABC -from typing import List +from typing import List, Dict import os, sys, copy, json from jieba.analyse import extract_tags from collections import Counter @@ -10,12 +10,13 @@ from langchain.docstore.document import Document from .schema import Memory, Message from coagent.service.service_factory import KBServiceFactory -from coagent.llm_models import getChatModel, getChatModelFromConfig +from coagent.llm_models import getChatModelFromConfig from coagent.llm_models.llm_config import EmbedConfig, LLMConfig from coagent.embeddings.utils import load_embeddings_from_path from coagent.utils.common_utils import save_to_json_file, read_json_file, addMinutesToTime from coagent.connector.configs.prompts import CONV_SUMMARY_PROMPT_SPEC from coagent.orm import table_init +from coagent.base_configs.env_config import KB_ROOT_PATH # from configs.model_config import KB_ROOT_PATH, EMBEDDING_MODEL, EMBEDDING_DEVICE, SCORE_THRESHOLD # from configs.model_config import embedding_model_dict @@ -70,16 +71,22 @@ class BaseMemoryManager(ABC): self.unique_name = unique_name self.memory_type = memory_type self.do_init = do_init - self.current_memory = Memory(messages=[]) - self.recall_memory = Memory(messages=[]) - self.summary_memory = Memory(messages=[]) + # self.current_memory = Memory(messages=[]) + # self.recall_memory = Memory(messages=[]) + # self.summary_memory = Memory(messages=[]) + self.current_memory_dict: Dict[str, Memory] = {} + self.recall_memory_dict: Dict[str, Memory] = {} + self.summary_memory_dict: Dict[str, Memory] = {} self.save_message_keys = [ 'chat_index', 'role_name', 'role_type', 'role_prompt', 'input_query', 'origin_query', 'datetime', 'role_content', 'step_content', 'parsed_output', 'spec_parsed_output', 'parsed_output_list', 'task', 'db_docs', 'code_docs', 'search_docs', 'phase_name', 'chain_name', 'customed_kargs'] self.init_vb() - def init_vb(self): + def re_init(self, do_init: bool=False): + self.init_vb() + + def init_vb(self, do_init: bool=None): """ Initializes the vb. """ @@ -135,13 +142,15 @@ class BaseMemoryManager(ABC): """ pass - def save_to_vs(self, embed_model="", embed_device=""): + def save_to_vs(self, ): """ Saves the memory to the vector space. + """ + pass - Args: - - embed_model: A string representing the embedding model. Default is EMBEDDING_MODEL. - - embed_device: A string representing the embedding device. Default is EMBEDDING_DEVICE. + def get_memory_pool(self, user_name: str, ): + """ + return memory_pool """ pass @@ -230,7 +239,7 @@ class LocalMemoryManager(BaseMemoryManager): unique_name: str = "default", memory_type: str = "recall", do_init: bool = False, - kb_root_path: str = "", + kb_root_path: str = KB_ROOT_PATH, ): self.user_name = user_name self.unique_name = unique_name @@ -239,16 +248,22 @@ class LocalMemoryManager(BaseMemoryManager): self.kb_root_path = kb_root_path self.embed_config: EmbedConfig = embed_config self.llm_config: LLMConfig = llm_config - self.current_memory = Memory(messages=[]) - self.recall_memory = Memory(messages=[]) - self.summary_memory = Memory(messages=[]) + # self.current_memory = Memory(messages=[]) + # self.recall_memory = Memory(messages=[]) + # self.summary_memory = Memory(messages=[]) + self.current_memory_dict: Dict[str, Memory] = {} + self.recall_memory_dict: Dict[str, Memory] = {} + self.summary_memory_dict: Dict[str, Memory] = {} self.save_message_keys = [ 'chat_index', 'role_name', 'role_type', 'role_prompt', 'input_query', 'origin_query', 'datetime', 'role_content', 'step_content', 'parsed_output', 'spec_parsed_output', 'parsed_output_list', 'task', 'db_docs', 'code_docs', 'search_docs', 'phase_name', 'chain_name', 'customed_kargs'] self.init_vb() - def init_vb(self): + def re_init(self, do_init: bool=False): + self.init_vb(do_init) + + def init_vb(self, do_init: bool=None): vb_name = f"{self.user_name}/{self.unique_name}/{self.memory_type}" # default to recreate a new vb table_init() @@ -256,31 +271,37 @@ class LocalMemoryManager(BaseMemoryManager): if vb: status = vb.clear_vs() - if not self.do_init: + check_do_init = do_init if do_init else self.do_init + if not check_do_init: self.load(self.kb_root_path) self.save_to_vs() def append(self, message: Message): - self.recall_memory.append(message) + self.check_user_name(message.user_name) + + uuid_name = "_".join([self.user_name, self.unique_name, self.memory_type]) + self.recall_memory_dict[uuid_name].append(message) # if message.role_type == "summary": - self.summary_memory.append(message) + self.summary_memory_dict[uuid_name].append(message) else: - self.current_memory.append(message) + self.current_memory_dict[uuid_name].append(message) self.save(self.kb_root_path) self.save_new_to_vs([message]) - def extend(self, memory: Memory): - self.recall_memory.extend(memory) - self.current_memory.extend(self.recall_memory.filter_by_role_type(["summary"])) - self.summary_memory.extend(self.recall_memory.select_by_role_type(["summary"])) - self.save(self.kb_root_path) - self.save_new_to_vs(memory.messages) + # def extend(self, memory: Memory): + # self.recall_memory.extend(memory) + # self.current_memory.extend(self.recall_memory.filter_by_role_type(["summary"])) + # self.summary_memory.extend(self.recall_memory.select_by_role_type(["summary"])) + # self.save(self.kb_root_path) + # self.save_new_to_vs(memory.messages) def save(self, save_dir: str = "./"): file_path = os.path.join(save_dir, f"{self.user_name}/{self.unique_name}/{self.memory_type}/converation.jsonl") - memory_messages = self.recall_memory.dict() + uuid_name = "_".join([self.user_name, self.unique_name, self.memory_type]) + + memory_messages = self.recall_memory_dict[uuid_name].dict() memory_messages = {k: [ {kkk: vvv for kkk, vvv in vv.items() if kkk in self.save_message_keys} for vv in v ] @@ -291,18 +312,28 @@ class LocalMemoryManager(BaseMemoryManager): def load(self, load_dir: str = "./") -> Memory: file_path = os.path.join(load_dir, f"{self.user_name}/{self.unique_name}/{self.memory_type}/converation.jsonl") + uuid_name = "_".join([self.user_name, self.unique_name, self.memory_type]) if os.path.exists(file_path): - self.recall_memory = Memory(**read_json_file(file_path)) - self.current_memory = Memory(messages=self.recall_memory.filter_by_role_type(["summary"])) - self.summary_memory = Memory(messages=self.recall_memory.select_by_role_type(["summary"])) + # self.recall_memory = Memory(**read_json_file(file_path)) + # self.current_memory = Memory(messages=self.recall_memory.filter_by_role_type(["summary"])) + # self.summary_memory = Memory(messages=self.recall_memory.select_by_role_type(["summary"])) + + recall_memory = Memory(**read_json_file(file_path)) + self.recall_memory_dict[uuid_name] = recall_memory + self.current_memory_dict[uuid_name] = Memory(messages=recall_memory.filter_by_role_type(["summary"])) + self.summary_memory_dict[uuid_name] = Memory(messages=recall_memory.select_by_role_type(["summary"])) + else: + self.recall_memory_dict[uuid_name] = Memory(messages=[]) + self.current_memory_dict[uuid_name] = Memory(messages=[]) + self.summary_memory_dict[uuid_name] = Memory(messages=[]) def save_new_to_vs(self, messages: List[Message]): if self.embed_config: vb_name = f"{self.user_name}/{self.unique_name}/{self.memory_type}" # default to faiss, todo: add new vstype vb = KBServiceFactory.get_service(vb_name, "faiss", self.embed_config, self.kb_root_path) - embeddings = load_embeddings_from_path(self.embed_config.embed_model_path, self.embed_config.model_device,) + embeddings = load_embeddings_from_path(self.embed_config.embed_model_path, self.embed_config.model_device, self.embed_config.langchain_embeddings) messages = [ {k: v for k, v in m.dict().items() if k in self.save_message_keys} for m in messages] @@ -311,23 +342,26 @@ class LocalMemoryManager(BaseMemoryManager): vb.do_add_doc(docs, embeddings) def save_to_vs(self): - vb_name = f"{self.user_name}/{self.unique_name}/{self.memory_type}" - # default to recreate a new vb - vb = KBServiceFactory.get_service_by_name(vb_name, self.embed_config, self.kb_root_path) - if vb: - status = vb.clear_vs() - # create_kb(vb_name, "faiss", embed_model) + '''only after load''' + if self.embed_config: + vb_name = f"{self.user_name}/{self.unique_name}/{self.memory_type}" + uuid_name = "_".join([self.user_name, self.unique_name, self.memory_type]) + # default to recreate a new vb + vb = KBServiceFactory.get_service_by_name(vb_name, self.embed_config, self.kb_root_path) + if vb: + status = vb.clear_vs() + # create_kb(vb_name, "faiss", embed_model) - # default to faiss, todo: add new vstype - vb = KBServiceFactory.get_service(vb_name, "faiss", self.embed_config, self.kb_root_path) - embeddings = load_embeddings_from_path(self.embed_config.embed_model_path, self.embed_config.model_device,) - messages = self.recall_memory.dict() - messages = [ - {kkk: vvv for kkk, vvv in vv.items() if kkk in self.save_message_keys} - for k, v in messages.items() for vv in v] - docs = [{"page_content": m["step_content"] or m["role_content"] or m["input_query"] or m["origin_query"], "metadata": m} for m in messages] - docs = [Document(**doc) for doc in docs] - vb.do_add_doc(docs, embeddings) + # default to faiss, todo: add new vstype + vb = KBServiceFactory.get_service(vb_name, "faiss", self.embed_config, self.kb_root_path) + embeddings = load_embeddings_from_path(self.embed_config.embed_model_path, self.embed_config.model_device, self.embed_config.langchain_embeddings) + messages = self.recall_memory_dict[uuid_name].dict() + messages = [ + {kkk: vvv for kkk, vvv in vv.items() if kkk in self.save_message_keys} + for k, v in messages.items() for vv in v] + docs = [{"page_content": m["step_content"] or m["role_content"] or m["input_query"] or m["origin_query"], "metadata": m} for m in messages] + docs = [Document(**doc) for doc in docs] + vb.do_add_doc(docs, embeddings) # def load_from_vs(self, embed_model=EMBEDDING_MODEL) -> Memory: # vb_name = f"{self.user_name}/{self.unique_name}/{self.memory_type}" @@ -338,7 +372,12 @@ class LocalMemoryManager(BaseMemoryManager): # docs = vb.get_all_documents() # print(docs) - def router_retrieval(self, text: str=None, datetime: str = None, n=5, top_k=5, retrieval_type: str = "embedding", **kwargs) -> List[Message]: + def get_memory_pool(self, user_name: str, ): + self.check_user_name(user_name) + uuid_name = "_".join([self.user_name, self.unique_name, self.memory_type]) + return self.recall_memory_dict[uuid_name] + + def router_retrieval(self, user_name: str = "default", text: str=None, datetime: str = None, n=5, top_k=5, retrieval_type: str = "embedding", **kwargs) -> List[Message]: retrieval_func_dict = { "embedding": self.embedding_retrieval, "text": self.text_retrieval, "datetime": self.datetime_retrieval } @@ -356,20 +395,22 @@ class LocalMemoryManager(BaseMemoryManager): # return retrieval_func(**params) - def embedding_retrieval(self, text: str, top_k=1, score_threshold=1.0, **kwargs) -> List[Message]: + def embedding_retrieval(self, text: str, top_k=1, score_threshold=1.0, user_name: str = "default", **kwargs) -> List[Message]: if text is None: return [] - vb_name = f"{self.user_name}/{self.unique_name}/{self.memory_type}" + vb_name = f"{user_name}/{self.unique_name}/{self.memory_type}" vb = KBServiceFactory.get_service(vb_name, "faiss", self.embed_config, self.kb_root_path) docs = vb.search_docs(text, top_k=top_k, score_threshold=score_threshold) return [Message(**doc.metadata) for doc, score in docs] - def text_retrieval(self, text: str, **kwargs) -> List[Message]: + def text_retrieval(self, text: str, user_name: str = "default", **kwargs) -> List[Message]: if text is None: return [] - return self._text_retrieval_from_cache(self.recall_memory.messages, text, score_threshold=0.3, topK=5, **kwargs) + uuid_name = "_".join([user_name, self.unique_name, self.memory_type]) + return self._text_retrieval_from_cache(self.recall_memory_dict[uuid_name].messages, text, score_threshold=0.3, topK=5, **kwargs) - def datetime_retrieval(self, datetime: str, text: str = None, n: int = 5, **kwargs) -> List[Message]: + def datetime_retrieval(self, datetime: str, text: str = None, n: int = 5, user_name: str = "default", **kwargs) -> List[Message]: if datetime is None: return [] - return self._datetime_retrieval_from_cache(self.recall_memory.messages, datetime, text, n, **kwargs) + uuid_name = "_".join([user_name, self.unique_name, self.memory_type]) + return self._datetime_retrieval_from_cache(self.recall_memory_dict[uuid_name].messages, datetime, text, n, **kwargs) def _text_retrieval_from_cache(self, messages: List[Message], text: str = None, score_threshold=0.3, topK=5, tag_topK=5, **kwargs) -> List[Message]: keywords = extract_tags(text, topK=tag_topK) @@ -427,4 +468,18 @@ class LocalMemoryManager(BaseMemoryManager): ) summary_message.parsed_output_list.append({"summary": content}) newest_messages.insert(0, summary_message) - return newest_messages \ No newline at end of file + return newest_messages + + def check_user_name(self, user_name: str): + # logger.debug(f"self.user_name is {self.user_name}") + if user_name != self.user_name: + self.user_name = user_name + self.init_vb() + + uuid_name = "_".join([self.user_name, self.unique_name, self.memory_type]) + if uuid_name not in self.recall_memory_dict: + self.recall_memory_dict[uuid_name] = Memory(messages=[]) + self.current_memory_dict[uuid_name] = Memory(messages=[]) + self.summary_memory_dict[uuid_name] = Memory(messages=[]) + + # logger.debug(f"self.user_name is {self.user_name}") \ No newline at end of file diff --git a/coagent/connector/message_process.py b/coagent/connector/message_process.py index a97b2cb..b225fe5 100644 --- a/coagent/connector/message_process.py +++ b/coagent/connector/message_process.py @@ -1,16 +1,19 @@ import re, traceback, uuid, copy, json, os +from typing import Union from loguru import logger +from langchain.schema import BaseRetriever -# from configs.server_config import SANDBOX_SERVER -# from configs.model_config import JUPYTER_WORK_PATH from coagent.connector.schema import ( Memory, Role, Message, ActionStatus, CodeDoc, Doc, LogVerboseEnum ) +from coagent.retrieval.base_retrieval import IMRertrieval from coagent.connector.memory_manager import BaseMemoryManager from coagent.tools import DDGSTool, DocRetrieval, CodeRetrieval from coagent.sandbox import PyCodeBox, CodeBoxResponse from coagent.llm_models.llm_config import LLMConfig, EmbedConfig +from coagent.base_configs.env_config import JUPYTER_WORK_PATH + from .utils import parse_dict_to_dict, parse_text_to_dict @@ -19,10 +22,13 @@ class MessageUtils: self, role: Role = None, sandbox_server: dict = {}, - jupyter_work_path: str = "./", + jupyter_work_path: str = JUPYTER_WORK_PATH, embed_config: EmbedConfig = None, llm_config: LLMConfig = None, kb_root_path: str = "", + doc_retrieval: Union[BaseRetriever, IMRertrieval] = None, + code_retrieval: IMRertrieval = None, + search_retrieval: IMRertrieval = None, log_verbose: str = "0" ) -> None: self.role = role @@ -31,6 +37,9 @@ class MessageUtils: self.embed_config = embed_config self.llm_config = llm_config self.kb_root_path = kb_root_path + self.doc_retrieval = doc_retrieval + self.code_retrieval = code_retrieval + self.search_retrieval = search_retrieval self.codebox = PyCodeBox( remote_url=self.sandbox_server.get("url", "http://127.0.0.1:5050"), remote_ip=self.sandbox_server.get("host", "http://127.0.0.1"), @@ -44,6 +53,7 @@ class MessageUtils: self.log_verbose = os.environ.get("log_verbose", "0") or log_verbose def inherit_extrainfo(self, input_message: Message, output_message: Message): + output_message.user_name = input_message.user_name output_message.db_docs = input_message.db_docs output_message.search_docs = input_message.search_docs output_message.code_docs = input_message.code_docs @@ -116,18 +126,45 @@ class MessageUtils: knowledge_basename = message.doc_engine_name top_k = message.top_k score_threshold = message.score_threshold - if knowledge_basename: - docs = DocRetrieval.run(query, knowledge_basename, top_k, score_threshold, self.embed_config, self.kb_root_path) + if self.doc_retrieval: + if isinstance(self.doc_retrieval, BaseRetriever): + docs = self.doc_retrieval.get_relevant_documents(query) + else: + # docs = self.doc_retrieval.run(query, search_top=message.top_k, score_threshold=message.score_threshold,) + docs = self.doc_retrieval.run(query) + docs = [ + {"index": idx, "snippet": doc.page_content, "title": doc.metadata.get("title_prefix", ""), "link": doc.metadata.get("url", "")} + for idx, doc in enumerate(docs) + ] message.db_docs = [Doc(**doc) for doc in docs] + else: + if knowledge_basename: + docs = DocRetrieval.run(query, knowledge_basename, top_k, score_threshold, self.embed_config, self.kb_root_path) + message.db_docs = [Doc(**doc) for doc in docs] return message def get_code_retrieval(self, message: Message) -> Message: - query = message.input_query + query = message.role_content code_engine_name = message.code_engine_name history_node_list = message.history_node_list - code_docs = CodeRetrieval.run(code_engine_name, query, code_limit=message.top_k, history_node_list=history_node_list, search_type=message.cb_search_type, - llm_config=self.llm_config, embed_config=self.embed_config,) - message.code_docs = [CodeDoc(**doc) for doc in code_docs] + + use_nh = message.use_nh + local_graph_path = message.local_graph_path + + if self.code_retrieval: + code_docs = self.code_retrieval.run( + query, history_node_list=history_node_list, search_type=message.cb_search_type, + code_limit=1 + ) + else: + code_docs = CodeRetrieval.run(code_engine_name, query, code_limit=message.top_k, history_node_list=history_node_list, search_type=message.cb_search_type, + llm_config=self.llm_config, embed_config=self.embed_config, + use_nh=use_nh, local_graph_path=local_graph_path) + + message.code_docs = [CodeDoc(**doc) for doc in code_docs] + + # related_nodes = [doc.get_related_node() for idx, doc in enumerate(message.code_docs) if idx==0], + # history_node_list.extend([node[0] for node in related_nodes]) return message def get_tool_retrieval(self, message: Message) -> Message: @@ -160,6 +197,7 @@ class MessageUtils: if code_answer.code_exe_type == "error" else f"The return information after executing the above code is {code_answer.code_exe_response}.\n" observation_message = Message( + user_name=message.user_name, role_name="observation", role_type="function", #self.role.role_type, role_content="", @@ -190,6 +228,7 @@ class MessageUtils: def tool_step(self, message: Message) -> Message: '''execute tool''' observation_message = Message( + user_name=message.user_name, role_name="observation", role_type="function", #self.role.role_type, role_content="\n**Observation:** there is no tool can execute\n", @@ -226,7 +265,7 @@ class MessageUtils: return message, observation_message def parser(self, message: Message) -> Message: - '''''' + '''parse llm output into dict''' content = message.role_content # parse start parsed_dict = parse_text_to_dict(content) diff --git a/coagent/connector/phase/base_phase.py b/coagent/connector/phase/base_phase.py index 14bc66b..0b71e98 100644 --- a/coagent/connector/phase/base_phase.py +++ b/coagent/connector/phase/base_phase.py @@ -5,6 +5,8 @@ import importlib import copy from loguru import logger +from langchain.schema import BaseRetriever + from coagent.connector.agents import BaseAgent from coagent.connector.chains import BaseChain from coagent.connector.schema import ( @@ -18,9 +20,6 @@ from coagent.connector.message_process import MessageUtils from coagent.llm_models.llm_config import EmbedConfig, LLMConfig from coagent.base_configs.env_config import JUPYTER_WORK_PATH, KB_ROOT_PATH -# from configs.model_config import JUPYTER_WORK_PATH, KB_ROOT_PATH -# from configs.server_config import SANDBOX_SERVER - role_configs = load_role_configs(AGETN_CONFIGS) chain_configs = load_chain_configs(CHAIN_CONFIGS) @@ -39,20 +38,24 @@ class BasePhase: kb_root_path: str = KB_ROOT_PATH, jupyter_work_path: str = JUPYTER_WORK_PATH, sandbox_server: dict = {}, - embed_config: EmbedConfig = EmbedConfig(), - llm_config: LLMConfig = LLMConfig(), + embed_config: EmbedConfig = None, + llm_config: LLMConfig = None, task: Task = None, base_phase_config: Union[dict, str] = PHASE_CONFIGS, base_chain_config: Union[dict, str] = CHAIN_CONFIGS, base_role_config: Union[dict, str] = AGETN_CONFIGS, + chains: List[BaseChain] = [], + doc_retrieval: Union[BaseRetriever] = None, + code_retrieval = None, + search_retrieval = None, log_verbose: str = "0" ) -> None: # self.phase_name = phase_name self.do_summary = False - self.do_search = False - self.do_code_retrieval = False - self.do_doc_retrieval = False + self.do_search = search_retrieval is not None + self.do_code_retrieval = code_retrieval is not None + self.do_doc_retrieval = doc_retrieval is not None self.do_tool_retrieval = False # memory_pool dont have specific order # self.memory_pool = Memory(messages=[]) @@ -62,12 +65,15 @@ class BasePhase: self.jupyter_work_path = jupyter_work_path self.kb_root_path = kb_root_path self.log_verbose = max(os.environ.get("log_verbose", "0"), log_verbose) - - self.message_utils = MessageUtils(None, sandbox_server, jupyter_work_path, embed_config, llm_config, kb_root_path, log_verbose) + # TODO透传 + self.doc_retrieval = doc_retrieval + self.code_retrieval = code_retrieval + self.search_retrieval = search_retrieval + self.message_utils = MessageUtils(None, sandbox_server, jupyter_work_path, embed_config, llm_config, kb_root_path, doc_retrieval, code_retrieval, search_retrieval, log_verbose) self.global_memory = Memory(messages=[]) self.phase_memory: List[Memory] = [] # according phase name to init the phase contains - self.chains: List[BaseChain] = self.init_chains( + self.chains: List[BaseChain] = chains if chains else self.init_chains( phase_name, phase_config, task=task, @@ -90,7 +96,9 @@ class BasePhase: kb_root_path=kb_root_path ) - def astep(self, query: Message, history: Memory = None) -> Tuple[Message, Memory]: + def astep(self, query: Message, history: Memory = None, reinit_memory=False) -> Tuple[Message, Memory]: + if reinit_memory: + self.memory_manager.re_init(reinit_memory) self.memory_manager.append(query) summary_message = None chain_message = Memory(messages=[]) @@ -139,8 +147,8 @@ class BasePhase: message.role_name = self.phase_name yield message, local_phase_memory - def step(self, query: Message, history: Memory = None) -> Tuple[Message, Memory]: - for message, local_phase_memory in self.astep(query, history=history): + def step(self, query: Message, history: Memory = None, reinit_memory=False) -> Tuple[Message, Memory]: + for message, local_phase_memory in self.astep(query, history=history, reinit_memory=reinit_memory): pass return message, local_phase_memory @@ -194,6 +202,9 @@ class BasePhase: sandbox_server=self.sandbox_server, jupyter_work_path=self.jupyter_work_path, kb_root_path=self.kb_root_path, + doc_retrieval=self.doc_retrieval, + code_retrieval=self.code_retrieval, + search_retrieval=self.search_retrieval, log_verbose=self.log_verbose ) if agent_config.role.agent_type == "SelectorAgent": @@ -205,7 +216,7 @@ class BasePhase: group_base_agent = baseAgent( role=group_agent_config.role, prompt_config = group_agent_config.prompt_config, - prompt_manager_type=agent_config.prompt_manager_type, + prompt_manager_type=group_agent_config.prompt_manager_type, task = task, memory = memory, chat_turn=group_agent_config.chat_turn, @@ -216,6 +227,9 @@ class BasePhase: sandbox_server=self.sandbox_server, jupyter_work_path=self.jupyter_work_path, kb_root_path=self.kb_root_path, + doc_retrieval=self.doc_retrieval, + code_retrieval=self.code_retrieval, + search_retrieval=self.search_retrieval, log_verbose=self.log_verbose ) base_agent.group_agents.append(group_base_agent) @@ -223,13 +237,16 @@ class BasePhase: agents.append(base_agent) chain_instance = BaseChain( - agents, chain_config.chat_turn, - do_checker=chain_configs[chain_name].do_checker, + chain_config, + agents, jupyter_work_path=self.jupyter_work_path, sandbox_server=self.sandbox_server, embed_config=self.embed_config, llm_config=self.llm_config, kb_root_path=self.kb_root_path, + doc_retrieval=self.doc_retrieval, + code_retrieval=self.code_retrieval, + search_retrieval=self.search_retrieval, log_verbose=self.log_verbose ) chains.append(chain_instance) diff --git a/coagent/connector/prompt_manager/__init__.py b/coagent/connector/prompt_manager/__init__.py new file mode 100644 index 0000000..8957e25 --- /dev/null +++ b/coagent/connector/prompt_manager/__init__.py @@ -0,0 +1,2 @@ +from .prompt_manager import PromptManager +from .extend_manager import * \ No newline at end of file diff --git a/coagent/connector/prompt_manager/extend_manager.py b/coagent/connector/prompt_manager/extend_manager.py new file mode 100644 index 0000000..69cc449 --- /dev/null +++ b/coagent/connector/prompt_manager/extend_manager.py @@ -0,0 +1,45 @@ + +from coagent.connector.schema import Message +from .prompt_manager import PromptManager + + +class Code2DocPM(PromptManager): + def handle_code_snippet(self, **kwargs) -> str: + if 'previous_agent_message' not in kwargs: + return "" + previous_agent_message: Message = kwargs['previous_agent_message'] + code_snippet = previous_agent_message.customed_kargs.get("Code Snippet", "") + current_vertex = previous_agent_message.customed_kargs.get("Current_Vertex", "") + instruction = "A segment of code that contains the function or method to be documented.\n" + return instruction + "\n" + f"name: {current_vertex}\n{code_snippet}" + + def handle_specific_objective(self, **kwargs) -> str: + if 'previous_agent_message' not in kwargs: + return "" + previous_agent_message: Message = kwargs['previous_agent_message'] + specific_objective = previous_agent_message.parsed_output.get("Code Path") + + instruction = "Provide the code path of the function or method you wish to document.\n" + s = instruction + f"\n{specific_objective}" + return s + + +class CodeRetrievalPM(PromptManager): + def handle_code_snippet(self, **kwargs) -> str: + if 'previous_agent_message' not in kwargs: + return "" + previous_agent_message: Message = kwargs['previous_agent_message'] + code_snippet = previous_agent_message.customed_kargs.get("Code Snippet", "") + current_vertex = previous_agent_message.customed_kargs.get("Current_Vertex", "") + instruction = "the initial Code or objective that the user wanted to achieve" + return instruction + "\n" + f"name: {current_vertex}\n{code_snippet}" + + def handle_retrieval_codes(self, **kwargs) -> str: + if 'previous_agent_message' not in kwargs: + return "" + previous_agent_message: Message = kwargs['previous_agent_message'] + Retrieval_Codes = previous_agent_message.customed_kargs["Retrieval_Codes"] + Relative_vertex = previous_agent_message.customed_kargs["Relative_vertex"] + instruction = "the initial Code or objective that the user wanted to achieve" + s = instruction + "\n" + "\n".join([f"name: {vertext}\n{code}" for vertext, code in zip(Relative_vertex, Retrieval_Codes)]) + return s diff --git a/coagent/connector/prompt_manager/prompt_manager.py b/coagent/connector/prompt_manager/prompt_manager.py new file mode 100644 index 0000000..57d454a --- /dev/null +++ b/coagent/connector/prompt_manager/prompt_manager.py @@ -0,0 +1,353 @@ +import random +from textwrap import dedent +import copy +from loguru import logger + +from langchain.agents.tools import Tool + +from coagent.connector.schema import Memory, Message +from coagent.connector.utils import extract_section, parse_section + + + +class PromptManager: + def __init__(self, role_prompt="", prompt_config=None, monitored_agents=[], monitored_fields=[]): + self.role_prompt = role_prompt + self.monitored_agents = monitored_agents + self.monitored_fields = monitored_fields + self.field_handlers = {} + self.context_handlers = {} + self.field_order = [] # 用于普通字段的顺序 + self.context_order = [] # 单独维护上下文字段的顺序 + self.field_descriptions = {} + self.omit_if_empty_flags = {} + self.context_title = "### Context Data\n\n" + + self.prompt_config = prompt_config + if self.prompt_config: + self.register_fields_from_config() + + def register_field(self, field_name, function=None, title=None, description=None, is_context=True, omit_if_empty=True): + """ + 注册一个新的字段及其处理函数。 + Args: + field_name (str): 字段名称。 + function (callable): 处理字段数据的函数。 + title (str, optional): 字段的自定义标题(可选)。 + description (str, optional): 字段的描述(可选,可以是几句话)。 + is_context (bool, optional): 指示该字段是否为上下文字段。 + omit_if_empty (bool, optional): 如果数据为空,是否省略该字段。 + """ + if not function: + function = self.handle_custom_data + + # Register the handler function based on context flag + if is_context: + self.context_handlers[field_name] = function + else: + self.field_handlers[field_name] = function + + # Store the custom title if provided and adjust the title prefix based on context + title_prefix = "####" if is_context else "###" + if title is not None: + self.field_descriptions[field_name] = f"{title_prefix} {title}\n\n" + elif description is not None: + # If title is not provided but description is, use description as title + self.field_descriptions[field_name] = f"{title_prefix} {field_name.replace('_', ' ').title()}\n\n{description}\n\n" + else: + # If neither title nor description is provided, use the field name as title + self.field_descriptions[field_name] = f"{title_prefix} {field_name.replace('_', ' ').title()}\n\n" + + # Store the omit_if_empty flag for this field + self.omit_if_empty_flags[field_name] = omit_if_empty + + if is_context and field_name != 'context_placeholder': + self.context_handlers[field_name] = function + self.context_order.append(field_name) + else: + self.field_handlers[field_name] = function + self.field_order.append(field_name) + + def generate_full_prompt(self, **kwargs): + full_prompt = [] + context_prompts = [] # 用于收集上下文内容 + is_pre_print = kwargs.get("is_pre_print", False) # 用于强制打印所有prompt 字段信息,不管有没有空 + + # 先处理上下文字段 + for field_name in self.context_order: + handler = self.context_handlers[field_name] + processed_prompt = handler(**kwargs) + # Check if the field should be omitted when empty + if self.omit_if_empty_flags.get(field_name, False) and not processed_prompt and not is_pre_print: + continue # Skip this field + title_or_description = self.field_descriptions.get(field_name, f"#### {field_name.replace('_', ' ').title()}\n\n") + context_prompts.append(title_or_description + processed_prompt + '\n\n') + + # 处理普通字段,同时查找 context_placeholder 的位置 + for field_name in self.field_order: + if field_name == 'context_placeholder': + # 在 context_placeholder 的位置插入上下文数据 + full_prompt.append(self.context_title) # 添加上下文部分的大标题 + full_prompt.extend(context_prompts) # 添加收集的上下文内容 + else: + handler = self.field_handlers[field_name] + processed_prompt = handler(**kwargs) + # Check if the field should be omitted when empty + if self.omit_if_empty_flags.get(field_name, False) and not processed_prompt and not is_pre_print: + continue # Skip this field + title_or_description = self.field_descriptions.get(field_name, f"### {field_name.replace('_', ' ').title()}\n\n") + full_prompt.append(title_or_description + processed_prompt + '\n\n') + + # 返回完整的提示,移除尾部的空行 + return ''.join(full_prompt).rstrip('\n') + + def pre_print(self, **kwargs): + kwargs.update({"is_pre_print": True}) + prompt = self.generate_full_prompt(**kwargs) + + input_keys = parse_section(self.role_prompt, 'Response Output Format') + llm_predict = "\n".join([f"**{k}:**" for k in input_keys]) + return prompt + "\n\n" + "#"*19 + "\n<<<>>>\n" + "#"*19 + f"\n\n{llm_predict}\n" + + def handle_custom_data(self, **kwargs): + return "" + + def handle_tool_data(self, **kwargs): + if 'previous_agent_message' not in kwargs: + return "" + + previous_agent_message = kwargs.get('previous_agent_message') + tools: list[Tool] = previous_agent_message.tools + + if not tools: + return "" + + tool_strings = [] + for tool in tools: + args_str = f'args: {str(tool.args)}' if tool.args_schema else "" + tool_strings.append(f"{tool.name}: {tool.description}, {args_str}") + formatted_tools = "\n".join(tool_strings) + + tool_names = ", ".join([tool.name for tool in tools]) + + tool_prompt = dedent(f""" +Below is a list of tools that are available for your use: +{formatted_tools} + +valid "tool_name" value is: +{tool_names} +""") + + return tool_prompt + + def handle_agent_data(self, **kwargs): + if 'agents' not in kwargs: + return "" + + agents = kwargs.get('agents') + random.shuffle(agents) + agent_names = ", ".join([f'{agent.role.role_name}' for agent in agents]) + agent_descs = [] + for agent in agents: + role_desc = agent.role.role_prompt.split("####")[1] + while "\n\n" in role_desc: + role_desc = role_desc.replace("\n\n", "\n") + role_desc = role_desc.replace("\n", ",") + + agent_descs.append(f'"role name: {agent.role.role_name}\nrole description: {role_desc}"') + + agents = "\n".join(agent_descs) + agent_prompt = f''' + Please ensure your selection is one of the listed roles. Available roles for selection: + {agents} + Please ensure select the Role from agent names, such as {agent_names}''' + + return dedent(agent_prompt) + + def handle_doc_info(self, **kwargs) -> str: + if 'previous_agent_message' not in kwargs: + return "" + previous_agent_message: Message = kwargs.get('previous_agent_message') + db_docs = previous_agent_message.db_docs + search_docs = previous_agent_message.search_docs + code_cocs = previous_agent_message.code_docs + doc_infos = "\n".join([doc.get_snippet() for doc in db_docs] + [doc.get_snippet() for doc in search_docs] + + [doc.get_code() for doc in code_cocs]) + return doc_infos + + def handle_session_records(self, **kwargs) -> str: + + memory_pool: Memory = kwargs.get('memory_pool', Memory(messages=[])) + memory_pool = self.select_memory_by_agent_name(memory_pool) + memory_pool = self.select_memory_by_parsed_key(memory_pool) + + return memory_pool.to_str_messages(content_key="parsed_output_list", with_tag=True) + + def handle_current_plan(self, **kwargs) -> str: + if 'previous_agent_message' not in kwargs: + return "" + previous_agent_message = kwargs['previous_agent_message'] + return previous_agent_message.parsed_output.get("CURRENT_STEP", "") + + def handle_agent_profile(self, **kwargs) -> str: + return extract_section(self.role_prompt, 'Agent Profile') + + def handle_output_format(self, **kwargs) -> str: + return extract_section(self.role_prompt, 'Response Output Format') + + def handle_response(self, **kwargs) -> str: + if 'react_memory' not in kwargs: + return "" + + react_memory = kwargs.get('react_memory', Memory(messages=[])) + if react_memory is None: + return "" + + return "\n".join(["\n".join([f"**{k}:**\n{v}" for k,v in _dict.items()]) for _dict in react_memory.get_parserd_output()]) + + def handle_task_records(self, **kwargs) -> str: + if 'task_memory' not in kwargs: + return "" + + task_memory: Memory = kwargs.get('task_memory', Memory(messages=[])) + if task_memory is None: + return "" + + return "\n".join(["\n".join([f"**{k}:**\n{v}" for k,v in _dict.items() if k not in ["CURRENT_STEP"]]) for _dict in task_memory.get_parserd_output()]) + + def handle_previous_message(self, message: Message) -> str: + pass + + def handle_message_by_role_name(self, message: Message) -> str: + pass + + def handle_message_by_role_type(self, message: Message) -> str: + pass + + def handle_current_agent_react_message(self, message: Message) -> str: + pass + + def extract_codedoc_info_for_prompt(self, message: Message) -> str: + code_docs = message.code_docs + doc_infos = "\n".join([doc.get_code() for doc in code_docs]) + return doc_infos + + def select_memory_by_parsed_key(self, memory: Memory) -> Memory: + return Memory( + messages=[self.select_message_by_parsed_key(message) for message in memory.messages + if self.select_message_by_parsed_key(message) is not None] + ) + + def select_memory_by_agent_name(self, memory: Memory) -> Memory: + return Memory( + messages=[self.select_message_by_agent_name(message) for message in memory.messages + if self.select_message_by_agent_name(message) is not None] + ) + + def select_message_by_agent_name(self, message: Message) -> Message: + # assume we focus all agents + if self.monitored_agents == []: + return message + return None if message is None or message.role_name not in self.monitored_agents else self.select_message_by_parsed_key(message) + + def select_message_by_parsed_key(self, message: Message) -> Message: + # assume we focus all key contents + if message is None: + return message + + if self.monitored_fields == []: + return message + + message_c = copy.deepcopy(message) + message_c.parsed_output = {k: v for k,v in message_c.parsed_output.items() if k in self.monitored_fields} + message_c.parsed_output_list = [{k: v for k,v in parsed_output.items() if k in self.monitored_fields} for parsed_output in message_c.parsed_output_list] + return message_c + + def get_memory(self, content_key="role_content"): + return self.memory.to_tuple_messages(content_key="step_content") + + def get_memory_str(self, content_key="role_content"): + return "\n".join([": ".join(i) for i in self.memory.to_tuple_messages(content_key="step_content")]) + + def register_fields_from_config(self): + + for prompt_field in self.prompt_config: + + function_name = prompt_field.function_name + # 检查function_name是否是self的一个方法 + if function_name and hasattr(self, function_name): + function = getattr(self, function_name) + else: + function = self.handle_custom_data + + self.register_field(prompt_field.field_name, + function=function, + title=prompt_field.title, + description=prompt_field.description, + is_context=prompt_field.is_context, + omit_if_empty=prompt_field.omit_if_empty) + + def register_standard_fields(self): + self.register_field('agent_profile', function=self.handle_agent_profile, is_context=False) + self.register_field('tool_information', function=self.handle_tool_data, is_context=False) + self.register_field('context_placeholder', is_context=True) # 用于标记上下文数据部分的位置 + self.register_field('reference_documents', function=self.handle_doc_info, is_context=True) + self.register_field('session_records', function=self.handle_session_records, is_context=True) + self.register_field('output_format', function=self.handle_output_format, title='Response Output Format', is_context=False) + self.register_field('response', function=self.handle_response, is_context=False, omit_if_empty=False) + + def register_executor_fields(self): + self.register_field('agent_profile', function=self.handle_agent_profile, is_context=False) + self.register_field('tool_information', function=self.handle_tool_data, is_context=False) + self.register_field('context_placeholder', is_context=True) # 用于标记上下文数据部分的位置 + self.register_field('reference_documents', function=self.handle_doc_info, is_context=True) + self.register_field('session_records', function=self.handle_session_records, is_context=True) + self.register_field('current_plan', function=self.handle_current_plan, is_context=True) + self.register_field('output_format', function=self.handle_output_format, title='Response Output Format', is_context=False) + self.register_field('response', function=self.handle_response, is_context=False, omit_if_empty=False) + + def register_fields_from_dict(self, fields_dict): + # 使用字典注册字段的函数 + for field_name, field_config in fields_dict.items(): + function_name = field_config.get('function', None) + title = field_config.get('title', None) + description = field_config.get('description', None) + is_context = field_config.get('is_context', True) + omit_if_empty = field_config.get('omit_if_empty', True) + + # 检查function_name是否是self的一个方法 + if function_name and hasattr(self, function_name): + function = getattr(self, function_name) + else: + function = self.handle_custom_data + + # 调用已存在的register_field方法注册字段 + self.register_field(field_name, function=function, title=title, description=description, is_context=is_context, omit_if_empty=omit_if_empty) + + + +def main(): + manager = PromptManager() + manager.register_standard_fields() + + manager.register_field('agents_work_progress', title=f"Agents' Work Progress", is_context=True) + + # 创建数据字典 + data_dict = { + "agent_profile": "这是代理配置文件...", + # "tool_list": "这是工具列表...", + "reference_documents": "这是参考文档...", + "session_records": "这是会话记录...", + "agents_work_progress": "这是代理工作进展...", + "output_format": "这是预期的输出格式...", + # "response": "这是生成或继续回应的指令...", + "response": "", + "test": 'xxxxx' + } + + # 组合完整的提示 + full_prompt = manager.generate_full_prompt(data_dict) + print(full_prompt) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/coagent/connector/schema/general_schema.py b/coagent/connector/schema/general_schema.py index 9da753c..c45aa5e 100644 --- a/coagent/connector/schema/general_schema.py +++ b/coagent/connector/schema/general_schema.py @@ -215,15 +215,15 @@ class Env(BaseModel): class Role(BaseModel): role_type: str role_name: str - role_desc: str - agent_type: str = "" + role_desc: str = "" + agent_type: str = "BaseAgent" role_prompt: str = "" template_prompt: str = "" class ChainConfig(BaseModel): chain_name: str - chain_type: str + chain_type: str = "BaseChain" agents: List[str] do_checker: bool = False chat_turn: int = 1 diff --git a/coagent/connector/schema/memory.py b/coagent/connector/schema/memory.py index bba4c43..07b92fd 100644 --- a/coagent/connector/schema/memory.py +++ b/coagent/connector/schema/memory.py @@ -131,6 +131,9 @@ class Memory(BaseModel): # logger.debug(f"{message.role_name}: {message.parsed_output_list}") # return [parsed_output for message in self.messages for parsed_output in message.parsed_output_list[1:]] return [parsed_output for message in self.messages for parsed_output in message.parsed_output_list] + + def get_spec_parserd_output(self, ): + return [message.spec_parsed_output for message in self.messages] def get_rolenames(self, ): '''''' diff --git a/coagent/connector/schema/message.py b/coagent/connector/schema/message.py index c306ba5..92f9658 100644 --- a/coagent/connector/schema/message.py +++ b/coagent/connector/schema/message.py @@ -7,6 +7,7 @@ from .general_schema import * class Message(BaseModel): chat_index: str = None + user_name: str = "default" role_name: str role_type: str role_prompt: str = None @@ -53,6 +54,8 @@ class Message(BaseModel): cb_search_type: str = None search_engine_name: str = None top_k: int = 3 + use_nh: bool = True + local_graph_path: str = '' score_threshold: float = 1.0 do_doc_retrieval: bool = False do_code_retrieval: bool = False diff --git a/coagent/connector/utils.py b/coagent/connector/utils.py index d3dfb61..f3f8bfb 100644 --- a/coagent/connector/utils.py +++ b/coagent/connector/utils.py @@ -72,20 +72,25 @@ def parse_text_to_dict(text): def parse_dict_to_dict(parsed_dict) -> dict: code_pattern = r'```python\n(.*?)```' tool_pattern = r'```json\n(.*?)```' + java_pattern = r'```java\n(.*?)```' - pattern_dict = {"code": code_pattern, "json": tool_pattern} + pattern_dict = {"code": code_pattern, "json": tool_pattern, "java": java_pattern} spec_parsed_dict = copy.deepcopy(parsed_dict) for key, pattern in pattern_dict.items(): for k, text in parsed_dict.items(): # Search for the code block - if not isinstance(text, str): continue + if not isinstance(text, str): + spec_parsed_dict[k] = text + continue _match = re.search(pattern, text, re.DOTALL) if _match: # Add the code block to the dictionary try: spec_parsed_dict[key] = json.loads(_match.group(1).strip()) + spec_parsed_dict[k] = json.loads(_match.group(1).strip()) except: spec_parsed_dict[key] = _match.group(1).strip() + spec_parsed_dict[k] = _match.group(1).strip() break return spec_parsed_dict diff --git a/coagent/db_handler/graph_db_handler/nebula_handler.py b/coagent/db_handler/graph_db_handler/nebula_handler.py index e5f4cde..062b4e7 100644 --- a/coagent/db_handler/graph_db_handler/nebula_handler.py +++ b/coagent/db_handler/graph_db_handler/nebula_handler.py @@ -43,7 +43,7 @@ class NebulaHandler: elif self.space_name: cypher = f'USE {self.space_name};{cypher}' - logger.debug(cypher) + # logger.debug(cypher) resp = session.execute(cypher) if format_res: @@ -247,6 +247,24 @@ class NebulaHandler: res = self.execute_cypher(cypher, self.space_name) return self.result_to_dict(res) + def get_all_vertices(self,): + ''' + get all vertices + @return: + ''' + cypher = "MATCH (v) RETURN v;" + res = self.execute_cypher(cypher, self.space_name) + return self.result_to_dict(res) + + def get_relative_vertices(self, vertice): + ''' + get all vertices + @return: + ''' + cypher = f'''MATCH (v1)--(v2) WHERE id(v1) == '{vertice}' RETURN id(v2) as id;''' + res = self.execute_cypher(cypher, self.space_name) + return self.result_to_dict(res) + def result_to_dict(self, result) -> dict: """ build list for each column, and transform to dataframe diff --git a/coagent/embeddings/faiss_m.py b/coagent/embeddings/faiss_m.py index 17a910d..31d124c 100644 --- a/coagent/embeddings/faiss_m.py +++ b/coagent/embeddings/faiss_m.py @@ -6,6 +6,7 @@ import os import pickle import uuid import warnings +from enum import Enum from pathlib import Path from typing import ( Any, @@ -22,10 +23,22 @@ import numpy as np from langchain.docstore.base import AddableMixin, Docstore from langchain.docstore.document import Document -from langchain.docstore.in_memory import InMemoryDocstore +# from langchain.docstore.in_memory import InMemoryDocstore +from .in_memory import InMemoryDocstore from langchain.embeddings.base import Embeddings from langchain.vectorstores.base import VectorStore -from langchain.vectorstores.utils import DistanceStrategy, maximal_marginal_relevance +from langchain.vectorstores.utils import maximal_marginal_relevance + + +class DistanceStrategy(str, Enum): + """Enumerator of the Distance strategies for calculating distances + between vectors.""" + + EUCLIDEAN_DISTANCE = "EUCLIDEAN_DISTANCE" + MAX_INNER_PRODUCT = "MAX_INNER_PRODUCT" + DOT_PRODUCT = "DOT_PRODUCT" + JACCARD = "JACCARD" + COSINE = "COSINE" def dependable_faiss_import(no_avx2: Optional[bool] = None) -> Any: @@ -219,6 +232,9 @@ class FAISS(VectorStore): if self._normalize_L2: faiss.normalize_L2(vector) scores, indices = self.index.search(vector, k if filter is None else fetch_k) + # 经过normalize的结果会超出1 + if self._normalize_L2: + scores = np.array([row / np.linalg.norm(row) if np.max(row) > 1 else row for row in scores]) docs = [] for j, i in enumerate(indices[0]): if i == -1: @@ -565,7 +581,7 @@ class FAISS(VectorStore): vecstore = cls( embedding.embed_query, index, - InMemoryDocstore(), + InMemoryDocstore({}), {}, normalize_L2=normalize_L2, distance_strategy=distance_strategy, diff --git a/coagent/embeddings/get_embedding.py b/coagent/embeddings/get_embedding.py index 3ea45ab..34382b8 100644 --- a/coagent/embeddings/get_embedding.py +++ b/coagent/embeddings/get_embedding.py @@ -10,13 +10,14 @@ from loguru import logger # from configs.model_config import EMBEDDING_MODEL from coagent.embeddings.openai_embedding import OpenAIEmbedding from coagent.embeddings.huggingface_embedding import HFEmbedding - +from coagent.llm_models.llm_config import EmbedConfig def get_embedding( engine: str, text_list: list, model_path: str = "text2vec-base-chinese", embedding_device: str = "cpu", + embed_config: EmbedConfig = None, ): ''' get embedding @@ -25,8 +26,12 @@ def get_embedding( @return: ''' emb_res = {} - - if engine == 'openai': + if embed_config and embed_config.langchain_embeddings: + emb_res = embed_config.langchain_embeddings.embed_documents(text_list) + emb_res = { + text_list[idx]: emb_res[idx] for idx in range(len(text_list)) + } + elif engine == 'openai': oae = OpenAIEmbedding() emb_res = oae.get_emb(text_list) elif engine == 'model': diff --git a/coagent/embeddings/in_memory.py b/coagent/embeddings/in_memory.py new file mode 100644 index 0000000..f92484d --- /dev/null +++ b/coagent/embeddings/in_memory.py @@ -0,0 +1,49 @@ +"""Simple in memory docstore in the form of a dict.""" +from typing import Dict, List, Optional, Union + +from langchain.docstore.base import AddableMixin, Docstore +from langchain.docstore.document import Document + + +class InMemoryDocstore(Docstore, AddableMixin): + """Simple in memory docstore in the form of a dict.""" + + def __init__(self, _dict: Optional[Dict[str, Document]] = None): + """Initialize with dict.""" + self._dict = _dict if _dict is not None else {} + + def add(self, texts: Dict[str, Document]) -> None: + """Add texts to in memory dictionary. + + Args: + texts: dictionary of id -> document. + + Returns: + None + """ + overlapping = set(texts).intersection(self._dict) + if overlapping: + raise ValueError(f"Tried to add ids that already exist: {overlapping}") + self._dict = {**self._dict, **texts} + + def delete(self, ids: List) -> None: + """Deleting IDs from in memory dictionary.""" + overlapping = set(ids).intersection(self._dict) + if not overlapping: + raise ValueError(f"Tried to delete ids that does not exist: {ids}") + for _id in ids: + self._dict.pop(_id) + + def search(self, search: str) -> Union[str, Document]: + """Search via direct lookup. + + Args: + search: id of a document to search for. + + Returns: + Document if found, else error message. + """ + if search not in self._dict: + return f"ID {search} not found." + else: + return self._dict[search] diff --git a/coagent/embeddings/utils.py b/coagent/embeddings/utils.py index 8b98f27..25088b1 100644 --- a/coagent/embeddings/utils.py +++ b/coagent/embeddings/utils.py @@ -1,6 +1,8 @@ import os from functools import lru_cache from langchain.embeddings.huggingface import HuggingFaceEmbeddings +from langchain.embeddings.base import Embeddings + # from configs.model_config import embedding_model_dict from loguru import logger @@ -12,8 +14,11 @@ def load_embeddings(model: str, device: str, embedding_model_dict: dict): return embeddings -@lru_cache(1) -def load_embeddings_from_path(model_path: str, device: str): +# @lru_cache(1) +def load_embeddings_from_path(model_path: str, device: str, langchain_embeddings: Embeddings = None): + if langchain_embeddings: + return langchain_embeddings + embeddings = HuggingFaceEmbeddings(model_name=model_path, model_kwargs={'device': device}) return embeddings diff --git a/coagent/llm_models/__init__.py b/coagent/llm_models/__init__.py index 606e061..11bbe7a 100644 --- a/coagent/llm_models/__init__.py +++ b/coagent/llm_models/__init__.py @@ -1,8 +1,8 @@ -from .openai_model import getChatModel, getExtraModel, getChatModelFromConfig +from .openai_model import getExtraModel, getChatModelFromConfig from .llm_config import LLMConfig, EmbedConfig __all__ = [ - "getChatModel", "getExtraModel", "getChatModelFromConfig", + "getExtraModel", "getChatModelFromConfig", "LLMConfig", "EmbedConfig" ] \ No newline at end of file diff --git a/coagent/llm_models/llm_config.py b/coagent/llm_models/llm_config.py index 389290f..9dac682 100644 --- a/coagent/llm_models/llm_config.py +++ b/coagent/llm_models/llm_config.py @@ -1,6 +1,9 @@ from dataclasses import dataclass from typing import List, Union +from langchain.embeddings.base import Embeddings +from langchain.llms.base import LLM, BaseLLM + @dataclass @@ -12,7 +15,8 @@ class LLMConfig: stop: Union[List[str], str] = None, api_key: str = "", api_base_url: str = "", - model_device: str = "cpu", + model_device: str = "cpu", # unuse,will delete it + llm: LLM = None, **kwargs ): @@ -21,7 +25,7 @@ class LLMConfig: self.stop: Union[List[str], str] = stop self.api_key: str = api_key self.api_base_url: str = api_base_url - self.model_device: str = model_device + self.llm: LLM = llm # self.check_config() @@ -42,6 +46,7 @@ class EmbedConfig: embed_model_path: str = "", embed_engine: str = "", model_device: str = "cpu", + langchain_embeddings: Embeddings = None, **kwargs ): self.embed_model: str = embed_model @@ -51,6 +56,8 @@ class EmbedConfig: self.api_key: str = api_key self.api_base_url: str = api_base_url # + self.langchain_embeddings = langchain_embeddings + # self.check_config() def check_config(self, ): diff --git a/coagent/llm_models/openai_model.py b/coagent/llm_models/openai_model.py index 381e512..50453c9 100644 --- a/coagent/llm_models/openai_model.py +++ b/coagent/llm_models/openai_model.py @@ -1,38 +1,54 @@ import os +from typing import Union, Optional, List +from loguru import logger from langchain.callbacks import AsyncIteratorCallbackHandler from langchain.chat_models import ChatOpenAI +from langchain.llms.base import LLM from .llm_config import LLMConfig # from configs.model_config import (llm_model_dict, LLM_MODEL) -def getChatModel(callBack: AsyncIteratorCallbackHandler = None, temperature=0.3, stop=None): - if callBack is None: +class CustomLLMModel: + + def __init__(self, llm: LLM): + self.llm: LLM = llm + + def __call__(self, prompt: str, + stop: Optional[List[str]] = None): + return self.llm(prompt, stop) + + def _call(self, prompt: str, + stop: Optional[List[str]] = None): + return self.llm(prompt, stop) + + def predict(self, prompt: str, + stop: Optional[List[str]] = None): + return self.llm(prompt, stop) + + def batch(self, prompts: str, + stop: Optional[List[str]] = None): + return [self.llm(prompt, stop) for prompt in prompts] + + +def getChatModelFromConfig(llm_config: LLMConfig, callBack: AsyncIteratorCallbackHandler = None, ) -> Union[ChatOpenAI, LLM]: + # logger.debug(f"llm type is {type(llm_config.llm)}") + if llm_config is None: model = ChatOpenAI( streaming=True, verbose=True, - openai_api_key=llm_model_dict[LLM_MODEL]["api_key"], - openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"], - model_name=LLM_MODEL, - temperature=temperature, - stop=stop + openai_api_key=os.environ.get("api_key"), + openai_api_base=os.environ.get("api_base_url"), + model_name=os.environ.get("LLM_MODEL", "gpt-3.5-turbo"), + temperature=os.environ.get("temperature", 0.5), + stop=os.environ.get("stop", ""), ) - else: - model = ChatOpenAI( - streaming=True, - verbose=True, - callBack=[callBack], - openai_api_key=llm_model_dict[LLM_MODEL]["api_key"], - openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"], - model_name=LLM_MODEL, - temperature=temperature, - stop=stop - ) - return model + return model - -def getChatModelFromConfig(llm_config: LLMConfig, callBack: AsyncIteratorCallbackHandler = None, ): + if llm_config and llm_config.llm and isinstance(llm_config.llm, LLM): + return CustomLLMModel(llm=llm_config.llm) + if callBack is None: model = ChatOpenAI( streaming=True, diff --git a/coagent/retrieval/__init__.py b/coagent/retrieval/__init__.py new file mode 100644 index 0000000..193bb2e --- /dev/null +++ b/coagent/retrieval/__init__.py @@ -0,0 +1,5 @@ +# from .base_retrieval import * + +# __all__ = [ +# "IMRertrieval", "BaseDocRetrieval", "BaseCodeRetrieval", "BaseSearchRetrieval" +# ] \ No newline at end of file diff --git a/coagent/retrieval/base_retrieval.py b/coagent/retrieval/base_retrieval.py new file mode 100644 index 0000000..79279c6 --- /dev/null +++ b/coagent/retrieval/base_retrieval.py @@ -0,0 +1,75 @@ + +from coagent.llm_models.llm_config import EmbedConfig, LLMConfig +from coagent.base_configs.env_config import KB_ROOT_PATH +from coagent.tools import DocRetrieval, CodeRetrieval + + +class IMRertrieval: + + def __init__(self,): + ''' + init your personal attributes + ''' + pass + + def run(self, ): + ''' + execute interface, and can use init' attributes + ''' + pass + + +class BaseDocRetrieval(IMRertrieval): + + def __init__(self, knowledge_base_name: str, search_top=5, score_threshold=1.0, embed_config: EmbedConfig=EmbedConfig(), kb_root_path: str=KB_ROOT_PATH): + self.knowledge_base_name = knowledge_base_name + self.search_top = search_top + self.score_threshold = score_threshold + self.embed_config = embed_config + self.kb_root_path = kb_root_path + + def run(self, query: str, search_top=None, score_threshold=None, ): + docs = DocRetrieval.run( + query=query, knowledge_base_name=self.knowledge_base_name, + search_top=search_top or self.search_top, + score_threshold=score_threshold or self.score_threshold, + embed_config=self.embed_config, + kb_root_path=self.kb_root_path + ) + return docs + + +class BaseCodeRetrieval(IMRertrieval): + + def __init__(self, code_base_name, embed_config: EmbedConfig, llm_config: LLMConfig, search_type = 'tag', code_limit = 1, local_graph_path: str=""): + self.code_base_name = code_base_name + self.embed_config = embed_config + self.llm_config = llm_config + self.search_type = search_type + self.code_limit = code_limit + self.use_nh: bool = False + self.local_graph_path: str = local_graph_path + + def run(self, query, history_node_list=[], search_type = None, code_limit=None): + code_docs = CodeRetrieval.run( + code_base_name=self.code_base_name, + query=query, + history_node_list=history_node_list, + code_limit=code_limit or self.code_limit, + search_type=search_type or self.search_type, + llm_config=self.llm_config, + embed_config=self.embed_config, + use_nh=self.use_nh, + local_graph_path=self.local_graph_path + ) + return code_docs + + + +class BaseSearchRetrieval(IMRertrieval): + + def __init__(self, ): + pass + + def run(self, ): + pass diff --git a/coagent/retrieval/document_loaders/__init__.py b/coagent/retrieval/document_loaders/__init__.py new file mode 100644 index 0000000..2343aeb --- /dev/null +++ b/coagent/retrieval/document_loaders/__init__.py @@ -0,0 +1,6 @@ +from .json_loader import JSONLoader +from .jsonl_loader import JSONLLoader + +__all__ = [ + "JSONLoader", "JSONLLoader" +] \ No newline at end of file diff --git a/coagent/retrieval/document_loaders/json_loader.py b/coagent/retrieval/document_loaders/json_loader.py new file mode 100644 index 0000000..4e5ecd3 --- /dev/null +++ b/coagent/retrieval/document_loaders/json_loader.py @@ -0,0 +1,61 @@ +import json +from pathlib import Path +from typing import AnyStr, Callable, Dict, List, Optional, Union + +from langchain.docstore.document import Document +from langchain.document_loaders.base import BaseLoader +from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter + +from coagent.utils.common_utils import read_json_file + + +class JSONLoader(BaseLoader): + + def __init__( + self, + file_path: Union[str, Path], + schema_key: str = "all_text", + content_key: Optional[str] = None, + metadata_func: Optional[Callable[[Dict, Dict], Dict]] = None, + text_content: bool = True, + ): + self.file_path = Path(file_path).resolve() + self.schema_key = schema_key + self._content_key = content_key + self._metadata_func = metadata_func + self._text_content = text_content + + def load(self, ) -> List[Document]: + """Load and return documents from the JSON file.""" + docs: List[Document] = [] + datas = read_json_file(self.file_path) + self._parse(datas, docs) + return docs + + def _parse(self, datas: List, docs: List[Document]) -> None: + for idx, sample in enumerate(datas): + metadata = dict( + source=str(self.file_path), + seq_num=idx, + ) + text = sample.get(self.schema_key, "") + docs.append(Document(page_content=text, metadata=metadata)) + + def load_and_split( + self, text_splitter: Optional[TextSplitter] = None + ) -> List[Document]: + """Load Documents and split into chunks. Chunks are returned as Documents. + + Args: + text_splitter: TextSplitter instance to use for splitting documents. + Defaults to RecursiveCharacterTextSplitter. + + Returns: + List of Documents. + """ + if text_splitter is None: + _text_splitter: TextSplitter = RecursiveCharacterTextSplitter() + else: + _text_splitter = text_splitter + docs = self.load() + return _text_splitter.split_documents(docs) \ No newline at end of file diff --git a/coagent/retrieval/document_loaders/jsonl_loader.py b/coagent/retrieval/document_loaders/jsonl_loader.py new file mode 100644 index 0000000..a56e6eb --- /dev/null +++ b/coagent/retrieval/document_loaders/jsonl_loader.py @@ -0,0 +1,62 @@ +import json +from pathlib import Path +from typing import AnyStr, Callable, Dict, List, Optional, Union + +from langchain.docstore.document import Document +from langchain.document_loaders.base import BaseLoader +from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter + +from coagent.utils.common_utils import read_jsonl_file + + +class JSONLLoader(BaseLoader): + + def __init__( + self, + file_path: Union[str, Path], + schema_key: str = "all_text", + content_key: Optional[str] = None, + metadata_func: Optional[Callable[[Dict, Dict], Dict]] = None, + text_content: bool = True, + ): + self.file_path = Path(file_path).resolve() + self.schema_key = schema_key + self._content_key = content_key + self._metadata_func = metadata_func + self._text_content = text_content + + def load(self, ) -> List[Document]: + """Load and return documents from the JSON file.""" + docs: List[Document] = [] + datas = read_jsonl_file(self.file_path) + self._parse(datas, docs) + return docs + + def _parse(self, datas: List, docs: List[Document]) -> None: + for idx, sample in enumerate(datas): + metadata = dict( + source=str(self.file_path), + seq_num=idx, + ) + text = sample.get(self.schema_key, "") + docs.append(Document(page_content=text, metadata=metadata)) + + def load_and_split( + self, text_splitter: Optional[TextSplitter] = None + ) -> List[Document]: + """Load Documents and split into chunks. Chunks are returned as Documents. + + Args: + text_splitter: TextSplitter instance to use for splitting documents. + Defaults to RecursiveCharacterTextSplitter. + + Returns: + List of Documents. + """ + if text_splitter is None: + _text_splitter: TextSplitter = RecursiveCharacterTextSplitter() + else: + _text_splitter = text_splitter + + docs = self.load() + return _text_splitter.split_documents(docs) \ No newline at end of file diff --git a/coagent/retrieval/text_splitter/__init__.py b/coagent/retrieval/text_splitter/__init__.py new file mode 100644 index 0000000..8867d74 --- /dev/null +++ b/coagent/retrieval/text_splitter/__init__.py @@ -0,0 +1,3 @@ +from .langchain_splitter import LCTextSplitter + +__all__ = ["LCTextSplitter"] \ No newline at end of file diff --git a/coagent/retrieval/text_splitter/langchain_splitter.py b/coagent/retrieval/text_splitter/langchain_splitter.py new file mode 100644 index 0000000..4b53025 --- /dev/null +++ b/coagent/retrieval/text_splitter/langchain_splitter.py @@ -0,0 +1,77 @@ +import os +import importlib +from loguru import logger + +from langchain.document_loaders.base import BaseLoader +from langchain.text_splitter import ( + SpacyTextSplitter, RecursiveCharacterTextSplitter +) + +# from configs.model_config import ( +# CHUNK_SIZE, +# OVERLAP_SIZE, +# ZH_TITLE_ENHANCE +# ) +from coagent.utils.path_utils import * + + + +class LCTextSplitter: + '''langchain textsplitter 执行file2text''' + def __init__( + self, filepath: str, text_splitter_name: str = None, + chunk_size: int = 500, + overlap_size: int = 50 + ): + self.filepath = filepath + self.ext = os.path.splitext(filepath)[-1].lower() + self.text_splitter_name = text_splitter_name + self.chunk_size = chunk_size + self.overlap_size = overlap_size + if self.ext not in SUPPORTED_EXTS: + raise ValueError(f"暂未支持的文件格式 {self.ext}") + self.document_loader_name = get_LoaderClass(self.ext) + + def file2text(self, ): + loader = self._load_document() + text_splitter = self._load_text_splitter() + if self.document_loader_name in ["JSONLoader", "JSONLLoader"]: + # docs = loader.load() + docs = loader.load_and_split(text_splitter) + # logger.debug(f"please check your file can be loaded, docs.lens {len(docs)}") + else: + docs = loader.load_and_split(text_splitter) + + return docs + + def _load_document(self, ) -> BaseLoader: + DocumentLoader = EXT2LOADER_DICT[self.ext] + if self.document_loader_name == "UnstructuredFileLoader": + loader = DocumentLoader(self.filepath, autodetect_encoding=True) + else: + loader = DocumentLoader(self.filepath) + return loader + + def _load_text_splitter(self, ): + try: + if self.text_splitter_name is None: + text_splitter = SpacyTextSplitter( + pipeline="zh_core_web_sm", + chunk_size=self.chunk_size, + chunk_overlap=self.overlap_size, + ) + self.text_splitter_name = "SpacyTextSplitter" + # elif self.document_loader_name in ["JSONLoader", "JSONLLoader"]: + # text_splitter = None + else: + text_splitter_module = importlib.import_module('langchain.text_splitter') + TextSplitter = getattr(text_splitter_module, self.text_splitter_name) + text_splitter = TextSplitter( + chunk_size=self.chunk_size, + chunk_overlap=self.overlap_size) + except Exception as e: + text_splitter = RecursiveCharacterTextSplitter( + chunk_size=self.chunk_size, + chunk_overlap=self.overlap_size, + ) + return text_splitter diff --git a/coagent/retrieval/text_splitter/utils.py b/coagent/retrieval/text_splitter/utils.py new file mode 100644 index 0000000..e69de29 diff --git a/coagent/sandbox/pycodebox.py b/coagent/sandbox/pycodebox.py index d74a868..39a4ab9 100644 --- a/coagent/sandbox/pycodebox.py +++ b/coagent/sandbox/pycodebox.py @@ -32,8 +32,8 @@ class PyCodeBox(BaseBox): self.do_check_net = do_check_net self.use_stop = use_stop self.jupyter_work_path = jupyter_work_path - asyncio.run(self.astart()) - # self.start() + # asyncio.run(self.astart()) + self.start() # logger.info(f"""remote_url: {self.remote_url}, # remote_ip: {self.remote_ip}, @@ -199,13 +199,13 @@ class PyCodeBox(BaseBox): async def _aget_kernelid(self, ) -> None: headers = {"Authorization": f'Token {self.token}', 'token': self.token} - response = requests.get(f"{self.kernel_url}?token={self.token}", headers=headers) + # response = requests.get(f"{self.kernel_url}?token={self.token}", headers=headers) async with aiohttp.ClientSession() as session: - async with session.get(f"{self.kernel_url}?token={self.token}", headers=headers) as resp: + async with session.get(f"{self.kernel_url}?token={self.token}", headers=headers, timeout=270) as resp: if len(await resp.json()) > 0: self.kernel_id = (await resp.json())[0]["id"] else: - async with session.post(f"{self.kernel_url}?token={self.token}", headers=headers) as response: + async with session.post(f"{self.kernel_url}?token={self.token}", headers=headers, timeout=270) as response: self.kernel_id = (await response.json())["id"] # if len(response.json()) > 0: @@ -220,41 +220,45 @@ class PyCodeBox(BaseBox): return False try: - response = requests.get(f"{self.kernel_url}?token={self.token}", timeout=270) + response = requests.get(f"{self.kernel_url}?token={self.token}", timeout=10) return response.status_code == 200 except requests.exceptions.ConnectionError: return False + except requests.exceptions.ReadTimeout: + return False async def _acheck_connect(self, ) -> bool: if self.kernel_url == "": return False try: async with aiohttp.ClientSession() as session: - async with session.get(f"{self.kernel_url}?token={self.token}", timeout=270) as resp: + async with session.get(f"{self.kernel_url}?token={self.token}", timeout=10) as resp: return resp.status == 200 except aiohttp.ClientConnectorError: - pass + return False except aiohttp.ServerDisconnectedError: - pass + return False def _check_port(self, ) -> bool: try: - response = requests.get(f"{self.remote_ip}:{self.remote_port}", timeout=270) + response = requests.get(f"{self.remote_ip}:{self.remote_port}", timeout=10) logger.warning(f"Port is conflict, please check your codebox's port {self.remote_port}") return response.status_code == 200 except requests.exceptions.ConnectionError: return False + except requests.exceptions.ReadTimeout: + return False async def _acheck_port(self, ) -> bool: try: async with aiohttp.ClientSession() as session: - async with session.get(f"{self.remote_ip}:{self.remote_port}", timeout=270) as resp: + async with session.get(f"{self.remote_ip}:{self.remote_port}", timeout=10) as resp: # logger.warning(f"Port is conflict, please check your codebox's port {self.remote_port}") return resp.status == 200 except aiohttp.ClientConnectorError: - pass + return False except aiohttp.ServerDisconnectedError: - pass + return False def _check_connect_success(self, retry_nums: int = 2) -> bool: if not self.do_check_net: return True @@ -263,7 +267,7 @@ class PyCodeBox(BaseBox): try: connect_status = self._check_connect() if connect_status: - logger.info(f"{self.remote_url} connection success") + # logger.info(f"{self.remote_url} connection success") return True except requests.exceptions.ConnectionError: logger.info(f"{self.remote_url} connection fail") @@ -301,10 +305,12 @@ class PyCodeBox(BaseBox): else: # TODO 自动检测本地接口 port_status = self._check_port() + self.kernel_url = self.remote_url + "/api/kernels" connect_status = self._check_connect() - logger.info(f"port_status: {port_status}, connect_status: {connect_status}") + if os.environ.get("log_verbose", "0") >= "2": + logger.info(f"port_status: {port_status}, connect_status: {connect_status}") if port_status and not connect_status: - raise BaseException(f"Port is conflict, please check your codebox's port {self.remote_port}") + logger.error("Port is conflict, please check your codebox's port {self.remote_port}") if not connect_status: self.jupyter = subprocess.Popen( @@ -321,14 +327,32 @@ class PyCodeBox(BaseBox): stdout=subprocess.PIPE, ) + record = [] + while True and self.jupyter and len(record)<100: + line = self.jupyter.stderr.readline() + try: + content = line.decode("utf-8") + except: + content = line.decode("gbk") + # logger.debug(content) + record.append(content) + if "control-c" in content.lower(): + break + self.kernel_url = self.remote_url + "/api/kernels" self.do_check_net = True self._check_connect_success() self._get_kernelid() - # logger.debug(self.kernel_url.replace("http", "ws") + f"/{self.kernel_id}/channels?token={self.token}") self.wc_url = self.kernel_url.replace("http", "ws") + f"/{self.kernel_id}/channels?token={self.token}" headers = {"Authorization": f'Token {self.token}', 'token': self.token} - self.ws = create_connection(self.wc_url, headers=headers) + retry_nums = 3 + while retry_nums>=0: + try: + self.ws = create_connection(self.wc_url, headers=headers, timeout=5) + break + except Exception as e: + logger.error(f"create ws connection timeout {e}") + retry_nums -= 1 async def astart(self, ): '''判断是从外部service执行还是内部启动notebook执行''' @@ -369,10 +393,16 @@ class PyCodeBox(BaseBox): cwd=self.jupyter_work_path ) - while True and self.jupyter: + record = [] + while True and self.jupyter and len(record)<100: line = self.jupyter.stderr.readline() - # logger.debug(line.decode("gbk")) - if "Control-C" in line.decode("gbk"): + try: + content = line.decode("utf-8") + except: + content = line.decode("gbk") + # logger.debug(content) + record.append(content) + if "control-c" in content.lower(): break self.kernel_url = self.remote_url + "/api/kernels" self.do_check_net = True @@ -380,7 +410,15 @@ class PyCodeBox(BaseBox): await self._aget_kernelid() self.wc_url = self.kernel_url.replace("http", "ws") + f"/{self.kernel_id}/channels?token={self.token}" headers = {"Authorization": f'Token {self.token}', 'token': self.token} - self.ws = create_connection(self.wc_url, headers=headers) + + retry_nums = 3 + while retry_nums>=0: + try: + self.ws = create_connection(self.wc_url, headers=headers, timeout=5) + break + except Exception as e: + logger.error(f"create ws connection timeout {e}") + retry_nums -= 1 def status(self,) -> CodeBoxStatus: if not self.kernel_id: diff --git a/coagent/service/base_service.py b/coagent/service/base_service.py index 739e715..f5000a5 100644 --- a/coagent/service/base_service.py +++ b/coagent/service/base_service.py @@ -17,7 +17,7 @@ from coagent.orm.commands import * from coagent.utils.path_utils import * from coagent.orm.utils import DocumentFile from coagent.embeddings.utils import load_embeddings, load_embeddings_from_path -from coagent.text_splitter import LCTextSplitter +from coagent.retrieval.text_splitter import LCTextSplitter from coagent.llm_models.llm_config import EmbedConfig @@ -46,7 +46,7 @@ class KBService(ABC): def _load_embeddings(self) -> Embeddings: # return load_embeddings(self.embed_model, embed_device, embedding_model_dict) - return load_embeddings_from_path(self.embed_config.embed_model_path, self.embed_config.model_device) + return load_embeddings_from_path(self.embed_config.embed_model_path, self.embed_config.model_device, self.embed_config.langchain_embeddings) def create_kb(self): """ diff --git a/coagent/service/cb_api.py b/coagent/service/cb_api.py index 6de3c42..5705644 100644 --- a/coagent/service/cb_api.py +++ b/coagent/service/cb_api.py @@ -20,9 +20,6 @@ from coagent.utils.path_utils import * from coagent.orm.commands import * from coagent.db_handler.graph_db_handler.nebula_handler import NebulaHandler from coagent.db_handler.vector_db_handler.chroma_handler import ChromaHandler -# from configs.server_config import NEBULA_HOST, NEBULA_PORT, NEBULA_USER, NEBULA_PASSWORD, NEBULA_STORAGED_PORT -# from configs.server_config import CHROMA_PERSISTENT_PATH - from coagent.base_configs.env_config import ( CB_ROOT_PATH, NEBULA_HOST, NEBULA_PORT, NEBULA_USER, NEBULA_PASSWORD, NEBULA_STORAGED_PORT, @@ -58,10 +55,11 @@ async def create_cb(zip_file, model_name: bool = Body(..., examples=["samples"]), temperature: bool = Body(..., examples=["samples"]), model_device: bool = Body(..., examples=["samples"]), + embed_config: EmbedConfig = None, ) -> BaseResponse: logger.info('cb_name={}, zip_path={}, do_interpret={}'.format(cb_name, code_path, do_interpret)) - embed_config: EmbedConfig = EmbedConfig(**locals()) + embed_config: EmbedConfig = EmbedConfig(**locals()) if embed_config is None else embed_config llm_config: LLMConfig = LLMConfig(**locals()) # Create selected knowledge base @@ -101,9 +99,10 @@ async def delete_cb( model_name: bool = Body(..., examples=["samples"]), temperature: bool = Body(..., examples=["samples"]), model_device: bool = Body(..., examples=["samples"]), + embed_config: EmbedConfig = None, ) -> BaseResponse: logger.info('cb_name={}'.format(cb_name)) - embed_config: EmbedConfig = EmbedConfig(**locals()) + embed_config: EmbedConfig = EmbedConfig(**locals()) if embed_config is None else embed_config llm_config: LLMConfig = LLMConfig(**locals()) # Create selected knowledge base if not validate_kb_name(cb_name): @@ -143,18 +142,24 @@ def search_code(cb_name: str = Body(..., examples=["sofaboot"]), model_name: bool = Body(..., examples=["samples"]), temperature: bool = Body(..., examples=["samples"]), model_device: bool = Body(..., examples=["samples"]), + use_nh: bool = True, + local_graph_path: str = '', + embed_config: EmbedConfig = None, ) -> dict: - - logger.info('cb_name={}'.format(cb_name)) - logger.info('query={}'.format(query)) - logger.info('code_limit={}'.format(code_limit)) - logger.info('search_type={}'.format(search_type)) - logger.info('history_node_list={}'.format(history_node_list)) - embed_config: EmbedConfig = EmbedConfig(**locals()) + + if os.environ.get("log_verbose", "0") >= "2": + logger.info(f'local_graph_path={local_graph_path}') + logger.info('cb_name={}'.format(cb_name)) + logger.info('query={}'.format(query)) + logger.info('code_limit={}'.format(code_limit)) + logger.info('search_type={}'.format(search_type)) + logger.info('history_node_list={}'.format(history_node_list)) + embed_config: EmbedConfig = EmbedConfig(**locals()) if embed_config is None else embed_config llm_config: LLMConfig = LLMConfig(**locals()) try: # load codebase - cbh = CodeBaseHandler(codebase_name=cb_name, embed_config=embed_config, llm_config=llm_config) + cbh = CodeBaseHandler(codebase_name=cb_name, embed_config=embed_config, llm_config=llm_config, + use_nh=use_nh, local_graph_path=local_graph_path) # search code context, related_vertices = cbh.search_code(query, search_type=search_type, limit=code_limit) @@ -179,11 +184,13 @@ def search_related_vertices(cb_name: str = Body(..., examples=["sofaboot"]), # load codebase nh = NebulaHandler(host=NEBULA_HOST, port=NEBULA_PORT, username=NEBULA_USER, password=NEBULA_PASSWORD, space_name=cb_name) - - cypher = f'''MATCH (v1)--(v2) WHERE id(v1) == '{vertex}' RETURN id(v2) as id;''' - + + if vertex.endswith(".java"): + cypher = f'''MATCH (v1)--(v2:package) WHERE id(v1) == '{vertex}' RETURN id(v2) as id;''' + else: + cypher = f'''MATCH (v1)--(v2) WHERE id(v1) == '{vertex}' RETURN id(v2) as id;''' + # cypher = f'''MATCH (v1)--(v2) WHERE id(v1) == '{vertex}' RETURN v2;''' cypher_res = nh.execute_cypher(cypher=cypher, format_res=True) - related_vertices = cypher_res.get('id', []) related_vertices = [i.as_string() for i in related_vertices] @@ -200,8 +207,8 @@ def search_related_vertices(cb_name: str = Body(..., examples=["sofaboot"]), def search_code_by_vertex(cb_name: str = Body(..., examples=["sofaboot"]), vertex: str = Body(..., examples=['***'])) -> dict: - logger.info('cb_name={}'.format(cb_name)) - logger.info('vertex={}'.format(vertex)) + # logger.info('cb_name={}'.format(cb_name)) + # logger.info('vertex={}'.format(vertex)) try: nh = NebulaHandler(host=NEBULA_HOST, port=NEBULA_PORT, username=NEBULA_USER, @@ -233,7 +240,7 @@ def search_code_by_vertex(cb_name: str = Body(..., examples=["sofaboot"]), return res except Exception as e: logger.exception(e) - return {} + return {'code': ""} def cb_exists_api(cb_name: str = Body(..., examples=["sofaboot"])) -> bool: diff --git a/coagent/service/faiss_db_service.py b/coagent/service/faiss_db_service.py index 4a6e43d..c7eac21 100644 --- a/coagent/service/faiss_db_service.py +++ b/coagent/service/faiss_db_service.py @@ -8,17 +8,6 @@ from loguru import logger from langchain.embeddings.base import Embeddings from langchain.docstore.document import Document from langchain.embeddings.huggingface import HuggingFaceEmbeddings -from langchain.vectorstores.utils import DistanceStrategy - -# from configs.model_config import ( -# KB_ROOT_PATH, -# CACHED_VS_NUM, -# EMBEDDING_MODEL, -# EMBEDDING_DEVICE, -# SCORE_THRESHOLD, -# FAISS_NORMALIZE_L2 -# ) -# from configs.model_config import embedding_model_dict from coagent.base_configs.env_config import ( KB_ROOT_PATH, @@ -52,15 +41,15 @@ def load_vector_store( tick: int = 0, # tick will be changed by upload_doc etc. and make cache refreshed. kb_root_path: str = KB_ROOT_PATH, ): - print(f"loading vector store in '{knowledge_base_name}'.") + # print(f"loading vector store in '{knowledge_base_name}'.") vs_path = get_vs_path(knowledge_base_name, kb_root_path) if embeddings is None: - embeddings = load_embeddings_from_path(embed_config.embed_model_path, embed_config.model_device) + embeddings = load_embeddings_from_path(embed_config.embed_model_path, embed_config.model_device, embed_config.langchain_embeddings) if not os.path.exists(vs_path): os.makedirs(vs_path) - distance_strategy = DistanceStrategy.EUCLIDEAN_DISTANCE + distance_strategy = "EUCLIDEAN_DISTANCE" if "index.faiss" in os.listdir(vs_path): search_index = FAISS.load_local(vs_path, embeddings, normalize_L2=FAISS_NORMALIZE_L2, distance_strategy=distance_strategy) else: diff --git a/coagent/tools/cb_query_tool.py b/coagent/tools/cb_query_tool.py index 8025e92..3d11b4e 100644 --- a/coagent/tools/cb_query_tool.py +++ b/coagent/tools/cb_query_tool.py @@ -9,9 +9,7 @@ from pydantic import BaseModel, Field from loguru import logger from coagent.llm_models import LLMConfig, EmbedConfig - from .base_tool import BaseToolModel - from coagent.service.cb_api import search_code @@ -29,7 +27,17 @@ class CodeRetrieval(BaseToolModel): code: str = Field(..., description="检索代码") @classmethod - def run(cls, code_base_name, query, code_limit=1, history_node_list=[], search_type="tag", llm_config: LLMConfig=None, embed_config: EmbedConfig=None): + def run(cls, + code_base_name, + query, + code_limit=1, + history_node_list=[], + search_type="tag", + llm_config: LLMConfig=None, + embed_config: EmbedConfig=None, + use_nh: str=True, + local_graph_path: str='' + ): """excute your tool!""" search_type = { @@ -45,7 +53,8 @@ class CodeRetrieval(BaseToolModel): codes = search_code(code_base_name, query, code_limit, search_type=search_type, history_node_list=history_node_list, embed_engine=embed_config.embed_engine, embed_model=embed_config.embed_model, embed_model_path=embed_config.embed_model_path, model_device=embed_config.model_device, model_name=llm_config.model_name, temperature=llm_config.temperature, - api_base_url=llm_config.api_base_url, api_key=llm_config.api_key + api_base_url=llm_config.api_base_url, api_key=llm_config.api_key, use_nh=use_nh, + local_graph_path=local_graph_path, embed_config=embed_config ) return_codes = [] context = codes['context'] diff --git a/coagent/tools/codechat_tools.py b/coagent/tools/codechat_tools.py index 22c3df7..b99c4dc 100644 --- a/coagent/tools/codechat_tools.py +++ b/coagent/tools/codechat_tools.py @@ -5,6 +5,7 @@ @time: 2023/12/14 上午10:24 @desc: ''' +import os from pydantic import BaseModel, Field from loguru import logger @@ -40,10 +41,9 @@ class CodeRetrievalSingle(BaseToolModel): vertex: str = Field(..., description="代码对应 id") @classmethod - def run(cls, code_base_name, query, embed_config: EmbedConfig, llm_config: LLMConfig, **kargs): + def run(cls, code_base_name, query, embed_config: EmbedConfig, llm_config: LLMConfig, search_type="description", **kargs): """excute your tool!""" - search_type = 'description' code_limit = 1 # default @@ -51,10 +51,11 @@ class CodeRetrievalSingle(BaseToolModel): history_node_list=[], embed_engine=embed_config.embed_engine, embed_model=embed_config.embed_model, embed_model_path=embed_config.embed_model_path, model_device=embed_config.model_device, model_name=llm_config.model_name, temperature=llm_config.temperature, - api_base_url=llm_config.api_base_url, api_key=llm_config.api_key + api_base_url=llm_config.api_base_url, api_key=llm_config.api_key, embed_config=embed_config, use_nh=kargs.get("use_nh", True), + local_graph_path=kargs.get("local_graph_path", "") ) - - logger.debug(search_result) + if os.environ.get("log_verbose", "0") >= "3": + logger.debug(search_result) code = search_result['context'] vertex = search_result['related_vertices'][0] # logger.debug(f"code: {code}, vertex: {vertex}") @@ -83,7 +84,7 @@ class RelatedVerticesRetrival(BaseToolModel): def run(cls, code_base_name: str, vertex: str, **kargs): """execute your tool!""" related_vertices = search_related_vertices(cb_name=code_base_name, vertex=vertex) - logger.debug(f"related_vertices: {related_vertices}") + # logger.debug(f"related_vertices: {related_vertices}") return related_vertices @@ -110,6 +111,6 @@ class Vertex2Code(BaseToolModel): else: vertex = vertex.strip(' "') - logger.info(f'vertex={vertex}') + # logger.info(f'vertex={vertex}') res = search_code_by_vertex(cb_name=code_base_name, vertex=vertex) return res \ No newline at end of file diff --git a/coagent/tools/docs_retrieval.py b/coagent/tools/docs_retrieval.py index 23aef35..fcb7163 100644 --- a/coagent/tools/docs_retrieval.py +++ b/coagent/tools/docs_retrieval.py @@ -2,11 +2,7 @@ from pydantic import BaseModel, Field from loguru import logger from coagent.llm_models.llm_config import EmbedConfig - from .base_tool import BaseToolModel - - - from coagent.service.kb_api import search_docs diff --git a/coagent/tools/duckduckgo_search.py b/coagent/tools/duckduckgo_search.py index badfa3c..8e079e4 100644 --- a/coagent/tools/duckduckgo_search.py +++ b/coagent/tools/duckduckgo_search.py @@ -9,8 +9,10 @@ import numpy as np from loguru import logger from .base_tool import BaseToolModel - -from duckduckgo_search import DDGS +try: + from duckduckgo_search import DDGS +except: + logger.warning("can't find duckduckgo_search, if you need it, please `pip install duckduckgo_search`") class DDGSTool(BaseToolModel): diff --git a/coagent/utils/code2doc_util.py b/coagent/utils/code2doc_util.py new file mode 100644 index 0000000..11cffd4 --- /dev/null +++ b/coagent/utils/code2doc_util.py @@ -0,0 +1,89 @@ +import json + + +def class_info_decode(data): + '''解析class的相关信息''' + params_dict = {} + + for i in data: + _params_dict = {} + for ii in i: + for k, v in ii.items(): + if k=="origin_query": continue + + if k == "Code Path": + _params_dict["code_path"] = v.split("#")[0] + _params_dict["function_name"] = ".".join(v.split("#")[1:]) + + if k == "Class Description": + _params_dict["ClassDescription"] = v + + if k == "Class Base": + _params_dict["ClassBase"] = v + + if k=="Init Parameters": + _params_dict["Parameters"] = v + + + code_path = _params_dict["code_path"] + params_dict.setdefault(code_path, []).append(_params_dict) + + return params_dict + +def method_info_decode(data): + params_dict = {} + + for i in data: + _params_dict = {} + for ii in i: + for k, v in ii.items(): + if k=="origin_query": continue + + if k == "Code Path": + _params_dict["code_path"] = v.split("#")[0] + _params_dict["function_name"] = ".".join(v.split("#")[1:]) + + if k == "Return Value Description": + _params_dict["Returns"] = v + + if k == "Return Type": + _params_dict["ReturnType"] = v + + if k=="Parameters": + _params_dict["Parameters"] = v + + + code_path = _params_dict["code_path"] + params_dict.setdefault(code_path, []).append(_params_dict) + + return params_dict + +def encode2md(data, md_format): + md_dict = {} + for code_path, params_list in data.items(): + for params in params_list: + params["Parameters_text"] = "\n".join([f"{param['param']}({param['param_type']})-{param['param_description']}" + for param in params["Parameters"]]) + # params.delete("Parameters") + text=md_format.format(**params) + md_dict.setdefault(code_path, []).append(text) + return md_dict + + +method_text_md = '''> {function_name} + +| Column Name | Content | +|-----------------|-----------------| +| Parameters | {Parameters_text} | +| Returns | {Returns} | +| Return type | {ReturnType} | +''' + +class_text_md = '''> {code_path} + +Bases: {ClassBase} + +{ClassDescription} + +{Parameters_text} +''' \ No newline at end of file diff --git a/coagent/utils/common_utils.py b/coagent/utils/common_utils.py index 19c0cca..ed5c459 100644 --- a/coagent/utils/common_utils.py +++ b/coagent/utils/common_utils.py @@ -7,7 +7,7 @@ from pathlib import Path from io import BytesIO from fastapi import Body, File, Form, Body, Query, UploadFile from tempfile import SpooledTemporaryFile - +import json DATE_FORMAT = "%Y-%m-%d %H:%M:%S" @@ -109,4 +109,6 @@ def get_uploadfile(file: Union[str, Path, bytes], filename=None) -> UploadFile: temp_file = SpooledTemporaryFile(max_size=10 * 1024 * 1024) temp_file.write(file.read()) temp_file.seek(0) - return UploadFile(file=temp_file, filename=filename) \ No newline at end of file + return UploadFile(file=temp_file, filename=filename) + + diff --git a/coagent/utils/path_utils.py b/coagent/utils/path_utils.py index 4810496..5aec2fd 100644 --- a/coagent/utils/path_utils.py +++ b/coagent/utils/path_utils.py @@ -1,7 +1,7 @@ import os from langchain.document_loaders import CSVLoader, PyPDFLoader, UnstructuredFileLoader, TextLoader, PythonLoader -from coagent.document_loaders import JSONLLoader, JSONLoader +from coagent.retrieval.document_loaders import JSONLLoader, JSONLoader # from configs.model_config import ( # embedding_model_dict, # KB_ROOT_PATH, diff --git a/configs/default_config.py b/configs/default_config.py index 7af6752..3d71d89 100644 --- a/configs/default_config.py +++ b/configs/default_config.py @@ -21,17 +21,20 @@ JUPYTER_WORK_PATH = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath # WEB_CRAWL存储路径 WEB_CRAWL_PATH = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "knowledge_base") # NEBULA_DATA存储路径 -NELUBA_PATH = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "data/neluba_data") +NEBULA_PATH = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "data/nebula_data") -for _path in [LOG_PATH, SOURCE_PATH, KB_ROOT_PATH, NLTK_DATA_PATH, JUPYTER_WORK_PATH, WEB_CRAWL_PATH, NELUBA_PATH]: +# CHROMA 存储路径 +CHROMA_PERSISTENT_PATH = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "data/chroma_data") + +for _path in [LOG_PATH, SOURCE_PATH, KB_ROOT_PATH, CB_ROOT_PATH, NLTK_DATA_PATH, JUPYTER_WORK_PATH, WEB_CRAWL_PATH, NEBULA_PATH, CHROMA_PERSISTENT_PATH]: if not os.path.exists(_path): os.makedirs(_path, exist_ok=True) - -# + path_envt_dict = { "LOG_PATH": LOG_PATH, "SOURCE_PATH": SOURCE_PATH, "KB_ROOT_PATH": KB_ROOT_PATH, "NLTK_DATA_PATH":NLTK_DATA_PATH, "JUPYTER_WORK_PATH": JUPYTER_WORK_PATH, - "WEB_CRAWL_PATH": WEB_CRAWL_PATH, "NELUBA_PATH": NELUBA_PATH + "WEB_CRAWL_PATH": WEB_CRAWL_PATH, "NEBULA_PATH": NEBULA_PATH, + "CHROMA_PERSISTENT_PATH": CHROMA_PERSISTENT_PATH } for path_name, _path in path_envt_dict.items(): os.environ[path_name] = _path diff --git a/configs/model_config.py.example b/configs/model_config.py.example index ca95bb8..d4a0325 100644 --- a/configs/model_config.py.example +++ b/configs/model_config.py.example @@ -33,7 +33,7 @@ except: pass # add your openai key -OPENAI_API_BASE = "http://openai.com/v1/chat/completions" +OPENAI_API_BASE = "https://api.openai.com/v1" os.environ["API_BASE_URL"] = OPENAI_API_BASE os.environ["OPENAI_API_KEY"] = "sk-xx" openai.api_key = "sk-xx" diff --git a/configs/server_config.py.example b/configs/server_config.py.example index 45d4b61..c313cf0 100644 --- a/configs/server_config.py.example +++ b/configs/server_config.py.example @@ -58,9 +58,6 @@ NEBULA_GRAPH_SERVER = { "docker_port": NEBULA_PORT } -# chroma conf -CHROMA_PERSISTENT_PATH = '/home/user/chatbot/data/chroma_data' - # sandbox api server SANDBOX_CONTRAINER_NAME = "devopsgpt_sandbox" SANDBOX_IMAGE_NAME = "devopsgpt:py39" diff --git a/examples/agent_examples/baseGroupPhase_example.py b/examples/agent_examples/baseGroupPhase_example.py index 98bc4bd..3d375e5 100644 --- a/examples/agent_examples/baseGroupPhase_example.py +++ b/examples/agent_examples/baseGroupPhase_example.py @@ -15,11 +15,11 @@ from coagent.connector.schema import Message # tools = toLangchainTools([TOOL_DICT[i] for i in TOOL_SETS if i in TOOL_DICT]) # log-level,print prompt和llm predict -os.environ["log_verbose"] = "0" +os.environ["log_verbose"] = "2" phase_name = "baseGroupPhase" llm_config = LLMConfig( - model_name=LLM_MODEL, model_device="cpu",api_key=os.environ["OPENAI_API_KEY"], + model_name=LLM_MODEL, api_key=os.environ["OPENAI_API_KEY"], api_base_url=os.environ["API_BASE_URL"], temperature=0.3 ) embed_config = EmbedConfig( diff --git a/examples/agent_examples/baseTaskPhase_example.py b/examples/agent_examples/baseTaskPhase_example.py index 2b01f3d..0dd6769 100644 --- a/examples/agent_examples/baseTaskPhase_example.py +++ b/examples/agent_examples/baseTaskPhase_example.py @@ -17,7 +17,7 @@ os.environ["log_verbose"] = "2" phase_name = "baseTaskPhase" llm_config = LLMConfig( - model_name="gpt-3.5-turbo", model_device="cpu",api_key=os.environ["OPENAI_API_KEY"], + model_name="gpt-3.5-turbo", api_key=os.environ["OPENAI_API_KEY"], api_base_url=os.environ["API_BASE_URL"], temperature=0.3 ) embed_config = EmbedConfig( diff --git a/examples/agent_examples/codeChatPhaseLocal_example.py b/examples/agent_examples/codeChatPhaseLocal_example.py new file mode 100644 index 0000000..ff5ee24 --- /dev/null +++ b/examples/agent_examples/codeChatPhaseLocal_example.py @@ -0,0 +1,135 @@ +# encoding: utf-8 +''' +@author: 温进 +@file: codeChatPhaseLocal_example.py +@time: 2024/1/31 下午4:32 +@desc: +''' +import os, sys, requests +from concurrent.futures import ThreadPoolExecutor +from tqdm import tqdm + +import requests +from typing import List + +src_dir = os.path.join( + os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +) +sys.path.append(src_dir) + +from configs.model_config import KB_ROOT_PATH, JUPYTER_WORK_PATH, CB_ROOT_PATH +from configs.server_config import SANDBOX_SERVER +from coagent.tools import toLangchainTools, TOOL_DICT, TOOL_SETS +from coagent.llm_models.llm_config import EmbedConfig, LLMConfig +from coagent.connector.phase import BasePhase +from coagent.connector.schema import Message, Memory +from coagent.codechat.codebase_handler.codebase_handler import CodeBaseHandler + + + +# log-level,print prompt和llm predict +os.environ["log_verbose"] = "1" + +llm_config = LLMConfig( + model_name="gpt-3.5-turbo", model_device="cpu",api_key=os.environ["OPENAI_API_KEY"], + api_base_url=os.environ["API_BASE_URL"], temperature=0.3 + ) +embed_config = EmbedConfig( + embed_engine="model", embed_model="text2vec-base-chinese", + embed_model_path=os.path.join(src_dir, "embedding_models/text2vec-base-chinese") + ) + + +# delete codebase +codebase_name = 'client_local' +code_path = '/Users/bingxu/Desktop/工作/大模型/chatbot/test_code_repo/client' +code_path = "D://chromeDownloads/devopschat-bot/client_v2/client" +use_nh = True +# cbh = CodeBaseHandler(codebase_name, code_path, crawl_type='dir', use_nh=use_nh, local_graph_path=CB_ROOT_PATH, +# llm_config=llm_config, embed_config=embed_config) +cbh = CodeBaseHandler(codebase_name, code_path, crawl_type='dir', use_nh=use_nh, local_graph_path=CB_ROOT_PATH, + llm_config=llm_config, embed_config=embed_config) +cbh.delete_codebase(codebase_name=codebase_name) + + +# initialize codebase +codebase_name = 'client_local' +code_path = '/Users/bingxu/Desktop/工作/大模型/chatbot/test_code_repo/client' +code_path = "D://chromeDownloads/devopschat-bot/client_v2/client" +code_path = "/home/user/client" +use_nh = True +do_interpret = True +cbh = CodeBaseHandler(codebase_name, code_path, crawl_type='dir', use_nh=use_nh, local_graph_path=CB_ROOT_PATH, + llm_config=llm_config, embed_config=embed_config) +cbh.import_code(do_interpret=do_interpret) + + + +# chat with codebase +phase_name = "codeChatPhase" +phase = BasePhase( + phase_name, sandbox_server=SANDBOX_SERVER, jupyter_work_path=JUPYTER_WORK_PATH, + embed_config=embed_config, llm_config=llm_config, kb_root_path=KB_ROOT_PATH, +) + +# remove 这个函数是做什么的 => 基于标签 +# 有没有函数已经实现了从字符串删除指定字符串的功能,使用的话可以怎么使用,写个java代码 => 基于描述 +# 有根据我以下的需求用 java 开发一个方法:输入为字符串,将输入中的 .java 字符串给删除掉,然后返回新的字符串 => 基于描述 + +## 需要启动容器中的nebula,采用use_nh=True来构建代码库,是可以通过cypher来查询 +# round-1 +# query_content = "代码一共有多少类" +# query = Message( +# role_name="human", role_type="user", +# role_content=query_content, input_query=query_content, origin_query=query_content, +# code_engine_name="client_1", score_threshold=1.0, top_k=3, cb_search_type="cypher" +# ) +# +# output_message1, _ = phase.step(query) +# print(output_message1) + +# round-2 +# query_content = "代码库里有哪些函数,返回5个就行" +# query = Message( +# role_name="human", role_type="user", +# role_content=query_content, input_query=query_content, origin_query=query_content, +# code_engine_name="client_1", score_threshold=1.0, top_k=3, cb_search_type="cypher" +# ) +# output_message2, _ = phase.step(query) +# print(output_message2) + + +# round-3 +query_content = "remove 这个函数是做什么的" +query = Message( + role_name="user", role_type="human", + role_content=query_content, input_query=query_content, origin_query=query_content, + code_engine_name=codebase_name, score_threshold=1.0, top_k=3, cb_search_type="tag", + use_nh=False, local_graph_path=CB_ROOT_PATH + ) +output_message3, output_memory3 = phase.step(query) +print(output_memory3.to_str_messages(return_all=True, content_key="parsed_output_list")) + +# +# # round-4 +query_content = "有没有函数已经实现了从字符串删除指定字符串的功能,使用的话可以怎么使用,写个java代码" +query = Message( + role_name="human", role_type="user", + role_content=query_content, input_query=query_content, origin_query=query_content, + code_engine_name=codebase_name, score_threshold=1.0, top_k=3, cb_search_type="description", + use_nh=False, local_graph_path=CB_ROOT_PATH + ) +output_message4, output_memory4 = phase.step(query) +print(output_memory4.to_str_messages(return_all=True, content_key="parsed_output_list")) + + +# # round-5 +query_content = "有根据我以下的需求用 java 开发一个方法:输入为字符串,将输入中的 .java 字符串给删除掉,然后返回新的字符串" +query = Message( + role_name="human", role_type="user", + role_content=query_content, input_query=query_content, origin_query=query_content, + code_engine_name=codebase_name, score_threshold=1.0, top_k=3, cb_search_type="description", + use_nh=False, local_graph_path=CB_ROOT_PATH + ) +output_message5, output_memory5 = phase.step(query) +print(output_memory5.to_str_messages(return_all=True, content_key="parsed_output_list")) diff --git a/examples/agent_examples/codeChatPhase_example.py b/examples/agent_examples/codeChatPhase_example.py index 04bb7c9..1b5a13a 100644 --- a/examples/agent_examples/codeChatPhase_example.py +++ b/examples/agent_examples/codeChatPhase_example.py @@ -17,13 +17,14 @@ os.environ["log_verbose"] = "2" phase_name = "codeChatPhase" llm_config = LLMConfig( - model_name="gpt-3.5-turbo", model_device="cpu",api_key=os.environ["OPENAI_API_KEY"], + model_name="gpt-3.5-turbo", api_key=os.environ["OPENAI_API_KEY"], api_base_url=os.environ["API_BASE_URL"], temperature=0.3 ) embed_config = EmbedConfig( - embed_engine="model", embed_model="text2vec-base-chinese", + embed_engine="model", embed_model="text2vec-base-chinese", embed_model_path=os.path.join(src_dir, "embedding_models/text2vec-base-chinese") ) + phase = BasePhase( phase_name, sandbox_server=SANDBOX_SERVER, jupyter_work_path=JUPYTER_WORK_PATH, embed_config=embed_config, llm_config=llm_config, kb_root_path=KB_ROOT_PATH, @@ -35,50 +36,56 @@ phase = BasePhase( # 有根据我以下的需求用 java 开发一个方法:输入为字符串,将输入中的 .java 字符串给删除掉,然后返回新的字符串 => 基于描述 # round-1 -query_content = "代码一共有多少类" -query = Message( - role_name="human", role_type="user", - role_content=query_content, input_query=query_content, origin_query=query_content, - code_engine_name="client", score_threshold=1.0, top_k=3, cb_search_type="cypher" - ) - -output_message1, _ = phase.step(query) +# query_content = "代码一共有多少类" +# query = Message( +# role_name="human", role_type="user", +# role_content=query_content, input_query=query_content, origin_query=query_content, +# code_engine_name="client_1", score_threshold=1.0, top_k=3, cb_search_type="cypher" +# ) +# +# output_message1, _ = phase.step(query) +# print(output_message1) # round-2 -query_content = "代码库里有哪些函数,返回5个就行" -query = Message( - role_name="human", role_type="user", - role_content=query_content, input_query=query_content, origin_query=query_content, - code_engine_name="client", score_threshold=1.0, top_k=3, cb_search_type="cypher" - ) -output_message2, _ = phase.step(query) +# query_content = "代码库里有哪些函数,返回5个就行" +# query = Message( +# role_name="human", role_type="user", +# role_content=query_content, input_query=query_content, origin_query=query_content, +# code_engine_name="client_1", score_threshold=1.0, top_k=3, cb_search_type="cypher" +# ) +# output_message2, _ = phase.step(query) +# print(output_message2) -# round-3 +# +# # round-3 query_content = "remove 这个函数是做什么的" query = Message( - role_name="user", role_type="human", + role_name="user", role_type="human", role_content=query_content, input_query=query_content, origin_query=query_content, code_engine_name="client", score_threshold=1.0, top_k=3, cb_search_type="tag" ) output_message3, _ = phase.step(query) +print(output_message3) -# round-4 -query_content = "有没有函数已经实现了从字符串删除指定字符串的功能,使用的话可以怎么使用,写个java代码" -query = Message( - role_name="human", role_type="user", - role_content=query_content, input_query=query_content, origin_query=query_content, - code_engine_name="client", score_threshold=1.0, top_k=3, cb_search_type="description" - ) -output_message4, _ = phase.step(query) - - -# round-5 -query_content = "有根据我以下的需求用 java 开发一个方法:输入为字符串,将输入中的 .java 字符串给删除掉,然后返回新的字符串" -query = Message( - role_name="human", role_type="user", - role_content=query_content, input_query=query_content, origin_query=query_content, - code_engine_name="client", score_threshold=1.0, top_k=3, cb_search_type="description" - ) -output_message5, output_memory5 = phase.step(query) - -print(output_memory5.to_str_messages(return_all=True, content_key="parsed_output_list")) \ No newline at end of file +# +# # round-4 +# query_content = "有没有函数已经实现了从字符串删除指定字符串的功能,使用的话可以怎么使用,写个java代码" +# query = Message( +# role_name="human", role_type="user", +# role_content=query_content, input_query=query_content, origin_query=query_content, +# code_engine_name="client_1", score_threshold=1.0, top_k=3, cb_search_type="description" +# ) +# output_message4, _ = phase.step(query) +# print(output_message4) +# +# # round-5 +# query_content = "有根据我以下的需求用 java 开发一个方法:输入为字符串,将输入中的 .java 字符串给删除掉,然后返回新的字符串" +# query = Message( +# role_name="human", role_type="user", +# role_content=query_content, input_query=query_content, origin_query=query_content, +# code_engine_name="client_1", score_threshold=1.0, top_k=3, cb_search_type="description" +# ) +# output_message5, output_memory5 = phase.step(query) +# print(output_message5) +# +# print(output_memory5.to_str_messages(return_all=True, content_key="parsed_output_list")) \ No newline at end of file diff --git a/examples/agent_examples/codeGenDoc_example.py b/examples/agent_examples/codeGenDoc_example.py new file mode 100644 index 0000000..61221f2 --- /dev/null +++ b/examples/agent_examples/codeGenDoc_example.py @@ -0,0 +1,507 @@ +import os, sys, json +from loguru import logger +src_dir = os.path.join( + os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +) +sys.path.append(src_dir) + +from configs.model_config import KB_ROOT_PATH, JUPYTER_WORK_PATH, CB_ROOT_PATH +from configs.server_config import SANDBOX_SERVER +from coagent.llm_models.llm_config import EmbedConfig, LLMConfig + +from coagent.connector.phase import BasePhase +from coagent.connector.agents import BaseAgent +from coagent.connector.schema import Message +from coagent.tools import CodeRetrievalSingle +from coagent.codechat.codebase_handler.codebase_handler import CodeBaseHandler +import importlib + + +# 定义一个新的agent类 +class CodeGenDocer(BaseAgent): + + def start_action_step(self, message: Message) -> Message: + '''do action before agent predict ''' + # 根据问题获取代码片段和节点信息 + action_json = CodeRetrievalSingle.run(message.code_engine_name, message.origin_query, + llm_config=self.llm_config, embed_config=self.embed_config, local_graph_path=message.local_graph_path, use_nh=message.use_nh,search_type="tag") + current_vertex = action_json['vertex'] + message.customed_kargs["Code Snippet"] = action_json["code"] + message.customed_kargs['Current_Vertex'] = current_vertex + return message + + +# add agent or prompt_manager class +agent_module = importlib.import_module("coagent.connector.agents") +setattr(agent_module, 'CodeGenDocer', CodeGenDocer) + + +# log-level,print prompt和llm predict +os.environ["log_verbose"] = "1" + +phase_name = "code2DocsGroup" +llm_config = LLMConfig( + model_name="gpt-4", api_key=os.environ["OPENAI_API_KEY"], + api_base_url=os.environ["API_BASE_URL"], temperature=0.3 +) +embed_config = EmbedConfig( + embed_engine="model", embed_model="text2vec-base-chinese", + embed_model_path=os.path.join(src_dir, "embedding_models/text2vec-base-chinese") + ) + +# initialize codebase +# delete codebase +codebase_name = 'client_local' +code_path = "D://chromeDownloads/devopschat-bot/client_v2/client" +use_nh = False +cbh = CodeBaseHandler(codebase_name, code_path, crawl_type='dir', use_nh=use_nh, local_graph_path=CB_ROOT_PATH, + llm_config=llm_config, embed_config=embed_config) +cbh.delete_codebase(codebase_name=codebase_name) + + +# load codebase +codebase_name = 'client_local' +code_path = "D://chromeDownloads/devopschat-bot/client_v2/client" +use_nh = True +do_interpret = True +cbh = CodeBaseHandler(codebase_name, code_path, crawl_type='dir', use_nh=use_nh, local_graph_path=CB_ROOT_PATH, + llm_config=llm_config, embed_config=embed_config) +cbh.import_code(do_interpret=do_interpret) + +# 根据前面的load过程进行初始化 +cbh = CodeBaseHandler(codebase_name, code_path, crawl_type='dir', use_nh=use_nh, local_graph_path=CB_ROOT_PATH, + llm_config=llm_config, embed_config=embed_config) +phase = BasePhase( + phase_name, sandbox_server=SANDBOX_SERVER, jupyter_work_path=JUPYTER_WORK_PATH, + embed_config=embed_config, llm_config=llm_config, kb_root_path=KB_ROOT_PATH, +) + +for vertex_type in ["class", "method"]: + vertexes = cbh.search_vertices(vertex_type=vertex_type) + logger.info(f"vertexes={vertexes}") + + # round-1 + docs = [] + for vertex in vertexes: + vertex = vertex.split("-")[0] # -为method的参数 + query_content = f"为{vertex_type}节点 {vertex}生成文档" + query = Message( + role_name="human", role_type="user", + role_content=query_content, input_query=query_content, origin_query=query_content, + code_engine_name="client_local", score_threshold=1.0, top_k=3, cb_search_type="tag", use_nh=use_nh, + local_graph_path=CB_ROOT_PATH, + ) + output_message, output_memory = phase.step(query, reinit_memory=True) + # print(output_memory.to_str_messages(return_all=True, content_key="parsed_output_list")) + docs.append(output_memory.get_spec_parserd_output()) + + os.makedirs(f"{CB_ROOT_PATH}/docs", exist_ok=True) + with open(f"{CB_ROOT_PATH}/docs/raw_{vertex_type}.json", "w") as f: + json.dump(docs, f) + + +# 下面把生成的文档信息转换成markdown文本 +from coagent.utils.code2doc_util import * +import json +with open(f"{CB_ROOT_PATH}/docs/raw_method.json", "r") as f: + method_raw_data = json.load(f) + +with open(f"{CB_ROOT_PATH}/docs/raw_class.json", "r") as f: + class_raw_data = json.load(f) + + +method_data = method_info_decode(method_raw_data) +class_data = class_info_decode(class_raw_data) +method_mds = encode2md(method_data, method_text_md) +class_mds = encode2md(class_data, class_text_md) + + +docs_dict = {} +for k,v in class_mds.items(): + method_textmds = method_mds.get(k, []) + for vv in v: + # 理论上只有一个 + text_md = vv + + for method_textmd in method_textmds: + text_md += "\n
" + method_textmd + + docs_dict.setdefault(k, []).append(text_md) + + with open(f"{CB_ROOT_PATH}//docs/{k}.md", "w") as f: + f.write(text_md) + + + + + +#################################### +######## 下面是完整的复现过程 ######## +#################################### + +# import os, sys, requests +# from loguru import logger +# src_dir = os.path.join( +# os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +# ) +# sys.path.append(src_dir) + +# from configs.model_config import KB_ROOT_PATH, JUPYTER_WORK_PATH, CB_ROOT_PATH +# from configs.server_config import SANDBOX_SERVER +# from coagent.tools import toLangchainTools, TOOL_DICT, TOOL_SETS +# from coagent.llm_models.llm_config import EmbedConfig, LLMConfig + +# from coagent.connector.phase import BasePhase +# from coagent.connector.agents import BaseAgent, SelectorAgent +# from coagent.connector.chains import BaseChain +# from coagent.connector.schema import ( +# Message, Memory, load_role_configs, load_phase_configs, load_chain_configs, ActionStatus +# ) +# from coagent.connector.memory_manager import BaseMemoryManager +# from coagent.connector.configs import AGETN_CONFIGS, CHAIN_CONFIGS, PHASE_CONFIGS, BASE_PROMPT_CONFIGS +# from coagent.connector.prompt_manager.prompt_manager import PromptManager +# from coagent.codechat.codebase_handler.codebase_handler import CodeBaseHandler + +# import importlib +# from loguru import logger + + +# from coagent.tools import CodeRetrievalSingle, RelatedVerticesRetrival, Vertex2Code + + +# # update new agent configs +# codeGenDocGroup_PROMPT = """#### Agent Profile + +# Your goal is to response according the Context Data's information with the role that will best facilitate a solution, taking into account all relevant context (Context) provided. + +# When you need to select the appropriate role for handling a user's query, carefully read the provided role names, role descriptions and tool list. + +# ATTENTION: response carefully referenced "Response Output Format" in format. + +# #### Input Format + +# #### Response Output Format + +# **Code Path:** Extract the paths for the class/method/function that need to be addressed from the context + +# **Role:** Select the role from agent names +# """ + +# classGenDoc_PROMPT = """#### Agent Profile +# As an advanced code documentation generator, you are proficient in translating class definitions into comprehensive documentation with a focus on instantiation parameters. +# Your specific task is to parse the given code snippet of a class, extract information regarding its instantiation parameters. + +# ATTENTION: response carefully in "Response Output Format". + +# #### Input Format + +# **Code Snippet:** Provide the full class definition, including the constructor and any parameters it may require for instantiation. + +# #### Response Output Format +# **Class Base:** Specify the base class or interface from which the current class extends, if any. + +# **Class Description:** Offer a brief description of the class's purpose and functionality. + +# **Init Parameters:** List each parameter from construct. For each parameter, provide: +# - `param`: The parameter name +# - `param_description`: A concise explanation of the parameter's purpose. +# - `param_type`: The data type of the parameter, if explicitly defined. + +# ```json +# [ +# { +# "param": "parameter_name", +# "param_description": "A brief description of what this parameter is used for.", +# "param_type": "The data type of the parameter" +# }, +# ... +# ] +# ``` + + +# If no parameter for construct, return +# ```json +# [] +# ``` +# """ + +# funcGenDoc_PROMPT = """#### Agent Profile +# You are a high-level code documentation assistant, skilled at extracting information from function/method code into detailed and well-structured documentation. + +# ATTENTION: response carefully in "Response Output Format". + + +# #### Input Format +# **Code Path:** Provide the code path of the function or method you wish to document. +# This name will be used to identify and extract the relevant details from the code snippet provided. + +# **Code Snippet:** A segment of code that contains the function or method to be documented. + +# #### Response Output Format + +# **Class Description:** Offer a brief description of the method(function)'s purpose and functionality. + +# **Parameters:** Extract parameter for the specific function/method Code from Code Snippet. For parameter, provide: +# - `param`: The parameter name +# - `param_description`: A concise explanation of the parameter's purpose. +# - `param_type`: The data type of the parameter, if explicitly defined. +# ```json +# [ +# { +# "param": "parameter_name", +# "param_description": "A brief description of what this parameter is used for.", +# "param_type": "The data type of the parameter" +# }, +# ... +# ] +# ``` + +# If no parameter for function/method, return +# ```json +# [] +# ``` + +# **Return Value Description:** Describe what the function/method returns upon completion. + +# **Return Type:** Indicate the type of data the function/method returns (e.g., string, integer, object, void). +# """ + +# CODE_GENERATE_GROUP_PROMPT_CONFIGS = [ +# {"field_name": 'agent_profile', "function_name": 'handle_agent_profile', "is_context": False}, +# {"field_name": 'agent_infomation', "function_name": 'handle_agent_data', "is_context": False, "omit_if_empty": False}, +# # {"field_name": 'tool_information',"function_name": 'handle_tool_data', "is_context": False}, +# {"field_name": 'context_placeholder', "function_name": '', "is_context": True}, +# # {"field_name": 'reference_documents', "function_name": 'handle_doc_info'}, +# {"field_name": 'session_records', "function_name": 'handle_session_records'}, +# {"field_name": 'Specific Objective', "function_name": 'handle_specific_objective'}, +# {"field_name": 'Code Snippet', "function_name": 'handle_code_snippet'}, +# {"field_name": 'output_format', "function_name": 'handle_output_format', 'title': 'Response Output Format', "is_context": False}, +# {"field_name": 'begin!!!', "function_name": 'handle_response', "is_context": False, "omit_if_empty": False} +# ] + +# CODE_GENERATE_DOC_PROMPT_CONFIGS = [ +# {"field_name": 'agent_profile', "function_name": 'handle_agent_profile', "is_context": False}, +# # {"field_name": 'tool_information',"function_name": 'handle_tool_data', "is_context": False}, +# {"field_name": 'context_placeholder', "function_name": '', "is_context": True}, +# # {"field_name": 'reference_documents', "function_name": 'handle_doc_info'}, +# {"field_name": 'session_records', "function_name": 'handle_session_records'}, +# {"field_name": 'Specific Objective', "function_name": 'handle_specific_objective'}, +# {"field_name": 'Code Snippet', "function_name": 'handle_code_snippet'}, +# {"field_name": 'output_format', "function_name": 'handle_output_format', 'title': 'Response Output Format', "is_context": False}, +# {"field_name": 'begin!!!', "function_name": 'handle_response', "is_context": False, "omit_if_empty": False} +# ] + + +# class CodeGenDocPM(PromptManager): +# def handle_code_snippet(self, **kwargs) -> str: +# if 'previous_agent_message' not in kwargs: +# return "" +# previous_agent_message: Message = kwargs['previous_agent_message'] +# code_snippet = previous_agent_message.customed_kargs.get("Code Snippet", "") +# current_vertex = previous_agent_message.customed_kargs.get("Current_Vertex", "") +# instruction = "A segment of code that contains the function or method to be documented.\n" +# return instruction + "\n" + f"name: {current_vertex}\n{code_snippet}" + +# def handle_specific_objective(self, **kwargs) -> str: +# if 'previous_agent_message' not in kwargs: +# return "" +# previous_agent_message: Message = kwargs['previous_agent_message'] +# specific_objective = previous_agent_message.parsed_output.get("Code Path") + +# instruction = "Provide the code path of the function or method you wish to document.\n" +# s = instruction + f"\n{specific_objective}" +# return s + + +# from coagent.tools import CodeRetrievalSingle + +# # 定义一个新的agent类 +# class CodeGenDocer(BaseAgent): + +# def start_action_step(self, message: Message) -> Message: +# '''do action before agent predict ''' +# # 根据问题获取代码片段和节点信息 +# action_json = CodeRetrievalSingle.run(message.code_engine_name, message.origin_query, +# llm_config=self.llm_config, embed_config=self.embed_config, local_graph_path=message.local_graph_path, use_nh=message.use_nh,search_type="tag") +# current_vertex = action_json['vertex'] +# message.customed_kargs["Code Snippet"] = action_json["code"] +# message.customed_kargs['Current_Vertex'] = current_vertex +# return message + +# # add agent or prompt_manager class +# agent_module = importlib.import_module("coagent.connector.agents") +# prompt_manager_module = importlib.import_module("coagent.connector.prompt_manager") + +# setattr(agent_module, 'CodeGenDocer', CodeGenDocer) +# setattr(prompt_manager_module, 'CodeGenDocPM', CodeGenDocPM) + + + + +# AGETN_CONFIGS.update({ +# "classGenDoc": { +# "role": { +# "role_prompt": classGenDoc_PROMPT, +# "role_type": "assistant", +# "role_name": "classGenDoc", +# "role_desc": "", +# "agent_type": "CodeGenDocer" +# }, +# "prompt_config": CODE_GENERATE_DOC_PROMPT_CONFIGS, +# "prompt_manager_type": "CodeGenDocPM", +# "chat_turn": 1, +# "focus_agents": [], +# "focus_message_keys": [], +# }, +# "funcGenDoc": { +# "role": { +# "role_prompt": funcGenDoc_PROMPT, +# "role_type": "assistant", +# "role_name": "funcGenDoc", +# "role_desc": "", +# "agent_type": "CodeGenDocer" +# }, +# "prompt_config": CODE_GENERATE_DOC_PROMPT_CONFIGS, +# "prompt_manager_type": "CodeGenDocPM", +# "chat_turn": 1, +# "focus_agents": [], +# "focus_message_keys": [], +# }, +# "codeGenDocsGrouper": { +# "role": { +# "role_prompt": codeGenDocGroup_PROMPT, +# "role_type": "assistant", +# "role_name": "codeGenDocsGrouper", +# "role_desc": "", +# "agent_type": "SelectorAgent" +# }, +# "prompt_config": CODE_GENERATE_GROUP_PROMPT_CONFIGS, +# "group_agents": ["classGenDoc", "funcGenDoc"], +# "chat_turn": 1, +# }, +# }) +# # update new chain configs +# CHAIN_CONFIGS.update({ +# "codeGenDocsGroupChain": { +# "chain_name": "codeGenDocsGroupChain", +# "chain_type": "BaseChain", +# "agents": ["codeGenDocsGrouper"], +# "chat_turn": 1, +# "do_checker": False, +# "chain_prompt": "" +# } +# }) + +# # update phase configs +# PHASE_CONFIGS.update({ +# "codeGenDocsGroup": { +# "phase_name": "codeGenDocsGroup", +# "phase_type": "BasePhase", +# "chains": ["codeGenDocsGroupChain"], +# "do_summary": False, +# "do_search": False, +# "do_doc_retrieval": False, +# "do_code_retrieval": False, +# "do_tool_retrieval": False, +# }, +# }) + + +# role_configs = load_role_configs(AGETN_CONFIGS) +# chain_configs = load_chain_configs(CHAIN_CONFIGS) +# phase_configs = load_phase_configs(PHASE_CONFIGS) + +# # log-level,print prompt和llm predict +# os.environ["log_verbose"] = "1" + +# phase_name = "codeGenDocsGroup" +# llm_config = LLMConfig( +# model_name="gpt-4", api_key=os.environ["OPENAI_API_KEY"], +# api_base_url=os.environ["API_BASE_URL"], temperature=0.3 +# ) +# embed_config = EmbedConfig( +# embed_engine="model", embed_model="text2vec-base-chinese", +# embed_model_path=os.path.join(src_dir, "embedding_models/text2vec-base-chinese") +# ) + + +# # initialize codebase +# # delete codebase +# codebase_name = 'client_local' +# code_path = "D://chromeDownloads/devopschat-bot/client_v2/client" +# use_nh = False +# cbh = CodeBaseHandler(codebase_name, code_path, crawl_type='dir', use_nh=use_nh, local_graph_path=CB_ROOT_PATH, +# llm_config=llm_config, embed_config=embed_config) +# cbh.delete_codebase(codebase_name=codebase_name) + + +# # load codebase +# codebase_name = 'client_local' +# code_path = "D://chromeDownloads/devopschat-bot/client_v2/client" +# use_nh = False +# do_interpret = True +# cbh = CodeBaseHandler(codebase_name, code_path, crawl_type='dir', use_nh=use_nh, local_graph_path=CB_ROOT_PATH, +# llm_config=llm_config, embed_config=embed_config) +# cbh.import_code(do_interpret=do_interpret) + + +# phase = BasePhase( +# phase_name, sandbox_server=SANDBOX_SERVER, jupyter_work_path=JUPYTER_WORK_PATH, +# embed_config=embed_config, llm_config=llm_config, kb_root_path=KB_ROOT_PATH, +# ) + +# for vertex_type in ["class", "method"]: +# vertexes = cbh.search_vertices(vertex_type=vertex_type) +# logger.info(f"vertexes={vertexes}") + +# # round-1 +# docs = [] +# for vertex in vertexes: +# vertex = vertex.split("-")[0] # -为method的参数 +# query_content = f"为{vertex_type}节点 {vertex}生成文档" +# query = Message( +# role_name="human", role_type="user", +# role_content=query_content, input_query=query_content, origin_query=query_content, +# code_engine_name="client_local", score_threshold=1.0, top_k=3, cb_search_type="tag", use_nh=use_nh, +# local_graph_path=CB_ROOT_PATH, +# ) +# output_message, output_memory = phase.step(query, reinit_memory=True) +# # print(output_memory.to_str_messages(return_all=True, content_key="parsed_output_list")) +# docs.append(output_memory.get_spec_parserd_output()) + +# import json +# os.makedirs("/home/user/code_base/docs", exist_ok=True) +# with open(f"/home/user/code_base/docs/raw_{vertex_type}.json", "w") as f: +# json.dump(docs, f) + + +# # 下面把生成的文档信息转换成markdown文本 +# from coagent.utils.code2doc_util import * + +# import json +# with open(f"/home/user/code_base/docs/raw_method.json", "r") as f: +# method_raw_data = json.load(f) + +# with open(f"/home/user/code_base/docs/raw_class.json", "r") as f: +# class_raw_data = json.load(f) + + +# method_data = method_info_decode(method_raw_data) +# class_data = class_info_decode(class_raw_data) +# method_mds = encode2md(method_data, method_text_md) +# class_mds = encode2md(class_data, class_text_md) + +# docs_dict = {} +# for k,v in class_mds.items(): +# method_textmds = method_mds.get(k, []) +# for vv in v: +# # 理论上只有一个 +# text_md = vv + +# for method_textmd in method_textmds: +# text_md += "\n
" + method_textmd + +# docs_dict.setdefault(k, []).append(text_md) + +# with open(f"/home/user/code_base/docs/{k}.md", "w") as f: +# f.write(text_md) \ No newline at end of file diff --git a/examples/agent_examples/codeGenTestCases_example.py b/examples/agent_examples/codeGenTestCases_example.py new file mode 100644 index 0000000..e335ed3 --- /dev/null +++ b/examples/agent_examples/codeGenTestCases_example.py @@ -0,0 +1,444 @@ +import os, sys, json +from loguru import logger +src_dir = os.path.join( + os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +) +sys.path.append(src_dir) + +from configs.model_config import KB_ROOT_PATH, JUPYTER_WORK_PATH, CB_ROOT_PATH +from configs.server_config import SANDBOX_SERVER +from coagent.llm_models.llm_config import EmbedConfig, LLMConfig + +from coagent.connector.phase import BasePhase +from coagent.connector.agents import BaseAgent +from coagent.connector.schema import Message +from coagent.codechat.codebase_handler.codebase_handler import CodeBaseHandler + +import importlib +from loguru import logger + + +from coagent.tools import CodeRetrievalSingle, RelatedVerticesRetrival, Vertex2Code + +# 定义一个新的agent类 +class CodeRetrieval(BaseAgent): + + def start_action_step(self, message: Message) -> Message: + '''do action before agent predict ''' + # 根据问题获取代码片段和节点信息 + action_json = CodeRetrievalSingle.run(message.code_engine_name, message.origin_query, llm_config=self.llm_config, embed_config=self.embed_config, search_type="tag", + local_graph_path=message.local_graph_path, use_nh=message.use_nh) + current_vertex = action_json['vertex'] + message.customed_kargs["Code Snippet"] = action_json["code"] + message.customed_kargs['Current_Vertex'] = current_vertex + + # 获取邻近节点 + action_json = RelatedVerticesRetrival.run(message.code_engine_name, message.customed_kargs['Current_Vertex']) + # 获取邻近节点所有代码 + relative_vertex = [] + retrieval_Codes = [] + for vertex in action_json["vertices"]: + # 由于代码是文件级别,所以相同文件代码不再获取 + # logger.debug(f"{current_vertex}, {vertex}") + current_vertex_name = current_vertex.replace("#", "").replace(".java", "" ) if current_vertex.endswith(".java") else current_vertex + if current_vertex_name.split("#")[0] == vertex.split("#")[0]: continue + + action_json = Vertex2Code.run(message.code_engine_name, vertex) + if action_json["code"]: + retrieval_Codes.append(action_json["code"]) + relative_vertex.append(vertex) + # + message.customed_kargs["Retrieval_Codes"] = retrieval_Codes + message.customed_kargs["Relative_vertex"] = relative_vertex + return message + + +# add agent or prompt_manager class +agent_module = importlib.import_module("coagent.connector.agents") +setattr(agent_module, 'CodeRetrieval', CodeRetrieval) + + + +llm_config = LLMConfig( + model_name="gpt-4", api_key=os.environ["OPENAI_API_KEY"], + api_base_url=os.environ["API_BASE_URL"], temperature=0.3 +) +embed_config = EmbedConfig( + embed_engine="model", embed_model="text2vec-base-chinese", + embed_model_path=os.path.join(src_dir, "embedding_models/text2vec-base-chinese") + ) + +## initialize codebase +# delete codebase +codebase_name = 'client_local' +code_path = "D://chromeDownloads/devopschat-bot/client_v2/client" +use_nh = False +cbh = CodeBaseHandler(codebase_name, code_path, crawl_type='dir', use_nh=use_nh, local_graph_path=CB_ROOT_PATH, + llm_config=llm_config, embed_config=embed_config) +cbh.delete_codebase(codebase_name=codebase_name) + + +# load codebase +codebase_name = 'client_local' +code_path = "D://chromeDownloads/devopschat-bot/client_v2/client" +use_nh = True +do_interpret = True +cbh = CodeBaseHandler(codebase_name, code_path, crawl_type='dir', use_nh=use_nh, local_graph_path=CB_ROOT_PATH, + llm_config=llm_config, embed_config=embed_config) +cbh.import_code(do_interpret=do_interpret) + + +cbh = CodeBaseHandler(codebase_name, code_path, crawl_type='dir', use_nh=use_nh, local_graph_path=CB_ROOT_PATH, + llm_config=llm_config, embed_config=embed_config) +vertexes = cbh.search_vertices(vertex_type="class") + + +# log-level,print prompt和llm predict +os.environ["log_verbose"] = "0" + +phase_name = "code2Tests" +phase = BasePhase( + phase_name, sandbox_server=SANDBOX_SERVER, jupyter_work_path=JUPYTER_WORK_PATH, + embed_config=embed_config, llm_config=llm_config, kb_root_path=KB_ROOT_PATH, +) + +# round-1 +logger.debug(vertexes) +test_cases = [] +for vertex in vertexes: + query_content = f"为{vertex}生成可执行的测例 " + query = Message( + role_name="human", role_type="user", + role_content=query_content, input_query=query_content, origin_query=query_content, + code_engine_name=codebase_name, score_threshold=1.0, top_k=3, cb_search_type="tag", + use_nh=use_nh, local_graph_path=CB_ROOT_PATH, + ) + output_message, output_memory = phase.step(query, reinit_memory=True) + # print(output_memory.to_str_messages(return_all=True, content_key="parsed_output_list")) + print(output_memory.get_spec_parserd_output()) + values = output_memory.get_spec_parserd_output() + test_code = {k:v for i in values for k,v in i.items() if k in ["SaveFileName", "Test Code"]} + test_cases.append(test_code) + + os.makedirs(f"{CB_ROOT_PATH}/tests", exist_ok=True) + + with open(f"{CB_ROOT_PATH}/tests/{test_code['SaveFileName']}", "w") as f: + f.write(test_code["Test Code"]) + break + + + + +# import os, sys, json +# from loguru import logger +# src_dir = os.path.join( +# os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +# ) +# sys.path.append(src_dir) + +# from configs.model_config import KB_ROOT_PATH, JUPYTER_WORK_PATH, CB_ROOT_PATH +# from configs.server_config import SANDBOX_SERVER +# from coagent.tools import toLangchainTools, TOOL_DICT, TOOL_SETS +# from coagent.llm_models.llm_config import EmbedConfig, LLMConfig + +# from coagent.connector.phase import BasePhase +# from coagent.connector.agents import BaseAgent +# from coagent.connector.chains import BaseChain +# from coagent.connector.schema import ( +# Message, Memory, load_role_configs, load_phase_configs, load_chain_configs, ActionStatus +# ) +# from coagent.connector.memory_manager import BaseMemoryManager +# from coagent.connector.configs import AGETN_CONFIGS, CHAIN_CONFIGS, PHASE_CONFIGS, BASE_PROMPT_CONFIGS +# from coagent.connector.prompt_manager.prompt_manager import PromptManager +# from coagent.codechat.codebase_handler.codebase_handler import CodeBaseHandler + +# import importlib +# from loguru import logger + +# # 下面给定了一份代码片段,以及来自于它的依赖类、依赖方法相关的代码片段,你需要判断是否为这段指定代码片段生成测例。 +# # 理论上所有代码都需要写测例,但是受限于人的精力不可能覆盖所有代码 +# # 考虑以下因素进行裁剪: +# # 功能性: 如果它实现的是一个具体的功能或逻辑,则通常需要编写测试用例以验证其正确性。 +# # 复杂性: 如果代码较为,尤其是包含多个条件判断、循环、异常处理等的代码,更可能隐藏bug,因此应该编写测试用例。如果代码涉及复杂的算法或者逻辑,那么编写测试用例可以帮助确保逻辑的正确性,并在未来的重构中防止引入错误。 +# # 关键性: 如果它是关键路径的一部分或影响到核心功能,那么它就需要被测试。对于核心业务逻辑或者系统的关键组件,应当编写全面的测试用例来确保功能的正确性和稳定性。 +# # 依赖性: 如果代码有外部依赖,可能需要编写集成测试或模拟这些依赖进行单元测试。 +# # 用户输入: 如果代码处理用户输入,尤其是来自外部的、非受控的输入,那么创建测试用例来检查输入验证和处理是很重要的。 +# # 频繁更改:对于经常需要更新或修改的代码,有相应的测试用例可以确保更改不会破坏现有功能。 + + +# # 代码公开或重用:如果代码将被公开或用于其他项目,编写测试用例可以提高代码的可信度和易用性。 + + +# # update new agent configs +# judgeGenerateTests_PROMPT = """#### Agent Profile +# When determining the necessity of writing test cases for a given code snippet, +# it's essential to evaluate its interactions with dependent classes and methods (retrieved code snippets), +# in addition to considering these critical factors: +# 1. Functionality: If it implements a concrete function or logic, test cases are typically necessary to verify its correctness. +# 2. Complexity: If the code is complex, especially if it contains multiple conditional statements, loops, exceptions handling, etc., +# it's more likely to harbor bugs, and thus test cases should be written. +# If the code involves complex algorithms or logic, then writing test cases can help ensure the accuracy of the logic and prevent errors during future refactoring. +# 3. Criticality: If it's part of the critical path or affects core functionalities, then it needs to be tested. +# Comprehensive test cases should be written for core business logic or key components of the system to ensure the correctness and stability of the functionality. +# 4. Dependencies: If the code has external dependencies, integration testing may be necessary, or mocking these dependencies during unit testing might be required. +# 5. User Input: If the code handles user input, especially from unregulated external sources, creating test cases to check input validation and handling is important. +# 6. Frequent Changes: For code that requires regular updates or modifications, having the appropriate test cases ensures that changes do not break existing functionalities. + +# #### Input Format + +# **Code Snippet:** the initial Code or objective that the user wanted to achieve + +# **Retrieval Code Snippets:** These are the associated code segments that the main Code Snippet depends on. +# Examine these snippets to understand how they interact with the main snippet and to determine how they might affect the overall functionality. + +# #### Response Output Format +# **Action Status:** Set to 'finished' or 'continued'. +# If set to 'finished', the code snippet does not warrant the generation of a test case. +# If set to 'continued', the code snippet necessitates the creation of a test case. + +# **REASON:** Justify the selection of 'finished' or 'continued', contemplating the decision through a step-by-step rationale. +# """ + +# generateTests_PROMPT = """#### Agent Profile +# As an agent specializing in software quality assurance, +# your mission is to craft comprehensive test cases that bolster the functionality, reliability, and robustness of a specified Code Snippet. +# This task is to be carried out with a keen understanding of the snippet's interactions with its dependent classes and methods—collectively referred to as Retrieval Code Snippets. +# Analyze the details given below to grasp the code's intended purpose, its inherent complexity, and the context within which it operates. +# Your constructed test cases must thoroughly examine the various factors influencing the code's quality and performance. + +# ATTENTION: response carefully referenced "Response Output Format" in format. + +# Each test case should include: +# 1. clear description of the test purpose. +# 2. The input values or conditions for the test. +# 3. The expected outcome or assertion for the test. +# 4. Appropriate tags (e.g., 'functional', 'integration', 'regression') that classify the type of test case. +# 5. these test code should have package and import + +# #### Input Format + +# **Code Snippet:** the initial Code or objective that the user wanted to achieve + +# **Retrieval Code Snippets:** These are the interrelated pieces of code sourced from the codebase, which support or influence the primary Code Snippet. + +# #### Response Output Format +# **SaveFileName:** construct a local file name based on Question and Context, such as + +# ```java +# package/class.java +# ``` + + +# **Test Code:** generate the test code for the current Code Snippet. +# ```java +# ... +# ``` + +# """ + +# CODE_GENERATE_TESTS_PROMPT_CONFIGS = [ +# {"field_name": 'agent_profile', "function_name": 'handle_agent_profile', "is_context": False}, +# # {"field_name": 'tool_information',"function_name": 'handle_tool_data', "is_context": False}, +# {"field_name": 'context_placeholder', "function_name": '', "is_context": True}, +# # {"field_name": 'reference_documents', "function_name": 'handle_doc_info'}, +# {"field_name": 'session_records', "function_name": 'handle_session_records'}, +# {"field_name": 'code_snippet', "function_name": 'handle_code_snippet'}, +# {"field_name": 'retrieval_codes', "function_name": 'handle_retrieval_codes', "description": ""}, +# {"field_name": 'output_format', "function_name": 'handle_output_format', 'title': 'Response Output Format', "is_context": False}, +# {"field_name": 'begin!!!', "function_name": 'handle_response', "is_context": False, "omit_if_empty": False} +# ] + + +# class CodeRetrievalPM(PromptManager): +# def handle_code_snippet(self, **kwargs) -> str: +# if 'previous_agent_message' not in kwargs: +# return "" +# previous_agent_message: Message = kwargs['previous_agent_message'] +# code_snippet = previous_agent_message.customed_kargs.get("Code Snippet", "") +# current_vertex = previous_agent_message.customed_kargs.get("Current_Vertex", "") +# instruction = "the initial Code or objective that the user wanted to achieve" +# return instruction + "\n" + f"name: {current_vertex}\n{code_snippet}" + +# def handle_retrieval_codes(self, **kwargs) -> str: +# if 'previous_agent_message' not in kwargs: +# return "" +# previous_agent_message: Message = kwargs['previous_agent_message'] +# Retrieval_Codes = previous_agent_message.customed_kargs["Retrieval_Codes"] +# Relative_vertex = previous_agent_message.customed_kargs["Relative_vertex"] +# instruction = "the initial Code or objective that the user wanted to achieve" +# s = instruction + "\n" + "\n".join([f"name: {vertext}\n{code}" for vertext, code in zip(Relative_vertex, Retrieval_Codes)]) +# return s + + + +# AGETN_CONFIGS.update({ +# "CodeJudger": { +# "role": { +# "role_prompt": judgeGenerateTests_PROMPT, +# "role_type": "assistant", +# "role_name": "CodeJudger", +# "role_desc": "", +# "agent_type": "CodeRetrieval" +# # "agent_type": "BaseAgent" +# }, +# "prompt_config": CODE_GENERATE_TESTS_PROMPT_CONFIGS, +# "prompt_manager_type": "CodeRetrievalPM", +# "chat_turn": 1, +# "focus_agents": [], +# "focus_message_keys": [], +# }, +# "generateTests": { +# "role": { +# "role_prompt": generateTests_PROMPT, +# "role_type": "assistant", +# "role_name": "generateTests", +# "role_desc": "", +# "agent_type": "CodeRetrieval" +# # "agent_type": "BaseAgent" +# }, +# "prompt_config": CODE_GENERATE_TESTS_PROMPT_CONFIGS, +# "prompt_manager_type": "CodeRetrievalPM", +# "chat_turn": 1, +# "focus_agents": [], +# "focus_message_keys": [], +# }, +# }) +# # update new chain configs +# CHAIN_CONFIGS.update({ +# "codeRetrievalChain": { +# "chain_name": "codeRetrievalChain", +# "chain_type": "BaseChain", +# "agents": ["CodeJudger", "generateTests"], +# "chat_turn": 1, +# "do_checker": False, +# "chain_prompt": "" +# } +# }) + +# # update phase configs +# PHASE_CONFIGS.update({ +# "codeGenerateTests": { +# "phase_name": "codeGenerateTests", +# "phase_type": "BasePhase", +# "chains": ["codeRetrievalChain"], +# "do_summary": False, +# "do_search": False, +# "do_doc_retrieval": False, +# "do_code_retrieval": False, +# "do_tool_retrieval": False, +# }, +# }) + + +# role_configs = load_role_configs(AGETN_CONFIGS) +# chain_configs = load_chain_configs(CHAIN_CONFIGS) +# phase_configs = load_phase_configs(PHASE_CONFIGS) + + +# from coagent.tools import CodeRetrievalSingle, RelatedVerticesRetrival, Vertex2Code + +# # 定义一个新的agent类 +# class CodeRetrieval(BaseAgent): + +# def start_action_step(self, message: Message) -> Message: +# '''do action before agent predict ''' +# # 根据问题获取代码片段和节点信息 +# action_json = CodeRetrievalSingle.run(message.code_engine_name, message.origin_query, llm_config=self.llm_config, embed_config=self.embed_config, search_type="tag", +# local_graph_path=message.local_graph_path, use_nh=message.use_nh) +# current_vertex = action_json['vertex'] +# message.customed_kargs["Code Snippet"] = action_json["code"] +# message.customed_kargs['Current_Vertex'] = current_vertex + +# # 获取邻近节点 +# action_json = RelatedVerticesRetrival.run(message.code_engine_name, message.customed_kargs['Current_Vertex']) +# # 获取邻近节点所有代码 +# relative_vertex = [] +# retrieval_Codes = [] +# for vertex in action_json["vertices"]: +# # 由于代码是文件级别,所以相同文件代码不再获取 +# # logger.debug(f"{current_vertex}, {vertex}") +# current_vertex_name = current_vertex.replace("#", "").replace(".java", "" ) if current_vertex.endswith(".java") else current_vertex +# if current_vertex_name.split("#")[0] == vertex.split("#")[0]: continue + +# action_json = Vertex2Code.run(message.code_engine_name, vertex) +# if action_json["code"]: +# retrieval_Codes.append(action_json["code"]) +# relative_vertex.append(vertex) +# # +# message.customed_kargs["Retrieval_Codes"] = retrieval_Codes +# message.customed_kargs["Relative_vertex"] = relative_vertex +# return message + + +# # add agent or prompt_manager class +# agent_module = importlib.import_module("coagent.connector.agents") +# prompt_manager_module = importlib.import_module("coagent.connector.prompt_manager") + +# setattr(agent_module, 'CodeRetrieval', CodeRetrieval) +# setattr(prompt_manager_module, 'CodeRetrievalPM', CodeRetrievalPM) + + +# # log-level,print prompt和llm predict +# os.environ["log_verbose"] = "0" + +# phase_name = "codeGenerateTests" +# llm_config = LLMConfig( +# model_name="gpt-4", api_key=os.environ["OPENAI_API_KEY"], +# api_base_url=os.environ["API_BASE_URL"], temperature=0.3 +# ) +# embed_config = EmbedConfig( +# embed_engine="model", embed_model="text2vec-base-chinese", +# embed_model_path=os.path.join(src_dir, "embedding_models/text2vec-base-chinese") +# ) + +# ## initialize codebase +# # delete codebase +# codebase_name = 'client_local' +# code_path = "D://chromeDownloads/devopschat-bot/client_v2/client" +# use_nh = False +# cbh = CodeBaseHandler(codebase_name, code_path, crawl_type='dir', use_nh=use_nh, local_graph_path=CB_ROOT_PATH, +# llm_config=llm_config, embed_config=embed_config) +# cbh.delete_codebase(codebase_name=codebase_name) + + +# # load codebase +# codebase_name = 'client_local' +# code_path = "D://chromeDownloads/devopschat-bot/client_v2/client" +# use_nh = True +# do_interpret = True +# cbh = CodeBaseHandler(codebase_name, code_path, crawl_type='dir', use_nh=use_nh, local_graph_path=CB_ROOT_PATH, +# llm_config=llm_config, embed_config=embed_config) +# cbh.import_code(do_interpret=do_interpret) + + +# cbh = CodeBaseHandler(codebase_name, code_path, crawl_type='dir', use_nh=use_nh, local_graph_path=CB_ROOT_PATH, +# llm_config=llm_config, embed_config=embed_config) +# vertexes = cbh.search_vertices(vertex_type="class") + +# phase = BasePhase( +# phase_name, sandbox_server=SANDBOX_SERVER, jupyter_work_path=JUPYTER_WORK_PATH, +# embed_config=embed_config, llm_config=llm_config, kb_root_path=KB_ROOT_PATH, +# ) + +# # round-1 +# logger.debug(vertexes) +# test_cases = [] +# for vertex in vertexes: +# query_content = f"为{vertex}生成可执行的测例 " +# query = Message( +# role_name="human", role_type="user", +# role_content=query_content, input_query=query_content, origin_query=query_content, +# code_engine_name=codebase_name, score_threshold=1.0, top_k=3, cb_search_type="tag", +# use_nh=use_nh, local_graph_path=CB_ROOT_PATH, +# ) +# output_message, output_memory = phase.step(query, reinit_memory=True) +# # print(output_memory.to_str_messages(return_all=True, content_key="parsed_output_list")) +# print(output_memory.get_spec_parserd_output()) +# values = output_memory.get_spec_parserd_output() +# test_code = {k:v for i in values for k,v in i.items() if k in ["SaveFileName", "Test Code"]} +# test_cases.append(test_code) + +# os.makedirs(f"{CB_ROOT_PATH}/tests", exist_ok=True) + +# with open(f"{CB_ROOT_PATH}/tests/{test_code['SaveFileName']}", "w") as f: +# f.write(test_code["Test Code"]) \ No newline at end of file diff --git a/examples/agent_examples/codeReactPhase_example.py b/examples/agent_examples/codeReactPhase_example.py index f4c35ee..ee2e7bd 100644 --- a/examples/agent_examples/codeReactPhase_example.py +++ b/examples/agent_examples/codeReactPhase_example.py @@ -17,7 +17,7 @@ os.environ["log_verbose"] = "2" phase_name = "codeReactPhase" llm_config = LLMConfig( - model_name="gpt-3.5-turbo", model_device="cpu",api_key=os.environ["OPENAI_API_KEY"], + model_name="gpt-3.5-turbo",api_key=os.environ["OPENAI_API_KEY"], api_base_url=os.environ["API_BASE_URL"], temperature=0.3 ) embed_config = EmbedConfig( diff --git a/examples/agent_examples/codeRetrieval_example.py b/examples/agent_examples/codeRetrieval_example.py index e86647e..fbf3e1e 100644 --- a/examples/agent_examples/codeRetrieval_example.py +++ b/examples/agent_examples/codeRetrieval_example.py @@ -18,8 +18,7 @@ from coagent.connector.schema import ( ) from coagent.connector.memory_manager import BaseMemoryManager from coagent.connector.configs import AGETN_CONFIGS, CHAIN_CONFIGS, PHASE_CONFIGS, BASE_PROMPT_CONFIGS -from coagent.connector.utils import parse_section -from coagent.connector.prompt_manager import PromptManager +from coagent.connector.prompt_manager.prompt_manager import PromptManager import importlib from loguru import logger @@ -230,7 +229,7 @@ os.environ["log_verbose"] = "2" phase_name = "codeRetrievalPhase" llm_config = LLMConfig( - model_name="gpt-3.5-turbo", model_device="cpu",api_key=os.environ["OPENAI_API_KEY"], + model_name="gpt-3.5-turbo", api_key=os.environ["OPENAI_API_KEY"], api_base_url=os.environ["API_BASE_URL"], temperature=0.3 ) embed_config = EmbedConfig( @@ -246,7 +245,7 @@ query_content = "UtilsTest 这个类中测试了哪些函数,测试的函数代 query = Message( role_name="human", role_type="user", role_content=query_content, input_query=query_content, origin_query=query_content, - code_engine_name="client", score_threshold=1.0, top_k=3, cb_search_type="tag" + code_engine_name="client_1", score_threshold=1.0, top_k=3, cb_search_type="tag" ) diff --git a/examples/agent_examples/codeToolReactPhase_example.py b/examples/agent_examples/codeToolReactPhase_example.py index feb3be1..88cbfcf 100644 --- a/examples/agent_examples/codeToolReactPhase_example.py +++ b/examples/agent_examples/codeToolReactPhase_example.py @@ -24,7 +24,7 @@ os.environ["log_verbose"] = "2" phase_name = "codeToolReactPhase" llm_config = LLMConfig( - model_name="gpt-3.5-turbo-0613", model_device="cpu",api_key=os.environ["OPENAI_API_KEY"], + model_name="gpt-3.5-turbo-0613", api_key=os.environ["OPENAI_API_KEY"], api_base_url=os.environ["API_BASE_URL"], temperature=0.7 ) embed_config = EmbedConfig( diff --git a/examples/agent_examples/docChatPhase_example.py b/examples/agent_examples/docChatPhase_example.py index fceca34..f2b4b7e 100644 --- a/examples/agent_examples/docChatPhase_example.py +++ b/examples/agent_examples/docChatPhase_example.py @@ -17,7 +17,7 @@ from coagent.connector.schema import Message, Memory tools = toLangchainTools([TOOL_DICT[i] for i in TOOL_SETS if i in TOOL_DICT]) llm_config = LLMConfig( - model_name="gpt-3.5-turbo", model_device="cpu",api_key=os.environ["OPENAI_API_KEY"], + model_name="gpt-3.5-turbo",api_key=os.environ["OPENAI_API_KEY"], api_base_url=os.environ["API_BASE_URL"], temperature=0.3 ) embed_config = EmbedConfig( diff --git a/examples/agent_examples/metagpt_phase_example.py b/examples/agent_examples/metagpt_phase_example.py index dafb06a..4a9729b 100644 --- a/examples/agent_examples/metagpt_phase_example.py +++ b/examples/agent_examples/metagpt_phase_example.py @@ -18,7 +18,7 @@ os.environ["log_verbose"] = "0" phase_name = "metagpt_code_devlop" llm_config = LLMConfig( - model_name="gpt-3.5-turbo", model_device="cpu",api_key=os.environ["OPENAI_API_KEY"], + model_name="gpt-3.5-turbo", api_key=os.environ["OPENAI_API_KEY"], api_base_url=os.environ["API_BASE_URL"], temperature=0.3 ) embed_config = EmbedConfig( diff --git a/examples/agent_examples/searchChatPhase_example.py b/examples/agent_examples/searchChatPhase_example.py index e319ac7..cecafb7 100644 --- a/examples/agent_examples/searchChatPhase_example.py +++ b/examples/agent_examples/searchChatPhase_example.py @@ -20,7 +20,7 @@ os.environ["log_verbose"] = "2" phase_name = "searchChatPhase" llm_config = LLMConfig( - model_name="gpt-3.5-turbo", model_device="cpu",api_key=os.environ["OPENAI_API_KEY"], + model_name="gpt-3.5-turbo", api_key=os.environ["OPENAI_API_KEY"], api_base_url=os.environ["API_BASE_URL"], temperature=0.3 ) embed_config = EmbedConfig( diff --git a/examples/agent_examples/toolReactPhase_example.py b/examples/agent_examples/toolReactPhase_example.py index 93d0a15..7e86be9 100644 --- a/examples/agent_examples/toolReactPhase_example.py +++ b/examples/agent_examples/toolReactPhase_example.py @@ -18,7 +18,7 @@ os.environ["log_verbose"] = "2" phase_name = "toolReactPhase" llm_config = LLMConfig( - model_name="gpt-3.5-turbo", model_device="cpu",api_key=os.environ["OPENAI_API_KEY"], + model_name="gpt-3.5-turbo",api_key=os.environ["OPENAI_API_KEY"], api_base_url=os.environ["API_BASE_URL"], temperature=0.3 ) embed_config = EmbedConfig( diff --git a/examples/api.py b/examples/api.py index 2689169..7c48552 100644 --- a/examples/api.py +++ b/examples/api.py @@ -151,9 +151,9 @@ def create_app(): )(delete_cb) app.post("/code_base/code_base_chat", - tags=["Code Base Management"], - summary="删除 code_base" - )(delete_cb) + tags=["Code Base Management"], + summary="code_base 对话" + )(search_code) app.get("/code_base/list_code_bases", tags=["Code Base Management"], diff --git a/examples/auto_examples/auto_feedback_from_code_execution.py b/examples/auto_examples/auto_feedback_from_code_execution.py index 2c03ea4..c9a3754 100644 --- a/examples/auto_examples/auto_feedback_from_code_execution.py +++ b/examples/auto_examples/auto_feedback_from_code_execution.py @@ -117,7 +117,7 @@ PHASE_CONFIGS.update({ llm_config = LLMConfig( - model_name="gpt-3.5-turbo", model_device="cpu",api_key=os.environ["OPENAI_API_KEY"], + model_name="gpt-3.5-turbo", api_key=os.environ["OPENAI_API_KEY"], api_base_url=os.environ["API_BASE_URL"], temperature=0.3 ) diff --git a/examples/start.py b/examples/start.py index a37022d..2275cfb 100644 --- a/examples/start.py +++ b/examples/start.py @@ -98,12 +98,6 @@ def start_docker(client, script_shs, ports, image_name, container_name, mounts=N network_name ='my_network' def start_sandbox_service(network_name ='my_network'): - # networks = client.networks.list() - # if any([network_name==i.attrs["Name"] for i in networks]): - # network = client.networks.get(network_name) - # else: - # network = client.networks.create('my_network', driver='bridge') - mount = Mount( type='bind', source=os.path.join(src_dir, "jupyter_work"), @@ -114,6 +108,12 @@ def start_sandbox_service(network_name ='my_network'): # 沙盒的启动与服务的启动是独立的 if SANDBOX_SERVER["do_remote"]: client = docker.from_env() + networks = client.networks.list() + if any([network_name==i.attrs["Name"] for i in networks]): + network = client.networks.get(network_name) + else: + network = client.networks.create('my_network', driver='bridge') + # 启动容器 logger.info("start container sandbox service") JUPYTER_WORK_PATH = "/home/user/chatbot/jupyter_work" @@ -150,7 +150,7 @@ def start_api_service(sandbox_host=DEFAULT_BIND_HOST): client = docker.from_env() logger.info("start container service") check_process("api.py", do_stop=True) - check_process("sdfile_api.py", do_stop=True) + check_process("llm_api.py", do_stop=True) check_process("sdfile_api.py", do_stop=True) check_process("webui.py", do_stop=True) mount = Mount( @@ -159,27 +159,28 @@ def start_api_service(sandbox_host=DEFAULT_BIND_HOST): target='/home/user/chatbot/', read_only=False # 如果需要只读访问,将此选项设置为True ) - mount_database = Mount( - type='bind', - source=os.path.join(src_dir, "knowledge_base"), - target='/home/user/knowledge_base/', - read_only=False # 如果需要只读访问,将此选项设置为True - ) - mount_code_database = Mount( - type='bind', - source=os.path.join(src_dir, "code_base"), - target='/home/user/code_base/', - read_only=False # 如果需要只读访问,将此选项设置为True - ) + # mount_database = Mount( + # type='bind', + # source=os.path.join(src_dir, "knowledge_base"), + # target='/home/user/knowledge_base/', + # read_only=False # 如果需要只读访问,将此选项设置为True + # ) + # mount_code_database = Mount( + # type='bind', + # source=os.path.join(src_dir, "code_base"), + # target='/home/user/code_base/', + # read_only=False # 如果需要只读访问,将此选项设置为True + # ) ports={ f"{API_SERVER['docker_port']}/tcp": f"{API_SERVER['port']}/tcp", f"{WEBUI_SERVER['docker_port']}/tcp": f"{WEBUI_SERVER['port']}/tcp", f"{SDFILE_API_SERVER['docker_port']}/tcp": f"{SDFILE_API_SERVER['port']}/tcp", f"{NEBULA_GRAPH_SERVER['docker_port']}/tcp": f"{NEBULA_GRAPH_SERVER['port']}/tcp" } - mounts = [mount, mount_database, mount_code_database] + # mounts = [mount, mount_database, mount_code_database] + mounts = [mount] script_shs = [ - "mkdir -p /home/user/logs", + "mkdir -p /home/user/chatbot/logs", ''' if [ -d "/home/user/chatbot/data/nebula_data/data/meta" ]; then cp -r /home/user/chatbot/data/nebula_data/data /usr/local/nebula/ @@ -197,12 +198,12 @@ def start_api_service(sandbox_host=DEFAULT_BIND_HOST): "pip install jieba", "pip install duckduckgo-search", - "nohup python chatbot/examples/sdfile_api.py > /home/user/logs/sdfile_api.log 2>&1 &", + "nohup python chatbot/examples/sdfile_api.py > /home/user/chatbot/logs/sdfile_api.log 2>&1 &", f"export DUCKDUCKGO_PROXY=socks5://host.docker.internal:13659 && export SANDBOX_HOST={sandbox_host} &&\ - nohup python chatbot/examples/api.py > /home/user/logs/api.log 2>&1 &", + nohup python chatbot/examples/api.py > /home/user/chatbot/logs/api.log 2>&1 &", "nohup python chatbot/examples/llm_api.py > /home/user/llm.log 2>&1 &", f"export DUCKDUCKGO_PROXY=socks5://host.docker.internal:13659 && export SANDBOX_HOST={sandbox_host} &&\ - cd chatbot/examples && nohup streamlit run webui.py > /home/user/logs/start_webui.log 2>&1 &" + cd chatbot/examples && nohup streamlit run webui.py > /home/user/chatbot/logs/start_webui.log 2>&1 &" ] if check_docker(client, CONTRAINER_NAME, do_stop=True): container = start_docker(client, script_shs, ports, IMAGE_NAME, CONTRAINER_NAME, mounts, network=network_name) @@ -212,12 +213,9 @@ def start_api_service(sandbox_host=DEFAULT_BIND_HOST): # 关闭之前启动的docker 服务 # check_docker(client, CONTRAINER_NAME, do_stop=True, ) - # api_sh = "nohup python ../coagent/service/api.py > ../logs/api.log 2>&1 &" api_sh = "nohup python api.py > ../logs/api.log 2>&1 &" - # sdfile_sh = "nohup python ../coagent/service/sdfile_api.py > ../logs/sdfile_api.log 2>&1 &" sdfile_sh = "nohup python sdfile_api.py > ../logs/sdfile_api.log 2>&1 &" notebook_sh = f"nohup jupyter-notebook --NotebookApp.token=mytoken --port={SANDBOX_SERVER['port']} --allow-root --ip=0.0.0.0 --notebook-dir={JUPYTER_WORK_PATH} --no-browser --ServerApp.disable_check_xsrf=True > ../logs/sandbox.log 2>&1 &" - # llm_sh = "nohup python ../coagent/service/llm_api.py > ../logs/llm_api.log 2>&1 &" llm_sh = "nohup python llm_api.py > ../logs/llm_api.log 2>&1 &" webui_sh = "streamlit run webui.py" if USE_TTY else "streamlit run webui.py" diff --git a/examples/webui/code.py b/examples/webui/code.py index 997edca..bf9bf54 100644 --- a/examples/webui/code.py +++ b/examples/webui/code.py @@ -22,7 +22,7 @@ from coagent.service.service_factory import get_cb_details, get_cb_details_by_cb from coagent.orm import table_init -from configs.model_config import EMBEDDING_DEVICE, EMBEDDING_ENGINE, EMBEDDING_MODEL, embedding_model_dict +from configs.model_config import EMBEDDING_DEVICE, EMBEDDING_ENGINE, EMBEDDING_MODEL, embedding_model_dict,llm_model_dict # SENTENCE_SIZE = 100 cell_renderer = JsCode("""function(params) {if(params.value==true){return '✓'}else{return '×'}}""") @@ -117,6 +117,8 @@ def code_page(api: ApiRequest): embed_model_path=embedding_model_dict[EMBEDDING_MODEL], embedding_device=EMBEDDING_DEVICE, llm_model=LLM_MODEL, + api_key=llm_model_dict[LLM_MODEL]["api_key"], + api_base_url=llm_model_dict[LLM_MODEL]["api_base_url"], ) st.toast(ret.get("msg", " ")) st.session_state["selected_cb_name"] = cb_name @@ -153,6 +155,8 @@ def code_page(api: ApiRequest): embed_model_path=embedding_model_dict[EMBEDDING_MODEL], embedding_device=EMBEDDING_DEVICE, llm_model=LLM_MODEL, + api_key=llm_model_dict[LLM_MODEL]["api_key"], + api_base_url=llm_model_dict[LLM_MODEL]["api_base_url"], ) st.toast(ret.get("msg", "删除成功")) time.sleep(0.05) diff --git a/examples/webui/dialogue.py b/examples/webui/dialogue.py index 3beae3f..36a4e18 100644 --- a/examples/webui/dialogue.py +++ b/examples/webui/dialogue.py @@ -11,7 +11,7 @@ from coagent.chat.search_chat import SEARCH_ENGINES from coagent.connector import PHASE_LIST, PHASE_CONFIGS from coagent.service.service_factory import get_cb_details_by_cb_name -from configs.model_config import EMBEDDING_DEVICE, EMBEDDING_MODEL, embedding_model_dict, EMBEDDING_ENGINE, KB_ROOT_PATH +from configs.model_config import EMBEDDING_DEVICE, EMBEDDING_MODEL, embedding_model_dict, EMBEDDING_ENGINE, KB_ROOT_PATH, llm_model_dict chat_box = ChatBox( assistant_avatar="../sources/imgs/devops-chatbot2.png" ) @@ -174,7 +174,7 @@ def dialogue_page(api: ApiRequest): is_detailed = st.toggle(webui_configs["dialogue"]["phase_toggle_detailed_name"], False) tool_using_on = st.toggle( webui_configs["dialogue"]["phase_toggle_doToolUsing"], - PHASE_CONFIGS[choose_phase]["do_using_tool"]) + PHASE_CONFIGS[choose_phase].get("do_using_tool", False)) tool_selects = [] if tool_using_on: with st.expander("工具军火库", True): @@ -183,7 +183,7 @@ def dialogue_page(api: ApiRequest): TOOL_SETS, ["WeatherInfo"]) search_on = st.toggle(webui_configs["dialogue"]["phase_toggle_doSearch"], - PHASE_CONFIGS[choose_phase]["do_search"]) + PHASE_CONFIGS[choose_phase].get("do_search", False)) search_engine, top_k = None, 3 if search_on: with st.expander(webui_configs["dialogue"]["expander_search_name"], True): @@ -195,7 +195,8 @@ def dialogue_page(api: ApiRequest): doc_retrieval_on = st.toggle( webui_configs["dialogue"]["phase_toggle_doDocRetrieval"], - PHASE_CONFIGS[choose_phase]["do_doc_retrieval"]) + PHASE_CONFIGS[choose_phase].get("do_doc_retrieval", False) + ) selected_kb, top_k, score_threshold = None, 3, 1.0 if doc_retrieval_on: with st.expander(webui_configs["dialogue"]["kbase_expander_name"], True): @@ -215,7 +216,7 @@ def dialogue_page(api: ApiRequest): code_retrieval_on = st.toggle( webui_configs["dialogue"]["phase_toggle_doCodeRetrieval"], - PHASE_CONFIGS[choose_phase]["do_code_retrieval"]) + PHASE_CONFIGS[choose_phase].get("do_code_retrieval", False)) selected_cb, top_k = None, 1 cb_search_type = "tag" if code_retrieval_on: @@ -296,7 +297,8 @@ def dialogue_page(api: ApiRequest): r = api.chat_chat( prompt, history, no_remote_api=True, embed_model=EMBEDDING_MODEL, embed_model_path=embedding_model_dict[EMBEDDING_MODEL], - model_device=EMBEDDING_DEVICE, embed_engine=EMBEDDING_ENGINE, + model_device=EMBEDDING_DEVICE, embed_engine=EMBEDDING_ENGINE,api_key=llm_model_dict[LLM_MODEL]["api_key"], + api_base_url=llm_model_dict[LLM_MODEL]["api_base_url"], llm_model=LLM_MODEL) for t in r: if error_msg := check_error_msg(t): # check whether error occured @@ -362,6 +364,8 @@ def dialogue_page(api: ApiRequest): "embed_engine": EMBEDDING_ENGINE, "kb_root_path": KB_ROOT_PATH, "model_name": LLM_MODEL, + "api_key": llm_model_dict[LLM_MODEL]["api_key"], + "api_base_url": llm_model_dict[LLM_MODEL]["api_base_url"], } text = "" d = {"docs": []} @@ -405,7 +409,10 @@ def dialogue_page(api: ApiRequest): api.knowledge_base_chat( prompt, selected_kb, kb_top_k, score_threshold, history, embed_model=EMBEDDING_MODEL, embed_model_path=embedding_model_dict[EMBEDDING_MODEL], - model_device=EMBEDDING_DEVICE, embed_engine=EMBEDDING_ENGINE, llm_model=LLM_MODEL) + model_device=EMBEDDING_DEVICE, embed_engine=EMBEDDING_ENGINE, llm_model=LLM_MODEL, + api_key=llm_model_dict[LLM_MODEL]["api_key"], + api_base_url=llm_model_dict[LLM_MODEL]["api_base_url"], + ) ): if error_msg := check_error_msg(d): # check whether error occured st.error(error_msg) @@ -415,11 +422,7 @@ def dialogue_page(api: ApiRequest): # chat_box.update_msg("知识库匹配结果: \n\n".join(d["docs"]), element_index=1, streaming=False, state="complete") chat_box.update_msg(text, element_index=0, streaming=False) # 更新最终的字符串,去除光标 chat_box.update_msg("{webui_configs['chat']['chatbox_doc_result']}: \n\n".join(d["docs"]), element_index=1, streaming=False, state="complete") - # # 判断是否存在代码, 并提高编辑功能,执行功能 - # code_text = api.codebox.decode_code_from_text(text) - # GLOBAL_EXE_CODE_TEXT = code_text - # if code_text and code_exec_on: - # codebox_res = api.codebox_chat("```"+code_text+"```", do_code_exe=True) + elif dialogue_mode == webui_configs["dialogue"]["mode"][2]: logger.info('prompt={}'.format(prompt)) logger.info('history={}'.format(history)) @@ -438,7 +441,9 @@ def dialogue_page(api: ApiRequest): cb_search_type=cb_search_type, no_remote_api=True, embed_model=EMBEDDING_MODEL, embed_model_path=embedding_model_dict[EMBEDDING_MODEL], - embed_engine=EMBEDDING_ENGINE, llm_model=LLM_MODEL + embed_engine=EMBEDDING_ENGINE, llm_model=LLM_MODEL, + api_key=llm_model_dict[LLM_MODEL]["api_key"], + api_base_url=llm_model_dict[LLM_MODEL]["api_base_url"], )): if error_msg := check_error_msg(d): st.error(error_msg) @@ -448,6 +453,7 @@ def dialogue_page(api: ApiRequest): chat_box.update_msg(text, element_index=0) # postprocess + logger.debug(f"d={d}") text = replace_lt_gt(text) chat_box.update_msg(text, element_index=0, streaming=False) # 更新最终的字符串,去除光标 logger.debug('text={}'.format(text)) @@ -467,7 +473,9 @@ def dialogue_page(api: ApiRequest): api.search_engine_chat( prompt, search_engine, se_top_k, history, embed_model=EMBEDDING_MODEL, embed_model_path=embedding_model_dict[EMBEDDING_MODEL], - model_device=EMBEDDING_DEVICE, embed_engine=EMBEDDING_ENGINE, llm_model=LLM_MODEL) + model_device=EMBEDDING_DEVICE, embed_engine=EMBEDDING_ENGINE, llm_model=LLM_MODEL, + pi_key=llm_model_dict[LLM_MODEL]["api_key"], + api_base_url=llm_model_dict[LLM_MODEL]["api_base_url"],) ): if error_msg := check_error_msg(d): # check whether error occured st.error(error_msg) @@ -477,56 +485,11 @@ def dialogue_page(api: ApiRequest): # chat_box.update_msg("搜索匹配结果: \n\n".join(d["docs"]), element_index=1, streaming=False) chat_box.update_msg(text, element_index=0, streaming=False) # 更新最终的字符串,去除光标 chat_box.update_msg(f"{webui_configs['chat']['chatbox_search_result']}: \n\n".join(d["docs"]), element_index=1, streaming=False, state="complete") - # # 判断是否存在代码, 并提高编辑功能,执行功能 - # code_text = api.codebox.decode_code_from_text(text) - # GLOBAL_EXE_CODE_TEXT = code_text - # if code_text and code_exec_on: - # codebox_res = api.codebox_chat("```"+code_text+"```", do_code_exe=True) # 将上传文件清空 st.session_state["interpreter_file_key"] += 1 st.experimental_rerun() - # if code_interpreter_on: - # with st.expander(webui_configs['sandbox']['expander_code_name'], False): - # code_part = st.text_area( - # webui_configs['sandbox']['textArea_code_name'], code_text, key="code_text") - # cols = st.columns(2) - # if cols[0].button( - # webui_configs['sandbox']['button_modify_code_name'], - # use_container_width=True, - # ): - # code_text = code_part - # GLOBAL_EXE_CODE_TEXT = code_text - # st.toast(webui_configs['sandbox']['text_modify_code']) - - # if cols[1].button( - # webui_configs['sandbox']['button_exec_code_name'], - # use_container_width=True - # ): - # if code_text: - # codebox_res = api.codebox_chat("```"+code_text+"```", do_code_exe=True) - # st.toast(webui_configs['sandbox']['text_execing_code'],) - # else: - # st.toast(webui_configs['sandbox']['text_error_exec_code'],) - - # #TODO 这段信息会被记录到history里 - # if codebox_res is not None and codebox_res.code_exe_status != 200: - # st.toast(f"{codebox_res.code_exe_response}") - - # if codebox_res is not None and codebox_res.code_exe_status == 200: - # st.toast(f"codebox_chat {codebox_res}") - # chat_box.ai_say(Markdown(code_text, in_expander=True, title="code interpreter", unsafe_allow_html=True), ) - # if codebox_res.code_exe_type == "image/png": - # base_text = f"```\n{code_text}\n```\n\n" - # img_html = "".format( - # codebox_res.code_exe_response - # ) - # chat_box.update_msg(img_html, streaming=False, state="complete") - # else: - # chat_box.update_msg('```\n'+code_text+'\n```'+"\n\n"+'```\n'+codebox_res.code_exe_response+'\n```', - # streaming=False, state="complete") - now = datetime.now() with st.sidebar: diff --git a/examples/webui/document.py b/examples/webui/document.py index 16b2f72..4f41c97 100644 --- a/examples/webui/document.py +++ b/examples/webui/document.py @@ -14,7 +14,8 @@ from coagent.orm import table_init from configs.model_config import ( KB_ROOT_PATH, kbs_config, DEFAULT_VS_TYPE, WEB_CRAWL_PATH, - EMBEDDING_DEVICE, EMBEDDING_ENGINE, EMBEDDING_MODEL, embedding_model_dict + EMBEDDING_DEVICE, EMBEDDING_ENGINE, EMBEDDING_MODEL, embedding_model_dict, + llm_model_dict ) # SENTENCE_SIZE = 100 @@ -136,6 +137,8 @@ def knowledge_page( embed_engine=EMBEDDING_ENGINE, embedding_device= EMBEDDING_DEVICE, embed_model_path=embedding_model_dict[embed_model], + api_key=llm_model_dict[LLM_MODEL]["api_key"], + api_base_url=llm_model_dict[LLM_MODEL]["api_base_url"], ) st.toast(ret.get("msg", " ")) st.session_state["selected_kb_name"] = kb_name @@ -160,7 +163,10 @@ def knowledge_page( data = [{"file": f, "knowledge_base_name": kb, "not_refresh_vs_cache": True, "embed_model": EMBEDDING_MODEL, "embed_model_path": embedding_model_dict[EMBEDDING_MODEL], "model_device": EMBEDDING_DEVICE, - "embed_engine": EMBEDDING_ENGINE} + "embed_engine": EMBEDDING_ENGINE, + "api_key": llm_model_dict[LLM_MODEL]["api_key"], + "api_base_url": llm_model_dict[LLM_MODEL]["api_base_url"], + } for f in files] data[-1]["not_refresh_vs_cache"]=False for k in data: @@ -210,7 +216,9 @@ def knowledge_page( "embed_model": EMBEDDING_MODEL, "embed_model_path": embedding_model_dict[EMBEDDING_MODEL], "model_device": EMBEDDING_DEVICE, - "embed_engine": EMBEDDING_ENGINE}] + "embed_engine": EMBEDDING_ENGINE, + "api_key": llm_model_dict[LLM_MODEL]["api_key"], + "api_base_url": llm_model_dict[LLM_MODEL]["api_base_url"],}] for k in data: ret = api.upload_kb_doc(**k) logger.info(ret) @@ -297,7 +305,9 @@ def knowledge_page( api.update_kb_doc(kb, row["file_name"], embed_engine=EMBEDDING_ENGINE,embed_model=EMBEDDING_MODEL, embed_model_path=embedding_model_dict[EMBEDDING_MODEL], - model_device=EMBEDDING_DEVICE + model_device=EMBEDDING_DEVICE, + api_key=llm_model_dict[LLM_MODEL]["api_key"], + api_base_url=llm_model_dict[LLM_MODEL]["api_base_url"], ) st.experimental_rerun() @@ -311,7 +321,9 @@ def knowledge_page( api.delete_kb_doc(kb, row["file_name"], embed_engine=EMBEDDING_ENGINE,embed_model=EMBEDDING_MODEL, embed_model_path=embedding_model_dict[EMBEDDING_MODEL], - model_device=EMBEDDING_DEVICE) + model_device=EMBEDDING_DEVICE, + api_key=llm_model_dict[LLM_MODEL]["api_key"], + api_base_url=llm_model_dict[LLM_MODEL]["api_base_url"],) st.experimental_rerun() if cols[3].button( @@ -323,7 +335,9 @@ def knowledge_page( ret = api.delete_kb_doc(kb, row["file_name"], True, embed_engine=EMBEDDING_ENGINE,embed_model=EMBEDDING_MODEL, embed_model_path=embedding_model_dict[EMBEDDING_MODEL], - model_device=EMBEDDING_DEVICE) + model_device=EMBEDDING_DEVICE, + api_key=llm_model_dict[LLM_MODEL]["api_key"], + api_base_url=llm_model_dict[LLM_MODEL]["api_base_url"],) st.toast(ret.get("msg", " ")) st.experimental_rerun() @@ -344,6 +358,8 @@ def knowledge_page( for d in api.recreate_vector_store( kb, vs_type=default_vs_type, embed_model=embedding_model, embedding_device=EMBEDDING_DEVICE, embed_model_path=embedding_model_dict["embedding_model"], embed_engine=EMBEDDING_ENGINE, + api_key=llm_model_dict[LLM_MODEL]["api_key"], + api_base_url=llm_model_dict[LLM_MODEL]["api_base_url"], ): if msg := check_error_msg(d): st.toast(msg) diff --git a/examples/webui/utils.py b/examples/webui/utils.py index 35f3892..57607e4 100644 --- a/examples/webui/utils.py +++ b/examples/webui/utils.py @@ -299,7 +299,9 @@ class ApiRequest: stream: bool = True, no_remote_api: bool = None, embed_model: str="", embed_model_path: str="", model_device: str="", embed_engine: str="", - llm_model: str ="", temperature: float= 0.2 + llm_model: str ="", temperature: float= 0.2, + api_key: str=os.environ["OPENAI_API_KEY"], + api_base_url: str = os.environ["API_BASE_URL"], ): ''' 对应api.py/chat/chat接口 @@ -311,8 +313,8 @@ class ApiRequest: "query": query, "history": history, "stream": stream, - "api_key": os.environ["OPENAI_API_KEY"], - "api_base_url": os.environ["API_BASE_URL"], + "api_key": api_key, + "api_base_url": api_base_url, "embed_model": embed_model, "embed_model_path": embed_model_path, "embed_engine": embed_engine, @@ -339,7 +341,9 @@ class ApiRequest: stream: bool = True, no_remote_api: bool = None, embed_model: str="", embed_model_path: str="", model_device: str="", embed_engine: str="", - llm_model: str ="", temperature: float= 0.2 + llm_model: str ="", temperature: float= 0.2, + api_key: str=os.environ["OPENAI_API_KEY"], + api_base_url: str = os.environ["API_BASE_URL"], ): ''' 对应api.py/chat/knowledge_base_chat接口 @@ -355,8 +359,8 @@ class ApiRequest: "history": history, "stream": stream, "local_doc_url": no_remote_api, - "api_key": os.environ["OPENAI_API_KEY"], - "api_base_url": os.environ["API_BASE_URL"], + "api_key": api_key, + "api_base_url": api_base_url, "embed_model": embed_model, "embed_model_path": embed_model_path, "embed_engine": embed_engine, @@ -386,7 +390,10 @@ class ApiRequest: stream: bool = True, no_remote_api: bool = None, embed_model: str="", embed_model_path: str="", model_device: str="", embed_engine: str="", - llm_model: str ="", temperature: float= 0.2 + llm_model: str ="", temperature: float= 0.2, + api_key: str=os.environ["OPENAI_API_KEY"], + api_base_url: str = os.environ["API_BASE_URL"], + ): ''' 对应api.py/chat/search_engine_chat接口 @@ -400,8 +407,8 @@ class ApiRequest: "top_k": top_k, "history": history, "stream": stream, - "api_key": os.environ["OPENAI_API_KEY"], - "api_base_url": os.environ["API_BASE_URL"], + "api_key": api_key, + "api_base_url": api_base_url, "embed_model": embed_model, "embed_model_path": embed_model_path, "embed_engine": embed_engine, @@ -432,7 +439,9 @@ class ApiRequest: stream: bool = True, no_remote_api: bool = None, embed_model: str="", embed_model_path: str="", model_device: str="", embed_engine: str="", - llm_model: str ="", temperature: float= 0.2 + llm_model: str ="", temperature: float= 0.2, + api_key: str=os.environ["OPENAI_API_KEY"], + api_base_url: str = os.environ["API_BASE_URL"], ): ''' 对应api.py/chat/knowledge_base_chat接口 @@ -458,8 +467,8 @@ class ApiRequest: "cb_search_type": cb_search_type, "stream": stream, "local_doc_url": no_remote_api, - "api_key": os.environ["OPENAI_API_KEY"], - "api_base_url": os.environ["API_BASE_URL"], + "api_key": api_key, + "api_base_url": api_base_url, "embed_model": embed_model, "embed_model_path": embed_model_path, "embed_engine": embed_engine, @@ -510,6 +519,8 @@ class ApiRequest: embed_model: str="", embed_model_path: str="", model_device: str="", embed_engine: str="", temperature: float=0.2, model_name:str ="", + api_key: str=os.environ["OPENAI_API_KEY"], + api_base_url: str = os.environ["API_BASE_URL"], ): ''' 对应api.py/chat/chat接口 @@ -541,8 +552,8 @@ class ApiRequest: "isDetailed": isDetailed, "upload_file": upload_file, "kb_root_path": kb_root_path, - "api_key": os.environ["OPENAI_API_KEY"], - "api_base_url": os.environ["API_BASE_URL"], + "api_key": api_key, + "api_base_url": api_base_url, "embed_model": embed_model, "embed_model_path": embed_model_path, "embed_engine": embed_engine, @@ -588,6 +599,8 @@ class ApiRequest: embed_model: str="", embed_model_path: str="", model_device: str="", embed_engine: str="", temperature: float=0.2, model_name: str="", + api_key: str=os.environ["OPENAI_API_KEY"], + api_base_url: str = os.environ["API_BASE_URL"], ): ''' 对应api.py/chat/chat接口 @@ -620,8 +633,8 @@ class ApiRequest: "isDetailed": isDetailed, "upload_file": upload_file, "kb_root_path": kb_root_path, - "api_key": os.environ["OPENAI_API_KEY"], - "api_base_url": os.environ["API_BASE_URL"], + "api_key": api_key, + "api_base_url": api_base_url, "embed_model": embed_model, "embed_model_path": embed_model_path, "embed_engine": embed_engine, @@ -694,7 +707,9 @@ class ApiRequest: no_remote_api: bool = None, kb_root_path: str =KB_ROOT_PATH, embed_model: str="", embed_model_path: str="", - embedding_device: str="", embed_engine: str="" + embedding_device: str="", embed_engine: str="", + api_key: str=os.environ["OPENAI_API_KEY"], + api_base_url: str = os.environ["API_BASE_URL"], ): ''' 对应api.py/knowledge_base/create_knowledge_base接口 @@ -706,8 +721,8 @@ class ApiRequest: "knowledge_base_name": knowledge_base_name, "vector_store_type": vector_store_type, "kb_root_path": kb_root_path, - "api_key": os.environ["OPENAI_API_KEY"], - "api_base_url": os.environ["API_BASE_URL"], + "api_key": api_key, + "api_base_url": api_base_url, "embed_model": embed_model, "embed_model_path": embed_model_path, "model_device": embedding_device, @@ -781,7 +796,9 @@ class ApiRequest: no_remote_api: bool = None, kb_root_path: str = KB_ROOT_PATH, embed_model: str="", embed_model_path: str="", - model_device: str="", embed_engine: str="" + model_device: str="", embed_engine: str="", + api_key: str=os.environ["OPENAI_API_KEY"], + api_base_url: str = os.environ["API_BASE_URL"], ): ''' 对应api.py/knowledge_base/upload_docs接口 @@ -810,8 +827,8 @@ class ApiRequest: override, not_refresh_vs_cache, kb_root_path=kb_root_path, - api_key=os.environ["OPENAI_API_KEY"], - api_base_url=os.environ["API_BASE_URL"], + api_key=api_key, + api_base_url=api_base_url, embed_model=embed_model, embed_model_path=embed_model_path, model_device=model_device, @@ -839,7 +856,9 @@ class ApiRequest: no_remote_api: bool = None, kb_root_path: str = KB_ROOT_PATH, embed_model: str="", embed_model_path: str="", - model_device: str="", embed_engine: str="" + model_device: str="", embed_engine: str="", + api_key: str=os.environ["OPENAI_API_KEY"], + api_base_url: str = os.environ["API_BASE_URL"], ): ''' 对应api.py/knowledge_base/delete_doc接口 @@ -853,8 +872,8 @@ class ApiRequest: "delete_content": delete_content, "not_refresh_vs_cache": not_refresh_vs_cache, "kb_root_path": kb_root_path, - "api_key": os.environ["OPENAI_API_KEY"], - "api_base_url": os.environ["API_BASE_URL"], + "api_key": api_key, + "api_base_url": api_base_url, "embed_model": embed_model, "embed_model_path": embed_model_path, "model_device": model_device, @@ -878,7 +897,9 @@ class ApiRequest: not_refresh_vs_cache: bool = False, no_remote_api: bool = None, embed_model: str="", embed_model_path: str="", - model_device: str="", embed_engine: str="" + model_device: str="", embed_engine: str="", + api_key: str=os.environ["OPENAI_API_KEY"], + api_base_url: str = os.environ["API_BASE_URL"], ): ''' 对应api.py/knowledge_base/update_doc接口 @@ -889,8 +910,8 @@ class ApiRequest: if no_remote_api: response = run_async(update_doc( knowledge_base_name, file_name, not_refresh_vs_cache, kb_root_path=KB_ROOT_PATH, - api_key=os.environ["OPENAI_API_KEY"], - api_base_url=os.environ["API_BASE_URL"], + api_key=api_key, + api_base_url=api_base_url, embed_model=embed_model, embed_model_path=embed_model_path, model_device=model_device, @@ -915,7 +936,9 @@ class ApiRequest: no_remote_api: bool = None, kb_root_path: str =KB_ROOT_PATH, embed_model: str="", embed_model_path: str="", - embedding_device: str="", embed_engine: str="" + embedding_device: str="", embed_engine: str="", + api_key: str=os.environ["OPENAI_API_KEY"], + api_base_url: str = os.environ["API_BASE_URL"], ): ''' 对应api.py/knowledge_base/recreate_vector_store接口 @@ -928,8 +951,8 @@ class ApiRequest: "allow_empty_kb": allow_empty_kb, "vs_type": vs_type, "kb_root_path": kb_root_path, - "api_key": os.environ["OPENAI_API_KEY"], - "api_base_url": os.environ["API_BASE_URL"], + "api_key": api_key, + "api_base_url": api_base_url, "embed_model": embed_model, "embed_model_path": embed_model_path, "model_device": embedding_device, @@ -1041,7 +1064,9 @@ class ApiRequest: # code base 相关操作 def create_code_base(self, cb_name, zip_file, do_interpret: bool, no_remote_api: bool = None, embed_model: str="", embed_model_path: str="", embedding_device: str="", embed_engine: str="", - llm_model: str ="", temperature: float= 0.2 + llm_model: str ="", temperature: float= 0.2, + api_key: str=os.environ["OPENAI_API_KEY"], + api_base_url: str = os.environ["API_BASE_URL"], ): ''' 创建 code_base @@ -1067,8 +1092,8 @@ class ApiRequest: "cb_name": cb_name, "code_path": raw_code_path, "do_interpret": do_interpret, - "api_key": os.environ["OPENAI_API_KEY"], - "api_base_url": os.environ["API_BASE_URL"], + "api_key": api_key, + "api_base_url": api_base_url, "embed_model": embed_model, "embed_model_path": embed_model_path, "embed_engine": embed_engine, @@ -1091,7 +1116,9 @@ class ApiRequest: def delete_code_base(self, cb_name: str, no_remote_api: bool = None, embed_model: str="", embed_model_path: str="", embedding_device: str="", embed_engine: str="", - llm_model: str ="", temperature: float= 0.2 + llm_model: str ="", temperature: float= 0.2, + api_key: str=os.environ["OPENAI_API_KEY"], + api_base_url: str = os.environ["API_BASE_URL"], ): ''' 删除 code_base @@ -1102,8 +1129,8 @@ class ApiRequest: no_remote_api = self.no_remote_api data = { "cb_name": cb_name, - "api_key": os.environ["OPENAI_API_KEY"], - "api_base_url": os.environ["API_BASE_URL"], + "api_key": api_key, + "api_base_url": api_base_url, "embed_model": embed_model, "embed_model_path": embed_model_path, "embed_engine": embed_engine,