Merge pull request #16 from codefuse-ai/coagent_branch
rename dev_opsgpt to coagent, and add memory&prompt manager
This commit is contained in:
commit
b7fdf50da7
|
@ -10,4 +10,8 @@ code_base
|
|||
.DS_Store
|
||||
.idea
|
||||
data
|
||||
.pyc
|
||||
tests
|
||||
*egg-info
|
||||
build
|
||||
dist
|
||||
|
|
|
@ -3,7 +3,6 @@ From python:3.9.18-bookworm
|
|||
WORKDIR /home/user
|
||||
|
||||
COPY ./requirements.txt /home/user/docker_requirements.txt
|
||||
COPY ./jupyter_start.sh /home/user/jupyter_start.sh
|
||||
|
||||
|
||||
RUN apt-get update
|
||||
|
|
201
LICENSE
201
LICENSE
|
@ -1,201 +0,0 @@
|
|||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
26
README.md
26
README.md
|
@ -1,10 +1,8 @@
|
|||
<p align="left">
|
||||
<a>中文</a>  |  <a href="README_en.md">English  </a>
|
||||
</p>
|
||||
|
||||
# <p align="center">CodeFuse-ChatBot: Development by Private Knowledge Augmentation</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="README.md"><img src="https://img.shields.io/badge/文档-中文版-yellow.svg" alt="ZH doc"></a>
|
||||
<a href="README_en.md"><img src="https://img.shields.io/badge/document-English-yellow.svg" alt="EN doc"></a>
|
||||
<img src="https://img.shields.io/github/license/codefuse-ai/codefuse-chatbot" alt="License">
|
||||
<a href="https://github.com/codefuse-ai/codefuse-chatbot/issues">
|
||||
<img alt="Open Issues" src="https://img.shields.io/github/issues-raw/codefuse-ai/codefuse-chatbot" />
|
||||
|
@ -38,7 +36,7 @@ DevOps-ChatBot是由蚂蚁CodeFuse团队开发的开源AI智能助手,致力
|
|||
💡 本项目旨在通过检索增强生成(Retrieval Augmented Generation,RAG)、工具学习(Tool Learning)和沙盒环境来构建软件开发全生命周期的AI智能助手,涵盖设计、编码、测试、部署和运维等阶段。 逐渐从各处资料查询、独立分散平台操作的传统开发运维模式转变到大模型问答的智能化开发运维模式,改变人们的开发运维习惯。
|
||||
|
||||
本项目核心差异技术、功能点:
|
||||
- **🧠 智能调度核心:** 构建了体系链路完善的调度核心,支持多模式一键配置,简化操作流程。 [使用说明](sources/readme_docs/multi-agent.md)
|
||||
- **🧠 智能调度核心:** 构建了体系链路完善的调度核心,支持多模式一键配置,简化操作流程。 [使用说明](sources/readme_docs/coagent/coagent.md)
|
||||
- **💻 代码整库分析:** 实现了仓库级的代码深入理解,以及项目文件级的代码编写与生成,提升了开发效率。
|
||||
- **📄 文档分析增强:** 融合了文档知识库与知识图谱,通过检索和推理增强,为文档分析提供了更深层次的支持。
|
||||
- **🔧 垂类专属知识:** 为DevOps领域定制的专属知识库,支持垂类知识库的自助一键构建,便捷实用。
|
||||
|
@ -93,7 +91,13 @@ DevOps-ChatBot是由蚂蚁CodeFuse团队开发的开源AI智能助手,致力
|
|||
|
||||
|
||||
## 🚀 快速使用
|
||||
### coagent-py
|
||||
完整文档见:[coagent](sources/readme_docs/coagent/coagent.md)
|
||||
```
|
||||
pip install coagent
|
||||
```
|
||||
|
||||
### 使用ChatBot
|
||||
请自行安装 nvidia 驱动程序,本项目已在 Python 3.9.18,CUDA 11.7 环境下,Windows、X86 架构的 macOS 系统中完成测试。
|
||||
|
||||
Docker安装、私有化LLM接入及相关启动问题见:[快速使用明细](sources/readme_docs/start.md)
|
||||
|
@ -155,12 +159,12 @@ NO_REMOTE_API = True
|
|||
```bash
|
||||
# 若需要支撑codellama-34b-int4模型,需要给fastchat打一个补丁
|
||||
# cp examples/gptq.py ~/site-packages/fastchat/modules/gptq.py
|
||||
# dev_opsgpt/service/llm_api.py#258 修改为 kwargs={"gptq_wbits": 4},
|
||||
# examples/llm_api.py#258 修改为 kwargs={"gptq_wbits": 4},
|
||||
|
||||
# start llm-service(可选)
|
||||
python dev_opsgpt/service/llm_api.py
|
||||
python examples/llm_api.py
|
||||
```
|
||||
更多LLM接入方法见[详情...](sources/readme_docs/fastchat.md)
|
||||
更多LLM接入方法见[更多细节...](sources/readme_docs/fastchat.md)
|
||||
<br>
|
||||
|
||||
```bash
|
||||
|
@ -168,6 +172,12 @@ python dev_opsgpt/service/llm_api.py
|
|||
cd examples
|
||||
python start.py
|
||||
```
|
||||
## 贡献指南
|
||||
非常感谢您对 Codefuse 项目感兴趣,我们非常欢迎您对 Codefuse 项目的各种建议、意见(包括批评)、评论和贡献。
|
||||
|
||||
您对 Codefuse 的各种建议、意见、评论可以直接通过 GitHub 的 Issues 提出。
|
||||
|
||||
参与 Codefuse 项目并为其作出贡献的方法有很多:代码实现、测试编写、流程工具改进、文档完善等等。任何贡献我们都会非常欢迎,并将您加入贡献者列表。详见[Contribution Guide...](sources/readme_docs/contribution/contribute_guide.md)
|
||||
|
||||
## 🤗 致谢
|
||||
|
||||
|
|
34
README_en.md
34
README_en.md
|
@ -1,10 +1,8 @@
|
|||
<p align="left">
|
||||
<a href="README.md">中文</a>  |  <a>English  </a>
|
||||
</p>
|
||||
|
||||
# <p align="center">Codefuse-ChatBot: Development by Private Knowledge Augmentation</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="README.md"><img src="https://img.shields.io/badge/文档-中文版-yellow.svg" alt="ZH doc"></a>
|
||||
<a href="README_EN.md"><img src="https://img.shields.io/badge/document-英文版-yellow.svg" alt="EN doc"></a>
|
||||
<img src="https://img.shields.io/github/license/codefuse-ai/codefuse-chatbot" alt="License">
|
||||
<a href="https://github.com/codefuse-ai/codefuse-chatbot/issues">
|
||||
<img alt="Open Issues" src="https://img.shields.io/github/issues-raw/codefuse-ai/codefuse-chatbot" />
|
||||
|
@ -15,6 +13,7 @@ This project is an open-source AI intelligent assistant, specifically designed f
|
|||
|
||||
|
||||
## 🔔 Updates
|
||||
- [2023.12.26] Opening the capability to integrate with open-source private large models and large model interfaces based on FastChat
|
||||
- [2023.12.01] Release of Multi-Agent and codebase retrieval functionalities.
|
||||
- [2023.11.15] Addition of Q&A enhancement mode based on the local codebase.
|
||||
- [2023.09.15] Launch of sandbox functionality for local/isolated environments, enabling knowledge retrieval from specified URLs using web crawlers.
|
||||
|
@ -30,13 +29,13 @@ This project is an open-source AI intelligent assistant, specifically designed f
|
|||
|
||||
💡 The aim of this project is to construct an AI intelligent assistant for the entire lifecycle of software development, covering design, coding, testing, deployment, and operations, through Retrieval Augmented Generation (RAG), Tool Learning, and sandbox environments. It transitions gradually from the traditional development and operations mode of querying information from various sources and operating on standalone, disparate platforms to an intelligent development and operations mode based on large-model Q&A, changing people's development and operations habits.
|
||||
|
||||
- **🧠 Intelligent Scheduling Core:** Constructed a well-integrated scheduling core system that supports multi-mode one-click configuration, simplifying the operational process.
|
||||
- **🧠 Intelligent Scheduling Core:** Constructed a well-integrated scheduling core system that supports multi-mode one-click configuration, simplifying the operational process. [coagent](sources/readme_docs/coagent/coagent-en.md)
|
||||
- **💻 Comprehensive Code Repository Analysis:** Achieved in-depth understanding at the repository level and coding and generation at the project file level, enhancing development efficiency.
|
||||
- **📄 Enhanced Document Analysis:** Integrated document knowledge bases with knowledge graphs, providing deeper support for document analysis through enhanced retrieval and reasoning.
|
||||
- **🔧 Industry-Specific Knowledge:** Tailored a specialized knowledge base for the DevOps domain, supporting the self-service one-click construction of industry-specific knowledge bases for convenience and practicality.
|
||||
- **🤖 Compatible Models for Specific Verticals:** Designed small models specifically for the DevOps field, ensuring compatibility with related DevOps platforms and promoting the integration of the technological ecosystem.
|
||||
|
||||
🌍 Relying on open-source LLM and Embedding models, this project can achieve offline private deployments based on open-source models. Additionally, this project also supports the use of the OpenAI API.
|
||||
🌍 Relying on open-source LLM and Embedding models, this project can achieve offline private deployments based on open-source models. Additionally, this project also supports the use of the OpenAI API.[Access Demo](sources/readme_docs/fastchat-en.md)
|
||||
|
||||
👥 The core development team has been long-term focused on research in the AIOps + NLP domain. We initiated the CodefuseGPT project, hoping that everyone could contribute high-quality development and operations documents widely, jointly perfecting this solution to achieve the goal of "Making Development Seamless for Everyone."
|
||||
|
||||
|
@ -64,7 +63,7 @@ This project is an open-source AI intelligent assistant, specifically designed f
|
|||
- 💬 **LLM:**:Supports various open-source models and LLM interfaces.
|
||||
- 🛠️ **API Management::** Enables rapid integration of open-source components and operational platforms.
|
||||
|
||||
For implementation details, see: [Technical Route Details](sources/readme_docs/roadmap.md)
|
||||
For implementation details, see: [Technical Route Details](sources/readme_docs/roadmap-en.md)
|
||||
|
||||
|
||||
## 🌐 Model Integration
|
||||
|
@ -79,7 +78,13 @@ If you need to integrate a specific model, please inform us of your requirements
|
|||
|
||||
|
||||
## 🚀 Quick Start
|
||||
### coagent-py
|
||||
More Detail see:[coagent](sources/readme_docs/coagent/coagent-en.md)
|
||||
```
|
||||
pip install coagent
|
||||
```
|
||||
|
||||
### ChatBot-UI
|
||||
Please install the Nvidia driver yourself; this project has been tested on Python 3.9.18, CUDA 11.7, Windows, and X86 architecture macOS systems.
|
||||
|
||||
1. Preparation of Python environment
|
||||
|
@ -172,11 +177,13 @@ By default, only webui related services are started, and fastchat is not started
|
|||
```bash
|
||||
# if use codellama-34b-int4, you should replace fastchat's gptq.py
|
||||
# cp examples/gptq.py ~/site-packages/fastchat/modules/gptq.py
|
||||
# dev_opsgpt/service/llm_api.py#258 => kwargs={"gptq_wbits": 4},
|
||||
# examples/llm_api.py#258 => kwargs={"gptq_wbits": 4},
|
||||
|
||||
# start llm-service(可选)
|
||||
python dev_opsgpt/service/llm_api.py
|
||||
python examples/llm_api.py
|
||||
```
|
||||
More details about accessing LLM Moldes[More Details...](sources/readme_docs/fastchat.md)
|
||||
<br>
|
||||
|
||||
```bash
|
||||
# After configuring server_config.py, you can start with just one click.
|
||||
|
@ -184,6 +191,13 @@ cd examples
|
|||
bash start_webui.sh
|
||||
```
|
||||
|
||||
## 贡献指南
|
||||
Thank you for your interest in the Codefuse project. We warmly welcome any suggestions, opinions (including criticisms), comments, and contributions to the Codefuse project.
|
||||
|
||||
Your suggestions, opinions, and comments on Codefuse can be directly submitted through GitHub Issues.
|
||||
|
||||
There are many ways to participate in the Codefuse project and contribute to it: code implementation, test writing, process tool improvement, documentation enhancement, and more. We welcome any contributions and will add you to our list of contributors. See [contribution guide](sources/readme_docs/contribution/contribute_guide_en.md)
|
||||
|
||||
## 🤗 Acknowledgements
|
||||
|
||||
This project is based on [langchain-chatchat](https://github.com/chatchat-space/Langchain-Chatchat) and [codebox-api](https://github.com/shroominic/codebox-api). We deeply appreciate their contributions to open source!
|
||||
This project is based on [langchain-chatchat](https://github.com/chatchat-space/Langchain-Chatchat) and [codebox-api](https://github.com/shroominic/codebox-api). We deeply appreciate their contributions to open source!
|
|
@ -0,0 +1,88 @@
|
|||
import os
|
||||
import platform
|
||||
|
||||
system_name = platform.system()
|
||||
executable_path = os.getcwd()
|
||||
|
||||
# 日志存储路径
|
||||
LOG_PATH = os.environ.get("LOG_PATH", None) or os.path.join(executable_path, "logs")
|
||||
|
||||
# 知识库默认存储路径
|
||||
SOURCE_PATH = os.environ.get("SOURCE_PATH", None) or os.path.join(executable_path, "sources")
|
||||
|
||||
# 知识库默认存储路径
|
||||
KB_ROOT_PATH = os.environ.get("KB_ROOT_PATH", None) or os.path.join(executable_path, "knowledge_base")
|
||||
|
||||
# 代码库默认存储路径
|
||||
CB_ROOT_PATH = os.environ.get("CB_ROOT_PATH", None) or os.path.join(executable_path, "code_base")
|
||||
|
||||
# nltk 模型存储路径
|
||||
NLTK_DATA_PATH = os.environ.get("NLTK_DATA_PATH", None) or os.path.join(executable_path, "nltk_data")
|
||||
|
||||
# 代码存储路径
|
||||
JUPYTER_WORK_PATH = os.environ.get("JUPYTER_WORK_PATH", None) or os.path.join(executable_path, "jupyter_work")
|
||||
|
||||
# WEB_CRAWL存储路径
|
||||
WEB_CRAWL_PATH = os.environ.get("WEB_CRAWL_PATH", None) or os.path.join(executable_path, "knowledge_base")
|
||||
|
||||
# NEBULA_DATA存储路径
|
||||
NELUBA_PATH = os.environ.get("NELUBA_PATH", None) or os.path.join(executable_path, "data/neluba_data")
|
||||
|
||||
for _path in [LOG_PATH, SOURCE_PATH, KB_ROOT_PATH, NLTK_DATA_PATH, JUPYTER_WORK_PATH, WEB_CRAWL_PATH, NELUBA_PATH]:
|
||||
if not os.path.exists(_path):
|
||||
os.makedirs(_path, exist_ok=True)
|
||||
|
||||
# 数据库默认存储路径。
|
||||
# 如果使用sqlite,可以直接修改DB_ROOT_PATH;如果使用其它数据库,请直接修改SQLALCHEMY_DATABASE_URI。
|
||||
DB_ROOT_PATH = os.path.join(KB_ROOT_PATH, "info.db")
|
||||
SQLALCHEMY_DATABASE_URI = f"sqlite:///{DB_ROOT_PATH}"
|
||||
|
||||
kbs_config = {
|
||||
"faiss": {
|
||||
},}
|
||||
|
||||
|
||||
# GENERAL SERVER CONFIG
|
||||
DEFAULT_BIND_HOST = os.environ.get("DEFAULT_BIND_HOST", None) or "127.0.0.1"
|
||||
|
||||
# NEBULA SERVER CONFIG
|
||||
NEBULA_HOST = DEFAULT_BIND_HOST
|
||||
NEBULA_PORT = 9669
|
||||
NEBULA_STORAGED_PORT = 9779
|
||||
NEBULA_USER = 'root'
|
||||
NEBULA_PASSWORD = ''
|
||||
NEBULA_GRAPH_SERVER = {
|
||||
"host": DEFAULT_BIND_HOST,
|
||||
"port": NEBULA_PORT,
|
||||
"docker_port": NEBULA_PORT
|
||||
}
|
||||
|
||||
# CHROMA CONFIG
|
||||
CHROMA_PERSISTENT_PATH = '/home/user/chatbot/data/chroma_data'
|
||||
|
||||
|
||||
# 默认向量库类型。可选:faiss, milvus, pg.
|
||||
DEFAULT_VS_TYPE = os.environ.get("DEFAULT_VS_TYPE") or "faiss"
|
||||
|
||||
# 缓存向量库数量
|
||||
CACHED_VS_NUM = os.environ.get("CACHED_VS_NUM") or 1
|
||||
|
||||
# 知识库中单段文本长度
|
||||
CHUNK_SIZE = os.environ.get("CHUNK_SIZE") or 500
|
||||
|
||||
# 知识库中相邻文本重合长度
|
||||
OVERLAP_SIZE = os.environ.get("OVERLAP_SIZE") or 50
|
||||
|
||||
# 知识库匹配向量数量
|
||||
VECTOR_SEARCH_TOP_K = os.environ.get("VECTOR_SEARCH_TOP_K") or 5
|
||||
|
||||
# 知识库匹配相关度阈值,取值范围在0-1之间,SCORE越小,相关度越高,取到1相当于不筛选,建议设置在0.5左右
|
||||
# Mac 可能存在无法使用normalized_L2的问题,因此调整SCORE_THRESHOLD至 0~1100
|
||||
FAISS_NORMALIZE_L2 = True if system_name in ["Linux", "Windows"] else False
|
||||
SCORE_THRESHOLD = 1 if system_name in ["Linux", "Windows"] else 1100
|
||||
|
||||
# 搜索引擎匹配结题数量
|
||||
SEARCH_ENGINE_TOP_K = os.environ.get("SEARCH_ENGINE_TOP_K") or 5
|
||||
|
||||
# 代码引擎匹配结题数量
|
||||
CODE_SEARCH_TOP_K = os.environ.get("CODE_SEARCH_TOP_K") or 1
|
|
@ -5,30 +5,26 @@ from loguru import logger
|
|||
import importlib
|
||||
import copy
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from configs.model_config import (
|
||||
llm_model_dict, LLM_MODEL, PROMPT_TEMPLATE,
|
||||
VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD)
|
||||
# from configs.model_config import (
|
||||
# llm_model_dict, LLM_MODEL, PROMPT_TEMPLATE,
|
||||
# VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD)
|
||||
|
||||
from dev_opsgpt.tools import (
|
||||
from coagent.tools import (
|
||||
toLangchainTools,
|
||||
TOOL_DICT, TOOL_SETS
|
||||
)
|
||||
|
||||
from dev_opsgpt.connector.phase import BasePhase
|
||||
from dev_opsgpt.connector.agents import BaseAgent, ReactAgent
|
||||
from dev_opsgpt.connector.chains import BaseChain
|
||||
from dev_opsgpt.connector.schema import (
|
||||
Message,
|
||||
load_phase_configs, load_chain_configs, load_role_configs
|
||||
)
|
||||
from dev_opsgpt.connector.schema import Memory
|
||||
from dev_opsgpt.utils.common_utils import file_normalize
|
||||
from dev_opsgpt.chat.utils import History, wrap_done
|
||||
from dev_opsgpt.connector.configs import PHASE_CONFIGS, AGETN_CONFIGS, CHAIN_CONFIGS
|
||||
from coagent.connector.phase import BasePhase
|
||||
from coagent.connector.schema import Message
|
||||
from coagent.connector.schema import Memory
|
||||
from coagent.chat.utils import History, wrap_done
|
||||
from coagent.llm_models.llm_config import LLMConfig, EmbedConfig
|
||||
from coagent.connector.configs import PHASE_CONFIGS, AGETN_CONFIGS, CHAIN_CONFIGS
|
||||
|
||||
PHASE_MODULE = importlib.import_module("dev_opsgpt.connector.phase")
|
||||
PHASE_MODULE = importlib.import_module("coagent.connector.phase")
|
||||
|
||||
|
||||
|
||||
|
@ -56,8 +52,8 @@ class AgentChat:
|
|||
doc_engine_name: str = Body(..., description="知识库名称", examples=["samples"]),
|
||||
search_engine_name: str = Body(..., description="搜索引擎名称", examples=["duckduckgo"]),
|
||||
code_engine_name: str = Body(..., description="代码引擎名称", examples=["samples"]),
|
||||
top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"),
|
||||
score_threshold: float = Body(SCORE_THRESHOLD, description="知识库匹配相关度阈值,取值范围在0-1之间,SCORE越小,相关度越高,取到1相当于不筛选,建议设置在0.5左右", ge=0, le=1),
|
||||
top_k: int = Body(5, description="匹配向量数"),
|
||||
score_threshold: float = Body(1, description="知识库匹配相关度阈值,取值范围在0-1之间,SCORE越小,相关度越高,取到1相当于不筛选,建议设置在0.5左右", ge=0, le=1),
|
||||
stream: bool = Body(False, description="流式输出"),
|
||||
local_doc_url: bool = Body(False, description="知识文件返回本地路径(true)或URL(false)"),
|
||||
choose_tools: List[str] = Body([], description="选择tool的集合"),
|
||||
|
@ -71,12 +67,27 @@ class AgentChat:
|
|||
history_node_list: List = Body([], description="代码历史相关节点"),
|
||||
isDetailed: bool = Body(False, description="是否输出完整的agent相关内容"),
|
||||
upload_file: Union[str, Path, bytes] = "",
|
||||
kb_root_path: str = Body("", description="知识库存储路径"),
|
||||
jupyter_work_path: str = Body("", description="sandbox执行环境"),
|
||||
sandbox_server: str = Body({}, description="代码历史相关节点"),
|
||||
api_key: str = Body(os.environ.get("OPENAI_API_KEY"), description=""),
|
||||
api_base_url: str = Body(os.environ.get("API_BASE_URL"),),
|
||||
embed_model: str = Body("", description="向量模型"),
|
||||
embed_model_path: str = Body("", description="向量模型路径"),
|
||||
model_device: str = Body("", description="模型加载设备"),
|
||||
embed_engine: str = Body("", description="向量模型类型"),
|
||||
model_name: str = Body("", description="llm模型名称"),
|
||||
temperature: float = Body(0.2, description=""),
|
||||
**kargs
|
||||
) -> Message:
|
||||
|
||||
# update configs
|
||||
phase_configs, chain_configs, agent_configs = self.update_configs(
|
||||
custom_phase_configs, custom_chain_configs, custom_role_configs)
|
||||
params = locals()
|
||||
params.pop("self")
|
||||
embed_config: EmbedConfig = EmbedConfig(**params)
|
||||
llm_config: LLMConfig = LLMConfig(**params)
|
||||
|
||||
logger.info('phase_configs={}'.format(phase_configs))
|
||||
logger.info('chain_configs={}'.format(chain_configs))
|
||||
|
@ -86,7 +97,6 @@ class AgentChat:
|
|||
|
||||
# choose tools
|
||||
tools = toLangchainTools([TOOL_DICT[i] for i in choose_tools if i in TOOL_DICT])
|
||||
logger.debug(f"upload_file: {upload_file}")
|
||||
|
||||
if upload_file:
|
||||
upload_file_name = upload_file if upload_file and isinstance(upload_file, str) else upload_file.name
|
||||
|
@ -97,8 +107,8 @@ class AgentChat:
|
|||
|
||||
input_message = Message(
|
||||
role_content=query,
|
||||
role_type="human",
|
||||
role_name="user",
|
||||
role_type="user",
|
||||
role_name="human",
|
||||
input_query=query,
|
||||
origin_query=query,
|
||||
phase_name=phase_name,
|
||||
|
@ -120,30 +130,25 @@ class AgentChat:
|
|||
])
|
||||
# start to execute
|
||||
phase_class = getattr(PHASE_MODULE, phase_configs[input_message.phase_name]["phase_type"])
|
||||
# TODO 需要把相关信息补充上去
|
||||
phase = phase_class(input_message.phase_name,
|
||||
task = input_message.task,
|
||||
phase_config = phase_configs,
|
||||
chain_config = chain_configs,
|
||||
role_config = agent_configs,
|
||||
do_summary=phase_configs[input_message.phase_name]["do_summary"],
|
||||
do_code_retrieval=input_message.do_code_retrieval,
|
||||
do_doc_retrieval=input_message.do_doc_retrieval,
|
||||
do_search=input_message.do_search,
|
||||
base_phase_config = phase_configs,
|
||||
base_chain_config = chain_configs,
|
||||
base_role_config = agent_configs,
|
||||
phase_config = None,
|
||||
kb_root_path = kb_root_path,
|
||||
jupyter_work_path = jupyter_work_path,
|
||||
sandbox_server = sandbox_server,
|
||||
embed_config = embed_config,
|
||||
llm_config = llm_config,
|
||||
)
|
||||
output_message, local_memory = phase.step(input_message, history)
|
||||
# logger.debug(f"local_memory: {local_memory.to_str_messages(content_key='step_content')}")
|
||||
|
||||
# return {
|
||||
# "answer": output_message.role_content,
|
||||
# "db_docs": output_message.db_docs,
|
||||
# "search_docs": output_message.search_docs,
|
||||
# "code_docs": output_message.code_docs,
|
||||
# "figures": output_message.figures
|
||||
# }
|
||||
|
||||
def chat_iterator(message: Message, local_memory: Memory, isDetailed=False):
|
||||
step_content = local_memory.to_str_messages(content_key='step_content', filter_roles=["user"])
|
||||
final_content = message.role_content
|
||||
logger.debug(f"{step_content}")
|
||||
result = {
|
||||
"answer": "",
|
||||
"db_docs": [str(doc) for doc in message.db_docs],
|
||||
|
@ -190,8 +195,8 @@ class AgentChat:
|
|||
search_engine_name: str = Body(..., description="搜索引擎名称", examples=["duckduckgo"]),
|
||||
code_engine_name: str = Body(..., description="代码引擎名称", examples=["samples"]),
|
||||
cb_search_type: str = Body(..., description="代码查询模式", examples=["tag"]),
|
||||
top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"),
|
||||
score_threshold: float = Body(SCORE_THRESHOLD, description="知识库匹配相关度阈值,取值范围在0-1之间,SCORE越小,相关度越高,取到1相当于不筛选,建议设置在0.5左右", ge=0, le=1),
|
||||
top_k: int = Body(5, description="匹配向量数"),
|
||||
score_threshold: float = Body(1, description="知识库匹配相关度阈值,取值范围在0-1之间,SCORE越小,相关度越高,取到1相当于不筛选,建议设置在0.5左右", ge=0, le=1),
|
||||
stream: bool = Body(False, description="流式输出"),
|
||||
local_doc_url: bool = Body(False, description="知识文件返回本地路径(true)或URL(false)"),
|
||||
choose_tools: List[str] = Body([], description="选择tool的集合"),
|
||||
|
@ -205,15 +210,32 @@ class AgentChat:
|
|||
history_node_list: List = Body([], description="代码历史相关节点"),
|
||||
isDetailed: bool = Body(False, description="是否输出完整的agent相关内容"),
|
||||
upload_file: Union[str, Path, bytes] = "",
|
||||
kb_root_path: str = Body("", description="知识库存储路径"),
|
||||
jupyter_work_path: str = Body("", description="sandbox执行环境"),
|
||||
sandbox_server: str = Body({}, description="代码历史相关节点"),
|
||||
api_key: str = Body(os.environ["OPENAI_API_KEY"], description=""),
|
||||
api_base_url: str = Body(os.environ.get("API_BASE_URL"),),
|
||||
embed_model: str = Body("", description="向量模型"),
|
||||
embed_model_path: str = Body("", description="向量模型路径"),
|
||||
model_device: str = Body("", description="模型加载设备"),
|
||||
embed_engine: str = Body("", description="向量模型类型"),
|
||||
model_name: str = Body("", description="llm模型名称"),
|
||||
temperature: float = Body(0.2, description=""),
|
||||
**kargs
|
||||
) -> Message:
|
||||
|
||||
# update configs
|
||||
phase_configs, chain_configs, agent_configs = self.update_configs(
|
||||
custom_phase_configs, custom_chain_configs, custom_role_configs)
|
||||
|
||||
#
|
||||
params = locals()
|
||||
params.pop("self")
|
||||
embed_config: EmbedConfig = EmbedConfig(**params)
|
||||
llm_config: LLMConfig = LLMConfig(**params)
|
||||
|
||||
# choose tools
|
||||
tools = toLangchainTools([TOOL_DICT[i] for i in choose_tools if i in TOOL_DICT])
|
||||
logger.debug(f"upload_file: {upload_file}")
|
||||
|
||||
if upload_file:
|
||||
upload_file_name = upload_file if upload_file and isinstance(upload_file, str) else upload_file.name
|
||||
|
@ -224,8 +246,8 @@ class AgentChat:
|
|||
|
||||
input_message = Message(
|
||||
role_content=query,
|
||||
role_type="human",
|
||||
role_name="user",
|
||||
role_type="user",
|
||||
role_name="human",
|
||||
input_query=query,
|
||||
origin_query=query,
|
||||
phase_name=phase_name,
|
||||
|
@ -252,21 +274,23 @@ class AgentChat:
|
|||
phase_class = getattr(PHASE_MODULE, phase_configs[input_message.phase_name]["phase_type"])
|
||||
phase = phase_class(input_message.phase_name,
|
||||
task = input_message.task,
|
||||
phase_config = phase_configs,
|
||||
chain_config = chain_configs,
|
||||
role_config = agent_configs,
|
||||
do_summary=phase_configs[input_message.phase_name]["do_summary"],
|
||||
do_code_retrieval=input_message.do_code_retrieval,
|
||||
do_doc_retrieval=input_message.do_doc_retrieval,
|
||||
do_search=input_message.do_search,
|
||||
base_phase_config = phase_configs,
|
||||
base_chain_config = chain_configs,
|
||||
base_role_config = agent_configs,
|
||||
phase_config = None,
|
||||
kb_root_path = kb_root_path,
|
||||
jupyter_work_path = jupyter_work_path,
|
||||
sandbox_server = sandbox_server,
|
||||
embed_config = embed_config,
|
||||
llm_config = llm_config,
|
||||
)
|
||||
self.chatPhase_dict[phase_configs[input_message.phase_name]["phase_type"]] = phase
|
||||
else:
|
||||
phase = self.chatPhase_dict[phase_configs[input_message.phase_name]["phase_type"]]
|
||||
|
||||
def chat_iterator(message: Message, local_memory: Memory, isDetailed=False):
|
||||
step_content = local_memory.to_str_messages(content_key='step_content', filter_roles=["user"])
|
||||
step_content = "\n\n".join([f"{v}" for parsed_output in local_memory.get_parserd_output_list() for k, v in parsed_output.items() if k not in ["Action Status"]])
|
||||
step_content = local_memory.to_str_messages(content_key='step_content', filter_roles=["human"])
|
||||
step_content = "\n\n".join([f"{v}" for parsed_output in local_memory.get_parserd_output_list()[1:] for k, v in parsed_output.items() if k not in ["Action Status"]])
|
||||
final_content = message.role_content
|
||||
result = {
|
||||
"answer": "",
|
||||
|
@ -279,7 +303,6 @@ class AgentChat:
|
|||
"final_content": final_content,
|
||||
}
|
||||
|
||||
|
||||
related_nodes, has_nodes = [], [ ]
|
||||
for nodes in result["related_nodes"]:
|
||||
for node in nodes:
|
||||
|
@ -301,7 +324,7 @@ class AgentChat:
|
|||
|
||||
for output_message, local_memory in phase.astep(input_message, history):
|
||||
|
||||
# logger.debug(f"output_message: {output_message.role_content}")
|
||||
# logger.debug(f"output_message: {output_message}")
|
||||
# output_message = Message(**output_message)
|
||||
# local_memory = Memory(**local_memory)
|
||||
for result in chat_iterator(output_message, local_memory, isDetailed):
|
|
@ -1,16 +1,17 @@
|
|||
from fastapi import Body, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
import asyncio, json
|
||||
import asyncio, json, os
|
||||
from typing import List, AsyncIterable
|
||||
|
||||
from langchain import LLMChain
|
||||
from langchain.callbacks import AsyncIteratorCallbackHandler
|
||||
from langchain.prompts.chat import ChatPromptTemplate
|
||||
|
||||
from dev_opsgpt.llm_models import getChatModel
|
||||
from dev_opsgpt.chat.utils import History, wrap_done
|
||||
from configs.model_config import (llm_model_dict, LLM_MODEL, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD)
|
||||
from dev_opsgpt.utils import BaseResponse
|
||||
from coagent.llm_models import getChatModel, getChatModelFromConfig
|
||||
from coagent.chat.utils import History, wrap_done
|
||||
from coagent.llm_models.llm_config import LLMConfig, EmbedConfig
|
||||
# from configs.model_config import (llm_model_dict, LLM_MODEL, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD)
|
||||
from coagent.utils import BaseResponse
|
||||
from loguru import logger
|
||||
|
||||
|
||||
|
@ -37,22 +38,34 @@ class Chat:
|
|||
examples=[[{"role": "user", "content": "我们来玩成语接龙,我先来,生龙活虎"}]]
|
||||
),
|
||||
engine_name: str = Body(..., description="知识库名称", examples=["samples"]),
|
||||
top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"),
|
||||
score_threshold: float = Body(SCORE_THRESHOLD, description="知识库匹配相关度阈值,取值范围在0-1之间,SCORE越小,相关度越高,取到1相当于不筛选,建议设置在0.5左右", ge=0, le=1),
|
||||
top_k: int = Body(5, description="匹配向量数"),
|
||||
score_threshold: float = Body(1, description="知识库匹配相关度阈值,取值范围在0-1之间,SCORE越小,相关度越高,取到1相当于不筛选,建议设置在0.5左右", ge=0, le=1),
|
||||
stream: bool = Body(False, description="流式输出"),
|
||||
local_doc_url: bool = Body(False, description="知识文件返回本地路径(true)或URL(false)"),
|
||||
request: Request = None,
|
||||
api_key: str = Body(os.environ.get("OPENAI_API_KEY")),
|
||||
api_base_url: str = Body(os.environ.get("API_BASE_URL")),
|
||||
embed_model: str = Body("", ),
|
||||
embed_model_path: str = Body("", ),
|
||||
embed_engine: str = Body("", ),
|
||||
model_name: str = Body("", ),
|
||||
temperature: float = Body(0.5, ),
|
||||
model_device: str = Body("", ),
|
||||
**kargs
|
||||
):
|
||||
params = locals()
|
||||
params.pop("self", None)
|
||||
llm_config: LLMConfig = LLMConfig(**params)
|
||||
embed_config: EmbedConfig = EmbedConfig(**params)
|
||||
self.engine_name = engine_name if isinstance(engine_name, str) else engine_name.default
|
||||
self.top_k = top_k if isinstance(top_k, int) else top_k.default
|
||||
self.score_threshold = score_threshold if isinstance(score_threshold, float) else score_threshold.default
|
||||
self.stream = stream if isinstance(stream, bool) else stream.default
|
||||
self.local_doc_url = local_doc_url if isinstance(local_doc_url, bool) else local_doc_url.default
|
||||
self.request = request
|
||||
return self._chat(query, history, **kargs)
|
||||
return self._chat(query, history, llm_config, embed_config, **kargs)
|
||||
|
||||
def _chat(self, query: str, history: List[History], **kargs):
|
||||
def _chat(self, query: str, history: List[History], llm_config: LLMConfig, embed_config: EmbedConfig, **kargs):
|
||||
history = [History(**h) if isinstance(h, dict) else h for h in history]
|
||||
|
||||
## check service dependcy is ok
|
||||
|
@ -61,9 +74,10 @@ class Chat:
|
|||
if service_status.code!=200: return service_status
|
||||
|
||||
def chat_iterator(query: str, history: List[History]):
|
||||
model = getChatModel()
|
||||
# model = getChatModel()
|
||||
model = getChatModelFromConfig(llm_config)
|
||||
|
||||
result, content = self.create_task(query, history, model, **kargs)
|
||||
result, content = self.create_task(query, history, model, llm_config, embed_config, **kargs)
|
||||
logger.info('result={}'.format(result))
|
||||
logger.info('content={}'.format(content))
|
||||
|
||||
|
@ -87,21 +101,34 @@ class Chat:
|
|||
examples=[[{"role": "user", "content": "我们来玩成语接龙,我先来,生龙活虎"}]]
|
||||
),
|
||||
engine_name: str = Body(..., description="知识库名称", examples=["samples"]),
|
||||
top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"),
|
||||
score_threshold: float = Body(SCORE_THRESHOLD, description="知识库匹配相关度阈值,取值范围在0-1之间,SCORE越小,相关度越高,取到1相当于不筛选,建议设置在0.5左右", ge=0, le=1),
|
||||
top_k: int = Body(5, description="匹配向量数"),
|
||||
score_threshold: float = Body(1, description="知识库匹配相关度阈值,取值范围在0-1之间,SCORE越小,相关度越高,取到1相当于不筛选,建议设置在0.5左右", ge=0, le=1),
|
||||
stream: bool = Body(False, description="流式输出"),
|
||||
local_doc_url: bool = Body(False, description="知识文件返回本地路径(true)或URL(false)"),
|
||||
request: Request = None,
|
||||
api_key: str = Body(os.environ.get("OPENAI_API_KEY")),
|
||||
api_base_url: str = Body(os.environ.get("API_BASE_URL")),
|
||||
embed_model: str = Body("", ),
|
||||
embed_model_path: str = Body("", ),
|
||||
embed_engine: str = Body("", ),
|
||||
model_name: str = Body("", ),
|
||||
temperature: float = Body(0.5, ),
|
||||
model_device: str = Body("", ),
|
||||
):
|
||||
#
|
||||
params = locals()
|
||||
params.pop("self", None)
|
||||
llm_config: LLMConfig = LLMConfig(**params)
|
||||
embed_config: EmbedConfig = EmbedConfig(**params)
|
||||
self.engine_name = engine_name if isinstance(engine_name, str) else engine_name.default
|
||||
self.top_k = top_k if isinstance(top_k, int) else top_k.default
|
||||
self.score_threshold = score_threshold if isinstance(score_threshold, float) else score_threshold.default
|
||||
self.stream = stream if isinstance(stream, bool) else stream.default
|
||||
self.local_doc_url = local_doc_url if isinstance(local_doc_url, bool) else local_doc_url.default
|
||||
self.request = request
|
||||
return self._achat(query, history)
|
||||
return self._achat(query, history, llm_config, embed_config)
|
||||
|
||||
def _achat(self, query: str, history: List[History]):
|
||||
def _achat(self, query: str, history: List[History], llm_config: LLMConfig, embed_config: EmbedConfig):
|
||||
history = [History(**h) if isinstance(h, dict) else h for h in history]
|
||||
## check service dependcy is ok
|
||||
service_status = self.check_service_status()
|
||||
|
@ -109,9 +136,10 @@ class Chat:
|
|||
|
||||
async def chat_iterator(query, history):
|
||||
callback = AsyncIteratorCallbackHandler()
|
||||
model = getChatModel()
|
||||
# model = getChatModel()
|
||||
model = getChatModelFromConfig(llm_config)
|
||||
|
||||
task, result = self.create_atask(query, history, model, callback)
|
||||
task, result = self.create_atask(query, history, model, llm_config, embed_config, callback)
|
||||
if self.stream:
|
||||
for token in callback["text"]:
|
||||
result["answer"] = token
|
||||
|
@ -125,7 +153,7 @@ class Chat:
|
|||
return StreamingResponse(chat_iterator(query, history),
|
||||
media_type="text/event-stream")
|
||||
|
||||
def create_task(self, query: str, history: List[History], model, **kargs):
|
||||
def create_task(self, query: str, history: List[History], model, llm_config: LLMConfig, embed_config: EmbedConfig, **kargs):
|
||||
'''构建 llm 生成任务'''
|
||||
chat_prompt = ChatPromptTemplate.from_messages(
|
||||
[i.to_msg_tuple() for i in history] + [("human", "{input}")]
|
||||
|
@ -134,7 +162,7 @@ class Chat:
|
|||
content = chain({"input": query})
|
||||
return {"answer": "", "docs": ""}, content
|
||||
|
||||
def create_atask(self, query, history, model, callback: AsyncIteratorCallbackHandler):
|
||||
def create_atask(self, query, history, model, llm_config: LLMConfig, embed_config: EmbedConfig, callback: AsyncIteratorCallbackHandler):
|
||||
chat_prompt = ChatPromptTemplate.from_messages(
|
||||
[i.to_msg_tuple() for i in history] + [("human", "{input}")]
|
||||
)
|
|
@ -8,7 +8,6 @@
|
|||
|
||||
from fastapi import Request, Body
|
||||
import os, asyncio
|
||||
from urllib.parse import urlencode
|
||||
from typing import List
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
|
@ -16,16 +15,19 @@ from langchain import LLMChain
|
|||
from langchain.callbacks import AsyncIteratorCallbackHandler
|
||||
from langchain.prompts.chat import ChatPromptTemplate
|
||||
|
||||
from configs.model_config import (
|
||||
llm_model_dict, LLM_MODEL, PROMPT_TEMPLATE,
|
||||
VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, CODE_PROMPT_TEMPLATE)
|
||||
from dev_opsgpt.chat.utils import History, wrap_done
|
||||
from dev_opsgpt.utils import BaseResponse
|
||||
# from configs.model_config import (
|
||||
# llm_model_dict, LLM_MODEL, PROMPT_TEMPLATE,
|
||||
# VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, CODE_PROMPT_TEMPLATE)
|
||||
from coagent.connector.configs.prompts import CODE_PROMPT_TEMPLATE
|
||||
from coagent.chat.utils import History, wrap_done
|
||||
from coagent.utils import BaseResponse
|
||||
from .base_chat import Chat
|
||||
from dev_opsgpt.llm_models import getChatModel
|
||||
from coagent.llm_models import getChatModel, getChatModelFromConfig
|
||||
|
||||
from dev_opsgpt.service.kb_api import search_docs, KBServiceFactory
|
||||
from dev_opsgpt.service.cb_api import search_code, cb_exists_api
|
||||
from coagent.llm_models.llm_config import LLMConfig, EmbedConfig
|
||||
|
||||
|
||||
from coagent.service.cb_api import search_code, cb_exists_api
|
||||
from loguru import logger
|
||||
import json
|
||||
|
||||
|
@ -51,12 +53,21 @@ class CodeChat(Chat):
|
|||
return BaseResponse(code=404, msg=f"未找到代码库 {self.engine_name}")
|
||||
return BaseResponse(code=200, msg=f"找到代码库 {self.engine_name}")
|
||||
|
||||
def _process(self, query: str, history: List[History], model):
|
||||
def _process(self, query: str, history: List[History], model, llm_config: LLMConfig, embed_config: EmbedConfig):
|
||||
'''process'''
|
||||
|
||||
codes_res = search_code(query=query, cb_name=self.engine_name, code_limit=self.code_limit,
|
||||
search_type=self.cb_search_type,
|
||||
history_node_list=self.history_node_list)
|
||||
history_node_list=self.history_node_list,
|
||||
api_key=llm_config.api_key,
|
||||
api_base_url=llm_config.api_base_url,
|
||||
model_name=llm_config.model_name,
|
||||
temperature=llm_config.temperature,
|
||||
embed_model=embed_config.embed_model,
|
||||
embed_model_path=embed_config.embed_model_path,
|
||||
embed_engine=embed_config.embed_engine,
|
||||
model_device=embed_config.model_device,
|
||||
)
|
||||
|
||||
context = codes_res['context']
|
||||
related_vertices = codes_res['related_vertices']
|
||||
|
@ -94,17 +105,30 @@ class CodeChat(Chat):
|
|||
stream: bool = Body(False, description="流式输出"),
|
||||
local_doc_url: bool = Body(False, description="知识文件返回本地路径(true)或URL(false)"),
|
||||
request: Request = None,
|
||||
|
||||
api_key: str = Body(os.environ.get("OPENAI_API_KEY")),
|
||||
api_base_url: str = Body(os.environ.get("API_BASE_URL")),
|
||||
embed_model: str = Body("", ),
|
||||
embed_model_path: str = Body("", ),
|
||||
embed_engine: str = Body("", ),
|
||||
model_name: str = Body("", ),
|
||||
temperature: float = Body(0.5, ),
|
||||
model_device: str = Body("", ),
|
||||
**kargs
|
||||
):
|
||||
params = locals()
|
||||
params.pop("self")
|
||||
llm_config: LLMConfig = LLMConfig(**params)
|
||||
embed_config: EmbedConfig = EmbedConfig(**params)
|
||||
self.engine_name = engine_name if isinstance(engine_name, str) else engine_name.default
|
||||
self.code_limit = code_limit
|
||||
self.stream = stream if isinstance(stream, bool) else stream.default
|
||||
self.local_doc_url = local_doc_url if isinstance(local_doc_url, bool) else local_doc_url.default
|
||||
self.request = request
|
||||
self.cb_search_type = cb_search_type
|
||||
return self._chat(query, history, **kargs)
|
||||
return self._chat(query, history, llm_config, embed_config, **kargs)
|
||||
|
||||
def _chat(self, query: str, history: List[History], **kargs):
|
||||
def _chat(self, query: str, history: List[History], llm_config: LLMConfig, embed_config: EmbedConfig, **kargs):
|
||||
history = [History(**h) if isinstance(h, dict) else h for h in history]
|
||||
|
||||
service_status = self.check_service_status()
|
||||
|
@ -112,9 +136,10 @@ class CodeChat(Chat):
|
|||
if service_status.code != 200: return service_status
|
||||
|
||||
def chat_iterator(query: str, history: List[History]):
|
||||
model = getChatModel()
|
||||
# model = getChatModel()
|
||||
model = getChatModelFromConfig(llm_config)
|
||||
|
||||
result, content = self.create_task(query, history, model, **kargs)
|
||||
result, content = self.create_task(query, history, model, llm_config, embed_config, **kargs)
|
||||
# logger.info('result={}'.format(result))
|
||||
# logger.info('content={}'.format(content))
|
||||
|
||||
|
@ -130,9 +155,9 @@ class CodeChat(Chat):
|
|||
return StreamingResponse(chat_iterator(query, history),
|
||||
media_type="text/event-stream")
|
||||
|
||||
def create_task(self, query: str, history: List[History], model):
|
||||
def create_task(self, query: str, history: List[History], model, llm_config: LLMConfig, embed_config: EmbedConfig):
|
||||
'''构建 llm 生成任务'''
|
||||
chain, context, result = self._process(query, history, model)
|
||||
chain, context, result = self._process(query, history, model, llm_config, embed_config)
|
||||
logger.info('chain={}'.format(chain))
|
||||
try:
|
||||
content = chain({"context": context, "question": query})
|
||||
|
@ -140,8 +165,8 @@ class CodeChat(Chat):
|
|||
content = {"text": str(e)}
|
||||
return result, content
|
||||
|
||||
def create_atask(self, query, history, model, callback: AsyncIteratorCallbackHandler):
|
||||
chain, context, result = self._process(query, history, model)
|
||||
def create_atask(self, query, history, model, llm_config: LLMConfig, embed_config: EmbedConfig, callback: AsyncIteratorCallbackHandler):
|
||||
chain, context, result = self._process(query, history, model, llm_config, embed_config)
|
||||
task = asyncio.create_task(wrap_done(
|
||||
chain.acall({"context": context, "question": query}), callback.done
|
||||
))
|
|
@ -8,13 +8,16 @@ from langchain.callbacks import AsyncIteratorCallbackHandler
|
|||
from langchain.prompts.chat import ChatPromptTemplate
|
||||
|
||||
|
||||
from configs.model_config import (
|
||||
llm_model_dict, LLM_MODEL, PROMPT_TEMPLATE,
|
||||
VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD)
|
||||
from dev_opsgpt.chat.utils import History, wrap_done
|
||||
from dev_opsgpt.utils import BaseResponse
|
||||
# from configs.model_config import (
|
||||
# llm_model_dict, LLM_MODEL, PROMPT_TEMPLATE,
|
||||
# VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD)
|
||||
from coagent.base_configs.env_config import KB_ROOT_PATH
|
||||
from coagent.connector.configs.prompts import ORIGIN_TEMPLATE_PROMPT
|
||||
from coagent.chat.utils import History, wrap_done
|
||||
from coagent.utils import BaseResponse
|
||||
from coagent.llm_models.llm_config import LLMConfig, EmbedConfig
|
||||
from .base_chat import Chat
|
||||
from dev_opsgpt.service.kb_api import search_docs, KBServiceFactory
|
||||
from coagent.service.kb_api import search_docs, KBServiceFactory
|
||||
from loguru import logger
|
||||
|
||||
|
||||
|
@ -23,26 +26,33 @@ class KnowledgeChat(Chat):
|
|||
def __init__(
|
||||
self,
|
||||
engine_name: str = "",
|
||||
top_k: int = VECTOR_SEARCH_TOP_K,
|
||||
top_k: int = 5,
|
||||
stream: bool = False,
|
||||
score_thresold: float = SCORE_THRESHOLD,
|
||||
score_thresold: float = 1.0,
|
||||
local_doc_url: bool = False,
|
||||
request: Request = None,
|
||||
kb_root_path: str = KB_ROOT_PATH,
|
||||
) -> None:
|
||||
super().__init__(engine_name, top_k, stream)
|
||||
self.score_thresold = score_thresold
|
||||
self.local_doc_url = local_doc_url
|
||||
self.request = request
|
||||
self.kb_root_path = kb_root_path
|
||||
|
||||
def check_service_status(self) -> BaseResponse:
|
||||
kb = KBServiceFactory.get_service_by_name(self.engine_name)
|
||||
kb = KBServiceFactory.get_service_by_name(self.engine_name, self.kb_root_path)
|
||||
if kb is None:
|
||||
return BaseResponse(code=404, msg=f"未找到知识库 {self.engine_name}")
|
||||
return BaseResponse(code=200, msg=f"找到知识库 {self.engine_name}")
|
||||
|
||||
def _process(self, query: str, history: List[History], model):
|
||||
def _process(self, query: str, history: List[History], model, llm_config: LLMConfig, embed_config: EmbedConfig, ):
|
||||
'''process'''
|
||||
docs = search_docs(query, self.engine_name, self.top_k, self.score_threshold)
|
||||
docs = search_docs(
|
||||
query, self.engine_name, self.top_k, self.score_threshold, self.kb_root_path,
|
||||
api_key=embed_config.api_key, api_base_url=embed_config.api_base_url, embed_model=embed_config.embed_model,
|
||||
embed_model_path=embed_config.embed_model_path, embed_engine=embed_config.embed_engine,
|
||||
model_device=embed_config.model_device,
|
||||
)
|
||||
context = "\n".join([doc.page_content for doc in docs])
|
||||
source_documents = []
|
||||
for inum, doc in enumerate(docs):
|
||||
|
@ -55,24 +65,24 @@ class KnowledgeChat(Chat):
|
|||
text = f"""出处 [{inum + 1}] [{filename}]({url}) \n\n{doc.page_content}\n\n"""
|
||||
source_documents.append(text)
|
||||
chat_prompt = ChatPromptTemplate.from_messages(
|
||||
[i.to_msg_tuple() for i in history] + [("human", PROMPT_TEMPLATE)]
|
||||
[i.to_msg_tuple() for i in history] + [("human", ORIGIN_TEMPLATE_PROMPT)]
|
||||
)
|
||||
chain = LLMChain(prompt=chat_prompt, llm=model)
|
||||
result = {"answer": "", "docs": source_documents}
|
||||
return chain, context, result
|
||||
|
||||
def create_task(self, query: str, history: List[History], model):
|
||||
def create_task(self, query: str, history: List[History], model, llm_config: LLMConfig, embed_config: EmbedConfig, ):
|
||||
'''构建 llm 生成任务'''
|
||||
logger.debug(f"query: {query}, history: {history}")
|
||||
chain, context, result = self._process(query, history, model)
|
||||
chain, context, result = self._process(query, history, model, llm_config, embed_config)
|
||||
try:
|
||||
content = chain({"context": context, "question": query})
|
||||
except Exception as e:
|
||||
content = {"text": str(e)}
|
||||
return result, content
|
||||
|
||||
def create_atask(self, query, history, model, callback: AsyncIteratorCallbackHandler):
|
||||
chain, context, result = self._process(query, history, model)
|
||||
def create_atask(self, query, history, model, llm_config: LLMConfig, embed_config: EmbedConfig, callback: AsyncIteratorCallbackHandler):
|
||||
chain, context, result = self._process(query, history, model, llm_config, embed_config)
|
||||
task = asyncio.create_task(wrap_done(
|
||||
chain.acall({"context": context, "question": query}), callback.done
|
||||
))
|
|
@ -6,7 +6,8 @@ from langchain.callbacks import AsyncIteratorCallbackHandler
|
|||
from langchain.prompts.chat import ChatPromptTemplate
|
||||
|
||||
|
||||
from dev_opsgpt.chat.utils import History, wrap_done
|
||||
from coagent.chat.utils import History, wrap_done
|
||||
from coagent.llm_models.llm_config import LLMConfig, EmbedConfig
|
||||
from .base_chat import Chat
|
||||
from loguru import logger
|
||||
|
||||
|
@ -21,7 +22,7 @@ class LLMChat(Chat):
|
|||
) -> None:
|
||||
super().__init__(engine_name, top_k, stream)
|
||||
|
||||
def create_task(self, query: str, history: List[History], model):
|
||||
def create_task(self, query: str, history: List[History], model, llm_config: LLMConfig, embed_config: EmbedConfig, **kargs):
|
||||
'''构建 llm 生成任务'''
|
||||
chat_prompt = ChatPromptTemplate.from_messages(
|
||||
[i.to_msg_tuple() for i in history] + [("human", "{input}")]
|
||||
|
@ -30,7 +31,7 @@ class LLMChat(Chat):
|
|||
content = chain({"input": query})
|
||||
return {"answer": "", "docs": ""}, content
|
||||
|
||||
def create_atask(self, query, history, model, callback: AsyncIteratorCallbackHandler):
|
||||
def create_atask(self, query, history, model, llm_config: LLMConfig, embed_config: EmbedConfig, callback: AsyncIteratorCallbackHandler):
|
||||
chat_prompt = ChatPromptTemplate.from_messages(
|
||||
[i.to_msg_tuple() for i in history] + [("human", "{input}")]
|
||||
)
|
|
@ -1,4 +1,3 @@
|
|||
from fastapi import Request
|
||||
import os, asyncio
|
||||
from typing import List, Optional, Dict
|
||||
|
||||
|
@ -8,11 +7,13 @@ from langchain.utilities import BingSearchAPIWrapper, DuckDuckGoSearchAPIWrapper
|
|||
from langchain.prompts.chat import ChatPromptTemplate
|
||||
from langchain.docstore.document import Document
|
||||
|
||||
from configs.model_config import (
|
||||
PROMPT_TEMPLATE, SEARCH_ENGINE_TOP_K, BING_SUBSCRIPTION_KEY, BING_SEARCH_URL,
|
||||
VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD)
|
||||
from dev_opsgpt.chat.utils import History, wrap_done
|
||||
from dev_opsgpt.utils import BaseResponse
|
||||
# from configs.model_config import (
|
||||
# PROMPT_TEMPLATE, SEARCH_ENGINE_TOP_K, BING_SUBSCRIPTION_KEY, BING_SEARCH_URL,
|
||||
# VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD)
|
||||
from coagent.connector.configs.prompts import ORIGIN_TEMPLATE_PROMPT
|
||||
from coagent.chat.utils import History, wrap_done
|
||||
from coagent.utils import BaseResponse
|
||||
from coagent.llm_models.llm_config import LLMConfig, EmbedConfig
|
||||
from .base_chat import Chat
|
||||
|
||||
from loguru import logger
|
||||
|
@ -20,19 +21,19 @@ from loguru import logger
|
|||
from duckduckgo_search import DDGS
|
||||
|
||||
|
||||
def bing_search(text, result_len=SEARCH_ENGINE_TOP_K):
|
||||
if not (BING_SEARCH_URL and BING_SUBSCRIPTION_KEY):
|
||||
return [{"snippet": "please set BING_SUBSCRIPTION_KEY and BING_SEARCH_URL in os ENV",
|
||||
"title": "env info is not found",
|
||||
"link": "https://python.langchain.com/en/latest/modules/agents/tools/examples/bing_search.html"}]
|
||||
search = BingSearchAPIWrapper(bing_subscription_key=BING_SUBSCRIPTION_KEY,
|
||||
bing_search_url=BING_SEARCH_URL)
|
||||
return search.results(text, result_len)
|
||||
# def bing_search(text, result_len=5):
|
||||
# if not (BING_SEARCH_URL and BING_SUBSCRIPTION_KEY):
|
||||
# return [{"snippet": "please set BING_SUBSCRIPTION_KEY and BING_SEARCH_URL in os ENV",
|
||||
# "title": "env info is not found",
|
||||
# "link": "https://python.langchain.com/en/latest/modules/agents/tools/examples/bing_search.html"}]
|
||||
# search = BingSearchAPIWrapper(bing_subscription_key=BING_SUBSCRIPTION_KEY,
|
||||
# bing_search_url=BING_SEARCH_URL)
|
||||
# return search.results(text, result_len)
|
||||
|
||||
|
||||
def duckduckgo_search(
|
||||
query: str,
|
||||
result_len: int = SEARCH_ENGINE_TOP_K,
|
||||
result_len: int = 5,
|
||||
region: Optional[str] = "wt-wt",
|
||||
safesearch: str = "moderate",
|
||||
time: Optional[str] = "y",
|
||||
|
@ -79,7 +80,7 @@ def duckduckgo_search(
|
|||
|
||||
|
||||
SEARCH_ENGINES = {"duckduckgo": duckduckgo_search,
|
||||
"bing": bing_search,
|
||||
# "bing": bing_search,
|
||||
}
|
||||
|
||||
|
||||
|
@ -96,7 +97,7 @@ def search_result2docs(search_results):
|
|||
def lookup_search_engine(
|
||||
query: str,
|
||||
search_engine_name: str,
|
||||
top_k: int = SEARCH_ENGINE_TOP_K,
|
||||
top_k: int = 5,
|
||||
):
|
||||
results = SEARCH_ENGINES[search_engine_name](query, result_len=top_k)
|
||||
docs = search_result2docs(results)
|
||||
|
@ -109,7 +110,7 @@ class SearchChat(Chat):
|
|||
def __init__(
|
||||
self,
|
||||
engine_name: str = "",
|
||||
top_k: int = VECTOR_SEARCH_TOP_K,
|
||||
top_k: int = 5,
|
||||
stream: bool = False,
|
||||
) -> None:
|
||||
super().__init__(engine_name, top_k, stream)
|
||||
|
@ -130,19 +131,19 @@ class SearchChat(Chat):
|
|||
]
|
||||
|
||||
chat_prompt = ChatPromptTemplate.from_messages(
|
||||
[i.to_msg_tuple() for i in history] + [("human", PROMPT_TEMPLATE)]
|
||||
[i.to_msg_tuple() for i in history] + [("human", ORIGIN_TEMPLATE_PROMPT)]
|
||||
)
|
||||
chain = LLMChain(prompt=chat_prompt, llm=model)
|
||||
result = {"answer": "", "docs": source_documents}
|
||||
return chain, context, result
|
||||
|
||||
def create_task(self, query: str, history: List[History], model):
|
||||
def create_task(self, query: str, history: List[History], model, llm_config: LLMConfig, embed_config: EmbedConfig, ):
|
||||
'''构建 llm 生成任务'''
|
||||
chain, context, result = self._process(query, history, model)
|
||||
content = chain({"context": context, "question": query})
|
||||
return result, content
|
||||
|
||||
def create_atask(self, query, history, model, callback: AsyncIteratorCallbackHandler):
|
||||
def create_atask(self, query, history, model, llm_config: LLMConfig, embed_config: EmbedConfig, callback: AsyncIteratorCallbackHandler):
|
||||
chain, context, result = self._process(query, history, model)
|
||||
task = asyncio.create_task(wrap_done(
|
||||
chain.acall({"context": context, "question": query}), callback.done
|
|
@ -8,17 +8,20 @@
|
|||
import time
|
||||
from loguru import logger
|
||||
|
||||
from dev_opsgpt.codechat.code_analyzer.code_static_analysis import CodeStaticAnalysis
|
||||
from dev_opsgpt.codechat.code_analyzer.code_intepreter import CodeIntepreter
|
||||
from dev_opsgpt.codechat.code_analyzer.code_preprocess import CodePreprocessor
|
||||
from dev_opsgpt.codechat.code_analyzer.code_dedup import CodeDedup
|
||||
from coagent.codechat.code_analyzer.code_static_analysis import CodeStaticAnalysis
|
||||
from coagent.codechat.code_analyzer.code_intepreter import CodeIntepreter
|
||||
from coagent.codechat.code_analyzer.code_preprocess import CodePreprocessor
|
||||
from coagent.codechat.code_analyzer.code_dedup import CodeDedup
|
||||
from coagent.llm_models.llm_config import LLMConfig
|
||||
|
||||
|
||||
|
||||
class CodeAnalyzer:
|
||||
def __init__(self, language: str):
|
||||
def __init__(self, language: str, llm_config: LLMConfig):
|
||||
self.llm_config = llm_config
|
||||
self.code_preprocessor = CodePreprocessor()
|
||||
self.code_debup = CodeDedup()
|
||||
self.code_interperter = CodeIntepreter()
|
||||
self.code_interperter = CodeIntepreter(self.llm_config)
|
||||
self.code_static_analyzer = CodeStaticAnalysis(language=language)
|
||||
|
||||
def analyze(self, code_dict: dict, do_interpret: bool = True):
|
|
@ -10,14 +10,15 @@ from langchain.schema import (
|
|||
HumanMessage,
|
||||
)
|
||||
|
||||
from configs.model_config import CODE_INTERPERT_TEMPLATE
|
||||
|
||||
from dev_opsgpt.llm_models.openai_model import getChatModel
|
||||
# from configs.model_config import CODE_INTERPERT_TEMPLATE
|
||||
from coagent.connector.configs.prompts import CODE_INTERPERT_TEMPLATE
|
||||
from coagent.llm_models.openai_model import getChatModel, getChatModelFromConfig
|
||||
from coagent.llm_models.llm_config import LLMConfig
|
||||
|
||||
|
||||
class CodeIntepreter:
|
||||
def __init__(self):
|
||||
pass
|
||||
def __init__(self, llm_config: LLMConfig):
|
||||
self.llm_config = llm_config
|
||||
|
||||
def get_intepretation(self, code_list):
|
||||
'''
|
||||
|
@ -25,7 +26,8 @@ class CodeIntepreter:
|
|||
@param code_list:
|
||||
@return:
|
||||
'''
|
||||
chat_model = getChatModel()
|
||||
# chat_model = getChatModel()
|
||||
chat_model = getChatModelFromConfig(self.llm_config)
|
||||
|
||||
res = {}
|
||||
for code in code_list:
|
||||
|
@ -42,7 +44,8 @@ class CodeIntepreter:
|
|||
@param code_list:
|
||||
@return:
|
||||
'''
|
||||
chat_model = getChatModel()
|
||||
# chat_model = getChatModel()
|
||||
chat_model = getChatModelFromConfig(self.llm_config)
|
||||
|
||||
res = {}
|
||||
messages = []
|
|
@ -5,7 +5,7 @@
|
|||
@time: 2023/11/21 下午2:28
|
||||
@desc:
|
||||
'''
|
||||
from dev_opsgpt.codechat.code_analyzer.language_static_analysis import *
|
||||
from coagent.codechat.code_analyzer.language_static_analysis import *
|
||||
|
||||
class CodeStaticAnalysis:
|
||||
def __init__(self, language):
|
|
@ -62,7 +62,7 @@ class JavaStaticAnalysis:
|
|||
|
||||
for node in tree.types:
|
||||
if type(node) in (javalang.tree.ClassDeclaration, javalang.tree.InterfaceDeclaration):
|
||||
class_name = pac_name + '#' + node.name
|
||||
class_name = tree.package.name + '.' + node.name
|
||||
class_name_list.append(class_name)
|
||||
|
||||
for node_inner in node.body:
|
||||
|
@ -108,6 +108,28 @@ class JavaStaticAnalysis:
|
|||
return res_dict
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
java_code_dict = {
|
||||
'test': '''package com.theokanning.openai;
|
||||
|
||||
import com.theokanning.openai.client.Utils;
|
||||
|
||||
|
||||
public class UtilsTest {
|
||||
public void testRemoveChar() {
|
||||
String input = "hello";
|
||||
char ch = 'l';
|
||||
String expected = "heo";
|
||||
String res = Utils.remove(input, ch);
|
||||
System.out.println(res.equals(expected));
|
||||
}
|
||||
}
|
||||
'''
|
||||
}
|
||||
|
||||
jsa = JavaStaticAnalysis()
|
||||
res = jsa.analyze(java_code_dict)
|
||||
logger.info(res)
|
||||
|
||||
|
||||
|
|
@ -8,7 +8,7 @@
|
|||
from loguru import logger
|
||||
|
||||
import zipfile
|
||||
from dev_opsgpt.codechat.code_crawler.dir_crawler import DirCrawler
|
||||
from coagent.codechat.code_crawler.dir_crawler import DirCrawler
|
||||
|
||||
|
||||
class ZipCrawler:
|
|
@ -9,13 +9,16 @@ import time
|
|||
from loguru import logger
|
||||
from collections import defaultdict
|
||||
|
||||
from dev_opsgpt.db_handler.graph_db_handler.nebula_handler import NebulaHandler
|
||||
from dev_opsgpt.db_handler.vector_db_handler.chroma_handler import ChromaHandler
|
||||
from coagent.db_handler.graph_db_handler.nebula_handler import NebulaHandler
|
||||
from coagent.db_handler.vector_db_handler.chroma_handler import ChromaHandler
|
||||
|
||||
from dev_opsgpt.codechat.code_search.cypher_generator import CypherGenerator
|
||||
from dev_opsgpt.codechat.code_search.tagger import Tagger
|
||||
from dev_opsgpt.embeddings.get_embedding import get_embedding
|
||||
from coagent.codechat.code_search.cypher_generator import CypherGenerator
|
||||
from coagent.codechat.code_search.tagger import Tagger
|
||||
from coagent.embeddings.get_embedding import get_embedding
|
||||
from coagent.llm_models.llm_config import LLMConfig
|
||||
|
||||
|
||||
# from configs.model_config import EMBEDDING_DEVICE, EMBEDDING_MODEL
|
||||
# search_by_tag
|
||||
VERTEX_SCORE = 10
|
||||
HISTORY_VERTEX_SCORE = 5
|
||||
|
@ -26,13 +29,14 @@ MAX_DISTANCE = 1000
|
|||
|
||||
|
||||
class CodeSearch:
|
||||
def __init__(self, nh: NebulaHandler, ch: ChromaHandler, limit: int = 3):
|
||||
def __init__(self, llm_config: LLMConfig, nh: NebulaHandler, ch: ChromaHandler, limit: int = 3):
|
||||
'''
|
||||
init
|
||||
@param nh: NebulaHandler
|
||||
@param ch: ChromaHandler
|
||||
@param limit: limit of result
|
||||
'''
|
||||
self.llm_config = llm_config
|
||||
self.nh = nh
|
||||
self.ch = ch
|
||||
self.limit = limit
|
||||
|
@ -50,7 +54,6 @@ class CodeSearch:
|
|||
# get all verticex
|
||||
vertex_list = self.nh.get_vertices().get('v', [])
|
||||
vertex_vid_list = [i.as_node().get_id().as_string() for i in vertex_list]
|
||||
logger.debug(vertex_vid_list)
|
||||
|
||||
# update score
|
||||
vertex_score_dict = defaultdict(lambda: 0)
|
||||
|
@ -77,8 +80,26 @@ class CodeSearch:
|
|||
|
||||
# get most prominent package tag
|
||||
package_score_dict = defaultdict(lambda: 0)
|
||||
|
||||
for vertex, score in vertex_score_dict.items():
|
||||
package = '#'.join(vertex.split('#')[0:2])
|
||||
if '#' in vertex:
|
||||
# get class name first
|
||||
cypher = f'''MATCH (v1:class)-[e:contain]->(v2) WHERE id(v2) == '{vertex}' RETURN id(v1) as id;'''
|
||||
cypher_res = self.nh.execute_cypher(cypher=cypher, format_res=True)
|
||||
class_vertices = cypher_res.get('id', [])
|
||||
if not class_vertices:
|
||||
continue
|
||||
|
||||
vertex = class_vertices[0].as_string()
|
||||
|
||||
# get package name
|
||||
cypher = f'''MATCH (v1:package)-[e:contain]->(v2) WHERE id(v2) == '{vertex}' RETURN id(v1) as id;'''
|
||||
cypher_res = self.nh.execute_cypher(cypher=cypher, format_res=True)
|
||||
pac_vertices = cypher_res.get('id', [])
|
||||
if not pac_vertices:
|
||||
continue
|
||||
|
||||
package = pac_vertices[0].as_string()
|
||||
package_score_dict[package] += score
|
||||
|
||||
# get respective code
|
||||
|
@ -87,7 +108,10 @@ class CodeSearch:
|
|||
package_score_tuple.sort(key=lambda x: x[1], reverse=True)
|
||||
|
||||
ids = [i[0] for i in package_score_tuple]
|
||||
logger.info(f'ids={ids}')
|
||||
chroma_res = self.ch.get(ids=ids, include=['metadatas'])
|
||||
|
||||
# logger.info(chroma_res)
|
||||
for vertex, score in package_score_tuple:
|
||||
index = chroma_res['result']['ids'].index(vertex)
|
||||
code_text = chroma_res['result']['metadatas'][index]['code_text']
|
||||
|
@ -97,16 +121,17 @@ class CodeSearch:
|
|||
)
|
||||
if len(res) >= self.limit:
|
||||
break
|
||||
logger.info(f'retrival code={res}')
|
||||
return res
|
||||
|
||||
def search_by_desciption(self, query: str, engine: str):
|
||||
def search_by_desciption(self, query: str, engine: str, model_path: str = "text2vec-base-chinese", embedding_device: str = "cpu"):
|
||||
'''
|
||||
search by perform sim search
|
||||
@param query:
|
||||
@return:
|
||||
'''
|
||||
query = query.replace(',', ',')
|
||||
query_emb = get_embedding(engine=engine, text_list=[query])
|
||||
query_emb = get_embedding(engine=engine, text_list=[query], model_path=model_path, embedding_device= embedding_device,)
|
||||
query_emb = query_emb[query]
|
||||
|
||||
query_embeddings = [query_emb]
|
||||
|
@ -133,7 +158,7 @@ class CodeSearch:
|
|||
@param engine:
|
||||
@return:
|
||||
'''
|
||||
cg = CypherGenerator()
|
||||
cg = CypherGenerator(self.llm_config)
|
||||
cypher = cg.get_cypher(query)
|
||||
|
||||
if not cypher:
|
||||
|
@ -156,9 +181,12 @@ class CodeSearch:
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
from configs.server_config import NEBULA_HOST, NEBULA_PORT, NEBULA_USER, NEBULA_PASSWORD, NEBULA_STORAGED_PORT
|
||||
from configs.server_config import CHROMA_PERSISTENT_PATH
|
||||
|
||||
# from configs.server_config import NEBULA_HOST, NEBULA_PORT, NEBULA_USER, NEBULA_PASSWORD, NEBULA_STORAGED_PORT
|
||||
# from configs.server_config import CHROMA_PERSISTENT_PATH
|
||||
from coagent.base_configs.env_config import (
|
||||
NEBULA_HOST, NEBULA_PORT, NEBULA_USER, NEBULA_PASSWORD, NEBULA_STORAGED_PORT,
|
||||
CHROMA_PERSISTENT_PATH
|
||||
)
|
||||
codebase_name = 'testing'
|
||||
|
||||
nh = NebulaHandler(host=NEBULA_HOST, port=NEBULA_PORT, username=NEBULA_USER,
|
|
@ -0,0 +1,82 @@
|
|||
# encoding: utf-8
|
||||
'''
|
||||
@author: 温进
|
||||
@file: cypher_generator.py
|
||||
@time: 2023/11/24 上午10:17
|
||||
@desc:
|
||||
'''
|
||||
from langchain import PromptTemplate
|
||||
from loguru import logger
|
||||
|
||||
from coagent.llm_models.openai_model import getChatModel, getChatModelFromConfig
|
||||
from coagent.llm_models.llm_config import LLMConfig
|
||||
from coagent.utils.postprocess import replace_lt_gt
|
||||
from langchain.schema import (
|
||||
HumanMessage,
|
||||
)
|
||||
from langchain.chains.graph_qa.prompts import NGQL_GENERATION_PROMPT, CYPHER_GENERATION_TEMPLATE
|
||||
|
||||
schema = '''
|
||||
Node properties: [{'tag': 'package', 'properties': []}, {'tag': 'class', 'properties': []}, {'tag': 'method', 'properties': []}]
|
||||
Edge properties: [{'edge': 'contain', 'properties': []}, {'edge': 'depend', 'properties': []}]
|
||||
Relationships: ['(:package)-[:contain]->(:class)', '(:class)-[:contain]->(:method)', '(:package)-[:contain]->(:package)']
|
||||
'''
|
||||
|
||||
|
||||
class CypherGenerator:
|
||||
def __init__(self, llm_config: LLMConfig):
|
||||
self.model = getChatModelFromConfig(llm_config)
|
||||
NEBULAGRAPH_EXTRA_INSTRUCTIONS = """
|
||||
Instructions:
|
||||
|
||||
First, generate cypher then convert it to NebulaGraph Cypher dialect(rather than standard):
|
||||
1. it requires explicit label specification only when referring to node properties: v.`Foo`.name
|
||||
2. note explicit label specification is not needed for edge properties, so it's e.name instead of e.`Bar`.name
|
||||
3. it uses double equals sign for comparison: `==` rather than `=`
|
||||
4. only use id(Foo) to get the name of node or edge
|
||||
```\n"""
|
||||
|
||||
NGQL_GENERATION_TEMPLATE = CYPHER_GENERATION_TEMPLATE.replace(
|
||||
"Generate Cypher", "Generate NebulaGraph Cypher"
|
||||
).replace("Instructions:", NEBULAGRAPH_EXTRA_INSTRUCTIONS)
|
||||
|
||||
self.NGQL_GENERATION_PROMPT = PromptTemplate(
|
||||
input_variables=["schema", "question"], template=NGQL_GENERATION_TEMPLATE
|
||||
)
|
||||
|
||||
def get_cypher(self, query: str):
|
||||
'''
|
||||
get cypher from query
|
||||
@param query:
|
||||
@return:
|
||||
'''
|
||||
content = self.NGQL_GENERATION_PROMPT.format(schema=schema, question=query)
|
||||
logger.info(content)
|
||||
ans = ''
|
||||
message = [HumanMessage(content=content)]
|
||||
chat_res = self.model.predict_messages(message)
|
||||
ans = chat_res.content
|
||||
|
||||
ans = replace_lt_gt(ans)
|
||||
|
||||
ans = self.post_process(ans)
|
||||
return ans
|
||||
|
||||
def post_process(self, cypher_res: str):
|
||||
'''
|
||||
判断是否为正确的 cypher
|
||||
@param cypher_res:
|
||||
@return:
|
||||
'''
|
||||
if '(' not in cypher_res or ')' not in cypher_res:
|
||||
return ''
|
||||
|
||||
return cypher_res
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
query = '代码库里有哪些函数,返回5个就可以'
|
||||
cg = CypherGenerator()
|
||||
|
||||
ans = cg.get_cypher(query)
|
||||
logger.debug(f'ans=\n{ans}')
|
|
@ -20,4 +20,20 @@ class Tagger:
|
|||
# simple extract english
|
||||
tag_list = re.findall(r'[a-zA-Z\_\.]+', query)
|
||||
tag_list = list(set(tag_list))
|
||||
tag_list = self.filter_tag_list(tag_list)
|
||||
return tag_list
|
||||
|
||||
def filter_tag_list(self, tag_list):
|
||||
'''
|
||||
filter out tag
|
||||
@param tag_list:
|
||||
@return:
|
||||
'''
|
||||
res = []
|
||||
for tag in tag_list:
|
||||
if tag in ['java', 'python']:
|
||||
continue
|
||||
res.append(tag)
|
||||
return res
|
||||
|
||||
|
|
@ -8,17 +8,20 @@
|
|||
import time
|
||||
from loguru import logger
|
||||
|
||||
from configs.server_config import NEBULA_HOST, NEBULA_PORT, NEBULA_USER, NEBULA_PASSWORD, NEBULA_STORAGED_PORT
|
||||
from configs.server_config import CHROMA_PERSISTENT_PATH
|
||||
from dev_opsgpt.db_handler.graph_db_handler.nebula_handler import NebulaHandler
|
||||
from dev_opsgpt.db_handler.vector_db_handler.chroma_handler import ChromaHandler
|
||||
from dev_opsgpt.embeddings.get_embedding import get_embedding
|
||||
# from configs.server_config import NEBULA_HOST, NEBULA_PORT, NEBULA_USER, NEBULA_PASSWORD, NEBULA_STORAGED_PORT
|
||||
# from configs.server_config import CHROMA_PERSISTENT_PATH
|
||||
# from configs.model_config import EMBEDDING_DEVICE, EMBEDDING_MODEL
|
||||
from coagent.db_handler.graph_db_handler.nebula_handler import NebulaHandler
|
||||
from coagent.db_handler.vector_db_handler.chroma_handler import ChromaHandler
|
||||
from coagent.embeddings.get_embedding import get_embedding
|
||||
from coagent.llm_models.llm_config import EmbedConfig
|
||||
|
||||
|
||||
class CodeImporter:
|
||||
def __init__(self, codebase_name: str, engine: str, nh: NebulaHandler, ch: ChromaHandler):
|
||||
def __init__(self, codebase_name: str, embed_config: EmbedConfig, nh: NebulaHandler, ch: ChromaHandler):
|
||||
self.codebase_name = codebase_name
|
||||
self.engine = engine
|
||||
# self.engine = engine
|
||||
self.embed_config: EmbedConfig= embed_config
|
||||
self.nh = nh
|
||||
self.ch = ch
|
||||
|
||||
|
@ -27,9 +30,36 @@ class CodeImporter:
|
|||
import code to nebula and chroma
|
||||
@return:
|
||||
'''
|
||||
static_analysis_res = self.filter_out_vertex(static_analysis_res, interpretation)
|
||||
logger.info(f'static_analysis_res={static_analysis_res}')
|
||||
|
||||
self.analysis_res_to_graph(static_analysis_res)
|
||||
self.interpretation_to_db(static_analysis_res, interpretation, do_interpret)
|
||||
|
||||
def filter_out_vertex(self, static_analysis_res, interpretation):
|
||||
'''
|
||||
filter out nonexist vertices
|
||||
@param static_analysis_res:
|
||||
@param interpretation:
|
||||
@return:
|
||||
'''
|
||||
save_pac_name = set()
|
||||
for i, j in static_analysis_res.items():
|
||||
save_pac_name.add(j['pac_name'])
|
||||
|
||||
for class_name in j['class_name_list']:
|
||||
save_pac_name.add(class_name)
|
||||
save_pac_name.update(j['func_name_dict'].get(class_name, []))
|
||||
|
||||
for _, structure in static_analysis_res.items():
|
||||
new_pac_name_list = []
|
||||
for i in structure['import_pac_name_list']:
|
||||
if i in save_pac_name:
|
||||
new_pac_name_list.append(i)
|
||||
|
||||
structure['import_pac_name_list'] = new_pac_name_list
|
||||
return static_analysis_res
|
||||
|
||||
def analysis_res_to_graph(self, static_analysis_res):
|
||||
'''
|
||||
transform static_analysis_res to tuple
|
||||
|
@ -93,7 +123,7 @@ class CodeImporter:
|
|||
|
||||
return
|
||||
|
||||
def interpretation_to_db(self, static_analysis_res, interpretation, do_interpret):
|
||||
def interpretation_to_db(self, static_analysis_res, interpretation, do_interpret, ):
|
||||
'''
|
||||
vectorize interpretation and save to db
|
||||
@return:
|
||||
|
@ -102,7 +132,7 @@ class CodeImporter:
|
|||
if do_interpret:
|
||||
logger.info('start get embedding for interpretion')
|
||||
interp_list = list(interpretation.values())
|
||||
emb = get_embedding(engine=self.engine, text_list=interp_list)
|
||||
emb = get_embedding(engine=self.embed_config.embed_engine, text_list=interp_list, model_path=self.embed_config.embed_model_path, embedding_device= self.embed_config.model_device)
|
||||
logger.info('get embedding done')
|
||||
else:
|
||||
emb = {i: [0] for i in list(interpretation.values())}
|
||||
|
@ -113,6 +143,9 @@ class CodeImporter:
|
|||
metadatas = []
|
||||
|
||||
for code_text, interp in interpretation.items():
|
||||
if code_text not in static_analysis_res:
|
||||
continue
|
||||
|
||||
pac_name = static_analysis_res[code_text]['pac_name']
|
||||
if pac_name in ids:
|
||||
continue
|
|
@ -8,26 +8,41 @@
|
|||
import time
|
||||
from loguru import logger
|
||||
|
||||
from configs.server_config import NEBULA_HOST, NEBULA_PORT, NEBULA_USER, NEBULA_PASSWORD, NEBULA_STORAGED_PORT
|
||||
from configs.server_config import CHROMA_PERSISTENT_PATH
|
||||
# from configs.server_config import NEBULA_HOST, NEBULA_PORT, NEBULA_USER, NEBULA_PASSWORD, NEBULA_STORAGED_PORT
|
||||
# from configs.server_config import CHROMA_PERSISTENT_PATH
|
||||
# from configs.model_config import EMBEDDING_ENGINE
|
||||
|
||||
from configs.model_config import EMBEDDING_ENGINE
|
||||
from coagent.base_configs.env_config import (
|
||||
NEBULA_HOST, NEBULA_PORT, NEBULA_USER, NEBULA_PASSWORD, NEBULA_STORAGED_PORT,
|
||||
CHROMA_PERSISTENT_PATH
|
||||
)
|
||||
|
||||
from dev_opsgpt.db_handler.graph_db_handler.nebula_handler import NebulaHandler
|
||||
from dev_opsgpt.db_handler.vector_db_handler.chroma_handler import ChromaHandler
|
||||
from dev_opsgpt.codechat.code_crawler.zip_crawler import *
|
||||
from dev_opsgpt.codechat.code_analyzer.code_analyzer import CodeAnalyzer
|
||||
from dev_opsgpt.codechat.codebase_handler.code_importer import CodeImporter
|
||||
from dev_opsgpt.codechat.code_search.code_search import CodeSearch
|
||||
|
||||
from coagent.db_handler.graph_db_handler.nebula_handler import NebulaHandler
|
||||
from coagent.db_handler.vector_db_handler.chroma_handler import ChromaHandler
|
||||
from coagent.codechat.code_crawler.zip_crawler import *
|
||||
from coagent.codechat.code_analyzer.code_analyzer import CodeAnalyzer
|
||||
from coagent.codechat.codebase_handler.code_importer import CodeImporter
|
||||
from coagent.codechat.code_search.code_search import CodeSearch
|
||||
from coagent.llm_models.llm_config import EmbedConfig, LLMConfig
|
||||
|
||||
|
||||
class CodeBaseHandler:
|
||||
def __init__(self, codebase_name: str, code_path: str = '',
|
||||
language: str = 'java', crawl_type: str = 'ZIP'):
|
||||
def __init__(
|
||||
self,
|
||||
codebase_name: str,
|
||||
code_path: str = '',
|
||||
language: str = 'java',
|
||||
crawl_type: str = 'ZIP',
|
||||
embed_config: EmbedConfig = EmbedConfig(),
|
||||
llm_config: LLMConfig = LLMConfig()
|
||||
):
|
||||
self.codebase_name = codebase_name
|
||||
self.code_path = code_path
|
||||
self.language = language
|
||||
self.crawl_type = crawl_type
|
||||
self.embed_config = embed_config
|
||||
self.llm_config = llm_config
|
||||
|
||||
self.nh = NebulaHandler(host=NEBULA_HOST, port=NEBULA_PORT, username=NEBULA_USER,
|
||||
password=NEBULA_PASSWORD, space_name=codebase_name)
|
||||
|
@ -42,7 +57,7 @@ class CodeBaseHandler:
|
|||
@return:
|
||||
'''
|
||||
# init graph to init tag and edge
|
||||
code_importer = CodeImporter(engine=EMBEDDING_ENGINE, codebase_name=self.codebase_name,
|
||||
code_importer = CodeImporter(embed_config=self.embed_config, codebase_name=self.codebase_name,
|
||||
nh=self.nh, ch=self.ch)
|
||||
code_importer.init_graph()
|
||||
time.sleep(5)
|
||||
|
@ -56,7 +71,7 @@ class CodeBaseHandler:
|
|||
# analyze code
|
||||
logger.info('start analyze')
|
||||
st1 = time.time()
|
||||
code_analyzer = CodeAnalyzer(language=self.language)
|
||||
code_analyzer = CodeAnalyzer(language=self.language, llm_config = self.llm_config)
|
||||
static_analysis_res, interpretation = code_analyzer.analyze(code_dict, do_interpret=do_interpret)
|
||||
logger.debug('analyze done, rt={}'.format(time.time() - st1))
|
||||
|
||||
|
@ -111,14 +126,15 @@ class CodeBaseHandler:
|
|||
'''
|
||||
assert search_type in ['cypher', 'tag', 'description']
|
||||
|
||||
code_search = CodeSearch(nh=self.nh, ch=self.ch, limit=limit)
|
||||
code_search = CodeSearch(llm_config=self.llm_config, nh=self.nh, ch=self.ch, limit=limit)
|
||||
|
||||
if search_type == 'cypher':
|
||||
search_res = code_search.search_by_cypher(query=query)
|
||||
elif search_type == 'tag':
|
||||
search_res = code_search.search_by_tag(query=query)
|
||||
elif search_type == 'description':
|
||||
search_res = code_search.search_by_desciption(query=query, engine=EMBEDDING_ENGINE)
|
||||
search_res = code_search.search_by_desciption(
|
||||
query=query, engine=self.embed_config.embed_engine, model_path=self.embed_config.embed_model_path, embedding_device=self.embed_config.model_device)
|
||||
|
||||
context, related_vertice = self.format_search_res(search_res, search_type)
|
||||
return context, related_vertice
|
|
@ -1,9 +1,8 @@
|
|||
from .base_agent import BaseAgent
|
||||
from .react_agent import ReactAgent
|
||||
from .check_agent import CheckAgent
|
||||
from .executor_agent import ExecutorAgent
|
||||
from .selector_agent import SelectorAgent
|
||||
|
||||
__all__ = [
|
||||
"BaseAgent", "ReactAgent", "CheckAgent", "ExecutorAgent", "SelectorAgent"
|
||||
"BaseAgent", "ReactAgent", "ExecutorAgent", "SelectorAgent"
|
||||
]
|
|
@ -1,113 +1,218 @@
|
|||
from pydantic import BaseModel
|
||||
from typing import List, Union
|
||||
import re
|
||||
import importlib
|
||||
import re, os
|
||||
import copy
|
||||
import json
|
||||
import traceback
|
||||
import uuid
|
||||
from loguru import logger
|
||||
|
||||
from dev_opsgpt.connector.schema import (
|
||||
Memory, Task, Env, Role, Message, ActionStatus, CodeDoc, Doc
|
||||
from coagent.connector.schema import (
|
||||
Memory, Task, Role, Message, PromptField, LogVerboseEnum
|
||||
)
|
||||
from configs.server_config import SANDBOX_SERVER
|
||||
from dev_opsgpt.sandbox import PyCodeBox, CodeBoxResponse
|
||||
from dev_opsgpt.tools import DDGSTool, DocRetrieval, CodeRetrieval
|
||||
from dev_opsgpt.connector.configs.prompts import BASE_PROMPT_INPUT, QUERY_CONTEXT_DOC_PROMPT_INPUT, BEGIN_PROMPT_INPUT
|
||||
from dev_opsgpt.connector.message_process import MessageUtils
|
||||
from dev_opsgpt.connector.configs.agent_config import REACT_PROMPT_INPUT, QUERY_CONTEXT_PROMPT_INPUT, PLAN_PROMPT_INPUT
|
||||
|
||||
from dev_opsgpt.llm_models import getChatModel, getExtraModel
|
||||
from dev_opsgpt.connector.utils import parse_section
|
||||
|
||||
|
||||
from coagent.connector.memory_manager import BaseMemoryManager
|
||||
from coagent.connector.configs.prompts import BEGIN_PROMPT_INPUT
|
||||
from coagent.connector.message_process import MessageUtils
|
||||
from coagent.llm_models import getChatModel, getExtraModel, LLMConfig, getChatModelFromConfig, EmbedConfig
|
||||
from coagent.connector.prompt_manager import PromptManager
|
||||
from coagent.connector.memory_manager import LocalMemoryManager
|
||||
from coagent.connector.utils import parse_section
|
||||
# from configs.model_config import JUPYTER_WORK_PATH
|
||||
# from configs.server_config import SANDBOX_SERVER
|
||||
|
||||
class BaseAgent:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
role: Role,
|
||||
prompt_config: [PromptField],
|
||||
prompt_manager_type: str = "PromptManager",
|
||||
task: Task = None,
|
||||
memory: Memory = None,
|
||||
chat_turn: int = 1,
|
||||
do_search: bool = False,
|
||||
do_doc_retrieval: bool = False,
|
||||
do_tool_retrieval: bool = False,
|
||||
temperature: float = 0.2,
|
||||
stop: Union[List[str], str] = None,
|
||||
do_filter: bool = True,
|
||||
do_use_self_memory: bool = True,
|
||||
focus_agents: List[str] = [],
|
||||
focus_message_keys: List[str] = [],
|
||||
# prompt_mamnger: PromptManager
|
||||
#
|
||||
llm_config: LLMConfig = None,
|
||||
embed_config: EmbedConfig = None,
|
||||
sandbox_server: dict = {},
|
||||
jupyter_work_path: str = "",
|
||||
kb_root_path: str = "",
|
||||
log_verbose: str = "0"
|
||||
):
|
||||
|
||||
self.task = task
|
||||
self.role = role
|
||||
self.message_utils = MessageUtils(role)
|
||||
self.llm = self.create_llm_engine(temperature, stop)
|
||||
self.sandbox_server = sandbox_server
|
||||
self.jupyter_work_path = jupyter_work_path
|
||||
self.kb_root_path = kb_root_path
|
||||
self.message_utils = MessageUtils(role, sandbox_server, jupyter_work_path, embed_config, llm_config, kb_root_path, log_verbose)
|
||||
self.memory = self.init_history(memory)
|
||||
self.llm_config: LLMConfig = llm_config
|
||||
self.embed_config: EmbedConfig = embed_config
|
||||
self.llm = self.create_llm_engine(llm_config=self.llm_config)
|
||||
self.chat_turn = chat_turn
|
||||
self.do_search = do_search
|
||||
self.do_doc_retrieval = do_doc_retrieval
|
||||
self.do_tool_retrieval = do_tool_retrieval
|
||||
#
|
||||
self.focus_agents = focus_agents
|
||||
self.focus_message_keys = focus_message_keys
|
||||
self.do_filter = do_filter
|
||||
self.do_use_self_memory = do_use_self_memory
|
||||
# self.prompt_manager = None
|
||||
#
|
||||
prompt_manager_module = importlib.import_module("coagent.connector.prompt_manager")
|
||||
prompt_manager = getattr(prompt_manager_module, prompt_manager_type)
|
||||
self.prompt_manager: PromptManager = prompt_manager(role_prompt=role.role_prompt, prompt_config=prompt_config)
|
||||
self.log_verbose = max(os.environ.get("log_verbose", "0"), log_verbose)
|
||||
|
||||
def run(self, query: Message, history: Memory = None, background: Memory = None, memory_pool: Memory=None) -> Message:
|
||||
def step(self, query: Message, history: Memory = None, background: Memory = None, memory_manager: BaseMemoryManager=None) -> Message:
|
||||
'''agent reponse from multi-message'''
|
||||
message = None
|
||||
for message in self.arun(query, history, background, memory_pool):
|
||||
for message in self.astep(query, history, background, memory_manager):
|
||||
pass
|
||||
return message
|
||||
|
||||
def arun(self, query: Message, history: Memory = None, background: Memory = None, memory_pool: Memory=None) -> Message:
|
||||
def astep(self, query: Message, history: Memory = None, background: Memory = None, memory_manager: BaseMemoryManager=None) -> Message:
|
||||
'''agent reponse from multi-message'''
|
||||
# insert query into memory
|
||||
query_c = copy.deepcopy(query)
|
||||
query_c = self.start_action_step(query_c)
|
||||
|
||||
self_memory = self.memory if self.do_use_self_memory else None
|
||||
# create your llm prompt
|
||||
prompt = self.create_prompt(query_c, self_memory, history, background, memory_pool=memory_pool)
|
||||
|
||||
# llm predict
|
||||
# prompt = self.create_prompt(query_c, self.memory, history, background, memory_pool=memory_manager.current_memory)
|
||||
if memory_manager is None:
|
||||
memory_manager = LocalMemoryManager(
|
||||
unique_name=self.role.role_name,
|
||||
do_init=True,
|
||||
kb_root_path = self.kb_root_path,
|
||||
embed_config=self.embed_config,
|
||||
llm_config=self.embed_config
|
||||
)
|
||||
memory_manager.append(query)
|
||||
memory_pool = memory_manager.current_memory
|
||||
else:
|
||||
memory_pool = memory_manager.current_memory
|
||||
|
||||
|
||||
logger.debug(f"memory_pool: {memory_pool}")
|
||||
prompt = self.prompt_manager.generate_full_prompt(
|
||||
previous_agent_message=query_c, agent_long_term_memory=self.memory, ui_history=history, chain_summary_messages=background, memory_pool=memory_pool)
|
||||
content = self.llm.predict(prompt)
|
||||
logger.debug(f"{self.role.role_name} prompt: {prompt}")
|
||||
logger.debug(f"{self.role.role_name} content: {content}")
|
||||
|
||||
if LogVerboseEnum.ge(LogVerboseEnum.Log2Level, self.log_verbose):
|
||||
logger.debug(f"{self.role.role_name} prompt: {prompt}")
|
||||
|
||||
if LogVerboseEnum.ge(LogVerboseEnum.Log1Level, self.log_verbose):
|
||||
logger.info(f"{self.role.role_name} content: {content}")
|
||||
|
||||
output_message = Message(
|
||||
role_name=self.role.role_name,
|
||||
role_type="ai", #self.role.role_type,
|
||||
role_type="assistant", #self.role.role_type,
|
||||
role_content=content,
|
||||
step_content=content,
|
||||
input_query=query_c.input_query,
|
||||
tools=query_c.tools,
|
||||
parsed_output_list=[query.parsed_output],
|
||||
# parsed_output_list=[query.parsed_output],
|
||||
customed_kargs=query_c.customed_kargs
|
||||
)
|
||||
|
||||
# common parse llm' content to message
|
||||
output_message = self.message_utils.parser(output_message)
|
||||
if self.do_filter:
|
||||
output_message = self.message_utils.filter(output_message)
|
||||
|
||||
# action step
|
||||
output_message, observation_message = self.message_utils.step_router(output_message, history, background, memory_pool=memory_pool)
|
||||
output_message, observation_message = self.message_utils.step_router(output_message, history, background, memory_manager=memory_manager)
|
||||
output_message.parsed_output_list.append(output_message.parsed_output)
|
||||
if observation_message:
|
||||
output_message.parsed_output_list.append(observation_message.parsed_output)
|
||||
|
||||
# update self_memory
|
||||
self.append_history(query_c)
|
||||
self.append_history(output_message)
|
||||
# logger.info(f"{self.role.role_name} currenct question: {output_message.input_query}\nllm_step_run: {output_message.role_content}")
|
||||
|
||||
output_message.input_query = output_message.role_content
|
||||
# output_message.parsed_output_list.append(output_message.parsed_output) # 与上述重复?
|
||||
# end
|
||||
output_message = self.message_utils.inherit_extrainfo(query, output_message)
|
||||
output_message = self.end_action_step(output_message)
|
||||
|
||||
# update memory pool
|
||||
memory_pool.append(output_message)
|
||||
memory_manager.append(output_message)
|
||||
yield output_message
|
||||
|
||||
def pre_print(self, query: Message, history: Memory = None, background: Memory = None, memory_manager: BaseMemoryManager=None):
|
||||
prompt = self.prompt_manager.pre_print(
|
||||
previous_agent_message=query, agent_long_term_memory=self.memory, ui_history=history, chain_summary_messages=background, memory_pool=memory_manager.current_memory)
|
||||
title = f"<<<<{self.role.role_name}'s prompt>>>>"
|
||||
print("#"*len(title) + f"\n{title}\n"+ "#"*len(title)+ f"\n\n{prompt}\n\n")
|
||||
|
||||
def init_history(self, memory: Memory = None) -> Memory:
|
||||
return Memory(messages=[])
|
||||
|
||||
def update_history(self, message: Message):
|
||||
self.memory.append(message)
|
||||
|
||||
def append_history(self, message: Message):
|
||||
self.memory.append(message)
|
||||
|
||||
def clear_history(self, ):
|
||||
self.memory.clear()
|
||||
self.memory = self.init_history()
|
||||
|
||||
def create_llm_engine(self, llm_config: LLMConfig = None, temperature=0.2, stop=None):
|
||||
if llm_config is None:
|
||||
return getChatModel(temperature=temperature, stop=stop)
|
||||
else:
|
||||
return getChatModelFromConfig(llm_config=llm_config)
|
||||
|
||||
def registry_actions(self, actions):
|
||||
'''registry llm's actions'''
|
||||
self.action_list = actions
|
||||
|
||||
def start_action_step(self, message: Message) -> Message:
|
||||
'''do action before agent predict '''
|
||||
# action_json = self.start_action()
|
||||
# message["customed_kargs"]["xx"] = action_json
|
||||
return message
|
||||
|
||||
def end_action_step(self, message: Message) -> Message:
|
||||
'''do action after agent predict '''
|
||||
# action_json = self.end_action()
|
||||
# message["customed_kargs"]["xx"] = action_json
|
||||
return message
|
||||
|
||||
def token_usage(self, ):
|
||||
'''calculate the usage of token'''
|
||||
pass
|
||||
|
||||
def select_memory_by_key(self, memory: Memory) -> Memory:
|
||||
return Memory(
|
||||
messages=[self.select_message_by_key(message) for message in memory.messages
|
||||
if self.select_message_by_key(message) is not None]
|
||||
)
|
||||
|
||||
def select_memory_by_agent_key(self, memory: Memory) -> Memory:
|
||||
return Memory(
|
||||
messages=[self.select_message_by_agent_key(message) for message in memory.messages
|
||||
if self.select_message_by_agent_key(message) is not None]
|
||||
)
|
||||
|
||||
def select_message_by_agent_key(self, message: Message) -> Message:
|
||||
# assume we focus all agents
|
||||
if self.focus_agents == []:
|
||||
return message
|
||||
return None if message is None or message.role_name not in self.focus_agents else self.select_message_by_key(message)
|
||||
|
||||
def select_message_by_key(self, message: Message) -> Message:
|
||||
# assume we focus all key contents
|
||||
if message is None:
|
||||
return message
|
||||
|
||||
if self.focus_message_keys == []:
|
||||
return message
|
||||
|
||||
message_c = copy.deepcopy(message)
|
||||
message_c.parsed_output = {k: v for k,v in message_c.parsed_output.items() if k in self.focus_message_keys}
|
||||
message_c.parsed_output_list = [{k: v for k,v in parsed_output.items() if k in self.focus_message_keys} for parsed_output in message_c.parsed_output_list]
|
||||
return message_c
|
||||
|
||||
def get_memory(self, content_key="role_content"):
|
||||
return self.memory.to_tuple_messages(content_key="step_content")
|
||||
|
||||
def get_memory_str(self, content_key="role_content"):
|
||||
return "\n".join([": ".join(i) for i in self.memory.to_tuple_messages(content_key="step_content")])
|
||||
|
||||
|
||||
def create_prompt(
|
||||
self, query: Message, memory: Memory =None, history: Memory = None, background: Memory = None, memory_pool: Memory=None, prompt_mamnger=None) -> str:
|
||||
|
@ -225,7 +330,7 @@ class BaseAgent:
|
|||
|
||||
# logger.debug(f"{self.role.role_name} prompt: {prompt}")
|
||||
return prompt
|
||||
|
||||
|
||||
def create_doc_prompt(self, message: Message) -> str:
|
||||
''''''
|
||||
db_docs = message.db_docs
|
||||
|
@ -274,76 +379,4 @@ class BaseAgent:
|
|||
if selfmemory_message:
|
||||
selfmemory_message = re.sub("}", "}}", re.sub("{", "{{", selfmemory_message))
|
||||
return "\n补充自身对话信息: " + selfmemory_message if selfmemory_message else None
|
||||
|
||||
def init_history(self, memory: Memory = None) -> Memory:
|
||||
return Memory(messages=[])
|
||||
|
||||
def update_history(self, message: Message):
|
||||
self.memory.append(message)
|
||||
|
||||
def append_history(self, message: Message):
|
||||
self.memory.append(message)
|
||||
|
||||
def clear_history(self, ):
|
||||
self.memory.clear()
|
||||
self.memory = self.init_history()
|
||||
|
||||
def create_llm_engine(self, temperature=0.2, stop=None):
|
||||
return getChatModel(temperature=temperature, stop=stop)
|
||||
|
||||
def registry_actions(self, actions):
|
||||
'''registry llm's actions'''
|
||||
self.action_list = actions
|
||||
|
||||
def start_action_step(self, message: Message) -> Message:
|
||||
'''do action before agent predict '''
|
||||
# action_json = self.start_action()
|
||||
# message["customed_kargs"]["xx"] = action_json
|
||||
return message
|
||||
|
||||
def end_action_step(self, message: Message) -> Message:
|
||||
'''do action after agent predict '''
|
||||
# action_json = self.end_action()
|
||||
# message["customed_kargs"]["xx"] = action_json
|
||||
return message
|
||||
|
||||
def token_usage(self, ):
|
||||
'''calculate the usage of token'''
|
||||
pass
|
||||
|
||||
def select_memory_by_key(self, memory: Memory) -> Memory:
|
||||
return Memory(
|
||||
messages=[self.select_message_by_key(message) for message in memory.messages
|
||||
if self.select_message_by_key(message) is not None]
|
||||
)
|
||||
|
||||
def select_memory_by_agent_key(self, memory: Memory) -> Memory:
|
||||
return Memory(
|
||||
messages=[self.select_message_by_agent_key(message) for message in memory.messages
|
||||
if self.select_message_by_agent_key(message) is not None]
|
||||
)
|
||||
|
||||
def select_message_by_agent_key(self, message: Message) -> Message:
|
||||
# assume we focus all agents
|
||||
if self.focus_agents == []:
|
||||
return message
|
||||
return None if message is None or message.role_name not in self.focus_agents else self.select_message_by_key(message)
|
||||
|
||||
def select_message_by_key(self, message: Message) -> Message:
|
||||
# assume we focus all key contents
|
||||
if message is None:
|
||||
return message
|
||||
|
||||
if self.focus_message_keys == []:
|
||||
return message
|
||||
|
||||
message_c = copy.deepcopy(message)
|
||||
message_c.parsed_output = {k: v for k,v in message_c.parsed_output.items() if k in self.focus_message_keys}
|
||||
message_c.parsed_output_list = [{k: v for k,v in parsed_output.items() if k in self.focus_message_keys} for parsed_output in message_c.parsed_output_list]
|
||||
return message_c
|
||||
|
||||
def get_memory(self, content_key="role_content"):
|
||||
return self.memory.to_tuple_messages(content_key="step_content")
|
||||
|
||||
def get_memory_str(self, content_key="role_content"):
|
||||
return "\n".join([": ".join(i) for i in self.memory.to_tuple_messages(content_key="step_content")])
|
||||
|
|
@ -0,0 +1,152 @@
|
|||
from typing import List, Union
|
||||
import copy
|
||||
from loguru import logger
|
||||
|
||||
from coagent.connector.schema import (
|
||||
Memory, Task, Env, Role, Message, ActionStatus, PromptField, LogVerboseEnum
|
||||
)
|
||||
from coagent.connector.memory_manager import BaseMemoryManager
|
||||
from coagent.connector.configs.prompts import BEGIN_PROMPT_INPUT
|
||||
from coagent.llm_models import LLMConfig, EmbedConfig
|
||||
from coagent.connector.memory_manager import LocalMemoryManager
|
||||
|
||||
from .base_agent import BaseAgent
|
||||
|
||||
|
||||
class ExecutorAgent(BaseAgent):
|
||||
def __init__(
|
||||
self,
|
||||
role: Role,
|
||||
prompt_config: [PromptField],
|
||||
prompt_manager_type: str= "PromptManager",
|
||||
task: Task = None,
|
||||
memory: Memory = None,
|
||||
chat_turn: int = 1,
|
||||
focus_agents: List[str] = [],
|
||||
focus_message_keys: List[str] = [],
|
||||
#
|
||||
llm_config: LLMConfig = None,
|
||||
embed_config: EmbedConfig = None,
|
||||
sandbox_server: dict = {},
|
||||
jupyter_work_path: str = "",
|
||||
kb_root_path: str = "",
|
||||
log_verbose: str = "0"
|
||||
):
|
||||
|
||||
super().__init__(role, prompt_config, prompt_manager_type, task, memory, chat_turn,
|
||||
focus_agents, focus_message_keys, llm_config, embed_config, sandbox_server,
|
||||
jupyter_work_path, kb_root_path, log_verbose
|
||||
)
|
||||
self.do_all_task = True # run all tasks
|
||||
|
||||
def astep(self, query: Message, history: Memory = None, background: Memory = None, memory_manager: BaseMemoryManager=None) -> Message:
|
||||
'''agent reponse from multi-message'''
|
||||
# insert query into memory
|
||||
task_executor_memory = Memory(messages=[])
|
||||
# insert query
|
||||
output_message = Message(
|
||||
role_name=self.role.role_name,
|
||||
role_type="assistant", #self.role.role_type,
|
||||
role_content=query.input_query,
|
||||
step_content="",
|
||||
input_query=query.input_query,
|
||||
tools=query.tools,
|
||||
# parsed_output_list=[query.parsed_output],
|
||||
customed_kargs=query.customed_kargs
|
||||
)
|
||||
|
||||
if memory_manager is None:
|
||||
memory_manager = LocalMemoryManager(
|
||||
unique_name=self.role.role_name,
|
||||
do_init=True,
|
||||
kb_root_path = self.kb_root_path,
|
||||
embed_config=self.embed_config,
|
||||
llm_config=self.embed_config
|
||||
)
|
||||
memory_manager.append(query)
|
||||
|
||||
# self_memory = self.memory if self.do_use_self_memory else None
|
||||
|
||||
plan_step = int(query.parsed_output.get("PLAN_STEP", 0))
|
||||
# 如果存在plan字段且plan字段为str的时候
|
||||
if "PLAN" not in query.parsed_output or isinstance(query.parsed_output.get("PLAN", []), str) or plan_step >= len(query.parsed_output.get("PLAN", [])):
|
||||
query_c = copy.deepcopy(query)
|
||||
query_c = self.start_action_step(query_c)
|
||||
query_c.parsed_output = {"CURRENT_STEP": query_c.input_query}
|
||||
task_executor_memory.append(query_c)
|
||||
for output_message, task_executor_memory in self._arun_step(output_message, query_c, self.memory, history, background, memory_manager, task_executor_memory):
|
||||
pass
|
||||
# task_executor_memory.append(query_c)
|
||||
# content = "the execution step of the plan is exceed the planned scope."
|
||||
# output_message.parsed_dict = {"Thought": content, "Action Status": "finished", "Action": content}
|
||||
# task_executor_memory.append(output_message)
|
||||
|
||||
elif "PLAN" in query.parsed_output:
|
||||
if self.do_all_task:
|
||||
# run all tasks step by step
|
||||
for task_content in query.parsed_output["PLAN"][plan_step:]:
|
||||
# create your llm prompt
|
||||
query_c = copy.deepcopy(query)
|
||||
query_c.parsed_output = {"CURRENT_STEP": task_content}
|
||||
task_executor_memory.append(query_c)
|
||||
for output_message, task_executor_memory in self._arun_step(output_message, query_c, self.memory, history, background, memory_manager, task_executor_memory):
|
||||
pass
|
||||
yield output_message
|
||||
else:
|
||||
query_c = copy.deepcopy(query)
|
||||
query_c = self.start_action_step(query_c)
|
||||
task_content = query_c.parsed_output["PLAN"][plan_step]
|
||||
query_c.parsed_output = {"CURRENT_STEP": task_content}
|
||||
task_executor_memory.append(query_c)
|
||||
for output_message, task_executor_memory in self._arun_step(output_message, query_c, self.memory, history, background, memory_manager, task_executor_memory):
|
||||
pass
|
||||
output_message.parsed_output.update({"CURRENT_STEP": plan_step})
|
||||
# update self_memory
|
||||
self.append_history(query)
|
||||
self.append_history(output_message)
|
||||
output_message.input_query = output_message.role_content
|
||||
# end_action_step
|
||||
output_message = self.end_action_step(output_message)
|
||||
# update memory pool
|
||||
memory_manager.append(output_message)
|
||||
yield output_message
|
||||
|
||||
def _arun_step(self, output_message: Message, query: Message, self_memory: Memory,
|
||||
history: Memory, background: Memory, memory_manager: BaseMemoryManager,
|
||||
task_memory: Memory) -> Union[Message, Memory]:
|
||||
'''execute the llm predict by created prompt'''
|
||||
memory_pool = memory_manager.current_memory
|
||||
prompt = self.prompt_manager.generate_full_prompt(
|
||||
previous_agent_message=query, agent_long_term_memory=self_memory, ui_history=history, chain_summary_messages=background, memory_pool=memory_pool,
|
||||
task_memory=task_memory)
|
||||
content = self.llm.predict(prompt)
|
||||
|
||||
if LogVerboseEnum.ge(LogVerboseEnum.Log2Level, self.log_verbose):
|
||||
logger.debug(f"{self.role.role_name} prompt: {prompt}")
|
||||
|
||||
if LogVerboseEnum.ge(LogVerboseEnum.Log1Level, self.log_verbose):
|
||||
logger.info(f"{self.role.role_name} content: {content}")
|
||||
|
||||
output_message.role_content = content
|
||||
output_message.step_content += "\n"+output_message.role_content
|
||||
output_message = self.message_utils.parser(output_message)
|
||||
# according the output to choose one action for code_content or tool_content
|
||||
output_message, observation_message = self.message_utils.step_router(output_message)
|
||||
# update parserd_output_list
|
||||
output_message.parsed_output_list.append(output_message.parsed_output)
|
||||
|
||||
react_message = copy.deepcopy(output_message)
|
||||
task_memory.append(react_message)
|
||||
if observation_message:
|
||||
task_memory.append(observation_message)
|
||||
output_message.parsed_output_list.append(observation_message.parsed_output)
|
||||
# logger.debug(f"{observation_message.role_name} content: {observation_message.role_content}")
|
||||
yield output_message, task_memory
|
||||
|
||||
def pre_print(self, query: Message, history: Memory = None, background: Memory = None, memory_manager: BaseMemoryManager = None):
|
||||
task_memory = Memory(messages=[])
|
||||
prompt = self.prompt_manager.pre_print(
|
||||
previous_agent_message=query, agent_long_term_memory=self.memory, ui_history=history, chain_summary_messages=background, react_memory=None,
|
||||
memory_pool=memory_manager.current_memory, task_memory=task_memory)
|
||||
title = f"<<<<{self.role.role_name}'s prompt>>>>"
|
||||
print("#"*len(title) + f"\n{title}\n"+ "#"*len(title)+ f"\n\n{prompt}\n\n")
|
|
@ -1,96 +1,117 @@
|
|||
from pydantic import BaseModel
|
||||
from typing import List, Union
|
||||
import re
|
||||
import json
|
||||
import traceback
|
||||
import copy
|
||||
from loguru import logger
|
||||
|
||||
from langchain.prompts.chat import ChatPromptTemplate
|
||||
|
||||
from dev_opsgpt.connector.schema import (
|
||||
Memory, Task, Env, Role, Message, ActionStatus
|
||||
from coagent.connector.schema import (
|
||||
Memory, Task, Env, Role, Message, ActionStatus, PromptField, LogVerboseEnum
|
||||
)
|
||||
from dev_opsgpt.llm_models import getChatModel
|
||||
from dev_opsgpt.connector.configs.agent_config import REACT_PROMPT_INPUT
|
||||
|
||||
from coagent.connector.memory_manager import BaseMemoryManager
|
||||
from coagent.connector.configs.agent_config import REACT_PROMPT_INPUT
|
||||
from coagent.llm_models import LLMConfig, EmbedConfig
|
||||
from .base_agent import BaseAgent
|
||||
from coagent.connector.memory_manager import LocalMemoryManager
|
||||
|
||||
from coagent.connector.prompt_manager import PromptManager
|
||||
|
||||
|
||||
class ReactAgent(BaseAgent):
|
||||
def __init__(
|
||||
self,
|
||||
role: Role,
|
||||
prompt_config: [PromptField],
|
||||
prompt_manager_type: str = "PromptManager",
|
||||
task: Task = None,
|
||||
memory: Memory = None,
|
||||
chat_turn: int = 1,
|
||||
do_search: bool = False,
|
||||
do_doc_retrieval: bool = False,
|
||||
do_tool_retrieval: bool = False,
|
||||
temperature: float = 0.2,
|
||||
stop: Union[List[str], str] = None,
|
||||
do_filter: bool = True,
|
||||
do_use_self_memory: bool = True,
|
||||
focus_agents: List[str] = [],
|
||||
focus_message_keys: List[str] = [],
|
||||
# prompt_mamnger: PromptManager
|
||||
#
|
||||
llm_config: LLMConfig = None,
|
||||
embed_config: EmbedConfig = None,
|
||||
sandbox_server: dict = {},
|
||||
jupyter_work_path: str = "",
|
||||
kb_root_path: str = "",
|
||||
log_verbose: str = "0"
|
||||
):
|
||||
|
||||
super().__init__(role, task, memory, chat_turn, do_search, do_doc_retrieval,
|
||||
do_tool_retrieval, temperature, stop, do_filter,do_use_self_memory,
|
||||
focus_agents, focus_message_keys
|
||||
super().__init__(role, prompt_config, prompt_manager_type, task, memory, chat_turn,
|
||||
focus_agents, focus_message_keys, llm_config, embed_config, sandbox_server,
|
||||
jupyter_work_path, kb_root_path, log_verbose
|
||||
)
|
||||
|
||||
def run(self, query: Message, history: Memory = None, background: Memory = None, memory_pool: Memory = None) -> Message:
|
||||
def step(self, query: Message, history: Memory = None, background: Memory = None, memory_manager: BaseMemoryManager = None) -> Message:
|
||||
'''agent reponse from multi-message'''
|
||||
for message in self.arun(query, history, background, memory_pool):
|
||||
for message in self.astep(query, history, background, memory_manager):
|
||||
pass
|
||||
return message
|
||||
|
||||
def arun(self, query: Message, history: Memory = None, background: Memory = None, memory_pool: Memory = None) -> Message:
|
||||
def astep(self, query: Message, history: Memory = None, background: Memory = None, memory_manager: BaseMemoryManager = None) -> Message:
|
||||
'''agent reponse from multi-message'''
|
||||
step_nums = copy.deepcopy(self.chat_turn)
|
||||
react_memory = Memory(messages=[])
|
||||
# insert query
|
||||
output_message = Message(
|
||||
role_name=self.role.role_name,
|
||||
role_type="ai", #self.role.role_type,
|
||||
role_type="assistant", #self.role.role_type,
|
||||
role_content=query.input_query,
|
||||
step_content="",
|
||||
input_query=query.input_query,
|
||||
tools=query.tools,
|
||||
parsed_output_list=[query.parsed_output],
|
||||
# parsed_output_list=[query.parsed_output],
|
||||
customed_kargs=query.customed_kargs
|
||||
)
|
||||
query_c = copy.deepcopy(query)
|
||||
query_c = self.start_action_step(query_c)
|
||||
if query.parsed_output:
|
||||
query_c.parsed_output = {"Question": "\n".join([f"{v}" for k, v in query.parsed_output.items() if k not in ["Action Status"]])}
|
||||
else:
|
||||
query_c.parsed_output = {"Question": query.input_query}
|
||||
react_memory.append(query_c)
|
||||
self_memory = self.memory if self.do_use_self_memory else None
|
||||
# if query.parsed_output:
|
||||
# query_c.parsed_output = {"Question": "\n".join([f"{v}" for k, v in query.parsed_output.items() if k not in ["Action Status"]])}
|
||||
# else:
|
||||
# query_c.parsed_output = {"Question": query.input_query}
|
||||
# react_memory.append(query_c)
|
||||
# self_memory = self.memory if self.do_use_self_memory else None
|
||||
idx = 0
|
||||
# start to react
|
||||
while step_nums > 0:
|
||||
output_message.role_content = output_message.step_content
|
||||
prompt = self.create_prompt(query, self_memory, history, background, react_memory, memory_pool)
|
||||
# prompt = self.create_prompt(query, self.memory, history, background, react_memory, memory_manager.current_memory)
|
||||
|
||||
if memory_manager is None:
|
||||
memory_manager = LocalMemoryManager(
|
||||
unique_name=self.role.role_name,
|
||||
do_init=True,
|
||||
kb_root_path = self.kb_root_path,
|
||||
embed_config=self.embed_config,
|
||||
llm_config=self.embed_config
|
||||
)
|
||||
memory_manager.append(query)
|
||||
memory_pool = memory_manager.current_memory
|
||||
else:
|
||||
memory_pool = memory_manager.current_memory
|
||||
|
||||
prompt = self.prompt_manager.generate_full_prompt(
|
||||
previous_agent_message=query_c, agent_long_term_memory=self.memory, ui_history=history, chain_summary_messages=background, react_memory=react_memory,
|
||||
memory_pool=memory_pool)
|
||||
try:
|
||||
content = self.llm.predict(prompt)
|
||||
except Exception as e:
|
||||
logger.warning(f"error prompt: {prompt}")
|
||||
logger.error(f"error prompt: {prompt}")
|
||||
raise Exception(traceback.format_exc())
|
||||
|
||||
output_message.role_content = "\n"+content
|
||||
output_message.step_content += "\n"+output_message.role_content
|
||||
yield output_message
|
||||
|
||||
# logger.debug(f"{self.role.role_name}, {idx} iteration prompt: {prompt}")
|
||||
logger.info(f"{self.role.role_name}, {idx} iteration step_run: {output_message.role_content}")
|
||||
if LogVerboseEnum.ge(LogVerboseEnum.Log2Level, self.log_verbose):
|
||||
logger.debug(f"{self.role.role_name}, {idx} iteration prompt: {prompt}")
|
||||
|
||||
if LogVerboseEnum.ge(LogVerboseEnum.Log1Level, self.log_verbose):
|
||||
logger.info(f"{self.role.role_name}, {idx} iteration step_run: {output_message.role_content}")
|
||||
|
||||
output_message = self.message_utils.parser(output_message)
|
||||
# when get finished signal can stop early
|
||||
if output_message.action_status == ActionStatus.FINISHED or output_message.action_status == ActionStatus.STOPED: break
|
||||
if output_message.action_status == ActionStatus.FINISHED or output_message.action_status == ActionStatus.STOPPED:
|
||||
output_message.parsed_output_list.append(output_message.parsed_output)
|
||||
break
|
||||
# according the output to choose one action for code_content or tool_content
|
||||
output_message, observation_message = self.message_utils.step_router(output_message)
|
||||
output_message.parsed_output_list.append(output_message.parsed_output)
|
||||
|
@ -101,21 +122,45 @@ class ReactAgent(BaseAgent):
|
|||
react_memory.append(observation_message)
|
||||
output_message.parsed_output_list.append(observation_message.parsed_output)
|
||||
# logger.debug(f"{observation_message.role_name} content: {observation_message.role_content}")
|
||||
# logger.info(f"{self.role.role_name} currenct question: {output_message.input_query}\nllm_react_run: {output_message.role_content}")
|
||||
|
||||
idx += 1
|
||||
step_nums -= 1
|
||||
yield output_message
|
||||
# react' self_memory saved at last
|
||||
self.append_history(output_message)
|
||||
# update memory pool
|
||||
# memory_pool.append(output_message)
|
||||
output_message.input_query = query.input_query
|
||||
# end_action_step
|
||||
# end_action_step, BUG:it may cause slack some information
|
||||
output_message = self.end_action_step(output_message)
|
||||
# update memory pool
|
||||
memory_pool.append(output_message)
|
||||
memory_manager.append(output_message)
|
||||
yield output_message
|
||||
|
||||
def pre_print(self, query: Message, history: Memory = None, background: Memory = None, memory_manager: BaseMemoryManager=None):
|
||||
react_memory = Memory(messages=[])
|
||||
prompt = self.prompt_manager.pre_print(
|
||||
previous_agent_message=query, agent_long_term_memory=self.memory, ui_history=history, chain_summary_messages=background, react_memory=react_memory,
|
||||
memory_pool=memory_manager.current_memory)
|
||||
title = f"<<<<{self.role.role_name}'s prompt>>>>"
|
||||
print("#"*len(title) + f"\n{title}\n"+ "#"*len(title)+ f"\n\n{prompt}\n\n")
|
||||
|
||||
# def create_prompt(
|
||||
# self, query: Message, memory: Memory =None, history: Memory = None, background: Memory = None, react_memory: Memory = None, memory_manager: BaseMemoryManager= None,
|
||||
# prompt_mamnger=None) -> str:
|
||||
# prompt_mamnger = PromptManager()
|
||||
# prompt_mamnger.register_standard_fields()
|
||||
|
||||
# # input_keys = parse_section(self.role.role_prompt, 'Agent Profile')
|
||||
|
||||
# data_dict = {
|
||||
# "agent_profile": extract_section(self.role.role_prompt, 'Agent Profile'),
|
||||
# "tool_information": query.tools,
|
||||
# "session_records": memory_manager,
|
||||
# "reference_documents": query,
|
||||
# "output_format": extract_section(self.role.role_prompt, 'Response Output Format'),
|
||||
# "response": "\n".join(["\n".join([f"**{k}:**\n{v}" for k,v in _dict.items()]) for _dict in react_memory.get_parserd_output()]),
|
||||
# }
|
||||
# # logger.debug(memory_pool)
|
||||
|
||||
# return prompt_mamnger.generate_full_prompt(data_dict)
|
||||
|
||||
def create_prompt(
|
||||
self, query: Message, memory: Memory =None, history: Memory = None, background: Memory = None, react_memory: Memory = None, memory_pool: Memory= None,
|
|
@ -0,0 +1,190 @@
|
|||
from typing import List, Union
|
||||
import copy
|
||||
import random
|
||||
from loguru import logger
|
||||
|
||||
from coagent.connector.schema import (
|
||||
Memory, Task, Role, Message, PromptField, LogVerboseEnum
|
||||
)
|
||||
from coagent.connector.memory_manager import BaseMemoryManager
|
||||
from coagent.connector.configs.prompts import BEGIN_PROMPT_INPUT
|
||||
from coagent.connector.memory_manager import LocalMemoryManager
|
||||
from coagent.llm_models import LLMConfig, EmbedConfig
|
||||
from .base_agent import BaseAgent
|
||||
|
||||
|
||||
class SelectorAgent(BaseAgent):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
role: Role,
|
||||
prompt_config: List[PromptField] = None,
|
||||
prompt_manager_type: str = "PromptManager",
|
||||
task: Task = None,
|
||||
memory: Memory = None,
|
||||
chat_turn: int = 1,
|
||||
focus_agents: List[str] = [],
|
||||
focus_message_keys: List[str] = [],
|
||||
group_agents: List[BaseAgent] = [],
|
||||
#
|
||||
llm_config: LLMConfig = None,
|
||||
embed_config: EmbedConfig = None,
|
||||
sandbox_server: dict = {},
|
||||
jupyter_work_path: str = "",
|
||||
kb_root_path: str = "",
|
||||
log_verbose: str = "0"
|
||||
):
|
||||
|
||||
super().__init__(role, prompt_config, prompt_manager_type, task, memory, chat_turn,
|
||||
focus_agents, focus_message_keys, llm_config, embed_config, sandbox_server,
|
||||
jupyter_work_path, kb_root_path, log_verbose
|
||||
)
|
||||
self.group_agents = group_agents
|
||||
|
||||
def astep(self, query: Message, history: Memory = None, background: Memory = None, memory_manager: BaseMemoryManager=None) -> Message:
|
||||
'''agent reponse from multi-message'''
|
||||
# insert query into memory
|
||||
query_c = copy.deepcopy(query)
|
||||
query_c = self.start_action_step(query_c)
|
||||
# create your llm prompt
|
||||
if memory_manager is None:
|
||||
memory_manager = LocalMemoryManager(
|
||||
unique_name=self.role.role_name,
|
||||
do_init=True,
|
||||
kb_root_path = self.kb_root_path,
|
||||
embed_config=self.embed_config,
|
||||
llm_config=self.embed_config
|
||||
)
|
||||
memory_manager.append(query)
|
||||
memory_pool = memory_manager.current_memory
|
||||
else:
|
||||
memory_pool = memory_manager.current_memory
|
||||
prompt = self.prompt_manager.generate_full_prompt(
|
||||
previous_agent_message=query_c, agent_long_term_memory=self.memory, ui_history=history, chain_summary_messages=background, react_memory=None,
|
||||
memory_pool=memory_pool, agents=self.group_agents)
|
||||
content = self.llm.predict(prompt)
|
||||
|
||||
if LogVerboseEnum.ge(LogVerboseEnum.Log2Level, self.log_verbose):
|
||||
logger.debug(f"{self.role.role_name} prompt: {prompt}")
|
||||
|
||||
if LogVerboseEnum.ge(LogVerboseEnum.Log1Level, self.log_verbose):
|
||||
logger.info(f"{self.role.role_name} content: {content}")
|
||||
|
||||
# select agent
|
||||
select_message = Message(
|
||||
role_name=self.role.role_name,
|
||||
role_type="assistant", #self.role.role_type,
|
||||
role_content=content,
|
||||
step_content=content,
|
||||
input_query=query_c.input_query,
|
||||
tools=query_c.tools,
|
||||
# parsed_output_list=[query_c.parsed_output]
|
||||
customed_kargs=query.customed_kargs
|
||||
)
|
||||
# common parse llm' content to message
|
||||
select_message = self.message_utils.parser(select_message)
|
||||
select_message.parsed_output_list.append(select_message.parsed_output)
|
||||
|
||||
output_message = None
|
||||
if select_message.parsed_output.get("Role", "") in [agent.role.role_name for agent in self.group_agents]:
|
||||
for agent in self.group_agents:
|
||||
if agent.role.role_name == select_message.parsed_output.get("Role", ""):
|
||||
break
|
||||
for output_message in agent.astep(query_c, history, background=background, memory_manager=memory_manager):
|
||||
yield output_message or select_message
|
||||
# update self_memory
|
||||
self.append_history(query_c)
|
||||
self.append_history(output_message)
|
||||
output_message.input_query = output_message.role_content
|
||||
# output_message.parsed_output_list.append(output_message.parsed_output)
|
||||
#
|
||||
output_message = self.end_action_step(output_message)
|
||||
# update memory pool
|
||||
memory_manager.append(output_message)
|
||||
|
||||
select_message.parsed_output = output_message.parsed_output
|
||||
select_message.parsed_output_list.extend(output_message.parsed_output_list)
|
||||
yield select_message
|
||||
|
||||
def pre_print(self, query: Message, history: Memory = None, background: Memory = None, memory_manager: BaseMemoryManager=None):
|
||||
prompt = self.prompt_manager.pre_print(
|
||||
previous_agent_message=query, agent_long_term_memory=self.memory, ui_history=history, chain_summary_messages=background, react_memory=None,
|
||||
memory_pool=memory_manager.current_memory, agents=self.group_agents)
|
||||
title = f"<<<<{self.role.role_name}'s prompt>>>>"
|
||||
print("#"*len(title) + f"\n{title}\n"+ "#"*len(title)+ f"\n\n{prompt}\n\n")
|
||||
|
||||
for agent in self.group_agents:
|
||||
agent.pre_print(query=query, history=history, background=background, memory_manager=memory_manager)
|
||||
|
||||
# def create_prompt(
|
||||
# self, query: Message, memory: Memory =None, history: Memory = None, background: Memory = None, memory_manager: BaseMemoryManager=None, prompt_mamnger=None) -> str:
|
||||
# '''
|
||||
# role\task\tools\docs\memory
|
||||
# '''
|
||||
# #
|
||||
# doc_infos = self.create_doc_prompt(query)
|
||||
# code_infos = self.create_codedoc_prompt(query)
|
||||
# #
|
||||
# formatted_tools, tool_names, tools_descs = self.create_tools_prompt(query)
|
||||
# agent_names, agents = self.create_agent_names()
|
||||
# task_prompt = self.create_task_prompt(query)
|
||||
# background_prompt = self.create_background_prompt(background)
|
||||
# history_prompt = self.create_history_prompt(history)
|
||||
# selfmemory_prompt = self.create_selfmemory_prompt(memory, control_key="step_content")
|
||||
|
||||
|
||||
# DocInfos = ""
|
||||
# if doc_infos is not None and doc_infos!="" and doc_infos!="不存在知识库辅助信息":
|
||||
# DocInfos += f"\nDocument Information: {doc_infos}"
|
||||
|
||||
# if code_infos is not None and code_infos!="" and code_infos!="不存在代码库辅助信息":
|
||||
# DocInfos += f"\nCodeBase Infomation: {code_infos}"
|
||||
|
||||
# input_query = query.input_query
|
||||
# logger.debug(f"{self.role.role_name} input_query: {input_query}")
|
||||
# prompt = self.role.role_prompt.format(**{"agent_names": agent_names, "agents": agents, "formatted_tools": tools_descs, "tool_names": tool_names})
|
||||
# #
|
||||
# memory_pool_select_by_agent_key = self.select_memory_by_agent_key(memory_manager.current_memory)
|
||||
# memory_pool_select_by_agent_key_context = '\n\n'.join([f"*{k}*\n{v}" for parsed_output in memory_pool_select_by_agent_key.get_parserd_output_list() for k, v in parsed_output.items() if k not in ['Action Status']])
|
||||
|
||||
# input_keys = parse_section(self.role.role_prompt, 'Input Format')
|
||||
# #
|
||||
# prompt += "\n" + BEGIN_PROMPT_INPUT
|
||||
# for input_key in input_keys:
|
||||
# if input_key == "Origin Query":
|
||||
# prompt += "\n**Origin Query:**\n" + query.origin_query
|
||||
# elif input_key == "Context":
|
||||
# context = "\n".join([f"*{k}*\n{v}" for i in query.parsed_output_list for k,v in i.items() if "Action Status" !=k])
|
||||
# if history:
|
||||
# context = history_prompt + "\n" + context
|
||||
# if not context:
|
||||
# context = "there is no context"
|
||||
|
||||
# if self.focus_agents and memory_pool_select_by_agent_key_context:
|
||||
# context = memory_pool_select_by_agent_key_context
|
||||
# prompt += "\n**Context:**\n" + context + "\n" + input_query
|
||||
# elif input_key == "DocInfos":
|
||||
# prompt += "\n**DocInfos:**\n" + DocInfos
|
||||
# elif input_key == "Question":
|
||||
# prompt += "\n**Question:**\n" + input_query
|
||||
|
||||
# while "{{" in prompt or "}}" in prompt:
|
||||
# prompt = prompt.replace("{{", "{")
|
||||
# prompt = prompt.replace("}}", "}")
|
||||
|
||||
# # logger.debug(f"{self.role.role_name} prompt: {prompt}")
|
||||
# return prompt
|
||||
|
||||
# def create_agent_names(self):
|
||||
# random.shuffle(self.group_agents)
|
||||
# agent_names = ", ".join([f'{agent.role.role_name}' for agent in self.group_agents])
|
||||
# agent_descs = []
|
||||
# for agent in self.group_agents:
|
||||
# role_desc = agent.role.role_prompt.split("####")[1]
|
||||
# while "\n\n" in role_desc:
|
||||
# role_desc = role_desc.replace("\n\n", "\n")
|
||||
# role_desc = role_desc.replace("\n", ",")
|
||||
|
||||
# agent_descs.append(f'"role name: {agent.role.role_name}\nrole description: {role_desc}"')
|
||||
|
||||
# return agent_names, "\n".join(agent_descs)
|
|
@ -1,56 +1,68 @@
|
|||
from pydantic import BaseModel
|
||||
from typing import List
|
||||
import json
|
||||
import re
|
||||
from loguru import logger
|
||||
import traceback
|
||||
import uuid
|
||||
import copy
|
||||
import copy, os
|
||||
|
||||
from dev_opsgpt.connector.agents import BaseAgent, CheckAgent
|
||||
from dev_opsgpt.tools.base_tool import BaseTools, Tool
|
||||
from coagent.connector.agents import BaseAgent
|
||||
|
||||
from dev_opsgpt.connector.schema import (
|
||||
from coagent.connector.schema import (
|
||||
Memory, Role, Message, ActionStatus, ChainConfig,
|
||||
load_role_configs
|
||||
)
|
||||
from dev_opsgpt.connector.message_process import MessageUtils
|
||||
|
||||
from dev_opsgpt.connector.configs.agent_config import AGETN_CONFIGS
|
||||
from coagent.connector.memory_manager import BaseMemoryManager
|
||||
from coagent.connector.message_process import MessageUtils
|
||||
from coagent.llm_models.llm_config import LLMConfig, EmbedConfig
|
||||
from coagent.connector.configs.agent_config import AGETN_CONFIGS
|
||||
role_configs = load_role_configs(AGETN_CONFIGS)
|
||||
|
||||
# from configs.model_config import JUPYTER_WORK_PATH
|
||||
# from configs.server_config import SANDBOX_SERVER
|
||||
|
||||
|
||||
class BaseChain:
|
||||
def __init__(
|
||||
self,
|
||||
chainConfig: ChainConfig,
|
||||
# chainConfig: ChainConfig,
|
||||
agents: List[BaseAgent],
|
||||
chat_turn: int = 1,
|
||||
do_checker: bool = False,
|
||||
# prompt_mamnger: PromptManager
|
||||
sandbox_server: dict = {},
|
||||
jupyter_work_path: str = "",
|
||||
kb_root_path: str = "",
|
||||
llm_config: LLMConfig = LLMConfig(),
|
||||
embed_config: EmbedConfig = None,
|
||||
log_verbose: str = "0"
|
||||
) -> None:
|
||||
self.chainConfig = chainConfig
|
||||
self.agents = agents
|
||||
# self.chainConfig = chainConfig
|
||||
self.agents: List[BaseAgent] = agents
|
||||
self.chat_turn = chat_turn
|
||||
self.do_checker = do_checker
|
||||
self.checker = CheckAgent(role=role_configs["checker"].role,
|
||||
task = None,
|
||||
memory = None,
|
||||
do_search = role_configs["checker"].do_search,
|
||||
do_doc_retrieval = role_configs["checker"].do_doc_retrieval,
|
||||
do_tool_retrieval = role_configs["checker"].do_tool_retrieval,
|
||||
do_filter=False, do_use_self_memory=False)
|
||||
self.messageUtils = MessageUtils()
|
||||
self.sandbox_server = sandbox_server
|
||||
self.jupyter_work_path = jupyter_work_path
|
||||
self.llm_config = llm_config
|
||||
self.log_verbose = max(os.environ.get("log_verbose", "0"), log_verbose)
|
||||
self.checker = BaseAgent(role=role_configs["checker"].role,
|
||||
prompt_config=role_configs["checker"].prompt_config,
|
||||
task = None, memory = None,
|
||||
llm_config=llm_config, embed_config=embed_config,
|
||||
sandbox_server=sandbox_server, jupyter_work_path=jupyter_work_path,
|
||||
kb_root_path=kb_root_path
|
||||
)
|
||||
self.messageUtils = MessageUtils(None, sandbox_server, self.jupyter_work_path, embed_config, llm_config, kb_root_path, log_verbose)
|
||||
# all memory created by agent until instance deleted
|
||||
self.global_memory = Memory(messages=[])
|
||||
|
||||
def step(self, query: Message, history: Memory = None, background: Memory = None, memory_pool: Memory = None) -> Message:
|
||||
def step(self, query: Message, history: Memory = None, background: Memory = None, memory_manager: BaseMemoryManager = None) -> Message:
|
||||
'''execute chain'''
|
||||
for output_message, local_memory in self.astep(query, history, background, memory_pool):
|
||||
for output_message, local_memory in self.astep(query, history, background, memory_manager):
|
||||
pass
|
||||
return output_message, local_memory
|
||||
|
||||
def pre_print(self, query: Message, history: Memory = None, background: Memory = None, memory_manager: BaseMemoryManager = None) -> Message:
|
||||
'''execute chain'''
|
||||
for agent in self.agents:
|
||||
agent.pre_print(query, history, background=background, memory_manager=memory_manager)
|
||||
|
||||
def astep(self, query: Message, history: Memory = None, background: Memory = None, memory_pool: Memory = None) -> Message:
|
||||
def astep(self, query: Message, history: Memory = None, background: Memory = None, memory_manager: BaseMemoryManager = None) -> Message:
|
||||
'''execute chain'''
|
||||
local_memory = Memory(messages=[])
|
||||
input_message = copy.deepcopy(query)
|
||||
|
@ -61,42 +73,35 @@ class BaseChain:
|
|||
# local_memory.append(input_message)
|
||||
while step_nums > 0:
|
||||
for agent in self.agents:
|
||||
for output_message in agent.arun(input_message, history, background=background, memory_pool=memory_pool):
|
||||
for output_message in agent.astep(input_message, history, background=background, memory_manager=memory_manager):
|
||||
# logger.debug(f"local_memory {local_memory + output_message}")
|
||||
yield output_message, local_memory + output_message
|
||||
|
||||
output_message = self.messageUtils.inherit_extrainfo(input_message, output_message)
|
||||
# according the output to choose one action for code_content or tool_content
|
||||
# logger.info(f"{agent.role.role_name} currenct message: {output_message.step_content}\n next llm question: {output_message.input_query}")
|
||||
output_message = self.messageUtils.parser(output_message)
|
||||
yield output_message, local_memory + output_message
|
||||
# output_message = self.step_router(output_message)
|
||||
|
||||
input_message = output_message
|
||||
self.global_memory.append(output_message)
|
||||
|
||||
local_memory.append(output_message)
|
||||
# when get finished signal can stop early
|
||||
if output_message.action_status == ActionStatus.FINISHED or output_message.action_status == ActionStatus.STOPED:
|
||||
if output_message.action_status == ActionStatus.FINISHED or output_message.action_status == ActionStatus.STOPPED:
|
||||
action_status = False
|
||||
break
|
||||
|
||||
if output_message.action_status == ActionStatus.FINISHED:
|
||||
break
|
||||
|
||||
if self.do_checker and self.chat_turn > 1:
|
||||
# logger.debug(f"{self.checker.role.role_name} input global memory: {self.global_memory.to_str_messages(content_key='step_content', return_all=False)}")
|
||||
for check_message in self.checker.arun(query, background=local_memory, memory_pool=memory_pool):
|
||||
for check_message in self.checker.astep(query, background=local_memory, memory_manager=memory_manager):
|
||||
pass
|
||||
check_message = self.messageUtils.parser(check_message)
|
||||
check_message = self.messageUtils.filter(check_message)
|
||||
check_message = self.messageUtils.inherit_extrainfo(output_message, check_message)
|
||||
logger.debug(f"{self.checker.role.role_name}: {check_message.role_content}")
|
||||
# logger.debug(f"{self.checker.role.role_name}: {check_message.role_content}")
|
||||
|
||||
if check_message.action_status == ActionStatus.FINISHED:
|
||||
self.global_memory.append(check_message)
|
||||
break
|
||||
|
||||
step_nums -= 1
|
||||
#
|
||||
output_message = check_message or output_message # 返回chain和checker的结果
|
||||
|
@ -109,8 +114,6 @@ class BaseChain:
|
|||
|
||||
def get_memory_str(self, content_key="role_content") -> Memory:
|
||||
memory = self.global_memory
|
||||
# for i in memory.to_tuple_messages(content_key=content_key):
|
||||
# logger.debug(f"{i}")
|
||||
return "\n".join([": ".join(i) for i in memory.to_tuple_messages(content_key=content_key)])
|
||||
|
||||
def get_agents_memory(self, content_key="role_content"):
|
|
@ -0,0 +1,12 @@
|
|||
from typing import List
|
||||
from loguru import logger
|
||||
from coagent.connector.agents import BaseAgent
|
||||
from .base_chain import BaseChain
|
||||
|
||||
|
||||
|
||||
|
||||
class ExecutorRefineChain(BaseChain):
|
||||
|
||||
def __init__(self, agents: List[BaseAgent], do_code_exec: bool = False) -> None:
|
||||
super().__init__(agents, do_code_exec)
|
|
@ -0,0 +1,9 @@
|
|||
from .agent_config import AGETN_CONFIGS
|
||||
from .chain_config import CHAIN_CONFIGS
|
||||
from .phase_config import PHASE_CONFIGS
|
||||
from .prompt_config import BASE_PROMPT_CONFIGS, EXECUTOR_PROMPT_CONFIGS, SELECTOR_PROMPT_CONFIGS, BASE_NOTOOLPROMPT_CONFIGS
|
||||
|
||||
__all__ = [
|
||||
"AGETN_CONFIGS", "CHAIN_CONFIGS", "PHASE_CONFIGS",
|
||||
"BASE_PROMPT_CONFIGS", "EXECUTOR_PROMPT_CONFIGS", "SELECTOR_PROMPT_CONFIGS", "BASE_NOTOOLPROMPT_CONFIGS"
|
||||
]
|
|
@ -13,6 +13,7 @@ from .prompts import (
|
|||
REACT_TEMPLATE_PROMPT,
|
||||
REACT_TOOL_PROMPT, REACT_CODE_PROMPT, REACT_TOOL_AND_CODE_PLANNER_PROMPT, REACT_TOOL_AND_CODE_PROMPT
|
||||
)
|
||||
from .prompt_config import BASE_PROMPT_CONFIGS, EXECUTOR_PROMPT_CONFIGS, SELECTOR_PROMPT_CONFIGS, BASE_NOTOOLPROMPT_CONFIGS
|
||||
|
||||
|
||||
|
||||
|
@ -34,11 +35,9 @@ AGETN_CONFIGS = {
|
|||
"role_desc": "",
|
||||
"agent_type": "SelectorAgent"
|
||||
},
|
||||
"prompt_config": SELECTOR_PROMPT_CONFIGS,
|
||||
"group_agents": ["tool_react", "code_react"],
|
||||
"chat_turn": 1,
|
||||
"do_search": False,
|
||||
"do_doc_retrieval": False,
|
||||
"do_tool_retrieval": False
|
||||
},
|
||||
"checker": {
|
||||
"role": {
|
||||
|
@ -48,10 +47,8 @@ AGETN_CONFIGS = {
|
|||
"role_desc": "",
|
||||
"agent_type": "BaseAgent"
|
||||
},
|
||||
"prompt_config": BASE_PROMPT_CONFIGS,
|
||||
"chat_turn": 1,
|
||||
"do_search": False,
|
||||
"do_doc_retrieval": False,
|
||||
"do_tool_retrieval": False
|
||||
},
|
||||
"conv_summary": {
|
||||
"role": {
|
||||
|
@ -61,10 +58,8 @@ AGETN_CONFIGS = {
|
|||
"role_desc": "",
|
||||
"agent_type": "BaseAgent"
|
||||
},
|
||||
"prompt_config": BASE_PROMPT_CONFIGS,
|
||||
"chat_turn": 1,
|
||||
"do_search": False,
|
||||
"do_doc_retrieval": False,
|
||||
"do_tool_retrieval": False
|
||||
},
|
||||
"general_planner": {
|
||||
"role": {
|
||||
|
@ -74,10 +69,8 @@ AGETN_CONFIGS = {
|
|||
"role_desc": "",
|
||||
"agent_type": "BaseAgent"
|
||||
},
|
||||
"prompt_config": BASE_PROMPT_CONFIGS,
|
||||
"chat_turn": 1,
|
||||
"do_search": False,
|
||||
"do_doc_retrieval": False,
|
||||
"do_tool_retrieval": False
|
||||
},
|
||||
"executor": {
|
||||
"role": {
|
||||
|
@ -87,11 +80,9 @@ AGETN_CONFIGS = {
|
|||
"role_desc": "",
|
||||
"agent_type": "ExecutorAgent",
|
||||
},
|
||||
"prompt_config": EXECUTOR_PROMPT_CONFIGS,
|
||||
"stop": "\n**Observation:**",
|
||||
"chat_turn": 1,
|
||||
"do_search": False,
|
||||
"do_doc_retrieval": False,
|
||||
"do_tool_retrieval": False
|
||||
},
|
||||
"base_refiner": {
|
||||
"role": {
|
||||
|
@ -101,10 +92,8 @@ AGETN_CONFIGS = {
|
|||
"role_desc": "",
|
||||
"agent_type": "BaseAgent"
|
||||
},
|
||||
"prompt_config": BASE_PROMPT_CONFIGS,
|
||||
"chat_turn": 1,
|
||||
"do_search": False,
|
||||
"do_doc_retrieval": False,
|
||||
"do_tool_retrieval": False
|
||||
},
|
||||
"planner": {
|
||||
"role": {
|
||||
|
@ -114,10 +103,8 @@ AGETN_CONFIGS = {
|
|||
"role_desc": "",
|
||||
"agent_type": "BaseAgent"
|
||||
},
|
||||
"prompt_config": BASE_PROMPT_CONFIGS,
|
||||
"chat_turn": 1,
|
||||
"do_search": False,
|
||||
"do_doc_retrieval": False,
|
||||
"do_tool_retrieval": False
|
||||
},
|
||||
"intention_recognizer": {
|
||||
"role": {
|
||||
|
@ -127,10 +114,8 @@ AGETN_CONFIGS = {
|
|||
"role_desc": "",
|
||||
"agent_type": "BaseAgent"
|
||||
},
|
||||
"prompt_config": BASE_PROMPT_CONFIGS,
|
||||
"chat_turn": 1,
|
||||
"do_search": False,
|
||||
"do_doc_retrieval": False,
|
||||
"do_tool_retrieval": False
|
||||
},
|
||||
"tool_planner": {
|
||||
"role": {
|
||||
|
@ -140,10 +125,8 @@ AGETN_CONFIGS = {
|
|||
"role_desc": "",
|
||||
"agent_type": "BaseAgent"
|
||||
},
|
||||
"prompt_config": BASE_PROMPT_CONFIGS,
|
||||
"chat_turn": 1,
|
||||
"do_search": False,
|
||||
"do_doc_retrieval": False,
|
||||
"do_tool_retrieval": False
|
||||
},
|
||||
"tool_and_code_react": {
|
||||
"role": {
|
||||
|
@ -153,11 +136,9 @@ AGETN_CONFIGS = {
|
|||
"role_desc": "",
|
||||
"agent_type": "ReactAgent",
|
||||
},
|
||||
"prompt_config": BASE_PROMPT_CONFIGS,
|
||||
"stop": "\n**Observation:**",
|
||||
"chat_turn": 7,
|
||||
"do_search": False,
|
||||
"do_doc_retrieval": False,
|
||||
"do_tool_retrieval": False
|
||||
},
|
||||
"tool_and_code_planner": {
|
||||
"role": {
|
||||
|
@ -167,10 +148,8 @@ AGETN_CONFIGS = {
|
|||
"role_desc": "",
|
||||
"agent_type": "BaseAgent"
|
||||
},
|
||||
"prompt_config": BASE_PROMPT_CONFIGS,
|
||||
"chat_turn": 1,
|
||||
"do_search": False,
|
||||
"do_doc_retrieval": False,
|
||||
"do_tool_retrieval": False
|
||||
},
|
||||
"tool_react": {
|
||||
"role": {
|
||||
|
@ -180,10 +159,8 @@ AGETN_CONFIGS = {
|
|||
"role_desc": "",
|
||||
"agent_type": "ReactAgent"
|
||||
},
|
||||
"prompt_config": BASE_PROMPT_CONFIGS,
|
||||
"chat_turn": 5,
|
||||
"do_search": False,
|
||||
"do_doc_retrieval": False,
|
||||
"do_tool_retrieval": False,
|
||||
"stop": "\n**Observation:**"
|
||||
},
|
||||
"code_react": {
|
||||
|
@ -194,10 +171,8 @@ AGETN_CONFIGS = {
|
|||
"role_desc": "",
|
||||
"agent_type": "ReactAgent"
|
||||
},
|
||||
"prompt_config": BASE_NOTOOLPROMPT_CONFIGS,
|
||||
"chat_turn": 5,
|
||||
"do_search": False,
|
||||
"do_doc_retrieval": False,
|
||||
"do_tool_retrieval": False,
|
||||
"stop": "\n**Observation:**"
|
||||
},
|
||||
"qaer": {
|
||||
|
@ -208,23 +183,19 @@ AGETN_CONFIGS = {
|
|||
"role_desc": "",
|
||||
"agent_type": "BaseAgent"
|
||||
},
|
||||
"prompt_config": BASE_PROMPT_CONFIGS,
|
||||
"chat_turn": 1,
|
||||
"do_search": False,
|
||||
"do_doc_retrieval": False,
|
||||
"do_tool_retrieval": False
|
||||
},
|
||||
"code_qaer": {
|
||||
"role": {
|
||||
"role_prompt": CODE_QA_PROMPT ,
|
||||
"role_prompt": CODE_QA_PROMPT,
|
||||
"role_type": "assistant",
|
||||
"role_name": "code_qaer",
|
||||
"role_desc": "",
|
||||
"agent_type": "BaseAgent"
|
||||
},
|
||||
"prompt_config": BASE_PROMPT_CONFIGS,
|
||||
"chat_turn": 1,
|
||||
"do_search": False,
|
||||
"do_doc_retrieval": True,
|
||||
"do_tool_retrieval": False
|
||||
},
|
||||
"searcher": {
|
||||
"role": {
|
||||
|
@ -234,10 +205,8 @@ AGETN_CONFIGS = {
|
|||
"role_desc": "",
|
||||
"agent_type": "BaseAgent"
|
||||
},
|
||||
"prompt_config": BASE_PROMPT_CONFIGS,
|
||||
"chat_turn": 1,
|
||||
"do_search": True,
|
||||
"do_doc_retrieval": False,
|
||||
"do_tool_retrieval": False
|
||||
},
|
||||
"metaGPT_PRD": {
|
||||
"role": {
|
||||
|
@ -247,10 +216,8 @@ AGETN_CONFIGS = {
|
|||
"role_desc": "",
|
||||
"agent_type": "BaseAgent"
|
||||
},
|
||||
"prompt_config": BASE_PROMPT_CONFIGS,
|
||||
"chat_turn": 1,
|
||||
"do_search": False,
|
||||
"do_doc_retrieval": False,
|
||||
"do_tool_retrieval": False,
|
||||
"focus_agents": [],
|
||||
"focus_message_keys": [],
|
||||
},
|
||||
|
@ -263,10 +230,8 @@ AGETN_CONFIGS = {
|
|||
"role_desc": "",
|
||||
"agent_type": "BaseAgent"
|
||||
},
|
||||
"prompt_config": BASE_PROMPT_CONFIGS,
|
||||
"chat_turn": 1,
|
||||
"do_search": False,
|
||||
"do_doc_retrieval": False,
|
||||
"do_tool_retrieval": False,
|
||||
"focus_agents": ["metaGPT_PRD"],
|
||||
"focus_message_keys": [],
|
||||
},
|
||||
|
@ -278,10 +243,8 @@ AGETN_CONFIGS = {
|
|||
"role_desc": "",
|
||||
"agent_type": "BaseAgent"
|
||||
},
|
||||
"prompt_config": BASE_PROMPT_CONFIGS,
|
||||
"chat_turn": 1,
|
||||
"do_search": False,
|
||||
"do_doc_retrieval": False,
|
||||
"do_tool_retrieval": False,
|
||||
"focus_agents": ["metaGPT_DESIGN"],
|
||||
"focus_message_keys": [],
|
||||
},
|
||||
|
@ -293,10 +256,8 @@ AGETN_CONFIGS = {
|
|||
"role_desc": "",
|
||||
"agent_type": "ExecutorAgent"
|
||||
},
|
||||
"prompt_config": EXECUTOR_PROMPT_CONFIGS,
|
||||
"chat_turn": 1,
|
||||
"do_search": False,
|
||||
"do_doc_retrieval": False,
|
||||
"do_tool_retrieval": False,
|
||||
"focus_agents": ["metaGPT_DESIGN", "metaGPT_TASK"],
|
||||
"focus_message_keys": [],
|
||||
},
|
|
@ -88,26 +88,26 @@ PHASE_CONFIGS = {
|
|||
"do_tool_retrieval": False,
|
||||
"do_using_tool": False
|
||||
},
|
||||
# "metagpt_code_devlop": {
|
||||
# "phase_name": "metagpt_code_devlop",
|
||||
# "phase_type": "BasePhase",
|
||||
# "chains": ["metagptChain",],
|
||||
# "do_summary": False,
|
||||
# "do_search": False,
|
||||
# "do_doc_retrieval": False,
|
||||
# "do_code_retrieval": False,
|
||||
# "do_tool_retrieval": False,
|
||||
# "do_using_tool": False
|
||||
# },
|
||||
# "baseGroupPhase": {
|
||||
# "phase_name": "baseGroupPhase",
|
||||
# "phase_type": "BasePhase",
|
||||
# "chains": ["baseGroupChain"],
|
||||
# "do_summary": False,
|
||||
# "do_search": False,
|
||||
# "do_doc_retrieval": False,
|
||||
# "do_code_retrieval": False,
|
||||
# "do_tool_retrieval": False,
|
||||
# "do_using_tool": False
|
||||
# },
|
||||
"metagpt_code_devlop": {
|
||||
"phase_name": "metagpt_code_devlop",
|
||||
"phase_type": "BasePhase",
|
||||
"chains": ["metagptChain",],
|
||||
"do_summary": False,
|
||||
"do_search": False,
|
||||
"do_doc_retrieval": False,
|
||||
"do_code_retrieval": False,
|
||||
"do_tool_retrieval": False,
|
||||
"do_using_tool": False
|
||||
},
|
||||
"baseGroupPhase": {
|
||||
"phase_name": "baseGroupPhase",
|
||||
"phase_type": "BasePhase",
|
||||
"chains": ["baseGroupChain"],
|
||||
"do_summary": False,
|
||||
"do_search": False,
|
||||
"do_doc_retrieval": False,
|
||||
"do_code_retrieval": False,
|
||||
"do_tool_retrieval": False,
|
||||
"do_using_tool": False
|
||||
},
|
||||
}
|
|
@ -0,0 +1,43 @@
|
|||
BASE_PROMPT_CONFIGS = [
|
||||
{"field_name": 'agent_profile', "function_name": 'handle_agent_profile', "is_context": False},
|
||||
{"field_name": 'tool_information',"function_name": 'handle_tool_data', "is_context": False},
|
||||
{"field_name": 'context_placeholder', "function_name": '', "is_context": True},
|
||||
{"field_name": 'reference_documents', "function_name": 'handle_doc_info'},
|
||||
{"field_name": 'session_records', "function_name": 'handle_session_records'},
|
||||
{"field_name": 'task_records', "function_name": 'handle_task_records'},
|
||||
{"field_name": 'output_format', "function_name": 'handle_output_format', 'title': 'Response Output Format', "is_context": False},
|
||||
{"field_name": 'begin!!!', "function_name": 'handle_response', "is_context": False, "omit_if_empty": False}
|
||||
]
|
||||
|
||||
BASE_NOTOOLPROMPT_CONFIGS = [
|
||||
{"field_name": 'agent_profile', "function_name": 'handle_agent_profile', "is_context": False},
|
||||
{"field_name": 'context_placeholder', "function_name": '', "is_context": True},
|
||||
{"field_name": 'reference_documents', "function_name": 'handle_doc_info'},
|
||||
{"field_name": 'session_records', "function_name": 'handle_session_records'},
|
||||
{"field_name": 'output_format', "function_name": 'handle_output_format', 'title': 'Response Output Format', "is_context": False},
|
||||
{"field_name": 'begin!!!', "function_name": 'handle_response', "is_context": False, "omit_if_empty": False}
|
||||
]
|
||||
|
||||
EXECUTOR_PROMPT_CONFIGS = [
|
||||
{"field_name": 'agent_profile', "function_name": 'handle_agent_profile', "is_context": False},
|
||||
{"field_name": 'tool_information',"function_name": 'handle_tool_data', "is_context": False},
|
||||
{"field_name": 'context_placeholder', "function_name": '', "is_context": True},
|
||||
{"field_name": 'reference_documents', "function_name": 'handle_doc_info'},
|
||||
{"field_name": 'session_records', "function_name": 'handle_session_records'},
|
||||
{"field_name": 'task_records', "function_name": 'handle_task_records'},
|
||||
{"field_name": 'current_plan', "function_name": 'handle_current_plan'},
|
||||
{"field_name": 'output_format', "function_name": 'handle_output_format', 'title': 'Response Output Format', "is_context": False},
|
||||
{"field_name": 'begin!!!', "function_name": 'handle_response', "is_context": False, "omit_if_empty": False}
|
||||
]
|
||||
|
||||
SELECTOR_PROMPT_CONFIGS = [
|
||||
{"field_name": 'agent_profile', "function_name": 'handle_agent_profile', "is_context": False},
|
||||
{"field_name": 'tool_information',"function_name": 'handle_tool_data', "is_context": False},
|
||||
{"field_name": 'agent_infomation', "function_name": 'handle_agent_data', "is_context": False, "omit_if_empty": False},
|
||||
{"field_name": 'context_placeholder', "function_name": '', "is_context": True},
|
||||
{"field_name": 'reference_documents', "function_name": 'handle_doc_info'},
|
||||
{"field_name": 'session_records', "function_name": 'handle_session_records'},
|
||||
{"field_name": 'current_plan', "function_name": 'handle_current_plan'},
|
||||
{"field_name": 'output_format', "function_name": 'handle_output_format', 'title': 'Response Output Format', "is_context": False},
|
||||
{"field_name": 'begin!!!', "function_name": 'handle_response', "is_context": False, "omit_if_empty": False}
|
||||
]
|
|
@ -8,9 +8,9 @@ from .intention_template_prompt import RECOGNIZE_INTENTION_PROMPT
|
|||
|
||||
from .checker_template_prompt import CHECKER_PROMPT, CHECKER_TEMPLATE_PROMPT
|
||||
|
||||
from .summary_template_prompt import CONV_SUMMARY_PROMPT
|
||||
from .summary_template_prompt import CONV_SUMMARY_PROMPT, CONV_SUMMARY_PROMPT_SPEC
|
||||
|
||||
from .qa_template_prompt import QA_PROMPT, CODE_QA_PROMPT, QA_TEMPLATE_PROMPT
|
||||
from .qa_template_prompt import QA_PROMPT, CODE_QA_PROMPT, QA_TEMPLATE_PROMPT, CODE_PROMPT_TEMPLATE, CODE_INTERPERT_TEMPLATE, ORIGIN_TEMPLATE_PROMPT
|
||||
|
||||
from .executor_template_prompt import EXECUTOR_TEMPLATE_PROMPT
|
||||
from .refine_template_prompt import REFINE_TEMPLATE_PROMPT
|
||||
|
@ -30,8 +30,8 @@ __all__ = [
|
|||
"RECOGNIZE_INTENTION_PROMPT",
|
||||
"PRD_WRITER_METAGPT_PROMPT", "DESIGN_WRITER_METAGPT_PROMPT", "TASK_WRITER_METAGPT_PROMPT", "CODE_WRITER_METAGPT_PROMPT",
|
||||
"CHECKER_PROMPT", "CHECKER_TEMPLATE_PROMPT",
|
||||
"CONV_SUMMARY_PROMPT",
|
||||
"QA_PROMPT", "CODE_QA_PROMPT", "QA_TEMPLATE_PROMPT",
|
||||
"CONV_SUMMARY_PROMPT", "CONV_SUMMARY_PROMPT_SPEC",
|
||||
"QA_PROMPT", "CODE_QA_PROMPT", "QA_TEMPLATE_PROMPT", "CODE_PROMPT_TEMPLATE", "CODE_INTERPERT_TEMPLATE", "ORIGIN_TEMPLATE_PROMPT",
|
||||
"EXECUTOR_TEMPLATE_PROMPT",
|
||||
"REFINE_TEMPLATE_PROMPT",
|
||||
"SELECTOR_AGENT_TEMPLATE_PROMPT",
|
|
@ -0,0 +1,21 @@
|
|||
SELECTOR_AGENT_TEMPLATE_PROMPT = """#### Agent Profile
|
||||
|
||||
Your goal is to response according the Context Data's information with the role that will best facilitate a solution, taking into account all relevant context (Context) provided.
|
||||
|
||||
When you need to select the appropriate role for handling a user's query, carefully read the provided role names, role descriptions and tool list.
|
||||
|
||||
ATTENTION: response carefully referenced "Response Output Format" in format.
|
||||
|
||||
#### Input Format
|
||||
|
||||
**Origin Query:** the initial question or objective that the user wanted to achieve
|
||||
|
||||
**Context:** the context history to determine if Origin Query has been achieved.
|
||||
|
||||
#### Response Output Format
|
||||
|
||||
**Thoughts:** think the reason step by step about why you selecte one role
|
||||
|
||||
**Role:** Select the role from agent names.
|
||||
|
||||
"""
|
|
@ -1,5 +1,5 @@
|
|||
|
||||
CHECKER_TEMPLATE_PROMPT = """#### Checker Assistance Guidance
|
||||
CHECKER_TEMPLATE_PROMPT = """#### Agent Profile
|
||||
|
||||
When users have completed a sequence of tasks or if there is clear evidence that no further actions are required, your role is to confirm the completion.
|
||||
Your task is to assess the current situation based on the context and determine whether all objectives have been met.
|
||||
|
@ -12,7 +12,7 @@ Each decision should be justified based on the context provided, specifying if t
|
|||
**Context:** the current status and history of the tasks to determine if Origin Query has been achieved.
|
||||
|
||||
#### Response Output Format
|
||||
**Action Status:** Set to 'finished' or 'continued'.
|
||||
**Action Status:** finished or continued
|
||||
If it's 'finished', the context can answer the origin query.
|
||||
If it's 'continued', the context cant answer the origin query.
|
||||
|
|
@ -0,0 +1,31 @@
|
|||
EXECUTOR_TEMPLATE_PROMPT = """#### Agent Profile
|
||||
|
||||
When users need help with coding or using tools, your role is to provide precise and effective guidance.
|
||||
Use the tools provided if they can solve the problem, otherwise, write the code step by step, showing only the part necessary to solve the current problem.
|
||||
Each reply should contain only the guidance required for the current step either by tool usage or code.
|
||||
ATTENTION: The Action Status field ensures that the tools or code mentioned in the Action can be parsed smoothly. Please make sure not to omit the Action Status field when replying.
|
||||
|
||||
#### Response Output Format
|
||||
|
||||
**Thoughts:** Considering the session records and executed steps, decide whether the current step requires the use of a tool or code_executing.
|
||||
Solve the problem step by step, only displaying the thought process necessary for the current step of solving the problem.
|
||||
If code_executing is required, outline the plan for executing this step.
|
||||
|
||||
**Action Status:** Set to 'stopped' or 'code_executing'. If it's 'stopped', the next action is to provide the final answer to the original question. If it's 'code_executing', the next step is to write the code.
|
||||
|
||||
**Action:** Code according to your thoughts. Use this format for code:
|
||||
|
||||
```python
|
||||
# Write your code here
|
||||
```
|
||||
"""
|
||||
|
||||
# **Observation:** Check the results and effects of the executed code.
|
||||
|
||||
# ... (Repeat this Question/Thoughts/Action/Observation cycle as needed)
|
||||
|
||||
# **Thoughts:** I now know the final answer
|
||||
|
||||
# **Action Status:** Set to 'stopped'
|
||||
|
||||
# **Action:** The final answer to the original input question
|
|
@ -1,4 +1,4 @@
|
|||
PRD_WRITER_METAGPT_PROMPT = """#### PRD Writer Assistance Guidance
|
||||
PRD_WRITER_METAGPT_PROMPT = """#### Agent Profile
|
||||
|
||||
You are a professional Product Manager, your goal is to design a concise, usable, efficient product.
|
||||
According to the context, fill in the following missing information, note that each sections are returned in Python code triple quote form seperatedly.
|
||||
|
@ -56,12 +56,12 @@ There are no unclear points.'''
|
|||
|
||||
|
||||
|
||||
DESIGN_WRITER_METAGPT_PROMPT = """#### PRD Writer Assistance Guidance
|
||||
DESIGN_WRITER_METAGPT_PROMPT = """#### Agent Profile
|
||||
|
||||
You are an architect; the goal is to design a SOTA PEP8-compliant python system; make the best use of good open source tools.
|
||||
Fill in the following missing information based on the context, note that all sections are response with code form separately.
|
||||
8192 chars or 2048 tokens. Try to use them up.
|
||||
ATTENTION: response carefully referenced "Response Format" in format.
|
||||
ATTENTION: response carefully referenced "Response Output Format" in format.
|
||||
|
||||
#### Input Format
|
||||
|
||||
|
@ -69,7 +69,7 @@ ATTENTION: response carefully referenced "Response Format" in format.
|
|||
|
||||
**Context:** the current status and history of the tasks to determine if Origin Query has been achieved.
|
||||
|
||||
#### Response Format
|
||||
#### Response Output Format
|
||||
**Implementation approach:**
|
||||
Provide as Plain text. Analyze the difficult points of the requirements, select the appropriate open-source framework.
|
||||
|
||||
|
@ -117,7 +117,7 @@ Provide as Plain text. Make clear here.
|
|||
|
||||
|
||||
|
||||
TASK_WRITER_METAGPT_PROMPT = """#### Task Plan Assistance Guidance
|
||||
TASK_WRITER_METAGPT_PROMPT = """#### Agent Profile
|
||||
|
||||
You are a project manager, the goal is to break down tasks according to PRD/technical design, give a task list, and analyze task dependencies to start with the prerequisite modules
|
||||
Based on the context, fill in the following missing information, note that all sections are returned in Python code triple quote form seperatedly.
|
||||
|
@ -176,7 +176,7 @@ Provide as Plain text. Make clear here. For example, don't forget a main entry.
|
|||
"""
|
||||
|
||||
|
||||
CODE_WRITER_METAGPT_PROMPT = """#### Code Writer Assistance Guidance
|
||||
CODE_WRITER_METAGPT_PROMPT = """#### Agent Profile
|
||||
|
||||
You are a professional engineer; the main goal is to write PEP8 compliant, elegant, modular, easy to read and maintain Python 3.9 code (but you can also use other programming language)
|
||||
|
||||
|
@ -204,7 +204,7 @@ ATTENTION: response carefully referenced "Response Output Format" in format **$k
|
|||
#### Response Output Format
|
||||
**Action Status:** Coding2File
|
||||
|
||||
**SaveFileName** construct a local file name based on Question and Context, such as
|
||||
**SaveFileName:** construct a local file name based on Question and Context, such as
|
||||
|
||||
```python
|
||||
$projectname/$filename.py
|
|
@ -1,6 +1,6 @@
|
|||
|
||||
|
||||
PLANNER_TEMPLATE_PROMPT = """#### Planner Assistance Guidance
|
||||
PLANNER_TEMPLATE_PROMPT = """#### Agent Profile
|
||||
|
||||
When users need assistance with generating a sequence of achievable tasks, your role is to provide a coherent and continuous plan.
|
||||
Design the plan step by step, ensuring each task builds on the completion of the previous one.
|
||||
|
@ -27,12 +27,11 @@ If it's 'planning', the PLAN is to provide a Python list[str] of achievable task
|
|||
"""
|
||||
|
||||
|
||||
TOOL_PLANNER_PROMPT = """#### Tool Planner Assistance Guidance
|
||||
TOOL_PLANNER_PROMPT = """#### Agent Profile
|
||||
|
||||
Helps user to break down a process of tool usage into a series of plans.
|
||||
If there are no available tools, can directly answer the question.
|
||||
Rrespond to humans in the most helpful and accurate way possible.
|
||||
You can use the following tool: {formatted_tools}
|
||||
|
||||
#### Input Format
|
||||
|
|
@ -1,4 +1,6 @@
|
|||
QA_TEMPLATE_PROMPT = """#### Question Answer Assistance Guidance
|
||||
# Question Answer Assistance Guidance
|
||||
|
||||
QA_TEMPLATE_PROMPT = """#### Agent Profile
|
||||
|
||||
Based on the information provided, please answer the origin query concisely and professionally.
|
||||
Attention: Follow the input format and response output format
|
||||
|
@ -18,7 +20,7 @@ If the answer cannot be derived from the given Context and DocInfos, please say
|
|||
"""
|
||||
|
||||
|
||||
CODE_QA_PROMPT = """#### Code Answer Assistance Guidance
|
||||
CODE_QA_PROMPT = """#### Agent Profile
|
||||
|
||||
Based on the information provided, please answer the origin query concisely and professionally.
|
||||
Attention: Follow the input format and response output format
|
||||
|
@ -51,4 +53,22 @@ $JSON_BLOB
|
|||
```
|
||||
"""
|
||||
|
||||
# CODE_QA_PROMPT = """【指令】根据已知信息来回答问"""
|
||||
# 基于本地代码知识问答的提示词模版
|
||||
CODE_PROMPT_TEMPLATE = """【指令】根据已知信息来回答问题。
|
||||
|
||||
【已知信息】{context}
|
||||
|
||||
【问题】{question}"""
|
||||
|
||||
# 代码解释模版
|
||||
CODE_INTERPERT_TEMPLATE = '''{code}
|
||||
|
||||
解释一下这段代码'''
|
||||
# CODE_QA_PROMPT = """【指令】根据已知信息来回答问"""
|
||||
|
||||
# 基于本地知识问答的提示词模版
|
||||
ORIGIN_TEMPLATE_PROMPT = """【指令】根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,不允许在答案中添加编造成分,答案请使用中文。
|
||||
|
||||
【已知信息】{context}
|
||||
|
||||
【问题】{question}"""
|
|
@ -0,0 +1,103 @@
|
|||
|
||||
|
||||
# REACT_CODE_PROMPT = """#### Agent Profile
|
||||
|
||||
# 1. When users need help with coding, your role is to provide precise and effective guidance.
|
||||
# 2. Reply follows the format of Thoughts/Action Status/Action/Observation cycle.
|
||||
# 3. Provide the final answer if they can solve the problem, otherwise, write the code step by step, showing only the part necessary to solve the current problem.
|
||||
# Each reply should contain only the guidance required for the current step either by tool usage or code.
|
||||
# 4. If the Response already contains content, continue writing following the format of the Response Output Format.
|
||||
|
||||
# #### Response Output Format
|
||||
|
||||
# **Thoughts:** Considering the session records and executed steps, solve the problem step by step, only displaying the thought process necessary for the current step of solving the problem,
|
||||
# outline the plan for executing this step.
|
||||
|
||||
# **Action Status:** Set to 'stopped' or 'code_executing'.
|
||||
# If it's 'stopped', the action is to provide the final answer to the session records and executed steps.
|
||||
# If it's 'code_executing', the action is to write the code.
|
||||
|
||||
# **Action:**
|
||||
# ```python
|
||||
# # Write your code here
|
||||
# ...
|
||||
# ```
|
||||
|
||||
# **Observation:** Check the results and effects of the executed code.
|
||||
|
||||
# ... (Repeat this "Thoughts/Action Status/Action/Observation" cycle format as needed)
|
||||
|
||||
# **Thoughts:** Considering the session records and executed steps, give the final answer
|
||||
# .
|
||||
# **Action Status:** stopped
|
||||
|
||||
# **Action:** Response the final answer to the session records.
|
||||
|
||||
# """
|
||||
|
||||
REACT_CODE_PROMPT = """#### Agent Profile
|
||||
|
||||
When users need help with coding, your role is to provide precise and effective guidance.
|
||||
|
||||
Write the code step by step, showing only the part necessary to solve the current problem. Each reply should contain only the code required for the current step.
|
||||
|
||||
#### Response Output Format
|
||||
|
||||
**Thoughts:** According the previous context, solve the problem step by step, only displaying the thought process necessary for the current step of solving the problem,
|
||||
outline the plan for executing this step.
|
||||
|
||||
**Action Status:** Set to 'stopped' or 'code_executing'.
|
||||
If it's 'stopped', the action is to provide the final answer to the session records and executed steps.
|
||||
If it's 'code_executing', the action is to write the code.
|
||||
|
||||
**Action:**
|
||||
```python
|
||||
# Write your code here
|
||||
...
|
||||
```
|
||||
|
||||
**Observation:** Check the results and effects of the executed code.
|
||||
|
||||
... (Repeat this "Thoughts/Action Status/Action/Observation" cycle format as needed)
|
||||
|
||||
**Thoughts:** Considering the session records and executed steps, give the final answer
|
||||
.
|
||||
**Action Status:** stopped
|
||||
|
||||
**Action:** Response the final answer to the session records.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
# REACT_CODE_PROMPT = """#### Writing Code Assistance Guidance
|
||||
|
||||
# When users need help with coding, your role is to provide precise and effective guidance.
|
||||
|
||||
# Write the code step by step, showing only the part necessary to solve the current problem. Each reply should contain only the code required for the current step.
|
||||
|
||||
# #### Response Process
|
||||
|
||||
# **Question:** First, clarify the problem to be solved.
|
||||
|
||||
# **Thoughts:** Based on the question and observations above, provide the plan for executing this step.
|
||||
|
||||
# **Action Status:** Set to 'stoped' or 'code_executing'. If it's 'stoped', the action is to provide the final answer to the original question. If it's 'code_executing', the action is to write the code.
|
||||
|
||||
# **Action:**
|
||||
# ```python
|
||||
# # Write your code here
|
||||
# import os
|
||||
# ...
|
||||
# ```
|
||||
|
||||
# **Observation:** Check the results and effects of the executed code.
|
||||
|
||||
# ... (Repeat this Thoughts/Action/Observation cycle as needed)
|
||||
|
||||
# **Thoughts:** I now know the final answer
|
||||
|
||||
# **Action Status:** Set to 'stoped'
|
||||
|
||||
# **Action:** The final answer to the original input question
|
||||
|
||||
# """
|
|
@ -0,0 +1,37 @@
|
|||
|
||||
|
||||
REACT_TEMPLATE_PROMPT = """#### Agent Profile
|
||||
|
||||
1. When users need help with coding, your role is to provide precise and effective guidance.
|
||||
2. Reply follows the format of Thoughts/Action Status/Action/Observation cycle.
|
||||
3. Provide the final answer if they can solve the problem, otherwise, write the code step by step, showing only the part necessary to solve the current problem.
|
||||
Each reply should contain only the guidance required for the current step either by tool usage or code.
|
||||
4. If the Response already contains content, continue writing following the format of the Response Output Format.
|
||||
|
||||
ATTENTION: Under the "Response" heading, the output format strictly adheres to the content specified in the "Response Output Format."
|
||||
|
||||
#### Response Output Format
|
||||
|
||||
**Question:** First, clarify the problem to be solved.
|
||||
|
||||
**Thoughts:** Based on the Session Records or observations above, provide the plan for executing this step.
|
||||
|
||||
**Action Status:** Set to either 'stopped' or 'code_executing'. If it's 'stopped', the next action is to provide the final answer to the original question. If it's 'code_executing', the next step is to write the code.
|
||||
|
||||
**Action:** Code according to your thoughts. Use this format for code:
|
||||
|
||||
```python
|
||||
# Write your code here
|
||||
```
|
||||
|
||||
**Observation:** Check the results and effects of the executed code.
|
||||
|
||||
... (Repeat this "Thoughts/Action Status/Action/Observation" cycle format as needed)
|
||||
|
||||
**Thoughts:** Considering the session records and executed steps, give the final answer.
|
||||
|
||||
**Action Status:** stopped
|
||||
|
||||
**Action:** Response the final answer to the session records.
|
||||
|
||||
"""
|
|
@ -1,13 +1,9 @@
|
|||
REACT_TOOL_AND_CODE_PLANNER_PROMPT = """#### Planner Assistance Guidance
|
||||
REACT_TOOL_AND_CODE_PLANNER_PROMPT = """#### Agent Profile
|
||||
When users seek assistance in breaking down complex issues into manageable and actionable steps,
|
||||
your responsibility is to deliver a well-organized strategy or resolution through the use of tools or coding.
|
||||
|
||||
ATTENTION: response carefully referenced "Response Output Format" in format.
|
||||
|
||||
You may use the following tools:
|
||||
{formatted_tools}
|
||||
Depending on the user's query, the response will either be a plan detailing the use of tools and reasoning, or a direct answer if the problem does not require breaking down.
|
||||
|
||||
#### Input Format
|
||||
|
||||
**Question:** First, clarify the problem to be solved.
|
|
@ -0,0 +1,197 @@
|
|||
REACT_TOOL_AND_CODE_PROMPT = """#### Agent Profile
|
||||
|
||||
When users need help with coding or using tools, your role is to provide precise and effective guidance.
|
||||
Use the tools provided if they can solve the problem, otherwise, write the code step by step, showing only the part necessary to solve the current problem.
|
||||
Each reply should contain only the guidance required for the current step either by tool usage or code.
|
||||
ATTENTION: The Action Status field ensures that the tools or code mentioned in the Action can be parsed smoothly. Please make sure not to omit the Action Status field when replying.
|
||||
|
||||
#### Tool Infomation
|
||||
|
||||
You can use these tools:\n{formatted_tools}
|
||||
|
||||
Valid "tool_name" value:\n{tool_names}
|
||||
|
||||
#### Response Output Format
|
||||
|
||||
**Thoughts:** Considering the session records and executed steps, decide whether the current step requires the use of a tool or code_executing. Solve the problem step by step, only displaying the thought process necessary for the current step of solving the problem. If code_executing is required, outline the plan for executing this step.
|
||||
|
||||
**Action Status:** stoped, tool_using or code_executing
|
||||
Use 'stopped' when the task has been completed, and no further use of tools or execution of code is necessary.
|
||||
Use 'tool_using' when the current step in the process involves utilizing a tool to proceed.
|
||||
Use 'code_executing' when the current step requires writing and executing code.
|
||||
|
||||
**Action:**
|
||||
|
||||
If Action Status is 'tool_using', format the tool action in JSON from Question and Observation, enclosed in a code block, like this:
|
||||
```json
|
||||
{
|
||||
"tool_name": "$TOOL_NAME",
|
||||
"tool_params": "$INPUT"
|
||||
}
|
||||
```
|
||||
|
||||
If Action Status is 'code_executing', write the necessary code to solve the issue, enclosed in a code block, like this:
|
||||
```python
|
||||
Write your running code here
|
||||
```
|
||||
|
||||
If Action Status is 'stopped', provide the final response or instructions in written form, enclosed in a code block, like this:
|
||||
```text
|
||||
The final response or instructions to the user question.
|
||||
```
|
||||
|
||||
**Observation:** Check the results and effects of the executed action.
|
||||
|
||||
... (Repeat this Thoughts/Action Status/Action/Observation cycle as needed)
|
||||
|
||||
**Thoughts:** Conclude the final response to the user question.
|
||||
|
||||
**Action Status:** stoped
|
||||
|
||||
**Action:** The final answer or guidance to the user question.
|
||||
"""
|
||||
|
||||
# REACT_TOOL_AND_CODE_PROMPT = """#### Agent Profile
|
||||
|
||||
# 1. When users need help with coding or using tools, your role is to provide precise and effective guidance.
|
||||
# 2. Reply follows the format of Thoughts/Action Status/Action/Observation cycle.
|
||||
# 3. Use the tools provided if they can solve the problem, otherwise, write the code step by step, showing only the part necessary to solve the current problem.
|
||||
# Each reply should contain only the guidance required for the current step either by tool usage or code.
|
||||
# 4. If the Response already contains content, continue writing following the format of the Response Output Format.
|
||||
|
||||
# ATTENTION: The "Action Status" field ensures that the tools or code mentioned in the "Action" can be parsed smoothly. Please make sure not to omit the "Action Status" field when replying.
|
||||
|
||||
# #### Tool Infomation
|
||||
|
||||
# You can use these tools:\n{formatted_tools}
|
||||
|
||||
# Valid "tool_name" value:\n{tool_names}
|
||||
|
||||
# #### Response Output Format
|
||||
|
||||
# **Thoughts:** Considering the user's question, previously executed steps, and the plan, decide whether the current step requires the use of a tool or code_executing.
|
||||
# Solve the problem step by step, only displaying the thought process necessary for the current step of solving the problem.
|
||||
# If a tool can be used, provide its name and parameters. If code_executing is required, outline the plan for executing this step.
|
||||
|
||||
# **Action Status:** stoped, tool_using, or code_executing. (Choose one from these three statuses.)
|
||||
# # If the task is done, set it to 'stoped'.
|
||||
# # If using a tool, set it to 'tool_using'.
|
||||
# # If writing code, set it to 'code_executing'.
|
||||
|
||||
# **Action:**
|
||||
|
||||
# If Action Status is 'tool_using', format the tool action in JSON from Question and Observation, enclosed in a code block, like this:
|
||||
# ```json
|
||||
# {
|
||||
# "tool_name": "$TOOL_NAME",
|
||||
# "tool_params": "$INPUT"
|
||||
# }
|
||||
# ```
|
||||
|
||||
# If Action Status is 'code_executing', write the necessary code to solve the issue, enclosed in a code block, like this:
|
||||
# ```python
|
||||
# Write your running code here
|
||||
# ...
|
||||
# ```
|
||||
|
||||
# If Action Status is 'stopped', provide the final response or instructions in written form, enclosed in a code block, like this:
|
||||
# ```text
|
||||
# The final response or instructions to the original input question.
|
||||
# ```
|
||||
|
||||
# **Observation:** Check the results and effects of the executed action.
|
||||
|
||||
# ... (Repeat this Thoughts/Action Status/Action/Observation cycle as needed)
|
||||
|
||||
# **Thoughts:** Considering the user's question, previously executed steps, give the final answer.
|
||||
|
||||
# **Action Status:** stopped
|
||||
|
||||
# **Action:** Response the final answer to the session records.
|
||||
# """
|
||||
|
||||
|
||||
# REACT_TOOL_AND_CODE_PROMPT = """#### Code and Tool Agent Assistance Guidance
|
||||
|
||||
# When users need help with coding or using tools, your role is to provide precise and effective guidance. Use the tools provided if they can solve the problem, otherwise, write the code step by step, showing only the part necessary to solve the current problem. Each reply should contain only the guidance required for the current step either by tool usage or code.
|
||||
|
||||
# #### Tool Infomation
|
||||
|
||||
# You can use these tools:\n{formatted_tools}
|
||||
|
||||
# Valid "tool_name" value:\n{tool_names}
|
||||
|
||||
# #### Response Process
|
||||
|
||||
# **Question:** Start by understanding the input question to be answered.
|
||||
|
||||
# **Thoughts:** Considering the user's question, previously executed steps, and the plan, decide whether the current step requires the use of a tool or code_executing. Solve the problem step by step, only displaying the thought process necessary for the current step of solving the problem. If a tool can be used, provide its name and parameters. If code_executing is required, outline the plan for executing this step.
|
||||
|
||||
# **Action Status:** stoped, tool_using, or code_executing. (Choose one from these three statuses.)
|
||||
# If the task is done, set it to 'stoped'.
|
||||
# If using a tool, set it to 'tool_using'.
|
||||
# If writing code, set it to 'code_executing'.
|
||||
|
||||
# **Action:**
|
||||
|
||||
# If using a tool, use the tools by formatting the tool action in JSON from Question and Observation:. The format should be:
|
||||
# ```json
|
||||
# {{
|
||||
# "tool_name": "$TOOL_NAME",
|
||||
# "tool_params": "$INPUT"
|
||||
# }}
|
||||
# ```
|
||||
|
||||
# If the problem cannot be solved with a tool at the moment, then proceed to solve the issue using code. Output the following format to execute the code:
|
||||
|
||||
# ```python
|
||||
# Write your code here
|
||||
# ```
|
||||
|
||||
# **Observation:** Check the results and effects of the executed action.
|
||||
|
||||
# ... (Repeat this Thoughts/Action/Observation cycle as needed)
|
||||
|
||||
# **Thoughts:** Conclude the final response to the input question.
|
||||
|
||||
# **Action Status:** stoped
|
||||
|
||||
# **Action:** The final answer or guidance to the original input question.
|
||||
# """
|
||||
|
||||
|
||||
# REACT_TOOL_AND_CODE_PROMPT = """你是一个使用工具与代码的助手。
|
||||
# 如果现有工具不足以完成整个任务,请不要添加不存在的工具,只使用现有工具完成可能的部分。
|
||||
# 如果当前步骤不能使用工具完成,将由代码来完成。
|
||||
# 有效的"action"值为:"stopped"(已经完成用户的任务) 、 "tool_using" (使用工具来回答问题) 或 'code_executing'(结合总结下述思维链过程编写下一步的可执行代码)。
|
||||
# 尽可能地以有帮助和准确的方式回应人类,你可以使用以下工具:
|
||||
# {formatted_tools}
|
||||
# 如果现在的步骤可以用工具解决问题,请仅在每个$JSON_BLOB中提供一个action,如下所示:
|
||||
# ```
|
||||
# {{{{
|
||||
# "action": $ACTION,
|
||||
# "tool_name": $TOOL_NAME
|
||||
# "tool_params": $INPUT
|
||||
# }}}}
|
||||
# ```
|
||||
# 若当前无法通过工具解决问题,则使用代码解决问题
|
||||
# 请仅在每个$JSON_BLOB中提供一个action,如下所示:
|
||||
# ```
|
||||
# {{{{'action': $ACTION,'code_content': $CODE}}}}
|
||||
# ```
|
||||
|
||||
# 按照以下思维链格式进行回应($JSON_BLOB要求符合上述规定):
|
||||
# 问题:输入问题以回答
|
||||
# 思考:考虑之前和之后的步骤
|
||||
# 行动:
|
||||
# ```
|
||||
# $JSON_BLOB
|
||||
# ```
|
||||
# 观察:行动结果
|
||||
# ...(重复思考/行动/观察N次)
|
||||
# 思考:我知道该如何回应
|
||||
# 行动:
|
||||
# ```
|
||||
# $JSON_BLOB
|
||||
# ```
|
||||
# """
|
|
@ -1,47 +1,43 @@
|
|||
REACT_TOOL_PROMPT = """#### Tool Agent Assistance Guidance
|
||||
REACT_TOOL_PROMPT = """#### Agent Profile
|
||||
|
||||
When interacting with users, your role is to respond in a helpful and accurate manner using the tools available. Follow the steps below to ensure efficient and effective use of the tools.
|
||||
|
||||
Please note that all the tools you can use are listed below. You can only choose from these tools for use. If there are no suitable tools, please do not invent any tools. Just let the user know that you do not have suitable tools to use.
|
||||
Please note that all the tools you can use are listed below. You can only choose from these tools for use.
|
||||
|
||||
#### Tool List
|
||||
If there are no suitable tools, please do not invent any tools. Just let the user know that you do not have suitable tools to use.
|
||||
|
||||
you can use these tools:\n{formatted_tools}
|
||||
ATTENTION: The Action Status field ensures that the tools or code mentioned in the Action can be parsed smoothly. Please make sure not to omit the Action Status field when replying.
|
||||
|
||||
valid "tool_name" value is:\n{tool_names}
|
||||
#### Response Output Format
|
||||
|
||||
#### Response Process
|
||||
**Thoughts:** According the previous observations, plan the approach for using the tool effectively.
|
||||
|
||||
**Question:** Start by understanding the input question to be answered.
|
||||
|
||||
**Thoughts:** Based on the question and previous observations, plan the approach for using the tool effectively.
|
||||
|
||||
**Action Status:** Set to either 'stoped' or 'tool_using'. If 'stoped', provide the final response to the original question. If 'tool_using', proceed with using the specified tool.
|
||||
**Action Status:** Set to either 'stopped' or 'tool_using'. If 'stopped', provide the final response to the original question. If 'tool_using', proceed with using the specified tool.
|
||||
|
||||
**Action:** Use the tools by formatting the tool action in JSON. The format should be:
|
||||
|
||||
```json
|
||||
{{
|
||||
{
|
||||
"tool_name": "$TOOL_NAME",
|
||||
"tool_params": "$INPUT"
|
||||
}}
|
||||
}
|
||||
```
|
||||
|
||||
**Observation:** Evaluate the outcome of the tool's usage.
|
||||
|
||||
... (Repeat this Thoughts/Action/Observation cycle as needed)
|
||||
... (Repeat this Thoughts/Action Status/Action/Observation cycle as needed)
|
||||
|
||||
**Thoughts:** Determine the final response based on the results.
|
||||
|
||||
**Action Status:** Set to 'stoped'
|
||||
**Action Status:** Set to 'stopped'
|
||||
|
||||
**Action:** Conclude with the final response to the original question in this format:
|
||||
|
||||
```json
|
||||
{{
|
||||
{
|
||||
"tool_params": "Final response to be provided to the user",
|
||||
"tool_name": "notool",
|
||||
}}
|
||||
}
|
||||
```
|
||||
"""
|
||||
|
||||
|
@ -49,7 +45,7 @@ valid "tool_name" value is:\n{tool_names}
|
|||
# REACT_TOOL_PROMPT = """尽可能地以有帮助和准确的方式回应人类。您可以使用以下工具:
|
||||
# {formatted_tools}
|
||||
# 使用json blob来指定一个工具,提供一个action关键字(工具名称)和一个tool_params关键字(工具输入)。
|
||||
# 有效的"action"值为:"stoped" 或 "tool_using" (使用工具来回答问题)
|
||||
# 有效的"action"值为:"stopped" 或 "tool_using" (使用工具来回答问题)
|
||||
# 有效的"tool_name"值为:{tool_names}
|
||||
# 请仅在每个$JSON_BLOB中提供一个action,如下所示:
|
||||
# ```
|
||||
|
@ -73,7 +69,7 @@ valid "tool_name" value is:\n{tool_names}
|
|||
# 行动:
|
||||
# ```
|
||||
# {{{{
|
||||
# "action": "stoped",
|
||||
# "action": "stopped",
|
||||
# "tool_name": "notool",
|
||||
# "tool_params": "最终返回答案给到用户"
|
||||
# }}}}
|
|
@ -1,4 +1,4 @@
|
|||
REFINE_TEMPLATE_PROMPT = """#### Refiner Assistance Guidance
|
||||
REFINE_TEMPLATE_PROMPT = """#### Agent Profile
|
||||
|
||||
When users have a sequence of tasks that require optimization or adjustment based on feedback from the context, your role is to refine the existing plan.
|
||||
Your task is to identify where improvements can be made and provide a revised plan that is more efficient or effective.
|
|
@ -0,0 +1,40 @@
|
|||
CONV_SUMMARY_PROMPT = """尽可能地以有帮助和准确的方式回应人类,根据“背景信息”中的有效信息回答问题,
|
||||
使用 JSON Blob 来指定一个返回的内容,提供一个 action(行动)。
|
||||
有效的 'action' 值为:'finished'(任务已经可以通过上下文信息可以回答) or 'continue' (根据背景信息回答问题)。
|
||||
在每个 $JSON_BLOB 中仅提供一个 action,如下所示:
|
||||
```
|
||||
{{'action': $ACTION, 'content': '根据背景信息回答问题'}}
|
||||
```
|
||||
按照以下格式进行回应:
|
||||
问题:输入问题以回答
|
||||
行动:
|
||||
```
|
||||
$JSON_BLOB
|
||||
```
|
||||
"""
|
||||
|
||||
CONV_SUMMARY_PROMPT = """尽可能地以有帮助和准确的方式回应人类
|
||||
根据“背景信息”中的有效信息回答问题,同时展现解答的过程和内容
|
||||
若能根“背景信息”回答问题,则直接回答
|
||||
否则,总结“背景信息”的内容
|
||||
"""
|
||||
|
||||
|
||||
CONV_SUMMARY_PROMPT_SPEC = """
|
||||
Your job is to summarize a history of previous messages in a conversation between an AI persona and a human.
|
||||
The conversation you are given is a fixed context window and may not be complete.
|
||||
Messages sent by the AI are marked with the 'assistant' role.
|
||||
The AI 'assistant' can also make calls to functions, whose outputs can be seen in messages with the 'function' role.
|
||||
Things the AI says in the message content are considered inner monologue and are not seen by the user.
|
||||
The only AI messages seen by the user are from when the AI uses 'send_message'.
|
||||
Messages the user sends are in the 'user' role.
|
||||
The 'user' role is also used for important system events, such as login events and heartbeat events (heartbeats run the AI's program without user action, allowing the AI to act without prompting from the user sending them a message).
|
||||
Summarize what happened in the conversation from the perspective of the AI (use the first person).
|
||||
Keep your summary less than 100 words, do NOT exceed this word limit.
|
||||
Only output the summary, do NOT include anything else in your output.
|
||||
|
||||
--- conversation
|
||||
{conversation}
|
||||
---
|
||||
|
||||
"""
|
|
@ -0,0 +1,430 @@
|
|||
from abc import abstractmethod, ABC
|
||||
from typing import List
|
||||
import os, sys, copy, json
|
||||
from jieba.analyse import extract_tags
|
||||
from collections import Counter
|
||||
from loguru import logger
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
|
||||
|
||||
from .schema import Memory, Message
|
||||
from coagent.service.service_factory import KBServiceFactory
|
||||
from coagent.llm_models import getChatModel, getChatModelFromConfig
|
||||
from coagent.llm_models.llm_config import EmbedConfig, LLMConfig
|
||||
from coagent.embeddings.utils import load_embeddings_from_path
|
||||
from coagent.utils.common_utils import save_to_json_file, read_json_file, addMinutesToTime
|
||||
from coagent.connector.configs.prompts import CONV_SUMMARY_PROMPT_SPEC
|
||||
from coagent.orm import table_init
|
||||
# from configs.model_config import KB_ROOT_PATH, EMBEDDING_MODEL, EMBEDDING_DEVICE, SCORE_THRESHOLD
|
||||
# from configs.model_config import embedding_model_dict
|
||||
|
||||
|
||||
class BaseMemoryManager(ABC):
|
||||
"""
|
||||
This class represents a local memory manager that inherits from BaseMemoryManager.
|
||||
|
||||
Attributes:
|
||||
- user_name: A string representing the user name. Default is "default".
|
||||
- unique_name: A string representing the unique name. Default is "default".
|
||||
- memory_type: A string representing the memory type. Default is "recall".
|
||||
- do_init: A boolean indicating whether to initialize. Default is False.
|
||||
- current_memory: An instance of Memory class representing the current memory.
|
||||
- recall_memory: An instance of Memory class representing the recall memory.
|
||||
- summary_memory: An instance of Memory class representing the summary memory.
|
||||
- save_message_keys: A list of strings representing the keys for saving messages.
|
||||
|
||||
Methods:
|
||||
- __init__: Initializes the LocalMemoryManager with the given user_name, unique_name, memory_type, and do_init.
|
||||
- init_vb: Initializes the vb.
|
||||
- append: Appends a message to the recall memory, current memory, and summary memory.
|
||||
- extend: Extends the recall memory, current memory, and summary memory.
|
||||
- save: Saves the memory to the specified directory.
|
||||
- load: Loads the memory from the specified directory and returns a Memory instance.
|
||||
- save_new_to_vs: Saves new messages to the vector space.
|
||||
- save_to_vs: Saves the memory to the vector space.
|
||||
- router_retrieval: Routes the retrieval based on the retrieval type.
|
||||
- embedding_retrieval: Retrieves messages based on embedding.
|
||||
- text_retrieval: Retrieves messages based on text.
|
||||
- datetime_retrieval: Retrieves messages based on datetime.
|
||||
- recursive_summary: Performs recursive summarization of messages.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
user_name: str = "default",
|
||||
unique_name: str = "default",
|
||||
memory_type: str = "recall",
|
||||
do_init: bool = False,
|
||||
):
|
||||
"""
|
||||
Initializes the LocalMemoryManager with the given parameters.
|
||||
|
||||
Args:
|
||||
- user_name: A string representing the user name. Default is "default".
|
||||
- unique_name: A string representing the unique name. Default is "default".
|
||||
- memory_type: A string representing the memory type. Default is "recall".
|
||||
- do_init: A boolean indicating whether to initialize. Default is False.
|
||||
"""
|
||||
self.user_name = user_name
|
||||
self.unique_name = unique_name
|
||||
self.memory_type = memory_type
|
||||
self.do_init = do_init
|
||||
self.current_memory = Memory(messages=[])
|
||||
self.recall_memory = Memory(messages=[])
|
||||
self.summary_memory = Memory(messages=[])
|
||||
self.save_message_keys = [
|
||||
'chat_index', 'role_name', 'role_type', 'role_prompt', 'input_query', 'origin_query',
|
||||
'datetime', 'role_content', 'step_content', 'parsed_output', 'spec_parsed_output', 'parsed_output_list',
|
||||
'task', 'db_docs', 'code_docs', 'search_docs', 'phase_name', 'chain_name', 'customed_kargs']
|
||||
self.init_vb()
|
||||
|
||||
def init_vb(self):
|
||||
"""
|
||||
Initializes the vb.
|
||||
"""
|
||||
pass
|
||||
|
||||
def append(self, message: Message):
|
||||
"""
|
||||
Appends a message to the recall memory, current memory, and summary memory.
|
||||
|
||||
Args:
|
||||
- message: An instance of Message class representing the message to be appended.
|
||||
"""
|
||||
pass
|
||||
|
||||
def extend(self, memory: Memory):
|
||||
"""
|
||||
Extends the recall memory, current memory, and summary memory.
|
||||
|
||||
Args:
|
||||
- memory: An instance of Memory class representing the memory to be extended.
|
||||
"""
|
||||
pass
|
||||
|
||||
def save(self, save_dir: str = ""):
|
||||
"""
|
||||
Saves the memory to the specified directory.
|
||||
|
||||
Args:
|
||||
- save_dir: A string representing the directory to save the memory. Default is KB_ROOT_PATH.
|
||||
"""
|
||||
pass
|
||||
|
||||
def load(self, load_dir: str = "") -> Memory:
|
||||
"""
|
||||
Loads the memory from the specified directory and returns a Memory instance.
|
||||
|
||||
Args:
|
||||
- load_dir: A string representing the directory to load the memory from. Default is KB_ROOT_PATH.
|
||||
|
||||
Returns:
|
||||
- An instance of Memory class representing the loaded memory.
|
||||
"""
|
||||
pass
|
||||
|
||||
def save_new_to_vs(self, messages: List[Message]):
|
||||
"""
|
||||
Saves new messages to the vector space.
|
||||
|
||||
Args:
|
||||
- messages: A list of Message instances representing the messages to be saved.
|
||||
- embed_model: A string representing the embedding model. Default is EMBEDDING_MODEL.
|
||||
- embed_device: A string representing the embedding device. Default is EMBEDDING_DEVICE.
|
||||
"""
|
||||
pass
|
||||
|
||||
def save_to_vs(self, embed_model="", embed_device=""):
|
||||
"""
|
||||
Saves the memory to the vector space.
|
||||
|
||||
Args:
|
||||
- embed_model: A string representing the embedding model. Default is EMBEDDING_MODEL.
|
||||
- embed_device: A string representing the embedding device. Default is EMBEDDING_DEVICE.
|
||||
"""
|
||||
pass
|
||||
|
||||
def router_retrieval(self, text: str=None, datetime: str = None, n=5, top_k=5, retrieval_type: str = "embedding", **kwargs) -> List[Message]:
|
||||
"""
|
||||
Routes the retrieval based on the retrieval type.
|
||||
|
||||
Args:
|
||||
- text: A string representing the text for retrieval. Default is None.
|
||||
- datetime: A string representing the datetime for retrieval. Default is None.
|
||||
- n: An integer representing the number of messages. Default is 5.
|
||||
- top_k: An integer representing the top k messages. Default is 5.
|
||||
- retrieval_type: A string representing the retrieval type. Default is "embedding".
|
||||
- **kwargs: Additional keyword arguments for retrieval.
|
||||
|
||||
Returns:
|
||||
- A list of Message instances representing the retrieved messages.
|
||||
"""
|
||||
pass
|
||||
|
||||
def embedding_retrieval(self, text: str, embed_model="", top_k=1, score_threshold=1.0, **kwargs) -> List[Message]:
|
||||
"""
|
||||
Retrieves messages based on embedding.
|
||||
|
||||
Args:
|
||||
- text: A string representing the text for retrieval.
|
||||
- embed_model: A string representing the embedding model. Default is EMBEDDING_MODEL.
|
||||
- top_k: An integer representing the top k messages. Default is 1.
|
||||
- score_threshold: A float representing the score threshold. Default is SCORE_THRESHOLD.
|
||||
- **kwargs: Additional keyword arguments for retrieval.
|
||||
|
||||
Returns:
|
||||
- A list of Message instances representing the retrieved messages.
|
||||
"""
|
||||
pass
|
||||
|
||||
def text_retrieval(self, text: str, **kwargs) -> List[Message]:
|
||||
"""
|
||||
Retrieves messages based on text.
|
||||
|
||||
Args:
|
||||
- text: A string representing the text for retrieval.
|
||||
- **kwargs: Additional keyword arguments for retrieval.
|
||||
|
||||
Returns:
|
||||
- A list of Message instances representing the retrieved messages.
|
||||
"""
|
||||
pass
|
||||
|
||||
def datetime_retrieval(self, datetime: str, text: str = None, n: int = 5, **kwargs) -> List[Message]:
|
||||
"""
|
||||
Retrieves messages based on datetime.
|
||||
|
||||
Args:
|
||||
- datetime: A string representing the datetime for retrieval.
|
||||
- text: A string representing the text for retrieval. Default is None.
|
||||
- n: An integer representing the number of messages. Default is 5.
|
||||
- **kwargs: Additional keyword arguments for retrieval.
|
||||
|
||||
Returns:
|
||||
- A list of Message instances representing the retrieved messages.
|
||||
"""
|
||||
pass
|
||||
|
||||
def recursive_summary(self, messages: List[Message], split_n: int = 20) -> List[Message]:
|
||||
"""
|
||||
Performs recursive summarization of messages.
|
||||
|
||||
Args:
|
||||
- messages: A list of Message instances representing the messages to be summarized.
|
||||
- split_n: An integer representing the split n. Default is 20.
|
||||
|
||||
Returns:
|
||||
- A list of Message instances representing the summarized messages.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class LocalMemoryManager(BaseMemoryManager):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_config: EmbedConfig,
|
||||
llm_config: LLMConfig,
|
||||
user_name: str = "default",
|
||||
unique_name: str = "default",
|
||||
memory_type: str = "recall",
|
||||
do_init: bool = False,
|
||||
kb_root_path: str = "",
|
||||
):
|
||||
self.user_name = user_name
|
||||
self.unique_name = unique_name
|
||||
self.memory_type = memory_type
|
||||
self.do_init = do_init
|
||||
self.kb_root_path = kb_root_path
|
||||
self.embed_config: EmbedConfig = embed_config
|
||||
self.llm_config: LLMConfig = llm_config
|
||||
self.current_memory = Memory(messages=[])
|
||||
self.recall_memory = Memory(messages=[])
|
||||
self.summary_memory = Memory(messages=[])
|
||||
self.save_message_keys = [
|
||||
'chat_index', 'role_name', 'role_type', 'role_prompt', 'input_query', 'origin_query',
|
||||
'datetime', 'role_content', 'step_content', 'parsed_output', 'spec_parsed_output', 'parsed_output_list',
|
||||
'task', 'db_docs', 'code_docs', 'search_docs', 'phase_name', 'chain_name', 'customed_kargs']
|
||||
self.init_vb()
|
||||
|
||||
def init_vb(self):
|
||||
vb_name = f"{self.user_name}/{self.unique_name}/{self.memory_type}"
|
||||
# default to recreate a new vb
|
||||
table_init()
|
||||
vb = KBServiceFactory.get_service_by_name(vb_name, self.embed_config, self.kb_root_path)
|
||||
if vb:
|
||||
status = vb.clear_vs()
|
||||
|
||||
if not self.do_init:
|
||||
self.load(self.kb_root_path)
|
||||
self.save_to_vs()
|
||||
|
||||
def append(self, message: Message):
|
||||
self.recall_memory.append(message)
|
||||
#
|
||||
if message.role_type == "summary":
|
||||
self.summary_memory.append(message)
|
||||
else:
|
||||
self.current_memory.append(message)
|
||||
|
||||
self.save(self.kb_root_path)
|
||||
self.save_new_to_vs([message])
|
||||
|
||||
def extend(self, memory: Memory):
|
||||
self.recall_memory.extend(memory)
|
||||
self.current_memory.extend(self.recall_memory.filter_by_role_type(["summary"]))
|
||||
self.summary_memory.extend(self.recall_memory.select_by_role_type(["summary"]))
|
||||
self.save(self.kb_root_path)
|
||||
self.save_new_to_vs(memory.messages)
|
||||
|
||||
def save(self, save_dir: str = "./"):
|
||||
file_path = os.path.join(save_dir, f"{self.user_name}/{self.unique_name}/{self.memory_type}/converation.jsonl")
|
||||
memory_messages = self.recall_memory.dict()
|
||||
memory_messages = {k: [
|
||||
{kkk: vvv for kkk, vvv in vv.items() if kkk in self.save_message_keys}
|
||||
for vv in v ]
|
||||
for k, v in memory_messages.items()
|
||||
}
|
||||
#
|
||||
save_to_json_file(memory_messages, file_path)
|
||||
|
||||
def load(self, load_dir: str = "./") -> Memory:
|
||||
file_path = os.path.join(load_dir, f"{self.user_name}/{self.unique_name}/{self.memory_type}/converation.jsonl")
|
||||
|
||||
if os.path.exists(file_path):
|
||||
self.recall_memory = Memory(**read_json_file(file_path))
|
||||
self.current_memory = Memory(messages=self.recall_memory.filter_by_role_type(["summary"]))
|
||||
self.summary_memory = Memory(messages=self.recall_memory.select_by_role_type(["summary"]))
|
||||
|
||||
def save_new_to_vs(self, messages: List[Message]):
|
||||
if self.embed_config:
|
||||
vb_name = f"{self.user_name}/{self.unique_name}/{self.memory_type}"
|
||||
# default to faiss, todo: add new vstype
|
||||
vb = KBServiceFactory.get_service(vb_name, "faiss", self.embed_config, self.kb_root_path)
|
||||
embeddings = load_embeddings_from_path(self.embed_config.embed_model_path, self.embed_config.model_device,)
|
||||
messages = [
|
||||
{k: v for k, v in m.dict().items() if k in self.save_message_keys}
|
||||
for m in messages]
|
||||
docs = [{"page_content": m["step_content"] or m["role_content"] or m["input_query"] or m["origin_query"], "metadata": m} for m in messages]
|
||||
docs = [Document(**doc) for doc in docs]
|
||||
vb.do_add_doc(docs, embeddings)
|
||||
|
||||
def save_to_vs(self):
|
||||
vb_name = f"{self.user_name}/{self.unique_name}/{self.memory_type}"
|
||||
# default to recreate a new vb
|
||||
vb = KBServiceFactory.get_service_by_name(vb_name, self.embed_config, self.kb_root_path)
|
||||
if vb:
|
||||
status = vb.clear_vs()
|
||||
# create_kb(vb_name, "faiss", embed_model)
|
||||
|
||||
# default to faiss, todo: add new vstype
|
||||
vb = KBServiceFactory.get_service(vb_name, "faiss", self.embed_config, self.kb_root_path)
|
||||
embeddings = load_embeddings_from_path(self.embed_config.embed_model_path, self.embed_config.model_device,)
|
||||
messages = self.recall_memory.dict()
|
||||
messages = [
|
||||
{kkk: vvv for kkk, vvv in vv.items() if kkk in self.save_message_keys}
|
||||
for k, v in messages.items() for vv in v]
|
||||
docs = [{"page_content": m["step_content"] or m["role_content"] or m["input_query"] or m["origin_query"], "metadata": m} for m in messages]
|
||||
docs = [Document(**doc) for doc in docs]
|
||||
vb.do_add_doc(docs, embeddings)
|
||||
|
||||
# def load_from_vs(self, embed_model=EMBEDDING_MODEL) -> Memory:
|
||||
# vb_name = f"{self.user_name}/{self.unique_name}/{self.memory_type}"
|
||||
|
||||
# create_kb(vb_name, "faiss", embed_model)
|
||||
# # default to faiss, todo: add new vstype
|
||||
# vb = KBServiceFactory.get_service(vb_name, "faiss", embed_model)
|
||||
# docs = vb.get_all_documents()
|
||||
# print(docs)
|
||||
|
||||
def router_retrieval(self, text: str=None, datetime: str = None, n=5, top_k=5, retrieval_type: str = "embedding", **kwargs) -> List[Message]:
|
||||
retrieval_func_dict = {
|
||||
"embedding": self.embedding_retrieval, "text": self.text_retrieval, "datetime": self.datetime_retrieval
|
||||
}
|
||||
|
||||
# 确保提供了合法的检索类型
|
||||
if retrieval_type not in retrieval_func_dict:
|
||||
raise ValueError(f"Invalid retrieval_type: '{retrieval_type}'. Available types: {list(retrieval_func_dict.keys())}")
|
||||
|
||||
retrieval_func = retrieval_func_dict[retrieval_type]
|
||||
#
|
||||
params = locals()
|
||||
params.pop("self")
|
||||
params.pop("retrieval_type")
|
||||
params.update(params.pop('kwargs', {}))
|
||||
#
|
||||
return retrieval_func(**params)
|
||||
|
||||
def embedding_retrieval(self, text: str, top_k=1, score_threshold=1.0, **kwargs) -> List[Message]:
|
||||
if text is None: return []
|
||||
vb_name = f"{self.user_name}/{self.unique_name}/{self.memory_type}"
|
||||
vb = KBServiceFactory.get_service(vb_name, "faiss", self.embed_config, self.kb_root_path)
|
||||
docs = vb.search_docs(text, top_k=top_k, score_threshold=score_threshold)
|
||||
return [Message(**doc.metadata) for doc, score in docs]
|
||||
|
||||
def text_retrieval(self, text: str, **kwargs) -> List[Message]:
|
||||
if text is None: return []
|
||||
return self._text_retrieval_from_cache(self.recall_memory.messages, text, score_threshold=0.3, topK=5, **kwargs)
|
||||
|
||||
def datetime_retrieval(self, datetime: str, text: str = None, n: int = 5, **kwargs) -> List[Message]:
|
||||
if datetime is None: return []
|
||||
return self._datetime_retrieval_from_cache(self.recall_memory.messages, datetime, text, n, **kwargs)
|
||||
|
||||
def _text_retrieval_from_cache(self, messages: List[Message], text: str = None, score_threshold=0.3, topK=5, tag_topK=5, **kwargs) -> List[Message]:
|
||||
keywords = extract_tags(text, topK=tag_topK)
|
||||
|
||||
matched_messages = []
|
||||
for message in messages:
|
||||
message_keywords = extract_tags(message.step_content or message.role_content or message.input_query, topK=tag_topK)
|
||||
# calculate jaccard similarity
|
||||
intersection = Counter(keywords) & Counter(message_keywords)
|
||||
union = Counter(keywords) | Counter(message_keywords)
|
||||
similarity = sum(intersection.values()) / sum(union.values())
|
||||
if similarity >= score_threshold:
|
||||
matched_messages.append((message, similarity))
|
||||
matched_messages = sorted(matched_messages, key=lambda x:x[1])
|
||||
return [m for m, s in matched_messages][:topK]
|
||||
|
||||
def _datetime_retrieval_from_cache(self, messages: List[Message], datetime: str, text: str = None, n: int = 5, **kwargs) -> List[Message]:
|
||||
# select message by datetime
|
||||
datetime_before, datetime_after = addMinutesToTime(datetime, n)
|
||||
select_messages = [
|
||||
message for message in messages
|
||||
if datetime_before<=message.datetime<=datetime_after
|
||||
]
|
||||
return self._text_retrieval_from_cache(select_messages, text)
|
||||
|
||||
def recursive_summary(self, messages: List[Message], split_n: int = 20) -> List[Message]:
|
||||
|
||||
if len(messages) == 0:
|
||||
return messages
|
||||
|
||||
newest_messages = messages[-split_n:]
|
||||
summary_messages = messages[:len(messages)-split_n]
|
||||
|
||||
while (len(newest_messages) != 0) and (newest_messages[0].role_type != "user"):
|
||||
message = newest_messages.pop(0)
|
||||
summary_messages.append(message)
|
||||
|
||||
# summary
|
||||
# model = getChatModel(temperature=0.2)
|
||||
model = getChatModelFromConfig(self.llm_config)
|
||||
summary_content = '\n\n'.join([
|
||||
m.role_type + "\n" + "\n".join(([f"*{k}* {v}" for parsed_output in m.parsed_output_list for k, v in parsed_output.items() if k not in ['Action Status']]))
|
||||
for m in summary_messages if m.role_type not in ["summary"]
|
||||
])
|
||||
|
||||
summary_prompt = CONV_SUMMARY_PROMPT_SPEC.format(conversation=summary_content)
|
||||
content = model.predict(summary_prompt)
|
||||
summary_message = Message(
|
||||
role_name="summaryer",
|
||||
role_type="summary",
|
||||
role_content=content,
|
||||
step_content=content,
|
||||
parsed_output_list=[],
|
||||
customed_kargs={}
|
||||
)
|
||||
summary_message.parsed_output_list.append({"summary": content})
|
||||
newest_messages.insert(0, summary_message)
|
||||
return newest_messages
|
|
@ -0,0 +1,267 @@
|
|||
import re, traceback, uuid, copy, json, os
|
||||
from loguru import logger
|
||||
|
||||
|
||||
# from configs.server_config import SANDBOX_SERVER
|
||||
# from configs.model_config import JUPYTER_WORK_PATH
|
||||
from coagent.connector.schema import (
|
||||
Memory, Role, Message, ActionStatus, CodeDoc, Doc, LogVerboseEnum
|
||||
)
|
||||
from coagent.connector.memory_manager import BaseMemoryManager
|
||||
from coagent.tools import DDGSTool, DocRetrieval, CodeRetrieval
|
||||
from coagent.sandbox import PyCodeBox, CodeBoxResponse
|
||||
from coagent.llm_models.llm_config import LLMConfig, EmbedConfig
|
||||
from .utils import parse_dict_to_dict, parse_text_to_dict
|
||||
|
||||
|
||||
class MessageUtils:
|
||||
def __init__(
|
||||
self,
|
||||
role: Role = None,
|
||||
sandbox_server: dict = {},
|
||||
jupyter_work_path: str = "./",
|
||||
embed_config: EmbedConfig = None,
|
||||
llm_config: LLMConfig = None,
|
||||
kb_root_path: str = "",
|
||||
log_verbose: str = "0"
|
||||
) -> None:
|
||||
self.role = role
|
||||
self.sandbox_server = sandbox_server
|
||||
self.jupyter_work_path = jupyter_work_path
|
||||
self.embed_config = embed_config
|
||||
self.llm_config = llm_config
|
||||
self.kb_root_path = kb_root_path
|
||||
self.codebox = PyCodeBox(
|
||||
remote_url=self.sandbox_server.get("url", "http://127.0.0.1:5050"),
|
||||
remote_ip=self.sandbox_server.get("host", "http://127.0.0.1"),
|
||||
remote_port=self.sandbox_server.get("port", "5050"),
|
||||
jupyter_work_path=jupyter_work_path,
|
||||
token="mytoken",
|
||||
do_code_exe=True,
|
||||
do_remote=self.sandbox_server.get("do_remote", False),
|
||||
do_check_net=False
|
||||
)
|
||||
self.log_verbose = os.environ.get("log_verbose", "0") or log_verbose
|
||||
|
||||
def inherit_extrainfo(self, input_message: Message, output_message: Message):
|
||||
output_message.db_docs = input_message.db_docs
|
||||
output_message.search_docs = input_message.search_docs
|
||||
output_message.code_docs = input_message.code_docs
|
||||
output_message.figures.update(input_message.figures)
|
||||
output_message.origin_query = input_message.origin_query
|
||||
output_message.code_engine_name = input_message.code_engine_name
|
||||
|
||||
output_message.doc_engine_name = input_message.doc_engine_name
|
||||
output_message.search_engine_name = input_message.search_engine_name
|
||||
output_message.top_k = input_message.top_k
|
||||
output_message.score_threshold = input_message.score_threshold
|
||||
output_message.cb_search_type = input_message.cb_search_type
|
||||
output_message.do_doc_retrieval = input_message.do_doc_retrieval
|
||||
output_message.do_code_retrieval = input_message.do_code_retrieval
|
||||
output_message.do_tool_retrieval = input_message.do_tool_retrieval
|
||||
#
|
||||
output_message.tools = input_message.tools
|
||||
output_message.agents = input_message.agents
|
||||
|
||||
# update customed_kargs, if exist, keep; else add
|
||||
customed_kargs = copy.deepcopy(input_message.customed_kargs)
|
||||
customed_kargs.update(output_message.customed_kargs)
|
||||
output_message.customed_kargs = customed_kargs
|
||||
return output_message
|
||||
|
||||
def inherit_baseparam(self, input_message: Message, output_message: Message):
|
||||
# 只更新参数
|
||||
output_message.doc_engine_name = input_message.doc_engine_name
|
||||
output_message.search_engine_name = input_message.search_engine_name
|
||||
output_message.top_k = input_message.top_k
|
||||
output_message.score_threshold = input_message.score_threshold
|
||||
output_message.cb_search_type = input_message.cb_search_type
|
||||
output_message.do_doc_retrieval = input_message.do_doc_retrieval
|
||||
output_message.do_code_retrieval = input_message.do_code_retrieval
|
||||
output_message.do_tool_retrieval = input_message.do_tool_retrieval
|
||||
#
|
||||
output_message.tools = input_message.tools
|
||||
output_message.agents = input_message.agents
|
||||
# 存在bug导致相同key被覆盖
|
||||
output_message.customed_kargs.update(input_message.customed_kargs)
|
||||
return output_message
|
||||
|
||||
def get_extrainfo_step(self, message: Message, do_search, do_doc_retrieval, do_code_retrieval, do_tool_retrieval) -> Message:
|
||||
''''''
|
||||
if do_search:
|
||||
message = self.get_search_retrieval(message)
|
||||
|
||||
if do_doc_retrieval:
|
||||
message = self.get_doc_retrieval(message)
|
||||
|
||||
if do_code_retrieval:
|
||||
message = self.get_code_retrieval(message)
|
||||
|
||||
if do_tool_retrieval:
|
||||
message = self.get_tool_retrieval(message)
|
||||
|
||||
return message
|
||||
|
||||
def get_search_retrieval(self, message: Message,) -> Message:
|
||||
SEARCH_ENGINES = {"duckduckgo": DDGSTool}
|
||||
search_docs = []
|
||||
for idx, doc in enumerate(SEARCH_ENGINES["duckduckgo"].run(message.role_content, 3)):
|
||||
doc.update({"index": idx})
|
||||
search_docs.append(Doc(**doc))
|
||||
message.search_docs = search_docs
|
||||
return message
|
||||
|
||||
def get_doc_retrieval(self, message: Message) -> Message:
|
||||
query = message.role_content
|
||||
knowledge_basename = message.doc_engine_name
|
||||
top_k = message.top_k
|
||||
score_threshold = message.score_threshold
|
||||
if knowledge_basename:
|
||||
docs = DocRetrieval.run(query, knowledge_basename, top_k, score_threshold, self.embed_config, self.kb_root_path)
|
||||
message.db_docs = [Doc(**doc) for doc in docs]
|
||||
return message
|
||||
|
||||
def get_code_retrieval(self, message: Message) -> Message:
|
||||
query = message.input_query
|
||||
code_engine_name = message.code_engine_name
|
||||
history_node_list = message.history_node_list
|
||||
code_docs = CodeRetrieval.run(code_engine_name, query, code_limit=message.top_k, history_node_list=history_node_list, search_type=message.cb_search_type,
|
||||
llm_config=self.llm_config, embed_config=self.embed_config,)
|
||||
message.code_docs = [CodeDoc(**doc) for doc in code_docs]
|
||||
return message
|
||||
|
||||
def get_tool_retrieval(self, message: Message) -> Message:
|
||||
return message
|
||||
|
||||
def step_router(self, message: Message, history: Memory = None, background: Memory = None, memory_manager: BaseMemoryManager=None) -> tuple[Message, ...]:
|
||||
''''''
|
||||
if LogVerboseEnum.ge(LogVerboseEnum.Log1Level, self.log_verbose):
|
||||
logger.info(f"message.action_status: {message.action_status}")
|
||||
|
||||
observation_message = None
|
||||
if message.action_status == ActionStatus.CODE_EXECUTING:
|
||||
message, observation_message = self.code_step(message)
|
||||
elif message.action_status == ActionStatus.TOOL_USING:
|
||||
message, observation_message = self.tool_step(message)
|
||||
elif message.action_status == ActionStatus.CODING2FILE:
|
||||
self.save_code2file(message, self.jupyter_work_path)
|
||||
elif message.action_status == ActionStatus.CODE_RETRIEVAL:
|
||||
pass
|
||||
elif message.action_status == ActionStatus.CODING:
|
||||
pass
|
||||
|
||||
return message, observation_message
|
||||
|
||||
def code_step(self, message: Message) -> Message:
|
||||
'''execute code'''
|
||||
# logger.debug(f"message.role_content: {message.role_content}, message.code_content: {message.code_content}")
|
||||
code_answer = self.codebox.chat('```python\n{}```'.format(message.code_content))
|
||||
code_prompt = f"The return error after executing the above code is {code_answer.code_exe_response},need to recover.\n" \
|
||||
if code_answer.code_exe_type == "error" else f"The return information after executing the above code is {code_answer.code_exe_response}.\n"
|
||||
|
||||
observation_message = Message(
|
||||
role_name="observation",
|
||||
role_type="function", #self.role.role_type,
|
||||
role_content="",
|
||||
step_content="",
|
||||
input_query=message.code_content,
|
||||
)
|
||||
uid = str(uuid.uuid1())
|
||||
if code_answer.code_exe_type == "image/png":
|
||||
message.figures[uid] = code_answer.code_exe_response
|
||||
message.code_answer = f"\n**Observation:**: The return figure name is {uid} after executing the above code.\n"
|
||||
message.observation = f"\n**Observation:**: The return figure name is {uid} after executing the above code.\n"
|
||||
message.step_content += f"\n**Observation:**: The return figure name is {uid} after executing the above code.\n"
|
||||
# message.role_content += f"\n**Observation:**:执行上述代码后生成一张图片, 图片名为{uid}\n"
|
||||
observation_message.role_content = f"\n**Observation:**: The return figure name is {uid} after executing the above code.\n"
|
||||
observation_message.parsed_output = {"Observation": f"The return figure name is {uid} after executing the above code.\n"}
|
||||
else:
|
||||
message.code_answer = code_answer.code_exe_response
|
||||
message.observation = code_answer.code_exe_response
|
||||
message.step_content += f"\n**Observation:**: {code_prompt}\n"
|
||||
# message.role_content += f"\n**Observation:**: {code_prompt}\n"
|
||||
observation_message.role_content = f"\n**Observation:**: {code_prompt}\n"
|
||||
observation_message.parsed_output = {"Observation": f"{code_prompt}\n"}
|
||||
|
||||
if LogVerboseEnum.ge(LogVerboseEnum.Log1Level, self.log_verbose):
|
||||
logger.info(f"**Observation:** {message.action_status}, {message.observation}")
|
||||
return message, observation_message
|
||||
|
||||
def tool_step(self, message: Message) -> Message:
|
||||
'''execute tool'''
|
||||
observation_message = Message(
|
||||
role_name="observation",
|
||||
role_type="function", #self.role.role_type,
|
||||
role_content="\n**Observation:** there is no tool can execute\n",
|
||||
step_content="",
|
||||
input_query=str(message.tool_params),
|
||||
tools=message.tools,
|
||||
)
|
||||
if LogVerboseEnum.ge(LogVerboseEnum.Log1Level, self.log_verbose):
|
||||
logger.info(f"message: {message.action_status}, {message.tool_params}")
|
||||
|
||||
tool_names = [tool.name for tool in message.tools]
|
||||
if message.tool_name not in tool_names:
|
||||
message.tool_answer = "\n**Observation:** there is no tool can execute.\n"
|
||||
message.observation = "\n**Observation:** there is no tool can execute.\n"
|
||||
# message.role_content += f"\n**Observation:**: 不存在可以执行的tool\n"
|
||||
message.step_content += f"\n**Observation:** there is no tool can execute.\n"
|
||||
observation_message.role_content = f"\n**Observation:** there is no tool can execute.\n"
|
||||
observation_message.parsed_output = {"Observation": "there is no tool can execute.\n"}
|
||||
|
||||
# logger.debug(message.tool_params)
|
||||
for tool in message.tools:
|
||||
if tool.name == message.tool_params.get("tool_name", ""):
|
||||
tool_res = tool.func(**message.tool_params.get("tool_params", {}))
|
||||
message.tool_answer = tool_res
|
||||
message.observation = tool_res
|
||||
# message.role_content += f"\n**Observation:**: {tool_res}\n"
|
||||
message.step_content += f"\n**Observation:** {tool_res}.\n"
|
||||
observation_message.role_content = f"\n**Observation:** {tool_res}.\n"
|
||||
observation_message.parsed_output = {"Observation": f"{tool_res}.\n"}
|
||||
break
|
||||
|
||||
if LogVerboseEnum.ge(LogVerboseEnum.Log1Level, self.log_verbose):
|
||||
logger.info(f"**Observation:** {message.action_status}, {message.observation}")
|
||||
return message, observation_message
|
||||
|
||||
def parser(self, message: Message) -> Message:
|
||||
''''''
|
||||
content = message.role_content
|
||||
# parse start
|
||||
parsed_dict = parse_text_to_dict(content)
|
||||
spec_parsed_dict = parse_dict_to_dict(parsed_dict)
|
||||
# select parse value
|
||||
action_value = parsed_dict.get('Action Status')
|
||||
if action_value:
|
||||
action_value = action_value.lower()
|
||||
|
||||
code_content_value = spec_parsed_dict.get('code')
|
||||
if action_value == 'tool_using':
|
||||
tool_params_value = spec_parsed_dict.get('json')
|
||||
else:
|
||||
tool_params_value = None
|
||||
|
||||
# add parse value to message
|
||||
message.action_status = action_value or "default"
|
||||
message.code_content = code_content_value
|
||||
message.tool_params = tool_params_value
|
||||
message.parsed_output = parsed_dict
|
||||
message.spec_parsed_output = spec_parsed_dict
|
||||
return message
|
||||
|
||||
def save_code2file(self, message: Message, project_dir="./"):
|
||||
filename = message.parsed_output.get("SaveFileName")
|
||||
code = message.spec_parsed_output.get("code")
|
||||
|
||||
for k, v in {">": ">", "≥": ">=", "<": "<", "≤": "<="}.items():
|
||||
code = code.replace(k, v)
|
||||
|
||||
file_path = os.path.join(project_dir, filename)
|
||||
|
||||
if not os.path.exists(file_path):
|
||||
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
||||
|
||||
with open(file_path, "w") as f:
|
||||
f.write(code)
|
||||
|
|
@ -0,0 +1,255 @@
|
|||
from typing import List, Union, Dict, Tuple
|
||||
import os
|
||||
import json
|
||||
import importlib
|
||||
import copy
|
||||
from loguru import logger
|
||||
|
||||
from coagent.connector.agents import BaseAgent
|
||||
from coagent.connector.chains import BaseChain
|
||||
from coagent.connector.schema import (
|
||||
Memory, Task, Message, AgentConfig, ChainConfig, PhaseConfig, LogVerboseEnum,
|
||||
CompletePhaseConfig,
|
||||
load_chain_configs, load_phase_configs, load_role_configs
|
||||
)
|
||||
from coagent.connector.memory_manager import BaseMemoryManager, LocalMemoryManager
|
||||
from coagent.connector.configs import AGETN_CONFIGS, CHAIN_CONFIGS, PHASE_CONFIGS
|
||||
from coagent.connector.message_process import MessageUtils
|
||||
from coagent.llm_models.llm_config import EmbedConfig, LLMConfig
|
||||
from coagent.base_configs.env_config import JUPYTER_WORK_PATH, KB_ROOT_PATH
|
||||
|
||||
# from configs.model_config import JUPYTER_WORK_PATH, KB_ROOT_PATH
|
||||
# from configs.server_config import SANDBOX_SERVER
|
||||
|
||||
|
||||
role_configs = load_role_configs(AGETN_CONFIGS)
|
||||
chain_configs = load_chain_configs(CHAIN_CONFIGS)
|
||||
phase_configs = load_phase_configs(PHASE_CONFIGS)
|
||||
|
||||
|
||||
CUR_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
|
||||
class BasePhase:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
phase_name: str,
|
||||
phase_config: CompletePhaseConfig = None,
|
||||
kb_root_path: str = KB_ROOT_PATH,
|
||||
jupyter_work_path: str = JUPYTER_WORK_PATH,
|
||||
sandbox_server: dict = {},
|
||||
embed_config: EmbedConfig = EmbedConfig(),
|
||||
llm_config: LLMConfig = LLMConfig(),
|
||||
task: Task = None,
|
||||
base_phase_config: Union[dict, str] = PHASE_CONFIGS,
|
||||
base_chain_config: Union[dict, str] = CHAIN_CONFIGS,
|
||||
base_role_config: Union[dict, str] = AGETN_CONFIGS,
|
||||
log_verbose: str = "0"
|
||||
) -> None:
|
||||
#
|
||||
self.phase_name = phase_name
|
||||
self.do_summary = False
|
||||
self.do_search = False
|
||||
self.do_code_retrieval = False
|
||||
self.do_doc_retrieval = False
|
||||
self.do_tool_retrieval = False
|
||||
# memory_pool dont have specific order
|
||||
# self.memory_pool = Memory(messages=[])
|
||||
self.embed_config = embed_config
|
||||
self.llm_config = llm_config
|
||||
self.sandbox_server = sandbox_server
|
||||
self.jupyter_work_path = jupyter_work_path
|
||||
self.kb_root_path = kb_root_path
|
||||
self.log_verbose = max(os.environ.get("log_verbose", "0"), log_verbose)
|
||||
|
||||
self.message_utils = MessageUtils(None, sandbox_server, jupyter_work_path, embed_config, llm_config, kb_root_path, log_verbose)
|
||||
self.global_memory = Memory(messages=[])
|
||||
self.phase_memory: List[Memory] = []
|
||||
# according phase name to init the phase contains
|
||||
self.chains: List[BaseChain] = self.init_chains(
|
||||
phase_name,
|
||||
phase_config,
|
||||
task=task,
|
||||
memory=None,
|
||||
base_phase_config = base_phase_config,
|
||||
base_chain_config = base_chain_config,
|
||||
base_role_config = base_role_config,
|
||||
)
|
||||
self.memory_manager: BaseMemoryManager = LocalMemoryManager(
|
||||
unique_name=phase_name, do_init=True, kb_root_path = kb_root_path, embed_config=embed_config, llm_config=llm_config
|
||||
)
|
||||
self.conv_summary_agent = BaseAgent(
|
||||
role=role_configs["conv_summary"].role,
|
||||
prompt_config=role_configs["conv_summary"].prompt_config,
|
||||
task = None, memory = None,
|
||||
llm_config=self.llm_config,
|
||||
embed_config=self.embed_config,
|
||||
sandbox_server=sandbox_server,
|
||||
jupyter_work_path=jupyter_work_path,
|
||||
kb_root_path=kb_root_path
|
||||
)
|
||||
|
||||
def astep(self, query: Message, history: Memory = None) -> Tuple[Message, Memory]:
|
||||
self.memory_manager.append(query)
|
||||
summary_message = None
|
||||
chain_message = Memory(messages=[])
|
||||
local_phase_memory = Memory(messages=[])
|
||||
# do_search、do_doc_search、do_code_search
|
||||
query = self.message_utils.get_extrainfo_step(query, self.do_search, self.do_doc_retrieval, self.do_code_retrieval, self.do_tool_retrieval)
|
||||
query.parsed_output = query.parsed_output if query.parsed_output else {"origin_query": query.input_query}
|
||||
query.parsed_output_list = query.parsed_output_list if query.parsed_output_list else [{"origin_query": query.input_query}]
|
||||
input_message = copy.deepcopy(query)
|
||||
|
||||
self.global_memory.append(input_message)
|
||||
local_phase_memory.append(input_message)
|
||||
for chain in self.chains:
|
||||
# chain can supply background and query to next chain
|
||||
for output_message, local_chain_memory in chain.astep(input_message, history, background=chain_message, memory_manager=self.memory_manager):
|
||||
# logger.debug(f"local_memory: {local_phase_memory + local_chain_memory}")
|
||||
yield output_message, local_phase_memory + local_chain_memory
|
||||
|
||||
output_message = self.message_utils.inherit_extrainfo(input_message, output_message)
|
||||
input_message = output_message
|
||||
# logger.info(f"{chain.chainConfig.chain_name} phase_step: {output_message.role_content}")
|
||||
# 这一段也有问题
|
||||
self.global_memory.extend(local_chain_memory)
|
||||
local_phase_memory.extend(local_chain_memory)
|
||||
|
||||
# whether to use summary_llm
|
||||
if self.do_summary:
|
||||
if LogVerboseEnum.ge(LogVerboseEnum.Log1Level, self.log_verbose):
|
||||
logger.info(f"{self.conv_summary_agent.role.role_name} input global memory: {local_phase_memory.to_str_messages(content_key='step_content')}")
|
||||
for summary_message in self.conv_summary_agent.astep(query, background=local_phase_memory, memory_manager=self.memory_manager):
|
||||
pass
|
||||
# summary_message = Message(**summary_message)
|
||||
summary_message.role_name = chain.chainConfig.chain_name
|
||||
summary_message = self.conv_summary_agent.message_utils.parser(summary_message)
|
||||
summary_message = self.message_utils.inherit_extrainfo(output_message, summary_message)
|
||||
chain_message.append(summary_message)
|
||||
|
||||
message = summary_message or output_message
|
||||
yield message, local_phase_memory
|
||||
|
||||
# 由于不会存在多轮chain执行,所以直接保留memory即可
|
||||
for chain in self.chains:
|
||||
self.phase_memory.append(chain.global_memory)
|
||||
# TODO:local_memory缺少添加summary的过程
|
||||
message = summary_message or output_message
|
||||
message.role_name = self.phase_name
|
||||
yield message, local_phase_memory
|
||||
|
||||
def step(self, query: Message, history: Memory = None) -> Tuple[Message, Memory]:
|
||||
for message, local_phase_memory in self.astep(query, history=history):
|
||||
pass
|
||||
return message, local_phase_memory
|
||||
|
||||
def pre_print(self, query, history: Memory = None) -> List[str]:
|
||||
chain_message = Memory(messages=[])
|
||||
for chain in self.chains:
|
||||
chain.pre_print(query, history, background=chain_message, memory_manager=self.memory_manager)
|
||||
|
||||
def init_chains(self, phase_name: str, phase_config: CompletePhaseConfig, base_phase_config, base_chain_config,
|
||||
base_role_config, task=None, memory=None) -> List[BaseChain]:
|
||||
# load config
|
||||
role_configs = load_role_configs(base_role_config)
|
||||
chain_configs = load_chain_configs(base_chain_config)
|
||||
phase_configs = load_phase_configs(base_phase_config)
|
||||
|
||||
chains = []
|
||||
self.chain_module = importlib.import_module("coagent.connector.chains")
|
||||
self.agent_module = importlib.import_module("coagent.connector.agents")
|
||||
|
||||
phase: PhaseConfig = phase_configs.get(phase_name)
|
||||
# set phase
|
||||
self.do_summary = phase.do_summary
|
||||
self.do_search = phase.do_search
|
||||
self.do_code_retrieval = phase.do_code_retrieval
|
||||
self.do_doc_retrieval = phase.do_doc_retrieval
|
||||
self.do_tool_retrieval = phase.do_tool_retrieval
|
||||
logger.info(f"start to init the phase, the phase_name is {phase_name}, it contains these chains such as {phase.chains}")
|
||||
|
||||
for chain_name in phase.chains:
|
||||
# logger.debug(f"{chain_configs.keys()}")
|
||||
chain_config: ChainConfig = chain_configs[chain_name]
|
||||
logger.info(f"start to init the chain, the chain_name is {chain_name}, it contains these agents such as {chain_config.agents}")
|
||||
|
||||
agents = []
|
||||
for agent_name in chain_config.agents:
|
||||
agent_config: AgentConfig = role_configs[agent_name]
|
||||
llm_config = copy.deepcopy(self.llm_config)
|
||||
llm_config.stop = agent_config.stop
|
||||
baseAgent: BaseAgent = getattr(self.agent_module, agent_config.role.agent_type)
|
||||
base_agent = baseAgent(
|
||||
role=agent_config.role,
|
||||
prompt_config = agent_config.prompt_config,
|
||||
prompt_manager_type=agent_config.prompt_manager_type,
|
||||
task = task,
|
||||
memory = memory,
|
||||
chat_turn=agent_config.chat_turn,
|
||||
focus_agents=agent_config.focus_agents,
|
||||
focus_message_keys=agent_config.focus_message_keys,
|
||||
llm_config=llm_config,
|
||||
embed_config=self.embed_config,
|
||||
sandbox_server=self.sandbox_server,
|
||||
jupyter_work_path=self.jupyter_work_path,
|
||||
kb_root_path=self.kb_root_path,
|
||||
log_verbose=self.log_verbose
|
||||
)
|
||||
if agent_config.role.agent_type == "SelectorAgent":
|
||||
for group_agent_name in agent_config.group_agents:
|
||||
group_agent_config = role_configs[group_agent_name]
|
||||
llm_config = copy.deepcopy(self.llm_config)
|
||||
llm_config.stop = group_agent_config.stop
|
||||
baseAgent: BaseAgent = getattr(self.agent_module, group_agent_config.role.agent_type)
|
||||
group_base_agent = baseAgent(
|
||||
role=group_agent_config.role,
|
||||
prompt_config = group_agent_config.prompt_config,
|
||||
prompt_manager_type=agent_config.prompt_manager_type,
|
||||
task = task,
|
||||
memory = memory,
|
||||
chat_turn=group_agent_config.chat_turn,
|
||||
focus_agents=group_agent_config.focus_agents,
|
||||
focus_message_keys=group_agent_config.focus_message_keys,
|
||||
llm_config=llm_config,
|
||||
embed_config=self.embed_config,
|
||||
sandbox_server=self.sandbox_server,
|
||||
jupyter_work_path=self.jupyter_work_path,
|
||||
kb_root_path=self.kb_root_path,
|
||||
log_verbose=self.log_verbose
|
||||
)
|
||||
base_agent.group_agents.append(group_base_agent)
|
||||
|
||||
agents.append(base_agent)
|
||||
|
||||
chain_instance = BaseChain(
|
||||
agents, chain_config.chat_turn,
|
||||
do_checker=chain_configs[chain_name].do_checker,
|
||||
jupyter_work_path=self.jupyter_work_path,
|
||||
sandbox_server=self.sandbox_server,
|
||||
embed_config=self.embed_config,
|
||||
llm_config=self.llm_config,
|
||||
kb_root_path=self.kb_root_path,
|
||||
log_verbose=self.log_verbose
|
||||
)
|
||||
chains.append(chain_instance)
|
||||
|
||||
return chains
|
||||
|
||||
def update(self) -> Memory:
|
||||
pass
|
||||
|
||||
def get_memory(self, ) -> Memory:
|
||||
return Memory.from_memory_list(
|
||||
[chain.get_memory() for chain in self.chains]
|
||||
)
|
||||
|
||||
def get_memory_str(self, do_all_memory=True, content_key="role_content") -> str:
|
||||
memory = self.global_memory if do_all_memory else self.phase_memory
|
||||
return "\n".join([": ".join(i) for i in memory.to_tuple_messages(content_key=content_key)])
|
||||
|
||||
def get_chains_memory(self, content_key="role_content") -> List[Tuple]:
|
||||
return [memory.to_tuple_messages(content_key=content_key) for memory in self.phase_memory]
|
||||
|
||||
def get_chains_memory_str(self, content_key="role_content") -> str:
|
||||
return "************".join([f"{chain.chainConfig.chain_name}\n" + chain.get_memory_str(content_key=content_key) for chain in self.chains])
|
|
@ -0,0 +1,350 @@
|
|||
from coagent.connector.schema import Memory, Message
|
||||
import random
|
||||
from textwrap import dedent
|
||||
import copy
|
||||
from loguru import logger
|
||||
|
||||
from coagent.connector.utils import extract_section, parse_section
|
||||
|
||||
|
||||
class PromptManager:
|
||||
def __init__(self, role_prompt="", prompt_config=None, monitored_agents=[], monitored_fields=[]):
|
||||
self.role_prompt = role_prompt
|
||||
self.monitored_agents = monitored_agents
|
||||
self.monitored_fields = monitored_fields
|
||||
self.field_handlers = {}
|
||||
self.context_handlers = {}
|
||||
self.field_order = [] # 用于普通字段的顺序
|
||||
self.context_order = [] # 单独维护上下文字段的顺序
|
||||
self.field_descriptions = {}
|
||||
self.omit_if_empty_flags = {}
|
||||
self.context_title = "### Context Data\n\n"
|
||||
|
||||
self.prompt_config = prompt_config
|
||||
if self.prompt_config:
|
||||
self.register_fields_from_config()
|
||||
|
||||
def register_field(self, field_name, function=None, title=None, description=None, is_context=True, omit_if_empty=True):
|
||||
"""
|
||||
注册一个新的字段及其处理函数。
|
||||
Args:
|
||||
field_name (str): 字段名称。
|
||||
function (callable): 处理字段数据的函数。
|
||||
title (str, optional): 字段的自定义标题(可选)。
|
||||
description (str, optional): 字段的描述(可选,可以是几句话)。
|
||||
is_context (bool, optional): 指示该字段是否为上下文字段。
|
||||
omit_if_empty (bool, optional): 如果数据为空,是否省略该字段。
|
||||
"""
|
||||
if not function:
|
||||
function = self.handle_custom_data
|
||||
|
||||
# Register the handler function based on context flag
|
||||
if is_context:
|
||||
self.context_handlers[field_name] = function
|
||||
else:
|
||||
self.field_handlers[field_name] = function
|
||||
|
||||
# Store the custom title if provided and adjust the title prefix based on context
|
||||
title_prefix = "####" if is_context else "###"
|
||||
if title is not None:
|
||||
self.field_descriptions[field_name] = f"{title_prefix} {title}\n\n"
|
||||
elif description is not None:
|
||||
# If title is not provided but description is, use description as title
|
||||
self.field_descriptions[field_name] = f"{title_prefix} {field_name.replace('_', ' ').title()}\n\n{description}\n\n"
|
||||
else:
|
||||
# If neither title nor description is provided, use the field name as title
|
||||
self.field_descriptions[field_name] = f"{title_prefix} {field_name.replace('_', ' ').title()}\n\n"
|
||||
|
||||
# Store the omit_if_empty flag for this field
|
||||
self.omit_if_empty_flags[field_name] = omit_if_empty
|
||||
|
||||
if is_context and field_name != 'context_placeholder':
|
||||
self.context_handlers[field_name] = function
|
||||
self.context_order.append(field_name)
|
||||
else:
|
||||
self.field_handlers[field_name] = function
|
||||
self.field_order.append(field_name)
|
||||
|
||||
def generate_full_prompt(self, **kwargs):
|
||||
full_prompt = []
|
||||
context_prompts = [] # 用于收集上下文内容
|
||||
is_pre_print = kwargs.get("is_pre_print", False) # 用于强制打印所有prompt 字段信息,不管有没有空
|
||||
|
||||
# 先处理上下文字段
|
||||
for field_name in self.context_order:
|
||||
handler = self.context_handlers[field_name]
|
||||
processed_prompt = handler(**kwargs)
|
||||
# Check if the field should be omitted when empty
|
||||
if self.omit_if_empty_flags.get(field_name, False) and not processed_prompt and not is_pre_print:
|
||||
continue # Skip this field
|
||||
title_or_description = self.field_descriptions.get(field_name, f"#### {field_name.replace('_', ' ').title()}\n\n")
|
||||
context_prompts.append(title_or_description + processed_prompt + '\n\n')
|
||||
|
||||
# 处理普通字段,同时查找 context_placeholder 的位置
|
||||
for field_name in self.field_order:
|
||||
if field_name == 'context_placeholder':
|
||||
# 在 context_placeholder 的位置插入上下文数据
|
||||
full_prompt.append(self.context_title) # 添加上下文部分的大标题
|
||||
full_prompt.extend(context_prompts) # 添加收集的上下文内容
|
||||
else:
|
||||
handler = self.field_handlers[field_name]
|
||||
processed_prompt = handler(**kwargs)
|
||||
# Check if the field should be omitted when empty
|
||||
if self.omit_if_empty_flags.get(field_name, False) and not processed_prompt and not is_pre_print:
|
||||
continue # Skip this field
|
||||
title_or_description = self.field_descriptions.get(field_name, f"### {field_name.replace('_', ' ').title()}\n\n")
|
||||
full_prompt.append(title_or_description + processed_prompt + '\n\n')
|
||||
|
||||
# 返回完整的提示,移除尾部的空行
|
||||
return ''.join(full_prompt).rstrip('\n')
|
||||
|
||||
def pre_print(self, **kwargs):
|
||||
kwargs.update({"is_pre_print": True})
|
||||
prompt = self.generate_full_prompt(**kwargs)
|
||||
|
||||
input_keys = parse_section(self.role_prompt, 'Response Output Format')
|
||||
llm_predict = "\n".join([f"**{k}:**" for k in input_keys])
|
||||
return prompt + "\n\n" + "#"*19 + "\n<<<<LLM PREDICT>>>>\n" + "#"*19 + f"\n\n{llm_predict}\n"
|
||||
|
||||
def handle_custom_data(self, **kwargs):
|
||||
return ""
|
||||
|
||||
def handle_tool_data(self, **kwargs):
|
||||
if 'previous_agent_message' not in kwargs:
|
||||
return ""
|
||||
|
||||
previous_agent_message = kwargs.get('previous_agent_message')
|
||||
tools = previous_agent_message.tools
|
||||
|
||||
if not tools:
|
||||
return ""
|
||||
|
||||
tool_strings = []
|
||||
for tool in tools:
|
||||
args_schema = str(tool.args)
|
||||
tool_strings.append(f"{tool.name}: {tool.description}, args: {args_schema}")
|
||||
formatted_tools = "\n".join(tool_strings)
|
||||
|
||||
tool_names = ", ".join([tool.name for tool in tools])
|
||||
|
||||
tool_prompt = dedent(f"""
|
||||
Below is a list of tools that are available for your use:
|
||||
{formatted_tools}
|
||||
|
||||
valid "tool_name" value is:
|
||||
{tool_names}
|
||||
""")
|
||||
|
||||
return tool_prompt
|
||||
|
||||
def handle_agent_data(self, **kwargs):
|
||||
if 'agents' not in kwargs:
|
||||
return ""
|
||||
|
||||
agents = kwargs.get('agents')
|
||||
random.shuffle(agents)
|
||||
agent_names = ", ".join([f'{agent.role.role_name}' for agent in agents])
|
||||
agent_descs = []
|
||||
for agent in agents:
|
||||
role_desc = agent.role.role_prompt.split("####")[1]
|
||||
while "\n\n" in role_desc:
|
||||
role_desc = role_desc.replace("\n\n", "\n")
|
||||
role_desc = role_desc.replace("\n", ",")
|
||||
|
||||
agent_descs.append(f'"role name: {agent.role.role_name}\nrole description: {role_desc}"')
|
||||
|
||||
agents = "\n".join(agent_descs)
|
||||
agent_prompt = f'''
|
||||
Please ensure your selection is one of the listed roles. Available roles for selection:
|
||||
{agents}
|
||||
Please ensure select the Role from agent names, such as {agent_names}'''
|
||||
|
||||
return dedent(agent_prompt)
|
||||
|
||||
def handle_doc_info(self, **kwargs) -> str:
|
||||
if 'previous_agent_message' not in kwargs:
|
||||
return ""
|
||||
previous_agent_message: Message = kwargs.get('previous_agent_message')
|
||||
db_docs = previous_agent_message.db_docs
|
||||
search_docs = previous_agent_message.search_docs
|
||||
code_cocs = previous_agent_message.code_docs
|
||||
doc_infos = "\n".join([doc.get_snippet() for doc in db_docs] + [doc.get_snippet() for doc in search_docs] +
|
||||
[doc.get_code() for doc in code_cocs])
|
||||
return doc_infos
|
||||
|
||||
def handle_session_records(self, **kwargs) -> str:
|
||||
|
||||
memory_pool: Memory = kwargs.get('memory_pool', Memory(messages=[]))
|
||||
memory_pool = self.select_memory_by_agent_name(memory_pool)
|
||||
memory_pool = self.select_memory_by_parsed_key(memory_pool)
|
||||
|
||||
return memory_pool.to_str_messages(content_key="parsed_output_list", with_tag=True)
|
||||
|
||||
def handle_current_plan(self, **kwargs) -> str:
|
||||
if 'previous_agent_message' not in kwargs:
|
||||
return ""
|
||||
previous_agent_message = kwargs['previous_agent_message']
|
||||
return previous_agent_message.parsed_output.get("CURRENT_STEP", "")
|
||||
|
||||
def handle_agent_profile(self, **kwargs) -> str:
|
||||
return extract_section(self.role_prompt, 'Agent Profile')
|
||||
|
||||
def handle_output_format(self, **kwargs) -> str:
|
||||
return extract_section(self.role_prompt, 'Response Output Format')
|
||||
|
||||
def handle_response(self, **kwargs) -> str:
|
||||
if 'react_memory' not in kwargs:
|
||||
return ""
|
||||
|
||||
react_memory = kwargs.get('react_memory', Memory(messages=[]))
|
||||
if react_memory is None:
|
||||
return ""
|
||||
|
||||
return "\n".join(["\n".join([f"**{k}:**\n{v}" for k,v in _dict.items()]) for _dict in react_memory.get_parserd_output()])
|
||||
|
||||
def handle_task_records(self, **kwargs) -> str:
|
||||
if 'task_memory' not in kwargs:
|
||||
return ""
|
||||
|
||||
task_memory: Memory = kwargs.get('task_memory', Memory(messages=[]))
|
||||
if task_memory is None:
|
||||
return ""
|
||||
|
||||
return "\n".join(["\n".join([f"**{k}:**\n{v}" for k,v in _dict.items() if k not in ["CURRENT_STEP"]]) for _dict in task_memory.get_parserd_output()])
|
||||
|
||||
def handle_previous_message(self, message: Message) -> str:
|
||||
pass
|
||||
|
||||
def handle_message_by_role_name(self, message: Message) -> str:
|
||||
pass
|
||||
|
||||
def handle_message_by_role_type(self, message: Message) -> str:
|
||||
pass
|
||||
|
||||
def handle_current_agent_react_message(self, message: Message) -> str:
|
||||
pass
|
||||
|
||||
def extract_codedoc_info_for_prompt(self, message: Message) -> str:
|
||||
code_docs = message.code_docs
|
||||
doc_infos = "\n".join([doc.get_code() for doc in code_docs])
|
||||
return doc_infos
|
||||
|
||||
def select_memory_by_parsed_key(self, memory: Memory) -> Memory:
|
||||
return Memory(
|
||||
messages=[self.select_message_by_parsed_key(message) for message in memory.messages
|
||||
if self.select_message_by_parsed_key(message) is not None]
|
||||
)
|
||||
|
||||
def select_memory_by_agent_name(self, memory: Memory) -> Memory:
|
||||
return Memory(
|
||||
messages=[self.select_message_by_agent_name(message) for message in memory.messages
|
||||
if self.select_message_by_agent_name(message) is not None]
|
||||
)
|
||||
|
||||
def select_message_by_agent_name(self, message: Message) -> Message:
|
||||
# assume we focus all agents
|
||||
if self.monitored_agents == []:
|
||||
return message
|
||||
return None if message is None or message.role_name not in self.monitored_agents else self.select_message_by_parsed_key(message)
|
||||
|
||||
def select_message_by_parsed_key(self, message: Message) -> Message:
|
||||
# assume we focus all key contents
|
||||
if message is None:
|
||||
return message
|
||||
|
||||
if self.monitored_fields == []:
|
||||
return message
|
||||
|
||||
message_c = copy.deepcopy(message)
|
||||
message_c.parsed_output = {k: v for k,v in message_c.parsed_output.items() if k in self.monitored_fields}
|
||||
message_c.parsed_output_list = [{k: v for k,v in parsed_output.items() if k in self.monitored_fields} for parsed_output in message_c.parsed_output_list]
|
||||
return message_c
|
||||
|
||||
def get_memory(self, content_key="role_content"):
|
||||
return self.memory.to_tuple_messages(content_key="step_content")
|
||||
|
||||
def get_memory_str(self, content_key="role_content"):
|
||||
return "\n".join([": ".join(i) for i in self.memory.to_tuple_messages(content_key="step_content")])
|
||||
|
||||
def register_fields_from_config(self):
|
||||
|
||||
for prompt_field in self.prompt_config:
|
||||
|
||||
function_name = prompt_field.function_name
|
||||
# 检查function_name是否是self的一个方法
|
||||
if function_name and hasattr(self, function_name):
|
||||
function = getattr(self, function_name)
|
||||
else:
|
||||
function = self.handle_custom_data
|
||||
|
||||
self.register_field(prompt_field.field_name,
|
||||
function=function,
|
||||
title=prompt_field.title,
|
||||
description=prompt_field.description,
|
||||
is_context=prompt_field.is_context,
|
||||
omit_if_empty=prompt_field.omit_if_empty)
|
||||
|
||||
def register_standard_fields(self):
|
||||
self.register_field('agent_profile', function=self.handle_agent_profile, is_context=False)
|
||||
self.register_field('tool_information', function=self.handle_tool_data, is_context=False)
|
||||
self.register_field('context_placeholder', is_context=True) # 用于标记上下文数据部分的位置
|
||||
self.register_field('reference_documents', function=self.handle_doc_info, is_context=True)
|
||||
self.register_field('session_records', function=self.handle_session_records, is_context=True)
|
||||
self.register_field('output_format', function=self.handle_output_format, title='Response Output Format', is_context=False)
|
||||
self.register_field('response', function=self.handle_response, is_context=False, omit_if_empty=False)
|
||||
|
||||
def register_executor_fields(self):
|
||||
self.register_field('agent_profile', function=self.handle_agent_profile, is_context=False)
|
||||
self.register_field('tool_information', function=self.handle_tool_data, is_context=False)
|
||||
self.register_field('context_placeholder', is_context=True) # 用于标记上下文数据部分的位置
|
||||
self.register_field('reference_documents', function=self.handle_doc_info, is_context=True)
|
||||
self.register_field('session_records', function=self.handle_session_records, is_context=True)
|
||||
self.register_field('current_plan', function=self.handle_current_plan, is_context=True)
|
||||
self.register_field('output_format', function=self.handle_output_format, title='Response Output Format', is_context=False)
|
||||
self.register_field('response', function=self.handle_response, is_context=False, omit_if_empty=False)
|
||||
|
||||
def register_fields_from_dict(self, fields_dict):
|
||||
# 使用字典注册字段的函数
|
||||
for field_name, field_config in fields_dict.items():
|
||||
function_name = field_config.get('function', None)
|
||||
title = field_config.get('title', None)
|
||||
description = field_config.get('description', None)
|
||||
is_context = field_config.get('is_context', True)
|
||||
omit_if_empty = field_config.get('omit_if_empty', True)
|
||||
|
||||
# 检查function_name是否是self的一个方法
|
||||
if function_name and hasattr(self, function_name):
|
||||
function = getattr(self, function_name)
|
||||
else:
|
||||
function = self.handle_custom_data
|
||||
|
||||
# 调用已存在的register_field方法注册字段
|
||||
self.register_field(field_name, function=function, title=title, description=description, is_context=is_context, omit_if_empty=omit_if_empty)
|
||||
|
||||
|
||||
|
||||
def main():
|
||||
manager = PromptManager()
|
||||
manager.register_standard_fields()
|
||||
|
||||
manager.register_field('agents_work_progress', title=f"Agents' Work Progress", is_context=True)
|
||||
|
||||
# 创建数据字典
|
||||
data_dict = {
|
||||
"agent_profile": "这是代理配置文件...",
|
||||
# "tool_list": "这是工具列表...",
|
||||
"reference_documents": "这是参考文档...",
|
||||
"session_records": "这是会话记录...",
|
||||
"agents_work_progress": "这是代理工作进展...",
|
||||
"output_format": "这是预期的输出格式...",
|
||||
# "response": "这是生成或继续回应的指令...",
|
||||
"response": "",
|
||||
"test": 'xxxxx'
|
||||
}
|
||||
|
||||
# 组合完整的提示
|
||||
full_prompt = manager.generate_full_prompt(data_dict)
|
||||
print(full_prompt)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -3,7 +3,7 @@ from .general_schema import *
|
|||
from .message import Message
|
||||
|
||||
__all__ = [
|
||||
"Memory", "ActionStatus", "Doc", "CodeDoc", "Task",
|
||||
"Memory", "ActionStatus", "Doc", "CodeDoc", "Task", "LogVerboseEnum",
|
||||
"Env", "Role", "ChainConfig", "AgentConfig", "PhaseConfig", "Message",
|
||||
"load_role_configs", "load_chain_configs", "load_phase_configs"
|
||||
]
|
|
@ -1,5 +1,5 @@
|
|||
from pydantic import BaseModel
|
||||
from typing import List, Dict
|
||||
from typing import List, Dict, Optional, Union
|
||||
from enum import Enum
|
||||
import re
|
||||
import json
|
||||
|
@ -11,7 +11,7 @@ class ActionStatus(Enum):
|
|||
DEFAUILT = "default"
|
||||
|
||||
FINISHED = "finished"
|
||||
STOPED = "stoped"
|
||||
STOPPED = "stopped"
|
||||
CONTINUED = "continued"
|
||||
|
||||
TOOL_USING = "tool_using"
|
||||
|
@ -38,8 +38,8 @@ class FinishedAction(Action):
|
|||
action_name: str = ActionStatus.FINISHED
|
||||
description: str = "provide the final answer to the original query to break the chain answer"
|
||||
|
||||
class StopedAction(Action):
|
||||
action_name: str = ActionStatus.STOPED
|
||||
class StoppedAction(Action):
|
||||
action_name: str = ActionStatus.STOPPED
|
||||
description: str = "provide the final answer to the original query to break the agent answer"
|
||||
|
||||
class ContinuedAction(Action):
|
||||
|
@ -86,6 +86,7 @@ class RoleTypeEnums(Enum):
|
|||
ASSISTANT = "assistant"
|
||||
FUNCTION = "function"
|
||||
OBSERVATION = "observation"
|
||||
SUMMARY = "summary"
|
||||
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, str):
|
||||
|
@ -118,7 +119,7 @@ class PromptKeyEnums(Enum):
|
|||
def __eq__(self, other):
|
||||
if isinstance(other, str):
|
||||
return self.value == other
|
||||
return super().__eq__(other)
|
||||
return super().__eq__(other)
|
||||
|
||||
|
||||
class Doc(BaseModel):
|
||||
|
@ -167,15 +168,43 @@ class CodeDoc(BaseModel):
|
|||
return f"""出处 [{self.index + 1}] \n\n来源 ({self.related_nodes}) \n\n内容 {self.code}\n\n"""
|
||||
|
||||
|
||||
class LogVerboseEnum(Enum):
|
||||
Log0Level = "0" # don't print log
|
||||
Log1Level = "1" # print level-1 log
|
||||
Log2Level = "2" # print level-2 log
|
||||
Log3Level = "3" # print level-3 log
|
||||
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, str):
|
||||
return self.value.lower() == other.lower()
|
||||
if isinstance(other, LogVerboseEnum):
|
||||
return self.value == other.value
|
||||
return False
|
||||
|
||||
def __ge__(self, other):
|
||||
if isinstance(other, LogVerboseEnum):
|
||||
return int(self.value) >= int(other.value)
|
||||
if isinstance(other, str):
|
||||
return int(self.value) >= int(other)
|
||||
return NotImplemented
|
||||
|
||||
def __le__(self, other):
|
||||
if isinstance(other, LogVerboseEnum):
|
||||
return int(self.value) <= int(other.value)
|
||||
if isinstance(other, str):
|
||||
return int(self.value) <= int(other)
|
||||
return NotImplemented
|
||||
|
||||
@classmethod
|
||||
def ge(self, enum_value: 'LogVerboseEnum', other: Union[str, 'LogVerboseEnum']):
|
||||
return enum_value <= other
|
||||
|
||||
|
||||
class Task(BaseModel):
|
||||
task_type: str
|
||||
task_name: str
|
||||
task_desc: str
|
||||
task_prompt: str
|
||||
# def __init__(self, task_type, task_name, task_desc) -> None:
|
||||
# self.task_type = task_type
|
||||
# self.task_name = task_name
|
||||
# self.task_desc = task_desc
|
||||
|
||||
class Env(BaseModel):
|
||||
env_type: str
|
||||
|
@ -192,30 +221,32 @@ class Role(BaseModel):
|
|||
template_prompt: str = ""
|
||||
|
||||
|
||||
|
||||
class ChainConfig(BaseModel):
|
||||
chain_name: str
|
||||
chain_type: str
|
||||
agents: List[str]
|
||||
do_checker: bool = False
|
||||
chat_turn: int = 1
|
||||
clear_structure: bool = False
|
||||
brainstorming: bool = False
|
||||
gui_design: bool = True
|
||||
git_management: bool = False
|
||||
self_improve: bool = False
|
||||
|
||||
|
||||
class PromptField(BaseModel):
|
||||
field_name: str # 假设这是一个函数类型,您可以根据需要更改
|
||||
function_name: str
|
||||
title: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
is_context: Optional[bool] = True
|
||||
omit_if_empty: Optional[bool] = True
|
||||
|
||||
|
||||
class AgentConfig(BaseModel):
|
||||
role: Role
|
||||
stop: str = None
|
||||
prompt_config: List[PromptField]
|
||||
prompt_manager_type: str = "PromptManager"
|
||||
chat_turn: int = 1
|
||||
do_search: bool = False
|
||||
do_doc_retrieval: bool = False
|
||||
do_tool_retrieval: bool = False
|
||||
focus_agents: List = []
|
||||
focus_message_keys: List = []
|
||||
group_agents: List = []
|
||||
stop: str = ""
|
||||
|
||||
|
||||
class PhaseConfig(BaseModel):
|
||||
|
@ -229,13 +260,32 @@ class PhaseConfig(BaseModel):
|
|||
do_tool_retrieval: bool = False
|
||||
|
||||
|
||||
class CompleteChainConfig(BaseModel):
|
||||
chain_name: str
|
||||
chain_type: str
|
||||
agents: Dict[str, AgentConfig]
|
||||
do_checker: bool = False
|
||||
chat_turn: int = 1
|
||||
|
||||
|
||||
class CompletePhaseConfig(BaseModel):
|
||||
phase_name: str
|
||||
phase_type: str
|
||||
chains: Dict[str, CompleteChainConfig]
|
||||
do_summary: bool = False
|
||||
do_search: bool = False
|
||||
do_doc_retrieval: bool = False
|
||||
do_code_retrieval: bool = False
|
||||
do_tool_retrieval: bool = False
|
||||
|
||||
|
||||
def load_role_configs(config) -> Dict[str, AgentConfig]:
|
||||
if isinstance(config, str):
|
||||
with open(config, 'r', encoding="utf8") as file:
|
||||
configs = json.load(file)
|
||||
else:
|
||||
configs = config
|
||||
|
||||
# logger.debug(configs)
|
||||
return {name: AgentConfig(**v) for name, v in configs.items()}
|
||||
|
||||
|
||||
|
@ -254,4 +304,6 @@ def load_phase_configs(config) -> Dict[str, PhaseConfig]:
|
|||
configs = json.load(file)
|
||||
else:
|
||||
configs = config
|
||||
return {name: PhaseConfig(**v) for name, v in configs.items()}
|
||||
return {name: PhaseConfig(**v) for name, v in configs.items()}
|
||||
|
||||
# AgentConfig.update_forward_refs()
|
|
@ -0,0 +1,158 @@
|
|||
from pydantic import BaseModel
|
||||
from typing import List, Union, Dict
|
||||
from loguru import logger
|
||||
|
||||
from .message import Message
|
||||
from coagent.utils.common_utils import (
|
||||
save_to_jsonl_file, save_to_json_file, read_json_file, read_jsonl_file
|
||||
)
|
||||
|
||||
|
||||
class Memory(BaseModel):
|
||||
messages: List[Message] = []
|
||||
|
||||
# def __init__(self, messages: List[Message] = []):
|
||||
# self.messages = messages
|
||||
|
||||
def append(self, message: Message):
|
||||
self.messages.append(message)
|
||||
|
||||
def extend(self, memory: 'Memory'):
|
||||
self.messages.extend(memory.messages)
|
||||
|
||||
def update(self, role_name: str, role_type: str, role_content: str):
|
||||
self.messages.append(Message(role_name, role_type, role_content, role_content))
|
||||
|
||||
def clear(self, ):
|
||||
self.messages = []
|
||||
|
||||
def delete(self, ):
|
||||
pass
|
||||
|
||||
def get_messages(self, k=0) -> List[Message]:
|
||||
"""Return the most recent k memories, return all when k=0"""
|
||||
return self.messages[-k:]
|
||||
|
||||
def split_by_role_type(self) -> List[Dict[str, 'Memory']]:
|
||||
"""
|
||||
Split messages into rounds of conversation based on role_type.
|
||||
Each round consists of consecutive messages of the same role_type.
|
||||
User messages form a single round, while assistant and function messages are combined into a single round.
|
||||
Each round is represented by a dict with 'role' and 'memory' keys, with assistant and function messages
|
||||
labeled as 'assistant'.
|
||||
"""
|
||||
rounds = []
|
||||
current_memory = Memory()
|
||||
current_role = None
|
||||
|
||||
logger.debug(len(self.messages))
|
||||
|
||||
for msg in self.messages:
|
||||
# Determine the message's role, considering 'function' as 'assistant'
|
||||
message_role = 'assistant' if msg.role_type in ['assistant', 'function'] else 'user'
|
||||
|
||||
# If the current memory is empty or the current message is of the same role_type as current_role, add to current memory
|
||||
if not current_memory.messages or current_role == message_role:
|
||||
current_memory.append(msg)
|
||||
else:
|
||||
# Finish the current memory and start a new one
|
||||
rounds.append({'role': current_role, 'memory': current_memory})
|
||||
current_memory = Memory()
|
||||
current_memory.append(msg)
|
||||
|
||||
# Update the current_role, considering 'function' as 'assistant'
|
||||
current_role = message_role
|
||||
|
||||
# Don't forget to add the last memory if it exists
|
||||
if current_memory.messages:
|
||||
rounds.append({'role': current_role, 'memory': current_memory})
|
||||
|
||||
logger.debug(rounds)
|
||||
|
||||
return rounds
|
||||
|
||||
def format_rounds_to_html(self) -> str:
|
||||
formatted_html_str = ""
|
||||
rounds = self.split_by_role_type()
|
||||
|
||||
for round in rounds:
|
||||
role = round['role']
|
||||
memory = round['memory']
|
||||
|
||||
# 转换当前round的Memory为字符串
|
||||
messages_str = memory.to_str_messages()
|
||||
|
||||
# 根据角色类型添加相应的HTML标签
|
||||
if role == 'user':
|
||||
formatted_html_str += f"<user-message>\n{messages_str}\n</user-message>\n"
|
||||
else: # 对于'assistant'和'function'角色,我们将其视为'assistant'
|
||||
formatted_html_str += f"<assistant-message>\n{messages_str}\n</assistant-message>\n"
|
||||
|
||||
return formatted_html_str
|
||||
|
||||
|
||||
def filter_by_role_type(self, role_types: List[str]) -> List[Message]:
|
||||
# Filter messages based on role types
|
||||
return [message for message in self.messages if message.role_type not in role_types]
|
||||
|
||||
def select_by_role_type(self, role_types: List[str]) -> List[Message]:
|
||||
# Select messages based on role types
|
||||
return [message for message in self.messages if message.role_type in role_types]
|
||||
|
||||
def to_tuple_messages(self, return_all: bool = True, content_key="role_content", filter_roles=[]):
|
||||
# Convert messages to tuples based on parameters
|
||||
# logger.debug(f"{[message.to_tuple_message(return_all, content_key) for message in self.messages ]}")
|
||||
return [
|
||||
message.to_tuple_message(return_all, content_key) for message in self.messages
|
||||
if message.role_name not in filter_roles
|
||||
]
|
||||
|
||||
def to_dict_messages(self, filter_roles=[]):
|
||||
# Convert messages to dictionaries based on filter roles
|
||||
return [
|
||||
message.to_dict_message() for message in self.messages
|
||||
if message.role_name not in filter_roles
|
||||
]
|
||||
|
||||
def to_str_messages(self, return_all: bool = True, content_key="role_content", filter_roles=[], with_tag=False):
|
||||
# Convert messages to strings based on parameters
|
||||
# for message in self.messages:
|
||||
# logger.debug(f"{message.role_name}: {message.to_str_content(return_all, content_key, with_tag=with_tag)}")
|
||||
# logger.debug(f"{[message.to_tuple_message(return_all, content_key) for message in self.messages ]}")
|
||||
return "\n\n".join([message.to_str_content(return_all, content_key, with_tag=with_tag) for message in self.messages
|
||||
if message.role_name not in filter_roles
|
||||
])
|
||||
|
||||
def get_parserd_output(self, ):
|
||||
return [message.parsed_output for message in self.messages]
|
||||
|
||||
def get_parserd_output_list(self, ):
|
||||
# for message in self.messages:
|
||||
# logger.debug(f"{message.role_name}: {message.parsed_output_list}")
|
||||
# return [parsed_output for message in self.messages for parsed_output in message.parsed_output_list[1:]]
|
||||
return [parsed_output for message in self.messages for parsed_output in message.parsed_output_list]
|
||||
|
||||
def get_rolenames(self, ):
|
||||
''''''
|
||||
return [message.role_name for message in self.messages]
|
||||
|
||||
@classmethod
|
||||
def from_memory_list(cls, memorys: List['Memory']) -> 'Memory':
|
||||
return cls(messages=[message for memory in memorys for message in memory.get_messages()])
|
||||
|
||||
def __len__(self, ):
|
||||
return len(self.messages)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return "\n".join([":".join(i) for i in self.to_tuple_messages()])
|
||||
|
||||
def __add__(self, other: Union[Message, 'Memory']) -> 'Memory':
|
||||
if isinstance(other, Message):
|
||||
return Memory(messages=self.messages + [other])
|
||||
elif isinstance(other, Memory):
|
||||
return Memory(messages=self.messages + other.messages)
|
||||
else:
|
||||
raise ValueError(f"cant add unspecified type like as {type(other)}")
|
||||
|
||||
|
||||
|
|
@ -1,6 +1,7 @@
|
|||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, root_validator
|
||||
from loguru import logger
|
||||
|
||||
from coagent.utils.common_utils import getCurrentDatetime
|
||||
from .general_schema import *
|
||||
|
||||
|
||||
|
@ -11,6 +12,7 @@ class Message(BaseModel):
|
|||
role_prompt: str = None
|
||||
input_query: str = None
|
||||
origin_query: str = None
|
||||
datetime: str = getCurrentDatetime()
|
||||
|
||||
# llm output
|
||||
role_content: str = None
|
||||
|
@ -27,7 +29,7 @@ class Message(BaseModel):
|
|||
parsed_output_list: List[Dict] = []
|
||||
|
||||
# llm\tool\code executre information
|
||||
action_status: str = ActionStatus.DEFAUILT
|
||||
action_status: str = "default"
|
||||
agent_index: int = None
|
||||
code_answer: str = None
|
||||
tool_answer: str = None
|
||||
|
@ -59,6 +61,27 @@ class Message(BaseModel):
|
|||
# user's customed kargs for init or end action
|
||||
customed_kargs: dict = {}
|
||||
|
||||
|
||||
@root_validator(pre=True)
|
||||
def check_card_number_omitted(cls, values):
|
||||
input_query = values.get("input_query")
|
||||
origin_query = values.get("origin_query")
|
||||
role_content = values.get("role_content")
|
||||
if input_query is None:
|
||||
values["input_query"] = origin_query or role_content
|
||||
if role_content is None:
|
||||
values["role_content"] = origin_query
|
||||
return values
|
||||
|
||||
# pydantic>=2.0
|
||||
# @model_validator(mode='after')
|
||||
# def check_passwords_match(self) -> 'Message':
|
||||
# if self.input_query is None:
|
||||
# self.input_query = self.origin_query or self.role_content
|
||||
# if self.role_content is None:
|
||||
# self.role_content = self.origin_query
|
||||
# return self
|
||||
|
||||
def to_tuple_message(self, return_all: bool = True, content_key="role_content"):
|
||||
role_content = self.to_str_content(False, content_key)
|
||||
if return_all:
|
||||
|
@ -66,29 +89,28 @@ class Message(BaseModel):
|
|||
else:
|
||||
return (role_content)
|
||||
|
||||
def to_dict_message(self, return_all: bool = True, content_key="role_content"):
|
||||
role_content = self.to_str_content(False, content_key)
|
||||
if return_all:
|
||||
return {"role": self.role_name, "content": role_content}
|
||||
else:
|
||||
return vars(self)
|
||||
def to_dict_message(self, ):
|
||||
return vars(self)
|
||||
|
||||
def to_str_content(self, return_all: bool = True, content_key="role_content"):
|
||||
def to_str_content(self, return_all: bool = True, content_key="role_content", with_tag=False):
|
||||
if content_key == "role_content":
|
||||
role_content = self.role_content or self.input_query
|
||||
elif content_key == "step_content":
|
||||
role_content = self.step_content or self.role_content or self.input_query
|
||||
elif content_key == "parsed_output":
|
||||
role_content = "\n".join([f"**{k}:** {v}" for k, v in self.parsed_output.items()])
|
||||
elif content_key == "parsed_output_list":
|
||||
role_content = "\n".join([f"**{k}:** {v}" for po in self.parsed_output_list for k,v in po.items()])
|
||||
else:
|
||||
role_content = self.role_content or self.input_query
|
||||
|
||||
if return_all:
|
||||
return f"{self.role_name}: {role_content}"
|
||||
if with_tag:
|
||||
start_tag = f"<{self.role_type}-{self.role_name}-message>"
|
||||
end_tag = f"</{self.role_type}-{self.role_name}-message>"
|
||||
return f"{start_tag}\n{role_content}\n{end_tag}"
|
||||
else:
|
||||
return role_content
|
||||
|
||||
def is_system_role(self,):
|
||||
return self.role_type == "system"
|
||||
|
||||
def __str__(self) -> str:
|
||||
# key_str = '\n'.join([k for k, v in vars(self).items()])
|
||||
# logger.debug(f"{key_str}")
|
|
@ -0,0 +1,117 @@
|
|||
import re, copy, json
|
||||
from loguru import logger
|
||||
|
||||
|
||||
def extract_section(text, section_name):
|
||||
# Define a pattern to extract the named section along with its content
|
||||
section_pattern = rf'#### {section_name}\n(.*?)(?=####|$)'
|
||||
|
||||
# Find the specific section content
|
||||
section_content = re.search(section_pattern, text, re.DOTALL)
|
||||
|
||||
if section_content:
|
||||
# If the section is found, extract the content and strip the leading/trailing whitespace
|
||||
# This will also remove leading/trailing newlines
|
||||
content = section_content.group(1).strip()
|
||||
|
||||
# Return the cleaned content
|
||||
return content
|
||||
else:
|
||||
# If the section is not found, return an empty string
|
||||
return ""
|
||||
|
||||
|
||||
def parse_section(text, section_name):
|
||||
# Define a pattern to extract the named section along with its content
|
||||
section_pattern = rf'#### {section_name}\n(.*?)(?=####|$)'
|
||||
|
||||
# Find the specific section content
|
||||
section_content = re.search(section_pattern, text, re.DOTALL)
|
||||
|
||||
if section_content:
|
||||
# If the section is found, extract the content
|
||||
content = section_content.group(1)
|
||||
|
||||
# Define a pattern to find segments that follow the format **xx:**
|
||||
segments_pattern = r'\*\*([^*]+):\*\*'
|
||||
|
||||
# Use findall method to extract all matches in the section content
|
||||
segments = re.findall(segments_pattern, content)
|
||||
|
||||
return segments
|
||||
else:
|
||||
# If the section is not found, return an empty list
|
||||
return []
|
||||
|
||||
|
||||
def parse_text_to_dict(text):
|
||||
# Define a regular expression pattern to capture the key and value
|
||||
main_pattern = r"\*\*(.+?):\*\*\s*(.*?)\s*(?=\*\*|$)"
|
||||
list_pattern = r'```python\n(.*?)```'
|
||||
plan_pattern = r'\[\s*.*?\s*\]'
|
||||
|
||||
# Use re.findall to find all main matches in the text
|
||||
main_matches = re.findall(main_pattern, text, re.DOTALL)
|
||||
|
||||
# Convert main matches to a dictionary
|
||||
parsed_dict = {key.strip(): value.strip() for key, value in main_matches}
|
||||
|
||||
for k, v in parsed_dict.items():
|
||||
for pattern in [list_pattern, plan_pattern]:
|
||||
if "PLAN" != k: continue
|
||||
v = v.replace("```list", "```python")
|
||||
match_value = re.search(pattern, v, re.DOTALL)
|
||||
if match_value:
|
||||
# Add the code block to the dictionary
|
||||
parsed_dict[k] = eval(match_value.group(1).strip())
|
||||
break
|
||||
|
||||
return parsed_dict
|
||||
|
||||
|
||||
def parse_dict_to_dict(parsed_dict) -> dict:
|
||||
code_pattern = r'```python\n(.*?)```'
|
||||
tool_pattern = r'```json\n(.*?)```'
|
||||
|
||||
pattern_dict = {"code": code_pattern, "json": tool_pattern}
|
||||
spec_parsed_dict = copy.deepcopy(parsed_dict)
|
||||
for key, pattern in pattern_dict.items():
|
||||
for k, text in parsed_dict.items():
|
||||
# Search for the code block
|
||||
if not isinstance(text, str): continue
|
||||
_match = re.search(pattern, text, re.DOTALL)
|
||||
if _match:
|
||||
# Add the code block to the dictionary
|
||||
try:
|
||||
spec_parsed_dict[key] = json.loads(_match.group(1).strip())
|
||||
except:
|
||||
spec_parsed_dict[key] = _match.group(1).strip()
|
||||
break
|
||||
return spec_parsed_dict
|
||||
|
||||
|
||||
def prompt_cost(model_type: str, num_prompt_tokens: float, num_completion_tokens: float):
|
||||
input_cost_map = {
|
||||
"gpt-3.5-turbo": 0.0015,
|
||||
"gpt-3.5-turbo-16k": 0.003,
|
||||
"gpt-3.5-turbo-0613": 0.0015,
|
||||
"gpt-3.5-turbo-16k-0613": 0.003,
|
||||
"gpt-4": 0.03,
|
||||
"gpt-4-0613": 0.03,
|
||||
"gpt-4-32k": 0.06,
|
||||
}
|
||||
|
||||
output_cost_map = {
|
||||
"gpt-3.5-turbo": 0.002,
|
||||
"gpt-3.5-turbo-16k": 0.004,
|
||||
"gpt-3.5-turbo-0613": 0.002,
|
||||
"gpt-3.5-turbo-16k-0613": 0.004,
|
||||
"gpt-4": 0.06,
|
||||
"gpt-4-0613": 0.06,
|
||||
"gpt-4-32k": 0.12,
|
||||
}
|
||||
|
||||
if model_type not in input_cost_map or model_type not in output_cost_map:
|
||||
return -1
|
||||
|
||||
return num_prompt_tokens * input_cost_map[model_type] / 1000.0 + num_completion_tokens * output_cost_map[model_type] / 1000.0
|
|
@ -224,6 +224,7 @@ class NebulaHandler:
|
|||
res = {'vertices': -1, 'edges': -1}
|
||||
|
||||
stats_res_dict = self.result_to_dict(stats_res)
|
||||
logger.info(stats_res_dict)
|
||||
for idx in range(len(stats_res_dict['Type'])):
|
||||
t = stats_res_dict['Type'][idx].as_string()
|
||||
name = stats_res_dict['Name'][idx].as_string()
|
||||
|
@ -264,7 +265,3 @@ class NebulaHandler:
|
|||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
@ -6,7 +6,7 @@ from langchain.docstore.document import Document
|
|||
from langchain.document_loaders.base import BaseLoader
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
|
||||
|
||||
from dev_opsgpt.utils.common_utils import read_json_file
|
||||
from coagent.utils.common_utils import read_json_file
|
||||
|
||||
|
||||
class JSONLoader(BaseLoader):
|
|
@ -6,7 +6,7 @@ from langchain.docstore.document import Document
|
|||
from langchain.document_loaders.base import BaseLoader
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
|
||||
|
||||
from dev_opsgpt.utils.common_utils import read_jsonl_file
|
||||
from coagent.utils.common_utils import read_jsonl_file
|
||||
|
||||
|
||||
class JSONLLoader(BaseLoader):
|
|
@ -602,7 +602,6 @@ class FAISS(VectorStore):
|
|||
faiss = FAISS.from_texts(texts, embeddings)
|
||||
"""
|
||||
from loguru import logger
|
||||
logger.debug(f"texts: {len(texts)}")
|
||||
embeddings = embedding.embed_documents(texts)
|
||||
return cls.__from(
|
||||
texts,
|
|
@ -7,12 +7,17 @@
|
|||
'''
|
||||
from loguru import logger
|
||||
|
||||
from configs.model_config import EMBEDDING_MODEL
|
||||
from dev_opsgpt.embeddings.openai_embedding import OpenAIEmbedding
|
||||
from dev_opsgpt.embeddings.huggingface_embedding import HFEmbedding
|
||||
# from configs.model_config import EMBEDDING_MODEL
|
||||
from coagent.embeddings.openai_embedding import OpenAIEmbedding
|
||||
from coagent.embeddings.huggingface_embedding import HFEmbedding
|
||||
|
||||
|
||||
def get_embedding(engine: str, text_list: list):
|
||||
def get_embedding(
|
||||
engine: str,
|
||||
text_list: list,
|
||||
model_path: str = "text2vec-base-chinese",
|
||||
embedding_device: str = "cpu",
|
||||
):
|
||||
'''
|
||||
get embedding
|
||||
@param engine: openai / hf
|
||||
|
@ -25,7 +30,7 @@ def get_embedding(engine: str, text_list: list):
|
|||
oae = OpenAIEmbedding()
|
||||
emb_res = oae.get_emb(text_list)
|
||||
elif engine == 'model':
|
||||
hfe = HFEmbedding(EMBEDDING_MODEL)
|
||||
hfe = HFEmbedding(model_path, embedding_device)
|
||||
emb_res = hfe.get_emb(text_list)
|
||||
|
||||
return emb_res
|
|
@ -6,8 +6,9 @@
|
|||
@desc:
|
||||
'''
|
||||
from loguru import logger
|
||||
from configs.model_config import EMBEDDING_DEVICE
|
||||
from dev_opsgpt.embeddings.utils import load_embeddings
|
||||
# from configs.model_config import EMBEDDING_DEVICE
|
||||
# from configs.model_config import embedding_model_dict
|
||||
from coagent.embeddings.utils import load_embeddings, load_embeddings_from_path
|
||||
|
||||
|
||||
class HFEmbedding:
|
||||
|
@ -22,8 +23,8 @@ class HFEmbedding:
|
|||
cls._instance[instance_key] = super().__new__(cls)
|
||||
return cls._instance[instance_key]
|
||||
|
||||
def __init__(self, model_name):
|
||||
self.model = load_embeddings(model=model_name, device=EMBEDDING_DEVICE)
|
||||
def __init__(self, model_name, embedding_device):
|
||||
self.model = load_embeddings_from_path(model_path=model_name, device=embedding_device)
|
||||
logger.debug('load success')
|
||||
|
||||
def get_emb(self, text_list):
|
||||
|
@ -32,9 +33,7 @@ class HFEmbedding:
|
|||
@param text_list:
|
||||
@return:
|
||||
'''
|
||||
logger.info('st')
|
||||
emb_res = self.model.embed_documents(text_list)
|
||||
logger.info('ed')
|
||||
res = {
|
||||
text_list[idx]: emb_res[idx] for idx in range(len(text_list))
|
||||
}
|
|
@ -0,0 +1,20 @@
|
|||
import os
|
||||
from functools import lru_cache
|
||||
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
||||
# from configs.model_config import embedding_model_dict
|
||||
from loguru import logger
|
||||
|
||||
|
||||
@lru_cache(1)
|
||||
def load_embeddings(model: str, device: str, embedding_model_dict: dict):
|
||||
embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[model],
|
||||
model_kwargs={'device': device})
|
||||
return embeddings
|
||||
|
||||
|
||||
@lru_cache(1)
|
||||
def load_embeddings_from_path(model_path: str, device: str):
|
||||
embeddings = HuggingFaceEmbeddings(model_name=model_path,
|
||||
model_kwargs={'device': device})
|
||||
return embeddings
|
||||
|
|
@ -0,0 +1,8 @@
|
|||
from .openai_model import getChatModel, getExtraModel, getChatModelFromConfig
|
||||
from .llm_config import LLMConfig, EmbedConfig
|
||||
|
||||
|
||||
__all__ = [
|
||||
"getChatModel", "getExtraModel", "getChatModelFromConfig",
|
||||
"LLMConfig", "EmbedConfig"
|
||||
]
|
|
@ -0,0 +1,61 @@
|
|||
from dataclasses import dataclass
|
||||
from typing import List, Union
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMConfig:
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = "gpt-3.5-turbo",
|
||||
temperature: float = 0.25,
|
||||
stop: Union[List[str], str] = None,
|
||||
api_key: str = "",
|
||||
api_base_url: str = "",
|
||||
model_device: str = "cpu",
|
||||
**kwargs
|
||||
):
|
||||
|
||||
self.model_name: str = model_name
|
||||
self.temperature: float = temperature
|
||||
self.stop: Union[List[str], str] = stop
|
||||
self.api_key: str = api_key
|
||||
self.api_base_url: str = api_base_url
|
||||
self.model_device: str = model_device
|
||||
#
|
||||
self.check_config()
|
||||
|
||||
def check_config(self, ):
|
||||
pass
|
||||
|
||||
def __str__(self):
|
||||
return ', '.join(f"{k}: {v}" for k,v in vars(self).items())
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmbedConfig:
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str = "",
|
||||
api_base_url: str = "",
|
||||
embed_model: str = "",
|
||||
embed_model_path: str = "",
|
||||
embed_engine: str = "",
|
||||
model_device: str = "cpu",
|
||||
**kwargs
|
||||
):
|
||||
self.embed_model: str = embed_model
|
||||
self.embed_model_path: str = embed_model_path
|
||||
self.embed_engine: str = embed_engine
|
||||
self.model_device: str = model_device
|
||||
self.api_key: str = api_key
|
||||
self.api_base_url: str = api_base_url
|
||||
#
|
||||
self.check_config()
|
||||
|
||||
def check_config(self, ):
|
||||
pass
|
||||
|
||||
def __str__(self):
|
||||
return ', '.join(f"{k}: {v}" for k,v in vars(self).items())
|
||||
|
|
@ -1,7 +1,10 @@
|
|||
import os
|
||||
|
||||
from langchain.callbacks import AsyncIteratorCallbackHandler
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
|
||||
from configs.model_config import (llm_model_dict, LLM_MODEL)
|
||||
from .llm_config import LLMConfig
|
||||
# from configs.model_config import (llm_model_dict, LLM_MODEL)
|
||||
|
||||
|
||||
def getChatModel(callBack: AsyncIteratorCallbackHandler = None, temperature=0.3, stop=None):
|
||||
|
@ -28,6 +31,33 @@ def getChatModel(callBack: AsyncIteratorCallbackHandler = None, temperature=0.3,
|
|||
)
|
||||
return model
|
||||
|
||||
|
||||
def getChatModelFromConfig(llm_config: LLMConfig, callBack: AsyncIteratorCallbackHandler = None, ):
|
||||
if callBack is None:
|
||||
model = ChatOpenAI(
|
||||
streaming=True,
|
||||
verbose=True,
|
||||
openai_api_key=llm_config.api_key,
|
||||
openai_api_base=llm_config.api_base_url,
|
||||
model_name=llm_config.model_name,
|
||||
temperature=llm_config.temperature,
|
||||
stop=llm_config.stop
|
||||
)
|
||||
else:
|
||||
model = ChatOpenAI(
|
||||
streaming=True,
|
||||
verbose=True,
|
||||
callBack=[callBack],
|
||||
openai_api_key=llm_config.api_key,
|
||||
openai_api_base=llm_config.api_base_url,
|
||||
model_name=llm_config.model_name,
|
||||
temperature=llm_config.temperature,
|
||||
stop=llm_config.stop
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
import json, requests
|
||||
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue