diff --git a/.gitignore b/.gitignore index 4ff6fc8..0152705 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,10 @@ **/__pycache__ knowledge_base logs +embedding_models jupyter_work model_config.py server_config.py +code_base +.DS_Store +.idea diff --git a/Dockerfile b/Dockerfile index 3940132..0ef02c7 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,10 +1,18 @@ -From python:3.9-bookworm +From python:3.9.18-bookworm WORKDIR /home/user -COPY ./docker_requirements.txt /home/user/docker_requirements.txt +COPY ./requirements.txt /home/user/docker_requirements.txt COPY ./jupyter_start.sh /home/user/jupyter_start.sh + +RUN apt-get update +RUN apt-get install -y iputils-ping telnetd net-tools vim tcpdump +# RUN echo telnet stream tcp nowait telnetd /usr/sbin/tcpd /usr/sbin/in.telnetd /etc/inetd.conf +# RUN service inetutils-inetd start +# service inetutils-inetd status + + RUN pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple RUN pip install -r /home/user/docker_requirements.txt diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..261eeb9 --- /dev/null +++ b/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/configs/model_config.py.example b/configs/model_config.py.example index 4f1aef9..753bbc9 100644 --- a/configs/model_config.py.example +++ b/configs/model_config.py.example @@ -33,6 +33,10 @@ embedding_model_dict = { "bge-large-zh": "BAAI/bge-large-zh" } + +LOCAL_MODEL_DIR = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "embedding_models") +embedding_model_dict = {k: f"/home/user/chatbot/embedding_models/{v}" if is_running_in_docker() else f"{LOCAL_MODEL_DIR}/{v}" for k, v in embedding_model_dict.items()} + # 选用的 Embedding 名称 EMBEDDING_MODEL = "text2vec-base" @@ -97,6 +101,7 @@ llm_model_dict = { # LLM 名称 LLM_MODEL = "gpt-3.5-turbo" +USE_FASTCHAT = "gpt" not in LLM_MODEL # 判断是否进行fastchat # LLM 运行设备 LLM_DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" @@ -112,6 +117,9 @@ SOURCE_PATH = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__fil # 知识库默认存储路径 KB_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "knowledge_base") +# 代码库默认存储路径 +CB_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "code_base") + # nltk 模型存储路径 NLTK_DATA_PATH = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "nltk_data") @@ -153,7 +161,7 @@ DEFAULT_VS_TYPE = "faiss" CACHED_VS_NUM = 1 # 知识库中单段文本长度 -CHUNK_SIZE = 250 +CHUNK_SIZE = 500 # 知识库中相邻文本重合长度 OVERLAP_SIZE = 50 @@ -169,6 +177,9 @@ SCORE_THRESHOLD = 1 if system_name in ["Linux", "Windows"] else 1100 # 搜索引擎匹配结题数量 SEARCH_ENGINE_TOP_K = 5 +# 代码引擎匹配结题数量 +CODE_SEARCH_TOP_K = 1 + # 基于本地知识问答的提示词模版 PROMPT_TEMPLATE = """【指令】根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,不允许在答案中添加编造成分,答案请使用中文。 @@ -176,6 +187,13 @@ PROMPT_TEMPLATE = """【指令】根据已知信息,简洁和专业的来回 【问题】{question}""" +# 基于本地代码知识问答的提示词模版 +CODE_PROMPT_TEMPLATE = """【指令】根据已知信息来回答问题。 + +【已知信息】{context} + +【问题】{question}""" + # API 是否开启跨域,默认为False,如果需要开启,请设置为True # is open cross domain OPEN_CROSS_DOMAIN = False diff --git a/configs/server_config.py.example b/configs/server_config.py.example index b7f539e..b355dcd 100644 --- a/configs/server_config.py.example +++ b/configs/server_config.py.example @@ -1,41 +1,62 @@ from .model_config import LLM_MODEL, LLM_DEVICE +import os # API 是否开启跨域,默认为False,如果需要开启,请设置为True # is open cross domain OPEN_CROSS_DOMAIN = False - +# 是否用容器来启动服务 +DOCKER_SERVICE = True +# 是否采用容器沙箱 +SANDBOX_DO_REMOTE = True +# 是否采用api服务来进行 +NO_REMOTE_API = True # 各服务器默认绑定host DEFAULT_BIND_HOST = "127.0.0.1" +# +CONTRAINER_NAME = "devopsgpt_webui" +IMAGE_NAME = "devopsgpt:py39" + # webui.py server WEBUI_SERVER = { "host": DEFAULT_BIND_HOST, "port": 8501, + "docker_port": 8501 } # api.py server API_SERVER = { "host": DEFAULT_BIND_HOST, "port": 7861, + "docker_port": 7861 +} + +# sdfile_api.py server +SDFILE_API_SERVER = { + "host": DEFAULT_BIND_HOST, + "port": 7862, + "docker_port": 7862 } # fastchat openai_api server FSCHAT_OPENAI_API = { "host": DEFAULT_BIND_HOST, "port": 8888, # model_config.llm_model_dict中模型配置的api_base_url需要与这里一致。 + "docker_port": 8888, # model_config.llm_model_dict中模型配置的api_base_url需要与这里一致。 } # sandbox api server -CONTRAINER_NAME = "devopsgt_default" -IMAGE_NAME = "devopsgpt:pypy38" +SANDBOX_CONTRAINER_NAME = "devopsgpt_sandbox" +SANDBOX_IMAGE_NAME = "devopsgpt:py39" +SANDBOX_HOST = os.environ.get("SANDBOX_HOST") or DEFAULT_BIND_HOST # "172.25.0.3" SANDBOX_SERVER = { - "host": DEFAULT_BIND_HOST, + "host": f"http://{SANDBOX_HOST}", "port": 5050, - "url": "http://localhost:5050", - "do_remote": True, + "docker_port": 5050, + "url": f"http://{SANDBOX_HOST}:5050", + "do_remote": SANDBOX_DO_REMOTE, } - # fastchat model_worker server # 这些模型必须是在model_config.llm_model_dict中正确配置的。 # 在启动startup.py时,可用通过`--model-worker --model-name xxxx`指定模型,不指定则为LLM_MODEL diff --git a/configs/utils.py b/configs/utils.py new file mode 100644 index 0000000..797591f --- /dev/null +++ b/configs/utils.py @@ -0,0 +1,15 @@ +import os + +def is_running_in_docker(): + """ + 检查当前代码是否在 Docker 容器中运行 + """ + # 检查是否存在 /.dockerenv 文件 + if os.path.exists('/.dockerenv'): + return True + + # 检查 cgroup 文件系统是否为 /docker/ 开头 + if os.path.exists("/proc/1/cgroup"): + with open('/proc/1/cgroup', 'rt') as f: + return '/docker/' in f.read() + return False \ No newline at end of file diff --git a/dev_opsgpt/chat/__init__.py b/dev_opsgpt/chat/__init__.py index c350a2a..737d802 100644 --- a/dev_opsgpt/chat/__init__.py +++ b/dev_opsgpt/chat/__init__.py @@ -2,7 +2,12 @@ from .base_chat import Chat from .knowledge_chat import KnowledgeChat from .llm_chat import LLMChat from .search_chat import SearchChat +from .tool_chat import ToolChat +from .data_chat import DataChat +from .code_chat import CodeChat +from .agent_chat import AgentChat + __all__ = [ - "Chat", "KnowledgeChat", "LLMChat", "SearchChat" + "Chat", "KnowledgeChat", "LLMChat", "SearchChat", "ToolChat", "DataChat", "CodeChat", "AgentChat" ] diff --git a/dev_opsgpt/chat/agent_chat.py b/dev_opsgpt/chat/agent_chat.py new file mode 100644 index 0000000..3fe68e2 --- /dev/null +++ b/dev_opsgpt/chat/agent_chat.py @@ -0,0 +1,169 @@ +from fastapi import Body, Request +from fastapi.responses import StreamingResponse +from typing import List +from loguru import logger +import importlib +import copy +import json + +from configs.model_config import ( + llm_model_dict, LLM_MODEL, PROMPT_TEMPLATE, + VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD) + +from dev_opsgpt.tools import ( + toLangchainTools, + TOOL_DICT, TOOL_SETS +) + +from dev_opsgpt.connector.phase import BasePhase +from dev_opsgpt.connector.agents import BaseAgent, ReactAgent +from dev_opsgpt.connector.chains import BaseChain +from dev_opsgpt.connector.connector_schema import ( + Message, + load_phase_configs, load_chain_configs, load_role_configs + ) +from dev_opsgpt.connector.shcema import Memory + +from dev_opsgpt.chat.utils import History, wrap_done +from dev_opsgpt.connector.configs import PHASE_CONFIGS, AGETN_CONFIGS, CHAIN_CONFIGS + +PHASE_MODULE = importlib.import_module("dev_opsgpt.connector.phase") + + + +class AgentChat: + + def __init__( + self, + engine_name: str = "", + top_k: int = 1, + stream: bool = False, + ) -> None: + self.top_k = top_k + self.stream = stream + + def chat( + self, + query: str = Body(..., description="用户输入", examples=["hello"]), + phase_name: str = Body(..., description="执行场景名称", examples=["chatPhase"]), + chain_name: str = Body(..., description="执行链的名称", examples=["chatChain"]), + history: List[History] = Body( + [], description="历史对话", + examples=[[{"role": "user", "content": "我们来玩成语接龙,我先来,生龙活虎"}]] + ), + doc_engine_name: str = Body(..., description="知识库名称", examples=["samples"]), + search_engine_name: str = Body(..., description="搜索引擎名称", examples=["duckduckgo"]), + code_engine_name: str = Body(..., description="代码引擎名称", examples=["samples"]), + top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"), + score_threshold: float = Body(SCORE_THRESHOLD, description="知识库匹配相关度阈值,取值范围在0-1之间,SCORE越小,相关度越高,取到1相当于不筛选,建议设置在0.5左右", ge=0, le=1), + stream: bool = Body(False, description="流式输出"), + local_doc_url: bool = Body(False, description="知识文件返回本地路径(true)或URL(false)"), + choose_tools: List[str] = Body([], description="选择tool的集合"), + do_search: bool = Body(False, description="是否进行搜索"), + do_doc_retrieval: bool = Body(False, description="是否进行知识库检索"), + do_code_retrieval: bool = Body(False, description="是否执行代码检索"), + do_tool_retrieval: bool = Body(False, description="是否执行工具检索"), + custom_phase_configs: dict = Body({}, description="自定义phase配置"), + custom_chain_configs: dict = Body({}, description="自定义chain配置"), + custom_role_configs: dict = Body({}, description="自定义role配置"), + history_node_list: List = Body([], description="代码历史相关节点"), + isDetaild: bool = Body([], description="是否输出完整的agent相关内容"), + **kargs + ) -> Message: + + # update configs + phase_configs, chain_configs, agent_configs = self.update_configs( + custom_phase_configs, custom_chain_configs, custom_role_configs) + # choose tools + tools = toLangchainTools([TOOL_DICT[i] for i in choose_tools if i in TOOL_DICT]) + input_message = Message( + role_content=query, + role_type="human", + role_name="user", + input_query=query, + phase_name=phase_name, + chain_name=chain_name, + do_search=do_search, + do_doc_retrieval=do_doc_retrieval, + do_code_retrieval=do_code_retrieval, + do_tool_retrieval=do_tool_retrieval, + doc_engine_name=doc_engine_name, search_engine_name=search_engine_name, + code_engine_name=code_engine_name, + score_threshold=score_threshold, top_k=top_k, + history_node_list=history_node_list, + tools=tools + ) + # history memory mangemant + history = Memory([ + Message(role_name=i["role"], role_type=i["role"], role_content=i["content"]) + for i in history + ]) + # start to execute + phase_class = getattr(PHASE_MODULE, phase_configs[input_message.phase_name]["phase_type"]) + phase = phase_class(input_message.phase_name, + task = input_message.task, + phase_config = phase_configs, + chain_config = chain_configs, + role_config = agent_configs, + do_summary=phase_configs[input_message.phase_name]["do_summary"], + do_code_retrieval=input_message.do_code_retrieval, + do_doc_retrieval=input_message.do_doc_retrieval, + do_search=input_message.do_search, + ) + output_message, local_memory = phase.step(input_message, history) + # logger.debug(f"local_memory: {local_memory.to_str_messages(content_key='step_content')}") + + # return { + # "answer": output_message.role_content, + # "db_docs": output_message.db_docs, + # "search_docs": output_message.search_docs, + # "code_docs": output_message.code_docs, + # "figures": output_message.figures + # } + + def chat_iterator(message: Message, local_memory: Memory, isDetaild=False): + result = { + "answer": "", + "db_docs": [str(doc) for doc in message.db_docs], + "search_docs": [str(doc) for doc in message.search_docs], + "code_docs": [str(doc) for doc in message.code_docs], + "related_nodes": [doc.get_related_node() for idx, doc in enumerate(message.code_docs) if idx==0], + "figures": message.figures + } + + + related_nodes, has_nodes = [], [ ] + for nodes in result["related_nodes"]: + for node in nodes: + if node not in has_nodes: + related_nodes.append(node) + result["related_nodes"] = related_nodes + + # logger.debug(f"{result['figures'].keys()}") + message_str = local_memory.to_str_messages(content_key='step_content') if isDetaild else message.role_content + if self.stream: + for token in message_str: + result["answer"] = token + yield json.dumps(result, ensure_ascii=False) + else: + for token in message_str: + result["answer"] += token + yield json.dumps(result, ensure_ascii=False) + + return StreamingResponse(chat_iterator(output_message, local_memory, isDetaild), media_type="text/event-stream") + + def _chat(self, ): + pass + + def update_configs(self, custom_phase_configs, custom_chain_configs, custom_role_configs): + '''update phase/chain/agent configs''' + phase_configs = copy.deepcopy(PHASE_CONFIGS) + phase_configs.update(custom_phase_configs) + chain_configs = copy.deepcopy(CHAIN_CONFIGS) + chain_configs.update(custom_chain_configs) + agent_configs = copy.deepcopy(AGETN_CONFIGS) + agent_configs.update(custom_role_configs) + # phase_configs = load_phase_configs(new_phase_configs) + # chian_configs = load_chain_configs(new_chain_configs) + # agent_configs = load_role_configs(new_agent_configs) + return phase_configs, chain_configs, agent_configs \ No newline at end of file diff --git a/dev_opsgpt/chat/base_chat.py b/dev_opsgpt/chat/base_chat.py index 20bbad5..6764772 100644 --- a/dev_opsgpt/chat/base_chat.py +++ b/dev_opsgpt/chat/base_chat.py @@ -3,12 +3,11 @@ from fastapi.responses import StreamingResponse import asyncio, json from typing import List, AsyncIterable -from langchain.chat_models import ChatOpenAI from langchain import LLMChain from langchain.callbacks import AsyncIteratorCallbackHandler from langchain.prompts.chat import ChatPromptTemplate - +from dev_opsgpt.llm_models import getChatModel from dev_opsgpt.chat.utils import History, wrap_done from configs.model_config import (llm_model_dict, LLM_MODEL, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD) from dev_opsgpt.utils import BaseResponse @@ -16,30 +15,6 @@ from loguru import logger - - - -def getChatModel(callBack: AsyncIteratorCallbackHandler = None): - if callBack is None: - model = ChatOpenAI( - streaming=True, - verbose=True, - openai_api_key=llm_model_dict[LLM_MODEL]["api_key"], - openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"], - model_name=LLM_MODEL - ) - else: - model = ChatOpenAI( - streaming=True, - verbose=True, - callBack=[callBack], - openai_api_key=llm_model_dict[LLM_MODEL]["api_key"], - openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"], - model_name=LLM_MODEL - ) - return model - - class Chat: def __init__( self, @@ -67,6 +42,7 @@ class Chat: stream: bool = Body(False, description="流式输出"), local_doc_url: bool = Body(False, description="知识文件返回本地路径(true)或URL(false)"), request: Request = None, + **kargs ): self.engine_name = engine_name if isinstance(engine_name, str) else engine_name.default self.top_k = top_k if isinstance(top_k, int) else top_k.default @@ -74,18 +50,23 @@ class Chat: self.stream = stream if isinstance(stream, bool) else stream.default self.local_doc_url = local_doc_url if isinstance(local_doc_url, bool) else local_doc_url.default self.request = request - return self._chat(query, history) + return self._chat(query, history, **kargs) - def _chat(self, query: str, history: List[History]): + def _chat(self, query: str, history: List[History], **kargs): history = [History(**h) if isinstance(h, dict) else h for h in history] + ## check service dependcy is ok service_status = self.check_service_status() + if service_status.code!=200: return service_status def chat_iterator(query: str, history: List[History]): model = getChatModel() - result ,content = self.create_task(query, history, model) + result, content = self.create_task(query, history, model, **kargs) + logger.info('result={}'.format(result)) + logger.info('content={}'.format(content)) + if self.stream: for token in content["text"]: result["answer"] = token @@ -144,7 +125,7 @@ class Chat: return StreamingResponse(chat_iterator(query, history), media_type="text/event-stream") - def create_task(self, query: str, history: List[History], model): + def create_task(self, query: str, history: List[History], model, **kargs): '''构建 llm 生成任务''' chat_prompt = ChatPromptTemplate.from_messages( [i.to_msg_tuple() for i in history] + [("human", "{input}")] diff --git a/dev_opsgpt/chat/code_chat.py b/dev_opsgpt/chat/code_chat.py new file mode 100644 index 0000000..fea8e7a --- /dev/null +++ b/dev_opsgpt/chat/code_chat.py @@ -0,0 +1,143 @@ +# encoding: utf-8 +''' +@author: 温进 +@file: code_chat.py +@time: 2023/10/24 下午4:04 +@desc: +''' + +from fastapi import Request, Body +import os, asyncio +from urllib.parse import urlencode +from typing import List +from fastapi.responses import StreamingResponse + +from langchain import LLMChain +from langchain.callbacks import AsyncIteratorCallbackHandler +from langchain.prompts.chat import ChatPromptTemplate + +from configs.model_config import ( + llm_model_dict, LLM_MODEL, PROMPT_TEMPLATE, + VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, CODE_PROMPT_TEMPLATE) +from dev_opsgpt.chat.utils import History, wrap_done +from dev_opsgpt.utils import BaseResponse +from .base_chat import Chat +from dev_opsgpt.llm_models import getChatModel + +from dev_opsgpt.service.kb_api import search_docs, KBServiceFactory +from dev_opsgpt.service.cb_api import search_code, cb_exists_api +from loguru import logger +import json + + +class CodeChat(Chat): + + def __init__( + self, + code_base_name: str = '', + code_limit: int = 1, + stream: bool = False, + request: Request = None, + ) -> None: + super().__init__(engine_name=code_base_name, stream=stream) + self.engine_name = code_base_name + self.code_limit = code_limit + self.request = request + self.history_node_list = [] + + def check_service_status(self) -> BaseResponse: + cb = cb_exists_api(self.engine_name) + if not cb: + return BaseResponse(code=404, msg=f"未找到代码库 {self.engine_name}") + return BaseResponse(code=200, msg=f"找到代码库 {self.engine_name}") + + def _process(self, query: str, history: List[History], model): + '''process''' + codes_res = search_code(query=query, cb_name=self.engine_name, code_limit=self.code_limit, + history_node_list=self.history_node_list) + + codes = codes_res['related_code'] + nodes = codes_res['related_node'] + + # update node names + node_names = [node[0] for node in nodes] + self.history_node_list.extend(node_names) + self.history_node_list = list(set(self.history_node_list)) + + context = "\n".join(codes) + source_nodes = [] + + for inum, node_info in enumerate(nodes[0:5]): + node_name, node_type, node_score = node_info[0], node_info[1], node_info[2] + source_nodes.append(f'{inum + 1}. 节点名为 {node_name}, 节点类型为 `{node_type}`, 节点得分为 `{node_score}`') + + chat_prompt = ChatPromptTemplate.from_messages( + [i.to_msg_tuple() for i in history] + [("human", CODE_PROMPT_TEMPLATE)] + ) + chain = LLMChain(prompt=chat_prompt, llm=model) + result = {"answer": "", "codes": source_nodes} + return chain, context, result + + def chat( + self, + query: str = Body(..., description="用户输入", examples=["hello"]), + history: List[History] = Body( + [], description="历史对话", + examples=[[{"role": "user", "content": "我们来玩成语接龙,我先来,生龙活虎"}]] + ), + engine_name: str = Body(..., description="知识库名称", examples=["samples"]), + code_limit: int = Body(1, examples=['1']), + stream: bool = Body(False, description="流式输出"), + local_doc_url: bool = Body(False, description="知识文件返回本地路径(true)或URL(false)"), + request: Request = Body(None), + **kargs + ): + self.engine_name = engine_name if isinstance(engine_name, str) else engine_name.default + self.code_limit = code_limit + self.stream = stream if isinstance(stream, bool) else stream.default + self.local_doc_url = local_doc_url if isinstance(local_doc_url, bool) else local_doc_url.default + self.request = request + return self._chat(query, history, **kargs) + + def _chat(self, query: str, history: List[History], **kargs): + history = [History(**h) if isinstance(h, dict) else h for h in history] + + service_status = self.check_service_status() + + if service_status.code != 200: return service_status + + def chat_iterator(query: str, history: List[History]): + model = getChatModel() + + result, content = self.create_task(query, history, model, **kargs) + # logger.info('result={}'.format(result)) + # logger.info('content={}'.format(content)) + + if self.stream: + for token in content["text"]: + result["answer"] = token + yield json.dumps(result, ensure_ascii=False) + else: + for token in content["text"]: + result["answer"] += token + yield json.dumps(result, ensure_ascii=False) + + return StreamingResponse(chat_iterator(query, history), + media_type="text/event-stream") + + def create_task(self, query: str, history: List[History], model): + '''构建 llm 生成任务''' + chain, context, result = self._process(query, history, model) + logger.info('chain={}'.format(chain)) + try: + content = chain({"context": context, "question": query}) + except Exception as e: + content = {"text": str(e)} + return result, content + + def create_atask(self, query, history, model, callback: AsyncIteratorCallbackHandler): + chain, context, result = self._process(query, history, model) + task = asyncio.create_task(wrap_done( + chain.acall({"context": context, "question": query}), callback.done + )) + return task, result diff --git a/dev_opsgpt/chat/data_chat.py b/dev_opsgpt/chat/data_chat.py new file mode 100644 index 0000000..448aaf1 --- /dev/null +++ b/dev_opsgpt/chat/data_chat.py @@ -0,0 +1,229 @@ +import asyncio +from typing import List + +from langchain import LLMChain +from langchain.callbacks import AsyncIteratorCallbackHandler +from langchain.prompts.chat import ChatPromptTemplate +from langchain.agents import AgentType, initialize_agent + +from dev_opsgpt.tools import ( + WeatherInfo, WorldTimeGetTimezoneByArea, Multiplier, + toLangchainTools, get_tool_schema + ) +from .utils import History, wrap_done +from .base_chat import Chat +from loguru import logger +import json, re + +from dev_opsgpt.sandbox import PyCodeBox, CodeBoxResponse +from configs.server_config import SANDBOX_SERVER + +def get_tool_agent(tools, llm): + return initialize_agent( + tools, + llm, + agent=AgentType.STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION, + verbose=True, + ) + +PROMPT_TEMPLATE = """ +`角色` +你是一个数据分析师,借鉴下述步骤,逐步完成数据分析任务的拆解和代码编写,尽可能帮助和准确地回答用户的问题。 + +数据文件的存放路径为 `./` + +`数据分析流程` +- 判断文件是否存在,并读取文件数据 +- 输出数据的基本信息,包括但不限于字段、文本、数据类型等 +- 输出数据的详细统计信息 +- 判断是否需要画图分析,选择合适的字段进行画图 +- 判断数据是否需要进行清洗 +- 判断数据或图片是否需要保存 +... +- 结合数据统计分析结果和画图结果,进行总结和分析这份数据的价值 + +`要求` +- 每轮选择一个数据分析流程,需要综合考虑上轮和后续的可能影响 +- 数据分析流程只提供参考,不要拘泥于它的具体流程,要有自己的思考 +- 使用JSON blob来指定一个计划,通过提供task_status关键字(任务状态)、plan关键字(数据分析计划)和code关键字(可执行代码)。 + +合法的 "task_status" 值: "finished" 表明当前用户问题已被准确回答 或者 "continued" 表明用户问题仍需要进一步分析 + +`$JSON_BLOB如下所示` +``` +{{ + "task_status": $TASK_STATUS, + "plan": $PLAN, + "code": ```python\n$CODE``` +}} +``` + +`跟随如下示例` +问题: 输入待回答的问题 +行动:$JSON_BLOB + +... (重复 行动 N 次,每次只生成一个行动) + +行动: +``` +{{ + "task_status": "finished", + "plan": 我已经可以回答用户问题了,最后回答用户的内容 +}} + +``` + +`数据分析,开始` + +问题:{query} +""" + + +PROMPT_TEMPLATE_2 = """ +`角色` +你是一个数据分析师,借鉴下述步骤,逐步完成数据分析任务的拆解和代码编写,尽可能帮助和准确地回答用户的问题。 + +数据文件的存放路径为 `./` + +`数据分析流程` +- 判断文件是否存在,并读取文件数据 +- 输出数据的基本信息,包括但不限于字段、文本、数据类型等 +- 输出数据的详细统计信息 +- 判断数据是否需要进行清洗 +- 判断是否需要画图分析,选择合适的字段进行画图 +- 判断清洗后数据或图片是否需要保存 +... +- 结合数据统计分析结果和画图结果,进行总结和分析这份数据的价值 + +`要求` +- 每轮选择一个数据分析流程,需要综合考虑上轮和后续的可能影响 +- 数据分析流程只提供参考,不要拘泥于它的具体流程,要有自己的思考 +- 使用JSON blob来指定一个计划,通过提供task_status关键字(任务状态)、plan关键字(数据分析计划)和code关键字(可执行代码)。 + +合法的 "task_status" 值: "finished" 表明当前用户问题已被准确回答 或者 "continued" 表明用户问题仍需要进一步分析 + +`$JSON_BLOB如下所示` +``` +{{ + "task_status": $TASK_STATUS, + "plan": $PLAN, + "code": ```python\n$CODE``` +}} +``` + +`跟随如下示例` +问题: 输入待回答的问题 +行动:$JSON_BLOB + +... (重复 行动 N 次,每次只生成一个行动) + +行动: +``` +{{ + "task_status": "finished", + "plan": 我已经可以回答用户问题了,最后回答用户的内容 +}} + +`数据分析,开始` + +问题:上传了一份employee_data.csv文件,请对它进行数据分析 + +问题:{query} +{history} + +""" + +class DataChat(Chat): + + def __init__( + self, + engine_name: str = "", + top_k: int = 1, + stream: bool = False, + ) -> None: + super().__init__(engine_name, top_k, stream) + self.tool_prompt = """结合上下文信息,{tools} {input}""" + self.codebox = PyCodeBox( + remote_url=SANDBOX_SERVER["url"], + remote_ip=SANDBOX_SERVER["host"], # "http://localhost", + remote_port=SANDBOX_SERVER["port"], + token="mytoken", + do_code_exe=True, + do_remote=SANDBOX_SERVER["do_remote"] + ) + + def create_task(self, query: str, history: List[History], model): + '''构建 llm 生成任务''' + logger.debug("content:{}".format([i.to_msg_tuple() for i in history] + [("human", PROMPT_TEMPLATE)])) + chat_prompt = ChatPromptTemplate.from_messages( + [i.to_msg_tuple() for i in history] + [("human", PROMPT_TEMPLATE)] + ) + pattern = re.compile(r"```(?:json)?\n(.*?)\n", re.DOTALL) + internal_history = [] + retry_nums = 2 + while retry_nums >= 0: + if len(internal_history) == 0: + chat_prompt = ChatPromptTemplate.from_messages( + [i.to_msg_tuple() for i in history] + [("human", PROMPT_TEMPLATE)] + ) + else: + chat_prompt = ChatPromptTemplate.from_messages( + [i.to_msg_tuple() for i in history] + [("human", PROMPT_TEMPLATE_2)] + ) + + chain = LLMChain(prompt=chat_prompt, llm=model) + content = chain({"query": query, "history": "\n".join(internal_history)})["text"] + + # content = pattern.search(content) + # logger.info(f"content: {content}") + # content = json.loads(content.group(1).strip(), strict=False) + + internal_history.append(f"{content}") + refer_info = "\n".join(internal_history) + logger.info(f"refer_info: {refer_info}") + try: + content = content.split("行动:")[-1].split("行动:")[-1] + content = json.loads(content) + except: + content = content.split("行动:")[-1].split("行动:")[-1] + content = eval(content) + + if "finished" == content["task_status"]: + break + elif "code" in content: + # elif "```code" in content or "```python" in content: + # code_text = self.codebox.decode_code_from_text(content) + code_text = content["code"] + codebox_res = self.codebox.chat("```"+code_text+"```", do_code_exe=True) + + if codebox_res is not None and codebox_res.code_exe_status != 200: + logger.warning(f"{codebox_res.code_exe_response}") + internal_history.append(f"观察: 根据这个报错信息 {codebox_res.code_exe_response},进行代码修复") + + if codebox_res is not None and codebox_res.code_exe_status == 200: + if codebox_res.code_exe_type == "image/png": + base_text = f"```\n{code_text}\n```\n\n" + img_html = "".format( + codebox_res.code_exe_response + ) + internal_history.append(f"观察: {img_html}") + # logger.info('```\n'+code_text+'\n```'+"\n\n"+'```\n'+codebox_res.code_exe_response+'\n```') + else: + internal_history.append(f"观察: {codebox_res.code_exe_response}") + # logger.info('```\n'+code_text+'\n```'+"\n\n"+'```\n'+codebox_res.code_exe_response+'\n```') + else: + internal_history.append(f"观察:下一步应该怎么做?") + retry_nums -= 1 + + + return {"answer": "", "docs": ""}, {"text": "\n".join(internal_history)} + + def create_atask(self, query, history, model, callback: AsyncIteratorCallbackHandler): + chat_prompt = ChatPromptTemplate.from_messages( + [i.to_msg_tuple() for i in history] + [("human", PROMPT_TEMPLATE)] + ) + chain = LLMChain(prompt=chat_prompt, llm=model) + task = asyncio.create_task(wrap_done( + chain.acall({"input": query}), callback.done + )) + return task, {"answer": "", "docs": ""} \ No newline at end of file diff --git a/dev_opsgpt/chat/search_chat.py b/dev_opsgpt/chat/search_chat.py index badad4e..8346469 100644 --- a/dev_opsgpt/chat/search_chat.py +++ b/dev_opsgpt/chat/search_chat.py @@ -1,6 +1,5 @@ from fastapi import Request import os, asyncio -from urllib.parse import urlencode from typing import List, Optional, Dict from langchain import LLMChain diff --git a/dev_opsgpt/chat/tool_chat.py b/dev_opsgpt/chat/tool_chat.py new file mode 100644 index 0000000..ceb6a86 --- /dev/null +++ b/dev_opsgpt/chat/tool_chat.py @@ -0,0 +1,84 @@ +import asyncio +from typing import List + +from langchain import LLMChain +from langchain.callbacks import AsyncIteratorCallbackHandler +from langchain.prompts.chat import ChatPromptTemplate +from langchain.agents import AgentType, initialize_agent +import langchain +from langchain.schema import ( + AgentAction + ) + + +# langchain.debug = True + +from dev_opsgpt.tools import ( + TOOL_SETS, TOOL_DICT, + toLangchainTools, get_tool_schema + ) +from .utils import History, wrap_done +from .base_chat import Chat +from loguru import logger + + +def get_tool_agent(tools, llm): + return initialize_agent( + tools, + llm, + agent=AgentType.STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION, + verbose=True, + return_intermediate_steps=True + ) + + +class ToolChat(Chat): + + def __init__( + self, + engine_name: str = "", + top_k: int = 1, + stream: bool = False, + ) -> None: + super().__init__(engine_name, top_k, stream) + self.tool_prompt = """结合上下文信息,{tools} {input}""" + self.tools = toLangchainTools([TOOL_DICT[i] for i in TOOL_SETS if i in TOOL_DICT]) + + def create_task(self, query: str, history: List[History], model, **kargs): + '''构建 llm 生成任务''' + logger.debug("content:{}".format([i.to_msg_tuple() for i in history] + [("human", "{query}")])) + # chat_prompt = ChatPromptTemplate.from_messages( + # [i.to_msg_tuple() for i in history] + [("human", "{query}")] + # ) + tools = kargs.get("tool_sets", []) + tools = toLangchainTools([TOOL_DICT[i] for i in tools if i in TOOL_DICT]) + agent = get_tool_agent(tools if tools else self.tools, model) + content = agent(query) + + logger.debug(f"content: {content}") + + s = "" + if isinstance(content, str): + s = content + else: + for i in content["intermediate_steps"]: + for j in i: + if isinstance(j, AgentAction): + s += j.log + "\n" + else: + s += "Observation: " + str(j) + "\n" + + s += "final answer:" + content["output"] + # chain = LLMChain(prompt=chat_prompt, llm=model) + # content = chain({"tools": tools, "input": query}) + return {"answer": "", "docs": ""}, {"text": s} + + def create_atask(self, query, history, model, callback: AsyncIteratorCallbackHandler): + chat_prompt = ChatPromptTemplate.from_messages( + [i.to_msg_tuple() for i in history] + [("human", self.tool_prompt)] + ) + chain = LLMChain(prompt=chat_prompt, llm=model) + task = asyncio.create_task(wrap_done( + chain.acall({"input": query}), callback.done + )) + return task, {"answer": "", "docs": ""} \ No newline at end of file diff --git a/dev_opsgpt/codebase_handler/__init__.py b/dev_opsgpt/codebase_handler/__init__.py new file mode 100644 index 0000000..ab840cf --- /dev/null +++ b/dev_opsgpt/codebase_handler/__init__.py @@ -0,0 +1,7 @@ +# encoding: utf-8 +''' +@author: 温进 +@file: __init__.py.py +@time: 2023/10/23 下午4:57 +@desc: +''' \ No newline at end of file diff --git a/dev_opsgpt/codebase_handler/codebase_handler.py b/dev_opsgpt/codebase_handler/codebase_handler.py new file mode 100644 index 0000000..82e9879 --- /dev/null +++ b/dev_opsgpt/codebase_handler/codebase_handler.py @@ -0,0 +1,139 @@ +# encoding: utf-8 +''' +@author: 温进 +@file: codebase_handler.py +@time: 2023/10/23 下午5:05 +@desc: +''' + +from loguru import logger +import time +import os + +from dev_opsgpt.codebase_handler.parser.java_paraser.java_crawler import JavaCrawler +from dev_opsgpt.codebase_handler.parser.java_paraser.java_preprocess import JavaPreprocessor +from dev_opsgpt.codebase_handler.parser.java_paraser.java_dedup import JavaDedup +from dev_opsgpt.codebase_handler.parser.java_paraser.java_parser import JavaParser +from dev_opsgpt.codebase_handler.tagger.tagger import Tagger +from dev_opsgpt.codebase_handler.tagger.tuple_generation import node_edge_update + +from dev_opsgpt.codebase_handler.networkx_handler.networkx_handler import NetworkxHandler +from dev_opsgpt.codebase_handler.codedb_handler.local_codedb_handler import LocalCodeDBHandler + + +class CodeBaseHandler(): + def __init__(self, code_name: str, code_path: str = '', cb_root_path: str = '', history_node_list: list = []): + self.nh = None + self.lcdh = None + self.code_name = code_name + self.code_path = code_path + + self.codebase_path = cb_root_path + os.sep + code_name + self.graph_path = self.codebase_path + os.sep + 'graph.pk' + self.codedb_path = self.codebase_path + os.sep + 'codedb.pk' + + self.tagger = Tagger() + self.history_node_list = history_node_list + + def import_code(self, do_save: bool=False, do_load: bool=False) -> bool: + ''' + import code to codeBase + @param code_path: + @param do_save: + @param do_load: + @return: True as success; False as failure + ''' + if do_load: + logger.info('start load from codebase_path') + load_graph_path = self.graph_path + load_codedb_path = self.codedb_path + + st = time.time() + self.nh = NetworkxHandler(graph_path=load_graph_path) + logger.info('generate graph success, rt={}'.format(time.time() - st)) + + st = time.time() + self.lcdh = LocalCodeDBHandler(db_path=load_codedb_path) + logger.info('generate codedb success, rt={}'.format(time.time() - st)) + else: + logger.info('start load from code_path') + st = time.time() + java_code_dict = JavaCrawler.local_java_file_crawler(self.code_path) + logger.info('crawl success, rt={}'.format(time.time() - st)) + + jp = JavaPreprocessor() + java_code_dict = jp.preprocess(java_code_dict) + + jd = JavaDedup() + java_code_dict = jd.dedup(java_code_dict) + + st = time.time() + j_parser = JavaParser() + parse_res = j_parser.parse(java_code_dict) + logger.info('parse success, rt={}'.format(time.time() - st)) + + st = time.time() + tagged_code = self.tagger.generate_tag(parse_res) + node_list, edge_list = node_edge_update(parse_res.values()) + logger.info('get node and edge success, rt={}'.format(time.time() - st)) + + st = time.time() + self.nh = NetworkxHandler(node_list=node_list, edge_list=edge_list) + logger.info('generate graph success, rt={}'.format(time.time() - st)) + + st = time.time() + self.lcdh = LocalCodeDBHandler(tagged_code) + logger.info('CodeDB load success, rt={}'.format(time.time() - st)) + + if do_save: + save_graph_path = self.graph_path + save_codedb_path = self.codedb_path + self.nh.save_graph(save_graph_path) + self.lcdh.save_db(save_codedb_path) + + def search_code(self, query: str, code_limit: int, history_node_list: list = []): + ''' + search code related to query + @param self: + @param query: + @return: + ''' + # get query tag + query_tag_list = self.tagger.generate_tag_query(query) + + related_node_score_list = self.nh.search_node_with_score(query_tag_list=query_tag_list, + history_node_list=history_node_list) + + score_dict = { + i[0]: i[1] + for i in related_node_score_list + } + related_node = [i[0] for i in related_node_score_list] + related_score = [i[1] for i in related_node_score_list] + + related_code, code_related_node = self.lcdh.search_by_multi_tag(related_node, lim=code_limit) + + related_node = [ + (node, self.nh.get_node_type(node), score_dict[node]) + for node in code_related_node + ] + + related_node.sort(key=lambda x: x[2], reverse=True) + + logger.info('related_node={}'.format(related_node)) + logger.info('related_code={}'.format(related_code)) + logger.info('num of code={}'.format(len(related_code))) + return related_code, related_node + + def refresh_history(self): + self.history_node_list = [] + + + + + + + + + + diff --git a/dev_opsgpt/codebase_handler/codedb_handler/__init__.py b/dev_opsgpt/codebase_handler/codedb_handler/__init__.py new file mode 100644 index 0000000..24ed759 --- /dev/null +++ b/dev_opsgpt/codebase_handler/codedb_handler/__init__.py @@ -0,0 +1,7 @@ +# encoding: utf-8 +''' +@author: 温进 +@file: __init__.py.py +@time: 2023/10/23 下午5:04 +@desc: +''' \ No newline at end of file diff --git a/dev_opsgpt/codebase_handler/codedb_handler/local_codedb_handler.py b/dev_opsgpt/codebase_handler/codedb_handler/local_codedb_handler.py new file mode 100644 index 0000000..e8fb54a --- /dev/null +++ b/dev_opsgpt/codebase_handler/codedb_handler/local_codedb_handler.py @@ -0,0 +1,55 @@ +# encoding: utf-8 +''' +@author: 温进 +@file: local_codedb_handler.py +@time: 2023/10/23 下午5:05 +@desc: +''' +import pickle + + +class LocalCodeDBHandler: + def __init__(self, tagged_code: dict = {}, db_path: str = ''): + if db_path: + with open(db_path, 'rb') as f: + self.data = pickle.load(f) + else: + self.data = {} + for code, tag in tagged_code.items(): + self.data[code] = str(tag) + + def search_by_single_tag(self, tag, lim): + res = list() + for k, v in self.data.items(): + if tag in v and k not in res: + res.append(k) + + if len(res) > lim: + break + return res + + def search_by_multi_tag(self, tag_list, lim=3): + res = list() + res_related_node = [] + for tag in tag_list: + single_tag_res = self.search_by_single_tag(tag, lim) + for code in single_tag_res: + if code not in res: + res.append(code) + res_related_node.append(tag) + if len(res) >= lim: + break + + # reverse order so that most relevant one is close to the query + res = res[0:lim] + res.reverse() + + return res, res_related_node + + def save_db(self, save_path): + with open(save_path, 'wb') as f: + pickle.dump(self.data, f) + + def __len__(self): + return len(self.data) + diff --git a/dev_opsgpt/codebase_handler/networkx_handler/__init__.py b/dev_opsgpt/codebase_handler/networkx_handler/__init__.py new file mode 100644 index 0000000..05f385f --- /dev/null +++ b/dev_opsgpt/codebase_handler/networkx_handler/__init__.py @@ -0,0 +1,7 @@ +# encoding: utf-8 +''' +@author: 温进 +@file: __init__.py.py +@time: 2023/10/23 下午5:00 +@desc: +''' \ No newline at end of file diff --git a/dev_opsgpt/codebase_handler/networkx_handler/networkx_handler.py b/dev_opsgpt/codebase_handler/networkx_handler/networkx_handler.py new file mode 100644 index 0000000..ffc16be --- /dev/null +++ b/dev_opsgpt/codebase_handler/networkx_handler/networkx_handler.py @@ -0,0 +1,129 @@ +# encoding: utf-8 +''' +@author: 温进 +@file: networkx_handler.py +@time: 2023/10/23 下午5:02 +@desc: +''' + +import networkx as nx +from loguru import logger +import matplotlib.pyplot as plt +import pickle +from collections import defaultdict +import json + +QUERY_SCORE = 10 +HISTORY_SCORE = 5 +RATIO = 0.5 + + +class NetworkxHandler: + def __init__(self, graph_path: str = '', node_list: list = [], edge_list: list = []): + if graph_path: + self.graph_path = graph_path + with open(graph_path, 'r') as f: + self.G = nx.node_link_graph(json.load(f)) + else: + self.G = nx.DiGraph() + self.populate_graph(node_list, edge_list) + logger.debug( + 'number of nodes={}, number of edges={}'.format(self.G.number_of_nodes(), self.G.number_of_edges())) + + self.query_score = QUERY_SCORE + self.history_score = HISTORY_SCORE + self.ratio = RATIO + + def populate_graph(self, node_list, edge_list): + ''' + populate graph with node_list and edge_list + ''' + self.G.add_nodes_from(node_list) + for edge in edge_list: + self.G.add_edge(edge[0], edge[-1], relation=edge[1]) + + def draw_graph(self, save_path: str): + ''' + draw and save to save_path + ''' + sub = plt.subplot(111) + nx.draw(self.G, with_labels=True) + + plt.savefig(save_path) + + def search_node(self, query_tag_list: list, history_node_list: list = []): + ''' + search node by tag_list, search from history_tag neighbors first + > query_tag_list: tag from query + > history_node_list + ''' + node_list = set() + + # search from history_tag_list first, then all nodes + for tag in query_tag_list: + add = False + for history_node in history_node_list: + connect_node_list: list = self.G.adj[history_node] + connect_node_list.insert(0, history_node) + for connect_node in connect_node_list: + node_name_lim = len(connect_node) if '_' not in connect_node else connect_node.index('_') + node_name = connect_node[0:node_name_lim] + if tag.lower() in node_name.lower(): + node_list.add(connect_node) + add = True + if not add: + for node in self.G.nodes(): + if tag.lower() in node.lower(): + node_list.add(node) + return node_list + + def search_node_with_score(self, query_tag_list: list, history_node_list: list = []): + ''' + search node by tag_list, search from history_tag neighbors first + > query_tag_list: tag from query + > history_node_list + ''' + logger.info('query_tag_list={}, history_node_list={}'.format(query_tag_list, history_node_list)) + node_dict = defaultdict(lambda: 0) + + # loop over query_tag_list and add node: + for tag in query_tag_list: + for node in self.G.nodes: + if tag.lower() in node.lower(): + node_dict[node] += self.query_score + + # loop over history_node and add node score + for node in history_node_list: + node_dict[node] += self.history_score + + logger.info('temp_res={}'.format(node_dict)) + + # adj score broadcast + for node in node_dict: + adj_node_list = self.G.adj[node] + for adj_node in adj_node_list: + node_dict[node] += node_dict.get(adj_node, 0) * self.ratio + + # sort + node_list = [(node, node_score) for node, node_score in node_dict.items()] + node_list.sort(key=lambda x: x[1], reverse=True) + return node_list + + def save_graph(self, save_path: str): + to_save = nx.node_link_data(self.G) + with open(save_path, 'w') as f: + json.dump(to_save, f) + + def __len__(self): + return self.G.number_of_nodes() + + def get_node_type(self, node_name): + node_type = self.G.nodes[node_name]['type'] + return node_type + + def refresh_graph(self, ): + with open(self.graph_path, 'r') as f: + self.G = nx.node_link_graph(json.load(f)) + + + diff --git a/dev_opsgpt/codebase_handler/parser/__init__.py b/dev_opsgpt/codebase_handler/parser/__init__.py new file mode 100644 index 0000000..05f385f --- /dev/null +++ b/dev_opsgpt/codebase_handler/parser/__init__.py @@ -0,0 +1,7 @@ +# encoding: utf-8 +''' +@author: 温进 +@file: __init__.py.py +@time: 2023/10/23 下午5:00 +@desc: +''' \ No newline at end of file diff --git a/dev_opsgpt/codebase_handler/parser/java_paraser/__init__.py b/dev_opsgpt/codebase_handler/parser/java_paraser/__init__.py new file mode 100644 index 0000000..ee1c9a5 --- /dev/null +++ b/dev_opsgpt/codebase_handler/parser/java_paraser/__init__.py @@ -0,0 +1,7 @@ +# encoding: utf-8 +''' +@author: 温进 +@file: __init__.py.py +@time: 2023/10/23 下午5:01 +@desc: +''' \ No newline at end of file diff --git a/dev_opsgpt/codebase_handler/parser/java_paraser/java_crawler.py b/dev_opsgpt/codebase_handler/parser/java_paraser/java_crawler.py new file mode 100644 index 0000000..0681033 --- /dev/null +++ b/dev_opsgpt/codebase_handler/parser/java_paraser/java_crawler.py @@ -0,0 +1,32 @@ +# encoding: utf-8 +''' +@author: 温进 +@file: java_crawler.py +@time: 2023/10/23 下午5:02 +@desc: +''' + +import os +import glob +from loguru import logger + + +class JavaCrawler: + @staticmethod + def local_java_file_crawler(path: str): + ''' + read local java file in path + > path: path to crawl, must be absolute path like A/B/C + < dict of java code string + ''' + java_file_list = glob.glob('{path}{sep}**{sep}*.java'.format(path=path, sep=os.path.sep), recursive=True) + java_code_dict = {} + + logger.debug('number of file={}'.format(len(java_file_list))) + # logger.debug(java_file_list) + + for java_file in java_file_list: + with open(java_file) as f: + java_code = ''.join(f.readlines()) + java_code_dict[java_file] = java_code + return java_code_dict \ No newline at end of file diff --git a/dev_opsgpt/codebase_handler/parser/java_paraser/java_dedup.py b/dev_opsgpt/codebase_handler/parser/java_paraser/java_dedup.py new file mode 100644 index 0000000..c8e88b2 --- /dev/null +++ b/dev_opsgpt/codebase_handler/parser/java_paraser/java_dedup.py @@ -0,0 +1,15 @@ +# encoding: utf-8 +''' +@author: 温进 +@file: java_dedup.py +@time: 2023/10/23 下午5:02 +@desc: +''' + + +class JavaDedup: + def __init__(self): + pass + + def dedup(self, java_code_dict): + return java_code_dict \ No newline at end of file diff --git a/dev_opsgpt/codebase_handler/parser/java_paraser/java_parser.py b/dev_opsgpt/codebase_handler/parser/java_paraser/java_parser.py new file mode 100644 index 0000000..23826d5 --- /dev/null +++ b/dev_opsgpt/codebase_handler/parser/java_paraser/java_parser.py @@ -0,0 +1,107 @@ +# encoding: utf-8 +''' +@author: 温进 +@file: java_parser.py +@time: 2023/10/23 下午5:03 +@desc: +''' +import json +import javalang +import glob +import os +from loguru import logger + + +class JavaParser: + def __init__(self): + pass + + def parse(self, java_code_list): + ''' + parse java code and extract entity + ''' + tree_dict = self.preparse(java_code_list) + res = self.multi_java_code_parse(tree_dict) + + return res + + def preparse(self, java_code_dict): + ''' + preparse by javalang + < dict of java_code and tree + ''' + tree_dict = {} + for fp, java_code in java_code_dict.items(): + try: + tree = javalang.parse.parse(java_code) + except Exception as e: + continue + + if tree.package is not None: + tree_dict[java_code] = tree + logger.info('success parse {} files'.format(len(tree_dict))) + return tree_dict + + def single_java_code_parse(self, tree): + ''' + parse single code file + > tree: javalang parse result + < {pac_name: '', class_name_list: [], func_name_list: [], import_pac_name_list: []]} + ''' + import_pac_name_list = [] + + # get imports + import_list = tree.imports + + for import_pac in import_list: + import_pac_name = import_pac.path + import_pac_name_list.append(import_pac_name) + + pac_name = tree.package.name + class_name_list = [] + func_name_dict = {} + + for node in tree.types: + if type(node) in (javalang.tree.ClassDeclaration, javalang.tree.InterfaceDeclaration): + class_name = pac_name + '.' + node.name + class_name_list.append(class_name) + + for node_inner in node.body: + if type(node_inner) is javalang.tree.MethodDeclaration: + func_name = class_name + '.' + node_inner.name + + # add params name to func_name + params_list = node_inner.parameters + + for params in params_list: + params_name = params.type.name + func_name = func_name + '_' + params_name + + if class_name not in func_name_dict: + func_name_dict[class_name] = [] + + func_name_dict[class_name].append(func_name) + + res = { + 'pac_name': pac_name, + 'class_name_list': class_name_list, + 'func_name_dict': func_name_dict, + 'import_pac_name_list': import_pac_name_list + } + return res + + def multi_java_code_parse(self, tree_dict): + ''' + parse multiple java code + > tree_list + < parse_result_dict + ''' + res_dict = {} + for java_code, tree in tree_dict.items(): + try: + res_dict[java_code] = self.single_java_code_parse(tree) + except Exception as e: + logger.debug(java_code) + raise ImportError + + return res_dict \ No newline at end of file diff --git a/dev_opsgpt/codebase_handler/parser/java_paraser/java_preprocess.py b/dev_opsgpt/codebase_handler/parser/java_paraser/java_preprocess.py new file mode 100644 index 0000000..c71729f --- /dev/null +++ b/dev_opsgpt/codebase_handler/parser/java_paraser/java_preprocess.py @@ -0,0 +1,14 @@ +# encoding: utf-8 +''' +@author: 温进 +@file: java_preprocess.py +@time: 2023/10/23 下午5:04 +@desc: +''' + +class JavaPreprocessor: + def __init__(self): + pass + + def preprocess(self, java_code_dict): + return java_code_dict \ No newline at end of file diff --git a/dev_opsgpt/codebase_handler/tagger/__init__.py b/dev_opsgpt/codebase_handler/tagger/__init__.py new file mode 100644 index 0000000..05f385f --- /dev/null +++ b/dev_opsgpt/codebase_handler/tagger/__init__.py @@ -0,0 +1,7 @@ +# encoding: utf-8 +''' +@author: 温进 +@file: __init__.py.py +@time: 2023/10/23 下午5:00 +@desc: +''' \ No newline at end of file diff --git a/dev_opsgpt/codebase_handler/tagger/tagger.py b/dev_opsgpt/codebase_handler/tagger/tagger.py new file mode 100644 index 0000000..943c013 --- /dev/null +++ b/dev_opsgpt/codebase_handler/tagger/tagger.py @@ -0,0 +1,48 @@ +# encoding: utf-8 +''' +@author: 温进 +@file: tagger.py +@time: 2023/10/23 下午5:01 +@desc: +''' +import re +from loguru import logger + + +class Tagger: + def __init__(self): + pass + + def generate_tag(self, parse_res_dict: dict): + ''' + generate tag from parse_res + ''' + res = {} + for java_code, parse_res in parse_res_dict.items(): + tag = {} + tag['pac_name'] = parse_res.get('pac_name') + tag['class_name'] = set(parse_res.get('class_name_list')) + tag['func_name'] = set() + + for _, func_name_list in parse_res.get('func_name_dict', {}).items(): + tag['func_name'].update(func_name_list) + + res[java_code] = tag + return res + + def generate_tag_query(self, query): + ''' + generate tag from query + ''' + # simple extract english + tag_list = re.findall(r'[a-zA-Z\_\.]+', query) + tag_list = list(set(tag_list)) + return tag_list + + +if __name__ == '__main__': + tagger = Tagger() + logger.debug(tagger.generate_tag_query('com.CheckHolder 有哪些函数')) + + + diff --git a/dev_opsgpt/codebase_handler/tagger/tuple_generation.py b/dev_opsgpt/codebase_handler/tagger/tuple_generation.py new file mode 100644 index 0000000..fa33559 --- /dev/null +++ b/dev_opsgpt/codebase_handler/tagger/tuple_generation.py @@ -0,0 +1,51 @@ +# encoding: utf-8 +''' +@author: 温进 +@file: tuple_generation.py +@time: 2023/10/23 下午5:01 +@desc: +''' + + +def node_edge_update(parse_res_list: list, node_list: list = list(), edge_list: list = list()): + ''' + generate node and edge by parse_res + < node: list of string node + < edge: (node_st, relation, node_ed) + ''' + node_dict = {i: j for i, j in node_list} + + for single_parse_res in parse_res_list: + pac_name = single_parse_res['pac_name'] + + node_dict[pac_name] = {'type': 'package'} + + # class_name + for class_name in single_parse_res['class_name_list']: + node_dict[class_name] = {'type': 'class'} + edge_list.append((pac_name, 'contain', class_name)) + edge_list.append((class_name, 'inside', pac_name)) + + # func_name + for class_name, func_name_list in single_parse_res['func_name_dict'].items(): + node_list.append(class_name) + for func_name in func_name_list: + node_dict[func_name] = {'type': 'func'} + edge_list.append((class_name, 'contain', func_name)) + edge_list.append((func_name, 'inside', class_name)) + + # depend + for depend_pac_name in single_parse_res['import_pac_name_list']: + if depend_pac_name.endswith('*'): + depend_pac_name = depend_pac_name[0:-2] + + if depend_pac_name in node_dict: + continue + else: + node_dict[depend_pac_name] = {'type': 'unknown'} + edge_list.append((pac_name, 'depend', depend_pac_name)) + edge_list.append((depend_pac_name, 'beDepended', pac_name)) + + node_list = [(i, j) for i, j in node_dict.items()] + + return node_list, edge_list \ No newline at end of file diff --git a/dev_opsgpt/connector/__init__.py b/dev_opsgpt/connector/__init__.py new file mode 100644 index 0000000..70a0b72 --- /dev/null +++ b/dev_opsgpt/connector/__init__.py @@ -0,0 +1,9 @@ +from .configs import PHASE_CONFIGS + + + +PHASE_LIST = list(PHASE_CONFIGS.keys()) + +__all__ = [ + "PHASE_CONFIGS" +] \ No newline at end of file diff --git a/dev_opsgpt/connector/agents/__init__.py b/dev_opsgpt/connector/agents/__init__.py new file mode 100644 index 0000000..b4aed3f --- /dev/null +++ b/dev_opsgpt/connector/agents/__init__.py @@ -0,0 +1,6 @@ +from .base_agent import BaseAgent +from .react_agent import ReactAgent + +__all__ = [ + "BaseAgent", "ReactAgent" +] \ No newline at end of file diff --git a/dev_opsgpt/connector/agents/base_agent.py b/dev_opsgpt/connector/agents/base_agent.py new file mode 100644 index 0000000..1def69c --- /dev/null +++ b/dev_opsgpt/connector/agents/base_agent.py @@ -0,0 +1,427 @@ +from pydantic import BaseModel +from typing import List, Union +import re +import copy +import json +import traceback +import uuid +from loguru import logger + +from dev_opsgpt.connector.shcema.memory import Memory +from dev_opsgpt.connector.connector_schema import ( + Task, Role, Message, ActionStatus, Doc, CodeDoc + ) +from configs.server_config import SANDBOX_SERVER +from dev_opsgpt.sandbox import PyCodeBox, CodeBoxResponse +from dev_opsgpt.tools import DDGSTool, DocRetrieval, CodeRetrieval +from dev_opsgpt.connector.configs.agent_config import REACT_PROMPT_INPUT + +from dev_opsgpt.llm_models import getChatModel + + +class BaseAgent: + + def __init__( + self, + role: Role, + task: Task = None, + memory: Memory = None, + chat_turn: int = 1, + do_search: bool = False, + do_doc_retrieval: bool = False, + do_tool_retrieval: bool = False, + temperature: float = 0.2, + stop: Union[List[str], str] = None, + do_filter: bool = True, + do_use_self_memory: bool = True, + # docs_prompt: str, + # prompt_mamnger: PromptManager + ): + + self.task = task + self.role = role + self.llm = self.create_llm_engine(temperature, stop) + self.memory = self.init_history(memory) + self.chat_turn = chat_turn + self.do_search = do_search + self.do_doc_retrieval = do_doc_retrieval + self.do_tool_retrieval = do_tool_retrieval + self.codebox = PyCodeBox( + remote_url=SANDBOX_SERVER["url"], + remote_ip=SANDBOX_SERVER["host"], + remote_port=SANDBOX_SERVER["port"], + token="mytoken", + do_code_exe=True, + do_remote=SANDBOX_SERVER["do_remote"], + do_check_net=False + ) + self.do_filter = do_filter + self.do_use_self_memory = do_use_self_memory + # self.docs_prompt = docs_prompt + # self.prompt_manager = None + + def run(self, query: Message, history: Memory = None, background: Memory = None) -> Message: + '''llm inference''' + # insert query into memory + query_c = copy.deepcopy(query) + + self_memory = self.memory if self.do_use_self_memory else None + prompt = self.create_prompt(query_c, self_memory, history, background) + content = self.llm.predict(prompt) + logger.debug(f"{self.role.role_name} prompt: {prompt}") + # logger.debug(f"{self.role.role_name} content: {content}") + + output_message = Message( + role_name=self.role.role_name, + role_type="ai", #self.role.role_type, + role_content=content, + role_contents=[content], + input_query=query_c.input_query, + tools=query_c.tools + ) + + output_message = self.parser(output_message) + if self.do_filter: + output_message = self.filter(output_message) + + + # 更新自身的回答 + self.append_history(query_c) + self.append_history(output_message) + logger.info(f"{self.role.role_name} step_run: {output_message.role_content}") + return output_message + + def create_prompt( + self, query: Message, memory: Memory =None, history: Memory = None, background: Memory = None, prompt_mamnger=None) -> str: + ''' + role\task\tools\docs\memory + ''' + # + doc_infos = self.create_doc_prompt(query) + code_infos = self.create_codedoc_prompt(query) + # + formatted_tools, tool_names = self.create_tools_prompt(query) + task_prompt = self.create_task_prompt(query) + background_prompt = self.create_background_prompt(background, control_key="step_content") + history_prompt = self.create_history_prompt(history) + selfmemory_prompt = self.create_selfmemory_prompt(memory, control_key="step_content") + # + # extra_system_prompt = self.role.role_prompt + + prompt = self.role.role_prompt.format(**{"formatted_tools": formatted_tools, "tool_names": tool_names}) + + task = query.task or self.task + if task_prompt is not None: + prompt += "\n" + task.task_prompt + + if doc_infos is not None and doc_infos!="" and doc_infos!="不存在知识库辅助信息": + prompt += f"\n知识库信息: {doc_infos}" + + if code_infos is not None and code_infos!="" and code_infos!="不存在代码库辅助信息": + prompt += f"\n代码库信息: {code_infos}" + + if background_prompt: + prompt += "\n" + background_prompt + + if history_prompt: + prompt += "\n" + history_prompt + + if selfmemory_prompt: + prompt += "\n" + selfmemory_prompt + + # input_query = memory.to_tuple_messages(content_key="step_content") + # input_query = "\n".join([f"{k}: {v}" for k, v in input_query if v]) + + input_query = query.role_content + + # logger.debug(f"{self.role.role_name} extra_system_prompt: {self.role.role_prompt}") + logger.debug(f"{self.role.role_name} input_query: {input_query}") + # logger.debug(f"{self.role.role_name} doc_infos: {doc_infos}") + logger.debug(f"{self.role.role_name} tool_names: {tool_names}") + prompt += "\n" + REACT_PROMPT_INPUT.format(**{"query": input_query}) + # prompt = extra_system_prompt.format(**{"query": input_query, "doc_infos": doc_infos, "formatted_tools": formatted_tools, "tool_names": tool_names}) + while "{{" in prompt or "}}" in prompt: + prompt = prompt.replace("{{", "{") + prompt = prompt.replace("}}", "}") + return prompt + + # prompt_comp = [("system", extra_system_prompt)] + memory.to_tuple_messages() + # prompt = ChatPromptTemplate.from_messages(prompt_comp) + # prompt = prompt.format(**{"query": query.role_content, "doc_infos": doc_infos, "formatted_tools": formatted_tools, "tool_names": tool_names}) + # return prompt + + def create_doc_prompt(self, message: Message) -> str: + '''''' + db_docs = message.db_docs + search_docs = message.search_docs + doc_infos = "\n".join([doc.get_snippet() for doc in db_docs] + [doc.get_snippet() for doc in search_docs]) + return doc_infos or "不存在知识库辅助信息" + + def create_codedoc_prompt(self, message: Message) -> str: + '''''' + code_docs = message.code_docs + doc_infos = "\n".join([doc.get_code() for doc in code_docs]) + return doc_infos or "不存在代码库辅助信息" + + def create_tools_prompt(self, message: Message) -> str: + tools = message.tools + tool_strings = [] + for tool in tools: + args_schema = re.sub("}", "}}}}", re.sub("{", "{{{{", str(tool.args))) + tool_strings.append(f"{tool.name}: {tool.description}, args: {args_schema}") + formatted_tools = "\n".join(tool_strings) + tool_names = ", ".join([tool.name for tool in tools]) + return formatted_tools, tool_names + + def create_task_prompt(self, message: Message) -> str: + task = message.task or self.task + return "\n任务目标: " + task.task_prompt if task is not None else None + + def create_background_prompt(self, background: Memory, control_key="role_content") -> str: + background_message = None if background is None else background.to_str_messages(content_key=control_key) + # logger.debug(f"background_message: {background_message}") + if background_message: + background_message = re.sub("}", "}}", re.sub("{", "{{", background_message)) + return "\n背景信息: " + background_message if background_message else None + + def create_history_prompt(self, history: Memory, control_key="role_content") -> str: + history_message = None if history is None else history.to_str_messages(content_key=control_key) + if history_message: + history_message = re.sub("}", "}}", re.sub("{", "{{", history_message)) + return "\n补充对话信息: " + history_message if history_message else None + + def create_selfmemory_prompt(self, selfmemory: Memory, control_key="role_content") -> str: + selfmemory_message = None if selfmemory is None else selfmemory.to_str_messages(content_key=control_key) + if selfmemory_message: + selfmemory_message = re.sub("}", "}}", re.sub("{", "{{", selfmemory_message)) + return "\n补充自身对话信息: " + selfmemory_message if selfmemory_message else None + + def init_history(self, memory: Memory = None) -> Memory: + return Memory([]) + + def update_history(self, message: Message): + self.memory.append(message) + + def append_history(self, message: Message): + self.memory.append(message) + + def clear_history(self, ): + self.memory.clear() + self.memory = self.init_history() + + def create_llm_engine(self, temperature=0.2, stop=None): + return getChatModel(temperature=temperature, stop=stop) + + def filter(self, message: Message, stop=None) -> Message: + + tool_params = self.parser_spec_key(message.role_content, "tool_params") + code_content = self.parser_spec_key(message.role_content, "code_content") + plan = self.parser_spec_key(message.role_content, "plan") + plans = self.parser_spec_key(message.role_content, "plans", do_search=False) + content = self.parser_spec_key(message.role_content, "content", do_search=False) + + # logger.debug(f"tool_params: {tool_params}, code_content: {code_content}, plan: {plan}, plans: {plans}, content: {content}") + role_content = tool_params or code_content or plan or plans or content + message.role_content = role_content or message.role_content + return message + + def token_usage(self, ): + pass + + def get_extra_infos(self, message: Message) -> Message: + '''''' + if self.do_search: + message = self.get_search_retrieval(message) + + if self.do_doc_retrieval: + message = self.get_doc_retrieval(message) + + if self.do_tool_retrieval: + message = self.get_tool_retrieval(message) + + return message + + def get_search_retrieval(self, message: Message,) -> Message: + SEARCH_ENGINES = {"duckduckgo": DDGSTool} + search_docs = [] + for idx, doc in enumerate(SEARCH_ENGINES["duckduckgo"].run(message.role_content, 3)): + doc.update({"index": idx}) + search_docs.append(Doc(**doc)) + message.search_docs = search_docs + return message + + def get_doc_retrieval(self, message: Message) -> Message: + query = message.role_content + knowledge_basename = message.doc_engine_name + top_k = message.top_k + score_threshold = message.score_threshold + if knowledge_basename: + docs = DocRetrieval.run(query, knowledge_basename, top_k, score_threshold) + message.db_docs = [Doc(**doc) for doc in docs] + return message + + def get_code_retrieval(self, message: Message) -> Message: + # DocRetrieval.run("langchain是什么", "DSADSAD") + query = message.input_query + code_engine_name = message.code_engine_name + history_node_list = message.history_node_list + code_docs = CodeRetrieval.run(code_engine_name, query, code_limit=message.top_k, history_node_list=history_node_list) + message.code_docs = [CodeDoc(**doc) for doc in code_docs] + return message + + def get_tool_retrieval(self, message: Message) -> Message: + return message + + def step_router(self, message: Message) -> Message: + '''''' + # message = self.parser(message) + # logger.debug(f"message.action_status: {message.action_status}") + if message.action_status == ActionStatus.CODING: + message = self.code_step(message) + elif message.action_status == ActionStatus.TOOL_USING: + message = self.tool_step(message) + + return message + + def code_step(self, message: Message) -> Message: + '''execute code''' + # logger.debug(f"message.role_content: {message.role_content}, message.code_content: {message.code_content}") + code_answer = self.codebox.chat('```python\n{}```'.format(message.code_content)) + code_prompt = f"执行上述代码后存在报错信息为 {code_answer.code_exe_response},需要进行修复" \ + if code_answer.code_exe_type == "error" else f"执行上述代码后返回信息为 {code_answer.code_exe_response}" + uid = str(uuid.uuid1()) + if code_answer.code_exe_type == "image/png": + message.figures[uid] = code_answer.code_exe_response + message.code_answer = f"\n观察: 执行上述代码后生成一张图片, 图片名为{uid}\n" + message.observation = f"\n观察: 执行上述代码后生成一张图片, 图片名为{uid}\n" + message.step_content += f"\n观察: 执行上述代码后生成一张图片, 图片名为{uid}\n" + message.step_contents += [f"\n观察: 执行上述代码后生成一张图片, 图片名为{uid}\n"] + message.role_content += f"\n观察:执行上述代码后生成一张图片, 图片名为{uid}\n" + else: + message.code_answer = code_answer.code_exe_response + message.observation = code_answer.code_exe_response + message.step_content += f"\n观察: {code_prompt}\n" + message.step_contents += [f"\n观察: {code_prompt}\n"] + message.role_content += f"\n观察: {code_prompt}\n" + # logger.info(f"观察: {message.action_status}, {message.observation}") + return message + + def tool_step(self, message: Message) -> Message: + '''execute tool''' + # logger.debug(f"message: {message.action_status}, {message.tool_name}, {message.tool_params}") + tool_names = [tool.name for tool in message.tools] + if message.tool_name not in tool_names: + message.tool_answer = "不存在可以执行的tool" + message.observation = "不存在可以执行的tool" + message.role_content += f"\n观察: 不存在可以执行的tool\n" + message.step_content += f"\n观察: 不存在可以执行的tool\n" + message.step_contents += [f"\n观察: 不存在可以执行的tool\n"] + for tool in message.tools: + if tool.name == message.tool_name: + tool_res = tool.func(**message.tool_params) + message.tool_answer = tool_res + message.observation = tool_res + message.role_content += f"\n观察: {tool_res}\n" + message.step_content += f"\n观察: {tool_res}\n" + message.step_contents += [f"\n观察: {tool_res}\n"] + + # logger.info(f"观察: {message.action_status}, {message.observation}") + return message + + def parser(self, message: Message) -> Message: + '''''' + content = message.role_content + parser_keys = ["action", "code_content", "code_filename", "tool_params", "plans"] + try: + s_json = self._parse_json(content) + message.action_status = s_json.get("action") + message.code_content = s_json.get("code_content") + message.tool_params = s_json.get("tool_params") + message.tool_name = s_json.get("tool_name") + message.code_filename = s_json.get("code_filename") + message.plans = s_json.get("plans") + # for parser_key in parser_keys: + # message.action_status = content.get(parser_key) + except Exception as e: + # logger.warning(f"{traceback.format_exc()}") + action_value = self._match(r"'action':\s*'([^']*)'", content) if "'action'" in content else self._match(r'"action":\s*"([^"]*)"', content) + code_content_value = self._match(r"'code_content':\s*'([^']*)'", content) if "'code_content'" in content else self._match(r'"code_content":\s*"([^"]*)"', content) + filename_value = self._match(r"'code_filename':\s*'([^']*)'", content) if "'code_filename'" in content else self._match(r'"code_filename":\s*"([^"]*)"', content) + tool_params_value = self._match(r"'tool_params':\s*(\{[^{}]*\})", content, do_json=True) if "'tool_params'" in content \ + else self._match(r'"tool_params":\s*(\{[^{}]*\})', content, do_json=True) + tool_name_value = self._match(r"'tool_name':\s*'([^']*)'", content) if "'tool_name'" in content else self._match(r'"tool_name":\s*"([^"]*)"', content) + plans_value = self._match(r"'plans':\s*(\[.*?\])", content, do_search=False) if "'plans'" in content else self._match(r'"plans":\s*(\[.*?\])', content, do_search=False, ) + # re解析 + message.action_status = action_value or "default" + message.code_content = code_content_value + message.code_filename = filename_value + message.tool_params = tool_params_value + message.tool_name = tool_name_value + message.plans = plans_value + + # logger.debug(f"确认当前的action: {message.action_status}") + + return message + + def parser_spec_key(self, content, key, do_search=True, do_json=False) -> str: + '''''' + key2pattern = { + "'action'": r"'action':\s*'([^']*)'", '"action"': r'"action":\s*"([^"]*)"', + "'code_content'": r"'code_content':\s*'([^']*)'", '"code_content"': r'"code_content":\s*"([^"]*)"', + "'code_filename'": r"'code_filename':\s*'([^']*)'", '"code_filename"': r'"code_filename":\s*"([^"]*)"', + "'tool_params'": r"'tool_params':\s*(\{[^{}]*\})", '"tool_params"': r'"tool_params":\s*(\{[^{}]*\})', + "'tool_name'": r"'tool_name':\s*'([^']*)'", '"tool_name"': r'"tool_name":\s*"([^"]*)"', + "'plans'": r"'plans':\s*(\[.*?\])", '"plans"': r'"plans":\s*(\[.*?\])', + "'content'": r"'content':\s*'([^']*)'", '"content"': r'"content":\s*"([^"]*)"', + } + + s_json = self._parse_json(content) + try: + if s_json and key in s_json: + return str(s_json[key]) + except: + pass + + keystr = f"'{key}'" if f"'{key}'" in content else f'"{key}"' + return self._match(key2pattern.get(keystr, fr"'{key}':\s*'([^']*)'"), content, do_search=do_search, do_json=do_json) + + def _match(self, pattern, s, do_search=True, do_json=False): + try: + if do_search: + match = re.search(pattern, s) + if match: + value = match.group(1).replace("\\n", "\n") + if do_json: + value = json.loads(value) + else: + value = None + else: + match = re.findall(pattern, s, re.DOTALL) + if match: + value = match[0] + if do_json: + value = json.loads(value) + else: + value = None + except Exception as e: + logger.warning(f"{traceback.format_exc()}") + + # logger.debug(f"pattern: {pattern}, s: {s}, match: {match}") + return value + + def _parse_json(self, s): + try: + pattern = r"```([^`]+)```" + match = re.findall(pattern, s) + if match: + return eval(match[0]) + except: + pass + return None + + + def get_memory(self, ): + return self.memory.to_tuple_messages(content_key="step_content") + + def get_memory_str(self, ): + return "\n".join([": ".join(i) for i in self.memory.to_tuple_messages(content_key="step_content")]) \ No newline at end of file diff --git a/dev_opsgpt/connector/agents/react_agent.py b/dev_opsgpt/connector/agents/react_agent.py new file mode 100644 index 0000000..4625f74 --- /dev/null +++ b/dev_opsgpt/connector/agents/react_agent.py @@ -0,0 +1,138 @@ +from pydantic import BaseModel +from typing import List, Union +import re +import traceback +import copy +from loguru import logger + +from langchain.prompts.chat import ChatPromptTemplate + +from dev_opsgpt.connector.connector_schema import Message +from dev_opsgpt.connector.shcema.memory import Memory +from dev_opsgpt.connector.connector_schema import Task, Env, Role, Message, ActionStatus +from dev_opsgpt.llm_models import getChatModel +from dev_opsgpt.connector.configs.agent_config import REACT_PROMPT_INPUT + +from .base_agent import BaseAgent + + +class ReactAgent(BaseAgent): + def __init__( + self, + role: Role, + task: Task = None, + memory: Memory = None, + chat_turn: int = 1, + do_search: bool = False, + do_doc_retrieval: bool = False, + do_tool_retrieval: bool = False, + temperature: float = 0.2, + stop: Union[List[str], str] = "观察", + do_filter: bool = True, + do_use_self_memory: bool = True, + # docs_prompt: str, + # prompt_mamnger: PromptManager + ): + super().__init__(role, task, memory, chat_turn, do_search, do_doc_retrieval, + do_tool_retrieval, temperature, stop, do_filter,do_use_self_memory + ) + + def run(self, query: Message, history: Memory = None, background: Memory = None) -> Message: + step_nums = copy.deepcopy(self.chat_turn) + react_memory = Memory([]) + # 问题插入 + output_message = Message( + role_name=self.role.role_name, + role_type="ai", #self.role.role_type, + role_content=query.input_query, + step_content=query.input_query, + input_query=query.input_query, + tools=query.tools + ) + react_memory.append(output_message) + idx = 0 + while step_nums > 0: + output_message.role_content = output_message.step_content + self_memory = self.memory if self.do_use_self_memory else None + prompt = self.create_prompt(query, self_memory, history, background, react_memory) + try: + content = self.llm.predict(prompt) + except Exception as e: + logger.warning(f"error prompt: {prompt}") + raise Exception(traceback.format_exc()) + + output_message.role_content = content + output_message.role_contents += [content] + output_message.step_content += output_message.role_content + output_message.step_contents + [output_message.role_content] + + # logger.debug(f"{self.role.role_name}, {idx} iteration prompt: {prompt}") + # logger.info(f"{self.role.role_name}, {idx} iteration step_run: {output_message.role_content}") + + output_message = self.parser(output_message) + # when get finished signal can stop early + if output_message.action_status == ActionStatus.FINISHED: break + # according the output to choose one action for code_content or tool_content + output_message = self.step_router(output_message) + logger.info(f"{self.role.role_name} react_run: {output_message.role_content}") + + idx += 1 + step_nums -= 1 + # react' self_memory saved at last + self.append_history(output_message) + return output_message + + def create_prompt( + self, query: Message, memory: Memory =None, history: Memory = None, background: Memory = None, react_memory: Memory = None, prompt_mamnger=None) -> str: + ''' + role\task\tools\docs\memory + ''' + # + doc_infos = self.create_doc_prompt(query) + code_infos = self.create_codedoc_prompt(query) + # + formatted_tools, tool_names = self.create_tools_prompt(query) + task_prompt = self.create_task_prompt(query) + background_prompt = self.create_background_prompt(background) + history_prompt = self.create_history_prompt(history) + selfmemory_prompt = self.create_selfmemory_prompt(memory, control_key="step_content") + # + # extra_system_prompt = self.role.role_prompt + prompt = self.role.role_prompt.format(**{"formatted_tools": formatted_tools, "tool_names": tool_names}) + + + task = query.task or self.task + if task_prompt is not None: + prompt += "\n" + task.task_prompt + + if doc_infos is not None and doc_infos!="" and doc_infos!="不存在知识库辅助信息": + prompt += f"\n知识库信息: {doc_infos}" + + if code_infos is not None and code_infos!="" and code_infos!="不存在代码库辅助信息": + prompt += f"\n代码库信息: {code_infos}" + + if background_prompt: + prompt += "\n" + background_prompt + + if history_prompt: + prompt += "\n" + history_prompt + + if selfmemory_prompt: + prompt += "\n" + selfmemory_prompt + + # react 流程是自身迭代过程,另外二次触发的是需要作为历史对话信息 + input_query = react_memory.to_tuple_messages(content_key="step_content") + input_query = "\n".join([f"{v}" for k, v in input_query if v]) + + # logger.debug(f"{self.role.role_name} extra_system_prompt: {self.role.role_prompt}") + # logger.debug(f"{self.role.role_name} input_query: {input_query}") + # logger.debug(f"{self.role.role_name} doc_infos: {doc_infos}") + # logger.debug(f"{self.role.role_name} tool_names: {tool_names}") + prompt += "\n" + REACT_PROMPT_INPUT.format(**{"query": input_query}) + + # prompt = extra_system_prompt.format(**{"query": input_query, "doc_infos": doc_infos, "formatted_tools": formatted_tools, "tool_names": tool_names}) + while "{{" in prompt or "}}" in prompt: + prompt = prompt.replace("{{", "{") + prompt = prompt.replace("}}", "}") + return prompt + diff --git a/dev_opsgpt/connector/chains/__init__.py b/dev_opsgpt/connector/chains/__init__.py new file mode 100644 index 0000000..8d3676c --- /dev/null +++ b/dev_opsgpt/connector/chains/__init__.py @@ -0,0 +1,5 @@ +from .base_chain import BaseChain + +__all__ = [ + "BaseChain" +] \ No newline at end of file diff --git a/dev_opsgpt/connector/chains/base_chain.py b/dev_opsgpt/connector/chains/base_chain.py new file mode 100644 index 0000000..dc49755 --- /dev/null +++ b/dev_opsgpt/connector/chains/base_chain.py @@ -0,0 +1,281 @@ +from pydantic import BaseModel +from typing import List +import json +import re +from loguru import logger +import traceback +import uuid +import copy + +from dev_opsgpt.connector.agents import BaseAgent +from dev_opsgpt.tools.base_tool import BaseTools, Tool +from dev_opsgpt.connector.shcema.memory import Memory +from dev_opsgpt.connector.connector_schema import ( + Role, Message, ActionStatus, ChainConfig, + load_role_configs +) +from dev_opsgpt.sandbox import PyCodeBox, CodeBoxResponse + + +from configs.server_config import SANDBOX_SERVER + +from dev_opsgpt.connector.configs.agent_config import AGETN_CONFIGS +role_configs = load_role_configs(AGETN_CONFIGS) + + +class BaseChain: + def __init__( + self, + chainConfig: ChainConfig, + agents: List[BaseAgent], + chat_turn: int = 1, + do_checker: bool = False, + do_code_exec: bool = False, + # prompt_mamnger: PromptManager + ) -> None: + self.chainConfig = chainConfig + self.agents = agents + self.chat_turn = chat_turn + self.do_checker = do_checker + self.checker = BaseAgent(role=role_configs["checker"].role, + task = None, + memory = None, + do_search = role_configs["checker"].do_search, + do_doc_retrieval = role_configs["checker"].do_doc_retrieval, + do_tool_retrieval = role_configs["checker"].do_tool_retrieval, + do_filter=False, do_use_self_memory=False) + + self.global_memory = Memory([]) + self.local_memory = Memory([]) + self.do_code_exec = do_code_exec + self.codebox = PyCodeBox( + remote_url=SANDBOX_SERVER["url"], + remote_ip=SANDBOX_SERVER["host"], + remote_port=SANDBOX_SERVER["port"], + token="mytoken", + do_code_exe=True, + do_remote=SANDBOX_SERVER["do_remote"], + do_check_net=False + ) + + def step(self, query: Message, history: Memory = None, background: Memory = None) -> Message: + '''execute chain''' + local_memory = Memory([]) + input_message = copy.deepcopy(query) + step_nums = copy.deepcopy(self.chat_turn) + check_message = None + + self.global_memory.append(input_message) + local_memory.append(input_message) + while step_nums > 0: + + for agent in self.agents: + output_message = agent.run(input_message, history, background=background) + output_message = self.inherit_extrainfo(input_message, output_message) + # according the output to choose one action for code_content or tool_content + logger.info(f"{agent.role.role_name} message: {output_message.role_content}") + output_message = self.parser(output_message) + # output_message = self.step_router(output_message) + + input_message = output_message + self.global_memory.append(output_message) + + local_memory.append(output_message) + # when get finished signal can stop early + if output_message.action_status == ActionStatus.FINISHED: + break + + if self.do_checker: + logger.debug(f"{self.checker.role.role_name} input global memory: {self.global_memory.to_str_messages(content_key='step_content')}") + check_message = self.checker.run(query, background=self.global_memory) + check_message = self.parser(check_message) + check_message = self.filter(check_message) + check_message = self.inherit_extrainfo(output_message, check_message) + logger.debug(f"{self.checker.role.role_name}: {check_message.role_content}") + + if check_message.action_status == ActionStatus.FINISHED: + self.global_memory.append(check_message) + break + + step_nums -= 1 + + return check_message or output_message, local_memory + + def step_router(self, message: Message) -> Message: + '''''' + # message = self.parser(message) + # logger.debug(f"message.action_status: {message.action_status}") + if message.action_status == ActionStatus.CODING: + message = self.code_step(message) + elif message.action_status == ActionStatus.TOOL_USING: + message = self.tool_step(message) + + return message + + def code_step(self, message: Message) -> Message: + '''execute code''' + # logger.debug(f"message.role_content: {message.role_content}, message.code_content: {message.code_content}") + code_answer = self.codebox.chat('```python\n{}```'.format(message.code_content)) + uid = str(uuid.uuid1()) + if code_answer.code_exe_type == "image/png": + message.figures[uid] = code_answer.code_exe_response + message.code_answer = f"\n观察: 执行代码后获得输出一张图片, 文件名为{uid}\n" + message.observation = f"\n观察: 执行代码后获得输出一张图片, 文件名为{uid}\n" + message.step_content += f"\n观察: 执行代码后获得输出一张图片, 文件名为{uid}\n" + message.step_contents += [f"\n观察: 执行代码后获得输出一张图片, 文件名为{uid}\n"] + message.role_content += f"\n执行代码后获得输出一张图片, 文件名为{uid}\n" + else: + message.code_answer = code_answer.code_exe_response + message.observation = code_answer.code_exe_response + message.step_content += f"\n观察: 执行代码后获得输出是 {code_answer.code_exe_response}\n" + message.step_contents += [f"\n观察: 执行代码后获得输出是 {code_answer.code_exe_response}\n"] + message.role_content += f"\n观察: 执行代码后获得输出是 {code_answer.code_exe_response}\n" + logger.info(f"观察: {message.action_status}, {message.observation}") + return message + + def tool_step(self, message: Message) -> Message: + '''execute tool''' + # logger.debug(f"message: {message.action_status}, {message.tool_name}, {message.tool_params}") + tool_names = [tool.name for tool in message.tools] + if message.tool_name not in tool_names: + message.tool_answer = "不存在可以执行的tool" + message.observation = "不存在可以执行的tool" + message.role_content += f"\n观察: 不存在可以执行的tool\n" + message.step_content += f"\n观察: 不存在可以执行的tool\n" + message.step_contents += [f"\n观察: 不存在可以执行的tool\n"] + for tool in message.tools: + if tool.name == message.tool_name: + tool_res = tool.func(**message.tool_params) + message.tool_answer = tool_res + message.observation = tool_res + message.role_content += f"\n观察: {tool_res}\n" + message.step_content += f"\n观察: {tool_res}\n" + message.step_contents += [f"\n观察: {tool_res}\n"] + return message + + def filter(self, message: Message, stop=None) -> Message: + + tool_params = self.parser_spec_key(message.role_content, "tool_params") + code_content = self.parser_spec_key(message.role_content, "code_content") + plan = self.parser_spec_key(message.role_content, "plan") + plans = self.parser_spec_key(message.role_content, "plans", do_search=False) + content = self.parser_spec_key(message.role_content, "content", do_search=False) + + # logger.debug(f"tool_params: {tool_params}, code_content: {code_content}, plan: {plan}, plans: {plans}, content: {content}") + role_content = tool_params or code_content or plan or plans or content + message.role_content = role_content or message.role_content + return message + + def parser(self, message: Message) -> Message: + '''''' + content = message.role_content + parser_keys = ["action", "code_content", "code_filename", "tool_params", "plans"] + try: + s_json = self._parse_json(content) + message.action_status = s_json.get("action") + message.code_content = s_json.get("code_content") + message.tool_params = s_json.get("tool_params") + message.tool_name = s_json.get("tool_name") + message.code_filename = s_json.get("code_filename") + message.plans = s_json.get("plans") + # for parser_key in parser_keys: + # message.action_status = content.get(parser_key) + except Exception as e: + # logger.warning(f"{traceback.format_exc()}") + action_value = self._match(r"'action':\s*'([^']*)'", content) if "'action'" in content else self._match(r'"action":\s*"([^"]*)"', content) + code_content_value = self._match(r"'code_content':\s*'([^']*)'", content) if "'code_content'" in content else self._match(r'"code_content":\s*"([^"]*)"', content) + filename_value = self._match(r"'code_filename':\s*'([^']*)'", content) if "'code_filename'" in content else self._match(r'"code_filename":\s*"([^"]*)"', content) + tool_params_value = self._match(r"'tool_params':\s*(\{[^{}]*\})", content, do_json=True) if "'tool_params'" in content \ + else self._match(r'"tool_params":\s*(\{[^{}]*\})', content, do_json=True) + tool_name_value = self._match(r"'tool_name':\s*'([^']*)'", content) if "'tool_name'" in content else self._match(r'"tool_name":\s*"([^"]*)"', content) + plans_value = self._match(r"'plans':\s*(\[.*?\])", content, do_search=False) if "'plans'" in content else self._match(r'"plans":\s*(\[.*?\])', content, do_search=False, ) + # re解析 + message.action_status = action_value or "default" + message.code_content = code_content_value + message.code_filename = filename_value + message.tool_params = tool_params_value + message.tool_name = tool_name_value + message.plans = plans_value + + logger.debug(f"确认当前的action: {message.action_status}") + + return message + + def parser_spec_key(self, content, key, do_search=True, do_json=False) -> str: + '''''' + key2pattern = { + "'action'": r"'action':\s*'([^']*)'", '"action"': r'"action":\s*"([^"]*)"', + "'code_content'": r"'code_content':\s*'([^']*)'", '"code_content"': r'"code_content":\s*"([^"]*)"', + "'code_filename'": r"'code_filename':\s*'([^']*)'", '"code_filename"': r'"code_filename":\s*"([^"]*)"', + "'tool_params'": r"'tool_params':\s*(\{[^{}]*\})", '"tool_params"': r'"tool_params":\s*(\{[^{}]*\})', + "'tool_name'": r"'tool_name':\s*'([^']*)'", '"tool_name"': r'"tool_name":\s*"([^"]*)"', + "'plans'": r"'plans':\s*(\[.*?\])", '"plans"': r'"plans":\s*(\[.*?\])', + "'content'": r"'content':\s*'([^']*)'", '"content"': r'"content":\s*"([^"]*)"', + } + + s_json = self._parse_json(content) + try: + if s_json and key in s_json: + return str(s_json[key]) + except: + pass + + keystr = f"'{key}'" if f"'{key}'" in content else f'"{key}"' + return self._match(key2pattern.get(keystr, fr"'{key}':\s*'([^']*)'"), content, do_search=do_search, do_json=do_json) + + def _match(self, pattern, s, do_search=True, do_json=False): + try: + if do_search: + match = re.search(pattern, s) + if match: + value = match.group(1).replace("\\n", "\n") + if do_json: + value = json.loads(value) + else: + value = None + else: + match = re.findall(pattern, s, re.DOTALL) + if match: + value = match[0] + if do_json: + value = json.loads(value) + else: + value = None + except Exception as e: + logger.warning(f"{traceback.format_exc()}") + + # logger.debug(f"pattern: {pattern}, s: {s}, match: {match}") + return value + + def _parse_json(self, s): + try: + pattern = r"```([^`]+)```" + match = re.findall(pattern, s) + if match: + return eval(match[0]) + except: + pass + return None + + def inherit_extrainfo(self, input_message: Message, output_message: Message): + output_message.db_docs = input_message.db_docs + output_message.search_docs = input_message.search_docs + output_message.code_docs = input_message.code_docs + output_message.figures.update(input_message.figures) + return output_message + + def get_memory(self, do_all_memory=True, content_key="role_content") -> Memory: + memory = self.global_memory if do_all_memory else self.local_memory + return memory.to_tuple_messages(content_key=content_key) + + def get_memory_str(self, do_all_memory=True, content_key="role_content") -> Memory: + memory = self.global_memory if do_all_memory else self.local_memory + # for i in memory.to_tuple_messages(content_key=content_key): + # logger.debug(f"{i}") + return "\n".join([": ".join(i) for i in memory.to_tuple_messages(content_key=content_key)]) + + def get_agents_memory(self, content_key="role_content"): + return [agent.get_memory(content_key=content_key) for agent in self.agents] + + def get_agents_memory_str(self, content_key="role_content"): + return "************".join([f"{agent.role.role_name}\n" + agent.get_memory_str(content_key=content_key) for agent in self.agents]) \ No newline at end of file diff --git a/dev_opsgpt/connector/chains/chains.py b/dev_opsgpt/connector/chains/chains.py new file mode 100644 index 0000000..fd490b0 --- /dev/null +++ b/dev_opsgpt/connector/chains/chains.py @@ -0,0 +1,28 @@ +from typing import List +from dev_opsgpt.connector.agents import BaseAgent +from .base_chain import BaseChain + + + +class simpleChatChain(BaseChain): + + def __init__(self, agents: List[BaseAgent], do_code_exec: bool = False) -> None: + super().__init__(agents, do_code_exec) + + +class toolChatChain(BaseChain): + + def __init__(self, agents: List[BaseAgent], do_code_exec: bool = False) -> None: + super().__init__(agents, do_code_exec) + + +class dataAnalystChain(BaseChain): + + def __init__(self, agents: List[BaseAgent], do_code_exec: bool = False) -> None: + super().__init__(agents, do_code_exec) + + +class plannerChain(BaseChain): + + def __init__(self, agents: List[BaseAgent], do_code_exec: bool = False) -> None: + super().__init__(agents, do_code_exec) diff --git a/dev_opsgpt/connector/configs/__init__.py b/dev_opsgpt/connector/configs/__init__.py new file mode 100644 index 0000000..873ba85 --- /dev/null +++ b/dev_opsgpt/connector/configs/__init__.py @@ -0,0 +1,7 @@ +from .agent_config import AGETN_CONFIGS +from .chain_config import CHAIN_CONFIGS +from .phase_config import PHASE_CONFIGS + +__all__ = [ + "AGETN_CONFIGS", "CHAIN_CONFIGS", "PHASE_CONFIGS" + ] \ No newline at end of file diff --git a/dev_opsgpt/connector/configs/agent_config.py b/dev_opsgpt/connector/configs/agent_config.py new file mode 100644 index 0000000..42432a1 --- /dev/null +++ b/dev_opsgpt/connector/configs/agent_config.py @@ -0,0 +1,410 @@ +from enum import Enum + + +class AgentType: + REACT = "ReactAgent" + ONE_STEP = "BaseAgent" + DEFAULT = "BaseAgent" + + +REACT_TOOL_PROMPT = """尽可能地以有帮助和准确的方式回应人类。您可以使用以下工具: +{formatted_tools} +使用json blob来指定一个工具,提供一个action关键字(工具名称)和一个tool_params关键字(工具输入)。 +有效的"action"值为:"finished" 或 "tool_using" (使用工具来回答问题) +有效的"tool_name"值为:{tool_names} +请仅在每个$JSON_BLOB中提供一个action,如下所示: +``` +{{{{ +"action": $ACTION, +"tool_name": $TOOL_NAME +"tool_params": $INPUT +}}}} +``` + +按照以下格式进行回应: +问题:输入问题以回答 +思考:考虑之前和之后的步骤 +行动: +``` +$JSON_BLOB +``` +观察:行动结果 +...(重复思考/行动/观察N次) +思考:我知道该如何回应 +行动: +``` +{{{{ +"action": "finished", +"tool_name": "notool" +"tool_params": "最终返回答案给到用户" +}}}} +``` +""" + +REACT_PROMPT_INPUT = '''下面开始!记住根据问题进行返回需要生成的答案 +问题: {query}''' + + +REACT_CODE_PROMPT = """尽可能地以有帮助和准确的方式回应人类,能够逐步编写可执行并打印变量的代码来解决问题 +使用 JSON Blob 来指定一个返回的内容,提供一个 action(行动)和一个 code (生成代码)。 +有效的 'action' 值为:'coding'(结合总结下述思维链过程编写下一步的可执行代码) or 'finished' (总结下述思维链过程可回答问题)。 +在每个 $JSON_BLOB 中仅提供一个 action,如下所示: +``` +{{{{'action': $ACTION,'code_content': $CODE}}}} +``` + +按照以下思维链格式进行回应: +问题:输入问题以回答 +思考:考虑之前和之后的步骤 +行动: +``` +$JSON_BLOB +``` +观察:行动结果 +...(重复思考/行动/观察N次) +思考:我知道该如何回应 +行动: +``` +{{{{ +"action": "finished", +"code_content": "总结上述思维链过程回答问题" +}}}} +``` +""" + +GENERAL_PLANNER_PROMPT = """你是一个通用计划拆解助手,将问题拆解问题成各个详细明确的步骤计划或直接回答问题,尽可能地以有帮助和准确的方式回应人类, +使用 JSON Blob 来指定一个返回的内容,提供一个 action(行动)和一个 plans (生成的计划)。 +有效的 'action' 值为:'planning'(拆解计划) or 'only_answer' (不需要拆解问题即可直接回答问题)。 +有效的 'plans' 值为: 一个任务列表,按顺序写出需要执行的计划 +在每个 $JSON_BLOB 中仅提供一个 action,如下所示: +``` +{{'action': 'planning', 'plans': [$PLAN1, $PLAN2, $PLAN3, ..., $PLANN], }} +或者 +{{'action': 'only_answer', 'plans': "直接回答问题", }} +``` + +按照以下格式进行回应: +问题:输入问题以回答 +行动: +``` +$JSON_BLOB +``` +""" + +DATA_PLANNER_PROMPT = """你是一个数据分析助手,能够根据问题来制定一个详细明确的数据分析计划,尽可能地以有帮助和准确的方式回应人类, +使用 JSON Blob 来指定一个返回的内容,提供一个 action(行动)和一个 plans (生成的计划)。 +有效的 'action' 值为:'planning'(拆解计划) or 'only_answer' (不需要拆解问题即可直接回答问题)。 +有效的 'plans' 值为: 一份数据分析计划清单,按顺序排列,用文本表示 +在每个 $JSON_BLOB 中仅提供一个 action,如下所示: +``` +{{'action': 'planning', 'plans': '$PLAN1, $PLAN2, ..., $PLAN3' }} +``` + +按照以下格式进行回应: +问题:输入问题以回答 +行动: +``` +$JSON_BLOB +``` +""" + +TOOL_PLANNER_PROMPT = """你是一个工具使用过程的计划拆解助手,将问题拆解为一系列的工具使用计划,若没有可用工具则直接回答问题,尽可能地以有帮助和准确的方式回应人类,你可以使用以下工具: +{formatted_tools} +使用 JSON Blob 来指定一个返回的内容,提供一个 action(行动)和一个 plans (生成的计划)。 +有效的 'action' 值为:'planning'(拆解计划) or 'only_answer' (不需要拆解问题即可直接回答问题)。 +有效的 'plans' 值为: 一个任务列表,按顺序写出需要使用的工具和使用该工具的理由 +在每个 $JSON_BLOB 中仅提供一个 action,如下两个示例所示: +``` +{{'action': 'planning', 'plans': [$PLAN1, $PLAN2, $PLAN3, ..., $PLANN], }} +``` +或者 若无法通过以上工具解决问题,则直接回答问题 +``` +{{'action': 'only_answer', 'plans': "直接回答问题", }} +``` + +按照以下格式进行回应: +问题:输入问题以回答 +行动: +``` +$JSON_BLOB +``` +""" + + +RECOGNIZE_INTENTION_PROMPT = """你是一个任务决策助手,能够将理解用户意图并决策采取最合适的行动,尽可能地以有帮助和准确的方式回应人类, +使用 JSON Blob 来指定一个返回的内容,提供一个 action(行动)。 +有效的 'action' 值为:'planning'(需要先进行拆解计划) or 'only_answer' (不需要拆解问题即可直接回答问题)or "tool_using" (使用工具来回答问题) or 'coding'(生成可执行的代码)。 +在每个 $JSON_BLOB 中仅提供一个 action,如下所示: +``` +{{'action': $ACTION}} +``` +按照以下格式进行回应: +问题:输入问题以回答 +行动:$ACTION +``` +$JSON_BLOB +``` +""" + + +CHECKER_PROMPT = """尽可能地以有帮助和准确的方式回应人类,判断问题是否得到解答,同时展现解答的过程和内容 +使用 JSON Blob 来指定一个返回的内容,提供一个 action(行动)。 +有效的 'action' 值为:'finished'(任务已经可以通过“背景信息”和“对话信息”回答问题) or 'continue' (“背景信息”和“对话信息”不足以回答问题)。 +在每个 $JSON_BLOB 中仅提供一个 action,如下所示: +``` +{{'action': $ACTION, 'content': '提取“背景信息”和“对话信息”中信息来回答问题'}} +``` +按照以下格式进行回应: +问题:输入问题以回答 +行动:$ACTION +``` +$JSON_BLOB +``` +""" + +CONV_SUMMARY_PROMPT = """尽可能地以有帮助和准确的方式回应人类,根据“背景信息”中的有效信息回答问题, +使用 JSON Blob 来指定一个返回的内容,提供一个 action(行动)。 +有效的 'action' 值为:'finished'(任务已经可以通过上下文信息可以回答) or 'continue' (根据背景信息回答问题)。 +在每个 $JSON_BLOB 中仅提供一个 action,如下所示: +``` +{{'action': $ACTION, 'content': '根据背景信息回答问题'}} +``` +按照以下格式进行回应: +问题:输入问题以回答 +行动: +``` +$JSON_BLOB +``` +""" + +CONV_SUMMARY_PROMPT = """尽可能地以有帮助和准确的方式回应人类 +根据“背景信息”中的有效信息回答问题,同时展现解答的过程和内容 +若能根“背景信息”回答问题,则直接回答 +否则,总结“背景信息”的内容 +""" + + + +QA_PROMPT = """根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,不允许在答案中添加编造成分,答案请使用中文。 +使用 JSON Blob 来指定一个返回的内容,提供一个 action(行动)。 +有效的 'action' 值为:'finished'(任务已经可以通过上下文信息可以回答) or 'continue' (上下文信息不足以回答问题)。 +在每个 $JSON_BLOB 中仅提供一个 action,如下所示: +``` +{{'action': $ACTION, 'content': '总结对话内容'}} +``` +按照以下格式进行回应: +问题:输入问题以回答 +行动:$ACTION +``` +$JSON_BLOB +``` +""" + +CODE_QA_PROMPT = """【指令】根据已知信息来回答问""" + + +AGETN_CONFIGS = { + "checker": { + "role": { + "role_prompt": CHECKER_PROMPT, + "role_type": "ai", + "role_name": "checker", + "role_desc": "", + "agent_type": "BaseAgent" + }, + "chat_turn": 1, + "do_search": False, + "do_doc_retrieval": False, + "do_tool_retrieval": False + }, + "conv_summary": { + "role": { + "role_prompt": CONV_SUMMARY_PROMPT, + "role_type": "ai", + "role_name": "conv_summary", + "role_desc": "", + "agent_type": "BaseAgent" + }, + "chat_turn": 1, + "do_search": False, + "do_doc_retrieval": False, + "do_tool_retrieval": False + }, + "general_planner": { + "role": { + "role_prompt": GENERAL_PLANNER_PROMPT, + "role_type": "ai", + "role_name": "general_planner", + "role_desc": "", + "agent_type": "BaseAgent" + }, + "chat_turn": 1, + "do_search": False, + "do_doc_retrieval": False, + "do_tool_retrieval": False + }, + "planner": { + "role": { + "role_prompt": DATA_PLANNER_PROMPT, + "role_type": "ai", + "role_name": "planner", + "role_desc": "", + "agent_type": "BaseAgent" + }, + "chat_turn": 1, + "do_search": False, + "do_doc_retrieval": False, + "do_tool_retrieval": False + }, + "intention_recognizer": { + "role": { + "role_prompt": RECOGNIZE_INTENTION_PROMPT, + "role_type": "ai", + "role_name": "intention_recognizer", + "role_desc": "", + "agent_type": "BaseAgent" + }, + "chat_turn": 1, + "do_search": False, + "do_doc_retrieval": False, + "do_tool_retrieval": False + }, + "tool_planner": { + "role": { + "role_prompt": TOOL_PLANNER_PROMPT, + "role_type": "ai", + "role_name": "tool_planner", + "role_desc": "", + "agent_type": "BaseAgent" + }, + "chat_turn": 1, + "do_search": False, + "do_doc_retrieval": False, + "do_tool_retrieval": False + }, + "tool_react": { + "role": { + "role_prompt": REACT_TOOL_PROMPT, + "role_type": "ai", + "role_name": "tool_react", + "role_desc": "", + "agent_type": "ReactAgent" + }, + "chat_turn": 5, + "do_search": False, + "do_doc_retrieval": False, + "do_tool_retrieval": False, + "stop": "观察" + }, + "code_react": { + "role": { + "role_prompt": REACT_CODE_PROMPT, + "role_type": "ai", + "role_name": "code_react", + "role_desc": "", + "agent_type": "ReactAgent" + }, + "chat_turn": 5, + "do_search": False, + "do_doc_retrieval": False, + "do_tool_retrieval": False, + "stop": "观察" + }, + "qaer": { + "role": { + "role_prompt": QA_PROMPT, + "role_type": "ai", + "role_name": "qaer", + "role_desc": "", + "agent_type": "BaseAgent" + }, + "chat_turn": 1, + "do_search": False, + "do_doc_retrieval": True, + "do_tool_retrieval": False + }, + "code_qaer": { + "role": { + "role_prompt": CODE_QA_PROMPT , + "role_type": "ai", + "role_name": "code_qaer", + "role_desc": "", + "agent_type": "BaseAgent" + }, + "chat_turn": 1, + "do_search": False, + "do_doc_retrieval": True, + "do_tool_retrieval": False + }, + "searcher": { + "role": { + "role_prompt": QA_PROMPT, + "role_type": "ai", + "role_name": "searcher", + "role_desc": "", + "agent_type": "BaseAgent" + }, + "chat_turn": 1, + "do_search": True, + "do_doc_retrieval": False, + "do_tool_retrieval": False + }, + "answer": { + "role": { + "role_prompt": "", + "role_type": "ai", + "role_name": "answer", + "role_desc": "", + "agent_type": "BaseAgent" + }, + "chat_turn": 1, + "do_search": False, + "do_doc_retrieval": False, + "do_tool_retrieval": False + }, + "data_analyst": { + "role": { + "role_prompt": """你是一个数据分析的代码开发助手,能够编写可执行的代码来完成相关的数据分析问题,使用 JSON Blob 来指定一个返回的内容,通过提供一个 action(行动)和一个 code (生成代码)和 一个 file_name (指定保存文件)。\ + 有效的 'action' 值为:'coding'(生成可执行的代码) or 'finished' (不生成代码并直接返回答案)。在每个 $JSON_BLOB 中仅提供一个 action,如下所示:\ + ```\n{{'action': $ACTION,'code_content': $CODE, 'code_filename': $FILE_NAME}}```\ + 下面开始!记住根据问题进行返回需要生成的答案,格式为 ```JSON_BLOB```""", + "role_type": "ai", + "role_name": "data_analyst", + "role_desc": "", + "agent_type": "BaseAgent" + }, + "chat_turn": 1, + "do_search": False, + "do_doc_retrieval": False, + "do_tool_retrieval": False + }, + "deveploer": { + "role": { + "role_prompt": """你是一个代码开发助手,能够编写可执行的代码来完成问题,使用 JSON Blob 来指定一个返回的内容,通过提供一个 action(行动)和一个 code (生成代码)和 一个 file_name (指定保存文件)。\ + 有效的 'action' 值为:'coding'(生成可执行的代码) or 'finished' (不生成代码并直接返回答案)。在每个 $JSON_BLOB 中仅提供一个 action,如下所示:\ + ```\n{{'action': $ACTION,'code_content': $CODE, 'code_filename': $FILE_NAME}}```\ + 下面开始!记住根据问题进行返回需要生成的答案,格式为 ```JSON_BLOB```""", + "role_type": "ai", + "role_name": "deveploer", + "role_desc": "", + "agent_type": "BaseAgent" + }, + "chat_turn": 1, + "do_search": False, + "do_doc_retrieval": False, + "do_tool_retrieval": False + }, + "tester": { + "role": { + "role_prompt": "你是一个QA问答的助手,能够尽可能准确地回答问题,下面请逐步思考问题并回答", + "role_type": "ai", + "role_name": "tester", + "role_desc": "", + "agent_type": "BaseAgent" + }, + "chat_turn": 1, + "do_search": False, + "do_doc_retrieval": False, + "do_tool_retrieval": False + } +} \ No newline at end of file diff --git a/dev_opsgpt/connector/configs/chain_config.py b/dev_opsgpt/connector/configs/chain_config.py new file mode 100644 index 0000000..5ed2605 --- /dev/null +++ b/dev_opsgpt/connector/configs/chain_config.py @@ -0,0 +1,88 @@ + + +CHAIN_CONFIGS = { + "chatChain": { + "chain_name": "chatChain", + "chain_type": "BaseChain", + "agents": ["answer"], + "chat_turn": 1, + "do_checker": False, + "clear_structure": "True", + "brainstorming": "False", + "gui_design": "True", + "git_management": "False", + "self_improve": "False" + }, + "docChatChain": { + "chain_name": "docChatChain", + "chain_type": "BaseChain", + "agents": ["qaer"], + "chat_turn": 1, + "do_checker": False, + "clear_structure": "True", + "brainstorming": "False", + "gui_design": "True", + "git_management": "False", + "self_improve": "False" + }, + "searchChatChain": { + "chain_name": "searchChatChain", + "chain_type": "BaseChain", + "agents": ["searcher"], + "chat_turn": 1, + "do_checker": False, + "clear_structure": "True", + "brainstorming": "False", + "gui_design": "True", + "git_management": "False", + "self_improve": "False" + }, + "codeChatChain": { + "chain_name": "codehChatChain", + "chain_type": "BaseChain", + "agents": ["code_qaer"], + "chat_turn": 1, + "do_checker": False, + "clear_structure": "True", + "brainstorming": "False", + "gui_design": "True", + "git_management": "False", + "self_improve": "False" + }, + "toolReactChain": { + "chain_name": "toolReactChain", + "chain_type": "BaseChain", + "agents": ["tool_planner", "tool_react"], + "chat_turn": 2, + "do_checker": True, + "clear_structure": "True", + "brainstorming": "False", + "gui_design": "True", + "git_management": "False", + "self_improve": "False" + }, + "codeReactChain": { + "chain_name": "codeReactChain", + "chain_type": "BaseChain", + "agents": ["planner", "code_react"], + "chat_turn": 2, + "do_checker": True, + "clear_structure": "True", + "brainstorming": "False", + "gui_design": "True", + "git_management": "False", + "self_improve": "False" + }, + "dataAnalystChain": { + "chain_name": "dataAnalystChain", + "chain_type": "BaseChain", + "agents": ["planner", "code_react"], + "chat_turn": 2, + "do_checker": True, + "clear_structure": "True", + "brainstorming": "False", + "gui_design": "True", + "git_management": "False", + "self_improve": "False" + }, +} diff --git a/dev_opsgpt/connector/configs/phase_config.py b/dev_opsgpt/connector/configs/phase_config.py new file mode 100644 index 0000000..30b9de3 --- /dev/null +++ b/dev_opsgpt/connector/configs/phase_config.py @@ -0,0 +1,79 @@ +PHASE_CONFIGS = { + "chatPhase": { + "phase_name": "chatPhase", + "phase_type": "BasePhase", + "chains": ["chatChain"], + "do_summary": False, + "do_search": False, + "do_doc_retrieval": False, + "do_code_retrieval": False, + "do_tool_retrieval": False, + "do_using_tool": False + }, + "docChatPhase": { + "phase_name": "docChatPhase", + "phase_type": "BasePhase", + "chains": ["docChatChain"], + "do_summary": False, + "do_search": False, + "do_doc_retrieval": True, + "do_code_retrieval": False, + "do_tool_retrieval": False, + "do_using_tool": False + }, + "searchChatPhase": { + "phase_name": "searchChatPhase", + "phase_type": "BasePhase", + "chains": ["searchChatChain"], + "do_summary": False, + "do_search": True, + "do_doc_retrieval": False, + "do_code_retrieval": False, + "do_tool_retrieval": False, + "do_using_tool": False + }, + "codeChatPhase": { + "phase_name": "codeChatPhase", + "phase_type": "BasePhase", + "chains": ["codeChatChain"], + "do_summary": False, + "do_search": False, + "do_doc_retrieval": False, + "do_code_retrieval": True, + "do_tool_retrieval": False, + "do_using_tool": False + }, + "toolReactPhase": { + "phase_name": "toolReactPhase", + "phase_type": "BasePhase", + "chains": ["toolReactChain"], + "do_summary": False, + "do_search": False, + "do_doc_retrieval": False, + "do_code_retrieval": False, + "do_tool_retrieval": False, + "do_using_tool": True + }, + "codeReactPhase": { + "phase_name": "codeReacttPhase", + "phase_type": "BasePhase", + "chains": ["codeReactChain"], + "do_summary": False, + "do_search": False, + "do_doc_retrieval": False, + "do_code_retrieval": False, + "do_tool_retrieval": False, + "do_using_tool": False + }, + "dataReactPhase": { + "phase_name": "dataReactPhase", + "phase_type": "BasePhase", + "chains": ["dataAnalystChain"], + "do_summary": True, + "do_search": False, + "do_doc_retrieval": False, + "do_code_retrieval": False, + "do_tool_retrieval": False, + "do_using_tool": False + } +} diff --git a/dev_opsgpt/connector/connector_schema.py b/dev_opsgpt/connector/connector_schema.py new file mode 100644 index 0000000..75c5387 --- /dev/null +++ b/dev_opsgpt/connector/connector_schema.py @@ -0,0 +1,248 @@ +from pydantic import BaseModel +from typing import List, Dict +from enum import Enum +import re +import json +from loguru import logger +from langchain.tools import BaseTool + + +class ActionStatus(Enum): + FINISHED = "finished" + CODING = "coding" + TOOL_USING = "tool_using" + REASONING = "reasoning" + PLANNING = "planning" + EXECUTING_CODE = "executing_code" + EXECUTING_TOOL = "executing_tool" + DEFAUILT = "default" + + def __eq__(self, other): + if isinstance(other, str): + return self.value == other + return super().__eq__(other) + +class Doc(BaseModel): + title: str + snippet: str + link: str + index: int + + def get_title(self): + return self.title + + def get_snippet(self, ): + return self.snippet + + def get_link(self, ): + return self.link + + def get_index(self, ): + return self.index + + def to_json(self): + return vars(self) + + def __str__(self,): + return f"""出处 [{self.index + 1}] 标题 [{self.title}]\n\n来源 ({self.link}) \n\n内容 {self.snippet}\n\n""" + + +class CodeDoc(BaseModel): + code: str + related_nodes: list + index: int + + def get_code(self, ): + return self.code + + def get_related_node(self, ): + return self.related_nodes + + def get_index(self, ): + return self.index + + def to_json(self): + return vars(self) + + def __str__(self,): + return f"""出处 [{self.index + 1}] \n\n来源 ({self.related_nodes}) \n\n内容 {self.code}\n\n""" + + +class Docs: + + def __init__(self, docs: List[Doc]): + self.titles: List[str] = [doc.get_title() for doc in docs] + self.snippets: List[str] = [doc.get_snippet() for doc in docs] + self.links: List[str] = [doc.get_link() for doc in docs] + self.indexs: List[int] = [doc.get_index() for doc in docs] + +class Task(BaseModel): + task_type: str + task_name: str + task_desc: str + task_prompt: str + # def __init__(self, task_type, task_name, task_desc) -> None: + # self.task_type = task_type + # self.task_name = task_name + # self.task_desc = task_desc + +class Env(BaseModel): + env_type: str + env_name: str + env_desc:str + + +class Role(BaseModel): + role_type: str + role_name: str + role_desc: str + agent_type: str = "" + role_prompt: str = "" + template_prompt: str = "" + + + +class ChainConfig(BaseModel): + chain_name: str + chain_type: str + agents: List[str] + do_checker: bool = False + chat_turn: int = 1 + clear_structure: bool = False + brainstorming: bool = False + gui_design: bool = True + git_management: bool = False + self_improve: bool = False + + +class AgentConfig(BaseModel): + role: Role + chat_turn: int = 1 + do_search: bool = False + do_doc_retrieval: bool = False + do_tool_retrieval: bool = False + + +class PhaseConfig(BaseModel): + phase_name: str + phase_type: str + chains: List[str] + do_summary: bool = False + do_search: bool = False + do_doc_retrieval: bool = False + do_code_retrieval: bool = False + do_tool_retrieval: bool = False + +class Message(BaseModel): + role_name: str + role_type: str + role_prompt: str = None + input_query: str = None + + # 模型最终返回 + role_content: str = None + role_contents: List[str] = [] + step_content: str = None + step_contents: List[str] = [] + chain_content: str = None + chain_contents: List[str] = [] + + # 模型结果解析 + plans: List[str] = None + code_content: str = None + code_filename: str = None + tool_params: str = None + tool_name: str = None + + # 执行结果 + action_status: str = ActionStatus.DEFAUILT + code_answer: str = None + tool_answer: str = None + observation: str = None + figures: Dict[str, str] = {} + + # 辅助信息 + tools: List[BaseTool] = [] + task: Task = None + db_docs: List['Doc'] = [] + code_docs: List['CodeDoc'] = [] + search_docs: List['Doc'] = [] + + # 执行输入 + phase_name: str = None + chain_name: str = None + do_search: bool = False + doc_engine_name: str = None + code_engine_name: str = None + search_engine_name: str = None + top_k: int = 3 + score_threshold: float = 1.0 + do_doc_retrieval: bool = False + do_code_retrieval: bool = False + do_tool_retrieval: bool = False + history_node_list: List[str] = [] + + + def to_tuple_message(self, return_all: bool = False, content_key="role_content"): + if content_key == "role_content": + role_content = self.role_content + elif content_key == "step_content": + role_content = self.step_content or self.role_content + else: + role_content =self.role_content + + if return_all: + return (self.role_name, self.role_type, role_content) + else: + return (self.role_name, role_content) + return (self.role_type, re.sub("}", "}}", re.sub("{", "{{", str(self.role_content)))) + + def to_dict_message(self, return_all: bool = False, content_key="role_content"): + if content_key == "role_content": + role_content =self.role_content + elif content_key == "step_content": + role_content = self.step_content or self.role_content + else: + role_content =self.role_content + + if return_all: + return vars(self) + else: + return {"role": self.role_name, "content": role_content} + + def is_system_role(self,): + return self.role_type == "system" + + def __str__(self) -> str: + # key_str = '\n'.join([k for k, v in vars(self).items()]) + # logger.debug(f"{key_str}") + return "\n".join([": ".join([k, str(v)]) for k, v in vars(self).items()]) + + + +def load_role_configs(config) -> Dict[str, AgentConfig]: + if isinstance(config, str): + with open(config, 'r', encoding="utf8") as file: + configs = json.load(file) + else: + configs = config + + return {name: AgentConfig(**v) for name, v in configs.items()} + + +def load_chain_configs(config) -> Dict[str, ChainConfig]: + if isinstance(config, str): + with open(config, 'r', encoding="utf8") as file: + configs = json.load(file) + else: + configs = config + return {name: ChainConfig(**v) for name, v in configs.items()} + + +def load_phase_configs(config) -> Dict[str, PhaseConfig]: + if isinstance(config, str): + with open(config, 'r', encoding="utf8") as file: + configs = json.load(file) + else: + configs = config + return {name: PhaseConfig(**v) for name, v in configs.items()} \ No newline at end of file diff --git a/dev_opsgpt/connector/phase/__init__.py b/dev_opsgpt/connector/phase/__init__.py new file mode 100644 index 0000000..73332dc --- /dev/null +++ b/dev_opsgpt/connector/phase/__init__.py @@ -0,0 +1,3 @@ +from .base_phase import BasePhase + +__all__ = ["BasePhase"] \ No newline at end of file diff --git a/dev_opsgpt/connector/phase/base_phase.py b/dev_opsgpt/connector/phase/base_phase.py new file mode 100644 index 0000000..cea23f4 --- /dev/null +++ b/dev_opsgpt/connector/phase/base_phase.py @@ -0,0 +1,215 @@ +from typing import List, Union, Dict, Tuple +import os +import json +import importlib +import copy +from loguru import logger + +from dev_opsgpt.connector.agents import BaseAgent +from dev_opsgpt.connector.chains import BaseChain +from dev_opsgpt.tools.base_tool import BaseTools, Tool +from dev_opsgpt.connector.shcema.memory import Memory +from dev_opsgpt.connector.connector_schema import ( + Task, Env, Role, Message, Doc, Docs, AgentConfig, ChainConfig, PhaseConfig, CodeDoc, + load_chain_configs, load_phase_configs, load_role_configs +) +from dev_opsgpt.connector.configs import AGETN_CONFIGS, CHAIN_CONFIGS, PHASE_CONFIGS +from dev_opsgpt.tools import DDGSTool, DocRetrieval, CodeRetrieval + + +role_configs = load_role_configs(AGETN_CONFIGS) +chain_configs = load_chain_configs(CHAIN_CONFIGS) +phase_configs = load_phase_configs(PHASE_CONFIGS) + + +CUR_DIR = os.path.dirname(os.path.abspath(__file__)) + + +class BasePhase: + + def __init__( + self, + phase_name: str, + task: Task = None, + do_summary: bool = False, + do_search: bool = False, + do_doc_retrieval: bool = False, + do_code_retrieval: bool = False, + do_tool_retrieval: bool = False, + phase_config: Union[dict, str] = PHASE_CONFIGS, + chain_config: Union[dict, str] = CHAIN_CONFIGS, + role_config: Union[dict, str] = AGETN_CONFIGS, + ) -> None: + self.conv_summary_agent = BaseAgent(role=role_configs["conv_summary"].role, + task = None, + memory = None, + do_search = role_configs["conv_summary"].do_search, + do_doc_retrieval = role_configs["conv_summary"].do_doc_retrieval, + do_tool_retrieval = role_configs["conv_summary"].do_tool_retrieval, + do_filter=False, do_use_self_memory=False) + + self.chains: List[BaseChain] = self.init_chains( + phase_name, + task=task, + memory=None, + phase_config = phase_config, + chain_config = chain_config, + role_config = role_config, + ) + self.phase_name = phase_name + self.do_summary = do_summary + self.do_search = do_search + self.do_code_retrieval = do_code_retrieval + self.do_doc_retrieval = do_doc_retrieval + self.do_tool_retrieval = do_tool_retrieval + + self.global_message = Memory([]) + # self.chain_message = Memory([]) + self.phase_memory: List[Memory] = [] + + def step(self, query: Message, history: Memory = None) -> Tuple[Message, Memory]: + summary_message = None + chain_message = Memory([]) + local_memory = Memory([]) + # do_search、do_doc_search、do_code_search + query = self.get_extrainfo_step(query) + input_message = copy.deepcopy(query) + + self.global_message.append(input_message) + for chain in self.chains: + # chain can supply background and query to next chain + output_message, chain_memory = chain.step(input_message, history, background=chain_message) + output_message = self.inherit_extrainfo(input_message, output_message) + input_message = output_message + logger.info(f"{chain.chainConfig.chain_name} phase_step: {output_message.role_content}") + + self.global_message.append(output_message) + local_memory.extend(chain_memory) + + # whether use summary_llm + if self.do_summary: + logger.info(f"{self.conv_summary_agent.role.role_name} input global memory: {self.global_message.to_str_messages(content_key='step_content')}") + logger.info(f"{self.conv_summary_agent.role.role_name} input global memory: {self.global_message.to_str_messages(content_key='role_content')}") + summary_message = self.conv_summary_agent.run(query, background=self.global_message) + summary_message.role_name = chain.chainConfig.chain_name + summary_message = self.conv_summary_agent.parser(summary_message) + summary_message = self.conv_summary_agent.filter(summary_message) + summary_message = self.inherit_extrainfo(output_message, summary_message) + chain_message.append(summary_message) + + # 由于不会存在多轮chain执行,所以直接保留memory即可 + for chain in self.chains: + self.phase_memory.append(chain.global_memory) + + message = summary_message or output_message + message.role_name = self.phase_name + # message.db_docs = query.db_docs + # message.code_docs = query.code_docs + # message.search_docs = query.search_docs + return summary_message or output_message, local_memory + + def init_chains(self, phase_name, phase_config, chain_config, + role_config, task=None, memory=None) -> List[BaseChain]: + # load config + role_configs = load_role_configs(role_config) + chain_configs = load_chain_configs(chain_config) + phase_configs = load_phase_configs(phase_config) + + chains = [] + self.chain_module = importlib.import_module("dev_opsgpt.connector.chains") + self.agent_module = importlib.import_module("dev_opsgpt.connector.agents") + phase = phase_configs.get(phase_name) + for chain_name in phase.chains: + logger.info(f"chain_name: {chain_name}") + # chain_class = getattr(self.chain_module, chain_name) + logger.debug(f"{chain_configs.keys()}") + chain_config = chain_configs[chain_name] + + agents = [ + getattr(self.agent_module, role_configs[agent_name].role.agent_type)( + role_configs[agent_name].role, + task = task, + memory = memory, + chat_turn=role_configs[agent_name].chat_turn, + do_search = role_configs[agent_name].do_search, + do_doc_retrieval = role_configs[agent_name].do_doc_retrieval, + do_tool_retrieval = role_configs[agent_name].do_tool_retrieval, + ) + for agent_name in chain_config.agents + ] + chain_instance = BaseChain( + chain_config, agents, chain_config.chat_turn, + do_checker=chain_configs[chain_name].do_checker, + do_code_exec=False,) + chains.append(chain_instance) + + return chains + + def get_extrainfo_step(self, input_message): + if self.do_doc_retrieval: + input_message = self.get_doc_retrieval(input_message) + + logger.debug(F"self.do_code_retrieval: {self.do_code_retrieval}") + if self.do_code_retrieval: + input_message = self.get_code_retrieval(input_message) + + if self.do_search: + input_message = self.get_search_retrieval(input_message) + + return input_message + + def inherit_extrainfo(self, input_message: Message, output_message: Message): + output_message.db_docs = input_message.db_docs + output_message.search_docs = input_message.search_docs + output_message.code_docs = input_message.code_docs + output_message.figures.update(input_message.figures) + return output_message + + def get_search_retrieval(self, message: Message,) -> Message: + SEARCH_ENGINES = {"duckduckgo": DDGSTool} + search_docs = [] + for idx, doc in enumerate(SEARCH_ENGINES["duckduckgo"].run(message.role_content, 3)): + doc.update({"index": idx}) + search_docs.append(Doc(**doc)) + message.search_docs = search_docs + return message + + def get_doc_retrieval(self, message: Message) -> Message: + query = message.role_content + knowledge_basename = message.doc_engine_name + top_k = message.top_k + score_threshold = message.score_threshold + if knowledge_basename: + docs = DocRetrieval.run(query, knowledge_basename, top_k, score_threshold) + message.db_docs = [Doc(**doc) for doc in docs] + return message + + def get_code_retrieval(self, message: Message) -> Message: + # DocRetrieval.run("langchain是什么", "DSADSAD") + query = message.input_query + code_engine_name = message.code_engine_name + history_node_list = message.history_node_list + code_docs = CodeRetrieval.run(code_engine_name, query, code_limit=message.top_k, history_node_list=history_node_list) + message.code_docs = [CodeDoc(**doc) for doc in code_docs] + return message + + def get_tool_retrieval(self, message: Message) -> Message: + return message + + def update(self) -> Memory: + pass + + def get_memory(self, ) -> Memory: + return Memory.from_memory_list( + [chain.get_memory() for chain in self.chains] + ) + + def get_memory_str(self, do_all_memory=True, content_key="role_content") -> str: + memory = self.global_message if do_all_memory else self.phase_memory + return "\n".join([": ".join(i) for i in memory.to_tuple_messages(content_key=content_key)]) + + def get_chains_memory(self, content_key="role_content") -> List[Tuple]: + return [memory.to_tuple_messages(content_key=content_key) for memory in self.phase_memory] + + def get_chains_memory_str(self, content_key="role_content") -> str: + return "************".join([f"{chain.chainConfig.chain_name}\n" + chain.get_memory_str(content_key=content_key) for chain in self.chains]) \ No newline at end of file diff --git a/dev_opsgpt/connector/shcema/__init__.py b/dev_opsgpt/connector/shcema/__init__.py new file mode 100644 index 0000000..d523fc9 --- /dev/null +++ b/dev_opsgpt/connector/shcema/__init__.py @@ -0,0 +1,6 @@ +from .memory import Memory + + +__all__ = [ + "Memory" +] \ No newline at end of file diff --git a/dev_opsgpt/connector/shcema/memory.py b/dev_opsgpt/connector/shcema/memory.py new file mode 100644 index 0000000..f5e6a5f --- /dev/null +++ b/dev_opsgpt/connector/shcema/memory.py @@ -0,0 +1,88 @@ +from pydantic import BaseModel +from typing import List +from loguru import logger + +from dev_opsgpt.connector.connector_schema import Message +from dev_opsgpt.utils.common_utils import ( + save_to_jsonl_file, save_to_json_file, read_json_file, read_jsonl_file +) + + +class Memory: + + def __init__(self, messages: List[Message] = []): + self.messages = messages + + def append(self, message: Message): + self.messages.append(message) + + def extend(self, memory: 'Memory'): + self.messages.extend(memory.messages) + + def update(self, role_name: str, role_type: str, role_content: str): + self.messages.append(Message(role_name, role_type, role_content, role_content)) + + def clear(self, ): + self.messages = [] + + def delete(self, ): + pass + + def get_messages(self, ) -> List[Message]: + return self.messages + + def save(self, file_type="jsonl", return_all=True): + try: + if file_type == "jsonl": + save_to_jsonl_file(self.to_dict_messages(return_all=return_all), "role_name_history"+f".{file_type}") + return True + elif file_type in ["json", "txt"]: + save_to_json_file(self.to_dict_messages(return_all=return_all), "role_name_history"+f".{file_type}") + return True + except: + return False + return False + + def load(self, filepath): + file_type = filepath + try: + if file_type == "jsonl": + self.messages = [Message(**message) for message in read_jsonl_file(filepath)] + return True + elif file_type in ["json", "txt"]: + self.messages = [Message(**message) for message in read_jsonl_file(filepath)] + return True + except: + return False + + return False + + def to_tuple_messages(self, return_system: bool = False, return_all: bool = False, content_key="role_content"): + # logger.debug(f"{[message.to_tuple_message(return_all, content_key) for message in self.messages ]}") + return [ + message.to_tuple_message(return_all, content_key) for message in self.messages + if not message.is_system_role() | return_system + ] + + def to_dict_messages(self, return_system: bool = False, return_all: bool = False, content_key="role_content"): + return [ + message.to_dict_message(return_all, content_key) for message in self.messages + if not message.is_system_role() | return_system + ] + + def to_str_messages(self, return_system: bool = False, return_all: bool = False, content_key="role_content"): + # logger.debug(f"{[message.to_tuple_message(return_all, content_key) for message in self.messages ]}") + return "\n".join([ + ": ".join(message.to_tuple_message(return_all, content_key)) for message in self.messages + if not message.is_system_role() | return_system + ]) + + @classmethod + def from_memory_list(cls, memorys: List['Memory']) -> 'Memory': + return cls([message for memory in memorys for message in memory.get_messages()]) + + def __len__(self, ): + return len(self.messages) + + def __str__(self) -> str: + return "\n".join([":".join(i) for i in self.to_tuple_messages()]) \ No newline at end of file diff --git a/dev_opsgpt/connector/utils.py b/dev_opsgpt/connector/utils.py new file mode 100644 index 0000000..9684505 --- /dev/null +++ b/dev_opsgpt/connector/utils.py @@ -0,0 +1,27 @@ + + +def prompt_cost(model_type: str, num_prompt_tokens: float, num_completion_tokens: float): + input_cost_map = { + "gpt-3.5-turbo": 0.0015, + "gpt-3.5-turbo-16k": 0.003, + "gpt-3.5-turbo-0613": 0.0015, + "gpt-3.5-turbo-16k-0613": 0.003, + "gpt-4": 0.03, + "gpt-4-0613": 0.03, + "gpt-4-32k": 0.06, + } + + output_cost_map = { + "gpt-3.5-turbo": 0.002, + "gpt-3.5-turbo-16k": 0.004, + "gpt-3.5-turbo-0613": 0.002, + "gpt-3.5-turbo-16k-0613": 0.004, + "gpt-4": 0.06, + "gpt-4-0613": 0.06, + "gpt-4-32k": 0.12, + } + + if model_type not in input_cost_map or model_type not in output_cost_map: + return -1 + + return num_prompt_tokens * input_cost_map[model_type] / 1000.0 + num_completion_tokens * output_cost_map[model_type] / 1000.0 diff --git a/dev_opsgpt/document_loaders/json_loader.py b/dev_opsgpt/document_loaders/json_loader.py index f5d9a6a..e931f60 100644 --- a/dev_opsgpt/document_loaders/json_loader.py +++ b/dev_opsgpt/document_loaders/json_loader.py @@ -4,6 +4,7 @@ from typing import AnyStr, Callable, Dict, List, Optional, Union from langchain.docstore.document import Document from langchain.document_loaders.base import BaseLoader +from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter from dev_opsgpt.utils.common_utils import read_json_file @@ -39,3 +40,22 @@ class JSONLoader(BaseLoader): ) text = sample.get(self.schema_key, "") docs.append(Document(page_content=text, metadata=metadata)) + + def load_and_split( + self, text_splitter: Optional[TextSplitter] = None + ) -> List[Document]: + """Load Documents and split into chunks. Chunks are returned as Documents. + + Args: + text_splitter: TextSplitter instance to use for splitting documents. + Defaults to RecursiveCharacterTextSplitter. + + Returns: + List of Documents. + """ + if text_splitter is None: + _text_splitter: TextSplitter = RecursiveCharacterTextSplitter() + else: + _text_splitter = text_splitter + docs = self.load() + return _text_splitter.split_documents(docs) \ No newline at end of file diff --git a/dev_opsgpt/document_loaders/jsonl_loader.py b/dev_opsgpt/document_loaders/jsonl_loader.py index 0f9de8e..5a50d36 100644 --- a/dev_opsgpt/document_loaders/jsonl_loader.py +++ b/dev_opsgpt/document_loaders/jsonl_loader.py @@ -4,6 +4,7 @@ from typing import AnyStr, Callable, Dict, List, Optional, Union from langchain.docstore.document import Document from langchain.document_loaders.base import BaseLoader +from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter from dev_opsgpt.utils.common_utils import read_jsonl_file @@ -39,3 +40,23 @@ class JSONLLoader(BaseLoader): ) text = sample.get(self.schema_key, "") docs.append(Document(page_content=text, metadata=metadata)) + + def load_and_split( + self, text_splitter: Optional[TextSplitter] = None + ) -> List[Document]: + """Load Documents and split into chunks. Chunks are returned as Documents. + + Args: + text_splitter: TextSplitter instance to use for splitting documents. + Defaults to RecursiveCharacterTextSplitter. + + Returns: + List of Documents. + """ + if text_splitter is None: + _text_splitter: TextSplitter = RecursiveCharacterTextSplitter() + else: + _text_splitter = text_splitter + + docs = self.load() + return _text_splitter.split_documents(docs) \ No newline at end of file diff --git a/dev_opsgpt/embeddings/faiss_m.py b/dev_opsgpt/embeddings/faiss_m.py new file mode 100644 index 0000000..505f845 --- /dev/null +++ b/dev_opsgpt/embeddings/faiss_m.py @@ -0,0 +1,776 @@ +"""Wrapper around FAISS vector database.""" +from __future__ import annotations + +import operator +import os +import pickle +import uuid +import warnings +from pathlib import Path +from typing import ( + Any, + Callable, + Dict, + Iterable, + List, + Optional, + Sized, + Tuple, +) + +import numpy as np + +from langchain.docstore.base import AddableMixin, Docstore +from langchain.docstore.document import Document +from langchain.docstore.in_memory import InMemoryDocstore +from langchain.embeddings.base import Embeddings +from langchain.vectorstores.base import VectorStore +from langchain.vectorstores.utils import DistanceStrategy, maximal_marginal_relevance + + +def dependable_faiss_import(no_avx2: Optional[bool] = None) -> Any: + """ + Import faiss if available, otherwise raise error. + If FAISS_NO_AVX2 environment variable is set, it will be considered + to load FAISS with no AVX2 optimization. + + Args: + no_avx2: Load FAISS strictly with no AVX2 optimization + so that the vectorstore is portable and compatible with other devices. + """ + if no_avx2 is None and "FAISS_NO_AVX2" in os.environ: + no_avx2 = bool(os.getenv("FAISS_NO_AVX2")) + + try: + if no_avx2: + from faiss import swigfaiss as faiss + else: + import faiss + except ImportError: + raise ImportError( + "Could not import faiss python package. " + "Please install it with `pip install faiss-gpu` (for CUDA supported GPU) " + "or `pip install faiss-cpu` (depending on Python version)." + ) + return faiss + + +def _len_check_if_sized(x: Any, y: Any, x_name: str, y_name: str) -> None: + if isinstance(x, Sized) and isinstance(y, Sized) and len(x) != len(y): + raise ValueError( + f"{x_name} and {y_name} expected to be equal length but " + f"len({x_name})={len(x)} and len({y_name})={len(y)}" + ) + return + + +class FAISS(VectorStore): + """Wrapper around FAISS vector database. + + To use, you must have the ``faiss`` python package installed. + + Example: + .. code-block:: python + + from langchain.embeddings.openai import OpenAIEmbeddings + from langchain.vectorstores import FAISS + + embeddings = OpenAIEmbeddings() + texts = ["FAISS is an important library", "LangChain supports FAISS"] + faiss = FAISS.from_texts(texts, embeddings) + + """ + + def __init__( + self, + embedding_function: Callable, + index: Any, + docstore: Docstore, + index_to_docstore_id: Dict[int, str], + relevance_score_fn: Optional[Callable[[float], float]] = None, + normalize_L2: bool = False, + distance_strategy: DistanceStrategy = DistanceStrategy.EUCLIDEAN_DISTANCE, + ): + """Initialize with necessary components.""" + self.embedding_function = embedding_function + self.index = index + self.docstore = docstore + self.index_to_docstore_id = index_to_docstore_id + self.distance_strategy = distance_strategy + self.override_relevance_score_fn = relevance_score_fn + self._normalize_L2 = normalize_L2 + if ( + self.distance_strategy != DistanceStrategy.EUCLIDEAN_DISTANCE + and self._normalize_L2 + ): + warnings.warn( + "Normalizing L2 is not applicable for metric type: {strategy}".format( + strategy=self.distance_strategy + ) + ) + + def __add( + self, + texts: Iterable[str], + embeddings: Iterable[List[float]], + metadatas: Optional[Iterable[dict]] = None, + ids: Optional[List[str]] = None, + ) -> List[str]: + faiss = dependable_faiss_import() + + if not isinstance(self.docstore, AddableMixin): + raise ValueError( + "If trying to add texts, the underlying docstore should support " + f"adding items, which {self.docstore} does not" + ) + + _len_check_if_sized(texts, metadatas, "texts", "metadatas") + _metadatas = metadatas or ({} for _ in texts) + documents = [ + Document(page_content=t, metadata=m) for t, m in zip(texts, _metadatas) + ] + + _len_check_if_sized(documents, embeddings, "documents", "embeddings") + _len_check_if_sized(documents, ids, "documents", "ids") + + # Add to the index. + vector = np.array(embeddings, dtype=np.float32) + if self._normalize_L2: + faiss.normalize_L2(vector) + self.index.add(vector) + + # Add information to docstore and index. + ids = ids or [str(uuid.uuid4()) for _ in texts] + self.docstore.add({id_: doc for id_, doc in zip(ids, documents)}) + starting_len = len(self.index_to_docstore_id) + index_to_id = {starting_len + j: id_ for j, id_ in enumerate(ids)} + self.index_to_docstore_id.update(index_to_id) + return ids + + def add_texts( + self, + texts: Iterable[str], + metadatas: Optional[List[dict]] = None, + ids: Optional[List[str]] = None, + **kwargs: Any, + ) -> List[str]: + """Run more texts through the embeddings and add to the vectorstore. + + Args: + texts: Iterable of strings to add to the vectorstore. + metadatas: Optional list of metadatas associated with the texts. + ids: Optional list of unique IDs. + + Returns: + List of ids from adding the texts into the vectorstore. + """ + # embeddings = [self.embedding_function(text) for text in texts] + embeddings = self.embedding_function(texts) + return self.__add(texts, embeddings, metadatas=metadatas, ids=ids) + + def add_embeddings( + self, + text_embeddings: Iterable[Tuple[str, List[float]]], + metadatas: Optional[List[dict]] = None, + ids: Optional[List[str]] = None, + **kwargs: Any, + ) -> List[str]: + """Run more texts through the embeddings and add to the vectorstore. + + Args: + text_embeddings: Iterable pairs of string and embedding to + add to the vectorstore. + metadatas: Optional list of metadatas associated with the texts. + ids: Optional list of unique IDs. + + Returns: + List of ids from adding the texts into the vectorstore. + """ + # Embed and create the documents. + texts, embeddings = zip(*text_embeddings) + return self.__add(texts, embeddings, metadatas=metadatas, ids=ids) + + def similarity_search_with_score_by_vector( + self, + embedding: List[float], + k: int = 4, + filter: Optional[Dict[str, Any]] = None, + fetch_k: int = 20, + **kwargs: Any, + ) -> List[Tuple[Document, float]]: + """Return docs most similar to query. + + Args: + embedding: Embedding vector to look up documents similar to. + k: Number of Documents to return. Defaults to 4. + filter (Optional[Dict[str, Any]]): Filter by metadata. Defaults to None. + fetch_k: (Optional[int]) Number of Documents to fetch before filtering. + Defaults to 20. + **kwargs: kwargs to be passed to similarity search. Can include: + score_threshold: Optional, a floating point value between 0 to 1 to + filter the resulting set of retrieved docs + + Returns: + List of documents most similar to the query text and L2 distance + in float for each. Lower score represents more similarity. + """ + faiss = dependable_faiss_import() + vector = np.array([embedding], dtype=np.float32) + if self._normalize_L2: + faiss.normalize_L2(vector) + scores, indices = self.index.search(vector, k if filter is None else fetch_k) + docs = [] + for j, i in enumerate(indices[0]): + if i == -1: + # This happens when not enough docs are returned. + continue + _id = self.index_to_docstore_id[i] + doc = self.docstore.search(_id) + if not isinstance(doc, Document): + raise ValueError(f"Could not find document for id {_id}, got {doc}") + if filter is not None: + filter = { + key: [value] if not isinstance(value, list) else value + for key, value in filter.items() + } + if all(doc.metadata.get(key) in value for key, value in filter.items()): + docs.append((doc, scores[0][j])) + else: + docs.append((doc, scores[0][j])) + + score_threshold = kwargs.get("score_threshold") + if score_threshold is not None: + cmp = ( + operator.ge + if self.distance_strategy + in (DistanceStrategy.MAX_INNER_PRODUCT, DistanceStrategy.JACCARD) + else operator.le + ) + docs = [ + (doc, similarity) + for doc, similarity in docs + if cmp(similarity, score_threshold) + ] + return docs[:k] + + def similarity_search_with_score( + self, + query: str, + k: int = 4, + filter: Optional[Dict[str, Any]] = None, + fetch_k: int = 20, + **kwargs: Any, + ) -> List[Tuple[Document, float]]: + """Return docs most similar to query. + + Args: + query: Text to look up documents similar to. + k: Number of Documents to return. Defaults to 4. + filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. + fetch_k: (Optional[int]) Number of Documents to fetch before filtering. + Defaults to 20. + + Returns: + List of documents most similar to the query text with + L2 distance in float. Lower score represents more similarity. + """ + embedding = self.embedding_function(query) + docs = self.similarity_search_with_score_by_vector( + embedding, + k, + filter=filter, + fetch_k=fetch_k, + **kwargs, + ) + return docs + + def similarity_search_by_vector( + self, + embedding: List[float], + k: int = 4, + filter: Optional[Dict[str, Any]] = None, + fetch_k: int = 20, + **kwargs: Any, + ) -> List[Document]: + """Return docs most similar to embedding vector. + + Args: + embedding: Embedding to look up documents similar to. + k: Number of Documents to return. Defaults to 4. + filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. + fetch_k: (Optional[int]) Number of Documents to fetch before filtering. + Defaults to 20. + + Returns: + List of Documents most similar to the embedding. + """ + docs_and_scores = self.similarity_search_with_score_by_vector( + embedding, + k, + filter=filter, + fetch_k=fetch_k, + **kwargs, + ) + return [doc for doc, _ in docs_and_scores] + + def similarity_search( + self, + query: str, + k: int = 4, + filter: Optional[Dict[str, Any]] = None, + fetch_k: int = 20, + **kwargs: Any, + ) -> List[Document]: + """Return docs most similar to query. + + Args: + query: Text to look up documents similar to. + k: Number of Documents to return. Defaults to 4. + filter: (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. + fetch_k: (Optional[int]) Number of Documents to fetch before filtering. + Defaults to 20. + + Returns: + List of Documents most similar to the query. + """ + docs_and_scores = self.similarity_search_with_score( + query, k, filter=filter, fetch_k=fetch_k, **kwargs + ) + return [doc for doc, _ in docs_and_scores] + + def max_marginal_relevance_search_with_score_by_vector( + self, + embedding: List[float], + *, + k: int = 4, + fetch_k: int = 20, + lambda_mult: float = 0.5, + filter: Optional[Dict[str, Any]] = None, + ) -> List[Tuple[Document, float]]: + """Return docs and their similarity scores selected using the maximal marginal + relevance. + + Maximal marginal relevance optimizes for similarity to query AND diversity + among selected documents. + + Args: + embedding: Embedding to look up documents similar to. + k: Number of Documents to return. Defaults to 4. + fetch_k: Number of Documents to fetch before filtering to + pass to MMR algorithm. + lambda_mult: Number between 0 and 1 that determines the degree + of diversity among the results with 0 corresponding + to maximum diversity and 1 to minimum diversity. + Defaults to 0.5. + Returns: + List of Documents and similarity scores selected by maximal marginal + relevance and score for each. + """ + scores, indices = self.index.search( + np.array([embedding], dtype=np.float32), + fetch_k if filter is None else fetch_k * 2, + ) + if filter is not None: + filtered_indices = [] + for i in indices[0]: + if i == -1: + # This happens when not enough docs are returned. + continue + _id = self.index_to_docstore_id[i] + doc = self.docstore.search(_id) + if not isinstance(doc, Document): + raise ValueError(f"Could not find document for id {_id}, got {doc}") + if all( + doc.metadata.get(key) in value + if isinstance(value, list) + else doc.metadata.get(key) == value + for key, value in filter.items() + ): + filtered_indices.append(i) + indices = np.array([filtered_indices]) + # -1 happens when not enough docs are returned. + embeddings = [self.index.reconstruct(int(i)) for i in indices[0] if i != -1] + mmr_selected = maximal_marginal_relevance( + np.array([embedding], dtype=np.float32), + embeddings, + k=k, + lambda_mult=lambda_mult, + ) + selected_indices = [indices[0][i] for i in mmr_selected] + selected_scores = [scores[0][i] for i in mmr_selected] + docs_and_scores = [] + for i, score in zip(selected_indices, selected_scores): + if i == -1: + # This happens when not enough docs are returned. + continue + _id = self.index_to_docstore_id[i] + doc = self.docstore.search(_id) + if not isinstance(doc, Document): + raise ValueError(f"Could not find document for id {_id}, got {doc}") + docs_and_scores.append((doc, score)) + return docs_and_scores + + def max_marginal_relevance_search_by_vector( + self, + embedding: List[float], + k: int = 4, + fetch_k: int = 20, + lambda_mult: float = 0.5, + filter: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> List[Document]: + """Return docs selected using the maximal marginal relevance. + + Maximal marginal relevance optimizes for similarity to query AND diversity + among selected documents. + + Args: + embedding: Embedding to look up documents similar to. + k: Number of Documents to return. Defaults to 4. + fetch_k: Number of Documents to fetch before filtering to + pass to MMR algorithm. + lambda_mult: Number between 0 and 1 that determines the degree + of diversity among the results with 0 corresponding + to maximum diversity and 1 to minimum diversity. + Defaults to 0.5. + Returns: + List of Documents selected by maximal marginal relevance. + """ + docs_and_scores = self.max_marginal_relevance_search_with_score_by_vector( + embedding, k=k, fetch_k=fetch_k, lambda_mult=lambda_mult, filter=filter + ) + return [doc for doc, _ in docs_and_scores] + + def max_marginal_relevance_search( + self, + query: str, + k: int = 4, + fetch_k: int = 20, + lambda_mult: float = 0.5, + filter: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> List[Document]: + """Return docs selected using the maximal marginal relevance. + + Maximal marginal relevance optimizes for similarity to query AND diversity + among selected documents. + + Args: + query: Text to look up documents similar to. + k: Number of Documents to return. Defaults to 4. + fetch_k: Number of Documents to fetch before filtering (if needed) to + pass to MMR algorithm. + lambda_mult: Number between 0 and 1 that determines the degree + of diversity among the results with 0 corresponding + to maximum diversity and 1 to minimum diversity. + Defaults to 0.5. + Returns: + List of Documents selected by maximal marginal relevance. + """ + embedding = self.embedding_function(query) + docs = self.max_marginal_relevance_search_by_vector( + embedding, + k=k, + fetch_k=fetch_k, + lambda_mult=lambda_mult, + filter=filter, + **kwargs, + ) + return docs + + def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> Optional[bool]: + """Delete by ID. These are the IDs in the vectorstore. + + Args: + ids: List of ids to delete. + + Returns: + Optional[bool]: True if deletion is successful, + False otherwise, None if not implemented. + """ + if ids is None: + raise ValueError("No ids provided to delete.") + missing_ids = set(ids).difference(self.index_to_docstore_id.values()) + if missing_ids: + raise ValueError( + f"Some specified ids do not exist in the current store. Ids not found: " + f"{missing_ids}" + ) + + reversed_index = {id_: idx for idx, id_ in self.index_to_docstore_id.items()} + index_to_delete = [reversed_index[id_] for id_ in ids] + + self.index.remove_ids(np.array(index_to_delete, dtype=np.int64)) + self.docstore.delete(ids) + + remaining_ids = [ + id_ + for i, id_ in sorted(self.index_to_docstore_id.items()) + if i not in index_to_delete + ] + self.index_to_docstore_id = {i: id_ for i, id_ in enumerate(remaining_ids)} + + return True + + def merge_from(self, target: FAISS) -> None: + """Merge another FAISS object with the current one. + + Add the target FAISS to the current one. + + Args: + target: FAISS object you wish to merge into the current one + + Returns: + None. + """ + if not isinstance(self.docstore, AddableMixin): + raise ValueError("Cannot merge with this type of docstore") + # Numerical index for target docs are incremental on existing ones + starting_len = len(self.index_to_docstore_id) + + # Merge two IndexFlatL2 + self.index.merge_from(target.index) + + # Get id and docs from target FAISS object + full_info = [] + for i, target_id in target.index_to_docstore_id.items(): + doc = target.docstore.search(target_id) + if not isinstance(doc, Document): + raise ValueError("Document should be returned") + full_info.append((starting_len + i, target_id, doc)) + + # Add information to docstore and index_to_docstore_id. + self.docstore.add({_id: doc for _, _id, doc in full_info}) + index_to_id = {index: _id for index, _id, _ in full_info} + self.index_to_docstore_id.update(index_to_id) + + @classmethod + def __from( + cls, + texts: Iterable[str], + embeddings: List[List[float]], + embedding: Embeddings, + metadatas: Optional[Iterable[dict]] = None, + ids: Optional[List[str]] = None, + normalize_L2: bool = False, + distance_strategy: DistanceStrategy = DistanceStrategy.EUCLIDEAN_DISTANCE, + **kwargs: Any, + ) -> FAISS: + faiss = dependable_faiss_import() + if distance_strategy == DistanceStrategy.MAX_INNER_PRODUCT: + index = faiss.IndexFlatIP(len(embeddings[0])) + else: + # Default to L2, currently other metric types not initialized. + index = faiss.IndexFlatL2(len(embeddings[0])) + vecstore = cls( + embedding.embed_query, + index, + InMemoryDocstore(), + {}, + normalize_L2=normalize_L2, + distance_strategy=distance_strategy, + **kwargs, + ) + vecstore.__add(texts, embeddings, metadatas=metadatas, ids=ids) + return vecstore + + @classmethod + def from_texts( + cls, + texts: List[str], + embedding: Embeddings, + metadatas: Optional[List[dict]] = None, + ids: Optional[List[str]] = None, + **kwargs: Any, + ) -> FAISS: + """Construct FAISS wrapper from raw documents. + + This is a user friendly interface that: + 1. Embeds documents. + 2. Creates an in memory docstore + 3. Initializes the FAISS database + + This is intended to be a quick way to get started. + + Example: + .. code-block:: python + + from langchain import FAISS + from langchain.embeddings import OpenAIEmbeddings + + embeddings = OpenAIEmbeddings() + faiss = FAISS.from_texts(texts, embeddings) + """ + from loguru import logger + logger.debug(f"texts: {len(texts)}") + embeddings = embedding.embed_documents(texts) + return cls.__from( + texts, + embeddings, + embedding, + metadatas=metadatas, + ids=ids, + **kwargs, + ) + + @classmethod + def from_embeddings( + cls, + text_embeddings: Iterable[Tuple[str, List[float]]], + embedding: Embeddings, + metadatas: Optional[Iterable[dict]] = None, + ids: Optional[List[str]] = None, + **kwargs: Any, + ) -> FAISS: + """Construct FAISS wrapper from raw documents. + + This is a user friendly interface that: + 1. Embeds documents. + 2. Creates an in memory docstore + 3. Initializes the FAISS database + + This is intended to be a quick way to get started. + + Example: + .. code-block:: python + + from langchain import FAISS + from langchain.embeddings import OpenAIEmbeddings + + embeddings = OpenAIEmbeddings() + text_embeddings = embeddings.embed_documents(texts) + text_embedding_pairs = zip(texts, text_embeddings) + faiss = FAISS.from_embeddings(text_embedding_pairs, embeddings) + """ + texts = [t[0] for t in text_embeddings] + embeddings = [t[1] for t in text_embeddings] + return cls.__from( + texts, + embeddings, + embedding, + metadatas=metadatas, + ids=ids, + **kwargs, + ) + + def save_local(self, folder_path: str, index_name: str = "index") -> None: + """Save FAISS index, docstore, and index_to_docstore_id to disk. + + Args: + folder_path: folder path to save index, docstore, + and index_to_docstore_id to. + index_name: for saving with a specific index file name + """ + path = Path(folder_path) + path.mkdir(exist_ok=True, parents=True) + + # save index separately since it is not picklable + faiss = dependable_faiss_import() + faiss.write_index( + self.index, str(path / "{index_name}.faiss".format(index_name=index_name)) + ) + + # save docstore and index_to_docstore_id + with open(path / "{index_name}.pkl".format(index_name=index_name), "wb") as f: + pickle.dump((self.docstore, self.index_to_docstore_id), f) + + @classmethod + def load_local( + cls, + folder_path: str, + embeddings: Embeddings, + index_name: str = "index", + **kwargs: Any, + ) -> FAISS: + """Load FAISS index, docstore, and index_to_docstore_id from disk. + + Args: + folder_path: folder path to load index, docstore, + and index_to_docstore_id from. + embeddings: Embeddings to use when generating queries + index_name: for saving with a specific index file name + """ + path = Path(folder_path) + # load index separately since it is not picklable + faiss = dependable_faiss_import() + index = faiss.read_index( + str(path / "{index_name}.faiss".format(index_name=index_name)) + ) + + # load docstore and index_to_docstore_id + with open(path / "{index_name}.pkl".format(index_name=index_name), "rb") as f: + docstore, index_to_docstore_id = pickle.load(f) + return cls( + embeddings.embed_query, index, docstore, index_to_docstore_id, **kwargs + ) + + def serialize_to_bytes(self) -> bytes: + """Serialize FAISS index, docstore, and index_to_docstore_id to bytes.""" + return pickle.dumps((self.index, self.docstore, self.index_to_docstore_id)) + + @classmethod + def deserialize_from_bytes( + cls, + serialized: bytes, + embeddings: Embeddings, + **kwargs: Any, + ) -> FAISS: + """Deserialize FAISS index, docstore, and index_to_docstore_id from bytes.""" + index, docstore, index_to_docstore_id = pickle.loads(serialized) + return cls( + embeddings.embed_query, index, docstore, index_to_docstore_id, **kwargs + ) + + def _select_relevance_score_fn(self) -> Callable[[float], float]: + """ + The 'correct' relevance function + may differ depending on a few things, including: + - the distance / similarity metric used by the VectorStore + - the scale of your embeddings (OpenAI's are unit normed. Many others are not!) + - embedding dimensionality + - etc. + """ + if self.override_relevance_score_fn is not None: + return self.override_relevance_score_fn + + # Default strategy is to rely on distance strategy provided in + # vectorstore constructor + if self.distance_strategy == DistanceStrategy.MAX_INNER_PRODUCT: + return self._max_inner_product_relevance_score_fn + elif self.distance_strategy == DistanceStrategy.EUCLIDEAN_DISTANCE: + # Default behavior is to use euclidean distance relevancy + return self._euclidean_relevance_score_fn + else: + raise ValueError( + "Unknown distance strategy, must be cosine, max_inner_product," + " or euclidean" + ) + + def _similarity_search_with_relevance_scores( + self, + query: str, + k: int = 4, + filter: Optional[Dict[str, Any]] = None, + fetch_k: int = 20, + **kwargs: Any, + ) -> List[Tuple[Document, float]]: + """Return docs and their similarity scores on a scale from 0 to 1.""" + # Pop score threshold so that only relevancy scores, not raw scores, are + # filtered. + relevance_score_fn = self._select_relevance_score_fn() + if relevance_score_fn is None: + raise ValueError( + "normalize_score_fn must be provided to" + " FAISS constructor to normalize scores" + ) + docs_and_scores = self.similarity_search_with_score( + query, + k=k, + filter=filter, + fetch_k=fetch_k, + **kwargs, + ) + docs_and_rel_scores = [ + (doc, relevance_score_fn(score)) for doc, score in docs_and_scores + ] + return docs_and_rel_scores diff --git a/dev_opsgpt/llm_models/__init__.py b/dev_opsgpt/llm_models/__init__.py new file mode 100644 index 0000000..b398434 --- /dev/null +++ b/dev_opsgpt/llm_models/__init__.py @@ -0,0 +1,6 @@ +from .openai_model import getChatModel + + +__all__ = [ + "getChatModel" +] \ No newline at end of file diff --git a/dev_opsgpt/llm_models/openai_model.py b/dev_opsgpt/llm_models/openai_model.py new file mode 100644 index 0000000..ffd4e99 --- /dev/null +++ b/dev_opsgpt/llm_models/openai_model.py @@ -0,0 +1,29 @@ +from langchain.callbacks import AsyncIteratorCallbackHandler +from langchain.chat_models import ChatOpenAI + +from configs.model_config import (llm_model_dict, LLM_MODEL) + + +def getChatModel(callBack: AsyncIteratorCallbackHandler = None, temperature=0.3, stop=None): + if callBack is None: + model = ChatOpenAI( + streaming=True, + verbose=True, + openai_api_key=llm_model_dict[LLM_MODEL]["api_key"], + openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"], + model_name=LLM_MODEL, + temperature=temperature, + stop=stop + ) + else: + model = ChatOpenAI( + streaming=True, + verbose=True, + callBack=[callBack], + openai_api_key=llm_model_dict[LLM_MODEL]["api_key"], + openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"], + model_name=LLM_MODEL, + temperature=temperature, + stop=stop + ) + return model \ No newline at end of file diff --git a/dev_opsgpt/orm/__init__.py b/dev_opsgpt/orm/__init__.py index dde0ce6..2a2c21b 100644 --- a/dev_opsgpt/orm/__init__.py +++ b/dev_opsgpt/orm/__init__.py @@ -18,5 +18,6 @@ def check_tables_exist(table_name) -> bool: return table_exist def table_init(): - if (not check_tables_exist("knowledge_base")) or (not check_tables_exist ("knowledge_file")): + if (not check_tables_exist("knowledge_base")) or (not check_tables_exist ("knowledge_file")) or \ + (not check_tables_exist ("code_base")): create_tables() diff --git a/dev_opsgpt/orm/commands/__init__.py b/dev_opsgpt/orm/commands/__init__.py index a14dd90..dfda957 100644 --- a/dev_opsgpt/orm/commands/__init__.py +++ b/dev_opsgpt/orm/commands/__init__.py @@ -1,5 +1,6 @@ from .document_file_cds import * from .document_base_cds import * +from .code_base_cds import * __all__ = [ "add_kb_to_db", "list_kbs_from_db", "kb_exists", @@ -7,4 +8,7 @@ __all__ = [ "list_docs_from_db", "add_doc_to_db", "delete_file_from_db", "delete_files_from_db", "doc_exists", "get_file_detail", + + "list_cbs_from_db", "add_cb_to_db", "delete_cb_from_db", + "cb_exists", "get_cb_detail", ] \ No newline at end of file diff --git a/dev_opsgpt/orm/commands/code_base_cds.py b/dev_opsgpt/orm/commands/code_base_cds.py new file mode 100644 index 0000000..939d1d5 --- /dev/null +++ b/dev_opsgpt/orm/commands/code_base_cds.py @@ -0,0 +1,79 @@ +# encoding: utf-8 +''' +@author: 温进 +@file: code_base_cds.py.py +@time: 2023/10/23 下午4:34 +@desc: +''' +from loguru import logger +from dev_opsgpt.orm.db import with_session, _engine +from dev_opsgpt.orm.schemas.base_schema import CodeBaseSchema + + +@with_session +def add_cb_to_db(session, code_name, code_path, code_graph_node_num, code_file_num): + # 增:创建知识库实例 + cb = session.query(CodeBaseSchema).filter_by(code_name=code_name).first() + if not cb: + cb = CodeBaseSchema(code_name=code_name, code_path=code_path, code_graph_node_num=code_graph_node_num, + code_file_num=code_file_num) + session.add(cb) + else: + cb.code_path = code_path + cb.code_graph_node_num = code_graph_node_num + return True + + +@with_session +def list_cbs_from_db(session): + ''' + 查:查询实例 + ''' + cbs = session.query(CodeBaseSchema.code_name).all() + cbs = [cb[0] for cb in cbs] + return cbs + + +@with_session +def cb_exists(session, code_name): + ''' + 判断是否存在 + ''' + cb = session.query(CodeBaseSchema).filter_by(code_name=code_name).first() + status = True if cb else False + return status + +@with_session +def load_cb_from_db(session, code_name): + cb = session.query(CodeBaseSchema).filter_by(code_name=code_name).first() + if cb: + code_name, code_path, code_graph_node_num = cb.code_name, cb.code_path, cb.code_graph_node_num + else: + code_name, code_path, code_graph_node_num = None, None, None + return code_name, code_path, code_graph_node_num + + +@with_session +def delete_cb_from_db(session, code_name): + cb = session.query(CodeBaseSchema).filter_by(code_name=code_name).first() + if cb: + session.delete(cb) + return True + + +@with_session +def get_cb_detail(session, code_name: str) -> dict: + cb: CodeBaseSchema = session.query(CodeBaseSchema).filter_by(code_name=code_name).first() + logger.info(cb) + logger.info('code_name={}'.format(cb.code_name)) + if cb: + return { + "code_name": cb.code_name, + "code_path": cb.code_path, + "code_graph_node_num": cb.code_graph_node_num, + 'code_file_num': cb.code_file_num + } + else: + return { + } + diff --git a/dev_opsgpt/orm/schemas/base_schema.py b/dev_opsgpt/orm/schemas/base_schema.py index 197812c..6c497d3 100644 --- a/dev_opsgpt/orm/schemas/base_schema.py +++ b/dev_opsgpt/orm/schemas/base_schema.py @@ -46,3 +46,24 @@ class KnowledgeFileSchema(Base): text_splitter_name='{self.text_splitter_name}', file_version='{self.file_version}', create_time='{self.create_time}')>""" + + +class CodeBaseSchema(Base): + ''' + 代码数据库模型 + ''' + __tablename__ = 'code_base' + id = Column(Integer, primary_key=True, autoincrement=True, comment='代码库 ID') + code_name = Column(String, comment='代码库名称') + code_path = Column(String, comment='代码本地路径') + code_graph_node_num = Column(String, comment='代码图谱节点数') + code_file_num = Column(String, comment='代码解析文件数') + create_time = Column(DateTime, default=func.now(), comment='创建时间') + + def __repr__(self): + return f"""""" diff --git a/dev_opsgpt/sandbox/basebox.py b/dev_opsgpt/sandbox/basebox.py index 55e2b78..476614c 100644 --- a/dev_opsgpt/sandbox/basebox.py +++ b/dev_opsgpt/sandbox/basebox.py @@ -3,6 +3,7 @@ from typing import Optional from pathlib import Path import sys from abc import ABC, abstractclassmethod +from loguru import logger from configs.server_config import SANDBOX_SERVER @@ -22,21 +23,6 @@ class CodeBoxStatus(BaseModel): status: str -class CodeBoxFile(BaseModel): - """ - Represents a file returned from a CodeBox instance. - """ - - name: str - content: Optional[bytes] = None - - def __str__(self): - return self.name - - def __repr__(self): - return f"File({self.name})" - - class BaseBox(ABC): enter_status = False diff --git a/dev_opsgpt/sandbox/pycodebox.py b/dev_opsgpt/sandbox/pycodebox.py index a22b04f..2012752 100644 --- a/dev_opsgpt/sandbox/pycodebox.py +++ b/dev_opsgpt/sandbox/pycodebox.py @@ -1,7 +1,6 @@ import time, os, docker, requests, json, uuid, subprocess, time, asyncio, aiohttp, re, traceback import psutil from typing import List, Optional, Union -from pathlib import Path from loguru import logger from websockets.sync.client import connect as ws_connect_sync @@ -11,7 +10,8 @@ from websockets.client import WebSocketClientProtocol, ClientConnection from websockets.exceptions import ConnectionClosedError from configs.server_config import SANDBOX_SERVER -from .basebox import BaseBox, CodeBoxResponse, CodeBoxStatus, CodeBoxFile +from configs.model_config import JUPYTER_WORK_PATH +from .basebox import BaseBox, CodeBoxResponse, CodeBoxStatus class PyCodeBox(BaseBox): @@ -25,12 +25,18 @@ class PyCodeBox(BaseBox): remote_port: str = SANDBOX_SERVER["port"], token: str = "mytoken", do_code_exe: bool = False, - do_remote: bool = False + do_remote: bool = False, + do_check_net: bool = True, ): super().__init__(remote_url, remote_ip, remote_port, token, do_code_exe, do_remote) self.enter_status = True + self.do_check_net = do_check_net asyncio.run(self.astart()) + # logger.info(f"""remote_url: {self.remote_url}, + # remote_ip: {self.remote_ip}, + # remote_port: {self.remote_port}""") + def decode_code_from_text(self, text: str) -> str: pattern = r'```.*?```' code_blocks = re.findall(pattern, text, re.DOTALL) @@ -73,7 +79,8 @@ class PyCodeBox(BaseBox): if not self.ws: raise RuntimeError("Jupyter not running. Make sure to start it first") - logger.debug(f"code_text: {json.dumps(code_text, ensure_ascii=False)}") + # logger.debug(f"code_text: {len(code_text)}, {code_text}") + self.ws.send( json.dumps( { @@ -103,7 +110,7 @@ class PyCodeBox(BaseBox): raise RuntimeError("Mixing asyncio and sync code is not supported") received_msg = json.loads(self.ws.recv()) except ConnectionClosedError: - logger.debug("box start, ConnectionClosedError!!!") + # logger.debug("box start, ConnectionClosedError!!!") self.start() return self.run(code_text, file_path, retry - 1) @@ -156,7 +163,7 @@ class PyCodeBox(BaseBox): return CodeBoxResponse( code_exe_type="text", code_text=code_text, - code_exe_response=result or "Code run successfully (no output)", + code_exe_response=result or "Code run successfully (no output),可能没有打印需要确认的变量", code_exe_status=200, do_code_exe=self.do_code_exe ) @@ -219,7 +226,6 @@ class PyCodeBox(BaseBox): async def _acheck_connect(self, ) -> bool: if self.kernel_url == "": return False - try: async with aiohttp.ClientSession() as session: async with session.get(f"{self.kernel_url}?token={self.token}", timeout=270) as resp: @@ -231,7 +237,7 @@ class PyCodeBox(BaseBox): def _check_port(self, ) -> bool: try: - response = requests.get(f"http://localhost:{self.remote_port}", timeout=270) + response = requests.get(f"{self.remote_ip}:{self.remote_port}", timeout=270) logger.warning(f"Port is conflict, please check your codebox's port {self.remote_port}") return response.status_code == 200 except requests.exceptions.ConnectionError: @@ -240,7 +246,7 @@ class PyCodeBox(BaseBox): async def _acheck_port(self, ) -> bool: try: async with aiohttp.ClientSession() as session: - async with session.get(f"http://localhost:{self.remote_port}", timeout=270) as resp: + async with session.get(f"{self.remote_ip}:{self.remote_port}", timeout=270) as resp: logger.warning(f"Port is conflict, please check your codebox's port {self.remote_port}") return resp.status == 200 except aiohttp.ClientConnectorError: @@ -249,6 +255,8 @@ class PyCodeBox(BaseBox): pass def _check_connect_success(self, retry_nums: int = 5) -> bool: + if not self.do_check_net: return True + while retry_nums > 0: try: connect_status = self._check_connect() @@ -262,6 +270,7 @@ class PyCodeBox(BaseBox): raise BaseException(f"can't connect to {self.remote_url}") async def _acheck_connect_success(self, retry_nums: int = 5) -> bool: + if not self.do_check_net: return True while retry_nums > 0: try: connect_status = await self._acheck_connect() @@ -283,7 +292,7 @@ class PyCodeBox(BaseBox): self._check_connect_success() self._get_kernelid() - logger.debug(self.kernel_url.replace("http", "ws") + f"/{self.kernel_id}/channels?token={self.token}") + # logger.debug(self.kernel_url.replace("http", "ws") + f"/{self.kernel_id}/channels?token={self.token}") self.wc_url = self.kernel_url.replace("http", "ws") + f"/{self.kernel_id}/channels?token={self.token}" headers = {"Authorization": f'Token {self.token}', 'token': self.token} self.ws = create_connection(self.wc_url, headers=headers) @@ -291,27 +300,30 @@ class PyCodeBox(BaseBox): # TODO 自动检测本地接口 port_status = self._check_port() connect_status = self._check_connect() - logger.debug(f"port_status: {port_status}, connect_status: {connect_status}") + logger.info(f"port_status: {port_status}, connect_status: {connect_status}") if port_status and not connect_status: raise BaseException(f"Port is conflict, please check your codebox's port {self.remote_port}") if not connect_status: - self.jupyter = subprocess.Popen( + self.jupyter = subprocess.run( [ "jupyer", "notebnook", f"--NotebookApp.token={self.token}", f"--port={self.remote_port}", "--no-browser", "--ServerApp.disable_check_xsrf=True", + "--notebook-dir={JUPYTER_WORK_PATH}" ], stderr=subprocess.PIPE, stdin=subprocess.PIPE, stdout=subprocess.PIPE, ) + self.kernel_url = self.remote_url + "/api/kernels" + self.do_check_net = True self._check_connect_success() self._get_kernelid() - logger.debug(self.kernel_url.replace("http", "ws") + f"/{self.kernel_id}/channels?token={self.token}") + # logger.debug(self.kernel_url.replace("http", "ws") + f"/{self.kernel_id}/channels?token={self.token}") self.wc_url = self.kernel_url.replace("http", "ws") + f"/{self.kernel_id}/channels?token={self.token}" headers = {"Authorization": f'Token {self.token}', 'token': self.token} self.ws = create_connection(self.wc_url, headers=headers) @@ -333,10 +345,10 @@ class PyCodeBox(BaseBox): port_status = await self._acheck_port() self.kernel_url = self.remote_url + "/api/kernels" connect_status = await self._acheck_connect() - logger.debug(f"port_status: {port_status}, connect_status: {connect_status}") + logger.info(f"port_status: {port_status}, connect_status: {connect_status}") if port_status and not connect_status: raise BaseException(f"Port is conflict, please check your codebox's port {self.remote_port}") - + if not connect_status: self.jupyter = subprocess.Popen( [ @@ -344,13 +356,15 @@ class PyCodeBox(BaseBox): f"--NotebookApp.token={self.token}", f"--port={self.remote_port}", "--no-browser", - "--ServerApp.disable_check_xsrf=True", + "--ServerApp.disable_check_xsrf=True" ], stderr=subprocess.PIPE, stdin=subprocess.PIPE, stdout=subprocess.PIPE, ) + self.kernel_url = self.remote_url + "/api/kernels" + self.do_check_net = True await self._acheck_connect_success() await self._aget_kernelid() self.wc_url = self.kernel_url.replace("http", "ws") + f"/{self.kernel_id}/channels?token={self.token}" @@ -405,7 +419,8 @@ class PyCodeBox(BaseBox): except Exception as e: logger.error(traceback.format_exc()) self.ws = None - return CodeBoxStatus(status="stopped") + + # return CodeBoxStatus(status="stopped") def __del__(self): self.stop() diff --git a/dev_opsgpt/service/api.py b/dev_opsgpt/service/api.py index de3f5b6..f9e5b75 100644 --- a/dev_opsgpt/service/api.py +++ b/dev_opsgpt/service/api.py @@ -18,15 +18,19 @@ from configs.server_config import OPEN_CROSS_DOMAIN from dev_opsgpt.chat import LLMChat, SearchChat, KnowledgeChat from dev_opsgpt.service.kb_api import * +from dev_opsgpt.service.cb_api import * from dev_opsgpt.utils.server_utils import BaseResponse, ListResponse, FastAPI, MakeFastAPIOffline nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path -from dev_opsgpt.chat import LLMChat, SearchChat, KnowledgeChat +from dev_opsgpt.chat import LLMChat, SearchChat, KnowledgeChat, ToolChat, DataChat, CodeChat llmChat = LLMChat() searchChat = SearchChat() knowledgeChat = KnowledgeChat() +toolChat = ToolChat() +dataChat = DataChat() +codeChat = CodeChat() async def document(): @@ -71,6 +75,18 @@ def create_app(): app.post("/chat/search_engine_chat", tags=["Chat"], summary="与搜索引擎对话")(searchChat.chat) + app.post("/chat/tool_chat", + tags=["Chat"], + summary="与搜索引擎对话")(toolChat.chat) + + app.post("/chat/data_chat", + tags=["Chat"], + summary="与搜索引擎对话")(dataChat.chat) + + app.post("/chat/code_chat", + tags=["Chat"], + summary="与代码库对话")(codeChat.chat) + # Tag: Knowledge Base Management app.get("/knowledge_base/list_knowledge_bases", @@ -129,6 +145,27 @@ def create_app(): summary="根据content中文档重建向量库,流式输出处理进度。" )(recreate_vector_store) + app.post("/code_base/create_code_base", + tags=["Code Base Management"], + summary="新建 code_base" + )(create_cb) + + app.post("/code_base/delete_code_base", + tags=["Code Base Management"], + summary="删除 code_base" + )(delete_cb) + + app.post("/code_base/code_base_chat", + tags=["Code Base Management"], + summary="删除 code_base" + )(delete_cb) + + app.get("/code_base/list_code_bases", + tags=["Code Base Management"], + summary="列举 code_base", + response_model=ListResponse + )(list_cbs) + # # LLM模型相关接口 # app.post("/llm_model/list_models", # tags=["LLM Model Management"], diff --git a/dev_opsgpt/service/cb_api.py b/dev_opsgpt/service/cb_api.py new file mode 100644 index 0000000..937f29d --- /dev/null +++ b/dev_opsgpt/service/cb_api.py @@ -0,0 +1,128 @@ +# encoding: utf-8 +''' +@author: 温进 +@file: cb_api.py +@time: 2023/10/23 下午7:08 +@desc: +''' + +import urllib, os, json, traceback +from typing import List, Dict +import shutil + +from fastapi.responses import StreamingResponse, FileResponse +from fastapi import File, Form, Body, Query, UploadFile +from langchain.docstore.document import Document + +from .service_factory import KBServiceFactory +from dev_opsgpt.utils.server_utils import BaseResponse, ListResponse +from dev_opsgpt.utils.path_utils import * +from dev_opsgpt.orm.commands import * + +from configs.model_config import ( + CB_ROOT_PATH +) + +from dev_opsgpt.codebase_handler.codebase_handler import CodeBaseHandler + +from loguru import logger + + +async def list_cbs(): + # Get List of Knowledge Base + return ListResponse(data=list_cbs_from_db()) + + +async def create_cb(cb_name: str = Body(..., examples=["samples"]), + code_path: str = Body(..., examples=["samples"]) + ) -> BaseResponse: + logger.info('cb_name={}, zip_path={}'.format(cb_name, code_path)) + + # Create selected knowledge base + if not validate_kb_name(cb_name): + return BaseResponse(code=403, msg="Don't attack me") + if cb_name is None or cb_name.strip() == "": + return BaseResponse(code=404, msg="知识库名称不能为空,请重新填写知识库名称") + + cb = cb_exists(cb_name) + if cb: + return BaseResponse(code=404, msg=f"已存在同名代码知识库 {cb_name}") + + try: + logger.info('start build code base') + cbh = CodeBaseHandler(cb_name, code_path, cb_root_path=CB_ROOT_PATH) + cbh.import_code(do_save=True) + code_graph_node_num = len(cbh.nh) + code_file_num = len(cbh.lcdh) + logger.info('build code base done') + + # create cb to table + add_cb_to_db(cb_name, cbh.code_path, code_graph_node_num, code_file_num) + logger.info('add cb to mysql table success') + except Exception as e: + print(e) + return BaseResponse(code=500, msg=f"创建代码知识库出错: {e}") + + return BaseResponse(code=200, msg=f"已新增代码知识库 {cb_name}") + + +async def delete_cb(cb_name: str = Body(..., examples=["samples"])) -> BaseResponse: + logger.info('cb_name={}'.format(cb_name)) + # Create selected knowledge base + if not validate_kb_name(cb_name): + return BaseResponse(code=403, msg="Don't attack me") + if cb_name is None or cb_name.strip() == "": + return BaseResponse(code=404, msg="知识库名称不能为空,请重新填写知识库名称") + + cb = cb_exists(cb_name) + if cb: + try: + delete_cb_from_db(cb_name) + + # delete local file + shutil.rmtree(CB_ROOT_PATH + os.sep + cb_name) + except Exception as e: + print(e) + return BaseResponse(code=500, msg=f"删除代码知识库出错: {e}") + + return BaseResponse(code=200, msg=f"已删除代码知识库 {cb_name}") + + +def search_code(cb_name: str = Body(..., examples=["sofaboot"]), + query: str = Body(..., examples=['你好']), + code_limit: int = Body(..., examples=['1']), + history_node_list: list = Body(...)) -> dict: + + logger.info('cb_name={}'.format(cb_name)) + logger.info('query={}'.format(query)) + logger.info('code_limit={}'.format(code_limit)) + logger.info('history_node_list={}'.format(history_node_list)) + + try: + # load codebase + cbh = CodeBaseHandler(code_name=cb_name, cb_root_path=CB_ROOT_PATH) + cbh.import_code(do_load=True) + + # search code + related_code, related_node = cbh.search_code(query, code_limit=code_limit, history_node_list=history_node_list) + + res = { + 'related_code': related_code, + 'related_node': related_node + } + return res + except Exception as e: + logger.exception(e) + return {} + + +def cb_exists_api(cb_name: str = Body(..., examples=["sofaboot"])) -> bool: + try: + res = cb_exists(cb_name) + return res + except Exception as e: + logger.exception(e) + return False + + + diff --git a/dev_opsgpt/service/faiss_db_service.py b/dev_opsgpt/service/faiss_db_service.py index 314fa4a..ea336bc 100644 --- a/dev_opsgpt/service/faiss_db_service.py +++ b/dev_opsgpt/service/faiss_db_service.py @@ -4,7 +4,7 @@ from typing import List from functools import lru_cache from loguru import logger -from langchain.vectorstores import FAISS +# from langchain.vectorstores import FAISS from langchain.embeddings.base import Embeddings from langchain.docstore.document import Document from langchain.embeddings.huggingface import HuggingFaceEmbeddings @@ -22,6 +22,7 @@ from dev_opsgpt.utils.path_utils import * from dev_opsgpt.orm.utils import DocumentFile from dev_opsgpt.utils.server_utils import torch_gc from dev_opsgpt.embeddings.utils import load_embeddings +from dev_opsgpt.embeddings.faiss_m import FAISS # make HuggingFaceEmbeddings hashable @@ -124,6 +125,7 @@ class FaissKBService(KBService): vector_store = load_vector_store(self.kb_name, embeddings=embeddings, tick=_VECTOR_STORE_TICKS.get(self.kb_name, 0)) + vector_store.embedding_function = embeddings.embed_documents logger.info("docs.lens: {}".format(len(docs))) vector_store.add_documents(docs) torch_gc() diff --git a/dev_opsgpt/service/kb_api.py b/dev_opsgpt/service/kb_api.py index 1bb3caf..dd6363e 100644 --- a/dev_opsgpt/service/kb_api.py +++ b/dev_opsgpt/service/kb_api.py @@ -16,7 +16,6 @@ from configs.model_config import ( ) - async def list_kbs(): # Get List of Knowledge Base return ListResponse(data=list_kbs_from_db()) diff --git a/dev_opsgpt/service/llm_api.py b/dev_opsgpt/service/llm_api.py index da2be7a..00d67e6 100644 --- a/dev_opsgpt/service/llm_api.py +++ b/dev_opsgpt/service/llm_api.py @@ -6,7 +6,6 @@ import os src_dir = os.path.join( os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) ) -print(src_dir) sys.path.append(src_dir) sys.path.append(os.path.dirname(os.path.dirname(__file__))) @@ -20,7 +19,7 @@ model_worker_port = 20002 openai_api_port = 8888 base_url = "http://127.0.0.1:{}" -os.environ['PATH'] = os.environ.get("PATH", "") + os.pathsep + r'/d/env_utils/miniconda3/envs/devopsgpt/Lib/site-packages/torch/lib' +os.environ['PATH'] = os.environ.get("PATH", "") + os.pathsep def set_httpx_timeout(timeout=60.0): import httpx diff --git a/dev_opsgpt/service/sdfile_api.py b/dev_opsgpt/service/sdfile_api.py new file mode 100644 index 0000000..17336b5 --- /dev/null +++ b/dev_opsgpt/service/sdfile_api.py @@ -0,0 +1,129 @@ +import sys, os, json, traceback, uvicorn, argparse + +src_dir = os.path.join( + os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +) +sys.path.append(src_dir) + +from loguru import logger + +from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import StreamingResponse, FileResponse +from fastapi import File, UploadFile + +from dev_opsgpt.utils.server_utils import BaseResponse, ListResponse +from configs.server_config import OPEN_CROSS_DOMAIN, SDFILE_API_SERVER +from configs.model_config import ( + JUPYTER_WORK_PATH +) +from configs import VERSION + + + +async def sd_upload_file(file: UploadFile = File(...), work_dir: str = JUPYTER_WORK_PATH): + # 保存上传的文件到服务器 + try: + content = await file.read() + with open(os.path.join(work_dir, file.filename), "wb") as f: + f.write(content) + return {"data": True} + except: + return {"data": False} + + +async def sd_download_file(filename: str, save_filename: str = "filename_to_download.ext", work_dir: str = JUPYTER_WORK_PATH): + # 从服务器下载文件 + logger.debug(f"{os.path.join(work_dir, filename)}") + return {"data": FileResponse(os.path.join(work_dir, filename), filename=save_filename)} + + +async def sd_list_files(work_dir: str = JUPYTER_WORK_PATH): + # 去除目录 + return {"data": os.listdir(work_dir)} + + +async def sd_delete_file(filename: str, work_dir: str = JUPYTER_WORK_PATH): + # 去除目录 + try: + os.remove(os.path.join(work_dir, filename)) + return {"data": True} + except: + return {"data": False} + + +def create_app(): + app = FastAPI( + title="DevOps-ChatBot API Server", + version=VERSION + ) + # MakeFastAPIOffline(app) + # Add CORS middleware to allow all origins + # 在config.py中设置OPEN_DOMAIN=True,允许跨域 + # set OPEN_DOMAIN=True in config.py to allow cross-domain + if OPEN_CROSS_DOMAIN: + app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) + + app.post("/sdfiles/upload", + tags=["files upload and download"], + response_model=BaseResponse, + summary="上传文件到沙盒" + )(sd_upload_file) + + app.get("/sdfiles/download", + tags=["files upload and download"], + response_model=BaseResponse, + summary="从沙盒下载文件" + )(sd_download_file) + + app.get("/sdfiles/list", + tags=["files upload and download"], + response_model=ListResponse, + summary="从沙盒工作目录展示文件" + )(sd_list_files) + + app.get("/sdfiles/delete", + tags=["files upload and download"], + response_model=BaseResponse, + summary="从沙盒工作目录中删除文件" + )(sd_delete_file) + return app + + + +app = create_app() + +def run_api(host, port, **kwargs): + if kwargs.get("ssl_keyfile") and kwargs.get("ssl_certfile"): + uvicorn.run(app, + host=host, + port=port, + ssl_keyfile=kwargs.get("ssl_keyfile"), + ssl_certfile=kwargs.get("ssl_certfile"), + ) + else: + uvicorn.run(app, host=host, port=port) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(prog='DevOps-ChatBot', + description='About DevOps-ChatBot, local knowledge based LLM with langchain' + ' | 基于本地知识库的 LLM 问答') + parser.add_argument("--host", type=str, default="0.0.0.0") + parser.add_argument("--port", type=int, default=SDFILE_API_SERVER["port"]) + parser.add_argument("--ssl_keyfile", type=str) + parser.add_argument("--ssl_certfile", type=str) + # 初始化消息 + args = parser.parse_args() + args_dict = vars(args) + run_api(host=args.host, + port=args.port, + ssl_keyfile=args.ssl_keyfile, + ssl_certfile=args.ssl_certfile, + ) \ No newline at end of file diff --git a/dev_opsgpt/service/service_factory.py b/dev_opsgpt/service/service_factory.py index 5eb7ae2..59414de 100644 --- a/dev_opsgpt/service/service_factory.py +++ b/dev_opsgpt/service/service_factory.py @@ -77,6 +77,33 @@ def get_kb_details() -> List[Dict]: return data +def get_cb_details() -> List[Dict]: + ''' + get codebase details + @return: list of data + ''' + res = {} + cbs_in_db = list_cbs_from_db() + for cb in cbs_in_db: + cb_detail = get_cb_detail(cb) + res[cb] = cb_detail + + data = [] + for i, v in enumerate(res.values()): + v['No'] = i + 1 + data.append(v) + return data + +def get_cb_details_by_cb_name(cb_name) -> dict: + ''' + get codebase details by cb_name + @return: list of data + ''' + cb_detail = get_cb_detail(cb_name) + return cb_detail + + + def get_kb_doc_details(kb_name: str) -> List[Dict]: kb = KBServiceFactory.get_service_by_name(kb_name) diff --git a/dev_opsgpt/text_splitter/langchain_splitter.py b/dev_opsgpt/text_splitter/langchain_splitter.py index 9e4a65a..5095267 100644 --- a/dev_opsgpt/text_splitter/langchain_splitter.py +++ b/dev_opsgpt/text_splitter/langchain_splitter.py @@ -32,10 +32,12 @@ class LCTextSplitter: loader = self._load_document() text_splitter = self._load_text_splitter() if self.document_loader_name in ["JSONLoader", "JSONLLoader"]: - docs = loader.load() + # docs = loader.load() + docs = loader.load_and_split(text_splitter) + logger.debug(f"please check your file can be loaded, docs.lens {len(docs)}") else: docs = loader.load_and_split(text_splitter) - logger.info(docs[0]) + return docs def _load_document(self, ) -> BaseLoader: @@ -55,8 +57,8 @@ class LCTextSplitter: chunk_overlap=OVERLAP_SIZE, ) self.text_splitter_name = "SpacyTextSplitter" - elif self.document_loader_name in ["JSONLoader", "JSONLLoader"]: - text_splitter = None + # elif self.document_loader_name in ["JSONLoader", "JSONLLoader"]: + # text_splitter = None else: text_splitter_module = importlib.import_module('langchain.text_splitter') TextSplitter = getattr(text_splitter_module, self.text_splitter_name) diff --git a/dev_opsgpt/tools/__init__.py b/dev_opsgpt/tools/__init__.py new file mode 100644 index 0000000..130dc71 --- /dev/null +++ b/dev_opsgpt/tools/__init__.py @@ -0,0 +1,33 @@ +from .base_tool import toLangchainTools, get_tool_schema, BaseToolModel +from .weather import WeatherInfo, DistrictInfo +from .multiplier import Multiplier +from .world_time import WorldTimeGetTimezoneByArea +from .abnormal_detection import KSigmaDetector +from .metrics_query import MetricsQuery +from .duckduckgo_search import DDGSTool +from .docs_retrieval import DocRetrieval +from .cb_query_tool import CodeRetrieval + +TOOL_SETS = [ + "WeatherInfo", "WorldTimeGetTimezoneByArea", "Multiplier", "DistrictInfo", "KSigmaDetector", "MetricsQuery", "DDGSTool", + "DocRetrieval", "CodeRetrieval" + ] + +TOOL_DICT = { + "WeatherInfo": WeatherInfo, + "WorldTimeGetTimezoneByArea": WorldTimeGetTimezoneByArea, + "Multiplier": Multiplier, + "DistrictInfo": DistrictInfo, + "KSigmaDetector": KSigmaDetector, + "MetricsQuery": MetricsQuery, + "DDGSTool": DDGSTool, + "DocRetrieval": DocRetrieval, + "CodeRetrieval": CodeRetrieval +} + +__all__ = [ + "WeatherInfo", "WorldTimeGetTimezoneByArea", "Multiplier", "DistrictInfo", "KSigmaDetector", "MetricsQuery", "DDGSTool", + "DocRetrieval", "CodeRetrieval", + "toLangchainTools", "get_tool_schema", "tool_sets", "BaseToolModel" +] + diff --git a/dev_opsgpt/tools/abnormal_detection.py b/dev_opsgpt/tools/abnormal_detection.py new file mode 100644 index 0000000..3b222f7 --- /dev/null +++ b/dev_opsgpt/tools/abnormal_detection.py @@ -0,0 +1,45 @@ + +import json +import os +import re +from pydantic import BaseModel, Field +from typing import List, Dict +import requests +import numpy as np +from loguru import logger + +from .base_tool import BaseToolModel + + + +class KSigmaDetector(BaseToolModel): + """ + Tips: + default control Required, e.g. key1 is not Required/key2 is Required + """ + + name: str = "KSigmaDetector" + description: str = "Anomaly detection using K-Sigma method" + + class ToolInputArgs(BaseModel): + """Input for KSigmaDetector.""" + + data: List[float] = Field(..., description="List of data points") + detect_window: int = Field(default=5, description="The size of the detect window for detecting anomalies") + abnormal_window: int = Field(default=3, description="The threshold for the number of abnormal points required to classify the data as abnormal") + k: float = Field(default=3.0, description="the coef of k-sigma") + + class ToolOutputArgs(BaseModel): + """Output for KSigmaDetector.""" + + is_abnormal: bool = Field(..., description="Indicates whether the input data is abnormal or not") + + @staticmethod + def run(data, detect_window=5, abnormal_window=3, k=3.0): + refer_data = np.array(data[-detect_window:]) + detect_data = np.array(data[:-detect_window]) + mean = np.mean(refer_data) + std = np.std(refer_data) + + is_abnormal = np.sum(np.abs(detect_data - mean) > k * std) >= abnormal_window + return {"is_abnormal": is_abnormal} \ No newline at end of file diff --git a/dev_opsgpt/tools/base_tool.py b/dev_opsgpt/tools/base_tool.py new file mode 100644 index 0000000..507822a --- /dev/null +++ b/dev_opsgpt/tools/base_tool.py @@ -0,0 +1,79 @@ +from langchain.agents import Tool +from langchain.tools import StructuredTool +from langchain.tools.base import ToolException +from pydantic import BaseModel, Field +from typing import List, Dict +# import jsonref +import json + + +class BaseToolModel: + name = "BaseToolModel" + description = "Tool Description" + + class ToolInputArgs(BaseModel): + """ + Input for MoveFileTool. + Tips: + default control Required, e.g. key1 is not Required/key2 is Required + """ + + key1: str = Field(default=None, description="hello world!") + key2: str = Field(..., description="hello world!!") + + class ToolOutputArgs(BaseModel): + """ + Input for MoveFileTool. + Tips: + default control Required, e.g. key1 is not Required/key2 is Required + """ + + key1: str = Field(default=None, description="hello world!") + key2: str = Field(..., description="hello world!!") + + @classmethod + def run(cls, tool_input_args: ToolInputArgs) -> ToolOutputArgs: + """excute your tool!""" + pass + + +class BaseTools: + tools: List[BaseToolModel] + + +def get_tool_schema(tool: BaseToolModel) -> Dict: + '''转json schema结构''' + data = jsonref.loads(tool.schema_json()) + _ = json.dumps(data, indent=4) + del data["definitions"] + return data + + +def _handle_error(error: ToolException) -> str: + return ( + "The following errors occurred during tool execution:" + + error.args[0] + + "Please try again." + ) + +import requests +from loguru import logger +def fff(city, extensions): + url = "https://restapi.amap.com/v3/weather/weatherInfo" + json_data = {"key": "4ceb2ef6257a627b72e3be6beab5b059", "city": city, "extensions": extensions} + logger.debug(f"json_data: {json_data}") + res = requests.get(url, params={"key": "4ceb2ef6257a627b72e3be6beab5b059", "city": city, "extensions": extensions}) + return res.json() + + +def toLangchainTools(tools: BaseTools) -> List: + '''''' + return [ + StructuredTool( + name=tool.name, + func=tool.run, + description=tool.description, + args_schema=tool.ToolInputArgs, + handle_tool_error=_handle_error, + ) for tool in tools + ] diff --git a/dev_opsgpt/tools/cb_query_tool.py b/dev_opsgpt/tools/cb_query_tool.py new file mode 100644 index 0000000..c162252 --- /dev/null +++ b/dev_opsgpt/tools/cb_query_tool.py @@ -0,0 +1,47 @@ +# encoding: utf-8 +''' +@author: 温进 +@file: cb_query_tool.py +@time: 2023/11/2 下午4:41 +@desc: +''' +import json +import os +import re +from pydantic import BaseModel, Field +from typing import List, Dict +import requests +import numpy as np +from loguru import logger + +from configs.model_config import ( + CODE_SEARCH_TOP_K) +from .base_tool import BaseToolModel + +from dev_opsgpt.service.cb_api import search_code + + +class CodeRetrieval(BaseToolModel): + name = "CodeRetrieval" + description = "采用知识图谱从本地代码知识库获取相关代码" + + class ToolInputArgs(BaseModel): + query: str = Field(..., description="检索的关键字或问题") + code_base_name: str = Field(..., description="知识库名称", examples=["samples"]) + code_limit: int = Field(CODE_SEARCH_TOP_K, description="检索返回的数量") + + class ToolOutputArgs(BaseModel): + """Output for MetricsQuery.""" + code: str = Field(..., description="检索代码") + + @classmethod + def run(cls, code_base_name, query, code_limit=CODE_SEARCH_TOP_K, history_node_list=[]): + """excute your tool!""" + codes = search_code(code_base_name, query, code_limit, history_node_list=history_node_list) + return_codes = [] + related_code = codes['related_code'] + related_nodes = codes['related_node'] + + for idx, code in enumerate(related_code): + return_codes.append({'index': idx, 'code': code, "related_nodes": related_nodes}) + return return_codes diff --git a/dev_opsgpt/tools/docs_retrieval.py b/dev_opsgpt/tools/docs_retrieval.py new file mode 100644 index 0000000..b2cbee3 --- /dev/null +++ b/dev_opsgpt/tools/docs_retrieval.py @@ -0,0 +1,42 @@ +import json +import os +import re +from pydantic import BaseModel, Field +from typing import List, Dict +import requests +import numpy as np +from loguru import logger + +from configs.model_config import ( + VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD) +from .base_tool import BaseToolModel + + + +from dev_opsgpt.service.kb_api import search_docs + + +class DocRetrieval(BaseToolModel): + name = "DocRetrieval" + description = "采用向量化对本地知识库进行检索" + + class ToolInputArgs(BaseModel): + query: str = Field(..., description="检索的关键字或问题") + knowledge_base_name: str = Field(..., description="知识库名称", examples=["samples"]) + search_top: int = Field(VECTOR_SEARCH_TOP_K, description="检索返回的数量") + score_threshold: float = Field(SCORE_THRESHOLD, description="知识库匹配相关度阈值,取值范围在0-1之间,SCORE越小,相关度越高,取到1相当于不筛选,建议设置在0.5左右", ge=0, le=1) + + class ToolOutputArgs(BaseModel): + """Output for MetricsQuery.""" + title: str = Field(..., description="检索网页标题") + snippet: str = Field(..., description="检索内容的判断") + link: str = Field(..., description="检索网页地址") + + @classmethod + def run(cls, query, knowledge_base_name, search_top=VECTOR_SEARCH_TOP_K, score_threshold=SCORE_THRESHOLD): + """excute your tool!""" + docs = search_docs(query, knowledge_base_name, search_top, score_threshold) + return_docs = [] + for idx, doc in enumerate(docs): + return_docs.append({"index": idx, "snippet": doc.page_content, "title": doc.metadata.get("source"), "link": doc.metadata.get("source")}) + return return_docs diff --git a/dev_opsgpt/tools/duckduckgo_search.py b/dev_opsgpt/tools/duckduckgo_search.py new file mode 100644 index 0000000..b69e68e --- /dev/null +++ b/dev_opsgpt/tools/duckduckgo_search.py @@ -0,0 +1,72 @@ + +import json +import os +import re +from pydantic import BaseModel, Field +from typing import List, Dict +import requests +import numpy as np +from loguru import logger + +from .base_tool import BaseToolModel +from configs.model_config import ( + PROMPT_TEMPLATE, SEARCH_ENGINE_TOP_K, BING_SUBSCRIPTION_KEY, BING_SEARCH_URL, + VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD) + +from duckduckgo_search import DDGS + + +class DDGSTool(BaseToolModel): + name = "DDGSTool" + description = "通过duckduckgo进行资料搜索" + + class ToolInputArgs(BaseModel): + query: str = Field(..., description="检索的关键字或问题") + search_top: int = Field(..., description="检索返回的数量") + region: str = Field("wt-wt", enum=["wt-wt", "us-en", "uk-en", "ru-ru"], description="搜索的区域") + safesearch: str = Field("moderate", enum=["on", "moderate", "off"], description="") + timelimit: str = Field(None, enum=[None, "d", "w", "m", "y"], description="查询时间方式") + backend: str = Field("api", description="搜索的资料来源") + + class ToolOutputArgs(BaseModel): + """Output for MetricsQuery.""" + title: str = Field(..., description="检索网页标题") + snippet: str = Field(..., description="检索内容的判断") + link: str = Field(..., description="检索网页地址") + + @classmethod + def run(cls, query, search_top, region="wt-wt", safesearch="moderate", timelimit=None, backend="api"): + """excute your tool!""" + with DDGS(proxies=os.environ.get("DUCKDUCKGO_PROXY")) as ddgs: + results = ddgs.text( + query, + region=region, + safesearch=safesearch, + timelimit=timelimit, + backend=backend, + ) + if results is None: + return [{"Result": "No good DuckDuckGo Search Result was found"}] + + def to_metadata(result: Dict) -> Dict[str, str]: + if backend == "news": + return { + "date": result["date"], + "title": result["title"], + "snippet": result["body"], + "source": result["source"], + "link": result["url"], + } + return { + "snippet": result["body"], + "title": result["title"], + "link": result["href"], + } + + formatted_results = [] + for i, res in enumerate(results, 1): + if res is not None: + formatted_results.append(to_metadata(res)) + if len(formatted_results) == search_top: + break + return formatted_results diff --git a/dev_opsgpt/tools/metrics_query.py b/dev_opsgpt/tools/metrics_query.py new file mode 100644 index 0000000..c3336c7 --- /dev/null +++ b/dev_opsgpt/tools/metrics_query.py @@ -0,0 +1,33 @@ + +import json +import os +import re +from pydantic import BaseModel, Field +from typing import List, Dict +import requests +import numpy as np +from loguru import logger + +from .base_tool import BaseToolModel + + + +class MetricsQuery(BaseToolModel): + name = "MetricsQuery" + description = "查询机器的监控数据" + + class ToolInputArgs(BaseModel): + machine_ip: str = Field(..., description="machine_ip") + time: int = Field(..., description="time period") + + class ToolOutputArgs(BaseModel): + """Output for MetricsQuery.""" + + datas: List[float] = Field(..., description="监控时序数组") + + def run(machine_ip, time): + """excute your tool!""" + data = [0.857, 2.345, 1.234, 4.567, 3.456, 9.876, 5.678, 7.890, 6.789, 8.901, 10.987, 12.345, 11.234, 14.567, 13.456, 19.876, 15.678, 17.890, + 16.789, 18.901, 20.987, 22.345, 21.234, 24.567, 23.456, 29.876, 25.678, 27.890, 26.789, 28.901, 30.987, 32.345, 31.234, 34.567, + 33.456, 39.876, 35.678, 37.890, 36.789, 38.901, 40.987] + return data[:30] \ No newline at end of file diff --git a/dev_opsgpt/tools/multiplier.py b/dev_opsgpt/tools/multiplier.py new file mode 100644 index 0000000..e40d382 --- /dev/null +++ b/dev_opsgpt/tools/multiplier.py @@ -0,0 +1,38 @@ +from pydantic import BaseModel, Field +from typing import List, Dict +import requests +from loguru import logger + +from .base_tool import BaseToolModel + + + +class Multiplier(BaseToolModel): + """ + Tips: + default control Required, e.g. key1 is not Required/key2 is Required + """ + + name: str = "Multiplier" + description: str = """useful for when you need to multiply two numbers together. \ + The input to this tool should be a comma separated list of numbers of length two, representing the two numbers you want to multiply together. \ + For example, `1,2` would be the input if you wanted to multiply 1 by 2.""" + + class ToolInputArgs(BaseModel): + """Input for Multiplier.""" + + # key: str = Field(..., description="用户在高德地图官网申请web服务API类型KEY") + a: int = Field(..., description="num a") + b: int = Field(..., description="num b") + + class ToolOutputArgs(BaseModel): + """Output for Multiplier.""" + + res: int = Field(..., description="the result of two nums") + + @staticmethod + def run(a, b): + return a * b + +def multi_run(a, b): + return a * b \ No newline at end of file diff --git a/dev_opsgpt/tools/sandbox.py b/dev_opsgpt/tools/sandbox.py new file mode 100644 index 0000000..e69de29 diff --git a/dev_opsgpt/tools/weather.py b/dev_opsgpt/tools/weather.py new file mode 100644 index 0000000..bfb5b02 --- /dev/null +++ b/dev_opsgpt/tools/weather.py @@ -0,0 +1,109 @@ + +import json +import os +import re +from pydantic import BaseModel, Field +from typing import List, Dict +import requests +from loguru import logger + +from .base_tool import BaseToolModel + + + +class WeatherInfo(BaseToolModel): + """ + Tips: + default control Required, e.g. key1 is not Required/key2 is Required + """ + + name: str = "WeatherInfo" + description: str = "According to the user's input adcode, it can query the current/future weather conditions of the target area." + + class ToolInputArgs(BaseModel): + """Input for Weather.""" + + # key: str = Field(..., description="用户在高德地图官网申请web服务API类型KEY") + city: str = Field(..., description="城市编码,输入城市的adcode,adcode信息可参考城市编码表") + extensions: str = Field(default=None, enum=["base", "all"], description="气象类型,输入城市的adcode,adcode信息可参考城市编码表") + + class ToolOutputArgs(BaseModel): + """Output for Weather.""" + + lives: str = Field(default=None, description="实况天气数据") + + # @classmethod + # def run(cls, tool_input_args: ToolInputArgs) -> ToolOutputArgs: + # """excute your tool!""" + # url = "https://restapi.amap.com/v3/weather/weatherInfo" + # try: + # json_data = tool_input_args.dict() + # json_data["key"] = "4ceb2ef6257a627b72e3be6beab5b059" + # res = requests.get(url, json_data) + # return res.json() + # except Exception as e: + # return e + + @staticmethod + def run(city, extensions) -> ToolOutputArgs: + """excute your tool!""" + url = "https://restapi.amap.com/v3/weather/weatherInfo" + try: + json_data = {} + json_data["city"] = city + json_data["key"] = "4ceb2ef6257a627b72e3be6beab5b059" + json_data["extensions"] = extensions + logger.debug(f"json_data: {json_data}") + res = requests.get(url, params=json_data) + return res.json() + except Exception as e: + return e + + +class DistrictInfo(BaseToolModel): + """ + Tips: + default control Required, e.g. key1 is not Required/key2 is Required + """ + + name: str = "DistrictInfo" + description: str = "用户希望通过得到行政区域信息,进行开发工作。" + + class ToolInputArgs(BaseModel): + """Input for district.""" + keywords: str = Field(default=None, description="规则:只支持单个关键词语搜索关键词支持:行政区名称、citycode、adcode例如,在subdistrict=2,搜索省份(例如山东),能够显示市(例如济南),区(例如历下区)") + subdistrict: str = Field(default=None, enums=[1,2,3], description="""规则:设置显示下级行政区级数(行政区级别包括:国家、省/直辖市、市、区/县、乡镇/街道多级数据) + +可选值:0、1、2、3等数字,并以此类推 + +0:不返回下级行政区; + +1:返回下一级行政区; + +2:返回下两级行政区; + +3:返回下三级行政区;""") + page: int = Field(default=1, examples=["page=2", "page=3"], description="最外层的districts最多会返回20个数据,若超过限制,请用page请求下一页数据。") + extensions: str = Field(default=None, enum=["base", "all"], description="气象类型,输入城市的adcode,adcode信息可参考城市编码表") + + class ToolOutputArgs(BaseModel): + """Output for district.""" + + districts: str = Field(default=None, description="行政区列表") + + @staticmethod + def run(keywords=None, subdistrict=None, page=1, extensions=None) -> ToolOutputArgs: + """excute your tool!""" + url = "https://restapi.amap.com/v3/config/district" + try: + json_data = {} + json_data["keywords"] = keywords + json_data["key"] = "4ceb2ef6257a627b72e3be6beab5b059" + json_data["subdistrict"] = subdistrict + json_data["page"] = page + json_data["extensions"] = extensions + logger.debug(f"json_data: {json_data}") + res = requests.get(url, params=json_data) + return res.json() + except Exception as e: + return e diff --git a/dev_opsgpt/tools/world_time.py b/dev_opsgpt/tools/world_time.py new file mode 100644 index 0000000..642a280 --- /dev/null +++ b/dev_opsgpt/tools/world_time.py @@ -0,0 +1,255 @@ +import json +import os +import re +import requests +from pydantic import BaseModel, Field +from typing import List + +from .base_tool import BaseToolModel + + +class WorldTimeGetTimezoneByArea(BaseToolModel): + """ + World Time API + Tips: + default control Required, e.g. key1 is not Required/key2 is Required + """ + + name = "WorldTime.getTimezoneByArea" + description = "a listing of all timezones available for that area." + + class ToolInputArgs(BaseModel): + """Input for WorldTimeGetTimezoneByArea.""" + area: str = Field(..., description="area") + + class ToolOutputArgs(BaseModel): + """Output for WorldTimeGetTimezoneByArea.""" + DateTimeJsonResponse: str = Field(..., description="a list of available timezones") + + @classmethod + def run(area: str) -> ToolOutputArgs: + """excute your tool!""" + url = "http://worldtimeapi.org/api/timezone" + try: + res = requests.get(url, json={"area": area}) + return res.text + except Exception as e: + return e + + +def worldtime_run(area): + url = "http://worldtimeapi.org/api/timezone" + res = requests.get(url, json={"area": area}) + return res.text + +# class WorldTime(BaseTool): +# api_spec: str = ''' +# description: >- +# A simple API to get the current time based on +# a request with a timezone. + +# servers: +# - url: http://worldtimeapi.org/api/ + +# paths: +# /timezone: +# get: +# description: a listing of all timezones. +# operationId: getTimezone +# responses: +# default: +# $ref: "#/components/responses/SuccessfulListJsonResponse" + +# /timezone/{area}: +# get: +# description: a listing of all timezones available for that area. +# operationId: getTimezoneByArea +# parameters: +# - name: area +# in: path +# required: true +# schema: +# type: string +# responses: +# '200': +# $ref: "#/components/responses/SuccessfulListJsonResponse" +# default: +# $ref: "#/components/responses/ErrorJsonResponse" + +# /timezone/{area}/{location}: +# get: +# description: request the current time for a timezone. +# operationId: getTimeByTimezone +# parameters: +# - name: area +# in: path +# required: true +# schema: +# type: string +# - name: location +# in: path +# required: true +# schema: +# type: string +# responses: +# '200': +# $ref: "#/components/responses/SuccessfulDateTimeJsonResponse" +# default: +# $ref: "#/components/responses/ErrorJsonResponse" + +# /ip: +# get: +# description: >- +# request the current time based on the ip of the request. +# note: this is a "best guess" obtained from open-source data. +# operationId: getTimeByIP +# responses: +# '200': +# $ref: "#/components/responses/SuccessfulDateTimeJsonResponse" +# default: +# $ref: "#/components/responses/ErrorJsonResponse" + +# components: +# responses: +# SuccessfulListJsonResponse: +# description: >- +# the list of available timezones in JSON format +# content: +# application/json: +# schema: +# $ref: "#/components/schemas/ListJsonResponse" + +# SuccessfulDateTimeJsonResponse: +# description: >- +# the current time for the timezone requested in JSON format +# content: +# application/json: +# schema: +# $ref: "#/components/schemas/DateTimeJsonResponse" + +# ErrorJsonResponse: +# description: >- +# an error response in JSON format +# content: +# application/json: +# schema: +# $ref: "#/components/schemas/ErrorJsonResponse" + +# schemas: +# ListJsonResponse: +# type: array +# description: >- +# a list of available timezones +# items: +# type: string + +# DateTimeJsonResponse: +# required: +# - abbreviation +# - client_ip +# - datetime +# - day_of_week +# - day_of_year +# - dst +# - dst_offset +# - timezone +# - unixtime +# - utc_datetime +# - utc_offset +# - week_number +# properties: +# abbreviation: +# type: string +# description: >- +# the abbreviated name of the timezone +# client_ip: +# type: string +# description: >- +# the IP of the client making the request +# datetime: +# type: string +# description: >- +# an ISO8601-valid string representing +# the current, local date/time +# day_of_week: +# type: integer +# description: >- +# current day number of the week, where sunday is 0 +# day_of_year: +# type: integer +# description: >- +# ordinal date of the current year +# dst: +# type: boolean +# description: >- +# flag indicating whether the local +# time is in daylight savings +# dst_from: +# type: string +# description: >- +# an ISO8601-valid string representing +# the datetime when daylight savings +# started for this timezone +# dst_offset: +# type: integer +# description: >- +# the difference in seconds between the current local +# time and daylight saving time for the location +# dst_until: +# type: string +# description: >- +# an ISO8601-valid string representing +# the datetime when daylight savings +# will end for this timezone +# raw_offset: +# type: integer +# description: >- +# the difference in seconds between the current local time +# and the time in UTC, excluding any daylight saving difference +# (see dst_offset) +# timezone: +# type: string +# description: >- +# timezone in `Area/Location` or +# `Area/Location/Region` format +# unixtime: +# type: integer +# description: >- +# number of seconds since the Epoch +# utc_datetime: +# type: string +# description: >- +# an ISO8601-valid string representing +# the current date/time in UTC +# utc_offset: +# type: string +# description: >- +# an ISO8601-valid string representing +# the offset from UTC +# week_number: +# type: integer +# description: >- +# the current week number + +# ErrorJsonResponse: +# required: +# - error +# properties: +# error: +# type: string +# description: >- +# details about the error encountered +# ''' + +# def exec_tool(self, message: UserMessage) -> UserMessage: +# match = re.search(r'{[\s\S]*}', message.content) +# if match: +# params = json.loads(match.group()) +# url = params["url"] +# if "params" in params: +# url = url.format(**params["params"]) +# res = requests.get(url) +# response_msg = UserMessage(content=f"API response: {res.text}") +# else: +# raise "ERROR" +# return response_msg \ No newline at end of file diff --git a/dev_opsgpt/utils/common_utils.py b/dev_opsgpt/utils/common_utils.py index 01fc2f2..64d2fa6 100644 --- a/dev_opsgpt/utils/common_utils.py +++ b/dev_opsgpt/utils/common_utils.py @@ -2,6 +2,11 @@ import textwrap, time, copy, random, hashlib, json, os from datetime import datetime, timedelta from functools import wraps from loguru import logger +from typing import * +from pathlib import Path +from io import BytesIO +from fastapi import Body, File, Form, Body, Query, UploadFile +from tempfile import SpooledTemporaryFile @@ -65,3 +70,23 @@ def save_to_json_file(data, filename): with open(filename, "w", encoding="utf-8") as f: json.dump(data, f, indent=2, ensure_ascii=False) + + +def file_normalize(file: Union[str, Path, bytes], filename=None): + logger.debug(f"{file}") + if isinstance(file, bytes): # raw bytes + file = BytesIO(file) + elif hasattr(file, "read"): # a file io like object + filename = filename or file.name + else: # a local path + file = Path(file).absolute().open("rb") + logger.debug(file) + filename = filename or file.name + return file, filename + + +def get_uploadfile(file: Union[str, Path, bytes], filename=None) -> UploadFile: + temp_file = SpooledTemporaryFile(max_size=10 * 1024 * 1024) + temp_file.write(file.read()) + temp_file.seek(0) + return UploadFile(file=temp_file, filename=filename) \ No newline at end of file diff --git a/dev_opsgpt/utils/path_utils.py b/dev_opsgpt/utils/path_utils.py index 67daf6b..26e766a 100644 --- a/dev_opsgpt/utils/path_utils.py +++ b/dev_opsgpt/utils/path_utils.py @@ -29,7 +29,7 @@ LOADER2EXT_DICT = {"UnstructuredFileLoader": ['.eml', '.html', '.md', '.msg', '. "TextLoader": ['.txt'], "PythonLoader": ['.py'], "JSONLoader": ['.json'], - "JSONLLoader": ['.jsonl'], + "JSONLLoader": ['.jsonl'] } EXT2LOADER_DICT = {ext: LOADERNAME2LOADER_DICT[k] for k, exts in LOADER2EXT_DICT.items() for ext in exts} @@ -61,8 +61,10 @@ def list_kbs_from_folder(): def list_docs_from_folder(kb_name: str): doc_path = get_doc_path(kb_name) - return [file for file in os.listdir(doc_path) - if os.path.isfile(os.path.join(doc_path, file))] + if os.path.exists(doc_path): + return [file for file in os.listdir(doc_path) + if os.path.isfile(os.path.join(doc_path, file))] + return [] def get_LoaderClass(file_extension): for LoaderClass, extensions in LOADER2EXT_DICT.items(): diff --git a/dev_opsgpt/webui/__init__.py b/dev_opsgpt/webui/__init__.py index eaa52a5..d5bd734 100644 --- a/dev_opsgpt/webui/__init__.py +++ b/dev_opsgpt/webui/__init__.py @@ -1,9 +1,10 @@ from .dialogue import dialogue_page, chat_box from .document import knowledge_page +from .code import code_page from .prompt import prompt_page from .utils import ApiRequest __all__ = [ "dialogue_page", "chat_box", "prompt_page", "knowledge_page", - "ApiRequest" + "ApiRequest", "code_page" ] \ No newline at end of file diff --git a/dev_opsgpt/webui/code.py b/dev_opsgpt/webui/code.py new file mode 100644 index 0000000..93b2e04 --- /dev/null +++ b/dev_opsgpt/webui/code.py @@ -0,0 +1,140 @@ +# encoding: utf-8 +''' +@author: 温进 +@file: code.py.py +@time: 2023/10/23 下午5:31 +@desc: +''' + +import streamlit as st +import os +import time +import traceback +from typing import Literal, Dict, Tuple +from st_aggrid import AgGrid, JsCode +from st_aggrid.grid_options_builder import GridOptionsBuilder +import pandas as pd + +from configs.model_config import embedding_model_dict, kbs_config, EMBEDDING_MODEL, DEFAULT_VS_TYPE, WEB_CRAWL_PATH +from .utils import * +from dev_opsgpt.utils.path_utils import * +from dev_opsgpt.service.service_factory import get_cb_details, get_cb_details_by_cb_name +from dev_opsgpt.orm import table_init + +# SENTENCE_SIZE = 100 + +cell_renderer = JsCode("""function(params) {if(params.value==true){return '✓'}else{return '×'}}""") + + +def file_exists(cb: str, selected_rows: List) -> Tuple[str, str]: + ''' + check whether the dir exist in local file + return the dir's name and path if it exists. + ''' + if selected_rows: + file_name = selected_rows[0]["code_name"] + file_path = get_file_path(cb, file_name) + if os.path.isfile(file_path): + return file_name, file_path + return "", "" + + +def code_page(api: ApiRequest): + # 判断表是否存在并进行初始化 + table_init() + + try: + logger.info(get_cb_details()) + cb_list = {x["code_name"]: x for x in get_cb_details()} + except Exception as e: + logger.exception(e) + st.error("获取知识库信息错误,请检查是否已按照 `README.md` 中 `4 知识库初始化与迁移` 步骤完成初始化或迁移,或是否为数据库连接错误。") + st.stop() + cb_names = list(cb_list.keys()) + + if "selected_cb_name" in st.session_state and st.session_state["selected_cb_name"] in cb_names: + selected_cb_index = cb_names.index(st.session_state["selected_cb_name"]) + else: + selected_cb_index = 0 + + def format_selected_cb(cb_name: str) -> str: + if cb := cb_list.get(cb_name): + return f"{cb_name} ({cb['code_path']})" + else: + return cb_name + + selected_cb = st.selectbox( + "请选择或新建代码知识库:", + cb_names + ["新建代码知识库"], + format_func=format_selected_cb, + index=selected_cb_index + ) + + if selected_cb == "新建代码知识库": + with st.form("新建代码知识库"): + + cb_name = st.text_input( + "新建代码知识库名称", + placeholder="新代码知识库名称,不支持中文命名", + key="cb_name", + ) + + file = st.file_uploader("上传代码库 zip 文件", + ['.zip'], + accept_multiple_files=False, + ) + + submit_create_kb = st.form_submit_button( + "新建", + use_container_width=True, + ) + + if submit_create_kb: + # unzip file + logger.info('files={}'.format(file)) + + if not cb_name or not cb_name.strip(): + st.error(f"知识库名称不能为空!") + elif cb_name in cb_list: + st.error(f"名为 {cb_name} 的知识库已经存在!") + elif file.type not in ['application/zip', 'application/x-zip-compressed']: + logger.error(f"{file.type}") + st.error('请先上传 zip 文件,再新建代码知识库') + else: + ret = api.create_code_base( + cb_name, + file, + no_remote_api=True + ) + st.toast(ret.get("msg", " ")) + st.session_state["selected_cb_name"] = cb_name + st.experimental_rerun() + elif selected_cb: + cb = selected_cb + + # 知识库详情 + cb_details = get_cb_details_by_cb_name(cb) + if not len(cb_details): + st.info(f"代码知识库 `{cb}` 中暂无信息") + else: + logger.info(cb_details) + st.write(f"代码知识库 `{cb}` 加载成功,中含有以下信息:") + + st.write('代码知识库 `{}` 代码文件数=`{}`'.format(cb_details['code_name'], + cb_details.get('code_file_num', 'unknown'))) + + st.write('代码知识库 `{}` 知识图谱节点数=`{}`'.format(cb_details['code_name'], cb_details['code_graph_node_num'])) + + st.divider() + + cols = st.columns(3) + + if cols[2].button( + "删除知识库", + use_container_width=True, + ): + ret = api.delete_code_base(cb, + no_remote_api=True) + st.toast(ret.get("msg", " ")) + time.sleep(1) + st.experimental_rerun() diff --git a/dev_opsgpt/webui/dialogue.py b/dev_opsgpt/webui/dialogue.py index 07e1c20..bb45af7 100644 --- a/dev_opsgpt/webui/dialogue.py +++ b/dev_opsgpt/webui/dialogue.py @@ -2,10 +2,13 @@ import streamlit as st from streamlit_chatbox import * from typing import List, Dict from datetime import datetime - +from random import randint from .utils import * + from dev_opsgpt.utils import * +from dev_opsgpt.tools import TOOL_SETS from dev_opsgpt.chat.search_chat import SEARCH_ENGINES +from dev_opsgpt.connector import PHASE_LIST, PHASE_CONFIGS chat_box = ChatBox( assistant_avatar="../sources/imgs/devops-chatbot2.png" @@ -55,7 +58,11 @@ def dialogue_page(api: ApiRequest): dialogue_mode = st.selectbox("请选择对话模式", ["LLM 对话", "知识库问答", + "代码知识库问答", "搜索引擎问答", + "工具问答", + "数据分析", + "Agent问答" ], on_change=on_mode_change, key="dialogue_mode", @@ -67,6 +74,10 @@ def dialogue_page(api: ApiRequest): def on_kb_change(): st.toast(f"已加载知识库: {st.session_state.selected_kb}") + def on_cb_change(): + st.toast(f"已加载代码知识库: {st.session_state.selected_cb}") + + not_agent_qa = True if dialogue_mode == "知识库问答": with st.expander("知识库配置", True): kb_list = api.list_knowledge_bases(no_remote_api=True) @@ -80,13 +91,142 @@ def dialogue_page(api: ApiRequest): score_threshold = st.number_input("知识匹配分数阈值:", 0.0, float(SCORE_THRESHOLD), float(SCORE_THRESHOLD), float(SCORE_THRESHOLD//100)) # chunk_content = st.checkbox("关联上下文", False, disabled=True) # chunk_size = st.slider("关联长度:", 0, 500, 250, disabled=True) + elif dialogue_mode == '代码知识库问答': + with st.expander('代码知识库配置', True): + cb_list = api.list_cb(no_remote_api=True) + logger.debug('codebase_list={}'.format(cb_list)) + selected_cb = st.selectbox( + "请选择代码知识库:", + cb_list, + on_change=on_cb_change, + key="selected_cb", + ) + st.toast(f"已加载代码知识库: {st.session_state.selected_cb}") + cb_code_limit = st.number_input("匹配代码条数:", 1, 20, 1) elif dialogue_mode == "搜索引擎问答": with st.expander("搜索引擎配置", True): search_engine = st.selectbox("请选择搜索引擎", SEARCH_ENGINES.keys(), 0) se_top_k = st.number_input("匹配搜索结果条数:", 1, 20, 3) + elif dialogue_mode == "工具问答": + with st.expander("工具军火库", True): + tool_selects = st.multiselect( + '请选择待使用的工具', TOOL_SETS, ["WeatherInfo"]) + + elif dialogue_mode == "数据分析": + with st.expander("沙盒文件管理", False): + def _upload(upload_file): + res = api.web_sd_upload(upload_file) + logger.debug(res) + if res["msg"]: + st.success("上文件传成功") + else: + st.toast("文件上传失败") - code_interpreter_on = st.toggle("开启代码解释器") - code_exec_on = st.toggle("自动执行代码") + interpreter_file = st.file_uploader( + "上传沙盒文件", + [i for ls in LOADER2EXT_DICT.values() for i in ls], + accept_multiple_files=False, + key="interpreter_file", + ) + + if interpreter_file: + _upload(interpreter_file) + interpreter_file = None + # + files = api.web_sd_list_files() + files = files["data"] + download_file = st.selectbox("选择要处理文件", files, + key="download_file",) + + cols = st.columns(2) + file_url, file_name = api.web_sd_download(download_file) + cols[0].download_button("点击下载", file_url, file_name) + if cols[1].button("点击删除", ): + api.web_sd_delete(download_file) + + elif dialogue_mode == "Agent问答": + not_agent_qa = False + with st.expander("Phase管理", True): + choose_phase = st.selectbox( + '请选择待使用的执行链路', PHASE_LIST, 0) + + is_detailed = st.toggle("返回明细的Agent交互", False) + tool_using_on = st.toggle("开启工具使用", PHASE_CONFIGS[choose_phase]["do_using_tool"]) + tool_selects = [] + if tool_using_on: + with st.expander("工具军火库", True): + tool_selects = st.multiselect( + '请选择待使用的工具', TOOL_SETS, ["WeatherInfo"]) + + search_on = st.toggle("开启搜索增强", PHASE_CONFIGS[choose_phase]["do_search"]) + search_engine, top_k = None, 3 + if search_on: + with st.expander("搜索引擎配置", True): + search_engine = st.selectbox("请选择搜索引擎", SEARCH_ENGINES.keys(), 0) + top_k = st.number_input("匹配搜索结果条数:", 1, 20, 3) + + doc_retrieval_on = st.toggle("开启知识库检索增强", PHASE_CONFIGS[choose_phase]["do_doc_retrieval"]) + selected_kb, top_k, score_threshold = None, 3, 1.0 + if doc_retrieval_on: + with st.expander("知识库配置", True): + kb_list = api.list_knowledge_bases(no_remote_api=True) + selected_kb = st.selectbox( + "请选择知识库:", + kb_list, + on_change=on_kb_change, + key="selected_kb", + ) + top_k = st.number_input("匹配知识条数:", 1, 20, 3) + score_threshold = st.number_input("知识匹配分数阈值:", 0.0, float(SCORE_THRESHOLD), float(SCORE_THRESHOLD), float(SCORE_THRESHOLD//100)) + + code_retrieval_on = st.toggle("开启代码检索增强", PHASE_CONFIGS[choose_phase]["do_code_retrieval"]) + selected_cb, top_k = None, 1 + if code_retrieval_on: + with st.expander('代码知识库配置', True): + cb_list = api.list_cb(no_remote_api=True) + logger.debug('codebase_list={}'.format(cb_list)) + selected_cb = st.selectbox( + "请选择代码知识库:", + cb_list, + on_change=on_cb_change, + key="selected_cb", + ) + st.toast(f"已加载代码知识库: {st.session_state.selected_cb}") + top_k = st.number_input("匹配代码条数:", 1, 20, 1) + + with st.expander("沙盒文件管理", False): + def _upload(upload_file): + res = api.web_sd_upload(upload_file) + logger.debug(res) + if res["msg"]: + st.success("上文件传成功") + else: + st.toast("文件上传失败") + + interpreter_file = st.file_uploader( + "上传沙盒文件", + [i for ls in LOADER2EXT_DICT.values() for i in ls], + accept_multiple_files=False, + key="interpreter_file", + ) + + if interpreter_file: + _upload(interpreter_file) + interpreter_file = None + # + files = api.web_sd_list_files() + files = files["data"] + download_file = st.selectbox("选择要处理文件", files, + key="download_file",) + + cols = st.columns(2) + file_url, file_name = api.web_sd_download(download_file) + cols[0].download_button("点击下载", file_url, file_name) + if cols[1].button("点击删除", ): + api.web_sd_delete(download_file) + + code_interpreter_on = st.toggle("开启代码解释器") and not_agent_qa + code_exec_on = st.toggle("自动执行代码") and not_agent_qa # Display chat messages from history on app rerun @@ -102,7 +242,97 @@ def dialogue_page(api: ApiRequest): if dialogue_mode == "LLM 对话": chat_box.ai_say("正在思考...") text = "" - r = api.chat_chat(prompt, history) + r = api.chat_chat(prompt, history, no_remote_api=True) + for t in r: + if error_msg := check_error_msg(t): # check whether error occured + st.error(error_msg) + break + text += t["answer"] + chat_box.update_msg(text) + logger.debug(f"text: {text}") + chat_box.update_msg(text, streaming=False) # 更新最终的字符串,去除光标 + # 判断是否存在代码, 并提高编辑功能,执行功能 + code_text = api.codebox.decode_code_from_text(text) + GLOBAL_EXE_CODE_TEXT = code_text + if code_text and code_exec_on: + codebox_res = api.codebox_chat("```"+code_text+"```", do_code_exe=True) + elif dialogue_mode == "Agent问答": + display_infos = [f"正在思考..."] + if search_on: + display_infos.append(Markdown("...", in_expander=True, title="网络搜索结果")) + if doc_retrieval_on: + display_infos.append(Markdown("...", in_expander=True, title="知识库匹配结果")) + chat_box.ai_say(display_infos) + + if 'history_node_list' in st.session_state: + history_node_list: List[str] = st.session_state['history_node_list'] + else: + history_node_list: List[str] = [] + + input_kargs = {"query": prompt, + "phase_name": choose_phase, + "history": history, + "doc_engine_name": selected_kb, + "search_engine_name": search_engine, + "code_engine_name": selected_cb, + "top_k": top_k, + "score_threshold": score_threshold, + "do_search": search_on, + "do_doc_retrieval": doc_retrieval_on, + "do_code_retrieval": code_retrieval_on, + "do_tool_retrieval": False, + "custom_phase_configs": {}, + "custom_chain_configs": {}, + "custom_role_configs": {}, + "choose_tools": tool_selects, + "history_node_list": history_node_list, + "isDetailed": is_detailed, + } + text = "" + d = {"docs": []} + for idx_count, d in enumerate(api.agent_chat(**input_kargs)): + if error_msg := check_error_msg(d): # check whether error occured + st.error(error_msg) + text += d["answer"] + if idx_count%20 == 0: + chat_box.update_msg(text, element_index=0) + + for k, v in d["figures"].items(): + logger.debug(f"figure: {k}") + if k in text: + img_html = "\n\n".format(v) + text = text.replace(k, img_html).replace(".png", "") + chat_box.update_msg(text, element_index=0, streaming=False, state="complete") # 更新最终的字符串,去除光标 + if search_on: + chat_box.update_msg("搜索匹配结果:\n\n" + "\n\n".join(d["search_docs"]), element_index=search_on, streaming=False, state="complete") + if doc_retrieval_on: + chat_box.update_msg("知识库匹配结果:\n\n" + "\n\n".join(d["db_docs"]), element_index=search_on+doc_retrieval_on, streaming=False, state="complete") + + history_node_list.extend([node[0] for node in d.get("related_nodes", [])]) + history_node_list = list(set(history_node_list)) + st.session_state['history_node_list'] = history_node_list + + elif dialogue_mode == "工具问答": + chat_box.ai_say("正在思考...") + text = "" + r = api.tool_chat(prompt, history, tool_sets=tool_selects) + for t in r: + if error_msg := check_error_msg(t): # check whether error occured + st.error(error_msg) + break + text += t["answer"] + chat_box.update_msg(text) + logger.debug(f"text: {text}") + chat_box.update_msg(text, streaming=False) # 更新最终的字符串,去除光标 + # 判断是否存在代码, 并提高编辑功能,执行功能 + code_text = api.codebox.decode_code_from_text(text) + GLOBAL_EXE_CODE_TEXT = code_text + if code_text and code_exec_on: + codebox_res = api.codebox_chat("```"+code_text+"```", do_code_exe=True) + elif dialogue_mode == "数据分析": + chat_box.ai_say("正在思考...") + text = "" + r = api.data_chat(prompt, history) for t in r: if error_msg := check_error_msg(t): # check whether error occured st.error(error_msg) @@ -116,7 +346,6 @@ def dialogue_page(api: ApiRequest): GLOBAL_EXE_CODE_TEXT = code_text if code_text and code_exec_on: codebox_res = api.codebox_chat("```"+code_text+"```", do_code_exe=True) - elif dialogue_mode == "知识库问答": history = get_messages_history(history_len) chat_box.ai_say([ @@ -124,6 +353,7 @@ def dialogue_page(api: ApiRequest): Markdown("...", in_expander=True, title="知识库匹配结果"), ]) text = "" + d = {"docs": []} for idx_count, d in enumerate(api.knowledge_base_chat(prompt, selected_kb, kb_top_k, score_threshold, history)): if error_msg := check_error_msg(d): # check whether error occured st.error(error_msg) @@ -138,6 +368,33 @@ def dialogue_page(api: ApiRequest): GLOBAL_EXE_CODE_TEXT = code_text if code_text and code_exec_on: codebox_res = api.codebox_chat("```"+code_text+"```", do_code_exe=True) + elif dialogue_mode == '代码知识库问答': + logger.info('prompt={}'.format(prompt)) + logger.info('history={}'.format(history)) + if 'history_node_list' in st.session_state: + api.codeChat.history_node_list = st.session_state['history_node_list'] + + chat_box.ai_say([ + f"正在查询代码知识库 `{selected_cb}` ...", + Markdown("...", in_expander=True, title="代码库匹配结果"), + ]) + text = "" + d = {"codes": []} + + for idx_count, d in enumerate(api.code_base_chat(query=prompt, code_base_name=selected_cb, + code_limit=cb_code_limit, history=history, + no_remote_api=True)): + if error_msg := check_error_msg(d): + st.error(error_msg) + text += d["answer"] + if idx_count % 10 == 0: + chat_box.update_msg(text, element_index=0) + chat_box.update_msg(text, element_index=0, streaming=False) # 更新最终的字符串,去除光标 + chat_box.update_msg("\n".join(d["codes"]), element_index=1, streaming=False, state="complete") + + # session state update + st.session_state['history_node_list'] = api.codeChat.history_node_list + elif dialogue_mode == "搜索引擎问答": chat_box.ai_say([ f"正在执行 `{search_engine}` 搜索...", @@ -145,7 +402,7 @@ def dialogue_page(api: ApiRequest): ]) text = "" d = {"docs": []} - for d in api.search_engine_chat(prompt, search_engine, se_top_k): + for idx_count, d in enumerate(api.search_engine_chat(prompt, search_engine, se_top_k)): if error_msg := check_error_msg(d): # check whether error occured st.error(error_msg) text += d["answer"] @@ -194,7 +451,7 @@ def dialogue_page(api: ApiRequest): img_html = "".format( codebox_res.code_exe_response ) - chat_box.update_msg(base_text + img_html, streaming=False, state="complete") + chat_box.update_msg(img_html, streaming=False, state="complete") else: chat_box.update_msg('```\n'+code_text+'\n```'+"\n\n"+'```\n'+codebox_res.code_exe_response+'\n```', streaming=False, state="complete") diff --git a/dev_opsgpt/webui/document.py b/dev_opsgpt/webui/document.py index d8df1de..ac0f774 100644 --- a/dev_opsgpt/webui/document.py +++ b/dev_opsgpt/webui/document.py @@ -137,7 +137,24 @@ def knowledge_page(api: ApiRequest): [i for ls in LOADER2EXT_DICT.values() for i in ls], accept_multiple_files=True, ) - + + if st.button( + "添加文件到知识库", + # help="请先上传文件,再点击添加", + # use_container_width=True, + disabled=len(files) == 0, + ): + data = [{"file": f, "knowledge_base_name": kb, "not_refresh_vs_cache": True} for f in files] + data[-1]["not_refresh_vs_cache"]=False + for k in data: + pass + ret = api.upload_kb_doc(**k) + if msg := check_success_msg(ret): + st.toast(msg, icon="✔") + elif msg := check_error_msg(ret): + st.toast(msg, icon="✖") + st.session_state.files = [] + base_url = st.text_input( "待获取内容的URL地址", placeholder="请填写正确可打开的URL地址", @@ -187,22 +204,6 @@ def knowledge_page(api: ApiRequest): if os.path.exists(html_path): os.remove(html_path) - if st.button( - "添加文件到知识库", - # help="请先上传文件,再点击添加", - # use_container_width=True, - disabled=len(files) == 0, - ): - data = [{"file": f, "knowledge_base_name": kb, "not_refresh_vs_cache": True} for f in files] - data[-1]["not_refresh_vs_cache"]=False - for k in data: - ret = api.upload_kb_doc(**k) - if msg := check_success_msg(ret): - st.toast(msg, icon="✔") - elif msg := check_error_msg(ret): - st.toast(msg, icon="✖") - st.session_state.files = [] - st.divider() # 知识库详情 diff --git a/dev_opsgpt/webui/utils.py b/dev_opsgpt/webui/utils.py index 48b5a65..5868fd3 100644 --- a/dev_opsgpt/webui/utils.py +++ b/dev_opsgpt/webui/utils.py @@ -10,11 +10,13 @@ import json import nltk import traceback from loguru import logger +import zipfile from configs.model_config import ( EMBEDDING_MODEL, DEFAULT_VS_TYPE, KB_ROOT_PATH, + CB_ROOT_PATH, LLM_MODEL, SCORE_THRESHOLD, VECTOR_SEARCH_TOP_K, @@ -27,8 +29,10 @@ from configs.server_config import SANDBOX_SERVER from dev_opsgpt.utils.server_utils import run_async, iter_over_async from dev_opsgpt.service.kb_api import * -from dev_opsgpt.chat import LLMChat, SearchChat, KnowledgeChat +from dev_opsgpt.service.cb_api import * +from dev_opsgpt.chat import LLMChat, SearchChat, KnowledgeChat, ToolChat, DataChat, CodeChat, AgentChat from dev_opsgpt.sandbox import PyCodeBox, CodeBoxResponse +from dev_opsgpt.utils.common_utils import file_normalize, get_uploadfile from web_crawler.utils.WebCrawler import WebCrawler @@ -58,15 +62,22 @@ class ApiRequest: def __init__( self, base_url: str = "http://127.0.0.1:7861", + sandbox_file_url: str = "http://127.0.0.1:7862", timeout: float = 60.0, no_remote_api: bool = False, # call api view function directly ): self.base_url = base_url + self.sandbox_file_url = sandbox_file_url self.timeout = timeout self.no_remote_api = no_remote_api self.llmChat = LLMChat() self.searchChat = SearchChat() self.knowledgeChat = KnowledgeChat() + self.toolChat = ToolChat() + self.dataChat = DataChat() + self.codeChat = CodeChat() + + self.agentChat = AgentChat() self.codebox = PyCodeBox( remote_url=SANDBOX_SERVER["url"], remote_ip=SANDBOX_SERVER["host"], # "http://localhost", @@ -83,7 +94,8 @@ class ApiRequest: if (not url.startswith("http") and self.base_url ): - part1 = self.base_url.strip(" /") + part1 = self.sandbox_file_url.strip(" /") \ + if "sdfiles" in url else self.base_url.strip(" /") part2 = url.strip(" /") return f"{part1}/{part2}" else: @@ -331,7 +343,7 @@ class ApiRequest: self, query: str, search_engine_name: str, - top_k: int = SEARCH_ENGINE_TOP_K, + code_limit: int, stream: bool = True, no_remote_api: bool = None, ): @@ -344,7 +356,7 @@ class ApiRequest: data = { "query": query, "engine_name": search_engine_name, - "top_k": top_k, + "code_limit": code_limit, "history": [], "stream": stream, } @@ -360,7 +372,157 @@ class ApiRequest: ) return self._httpx_stream2generator(response, as_json=True) - # 知识库相关操作 + def tool_chat( + self, + query: str, + history: List[Dict] = [], + tool_sets: List[str] = [], + stream: bool = True, + no_remote_api: bool = None, + ): + ''' + 对应api.py/chat/chat接口 + ''' + if no_remote_api is None: + no_remote_api = self.no_remote_api + + data = { + "query": query, + "history": history, + "tool_sets": tool_sets, + "stream": stream, + } + + if no_remote_api: + response = self.toolChat.chat(**data) + return self._fastapi_stream2generator(response, as_json=True) + else: + response = self.post("/chat/tool_chat", json=data, stream=True) + return self._httpx_stream2generator(response) + + def data_chat( + self, + query: str, + history: List[Dict] = [], + stream: bool = True, + no_remote_api: bool = None, + ): + ''' + 对应api.py/chat/chat接口 + ''' + if no_remote_api is None: + no_remote_api = self.no_remote_api + + data = { + "query": query, + "history": history, + "stream": stream, + } + + if no_remote_api: + response = self.dataChat.chat(**data) + return self._fastapi_stream2generator(response, as_json=True) + else: + response = self.post("/chat/data_chat", json=data, stream=True) + return self._httpx_stream2generator(response) + + def code_base_chat( + self, + query: str, + code_base_name: str, + code_limit: int = 1, + history: List[Dict] = [], + stream: bool = True, + no_remote_api: bool = None, + ): + ''' + 对应api.py/chat/knowledge_base_chat接口 + ''' + if no_remote_api is None: + no_remote_api = self.no_remote_api + + data = { + "query": query, + "history": history, + "engine_name": code_base_name, + "code_limit": code_limit, + "stream": stream, + "local_doc_url": no_remote_api, + } + logger.info('data={}'.format(data)) + + if no_remote_api: + logger.info('history_node_list before={}'.format(self.codeChat.history_node_list)) + response = self.codeChat.chat(**data) + logger.info('history_node_list after={}'.format(self.codeChat.history_node_list)) + return self._fastapi_stream2generator(response, as_json=True) + else: + response = self.post( + "/chat/code_chat", + json=data, + stream=True, + ) + return self._httpx_stream2generator(response, as_json=True) + + def agent_chat( + self, + query: str, + phase_name: str, + doc_engine_name: str, + code_engine_name: str, + search_engine_name: str, + top_k: int = 3, + score_threshold: float = 1.0, + history: List[Dict] = [], + stream: bool = True, + local_doc_url: bool = False, + do_search: bool = False, + do_doc_retrieval: bool = False, + do_code_retrieval: bool = False, + do_tool_retrieval: bool = False, + choose_tools: List[str] = [], + custom_phase_configs = {}, + custom_chain_configs = {}, + custom_role_configs = {}, + no_remote_api: bool = None, + history_node_list: List[str] = [], + isDetailed: bool = False + ): + ''' + 对应api.py/chat/chat接口 + ''' + if no_remote_api is None: + no_remote_api = self.no_remote_api + + data = { + "query": query, + "phase_name": phase_name, + "chain_name": "", + "history": history, + "doc_engine_name": doc_engine_name, + "code_engine_name": code_engine_name, + "search_engine_name": search_engine_name, + "top_k": top_k, + "score_threshold": score_threshold, + "stream": stream, + "local_doc_url": local_doc_url, + "do_search": do_search, + "do_doc_retrieval": do_doc_retrieval, + "do_code_retrieval": do_code_retrieval, + "do_tool_retrieval": do_tool_retrieval, + "custom_phase_configs": custom_phase_configs, + "custom_chain_configs": custom_phase_configs, + "custom_role_configs": custom_role_configs, + "choose_tools": choose_tools, + "history_node_list": history_node_list, + "isDetailed": isDetailed + } + if no_remote_api: + response = self.agentChat.chat(**data) + return self._fastapi_stream2generator(response, as_json=True) + else: + response = self.post("/chat/data_chat", json=data, stream=True) + return self._httpx_stream2generator(response) def _check_httpx_json_response( self, @@ -377,6 +539,21 @@ class ApiRequest: logger.error(e) return {"code": 500, "msg": errorMsg or str(e)} + def _check_httpx_file_response( + self, + response: httpx.Response, + errorMsg: str = f"无法连接API服务器,请确认已执行python server\\api.py", + ) -> Dict: + ''' + check whether httpx returns correct data with normal Response. + error in api with streaming support was checked in _httpx_stream2enerator + ''' + try: + return response.content + except Exception as e: + logger.error(e) + return {"code": 500, "msg": errorMsg or str(e)} + def list_knowledge_bases( self, no_remote_api: bool = None, @@ -662,6 +839,122 @@ class ApiRequest: else: raise Exception("not impletenion") + def web_sd_upload(self, file: str = None, filename: str = None): + '''对应file_service/sd_upload_file''' + file, filename = file_normalize(file, filename) + response = self.post( + "/sdfiles/upload", + files={"file": (filename, file)}, + ) + return self._check_httpx_json_response(response) + + def web_sd_download(self, filename: str, save_filename: str = None): + '''对应file_service/sd_download_file''' + save_filename = save_filename or filename + # response = self.get( + # f"/sdfiles/download", + # params={"filename": filename, "save_filename": save_filename} + # ) + key_value_str = f"filename={filename}&save_filename={save_filename}" + return self._parse_url(f"/sdfiles/download?{key_value_str}"), save_filename + + def web_sd_delete(self, filename: str): + '''对应file_service/sd_delete_file''' + response = self.get( + f"/sdfiles/delete", + params={"filename": filename} + ) + return self._check_httpx_json_response(response) + + def web_sd_list_files(self, ): + '''对应对应file_service/sd_list_files接口''' + response = self.get("/sdfiles/list",) + return self._check_httpx_json_response(response) + + # code base 相关操作 + def create_code_base(self, cb_name, zip_file, no_remote_api: bool = None,): + ''' + 创建 code_base + @param cb_name: + @param zip_path: + @return: + ''' + if no_remote_api is None: + no_remote_api = self.no_remote_api + + # mkdir + cb_root_path = CB_ROOT_PATH + mkdir_dir = [ + cb_root_path, + cb_root_path + os.sep + cb_name, + raw_code_path := cb_root_path + os.sep + cb_name + os.sep + 'raw_code' + ] + for dir in mkdir_dir: + if not os.path.exists(dir): + os.makedirs(dir) + + # unzip + with zipfile.ZipFile(zip_file, 'r') as z: + z.extractall(raw_code_path) + + data = { + "cb_name": cb_name, + "code_path": raw_code_path + } + logger.info('create cb data={}'.format(data)) + + if no_remote_api: + response = run_async(create_cb(**data)) + return response.dict() + else: + response = self.post( + "/code_base/create_code_base", + json=data, + ) + logger.info('response={}'.format(response.json())) + return self._check_httpx_json_response(response) + + def delete_code_base(self, cb_name: str, no_remote_api: bool = None,): + ''' + 删除 code_base + @param cb_name: + @return: + ''' + if no_remote_api is None: + no_remote_api = self.no_remote_api + + data = { + "cb_name": cb_name + } + + if no_remote_api: + response = run_async(delete_cb(**data)) + return response.dict() + else: + response = self.post( + "/code_base/delete_code_base", + json=cb_name + ) + logger.info(response.json()) + return self._check_httpx_json_response(response) + + def list_cb(self, no_remote_api: bool = None): + ''' + 列举 code_base + @return: + ''' + if no_remote_api is None: + no_remote_api = self.no_remote_api + + if no_remote_api: + response = run_async(list_cbs()) + return response.data + else: + response = self.get("/code_base/list_code_bases") + data = self._check_httpx_json_response(response) + return data.get("data", []) + + def check_error_msg(data: Union[str, dict, list], key: str = "errorMsg") -> str: ''' diff --git a/docker_build.sh b/docker_build.sh index 1d2a7cf..be951aa 100644 --- a/docker_build.sh +++ b/docker_build.sh @@ -1,3 +1,3 @@ #!/bin/bash -docker build -t devopsgpt:pypy38 . \ No newline at end of file +docker build -t devopsgpt:py39 . \ No newline at end of file diff --git a/examples/start.py b/examples/start.py new file mode 100644 index 0000000..8b3ab56 --- /dev/null +++ b/examples/start.py @@ -0,0 +1,208 @@ +import docker, sys, os, time, requests, psutil +import subprocess +from docker.types import Mount, DeviceRequest +from loguru import logger + +src_dir = os.path.join( + os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +) +sys.path.append(src_dir) + +from configs.model_config import USE_FASTCHAT +from configs.server_config import ( + NO_REMOTE_API, SANDBOX_SERVER, SANDBOX_IMAGE_NAME, SANDBOX_CONTRAINER_NAME, + WEBUI_SERVER, API_SERVER, SDFILE_API_SERVER, CONTRAINER_NAME, IMAGE_NAME, DOCKER_SERVICE, + DEFAULT_BIND_HOST, +) + + +import platform +system_name = platform.system() +USE_TTY = system_name in ["Windows"] + + +def check_process(content: str, lang: str = None, do_stop=False): + '''process-not-exist is true, process-exist is false''' + for process in psutil.process_iter(["pid", "name", "cmdline"]): + # check process name contains "jupyter" and port=xx + + # if f"port={SANDBOX_SERVER['port']}" in str(process.info["cmdline"]).lower() and \ + # "jupyter" in process.info['name'].lower(): + if content in str(process.info["cmdline"]).lower(): + logger.info(f"content, {process.info}") + # 关闭进程 + if do_stop: + process.terminate() + return True + return False + return True + +def check_docker(client, container_name, do_stop=False): + '''container-not-exist is true, container-exist is false''' + for i in client.containers.list(all=True): + if i.name == container_name: + if do_stop: + container = i + container.stop() + container.remove() + return True + return False + return True + +def start_docker(client, script_shs, ports, image_name, container_name, mounts=None, network=None): + container = client.containers.run( + image=image_name, + command="bash", + mounts=mounts, + name=container_name, + # device_requests=[DeviceRequest(count=-1, capabilities=[['gpu']])], + # network_mode="host", + ports=ports, + stdin_open=True, + detach=True, + tty=USE_TTY, + network=network, + ) + + logger.info(f"docker id: {container.id[:10]}") + + # 启动notebook + for script_sh in script_shs: + if USE_FASTCHAT and "llm_api" in script_sh: + logger.debug(script_sh) + response = container.exec_run(["sh", "-c", script_sh]) + logger.debug(response) + elif "llm_api" not in script_sh: + logger.debug(script_sh) + response = container.exec_run(["sh", "-c", script_sh]) + logger.debug(response) + return container + +######################################### +############# 开始启动服务 ############### +######################################### + +client = docker.from_env() +client.containers.run +network_name ='my_network' + +def start_sandbox_service(): + networks = client.networks.list() + if any([network_name==i.attrs["Name"] for i in networks]): + network = client.networks.get(network_name) + else: + network = client.networks.create('my_network', driver='bridge') + + mount = Mount( + type='bind', + source=os.path.join(src_dir, "jupyter_work"), + target='/home/user/chatbot/jupyter_work', + read_only=False # 如果需要只读访问,将此选项设置为True + ) + mounts = [mount] + # 沙盒的启动与服务的启动是独立的 + if SANDBOX_SERVER["do_remote"]: + # 启动容器 + logger.info("start container sandbox service") + script_shs = ["bash jupyter_start.sh"] + JUPYTER_WORK_PATH = "/home/user/chatbot/jupyter_work" + script_shs = [f"cd /home/user/chatbot/jupyter_work && nohup jupyter-notebook --NotebookApp.token=mytoken --port=5050 --allow-root --ip=0.0.0.0 --notebook-dir={JUPYTER_WORK_PATH} --no-browser --ServerApp.disable_check_xsrf=True &"] + ports = {f"{SANDBOX_SERVER['docker_port']}/tcp": f"{SANDBOX_SERVER['port']}/tcp"} + if check_docker(client, SANDBOX_CONTRAINER_NAME, do_stop=True, ): + container = start_docker(client, script_shs, ports, SANDBOX_IMAGE_NAME, SANDBOX_CONTRAINER_NAME, mounts=mounts, network=network_name) + # 判断notebook是否启动 + retry_nums = 3 + while retry_nums>0: + response = requests.get(f"http://localhost:{SANDBOX_SERVER['port']}", timeout=270) + if response.status_code == 200: + logger.info("container & notebook init success") + break + else: + retry_nums -= 1 + logger.info(client.containers.list()) + logger.info("wait container running ...") + time.sleep(5) + + else: + check_docker(client, SANDBOX_CONTRAINER_NAME, do_stop=True, ) + logger.info("start local sandbox service") + +def start_api_service(sandbox_host=DEFAULT_BIND_HOST): + # 启动service的容器 + if DOCKER_SERVICE: + logger.info("start container service") + check_process("service/api.py", do_stop=True) + check_process("service/sdfile_api.py", do_stop=True) + check_process("service/sdfile_api.py", do_stop=True) + check_process("webui.py", do_stop=True) + mount = Mount( + type='bind', + source=src_dir, + target='/home/user/chatbot/', + read_only=False # 如果需要只读访问,将此选项设置为True + ) + mount_database = Mount( + type='bind', + source=os.path.join(src_dir, "knowledge_base"), + target='/home/user/knowledge_base/', + read_only=False # 如果需要只读访问,将此选项设置为True + ) + + ports={ + f"{API_SERVER['docker_port']}/tcp": f"{API_SERVER['port']}/tcp", + f"{WEBUI_SERVER['docker_port']}/tcp": f"{WEBUI_SERVER['port']}/tcp", + f"{SDFILE_API_SERVER['docker_port']}/tcp": f"{SDFILE_API_SERVER['port']}/tcp", + } + mounts = [mount, mount_database] + script_shs = [ + "mkdir -p /home/user/logs", + "pip install zdatafront-sdk-python -i https://artifacts.antgroup-inc.cn/simple", + "pip install jsonref", + "pip install javalang", + "nohup python chatbot/dev_opsgpt/service/sdfile_api.py > /home/user/logs/sdfile_api.log 2>&1 &", + f"export DUCKDUCKGO_PROXY=socks5://host.docker.internal:13659 && export SANDBOX_HOST={sandbox_host} &&\ + nohup python chatbot/dev_opsgpt/service/api.py > /home/user/logs/api.log 2>&1 &", + "nohup python chatbot/dev_opsgpt/service/llm_api.py > /home/user/ 2>&1 &", + f"export DUCKDUCKGO_PROXY=socks5://host.docker.internal:13659 && export SANDBOX_HOST={sandbox_host} &&\ + cd chatbot/examples && nohup streamlit run webui.py > /home/user/logs/start_webui.log 2>&1 &" + ] + if check_docker(client, CONTRAINER_NAME, do_stop=True): + container = start_docker(client, script_shs, ports, IMAGE_NAME, CONTRAINER_NAME, mounts, network=network_name) + + else: + logger.info("start local service") + # 关闭之前启动的docker 服务 + # check_docker(client, CONTRAINER_NAME, do_stop=True, ) + + api_sh = "nohup python ../dev_opsgpt/service/api.py > ../logs/api.log 2>&1 &" + sdfile_sh = "nohup python ../dev_opsgpt/service/sdfile_api.py > ../logs/sdfile_api.log 2>&1 &" + llm_sh = "nohup python ../dev_opsgpt/service/llm_api.py > ../logs/llm_api.log 2>&1 &" + webui_sh = "streamlit run webui.py" if USE_TTY else "streamlit run webui.py" + # + if not NO_REMOTE_API and check_process("service/api.py"): + logger.info('check 1') + subprocess.Popen(api_sh, shell=True) + # + if USE_FASTCHAT and check_process("service/llm_api.py"): + subprocess.Popen(llm_sh, shell=True) + # + if check_process("service/sdfile_api.py"): + subprocess.Popen(sdfile_sh, shell=True) + + subprocess.Popen(webui_sh, shell=True) + + + +if __name__ == "__main__": + start_sandbox_service() + client = docker.from_env() + containers = client.containers.list(all=True) + + sandbox_host = DEFAULT_BIND_HOST + for container in containers: + container_a_info = client.containers.get(container.id) + if container_a_info.name == SANDBOX_CONTRAINER_NAME: + container1_networks = container.attrs['NetworkSettings']['Networks'] + sandbox_host = container1_networks.get(network_name)["IPAddress"] + break + start_api_service(sandbox_host) diff --git a/examples/start_sandbox.py b/examples/start_sandbox.py index 43a46f2..4ba53ff 100644 --- a/examples/start_sandbox.py +++ b/examples/start_sandbox.py @@ -7,13 +7,13 @@ src_dir = os.path.join( ) sys.path.append(src_dir) -from configs.server_config import CONTRAINER_NAME, SANDBOX_SERVER, IMAGE_NAME +from configs.server_config import SANDBOX_SERVER, SANDBOX_IMAGE_NAME, SANDBOX_CONTRAINER_NAME if SANDBOX_SERVER["do_remote"]: client = docker.from_env() for i in client.containers.list(all=True): - if i.name == CONTRAINER_NAME: + if i.name == SANDBOX_CONTRAINER_NAME: container = i container.stop() container.remove() @@ -21,10 +21,10 @@ if SANDBOX_SERVER["do_remote"]: # 启动容器 logger.info("start ot init container & notebook") container = client.containers.run( - image=IMAGE_NAME, + image=SANDBOX_IMAGE_NAME, command="bash", - name=CONTRAINER_NAME, - ports={"5050/tcp": SANDBOX_SERVER["port"]}, + name=SANDBOX_CONTRAINER_NAME, + ports={f"{SANDBOX_SERVER['docker_port']}/tcp": SANDBOX_SERVER["port"]}, stdin_open=True, detach=True, tty=True, diff --git a/examples/start_service_docker.py b/examples/start_service_docker.py new file mode 100644 index 0000000..017ff85 --- /dev/null +++ b/examples/start_service_docker.py @@ -0,0 +1,68 @@ +import docker, sys, os, time, requests +from docker.types import Mount + +from loguru import logger + +src_dir = os.path.join( + os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +) +sys.path.append(src_dir) + +from configs.server_config import WEBUI_SERVER, API_SERVER, SDFILE_API_SERVER, CONTRAINER_NAME, IMAGE_NAME +from configs.model_config import USE_FASTCHAT + + + +logger.info(f"IMAGE_NAME: {IMAGE_NAME}, CONTRAINER_NAME: {CONTRAINER_NAME}, ") + + +client = docker.from_env() +for i in client.containers.list(all=True): + if i.name == CONTRAINER_NAME: + container = i + container.stop() + container.remove() + break + + + +# 启动容器 +logger.info("start service") + +mount = Mount( + type='bind', + source=src_dir, + target='/home/user/chatbot/', + read_only=True # 如果需要只读访问,将此选项设置为True +) + +container = client.containers.run( + image=IMAGE_NAME, + command="bash", + mounts=[mount], + name=CONTRAINER_NAME, + ports={ + f"{WEBUI_SERVER['docker_port']}/tcp": API_SERVER['port'], + f"{API_SERVER['docker_port']}/tcp": WEBUI_SERVER['port'], + f"{SDFILE_API_SERVER['docker_port']}/tcp": SDFILE_API_SERVER['port'], + }, + stdin_open=True, + detach=True, + tty=True, +) + +# 启动notebook +exec_command = container.exec_run("bash jupyter_start.sh") +# +exec_command = container.exec_run("cd /homse/user/chatbot && nohup python devops_gpt/service/sdfile_api.py > /homse/user/logs/sdfile_api.log &") +# +exec_command = container.exec_run("cd /homse/user/chatbot && nohup python devops_gpt/service/api.py > /homse/user/logs/api.log &") + +if USE_FASTCHAT: + # 启动fastchat的服务 + exec_command = container.exec_run("cd /homse/user/chatbot && nohup python devops_gpt/service/llm_api.py > /homse/user/logs/llm_api.log &") +# +exec_command = container.exec_run("cd /homse/user/chatbot/examples && nohup bash start_webui.sh > /homse/user/logs/start_webui.log &") + + + diff --git a/examples/stop.py b/examples/stop.py new file mode 100644 index 0000000..536cd35 --- /dev/null +++ b/examples/stop.py @@ -0,0 +1,28 @@ +import docker, sys, os +from loguru import logger + +src_dir = os.path.join( + os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +) +sys.path.append(src_dir) + +from configs.server_config import ( + SANDBOX_CONTRAINER_NAME, CONTRAINER_NAME, SANDBOX_SERVER, DOCKER_SERVICE +) + + +from start import check_docker, check_process + +client = docker.from_env() + +# +check_docker(client, SANDBOX_CONTRAINER_NAME, do_stop=True, ) +check_process(f"port={SANDBOX_SERVER['port']}", do_stop=True) +check_process(f"port=5050", do_stop=True) + +# +check_docker(client, CONTRAINER_NAME, do_stop=True, ) +check_process("service/api.py", do_stop=True) +check_process("service/sdfile_api.py", do_stop=True) +check_process("service/llm_api.py", do_stop=True) +check_process("webui.py", do_stop=True) diff --git a/examples/webui.py b/examples/webui.py index 739453c..3ce6a9e 100644 --- a/examples/webui.py +++ b/examples/webui.py @@ -4,10 +4,13 @@ # 3. 运行API服务器:python server/api.py。如果使用api = ApiRequest(no_remote_api=True),该步可以跳过。 # 4. 运行WEB UI:streamlit run webui.py --server.port 7860 -import os, sys +import os +import sys import streamlit as st from streamlit_option_menu import option_menu +import multiprocessing + src_dir = os.path.join( os.path.dirname(os.path.dirname(os.path.abspath(__file__))) ) @@ -15,10 +18,11 @@ sys.path.append(src_dir) from dev_opsgpt.webui import * from configs import VERSION, LLM_MODEL +from configs.server_config import NO_REMOTE_API -api = ApiRequest(base_url="http://127.0.0.1:7861", no_remote_api=True) +api = ApiRequest(base_url="http://127.0.0.1:7861", no_remote_api=NO_REMOTE_API) if __name__ == "__main__": @@ -48,6 +52,10 @@ if __name__ == "__main__": "icon": "hdd-stack", "func": knowledge_page, }, + "代码知识库管理": { + "icon": "hdd-stack", + "func": code_page, + }, # "Prompt管理": { # "icon": "hdd-stack", # "func": prompt_page, diff --git a/requirements.txt b/requirements.txt index ed8e1ee..36bd661 100644 --- a/requirements.txt +++ b/requirements.txt @@ -25,8 +25,7 @@ notebook websockets fake_useragent selenium -auto-gptq==0.4.2 - +jsonref # uncomment libs if you want to use corresponding vector store # pymilvus==2.1.3 # requires milvus==2.1.3 @@ -41,3 +40,5 @@ streamlit-antd-components>=0.1.11 streamlit-chatbox>=1.1.6 streamlit-aggrid>=0.3.4.post3 httpx~=0.24.1 + +javalang==0.13.0 diff --git a/sources/docs/python_langchain_com_docs_get_started_introduction_text.jsonl b/sources/docs/python_langchain_com_docs_get_started_introduction_text.jsonl new file mode 100644 index 0000000..20da287 --- /dev/null +++ b/sources/docs/python_langchain_com_docs_get_started_introduction_text.jsonl @@ -0,0 +1 @@ +{"url": "https://python.langchain.com/docs/get_started/introduction", "host_url": "https://python.langchain.com", "title": "Introduction | 🦜️🔗 Langchain", "all_text": "\n\nIntroduction | 🦜️🔗 Langchain\n\nSkip to main content🦜️🔗 LangChainDocsUse casesIntegrationsAPICommunityChat our docsLangSmithJS/TS DocsSearchCTRLKGet startedIntroductionInstallationQuickstartLangChain Expression LanguageInterfaceHow toCookbookLangChain Expression Language (LCEL)ModulesModel I/​ORetrievalChainsMemoryAgentsCallbacksModulesGuidesMoreGet startedIntroductionOn this pageIntroductionLangChain is a framework for developing applications powered by language models. It enables applications that:Are context-aware: connect a language model to sources of context (prompt instructions, few shot examples, content to ground its response in, etc.)Reason: rely on a language model to reason (about how to answer based on provided context, what actions to take, etc.)The main value props of LangChain are:Components: abstractions for working with language models, along with a collection of implementations for each abstraction. Components are modular and easy-to-use, whether you are using the rest of the LangChain framework or notOff-the-shelf chains: a structured assembly of components for accomplishing specific higher-level tasksOff-the-shelf chains make it easy to get started. For complex applications, components make it easy to customize existing chains and build new ones.Get started​Here’s how to install LangChain, set up your environment, and start building.We recommend following our Quickstart guide to familiarize yourself with the framework by building your first LangChain application.Note: These docs are for the LangChain Python package. For documentation on LangChain.js, the JS/TS version, head here.Modules​LangChain provides standard, extendable interfaces and external integrations for the following modules, listed from least to most complex:Model I/O​Interface with language modelsRetrieval​Interface with application-specific dataChains​Construct sequences of callsAgents​Let chains choose which tools to use given high-level directivesMemory​Persist application state between runs of a chainCallbacks​Log and stream intermediate steps of any chainExamples, ecosystem, and resources​Use cases​Walkthroughs and best-practices for common end-to-end use cases, like:Document question answeringChatbotsAnalyzing structured dataand much more...Guides​Learn best practices for developing with LangChain.Ecosystem​LangChain is part of a rich ecosystem of tools that integrate with our framework and build on top of it. Check out our growing list of integrations and dependent repos.Additional resources​Our community is full of prolific developers, creative builders, and fantastic teachers. Check out YouTube tutorials for great tutorials from folks in the community, and Gallery for a list of awesome LangChain projects, compiled by the folks at KyroLabs.Community​Head to the Community navigator to find places to ask questions, share feedback, meet other developers, and dream about the future of LLM’s.API reference​Head to the reference section for full documentation of all classes and methods in the LangChain Python package.PreviousGet startedNextInstallationGet startedModulesExamples, ecosystem, and resourcesUse casesGuidesEcosystemAdditional resourcesCommunityAPI referenceCommunityDiscordTwitterGitHubPythonJS/TSMoreHomepageBlogCopyright © 2023 LangChain, Inc.\n\n"} diff --git a/sources/docs/zhuanlan_zhihu_com_p_80963305_text.jsonl b/sources/docs/zhuanlan_zhihu_com_p_80963305_text.jsonl new file mode 100644 index 0000000..1d63956 --- /dev/null +++ b/sources/docs/zhuanlan_zhihu_com_p_80963305_text.jsonl @@ -0,0 +1 @@ +{"url": "https://zhuanlan.zhihu.com/p/80963305", "host_url": "https://zhuanlan.zhihu.com", "title": "【工具类】PyCharm+Anaconda+jupyter notebook +pip环境配置 - 知乎", "all_text": "\n【工具类】PyCharm+Anaconda+jupyter notebook +pip环境配置 - 知乎切换模式写文章登录/注册【工具类】PyCharm+Anaconda+jupyter notebook +pip环境配置Joe.Zhao14 人赞同了该文章Pycharm是一个很好的python的IDE,Anaconda是一个环境管理工具,可以针对不同工作配置不同的环境,如何在Pycharm中调用Anaconda中创建的环境Anaconda环境配置Anaconda 解决了官方 Python 的两大痛点第一:提供了包管理功能,解决安装第三方包经常失败第二:提供环境管理的功能,功能类似 Virtualenv,解决了多版本Python并存、切换的问题。查看Anaconda中所有的Python环境,Window环境下Anaconda Prompt中输入以下命令,其中前面有个‘*’的代表当前环境\n```code\nconda info --env\n\n# conda environments:\n#\nbase * D:\\Anaconda3\ntf D:\\Anaconda3\\envs\\tf\n```\n创建新的Python环境\n```code\nconda create --name python35 python=3.5 #代表创建一个python3.5的环境,我们把它命名为python35\n```\n激活进入创建的环境\n```code\nconda activate python35\n```\n在当前环境中安装package,可以使用pip,还可以用conda\n```code\npip install numpy\nconda install numpy\n```\n退出当前环境,回到base环境\n```code\nconda deactivate\n```\n删除创建的环境,conda创建的环境会在安装目录Anaconda3\\envs\\下面,每一个环境对应一个文件夹,当删除环境的时候,响应的文件夹也会被删除掉\n```code\nconda env remove --name python35 --all\nconda remove --name myenv --all\n```\nconda源头\n```code\nconda config --show channels\nchannels:\n- https://pypi.doubanio.com/simple/\n- defaults\n```\n添加新源\n```code\nconda config --add channels https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/\n```\n删除源\n```code\nconda config --remove channels https://pypi.doubanio.com/simple/\n```\n查看安装package\n```code\nconda list\n\n```\nPycharm 使用Anaconda创建的环境pycharm工程目录中打开/file/settings/Project Interpreter在Project Interpreter中打开Add,左侧边栏目选择Conda Environment,右侧选择Existing environment在文件路径中选择Anaconda安装目录下面的envs目录,下面是该系统安装的所有anaconda环境,进入文件夹,选择python解释器这就就把Pycharm下使用Anconda环境的配置完成了。pip 环境配置conda 环境下也可以用pip来安装包pip安装\n```code\npip install 安装包名\n[...]\nSuccessfully installed SomePackage #安装成功\n```\npip 安装指定源\n```code\npip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple\n```\npip查看是否已安装\n```code\npip show --files 安装包名\n\nName:SomePackage # 包名\nVersion:1.0 # 版本号\nLocation:/my/env/lib/pythonx.x/site-packages # 安装位置\nFiles: # 包含文件等等\n../somepackage/__init__.py\n```\npip检查哪些包需要更新\n```code\npip list --outdated\n```\npip升级包\n```code\npip install --upgrade 要升级的包名\n```\npip卸载包\n```code\npip uninstall 要卸载的包名\n```\npip参数解释\n```code\npip --help\n```\nJupyter notebook使用在利用anaconda创建了tensorflow,pytorch等儿女environment后,想利用jupyter notebook,发现jupyer notebook中只有系统的python3环境,如何把conda创建的环境添加进jupyter notebook中呢,终于解决了这个问题了1. 安装ipykernel\n```code\nconda install ipykernel\n```\n2. 将环境写入notebook的kernel中\n```code\npython -m ipykernel install --user --name your_env_name --display-name your_env_name\n\n//把conda environment pytorch_0.4 add to jupyter notebook kernel display as pytorch_0.4\npython -m ipykernel install --user --name pytorch_0.4 --display-name pytorch_0.4\n```\n3. 打开notebook\n```code\njupyter notebook\n```\n4. magic commands\n```code\n!git clone https://github.com/ultralytics/yolov5\n%ls\n%cd yolov5\n%pip install -qr requirements.txt\n```\n还有一些实用的魔术命令\n```code\n%magic——用来显示所有魔术命令的详细文档\n%time和%timeit——用来测试代码执行时间\n```\n参考文档编辑于 2023-05-21 20:41・IP 属地浙江PyCharmAnacondapip3​赞同 14​​2 条评论​分享​喜欢​收藏​申请转载​"} diff --git a/sources/readme_docs/roadmap.md b/sources/readme_docs/roadmap.md index 15f1806..6dc0839 100644 --- a/sources/readme_docs/roadmap.md +++ b/sources/readme_docs/roadmap.md @@ -47,13 +47,13 @@ - [x] Web Crawl 通用能力:技术文档: 知乎、csdn、阿里云开发者论坛、腾讯云开发者论坛等
- v0.1 -- [ ] Sandbox 环境: 上传、下载文件 +- [x] Sandbox 环境: 上传、下载文件 - [ ] Vector Database & Retrieval - [ ] task retrieval - [ ] tool retrieval - [ ] Connector - [ ] 基于langchain的react模式 -- [ ] 基于sentencebert接入Text Embedding: 向量加载速度提升 +- [x] 基于sentencebert接入Text Embedding: 向量加载速度提升
- v0.2 diff --git a/tests/chains_test.py b/tests/chains_test.py new file mode 100644 index 0000000..af14804 --- /dev/null +++ b/tests/chains_test.py @@ -0,0 +1,168 @@ +import os, sys, requests + +src_dir = os.path.join( + os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +) +sys.path.append(src_dir) + +from dev_opsgpt.tools import ( + toLangchainTools, get_tool_schema, DDGSTool, DocRetrieval, + TOOL_DICT, TOOL_SETS + ) + +from configs.model_config import * +from dev_opsgpt.connector.phase import BasePhase +from dev_opsgpt.connector.agents import BaseAgent +from dev_opsgpt.connector.chains import BaseChain +from dev_opsgpt.connector.connector_schema import ( + Message, load_role_configs, load_phase_configs, load_chain_configs + ) +from dev_opsgpt.connector.configs import AGETN_CONFIGS, CHAIN_CONFIGS, PHASE_CONFIGS +import importlib + +print(src_dir) + +tools = toLangchainTools([TOOL_DICT[i] for i in TOOL_SETS if i in TOOL_DICT]) + +role_configs = load_role_configs(AGETN_CONFIGS) +chain_configs = load_chain_configs(CHAIN_CONFIGS) +phase_configs = load_phase_configs(PHASE_CONFIGS) + +agent_module = importlib.import_module("dev_opsgpt.connector.agents") + + +# agent的测试 +query = Message(role_name="tool_react", role_type="human", + role_content="我有一份时序数据,[0.857, 2.345, 1.234, 4.567, 3.456, 9.876, 5.678, 7.890, 6.789, 8.901, 10.987, 12.345, 11.234, 14.567, 13.456, 19.876, 15.678, 17.890, 16.789, \ + 18.901, 20.987, 22.345, 21.234, 24.567, 23.456, 29.876, 25.678, 27.890, 26.789, 28.901, 30.987, 32.345, 31.234, 34.567, 33.456, 39.876, 35.678, 37.890, 36.789, 38.901, 40.987],\ + 我不知道这份数据是否存在问题,请帮我判断一下", tools=tools) + +query = Message(role_name="tool_react", role_type="human", + role_content="帮我确认下127.0.0.1这个服务器的在10点是否存在异常,请帮我判断一下", tools=tools) + +query = Message(role_name="code_react", role_type="human", + role_content="帮我确认当前目录下有哪些文件", tools=tools) + +# "给我一份冒泡排序的代码" +query = Message(role_name="intention_recognizer", role_type="human", + role_content="对employee_data.csv进行数据分析", tools=tools) + +# role = role_configs["general_planner"] +# agent_class = getattr(agent_module, role.role.agent_type) +# agent = agent_class(role.role, +# task = None, +# memory = None, +# chat_turn=role.chat_turn, +# do_search = role.do_search, +# do_doc_retrieval = role.do_doc_retrieval, +# do_tool_retrieval = role.do_tool_retrieval,) + +# message = agent.run(query) +# print(message.role_content) + + +# chain的测试 + +# query = Message(role_name="deveploer", role_type="human", role_content="编写冒泡排序,并生成测例") +# query = Message(role_name="general_planner", role_type="human", role_content="对employee_data.csv进行数据分析") +# query = Message(role_name="tool_react", role_type="human", role_content="我有一份时序数据,[0.857, 2.345, 1.234, 4.567, 3.456, 9.876, 5.678, 7.890, 6.789, 8.901, 10.987, 12.345, 11.234, 14.567, 13.456, 19.876, 15.678, 17.890, 16.789, 18.901, 20.987, 22.345, 21.234, 24.567, 23.456, 29.876, 25.678, 27.890, 26.789, 28.901, 30.987, 32.345, 31.234, 34.567, 33.456, 39.876, 35.678, 37.890, 36.789, 38.901, 40.987],\我不知道这份数据是否存在问题,请帮我判断一下", tools=tools) + +# role = role_configs[query.role_name] +role1 = role_configs["planner"] +role2 = role_configs["code_react"] + +agents = [ + getattr(agent_module, role1.role.agent_type)(role1.role, + task = None, + memory = None, + do_search = role1.do_search, + do_doc_retrieval = role1.do_doc_retrieval, + do_tool_retrieval = role1.do_tool_retrieval,), + getattr(agent_module, role2.role.agent_type)(role2.role, + task = None, + memory = None, + do_search = role2.do_search, + do_doc_retrieval = role2.do_doc_retrieval, + do_tool_retrieval = role2.do_tool_retrieval,), + ] + +query = Message(role_name="user", role_type="human", + role_content="确认本地是否存在employee_data.csv,并查看它有哪些列和数据类型,分析这份数据的内容,根据这个数据预测未来走势", tools=tools) +query = Message(role_name="user", role_type="human", + role_content="确认本地是否存在employee_data.csv,并查看它有哪些列和数据类型", tools=tools) +chain = BaseChain(chain_configs["dataAnalystChain"], agents, do_code_exec=False) + +# message = chain.step(query) +# print(message.role_content) + +# print("\n".join("\n".join([": ".join(j) for j in i]) for i in chain.get_agents_memory())) +# print("\n".join(": ".join(i) for i in chain.get_memory())) +# print( chain.get_agents_memory_str()) +# print( chain.get_memory_str()) + + + + +# 测试 phase +phase_name = "toolReactPhase" +# phase_name = "codeReactPhase" +# phase_name = "chatPhase" + +phase = BasePhase(phase_name, + task = None, + phase_config = PHASE_CONFIGS, + chain_config = CHAIN_CONFIGS, + role_config = AGETN_CONFIGS, + do_summary=False, + do_code_retrieval=False, + do_doc_retrieval=True, + do_search=False, + ) + +query = Message(role_name="user", role_type="human", + role_content="确认本地是否存在employee_data.csv,并查看它有哪些列和数据类型,并选择合适的数值列画出折线图") + +query = Message(role_name="user", role_type="human", + role_content="判断下127.0.0.1这个服务器的在10点的监控数据,是否存在异常", tools=tools) + +# 根据其他类似的类,新开发个 ExceptionComponent2,继承 AbstractTrafficComponent +# query = Message(role_name="human", role_type="human", role_content="langchain有什么用") + +# output_message = phase.step(query) + +# print(phase.get_chains_memory(content_key="step_content")) +# print(phase.get_chains_memory_str(content_key="step_content")) +# print(output_message.to_tuple_message(return_all=True)) + + +from dev_opsgpt.tools import DDGSTool, CodeRetrieval +# print(DDGSTool.run("langchain是什么", 3)) +# print(CodeRetrieval.run("dsadsadsa", query.role_content, code_limit=3, history_node_list=[])) + + +# from dev_opsgpt.chat.agent_chat import AgentChat + +# agentChat = AgentChat() +# value = { +# "query": "帮我确认下127.0.0.1这个服务器的在10点是否存在异常,请帮我判断一下", +# "phase_name": "toolReactPhase", +# "chain_name": "", +# "history": [], +# "doc_engine_name": "DSADSAD", +# "search_engine_name": "duckduckgo", +# "top_k": 3, +# "score_threshold": 1.0, +# "stream": False, +# "local_doc_url": False, +# "do_search": False, +# "do_doc_retrieval": False, +# "do_code_retrieval": False, +# "do_tool_retrieval": False, +# "custom_phase_configs": {}, +# "custom_chain_configs": {}, +# "custom_role_configs": {}, +# "choose_tools": list(TOOL_SETS) +# } + +# answer = agentChat.chat(**value) +# print(answer) \ No newline at end of file diff --git a/tests/docker_test.py b/tests/docker_test.py index 2fe0816..c797dab 100644 --- a/tests/docker_test.py +++ b/tests/docker_test.py @@ -8,4 +8,37 @@ print(time.time()-st) st = time.time() client.containers.run("ubuntu:latest", "echo hello world") -print(time.time()-st) \ No newline at end of file +print(time.time()-st) + + +import socket + + +def get_ip_address(): + hostname = socket.gethostname() + ip_address = socket.gethostbyname(hostname) + return ip_address + +def get_ipv4_address(): + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + try: + # 使用一个临时套接字连接到公共的 DNS 服务器 + s.connect(("8.8.8.8", 80)) + ip_address = s.getsockname()[0] + finally: + s.close() + return ip_address + +# print(get_ipv4_address()) +# import docker +# client = docker.from_env() + +# containers = client.containers.list(all=True) +# for container in containers: +# container_a_info = client.containers.get(container.id) +# container1_networks = container.attrs['NetworkSettings']['Networks'] +# container_a_ip = container_a_info.attrs['NetworkSettings']['IPAddress'] + +# print(container_a_info.name, container_a_ip, [[k, v["IPAddress"]] for k,v in container1_networks.items() ]) + + diff --git a/tests/file_test.py b/tests/file_test.py new file mode 100644 index 0000000..0ab2230 --- /dev/null +++ b/tests/file_test.py @@ -0,0 +1,42 @@ +import requests, os, sys +# src_dir = os.path.join( +# os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +# ) +# sys.path.append(src_dir) + +# from dev_opsgpt.utils.common_utils import st_load_file +# from dev_opsgpt.sandbox.pycodebox import PyCodeBox +# from examples.file_fastapi import upload_file, download_file +# from pathlib import Path +# import httpx +# from loguru import logger +# from io import BytesIO + + +# def _parse_url(url: str, base_url: str) -> str: +# if (not url.startswith("http") +# and base_url +# ): +# part1 = base_url.strip(" /") +# part2 = url.strip(" /") +# return f"{part1}/{part2}" +# else: +# return url + +# base_url: str = "http://127.0.0.1:7861" +# timeout: float = 60.0, +# url = "/files/upload" +# url = _parse_url(url, base_url) +# logger.debug(url) +# kwargs = {} +# kwargs.setdefault("timeout", timeout) + +# import asyncio +# file = "./torch_test.py" +# upload_filename = st_load_file(file, filename="torch_test.py") +# asyncio.run(upload_file(upload_filename)) + +import requests +url = "http://127.0.0.1:7862/sdfiles/download?filename=torch_test.py&save_filename=torch_test.py" +r = requests.get(url) +print(type(r.text)) \ No newline at end of file diff --git a/tests/openai_test.py b/tests/openai_test.py index 13ad17b..fb4928e 100644 --- a/tests/openai_test.py +++ b/tests/openai_test.py @@ -9,7 +9,7 @@ from configs import llm_model_dict, LLM_MODEL import openai # os.environ["OPENAI_PROXY"] = "socks5h://127.0.0.1:7890" # os.environ["OPENAI_PROXY"] = "http://127.0.0.1:7890" -os.environ["OPENAI_API_KEY"] = "" +# os.environ["OPENAI_API_KEY"] = "" if __name__ == "__main__": diff --git a/tests/sandbox_test.py b/tests/sandbox_test.py index 46b2aa8..951457b 100644 --- a/tests/sandbox_test.py +++ b/tests/sandbox_test.py @@ -13,7 +13,7 @@ src_dir = os.path.join( ) sys.path.append(src_dir) - +from dev_opsgpt.service.sdfile_api import sd_upload_file from dev_opsgpt.sandbox.pycodebox import PyCodeBox from pathlib import Path @@ -24,20 +24,20 @@ from pathlib import Path # print(sys.executable) -# import requests +import requests -# # 设置Jupyter Notebook服务器的URL -# url = 'http://localhost:5050' # 或者是你自己的Jupyter服务器的URL +# 设置Jupyter Notebook服务器的URL +url = 'http://172.25.0.3:5050' # 或者是你自己的Jupyter服务器的URL -# # 发送GET请求来获取Jupyter Notebook的登录页面 -# response = requests.get(url) +# 发送GET请求来获取Jupyter Notebook的登录页面 +response = requests.get(url) -# # 检查响应状态码 -# if response.status_code == 200: -# # 打印响应内容 -# print('connect success') -# else: -# print('connect fail') +# 检查响应状态码 +if response.status_code == 200: + # 打印响应内容 + print('connect success') +else: + print('connect fail') # import subprocess # jupyter = subprocess.Popen( @@ -53,31 +53,42 @@ from pathlib import Path # stdout=subprocess.PIPE, # ) -# 测试1 -import time, psutil -from loguru import logger -pycodebox = PyCodeBox(remote_url="http://localhost:5050", - remote_ip="http://localhost", - remote_port="5050", - token="mytoken", - do_code_exe=True, - do_remote=False) +# # 测试1 +# import time, psutil +# from loguru import logger +# import asyncio +# pycodebox = PyCodeBox(remote_url="http://localhost:5050", +# remote_ip="http://localhost", +# remote_port="5050", +# token="mytoken", +# do_code_exe=True, +# do_remote=False) + +# pycodebox.list_files() +# file = "./torch_test.py" +# upload_file = st_load_file(file, filename="torch_test.py") + +# file_content = upload_file.read() # 读取上传文件的内容 +# print(upload_file, file_content) +# pycodebox.upload("torch_test.py", upload_file) + +# asyncio.run(pycodebox.alist_files()) -reuslt = pycodebox.chat("```print('hello world!')```", do_code_exe=True) -print(reuslt) +# reuslt = pycodebox.chat("```print('hello world!')```", do_code_exe=True) +# print(reuslt) -reuslt = pycodebox.chat("print('hello world!')", do_code_exe=False) -print(reuslt) +# reuslt = pycodebox.chat("print('hello world!')", do_code_exe=False) +# print(reuslt) -for process in psutil.process_iter(["pid", "name", "cmdline"]): - # 检查进程名是否包含"jupyter" - if 'port=5050' in str(process.info["cmdline"]).lower() and \ - "jupyter" in process.info['name'].lower(): +# for process in psutil.process_iter(["pid", "name", "cmdline"]): +# # 检查进程名是否包含"jupyter" +# if 'port=5050' in str(process.info["cmdline"]).lower() and \ +# "jupyter" in process.info['name'].lower(): - logger.warning(f'port=5050, {process.info}') - # 关闭进程 - process.terminate() +# logger.warning(f'port=5050, {process.info}') +# # 关闭进程 +# process.terminate() # 测试2 @@ -103,61 +114,3 @@ for process in psutil.process_iter(["pid", "name", "cmdline"]): # result = codebox.run("print('hello world!')") # print(result) - - - - -# headers = {'Authorization': 'Token mytoken', 'token': 'mytoken'} - -# kernel_url = "http://localhost:5050/api/kernels" - -# response = requests.get(kernel_url, headers=headers) -# if len(response.json())>0: -# kernel_id = response.json()[0]["id"] -# else: -# response = requests.post(kernel_url, headers=headers) -# kernel_id = response.json()["id"] - - -# print(f"ws://localhost:5050/api/kernels/{kernel_id}/channels?token=mytoken") -# ws = create_connection(f"ws://localhost:5050/api/kernels/{kernel_id}/channels?token=mytoken", headers=headers) - -# code_text = "print('hello world!')" -# # code_text = "import matplotlib.pyplot as plt\n\nplt.figure(figsize=(4,2))\nplt.plot([1,2,3,4,5])\nplt.show()" - -# ws.send( -# json.dumps( -# { -# "header": { -# "msg_id": (msg_id := uuid4().hex), -# "msg_type": "execute_request", -# }, -# "parent_header": {}, -# "metadata": {}, -# "content": { -# "code": code_text, -# "silent": True, -# "store_history": True, -# "user_expressions": {}, -# "allow_stdin": False, -# "stop_on_error": True, -# }, -# "channel": "shell", -# "buffers": [], -# } -# ) -# ) - -# while True: -# received_msg = json.loads(ws.recv()) -# if received_msg["msg_type"] == "stream": -# result_msg = received_msg # 找到结果消息 -# break -# elif received_msg["header"]["msg_type"] == "execute_result": -# result_msg = received_msg # 找到结果消息 -# break -# elif received_msg["header"]["msg_type"] == "display_data": -# result_msg = received_msg # 找到结果消息 -# break - -# print(received_msg) diff --git a/tests/text_splitter_test.py b/tests/text_splitter_test.py new file mode 100644 index 0000000..e310834 --- /dev/null +++ b/tests/text_splitter_test.py @@ -0,0 +1,14 @@ +import os, sys + +src_dir = os.path.join( + os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +) +sys.path.append(src_dir) + +from dev_opsgpt.text_splitter import LCTextSplitter + +filepath = "" +lc_textSplitter = LCTextSplitter(filepath) +docs = lc_textSplitter.file2text() + +print(docs[0]) \ No newline at end of file diff --git a/tests/tool_test.py b/tests/tool_test.py new file mode 100644 index 0000000..74a5485 --- /dev/null +++ b/tests/tool_test.py @@ -0,0 +1,114 @@ + + +from langchain.agents import initialize_agent, Tool +from langchain.tools import format_tool_to_openai_function, MoveFileTool, StructuredTool +from pydantic import BaseModel, Field, create_model +from pydantic.schema import model_schema, get_flat_models_from_fields +from typing import List, Set +import jsonref +import json + +import os, sys, requests + +src_dir = os.path.join( + os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +) +sys.path.append(src_dir) + +from dev_opsgpt.tools import ( + WeatherInfo, WorldTimeGetTimezoneByArea, Multiplier, KSigmaDetector, + toLangchainTools, get_tool_schema, + TOOL_DICT, TOOL_SETS + ) +from configs.model_config import (llm_model_dict, LLM_MODEL, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD) + +from langchain.chat_models import ChatOpenAI +from langchain.agents import AgentType, initialize_agent +import langchain + +# langchain.debug = True + +tools = toLangchainTools([WeatherInfo, Multiplier, KSigmaDetector]) + +llm = ChatOpenAI( + streaming=True, + verbose=True, + openai_api_key=llm_model_dict[LLM_MODEL]["api_key"], + openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"], + model_name=LLM_MODEL + ) + +chat_prompt = '''if you can +tools: {tools} +query: {query} + +if you choose llm-tool, you can direct +''' +# chain = LLMChain(prompt=chat_prompt, llm=llm) +# content = chain({"tools": tools, "input": query}) + +# tool的检索 + +# tool参数的填充 + +# 函数执行 + +# from langchain.tools import StructuredTool + +tools = [ + StructuredTool( + name=Multiplier.name, + func=Multiplier.run, + description=Multiplier.description, + args_schema=Multiplier.ToolInputArgs, + ), + StructuredTool( + name=WeatherInfo.name, + func=WeatherInfo.run, + description=WeatherInfo.description, + args_schema=WeatherInfo.ToolInputArgs, + ) + ] + +print(tools[0].func(1,2)) + + +tools = toLangchainTools([TOOL_DICT[i] for i in TOOL_SETS if i in TOOL_DICT]) + +agent = initialize_agent( + tools, + llm, + agent=AgentType.STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION, + verbose=True, + return_intermediate_steps=True +) + +# agent.return_intermediate_steps = True +# content = agent.run("查询北京的行政编码,同时返回北京的天气情况") +# print(content) + +# content = agent.run("判断这份数据是否存在异常,[0.857, 2.345, 1.234, 4.567, 3.456, 9.876, 5.678, 7.890, 6.789, 8.901, 10.987, 12.345, 11.234, 14.567, 13.456, 19.876, 15.678, 17.890, 16.789, 18.901, 20.987, 22.345, 21.234, 24.567, 23.456, 29.876, 25.678, 27.890, 26.789, 28.901, 30.987, 32.345, 31.234, 34.567, 33.456, 39.876, 35.678, 37.890, 36.789, 38.901, 40.987]") +# content = agent("我有一份时序数据,[0.857, 2.345, 1.234, 4.567, 3.456, 9.876, 5.678, 7.890, 6.789, 8.901, 10.987, 12.345, 11.234, 14.567, 13.456, 19.876, 15.678, 17.890, 16.789, 18.901, 20.987, 22.345, 21.234, 24.567, 23.456, 29.876, 25.678, 27.890, 26.789, 28.901, 30.987, 32.345, 31.234, 34.567, 33.456, 39.876, 35.678, 37.890, 36.789, 38.901, 40.987],\我不知道这份数据是否存在问题,请帮我判断一下") +# # print(content) +# from langchain.schema import ( +# AgentAction +# ) + +# s = "" +# if isinstance(content, str): +# s = content +# else: +# for i in content["intermediate_steps"]: +# for j in i: +# if isinstance(j, AgentAction): +# s += j.log + "\n" +# else: +# s += "Observation: " + str(j) + "\n" + +# s += "final answer:" + content["output"] +# print(s) + +# print(content["intermediate_steps"][0][0].log) +# print( content["intermediate_steps"][0][0].log, content[""] + "\n" + content["i"] + "\n" + ) +# content = agent.run("i want to know the timezone of asia/shanghai, list all timezones available for that area.") +# print(content) \ No newline at end of file