update codebase & multi-agent framework
This commit is contained in:
parent
6bc0ca45ce
commit
d5e2bb7acc
|
@ -1,6 +1,10 @@
|
|||
**/__pycache__
|
||||
knowledge_base
|
||||
logs
|
||||
embedding_models
|
||||
jupyter_work
|
||||
model_config.py
|
||||
server_config.py
|
||||
code_base
|
||||
.DS_Store
|
||||
.idea
|
||||
|
|
12
Dockerfile
12
Dockerfile
|
@ -1,10 +1,18 @@
|
|||
From python:3.9-bookworm
|
||||
From python:3.9.18-bookworm
|
||||
|
||||
WORKDIR /home/user
|
||||
|
||||
COPY ./docker_requirements.txt /home/user/docker_requirements.txt
|
||||
COPY ./requirements.txt /home/user/docker_requirements.txt
|
||||
COPY ./jupyter_start.sh /home/user/jupyter_start.sh
|
||||
|
||||
|
||||
RUN apt-get update
|
||||
RUN apt-get install -y iputils-ping telnetd net-tools vim tcpdump
|
||||
# RUN echo telnet stream tcp nowait telnetd /usr/sbin/tcpd /usr/sbin/in.telnetd /etc/inetd.conf
|
||||
# RUN service inetutils-inetd start
|
||||
# service inetutils-inetd status
|
||||
|
||||
|
||||
RUN pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple
|
||||
RUN pip install -r /home/user/docker_requirements.txt
|
||||
|
||||
|
|
|
@ -0,0 +1,201 @@
|
|||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
|
@ -33,6 +33,10 @@ embedding_model_dict = {
|
|||
"bge-large-zh": "BAAI/bge-large-zh"
|
||||
}
|
||||
|
||||
|
||||
LOCAL_MODEL_DIR = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "embedding_models")
|
||||
embedding_model_dict = {k: f"/home/user/chatbot/embedding_models/{v}" if is_running_in_docker() else f"{LOCAL_MODEL_DIR}/{v}" for k, v in embedding_model_dict.items()}
|
||||
|
||||
# 选用的 Embedding 名称
|
||||
EMBEDDING_MODEL = "text2vec-base"
|
||||
|
||||
|
@ -97,6 +101,7 @@ llm_model_dict = {
|
|||
|
||||
# LLM 名称
|
||||
LLM_MODEL = "gpt-3.5-turbo"
|
||||
USE_FASTCHAT = "gpt" not in LLM_MODEL # 判断是否进行fastchat
|
||||
|
||||
# LLM 运行设备
|
||||
LLM_DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
||||
|
@ -112,6 +117,9 @@ SOURCE_PATH = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__fil
|
|||
# 知识库默认存储路径
|
||||
KB_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "knowledge_base")
|
||||
|
||||
# 代码库默认存储路径
|
||||
CB_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "code_base")
|
||||
|
||||
# nltk 模型存储路径
|
||||
NLTK_DATA_PATH = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "nltk_data")
|
||||
|
||||
|
@ -153,7 +161,7 @@ DEFAULT_VS_TYPE = "faiss"
|
|||
CACHED_VS_NUM = 1
|
||||
|
||||
# 知识库中单段文本长度
|
||||
CHUNK_SIZE = 250
|
||||
CHUNK_SIZE = 500
|
||||
|
||||
# 知识库中相邻文本重合长度
|
||||
OVERLAP_SIZE = 50
|
||||
|
@ -169,6 +177,9 @@ SCORE_THRESHOLD = 1 if system_name in ["Linux", "Windows"] else 1100
|
|||
# 搜索引擎匹配结题数量
|
||||
SEARCH_ENGINE_TOP_K = 5
|
||||
|
||||
# 代码引擎匹配结题数量
|
||||
CODE_SEARCH_TOP_K = 1
|
||||
|
||||
# 基于本地知识问答的提示词模版
|
||||
PROMPT_TEMPLATE = """【指令】根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,不允许在答案中添加编造成分,答案请使用中文。
|
||||
|
||||
|
@ -176,6 +187,13 @@ PROMPT_TEMPLATE = """【指令】根据已知信息,简洁和专业的来回
|
|||
|
||||
【问题】{question}"""
|
||||
|
||||
# 基于本地代码知识问答的提示词模版
|
||||
CODE_PROMPT_TEMPLATE = """【指令】根据已知信息来回答问题。
|
||||
|
||||
【已知信息】{context}
|
||||
|
||||
【问题】{question}"""
|
||||
|
||||
# API 是否开启跨域,默认为False,如果需要开启,请设置为True
|
||||
# is open cross domain
|
||||
OPEN_CROSS_DOMAIN = False
|
||||
|
|
|
@ -1,41 +1,62 @@
|
|||
from .model_config import LLM_MODEL, LLM_DEVICE
|
||||
import os
|
||||
|
||||
# API 是否开启跨域,默认为False,如果需要开启,请设置为True
|
||||
# is open cross domain
|
||||
OPEN_CROSS_DOMAIN = False
|
||||
|
||||
# 是否用容器来启动服务
|
||||
DOCKER_SERVICE = True
|
||||
# 是否采用容器沙箱
|
||||
SANDBOX_DO_REMOTE = True
|
||||
# 是否采用api服务来进行
|
||||
NO_REMOTE_API = True
|
||||
# 各服务器默认绑定host
|
||||
DEFAULT_BIND_HOST = "127.0.0.1"
|
||||
|
||||
#
|
||||
CONTRAINER_NAME = "devopsgpt_webui"
|
||||
IMAGE_NAME = "devopsgpt:py39"
|
||||
|
||||
# webui.py server
|
||||
WEBUI_SERVER = {
|
||||
"host": DEFAULT_BIND_HOST,
|
||||
"port": 8501,
|
||||
"docker_port": 8501
|
||||
}
|
||||
|
||||
# api.py server
|
||||
API_SERVER = {
|
||||
"host": DEFAULT_BIND_HOST,
|
||||
"port": 7861,
|
||||
"docker_port": 7861
|
||||
}
|
||||
|
||||
# sdfile_api.py server
|
||||
SDFILE_API_SERVER = {
|
||||
"host": DEFAULT_BIND_HOST,
|
||||
"port": 7862,
|
||||
"docker_port": 7862
|
||||
}
|
||||
|
||||
# fastchat openai_api server
|
||||
FSCHAT_OPENAI_API = {
|
||||
"host": DEFAULT_BIND_HOST,
|
||||
"port": 8888, # model_config.llm_model_dict中模型配置的api_base_url需要与这里一致。
|
||||
"docker_port": 8888, # model_config.llm_model_dict中模型配置的api_base_url需要与这里一致。
|
||||
}
|
||||
|
||||
# sandbox api server
|
||||
CONTRAINER_NAME = "devopsgt_default"
|
||||
IMAGE_NAME = "devopsgpt:pypy38"
|
||||
SANDBOX_CONTRAINER_NAME = "devopsgpt_sandbox"
|
||||
SANDBOX_IMAGE_NAME = "devopsgpt:py39"
|
||||
SANDBOX_HOST = os.environ.get("SANDBOX_HOST") or DEFAULT_BIND_HOST # "172.25.0.3"
|
||||
SANDBOX_SERVER = {
|
||||
"host": DEFAULT_BIND_HOST,
|
||||
"host": f"http://{SANDBOX_HOST}",
|
||||
"port": 5050,
|
||||
"url": "http://localhost:5050",
|
||||
"do_remote": True,
|
||||
"docker_port": 5050,
|
||||
"url": f"http://{SANDBOX_HOST}:5050",
|
||||
"do_remote": SANDBOX_DO_REMOTE,
|
||||
}
|
||||
|
||||
|
||||
# fastchat model_worker server
|
||||
# 这些模型必须是在model_config.llm_model_dict中正确配置的。
|
||||
# 在启动startup.py时,可用通过`--model-worker --model-name xxxx`指定模型,不指定则为LLM_MODEL
|
||||
|
|
|
@ -0,0 +1,15 @@
|
|||
import os
|
||||
|
||||
def is_running_in_docker():
|
||||
"""
|
||||
检查当前代码是否在 Docker 容器中运行
|
||||
"""
|
||||
# 检查是否存在 /.dockerenv 文件
|
||||
if os.path.exists('/.dockerenv'):
|
||||
return True
|
||||
|
||||
# 检查 cgroup 文件系统是否为 /docker/ 开头
|
||||
if os.path.exists("/proc/1/cgroup"):
|
||||
with open('/proc/1/cgroup', 'rt') as f:
|
||||
return '/docker/' in f.read()
|
||||
return False
|
|
@ -2,7 +2,12 @@ from .base_chat import Chat
|
|||
from .knowledge_chat import KnowledgeChat
|
||||
from .llm_chat import LLMChat
|
||||
from .search_chat import SearchChat
|
||||
from .tool_chat import ToolChat
|
||||
from .data_chat import DataChat
|
||||
from .code_chat import CodeChat
|
||||
from .agent_chat import AgentChat
|
||||
|
||||
|
||||
__all__ = [
|
||||
"Chat", "KnowledgeChat", "LLMChat", "SearchChat"
|
||||
"Chat", "KnowledgeChat", "LLMChat", "SearchChat", "ToolChat", "DataChat", "CodeChat", "AgentChat"
|
||||
]
|
||||
|
|
|
@ -0,0 +1,169 @@
|
|||
from fastapi import Body, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from typing import List
|
||||
from loguru import logger
|
||||
import importlib
|
||||
import copy
|
||||
import json
|
||||
|
||||
from configs.model_config import (
|
||||
llm_model_dict, LLM_MODEL, PROMPT_TEMPLATE,
|
||||
VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD)
|
||||
|
||||
from dev_opsgpt.tools import (
|
||||
toLangchainTools,
|
||||
TOOL_DICT, TOOL_SETS
|
||||
)
|
||||
|
||||
from dev_opsgpt.connector.phase import BasePhase
|
||||
from dev_opsgpt.connector.agents import BaseAgent, ReactAgent
|
||||
from dev_opsgpt.connector.chains import BaseChain
|
||||
from dev_opsgpt.connector.connector_schema import (
|
||||
Message,
|
||||
load_phase_configs, load_chain_configs, load_role_configs
|
||||
)
|
||||
from dev_opsgpt.connector.shcema import Memory
|
||||
|
||||
from dev_opsgpt.chat.utils import History, wrap_done
|
||||
from dev_opsgpt.connector.configs import PHASE_CONFIGS, AGETN_CONFIGS, CHAIN_CONFIGS
|
||||
|
||||
PHASE_MODULE = importlib.import_module("dev_opsgpt.connector.phase")
|
||||
|
||||
|
||||
|
||||
class AgentChat:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine_name: str = "",
|
||||
top_k: int = 1,
|
||||
stream: bool = False,
|
||||
) -> None:
|
||||
self.top_k = top_k
|
||||
self.stream = stream
|
||||
|
||||
def chat(
|
||||
self,
|
||||
query: str = Body(..., description="用户输入", examples=["hello"]),
|
||||
phase_name: str = Body(..., description="执行场景名称", examples=["chatPhase"]),
|
||||
chain_name: str = Body(..., description="执行链的名称", examples=["chatChain"]),
|
||||
history: List[History] = Body(
|
||||
[], description="历史对话",
|
||||
examples=[[{"role": "user", "content": "我们来玩成语接龙,我先来,生龙活虎"}]]
|
||||
),
|
||||
doc_engine_name: str = Body(..., description="知识库名称", examples=["samples"]),
|
||||
search_engine_name: str = Body(..., description="搜索引擎名称", examples=["duckduckgo"]),
|
||||
code_engine_name: str = Body(..., description="代码引擎名称", examples=["samples"]),
|
||||
top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"),
|
||||
score_threshold: float = Body(SCORE_THRESHOLD, description="知识库匹配相关度阈值,取值范围在0-1之间,SCORE越小,相关度越高,取到1相当于不筛选,建议设置在0.5左右", ge=0, le=1),
|
||||
stream: bool = Body(False, description="流式输出"),
|
||||
local_doc_url: bool = Body(False, description="知识文件返回本地路径(true)或URL(false)"),
|
||||
choose_tools: List[str] = Body([], description="选择tool的集合"),
|
||||
do_search: bool = Body(False, description="是否进行搜索"),
|
||||
do_doc_retrieval: bool = Body(False, description="是否进行知识库检索"),
|
||||
do_code_retrieval: bool = Body(False, description="是否执行代码检索"),
|
||||
do_tool_retrieval: bool = Body(False, description="是否执行工具检索"),
|
||||
custom_phase_configs: dict = Body({}, description="自定义phase配置"),
|
||||
custom_chain_configs: dict = Body({}, description="自定义chain配置"),
|
||||
custom_role_configs: dict = Body({}, description="自定义role配置"),
|
||||
history_node_list: List = Body([], description="代码历史相关节点"),
|
||||
isDetaild: bool = Body([], description="是否输出完整的agent相关内容"),
|
||||
**kargs
|
||||
) -> Message:
|
||||
|
||||
# update configs
|
||||
phase_configs, chain_configs, agent_configs = self.update_configs(
|
||||
custom_phase_configs, custom_chain_configs, custom_role_configs)
|
||||
# choose tools
|
||||
tools = toLangchainTools([TOOL_DICT[i] for i in choose_tools if i in TOOL_DICT])
|
||||
input_message = Message(
|
||||
role_content=query,
|
||||
role_type="human",
|
||||
role_name="user",
|
||||
input_query=query,
|
||||
phase_name=phase_name,
|
||||
chain_name=chain_name,
|
||||
do_search=do_search,
|
||||
do_doc_retrieval=do_doc_retrieval,
|
||||
do_code_retrieval=do_code_retrieval,
|
||||
do_tool_retrieval=do_tool_retrieval,
|
||||
doc_engine_name=doc_engine_name, search_engine_name=search_engine_name,
|
||||
code_engine_name=code_engine_name,
|
||||
score_threshold=score_threshold, top_k=top_k,
|
||||
history_node_list=history_node_list,
|
||||
tools=tools
|
||||
)
|
||||
# history memory mangemant
|
||||
history = Memory([
|
||||
Message(role_name=i["role"], role_type=i["role"], role_content=i["content"])
|
||||
for i in history
|
||||
])
|
||||
# start to execute
|
||||
phase_class = getattr(PHASE_MODULE, phase_configs[input_message.phase_name]["phase_type"])
|
||||
phase = phase_class(input_message.phase_name,
|
||||
task = input_message.task,
|
||||
phase_config = phase_configs,
|
||||
chain_config = chain_configs,
|
||||
role_config = agent_configs,
|
||||
do_summary=phase_configs[input_message.phase_name]["do_summary"],
|
||||
do_code_retrieval=input_message.do_code_retrieval,
|
||||
do_doc_retrieval=input_message.do_doc_retrieval,
|
||||
do_search=input_message.do_search,
|
||||
)
|
||||
output_message, local_memory = phase.step(input_message, history)
|
||||
# logger.debug(f"local_memory: {local_memory.to_str_messages(content_key='step_content')}")
|
||||
|
||||
# return {
|
||||
# "answer": output_message.role_content,
|
||||
# "db_docs": output_message.db_docs,
|
||||
# "search_docs": output_message.search_docs,
|
||||
# "code_docs": output_message.code_docs,
|
||||
# "figures": output_message.figures
|
||||
# }
|
||||
|
||||
def chat_iterator(message: Message, local_memory: Memory, isDetaild=False):
|
||||
result = {
|
||||
"answer": "",
|
||||
"db_docs": [str(doc) for doc in message.db_docs],
|
||||
"search_docs": [str(doc) for doc in message.search_docs],
|
||||
"code_docs": [str(doc) for doc in message.code_docs],
|
||||
"related_nodes": [doc.get_related_node() for idx, doc in enumerate(message.code_docs) if idx==0],
|
||||
"figures": message.figures
|
||||
}
|
||||
|
||||
|
||||
related_nodes, has_nodes = [], [ ]
|
||||
for nodes in result["related_nodes"]:
|
||||
for node in nodes:
|
||||
if node not in has_nodes:
|
||||
related_nodes.append(node)
|
||||
result["related_nodes"] = related_nodes
|
||||
|
||||
# logger.debug(f"{result['figures'].keys()}")
|
||||
message_str = local_memory.to_str_messages(content_key='step_content') if isDetaild else message.role_content
|
||||
if self.stream:
|
||||
for token in message_str:
|
||||
result["answer"] = token
|
||||
yield json.dumps(result, ensure_ascii=False)
|
||||
else:
|
||||
for token in message_str:
|
||||
result["answer"] += token
|
||||
yield json.dumps(result, ensure_ascii=False)
|
||||
|
||||
return StreamingResponse(chat_iterator(output_message, local_memory, isDetaild), media_type="text/event-stream")
|
||||
|
||||
def _chat(self, ):
|
||||
pass
|
||||
|
||||
def update_configs(self, custom_phase_configs, custom_chain_configs, custom_role_configs):
|
||||
'''update phase/chain/agent configs'''
|
||||
phase_configs = copy.deepcopy(PHASE_CONFIGS)
|
||||
phase_configs.update(custom_phase_configs)
|
||||
chain_configs = copy.deepcopy(CHAIN_CONFIGS)
|
||||
chain_configs.update(custom_chain_configs)
|
||||
agent_configs = copy.deepcopy(AGETN_CONFIGS)
|
||||
agent_configs.update(custom_role_configs)
|
||||
# phase_configs = load_phase_configs(new_phase_configs)
|
||||
# chian_configs = load_chain_configs(new_chain_configs)
|
||||
# agent_configs = load_role_configs(new_agent_configs)
|
||||
return phase_configs, chain_configs, agent_configs
|
|
@ -3,12 +3,11 @@ from fastapi.responses import StreamingResponse
|
|||
import asyncio, json
|
||||
from typing import List, AsyncIterable
|
||||
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain import LLMChain
|
||||
from langchain.callbacks import AsyncIteratorCallbackHandler
|
||||
from langchain.prompts.chat import ChatPromptTemplate
|
||||
|
||||
|
||||
from dev_opsgpt.llm_models import getChatModel
|
||||
from dev_opsgpt.chat.utils import History, wrap_done
|
||||
from configs.model_config import (llm_model_dict, LLM_MODEL, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD)
|
||||
from dev_opsgpt.utils import BaseResponse
|
||||
|
@ -16,30 +15,6 @@ from loguru import logger
|
|||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def getChatModel(callBack: AsyncIteratorCallbackHandler = None):
|
||||
if callBack is None:
|
||||
model = ChatOpenAI(
|
||||
streaming=True,
|
||||
verbose=True,
|
||||
openai_api_key=llm_model_dict[LLM_MODEL]["api_key"],
|
||||
openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"],
|
||||
model_name=LLM_MODEL
|
||||
)
|
||||
else:
|
||||
model = ChatOpenAI(
|
||||
streaming=True,
|
||||
verbose=True,
|
||||
callBack=[callBack],
|
||||
openai_api_key=llm_model_dict[LLM_MODEL]["api_key"],
|
||||
openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"],
|
||||
model_name=LLM_MODEL
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
class Chat:
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -67,6 +42,7 @@ class Chat:
|
|||
stream: bool = Body(False, description="流式输出"),
|
||||
local_doc_url: bool = Body(False, description="知识文件返回本地路径(true)或URL(false)"),
|
||||
request: Request = None,
|
||||
**kargs
|
||||
):
|
||||
self.engine_name = engine_name if isinstance(engine_name, str) else engine_name.default
|
||||
self.top_k = top_k if isinstance(top_k, int) else top_k.default
|
||||
|
@ -74,18 +50,23 @@ class Chat:
|
|||
self.stream = stream if isinstance(stream, bool) else stream.default
|
||||
self.local_doc_url = local_doc_url if isinstance(local_doc_url, bool) else local_doc_url.default
|
||||
self.request = request
|
||||
return self._chat(query, history)
|
||||
return self._chat(query, history, **kargs)
|
||||
|
||||
def _chat(self, query: str, history: List[History]):
|
||||
def _chat(self, query: str, history: List[History], **kargs):
|
||||
history = [History(**h) if isinstance(h, dict) else h for h in history]
|
||||
|
||||
## check service dependcy is ok
|
||||
service_status = self.check_service_status()
|
||||
|
||||
if service_status.code!=200: return service_status
|
||||
|
||||
def chat_iterator(query: str, history: List[History]):
|
||||
model = getChatModel()
|
||||
|
||||
result ,content = self.create_task(query, history, model)
|
||||
result, content = self.create_task(query, history, model, **kargs)
|
||||
logger.info('result={}'.format(result))
|
||||
logger.info('content={}'.format(content))
|
||||
|
||||
if self.stream:
|
||||
for token in content["text"]:
|
||||
result["answer"] = token
|
||||
|
@ -144,7 +125,7 @@ class Chat:
|
|||
return StreamingResponse(chat_iterator(query, history),
|
||||
media_type="text/event-stream")
|
||||
|
||||
def create_task(self, query: str, history: List[History], model):
|
||||
def create_task(self, query: str, history: List[History], model, **kargs):
|
||||
'''构建 llm 生成任务'''
|
||||
chat_prompt = ChatPromptTemplate.from_messages(
|
||||
[i.to_msg_tuple() for i in history] + [("human", "{input}")]
|
||||
|
|
|
@ -0,0 +1,143 @@
|
|||
# encoding: utf-8
|
||||
'''
|
||||
@author: 温进
|
||||
@file: code_chat.py
|
||||
@time: 2023/10/24 下午4:04
|
||||
@desc:
|
||||
'''
|
||||
|
||||
from fastapi import Request, Body
|
||||
import os, asyncio
|
||||
from urllib.parse import urlencode
|
||||
from typing import List
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from langchain import LLMChain
|
||||
from langchain.callbacks import AsyncIteratorCallbackHandler
|
||||
from langchain.prompts.chat import ChatPromptTemplate
|
||||
|
||||
from configs.model_config import (
|
||||
llm_model_dict, LLM_MODEL, PROMPT_TEMPLATE,
|
||||
VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, CODE_PROMPT_TEMPLATE)
|
||||
from dev_opsgpt.chat.utils import History, wrap_done
|
||||
from dev_opsgpt.utils import BaseResponse
|
||||
from .base_chat import Chat
|
||||
from dev_opsgpt.llm_models import getChatModel
|
||||
|
||||
from dev_opsgpt.service.kb_api import search_docs, KBServiceFactory
|
||||
from dev_opsgpt.service.cb_api import search_code, cb_exists_api
|
||||
from loguru import logger
|
||||
import json
|
||||
|
||||
|
||||
class CodeChat(Chat):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
code_base_name: str = '',
|
||||
code_limit: int = 1,
|
||||
stream: bool = False,
|
||||
request: Request = None,
|
||||
) -> None:
|
||||
super().__init__(engine_name=code_base_name, stream=stream)
|
||||
self.engine_name = code_base_name
|
||||
self.code_limit = code_limit
|
||||
self.request = request
|
||||
self.history_node_list = []
|
||||
|
||||
def check_service_status(self) -> BaseResponse:
|
||||
cb = cb_exists_api(self.engine_name)
|
||||
if not cb:
|
||||
return BaseResponse(code=404, msg=f"未找到代码库 {self.engine_name}")
|
||||
return BaseResponse(code=200, msg=f"找到代码库 {self.engine_name}")
|
||||
|
||||
def _process(self, query: str, history: List[History], model):
|
||||
'''process'''
|
||||
codes_res = search_code(query=query, cb_name=self.engine_name, code_limit=self.code_limit,
|
||||
history_node_list=self.history_node_list)
|
||||
|
||||
codes = codes_res['related_code']
|
||||
nodes = codes_res['related_node']
|
||||
|
||||
# update node names
|
||||
node_names = [node[0] for node in nodes]
|
||||
self.history_node_list.extend(node_names)
|
||||
self.history_node_list = list(set(self.history_node_list))
|
||||
|
||||
context = "\n".join(codes)
|
||||
source_nodes = []
|
||||
|
||||
for inum, node_info in enumerate(nodes[0:5]):
|
||||
node_name, node_type, node_score = node_info[0], node_info[1], node_info[2]
|
||||
source_nodes.append(f'{inum + 1}. 节点名为 {node_name}, 节点类型为 `{node_type}`, 节点得分为 `{node_score}`')
|
||||
|
||||
chat_prompt = ChatPromptTemplate.from_messages(
|
||||
[i.to_msg_tuple() for i in history] + [("human", CODE_PROMPT_TEMPLATE)]
|
||||
)
|
||||
chain = LLMChain(prompt=chat_prompt, llm=model)
|
||||
result = {"answer": "", "codes": source_nodes}
|
||||
return chain, context, result
|
||||
|
||||
def chat(
|
||||
self,
|
||||
query: str = Body(..., description="用户输入", examples=["hello"]),
|
||||
history: List[History] = Body(
|
||||
[], description="历史对话",
|
||||
examples=[[{"role": "user", "content": "我们来玩成语接龙,我先来,生龙活虎"}]]
|
||||
),
|
||||
engine_name: str = Body(..., description="知识库名称", examples=["samples"]),
|
||||
code_limit: int = Body(1, examples=['1']),
|
||||
stream: bool = Body(False, description="流式输出"),
|
||||
local_doc_url: bool = Body(False, description="知识文件返回本地路径(true)或URL(false)"),
|
||||
request: Request = Body(None),
|
||||
**kargs
|
||||
):
|
||||
self.engine_name = engine_name if isinstance(engine_name, str) else engine_name.default
|
||||
self.code_limit = code_limit
|
||||
self.stream = stream if isinstance(stream, bool) else stream.default
|
||||
self.local_doc_url = local_doc_url if isinstance(local_doc_url, bool) else local_doc_url.default
|
||||
self.request = request
|
||||
return self._chat(query, history, **kargs)
|
||||
|
||||
def _chat(self, query: str, history: List[History], **kargs):
|
||||
history = [History(**h) if isinstance(h, dict) else h for h in history]
|
||||
|
||||
service_status = self.check_service_status()
|
||||
|
||||
if service_status.code != 200: return service_status
|
||||
|
||||
def chat_iterator(query: str, history: List[History]):
|
||||
model = getChatModel()
|
||||
|
||||
result, content = self.create_task(query, history, model, **kargs)
|
||||
# logger.info('result={}'.format(result))
|
||||
# logger.info('content={}'.format(content))
|
||||
|
||||
if self.stream:
|
||||
for token in content["text"]:
|
||||
result["answer"] = token
|
||||
yield json.dumps(result, ensure_ascii=False)
|
||||
else:
|
||||
for token in content["text"]:
|
||||
result["answer"] += token
|
||||
yield json.dumps(result, ensure_ascii=False)
|
||||
|
||||
return StreamingResponse(chat_iterator(query, history),
|
||||
media_type="text/event-stream")
|
||||
|
||||
def create_task(self, query: str, history: List[History], model):
|
||||
'''构建 llm 生成任务'''
|
||||
chain, context, result = self._process(query, history, model)
|
||||
logger.info('chain={}'.format(chain))
|
||||
try:
|
||||
content = chain({"context": context, "question": query})
|
||||
except Exception as e:
|
||||
content = {"text": str(e)}
|
||||
return result, content
|
||||
|
||||
def create_atask(self, query, history, model, callback: AsyncIteratorCallbackHandler):
|
||||
chain, context, result = self._process(query, history, model)
|
||||
task = asyncio.create_task(wrap_done(
|
||||
chain.acall({"context": context, "question": query}), callback.done
|
||||
))
|
||||
return task, result
|
|
@ -0,0 +1,229 @@
|
|||
import asyncio
|
||||
from typing import List
|
||||
|
||||
from langchain import LLMChain
|
||||
from langchain.callbacks import AsyncIteratorCallbackHandler
|
||||
from langchain.prompts.chat import ChatPromptTemplate
|
||||
from langchain.agents import AgentType, initialize_agent
|
||||
|
||||
from dev_opsgpt.tools import (
|
||||
WeatherInfo, WorldTimeGetTimezoneByArea, Multiplier,
|
||||
toLangchainTools, get_tool_schema
|
||||
)
|
||||
from .utils import History, wrap_done
|
||||
from .base_chat import Chat
|
||||
from loguru import logger
|
||||
import json, re
|
||||
|
||||
from dev_opsgpt.sandbox import PyCodeBox, CodeBoxResponse
|
||||
from configs.server_config import SANDBOX_SERVER
|
||||
|
||||
def get_tool_agent(tools, llm):
|
||||
return initialize_agent(
|
||||
tools,
|
||||
llm,
|
||||
agent=AgentType.STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION,
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
PROMPT_TEMPLATE = """
|
||||
`角色`
|
||||
你是一个数据分析师,借鉴下述步骤,逐步完成数据分析任务的拆解和代码编写,尽可能帮助和准确地回答用户的问题。
|
||||
|
||||
数据文件的存放路径为 `./`
|
||||
|
||||
`数据分析流程`
|
||||
- 判断文件是否存在,并读取文件数据
|
||||
- 输出数据的基本信息,包括但不限于字段、文本、数据类型等
|
||||
- 输出数据的详细统计信息
|
||||
- 判断是否需要画图分析,选择合适的字段进行画图
|
||||
- 判断数据是否需要进行清洗
|
||||
- 判断数据或图片是否需要保存
|
||||
...
|
||||
- 结合数据统计分析结果和画图结果,进行总结和分析这份数据的价值
|
||||
|
||||
`要求`
|
||||
- 每轮选择一个数据分析流程,需要综合考虑上轮和后续的可能影响
|
||||
- 数据分析流程只提供参考,不要拘泥于它的具体流程,要有自己的思考
|
||||
- 使用JSON blob来指定一个计划,通过提供task_status关键字(任务状态)、plan关键字(数据分析计划)和code关键字(可执行代码)。
|
||||
|
||||
合法的 "task_status" 值: "finished" 表明当前用户问题已被准确回答 或者 "continued" 表明用户问题仍需要进一步分析
|
||||
|
||||
`$JSON_BLOB如下所示`
|
||||
```
|
||||
{{
|
||||
"task_status": $TASK_STATUS,
|
||||
"plan": $PLAN,
|
||||
"code": ```python\n$CODE```
|
||||
}}
|
||||
```
|
||||
|
||||
`跟随如下示例`
|
||||
问题: 输入待回答的问题
|
||||
行动:$JSON_BLOB
|
||||
|
||||
... (重复 行动 N 次,每次只生成一个行动)
|
||||
|
||||
行动:
|
||||
```
|
||||
{{
|
||||
"task_status": "finished",
|
||||
"plan": 我已经可以回答用户问题了,最后回答用户的内容
|
||||
}}
|
||||
|
||||
```
|
||||
|
||||
`数据分析,开始`
|
||||
|
||||
问题:{query}
|
||||
"""
|
||||
|
||||
|
||||
PROMPT_TEMPLATE_2 = """
|
||||
`角色`
|
||||
你是一个数据分析师,借鉴下述步骤,逐步完成数据分析任务的拆解和代码编写,尽可能帮助和准确地回答用户的问题。
|
||||
|
||||
数据文件的存放路径为 `./`
|
||||
|
||||
`数据分析流程`
|
||||
- 判断文件是否存在,并读取文件数据
|
||||
- 输出数据的基本信息,包括但不限于字段、文本、数据类型等
|
||||
- 输出数据的详细统计信息
|
||||
- 判断数据是否需要进行清洗
|
||||
- 判断是否需要画图分析,选择合适的字段进行画图
|
||||
- 判断清洗后数据或图片是否需要保存
|
||||
...
|
||||
- 结合数据统计分析结果和画图结果,进行总结和分析这份数据的价值
|
||||
|
||||
`要求`
|
||||
- 每轮选择一个数据分析流程,需要综合考虑上轮和后续的可能影响
|
||||
- 数据分析流程只提供参考,不要拘泥于它的具体流程,要有自己的思考
|
||||
- 使用JSON blob来指定一个计划,通过提供task_status关键字(任务状态)、plan关键字(数据分析计划)和code关键字(可执行代码)。
|
||||
|
||||
合法的 "task_status" 值: "finished" 表明当前用户问题已被准确回答 或者 "continued" 表明用户问题仍需要进一步分析
|
||||
|
||||
`$JSON_BLOB如下所示`
|
||||
```
|
||||
{{
|
||||
"task_status": $TASK_STATUS,
|
||||
"plan": $PLAN,
|
||||
"code": ```python\n$CODE```
|
||||
}}
|
||||
```
|
||||
|
||||
`跟随如下示例`
|
||||
问题: 输入待回答的问题
|
||||
行动:$JSON_BLOB
|
||||
|
||||
... (重复 行动 N 次,每次只生成一个行动)
|
||||
|
||||
行动:
|
||||
```
|
||||
{{
|
||||
"task_status": "finished",
|
||||
"plan": 我已经可以回答用户问题了,最后回答用户的内容
|
||||
}}
|
||||
|
||||
`数据分析,开始`
|
||||
|
||||
问题:上传了一份employee_data.csv文件,请对它进行数据分析
|
||||
|
||||
问题:{query}
|
||||
{history}
|
||||
|
||||
"""
|
||||
|
||||
class DataChat(Chat):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine_name: str = "",
|
||||
top_k: int = 1,
|
||||
stream: bool = False,
|
||||
) -> None:
|
||||
super().__init__(engine_name, top_k, stream)
|
||||
self.tool_prompt = """结合上下文信息,{tools} {input}"""
|
||||
self.codebox = PyCodeBox(
|
||||
remote_url=SANDBOX_SERVER["url"],
|
||||
remote_ip=SANDBOX_SERVER["host"], # "http://localhost",
|
||||
remote_port=SANDBOX_SERVER["port"],
|
||||
token="mytoken",
|
||||
do_code_exe=True,
|
||||
do_remote=SANDBOX_SERVER["do_remote"]
|
||||
)
|
||||
|
||||
def create_task(self, query: str, history: List[History], model):
|
||||
'''构建 llm 生成任务'''
|
||||
logger.debug("content:{}".format([i.to_msg_tuple() for i in history] + [("human", PROMPT_TEMPLATE)]))
|
||||
chat_prompt = ChatPromptTemplate.from_messages(
|
||||
[i.to_msg_tuple() for i in history] + [("human", PROMPT_TEMPLATE)]
|
||||
)
|
||||
pattern = re.compile(r"```(?:json)?\n(.*?)\n", re.DOTALL)
|
||||
internal_history = []
|
||||
retry_nums = 2
|
||||
while retry_nums >= 0:
|
||||
if len(internal_history) == 0:
|
||||
chat_prompt = ChatPromptTemplate.from_messages(
|
||||
[i.to_msg_tuple() for i in history] + [("human", PROMPT_TEMPLATE)]
|
||||
)
|
||||
else:
|
||||
chat_prompt = ChatPromptTemplate.from_messages(
|
||||
[i.to_msg_tuple() for i in history] + [("human", PROMPT_TEMPLATE_2)]
|
||||
)
|
||||
|
||||
chain = LLMChain(prompt=chat_prompt, llm=model)
|
||||
content = chain({"query": query, "history": "\n".join(internal_history)})["text"]
|
||||
|
||||
# content = pattern.search(content)
|
||||
# logger.info(f"content: {content}")
|
||||
# content = json.loads(content.group(1).strip(), strict=False)
|
||||
|
||||
internal_history.append(f"{content}")
|
||||
refer_info = "\n".join(internal_history)
|
||||
logger.info(f"refer_info: {refer_info}")
|
||||
try:
|
||||
content = content.split("行动:")[-1].split("行动:")[-1]
|
||||
content = json.loads(content)
|
||||
except:
|
||||
content = content.split("行动:")[-1].split("行动:")[-1]
|
||||
content = eval(content)
|
||||
|
||||
if "finished" == content["task_status"]:
|
||||
break
|
||||
elif "code" in content:
|
||||
# elif "```code" in content or "```python" in content:
|
||||
# code_text = self.codebox.decode_code_from_text(content)
|
||||
code_text = content["code"]
|
||||
codebox_res = self.codebox.chat("```"+code_text+"```", do_code_exe=True)
|
||||
|
||||
if codebox_res is not None and codebox_res.code_exe_status != 200:
|
||||
logger.warning(f"{codebox_res.code_exe_response}")
|
||||
internal_history.append(f"观察: 根据这个报错信息 {codebox_res.code_exe_response},进行代码修复")
|
||||
|
||||
if codebox_res is not None and codebox_res.code_exe_status == 200:
|
||||
if codebox_res.code_exe_type == "image/png":
|
||||
base_text = f"```\n{code_text}\n```\n\n"
|
||||
img_html = "<img src='data:image/png;base64,{}' class='img-fluid'>".format(
|
||||
codebox_res.code_exe_response
|
||||
)
|
||||
internal_history.append(f"观察: {img_html}")
|
||||
# logger.info('```\n'+code_text+'\n```'+"\n\n"+'```\n'+codebox_res.code_exe_response+'\n```')
|
||||
else:
|
||||
internal_history.append(f"观察: {codebox_res.code_exe_response}")
|
||||
# logger.info('```\n'+code_text+'\n```'+"\n\n"+'```\n'+codebox_res.code_exe_response+'\n```')
|
||||
else:
|
||||
internal_history.append(f"观察:下一步应该怎么做?")
|
||||
retry_nums -= 1
|
||||
|
||||
|
||||
return {"answer": "", "docs": ""}, {"text": "\n".join(internal_history)}
|
||||
|
||||
def create_atask(self, query, history, model, callback: AsyncIteratorCallbackHandler):
|
||||
chat_prompt = ChatPromptTemplate.from_messages(
|
||||
[i.to_msg_tuple() for i in history] + [("human", PROMPT_TEMPLATE)]
|
||||
)
|
||||
chain = LLMChain(prompt=chat_prompt, llm=model)
|
||||
task = asyncio.create_task(wrap_done(
|
||||
chain.acall({"input": query}), callback.done
|
||||
))
|
||||
return task, {"answer": "", "docs": ""}
|
|
@ -1,6 +1,5 @@
|
|||
from fastapi import Request
|
||||
import os, asyncio
|
||||
from urllib.parse import urlencode
|
||||
from typing import List, Optional, Dict
|
||||
|
||||
from langchain import LLMChain
|
||||
|
|
|
@ -0,0 +1,84 @@
|
|||
import asyncio
|
||||
from typing import List
|
||||
|
||||
from langchain import LLMChain
|
||||
from langchain.callbacks import AsyncIteratorCallbackHandler
|
||||
from langchain.prompts.chat import ChatPromptTemplate
|
||||
from langchain.agents import AgentType, initialize_agent
|
||||
import langchain
|
||||
from langchain.schema import (
|
||||
AgentAction
|
||||
)
|
||||
|
||||
|
||||
# langchain.debug = True
|
||||
|
||||
from dev_opsgpt.tools import (
|
||||
TOOL_SETS, TOOL_DICT,
|
||||
toLangchainTools, get_tool_schema
|
||||
)
|
||||
from .utils import History, wrap_done
|
||||
from .base_chat import Chat
|
||||
from loguru import logger
|
||||
|
||||
|
||||
def get_tool_agent(tools, llm):
|
||||
return initialize_agent(
|
||||
tools,
|
||||
llm,
|
||||
agent=AgentType.STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION,
|
||||
verbose=True,
|
||||
return_intermediate_steps=True
|
||||
)
|
||||
|
||||
|
||||
class ToolChat(Chat):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine_name: str = "",
|
||||
top_k: int = 1,
|
||||
stream: bool = False,
|
||||
) -> None:
|
||||
super().__init__(engine_name, top_k, stream)
|
||||
self.tool_prompt = """结合上下文信息,{tools} {input}"""
|
||||
self.tools = toLangchainTools([TOOL_DICT[i] for i in TOOL_SETS if i in TOOL_DICT])
|
||||
|
||||
def create_task(self, query: str, history: List[History], model, **kargs):
|
||||
'''构建 llm 生成任务'''
|
||||
logger.debug("content:{}".format([i.to_msg_tuple() for i in history] + [("human", "{query}")]))
|
||||
# chat_prompt = ChatPromptTemplate.from_messages(
|
||||
# [i.to_msg_tuple() for i in history] + [("human", "{query}")]
|
||||
# )
|
||||
tools = kargs.get("tool_sets", [])
|
||||
tools = toLangchainTools([TOOL_DICT[i] for i in tools if i in TOOL_DICT])
|
||||
agent = get_tool_agent(tools if tools else self.tools, model)
|
||||
content = agent(query)
|
||||
|
||||
logger.debug(f"content: {content}")
|
||||
|
||||
s = ""
|
||||
if isinstance(content, str):
|
||||
s = content
|
||||
else:
|
||||
for i in content["intermediate_steps"]:
|
||||
for j in i:
|
||||
if isinstance(j, AgentAction):
|
||||
s += j.log + "\n"
|
||||
else:
|
||||
s += "Observation: " + str(j) + "\n"
|
||||
|
||||
s += "final answer:" + content["output"]
|
||||
# chain = LLMChain(prompt=chat_prompt, llm=model)
|
||||
# content = chain({"tools": tools, "input": query})
|
||||
return {"answer": "", "docs": ""}, {"text": s}
|
||||
|
||||
def create_atask(self, query, history, model, callback: AsyncIteratorCallbackHandler):
|
||||
chat_prompt = ChatPromptTemplate.from_messages(
|
||||
[i.to_msg_tuple() for i in history] + [("human", self.tool_prompt)]
|
||||
)
|
||||
chain = LLMChain(prompt=chat_prompt, llm=model)
|
||||
task = asyncio.create_task(wrap_done(
|
||||
chain.acall({"input": query}), callback.done
|
||||
))
|
||||
return task, {"answer": "", "docs": ""}
|
|
@ -0,0 +1,7 @@
|
|||
# encoding: utf-8
|
||||
'''
|
||||
@author: 温进
|
||||
@file: __init__.py.py
|
||||
@time: 2023/10/23 下午4:57
|
||||
@desc:
|
||||
'''
|
|
@ -0,0 +1,139 @@
|
|||
# encoding: utf-8
|
||||
'''
|
||||
@author: 温进
|
||||
@file: codebase_handler.py
|
||||
@time: 2023/10/23 下午5:05
|
||||
@desc:
|
||||
'''
|
||||
|
||||
from loguru import logger
|
||||
import time
|
||||
import os
|
||||
|
||||
from dev_opsgpt.codebase_handler.parser.java_paraser.java_crawler import JavaCrawler
|
||||
from dev_opsgpt.codebase_handler.parser.java_paraser.java_preprocess import JavaPreprocessor
|
||||
from dev_opsgpt.codebase_handler.parser.java_paraser.java_dedup import JavaDedup
|
||||
from dev_opsgpt.codebase_handler.parser.java_paraser.java_parser import JavaParser
|
||||
from dev_opsgpt.codebase_handler.tagger.tagger import Tagger
|
||||
from dev_opsgpt.codebase_handler.tagger.tuple_generation import node_edge_update
|
||||
|
||||
from dev_opsgpt.codebase_handler.networkx_handler.networkx_handler import NetworkxHandler
|
||||
from dev_opsgpt.codebase_handler.codedb_handler.local_codedb_handler import LocalCodeDBHandler
|
||||
|
||||
|
||||
class CodeBaseHandler():
|
||||
def __init__(self, code_name: str, code_path: str = '', cb_root_path: str = '', history_node_list: list = []):
|
||||
self.nh = None
|
||||
self.lcdh = None
|
||||
self.code_name = code_name
|
||||
self.code_path = code_path
|
||||
|
||||
self.codebase_path = cb_root_path + os.sep + code_name
|
||||
self.graph_path = self.codebase_path + os.sep + 'graph.pk'
|
||||
self.codedb_path = self.codebase_path + os.sep + 'codedb.pk'
|
||||
|
||||
self.tagger = Tagger()
|
||||
self.history_node_list = history_node_list
|
||||
|
||||
def import_code(self, do_save: bool=False, do_load: bool=False) -> bool:
|
||||
'''
|
||||
import code to codeBase
|
||||
@param code_path:
|
||||
@param do_save:
|
||||
@param do_load:
|
||||
@return: True as success; False as failure
|
||||
'''
|
||||
if do_load:
|
||||
logger.info('start load from codebase_path')
|
||||
load_graph_path = self.graph_path
|
||||
load_codedb_path = self.codedb_path
|
||||
|
||||
st = time.time()
|
||||
self.nh = NetworkxHandler(graph_path=load_graph_path)
|
||||
logger.info('generate graph success, rt={}'.format(time.time() - st))
|
||||
|
||||
st = time.time()
|
||||
self.lcdh = LocalCodeDBHandler(db_path=load_codedb_path)
|
||||
logger.info('generate codedb success, rt={}'.format(time.time() - st))
|
||||
else:
|
||||
logger.info('start load from code_path')
|
||||
st = time.time()
|
||||
java_code_dict = JavaCrawler.local_java_file_crawler(self.code_path)
|
||||
logger.info('crawl success, rt={}'.format(time.time() - st))
|
||||
|
||||
jp = JavaPreprocessor()
|
||||
java_code_dict = jp.preprocess(java_code_dict)
|
||||
|
||||
jd = JavaDedup()
|
||||
java_code_dict = jd.dedup(java_code_dict)
|
||||
|
||||
st = time.time()
|
||||
j_parser = JavaParser()
|
||||
parse_res = j_parser.parse(java_code_dict)
|
||||
logger.info('parse success, rt={}'.format(time.time() - st))
|
||||
|
||||
st = time.time()
|
||||
tagged_code = self.tagger.generate_tag(parse_res)
|
||||
node_list, edge_list = node_edge_update(parse_res.values())
|
||||
logger.info('get node and edge success, rt={}'.format(time.time() - st))
|
||||
|
||||
st = time.time()
|
||||
self.nh = NetworkxHandler(node_list=node_list, edge_list=edge_list)
|
||||
logger.info('generate graph success, rt={}'.format(time.time() - st))
|
||||
|
||||
st = time.time()
|
||||
self.lcdh = LocalCodeDBHandler(tagged_code)
|
||||
logger.info('CodeDB load success, rt={}'.format(time.time() - st))
|
||||
|
||||
if do_save:
|
||||
save_graph_path = self.graph_path
|
||||
save_codedb_path = self.codedb_path
|
||||
self.nh.save_graph(save_graph_path)
|
||||
self.lcdh.save_db(save_codedb_path)
|
||||
|
||||
def search_code(self, query: str, code_limit: int, history_node_list: list = []):
|
||||
'''
|
||||
search code related to query
|
||||
@param self:
|
||||
@param query:
|
||||
@return:
|
||||
'''
|
||||
# get query tag
|
||||
query_tag_list = self.tagger.generate_tag_query(query)
|
||||
|
||||
related_node_score_list = self.nh.search_node_with_score(query_tag_list=query_tag_list,
|
||||
history_node_list=history_node_list)
|
||||
|
||||
score_dict = {
|
||||
i[0]: i[1]
|
||||
for i in related_node_score_list
|
||||
}
|
||||
related_node = [i[0] for i in related_node_score_list]
|
||||
related_score = [i[1] for i in related_node_score_list]
|
||||
|
||||
related_code, code_related_node = self.lcdh.search_by_multi_tag(related_node, lim=code_limit)
|
||||
|
||||
related_node = [
|
||||
(node, self.nh.get_node_type(node), score_dict[node])
|
||||
for node in code_related_node
|
||||
]
|
||||
|
||||
related_node.sort(key=lambda x: x[2], reverse=True)
|
||||
|
||||
logger.info('related_node={}'.format(related_node))
|
||||
logger.info('related_code={}'.format(related_code))
|
||||
logger.info('num of code={}'.format(len(related_code)))
|
||||
return related_code, related_node
|
||||
|
||||
def refresh_history(self):
|
||||
self.history_node_list = []
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
# encoding: utf-8
|
||||
'''
|
||||
@author: 温进
|
||||
@file: __init__.py.py
|
||||
@time: 2023/10/23 下午5:04
|
||||
@desc:
|
||||
'''
|
|
@ -0,0 +1,55 @@
|
|||
# encoding: utf-8
|
||||
'''
|
||||
@author: 温进
|
||||
@file: local_codedb_handler.py
|
||||
@time: 2023/10/23 下午5:05
|
||||
@desc:
|
||||
'''
|
||||
import pickle
|
||||
|
||||
|
||||
class LocalCodeDBHandler:
|
||||
def __init__(self, tagged_code: dict = {}, db_path: str = ''):
|
||||
if db_path:
|
||||
with open(db_path, 'rb') as f:
|
||||
self.data = pickle.load(f)
|
||||
else:
|
||||
self.data = {}
|
||||
for code, tag in tagged_code.items():
|
||||
self.data[code] = str(tag)
|
||||
|
||||
def search_by_single_tag(self, tag, lim):
|
||||
res = list()
|
||||
for k, v in self.data.items():
|
||||
if tag in v and k not in res:
|
||||
res.append(k)
|
||||
|
||||
if len(res) > lim:
|
||||
break
|
||||
return res
|
||||
|
||||
def search_by_multi_tag(self, tag_list, lim=3):
|
||||
res = list()
|
||||
res_related_node = []
|
||||
for tag in tag_list:
|
||||
single_tag_res = self.search_by_single_tag(tag, lim)
|
||||
for code in single_tag_res:
|
||||
if code not in res:
|
||||
res.append(code)
|
||||
res_related_node.append(tag)
|
||||
if len(res) >= lim:
|
||||
break
|
||||
|
||||
# reverse order so that most relevant one is close to the query
|
||||
res = res[0:lim]
|
||||
res.reverse()
|
||||
|
||||
return res, res_related_node
|
||||
|
||||
def save_db(self, save_path):
|
||||
with open(save_path, 'wb') as f:
|
||||
pickle.dump(self.data, f)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
# encoding: utf-8
|
||||
'''
|
||||
@author: 温进
|
||||
@file: __init__.py.py
|
||||
@time: 2023/10/23 下午5:00
|
||||
@desc:
|
||||
'''
|
|
@ -0,0 +1,129 @@
|
|||
# encoding: utf-8
|
||||
'''
|
||||
@author: 温进
|
||||
@file: networkx_handler.py
|
||||
@time: 2023/10/23 下午5:02
|
||||
@desc:
|
||||
'''
|
||||
|
||||
import networkx as nx
|
||||
from loguru import logger
|
||||
import matplotlib.pyplot as plt
|
||||
import pickle
|
||||
from collections import defaultdict
|
||||
import json
|
||||
|
||||
QUERY_SCORE = 10
|
||||
HISTORY_SCORE = 5
|
||||
RATIO = 0.5
|
||||
|
||||
|
||||
class NetworkxHandler:
|
||||
def __init__(self, graph_path: str = '', node_list: list = [], edge_list: list = []):
|
||||
if graph_path:
|
||||
self.graph_path = graph_path
|
||||
with open(graph_path, 'r') as f:
|
||||
self.G = nx.node_link_graph(json.load(f))
|
||||
else:
|
||||
self.G = nx.DiGraph()
|
||||
self.populate_graph(node_list, edge_list)
|
||||
logger.debug(
|
||||
'number of nodes={}, number of edges={}'.format(self.G.number_of_nodes(), self.G.number_of_edges()))
|
||||
|
||||
self.query_score = QUERY_SCORE
|
||||
self.history_score = HISTORY_SCORE
|
||||
self.ratio = RATIO
|
||||
|
||||
def populate_graph(self, node_list, edge_list):
|
||||
'''
|
||||
populate graph with node_list and edge_list
|
||||
'''
|
||||
self.G.add_nodes_from(node_list)
|
||||
for edge in edge_list:
|
||||
self.G.add_edge(edge[0], edge[-1], relation=edge[1])
|
||||
|
||||
def draw_graph(self, save_path: str):
|
||||
'''
|
||||
draw and save to save_path
|
||||
'''
|
||||
sub = plt.subplot(111)
|
||||
nx.draw(self.G, with_labels=True)
|
||||
|
||||
plt.savefig(save_path)
|
||||
|
||||
def search_node(self, query_tag_list: list, history_node_list: list = []):
|
||||
'''
|
||||
search node by tag_list, search from history_tag neighbors first
|
||||
> query_tag_list: tag from query
|
||||
> history_node_list
|
||||
'''
|
||||
node_list = set()
|
||||
|
||||
# search from history_tag_list first, then all nodes
|
||||
for tag in query_tag_list:
|
||||
add = False
|
||||
for history_node in history_node_list:
|
||||
connect_node_list: list = self.G.adj[history_node]
|
||||
connect_node_list.insert(0, history_node)
|
||||
for connect_node in connect_node_list:
|
||||
node_name_lim = len(connect_node) if '_' not in connect_node else connect_node.index('_')
|
||||
node_name = connect_node[0:node_name_lim]
|
||||
if tag.lower() in node_name.lower():
|
||||
node_list.add(connect_node)
|
||||
add = True
|
||||
if not add:
|
||||
for node in self.G.nodes():
|
||||
if tag.lower() in node.lower():
|
||||
node_list.add(node)
|
||||
return node_list
|
||||
|
||||
def search_node_with_score(self, query_tag_list: list, history_node_list: list = []):
|
||||
'''
|
||||
search node by tag_list, search from history_tag neighbors first
|
||||
> query_tag_list: tag from query
|
||||
> history_node_list
|
||||
'''
|
||||
logger.info('query_tag_list={}, history_node_list={}'.format(query_tag_list, history_node_list))
|
||||
node_dict = defaultdict(lambda: 0)
|
||||
|
||||
# loop over query_tag_list and add node:
|
||||
for tag in query_tag_list:
|
||||
for node in self.G.nodes:
|
||||
if tag.lower() in node.lower():
|
||||
node_dict[node] += self.query_score
|
||||
|
||||
# loop over history_node and add node score
|
||||
for node in history_node_list:
|
||||
node_dict[node] += self.history_score
|
||||
|
||||
logger.info('temp_res={}'.format(node_dict))
|
||||
|
||||
# adj score broadcast
|
||||
for node in node_dict:
|
||||
adj_node_list = self.G.adj[node]
|
||||
for adj_node in adj_node_list:
|
||||
node_dict[node] += node_dict.get(adj_node, 0) * self.ratio
|
||||
|
||||
# sort
|
||||
node_list = [(node, node_score) for node, node_score in node_dict.items()]
|
||||
node_list.sort(key=lambda x: x[1], reverse=True)
|
||||
return node_list
|
||||
|
||||
def save_graph(self, save_path: str):
|
||||
to_save = nx.node_link_data(self.G)
|
||||
with open(save_path, 'w') as f:
|
||||
json.dump(to_save, f)
|
||||
|
||||
def __len__(self):
|
||||
return self.G.number_of_nodes()
|
||||
|
||||
def get_node_type(self, node_name):
|
||||
node_type = self.G.nodes[node_name]['type']
|
||||
return node_type
|
||||
|
||||
def refresh_graph(self, ):
|
||||
with open(self.graph_path, 'r') as f:
|
||||
self.G = nx.node_link_graph(json.load(f))
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
# encoding: utf-8
|
||||
'''
|
||||
@author: 温进
|
||||
@file: __init__.py.py
|
||||
@time: 2023/10/23 下午5:00
|
||||
@desc:
|
||||
'''
|
|
@ -0,0 +1,7 @@
|
|||
# encoding: utf-8
|
||||
'''
|
||||
@author: 温进
|
||||
@file: __init__.py.py
|
||||
@time: 2023/10/23 下午5:01
|
||||
@desc:
|
||||
'''
|
|
@ -0,0 +1,32 @@
|
|||
# encoding: utf-8
|
||||
'''
|
||||
@author: 温进
|
||||
@file: java_crawler.py
|
||||
@time: 2023/10/23 下午5:02
|
||||
@desc:
|
||||
'''
|
||||
|
||||
import os
|
||||
import glob
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class JavaCrawler:
|
||||
@staticmethod
|
||||
def local_java_file_crawler(path: str):
|
||||
'''
|
||||
read local java file in path
|
||||
> path: path to crawl, must be absolute path like A/B/C
|
||||
< dict of java code string
|
||||
'''
|
||||
java_file_list = glob.glob('{path}{sep}**{sep}*.java'.format(path=path, sep=os.path.sep), recursive=True)
|
||||
java_code_dict = {}
|
||||
|
||||
logger.debug('number of file={}'.format(len(java_file_list)))
|
||||
# logger.debug(java_file_list)
|
||||
|
||||
for java_file in java_file_list:
|
||||
with open(java_file) as f:
|
||||
java_code = ''.join(f.readlines())
|
||||
java_code_dict[java_file] = java_code
|
||||
return java_code_dict
|
|
@ -0,0 +1,15 @@
|
|||
# encoding: utf-8
|
||||
'''
|
||||
@author: 温进
|
||||
@file: java_dedup.py
|
||||
@time: 2023/10/23 下午5:02
|
||||
@desc:
|
||||
'''
|
||||
|
||||
|
||||
class JavaDedup:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def dedup(self, java_code_dict):
|
||||
return java_code_dict
|
|
@ -0,0 +1,107 @@
|
|||
# encoding: utf-8
|
||||
'''
|
||||
@author: 温进
|
||||
@file: java_parser.py
|
||||
@time: 2023/10/23 下午5:03
|
||||
@desc:
|
||||
'''
|
||||
import json
|
||||
import javalang
|
||||
import glob
|
||||
import os
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class JavaParser:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def parse(self, java_code_list):
|
||||
'''
|
||||
parse java code and extract entity
|
||||
'''
|
||||
tree_dict = self.preparse(java_code_list)
|
||||
res = self.multi_java_code_parse(tree_dict)
|
||||
|
||||
return res
|
||||
|
||||
def preparse(self, java_code_dict):
|
||||
'''
|
||||
preparse by javalang
|
||||
< dict of java_code and tree
|
||||
'''
|
||||
tree_dict = {}
|
||||
for fp, java_code in java_code_dict.items():
|
||||
try:
|
||||
tree = javalang.parse.parse(java_code)
|
||||
except Exception as e:
|
||||
continue
|
||||
|
||||
if tree.package is not None:
|
||||
tree_dict[java_code] = tree
|
||||
logger.info('success parse {} files'.format(len(tree_dict)))
|
||||
return tree_dict
|
||||
|
||||
def single_java_code_parse(self, tree):
|
||||
'''
|
||||
parse single code file
|
||||
> tree: javalang parse result
|
||||
< {pac_name: '', class_name_list: [], func_name_list: [], import_pac_name_list: []]}
|
||||
'''
|
||||
import_pac_name_list = []
|
||||
|
||||
# get imports
|
||||
import_list = tree.imports
|
||||
|
||||
for import_pac in import_list:
|
||||
import_pac_name = import_pac.path
|
||||
import_pac_name_list.append(import_pac_name)
|
||||
|
||||
pac_name = tree.package.name
|
||||
class_name_list = []
|
||||
func_name_dict = {}
|
||||
|
||||
for node in tree.types:
|
||||
if type(node) in (javalang.tree.ClassDeclaration, javalang.tree.InterfaceDeclaration):
|
||||
class_name = pac_name + '.' + node.name
|
||||
class_name_list.append(class_name)
|
||||
|
||||
for node_inner in node.body:
|
||||
if type(node_inner) is javalang.tree.MethodDeclaration:
|
||||
func_name = class_name + '.' + node_inner.name
|
||||
|
||||
# add params name to func_name
|
||||
params_list = node_inner.parameters
|
||||
|
||||
for params in params_list:
|
||||
params_name = params.type.name
|
||||
func_name = func_name + '_' + params_name
|
||||
|
||||
if class_name not in func_name_dict:
|
||||
func_name_dict[class_name] = []
|
||||
|
||||
func_name_dict[class_name].append(func_name)
|
||||
|
||||
res = {
|
||||
'pac_name': pac_name,
|
||||
'class_name_list': class_name_list,
|
||||
'func_name_dict': func_name_dict,
|
||||
'import_pac_name_list': import_pac_name_list
|
||||
}
|
||||
return res
|
||||
|
||||
def multi_java_code_parse(self, tree_dict):
|
||||
'''
|
||||
parse multiple java code
|
||||
> tree_list
|
||||
< parse_result_dict
|
||||
'''
|
||||
res_dict = {}
|
||||
for java_code, tree in tree_dict.items():
|
||||
try:
|
||||
res_dict[java_code] = self.single_java_code_parse(tree)
|
||||
except Exception as e:
|
||||
logger.debug(java_code)
|
||||
raise ImportError
|
||||
|
||||
return res_dict
|
|
@ -0,0 +1,14 @@
|
|||
# encoding: utf-8
|
||||
'''
|
||||
@author: 温进
|
||||
@file: java_preprocess.py
|
||||
@time: 2023/10/23 下午5:04
|
||||
@desc:
|
||||
'''
|
||||
|
||||
class JavaPreprocessor:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def preprocess(self, java_code_dict):
|
||||
return java_code_dict
|
|
@ -0,0 +1,7 @@
|
|||
# encoding: utf-8
|
||||
'''
|
||||
@author: 温进
|
||||
@file: __init__.py.py
|
||||
@time: 2023/10/23 下午5:00
|
||||
@desc:
|
||||
'''
|
|
@ -0,0 +1,48 @@
|
|||
# encoding: utf-8
|
||||
'''
|
||||
@author: 温进
|
||||
@file: tagger.py
|
||||
@time: 2023/10/23 下午5:01
|
||||
@desc:
|
||||
'''
|
||||
import re
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class Tagger:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def generate_tag(self, parse_res_dict: dict):
|
||||
'''
|
||||
generate tag from parse_res
|
||||
'''
|
||||
res = {}
|
||||
for java_code, parse_res in parse_res_dict.items():
|
||||
tag = {}
|
||||
tag['pac_name'] = parse_res.get('pac_name')
|
||||
tag['class_name'] = set(parse_res.get('class_name_list'))
|
||||
tag['func_name'] = set()
|
||||
|
||||
for _, func_name_list in parse_res.get('func_name_dict', {}).items():
|
||||
tag['func_name'].update(func_name_list)
|
||||
|
||||
res[java_code] = tag
|
||||
return res
|
||||
|
||||
def generate_tag_query(self, query):
|
||||
'''
|
||||
generate tag from query
|
||||
'''
|
||||
# simple extract english
|
||||
tag_list = re.findall(r'[a-zA-Z\_\.]+', query)
|
||||
tag_list = list(set(tag_list))
|
||||
return tag_list
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tagger = Tagger()
|
||||
logger.debug(tagger.generate_tag_query('com.CheckHolder 有哪些函数'))
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,51 @@
|
|||
# encoding: utf-8
|
||||
'''
|
||||
@author: 温进
|
||||
@file: tuple_generation.py
|
||||
@time: 2023/10/23 下午5:01
|
||||
@desc:
|
||||
'''
|
||||
|
||||
|
||||
def node_edge_update(parse_res_list: list, node_list: list = list(), edge_list: list = list()):
|
||||
'''
|
||||
generate node and edge by parse_res
|
||||
< node: list of string node
|
||||
< edge: (node_st, relation, node_ed)
|
||||
'''
|
||||
node_dict = {i: j for i, j in node_list}
|
||||
|
||||
for single_parse_res in parse_res_list:
|
||||
pac_name = single_parse_res['pac_name']
|
||||
|
||||
node_dict[pac_name] = {'type': 'package'}
|
||||
|
||||
# class_name
|
||||
for class_name in single_parse_res['class_name_list']:
|
||||
node_dict[class_name] = {'type': 'class'}
|
||||
edge_list.append((pac_name, 'contain', class_name))
|
||||
edge_list.append((class_name, 'inside', pac_name))
|
||||
|
||||
# func_name
|
||||
for class_name, func_name_list in single_parse_res['func_name_dict'].items():
|
||||
node_list.append(class_name)
|
||||
for func_name in func_name_list:
|
||||
node_dict[func_name] = {'type': 'func'}
|
||||
edge_list.append((class_name, 'contain', func_name))
|
||||
edge_list.append((func_name, 'inside', class_name))
|
||||
|
||||
# depend
|
||||
for depend_pac_name in single_parse_res['import_pac_name_list']:
|
||||
if depend_pac_name.endswith('*'):
|
||||
depend_pac_name = depend_pac_name[0:-2]
|
||||
|
||||
if depend_pac_name in node_dict:
|
||||
continue
|
||||
else:
|
||||
node_dict[depend_pac_name] = {'type': 'unknown'}
|
||||
edge_list.append((pac_name, 'depend', depend_pac_name))
|
||||
edge_list.append((depend_pac_name, 'beDepended', pac_name))
|
||||
|
||||
node_list = [(i, j) for i, j in node_dict.items()]
|
||||
|
||||
return node_list, edge_list
|
|
@ -0,0 +1,9 @@
|
|||
from .configs import PHASE_CONFIGS
|
||||
|
||||
|
||||
|
||||
PHASE_LIST = list(PHASE_CONFIGS.keys())
|
||||
|
||||
__all__ = [
|
||||
"PHASE_CONFIGS"
|
||||
]
|
|
@ -0,0 +1,6 @@
|
|||
from .base_agent import BaseAgent
|
||||
from .react_agent import ReactAgent
|
||||
|
||||
__all__ = [
|
||||
"BaseAgent", "ReactAgent"
|
||||
]
|
|
@ -0,0 +1,427 @@
|
|||
from pydantic import BaseModel
|
||||
from typing import List, Union
|
||||
import re
|
||||
import copy
|
||||
import json
|
||||
import traceback
|
||||
import uuid
|
||||
from loguru import logger
|
||||
|
||||
from dev_opsgpt.connector.shcema.memory import Memory
|
||||
from dev_opsgpt.connector.connector_schema import (
|
||||
Task, Role, Message, ActionStatus, Doc, CodeDoc
|
||||
)
|
||||
from configs.server_config import SANDBOX_SERVER
|
||||
from dev_opsgpt.sandbox import PyCodeBox, CodeBoxResponse
|
||||
from dev_opsgpt.tools import DDGSTool, DocRetrieval, CodeRetrieval
|
||||
from dev_opsgpt.connector.configs.agent_config import REACT_PROMPT_INPUT
|
||||
|
||||
from dev_opsgpt.llm_models import getChatModel
|
||||
|
||||
|
||||
class BaseAgent:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
role: Role,
|
||||
task: Task = None,
|
||||
memory: Memory = None,
|
||||
chat_turn: int = 1,
|
||||
do_search: bool = False,
|
||||
do_doc_retrieval: bool = False,
|
||||
do_tool_retrieval: bool = False,
|
||||
temperature: float = 0.2,
|
||||
stop: Union[List[str], str] = None,
|
||||
do_filter: bool = True,
|
||||
do_use_self_memory: bool = True,
|
||||
# docs_prompt: str,
|
||||
# prompt_mamnger: PromptManager
|
||||
):
|
||||
|
||||
self.task = task
|
||||
self.role = role
|
||||
self.llm = self.create_llm_engine(temperature, stop)
|
||||
self.memory = self.init_history(memory)
|
||||
self.chat_turn = chat_turn
|
||||
self.do_search = do_search
|
||||
self.do_doc_retrieval = do_doc_retrieval
|
||||
self.do_tool_retrieval = do_tool_retrieval
|
||||
self.codebox = PyCodeBox(
|
||||
remote_url=SANDBOX_SERVER["url"],
|
||||
remote_ip=SANDBOX_SERVER["host"],
|
||||
remote_port=SANDBOX_SERVER["port"],
|
||||
token="mytoken",
|
||||
do_code_exe=True,
|
||||
do_remote=SANDBOX_SERVER["do_remote"],
|
||||
do_check_net=False
|
||||
)
|
||||
self.do_filter = do_filter
|
||||
self.do_use_self_memory = do_use_self_memory
|
||||
# self.docs_prompt = docs_prompt
|
||||
# self.prompt_manager = None
|
||||
|
||||
def run(self, query: Message, history: Memory = None, background: Memory = None) -> Message:
|
||||
'''llm inference'''
|
||||
# insert query into memory
|
||||
query_c = copy.deepcopy(query)
|
||||
|
||||
self_memory = self.memory if self.do_use_self_memory else None
|
||||
prompt = self.create_prompt(query_c, self_memory, history, background)
|
||||
content = self.llm.predict(prompt)
|
||||
logger.debug(f"{self.role.role_name} prompt: {prompt}")
|
||||
# logger.debug(f"{self.role.role_name} content: {content}")
|
||||
|
||||
output_message = Message(
|
||||
role_name=self.role.role_name,
|
||||
role_type="ai", #self.role.role_type,
|
||||
role_content=content,
|
||||
role_contents=[content],
|
||||
input_query=query_c.input_query,
|
||||
tools=query_c.tools
|
||||
)
|
||||
|
||||
output_message = self.parser(output_message)
|
||||
if self.do_filter:
|
||||
output_message = self.filter(output_message)
|
||||
|
||||
|
||||
# 更新自身的回答
|
||||
self.append_history(query_c)
|
||||
self.append_history(output_message)
|
||||
logger.info(f"{self.role.role_name} step_run: {output_message.role_content}")
|
||||
return output_message
|
||||
|
||||
def create_prompt(
|
||||
self, query: Message, memory: Memory =None, history: Memory = None, background: Memory = None, prompt_mamnger=None) -> str:
|
||||
'''
|
||||
role\task\tools\docs\memory
|
||||
'''
|
||||
#
|
||||
doc_infos = self.create_doc_prompt(query)
|
||||
code_infos = self.create_codedoc_prompt(query)
|
||||
#
|
||||
formatted_tools, tool_names = self.create_tools_prompt(query)
|
||||
task_prompt = self.create_task_prompt(query)
|
||||
background_prompt = self.create_background_prompt(background, control_key="step_content")
|
||||
history_prompt = self.create_history_prompt(history)
|
||||
selfmemory_prompt = self.create_selfmemory_prompt(memory, control_key="step_content")
|
||||
#
|
||||
# extra_system_prompt = self.role.role_prompt
|
||||
|
||||
prompt = self.role.role_prompt.format(**{"formatted_tools": formatted_tools, "tool_names": tool_names})
|
||||
|
||||
task = query.task or self.task
|
||||
if task_prompt is not None:
|
||||
prompt += "\n" + task.task_prompt
|
||||
|
||||
if doc_infos is not None and doc_infos!="" and doc_infos!="不存在知识库辅助信息":
|
||||
prompt += f"\n知识库信息: {doc_infos}"
|
||||
|
||||
if code_infos is not None and code_infos!="" and code_infos!="不存在代码库辅助信息":
|
||||
prompt += f"\n代码库信息: {code_infos}"
|
||||
|
||||
if background_prompt:
|
||||
prompt += "\n" + background_prompt
|
||||
|
||||
if history_prompt:
|
||||
prompt += "\n" + history_prompt
|
||||
|
||||
if selfmemory_prompt:
|
||||
prompt += "\n" + selfmemory_prompt
|
||||
|
||||
# input_query = memory.to_tuple_messages(content_key="step_content")
|
||||
# input_query = "\n".join([f"{k}: {v}" for k, v in input_query if v])
|
||||
|
||||
input_query = query.role_content
|
||||
|
||||
# logger.debug(f"{self.role.role_name} extra_system_prompt: {self.role.role_prompt}")
|
||||
logger.debug(f"{self.role.role_name} input_query: {input_query}")
|
||||
# logger.debug(f"{self.role.role_name} doc_infos: {doc_infos}")
|
||||
logger.debug(f"{self.role.role_name} tool_names: {tool_names}")
|
||||
prompt += "\n" + REACT_PROMPT_INPUT.format(**{"query": input_query})
|
||||
# prompt = extra_system_prompt.format(**{"query": input_query, "doc_infos": doc_infos, "formatted_tools": formatted_tools, "tool_names": tool_names})
|
||||
while "{{" in prompt or "}}" in prompt:
|
||||
prompt = prompt.replace("{{", "{")
|
||||
prompt = prompt.replace("}}", "}")
|
||||
return prompt
|
||||
|
||||
# prompt_comp = [("system", extra_system_prompt)] + memory.to_tuple_messages()
|
||||
# prompt = ChatPromptTemplate.from_messages(prompt_comp)
|
||||
# prompt = prompt.format(**{"query": query.role_content, "doc_infos": doc_infos, "formatted_tools": formatted_tools, "tool_names": tool_names})
|
||||
# return prompt
|
||||
|
||||
def create_doc_prompt(self, message: Message) -> str:
|
||||
''''''
|
||||
db_docs = message.db_docs
|
||||
search_docs = message.search_docs
|
||||
doc_infos = "\n".join([doc.get_snippet() for doc in db_docs] + [doc.get_snippet() for doc in search_docs])
|
||||
return doc_infos or "不存在知识库辅助信息"
|
||||
|
||||
def create_codedoc_prompt(self, message: Message) -> str:
|
||||
''''''
|
||||
code_docs = message.code_docs
|
||||
doc_infos = "\n".join([doc.get_code() for doc in code_docs])
|
||||
return doc_infos or "不存在代码库辅助信息"
|
||||
|
||||
def create_tools_prompt(self, message: Message) -> str:
|
||||
tools = message.tools
|
||||
tool_strings = []
|
||||
for tool in tools:
|
||||
args_schema = re.sub("}", "}}}}", re.sub("{", "{{{{", str(tool.args)))
|
||||
tool_strings.append(f"{tool.name}: {tool.description}, args: {args_schema}")
|
||||
formatted_tools = "\n".join(tool_strings)
|
||||
tool_names = ", ".join([tool.name for tool in tools])
|
||||
return formatted_tools, tool_names
|
||||
|
||||
def create_task_prompt(self, message: Message) -> str:
|
||||
task = message.task or self.task
|
||||
return "\n任务目标: " + task.task_prompt if task is not None else None
|
||||
|
||||
def create_background_prompt(self, background: Memory, control_key="role_content") -> str:
|
||||
background_message = None if background is None else background.to_str_messages(content_key=control_key)
|
||||
# logger.debug(f"background_message: {background_message}")
|
||||
if background_message:
|
||||
background_message = re.sub("}", "}}", re.sub("{", "{{", background_message))
|
||||
return "\n背景信息: " + background_message if background_message else None
|
||||
|
||||
def create_history_prompt(self, history: Memory, control_key="role_content") -> str:
|
||||
history_message = None if history is None else history.to_str_messages(content_key=control_key)
|
||||
if history_message:
|
||||
history_message = re.sub("}", "}}", re.sub("{", "{{", history_message))
|
||||
return "\n补充对话信息: " + history_message if history_message else None
|
||||
|
||||
def create_selfmemory_prompt(self, selfmemory: Memory, control_key="role_content") -> str:
|
||||
selfmemory_message = None if selfmemory is None else selfmemory.to_str_messages(content_key=control_key)
|
||||
if selfmemory_message:
|
||||
selfmemory_message = re.sub("}", "}}", re.sub("{", "{{", selfmemory_message))
|
||||
return "\n补充自身对话信息: " + selfmemory_message if selfmemory_message else None
|
||||
|
||||
def init_history(self, memory: Memory = None) -> Memory:
|
||||
return Memory([])
|
||||
|
||||
def update_history(self, message: Message):
|
||||
self.memory.append(message)
|
||||
|
||||
def append_history(self, message: Message):
|
||||
self.memory.append(message)
|
||||
|
||||
def clear_history(self, ):
|
||||
self.memory.clear()
|
||||
self.memory = self.init_history()
|
||||
|
||||
def create_llm_engine(self, temperature=0.2, stop=None):
|
||||
return getChatModel(temperature=temperature, stop=stop)
|
||||
|
||||
def filter(self, message: Message, stop=None) -> Message:
|
||||
|
||||
tool_params = self.parser_spec_key(message.role_content, "tool_params")
|
||||
code_content = self.parser_spec_key(message.role_content, "code_content")
|
||||
plan = self.parser_spec_key(message.role_content, "plan")
|
||||
plans = self.parser_spec_key(message.role_content, "plans", do_search=False)
|
||||
content = self.parser_spec_key(message.role_content, "content", do_search=False)
|
||||
|
||||
# logger.debug(f"tool_params: {tool_params}, code_content: {code_content}, plan: {plan}, plans: {plans}, content: {content}")
|
||||
role_content = tool_params or code_content or plan or plans or content
|
||||
message.role_content = role_content or message.role_content
|
||||
return message
|
||||
|
||||
def token_usage(self, ):
|
||||
pass
|
||||
|
||||
def get_extra_infos(self, message: Message) -> Message:
|
||||
''''''
|
||||
if self.do_search:
|
||||
message = self.get_search_retrieval(message)
|
||||
|
||||
if self.do_doc_retrieval:
|
||||
message = self.get_doc_retrieval(message)
|
||||
|
||||
if self.do_tool_retrieval:
|
||||
message = self.get_tool_retrieval(message)
|
||||
|
||||
return message
|
||||
|
||||
def get_search_retrieval(self, message: Message,) -> Message:
|
||||
SEARCH_ENGINES = {"duckduckgo": DDGSTool}
|
||||
search_docs = []
|
||||
for idx, doc in enumerate(SEARCH_ENGINES["duckduckgo"].run(message.role_content, 3)):
|
||||
doc.update({"index": idx})
|
||||
search_docs.append(Doc(**doc))
|
||||
message.search_docs = search_docs
|
||||
return message
|
||||
|
||||
def get_doc_retrieval(self, message: Message) -> Message:
|
||||
query = message.role_content
|
||||
knowledge_basename = message.doc_engine_name
|
||||
top_k = message.top_k
|
||||
score_threshold = message.score_threshold
|
||||
if knowledge_basename:
|
||||
docs = DocRetrieval.run(query, knowledge_basename, top_k, score_threshold)
|
||||
message.db_docs = [Doc(**doc) for doc in docs]
|
||||
return message
|
||||
|
||||
def get_code_retrieval(self, message: Message) -> Message:
|
||||
# DocRetrieval.run("langchain是什么", "DSADSAD")
|
||||
query = message.input_query
|
||||
code_engine_name = message.code_engine_name
|
||||
history_node_list = message.history_node_list
|
||||
code_docs = CodeRetrieval.run(code_engine_name, query, code_limit=message.top_k, history_node_list=history_node_list)
|
||||
message.code_docs = [CodeDoc(**doc) for doc in code_docs]
|
||||
return message
|
||||
|
||||
def get_tool_retrieval(self, message: Message) -> Message:
|
||||
return message
|
||||
|
||||
def step_router(self, message: Message) -> Message:
|
||||
''''''
|
||||
# message = self.parser(message)
|
||||
# logger.debug(f"message.action_status: {message.action_status}")
|
||||
if message.action_status == ActionStatus.CODING:
|
||||
message = self.code_step(message)
|
||||
elif message.action_status == ActionStatus.TOOL_USING:
|
||||
message = self.tool_step(message)
|
||||
|
||||
return message
|
||||
|
||||
def code_step(self, message: Message) -> Message:
|
||||
'''execute code'''
|
||||
# logger.debug(f"message.role_content: {message.role_content}, message.code_content: {message.code_content}")
|
||||
code_answer = self.codebox.chat('```python\n{}```'.format(message.code_content))
|
||||
code_prompt = f"执行上述代码后存在报错信息为 {code_answer.code_exe_response},需要进行修复" \
|
||||
if code_answer.code_exe_type == "error" else f"执行上述代码后返回信息为 {code_answer.code_exe_response}"
|
||||
uid = str(uuid.uuid1())
|
||||
if code_answer.code_exe_type == "image/png":
|
||||
message.figures[uid] = code_answer.code_exe_response
|
||||
message.code_answer = f"\n观察: 执行上述代码后生成一张图片, 图片名为{uid}\n"
|
||||
message.observation = f"\n观察: 执行上述代码后生成一张图片, 图片名为{uid}\n"
|
||||
message.step_content += f"\n观察: 执行上述代码后生成一张图片, 图片名为{uid}\n"
|
||||
message.step_contents += [f"\n观察: 执行上述代码后生成一张图片, 图片名为{uid}\n"]
|
||||
message.role_content += f"\n观察:执行上述代码后生成一张图片, 图片名为{uid}\n"
|
||||
else:
|
||||
message.code_answer = code_answer.code_exe_response
|
||||
message.observation = code_answer.code_exe_response
|
||||
message.step_content += f"\n观察: {code_prompt}\n"
|
||||
message.step_contents += [f"\n观察: {code_prompt}\n"]
|
||||
message.role_content += f"\n观察: {code_prompt}\n"
|
||||
# logger.info(f"观察: {message.action_status}, {message.observation}")
|
||||
return message
|
||||
|
||||
def tool_step(self, message: Message) -> Message:
|
||||
'''execute tool'''
|
||||
# logger.debug(f"message: {message.action_status}, {message.tool_name}, {message.tool_params}")
|
||||
tool_names = [tool.name for tool in message.tools]
|
||||
if message.tool_name not in tool_names:
|
||||
message.tool_answer = "不存在可以执行的tool"
|
||||
message.observation = "不存在可以执行的tool"
|
||||
message.role_content += f"\n观察: 不存在可以执行的tool\n"
|
||||
message.step_content += f"\n观察: 不存在可以执行的tool\n"
|
||||
message.step_contents += [f"\n观察: 不存在可以执行的tool\n"]
|
||||
for tool in message.tools:
|
||||
if tool.name == message.tool_name:
|
||||
tool_res = tool.func(**message.tool_params)
|
||||
message.tool_answer = tool_res
|
||||
message.observation = tool_res
|
||||
message.role_content += f"\n观察: {tool_res}\n"
|
||||
message.step_content += f"\n观察: {tool_res}\n"
|
||||
message.step_contents += [f"\n观察: {tool_res}\n"]
|
||||
|
||||
# logger.info(f"观察: {message.action_status}, {message.observation}")
|
||||
return message
|
||||
|
||||
def parser(self, message: Message) -> Message:
|
||||
''''''
|
||||
content = message.role_content
|
||||
parser_keys = ["action", "code_content", "code_filename", "tool_params", "plans"]
|
||||
try:
|
||||
s_json = self._parse_json(content)
|
||||
message.action_status = s_json.get("action")
|
||||
message.code_content = s_json.get("code_content")
|
||||
message.tool_params = s_json.get("tool_params")
|
||||
message.tool_name = s_json.get("tool_name")
|
||||
message.code_filename = s_json.get("code_filename")
|
||||
message.plans = s_json.get("plans")
|
||||
# for parser_key in parser_keys:
|
||||
# message.action_status = content.get(parser_key)
|
||||
except Exception as e:
|
||||
# logger.warning(f"{traceback.format_exc()}")
|
||||
action_value = self._match(r"'action':\s*'([^']*)'", content) if "'action'" in content else self._match(r'"action":\s*"([^"]*)"', content)
|
||||
code_content_value = self._match(r"'code_content':\s*'([^']*)'", content) if "'code_content'" in content else self._match(r'"code_content":\s*"([^"]*)"', content)
|
||||
filename_value = self._match(r"'code_filename':\s*'([^']*)'", content) if "'code_filename'" in content else self._match(r'"code_filename":\s*"([^"]*)"', content)
|
||||
tool_params_value = self._match(r"'tool_params':\s*(\{[^{}]*\})", content, do_json=True) if "'tool_params'" in content \
|
||||
else self._match(r'"tool_params":\s*(\{[^{}]*\})', content, do_json=True)
|
||||
tool_name_value = self._match(r"'tool_name':\s*'([^']*)'", content) if "'tool_name'" in content else self._match(r'"tool_name":\s*"([^"]*)"', content)
|
||||
plans_value = self._match(r"'plans':\s*(\[.*?\])", content, do_search=False) if "'plans'" in content else self._match(r'"plans":\s*(\[.*?\])', content, do_search=False, )
|
||||
# re解析
|
||||
message.action_status = action_value or "default"
|
||||
message.code_content = code_content_value
|
||||
message.code_filename = filename_value
|
||||
message.tool_params = tool_params_value
|
||||
message.tool_name = tool_name_value
|
||||
message.plans = plans_value
|
||||
|
||||
# logger.debug(f"确认当前的action: {message.action_status}")
|
||||
|
||||
return message
|
||||
|
||||
def parser_spec_key(self, content, key, do_search=True, do_json=False) -> str:
|
||||
''''''
|
||||
key2pattern = {
|
||||
"'action'": r"'action':\s*'([^']*)'", '"action"': r'"action":\s*"([^"]*)"',
|
||||
"'code_content'": r"'code_content':\s*'([^']*)'", '"code_content"': r'"code_content":\s*"([^"]*)"',
|
||||
"'code_filename'": r"'code_filename':\s*'([^']*)'", '"code_filename"': r'"code_filename":\s*"([^"]*)"',
|
||||
"'tool_params'": r"'tool_params':\s*(\{[^{}]*\})", '"tool_params"': r'"tool_params":\s*(\{[^{}]*\})',
|
||||
"'tool_name'": r"'tool_name':\s*'([^']*)'", '"tool_name"': r'"tool_name":\s*"([^"]*)"',
|
||||
"'plans'": r"'plans':\s*(\[.*?\])", '"plans"': r'"plans":\s*(\[.*?\])',
|
||||
"'content'": r"'content':\s*'([^']*)'", '"content"': r'"content":\s*"([^"]*)"',
|
||||
}
|
||||
|
||||
s_json = self._parse_json(content)
|
||||
try:
|
||||
if s_json and key in s_json:
|
||||
return str(s_json[key])
|
||||
except:
|
||||
pass
|
||||
|
||||
keystr = f"'{key}'" if f"'{key}'" in content else f'"{key}"'
|
||||
return self._match(key2pattern.get(keystr, fr"'{key}':\s*'([^']*)'"), content, do_search=do_search, do_json=do_json)
|
||||
|
||||
def _match(self, pattern, s, do_search=True, do_json=False):
|
||||
try:
|
||||
if do_search:
|
||||
match = re.search(pattern, s)
|
||||
if match:
|
||||
value = match.group(1).replace("\\n", "\n")
|
||||
if do_json:
|
||||
value = json.loads(value)
|
||||
else:
|
||||
value = None
|
||||
else:
|
||||
match = re.findall(pattern, s, re.DOTALL)
|
||||
if match:
|
||||
value = match[0]
|
||||
if do_json:
|
||||
value = json.loads(value)
|
||||
else:
|
||||
value = None
|
||||
except Exception as e:
|
||||
logger.warning(f"{traceback.format_exc()}")
|
||||
|
||||
# logger.debug(f"pattern: {pattern}, s: {s}, match: {match}")
|
||||
return value
|
||||
|
||||
def _parse_json(self, s):
|
||||
try:
|
||||
pattern = r"```([^`]+)```"
|
||||
match = re.findall(pattern, s)
|
||||
if match:
|
||||
return eval(match[0])
|
||||
except:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
def get_memory(self, ):
|
||||
return self.memory.to_tuple_messages(content_key="step_content")
|
||||
|
||||
def get_memory_str(self, ):
|
||||
return "\n".join([": ".join(i) for i in self.memory.to_tuple_messages(content_key="step_content")])
|
|
@ -0,0 +1,138 @@
|
|||
from pydantic import BaseModel
|
||||
from typing import List, Union
|
||||
import re
|
||||
import traceback
|
||||
import copy
|
||||
from loguru import logger
|
||||
|
||||
from langchain.prompts.chat import ChatPromptTemplate
|
||||
|
||||
from dev_opsgpt.connector.connector_schema import Message
|
||||
from dev_opsgpt.connector.shcema.memory import Memory
|
||||
from dev_opsgpt.connector.connector_schema import Task, Env, Role, Message, ActionStatus
|
||||
from dev_opsgpt.llm_models import getChatModel
|
||||
from dev_opsgpt.connector.configs.agent_config import REACT_PROMPT_INPUT
|
||||
|
||||
from .base_agent import BaseAgent
|
||||
|
||||
|
||||
class ReactAgent(BaseAgent):
|
||||
def __init__(
|
||||
self,
|
||||
role: Role,
|
||||
task: Task = None,
|
||||
memory: Memory = None,
|
||||
chat_turn: int = 1,
|
||||
do_search: bool = False,
|
||||
do_doc_retrieval: bool = False,
|
||||
do_tool_retrieval: bool = False,
|
||||
temperature: float = 0.2,
|
||||
stop: Union[List[str], str] = "观察",
|
||||
do_filter: bool = True,
|
||||
do_use_self_memory: bool = True,
|
||||
# docs_prompt: str,
|
||||
# prompt_mamnger: PromptManager
|
||||
):
|
||||
super().__init__(role, task, memory, chat_turn, do_search, do_doc_retrieval,
|
||||
do_tool_retrieval, temperature, stop, do_filter,do_use_self_memory
|
||||
)
|
||||
|
||||
def run(self, query: Message, history: Memory = None, background: Memory = None) -> Message:
|
||||
step_nums = copy.deepcopy(self.chat_turn)
|
||||
react_memory = Memory([])
|
||||
# 问题插入
|
||||
output_message = Message(
|
||||
role_name=self.role.role_name,
|
||||
role_type="ai", #self.role.role_type,
|
||||
role_content=query.input_query,
|
||||
step_content=query.input_query,
|
||||
input_query=query.input_query,
|
||||
tools=query.tools
|
||||
)
|
||||
react_memory.append(output_message)
|
||||
idx = 0
|
||||
while step_nums > 0:
|
||||
output_message.role_content = output_message.step_content
|
||||
self_memory = self.memory if self.do_use_self_memory else None
|
||||
prompt = self.create_prompt(query, self_memory, history, background, react_memory)
|
||||
try:
|
||||
content = self.llm.predict(prompt)
|
||||
except Exception as e:
|
||||
logger.warning(f"error prompt: {prompt}")
|
||||
raise Exception(traceback.format_exc())
|
||||
|
||||
output_message.role_content = content
|
||||
output_message.role_contents += [content]
|
||||
output_message.step_content += output_message.role_content
|
||||
output_message.step_contents + [output_message.role_content]
|
||||
|
||||
# logger.debug(f"{self.role.role_name}, {idx} iteration prompt: {prompt}")
|
||||
# logger.info(f"{self.role.role_name}, {idx} iteration step_run: {output_message.role_content}")
|
||||
|
||||
output_message = self.parser(output_message)
|
||||
# when get finished signal can stop early
|
||||
if output_message.action_status == ActionStatus.FINISHED: break
|
||||
# according the output to choose one action for code_content or tool_content
|
||||
output_message = self.step_router(output_message)
|
||||
logger.info(f"{self.role.role_name} react_run: {output_message.role_content}")
|
||||
|
||||
idx += 1
|
||||
step_nums -= 1
|
||||
# react' self_memory saved at last
|
||||
self.append_history(output_message)
|
||||
return output_message
|
||||
|
||||
def create_prompt(
|
||||
self, query: Message, memory: Memory =None, history: Memory = None, background: Memory = None, react_memory: Memory = None, prompt_mamnger=None) -> str:
|
||||
'''
|
||||
role\task\tools\docs\memory
|
||||
'''
|
||||
#
|
||||
doc_infos = self.create_doc_prompt(query)
|
||||
code_infos = self.create_codedoc_prompt(query)
|
||||
#
|
||||
formatted_tools, tool_names = self.create_tools_prompt(query)
|
||||
task_prompt = self.create_task_prompt(query)
|
||||
background_prompt = self.create_background_prompt(background)
|
||||
history_prompt = self.create_history_prompt(history)
|
||||
selfmemory_prompt = self.create_selfmemory_prompt(memory, control_key="step_content")
|
||||
#
|
||||
# extra_system_prompt = self.role.role_prompt
|
||||
prompt = self.role.role_prompt.format(**{"formatted_tools": formatted_tools, "tool_names": tool_names})
|
||||
|
||||
|
||||
task = query.task or self.task
|
||||
if task_prompt is not None:
|
||||
prompt += "\n" + task.task_prompt
|
||||
|
||||
if doc_infos is not None and doc_infos!="" and doc_infos!="不存在知识库辅助信息":
|
||||
prompt += f"\n知识库信息: {doc_infos}"
|
||||
|
||||
if code_infos is not None and code_infos!="" and code_infos!="不存在代码库辅助信息":
|
||||
prompt += f"\n代码库信息: {code_infos}"
|
||||
|
||||
if background_prompt:
|
||||
prompt += "\n" + background_prompt
|
||||
|
||||
if history_prompt:
|
||||
prompt += "\n" + history_prompt
|
||||
|
||||
if selfmemory_prompt:
|
||||
prompt += "\n" + selfmemory_prompt
|
||||
|
||||
# react 流程是自身迭代过程,另外二次触发的是需要作为历史对话信息
|
||||
input_query = react_memory.to_tuple_messages(content_key="step_content")
|
||||
input_query = "\n".join([f"{v}" for k, v in input_query if v])
|
||||
|
||||
# logger.debug(f"{self.role.role_name} extra_system_prompt: {self.role.role_prompt}")
|
||||
# logger.debug(f"{self.role.role_name} input_query: {input_query}")
|
||||
# logger.debug(f"{self.role.role_name} doc_infos: {doc_infos}")
|
||||
# logger.debug(f"{self.role.role_name} tool_names: {tool_names}")
|
||||
prompt += "\n" + REACT_PROMPT_INPUT.format(**{"query": input_query})
|
||||
|
||||
# prompt = extra_system_prompt.format(**{"query": input_query, "doc_infos": doc_infos, "formatted_tools": formatted_tools, "tool_names": tool_names})
|
||||
while "{{" in prompt or "}}" in prompt:
|
||||
prompt = prompt.replace("{{", "{")
|
||||
prompt = prompt.replace("}}", "}")
|
||||
return prompt
|
||||
|
|
@ -0,0 +1,5 @@
|
|||
from .base_chain import BaseChain
|
||||
|
||||
__all__ = [
|
||||
"BaseChain"
|
||||
]
|
|
@ -0,0 +1,281 @@
|
|||
from pydantic import BaseModel
|
||||
from typing import List
|
||||
import json
|
||||
import re
|
||||
from loguru import logger
|
||||
import traceback
|
||||
import uuid
|
||||
import copy
|
||||
|
||||
from dev_opsgpt.connector.agents import BaseAgent
|
||||
from dev_opsgpt.tools.base_tool import BaseTools, Tool
|
||||
from dev_opsgpt.connector.shcema.memory import Memory
|
||||
from dev_opsgpt.connector.connector_schema import (
|
||||
Role, Message, ActionStatus, ChainConfig,
|
||||
load_role_configs
|
||||
)
|
||||
from dev_opsgpt.sandbox import PyCodeBox, CodeBoxResponse
|
||||
|
||||
|
||||
from configs.server_config import SANDBOX_SERVER
|
||||
|
||||
from dev_opsgpt.connector.configs.agent_config import AGETN_CONFIGS
|
||||
role_configs = load_role_configs(AGETN_CONFIGS)
|
||||
|
||||
|
||||
class BaseChain:
|
||||
def __init__(
|
||||
self,
|
||||
chainConfig: ChainConfig,
|
||||
agents: List[BaseAgent],
|
||||
chat_turn: int = 1,
|
||||
do_checker: bool = False,
|
||||
do_code_exec: bool = False,
|
||||
# prompt_mamnger: PromptManager
|
||||
) -> None:
|
||||
self.chainConfig = chainConfig
|
||||
self.agents = agents
|
||||
self.chat_turn = chat_turn
|
||||
self.do_checker = do_checker
|
||||
self.checker = BaseAgent(role=role_configs["checker"].role,
|
||||
task = None,
|
||||
memory = None,
|
||||
do_search = role_configs["checker"].do_search,
|
||||
do_doc_retrieval = role_configs["checker"].do_doc_retrieval,
|
||||
do_tool_retrieval = role_configs["checker"].do_tool_retrieval,
|
||||
do_filter=False, do_use_self_memory=False)
|
||||
|
||||
self.global_memory = Memory([])
|
||||
self.local_memory = Memory([])
|
||||
self.do_code_exec = do_code_exec
|
||||
self.codebox = PyCodeBox(
|
||||
remote_url=SANDBOX_SERVER["url"],
|
||||
remote_ip=SANDBOX_SERVER["host"],
|
||||
remote_port=SANDBOX_SERVER["port"],
|
||||
token="mytoken",
|
||||
do_code_exe=True,
|
||||
do_remote=SANDBOX_SERVER["do_remote"],
|
||||
do_check_net=False
|
||||
)
|
||||
|
||||
def step(self, query: Message, history: Memory = None, background: Memory = None) -> Message:
|
||||
'''execute chain'''
|
||||
local_memory = Memory([])
|
||||
input_message = copy.deepcopy(query)
|
||||
step_nums = copy.deepcopy(self.chat_turn)
|
||||
check_message = None
|
||||
|
||||
self.global_memory.append(input_message)
|
||||
local_memory.append(input_message)
|
||||
while step_nums > 0:
|
||||
|
||||
for agent in self.agents:
|
||||
output_message = agent.run(input_message, history, background=background)
|
||||
output_message = self.inherit_extrainfo(input_message, output_message)
|
||||
# according the output to choose one action for code_content or tool_content
|
||||
logger.info(f"{agent.role.role_name} message: {output_message.role_content}")
|
||||
output_message = self.parser(output_message)
|
||||
# output_message = self.step_router(output_message)
|
||||
|
||||
input_message = output_message
|
||||
self.global_memory.append(output_message)
|
||||
|
||||
local_memory.append(output_message)
|
||||
# when get finished signal can stop early
|
||||
if output_message.action_status == ActionStatus.FINISHED:
|
||||
break
|
||||
|
||||
if self.do_checker:
|
||||
logger.debug(f"{self.checker.role.role_name} input global memory: {self.global_memory.to_str_messages(content_key='step_content')}")
|
||||
check_message = self.checker.run(query, background=self.global_memory)
|
||||
check_message = self.parser(check_message)
|
||||
check_message = self.filter(check_message)
|
||||
check_message = self.inherit_extrainfo(output_message, check_message)
|
||||
logger.debug(f"{self.checker.role.role_name}: {check_message.role_content}")
|
||||
|
||||
if check_message.action_status == ActionStatus.FINISHED:
|
||||
self.global_memory.append(check_message)
|
||||
break
|
||||
|
||||
step_nums -= 1
|
||||
|
||||
return check_message or output_message, local_memory
|
||||
|
||||
def step_router(self, message: Message) -> Message:
|
||||
''''''
|
||||
# message = self.parser(message)
|
||||
# logger.debug(f"message.action_status: {message.action_status}")
|
||||
if message.action_status == ActionStatus.CODING:
|
||||
message = self.code_step(message)
|
||||
elif message.action_status == ActionStatus.TOOL_USING:
|
||||
message = self.tool_step(message)
|
||||
|
||||
return message
|
||||
|
||||
def code_step(self, message: Message) -> Message:
|
||||
'''execute code'''
|
||||
# logger.debug(f"message.role_content: {message.role_content}, message.code_content: {message.code_content}")
|
||||
code_answer = self.codebox.chat('```python\n{}```'.format(message.code_content))
|
||||
uid = str(uuid.uuid1())
|
||||
if code_answer.code_exe_type == "image/png":
|
||||
message.figures[uid] = code_answer.code_exe_response
|
||||
message.code_answer = f"\n观察: 执行代码后获得输出一张图片, 文件名为{uid}\n"
|
||||
message.observation = f"\n观察: 执行代码后获得输出一张图片, 文件名为{uid}\n"
|
||||
message.step_content += f"\n观察: 执行代码后获得输出一张图片, 文件名为{uid}\n"
|
||||
message.step_contents += [f"\n观察: 执行代码后获得输出一张图片, 文件名为{uid}\n"]
|
||||
message.role_content += f"\n执行代码后获得输出一张图片, 文件名为{uid}\n"
|
||||
else:
|
||||
message.code_answer = code_answer.code_exe_response
|
||||
message.observation = code_answer.code_exe_response
|
||||
message.step_content += f"\n观察: 执行代码后获得输出是 {code_answer.code_exe_response}\n"
|
||||
message.step_contents += [f"\n观察: 执行代码后获得输出是 {code_answer.code_exe_response}\n"]
|
||||
message.role_content += f"\n观察: 执行代码后获得输出是 {code_answer.code_exe_response}\n"
|
||||
logger.info(f"观察: {message.action_status}, {message.observation}")
|
||||
return message
|
||||
|
||||
def tool_step(self, message: Message) -> Message:
|
||||
'''execute tool'''
|
||||
# logger.debug(f"message: {message.action_status}, {message.tool_name}, {message.tool_params}")
|
||||
tool_names = [tool.name for tool in message.tools]
|
||||
if message.tool_name not in tool_names:
|
||||
message.tool_answer = "不存在可以执行的tool"
|
||||
message.observation = "不存在可以执行的tool"
|
||||
message.role_content += f"\n观察: 不存在可以执行的tool\n"
|
||||
message.step_content += f"\n观察: 不存在可以执行的tool\n"
|
||||
message.step_contents += [f"\n观察: 不存在可以执行的tool\n"]
|
||||
for tool in message.tools:
|
||||
if tool.name == message.tool_name:
|
||||
tool_res = tool.func(**message.tool_params)
|
||||
message.tool_answer = tool_res
|
||||
message.observation = tool_res
|
||||
message.role_content += f"\n观察: {tool_res}\n"
|
||||
message.step_content += f"\n观察: {tool_res}\n"
|
||||
message.step_contents += [f"\n观察: {tool_res}\n"]
|
||||
return message
|
||||
|
||||
def filter(self, message: Message, stop=None) -> Message:
|
||||
|
||||
tool_params = self.parser_spec_key(message.role_content, "tool_params")
|
||||
code_content = self.parser_spec_key(message.role_content, "code_content")
|
||||
plan = self.parser_spec_key(message.role_content, "plan")
|
||||
plans = self.parser_spec_key(message.role_content, "plans", do_search=False)
|
||||
content = self.parser_spec_key(message.role_content, "content", do_search=False)
|
||||
|
||||
# logger.debug(f"tool_params: {tool_params}, code_content: {code_content}, plan: {plan}, plans: {plans}, content: {content}")
|
||||
role_content = tool_params or code_content or plan or plans or content
|
||||
message.role_content = role_content or message.role_content
|
||||
return message
|
||||
|
||||
def parser(self, message: Message) -> Message:
|
||||
''''''
|
||||
content = message.role_content
|
||||
parser_keys = ["action", "code_content", "code_filename", "tool_params", "plans"]
|
||||
try:
|
||||
s_json = self._parse_json(content)
|
||||
message.action_status = s_json.get("action")
|
||||
message.code_content = s_json.get("code_content")
|
||||
message.tool_params = s_json.get("tool_params")
|
||||
message.tool_name = s_json.get("tool_name")
|
||||
message.code_filename = s_json.get("code_filename")
|
||||
message.plans = s_json.get("plans")
|
||||
# for parser_key in parser_keys:
|
||||
# message.action_status = content.get(parser_key)
|
||||
except Exception as e:
|
||||
# logger.warning(f"{traceback.format_exc()}")
|
||||
action_value = self._match(r"'action':\s*'([^']*)'", content) if "'action'" in content else self._match(r'"action":\s*"([^"]*)"', content)
|
||||
code_content_value = self._match(r"'code_content':\s*'([^']*)'", content) if "'code_content'" in content else self._match(r'"code_content":\s*"([^"]*)"', content)
|
||||
filename_value = self._match(r"'code_filename':\s*'([^']*)'", content) if "'code_filename'" in content else self._match(r'"code_filename":\s*"([^"]*)"', content)
|
||||
tool_params_value = self._match(r"'tool_params':\s*(\{[^{}]*\})", content, do_json=True) if "'tool_params'" in content \
|
||||
else self._match(r'"tool_params":\s*(\{[^{}]*\})', content, do_json=True)
|
||||
tool_name_value = self._match(r"'tool_name':\s*'([^']*)'", content) if "'tool_name'" in content else self._match(r'"tool_name":\s*"([^"]*)"', content)
|
||||
plans_value = self._match(r"'plans':\s*(\[.*?\])", content, do_search=False) if "'plans'" in content else self._match(r'"plans":\s*(\[.*?\])', content, do_search=False, )
|
||||
# re解析
|
||||
message.action_status = action_value or "default"
|
||||
message.code_content = code_content_value
|
||||
message.code_filename = filename_value
|
||||
message.tool_params = tool_params_value
|
||||
message.tool_name = tool_name_value
|
||||
message.plans = plans_value
|
||||
|
||||
logger.debug(f"确认当前的action: {message.action_status}")
|
||||
|
||||
return message
|
||||
|
||||
def parser_spec_key(self, content, key, do_search=True, do_json=False) -> str:
|
||||
''''''
|
||||
key2pattern = {
|
||||
"'action'": r"'action':\s*'([^']*)'", '"action"': r'"action":\s*"([^"]*)"',
|
||||
"'code_content'": r"'code_content':\s*'([^']*)'", '"code_content"': r'"code_content":\s*"([^"]*)"',
|
||||
"'code_filename'": r"'code_filename':\s*'([^']*)'", '"code_filename"': r'"code_filename":\s*"([^"]*)"',
|
||||
"'tool_params'": r"'tool_params':\s*(\{[^{}]*\})", '"tool_params"': r'"tool_params":\s*(\{[^{}]*\})',
|
||||
"'tool_name'": r"'tool_name':\s*'([^']*)'", '"tool_name"': r'"tool_name":\s*"([^"]*)"',
|
||||
"'plans'": r"'plans':\s*(\[.*?\])", '"plans"': r'"plans":\s*(\[.*?\])',
|
||||
"'content'": r"'content':\s*'([^']*)'", '"content"': r'"content":\s*"([^"]*)"',
|
||||
}
|
||||
|
||||
s_json = self._parse_json(content)
|
||||
try:
|
||||
if s_json and key in s_json:
|
||||
return str(s_json[key])
|
||||
except:
|
||||
pass
|
||||
|
||||
keystr = f"'{key}'" if f"'{key}'" in content else f'"{key}"'
|
||||
return self._match(key2pattern.get(keystr, fr"'{key}':\s*'([^']*)'"), content, do_search=do_search, do_json=do_json)
|
||||
|
||||
def _match(self, pattern, s, do_search=True, do_json=False):
|
||||
try:
|
||||
if do_search:
|
||||
match = re.search(pattern, s)
|
||||
if match:
|
||||
value = match.group(1).replace("\\n", "\n")
|
||||
if do_json:
|
||||
value = json.loads(value)
|
||||
else:
|
||||
value = None
|
||||
else:
|
||||
match = re.findall(pattern, s, re.DOTALL)
|
||||
if match:
|
||||
value = match[0]
|
||||
if do_json:
|
||||
value = json.loads(value)
|
||||
else:
|
||||
value = None
|
||||
except Exception as e:
|
||||
logger.warning(f"{traceback.format_exc()}")
|
||||
|
||||
# logger.debug(f"pattern: {pattern}, s: {s}, match: {match}")
|
||||
return value
|
||||
|
||||
def _parse_json(self, s):
|
||||
try:
|
||||
pattern = r"```([^`]+)```"
|
||||
match = re.findall(pattern, s)
|
||||
if match:
|
||||
return eval(match[0])
|
||||
except:
|
||||
pass
|
||||
return None
|
||||
|
||||
def inherit_extrainfo(self, input_message: Message, output_message: Message):
|
||||
output_message.db_docs = input_message.db_docs
|
||||
output_message.search_docs = input_message.search_docs
|
||||
output_message.code_docs = input_message.code_docs
|
||||
output_message.figures.update(input_message.figures)
|
||||
return output_message
|
||||
|
||||
def get_memory(self, do_all_memory=True, content_key="role_content") -> Memory:
|
||||
memory = self.global_memory if do_all_memory else self.local_memory
|
||||
return memory.to_tuple_messages(content_key=content_key)
|
||||
|
||||
def get_memory_str(self, do_all_memory=True, content_key="role_content") -> Memory:
|
||||
memory = self.global_memory if do_all_memory else self.local_memory
|
||||
# for i in memory.to_tuple_messages(content_key=content_key):
|
||||
# logger.debug(f"{i}")
|
||||
return "\n".join([": ".join(i) for i in memory.to_tuple_messages(content_key=content_key)])
|
||||
|
||||
def get_agents_memory(self, content_key="role_content"):
|
||||
return [agent.get_memory(content_key=content_key) for agent in self.agents]
|
||||
|
||||
def get_agents_memory_str(self, content_key="role_content"):
|
||||
return "************".join([f"{agent.role.role_name}\n" + agent.get_memory_str(content_key=content_key) for agent in self.agents])
|
|
@ -0,0 +1,28 @@
|
|||
from typing import List
|
||||
from dev_opsgpt.connector.agents import BaseAgent
|
||||
from .base_chain import BaseChain
|
||||
|
||||
|
||||
|
||||
class simpleChatChain(BaseChain):
|
||||
|
||||
def __init__(self, agents: List[BaseAgent], do_code_exec: bool = False) -> None:
|
||||
super().__init__(agents, do_code_exec)
|
||||
|
||||
|
||||
class toolChatChain(BaseChain):
|
||||
|
||||
def __init__(self, agents: List[BaseAgent], do_code_exec: bool = False) -> None:
|
||||
super().__init__(agents, do_code_exec)
|
||||
|
||||
|
||||
class dataAnalystChain(BaseChain):
|
||||
|
||||
def __init__(self, agents: List[BaseAgent], do_code_exec: bool = False) -> None:
|
||||
super().__init__(agents, do_code_exec)
|
||||
|
||||
|
||||
class plannerChain(BaseChain):
|
||||
|
||||
def __init__(self, agents: List[BaseAgent], do_code_exec: bool = False) -> None:
|
||||
super().__init__(agents, do_code_exec)
|
|
@ -0,0 +1,7 @@
|
|||
from .agent_config import AGETN_CONFIGS
|
||||
from .chain_config import CHAIN_CONFIGS
|
||||
from .phase_config import PHASE_CONFIGS
|
||||
|
||||
__all__ = [
|
||||
"AGETN_CONFIGS", "CHAIN_CONFIGS", "PHASE_CONFIGS"
|
||||
]
|
|
@ -0,0 +1,410 @@
|
|||
from enum import Enum
|
||||
|
||||
|
||||
class AgentType:
|
||||
REACT = "ReactAgent"
|
||||
ONE_STEP = "BaseAgent"
|
||||
DEFAULT = "BaseAgent"
|
||||
|
||||
|
||||
REACT_TOOL_PROMPT = """尽可能地以有帮助和准确的方式回应人类。您可以使用以下工具:
|
||||
{formatted_tools}
|
||||
使用json blob来指定一个工具,提供一个action关键字(工具名称)和一个tool_params关键字(工具输入)。
|
||||
有效的"action"值为:"finished" 或 "tool_using" (使用工具来回答问题)
|
||||
有效的"tool_name"值为:{tool_names}
|
||||
请仅在每个$JSON_BLOB中提供一个action,如下所示:
|
||||
```
|
||||
{{{{
|
||||
"action": $ACTION,
|
||||
"tool_name": $TOOL_NAME
|
||||
"tool_params": $INPUT
|
||||
}}}}
|
||||
```
|
||||
|
||||
按照以下格式进行回应:
|
||||
问题:输入问题以回答
|
||||
思考:考虑之前和之后的步骤
|
||||
行动:
|
||||
```
|
||||
$JSON_BLOB
|
||||
```
|
||||
观察:行动结果
|
||||
...(重复思考/行动/观察N次)
|
||||
思考:我知道该如何回应
|
||||
行动:
|
||||
```
|
||||
{{{{
|
||||
"action": "finished",
|
||||
"tool_name": "notool"
|
||||
"tool_params": "最终返回答案给到用户"
|
||||
}}}}
|
||||
```
|
||||
"""
|
||||
|
||||
REACT_PROMPT_INPUT = '''下面开始!记住根据问题进行返回需要生成的答案
|
||||
问题: {query}'''
|
||||
|
||||
|
||||
REACT_CODE_PROMPT = """尽可能地以有帮助和准确的方式回应人类,能够逐步编写可执行并打印变量的代码来解决问题
|
||||
使用 JSON Blob 来指定一个返回的内容,提供一个 action(行动)和一个 code (生成代码)。
|
||||
有效的 'action' 值为:'coding'(结合总结下述思维链过程编写下一步的可执行代码) or 'finished' (总结下述思维链过程可回答问题)。
|
||||
在每个 $JSON_BLOB 中仅提供一个 action,如下所示:
|
||||
```
|
||||
{{{{'action': $ACTION,'code_content': $CODE}}}}
|
||||
```
|
||||
|
||||
按照以下思维链格式进行回应:
|
||||
问题:输入问题以回答
|
||||
思考:考虑之前和之后的步骤
|
||||
行动:
|
||||
```
|
||||
$JSON_BLOB
|
||||
```
|
||||
观察:行动结果
|
||||
...(重复思考/行动/观察N次)
|
||||
思考:我知道该如何回应
|
||||
行动:
|
||||
```
|
||||
{{{{
|
||||
"action": "finished",
|
||||
"code_content": "总结上述思维链过程回答问题"
|
||||
}}}}
|
||||
```
|
||||
"""
|
||||
|
||||
GENERAL_PLANNER_PROMPT = """你是一个通用计划拆解助手,将问题拆解问题成各个详细明确的步骤计划或直接回答问题,尽可能地以有帮助和准确的方式回应人类,
|
||||
使用 JSON Blob 来指定一个返回的内容,提供一个 action(行动)和一个 plans (生成的计划)。
|
||||
有效的 'action' 值为:'planning'(拆解计划) or 'only_answer' (不需要拆解问题即可直接回答问题)。
|
||||
有效的 'plans' 值为: 一个任务列表,按顺序写出需要执行的计划
|
||||
在每个 $JSON_BLOB 中仅提供一个 action,如下所示:
|
||||
```
|
||||
{{'action': 'planning', 'plans': [$PLAN1, $PLAN2, $PLAN3, ..., $PLANN], }}
|
||||
或者
|
||||
{{'action': 'only_answer', 'plans': "直接回答问题", }}
|
||||
```
|
||||
|
||||
按照以下格式进行回应:
|
||||
问题:输入问题以回答
|
||||
行动:
|
||||
```
|
||||
$JSON_BLOB
|
||||
```
|
||||
"""
|
||||
|
||||
DATA_PLANNER_PROMPT = """你是一个数据分析助手,能够根据问题来制定一个详细明确的数据分析计划,尽可能地以有帮助和准确的方式回应人类,
|
||||
使用 JSON Blob 来指定一个返回的内容,提供一个 action(行动)和一个 plans (生成的计划)。
|
||||
有效的 'action' 值为:'planning'(拆解计划) or 'only_answer' (不需要拆解问题即可直接回答问题)。
|
||||
有效的 'plans' 值为: 一份数据分析计划清单,按顺序排列,用文本表示
|
||||
在每个 $JSON_BLOB 中仅提供一个 action,如下所示:
|
||||
```
|
||||
{{'action': 'planning', 'plans': '$PLAN1, $PLAN2, ..., $PLAN3' }}
|
||||
```
|
||||
|
||||
按照以下格式进行回应:
|
||||
问题:输入问题以回答
|
||||
行动:
|
||||
```
|
||||
$JSON_BLOB
|
||||
```
|
||||
"""
|
||||
|
||||
TOOL_PLANNER_PROMPT = """你是一个工具使用过程的计划拆解助手,将问题拆解为一系列的工具使用计划,若没有可用工具则直接回答问题,尽可能地以有帮助和准确的方式回应人类,你可以使用以下工具:
|
||||
{formatted_tools}
|
||||
使用 JSON Blob 来指定一个返回的内容,提供一个 action(行动)和一个 plans (生成的计划)。
|
||||
有效的 'action' 值为:'planning'(拆解计划) or 'only_answer' (不需要拆解问题即可直接回答问题)。
|
||||
有效的 'plans' 值为: 一个任务列表,按顺序写出需要使用的工具和使用该工具的理由
|
||||
在每个 $JSON_BLOB 中仅提供一个 action,如下两个示例所示:
|
||||
```
|
||||
{{'action': 'planning', 'plans': [$PLAN1, $PLAN2, $PLAN3, ..., $PLANN], }}
|
||||
```
|
||||
或者 若无法通过以上工具解决问题,则直接回答问题
|
||||
```
|
||||
{{'action': 'only_answer', 'plans': "直接回答问题", }}
|
||||
```
|
||||
|
||||
按照以下格式进行回应:
|
||||
问题:输入问题以回答
|
||||
行动:
|
||||
```
|
||||
$JSON_BLOB
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
RECOGNIZE_INTENTION_PROMPT = """你是一个任务决策助手,能够将理解用户意图并决策采取最合适的行动,尽可能地以有帮助和准确的方式回应人类,
|
||||
使用 JSON Blob 来指定一个返回的内容,提供一个 action(行动)。
|
||||
有效的 'action' 值为:'planning'(需要先进行拆解计划) or 'only_answer' (不需要拆解问题即可直接回答问题)or "tool_using" (使用工具来回答问题) or 'coding'(生成可执行的代码)。
|
||||
在每个 $JSON_BLOB 中仅提供一个 action,如下所示:
|
||||
```
|
||||
{{'action': $ACTION}}
|
||||
```
|
||||
按照以下格式进行回应:
|
||||
问题:输入问题以回答
|
||||
行动:$ACTION
|
||||
```
|
||||
$JSON_BLOB
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
CHECKER_PROMPT = """尽可能地以有帮助和准确的方式回应人类,判断问题是否得到解答,同时展现解答的过程和内容
|
||||
使用 JSON Blob 来指定一个返回的内容,提供一个 action(行动)。
|
||||
有效的 'action' 值为:'finished'(任务已经可以通过“背景信息”和“对话信息”回答问题) or 'continue' (“背景信息”和“对话信息”不足以回答问题)。
|
||||
在每个 $JSON_BLOB 中仅提供一个 action,如下所示:
|
||||
```
|
||||
{{'action': $ACTION, 'content': '提取“背景信息”和“对话信息”中信息来回答问题'}}
|
||||
```
|
||||
按照以下格式进行回应:
|
||||
问题:输入问题以回答
|
||||
行动:$ACTION
|
||||
```
|
||||
$JSON_BLOB
|
||||
```
|
||||
"""
|
||||
|
||||
CONV_SUMMARY_PROMPT = """尽可能地以有帮助和准确的方式回应人类,根据“背景信息”中的有效信息回答问题,
|
||||
使用 JSON Blob 来指定一个返回的内容,提供一个 action(行动)。
|
||||
有效的 'action' 值为:'finished'(任务已经可以通过上下文信息可以回答) or 'continue' (根据背景信息回答问题)。
|
||||
在每个 $JSON_BLOB 中仅提供一个 action,如下所示:
|
||||
```
|
||||
{{'action': $ACTION, 'content': '根据背景信息回答问题'}}
|
||||
```
|
||||
按照以下格式进行回应:
|
||||
问题:输入问题以回答
|
||||
行动:
|
||||
```
|
||||
$JSON_BLOB
|
||||
```
|
||||
"""
|
||||
|
||||
CONV_SUMMARY_PROMPT = """尽可能地以有帮助和准确的方式回应人类
|
||||
根据“背景信息”中的有效信息回答问题,同时展现解答的过程和内容
|
||||
若能根“背景信息”回答问题,则直接回答
|
||||
否则,总结“背景信息”的内容
|
||||
"""
|
||||
|
||||
|
||||
|
||||
QA_PROMPT = """根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,不允许在答案中添加编造成分,答案请使用中文。
|
||||
使用 JSON Blob 来指定一个返回的内容,提供一个 action(行动)。
|
||||
有效的 'action' 值为:'finished'(任务已经可以通过上下文信息可以回答) or 'continue' (上下文信息不足以回答问题)。
|
||||
在每个 $JSON_BLOB 中仅提供一个 action,如下所示:
|
||||
```
|
||||
{{'action': $ACTION, 'content': '总结对话内容'}}
|
||||
```
|
||||
按照以下格式进行回应:
|
||||
问题:输入问题以回答
|
||||
行动:$ACTION
|
||||
```
|
||||
$JSON_BLOB
|
||||
```
|
||||
"""
|
||||
|
||||
CODE_QA_PROMPT = """【指令】根据已知信息来回答问"""
|
||||
|
||||
|
||||
AGETN_CONFIGS = {
|
||||
"checker": {
|
||||
"role": {
|
||||
"role_prompt": CHECKER_PROMPT,
|
||||
"role_type": "ai",
|
||||
"role_name": "checker",
|
||||
"role_desc": "",
|
||||
"agent_type": "BaseAgent"
|
||||
},
|
||||
"chat_turn": 1,
|
||||
"do_search": False,
|
||||
"do_doc_retrieval": False,
|
||||
"do_tool_retrieval": False
|
||||
},
|
||||
"conv_summary": {
|
||||
"role": {
|
||||
"role_prompt": CONV_SUMMARY_PROMPT,
|
||||
"role_type": "ai",
|
||||
"role_name": "conv_summary",
|
||||
"role_desc": "",
|
||||
"agent_type": "BaseAgent"
|
||||
},
|
||||
"chat_turn": 1,
|
||||
"do_search": False,
|
||||
"do_doc_retrieval": False,
|
||||
"do_tool_retrieval": False
|
||||
},
|
||||
"general_planner": {
|
||||
"role": {
|
||||
"role_prompt": GENERAL_PLANNER_PROMPT,
|
||||
"role_type": "ai",
|
||||
"role_name": "general_planner",
|
||||
"role_desc": "",
|
||||
"agent_type": "BaseAgent"
|
||||
},
|
||||
"chat_turn": 1,
|
||||
"do_search": False,
|
||||
"do_doc_retrieval": False,
|
||||
"do_tool_retrieval": False
|
||||
},
|
||||
"planner": {
|
||||
"role": {
|
||||
"role_prompt": DATA_PLANNER_PROMPT,
|
||||
"role_type": "ai",
|
||||
"role_name": "planner",
|
||||
"role_desc": "",
|
||||
"agent_type": "BaseAgent"
|
||||
},
|
||||
"chat_turn": 1,
|
||||
"do_search": False,
|
||||
"do_doc_retrieval": False,
|
||||
"do_tool_retrieval": False
|
||||
},
|
||||
"intention_recognizer": {
|
||||
"role": {
|
||||
"role_prompt": RECOGNIZE_INTENTION_PROMPT,
|
||||
"role_type": "ai",
|
||||
"role_name": "intention_recognizer",
|
||||
"role_desc": "",
|
||||
"agent_type": "BaseAgent"
|
||||
},
|
||||
"chat_turn": 1,
|
||||
"do_search": False,
|
||||
"do_doc_retrieval": False,
|
||||
"do_tool_retrieval": False
|
||||
},
|
||||
"tool_planner": {
|
||||
"role": {
|
||||
"role_prompt": TOOL_PLANNER_PROMPT,
|
||||
"role_type": "ai",
|
||||
"role_name": "tool_planner",
|
||||
"role_desc": "",
|
||||
"agent_type": "BaseAgent"
|
||||
},
|
||||
"chat_turn": 1,
|
||||
"do_search": False,
|
||||
"do_doc_retrieval": False,
|
||||
"do_tool_retrieval": False
|
||||
},
|
||||
"tool_react": {
|
||||
"role": {
|
||||
"role_prompt": REACT_TOOL_PROMPT,
|
||||
"role_type": "ai",
|
||||
"role_name": "tool_react",
|
||||
"role_desc": "",
|
||||
"agent_type": "ReactAgent"
|
||||
},
|
||||
"chat_turn": 5,
|
||||
"do_search": False,
|
||||
"do_doc_retrieval": False,
|
||||
"do_tool_retrieval": False,
|
||||
"stop": "观察"
|
||||
},
|
||||
"code_react": {
|
||||
"role": {
|
||||
"role_prompt": REACT_CODE_PROMPT,
|
||||
"role_type": "ai",
|
||||
"role_name": "code_react",
|
||||
"role_desc": "",
|
||||
"agent_type": "ReactAgent"
|
||||
},
|
||||
"chat_turn": 5,
|
||||
"do_search": False,
|
||||
"do_doc_retrieval": False,
|
||||
"do_tool_retrieval": False,
|
||||
"stop": "观察"
|
||||
},
|
||||
"qaer": {
|
||||
"role": {
|
||||
"role_prompt": QA_PROMPT,
|
||||
"role_type": "ai",
|
||||
"role_name": "qaer",
|
||||
"role_desc": "",
|
||||
"agent_type": "BaseAgent"
|
||||
},
|
||||
"chat_turn": 1,
|
||||
"do_search": False,
|
||||
"do_doc_retrieval": True,
|
||||
"do_tool_retrieval": False
|
||||
},
|
||||
"code_qaer": {
|
||||
"role": {
|
||||
"role_prompt": CODE_QA_PROMPT ,
|
||||
"role_type": "ai",
|
||||
"role_name": "code_qaer",
|
||||
"role_desc": "",
|
||||
"agent_type": "BaseAgent"
|
||||
},
|
||||
"chat_turn": 1,
|
||||
"do_search": False,
|
||||
"do_doc_retrieval": True,
|
||||
"do_tool_retrieval": False
|
||||
},
|
||||
"searcher": {
|
||||
"role": {
|
||||
"role_prompt": QA_PROMPT,
|
||||
"role_type": "ai",
|
||||
"role_name": "searcher",
|
||||
"role_desc": "",
|
||||
"agent_type": "BaseAgent"
|
||||
},
|
||||
"chat_turn": 1,
|
||||
"do_search": True,
|
||||
"do_doc_retrieval": False,
|
||||
"do_tool_retrieval": False
|
||||
},
|
||||
"answer": {
|
||||
"role": {
|
||||
"role_prompt": "",
|
||||
"role_type": "ai",
|
||||
"role_name": "answer",
|
||||
"role_desc": "",
|
||||
"agent_type": "BaseAgent"
|
||||
},
|
||||
"chat_turn": 1,
|
||||
"do_search": False,
|
||||
"do_doc_retrieval": False,
|
||||
"do_tool_retrieval": False
|
||||
},
|
||||
"data_analyst": {
|
||||
"role": {
|
||||
"role_prompt": """你是一个数据分析的代码开发助手,能够编写可执行的代码来完成相关的数据分析问题,使用 JSON Blob 来指定一个返回的内容,通过提供一个 action(行动)和一个 code (生成代码)和 一个 file_name (指定保存文件)。\
|
||||
有效的 'action' 值为:'coding'(生成可执行的代码) or 'finished' (不生成代码并直接返回答案)。在每个 $JSON_BLOB 中仅提供一个 action,如下所示:\
|
||||
```\n{{'action': $ACTION,'code_content': $CODE, 'code_filename': $FILE_NAME}}```\
|
||||
下面开始!记住根据问题进行返回需要生成的答案,格式为 ```JSON_BLOB```""",
|
||||
"role_type": "ai",
|
||||
"role_name": "data_analyst",
|
||||
"role_desc": "",
|
||||
"agent_type": "BaseAgent"
|
||||
},
|
||||
"chat_turn": 1,
|
||||
"do_search": False,
|
||||
"do_doc_retrieval": False,
|
||||
"do_tool_retrieval": False
|
||||
},
|
||||
"deveploer": {
|
||||
"role": {
|
||||
"role_prompt": """你是一个代码开发助手,能够编写可执行的代码来完成问题,使用 JSON Blob 来指定一个返回的内容,通过提供一个 action(行动)和一个 code (生成代码)和 一个 file_name (指定保存文件)。\
|
||||
有效的 'action' 值为:'coding'(生成可执行的代码) or 'finished' (不生成代码并直接返回答案)。在每个 $JSON_BLOB 中仅提供一个 action,如下所示:\
|
||||
```\n{{'action': $ACTION,'code_content': $CODE, 'code_filename': $FILE_NAME}}```\
|
||||
下面开始!记住根据问题进行返回需要生成的答案,格式为 ```JSON_BLOB```""",
|
||||
"role_type": "ai",
|
||||
"role_name": "deveploer",
|
||||
"role_desc": "",
|
||||
"agent_type": "BaseAgent"
|
||||
},
|
||||
"chat_turn": 1,
|
||||
"do_search": False,
|
||||
"do_doc_retrieval": False,
|
||||
"do_tool_retrieval": False
|
||||
},
|
||||
"tester": {
|
||||
"role": {
|
||||
"role_prompt": "你是一个QA问答的助手,能够尽可能准确地回答问题,下面请逐步思考问题并回答",
|
||||
"role_type": "ai",
|
||||
"role_name": "tester",
|
||||
"role_desc": "",
|
||||
"agent_type": "BaseAgent"
|
||||
},
|
||||
"chat_turn": 1,
|
||||
"do_search": False,
|
||||
"do_doc_retrieval": False,
|
||||
"do_tool_retrieval": False
|
||||
}
|
||||
}
|
|
@ -0,0 +1,88 @@
|
|||
|
||||
|
||||
CHAIN_CONFIGS = {
|
||||
"chatChain": {
|
||||
"chain_name": "chatChain",
|
||||
"chain_type": "BaseChain",
|
||||
"agents": ["answer"],
|
||||
"chat_turn": 1,
|
||||
"do_checker": False,
|
||||
"clear_structure": "True",
|
||||
"brainstorming": "False",
|
||||
"gui_design": "True",
|
||||
"git_management": "False",
|
||||
"self_improve": "False"
|
||||
},
|
||||
"docChatChain": {
|
||||
"chain_name": "docChatChain",
|
||||
"chain_type": "BaseChain",
|
||||
"agents": ["qaer"],
|
||||
"chat_turn": 1,
|
||||
"do_checker": False,
|
||||
"clear_structure": "True",
|
||||
"brainstorming": "False",
|
||||
"gui_design": "True",
|
||||
"git_management": "False",
|
||||
"self_improve": "False"
|
||||
},
|
||||
"searchChatChain": {
|
||||
"chain_name": "searchChatChain",
|
||||
"chain_type": "BaseChain",
|
||||
"agents": ["searcher"],
|
||||
"chat_turn": 1,
|
||||
"do_checker": False,
|
||||
"clear_structure": "True",
|
||||
"brainstorming": "False",
|
||||
"gui_design": "True",
|
||||
"git_management": "False",
|
||||
"self_improve": "False"
|
||||
},
|
||||
"codeChatChain": {
|
||||
"chain_name": "codehChatChain",
|
||||
"chain_type": "BaseChain",
|
||||
"agents": ["code_qaer"],
|
||||
"chat_turn": 1,
|
||||
"do_checker": False,
|
||||
"clear_structure": "True",
|
||||
"brainstorming": "False",
|
||||
"gui_design": "True",
|
||||
"git_management": "False",
|
||||
"self_improve": "False"
|
||||
},
|
||||
"toolReactChain": {
|
||||
"chain_name": "toolReactChain",
|
||||
"chain_type": "BaseChain",
|
||||
"agents": ["tool_planner", "tool_react"],
|
||||
"chat_turn": 2,
|
||||
"do_checker": True,
|
||||
"clear_structure": "True",
|
||||
"brainstorming": "False",
|
||||
"gui_design": "True",
|
||||
"git_management": "False",
|
||||
"self_improve": "False"
|
||||
},
|
||||
"codeReactChain": {
|
||||
"chain_name": "codeReactChain",
|
||||
"chain_type": "BaseChain",
|
||||
"agents": ["planner", "code_react"],
|
||||
"chat_turn": 2,
|
||||
"do_checker": True,
|
||||
"clear_structure": "True",
|
||||
"brainstorming": "False",
|
||||
"gui_design": "True",
|
||||
"git_management": "False",
|
||||
"self_improve": "False"
|
||||
},
|
||||
"dataAnalystChain": {
|
||||
"chain_name": "dataAnalystChain",
|
||||
"chain_type": "BaseChain",
|
||||
"agents": ["planner", "code_react"],
|
||||
"chat_turn": 2,
|
||||
"do_checker": True,
|
||||
"clear_structure": "True",
|
||||
"brainstorming": "False",
|
||||
"gui_design": "True",
|
||||
"git_management": "False",
|
||||
"self_improve": "False"
|
||||
},
|
||||
}
|
|
@ -0,0 +1,79 @@
|
|||
PHASE_CONFIGS = {
|
||||
"chatPhase": {
|
||||
"phase_name": "chatPhase",
|
||||
"phase_type": "BasePhase",
|
||||
"chains": ["chatChain"],
|
||||
"do_summary": False,
|
||||
"do_search": False,
|
||||
"do_doc_retrieval": False,
|
||||
"do_code_retrieval": False,
|
||||
"do_tool_retrieval": False,
|
||||
"do_using_tool": False
|
||||
},
|
||||
"docChatPhase": {
|
||||
"phase_name": "docChatPhase",
|
||||
"phase_type": "BasePhase",
|
||||
"chains": ["docChatChain"],
|
||||
"do_summary": False,
|
||||
"do_search": False,
|
||||
"do_doc_retrieval": True,
|
||||
"do_code_retrieval": False,
|
||||
"do_tool_retrieval": False,
|
||||
"do_using_tool": False
|
||||
},
|
||||
"searchChatPhase": {
|
||||
"phase_name": "searchChatPhase",
|
||||
"phase_type": "BasePhase",
|
||||
"chains": ["searchChatChain"],
|
||||
"do_summary": False,
|
||||
"do_search": True,
|
||||
"do_doc_retrieval": False,
|
||||
"do_code_retrieval": False,
|
||||
"do_tool_retrieval": False,
|
||||
"do_using_tool": False
|
||||
},
|
||||
"codeChatPhase": {
|
||||
"phase_name": "codeChatPhase",
|
||||
"phase_type": "BasePhase",
|
||||
"chains": ["codeChatChain"],
|
||||
"do_summary": False,
|
||||
"do_search": False,
|
||||
"do_doc_retrieval": False,
|
||||
"do_code_retrieval": True,
|
||||
"do_tool_retrieval": False,
|
||||
"do_using_tool": False
|
||||
},
|
||||
"toolReactPhase": {
|
||||
"phase_name": "toolReactPhase",
|
||||
"phase_type": "BasePhase",
|
||||
"chains": ["toolReactChain"],
|
||||
"do_summary": False,
|
||||
"do_search": False,
|
||||
"do_doc_retrieval": False,
|
||||
"do_code_retrieval": False,
|
||||
"do_tool_retrieval": False,
|
||||
"do_using_tool": True
|
||||
},
|
||||
"codeReactPhase": {
|
||||
"phase_name": "codeReacttPhase",
|
||||
"phase_type": "BasePhase",
|
||||
"chains": ["codeReactChain"],
|
||||
"do_summary": False,
|
||||
"do_search": False,
|
||||
"do_doc_retrieval": False,
|
||||
"do_code_retrieval": False,
|
||||
"do_tool_retrieval": False,
|
||||
"do_using_tool": False
|
||||
},
|
||||
"dataReactPhase": {
|
||||
"phase_name": "dataReactPhase",
|
||||
"phase_type": "BasePhase",
|
||||
"chains": ["dataAnalystChain"],
|
||||
"do_summary": True,
|
||||
"do_search": False,
|
||||
"do_doc_retrieval": False,
|
||||
"do_code_retrieval": False,
|
||||
"do_tool_retrieval": False,
|
||||
"do_using_tool": False
|
||||
}
|
||||
}
|
|
@ -0,0 +1,248 @@
|
|||
from pydantic import BaseModel
|
||||
from typing import List, Dict
|
||||
from enum import Enum
|
||||
import re
|
||||
import json
|
||||
from loguru import logger
|
||||
from langchain.tools import BaseTool
|
||||
|
||||
|
||||
class ActionStatus(Enum):
|
||||
FINISHED = "finished"
|
||||
CODING = "coding"
|
||||
TOOL_USING = "tool_using"
|
||||
REASONING = "reasoning"
|
||||
PLANNING = "planning"
|
||||
EXECUTING_CODE = "executing_code"
|
||||
EXECUTING_TOOL = "executing_tool"
|
||||
DEFAUILT = "default"
|
||||
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, str):
|
||||
return self.value == other
|
||||
return super().__eq__(other)
|
||||
|
||||
class Doc(BaseModel):
|
||||
title: str
|
||||
snippet: str
|
||||
link: str
|
||||
index: int
|
||||
|
||||
def get_title(self):
|
||||
return self.title
|
||||
|
||||
def get_snippet(self, ):
|
||||
return self.snippet
|
||||
|
||||
def get_link(self, ):
|
||||
return self.link
|
||||
|
||||
def get_index(self, ):
|
||||
return self.index
|
||||
|
||||
def to_json(self):
|
||||
return vars(self)
|
||||
|
||||
def __str__(self,):
|
||||
return f"""出处 [{self.index + 1}] 标题 [{self.title}]\n\n来源 ({self.link}) \n\n内容 {self.snippet}\n\n"""
|
||||
|
||||
|
||||
class CodeDoc(BaseModel):
|
||||
code: str
|
||||
related_nodes: list
|
||||
index: int
|
||||
|
||||
def get_code(self, ):
|
||||
return self.code
|
||||
|
||||
def get_related_node(self, ):
|
||||
return self.related_nodes
|
||||
|
||||
def get_index(self, ):
|
||||
return self.index
|
||||
|
||||
def to_json(self):
|
||||
return vars(self)
|
||||
|
||||
def __str__(self,):
|
||||
return f"""出处 [{self.index + 1}] \n\n来源 ({self.related_nodes}) \n\n内容 {self.code}\n\n"""
|
||||
|
||||
|
||||
class Docs:
|
||||
|
||||
def __init__(self, docs: List[Doc]):
|
||||
self.titles: List[str] = [doc.get_title() for doc in docs]
|
||||
self.snippets: List[str] = [doc.get_snippet() for doc in docs]
|
||||
self.links: List[str] = [doc.get_link() for doc in docs]
|
||||
self.indexs: List[int] = [doc.get_index() for doc in docs]
|
||||
|
||||
class Task(BaseModel):
|
||||
task_type: str
|
||||
task_name: str
|
||||
task_desc: str
|
||||
task_prompt: str
|
||||
# def __init__(self, task_type, task_name, task_desc) -> None:
|
||||
# self.task_type = task_type
|
||||
# self.task_name = task_name
|
||||
# self.task_desc = task_desc
|
||||
|
||||
class Env(BaseModel):
|
||||
env_type: str
|
||||
env_name: str
|
||||
env_desc:str
|
||||
|
||||
|
||||
class Role(BaseModel):
|
||||
role_type: str
|
||||
role_name: str
|
||||
role_desc: str
|
||||
agent_type: str = ""
|
||||
role_prompt: str = ""
|
||||
template_prompt: str = ""
|
||||
|
||||
|
||||
|
||||
class ChainConfig(BaseModel):
|
||||
chain_name: str
|
||||
chain_type: str
|
||||
agents: List[str]
|
||||
do_checker: bool = False
|
||||
chat_turn: int = 1
|
||||
clear_structure: bool = False
|
||||
brainstorming: bool = False
|
||||
gui_design: bool = True
|
||||
git_management: bool = False
|
||||
self_improve: bool = False
|
||||
|
||||
|
||||
class AgentConfig(BaseModel):
|
||||
role: Role
|
||||
chat_turn: int = 1
|
||||
do_search: bool = False
|
||||
do_doc_retrieval: bool = False
|
||||
do_tool_retrieval: bool = False
|
||||
|
||||
|
||||
class PhaseConfig(BaseModel):
|
||||
phase_name: str
|
||||
phase_type: str
|
||||
chains: List[str]
|
||||
do_summary: bool = False
|
||||
do_search: bool = False
|
||||
do_doc_retrieval: bool = False
|
||||
do_code_retrieval: bool = False
|
||||
do_tool_retrieval: bool = False
|
||||
|
||||
class Message(BaseModel):
|
||||
role_name: str
|
||||
role_type: str
|
||||
role_prompt: str = None
|
||||
input_query: str = None
|
||||
|
||||
# 模型最终返回
|
||||
role_content: str = None
|
||||
role_contents: List[str] = []
|
||||
step_content: str = None
|
||||
step_contents: List[str] = []
|
||||
chain_content: str = None
|
||||
chain_contents: List[str] = []
|
||||
|
||||
# 模型结果解析
|
||||
plans: List[str] = None
|
||||
code_content: str = None
|
||||
code_filename: str = None
|
||||
tool_params: str = None
|
||||
tool_name: str = None
|
||||
|
||||
# 执行结果
|
||||
action_status: str = ActionStatus.DEFAUILT
|
||||
code_answer: str = None
|
||||
tool_answer: str = None
|
||||
observation: str = None
|
||||
figures: Dict[str, str] = {}
|
||||
|
||||
# 辅助信息
|
||||
tools: List[BaseTool] = []
|
||||
task: Task = None
|
||||
db_docs: List['Doc'] = []
|
||||
code_docs: List['CodeDoc'] = []
|
||||
search_docs: List['Doc'] = []
|
||||
|
||||
# 执行输入
|
||||
phase_name: str = None
|
||||
chain_name: str = None
|
||||
do_search: bool = False
|
||||
doc_engine_name: str = None
|
||||
code_engine_name: str = None
|
||||
search_engine_name: str = None
|
||||
top_k: int = 3
|
||||
score_threshold: float = 1.0
|
||||
do_doc_retrieval: bool = False
|
||||
do_code_retrieval: bool = False
|
||||
do_tool_retrieval: bool = False
|
||||
history_node_list: List[str] = []
|
||||
|
||||
|
||||
def to_tuple_message(self, return_all: bool = False, content_key="role_content"):
|
||||
if content_key == "role_content":
|
||||
role_content = self.role_content
|
||||
elif content_key == "step_content":
|
||||
role_content = self.step_content or self.role_content
|
||||
else:
|
||||
role_content =self.role_content
|
||||
|
||||
if return_all:
|
||||
return (self.role_name, self.role_type, role_content)
|
||||
else:
|
||||
return (self.role_name, role_content)
|
||||
return (self.role_type, re.sub("}", "}}", re.sub("{", "{{", str(self.role_content))))
|
||||
|
||||
def to_dict_message(self, return_all: bool = False, content_key="role_content"):
|
||||
if content_key == "role_content":
|
||||
role_content =self.role_content
|
||||
elif content_key == "step_content":
|
||||
role_content = self.step_content or self.role_content
|
||||
else:
|
||||
role_content =self.role_content
|
||||
|
||||
if return_all:
|
||||
return vars(self)
|
||||
else:
|
||||
return {"role": self.role_name, "content": role_content}
|
||||
|
||||
def is_system_role(self,):
|
||||
return self.role_type == "system"
|
||||
|
||||
def __str__(self) -> str:
|
||||
# key_str = '\n'.join([k for k, v in vars(self).items()])
|
||||
# logger.debug(f"{key_str}")
|
||||
return "\n".join([": ".join([k, str(v)]) for k, v in vars(self).items()])
|
||||
|
||||
|
||||
|
||||
def load_role_configs(config) -> Dict[str, AgentConfig]:
|
||||
if isinstance(config, str):
|
||||
with open(config, 'r', encoding="utf8") as file:
|
||||
configs = json.load(file)
|
||||
else:
|
||||
configs = config
|
||||
|
||||
return {name: AgentConfig(**v) for name, v in configs.items()}
|
||||
|
||||
|
||||
def load_chain_configs(config) -> Dict[str, ChainConfig]:
|
||||
if isinstance(config, str):
|
||||
with open(config, 'r', encoding="utf8") as file:
|
||||
configs = json.load(file)
|
||||
else:
|
||||
configs = config
|
||||
return {name: ChainConfig(**v) for name, v in configs.items()}
|
||||
|
||||
|
||||
def load_phase_configs(config) -> Dict[str, PhaseConfig]:
|
||||
if isinstance(config, str):
|
||||
with open(config, 'r', encoding="utf8") as file:
|
||||
configs = json.load(file)
|
||||
else:
|
||||
configs = config
|
||||
return {name: PhaseConfig(**v) for name, v in configs.items()}
|
|
@ -0,0 +1,3 @@
|
|||
from .base_phase import BasePhase
|
||||
|
||||
__all__ = ["BasePhase"]
|
|
@ -0,0 +1,215 @@
|
|||
from typing import List, Union, Dict, Tuple
|
||||
import os
|
||||
import json
|
||||
import importlib
|
||||
import copy
|
||||
from loguru import logger
|
||||
|
||||
from dev_opsgpt.connector.agents import BaseAgent
|
||||
from dev_opsgpt.connector.chains import BaseChain
|
||||
from dev_opsgpt.tools.base_tool import BaseTools, Tool
|
||||
from dev_opsgpt.connector.shcema.memory import Memory
|
||||
from dev_opsgpt.connector.connector_schema import (
|
||||
Task, Env, Role, Message, Doc, Docs, AgentConfig, ChainConfig, PhaseConfig, CodeDoc,
|
||||
load_chain_configs, load_phase_configs, load_role_configs
|
||||
)
|
||||
from dev_opsgpt.connector.configs import AGETN_CONFIGS, CHAIN_CONFIGS, PHASE_CONFIGS
|
||||
from dev_opsgpt.tools import DDGSTool, DocRetrieval, CodeRetrieval
|
||||
|
||||
|
||||
role_configs = load_role_configs(AGETN_CONFIGS)
|
||||
chain_configs = load_chain_configs(CHAIN_CONFIGS)
|
||||
phase_configs = load_phase_configs(PHASE_CONFIGS)
|
||||
|
||||
|
||||
CUR_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
|
||||
class BasePhase:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
phase_name: str,
|
||||
task: Task = None,
|
||||
do_summary: bool = False,
|
||||
do_search: bool = False,
|
||||
do_doc_retrieval: bool = False,
|
||||
do_code_retrieval: bool = False,
|
||||
do_tool_retrieval: bool = False,
|
||||
phase_config: Union[dict, str] = PHASE_CONFIGS,
|
||||
chain_config: Union[dict, str] = CHAIN_CONFIGS,
|
||||
role_config: Union[dict, str] = AGETN_CONFIGS,
|
||||
) -> None:
|
||||
self.conv_summary_agent = BaseAgent(role=role_configs["conv_summary"].role,
|
||||
task = None,
|
||||
memory = None,
|
||||
do_search = role_configs["conv_summary"].do_search,
|
||||
do_doc_retrieval = role_configs["conv_summary"].do_doc_retrieval,
|
||||
do_tool_retrieval = role_configs["conv_summary"].do_tool_retrieval,
|
||||
do_filter=False, do_use_self_memory=False)
|
||||
|
||||
self.chains: List[BaseChain] = self.init_chains(
|
||||
phase_name,
|
||||
task=task,
|
||||
memory=None,
|
||||
phase_config = phase_config,
|
||||
chain_config = chain_config,
|
||||
role_config = role_config,
|
||||
)
|
||||
self.phase_name = phase_name
|
||||
self.do_summary = do_summary
|
||||
self.do_search = do_search
|
||||
self.do_code_retrieval = do_code_retrieval
|
||||
self.do_doc_retrieval = do_doc_retrieval
|
||||
self.do_tool_retrieval = do_tool_retrieval
|
||||
|
||||
self.global_message = Memory([])
|
||||
# self.chain_message = Memory([])
|
||||
self.phase_memory: List[Memory] = []
|
||||
|
||||
def step(self, query: Message, history: Memory = None) -> Tuple[Message, Memory]:
|
||||
summary_message = None
|
||||
chain_message = Memory([])
|
||||
local_memory = Memory([])
|
||||
# do_search、do_doc_search、do_code_search
|
||||
query = self.get_extrainfo_step(query)
|
||||
input_message = copy.deepcopy(query)
|
||||
|
||||
self.global_message.append(input_message)
|
||||
for chain in self.chains:
|
||||
# chain can supply background and query to next chain
|
||||
output_message, chain_memory = chain.step(input_message, history, background=chain_message)
|
||||
output_message = self.inherit_extrainfo(input_message, output_message)
|
||||
input_message = output_message
|
||||
logger.info(f"{chain.chainConfig.chain_name} phase_step: {output_message.role_content}")
|
||||
|
||||
self.global_message.append(output_message)
|
||||
local_memory.extend(chain_memory)
|
||||
|
||||
# whether use summary_llm
|
||||
if self.do_summary:
|
||||
logger.info(f"{self.conv_summary_agent.role.role_name} input global memory: {self.global_message.to_str_messages(content_key='step_content')}")
|
||||
logger.info(f"{self.conv_summary_agent.role.role_name} input global memory: {self.global_message.to_str_messages(content_key='role_content')}")
|
||||
summary_message = self.conv_summary_agent.run(query, background=self.global_message)
|
||||
summary_message.role_name = chain.chainConfig.chain_name
|
||||
summary_message = self.conv_summary_agent.parser(summary_message)
|
||||
summary_message = self.conv_summary_agent.filter(summary_message)
|
||||
summary_message = self.inherit_extrainfo(output_message, summary_message)
|
||||
chain_message.append(summary_message)
|
||||
|
||||
# 由于不会存在多轮chain执行,所以直接保留memory即可
|
||||
for chain in self.chains:
|
||||
self.phase_memory.append(chain.global_memory)
|
||||
|
||||
message = summary_message or output_message
|
||||
message.role_name = self.phase_name
|
||||
# message.db_docs = query.db_docs
|
||||
# message.code_docs = query.code_docs
|
||||
# message.search_docs = query.search_docs
|
||||
return summary_message or output_message, local_memory
|
||||
|
||||
def init_chains(self, phase_name, phase_config, chain_config,
|
||||
role_config, task=None, memory=None) -> List[BaseChain]:
|
||||
# load config
|
||||
role_configs = load_role_configs(role_config)
|
||||
chain_configs = load_chain_configs(chain_config)
|
||||
phase_configs = load_phase_configs(phase_config)
|
||||
|
||||
chains = []
|
||||
self.chain_module = importlib.import_module("dev_opsgpt.connector.chains")
|
||||
self.agent_module = importlib.import_module("dev_opsgpt.connector.agents")
|
||||
phase = phase_configs.get(phase_name)
|
||||
for chain_name in phase.chains:
|
||||
logger.info(f"chain_name: {chain_name}")
|
||||
# chain_class = getattr(self.chain_module, chain_name)
|
||||
logger.debug(f"{chain_configs.keys()}")
|
||||
chain_config = chain_configs[chain_name]
|
||||
|
||||
agents = [
|
||||
getattr(self.agent_module, role_configs[agent_name].role.agent_type)(
|
||||
role_configs[agent_name].role,
|
||||
task = task,
|
||||
memory = memory,
|
||||
chat_turn=role_configs[agent_name].chat_turn,
|
||||
do_search = role_configs[agent_name].do_search,
|
||||
do_doc_retrieval = role_configs[agent_name].do_doc_retrieval,
|
||||
do_tool_retrieval = role_configs[agent_name].do_tool_retrieval,
|
||||
)
|
||||
for agent_name in chain_config.agents
|
||||
]
|
||||
chain_instance = BaseChain(
|
||||
chain_config, agents, chain_config.chat_turn,
|
||||
do_checker=chain_configs[chain_name].do_checker,
|
||||
do_code_exec=False,)
|
||||
chains.append(chain_instance)
|
||||
|
||||
return chains
|
||||
|
||||
def get_extrainfo_step(self, input_message):
|
||||
if self.do_doc_retrieval:
|
||||
input_message = self.get_doc_retrieval(input_message)
|
||||
|
||||
logger.debug(F"self.do_code_retrieval: {self.do_code_retrieval}")
|
||||
if self.do_code_retrieval:
|
||||
input_message = self.get_code_retrieval(input_message)
|
||||
|
||||
if self.do_search:
|
||||
input_message = self.get_search_retrieval(input_message)
|
||||
|
||||
return input_message
|
||||
|
||||
def inherit_extrainfo(self, input_message: Message, output_message: Message):
|
||||
output_message.db_docs = input_message.db_docs
|
||||
output_message.search_docs = input_message.search_docs
|
||||
output_message.code_docs = input_message.code_docs
|
||||
output_message.figures.update(input_message.figures)
|
||||
return output_message
|
||||
|
||||
def get_search_retrieval(self, message: Message,) -> Message:
|
||||
SEARCH_ENGINES = {"duckduckgo": DDGSTool}
|
||||
search_docs = []
|
||||
for idx, doc in enumerate(SEARCH_ENGINES["duckduckgo"].run(message.role_content, 3)):
|
||||
doc.update({"index": idx})
|
||||
search_docs.append(Doc(**doc))
|
||||
message.search_docs = search_docs
|
||||
return message
|
||||
|
||||
def get_doc_retrieval(self, message: Message) -> Message:
|
||||
query = message.role_content
|
||||
knowledge_basename = message.doc_engine_name
|
||||
top_k = message.top_k
|
||||
score_threshold = message.score_threshold
|
||||
if knowledge_basename:
|
||||
docs = DocRetrieval.run(query, knowledge_basename, top_k, score_threshold)
|
||||
message.db_docs = [Doc(**doc) for doc in docs]
|
||||
return message
|
||||
|
||||
def get_code_retrieval(self, message: Message) -> Message:
|
||||
# DocRetrieval.run("langchain是什么", "DSADSAD")
|
||||
query = message.input_query
|
||||
code_engine_name = message.code_engine_name
|
||||
history_node_list = message.history_node_list
|
||||
code_docs = CodeRetrieval.run(code_engine_name, query, code_limit=message.top_k, history_node_list=history_node_list)
|
||||
message.code_docs = [CodeDoc(**doc) for doc in code_docs]
|
||||
return message
|
||||
|
||||
def get_tool_retrieval(self, message: Message) -> Message:
|
||||
return message
|
||||
|
||||
def update(self) -> Memory:
|
||||
pass
|
||||
|
||||
def get_memory(self, ) -> Memory:
|
||||
return Memory.from_memory_list(
|
||||
[chain.get_memory() for chain in self.chains]
|
||||
)
|
||||
|
||||
def get_memory_str(self, do_all_memory=True, content_key="role_content") -> str:
|
||||
memory = self.global_message if do_all_memory else self.phase_memory
|
||||
return "\n".join([": ".join(i) for i in memory.to_tuple_messages(content_key=content_key)])
|
||||
|
||||
def get_chains_memory(self, content_key="role_content") -> List[Tuple]:
|
||||
return [memory.to_tuple_messages(content_key=content_key) for memory in self.phase_memory]
|
||||
|
||||
def get_chains_memory_str(self, content_key="role_content") -> str:
|
||||
return "************".join([f"{chain.chainConfig.chain_name}\n" + chain.get_memory_str(content_key=content_key) for chain in self.chains])
|
|
@ -0,0 +1,6 @@
|
|||
from .memory import Memory
|
||||
|
||||
|
||||
__all__ = [
|
||||
"Memory"
|
||||
]
|
|
@ -0,0 +1,88 @@
|
|||
from pydantic import BaseModel
|
||||
from typing import List
|
||||
from loguru import logger
|
||||
|
||||
from dev_opsgpt.connector.connector_schema import Message
|
||||
from dev_opsgpt.utils.common_utils import (
|
||||
save_to_jsonl_file, save_to_json_file, read_json_file, read_jsonl_file
|
||||
)
|
||||
|
||||
|
||||
class Memory:
|
||||
|
||||
def __init__(self, messages: List[Message] = []):
|
||||
self.messages = messages
|
||||
|
||||
def append(self, message: Message):
|
||||
self.messages.append(message)
|
||||
|
||||
def extend(self, memory: 'Memory'):
|
||||
self.messages.extend(memory.messages)
|
||||
|
||||
def update(self, role_name: str, role_type: str, role_content: str):
|
||||
self.messages.append(Message(role_name, role_type, role_content, role_content))
|
||||
|
||||
def clear(self, ):
|
||||
self.messages = []
|
||||
|
||||
def delete(self, ):
|
||||
pass
|
||||
|
||||
def get_messages(self, ) -> List[Message]:
|
||||
return self.messages
|
||||
|
||||
def save(self, file_type="jsonl", return_all=True):
|
||||
try:
|
||||
if file_type == "jsonl":
|
||||
save_to_jsonl_file(self.to_dict_messages(return_all=return_all), "role_name_history"+f".{file_type}")
|
||||
return True
|
||||
elif file_type in ["json", "txt"]:
|
||||
save_to_json_file(self.to_dict_messages(return_all=return_all), "role_name_history"+f".{file_type}")
|
||||
return True
|
||||
except:
|
||||
return False
|
||||
return False
|
||||
|
||||
def load(self, filepath):
|
||||
file_type = filepath
|
||||
try:
|
||||
if file_type == "jsonl":
|
||||
self.messages = [Message(**message) for message in read_jsonl_file(filepath)]
|
||||
return True
|
||||
elif file_type in ["json", "txt"]:
|
||||
self.messages = [Message(**message) for message in read_jsonl_file(filepath)]
|
||||
return True
|
||||
except:
|
||||
return False
|
||||
|
||||
return False
|
||||
|
||||
def to_tuple_messages(self, return_system: bool = False, return_all: bool = False, content_key="role_content"):
|
||||
# logger.debug(f"{[message.to_tuple_message(return_all, content_key) for message in self.messages ]}")
|
||||
return [
|
||||
message.to_tuple_message(return_all, content_key) for message in self.messages
|
||||
if not message.is_system_role() | return_system
|
||||
]
|
||||
|
||||
def to_dict_messages(self, return_system: bool = False, return_all: bool = False, content_key="role_content"):
|
||||
return [
|
||||
message.to_dict_message(return_all, content_key) for message in self.messages
|
||||
if not message.is_system_role() | return_system
|
||||
]
|
||||
|
||||
def to_str_messages(self, return_system: bool = False, return_all: bool = False, content_key="role_content"):
|
||||
# logger.debug(f"{[message.to_tuple_message(return_all, content_key) for message in self.messages ]}")
|
||||
return "\n".join([
|
||||
": ".join(message.to_tuple_message(return_all, content_key)) for message in self.messages
|
||||
if not message.is_system_role() | return_system
|
||||
])
|
||||
|
||||
@classmethod
|
||||
def from_memory_list(cls, memorys: List['Memory']) -> 'Memory':
|
||||
return cls([message for memory in memorys for message in memory.get_messages()])
|
||||
|
||||
def __len__(self, ):
|
||||
return len(self.messages)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return "\n".join([":".join(i) for i in self.to_tuple_messages()])
|
|
@ -0,0 +1,27 @@
|
|||
|
||||
|
||||
def prompt_cost(model_type: str, num_prompt_tokens: float, num_completion_tokens: float):
|
||||
input_cost_map = {
|
||||
"gpt-3.5-turbo": 0.0015,
|
||||
"gpt-3.5-turbo-16k": 0.003,
|
||||
"gpt-3.5-turbo-0613": 0.0015,
|
||||
"gpt-3.5-turbo-16k-0613": 0.003,
|
||||
"gpt-4": 0.03,
|
||||
"gpt-4-0613": 0.03,
|
||||
"gpt-4-32k": 0.06,
|
||||
}
|
||||
|
||||
output_cost_map = {
|
||||
"gpt-3.5-turbo": 0.002,
|
||||
"gpt-3.5-turbo-16k": 0.004,
|
||||
"gpt-3.5-turbo-0613": 0.002,
|
||||
"gpt-3.5-turbo-16k-0613": 0.004,
|
||||
"gpt-4": 0.06,
|
||||
"gpt-4-0613": 0.06,
|
||||
"gpt-4-32k": 0.12,
|
||||
}
|
||||
|
||||
if model_type not in input_cost_map or model_type not in output_cost_map:
|
||||
return -1
|
||||
|
||||
return num_prompt_tokens * input_cost_map[model_type] / 1000.0 + num_completion_tokens * output_cost_map[model_type] / 1000.0
|
|
@ -4,6 +4,7 @@ from typing import AnyStr, Callable, Dict, List, Optional, Union
|
|||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
|
||||
|
||||
from dev_opsgpt.utils.common_utils import read_json_file
|
||||
|
||||
|
@ -39,3 +40,22 @@ class JSONLoader(BaseLoader):
|
|||
)
|
||||
text = sample.get(self.schema_key, "")
|
||||
docs.append(Document(page_content=text, metadata=metadata))
|
||||
|
||||
def load_and_split(
|
||||
self, text_splitter: Optional[TextSplitter] = None
|
||||
) -> List[Document]:
|
||||
"""Load Documents and split into chunks. Chunks are returned as Documents.
|
||||
|
||||
Args:
|
||||
text_splitter: TextSplitter instance to use for splitting documents.
|
||||
Defaults to RecursiveCharacterTextSplitter.
|
||||
|
||||
Returns:
|
||||
List of Documents.
|
||||
"""
|
||||
if text_splitter is None:
|
||||
_text_splitter: TextSplitter = RecursiveCharacterTextSplitter()
|
||||
else:
|
||||
_text_splitter = text_splitter
|
||||
docs = self.load()
|
||||
return _text_splitter.split_documents(docs)
|
|
@ -4,6 +4,7 @@ from typing import AnyStr, Callable, Dict, List, Optional, Union
|
|||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
|
||||
|
||||
from dev_opsgpt.utils.common_utils import read_jsonl_file
|
||||
|
||||
|
@ -39,3 +40,23 @@ class JSONLLoader(BaseLoader):
|
|||
)
|
||||
text = sample.get(self.schema_key, "")
|
||||
docs.append(Document(page_content=text, metadata=metadata))
|
||||
|
||||
def load_and_split(
|
||||
self, text_splitter: Optional[TextSplitter] = None
|
||||
) -> List[Document]:
|
||||
"""Load Documents and split into chunks. Chunks are returned as Documents.
|
||||
|
||||
Args:
|
||||
text_splitter: TextSplitter instance to use for splitting documents.
|
||||
Defaults to RecursiveCharacterTextSplitter.
|
||||
|
||||
Returns:
|
||||
List of Documents.
|
||||
"""
|
||||
if text_splitter is None:
|
||||
_text_splitter: TextSplitter = RecursiveCharacterTextSplitter()
|
||||
else:
|
||||
_text_splitter = text_splitter
|
||||
|
||||
docs = self.load()
|
||||
return _text_splitter.split_documents(docs)
|
|
@ -0,0 +1,776 @@
|
|||
"""Wrapper around FAISS vector database."""
|
||||
from __future__ import annotations
|
||||
|
||||
import operator
|
||||
import os
|
||||
import pickle
|
||||
import uuid
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Sized,
|
||||
Tuple,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
|
||||
from langchain.docstore.base import AddableMixin, Docstore
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.docstore.in_memory import InMemoryDocstore
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.vectorstores.base import VectorStore
|
||||
from langchain.vectorstores.utils import DistanceStrategy, maximal_marginal_relevance
|
||||
|
||||
|
||||
def dependable_faiss_import(no_avx2: Optional[bool] = None) -> Any:
|
||||
"""
|
||||
Import faiss if available, otherwise raise error.
|
||||
If FAISS_NO_AVX2 environment variable is set, it will be considered
|
||||
to load FAISS with no AVX2 optimization.
|
||||
|
||||
Args:
|
||||
no_avx2: Load FAISS strictly with no AVX2 optimization
|
||||
so that the vectorstore is portable and compatible with other devices.
|
||||
"""
|
||||
if no_avx2 is None and "FAISS_NO_AVX2" in os.environ:
|
||||
no_avx2 = bool(os.getenv("FAISS_NO_AVX2"))
|
||||
|
||||
try:
|
||||
if no_avx2:
|
||||
from faiss import swigfaiss as faiss
|
||||
else:
|
||||
import faiss
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import faiss python package. "
|
||||
"Please install it with `pip install faiss-gpu` (for CUDA supported GPU) "
|
||||
"or `pip install faiss-cpu` (depending on Python version)."
|
||||
)
|
||||
return faiss
|
||||
|
||||
|
||||
def _len_check_if_sized(x: Any, y: Any, x_name: str, y_name: str) -> None:
|
||||
if isinstance(x, Sized) and isinstance(y, Sized) and len(x) != len(y):
|
||||
raise ValueError(
|
||||
f"{x_name} and {y_name} expected to be equal length but "
|
||||
f"len({x_name})={len(x)} and len({y_name})={len(y)}"
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
class FAISS(VectorStore):
|
||||
"""Wrapper around FAISS vector database.
|
||||
|
||||
To use, you must have the ``faiss`` python package installed.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.embeddings.openai import OpenAIEmbeddings
|
||||
from langchain.vectorstores import FAISS
|
||||
|
||||
embeddings = OpenAIEmbeddings()
|
||||
texts = ["FAISS is an important library", "LangChain supports FAISS"]
|
||||
faiss = FAISS.from_texts(texts, embeddings)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedding_function: Callable,
|
||||
index: Any,
|
||||
docstore: Docstore,
|
||||
index_to_docstore_id: Dict[int, str],
|
||||
relevance_score_fn: Optional[Callable[[float], float]] = None,
|
||||
normalize_L2: bool = False,
|
||||
distance_strategy: DistanceStrategy = DistanceStrategy.EUCLIDEAN_DISTANCE,
|
||||
):
|
||||
"""Initialize with necessary components."""
|
||||
self.embedding_function = embedding_function
|
||||
self.index = index
|
||||
self.docstore = docstore
|
||||
self.index_to_docstore_id = index_to_docstore_id
|
||||
self.distance_strategy = distance_strategy
|
||||
self.override_relevance_score_fn = relevance_score_fn
|
||||
self._normalize_L2 = normalize_L2
|
||||
if (
|
||||
self.distance_strategy != DistanceStrategy.EUCLIDEAN_DISTANCE
|
||||
and self._normalize_L2
|
||||
):
|
||||
warnings.warn(
|
||||
"Normalizing L2 is not applicable for metric type: {strategy}".format(
|
||||
strategy=self.distance_strategy
|
||||
)
|
||||
)
|
||||
|
||||
def __add(
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
embeddings: Iterable[List[float]],
|
||||
metadatas: Optional[Iterable[dict]] = None,
|
||||
ids: Optional[List[str]] = None,
|
||||
) -> List[str]:
|
||||
faiss = dependable_faiss_import()
|
||||
|
||||
if not isinstance(self.docstore, AddableMixin):
|
||||
raise ValueError(
|
||||
"If trying to add texts, the underlying docstore should support "
|
||||
f"adding items, which {self.docstore} does not"
|
||||
)
|
||||
|
||||
_len_check_if_sized(texts, metadatas, "texts", "metadatas")
|
||||
_metadatas = metadatas or ({} for _ in texts)
|
||||
documents = [
|
||||
Document(page_content=t, metadata=m) for t, m in zip(texts, _metadatas)
|
||||
]
|
||||
|
||||
_len_check_if_sized(documents, embeddings, "documents", "embeddings")
|
||||
_len_check_if_sized(documents, ids, "documents", "ids")
|
||||
|
||||
# Add to the index.
|
||||
vector = np.array(embeddings, dtype=np.float32)
|
||||
if self._normalize_L2:
|
||||
faiss.normalize_L2(vector)
|
||||
self.index.add(vector)
|
||||
|
||||
# Add information to docstore and index.
|
||||
ids = ids or [str(uuid.uuid4()) for _ in texts]
|
||||
self.docstore.add({id_: doc for id_, doc in zip(ids, documents)})
|
||||
starting_len = len(self.index_to_docstore_id)
|
||||
index_to_id = {starting_len + j: id_ for j, id_ in enumerate(ids)}
|
||||
self.index_to_docstore_id.update(index_to_id)
|
||||
return ids
|
||||
|
||||
def add_texts(
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
ids: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[str]:
|
||||
"""Run more texts through the embeddings and add to the vectorstore.
|
||||
|
||||
Args:
|
||||
texts: Iterable of strings to add to the vectorstore.
|
||||
metadatas: Optional list of metadatas associated with the texts.
|
||||
ids: Optional list of unique IDs.
|
||||
|
||||
Returns:
|
||||
List of ids from adding the texts into the vectorstore.
|
||||
"""
|
||||
# embeddings = [self.embedding_function(text) for text in texts]
|
||||
embeddings = self.embedding_function(texts)
|
||||
return self.__add(texts, embeddings, metadatas=metadatas, ids=ids)
|
||||
|
||||
def add_embeddings(
|
||||
self,
|
||||
text_embeddings: Iterable[Tuple[str, List[float]]],
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
ids: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[str]:
|
||||
"""Run more texts through the embeddings and add to the vectorstore.
|
||||
|
||||
Args:
|
||||
text_embeddings: Iterable pairs of string and embedding to
|
||||
add to the vectorstore.
|
||||
metadatas: Optional list of metadatas associated with the texts.
|
||||
ids: Optional list of unique IDs.
|
||||
|
||||
Returns:
|
||||
List of ids from adding the texts into the vectorstore.
|
||||
"""
|
||||
# Embed and create the documents.
|
||||
texts, embeddings = zip(*text_embeddings)
|
||||
return self.__add(texts, embeddings, metadatas=metadatas, ids=ids)
|
||||
|
||||
def similarity_search_with_score_by_vector(
|
||||
self,
|
||||
embedding: List[float],
|
||||
k: int = 4,
|
||||
filter: Optional[Dict[str, Any]] = None,
|
||||
fetch_k: int = 20,
|
||||
**kwargs: Any,
|
||||
) -> List[Tuple[Document, float]]:
|
||||
"""Return docs most similar to query.
|
||||
|
||||
Args:
|
||||
embedding: Embedding vector to look up documents similar to.
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
filter (Optional[Dict[str, Any]]): Filter by metadata. Defaults to None.
|
||||
fetch_k: (Optional[int]) Number of Documents to fetch before filtering.
|
||||
Defaults to 20.
|
||||
**kwargs: kwargs to be passed to similarity search. Can include:
|
||||
score_threshold: Optional, a floating point value between 0 to 1 to
|
||||
filter the resulting set of retrieved docs
|
||||
|
||||
Returns:
|
||||
List of documents most similar to the query text and L2 distance
|
||||
in float for each. Lower score represents more similarity.
|
||||
"""
|
||||
faiss = dependable_faiss_import()
|
||||
vector = np.array([embedding], dtype=np.float32)
|
||||
if self._normalize_L2:
|
||||
faiss.normalize_L2(vector)
|
||||
scores, indices = self.index.search(vector, k if filter is None else fetch_k)
|
||||
docs = []
|
||||
for j, i in enumerate(indices[0]):
|
||||
if i == -1:
|
||||
# This happens when not enough docs are returned.
|
||||
continue
|
||||
_id = self.index_to_docstore_id[i]
|
||||
doc = self.docstore.search(_id)
|
||||
if not isinstance(doc, Document):
|
||||
raise ValueError(f"Could not find document for id {_id}, got {doc}")
|
||||
if filter is not None:
|
||||
filter = {
|
||||
key: [value] if not isinstance(value, list) else value
|
||||
for key, value in filter.items()
|
||||
}
|
||||
if all(doc.metadata.get(key) in value for key, value in filter.items()):
|
||||
docs.append((doc, scores[0][j]))
|
||||
else:
|
||||
docs.append((doc, scores[0][j]))
|
||||
|
||||
score_threshold = kwargs.get("score_threshold")
|
||||
if score_threshold is not None:
|
||||
cmp = (
|
||||
operator.ge
|
||||
if self.distance_strategy
|
||||
in (DistanceStrategy.MAX_INNER_PRODUCT, DistanceStrategy.JACCARD)
|
||||
else operator.le
|
||||
)
|
||||
docs = [
|
||||
(doc, similarity)
|
||||
for doc, similarity in docs
|
||||
if cmp(similarity, score_threshold)
|
||||
]
|
||||
return docs[:k]
|
||||
|
||||
def similarity_search_with_score(
|
||||
self,
|
||||
query: str,
|
||||
k: int = 4,
|
||||
filter: Optional[Dict[str, Any]] = None,
|
||||
fetch_k: int = 20,
|
||||
**kwargs: Any,
|
||||
) -> List[Tuple[Document, float]]:
|
||||
"""Return docs most similar to query.
|
||||
|
||||
Args:
|
||||
query: Text to look up documents similar to.
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
|
||||
fetch_k: (Optional[int]) Number of Documents to fetch before filtering.
|
||||
Defaults to 20.
|
||||
|
||||
Returns:
|
||||
List of documents most similar to the query text with
|
||||
L2 distance in float. Lower score represents more similarity.
|
||||
"""
|
||||
embedding = self.embedding_function(query)
|
||||
docs = self.similarity_search_with_score_by_vector(
|
||||
embedding,
|
||||
k,
|
||||
filter=filter,
|
||||
fetch_k=fetch_k,
|
||||
**kwargs,
|
||||
)
|
||||
return docs
|
||||
|
||||
def similarity_search_by_vector(
|
||||
self,
|
||||
embedding: List[float],
|
||||
k: int = 4,
|
||||
filter: Optional[Dict[str, Any]] = None,
|
||||
fetch_k: int = 20,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Return docs most similar to embedding vector.
|
||||
|
||||
Args:
|
||||
embedding: Embedding to look up documents similar to.
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
|
||||
fetch_k: (Optional[int]) Number of Documents to fetch before filtering.
|
||||
Defaults to 20.
|
||||
|
||||
Returns:
|
||||
List of Documents most similar to the embedding.
|
||||
"""
|
||||
docs_and_scores = self.similarity_search_with_score_by_vector(
|
||||
embedding,
|
||||
k,
|
||||
filter=filter,
|
||||
fetch_k=fetch_k,
|
||||
**kwargs,
|
||||
)
|
||||
return [doc for doc, _ in docs_and_scores]
|
||||
|
||||
def similarity_search(
|
||||
self,
|
||||
query: str,
|
||||
k: int = 4,
|
||||
filter: Optional[Dict[str, Any]] = None,
|
||||
fetch_k: int = 20,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Return docs most similar to query.
|
||||
|
||||
Args:
|
||||
query: Text to look up documents similar to.
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
filter: (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
|
||||
fetch_k: (Optional[int]) Number of Documents to fetch before filtering.
|
||||
Defaults to 20.
|
||||
|
||||
Returns:
|
||||
List of Documents most similar to the query.
|
||||
"""
|
||||
docs_and_scores = self.similarity_search_with_score(
|
||||
query, k, filter=filter, fetch_k=fetch_k, **kwargs
|
||||
)
|
||||
return [doc for doc, _ in docs_and_scores]
|
||||
|
||||
def max_marginal_relevance_search_with_score_by_vector(
|
||||
self,
|
||||
embedding: List[float],
|
||||
*,
|
||||
k: int = 4,
|
||||
fetch_k: int = 20,
|
||||
lambda_mult: float = 0.5,
|
||||
filter: Optional[Dict[str, Any]] = None,
|
||||
) -> List[Tuple[Document, float]]:
|
||||
"""Return docs and their similarity scores selected using the maximal marginal
|
||||
relevance.
|
||||
|
||||
Maximal marginal relevance optimizes for similarity to query AND diversity
|
||||
among selected documents.
|
||||
|
||||
Args:
|
||||
embedding: Embedding to look up documents similar to.
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
fetch_k: Number of Documents to fetch before filtering to
|
||||
pass to MMR algorithm.
|
||||
lambda_mult: Number between 0 and 1 that determines the degree
|
||||
of diversity among the results with 0 corresponding
|
||||
to maximum diversity and 1 to minimum diversity.
|
||||
Defaults to 0.5.
|
||||
Returns:
|
||||
List of Documents and similarity scores selected by maximal marginal
|
||||
relevance and score for each.
|
||||
"""
|
||||
scores, indices = self.index.search(
|
||||
np.array([embedding], dtype=np.float32),
|
||||
fetch_k if filter is None else fetch_k * 2,
|
||||
)
|
||||
if filter is not None:
|
||||
filtered_indices = []
|
||||
for i in indices[0]:
|
||||
if i == -1:
|
||||
# This happens when not enough docs are returned.
|
||||
continue
|
||||
_id = self.index_to_docstore_id[i]
|
||||
doc = self.docstore.search(_id)
|
||||
if not isinstance(doc, Document):
|
||||
raise ValueError(f"Could not find document for id {_id}, got {doc}")
|
||||
if all(
|
||||
doc.metadata.get(key) in value
|
||||
if isinstance(value, list)
|
||||
else doc.metadata.get(key) == value
|
||||
for key, value in filter.items()
|
||||
):
|
||||
filtered_indices.append(i)
|
||||
indices = np.array([filtered_indices])
|
||||
# -1 happens when not enough docs are returned.
|
||||
embeddings = [self.index.reconstruct(int(i)) for i in indices[0] if i != -1]
|
||||
mmr_selected = maximal_marginal_relevance(
|
||||
np.array([embedding], dtype=np.float32),
|
||||
embeddings,
|
||||
k=k,
|
||||
lambda_mult=lambda_mult,
|
||||
)
|
||||
selected_indices = [indices[0][i] for i in mmr_selected]
|
||||
selected_scores = [scores[0][i] for i in mmr_selected]
|
||||
docs_and_scores = []
|
||||
for i, score in zip(selected_indices, selected_scores):
|
||||
if i == -1:
|
||||
# This happens when not enough docs are returned.
|
||||
continue
|
||||
_id = self.index_to_docstore_id[i]
|
||||
doc = self.docstore.search(_id)
|
||||
if not isinstance(doc, Document):
|
||||
raise ValueError(f"Could not find document for id {_id}, got {doc}")
|
||||
docs_and_scores.append((doc, score))
|
||||
return docs_and_scores
|
||||
|
||||
def max_marginal_relevance_search_by_vector(
|
||||
self,
|
||||
embedding: List[float],
|
||||
k: int = 4,
|
||||
fetch_k: int = 20,
|
||||
lambda_mult: float = 0.5,
|
||||
filter: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Return docs selected using the maximal marginal relevance.
|
||||
|
||||
Maximal marginal relevance optimizes for similarity to query AND diversity
|
||||
among selected documents.
|
||||
|
||||
Args:
|
||||
embedding: Embedding to look up documents similar to.
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
fetch_k: Number of Documents to fetch before filtering to
|
||||
pass to MMR algorithm.
|
||||
lambda_mult: Number between 0 and 1 that determines the degree
|
||||
of diversity among the results with 0 corresponding
|
||||
to maximum diversity and 1 to minimum diversity.
|
||||
Defaults to 0.5.
|
||||
Returns:
|
||||
List of Documents selected by maximal marginal relevance.
|
||||
"""
|
||||
docs_and_scores = self.max_marginal_relevance_search_with_score_by_vector(
|
||||
embedding, k=k, fetch_k=fetch_k, lambda_mult=lambda_mult, filter=filter
|
||||
)
|
||||
return [doc for doc, _ in docs_and_scores]
|
||||
|
||||
def max_marginal_relevance_search(
|
||||
self,
|
||||
query: str,
|
||||
k: int = 4,
|
||||
fetch_k: int = 20,
|
||||
lambda_mult: float = 0.5,
|
||||
filter: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Return docs selected using the maximal marginal relevance.
|
||||
|
||||
Maximal marginal relevance optimizes for similarity to query AND diversity
|
||||
among selected documents.
|
||||
|
||||
Args:
|
||||
query: Text to look up documents similar to.
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
fetch_k: Number of Documents to fetch before filtering (if needed) to
|
||||
pass to MMR algorithm.
|
||||
lambda_mult: Number between 0 and 1 that determines the degree
|
||||
of diversity among the results with 0 corresponding
|
||||
to maximum diversity and 1 to minimum diversity.
|
||||
Defaults to 0.5.
|
||||
Returns:
|
||||
List of Documents selected by maximal marginal relevance.
|
||||
"""
|
||||
embedding = self.embedding_function(query)
|
||||
docs = self.max_marginal_relevance_search_by_vector(
|
||||
embedding,
|
||||
k=k,
|
||||
fetch_k=fetch_k,
|
||||
lambda_mult=lambda_mult,
|
||||
filter=filter,
|
||||
**kwargs,
|
||||
)
|
||||
return docs
|
||||
|
||||
def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> Optional[bool]:
|
||||
"""Delete by ID. These are the IDs in the vectorstore.
|
||||
|
||||
Args:
|
||||
ids: List of ids to delete.
|
||||
|
||||
Returns:
|
||||
Optional[bool]: True if deletion is successful,
|
||||
False otherwise, None if not implemented.
|
||||
"""
|
||||
if ids is None:
|
||||
raise ValueError("No ids provided to delete.")
|
||||
missing_ids = set(ids).difference(self.index_to_docstore_id.values())
|
||||
if missing_ids:
|
||||
raise ValueError(
|
||||
f"Some specified ids do not exist in the current store. Ids not found: "
|
||||
f"{missing_ids}"
|
||||
)
|
||||
|
||||
reversed_index = {id_: idx for idx, id_ in self.index_to_docstore_id.items()}
|
||||
index_to_delete = [reversed_index[id_] for id_ in ids]
|
||||
|
||||
self.index.remove_ids(np.array(index_to_delete, dtype=np.int64))
|
||||
self.docstore.delete(ids)
|
||||
|
||||
remaining_ids = [
|
||||
id_
|
||||
for i, id_ in sorted(self.index_to_docstore_id.items())
|
||||
if i not in index_to_delete
|
||||
]
|
||||
self.index_to_docstore_id = {i: id_ for i, id_ in enumerate(remaining_ids)}
|
||||
|
||||
return True
|
||||
|
||||
def merge_from(self, target: FAISS) -> None:
|
||||
"""Merge another FAISS object with the current one.
|
||||
|
||||
Add the target FAISS to the current one.
|
||||
|
||||
Args:
|
||||
target: FAISS object you wish to merge into the current one
|
||||
|
||||
Returns:
|
||||
None.
|
||||
"""
|
||||
if not isinstance(self.docstore, AddableMixin):
|
||||
raise ValueError("Cannot merge with this type of docstore")
|
||||
# Numerical index for target docs are incremental on existing ones
|
||||
starting_len = len(self.index_to_docstore_id)
|
||||
|
||||
# Merge two IndexFlatL2
|
||||
self.index.merge_from(target.index)
|
||||
|
||||
# Get id and docs from target FAISS object
|
||||
full_info = []
|
||||
for i, target_id in target.index_to_docstore_id.items():
|
||||
doc = target.docstore.search(target_id)
|
||||
if not isinstance(doc, Document):
|
||||
raise ValueError("Document should be returned")
|
||||
full_info.append((starting_len + i, target_id, doc))
|
||||
|
||||
# Add information to docstore and index_to_docstore_id.
|
||||
self.docstore.add({_id: doc for _, _id, doc in full_info})
|
||||
index_to_id = {index: _id for index, _id, _ in full_info}
|
||||
self.index_to_docstore_id.update(index_to_id)
|
||||
|
||||
@classmethod
|
||||
def __from(
|
||||
cls,
|
||||
texts: Iterable[str],
|
||||
embeddings: List[List[float]],
|
||||
embedding: Embeddings,
|
||||
metadatas: Optional[Iterable[dict]] = None,
|
||||
ids: Optional[List[str]] = None,
|
||||
normalize_L2: bool = False,
|
||||
distance_strategy: DistanceStrategy = DistanceStrategy.EUCLIDEAN_DISTANCE,
|
||||
**kwargs: Any,
|
||||
) -> FAISS:
|
||||
faiss = dependable_faiss_import()
|
||||
if distance_strategy == DistanceStrategy.MAX_INNER_PRODUCT:
|
||||
index = faiss.IndexFlatIP(len(embeddings[0]))
|
||||
else:
|
||||
# Default to L2, currently other metric types not initialized.
|
||||
index = faiss.IndexFlatL2(len(embeddings[0]))
|
||||
vecstore = cls(
|
||||
embedding.embed_query,
|
||||
index,
|
||||
InMemoryDocstore(),
|
||||
{},
|
||||
normalize_L2=normalize_L2,
|
||||
distance_strategy=distance_strategy,
|
||||
**kwargs,
|
||||
)
|
||||
vecstore.__add(texts, embeddings, metadatas=metadatas, ids=ids)
|
||||
return vecstore
|
||||
|
||||
@classmethod
|
||||
def from_texts(
|
||||
cls,
|
||||
texts: List[str],
|
||||
embedding: Embeddings,
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
ids: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> FAISS:
|
||||
"""Construct FAISS wrapper from raw documents.
|
||||
|
||||
This is a user friendly interface that:
|
||||
1. Embeds documents.
|
||||
2. Creates an in memory docstore
|
||||
3. Initializes the FAISS database
|
||||
|
||||
This is intended to be a quick way to get started.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain import FAISS
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
|
||||
embeddings = OpenAIEmbeddings()
|
||||
faiss = FAISS.from_texts(texts, embeddings)
|
||||
"""
|
||||
from loguru import logger
|
||||
logger.debug(f"texts: {len(texts)}")
|
||||
embeddings = embedding.embed_documents(texts)
|
||||
return cls.__from(
|
||||
texts,
|
||||
embeddings,
|
||||
embedding,
|
||||
metadatas=metadatas,
|
||||
ids=ids,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_embeddings(
|
||||
cls,
|
||||
text_embeddings: Iterable[Tuple[str, List[float]]],
|
||||
embedding: Embeddings,
|
||||
metadatas: Optional[Iterable[dict]] = None,
|
||||
ids: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> FAISS:
|
||||
"""Construct FAISS wrapper from raw documents.
|
||||
|
||||
This is a user friendly interface that:
|
||||
1. Embeds documents.
|
||||
2. Creates an in memory docstore
|
||||
3. Initializes the FAISS database
|
||||
|
||||
This is intended to be a quick way to get started.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain import FAISS
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
|
||||
embeddings = OpenAIEmbeddings()
|
||||
text_embeddings = embeddings.embed_documents(texts)
|
||||
text_embedding_pairs = zip(texts, text_embeddings)
|
||||
faiss = FAISS.from_embeddings(text_embedding_pairs, embeddings)
|
||||
"""
|
||||
texts = [t[0] for t in text_embeddings]
|
||||
embeddings = [t[1] for t in text_embeddings]
|
||||
return cls.__from(
|
||||
texts,
|
||||
embeddings,
|
||||
embedding,
|
||||
metadatas=metadatas,
|
||||
ids=ids,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def save_local(self, folder_path: str, index_name: str = "index") -> None:
|
||||
"""Save FAISS index, docstore, and index_to_docstore_id to disk.
|
||||
|
||||
Args:
|
||||
folder_path: folder path to save index, docstore,
|
||||
and index_to_docstore_id to.
|
||||
index_name: for saving with a specific index file name
|
||||
"""
|
||||
path = Path(folder_path)
|
||||
path.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
# save index separately since it is not picklable
|
||||
faiss = dependable_faiss_import()
|
||||
faiss.write_index(
|
||||
self.index, str(path / "{index_name}.faiss".format(index_name=index_name))
|
||||
)
|
||||
|
||||
# save docstore and index_to_docstore_id
|
||||
with open(path / "{index_name}.pkl".format(index_name=index_name), "wb") as f:
|
||||
pickle.dump((self.docstore, self.index_to_docstore_id), f)
|
||||
|
||||
@classmethod
|
||||
def load_local(
|
||||
cls,
|
||||
folder_path: str,
|
||||
embeddings: Embeddings,
|
||||
index_name: str = "index",
|
||||
**kwargs: Any,
|
||||
) -> FAISS:
|
||||
"""Load FAISS index, docstore, and index_to_docstore_id from disk.
|
||||
|
||||
Args:
|
||||
folder_path: folder path to load index, docstore,
|
||||
and index_to_docstore_id from.
|
||||
embeddings: Embeddings to use when generating queries
|
||||
index_name: for saving with a specific index file name
|
||||
"""
|
||||
path = Path(folder_path)
|
||||
# load index separately since it is not picklable
|
||||
faiss = dependable_faiss_import()
|
||||
index = faiss.read_index(
|
||||
str(path / "{index_name}.faiss".format(index_name=index_name))
|
||||
)
|
||||
|
||||
# load docstore and index_to_docstore_id
|
||||
with open(path / "{index_name}.pkl".format(index_name=index_name), "rb") as f:
|
||||
docstore, index_to_docstore_id = pickle.load(f)
|
||||
return cls(
|
||||
embeddings.embed_query, index, docstore, index_to_docstore_id, **kwargs
|
||||
)
|
||||
|
||||
def serialize_to_bytes(self) -> bytes:
|
||||
"""Serialize FAISS index, docstore, and index_to_docstore_id to bytes."""
|
||||
return pickle.dumps((self.index, self.docstore, self.index_to_docstore_id))
|
||||
|
||||
@classmethod
|
||||
def deserialize_from_bytes(
|
||||
cls,
|
||||
serialized: bytes,
|
||||
embeddings: Embeddings,
|
||||
**kwargs: Any,
|
||||
) -> FAISS:
|
||||
"""Deserialize FAISS index, docstore, and index_to_docstore_id from bytes."""
|
||||
index, docstore, index_to_docstore_id = pickle.loads(serialized)
|
||||
return cls(
|
||||
embeddings.embed_query, index, docstore, index_to_docstore_id, **kwargs
|
||||
)
|
||||
|
||||
def _select_relevance_score_fn(self) -> Callable[[float], float]:
|
||||
"""
|
||||
The 'correct' relevance function
|
||||
may differ depending on a few things, including:
|
||||
- the distance / similarity metric used by the VectorStore
|
||||
- the scale of your embeddings (OpenAI's are unit normed. Many others are not!)
|
||||
- embedding dimensionality
|
||||
- etc.
|
||||
"""
|
||||
if self.override_relevance_score_fn is not None:
|
||||
return self.override_relevance_score_fn
|
||||
|
||||
# Default strategy is to rely on distance strategy provided in
|
||||
# vectorstore constructor
|
||||
if self.distance_strategy == DistanceStrategy.MAX_INNER_PRODUCT:
|
||||
return self._max_inner_product_relevance_score_fn
|
||||
elif self.distance_strategy == DistanceStrategy.EUCLIDEAN_DISTANCE:
|
||||
# Default behavior is to use euclidean distance relevancy
|
||||
return self._euclidean_relevance_score_fn
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unknown distance strategy, must be cosine, max_inner_product,"
|
||||
" or euclidean"
|
||||
)
|
||||
|
||||
def _similarity_search_with_relevance_scores(
|
||||
self,
|
||||
query: str,
|
||||
k: int = 4,
|
||||
filter: Optional[Dict[str, Any]] = None,
|
||||
fetch_k: int = 20,
|
||||
**kwargs: Any,
|
||||
) -> List[Tuple[Document, float]]:
|
||||
"""Return docs and their similarity scores on a scale from 0 to 1."""
|
||||
# Pop score threshold so that only relevancy scores, not raw scores, are
|
||||
# filtered.
|
||||
relevance_score_fn = self._select_relevance_score_fn()
|
||||
if relevance_score_fn is None:
|
||||
raise ValueError(
|
||||
"normalize_score_fn must be provided to"
|
||||
" FAISS constructor to normalize scores"
|
||||
)
|
||||
docs_and_scores = self.similarity_search_with_score(
|
||||
query,
|
||||
k=k,
|
||||
filter=filter,
|
||||
fetch_k=fetch_k,
|
||||
**kwargs,
|
||||
)
|
||||
docs_and_rel_scores = [
|
||||
(doc, relevance_score_fn(score)) for doc, score in docs_and_scores
|
||||
]
|
||||
return docs_and_rel_scores
|
|
@ -0,0 +1,6 @@
|
|||
from .openai_model import getChatModel
|
||||
|
||||
|
||||
__all__ = [
|
||||
"getChatModel"
|
||||
]
|
|
@ -0,0 +1,29 @@
|
|||
from langchain.callbacks import AsyncIteratorCallbackHandler
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
|
||||
from configs.model_config import (llm_model_dict, LLM_MODEL)
|
||||
|
||||
|
||||
def getChatModel(callBack: AsyncIteratorCallbackHandler = None, temperature=0.3, stop=None):
|
||||
if callBack is None:
|
||||
model = ChatOpenAI(
|
||||
streaming=True,
|
||||
verbose=True,
|
||||
openai_api_key=llm_model_dict[LLM_MODEL]["api_key"],
|
||||
openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"],
|
||||
model_name=LLM_MODEL,
|
||||
temperature=temperature,
|
||||
stop=stop
|
||||
)
|
||||
else:
|
||||
model = ChatOpenAI(
|
||||
streaming=True,
|
||||
verbose=True,
|
||||
callBack=[callBack],
|
||||
openai_api_key=llm_model_dict[LLM_MODEL]["api_key"],
|
||||
openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"],
|
||||
model_name=LLM_MODEL,
|
||||
temperature=temperature,
|
||||
stop=stop
|
||||
)
|
||||
return model
|
|
@ -18,5 +18,6 @@ def check_tables_exist(table_name) -> bool:
|
|||
return table_exist
|
||||
|
||||
def table_init():
|
||||
if (not check_tables_exist("knowledge_base")) or (not check_tables_exist ("knowledge_file")):
|
||||
if (not check_tables_exist("knowledge_base")) or (not check_tables_exist ("knowledge_file")) or \
|
||||
(not check_tables_exist ("code_base")):
|
||||
create_tables()
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
from .document_file_cds import *
|
||||
from .document_base_cds import *
|
||||
from .code_base_cds import *
|
||||
|
||||
__all__ = [
|
||||
"add_kb_to_db", "list_kbs_from_db", "kb_exists",
|
||||
|
@ -7,4 +8,7 @@ __all__ = [
|
|||
|
||||
"list_docs_from_db", "add_doc_to_db", "delete_file_from_db",
|
||||
"delete_files_from_db", "doc_exists", "get_file_detail",
|
||||
|
||||
"list_cbs_from_db", "add_cb_to_db", "delete_cb_from_db",
|
||||
"cb_exists", "get_cb_detail",
|
||||
]
|
|
@ -0,0 +1,79 @@
|
|||
# encoding: utf-8
|
||||
'''
|
||||
@author: 温进
|
||||
@file: code_base_cds.py.py
|
||||
@time: 2023/10/23 下午4:34
|
||||
@desc:
|
||||
'''
|
||||
from loguru import logger
|
||||
from dev_opsgpt.orm.db import with_session, _engine
|
||||
from dev_opsgpt.orm.schemas.base_schema import CodeBaseSchema
|
||||
|
||||
|
||||
@with_session
|
||||
def add_cb_to_db(session, code_name, code_path, code_graph_node_num, code_file_num):
|
||||
# 增:创建知识库实例
|
||||
cb = session.query(CodeBaseSchema).filter_by(code_name=code_name).first()
|
||||
if not cb:
|
||||
cb = CodeBaseSchema(code_name=code_name, code_path=code_path, code_graph_node_num=code_graph_node_num,
|
||||
code_file_num=code_file_num)
|
||||
session.add(cb)
|
||||
else:
|
||||
cb.code_path = code_path
|
||||
cb.code_graph_node_num = code_graph_node_num
|
||||
return True
|
||||
|
||||
|
||||
@with_session
|
||||
def list_cbs_from_db(session):
|
||||
'''
|
||||
查:查询实例
|
||||
'''
|
||||
cbs = session.query(CodeBaseSchema.code_name).all()
|
||||
cbs = [cb[0] for cb in cbs]
|
||||
return cbs
|
||||
|
||||
|
||||
@with_session
|
||||
def cb_exists(session, code_name):
|
||||
'''
|
||||
判断是否存在
|
||||
'''
|
||||
cb = session.query(CodeBaseSchema).filter_by(code_name=code_name).first()
|
||||
status = True if cb else False
|
||||
return status
|
||||
|
||||
@with_session
|
||||
def load_cb_from_db(session, code_name):
|
||||
cb = session.query(CodeBaseSchema).filter_by(code_name=code_name).first()
|
||||
if cb:
|
||||
code_name, code_path, code_graph_node_num = cb.code_name, cb.code_path, cb.code_graph_node_num
|
||||
else:
|
||||
code_name, code_path, code_graph_node_num = None, None, None
|
||||
return code_name, code_path, code_graph_node_num
|
||||
|
||||
|
||||
@with_session
|
||||
def delete_cb_from_db(session, code_name):
|
||||
cb = session.query(CodeBaseSchema).filter_by(code_name=code_name).first()
|
||||
if cb:
|
||||
session.delete(cb)
|
||||
return True
|
||||
|
||||
|
||||
@with_session
|
||||
def get_cb_detail(session, code_name: str) -> dict:
|
||||
cb: CodeBaseSchema = session.query(CodeBaseSchema).filter_by(code_name=code_name).first()
|
||||
logger.info(cb)
|
||||
logger.info('code_name={}'.format(cb.code_name))
|
||||
if cb:
|
||||
return {
|
||||
"code_name": cb.code_name,
|
||||
"code_path": cb.code_path,
|
||||
"code_graph_node_num": cb.code_graph_node_num,
|
||||
'code_file_num': cb.code_file_num
|
||||
}
|
||||
else:
|
||||
return {
|
||||
}
|
||||
|
|
@ -46,3 +46,24 @@ class KnowledgeFileSchema(Base):
|
|||
text_splitter_name='{self.text_splitter_name}',
|
||||
file_version='{self.file_version}',
|
||||
create_time='{self.create_time}')>"""
|
||||
|
||||
|
||||
class CodeBaseSchema(Base):
|
||||
'''
|
||||
代码数据库模型
|
||||
'''
|
||||
__tablename__ = 'code_base'
|
||||
id = Column(Integer, primary_key=True, autoincrement=True, comment='代码库 ID')
|
||||
code_name = Column(String, comment='代码库名称')
|
||||
code_path = Column(String, comment='代码本地路径')
|
||||
code_graph_node_num = Column(String, comment='代码图谱节点数')
|
||||
code_file_num = Column(String, comment='代码解析文件数')
|
||||
create_time = Column(DateTime, default=func.now(), comment='创建时间')
|
||||
|
||||
def __repr__(self):
|
||||
return f"""<CodeBase(id='{self.id}',
|
||||
code_name='{self.code_name}',
|
||||
code_path='{self.code_path}',
|
||||
code_graph_node_num='{self.code_graph_node_num}',
|
||||
code_file_num='{self.code_file_num}'
|
||||
create_time='{self.create_time}')>"""
|
||||
|
|
|
@ -3,6 +3,7 @@ from typing import Optional
|
|||
from pathlib import Path
|
||||
import sys
|
||||
from abc import ABC, abstractclassmethod
|
||||
from loguru import logger
|
||||
|
||||
from configs.server_config import SANDBOX_SERVER
|
||||
|
||||
|
@ -22,21 +23,6 @@ class CodeBoxStatus(BaseModel):
|
|||
status: str
|
||||
|
||||
|
||||
class CodeBoxFile(BaseModel):
|
||||
"""
|
||||
Represents a file returned from a CodeBox instance.
|
||||
"""
|
||||
|
||||
name: str
|
||||
content: Optional[bytes] = None
|
||||
|
||||
def __str__(self):
|
||||
return self.name
|
||||
|
||||
def __repr__(self):
|
||||
return f"File({self.name})"
|
||||
|
||||
|
||||
class BaseBox(ABC):
|
||||
|
||||
enter_status = False
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
import time, os, docker, requests, json, uuid, subprocess, time, asyncio, aiohttp, re, traceback
|
||||
import psutil
|
||||
from typing import List, Optional, Union
|
||||
from pathlib import Path
|
||||
from loguru import logger
|
||||
|
||||
from websockets.sync.client import connect as ws_connect_sync
|
||||
|
@ -11,7 +10,8 @@ from websockets.client import WebSocketClientProtocol, ClientConnection
|
|||
from websockets.exceptions import ConnectionClosedError
|
||||
|
||||
from configs.server_config import SANDBOX_SERVER
|
||||
from .basebox import BaseBox, CodeBoxResponse, CodeBoxStatus, CodeBoxFile
|
||||
from configs.model_config import JUPYTER_WORK_PATH
|
||||
from .basebox import BaseBox, CodeBoxResponse, CodeBoxStatus
|
||||
|
||||
|
||||
class PyCodeBox(BaseBox):
|
||||
|
@ -25,12 +25,18 @@ class PyCodeBox(BaseBox):
|
|||
remote_port: str = SANDBOX_SERVER["port"],
|
||||
token: str = "mytoken",
|
||||
do_code_exe: bool = False,
|
||||
do_remote: bool = False
|
||||
do_remote: bool = False,
|
||||
do_check_net: bool = True,
|
||||
):
|
||||
super().__init__(remote_url, remote_ip, remote_port, token, do_code_exe, do_remote)
|
||||
self.enter_status = True
|
||||
self.do_check_net = do_check_net
|
||||
asyncio.run(self.astart())
|
||||
|
||||
# logger.info(f"""remote_url: {self.remote_url},
|
||||
# remote_ip: {self.remote_ip},
|
||||
# remote_port: {self.remote_port}""")
|
||||
|
||||
def decode_code_from_text(self, text: str) -> str:
|
||||
pattern = r'```.*?```'
|
||||
code_blocks = re.findall(pattern, text, re.DOTALL)
|
||||
|
@ -73,7 +79,8 @@ class PyCodeBox(BaseBox):
|
|||
if not self.ws:
|
||||
raise RuntimeError("Jupyter not running. Make sure to start it first")
|
||||
|
||||
logger.debug(f"code_text: {json.dumps(code_text, ensure_ascii=False)}")
|
||||
# logger.debug(f"code_text: {len(code_text)}, {code_text}")
|
||||
|
||||
self.ws.send(
|
||||
json.dumps(
|
||||
{
|
||||
|
@ -103,7 +110,7 @@ class PyCodeBox(BaseBox):
|
|||
raise RuntimeError("Mixing asyncio and sync code is not supported")
|
||||
received_msg = json.loads(self.ws.recv())
|
||||
except ConnectionClosedError:
|
||||
logger.debug("box start, ConnectionClosedError!!!")
|
||||
# logger.debug("box start, ConnectionClosedError!!!")
|
||||
self.start()
|
||||
return self.run(code_text, file_path, retry - 1)
|
||||
|
||||
|
@ -156,7 +163,7 @@ class PyCodeBox(BaseBox):
|
|||
return CodeBoxResponse(
|
||||
code_exe_type="text",
|
||||
code_text=code_text,
|
||||
code_exe_response=result or "Code run successfully (no output)",
|
||||
code_exe_response=result or "Code run successfully (no output),可能没有打印需要确认的变量",
|
||||
code_exe_status=200,
|
||||
do_code_exe=self.do_code_exe
|
||||
)
|
||||
|
@ -219,7 +226,6 @@ class PyCodeBox(BaseBox):
|
|||
async def _acheck_connect(self, ) -> bool:
|
||||
if self.kernel_url == "":
|
||||
return False
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(f"{self.kernel_url}?token={self.token}", timeout=270) as resp:
|
||||
|
@ -231,7 +237,7 @@ class PyCodeBox(BaseBox):
|
|||
|
||||
def _check_port(self, ) -> bool:
|
||||
try:
|
||||
response = requests.get(f"http://localhost:{self.remote_port}", timeout=270)
|
||||
response = requests.get(f"{self.remote_ip}:{self.remote_port}", timeout=270)
|
||||
logger.warning(f"Port is conflict, please check your codebox's port {self.remote_port}")
|
||||
return response.status_code == 200
|
||||
except requests.exceptions.ConnectionError:
|
||||
|
@ -240,7 +246,7 @@ class PyCodeBox(BaseBox):
|
|||
async def _acheck_port(self, ) -> bool:
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(f"http://localhost:{self.remote_port}", timeout=270) as resp:
|
||||
async with session.get(f"{self.remote_ip}:{self.remote_port}", timeout=270) as resp:
|
||||
logger.warning(f"Port is conflict, please check your codebox's port {self.remote_port}")
|
||||
return resp.status == 200
|
||||
except aiohttp.ClientConnectorError:
|
||||
|
@ -249,6 +255,8 @@ class PyCodeBox(BaseBox):
|
|||
pass
|
||||
|
||||
def _check_connect_success(self, retry_nums: int = 5) -> bool:
|
||||
if not self.do_check_net: return True
|
||||
|
||||
while retry_nums > 0:
|
||||
try:
|
||||
connect_status = self._check_connect()
|
||||
|
@ -262,6 +270,7 @@ class PyCodeBox(BaseBox):
|
|||
raise BaseException(f"can't connect to {self.remote_url}")
|
||||
|
||||
async def _acheck_connect_success(self, retry_nums: int = 5) -> bool:
|
||||
if not self.do_check_net: return True
|
||||
while retry_nums > 0:
|
||||
try:
|
||||
connect_status = await self._acheck_connect()
|
||||
|
@ -283,7 +292,7 @@ class PyCodeBox(BaseBox):
|
|||
self._check_connect_success()
|
||||
|
||||
self._get_kernelid()
|
||||
logger.debug(self.kernel_url.replace("http", "ws") + f"/{self.kernel_id}/channels?token={self.token}")
|
||||
# logger.debug(self.kernel_url.replace("http", "ws") + f"/{self.kernel_id}/channels?token={self.token}")
|
||||
self.wc_url = self.kernel_url.replace("http", "ws") + f"/{self.kernel_id}/channels?token={self.token}"
|
||||
headers = {"Authorization": f'Token {self.token}', 'token': self.token}
|
||||
self.ws = create_connection(self.wc_url, headers=headers)
|
||||
|
@ -291,27 +300,30 @@ class PyCodeBox(BaseBox):
|
|||
# TODO 自动检测本地接口
|
||||
port_status = self._check_port()
|
||||
connect_status = self._check_connect()
|
||||
logger.debug(f"port_status: {port_status}, connect_status: {connect_status}")
|
||||
logger.info(f"port_status: {port_status}, connect_status: {connect_status}")
|
||||
if port_status and not connect_status:
|
||||
raise BaseException(f"Port is conflict, please check your codebox's port {self.remote_port}")
|
||||
|
||||
if not connect_status:
|
||||
self.jupyter = subprocess.Popen(
|
||||
self.jupyter = subprocess.run(
|
||||
[
|
||||
"jupyer", "notebnook",
|
||||
f"--NotebookApp.token={self.token}",
|
||||
f"--port={self.remote_port}",
|
||||
"--no-browser",
|
||||
"--ServerApp.disable_check_xsrf=True",
|
||||
"--notebook-dir={JUPYTER_WORK_PATH}"
|
||||
],
|
||||
stderr=subprocess.PIPE,
|
||||
stdin=subprocess.PIPE,
|
||||
stdout=subprocess.PIPE,
|
||||
)
|
||||
|
||||
self.kernel_url = self.remote_url + "/api/kernels"
|
||||
self.do_check_net = True
|
||||
self._check_connect_success()
|
||||
self._get_kernelid()
|
||||
logger.debug(self.kernel_url.replace("http", "ws") + f"/{self.kernel_id}/channels?token={self.token}")
|
||||
# logger.debug(self.kernel_url.replace("http", "ws") + f"/{self.kernel_id}/channels?token={self.token}")
|
||||
self.wc_url = self.kernel_url.replace("http", "ws") + f"/{self.kernel_id}/channels?token={self.token}"
|
||||
headers = {"Authorization": f'Token {self.token}', 'token': self.token}
|
||||
self.ws = create_connection(self.wc_url, headers=headers)
|
||||
|
@ -333,10 +345,10 @@ class PyCodeBox(BaseBox):
|
|||
port_status = await self._acheck_port()
|
||||
self.kernel_url = self.remote_url + "/api/kernels"
|
||||
connect_status = await self._acheck_connect()
|
||||
logger.debug(f"port_status: {port_status}, connect_status: {connect_status}")
|
||||
logger.info(f"port_status: {port_status}, connect_status: {connect_status}")
|
||||
if port_status and not connect_status:
|
||||
raise BaseException(f"Port is conflict, please check your codebox's port {self.remote_port}")
|
||||
|
||||
|
||||
if not connect_status:
|
||||
self.jupyter = subprocess.Popen(
|
||||
[
|
||||
|
@ -344,13 +356,15 @@ class PyCodeBox(BaseBox):
|
|||
f"--NotebookApp.token={self.token}",
|
||||
f"--port={self.remote_port}",
|
||||
"--no-browser",
|
||||
"--ServerApp.disable_check_xsrf=True",
|
||||
"--ServerApp.disable_check_xsrf=True"
|
||||
],
|
||||
stderr=subprocess.PIPE,
|
||||
stdin=subprocess.PIPE,
|
||||
stdout=subprocess.PIPE,
|
||||
)
|
||||
|
||||
self.kernel_url = self.remote_url + "/api/kernels"
|
||||
self.do_check_net = True
|
||||
await self._acheck_connect_success()
|
||||
await self._aget_kernelid()
|
||||
self.wc_url = self.kernel_url.replace("http", "ws") + f"/{self.kernel_id}/channels?token={self.token}"
|
||||
|
@ -405,7 +419,8 @@ class PyCodeBox(BaseBox):
|
|||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
self.ws = None
|
||||
return CodeBoxStatus(status="stopped")
|
||||
|
||||
# return CodeBoxStatus(status="stopped")
|
||||
|
||||
def __del__(self):
|
||||
self.stop()
|
||||
|
|
|
@ -18,15 +18,19 @@ from configs.server_config import OPEN_CROSS_DOMAIN
|
|||
|
||||
from dev_opsgpt.chat import LLMChat, SearchChat, KnowledgeChat
|
||||
from dev_opsgpt.service.kb_api import *
|
||||
from dev_opsgpt.service.cb_api import *
|
||||
from dev_opsgpt.utils.server_utils import BaseResponse, ListResponse, FastAPI, MakeFastAPIOffline
|
||||
|
||||
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
|
||||
|
||||
from dev_opsgpt.chat import LLMChat, SearchChat, KnowledgeChat
|
||||
from dev_opsgpt.chat import LLMChat, SearchChat, KnowledgeChat, ToolChat, DataChat, CodeChat
|
||||
|
||||
llmChat = LLMChat()
|
||||
searchChat = SearchChat()
|
||||
knowledgeChat = KnowledgeChat()
|
||||
toolChat = ToolChat()
|
||||
dataChat = DataChat()
|
||||
codeChat = CodeChat()
|
||||
|
||||
|
||||
async def document():
|
||||
|
@ -71,6 +75,18 @@ def create_app():
|
|||
app.post("/chat/search_engine_chat",
|
||||
tags=["Chat"],
|
||||
summary="与搜索引擎对话")(searchChat.chat)
|
||||
app.post("/chat/tool_chat",
|
||||
tags=["Chat"],
|
||||
summary="与搜索引擎对话")(toolChat.chat)
|
||||
|
||||
app.post("/chat/data_chat",
|
||||
tags=["Chat"],
|
||||
summary="与搜索引擎对话")(dataChat.chat)
|
||||
|
||||
app.post("/chat/code_chat",
|
||||
tags=["Chat"],
|
||||
summary="与代码库对话")(codeChat.chat)
|
||||
|
||||
|
||||
# Tag: Knowledge Base Management
|
||||
app.get("/knowledge_base/list_knowledge_bases",
|
||||
|
@ -129,6 +145,27 @@ def create_app():
|
|||
summary="根据content中文档重建向量库,流式输出处理进度。"
|
||||
)(recreate_vector_store)
|
||||
|
||||
app.post("/code_base/create_code_base",
|
||||
tags=["Code Base Management"],
|
||||
summary="新建 code_base"
|
||||
)(create_cb)
|
||||
|
||||
app.post("/code_base/delete_code_base",
|
||||
tags=["Code Base Management"],
|
||||
summary="删除 code_base"
|
||||
)(delete_cb)
|
||||
|
||||
app.post("/code_base/code_base_chat",
|
||||
tags=["Code Base Management"],
|
||||
summary="删除 code_base"
|
||||
)(delete_cb)
|
||||
|
||||
app.get("/code_base/list_code_bases",
|
||||
tags=["Code Base Management"],
|
||||
summary="列举 code_base",
|
||||
response_model=ListResponse
|
||||
)(list_cbs)
|
||||
|
||||
# # LLM模型相关接口
|
||||
# app.post("/llm_model/list_models",
|
||||
# tags=["LLM Model Management"],
|
||||
|
|
|
@ -0,0 +1,128 @@
|
|||
# encoding: utf-8
|
||||
'''
|
||||
@author: 温进
|
||||
@file: cb_api.py
|
||||
@time: 2023/10/23 下午7:08
|
||||
@desc:
|
||||
'''
|
||||
|
||||
import urllib, os, json, traceback
|
||||
from typing import List, Dict
|
||||
import shutil
|
||||
|
||||
from fastapi.responses import StreamingResponse, FileResponse
|
||||
from fastapi import File, Form, Body, Query, UploadFile
|
||||
from langchain.docstore.document import Document
|
||||
|
||||
from .service_factory import KBServiceFactory
|
||||
from dev_opsgpt.utils.server_utils import BaseResponse, ListResponse
|
||||
from dev_opsgpt.utils.path_utils import *
|
||||
from dev_opsgpt.orm.commands import *
|
||||
|
||||
from configs.model_config import (
|
||||
CB_ROOT_PATH
|
||||
)
|
||||
|
||||
from dev_opsgpt.codebase_handler.codebase_handler import CodeBaseHandler
|
||||
|
||||
from loguru import logger
|
||||
|
||||
|
||||
async def list_cbs():
|
||||
# Get List of Knowledge Base
|
||||
return ListResponse(data=list_cbs_from_db())
|
||||
|
||||
|
||||
async def create_cb(cb_name: str = Body(..., examples=["samples"]),
|
||||
code_path: str = Body(..., examples=["samples"])
|
||||
) -> BaseResponse:
|
||||
logger.info('cb_name={}, zip_path={}'.format(cb_name, code_path))
|
||||
|
||||
# Create selected knowledge base
|
||||
if not validate_kb_name(cb_name):
|
||||
return BaseResponse(code=403, msg="Don't attack me")
|
||||
if cb_name is None or cb_name.strip() == "":
|
||||
return BaseResponse(code=404, msg="知识库名称不能为空,请重新填写知识库名称")
|
||||
|
||||
cb = cb_exists(cb_name)
|
||||
if cb:
|
||||
return BaseResponse(code=404, msg=f"已存在同名代码知识库 {cb_name}")
|
||||
|
||||
try:
|
||||
logger.info('start build code base')
|
||||
cbh = CodeBaseHandler(cb_name, code_path, cb_root_path=CB_ROOT_PATH)
|
||||
cbh.import_code(do_save=True)
|
||||
code_graph_node_num = len(cbh.nh)
|
||||
code_file_num = len(cbh.lcdh)
|
||||
logger.info('build code base done')
|
||||
|
||||
# create cb to table
|
||||
add_cb_to_db(cb_name, cbh.code_path, code_graph_node_num, code_file_num)
|
||||
logger.info('add cb to mysql table success')
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return BaseResponse(code=500, msg=f"创建代码知识库出错: {e}")
|
||||
|
||||
return BaseResponse(code=200, msg=f"已新增代码知识库 {cb_name}")
|
||||
|
||||
|
||||
async def delete_cb(cb_name: str = Body(..., examples=["samples"])) -> BaseResponse:
|
||||
logger.info('cb_name={}'.format(cb_name))
|
||||
# Create selected knowledge base
|
||||
if not validate_kb_name(cb_name):
|
||||
return BaseResponse(code=403, msg="Don't attack me")
|
||||
if cb_name is None or cb_name.strip() == "":
|
||||
return BaseResponse(code=404, msg="知识库名称不能为空,请重新填写知识库名称")
|
||||
|
||||
cb = cb_exists(cb_name)
|
||||
if cb:
|
||||
try:
|
||||
delete_cb_from_db(cb_name)
|
||||
|
||||
# delete local file
|
||||
shutil.rmtree(CB_ROOT_PATH + os.sep + cb_name)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return BaseResponse(code=500, msg=f"删除代码知识库出错: {e}")
|
||||
|
||||
return BaseResponse(code=200, msg=f"已删除代码知识库 {cb_name}")
|
||||
|
||||
|
||||
def search_code(cb_name: str = Body(..., examples=["sofaboot"]),
|
||||
query: str = Body(..., examples=['你好']),
|
||||
code_limit: int = Body(..., examples=['1']),
|
||||
history_node_list: list = Body(...)) -> dict:
|
||||
|
||||
logger.info('cb_name={}'.format(cb_name))
|
||||
logger.info('query={}'.format(query))
|
||||
logger.info('code_limit={}'.format(code_limit))
|
||||
logger.info('history_node_list={}'.format(history_node_list))
|
||||
|
||||
try:
|
||||
# load codebase
|
||||
cbh = CodeBaseHandler(code_name=cb_name, cb_root_path=CB_ROOT_PATH)
|
||||
cbh.import_code(do_load=True)
|
||||
|
||||
# search code
|
||||
related_code, related_node = cbh.search_code(query, code_limit=code_limit, history_node_list=history_node_list)
|
||||
|
||||
res = {
|
||||
'related_code': related_code,
|
||||
'related_node': related_node
|
||||
}
|
||||
return res
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
return {}
|
||||
|
||||
|
||||
def cb_exists_api(cb_name: str = Body(..., examples=["sofaboot"])) -> bool:
|
||||
try:
|
||||
res = cb_exists(cb_name)
|
||||
return res
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
return False
|
||||
|
||||
|
||||
|
|
@ -4,7 +4,7 @@ from typing import List
|
|||
from functools import lru_cache
|
||||
from loguru import logger
|
||||
|
||||
from langchain.vectorstores import FAISS
|
||||
# from langchain.vectorstores import FAISS
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
||||
|
@ -22,6 +22,7 @@ from dev_opsgpt.utils.path_utils import *
|
|||
from dev_opsgpt.orm.utils import DocumentFile
|
||||
from dev_opsgpt.utils.server_utils import torch_gc
|
||||
from dev_opsgpt.embeddings.utils import load_embeddings
|
||||
from dev_opsgpt.embeddings.faiss_m import FAISS
|
||||
|
||||
|
||||
# make HuggingFaceEmbeddings hashable
|
||||
|
@ -124,6 +125,7 @@ class FaissKBService(KBService):
|
|||
vector_store = load_vector_store(self.kb_name,
|
||||
embeddings=embeddings,
|
||||
tick=_VECTOR_STORE_TICKS.get(self.kb_name, 0))
|
||||
vector_store.embedding_function = embeddings.embed_documents
|
||||
logger.info("docs.lens: {}".format(len(docs)))
|
||||
vector_store.add_documents(docs)
|
||||
torch_gc()
|
||||
|
|
|
@ -16,7 +16,6 @@ from configs.model_config import (
|
|||
)
|
||||
|
||||
|
||||
|
||||
async def list_kbs():
|
||||
# Get List of Knowledge Base
|
||||
return ListResponse(data=list_kbs_from_db())
|
||||
|
|
|
@ -6,7 +6,6 @@ import os
|
|||
src_dir = os.path.join(
|
||||
os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
)
|
||||
print(src_dir)
|
||||
sys.path.append(src_dir)
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
|
||||
|
@ -20,7 +19,7 @@ model_worker_port = 20002
|
|||
openai_api_port = 8888
|
||||
base_url = "http://127.0.0.1:{}"
|
||||
|
||||
os.environ['PATH'] = os.environ.get("PATH", "") + os.pathsep + r'/d/env_utils/miniconda3/envs/devopsgpt/Lib/site-packages/torch/lib'
|
||||
os.environ['PATH'] = os.environ.get("PATH", "") + os.pathsep
|
||||
|
||||
def set_httpx_timeout(timeout=60.0):
|
||||
import httpx
|
||||
|
|
|
@ -0,0 +1,129 @@
|
|||
import sys, os, json, traceback, uvicorn, argparse
|
||||
|
||||
src_dir = os.path.join(
|
||||
os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
)
|
||||
sys.path.append(src_dir)
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import StreamingResponse, FileResponse
|
||||
from fastapi import File, UploadFile
|
||||
|
||||
from dev_opsgpt.utils.server_utils import BaseResponse, ListResponse
|
||||
from configs.server_config import OPEN_CROSS_DOMAIN, SDFILE_API_SERVER
|
||||
from configs.model_config import (
|
||||
JUPYTER_WORK_PATH
|
||||
)
|
||||
from configs import VERSION
|
||||
|
||||
|
||||
|
||||
async def sd_upload_file(file: UploadFile = File(...), work_dir: str = JUPYTER_WORK_PATH):
|
||||
# 保存上传的文件到服务器
|
||||
try:
|
||||
content = await file.read()
|
||||
with open(os.path.join(work_dir, file.filename), "wb") as f:
|
||||
f.write(content)
|
||||
return {"data": True}
|
||||
except:
|
||||
return {"data": False}
|
||||
|
||||
|
||||
async def sd_download_file(filename: str, save_filename: str = "filename_to_download.ext", work_dir: str = JUPYTER_WORK_PATH):
|
||||
# 从服务器下载文件
|
||||
logger.debug(f"{os.path.join(work_dir, filename)}")
|
||||
return {"data": FileResponse(os.path.join(work_dir, filename), filename=save_filename)}
|
||||
|
||||
|
||||
async def sd_list_files(work_dir: str = JUPYTER_WORK_PATH):
|
||||
# 去除目录
|
||||
return {"data": os.listdir(work_dir)}
|
||||
|
||||
|
||||
async def sd_delete_file(filename: str, work_dir: str = JUPYTER_WORK_PATH):
|
||||
# 去除目录
|
||||
try:
|
||||
os.remove(os.path.join(work_dir, filename))
|
||||
return {"data": True}
|
||||
except:
|
||||
return {"data": False}
|
||||
|
||||
|
||||
def create_app():
|
||||
app = FastAPI(
|
||||
title="DevOps-ChatBot API Server",
|
||||
version=VERSION
|
||||
)
|
||||
# MakeFastAPIOffline(app)
|
||||
# Add CORS middleware to allow all origins
|
||||
# 在config.py中设置OPEN_DOMAIN=True,允许跨域
|
||||
# set OPEN_DOMAIN=True in config.py to allow cross-domain
|
||||
if OPEN_CROSS_DOMAIN:
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
app.post("/sdfiles/upload",
|
||||
tags=["files upload and download"],
|
||||
response_model=BaseResponse,
|
||||
summary="上传文件到沙盒"
|
||||
)(sd_upload_file)
|
||||
|
||||
app.get("/sdfiles/download",
|
||||
tags=["files upload and download"],
|
||||
response_model=BaseResponse,
|
||||
summary="从沙盒下载文件"
|
||||
)(sd_download_file)
|
||||
|
||||
app.get("/sdfiles/list",
|
||||
tags=["files upload and download"],
|
||||
response_model=ListResponse,
|
||||
summary="从沙盒工作目录展示文件"
|
||||
)(sd_list_files)
|
||||
|
||||
app.get("/sdfiles/delete",
|
||||
tags=["files upload and download"],
|
||||
response_model=BaseResponse,
|
||||
summary="从沙盒工作目录中删除文件"
|
||||
)(sd_delete_file)
|
||||
return app
|
||||
|
||||
|
||||
|
||||
app = create_app()
|
||||
|
||||
def run_api(host, port, **kwargs):
|
||||
if kwargs.get("ssl_keyfile") and kwargs.get("ssl_certfile"):
|
||||
uvicorn.run(app,
|
||||
host=host,
|
||||
port=port,
|
||||
ssl_keyfile=kwargs.get("ssl_keyfile"),
|
||||
ssl_certfile=kwargs.get("ssl_certfile"),
|
||||
)
|
||||
else:
|
||||
uvicorn.run(app, host=host, port=port)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(prog='DevOps-ChatBot',
|
||||
description='About DevOps-ChatBot, local knowledge based LLM with langchain'
|
||||
' | 基于本地知识库的 LLM 问答')
|
||||
parser.add_argument("--host", type=str, default="0.0.0.0")
|
||||
parser.add_argument("--port", type=int, default=SDFILE_API_SERVER["port"])
|
||||
parser.add_argument("--ssl_keyfile", type=str)
|
||||
parser.add_argument("--ssl_certfile", type=str)
|
||||
# 初始化消息
|
||||
args = parser.parse_args()
|
||||
args_dict = vars(args)
|
||||
run_api(host=args.host,
|
||||
port=args.port,
|
||||
ssl_keyfile=args.ssl_keyfile,
|
||||
ssl_certfile=args.ssl_certfile,
|
||||
)
|
|
@ -77,6 +77,33 @@ def get_kb_details() -> List[Dict]:
|
|||
|
||||
return data
|
||||
|
||||
def get_cb_details() -> List[Dict]:
|
||||
'''
|
||||
get codebase details
|
||||
@return: list of data
|
||||
'''
|
||||
res = {}
|
||||
cbs_in_db = list_cbs_from_db()
|
||||
for cb in cbs_in_db:
|
||||
cb_detail = get_cb_detail(cb)
|
||||
res[cb] = cb_detail
|
||||
|
||||
data = []
|
||||
for i, v in enumerate(res.values()):
|
||||
v['No'] = i + 1
|
||||
data.append(v)
|
||||
return data
|
||||
|
||||
def get_cb_details_by_cb_name(cb_name) -> dict:
|
||||
'''
|
||||
get codebase details by cb_name
|
||||
@return: list of data
|
||||
'''
|
||||
cb_detail = get_cb_detail(cb_name)
|
||||
return cb_detail
|
||||
|
||||
|
||||
|
||||
|
||||
def get_kb_doc_details(kb_name: str) -> List[Dict]:
|
||||
kb = KBServiceFactory.get_service_by_name(kb_name)
|
||||
|
|
|
@ -32,10 +32,12 @@ class LCTextSplitter:
|
|||
loader = self._load_document()
|
||||
text_splitter = self._load_text_splitter()
|
||||
if self.document_loader_name in ["JSONLoader", "JSONLLoader"]:
|
||||
docs = loader.load()
|
||||
# docs = loader.load()
|
||||
docs = loader.load_and_split(text_splitter)
|
||||
logger.debug(f"please check your file can be loaded, docs.lens {len(docs)}")
|
||||
else:
|
||||
docs = loader.load_and_split(text_splitter)
|
||||
logger.info(docs[0])
|
||||
|
||||
return docs
|
||||
|
||||
def _load_document(self, ) -> BaseLoader:
|
||||
|
@ -55,8 +57,8 @@ class LCTextSplitter:
|
|||
chunk_overlap=OVERLAP_SIZE,
|
||||
)
|
||||
self.text_splitter_name = "SpacyTextSplitter"
|
||||
elif self.document_loader_name in ["JSONLoader", "JSONLLoader"]:
|
||||
text_splitter = None
|
||||
# elif self.document_loader_name in ["JSONLoader", "JSONLLoader"]:
|
||||
# text_splitter = None
|
||||
else:
|
||||
text_splitter_module = importlib.import_module('langchain.text_splitter')
|
||||
TextSplitter = getattr(text_splitter_module, self.text_splitter_name)
|
||||
|
|
|
@ -0,0 +1,33 @@
|
|||
from .base_tool import toLangchainTools, get_tool_schema, BaseToolModel
|
||||
from .weather import WeatherInfo, DistrictInfo
|
||||
from .multiplier import Multiplier
|
||||
from .world_time import WorldTimeGetTimezoneByArea
|
||||
from .abnormal_detection import KSigmaDetector
|
||||
from .metrics_query import MetricsQuery
|
||||
from .duckduckgo_search import DDGSTool
|
||||
from .docs_retrieval import DocRetrieval
|
||||
from .cb_query_tool import CodeRetrieval
|
||||
|
||||
TOOL_SETS = [
|
||||
"WeatherInfo", "WorldTimeGetTimezoneByArea", "Multiplier", "DistrictInfo", "KSigmaDetector", "MetricsQuery", "DDGSTool",
|
||||
"DocRetrieval", "CodeRetrieval"
|
||||
]
|
||||
|
||||
TOOL_DICT = {
|
||||
"WeatherInfo": WeatherInfo,
|
||||
"WorldTimeGetTimezoneByArea": WorldTimeGetTimezoneByArea,
|
||||
"Multiplier": Multiplier,
|
||||
"DistrictInfo": DistrictInfo,
|
||||
"KSigmaDetector": KSigmaDetector,
|
||||
"MetricsQuery": MetricsQuery,
|
||||
"DDGSTool": DDGSTool,
|
||||
"DocRetrieval": DocRetrieval,
|
||||
"CodeRetrieval": CodeRetrieval
|
||||
}
|
||||
|
||||
__all__ = [
|
||||
"WeatherInfo", "WorldTimeGetTimezoneByArea", "Multiplier", "DistrictInfo", "KSigmaDetector", "MetricsQuery", "DDGSTool",
|
||||
"DocRetrieval", "CodeRetrieval",
|
||||
"toLangchainTools", "get_tool_schema", "tool_sets", "BaseToolModel"
|
||||
]
|
||||
|
|
@ -0,0 +1,45 @@
|
|||
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Dict
|
||||
import requests
|
||||
import numpy as np
|
||||
from loguru import logger
|
||||
|
||||
from .base_tool import BaseToolModel
|
||||
|
||||
|
||||
|
||||
class KSigmaDetector(BaseToolModel):
|
||||
"""
|
||||
Tips:
|
||||
default control Required, e.g. key1 is not Required/key2 is Required
|
||||
"""
|
||||
|
||||
name: str = "KSigmaDetector"
|
||||
description: str = "Anomaly detection using K-Sigma method"
|
||||
|
||||
class ToolInputArgs(BaseModel):
|
||||
"""Input for KSigmaDetector."""
|
||||
|
||||
data: List[float] = Field(..., description="List of data points")
|
||||
detect_window: int = Field(default=5, description="The size of the detect window for detecting anomalies")
|
||||
abnormal_window: int = Field(default=3, description="The threshold for the number of abnormal points required to classify the data as abnormal")
|
||||
k: float = Field(default=3.0, description="the coef of k-sigma")
|
||||
|
||||
class ToolOutputArgs(BaseModel):
|
||||
"""Output for KSigmaDetector."""
|
||||
|
||||
is_abnormal: bool = Field(..., description="Indicates whether the input data is abnormal or not")
|
||||
|
||||
@staticmethod
|
||||
def run(data, detect_window=5, abnormal_window=3, k=3.0):
|
||||
refer_data = np.array(data[-detect_window:])
|
||||
detect_data = np.array(data[:-detect_window])
|
||||
mean = np.mean(refer_data)
|
||||
std = np.std(refer_data)
|
||||
|
||||
is_abnormal = np.sum(np.abs(detect_data - mean) > k * std) >= abnormal_window
|
||||
return {"is_abnormal": is_abnormal}
|
|
@ -0,0 +1,79 @@
|
|||
from langchain.agents import Tool
|
||||
from langchain.tools import StructuredTool
|
||||
from langchain.tools.base import ToolException
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Dict
|
||||
# import jsonref
|
||||
import json
|
||||
|
||||
|
||||
class BaseToolModel:
|
||||
name = "BaseToolModel"
|
||||
description = "Tool Description"
|
||||
|
||||
class ToolInputArgs(BaseModel):
|
||||
"""
|
||||
Input for MoveFileTool.
|
||||
Tips:
|
||||
default control Required, e.g. key1 is not Required/key2 is Required
|
||||
"""
|
||||
|
||||
key1: str = Field(default=None, description="hello world!")
|
||||
key2: str = Field(..., description="hello world!!")
|
||||
|
||||
class ToolOutputArgs(BaseModel):
|
||||
"""
|
||||
Input for MoveFileTool.
|
||||
Tips:
|
||||
default control Required, e.g. key1 is not Required/key2 is Required
|
||||
"""
|
||||
|
||||
key1: str = Field(default=None, description="hello world!")
|
||||
key2: str = Field(..., description="hello world!!")
|
||||
|
||||
@classmethod
|
||||
def run(cls, tool_input_args: ToolInputArgs) -> ToolOutputArgs:
|
||||
"""excute your tool!"""
|
||||
pass
|
||||
|
||||
|
||||
class BaseTools:
|
||||
tools: List[BaseToolModel]
|
||||
|
||||
|
||||
def get_tool_schema(tool: BaseToolModel) -> Dict:
|
||||
'''转json schema结构'''
|
||||
data = jsonref.loads(tool.schema_json())
|
||||
_ = json.dumps(data, indent=4)
|
||||
del data["definitions"]
|
||||
return data
|
||||
|
||||
|
||||
def _handle_error(error: ToolException) -> str:
|
||||
return (
|
||||
"The following errors occurred during tool execution:"
|
||||
+ error.args[0]
|
||||
+ "Please try again."
|
||||
)
|
||||
|
||||
import requests
|
||||
from loguru import logger
|
||||
def fff(city, extensions):
|
||||
url = "https://restapi.amap.com/v3/weather/weatherInfo"
|
||||
json_data = {"key": "4ceb2ef6257a627b72e3be6beab5b059", "city": city, "extensions": extensions}
|
||||
logger.debug(f"json_data: {json_data}")
|
||||
res = requests.get(url, params={"key": "4ceb2ef6257a627b72e3be6beab5b059", "city": city, "extensions": extensions})
|
||||
return res.json()
|
||||
|
||||
|
||||
def toLangchainTools(tools: BaseTools) -> List:
|
||||
''''''
|
||||
return [
|
||||
StructuredTool(
|
||||
name=tool.name,
|
||||
func=tool.run,
|
||||
description=tool.description,
|
||||
args_schema=tool.ToolInputArgs,
|
||||
handle_tool_error=_handle_error,
|
||||
) for tool in tools
|
||||
]
|
|
@ -0,0 +1,47 @@
|
|||
# encoding: utf-8
|
||||
'''
|
||||
@author: 温进
|
||||
@file: cb_query_tool.py
|
||||
@time: 2023/11/2 下午4:41
|
||||
@desc:
|
||||
'''
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Dict
|
||||
import requests
|
||||
import numpy as np
|
||||
from loguru import logger
|
||||
|
||||
from configs.model_config import (
|
||||
CODE_SEARCH_TOP_K)
|
||||
from .base_tool import BaseToolModel
|
||||
|
||||
from dev_opsgpt.service.cb_api import search_code
|
||||
|
||||
|
||||
class CodeRetrieval(BaseToolModel):
|
||||
name = "CodeRetrieval"
|
||||
description = "采用知识图谱从本地代码知识库获取相关代码"
|
||||
|
||||
class ToolInputArgs(BaseModel):
|
||||
query: str = Field(..., description="检索的关键字或问题")
|
||||
code_base_name: str = Field(..., description="知识库名称", examples=["samples"])
|
||||
code_limit: int = Field(CODE_SEARCH_TOP_K, description="检索返回的数量")
|
||||
|
||||
class ToolOutputArgs(BaseModel):
|
||||
"""Output for MetricsQuery."""
|
||||
code: str = Field(..., description="检索代码")
|
||||
|
||||
@classmethod
|
||||
def run(cls, code_base_name, query, code_limit=CODE_SEARCH_TOP_K, history_node_list=[]):
|
||||
"""excute your tool!"""
|
||||
codes = search_code(code_base_name, query, code_limit, history_node_list=history_node_list)
|
||||
return_codes = []
|
||||
related_code = codes['related_code']
|
||||
related_nodes = codes['related_node']
|
||||
|
||||
for idx, code in enumerate(related_code):
|
||||
return_codes.append({'index': idx, 'code': code, "related_nodes": related_nodes})
|
||||
return return_codes
|
|
@ -0,0 +1,42 @@
|
|||
import json
|
||||
import os
|
||||
import re
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Dict
|
||||
import requests
|
||||
import numpy as np
|
||||
from loguru import logger
|
||||
|
||||
from configs.model_config import (
|
||||
VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD)
|
||||
from .base_tool import BaseToolModel
|
||||
|
||||
|
||||
|
||||
from dev_opsgpt.service.kb_api import search_docs
|
||||
|
||||
|
||||
class DocRetrieval(BaseToolModel):
|
||||
name = "DocRetrieval"
|
||||
description = "采用向量化对本地知识库进行检索"
|
||||
|
||||
class ToolInputArgs(BaseModel):
|
||||
query: str = Field(..., description="检索的关键字或问题")
|
||||
knowledge_base_name: str = Field(..., description="知识库名称", examples=["samples"])
|
||||
search_top: int = Field(VECTOR_SEARCH_TOP_K, description="检索返回的数量")
|
||||
score_threshold: float = Field(SCORE_THRESHOLD, description="知识库匹配相关度阈值,取值范围在0-1之间,SCORE越小,相关度越高,取到1相当于不筛选,建议设置在0.5左右", ge=0, le=1)
|
||||
|
||||
class ToolOutputArgs(BaseModel):
|
||||
"""Output for MetricsQuery."""
|
||||
title: str = Field(..., description="检索网页标题")
|
||||
snippet: str = Field(..., description="检索内容的判断")
|
||||
link: str = Field(..., description="检索网页地址")
|
||||
|
||||
@classmethod
|
||||
def run(cls, query, knowledge_base_name, search_top=VECTOR_SEARCH_TOP_K, score_threshold=SCORE_THRESHOLD):
|
||||
"""excute your tool!"""
|
||||
docs = search_docs(query, knowledge_base_name, search_top, score_threshold)
|
||||
return_docs = []
|
||||
for idx, doc in enumerate(docs):
|
||||
return_docs.append({"index": idx, "snippet": doc.page_content, "title": doc.metadata.get("source"), "link": doc.metadata.get("source")})
|
||||
return return_docs
|
|
@ -0,0 +1,72 @@
|
|||
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Dict
|
||||
import requests
|
||||
import numpy as np
|
||||
from loguru import logger
|
||||
|
||||
from .base_tool import BaseToolModel
|
||||
from configs.model_config import (
|
||||
PROMPT_TEMPLATE, SEARCH_ENGINE_TOP_K, BING_SUBSCRIPTION_KEY, BING_SEARCH_URL,
|
||||
VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD)
|
||||
|
||||
from duckduckgo_search import DDGS
|
||||
|
||||
|
||||
class DDGSTool(BaseToolModel):
|
||||
name = "DDGSTool"
|
||||
description = "通过duckduckgo进行资料搜索"
|
||||
|
||||
class ToolInputArgs(BaseModel):
|
||||
query: str = Field(..., description="检索的关键字或问题")
|
||||
search_top: int = Field(..., description="检索返回的数量")
|
||||
region: str = Field("wt-wt", enum=["wt-wt", "us-en", "uk-en", "ru-ru"], description="搜索的区域")
|
||||
safesearch: str = Field("moderate", enum=["on", "moderate", "off"], description="")
|
||||
timelimit: str = Field(None, enum=[None, "d", "w", "m", "y"], description="查询时间方式")
|
||||
backend: str = Field("api", description="搜索的资料来源")
|
||||
|
||||
class ToolOutputArgs(BaseModel):
|
||||
"""Output for MetricsQuery."""
|
||||
title: str = Field(..., description="检索网页标题")
|
||||
snippet: str = Field(..., description="检索内容的判断")
|
||||
link: str = Field(..., description="检索网页地址")
|
||||
|
||||
@classmethod
|
||||
def run(cls, query, search_top, region="wt-wt", safesearch="moderate", timelimit=None, backend="api"):
|
||||
"""excute your tool!"""
|
||||
with DDGS(proxies=os.environ.get("DUCKDUCKGO_PROXY")) as ddgs:
|
||||
results = ddgs.text(
|
||||
query,
|
||||
region=region,
|
||||
safesearch=safesearch,
|
||||
timelimit=timelimit,
|
||||
backend=backend,
|
||||
)
|
||||
if results is None:
|
||||
return [{"Result": "No good DuckDuckGo Search Result was found"}]
|
||||
|
||||
def to_metadata(result: Dict) -> Dict[str, str]:
|
||||
if backend == "news":
|
||||
return {
|
||||
"date": result["date"],
|
||||
"title": result["title"],
|
||||
"snippet": result["body"],
|
||||
"source": result["source"],
|
||||
"link": result["url"],
|
||||
}
|
||||
return {
|
||||
"snippet": result["body"],
|
||||
"title": result["title"],
|
||||
"link": result["href"],
|
||||
}
|
||||
|
||||
formatted_results = []
|
||||
for i, res in enumerate(results, 1):
|
||||
if res is not None:
|
||||
formatted_results.append(to_metadata(res))
|
||||
if len(formatted_results) == search_top:
|
||||
break
|
||||
return formatted_results
|
|
@ -0,0 +1,33 @@
|
|||
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Dict
|
||||
import requests
|
||||
import numpy as np
|
||||
from loguru import logger
|
||||
|
||||
from .base_tool import BaseToolModel
|
||||
|
||||
|
||||
|
||||
class MetricsQuery(BaseToolModel):
|
||||
name = "MetricsQuery"
|
||||
description = "查询机器的监控数据"
|
||||
|
||||
class ToolInputArgs(BaseModel):
|
||||
machine_ip: str = Field(..., description="machine_ip")
|
||||
time: int = Field(..., description="time period")
|
||||
|
||||
class ToolOutputArgs(BaseModel):
|
||||
"""Output for MetricsQuery."""
|
||||
|
||||
datas: List[float] = Field(..., description="监控时序数组")
|
||||
|
||||
def run(machine_ip, time):
|
||||
"""excute your tool!"""
|
||||
data = [0.857, 2.345, 1.234, 4.567, 3.456, 9.876, 5.678, 7.890, 6.789, 8.901, 10.987, 12.345, 11.234, 14.567, 13.456, 19.876, 15.678, 17.890,
|
||||
16.789, 18.901, 20.987, 22.345, 21.234, 24.567, 23.456, 29.876, 25.678, 27.890, 26.789, 28.901, 30.987, 32.345, 31.234, 34.567,
|
||||
33.456, 39.876, 35.678, 37.890, 36.789, 38.901, 40.987]
|
||||
return data[:30]
|
|
@ -0,0 +1,38 @@
|
|||
from pydantic import BaseModel, Field
|
||||
from typing import List, Dict
|
||||
import requests
|
||||
from loguru import logger
|
||||
|
||||
from .base_tool import BaseToolModel
|
||||
|
||||
|
||||
|
||||
class Multiplier(BaseToolModel):
|
||||
"""
|
||||
Tips:
|
||||
default control Required, e.g. key1 is not Required/key2 is Required
|
||||
"""
|
||||
|
||||
name: str = "Multiplier"
|
||||
description: str = """useful for when you need to multiply two numbers together. \
|
||||
The input to this tool should be a comma separated list of numbers of length two, representing the two numbers you want to multiply together. \
|
||||
For example, `1,2` would be the input if you wanted to multiply 1 by 2."""
|
||||
|
||||
class ToolInputArgs(BaseModel):
|
||||
"""Input for Multiplier."""
|
||||
|
||||
# key: str = Field(..., description="用户在高德地图官网申请web服务API类型KEY")
|
||||
a: int = Field(..., description="num a")
|
||||
b: int = Field(..., description="num b")
|
||||
|
||||
class ToolOutputArgs(BaseModel):
|
||||
"""Output for Multiplier."""
|
||||
|
||||
res: int = Field(..., description="the result of two nums")
|
||||
|
||||
@staticmethod
|
||||
def run(a, b):
|
||||
return a * b
|
||||
|
||||
def multi_run(a, b):
|
||||
return a * b
|
|
@ -0,0 +1,109 @@
|
|||
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Dict
|
||||
import requests
|
||||
from loguru import logger
|
||||
|
||||
from .base_tool import BaseToolModel
|
||||
|
||||
|
||||
|
||||
class WeatherInfo(BaseToolModel):
|
||||
"""
|
||||
Tips:
|
||||
default control Required, e.g. key1 is not Required/key2 is Required
|
||||
"""
|
||||
|
||||
name: str = "WeatherInfo"
|
||||
description: str = "According to the user's input adcode, it can query the current/future weather conditions of the target area."
|
||||
|
||||
class ToolInputArgs(BaseModel):
|
||||
"""Input for Weather."""
|
||||
|
||||
# key: str = Field(..., description="用户在高德地图官网申请web服务API类型KEY")
|
||||
city: str = Field(..., description="城市编码,输入城市的adcode,adcode信息可参考城市编码表")
|
||||
extensions: str = Field(default=None, enum=["base", "all"], description="气象类型,输入城市的adcode,adcode信息可参考城市编码表")
|
||||
|
||||
class ToolOutputArgs(BaseModel):
|
||||
"""Output for Weather."""
|
||||
|
||||
lives: str = Field(default=None, description="实况天气数据")
|
||||
|
||||
# @classmethod
|
||||
# def run(cls, tool_input_args: ToolInputArgs) -> ToolOutputArgs:
|
||||
# """excute your tool!"""
|
||||
# url = "https://restapi.amap.com/v3/weather/weatherInfo"
|
||||
# try:
|
||||
# json_data = tool_input_args.dict()
|
||||
# json_data["key"] = "4ceb2ef6257a627b72e3be6beab5b059"
|
||||
# res = requests.get(url, json_data)
|
||||
# return res.json()
|
||||
# except Exception as e:
|
||||
# return e
|
||||
|
||||
@staticmethod
|
||||
def run(city, extensions) -> ToolOutputArgs:
|
||||
"""excute your tool!"""
|
||||
url = "https://restapi.amap.com/v3/weather/weatherInfo"
|
||||
try:
|
||||
json_data = {}
|
||||
json_data["city"] = city
|
||||
json_data["key"] = "4ceb2ef6257a627b72e3be6beab5b059"
|
||||
json_data["extensions"] = extensions
|
||||
logger.debug(f"json_data: {json_data}")
|
||||
res = requests.get(url, params=json_data)
|
||||
return res.json()
|
||||
except Exception as e:
|
||||
return e
|
||||
|
||||
|
||||
class DistrictInfo(BaseToolModel):
|
||||
"""
|
||||
Tips:
|
||||
default control Required, e.g. key1 is not Required/key2 is Required
|
||||
"""
|
||||
|
||||
name: str = "DistrictInfo"
|
||||
description: str = "用户希望通过得到行政区域信息,进行开发工作。"
|
||||
|
||||
class ToolInputArgs(BaseModel):
|
||||
"""Input for district."""
|
||||
keywords: str = Field(default=None, description="规则:只支持单个关键词语搜索关键词支持:行政区名称、citycode、adcode例如,在subdistrict=2,搜索省份(例如山东),能够显示市(例如济南),区(例如历下区)")
|
||||
subdistrict: str = Field(default=None, enums=[1,2,3], description="""规则:设置显示下级行政区级数(行政区级别包括:国家、省/直辖市、市、区/县、乡镇/街道多级数据)
|
||||
|
||||
可选值:0、1、2、3等数字,并以此类推
|
||||
|
||||
0:不返回下级行政区;
|
||||
|
||||
1:返回下一级行政区;
|
||||
|
||||
2:返回下两级行政区;
|
||||
|
||||
3:返回下三级行政区;""")
|
||||
page: int = Field(default=1, examples=["page=2", "page=3"], description="最外层的districts最多会返回20个数据,若超过限制,请用page请求下一页数据。")
|
||||
extensions: str = Field(default=None, enum=["base", "all"], description="气象类型,输入城市的adcode,adcode信息可参考城市编码表")
|
||||
|
||||
class ToolOutputArgs(BaseModel):
|
||||
"""Output for district."""
|
||||
|
||||
districts: str = Field(default=None, description="行政区列表")
|
||||
|
||||
@staticmethod
|
||||
def run(keywords=None, subdistrict=None, page=1, extensions=None) -> ToolOutputArgs:
|
||||
"""excute your tool!"""
|
||||
url = "https://restapi.amap.com/v3/config/district"
|
||||
try:
|
||||
json_data = {}
|
||||
json_data["keywords"] = keywords
|
||||
json_data["key"] = "4ceb2ef6257a627b72e3be6beab5b059"
|
||||
json_data["subdistrict"] = subdistrict
|
||||
json_data["page"] = page
|
||||
json_data["extensions"] = extensions
|
||||
logger.debug(f"json_data: {json_data}")
|
||||
res = requests.get(url, params=json_data)
|
||||
return res.json()
|
||||
except Exception as e:
|
||||
return e
|
|
@ -0,0 +1,255 @@
|
|||
import json
|
||||
import os
|
||||
import re
|
||||
import requests
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List
|
||||
|
||||
from .base_tool import BaseToolModel
|
||||
|
||||
|
||||
class WorldTimeGetTimezoneByArea(BaseToolModel):
|
||||
"""
|
||||
World Time API
|
||||
Tips:
|
||||
default control Required, e.g. key1 is not Required/key2 is Required
|
||||
"""
|
||||
|
||||
name = "WorldTime.getTimezoneByArea"
|
||||
description = "a listing of all timezones available for that area."
|
||||
|
||||
class ToolInputArgs(BaseModel):
|
||||
"""Input for WorldTimeGetTimezoneByArea."""
|
||||
area: str = Field(..., description="area")
|
||||
|
||||
class ToolOutputArgs(BaseModel):
|
||||
"""Output for WorldTimeGetTimezoneByArea."""
|
||||
DateTimeJsonResponse: str = Field(..., description="a list of available timezones")
|
||||
|
||||
@classmethod
|
||||
def run(area: str) -> ToolOutputArgs:
|
||||
"""excute your tool!"""
|
||||
url = "http://worldtimeapi.org/api/timezone"
|
||||
try:
|
||||
res = requests.get(url, json={"area": area})
|
||||
return res.text
|
||||
except Exception as e:
|
||||
return e
|
||||
|
||||
|
||||
def worldtime_run(area):
|
||||
url = "http://worldtimeapi.org/api/timezone"
|
||||
res = requests.get(url, json={"area": area})
|
||||
return res.text
|
||||
|
||||
# class WorldTime(BaseTool):
|
||||
# api_spec: str = '''
|
||||
# description: >-
|
||||
# A simple API to get the current time based on
|
||||
# a request with a timezone.
|
||||
|
||||
# servers:
|
||||
# - url: http://worldtimeapi.org/api/
|
||||
|
||||
# paths:
|
||||
# /timezone:
|
||||
# get:
|
||||
# description: a listing of all timezones.
|
||||
# operationId: getTimezone
|
||||
# responses:
|
||||
# default:
|
||||
# $ref: "#/components/responses/SuccessfulListJsonResponse"
|
||||
|
||||
# /timezone/{area}:
|
||||
# get:
|
||||
# description: a listing of all timezones available for that area.
|
||||
# operationId: getTimezoneByArea
|
||||
# parameters:
|
||||
# - name: area
|
||||
# in: path
|
||||
# required: true
|
||||
# schema:
|
||||
# type: string
|
||||
# responses:
|
||||
# '200':
|
||||
# $ref: "#/components/responses/SuccessfulListJsonResponse"
|
||||
# default:
|
||||
# $ref: "#/components/responses/ErrorJsonResponse"
|
||||
|
||||
# /timezone/{area}/{location}:
|
||||
# get:
|
||||
# description: request the current time for a timezone.
|
||||
# operationId: getTimeByTimezone
|
||||
# parameters:
|
||||
# - name: area
|
||||
# in: path
|
||||
# required: true
|
||||
# schema:
|
||||
# type: string
|
||||
# - name: location
|
||||
# in: path
|
||||
# required: true
|
||||
# schema:
|
||||
# type: string
|
||||
# responses:
|
||||
# '200':
|
||||
# $ref: "#/components/responses/SuccessfulDateTimeJsonResponse"
|
||||
# default:
|
||||
# $ref: "#/components/responses/ErrorJsonResponse"
|
||||
|
||||
# /ip:
|
||||
# get:
|
||||
# description: >-
|
||||
# request the current time based on the ip of the request.
|
||||
# note: this is a "best guess" obtained from open-source data.
|
||||
# operationId: getTimeByIP
|
||||
# responses:
|
||||
# '200':
|
||||
# $ref: "#/components/responses/SuccessfulDateTimeJsonResponse"
|
||||
# default:
|
||||
# $ref: "#/components/responses/ErrorJsonResponse"
|
||||
|
||||
# components:
|
||||
# responses:
|
||||
# SuccessfulListJsonResponse:
|
||||
# description: >-
|
||||
# the list of available timezones in JSON format
|
||||
# content:
|
||||
# application/json:
|
||||
# schema:
|
||||
# $ref: "#/components/schemas/ListJsonResponse"
|
||||
|
||||
# SuccessfulDateTimeJsonResponse:
|
||||
# description: >-
|
||||
# the current time for the timezone requested in JSON format
|
||||
# content:
|
||||
# application/json:
|
||||
# schema:
|
||||
# $ref: "#/components/schemas/DateTimeJsonResponse"
|
||||
|
||||
# ErrorJsonResponse:
|
||||
# description: >-
|
||||
# an error response in JSON format
|
||||
# content:
|
||||
# application/json:
|
||||
# schema:
|
||||
# $ref: "#/components/schemas/ErrorJsonResponse"
|
||||
|
||||
# schemas:
|
||||
# ListJsonResponse:
|
||||
# type: array
|
||||
# description: >-
|
||||
# a list of available timezones
|
||||
# items:
|
||||
# type: string
|
||||
|
||||
# DateTimeJsonResponse:
|
||||
# required:
|
||||
# - abbreviation
|
||||
# - client_ip
|
||||
# - datetime
|
||||
# - day_of_week
|
||||
# - day_of_year
|
||||
# - dst
|
||||
# - dst_offset
|
||||
# - timezone
|
||||
# - unixtime
|
||||
# - utc_datetime
|
||||
# - utc_offset
|
||||
# - week_number
|
||||
# properties:
|
||||
# abbreviation:
|
||||
# type: string
|
||||
# description: >-
|
||||
# the abbreviated name of the timezone
|
||||
# client_ip:
|
||||
# type: string
|
||||
# description: >-
|
||||
# the IP of the client making the request
|
||||
# datetime:
|
||||
# type: string
|
||||
# description: >-
|
||||
# an ISO8601-valid string representing
|
||||
# the current, local date/time
|
||||
# day_of_week:
|
||||
# type: integer
|
||||
# description: >-
|
||||
# current day number of the week, where sunday is 0
|
||||
# day_of_year:
|
||||
# type: integer
|
||||
# description: >-
|
||||
# ordinal date of the current year
|
||||
# dst:
|
||||
# type: boolean
|
||||
# description: >-
|
||||
# flag indicating whether the local
|
||||
# time is in daylight savings
|
||||
# dst_from:
|
||||
# type: string
|
||||
# description: >-
|
||||
# an ISO8601-valid string representing
|
||||
# the datetime when daylight savings
|
||||
# started for this timezone
|
||||
# dst_offset:
|
||||
# type: integer
|
||||
# description: >-
|
||||
# the difference in seconds between the current local
|
||||
# time and daylight saving time for the location
|
||||
# dst_until:
|
||||
# type: string
|
||||
# description: >-
|
||||
# an ISO8601-valid string representing
|
||||
# the datetime when daylight savings
|
||||
# will end for this timezone
|
||||
# raw_offset:
|
||||
# type: integer
|
||||
# description: >-
|
||||
# the difference in seconds between the current local time
|
||||
# and the time in UTC, excluding any daylight saving difference
|
||||
# (see dst_offset)
|
||||
# timezone:
|
||||
# type: string
|
||||
# description: >-
|
||||
# timezone in `Area/Location` or
|
||||
# `Area/Location/Region` format
|
||||
# unixtime:
|
||||
# type: integer
|
||||
# description: >-
|
||||
# number of seconds since the Epoch
|
||||
# utc_datetime:
|
||||
# type: string
|
||||
# description: >-
|
||||
# an ISO8601-valid string representing
|
||||
# the current date/time in UTC
|
||||
# utc_offset:
|
||||
# type: string
|
||||
# description: >-
|
||||
# an ISO8601-valid string representing
|
||||
# the offset from UTC
|
||||
# week_number:
|
||||
# type: integer
|
||||
# description: >-
|
||||
# the current week number
|
||||
|
||||
# ErrorJsonResponse:
|
||||
# required:
|
||||
# - error
|
||||
# properties:
|
||||
# error:
|
||||
# type: string
|
||||
# description: >-
|
||||
# details about the error encountered
|
||||
# '''
|
||||
|
||||
# def exec_tool(self, message: UserMessage) -> UserMessage:
|
||||
# match = re.search(r'{[\s\S]*}', message.content)
|
||||
# if match:
|
||||
# params = json.loads(match.group())
|
||||
# url = params["url"]
|
||||
# if "params" in params:
|
||||
# url = url.format(**params["params"])
|
||||
# res = requests.get(url)
|
||||
# response_msg = UserMessage(content=f"API response: {res.text}")
|
||||
# else:
|
||||
# raise "ERROR"
|
||||
# return response_msg
|
|
@ -2,6 +2,11 @@ import textwrap, time, copy, random, hashlib, json, os
|
|||
from datetime import datetime, timedelta
|
||||
from functools import wraps
|
||||
from loguru import logger
|
||||
from typing import *
|
||||
from pathlib import Path
|
||||
from io import BytesIO
|
||||
from fastapi import Body, File, Form, Body, Query, UploadFile
|
||||
from tempfile import SpooledTemporaryFile
|
||||
|
||||
|
||||
|
||||
|
@ -65,3 +70,23 @@ def save_to_json_file(data, filename):
|
|||
|
||||
with open(filename, "w", encoding="utf-8") as f:
|
||||
json.dump(data, f, indent=2, ensure_ascii=False)
|
||||
|
||||
|
||||
def file_normalize(file: Union[str, Path, bytes], filename=None):
|
||||
logger.debug(f"{file}")
|
||||
if isinstance(file, bytes): # raw bytes
|
||||
file = BytesIO(file)
|
||||
elif hasattr(file, "read"): # a file io like object
|
||||
filename = filename or file.name
|
||||
else: # a local path
|
||||
file = Path(file).absolute().open("rb")
|
||||
logger.debug(file)
|
||||
filename = filename or file.name
|
||||
return file, filename
|
||||
|
||||
|
||||
def get_uploadfile(file: Union[str, Path, bytes], filename=None) -> UploadFile:
|
||||
temp_file = SpooledTemporaryFile(max_size=10 * 1024 * 1024)
|
||||
temp_file.write(file.read())
|
||||
temp_file.seek(0)
|
||||
return UploadFile(file=temp_file, filename=filename)
|
|
@ -29,7 +29,7 @@ LOADER2EXT_DICT = {"UnstructuredFileLoader": ['.eml', '.html', '.md', '.msg', '.
|
|||
"TextLoader": ['.txt'],
|
||||
"PythonLoader": ['.py'],
|
||||
"JSONLoader": ['.json'],
|
||||
"JSONLLoader": ['.jsonl'],
|
||||
"JSONLLoader": ['.jsonl']
|
||||
}
|
||||
|
||||
EXT2LOADER_DICT = {ext: LOADERNAME2LOADER_DICT[k] for k, exts in LOADER2EXT_DICT.items() for ext in exts}
|
||||
|
@ -61,8 +61,10 @@ def list_kbs_from_folder():
|
|||
|
||||
def list_docs_from_folder(kb_name: str):
|
||||
doc_path = get_doc_path(kb_name)
|
||||
return [file for file in os.listdir(doc_path)
|
||||
if os.path.isfile(os.path.join(doc_path, file))]
|
||||
if os.path.exists(doc_path):
|
||||
return [file for file in os.listdir(doc_path)
|
||||
if os.path.isfile(os.path.join(doc_path, file))]
|
||||
return []
|
||||
|
||||
def get_LoaderClass(file_extension):
|
||||
for LoaderClass, extensions in LOADER2EXT_DICT.items():
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
from .dialogue import dialogue_page, chat_box
|
||||
from .document import knowledge_page
|
||||
from .code import code_page
|
||||
from .prompt import prompt_page
|
||||
from .utils import ApiRequest
|
||||
|
||||
__all__ = [
|
||||
"dialogue_page", "chat_box", "prompt_page", "knowledge_page",
|
||||
"ApiRequest"
|
||||
"ApiRequest", "code_page"
|
||||
]
|
|
@ -0,0 +1,140 @@
|
|||
# encoding: utf-8
|
||||
'''
|
||||
@author: 温进
|
||||
@file: code.py.py
|
||||
@time: 2023/10/23 下午5:31
|
||||
@desc:
|
||||
'''
|
||||
|
||||
import streamlit as st
|
||||
import os
|
||||
import time
|
||||
import traceback
|
||||
from typing import Literal, Dict, Tuple
|
||||
from st_aggrid import AgGrid, JsCode
|
||||
from st_aggrid.grid_options_builder import GridOptionsBuilder
|
||||
import pandas as pd
|
||||
|
||||
from configs.model_config import embedding_model_dict, kbs_config, EMBEDDING_MODEL, DEFAULT_VS_TYPE, WEB_CRAWL_PATH
|
||||
from .utils import *
|
||||
from dev_opsgpt.utils.path_utils import *
|
||||
from dev_opsgpt.service.service_factory import get_cb_details, get_cb_details_by_cb_name
|
||||
from dev_opsgpt.orm import table_init
|
||||
|
||||
# SENTENCE_SIZE = 100
|
||||
|
||||
cell_renderer = JsCode("""function(params) {if(params.value==true){return '✓'}else{return '×'}}""")
|
||||
|
||||
|
||||
def file_exists(cb: str, selected_rows: List) -> Tuple[str, str]:
|
||||
'''
|
||||
check whether the dir exist in local file
|
||||
return the dir's name and path if it exists.
|
||||
'''
|
||||
if selected_rows:
|
||||
file_name = selected_rows[0]["code_name"]
|
||||
file_path = get_file_path(cb, file_name)
|
||||
if os.path.isfile(file_path):
|
||||
return file_name, file_path
|
||||
return "", ""
|
||||
|
||||
|
||||
def code_page(api: ApiRequest):
|
||||
# 判断表是否存在并进行初始化
|
||||
table_init()
|
||||
|
||||
try:
|
||||
logger.info(get_cb_details())
|
||||
cb_list = {x["code_name"]: x for x in get_cb_details()}
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
st.error("获取知识库信息错误,请检查是否已按照 `README.md` 中 `4 知识库初始化与迁移` 步骤完成初始化或迁移,或是否为数据库连接错误。")
|
||||
st.stop()
|
||||
cb_names = list(cb_list.keys())
|
||||
|
||||
if "selected_cb_name" in st.session_state and st.session_state["selected_cb_name"] in cb_names:
|
||||
selected_cb_index = cb_names.index(st.session_state["selected_cb_name"])
|
||||
else:
|
||||
selected_cb_index = 0
|
||||
|
||||
def format_selected_cb(cb_name: str) -> str:
|
||||
if cb := cb_list.get(cb_name):
|
||||
return f"{cb_name} ({cb['code_path']})"
|
||||
else:
|
||||
return cb_name
|
||||
|
||||
selected_cb = st.selectbox(
|
||||
"请选择或新建代码知识库:",
|
||||
cb_names + ["新建代码知识库"],
|
||||
format_func=format_selected_cb,
|
||||
index=selected_cb_index
|
||||
)
|
||||
|
||||
if selected_cb == "新建代码知识库":
|
||||
with st.form("新建代码知识库"):
|
||||
|
||||
cb_name = st.text_input(
|
||||
"新建代码知识库名称",
|
||||
placeholder="新代码知识库名称,不支持中文命名",
|
||||
key="cb_name",
|
||||
)
|
||||
|
||||
file = st.file_uploader("上传代码库 zip 文件",
|
||||
['.zip'],
|
||||
accept_multiple_files=False,
|
||||
)
|
||||
|
||||
submit_create_kb = st.form_submit_button(
|
||||
"新建",
|
||||
use_container_width=True,
|
||||
)
|
||||
|
||||
if submit_create_kb:
|
||||
# unzip file
|
||||
logger.info('files={}'.format(file))
|
||||
|
||||
if not cb_name or not cb_name.strip():
|
||||
st.error(f"知识库名称不能为空!")
|
||||
elif cb_name in cb_list:
|
||||
st.error(f"名为 {cb_name} 的知识库已经存在!")
|
||||
elif file.type not in ['application/zip', 'application/x-zip-compressed']:
|
||||
logger.error(f"{file.type}")
|
||||
st.error('请先上传 zip 文件,再新建代码知识库')
|
||||
else:
|
||||
ret = api.create_code_base(
|
||||
cb_name,
|
||||
file,
|
||||
no_remote_api=True
|
||||
)
|
||||
st.toast(ret.get("msg", " "))
|
||||
st.session_state["selected_cb_name"] = cb_name
|
||||
st.experimental_rerun()
|
||||
elif selected_cb:
|
||||
cb = selected_cb
|
||||
|
||||
# 知识库详情
|
||||
cb_details = get_cb_details_by_cb_name(cb)
|
||||
if not len(cb_details):
|
||||
st.info(f"代码知识库 `{cb}` 中暂无信息")
|
||||
else:
|
||||
logger.info(cb_details)
|
||||
st.write(f"代码知识库 `{cb}` 加载成功,中含有以下信息:")
|
||||
|
||||
st.write('代码知识库 `{}` 代码文件数=`{}`'.format(cb_details['code_name'],
|
||||
cb_details.get('code_file_num', 'unknown')))
|
||||
|
||||
st.write('代码知识库 `{}` 知识图谱节点数=`{}`'.format(cb_details['code_name'], cb_details['code_graph_node_num']))
|
||||
|
||||
st.divider()
|
||||
|
||||
cols = st.columns(3)
|
||||
|
||||
if cols[2].button(
|
||||
"删除知识库",
|
||||
use_container_width=True,
|
||||
):
|
||||
ret = api.delete_code_base(cb,
|
||||
no_remote_api=True)
|
||||
st.toast(ret.get("msg", " "))
|
||||
time.sleep(1)
|
||||
st.experimental_rerun()
|
|
@ -2,10 +2,13 @@ import streamlit as st
|
|||
from streamlit_chatbox import *
|
||||
from typing import List, Dict
|
||||
from datetime import datetime
|
||||
|
||||
from random import randint
|
||||
from .utils import *
|
||||
|
||||
from dev_opsgpt.utils import *
|
||||
from dev_opsgpt.tools import TOOL_SETS
|
||||
from dev_opsgpt.chat.search_chat import SEARCH_ENGINES
|
||||
from dev_opsgpt.connector import PHASE_LIST, PHASE_CONFIGS
|
||||
|
||||
chat_box = ChatBox(
|
||||
assistant_avatar="../sources/imgs/devops-chatbot2.png"
|
||||
|
@ -55,7 +58,11 @@ def dialogue_page(api: ApiRequest):
|
|||
dialogue_mode = st.selectbox("请选择对话模式",
|
||||
["LLM 对话",
|
||||
"知识库问答",
|
||||
"代码知识库问答",
|
||||
"搜索引擎问答",
|
||||
"工具问答",
|
||||
"数据分析",
|
||||
"Agent问答"
|
||||
],
|
||||
on_change=on_mode_change,
|
||||
key="dialogue_mode",
|
||||
|
@ -67,6 +74,10 @@ def dialogue_page(api: ApiRequest):
|
|||
def on_kb_change():
|
||||
st.toast(f"已加载知识库: {st.session_state.selected_kb}")
|
||||
|
||||
def on_cb_change():
|
||||
st.toast(f"已加载代码知识库: {st.session_state.selected_cb}")
|
||||
|
||||
not_agent_qa = True
|
||||
if dialogue_mode == "知识库问答":
|
||||
with st.expander("知识库配置", True):
|
||||
kb_list = api.list_knowledge_bases(no_remote_api=True)
|
||||
|
@ -80,13 +91,142 @@ def dialogue_page(api: ApiRequest):
|
|||
score_threshold = st.number_input("知识匹配分数阈值:", 0.0, float(SCORE_THRESHOLD), float(SCORE_THRESHOLD), float(SCORE_THRESHOLD//100))
|
||||
# chunk_content = st.checkbox("关联上下文", False, disabled=True)
|
||||
# chunk_size = st.slider("关联长度:", 0, 500, 250, disabled=True)
|
||||
elif dialogue_mode == '代码知识库问答':
|
||||
with st.expander('代码知识库配置', True):
|
||||
cb_list = api.list_cb(no_remote_api=True)
|
||||
logger.debug('codebase_list={}'.format(cb_list))
|
||||
selected_cb = st.selectbox(
|
||||
"请选择代码知识库:",
|
||||
cb_list,
|
||||
on_change=on_cb_change,
|
||||
key="selected_cb",
|
||||
)
|
||||
st.toast(f"已加载代码知识库: {st.session_state.selected_cb}")
|
||||
cb_code_limit = st.number_input("匹配代码条数:", 1, 20, 1)
|
||||
elif dialogue_mode == "搜索引擎问答":
|
||||
with st.expander("搜索引擎配置", True):
|
||||
search_engine = st.selectbox("请选择搜索引擎", SEARCH_ENGINES.keys(), 0)
|
||||
se_top_k = st.number_input("匹配搜索结果条数:", 1, 20, 3)
|
||||
elif dialogue_mode == "工具问答":
|
||||
with st.expander("工具军火库", True):
|
||||
tool_selects = st.multiselect(
|
||||
'请选择待使用的工具', TOOL_SETS, ["WeatherInfo"])
|
||||
|
||||
elif dialogue_mode == "数据分析":
|
||||
with st.expander("沙盒文件管理", False):
|
||||
def _upload(upload_file):
|
||||
res = api.web_sd_upload(upload_file)
|
||||
logger.debug(res)
|
||||
if res["msg"]:
|
||||
st.success("上文件传成功")
|
||||
else:
|
||||
st.toast("文件上传失败")
|
||||
|
||||
code_interpreter_on = st.toggle("开启代码解释器")
|
||||
code_exec_on = st.toggle("自动执行代码")
|
||||
interpreter_file = st.file_uploader(
|
||||
"上传沙盒文件",
|
||||
[i for ls in LOADER2EXT_DICT.values() for i in ls],
|
||||
accept_multiple_files=False,
|
||||
key="interpreter_file",
|
||||
)
|
||||
|
||||
if interpreter_file:
|
||||
_upload(interpreter_file)
|
||||
interpreter_file = None
|
||||
#
|
||||
files = api.web_sd_list_files()
|
||||
files = files["data"]
|
||||
download_file = st.selectbox("选择要处理文件", files,
|
||||
key="download_file",)
|
||||
|
||||
cols = st.columns(2)
|
||||
file_url, file_name = api.web_sd_download(download_file)
|
||||
cols[0].download_button("点击下载", file_url, file_name)
|
||||
if cols[1].button("点击删除", ):
|
||||
api.web_sd_delete(download_file)
|
||||
|
||||
elif dialogue_mode == "Agent问答":
|
||||
not_agent_qa = False
|
||||
with st.expander("Phase管理", True):
|
||||
choose_phase = st.selectbox(
|
||||
'请选择待使用的执行链路', PHASE_LIST, 0)
|
||||
|
||||
is_detailed = st.toggle("返回明细的Agent交互", False)
|
||||
tool_using_on = st.toggle("开启工具使用", PHASE_CONFIGS[choose_phase]["do_using_tool"])
|
||||
tool_selects = []
|
||||
if tool_using_on:
|
||||
with st.expander("工具军火库", True):
|
||||
tool_selects = st.multiselect(
|
||||
'请选择待使用的工具', TOOL_SETS, ["WeatherInfo"])
|
||||
|
||||
search_on = st.toggle("开启搜索增强", PHASE_CONFIGS[choose_phase]["do_search"])
|
||||
search_engine, top_k = None, 3
|
||||
if search_on:
|
||||
with st.expander("搜索引擎配置", True):
|
||||
search_engine = st.selectbox("请选择搜索引擎", SEARCH_ENGINES.keys(), 0)
|
||||
top_k = st.number_input("匹配搜索结果条数:", 1, 20, 3)
|
||||
|
||||
doc_retrieval_on = st.toggle("开启知识库检索增强", PHASE_CONFIGS[choose_phase]["do_doc_retrieval"])
|
||||
selected_kb, top_k, score_threshold = None, 3, 1.0
|
||||
if doc_retrieval_on:
|
||||
with st.expander("知识库配置", True):
|
||||
kb_list = api.list_knowledge_bases(no_remote_api=True)
|
||||
selected_kb = st.selectbox(
|
||||
"请选择知识库:",
|
||||
kb_list,
|
||||
on_change=on_kb_change,
|
||||
key="selected_kb",
|
||||
)
|
||||
top_k = st.number_input("匹配知识条数:", 1, 20, 3)
|
||||
score_threshold = st.number_input("知识匹配分数阈值:", 0.0, float(SCORE_THRESHOLD), float(SCORE_THRESHOLD), float(SCORE_THRESHOLD//100))
|
||||
|
||||
code_retrieval_on = st.toggle("开启代码检索增强", PHASE_CONFIGS[choose_phase]["do_code_retrieval"])
|
||||
selected_cb, top_k = None, 1
|
||||
if code_retrieval_on:
|
||||
with st.expander('代码知识库配置', True):
|
||||
cb_list = api.list_cb(no_remote_api=True)
|
||||
logger.debug('codebase_list={}'.format(cb_list))
|
||||
selected_cb = st.selectbox(
|
||||
"请选择代码知识库:",
|
||||
cb_list,
|
||||
on_change=on_cb_change,
|
||||
key="selected_cb",
|
||||
)
|
||||
st.toast(f"已加载代码知识库: {st.session_state.selected_cb}")
|
||||
top_k = st.number_input("匹配代码条数:", 1, 20, 1)
|
||||
|
||||
with st.expander("沙盒文件管理", False):
|
||||
def _upload(upload_file):
|
||||
res = api.web_sd_upload(upload_file)
|
||||
logger.debug(res)
|
||||
if res["msg"]:
|
||||
st.success("上文件传成功")
|
||||
else:
|
||||
st.toast("文件上传失败")
|
||||
|
||||
interpreter_file = st.file_uploader(
|
||||
"上传沙盒文件",
|
||||
[i for ls in LOADER2EXT_DICT.values() for i in ls],
|
||||
accept_multiple_files=False,
|
||||
key="interpreter_file",
|
||||
)
|
||||
|
||||
if interpreter_file:
|
||||
_upload(interpreter_file)
|
||||
interpreter_file = None
|
||||
#
|
||||
files = api.web_sd_list_files()
|
||||
files = files["data"]
|
||||
download_file = st.selectbox("选择要处理文件", files,
|
||||
key="download_file",)
|
||||
|
||||
cols = st.columns(2)
|
||||
file_url, file_name = api.web_sd_download(download_file)
|
||||
cols[0].download_button("点击下载", file_url, file_name)
|
||||
if cols[1].button("点击删除", ):
|
||||
api.web_sd_delete(download_file)
|
||||
|
||||
code_interpreter_on = st.toggle("开启代码解释器") and not_agent_qa
|
||||
code_exec_on = st.toggle("自动执行代码") and not_agent_qa
|
||||
|
||||
# Display chat messages from history on app rerun
|
||||
|
||||
|
@ -102,7 +242,97 @@ def dialogue_page(api: ApiRequest):
|
|||
if dialogue_mode == "LLM 对话":
|
||||
chat_box.ai_say("正在思考...")
|
||||
text = ""
|
||||
r = api.chat_chat(prompt, history)
|
||||
r = api.chat_chat(prompt, history, no_remote_api=True)
|
||||
for t in r:
|
||||
if error_msg := check_error_msg(t): # check whether error occured
|
||||
st.error(error_msg)
|
||||
break
|
||||
text += t["answer"]
|
||||
chat_box.update_msg(text)
|
||||
logger.debug(f"text: {text}")
|
||||
chat_box.update_msg(text, streaming=False) # 更新最终的字符串,去除光标
|
||||
# 判断是否存在代码, 并提高编辑功能,执行功能
|
||||
code_text = api.codebox.decode_code_from_text(text)
|
||||
GLOBAL_EXE_CODE_TEXT = code_text
|
||||
if code_text and code_exec_on:
|
||||
codebox_res = api.codebox_chat("```"+code_text+"```", do_code_exe=True)
|
||||
elif dialogue_mode == "Agent问答":
|
||||
display_infos = [f"正在思考..."]
|
||||
if search_on:
|
||||
display_infos.append(Markdown("...", in_expander=True, title="网络搜索结果"))
|
||||
if doc_retrieval_on:
|
||||
display_infos.append(Markdown("...", in_expander=True, title="知识库匹配结果"))
|
||||
chat_box.ai_say(display_infos)
|
||||
|
||||
if 'history_node_list' in st.session_state:
|
||||
history_node_list: List[str] = st.session_state['history_node_list']
|
||||
else:
|
||||
history_node_list: List[str] = []
|
||||
|
||||
input_kargs = {"query": prompt,
|
||||
"phase_name": choose_phase,
|
||||
"history": history,
|
||||
"doc_engine_name": selected_kb,
|
||||
"search_engine_name": search_engine,
|
||||
"code_engine_name": selected_cb,
|
||||
"top_k": top_k,
|
||||
"score_threshold": score_threshold,
|
||||
"do_search": search_on,
|
||||
"do_doc_retrieval": doc_retrieval_on,
|
||||
"do_code_retrieval": code_retrieval_on,
|
||||
"do_tool_retrieval": False,
|
||||
"custom_phase_configs": {},
|
||||
"custom_chain_configs": {},
|
||||
"custom_role_configs": {},
|
||||
"choose_tools": tool_selects,
|
||||
"history_node_list": history_node_list,
|
||||
"isDetailed": is_detailed,
|
||||
}
|
||||
text = ""
|
||||
d = {"docs": []}
|
||||
for idx_count, d in enumerate(api.agent_chat(**input_kargs)):
|
||||
if error_msg := check_error_msg(d): # check whether error occured
|
||||
st.error(error_msg)
|
||||
text += d["answer"]
|
||||
if idx_count%20 == 0:
|
||||
chat_box.update_msg(text, element_index=0)
|
||||
|
||||
for k, v in d["figures"].items():
|
||||
logger.debug(f"figure: {k}")
|
||||
if k in text:
|
||||
img_html = "\n<img src='data:image/png;base64,{}' class='img-fluid'>\n".format(v)
|
||||
text = text.replace(k, img_html).replace(".png", "")
|
||||
chat_box.update_msg(text, element_index=0, streaming=False, state="complete") # 更新最终的字符串,去除光标
|
||||
if search_on:
|
||||
chat_box.update_msg("搜索匹配结果:\n\n" + "\n\n".join(d["search_docs"]), element_index=search_on, streaming=False, state="complete")
|
||||
if doc_retrieval_on:
|
||||
chat_box.update_msg("知识库匹配结果:\n\n" + "\n\n".join(d["db_docs"]), element_index=search_on+doc_retrieval_on, streaming=False, state="complete")
|
||||
|
||||
history_node_list.extend([node[0] for node in d.get("related_nodes", [])])
|
||||
history_node_list = list(set(history_node_list))
|
||||
st.session_state['history_node_list'] = history_node_list
|
||||
|
||||
elif dialogue_mode == "工具问答":
|
||||
chat_box.ai_say("正在思考...")
|
||||
text = ""
|
||||
r = api.tool_chat(prompt, history, tool_sets=tool_selects)
|
||||
for t in r:
|
||||
if error_msg := check_error_msg(t): # check whether error occured
|
||||
st.error(error_msg)
|
||||
break
|
||||
text += t["answer"]
|
||||
chat_box.update_msg(text)
|
||||
logger.debug(f"text: {text}")
|
||||
chat_box.update_msg(text, streaming=False) # 更新最终的字符串,去除光标
|
||||
# 判断是否存在代码, 并提高编辑功能,执行功能
|
||||
code_text = api.codebox.decode_code_from_text(text)
|
||||
GLOBAL_EXE_CODE_TEXT = code_text
|
||||
if code_text and code_exec_on:
|
||||
codebox_res = api.codebox_chat("```"+code_text+"```", do_code_exe=True)
|
||||
elif dialogue_mode == "数据分析":
|
||||
chat_box.ai_say("正在思考...")
|
||||
text = ""
|
||||
r = api.data_chat(prompt, history)
|
||||
for t in r:
|
||||
if error_msg := check_error_msg(t): # check whether error occured
|
||||
st.error(error_msg)
|
||||
|
@ -116,7 +346,6 @@ def dialogue_page(api: ApiRequest):
|
|||
GLOBAL_EXE_CODE_TEXT = code_text
|
||||
if code_text and code_exec_on:
|
||||
codebox_res = api.codebox_chat("```"+code_text+"```", do_code_exe=True)
|
||||
|
||||
elif dialogue_mode == "知识库问答":
|
||||
history = get_messages_history(history_len)
|
||||
chat_box.ai_say([
|
||||
|
@ -124,6 +353,7 @@ def dialogue_page(api: ApiRequest):
|
|||
Markdown("...", in_expander=True, title="知识库匹配结果"),
|
||||
])
|
||||
text = ""
|
||||
d = {"docs": []}
|
||||
for idx_count, d in enumerate(api.knowledge_base_chat(prompt, selected_kb, kb_top_k, score_threshold, history)):
|
||||
if error_msg := check_error_msg(d): # check whether error occured
|
||||
st.error(error_msg)
|
||||
|
@ -138,6 +368,33 @@ def dialogue_page(api: ApiRequest):
|
|||
GLOBAL_EXE_CODE_TEXT = code_text
|
||||
if code_text and code_exec_on:
|
||||
codebox_res = api.codebox_chat("```"+code_text+"```", do_code_exe=True)
|
||||
elif dialogue_mode == '代码知识库问答':
|
||||
logger.info('prompt={}'.format(prompt))
|
||||
logger.info('history={}'.format(history))
|
||||
if 'history_node_list' in st.session_state:
|
||||
api.codeChat.history_node_list = st.session_state['history_node_list']
|
||||
|
||||
chat_box.ai_say([
|
||||
f"正在查询代码知识库 `{selected_cb}` ...",
|
||||
Markdown("...", in_expander=True, title="代码库匹配结果"),
|
||||
])
|
||||
text = ""
|
||||
d = {"codes": []}
|
||||
|
||||
for idx_count, d in enumerate(api.code_base_chat(query=prompt, code_base_name=selected_cb,
|
||||
code_limit=cb_code_limit, history=history,
|
||||
no_remote_api=True)):
|
||||
if error_msg := check_error_msg(d):
|
||||
st.error(error_msg)
|
||||
text += d["answer"]
|
||||
if idx_count % 10 == 0:
|
||||
chat_box.update_msg(text, element_index=0)
|
||||
chat_box.update_msg(text, element_index=0, streaming=False) # 更新最终的字符串,去除光标
|
||||
chat_box.update_msg("\n".join(d["codes"]), element_index=1, streaming=False, state="complete")
|
||||
|
||||
# session state update
|
||||
st.session_state['history_node_list'] = api.codeChat.history_node_list
|
||||
|
||||
elif dialogue_mode == "搜索引擎问答":
|
||||
chat_box.ai_say([
|
||||
f"正在执行 `{search_engine}` 搜索...",
|
||||
|
@ -145,7 +402,7 @@ def dialogue_page(api: ApiRequest):
|
|||
])
|
||||
text = ""
|
||||
d = {"docs": []}
|
||||
for d in api.search_engine_chat(prompt, search_engine, se_top_k):
|
||||
for idx_count, d in enumerate(api.search_engine_chat(prompt, search_engine, se_top_k)):
|
||||
if error_msg := check_error_msg(d): # check whether error occured
|
||||
st.error(error_msg)
|
||||
text += d["answer"]
|
||||
|
@ -194,7 +451,7 @@ def dialogue_page(api: ApiRequest):
|
|||
img_html = "<img src='data:image/png;base64,{}' class='img-fluid'>".format(
|
||||
codebox_res.code_exe_response
|
||||
)
|
||||
chat_box.update_msg(base_text + img_html, streaming=False, state="complete")
|
||||
chat_box.update_msg(img_html, streaming=False, state="complete")
|
||||
else:
|
||||
chat_box.update_msg('```\n'+code_text+'\n```'+"\n\n"+'```\n'+codebox_res.code_exe_response+'\n```',
|
||||
streaming=False, state="complete")
|
||||
|
|
|
@ -137,7 +137,24 @@ def knowledge_page(api: ApiRequest):
|
|||
[i for ls in LOADER2EXT_DICT.values() for i in ls],
|
||||
accept_multiple_files=True,
|
||||
)
|
||||
|
||||
|
||||
if st.button(
|
||||
"添加文件到知识库",
|
||||
# help="请先上传文件,再点击添加",
|
||||
# use_container_width=True,
|
||||
disabled=len(files) == 0,
|
||||
):
|
||||
data = [{"file": f, "knowledge_base_name": kb, "not_refresh_vs_cache": True} for f in files]
|
||||
data[-1]["not_refresh_vs_cache"]=False
|
||||
for k in data:
|
||||
pass
|
||||
ret = api.upload_kb_doc(**k)
|
||||
if msg := check_success_msg(ret):
|
||||
st.toast(msg, icon="✔")
|
||||
elif msg := check_error_msg(ret):
|
||||
st.toast(msg, icon="✖")
|
||||
st.session_state.files = []
|
||||
|
||||
base_url = st.text_input(
|
||||
"待获取内容的URL地址",
|
||||
placeholder="请填写正确可打开的URL地址",
|
||||
|
@ -187,22 +204,6 @@ def knowledge_page(api: ApiRequest):
|
|||
if os.path.exists(html_path):
|
||||
os.remove(html_path)
|
||||
|
||||
if st.button(
|
||||
"添加文件到知识库",
|
||||
# help="请先上传文件,再点击添加",
|
||||
# use_container_width=True,
|
||||
disabled=len(files) == 0,
|
||||
):
|
||||
data = [{"file": f, "knowledge_base_name": kb, "not_refresh_vs_cache": True} for f in files]
|
||||
data[-1]["not_refresh_vs_cache"]=False
|
||||
for k in data:
|
||||
ret = api.upload_kb_doc(**k)
|
||||
if msg := check_success_msg(ret):
|
||||
st.toast(msg, icon="✔")
|
||||
elif msg := check_error_msg(ret):
|
||||
st.toast(msg, icon="✖")
|
||||
st.session_state.files = []
|
||||
|
||||
st.divider()
|
||||
|
||||
# 知识库详情
|
||||
|
|
|
@ -10,11 +10,13 @@ import json
|
|||
import nltk
|
||||
import traceback
|
||||
from loguru import logger
|
||||
import zipfile
|
||||
|
||||
from configs.model_config import (
|
||||
EMBEDDING_MODEL,
|
||||
DEFAULT_VS_TYPE,
|
||||
KB_ROOT_PATH,
|
||||
CB_ROOT_PATH,
|
||||
LLM_MODEL,
|
||||
SCORE_THRESHOLD,
|
||||
VECTOR_SEARCH_TOP_K,
|
||||
|
@ -27,8 +29,10 @@ from configs.server_config import SANDBOX_SERVER
|
|||
|
||||
from dev_opsgpt.utils.server_utils import run_async, iter_over_async
|
||||
from dev_opsgpt.service.kb_api import *
|
||||
from dev_opsgpt.chat import LLMChat, SearchChat, KnowledgeChat
|
||||
from dev_opsgpt.service.cb_api import *
|
||||
from dev_opsgpt.chat import LLMChat, SearchChat, KnowledgeChat, ToolChat, DataChat, CodeChat, AgentChat
|
||||
from dev_opsgpt.sandbox import PyCodeBox, CodeBoxResponse
|
||||
from dev_opsgpt.utils.common_utils import file_normalize, get_uploadfile
|
||||
|
||||
from web_crawler.utils.WebCrawler import WebCrawler
|
||||
|
||||
|
@ -58,15 +62,22 @@ class ApiRequest:
|
|||
def __init__(
|
||||
self,
|
||||
base_url: str = "http://127.0.0.1:7861",
|
||||
sandbox_file_url: str = "http://127.0.0.1:7862",
|
||||
timeout: float = 60.0,
|
||||
no_remote_api: bool = False, # call api view function directly
|
||||
):
|
||||
self.base_url = base_url
|
||||
self.sandbox_file_url = sandbox_file_url
|
||||
self.timeout = timeout
|
||||
self.no_remote_api = no_remote_api
|
||||
self.llmChat = LLMChat()
|
||||
self.searchChat = SearchChat()
|
||||
self.knowledgeChat = KnowledgeChat()
|
||||
self.toolChat = ToolChat()
|
||||
self.dataChat = DataChat()
|
||||
self.codeChat = CodeChat()
|
||||
|
||||
self.agentChat = AgentChat()
|
||||
self.codebox = PyCodeBox(
|
||||
remote_url=SANDBOX_SERVER["url"],
|
||||
remote_ip=SANDBOX_SERVER["host"], # "http://localhost",
|
||||
|
@ -83,7 +94,8 @@ class ApiRequest:
|
|||
if (not url.startswith("http")
|
||||
and self.base_url
|
||||
):
|
||||
part1 = self.base_url.strip(" /")
|
||||
part1 = self.sandbox_file_url.strip(" /") \
|
||||
if "sdfiles" in url else self.base_url.strip(" /")
|
||||
part2 = url.strip(" /")
|
||||
return f"{part1}/{part2}"
|
||||
else:
|
||||
|
@ -331,7 +343,7 @@ class ApiRequest:
|
|||
self,
|
||||
query: str,
|
||||
search_engine_name: str,
|
||||
top_k: int = SEARCH_ENGINE_TOP_K,
|
||||
code_limit: int,
|
||||
stream: bool = True,
|
||||
no_remote_api: bool = None,
|
||||
):
|
||||
|
@ -344,7 +356,7 @@ class ApiRequest:
|
|||
data = {
|
||||
"query": query,
|
||||
"engine_name": search_engine_name,
|
||||
"top_k": top_k,
|
||||
"code_limit": code_limit,
|
||||
"history": [],
|
||||
"stream": stream,
|
||||
}
|
||||
|
@ -360,7 +372,157 @@ class ApiRequest:
|
|||
)
|
||||
return self._httpx_stream2generator(response, as_json=True)
|
||||
|
||||
# 知识库相关操作
|
||||
def tool_chat(
|
||||
self,
|
||||
query: str,
|
||||
history: List[Dict] = [],
|
||||
tool_sets: List[str] = [],
|
||||
stream: bool = True,
|
||||
no_remote_api: bool = None,
|
||||
):
|
||||
'''
|
||||
对应api.py/chat/chat接口
|
||||
'''
|
||||
if no_remote_api is None:
|
||||
no_remote_api = self.no_remote_api
|
||||
|
||||
data = {
|
||||
"query": query,
|
||||
"history": history,
|
||||
"tool_sets": tool_sets,
|
||||
"stream": stream,
|
||||
}
|
||||
|
||||
if no_remote_api:
|
||||
response = self.toolChat.chat(**data)
|
||||
return self._fastapi_stream2generator(response, as_json=True)
|
||||
else:
|
||||
response = self.post("/chat/tool_chat", json=data, stream=True)
|
||||
return self._httpx_stream2generator(response)
|
||||
|
||||
def data_chat(
|
||||
self,
|
||||
query: str,
|
||||
history: List[Dict] = [],
|
||||
stream: bool = True,
|
||||
no_remote_api: bool = None,
|
||||
):
|
||||
'''
|
||||
对应api.py/chat/chat接口
|
||||
'''
|
||||
if no_remote_api is None:
|
||||
no_remote_api = self.no_remote_api
|
||||
|
||||
data = {
|
||||
"query": query,
|
||||
"history": history,
|
||||
"stream": stream,
|
||||
}
|
||||
|
||||
if no_remote_api:
|
||||
response = self.dataChat.chat(**data)
|
||||
return self._fastapi_stream2generator(response, as_json=True)
|
||||
else:
|
||||
response = self.post("/chat/data_chat", json=data, stream=True)
|
||||
return self._httpx_stream2generator(response)
|
||||
|
||||
def code_base_chat(
|
||||
self,
|
||||
query: str,
|
||||
code_base_name: str,
|
||||
code_limit: int = 1,
|
||||
history: List[Dict] = [],
|
||||
stream: bool = True,
|
||||
no_remote_api: bool = None,
|
||||
):
|
||||
'''
|
||||
对应api.py/chat/knowledge_base_chat接口
|
||||
'''
|
||||
if no_remote_api is None:
|
||||
no_remote_api = self.no_remote_api
|
||||
|
||||
data = {
|
||||
"query": query,
|
||||
"history": history,
|
||||
"engine_name": code_base_name,
|
||||
"code_limit": code_limit,
|
||||
"stream": stream,
|
||||
"local_doc_url": no_remote_api,
|
||||
}
|
||||
logger.info('data={}'.format(data))
|
||||
|
||||
if no_remote_api:
|
||||
logger.info('history_node_list before={}'.format(self.codeChat.history_node_list))
|
||||
response = self.codeChat.chat(**data)
|
||||
logger.info('history_node_list after={}'.format(self.codeChat.history_node_list))
|
||||
return self._fastapi_stream2generator(response, as_json=True)
|
||||
else:
|
||||
response = self.post(
|
||||
"/chat/code_chat",
|
||||
json=data,
|
||||
stream=True,
|
||||
)
|
||||
return self._httpx_stream2generator(response, as_json=True)
|
||||
|
||||
def agent_chat(
|
||||
self,
|
||||
query: str,
|
||||
phase_name: str,
|
||||
doc_engine_name: str,
|
||||
code_engine_name: str,
|
||||
search_engine_name: str,
|
||||
top_k: int = 3,
|
||||
score_threshold: float = 1.0,
|
||||
history: List[Dict] = [],
|
||||
stream: bool = True,
|
||||
local_doc_url: bool = False,
|
||||
do_search: bool = False,
|
||||
do_doc_retrieval: bool = False,
|
||||
do_code_retrieval: bool = False,
|
||||
do_tool_retrieval: bool = False,
|
||||
choose_tools: List[str] = [],
|
||||
custom_phase_configs = {},
|
||||
custom_chain_configs = {},
|
||||
custom_role_configs = {},
|
||||
no_remote_api: bool = None,
|
||||
history_node_list: List[str] = [],
|
||||
isDetailed: bool = False
|
||||
):
|
||||
'''
|
||||
对应api.py/chat/chat接口
|
||||
'''
|
||||
if no_remote_api is None:
|
||||
no_remote_api = self.no_remote_api
|
||||
|
||||
data = {
|
||||
"query": query,
|
||||
"phase_name": phase_name,
|
||||
"chain_name": "",
|
||||
"history": history,
|
||||
"doc_engine_name": doc_engine_name,
|
||||
"code_engine_name": code_engine_name,
|
||||
"search_engine_name": search_engine_name,
|
||||
"top_k": top_k,
|
||||
"score_threshold": score_threshold,
|
||||
"stream": stream,
|
||||
"local_doc_url": local_doc_url,
|
||||
"do_search": do_search,
|
||||
"do_doc_retrieval": do_doc_retrieval,
|
||||
"do_code_retrieval": do_code_retrieval,
|
||||
"do_tool_retrieval": do_tool_retrieval,
|
||||
"custom_phase_configs": custom_phase_configs,
|
||||
"custom_chain_configs": custom_phase_configs,
|
||||
"custom_role_configs": custom_role_configs,
|
||||
"choose_tools": choose_tools,
|
||||
"history_node_list": history_node_list,
|
||||
"isDetailed": isDetailed
|
||||
}
|
||||
if no_remote_api:
|
||||
response = self.agentChat.chat(**data)
|
||||
return self._fastapi_stream2generator(response, as_json=True)
|
||||
else:
|
||||
response = self.post("/chat/data_chat", json=data, stream=True)
|
||||
return self._httpx_stream2generator(response)
|
||||
|
||||
def _check_httpx_json_response(
|
||||
self,
|
||||
|
@ -377,6 +539,21 @@ class ApiRequest:
|
|||
logger.error(e)
|
||||
return {"code": 500, "msg": errorMsg or str(e)}
|
||||
|
||||
def _check_httpx_file_response(
|
||||
self,
|
||||
response: httpx.Response,
|
||||
errorMsg: str = f"无法连接API服务器,请确认已执行python server\\api.py",
|
||||
) -> Dict:
|
||||
'''
|
||||
check whether httpx returns correct data with normal Response.
|
||||
error in api with streaming support was checked in _httpx_stream2enerator
|
||||
'''
|
||||
try:
|
||||
return response.content
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
return {"code": 500, "msg": errorMsg or str(e)}
|
||||
|
||||
def list_knowledge_bases(
|
||||
self,
|
||||
no_remote_api: bool = None,
|
||||
|
@ -662,6 +839,122 @@ class ApiRequest:
|
|||
else:
|
||||
raise Exception("not impletenion")
|
||||
|
||||
def web_sd_upload(self, file: str = None, filename: str = None):
|
||||
'''对应file_service/sd_upload_file'''
|
||||
file, filename = file_normalize(file, filename)
|
||||
response = self.post(
|
||||
"/sdfiles/upload",
|
||||
files={"file": (filename, file)},
|
||||
)
|
||||
return self._check_httpx_json_response(response)
|
||||
|
||||
def web_sd_download(self, filename: str, save_filename: str = None):
|
||||
'''对应file_service/sd_download_file'''
|
||||
save_filename = save_filename or filename
|
||||
# response = self.get(
|
||||
# f"/sdfiles/download",
|
||||
# params={"filename": filename, "save_filename": save_filename}
|
||||
# )
|
||||
key_value_str = f"filename={filename}&save_filename={save_filename}"
|
||||
return self._parse_url(f"/sdfiles/download?{key_value_str}"), save_filename
|
||||
|
||||
def web_sd_delete(self, filename: str):
|
||||
'''对应file_service/sd_delete_file'''
|
||||
response = self.get(
|
||||
f"/sdfiles/delete",
|
||||
params={"filename": filename}
|
||||
)
|
||||
return self._check_httpx_json_response(response)
|
||||
|
||||
def web_sd_list_files(self, ):
|
||||
'''对应对应file_service/sd_list_files接口'''
|
||||
response = self.get("/sdfiles/list",)
|
||||
return self._check_httpx_json_response(response)
|
||||
|
||||
# code base 相关操作
|
||||
def create_code_base(self, cb_name, zip_file, no_remote_api: bool = None,):
|
||||
'''
|
||||
创建 code_base
|
||||
@param cb_name:
|
||||
@param zip_path:
|
||||
@return:
|
||||
'''
|
||||
if no_remote_api is None:
|
||||
no_remote_api = self.no_remote_api
|
||||
|
||||
# mkdir
|
||||
cb_root_path = CB_ROOT_PATH
|
||||
mkdir_dir = [
|
||||
cb_root_path,
|
||||
cb_root_path + os.sep + cb_name,
|
||||
raw_code_path := cb_root_path + os.sep + cb_name + os.sep + 'raw_code'
|
||||
]
|
||||
for dir in mkdir_dir:
|
||||
if not os.path.exists(dir):
|
||||
os.makedirs(dir)
|
||||
|
||||
# unzip
|
||||
with zipfile.ZipFile(zip_file, 'r') as z:
|
||||
z.extractall(raw_code_path)
|
||||
|
||||
data = {
|
||||
"cb_name": cb_name,
|
||||
"code_path": raw_code_path
|
||||
}
|
||||
logger.info('create cb data={}'.format(data))
|
||||
|
||||
if no_remote_api:
|
||||
response = run_async(create_cb(**data))
|
||||
return response.dict()
|
||||
else:
|
||||
response = self.post(
|
||||
"/code_base/create_code_base",
|
||||
json=data,
|
||||
)
|
||||
logger.info('response={}'.format(response.json()))
|
||||
return self._check_httpx_json_response(response)
|
||||
|
||||
def delete_code_base(self, cb_name: str, no_remote_api: bool = None,):
|
||||
'''
|
||||
删除 code_base
|
||||
@param cb_name:
|
||||
@return:
|
||||
'''
|
||||
if no_remote_api is None:
|
||||
no_remote_api = self.no_remote_api
|
||||
|
||||
data = {
|
||||
"cb_name": cb_name
|
||||
}
|
||||
|
||||
if no_remote_api:
|
||||
response = run_async(delete_cb(**data))
|
||||
return response.dict()
|
||||
else:
|
||||
response = self.post(
|
||||
"/code_base/delete_code_base",
|
||||
json=cb_name
|
||||
)
|
||||
logger.info(response.json())
|
||||
return self._check_httpx_json_response(response)
|
||||
|
||||
def list_cb(self, no_remote_api: bool = None):
|
||||
'''
|
||||
列举 code_base
|
||||
@return:
|
||||
'''
|
||||
if no_remote_api is None:
|
||||
no_remote_api = self.no_remote_api
|
||||
|
||||
if no_remote_api:
|
||||
response = run_async(list_cbs())
|
||||
return response.data
|
||||
else:
|
||||
response = self.get("/code_base/list_code_bases")
|
||||
data = self._check_httpx_json_response(response)
|
||||
return data.get("data", [])
|
||||
|
||||
|
||||
|
||||
def check_error_msg(data: Union[str, dict, list], key: str = "errorMsg") -> str:
|
||||
'''
|
||||
|
|
|
@ -1,3 +1,3 @@
|
|||
#!/bin/bash
|
||||
|
||||
docker build -t devopsgpt:pypy38 .
|
||||
docker build -t devopsgpt:py39 .
|
|
@ -0,0 +1,208 @@
|
|||
import docker, sys, os, time, requests, psutil
|
||||
import subprocess
|
||||
from docker.types import Mount, DeviceRequest
|
||||
from loguru import logger
|
||||
|
||||
src_dir = os.path.join(
|
||||
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
)
|
||||
sys.path.append(src_dir)
|
||||
|
||||
from configs.model_config import USE_FASTCHAT
|
||||
from configs.server_config import (
|
||||
NO_REMOTE_API, SANDBOX_SERVER, SANDBOX_IMAGE_NAME, SANDBOX_CONTRAINER_NAME,
|
||||
WEBUI_SERVER, API_SERVER, SDFILE_API_SERVER, CONTRAINER_NAME, IMAGE_NAME, DOCKER_SERVICE,
|
||||
DEFAULT_BIND_HOST,
|
||||
)
|
||||
|
||||
|
||||
import platform
|
||||
system_name = platform.system()
|
||||
USE_TTY = system_name in ["Windows"]
|
||||
|
||||
|
||||
def check_process(content: str, lang: str = None, do_stop=False):
|
||||
'''process-not-exist is true, process-exist is false'''
|
||||
for process in psutil.process_iter(["pid", "name", "cmdline"]):
|
||||
# check process name contains "jupyter" and port=xx
|
||||
|
||||
# if f"port={SANDBOX_SERVER['port']}" in str(process.info["cmdline"]).lower() and \
|
||||
# "jupyter" in process.info['name'].lower():
|
||||
if content in str(process.info["cmdline"]).lower():
|
||||
logger.info(f"content, {process.info}")
|
||||
# 关闭进程
|
||||
if do_stop:
|
||||
process.terminate()
|
||||
return True
|
||||
return False
|
||||
return True
|
||||
|
||||
def check_docker(client, container_name, do_stop=False):
|
||||
'''container-not-exist is true, container-exist is false'''
|
||||
for i in client.containers.list(all=True):
|
||||
if i.name == container_name:
|
||||
if do_stop:
|
||||
container = i
|
||||
container.stop()
|
||||
container.remove()
|
||||
return True
|
||||
return False
|
||||
return True
|
||||
|
||||
def start_docker(client, script_shs, ports, image_name, container_name, mounts=None, network=None):
|
||||
container = client.containers.run(
|
||||
image=image_name,
|
||||
command="bash",
|
||||
mounts=mounts,
|
||||
name=container_name,
|
||||
# device_requests=[DeviceRequest(count=-1, capabilities=[['gpu']])],
|
||||
# network_mode="host",
|
||||
ports=ports,
|
||||
stdin_open=True,
|
||||
detach=True,
|
||||
tty=USE_TTY,
|
||||
network=network,
|
||||
)
|
||||
|
||||
logger.info(f"docker id: {container.id[:10]}")
|
||||
|
||||
# 启动notebook
|
||||
for script_sh in script_shs:
|
||||
if USE_FASTCHAT and "llm_api" in script_sh:
|
||||
logger.debug(script_sh)
|
||||
response = container.exec_run(["sh", "-c", script_sh])
|
||||
logger.debug(response)
|
||||
elif "llm_api" not in script_sh:
|
||||
logger.debug(script_sh)
|
||||
response = container.exec_run(["sh", "-c", script_sh])
|
||||
logger.debug(response)
|
||||
return container
|
||||
|
||||
#########################################
|
||||
############# 开始启动服务 ###############
|
||||
#########################################
|
||||
|
||||
client = docker.from_env()
|
||||
client.containers.run
|
||||
network_name ='my_network'
|
||||
|
||||
def start_sandbox_service():
|
||||
networks = client.networks.list()
|
||||
if any([network_name==i.attrs["Name"] for i in networks]):
|
||||
network = client.networks.get(network_name)
|
||||
else:
|
||||
network = client.networks.create('my_network', driver='bridge')
|
||||
|
||||
mount = Mount(
|
||||
type='bind',
|
||||
source=os.path.join(src_dir, "jupyter_work"),
|
||||
target='/home/user/chatbot/jupyter_work',
|
||||
read_only=False # 如果需要只读访问,将此选项设置为True
|
||||
)
|
||||
mounts = [mount]
|
||||
# 沙盒的启动与服务的启动是独立的
|
||||
if SANDBOX_SERVER["do_remote"]:
|
||||
# 启动容器
|
||||
logger.info("start container sandbox service")
|
||||
script_shs = ["bash jupyter_start.sh"]
|
||||
JUPYTER_WORK_PATH = "/home/user/chatbot/jupyter_work"
|
||||
script_shs = [f"cd /home/user/chatbot/jupyter_work && nohup jupyter-notebook --NotebookApp.token=mytoken --port=5050 --allow-root --ip=0.0.0.0 --notebook-dir={JUPYTER_WORK_PATH} --no-browser --ServerApp.disable_check_xsrf=True &"]
|
||||
ports = {f"{SANDBOX_SERVER['docker_port']}/tcp": f"{SANDBOX_SERVER['port']}/tcp"}
|
||||
if check_docker(client, SANDBOX_CONTRAINER_NAME, do_stop=True, ):
|
||||
container = start_docker(client, script_shs, ports, SANDBOX_IMAGE_NAME, SANDBOX_CONTRAINER_NAME, mounts=mounts, network=network_name)
|
||||
# 判断notebook是否启动
|
||||
retry_nums = 3
|
||||
while retry_nums>0:
|
||||
response = requests.get(f"http://localhost:{SANDBOX_SERVER['port']}", timeout=270)
|
||||
if response.status_code == 200:
|
||||
logger.info("container & notebook init success")
|
||||
break
|
||||
else:
|
||||
retry_nums -= 1
|
||||
logger.info(client.containers.list())
|
||||
logger.info("wait container running ...")
|
||||
time.sleep(5)
|
||||
|
||||
else:
|
||||
check_docker(client, SANDBOX_CONTRAINER_NAME, do_stop=True, )
|
||||
logger.info("start local sandbox service")
|
||||
|
||||
def start_api_service(sandbox_host=DEFAULT_BIND_HOST):
|
||||
# 启动service的容器
|
||||
if DOCKER_SERVICE:
|
||||
logger.info("start container service")
|
||||
check_process("service/api.py", do_stop=True)
|
||||
check_process("service/sdfile_api.py", do_stop=True)
|
||||
check_process("service/sdfile_api.py", do_stop=True)
|
||||
check_process("webui.py", do_stop=True)
|
||||
mount = Mount(
|
||||
type='bind',
|
||||
source=src_dir,
|
||||
target='/home/user/chatbot/',
|
||||
read_only=False # 如果需要只读访问,将此选项设置为True
|
||||
)
|
||||
mount_database = Mount(
|
||||
type='bind',
|
||||
source=os.path.join(src_dir, "knowledge_base"),
|
||||
target='/home/user/knowledge_base/',
|
||||
read_only=False # 如果需要只读访问,将此选项设置为True
|
||||
)
|
||||
|
||||
ports={
|
||||
f"{API_SERVER['docker_port']}/tcp": f"{API_SERVER['port']}/tcp",
|
||||
f"{WEBUI_SERVER['docker_port']}/tcp": f"{WEBUI_SERVER['port']}/tcp",
|
||||
f"{SDFILE_API_SERVER['docker_port']}/tcp": f"{SDFILE_API_SERVER['port']}/tcp",
|
||||
}
|
||||
mounts = [mount, mount_database]
|
||||
script_shs = [
|
||||
"mkdir -p /home/user/logs",
|
||||
"pip install zdatafront-sdk-python -i https://artifacts.antgroup-inc.cn/simple",
|
||||
"pip install jsonref",
|
||||
"pip install javalang",
|
||||
"nohup python chatbot/dev_opsgpt/service/sdfile_api.py > /home/user/logs/sdfile_api.log 2>&1 &",
|
||||
f"export DUCKDUCKGO_PROXY=socks5://host.docker.internal:13659 && export SANDBOX_HOST={sandbox_host} &&\
|
||||
nohup python chatbot/dev_opsgpt/service/api.py > /home/user/logs/api.log 2>&1 &",
|
||||
"nohup python chatbot/dev_opsgpt/service/llm_api.py > /home/user/ 2>&1 &",
|
||||
f"export DUCKDUCKGO_PROXY=socks5://host.docker.internal:13659 && export SANDBOX_HOST={sandbox_host} &&\
|
||||
cd chatbot/examples && nohup streamlit run webui.py > /home/user/logs/start_webui.log 2>&1 &"
|
||||
]
|
||||
if check_docker(client, CONTRAINER_NAME, do_stop=True):
|
||||
container = start_docker(client, script_shs, ports, IMAGE_NAME, CONTRAINER_NAME, mounts, network=network_name)
|
||||
|
||||
else:
|
||||
logger.info("start local service")
|
||||
# 关闭之前启动的docker 服务
|
||||
# check_docker(client, CONTRAINER_NAME, do_stop=True, )
|
||||
|
||||
api_sh = "nohup python ../dev_opsgpt/service/api.py > ../logs/api.log 2>&1 &"
|
||||
sdfile_sh = "nohup python ../dev_opsgpt/service/sdfile_api.py > ../logs/sdfile_api.log 2>&1 &"
|
||||
llm_sh = "nohup python ../dev_opsgpt/service/llm_api.py > ../logs/llm_api.log 2>&1 &"
|
||||
webui_sh = "streamlit run webui.py" if USE_TTY else "streamlit run webui.py"
|
||||
#
|
||||
if not NO_REMOTE_API and check_process("service/api.py"):
|
||||
logger.info('check 1')
|
||||
subprocess.Popen(api_sh, shell=True)
|
||||
#
|
||||
if USE_FASTCHAT and check_process("service/llm_api.py"):
|
||||
subprocess.Popen(llm_sh, shell=True)
|
||||
#
|
||||
if check_process("service/sdfile_api.py"):
|
||||
subprocess.Popen(sdfile_sh, shell=True)
|
||||
|
||||
subprocess.Popen(webui_sh, shell=True)
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
start_sandbox_service()
|
||||
client = docker.from_env()
|
||||
containers = client.containers.list(all=True)
|
||||
|
||||
sandbox_host = DEFAULT_BIND_HOST
|
||||
for container in containers:
|
||||
container_a_info = client.containers.get(container.id)
|
||||
if container_a_info.name == SANDBOX_CONTRAINER_NAME:
|
||||
container1_networks = container.attrs['NetworkSettings']['Networks']
|
||||
sandbox_host = container1_networks.get(network_name)["IPAddress"]
|
||||
break
|
||||
start_api_service(sandbox_host)
|
|
@ -7,13 +7,13 @@ src_dir = os.path.join(
|
|||
)
|
||||
sys.path.append(src_dir)
|
||||
|
||||
from configs.server_config import CONTRAINER_NAME, SANDBOX_SERVER, IMAGE_NAME
|
||||
from configs.server_config import SANDBOX_SERVER, SANDBOX_IMAGE_NAME, SANDBOX_CONTRAINER_NAME
|
||||
|
||||
|
||||
if SANDBOX_SERVER["do_remote"]:
|
||||
client = docker.from_env()
|
||||
for i in client.containers.list(all=True):
|
||||
if i.name == CONTRAINER_NAME:
|
||||
if i.name == SANDBOX_CONTRAINER_NAME:
|
||||
container = i
|
||||
container.stop()
|
||||
container.remove()
|
||||
|
@ -21,10 +21,10 @@ if SANDBOX_SERVER["do_remote"]:
|
|||
# 启动容器
|
||||
logger.info("start ot init container & notebook")
|
||||
container = client.containers.run(
|
||||
image=IMAGE_NAME,
|
||||
image=SANDBOX_IMAGE_NAME,
|
||||
command="bash",
|
||||
name=CONTRAINER_NAME,
|
||||
ports={"5050/tcp": SANDBOX_SERVER["port"]},
|
||||
name=SANDBOX_CONTRAINER_NAME,
|
||||
ports={f"{SANDBOX_SERVER['docker_port']}/tcp": SANDBOX_SERVER["port"]},
|
||||
stdin_open=True,
|
||||
detach=True,
|
||||
tty=True,
|
||||
|
|
|
@ -0,0 +1,68 @@
|
|||
import docker, sys, os, time, requests
|
||||
from docker.types import Mount
|
||||
|
||||
from loguru import logger
|
||||
|
||||
src_dir = os.path.join(
|
||||
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
)
|
||||
sys.path.append(src_dir)
|
||||
|
||||
from configs.server_config import WEBUI_SERVER, API_SERVER, SDFILE_API_SERVER, CONTRAINER_NAME, IMAGE_NAME
|
||||
from configs.model_config import USE_FASTCHAT
|
||||
|
||||
|
||||
|
||||
logger.info(f"IMAGE_NAME: {IMAGE_NAME}, CONTRAINER_NAME: {CONTRAINER_NAME}, ")
|
||||
|
||||
|
||||
client = docker.from_env()
|
||||
for i in client.containers.list(all=True):
|
||||
if i.name == CONTRAINER_NAME:
|
||||
container = i
|
||||
container.stop()
|
||||
container.remove()
|
||||
break
|
||||
|
||||
|
||||
|
||||
# 启动容器
|
||||
logger.info("start service")
|
||||
|
||||
mount = Mount(
|
||||
type='bind',
|
||||
source=src_dir,
|
||||
target='/home/user/chatbot/',
|
||||
read_only=True # 如果需要只读访问,将此选项设置为True
|
||||
)
|
||||
|
||||
container = client.containers.run(
|
||||
image=IMAGE_NAME,
|
||||
command="bash",
|
||||
mounts=[mount],
|
||||
name=CONTRAINER_NAME,
|
||||
ports={
|
||||
f"{WEBUI_SERVER['docker_port']}/tcp": API_SERVER['port'],
|
||||
f"{API_SERVER['docker_port']}/tcp": WEBUI_SERVER['port'],
|
||||
f"{SDFILE_API_SERVER['docker_port']}/tcp": SDFILE_API_SERVER['port'],
|
||||
},
|
||||
stdin_open=True,
|
||||
detach=True,
|
||||
tty=True,
|
||||
)
|
||||
|
||||
# 启动notebook
|
||||
exec_command = container.exec_run("bash jupyter_start.sh")
|
||||
#
|
||||
exec_command = container.exec_run("cd /homse/user/chatbot && nohup python devops_gpt/service/sdfile_api.py > /homse/user/logs/sdfile_api.log &")
|
||||
#
|
||||
exec_command = container.exec_run("cd /homse/user/chatbot && nohup python devops_gpt/service/api.py > /homse/user/logs/api.log &")
|
||||
|
||||
if USE_FASTCHAT:
|
||||
# 启动fastchat的服务
|
||||
exec_command = container.exec_run("cd /homse/user/chatbot && nohup python devops_gpt/service/llm_api.py > /homse/user/logs/llm_api.log &")
|
||||
#
|
||||
exec_command = container.exec_run("cd /homse/user/chatbot/examples && nohup bash start_webui.sh > /homse/user/logs/start_webui.log &")
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,28 @@
|
|||
import docker, sys, os
|
||||
from loguru import logger
|
||||
|
||||
src_dir = os.path.join(
|
||||
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
)
|
||||
sys.path.append(src_dir)
|
||||
|
||||
from configs.server_config import (
|
||||
SANDBOX_CONTRAINER_NAME, CONTRAINER_NAME, SANDBOX_SERVER, DOCKER_SERVICE
|
||||
)
|
||||
|
||||
|
||||
from start import check_docker, check_process
|
||||
|
||||
client = docker.from_env()
|
||||
|
||||
#
|
||||
check_docker(client, SANDBOX_CONTRAINER_NAME, do_stop=True, )
|
||||
check_process(f"port={SANDBOX_SERVER['port']}", do_stop=True)
|
||||
check_process(f"port=5050", do_stop=True)
|
||||
|
||||
#
|
||||
check_docker(client, CONTRAINER_NAME, do_stop=True, )
|
||||
check_process("service/api.py", do_stop=True)
|
||||
check_process("service/sdfile_api.py", do_stop=True)
|
||||
check_process("service/llm_api.py", do_stop=True)
|
||||
check_process("webui.py", do_stop=True)
|
|
@ -4,10 +4,13 @@
|
|||
# 3. 运行API服务器:python server/api.py。如果使用api = ApiRequest(no_remote_api=True),该步可以跳过。
|
||||
# 4. 运行WEB UI:streamlit run webui.py --server.port 7860
|
||||
|
||||
import os, sys
|
||||
import os
|
||||
import sys
|
||||
import streamlit as st
|
||||
from streamlit_option_menu import option_menu
|
||||
|
||||
import multiprocessing
|
||||
|
||||
src_dir = os.path.join(
|
||||
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
)
|
||||
|
@ -15,10 +18,11 @@ sys.path.append(src_dir)
|
|||
|
||||
from dev_opsgpt.webui import *
|
||||
from configs import VERSION, LLM_MODEL
|
||||
from configs.server_config import NO_REMOTE_API
|
||||
|
||||
|
||||
|
||||
api = ApiRequest(base_url="http://127.0.0.1:7861", no_remote_api=True)
|
||||
api = ApiRequest(base_url="http://127.0.0.1:7861", no_remote_api=NO_REMOTE_API)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -48,6 +52,10 @@ if __name__ == "__main__":
|
|||
"icon": "hdd-stack",
|
||||
"func": knowledge_page,
|
||||
},
|
||||
"代码知识库管理": {
|
||||
"icon": "hdd-stack",
|
||||
"func": code_page,
|
||||
},
|
||||
# "Prompt管理": {
|
||||
# "icon": "hdd-stack",
|
||||
# "func": prompt_page,
|
||||
|
|
|
@ -25,8 +25,7 @@ notebook
|
|||
websockets
|
||||
fake_useragent
|
||||
selenium
|
||||
auto-gptq==0.4.2
|
||||
|
||||
jsonref
|
||||
|
||||
# uncomment libs if you want to use corresponding vector store
|
||||
# pymilvus==2.1.3 # requires milvus==2.1.3
|
||||
|
@ -41,3 +40,5 @@ streamlit-antd-components>=0.1.11
|
|||
streamlit-chatbox>=1.1.6
|
||||
streamlit-aggrid>=0.3.4.post3
|
||||
httpx~=0.24.1
|
||||
|
||||
javalang==0.13.0
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
{"url": "https://python.langchain.com/docs/get_started/introduction", "host_url": "https://python.langchain.com", "title": "Introduction | 🦜️🔗 Langchain", "all_text": "\n\nIntroduction | 🦜️🔗 Langchain\n\nSkip to main content🦜️🔗 LangChainDocsUse casesIntegrationsAPICommunityChat our docsLangSmithJS/TS DocsSearchCTRLKGet startedIntroductionInstallationQuickstartLangChain Expression LanguageInterfaceHow toCookbookLangChain Expression Language (LCEL)ModulesModel I/ORetrievalChainsMemoryAgentsCallbacksModulesGuidesMoreGet startedIntroductionOn this pageIntroductionLangChain is a framework for developing applications powered by language models. It enables applications that:Are context-aware: connect a language model to sources of context (prompt instructions, few shot examples, content to ground its response in, etc.)Reason: rely on a language model to reason (about how to answer based on provided context, what actions to take, etc.)The main value props of LangChain are:Components: abstractions for working with language models, along with a collection of implementations for each abstraction. Components are modular and easy-to-use, whether you are using the rest of the LangChain framework or notOff-the-shelf chains: a structured assembly of components for accomplishing specific higher-level tasksOff-the-shelf chains make it easy to get started. For complex applications, components make it easy to customize existing chains and build new ones.Get startedHere’s how to install LangChain, set up your environment, and start building.We recommend following our Quickstart guide to familiarize yourself with the framework by building your first LangChain application.Note: These docs are for the LangChain Python package. For documentation on LangChain.js, the JS/TS version, head here.ModulesLangChain provides standard, extendable interfaces and external integrations for the following modules, listed from least to most complex:Model I/OInterface with language modelsRetrievalInterface with application-specific dataChainsConstruct sequences of callsAgentsLet chains choose which tools to use given high-level directivesMemoryPersist application state between runs of a chainCallbacksLog and stream intermediate steps of any chainExamples, ecosystem, and resourcesUse casesWalkthroughs and best-practices for common end-to-end use cases, like:Document question answeringChatbotsAnalyzing structured dataand much more...GuidesLearn best practices for developing with LangChain.EcosystemLangChain is part of a rich ecosystem of tools that integrate with our framework and build on top of it. Check out our growing list of integrations and dependent repos.Additional resourcesOur community is full of prolific developers, creative builders, and fantastic teachers. Check out YouTube tutorials for great tutorials from folks in the community, and Gallery for a list of awesome LangChain projects, compiled by the folks at KyroLabs.CommunityHead to the Community navigator to find places to ask questions, share feedback, meet other developers, and dream about the future of LLM’s.API referenceHead to the reference section for full documentation of all classes and methods in the LangChain Python package.PreviousGet startedNextInstallationGet startedModulesExamples, ecosystem, and resourcesUse casesGuidesEcosystemAdditional resourcesCommunityAPI referenceCommunityDiscordTwitterGitHubPythonJS/TSMoreHomepageBlogCopyright © 2023 LangChain, Inc.\n\n"}
|
|
@ -0,0 +1 @@
|
|||
{"url": "https://zhuanlan.zhihu.com/p/80963305", "host_url": "https://zhuanlan.zhihu.com", "title": "【工具类】PyCharm+Anaconda+jupyter notebook +pip环境配置 - 知乎", "all_text": "\n【工具类】PyCharm+Anaconda+jupyter notebook +pip环境配置 - 知乎切换模式写文章登录/注册【工具类】PyCharm+Anaconda+jupyter notebook +pip环境配置Joe.Zhao14 人赞同了该文章Pycharm是一个很好的python的IDE,Anaconda是一个环境管理工具,可以针对不同工作配置不同的环境,如何在Pycharm中调用Anaconda中创建的环境Anaconda环境配置Anaconda 解决了官方 Python 的两大痛点第一:提供了包管理功能,解决安装第三方包经常失败第二:提供环境管理的功能,功能类似 Virtualenv,解决了多版本Python并存、切换的问题。查看Anaconda中所有的Python环境,Window环境下Anaconda Prompt中输入以下命令,其中前面有个‘*’的代表当前环境\n```code\nconda info --env\n\n# conda environments:\n#\nbase * D:\\Anaconda3\ntf D:\\Anaconda3\\envs\\tf\n```\n创建新的Python环境\n```code\nconda create --name python35 python=3.5 #代表创建一个python3.5的环境,我们把它命名为python35\n```\n激活进入创建的环境\n```code\nconda activate python35\n```\n在当前环境中安装package,可以使用pip,还可以用conda\n```code\npip install numpy\nconda install numpy\n```\n退出当前环境,回到base环境\n```code\nconda deactivate\n```\n删除创建的环境,conda创建的环境会在安装目录Anaconda3\\envs\\下面,每一个环境对应一个文件夹,当删除环境的时候,响应的文件夹也会被删除掉\n```code\nconda env remove --name python35 --all\nconda remove --name myenv --all\n```\nconda源头\n```code\nconda config --show channels\nchannels:\n- https://pypi.doubanio.com/simple/\n- defaults\n```\n添加新源\n```code\nconda config --add channels https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/\n```\n删除源\n```code\nconda config --remove channels https://pypi.doubanio.com/simple/\n```\n查看安装package\n```code\nconda list\n\n```\nPycharm 使用Anaconda创建的环境pycharm工程目录中打开/file/settings/Project Interpreter在Project Interpreter中打开Add,左侧边栏目选择Conda Environment,右侧选择Existing environment在文件路径中选择Anaconda安装目录下面的envs目录,下面是该系统安装的所有anaconda环境,进入文件夹,选择python解释器这就就把Pycharm下使用Anconda环境的配置完成了。pip 环境配置conda 环境下也可以用pip来安装包pip安装\n```code\npip install 安装包名\n[...]\nSuccessfully installed SomePackage #安装成功\n```\npip 安装指定源\n```code\npip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple\n```\npip查看是否已安装\n```code\npip show --files 安装包名\n\nName:SomePackage # 包名\nVersion:1.0 # 版本号\nLocation:/my/env/lib/pythonx.x/site-packages # 安装位置\nFiles: # 包含文件等等\n../somepackage/__init__.py\n```\npip检查哪些包需要更新\n```code\npip list --outdated\n```\npip升级包\n```code\npip install --upgrade 要升级的包名\n```\npip卸载包\n```code\npip uninstall 要卸载的包名\n```\npip参数解释\n```code\npip --help\n```\nJupyter notebook使用在利用anaconda创建了tensorflow,pytorch等儿女environment后,想利用jupyter notebook,发现jupyer notebook中只有系统的python3环境,如何把conda创建的环境添加进jupyter notebook中呢,终于解决了这个问题了1. 安装ipykernel\n```code\nconda install ipykernel\n```\n2. 将环境写入notebook的kernel中\n```code\npython -m ipykernel install --user --name your_env_name --display-name your_env_name\n\n//把conda environment pytorch_0.4 add to jupyter notebook kernel display as pytorch_0.4\npython -m ipykernel install --user --name pytorch_0.4 --display-name pytorch_0.4\n```\n3. 打开notebook\n```code\njupyter notebook\n```\n4. magic commands\n```code\n!git clone https://github.com/ultralytics/yolov5\n%ls\n%cd yolov5\n%pip install -qr requirements.txt\n```\n还有一些实用的魔术命令\n```code\n%magic——用来显示所有魔术命令的详细文档\n%time和%timeit——用来测试代码执行时间\n```\n参考文档编辑于 2023-05-21 20:41・IP 属地浙江PyCharmAnacondapip3赞同 142 条评论分享喜欢收藏申请转载"}
|
|
@ -47,13 +47,13 @@
|
|||
- [x] Web Crawl 通用能力:技术文档: 知乎、csdn、阿里云开发者论坛、腾讯云开发者论坛等
|
||||
<br>
|
||||
- v0.1
|
||||
- [ ] Sandbox 环境: 上传、下载文件
|
||||
- [x] Sandbox 环境: 上传、下载文件
|
||||
- [ ] Vector Database & Retrieval
|
||||
- [ ] task retrieval
|
||||
- [ ] tool retrieval
|
||||
- [ ] Connector
|
||||
- [ ] 基于langchain的react模式
|
||||
- [ ] 基于sentencebert接入Text Embedding: 向量加载速度提升
|
||||
- [x] 基于sentencebert接入Text Embedding: 向量加载速度提升
|
||||
<br>
|
||||
|
||||
- v0.2
|
||||
|
|
|
@ -0,0 +1,168 @@
|
|||
import os, sys, requests
|
||||
|
||||
src_dir = os.path.join(
|
||||
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
)
|
||||
sys.path.append(src_dir)
|
||||
|
||||
from dev_opsgpt.tools import (
|
||||
toLangchainTools, get_tool_schema, DDGSTool, DocRetrieval,
|
||||
TOOL_DICT, TOOL_SETS
|
||||
)
|
||||
|
||||
from configs.model_config import *
|
||||
from dev_opsgpt.connector.phase import BasePhase
|
||||
from dev_opsgpt.connector.agents import BaseAgent
|
||||
from dev_opsgpt.connector.chains import BaseChain
|
||||
from dev_opsgpt.connector.connector_schema import (
|
||||
Message, load_role_configs, load_phase_configs, load_chain_configs
|
||||
)
|
||||
from dev_opsgpt.connector.configs import AGETN_CONFIGS, CHAIN_CONFIGS, PHASE_CONFIGS
|
||||
import importlib
|
||||
|
||||
print(src_dir)
|
||||
|
||||
tools = toLangchainTools([TOOL_DICT[i] for i in TOOL_SETS if i in TOOL_DICT])
|
||||
|
||||
role_configs = load_role_configs(AGETN_CONFIGS)
|
||||
chain_configs = load_chain_configs(CHAIN_CONFIGS)
|
||||
phase_configs = load_phase_configs(PHASE_CONFIGS)
|
||||
|
||||
agent_module = importlib.import_module("dev_opsgpt.connector.agents")
|
||||
|
||||
|
||||
# agent的测试
|
||||
query = Message(role_name="tool_react", role_type="human",
|
||||
role_content="我有一份时序数据,[0.857, 2.345, 1.234, 4.567, 3.456, 9.876, 5.678, 7.890, 6.789, 8.901, 10.987, 12.345, 11.234, 14.567, 13.456, 19.876, 15.678, 17.890, 16.789, \
|
||||
18.901, 20.987, 22.345, 21.234, 24.567, 23.456, 29.876, 25.678, 27.890, 26.789, 28.901, 30.987, 32.345, 31.234, 34.567, 33.456, 39.876, 35.678, 37.890, 36.789, 38.901, 40.987],\
|
||||
我不知道这份数据是否存在问题,请帮我判断一下", tools=tools)
|
||||
|
||||
query = Message(role_name="tool_react", role_type="human",
|
||||
role_content="帮我确认下127.0.0.1这个服务器的在10点是否存在异常,请帮我判断一下", tools=tools)
|
||||
|
||||
query = Message(role_name="code_react", role_type="human",
|
||||
role_content="帮我确认当前目录下有哪些文件", tools=tools)
|
||||
|
||||
# "给我一份冒泡排序的代码"
|
||||
query = Message(role_name="intention_recognizer", role_type="human",
|
||||
role_content="对employee_data.csv进行数据分析", tools=tools)
|
||||
|
||||
# role = role_configs["general_planner"]
|
||||
# agent_class = getattr(agent_module, role.role.agent_type)
|
||||
# agent = agent_class(role.role,
|
||||
# task = None,
|
||||
# memory = None,
|
||||
# chat_turn=role.chat_turn,
|
||||
# do_search = role.do_search,
|
||||
# do_doc_retrieval = role.do_doc_retrieval,
|
||||
# do_tool_retrieval = role.do_tool_retrieval,)
|
||||
|
||||
# message = agent.run(query)
|
||||
# print(message.role_content)
|
||||
|
||||
|
||||
# chain的测试
|
||||
|
||||
# query = Message(role_name="deveploer", role_type="human", role_content="编写冒泡排序,并生成测例")
|
||||
# query = Message(role_name="general_planner", role_type="human", role_content="对employee_data.csv进行数据分析")
|
||||
# query = Message(role_name="tool_react", role_type="human", role_content="我有一份时序数据,[0.857, 2.345, 1.234, 4.567, 3.456, 9.876, 5.678, 7.890, 6.789, 8.901, 10.987, 12.345, 11.234, 14.567, 13.456, 19.876, 15.678, 17.890, 16.789, 18.901, 20.987, 22.345, 21.234, 24.567, 23.456, 29.876, 25.678, 27.890, 26.789, 28.901, 30.987, 32.345, 31.234, 34.567, 33.456, 39.876, 35.678, 37.890, 36.789, 38.901, 40.987],\我不知道这份数据是否存在问题,请帮我判断一下", tools=tools)
|
||||
|
||||
# role = role_configs[query.role_name]
|
||||
role1 = role_configs["planner"]
|
||||
role2 = role_configs["code_react"]
|
||||
|
||||
agents = [
|
||||
getattr(agent_module, role1.role.agent_type)(role1.role,
|
||||
task = None,
|
||||
memory = None,
|
||||
do_search = role1.do_search,
|
||||
do_doc_retrieval = role1.do_doc_retrieval,
|
||||
do_tool_retrieval = role1.do_tool_retrieval,),
|
||||
getattr(agent_module, role2.role.agent_type)(role2.role,
|
||||
task = None,
|
||||
memory = None,
|
||||
do_search = role2.do_search,
|
||||
do_doc_retrieval = role2.do_doc_retrieval,
|
||||
do_tool_retrieval = role2.do_tool_retrieval,),
|
||||
]
|
||||
|
||||
query = Message(role_name="user", role_type="human",
|
||||
role_content="确认本地是否存在employee_data.csv,并查看它有哪些列和数据类型,分析这份数据的内容,根据这个数据预测未来走势", tools=tools)
|
||||
query = Message(role_name="user", role_type="human",
|
||||
role_content="确认本地是否存在employee_data.csv,并查看它有哪些列和数据类型", tools=tools)
|
||||
chain = BaseChain(chain_configs["dataAnalystChain"], agents, do_code_exec=False)
|
||||
|
||||
# message = chain.step(query)
|
||||
# print(message.role_content)
|
||||
|
||||
# print("\n".join("\n".join([": ".join(j) for j in i]) for i in chain.get_agents_memory()))
|
||||
# print("\n".join(": ".join(i) for i in chain.get_memory()))
|
||||
# print( chain.get_agents_memory_str())
|
||||
# print( chain.get_memory_str())
|
||||
|
||||
|
||||
|
||||
|
||||
# 测试 phase
|
||||
phase_name = "toolReactPhase"
|
||||
# phase_name = "codeReactPhase"
|
||||
# phase_name = "chatPhase"
|
||||
|
||||
phase = BasePhase(phase_name,
|
||||
task = None,
|
||||
phase_config = PHASE_CONFIGS,
|
||||
chain_config = CHAIN_CONFIGS,
|
||||
role_config = AGETN_CONFIGS,
|
||||
do_summary=False,
|
||||
do_code_retrieval=False,
|
||||
do_doc_retrieval=True,
|
||||
do_search=False,
|
||||
)
|
||||
|
||||
query = Message(role_name="user", role_type="human",
|
||||
role_content="确认本地是否存在employee_data.csv,并查看它有哪些列和数据类型,并选择合适的数值列画出折线图")
|
||||
|
||||
query = Message(role_name="user", role_type="human",
|
||||
role_content="判断下127.0.0.1这个服务器的在10点的监控数据,是否存在异常", tools=tools)
|
||||
|
||||
# 根据其他类似的类,新开发个 ExceptionComponent2,继承 AbstractTrafficComponent
|
||||
# query = Message(role_name="human", role_type="human", role_content="langchain有什么用")
|
||||
|
||||
# output_message = phase.step(query)
|
||||
|
||||
# print(phase.get_chains_memory(content_key="step_content"))
|
||||
# print(phase.get_chains_memory_str(content_key="step_content"))
|
||||
# print(output_message.to_tuple_message(return_all=True))
|
||||
|
||||
|
||||
from dev_opsgpt.tools import DDGSTool, CodeRetrieval
|
||||
# print(DDGSTool.run("langchain是什么", 3))
|
||||
# print(CodeRetrieval.run("dsadsadsa", query.role_content, code_limit=3, history_node_list=[]))
|
||||
|
||||
|
||||
# from dev_opsgpt.chat.agent_chat import AgentChat
|
||||
|
||||
# agentChat = AgentChat()
|
||||
# value = {
|
||||
# "query": "帮我确认下127.0.0.1这个服务器的在10点是否存在异常,请帮我判断一下",
|
||||
# "phase_name": "toolReactPhase",
|
||||
# "chain_name": "",
|
||||
# "history": [],
|
||||
# "doc_engine_name": "DSADSAD",
|
||||
# "search_engine_name": "duckduckgo",
|
||||
# "top_k": 3,
|
||||
# "score_threshold": 1.0,
|
||||
# "stream": False,
|
||||
# "local_doc_url": False,
|
||||
# "do_search": False,
|
||||
# "do_doc_retrieval": False,
|
||||
# "do_code_retrieval": False,
|
||||
# "do_tool_retrieval": False,
|
||||
# "custom_phase_configs": {},
|
||||
# "custom_chain_configs": {},
|
||||
# "custom_role_configs": {},
|
||||
# "choose_tools": list(TOOL_SETS)
|
||||
# }
|
||||
|
||||
# answer = agentChat.chat(**value)
|
||||
# print(answer)
|
|
@ -8,4 +8,37 @@ print(time.time()-st)
|
|||
|
||||
st = time.time()
|
||||
client.containers.run("ubuntu:latest", "echo hello world")
|
||||
print(time.time()-st)
|
||||
print(time.time()-st)
|
||||
|
||||
|
||||
import socket
|
||||
|
||||
|
||||
def get_ip_address():
|
||||
hostname = socket.gethostname()
|
||||
ip_address = socket.gethostbyname(hostname)
|
||||
return ip_address
|
||||
|
||||
def get_ipv4_address():
|
||||
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
try:
|
||||
# 使用一个临时套接字连接到公共的 DNS 服务器
|
||||
s.connect(("8.8.8.8", 80))
|
||||
ip_address = s.getsockname()[0]
|
||||
finally:
|
||||
s.close()
|
||||
return ip_address
|
||||
|
||||
# print(get_ipv4_address())
|
||||
# import docker
|
||||
# client = docker.from_env()
|
||||
|
||||
# containers = client.containers.list(all=True)
|
||||
# for container in containers:
|
||||
# container_a_info = client.containers.get(container.id)
|
||||
# container1_networks = container.attrs['NetworkSettings']['Networks']
|
||||
# container_a_ip = container_a_info.attrs['NetworkSettings']['IPAddress']
|
||||
|
||||
# print(container_a_info.name, container_a_ip, [[k, v["IPAddress"]] for k,v in container1_networks.items() ])
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,42 @@
|
|||
import requests, os, sys
|
||||
# src_dir = os.path.join(
|
||||
# os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
# )
|
||||
# sys.path.append(src_dir)
|
||||
|
||||
# from dev_opsgpt.utils.common_utils import st_load_file
|
||||
# from dev_opsgpt.sandbox.pycodebox import PyCodeBox
|
||||
# from examples.file_fastapi import upload_file, download_file
|
||||
# from pathlib import Path
|
||||
# import httpx
|
||||
# from loguru import logger
|
||||
# from io import BytesIO
|
||||
|
||||
|
||||
# def _parse_url(url: str, base_url: str) -> str:
|
||||
# if (not url.startswith("http")
|
||||
# and base_url
|
||||
# ):
|
||||
# part1 = base_url.strip(" /")
|
||||
# part2 = url.strip(" /")
|
||||
# return f"{part1}/{part2}"
|
||||
# else:
|
||||
# return url
|
||||
|
||||
# base_url: str = "http://127.0.0.1:7861"
|
||||
# timeout: float = 60.0,
|
||||
# url = "/files/upload"
|
||||
# url = _parse_url(url, base_url)
|
||||
# logger.debug(url)
|
||||
# kwargs = {}
|
||||
# kwargs.setdefault("timeout", timeout)
|
||||
|
||||
# import asyncio
|
||||
# file = "./torch_test.py"
|
||||
# upload_filename = st_load_file(file, filename="torch_test.py")
|
||||
# asyncio.run(upload_file(upload_filename))
|
||||
|
||||
import requests
|
||||
url = "http://127.0.0.1:7862/sdfiles/download?filename=torch_test.py&save_filename=torch_test.py"
|
||||
r = requests.get(url)
|
||||
print(type(r.text))
|
|
@ -9,7 +9,7 @@ from configs import llm_model_dict, LLM_MODEL
|
|||
import openai
|
||||
# os.environ["OPENAI_PROXY"] = "socks5h://127.0.0.1:7890"
|
||||
# os.environ["OPENAI_PROXY"] = "http://127.0.0.1:7890"
|
||||
os.environ["OPENAI_API_KEY"] = ""
|
||||
# os.environ["OPENAI_API_KEY"] = ""
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -13,7 +13,7 @@ src_dir = os.path.join(
|
|||
)
|
||||
sys.path.append(src_dir)
|
||||
|
||||
|
||||
from dev_opsgpt.service.sdfile_api import sd_upload_file
|
||||
from dev_opsgpt.sandbox.pycodebox import PyCodeBox
|
||||
from pathlib import Path
|
||||
|
||||
|
@ -24,20 +24,20 @@ from pathlib import Path
|
|||
# print(sys.executable)
|
||||
|
||||
|
||||
# import requests
|
||||
import requests
|
||||
|
||||
# # 设置Jupyter Notebook服务器的URL
|
||||
# url = 'http://localhost:5050' # 或者是你自己的Jupyter服务器的URL
|
||||
# 设置Jupyter Notebook服务器的URL
|
||||
url = 'http://172.25.0.3:5050' # 或者是你自己的Jupyter服务器的URL
|
||||
|
||||
# # 发送GET请求来获取Jupyter Notebook的登录页面
|
||||
# response = requests.get(url)
|
||||
# 发送GET请求来获取Jupyter Notebook的登录页面
|
||||
response = requests.get(url)
|
||||
|
||||
# # 检查响应状态码
|
||||
# if response.status_code == 200:
|
||||
# # 打印响应内容
|
||||
# print('connect success')
|
||||
# else:
|
||||
# print('connect fail')
|
||||
# 检查响应状态码
|
||||
if response.status_code == 200:
|
||||
# 打印响应内容
|
||||
print('connect success')
|
||||
else:
|
||||
print('connect fail')
|
||||
|
||||
# import subprocess
|
||||
# jupyter = subprocess.Popen(
|
||||
|
@ -53,31 +53,42 @@ from pathlib import Path
|
|||
# stdout=subprocess.PIPE,
|
||||
# )
|
||||
|
||||
# 测试1
|
||||
import time, psutil
|
||||
from loguru import logger
|
||||
pycodebox = PyCodeBox(remote_url="http://localhost:5050",
|
||||
remote_ip="http://localhost",
|
||||
remote_port="5050",
|
||||
token="mytoken",
|
||||
do_code_exe=True,
|
||||
do_remote=False)
|
||||
# # 测试1
|
||||
# import time, psutil
|
||||
# from loguru import logger
|
||||
# import asyncio
|
||||
# pycodebox = PyCodeBox(remote_url="http://localhost:5050",
|
||||
# remote_ip="http://localhost",
|
||||
# remote_port="5050",
|
||||
# token="mytoken",
|
||||
# do_code_exe=True,
|
||||
# do_remote=False)
|
||||
|
||||
# pycodebox.list_files()
|
||||
# file = "./torch_test.py"
|
||||
# upload_file = st_load_file(file, filename="torch_test.py")
|
||||
|
||||
# file_content = upload_file.read() # 读取上传文件的内容
|
||||
# print(upload_file, file_content)
|
||||
# pycodebox.upload("torch_test.py", upload_file)
|
||||
|
||||
# asyncio.run(pycodebox.alist_files())
|
||||
|
||||
|
||||
reuslt = pycodebox.chat("```print('hello world!')```", do_code_exe=True)
|
||||
print(reuslt)
|
||||
# reuslt = pycodebox.chat("```print('hello world!')```", do_code_exe=True)
|
||||
# print(reuslt)
|
||||
|
||||
reuslt = pycodebox.chat("print('hello world!')", do_code_exe=False)
|
||||
print(reuslt)
|
||||
# reuslt = pycodebox.chat("print('hello world!')", do_code_exe=False)
|
||||
# print(reuslt)
|
||||
|
||||
for process in psutil.process_iter(["pid", "name", "cmdline"]):
|
||||
# 检查进程名是否包含"jupyter"
|
||||
if 'port=5050' in str(process.info["cmdline"]).lower() and \
|
||||
"jupyter" in process.info['name'].lower():
|
||||
# for process in psutil.process_iter(["pid", "name", "cmdline"]):
|
||||
# # 检查进程名是否包含"jupyter"
|
||||
# if 'port=5050' in str(process.info["cmdline"]).lower() and \
|
||||
# "jupyter" in process.info['name'].lower():
|
||||
|
||||
logger.warning(f'port=5050, {process.info}')
|
||||
# 关闭进程
|
||||
process.terminate()
|
||||
# logger.warning(f'port=5050, {process.info}')
|
||||
# # 关闭进程
|
||||
# process.terminate()
|
||||
|
||||
|
||||
# 测试2
|
||||
|
@ -103,61 +114,3 @@ for process in psutil.process_iter(["pid", "name", "cmdline"]):
|
|||
|
||||
# result = codebox.run("print('hello world!')")
|
||||
# print(result)
|
||||
|
||||
|
||||
|
||||
|
||||
# headers = {'Authorization': 'Token mytoken', 'token': 'mytoken'}
|
||||
|
||||
# kernel_url = "http://localhost:5050/api/kernels"
|
||||
|
||||
# response = requests.get(kernel_url, headers=headers)
|
||||
# if len(response.json())>0:
|
||||
# kernel_id = response.json()[0]["id"]
|
||||
# else:
|
||||
# response = requests.post(kernel_url, headers=headers)
|
||||
# kernel_id = response.json()["id"]
|
||||
|
||||
|
||||
# print(f"ws://localhost:5050/api/kernels/{kernel_id}/channels?token=mytoken")
|
||||
# ws = create_connection(f"ws://localhost:5050/api/kernels/{kernel_id}/channels?token=mytoken", headers=headers)
|
||||
|
||||
# code_text = "print('hello world!')"
|
||||
# # code_text = "import matplotlib.pyplot as plt\n\nplt.figure(figsize=(4,2))\nplt.plot([1,2,3,4,5])\nplt.show()"
|
||||
|
||||
# ws.send(
|
||||
# json.dumps(
|
||||
# {
|
||||
# "header": {
|
||||
# "msg_id": (msg_id := uuid4().hex),
|
||||
# "msg_type": "execute_request",
|
||||
# },
|
||||
# "parent_header": {},
|
||||
# "metadata": {},
|
||||
# "content": {
|
||||
# "code": code_text,
|
||||
# "silent": True,
|
||||
# "store_history": True,
|
||||
# "user_expressions": {},
|
||||
# "allow_stdin": False,
|
||||
# "stop_on_error": True,
|
||||
# },
|
||||
# "channel": "shell",
|
||||
# "buffers": [],
|
||||
# }
|
||||
# )
|
||||
# )
|
||||
|
||||
# while True:
|
||||
# received_msg = json.loads(ws.recv())
|
||||
# if received_msg["msg_type"] == "stream":
|
||||
# result_msg = received_msg # 找到结果消息
|
||||
# break
|
||||
# elif received_msg["header"]["msg_type"] == "execute_result":
|
||||
# result_msg = received_msg # 找到结果消息
|
||||
# break
|
||||
# elif received_msg["header"]["msg_type"] == "display_data":
|
||||
# result_msg = received_msg # 找到结果消息
|
||||
# break
|
||||
|
||||
# print(received_msg)
|
||||
|
|
|
@ -0,0 +1,14 @@
|
|||
import os, sys
|
||||
|
||||
src_dir = os.path.join(
|
||||
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
)
|
||||
sys.path.append(src_dir)
|
||||
|
||||
from dev_opsgpt.text_splitter import LCTextSplitter
|
||||
|
||||
filepath = ""
|
||||
lc_textSplitter = LCTextSplitter(filepath)
|
||||
docs = lc_textSplitter.file2text()
|
||||
|
||||
print(docs[0])
|
|
@ -0,0 +1,114 @@
|
|||
|
||||
|
||||
from langchain.agents import initialize_agent, Tool
|
||||
from langchain.tools import format_tool_to_openai_function, MoveFileTool, StructuredTool
|
||||
from pydantic import BaseModel, Field, create_model
|
||||
from pydantic.schema import model_schema, get_flat_models_from_fields
|
||||
from typing import List, Set
|
||||
import jsonref
|
||||
import json
|
||||
|
||||
import os, sys, requests
|
||||
|
||||
src_dir = os.path.join(
|
||||
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
)
|
||||
sys.path.append(src_dir)
|
||||
|
||||
from dev_opsgpt.tools import (
|
||||
WeatherInfo, WorldTimeGetTimezoneByArea, Multiplier, KSigmaDetector,
|
||||
toLangchainTools, get_tool_schema,
|
||||
TOOL_DICT, TOOL_SETS
|
||||
)
|
||||
from configs.model_config import (llm_model_dict, LLM_MODEL, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD)
|
||||
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.agents import AgentType, initialize_agent
|
||||
import langchain
|
||||
|
||||
# langchain.debug = True
|
||||
|
||||
tools = toLangchainTools([WeatherInfo, Multiplier, KSigmaDetector])
|
||||
|
||||
llm = ChatOpenAI(
|
||||
streaming=True,
|
||||
verbose=True,
|
||||
openai_api_key=llm_model_dict[LLM_MODEL]["api_key"],
|
||||
openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"],
|
||||
model_name=LLM_MODEL
|
||||
)
|
||||
|
||||
chat_prompt = '''if you can
|
||||
tools: {tools}
|
||||
query: {query}
|
||||
|
||||
if you choose llm-tool, you can direct
|
||||
'''
|
||||
# chain = LLMChain(prompt=chat_prompt, llm=llm)
|
||||
# content = chain({"tools": tools, "input": query})
|
||||
|
||||
# tool的检索
|
||||
|
||||
# tool参数的填充
|
||||
|
||||
# 函数执行
|
||||
|
||||
# from langchain.tools import StructuredTool
|
||||
|
||||
tools = [
|
||||
StructuredTool(
|
||||
name=Multiplier.name,
|
||||
func=Multiplier.run,
|
||||
description=Multiplier.description,
|
||||
args_schema=Multiplier.ToolInputArgs,
|
||||
),
|
||||
StructuredTool(
|
||||
name=WeatherInfo.name,
|
||||
func=WeatherInfo.run,
|
||||
description=WeatherInfo.description,
|
||||
args_schema=WeatherInfo.ToolInputArgs,
|
||||
)
|
||||
]
|
||||
|
||||
print(tools[0].func(1,2))
|
||||
|
||||
|
||||
tools = toLangchainTools([TOOL_DICT[i] for i in TOOL_SETS if i in TOOL_DICT])
|
||||
|
||||
agent = initialize_agent(
|
||||
tools,
|
||||
llm,
|
||||
agent=AgentType.STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION,
|
||||
verbose=True,
|
||||
return_intermediate_steps=True
|
||||
)
|
||||
|
||||
# agent.return_intermediate_steps = True
|
||||
# content = agent.run("查询北京的行政编码,同时返回北京的天气情况")
|
||||
# print(content)
|
||||
|
||||
# content = agent.run("判断这份数据是否存在异常,[0.857, 2.345, 1.234, 4.567, 3.456, 9.876, 5.678, 7.890, 6.789, 8.901, 10.987, 12.345, 11.234, 14.567, 13.456, 19.876, 15.678, 17.890, 16.789, 18.901, 20.987, 22.345, 21.234, 24.567, 23.456, 29.876, 25.678, 27.890, 26.789, 28.901, 30.987, 32.345, 31.234, 34.567, 33.456, 39.876, 35.678, 37.890, 36.789, 38.901, 40.987]")
|
||||
# content = agent("我有一份时序数据,[0.857, 2.345, 1.234, 4.567, 3.456, 9.876, 5.678, 7.890, 6.789, 8.901, 10.987, 12.345, 11.234, 14.567, 13.456, 19.876, 15.678, 17.890, 16.789, 18.901, 20.987, 22.345, 21.234, 24.567, 23.456, 29.876, 25.678, 27.890, 26.789, 28.901, 30.987, 32.345, 31.234, 34.567, 33.456, 39.876, 35.678, 37.890, 36.789, 38.901, 40.987],\我不知道这份数据是否存在问题,请帮我判断一下")
|
||||
# # print(content)
|
||||
# from langchain.schema import (
|
||||
# AgentAction
|
||||
# )
|
||||
|
||||
# s = ""
|
||||
# if isinstance(content, str):
|
||||
# s = content
|
||||
# else:
|
||||
# for i in content["intermediate_steps"]:
|
||||
# for j in i:
|
||||
# if isinstance(j, AgentAction):
|
||||
# s += j.log + "\n"
|
||||
# else:
|
||||
# s += "Observation: " + str(j) + "\n"
|
||||
|
||||
# s += "final answer:" + content["output"]
|
||||
# print(s)
|
||||
|
||||
# print(content["intermediate_steps"][0][0].log)
|
||||
# print( content["intermediate_steps"][0][0].log, content[""] + "\n" + content["i"] + "\n" + )
|
||||
# content = agent.run("i want to know the timezone of asia/shanghai, list all timezones available for that area.")
|
||||
# print(content)
|
Loading…
Reference in New Issue