update codebase & multi-agent framework
This commit is contained in:
parent
6bc0ca45ce
commit
d5e2bb7acc
4
.gitignore
vendored
4
.gitignore
vendored
@ -1,6 +1,10 @@
|
||||
**/__pycache__
|
||||
knowledge_base
|
||||
logs
|
||||
embedding_models
|
||||
jupyter_work
|
||||
model_config.py
|
||||
server_config.py
|
||||
code_base
|
||||
.DS_Store
|
||||
.idea
|
||||
|
||||
12
Dockerfile
12
Dockerfile
@ -1,10 +1,18 @@
|
||||
From python:3.9-bookworm
|
||||
From python:3.9.18-bookworm
|
||||
|
||||
WORKDIR /home/user
|
||||
|
||||
COPY ./docker_requirements.txt /home/user/docker_requirements.txt
|
||||
COPY ./requirements.txt /home/user/docker_requirements.txt
|
||||
COPY ./jupyter_start.sh /home/user/jupyter_start.sh
|
||||
|
||||
|
||||
RUN apt-get update
|
||||
RUN apt-get install -y iputils-ping telnetd net-tools vim tcpdump
|
||||
# RUN echo telnet stream tcp nowait telnetd /usr/sbin/tcpd /usr/sbin/in.telnetd /etc/inetd.conf
|
||||
# RUN service inetutils-inetd start
|
||||
# service inetutils-inetd status
|
||||
|
||||
|
||||
RUN pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple
|
||||
RUN pip install -r /home/user/docker_requirements.txt
|
||||
|
||||
|
||||
201
LICENSE
Normal file
201
LICENSE
Normal file
@ -0,0 +1,201 @@
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
@ -33,6 +33,10 @@ embedding_model_dict = {
|
||||
"bge-large-zh": "BAAI/bge-large-zh"
|
||||
}
|
||||
|
||||
|
||||
LOCAL_MODEL_DIR = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "embedding_models")
|
||||
embedding_model_dict = {k: f"/home/user/chatbot/embedding_models/{v}" if is_running_in_docker() else f"{LOCAL_MODEL_DIR}/{v}" for k, v in embedding_model_dict.items()}
|
||||
|
||||
# 选用的 Embedding 名称
|
||||
EMBEDDING_MODEL = "text2vec-base"
|
||||
|
||||
@ -97,6 +101,7 @@ llm_model_dict = {
|
||||
|
||||
# LLM 名称
|
||||
LLM_MODEL = "gpt-3.5-turbo"
|
||||
USE_FASTCHAT = "gpt" not in LLM_MODEL # 判断是否进行fastchat
|
||||
|
||||
# LLM 运行设备
|
||||
LLM_DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
||||
@ -112,6 +117,9 @@ SOURCE_PATH = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__fil
|
||||
# 知识库默认存储路径
|
||||
KB_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "knowledge_base")
|
||||
|
||||
# 代码库默认存储路径
|
||||
CB_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "code_base")
|
||||
|
||||
# nltk 模型存储路径
|
||||
NLTK_DATA_PATH = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "nltk_data")
|
||||
|
||||
@ -153,7 +161,7 @@ DEFAULT_VS_TYPE = "faiss"
|
||||
CACHED_VS_NUM = 1
|
||||
|
||||
# 知识库中单段文本长度
|
||||
CHUNK_SIZE = 250
|
||||
CHUNK_SIZE = 500
|
||||
|
||||
# 知识库中相邻文本重合长度
|
||||
OVERLAP_SIZE = 50
|
||||
@ -169,6 +177,9 @@ SCORE_THRESHOLD = 1 if system_name in ["Linux", "Windows"] else 1100
|
||||
# 搜索引擎匹配结题数量
|
||||
SEARCH_ENGINE_TOP_K = 5
|
||||
|
||||
# 代码引擎匹配结题数量
|
||||
CODE_SEARCH_TOP_K = 1
|
||||
|
||||
# 基于本地知识问答的提示词模版
|
||||
PROMPT_TEMPLATE = """【指令】根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,不允许在答案中添加编造成分,答案请使用中文。
|
||||
|
||||
@ -176,6 +187,13 @@ PROMPT_TEMPLATE = """【指令】根据已知信息,简洁和专业的来回
|
||||
|
||||
【问题】{question}"""
|
||||
|
||||
# 基于本地代码知识问答的提示词模版
|
||||
CODE_PROMPT_TEMPLATE = """【指令】根据已知信息来回答问题。
|
||||
|
||||
【已知信息】{context}
|
||||
|
||||
【问题】{question}"""
|
||||
|
||||
# API 是否开启跨域,默认为False,如果需要开启,请设置为True
|
||||
# is open cross domain
|
||||
OPEN_CROSS_DOMAIN = False
|
||||
|
||||
@ -1,41 +1,62 @@
|
||||
from .model_config import LLM_MODEL, LLM_DEVICE
|
||||
import os
|
||||
|
||||
# API 是否开启跨域,默认为False,如果需要开启,请设置为True
|
||||
# is open cross domain
|
||||
OPEN_CROSS_DOMAIN = False
|
||||
|
||||
# 是否用容器来启动服务
|
||||
DOCKER_SERVICE = True
|
||||
# 是否采用容器沙箱
|
||||
SANDBOX_DO_REMOTE = True
|
||||
# 是否采用api服务来进行
|
||||
NO_REMOTE_API = True
|
||||
# 各服务器默认绑定host
|
||||
DEFAULT_BIND_HOST = "127.0.0.1"
|
||||
|
||||
#
|
||||
CONTRAINER_NAME = "devopsgpt_webui"
|
||||
IMAGE_NAME = "devopsgpt:py39"
|
||||
|
||||
# webui.py server
|
||||
WEBUI_SERVER = {
|
||||
"host": DEFAULT_BIND_HOST,
|
||||
"port": 8501,
|
||||
"docker_port": 8501
|
||||
}
|
||||
|
||||
# api.py server
|
||||
API_SERVER = {
|
||||
"host": DEFAULT_BIND_HOST,
|
||||
"port": 7861,
|
||||
"docker_port": 7861
|
||||
}
|
||||
|
||||
# sdfile_api.py server
|
||||
SDFILE_API_SERVER = {
|
||||
"host": DEFAULT_BIND_HOST,
|
||||
"port": 7862,
|
||||
"docker_port": 7862
|
||||
}
|
||||
|
||||
# fastchat openai_api server
|
||||
FSCHAT_OPENAI_API = {
|
||||
"host": DEFAULT_BIND_HOST,
|
||||
"port": 8888, # model_config.llm_model_dict中模型配置的api_base_url需要与这里一致。
|
||||
"docker_port": 8888, # model_config.llm_model_dict中模型配置的api_base_url需要与这里一致。
|
||||
}
|
||||
|
||||
# sandbox api server
|
||||
CONTRAINER_NAME = "devopsgt_default"
|
||||
IMAGE_NAME = "devopsgpt:pypy38"
|
||||
SANDBOX_CONTRAINER_NAME = "devopsgpt_sandbox"
|
||||
SANDBOX_IMAGE_NAME = "devopsgpt:py39"
|
||||
SANDBOX_HOST = os.environ.get("SANDBOX_HOST") or DEFAULT_BIND_HOST # "172.25.0.3"
|
||||
SANDBOX_SERVER = {
|
||||
"host": DEFAULT_BIND_HOST,
|
||||
"host": f"http://{SANDBOX_HOST}",
|
||||
"port": 5050,
|
||||
"url": "http://localhost:5050",
|
||||
"do_remote": True,
|
||||
"docker_port": 5050,
|
||||
"url": f"http://{SANDBOX_HOST}:5050",
|
||||
"do_remote": SANDBOX_DO_REMOTE,
|
||||
}
|
||||
|
||||
|
||||
# fastchat model_worker server
|
||||
# 这些模型必须是在model_config.llm_model_dict中正确配置的。
|
||||
# 在启动startup.py时,可用通过`--model-worker --model-name xxxx`指定模型,不指定则为LLM_MODEL
|
||||
|
||||
15
configs/utils.py
Normal file
15
configs/utils.py
Normal file
@ -0,0 +1,15 @@
|
||||
import os
|
||||
|
||||
def is_running_in_docker():
|
||||
"""
|
||||
检查当前代码是否在 Docker 容器中运行
|
||||
"""
|
||||
# 检查是否存在 /.dockerenv 文件
|
||||
if os.path.exists('/.dockerenv'):
|
||||
return True
|
||||
|
||||
# 检查 cgroup 文件系统是否为 /docker/ 开头
|
||||
if os.path.exists("/proc/1/cgroup"):
|
||||
with open('/proc/1/cgroup', 'rt') as f:
|
||||
return '/docker/' in f.read()
|
||||
return False
|
||||
@ -2,7 +2,12 @@ from .base_chat import Chat
|
||||
from .knowledge_chat import KnowledgeChat
|
||||
from .llm_chat import LLMChat
|
||||
from .search_chat import SearchChat
|
||||
from .tool_chat import ToolChat
|
||||
from .data_chat import DataChat
|
||||
from .code_chat import CodeChat
|
||||
from .agent_chat import AgentChat
|
||||
|
||||
|
||||
__all__ = [
|
||||
"Chat", "KnowledgeChat", "LLMChat", "SearchChat"
|
||||
"Chat", "KnowledgeChat", "LLMChat", "SearchChat", "ToolChat", "DataChat", "CodeChat", "AgentChat"
|
||||
]
|
||||
|
||||
169
dev_opsgpt/chat/agent_chat.py
Normal file
169
dev_opsgpt/chat/agent_chat.py
Normal file
@ -0,0 +1,169 @@
|
||||
from fastapi import Body, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from typing import List
|
||||
from loguru import logger
|
||||
import importlib
|
||||
import copy
|
||||
import json
|
||||
|
||||
from configs.model_config import (
|
||||
llm_model_dict, LLM_MODEL, PROMPT_TEMPLATE,
|
||||
VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD)
|
||||
|
||||
from dev_opsgpt.tools import (
|
||||
toLangchainTools,
|
||||
TOOL_DICT, TOOL_SETS
|
||||
)
|
||||
|
||||
from dev_opsgpt.connector.phase import BasePhase
|
||||
from dev_opsgpt.connector.agents import BaseAgent, ReactAgent
|
||||
from dev_opsgpt.connector.chains import BaseChain
|
||||
from dev_opsgpt.connector.connector_schema import (
|
||||
Message,
|
||||
load_phase_configs, load_chain_configs, load_role_configs
|
||||
)
|
||||
from dev_opsgpt.connector.shcema import Memory
|
||||
|
||||
from dev_opsgpt.chat.utils import History, wrap_done
|
||||
from dev_opsgpt.connector.configs import PHASE_CONFIGS, AGETN_CONFIGS, CHAIN_CONFIGS
|
||||
|
||||
PHASE_MODULE = importlib.import_module("dev_opsgpt.connector.phase")
|
||||
|
||||
|
||||
|
||||
class AgentChat:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine_name: str = "",
|
||||
top_k: int = 1,
|
||||
stream: bool = False,
|
||||
) -> None:
|
||||
self.top_k = top_k
|
||||
self.stream = stream
|
||||
|
||||
def chat(
|
||||
self,
|
||||
query: str = Body(..., description="用户输入", examples=["hello"]),
|
||||
phase_name: str = Body(..., description="执行场景名称", examples=["chatPhase"]),
|
||||
chain_name: str = Body(..., description="执行链的名称", examples=["chatChain"]),
|
||||
history: List[History] = Body(
|
||||
[], description="历史对话",
|
||||
examples=[[{"role": "user", "content": "我们来玩成语接龙,我先来,生龙活虎"}]]
|
||||
),
|
||||
doc_engine_name: str = Body(..., description="知识库名称", examples=["samples"]),
|
||||
search_engine_name: str = Body(..., description="搜索引擎名称", examples=["duckduckgo"]),
|
||||
code_engine_name: str = Body(..., description="代码引擎名称", examples=["samples"]),
|
||||
top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"),
|
||||
score_threshold: float = Body(SCORE_THRESHOLD, description="知识库匹配相关度阈值,取值范围在0-1之间,SCORE越小,相关度越高,取到1相当于不筛选,建议设置在0.5左右", ge=0, le=1),
|
||||
stream: bool = Body(False, description="流式输出"),
|
||||
local_doc_url: bool = Body(False, description="知识文件返回本地路径(true)或URL(false)"),
|
||||
choose_tools: List[str] = Body([], description="选择tool的集合"),
|
||||
do_search: bool = Body(False, description="是否进行搜索"),
|
||||
do_doc_retrieval: bool = Body(False, description="是否进行知识库检索"),
|
||||
do_code_retrieval: bool = Body(False, description="是否执行代码检索"),
|
||||
do_tool_retrieval: bool = Body(False, description="是否执行工具检索"),
|
||||
custom_phase_configs: dict = Body({}, description="自定义phase配置"),
|
||||
custom_chain_configs: dict = Body({}, description="自定义chain配置"),
|
||||
custom_role_configs: dict = Body({}, description="自定义role配置"),
|
||||
history_node_list: List = Body([], description="代码历史相关节点"),
|
||||
isDetaild: bool = Body([], description="是否输出完整的agent相关内容"),
|
||||
**kargs
|
||||
) -> Message:
|
||||
|
||||
# update configs
|
||||
phase_configs, chain_configs, agent_configs = self.update_configs(
|
||||
custom_phase_configs, custom_chain_configs, custom_role_configs)
|
||||
# choose tools
|
||||
tools = toLangchainTools([TOOL_DICT[i] for i in choose_tools if i in TOOL_DICT])
|
||||
input_message = Message(
|
||||
role_content=query,
|
||||
role_type="human",
|
||||
role_name="user",
|
||||
input_query=query,
|
||||
phase_name=phase_name,
|
||||
chain_name=chain_name,
|
||||
do_search=do_search,
|
||||
do_doc_retrieval=do_doc_retrieval,
|
||||
do_code_retrieval=do_code_retrieval,
|
||||
do_tool_retrieval=do_tool_retrieval,
|
||||
doc_engine_name=doc_engine_name, search_engine_name=search_engine_name,
|
||||
code_engine_name=code_engine_name,
|
||||
score_threshold=score_threshold, top_k=top_k,
|
||||
history_node_list=history_node_list,
|
||||
tools=tools
|
||||
)
|
||||
# history memory mangemant
|
||||
history = Memory([
|
||||
Message(role_name=i["role"], role_type=i["role"], role_content=i["content"])
|
||||
for i in history
|
||||
])
|
||||
# start to execute
|
||||
phase_class = getattr(PHASE_MODULE, phase_configs[input_message.phase_name]["phase_type"])
|
||||
phase = phase_class(input_message.phase_name,
|
||||
task = input_message.task,
|
||||
phase_config = phase_configs,
|
||||
chain_config = chain_configs,
|
||||
role_config = agent_configs,
|
||||
do_summary=phase_configs[input_message.phase_name]["do_summary"],
|
||||
do_code_retrieval=input_message.do_code_retrieval,
|
||||
do_doc_retrieval=input_message.do_doc_retrieval,
|
||||
do_search=input_message.do_search,
|
||||
)
|
||||
output_message, local_memory = phase.step(input_message, history)
|
||||
# logger.debug(f"local_memory: {local_memory.to_str_messages(content_key='step_content')}")
|
||||
|
||||
# return {
|
||||
# "answer": output_message.role_content,
|
||||
# "db_docs": output_message.db_docs,
|
||||
# "search_docs": output_message.search_docs,
|
||||
# "code_docs": output_message.code_docs,
|
||||
# "figures": output_message.figures
|
||||
# }
|
||||
|
||||
def chat_iterator(message: Message, local_memory: Memory, isDetaild=False):
|
||||
result = {
|
||||
"answer": "",
|
||||
"db_docs": [str(doc) for doc in message.db_docs],
|
||||
"search_docs": [str(doc) for doc in message.search_docs],
|
||||
"code_docs": [str(doc) for doc in message.code_docs],
|
||||
"related_nodes": [doc.get_related_node() for idx, doc in enumerate(message.code_docs) if idx==0],
|
||||
"figures": message.figures
|
||||
}
|
||||
|
||||
|
||||
related_nodes, has_nodes = [], [ ]
|
||||
for nodes in result["related_nodes"]:
|
||||
for node in nodes:
|
||||
if node not in has_nodes:
|
||||
related_nodes.append(node)
|
||||
result["related_nodes"] = related_nodes
|
||||
|
||||
# logger.debug(f"{result['figures'].keys()}")
|
||||
message_str = local_memory.to_str_messages(content_key='step_content') if isDetaild else message.role_content
|
||||
if self.stream:
|
||||
for token in message_str:
|
||||
result["answer"] = token
|
||||
yield json.dumps(result, ensure_ascii=False)
|
||||
else:
|
||||
for token in message_str:
|
||||
result["answer"] += token
|
||||
yield json.dumps(result, ensure_ascii=False)
|
||||
|
||||
return StreamingResponse(chat_iterator(output_message, local_memory, isDetaild), media_type="text/event-stream")
|
||||
|
||||
def _chat(self, ):
|
||||
pass
|
||||
|
||||
def update_configs(self, custom_phase_configs, custom_chain_configs, custom_role_configs):
|
||||
'''update phase/chain/agent configs'''
|
||||
phase_configs = copy.deepcopy(PHASE_CONFIGS)
|
||||
phase_configs.update(custom_phase_configs)
|
||||
chain_configs = copy.deepcopy(CHAIN_CONFIGS)
|
||||
chain_configs.update(custom_chain_configs)
|
||||
agent_configs = copy.deepcopy(AGETN_CONFIGS)
|
||||
agent_configs.update(custom_role_configs)
|
||||
# phase_configs = load_phase_configs(new_phase_configs)
|
||||
# chian_configs = load_chain_configs(new_chain_configs)
|
||||
# agent_configs = load_role_configs(new_agent_configs)
|
||||
return phase_configs, chain_configs, agent_configs
|
||||
@ -3,12 +3,11 @@ from fastapi.responses import StreamingResponse
|
||||
import asyncio, json
|
||||
from typing import List, AsyncIterable
|
||||
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain import LLMChain
|
||||
from langchain.callbacks import AsyncIteratorCallbackHandler
|
||||
from langchain.prompts.chat import ChatPromptTemplate
|
||||
|
||||
|
||||
from dev_opsgpt.llm_models import getChatModel
|
||||
from dev_opsgpt.chat.utils import History, wrap_done
|
||||
from configs.model_config import (llm_model_dict, LLM_MODEL, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD)
|
||||
from dev_opsgpt.utils import BaseResponse
|
||||
@ -16,30 +15,6 @@ from loguru import logger
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def getChatModel(callBack: AsyncIteratorCallbackHandler = None):
|
||||
if callBack is None:
|
||||
model = ChatOpenAI(
|
||||
streaming=True,
|
||||
verbose=True,
|
||||
openai_api_key=llm_model_dict[LLM_MODEL]["api_key"],
|
||||
openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"],
|
||||
model_name=LLM_MODEL
|
||||
)
|
||||
else:
|
||||
model = ChatOpenAI(
|
||||
streaming=True,
|
||||
verbose=True,
|
||||
callBack=[callBack],
|
||||
openai_api_key=llm_model_dict[LLM_MODEL]["api_key"],
|
||||
openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"],
|
||||
model_name=LLM_MODEL
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
class Chat:
|
||||
def __init__(
|
||||
self,
|
||||
@ -67,6 +42,7 @@ class Chat:
|
||||
stream: bool = Body(False, description="流式输出"),
|
||||
local_doc_url: bool = Body(False, description="知识文件返回本地路径(true)或URL(false)"),
|
||||
request: Request = None,
|
||||
**kargs
|
||||
):
|
||||
self.engine_name = engine_name if isinstance(engine_name, str) else engine_name.default
|
||||
self.top_k = top_k if isinstance(top_k, int) else top_k.default
|
||||
@ -74,18 +50,23 @@ class Chat:
|
||||
self.stream = stream if isinstance(stream, bool) else stream.default
|
||||
self.local_doc_url = local_doc_url if isinstance(local_doc_url, bool) else local_doc_url.default
|
||||
self.request = request
|
||||
return self._chat(query, history)
|
||||
return self._chat(query, history, **kargs)
|
||||
|
||||
def _chat(self, query: str, history: List[History]):
|
||||
def _chat(self, query: str, history: List[History], **kargs):
|
||||
history = [History(**h) if isinstance(h, dict) else h for h in history]
|
||||
|
||||
## check service dependcy is ok
|
||||
service_status = self.check_service_status()
|
||||
|
||||
if service_status.code!=200: return service_status
|
||||
|
||||
def chat_iterator(query: str, history: List[History]):
|
||||
model = getChatModel()
|
||||
|
||||
result ,content = self.create_task(query, history, model)
|
||||
result, content = self.create_task(query, history, model, **kargs)
|
||||
logger.info('result={}'.format(result))
|
||||
logger.info('content={}'.format(content))
|
||||
|
||||
if self.stream:
|
||||
for token in content["text"]:
|
||||
result["answer"] = token
|
||||
@ -144,7 +125,7 @@ class Chat:
|
||||
return StreamingResponse(chat_iterator(query, history),
|
||||
media_type="text/event-stream")
|
||||
|
||||
def create_task(self, query: str, history: List[History], model):
|
||||
def create_task(self, query: str, history: List[History], model, **kargs):
|
||||
'''构建 llm 生成任务'''
|
||||
chat_prompt = ChatPromptTemplate.from_messages(
|
||||
[i.to_msg_tuple() for i in history] + [("human", "{input}")]
|
||||
|
||||
143
dev_opsgpt/chat/code_chat.py
Normal file
143
dev_opsgpt/chat/code_chat.py
Normal file
@ -0,0 +1,143 @@
|
||||
# encoding: utf-8
|
||||
'''
|
||||
@author: 温进
|
||||
@file: code_chat.py
|
||||
@time: 2023/10/24 下午4:04
|
||||
@desc:
|
||||
'''
|
||||
|
||||
from fastapi import Request, Body
|
||||
import os, asyncio
|
||||
from urllib.parse import urlencode
|
||||
from typing import List
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from langchain import LLMChain
|
||||
from langchain.callbacks import AsyncIteratorCallbackHandler
|
||||
from langchain.prompts.chat import ChatPromptTemplate
|
||||
|
||||
from configs.model_config import (
|
||||
llm_model_dict, LLM_MODEL, PROMPT_TEMPLATE,
|
||||
VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, CODE_PROMPT_TEMPLATE)
|
||||
from dev_opsgpt.chat.utils import History, wrap_done
|
||||
from dev_opsgpt.utils import BaseResponse
|
||||
from .base_chat import Chat
|
||||
from dev_opsgpt.llm_models import getChatModel
|
||||
|
||||
from dev_opsgpt.service.kb_api import search_docs, KBServiceFactory
|
||||
from dev_opsgpt.service.cb_api import search_code, cb_exists_api
|
||||
from loguru import logger
|
||||
import json
|
||||
|
||||
|
||||
class CodeChat(Chat):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
code_base_name: str = '',
|
||||
code_limit: int = 1,
|
||||
stream: bool = False,
|
||||
request: Request = None,
|
||||
) -> None:
|
||||
super().__init__(engine_name=code_base_name, stream=stream)
|
||||
self.engine_name = code_base_name
|
||||
self.code_limit = code_limit
|
||||
self.request = request
|
||||
self.history_node_list = []
|
||||
|
||||
def check_service_status(self) -> BaseResponse:
|
||||
cb = cb_exists_api(self.engine_name)
|
||||
if not cb:
|
||||
return BaseResponse(code=404, msg=f"未找到代码库 {self.engine_name}")
|
||||
return BaseResponse(code=200, msg=f"找到代码库 {self.engine_name}")
|
||||
|
||||
def _process(self, query: str, history: List[History], model):
|
||||
'''process'''
|
||||
codes_res = search_code(query=query, cb_name=self.engine_name, code_limit=self.code_limit,
|
||||
history_node_list=self.history_node_list)
|
||||
|
||||
codes = codes_res['related_code']
|
||||
nodes = codes_res['related_node']
|
||||
|
||||
# update node names
|
||||
node_names = [node[0] for node in nodes]
|
||||
self.history_node_list.extend(node_names)
|
||||
self.history_node_list = list(set(self.history_node_list))
|
||||
|
||||
context = "\n".join(codes)
|
||||
source_nodes = []
|
||||
|
||||
for inum, node_info in enumerate(nodes[0:5]):
|
||||
node_name, node_type, node_score = node_info[0], node_info[1], node_info[2]
|
||||
source_nodes.append(f'{inum + 1}. 节点名为 {node_name}, 节点类型为 `{node_type}`, 节点得分为 `{node_score}`')
|
||||
|
||||
chat_prompt = ChatPromptTemplate.from_messages(
|
||||
[i.to_msg_tuple() for i in history] + [("human", CODE_PROMPT_TEMPLATE)]
|
||||
)
|
||||
chain = LLMChain(prompt=chat_prompt, llm=model)
|
||||
result = {"answer": "", "codes": source_nodes}
|
||||
return chain, context, result
|
||||
|
||||
def chat(
|
||||
self,
|
||||
query: str = Body(..., description="用户输入", examples=["hello"]),
|
||||
history: List[History] = Body(
|
||||
[], description="历史对话",
|
||||
examples=[[{"role": "user", "content": "我们来玩成语接龙,我先来,生龙活虎"}]]
|
||||
),
|
||||
engine_name: str = Body(..., description="知识库名称", examples=["samples"]),
|
||||
code_limit: int = Body(1, examples=['1']),
|
||||
stream: bool = Body(False, description="流式输出"),
|
||||
local_doc_url: bool = Body(False, description="知识文件返回本地路径(true)或URL(false)"),
|
||||
request: Request = Body(None),
|
||||
**kargs
|
||||
):
|
||||
self.engine_name = engine_name if isinstance(engine_name, str) else engine_name.default
|
||||
self.code_limit = code_limit
|
||||
self.stream = stream if isinstance(stream, bool) else stream.default
|
||||
self.local_doc_url = local_doc_url if isinstance(local_doc_url, bool) else local_doc_url.default
|
||||
self.request = request
|
||||
return self._chat(query, history, **kargs)
|
||||
|
||||
def _chat(self, query: str, history: List[History], **kargs):
|
||||
history = [History(**h) if isinstance(h, dict) else h for h in history]
|
||||
|
||||
service_status = self.check_service_status()
|
||||
|
||||
if service_status.code != 200: return service_status
|
||||
|
||||
def chat_iterator(query: str, history: List[History]):
|
||||
model = getChatModel()
|
||||
|
||||
result, content = self.create_task(query, history, model, **kargs)
|
||||
# logger.info('result={}'.format(result))
|
||||
# logger.info('content={}'.format(content))
|
||||
|
||||
if self.stream:
|
||||
for token in content["text"]:
|
||||
result["answer"] = token
|
||||
yield json.dumps(result, ensure_ascii=False)
|
||||
else:
|
||||
for token in content["text"]:
|
||||
result["answer"] += token
|
||||
yield json.dumps(result, ensure_ascii=False)
|
||||
|
||||
return StreamingResponse(chat_iterator(query, history),
|
||||
media_type="text/event-stream")
|
||||
|
||||
def create_task(self, query: str, history: List[History], model):
|
||||
'''构建 llm 生成任务'''
|
||||
chain, context, result = self._process(query, history, model)
|
||||
logger.info('chain={}'.format(chain))
|
||||
try:
|
||||
content = chain({"context": context, "question": query})
|
||||
except Exception as e:
|
||||
content = {"text": str(e)}
|
||||
return result, content
|
||||
|
||||
def create_atask(self, query, history, model, callback: AsyncIteratorCallbackHandler):
|
||||
chain, context, result = self._process(query, history, model)
|
||||
task = asyncio.create_task(wrap_done(
|
||||
chain.acall({"context": context, "question": query}), callback.done
|
||||
))
|
||||
return task, result
|
||||
229
dev_opsgpt/chat/data_chat.py
Normal file
229
dev_opsgpt/chat/data_chat.py
Normal file
@ -0,0 +1,229 @@
|
||||
import asyncio
|
||||
from typing import List
|
||||
|
||||
from langchain import LLMChain
|
||||
from langchain.callbacks import AsyncIteratorCallbackHandler
|
||||
from langchain.prompts.chat import ChatPromptTemplate
|
||||
from langchain.agents import AgentType, initialize_agent
|
||||
|
||||
from dev_opsgpt.tools import (
|
||||
WeatherInfo, WorldTimeGetTimezoneByArea, Multiplier,
|
||||
toLangchainTools, get_tool_schema
|
||||
)
|
||||
from .utils import History, wrap_done
|
||||
from .base_chat import Chat
|
||||
from loguru import logger
|
||||
import json, re
|
||||
|
||||
from dev_opsgpt.sandbox import PyCodeBox, CodeBoxResponse
|
||||
from configs.server_config import SANDBOX_SERVER
|
||||
|
||||
def get_tool_agent(tools, llm):
|
||||
return initialize_agent(
|
||||
tools,
|
||||
llm,
|
||||
agent=AgentType.STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION,
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
PROMPT_TEMPLATE = """
|
||||
`角色`
|
||||
你是一个数据分析师,借鉴下述步骤,逐步完成数据分析任务的拆解和代码编写,尽可能帮助和准确地回答用户的问题。
|
||||
|
||||
数据文件的存放路径为 `./`
|
||||
|
||||
`数据分析流程`
|
||||
- 判断文件是否存在,并读取文件数据
|
||||
- 输出数据的基本信息,包括但不限于字段、文本、数据类型等
|
||||
- 输出数据的详细统计信息
|
||||
- 判断是否需要画图分析,选择合适的字段进行画图
|
||||
- 判断数据是否需要进行清洗
|
||||
- 判断数据或图片是否需要保存
|
||||
...
|
||||
- 结合数据统计分析结果和画图结果,进行总结和分析这份数据的价值
|
||||
|
||||
`要求`
|
||||
- 每轮选择一个数据分析流程,需要综合考虑上轮和后续的可能影响
|
||||
- 数据分析流程只提供参考,不要拘泥于它的具体流程,要有自己的思考
|
||||
- 使用JSON blob来指定一个计划,通过提供task_status关键字(任务状态)、plan关键字(数据分析计划)和code关键字(可执行代码)。
|
||||
|
||||
合法的 "task_status" 值: "finished" 表明当前用户问题已被准确回答 或者 "continued" 表明用户问题仍需要进一步分析
|
||||
|
||||
`$JSON_BLOB如下所示`
|
||||
```
|
||||
{{
|
||||
"task_status": $TASK_STATUS,
|
||||
"plan": $PLAN,
|
||||
"code": ```python\n$CODE```
|
||||
}}
|
||||
```
|
||||
|
||||
`跟随如下示例`
|
||||
问题: 输入待回答的问题
|
||||
行动:$JSON_BLOB
|
||||
|
||||
... (重复 行动 N 次,每次只生成一个行动)
|
||||
|
||||
行动:
|
||||
```
|
||||
{{
|
||||
"task_status": "finished",
|
||||
"plan": 我已经可以回答用户问题了,最后回答用户的内容
|
||||
}}
|
||||
|
||||
```
|
||||
|
||||
`数据分析,开始`
|
||||
|
||||
问题:{query}
|
||||
"""
|
||||
|
||||
|
||||
PROMPT_TEMPLATE_2 = """
|
||||
`角色`
|
||||
你是一个数据分析师,借鉴下述步骤,逐步完成数据分析任务的拆解和代码编写,尽可能帮助和准确地回答用户的问题。
|
||||
|
||||
数据文件的存放路径为 `./`
|
||||
|
||||
`数据分析流程`
|
||||
- 判断文件是否存在,并读取文件数据
|
||||
- 输出数据的基本信息,包括但不限于字段、文本、数据类型等
|
||||
- 输出数据的详细统计信息
|
||||
- 判断数据是否需要进行清洗
|
||||
- 判断是否需要画图分析,选择合适的字段进行画图
|
||||
- 判断清洗后数据或图片是否需要保存
|
||||
...
|
||||
- 结合数据统计分析结果和画图结果,进行总结和分析这份数据的价值
|
||||
|
||||
`要求`
|
||||
- 每轮选择一个数据分析流程,需要综合考虑上轮和后续的可能影响
|
||||
- 数据分析流程只提供参考,不要拘泥于它的具体流程,要有自己的思考
|
||||
- 使用JSON blob来指定一个计划,通过提供task_status关键字(任务状态)、plan关键字(数据分析计划)和code关键字(可执行代码)。
|
||||
|
||||
合法的 "task_status" 值: "finished" 表明当前用户问题已被准确回答 或者 "continued" 表明用户问题仍需要进一步分析
|
||||
|
||||
`$JSON_BLOB如下所示`
|
||||
```
|
||||
{{
|
||||
"task_status": $TASK_STATUS,
|
||||
"plan": $PLAN,
|
||||
"code": ```python\n$CODE```
|
||||
}}
|
||||
```
|
||||
|
||||
`跟随如下示例`
|
||||
问题: 输入待回答的问题
|
||||
行动:$JSON_BLOB
|
||||
|
||||
... (重复 行动 N 次,每次只生成一个行动)
|
||||
|
||||
行动:
|
||||
```
|
||||
{{
|
||||
"task_status": "finished",
|
||||
"plan": 我已经可以回答用户问题了,最后回答用户的内容
|
||||
}}
|
||||
|
||||
`数据分析,开始`
|
||||
|
||||
问题:上传了一份employee_data.csv文件,请对它进行数据分析
|
||||
|
||||
问题:{query}
|
||||
{history}
|
||||
|
||||
"""
|
||||
|
||||
class DataChat(Chat):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine_name: str = "",
|
||||
top_k: int = 1,
|
||||
stream: bool = False,
|
||||
) -> None:
|
||||
super().__init__(engine_name, top_k, stream)
|
||||
self.tool_prompt = """结合上下文信息,{tools} {input}"""
|
||||
self.codebox = PyCodeBox(
|
||||
remote_url=SANDBOX_SERVER["url"],
|
||||
remote_ip=SANDBOX_SERVER["host"], # "http://localhost",
|
||||
remote_port=SANDBOX_SERVER["port"],
|
||||
token="mytoken",
|
||||
do_code_exe=True,
|
||||
do_remote=SANDBOX_SERVER["do_remote"]
|
||||
)
|
||||
|
||||
def create_task(self, query: str, history: List[History], model):
|
||||
'''构建 llm 生成任务'''
|
||||
logger.debug("content:{}".format([i.to_msg_tuple() for i in history] + [("human", PROMPT_TEMPLATE)]))
|
||||
chat_prompt = ChatPromptTemplate.from_messages(
|
||||
[i.to_msg_tuple() for i in history] + [("human", PROMPT_TEMPLATE)]
|
||||
)
|
||||
pattern = re.compile(r"```(?:json)?\n(.*?)\n", re.DOTALL)
|
||||
internal_history = []
|
||||
retry_nums = 2
|
||||
while retry_nums >= 0:
|
||||
if len(internal_history) == 0:
|
||||
chat_prompt = ChatPromptTemplate.from_messages(
|
||||
[i.to_msg_tuple() for i in history] + [("human", PROMPT_TEMPLATE)]
|
||||
)
|
||||
else:
|
||||
chat_prompt = ChatPromptTemplate.from_messages(
|
||||
[i.to_msg_tuple() for i in history] + [("human", PROMPT_TEMPLATE_2)]
|
||||
)
|
||||
|
||||
chain = LLMChain(prompt=chat_prompt, llm=model)
|
||||
content = chain({"query": query, "history": "\n".join(internal_history)})["text"]
|
||||
|
||||
# content = pattern.search(content)
|
||||
# logger.info(f"content: {content}")
|
||||
# content = json.loads(content.group(1).strip(), strict=False)
|
||||
|
||||
internal_history.append(f"{content}")
|
||||
refer_info = "\n".join(internal_history)
|
||||
logger.info(f"refer_info: {refer_info}")
|
||||
try:
|
||||
content = content.split("行动:")[-1].split("行动:")[-1]
|
||||
content = json.loads(content)
|
||||
except:
|
||||
content = content.split("行动:")[-1].split("行动:")[-1]
|
||||
content = eval(content)
|
||||
|
||||
if "finished" == content["task_status"]:
|
||||
break
|
||||
elif "code" in content:
|
||||
# elif "```code" in content or "```python" in content:
|
||||
# code_text = self.codebox.decode_code_from_text(content)
|
||||
code_text = content["code"]
|
||||
codebox_res = self.codebox.chat("```"+code_text+"```", do_code_exe=True)
|
||||
|
||||
if codebox_res is not None and codebox_res.code_exe_status != 200:
|
||||
logger.warning(f"{codebox_res.code_exe_response}")
|
||||
internal_history.append(f"观察: 根据这个报错信息 {codebox_res.code_exe_response},进行代码修复")
|
||||
|
||||
if codebox_res is not None and codebox_res.code_exe_status == 200:
|
||||
if codebox_res.code_exe_type == "image/png":
|
||||
base_text = f"```\n{code_text}\n```\n\n"
|
||||
img_html = "<img src='data:image/png;base64,{}' class='img-fluid'>".format(
|
||||
codebox_res.code_exe_response
|
||||
)
|
||||
internal_history.append(f"观察: {img_html}")
|
||||
# logger.info('```\n'+code_text+'\n```'+"\n\n"+'```\n'+codebox_res.code_exe_response+'\n```')
|
||||
else:
|
||||
internal_history.append(f"观察: {codebox_res.code_exe_response}")
|
||||
# logger.info('```\n'+code_text+'\n```'+"\n\n"+'```\n'+codebox_res.code_exe_response+'\n```')
|
||||
else:
|
||||
internal_history.append(f"观察:下一步应该怎么做?")
|
||||
retry_nums -= 1
|
||||
|
||||
|
||||
return {"answer": "", "docs": ""}, {"text": "\n".join(internal_history)}
|
||||
|
||||
def create_atask(self, query, history, model, callback: AsyncIteratorCallbackHandler):
|
||||
chat_prompt = ChatPromptTemplate.from_messages(
|
||||
[i.to_msg_tuple() for i in history] + [("human", PROMPT_TEMPLATE)]
|
||||
)
|
||||
chain = LLMChain(prompt=chat_prompt, llm=model)
|
||||
task = asyncio.create_task(wrap_done(
|
||||
chain.acall({"input": query}), callback.done
|
||||
))
|
||||
return task, {"answer": "", "docs": ""}
|
||||
@ -1,6 +1,5 @@
|
||||
from fastapi import Request
|
||||
import os, asyncio
|
||||
from urllib.parse import urlencode
|
||||
from typing import List, Optional, Dict
|
||||
|
||||
from langchain import LLMChain
|
||||
|
||||
84
dev_opsgpt/chat/tool_chat.py
Normal file
84
dev_opsgpt/chat/tool_chat.py
Normal file
@ -0,0 +1,84 @@
|
||||
import asyncio
|
||||
from typing import List
|
||||
|
||||
from langchain import LLMChain
|
||||
from langchain.callbacks import AsyncIteratorCallbackHandler
|
||||
from langchain.prompts.chat import ChatPromptTemplate
|
||||
from langchain.agents import AgentType, initialize_agent
|
||||
import langchain
|
||||
from langchain.schema import (
|
||||
AgentAction
|
||||
)
|
||||
|
||||
|
||||
# langchain.debug = True
|
||||
|
||||
from dev_opsgpt.tools import (
|
||||
TOOL_SETS, TOOL_DICT,
|
||||
toLangchainTools, get_tool_schema
|
||||
)
|
||||
from .utils import History, wrap_done
|
||||
from .base_chat import Chat
|
||||
from loguru import logger
|
||||
|
||||
|
||||
def get_tool_agent(tools, llm):
|
||||
return initialize_agent(
|
||||
tools,
|
||||
llm,
|
||||
agent=AgentType.STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION,
|
||||
verbose=True,
|
||||
return_intermediate_steps=True
|
||||
)
|
||||
|
||||
|
||||
class ToolChat(Chat):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine_name: str = "",
|
||||
top_k: int = 1,
|
||||
stream: bool = False,
|
||||
) -> None:
|
||||
super().__init__(engine_name, top_k, stream)
|
||||
self.tool_prompt = """结合上下文信息,{tools} {input}"""
|
||||
self.tools = toLangchainTools([TOOL_DICT[i] for i in TOOL_SETS if i in TOOL_DICT])
|
||||
|
||||
def create_task(self, query: str, history: List[History], model, **kargs):
|
||||
'''构建 llm 生成任务'''
|
||||
logger.debug("content:{}".format([i.to_msg_tuple() for i in history] + [("human", "{query}")]))
|
||||
# chat_prompt = ChatPromptTemplate.from_messages(
|
||||
# [i.to_msg_tuple() for i in history] + [("human", "{query}")]
|
||||
# )
|
||||
tools = kargs.get("tool_sets", [])
|
||||
tools = toLangchainTools([TOOL_DICT[i] for i in tools if i in TOOL_DICT])
|
||||
agent = get_tool_agent(tools if tools else self.tools, model)
|
||||
content = agent(query)
|
||||
|
||||
logger.debug(f"content: {content}")
|
||||
|
||||
s = ""
|
||||
if isinstance(content, str):
|
||||
s = content
|
||||
else:
|
||||
for i in content["intermediate_steps"]:
|
||||
for j in i:
|
||||
if isinstance(j, AgentAction):
|
||||
s += j.log + "\n"
|
||||
else:
|
||||
s += "Observation: " + str(j) + "\n"
|
||||
|
||||
s += "final answer:" + content["output"]
|
||||
# chain = LLMChain(prompt=chat_prompt, llm=model)
|
||||
# content = chain({"tools": tools, "input": query})
|
||||
return {"answer": "", "docs": ""}, {"text": s}
|
||||
|
||||
def create_atask(self, query, history, model, callback: AsyncIteratorCallbackHandler):
|
||||
chat_prompt = ChatPromptTemplate.from_messages(
|
||||
[i.to_msg_tuple() for i in history] + [("human", self.tool_prompt)]
|
||||
)
|
||||
chain = LLMChain(prompt=chat_prompt, llm=model)
|
||||
task = asyncio.create_task(wrap_done(
|
||||
chain.acall({"input": query}), callback.done
|
||||
))
|
||||
return task, {"answer": "", "docs": ""}
|
||||
7
dev_opsgpt/codebase_handler/__init__.py
Normal file
7
dev_opsgpt/codebase_handler/__init__.py
Normal file
@ -0,0 +1,7 @@
|
||||
# encoding: utf-8
|
||||
'''
|
||||
@author: 温进
|
||||
@file: __init__.py.py
|
||||
@time: 2023/10/23 下午4:57
|
||||
@desc:
|
||||
'''
|
||||
139
dev_opsgpt/codebase_handler/codebase_handler.py
Normal file
139
dev_opsgpt/codebase_handler/codebase_handler.py
Normal file
@ -0,0 +1,139 @@
|
||||
# encoding: utf-8
|
||||
'''
|
||||
@author: 温进
|
||||
@file: codebase_handler.py
|
||||
@time: 2023/10/23 下午5:05
|
||||
@desc:
|
||||
'''
|
||||
|
||||
from loguru import logger
|
||||
import time
|
||||
import os
|
||||
|
||||
from dev_opsgpt.codebase_handler.parser.java_paraser.java_crawler import JavaCrawler
|
||||
from dev_opsgpt.codebase_handler.parser.java_paraser.java_preprocess import JavaPreprocessor
|
||||
from dev_opsgpt.codebase_handler.parser.java_paraser.java_dedup import JavaDedup
|
||||
from dev_opsgpt.codebase_handler.parser.java_paraser.java_parser import JavaParser
|
||||
from dev_opsgpt.codebase_handler.tagger.tagger import Tagger
|
||||
from dev_opsgpt.codebase_handler.tagger.tuple_generation import node_edge_update
|
||||
|
||||
from dev_opsgpt.codebase_handler.networkx_handler.networkx_handler import NetworkxHandler
|
||||
from dev_opsgpt.codebase_handler.codedb_handler.local_codedb_handler import LocalCodeDBHandler
|
||||
|
||||
|
||||
class CodeBaseHandler():
|
||||
def __init__(self, code_name: str, code_path: str = '', cb_root_path: str = '', history_node_list: list = []):
|
||||
self.nh = None
|
||||
self.lcdh = None
|
||||
self.code_name = code_name
|
||||
self.code_path = code_path
|
||||
|
||||
self.codebase_path = cb_root_path + os.sep + code_name
|
||||
self.graph_path = self.codebase_path + os.sep + 'graph.pk'
|
||||
self.codedb_path = self.codebase_path + os.sep + 'codedb.pk'
|
||||
|
||||
self.tagger = Tagger()
|
||||
self.history_node_list = history_node_list
|
||||
|
||||
def import_code(self, do_save: bool=False, do_load: bool=False) -> bool:
|
||||
'''
|
||||
import code to codeBase
|
||||
@param code_path:
|
||||
@param do_save:
|
||||
@param do_load:
|
||||
@return: True as success; False as failure
|
||||
'''
|
||||
if do_load:
|
||||
logger.info('start load from codebase_path')
|
||||
load_graph_path = self.graph_path
|
||||
load_codedb_path = self.codedb_path
|
||||
|
||||
st = time.time()
|
||||
self.nh = NetworkxHandler(graph_path=load_graph_path)
|
||||
logger.info('generate graph success, rt={}'.format(time.time() - st))
|
||||
|
||||
st = time.time()
|
||||
self.lcdh = LocalCodeDBHandler(db_path=load_codedb_path)
|
||||
logger.info('generate codedb success, rt={}'.format(time.time() - st))
|
||||
else:
|
||||
logger.info('start load from code_path')
|
||||
st = time.time()
|
||||
java_code_dict = JavaCrawler.local_java_file_crawler(self.code_path)
|
||||
logger.info('crawl success, rt={}'.format(time.time() - st))
|
||||
|
||||
jp = JavaPreprocessor()
|
||||
java_code_dict = jp.preprocess(java_code_dict)
|
||||
|
||||
jd = JavaDedup()
|
||||
java_code_dict = jd.dedup(java_code_dict)
|
||||
|
||||
st = time.time()
|
||||
j_parser = JavaParser()
|
||||
parse_res = j_parser.parse(java_code_dict)
|
||||
logger.info('parse success, rt={}'.format(time.time() - st))
|
||||
|
||||
st = time.time()
|
||||
tagged_code = self.tagger.generate_tag(parse_res)
|
||||
node_list, edge_list = node_edge_update(parse_res.values())
|
||||
logger.info('get node and edge success, rt={}'.format(time.time() - st))
|
||||
|
||||
st = time.time()
|
||||
self.nh = NetworkxHandler(node_list=node_list, edge_list=edge_list)
|
||||
logger.info('generate graph success, rt={}'.format(time.time() - st))
|
||||
|
||||
st = time.time()
|
||||
self.lcdh = LocalCodeDBHandler(tagged_code)
|
||||
logger.info('CodeDB load success, rt={}'.format(time.time() - st))
|
||||
|
||||
if do_save:
|
||||
save_graph_path = self.graph_path
|
||||
save_codedb_path = self.codedb_path
|
||||
self.nh.save_graph(save_graph_path)
|
||||
self.lcdh.save_db(save_codedb_path)
|
||||
|
||||
def search_code(self, query: str, code_limit: int, history_node_list: list = []):
|
||||
'''
|
||||
search code related to query
|
||||
@param self:
|
||||
@param query:
|
||||
@return:
|
||||
'''
|
||||
# get query tag
|
||||
query_tag_list = self.tagger.generate_tag_query(query)
|
||||
|
||||
related_node_score_list = self.nh.search_node_with_score(query_tag_list=query_tag_list,
|
||||
history_node_list=history_node_list)
|
||||
|
||||
score_dict = {
|
||||
i[0]: i[1]
|
||||
for i in related_node_score_list
|
||||
}
|
||||
related_node = [i[0] for i in related_node_score_list]
|
||||
related_score = [i[1] for i in related_node_score_list]
|
||||
|
||||
related_code, code_related_node = self.lcdh.search_by_multi_tag(related_node, lim=code_limit)
|
||||
|
||||
related_node = [
|
||||
(node, self.nh.get_node_type(node), score_dict[node])
|
||||
for node in code_related_node
|
||||
]
|
||||
|
||||
related_node.sort(key=lambda x: x[2], reverse=True)
|
||||
|
||||
logger.info('related_node={}'.format(related_node))
|
||||
logger.info('related_code={}'.format(related_code))
|
||||
logger.info('num of code={}'.format(len(related_code)))
|
||||
return related_code, related_node
|
||||
|
||||
def refresh_history(self):
|
||||
self.history_node_list = []
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
7
dev_opsgpt/codebase_handler/codedb_handler/__init__.py
Normal file
7
dev_opsgpt/codebase_handler/codedb_handler/__init__.py
Normal file
@ -0,0 +1,7 @@
|
||||
# encoding: utf-8
|
||||
'''
|
||||
@author: 温进
|
||||
@file: __init__.py.py
|
||||
@time: 2023/10/23 下午5:04
|
||||
@desc:
|
||||
'''
|
||||
@ -0,0 +1,55 @@
|
||||
# encoding: utf-8
|
||||
'''
|
||||
@author: 温进
|
||||
@file: local_codedb_handler.py
|
||||
@time: 2023/10/23 下午5:05
|
||||
@desc:
|
||||
'''
|
||||
import pickle
|
||||
|
||||
|
||||
class LocalCodeDBHandler:
|
||||
def __init__(self, tagged_code: dict = {}, db_path: str = ''):
|
||||
if db_path:
|
||||
with open(db_path, 'rb') as f:
|
||||
self.data = pickle.load(f)
|
||||
else:
|
||||
self.data = {}
|
||||
for code, tag in tagged_code.items():
|
||||
self.data[code] = str(tag)
|
||||
|
||||
def search_by_single_tag(self, tag, lim):
|
||||
res = list()
|
||||
for k, v in self.data.items():
|
||||
if tag in v and k not in res:
|
||||
res.append(k)
|
||||
|
||||
if len(res) > lim:
|
||||
break
|
||||
return res
|
||||
|
||||
def search_by_multi_tag(self, tag_list, lim=3):
|
||||
res = list()
|
||||
res_related_node = []
|
||||
for tag in tag_list:
|
||||
single_tag_res = self.search_by_single_tag(tag, lim)
|
||||
for code in single_tag_res:
|
||||
if code not in res:
|
||||
res.append(code)
|
||||
res_related_node.append(tag)
|
||||
if len(res) >= lim:
|
||||
break
|
||||
|
||||
# reverse order so that most relevant one is close to the query
|
||||
res = res[0:lim]
|
||||
res.reverse()
|
||||
|
||||
return res, res_related_node
|
||||
|
||||
def save_db(self, save_path):
|
||||
with open(save_path, 'wb') as f:
|
||||
pickle.dump(self.data, f)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
7
dev_opsgpt/codebase_handler/networkx_handler/__init__.py
Normal file
7
dev_opsgpt/codebase_handler/networkx_handler/__init__.py
Normal file
@ -0,0 +1,7 @@
|
||||
# encoding: utf-8
|
||||
'''
|
||||
@author: 温进
|
||||
@file: __init__.py.py
|
||||
@time: 2023/10/23 下午5:00
|
||||
@desc:
|
||||
'''
|
||||
129
dev_opsgpt/codebase_handler/networkx_handler/networkx_handler.py
Normal file
129
dev_opsgpt/codebase_handler/networkx_handler/networkx_handler.py
Normal file
@ -0,0 +1,129 @@
|
||||
# encoding: utf-8
|
||||
'''
|
||||
@author: 温进
|
||||
@file: networkx_handler.py
|
||||
@time: 2023/10/23 下午5:02
|
||||
@desc:
|
||||
'''
|
||||
|
||||
import networkx as nx
|
||||
from loguru import logger
|
||||
import matplotlib.pyplot as plt
|
||||
import pickle
|
||||
from collections import defaultdict
|
||||
import json
|
||||
|
||||
QUERY_SCORE = 10
|
||||
HISTORY_SCORE = 5
|
||||
RATIO = 0.5
|
||||
|
||||
|
||||
class NetworkxHandler:
|
||||
def __init__(self, graph_path: str = '', node_list: list = [], edge_list: list = []):
|
||||
if graph_path:
|
||||
self.graph_path = graph_path
|
||||
with open(graph_path, 'r') as f:
|
||||
self.G = nx.node_link_graph(json.load(f))
|
||||
else:
|
||||
self.G = nx.DiGraph()
|
||||
self.populate_graph(node_list, edge_list)
|
||||
logger.debug(
|
||||
'number of nodes={}, number of edges={}'.format(self.G.number_of_nodes(), self.G.number_of_edges()))
|
||||
|
||||
self.query_score = QUERY_SCORE
|
||||
self.history_score = HISTORY_SCORE
|
||||
self.ratio = RATIO
|
||||
|
||||
def populate_graph(self, node_list, edge_list):
|
||||
'''
|
||||
populate graph with node_list and edge_list
|
||||
'''
|
||||
self.G.add_nodes_from(node_list)
|
||||
for edge in edge_list:
|
||||
self.G.add_edge(edge[0], edge[-1], relation=edge[1])
|
||||
|
||||
def draw_graph(self, save_path: str):
|
||||
'''
|
||||
draw and save to save_path
|
||||
'''
|
||||
sub = plt.subplot(111)
|
||||
nx.draw(self.G, with_labels=True)
|
||||
|
||||
plt.savefig(save_path)
|
||||
|
||||
def search_node(self, query_tag_list: list, history_node_list: list = []):
|
||||
'''
|
||||
search node by tag_list, search from history_tag neighbors first
|
||||
> query_tag_list: tag from query
|
||||
> history_node_list
|
||||
'''
|
||||
node_list = set()
|
||||
|
||||
# search from history_tag_list first, then all nodes
|
||||
for tag in query_tag_list:
|
||||
add = False
|
||||
for history_node in history_node_list:
|
||||
connect_node_list: list = self.G.adj[history_node]
|
||||
connect_node_list.insert(0, history_node)
|
||||
for connect_node in connect_node_list:
|
||||
node_name_lim = len(connect_node) if '_' not in connect_node else connect_node.index('_')
|
||||
node_name = connect_node[0:node_name_lim]
|
||||
if tag.lower() in node_name.lower():
|
||||
node_list.add(connect_node)
|
||||
add = True
|
||||
if not add:
|
||||
for node in self.G.nodes():
|
||||
if tag.lower() in node.lower():
|
||||
node_list.add(node)
|
||||
return node_list
|
||||
|
||||
def search_node_with_score(self, query_tag_list: list, history_node_list: list = []):
|
||||
'''
|
||||
search node by tag_list, search from history_tag neighbors first
|
||||
> query_tag_list: tag from query
|
||||
> history_node_list
|
||||
'''
|
||||
logger.info('query_tag_list={}, history_node_list={}'.format(query_tag_list, history_node_list))
|
||||
node_dict = defaultdict(lambda: 0)
|
||||
|
||||
# loop over query_tag_list and add node:
|
||||
for tag in query_tag_list:
|
||||
for node in self.G.nodes:
|
||||
if tag.lower() in node.lower():
|
||||
node_dict[node] += self.query_score
|
||||
|
||||
# loop over history_node and add node score
|
||||
for node in history_node_list:
|
||||
node_dict[node] += self.history_score
|
||||
|
||||
logger.info('temp_res={}'.format(node_dict))
|
||||
|
||||
# adj score broadcast
|
||||
for node in node_dict:
|
||||
adj_node_list = self.G.adj[node]
|
||||
for adj_node in adj_node_list:
|
||||
node_dict[node] += node_dict.get(adj_node, 0) * self.ratio
|
||||
|
||||
# sort
|
||||
node_list = [(node, node_score) for node, node_score in node_dict.items()]
|
||||
node_list.sort(key=lambda x: x[1], reverse=True)
|
||||
return node_list
|
||||
|
||||
def save_graph(self, save_path: str):
|
||||
to_save = nx.node_link_data(self.G)
|
||||
with open(save_path, 'w') as f:
|
||||
json.dump(to_save, f)
|
||||
|
||||
def __len__(self):
|
||||
return self.G.number_of_nodes()
|
||||
|
||||
def get_node_type(self, node_name):
|
||||
node_type = self.G.nodes[node_name]['type']
|
||||
return node_type
|
||||
|
||||
def refresh_graph(self, ):
|
||||
with open(self.graph_path, 'r') as f:
|
||||
self.G = nx.node_link_graph(json.load(f))
|
||||
|
||||
|
||||
|
||||
7
dev_opsgpt/codebase_handler/parser/__init__.py
Normal file
7
dev_opsgpt/codebase_handler/parser/__init__.py
Normal file
@ -0,0 +1,7 @@
|
||||
# encoding: utf-8
|
||||
'''
|
||||
@author: 温进
|
||||
@file: __init__.py.py
|
||||
@time: 2023/10/23 下午5:00
|
||||
@desc:
|
||||
'''
|
||||
@ -0,0 +1,7 @@
|
||||
# encoding: utf-8
|
||||
'''
|
||||
@author: 温进
|
||||
@file: __init__.py.py
|
||||
@time: 2023/10/23 下午5:01
|
||||
@desc:
|
||||
'''
|
||||
@ -0,0 +1,32 @@
|
||||
# encoding: utf-8
|
||||
'''
|
||||
@author: 温进
|
||||
@file: java_crawler.py
|
||||
@time: 2023/10/23 下午5:02
|
||||
@desc:
|
||||
'''
|
||||
|
||||
import os
|
||||
import glob
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class JavaCrawler:
|
||||
@staticmethod
|
||||
def local_java_file_crawler(path: str):
|
||||
'''
|
||||
read local java file in path
|
||||
> path: path to crawl, must be absolute path like A/B/C
|
||||
< dict of java code string
|
||||
'''
|
||||
java_file_list = glob.glob('{path}{sep}**{sep}*.java'.format(path=path, sep=os.path.sep), recursive=True)
|
||||
java_code_dict = {}
|
||||
|
||||
logger.debug('number of file={}'.format(len(java_file_list)))
|
||||
# logger.debug(java_file_list)
|
||||
|
||||
for java_file in java_file_list:
|
||||
with open(java_file) as f:
|
||||
java_code = ''.join(f.readlines())
|
||||
java_code_dict[java_file] = java_code
|
||||
return java_code_dict
|
||||
@ -0,0 +1,15 @@
|
||||
# encoding: utf-8
|
||||
'''
|
||||
@author: 温进
|
||||
@file: java_dedup.py
|
||||
@time: 2023/10/23 下午5:02
|
||||
@desc:
|
||||
'''
|
||||
|
||||
|
||||
class JavaDedup:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def dedup(self, java_code_dict):
|
||||
return java_code_dict
|
||||
107
dev_opsgpt/codebase_handler/parser/java_paraser/java_parser.py
Normal file
107
dev_opsgpt/codebase_handler/parser/java_paraser/java_parser.py
Normal file
@ -0,0 +1,107 @@
|
||||
# encoding: utf-8
|
||||
'''
|
||||
@author: 温进
|
||||
@file: java_parser.py
|
||||
@time: 2023/10/23 下午5:03
|
||||
@desc:
|
||||
'''
|
||||
import json
|
||||
import javalang
|
||||
import glob
|
||||
import os
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class JavaParser:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def parse(self, java_code_list):
|
||||
'''
|
||||
parse java code and extract entity
|
||||
'''
|
||||
tree_dict = self.preparse(java_code_list)
|
||||
res = self.multi_java_code_parse(tree_dict)
|
||||
|
||||
return res
|
||||
|
||||
def preparse(self, java_code_dict):
|
||||
'''
|
||||
preparse by javalang
|
||||
< dict of java_code and tree
|
||||
'''
|
||||
tree_dict = {}
|
||||
for fp, java_code in java_code_dict.items():
|
||||
try:
|
||||
tree = javalang.parse.parse(java_code)
|
||||
except Exception as e:
|
||||
continue
|
||||
|
||||
if tree.package is not None:
|
||||
tree_dict[java_code] = tree
|
||||
logger.info('success parse {} files'.format(len(tree_dict)))
|
||||
return tree_dict
|
||||
|
||||
def single_java_code_parse(self, tree):
|
||||
'''
|
||||
parse single code file
|
||||
> tree: javalang parse result
|
||||
< {pac_name: '', class_name_list: [], func_name_list: [], import_pac_name_list: []]}
|
||||
'''
|
||||
import_pac_name_list = []
|
||||
|
||||
# get imports
|
||||
import_list = tree.imports
|
||||
|
||||
for import_pac in import_list:
|
||||
import_pac_name = import_pac.path
|
||||
import_pac_name_list.append(import_pac_name)
|
||||
|
||||
pac_name = tree.package.name
|
||||
class_name_list = []
|
||||
func_name_dict = {}
|
||||
|
||||
for node in tree.types:
|
||||
if type(node) in (javalang.tree.ClassDeclaration, javalang.tree.InterfaceDeclaration):
|
||||
class_name = pac_name + '.' + node.name
|
||||
class_name_list.append(class_name)
|
||||
|
||||
for node_inner in node.body:
|
||||
if type(node_inner) is javalang.tree.MethodDeclaration:
|
||||
func_name = class_name + '.' + node_inner.name
|
||||
|
||||
# add params name to func_name
|
||||
params_list = node_inner.parameters
|
||||
|
||||
for params in params_list:
|
||||
params_name = params.type.name
|
||||
func_name = func_name + '_' + params_name
|
||||
|
||||
if class_name not in func_name_dict:
|
||||
func_name_dict[class_name] = []
|
||||
|
||||
func_name_dict[class_name].append(func_name)
|
||||
|
||||
res = {
|
||||
'pac_name': pac_name,
|
||||
'class_name_list': class_name_list,
|
||||
'func_name_dict': func_name_dict,
|
||||
'import_pac_name_list': import_pac_name_list
|
||||
}
|
||||
return res
|
||||
|
||||
def multi_java_code_parse(self, tree_dict):
|
||||
'''
|
||||
parse multiple java code
|
||||
> tree_list
|
||||
< parse_result_dict
|
||||
'''
|
||||
res_dict = {}
|
||||
for java_code, tree in tree_dict.items():
|
||||
try:
|
||||
res_dict[java_code] = self.single_java_code_parse(tree)
|
||||
except Exception as e:
|
||||
logger.debug(java_code)
|
||||
raise ImportError
|
||||
|
||||
return res_dict
|
||||
@ -0,0 +1,14 @@
|
||||
# encoding: utf-8
|
||||
'''
|
||||
@author: 温进
|
||||
@file: java_preprocess.py
|
||||
@time: 2023/10/23 下午5:04
|
||||
@desc:
|
||||
'''
|
||||
|
||||
class JavaPreprocessor:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def preprocess(self, java_code_dict):
|
||||
return java_code_dict
|
||||
7
dev_opsgpt/codebase_handler/tagger/__init__.py
Normal file
7
dev_opsgpt/codebase_handler/tagger/__init__.py
Normal file
@ -0,0 +1,7 @@
|
||||
# encoding: utf-8
|
||||
'''
|
||||
@author: 温进
|
||||
@file: __init__.py.py
|
||||
@time: 2023/10/23 下午5:00
|
||||
@desc:
|
||||
'''
|
||||
48
dev_opsgpt/codebase_handler/tagger/tagger.py
Normal file
48
dev_opsgpt/codebase_handler/tagger/tagger.py
Normal file
@ -0,0 +1,48 @@
|
||||
# encoding: utf-8
|
||||
'''
|
||||
@author: 温进
|
||||
@file: tagger.py
|
||||
@time: 2023/10/23 下午5:01
|
||||
@desc:
|
||||
'''
|
||||
import re
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class Tagger:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def generate_tag(self, parse_res_dict: dict):
|
||||
'''
|
||||
generate tag from parse_res
|
||||
'''
|
||||
res = {}
|
||||
for java_code, parse_res in parse_res_dict.items():
|
||||
tag = {}
|
||||
tag['pac_name'] = parse_res.get('pac_name')
|
||||
tag['class_name'] = set(parse_res.get('class_name_list'))
|
||||
tag['func_name'] = set()
|
||||
|
||||
for _, func_name_list in parse_res.get('func_name_dict', {}).items():
|
||||
tag['func_name'].update(func_name_list)
|
||||
|
||||
res[java_code] = tag
|
||||
return res
|
||||
|
||||
def generate_tag_query(self, query):
|
||||
'''
|
||||
generate tag from query
|
||||
'''
|
||||
# simple extract english
|
||||
tag_list = re.findall(r'[a-zA-Z\_\.]+', query)
|
||||
tag_list = list(set(tag_list))
|
||||
return tag_list
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tagger = Tagger()
|
||||
logger.debug(tagger.generate_tag_query('com.CheckHolder 有哪些函数'))
|
||||
|
||||
|
||||
|
||||
51
dev_opsgpt/codebase_handler/tagger/tuple_generation.py
Normal file
51
dev_opsgpt/codebase_handler/tagger/tuple_generation.py
Normal file
@ -0,0 +1,51 @@
|
||||
# encoding: utf-8
|
||||
'''
|
||||
@author: 温进
|
||||
@file: tuple_generation.py
|
||||
@time: 2023/10/23 下午5:01
|
||||
@desc:
|
||||
'''
|
||||
|
||||
|
||||
def node_edge_update(parse_res_list: list, node_list: list = list(), edge_list: list = list()):
|
||||
'''
|
||||
generate node and edge by parse_res
|
||||
< node: list of string node
|
||||
< edge: (node_st, relation, node_ed)
|
||||
'''
|
||||
node_dict = {i: j for i, j in node_list}
|
||||
|
||||
for single_parse_res in parse_res_list:
|
||||
pac_name = single_parse_res['pac_name']
|
||||
|
||||
node_dict[pac_name] = {'type': 'package'}
|
||||
|
||||
# class_name
|
||||
for class_name in single_parse_res['class_name_list']:
|
||||
node_dict[class_name] = {'type': 'class'}
|
||||
edge_list.append((pac_name, 'contain', class_name))
|
||||
edge_list.append((class_name, 'inside', pac_name))
|
||||
|
||||
# func_name
|
||||
for class_name, func_name_list in single_parse_res['func_name_dict'].items():
|
||||
node_list.append(class_name)
|
||||
for func_name in func_name_list:
|
||||
node_dict[func_name] = {'type': 'func'}
|
||||
edge_list.append((class_name, 'contain', func_name))
|
||||
edge_list.append((func_name, 'inside', class_name))
|
||||
|
||||
# depend
|
||||
for depend_pac_name in single_parse_res['import_pac_name_list']:
|
||||
if depend_pac_name.endswith('*'):
|
||||
depend_pac_name = depend_pac_name[0:-2]
|
||||
|
||||
if depend_pac_name in node_dict:
|
||||
continue
|
||||
else:
|
||||
node_dict[depend_pac_name] = {'type': 'unknown'}
|
||||
edge_list.append((pac_name, 'depend', depend_pac_name))
|
||||
edge_list.append((depend_pac_name, 'beDepended', pac_name))
|
||||
|
||||
node_list = [(i, j) for i, j in node_dict.items()]
|
||||
|
||||
return node_list, edge_list
|
||||
9
dev_opsgpt/connector/__init__.py
Normal file
9
dev_opsgpt/connector/__init__.py
Normal file
@ -0,0 +1,9 @@
|
||||
from .configs import PHASE_CONFIGS
|
||||
|
||||
|
||||
|
||||
PHASE_LIST = list(PHASE_CONFIGS.keys())
|
||||
|
||||
__all__ = [
|
||||
"PHASE_CONFIGS"
|
||||
]
|
||||
6
dev_opsgpt/connector/agents/__init__.py
Normal file
6
dev_opsgpt/connector/agents/__init__.py
Normal file
@ -0,0 +1,6 @@
|
||||
from .base_agent import BaseAgent
|
||||
from .react_agent import ReactAgent
|
||||
|
||||
__all__ = [
|
||||
"BaseAgent", "ReactAgent"
|
||||
]
|
||||
427
dev_opsgpt/connector/agents/base_agent.py
Normal file
427
dev_opsgpt/connector/agents/base_agent.py
Normal file
@ -0,0 +1,427 @@
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Union
|
||||
import re
|
||||
import copy
|
||||
import json
|
||||
import traceback
|
||||
import uuid
|
||||
from loguru import logger
|
||||
|
||||
from dev_opsgpt.connector.shcema.memory import Memory
|
||||
from dev_opsgpt.connector.connector_schema import (
|
||||
Task, Role, Message, ActionStatus, Doc, CodeDoc
|
||||
)
|
||||
from configs.server_config import SANDBOX_SERVER
|
||||
from dev_opsgpt.sandbox import PyCodeBox, CodeBoxResponse
|
||||
from dev_opsgpt.tools import DDGSTool, DocRetrieval, CodeRetrieval
|
||||
from dev_opsgpt.connector.configs.agent_config import REACT_PROMPT_INPUT
|
||||
|
||||
from dev_opsgpt.llm_models import getChatModel
|
||||
|
||||
|
||||
class BaseAgent:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
role: Role,
|
||||
task: Task = None,
|
||||
memory: Memory = None,
|
||||
chat_turn: int = 1,
|
||||
do_search: bool = False,
|
||||
do_doc_retrieval: bool = False,
|
||||
do_tool_retrieval: bool = False,
|
||||
temperature: float = 0.2,
|
||||
stop: Union[List[str], str] = None,
|
||||
do_filter: bool = True,
|
||||
do_use_self_memory: bool = True,
|
||||
# docs_prompt: str,
|
||||
# prompt_mamnger: PromptManager
|
||||
):
|
||||
|
||||
self.task = task
|
||||
self.role = role
|
||||
self.llm = self.create_llm_engine(temperature, stop)
|
||||
self.memory = self.init_history(memory)
|
||||
self.chat_turn = chat_turn
|
||||
self.do_search = do_search
|
||||
self.do_doc_retrieval = do_doc_retrieval
|
||||
self.do_tool_retrieval = do_tool_retrieval
|
||||
self.codebox = PyCodeBox(
|
||||
remote_url=SANDBOX_SERVER["url"],
|
||||
remote_ip=SANDBOX_SERVER["host"],
|
||||
remote_port=SANDBOX_SERVER["port"],
|
||||
token="mytoken",
|
||||
do_code_exe=True,
|
||||
do_remote=SANDBOX_SERVER["do_remote"],
|
||||
do_check_net=False
|
||||
)
|
||||
self.do_filter = do_filter
|
||||
self.do_use_self_memory = do_use_self_memory
|
||||
# self.docs_prompt = docs_prompt
|
||||
# self.prompt_manager = None
|
||||
|
||||
def run(self, query: Message, history: Memory = None, background: Memory = None) -> Message:
|
||||
'''llm inference'''
|
||||
# insert query into memory
|
||||
query_c = copy.deepcopy(query)
|
||||
|
||||
self_memory = self.memory if self.do_use_self_memory else None
|
||||
prompt = self.create_prompt(query_c, self_memory, history, background)
|
||||
content = self.llm.predict(prompt)
|
||||
logger.debug(f"{self.role.role_name} prompt: {prompt}")
|
||||
# logger.debug(f"{self.role.role_name} content: {content}")
|
||||
|
||||
output_message = Message(
|
||||
role_name=self.role.role_name,
|
||||
role_type="ai", #self.role.role_type,
|
||||
role_content=content,
|
||||
role_contents=[content],
|
||||
input_query=query_c.input_query,
|
||||
tools=query_c.tools
|
||||
)
|
||||
|
||||
output_message = self.parser(output_message)
|
||||
if self.do_filter:
|
||||
output_message = self.filter(output_message)
|
||||
|
||||
|
||||
# 更新自身的回答
|
||||
self.append_history(query_c)
|
||||
self.append_history(output_message)
|
||||
logger.info(f"{self.role.role_name} step_run: {output_message.role_content}")
|
||||
return output_message
|
||||
|
||||
def create_prompt(
|
||||
self, query: Message, memory: Memory =None, history: Memory = None, background: Memory = None, prompt_mamnger=None) -> str:
|
||||
'''
|
||||
role\task\tools\docs\memory
|
||||
'''
|
||||
#
|
||||
doc_infos = self.create_doc_prompt(query)
|
||||
code_infos = self.create_codedoc_prompt(query)
|
||||
#
|
||||
formatted_tools, tool_names = self.create_tools_prompt(query)
|
||||
task_prompt = self.create_task_prompt(query)
|
||||
background_prompt = self.create_background_prompt(background, control_key="step_content")
|
||||
history_prompt = self.create_history_prompt(history)
|
||||
selfmemory_prompt = self.create_selfmemory_prompt(memory, control_key="step_content")
|
||||
#
|
||||
# extra_system_prompt = self.role.role_prompt
|
||||
|
||||
prompt = self.role.role_prompt.format(**{"formatted_tools": formatted_tools, "tool_names": tool_names})
|
||||
|
||||
task = query.task or self.task
|
||||
if task_prompt is not None:
|
||||
prompt += "\n" + task.task_prompt
|
||||
|
||||
if doc_infos is not None and doc_infos!="" and doc_infos!="不存在知识库辅助信息":
|
||||
prompt += f"\n知识库信息: {doc_infos}"
|
||||
|
||||
if code_infos is not None and code_infos!="" and code_infos!="不存在代码库辅助信息":
|
||||
prompt += f"\n代码库信息: {code_infos}"
|
||||
|
||||
if background_prompt:
|
||||
prompt += "\n" + background_prompt
|
||||
|
||||
if history_prompt:
|
||||
prompt += "\n" + history_prompt
|
||||
|
||||
if selfmemory_prompt:
|
||||
prompt += "\n" + selfmemory_prompt
|
||||
|
||||
# input_query = memory.to_tuple_messages(content_key="step_content")
|
||||
# input_query = "\n".join([f"{k}: {v}" for k, v in input_query if v])
|
||||
|
||||
input_query = query.role_content
|
||||
|
||||
# logger.debug(f"{self.role.role_name} extra_system_prompt: {self.role.role_prompt}")
|
||||
logger.debug(f"{self.role.role_name} input_query: {input_query}")
|
||||
# logger.debug(f"{self.role.role_name} doc_infos: {doc_infos}")
|
||||
logger.debug(f"{self.role.role_name} tool_names: {tool_names}")
|
||||
prompt += "\n" + REACT_PROMPT_INPUT.format(**{"query": input_query})
|
||||
# prompt = extra_system_prompt.format(**{"query": input_query, "doc_infos": doc_infos, "formatted_tools": formatted_tools, "tool_names": tool_names})
|
||||
while "{{" in prompt or "}}" in prompt:
|
||||
prompt = prompt.replace("{{", "{")
|
||||
prompt = prompt.replace("}}", "}")
|
||||
return prompt
|
||||
|
||||
# prompt_comp = [("system", extra_system_prompt)] + memory.to_tuple_messages()
|
||||
# prompt = ChatPromptTemplate.from_messages(prompt_comp)
|
||||
# prompt = prompt.format(**{"query": query.role_content, "doc_infos": doc_infos, "formatted_tools": formatted_tools, "tool_names": tool_names})
|
||||
# return prompt
|
||||
|
||||
def create_doc_prompt(self, message: Message) -> str:
|
||||
''''''
|
||||
db_docs = message.db_docs
|
||||
search_docs = message.search_docs
|
||||
doc_infos = "\n".join([doc.get_snippet() for doc in db_docs] + [doc.get_snippet() for doc in search_docs])
|
||||
return doc_infos or "不存在知识库辅助信息"
|
||||
|
||||
def create_codedoc_prompt(self, message: Message) -> str:
|
||||
''''''
|
||||
code_docs = message.code_docs
|
||||
doc_infos = "\n".join([doc.get_code() for doc in code_docs])
|
||||
return doc_infos or "不存在代码库辅助信息"
|
||||
|
||||
def create_tools_prompt(self, message: Message) -> str:
|
||||
tools = message.tools
|
||||
tool_strings = []
|
||||
for tool in tools:
|
||||
args_schema = re.sub("}", "}}}}", re.sub("{", "{{{{", str(tool.args)))
|
||||
tool_strings.append(f"{tool.name}: {tool.description}, args: {args_schema}")
|
||||
formatted_tools = "\n".join(tool_strings)
|
||||
tool_names = ", ".join([tool.name for tool in tools])
|
||||
return formatted_tools, tool_names
|
||||
|
||||
def create_task_prompt(self, message: Message) -> str:
|
||||
task = message.task or self.task
|
||||
return "\n任务目标: " + task.task_prompt if task is not None else None
|
||||
|
||||
def create_background_prompt(self, background: Memory, control_key="role_content") -> str:
|
||||
background_message = None if background is None else background.to_str_messages(content_key=control_key)
|
||||
# logger.debug(f"background_message: {background_message}")
|
||||
if background_message:
|
||||
background_message = re.sub("}", "}}", re.sub("{", "{{", background_message))
|
||||
return "\n背景信息: " + background_message if background_message else None
|
||||
|
||||
def create_history_prompt(self, history: Memory, control_key="role_content") -> str:
|
||||
history_message = None if history is None else history.to_str_messages(content_key=control_key)
|
||||
if history_message:
|
||||
history_message = re.sub("}", "}}", re.sub("{", "{{", history_message))
|
||||
return "\n补充对话信息: " + history_message if history_message else None
|
||||
|
||||
def create_selfmemory_prompt(self, selfmemory: Memory, control_key="role_content") -> str:
|
||||
selfmemory_message = None if selfmemory is None else selfmemory.to_str_messages(content_key=control_key)
|
||||
if selfmemory_message:
|
||||
selfmemory_message = re.sub("}", "}}", re.sub("{", "{{", selfmemory_message))
|
||||
return "\n补充自身对话信息: " + selfmemory_message if selfmemory_message else None
|
||||
|
||||
def init_history(self, memory: Memory = None) -> Memory:
|
||||
return Memory([])
|
||||
|
||||
def update_history(self, message: Message):
|
||||
self.memory.append(message)
|
||||
|
||||
def append_history(self, message: Message):
|
||||
self.memory.append(message)
|
||||
|
||||
def clear_history(self, ):
|
||||
self.memory.clear()
|
||||
self.memory = self.init_history()
|
||||
|
||||
def create_llm_engine(self, temperature=0.2, stop=None):
|
||||
return getChatModel(temperature=temperature, stop=stop)
|
||||
|
||||
def filter(self, message: Message, stop=None) -> Message:
|
||||
|
||||
tool_params = self.parser_spec_key(message.role_content, "tool_params")
|
||||
code_content = self.parser_spec_key(message.role_content, "code_content")
|
||||
plan = self.parser_spec_key(message.role_content, "plan")
|
||||
plans = self.parser_spec_key(message.role_content, "plans", do_search=False)
|
||||
content = self.parser_spec_key(message.role_content, "content", do_search=False)
|
||||
|
||||
# logger.debug(f"tool_params: {tool_params}, code_content: {code_content}, plan: {plan}, plans: {plans}, content: {content}")
|
||||
role_content = tool_params or code_content or plan or plans or content
|
||||
message.role_content = role_content or message.role_content
|
||||
return message
|
||||
|
||||
def token_usage(self, ):
|
||||
pass
|
||||
|
||||
def get_extra_infos(self, message: Message) -> Message:
|
||||
''''''
|
||||
if self.do_search:
|
||||
message = self.get_search_retrieval(message)
|
||||
|
||||
if self.do_doc_retrieval:
|
||||
message = self.get_doc_retrieval(message)
|
||||
|
||||
if self.do_tool_retrieval:
|
||||
message = self.get_tool_retrieval(message)
|
||||
|
||||
return message
|
||||
|
||||
def get_search_retrieval(self, message: Message,) -> Message:
|
||||
SEARCH_ENGINES = {"duckduckgo": DDGSTool}
|
||||
search_docs = []
|
||||
for idx, doc in enumerate(SEARCH_ENGINES["duckduckgo"].run(message.role_content, 3)):
|
||||
doc.update({"index": idx})
|
||||
search_docs.append(Doc(**doc))
|
||||
message.search_docs = search_docs
|
||||
return message
|
||||
|
||||
def get_doc_retrieval(self, message: Message) -> Message:
|
||||
query = message.role_content
|
||||
knowledge_basename = message.doc_engine_name
|
||||
top_k = message.top_k
|
||||
score_threshold = message.score_threshold
|
||||
if knowledge_basename:
|
||||
docs = DocRetrieval.run(query, knowledge_basename, top_k, score_threshold)
|
||||
message.db_docs = [Doc(**doc) for doc in docs]
|
||||
return message
|
||||
|
||||
def get_code_retrieval(self, message: Message) -> Message:
|
||||
# DocRetrieval.run("langchain是什么", "DSADSAD")
|
||||
query = message.input_query
|
||||
code_engine_name = message.code_engine_name
|
||||
history_node_list = message.history_node_list
|
||||
code_docs = CodeRetrieval.run(code_engine_name, query, code_limit=message.top_k, history_node_list=history_node_list)
|
||||
message.code_docs = [CodeDoc(**doc) for doc in code_docs]
|
||||
return message
|
||||
|
||||
def get_tool_retrieval(self, message: Message) -> Message:
|
||||
return message
|
||||
|
||||
def step_router(self, message: Message) -> Message:
|
||||
''''''
|
||||
# message = self.parser(message)
|
||||
# logger.debug(f"message.action_status: {message.action_status}")
|
||||
if message.action_status == ActionStatus.CODING:
|
||||
message = self.code_step(message)
|
||||
elif message.action_status == ActionStatus.TOOL_USING:
|
||||
message = self.tool_step(message)
|
||||
|
||||
return message
|
||||
|
||||
def code_step(self, message: Message) -> Message:
|
||||
'''execute code'''
|
||||
# logger.debug(f"message.role_content: {message.role_content}, message.code_content: {message.code_content}")
|
||||
code_answer = self.codebox.chat('```python\n{}```'.format(message.code_content))
|
||||
code_prompt = f"执行上述代码后存在报错信息为 {code_answer.code_exe_response},需要进行修复" \
|
||||
if code_answer.code_exe_type == "error" else f"执行上述代码后返回信息为 {code_answer.code_exe_response}"
|
||||
uid = str(uuid.uuid1())
|
||||
if code_answer.code_exe_type == "image/png":
|
||||
message.figures[uid] = code_answer.code_exe_response
|
||||
message.code_answer = f"\n观察: 执行上述代码后生成一张图片, 图片名为{uid}\n"
|
||||
message.observation = f"\n观察: 执行上述代码后生成一张图片, 图片名为{uid}\n"
|
||||
message.step_content += f"\n观察: 执行上述代码后生成一张图片, 图片名为{uid}\n"
|
||||
message.step_contents += [f"\n观察: 执行上述代码后生成一张图片, 图片名为{uid}\n"]
|
||||
message.role_content += f"\n观察:执行上述代码后生成一张图片, 图片名为{uid}\n"
|
||||
else:
|
||||
message.code_answer = code_answer.code_exe_response
|
||||
message.observation = code_answer.code_exe_response
|
||||
message.step_content += f"\n观察: {code_prompt}\n"
|
||||
message.step_contents += [f"\n观察: {code_prompt}\n"]
|
||||
message.role_content += f"\n观察: {code_prompt}\n"
|
||||
# logger.info(f"观察: {message.action_status}, {message.observation}")
|
||||
return message
|
||||
|
||||
def tool_step(self, message: Message) -> Message:
|
||||
'''execute tool'''
|
||||
# logger.debug(f"message: {message.action_status}, {message.tool_name}, {message.tool_params}")
|
||||
tool_names = [tool.name for tool in message.tools]
|
||||
if message.tool_name not in tool_names:
|
||||
message.tool_answer = "不存在可以执行的tool"
|
||||
message.observation = "不存在可以执行的tool"
|
||||
message.role_content += f"\n观察: 不存在可以执行的tool\n"
|
||||
message.step_content += f"\n观察: 不存在可以执行的tool\n"
|
||||
message.step_contents += [f"\n观察: 不存在可以执行的tool\n"]
|
||||
for tool in message.tools:
|
||||
if tool.name == message.tool_name:
|
||||
tool_res = tool.func(**message.tool_params)
|
||||
message.tool_answer = tool_res
|
||||
message.observation = tool_res
|
||||
message.role_content += f"\n观察: {tool_res}\n"
|
||||
message.step_content += f"\n观察: {tool_res}\n"
|
||||
message.step_contents += [f"\n观察: {tool_res}\n"]
|
||||
|
||||
# logger.info(f"观察: {message.action_status}, {message.observation}")
|
||||
return message
|
||||
|
||||
def parser(self, message: Message) -> Message:
|
||||
''''''
|
||||
content = message.role_content
|
||||
parser_keys = ["action", "code_content", "code_filename", "tool_params", "plans"]
|
||||
try:
|
||||
s_json = self._parse_json(content)
|
||||
message.action_status = s_json.get("action")
|
||||
message.code_content = s_json.get("code_content")
|
||||
message.tool_params = s_json.get("tool_params")
|
||||
message.tool_name = s_json.get("tool_name")
|
||||
message.code_filename = s_json.get("code_filename")
|
||||
message.plans = s_json.get("plans")
|
||||
# for parser_key in parser_keys:
|
||||
# message.action_status = content.get(parser_key)
|
||||
except Exception as e:
|
||||
# logger.warning(f"{traceback.format_exc()}")
|
||||
action_value = self._match(r"'action':\s*'([^']*)'", content) if "'action'" in content else self._match(r'"action":\s*"([^"]*)"', content)
|
||||
code_content_value = self._match(r"'code_content':\s*'([^']*)'", content) if "'code_content'" in content else self._match(r'"code_content":\s*"([^"]*)"', content)
|
||||
filename_value = self._match(r"'code_filename':\s*'([^']*)'", content) if "'code_filename'" in content else self._match(r'"code_filename":\s*"([^"]*)"', content)
|
||||
tool_params_value = self._match(r"'tool_params':\s*(\{[^{}]*\})", content, do_json=True) if "'tool_params'" in content \
|
||||
else self._match(r'"tool_params":\s*(\{[^{}]*\})', content, do_json=True)
|
||||
tool_name_value = self._match(r"'tool_name':\s*'([^']*)'", content) if "'tool_name'" in content else self._match(r'"tool_name":\s*"([^"]*)"', content)
|
||||
plans_value = self._match(r"'plans':\s*(\[.*?\])", content, do_search=False) if "'plans'" in content else self._match(r'"plans":\s*(\[.*?\])', content, do_search=False, )
|
||||
# re解析
|
||||
message.action_status = action_value or "default"
|
||||
message.code_content = code_content_value
|
||||
message.code_filename = filename_value
|
||||
message.tool_params = tool_params_value
|
||||
message.tool_name = tool_name_value
|
||||
message.plans = plans_value
|
||||
|
||||
# logger.debug(f"确认当前的action: {message.action_status}")
|
||||
|
||||
return message
|
||||
|
||||
def parser_spec_key(self, content, key, do_search=True, do_json=False) -> str:
|
||||
''''''
|
||||
key2pattern = {
|
||||
"'action'": r"'action':\s*'([^']*)'", '"action"': r'"action":\s*"([^"]*)"',
|
||||
"'code_content'": r"'code_content':\s*'([^']*)'", '"code_content"': r'"code_content":\s*"([^"]*)"',
|
||||
"'code_filename'": r"'code_filename':\s*'([^']*)'", '"code_filename"': r'"code_filename":\s*"([^"]*)"',
|
||||
"'tool_params'": r"'tool_params':\s*(\{[^{}]*\})", '"tool_params"': r'"tool_params":\s*(\{[^{}]*\})',
|
||||
"'tool_name'": r"'tool_name':\s*'([^']*)'", '"tool_name"': r'"tool_name":\s*"([^"]*)"',
|
||||
"'plans'": r"'plans':\s*(\[.*?\])", '"plans"': r'"plans":\s*(\[.*?\])',
|
||||
"'content'": r"'content':\s*'([^']*)'", '"content"': r'"content":\s*"([^"]*)"',
|
||||
}
|
||||
|
||||
s_json = self._parse_json(content)
|
||||
try:
|
||||
if s_json and key in s_json:
|
||||
return str(s_json[key])
|
||||
except:
|
||||
pass
|
||||
|
||||
keystr = f"'{key}'" if f"'{key}'" in content else f'"{key}"'
|
||||
return self._match(key2pattern.get(keystr, fr"'{key}':\s*'([^']*)'"), content, do_search=do_search, do_json=do_json)
|
||||
|
||||
def _match(self, pattern, s, do_search=True, do_json=False):
|
||||
try:
|
||||
if do_search:
|
||||
match = re.search(pattern, s)
|
||||
if match:
|
||||
value = match.group(1).replace("\\n", "\n")
|
||||
if do_json:
|
||||
value = json.loads(value)
|
||||
else:
|
||||
value = None
|
||||
else:
|
||||
match = re.findall(pattern, s, re.DOTALL)
|
||||
if match:
|
||||
value = match[0]
|
||||
if do_json:
|
||||
value = json.loads(value)
|
||||
else:
|
||||
value = None
|
||||
except Exception as e:
|
||||
logger.warning(f"{traceback.format_exc()}")
|
||||
|
||||
# logger.debug(f"pattern: {pattern}, s: {s}, match: {match}")
|
||||
return value
|
||||
|
||||
def _parse_json(self, s):
|
||||
try:
|
||||
pattern = r"```([^`]+)```"
|
||||
match = re.findall(pattern, s)
|
||||
if match:
|
||||
return eval(match[0])
|
||||
except:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
def get_memory(self, ):
|
||||
return self.memory.to_tuple_messages(content_key="step_content")
|
||||
|
||||
def get_memory_str(self, ):
|
||||
return "\n".join([": ".join(i) for i in self.memory.to_tuple_messages(content_key="step_content")])
|
||||
138
dev_opsgpt/connector/agents/react_agent.py
Normal file
138
dev_opsgpt/connector/agents/react_agent.py
Normal file
@ -0,0 +1,138 @@
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Union
|
||||
import re
|
||||
import traceback
|
||||
import copy
|
||||
from loguru import logger
|
||||
|
||||
from langchain.prompts.chat import ChatPromptTemplate
|
||||
|
||||
from dev_opsgpt.connector.connector_schema import Message
|
||||
from dev_opsgpt.connector.shcema.memory import Memory
|
||||
from dev_opsgpt.connector.connector_schema import Task, Env, Role, Message, ActionStatus
|
||||
from dev_opsgpt.llm_models import getChatModel
|
||||
from dev_opsgpt.connector.configs.agent_config import REACT_PROMPT_INPUT
|
||||
|
||||
from .base_agent import BaseAgent
|
||||
|
||||
|
||||
class ReactAgent(BaseAgent):
|
||||
def __init__(
|
||||
self,
|
||||
role: Role,
|
||||
task: Task = None,
|
||||
memory: Memory = None,
|
||||
chat_turn: int = 1,
|
||||
do_search: bool = False,
|
||||
do_doc_retrieval: bool = False,
|
||||
do_tool_retrieval: bool = False,
|
||||
temperature: float = 0.2,
|
||||
stop: Union[List[str], str] = "观察",
|
||||
do_filter: bool = True,
|
||||
do_use_self_memory: bool = True,
|
||||
# docs_prompt: str,
|
||||
# prompt_mamnger: PromptManager
|
||||
):
|
||||
super().__init__(role, task, memory, chat_turn, do_search, do_doc_retrieval,
|
||||
do_tool_retrieval, temperature, stop, do_filter,do_use_self_memory
|
||||
)
|
||||
|
||||
def run(self, query: Message, history: Memory = None, background: Memory = None) -> Message:
|
||||
step_nums = copy.deepcopy(self.chat_turn)
|
||||
react_memory = Memory([])
|
||||
# 问题插入
|
||||
output_message = Message(
|
||||
role_name=self.role.role_name,
|
||||
role_type="ai", #self.role.role_type,
|
||||
role_content=query.input_query,
|
||||
step_content=query.input_query,
|
||||
input_query=query.input_query,
|
||||
tools=query.tools
|
||||
)
|
||||
react_memory.append(output_message)
|
||||
idx = 0
|
||||
while step_nums > 0:
|
||||
output_message.role_content = output_message.step_content
|
||||
self_memory = self.memory if self.do_use_self_memory else None
|
||||
prompt = self.create_prompt(query, self_memory, history, background, react_memory)
|
||||
try:
|
||||
content = self.llm.predict(prompt)
|
||||
except Exception as e:
|
||||
logger.warning(f"error prompt: {prompt}")
|
||||
raise Exception(traceback.format_exc())
|
||||
|
||||
output_message.role_content = content
|
||||
output_message.role_contents += [content]
|
||||
output_message.step_content += output_message.role_content
|
||||
output_message.step_contents + [output_message.role_content]
|
||||
|
||||
# logger.debug(f"{self.role.role_name}, {idx} iteration prompt: {prompt}")
|
||||
# logger.info(f"{self.role.role_name}, {idx} iteration step_run: {output_message.role_content}")
|
||||
|
||||
output_message = self.parser(output_message)
|
||||
# when get finished signal can stop early
|
||||
if output_message.action_status == ActionStatus.FINISHED: break
|
||||
# according the output to choose one action for code_content or tool_content
|
||||
output_message = self.step_router(output_message)
|
||||
logger.info(f"{self.role.role_name} react_run: {output_message.role_content}")
|
||||
|
||||
idx += 1
|
||||
step_nums -= 1
|
||||
# react' self_memory saved at last
|
||||
self.append_history(output_message)
|
||||
return output_message
|
||||
|
||||
def create_prompt(
|
||||
self, query: Message, memory: Memory =None, history: Memory = None, background: Memory = None, react_memory: Memory = None, prompt_mamnger=None) -> str:
|
||||
'''
|
||||
role\task\tools\docs\memory
|
||||
'''
|
||||
#
|
||||
doc_infos = self.create_doc_prompt(query)
|
||||
code_infos = self.create_codedoc_prompt(query)
|
||||
#
|
||||
formatted_tools, tool_names = self.create_tools_prompt(query)
|
||||
task_prompt = self.create_task_prompt(query)
|
||||
background_prompt = self.create_background_prompt(background)
|
||||
history_prompt = self.create_history_prompt(history)
|
||||
selfmemory_prompt = self.create_selfmemory_prompt(memory, control_key="step_content")
|
||||
#
|
||||
# extra_system_prompt = self.role.role_prompt
|
||||
prompt = self.role.role_prompt.format(**{"formatted_tools": formatted_tools, "tool_names": tool_names})
|
||||
|
||||
|
||||
task = query.task or self.task
|
||||
if task_prompt is not None:
|
||||
prompt += "\n" + task.task_prompt
|
||||
|
||||
if doc_infos is not None and doc_infos!="" and doc_infos!="不存在知识库辅助信息":
|
||||
prompt += f"\n知识库信息: {doc_infos}"
|
||||
|
||||
if code_infos is not None and code_infos!="" and code_infos!="不存在代码库辅助信息":
|
||||
prompt += f"\n代码库信息: {code_infos}"
|
||||
|
||||
if background_prompt:
|
||||
prompt += "\n" + background_prompt
|
||||
|
||||
if history_prompt:
|
||||
prompt += "\n" + history_prompt
|
||||
|
||||
if selfmemory_prompt:
|
||||
prompt += "\n" + selfmemory_prompt
|
||||
|
||||
# react 流程是自身迭代过程,另外二次触发的是需要作为历史对话信息
|
||||
input_query = react_memory.to_tuple_messages(content_key="step_content")
|
||||
input_query = "\n".join([f"{v}" for k, v in input_query if v])
|
||||
|
||||
# logger.debug(f"{self.role.role_name} extra_system_prompt: {self.role.role_prompt}")
|
||||
# logger.debug(f"{self.role.role_name} input_query: {input_query}")
|
||||
# logger.debug(f"{self.role.role_name} doc_infos: {doc_infos}")
|
||||
# logger.debug(f"{self.role.role_name} tool_names: {tool_names}")
|
||||
prompt += "\n" + REACT_PROMPT_INPUT.format(**{"query": input_query})
|
||||
|
||||
# prompt = extra_system_prompt.format(**{"query": input_query, "doc_infos": doc_infos, "formatted_tools": formatted_tools, "tool_names": tool_names})
|
||||
while "{{" in prompt or "}}" in prompt:
|
||||
prompt = prompt.replace("{{", "{")
|
||||
prompt = prompt.replace("}}", "}")
|
||||
return prompt
|
||||
|
||||
5
dev_opsgpt/connector/chains/__init__.py
Normal file
5
dev_opsgpt/connector/chains/__init__.py
Normal file
@ -0,0 +1,5 @@
|
||||
from .base_chain import BaseChain
|
||||
|
||||
__all__ = [
|
||||
"BaseChain"
|
||||
]
|
||||
281
dev_opsgpt/connector/chains/base_chain.py
Normal file
281
dev_opsgpt/connector/chains/base_chain.py
Normal file
@ -0,0 +1,281 @@
|
||||
from pydantic import BaseModel
|
||||
from typing import List
|
||||
import json
|
||||
import re
|
||||
from loguru import logger
|
||||
import traceback
|
||||
import uuid
|
||||
import copy
|
||||
|
||||
from dev_opsgpt.connector.agents import BaseAgent
|
||||
from dev_opsgpt.tools.base_tool import BaseTools, Tool
|
||||
from dev_opsgpt.connector.shcema.memory import Memory
|
||||
from dev_opsgpt.connector.connector_schema import (
|
||||
Role, Message, ActionStatus, ChainConfig,
|
||||
load_role_configs
|
||||
)
|
||||
from dev_opsgpt.sandbox import PyCodeBox, CodeBoxResponse
|
||||
|
||||
|
||||
from configs.server_config import SANDBOX_SERVER
|
||||
|
||||
from dev_opsgpt.connector.configs.agent_config import AGETN_CONFIGS
|
||||
role_configs = load_role_configs(AGETN_CONFIGS)
|
||||
|
||||
|
||||
class BaseChain:
|
||||
def __init__(
|
||||
self,
|
||||
chainConfig: ChainConfig,
|
||||
agents: List[BaseAgent],
|
||||
chat_turn: int = 1,
|
||||
do_checker: bool = False,
|
||||
do_code_exec: bool = False,
|
||||
# prompt_mamnger: PromptManager
|
||||
) -> None:
|
||||
self.chainConfig = chainConfig
|
||||
self.agents = agents
|
||||
self.chat_turn = chat_turn
|
||||
self.do_checker = do_checker
|
||||
self.checker = BaseAgent(role=role_configs["checker"].role,
|
||||
task = None,
|
||||
memory = None,
|
||||
do_search = role_configs["checker"].do_search,
|
||||
do_doc_retrieval = role_configs["checker"].do_doc_retrieval,
|
||||
do_tool_retrieval = role_configs["checker"].do_tool_retrieval,
|
||||
do_filter=False, do_use_self_memory=False)
|
||||
|
||||
self.global_memory = Memory([])
|
||||
self.local_memory = Memory([])
|
||||
self.do_code_exec = do_code_exec
|
||||
self.codebox = PyCodeBox(
|
||||
remote_url=SANDBOX_SERVER["url"],
|
||||
remote_ip=SANDBOX_SERVER["host"],
|
||||
remote_port=SANDBOX_SERVER["port"],
|
||||
token="mytoken",
|
||||
do_code_exe=True,
|
||||
do_remote=SANDBOX_SERVER["do_remote"],
|
||||
do_check_net=False
|
||||
)
|
||||
|
||||
def step(self, query: Message, history: Memory = None, background: Memory = None) -> Message:
|
||||
'''execute chain'''
|
||||
local_memory = Memory([])
|
||||
input_message = copy.deepcopy(query)
|
||||
step_nums = copy.deepcopy(self.chat_turn)
|
||||
check_message = None
|
||||
|
||||
self.global_memory.append(input_message)
|
||||
local_memory.append(input_message)
|
||||
while step_nums > 0:
|
||||
|
||||
for agent in self.agents:
|
||||
output_message = agent.run(input_message, history, background=background)
|
||||
output_message = self.inherit_extrainfo(input_message, output_message)
|
||||
# according the output to choose one action for code_content or tool_content
|
||||
logger.info(f"{agent.role.role_name} message: {output_message.role_content}")
|
||||
output_message = self.parser(output_message)
|
||||
# output_message = self.step_router(output_message)
|
||||
|
||||
input_message = output_message
|
||||
self.global_memory.append(output_message)
|
||||
|
||||
local_memory.append(output_message)
|
||||
# when get finished signal can stop early
|
||||
if output_message.action_status == ActionStatus.FINISHED:
|
||||
break
|
||||
|
||||
if self.do_checker:
|
||||
logger.debug(f"{self.checker.role.role_name} input global memory: {self.global_memory.to_str_messages(content_key='step_content')}")
|
||||
check_message = self.checker.run(query, background=self.global_memory)
|
||||
check_message = self.parser(check_message)
|
||||
check_message = self.filter(check_message)
|
||||
check_message = self.inherit_extrainfo(output_message, check_message)
|
||||
logger.debug(f"{self.checker.role.role_name}: {check_message.role_content}")
|
||||
|
||||
if check_message.action_status == ActionStatus.FINISHED:
|
||||
self.global_memory.append(check_message)
|
||||
break
|
||||
|
||||
step_nums -= 1
|
||||
|
||||
return check_message or output_message, local_memory
|
||||
|
||||
def step_router(self, message: Message) -> Message:
|
||||
''''''
|
||||
# message = self.parser(message)
|
||||
# logger.debug(f"message.action_status: {message.action_status}")
|
||||
if message.action_status == ActionStatus.CODING:
|
||||
message = self.code_step(message)
|
||||
elif message.action_status == ActionStatus.TOOL_USING:
|
||||
message = self.tool_step(message)
|
||||
|
||||
return message
|
||||
|
||||
def code_step(self, message: Message) -> Message:
|
||||
'''execute code'''
|
||||
# logger.debug(f"message.role_content: {message.role_content}, message.code_content: {message.code_content}")
|
||||
code_answer = self.codebox.chat('```python\n{}```'.format(message.code_content))
|
||||
uid = str(uuid.uuid1())
|
||||
if code_answer.code_exe_type == "image/png":
|
||||
message.figures[uid] = code_answer.code_exe_response
|
||||
message.code_answer = f"\n观察: 执行代码后获得输出一张图片, 文件名为{uid}\n"
|
||||
message.observation = f"\n观察: 执行代码后获得输出一张图片, 文件名为{uid}\n"
|
||||
message.step_content += f"\n观察: 执行代码后获得输出一张图片, 文件名为{uid}\n"
|
||||
message.step_contents += [f"\n观察: 执行代码后获得输出一张图片, 文件名为{uid}\n"]
|
||||
message.role_content += f"\n执行代码后获得输出一张图片, 文件名为{uid}\n"
|
||||
else:
|
||||
message.code_answer = code_answer.code_exe_response
|
||||
message.observation = code_answer.code_exe_response
|
||||
message.step_content += f"\n观察: 执行代码后获得输出是 {code_answer.code_exe_response}\n"
|
||||
message.step_contents += [f"\n观察: 执行代码后获得输出是 {code_answer.code_exe_response}\n"]
|
||||
message.role_content += f"\n观察: 执行代码后获得输出是 {code_answer.code_exe_response}\n"
|
||||
logger.info(f"观察: {message.action_status}, {message.observation}")
|
||||
return message
|
||||
|
||||
def tool_step(self, message: Message) -> Message:
|
||||
'''execute tool'''
|
||||
# logger.debug(f"message: {message.action_status}, {message.tool_name}, {message.tool_params}")
|
||||
tool_names = [tool.name for tool in message.tools]
|
||||
if message.tool_name not in tool_names:
|
||||
message.tool_answer = "不存在可以执行的tool"
|
||||
message.observation = "不存在可以执行的tool"
|
||||
message.role_content += f"\n观察: 不存在可以执行的tool\n"
|
||||
message.step_content += f"\n观察: 不存在可以执行的tool\n"
|
||||
message.step_contents += [f"\n观察: 不存在可以执行的tool\n"]
|
||||
for tool in message.tools:
|
||||
if tool.name == message.tool_name:
|
||||
tool_res = tool.func(**message.tool_params)
|
||||
message.tool_answer = tool_res
|
||||
message.observation = tool_res
|
||||
message.role_content += f"\n观察: {tool_res}\n"
|
||||
message.step_content += f"\n观察: {tool_res}\n"
|
||||
message.step_contents += [f"\n观察: {tool_res}\n"]
|
||||
return message
|
||||
|
||||
def filter(self, message: Message, stop=None) -> Message:
|
||||
|
||||
tool_params = self.parser_spec_key(message.role_content, "tool_params")
|
||||
code_content = self.parser_spec_key(message.role_content, "code_content")
|
||||
plan = self.parser_spec_key(message.role_content, "plan")
|
||||
plans = self.parser_spec_key(message.role_content, "plans", do_search=False)
|
||||
content = self.parser_spec_key(message.role_content, "content", do_search=False)
|
||||
|
||||
# logger.debug(f"tool_params: {tool_params}, code_content: {code_content}, plan: {plan}, plans: {plans}, content: {content}")
|
||||
role_content = tool_params or code_content or plan or plans or content
|
||||
message.role_content = role_content or message.role_content
|
||||
return message
|
||||
|
||||
def parser(self, message: Message) -> Message:
|
||||
''''''
|
||||
content = message.role_content
|
||||
parser_keys = ["action", "code_content", "code_filename", "tool_params", "plans"]
|
||||
try:
|
||||
s_json = self._parse_json(content)
|
||||
message.action_status = s_json.get("action")
|
||||
message.code_content = s_json.get("code_content")
|
||||
message.tool_params = s_json.get("tool_params")
|
||||
message.tool_name = s_json.get("tool_name")
|
||||
message.code_filename = s_json.get("code_filename")
|
||||
message.plans = s_json.get("plans")
|
||||
# for parser_key in parser_keys:
|
||||
# message.action_status = content.get(parser_key)
|
||||
except Exception as e:
|
||||
# logger.warning(f"{traceback.format_exc()}")
|
||||
action_value = self._match(r"'action':\s*'([^']*)'", content) if "'action'" in content else self._match(r'"action":\s*"([^"]*)"', content)
|
||||
code_content_value = self._match(r"'code_content':\s*'([^']*)'", content) if "'code_content'" in content else self._match(r'"code_content":\s*"([^"]*)"', content)
|
||||
filename_value = self._match(r"'code_filename':\s*'([^']*)'", content) if "'code_filename'" in content else self._match(r'"code_filename":\s*"([^"]*)"', content)
|
||||
tool_params_value = self._match(r"'tool_params':\s*(\{[^{}]*\})", content, do_json=True) if "'tool_params'" in content \
|
||||
else self._match(r'"tool_params":\s*(\{[^{}]*\})', content, do_json=True)
|
||||
tool_name_value = self._match(r"'tool_name':\s*'([^']*)'", content) if "'tool_name'" in content else self._match(r'"tool_name":\s*"([^"]*)"', content)
|
||||
plans_value = self._match(r"'plans':\s*(\[.*?\])", content, do_search=False) if "'plans'" in content else self._match(r'"plans":\s*(\[.*?\])', content, do_search=False, )
|
||||
# re解析
|
||||
message.action_status = action_value or "default"
|
||||
message.code_content = code_content_value
|
||||
message.code_filename = filename_value
|
||||
message.tool_params = tool_params_value
|
||||
message.tool_name = tool_name_value
|
||||
message.plans = plans_value
|
||||
|
||||
logger.debug(f"确认当前的action: {message.action_status}")
|
||||
|
||||
return message
|
||||
|
||||
def parser_spec_key(self, content, key, do_search=True, do_json=False) -> str:
|
||||
''''''
|
||||
key2pattern = {
|
||||
"'action'": r"'action':\s*'([^']*)'", '"action"': r'"action":\s*"([^"]*)"',
|
||||
"'code_content'": r"'code_content':\s*'([^']*)'", '"code_content"': r'"code_content":\s*"([^"]*)"',
|
||||
"'code_filename'": r"'code_filename':\s*'([^']*)'", '"code_filename"': r'"code_filename":\s*"([^"]*)"',
|
||||
"'tool_params'": r"'tool_params':\s*(\{[^{}]*\})", '"tool_params"': r'"tool_params":\s*(\{[^{}]*\})',
|
||||
"'tool_name'": r"'tool_name':\s*'([^']*)'", '"tool_name"': r'"tool_name":\s*"([^"]*)"',
|
||||
"'plans'": r"'plans':\s*(\[.*?\])", '"plans"': r'"plans":\s*(\[.*?\])',
|
||||
"'content'": r"'content':\s*'([^']*)'", '"content"': r'"content":\s*"([^"]*)"',
|
||||
}
|
||||
|
||||
s_json = self._parse_json(content)
|
||||
try:
|
||||
if s_json and key in s_json:
|
||||
return str(s_json[key])
|
||||
except:
|
||||
pass
|
||||
|
||||
keystr = f"'{key}'" if f"'{key}'" in content else f'"{key}"'
|
||||
return self._match(key2pattern.get(keystr, fr"'{key}':\s*'([^']*)'"), content, do_search=do_search, do_json=do_json)
|
||||
|
||||
def _match(self, pattern, s, do_search=True, do_json=False):
|
||||
try:
|
||||
if do_search:
|
||||
match = re.search(pattern, s)
|
||||
if match:
|
||||
value = match.group(1).replace("\\n", "\n")
|
||||
if do_json:
|
||||
value = json.loads(value)
|
||||
else:
|
||||
value = None
|
||||
else:
|
||||
match = re.findall(pattern, s, re.DOTALL)
|
||||
if match:
|
||||
value = match[0]
|
||||
if do_json:
|
||||
value = json.loads(value)
|
||||
else:
|
||||
value = None
|
||||
except Exception as e:
|
||||
logger.warning(f"{traceback.format_exc()}")
|
||||
|
||||
# logger.debug(f"pattern: {pattern}, s: {s}, match: {match}")
|
||||
return value
|
||||
|
||||
def _parse_json(self, s):
|
||||
try:
|
||||
pattern = r"```([^`]+)```"
|
||||
match = re.findall(pattern, s)
|
||||
if match:
|
||||
return eval(match[0])
|
||||
except:
|
||||
pass
|
||||
return None
|
||||
|
||||
def inherit_extrainfo(self, input_message: Message, output_message: Message):
|
||||
output_message.db_docs = input_message.db_docs
|
||||
output_message.search_docs = input_message.search_docs
|
||||
output_message.code_docs = input_message.code_docs
|
||||
output_message.figures.update(input_message.figures)
|
||||
return output_message
|
||||
|
||||
def get_memory(self, do_all_memory=True, content_key="role_content") -> Memory:
|
||||
memory = self.global_memory if do_all_memory else self.local_memory
|
||||
return memory.to_tuple_messages(content_key=content_key)
|
||||
|
||||
def get_memory_str(self, do_all_memory=True, content_key="role_content") -> Memory:
|
||||
memory = self.global_memory if do_all_memory else self.local_memory
|
||||
# for i in memory.to_tuple_messages(content_key=content_key):
|
||||
# logger.debug(f"{i}")
|
||||
return "\n".join([": ".join(i) for i in memory.to_tuple_messages(content_key=content_key)])
|
||||
|
||||
def get_agents_memory(self, content_key="role_content"):
|
||||
return [agent.get_memory(content_key=content_key) for agent in self.agents]
|
||||
|
||||
def get_agents_memory_str(self, content_key="role_content"):
|
||||
return "************".join([f"{agent.role.role_name}\n" + agent.get_memory_str(content_key=content_key) for agent in self.agents])
|
||||
28
dev_opsgpt/connector/chains/chains.py
Normal file
28
dev_opsgpt/connector/chains/chains.py
Normal file
@ -0,0 +1,28 @@
|
||||
from typing import List
|
||||
from dev_opsgpt.connector.agents import BaseAgent
|
||||
from .base_chain import BaseChain
|
||||
|
||||
|
||||
|
||||
class simpleChatChain(BaseChain):
|
||||
|
||||
def __init__(self, agents: List[BaseAgent], do_code_exec: bool = False) -> None:
|
||||
super().__init__(agents, do_code_exec)
|
||||
|
||||
|
||||
class toolChatChain(BaseChain):
|
||||
|
||||
def __init__(self, agents: List[BaseAgent], do_code_exec: bool = False) -> None:
|
||||
super().__init__(agents, do_code_exec)
|
||||
|
||||
|
||||
class dataAnalystChain(BaseChain):
|
||||
|
||||
def __init__(self, agents: List[BaseAgent], do_code_exec: bool = False) -> None:
|
||||
super().__init__(agents, do_code_exec)
|
||||
|
||||
|
||||
class plannerChain(BaseChain):
|
||||
|
||||
def __init__(self, agents: List[BaseAgent], do_code_exec: bool = False) -> None:
|
||||
super().__init__(agents, do_code_exec)
|
||||
7
dev_opsgpt/connector/configs/__init__.py
Normal file
7
dev_opsgpt/connector/configs/__init__.py
Normal file
@ -0,0 +1,7 @@
|
||||
from .agent_config import AGETN_CONFIGS
|
||||
from .chain_config import CHAIN_CONFIGS
|
||||
from .phase_config import PHASE_CONFIGS
|
||||
|
||||
__all__ = [
|
||||
"AGETN_CONFIGS", "CHAIN_CONFIGS", "PHASE_CONFIGS"
|
||||
]
|
||||
410
dev_opsgpt/connector/configs/agent_config.py
Normal file
410
dev_opsgpt/connector/configs/agent_config.py
Normal file
@ -0,0 +1,410 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class AgentType:
|
||||
REACT = "ReactAgent"
|
||||
ONE_STEP = "BaseAgent"
|
||||
DEFAULT = "BaseAgent"
|
||||
|
||||
|
||||
REACT_TOOL_PROMPT = """尽可能地以有帮助和准确的方式回应人类。您可以使用以下工具:
|
||||
{formatted_tools}
|
||||
使用json blob来指定一个工具,提供一个action关键字(工具名称)和一个tool_params关键字(工具输入)。
|
||||
有效的"action"值为:"finished" 或 "tool_using" (使用工具来回答问题)
|
||||
有效的"tool_name"值为:{tool_names}
|
||||
请仅在每个$JSON_BLOB中提供一个action,如下所示:
|
||||
```
|
||||
{{{{
|
||||
"action": $ACTION,
|
||||
"tool_name": $TOOL_NAME
|
||||
"tool_params": $INPUT
|
||||
}}}}
|
||||
```
|
||||
|
||||
按照以下格式进行回应:
|
||||
问题:输入问题以回答
|
||||
思考:考虑之前和之后的步骤
|
||||
行动:
|
||||
```
|
||||
$JSON_BLOB
|
||||
```
|
||||
观察:行动结果
|
||||
...(重复思考/行动/观察N次)
|
||||
思考:我知道该如何回应
|
||||
行动:
|
||||
```
|
||||
{{{{
|
||||
"action": "finished",
|
||||
"tool_name": "notool"
|
||||
"tool_params": "最终返回答案给到用户"
|
||||
}}}}
|
||||
```
|
||||
"""
|
||||
|
||||
REACT_PROMPT_INPUT = '''下面开始!记住根据问题进行返回需要生成的答案
|
||||
问题: {query}'''
|
||||
|
||||
|
||||
REACT_CODE_PROMPT = """尽可能地以有帮助和准确的方式回应人类,能够逐步编写可执行并打印变量的代码来解决问题
|
||||
使用 JSON Blob 来指定一个返回的内容,提供一个 action(行动)和一个 code (生成代码)。
|
||||
有效的 'action' 值为:'coding'(结合总结下述思维链过程编写下一步的可执行代码) or 'finished' (总结下述思维链过程可回答问题)。
|
||||
在每个 $JSON_BLOB 中仅提供一个 action,如下所示:
|
||||
```
|
||||
{{{{'action': $ACTION,'code_content': $CODE}}}}
|
||||
```
|
||||
|
||||
按照以下思维链格式进行回应:
|
||||
问题:输入问题以回答
|
||||
思考:考虑之前和之后的步骤
|
||||
行动:
|
||||
```
|
||||
$JSON_BLOB
|
||||
```
|
||||
观察:行动结果
|
||||
...(重复思考/行动/观察N次)
|
||||
思考:我知道该如何回应
|
||||
行动:
|
||||
```
|
||||
{{{{
|
||||
"action": "finished",
|
||||
"code_content": "总结上述思维链过程回答问题"
|
||||
}}}}
|
||||
```
|
||||
"""
|
||||
|
||||
GENERAL_PLANNER_PROMPT = """你是一个通用计划拆解助手,将问题拆解问题成各个详细明确的步骤计划或直接回答问题,尽可能地以有帮助和准确的方式回应人类,
|
||||
使用 JSON Blob 来指定一个返回的内容,提供一个 action(行动)和一个 plans (生成的计划)。
|
||||
有效的 'action' 值为:'planning'(拆解计划) or 'only_answer' (不需要拆解问题即可直接回答问题)。
|
||||
有效的 'plans' 值为: 一个任务列表,按顺序写出需要执行的计划
|
||||
在每个 $JSON_BLOB 中仅提供一个 action,如下所示:
|
||||
```
|
||||
{{'action': 'planning', 'plans': [$PLAN1, $PLAN2, $PLAN3, ..., $PLANN], }}
|
||||
或者
|
||||
{{'action': 'only_answer', 'plans': "直接回答问题", }}
|
||||
```
|
||||
|
||||
按照以下格式进行回应:
|
||||
问题:输入问题以回答
|
||||
行动:
|
||||
```
|
||||
$JSON_BLOB
|
||||
```
|
||||
"""
|
||||
|
||||
DATA_PLANNER_PROMPT = """你是一个数据分析助手,能够根据问题来制定一个详细明确的数据分析计划,尽可能地以有帮助和准确的方式回应人类,
|
||||
使用 JSON Blob 来指定一个返回的内容,提供一个 action(行动)和一个 plans (生成的计划)。
|
||||
有效的 'action' 值为:'planning'(拆解计划) or 'only_answer' (不需要拆解问题即可直接回答问题)。
|
||||
有效的 'plans' 值为: 一份数据分析计划清单,按顺序排列,用文本表示
|
||||
在每个 $JSON_BLOB 中仅提供一个 action,如下所示:
|
||||
```
|
||||
{{'action': 'planning', 'plans': '$PLAN1, $PLAN2, ..., $PLAN3' }}
|
||||
```
|
||||
|
||||
按照以下格式进行回应:
|
||||
问题:输入问题以回答
|
||||
行动:
|
||||
```
|
||||
$JSON_BLOB
|
||||
```
|
||||
"""
|
||||
|
||||
TOOL_PLANNER_PROMPT = """你是一个工具使用过程的计划拆解助手,将问题拆解为一系列的工具使用计划,若没有可用工具则直接回答问题,尽可能地以有帮助和准确的方式回应人类,你可以使用以下工具:
|
||||
{formatted_tools}
|
||||
使用 JSON Blob 来指定一个返回的内容,提供一个 action(行动)和一个 plans (生成的计划)。
|
||||
有效的 'action' 值为:'planning'(拆解计划) or 'only_answer' (不需要拆解问题即可直接回答问题)。
|
||||
有效的 'plans' 值为: 一个任务列表,按顺序写出需要使用的工具和使用该工具的理由
|
||||
在每个 $JSON_BLOB 中仅提供一个 action,如下两个示例所示:
|
||||
```
|
||||
{{'action': 'planning', 'plans': [$PLAN1, $PLAN2, $PLAN3, ..., $PLANN], }}
|
||||
```
|
||||
或者 若无法通过以上工具解决问题,则直接回答问题
|
||||
```
|
||||
{{'action': 'only_answer', 'plans': "直接回答问题", }}
|
||||
```
|
||||
|
||||
按照以下格式进行回应:
|
||||
问题:输入问题以回答
|
||||
行动:
|
||||
```
|
||||
$JSON_BLOB
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
RECOGNIZE_INTENTION_PROMPT = """你是一个任务决策助手,能够将理解用户意图并决策采取最合适的行动,尽可能地以有帮助和准确的方式回应人类,
|
||||
使用 JSON Blob 来指定一个返回的内容,提供一个 action(行动)。
|
||||
有效的 'action' 值为:'planning'(需要先进行拆解计划) or 'only_answer' (不需要拆解问题即可直接回答问题)or "tool_using" (使用工具来回答问题) or 'coding'(生成可执行的代码)。
|
||||
在每个 $JSON_BLOB 中仅提供一个 action,如下所示:
|
||||
```
|
||||
{{'action': $ACTION}}
|
||||
```
|
||||
按照以下格式进行回应:
|
||||
问题:输入问题以回答
|
||||
行动:$ACTION
|
||||
```
|
||||
$JSON_BLOB
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
CHECKER_PROMPT = """尽可能地以有帮助和准确的方式回应人类,判断问题是否得到解答,同时展现解答的过程和内容
|
||||
使用 JSON Blob 来指定一个返回的内容,提供一个 action(行动)。
|
||||
有效的 'action' 值为:'finished'(任务已经可以通过“背景信息”和“对话信息”回答问题) or 'continue' (“背景信息”和“对话信息”不足以回答问题)。
|
||||
在每个 $JSON_BLOB 中仅提供一个 action,如下所示:
|
||||
```
|
||||
{{'action': $ACTION, 'content': '提取“背景信息”和“对话信息”中信息来回答问题'}}
|
||||
```
|
||||
按照以下格式进行回应:
|
||||
问题:输入问题以回答
|
||||
行动:$ACTION
|
||||
```
|
||||
$JSON_BLOB
|
||||
```
|
||||
"""
|
||||
|
||||
CONV_SUMMARY_PROMPT = """尽可能地以有帮助和准确的方式回应人类,根据“背景信息”中的有效信息回答问题,
|
||||
使用 JSON Blob 来指定一个返回的内容,提供一个 action(行动)。
|
||||
有效的 'action' 值为:'finished'(任务已经可以通过上下文信息可以回答) or 'continue' (根据背景信息回答问题)。
|
||||
在每个 $JSON_BLOB 中仅提供一个 action,如下所示:
|
||||
```
|
||||
{{'action': $ACTION, 'content': '根据背景信息回答问题'}}
|
||||
```
|
||||
按照以下格式进行回应:
|
||||
问题:输入问题以回答
|
||||
行动:
|
||||
```
|
||||
$JSON_BLOB
|
||||
```
|
||||
"""
|
||||
|
||||
CONV_SUMMARY_PROMPT = """尽可能地以有帮助和准确的方式回应人类
|
||||
根据“背景信息”中的有效信息回答问题,同时展现解答的过程和内容
|
||||
若能根“背景信息”回答问题,则直接回答
|
||||
否则,总结“背景信息”的内容
|
||||
"""
|
||||
|
||||
|
||||
|
||||
QA_PROMPT = """根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,不允许在答案中添加编造成分,答案请使用中文。
|
||||
使用 JSON Blob 来指定一个返回的内容,提供一个 action(行动)。
|
||||
有效的 'action' 值为:'finished'(任务已经可以通过上下文信息可以回答) or 'continue' (上下文信息不足以回答问题)。
|
||||
在每个 $JSON_BLOB 中仅提供一个 action,如下所示:
|
||||
```
|
||||
{{'action': $ACTION, 'content': '总结对话内容'}}
|
||||
```
|
||||
按照以下格式进行回应:
|
||||
问题:输入问题以回答
|
||||
行动:$ACTION
|
||||
```
|
||||
$JSON_BLOB
|
||||
```
|
||||
"""
|
||||
|
||||
CODE_QA_PROMPT = """【指令】根据已知信息来回答问"""
|
||||
|
||||
|
||||
AGETN_CONFIGS = {
|
||||
"checker": {
|
||||
"role": {
|
||||
"role_prompt": CHECKER_PROMPT,
|
||||
"role_type": "ai",
|
||||
"role_name": "checker",
|
||||
"role_desc": "",
|
||||
"agent_type": "BaseAgent"
|
||||
},
|
||||
"chat_turn": 1,
|
||||
"do_search": False,
|
||||
"do_doc_retrieval": False,
|
||||
"do_tool_retrieval": False
|
||||
},
|
||||
"conv_summary": {
|
||||
"role": {
|
||||
"role_prompt": CONV_SUMMARY_PROMPT,
|
||||
"role_type": "ai",
|
||||
"role_name": "conv_summary",
|
||||
"role_desc": "",
|
||||
"agent_type": "BaseAgent"
|
||||
},
|
||||
"chat_turn": 1,
|
||||
"do_search": False,
|
||||
"do_doc_retrieval": False,
|
||||
"do_tool_retrieval": False
|
||||
},
|
||||
"general_planner": {
|
||||
"role": {
|
||||
"role_prompt": GENERAL_PLANNER_PROMPT,
|
||||
"role_type": "ai",
|
||||
"role_name": "general_planner",
|
||||
"role_desc": "",
|
||||
"agent_type": "BaseAgent"
|
||||
},
|
||||
"chat_turn": 1,
|
||||
"do_search": False,
|
||||
"do_doc_retrieval": False,
|
||||
"do_tool_retrieval": False
|
||||
},
|
||||
"planner": {
|
||||
"role": {
|
||||
"role_prompt": DATA_PLANNER_PROMPT,
|
||||
"role_type": "ai",
|
||||
"role_name": "planner",
|
||||
"role_desc": "",
|
||||
"agent_type": "BaseAgent"
|
||||
},
|
||||
"chat_turn": 1,
|
||||
"do_search": False,
|
||||
"do_doc_retrieval": False,
|
||||
"do_tool_retrieval": False
|
||||
},
|
||||
"intention_recognizer": {
|
||||
"role": {
|
||||
"role_prompt": RECOGNIZE_INTENTION_PROMPT,
|
||||
"role_type": "ai",
|
||||
"role_name": "intention_recognizer",
|
||||
"role_desc": "",
|
||||
"agent_type": "BaseAgent"
|
||||
},
|
||||
"chat_turn": 1,
|
||||
"do_search": False,
|
||||
"do_doc_retrieval": False,
|
||||
"do_tool_retrieval": False
|
||||
},
|
||||
"tool_planner": {
|
||||
"role": {
|
||||
"role_prompt": TOOL_PLANNER_PROMPT,
|
||||
"role_type": "ai",
|
||||
"role_name": "tool_planner",
|
||||
"role_desc": "",
|
||||
"agent_type": "BaseAgent"
|
||||
},
|
||||
"chat_turn": 1,
|
||||
"do_search": False,
|
||||
"do_doc_retrieval": False,
|
||||
"do_tool_retrieval": False
|
||||
},
|
||||
"tool_react": {
|
||||
"role": {
|
||||
"role_prompt": REACT_TOOL_PROMPT,
|
||||
"role_type": "ai",
|
||||
"role_name": "tool_react",
|
||||
"role_desc": "",
|
||||
"agent_type": "ReactAgent"
|
||||
},
|
||||
"chat_turn": 5,
|
||||
"do_search": False,
|
||||
"do_doc_retrieval": False,
|
||||
"do_tool_retrieval": False,
|
||||
"stop": "观察"
|
||||
},
|
||||
"code_react": {
|
||||
"role": {
|
||||
"role_prompt": REACT_CODE_PROMPT,
|
||||
"role_type": "ai",
|
||||
"role_name": "code_react",
|
||||
"role_desc": "",
|
||||
"agent_type": "ReactAgent"
|
||||
},
|
||||
"chat_turn": 5,
|
||||
"do_search": False,
|
||||
"do_doc_retrieval": False,
|
||||
"do_tool_retrieval": False,
|
||||
"stop": "观察"
|
||||
},
|
||||
"qaer": {
|
||||
"role": {
|
||||
"role_prompt": QA_PROMPT,
|
||||
"role_type": "ai",
|
||||
"role_name": "qaer",
|
||||
"role_desc": "",
|
||||
"agent_type": "BaseAgent"
|
||||
},
|
||||
"chat_turn": 1,
|
||||
"do_search": False,
|
||||
"do_doc_retrieval": True,
|
||||
"do_tool_retrieval": False
|
||||
},
|
||||
"code_qaer": {
|
||||
"role": {
|
||||
"role_prompt": CODE_QA_PROMPT ,
|
||||
"role_type": "ai",
|
||||
"role_name": "code_qaer",
|
||||
"role_desc": "",
|
||||
"agent_type": "BaseAgent"
|
||||
},
|
||||
"chat_turn": 1,
|
||||
"do_search": False,
|
||||
"do_doc_retrieval": True,
|
||||
"do_tool_retrieval": False
|
||||
},
|
||||
"searcher": {
|
||||
"role": {
|
||||
"role_prompt": QA_PROMPT,
|
||||
"role_type": "ai",
|
||||
"role_name": "searcher",
|
||||
"role_desc": "",
|
||||
"agent_type": "BaseAgent"
|
||||
},
|
||||
"chat_turn": 1,
|
||||
"do_search": True,
|
||||
"do_doc_retrieval": False,
|
||||
"do_tool_retrieval": False
|
||||
},
|
||||
"answer": {
|
||||
"role": {
|
||||
"role_prompt": "",
|
||||
"role_type": "ai",
|
||||
"role_name": "answer",
|
||||
"role_desc": "",
|
||||
"agent_type": "BaseAgent"
|
||||
},
|
||||
"chat_turn": 1,
|
||||
"do_search": False,
|
||||
"do_doc_retrieval": False,
|
||||
"do_tool_retrieval": False
|
||||
},
|
||||
"data_analyst": {
|
||||
"role": {
|
||||
"role_prompt": """你是一个数据分析的代码开发助手,能够编写可执行的代码来完成相关的数据分析问题,使用 JSON Blob 来指定一个返回的内容,通过提供一个 action(行动)和一个 code (生成代码)和 一个 file_name (指定保存文件)。\
|
||||
有效的 'action' 值为:'coding'(生成可执行的代码) or 'finished' (不生成代码并直接返回答案)。在每个 $JSON_BLOB 中仅提供一个 action,如下所示:\
|
||||
```\n{{'action': $ACTION,'code_content': $CODE, 'code_filename': $FILE_NAME}}```\
|
||||
下面开始!记住根据问题进行返回需要生成的答案,格式为 ```JSON_BLOB```""",
|
||||
"role_type": "ai",
|
||||
"role_name": "data_analyst",
|
||||
"role_desc": "",
|
||||
"agent_type": "BaseAgent"
|
||||
},
|
||||
"chat_turn": 1,
|
||||
"do_search": False,
|
||||
"do_doc_retrieval": False,
|
||||
"do_tool_retrieval": False
|
||||
},
|
||||
"deveploer": {
|
||||
"role": {
|
||||
"role_prompt": """你是一个代码开发助手,能够编写可执行的代码来完成问题,使用 JSON Blob 来指定一个返回的内容,通过提供一个 action(行动)和一个 code (生成代码)和 一个 file_name (指定保存文件)。\
|
||||
有效的 'action' 值为:'coding'(生成可执行的代码) or 'finished' (不生成代码并直接返回答案)。在每个 $JSON_BLOB 中仅提供一个 action,如下所示:\
|
||||
```\n{{'action': $ACTION,'code_content': $CODE, 'code_filename': $FILE_NAME}}```\
|
||||
下面开始!记住根据问题进行返回需要生成的答案,格式为 ```JSON_BLOB```""",
|
||||
"role_type": "ai",
|
||||
"role_name": "deveploer",
|
||||
"role_desc": "",
|
||||
"agent_type": "BaseAgent"
|
||||
},
|
||||
"chat_turn": 1,
|
||||
"do_search": False,
|
||||
"do_doc_retrieval": False,
|
||||
"do_tool_retrieval": False
|
||||
},
|
||||
"tester": {
|
||||
"role": {
|
||||
"role_prompt": "你是一个QA问答的助手,能够尽可能准确地回答问题,下面请逐步思考问题并回答",
|
||||
"role_type": "ai",
|
||||
"role_name": "tester",
|
||||
"role_desc": "",
|
||||
"agent_type": "BaseAgent"
|
||||
},
|
||||
"chat_turn": 1,
|
||||
"do_search": False,
|
||||
"do_doc_retrieval": False,
|
||||
"do_tool_retrieval": False
|
||||
}
|
||||
}
|
||||
88
dev_opsgpt/connector/configs/chain_config.py
Normal file
88
dev_opsgpt/connector/configs/chain_config.py
Normal file
@ -0,0 +1,88 @@
|
||||
|
||||
|
||||
CHAIN_CONFIGS = {
|
||||
"chatChain": {
|
||||
"chain_name": "chatChain",
|
||||
"chain_type": "BaseChain",
|
||||
"agents": ["answer"],
|
||||
"chat_turn": 1,
|
||||
"do_checker": False,
|
||||
"clear_structure": "True",
|
||||
"brainstorming": "False",
|
||||
"gui_design": "True",
|
||||
"git_management": "False",
|
||||
"self_improve": "False"
|
||||
},
|
||||
"docChatChain": {
|
||||
"chain_name": "docChatChain",
|
||||
"chain_type": "BaseChain",
|
||||
"agents": ["qaer"],
|
||||
"chat_turn": 1,
|
||||
"do_checker": False,
|
||||
"clear_structure": "True",
|
||||
"brainstorming": "False",
|
||||
"gui_design": "True",
|
||||
"git_management": "False",
|
||||
"self_improve": "False"
|
||||
},
|
||||
"searchChatChain": {
|
||||
"chain_name": "searchChatChain",
|
||||
"chain_type": "BaseChain",
|
||||
"agents": ["searcher"],
|
||||
"chat_turn": 1,
|
||||
"do_checker": False,
|
||||
"clear_structure": "True",
|
||||
"brainstorming": "False",
|
||||
"gui_design": "True",
|
||||
"git_management": "False",
|
||||
"self_improve": "False"
|
||||
},
|
||||
"codeChatChain": {
|
||||
"chain_name": "codehChatChain",
|
||||
"chain_type": "BaseChain",
|
||||
"agents": ["code_qaer"],
|
||||
"chat_turn": 1,
|
||||
"do_checker": False,
|
||||
"clear_structure": "True",
|
||||
"brainstorming": "False",
|
||||
"gui_design": "True",
|
||||
"git_management": "False",
|
||||
"self_improve": "False"
|
||||
},
|
||||
"toolReactChain": {
|
||||
"chain_name": "toolReactChain",
|
||||
"chain_type": "BaseChain",
|
||||
"agents": ["tool_planner", "tool_react"],
|
||||
"chat_turn": 2,
|
||||
"do_checker": True,
|
||||
"clear_structure": "True",
|
||||
"brainstorming": "False",
|
||||
"gui_design": "True",
|
||||
"git_management": "False",
|
||||
"self_improve": "False"
|
||||
},
|
||||
"codeReactChain": {
|
||||
"chain_name": "codeReactChain",
|
||||
"chain_type": "BaseChain",
|
||||
"agents": ["planner", "code_react"],
|
||||
"chat_turn": 2,
|
||||
"do_checker": True,
|
||||
"clear_structure": "True",
|
||||
"brainstorming": "False",
|
||||
"gui_design": "True",
|
||||
"git_management": "False",
|
||||
"self_improve": "False"
|
||||
},
|
||||
"dataAnalystChain": {
|
||||
"chain_name": "dataAnalystChain",
|
||||
"chain_type": "BaseChain",
|
||||
"agents": ["planner", "code_react"],
|
||||
"chat_turn": 2,
|
||||
"do_checker": True,
|
||||
"clear_structure": "True",
|
||||
"brainstorming": "False",
|
||||
"gui_design": "True",
|
||||
"git_management": "False",
|
||||
"self_improve": "False"
|
||||
},
|
||||
}
|
||||
79
dev_opsgpt/connector/configs/phase_config.py
Normal file
79
dev_opsgpt/connector/configs/phase_config.py
Normal file
@ -0,0 +1,79 @@
|
||||
PHASE_CONFIGS = {
|
||||
"chatPhase": {
|
||||
"phase_name": "chatPhase",
|
||||
"phase_type": "BasePhase",
|
||||
"chains": ["chatChain"],
|
||||
"do_summary": False,
|
||||
"do_search": False,
|
||||
"do_doc_retrieval": False,
|
||||
"do_code_retrieval": False,
|
||||
"do_tool_retrieval": False,
|
||||
"do_using_tool": False
|
||||
},
|
||||
"docChatPhase": {
|
||||
"phase_name": "docChatPhase",
|
||||
"phase_type": "BasePhase",
|
||||
"chains": ["docChatChain"],
|
||||
"do_summary": False,
|
||||
"do_search": False,
|
||||
"do_doc_retrieval": True,
|
||||
"do_code_retrieval": False,
|
||||
"do_tool_retrieval": False,
|
||||
"do_using_tool": False
|
||||
},
|
||||
"searchChatPhase": {
|
||||
"phase_name": "searchChatPhase",
|
||||
"phase_type": "BasePhase",
|
||||
"chains": ["searchChatChain"],
|
||||
"do_summary": False,
|
||||
"do_search": True,
|
||||
"do_doc_retrieval": False,
|
||||
"do_code_retrieval": False,
|
||||
"do_tool_retrieval": False,
|
||||
"do_using_tool": False
|
||||
},
|
||||
"codeChatPhase": {
|
||||
"phase_name": "codeChatPhase",
|
||||
"phase_type": "BasePhase",
|
||||
"chains": ["codeChatChain"],
|
||||
"do_summary": False,
|
||||
"do_search": False,
|
||||
"do_doc_retrieval": False,
|
||||
"do_code_retrieval": True,
|
||||
"do_tool_retrieval": False,
|
||||
"do_using_tool": False
|
||||
},
|
||||
"toolReactPhase": {
|
||||
"phase_name": "toolReactPhase",
|
||||
"phase_type": "BasePhase",
|
||||
"chains": ["toolReactChain"],
|
||||
"do_summary": False,
|
||||
"do_search": False,
|
||||
"do_doc_retrieval": False,
|
||||
"do_code_retrieval": False,
|
||||
"do_tool_retrieval": False,
|
||||
"do_using_tool": True
|
||||
},
|
||||
"codeReactPhase": {
|
||||
"phase_name": "codeReacttPhase",
|
||||
"phase_type": "BasePhase",
|
||||
"chains": ["codeReactChain"],
|
||||
"do_summary": False,
|
||||
"do_search": False,
|
||||
"do_doc_retrieval": False,
|
||||
"do_code_retrieval": False,
|
||||
"do_tool_retrieval": False,
|
||||
"do_using_tool": False
|
||||
},
|
||||
"dataReactPhase": {
|
||||
"phase_name": "dataReactPhase",
|
||||
"phase_type": "BasePhase",
|
||||
"chains": ["dataAnalystChain"],
|
||||
"do_summary": True,
|
||||
"do_search": False,
|
||||
"do_doc_retrieval": False,
|
||||
"do_code_retrieval": False,
|
||||
"do_tool_retrieval": False,
|
||||
"do_using_tool": False
|
||||
}
|
||||
}
|
||||
248
dev_opsgpt/connector/connector_schema.py
Normal file
248
dev_opsgpt/connector/connector_schema.py
Normal file
@ -0,0 +1,248 @@
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Dict
|
||||
from enum import Enum
|
||||
import re
|
||||
import json
|
||||
from loguru import logger
|
||||
from langchain.tools import BaseTool
|
||||
|
||||
|
||||
class ActionStatus(Enum):
|
||||
FINISHED = "finished"
|
||||
CODING = "coding"
|
||||
TOOL_USING = "tool_using"
|
||||
REASONING = "reasoning"
|
||||
PLANNING = "planning"
|
||||
EXECUTING_CODE = "executing_code"
|
||||
EXECUTING_TOOL = "executing_tool"
|
||||
DEFAUILT = "default"
|
||||
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, str):
|
||||
return self.value == other
|
||||
return super().__eq__(other)
|
||||
|
||||
class Doc(BaseModel):
|
||||
title: str
|
||||
snippet: str
|
||||
link: str
|
||||
index: int
|
||||
|
||||
def get_title(self):
|
||||
return self.title
|
||||
|
||||
def get_snippet(self, ):
|
||||
return self.snippet
|
||||
|
||||
def get_link(self, ):
|
||||
return self.link
|
||||
|
||||
def get_index(self, ):
|
||||
return self.index
|
||||
|
||||
def to_json(self):
|
||||
return vars(self)
|
||||
|
||||
def __str__(self,):
|
||||
return f"""出处 [{self.index + 1}] 标题 [{self.title}]\n\n来源 ({self.link}) \n\n内容 {self.snippet}\n\n"""
|
||||
|
||||
|
||||
class CodeDoc(BaseModel):
|
||||
code: str
|
||||
related_nodes: list
|
||||
index: int
|
||||
|
||||
def get_code(self, ):
|
||||
return self.code
|
||||
|
||||
def get_related_node(self, ):
|
||||
return self.related_nodes
|
||||
|
||||
def get_index(self, ):
|
||||
return self.index
|
||||
|
||||
def to_json(self):
|
||||
return vars(self)
|
||||
|
||||
def __str__(self,):
|
||||
return f"""出处 [{self.index + 1}] \n\n来源 ({self.related_nodes}) \n\n内容 {self.code}\n\n"""
|
||||
|
||||
|
||||
class Docs:
|
||||
|
||||
def __init__(self, docs: List[Doc]):
|
||||
self.titles: List[str] = [doc.get_title() for doc in docs]
|
||||
self.snippets: List[str] = [doc.get_snippet() for doc in docs]
|
||||
self.links: List[str] = [doc.get_link() for doc in docs]
|
||||
self.indexs: List[int] = [doc.get_index() for doc in docs]
|
||||
|
||||
class Task(BaseModel):
|
||||
task_type: str
|
||||
task_name: str
|
||||
task_desc: str
|
||||
task_prompt: str
|
||||
# def __init__(self, task_type, task_name, task_desc) -> None:
|
||||
# self.task_type = task_type
|
||||
# self.task_name = task_name
|
||||
# self.task_desc = task_desc
|
||||
|
||||
class Env(BaseModel):
|
||||
env_type: str
|
||||
env_name: str
|
||||
env_desc:str
|
||||
|
||||
|
||||
class Role(BaseModel):
|
||||
role_type: str
|
||||
role_name: str
|
||||
role_desc: str
|
||||
agent_type: str = ""
|
||||
role_prompt: str = ""
|
||||
template_prompt: str = ""
|
||||
|
||||
|
||||
|
||||
class ChainConfig(BaseModel):
|
||||
chain_name: str
|
||||
chain_type: str
|
||||
agents: List[str]
|
||||
do_checker: bool = False
|
||||
chat_turn: int = 1
|
||||
clear_structure: bool = False
|
||||
brainstorming: bool = False
|
||||
gui_design: bool = True
|
||||
git_management: bool = False
|
||||
self_improve: bool = False
|
||||
|
||||
|
||||
class AgentConfig(BaseModel):
|
||||
role: Role
|
||||
chat_turn: int = 1
|
||||
do_search: bool = False
|
||||
do_doc_retrieval: bool = False
|
||||
do_tool_retrieval: bool = False
|
||||
|
||||
|
||||
class PhaseConfig(BaseModel):
|
||||
phase_name: str
|
||||
phase_type: str
|
||||
chains: List[str]
|
||||
do_summary: bool = False
|
||||
do_search: bool = False
|
||||
do_doc_retrieval: bool = False
|
||||
do_code_retrieval: bool = False
|
||||
do_tool_retrieval: bool = False
|
||||
|
||||
class Message(BaseModel):
|
||||
role_name: str
|
||||
role_type: str
|
||||
role_prompt: str = None
|
||||
input_query: str = None
|
||||
|
||||
# 模型最终返回
|
||||
role_content: str = None
|
||||
role_contents: List[str] = []
|
||||
step_content: str = None
|
||||
step_contents: List[str] = []
|
||||
chain_content: str = None
|
||||
chain_contents: List[str] = []
|
||||
|
||||
# 模型结果解析
|
||||
plans: List[str] = None
|
||||
code_content: str = None
|
||||
code_filename: str = None
|
||||
tool_params: str = None
|
||||
tool_name: str = None
|
||||
|
||||
# 执行结果
|
||||
action_status: str = ActionStatus.DEFAUILT
|
||||
code_answer: str = None
|
||||
tool_answer: str = None
|
||||
observation: str = None
|
||||
figures: Dict[str, str] = {}
|
||||
|
||||
# 辅助信息
|
||||
tools: List[BaseTool] = []
|
||||
task: Task = None
|
||||
db_docs: List['Doc'] = []
|
||||
code_docs: List['CodeDoc'] = []
|
||||
search_docs: List['Doc'] = []
|
||||
|
||||
# 执行输入
|
||||
phase_name: str = None
|
||||
chain_name: str = None
|
||||
do_search: bool = False
|
||||
doc_engine_name: str = None
|
||||
code_engine_name: str = None
|
||||
search_engine_name: str = None
|
||||
top_k: int = 3
|
||||
score_threshold: float = 1.0
|
||||
do_doc_retrieval: bool = False
|
||||
do_code_retrieval: bool = False
|
||||
do_tool_retrieval: bool = False
|
||||
history_node_list: List[str] = []
|
||||
|
||||
|
||||
def to_tuple_message(self, return_all: bool = False, content_key="role_content"):
|
||||
if content_key == "role_content":
|
||||
role_content = self.role_content
|
||||
elif content_key == "step_content":
|
||||
role_content = self.step_content or self.role_content
|
||||
else:
|
||||
role_content =self.role_content
|
||||
|
||||
if return_all:
|
||||
return (self.role_name, self.role_type, role_content)
|
||||
else:
|
||||
return (self.role_name, role_content)
|
||||
return (self.role_type, re.sub("}", "}}", re.sub("{", "{{", str(self.role_content))))
|
||||
|
||||
def to_dict_message(self, return_all: bool = False, content_key="role_content"):
|
||||
if content_key == "role_content":
|
||||
role_content =self.role_content
|
||||
elif content_key == "step_content":
|
||||
role_content = self.step_content or self.role_content
|
||||
else:
|
||||
role_content =self.role_content
|
||||
|
||||
if return_all:
|
||||
return vars(self)
|
||||
else:
|
||||
return {"role": self.role_name, "content": role_content}
|
||||
|
||||
def is_system_role(self,):
|
||||
return self.role_type == "system"
|
||||
|
||||
def __str__(self) -> str:
|
||||
# key_str = '\n'.join([k for k, v in vars(self).items()])
|
||||
# logger.debug(f"{key_str}")
|
||||
return "\n".join([": ".join([k, str(v)]) for k, v in vars(self).items()])
|
||||
|
||||
|
||||
|
||||
def load_role_configs(config) -> Dict[str, AgentConfig]:
|
||||
if isinstance(config, str):
|
||||
with open(config, 'r', encoding="utf8") as file:
|
||||
configs = json.load(file)
|
||||
else:
|
||||
configs = config
|
||||
|
||||
return {name: AgentConfig(**v) for name, v in configs.items()}
|
||||
|
||||
|
||||
def load_chain_configs(config) -> Dict[str, ChainConfig]:
|
||||
if isinstance(config, str):
|
||||
with open(config, 'r', encoding="utf8") as file:
|
||||
configs = json.load(file)
|
||||
else:
|
||||
configs = config
|
||||
return {name: ChainConfig(**v) for name, v in configs.items()}
|
||||
|
||||
|
||||
def load_phase_configs(config) -> Dict[str, PhaseConfig]:
|
||||
if isinstance(config, str):
|
||||
with open(config, 'r', encoding="utf8") as file:
|
||||
configs = json.load(file)
|
||||
else:
|
||||
configs = config
|
||||
return {name: PhaseConfig(**v) for name, v in configs.items()}
|
||||
3
dev_opsgpt/connector/phase/__init__.py
Normal file
3
dev_opsgpt/connector/phase/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
from .base_phase import BasePhase
|
||||
|
||||
__all__ = ["BasePhase"]
|
||||
215
dev_opsgpt/connector/phase/base_phase.py
Normal file
215
dev_opsgpt/connector/phase/base_phase.py
Normal file
@ -0,0 +1,215 @@
|
||||
from typing import List, Union, Dict, Tuple
|
||||
import os
|
||||
import json
|
||||
import importlib
|
||||
import copy
|
||||
from loguru import logger
|
||||
|
||||
from dev_opsgpt.connector.agents import BaseAgent
|
||||
from dev_opsgpt.connector.chains import BaseChain
|
||||
from dev_opsgpt.tools.base_tool import BaseTools, Tool
|
||||
from dev_opsgpt.connector.shcema.memory import Memory
|
||||
from dev_opsgpt.connector.connector_schema import (
|
||||
Task, Env, Role, Message, Doc, Docs, AgentConfig, ChainConfig, PhaseConfig, CodeDoc,
|
||||
load_chain_configs, load_phase_configs, load_role_configs
|
||||
)
|
||||
from dev_opsgpt.connector.configs import AGETN_CONFIGS, CHAIN_CONFIGS, PHASE_CONFIGS
|
||||
from dev_opsgpt.tools import DDGSTool, DocRetrieval, CodeRetrieval
|
||||
|
||||
|
||||
role_configs = load_role_configs(AGETN_CONFIGS)
|
||||
chain_configs = load_chain_configs(CHAIN_CONFIGS)
|
||||
phase_configs = load_phase_configs(PHASE_CONFIGS)
|
||||
|
||||
|
||||
CUR_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
|
||||
class BasePhase:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
phase_name: str,
|
||||
task: Task = None,
|
||||
do_summary: bool = False,
|
||||
do_search: bool = False,
|
||||
do_doc_retrieval: bool = False,
|
||||
do_code_retrieval: bool = False,
|
||||
do_tool_retrieval: bool = False,
|
||||
phase_config: Union[dict, str] = PHASE_CONFIGS,
|
||||
chain_config: Union[dict, str] = CHAIN_CONFIGS,
|
||||
role_config: Union[dict, str] = AGETN_CONFIGS,
|
||||
) -> None:
|
||||
self.conv_summary_agent = BaseAgent(role=role_configs["conv_summary"].role,
|
||||
task = None,
|
||||
memory = None,
|
||||
do_search = role_configs["conv_summary"].do_search,
|
||||
do_doc_retrieval = role_configs["conv_summary"].do_doc_retrieval,
|
||||
do_tool_retrieval = role_configs["conv_summary"].do_tool_retrieval,
|
||||
do_filter=False, do_use_self_memory=False)
|
||||
|
||||
self.chains: List[BaseChain] = self.init_chains(
|
||||
phase_name,
|
||||
task=task,
|
||||
memory=None,
|
||||
phase_config = phase_config,
|
||||
chain_config = chain_config,
|
||||
role_config = role_config,
|
||||
)
|
||||
self.phase_name = phase_name
|
||||
self.do_summary = do_summary
|
||||
self.do_search = do_search
|
||||
self.do_code_retrieval = do_code_retrieval
|
||||
self.do_doc_retrieval = do_doc_retrieval
|
||||
self.do_tool_retrieval = do_tool_retrieval
|
||||
|
||||
self.global_message = Memory([])
|
||||
# self.chain_message = Memory([])
|
||||
self.phase_memory: List[Memory] = []
|
||||
|
||||
def step(self, query: Message, history: Memory = None) -> Tuple[Message, Memory]:
|
||||
summary_message = None
|
||||
chain_message = Memory([])
|
||||
local_memory = Memory([])
|
||||
# do_search、do_doc_search、do_code_search
|
||||
query = self.get_extrainfo_step(query)
|
||||
input_message = copy.deepcopy(query)
|
||||
|
||||
self.global_message.append(input_message)
|
||||
for chain in self.chains:
|
||||
# chain can supply background and query to next chain
|
||||
output_message, chain_memory = chain.step(input_message, history, background=chain_message)
|
||||
output_message = self.inherit_extrainfo(input_message, output_message)
|
||||
input_message = output_message
|
||||
logger.info(f"{chain.chainConfig.chain_name} phase_step: {output_message.role_content}")
|
||||
|
||||
self.global_message.append(output_message)
|
||||
local_memory.extend(chain_memory)
|
||||
|
||||
# whether use summary_llm
|
||||
if self.do_summary:
|
||||
logger.info(f"{self.conv_summary_agent.role.role_name} input global memory: {self.global_message.to_str_messages(content_key='step_content')}")
|
||||
logger.info(f"{self.conv_summary_agent.role.role_name} input global memory: {self.global_message.to_str_messages(content_key='role_content')}")
|
||||
summary_message = self.conv_summary_agent.run(query, background=self.global_message)
|
||||
summary_message.role_name = chain.chainConfig.chain_name
|
||||
summary_message = self.conv_summary_agent.parser(summary_message)
|
||||
summary_message = self.conv_summary_agent.filter(summary_message)
|
||||
summary_message = self.inherit_extrainfo(output_message, summary_message)
|
||||
chain_message.append(summary_message)
|
||||
|
||||
# 由于不会存在多轮chain执行,所以直接保留memory即可
|
||||
for chain in self.chains:
|
||||
self.phase_memory.append(chain.global_memory)
|
||||
|
||||
message = summary_message or output_message
|
||||
message.role_name = self.phase_name
|
||||
# message.db_docs = query.db_docs
|
||||
# message.code_docs = query.code_docs
|
||||
# message.search_docs = query.search_docs
|
||||
return summary_message or output_message, local_memory
|
||||
|
||||
def init_chains(self, phase_name, phase_config, chain_config,
|
||||
role_config, task=None, memory=None) -> List[BaseChain]:
|
||||
# load config
|
||||
role_configs = load_role_configs(role_config)
|
||||
chain_configs = load_chain_configs(chain_config)
|
||||
phase_configs = load_phase_configs(phase_config)
|
||||
|
||||
chains = []
|
||||
self.chain_module = importlib.import_module("dev_opsgpt.connector.chains")
|
||||
self.agent_module = importlib.import_module("dev_opsgpt.connector.agents")
|
||||
phase = phase_configs.get(phase_name)
|
||||
for chain_name in phase.chains:
|
||||
logger.info(f"chain_name: {chain_name}")
|
||||
# chain_class = getattr(self.chain_module, chain_name)
|
||||
logger.debug(f"{chain_configs.keys()}")
|
||||
chain_config = chain_configs[chain_name]
|
||||
|
||||
agents = [
|
||||
getattr(self.agent_module, role_configs[agent_name].role.agent_type)(
|
||||
role_configs[agent_name].role,
|
||||
task = task,
|
||||
memory = memory,
|
||||
chat_turn=role_configs[agent_name].chat_turn,
|
||||
do_search = role_configs[agent_name].do_search,
|
||||
do_doc_retrieval = role_configs[agent_name].do_doc_retrieval,
|
||||
do_tool_retrieval = role_configs[agent_name].do_tool_retrieval,
|
||||
)
|
||||
for agent_name in chain_config.agents
|
||||
]
|
||||
chain_instance = BaseChain(
|
||||
chain_config, agents, chain_config.chat_turn,
|
||||
do_checker=chain_configs[chain_name].do_checker,
|
||||
do_code_exec=False,)
|
||||
chains.append(chain_instance)
|
||||
|
||||
return chains
|
||||
|
||||
def get_extrainfo_step(self, input_message):
|
||||
if self.do_doc_retrieval:
|
||||
input_message = self.get_doc_retrieval(input_message)
|
||||
|
||||
logger.debug(F"self.do_code_retrieval: {self.do_code_retrieval}")
|
||||
if self.do_code_retrieval:
|
||||
input_message = self.get_code_retrieval(input_message)
|
||||
|
||||
if self.do_search:
|
||||
input_message = self.get_search_retrieval(input_message)
|
||||
|
||||
return input_message
|
||||
|
||||
def inherit_extrainfo(self, input_message: Message, output_message: Message):
|
||||
output_message.db_docs = input_message.db_docs
|
||||
output_message.search_docs = input_message.search_docs
|
||||
output_message.code_docs = input_message.code_docs
|
||||
output_message.figures.update(input_message.figures)
|
||||
return output_message
|
||||
|
||||
def get_search_retrieval(self, message: Message,) -> Message:
|
||||
SEARCH_ENGINES = {"duckduckgo": DDGSTool}
|
||||
search_docs = []
|
||||
for idx, doc in enumerate(SEARCH_ENGINES["duckduckgo"].run(message.role_content, 3)):
|
||||
doc.update({"index": idx})
|
||||
search_docs.append(Doc(**doc))
|
||||
message.search_docs = search_docs
|
||||
return message
|
||||
|
||||
def get_doc_retrieval(self, message: Message) -> Message:
|
||||
query = message.role_content
|
||||
knowledge_basename = message.doc_engine_name
|
||||
top_k = message.top_k
|
||||
score_threshold = message.score_threshold
|
||||
if knowledge_basename:
|
||||
docs = DocRetrieval.run(query, knowledge_basename, top_k, score_threshold)
|
||||
message.db_docs = [Doc(**doc) for doc in docs]
|
||||
return message
|
||||
|
||||
def get_code_retrieval(self, message: Message) -> Message:
|
||||
# DocRetrieval.run("langchain是什么", "DSADSAD")
|
||||
query = message.input_query
|
||||
code_engine_name = message.code_engine_name
|
||||
history_node_list = message.history_node_list
|
||||
code_docs = CodeRetrieval.run(code_engine_name, query, code_limit=message.top_k, history_node_list=history_node_list)
|
||||
message.code_docs = [CodeDoc(**doc) for doc in code_docs]
|
||||
return message
|
||||
|
||||
def get_tool_retrieval(self, message: Message) -> Message:
|
||||
return message
|
||||
|
||||
def update(self) -> Memory:
|
||||
pass
|
||||
|
||||
def get_memory(self, ) -> Memory:
|
||||
return Memory.from_memory_list(
|
||||
[chain.get_memory() for chain in self.chains]
|
||||
)
|
||||
|
||||
def get_memory_str(self, do_all_memory=True, content_key="role_content") -> str:
|
||||
memory = self.global_message if do_all_memory else self.phase_memory
|
||||
return "\n".join([": ".join(i) for i in memory.to_tuple_messages(content_key=content_key)])
|
||||
|
||||
def get_chains_memory(self, content_key="role_content") -> List[Tuple]:
|
||||
return [memory.to_tuple_messages(content_key=content_key) for memory in self.phase_memory]
|
||||
|
||||
def get_chains_memory_str(self, content_key="role_content") -> str:
|
||||
return "************".join([f"{chain.chainConfig.chain_name}\n" + chain.get_memory_str(content_key=content_key) for chain in self.chains])
|
||||
6
dev_opsgpt/connector/shcema/__init__.py
Normal file
6
dev_opsgpt/connector/shcema/__init__.py
Normal file
@ -0,0 +1,6 @@
|
||||
from .memory import Memory
|
||||
|
||||
|
||||
__all__ = [
|
||||
"Memory"
|
||||
]
|
||||
88
dev_opsgpt/connector/shcema/memory.py
Normal file
88
dev_opsgpt/connector/shcema/memory.py
Normal file
@ -0,0 +1,88 @@
|
||||
from pydantic import BaseModel
|
||||
from typing import List
|
||||
from loguru import logger
|
||||
|
||||
from dev_opsgpt.connector.connector_schema import Message
|
||||
from dev_opsgpt.utils.common_utils import (
|
||||
save_to_jsonl_file, save_to_json_file, read_json_file, read_jsonl_file
|
||||
)
|
||||
|
||||
|
||||
class Memory:
|
||||
|
||||
def __init__(self, messages: List[Message] = []):
|
||||
self.messages = messages
|
||||
|
||||
def append(self, message: Message):
|
||||
self.messages.append(message)
|
||||
|
||||
def extend(self, memory: 'Memory'):
|
||||
self.messages.extend(memory.messages)
|
||||
|
||||
def update(self, role_name: str, role_type: str, role_content: str):
|
||||
self.messages.append(Message(role_name, role_type, role_content, role_content))
|
||||
|
||||
def clear(self, ):
|
||||
self.messages = []
|
||||
|
||||
def delete(self, ):
|
||||
pass
|
||||
|
||||
def get_messages(self, ) -> List[Message]:
|
||||
return self.messages
|
||||
|
||||
def save(self, file_type="jsonl", return_all=True):
|
||||
try:
|
||||
if file_type == "jsonl":
|
||||
save_to_jsonl_file(self.to_dict_messages(return_all=return_all), "role_name_history"+f".{file_type}")
|
||||
return True
|
||||
elif file_type in ["json", "txt"]:
|
||||
save_to_json_file(self.to_dict_messages(return_all=return_all), "role_name_history"+f".{file_type}")
|
||||
return True
|
||||
except:
|
||||
return False
|
||||
return False
|
||||
|
||||
def load(self, filepath):
|
||||
file_type = filepath
|
||||
try:
|
||||
if file_type == "jsonl":
|
||||
self.messages = [Message(**message) for message in read_jsonl_file(filepath)]
|
||||
return True
|
||||
elif file_type in ["json", "txt"]:
|
||||
self.messages = [Message(**message) for message in read_jsonl_file(filepath)]
|
||||
return True
|
||||
except:
|
||||
return False
|
||||
|
||||
return False
|
||||
|
||||
def to_tuple_messages(self, return_system: bool = False, return_all: bool = False, content_key="role_content"):
|
||||
# logger.debug(f"{[message.to_tuple_message(return_all, content_key) for message in self.messages ]}")
|
||||
return [
|
||||
message.to_tuple_message(return_all, content_key) for message in self.messages
|
||||
if not message.is_system_role() | return_system
|
||||
]
|
||||
|
||||
def to_dict_messages(self, return_system: bool = False, return_all: bool = False, content_key="role_content"):
|
||||
return [
|
||||
message.to_dict_message(return_all, content_key) for message in self.messages
|
||||
if not message.is_system_role() | return_system
|
||||
]
|
||||
|
||||
def to_str_messages(self, return_system: bool = False, return_all: bool = False, content_key="role_content"):
|
||||
# logger.debug(f"{[message.to_tuple_message(return_all, content_key) for message in self.messages ]}")
|
||||
return "\n".join([
|
||||
": ".join(message.to_tuple_message(return_all, content_key)) for message in self.messages
|
||||
if not message.is_system_role() | return_system
|
||||
])
|
||||
|
||||
@classmethod
|
||||
def from_memory_list(cls, memorys: List['Memory']) -> 'Memory':
|
||||
return cls([message for memory in memorys for message in memory.get_messages()])
|
||||
|
||||
def __len__(self, ):
|
||||
return len(self.messages)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return "\n".join([":".join(i) for i in self.to_tuple_messages()])
|
||||
27
dev_opsgpt/connector/utils.py
Normal file
27
dev_opsgpt/connector/utils.py
Normal file
@ -0,0 +1,27 @@
|
||||
|
||||
|
||||
def prompt_cost(model_type: str, num_prompt_tokens: float, num_completion_tokens: float):
|
||||
input_cost_map = {
|
||||
"gpt-3.5-turbo": 0.0015,
|
||||
"gpt-3.5-turbo-16k": 0.003,
|
||||
"gpt-3.5-turbo-0613": 0.0015,
|
||||
"gpt-3.5-turbo-16k-0613": 0.003,
|
||||
"gpt-4": 0.03,
|
||||
"gpt-4-0613": 0.03,
|
||||
"gpt-4-32k": 0.06,
|
||||
}
|
||||
|
||||
output_cost_map = {
|
||||
"gpt-3.5-turbo": 0.002,
|
||||
"gpt-3.5-turbo-16k": 0.004,
|
||||
"gpt-3.5-turbo-0613": 0.002,
|
||||
"gpt-3.5-turbo-16k-0613": 0.004,
|
||||
"gpt-4": 0.06,
|
||||
"gpt-4-0613": 0.06,
|
||||
"gpt-4-32k": 0.12,
|
||||
}
|
||||
|
||||
if model_type not in input_cost_map or model_type not in output_cost_map:
|
||||
return -1
|
||||
|
||||
return num_prompt_tokens * input_cost_map[model_type] / 1000.0 + num_completion_tokens * output_cost_map[model_type] / 1000.0
|
||||
@ -4,6 +4,7 @@ from typing import AnyStr, Callable, Dict, List, Optional, Union
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
|
||||
|
||||
from dev_opsgpt.utils.common_utils import read_json_file
|
||||
|
||||
@ -39,3 +40,22 @@ class JSONLoader(BaseLoader):
|
||||
)
|
||||
text = sample.get(self.schema_key, "")
|
||||
docs.append(Document(page_content=text, metadata=metadata))
|
||||
|
||||
def load_and_split(
|
||||
self, text_splitter: Optional[TextSplitter] = None
|
||||
) -> List[Document]:
|
||||
"""Load Documents and split into chunks. Chunks are returned as Documents.
|
||||
|
||||
Args:
|
||||
text_splitter: TextSplitter instance to use for splitting documents.
|
||||
Defaults to RecursiveCharacterTextSplitter.
|
||||
|
||||
Returns:
|
||||
List of Documents.
|
||||
"""
|
||||
if text_splitter is None:
|
||||
_text_splitter: TextSplitter = RecursiveCharacterTextSplitter()
|
||||
else:
|
||||
_text_splitter = text_splitter
|
||||
docs = self.load()
|
||||
return _text_splitter.split_documents(docs)
|
||||
@ -4,6 +4,7 @@ from typing import AnyStr, Callable, Dict, List, Optional, Union
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
|
||||
|
||||
from dev_opsgpt.utils.common_utils import read_jsonl_file
|
||||
|
||||
@ -39,3 +40,23 @@ class JSONLLoader(BaseLoader):
|
||||
)
|
||||
text = sample.get(self.schema_key, "")
|
||||
docs.append(Document(page_content=text, metadata=metadata))
|
||||
|
||||
def load_and_split(
|
||||
self, text_splitter: Optional[TextSplitter] = None
|
||||
) -> List[Document]:
|
||||
"""Load Documents and split into chunks. Chunks are returned as Documents.
|
||||
|
||||
Args:
|
||||
text_splitter: TextSplitter instance to use for splitting documents.
|
||||
Defaults to RecursiveCharacterTextSplitter.
|
||||
|
||||
Returns:
|
||||
List of Documents.
|
||||
"""
|
||||
if text_splitter is None:
|
||||
_text_splitter: TextSplitter = RecursiveCharacterTextSplitter()
|
||||
else:
|
||||
_text_splitter = text_splitter
|
||||
|
||||
docs = self.load()
|
||||
return _text_splitter.split_documents(docs)
|
||||
776
dev_opsgpt/embeddings/faiss_m.py
Normal file
776
dev_opsgpt/embeddings/faiss_m.py
Normal file
@ -0,0 +1,776 @@
|
||||
"""Wrapper around FAISS vector database."""
|
||||
from __future__ import annotations
|
||||
|
||||
import operator
|
||||
import os
|
||||
import pickle
|
||||
import uuid
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Sized,
|
||||
Tuple,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
|
||||
from langchain.docstore.base import AddableMixin, Docstore
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.docstore.in_memory import InMemoryDocstore
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.vectorstores.base import VectorStore
|
||||
from langchain.vectorstores.utils import DistanceStrategy, maximal_marginal_relevance
|
||||
|
||||
|
||||
def dependable_faiss_import(no_avx2: Optional[bool] = None) -> Any:
|
||||
"""
|
||||
Import faiss if available, otherwise raise error.
|
||||
If FAISS_NO_AVX2 environment variable is set, it will be considered
|
||||
to load FAISS with no AVX2 optimization.
|
||||
|
||||
Args:
|
||||
no_avx2: Load FAISS strictly with no AVX2 optimization
|
||||
so that the vectorstore is portable and compatible with other devices.
|
||||
"""
|
||||
if no_avx2 is None and "FAISS_NO_AVX2" in os.environ:
|
||||
no_avx2 = bool(os.getenv("FAISS_NO_AVX2"))
|
||||
|
||||
try:
|
||||
if no_avx2:
|
||||
from faiss import swigfaiss as faiss
|
||||
else:
|
||||
import faiss
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import faiss python package. "
|
||||
"Please install it with `pip install faiss-gpu` (for CUDA supported GPU) "
|
||||
"or `pip install faiss-cpu` (depending on Python version)."
|
||||
)
|
||||
return faiss
|
||||
|
||||
|
||||
def _len_check_if_sized(x: Any, y: Any, x_name: str, y_name: str) -> None:
|
||||
if isinstance(x, Sized) and isinstance(y, Sized) and len(x) != len(y):
|
||||
raise ValueError(
|
||||
f"{x_name} and {y_name} expected to be equal length but "
|
||||
f"len({x_name})={len(x)} and len({y_name})={len(y)}"
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
class FAISS(VectorStore):
|
||||
"""Wrapper around FAISS vector database.
|
||||
|
||||
To use, you must have the ``faiss`` python package installed.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.embeddings.openai import OpenAIEmbeddings
|
||||
from langchain.vectorstores import FAISS
|
||||
|
||||
embeddings = OpenAIEmbeddings()
|
||||
texts = ["FAISS is an important library", "LangChain supports FAISS"]
|
||||
faiss = FAISS.from_texts(texts, embeddings)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedding_function: Callable,
|
||||
index: Any,
|
||||
docstore: Docstore,
|
||||
index_to_docstore_id: Dict[int, str],
|
||||
relevance_score_fn: Optional[Callable[[float], float]] = None,
|
||||
normalize_L2: bool = False,
|
||||
distance_strategy: DistanceStrategy = DistanceStrategy.EUCLIDEAN_DISTANCE,
|
||||
):
|
||||
"""Initialize with necessary components."""
|
||||
self.embedding_function = embedding_function
|
||||
self.index = index
|
||||
self.docstore = docstore
|
||||
self.index_to_docstore_id = index_to_docstore_id
|
||||
self.distance_strategy = distance_strategy
|
||||
self.override_relevance_score_fn = relevance_score_fn
|
||||
self._normalize_L2 = normalize_L2
|
||||
if (
|
||||
self.distance_strategy != DistanceStrategy.EUCLIDEAN_DISTANCE
|
||||
and self._normalize_L2
|
||||
):
|
||||
warnings.warn(
|
||||
"Normalizing L2 is not applicable for metric type: {strategy}".format(
|
||||
strategy=self.distance_strategy
|
||||
)
|
||||
)
|
||||
|
||||
def __add(
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
embeddings: Iterable[List[float]],
|
||||
metadatas: Optional[Iterable[dict]] = None,
|
||||
ids: Optional[List[str]] = None,
|
||||
) -> List[str]:
|
||||
faiss = dependable_faiss_import()
|
||||
|
||||
if not isinstance(self.docstore, AddableMixin):
|
||||
raise ValueError(
|
||||
"If trying to add texts, the underlying docstore should support "
|
||||
f"adding items, which {self.docstore} does not"
|
||||
)
|
||||
|
||||
_len_check_if_sized(texts, metadatas, "texts", "metadatas")
|
||||
_metadatas = metadatas or ({} for _ in texts)
|
||||
documents = [
|
||||
Document(page_content=t, metadata=m) for t, m in zip(texts, _metadatas)
|
||||
]
|
||||
|
||||
_len_check_if_sized(documents, embeddings, "documents", "embeddings")
|
||||
_len_check_if_sized(documents, ids, "documents", "ids")
|
||||
|
||||
# Add to the index.
|
||||
vector = np.array(embeddings, dtype=np.float32)
|
||||
if self._normalize_L2:
|
||||
faiss.normalize_L2(vector)
|
||||
self.index.add(vector)
|
||||
|
||||
# Add information to docstore and index.
|
||||
ids = ids or [str(uuid.uuid4()) for _ in texts]
|
||||
self.docstore.add({id_: doc for id_, doc in zip(ids, documents)})
|
||||
starting_len = len(self.index_to_docstore_id)
|
||||
index_to_id = {starting_len + j: id_ for j, id_ in enumerate(ids)}
|
||||
self.index_to_docstore_id.update(index_to_id)
|
||||
return ids
|
||||
|
||||
def add_texts(
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
ids: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[str]:
|
||||
"""Run more texts through the embeddings and add to the vectorstore.
|
||||
|
||||
Args:
|
||||
texts: Iterable of strings to add to the vectorstore.
|
||||
metadatas: Optional list of metadatas associated with the texts.
|
||||
ids: Optional list of unique IDs.
|
||||
|
||||
Returns:
|
||||
List of ids from adding the texts into the vectorstore.
|
||||
"""
|
||||
# embeddings = [self.embedding_function(text) for text in texts]
|
||||
embeddings = self.embedding_function(texts)
|
||||
return self.__add(texts, embeddings, metadatas=metadatas, ids=ids)
|
||||
|
||||
def add_embeddings(
|
||||
self,
|
||||
text_embeddings: Iterable[Tuple[str, List[float]]],
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
ids: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[str]:
|
||||
"""Run more texts through the embeddings and add to the vectorstore.
|
||||
|
||||
Args:
|
||||
text_embeddings: Iterable pairs of string and embedding to
|
||||
add to the vectorstore.
|
||||
metadatas: Optional list of metadatas associated with the texts.
|
||||
ids: Optional list of unique IDs.
|
||||
|
||||
Returns:
|
||||
List of ids from adding the texts into the vectorstore.
|
||||
"""
|
||||
# Embed and create the documents.
|
||||
texts, embeddings = zip(*text_embeddings)
|
||||
return self.__add(texts, embeddings, metadatas=metadatas, ids=ids)
|
||||
|
||||
def similarity_search_with_score_by_vector(
|
||||
self,
|
||||
embedding: List[float],
|
||||
k: int = 4,
|
||||
filter: Optional[Dict[str, Any]] = None,
|
||||
fetch_k: int = 20,
|
||||
**kwargs: Any,
|
||||
) -> List[Tuple[Document, float]]:
|
||||
"""Return docs most similar to query.
|
||||
|
||||
Args:
|
||||
embedding: Embedding vector to look up documents similar to.
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
filter (Optional[Dict[str, Any]]): Filter by metadata. Defaults to None.
|
||||
fetch_k: (Optional[int]) Number of Documents to fetch before filtering.
|
||||
Defaults to 20.
|
||||
**kwargs: kwargs to be passed to similarity search. Can include:
|
||||
score_threshold: Optional, a floating point value between 0 to 1 to
|
||||
filter the resulting set of retrieved docs
|
||||
|
||||
Returns:
|
||||
List of documents most similar to the query text and L2 distance
|
||||
in float for each. Lower score represents more similarity.
|
||||
"""
|
||||
faiss = dependable_faiss_import()
|
||||
vector = np.array([embedding], dtype=np.float32)
|
||||
if self._normalize_L2:
|
||||
faiss.normalize_L2(vector)
|
||||
scores, indices = self.index.search(vector, k if filter is None else fetch_k)
|
||||
docs = []
|
||||
for j, i in enumerate(indices[0]):
|
||||
if i == -1:
|
||||
# This happens when not enough docs are returned.
|
||||
continue
|
||||
_id = self.index_to_docstore_id[i]
|
||||
doc = self.docstore.search(_id)
|
||||
if not isinstance(doc, Document):
|
||||
raise ValueError(f"Could not find document for id {_id}, got {doc}")
|
||||
if filter is not None:
|
||||
filter = {
|
||||
key: [value] if not isinstance(value, list) else value
|
||||
for key, value in filter.items()
|
||||
}
|
||||
if all(doc.metadata.get(key) in value for key, value in filter.items()):
|
||||
docs.append((doc, scores[0][j]))
|
||||
else:
|
||||
docs.append((doc, scores[0][j]))
|
||||
|
||||
score_threshold = kwargs.get("score_threshold")
|
||||
if score_threshold is not None:
|
||||
cmp = (
|
||||
operator.ge
|
||||
if self.distance_strategy
|
||||
in (DistanceStrategy.MAX_INNER_PRODUCT, DistanceStrategy.JACCARD)
|
||||
else operator.le
|
||||
)
|
||||
docs = [
|
||||
(doc, similarity)
|
||||
for doc, similarity in docs
|
||||
if cmp(similarity, score_threshold)
|
||||
]
|
||||
return docs[:k]
|
||||
|
||||
def similarity_search_with_score(
|
||||
self,
|
||||
query: str,
|
||||
k: int = 4,
|
||||
filter: Optional[Dict[str, Any]] = None,
|
||||
fetch_k: int = 20,
|
||||
**kwargs: Any,
|
||||
) -> List[Tuple[Document, float]]:
|
||||
"""Return docs most similar to query.
|
||||
|
||||
Args:
|
||||
query: Text to look up documents similar to.
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
|
||||
fetch_k: (Optional[int]) Number of Documents to fetch before filtering.
|
||||
Defaults to 20.
|
||||
|
||||
Returns:
|
||||
List of documents most similar to the query text with
|
||||
L2 distance in float. Lower score represents more similarity.
|
||||
"""
|
||||
embedding = self.embedding_function(query)
|
||||
docs = self.similarity_search_with_score_by_vector(
|
||||
embedding,
|
||||
k,
|
||||
filter=filter,
|
||||
fetch_k=fetch_k,
|
||||
**kwargs,
|
||||
)
|
||||
return docs
|
||||
|
||||
def similarity_search_by_vector(
|
||||
self,
|
||||
embedding: List[float],
|
||||
k: int = 4,
|
||||
filter: Optional[Dict[str, Any]] = None,
|
||||
fetch_k: int = 20,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Return docs most similar to embedding vector.
|
||||
|
||||
Args:
|
||||
embedding: Embedding to look up documents similar to.
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
|
||||
fetch_k: (Optional[int]) Number of Documents to fetch before filtering.
|
||||
Defaults to 20.
|
||||
|
||||
Returns:
|
||||
List of Documents most similar to the embedding.
|
||||
"""
|
||||
docs_and_scores = self.similarity_search_with_score_by_vector(
|
||||
embedding,
|
||||
k,
|
||||
filter=filter,
|
||||
fetch_k=fetch_k,
|
||||
**kwargs,
|
||||
)
|
||||
return [doc for doc, _ in docs_and_scores]
|
||||
|
||||
def similarity_search(
|
||||
self,
|
||||
query: str,
|
||||
k: int = 4,
|
||||
filter: Optional[Dict[str, Any]] = None,
|
||||
fetch_k: int = 20,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Return docs most similar to query.
|
||||
|
||||
Args:
|
||||
query: Text to look up documents similar to.
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
filter: (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
|
||||
fetch_k: (Optional[int]) Number of Documents to fetch before filtering.
|
||||
Defaults to 20.
|
||||
|
||||
Returns:
|
||||
List of Documents most similar to the query.
|
||||
"""
|
||||
docs_and_scores = self.similarity_search_with_score(
|
||||
query, k, filter=filter, fetch_k=fetch_k, **kwargs
|
||||
)
|
||||
return [doc for doc, _ in docs_and_scores]
|
||||
|
||||
def max_marginal_relevance_search_with_score_by_vector(
|
||||
self,
|
||||
embedding: List[float],
|
||||
*,
|
||||
k: int = 4,
|
||||
fetch_k: int = 20,
|
||||
lambda_mult: float = 0.5,
|
||||
filter: Optional[Dict[str, Any]] = None,
|
||||
) -> List[Tuple[Document, float]]:
|
||||
"""Return docs and their similarity scores selected using the maximal marginal
|
||||
relevance.
|
||||
|
||||
Maximal marginal relevance optimizes for similarity to query AND diversity
|
||||
among selected documents.
|
||||
|
||||
Args:
|
||||
embedding: Embedding to look up documents similar to.
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
fetch_k: Number of Documents to fetch before filtering to
|
||||
pass to MMR algorithm.
|
||||
lambda_mult: Number between 0 and 1 that determines the degree
|
||||
of diversity among the results with 0 corresponding
|
||||
to maximum diversity and 1 to minimum diversity.
|
||||
Defaults to 0.5.
|
||||
Returns:
|
||||
List of Documents and similarity scores selected by maximal marginal
|
||||
relevance and score for each.
|
||||
"""
|
||||
scores, indices = self.index.search(
|
||||
np.array([embedding], dtype=np.float32),
|
||||
fetch_k if filter is None else fetch_k * 2,
|
||||
)
|
||||
if filter is not None:
|
||||
filtered_indices = []
|
||||
for i in indices[0]:
|
||||
if i == -1:
|
||||
# This happens when not enough docs are returned.
|
||||
continue
|
||||
_id = self.index_to_docstore_id[i]
|
||||
doc = self.docstore.search(_id)
|
||||
if not isinstance(doc, Document):
|
||||
raise ValueError(f"Could not find document for id {_id}, got {doc}")
|
||||
if all(
|
||||
doc.metadata.get(key) in value
|
||||
if isinstance(value, list)
|
||||
else doc.metadata.get(key) == value
|
||||
for key, value in filter.items()
|
||||
):
|
||||
filtered_indices.append(i)
|
||||
indices = np.array([filtered_indices])
|
||||
# -1 happens when not enough docs are returned.
|
||||
embeddings = [self.index.reconstruct(int(i)) for i in indices[0] if i != -1]
|
||||
mmr_selected = maximal_marginal_relevance(
|
||||
np.array([embedding], dtype=np.float32),
|
||||
embeddings,
|
||||
k=k,
|
||||
lambda_mult=lambda_mult,
|
||||
)
|
||||
selected_indices = [indices[0][i] for i in mmr_selected]
|
||||
selected_scores = [scores[0][i] for i in mmr_selected]
|
||||
docs_and_scores = []
|
||||
for i, score in zip(selected_indices, selected_scores):
|
||||
if i == -1:
|
||||
# This happens when not enough docs are returned.
|
||||
continue
|
||||
_id = self.index_to_docstore_id[i]
|
||||
doc = self.docstore.search(_id)
|
||||
if not isinstance(doc, Document):
|
||||
raise ValueError(f"Could not find document for id {_id}, got {doc}")
|
||||
docs_and_scores.append((doc, score))
|
||||
return docs_and_scores
|
||||
|
||||
def max_marginal_relevance_search_by_vector(
|
||||
self,
|
||||
embedding: List[float],
|
||||
k: int = 4,
|
||||
fetch_k: int = 20,
|
||||
lambda_mult: float = 0.5,
|
||||
filter: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Return docs selected using the maximal marginal relevance.
|
||||
|
||||
Maximal marginal relevance optimizes for similarity to query AND diversity
|
||||
among selected documents.
|
||||
|
||||
Args:
|
||||
embedding: Embedding to look up documents similar to.
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
fetch_k: Number of Documents to fetch before filtering to
|
||||
pass to MMR algorithm.
|
||||
lambda_mult: Number between 0 and 1 that determines the degree
|
||||
of diversity among the results with 0 corresponding
|
||||
to maximum diversity and 1 to minimum diversity.
|
||||
Defaults to 0.5.
|
||||
Returns:
|
||||
List of Documents selected by maximal marginal relevance.
|
||||
"""
|
||||
docs_and_scores = self.max_marginal_relevance_search_with_score_by_vector(
|
||||
embedding, k=k, fetch_k=fetch_k, lambda_mult=lambda_mult, filter=filter
|
||||
)
|
||||
return [doc for doc, _ in docs_and_scores]
|
||||
|
||||
def max_marginal_relevance_search(
|
||||
self,
|
||||
query: str,
|
||||
k: int = 4,
|
||||
fetch_k: int = 20,
|
||||
lambda_mult: float = 0.5,
|
||||
filter: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Return docs selected using the maximal marginal relevance.
|
||||
|
||||
Maximal marginal relevance optimizes for similarity to query AND diversity
|
||||
among selected documents.
|
||||
|
||||
Args:
|
||||
query: Text to look up documents similar to.
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
fetch_k: Number of Documents to fetch before filtering (if needed) to
|
||||
pass to MMR algorithm.
|
||||
lambda_mult: Number between 0 and 1 that determines the degree
|
||||
of diversity among the results with 0 corresponding
|
||||
to maximum diversity and 1 to minimum diversity.
|
||||
Defaults to 0.5.
|
||||
Returns:
|
||||
List of Documents selected by maximal marginal relevance.
|
||||
"""
|
||||
embedding = self.embedding_function(query)
|
||||
docs = self.max_marginal_relevance_search_by_vector(
|
||||
embedding,
|
||||
k=k,
|
||||
fetch_k=fetch_k,
|
||||
lambda_mult=lambda_mult,
|
||||
filter=filter,
|
||||
**kwargs,
|
||||
)
|
||||
return docs
|
||||
|
||||
def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> Optional[bool]:
|
||||
"""Delete by ID. These are the IDs in the vectorstore.
|
||||
|
||||
Args:
|
||||
ids: List of ids to delete.
|
||||
|
||||
Returns:
|
||||
Optional[bool]: True if deletion is successful,
|
||||
False otherwise, None if not implemented.
|
||||
"""
|
||||
if ids is None:
|
||||
raise ValueError("No ids provided to delete.")
|
||||
missing_ids = set(ids).difference(self.index_to_docstore_id.values())
|
||||
if missing_ids:
|
||||
raise ValueError(
|
||||
f"Some specified ids do not exist in the current store. Ids not found: "
|
||||
f"{missing_ids}"
|
||||
)
|
||||
|
||||
reversed_index = {id_: idx for idx, id_ in self.index_to_docstore_id.items()}
|
||||
index_to_delete = [reversed_index[id_] for id_ in ids]
|
||||
|
||||
self.index.remove_ids(np.array(index_to_delete, dtype=np.int64))
|
||||
self.docstore.delete(ids)
|
||||
|
||||
remaining_ids = [
|
||||
id_
|
||||
for i, id_ in sorted(self.index_to_docstore_id.items())
|
||||
if i not in index_to_delete
|
||||
]
|
||||
self.index_to_docstore_id = {i: id_ for i, id_ in enumerate(remaining_ids)}
|
||||
|
||||
return True
|
||||
|
||||
def merge_from(self, target: FAISS) -> None:
|
||||
"""Merge another FAISS object with the current one.
|
||||
|
||||
Add the target FAISS to the current one.
|
||||
|
||||
Args:
|
||||
target: FAISS object you wish to merge into the current one
|
||||
|
||||
Returns:
|
||||
None.
|
||||
"""
|
||||
if not isinstance(self.docstore, AddableMixin):
|
||||
raise ValueError("Cannot merge with this type of docstore")
|
||||
# Numerical index for target docs are incremental on existing ones
|
||||
starting_len = len(self.index_to_docstore_id)
|
||||
|
||||
# Merge two IndexFlatL2
|
||||
self.index.merge_from(target.index)
|
||||
|
||||
# Get id and docs from target FAISS object
|
||||
full_info = []
|
||||
for i, target_id in target.index_to_docstore_id.items():
|
||||
doc = target.docstore.search(target_id)
|
||||
if not isinstance(doc, Document):
|
||||
raise ValueError("Document should be returned")
|
||||
full_info.append((starting_len + i, target_id, doc))
|
||||
|
||||
# Add information to docstore and index_to_docstore_id.
|
||||
self.docstore.add({_id: doc for _, _id, doc in full_info})
|
||||
index_to_id = {index: _id for index, _id, _ in full_info}
|
||||
self.index_to_docstore_id.update(index_to_id)
|
||||
|
||||
@classmethod
|
||||
def __from(
|
||||
cls,
|
||||
texts: Iterable[str],
|
||||
embeddings: List[List[float]],
|
||||
embedding: Embeddings,
|
||||
metadatas: Optional[Iterable[dict]] = None,
|
||||
ids: Optional[List[str]] = None,
|
||||
normalize_L2: bool = False,
|
||||
distance_strategy: DistanceStrategy = DistanceStrategy.EUCLIDEAN_DISTANCE,
|
||||
**kwargs: Any,
|
||||
) -> FAISS:
|
||||
faiss = dependable_faiss_import()
|
||||
if distance_strategy == DistanceStrategy.MAX_INNER_PRODUCT:
|
||||
index = faiss.IndexFlatIP(len(embeddings[0]))
|
||||
else:
|
||||
# Default to L2, currently other metric types not initialized.
|
||||
index = faiss.IndexFlatL2(len(embeddings[0]))
|
||||
vecstore = cls(
|
||||
embedding.embed_query,
|
||||
index,
|
||||
InMemoryDocstore(),
|
||||
{},
|
||||
normalize_L2=normalize_L2,
|
||||
distance_strategy=distance_strategy,
|
||||
**kwargs,
|
||||
)
|
||||
vecstore.__add(texts, embeddings, metadatas=metadatas, ids=ids)
|
||||
return vecstore
|
||||
|
||||
@classmethod
|
||||
def from_texts(
|
||||
cls,
|
||||
texts: List[str],
|
||||
embedding: Embeddings,
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
ids: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> FAISS:
|
||||
"""Construct FAISS wrapper from raw documents.
|
||||
|
||||
This is a user friendly interface that:
|
||||
1. Embeds documents.
|
||||
2. Creates an in memory docstore
|
||||
3. Initializes the FAISS database
|
||||
|
||||
This is intended to be a quick way to get started.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain import FAISS
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
|
||||
embeddings = OpenAIEmbeddings()
|
||||
faiss = FAISS.from_texts(texts, embeddings)
|
||||
"""
|
||||
from loguru import logger
|
||||
logger.debug(f"texts: {len(texts)}")
|
||||
embeddings = embedding.embed_documents(texts)
|
||||
return cls.__from(
|
||||
texts,
|
||||
embeddings,
|
||||
embedding,
|
||||
metadatas=metadatas,
|
||||
ids=ids,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_embeddings(
|
||||
cls,
|
||||
text_embeddings: Iterable[Tuple[str, List[float]]],
|
||||
embedding: Embeddings,
|
||||
metadatas: Optional[Iterable[dict]] = None,
|
||||
ids: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> FAISS:
|
||||
"""Construct FAISS wrapper from raw documents.
|
||||
|
||||
This is a user friendly interface that:
|
||||
1. Embeds documents.
|
||||
2. Creates an in memory docstore
|
||||
3. Initializes the FAISS database
|
||||
|
||||
This is intended to be a quick way to get started.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain import FAISS
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
|
||||
embeddings = OpenAIEmbeddings()
|
||||
text_embeddings = embeddings.embed_documents(texts)
|
||||
text_embedding_pairs = zip(texts, text_embeddings)
|
||||
faiss = FAISS.from_embeddings(text_embedding_pairs, embeddings)
|
||||
"""
|
||||
texts = [t[0] for t in text_embeddings]
|
||||
embeddings = [t[1] for t in text_embeddings]
|
||||
return cls.__from(
|
||||
texts,
|
||||
embeddings,
|
||||
embedding,
|
||||
metadatas=metadatas,
|
||||
ids=ids,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def save_local(self, folder_path: str, index_name: str = "index") -> None:
|
||||
"""Save FAISS index, docstore, and index_to_docstore_id to disk.
|
||||
|
||||
Args:
|
||||
folder_path: folder path to save index, docstore,
|
||||
and index_to_docstore_id to.
|
||||
index_name: for saving with a specific index file name
|
||||
"""
|
||||
path = Path(folder_path)
|
||||
path.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
# save index separately since it is not picklable
|
||||
faiss = dependable_faiss_import()
|
||||
faiss.write_index(
|
||||
self.index, str(path / "{index_name}.faiss".format(index_name=index_name))
|
||||
)
|
||||
|
||||
# save docstore and index_to_docstore_id
|
||||
with open(path / "{index_name}.pkl".format(index_name=index_name), "wb") as f:
|
||||
pickle.dump((self.docstore, self.index_to_docstore_id), f)
|
||||
|
||||
@classmethod
|
||||
def load_local(
|
||||
cls,
|
||||
folder_path: str,
|
||||
embeddings: Embeddings,
|
||||
index_name: str = "index",
|
||||
**kwargs: Any,
|
||||
) -> FAISS:
|
||||
"""Load FAISS index, docstore, and index_to_docstore_id from disk.
|
||||
|
||||
Args:
|
||||
folder_path: folder path to load index, docstore,
|
||||
and index_to_docstore_id from.
|
||||
embeddings: Embeddings to use when generating queries
|
||||
index_name: for saving with a specific index file name
|
||||
"""
|
||||
path = Path(folder_path)
|
||||
# load index separately since it is not picklable
|
||||
faiss = dependable_faiss_import()
|
||||
index = faiss.read_index(
|
||||
str(path / "{index_name}.faiss".format(index_name=index_name))
|
||||
)
|
||||
|
||||
# load docstore and index_to_docstore_id
|
||||
with open(path / "{index_name}.pkl".format(index_name=index_name), "rb") as f:
|
||||
docstore, index_to_docstore_id = pickle.load(f)
|
||||
return cls(
|
||||
embeddings.embed_query, index, docstore, index_to_docstore_id, **kwargs
|
||||
)
|
||||
|
||||
def serialize_to_bytes(self) -> bytes:
|
||||
"""Serialize FAISS index, docstore, and index_to_docstore_id to bytes."""
|
||||
return pickle.dumps((self.index, self.docstore, self.index_to_docstore_id))
|
||||
|
||||
@classmethod
|
||||
def deserialize_from_bytes(
|
||||
cls,
|
||||
serialized: bytes,
|
||||
embeddings: Embeddings,
|
||||
**kwargs: Any,
|
||||
) -> FAISS:
|
||||
"""Deserialize FAISS index, docstore, and index_to_docstore_id from bytes."""
|
||||
index, docstore, index_to_docstore_id = pickle.loads(serialized)
|
||||
return cls(
|
||||
embeddings.embed_query, index, docstore, index_to_docstore_id, **kwargs
|
||||
)
|
||||
|
||||
def _select_relevance_score_fn(self) -> Callable[[float], float]:
|
||||
"""
|
||||
The 'correct' relevance function
|
||||
may differ depending on a few things, including:
|
||||
- the distance / similarity metric used by the VectorStore
|
||||
- the scale of your embeddings (OpenAI's are unit normed. Many others are not!)
|
||||
- embedding dimensionality
|
||||
- etc.
|
||||
"""
|
||||
if self.override_relevance_score_fn is not None:
|
||||
return self.override_relevance_score_fn
|
||||
|
||||
# Default strategy is to rely on distance strategy provided in
|
||||
# vectorstore constructor
|
||||
if self.distance_strategy == DistanceStrategy.MAX_INNER_PRODUCT:
|
||||
return self._max_inner_product_relevance_score_fn
|
||||
elif self.distance_strategy == DistanceStrategy.EUCLIDEAN_DISTANCE:
|
||||
# Default behavior is to use euclidean distance relevancy
|
||||
return self._euclidean_relevance_score_fn
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unknown distance strategy, must be cosine, max_inner_product,"
|
||||
" or euclidean"
|
||||
)
|
||||
|
||||
def _similarity_search_with_relevance_scores(
|
||||
self,
|
||||
query: str,
|
||||
k: int = 4,
|
||||
filter: Optional[Dict[str, Any]] = None,
|
||||
fetch_k: int = 20,
|
||||
**kwargs: Any,
|
||||
) -> List[Tuple[Document, float]]:
|
||||
"""Return docs and their similarity scores on a scale from 0 to 1."""
|
||||
# Pop score threshold so that only relevancy scores, not raw scores, are
|
||||
# filtered.
|
||||
relevance_score_fn = self._select_relevance_score_fn()
|
||||
if relevance_score_fn is None:
|
||||
raise ValueError(
|
||||
"normalize_score_fn must be provided to"
|
||||
" FAISS constructor to normalize scores"
|
||||
)
|
||||
docs_and_scores = self.similarity_search_with_score(
|
||||
query,
|
||||
k=k,
|
||||
filter=filter,
|
||||
fetch_k=fetch_k,
|
||||
**kwargs,
|
||||
)
|
||||
docs_and_rel_scores = [
|
||||
(doc, relevance_score_fn(score)) for doc, score in docs_and_scores
|
||||
]
|
||||
return docs_and_rel_scores
|
||||
6
dev_opsgpt/llm_models/__init__.py
Normal file
6
dev_opsgpt/llm_models/__init__.py
Normal file
@ -0,0 +1,6 @@
|
||||
from .openai_model import getChatModel
|
||||
|
||||
|
||||
__all__ = [
|
||||
"getChatModel"
|
||||
]
|
||||
29
dev_opsgpt/llm_models/openai_model.py
Normal file
29
dev_opsgpt/llm_models/openai_model.py
Normal file
@ -0,0 +1,29 @@
|
||||
from langchain.callbacks import AsyncIteratorCallbackHandler
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
|
||||
from configs.model_config import (llm_model_dict, LLM_MODEL)
|
||||
|
||||
|
||||
def getChatModel(callBack: AsyncIteratorCallbackHandler = None, temperature=0.3, stop=None):
|
||||
if callBack is None:
|
||||
model = ChatOpenAI(
|
||||
streaming=True,
|
||||
verbose=True,
|
||||
openai_api_key=llm_model_dict[LLM_MODEL]["api_key"],
|
||||
openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"],
|
||||
model_name=LLM_MODEL,
|
||||
temperature=temperature,
|
||||
stop=stop
|
||||
)
|
||||
else:
|
||||
model = ChatOpenAI(
|
||||
streaming=True,
|
||||
verbose=True,
|
||||
callBack=[callBack],
|
||||
openai_api_key=llm_model_dict[LLM_MODEL]["api_key"],
|
||||
openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"],
|
||||
model_name=LLM_MODEL,
|
||||
temperature=temperature,
|
||||
stop=stop
|
||||
)
|
||||
return model
|
||||
@ -18,5 +18,6 @@ def check_tables_exist(table_name) -> bool:
|
||||
return table_exist
|
||||
|
||||
def table_init():
|
||||
if (not check_tables_exist("knowledge_base")) or (not check_tables_exist ("knowledge_file")):
|
||||
if (not check_tables_exist("knowledge_base")) or (not check_tables_exist ("knowledge_file")) or \
|
||||
(not check_tables_exist ("code_base")):
|
||||
create_tables()
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
from .document_file_cds import *
|
||||
from .document_base_cds import *
|
||||
from .code_base_cds import *
|
||||
|
||||
__all__ = [
|
||||
"add_kb_to_db", "list_kbs_from_db", "kb_exists",
|
||||
@ -7,4 +8,7 @@ __all__ = [
|
||||
|
||||
"list_docs_from_db", "add_doc_to_db", "delete_file_from_db",
|
||||
"delete_files_from_db", "doc_exists", "get_file_detail",
|
||||
|
||||
"list_cbs_from_db", "add_cb_to_db", "delete_cb_from_db",
|
||||
"cb_exists", "get_cb_detail",
|
||||
]
|
||||
79
dev_opsgpt/orm/commands/code_base_cds.py
Normal file
79
dev_opsgpt/orm/commands/code_base_cds.py
Normal file
@ -0,0 +1,79 @@
|
||||
# encoding: utf-8
|
||||
'''
|
||||
@author: 温进
|
||||
@file: code_base_cds.py.py
|
||||
@time: 2023/10/23 下午4:34
|
||||
@desc:
|
||||
'''
|
||||
from loguru import logger
|
||||
from dev_opsgpt.orm.db import with_session, _engine
|
||||
from dev_opsgpt.orm.schemas.base_schema import CodeBaseSchema
|
||||
|
||||
|
||||
@with_session
|
||||
def add_cb_to_db(session, code_name, code_path, code_graph_node_num, code_file_num):
|
||||
# 增:创建知识库实例
|
||||
cb = session.query(CodeBaseSchema).filter_by(code_name=code_name).first()
|
||||
if not cb:
|
||||
cb = CodeBaseSchema(code_name=code_name, code_path=code_path, code_graph_node_num=code_graph_node_num,
|
||||
code_file_num=code_file_num)
|
||||
session.add(cb)
|
||||
else:
|
||||
cb.code_path = code_path
|
||||
cb.code_graph_node_num = code_graph_node_num
|
||||
return True
|
||||
|
||||
|
||||
@with_session
|
||||
def list_cbs_from_db(session):
|
||||
'''
|
||||
查:查询实例
|
||||
'''
|
||||
cbs = session.query(CodeBaseSchema.code_name).all()
|
||||
cbs = [cb[0] for cb in cbs]
|
||||
return cbs
|
||||
|
||||
|
||||
@with_session
|
||||
def cb_exists(session, code_name):
|
||||
'''
|
||||
判断是否存在
|
||||
'''
|
||||
cb = session.query(CodeBaseSchema).filter_by(code_name=code_name).first()
|
||||
status = True if cb else False
|
||||
return status
|
||||
|
||||
@with_session
|
||||
def load_cb_from_db(session, code_name):
|
||||
cb = session.query(CodeBaseSchema).filter_by(code_name=code_name).first()
|
||||
if cb:
|
||||
code_name, code_path, code_graph_node_num = cb.code_name, cb.code_path, cb.code_graph_node_num
|
||||
else:
|
||||
code_name, code_path, code_graph_node_num = None, None, None
|
||||
return code_name, code_path, code_graph_node_num
|
||||
|
||||
|
||||
@with_session
|
||||
def delete_cb_from_db(session, code_name):
|
||||
cb = session.query(CodeBaseSchema).filter_by(code_name=code_name).first()
|
||||
if cb:
|
||||
session.delete(cb)
|
||||
return True
|
||||
|
||||
|
||||
@with_session
|
||||
def get_cb_detail(session, code_name: str) -> dict:
|
||||
cb: CodeBaseSchema = session.query(CodeBaseSchema).filter_by(code_name=code_name).first()
|
||||
logger.info(cb)
|
||||
logger.info('code_name={}'.format(cb.code_name))
|
||||
if cb:
|
||||
return {
|
||||
"code_name": cb.code_name,
|
||||
"code_path": cb.code_path,
|
||||
"code_graph_node_num": cb.code_graph_node_num,
|
||||
'code_file_num': cb.code_file_num
|
||||
}
|
||||
else:
|
||||
return {
|
||||
}
|
||||
|
||||
@ -46,3 +46,24 @@ class KnowledgeFileSchema(Base):
|
||||
text_splitter_name='{self.text_splitter_name}',
|
||||
file_version='{self.file_version}',
|
||||
create_time='{self.create_time}')>"""
|
||||
|
||||
|
||||
class CodeBaseSchema(Base):
|
||||
'''
|
||||
代码数据库模型
|
||||
'''
|
||||
__tablename__ = 'code_base'
|
||||
id = Column(Integer, primary_key=True, autoincrement=True, comment='代码库 ID')
|
||||
code_name = Column(String, comment='代码库名称')
|
||||
code_path = Column(String, comment='代码本地路径')
|
||||
code_graph_node_num = Column(String, comment='代码图谱节点数')
|
||||
code_file_num = Column(String, comment='代码解析文件数')
|
||||
create_time = Column(DateTime, default=func.now(), comment='创建时间')
|
||||
|
||||
def __repr__(self):
|
||||
return f"""<CodeBase(id='{self.id}',
|
||||
code_name='{self.code_name}',
|
||||
code_path='{self.code_path}',
|
||||
code_graph_node_num='{self.code_graph_node_num}',
|
||||
code_file_num='{self.code_file_num}'
|
||||
create_time='{self.create_time}')>"""
|
||||
|
||||
@ -3,6 +3,7 @@ from typing import Optional
|
||||
from pathlib import Path
|
||||
import sys
|
||||
from abc import ABC, abstractclassmethod
|
||||
from loguru import logger
|
||||
|
||||
from configs.server_config import SANDBOX_SERVER
|
||||
|
||||
@ -22,21 +23,6 @@ class CodeBoxStatus(BaseModel):
|
||||
status: str
|
||||
|
||||
|
||||
class CodeBoxFile(BaseModel):
|
||||
"""
|
||||
Represents a file returned from a CodeBox instance.
|
||||
"""
|
||||
|
||||
name: str
|
||||
content: Optional[bytes] = None
|
||||
|
||||
def __str__(self):
|
||||
return self.name
|
||||
|
||||
def __repr__(self):
|
||||
return f"File({self.name})"
|
||||
|
||||
|
||||
class BaseBox(ABC):
|
||||
|
||||
enter_status = False
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
import time, os, docker, requests, json, uuid, subprocess, time, asyncio, aiohttp, re, traceback
|
||||
import psutil
|
||||
from typing import List, Optional, Union
|
||||
from pathlib import Path
|
||||
from loguru import logger
|
||||
|
||||
from websockets.sync.client import connect as ws_connect_sync
|
||||
@ -11,7 +10,8 @@ from websockets.client import WebSocketClientProtocol, ClientConnection
|
||||
from websockets.exceptions import ConnectionClosedError
|
||||
|
||||
from configs.server_config import SANDBOX_SERVER
|
||||
from .basebox import BaseBox, CodeBoxResponse, CodeBoxStatus, CodeBoxFile
|
||||
from configs.model_config import JUPYTER_WORK_PATH
|
||||
from .basebox import BaseBox, CodeBoxResponse, CodeBoxStatus
|
||||
|
||||
|
||||
class PyCodeBox(BaseBox):
|
||||
@ -25,12 +25,18 @@ class PyCodeBox(BaseBox):
|
||||
remote_port: str = SANDBOX_SERVER["port"],
|
||||
token: str = "mytoken",
|
||||
do_code_exe: bool = False,
|
||||
do_remote: bool = False
|
||||
do_remote: bool = False,
|
||||
do_check_net: bool = True,
|
||||
):
|
||||
super().__init__(remote_url, remote_ip, remote_port, token, do_code_exe, do_remote)
|
||||
self.enter_status = True
|
||||
self.do_check_net = do_check_net
|
||||
asyncio.run(self.astart())
|
||||
|
||||
# logger.info(f"""remote_url: {self.remote_url},
|
||||
# remote_ip: {self.remote_ip},
|
||||
# remote_port: {self.remote_port}""")
|
||||
|
||||
def decode_code_from_text(self, text: str) -> str:
|
||||
pattern = r'```.*?```'
|
||||
code_blocks = re.findall(pattern, text, re.DOTALL)
|
||||
@ -73,7 +79,8 @@ class PyCodeBox(BaseBox):
|
||||
if not self.ws:
|
||||
raise RuntimeError("Jupyter not running. Make sure to start it first")
|
||||
|
||||
logger.debug(f"code_text: {json.dumps(code_text, ensure_ascii=False)}")
|
||||
# logger.debug(f"code_text: {len(code_text)}, {code_text}")
|
||||
|
||||
self.ws.send(
|
||||
json.dumps(
|
||||
{
|
||||
@ -103,7 +110,7 @@ class PyCodeBox(BaseBox):
|
||||
raise RuntimeError("Mixing asyncio and sync code is not supported")
|
||||
received_msg = json.loads(self.ws.recv())
|
||||
except ConnectionClosedError:
|
||||
logger.debug("box start, ConnectionClosedError!!!")
|
||||
# logger.debug("box start, ConnectionClosedError!!!")
|
||||
self.start()
|
||||
return self.run(code_text, file_path, retry - 1)
|
||||
|
||||
@ -156,7 +163,7 @@ class PyCodeBox(BaseBox):
|
||||
return CodeBoxResponse(
|
||||
code_exe_type="text",
|
||||
code_text=code_text,
|
||||
code_exe_response=result or "Code run successfully (no output)",
|
||||
code_exe_response=result or "Code run successfully (no output),可能没有打印需要确认的变量",
|
||||
code_exe_status=200,
|
||||
do_code_exe=self.do_code_exe
|
||||
)
|
||||
@ -219,7 +226,6 @@ class PyCodeBox(BaseBox):
|
||||
async def _acheck_connect(self, ) -> bool:
|
||||
if self.kernel_url == "":
|
||||
return False
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(f"{self.kernel_url}?token={self.token}", timeout=270) as resp:
|
||||
@ -231,7 +237,7 @@ class PyCodeBox(BaseBox):
|
||||
|
||||
def _check_port(self, ) -> bool:
|
||||
try:
|
||||
response = requests.get(f"http://localhost:{self.remote_port}", timeout=270)
|
||||
response = requests.get(f"{self.remote_ip}:{self.remote_port}", timeout=270)
|
||||
logger.warning(f"Port is conflict, please check your codebox's port {self.remote_port}")
|
||||
return response.status_code == 200
|
||||
except requests.exceptions.ConnectionError:
|
||||
@ -240,7 +246,7 @@ class PyCodeBox(BaseBox):
|
||||
async def _acheck_port(self, ) -> bool:
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(f"http://localhost:{self.remote_port}", timeout=270) as resp:
|
||||
async with session.get(f"{self.remote_ip}:{self.remote_port}", timeout=270) as resp:
|
||||
logger.warning(f"Port is conflict, please check your codebox's port {self.remote_port}")
|
||||
return resp.status == 200
|
||||
except aiohttp.ClientConnectorError:
|
||||
@ -249,6 +255,8 @@ class PyCodeBox(BaseBox):
|
||||
pass
|
||||
|
||||
def _check_connect_success(self, retry_nums: int = 5) -> bool:
|
||||
if not self.do_check_net: return True
|
||||
|
||||
while retry_nums > 0:
|
||||
try:
|
||||
connect_status = self._check_connect()
|
||||
@ -262,6 +270,7 @@ class PyCodeBox(BaseBox):
|
||||
raise BaseException(f"can't connect to {self.remote_url}")
|
||||
|
||||
async def _acheck_connect_success(self, retry_nums: int = 5) -> bool:
|
||||
if not self.do_check_net: return True
|
||||
while retry_nums > 0:
|
||||
try:
|
||||
connect_status = await self._acheck_connect()
|
||||
@ -283,7 +292,7 @@ class PyCodeBox(BaseBox):
|
||||
self._check_connect_success()
|
||||
|
||||
self._get_kernelid()
|
||||
logger.debug(self.kernel_url.replace("http", "ws") + f"/{self.kernel_id}/channels?token={self.token}")
|
||||
# logger.debug(self.kernel_url.replace("http", "ws") + f"/{self.kernel_id}/channels?token={self.token}")
|
||||
self.wc_url = self.kernel_url.replace("http", "ws") + f"/{self.kernel_id}/channels?token={self.token}"
|
||||
headers = {"Authorization": f'Token {self.token}', 'token': self.token}
|
||||
self.ws = create_connection(self.wc_url, headers=headers)
|
||||
@ -291,27 +300,30 @@ class PyCodeBox(BaseBox):
|
||||
# TODO 自动检测本地接口
|
||||
port_status = self._check_port()
|
||||
connect_status = self._check_connect()
|
||||
logger.debug(f"port_status: {port_status}, connect_status: {connect_status}")
|
||||
logger.info(f"port_status: {port_status}, connect_status: {connect_status}")
|
||||
if port_status and not connect_status:
|
||||
raise BaseException(f"Port is conflict, please check your codebox's port {self.remote_port}")
|
||||
|
||||
if not connect_status:
|
||||
self.jupyter = subprocess.Popen(
|
||||
self.jupyter = subprocess.run(
|
||||
[
|
||||
"jupyer", "notebnook",
|
||||
f"--NotebookApp.token={self.token}",
|
||||
f"--port={self.remote_port}",
|
||||
"--no-browser",
|
||||
"--ServerApp.disable_check_xsrf=True",
|
||||
"--notebook-dir={JUPYTER_WORK_PATH}"
|
||||
],
|
||||
stderr=subprocess.PIPE,
|
||||
stdin=subprocess.PIPE,
|
||||
stdout=subprocess.PIPE,
|
||||
)
|
||||
|
||||
self.kernel_url = self.remote_url + "/api/kernels"
|
||||
self.do_check_net = True
|
||||
self._check_connect_success()
|
||||
self._get_kernelid()
|
||||
logger.debug(self.kernel_url.replace("http", "ws") + f"/{self.kernel_id}/channels?token={self.token}")
|
||||
# logger.debug(self.kernel_url.replace("http", "ws") + f"/{self.kernel_id}/channels?token={self.token}")
|
||||
self.wc_url = self.kernel_url.replace("http", "ws") + f"/{self.kernel_id}/channels?token={self.token}"
|
||||
headers = {"Authorization": f'Token {self.token}', 'token': self.token}
|
||||
self.ws = create_connection(self.wc_url, headers=headers)
|
||||
@ -333,10 +345,10 @@ class PyCodeBox(BaseBox):
|
||||
port_status = await self._acheck_port()
|
||||
self.kernel_url = self.remote_url + "/api/kernels"
|
||||
connect_status = await self._acheck_connect()
|
||||
logger.debug(f"port_status: {port_status}, connect_status: {connect_status}")
|
||||
logger.info(f"port_status: {port_status}, connect_status: {connect_status}")
|
||||
if port_status and not connect_status:
|
||||
raise BaseException(f"Port is conflict, please check your codebox's port {self.remote_port}")
|
||||
|
||||
|
||||
if not connect_status:
|
||||
self.jupyter = subprocess.Popen(
|
||||
[
|
||||
@ -344,13 +356,15 @@ class PyCodeBox(BaseBox):
|
||||
f"--NotebookApp.token={self.token}",
|
||||
f"--port={self.remote_port}",
|
||||
"--no-browser",
|
||||
"--ServerApp.disable_check_xsrf=True",
|
||||
"--ServerApp.disable_check_xsrf=True"
|
||||
],
|
||||
stderr=subprocess.PIPE,
|
||||
stdin=subprocess.PIPE,
|
||||
stdout=subprocess.PIPE,
|
||||
)
|
||||
|
||||
self.kernel_url = self.remote_url + "/api/kernels"
|
||||
self.do_check_net = True
|
||||
await self._acheck_connect_success()
|
||||
await self._aget_kernelid()
|
||||
self.wc_url = self.kernel_url.replace("http", "ws") + f"/{self.kernel_id}/channels?token={self.token}"
|
||||
@ -405,7 +419,8 @@ class PyCodeBox(BaseBox):
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
self.ws = None
|
||||
return CodeBoxStatus(status="stopped")
|
||||
|
||||
# return CodeBoxStatus(status="stopped")
|
||||
|
||||
def __del__(self):
|
||||
self.stop()
|
||||
|
||||
@ -18,15 +18,19 @@ from configs.server_config import OPEN_CROSS_DOMAIN
|
||||
|
||||
from dev_opsgpt.chat import LLMChat, SearchChat, KnowledgeChat
|
||||
from dev_opsgpt.service.kb_api import *
|
||||
from dev_opsgpt.service.cb_api import *
|
||||
from dev_opsgpt.utils.server_utils import BaseResponse, ListResponse, FastAPI, MakeFastAPIOffline
|
||||
|
||||
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
|
||||
|
||||
from dev_opsgpt.chat import LLMChat, SearchChat, KnowledgeChat
|
||||
from dev_opsgpt.chat import LLMChat, SearchChat, KnowledgeChat, ToolChat, DataChat, CodeChat
|
||||
|
||||
llmChat = LLMChat()
|
||||
searchChat = SearchChat()
|
||||
knowledgeChat = KnowledgeChat()
|
||||
toolChat = ToolChat()
|
||||
dataChat = DataChat()
|
||||
codeChat = CodeChat()
|
||||
|
||||
|
||||
async def document():
|
||||
@ -71,6 +75,18 @@ def create_app():
|
||||
app.post("/chat/search_engine_chat",
|
||||
tags=["Chat"],
|
||||
summary="与搜索引擎对话")(searchChat.chat)
|
||||
app.post("/chat/tool_chat",
|
||||
tags=["Chat"],
|
||||
summary="与搜索引擎对话")(toolChat.chat)
|
||||
|
||||
app.post("/chat/data_chat",
|
||||
tags=["Chat"],
|
||||
summary="与搜索引擎对话")(dataChat.chat)
|
||||
|
||||
app.post("/chat/code_chat",
|
||||
tags=["Chat"],
|
||||
summary="与代码库对话")(codeChat.chat)
|
||||
|
||||
|
||||
# Tag: Knowledge Base Management
|
||||
app.get("/knowledge_base/list_knowledge_bases",
|
||||
@ -129,6 +145,27 @@ def create_app():
|
||||
summary="根据content中文档重建向量库,流式输出处理进度。"
|
||||
)(recreate_vector_store)
|
||||
|
||||
app.post("/code_base/create_code_base",
|
||||
tags=["Code Base Management"],
|
||||
summary="新建 code_base"
|
||||
)(create_cb)
|
||||
|
||||
app.post("/code_base/delete_code_base",
|
||||
tags=["Code Base Management"],
|
||||
summary="删除 code_base"
|
||||
)(delete_cb)
|
||||
|
||||
app.post("/code_base/code_base_chat",
|
||||
tags=["Code Base Management"],
|
||||
summary="删除 code_base"
|
||||
)(delete_cb)
|
||||
|
||||
app.get("/code_base/list_code_bases",
|
||||
tags=["Code Base Management"],
|
||||
summary="列举 code_base",
|
||||
response_model=ListResponse
|
||||
)(list_cbs)
|
||||
|
||||
# # LLM模型相关接口
|
||||
# app.post("/llm_model/list_models",
|
||||
# tags=["LLM Model Management"],
|
||||
|
||||
128
dev_opsgpt/service/cb_api.py
Normal file
128
dev_opsgpt/service/cb_api.py
Normal file
@ -0,0 +1,128 @@
|
||||
# encoding: utf-8
|
||||
'''
|
||||
@author: 温进
|
||||
@file: cb_api.py
|
||||
@time: 2023/10/23 下午7:08
|
||||
@desc:
|
||||
'''
|
||||
|
||||
import urllib, os, json, traceback
|
||||
from typing import List, Dict
|
||||
import shutil
|
||||
|
||||
from fastapi.responses import StreamingResponse, FileResponse
|
||||
from fastapi import File, Form, Body, Query, UploadFile
|
||||
from langchain.docstore.document import Document
|
||||
|
||||
from .service_factory import KBServiceFactory
|
||||
from dev_opsgpt.utils.server_utils import BaseResponse, ListResponse
|
||||
from dev_opsgpt.utils.path_utils import *
|
||||
from dev_opsgpt.orm.commands import *
|
||||
|
||||
from configs.model_config import (
|
||||
CB_ROOT_PATH
|
||||
)
|
||||
|
||||
from dev_opsgpt.codebase_handler.codebase_handler import CodeBaseHandler
|
||||
|
||||
from loguru import logger
|
||||
|
||||
|
||||
async def list_cbs():
|
||||
# Get List of Knowledge Base
|
||||
return ListResponse(data=list_cbs_from_db())
|
||||
|
||||
|
||||
async def create_cb(cb_name: str = Body(..., examples=["samples"]),
|
||||
code_path: str = Body(..., examples=["samples"])
|
||||
) -> BaseResponse:
|
||||
logger.info('cb_name={}, zip_path={}'.format(cb_name, code_path))
|
||||
|
||||
# Create selected knowledge base
|
||||
if not validate_kb_name(cb_name):
|
||||
return BaseResponse(code=403, msg="Don't attack me")
|
||||
if cb_name is None or cb_name.strip() == "":
|
||||
return BaseResponse(code=404, msg="知识库名称不能为空,请重新填写知识库名称")
|
||||
|
||||
cb = cb_exists(cb_name)
|
||||
if cb:
|
||||
return BaseResponse(code=404, msg=f"已存在同名代码知识库 {cb_name}")
|
||||
|
||||
try:
|
||||
logger.info('start build code base')
|
||||
cbh = CodeBaseHandler(cb_name, code_path, cb_root_path=CB_ROOT_PATH)
|
||||
cbh.import_code(do_save=True)
|
||||
code_graph_node_num = len(cbh.nh)
|
||||
code_file_num = len(cbh.lcdh)
|
||||
logger.info('build code base done')
|
||||
|
||||
# create cb to table
|
||||
add_cb_to_db(cb_name, cbh.code_path, code_graph_node_num, code_file_num)
|
||||
logger.info('add cb to mysql table success')
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return BaseResponse(code=500, msg=f"创建代码知识库出错: {e}")
|
||||
|
||||
return BaseResponse(code=200, msg=f"已新增代码知识库 {cb_name}")
|
||||
|
||||
|
||||
async def delete_cb(cb_name: str = Body(..., examples=["samples"])) -> BaseResponse:
|
||||
logger.info('cb_name={}'.format(cb_name))
|
||||
# Create selected knowledge base
|
||||
if not validate_kb_name(cb_name):
|
||||
return BaseResponse(code=403, msg="Don't attack me")
|
||||
if cb_name is None or cb_name.strip() == "":
|
||||
return BaseResponse(code=404, msg="知识库名称不能为空,请重新填写知识库名称")
|
||||
|
||||
cb = cb_exists(cb_name)
|
||||
if cb:
|
||||
try:
|
||||
delete_cb_from_db(cb_name)
|
||||
|
||||
# delete local file
|
||||
shutil.rmtree(CB_ROOT_PATH + os.sep + cb_name)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return BaseResponse(code=500, msg=f"删除代码知识库出错: {e}")
|
||||
|
||||
return BaseResponse(code=200, msg=f"已删除代码知识库 {cb_name}")
|
||||
|
||||
|
||||
def search_code(cb_name: str = Body(..., examples=["sofaboot"]),
|
||||
query: str = Body(..., examples=['你好']),
|
||||
code_limit: int = Body(..., examples=['1']),
|
||||
history_node_list: list = Body(...)) -> dict:
|
||||
|
||||
logger.info('cb_name={}'.format(cb_name))
|
||||
logger.info('query={}'.format(query))
|
||||
logger.info('code_limit={}'.format(code_limit))
|
||||
logger.info('history_node_list={}'.format(history_node_list))
|
||||
|
||||
try:
|
||||
# load codebase
|
||||
cbh = CodeBaseHandler(code_name=cb_name, cb_root_path=CB_ROOT_PATH)
|
||||
cbh.import_code(do_load=True)
|
||||
|
||||
# search code
|
||||
related_code, related_node = cbh.search_code(query, code_limit=code_limit, history_node_list=history_node_list)
|
||||
|
||||
res = {
|
||||
'related_code': related_code,
|
||||
'related_node': related_node
|
||||
}
|
||||
return res
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
return {}
|
||||
|
||||
|
||||
def cb_exists_api(cb_name: str = Body(..., examples=["sofaboot"])) -> bool:
|
||||
try:
|
||||
res = cb_exists(cb_name)
|
||||
return res
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
return False
|
||||
|
||||
|
||||
|
||||
@ -4,7 +4,7 @@ from typing import List
|
||||
from functools import lru_cache
|
||||
from loguru import logger
|
||||
|
||||
from langchain.vectorstores import FAISS
|
||||
# from langchain.vectorstores import FAISS
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
||||
@ -22,6 +22,7 @@ from dev_opsgpt.utils.path_utils import *
|
||||
from dev_opsgpt.orm.utils import DocumentFile
|
||||
from dev_opsgpt.utils.server_utils import torch_gc
|
||||
from dev_opsgpt.embeddings.utils import load_embeddings
|
||||
from dev_opsgpt.embeddings.faiss_m import FAISS
|
||||
|
||||
|
||||
# make HuggingFaceEmbeddings hashable
|
||||
@ -124,6 +125,7 @@ class FaissKBService(KBService):
|
||||
vector_store = load_vector_store(self.kb_name,
|
||||
embeddings=embeddings,
|
||||
tick=_VECTOR_STORE_TICKS.get(self.kb_name, 0))
|
||||
vector_store.embedding_function = embeddings.embed_documents
|
||||
logger.info("docs.lens: {}".format(len(docs)))
|
||||
vector_store.add_documents(docs)
|
||||
torch_gc()
|
||||
|
||||
@ -16,7 +16,6 @@ from configs.model_config import (
|
||||
)
|
||||
|
||||
|
||||
|
||||
async def list_kbs():
|
||||
# Get List of Knowledge Base
|
||||
return ListResponse(data=list_kbs_from_db())
|
||||
|
||||
@ -6,7 +6,6 @@ import os
|
||||
src_dir = os.path.join(
|
||||
os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
)
|
||||
print(src_dir)
|
||||
sys.path.append(src_dir)
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
|
||||
@ -20,7 +19,7 @@ model_worker_port = 20002
|
||||
openai_api_port = 8888
|
||||
base_url = "http://127.0.0.1:{}"
|
||||
|
||||
os.environ['PATH'] = os.environ.get("PATH", "") + os.pathsep + r'/d/env_utils/miniconda3/envs/devopsgpt/Lib/site-packages/torch/lib'
|
||||
os.environ['PATH'] = os.environ.get("PATH", "") + os.pathsep
|
||||
|
||||
def set_httpx_timeout(timeout=60.0):
|
||||
import httpx
|
||||
|
||||
129
dev_opsgpt/service/sdfile_api.py
Normal file
129
dev_opsgpt/service/sdfile_api.py
Normal file
@ -0,0 +1,129 @@
|
||||
import sys, os, json, traceback, uvicorn, argparse
|
||||
|
||||
src_dir = os.path.join(
|
||||
os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
)
|
||||
sys.path.append(src_dir)
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import StreamingResponse, FileResponse
|
||||
from fastapi import File, UploadFile
|
||||
|
||||
from dev_opsgpt.utils.server_utils import BaseResponse, ListResponse
|
||||
from configs.server_config import OPEN_CROSS_DOMAIN, SDFILE_API_SERVER
|
||||
from configs.model_config import (
|
||||
JUPYTER_WORK_PATH
|
||||
)
|
||||
from configs import VERSION
|
||||
|
||||
|
||||
|
||||
async def sd_upload_file(file: UploadFile = File(...), work_dir: str = JUPYTER_WORK_PATH):
|
||||
# 保存上传的文件到服务器
|
||||
try:
|
||||
content = await file.read()
|
||||
with open(os.path.join(work_dir, file.filename), "wb") as f:
|
||||
f.write(content)
|
||||
return {"data": True}
|
||||
except:
|
||||
return {"data": False}
|
||||
|
||||
|
||||
async def sd_download_file(filename: str, save_filename: str = "filename_to_download.ext", work_dir: str = JUPYTER_WORK_PATH):
|
||||
# 从服务器下载文件
|
||||
logger.debug(f"{os.path.join(work_dir, filename)}")
|
||||
return {"data": FileResponse(os.path.join(work_dir, filename), filename=save_filename)}
|
||||
|
||||
|
||||
async def sd_list_files(work_dir: str = JUPYTER_WORK_PATH):
|
||||
# 去除目录
|
||||
return {"data": os.listdir(work_dir)}
|
||||
|
||||
|
||||
async def sd_delete_file(filename: str, work_dir: str = JUPYTER_WORK_PATH):
|
||||
# 去除目录
|
||||
try:
|
||||
os.remove(os.path.join(work_dir, filename))
|
||||
return {"data": True}
|
||||
except:
|
||||
return {"data": False}
|
||||
|
||||
|
||||
def create_app():
|
||||
app = FastAPI(
|
||||
title="DevOps-ChatBot API Server",
|
||||
version=VERSION
|
||||
)
|
||||
# MakeFastAPIOffline(app)
|
||||
# Add CORS middleware to allow all origins
|
||||
# 在config.py中设置OPEN_DOMAIN=True,允许跨域
|
||||
# set OPEN_DOMAIN=True in config.py to allow cross-domain
|
||||
if OPEN_CROSS_DOMAIN:
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
app.post("/sdfiles/upload",
|
||||
tags=["files upload and download"],
|
||||
response_model=BaseResponse,
|
||||
summary="上传文件到沙盒"
|
||||
)(sd_upload_file)
|
||||
|
||||
app.get("/sdfiles/download",
|
||||
tags=["files upload and download"],
|
||||
response_model=BaseResponse,
|
||||
summary="从沙盒下载文件"
|
||||
)(sd_download_file)
|
||||
|
||||
app.get("/sdfiles/list",
|
||||
tags=["files upload and download"],
|
||||
response_model=ListResponse,
|
||||
summary="从沙盒工作目录展示文件"
|
||||
)(sd_list_files)
|
||||
|
||||
app.get("/sdfiles/delete",
|
||||
tags=["files upload and download"],
|
||||
response_model=BaseResponse,
|
||||
summary="从沙盒工作目录中删除文件"
|
||||
)(sd_delete_file)
|
||||
return app
|
||||
|
||||
|
||||
|
||||
app = create_app()
|
||||
|
||||
def run_api(host, port, **kwargs):
|
||||
if kwargs.get("ssl_keyfile") and kwargs.get("ssl_certfile"):
|
||||
uvicorn.run(app,
|
||||
host=host,
|
||||
port=port,
|
||||
ssl_keyfile=kwargs.get("ssl_keyfile"),
|
||||
ssl_certfile=kwargs.get("ssl_certfile"),
|
||||
)
|
||||
else:
|
||||
uvicorn.run(app, host=host, port=port)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(prog='DevOps-ChatBot',
|
||||
description='About DevOps-ChatBot, local knowledge based LLM with langchain'
|
||||
' | 基于本地知识库的 LLM 问答')
|
||||
parser.add_argument("--host", type=str, default="0.0.0.0")
|
||||
parser.add_argument("--port", type=int, default=SDFILE_API_SERVER["port"])
|
||||
parser.add_argument("--ssl_keyfile", type=str)
|
||||
parser.add_argument("--ssl_certfile", type=str)
|
||||
# 初始化消息
|
||||
args = parser.parse_args()
|
||||
args_dict = vars(args)
|
||||
run_api(host=args.host,
|
||||
port=args.port,
|
||||
ssl_keyfile=args.ssl_keyfile,
|
||||
ssl_certfile=args.ssl_certfile,
|
||||
)
|
||||
@ -77,6 +77,33 @@ def get_kb_details() -> List[Dict]:
|
||||
|
||||
return data
|
||||
|
||||
def get_cb_details() -> List[Dict]:
|
||||
'''
|
||||
get codebase details
|
||||
@return: list of data
|
||||
'''
|
||||
res = {}
|
||||
cbs_in_db = list_cbs_from_db()
|
||||
for cb in cbs_in_db:
|
||||
cb_detail = get_cb_detail(cb)
|
||||
res[cb] = cb_detail
|
||||
|
||||
data = []
|
||||
for i, v in enumerate(res.values()):
|
||||
v['No'] = i + 1
|
||||
data.append(v)
|
||||
return data
|
||||
|
||||
def get_cb_details_by_cb_name(cb_name) -> dict:
|
||||
'''
|
||||
get codebase details by cb_name
|
||||
@return: list of data
|
||||
'''
|
||||
cb_detail = get_cb_detail(cb_name)
|
||||
return cb_detail
|
||||
|
||||
|
||||
|
||||
|
||||
def get_kb_doc_details(kb_name: str) -> List[Dict]:
|
||||
kb = KBServiceFactory.get_service_by_name(kb_name)
|
||||
|
||||
@ -32,10 +32,12 @@ class LCTextSplitter:
|
||||
loader = self._load_document()
|
||||
text_splitter = self._load_text_splitter()
|
||||
if self.document_loader_name in ["JSONLoader", "JSONLLoader"]:
|
||||
docs = loader.load()
|
||||
# docs = loader.load()
|
||||
docs = loader.load_and_split(text_splitter)
|
||||
logger.debug(f"please check your file can be loaded, docs.lens {len(docs)}")
|
||||
else:
|
||||
docs = loader.load_and_split(text_splitter)
|
||||
logger.info(docs[0])
|
||||
|
||||
return docs
|
||||
|
||||
def _load_document(self, ) -> BaseLoader:
|
||||
@ -55,8 +57,8 @@ class LCTextSplitter:
|
||||
chunk_overlap=OVERLAP_SIZE,
|
||||
)
|
||||
self.text_splitter_name = "SpacyTextSplitter"
|
||||
elif self.document_loader_name in ["JSONLoader", "JSONLLoader"]:
|
||||
text_splitter = None
|
||||
# elif self.document_loader_name in ["JSONLoader", "JSONLLoader"]:
|
||||
# text_splitter = None
|
||||
else:
|
||||
text_splitter_module = importlib.import_module('langchain.text_splitter')
|
||||
TextSplitter = getattr(text_splitter_module, self.text_splitter_name)
|
||||
|
||||
33
dev_opsgpt/tools/__init__.py
Normal file
33
dev_opsgpt/tools/__init__.py
Normal file
@ -0,0 +1,33 @@
|
||||
from .base_tool import toLangchainTools, get_tool_schema, BaseToolModel
|
||||
from .weather import WeatherInfo, DistrictInfo
|
||||
from .multiplier import Multiplier
|
||||
from .world_time import WorldTimeGetTimezoneByArea
|
||||
from .abnormal_detection import KSigmaDetector
|
||||
from .metrics_query import MetricsQuery
|
||||
from .duckduckgo_search import DDGSTool
|
||||
from .docs_retrieval import DocRetrieval
|
||||
from .cb_query_tool import CodeRetrieval
|
||||
|
||||
TOOL_SETS = [
|
||||
"WeatherInfo", "WorldTimeGetTimezoneByArea", "Multiplier", "DistrictInfo", "KSigmaDetector", "MetricsQuery", "DDGSTool",
|
||||
"DocRetrieval", "CodeRetrieval"
|
||||
]
|
||||
|
||||
TOOL_DICT = {
|
||||
"WeatherInfo": WeatherInfo,
|
||||
"WorldTimeGetTimezoneByArea": WorldTimeGetTimezoneByArea,
|
||||
"Multiplier": Multiplier,
|
||||
"DistrictInfo": DistrictInfo,
|
||||
"KSigmaDetector": KSigmaDetector,
|
||||
"MetricsQuery": MetricsQuery,
|
||||
"DDGSTool": DDGSTool,
|
||||
"DocRetrieval": DocRetrieval,
|
||||
"CodeRetrieval": CodeRetrieval
|
||||
}
|
||||
|
||||
__all__ = [
|
||||
"WeatherInfo", "WorldTimeGetTimezoneByArea", "Multiplier", "DistrictInfo", "KSigmaDetector", "MetricsQuery", "DDGSTool",
|
||||
"DocRetrieval", "CodeRetrieval",
|
||||
"toLangchainTools", "get_tool_schema", "tool_sets", "BaseToolModel"
|
||||
]
|
||||
|
||||
45
dev_opsgpt/tools/abnormal_detection.py
Normal file
45
dev_opsgpt/tools/abnormal_detection.py
Normal file
@ -0,0 +1,45 @@
|
||||
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Dict
|
||||
import requests
|
||||
import numpy as np
|
||||
from loguru import logger
|
||||
|
||||
from .base_tool import BaseToolModel
|
||||
|
||||
|
||||
|
||||
class KSigmaDetector(BaseToolModel):
|
||||
"""
|
||||
Tips:
|
||||
default control Required, e.g. key1 is not Required/key2 is Required
|
||||
"""
|
||||
|
||||
name: str = "KSigmaDetector"
|
||||
description: str = "Anomaly detection using K-Sigma method"
|
||||
|
||||
class ToolInputArgs(BaseModel):
|
||||
"""Input for KSigmaDetector."""
|
||||
|
||||
data: List[float] = Field(..., description="List of data points")
|
||||
detect_window: int = Field(default=5, description="The size of the detect window for detecting anomalies")
|
||||
abnormal_window: int = Field(default=3, description="The threshold for the number of abnormal points required to classify the data as abnormal")
|
||||
k: float = Field(default=3.0, description="the coef of k-sigma")
|
||||
|
||||
class ToolOutputArgs(BaseModel):
|
||||
"""Output for KSigmaDetector."""
|
||||
|
||||
is_abnormal: bool = Field(..., description="Indicates whether the input data is abnormal or not")
|
||||
|
||||
@staticmethod
|
||||
def run(data, detect_window=5, abnormal_window=3, k=3.0):
|
||||
refer_data = np.array(data[-detect_window:])
|
||||
detect_data = np.array(data[:-detect_window])
|
||||
mean = np.mean(refer_data)
|
||||
std = np.std(refer_data)
|
||||
|
||||
is_abnormal = np.sum(np.abs(detect_data - mean) > k * std) >= abnormal_window
|
||||
return {"is_abnormal": is_abnormal}
|
||||
79
dev_opsgpt/tools/base_tool.py
Normal file
79
dev_opsgpt/tools/base_tool.py
Normal file
@ -0,0 +1,79 @@
|
||||
from langchain.agents import Tool
|
||||
from langchain.tools import StructuredTool
|
||||
from langchain.tools.base import ToolException
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Dict
|
||||
# import jsonref
|
||||
import json
|
||||
|
||||
|
||||
class BaseToolModel:
|
||||
name = "BaseToolModel"
|
||||
description = "Tool Description"
|
||||
|
||||
class ToolInputArgs(BaseModel):
|
||||
"""
|
||||
Input for MoveFileTool.
|
||||
Tips:
|
||||
default control Required, e.g. key1 is not Required/key2 is Required
|
||||
"""
|
||||
|
||||
key1: str = Field(default=None, description="hello world!")
|
||||
key2: str = Field(..., description="hello world!!")
|
||||
|
||||
class ToolOutputArgs(BaseModel):
|
||||
"""
|
||||
Input for MoveFileTool.
|
||||
Tips:
|
||||
default control Required, e.g. key1 is not Required/key2 is Required
|
||||
"""
|
||||
|
||||
key1: str = Field(default=None, description="hello world!")
|
||||
key2: str = Field(..., description="hello world!!")
|
||||
|
||||
@classmethod
|
||||
def run(cls, tool_input_args: ToolInputArgs) -> ToolOutputArgs:
|
||||
"""excute your tool!"""
|
||||
pass
|
||||
|
||||
|
||||
class BaseTools:
|
||||
tools: List[BaseToolModel]
|
||||
|
||||
|
||||
def get_tool_schema(tool: BaseToolModel) -> Dict:
|
||||
'''转json schema结构'''
|
||||
data = jsonref.loads(tool.schema_json())
|
||||
_ = json.dumps(data, indent=4)
|
||||
del data["definitions"]
|
||||
return data
|
||||
|
||||
|
||||
def _handle_error(error: ToolException) -> str:
|
||||
return (
|
||||
"The following errors occurred during tool execution:"
|
||||
+ error.args[0]
|
||||
+ "Please try again."
|
||||
)
|
||||
|
||||
import requests
|
||||
from loguru import logger
|
||||
def fff(city, extensions):
|
||||
url = "https://restapi.amap.com/v3/weather/weatherInfo"
|
||||
json_data = {"key": "4ceb2ef6257a627b72e3be6beab5b059", "city": city, "extensions": extensions}
|
||||
logger.debug(f"json_data: {json_data}")
|
||||
res = requests.get(url, params={"key": "4ceb2ef6257a627b72e3be6beab5b059", "city": city, "extensions": extensions})
|
||||
return res.json()
|
||||
|
||||
|
||||
def toLangchainTools(tools: BaseTools) -> List:
|
||||
''''''
|
||||
return [
|
||||
StructuredTool(
|
||||
name=tool.name,
|
||||
func=tool.run,
|
||||
description=tool.description,
|
||||
args_schema=tool.ToolInputArgs,
|
||||
handle_tool_error=_handle_error,
|
||||
) for tool in tools
|
||||
]
|
||||
47
dev_opsgpt/tools/cb_query_tool.py
Normal file
47
dev_opsgpt/tools/cb_query_tool.py
Normal file
@ -0,0 +1,47 @@
|
||||
# encoding: utf-8
|
||||
'''
|
||||
@author: 温进
|
||||
@file: cb_query_tool.py
|
||||
@time: 2023/11/2 下午4:41
|
||||
@desc:
|
||||
'''
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Dict
|
||||
import requests
|
||||
import numpy as np
|
||||
from loguru import logger
|
||||
|
||||
from configs.model_config import (
|
||||
CODE_SEARCH_TOP_K)
|
||||
from .base_tool import BaseToolModel
|
||||
|
||||
from dev_opsgpt.service.cb_api import search_code
|
||||
|
||||
|
||||
class CodeRetrieval(BaseToolModel):
|
||||
name = "CodeRetrieval"
|
||||
description = "采用知识图谱从本地代码知识库获取相关代码"
|
||||
|
||||
class ToolInputArgs(BaseModel):
|
||||
query: str = Field(..., description="检索的关键字或问题")
|
||||
code_base_name: str = Field(..., description="知识库名称", examples=["samples"])
|
||||
code_limit: int = Field(CODE_SEARCH_TOP_K, description="检索返回的数量")
|
||||
|
||||
class ToolOutputArgs(BaseModel):
|
||||
"""Output for MetricsQuery."""
|
||||
code: str = Field(..., description="检索代码")
|
||||
|
||||
@classmethod
|
||||
def run(cls, code_base_name, query, code_limit=CODE_SEARCH_TOP_K, history_node_list=[]):
|
||||
"""excute your tool!"""
|
||||
codes = search_code(code_base_name, query, code_limit, history_node_list=history_node_list)
|
||||
return_codes = []
|
||||
related_code = codes['related_code']
|
||||
related_nodes = codes['related_node']
|
||||
|
||||
for idx, code in enumerate(related_code):
|
||||
return_codes.append({'index': idx, 'code': code, "related_nodes": related_nodes})
|
||||
return return_codes
|
||||
42
dev_opsgpt/tools/docs_retrieval.py
Normal file
42
dev_opsgpt/tools/docs_retrieval.py
Normal file
@ -0,0 +1,42 @@
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Dict
|
||||
import requests
|
||||
import numpy as np
|
||||
from loguru import logger
|
||||
|
||||
from configs.model_config import (
|
||||
VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD)
|
||||
from .base_tool import BaseToolModel
|
||||
|
||||
|
||||
|
||||
from dev_opsgpt.service.kb_api import search_docs
|
||||
|
||||
|
||||
class DocRetrieval(BaseToolModel):
|
||||
name = "DocRetrieval"
|
||||
description = "采用向量化对本地知识库进行检索"
|
||||
|
||||
class ToolInputArgs(BaseModel):
|
||||
query: str = Field(..., description="检索的关键字或问题")
|
||||
knowledge_base_name: str = Field(..., description="知识库名称", examples=["samples"])
|
||||
search_top: int = Field(VECTOR_SEARCH_TOP_K, description="检索返回的数量")
|
||||
score_threshold: float = Field(SCORE_THRESHOLD, description="知识库匹配相关度阈值,取值范围在0-1之间,SCORE越小,相关度越高,取到1相当于不筛选,建议设置在0.5左右", ge=0, le=1)
|
||||
|
||||
class ToolOutputArgs(BaseModel):
|
||||
"""Output for MetricsQuery."""
|
||||
title: str = Field(..., description="检索网页标题")
|
||||
snippet: str = Field(..., description="检索内容的判断")
|
||||
link: str = Field(..., description="检索网页地址")
|
||||
|
||||
@classmethod
|
||||
def run(cls, query, knowledge_base_name, search_top=VECTOR_SEARCH_TOP_K, score_threshold=SCORE_THRESHOLD):
|
||||
"""excute your tool!"""
|
||||
docs = search_docs(query, knowledge_base_name, search_top, score_threshold)
|
||||
return_docs = []
|
||||
for idx, doc in enumerate(docs):
|
||||
return_docs.append({"index": idx, "snippet": doc.page_content, "title": doc.metadata.get("source"), "link": doc.metadata.get("source")})
|
||||
return return_docs
|
||||
72
dev_opsgpt/tools/duckduckgo_search.py
Normal file
72
dev_opsgpt/tools/duckduckgo_search.py
Normal file
@ -0,0 +1,72 @@
|
||||
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Dict
|
||||
import requests
|
||||
import numpy as np
|
||||
from loguru import logger
|
||||
|
||||
from .base_tool import BaseToolModel
|
||||
from configs.model_config import (
|
||||
PROMPT_TEMPLATE, SEARCH_ENGINE_TOP_K, BING_SUBSCRIPTION_KEY, BING_SEARCH_URL,
|
||||
VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD)
|
||||
|
||||
from duckduckgo_search import DDGS
|
||||
|
||||
|
||||
class DDGSTool(BaseToolModel):
|
||||
name = "DDGSTool"
|
||||
description = "通过duckduckgo进行资料搜索"
|
||||
|
||||
class ToolInputArgs(BaseModel):
|
||||
query: str = Field(..., description="检索的关键字或问题")
|
||||
search_top: int = Field(..., description="检索返回的数量")
|
||||
region: str = Field("wt-wt", enum=["wt-wt", "us-en", "uk-en", "ru-ru"], description="搜索的区域")
|
||||
safesearch: str = Field("moderate", enum=["on", "moderate", "off"], description="")
|
||||
timelimit: str = Field(None, enum=[None, "d", "w", "m", "y"], description="查询时间方式")
|
||||
backend: str = Field("api", description="搜索的资料来源")
|
||||
|
||||
class ToolOutputArgs(BaseModel):
|
||||
"""Output for MetricsQuery."""
|
||||
title: str = Field(..., description="检索网页标题")
|
||||
snippet: str = Field(..., description="检索内容的判断")
|
||||
link: str = Field(..., description="检索网页地址")
|
||||
|
||||
@classmethod
|
||||
def run(cls, query, search_top, region="wt-wt", safesearch="moderate", timelimit=None, backend="api"):
|
||||
"""excute your tool!"""
|
||||
with DDGS(proxies=os.environ.get("DUCKDUCKGO_PROXY")) as ddgs:
|
||||
results = ddgs.text(
|
||||
query,
|
||||
region=region,
|
||||
safesearch=safesearch,
|
||||
timelimit=timelimit,
|
||||
backend=backend,
|
||||
)
|
||||
if results is None:
|
||||
return [{"Result": "No good DuckDuckGo Search Result was found"}]
|
||||
|
||||
def to_metadata(result: Dict) -> Dict[str, str]:
|
||||
if backend == "news":
|
||||
return {
|
||||
"date": result["date"],
|
||||
"title": result["title"],
|
||||
"snippet": result["body"],
|
||||
"source": result["source"],
|
||||
"link": result["url"],
|
||||
}
|
||||
return {
|
||||
"snippet": result["body"],
|
||||
"title": result["title"],
|
||||
"link": result["href"],
|
||||
}
|
||||
|
||||
formatted_results = []
|
||||
for i, res in enumerate(results, 1):
|
||||
if res is not None:
|
||||
formatted_results.append(to_metadata(res))
|
||||
if len(formatted_results) == search_top:
|
||||
break
|
||||
return formatted_results
|
||||
33
dev_opsgpt/tools/metrics_query.py
Normal file
33
dev_opsgpt/tools/metrics_query.py
Normal file
@ -0,0 +1,33 @@
|
||||
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Dict
|
||||
import requests
|
||||
import numpy as np
|
||||
from loguru import logger
|
||||
|
||||
from .base_tool import BaseToolModel
|
||||
|
||||
|
||||
|
||||
class MetricsQuery(BaseToolModel):
|
||||
name = "MetricsQuery"
|
||||
description = "查询机器的监控数据"
|
||||
|
||||
class ToolInputArgs(BaseModel):
|
||||
machine_ip: str = Field(..., description="machine_ip")
|
||||
time: int = Field(..., description="time period")
|
||||
|
||||
class ToolOutputArgs(BaseModel):
|
||||
"""Output for MetricsQuery."""
|
||||
|
||||
datas: List[float] = Field(..., description="监控时序数组")
|
||||
|
||||
def run(machine_ip, time):
|
||||
"""excute your tool!"""
|
||||
data = [0.857, 2.345, 1.234, 4.567, 3.456, 9.876, 5.678, 7.890, 6.789, 8.901, 10.987, 12.345, 11.234, 14.567, 13.456, 19.876, 15.678, 17.890,
|
||||
16.789, 18.901, 20.987, 22.345, 21.234, 24.567, 23.456, 29.876, 25.678, 27.890, 26.789, 28.901, 30.987, 32.345, 31.234, 34.567,
|
||||
33.456, 39.876, 35.678, 37.890, 36.789, 38.901, 40.987]
|
||||
return data[:30]
|
||||
38
dev_opsgpt/tools/multiplier.py
Normal file
38
dev_opsgpt/tools/multiplier.py
Normal file
@ -0,0 +1,38 @@
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Dict
|
||||
import requests
|
||||
from loguru import logger
|
||||
|
||||
from .base_tool import BaseToolModel
|
||||
|
||||
|
||||
|
||||
class Multiplier(BaseToolModel):
|
||||
"""
|
||||
Tips:
|
||||
default control Required, e.g. key1 is not Required/key2 is Required
|
||||
"""
|
||||
|
||||
name: str = "Multiplier"
|
||||
description: str = """useful for when you need to multiply two numbers together. \
|
||||
The input to this tool should be a comma separated list of numbers of length two, representing the two numbers you want to multiply together. \
|
||||
For example, `1,2` would be the input if you wanted to multiply 1 by 2."""
|
||||
|
||||
class ToolInputArgs(BaseModel):
|
||||
"""Input for Multiplier."""
|
||||
|
||||
# key: str = Field(..., description="用户在高德地图官网申请web服务API类型KEY")
|
||||
a: int = Field(..., description="num a")
|
||||
b: int = Field(..., description="num b")
|
||||
|
||||
class ToolOutputArgs(BaseModel):
|
||||
"""Output for Multiplier."""
|
||||
|
||||
res: int = Field(..., description="the result of two nums")
|
||||
|
||||
@staticmethod
|
||||
def run(a, b):
|
||||
return a * b
|
||||
|
||||
def multi_run(a, b):
|
||||
return a * b
|
||||
0
dev_opsgpt/tools/sandbox.py
Normal file
0
dev_opsgpt/tools/sandbox.py
Normal file
109
dev_opsgpt/tools/weather.py
Normal file
109
dev_opsgpt/tools/weather.py
Normal file
@ -0,0 +1,109 @@
|
||||
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Dict
|
||||
import requests
|
||||
from loguru import logger
|
||||
|
||||
from .base_tool import BaseToolModel
|
||||
|
||||
|
||||
|
||||
class WeatherInfo(BaseToolModel):
|
||||
"""
|
||||
Tips:
|
||||
default control Required, e.g. key1 is not Required/key2 is Required
|
||||
"""
|
||||
|
||||
name: str = "WeatherInfo"
|
||||
description: str = "According to the user's input adcode, it can query the current/future weather conditions of the target area."
|
||||
|
||||
class ToolInputArgs(BaseModel):
|
||||
"""Input for Weather."""
|
||||
|
||||
# key: str = Field(..., description="用户在高德地图官网申请web服务API类型KEY")
|
||||
city: str = Field(..., description="城市编码,输入城市的adcode,adcode信息可参考城市编码表")
|
||||
extensions: str = Field(default=None, enum=["base", "all"], description="气象类型,输入城市的adcode,adcode信息可参考城市编码表")
|
||||
|
||||
class ToolOutputArgs(BaseModel):
|
||||
"""Output for Weather."""
|
||||
|
||||
lives: str = Field(default=None, description="实况天气数据")
|
||||
|
||||
# @classmethod
|
||||
# def run(cls, tool_input_args: ToolInputArgs) -> ToolOutputArgs:
|
||||
# """excute your tool!"""
|
||||
# url = "https://restapi.amap.com/v3/weather/weatherInfo"
|
||||
# try:
|
||||
# json_data = tool_input_args.dict()
|
||||
# json_data["key"] = "4ceb2ef6257a627b72e3be6beab5b059"
|
||||
# res = requests.get(url, json_data)
|
||||
# return res.json()
|
||||
# except Exception as e:
|
||||
# return e
|
||||
|
||||
@staticmethod
|
||||
def run(city, extensions) -> ToolOutputArgs:
|
||||
"""excute your tool!"""
|
||||
url = "https://restapi.amap.com/v3/weather/weatherInfo"
|
||||
try:
|
||||
json_data = {}
|
||||
json_data["city"] = city
|
||||
json_data["key"] = "4ceb2ef6257a627b72e3be6beab5b059"
|
||||
json_data["extensions"] = extensions
|
||||
logger.debug(f"json_data: {json_data}")
|
||||
res = requests.get(url, params=json_data)
|
||||
return res.json()
|
||||
except Exception as e:
|
||||
return e
|
||||
|
||||
|
||||
class DistrictInfo(BaseToolModel):
|
||||
"""
|
||||
Tips:
|
||||
default control Required, e.g. key1 is not Required/key2 is Required
|
||||
"""
|
||||
|
||||
name: str = "DistrictInfo"
|
||||
description: str = "用户希望通过得到行政区域信息,进行开发工作。"
|
||||
|
||||
class ToolInputArgs(BaseModel):
|
||||
"""Input for district."""
|
||||
keywords: str = Field(default=None, description="规则:只支持单个关键词语搜索关键词支持:行政区名称、citycode、adcode例如,在subdistrict=2,搜索省份(例如山东),能够显示市(例如济南),区(例如历下区)")
|
||||
subdistrict: str = Field(default=None, enums=[1,2,3], description="""规则:设置显示下级行政区级数(行政区级别包括:国家、省/直辖市、市、区/县、乡镇/街道多级数据)
|
||||
|
||||
可选值:0、1、2、3等数字,并以此类推
|
||||
|
||||
0:不返回下级行政区;
|
||||
|
||||
1:返回下一级行政区;
|
||||
|
||||
2:返回下两级行政区;
|
||||
|
||||
3:返回下三级行政区;""")
|
||||
page: int = Field(default=1, examples=["page=2", "page=3"], description="最外层的districts最多会返回20个数据,若超过限制,请用page请求下一页数据。")
|
||||
extensions: str = Field(default=None, enum=["base", "all"], description="气象类型,输入城市的adcode,adcode信息可参考城市编码表")
|
||||
|
||||
class ToolOutputArgs(BaseModel):
|
||||
"""Output for district."""
|
||||
|
||||
districts: str = Field(default=None, description="行政区列表")
|
||||
|
||||
@staticmethod
|
||||
def run(keywords=None, subdistrict=None, page=1, extensions=None) -> ToolOutputArgs:
|
||||
"""excute your tool!"""
|
||||
url = "https://restapi.amap.com/v3/config/district"
|
||||
try:
|
||||
json_data = {}
|
||||
json_data["keywords"] = keywords
|
||||
json_data["key"] = "4ceb2ef6257a627b72e3be6beab5b059"
|
||||
json_data["subdistrict"] = subdistrict
|
||||
json_data["page"] = page
|
||||
json_data["extensions"] = extensions
|
||||
logger.debug(f"json_data: {json_data}")
|
||||
res = requests.get(url, params=json_data)
|
||||
return res.json()
|
||||
except Exception as e:
|
||||
return e
|
||||
255
dev_opsgpt/tools/world_time.py
Normal file
255
dev_opsgpt/tools/world_time.py
Normal file
@ -0,0 +1,255 @@
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import requests
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List
|
||||
|
||||
from .base_tool import BaseToolModel
|
||||
|
||||
|
||||
class WorldTimeGetTimezoneByArea(BaseToolModel):
|
||||
"""
|
||||
World Time API
|
||||
Tips:
|
||||
default control Required, e.g. key1 is not Required/key2 is Required
|
||||
"""
|
||||
|
||||
name = "WorldTime.getTimezoneByArea"
|
||||
description = "a listing of all timezones available for that area."
|
||||
|
||||
class ToolInputArgs(BaseModel):
|
||||
"""Input for WorldTimeGetTimezoneByArea."""
|
||||
area: str = Field(..., description="area")
|
||||
|
||||
class ToolOutputArgs(BaseModel):
|
||||
"""Output for WorldTimeGetTimezoneByArea."""
|
||||
DateTimeJsonResponse: str = Field(..., description="a list of available timezones")
|
||||
|
||||
@classmethod
|
||||
def run(area: str) -> ToolOutputArgs:
|
||||
"""excute your tool!"""
|
||||
url = "http://worldtimeapi.org/api/timezone"
|
||||
try:
|
||||
res = requests.get(url, json={"area": area})
|
||||
return res.text
|
||||
except Exception as e:
|
||||
return e
|
||||
|
||||
|
||||
def worldtime_run(area):
|
||||
url = "http://worldtimeapi.org/api/timezone"
|
||||
res = requests.get(url, json={"area": area})
|
||||
return res.text
|
||||
|
||||
# class WorldTime(BaseTool):
|
||||
# api_spec: str = '''
|
||||
# description: >-
|
||||
# A simple API to get the current time based on
|
||||
# a request with a timezone.
|
||||
|
||||
# servers:
|
||||
# - url: http://worldtimeapi.org/api/
|
||||
|
||||
# paths:
|
||||
# /timezone:
|
||||
# get:
|
||||
# description: a listing of all timezones.
|
||||
# operationId: getTimezone
|
||||
# responses:
|
||||
# default:
|
||||
# $ref: "#/components/responses/SuccessfulListJsonResponse"
|
||||
|
||||
# /timezone/{area}:
|
||||
# get:
|
||||
# description: a listing of all timezones available for that area.
|
||||
# operationId: getTimezoneByArea
|
||||
# parameters:
|
||||
# - name: area
|
||||
# in: path
|
||||
# required: true
|
||||
# schema:
|
||||
# type: string
|
||||
# responses:
|
||||
# '200':
|
||||
# $ref: "#/components/responses/SuccessfulListJsonResponse"
|
||||
# default:
|
||||
# $ref: "#/components/responses/ErrorJsonResponse"
|
||||
|
||||
# /timezone/{area}/{location}:
|
||||
# get:
|
||||
# description: request the current time for a timezone.
|
||||
# operationId: getTimeByTimezone
|
||||
# parameters:
|
||||
# - name: area
|
||||
# in: path
|
||||
# required: true
|
||||
# schema:
|
||||
# type: string
|
||||
# - name: location
|
||||
# in: path
|
||||
# required: true
|
||||
# schema:
|
||||
# type: string
|
||||
# responses:
|
||||
# '200':
|
||||
# $ref: "#/components/responses/SuccessfulDateTimeJsonResponse"
|
||||
# default:
|
||||
# $ref: "#/components/responses/ErrorJsonResponse"
|
||||
|
||||
# /ip:
|
||||
# get:
|
||||
# description: >-
|
||||
# request the current time based on the ip of the request.
|
||||
# note: this is a "best guess" obtained from open-source data.
|
||||
# operationId: getTimeByIP
|
||||
# responses:
|
||||
# '200':
|
||||
# $ref: "#/components/responses/SuccessfulDateTimeJsonResponse"
|
||||
# default:
|
||||
# $ref: "#/components/responses/ErrorJsonResponse"
|
||||
|
||||
# components:
|
||||
# responses:
|
||||
# SuccessfulListJsonResponse:
|
||||
# description: >-
|
||||
# the list of available timezones in JSON format
|
||||
# content:
|
||||
# application/json:
|
||||
# schema:
|
||||
# $ref: "#/components/schemas/ListJsonResponse"
|
||||
|
||||
# SuccessfulDateTimeJsonResponse:
|
||||
# description: >-
|
||||
# the current time for the timezone requested in JSON format
|
||||
# content:
|
||||
# application/json:
|
||||
# schema:
|
||||
# $ref: "#/components/schemas/DateTimeJsonResponse"
|
||||
|
||||
# ErrorJsonResponse:
|
||||
# description: >-
|
||||
# an error response in JSON format
|
||||
# content:
|
||||
# application/json:
|
||||
# schema:
|
||||
# $ref: "#/components/schemas/ErrorJsonResponse"
|
||||
|
||||
# schemas:
|
||||
# ListJsonResponse:
|
||||
# type: array
|
||||
# description: >-
|
||||
# a list of available timezones
|
||||
# items:
|
||||
# type: string
|
||||
|
||||
# DateTimeJsonResponse:
|
||||
# required:
|
||||
# - abbreviation
|
||||
# - client_ip
|
||||
# - datetime
|
||||
# - day_of_week
|
||||
# - day_of_year
|
||||
# - dst
|
||||
# - dst_offset
|
||||
# - timezone
|
||||
# - unixtime
|
||||
# - utc_datetime
|
||||
# - utc_offset
|
||||
# - week_number
|
||||
# properties:
|
||||
# abbreviation:
|
||||
# type: string
|
||||
# description: >-
|
||||
# the abbreviated name of the timezone
|
||||
# client_ip:
|
||||
# type: string
|
||||
# description: >-
|
||||
# the IP of the client making the request
|
||||
# datetime:
|
||||
# type: string
|
||||
# description: >-
|
||||
# an ISO8601-valid string representing
|
||||
# the current, local date/time
|
||||
# day_of_week:
|
||||
# type: integer
|
||||
# description: >-
|
||||
# current day number of the week, where sunday is 0
|
||||
# day_of_year:
|
||||
# type: integer
|
||||
# description: >-
|
||||
# ordinal date of the current year
|
||||
# dst:
|
||||
# type: boolean
|
||||
# description: >-
|
||||
# flag indicating whether the local
|
||||
# time is in daylight savings
|
||||
# dst_from:
|
||||
# type: string
|
||||
# description: >-
|
||||
# an ISO8601-valid string representing
|
||||
# the datetime when daylight savings
|
||||
# started for this timezone
|
||||
# dst_offset:
|
||||
# type: integer
|
||||
# description: >-
|
||||
# the difference in seconds between the current local
|
||||
# time and daylight saving time for the location
|
||||
# dst_until:
|
||||
# type: string
|
||||
# description: >-
|
||||
# an ISO8601-valid string representing
|
||||
# the datetime when daylight savings
|
||||
# will end for this timezone
|
||||
# raw_offset:
|
||||
# type: integer
|
||||
# description: >-
|
||||
# the difference in seconds between the current local time
|
||||
# and the time in UTC, excluding any daylight saving difference
|
||||
# (see dst_offset)
|
||||
# timezone:
|
||||
# type: string
|
||||
# description: >-
|
||||
# timezone in `Area/Location` or
|
||||
# `Area/Location/Region` format
|
||||
# unixtime:
|
||||
# type: integer
|
||||
# description: >-
|
||||
# number of seconds since the Epoch
|
||||
# utc_datetime:
|
||||
# type: string
|
||||
# description: >-
|
||||
# an ISO8601-valid string representing
|
||||
# the current date/time in UTC
|
||||
# utc_offset:
|
||||
# type: string
|
||||
# description: >-
|
||||
# an ISO8601-valid string representing
|
||||
# the offset from UTC
|
||||
# week_number:
|
||||
# type: integer
|
||||
# description: >-
|
||||
# the current week number
|
||||
|
||||
# ErrorJsonResponse:
|
||||
# required:
|
||||
# - error
|
||||
# properties:
|
||||
# error:
|
||||
# type: string
|
||||
# description: >-
|
||||
# details about the error encountered
|
||||
# '''
|
||||
|
||||
# def exec_tool(self, message: UserMessage) -> UserMessage:
|
||||
# match = re.search(r'{[\s\S]*}', message.content)
|
||||
# if match:
|
||||
# params = json.loads(match.group())
|
||||
# url = params["url"]
|
||||
# if "params" in params:
|
||||
# url = url.format(**params["params"])
|
||||
# res = requests.get(url)
|
||||
# response_msg = UserMessage(content=f"API response: {res.text}")
|
||||
# else:
|
||||
# raise "ERROR"
|
||||
# return response_msg
|
||||
@ -2,6 +2,11 @@ import textwrap, time, copy, random, hashlib, json, os
|
||||
from datetime import datetime, timedelta
|
||||
from functools import wraps
|
||||
from loguru import logger
|
||||
from typing import *
|
||||
from pathlib import Path
|
||||
from io import BytesIO
|
||||
from fastapi import Body, File, Form, Body, Query, UploadFile
|
||||
from tempfile import SpooledTemporaryFile
|
||||
|
||||
|
||||
|
||||
@ -65,3 +70,23 @@ def save_to_json_file(data, filename):
|
||||
|
||||
with open(filename, "w", encoding="utf-8") as f:
|
||||
json.dump(data, f, indent=2, ensure_ascii=False)
|
||||
|
||||
|
||||
def file_normalize(file: Union[str, Path, bytes], filename=None):
|
||||
logger.debug(f"{file}")
|
||||
if isinstance(file, bytes): # raw bytes
|
||||
file = BytesIO(file)
|
||||
elif hasattr(file, "read"): # a file io like object
|
||||
filename = filename or file.name
|
||||
else: # a local path
|
||||
file = Path(file).absolute().open("rb")
|
||||
logger.debug(file)
|
||||
filename = filename or file.name
|
||||
return file, filename
|
||||
|
||||
|
||||
def get_uploadfile(file: Union[str, Path, bytes], filename=None) -> UploadFile:
|
||||
temp_file = SpooledTemporaryFile(max_size=10 * 1024 * 1024)
|
||||
temp_file.write(file.read())
|
||||
temp_file.seek(0)
|
||||
return UploadFile(file=temp_file, filename=filename)
|
||||
@ -29,7 +29,7 @@ LOADER2EXT_DICT = {"UnstructuredFileLoader": ['.eml', '.html', '.md', '.msg', '.
|
||||
"TextLoader": ['.txt'],
|
||||
"PythonLoader": ['.py'],
|
||||
"JSONLoader": ['.json'],
|
||||
"JSONLLoader": ['.jsonl'],
|
||||
"JSONLLoader": ['.jsonl']
|
||||
}
|
||||
|
||||
EXT2LOADER_DICT = {ext: LOADERNAME2LOADER_DICT[k] for k, exts in LOADER2EXT_DICT.items() for ext in exts}
|
||||
@ -61,8 +61,10 @@ def list_kbs_from_folder():
|
||||
|
||||
def list_docs_from_folder(kb_name: str):
|
||||
doc_path = get_doc_path(kb_name)
|
||||
return [file for file in os.listdir(doc_path)
|
||||
if os.path.isfile(os.path.join(doc_path, file))]
|
||||
if os.path.exists(doc_path):
|
||||
return [file for file in os.listdir(doc_path)
|
||||
if os.path.isfile(os.path.join(doc_path, file))]
|
||||
return []
|
||||
|
||||
def get_LoaderClass(file_extension):
|
||||
for LoaderClass, extensions in LOADER2EXT_DICT.items():
|
||||
|
||||
@ -1,9 +1,10 @@
|
||||
from .dialogue import dialogue_page, chat_box
|
||||
from .document import knowledge_page
|
||||
from .code import code_page
|
||||
from .prompt import prompt_page
|
||||
from .utils import ApiRequest
|
||||
|
||||
__all__ = [
|
||||
"dialogue_page", "chat_box", "prompt_page", "knowledge_page",
|
||||
"ApiRequest"
|
||||
"ApiRequest", "code_page"
|
||||
]
|
||||
140
dev_opsgpt/webui/code.py
Normal file
140
dev_opsgpt/webui/code.py
Normal file
@ -0,0 +1,140 @@
|
||||
# encoding: utf-8
|
||||
'''
|
||||
@author: 温进
|
||||
@file: code.py.py
|
||||
@time: 2023/10/23 下午5:31
|
||||
@desc:
|
||||
'''
|
||||
|
||||
import streamlit as st
|
||||
import os
|
||||
import time
|
||||
import traceback
|
||||
from typing import Literal, Dict, Tuple
|
||||
from st_aggrid import AgGrid, JsCode
|
||||
from st_aggrid.grid_options_builder import GridOptionsBuilder
|
||||
import pandas as pd
|
||||
|
||||
from configs.model_config import embedding_model_dict, kbs_config, EMBEDDING_MODEL, DEFAULT_VS_TYPE, WEB_CRAWL_PATH
|
||||
from .utils import *
|
||||
from dev_opsgpt.utils.path_utils import *
|
||||
from dev_opsgpt.service.service_factory import get_cb_details, get_cb_details_by_cb_name
|
||||
from dev_opsgpt.orm import table_init
|
||||
|
||||
# SENTENCE_SIZE = 100
|
||||
|
||||
cell_renderer = JsCode("""function(params) {if(params.value==true){return '✓'}else{return '×'}}""")
|
||||
|
||||
|
||||
def file_exists(cb: str, selected_rows: List) -> Tuple[str, str]:
|
||||
'''
|
||||
check whether the dir exist in local file
|
||||
return the dir's name and path if it exists.
|
||||
'''
|
||||
if selected_rows:
|
||||
file_name = selected_rows[0]["code_name"]
|
||||
file_path = get_file_path(cb, file_name)
|
||||
if os.path.isfile(file_path):
|
||||
return file_name, file_path
|
||||
return "", ""
|
||||
|
||||
|
||||
def code_page(api: ApiRequest):
|
||||
# 判断表是否存在并进行初始化
|
||||
table_init()
|
||||
|
||||
try:
|
||||
logger.info(get_cb_details())
|
||||
cb_list = {x["code_name"]: x for x in get_cb_details()}
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
st.error("获取知识库信息错误,请检查是否已按照 `README.md` 中 `4 知识库初始化与迁移` 步骤完成初始化或迁移,或是否为数据库连接错误。")
|
||||
st.stop()
|
||||
cb_names = list(cb_list.keys())
|
||||
|
||||
if "selected_cb_name" in st.session_state and st.session_state["selected_cb_name"] in cb_names:
|
||||
selected_cb_index = cb_names.index(st.session_state["selected_cb_name"])
|
||||
else:
|
||||
selected_cb_index = 0
|
||||
|
||||
def format_selected_cb(cb_name: str) -> str:
|
||||
if cb := cb_list.get(cb_name):
|
||||
return f"{cb_name} ({cb['code_path']})"
|
||||
else:
|
||||
return cb_name
|
||||
|
||||
selected_cb = st.selectbox(
|
||||
"请选择或新建代码知识库:",
|
||||
cb_names + ["新建代码知识库"],
|
||||
format_func=format_selected_cb,
|
||||
index=selected_cb_index
|
||||
)
|
||||
|
||||
if selected_cb == "新建代码知识库":
|
||||
with st.form("新建代码知识库"):
|
||||
|
||||
cb_name = st.text_input(
|
||||
"新建代码知识库名称",
|
||||
placeholder="新代码知识库名称,不支持中文命名",
|
||||
key="cb_name",
|
||||
)
|
||||
|
||||
file = st.file_uploader("上传代码库 zip 文件",
|
||||
['.zip'],
|
||||
accept_multiple_files=False,
|
||||
)
|
||||
|
||||
submit_create_kb = st.form_submit_button(
|
||||
"新建",
|
||||
use_container_width=True,
|
||||
)
|
||||
|
||||
if submit_create_kb:
|
||||
# unzip file
|
||||
logger.info('files={}'.format(file))
|
||||
|
||||
if not cb_name or not cb_name.strip():
|
||||
st.error(f"知识库名称不能为空!")
|
||||
elif cb_name in cb_list:
|
||||
st.error(f"名为 {cb_name} 的知识库已经存在!")
|
||||
elif file.type not in ['application/zip', 'application/x-zip-compressed']:
|
||||
logger.error(f"{file.type}")
|
||||
st.error('请先上传 zip 文件,再新建代码知识库')
|
||||
else:
|
||||
ret = api.create_code_base(
|
||||
cb_name,
|
||||
file,
|
||||
no_remote_api=True
|
||||
)
|
||||
st.toast(ret.get("msg", " "))
|
||||
st.session_state["selected_cb_name"] = cb_name
|
||||
st.experimental_rerun()
|
||||
elif selected_cb:
|
||||
cb = selected_cb
|
||||
|
||||
# 知识库详情
|
||||
cb_details = get_cb_details_by_cb_name(cb)
|
||||
if not len(cb_details):
|
||||
st.info(f"代码知识库 `{cb}` 中暂无信息")
|
||||
else:
|
||||
logger.info(cb_details)
|
||||
st.write(f"代码知识库 `{cb}` 加载成功,中含有以下信息:")
|
||||
|
||||
st.write('代码知识库 `{}` 代码文件数=`{}`'.format(cb_details['code_name'],
|
||||
cb_details.get('code_file_num', 'unknown')))
|
||||
|
||||
st.write('代码知识库 `{}` 知识图谱节点数=`{}`'.format(cb_details['code_name'], cb_details['code_graph_node_num']))
|
||||
|
||||
st.divider()
|
||||
|
||||
cols = st.columns(3)
|
||||
|
||||
if cols[2].button(
|
||||
"删除知识库",
|
||||
use_container_width=True,
|
||||
):
|
||||
ret = api.delete_code_base(cb,
|
||||
no_remote_api=True)
|
||||
st.toast(ret.get("msg", " "))
|
||||
time.sleep(1)
|
||||
st.experimental_rerun()
|
||||
@ -2,10 +2,13 @@ import streamlit as st
|
||||
from streamlit_chatbox import *
|
||||
from typing import List, Dict
|
||||
from datetime import datetime
|
||||
|
||||
from random import randint
|
||||
from .utils import *
|
||||
|
||||
from dev_opsgpt.utils import *
|
||||
from dev_opsgpt.tools import TOOL_SETS
|
||||
from dev_opsgpt.chat.search_chat import SEARCH_ENGINES
|
||||
from dev_opsgpt.connector import PHASE_LIST, PHASE_CONFIGS
|
||||
|
||||
chat_box = ChatBox(
|
||||
assistant_avatar="../sources/imgs/devops-chatbot2.png"
|
||||
@ -55,7 +58,11 @@ def dialogue_page(api: ApiRequest):
|
||||
dialogue_mode = st.selectbox("请选择对话模式",
|
||||
["LLM 对话",
|
||||
"知识库问答",
|
||||
"代码知识库问答",
|
||||
"搜索引擎问答",
|
||||
"工具问答",
|
||||
"数据分析",
|
||||
"Agent问答"
|
||||
],
|
||||
on_change=on_mode_change,
|
||||
key="dialogue_mode",
|
||||
@ -67,6 +74,10 @@ def dialogue_page(api: ApiRequest):
|
||||
def on_kb_change():
|
||||
st.toast(f"已加载知识库: {st.session_state.selected_kb}")
|
||||
|
||||
def on_cb_change():
|
||||
st.toast(f"已加载代码知识库: {st.session_state.selected_cb}")
|
||||
|
||||
not_agent_qa = True
|
||||
if dialogue_mode == "知识库问答":
|
||||
with st.expander("知识库配置", True):
|
||||
kb_list = api.list_knowledge_bases(no_remote_api=True)
|
||||
@ -80,13 +91,142 @@ def dialogue_page(api: ApiRequest):
|
||||
score_threshold = st.number_input("知识匹配分数阈值:", 0.0, float(SCORE_THRESHOLD), float(SCORE_THRESHOLD), float(SCORE_THRESHOLD//100))
|
||||
# chunk_content = st.checkbox("关联上下文", False, disabled=True)
|
||||
# chunk_size = st.slider("关联长度:", 0, 500, 250, disabled=True)
|
||||
elif dialogue_mode == '代码知识库问答':
|
||||
with st.expander('代码知识库配置', True):
|
||||
cb_list = api.list_cb(no_remote_api=True)
|
||||
logger.debug('codebase_list={}'.format(cb_list))
|
||||
selected_cb = st.selectbox(
|
||||
"请选择代码知识库:",
|
||||
cb_list,
|
||||
on_change=on_cb_change,
|
||||
key="selected_cb",
|
||||
)
|
||||
st.toast(f"已加载代码知识库: {st.session_state.selected_cb}")
|
||||
cb_code_limit = st.number_input("匹配代码条数:", 1, 20, 1)
|
||||
elif dialogue_mode == "搜索引擎问答":
|
||||
with st.expander("搜索引擎配置", True):
|
||||
search_engine = st.selectbox("请选择搜索引擎", SEARCH_ENGINES.keys(), 0)
|
||||
se_top_k = st.number_input("匹配搜索结果条数:", 1, 20, 3)
|
||||
elif dialogue_mode == "工具问答":
|
||||
with st.expander("工具军火库", True):
|
||||
tool_selects = st.multiselect(
|
||||
'请选择待使用的工具', TOOL_SETS, ["WeatherInfo"])
|
||||
|
||||
elif dialogue_mode == "数据分析":
|
||||
with st.expander("沙盒文件管理", False):
|
||||
def _upload(upload_file):
|
||||
res = api.web_sd_upload(upload_file)
|
||||
logger.debug(res)
|
||||
if res["msg"]:
|
||||
st.success("上文件传成功")
|
||||
else:
|
||||
st.toast("文件上传失败")
|
||||
|
||||
code_interpreter_on = st.toggle("开启代码解释器")
|
||||
code_exec_on = st.toggle("自动执行代码")
|
||||
interpreter_file = st.file_uploader(
|
||||
"上传沙盒文件",
|
||||
[i for ls in LOADER2EXT_DICT.values() for i in ls],
|
||||
accept_multiple_files=False,
|
||||
key="interpreter_file",
|
||||
)
|
||||
|
||||
if interpreter_file:
|
||||
_upload(interpreter_file)
|
||||
interpreter_file = None
|
||||
#
|
||||
files = api.web_sd_list_files()
|
||||
files = files["data"]
|
||||
download_file = st.selectbox("选择要处理文件", files,
|
||||
key="download_file",)
|
||||
|
||||
cols = st.columns(2)
|
||||
file_url, file_name = api.web_sd_download(download_file)
|
||||
cols[0].download_button("点击下载", file_url, file_name)
|
||||
if cols[1].button("点击删除", ):
|
||||
api.web_sd_delete(download_file)
|
||||
|
||||
elif dialogue_mode == "Agent问答":
|
||||
not_agent_qa = False
|
||||
with st.expander("Phase管理", True):
|
||||
choose_phase = st.selectbox(
|
||||
'请选择待使用的执行链路', PHASE_LIST, 0)
|
||||
|
||||
is_detailed = st.toggle("返回明细的Agent交互", False)
|
||||
tool_using_on = st.toggle("开启工具使用", PHASE_CONFIGS[choose_phase]["do_using_tool"])
|
||||
tool_selects = []
|
||||
if tool_using_on:
|
||||
with st.expander("工具军火库", True):
|
||||
tool_selects = st.multiselect(
|
||||
'请选择待使用的工具', TOOL_SETS, ["WeatherInfo"])
|
||||
|
||||
search_on = st.toggle("开启搜索增强", PHASE_CONFIGS[choose_phase]["do_search"])
|
||||
search_engine, top_k = None, 3
|
||||
if search_on:
|
||||
with st.expander("搜索引擎配置", True):
|
||||
search_engine = st.selectbox("请选择搜索引擎", SEARCH_ENGINES.keys(), 0)
|
||||
top_k = st.number_input("匹配搜索结果条数:", 1, 20, 3)
|
||||
|
||||
doc_retrieval_on = st.toggle("开启知识库检索增强", PHASE_CONFIGS[choose_phase]["do_doc_retrieval"])
|
||||
selected_kb, top_k, score_threshold = None, 3, 1.0
|
||||
if doc_retrieval_on:
|
||||
with st.expander("知识库配置", True):
|
||||
kb_list = api.list_knowledge_bases(no_remote_api=True)
|
||||
selected_kb = st.selectbox(
|
||||
"请选择知识库:",
|
||||
kb_list,
|
||||
on_change=on_kb_change,
|
||||
key="selected_kb",
|
||||
)
|
||||
top_k = st.number_input("匹配知识条数:", 1, 20, 3)
|
||||
score_threshold = st.number_input("知识匹配分数阈值:", 0.0, float(SCORE_THRESHOLD), float(SCORE_THRESHOLD), float(SCORE_THRESHOLD//100))
|
||||
|
||||
code_retrieval_on = st.toggle("开启代码检索增强", PHASE_CONFIGS[choose_phase]["do_code_retrieval"])
|
||||
selected_cb, top_k = None, 1
|
||||
if code_retrieval_on:
|
||||
with st.expander('代码知识库配置', True):
|
||||
cb_list = api.list_cb(no_remote_api=True)
|
||||
logger.debug('codebase_list={}'.format(cb_list))
|
||||
selected_cb = st.selectbox(
|
||||
"请选择代码知识库:",
|
||||
cb_list,
|
||||
on_change=on_cb_change,
|
||||
key="selected_cb",
|
||||
)
|
||||
st.toast(f"已加载代码知识库: {st.session_state.selected_cb}")
|
||||
top_k = st.number_input("匹配代码条数:", 1, 20, 1)
|
||||
|
||||
with st.expander("沙盒文件管理", False):
|
||||
def _upload(upload_file):
|
||||
res = api.web_sd_upload(upload_file)
|
||||
logger.debug(res)
|
||||
if res["msg"]:
|
||||
st.success("上文件传成功")
|
||||
else:
|
||||
st.toast("文件上传失败")
|
||||
|
||||
interpreter_file = st.file_uploader(
|
||||
"上传沙盒文件",
|
||||
[i for ls in LOADER2EXT_DICT.values() for i in ls],
|
||||
accept_multiple_files=False,
|
||||
key="interpreter_file",
|
||||
)
|
||||
|
||||
if interpreter_file:
|
||||
_upload(interpreter_file)
|
||||
interpreter_file = None
|
||||
#
|
||||
files = api.web_sd_list_files()
|
||||
files = files["data"]
|
||||
download_file = st.selectbox("选择要处理文件", files,
|
||||
key="download_file",)
|
||||
|
||||
cols = st.columns(2)
|
||||
file_url, file_name = api.web_sd_download(download_file)
|
||||
cols[0].download_button("点击下载", file_url, file_name)
|
||||
if cols[1].button("点击删除", ):
|
||||
api.web_sd_delete(download_file)
|
||||
|
||||
code_interpreter_on = st.toggle("开启代码解释器") and not_agent_qa
|
||||
code_exec_on = st.toggle("自动执行代码") and not_agent_qa
|
||||
|
||||
# Display chat messages from history on app rerun
|
||||
|
||||
@ -102,7 +242,97 @@ def dialogue_page(api: ApiRequest):
|
||||
if dialogue_mode == "LLM 对话":
|
||||
chat_box.ai_say("正在思考...")
|
||||
text = ""
|
||||
r = api.chat_chat(prompt, history)
|
||||
r = api.chat_chat(prompt, history, no_remote_api=True)
|
||||
for t in r:
|
||||
if error_msg := check_error_msg(t): # check whether error occured
|
||||
st.error(error_msg)
|
||||
break
|
||||
text += t["answer"]
|
||||
chat_box.update_msg(text)
|
||||
logger.debug(f"text: {text}")
|
||||
chat_box.update_msg(text, streaming=False) # 更新最终的字符串,去除光标
|
||||
# 判断是否存在代码, 并提高编辑功能,执行功能
|
||||
code_text = api.codebox.decode_code_from_text(text)
|
||||
GLOBAL_EXE_CODE_TEXT = code_text
|
||||
if code_text and code_exec_on:
|
||||
codebox_res = api.codebox_chat("```"+code_text+"```", do_code_exe=True)
|
||||
elif dialogue_mode == "Agent问答":
|
||||
display_infos = [f"正在思考..."]
|
||||
if search_on:
|
||||
display_infos.append(Markdown("...", in_expander=True, title="网络搜索结果"))
|
||||
if doc_retrieval_on:
|
||||
display_infos.append(Markdown("...", in_expander=True, title="知识库匹配结果"))
|
||||
chat_box.ai_say(display_infos)
|
||||
|
||||
if 'history_node_list' in st.session_state:
|
||||
history_node_list: List[str] = st.session_state['history_node_list']
|
||||
else:
|
||||
history_node_list: List[str] = []
|
||||
|
||||
input_kargs = {"query": prompt,
|
||||
"phase_name": choose_phase,
|
||||
"history": history,
|
||||
"doc_engine_name": selected_kb,
|
||||
"search_engine_name": search_engine,
|
||||
"code_engine_name": selected_cb,
|
||||
"top_k": top_k,
|
||||
"score_threshold": score_threshold,
|
||||
"do_search": search_on,
|
||||
"do_doc_retrieval": doc_retrieval_on,
|
||||
"do_code_retrieval": code_retrieval_on,
|
||||
"do_tool_retrieval": False,
|
||||
"custom_phase_configs": {},
|
||||
"custom_chain_configs": {},
|
||||
"custom_role_configs": {},
|
||||
"choose_tools": tool_selects,
|
||||
"history_node_list": history_node_list,
|
||||
"isDetailed": is_detailed,
|
||||
}
|
||||
text = ""
|
||||
d = {"docs": []}
|
||||
for idx_count, d in enumerate(api.agent_chat(**input_kargs)):
|
||||
if error_msg := check_error_msg(d): # check whether error occured
|
||||
st.error(error_msg)
|
||||
text += d["answer"]
|
||||
if idx_count%20 == 0:
|
||||
chat_box.update_msg(text, element_index=0)
|
||||
|
||||
for k, v in d["figures"].items():
|
||||
logger.debug(f"figure: {k}")
|
||||
if k in text:
|
||||
img_html = "\n<img src='data:image/png;base64,{}' class='img-fluid'>\n".format(v)
|
||||
text = text.replace(k, img_html).replace(".png", "")
|
||||
chat_box.update_msg(text, element_index=0, streaming=False, state="complete") # 更新最终的字符串,去除光标
|
||||
if search_on:
|
||||
chat_box.update_msg("搜索匹配结果:\n\n" + "\n\n".join(d["search_docs"]), element_index=search_on, streaming=False, state="complete")
|
||||
if doc_retrieval_on:
|
||||
chat_box.update_msg("知识库匹配结果:\n\n" + "\n\n".join(d["db_docs"]), element_index=search_on+doc_retrieval_on, streaming=False, state="complete")
|
||||
|
||||
history_node_list.extend([node[0] for node in d.get("related_nodes", [])])
|
||||
history_node_list = list(set(history_node_list))
|
||||
st.session_state['history_node_list'] = history_node_list
|
||||
|
||||
elif dialogue_mode == "工具问答":
|
||||
chat_box.ai_say("正在思考...")
|
||||
text = ""
|
||||
r = api.tool_chat(prompt, history, tool_sets=tool_selects)
|
||||
for t in r:
|
||||
if error_msg := check_error_msg(t): # check whether error occured
|
||||
st.error(error_msg)
|
||||
break
|
||||
text += t["answer"]
|
||||
chat_box.update_msg(text)
|
||||
logger.debug(f"text: {text}")
|
||||
chat_box.update_msg(text, streaming=False) # 更新最终的字符串,去除光标
|
||||
# 判断是否存在代码, 并提高编辑功能,执行功能
|
||||
code_text = api.codebox.decode_code_from_text(text)
|
||||
GLOBAL_EXE_CODE_TEXT = code_text
|
||||
if code_text and code_exec_on:
|
||||
codebox_res = api.codebox_chat("```"+code_text+"```", do_code_exe=True)
|
||||
elif dialogue_mode == "数据分析":
|
||||
chat_box.ai_say("正在思考...")
|
||||
text = ""
|
||||
r = api.data_chat(prompt, history)
|
||||
for t in r:
|
||||
if error_msg := check_error_msg(t): # check whether error occured
|
||||
st.error(error_msg)
|
||||
@ -116,7 +346,6 @@ def dialogue_page(api: ApiRequest):
|
||||
GLOBAL_EXE_CODE_TEXT = code_text
|
||||
if code_text and code_exec_on:
|
||||
codebox_res = api.codebox_chat("```"+code_text+"```", do_code_exe=True)
|
||||
|
||||
elif dialogue_mode == "知识库问答":
|
||||
history = get_messages_history(history_len)
|
||||
chat_box.ai_say([
|
||||
@ -124,6 +353,7 @@ def dialogue_page(api: ApiRequest):
|
||||
Markdown("...", in_expander=True, title="知识库匹配结果"),
|
||||
])
|
||||
text = ""
|
||||
d = {"docs": []}
|
||||
for idx_count, d in enumerate(api.knowledge_base_chat(prompt, selected_kb, kb_top_k, score_threshold, history)):
|
||||
if error_msg := check_error_msg(d): # check whether error occured
|
||||
st.error(error_msg)
|
||||
@ -138,6 +368,33 @@ def dialogue_page(api: ApiRequest):
|
||||
GLOBAL_EXE_CODE_TEXT = code_text
|
||||
if code_text and code_exec_on:
|
||||
codebox_res = api.codebox_chat("```"+code_text+"```", do_code_exe=True)
|
||||
elif dialogue_mode == '代码知识库问答':
|
||||
logger.info('prompt={}'.format(prompt))
|
||||
logger.info('history={}'.format(history))
|
||||
if 'history_node_list' in st.session_state:
|
||||
api.codeChat.history_node_list = st.session_state['history_node_list']
|
||||
|
||||
chat_box.ai_say([
|
||||
f"正在查询代码知识库 `{selected_cb}` ...",
|
||||
Markdown("...", in_expander=True, title="代码库匹配结果"),
|
||||
])
|
||||
text = ""
|
||||
d = {"codes": []}
|
||||
|
||||
for idx_count, d in enumerate(api.code_base_chat(query=prompt, code_base_name=selected_cb,
|
||||
code_limit=cb_code_limit, history=history,
|
||||
no_remote_api=True)):
|
||||
if error_msg := check_error_msg(d):
|
||||
st.error(error_msg)
|
||||
text += d["answer"]
|
||||
if idx_count % 10 == 0:
|
||||
chat_box.update_msg(text, element_index=0)
|
||||
chat_box.update_msg(text, element_index=0, streaming=False) # 更新最终的字符串,去除光标
|
||||
chat_box.update_msg("\n".join(d["codes"]), element_index=1, streaming=False, state="complete")
|
||||
|
||||
# session state update
|
||||
st.session_state['history_node_list'] = api.codeChat.history_node_list
|
||||
|
||||
elif dialogue_mode == "搜索引擎问答":
|
||||
chat_box.ai_say([
|
||||
f"正在执行 `{search_engine}` 搜索...",
|
||||
@ -145,7 +402,7 @@ def dialogue_page(api: ApiRequest):
|
||||
])
|
||||
text = ""
|
||||
d = {"docs": []}
|
||||
for d in api.search_engine_chat(prompt, search_engine, se_top_k):
|
||||
for idx_count, d in enumerate(api.search_engine_chat(prompt, search_engine, se_top_k)):
|
||||
if error_msg := check_error_msg(d): # check whether error occured
|
||||
st.error(error_msg)
|
||||
text += d["answer"]
|
||||
@ -194,7 +451,7 @@ def dialogue_page(api: ApiRequest):
|
||||
img_html = "<img src='data:image/png;base64,{}' class='img-fluid'>".format(
|
||||
codebox_res.code_exe_response
|
||||
)
|
||||
chat_box.update_msg(base_text + img_html, streaming=False, state="complete")
|
||||
chat_box.update_msg(img_html, streaming=False, state="complete")
|
||||
else:
|
||||
chat_box.update_msg('```\n'+code_text+'\n```'+"\n\n"+'```\n'+codebox_res.code_exe_response+'\n```',
|
||||
streaming=False, state="complete")
|
||||
|
||||