[feature](coagent)<mv coagent to ~/CodeFuse-muAgent project>
This commit is contained in:
parent
eee0e09ee1
commit
9b419c2dde
|
@ -1,6 +1,7 @@
|
|||
**/__pycache__
|
||||
knowledge_base
|
||||
logs
|
||||
llm_models
|
||||
embedding_models
|
||||
jupyter_work
|
||||
model_config.py
|
||||
|
|
12
README.md
12
README.md
|
@ -14,11 +14,11 @@
|
|||
<br><br>
|
||||
</p>
|
||||
|
||||
DevOps-ChatBot是由蚂蚁CodeFuse团队开发的开源AI智能助手,致力于简化和优化软件开发生命周期中的各个环节。该项目结合了Multi-Agent的协同调度机制,并集成了丰富的工具库、代码库、知识库和沙盒环境,使得LLM模型能够在DevOps领域内有效执行和处理复杂任务。
|
||||
CodeFuse-ChatBot是由蚂蚁CodeFuse团队开发的开源AI智能助手,致力于简化和优化软件开发生命周期中的各个环节。该项目结合了Multi-Agent的协同调度机制,并集成了丰富的工具库、代码库、知识库和沙盒环境,使得LLM模型能够在DevOps领域内有效执行和处理复杂任务。
|
||||
|
||||
|
||||
## 🔔 更新
|
||||
- [2024.01.29] 开放可配置化的multi-agent框架:coagent,详情见[使用说明](sources/readme_docs/coagent/coagent.md)
|
||||
- [2024.01.29] 开放可配置化的multi-agent框架:codefuse-muAgent,详情见[使用说明](sources/readme_docs/coagent/coagent.md)
|
||||
- [2023.12.26] 基于FastChat接入开源私有化大模型和大模型接口的能力开放
|
||||
- [2023.12.14] 量子位公众号专题报道:[文章链接](https://mp.weixin.qq.com/s/MuPfayYTk9ZW6lcqgMpqKA)
|
||||
- [2023.12.01] Multi-Agent和代码库检索功能开放
|
||||
|
@ -96,10 +96,10 @@ DevOps-ChatBot是由蚂蚁CodeFuse团队开发的开源AI智能助手,致力
|
|||
|
||||
|
||||
## 🚀 快速使用
|
||||
### coagent-py
|
||||
完整文档见:[coagent](sources/readme_docs/coagent/coagent.md)
|
||||
### muagent-py
|
||||
完整文档见:[CodeFuse-muAgent](sources/readme_docs/coagent/coagent.md)
|
||||
```
|
||||
pip install coagent
|
||||
pip install codefuse-muagent
|
||||
```
|
||||
|
||||
### 使用ChatBot
|
||||
|
@ -128,7 +128,7 @@ pip install -r requirements.txt
|
|||
# 完成server_config.py配置后,可一键启动
|
||||
cd examples
|
||||
bash start.sh
|
||||
# 开始在页面进行配置即可
|
||||
# 开始在页面进行相关配置,然后打开`启动对话服务`即可
|
||||
```
|
||||
<div align=center>
|
||||
<img src="sources/docs_imgs/webui_config.png" alt="图片">
|
||||
|
|
53
README_en.md
53
README_en.md
|
@ -17,7 +17,7 @@ This project is an open-source AI intelligent assistant, specifically designed f
|
|||
|
||||
|
||||
## 🔔 Updates
|
||||
- [2024.01.29] A configurational multi-agent framework, CoAgent, has been open-sourced. For more details, please refer to [coagent](sources/readme_docs/coagent/coagent-en.md)
|
||||
- [2024.01.29] A configurational multi-agent framework, codefuse-muagent, has been open-sourced. For more details, please refer to [codefuse-muagent](sources/readme_docs/coagent/coagent-en.md)
|
||||
- [2023.12.26] Opening the capability to integrate with open-source private large models and large model interfaces based on FastChat
|
||||
- [2023.12.01] Release of Multi-Agent and codebase retrieval functionalities.
|
||||
- [2023.11.15] Addition of Q&A enhancement mode based on the local codebase.
|
||||
|
@ -34,7 +34,7 @@ This project is an open-source AI intelligent assistant, specifically designed f
|
|||
|
||||
💡 The aim of this project is to construct an AI intelligent assistant for the entire lifecycle of software development, covering design, coding, testing, deployment, and operations, through Retrieval Augmented Generation (RAG), Tool Learning, and sandbox environments. It transitions gradually from the traditional development and operations mode of querying information from various sources and operating on standalone, disparate platforms to an intelligent development and operations mode based on large-model Q&A, changing people's development and operations habits.
|
||||
|
||||
- **🧠 Intelligent Scheduling Core:** Constructed a well-integrated scheduling core system that supports multi-mode one-click configuration, simplifying the operational process. [coagent](sources/readme_docs/coagent/coagent-en.md)
|
||||
- **🧠 Intelligent Scheduling Core:** Constructed a well-integrated scheduling core system that supports multi-mode one-click configuration, simplifying the operational process. [codefuse-muagent](sources/readme_docs/coagent/coagent-en.md)
|
||||
- **💻 Comprehensive Code Repository Analysis:** Achieved in-depth understanding at the repository level and coding and generation at the project file level, enhancing development efficiency.
|
||||
- **📄 Enhanced Document Analysis:** Integrated document knowledge bases with knowledge graphs, providing deeper support for document analysis through enhanced retrieval and reasoning.
|
||||
- **🔧 Industry-Specific Knowledge:** Tailored a specialized knowledge base for the DevOps domain, supporting the self-service one-click construction of industry-specific knowledge bases for convenience and practicality.
|
||||
|
@ -83,10 +83,10 @@ If you need to integrate a specific model, please inform us of your requirements
|
|||
|
||||
|
||||
## 🚀 Quick Start
|
||||
### coagent-py
|
||||
More Detail see:[coagent](sources/readme_docs/coagent/coagent-en.md)
|
||||
### muagent-py
|
||||
More Detail see:[codefuse-muagent](sources/readme_docs/coagent/coagent-en.md)
|
||||
```
|
||||
pip install coagent
|
||||
pip install codefuse-muagent
|
||||
```
|
||||
|
||||
### ChatBot-UI
|
||||
|
@ -108,51 +108,12 @@ cd Codefuse-ChatBot
|
|||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
2. Preparation of Sandbox Environment
|
||||
- Windows Docker installation:
|
||||
[Docker Desktop for Windows](https://docs.docker.com/desktop/install/windows-install/) supports 64-bit versions of Windows 10 Pro, with Hyper-V enabled (not required for versions v1903 and above), or 64-bit versions of Windows 10 Home v1903 and above.
|
||||
|
||||
- [Comprehensive Detailed Windows 10 Docker Installation Tutorial](https://zhuanlan.zhihu.com/p/441965046)
|
||||
- [Docker: From Beginner to Practitioner](https://yeasy.gitbook.io/docker_practice/install/windows)
|
||||
- [Handling Docker Desktop requires the Server service to be enabled](https://blog.csdn.net/sunhy_csdn/article/details/106526991)
|
||||
- [Install wsl or wait for error prompt](https://learn.microsoft.com/en-us/windows/wsl/install)
|
||||
|
||||
- Linux Docker Installation:
|
||||
Linux installation is relatively simple, please search Baidu/Google for installation instructions.
|
||||
|
||||
- Mac Docker Installation
|
||||
- [Docker: From Beginner to Practitioner](https://yeasy.gitbook.io/docker_practice/install/mac)
|
||||
|
||||
```bash
|
||||
# Build images for the sandbox environment, see above for notebook version issues
|
||||
bash docker_build.sh
|
||||
```
|
||||
|
||||
3. Model Download (Optional)
|
||||
|
||||
If you need to use open-source LLM and Embed
|
||||
|
||||
ding models, you can download them from HuggingFace.
|
||||
Here, we use THUDM/chatglm2-6b and text2vec-base-chinese as examples:
|
||||
|
||||
```
|
||||
# install git-lfs
|
||||
git lfs install
|
||||
|
||||
# install LLM-model
|
||||
git lfs clone https://huggingface.co/THUDM/chatglm2-6b
|
||||
|
||||
# install Embedding-model
|
||||
git lfs clone https://huggingface.co/shibing624/text2vec-base-chinese
|
||||
```
|
||||
|
||||
|
||||
4. Start the Service
|
||||
2. Start the Service
|
||||
```bash
|
||||
# After configuring server_config.py, you can start with just one click.
|
||||
cd examples
|
||||
bash start.sh
|
||||
# you can config your llm model and embedding model
|
||||
# you can config your llm model and embedding model, then choose the "启动对话服务"
|
||||
```
|
||||
<div align=center>
|
||||
<img src="sources/docs_imgs/webui_config.png" alt="图片">
|
||||
|
|
|
@ -1,7 +0,0 @@
|
|||
# encoding: utf-8
|
||||
'''
|
||||
@author: 温进
|
||||
@file: __init__.py.py
|
||||
@time: 2023/11/9 下午4:01
|
||||
@desc:
|
||||
'''
|
|
@ -1,93 +0,0 @@
|
|||
import os
|
||||
import platform
|
||||
from loguru import logger
|
||||
|
||||
system_name = platform.system()
|
||||
executable_path = os.getcwd()
|
||||
|
||||
# 日志存储路径
|
||||
LOG_PATH = os.environ.get("LOG_PATH", None) or os.path.join(executable_path, "logs")
|
||||
|
||||
# # 知识库默认存储路径
|
||||
# SOURCE_PATH = os.environ.get("SOURCE_PATH", None) or os.path.join(executable_path, "sources")
|
||||
|
||||
# 知识库默认存储路径
|
||||
KB_ROOT_PATH = os.environ.get("KB_ROOT_PATH", None) or os.path.join(executable_path, "knowledge_base")
|
||||
|
||||
# 代码库默认存储路径
|
||||
CB_ROOT_PATH = os.environ.get("CB_ROOT_PATH", None) or os.path.join(executable_path, "code_base")
|
||||
|
||||
# # nltk 模型存储路径
|
||||
# NLTK_DATA_PATH = os.environ.get("NLTK_DATA_PATH", None) or os.path.join(executable_path, "nltk_data")
|
||||
|
||||
# 代码存储路径
|
||||
JUPYTER_WORK_PATH = os.environ.get("JUPYTER_WORK_PATH", None) or os.path.join(executable_path, "jupyter_work")
|
||||
|
||||
# WEB_CRAWL存储路径
|
||||
WEB_CRAWL_PATH = os.environ.get("WEB_CRAWL_PATH", None) or os.path.join(executable_path, "knowledge_base")
|
||||
|
||||
# NEBULA_DATA存储路径
|
||||
NEBULA_PATH = os.environ.get("NEBULA_PATH", None) or os.path.join(executable_path, "data/nebula_data")
|
||||
|
||||
# CHROMA 存储路径
|
||||
CHROMA_PERSISTENT_PATH = os.environ.get("CHROMA_PERSISTENT_PATH", None) or os.path.join(executable_path, "data/chroma_data")
|
||||
|
||||
for _path in [LOG_PATH, KB_ROOT_PATH, CB_ROOT_PATH, JUPYTER_WORK_PATH, WEB_CRAWL_PATH, NEBULA_PATH, CHROMA_PERSISTENT_PATH]:
|
||||
if not os.path.exists(_path) and int(os.environ.get("do_create_dir", True)):
|
||||
os.makedirs(_path, exist_ok=True)
|
||||
|
||||
# 数据库默认存储路径。
|
||||
# 如果使用sqlite,可以直接修改DB_ROOT_PATH;如果使用其它数据库,请直接修改SQLALCHEMY_DATABASE_URI。
|
||||
DB_ROOT_PATH = os.path.join(KB_ROOT_PATH, "info.db")
|
||||
SQLALCHEMY_DATABASE_URI = f"sqlite:///{DB_ROOT_PATH}"
|
||||
|
||||
kbs_config = {
|
||||
"faiss": {
|
||||
},}
|
||||
|
||||
|
||||
# GENERAL SERVER CONFIG
|
||||
DEFAULT_BIND_HOST = os.environ.get("DEFAULT_BIND_HOST", None) or "127.0.0.1"
|
||||
|
||||
# NEBULA SERVER CONFIG
|
||||
NEBULA_HOST = DEFAULT_BIND_HOST
|
||||
NEBULA_PORT = 9669
|
||||
NEBULA_STORAGED_PORT = 9779
|
||||
NEBULA_USER = 'root'
|
||||
NEBULA_PASSWORD = ''
|
||||
NEBULA_GRAPH_SERVER = {
|
||||
"host": DEFAULT_BIND_HOST,
|
||||
"port": NEBULA_PORT,
|
||||
"docker_port": NEBULA_PORT
|
||||
}
|
||||
|
||||
# CHROMA CONFIG
|
||||
# CHROMA_PERSISTENT_PATH = '/home/user/chatbot/data/chroma_data'
|
||||
# CHROMA_PERSISTENT_PATH = '/Users/bingxu/Desktop/工作/大模型/chatbot/codefuse-chatbot-antcode/data/chroma_data'
|
||||
|
||||
|
||||
# 默认向量库类型。可选:faiss, milvus, pg.
|
||||
DEFAULT_VS_TYPE = os.environ.get("DEFAULT_VS_TYPE") or "faiss"
|
||||
|
||||
# 缓存向量库数量
|
||||
CACHED_VS_NUM = os.environ.get("CACHED_VS_NUM") or 1
|
||||
|
||||
# 知识库中单段文本长度
|
||||
CHUNK_SIZE = os.environ.get("CHUNK_SIZE") or 500
|
||||
|
||||
# 知识库中相邻文本重合长度
|
||||
OVERLAP_SIZE = os.environ.get("OVERLAP_SIZE") or 50
|
||||
|
||||
# 知识库匹配向量数量
|
||||
VECTOR_SEARCH_TOP_K = os.environ.get("VECTOR_SEARCH_TOP_K") or 5
|
||||
|
||||
# 知识库匹配相关度阈值,取值范围在0-1之间,SCORE越小,相关度越高,取到1相当于不筛选,建议设置在0.5左右
|
||||
# Mac 可能存在无法使用normalized_L2的问题,因此调整SCORE_THRESHOLD至 0~1100
|
||||
FAISS_NORMALIZE_L2 = True if system_name in ["Linux", "Windows"] else False
|
||||
SCORE_THRESHOLD = 1 if system_name in ["Linux", "Windows"] else 1100
|
||||
|
||||
# 搜索引擎匹配结题数量
|
||||
SEARCH_ENGINE_TOP_K = os.environ.get("SEARCH_ENGINE_TOP_K") or 5
|
||||
|
||||
# 代码引擎匹配结题数量
|
||||
CODE_SEARCH_TOP_K = os.environ.get("CODE_SEARCH_TOP_K") or 1
|
|
@ -1,11 +0,0 @@
|
|||
from .base_chat import Chat
|
||||
from .knowledge_chat import KnowledgeChat
|
||||
from .llm_chat import LLMChat
|
||||
from .search_chat import SearchChat
|
||||
from .code_chat import CodeChat
|
||||
from .agent_chat import AgentChat
|
||||
|
||||
|
||||
__all__ = [
|
||||
"Chat", "KnowledgeChat", "LLMChat", "SearchChat", "CodeChat", "AgentChat"
|
||||
]
|
|
@ -1,348 +0,0 @@
|
|||
from fastapi import Body, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from typing import List, Union, Dict
|
||||
from loguru import logger
|
||||
import importlib
|
||||
import copy
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
# from configs.model_config import (
|
||||
# llm_model_dict, LLM_MODEL, PROMPT_TEMPLATE,
|
||||
# VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD)
|
||||
|
||||
from coagent.tools import (
|
||||
toLangchainTools,
|
||||
TOOL_DICT, TOOL_SETS
|
||||
)
|
||||
|
||||
from coagent.connector.phase import BasePhase
|
||||
from coagent.connector.schema import Message
|
||||
from coagent.connector.schema import Memory
|
||||
from coagent.chat.utils import History, wrap_done
|
||||
from coagent.llm_models.llm_config import LLMConfig, EmbedConfig
|
||||
from coagent.connector.configs import PHASE_CONFIGS, AGETN_CONFIGS, CHAIN_CONFIGS
|
||||
|
||||
PHASE_MODULE = importlib.import_module("coagent.connector.phase")
|
||||
|
||||
|
||||
|
||||
class AgentChat:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine_name: str = "",
|
||||
top_k: int = 1,
|
||||
stream: bool = False,
|
||||
) -> None:
|
||||
self.top_k = top_k
|
||||
self.stream = stream
|
||||
self.chatPhase_dict: Dict[str, BasePhase] = {}
|
||||
|
||||
def chat(
|
||||
self,
|
||||
query: str = Body(..., description="用户输入", examples=["hello"]),
|
||||
phase_name: str = Body(..., description="执行场景名称", examples=["chatPhase"]),
|
||||
chain_name: str = Body(..., description="执行链的名称", examples=["chatChain"]),
|
||||
history: List[History] = Body(
|
||||
[], description="历史对话",
|
||||
examples=[[{"role": "user", "content": "我们来玩成语接龙,我先来,生龙活虎"}]]
|
||||
),
|
||||
doc_engine_name: str = Body(..., description="知识库名称", examples=["samples"]),
|
||||
search_engine_name: str = Body(..., description="搜索引擎名称", examples=["duckduckgo"]),
|
||||
code_engine_name: str = Body(..., description="代码引擎名称", examples=["samples"]),
|
||||
top_k: int = Body(5, description="匹配向量数"),
|
||||
score_threshold: float = Body(1, description="知识库匹配相关度阈值,取值范围在0-1之间,SCORE越小,相关度越高,取到1相当于不筛选,建议设置在0.5左右", ge=0, le=1),
|
||||
stream: bool = Body(False, description="流式输出"),
|
||||
local_doc_url: bool = Body(False, description="知识文件返回本地路径(true)或URL(false)"),
|
||||
choose_tools: List[str] = Body([], description="选择tool的集合"),
|
||||
do_search: bool = Body(False, description="是否进行搜索"),
|
||||
do_doc_retrieval: bool = Body(False, description="是否进行知识库检索"),
|
||||
do_code_retrieval: bool = Body(False, description="是否执行代码检索"),
|
||||
do_tool_retrieval: bool = Body(False, description="是否执行工具检索"),
|
||||
custom_phase_configs: dict = Body({}, description="自定义phase配置"),
|
||||
custom_chain_configs: dict = Body({}, description="自定义chain配置"),
|
||||
custom_role_configs: dict = Body({}, description="自定义role配置"),
|
||||
history_node_list: List = Body([], description="代码历史相关节点"),
|
||||
isDetailed: bool = Body(False, description="是否输出完整的agent相关内容"),
|
||||
upload_file: Union[str, Path, bytes] = "",
|
||||
kb_root_path: str = Body("", description="知识库存储路径"),
|
||||
jupyter_work_path: str = Body("", description="sandbox执行环境"),
|
||||
sandbox_server: str = Body({}, description="代码历史相关节点"),
|
||||
api_key: str = Body(os.environ.get("OPENAI_API_KEY"), description=""),
|
||||
api_base_url: str = Body(os.environ.get("API_BASE_URL"),),
|
||||
embed_model: str = Body("", description="向量模型"),
|
||||
embed_model_path: str = Body("", description="向量模型路径"),
|
||||
model_device: str = Body("", description="模型加载设备"),
|
||||
embed_engine: str = Body("", description="向量模型类型"),
|
||||
model_name: str = Body("", description="llm模型名称"),
|
||||
temperature: float = Body(0.2, description=""),
|
||||
**kargs
|
||||
) -> Message:
|
||||
|
||||
# update configs
|
||||
phase_configs, chain_configs, agent_configs = self.update_configs(
|
||||
custom_phase_configs, custom_chain_configs, custom_role_configs)
|
||||
params = locals()
|
||||
params.pop("self")
|
||||
embed_config: EmbedConfig = EmbedConfig(**params)
|
||||
llm_config: LLMConfig = LLMConfig(**params)
|
||||
|
||||
logger.info('phase_configs={}'.format(phase_configs))
|
||||
logger.info('chain_configs={}'.format(chain_configs))
|
||||
logger.info('agent_configs={}'.format(agent_configs))
|
||||
logger.info('phase_name')
|
||||
logger.info('chain_name')
|
||||
|
||||
# choose tools
|
||||
tools = toLangchainTools([TOOL_DICT[i] for i in choose_tools if i in TOOL_DICT])
|
||||
|
||||
if upload_file:
|
||||
upload_file_name = upload_file if upload_file and isinstance(upload_file, str) else upload_file.name
|
||||
for _filename_idx in range(len(upload_file_name), 0, -1):
|
||||
if upload_file_name[:_filename_idx] in query:
|
||||
query = query.replace(upload_file_name[:_filename_idx], upload_file_name)
|
||||
break
|
||||
|
||||
input_message = Message(
|
||||
role_content=query,
|
||||
role_type="user",
|
||||
role_name="human",
|
||||
input_query=query,
|
||||
origin_query=query,
|
||||
phase_name=phase_name,
|
||||
chain_name=chain_name,
|
||||
do_search=do_search,
|
||||
do_doc_retrieval=do_doc_retrieval,
|
||||
do_code_retrieval=do_code_retrieval,
|
||||
do_tool_retrieval=do_tool_retrieval,
|
||||
doc_engine_name=doc_engine_name, search_engine_name=search_engine_name,
|
||||
code_engine_name=code_engine_name,
|
||||
score_threshold=score_threshold, top_k=top_k,
|
||||
history_node_list=history_node_list,
|
||||
tools=tools
|
||||
)
|
||||
# history memory mangemant
|
||||
history = Memory(messages=[
|
||||
Message(role_name=i["role"], role_type=i["role"], role_content=i["content"])
|
||||
for i in history
|
||||
])
|
||||
# start to execute
|
||||
phase_class = getattr(PHASE_MODULE, phase_configs[input_message.phase_name]["phase_type"])
|
||||
# TODO 需要把相关信息补充上去
|
||||
phase = phase_class(input_message.phase_name,
|
||||
task = input_message.task,
|
||||
base_phase_config = phase_configs,
|
||||
base_chain_config = chain_configs,
|
||||
base_role_config = agent_configs,
|
||||
phase_config = None,
|
||||
kb_root_path = kb_root_path,
|
||||
jupyter_work_path = jupyter_work_path,
|
||||
sandbox_server = sandbox_server,
|
||||
embed_config = embed_config,
|
||||
llm_config = llm_config,
|
||||
)
|
||||
output_message, local_memory = phase.step(input_message, history)
|
||||
|
||||
def chat_iterator(message: Message, local_memory: Memory, isDetailed=False):
|
||||
step_content = local_memory.to_str_messages(content_key='step_content', filter_roles=["user"])
|
||||
final_content = message.role_content
|
||||
logger.debug(f"{step_content}")
|
||||
result = {
|
||||
"answer": "",
|
||||
"db_docs": [str(doc) for doc in message.db_docs],
|
||||
"search_docs": [str(doc) for doc in message.search_docs],
|
||||
"code_docs": [str(doc) for doc in message.code_docs],
|
||||
"related_nodes": [doc.get_related_node() for idx, doc in enumerate(message.code_docs) if idx==0],
|
||||
"figures": message.figures,
|
||||
"step_content": step_content,
|
||||
"final_content": final_content,
|
||||
}
|
||||
|
||||
|
||||
related_nodes, has_nodes = [], [ ]
|
||||
for nodes in result["related_nodes"]:
|
||||
for node in nodes:
|
||||
if node not in has_nodes:
|
||||
related_nodes.append(node)
|
||||
result["related_nodes"] = related_nodes
|
||||
|
||||
# logger.debug(f"{result['figures'].keys()}, isDetailed: {isDetailed}")
|
||||
message_str = step_content
|
||||
if self.stream:
|
||||
for token in message_str:
|
||||
result["answer"] = token
|
||||
yield json.dumps(result, ensure_ascii=False)
|
||||
else:
|
||||
for token in message_str:
|
||||
result["answer"] += token
|
||||
yield json.dumps(result, ensure_ascii=False)
|
||||
|
||||
return StreamingResponse(chat_iterator(output_message, local_memory, isDetailed), media_type="text/event-stream")
|
||||
|
||||
|
||||
def achat(
|
||||
self,
|
||||
query: str = Body(..., description="用户输入", examples=["hello"]),
|
||||
phase_name: str = Body(..., description="执行场景名称", examples=["chatPhase"]),
|
||||
chain_name: str = Body(..., description="执行链的名称", examples=["chatChain"]),
|
||||
history: List[History] = Body(
|
||||
[], description="历史对话",
|
||||
examples=[[{"role": "user", "content": "我们来玩成语接龙,我先来,生龙活虎"}]]
|
||||
),
|
||||
doc_engine_name: str = Body(..., description="知识库名称", examples=["samples"]),
|
||||
search_engine_name: str = Body(..., description="搜索引擎名称", examples=["duckduckgo"]),
|
||||
code_engine_name: str = Body(..., description="代码引擎名称", examples=["samples"]),
|
||||
cb_search_type: str = Body(..., description="代码查询模式", examples=["tag"]),
|
||||
top_k: int = Body(5, description="匹配向量数"),
|
||||
score_threshold: float = Body(1, description="知识库匹配相关度阈值,取值范围在0-1之间,SCORE越小,相关度越高,取到1相当于不筛选,建议设置在0.5左右", ge=0, le=1),
|
||||
stream: bool = Body(False, description="流式输出"),
|
||||
local_doc_url: bool = Body(False, description="知识文件返回本地路径(true)或URL(false)"),
|
||||
choose_tools: List[str] = Body([], description="选择tool的集合"),
|
||||
do_search: bool = Body(False, description="是否进行搜索"),
|
||||
do_doc_retrieval: bool = Body(False, description="是否进行知识库检索"),
|
||||
do_code_retrieval: bool = Body(False, description="是否执行代码检索"),
|
||||
do_tool_retrieval: bool = Body(False, description="是否执行工具检索"),
|
||||
custom_phase_configs: dict = Body({}, description="自定义phase配置"),
|
||||
custom_chain_configs: dict = Body({}, description="自定义chain配置"),
|
||||
custom_role_configs: dict = Body({}, description="自定义role配置"),
|
||||
history_node_list: List = Body([], description="代码历史相关节点"),
|
||||
isDetailed: bool = Body(False, description="是否输出完整的agent相关内容"),
|
||||
upload_file: Union[str, Path, bytes] = "",
|
||||
kb_root_path: str = Body("", description="知识库存储路径"),
|
||||
jupyter_work_path: str = Body("", description="sandbox执行环境"),
|
||||
sandbox_server: str = Body({}, description="代码历史相关节点"),
|
||||
api_key: str = Body(os.environ["OPENAI_API_KEY"], description=""),
|
||||
api_base_url: str = Body(os.environ.get("API_BASE_URL"),),
|
||||
embed_model: str = Body("", description="向量模型"),
|
||||
embed_model_path: str = Body("", description="向量模型路径"),
|
||||
model_device: str = Body("", description="模型加载设备"),
|
||||
embed_engine: str = Body("", description="向量模型类型"),
|
||||
model_name: str = Body("", description="llm模型名称"),
|
||||
temperature: float = Body(0.2, description=""),
|
||||
**kargs
|
||||
) -> Message:
|
||||
|
||||
# update configs
|
||||
phase_configs, chain_configs, agent_configs = self.update_configs(
|
||||
custom_phase_configs, custom_chain_configs, custom_role_configs)
|
||||
|
||||
#
|
||||
params = locals()
|
||||
params.pop("self")
|
||||
embed_config: EmbedConfig = EmbedConfig(**params)
|
||||
llm_config: LLMConfig = LLMConfig(**params)
|
||||
|
||||
# choose tools
|
||||
tools = toLangchainTools([TOOL_DICT[i] for i in choose_tools if i in TOOL_DICT])
|
||||
|
||||
if upload_file:
|
||||
upload_file_name = upload_file if upload_file and isinstance(upload_file, str) else upload_file.name
|
||||
for _filename_idx in range(len(upload_file_name), 0, -1):
|
||||
if upload_file_name[:_filename_idx] in query:
|
||||
query = query.replace(upload_file_name[:_filename_idx], upload_file_name)
|
||||
break
|
||||
|
||||
input_message = Message(
|
||||
role_content=query,
|
||||
role_type="user",
|
||||
role_name="human",
|
||||
input_query=query,
|
||||
origin_query=query,
|
||||
phase_name=phase_name,
|
||||
chain_name=chain_name,
|
||||
do_search=do_search,
|
||||
do_doc_retrieval=do_doc_retrieval,
|
||||
do_code_retrieval=do_code_retrieval,
|
||||
do_tool_retrieval=do_tool_retrieval,
|
||||
doc_engine_name=doc_engine_name,
|
||||
search_engine_name=search_engine_name,
|
||||
code_engine_name=code_engine_name,
|
||||
cb_search_type=cb_search_type,
|
||||
score_threshold=score_threshold, top_k=top_k,
|
||||
history_node_list=history_node_list,
|
||||
tools=tools
|
||||
)
|
||||
# history memory mangemant
|
||||
history = Memory(messages=[
|
||||
Message(role_name=i["role"], role_type=i["role"], role_content=i["content"])
|
||||
for i in history
|
||||
])
|
||||
# start to execute
|
||||
if phase_configs[input_message.phase_name]["phase_type"] not in self.chatPhase_dict:
|
||||
phase_class = getattr(PHASE_MODULE, phase_configs[input_message.phase_name]["phase_type"])
|
||||
phase = phase_class(input_message.phase_name,
|
||||
task = input_message.task,
|
||||
base_phase_config = phase_configs,
|
||||
base_chain_config = chain_configs,
|
||||
base_role_config = agent_configs,
|
||||
phase_config = None,
|
||||
kb_root_path = kb_root_path,
|
||||
jupyter_work_path = jupyter_work_path,
|
||||
sandbox_server = sandbox_server,
|
||||
embed_config = embed_config,
|
||||
llm_config = llm_config,
|
||||
)
|
||||
self.chatPhase_dict[phase_configs[input_message.phase_name]["phase_type"]] = phase
|
||||
else:
|
||||
phase = self.chatPhase_dict[phase_configs[input_message.phase_name]["phase_type"]]
|
||||
|
||||
def chat_iterator(message: Message, local_memory: Memory, isDetailed=False):
|
||||
step_content = local_memory.to_str_messages(content_key='step_content', filter_roles=["human"])
|
||||
step_content = "\n\n".join([f"{v}" for parsed_output in local_memory.get_parserd_output_list()[1:] for k, v in parsed_output.items() if k not in ["Action Status"]])
|
||||
final_content = message.role_content
|
||||
result = {
|
||||
"answer": "",
|
||||
"db_docs": [str(doc) for doc in message.db_docs],
|
||||
"search_docs": [str(doc) for doc in message.search_docs],
|
||||
"code_docs": [str(doc) for doc in message.code_docs],
|
||||
"related_nodes": [doc.get_related_node() for idx, doc in enumerate(message.code_docs) if idx==0],
|
||||
"figures": message.figures,
|
||||
"step_content": step_content or final_content,
|
||||
"final_content": final_content,
|
||||
}
|
||||
|
||||
related_nodes, has_nodes = [], [ ]
|
||||
for nodes in result["related_nodes"]:
|
||||
for node in nodes:
|
||||
if node not in has_nodes:
|
||||
related_nodes.append(node)
|
||||
result["related_nodes"] = related_nodes
|
||||
|
||||
# logger.debug(f"{result['figures'].keys()}, isDetailed: {isDetailed}")
|
||||
message_str = step_content
|
||||
if self.stream:
|
||||
for token in message_str:
|
||||
result["answer"] = token
|
||||
yield json.dumps(result, ensure_ascii=False)
|
||||
else:
|
||||
for token in message_str:
|
||||
result["answer"] += token
|
||||
yield json.dumps(result, ensure_ascii=False)
|
||||
|
||||
|
||||
for output_message, local_memory in phase.astep(input_message, history):
|
||||
|
||||
# logger.debug(f"output_message: {output_message}")
|
||||
# output_message = Message(**output_message)
|
||||
# local_memory = Memory(**local_memory)
|
||||
for result in chat_iterator(output_message, local_memory, isDetailed):
|
||||
yield result
|
||||
|
||||
|
||||
def _chat(self, ):
|
||||
pass
|
||||
|
||||
def update_configs(self, custom_phase_configs, custom_chain_configs, custom_role_configs):
|
||||
'''update phase/chain/agent configs'''
|
||||
phase_configs = copy.deepcopy(PHASE_CONFIGS)
|
||||
phase_configs.update(custom_phase_configs)
|
||||
chain_configs = copy.deepcopy(CHAIN_CONFIGS)
|
||||
chain_configs.update(custom_chain_configs)
|
||||
agent_configs = copy.deepcopy(AGETN_CONFIGS)
|
||||
agent_configs.update(custom_role_configs)
|
||||
# phase_configs = load_phase_configs(new_phase_configs)
|
||||
# chian_configs = load_chain_configs(new_chain_configs)
|
||||
# agent_configs = load_role_configs(new_agent_configs)
|
||||
return phase_configs, chain_configs, agent_configs
|
|
@ -1,173 +0,0 @@
|
|||
from fastapi import Body, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
import asyncio, json, os
|
||||
from typing import List, AsyncIterable
|
||||
|
||||
from langchain import LLMChain
|
||||
from langchain.callbacks import AsyncIteratorCallbackHandler
|
||||
from langchain.prompts.chat import ChatPromptTemplate
|
||||
|
||||
from coagent.llm_models import getChatModelFromConfig
|
||||
from coagent.chat.utils import History, wrap_done
|
||||
from coagent.llm_models.llm_config import LLMConfig, EmbedConfig
|
||||
# from configs.model_config import (llm_model_dict, LLM_MODEL, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD)
|
||||
from coagent.utils import BaseResponse
|
||||
from loguru import logger
|
||||
|
||||
|
||||
|
||||
class Chat:
|
||||
def __init__(
|
||||
self,
|
||||
engine_name: str = "",
|
||||
top_k: int = 1,
|
||||
stream: bool = False,
|
||||
) -> None:
|
||||
self.engine_name = engine_name
|
||||
self.top_k = top_k
|
||||
self.stream = stream
|
||||
|
||||
def check_service_status(self, ) -> BaseResponse:
|
||||
return BaseResponse(code=200, msg=f"okok")
|
||||
|
||||
def chat(
|
||||
self,
|
||||
query: str = Body(..., description="用户输入", examples=["hello"]),
|
||||
history: List[History] = Body(
|
||||
[], description="历史对话",
|
||||
examples=[[{"role": "user", "content": "我们来玩成语接龙,我先来,生龙活虎"}]]
|
||||
),
|
||||
engine_name: str = Body(..., description="知识库名称", examples=["samples"]),
|
||||
top_k: int = Body(5, description="匹配向量数"),
|
||||
score_threshold: float = Body(1, description="知识库匹配相关度阈值,取值范围在0-1之间,SCORE越小,相关度越高,取到1相当于不筛选,建议设置在0.5左右", ge=0, le=1),
|
||||
stream: bool = Body(False, description="流式输出"),
|
||||
local_doc_url: bool = Body(False, description="知识文件返回本地路径(true)或URL(false)"),
|
||||
request: Request = None,
|
||||
api_key: str = Body(os.environ.get("OPENAI_API_KEY")),
|
||||
api_base_url: str = Body(os.environ.get("API_BASE_URL")),
|
||||
embed_model: str = Body("", ),
|
||||
embed_model_path: str = Body("", ),
|
||||
embed_engine: str = Body("", ),
|
||||
model_name: str = Body("", ),
|
||||
temperature: float = Body(0.5, ),
|
||||
model_device: str = Body("", ),
|
||||
**kargs
|
||||
):
|
||||
params = locals()
|
||||
params.pop("self", None)
|
||||
llm_config: LLMConfig = LLMConfig(**params)
|
||||
embed_config: EmbedConfig = EmbedConfig(**params)
|
||||
self.engine_name = engine_name if isinstance(engine_name, str) else engine_name.default
|
||||
self.top_k = top_k if isinstance(top_k, int) else top_k.default
|
||||
self.score_threshold = score_threshold if isinstance(score_threshold, float) else score_threshold.default
|
||||
self.stream = stream if isinstance(stream, bool) else stream.default
|
||||
self.local_doc_url = local_doc_url if isinstance(local_doc_url, bool) else local_doc_url.default
|
||||
self.request = request
|
||||
return self._chat(query, history, llm_config, embed_config, **kargs)
|
||||
|
||||
def _chat(self, query: str, history: List[History], llm_config: LLMConfig, embed_config: EmbedConfig, **kargs):
|
||||
history = [History(**h) if isinstance(h, dict) else h for h in history]
|
||||
|
||||
## check service dependcy is ok
|
||||
service_status = self.check_service_status()
|
||||
|
||||
if service_status.code!=200: return service_status
|
||||
|
||||
def chat_iterator(query: str, history: List[History]):
|
||||
# model = getChatModel()
|
||||
model = getChatModelFromConfig(llm_config)
|
||||
|
||||
result, content = self.create_task(query, history, model, llm_config, embed_config, **kargs)
|
||||
logger.info('result={}'.format(result))
|
||||
logger.info('content={}'.format(content))
|
||||
|
||||
if self.stream:
|
||||
for token in content["text"]:
|
||||
result["answer"] = token
|
||||
yield json.dumps(result, ensure_ascii=False)
|
||||
else:
|
||||
for token in content["text"]:
|
||||
result["answer"] += token
|
||||
yield json.dumps(result, ensure_ascii=False)
|
||||
|
||||
return StreamingResponse(chat_iterator(query, history),
|
||||
media_type="text/event-stream")
|
||||
|
||||
def achat(
|
||||
self,
|
||||
query: str = Body(..., description="用户输入", examples=["hello"]),
|
||||
history: List[History] = Body(
|
||||
[], description="历史对话",
|
||||
examples=[[{"role": "user", "content": "我们来玩成语接龙,我先来,生龙活虎"}]]
|
||||
),
|
||||
engine_name: str = Body(..., description="知识库名称", examples=["samples"]),
|
||||
top_k: int = Body(5, description="匹配向量数"),
|
||||
score_threshold: float = Body(1, description="知识库匹配相关度阈值,取值范围在0-1之间,SCORE越小,相关度越高,取到1相当于不筛选,建议设置在0.5左右", ge=0, le=1),
|
||||
stream: bool = Body(False, description="流式输出"),
|
||||
local_doc_url: bool = Body(False, description="知识文件返回本地路径(true)或URL(false)"),
|
||||
request: Request = None,
|
||||
api_key: str = Body(os.environ.get("OPENAI_API_KEY")),
|
||||
api_base_url: str = Body(os.environ.get("API_BASE_URL")),
|
||||
embed_model: str = Body("", ),
|
||||
embed_model_path: str = Body("", ),
|
||||
embed_engine: str = Body("", ),
|
||||
model_name: str = Body("", ),
|
||||
temperature: float = Body(0.5, ),
|
||||
model_device: str = Body("", ),
|
||||
):
|
||||
#
|
||||
params = locals()
|
||||
params.pop("self", None)
|
||||
llm_config: LLMConfig = LLMConfig(**params)
|
||||
embed_config: EmbedConfig = EmbedConfig(**params)
|
||||
self.engine_name = engine_name if isinstance(engine_name, str) else engine_name.default
|
||||
self.top_k = top_k if isinstance(top_k, int) else top_k.default
|
||||
self.score_threshold = score_threshold if isinstance(score_threshold, float) else score_threshold.default
|
||||
self.stream = stream if isinstance(stream, bool) else stream.default
|
||||
self.local_doc_url = local_doc_url if isinstance(local_doc_url, bool) else local_doc_url.default
|
||||
self.request = request
|
||||
return self._achat(query, history, llm_config, embed_config)
|
||||
|
||||
def _achat(self, query: str, history: List[History], llm_config: LLMConfig, embed_config: EmbedConfig):
|
||||
history = [History(**h) if isinstance(h, dict) else h for h in history]
|
||||
## check service dependcy is ok
|
||||
service_status = self.check_service_status()
|
||||
if service_status.code!=200: return service_status
|
||||
|
||||
async def chat_iterator(query, history):
|
||||
callback = AsyncIteratorCallbackHandler()
|
||||
# model = getChatModel()
|
||||
model = getChatModelFromConfig(llm_config)
|
||||
|
||||
task, result = self.create_atask(query, history, model, llm_config, embed_config, callback)
|
||||
if self.stream:
|
||||
for token in callback["text"]:
|
||||
result["answer"] = token
|
||||
yield json.dumps(result, ensure_ascii=False)
|
||||
else:
|
||||
for token in callback["text"]:
|
||||
result["answer"] += token
|
||||
yield json.dumps(result, ensure_ascii=False)
|
||||
await task
|
||||
|
||||
return StreamingResponse(chat_iterator(query, history),
|
||||
media_type="text/event-stream")
|
||||
|
||||
def create_task(self, query: str, history: List[History], model, llm_config: LLMConfig, embed_config: EmbedConfig, **kargs):
|
||||
'''构建 llm 生成任务'''
|
||||
chat_prompt = ChatPromptTemplate.from_messages(
|
||||
[i.to_msg_tuple() for i in history] + [("human", "{input}")]
|
||||
)
|
||||
chain = LLMChain(prompt=chat_prompt, llm=model)
|
||||
content = chain({"input": query})
|
||||
return {"answer": "", "docs": ""}, content
|
||||
|
||||
def create_atask(self, query, history, model, llm_config: LLMConfig, embed_config: EmbedConfig, callback: AsyncIteratorCallbackHandler):
|
||||
chat_prompt = ChatPromptTemplate.from_messages(
|
||||
[i.to_msg_tuple() for i in history] + [("human", "{input}")]
|
||||
)
|
||||
chain = LLMChain(prompt=chat_prompt, llm=model)
|
||||
task = asyncio.create_task(wrap_done(
|
||||
chain.acall({"input": query}), callback.done
|
||||
))
|
||||
return task, {"answer": "", "docs": ""}
|
|
@ -1,174 +0,0 @@
|
|||
# encoding: utf-8
|
||||
'''
|
||||
@author: 温进
|
||||
@file: code_chat.py
|
||||
@time: 2023/10/24 下午4:04
|
||||
@desc:
|
||||
'''
|
||||
|
||||
from fastapi import Request, Body
|
||||
import os, asyncio
|
||||
from typing import List
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from langchain import LLMChain
|
||||
from langchain.callbacks import AsyncIteratorCallbackHandler
|
||||
from langchain.prompts.chat import ChatPromptTemplate
|
||||
|
||||
# from configs.model_config import (
|
||||
# llm_model_dict, LLM_MODEL, PROMPT_TEMPLATE,
|
||||
# VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, CODE_PROMPT_TEMPLATE)
|
||||
from coagent.connector.configs.prompts import CODE_PROMPT_TEMPLATE
|
||||
from coagent.chat.utils import History, wrap_done
|
||||
from coagent.utils import BaseResponse
|
||||
from .base_chat import Chat
|
||||
from coagent.llm_models import getChatModelFromConfig
|
||||
|
||||
from coagent.llm_models.llm_config import LLMConfig, EmbedConfig
|
||||
|
||||
|
||||
from coagent.service.cb_api import search_code, cb_exists_api
|
||||
from loguru import logger
|
||||
import json
|
||||
|
||||
|
||||
class CodeChat(Chat):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
code_base_name: str = '',
|
||||
code_limit: int = 1,
|
||||
stream: bool = False,
|
||||
request: Request = None,
|
||||
) -> None:
|
||||
super().__init__(engine_name=code_base_name, stream=stream)
|
||||
self.engine_name = code_base_name
|
||||
self.code_limit = code_limit
|
||||
self.request = request
|
||||
self.history_node_list = []
|
||||
|
||||
def check_service_status(self) -> BaseResponse:
|
||||
cb = cb_exists_api(self.engine_name)
|
||||
if not cb:
|
||||
return BaseResponse(code=404, msg=f"未找到代码库 {self.engine_name}")
|
||||
return BaseResponse(code=200, msg=f"找到代码库 {self.engine_name}")
|
||||
|
||||
def _process(self, query: str, history: List[History], model, llm_config: LLMConfig, embed_config: EmbedConfig):
|
||||
'''process'''
|
||||
|
||||
codes_res = search_code(query=query, cb_name=self.engine_name, code_limit=self.code_limit,
|
||||
search_type=self.cb_search_type,
|
||||
history_node_list=self.history_node_list,
|
||||
api_key=llm_config.api_key,
|
||||
api_base_url=llm_config.api_base_url,
|
||||
model_name=llm_config.model_name,
|
||||
temperature=llm_config.temperature,
|
||||
embed_model=embed_config.embed_model,
|
||||
embed_model_path=embed_config.embed_model_path,
|
||||
embed_engine=embed_config.embed_engine,
|
||||
model_device=embed_config.model_device,
|
||||
embed_config=embed_config
|
||||
)
|
||||
|
||||
context = codes_res['context']
|
||||
related_vertices = codes_res['related_vertices']
|
||||
|
||||
# update node names
|
||||
# node_names = [node[0] for node in nodes]
|
||||
# self.history_node_list.extend(node_names)
|
||||
# self.history_node_list = list(set(self.history_node_list))
|
||||
|
||||
source_nodes = []
|
||||
|
||||
for inum, node_name in enumerate(related_vertices[0:5]):
|
||||
source_nodes.append(f'{inum + 1}. 节点名: `{node_name}`')
|
||||
|
||||
logger.info('history={}'.format(history))
|
||||
logger.info('message={}'.format([i.to_msg_tuple() for i in history] + [("human", CODE_PROMPT_TEMPLATE)]))
|
||||
chat_prompt = ChatPromptTemplate.from_messages(
|
||||
[i.to_msg_tuple() for i in history] + [("human", CODE_PROMPT_TEMPLATE)]
|
||||
)
|
||||
logger.info('chat_prompt={}'.format(chat_prompt))
|
||||
chain = LLMChain(prompt=chat_prompt, llm=model)
|
||||
result = {"answer": "", "codes": source_nodes}
|
||||
return chain, context, result
|
||||
|
||||
def chat(
|
||||
self,
|
||||
query: str = Body(..., description="用户输入", examples=["hello"]),
|
||||
history: List[History] = Body(
|
||||
[], description="历史对话",
|
||||
examples=[[{"role": "user", "content": "我们来玩成语接龙,我先来,生龙活虎"}]]
|
||||
),
|
||||
engine_name: str = Body(..., description="知识库名称", examples=["samples"]),
|
||||
code_limit: int = Body(1, examples=['1']),
|
||||
cb_search_type: str = Body('', examples=['1']),
|
||||
stream: bool = Body(False, description="流式输出"),
|
||||
local_doc_url: bool = Body(False, description="知识文件返回本地路径(true)或URL(false)"),
|
||||
request: Request = None,
|
||||
|
||||
api_key: str = Body(os.environ.get("OPENAI_API_KEY")),
|
||||
api_base_url: str = Body(os.environ.get("API_BASE_URL")),
|
||||
embed_model: str = Body("", ),
|
||||
embed_model_path: str = Body("", ),
|
||||
embed_engine: str = Body("", ),
|
||||
model_name: str = Body("", ),
|
||||
temperature: float = Body(0.5, ),
|
||||
model_device: str = Body("", ),
|
||||
**kargs
|
||||
):
|
||||
params = locals()
|
||||
params.pop("self")
|
||||
llm_config: LLMConfig = LLMConfig(**params)
|
||||
embed_config: EmbedConfig = EmbedConfig(**params)
|
||||
self.engine_name = engine_name if isinstance(engine_name, str) else engine_name.default
|
||||
self.code_limit = code_limit
|
||||
self.stream = stream if isinstance(stream, bool) else stream.default
|
||||
self.local_doc_url = local_doc_url if isinstance(local_doc_url, bool) else local_doc_url.default
|
||||
self.request = request
|
||||
self.cb_search_type = cb_search_type
|
||||
return self._chat(query, history, llm_config, embed_config, **kargs)
|
||||
|
||||
def _chat(self, query: str, history: List[History], llm_config: LLMConfig, embed_config: EmbedConfig, **kargs):
|
||||
history = [History(**h) if isinstance(h, dict) else h for h in history]
|
||||
|
||||
service_status = self.check_service_status()
|
||||
|
||||
if service_status.code != 200: return service_status
|
||||
|
||||
def chat_iterator(query: str, history: List[History]):
|
||||
# model = getChatModel()
|
||||
model = getChatModelFromConfig(llm_config)
|
||||
|
||||
result, content = self.create_task(query, history, model, llm_config, embed_config, **kargs)
|
||||
# logger.info('result={}'.format(result))
|
||||
# logger.info('content={}'.format(content))
|
||||
|
||||
if self.stream:
|
||||
for token in content["text"]:
|
||||
result["answer"] = token
|
||||
yield json.dumps(result, ensure_ascii=False)
|
||||
else:
|
||||
for token in content["text"]:
|
||||
result["answer"] += token
|
||||
yield json.dumps(result, ensure_ascii=False)
|
||||
|
||||
return StreamingResponse(chat_iterator(query, history),
|
||||
media_type="text/event-stream")
|
||||
|
||||
def create_task(self, query: str, history: List[History], model, llm_config: LLMConfig, embed_config: EmbedConfig):
|
||||
'''构建 llm 生成任务'''
|
||||
chain, context, result = self._process(query, history, model, llm_config, embed_config)
|
||||
logger.info('chain={}'.format(chain))
|
||||
try:
|
||||
content = chain({"context": context, "question": query})
|
||||
except Exception as e:
|
||||
content = {"text": str(e)}
|
||||
return result, content
|
||||
|
||||
def create_atask(self, query, history, model, llm_config: LLMConfig, embed_config: EmbedConfig, callback: AsyncIteratorCallbackHandler):
|
||||
chain, context, result = self._process(query, history, model, llm_config, embed_config)
|
||||
task = asyncio.create_task(wrap_done(
|
||||
chain.acall({"context": context, "question": query}), callback.done
|
||||
))
|
||||
return task, result
|
|
@ -1,89 +0,0 @@
|
|||
from fastapi import Request
|
||||
import os, asyncio
|
||||
from urllib.parse import urlencode
|
||||
from typing import List
|
||||
|
||||
from langchain import LLMChain
|
||||
from langchain.callbacks import AsyncIteratorCallbackHandler
|
||||
from langchain.prompts.chat import ChatPromptTemplate
|
||||
|
||||
|
||||
# from configs.model_config import (
|
||||
# llm_model_dict, LLM_MODEL, PROMPT_TEMPLATE,
|
||||
# VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD)
|
||||
from coagent.base_configs.env_config import KB_ROOT_PATH
|
||||
from coagent.connector.configs.prompts import ORIGIN_TEMPLATE_PROMPT
|
||||
from coagent.chat.utils import History, wrap_done
|
||||
from coagent.utils import BaseResponse
|
||||
from coagent.llm_models.llm_config import LLMConfig, EmbedConfig
|
||||
from .base_chat import Chat
|
||||
from coagent.service.kb_api import search_docs, KBServiceFactory
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class KnowledgeChat(Chat):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine_name: str = "",
|
||||
top_k: int = 5,
|
||||
stream: bool = False,
|
||||
score_thresold: float = 1.0,
|
||||
local_doc_url: bool = False,
|
||||
request: Request = None,
|
||||
kb_root_path: str = KB_ROOT_PATH,
|
||||
) -> None:
|
||||
super().__init__(engine_name, top_k, stream)
|
||||
self.score_thresold = score_thresold
|
||||
self.local_doc_url = local_doc_url
|
||||
self.request = request
|
||||
self.kb_root_path = kb_root_path
|
||||
|
||||
def check_service_status(self) -> BaseResponse:
|
||||
kb = KBServiceFactory.get_service_by_name(self.engine_name, self.kb_root_path)
|
||||
if kb is None:
|
||||
return BaseResponse(code=404, msg=f"未找到知识库 {self.engine_name}")
|
||||
return BaseResponse(code=200, msg=f"找到知识库 {self.engine_name}")
|
||||
|
||||
def _process(self, query: str, history: List[History], model, llm_config: LLMConfig, embed_config: EmbedConfig, ):
|
||||
'''process'''
|
||||
docs = search_docs(
|
||||
query, self.engine_name, self.top_k, self.score_threshold, self.kb_root_path,
|
||||
api_key=embed_config.api_key, api_base_url=embed_config.api_base_url, embed_model=embed_config.embed_model,
|
||||
embed_model_path=embed_config.embed_model_path, embed_engine=embed_config.embed_engine,
|
||||
model_device=embed_config.model_device,
|
||||
)
|
||||
context = "\n".join([doc.page_content for doc in docs])
|
||||
source_documents = []
|
||||
for inum, doc in enumerate(docs):
|
||||
filename = os.path.split(doc.metadata["source"])[-1]
|
||||
if self.local_doc_url:
|
||||
url = "file://" + doc.metadata["source"]
|
||||
else:
|
||||
parameters = urlencode({"knowledge_base_name": self.engine_name, "file_name":filename})
|
||||
url = f"{self.request.base_url}knowledge_base/download_doc?" + parameters
|
||||
text = f"""出处 [{inum + 1}] [{filename}]({url}) \n\n{doc.page_content}\n\n"""
|
||||
source_documents.append(text)
|
||||
chat_prompt = ChatPromptTemplate.from_messages(
|
||||
[i.to_msg_tuple() for i in history] + [("human", ORIGIN_TEMPLATE_PROMPT)]
|
||||
)
|
||||
chain = LLMChain(prompt=chat_prompt, llm=model)
|
||||
result = {"answer": "", "docs": source_documents}
|
||||
return chain, context, result
|
||||
|
||||
def create_task(self, query: str, history: List[History], model, llm_config: LLMConfig, embed_config: EmbedConfig, ):
|
||||
'''构建 llm 生成任务'''
|
||||
logger.debug(f"query: {query}, history: {history}")
|
||||
chain, context, result = self._process(query, history, model, llm_config, embed_config)
|
||||
try:
|
||||
content = chain({"context": context, "question": query})
|
||||
except Exception as e:
|
||||
content = {"text": str(e)}
|
||||
return result, content
|
||||
|
||||
def create_atask(self, query, history, model, llm_config: LLMConfig, embed_config: EmbedConfig, callback: AsyncIteratorCallbackHandler):
|
||||
chain, context, result = self._process(query, history, model, llm_config, embed_config)
|
||||
task = asyncio.create_task(wrap_done(
|
||||
chain.acall({"context": context, "question": query}), callback.done
|
||||
))
|
||||
return task, result
|
|
@ -1,42 +0,0 @@
|
|||
import asyncio
|
||||
from typing import List
|
||||
|
||||
from langchain import LLMChain
|
||||
from langchain.callbacks import AsyncIteratorCallbackHandler
|
||||
from langchain.prompts.chat import ChatPromptTemplate
|
||||
|
||||
|
||||
from coagent.chat.utils import History, wrap_done
|
||||
from coagent.llm_models.llm_config import LLMConfig, EmbedConfig
|
||||
from .base_chat import Chat
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class LLMChat(Chat):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine_name: str = "",
|
||||
top_k: int = 1,
|
||||
stream: bool = False,
|
||||
) -> None:
|
||||
super().__init__(engine_name, top_k, stream)
|
||||
|
||||
def create_task(self, query: str, history: List[History], model, llm_config: LLMConfig, embed_config: EmbedConfig, **kargs):
|
||||
'''构建 llm 生成任务'''
|
||||
chat_prompt = ChatPromptTemplate.from_messages(
|
||||
[i.to_msg_tuple() for i in history] + [("human", "{input}")]
|
||||
)
|
||||
chain = LLMChain(prompt=chat_prompt, llm=model)
|
||||
content = chain({"input": query})
|
||||
return {"answer": "", "docs": ""}, content
|
||||
|
||||
def create_atask(self, query, history, model, llm_config: LLMConfig, embed_config: EmbedConfig, callback: AsyncIteratorCallbackHandler):
|
||||
chat_prompt = ChatPromptTemplate.from_messages(
|
||||
[i.to_msg_tuple() for i in history] + [("human", "{input}")]
|
||||
)
|
||||
chain = LLMChain(prompt=chat_prompt, llm=model)
|
||||
task = asyncio.create_task(wrap_done(
|
||||
chain.acall({"input": query}), callback.done
|
||||
))
|
||||
return task, {"answer": "", "docs": ""}
|
|
@ -1,151 +0,0 @@
|
|||
import os, asyncio
|
||||
from typing import List, Optional, Dict
|
||||
|
||||
from langchain import LLMChain
|
||||
from langchain.callbacks import AsyncIteratorCallbackHandler
|
||||
from langchain.utilities import BingSearchAPIWrapper, DuckDuckGoSearchAPIWrapper
|
||||
from langchain.prompts.chat import ChatPromptTemplate
|
||||
from langchain.docstore.document import Document
|
||||
|
||||
# from configs.model_config import (
|
||||
# PROMPT_TEMPLATE, SEARCH_ENGINE_TOP_K, BING_SUBSCRIPTION_KEY, BING_SEARCH_URL,
|
||||
# VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD)
|
||||
from coagent.connector.configs.prompts import ORIGIN_TEMPLATE_PROMPT
|
||||
from coagent.chat.utils import History, wrap_done
|
||||
from coagent.utils import BaseResponse
|
||||
from coagent.llm_models.llm_config import LLMConfig, EmbedConfig
|
||||
from .base_chat import Chat
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from duckduckgo_search import DDGS
|
||||
|
||||
|
||||
# def bing_search(text, result_len=5):
|
||||
# if not (BING_SEARCH_URL and BING_SUBSCRIPTION_KEY):
|
||||
# return [{"snippet": "please set BING_SUBSCRIPTION_KEY and BING_SEARCH_URL in os ENV",
|
||||
# "title": "env info is not found",
|
||||
# "link": "https://python.langchain.com/en/latest/modules/agents/tools/examples/bing_search.html"}]
|
||||
# search = BingSearchAPIWrapper(bing_subscription_key=BING_SUBSCRIPTION_KEY,
|
||||
# bing_search_url=BING_SEARCH_URL)
|
||||
# return search.results(text, result_len)
|
||||
|
||||
|
||||
def duckduckgo_search(
|
||||
query: str,
|
||||
result_len: int = 5,
|
||||
region: Optional[str] = "wt-wt",
|
||||
safesearch: str = "moderate",
|
||||
time: Optional[str] = "y",
|
||||
backend: str = "api",
|
||||
):
|
||||
with DDGS(proxies=os.environ.get("DUCKDUCKGO_PROXY")) as ddgs:
|
||||
results = ddgs.text(
|
||||
query,
|
||||
region=region,
|
||||
safesearch=safesearch,
|
||||
timelimit=time,
|
||||
backend=backend,
|
||||
)
|
||||
if results is None:
|
||||
return [{"Result": "No good DuckDuckGo Search Result was found"}]
|
||||
|
||||
def to_metadata(result: Dict) -> Dict[str, str]:
|
||||
if backend == "news":
|
||||
return {
|
||||
"date": result["date"],
|
||||
"title": result["title"],
|
||||
"snippet": result["body"],
|
||||
"source": result["source"],
|
||||
"link": result["url"],
|
||||
}
|
||||
return {
|
||||
"snippet": result["body"],
|
||||
"title": result["title"],
|
||||
"link": result["href"],
|
||||
}
|
||||
|
||||
formatted_results = []
|
||||
for i, res in enumerate(results, 1):
|
||||
if res is not None:
|
||||
formatted_results.append(to_metadata(res))
|
||||
if len(formatted_results) == result_len:
|
||||
break
|
||||
return formatted_results
|
||||
|
||||
|
||||
# def duckduckgo_search(text, result_len=SEARCH_ENGINE_TOP_K):
|
||||
# search = DuckDuckGoSearchAPIWrapper()
|
||||
# return search.results(text, result_len)
|
||||
|
||||
|
||||
SEARCH_ENGINES = {"duckduckgo": duckduckgo_search,
|
||||
# "bing": bing_search,
|
||||
}
|
||||
|
||||
|
||||
def search_result2docs(search_results):
|
||||
docs = []
|
||||
for result in search_results:
|
||||
doc = Document(page_content=result["snippet"] if "snippet" in result.keys() else "",
|
||||
metadata={"source": result["link"] if "link" in result.keys() else "",
|
||||
"filename": result["title"] if "title" in result.keys() else ""})
|
||||
docs.append(doc)
|
||||
return docs
|
||||
|
||||
|
||||
def lookup_search_engine(
|
||||
query: str,
|
||||
search_engine_name: str,
|
||||
top_k: int = 5,
|
||||
):
|
||||
results = SEARCH_ENGINES[search_engine_name](query, result_len=top_k)
|
||||
docs = search_result2docs(results)
|
||||
return docs
|
||||
|
||||
|
||||
|
||||
class SearchChat(Chat):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine_name: str = "",
|
||||
top_k: int = 5,
|
||||
stream: bool = False,
|
||||
) -> None:
|
||||
super().__init__(engine_name, top_k, stream)
|
||||
|
||||
def check_service_status(self) -> BaseResponse:
|
||||
if self.engine_name not in SEARCH_ENGINES.keys():
|
||||
return BaseResponse(code=404, msg=f"未支持搜索引擎 {self.engine_name}")
|
||||
return BaseResponse(code=200, msg=f"支持搜索引擎 {self.engine_name}")
|
||||
|
||||
def _process(self, query: str, history: List[History], model):
|
||||
'''process'''
|
||||
docs = lookup_search_engine(query, self.engine_name, self.top_k)
|
||||
context = "\n".join([doc.page_content for doc in docs])
|
||||
|
||||
source_documents = [
|
||||
f"""出处 [{inum + 1}] [{doc.metadata["source"]}]({doc.metadata["source"]}) \n\n{doc.page_content}\n\n"""
|
||||
for inum, doc in enumerate(docs)
|
||||
]
|
||||
|
||||
chat_prompt = ChatPromptTemplate.from_messages(
|
||||
[i.to_msg_tuple() for i in history] + [("human", ORIGIN_TEMPLATE_PROMPT)]
|
||||
)
|
||||
chain = LLMChain(prompt=chat_prompt, llm=model)
|
||||
result = {"answer": "", "docs": source_documents}
|
||||
return chain, context, result
|
||||
|
||||
def create_task(self, query: str, history: List[History], model, llm_config: LLMConfig, embed_config: EmbedConfig, ):
|
||||
'''构建 llm 生成任务'''
|
||||
chain, context, result = self._process(query, history, model)
|
||||
content = chain({"context": context, "question": query})
|
||||
return result, content
|
||||
|
||||
def create_atask(self, query, history, model, llm_config: LLMConfig, embed_config: EmbedConfig, callback: AsyncIteratorCallbackHandler):
|
||||
chain, context, result = self._process(query, history, model)
|
||||
task = asyncio.create_task(wrap_done(
|
||||
chain.acall({"context": context, "question": query}), callback.done
|
||||
))
|
||||
return task, result
|
|
@ -1,30 +0,0 @@
|
|||
import asyncio
|
||||
from typing import Awaitable
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
async def wrap_done(fn: Awaitable, event: asyncio.Event):
|
||||
"""Wrap an awaitable with a event to signal when it's done or an exception is raised."""
|
||||
try:
|
||||
await fn
|
||||
except Exception as e:
|
||||
# TODO: handle exception
|
||||
print(f"Caught exception: {e}")
|
||||
finally:
|
||||
# Signal the aiter to stop.
|
||||
event.set()
|
||||
|
||||
|
||||
class History(BaseModel):
|
||||
"""
|
||||
对话历史
|
||||
可从dict生成,如
|
||||
h = History(**{"role":"user","content":"你好"})
|
||||
也可转换为tuple,如
|
||||
h.to_msy_tuple = ("human", "你好")
|
||||
"""
|
||||
role: str = Field(...)
|
||||
content: str = Field(...)
|
||||
|
||||
def to_msg_tuple(self):
|
||||
return "ai" if self.role=="assistant" else "human", self.content
|
|
@ -1,7 +0,0 @@
|
|||
# encoding: utf-8
|
||||
'''
|
||||
@author: 温进
|
||||
@file: __init__.py.py
|
||||
@time: 2023/11/21 下午2:01
|
||||
@desc:
|
||||
'''
|
|
@ -1,7 +0,0 @@
|
|||
# encoding: utf-8
|
||||
'''
|
||||
@author: 温进
|
||||
@file: __init__.py.py
|
||||
@time: 2023/11/21 下午2:27
|
||||
@desc:
|
||||
'''
|
|
@ -1,222 +0,0 @@
|
|||
# encoding: utf-8
|
||||
'''
|
||||
@author: 温进
|
||||
@file: code_analyzer.py
|
||||
@time: 2023/11/21 下午2:27
|
||||
@desc:
|
||||
'''
|
||||
import time
|
||||
from loguru import logger
|
||||
|
||||
from coagent.codechat.code_analyzer.code_static_analysis import CodeStaticAnalysis
|
||||
from coagent.codechat.code_analyzer.code_intepreter import CodeIntepreter
|
||||
from coagent.codechat.code_analyzer.code_preprocess import CodePreprocessor
|
||||
from coagent.codechat.code_analyzer.code_dedup import CodeDedup
|
||||
from coagent.llm_models.llm_config import LLMConfig
|
||||
|
||||
|
||||
|
||||
class CodeAnalyzer:
|
||||
def __init__(self, language: str, llm_config: LLMConfig):
|
||||
self.llm_config = llm_config
|
||||
self.code_preprocessor = CodePreprocessor()
|
||||
self.code_debup = CodeDedup()
|
||||
self.code_interperter = CodeIntepreter(self.llm_config)
|
||||
self.code_static_analyzer = CodeStaticAnalysis(language=language)
|
||||
|
||||
def analyze(self, code_dict: dict, do_interpret: bool = True):
|
||||
'''
|
||||
analyze code
|
||||
@param code_dict: {fp: code_text}
|
||||
@param do_interpret: Whether to get analysis result
|
||||
@return:
|
||||
'''
|
||||
# preprocess and dedup
|
||||
st = time.time()
|
||||
code_dict = self.code_preprocessor.preprocess(code_dict)
|
||||
code_dict = self.code_debup.dedup(code_dict)
|
||||
logger.debug('preprocess and dedup rt={}'.format(time.time() - st))
|
||||
|
||||
# static analysis
|
||||
st = time.time()
|
||||
static_analysis_res = self.code_static_analyzer.analyze(code_dict)
|
||||
logger.debug('static analysis rt={}'.format(time.time() - st))
|
||||
|
||||
# interpretation
|
||||
if do_interpret:
|
||||
logger.info('start interpret code')
|
||||
st = time.time()
|
||||
code_list = list(code_dict.values())
|
||||
interpretation = self.code_interperter.get_intepretation_batch(code_list)
|
||||
logger.debug('interpret rt={}'.format(time.time() - st))
|
||||
else:
|
||||
interpretation = {i: '' for i in code_dict.values()}
|
||||
|
||||
return static_analysis_res, interpretation
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
engine = 'openai'
|
||||
language = 'java'
|
||||
code_dict = {'1': '''package com.theokanning.openai.client;
|
||||
import com.theokanning.openai.DeleteResult;
|
||||
import com.theokanning.openai.OpenAiResponse;
|
||||
import com.theokanning.openai.audio.TranscriptionResult;
|
||||
import com.theokanning.openai.audio.TranslationResult;
|
||||
import com.theokanning.openai.billing.BillingUsage;
|
||||
import com.theokanning.openai.billing.Subscription;
|
||||
import com.theokanning.openai.completion.CompletionRequest;
|
||||
import com.theokanning.openai.completion.CompletionResult;
|
||||
import com.theokanning.openai.completion.chat.ChatCompletionRequest;
|
||||
import com.theokanning.openai.completion.chat.ChatCompletionResult;
|
||||
import com.theokanning.openai.edit.EditRequest;
|
||||
import com.theokanning.openai.edit.EditResult;
|
||||
import com.theokanning.openai.embedding.EmbeddingRequest;
|
||||
import com.theokanning.openai.embedding.EmbeddingResult;
|
||||
import com.theokanning.openai.engine.Engine;
|
||||
import com.theokanning.openai.file.File;
|
||||
import com.theokanning.openai.fine_tuning.FineTuningEvent;
|
||||
import com.theokanning.openai.fine_tuning.FineTuningJob;
|
||||
import com.theokanning.openai.fine_tuning.FineTuningJobRequest;
|
||||
import com.theokanning.openai.finetune.FineTuneEvent;
|
||||
import com.theokanning.openai.finetune.FineTuneRequest;
|
||||
import com.theokanning.openai.finetune.FineTuneResult;
|
||||
import com.theokanning.openai.image.CreateImageRequest;
|
||||
import com.theokanning.openai.image.ImageResult;
|
||||
import com.theokanning.openai.model.Model;
|
||||
import com.theokanning.openai.moderation.ModerationRequest;
|
||||
import com.theokanning.openai.moderation.ModerationResult;
|
||||
import io.reactivex.Single;
|
||||
import okhttp3.MultipartBody;
|
||||
import okhttp3.RequestBody;
|
||||
import okhttp3.ResponseBody;
|
||||
import retrofit2.Call;
|
||||
import retrofit2.http.*;
|
||||
import java.time.LocalDate;
|
||||
public interface OpenAiApi {
|
||||
@GET("v1/models")
|
||||
Single<OpenAiResponse<Model>> listModels();
|
||||
@GET("/v1/models/{model_id}")
|
||||
Single<Model> getModel(@Path("model_id") String modelId);
|
||||
@POST("/v1/completions")
|
||||
Single<CompletionResult> createCompletion(@Body CompletionRequest request);
|
||||
@Streaming
|
||||
@POST("/v1/completions")
|
||||
Call<ResponseBody> createCompletionStream(@Body CompletionRequest request);
|
||||
@POST("/v1/chat/completions")
|
||||
Single<ChatCompletionResult> createChatCompletion(@Body ChatCompletionRequest request);
|
||||
@Streaming
|
||||
@POST("/v1/chat/completions")
|
||||
Call<ResponseBody> createChatCompletionStream(@Body ChatCompletionRequest request);
|
||||
@Deprecated
|
||||
@POST("/v1/engines/{engine_id}/completions")
|
||||
Single<CompletionResult> createCompletion(@Path("engine_id") String engineId, @Body CompletionRequest request);
|
||||
@POST("/v1/edits")
|
||||
Single<EditResult> createEdit(@Body EditRequest request);
|
||||
@Deprecated
|
||||
@POST("/v1/engines/{engine_id}/edits")
|
||||
Single<EditResult> createEdit(@Path("engine_id") String engineId, @Body EditRequest request);
|
||||
@POST("/v1/embeddings")
|
||||
Single<EmbeddingResult> createEmbeddings(@Body EmbeddingRequest request);
|
||||
@Deprecated
|
||||
@POST("/v1/engines/{engine_id}/embeddings")
|
||||
Single<EmbeddingResult> createEmbeddings(@Path("engine_id") String engineId, @Body EmbeddingRequest request);
|
||||
@GET("/v1/files")
|
||||
Single<OpenAiResponse<File>> listFiles();
|
||||
@Multipart
|
||||
@POST("/v1/files")
|
||||
Single<File> uploadFile(@Part("purpose") RequestBody purpose, @Part MultipartBody.Part file);
|
||||
@DELETE("/v1/files/{file_id}")
|
||||
Single<DeleteResult> deleteFile(@Path("file_id") String fileId);
|
||||
@GET("/v1/files/{file_id}")
|
||||
Single<File> retrieveFile(@Path("file_id") String fileId);
|
||||
@Streaming
|
||||
@GET("/v1/files/{file_id}/content")
|
||||
Single<ResponseBody> retrieveFileContent(@Path("file_id") String fileId);
|
||||
@POST("/v1/fine_tuning/jobs")
|
||||
Single<FineTuningJob> createFineTuningJob(@Body FineTuningJobRequest request);
|
||||
@GET("/v1/fine_tuning/jobs")
|
||||
Single<OpenAiResponse<FineTuningJob>> listFineTuningJobs();
|
||||
@GET("/v1/fine_tuning/jobs/{fine_tuning_job_id}")
|
||||
Single<FineTuningJob> retrieveFineTuningJob(@Path("fine_tuning_job_id") String fineTuningJobId);
|
||||
@POST("/v1/fine_tuning/jobs/{fine_tuning_job_id}/cancel")
|
||||
Single<FineTuningJob> cancelFineTuningJob(@Path("fine_tuning_job_id") String fineTuningJobId);
|
||||
@GET("/v1/fine_tuning/jobs/{fine_tuning_job_id}/events")
|
||||
Single<OpenAiResponse<FineTuningEvent>> listFineTuningJobEvents(@Path("fine_tuning_job_id") String fineTuningJobId);
|
||||
@Deprecated
|
||||
@POST("/v1/fine-tunes")
|
||||
Single<FineTuneResult> createFineTune(@Body FineTuneRequest request);
|
||||
@POST("/v1/completions")
|
||||
Single<CompletionResult> createFineTuneCompletion(@Body CompletionRequest request);
|
||||
@Deprecated
|
||||
@GET("/v1/fine-tunes")
|
||||
Single<OpenAiResponse<FineTuneResult>> listFineTunes();
|
||||
@Deprecated
|
||||
@GET("/v1/fine-tunes/{fine_tune_id}")
|
||||
Single<FineTuneResult> retrieveFineTune(@Path("fine_tune_id") String fineTuneId);
|
||||
@Deprecated
|
||||
@POST("/v1/fine-tunes/{fine_tune_id}/cancel")
|
||||
Single<FineTuneResult> cancelFineTune(@Path("fine_tune_id") String fineTuneId);
|
||||
@Deprecated
|
||||
@GET("/v1/fine-tunes/{fine_tune_id}/events")
|
||||
Single<OpenAiResponse<FineTuneEvent>> listFineTuneEvents(@Path("fine_tune_id") String fineTuneId);
|
||||
@DELETE("/v1/models/{fine_tune_id}")
|
||||
Single<DeleteResult> deleteFineTune(@Path("fine_tune_id") String fineTuneId);
|
||||
@POST("/v1/images/generations")
|
||||
Single<ImageResult> createImage(@Body CreateImageRequest request);
|
||||
@POST("/v1/images/edits")
|
||||
Single<ImageResult> createImageEdit(@Body RequestBody requestBody);
|
||||
@POST("/v1/images/variations")
|
||||
Single<ImageResult> createImageVariation(@Body RequestBody requestBody);
|
||||
@POST("/v1/audio/transcriptions")
|
||||
Single<TranscriptionResult> createTranscription(@Body RequestBody requestBody);
|
||||
@POST("/v1/audio/translations")
|
||||
Single<TranslationResult> createTranslation(@Body RequestBody requestBody);
|
||||
@POST("/v1/moderations")
|
||||
Single<ModerationResult> createModeration(@Body ModerationRequest request);
|
||||
@Deprecated
|
||||
@GET("v1/engines")
|
||||
Single<OpenAiResponse<Engine>> getEngines();
|
||||
@Deprecated
|
||||
@GET("/v1/engines/{engine_id}")
|
||||
Single<Engine> getEngine(@Path("engine_id") String engineId);
|
||||
/**
|
||||
* Account information inquiry: It contains total amount (in US dollars) and other information.
|
||||
*
|
||||
* @return
|
||||
*/
|
||||
@Deprecated
|
||||
@GET("v1/dashboard/billing/subscription")
|
||||
Single<Subscription> subscription();
|
||||
/**
|
||||
* Account call interface consumption amount inquiry.
|
||||
* totalUsage = Total amount used by the account (in US cents).
|
||||
*
|
||||
* @param starDate
|
||||
* @param endDate
|
||||
* @return Consumption amount information.
|
||||
*/
|
||||
@Deprecated
|
||||
@GET("v1/dashboard/billing/usage")
|
||||
Single<BillingUsage> billingUsage(@Query("start_date") LocalDate starDate, @Query("end_date") LocalDate endDate);
|
||||
}''', '2': '''
|
||||
package com.theokanning.openai;
|
||||
|
||||
/**
|
||||
* OkHttp Interceptor that adds an authorization token header
|
||||
*
|
||||
* @deprecated Use {@link com.theokanning.openai.client.AuthenticationInterceptor}
|
||||
*/
|
||||
@Deprecated
|
||||
public class AuthenticationInterceptor extends com.theokanning.openai.client.AuthenticationInterceptor {
|
||||
|
||||
AuthenticationInterceptor(String token) {
|
||||
super(token);
|
||||
}
|
||||
|
||||
}
|
||||
'''}
|
||||
|
||||
ca = CodeAnalyzer(engine, language)
|
||||
res = ca.analyze(code_dict)
|
||||
logger.debug(res)
|
|
@ -1,31 +0,0 @@
|
|||
# encoding: utf-8
|
||||
'''
|
||||
@author: 温进
|
||||
@file: code_dedup.py
|
||||
@time: 2023/11/21 下午2:27
|
||||
@desc:
|
||||
'''
|
||||
# encoding: utf-8
|
||||
'''
|
||||
@author: 温进
|
||||
@file: java_dedup.py
|
||||
@time: 2023/10/23 下午5:02
|
||||
@desc:
|
||||
'''
|
||||
|
||||
|
||||
class CodeDedup:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def dedup(self, code_dict):
|
||||
code_dict = self.exact_dedup(code_dict)
|
||||
return code_dict
|
||||
|
||||
def exact_dedup(self, code_dict):
|
||||
res = {}
|
||||
for fp, code_text in code_dict.items():
|
||||
if code_text not in res.values():
|
||||
res[fp] = code_text
|
||||
|
||||
return res
|
|
@ -1,238 +0,0 @@
|
|||
# encoding: utf-8
|
||||
'''
|
||||
@author: 温进
|
||||
@file: code_intepreter.py
|
||||
@time: 2023/11/22 上午11:57
|
||||
@desc:
|
||||
'''
|
||||
from loguru import logger
|
||||
from langchain.schema import (
|
||||
HumanMessage,
|
||||
)
|
||||
|
||||
# from configs.model_config import CODE_INTERPERT_TEMPLATE
|
||||
from coagent.connector.configs.prompts import CODE_INTERPERT_TEMPLATE
|
||||
from coagent.llm_models.openai_model import getChatModelFromConfig
|
||||
from coagent.llm_models.llm_config import LLMConfig
|
||||
|
||||
|
||||
class CodeIntepreter:
|
||||
def __init__(self, llm_config: LLMConfig):
|
||||
self.llm_config = llm_config
|
||||
|
||||
def get_intepretation(self, code_list):
|
||||
'''
|
||||
get intepretion of code
|
||||
@param code_list:
|
||||
@return:
|
||||
'''
|
||||
# chat_model = getChatModel()
|
||||
chat_model = getChatModelFromConfig(self.llm_config)
|
||||
|
||||
res = {}
|
||||
for code in code_list:
|
||||
message = CODE_INTERPERT_TEMPLATE.format(code=code)
|
||||
message = [HumanMessage(content=message)]
|
||||
chat_res = chat_model.predict_messages(message)
|
||||
content = chat_res.content
|
||||
res[code] = content
|
||||
return res
|
||||
|
||||
def get_intepretation_batch(self, code_list):
|
||||
'''
|
||||
get intepretion of code
|
||||
@param code_list:
|
||||
@return:
|
||||
'''
|
||||
# chat_model = getChatModel()
|
||||
chat_model = getChatModelFromConfig(self.llm_config)
|
||||
|
||||
res = {}
|
||||
messages = []
|
||||
for code in code_list:
|
||||
message = CODE_INTERPERT_TEMPLATE.format(code=code)
|
||||
messages.append(message)
|
||||
|
||||
try:
|
||||
chat_ress = [chat_model(messages) for message in messages]
|
||||
except:
|
||||
chat_ress = chat_model.batch(messages)
|
||||
for chat_res, code in zip(chat_ress, code_list):
|
||||
try:
|
||||
res[code] = chat_res.content
|
||||
except:
|
||||
res[code] = chat_res
|
||||
return res
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
engine = 'openai'
|
||||
code_list = ['''package com.theokanning.openai.client;
|
||||
import com.theokanning.openai.DeleteResult;
|
||||
import com.theokanning.openai.OpenAiResponse;
|
||||
import com.theokanning.openai.audio.TranscriptionResult;
|
||||
import com.theokanning.openai.audio.TranslationResult;
|
||||
import com.theokanning.openai.billing.BillingUsage;
|
||||
import com.theokanning.openai.billing.Subscription;
|
||||
import com.theokanning.openai.completion.CompletionRequest;
|
||||
import com.theokanning.openai.completion.CompletionResult;
|
||||
import com.theokanning.openai.completion.chat.ChatCompletionRequest;
|
||||
import com.theokanning.openai.completion.chat.ChatCompletionResult;
|
||||
import com.theokanning.openai.edit.EditRequest;
|
||||
import com.theokanning.openai.edit.EditResult;
|
||||
import com.theokanning.openai.embedding.EmbeddingRequest;
|
||||
import com.theokanning.openai.embedding.EmbeddingResult;
|
||||
import com.theokanning.openai.engine.Engine;
|
||||
import com.theokanning.openai.file.File;
|
||||
import com.theokanning.openai.fine_tuning.FineTuningEvent;
|
||||
import com.theokanning.openai.fine_tuning.FineTuningJob;
|
||||
import com.theokanning.openai.fine_tuning.FineTuningJobRequest;
|
||||
import com.theokanning.openai.finetune.FineTuneEvent;
|
||||
import com.theokanning.openai.finetune.FineTuneRequest;
|
||||
import com.theokanning.openai.finetune.FineTuneResult;
|
||||
import com.theokanning.openai.image.CreateImageRequest;
|
||||
import com.theokanning.openai.image.ImageResult;
|
||||
import com.theokanning.openai.model.Model;
|
||||
import com.theokanning.openai.moderation.ModerationRequest;
|
||||
import com.theokanning.openai.moderation.ModerationResult;
|
||||
import io.reactivex.Single;
|
||||
import okhttp3.MultipartBody;
|
||||
import okhttp3.RequestBody;
|
||||
import okhttp3.ResponseBody;
|
||||
import retrofit2.Call;
|
||||
import retrofit2.http.*;
|
||||
import java.time.LocalDate;
|
||||
public interface OpenAiApi {
|
||||
@GET("v1/models")
|
||||
Single<OpenAiResponse<Model>> listModels();
|
||||
@GET("/v1/models/{model_id}")
|
||||
Single<Model> getModel(@Path("model_id") String modelId);
|
||||
@POST("/v1/completions")
|
||||
Single<CompletionResult> createCompletion(@Body CompletionRequest request);
|
||||
@Streaming
|
||||
@POST("/v1/completions")
|
||||
Call<ResponseBody> createCompletionStream(@Body CompletionRequest request);
|
||||
@POST("/v1/chat/completions")
|
||||
Single<ChatCompletionResult> createChatCompletion(@Body ChatCompletionRequest request);
|
||||
@Streaming
|
||||
@POST("/v1/chat/completions")
|
||||
Call<ResponseBody> createChatCompletionStream(@Body ChatCompletionRequest request);
|
||||
@Deprecated
|
||||
@POST("/v1/engines/{engine_id}/completions")
|
||||
Single<CompletionResult> createCompletion(@Path("engine_id") String engineId, @Body CompletionRequest request);
|
||||
@POST("/v1/edits")
|
||||
Single<EditResult> createEdit(@Body EditRequest request);
|
||||
@Deprecated
|
||||
@POST("/v1/engines/{engine_id}/edits")
|
||||
Single<EditResult> createEdit(@Path("engine_id") String engineId, @Body EditRequest request);
|
||||
@POST("/v1/embeddings")
|
||||
Single<EmbeddingResult> createEmbeddings(@Body EmbeddingRequest request);
|
||||
@Deprecated
|
||||
@POST("/v1/engines/{engine_id}/embeddings")
|
||||
Single<EmbeddingResult> createEmbeddings(@Path("engine_id") String engineId, @Body EmbeddingRequest request);
|
||||
@GET("/v1/files")
|
||||
Single<OpenAiResponse<File>> listFiles();
|
||||
@Multipart
|
||||
@POST("/v1/files")
|
||||
Single<File> uploadFile(@Part("purpose") RequestBody purpose, @Part MultipartBody.Part file);
|
||||
@DELETE("/v1/files/{file_id}")
|
||||
Single<DeleteResult> deleteFile(@Path("file_id") String fileId);
|
||||
@GET("/v1/files/{file_id}")
|
||||
Single<File> retrieveFile(@Path("file_id") String fileId);
|
||||
@Streaming
|
||||
@GET("/v1/files/{file_id}/content")
|
||||
Single<ResponseBody> retrieveFileContent(@Path("file_id") String fileId);
|
||||
@POST("/v1/fine_tuning/jobs")
|
||||
Single<FineTuningJob> createFineTuningJob(@Body FineTuningJobRequest request);
|
||||
@GET("/v1/fine_tuning/jobs")
|
||||
Single<OpenAiResponse<FineTuningJob>> listFineTuningJobs();
|
||||
@GET("/v1/fine_tuning/jobs/{fine_tuning_job_id}")
|
||||
Single<FineTuningJob> retrieveFineTuningJob(@Path("fine_tuning_job_id") String fineTuningJobId);
|
||||
@POST("/v1/fine_tuning/jobs/{fine_tuning_job_id}/cancel")
|
||||
Single<FineTuningJob> cancelFineTuningJob(@Path("fine_tuning_job_id") String fineTuningJobId);
|
||||
@GET("/v1/fine_tuning/jobs/{fine_tuning_job_id}/events")
|
||||
Single<OpenAiResponse<FineTuningEvent>> listFineTuningJobEvents(@Path("fine_tuning_job_id") String fineTuningJobId);
|
||||
@Deprecated
|
||||
@POST("/v1/fine-tunes")
|
||||
Single<FineTuneResult> createFineTune(@Body FineTuneRequest request);
|
||||
@POST("/v1/completions")
|
||||
Single<CompletionResult> createFineTuneCompletion(@Body CompletionRequest request);
|
||||
@Deprecated
|
||||
@GET("/v1/fine-tunes")
|
||||
Single<OpenAiResponse<FineTuneResult>> listFineTunes();
|
||||
@Deprecated
|
||||
@GET("/v1/fine-tunes/{fine_tune_id}")
|
||||
Single<FineTuneResult> retrieveFineTune(@Path("fine_tune_id") String fineTuneId);
|
||||
@Deprecated
|
||||
@POST("/v1/fine-tunes/{fine_tune_id}/cancel")
|
||||
Single<FineTuneResult> cancelFineTune(@Path("fine_tune_id") String fineTuneId);
|
||||
@Deprecated
|
||||
@GET("/v1/fine-tunes/{fine_tune_id}/events")
|
||||
Single<OpenAiResponse<FineTuneEvent>> listFineTuneEvents(@Path("fine_tune_id") String fineTuneId);
|
||||
@DELETE("/v1/models/{fine_tune_id}")
|
||||
Single<DeleteResult> deleteFineTune(@Path("fine_tune_id") String fineTuneId);
|
||||
@POST("/v1/images/generations")
|
||||
Single<ImageResult> createImage(@Body CreateImageRequest request);
|
||||
@POST("/v1/images/edits")
|
||||
Single<ImageResult> createImageEdit(@Body RequestBody requestBody);
|
||||
@POST("/v1/images/variations")
|
||||
Single<ImageResult> createImageVariation(@Body RequestBody requestBody);
|
||||
@POST("/v1/audio/transcriptions")
|
||||
Single<TranscriptionResult> createTranscription(@Body RequestBody requestBody);
|
||||
@POST("/v1/audio/translations")
|
||||
Single<TranslationResult> createTranslation(@Body RequestBody requestBody);
|
||||
@POST("/v1/moderations")
|
||||
Single<ModerationResult> createModeration(@Body ModerationRequest request);
|
||||
@Deprecated
|
||||
@GET("v1/engines")
|
||||
Single<OpenAiResponse<Engine>> getEngines();
|
||||
@Deprecated
|
||||
@GET("/v1/engines/{engine_id}")
|
||||
Single<Engine> getEngine(@Path("engine_id") String engineId);
|
||||
/**
|
||||
* Account information inquiry: It contains total amount (in US dollars) and other information.
|
||||
*
|
||||
* @return
|
||||
*/
|
||||
@Deprecated
|
||||
@GET("v1/dashboard/billing/subscription")
|
||||
Single<Subscription> subscription();
|
||||
/**
|
||||
* Account call interface consumption amount inquiry.
|
||||
* totalUsage = Total amount used by the account (in US cents).
|
||||
*
|
||||
* @param starDate
|
||||
* @param endDate
|
||||
* @return Consumption amount information.
|
||||
*/
|
||||
@Deprecated
|
||||
@GET("v1/dashboard/billing/usage")
|
||||
Single<BillingUsage> billingUsage(@Query("start_date") LocalDate starDate, @Query("end_date") LocalDate endDate);
|
||||
}''', '''
|
||||
package com.theokanning.openai;
|
||||
|
||||
/**
|
||||
* OkHttp Interceptor that adds an authorization token header
|
||||
*
|
||||
* @deprecated Use {@link com.theokanning.openai.client.AuthenticationInterceptor}
|
||||
*/
|
||||
@Deprecated
|
||||
public class AuthenticationInterceptor extends com.theokanning.openai.client.AuthenticationInterceptor {
|
||||
|
||||
AuthenticationInterceptor(String token) {
|
||||
super(token);
|
||||
}
|
||||
|
||||
}
|
||||
''']
|
||||
|
||||
ci = CodeIntepreter(engine)
|
||||
res = ci.get_intepretation_batch(code_list)
|
||||
logger.debug(res)
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
@ -1,14 +0,0 @@
|
|||
# encoding: utf-8
|
||||
'''
|
||||
@author: 温进
|
||||
@file: code_preprocess.py
|
||||
@time: 2023/11/21 下午2:28
|
||||
@desc:
|
||||
'''
|
||||
|
||||
class CodePreprocessor:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def preprocess(self, code_dict):
|
||||
return code_dict
|
|
@ -1,26 +0,0 @@
|
|||
# encoding: utf-8
|
||||
'''
|
||||
@author: 温进
|
||||
@file: code_static_analysis.py
|
||||
@time: 2023/11/21 下午2:28
|
||||
@desc:
|
||||
'''
|
||||
from coagent.codechat.code_analyzer.language_static_analysis import *
|
||||
|
||||
class CodeStaticAnalysis:
|
||||
def __init__(self, language):
|
||||
self.language = language
|
||||
|
||||
def analyze(self, code_dict):
|
||||
'''
|
||||
analyze code
|
||||
@param code_list:
|
||||
@return:
|
||||
'''
|
||||
if self.language == 'java':
|
||||
analyzer = JavaStaticAnalysis()
|
||||
else:
|
||||
raise ValueError('language should be one of [java]')
|
||||
|
||||
analyze_res = analyzer.analyze(code_dict)
|
||||
return analyze_res
|
|
@ -1,14 +0,0 @@
|
|||
# encoding: utf-8
|
||||
'''
|
||||
@author: 温进
|
||||
@file: __init__.py.py
|
||||
@time: 2023/11/21 下午4:24
|
||||
@desc:
|
||||
'''
|
||||
|
||||
from .java_static_analysis import JavaStaticAnalysis
|
||||
|
||||
|
||||
__all__ = [
|
||||
'JavaStaticAnalysis'
|
||||
]
|
|
@ -1,138 +0,0 @@
|
|||
# encoding: utf-8
|
||||
'''
|
||||
@author: 温进
|
||||
@file: java_static_analysis.py
|
||||
@time: 2023/11/21 下午4:25
|
||||
@desc:
|
||||
'''
|
||||
import os
|
||||
from loguru import logger
|
||||
import javalang
|
||||
|
||||
|
||||
class JavaStaticAnalysis:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def analyze(self, java_code_dict):
|
||||
'''
|
||||
parse java code and extract entity
|
||||
'''
|
||||
tree_dict = self.preparse(java_code_dict)
|
||||
res = self.multi_java_code_parse(tree_dict)
|
||||
|
||||
return res
|
||||
|
||||
def preparse(self, java_code_dict):
|
||||
'''
|
||||
preparse by javalang
|
||||
< dict of java_code and tree
|
||||
'''
|
||||
tree_dict = {}
|
||||
for fp, java_code in java_code_dict.items():
|
||||
try:
|
||||
tree = javalang.parse.parse(java_code)
|
||||
except Exception as e:
|
||||
continue
|
||||
|
||||
if tree.package is not None:
|
||||
tree_dict[fp] = {'code': java_code, 'tree': tree}
|
||||
logger.info('success parse {} files'.format(len(tree_dict)))
|
||||
return tree_dict
|
||||
|
||||
def single_java_code_parse(self, tree, fp):
|
||||
'''
|
||||
parse single code file
|
||||
> tree: javalang parse result
|
||||
< {pac_name: '', class_name_list: [], func_name_dict: {}, import_pac_name_list: []]}
|
||||
'''
|
||||
import_pac_name_list = []
|
||||
|
||||
# get imports
|
||||
import_list = tree.imports
|
||||
|
||||
for import_pac in import_list:
|
||||
import_pac_name = import_pac.path
|
||||
import_pac_name_list.append(import_pac_name)
|
||||
|
||||
fp_last = fp.split(os.path.sep)[-1]
|
||||
pac_name = tree.package.name + '#' + fp_last
|
||||
class_name_list = []
|
||||
func_name_dict = {}
|
||||
|
||||
for node in tree.types:
|
||||
if type(node) in (javalang.tree.ClassDeclaration, javalang.tree.InterfaceDeclaration):
|
||||
class_name = tree.package.name + '.' + node.name
|
||||
class_name_list.append(class_name)
|
||||
|
||||
for node_inner in node.body:
|
||||
if type(node_inner) is javalang.tree.MethodDeclaration:
|
||||
func_name = class_name + '#' + node_inner.name
|
||||
|
||||
# add params name to func_name
|
||||
params_list = node_inner.parameters
|
||||
|
||||
for params in params_list:
|
||||
params_name = params.type.name
|
||||
func_name = func_name + '-' + params_name
|
||||
|
||||
if class_name not in func_name_dict:
|
||||
func_name_dict[class_name] = []
|
||||
|
||||
func_name_dict[class_name].append(func_name)
|
||||
|
||||
res = {
|
||||
'pac_name': pac_name,
|
||||
'class_name_list': class_name_list,
|
||||
'func_name_dict': func_name_dict,
|
||||
'import_pac_name_list': import_pac_name_list
|
||||
}
|
||||
return res
|
||||
|
||||
def multi_java_code_parse(self, tree_dict):
|
||||
'''
|
||||
parse multiple java code
|
||||
> tree_list
|
||||
< parse_result_dict
|
||||
'''
|
||||
res_dict = {}
|
||||
for fp, value in tree_dict.items():
|
||||
java_code = value['code']
|
||||
tree = value['tree']
|
||||
try:
|
||||
res_dict[java_code] = self.single_java_code_parse(tree, fp)
|
||||
except Exception as e:
|
||||
logger.debug(java_code)
|
||||
raise ImportError
|
||||
|
||||
return res_dict
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
java_code_dict = {
|
||||
'test': '''package com.theokanning.openai;
|
||||
|
||||
import com.theokanning.openai.client.Utils;
|
||||
|
||||
|
||||
public class UtilsTest {
|
||||
public void testRemoveChar() {
|
||||
String input = "hello";
|
||||
char ch = 'l';
|
||||
String expected = "heo";
|
||||
String res = Utils.remove(input, ch);
|
||||
System.out.println(res.equals(expected));
|
||||
}
|
||||
}
|
||||
'''
|
||||
}
|
||||
|
||||
jsa = JavaStaticAnalysis()
|
||||
res = jsa.analyze(java_code_dict)
|
||||
logger.info(res)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
@ -1,15 +0,0 @@
|
|||
# encoding: utf-8
|
||||
'''
|
||||
@author: 温进
|
||||
@file: __init__.py.py
|
||||
@time: 2023/11/21 下午2:02
|
||||
@desc:
|
||||
'''
|
||||
from .zip_crawler import ZipCrawler
|
||||
from .dir_crawler import DirCrawler
|
||||
|
||||
|
||||
__all__ = [
|
||||
'ZipCrawler',
|
||||
'DirCrawler'
|
||||
]
|
|
@ -1,39 +0,0 @@
|
|||
# encoding: utf-8
|
||||
'''
|
||||
@author: 温进
|
||||
@file: dir_crawler.py
|
||||
@time: 2023/11/22 下午2:54
|
||||
@desc:
|
||||
'''
|
||||
from loguru import logger
|
||||
import os
|
||||
import glob
|
||||
|
||||
|
||||
class DirCrawler:
|
||||
@staticmethod
|
||||
def crawl(path: str, suffix: str):
|
||||
'''
|
||||
read local java file in path
|
||||
> path: path to crawl, must be absolute path like A/B/C
|
||||
< dict of java code string
|
||||
'''
|
||||
java_file_list = glob.glob('{path}{sep}**{sep}*.{suffix}'.format(path=path, sep=os.path.sep, suffix=suffix),
|
||||
recursive=True)
|
||||
java_code_dict = {}
|
||||
|
||||
logger.info(path)
|
||||
logger.info('number of file={}'.format(len(java_file_list)))
|
||||
logger.info(java_file_list)
|
||||
|
||||
for java_file in java_file_list:
|
||||
with open(java_file, encoding="utf-8") as f:
|
||||
java_code = ''.join(f.readlines())
|
||||
java_code_dict[java_file] = java_code
|
||||
return java_code_dict
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
path = '/Users/bingxu/Desktop/工作/大模型/chatbot/test_code_repo/middleware-alipay-starters-parent'
|
||||
suffix = 'java'
|
||||
DirCrawler.crawl(path, suffix)
|
|
@ -1,31 +0,0 @@
|
|||
# encoding: utf-8
|
||||
'''
|
||||
@author: 温进
|
||||
@file: zip_crawler.py
|
||||
@time: 2023/11/21 下午2:02
|
||||
@desc:
|
||||
'''
|
||||
from loguru import logger
|
||||
|
||||
import zipfile
|
||||
from coagent.codechat.code_crawler.dir_crawler import DirCrawler
|
||||
|
||||
|
||||
class ZipCrawler:
|
||||
@staticmethod
|
||||
def crawl(zip_file, output_path, suffix):
|
||||
'''
|
||||
unzip to output_path
|
||||
@param zip_file:
|
||||
@param output_path:
|
||||
@return:
|
||||
'''
|
||||
logger.info(f'output_path={output_path}')
|
||||
print(f'output_path={output_path}')
|
||||
with zipfile.ZipFile(zip_file, 'r') as z:
|
||||
z.extractall(output_path)
|
||||
|
||||
code_dict = DirCrawler.crawl(output_path, suffix)
|
||||
return code_dict
|
||||
|
||||
|
|
@ -1,7 +0,0 @@
|
|||
# encoding: utf-8
|
||||
'''
|
||||
@author: 温进
|
||||
@file: __init__.py.py
|
||||
@time: 2023/11/21 下午2:35
|
||||
@desc:
|
||||
'''
|
|
@ -1,261 +0,0 @@
|
|||
# encoding: utf-8
|
||||
'''
|
||||
@author: 温进
|
||||
@file: code_search.py
|
||||
@time: 2023/11/21 下午2:35
|
||||
@desc:
|
||||
'''
|
||||
import json
|
||||
import time
|
||||
from loguru import logger
|
||||
from collections import defaultdict
|
||||
|
||||
from coagent.db_handler.graph_db_handler.nebula_handler import NebulaHandler
|
||||
from coagent.db_handler.vector_db_handler.chroma_handler import ChromaHandler
|
||||
|
||||
from coagent.codechat.code_search.cypher_generator import CypherGenerator
|
||||
from coagent.codechat.code_search.tagger import Tagger
|
||||
from coagent.embeddings.get_embedding import get_embedding
|
||||
from coagent.llm_models.llm_config import LLMConfig, EmbedConfig
|
||||
|
||||
|
||||
# from configs.model_config import EMBEDDING_DEVICE, EMBEDDING_MODEL
|
||||
# search_by_tag
|
||||
VERTEX_SCORE = 10
|
||||
HISTORY_VERTEX_SCORE = 5
|
||||
VERTEX_MERGE_RATIO = 0.5
|
||||
|
||||
# search_by_description
|
||||
MAX_DISTANCE = 1000
|
||||
|
||||
|
||||
class CodeSearch:
|
||||
def __init__(self, llm_config: LLMConfig, nh: NebulaHandler, ch: ChromaHandler, limit: int = 3,
|
||||
local_graph_file_path: str = ''):
|
||||
'''
|
||||
init
|
||||
@param nh: NebulaHandler
|
||||
@param ch: ChromaHandler
|
||||
@param limit: limit of result
|
||||
'''
|
||||
self.llm_config = llm_config
|
||||
|
||||
self.nh = nh
|
||||
|
||||
if not self.nh:
|
||||
with open(local_graph_file_path, 'r') as f:
|
||||
self.graph = json.load(f)
|
||||
|
||||
self.ch = ch
|
||||
self.limit = limit
|
||||
|
||||
def search_by_tag(self, query: str):
|
||||
'''
|
||||
search_code_res by tag
|
||||
@param query: str
|
||||
@return:
|
||||
'''
|
||||
tagger = Tagger()
|
||||
tag_list = tagger.generate_tag_query(query)
|
||||
logger.info(f'query tag={tag_list}')
|
||||
|
||||
# get all vertices
|
||||
vertex_list = self.nh.get_vertices().get('v', [])
|
||||
vertex_vid_list = [i.as_node().get_id().as_string() for i in vertex_list]
|
||||
|
||||
# update score
|
||||
vertex_score_dict = defaultdict(lambda: 0)
|
||||
for vid in vertex_vid_list:
|
||||
for tag in tag_list:
|
||||
if tag in vid:
|
||||
vertex_score_dict[vid] += VERTEX_SCORE
|
||||
|
||||
# merge depend adj score
|
||||
vertex_score_dict_final = {}
|
||||
for vertex in vertex_score_dict:
|
||||
cypher = f'''MATCH (v1)-[e]-(v2) where id(v1) == "{vertex}" RETURN v2'''
|
||||
cypher_res = self.nh.execute_cypher(cypher, self.nh.space_name)
|
||||
cypher_res_dict = self.nh.result_to_dict(cypher_res)
|
||||
|
||||
adj_vertex_list = [i.as_node().get_id().as_string() for i in cypher_res_dict.get('v2', [])]
|
||||
|
||||
score = vertex_score_dict.get(vertex, 0)
|
||||
for adj_vertex in adj_vertex_list:
|
||||
score += vertex_score_dict.get(adj_vertex, 0) * VERTEX_MERGE_RATIO
|
||||
|
||||
if score > 0:
|
||||
vertex_score_dict_final[vertex] = score
|
||||
|
||||
# get most prominent package tag
|
||||
package_score_dict = defaultdict(lambda: 0)
|
||||
|
||||
for vertex, score in vertex_score_dict_final.items():
|
||||
if '#' in vertex:
|
||||
# get class name first
|
||||
cypher = f'''MATCH (v1:class)-[e:contain]->(v2) WHERE id(v2) == '{vertex}' RETURN id(v1) as id;'''
|
||||
cypher_res = self.nh.execute_cypher(cypher=cypher, format_res=True)
|
||||
class_vertices = cypher_res.get('id', [])
|
||||
if not class_vertices:
|
||||
continue
|
||||
|
||||
vertex = class_vertices[0].as_string()
|
||||
|
||||
# get package name
|
||||
cypher = f'''MATCH (v1:package)-[e:contain]->(v2) WHERE id(v2) == '{vertex}' RETURN id(v1) as id;'''
|
||||
cypher_res = self.nh.execute_cypher(cypher=cypher, format_res=True)
|
||||
pac_vertices = cypher_res.get('id', [])
|
||||
if not pac_vertices:
|
||||
continue
|
||||
|
||||
package = pac_vertices[0].as_string()
|
||||
package_score_dict[package] += score
|
||||
|
||||
# get respective code
|
||||
res = []
|
||||
package_score_tuple = list(package_score_dict.items())
|
||||
package_score_tuple.sort(key=lambda x: x[1], reverse=True)
|
||||
|
||||
ids = [i[0] for i in package_score_tuple]
|
||||
logger.info(f'ids={ids}')
|
||||
chroma_res = self.ch.get(ids=ids, include=['metadatas'])
|
||||
|
||||
for vertex, score in package_score_tuple:
|
||||
index = chroma_res['result']['ids'].index(vertex)
|
||||
code_text = chroma_res['result']['metadatas'][index]['code_text']
|
||||
res.append({
|
||||
"vertex": vertex,
|
||||
"code_text": code_text}
|
||||
)
|
||||
if len(res) >= self.limit:
|
||||
break
|
||||
# logger.info(f'retrival code={res}')
|
||||
return res
|
||||
|
||||
def search_by_tag_by_graph(self, query: str):
|
||||
'''
|
||||
search code by tag with graph
|
||||
@param query:
|
||||
@return:
|
||||
'''
|
||||
tagger = Tagger()
|
||||
tag_list = tagger.generate_tag_query(query)
|
||||
logger.info(f'query tag={tag_list}')
|
||||
|
||||
# loop to get package node
|
||||
package_score_dict = {}
|
||||
for code, structure in self.graph.items():
|
||||
score = 0
|
||||
for class_name in structure['class_name_list']:
|
||||
for tag in tag_list:
|
||||
if tag.lower() in class_name.lower():
|
||||
score += 1
|
||||
|
||||
for func_name_list in structure['func_name_dict'].values():
|
||||
for func_name in func_name_list:
|
||||
for tag in tag_list:
|
||||
if tag.lower() in func_name.lower():
|
||||
score += 1
|
||||
package_score_dict[structure['pac_name']] = score
|
||||
|
||||
# get respective code
|
||||
res = []
|
||||
package_score_tuple = list(package_score_dict.items())
|
||||
package_score_tuple.sort(key=lambda x: x[1], reverse=True)
|
||||
|
||||
ids = [i[0] for i in package_score_tuple]
|
||||
logger.info(f'ids={ids}')
|
||||
chroma_res = self.ch.get(ids=ids, include=['metadatas'])
|
||||
|
||||
# logger.info(chroma_res)
|
||||
for vertex, score in package_score_tuple:
|
||||
index = chroma_res['result']['ids'].index(vertex)
|
||||
code_text = chroma_res['result']['metadatas'][index]['code_text']
|
||||
res.append({
|
||||
"vertex": vertex,
|
||||
"code_text": code_text}
|
||||
)
|
||||
if len(res) >= self.limit:
|
||||
break
|
||||
# logger.info(f'retrival code={res}')
|
||||
return res
|
||||
|
||||
def search_by_desciption(self, query: str, engine: str, model_path: str = "text2vec-base-chinese", embedding_device: str = "cpu", embed_config: EmbedConfig=None):
|
||||
'''
|
||||
search by perform sim search
|
||||
@param query:
|
||||
@return:
|
||||
'''
|
||||
query = query.replace(',', ',')
|
||||
query_emb = get_embedding(engine=engine, text_list=[query], model_path=model_path, embedding_device= embedding_device, embed_config=embed_config)
|
||||
query_emb = query_emb[query]
|
||||
|
||||
query_embeddings = [query_emb]
|
||||
query_result = self.ch.query(query_embeddings=query_embeddings, n_results=self.limit,
|
||||
include=['metadatas', 'distances'])
|
||||
|
||||
res = []
|
||||
for idx, distance in enumerate(query_result['result']['distances'][0]):
|
||||
if distance < MAX_DISTANCE:
|
||||
vertex = query_result['result']['ids'][0][idx]
|
||||
code_text = query_result['result']['metadatas'][0][idx]['code_text']
|
||||
res.append({
|
||||
"vertex": vertex,
|
||||
"code_text": code_text
|
||||
})
|
||||
|
||||
return res
|
||||
|
||||
def search_by_cypher(self, query: str):
|
||||
'''
|
||||
search by generating cypher
|
||||
@param query:
|
||||
@param engine:
|
||||
@return:
|
||||
'''
|
||||
cg = CypherGenerator(self.llm_config)
|
||||
cypher = cg.get_cypher(query)
|
||||
|
||||
if not cypher:
|
||||
return None
|
||||
|
||||
cypher_res = self.nh.execute_cypher(cypher, self.nh.space_name)
|
||||
logger.info(f'cypher execution result={cypher_res}')
|
||||
if not cypher_res.is_succeeded():
|
||||
return {
|
||||
'cypher': '',
|
||||
'cypher_res': ''
|
||||
}
|
||||
|
||||
res = {
|
||||
'cypher': cypher,
|
||||
'cypher_res': cypher_res
|
||||
}
|
||||
|
||||
return res
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# from configs.server_config import NEBULA_HOST, NEBULA_PORT, NEBULA_USER, NEBULA_PASSWORD, NEBULA_STORAGED_PORT
|
||||
# from configs.server_config import CHROMA_PERSISTENT_PATH
|
||||
from coagent.base_configs.env_config import (
|
||||
NEBULA_HOST, NEBULA_PORT, NEBULA_USER, NEBULA_PASSWORD, NEBULA_STORAGED_PORT,
|
||||
CHROMA_PERSISTENT_PATH
|
||||
)
|
||||
codebase_name = 'testing'
|
||||
|
||||
nh = NebulaHandler(host=NEBULA_HOST, port=NEBULA_PORT, username=NEBULA_USER,
|
||||
password=NEBULA_PASSWORD, space_name=codebase_name)
|
||||
nh.add_host(NEBULA_HOST, NEBULA_STORAGED_PORT)
|
||||
time.sleep(0.5)
|
||||
|
||||
ch = ChromaHandler(path=CHROMA_PERSISTENT_PATH, collection_name=codebase_name)
|
||||
|
||||
cs = CodeSearch(nh, ch)
|
||||
# res = cs.search_by_tag(tag_list=['createFineTuneCompletion', 'OpenAiApi'])
|
||||
# logger.debug(res)
|
||||
|
||||
# res = cs.search_by_cypher('代码中一共有多少个类', 'openai')
|
||||
# logger.debug(res)
|
||||
|
||||
res = cs.search_by_desciption('使用不同的HTTP请求类型(GET、POST、DELETE等)来执行不同的操作', 'openai')
|
||||
logger.debug(res)
|
|
@ -1,82 +0,0 @@
|
|||
# encoding: utf-8
|
||||
'''
|
||||
@author: 温进
|
||||
@file: cypher_generator.py
|
||||
@time: 2023/11/24 上午10:17
|
||||
@desc:
|
||||
'''
|
||||
from langchain import PromptTemplate
|
||||
from loguru import logger
|
||||
|
||||
from coagent.llm_models.openai_model import getChatModelFromConfig
|
||||
from coagent.llm_models.llm_config import LLMConfig
|
||||
from coagent.utils.postprocess import replace_lt_gt
|
||||
from langchain.schema import (
|
||||
HumanMessage,
|
||||
)
|
||||
from langchain.chains.graph_qa.prompts import NGQL_GENERATION_PROMPT, CYPHER_GENERATION_TEMPLATE
|
||||
|
||||
schema = '''
|
||||
Node properties: [{'tag': 'package', 'properties': []}, {'tag': 'class', 'properties': []}, {'tag': 'method', 'properties': []}]
|
||||
Edge properties: [{'edge': 'contain', 'properties': []}, {'edge': 'depend', 'properties': []}]
|
||||
Relationships: ['(:package)-[:contain]->(:class)', '(:class)-[:contain]->(:method)', '(:package)-[:contain]->(:package)']
|
||||
'''
|
||||
|
||||
|
||||
class CypherGenerator:
|
||||
def __init__(self, llm_config: LLMConfig):
|
||||
self.model = getChatModelFromConfig(llm_config)
|
||||
NEBULAGRAPH_EXTRA_INSTRUCTIONS = """
|
||||
Instructions:
|
||||
|
||||
First, generate cypher then convert it to NebulaGraph Cypher dialect(rather than standard):
|
||||
1. it requires explicit label specification only when referring to node properties: v.`Foo`.name
|
||||
2. note explicit label specification is not needed for edge properties, so it's e.name instead of e.`Bar`.name
|
||||
3. it uses double equals sign for comparison: `==` rather than `=`
|
||||
4. only use id(Foo) to get the name of node or edge
|
||||
```\n"""
|
||||
|
||||
NGQL_GENERATION_TEMPLATE = CYPHER_GENERATION_TEMPLATE.replace(
|
||||
"Generate Cypher", "Generate NebulaGraph Cypher"
|
||||
).replace("Instructions:", NEBULAGRAPH_EXTRA_INSTRUCTIONS)
|
||||
|
||||
self.NGQL_GENERATION_PROMPT = PromptTemplate(
|
||||
input_variables=["schema", "question"], template=NGQL_GENERATION_TEMPLATE
|
||||
)
|
||||
|
||||
def get_cypher(self, query: str):
|
||||
'''
|
||||
get cypher from query
|
||||
@param query:
|
||||
@return:
|
||||
'''
|
||||
content = self.NGQL_GENERATION_PROMPT.format(schema=schema, question=query)
|
||||
logger.info(content)
|
||||
ans = ''
|
||||
message = [HumanMessage(content=content)]
|
||||
chat_res = self.model.predict_messages(message)
|
||||
ans = chat_res.content
|
||||
|
||||
ans = replace_lt_gt(ans)
|
||||
|
||||
ans = self.post_process(ans)
|
||||
return ans
|
||||
|
||||
def post_process(self, cypher_res: str):
|
||||
'''
|
||||
判断是否为正确的 cypher
|
||||
@param cypher_res:
|
||||
@return:
|
||||
'''
|
||||
if '(' not in cypher_res or ')' not in cypher_res:
|
||||
return ''
|
||||
|
||||
return cypher_res
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
query = '代码库里有哪些函数,返回5个就可以'
|
||||
cg = CypherGenerator()
|
||||
|
||||
ans = cg.get_cypher(query)
|
||||
logger.debug(f'ans=\n{ans}')
|
|
@ -1,39 +0,0 @@
|
|||
# encoding: utf-8
|
||||
'''
|
||||
@author: 温进
|
||||
@file: tagger.py
|
||||
@time: 2023/11/24 下午1:32
|
||||
@desc:
|
||||
'''
|
||||
import re
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class Tagger:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def generate_tag_query(self, query):
|
||||
'''
|
||||
generate tag from query
|
||||
'''
|
||||
# simple extract english
|
||||
tag_list = re.findall(r'[a-zA-Z\_\.]+', query)
|
||||
tag_list = list(set(tag_list))
|
||||
tag_list = self.filter_tag_list(tag_list)
|
||||
return tag_list
|
||||
|
||||
def filter_tag_list(self, tag_list):
|
||||
'''
|
||||
filter out tag
|
||||
@param tag_list:
|
||||
@return:
|
||||
'''
|
||||
res = []
|
||||
for tag in tag_list:
|
||||
if tag in ['java', 'python']:
|
||||
continue
|
||||
res.append(tag)
|
||||
return res
|
||||
|
||||
|
|
@ -1,7 +0,0 @@
|
|||
# encoding: utf-8
|
||||
'''
|
||||
@author: 温进
|
||||
@file: __init__.py.py
|
||||
@time: 2023/11/21 下午2:07
|
||||
@desc:
|
||||
'''
|
File diff suppressed because one or more lines are too long
|
@ -1,275 +0,0 @@
|
|||
# encoding: utf-8
|
||||
'''
|
||||
@author: 温进
|
||||
@file: codebase_handler.py
|
||||
@time: 2023/11/21 下午2:25
|
||||
@desc:
|
||||
'''
|
||||
import os
|
||||
import time
|
||||
import json
|
||||
from typing import List
|
||||
from loguru import logger
|
||||
|
||||
from coagent.base_configs.env_config import (
|
||||
NEBULA_HOST, NEBULA_PORT, NEBULA_USER, NEBULA_PASSWORD, NEBULA_STORAGED_PORT,
|
||||
CHROMA_PERSISTENT_PATH, CB_ROOT_PATH
|
||||
)
|
||||
|
||||
|
||||
from coagent.db_handler.graph_db_handler.nebula_handler import NebulaHandler
|
||||
from coagent.db_handler.vector_db_handler.chroma_handler import ChromaHandler
|
||||
from coagent.codechat.code_crawler.zip_crawler import *
|
||||
from coagent.codechat.code_analyzer.code_analyzer import CodeAnalyzer
|
||||
from coagent.codechat.codebase_handler.code_importer import CodeImporter
|
||||
from coagent.codechat.code_search.code_search import CodeSearch
|
||||
from coagent.llm_models.llm_config import EmbedConfig, LLMConfig
|
||||
|
||||
|
||||
class CodeBaseHandler:
|
||||
def __init__(
|
||||
self,
|
||||
codebase_name: str,
|
||||
code_path: str = '',
|
||||
language: str = 'java',
|
||||
crawl_type: str = 'ZIP',
|
||||
embed_config: EmbedConfig = EmbedConfig(),
|
||||
llm_config: LLMConfig = LLMConfig(),
|
||||
use_nh: bool = True,
|
||||
local_graph_path: str = CB_ROOT_PATH
|
||||
):
|
||||
self.codebase_name = codebase_name
|
||||
self.code_path = code_path
|
||||
self.language = language
|
||||
self.crawl_type = crawl_type
|
||||
self.embed_config = embed_config
|
||||
self.llm_config = llm_config
|
||||
self.local_graph_file_path = local_graph_path + os.sep + f'{self.codebase_name}_graph.json'
|
||||
|
||||
if use_nh:
|
||||
try:
|
||||
self.nh = NebulaHandler(host=NEBULA_HOST, port=NEBULA_PORT, username=NEBULA_USER,
|
||||
password=NEBULA_PASSWORD, space_name=codebase_name)
|
||||
self.nh.add_host(NEBULA_HOST, NEBULA_STORAGED_PORT)
|
||||
time.sleep(1)
|
||||
except:
|
||||
self.nh = None
|
||||
try:
|
||||
with open(self.local_graph_file_path, 'r') as f:
|
||||
self.graph = json.load(f)
|
||||
except:
|
||||
pass
|
||||
elif local_graph_path:
|
||||
self.nh = None
|
||||
try:
|
||||
with open(self.local_graph_file_path, 'r') as f:
|
||||
self.graph = json.load(f)
|
||||
except:
|
||||
pass
|
||||
|
||||
self.ch = ChromaHandler(path=CHROMA_PERSISTENT_PATH, collection_name=codebase_name)
|
||||
|
||||
def import_code(self, zip_file='', do_interpret=True):
|
||||
'''
|
||||
analyze code and save it to codekg and codedb
|
||||
@return:
|
||||
'''
|
||||
# init graph to init tag and edge
|
||||
code_importer = CodeImporter(embed_config=self.embed_config, codebase_name=self.codebase_name,
|
||||
nh=self.nh, ch=self.ch, local_graph_file_path=self.local_graph_file_path)
|
||||
if self.nh:
|
||||
code_importer.init_graph()
|
||||
time.sleep(5)
|
||||
|
||||
# crawl code
|
||||
st0 = time.time()
|
||||
logger.info('start crawl')
|
||||
code_dict = self.crawl_code(zip_file)
|
||||
logger.debug('crawl done, rt={}'.format(time.time() - st0))
|
||||
|
||||
# analyze code
|
||||
logger.info('start analyze')
|
||||
st1 = time.time()
|
||||
code_analyzer = CodeAnalyzer(language=self.language, llm_config=self.llm_config)
|
||||
static_analysis_res, interpretation = code_analyzer.analyze(code_dict, do_interpret=do_interpret)
|
||||
logger.debug('analyze done, rt={}'.format(time.time() - st1))
|
||||
|
||||
# add info to nebula and chroma
|
||||
st2 = time.time()
|
||||
code_importer.import_code(static_analysis_res, interpretation, do_interpret=do_interpret)
|
||||
logger.debug('update codebase done, rt={}'.format(time.time() - st2))
|
||||
|
||||
# get KG info
|
||||
if self.nh:
|
||||
time.sleep(10) # aviod nebula staus didn't complete
|
||||
stat = self.nh.get_stat()
|
||||
vertices_num, edges_num = stat['vertices'], stat['edges']
|
||||
else:
|
||||
vertices_num = 0
|
||||
edges_num = 0
|
||||
|
||||
# get chroma info
|
||||
file_num = self.ch.count()['result']
|
||||
|
||||
return vertices_num, edges_num, file_num
|
||||
|
||||
def delete_codebase(self, codebase_name: str):
|
||||
'''
|
||||
delete codebase
|
||||
@param codebase_name: name of codebase
|
||||
@return:
|
||||
'''
|
||||
if self.nh:
|
||||
self.nh.drop_space(space_name=codebase_name)
|
||||
elif self.local_graph_file_path and os.path.isfile(self.local_graph_file_path):
|
||||
os.remove(self.local_graph_file_path)
|
||||
|
||||
self.ch.delete_collection(collection_name=codebase_name)
|
||||
|
||||
def crawl_code(self, zip_file=''):
|
||||
'''
|
||||
@return:
|
||||
'''
|
||||
if self.language == 'java':
|
||||
suffix = 'java'
|
||||
|
||||
logger.info(f'crawl_type={self.crawl_type}')
|
||||
|
||||
code_dict = {}
|
||||
if self.crawl_type.lower() == 'zip':
|
||||
code_dict = ZipCrawler.crawl(zip_file, output_path=self.code_path, suffix=suffix)
|
||||
elif self.crawl_type.lower() == 'dir':
|
||||
code_dict = DirCrawler.crawl(self.code_path, suffix)
|
||||
|
||||
return code_dict
|
||||
|
||||
def search_code(self, query: str, search_type: str, limit: int = 3):
|
||||
'''
|
||||
search code from codebase
|
||||
@param limit:
|
||||
@param engine:
|
||||
@param query: query from user
|
||||
@param search_type: ['cypher', 'graph', 'vector']
|
||||
@return:
|
||||
'''
|
||||
if self.nh:
|
||||
assert search_type in ['cypher', 'tag', 'description']
|
||||
else:
|
||||
if search_type == 'tag':
|
||||
search_type = 'tag_by_local_graph'
|
||||
assert search_type in ['tag_by_local_graph', 'description']
|
||||
|
||||
code_search = CodeSearch(llm_config=self.llm_config, nh=self.nh, ch=self.ch, limit=limit,
|
||||
local_graph_file_path=self.local_graph_file_path)
|
||||
|
||||
if search_type == 'cypher':
|
||||
search_res = code_search.search_by_cypher(query=query)
|
||||
elif search_type == 'tag':
|
||||
search_res = code_search.search_by_tag(query=query)
|
||||
elif search_type == 'description':
|
||||
search_res = code_search.search_by_desciption(
|
||||
query=query, engine=self.embed_config.embed_engine, model_path=self.embed_config.embed_model_path,
|
||||
embedding_device=self.embed_config.model_device, embed_config=self.embed_config)
|
||||
elif search_type == 'tag_by_local_graph':
|
||||
search_res = code_search.search_by_tag_by_graph(query=query)
|
||||
|
||||
|
||||
context, related_vertice = self.format_search_res(search_res, search_type)
|
||||
return context, related_vertice
|
||||
|
||||
def format_search_res(self, search_res: str, search_type: str):
|
||||
'''
|
||||
format search_res
|
||||
@param search_res:
|
||||
@param search_type:
|
||||
@return:
|
||||
'''
|
||||
CYPHER_QA_PROMPT = '''
|
||||
执行的 Cypher 是: {cypher}
|
||||
Cypher 的结果是: {result}
|
||||
'''
|
||||
|
||||
if search_type == 'cypher':
|
||||
context = CYPHER_QA_PROMPT.format(cypher=search_res['cypher'], result=search_res['cypher_res'])
|
||||
related_vertice = []
|
||||
elif search_type == 'tag':
|
||||
context = ''
|
||||
related_vertice = []
|
||||
for code in search_res:
|
||||
context = context + code['code_text'] + '\n'
|
||||
related_vertice.append(code['vertex'])
|
||||
elif search_type == 'tag_by_local_graph':
|
||||
context = ''
|
||||
related_vertice = []
|
||||
for code in search_res:
|
||||
context = context + code['code_text'] + '\n'
|
||||
related_vertice.append(code['vertex'])
|
||||
elif search_type == 'description':
|
||||
context = ''
|
||||
related_vertice = []
|
||||
for code in search_res:
|
||||
context = context + code['code_text'] + '\n'
|
||||
related_vertice.append(code['vertex'])
|
||||
|
||||
return context, related_vertice
|
||||
|
||||
def search_vertices(self, vertex_type="class") -> List[str]:
|
||||
'''
|
||||
通过 method/class 来搜索所有的节点
|
||||
'''
|
||||
vertices = []
|
||||
if self.nh:
|
||||
vertices = self.nh.get_all_vertices()
|
||||
vertices = [str(v.as_node().get_id()) for v in vertices["v"] if vertex_type in v.as_node().tags()]
|
||||
# for v in vertices["v"]:
|
||||
# logger.debug(f"{v.as_node().get_id()}, {v.as_node().tags()}")
|
||||
else:
|
||||
if vertex_type == "class":
|
||||
vertices = [str(class_name) for code, structure in self.graph.items() for class_name in structure['class_name_list']]
|
||||
elif vertex_type == "method":
|
||||
vertices = [
|
||||
str(methods_name)
|
||||
for code, structure in self.graph.items()
|
||||
for methods_names in structure['func_name_dict'].values()
|
||||
for methods_name in methods_names
|
||||
]
|
||||
# logger.debug(vertices)
|
||||
return vertices
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
from configs.model_config import KB_ROOT_PATH, JUPYTER_WORK_PATH
|
||||
from configs.server_config import SANDBOX_SERVER
|
||||
|
||||
LLM_MODEL = "gpt-3.5-turbo"
|
||||
llm_config = LLMConfig(
|
||||
model_name=LLM_MODEL, model_device="cpu", api_key=os.environ["OPENAI_API_KEY"],
|
||||
api_base_url=os.environ["API_BASE_URL"], temperature=0.3
|
||||
)
|
||||
src_dir = '/Users/bingxu/Desktop/工作/大模型/chatbot/Codefuse-chatbot-antcode'
|
||||
embed_config = EmbedConfig(
|
||||
embed_engine="model", embed_model="text2vec-base-chinese",
|
||||
embed_model_path=os.path.join(src_dir, "embedding_models/text2vec-base-chinese")
|
||||
)
|
||||
|
||||
codebase_name = 'client_local'
|
||||
code_path = '/Users/bingxu/Desktop/工作/大模型/chatbot/test_code_repo/client'
|
||||
use_nh = False
|
||||
local_graph_path = '/Users/bingxu/Desktop/工作/大模型/chatbot/Codefuse-chatbot-antcode/code_base'
|
||||
CHROMA_PERSISTENT_PATH = '/Users/bingxu/Desktop/工作/大模型/chatbot/Codefuse-chatbot-antcode/data/chroma_data'
|
||||
|
||||
cbh = CodeBaseHandler(codebase_name, code_path, crawl_type='dir', use_nh=use_nh, local_graph_path=local_graph_path,
|
||||
llm_config=llm_config, embed_config=embed_config)
|
||||
|
||||
# test import code
|
||||
# cbh.import_code(do_interpret=True)
|
||||
|
||||
# query = '使用不同的HTTP请求类型(GET、POST、DELETE等)来执行不同的操作'
|
||||
# query = '代码中一共有多少个类'
|
||||
# query = 'remove 这个函数是用来做什么的'
|
||||
query = '有没有函数是从字符串中删除指定字符串的功能'
|
||||
|
||||
search_type = 'description'
|
||||
limit = 2
|
||||
res = cbh.search_code(query, search_type, limit)
|
||||
logger.debug(res)
|
|
@ -1,9 +0,0 @@
|
|||
from .configs import PHASE_CONFIGS
|
||||
|
||||
|
||||
|
||||
PHASE_LIST = list(PHASE_CONFIGS.keys())
|
||||
|
||||
__all__ = [
|
||||
"PHASE_CONFIGS"
|
||||
]
|
|
@ -1,6 +0,0 @@
|
|||
from .base_action import BaseAction
|
||||
|
||||
|
||||
__all__ = [
|
||||
"BaseAction"
|
||||
]
|
|
@ -1,16 +0,0 @@
|
|||
|
||||
from langchain.schema import BaseRetriever, Document
|
||||
|
||||
class BaseAction:
|
||||
|
||||
|
||||
def __init__(self, ):
|
||||
pass
|
||||
|
||||
def step(self, ):
|
||||
pass
|
||||
|
||||
def astep(self, ):
|
||||
pass
|
||||
|
||||
|
|
@ -1,8 +0,0 @@
|
|||
from .base_agent import BaseAgent
|
||||
from .react_agent import ReactAgent
|
||||
from .executor_agent import ExecutorAgent
|
||||
from .selector_agent import SelectorAgent
|
||||
|
||||
__all__ = [
|
||||
"BaseAgent", "ReactAgent", "ExecutorAgent", "SelectorAgent"
|
||||
]
|
|
@ -1,211 +0,0 @@
|
|||
from typing import List, Union
|
||||
import importlib
|
||||
import re, os
|
||||
import copy
|
||||
from loguru import logger
|
||||
|
||||
from langchain.schema import BaseRetriever
|
||||
|
||||
from coagent.connector.schema import (
|
||||
Memory, Task, Role, Message, PromptField, LogVerboseEnum
|
||||
)
|
||||
from coagent.connector.memory_manager import BaseMemoryManager
|
||||
from coagent.connector.message_process import MessageUtils
|
||||
from coagent.llm_models import getExtraModel, LLMConfig, getChatModelFromConfig, EmbedConfig
|
||||
from coagent.connector.prompt_manager.prompt_manager import PromptManager
|
||||
from coagent.connector.memory_manager import LocalMemoryManager
|
||||
from coagent.base_configs.env_config import JUPYTER_WORK_PATH, KB_ROOT_PATH
|
||||
|
||||
|
||||
class BaseAgent:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
role: Role,
|
||||
prompt_config: List[PromptField],
|
||||
prompt_manager_type: str = "PromptManager",
|
||||
task: Task = None,
|
||||
memory: Memory = None,
|
||||
chat_turn: int = 1,
|
||||
focus_agents: List[str] = [],
|
||||
focus_message_keys: List[str] = [],
|
||||
#
|
||||
llm_config: LLMConfig = None,
|
||||
embed_config: EmbedConfig = None,
|
||||
sandbox_server: dict = {},
|
||||
jupyter_work_path: str = JUPYTER_WORK_PATH,
|
||||
kb_root_path: str = KB_ROOT_PATH,
|
||||
doc_retrieval: Union[BaseRetriever] = None,
|
||||
code_retrieval = None,
|
||||
search_retrieval = None,
|
||||
log_verbose: str = "0"
|
||||
):
|
||||
|
||||
self.task = task
|
||||
self.role = role
|
||||
self.sandbox_server = sandbox_server
|
||||
self.jupyter_work_path = jupyter_work_path
|
||||
self.kb_root_path = kb_root_path
|
||||
self.message_utils = MessageUtils(role, sandbox_server, jupyter_work_path, embed_config, llm_config, kb_root_path, doc_retrieval, code_retrieval, search_retrieval, log_verbose)
|
||||
self.memory = self.init_history(memory)
|
||||
self.llm_config: LLMConfig = llm_config
|
||||
self.embed_config: EmbedConfig = embed_config
|
||||
self.llm = self.create_llm_engine(llm_config=self.llm_config)
|
||||
self.chat_turn = chat_turn
|
||||
#
|
||||
self.focus_agents = focus_agents
|
||||
self.focus_message_keys = focus_message_keys
|
||||
#
|
||||
prompt_manager_module = importlib.import_module("coagent.connector.prompt_manager")
|
||||
prompt_manager = getattr(prompt_manager_module, prompt_manager_type)
|
||||
self.prompt_manager: PromptManager = prompt_manager(role_prompt=role.role_prompt, prompt_config=prompt_config)
|
||||
self.log_verbose = max(os.environ.get("log_verbose", "0"), log_verbose)
|
||||
|
||||
def step(self, query: Message, history: Memory = None, background: Memory = None, memory_manager: BaseMemoryManager=None) -> Message:
|
||||
'''agent reponse from multi-message'''
|
||||
message = None
|
||||
for message in self.astep(query, history, background, memory_manager):
|
||||
pass
|
||||
return message
|
||||
|
||||
def astep(self, query: Message, history: Memory = None, background: Memory = None, memory_manager: BaseMemoryManager=None) -> Message:
|
||||
'''agent reponse from multi-message'''
|
||||
# insert query into memory
|
||||
query_c = copy.deepcopy(query)
|
||||
query_c = self.start_action_step(query_c)
|
||||
|
||||
# llm predict
|
||||
# prompt = self.create_prompt(query_c, self.memory, history, background, memory_pool=memory_manager.current_memory)
|
||||
if memory_manager is None:
|
||||
memory_manager = LocalMemoryManager(
|
||||
unique_name=self.role.role_name,
|
||||
do_init=True,
|
||||
kb_root_path = self.kb_root_path,
|
||||
embed_config=self.embed_config,
|
||||
llm_config=self.embed_config
|
||||
)
|
||||
memory_manager.append(query)
|
||||
memory_pool = memory_manager.get_memory_pool(query.user_name)
|
||||
|
||||
prompt = self.prompt_manager.generate_full_prompt(
|
||||
previous_agent_message=query_c, agent_long_term_memory=self.memory, ui_history=history, chain_summary_messages=background, memory_pool=memory_pool)
|
||||
content = self.llm.predict(prompt)
|
||||
|
||||
if LogVerboseEnum.ge(LogVerboseEnum.Log2Level, self.log_verbose):
|
||||
logger.debug(f"{self.role.role_name} prompt: {prompt}")
|
||||
|
||||
if LogVerboseEnum.ge(LogVerboseEnum.Log1Level, self.log_verbose):
|
||||
logger.info(f"{self.role.role_name} content: {content}")
|
||||
|
||||
output_message = Message(
|
||||
user_name=query.user_name,
|
||||
role_name=self.role.role_name,
|
||||
role_type="assistant", #self.role.role_type,
|
||||
role_content=content,
|
||||
step_content=content,
|
||||
input_query=query_c.input_query,
|
||||
tools=query_c.tools,
|
||||
# parsed_output_list=[query.parsed_output],
|
||||
customed_kargs=query_c.customed_kargs
|
||||
)
|
||||
|
||||
# common parse llm' content to message
|
||||
output_message = self.message_utils.parser(output_message)
|
||||
|
||||
# action step
|
||||
output_message, observation_message = self.message_utils.step_router(output_message, history, background, memory_manager=memory_manager)
|
||||
output_message.parsed_output_list.append(output_message.parsed_output)
|
||||
if observation_message:
|
||||
output_message.parsed_output_list.append(observation_message.parsed_output)
|
||||
|
||||
# update self_memory
|
||||
self.append_history(query_c)
|
||||
self.append_history(output_message)
|
||||
|
||||
output_message.input_query = output_message.role_content
|
||||
# end
|
||||
output_message = self.message_utils.inherit_extrainfo(query, output_message)
|
||||
output_message = self.end_action_step(output_message)
|
||||
|
||||
# update memory pool
|
||||
memory_manager.append(output_message)
|
||||
yield output_message
|
||||
|
||||
def pre_print(self, query: Message, history: Memory = None, background: Memory = None, memory_manager: BaseMemoryManager=None):
|
||||
prompt = self.prompt_manager.pre_print(
|
||||
previous_agent_message=query, agent_long_term_memory=self.memory, ui_history=history, chain_summary_messages=background, memory_pool=memory_manager.current_memory)
|
||||
title = f"<<<<{self.role.role_name}'s prompt>>>>"
|
||||
print("#"*len(title) + f"\n{title}\n"+ "#"*len(title)+ f"\n\n{prompt}\n\n")
|
||||
|
||||
def init_history(self, memory: Memory = None) -> Memory:
|
||||
return Memory(messages=[])
|
||||
|
||||
def update_history(self, message: Message):
|
||||
self.memory.append(message)
|
||||
|
||||
def append_history(self, message: Message):
|
||||
self.memory.append(message)
|
||||
|
||||
def clear_history(self, ):
|
||||
self.memory.clear()
|
||||
self.memory = self.init_history()
|
||||
|
||||
def create_llm_engine(self, llm_config: LLMConfig = None, temperature=0.2, stop=None):
|
||||
return getChatModelFromConfig(llm_config=llm_config)
|
||||
|
||||
def registry_actions(self, actions):
|
||||
'''registry llm's actions'''
|
||||
self.action_list = actions
|
||||
|
||||
def start_action_step(self, message: Message) -> Message:
|
||||
'''do action before agent predict '''
|
||||
# action_json = self.start_action()
|
||||
# message["customed_kargs"]["xx"] = action_json
|
||||
return message
|
||||
|
||||
def end_action_step(self, message: Message) -> Message:
|
||||
'''do action after agent predict '''
|
||||
# action_json = self.end_action()
|
||||
# message["customed_kargs"]["xx"] = action_json
|
||||
return message
|
||||
|
||||
def token_usage(self, ):
|
||||
'''calculate the usage of token'''
|
||||
pass
|
||||
|
||||
def select_memory_by_key(self, memory: Memory) -> Memory:
|
||||
return Memory(
|
||||
messages=[self.select_message_by_key(message) for message in memory.messages
|
||||
if self.select_message_by_key(message) is not None]
|
||||
)
|
||||
|
||||
def select_memory_by_agent_key(self, memory: Memory) -> Memory:
|
||||
return Memory(
|
||||
messages=[self.select_message_by_agent_key(message) for message in memory.messages
|
||||
if self.select_message_by_agent_key(message) is not None]
|
||||
)
|
||||
|
||||
def select_message_by_agent_key(self, message: Message) -> Message:
|
||||
# assume we focus all agents
|
||||
if self.focus_agents == []:
|
||||
return message
|
||||
return None if message is None or message.role_name not in self.focus_agents else self.select_message_by_key(message)
|
||||
|
||||
def select_message_by_key(self, message: Message) -> Message:
|
||||
# assume we focus all key contents
|
||||
if message is None:
|
||||
return message
|
||||
|
||||
if self.focus_message_keys == []:
|
||||
return message
|
||||
|
||||
message_c = copy.deepcopy(message)
|
||||
message_c.parsed_output = {k: v for k,v in message_c.parsed_output.items() if k in self.focus_message_keys}
|
||||
message_c.parsed_output_list = [{k: v for k,v in parsed_output.items() if k in self.focus_message_keys} for parsed_output in message_c.parsed_output_list]
|
||||
return message_c
|
||||
|
||||
def get_memory(self, content_key="role_content"):
|
||||
return self.memory.to_tuple_messages(content_key="step_content")
|
||||
|
||||
def get_memory_str(self, content_key="role_content"):
|
||||
return "\n".join([": ".join(i) for i in self.memory.to_tuple_messages(content_key="step_content")])
|
|
@ -1,157 +0,0 @@
|
|||
from typing import List, Union
|
||||
import copy
|
||||
from loguru import logger
|
||||
|
||||
from langchain.schema import BaseRetriever
|
||||
|
||||
from coagent.connector.schema import (
|
||||
Memory, Task, Env, Role, Message, ActionStatus, PromptField, LogVerboseEnum
|
||||
)
|
||||
from coagent.connector.memory_manager import BaseMemoryManager
|
||||
from coagent.llm_models import LLMConfig, EmbedConfig
|
||||
from coagent.connector.memory_manager import LocalMemoryManager
|
||||
from coagent.base_configs.env_config import JUPYTER_WORK_PATH, KB_ROOT_PATH
|
||||
from .base_agent import BaseAgent
|
||||
|
||||
|
||||
class ExecutorAgent(BaseAgent):
|
||||
def __init__(
|
||||
self,
|
||||
role: Role,
|
||||
prompt_config: List[PromptField],
|
||||
prompt_manager_type: str= "PromptManager",
|
||||
task: Task = None,
|
||||
memory: Memory = None,
|
||||
chat_turn: int = 1,
|
||||
focus_agents: List[str] = [],
|
||||
focus_message_keys: List[str] = [],
|
||||
#
|
||||
llm_config: LLMConfig = None,
|
||||
embed_config: EmbedConfig = None,
|
||||
sandbox_server: dict = {},
|
||||
jupyter_work_path: str = JUPYTER_WORK_PATH,
|
||||
kb_root_path: str = KB_ROOT_PATH,
|
||||
doc_retrieval: Union[BaseRetriever] = None,
|
||||
code_retrieval = None,
|
||||
search_retrieval = None,
|
||||
log_verbose: str = "0"
|
||||
):
|
||||
|
||||
super().__init__(role, prompt_config, prompt_manager_type, task, memory, chat_turn,
|
||||
focus_agents, focus_message_keys, llm_config, embed_config, sandbox_server,
|
||||
jupyter_work_path, kb_root_path, doc_retrieval, code_retrieval, search_retrieval, log_verbose
|
||||
)
|
||||
self.do_all_task = True # run all tasks
|
||||
|
||||
def astep(self, query: Message, history: Memory = None, background: Memory = None, memory_manager: BaseMemoryManager=None) -> Message:
|
||||
'''agent reponse from multi-message'''
|
||||
# insert query into memory
|
||||
task_executor_memory = Memory(messages=[])
|
||||
# insert query
|
||||
output_message = Message(
|
||||
user_name=query.user_name,
|
||||
role_name=self.role.role_name,
|
||||
role_type="assistant", #self.role.role_type,
|
||||
role_content=query.input_query,
|
||||
step_content="",
|
||||
input_query=query.input_query,
|
||||
tools=query.tools,
|
||||
# parsed_output_list=[query.parsed_output],
|
||||
customed_kargs=query.customed_kargs
|
||||
)
|
||||
|
||||
if memory_manager is None:
|
||||
memory_manager = LocalMemoryManager(
|
||||
unique_name=self.role.role_name,
|
||||
do_init=True,
|
||||
kb_root_path = self.kb_root_path,
|
||||
embed_config=self.embed_config,
|
||||
llm_config=self.embed_config
|
||||
)
|
||||
memory_manager.append(query)
|
||||
|
||||
# self_memory = self.memory if self.do_use_self_memory else None
|
||||
|
||||
plan_step = int(query.parsed_output.get("PLAN_STEP", 0))
|
||||
# 如果存在plan字段且plan字段为str的时候
|
||||
if "PLAN" not in query.parsed_output or isinstance(query.parsed_output.get("PLAN", []), str) or plan_step >= len(query.parsed_output.get("PLAN", [])):
|
||||
query_c = copy.deepcopy(query)
|
||||
query_c = self.start_action_step(query_c)
|
||||
query_c.parsed_output = {"CURRENT_STEP": query_c.input_query}
|
||||
task_executor_memory.append(query_c)
|
||||
for output_message, task_executor_memory in self._arun_step(output_message, query_c, self.memory, history, background, memory_manager, task_executor_memory):
|
||||
pass
|
||||
# task_executor_memory.append(query_c)
|
||||
# content = "the execution step of the plan is exceed the planned scope."
|
||||
# output_message.parsed_dict = {"Thought": content, "Action Status": "finished", "Action": content}
|
||||
# task_executor_memory.append(output_message)
|
||||
|
||||
elif "PLAN" in query.parsed_output:
|
||||
if self.do_all_task:
|
||||
# run all tasks step by step
|
||||
for task_content in query.parsed_output["PLAN"][plan_step:]:
|
||||
# create your llm prompt
|
||||
query_c = copy.deepcopy(query)
|
||||
query_c.parsed_output = {"CURRENT_STEP": task_content}
|
||||
task_executor_memory.append(query_c)
|
||||
for output_message, task_executor_memory in self._arun_step(output_message, query_c, self.memory, history, background, memory_manager, task_executor_memory):
|
||||
pass
|
||||
yield output_message
|
||||
else:
|
||||
query_c = copy.deepcopy(query)
|
||||
query_c = self.start_action_step(query_c)
|
||||
task_content = query_c.parsed_output["PLAN"][plan_step]
|
||||
query_c.parsed_output = {"CURRENT_STEP": task_content}
|
||||
task_executor_memory.append(query_c)
|
||||
for output_message, task_executor_memory in self._arun_step(output_message, query_c, self.memory, history, background, memory_manager, task_executor_memory):
|
||||
pass
|
||||
output_message.parsed_output.update({"CURRENT_STEP": plan_step})
|
||||
# update self_memory
|
||||
self.append_history(query)
|
||||
self.append_history(output_message)
|
||||
output_message.input_query = output_message.role_content
|
||||
# end_action_step
|
||||
output_message = self.end_action_step(output_message)
|
||||
# update memory pool
|
||||
memory_manager.append(output_message)
|
||||
yield output_message
|
||||
|
||||
def _arun_step(self, output_message: Message, query: Message, self_memory: Memory,
|
||||
history: Memory, background: Memory, memory_manager: BaseMemoryManager,
|
||||
task_memory: Memory) -> Union[Message, Memory]:
|
||||
'''execute the llm predict by created prompt'''
|
||||
memory_pool = memory_manager.get_memory_pool(query.user_name)
|
||||
prompt = self.prompt_manager.generate_full_prompt(
|
||||
previous_agent_message=query, agent_long_term_memory=self_memory, ui_history=history, chain_summary_messages=background, memory_pool=memory_pool,
|
||||
task_memory=task_memory)
|
||||
content = self.llm.predict(prompt)
|
||||
|
||||
if LogVerboseEnum.ge(LogVerboseEnum.Log2Level, self.log_verbose):
|
||||
logger.debug(f"{self.role.role_name} prompt: {prompt}")
|
||||
|
||||
if LogVerboseEnum.ge(LogVerboseEnum.Log1Level, self.log_verbose):
|
||||
logger.info(f"{self.role.role_name} content: {content}")
|
||||
|
||||
output_message.role_content = content
|
||||
output_message.step_content += "\n"+output_message.role_content
|
||||
output_message = self.message_utils.parser(output_message)
|
||||
# according the output to choose one action for code_content or tool_content
|
||||
output_message, observation_message = self.message_utils.step_router(output_message)
|
||||
# update parserd_output_list
|
||||
output_message.parsed_output_list.append(output_message.parsed_output)
|
||||
|
||||
react_message = copy.deepcopy(output_message)
|
||||
task_memory.append(react_message)
|
||||
if observation_message:
|
||||
task_memory.append(observation_message)
|
||||
output_message.parsed_output_list.append(observation_message.parsed_output)
|
||||
# logger.debug(f"{observation_message.role_name} content: {observation_message.role_content}")
|
||||
yield output_message, task_memory
|
||||
|
||||
def pre_print(self, query: Message, history: Memory = None, background: Memory = None, memory_manager: BaseMemoryManager = None):
|
||||
task_memory = Memory(messages=[])
|
||||
prompt = self.prompt_manager.pre_print(
|
||||
previous_agent_message=query, agent_long_term_memory=self.memory, ui_history=history, chain_summary_messages=background, react_memory=None,
|
||||
memory_pool=memory_manager.current_memory, task_memory=task_memory)
|
||||
title = f"<<<<{self.role.role_name}'s prompt>>>>"
|
||||
print("#"*len(title) + f"\n{title}\n"+ "#"*len(title)+ f"\n\n{prompt}\n\n")
|
|
@ -1,147 +0,0 @@
|
|||
from typing import List, Union
|
||||
import traceback
|
||||
import copy
|
||||
from loguru import logger
|
||||
|
||||
from langchain.schema import BaseRetriever
|
||||
|
||||
from coagent.connector.schema import (
|
||||
Memory, Task, Env, Role, Message, ActionStatus, PromptField, LogVerboseEnum
|
||||
)
|
||||
from coagent.connector.memory_manager import BaseMemoryManager
|
||||
from coagent.llm_models import LLMConfig, EmbedConfig
|
||||
from .base_agent import BaseAgent
|
||||
from coagent.connector.memory_manager import LocalMemoryManager
|
||||
from coagent.base_configs.env_config import JUPYTER_WORK_PATH, KB_ROOT_PATH
|
||||
|
||||
|
||||
class ReactAgent(BaseAgent):
|
||||
def __init__(
|
||||
self,
|
||||
role: Role,
|
||||
prompt_config: List[PromptField],
|
||||
prompt_manager_type: str = "PromptManager",
|
||||
task: Task = None,
|
||||
memory: Memory = None,
|
||||
chat_turn: int = 1,
|
||||
focus_agents: List[str] = [],
|
||||
focus_message_keys: List[str] = [],
|
||||
#
|
||||
llm_config: LLMConfig = None,
|
||||
embed_config: EmbedConfig = None,
|
||||
sandbox_server: dict = {},
|
||||
jupyter_work_path: str = JUPYTER_WORK_PATH,
|
||||
kb_root_path: str = KB_ROOT_PATH,
|
||||
doc_retrieval: Union[BaseRetriever] = None,
|
||||
code_retrieval = None,
|
||||
search_retrieval = None,
|
||||
log_verbose: str = "0"
|
||||
):
|
||||
|
||||
super().__init__(role, prompt_config, prompt_manager_type, task, memory, chat_turn,
|
||||
focus_agents, focus_message_keys, llm_config, embed_config, sandbox_server,
|
||||
jupyter_work_path, kb_root_path, doc_retrieval, code_retrieval, search_retrieval, log_verbose
|
||||
)
|
||||
|
||||
def step(self, query: Message, history: Memory = None, background: Memory = None, memory_manager: BaseMemoryManager = None) -> Message:
|
||||
'''agent reponse from multi-message'''
|
||||
for message in self.astep(query, history, background, memory_manager):
|
||||
pass
|
||||
return message
|
||||
|
||||
def astep(self, query: Message, history: Memory = None, background: Memory = None, memory_manager: BaseMemoryManager = None) -> Message:
|
||||
'''agent reponse from multi-message'''
|
||||
step_nums = copy.deepcopy(self.chat_turn)
|
||||
react_memory = Memory(messages=[])
|
||||
# insert query
|
||||
output_message = Message(
|
||||
user_name=query.user_name,
|
||||
role_name=self.role.role_name,
|
||||
role_type="assistant", #self.role.role_type,
|
||||
role_content=query.input_query,
|
||||
step_content="",
|
||||
input_query=query.input_query,
|
||||
tools=query.tools,
|
||||
# parsed_output_list=[query.parsed_output],
|
||||
customed_kargs=query.customed_kargs
|
||||
)
|
||||
query_c = copy.deepcopy(query)
|
||||
query_c = self.start_action_step(query_c)
|
||||
# if query.parsed_output:
|
||||
# query_c.parsed_output = {"Question": "\n".join([f"{v}" for k, v in query.parsed_output.items() if k not in ["Action Status"]])}
|
||||
# else:
|
||||
# query_c.parsed_output = {"Question": query.input_query}
|
||||
# react_memory.append(query_c)
|
||||
# self_memory = self.memory if self.do_use_self_memory else None
|
||||
idx = 0
|
||||
# start to react
|
||||
while step_nums > 0:
|
||||
output_message.role_content = output_message.step_content
|
||||
# prompt = self.create_prompt(query, self.memory, history, background, react_memory, memory_manager.current_memory)
|
||||
|
||||
if memory_manager is None:
|
||||
memory_manager = LocalMemoryManager(
|
||||
unique_name=self.role.role_name,
|
||||
do_init=True,
|
||||
kb_root_path = self.kb_root_path,
|
||||
embed_config=self.embed_config,
|
||||
llm_config=self.embed_config
|
||||
)
|
||||
memory_manager.append(query)
|
||||
memory_pool = memory_manager.get_memory_pool(query_c.user_name)
|
||||
|
||||
prompt = self.prompt_manager.generate_full_prompt(
|
||||
previous_agent_message=query_c, agent_long_term_memory=self.memory, ui_history=history, chain_summary_messages=background, react_memory=react_memory,
|
||||
memory_pool=memory_pool)
|
||||
try:
|
||||
content = self.llm.predict(prompt)
|
||||
except Exception as e:
|
||||
logger.error(f"error prompt: {prompt}")
|
||||
raise Exception(traceback.format_exc())
|
||||
|
||||
output_message.role_content = "\n"+content
|
||||
output_message.step_content += "\n"+output_message.role_content
|
||||
yield output_message
|
||||
|
||||
if LogVerboseEnum.ge(LogVerboseEnum.Log2Level, self.log_verbose):
|
||||
logger.debug(f"{self.role.role_name}, {idx} iteration prompt: {prompt}")
|
||||
|
||||
if LogVerboseEnum.ge(LogVerboseEnum.Log1Level, self.log_verbose):
|
||||
logger.info(f"{self.role.role_name}, {idx} iteration step_run: {output_message.role_content}")
|
||||
|
||||
output_message = self.message_utils.parser(output_message)
|
||||
# when get finished signal can stop early
|
||||
if output_message.action_status == ActionStatus.FINISHED or output_message.action_status == ActionStatus.STOPPED:
|
||||
output_message.parsed_output_list.append(output_message.parsed_output)
|
||||
break
|
||||
# according the output to choose one action for code_content or tool_content
|
||||
output_message, observation_message = self.message_utils.step_router(output_message)
|
||||
output_message.parsed_output_list.append(output_message.parsed_output)
|
||||
|
||||
react_message = copy.deepcopy(output_message)
|
||||
react_memory.append(react_message)
|
||||
if observation_message:
|
||||
react_memory.append(observation_message)
|
||||
output_message.parsed_output_list.append(observation_message.parsed_output)
|
||||
# logger.debug(f"{observation_message.role_name} content: {observation_message.role_content}")
|
||||
idx += 1
|
||||
step_nums -= 1
|
||||
yield output_message
|
||||
# react' self_memory saved at last
|
||||
self.append_history(output_message)
|
||||
output_message.input_query = query.input_query
|
||||
# end_action_step, BUG:it may cause slack some information
|
||||
output_message = self.end_action_step(output_message)
|
||||
# update memory pool
|
||||
memory_manager.append(output_message)
|
||||
yield output_message
|
||||
|
||||
def pre_print(self, query: Message, history: Memory = None, background: Memory = None, memory_manager: BaseMemoryManager=None):
|
||||
react_memory = Memory(messages=[])
|
||||
prompt = self.prompt_manager.pre_print(
|
||||
previous_agent_message=query, agent_long_term_memory=self.memory, ui_history=history, chain_summary_messages=background, react_memory=react_memory,
|
||||
memory_pool=memory_manager.current_memory)
|
||||
title = f"<<<<{self.role.role_name}'s prompt>>>>"
|
||||
print("#"*len(title) + f"\n{title}\n"+ "#"*len(title)+ f"\n\n{prompt}\n\n")
|
||||
|
||||
|
|
@ -1,125 +0,0 @@
|
|||
from typing import List, Union
|
||||
import copy
|
||||
import random
|
||||
from loguru import logger
|
||||
|
||||
from langchain.schema import BaseRetriever
|
||||
|
||||
from coagent.connector.schema import (
|
||||
Memory, Task, Role, Message, PromptField, LogVerboseEnum
|
||||
)
|
||||
from coagent.connector.memory_manager import BaseMemoryManager
|
||||
from coagent.connector.memory_manager import LocalMemoryManager
|
||||
from coagent.llm_models import LLMConfig, EmbedConfig
|
||||
from coagent.base_configs.env_config import JUPYTER_WORK_PATH, KB_ROOT_PATH
|
||||
from .base_agent import BaseAgent
|
||||
|
||||
|
||||
class SelectorAgent(BaseAgent):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
role: Role,
|
||||
prompt_config: List[PromptField] = None,
|
||||
prompt_manager_type: str = "PromptManager",
|
||||
task: Task = None,
|
||||
memory: Memory = None,
|
||||
chat_turn: int = 1,
|
||||
focus_agents: List[str] = [],
|
||||
focus_message_keys: List[str] = [],
|
||||
group_agents: List[BaseAgent] = [],
|
||||
#
|
||||
llm_config: LLMConfig = None,
|
||||
embed_config: EmbedConfig = None,
|
||||
sandbox_server: dict = {},
|
||||
jupyter_work_path: str = JUPYTER_WORK_PATH,
|
||||
kb_root_path: str = KB_ROOT_PATH,
|
||||
doc_retrieval: Union[BaseRetriever] = None,
|
||||
code_retrieval = None,
|
||||
search_retrieval = None,
|
||||
log_verbose: str = "0"
|
||||
):
|
||||
|
||||
super().__init__(role, prompt_config, prompt_manager_type, task, memory, chat_turn,
|
||||
focus_agents, focus_message_keys, llm_config, embed_config, sandbox_server,
|
||||
jupyter_work_path, kb_root_path, doc_retrieval, code_retrieval, search_retrieval, log_verbose
|
||||
)
|
||||
self.group_agents = group_agents
|
||||
|
||||
def astep(self, query: Message, history: Memory = None, background: Memory = None, memory_manager: BaseMemoryManager=None) -> Message:
|
||||
'''agent reponse from multi-message'''
|
||||
# insert query into memory
|
||||
query_c = copy.deepcopy(query)
|
||||
query_c = self.start_action_step(query_c)
|
||||
# create your llm prompt
|
||||
if memory_manager is None:
|
||||
memory_manager = LocalMemoryManager(
|
||||
unique_name=self.role.role_name,
|
||||
do_init=True,
|
||||
kb_root_path = self.kb_root_path,
|
||||
embed_config=self.embed_config,
|
||||
llm_config=self.embed_config
|
||||
)
|
||||
memory_manager.append(query)
|
||||
memory_pool = memory_manager.get_memory_pool(query_c.user_name)
|
||||
|
||||
prompt = self.prompt_manager.generate_full_prompt(
|
||||
previous_agent_message=query_c, agent_long_term_memory=self.memory, ui_history=history, chain_summary_messages=background, react_memory=None,
|
||||
memory_pool=memory_pool, agents=self.group_agents)
|
||||
content = self.llm.predict(prompt)
|
||||
|
||||
if LogVerboseEnum.ge(LogVerboseEnum.Log2Level, self.log_verbose):
|
||||
logger.debug(f"{self.role.role_name} prompt: {prompt}")
|
||||
|
||||
if LogVerboseEnum.ge(LogVerboseEnum.Log1Level, self.log_verbose):
|
||||
logger.info(f"{self.role.role_name} content: {content}")
|
||||
|
||||
# select agent
|
||||
select_message = Message(
|
||||
role_name=self.role.role_name,
|
||||
role_type="assistant", #self.role.role_type,
|
||||
role_content=content,
|
||||
step_content=content,
|
||||
input_query=query_c.input_query,
|
||||
tools=query_c.tools,
|
||||
# parsed_output_list=[query_c.parsed_output]
|
||||
customed_kargs=query.customed_kargs
|
||||
)
|
||||
# common parse llm' content to message
|
||||
select_message = self.message_utils.parser(select_message)
|
||||
select_message.parsed_output_list.append(select_message.parsed_output)
|
||||
|
||||
output_message = None
|
||||
if select_message.parsed_output.get("Role", "") in [agent.role.role_name for agent in self.group_agents]:
|
||||
for agent in self.group_agents:
|
||||
if agent.role.role_name == select_message.parsed_output.get("Role", ""):
|
||||
break
|
||||
|
||||
# 把除了role以外的信息传给下一个agent
|
||||
query_c.parsed_output.update({k:v for k,v in select_message.parsed_output.items() if k!="Role"})
|
||||
for output_message in agent.astep(query_c, history, background=background, memory_manager=memory_manager):
|
||||
yield output_message or select_message
|
||||
# update self_memory
|
||||
self.append_history(query_c)
|
||||
self.append_history(output_message)
|
||||
output_message.input_query = output_message.role_content
|
||||
# output_message.parsed_output_list.append(output_message.parsed_output)
|
||||
#
|
||||
output_message = self.end_action_step(output_message)
|
||||
# update memory pool
|
||||
memory_manager.append(output_message)
|
||||
|
||||
select_message.parsed_output = output_message.parsed_output
|
||||
select_message.spec_parsed_output.update(output_message.spec_parsed_output)
|
||||
select_message.parsed_output_list.extend(output_message.parsed_output_list)
|
||||
yield select_message
|
||||
|
||||
def pre_print(self, query: Message, history: Memory = None, background: Memory = None, memory_manager: BaseMemoryManager=None):
|
||||
prompt = self.prompt_manager.pre_print(
|
||||
previous_agent_message=query, agent_long_term_memory=self.memory, ui_history=history, chain_summary_messages=background, react_memory=None,
|
||||
memory_pool=memory_manager.current_memory, agents=self.group_agents)
|
||||
title = f"<<<<{self.role.role_name}'s prompt>>>>"
|
||||
print("#"*len(title) + f"\n{title}\n"+ "#"*len(title)+ f"\n\n{prompt}\n\n")
|
||||
|
||||
for agent in self.group_agents:
|
||||
agent.pre_print(query=query, history=history, background=background, memory_manager=memory_manager)
|
|
@ -1,7 +0,0 @@
|
|||
from .flow import AgentFlow, PhaseFlow, ChainFlow
|
||||
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AgentFlow", "PhaseFlow", "ChainFlow"
|
||||
]
|
|
@ -1,255 +0,0 @@
|
|||
import importlib
|
||||
from typing import List, Union, Dict, Any
|
||||
from loguru import logger
|
||||
import os
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.agents import Tool
|
||||
from langchain.llms.base import BaseLLM, LLM
|
||||
|
||||
from coagent.retrieval.base_retrieval import IMRertrieval
|
||||
from coagent.llm_models.llm_config import EmbedConfig, LLMConfig
|
||||
from coagent.connector.phase import BasePhase
|
||||
from coagent.connector.agents import BaseAgent
|
||||
from coagent.connector.chains import BaseChain
|
||||
from coagent.connector.schema import Message, Role, PromptField, ChainConfig
|
||||
from coagent.tools import toLangchainTools, TOOL_DICT, TOOL_SETS
|
||||
|
||||
|
||||
class AgentFlow:
|
||||
def __init__(
|
||||
self,
|
||||
role_name: str,
|
||||
agent_type: str,
|
||||
role_type: str = "assistant",
|
||||
agent_index: int = 0,
|
||||
role_prompt: str = "",
|
||||
prompt_config: List[Dict[str, Any]] = [],
|
||||
prompt_manager_type: str = "PromptManager",
|
||||
chat_turn: int = 3,
|
||||
focus_agents: List[str] = [],
|
||||
focus_messages: List[str] = [],
|
||||
embeddings: Embeddings = None,
|
||||
llm: BaseLLM = None,
|
||||
doc_retrieval: IMRertrieval = None,
|
||||
code_retrieval: IMRertrieval = None,
|
||||
search_retrieval: IMRertrieval = None,
|
||||
**kwargs
|
||||
):
|
||||
self.role_type = role_type
|
||||
self.role_name = role_name
|
||||
self.agent_type = agent_type
|
||||
self.role_prompt = role_prompt
|
||||
self.agent_index = agent_index
|
||||
|
||||
self.prompt_config = prompt_config
|
||||
self.prompt_manager_type = prompt_manager_type
|
||||
|
||||
self.chat_turn = chat_turn
|
||||
self.focus_agents = focus_agents
|
||||
self.focus_messages = focus_messages
|
||||
|
||||
self.embeddings = embeddings
|
||||
self.llm = llm
|
||||
self.doc_retrieval = doc_retrieval
|
||||
self.code_retrieval = code_retrieval
|
||||
self.search_retrieval = search_retrieval
|
||||
# self.build_config()
|
||||
# self.build_agent()
|
||||
|
||||
def build_config(self, embeddings: Embeddings = None, llm: BaseLLM = None):
|
||||
self.llm_config = LLMConfig(model_name="test", llm=self.llm or llm)
|
||||
self.embed_config = EmbedConfig(embed_model="test", langchain_embeddings=self.embeddings or embeddings)
|
||||
|
||||
def build_agent(self,
|
||||
embeddings: Embeddings = None, llm: BaseLLM = None,
|
||||
doc_retrieval: IMRertrieval = None,
|
||||
code_retrieval: IMRertrieval = None,
|
||||
search_retrieval: IMRertrieval = None,
|
||||
):
|
||||
# 可注册个性化的agent,仅通过start_action和end_action来注册
|
||||
# class ExtraAgent(BaseAgent):
|
||||
# def start_action_step(self, message: Message) -> Message:
|
||||
# pass
|
||||
|
||||
# def end_action_step(self, message: Message) -> Message:
|
||||
# pass
|
||||
# agent_module = importlib.import_module("coagent.connector.agents")
|
||||
# setattr(agent_module, 'extraAgent', ExtraAgent)
|
||||
|
||||
# 可注册个性化的prompt组装方式,
|
||||
# class CodeRetrievalPM(PromptManager):
|
||||
# def handle_code_packages(self, **kwargs) -> str:
|
||||
# if 'previous_agent_message' not in kwargs:
|
||||
# return ""
|
||||
# previous_agent_message: Message = kwargs['previous_agent_message']
|
||||
# # 由于两个agent共用了同一个manager,所以临时性处理
|
||||
# vertices = previous_agent_message.customed_kargs.get("RelatedVerticesRetrivalRes", {}).get("vertices", [])
|
||||
# return ", ".join([str(v) for v in vertices])
|
||||
|
||||
# prompt_manager_module = importlib.import_module("coagent.connector.prompt_manager")
|
||||
# setattr(prompt_manager_module, 'CodeRetrievalPM', CodeRetrievalPM)
|
||||
|
||||
# agent实例化
|
||||
agent_module = importlib.import_module("coagent.connector.agents")
|
||||
baseAgent: BaseAgent = getattr(agent_module, self.agent_type)
|
||||
role = Role(
|
||||
role_type=self.agent_type, role_name=self.role_name,
|
||||
agent_type=self.agent_type, role_prompt=self.role_prompt,
|
||||
)
|
||||
|
||||
self.build_config(embeddings, llm)
|
||||
self.agent = baseAgent(
|
||||
role=role,
|
||||
prompt_config = [PromptField(**config) for config in self.prompt_config],
|
||||
prompt_manager_type=self.prompt_manager_type,
|
||||
chat_turn=self.chat_turn,
|
||||
focus_agents=self.focus_agents,
|
||||
focus_message_keys=self.focus_messages,
|
||||
llm_config=self.llm_config,
|
||||
embed_config=self.embed_config,
|
||||
doc_retrieval=doc_retrieval or self.doc_retrieval,
|
||||
code_retrieval=code_retrieval or self.code_retrieval,
|
||||
search_retrieval=search_retrieval or self.search_retrieval,
|
||||
)
|
||||
|
||||
|
||||
|
||||
class ChainFlow:
|
||||
def __init__(
|
||||
self,
|
||||
chain_name: str,
|
||||
chain_index: int = 0,
|
||||
agent_flows: List[AgentFlow] = [],
|
||||
chat_turn: int = 5,
|
||||
do_checker: bool = False,
|
||||
embeddings: Embeddings = None,
|
||||
llm: BaseLLM = None,
|
||||
doc_retrieval: IMRertrieval = None,
|
||||
code_retrieval: IMRertrieval = None,
|
||||
search_retrieval: IMRertrieval = None,
|
||||
# chain_type: str = "BaseChain",
|
||||
**kwargs
|
||||
):
|
||||
self.agent_flows = sorted(agent_flows, key=lambda x:x.agent_index)
|
||||
self.chat_turn = chat_turn
|
||||
self.do_checker = do_checker
|
||||
self.chain_name = chain_name
|
||||
self.chain_index = chain_index
|
||||
self.chain_type = "BaseChain"
|
||||
|
||||
self.embeddings = embeddings
|
||||
self.llm = llm
|
||||
|
||||
self.doc_retrieval = doc_retrieval
|
||||
self.code_retrieval = code_retrieval
|
||||
self.search_retrieval = search_retrieval
|
||||
# self.build_config()
|
||||
# self.build_chain()
|
||||
|
||||
def build_config(self, embeddings: Embeddings = None, llm: BaseLLM = None):
|
||||
self.llm_config = LLMConfig(model_name="test", llm=self.llm or llm)
|
||||
self.embed_config = EmbedConfig(embed_model="test", langchain_embeddings=self.embeddings or embeddings)
|
||||
|
||||
def build_chain(self,
|
||||
embeddings: Embeddings = None, llm: BaseLLM = None,
|
||||
doc_retrieval: IMRertrieval = None,
|
||||
code_retrieval: IMRertrieval = None,
|
||||
search_retrieval: IMRertrieval = None,
|
||||
):
|
||||
# chain 实例化
|
||||
chain_module = importlib.import_module("coagent.connector.chains")
|
||||
baseChain: BaseChain = getattr(chain_module, self.chain_type)
|
||||
|
||||
agent_names = [agent_flow.role_name for agent_flow in self.agent_flows]
|
||||
chain_config = ChainConfig(chain_name=self.chain_name, agents=agent_names, do_checker=self.do_checker, chat_turn=self.chat_turn)
|
||||
|
||||
# agent 实例化
|
||||
self.build_config(embeddings, llm)
|
||||
for agent_flow in self.agent_flows:
|
||||
agent_flow.build_agent(embeddings, llm)
|
||||
|
||||
self.chain = baseChain(
|
||||
chain_config,
|
||||
[agent_flow.agent for agent_flow in self.agent_flows],
|
||||
embed_config=self.embed_config,
|
||||
llm_config=self.llm_config,
|
||||
doc_retrieval=doc_retrieval or self.doc_retrieval,
|
||||
code_retrieval=code_retrieval or self.code_retrieval,
|
||||
search_retrieval=search_retrieval or self.search_retrieval,
|
||||
)
|
||||
|
||||
class PhaseFlow:
|
||||
def __init__(
|
||||
self,
|
||||
phase_name: str,
|
||||
chain_flows: List[ChainFlow],
|
||||
embeddings: Embeddings = None,
|
||||
llm: BaseLLM = None,
|
||||
tools: List[Tool] = [],
|
||||
doc_retrieval: IMRertrieval = None,
|
||||
code_retrieval: IMRertrieval = None,
|
||||
search_retrieval: IMRertrieval = None,
|
||||
**kwargs
|
||||
):
|
||||
self.phase_name = phase_name
|
||||
self.chain_flows = sorted(chain_flows, key=lambda x:x.chain_index)
|
||||
self.phase_type = "BasePhase"
|
||||
self.tools = tools
|
||||
|
||||
self.embeddings = embeddings
|
||||
self.llm = llm
|
||||
|
||||
self.doc_retrieval = doc_retrieval
|
||||
self.code_retrieval = code_retrieval
|
||||
self.search_retrieval = search_retrieval
|
||||
# self.build_config()
|
||||
self.build_phase()
|
||||
|
||||
def __call__(self, params: dict) -> str:
|
||||
|
||||
# tools = toLangchainTools([TOOL_DICT[i] for i in TOOL_SETS if i in TOOL_DICT])
|
||||
# query_content = "帮我确认下127.0.0.1这个服务器的在10点是否存在异常,请帮我判断一下"
|
||||
try:
|
||||
logger.info(f"params: {params}")
|
||||
query_content = params.get("query") or params.get("input")
|
||||
search_type = params.get("search_type")
|
||||
query = Message(
|
||||
role_name="human", role_type="user", tools=self.tools,
|
||||
role_content=query_content, input_query=query_content, origin_query=query_content,
|
||||
cb_search_type=search_type,
|
||||
)
|
||||
# phase.pre_print(query)
|
||||
output_message, output_memory = self.phase.step(query)
|
||||
output_content = "\n\n".join((output_memory.to_str_messages(return_all=True, content_key="parsed_output_list").split("\n\n")[1:])) or output_message.role_content
|
||||
return output_content
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
return f"Error {e}"
|
||||
|
||||
def build_config(self, embeddings: Embeddings = None, llm: BaseLLM = None):
|
||||
self.llm_config = LLMConfig(model_name="test", llm=self.llm or llm)
|
||||
self.embed_config = EmbedConfig(embed_model="test", langchain_embeddings=self.embeddings or embeddings)
|
||||
|
||||
def build_phase(self, embeddings: Embeddings = None, llm: BaseLLM = None):
|
||||
# phase 实例化
|
||||
phase_module = importlib.import_module("coagent.connector.phase")
|
||||
basePhase: BasePhase = getattr(phase_module, self.phase_type)
|
||||
|
||||
# chain 实例化
|
||||
self.build_config(self.embeddings or embeddings, self.llm or llm)
|
||||
os.environ["log_verbose"] = "2"
|
||||
for chain_flow in self.chain_flows:
|
||||
chain_flow.build_chain(
|
||||
self.embeddings or embeddings, self.llm or llm,
|
||||
self.doc_retrieval, self.code_retrieval, self.search_retrieval
|
||||
)
|
||||
|
||||
self.phase: BasePhase = basePhase(
|
||||
phase_name=self.phase_name,
|
||||
chains=[chain_flow.chain for chain_flow in self.chain_flows],
|
||||
embed_config=self.embed_config,
|
||||
llm_config=self.llm_config,
|
||||
doc_retrieval=self.doc_retrieval,
|
||||
code_retrieval=self.code_retrieval,
|
||||
search_retrieval=self.search_retrieval
|
||||
)
|
|
@ -1,5 +0,0 @@
|
|||
from .base_chain import BaseChain
|
||||
|
||||
__all__ = [
|
||||
"BaseChain"
|
||||
]
|
|
@ -1,130 +0,0 @@
|
|||
from typing import List, Tuple, Union
|
||||
from loguru import logger
|
||||
import copy, os
|
||||
|
||||
from langchain.schema import BaseRetriever
|
||||
|
||||
from coagent.connector.agents import BaseAgent
|
||||
from coagent.connector.schema import (
|
||||
Memory, Role, Message, ActionStatus, ChainConfig,
|
||||
load_role_configs
|
||||
)
|
||||
from coagent.connector.memory_manager import BaseMemoryManager
|
||||
from coagent.connector.message_process import MessageUtils
|
||||
from coagent.llm_models.llm_config import LLMConfig, EmbedConfig
|
||||
from coagent.base_configs.env_config import JUPYTER_WORK_PATH, KB_ROOT_PATH
|
||||
from coagent.connector.configs.agent_config import AGETN_CONFIGS
|
||||
role_configs = load_role_configs(AGETN_CONFIGS)
|
||||
|
||||
|
||||
class BaseChain:
|
||||
def __init__(
|
||||
self,
|
||||
chainConfig: ChainConfig,
|
||||
agents: List[BaseAgent],
|
||||
# chat_turn: int = 1,
|
||||
# do_checker: bool = False,
|
||||
sandbox_server: dict = {},
|
||||
jupyter_work_path: str = JUPYTER_WORK_PATH,
|
||||
kb_root_path: str = KB_ROOT_PATH,
|
||||
llm_config: LLMConfig = LLMConfig(),
|
||||
embed_config: EmbedConfig = None,
|
||||
doc_retrieval: Union[BaseRetriever] = None,
|
||||
code_retrieval = None,
|
||||
search_retrieval = None,
|
||||
log_verbose: str = "0"
|
||||
) -> None:
|
||||
self.chainConfig = chainConfig
|
||||
self.agents: List[BaseAgent] = agents
|
||||
self.chat_turn = chainConfig.chat_turn
|
||||
self.do_checker = chainConfig.do_checker
|
||||
self.sandbox_server = sandbox_server
|
||||
self.jupyter_work_path = jupyter_work_path
|
||||
self.llm_config = llm_config
|
||||
self.log_verbose = max(os.environ.get("log_verbose", "0"), log_verbose)
|
||||
self.checker = BaseAgent(role=role_configs["checker"].role,
|
||||
prompt_config=role_configs["checker"].prompt_config,
|
||||
task = None, memory = None,
|
||||
llm_config=llm_config, embed_config=embed_config,
|
||||
sandbox_server=sandbox_server, jupyter_work_path=jupyter_work_path,
|
||||
kb_root_path=kb_root_path,
|
||||
doc_retrieval=doc_retrieval, code_retrieval=code_retrieval,
|
||||
search_retrieval=search_retrieval
|
||||
)
|
||||
self.messageUtils = MessageUtils(None, sandbox_server, self.jupyter_work_path, embed_config, llm_config, kb_root_path, doc_retrieval, code_retrieval, search_retrieval, log_verbose)
|
||||
# all memory created by agent until instance deleted
|
||||
self.global_memory = Memory(messages=[])
|
||||
|
||||
def step(self, query: Message, history: Memory = None, background: Memory = None, memory_manager: BaseMemoryManager = None) -> Message:
|
||||
'''execute chain'''
|
||||
for output_message, local_memory in self.astep(query, history, background, memory_manager):
|
||||
pass
|
||||
return output_message, local_memory
|
||||
|
||||
def pre_print(self, query: Message, history: Memory = None, background: Memory = None, memory_manager: BaseMemoryManager = None) -> Message:
|
||||
'''execute chain'''
|
||||
for agent in self.agents:
|
||||
agent.pre_print(query, history, background=background, memory_manager=memory_manager)
|
||||
|
||||
def astep(self, query: Message, history: Memory = None, background: Memory = None, memory_manager: BaseMemoryManager = None) -> Tuple[Message, Memory]:
|
||||
'''execute chain'''
|
||||
local_memory = Memory(messages=[])
|
||||
input_message = copy.deepcopy(query)
|
||||
step_nums = copy.deepcopy(self.chat_turn)
|
||||
check_message = None
|
||||
|
||||
# if input_message not in memory_manager:
|
||||
# memory_manager.append(input_message)
|
||||
|
||||
self.global_memory.append(input_message)
|
||||
# local_memory.append(input_message)
|
||||
while step_nums > 0:
|
||||
for agent in self.agents:
|
||||
for output_message in agent.astep(input_message, history, background=background, memory_manager=memory_manager):
|
||||
# logger.debug(f"local_memory {local_memory + output_message}")
|
||||
yield output_message, local_memory + output_message
|
||||
output_message = self.messageUtils.inherit_extrainfo(input_message, output_message)
|
||||
# according the output to choose one action for code_content or tool_content
|
||||
# output_message = self.messageUtils.parser(output_message)
|
||||
yield output_message, local_memory + output_message
|
||||
# output_message = self.step_router(output_message)
|
||||
input_message = output_message
|
||||
self.global_memory.append(output_message)
|
||||
|
||||
local_memory.append(output_message)
|
||||
# when get finished signal can stop early
|
||||
if output_message.action_status == ActionStatus.FINISHED or output_message.action_status == ActionStatus.STOPPED:
|
||||
action_status = False
|
||||
break
|
||||
if output_message.action_status == ActionStatus.FINISHED:
|
||||
break
|
||||
|
||||
if self.do_checker and self.chat_turn > 1:
|
||||
for check_message in self.checker.astep(query, background=local_memory, memory_manager=memory_manager):
|
||||
pass
|
||||
check_message = self.messageUtils.parser(check_message)
|
||||
check_message = self.messageUtils.inherit_extrainfo(output_message, check_message)
|
||||
# logger.debug(f"{self.checker.role.role_name}: {check_message.role_content}")
|
||||
|
||||
if check_message.action_status == ActionStatus.FINISHED:
|
||||
self.global_memory.append(check_message)
|
||||
break
|
||||
step_nums -= 1
|
||||
#
|
||||
output_message = check_message or output_message # 返回chain和checker的结果
|
||||
output_message.input_query = query.input_query # chain和chain之间消息通信不改变问题
|
||||
yield output_message, local_memory
|
||||
|
||||
def get_memory(self, content_key="role_content") -> Memory:
|
||||
memory = self.global_memory
|
||||
return memory.to_tuple_messages(content_key=content_key)
|
||||
|
||||
def get_memory_str(self, content_key="role_content") -> Memory:
|
||||
memory = self.global_memory
|
||||
return "\n".join([": ".join(i) for i in memory.to_tuple_messages(content_key=content_key)])
|
||||
|
||||
def get_agents_memory(self, content_key="role_content"):
|
||||
return [agent.get_memory(content_key=content_key) for agent in self.agents]
|
||||
|
||||
def get_agents_memory_str(self, content_key="role_content"):
|
||||
return "************".join([f"{agent.role.role_name}\n" + agent.get_memory_str(content_key=content_key) for agent in self.agents])
|
|
@ -1,12 +0,0 @@
|
|||
from typing import List
|
||||
from loguru import logger
|
||||
from coagent.connector.agents import BaseAgent
|
||||
from .base_chain import BaseChain
|
||||
|
||||
|
||||
|
||||
|
||||
class ExecutorRefineChain(BaseChain):
|
||||
|
||||
def __init__(self, agents: List[BaseAgent], do_code_exec: bool = False) -> None:
|
||||
super().__init__(agents, do_code_exec)
|
|
@ -1,10 +0,0 @@
|
|||
from .agent_config import AGETN_CONFIGS
|
||||
from .chain_config import CHAIN_CONFIGS
|
||||
from .phase_config import PHASE_CONFIGS
|
||||
from .prompt_config import *
|
||||
|
||||
__all__ = [
|
||||
"AGETN_CONFIGS", "CHAIN_CONFIGS", "PHASE_CONFIGS",
|
||||
"BASE_PROMPT_CONFIGS", "EXECUTOR_PROMPT_CONFIGS", "SELECTOR_PROMPT_CONFIGS", "BASE_NOTOOLPROMPT_CONFIGS",
|
||||
"CODE2DOC_GROUP_PROMPT_CONFIGS", "CODE2DOC_PROMPT_CONFIGS", "CODE2TESTS_PROMPT_CONFIGS"
|
||||
]
|
|
@ -1,330 +0,0 @@
|
|||
from enum import Enum
|
||||
from .prompts import *
|
||||
# from .prompts import (
|
||||
# REACT_PROMPT_INPUT, CHECK_PROMPT_INPUT, EXECUTOR_PROMPT_INPUT, CONTEXT_PROMPT_INPUT, QUERY_CONTEXT_PROMPT_INPUT,PLAN_PROMPT_INPUT,
|
||||
# RECOGNIZE_INTENTION_PROMPT,
|
||||
# CHECKER_TEMPLATE_PROMPT,
|
||||
# CONV_SUMMARY_PROMPT,
|
||||
# QA_PROMPT, CODE_QA_PROMPT, QA_TEMPLATE_PROMPT,
|
||||
# EXECUTOR_TEMPLATE_PROMPT,
|
||||
# REFINE_TEMPLATE_PROMPT,
|
||||
# SELECTOR_AGENT_TEMPLATE_PROMPT,
|
||||
# PLANNER_TEMPLATE_PROMPT, GENERAL_PLANNER_PROMPT, DATA_PLANNER_PROMPT, TOOL_PLANNER_PROMPT,
|
||||
# PRD_WRITER_METAGPT_PROMPT, DESIGN_WRITER_METAGPT_PROMPT, TASK_WRITER_METAGPT_PROMPT, CODE_WRITER_METAGPT_PROMPT,
|
||||
# REACT_TEMPLATE_PROMPT,
|
||||
# REACT_TOOL_PROMPT, REACT_CODE_PROMPT, REACT_TOOL_AND_CODE_PLANNER_PROMPT, REACT_TOOL_AND_CODE_PROMPT
|
||||
# )
|
||||
from .prompt_config import *
|
||||
# BASE_PROMPT_CONFIGS, EXECUTOR_PROMPT_CONFIGS, SELECTOR_PROMPT_CONFIGS, BASE_NOTOOLPROMPT_CONFIGS
|
||||
|
||||
|
||||
|
||||
class AgentType:
|
||||
REACT = "ReactAgent"
|
||||
EXECUTOR = "ExecutorAgent"
|
||||
ONE_STEP = "BaseAgent"
|
||||
DEFAULT = "BaseAgent"
|
||||
SELECTOR = "SelectorAgent"
|
||||
|
||||
|
||||
|
||||
AGETN_CONFIGS = {
|
||||
"baseGroup": {
|
||||
"role": {
|
||||
"role_prompt": SELECTOR_AGENT_TEMPLATE_PROMPT,
|
||||
"role_type": "assistant",
|
||||
"role_name": "baseGroup",
|
||||
"role_desc": "",
|
||||
"agent_type": "SelectorAgent"
|
||||
},
|
||||
"prompt_config": SELECTOR_PROMPT_CONFIGS,
|
||||
"group_agents": ["tool_react", "code_react"],
|
||||
"chat_turn": 1,
|
||||
},
|
||||
"checker": {
|
||||
"role": {
|
||||
"role_prompt": CHECKER_TEMPLATE_PROMPT,
|
||||
"role_type": "assistant",
|
||||
"role_name": "checker",
|
||||
"role_desc": "",
|
||||
"agent_type": "BaseAgent"
|
||||
},
|
||||
"prompt_config": BASE_PROMPT_CONFIGS,
|
||||
"chat_turn": 1,
|
||||
},
|
||||
"conv_summary": {
|
||||
"role": {
|
||||
"role_prompt": CONV_SUMMARY_PROMPT,
|
||||
"role_type": "assistant",
|
||||
"role_name": "conv_summary",
|
||||
"role_desc": "",
|
||||
"agent_type": "BaseAgent"
|
||||
},
|
||||
"prompt_config": BASE_PROMPT_CONFIGS,
|
||||
"chat_turn": 1,
|
||||
},
|
||||
"general_planner": {
|
||||
"role": {
|
||||
"role_prompt": PLANNER_TEMPLATE_PROMPT,
|
||||
"role_type": "assistant",
|
||||
"role_name": "general_planner",
|
||||
"role_desc": "",
|
||||
"agent_type": "BaseAgent"
|
||||
},
|
||||
"prompt_config": BASE_PROMPT_CONFIGS,
|
||||
"chat_turn": 1,
|
||||
},
|
||||
"executor": {
|
||||
"role": {
|
||||
"role_prompt": EXECUTOR_TEMPLATE_PROMPT,
|
||||
"role_type": "assistant",
|
||||
"role_name": "executor",
|
||||
"role_desc": "",
|
||||
"agent_type": "ExecutorAgent",
|
||||
},
|
||||
"prompt_config": EXECUTOR_PROMPT_CONFIGS,
|
||||
"stop": "\n**Observation:**",
|
||||
"chat_turn": 1,
|
||||
},
|
||||
"base_refiner": {
|
||||
"role": {
|
||||
"role_prompt": REFINE_TEMPLATE_PROMPT,
|
||||
"role_type": "assistant",
|
||||
"role_name": "base_refiner",
|
||||
"role_desc": "",
|
||||
"agent_type": "BaseAgent"
|
||||
},
|
||||
"prompt_config": BASE_PROMPT_CONFIGS,
|
||||
"chat_turn": 1,
|
||||
},
|
||||
"planner": {
|
||||
"role": {
|
||||
"role_prompt": DATA_PLANNER_PROMPT,
|
||||
"role_type": "assistant",
|
||||
"role_name": "planner",
|
||||
"role_desc": "",
|
||||
"agent_type": "BaseAgent"
|
||||
},
|
||||
"prompt_config": BASE_PROMPT_CONFIGS,
|
||||
"chat_turn": 1,
|
||||
},
|
||||
"intention_recognizer": {
|
||||
"role": {
|
||||
"role_prompt": RECOGNIZE_INTENTION_PROMPT,
|
||||
"role_type": "assistant",
|
||||
"role_name": "intention_recognizer",
|
||||
"role_desc": "",
|
||||
"agent_type": "BaseAgent"
|
||||
},
|
||||
"prompt_config": BASE_PROMPT_CONFIGS,
|
||||
"chat_turn": 1,
|
||||
},
|
||||
"tool_planner": {
|
||||
"role": {
|
||||
"role_prompt": TOOL_PLANNER_PROMPT,
|
||||
"role_type": "assistant",
|
||||
"role_name": "tool_planner",
|
||||
"role_desc": "",
|
||||
"agent_type": "BaseAgent"
|
||||
},
|
||||
"prompt_config": BASE_PROMPT_CONFIGS,
|
||||
"chat_turn": 1,
|
||||
},
|
||||
"tool_and_code_react": {
|
||||
"role": {
|
||||
"role_prompt": REACT_TOOL_AND_CODE_PROMPT,
|
||||
"role_type": "assistant",
|
||||
"role_name": "tool_and_code_react",
|
||||
"role_desc": "",
|
||||
"agent_type": "ReactAgent",
|
||||
},
|
||||
"prompt_config": BASE_PROMPT_CONFIGS,
|
||||
"stop": "\n**Observation:**",
|
||||
"chat_turn": 7,
|
||||
},
|
||||
"tool_and_code_planner": {
|
||||
"role": {
|
||||
"role_prompt": REACT_TOOL_AND_CODE_PLANNER_PROMPT,
|
||||
"role_type": "assistant",
|
||||
"role_name": "tool_and_code_planner",
|
||||
"role_desc": "",
|
||||
"agent_type": "BaseAgent"
|
||||
},
|
||||
"prompt_config": BASE_PROMPT_CONFIGS,
|
||||
"chat_turn": 1,
|
||||
},
|
||||
"tool_react": {
|
||||
"role": {
|
||||
"role_prompt": REACT_TOOL_PROMPT,
|
||||
"role_type": "assistant",
|
||||
"role_name": "tool_react",
|
||||
"role_desc": "",
|
||||
"agent_type": "ReactAgent"
|
||||
},
|
||||
"prompt_config": BASE_PROMPT_CONFIGS,
|
||||
"chat_turn": 5,
|
||||
"stop": "\n**Observation:**"
|
||||
},
|
||||
"code_react": {
|
||||
"role": {
|
||||
"role_prompt": REACT_CODE_PROMPT,
|
||||
"role_type": "assistant",
|
||||
"role_name": "code_react",
|
||||
"role_desc": "",
|
||||
"agent_type": "ReactAgent"
|
||||
},
|
||||
"prompt_config": BASE_NOTOOLPROMPT_CONFIGS,
|
||||
"chat_turn": 5,
|
||||
"stop": "\n**Observation:**"
|
||||
},
|
||||
"qaer": {
|
||||
"role": {
|
||||
"role_prompt": QA_TEMPLATE_PROMPT,
|
||||
"role_type": "assistant",
|
||||
"role_name": "qaer",
|
||||
"role_desc": "",
|
||||
"agent_type": "BaseAgent"
|
||||
},
|
||||
"prompt_config": BASE_PROMPT_CONFIGS,
|
||||
"chat_turn": 1,
|
||||
},
|
||||
"code_qaer": {
|
||||
"role": {
|
||||
"role_prompt": CODE_QA_PROMPT,
|
||||
"role_type": "assistant",
|
||||
"role_name": "code_qaer",
|
||||
"role_desc": "",
|
||||
"agent_type": "BaseAgent"
|
||||
},
|
||||
"prompt_config": BASE_PROMPT_CONFIGS,
|
||||
"chat_turn": 1,
|
||||
},
|
||||
"searcher": {
|
||||
"role": {
|
||||
"role_prompt": QA_TEMPLATE_PROMPT,
|
||||
"role_type": "assistant",
|
||||
"role_name": "searcher",
|
||||
"role_desc": "",
|
||||
"agent_type": "BaseAgent"
|
||||
},
|
||||
"prompt_config": BASE_PROMPT_CONFIGS,
|
||||
"chat_turn": 1,
|
||||
},
|
||||
"metaGPT_PRD": {
|
||||
"role": {
|
||||
"role_prompt": PRD_WRITER_METAGPT_PROMPT,
|
||||
"role_type": "assistant",
|
||||
"role_name": "metaGPT_PRD",
|
||||
"role_desc": "",
|
||||
"agent_type": "BaseAgent"
|
||||
},
|
||||
"prompt_config": BASE_PROMPT_CONFIGS,
|
||||
"chat_turn": 1,
|
||||
"focus_agents": [],
|
||||
"focus_message_keys": [],
|
||||
},
|
||||
|
||||
"metaGPT_DESIGN": {
|
||||
"role": {
|
||||
"role_prompt": DESIGN_WRITER_METAGPT_PROMPT,
|
||||
"role_type": "assistant",
|
||||
"role_name": "metaGPT_DESIGN",
|
||||
"role_desc": "",
|
||||
"agent_type": "BaseAgent"
|
||||
},
|
||||
"prompt_config": BASE_PROMPT_CONFIGS,
|
||||
"chat_turn": 1,
|
||||
"focus_agents": ["metaGPT_PRD"],
|
||||
"focus_message_keys": [],
|
||||
},
|
||||
"metaGPT_TASK": {
|
||||
"role": {
|
||||
"role_prompt": TASK_WRITER_METAGPT_PROMPT,
|
||||
"role_type": "assistant",
|
||||
"role_name": "metaGPT_TASK",
|
||||
"role_desc": "",
|
||||
"agent_type": "BaseAgent"
|
||||
},
|
||||
"prompt_config": BASE_PROMPT_CONFIGS,
|
||||
"chat_turn": 1,
|
||||
"focus_agents": ["metaGPT_DESIGN"],
|
||||
"focus_message_keys": [],
|
||||
},
|
||||
"metaGPT_CODER": {
|
||||
"role": {
|
||||
"role_prompt": CODE_WRITER_METAGPT_PROMPT,
|
||||
"role_type": "assistant",
|
||||
"role_name": "metaGPT_CODER",
|
||||
"role_desc": "",
|
||||
"agent_type": "ExecutorAgent"
|
||||
},
|
||||
"prompt_config": EXECUTOR_PROMPT_CONFIGS,
|
||||
"chat_turn": 1,
|
||||
"focus_agents": ["metaGPT_DESIGN", "metaGPT_TASK"],
|
||||
"focus_message_keys": [],
|
||||
},
|
||||
"class2Docer": {
|
||||
"role": {
|
||||
"role_prompt": Class2Doc_PROMPT,
|
||||
"role_type": "assistant",
|
||||
"role_name": "class2Docer",
|
||||
"role_desc": "",
|
||||
"agent_type": "CodeGenDocer"
|
||||
},
|
||||
"prompt_config": CODE2DOC_PROMPT_CONFIGS,
|
||||
"prompt_manager_type": "Code2DocPM",
|
||||
"chat_turn": 1,
|
||||
"focus_agents": [],
|
||||
"focus_message_keys": [],
|
||||
},
|
||||
"func2Docer": {
|
||||
"role": {
|
||||
"role_prompt": Func2Doc_PROMPT,
|
||||
"role_type": "assistant",
|
||||
"role_name": "func2Docer",
|
||||
"role_desc": "",
|
||||
"agent_type": "CodeGenDocer"
|
||||
},
|
||||
"prompt_config": CODE2DOC_PROMPT_CONFIGS,
|
||||
"prompt_manager_type": "Code2DocPM",
|
||||
"chat_turn": 1,
|
||||
"focus_agents": [],
|
||||
"focus_message_keys": [],
|
||||
},
|
||||
"code2DocsGrouper": {
|
||||
"role": {
|
||||
"role_prompt": Code2DocGroup_PROMPT,
|
||||
"role_type": "assistant",
|
||||
"role_name": "code2DocsGrouper",
|
||||
"role_desc": "",
|
||||
"agent_type": "SelectorAgent"
|
||||
},
|
||||
"prompt_config": CODE2DOC_GROUP_PROMPT_CONFIGS,
|
||||
"group_agents": ["class2Docer", "func2Docer"],
|
||||
"chat_turn": 1,
|
||||
},
|
||||
"Code2TestJudger": {
|
||||
"role": {
|
||||
"role_prompt": judgeCode2Tests_PROMPT,
|
||||
"role_type": "assistant",
|
||||
"role_name": "Code2TestJudger",
|
||||
"role_desc": "",
|
||||
"agent_type": "CodeRetrieval"
|
||||
},
|
||||
"prompt_config": CODE2TESTS_PROMPT_CONFIGS,
|
||||
"prompt_manager_type": "CodeRetrievalPM",
|
||||
"chat_turn": 1,
|
||||
},
|
||||
"code2Tests": {
|
||||
"role": {
|
||||
"role_prompt": code2Tests_PROMPT,
|
||||
"role_type": "assistant",
|
||||
"role_name": "code2Tests",
|
||||
"role_desc": "",
|
||||
"agent_type": "CodeRetrieval"
|
||||
},
|
||||
"prompt_config": CODE2TESTS_PROMPT_CONFIGS,
|
||||
"prompt_manager_type": "CodeRetrievalPM",
|
||||
"chat_turn": 1,
|
||||
},
|
||||
}
|
|
@ -1,99 +0,0 @@
|
|||
You are a Architect, named Bob, your goal is Design a concise, usable, complete python system, and the constraint is Try to specify good open source tools as much as possible.
|
||||
|
||||
# Context
|
||||
## Original Requirements:
|
||||
Create a snake game.
|
||||
|
||||
## Product Goals:
|
||||
Develop a highly addictive and engaging snake game.
|
||||
Provide a user-friendly and intuitive user interface.
|
||||
Implement various levels and challenges to keep the players entertained.
|
||||
## User Stories:
|
||||
As a user, I want to be able to control the snake's movement using arrow keys or touch gestures.
|
||||
As a user, I want to see my score and progress displayed on the screen.
|
||||
As a user, I want to be able to pause and resume the game at any time.
|
||||
As a user, I want to be challenged with different obstacles and levels as I progress.
|
||||
As a user, I want to have the option to compete with other players and compare my scores.
|
||||
## Competitive Analysis:
|
||||
Python Snake Game: A simple snake game implemented in Python with basic features and limited levels.
|
||||
Snake.io: A multiplayer online snake game with competitive gameplay and high engagement.
|
||||
Slither.io: Another multiplayer online snake game with a larger player base and addictive gameplay.
|
||||
Snake Zone: A mobile snake game with various power-ups and challenges.
|
||||
Snake Mania: A classic snake game with modern graphics and smooth controls.
|
||||
Snake Rush: A fast-paced snake game with time-limited challenges.
|
||||
Snake Master: A snake game with unique themes and customizable snakes.
|
||||
|
||||
## Requirement Analysis:
|
||||
The product should be a highly addictive and engaging snake game with a user-friendly interface. It should provide various levels and challenges to keep the players entertained. The game should have smooth controls and allow the users to compete with each other.
|
||||
|
||||
## Requirement Pool:
|
||||
```
|
||||
[
|
||||
["Implement different levels with increasing difficulty", "P0"],
|
||||
["Allow users to control the snake using arrow keys or touch gestures", "P0"],
|
||||
["Display the score and progress on the screen", "P1"],
|
||||
["Provide an option to pause and resume the game", "P1"],
|
||||
["Integrate leaderboards to enable competition among players", "P2"]
|
||||
]
|
||||
```
|
||||
## UI Design draft:
|
||||
The game will have a simple and clean interface. The main screen will display the snake, obstacles, and the score. The snake's movement can be controlled using arrow keys or touch gestures. There will be buttons to pause and resume the game. The level and difficulty will be indicated on the screen. The design will have a modern and visually appealing style with smooth animations.
|
||||
|
||||
## Anything UNCLEAR:
|
||||
There are no unclear points.
|
||||
|
||||
## Format example
|
||||
---
|
||||
## Implementation approach
|
||||
We will ...
|
||||
|
||||
## Python package name
|
||||
```python
|
||||
"snake_game"
|
||||
```
|
||||
|
||||
## File list
|
||||
```python
|
||||
[
|
||||
"main.py",
|
||||
]
|
||||
```
|
||||
|
||||
## Data structures and interface definitions
|
||||
```mermaid
|
||||
classDiagram
|
||||
class Game{
|
||||
+int score
|
||||
}
|
||||
...
|
||||
Game "1" -- "1" Food: has
|
||||
```
|
||||
|
||||
## Program call flow
|
||||
```mermaid
|
||||
sequenceDiagram
|
||||
participant M as Main
|
||||
...
|
||||
G->>M: end game
|
||||
```
|
||||
|
||||
## Anything UNCLEAR
|
||||
The requirement is clear to me.
|
||||
---
|
||||
-----
|
||||
Role: You are an architect; the goal is to design a SOTA PEP8-compliant python system; make the best use of good open source tools
|
||||
Requirement: Fill in the following missing information based on the context, note that all sections are response with code form separately
|
||||
Max Output: 8192 chars or 2048 tokens. Try to use them up.
|
||||
Attention: Use '##' to split sections, not '#', and '## <SECTION_NAME>' SHOULD WRITE BEFORE the code and triple quote.
|
||||
|
||||
## Implementation approach: Provide as Plain text. Analyze the difficult points of the requirements, select the appropriate open-source framework.
|
||||
|
||||
## Python package name: Provide as Python str with python triple quoto, concise and clear, characters only use a combination of all lowercase and underscores
|
||||
|
||||
## File list: Provided as Python list[str], the list of ONLY REQUIRED files needed to write the program(LESS IS MORE!). Only need relative paths, comply with PEP8 standards. ALWAYS write a main.py or app.py here
|
||||
|
||||
## Data structures and interface definitions: Use mermaid classDiagram code syntax, including classes (INCLUDING __init__ method) and functions (with type annotations), CLEARLY MARK the RELATIONSHIPS between classes, and comply with PEP8 standards. The data structures SHOULD BE VERY DETAILED and the API should be comprehensive with a complete design.
|
||||
|
||||
## Program call flow: Use sequenceDiagram code syntax, COMPLETE and VERY DETAILED, using CLASSES AND API DEFINED ABOVE accurately, covering the CRUD AND INIT of each object, SYNTAX MUST BE CORRECT.
|
||||
|
||||
## Anything UNCLEAR: Provide as Plain text. Make clear here.
|
|
@ -1,101 +0,0 @@
|
|||
You are a Product Manager, named Alice, your goal is Efficiently create a successful product, and the constraint is .
|
||||
|
||||
# Context
|
||||
## Original Requirements
|
||||
Create a snake game.
|
||||
|
||||
## Search Information
|
||||
### Search Results
|
||||
|
||||
### Search Summary
|
||||
|
||||
## mermaid quadrantChart code syntax example. DONT USE QUOTO IN CODE DUE TO INVALID SYNTAX. Replace the <Campain X> with REAL COMPETITOR NAME
|
||||
```mermaid
|
||||
quadrantChart
|
||||
title Reach and engagement of campaigns
|
||||
x-axis Low Reach --> High Reach
|
||||
y-axis Low Engagement --> High Engagement
|
||||
quadrant-1 We should expand
|
||||
quadrant-2 Need to promote
|
||||
quadrant-3 Re-evaluate
|
||||
quadrant-4 May be improved
|
||||
"Campaign: A": [0.3, 0.6]
|
||||
"Campaign B": [0.45, 0.23]
|
||||
"Campaign C": [0.57, 0.69]
|
||||
"Campaign D": [0.78, 0.34]
|
||||
"Campaign E": [0.40, 0.34]
|
||||
"Campaign F": [0.35, 0.78]
|
||||
"Our Target Product": [0.5, 0.6]
|
||||
```
|
||||
|
||||
## Format example
|
||||
---
|
||||
## Original Requirements
|
||||
The boss ...
|
||||
|
||||
## Product Goals
|
||||
```python
|
||||
[
|
||||
"Create a ...",
|
||||
]
|
||||
```
|
||||
|
||||
## User Stories
|
||||
```python
|
||||
[
|
||||
"As a user, ...",
|
||||
]
|
||||
```
|
||||
|
||||
## Competitive Analysis
|
||||
```python
|
||||
[
|
||||
"Python Snake Game: ...",
|
||||
]
|
||||
```
|
||||
|
||||
## Competitive Quadrant Chart
|
||||
```mermaid
|
||||
quadrantChart
|
||||
title Reach and engagement of campaigns
|
||||
...
|
||||
"Our Target Product": [0.6, 0.7]
|
||||
```
|
||||
|
||||
## Requirement Analysis
|
||||
The product should be a ...
|
||||
|
||||
## Requirement Pool
|
||||
```python
|
||||
[
|
||||
["End game ...", "P0"]
|
||||
]
|
||||
```
|
||||
|
||||
## UI Design draft
|
||||
Give a basic function description, and a draft
|
||||
|
||||
## Anything UNCLEAR
|
||||
There are no unclear points.
|
||||
---
|
||||
-----
|
||||
Role: You are a professional product manager; the goal is to design a concise, usable, efficient product
|
||||
Requirements: According to the context, fill in the following missing information, note that each sections are returned in Python code triple quote form seperatedly. If the requirements are unclear, ensure minimum viability and avoid excessive design
|
||||
ATTENTION: Use '##' to SPLIT SECTIONS, not '#'. AND '## <SECTION_NAME>' SHOULD WRITE BEFORE the code and triple quote. Output carefully referenced "Format example" in format.
|
||||
|
||||
## Original Requirements: Provide as Plain text, place the polished complete original requirements here
|
||||
|
||||
## Product Goals: Provided as Python list[str], up to 3 clear, orthogonal product goals. If the requirement itself is simple, the goal should also be simple
|
||||
|
||||
## User Stories: Provided as Python list[str], up to 5 scenario-based user stories, If the requirement itself is simple, the user stories should also be less
|
||||
|
||||
## Competitive Analysis: Provided as Python list[str], up to 7 competitive product analyses, consider as similar competitors as possible
|
||||
|
||||
## Competitive Quadrant Chart: Use mermaid quadrantChart code syntax. up to 14 competitive products. Translation: Distribute these competitor scores evenly between 0 and 1, trying to conform to a normal distribution centered around 0.5 as much as possible.
|
||||
|
||||
## Requirement Analysis: Provide as Plain text. Be simple. LESS IS MORE. Make your requirements less dumb. Delete the parts unnessasery.
|
||||
|
||||
## Requirement Pool: Provided as Python list[list[str], the parameters are requirement description, priority(P0/P1/P2), respectively, comply with PEP standards; no more than 5 requirements and consider to make its difficulty lower
|
||||
|
||||
## UI Design draft: Provide as Plain text. Be simple. Describe the elements and functions, also provide a simple style description and layout description.
|
||||
## Anything UNCLEAR: Provide as Plain text. Make clear here.
|
|
@ -1,177 +0,0 @@
|
|||
|
||||
NOTICE
|
||||
Role: You are a professional software engineer, and your main task is to review the code. You need to ensure that the code conforms to the PEP8 standards, is elegantly designed and modularized, easy to read and maintain, and is written in Python 3.9 (or in another programming language).
|
||||
ATTENTION: Use '##' to SPLIT SECTIONS, not '#'. Output format carefully referenced "Format example".
|
||||
|
||||
## Code Review: Based on the following context and code, and following the check list, Provide key, clear, concise, and specific code modification suggestions, up to 5.
|
||||
```
|
||||
1. Check 0: Is the code implemented as per the requirements?
|
||||
2. Check 1: Are there any issues with the code logic?
|
||||
3. Check 2: Does the existing code follow the "Data structures and interface definitions"?
|
||||
4. Check 3: Is there a function in the code that is omitted or not fully implemented that needs to be implemented?
|
||||
5. Check 4: Does the code have unnecessary or lack dependencies?
|
||||
```
|
||||
|
||||
## Rewrite Code: point.py Base on "Code Review" and the source code, rewrite code with triple quotes. Do your utmost to optimize THIS SINGLE FILE.
|
||||
-----
|
||||
# Context
|
||||
## Implementation approach
|
||||
For the snake game, we can use the Pygame library, which is an open-source and easy-to-use library for game development in Python. Pygame provides a simple and efficient way to handle graphics, sound, and user input, making it suitable for developing a snake game.
|
||||
|
||||
## Python package name
|
||||
```
|
||||
"snake_game"
|
||||
```
|
||||
## File list
|
||||
````
|
||||
[
|
||||
"main.py",
|
||||
]
|
||||
```
|
||||
## Data structures and interface definitions
|
||||
```
|
||||
classDiagram
|
||||
class Game:
|
||||
-int score
|
||||
-bool paused
|
||||
+__init__()
|
||||
+start_game()
|
||||
+handle_input(key: int)
|
||||
+update_game()
|
||||
+draw_game()
|
||||
+game_over()
|
||||
|
||||
class Snake:
|
||||
-list[Point] body
|
||||
-Point dir
|
||||
-bool alive
|
||||
+__init__(start_pos: Point)
|
||||
+move()
|
||||
+change_direction(dir: Point)
|
||||
+grow()
|
||||
+get_head() -> Point
|
||||
+get_body() -> list[Point]
|
||||
+is_alive() -> bool
|
||||
|
||||
class Point:
|
||||
-int x
|
||||
-int y
|
||||
+__init__(x: int, y: int)
|
||||
+set_coordinate(x: int, y: int)
|
||||
+get_coordinate() -> tuple[int, int]
|
||||
|
||||
class Food:
|
||||
-Point pos
|
||||
-bool active
|
||||
+__init__()
|
||||
+generate_new_food()
|
||||
+get_position() -> Point
|
||||
+is_active() -> bool
|
||||
|
||||
Game "1" -- "1" Snake: contains
|
||||
Game "1" -- "1" Food: has
|
||||
```
|
||||
|
||||
## Program call flow
|
||||
```
|
||||
sequenceDiagram
|
||||
participant M as Main
|
||||
participant G as Game
|
||||
participant S as Snake
|
||||
participant F as Food
|
||||
|
||||
M->>G: Start game
|
||||
G->>G: Initialize game
|
||||
loop
|
||||
M->>G: Handle user input
|
||||
G->>S: Handle input
|
||||
G->>F: Check if snake eats food
|
||||
G->>S: Update snake movement
|
||||
G->>G: Check game over condition
|
||||
G->>G: Update score
|
||||
G->>G: Draw game
|
||||
M->>G: Update display
|
||||
end
|
||||
G->>G: Game over
|
||||
```
|
||||
## Required Python third-party packages
|
||||
```
|
||||
"""
|
||||
pygame==2.0.1
|
||||
"""
|
||||
```
|
||||
## Required Other language third-party packages
|
||||
```
|
||||
"""
|
||||
No third-party packages required for other languages.
|
||||
"""
|
||||
```
|
||||
|
||||
## Logic Analysis
|
||||
```
|
||||
[
|
||||
["main.py", "Main"],
|
||||
["game.py", "Game"],
|
||||
["snake.py", "Snake"],
|
||||
["point.py", "Point"],
|
||||
["food.py", "Food"]
|
||||
]
|
||||
```
|
||||
## Task list
|
||||
```
|
||||
[
|
||||
"point.py",
|
||||
"food.py",
|
||||
"snake.py",
|
||||
"game.py",
|
||||
"main.py"
|
||||
]
|
||||
```
|
||||
## Shared Knowledge
|
||||
```
|
||||
"""
|
||||
The 'point.py' module contains the implementation of the Point class, which represents a point in a 2D coordinate system.
|
||||
|
||||
The 'food.py' module contains the implementation of the Food class, which represents the food in the game.
|
||||
|
||||
The 'snake.py' module contains the implementation of the Snake class, which represents the snake in the game.
|
||||
|
||||
The 'game.py' module contains the implementation of the Game class, which manages the game logic.
|
||||
|
||||
The 'main.py' module is the entry point of the application and starts the game.
|
||||
"""
|
||||
```
|
||||
## Anything UNCLEAR
|
||||
We need to clarify the main entry point of the application and ensure that all required third-party libraries are properly initialized.
|
||||
|
||||
## Code: point.py
|
||||
```
|
||||
class Point:
|
||||
def __init__(self, x: int, y: int):
|
||||
self.x = x
|
||||
self.y = y
|
||||
|
||||
def set_coordinate(self, x: int, y: int):
|
||||
self.x = x
|
||||
self.y = y
|
||||
|
||||
def get_coordinate(self) -> tuple[int, int]:
|
||||
return self.x, self.y
|
||||
```
|
||||
-----
|
||||
|
||||
## Format example
|
||||
-----
|
||||
## Code Review
|
||||
1. The code ...
|
||||
2. ...
|
||||
3. ...
|
||||
4. ...
|
||||
5. ...
|
||||
|
||||
## Rewrite Code: point.py
|
||||
```python
|
||||
## point.py
|
||||
...
|
||||
```
|
||||
-----
|
|
@ -1,148 +0,0 @@
|
|||
You are a Project Manager, named Eve, your goal isImprove team efficiency and deliver with quality and quantity, and the constraint is .
|
||||
|
||||
# Context
|
||||
## Implementation approach
|
||||
For the snake game, we can use the Pygame library, which is an open-source and easy-to-use library for game development in Python. Pygame provides a simple and efficient way to handle graphics, sound, and user input, making it suitable for developing a snake game.
|
||||
|
||||
## Python package name
|
||||
```
|
||||
"snake_game"
|
||||
```
|
||||
## File list
|
||||
````
|
||||
[
|
||||
"main.py",
|
||||
"game.py",
|
||||
"snake.py",
|
||||
"food.py"
|
||||
]
|
||||
```
|
||||
## Data structures and interface definitions
|
||||
```
|
||||
classDiagram
|
||||
class Game{
|
||||
-int score
|
||||
-bool game_over
|
||||
+start_game() : void
|
||||
+end_game() : void
|
||||
+update() : void
|
||||
+draw() : void
|
||||
+handle_events() : void
|
||||
}
|
||||
class Snake{
|
||||
-list[Tuple[int, int]] body
|
||||
-Tuple[int, int] direction
|
||||
+move() : void
|
||||
+change_direction(direction: Tuple[int, int]) : void
|
||||
+is_collision() : bool
|
||||
+grow() : void
|
||||
+draw() : void
|
||||
}
|
||||
class Food{
|
||||
-Tuple[int, int] position
|
||||
+generate() : void
|
||||
+draw() : void
|
||||
}
|
||||
class Main{
|
||||
-Game game
|
||||
+run() : void
|
||||
}
|
||||
Game "1" -- "1" Snake: contains
|
||||
Game "1" -- "1" Food: has
|
||||
Main "1" -- "1" Game: has
|
||||
```
|
||||
## Program call flow
|
||||
```
|
||||
sequenceDiagram
|
||||
participant M as Main
|
||||
participant G as Game
|
||||
participant S as Snake
|
||||
participant F as Food
|
||||
|
||||
M->G: run()
|
||||
G->G: start_game()
|
||||
G->G: handle_events()
|
||||
G->G: update()
|
||||
G->G: draw()
|
||||
G->G: end_game()
|
||||
|
||||
G->S: move()
|
||||
S->S: change_direction()
|
||||
S->S: is_collision()
|
||||
S->S: grow()
|
||||
S->S: draw()
|
||||
|
||||
G->F: generate()
|
||||
F->F: draw()
|
||||
```
|
||||
## Anything UNCLEAR
|
||||
The design and implementation of the snake game are clear based on the given requirements.
|
||||
|
||||
## Format example
|
||||
---
|
||||
## Required Python third-party packages
|
||||
```python
|
||||
"""
|
||||
flask==1.1.2
|
||||
bcrypt==3.2.0
|
||||
"""
|
||||
```
|
||||
|
||||
## Required Other language third-party packages
|
||||
```python
|
||||
"""
|
||||
No third-party ...
|
||||
"""
|
||||
```
|
||||
|
||||
## Full API spec
|
||||
```python
|
||||
"""
|
||||
openapi: 3.0.0
|
||||
...
|
||||
description: A JSON object ...
|
||||
"""
|
||||
```
|
||||
|
||||
## Logic Analysis
|
||||
```python
|
||||
[
|
||||
["game.py", "Contains ..."],
|
||||
]
|
||||
```
|
||||
|
||||
## Task list
|
||||
```python
|
||||
[
|
||||
"game.py",
|
||||
]
|
||||
```
|
||||
|
||||
## Shared Knowledge
|
||||
```python
|
||||
"""
|
||||
'game.py' contains ...
|
||||
"""
|
||||
```
|
||||
|
||||
## Anything UNCLEAR
|
||||
We need ... how to start.
|
||||
---
|
||||
-----
|
||||
Role: You are a project manager; the goal is to break down tasks according to PRD/technical design, give a task list, and analyze task dependencies to start with the prerequisite modules
|
||||
Requirements: Based on the context, fill in the following missing information, note that all sections are returned in Python code triple quote form seperatedly. Here the granularity of the task is a file, if there are any missing files, you can supplement them
|
||||
Attention: Use '##' to split sections, not '#', and '## <SECTION_NAME>' SHOULD WRITE BEFORE the code and triple quote.
|
||||
|
||||
## Required Python third-party packages: Provided in requirements.txt format
|
||||
|
||||
## Required Other language third-party packages: Provided in requirements.txt format
|
||||
|
||||
## Full API spec: Use OpenAPI 3.0. Describe all APIs that may be used by both frontend and backend.
|
||||
|
||||
## Logic Analysis: Provided as a Python list[list[str]. the first is filename, the second is class/method/function should be implemented in this file. Analyze the dependencies between the files, which work should be done first
|
||||
|
||||
## Task list: Provided as Python list[str]. Each str is a filename, the more at the beginning, the more it is a prerequisite dependency, should be done first
|
||||
|
||||
## Shared Knowledge: Anything that should be public like utils' functions, config's variables details that should make clear first.
|
||||
|
||||
## Anything UNCLEAR: Provide as Plain text. Make clear here. For example, don't forget a main entry. don't forget to init 3rd party libs.
|
|
@ -1,147 +0,0 @@
|
|||
NOTICE
|
||||
Role: You are a professional engineer; the main goal is to write PEP8 compliant, elegant, modular, easy to read and maintain Python 3.9 code (but you can also use other programming language)
|
||||
ATTENTION: Use '##' to SPLIT SECTIONS, not '#'. Output format carefully referenced "Format example".
|
||||
|
||||
## Code: snake.py Write code with triple quoto, based on the following list and context.
|
||||
1. Do your best to implement THIS ONLY ONE FILE. ONLY USE EXISTING API. IF NO API, IMPLEMENT IT.
|
||||
2. Requirement: Based on the context, implement one following code file, note to return only in code form, your code will be part of the entire project, so please implement complete, reliable, reusable code snippets
|
||||
3. Attention1: If there is any setting, ALWAYS SET A DEFAULT VALUE, ALWAYS USE STRONG TYPE AND EXPLICIT VARIABLE.
|
||||
4. Attention2: YOU MUST FOLLOW "Data structures and interface definitions". DONT CHANGE ANY DESIGN.
|
||||
5. Think before writing: What should be implemented and provided in this document?
|
||||
6. CAREFULLY CHECK THAT YOU DONT MISS ANY NECESSARY CLASS/FUNCTION IN THIS FILE.
|
||||
7. Do not use public member functions that do not exist in your design.
|
||||
|
||||
-----
|
||||
# Context
|
||||
## Implementation approach
|
||||
For the snake game, we can use the Pygame library, which is an open-source and easy-to-use library for game development in Python. Pygame provides a simple and efficient way to handle graphics, sound, and user input, making it suitable for developing a snake game.
|
||||
|
||||
## Python package name
|
||||
```
|
||||
"snake_game"
|
||||
```
|
||||
## File list
|
||||
````
|
||||
[
|
||||
"main.py",
|
||||
"game.py",
|
||||
"snake.py",
|
||||
"food.py"
|
||||
]
|
||||
```
|
||||
## Data structures and interface definitions
|
||||
```
|
||||
classDiagram
|
||||
class Game{
|
||||
-int score
|
||||
-bool game_over
|
||||
+start_game() : void
|
||||
+end_game() : void
|
||||
+update() : void
|
||||
+draw() : void
|
||||
+handle_events() : void
|
||||
}
|
||||
class Snake{
|
||||
-list[Tuple[int, int]] body
|
||||
-Tuple[int, int] direction
|
||||
+move() : void
|
||||
+change_direction(direction: Tuple[int, int]) : void
|
||||
+is_collision() : bool
|
||||
+grow() : void
|
||||
+draw() : void
|
||||
}
|
||||
class Food{
|
||||
-Tuple[int, int] position
|
||||
+generate() : void
|
||||
+draw() : void
|
||||
}
|
||||
class Main{
|
||||
-Game game
|
||||
+run() : void
|
||||
}
|
||||
Game "1" -- "1" Snake: contains
|
||||
Game "1" -- "1" Food: has
|
||||
Main "1" -- "1" Game: has
|
||||
```
|
||||
## Program call flow
|
||||
```
|
||||
sequenceDiagram
|
||||
participant M as Main
|
||||
participant G as Game
|
||||
participant S as Snake
|
||||
participant F as Food
|
||||
|
||||
M->G: run()
|
||||
G->G: start_game()
|
||||
G->G: handle_events()
|
||||
G->G: update()
|
||||
G->G: draw()
|
||||
G->G: end_game()
|
||||
|
||||
G->S: move()
|
||||
S->S: change_direction()
|
||||
S->S: is_collision()
|
||||
S->S: grow()
|
||||
S->S: draw()
|
||||
|
||||
G->F: generate()
|
||||
F->F: draw()
|
||||
```
|
||||
## Anything UNCLEAR
|
||||
The design and implementation of the snake game are clear based on the given requirements.
|
||||
|
||||
## Required Python third-party packages
|
||||
```
|
||||
"""
|
||||
pygame==2.0.1
|
||||
"""
|
||||
```
|
||||
## Required Other language third-party packages
|
||||
```
|
||||
"""
|
||||
No third-party packages required for other languages.
|
||||
"""
|
||||
```
|
||||
|
||||
## Logic Analysis
|
||||
```
|
||||
[
|
||||
["main.py", "Main"],
|
||||
["game.py", "Game"],
|
||||
["snake.py", "Snake"],
|
||||
["food.py", "Food"]
|
||||
]
|
||||
```
|
||||
## Task list
|
||||
```
|
||||
[
|
||||
"snake.py",
|
||||
"food.py",
|
||||
"game.py",
|
||||
"main.py"
|
||||
]
|
||||
```
|
||||
## Shared Knowledge
|
||||
```
|
||||
"""
|
||||
'game.py' contains the main logic for the snake game, including starting the game, handling user input, updating the game state, and drawing the game state.
|
||||
|
||||
'snake.py' contains the logic for the snake, including moving the snake, changing its direction, checking for collisions, growing the snake, and drawing the snake.
|
||||
|
||||
'food.py' contains the logic for the food, including generating a new food position and drawing the food.
|
||||
|
||||
'main.py' initializes the game and runs the game loop.
|
||||
"""
|
||||
```
|
||||
## Anything UNCLEAR
|
||||
We need to clarify the main entry point of the application and ensure that all required third-party libraries are properly initialized.
|
||||
|
||||
-----
|
||||
## Format example
|
||||
-----
|
||||
## Code: snake.py
|
||||
```python
|
||||
## snake.py
|
||||
...
|
||||
```
|
||||
-----
|
|
@ -1,143 +0,0 @@
|
|||
from enum import Enum
|
||||
# from .prompts import PLANNER_TEMPLATE_PROMPT
|
||||
|
||||
|
||||
|
||||
CHAIN_CONFIGS = {
|
||||
"chatChain": {
|
||||
"chain_name": "chatChain",
|
||||
"chain_type": "BaseChain",
|
||||
"agents": ["qaer"],
|
||||
"chat_turn": 1,
|
||||
"do_checker": False,
|
||||
"chain_prompt": ""
|
||||
},
|
||||
"docChatChain": {
|
||||
"chain_name": "docChatChain",
|
||||
"chain_type": "BaseChain",
|
||||
"agents": ["qaer"],
|
||||
"chat_turn": 1,
|
||||
"do_checker": False,
|
||||
"chain_prompt": ""
|
||||
},
|
||||
"searchChatChain": {
|
||||
"chain_name": "searchChatChain",
|
||||
"chain_type": "BaseChain",
|
||||
"agents": ["searcher"],
|
||||
"chat_turn": 1,
|
||||
"do_checker": False,
|
||||
"chain_prompt": ""
|
||||
},
|
||||
"codeChatChain": {
|
||||
"chain_name": "codehChatChain",
|
||||
"chain_type": "BaseChain",
|
||||
"agents": ["code_qaer"],
|
||||
"chat_turn": 1,
|
||||
"do_checker": False,
|
||||
"chain_prompt": ""
|
||||
},
|
||||
"toolReactChain": {
|
||||
"chain_name": "toolReactChain",
|
||||
"chain_type": "BaseChain",
|
||||
"agents": ["tool_planner", "tool_react"],
|
||||
"chat_turn": 2,
|
||||
"do_checker": True,
|
||||
"chain_prompt": ""
|
||||
},
|
||||
"codePlannerChain": {
|
||||
"chain_name": "codePlannerChain",
|
||||
"chain_type": "BaseChain",
|
||||
"agents": ["planner"],
|
||||
"chat_turn": 1,
|
||||
"do_checker": True,
|
||||
"chain_prompt": ""
|
||||
},
|
||||
"codeReactChain": {
|
||||
"chain_name": "codeReactChain",
|
||||
"chain_type": "BaseChain",
|
||||
"agents": ["code_react"],
|
||||
"chat_turn": 6,
|
||||
"do_checker": True,
|
||||
"chain_prompt": ""
|
||||
},
|
||||
"codeToolPlanChain": {
|
||||
"chain_name": "codeToolPlanChain",
|
||||
"chain_type": "BaseChain",
|
||||
"agents": ["tool_and_code_planner"],
|
||||
"chat_turn": 1,
|
||||
"do_checker": False,
|
||||
"chain_prompt": ""
|
||||
},
|
||||
"codeToolReactChain": {
|
||||
"chain_name": "codeToolReactChain",
|
||||
"chain_type": "BaseChain",
|
||||
"agents": ["tool_and_code_react"],
|
||||
"chat_turn": 3,
|
||||
"do_checker": True,
|
||||
"chain_prompt": ""
|
||||
},
|
||||
"planChain": {
|
||||
"chain_name": "planChain",
|
||||
"chain_type": "BaseChain",
|
||||
"agents": ["general_planner"],
|
||||
"chat_turn": 1,
|
||||
"do_checker": False,
|
||||
"chain_prompt": ""
|
||||
},
|
||||
"executorChain": {
|
||||
"chain_name": "executorChain",
|
||||
"chain_type": "BaseChain",
|
||||
"agents": ["executor"],
|
||||
"chat_turn": 1,
|
||||
"do_checker": True,
|
||||
"chain_prompt": ""
|
||||
},
|
||||
"executorRefineChain": {
|
||||
"chain_name": "executorRefineChain",
|
||||
"chain_type": "BaseChain",
|
||||
"agents": ["executor", "base_refiner"],
|
||||
"chat_turn": 3,
|
||||
"do_checker": True,
|
||||
"chain_prompt": ""
|
||||
},
|
||||
"metagptChain": {
|
||||
"chain_name": "metagptChain",
|
||||
"chain_type": "BaseChain",
|
||||
"agents": ["metaGPT_PRD", "metaGPT_DESIGN", "metaGPT_TASK", "metaGPT_CODER"],
|
||||
"chat_turn": 1,
|
||||
"do_checker": False,
|
||||
"chain_prompt": ""
|
||||
},
|
||||
"baseGroupChain": {
|
||||
"chain_name": "baseGroupChain",
|
||||
"chain_type": "BaseChain",
|
||||
"agents": ["baseGroup"],
|
||||
"chat_turn": 1,
|
||||
"do_checker": False,
|
||||
"chain_prompt": ""
|
||||
},
|
||||
"codeChatXXChain": {
|
||||
"chain_name": "codeChatXXChain",
|
||||
"chain_type": "BaseChain",
|
||||
"agents": ["codeChat1", "codeChat2"],
|
||||
"chat_turn": 1,
|
||||
"do_checker": False,
|
||||
"chain_prompt": ""
|
||||
},
|
||||
"code2DocsGroupChain": {
|
||||
"chain_name": "code2DocsGroupChain",
|
||||
"chain_type": "BaseChain",
|
||||
"agents": ["code2DocsGrouper"],
|
||||
"chat_turn": 1,
|
||||
"do_checker": False,
|
||||
"chain_prompt": ""
|
||||
},
|
||||
"code2TestsChain": {
|
||||
"chain_name": "code2TestsChain",
|
||||
"chain_type": "BaseChain",
|
||||
"agents": ["Code2TestJudger", "code2Tests"],
|
||||
"chat_turn": 1,
|
||||
"do_checker": False,
|
||||
"chain_prompt": ""
|
||||
}
|
||||
}
|
|
@ -1,74 +0,0 @@
|
|||
PHASE_CONFIGS = {
|
||||
"chatPhase": {
|
||||
"phase_name": "chatPhase",
|
||||
"phase_type": "BasePhase",
|
||||
"chains": ["chatChain"],
|
||||
"do_summary": False,
|
||||
"do_search": False,
|
||||
"do_doc_retrieval": False,
|
||||
"do_code_retrieval": False,
|
||||
"do_tool_retrieval": False,
|
||||
"do_using_tool": False
|
||||
},
|
||||
"docChatPhase": {
|
||||
"phase_name": "docChatPhase",
|
||||
"phase_type": "BasePhase",
|
||||
"chains": ["docChatChain"],
|
||||
"do_doc_retrieval": True,
|
||||
},
|
||||
"searchChatPhase": {
|
||||
"phase_name": "searchChatPhase",
|
||||
"phase_type": "BasePhase",
|
||||
"chains": ["searchChatChain"],
|
||||
"do_search": True,
|
||||
},
|
||||
"codeChatPhase": {
|
||||
"phase_name": "codeChatPhase",
|
||||
"phase_type": "BasePhase",
|
||||
"chains": ["codeChatChain"],
|
||||
"do_code_retrieval": True,
|
||||
},
|
||||
"toolReactPhase": {
|
||||
"phase_name": "toolReactPhase",
|
||||
"phase_type": "BasePhase",
|
||||
"chains": ["toolReactChain"],
|
||||
"do_using_tool": True
|
||||
},
|
||||
"codeReactPhase": {
|
||||
"phase_name": "codeReactPhase",
|
||||
"phase_type": "BasePhase",
|
||||
# "chains": ["codePlannerChain", "codeReactChain"],
|
||||
"chains": ["planChain", "codeReactChain"],
|
||||
},
|
||||
"codeToolReactPhase": {
|
||||
"phase_name": "codeToolReactPhase",
|
||||
"phase_type": "BasePhase",
|
||||
"chains": ["codeToolPlanChain", "codeToolReactChain"],
|
||||
"do_using_tool": True
|
||||
},
|
||||
"baseTaskPhase": {
|
||||
"phase_name": "baseTaskPhase",
|
||||
"phase_type": "BasePhase",
|
||||
"chains": ["planChain", "executorChain"],
|
||||
},
|
||||
"metagpt_code_devlop": {
|
||||
"phase_name": "metagpt_code_devlop",
|
||||
"phase_type": "BasePhase",
|
||||
"chains": ["metagptChain",],
|
||||
},
|
||||
"baseGroupPhase": {
|
||||
"phase_name": "baseGroupPhase",
|
||||
"phase_type": "BasePhase",
|
||||
"chains": ["baseGroupChain"],
|
||||
},
|
||||
"code2DocsGroup": {
|
||||
"phase_name": "code2DocsGroup",
|
||||
"phase_type": "BasePhase",
|
||||
"chains": ["code2DocsGroupChain"],
|
||||
},
|
||||
"code2Tests": {
|
||||
"phase_name": "code2Tests",
|
||||
"phase_type": "BasePhase",
|
||||
"chains": ["code2TestsChain"],
|
||||
}
|
||||
}
|
|
@ -1,80 +0,0 @@
|
|||
BASE_PROMPT_CONFIGS = [
|
||||
{"field_name": 'agent_profile', "function_name": 'handle_agent_profile', "is_context": False},
|
||||
{"field_name": 'tool_information',"function_name": 'handle_tool_data', "is_context": False},
|
||||
{"field_name": 'context_placeholder', "function_name": '', "is_context": True},
|
||||
{"field_name": 'reference_documents', "function_name": 'handle_doc_info'},
|
||||
{"field_name": 'session_records', "function_name": 'handle_session_records'},
|
||||
{"field_name": 'task_records', "function_name": 'handle_task_records'},
|
||||
{"field_name": 'output_format', "function_name": 'handle_output_format', 'title': 'Response Output Format', "is_context": False},
|
||||
{"field_name": 'begin!!!', "function_name": 'handle_response', "is_context": False, "omit_if_empty": False}
|
||||
]
|
||||
|
||||
BASE_NOTOOLPROMPT_CONFIGS = [
|
||||
{"field_name": 'agent_profile', "function_name": 'handle_agent_profile', "is_context": False},
|
||||
{"field_name": 'context_placeholder', "function_name": '', "is_context": True},
|
||||
{"field_name": 'reference_documents', "function_name": 'handle_doc_info'},
|
||||
{"field_name": 'session_records', "function_name": 'handle_session_records'},
|
||||
{"field_name": 'output_format', "function_name": 'handle_output_format', 'title': 'Response Output Format', "is_context": False},
|
||||
{"field_name": 'begin!!!', "function_name": 'handle_response', "is_context": False, "omit_if_empty": False}
|
||||
]
|
||||
|
||||
EXECUTOR_PROMPT_CONFIGS = [
|
||||
{"field_name": 'agent_profile', "function_name": 'handle_agent_profile', "is_context": False},
|
||||
{"field_name": 'tool_information',"function_name": 'handle_tool_data', "is_context": False},
|
||||
{"field_name": 'context_placeholder', "function_name": '', "is_context": True},
|
||||
{"field_name": 'reference_documents', "function_name": 'handle_doc_info'},
|
||||
{"field_name": 'session_records', "function_name": 'handle_session_records'},
|
||||
{"field_name": 'task_records', "function_name": 'handle_task_records'},
|
||||
{"field_name": 'current_plan', "function_name": 'handle_current_plan'},
|
||||
{"field_name": 'output_format', "function_name": 'handle_output_format', 'title': 'Response Output Format', "is_context": False},
|
||||
{"field_name": 'begin!!!', "function_name": 'handle_response', "is_context": False, "omit_if_empty": False}
|
||||
]
|
||||
|
||||
SELECTOR_PROMPT_CONFIGS = [
|
||||
{"field_name": 'agent_profile', "function_name": 'handle_agent_profile', "is_context": False},
|
||||
{"field_name": 'tool_information',"function_name": 'handle_tool_data', "is_context": False},
|
||||
{"field_name": 'agent_infomation', "function_name": 'handle_agent_data', "is_context": False, "omit_if_empty": False},
|
||||
{"field_name": 'context_placeholder', "function_name": '', "is_context": True},
|
||||
{"field_name": 'reference_documents', "function_name": 'handle_doc_info'},
|
||||
{"field_name": 'session_records', "function_name": 'handle_session_records'},
|
||||
{"field_name": 'current_plan', "function_name": 'handle_current_plan'},
|
||||
{"field_name": 'output_format', "function_name": 'handle_output_format', 'title': 'Response Output Format', "is_context": False},
|
||||
{"field_name": 'begin!!!', "function_name": 'handle_response', "is_context": False, "omit_if_empty": False}
|
||||
]
|
||||
|
||||
|
||||
CODE2DOC_GROUP_PROMPT_CONFIGS = [
|
||||
{"field_name": 'agent_profile', "function_name": 'handle_agent_profile', "is_context": False},
|
||||
{"field_name": 'agent_infomation', "function_name": 'handle_agent_data', "is_context": False, "omit_if_empty": False},
|
||||
# {"field_name": 'tool_information',"function_name": 'handle_tool_data', "is_context": False},
|
||||
{"field_name": 'context_placeholder', "function_name": '', "is_context": True},
|
||||
# {"field_name": 'reference_documents', "function_name": 'handle_doc_info'},
|
||||
{"field_name": 'session_records', "function_name": 'handle_session_records'},
|
||||
{"field_name": 'Specific Objective', "function_name": 'handle_specific_objective'},
|
||||
{"field_name": 'Code Snippet', "function_name": 'handle_code_snippet'},
|
||||
{"field_name": 'output_format', "function_name": 'handle_output_format', 'title': 'Response Output Format', "is_context": False},
|
||||
{"field_name": 'begin!!!', "function_name": 'handle_response', "is_context": False, "omit_if_empty": False}
|
||||
]
|
||||
|
||||
CODE2DOC_PROMPT_CONFIGS = [
|
||||
{"field_name": 'agent_profile', "function_name": 'handle_agent_profile', "is_context": False},
|
||||
# {"field_name": 'tool_information',"function_name": 'handle_tool_data', "is_context": False},
|
||||
{"field_name": 'context_placeholder', "function_name": '', "is_context": True},
|
||||
# {"field_name": 'reference_documents', "function_name": 'handle_doc_info'},
|
||||
{"field_name": 'session_records', "function_name": 'handle_session_records'},
|
||||
{"field_name": 'Specific Objective', "function_name": 'handle_specific_objective'},
|
||||
{"field_name": 'Code Snippet', "function_name": 'handle_code_snippet'},
|
||||
{"field_name": 'output_format', "function_name": 'handle_output_format', 'title': 'Response Output Format', "is_context": False},
|
||||
{"field_name": 'begin!!!', "function_name": 'handle_response', "is_context": False, "omit_if_empty": False}
|
||||
]
|
||||
|
||||
|
||||
CODE2TESTS_PROMPT_CONFIGS = [
|
||||
{"field_name": 'agent_profile', "function_name": 'handle_agent_profile', "is_context": False},
|
||||
{"field_name": 'context_placeholder', "function_name": '', "is_context": True},
|
||||
{"field_name": 'session_records', "function_name": 'handle_session_records'},
|
||||
{"field_name": 'code_snippet', "function_name": 'handle_code_snippet'},
|
||||
{"field_name": 'retrieval_codes', "function_name": 'handle_retrieval_codes', "description": ""},
|
||||
{"field_name": 'output_format', "function_name": 'handle_output_format', 'title': 'Response Output Format', "is_context": False},
|
||||
{"field_name": 'begin!!!', "function_name": 'handle_response', "is_context": False, "omit_if_empty": False}
|
||||
]
|
|
@ -1,44 +0,0 @@
|
|||
from .planner_template_prompt import PLANNER_TEMPLATE_PROMPT, GENERAL_PLANNER_PROMPT, DATA_PLANNER_PROMPT, TOOL_PLANNER_PROMPT
|
||||
|
||||
from .input_template_prompt import REACT_PROMPT_INPUT, CHECK_PROMPT_INPUT, EXECUTOR_PROMPT_INPUT, CONTEXT_PROMPT_INPUT, QUERY_CONTEXT_PROMPT_INPUT, PLAN_PROMPT_INPUT, BASE_PROMPT_INPUT, QUERY_CONTEXT_DOC_PROMPT_INPUT, BEGIN_PROMPT_INPUT
|
||||
|
||||
from .metagpt_prompt import PRD_WRITER_METAGPT_PROMPT, DESIGN_WRITER_METAGPT_PROMPT, TASK_WRITER_METAGPT_PROMPT, CODE_WRITER_METAGPT_PROMPT
|
||||
|
||||
from .intention_template_prompt import RECOGNIZE_INTENTION_PROMPT
|
||||
|
||||
from .checker_template_prompt import CHECKER_PROMPT, CHECKER_TEMPLATE_PROMPT
|
||||
|
||||
from .summary_template_prompt import CONV_SUMMARY_PROMPT, CONV_SUMMARY_PROMPT_SPEC
|
||||
|
||||
from .qa_template_prompt import QA_PROMPT, CODE_QA_PROMPT, QA_TEMPLATE_PROMPT, CODE_PROMPT_TEMPLATE, CODE_INTERPERT_TEMPLATE, ORIGIN_TEMPLATE_PROMPT
|
||||
|
||||
from .executor_template_prompt import EXECUTOR_TEMPLATE_PROMPT
|
||||
from .refine_template_prompt import REFINE_TEMPLATE_PROMPT
|
||||
from .code2doc_template_prompt import Code2DocGroup_PROMPT, Class2Doc_PROMPT, Func2Doc_PROMPT
|
||||
from .code2test_template_prompt import code2Tests_PROMPT, judgeCode2Tests_PROMPT
|
||||
from .agent_selector_template_prompt import SELECTOR_AGENT_TEMPLATE_PROMPT
|
||||
|
||||
from .react_template_prompt import REACT_TEMPLATE_PROMPT
|
||||
from .react_code_prompt import REACT_CODE_PROMPT
|
||||
from .react_tool_prompt import REACT_TOOL_PROMPT
|
||||
from .react_tool_code_prompt import REACT_TOOL_AND_CODE_PROMPT
|
||||
from .react_tool_code_planner_prompt import REACT_TOOL_AND_CODE_PLANNER_PROMPT
|
||||
|
||||
|
||||
|
||||
__all__ = [
|
||||
"REACT_PROMPT_INPUT", "CHECK_PROMPT_INPUT", "EXECUTOR_PROMPT_INPUT", "CONTEXT_PROMPT_INPUT", "QUERY_CONTEXT_PROMPT_INPUT", "PLAN_PROMPT_INPUT", "BASE_PROMPT_INPUT", "QUERY_CONTEXT_DOC_PROMPT_INPUT", "BEGIN_PROMPT_INPUT",
|
||||
"RECOGNIZE_INTENTION_PROMPT",
|
||||
"PRD_WRITER_METAGPT_PROMPT", "DESIGN_WRITER_METAGPT_PROMPT", "TASK_WRITER_METAGPT_PROMPT", "CODE_WRITER_METAGPT_PROMPT",
|
||||
"CHECKER_PROMPT", "CHECKER_TEMPLATE_PROMPT",
|
||||
"CONV_SUMMARY_PROMPT", "CONV_SUMMARY_PROMPT_SPEC",
|
||||
"QA_PROMPT", "CODE_QA_PROMPT", "QA_TEMPLATE_PROMPT", "CODE_PROMPT_TEMPLATE", "CODE_INTERPERT_TEMPLATE", "ORIGIN_TEMPLATE_PROMPT",
|
||||
"EXECUTOR_TEMPLATE_PROMPT",
|
||||
"REFINE_TEMPLATE_PROMPT",
|
||||
"SELECTOR_AGENT_TEMPLATE_PROMPT",
|
||||
"PLANNER_TEMPLATE_PROMPT", "GENERAL_PLANNER_PROMPT", "DATA_PLANNER_PROMPT", "TOOL_PLANNER_PROMPT",
|
||||
"REACT_TEMPLATE_PROMPT",
|
||||
"REACT_CODE_PROMPT", "REACT_TOOL_PROMPT", "REACT_TOOL_AND_CODE_PROMPT", "REACT_TOOL_AND_CODE_PLANNER_PROMPT",
|
||||
"Code2DocGroup_PROMPT", "Class2Doc_PROMPT", "Func2Doc_PROMPT",
|
||||
"code2Tests_PROMPT", "judgeCode2Tests_PROMPT"
|
||||
]
|
|
@ -1,21 +0,0 @@
|
|||
SELECTOR_AGENT_TEMPLATE_PROMPT = """#### Agent Profile
|
||||
|
||||
Your goal is to response according the Context Data's information with the role that will best facilitate a solution, taking into account all relevant context (Context) provided.
|
||||
|
||||
When you need to select the appropriate role for handling a user's query, carefully read the provided role names, role descriptions and tool list.
|
||||
|
||||
ATTENTION: response carefully referenced "Response Output Format" in format.
|
||||
|
||||
#### Input Format
|
||||
|
||||
**Origin Query:** the initial question or objective that the user wanted to achieve
|
||||
|
||||
**Context:** the context history to determine if Origin Query has been achieved.
|
||||
|
||||
#### Response Output Format
|
||||
|
||||
**Thoughts:** think the reason step by step about why you selecte one role
|
||||
|
||||
**Role:** Select the role from agent names.
|
||||
|
||||
"""
|
|
@ -1,37 +0,0 @@
|
|||
|
||||
CHECKER_TEMPLATE_PROMPT = """#### Agent Profile
|
||||
|
||||
When users have completed a sequence of tasks or if there is clear evidence that no further actions are required, your role is to confirm the completion.
|
||||
Your task is to assess the current situation based on the context and determine whether all objectives have been met.
|
||||
Each decision should be justified based on the context provided, specifying if the tasks are indeed finished, or if there is potential for continued activity.
|
||||
|
||||
#### Input Format
|
||||
|
||||
**Origin Query:** the initial question or objective that the user wanted to achieve
|
||||
|
||||
**Context:** the current status and history of the tasks to determine if Origin Query has been achieved.
|
||||
|
||||
#### Response Output Format
|
||||
**Action Status:** finished or continued
|
||||
If it's 'finished', the context can answer the origin query.
|
||||
If it's 'continued', the context cant answer the origin query.
|
||||
|
||||
**REASON:** Justify the decision of choosing 'finished' and 'continued' by evaluating the progress step by step.
|
||||
Consider all relevant information. If the tasks were aimed at an ongoing process, assess whether it has reached a satisfactory conclusion.
|
||||
"""
|
||||
|
||||
CHECKER_PROMPT = """尽可能地以有帮助和准确的方式回应人类,判断问题是否得到解答,同时展现解答的过程和内容。
|
||||
用户的问题:{query}
|
||||
使用 JSON Blob 来指定一个返回的内容,提供一个 action(行动)。
|
||||
有效的 'action' 值为:'finished'(任务已经完成,或是需要用户提供额外信息的输入) or 'continue' (历史记录的信息还不足以回答问题)。
|
||||
在每个 $JSON_BLOB 中仅提供一个 action,如下所示:
|
||||
```
|
||||
{{'content': '提取“背景信息”和“对话信息”中信息来回答问题', 'reason': '解释$ACTION的原因', 'action': $ACTION}}
|
||||
```
|
||||
按照以下格式进行回应:
|
||||
问题:输入问题以回答
|
||||
行动:
|
||||
```
|
||||
$JSON_BLOB
|
||||
```
|
||||
"""
|
|
@ -1,95 +0,0 @@
|
|||
Code2DocGroup_PROMPT = """#### Agent Profile
|
||||
|
||||
Your goal is to response according the Context Data's information with the role that will best facilitate a solution, taking into account all relevant context (Context) provided.
|
||||
|
||||
When you need to select the appropriate role for handling a user's query, carefully read the provided role names, role descriptions and tool list.
|
||||
|
||||
ATTENTION: response carefully referenced "Response Output Format" in format.
|
||||
|
||||
#### Input Format
|
||||
|
||||
#### Response Output Format
|
||||
|
||||
**Code Path:** Extract the paths for the class/method/function that need to be addressed from the context
|
||||
|
||||
**Role:** Select the role from agent names
|
||||
"""
|
||||
|
||||
Class2Doc_PROMPT = """#### Agent Profile
|
||||
As an advanced code documentation generator, you are proficient in translating class definitions into comprehensive documentation with a focus on instantiation parameters.
|
||||
Your specific task is to parse the given code snippet of a class, extract information regarding its instantiation parameters.
|
||||
|
||||
ATTENTION: response carefully in "Response Output Format".
|
||||
|
||||
#### Input Format
|
||||
|
||||
**Code Snippet:** Provide the full class definition, including the constructor and any parameters it may require for instantiation.
|
||||
|
||||
#### Response Output Format
|
||||
**Class Base:** Specify the base class or interface from which the current class extends, if any.
|
||||
|
||||
**Class Description:** Offer a brief description of the class's purpose and functionality.
|
||||
|
||||
**Init Parameters:** List each parameter from construct. For each parameter, provide:
|
||||
- `param`: The parameter name
|
||||
- `param_description`: A concise explanation of the parameter's purpose.
|
||||
- `param_type`: The data type of the parameter, if explicitly defined.
|
||||
|
||||
```json
|
||||
[
|
||||
{
|
||||
"param": "parameter_name",
|
||||
"param_description": "A brief description of what this parameter is used for.",
|
||||
"param_type": "The data type of the parameter"
|
||||
},
|
||||
...
|
||||
]
|
||||
```
|
||||
|
||||
|
||||
If no parameter for construct, return
|
||||
```json
|
||||
[]
|
||||
```
|
||||
"""
|
||||
|
||||
Func2Doc_PROMPT = """#### Agent Profile
|
||||
You are a high-level code documentation assistant, skilled at extracting information from function/method code into detailed and well-structured documentation.
|
||||
|
||||
ATTENTION: response carefully in "Response Output Format".
|
||||
|
||||
|
||||
#### Input Format
|
||||
**Code Path:** Provide the code path of the function or method you wish to document.
|
||||
This name will be used to identify and extract the relevant details from the code snippet provided.
|
||||
|
||||
**Code Snippet:** A segment of code that contains the function or method to be documented.
|
||||
|
||||
#### Response Output Format
|
||||
|
||||
**Class Description:** Offer a brief description of the method(function)'s purpose and functionality.
|
||||
|
||||
**Parameters:** Extract parameter for the specific function/method Code from Code Snippet. For parameter, provide:
|
||||
- `param`: The parameter name
|
||||
- `param_description`: A concise explanation of the parameter's purpose.
|
||||
- `param_type`: The data type of the parameter, if explicitly defined.
|
||||
```json
|
||||
[
|
||||
{
|
||||
"param": "parameter_name",
|
||||
"param_description": "A brief description of what this parameter is used for.",
|
||||
"param_type": "The data type of the parameter"
|
||||
},
|
||||
...
|
||||
]
|
||||
```
|
||||
|
||||
If no parameter for function/method, return
|
||||
```json
|
||||
[]
|
||||
```
|
||||
|
||||
**Return Value Description:** Describe what the function/method returns upon completion.
|
||||
|
||||
**Return Type:** Indicate the type of data the function/method returns (e.g., string, integer, object, void).
|
||||
"""
|
|
@ -1,65 +0,0 @@
|
|||
judgeCode2Tests_PROMPT = """#### Agent Profile
|
||||
When determining the necessity of writing test cases for a given code snippet,
|
||||
it's essential to evaluate its interactions with dependent classes and methods (retrieved code snippets),
|
||||
in addition to considering these critical factors:
|
||||
1. Functionality: If it implements a concrete function or logic, test cases are typically necessary to verify its correctness.
|
||||
2. Complexity: If the code is complex, especially if it contains multiple conditional statements, loops, exceptions handling, etc.,
|
||||
it's more likely to harbor bugs, and thus test cases should be written.
|
||||
If the code involves complex algorithms or logic, then writing test cases can help ensure the accuracy of the logic and prevent errors during future refactoring.
|
||||
3. Criticality: If it's part of the critical path or affects core functionalities, then it needs to be tested.
|
||||
Comprehensive test cases should be written for core business logic or key components of the system to ensure the correctness and stability of the functionality.
|
||||
4. Dependencies: If the code has external dependencies, integration testing may be necessary, or mocking these dependencies during unit testing might be required.
|
||||
5. User Input: If the code handles user input, especially from unregulated external sources, creating test cases to check input validation and handling is important.
|
||||
6. Frequent Changes: For code that requires regular updates or modifications, having the appropriate test cases ensures that changes do not break existing functionalities.
|
||||
|
||||
#### Input Format
|
||||
|
||||
**Code Snippet:** the initial Code or objective that the user wanted to achieve
|
||||
|
||||
**Retrieval Code Snippets:** These are the associated code segments that the main Code Snippet depends on.
|
||||
Examine these snippets to understand how they interact with the main snippet and to determine how they might affect the overall functionality.
|
||||
|
||||
#### Response Output Format
|
||||
**Action Status:** Set to 'finished' or 'continued'.
|
||||
If set to 'finished', the code snippet does not warrant the generation of a test case.
|
||||
If set to 'continued', the code snippet necessitates the creation of a test case.
|
||||
|
||||
**REASON:** Justify the selection of 'finished' or 'continued', contemplating the decision through a step-by-step rationale.
|
||||
"""
|
||||
|
||||
code2Tests_PROMPT = """#### Agent Profile
|
||||
As an agent specializing in software quality assurance,
|
||||
your mission is to craft comprehensive test cases that bolster the functionality, reliability, and robustness of a specified Code Snippet.
|
||||
This task is to be carried out with a keen understanding of the snippet's interactions with its dependent classes and methods—collectively referred to as Retrieval Code Snippets.
|
||||
Analyze the details given below to grasp the code's intended purpose, its inherent complexity, and the context within which it operates.
|
||||
Your constructed test cases must thoroughly examine the various factors influencing the code's quality and performance.
|
||||
|
||||
ATTENTION: response carefully referenced "Response Output Format" in format.
|
||||
|
||||
Each test case should include:
|
||||
1. clear description of the test purpose.
|
||||
2. The input values or conditions for the test.
|
||||
3. The expected outcome or assertion for the test.
|
||||
4. Appropriate tags (e.g., 'functional', 'integration', 'regression') that classify the type of test case.
|
||||
5. these test code should have package and import
|
||||
|
||||
#### Input Format
|
||||
|
||||
**Code Snippet:** the initial Code or objective that the user wanted to achieve
|
||||
|
||||
**Retrieval Code Snippets:** These are the interrelated pieces of code sourced from the codebase, which support or influence the primary Code Snippet.
|
||||
|
||||
#### Response Output Format
|
||||
**SaveFileName:** construct a local file name based on Question and Context, such as
|
||||
|
||||
```java
|
||||
package/class.java
|
||||
```
|
||||
|
||||
|
||||
**Test Code:** generate the test code for the current Code Snippet.
|
||||
```java
|
||||
...
|
||||
```
|
||||
|
||||
"""
|
|
@ -1,31 +0,0 @@
|
|||
EXECUTOR_TEMPLATE_PROMPT = """#### Agent Profile
|
||||
|
||||
When users need help with coding or using tools, your role is to provide precise and effective guidance.
|
||||
Use the tools provided if they can solve the problem, otherwise, write the code step by step, showing only the part necessary to solve the current problem.
|
||||
Each reply should contain only the guidance required for the current step either by tool usage or code.
|
||||
ATTENTION: The Action Status field ensures that the tools or code mentioned in the Action can be parsed smoothly. Please make sure not to omit the Action Status field when replying.
|
||||
|
||||
#### Response Output Format
|
||||
|
||||
**Thoughts:** Considering the session records and executed steps, decide whether the current step requires the use of a tool or code_executing.
|
||||
Solve the problem step by step, only displaying the thought process necessary for the current step of solving the problem.
|
||||
If code_executing is required, outline the plan for executing this step.
|
||||
|
||||
**Action Status:** Set to 'stopped' or 'code_executing'. If it's 'stopped', the next action is to provide the final answer to the original question. If it's 'code_executing', the next step is to write the code.
|
||||
|
||||
**Action:** Code according to your thoughts. Use this format for code:
|
||||
|
||||
```python
|
||||
# Write your code here
|
||||
```
|
||||
"""
|
||||
|
||||
# **Observation:** Check the results and effects of the executed code.
|
||||
|
||||
# ... (Repeat this Question/Thoughts/Action/Observation cycle as needed)
|
||||
|
||||
# **Thoughts:** I now know the final answer
|
||||
|
||||
# **Action Status:** Set to 'stopped'
|
||||
|
||||
# **Action:** The final answer to the original input question
|
|
@ -1,40 +0,0 @@
|
|||
|
||||
|
||||
BASE_PROMPT_INPUT = '''#### Begin!!!
|
||||
'''
|
||||
|
||||
PLAN_PROMPT_INPUT = '''#### Begin!!!
|
||||
**Question:** {query}
|
||||
'''
|
||||
|
||||
REACT_PROMPT_INPUT = '''#### Begin!!!
|
||||
{query}
|
||||
'''
|
||||
|
||||
|
||||
CONTEXT_PROMPT_INPUT = '''#### Begin!!!
|
||||
**Context:** {context}
|
||||
'''
|
||||
|
||||
QUERY_CONTEXT_DOC_PROMPT_INPUT = '''#### Begin!!!
|
||||
**Origin Query:** {query}
|
||||
|
||||
**Context:** {context}
|
||||
|
||||
**DocInfos:** {DocInfos}
|
||||
'''
|
||||
|
||||
QUERY_CONTEXT_PROMPT_INPUT = '''#### Begin!!!
|
||||
**Origin Query:** {query}
|
||||
|
||||
**Context:** {context}
|
||||
'''
|
||||
|
||||
EXECUTOR_PROMPT_INPUT = '''#### Begin!!!
|
||||
{query}
|
||||
'''
|
||||
|
||||
BEGIN_PROMPT_INPUT = '''#### Begin!!!
|
||||
'''
|
||||
|
||||
CHECK_PROMPT_INPUT = '''下面是用户的原始问题:{query}'''
|
|
@ -1,14 +0,0 @@
|
|||
RECOGNIZE_INTENTION_PROMPT = """你是一个任务决策助手,能够将理解用户意图并决策采取最合适的行动,尽可能地以有帮助和准确的方式回应人类,
|
||||
使用 JSON Blob 来指定一个返回的内容,提供一个 action(行动)。
|
||||
有效的 'action' 值为:'planning'(需要先进行拆解计划) or 'only_answer' (不需要拆解问题即可直接回答问题)or "tool_using" (使用工具来回答问题) or 'coding'(生成可执行的代码)。
|
||||
在每个 $JSON_BLOB 中仅提供一个 action,如下所示:
|
||||
```
|
||||
{{'action': $ACTION}}
|
||||
```
|
||||
按照以下格式进行回应:
|
||||
问题:输入问题以回答
|
||||
行动:$ACTION
|
||||
```
|
||||
$JSON_BLOB
|
||||
```
|
||||
"""
|
|
@ -1,218 +0,0 @@
|
|||
PRD_WRITER_METAGPT_PROMPT = """#### Agent Profile
|
||||
|
||||
You are a professional Product Manager, your goal is to design a concise, usable, efficient product.
|
||||
According to the context, fill in the following missing information, note that each sections are returned in Python code triple quote form seperatedly.
|
||||
If the Origin Query are unclear, ensure minimum viability and avoid excessive design.
|
||||
ATTENTION: response carefully referenced "Response Output Format" in format.
|
||||
|
||||
|
||||
#### Input Format
|
||||
|
||||
**Origin Query:** the initial question or objective that the user wanted to achieve
|
||||
|
||||
**Context:** the current status and history of the tasks to determine if Origin Query has been achieved.
|
||||
|
||||
#### Response Output Format
|
||||
**Original Requirements:**
|
||||
The boss ...
|
||||
|
||||
**Product Goals:**
|
||||
```python
|
||||
[
|
||||
"Create a ...",
|
||||
]
|
||||
```
|
||||
|
||||
**User Stories:**
|
||||
```python
|
||||
[
|
||||
"As a user, ...",
|
||||
]
|
||||
```
|
||||
|
||||
**Competitive Analysis:**
|
||||
```python
|
||||
[
|
||||
"Python Snake Game: ...",
|
||||
]
|
||||
```
|
||||
|
||||
**Requirement Analysis:**
|
||||
The product should be a ...
|
||||
|
||||
**Requirement Pool:**
|
||||
```python
|
||||
[
|
||||
["End game ...", "P0"]
|
||||
]
|
||||
```
|
||||
|
||||
**UI Design draft:**
|
||||
Give a basic function description, and a draft
|
||||
|
||||
**Anything UNCLEAR:**
|
||||
There are no unclear points.'''
|
||||
"""
|
||||
|
||||
|
||||
|
||||
DESIGN_WRITER_METAGPT_PROMPT = """#### Agent Profile
|
||||
|
||||
You are an architect; the goal is to design a SOTA PEP8-compliant python system; make the best use of good open source tools.
|
||||
Fill in the following missing information based on the context, note that all sections are response with code form separately.
|
||||
8192 chars or 2048 tokens. Try to use them up.
|
||||
ATTENTION: response carefully referenced "Response Output Format" in format.
|
||||
|
||||
#### Input Format
|
||||
|
||||
**Origin Query:** the initial question or objective that the user wanted to achieve
|
||||
|
||||
**Context:** the current status and history of the tasks to determine if Origin Query has been achieved.
|
||||
|
||||
#### Response Output Format
|
||||
**Implementation approach:**
|
||||
Provide as Plain text. Analyze the difficult points of the requirements, select the appropriate open-source framework.
|
||||
|
||||
**Python package name:**
|
||||
Provide as Python str with python triple quoto, concise and clear, characters only use a combination of all lowercase and underscores
|
||||
```python
|
||||
"snake_game"
|
||||
```
|
||||
|
||||
**File list:**
|
||||
Provided as Python list[str], the list of ONLY REQUIRED files needed to write the program(LESS IS MORE!). Only need relative paths, comply with PEP8 standards. ALWAYS write a main.py or app.py here
|
||||
|
||||
```python
|
||||
[
|
||||
"main.py",
|
||||
...
|
||||
]
|
||||
```
|
||||
|
||||
**Data structures and interface definitions:**
|
||||
Use mermaid classDiagram code syntax, including classes (INCLUDING __init__ method) and functions (with type annotations),
|
||||
CLEARLY MARK the RELATIONSHIPS between classes, and comply with PEP8 standards. The data structures SHOULD BE VERY DETAILED and the API should be comprehensive with a complete design.
|
||||
|
||||
```mermaid
|
||||
classDiagram
|
||||
class Game {{
|
||||
+int score
|
||||
}}
|
||||
...
|
||||
Game "1" -- "1" Food: has
|
||||
```
|
||||
|
||||
**Program call flow:**
|
||||
Use sequenceDiagram code syntax, COMPLETE and VERY DETAILED, using CLASSES AND API DEFINED ABOVE accurately, covering the CRUD AND INIT of each object, SYNTAX MUST BE CORRECT.
|
||||
```mermaid
|
||||
sequenceDiagram
|
||||
participant M as Main
|
||||
...
|
||||
G->>M: end game
|
||||
```
|
||||
|
||||
**Anything UNCLEAR:**
|
||||
Provide as Plain text. Make clear here.
|
||||
"""
|
||||
|
||||
|
||||
|
||||
TASK_WRITER_METAGPT_PROMPT = """#### Agent Profile
|
||||
|
||||
You are a project manager, the goal is to break down tasks according to PRD/technical design, give a task list, and analyze task dependencies to start with the prerequisite modules
|
||||
Based on the context, fill in the following missing information, note that all sections are returned in Python code triple quote form seperatedly.
|
||||
Here the granularity of the task is a file, if there are any missing files, you can supplement them
|
||||
8192 chars or 2048 tokens. Try to use them up.
|
||||
ATTENTION: response carefully referenced "Response Output Format" in format.
|
||||
|
||||
#### Input Format
|
||||
|
||||
**Origin Query:** the initial question or objective that the user wanted to achieve
|
||||
|
||||
**Context:** the current status and history of the tasks to determine if Origin Query has been achieved.
|
||||
|
||||
#### Response Output Format
|
||||
|
||||
**Required Python third-party packages:** Provided in requirements.txt format
|
||||
```python
|
||||
flask==1.1.2
|
||||
bcrypt==3.2.0
|
||||
...
|
||||
```
|
||||
|
||||
**Required Other language third-party packages:** Provided in requirements.txt format
|
||||
```python
|
||||
No third-party ...
|
||||
```
|
||||
|
||||
**Full API spec:** Use OpenAPI 3.0. Describe all APIs that may be used by both frontend and backend.
|
||||
```python
|
||||
openapi: 3.0.0
|
||||
...
|
||||
description: A JSON object ...
|
||||
```
|
||||
|
||||
**Logic Analysis:** Provided as a Python list[list[str]. the first is filename, the second is class/method/function should be implemented in this file. Analyze the dependencies between the files, which work should be done first
|
||||
```python
|
||||
[
|
||||
["game.py", "Contains ..."],
|
||||
]
|
||||
```
|
||||
|
||||
**PLAN:** Provided as Python list[str]. Each str is a filename, the more at the beginning, the more it is a prerequisite dependency, should be done first
|
||||
```python
|
||||
[
|
||||
"game.py",
|
||||
]
|
||||
```
|
||||
|
||||
**Shared Knowledge:** Anything that should be public like utils' functions, config's variables details that should make clear first.
|
||||
```python
|
||||
'game.py' contains ...
|
||||
```
|
||||
|
||||
**Anything UNCLEAR:**
|
||||
Provide as Plain text. Make clear here. For example, don't forget a main entry. don't forget to init 3rd party libs.
|
||||
"""
|
||||
|
||||
|
||||
CODE_WRITER_METAGPT_PROMPT = """#### Agent Profile
|
||||
|
||||
You are a professional engineer; the main goal is to write PEP8 compliant, elegant, modular, easy to read and maintain Python 3.9 code (but you can also use other programming language)
|
||||
|
||||
Code: Write code with triple quoto, based on the following list and context.
|
||||
1. Do your best to implement THIS ONLY ONE FILE. ONLY USE EXISTING API. IF NO API, IMPLEMENT IT.
|
||||
2. Requirement: Based on the context, implement one following code file, note to return only in code form, your code will be part of the entire project, so please implement complete, reliable, reusable code snippets
|
||||
3. Attention1: If there is any setting, ALWAYS SET A DEFAULT VALUE, ALWAYS USE STRONG TYPE AND EXPLICIT VARIABLE.
|
||||
4. Attention2: YOU MUST FOLLOW "Data structures and interface definitions". DONT CHANGE ANY DESIGN.
|
||||
5. Think before writing: What should be implemented and provided in this document?
|
||||
6. CAREFULLY CHECK THAT YOU DONT MISS ANY NECESSARY CLASS/FUNCTION IN THIS FILE.
|
||||
7. Do not use public member functions that do not exist in your design.
|
||||
8. **$key:** is Input format or Output format, *$key* is the context infomation, they are different.
|
||||
|
||||
8192 chars or 2048 tokens. Try to use them up.
|
||||
ATTENTION: response carefully referenced "Response Output Format" in format **$key:**.
|
||||
|
||||
|
||||
#### Input Format
|
||||
**Origin Query:** the user's origin query you should to be solved
|
||||
|
||||
**Context:** the current status and history of the tasks to determine if Origin Query has been achieved.
|
||||
|
||||
**Question:** clarify the current question to be solved
|
||||
|
||||
#### Response Output Format
|
||||
**Action Status:** Coding2File
|
||||
|
||||
**SaveFileName:** construct a local file name based on Question and Context, such as
|
||||
|
||||
```python
|
||||
$projectname/$filename.py
|
||||
```
|
||||
|
||||
**Code:** Write your code here
|
||||
```python
|
||||
# Write your code here
|
||||
```
|
||||
|
||||
"""
|
|
@ -1,113 +0,0 @@
|
|||
|
||||
|
||||
PLANNER_TEMPLATE_PROMPT = """#### Agent Profile
|
||||
|
||||
When users need assistance with generating a sequence of achievable tasks, your role is to provide a coherent and continuous plan.
|
||||
Design the plan step by step, ensuring each task builds on the completion of the previous one.
|
||||
Each instruction should be actionable and directly follow from the outcome of the preceding step.
|
||||
ATTENTION: response carefully referenced "Response Output Format" in format.
|
||||
|
||||
#### Input Format
|
||||
|
||||
**Question:** First, clarify the problem to be solved.
|
||||
|
||||
#### Response Output Format
|
||||
|
||||
**Action Status:** Set to 'finished' or 'planning'.
|
||||
If it's 'finished', the PLAN is to provide the final answer to the original question.
|
||||
If it's 'planning', the PLAN is to provide a Python list[str] of achievable tasks.
|
||||
|
||||
**PLAN:**
|
||||
```list
|
||||
[
|
||||
"First, we should ...",
|
||||
]
|
||||
```
|
||||
|
||||
"""
|
||||
|
||||
|
||||
TOOL_PLANNER_PROMPT = """#### Agent Profile
|
||||
|
||||
Helps user to break down a process of tool usage into a series of plans.
|
||||
If there are no available tools, can directly answer the question.
|
||||
Rrespond to humans in the most helpful and accurate way possible.
|
||||
|
||||
#### Input Format
|
||||
|
||||
**Origin Query:** the initial question or objective that the user wanted to achieve
|
||||
|
||||
**Context:** the current status and history of the tasks to determine if Origin Query has been achieved.
|
||||
|
||||
#### Response Output Format
|
||||
|
||||
**Action Status:** Set to 'finished' or 'planning'. If it's 'finished', the PLAN is to provide the final answer to the original question. If it's 'planning', the PLAN is to provide a sequence of achievable tasks.
|
||||
|
||||
**PLAN:**
|
||||
```python
|
||||
[
|
||||
"First, we should ...",
|
||||
]
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
GENERAL_PLANNER_PROMPT = """你是一个通用计划拆解助手,将问题拆解问题成各个详细明确的步骤计划或直接回答问题,尽可能地以有帮助和准确的方式回应人类,
|
||||
使用 JSON Blob 来指定一个返回的内容,提供一个 action(行动)和一个 plans (生成的计划)。
|
||||
有效的 'action' 值为:'planning'(拆解计划) or 'only_answer' (不需要拆解问题即可直接回答问题)。
|
||||
有效的 'plans' 值为: 一个任务列表,按顺序写出需要执行的计划
|
||||
在每个 $JSON_BLOB 中仅提供一个 action,如下所示:
|
||||
```
|
||||
{{'action': 'planning', 'plans': [$PLAN1, $PLAN2, $PLAN3, ..., $PLANN], }}
|
||||
或者
|
||||
{{'action': 'only_answer', 'plans': "直接回答问题", }}
|
||||
```
|
||||
|
||||
按照以下格式进行回应:
|
||||
问题:输入问题以回答
|
||||
行动:
|
||||
```
|
||||
$JSON_BLOB
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
DATA_PLANNER_PROMPT = """你是一个数据分析助手,能够根据问题来制定一个详细明确的数据分析计划,尽可能地以有帮助和准确的方式回应人类,
|
||||
使用 JSON Blob 来指定一个返回的内容,提供一个 action(行动)和一个 plans (生成的计划)。
|
||||
有效的 'action' 值为:'planning'(拆解计划) or 'only_answer' (不需要拆解问题即可直接回答问题)。
|
||||
有效的 'plans' 值为: 一份数据分析计划清单,按顺序排列,用文本表示
|
||||
在每个 $JSON_BLOB 中仅提供一个 action,如下所示:
|
||||
```
|
||||
{{'action': 'planning', 'plans': '$PLAN1, $PLAN2, ..., $PLAN3' }}
|
||||
```
|
||||
|
||||
按照以下格式进行回应:
|
||||
问题:输入问题以回答
|
||||
行动:
|
||||
```
|
||||
$JSON_BLOB
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
# TOOL_PLANNER_PROMPT = """你是一个工具使用过程的计划拆解助手,将问题拆解为一系列的工具使用计划,若没有可用工具则直接回答问题,尽可能地以有帮助和准确的方式回应人类,你可以使用以下工具:
|
||||
# {formatted_tools}
|
||||
# 使用 JSON Blob 来指定一个返回的内容,提供一个 action(行动)和一个 plans (生成的计划)。
|
||||
# 有效的 'action' 值为:'planning'(拆解计划) or 'only_answer' (不需要拆解问题即可直接回答问题)。
|
||||
# 有效的 'plans' 值为: 一个任务列表,按顺序写出需要使用的工具和使用该工具的理由
|
||||
# 在每个 $JSON_BLOB 中仅提供一个 action,如下两个示例所示:
|
||||
# ```
|
||||
# {{'action': 'planning', 'plans': [$PLAN1, $PLAN2, $PLAN3, ..., $PLANN], }}
|
||||
# ```
|
||||
# 或者 若无法通过以上工具解决问题,则直接回答问题
|
||||
# ```
|
||||
# {{'action': 'only_answer', 'plans': "直接回答问题", }}
|
||||
# ```
|
||||
|
||||
# 按照以下格式进行回应:
|
||||
# 问题:输入问题以回答
|
||||
# 行动:
|
||||
# ```
|
||||
# $JSON_BLOB
|
||||
# ```
|
||||
# """
|
|
@ -1,74 +0,0 @@
|
|||
# Question Answer Assistance Guidance
|
||||
|
||||
QA_TEMPLATE_PROMPT = """#### Agent Profile
|
||||
|
||||
Based on the information provided, please answer the origin query concisely and professionally.
|
||||
Attention: Follow the input format and response output format
|
||||
|
||||
#### Input Format
|
||||
|
||||
**Origin Query:** the initial question or objective that the user wanted to achieve
|
||||
|
||||
**Context:** the current status and history of the tasks to determine if Origin Query has been achieved.
|
||||
|
||||
**DocInfos:**: the relevant doc information or code information, if this is empty, don't refer to this.
|
||||
|
||||
#### Response Output Format
|
||||
**Action Status:** Set to 'Continued' or 'Stopped'.
|
||||
**Answer:** Response to the user's origin query based on Context and DocInfos. If DocInfos is empty, you can ignore it.
|
||||
If the answer cannot be derived from the given Context and DocInfos, please say 'The question cannot be answered based on the information provided' and do not add any fabricated elements to the answer.
|
||||
"""
|
||||
|
||||
|
||||
CODE_QA_PROMPT = """#### Agent Profile
|
||||
|
||||
Based on the information provided, please answer the origin query concisely and professionally.
|
||||
Attention: Follow the input format and response output format
|
||||
|
||||
#### Input Format
|
||||
|
||||
**Origin Query:** the initial question or objective that the user wanted to achieve
|
||||
|
||||
**DocInfos:**: the relevant doc information or code information, if this is empty, don't refer to this.
|
||||
|
||||
#### Response Output Format
|
||||
**Action Status:** Set to 'Continued' or 'Stopped'.
|
||||
**Answer:** Response to the user's origin query based on Context and DocInfos. If DocInfos is empty, you can ignore it.
|
||||
If the answer cannot be derived from the given Context and DocInfos, please say 'The question cannot be answered based on the information provided' and do not add any fabricated elements to the answer.
|
||||
"""
|
||||
|
||||
|
||||
QA_PROMPT = """根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,不允许在答案中添加编造成分,答案请使用中文。
|
||||
使用 JSON Blob 来指定一个返回的内容,提供一个 action(行动)。
|
||||
有效的 'action' 值为:'finished'(任务已经可以通过上下文信息可以回答) or 'continue' (上下文信息不足以回答问题)。
|
||||
在每个 $JSON_BLOB 中仅提供一个 action,如下所示:
|
||||
```
|
||||
{{'action': $ACTION, 'content': '总结对话内容'}}
|
||||
```
|
||||
按照以下格式进行回应:
|
||||
问题:输入问题以回答
|
||||
行动:$ACTION
|
||||
```
|
||||
$JSON_BLOB
|
||||
```
|
||||
"""
|
||||
|
||||
# 基于本地代码知识问答的提示词模版
|
||||
CODE_PROMPT_TEMPLATE = """【指令】根据已知信息来回答问题。
|
||||
|
||||
【已知信息】{context}
|
||||
|
||||
【问题】{question}"""
|
||||
|
||||
# 代码解释模版
|
||||
CODE_INTERPERT_TEMPLATE = '''{code}
|
||||
|
||||
解释一下这段代码'''
|
||||
# CODE_QA_PROMPT = """【指令】根据已知信息来回答问"""
|
||||
|
||||
# 基于本地知识问答的提示词模版
|
||||
ORIGIN_TEMPLATE_PROMPT = """【指令】根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,不允许在答案中添加编造成分,答案请使用中文。
|
||||
|
||||
【已知信息】{context}
|
||||
|
||||
【问题】{question}"""
|
|
@ -1,103 +0,0 @@
|
|||
|
||||
|
||||
# REACT_CODE_PROMPT = """#### Agent Profile
|
||||
|
||||
# 1. When users need help with coding, your role is to provide precise and effective guidance.
|
||||
# 2. Reply follows the format of Thoughts/Action Status/Action/Observation cycle.
|
||||
# 3. Provide the final answer if they can solve the problem, otherwise, write the code step by step, showing only the part necessary to solve the current problem.
|
||||
# Each reply should contain only the guidance required for the current step either by tool usage or code.
|
||||
# 4. If the Response already contains content, continue writing following the format of the Response Output Format.
|
||||
|
||||
# #### Response Output Format
|
||||
|
||||
# **Thoughts:** Considering the session records and executed steps, solve the problem step by step, only displaying the thought process necessary for the current step of solving the problem,
|
||||
# outline the plan for executing this step.
|
||||
|
||||
# **Action Status:** Set to 'stopped' or 'code_executing'.
|
||||
# If it's 'stopped', the action is to provide the final answer to the session records and executed steps.
|
||||
# If it's 'code_executing', the action is to write the code.
|
||||
|
||||
# **Action:**
|
||||
# ```python
|
||||
# # Write your code here
|
||||
# ...
|
||||
# ```
|
||||
|
||||
# **Observation:** Check the results and effects of the executed code.
|
||||
|
||||
# ... (Repeat this "Thoughts/Action Status/Action/Observation" cycle format as needed)
|
||||
|
||||
# **Thoughts:** Considering the session records and executed steps, give the final answer
|
||||
# .
|
||||
# **Action Status:** stopped
|
||||
|
||||
# **Action:** Response the final answer to the session records.
|
||||
|
||||
# """
|
||||
|
||||
REACT_CODE_PROMPT = """#### Agent Profile
|
||||
|
||||
When users need help with coding, your role is to provide precise and effective guidance.
|
||||
|
||||
Write the code step by step, showing only the part necessary to solve the current problem. Each reply should contain only the code required for the current step.
|
||||
|
||||
#### Response Output Format
|
||||
|
||||
**Thoughts:** According the previous context, solve the problem step by step, only displaying the thought process necessary for the current step of solving the problem,
|
||||
outline the plan for executing this step.
|
||||
|
||||
**Action Status:** Set to 'stopped' or 'code_executing'.
|
||||
If it's 'stopped', the action is to provide the final answer to the session records and executed steps.
|
||||
If it's 'code_executing', the action is to write the code.
|
||||
|
||||
**Action:**
|
||||
```python
|
||||
# Write your code here
|
||||
...
|
||||
```
|
||||
|
||||
**Observation:** Check the results and effects of the executed code.
|
||||
|
||||
... (Repeat this "Thoughts/Action Status/Action/Observation" cycle format as needed)
|
||||
|
||||
**Thoughts:** Considering the session records and executed steps, give the final answer
|
||||
.
|
||||
**Action Status:** stopped
|
||||
|
||||
**Action:** Response the final answer to the session records.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
# REACT_CODE_PROMPT = """#### Writing Code Assistance Guidance
|
||||
|
||||
# When users need help with coding, your role is to provide precise and effective guidance.
|
||||
|
||||
# Write the code step by step, showing only the part necessary to solve the current problem. Each reply should contain only the code required for the current step.
|
||||
|
||||
# #### Response Process
|
||||
|
||||
# **Question:** First, clarify the problem to be solved.
|
||||
|
||||
# **Thoughts:** Based on the question and observations above, provide the plan for executing this step.
|
||||
|
||||
# **Action Status:** Set to 'stoped' or 'code_executing'. If it's 'stoped', the action is to provide the final answer to the original question. If it's 'code_executing', the action is to write the code.
|
||||
|
||||
# **Action:**
|
||||
# ```python
|
||||
# # Write your code here
|
||||
# import os
|
||||
# ...
|
||||
# ```
|
||||
|
||||
# **Observation:** Check the results and effects of the executed code.
|
||||
|
||||
# ... (Repeat this Thoughts/Action/Observation cycle as needed)
|
||||
|
||||
# **Thoughts:** I now know the final answer
|
||||
|
||||
# **Action Status:** Set to 'stoped'
|
||||
|
||||
# **Action:** The final answer to the original input question
|
||||
|
||||
# """
|
|
@ -1,37 +0,0 @@
|
|||
|
||||
|
||||
REACT_TEMPLATE_PROMPT = """#### Agent Profile
|
||||
|
||||
1. When users need help with coding, your role is to provide precise and effective guidance.
|
||||
2. Reply follows the format of Thoughts/Action Status/Action/Observation cycle.
|
||||
3. Provide the final answer if they can solve the problem, otherwise, write the code step by step, showing only the part necessary to solve the current problem.
|
||||
Each reply should contain only the guidance required for the current step either by tool usage or code.
|
||||
4. If the Response already contains content, continue writing following the format of the Response Output Format.
|
||||
|
||||
ATTENTION: Under the "Response" heading, the output format strictly adheres to the content specified in the "Response Output Format."
|
||||
|
||||
#### Response Output Format
|
||||
|
||||
**Question:** First, clarify the problem to be solved.
|
||||
|
||||
**Thoughts:** Based on the Session Records or observations above, provide the plan for executing this step.
|
||||
|
||||
**Action Status:** Set to either 'stopped' or 'code_executing'. If it's 'stopped', the next action is to provide the final answer to the original question. If it's 'code_executing', the next step is to write the code.
|
||||
|
||||
**Action:** Code according to your thoughts. Use this format for code:
|
||||
|
||||
```python
|
||||
# Write your code here
|
||||
```
|
||||
|
||||
**Observation:** Check the results and effects of the executed code.
|
||||
|
||||
... (Repeat this "Thoughts/Action Status/Action/Observation" cycle format as needed)
|
||||
|
||||
**Thoughts:** Considering the session records and executed steps, give the final answer.
|
||||
|
||||
**Action Status:** stopped
|
||||
|
||||
**Action:** Response the final answer to the session records.
|
||||
|
||||
"""
|
|
@ -1,44 +0,0 @@
|
|||
REACT_TOOL_AND_CODE_PLANNER_PROMPT = """#### Agent Profile
|
||||
When users seek assistance in breaking down complex issues into manageable and actionable steps,
|
||||
your responsibility is to deliver a well-organized strategy or resolution through the use of tools or coding.
|
||||
|
||||
ATTENTION: response carefully referenced "Response Output Format" in format.
|
||||
|
||||
#### Input Format
|
||||
|
||||
**Question:** First, clarify the problem to be solved.
|
||||
|
||||
#### Response Output Format
|
||||
|
||||
**Action Status:** Set to 'planning' to provide a sequence of tasks, or 'only_answer' to provide a direct response without a plan.
|
||||
|
||||
**Action:**
|
||||
```list
|
||||
"First, we should ...",
|
||||
]
|
||||
```
|
||||
|
||||
Or, provide the direct answer.
|
||||
"""
|
||||
|
||||
# REACT_TOOL_AND_CODE_PLANNER_PROMPT = """你是一个工具和代码使用过程的计划拆解助手,将问题拆解为一系列的工具使用计划,若没有可用工具则使用代码,尽可能地以有帮助和准确的方式回应人类,你可以使用以下工具:
|
||||
# {formatted_tools}
|
||||
# 使用 JSON Blob 来指定一个返回的内容,提供一个 action(行动)和一个 plans (生成的计划)。
|
||||
# 有效的 'action' 值为:'planning'(拆解计划) or 'only_answer' (不需要拆解问题即可直接回答问题)。
|
||||
# 有效的 'plans' 值为: 一个任务列表,按顺序写出需要使用的工具和使用该工具的理由
|
||||
# 在每个 $JSON_BLOB 中仅提供一个 action,如下两个示例所示:
|
||||
# ```
|
||||
# {{'action': 'planning', 'plans': [$PLAN1, $PLAN2, $PLAN3, ..., $PLANN], }}
|
||||
# ```
|
||||
# 或者 若无法通过以上工具或者代码解决问题,则直接回答问题
|
||||
# ```
|
||||
# {{'action': 'only_answer', 'plans': "直接回答问题", }}
|
||||
# ```
|
||||
|
||||
# 按照以下格式进行回应($JSON_BLOB要求符合上述规定):
|
||||
# 问题:输入问题以回答
|
||||
# 行动:
|
||||
# ```
|
||||
# $JSON_BLOB
|
||||
# ```
|
||||
# """
|
|
@ -1,197 +0,0 @@
|
|||
REACT_TOOL_AND_CODE_PROMPT = """#### Agent Profile
|
||||
|
||||
When users need help with coding or using tools, your role is to provide precise and effective guidance.
|
||||
Use the tools provided if they can solve the problem, otherwise, write the code step by step, showing only the part necessary to solve the current problem.
|
||||
Each reply should contain only the guidance required for the current step either by tool usage or code.
|
||||
ATTENTION: The Action Status field ensures that the tools or code mentioned in the Action can be parsed smoothly. Please make sure not to omit the Action Status field when replying.
|
||||
|
||||
#### Tool Infomation
|
||||
|
||||
You can use these tools:\n{formatted_tools}
|
||||
|
||||
Valid "tool_name" value:\n{tool_names}
|
||||
|
||||
#### Response Output Format
|
||||
|
||||
**Thoughts:** Considering the session records and executed steps, decide whether the current step requires the use of a tool or code_executing. Solve the problem step by step, only displaying the thought process necessary for the current step of solving the problem. If code_executing is required, outline the plan for executing this step.
|
||||
|
||||
**Action Status:** stoped, tool_using or code_executing
|
||||
Use 'stopped' when the task has been completed, and no further use of tools or execution of code is necessary.
|
||||
Use 'tool_using' when the current step in the process involves utilizing a tool to proceed.
|
||||
Use 'code_executing' when the current step requires writing and executing code.
|
||||
|
||||
**Action:**
|
||||
|
||||
If Action Status is 'tool_using', format the tool action in JSON from Question and Observation, enclosed in a code block, like this:
|
||||
```json
|
||||
{
|
||||
"tool_name": "$TOOL_NAME",
|
||||
"tool_params": "$INPUT"
|
||||
}
|
||||
```
|
||||
|
||||
If Action Status is 'code_executing', write the necessary code to solve the issue, enclosed in a code block, like this:
|
||||
```python
|
||||
Write your running code here
|
||||
```
|
||||
|
||||
If Action Status is 'stopped', provide the final response or instructions in written form, enclosed in a code block, like this:
|
||||
```text
|
||||
The final response or instructions to the user question.
|
||||
```
|
||||
|
||||
**Observation:** Check the results and effects of the executed action.
|
||||
|
||||
... (Repeat this Thoughts/Action Status/Action/Observation cycle as needed)
|
||||
|
||||
**Thoughts:** Conclude the final response to the user question.
|
||||
|
||||
**Action Status:** stoped
|
||||
|
||||
**Action:** The final answer or guidance to the user question.
|
||||
"""
|
||||
|
||||
# REACT_TOOL_AND_CODE_PROMPT = """#### Agent Profile
|
||||
|
||||
# 1. When users need help with coding or using tools, your role is to provide precise and effective guidance.
|
||||
# 2. Reply follows the format of Thoughts/Action Status/Action/Observation cycle.
|
||||
# 3. Use the tools provided if they can solve the problem, otherwise, write the code step by step, showing only the part necessary to solve the current problem.
|
||||
# Each reply should contain only the guidance required for the current step either by tool usage or code.
|
||||
# 4. If the Response already contains content, continue writing following the format of the Response Output Format.
|
||||
|
||||
# ATTENTION: The "Action Status" field ensures that the tools or code mentioned in the "Action" can be parsed smoothly. Please make sure not to omit the "Action Status" field when replying.
|
||||
|
||||
# #### Tool Infomation
|
||||
|
||||
# You can use these tools:\n{formatted_tools}
|
||||
|
||||
# Valid "tool_name" value:\n{tool_names}
|
||||
|
||||
# #### Response Output Format
|
||||
|
||||
# **Thoughts:** Considering the user's question, previously executed steps, and the plan, decide whether the current step requires the use of a tool or code_executing.
|
||||
# Solve the problem step by step, only displaying the thought process necessary for the current step of solving the problem.
|
||||
# If a tool can be used, provide its name and parameters. If code_executing is required, outline the plan for executing this step.
|
||||
|
||||
# **Action Status:** stoped, tool_using, or code_executing. (Choose one from these three statuses.)
|
||||
# # If the task is done, set it to 'stoped'.
|
||||
# # If using a tool, set it to 'tool_using'.
|
||||
# # If writing code, set it to 'code_executing'.
|
||||
|
||||
# **Action:**
|
||||
|
||||
# If Action Status is 'tool_using', format the tool action in JSON from Question and Observation, enclosed in a code block, like this:
|
||||
# ```json
|
||||
# {
|
||||
# "tool_name": "$TOOL_NAME",
|
||||
# "tool_params": "$INPUT"
|
||||
# }
|
||||
# ```
|
||||
|
||||
# If Action Status is 'code_executing', write the necessary code to solve the issue, enclosed in a code block, like this:
|
||||
# ```python
|
||||
# Write your running code here
|
||||
# ...
|
||||
# ```
|
||||
|
||||
# If Action Status is 'stopped', provide the final response or instructions in written form, enclosed in a code block, like this:
|
||||
# ```text
|
||||
# The final response or instructions to the original input question.
|
||||
# ```
|
||||
|
||||
# **Observation:** Check the results and effects of the executed action.
|
||||
|
||||
# ... (Repeat this Thoughts/Action Status/Action/Observation cycle as needed)
|
||||
|
||||
# **Thoughts:** Considering the user's question, previously executed steps, give the final answer.
|
||||
|
||||
# **Action Status:** stopped
|
||||
|
||||
# **Action:** Response the final answer to the session records.
|
||||
# """
|
||||
|
||||
|
||||
# REACT_TOOL_AND_CODE_PROMPT = """#### Code and Tool Agent Assistance Guidance
|
||||
|
||||
# When users need help with coding or using tools, your role is to provide precise and effective guidance. Use the tools provided if they can solve the problem, otherwise, write the code step by step, showing only the part necessary to solve the current problem. Each reply should contain only the guidance required for the current step either by tool usage or code.
|
||||
|
||||
# #### Tool Infomation
|
||||
|
||||
# You can use these tools:\n{formatted_tools}
|
||||
|
||||
# Valid "tool_name" value:\n{tool_names}
|
||||
|
||||
# #### Response Process
|
||||
|
||||
# **Question:** Start by understanding the input question to be answered.
|
||||
|
||||
# **Thoughts:** Considering the user's question, previously executed steps, and the plan, decide whether the current step requires the use of a tool or code_executing. Solve the problem step by step, only displaying the thought process necessary for the current step of solving the problem. If a tool can be used, provide its name and parameters. If code_executing is required, outline the plan for executing this step.
|
||||
|
||||
# **Action Status:** stoped, tool_using, or code_executing. (Choose one from these three statuses.)
|
||||
# If the task is done, set it to 'stoped'.
|
||||
# If using a tool, set it to 'tool_using'.
|
||||
# If writing code, set it to 'code_executing'.
|
||||
|
||||
# **Action:**
|
||||
|
||||
# If using a tool, use the tools by formatting the tool action in JSON from Question and Observation:. The format should be:
|
||||
# ```json
|
||||
# {{
|
||||
# "tool_name": "$TOOL_NAME",
|
||||
# "tool_params": "$INPUT"
|
||||
# }}
|
||||
# ```
|
||||
|
||||
# If the problem cannot be solved with a tool at the moment, then proceed to solve the issue using code. Output the following format to execute the code:
|
||||
|
||||
# ```python
|
||||
# Write your code here
|
||||
# ```
|
||||
|
||||
# **Observation:** Check the results and effects of the executed action.
|
||||
|
||||
# ... (Repeat this Thoughts/Action/Observation cycle as needed)
|
||||
|
||||
# **Thoughts:** Conclude the final response to the input question.
|
||||
|
||||
# **Action Status:** stoped
|
||||
|
||||
# **Action:** The final answer or guidance to the original input question.
|
||||
# """
|
||||
|
||||
|
||||
# REACT_TOOL_AND_CODE_PROMPT = """你是一个使用工具与代码的助手。
|
||||
# 如果现有工具不足以完成整个任务,请不要添加不存在的工具,只使用现有工具完成可能的部分。
|
||||
# 如果当前步骤不能使用工具完成,将由代码来完成。
|
||||
# 有效的"action"值为:"stopped"(已经完成用户的任务) 、 "tool_using" (使用工具来回答问题) 或 'code_executing'(结合总结下述思维链过程编写下一步的可执行代码)。
|
||||
# 尽可能地以有帮助和准确的方式回应人类,你可以使用以下工具:
|
||||
# {formatted_tools}
|
||||
# 如果现在的步骤可以用工具解决问题,请仅在每个$JSON_BLOB中提供一个action,如下所示:
|
||||
# ```
|
||||
# {{{{
|
||||
# "action": $ACTION,
|
||||
# "tool_name": $TOOL_NAME
|
||||
# "tool_params": $INPUT
|
||||
# }}}}
|
||||
# ```
|
||||
# 若当前无法通过工具解决问题,则使用代码解决问题
|
||||
# 请仅在每个$JSON_BLOB中提供一个action,如下所示:
|
||||
# ```
|
||||
# {{{{'action': $ACTION,'code_content': $CODE}}}}
|
||||
# ```
|
||||
|
||||
# 按照以下思维链格式进行回应($JSON_BLOB要求符合上述规定):
|
||||
# 问题:输入问题以回答
|
||||
# 思考:考虑之前和之后的步骤
|
||||
# 行动:
|
||||
# ```
|
||||
# $JSON_BLOB
|
||||
# ```
|
||||
# 观察:行动结果
|
||||
# ...(重复思考/行动/观察N次)
|
||||
# 思考:我知道该如何回应
|
||||
# 行动:
|
||||
# ```
|
||||
# $JSON_BLOB
|
||||
# ```
|
||||
# """
|
|
@ -1,77 +0,0 @@
|
|||
REACT_TOOL_PROMPT = """#### Agent Profile
|
||||
|
||||
When interacting with users, your role is to respond in a helpful and accurate manner using the tools available. Follow the steps below to ensure efficient and effective use of the tools.
|
||||
|
||||
Please note that all the tools you can use are listed below. You can only choose from these tools for use.
|
||||
|
||||
If there are no suitable tools, please do not invent any tools. Just let the user know that you do not have suitable tools to use.
|
||||
|
||||
ATTENTION: The Action Status field ensures that the tools or code mentioned in the Action can be parsed smoothly. Please make sure not to omit the Action Status field when replying.
|
||||
|
||||
#### Response Output Format
|
||||
|
||||
**Thoughts:** According the previous observations, plan the approach for using the tool effectively.
|
||||
|
||||
**Action Status:** Set to either 'stopped' or 'tool_using'. If 'stopped', provide the final response to the original question. If 'tool_using', proceed with using the specified tool.
|
||||
|
||||
**Action:** Use the tools by formatting the tool action in JSON. The format should be:
|
||||
|
||||
```json
|
||||
{
|
||||
"tool_name": "$TOOL_NAME",
|
||||
"tool_params": "$INPUT"
|
||||
}
|
||||
```
|
||||
|
||||
**Observation:** Evaluate the outcome of the tool's usage.
|
||||
|
||||
... (Repeat this Thoughts/Action Status/Action/Observation cycle as needed)
|
||||
|
||||
**Thoughts:** Determine the final response based on the results.
|
||||
|
||||
**Action Status:** Set to 'stopped'
|
||||
|
||||
**Action:** Conclude with the final response to the original question in this format:
|
||||
|
||||
```json
|
||||
{
|
||||
"tool_params": "Final response to be provided to the user",
|
||||
"tool_name": "notool",
|
||||
}
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
# REACT_TOOL_PROMPT = """尽可能地以有帮助和准确的方式回应人类。您可以使用以下工具:
|
||||
# {formatted_tools}
|
||||
# 使用json blob来指定一个工具,提供一个action关键字(工具名称)和一个tool_params关键字(工具输入)。
|
||||
# 有效的"action"值为:"stopped" 或 "tool_using" (使用工具来回答问题)
|
||||
# 有效的"tool_name"值为:{tool_names}
|
||||
# 请仅在每个$JSON_BLOB中提供一个action,如下所示:
|
||||
# ```
|
||||
# {{{{
|
||||
# "action": $ACTION,
|
||||
# "tool_name": $TOOL_NAME,
|
||||
# "tool_params": $INPUT
|
||||
# }}}}
|
||||
# ```
|
||||
|
||||
# 按照以下格式进行回应:
|
||||
# 问题:输入问题以回答
|
||||
# 思考:考虑之前和之后的步骤
|
||||
# 行动:
|
||||
# ```
|
||||
# $JSON_BLOB
|
||||
# ```
|
||||
# 观察:行动结果
|
||||
# ...(重复思考/行动/观察N次)
|
||||
# 思考:我知道该如何回应
|
||||
# 行动:
|
||||
# ```
|
||||
# {{{{
|
||||
# "action": "stopped",
|
||||
# "tool_name": "notool",
|
||||
# "tool_params": "最终返回答案给到用户"
|
||||
# }}}}
|
||||
# ```
|
||||
# """
|
|
@ -1,30 +0,0 @@
|
|||
REFINE_TEMPLATE_PROMPT = """#### Agent Profile
|
||||
|
||||
When users have a sequence of tasks that require optimization or adjustment based on feedback from the context, your role is to refine the existing plan.
|
||||
Your task is to identify where improvements can be made and provide a revised plan that is more efficient or effective.
|
||||
Each instruction should be an enhancement of the existing plan and should specify the step from which the changes should be implemented.
|
||||
|
||||
#### Input Format
|
||||
|
||||
**Context:** Review the history of the plan and feedback to identify areas for improvement.
|
||||
Take into consideration all feedback information from the current step. If there is no existing plan, generate a new one.
|
||||
|
||||
#### Response Output Format
|
||||
|
||||
**REASON:** think the reason of why choose 'finished', 'unchanged' or 'adjusted' step by step.
|
||||
|
||||
**Action Status:** Set to 'finished', 'unchanged' or 'adjusted'.
|
||||
If it's 'finished', all tasks are accomplished, and no adjustments are needed, so PLAN_STEP is set to -1.
|
||||
If it's 'unchanged', this PLAN has no problem, just set PLAN_STEP to CURRENT_STEP+1.
|
||||
If it's 'adjusted', the PLAN is to provide an optimized version of the original plan.
|
||||
|
||||
**PLAN:**
|
||||
```list
|
||||
[
|
||||
"First, we should ...",
|
||||
]
|
||||
```
|
||||
|
||||
**PLAN_STEP:** Set to the plan index from which the changes should start. Index range from 0 to n-1 or -1
|
||||
If it's 'finished', the PLAN_STEP is -1. If it's 'adjusted', the PLAN_STEP is the index of the first revised task in the sequence.
|
||||
"""
|
|
@ -1,40 +0,0 @@
|
|||
CONV_SUMMARY_PROMPT = """尽可能地以有帮助和准确的方式回应人类,根据“背景信息”中的有效信息回答问题,
|
||||
使用 JSON Blob 来指定一个返回的内容,提供一个 action(行动)。
|
||||
有效的 'action' 值为:'finished'(任务已经可以通过上下文信息可以回答) or 'continue' (根据背景信息回答问题)。
|
||||
在每个 $JSON_BLOB 中仅提供一个 action,如下所示:
|
||||
```
|
||||
{{'action': $ACTION, 'content': '根据背景信息回答问题'}}
|
||||
```
|
||||
按照以下格式进行回应:
|
||||
问题:输入问题以回答
|
||||
行动:
|
||||
```
|
||||
$JSON_BLOB
|
||||
```
|
||||
"""
|
||||
|
||||
CONV_SUMMARY_PROMPT = """尽可能地以有帮助和准确的方式回应人类
|
||||
根据“背景信息”中的有效信息回答问题,同时展现解答的过程和内容
|
||||
若能根“背景信息”回答问题,则直接回答
|
||||
否则,总结“背景信息”的内容
|
||||
"""
|
||||
|
||||
|
||||
CONV_SUMMARY_PROMPT_SPEC = """
|
||||
Your job is to summarize a history of previous messages in a conversation between an AI persona and a human.
|
||||
The conversation you are given is a fixed context window and may not be complete.
|
||||
Messages sent by the AI are marked with the 'assistant' role.
|
||||
The AI 'assistant' can also make calls to functions, whose outputs can be seen in messages with the 'function' role.
|
||||
Things the AI says in the message content are considered inner monologue and are not seen by the user.
|
||||
The only AI messages seen by the user are from when the AI uses 'send_message'.
|
||||
Messages the user sends are in the 'user' role.
|
||||
The 'user' role is also used for important system events, such as login events and heartbeat events (heartbeats run the AI's program without user action, allowing the AI to act without prompting from the user sending them a message).
|
||||
Summarize what happened in the conversation from the perspective of the AI (use the first person).
|
||||
Keep your summary less than 100 words, do NOT exceed this word limit.
|
||||
Only output the summary, do NOT include anything else in your output.
|
||||
|
||||
--- conversation
|
||||
{conversation}
|
||||
---
|
||||
|
||||
"""
|
|
@ -1,489 +0,0 @@
|
|||
from abc import abstractmethod, ABC
|
||||
from typing import List, Dict
|
||||
import os, sys, copy, json
|
||||
from jieba.analyse import extract_tags
|
||||
from collections import Counter
|
||||
from loguru import logger
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
|
||||
|
||||
from .schema import Memory, Message
|
||||
from coagent.service.service_factory import KBServiceFactory
|
||||
from coagent.llm_models import getChatModelFromConfig
|
||||
from coagent.llm_models.llm_config import EmbedConfig, LLMConfig
|
||||
from coagent.embeddings.utils import load_embeddings_from_path
|
||||
from coagent.utils.common_utils import save_to_json_file, read_json_file, addMinutesToTime
|
||||
from coagent.connector.configs.prompts import CONV_SUMMARY_PROMPT_SPEC
|
||||
from coagent.orm import table_init
|
||||
from coagent.base_configs.env_config import KB_ROOT_PATH
|
||||
# from configs.model_config import KB_ROOT_PATH, EMBEDDING_MODEL, EMBEDDING_DEVICE, SCORE_THRESHOLD
|
||||
# from configs.model_config import embedding_model_dict
|
||||
|
||||
|
||||
class BaseMemoryManager(ABC):
|
||||
"""
|
||||
This class represents a local memory manager that inherits from BaseMemoryManager.
|
||||
|
||||
Attributes:
|
||||
- user_name: A string representing the user name. Default is "default".
|
||||
- unique_name: A string representing the unique name. Default is "default".
|
||||
- memory_type: A string representing the memory type. Default is "recall".
|
||||
- do_init: A boolean indicating whether to initialize. Default is False.
|
||||
- current_memory: An instance of Memory class representing the current memory.
|
||||
- recall_memory: An instance of Memory class representing the recall memory.
|
||||
- summary_memory: An instance of Memory class representing the summary memory.
|
||||
- save_message_keys: A list of strings representing the keys for saving messages.
|
||||
|
||||
Methods:
|
||||
- __init__: Initializes the LocalMemoryManager with the given user_name, unique_name, memory_type, and do_init.
|
||||
- init_vb: Initializes the vb.
|
||||
- append: Appends a message to the recall memory, current memory, and summary memory.
|
||||
- extend: Extends the recall memory, current memory, and summary memory.
|
||||
- save: Saves the memory to the specified directory.
|
||||
- load: Loads the memory from the specified directory and returns a Memory instance.
|
||||
- save_new_to_vs: Saves new messages to the vector space.
|
||||
- save_to_vs: Saves the memory to the vector space.
|
||||
- router_retrieval: Routes the retrieval based on the retrieval type.
|
||||
- embedding_retrieval: Retrieves messages based on embedding.
|
||||
- text_retrieval: Retrieves messages based on text.
|
||||
- datetime_retrieval: Retrieves messages based on datetime.
|
||||
- recursive_summary: Performs recursive summarization of messages.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
user_name: str = "default",
|
||||
unique_name: str = "default",
|
||||
memory_type: str = "recall",
|
||||
do_init: bool = False,
|
||||
):
|
||||
"""
|
||||
Initializes the LocalMemoryManager with the given parameters.
|
||||
|
||||
Args:
|
||||
- user_name: A string representing the user name. Default is "default".
|
||||
- unique_name: A string representing the unique name. Default is "default".
|
||||
- memory_type: A string representing the memory type. Default is "recall".
|
||||
- do_init: A boolean indicating whether to initialize. Default is False.
|
||||
"""
|
||||
self.user_name = user_name
|
||||
self.unique_name = unique_name
|
||||
self.memory_type = memory_type
|
||||
self.do_init = do_init
|
||||
# self.current_memory = Memory(messages=[])
|
||||
# self.recall_memory = Memory(messages=[])
|
||||
# self.summary_memory = Memory(messages=[])
|
||||
self.current_memory_dict: Dict[str, Memory] = {}
|
||||
self.recall_memory_dict: Dict[str, Memory] = {}
|
||||
self.summary_memory_dict: Dict[str, Memory] = {}
|
||||
self.save_message_keys = [
|
||||
'chat_index', 'role_name', 'role_type', 'role_prompt', 'input_query', 'origin_query',
|
||||
'datetime', 'role_content', 'step_content', 'parsed_output', 'spec_parsed_output', 'parsed_output_list',
|
||||
'task', 'db_docs', 'code_docs', 'search_docs', 'phase_name', 'chain_name', 'customed_kargs']
|
||||
self.init_vb()
|
||||
|
||||
def re_init(self, do_init: bool=False):
|
||||
self.init_vb()
|
||||
|
||||
def init_vb(self, do_init: bool=None):
|
||||
"""
|
||||
Initializes the vb.
|
||||
"""
|
||||
pass
|
||||
|
||||
def append(self, message: Message):
|
||||
"""
|
||||
Appends a message to the recall memory, current memory, and summary memory.
|
||||
|
||||
Args:
|
||||
- message: An instance of Message class representing the message to be appended.
|
||||
"""
|
||||
pass
|
||||
|
||||
def extend(self, memory: Memory):
|
||||
"""
|
||||
Extends the recall memory, current memory, and summary memory.
|
||||
|
||||
Args:
|
||||
- memory: An instance of Memory class representing the memory to be extended.
|
||||
"""
|
||||
pass
|
||||
|
||||
def save(self, save_dir: str = ""):
|
||||
"""
|
||||
Saves the memory to the specified directory.
|
||||
|
||||
Args:
|
||||
- save_dir: A string representing the directory to save the memory. Default is KB_ROOT_PATH.
|
||||
"""
|
||||
pass
|
||||
|
||||
def load(self, load_dir: str = "") -> Memory:
|
||||
"""
|
||||
Loads the memory from the specified directory and returns a Memory instance.
|
||||
|
||||
Args:
|
||||
- load_dir: A string representing the directory to load the memory from. Default is KB_ROOT_PATH.
|
||||
|
||||
Returns:
|
||||
- An instance of Memory class representing the loaded memory.
|
||||
"""
|
||||
pass
|
||||
|
||||
def save_new_to_vs(self, messages: List[Message]):
|
||||
"""
|
||||
Saves new messages to the vector space.
|
||||
|
||||
Args:
|
||||
- messages: A list of Message instances representing the messages to be saved.
|
||||
- embed_model: A string representing the embedding model. Default is EMBEDDING_MODEL.
|
||||
- embed_device: A string representing the embedding device. Default is EMBEDDING_DEVICE.
|
||||
"""
|
||||
pass
|
||||
|
||||
def save_to_vs(self, ):
|
||||
"""
|
||||
Saves the memory to the vector space.
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_memory_pool(self, user_name: str, ):
|
||||
"""
|
||||
return memory_pool
|
||||
"""
|
||||
pass
|
||||
|
||||
def router_retrieval(self, text: str=None, datetime: str = None, n=5, top_k=5, retrieval_type: str = "embedding", **kwargs) -> List[Message]:
|
||||
"""
|
||||
Routes the retrieval based on the retrieval type.
|
||||
|
||||
Args:
|
||||
- text: A string representing the text for retrieval. Default is None.
|
||||
- datetime: A string representing the datetime for retrieval. Default is None.
|
||||
- n: An integer representing the number of messages. Default is 5.
|
||||
- top_k: An integer representing the top k messages. Default is 5.
|
||||
- retrieval_type: A string representing the retrieval type. Default is "embedding".
|
||||
- **kwargs: Additional keyword arguments for retrieval.
|
||||
|
||||
Returns:
|
||||
- A list of Message instances representing the retrieved messages.
|
||||
"""
|
||||
pass
|
||||
|
||||
def embedding_retrieval(self, text: str, embed_model="", top_k=1, score_threshold=1.0, **kwargs) -> List[Message]:
|
||||
"""
|
||||
Retrieves messages based on embedding.
|
||||
|
||||
Args:
|
||||
- text: A string representing the text for retrieval.
|
||||
- embed_model: A string representing the embedding model. Default is EMBEDDING_MODEL.
|
||||
- top_k: An integer representing the top k messages. Default is 1.
|
||||
- score_threshold: A float representing the score threshold. Default is SCORE_THRESHOLD.
|
||||
- **kwargs: Additional keyword arguments for retrieval.
|
||||
|
||||
Returns:
|
||||
- A list of Message instances representing the retrieved messages.
|
||||
"""
|
||||
pass
|
||||
|
||||
def text_retrieval(self, text: str, **kwargs) -> List[Message]:
|
||||
"""
|
||||
Retrieves messages based on text.
|
||||
|
||||
Args:
|
||||
- text: A string representing the text for retrieval.
|
||||
- **kwargs: Additional keyword arguments for retrieval.
|
||||
|
||||
Returns:
|
||||
- A list of Message instances representing the retrieved messages.
|
||||
"""
|
||||
pass
|
||||
|
||||
def datetime_retrieval(self, datetime: str, text: str = None, n: int = 5, **kwargs) -> List[Message]:
|
||||
"""
|
||||
Retrieves messages based on datetime.
|
||||
|
||||
Args:
|
||||
- datetime: A string representing the datetime for retrieval.
|
||||
- text: A string representing the text for retrieval. Default is None.
|
||||
- n: An integer representing the number of messages. Default is 5.
|
||||
- **kwargs: Additional keyword arguments for retrieval.
|
||||
|
||||
Returns:
|
||||
- A list of Message instances representing the retrieved messages.
|
||||
"""
|
||||
pass
|
||||
|
||||
def recursive_summary(self, messages: List[Message], split_n: int = 20) -> List[Message]:
|
||||
"""
|
||||
Performs recursive summarization of messages.
|
||||
|
||||
Args:
|
||||
- messages: A list of Message instances representing the messages to be summarized.
|
||||
- split_n: An integer representing the split n. Default is 20.
|
||||
|
||||
Returns:
|
||||
- A list of Message instances representing the summarized messages.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class LocalMemoryManager(BaseMemoryManager):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_config: EmbedConfig,
|
||||
llm_config: LLMConfig,
|
||||
user_name: str = "default",
|
||||
unique_name: str = "default",
|
||||
memory_type: str = "recall",
|
||||
do_init: bool = False,
|
||||
kb_root_path: str = KB_ROOT_PATH,
|
||||
):
|
||||
self.user_name = user_name
|
||||
self.unique_name = unique_name
|
||||
self.memory_type = memory_type
|
||||
self.do_init = do_init
|
||||
self.kb_root_path = kb_root_path
|
||||
self.embed_config: EmbedConfig = embed_config
|
||||
self.llm_config: LLMConfig = llm_config
|
||||
# self.current_memory = Memory(messages=[])
|
||||
# self.recall_memory = Memory(messages=[])
|
||||
# self.summary_memory = Memory(messages=[])
|
||||
self.current_memory_dict: Dict[str, Memory] = {}
|
||||
self.recall_memory_dict: Dict[str, Memory] = {}
|
||||
self.summary_memory_dict: Dict[str, Memory] = {}
|
||||
self.save_message_keys = [
|
||||
'chat_index', 'role_name', 'role_type', 'role_prompt', 'input_query', 'origin_query',
|
||||
'datetime', 'role_content', 'step_content', 'parsed_output', 'spec_parsed_output', 'parsed_output_list',
|
||||
'task', 'db_docs', 'code_docs', 'search_docs', 'phase_name', 'chain_name', 'customed_kargs']
|
||||
self.init_vb()
|
||||
|
||||
def re_init(self, do_init: bool=False):
|
||||
self.init_vb(do_init)
|
||||
|
||||
def init_vb(self, do_init: bool=None):
|
||||
vb_name = f"{self.user_name}/{self.unique_name}/{self.memory_type}"
|
||||
# default to recreate a new vb
|
||||
table_init()
|
||||
vb = KBServiceFactory.get_service_by_name(vb_name, self.embed_config, self.kb_root_path)
|
||||
if vb:
|
||||
status = vb.clear_vs()
|
||||
|
||||
check_do_init = do_init if do_init else self.do_init
|
||||
if not check_do_init:
|
||||
self.load(self.kb_root_path)
|
||||
self.save_to_vs()
|
||||
|
||||
def append(self, message: Message):
|
||||
self.check_user_name(message.user_name)
|
||||
|
||||
uuid_name = "_".join([self.user_name, self.unique_name, self.memory_type])
|
||||
self.recall_memory_dict[uuid_name].append(message)
|
||||
#
|
||||
if message.role_type == "summary":
|
||||
self.summary_memory_dict[uuid_name].append(message)
|
||||
else:
|
||||
self.current_memory_dict[uuid_name].append(message)
|
||||
|
||||
self.save(self.kb_root_path)
|
||||
self.save_new_to_vs([message])
|
||||
|
||||
# def extend(self, memory: Memory):
|
||||
# self.recall_memory.extend(memory)
|
||||
# self.current_memory.extend(self.recall_memory.filter_by_role_type(["summary"]))
|
||||
# self.summary_memory.extend(self.recall_memory.select_by_role_type(["summary"]))
|
||||
# self.save(self.kb_root_path)
|
||||
# self.save_new_to_vs(memory.messages)
|
||||
|
||||
def save(self, save_dir: str = "./"):
|
||||
file_path = os.path.join(save_dir, f"{self.user_name}/{self.unique_name}/{self.memory_type}/converation.jsonl")
|
||||
uuid_name = "_".join([self.user_name, self.unique_name, self.memory_type])
|
||||
|
||||
memory_messages = self.recall_memory_dict[uuid_name].dict()
|
||||
memory_messages = {k: [
|
||||
{kkk: vvv for kkk, vvv in vv.items() if kkk in self.save_message_keys}
|
||||
for vv in v ]
|
||||
for k, v in memory_messages.items()
|
||||
}
|
||||
#
|
||||
save_to_json_file(memory_messages, file_path)
|
||||
|
||||
def load(self, load_dir: str = None) -> Memory:
|
||||
load_dir = load_dir or self.kb_root_path
|
||||
file_path = os.path.join(load_dir, f"{self.user_name}/{self.unique_name}/{self.memory_type}/converation.jsonl")
|
||||
uuid_name = "_".join([self.user_name, self.unique_name, self.memory_type])
|
||||
|
||||
if os.path.exists(file_path):
|
||||
# self.recall_memory = Memory(**read_json_file(file_path))
|
||||
# self.current_memory = Memory(messages=self.recall_memory.filter_by_role_type(["summary"]))
|
||||
# self.summary_memory = Memory(messages=self.recall_memory.select_by_role_type(["summary"]))
|
||||
|
||||
recall_memory = Memory(**read_json_file(file_path))
|
||||
self.recall_memory_dict[uuid_name] = recall_memory
|
||||
self.current_memory_dict[uuid_name] = Memory(messages=recall_memory.filter_by_role_type(["summary"]))
|
||||
self.summary_memory_dict[uuid_name] = Memory(messages=recall_memory.select_by_role_type(["summary"]))
|
||||
else:
|
||||
self.recall_memory_dict[uuid_name] = Memory(messages=[])
|
||||
self.current_memory_dict[uuid_name] = Memory(messages=[])
|
||||
self.summary_memory_dict[uuid_name] = Memory(messages=[])
|
||||
|
||||
def save_new_to_vs(self, messages: List[Message]):
|
||||
if self.embed_config:
|
||||
vb_name = f"{self.user_name}/{self.unique_name}/{self.memory_type}"
|
||||
# default to faiss, todo: add new vstype
|
||||
vb = KBServiceFactory.get_service(vb_name, "faiss", self.embed_config, self.kb_root_path)
|
||||
embeddings = load_embeddings_from_path(self.embed_config.embed_model_path, self.embed_config.model_device, self.embed_config.langchain_embeddings)
|
||||
messages = [
|
||||
{k: v for k, v in m.dict().items() if k in self.save_message_keys}
|
||||
for m in messages]
|
||||
docs = [{"page_content": m["step_content"] or m["role_content"] or m["input_query"] or m["origin_query"], "metadata": m} for m in messages]
|
||||
docs = [Document(**doc) for doc in docs]
|
||||
vb.do_add_doc(docs, embeddings)
|
||||
|
||||
def save_to_vs(self):
|
||||
'''only after load'''
|
||||
if self.embed_config:
|
||||
vb_name = f"{self.user_name}/{self.unique_name}/{self.memory_type}"
|
||||
uuid_name = "_".join([self.user_name, self.unique_name, self.memory_type])
|
||||
# default to recreate a new vb
|
||||
vb = KBServiceFactory.get_service_by_name(vb_name, self.embed_config, self.kb_root_path)
|
||||
if vb:
|
||||
status = vb.clear_vs()
|
||||
# create_kb(vb_name, "faiss", embed_model)
|
||||
|
||||
# default to faiss, todo: add new vstype
|
||||
vb = KBServiceFactory.get_service(vb_name, "faiss", self.embed_config, self.kb_root_path)
|
||||
embeddings = load_embeddings_from_path(self.embed_config.embed_model_path, self.embed_config.model_device, self.embed_config.langchain_embeddings)
|
||||
messages = self.recall_memory_dict[uuid_name].dict()
|
||||
messages = [
|
||||
{kkk: vvv for kkk, vvv in vv.items() if kkk in self.save_message_keys}
|
||||
for k, v in messages.items() for vv in v]
|
||||
docs = [{"page_content": m["step_content"] or m["role_content"] or m["input_query"] or m["origin_query"], "metadata": m} for m in messages]
|
||||
docs = [Document(**doc) for doc in docs]
|
||||
vb.do_add_doc(docs, embeddings)
|
||||
|
||||
# def load_from_vs(self, embed_model=EMBEDDING_MODEL) -> Memory:
|
||||
# vb_name = f"{self.user_name}/{self.unique_name}/{self.memory_type}"
|
||||
|
||||
# create_kb(vb_name, "faiss", embed_model)
|
||||
# # default to faiss, todo: add new vstype
|
||||
# vb = KBServiceFactory.get_service(vb_name, "faiss", embed_model)
|
||||
# docs = vb.get_all_documents()
|
||||
# print(docs)
|
||||
|
||||
def get_memory_pool(self, user_name: str, ):
|
||||
self.check_user_name(user_name)
|
||||
uuid_name = "_".join([self.user_name, self.unique_name, self.memory_type])
|
||||
return self.recall_memory_dict[uuid_name]
|
||||
|
||||
def router_retrieval(self, user_name: str = "default", text: str=None, datetime: str = None, n=5, top_k=5, retrieval_type: str = "embedding", **kwargs) -> List[Message]:
|
||||
retrieval_func_dict = {
|
||||
"embedding": self.embedding_retrieval, "text": self.text_retrieval, "datetime": self.datetime_retrieval
|
||||
}
|
||||
|
||||
# 确保提供了合法的检索类型
|
||||
if retrieval_type not in retrieval_func_dict:
|
||||
raise ValueError(f"Invalid retrieval_type: '{retrieval_type}'. Available types: {list(retrieval_func_dict.keys())}")
|
||||
|
||||
retrieval_func = retrieval_func_dict[retrieval_type]
|
||||
#
|
||||
params = locals()
|
||||
params.pop("self")
|
||||
params.pop("retrieval_type")
|
||||
params.update(params.pop('kwargs', {}))
|
||||
#
|
||||
return retrieval_func(**params)
|
||||
|
||||
def embedding_retrieval(self, text: str, top_k=1, score_threshold=1.0, user_name: str = "default", **kwargs) -> List[Message]:
|
||||
if text is None: return []
|
||||
vb_name = f"{user_name}/{self.unique_name}/{self.memory_type}"
|
||||
# logger.debug(f"vb_name={vb_name}")
|
||||
vb = KBServiceFactory.get_service(vb_name, "faiss", self.embed_config, self.kb_root_path)
|
||||
docs = vb.search_docs(text, top_k=top_k, score_threshold=score_threshold)
|
||||
return [Message(**doc.metadata) for doc, score in docs]
|
||||
|
||||
def text_retrieval(self, text: str, user_name: str = "default", **kwargs) -> List[Message]:
|
||||
if text is None: return []
|
||||
uuid_name = "_".join([user_name, self.unique_name, self.memory_type])
|
||||
# logger.debug(f"uuid_name={uuid_name}")
|
||||
return self._text_retrieval_from_cache(self.recall_memory_dict[uuid_name].messages, text, score_threshold=0.3, topK=5, **kwargs)
|
||||
|
||||
def datetime_retrieval(self, datetime: str, text: str = None, n: int = 5, user_name: str = "default", **kwargs) -> List[Message]:
|
||||
if datetime is None: return []
|
||||
uuid_name = "_".join([user_name, self.unique_name, self.memory_type])
|
||||
# logger.debug(f"uuid_name={uuid_name}")
|
||||
return self._datetime_retrieval_from_cache(self.recall_memory_dict[uuid_name].messages, datetime, text, n, **kwargs)
|
||||
|
||||
def _text_retrieval_from_cache(self, messages: List[Message], text: str = None, score_threshold=0.3, topK=5, tag_topK=5, **kwargs) -> List[Message]:
|
||||
keywords = extract_tags(text, topK=tag_topK)
|
||||
|
||||
matched_messages = []
|
||||
for message in messages:
|
||||
message_keywords = extract_tags(message.step_content or message.role_content or message.input_query, topK=tag_topK)
|
||||
# calculate jaccard similarity
|
||||
intersection = Counter(keywords) & Counter(message_keywords)
|
||||
union = Counter(keywords) | Counter(message_keywords)
|
||||
similarity = sum(intersection.values()) / sum(union.values())
|
||||
if similarity >= score_threshold:
|
||||
matched_messages.append((message, similarity))
|
||||
matched_messages = sorted(matched_messages, key=lambda x:x[1])
|
||||
return [m for m, s in matched_messages][:topK]
|
||||
|
||||
def _datetime_retrieval_from_cache(self, messages: List[Message], datetime: str, text: str = None, n: int = 5, **kwargs) -> List[Message]:
|
||||
# select message by datetime
|
||||
datetime_before, datetime_after = addMinutesToTime(datetime, n)
|
||||
select_messages = [
|
||||
message for message in messages
|
||||
if datetime_before<=message.datetime<=datetime_after
|
||||
]
|
||||
return self._text_retrieval_from_cache(select_messages, text)
|
||||
|
||||
def recursive_summary(self, messages: List[Message], split_n: int = 20) -> List[Message]:
|
||||
|
||||
if len(messages) == 0:
|
||||
return messages
|
||||
|
||||
newest_messages = messages[-split_n:]
|
||||
summary_messages = messages[:len(messages)-split_n]
|
||||
|
||||
while (len(newest_messages) != 0) and (newest_messages[0].role_type != "user"):
|
||||
message = newest_messages.pop(0)
|
||||
summary_messages.append(message)
|
||||
|
||||
# summary
|
||||
# model = getChatModel(temperature=0.2)
|
||||
model = getChatModelFromConfig(self.llm_config)
|
||||
summary_content = '\n\n'.join([
|
||||
m.role_type + "\n" + "\n".join(([f"*{k}* {v}" for parsed_output in m.parsed_output_list for k, v in parsed_output.items() if k not in ['Action Status']]))
|
||||
for m in summary_messages if m.role_type not in ["summary"]
|
||||
])
|
||||
|
||||
summary_prompt = CONV_SUMMARY_PROMPT_SPEC.format(conversation=summary_content)
|
||||
content = model.predict(summary_prompt)
|
||||
summary_message = Message(
|
||||
role_name="summaryer",
|
||||
role_type="summary",
|
||||
role_content=content,
|
||||
step_content=content,
|
||||
parsed_output_list=[],
|
||||
customed_kargs={}
|
||||
)
|
||||
summary_message.parsed_output_list.append({"summary": content})
|
||||
newest_messages.insert(0, summary_message)
|
||||
return newest_messages
|
||||
|
||||
def check_user_name(self, user_name: str):
|
||||
# logger.debug(f"self.user_name is {self.user_name}")
|
||||
if user_name != self.user_name:
|
||||
self.user_name = user_name
|
||||
self.init_vb()
|
||||
|
||||
uuid_name = "_".join([self.user_name, self.unique_name, self.memory_type])
|
||||
if uuid_name not in self.recall_memory_dict:
|
||||
self.recall_memory_dict[uuid_name] = Memory(messages=[])
|
||||
self.current_memory_dict[uuid_name] = Memory(messages=[])
|
||||
self.summary_memory_dict[uuid_name] = Memory(messages=[])
|
||||
|
||||
# logger.debug(f"self.user_name is {self.user_name}")
|
|
@ -1,306 +0,0 @@
|
|||
import re, traceback, uuid, copy, json, os
|
||||
from typing import Union
|
||||
from loguru import logger
|
||||
|
||||
from langchain.schema import BaseRetriever
|
||||
|
||||
from coagent.connector.schema import (
|
||||
Memory, Role, Message, ActionStatus, CodeDoc, Doc, LogVerboseEnum
|
||||
)
|
||||
from coagent.retrieval.base_retrieval import IMRertrieval
|
||||
from coagent.connector.memory_manager import BaseMemoryManager
|
||||
from coagent.tools import DDGSTool, DocRetrieval, CodeRetrieval
|
||||
from coagent.sandbox import PyCodeBox, CodeBoxResponse
|
||||
from coagent.llm_models.llm_config import LLMConfig, EmbedConfig
|
||||
from coagent.base_configs.env_config import JUPYTER_WORK_PATH
|
||||
|
||||
from .utils import parse_dict_to_dict, parse_text_to_dict
|
||||
|
||||
|
||||
class MessageUtils:
|
||||
def __init__(
|
||||
self,
|
||||
role: Role = None,
|
||||
sandbox_server: dict = {},
|
||||
jupyter_work_path: str = JUPYTER_WORK_PATH,
|
||||
embed_config: EmbedConfig = None,
|
||||
llm_config: LLMConfig = None,
|
||||
kb_root_path: str = "",
|
||||
doc_retrieval: Union[BaseRetriever, IMRertrieval] = None,
|
||||
code_retrieval: IMRertrieval = None,
|
||||
search_retrieval: IMRertrieval = None,
|
||||
log_verbose: str = "0"
|
||||
) -> None:
|
||||
self.role = role
|
||||
self.sandbox_server = sandbox_server
|
||||
self.jupyter_work_path = jupyter_work_path
|
||||
self.embed_config = embed_config
|
||||
self.llm_config = llm_config
|
||||
self.kb_root_path = kb_root_path
|
||||
self.doc_retrieval = doc_retrieval
|
||||
self.code_retrieval = code_retrieval
|
||||
self.search_retrieval = search_retrieval
|
||||
self.codebox = PyCodeBox(
|
||||
remote_url=self.sandbox_server.get("url", "http://127.0.0.1:5050"),
|
||||
remote_ip=self.sandbox_server.get("host", "http://127.0.0.1"),
|
||||
remote_port=self.sandbox_server.get("port", "5050"),
|
||||
jupyter_work_path=jupyter_work_path,
|
||||
token="mytoken",
|
||||
do_code_exe=True,
|
||||
do_remote=self.sandbox_server.get("do_remote", False),
|
||||
do_check_net=False
|
||||
)
|
||||
self.log_verbose = os.environ.get("log_verbose", "0") or log_verbose
|
||||
|
||||
def inherit_extrainfo(self, input_message: Message, output_message: Message):
|
||||
output_message.user_name = input_message.user_name
|
||||
output_message.db_docs = input_message.db_docs
|
||||
output_message.search_docs = input_message.search_docs
|
||||
output_message.code_docs = input_message.code_docs
|
||||
output_message.figures.update(input_message.figures)
|
||||
output_message.origin_query = input_message.origin_query
|
||||
output_message.code_engine_name = input_message.code_engine_name
|
||||
|
||||
output_message.doc_engine_name = input_message.doc_engine_name
|
||||
output_message.search_engine_name = input_message.search_engine_name
|
||||
output_message.top_k = input_message.top_k
|
||||
output_message.score_threshold = input_message.score_threshold
|
||||
output_message.cb_search_type = input_message.cb_search_type
|
||||
output_message.do_doc_retrieval = input_message.do_doc_retrieval
|
||||
output_message.do_code_retrieval = input_message.do_code_retrieval
|
||||
output_message.do_tool_retrieval = input_message.do_tool_retrieval
|
||||
#
|
||||
output_message.tools = input_message.tools
|
||||
output_message.agents = input_message.agents
|
||||
|
||||
# update customed_kargs, if exist, keep; else add
|
||||
customed_kargs = copy.deepcopy(input_message.customed_kargs)
|
||||
customed_kargs.update(output_message.customed_kargs)
|
||||
output_message.customed_kargs = customed_kargs
|
||||
return output_message
|
||||
|
||||
def inherit_baseparam(self, input_message: Message, output_message: Message):
|
||||
# 只更新参数
|
||||
output_message.doc_engine_name = input_message.doc_engine_name
|
||||
output_message.search_engine_name = input_message.search_engine_name
|
||||
output_message.top_k = input_message.top_k
|
||||
output_message.score_threshold = input_message.score_threshold
|
||||
output_message.cb_search_type = input_message.cb_search_type
|
||||
output_message.do_doc_retrieval = input_message.do_doc_retrieval
|
||||
output_message.do_code_retrieval = input_message.do_code_retrieval
|
||||
output_message.do_tool_retrieval = input_message.do_tool_retrieval
|
||||
#
|
||||
output_message.tools = input_message.tools
|
||||
output_message.agents = input_message.agents
|
||||
# 存在bug导致相同key被覆盖
|
||||
output_message.customed_kargs.update(input_message.customed_kargs)
|
||||
return output_message
|
||||
|
||||
def get_extrainfo_step(self, message: Message, do_search, do_doc_retrieval, do_code_retrieval, do_tool_retrieval) -> Message:
|
||||
''''''
|
||||
if do_search:
|
||||
message = self.get_search_retrieval(message)
|
||||
|
||||
if do_doc_retrieval:
|
||||
message = self.get_doc_retrieval(message)
|
||||
|
||||
if do_code_retrieval:
|
||||
message = self.get_code_retrieval(message)
|
||||
|
||||
if do_tool_retrieval:
|
||||
message = self.get_tool_retrieval(message)
|
||||
|
||||
return message
|
||||
|
||||
def get_search_retrieval(self, message: Message,) -> Message:
|
||||
SEARCH_ENGINES = {"duckduckgo": DDGSTool}
|
||||
search_docs = []
|
||||
for idx, doc in enumerate(SEARCH_ENGINES["duckduckgo"].run(message.role_content, 3)):
|
||||
doc.update({"index": idx})
|
||||
search_docs.append(Doc(**doc))
|
||||
message.search_docs = search_docs
|
||||
return message
|
||||
|
||||
def get_doc_retrieval(self, message: Message) -> Message:
|
||||
query = message.role_content
|
||||
knowledge_basename = message.doc_engine_name
|
||||
top_k = message.top_k
|
||||
score_threshold = message.score_threshold
|
||||
if self.doc_retrieval:
|
||||
if isinstance(self.doc_retrieval, BaseRetriever):
|
||||
docs = self.doc_retrieval.get_relevant_documents(query)
|
||||
else:
|
||||
# docs = self.doc_retrieval.run(query, search_top=message.top_k, score_threshold=message.score_threshold,)
|
||||
docs = self.doc_retrieval.run(query)
|
||||
docs = [
|
||||
{"index": idx, "snippet": doc.page_content, "title": doc.metadata.get("title_prefix", ""), "link": doc.metadata.get("url", "")}
|
||||
for idx, doc in enumerate(docs)
|
||||
]
|
||||
message.db_docs = [Doc(**doc) for doc in docs]
|
||||
else:
|
||||
if knowledge_basename:
|
||||
docs = DocRetrieval.run(query, knowledge_basename, top_k, score_threshold, self.embed_config, self.kb_root_path)
|
||||
message.db_docs = [Doc(**doc) for doc in docs]
|
||||
return message
|
||||
|
||||
def get_code_retrieval(self, message: Message) -> Message:
|
||||
query = message.role_content
|
||||
code_engine_name = message.code_engine_name
|
||||
history_node_list = message.history_node_list
|
||||
|
||||
use_nh = message.use_nh
|
||||
local_graph_path = message.local_graph_path
|
||||
|
||||
if self.code_retrieval:
|
||||
code_docs = self.code_retrieval.run(
|
||||
query, history_node_list=history_node_list, search_type=message.cb_search_type,
|
||||
code_limit=1
|
||||
)
|
||||
else:
|
||||
code_docs = CodeRetrieval.run(code_engine_name, query, code_limit=message.top_k, history_node_list=history_node_list, search_type=message.cb_search_type,
|
||||
llm_config=self.llm_config, embed_config=self.embed_config,
|
||||
use_nh=use_nh, local_graph_path=local_graph_path)
|
||||
|
||||
message.code_docs = [CodeDoc(**doc) for doc in code_docs]
|
||||
|
||||
# related_nodes = [doc.get_related_node() for idx, doc in enumerate(message.code_docs) if idx==0],
|
||||
# history_node_list.extend([node[0] for node in related_nodes])
|
||||
return message
|
||||
|
||||
def get_tool_retrieval(self, message: Message) -> Message:
|
||||
return message
|
||||
|
||||
def step_router(self, message: Message, history: Memory = None, background: Memory = None, memory_manager: BaseMemoryManager=None) -> tuple[Message, ...]:
|
||||
''''''
|
||||
if LogVerboseEnum.ge(LogVerboseEnum.Log1Level, self.log_verbose):
|
||||
logger.info(f"message.action_status: {message.action_status}")
|
||||
|
||||
observation_message = None
|
||||
if message.action_status == ActionStatus.CODE_EXECUTING:
|
||||
message, observation_message = self.code_step(message)
|
||||
elif message.action_status == ActionStatus.TOOL_USING:
|
||||
message, observation_message = self.tool_step(message)
|
||||
elif message.action_status == ActionStatus.CODING2FILE:
|
||||
self.save_code2file(message, self.jupyter_work_path)
|
||||
elif message.action_status == ActionStatus.CODE_RETRIEVAL:
|
||||
pass
|
||||
elif message.action_status == ActionStatus.CODING:
|
||||
pass
|
||||
|
||||
return message, observation_message
|
||||
|
||||
def code_step(self, message: Message) -> Message:
|
||||
'''execute code'''
|
||||
# logger.debug(f"message.role_content: {message.role_content}, message.code_content: {message.code_content}")
|
||||
code_answer = self.codebox.chat('```python\n{}```'.format(message.code_content))
|
||||
code_prompt = f"The return error after executing the above code is {code_answer.code_exe_response},need to recover.\n" \
|
||||
if code_answer.code_exe_type == "error" else f"The return information after executing the above code is {code_answer.code_exe_response}.\n"
|
||||
|
||||
observation_message = Message(
|
||||
user_name=message.user_name,
|
||||
role_name="observation",
|
||||
role_type="function", #self.role.role_type,
|
||||
role_content="",
|
||||
step_content="",
|
||||
input_query=message.code_content,
|
||||
)
|
||||
uid = str(uuid.uuid1())
|
||||
if code_answer.code_exe_type == "image/png":
|
||||
message.figures[uid] = code_answer.code_exe_response
|
||||
message.code_answer = f"\n**Observation:**: The return figure name is {uid} after executing the above code.\n"
|
||||
message.observation = f"\n**Observation:**: The return figure name is {uid} after executing the above code.\n"
|
||||
message.step_content += f"\n**Observation:**: The return figure name is {uid} after executing the above code.\n"
|
||||
# message.role_content += f"\n**Observation:**:执行上述代码后生成一张图片, 图片名为{uid}\n"
|
||||
observation_message.role_content = f"\n**Observation:**: The return figure name is {uid} after executing the above code.\n"
|
||||
observation_message.parsed_output = {"Observation": f"The return figure name is {uid} after executing the above code.\n"}
|
||||
else:
|
||||
message.code_answer = code_answer.code_exe_response
|
||||
message.observation = code_answer.code_exe_response
|
||||
message.step_content += f"\n**Observation:**: {code_prompt}\n"
|
||||
# message.role_content += f"\n**Observation:**: {code_prompt}\n"
|
||||
observation_message.role_content = f"\n**Observation:**: {code_prompt}\n"
|
||||
observation_message.parsed_output = {"Observation": f"{code_prompt}\n"}
|
||||
|
||||
if LogVerboseEnum.ge(LogVerboseEnum.Log1Level, self.log_verbose):
|
||||
logger.info(f"**Observation:** {message.action_status}, {message.observation}")
|
||||
return message, observation_message
|
||||
|
||||
def tool_step(self, message: Message) -> Message:
|
||||
'''execute tool'''
|
||||
observation_message = Message(
|
||||
user_name=message.user_name,
|
||||
role_name="observation",
|
||||
role_type="function", #self.role.role_type,
|
||||
role_content="\n**Observation:** there is no tool can execute\n",
|
||||
step_content="",
|
||||
input_query=str(message.tool_params),
|
||||
tools=message.tools,
|
||||
)
|
||||
if LogVerboseEnum.ge(LogVerboseEnum.Log1Level, self.log_verbose):
|
||||
logger.info(f"message: {message.action_status}, {message.tool_params}")
|
||||
|
||||
tool_names = [tool.name for tool in message.tools]
|
||||
if message.tool_name not in tool_names:
|
||||
message.tool_answer = "\n**Observation:** there is no tool can execute.\n"
|
||||
message.observation = "\n**Observation:** there is no tool can execute.\n"
|
||||
# message.role_content += f"\n**Observation:**: 不存在可以执行的tool\n"
|
||||
message.step_content += f"\n**Observation:** there is no tool can execute.\n"
|
||||
observation_message.role_content = f"\n**Observation:** there is no tool can execute.\n"
|
||||
observation_message.parsed_output = {"Observation": "there is no tool can execute.\n"}
|
||||
|
||||
# logger.debug(message.tool_params)
|
||||
for tool in message.tools:
|
||||
if tool.name == message.tool_params.get("tool_name", ""):
|
||||
tool_res = tool.func(**message.tool_params.get("tool_params", {}))
|
||||
message.tool_answer = tool_res
|
||||
message.observation = tool_res
|
||||
# message.role_content += f"\n**Observation:**: {tool_res}\n"
|
||||
message.step_content += f"\n**Observation:** {tool_res}.\n"
|
||||
observation_message.role_content = f"\n**Observation:** {tool_res}.\n"
|
||||
observation_message.parsed_output = {"Observation": f"{tool_res}.\n"}
|
||||
break
|
||||
|
||||
if LogVerboseEnum.ge(LogVerboseEnum.Log1Level, self.log_verbose):
|
||||
logger.info(f"**Observation:** {message.action_status}, {message.observation}")
|
||||
return message, observation_message
|
||||
|
||||
def parser(self, message: Message) -> Message:
|
||||
'''parse llm output into dict'''
|
||||
content = message.role_content
|
||||
# parse start
|
||||
parsed_dict = parse_text_to_dict(content)
|
||||
spec_parsed_dict = parse_dict_to_dict(parsed_dict)
|
||||
# select parse value
|
||||
action_value = parsed_dict.get('Action Status')
|
||||
if action_value:
|
||||
action_value = action_value.lower()
|
||||
|
||||
code_content_value = spec_parsed_dict.get('code')
|
||||
if action_value == 'tool_using':
|
||||
tool_params_value = spec_parsed_dict.get('json')
|
||||
else:
|
||||
tool_params_value = None
|
||||
|
||||
# add parse value to message
|
||||
message.action_status = action_value or "default"
|
||||
message.code_content = code_content_value
|
||||
message.tool_params = tool_params_value
|
||||
message.parsed_output = parsed_dict
|
||||
message.spec_parsed_output = spec_parsed_dict
|
||||
return message
|
||||
|
||||
def save_code2file(self, message: Message, project_dir="./"):
|
||||
filename = message.parsed_output.get("SaveFileName")
|
||||
code = message.spec_parsed_output.get("code")
|
||||
|
||||
for k, v in {">": ">", "≥": ">=", "<": "<", "≤": "<="}.items():
|
||||
code = code.replace(k, v)
|
||||
|
||||
file_path = os.path.join(project_dir, filename)
|
||||
|
||||
if not os.path.exists(file_path):
|
||||
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
||||
|
||||
with open(file_path, "w") as f:
|
||||
f.write(code)
|
||||
|
|
@ -1,3 +0,0 @@
|
|||
from .base_phase import BasePhase
|
||||
|
||||
__all__ = ["BasePhase"]
|
|
@ -1,272 +0,0 @@
|
|||
from typing import List, Union, Dict, Tuple
|
||||
import os
|
||||
import json
|
||||
import importlib
|
||||
import copy
|
||||
from loguru import logger
|
||||
|
||||
from langchain.schema import BaseRetriever
|
||||
|
||||
from coagent.connector.agents import BaseAgent
|
||||
from coagent.connector.chains import BaseChain
|
||||
from coagent.connector.schema import (
|
||||
Memory, Task, Message, AgentConfig, ChainConfig, PhaseConfig, LogVerboseEnum,
|
||||
CompletePhaseConfig,
|
||||
load_chain_configs, load_phase_configs, load_role_configs
|
||||
)
|
||||
from coagent.connector.memory_manager import BaseMemoryManager, LocalMemoryManager
|
||||
from coagent.connector.configs import AGETN_CONFIGS, CHAIN_CONFIGS, PHASE_CONFIGS
|
||||
from coagent.connector.message_process import MessageUtils
|
||||
from coagent.llm_models.llm_config import EmbedConfig, LLMConfig
|
||||
from coagent.base_configs.env_config import JUPYTER_WORK_PATH, KB_ROOT_PATH
|
||||
|
||||
|
||||
role_configs = load_role_configs(AGETN_CONFIGS)
|
||||
chain_configs = load_chain_configs(CHAIN_CONFIGS)
|
||||
phase_configs = load_phase_configs(PHASE_CONFIGS)
|
||||
|
||||
|
||||
CUR_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
|
||||
class BasePhase:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
phase_name: str,
|
||||
phase_config: CompletePhaseConfig = None,
|
||||
kb_root_path: str = KB_ROOT_PATH,
|
||||
jupyter_work_path: str = JUPYTER_WORK_PATH,
|
||||
sandbox_server: dict = {},
|
||||
embed_config: EmbedConfig = None,
|
||||
llm_config: LLMConfig = None,
|
||||
task: Task = None,
|
||||
base_phase_config: Union[dict, str] = PHASE_CONFIGS,
|
||||
base_chain_config: Union[dict, str] = CHAIN_CONFIGS,
|
||||
base_role_config: Union[dict, str] = AGETN_CONFIGS,
|
||||
chains: List[BaseChain] = [],
|
||||
doc_retrieval: Union[BaseRetriever] = None,
|
||||
code_retrieval = None,
|
||||
search_retrieval = None,
|
||||
log_verbose: str = "0"
|
||||
) -> None:
|
||||
#
|
||||
self.phase_name = phase_name
|
||||
self.do_summary = False
|
||||
self.do_search = search_retrieval is not None
|
||||
self.do_code_retrieval = code_retrieval is not None
|
||||
self.do_doc_retrieval = doc_retrieval is not None
|
||||
self.do_tool_retrieval = False
|
||||
# memory_pool dont have specific order
|
||||
# self.memory_pool = Memory(messages=[])
|
||||
self.embed_config = embed_config
|
||||
self.llm_config = llm_config
|
||||
self.sandbox_server = sandbox_server
|
||||
self.jupyter_work_path = jupyter_work_path
|
||||
self.kb_root_path = kb_root_path
|
||||
self.log_verbose = max(os.environ.get("log_verbose", "0"), log_verbose)
|
||||
# TODO透传
|
||||
self.doc_retrieval = doc_retrieval
|
||||
self.code_retrieval = code_retrieval
|
||||
self.search_retrieval = search_retrieval
|
||||
self.message_utils = MessageUtils(None, sandbox_server, jupyter_work_path, embed_config, llm_config, kb_root_path, doc_retrieval, code_retrieval, search_retrieval, log_verbose)
|
||||
self.global_memory = Memory(messages=[])
|
||||
self.phase_memory: List[Memory] = []
|
||||
# according phase name to init the phase contains
|
||||
self.chains: List[BaseChain] = chains if chains else self.init_chains(
|
||||
phase_name,
|
||||
phase_config,
|
||||
task=task,
|
||||
memory=None,
|
||||
base_phase_config = base_phase_config,
|
||||
base_chain_config = base_chain_config,
|
||||
base_role_config = base_role_config,
|
||||
)
|
||||
self.memory_manager: BaseMemoryManager = LocalMemoryManager(
|
||||
unique_name=phase_name, do_init=True, kb_root_path = kb_root_path, embed_config=embed_config, llm_config=llm_config
|
||||
)
|
||||
self.conv_summary_agent = BaseAgent(
|
||||
role=role_configs["conv_summary"].role,
|
||||
prompt_config=role_configs["conv_summary"].prompt_config,
|
||||
task = None, memory = None,
|
||||
llm_config=self.llm_config,
|
||||
embed_config=self.embed_config,
|
||||
sandbox_server=sandbox_server,
|
||||
jupyter_work_path=jupyter_work_path,
|
||||
kb_root_path=kb_root_path
|
||||
)
|
||||
|
||||
def astep(self, query: Message, history: Memory = None, reinit_memory=False) -> Tuple[Message, Memory]:
|
||||
if reinit_memory:
|
||||
self.memory_manager.re_init(reinit_memory)
|
||||
self.memory_manager.append(query)
|
||||
summary_message = None
|
||||
chain_message = Memory(messages=[])
|
||||
local_phase_memory = Memory(messages=[])
|
||||
# do_search、do_doc_search、do_code_search
|
||||
query = self.message_utils.get_extrainfo_step(query, self.do_search, self.do_doc_retrieval, self.do_code_retrieval, self.do_tool_retrieval)
|
||||
query.parsed_output = query.parsed_output if query.parsed_output else {"origin_query": query.input_query}
|
||||
query.parsed_output_list = query.parsed_output_list if query.parsed_output_list else [{"origin_query": query.input_query}]
|
||||
input_message = copy.deepcopy(query)
|
||||
|
||||
self.global_memory.append(input_message)
|
||||
local_phase_memory.append(input_message)
|
||||
for chain in self.chains:
|
||||
# chain can supply background and query to next chain
|
||||
for output_message, local_chain_memory in chain.astep(input_message, history, background=chain_message, memory_manager=self.memory_manager):
|
||||
# logger.debug(f"local_memory: {local_phase_memory + local_chain_memory}")
|
||||
yield output_message, local_phase_memory + local_chain_memory
|
||||
|
||||
output_message = self.message_utils.inherit_extrainfo(input_message, output_message)
|
||||
input_message = output_message
|
||||
# logger.info(f"{chain.chainConfig.chain_name} phase_step: {output_message.role_content}")
|
||||
# 这一段也有问题
|
||||
self.global_memory.extend(local_chain_memory)
|
||||
local_phase_memory.extend(local_chain_memory)
|
||||
|
||||
# whether to use summary_llm
|
||||
if self.do_summary:
|
||||
if LogVerboseEnum.ge(LogVerboseEnum.Log1Level, self.log_verbose):
|
||||
logger.info(f"{self.conv_summary_agent.role.role_name} input global memory: {local_phase_memory.to_str_messages(content_key='step_content')}")
|
||||
for summary_message in self.conv_summary_agent.astep(query, background=local_phase_memory, memory_manager=self.memory_manager):
|
||||
pass
|
||||
# summary_message = Message(**summary_message)
|
||||
summary_message.role_name = chain.chainConfig.chain_name
|
||||
summary_message = self.conv_summary_agent.message_utils.parser(summary_message)
|
||||
summary_message = self.message_utils.inherit_extrainfo(output_message, summary_message)
|
||||
chain_message.append(summary_message)
|
||||
|
||||
message = summary_message or output_message
|
||||
yield message, local_phase_memory
|
||||
|
||||
# 由于不会存在多轮chain执行,所以直接保留memory即可
|
||||
for chain in self.chains:
|
||||
self.phase_memory.append(chain.global_memory)
|
||||
# TODO:local_memory缺少添加summary的过程
|
||||
message = summary_message or output_message
|
||||
message.role_name = self.phase_name
|
||||
yield message, local_phase_memory
|
||||
|
||||
def step(self, query: Message, history: Memory = None, reinit_memory=False) -> Tuple[Message, Memory]:
|
||||
for message, local_phase_memory in self.astep(query, history=history, reinit_memory=reinit_memory):
|
||||
pass
|
||||
return message, local_phase_memory
|
||||
|
||||
def pre_print(self, query, history: Memory = None) -> List[str]:
|
||||
chain_message = Memory(messages=[])
|
||||
for chain in self.chains:
|
||||
chain.pre_print(query, history, background=chain_message, memory_manager=self.memory_manager)
|
||||
|
||||
def init_chains(self, phase_name: str, phase_config: CompletePhaseConfig, base_phase_config, base_chain_config,
|
||||
base_role_config, task=None, memory=None) -> List[BaseChain]:
|
||||
# load config
|
||||
role_configs = load_role_configs(base_role_config)
|
||||
chain_configs = load_chain_configs(base_chain_config)
|
||||
phase_configs = load_phase_configs(base_phase_config)
|
||||
|
||||
chains = []
|
||||
self.chain_module = importlib.import_module("coagent.connector.chains")
|
||||
self.agent_module = importlib.import_module("coagent.connector.agents")
|
||||
|
||||
phase: PhaseConfig = phase_configs.get(phase_name)
|
||||
# set phase
|
||||
self.do_summary = phase.do_summary
|
||||
self.do_search = phase.do_search
|
||||
self.do_code_retrieval = phase.do_code_retrieval
|
||||
self.do_doc_retrieval = phase.do_doc_retrieval
|
||||
self.do_tool_retrieval = phase.do_tool_retrieval
|
||||
logger.info(f"start to init the phase, the phase_name is {phase_name}, it contains these chains such as {phase.chains}")
|
||||
|
||||
for chain_name in phase.chains:
|
||||
# logger.debug(f"{chain_configs.keys()}")
|
||||
chain_config: ChainConfig = chain_configs[chain_name]
|
||||
logger.info(f"start to init the chain, the chain_name is {chain_name}, it contains these agents such as {chain_config.agents}")
|
||||
|
||||
agents = []
|
||||
for agent_name in chain_config.agents:
|
||||
agent_config: AgentConfig = role_configs[agent_name]
|
||||
llm_config = copy.deepcopy(self.llm_config)
|
||||
llm_config.stop = agent_config.stop
|
||||
baseAgent: BaseAgent = getattr(self.agent_module, agent_config.role.agent_type)
|
||||
base_agent = baseAgent(
|
||||
role=agent_config.role,
|
||||
prompt_config = agent_config.prompt_config,
|
||||
prompt_manager_type=agent_config.prompt_manager_type,
|
||||
task = task,
|
||||
memory = memory,
|
||||
chat_turn=agent_config.chat_turn,
|
||||
focus_agents=agent_config.focus_agents,
|
||||
focus_message_keys=agent_config.focus_message_keys,
|
||||
llm_config=llm_config,
|
||||
embed_config=self.embed_config,
|
||||
sandbox_server=self.sandbox_server,
|
||||
jupyter_work_path=self.jupyter_work_path,
|
||||
kb_root_path=self.kb_root_path,
|
||||
doc_retrieval=self.doc_retrieval,
|
||||
code_retrieval=self.code_retrieval,
|
||||
search_retrieval=self.search_retrieval,
|
||||
log_verbose=self.log_verbose
|
||||
)
|
||||
if agent_config.role.agent_type == "SelectorAgent":
|
||||
for group_agent_name in agent_config.group_agents:
|
||||
group_agent_config = role_configs[group_agent_name]
|
||||
llm_config = copy.deepcopy(self.llm_config)
|
||||
llm_config.stop = group_agent_config.stop
|
||||
baseAgent: BaseAgent = getattr(self.agent_module, group_agent_config.role.agent_type)
|
||||
group_base_agent = baseAgent(
|
||||
role=group_agent_config.role,
|
||||
prompt_config = group_agent_config.prompt_config,
|
||||
prompt_manager_type=group_agent_config.prompt_manager_type,
|
||||
task = task,
|
||||
memory = memory,
|
||||
chat_turn=group_agent_config.chat_turn,
|
||||
focus_agents=group_agent_config.focus_agents,
|
||||
focus_message_keys=group_agent_config.focus_message_keys,
|
||||
llm_config=llm_config,
|
||||
embed_config=self.embed_config,
|
||||
sandbox_server=self.sandbox_server,
|
||||
jupyter_work_path=self.jupyter_work_path,
|
||||
kb_root_path=self.kb_root_path,
|
||||
doc_retrieval=self.doc_retrieval,
|
||||
code_retrieval=self.code_retrieval,
|
||||
search_retrieval=self.search_retrieval,
|
||||
log_verbose=self.log_verbose
|
||||
)
|
||||
base_agent.group_agents.append(group_base_agent)
|
||||
|
||||
agents.append(base_agent)
|
||||
|
||||
chain_instance = BaseChain(
|
||||
chain_config,
|
||||
agents,
|
||||
jupyter_work_path=self.jupyter_work_path,
|
||||
sandbox_server=self.sandbox_server,
|
||||
embed_config=self.embed_config,
|
||||
llm_config=self.llm_config,
|
||||
kb_root_path=self.kb_root_path,
|
||||
doc_retrieval=self.doc_retrieval,
|
||||
code_retrieval=self.code_retrieval,
|
||||
search_retrieval=self.search_retrieval,
|
||||
log_verbose=self.log_verbose
|
||||
)
|
||||
chains.append(chain_instance)
|
||||
|
||||
return chains
|
||||
|
||||
def update(self) -> Memory:
|
||||
pass
|
||||
|
||||
def get_memory(self, ) -> Memory:
|
||||
return Memory.from_memory_list(
|
||||
[chain.get_memory() for chain in self.chains]
|
||||
)
|
||||
|
||||
def get_memory_str(self, do_all_memory=True, content_key="role_content") -> str:
|
||||
memory = self.global_memory if do_all_memory else self.phase_memory
|
||||
return "\n".join([": ".join(i) for i in memory.to_tuple_messages(content_key=content_key)])
|
||||
|
||||
def get_chains_memory(self, content_key="role_content") -> List[Tuple]:
|
||||
return [memory.to_tuple_messages(content_key=content_key) for memory in self.phase_memory]
|
||||
|
||||
def get_chains_memory_str(self, content_key="role_content") -> str:
|
||||
return "************".join([f"{chain.chainConfig.chain_name}\n" + chain.get_memory_str(content_key=content_key) for chain in self.chains])
|
|
@ -1,350 +0,0 @@
|
|||
from coagent.connector.schema import Memory, Message
|
||||
import random
|
||||
from textwrap import dedent
|
||||
import copy
|
||||
from loguru import logger
|
||||
|
||||
from coagent.connector.utils import extract_section, parse_section
|
||||
|
||||
|
||||
class PromptManager:
|
||||
def __init__(self, role_prompt="", prompt_config=None, monitored_agents=[], monitored_fields=[]):
|
||||
self.role_prompt = role_prompt
|
||||
self.monitored_agents = monitored_agents
|
||||
self.monitored_fields = monitored_fields
|
||||
self.field_handlers = {}
|
||||
self.context_handlers = {}
|
||||
self.field_order = [] # 用于普通字段的顺序
|
||||
self.context_order = [] # 单独维护上下文字段的顺序
|
||||
self.field_descriptions = {}
|
||||
self.omit_if_empty_flags = {}
|
||||
self.context_title = "### Context Data\n\n"
|
||||
|
||||
self.prompt_config = prompt_config
|
||||
if self.prompt_config:
|
||||
self.register_fields_from_config()
|
||||
|
||||
def register_field(self, field_name, function=None, title=None, description=None, is_context=True, omit_if_empty=True):
|
||||
"""
|
||||
注册一个新的字段及其处理函数。
|
||||
Args:
|
||||
field_name (str): 字段名称。
|
||||
function (callable): 处理字段数据的函数。
|
||||
title (str, optional): 字段的自定义标题(可选)。
|
||||
description (str, optional): 字段的描述(可选,可以是几句话)。
|
||||
is_context (bool, optional): 指示该字段是否为上下文字段。
|
||||
omit_if_empty (bool, optional): 如果数据为空,是否省略该字段。
|
||||
"""
|
||||
if not function:
|
||||
function = self.handle_custom_data
|
||||
|
||||
# Register the handler function based on context flag
|
||||
if is_context:
|
||||
self.context_handlers[field_name] = function
|
||||
else:
|
||||
self.field_handlers[field_name] = function
|
||||
|
||||
# Store the custom title if provided and adjust the title prefix based on context
|
||||
title_prefix = "####" if is_context else "###"
|
||||
if title is not None:
|
||||
self.field_descriptions[field_name] = f"{title_prefix} {title}\n\n"
|
||||
elif description is not None:
|
||||
# If title is not provided but description is, use description as title
|
||||
self.field_descriptions[field_name] = f"{title_prefix} {field_name.replace('_', ' ').title()}\n\n{description}\n\n"
|
||||
else:
|
||||
# If neither title nor description is provided, use the field name as title
|
||||
self.field_descriptions[field_name] = f"{title_prefix} {field_name.replace('_', ' ').title()}\n\n"
|
||||
|
||||
# Store the omit_if_empty flag for this field
|
||||
self.omit_if_empty_flags[field_name] = omit_if_empty
|
||||
|
||||
if is_context and field_name != 'context_placeholder':
|
||||
self.context_handlers[field_name] = function
|
||||
self.context_order.append(field_name)
|
||||
else:
|
||||
self.field_handlers[field_name] = function
|
||||
self.field_order.append(field_name)
|
||||
|
||||
def generate_full_prompt(self, **kwargs):
|
||||
full_prompt = []
|
||||
context_prompts = [] # 用于收集上下文内容
|
||||
is_pre_print = kwargs.get("is_pre_print", False) # 用于强制打印所有prompt 字段信息,不管有没有空
|
||||
|
||||
# 先处理上下文字段
|
||||
for field_name in self.context_order:
|
||||
handler = self.context_handlers[field_name]
|
||||
processed_prompt = handler(**kwargs)
|
||||
# Check if the field should be omitted when empty
|
||||
if self.omit_if_empty_flags.get(field_name, False) and not processed_prompt and not is_pre_print:
|
||||
continue # Skip this field
|
||||
title_or_description = self.field_descriptions.get(field_name, f"#### {field_name.replace('_', ' ').title()}\n\n")
|
||||
context_prompts.append(title_or_description + processed_prompt + '\n\n')
|
||||
|
||||
# 处理普通字段,同时查找 context_placeholder 的位置
|
||||
for field_name in self.field_order:
|
||||
if field_name == 'context_placeholder':
|
||||
# 在 context_placeholder 的位置插入上下文数据
|
||||
full_prompt.append(self.context_title) # 添加上下文部分的大标题
|
||||
full_prompt.extend(context_prompts) # 添加收集的上下文内容
|
||||
else:
|
||||
handler = self.field_handlers[field_name]
|
||||
processed_prompt = handler(**kwargs)
|
||||
# Check if the field should be omitted when empty
|
||||
if self.omit_if_empty_flags.get(field_name, False) and not processed_prompt and not is_pre_print:
|
||||
continue # Skip this field
|
||||
title_or_description = self.field_descriptions.get(field_name, f"### {field_name.replace('_', ' ').title()}\n\n")
|
||||
full_prompt.append(title_or_description + processed_prompt + '\n\n')
|
||||
|
||||
# 返回完整的提示,移除尾部的空行
|
||||
return ''.join(full_prompt).rstrip('\n')
|
||||
|
||||
def pre_print(self, **kwargs):
|
||||
kwargs.update({"is_pre_print": True})
|
||||
prompt = self.generate_full_prompt(**kwargs)
|
||||
|
||||
input_keys = parse_section(self.role_prompt, 'Response Output Format')
|
||||
llm_predict = "\n".join([f"**{k}:**" for k in input_keys])
|
||||
return prompt + "\n\n" + "#"*19 + "\n<<<<LLM PREDICT>>>>\n" + "#"*19 + f"\n\n{llm_predict}\n"
|
||||
|
||||
def handle_custom_data(self, **kwargs):
|
||||
return ""
|
||||
|
||||
def handle_tool_data(self, **kwargs):
|
||||
if 'previous_agent_message' not in kwargs:
|
||||
return ""
|
||||
|
||||
previous_agent_message = kwargs.get('previous_agent_message')
|
||||
tools = previous_agent_message.tools
|
||||
|
||||
if not tools:
|
||||
return ""
|
||||
|
||||
tool_strings = []
|
||||
for tool in tools:
|
||||
args_schema = str(tool.args)
|
||||
tool_strings.append(f"{tool.name}: {tool.description}, args: {args_schema}")
|
||||
formatted_tools = "\n".join(tool_strings)
|
||||
|
||||
tool_names = ", ".join([tool.name for tool in tools])
|
||||
|
||||
tool_prompt = dedent(f"""
|
||||
Below is a list of tools that are available for your use:
|
||||
{formatted_tools}
|
||||
|
||||
valid "tool_name" value is:
|
||||
{tool_names}
|
||||
""")
|
||||
|
||||
return tool_prompt
|
||||
|
||||
def handle_agent_data(self, **kwargs):
|
||||
if 'agents' not in kwargs:
|
||||
return ""
|
||||
|
||||
agents = kwargs.get('agents')
|
||||
random.shuffle(agents)
|
||||
agent_names = ", ".join([f'{agent.role.role_name}' for agent in agents])
|
||||
agent_descs = []
|
||||
for agent in agents:
|
||||
role_desc = agent.role.role_prompt.split("####")[1]
|
||||
while "\n\n" in role_desc:
|
||||
role_desc = role_desc.replace("\n\n", "\n")
|
||||
role_desc = role_desc.replace("\n", ",")
|
||||
|
||||
agent_descs.append(f'"role name: {agent.role.role_name}\nrole description: {role_desc}"')
|
||||
|
||||
agents = "\n".join(agent_descs)
|
||||
agent_prompt = f'''
|
||||
Please ensure your selection is one of the listed roles. Available roles for selection:
|
||||
{agents}
|
||||
Please ensure select the Role from agent names, such as {agent_names}'''
|
||||
|
||||
return dedent(agent_prompt)
|
||||
|
||||
def handle_doc_info(self, **kwargs) -> str:
|
||||
if 'previous_agent_message' not in kwargs:
|
||||
return ""
|
||||
previous_agent_message: Message = kwargs.get('previous_agent_message')
|
||||
db_docs = previous_agent_message.db_docs
|
||||
search_docs = previous_agent_message.search_docs
|
||||
code_cocs = previous_agent_message.code_docs
|
||||
doc_infos = "\n".join([doc.get_snippet() for doc in db_docs] + [doc.get_snippet() for doc in search_docs] +
|
||||
[doc.get_code() for doc in code_cocs])
|
||||
return doc_infos
|
||||
|
||||
def handle_session_records(self, **kwargs) -> str:
|
||||
|
||||
memory_pool: Memory = kwargs.get('memory_pool', Memory(messages=[]))
|
||||
memory_pool = self.select_memory_by_agent_name(memory_pool)
|
||||
memory_pool = self.select_memory_by_parsed_key(memory_pool)
|
||||
|
||||
return memory_pool.to_str_messages(content_key="parsed_output_list", with_tag=True)
|
||||
|
||||
def handle_current_plan(self, **kwargs) -> str:
|
||||
if 'previous_agent_message' not in kwargs:
|
||||
return ""
|
||||
previous_agent_message = kwargs['previous_agent_message']
|
||||
return previous_agent_message.parsed_output.get("CURRENT_STEP", "")
|
||||
|
||||
def handle_agent_profile(self, **kwargs) -> str:
|
||||
return extract_section(self.role_prompt, 'Agent Profile')
|
||||
|
||||
def handle_output_format(self, **kwargs) -> str:
|
||||
return extract_section(self.role_prompt, 'Response Output Format')
|
||||
|
||||
def handle_response(self, **kwargs) -> str:
|
||||
if 'react_memory' not in kwargs:
|
||||
return ""
|
||||
|
||||
react_memory = kwargs.get('react_memory', Memory(messages=[]))
|
||||
if react_memory is None:
|
||||
return ""
|
||||
|
||||
return "\n".join(["\n".join([f"**{k}:**\n{v}" for k,v in _dict.items()]) for _dict in react_memory.get_parserd_output()])
|
||||
|
||||
def handle_task_records(self, **kwargs) -> str:
|
||||
if 'task_memory' not in kwargs:
|
||||
return ""
|
||||
|
||||
task_memory: Memory = kwargs.get('task_memory', Memory(messages=[]))
|
||||
if task_memory is None:
|
||||
return ""
|
||||
|
||||
return "\n".join(["\n".join([f"**{k}:**\n{v}" for k,v in _dict.items() if k not in ["CURRENT_STEP"]]) for _dict in task_memory.get_parserd_output()])
|
||||
|
||||
def handle_previous_message(self, message: Message) -> str:
|
||||
pass
|
||||
|
||||
def handle_message_by_role_name(self, message: Message) -> str:
|
||||
pass
|
||||
|
||||
def handle_message_by_role_type(self, message: Message) -> str:
|
||||
pass
|
||||
|
||||
def handle_current_agent_react_message(self, message: Message) -> str:
|
||||
pass
|
||||
|
||||
def extract_codedoc_info_for_prompt(self, message: Message) -> str:
|
||||
code_docs = message.code_docs
|
||||
doc_infos = "\n".join([doc.get_code() for doc in code_docs])
|
||||
return doc_infos
|
||||
|
||||
def select_memory_by_parsed_key(self, memory: Memory) -> Memory:
|
||||
return Memory(
|
||||
messages=[self.select_message_by_parsed_key(message) for message in memory.messages
|
||||
if self.select_message_by_parsed_key(message) is not None]
|
||||
)
|
||||
|
||||
def select_memory_by_agent_name(self, memory: Memory) -> Memory:
|
||||
return Memory(
|
||||
messages=[self.select_message_by_agent_name(message) for message in memory.messages
|
||||
if self.select_message_by_agent_name(message) is not None]
|
||||
)
|
||||
|
||||
def select_message_by_agent_name(self, message: Message) -> Message:
|
||||
# assume we focus all agents
|
||||
if self.monitored_agents == []:
|
||||
return message
|
||||
return None if message is None or message.role_name not in self.monitored_agents else self.select_message_by_parsed_key(message)
|
||||
|
||||
def select_message_by_parsed_key(self, message: Message) -> Message:
|
||||
# assume we focus all key contents
|
||||
if message is None:
|
||||
return message
|
||||
|
||||
if self.monitored_fields == []:
|
||||
return message
|
||||
|
||||
message_c = copy.deepcopy(message)
|
||||
message_c.parsed_output = {k: v for k,v in message_c.parsed_output.items() if k in self.monitored_fields}
|
||||
message_c.parsed_output_list = [{k: v for k,v in parsed_output.items() if k in self.monitored_fields} for parsed_output in message_c.parsed_output_list]
|
||||
return message_c
|
||||
|
||||
def get_memory(self, content_key="role_content"):
|
||||
return self.memory.to_tuple_messages(content_key="step_content")
|
||||
|
||||
def get_memory_str(self, content_key="role_content"):
|
||||
return "\n".join([": ".join(i) for i in self.memory.to_tuple_messages(content_key="step_content")])
|
||||
|
||||
def register_fields_from_config(self):
|
||||
|
||||
for prompt_field in self.prompt_config:
|
||||
|
||||
function_name = prompt_field.function_name
|
||||
# 检查function_name是否是self的一个方法
|
||||
if function_name and hasattr(self, function_name):
|
||||
function = getattr(self, function_name)
|
||||
else:
|
||||
function = self.handle_custom_data
|
||||
|
||||
self.register_field(prompt_field.field_name,
|
||||
function=function,
|
||||
title=prompt_field.title,
|
||||
description=prompt_field.description,
|
||||
is_context=prompt_field.is_context,
|
||||
omit_if_empty=prompt_field.omit_if_empty)
|
||||
|
||||
def register_standard_fields(self):
|
||||
self.register_field('agent_profile', function=self.handle_agent_profile, is_context=False)
|
||||
self.register_field('tool_information', function=self.handle_tool_data, is_context=False)
|
||||
self.register_field('context_placeholder', is_context=True) # 用于标记上下文数据部分的位置
|
||||
self.register_field('reference_documents', function=self.handle_doc_info, is_context=True)
|
||||
self.register_field('session_records', function=self.handle_session_records, is_context=True)
|
||||
self.register_field('output_format', function=self.handle_output_format, title='Response Output Format', is_context=False)
|
||||
self.register_field('response', function=self.handle_response, is_context=False, omit_if_empty=False)
|
||||
|
||||
def register_executor_fields(self):
|
||||
self.register_field('agent_profile', function=self.handle_agent_profile, is_context=False)
|
||||
self.register_field('tool_information', function=self.handle_tool_data, is_context=False)
|
||||
self.register_field('context_placeholder', is_context=True) # 用于标记上下文数据部分的位置
|
||||
self.register_field('reference_documents', function=self.handle_doc_info, is_context=True)
|
||||
self.register_field('session_records', function=self.handle_session_records, is_context=True)
|
||||
self.register_field('current_plan', function=self.handle_current_plan, is_context=True)
|
||||
self.register_field('output_format', function=self.handle_output_format, title='Response Output Format', is_context=False)
|
||||
self.register_field('response', function=self.handle_response, is_context=False, omit_if_empty=False)
|
||||
|
||||
def register_fields_from_dict(self, fields_dict):
|
||||
# 使用字典注册字段的函数
|
||||
for field_name, field_config in fields_dict.items():
|
||||
function_name = field_config.get('function', None)
|
||||
title = field_config.get('title', None)
|
||||
description = field_config.get('description', None)
|
||||
is_context = field_config.get('is_context', True)
|
||||
omit_if_empty = field_config.get('omit_if_empty', True)
|
||||
|
||||
# 检查function_name是否是self的一个方法
|
||||
if function_name and hasattr(self, function_name):
|
||||
function = getattr(self, function_name)
|
||||
else:
|
||||
function = self.handle_custom_data
|
||||
|
||||
# 调用已存在的register_field方法注册字段
|
||||
self.register_field(field_name, function=function, title=title, description=description, is_context=is_context, omit_if_empty=omit_if_empty)
|
||||
|
||||
|
||||
|
||||
def main():
|
||||
manager = PromptManager()
|
||||
manager.register_standard_fields()
|
||||
|
||||
manager.register_field('agents_work_progress', title=f"Agents' Work Progress", is_context=True)
|
||||
|
||||
# 创建数据字典
|
||||
data_dict = {
|
||||
"agent_profile": "这是代理配置文件...",
|
||||
# "tool_list": "这是工具列表...",
|
||||
"reference_documents": "这是参考文档...",
|
||||
"session_records": "这是会话记录...",
|
||||
"agents_work_progress": "这是代理工作进展...",
|
||||
"output_format": "这是预期的输出格式...",
|
||||
# "response": "这是生成或继续回应的指令...",
|
||||
"response": "",
|
||||
"test": 'xxxxx'
|
||||
}
|
||||
|
||||
# 组合完整的提示
|
||||
full_prompt = manager.generate_full_prompt(data_dict)
|
||||
print(full_prompt)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -1,2 +0,0 @@
|
|||
from .prompt_manager import PromptManager
|
||||
from .extend_manager import *
|
|
@ -1,45 +0,0 @@
|
|||
|
||||
from coagent.connector.schema import Message
|
||||
from .prompt_manager import PromptManager
|
||||
|
||||
|
||||
class Code2DocPM(PromptManager):
|
||||
def handle_code_snippet(self, **kwargs) -> str:
|
||||
if 'previous_agent_message' not in kwargs:
|
||||
return ""
|
||||
previous_agent_message: Message = kwargs['previous_agent_message']
|
||||
code_snippet = previous_agent_message.customed_kargs.get("Code Snippet", "")
|
||||
current_vertex = previous_agent_message.customed_kargs.get("Current_Vertex", "")
|
||||
instruction = "A segment of code that contains the function or method to be documented.\n"
|
||||
return instruction + "\n" + f"name: {current_vertex}\n{code_snippet}"
|
||||
|
||||
def handle_specific_objective(self, **kwargs) -> str:
|
||||
if 'previous_agent_message' not in kwargs:
|
||||
return ""
|
||||
previous_agent_message: Message = kwargs['previous_agent_message']
|
||||
specific_objective = previous_agent_message.parsed_output.get("Code Path")
|
||||
|
||||
instruction = "Provide the code path of the function or method you wish to document.\n"
|
||||
s = instruction + f"\n{specific_objective}"
|
||||
return s
|
||||
|
||||
|
||||
class CodeRetrievalPM(PromptManager):
|
||||
def handle_code_snippet(self, **kwargs) -> str:
|
||||
if 'previous_agent_message' not in kwargs:
|
||||
return ""
|
||||
previous_agent_message: Message = kwargs['previous_agent_message']
|
||||
code_snippet = previous_agent_message.customed_kargs.get("Code Snippet", "")
|
||||
current_vertex = previous_agent_message.customed_kargs.get("Current_Vertex", "")
|
||||
instruction = "the initial Code or objective that the user wanted to achieve"
|
||||
return instruction + "\n" + f"name: {current_vertex}\n{code_snippet}"
|
||||
|
||||
def handle_retrieval_codes(self, **kwargs) -> str:
|
||||
if 'previous_agent_message' not in kwargs:
|
||||
return ""
|
||||
previous_agent_message: Message = kwargs['previous_agent_message']
|
||||
Retrieval_Codes = previous_agent_message.customed_kargs["Retrieval_Codes"]
|
||||
Relative_vertex = previous_agent_message.customed_kargs["Relative_vertex"]
|
||||
instruction = "the initial Code or objective that the user wanted to achieve"
|
||||
s = instruction + "\n" + "\n".join([f"name: {vertext}\n{code}" for vertext, code in zip(Relative_vertex, Retrieval_Codes)])
|
||||
return s
|
|
@ -1,353 +0,0 @@
|
|||
import random
|
||||
from textwrap import dedent
|
||||
import copy
|
||||
from loguru import logger
|
||||
|
||||
from langchain.agents.tools import Tool
|
||||
|
||||
from coagent.connector.schema import Memory, Message
|
||||
from coagent.connector.utils import extract_section, parse_section
|
||||
|
||||
|
||||
|
||||
class PromptManager:
|
||||
def __init__(self, role_prompt="", prompt_config=None, monitored_agents=[], monitored_fields=[]):
|
||||
self.role_prompt = role_prompt
|
||||
self.monitored_agents = monitored_agents
|
||||
self.monitored_fields = monitored_fields
|
||||
self.field_handlers = {}
|
||||
self.context_handlers = {}
|
||||
self.field_order = [] # 用于普通字段的顺序
|
||||
self.context_order = [] # 单独维护上下文字段的顺序
|
||||
self.field_descriptions = {}
|
||||
self.omit_if_empty_flags = {}
|
||||
self.context_title = "### Context Data\n\n"
|
||||
|
||||
self.prompt_config = prompt_config
|
||||
if self.prompt_config:
|
||||
self.register_fields_from_config()
|
||||
|
||||
def register_field(self, field_name, function=None, title=None, description=None, is_context=True, omit_if_empty=True):
|
||||
"""
|
||||
注册一个新的字段及其处理函数。
|
||||
Args:
|
||||
field_name (str): 字段名称。
|
||||
function (callable): 处理字段数据的函数。
|
||||
title (str, optional): 字段的自定义标题(可选)。
|
||||
description (str, optional): 字段的描述(可选,可以是几句话)。
|
||||
is_context (bool, optional): 指示该字段是否为上下文字段。
|
||||
omit_if_empty (bool, optional): 如果数据为空,是否省略该字段。
|
||||
"""
|
||||
if not function:
|
||||
function = self.handle_custom_data
|
||||
|
||||
# Register the handler function based on context flag
|
||||
if is_context:
|
||||
self.context_handlers[field_name] = function
|
||||
else:
|
||||
self.field_handlers[field_name] = function
|
||||
|
||||
# Store the custom title if provided and adjust the title prefix based on context
|
||||
title_prefix = "####" if is_context else "###"
|
||||
if title is not None:
|
||||
self.field_descriptions[field_name] = f"{title_prefix} {title}\n\n"
|
||||
elif description is not None:
|
||||
# If title is not provided but description is, use description as title
|
||||
self.field_descriptions[field_name] = f"{title_prefix} {field_name.replace('_', ' ').title()}\n\n{description}\n\n"
|
||||
else:
|
||||
# If neither title nor description is provided, use the field name as title
|
||||
self.field_descriptions[field_name] = f"{title_prefix} {field_name.replace('_', ' ').title()}\n\n"
|
||||
|
||||
# Store the omit_if_empty flag for this field
|
||||
self.omit_if_empty_flags[field_name] = omit_if_empty
|
||||
|
||||
if is_context and field_name != 'context_placeholder':
|
||||
self.context_handlers[field_name] = function
|
||||
self.context_order.append(field_name)
|
||||
else:
|
||||
self.field_handlers[field_name] = function
|
||||
self.field_order.append(field_name)
|
||||
|
||||
def generate_full_prompt(self, **kwargs):
|
||||
full_prompt = []
|
||||
context_prompts = [] # 用于收集上下文内容
|
||||
is_pre_print = kwargs.get("is_pre_print", False) # 用于强制打印所有prompt 字段信息,不管有没有空
|
||||
|
||||
# 先处理上下文字段
|
||||
for field_name in self.context_order:
|
||||
handler = self.context_handlers[field_name]
|
||||
processed_prompt = handler(**kwargs)
|
||||
# Check if the field should be omitted when empty
|
||||
if self.omit_if_empty_flags.get(field_name, False) and not processed_prompt and not is_pre_print:
|
||||
continue # Skip this field
|
||||
title_or_description = self.field_descriptions.get(field_name, f"#### {field_name.replace('_', ' ').title()}\n\n")
|
||||
context_prompts.append(title_or_description + processed_prompt + '\n\n')
|
||||
|
||||
# 处理普通字段,同时查找 context_placeholder 的位置
|
||||
for field_name in self.field_order:
|
||||
if field_name == 'context_placeholder':
|
||||
# 在 context_placeholder 的位置插入上下文数据
|
||||
full_prompt.append(self.context_title) # 添加上下文部分的大标题
|
||||
full_prompt.extend(context_prompts) # 添加收集的上下文内容
|
||||
else:
|
||||
handler = self.field_handlers[field_name]
|
||||
processed_prompt = handler(**kwargs)
|
||||
# Check if the field should be omitted when empty
|
||||
if self.omit_if_empty_flags.get(field_name, False) and not processed_prompt and not is_pre_print:
|
||||
continue # Skip this field
|
||||
title_or_description = self.field_descriptions.get(field_name, f"### {field_name.replace('_', ' ').title()}\n\n")
|
||||
full_prompt.append(title_or_description + processed_prompt + '\n\n')
|
||||
|
||||
# 返回完整的提示,移除尾部的空行
|
||||
return ''.join(full_prompt).rstrip('\n')
|
||||
|
||||
def pre_print(self, **kwargs):
|
||||
kwargs.update({"is_pre_print": True})
|
||||
prompt = self.generate_full_prompt(**kwargs)
|
||||
|
||||
input_keys = parse_section(self.role_prompt, 'Response Output Format')
|
||||
llm_predict = "\n".join([f"**{k}:**" for k in input_keys])
|
||||
return prompt + "\n\n" + "#"*19 + "\n<<<<LLM PREDICT>>>>\n" + "#"*19 + f"\n\n{llm_predict}\n"
|
||||
|
||||
def handle_custom_data(self, **kwargs):
|
||||
return ""
|
||||
|
||||
def handle_tool_data(self, **kwargs):
|
||||
if 'previous_agent_message' not in kwargs:
|
||||
return ""
|
||||
|
||||
previous_agent_message = kwargs.get('previous_agent_message')
|
||||
tools: list[Tool] = previous_agent_message.tools
|
||||
|
||||
if not tools:
|
||||
return ""
|
||||
|
||||
tool_strings = []
|
||||
for tool in tools:
|
||||
args_str = f'args: {str(tool.args)}' if tool.args_schema else ""
|
||||
tool_strings.append(f"{tool.name}: {tool.description}, {args_str}")
|
||||
formatted_tools = "\n".join(tool_strings)
|
||||
|
||||
tool_names = ", ".join([tool.name for tool in tools])
|
||||
|
||||
tool_prompt = dedent(f"""
|
||||
Below is a list of tools that are available for your use:
|
||||
{formatted_tools}
|
||||
|
||||
valid "tool_name" value is:
|
||||
{tool_names}
|
||||
""")
|
||||
|
||||
return tool_prompt
|
||||
|
||||
def handle_agent_data(self, **kwargs):
|
||||
if 'agents' not in kwargs:
|
||||
return ""
|
||||
|
||||
agents = kwargs.get('agents')
|
||||
random.shuffle(agents)
|
||||
agent_names = ", ".join([f'{agent.role.role_name}' for agent in agents])
|
||||
agent_descs = []
|
||||
for agent in agents:
|
||||
role_desc = agent.role.role_prompt.split("####")[1]
|
||||
while "\n\n" in role_desc:
|
||||
role_desc = role_desc.replace("\n\n", "\n")
|
||||
role_desc = role_desc.replace("\n", ",")
|
||||
|
||||
agent_descs.append(f'"role name: {agent.role.role_name}\nrole description: {role_desc}"')
|
||||
|
||||
agents = "\n".join(agent_descs)
|
||||
agent_prompt = f'''
|
||||
Please ensure your selection is one of the listed roles. Available roles for selection:
|
||||
{agents}
|
||||
Please ensure select the Role from agent names, such as {agent_names}'''
|
||||
|
||||
return dedent(agent_prompt)
|
||||
|
||||
def handle_doc_info(self, **kwargs) -> str:
|
||||
if 'previous_agent_message' not in kwargs:
|
||||
return ""
|
||||
previous_agent_message: Message = kwargs.get('previous_agent_message')
|
||||
db_docs = previous_agent_message.db_docs
|
||||
search_docs = previous_agent_message.search_docs
|
||||
code_cocs = previous_agent_message.code_docs
|
||||
doc_infos = "\n".join([doc.get_snippet() for doc in db_docs] + [doc.get_snippet() for doc in search_docs] +
|
||||
[doc.get_code() for doc in code_cocs])
|
||||
return doc_infos
|
||||
|
||||
def handle_session_records(self, **kwargs) -> str:
|
||||
|
||||
memory_pool: Memory = kwargs.get('memory_pool', Memory(messages=[]))
|
||||
memory_pool = self.select_memory_by_agent_name(memory_pool)
|
||||
memory_pool = self.select_memory_by_parsed_key(memory_pool)
|
||||
|
||||
return memory_pool.to_str_messages(content_key="parsed_output_list", with_tag=True)
|
||||
|
||||
def handle_current_plan(self, **kwargs) -> str:
|
||||
if 'previous_agent_message' not in kwargs:
|
||||
return ""
|
||||
previous_agent_message = kwargs['previous_agent_message']
|
||||
return previous_agent_message.parsed_output.get("CURRENT_STEP", "")
|
||||
|
||||
def handle_agent_profile(self, **kwargs) -> str:
|
||||
return extract_section(self.role_prompt, 'Agent Profile')
|
||||
|
||||
def handle_output_format(self, **kwargs) -> str:
|
||||
return extract_section(self.role_prompt, 'Response Output Format')
|
||||
|
||||
def handle_response(self, **kwargs) -> str:
|
||||
if 'react_memory' not in kwargs:
|
||||
return ""
|
||||
|
||||
react_memory = kwargs.get('react_memory', Memory(messages=[]))
|
||||
if react_memory is None:
|
||||
return ""
|
||||
|
||||
return "\n".join(["\n".join([f"**{k}:**\n{v}" for k,v in _dict.items()]) for _dict in react_memory.get_parserd_output()])
|
||||
|
||||
def handle_task_records(self, **kwargs) -> str:
|
||||
if 'task_memory' not in kwargs:
|
||||
return ""
|
||||
|
||||
task_memory: Memory = kwargs.get('task_memory', Memory(messages=[]))
|
||||
if task_memory is None:
|
||||
return ""
|
||||
|
||||
return "\n".join(["\n".join([f"**{k}:**\n{v}" for k,v in _dict.items() if k not in ["CURRENT_STEP"]]) for _dict in task_memory.get_parserd_output()])
|
||||
|
||||
def handle_previous_message(self, message: Message) -> str:
|
||||
pass
|
||||
|
||||
def handle_message_by_role_name(self, message: Message) -> str:
|
||||
pass
|
||||
|
||||
def handle_message_by_role_type(self, message: Message) -> str:
|
||||
pass
|
||||
|
||||
def handle_current_agent_react_message(self, message: Message) -> str:
|
||||
pass
|
||||
|
||||
def extract_codedoc_info_for_prompt(self, message: Message) -> str:
|
||||
code_docs = message.code_docs
|
||||
doc_infos = "\n".join([doc.get_code() for doc in code_docs])
|
||||
return doc_infos
|
||||
|
||||
def select_memory_by_parsed_key(self, memory: Memory) -> Memory:
|
||||
return Memory(
|
||||
messages=[self.select_message_by_parsed_key(message) for message in memory.messages
|
||||
if self.select_message_by_parsed_key(message) is not None]
|
||||
)
|
||||
|
||||
def select_memory_by_agent_name(self, memory: Memory) -> Memory:
|
||||
return Memory(
|
||||
messages=[self.select_message_by_agent_name(message) for message in memory.messages
|
||||
if self.select_message_by_agent_name(message) is not None]
|
||||
)
|
||||
|
||||
def select_message_by_agent_name(self, message: Message) -> Message:
|
||||
# assume we focus all agents
|
||||
if self.monitored_agents == []:
|
||||
return message
|
||||
return None if message is None or message.role_name not in self.monitored_agents else self.select_message_by_parsed_key(message)
|
||||
|
||||
def select_message_by_parsed_key(self, message: Message) -> Message:
|
||||
# assume we focus all key contents
|
||||
if message is None:
|
||||
return message
|
||||
|
||||
if self.monitored_fields == []:
|
||||
return message
|
||||
|
||||
message_c = copy.deepcopy(message)
|
||||
message_c.parsed_output = {k: v for k,v in message_c.parsed_output.items() if k in self.monitored_fields}
|
||||
message_c.parsed_output_list = [{k: v for k,v in parsed_output.items() if k in self.monitored_fields} for parsed_output in message_c.parsed_output_list]
|
||||
return message_c
|
||||
|
||||
def get_memory(self, content_key="role_content"):
|
||||
return self.memory.to_tuple_messages(content_key="step_content")
|
||||
|
||||
def get_memory_str(self, content_key="role_content"):
|
||||
return "\n".join([": ".join(i) for i in self.memory.to_tuple_messages(content_key="step_content")])
|
||||
|
||||
def register_fields_from_config(self):
|
||||
|
||||
for prompt_field in self.prompt_config:
|
||||
|
||||
function_name = prompt_field.function_name
|
||||
# 检查function_name是否是self的一个方法
|
||||
if function_name and hasattr(self, function_name):
|
||||
function = getattr(self, function_name)
|
||||
else:
|
||||
function = self.handle_custom_data
|
||||
|
||||
self.register_field(prompt_field.field_name,
|
||||
function=function,
|
||||
title=prompt_field.title,
|
||||
description=prompt_field.description,
|
||||
is_context=prompt_field.is_context,
|
||||
omit_if_empty=prompt_field.omit_if_empty)
|
||||
|
||||
def register_standard_fields(self):
|
||||
self.register_field('agent_profile', function=self.handle_agent_profile, is_context=False)
|
||||
self.register_field('tool_information', function=self.handle_tool_data, is_context=False)
|
||||
self.register_field('context_placeholder', is_context=True) # 用于标记上下文数据部分的位置
|
||||
self.register_field('reference_documents', function=self.handle_doc_info, is_context=True)
|
||||
self.register_field('session_records', function=self.handle_session_records, is_context=True)
|
||||
self.register_field('output_format', function=self.handle_output_format, title='Response Output Format', is_context=False)
|
||||
self.register_field('response', function=self.handle_response, is_context=False, omit_if_empty=False)
|
||||
|
||||
def register_executor_fields(self):
|
||||
self.register_field('agent_profile', function=self.handle_agent_profile, is_context=False)
|
||||
self.register_field('tool_information', function=self.handle_tool_data, is_context=False)
|
||||
self.register_field('context_placeholder', is_context=True) # 用于标记上下文数据部分的位置
|
||||
self.register_field('reference_documents', function=self.handle_doc_info, is_context=True)
|
||||
self.register_field('session_records', function=self.handle_session_records, is_context=True)
|
||||
self.register_field('current_plan', function=self.handle_current_plan, is_context=True)
|
||||
self.register_field('output_format', function=self.handle_output_format, title='Response Output Format', is_context=False)
|
||||
self.register_field('response', function=self.handle_response, is_context=False, omit_if_empty=False)
|
||||
|
||||
def register_fields_from_dict(self, fields_dict):
|
||||
# 使用字典注册字段的函数
|
||||
for field_name, field_config in fields_dict.items():
|
||||
function_name = field_config.get('function', None)
|
||||
title = field_config.get('title', None)
|
||||
description = field_config.get('description', None)
|
||||
is_context = field_config.get('is_context', True)
|
||||
omit_if_empty = field_config.get('omit_if_empty', True)
|
||||
|
||||
# 检查function_name是否是self的一个方法
|
||||
if function_name and hasattr(self, function_name):
|
||||
function = getattr(self, function_name)
|
||||
else:
|
||||
function = self.handle_custom_data
|
||||
|
||||
# 调用已存在的register_field方法注册字段
|
||||
self.register_field(field_name, function=function, title=title, description=description, is_context=is_context, omit_if_empty=omit_if_empty)
|
||||
|
||||
|
||||
|
||||
def main():
|
||||
manager = PromptManager()
|
||||
manager.register_standard_fields()
|
||||
|
||||
manager.register_field('agents_work_progress', title=f"Agents' Work Progress", is_context=True)
|
||||
|
||||
# 创建数据字典
|
||||
data_dict = {
|
||||
"agent_profile": "这是代理配置文件...",
|
||||
# "tool_list": "这是工具列表...",
|
||||
"reference_documents": "这是参考文档...",
|
||||
"session_records": "这是会话记录...",
|
||||
"agents_work_progress": "这是代理工作进展...",
|
||||
"output_format": "这是预期的输出格式...",
|
||||
# "response": "这是生成或继续回应的指令...",
|
||||
"response": "",
|
||||
"test": 'xxxxx'
|
||||
}
|
||||
|
||||
# 组合完整的提示
|
||||
full_prompt = manager.generate_full_prompt(data_dict)
|
||||
print(full_prompt)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -1,9 +0,0 @@
|
|||
from .memory import Memory
|
||||
from .general_schema import *
|
||||
from .message import Message
|
||||
|
||||
__all__ = [
|
||||
"Memory", "ActionStatus", "Doc", "CodeDoc", "Task", "LogVerboseEnum",
|
||||
"Env", "Role", "ChainConfig", "AgentConfig", "PhaseConfig", "Message",
|
||||
"load_role_configs", "load_chain_configs", "load_phase_configs"
|
||||
]
|
|
@ -1,309 +0,0 @@
|
|||
from pydantic import BaseModel
|
||||
from typing import List, Dict, Optional, Union
|
||||
from enum import Enum
|
||||
import re
|
||||
import json
|
||||
from loguru import logger
|
||||
from langchain.tools import BaseTool
|
||||
|
||||
|
||||
class ActionStatus(Enum):
|
||||
DEFAUILT = "default"
|
||||
|
||||
FINISHED = "finished"
|
||||
STOPPED = "stopped"
|
||||
CONTINUED = "continued"
|
||||
|
||||
TOOL_USING = "tool_using"
|
||||
CODING = "coding"
|
||||
CODE_EXECUTING = "code_executing"
|
||||
CODING2FILE = "coding2file"
|
||||
|
||||
PLANNING = "planning"
|
||||
UNCHANGED = "unchanged"
|
||||
ADJUSTED = "adjusted"
|
||||
CODE_RETRIEVAL = "code_retrieval"
|
||||
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, str):
|
||||
return self.value.lower() == other.lower()
|
||||
return super().__eq__(other)
|
||||
|
||||
|
||||
class Action(BaseModel):
|
||||
action_name: str
|
||||
description: str
|
||||
|
||||
class FinishedAction(Action):
|
||||
action_name: str = ActionStatus.FINISHED
|
||||
description: str = "provide the final answer to the original query to break the chain answer"
|
||||
|
||||
class StoppedAction(Action):
|
||||
action_name: str = ActionStatus.STOPPED
|
||||
description: str = "provide the final answer to the original query to break the agent answer"
|
||||
|
||||
class ContinuedAction(Action):
|
||||
action_name: str = ActionStatus.CONTINUED
|
||||
description: str = "cant't provide the final answer to the original query"
|
||||
|
||||
class ToolUsingAction(Action):
|
||||
action_name: str = ActionStatus.TOOL_USING
|
||||
description: str = "proceed with using the specified tool."
|
||||
|
||||
class CodingdAction(Action):
|
||||
action_name: str = ActionStatus.CODING
|
||||
description: str = "provide the answer by writing code"
|
||||
|
||||
class Coding2FileAction(Action):
|
||||
action_name: str = ActionStatus.CODING2FILE
|
||||
description: str = "provide the answer by writing code and filename"
|
||||
|
||||
class CodeExecutingAction(Action):
|
||||
action_name: str = ActionStatus.CODE_EXECUTING
|
||||
description: str = "provide the answer by writing executable code"
|
||||
|
||||
class PlanningAction(Action):
|
||||
action_name: str = ActionStatus.PLANNING
|
||||
description: str = "provide a sequence of tasks"
|
||||
|
||||
class UnchangedAction(Action):
|
||||
action_name: str = ActionStatus.UNCHANGED
|
||||
description: str = "this PLAN has no problem, just set PLAN_STEP to CURRENT_STEP+1."
|
||||
|
||||
class AdjustedAction(Action):
|
||||
action_name: str = ActionStatus.ADJUSTED
|
||||
description: str = "the PLAN is to provide an optimized version of the original plan."
|
||||
|
||||
# extended action exmaple
|
||||
class CodeRetrievalAction(Action):
|
||||
action_name: str = ActionStatus.CODE_RETRIEVAL
|
||||
description: str = "execute the code retrieval to acquire more code information"
|
||||
|
||||
|
||||
class RoleTypeEnums(Enum):
|
||||
SYSTEM = "system"
|
||||
USER = "user"
|
||||
ASSISTANT = "assistant"
|
||||
FUNCTION = "function"
|
||||
OBSERVATION = "observation"
|
||||
SUMMARY = "summary"
|
||||
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, str):
|
||||
return self.value == other
|
||||
return super().__eq__(other)
|
||||
|
||||
|
||||
class PromptKey(BaseModel):
|
||||
key_name: str
|
||||
description: str
|
||||
|
||||
class PromptKeyEnums(Enum):
|
||||
# Origin Query is ui's user question
|
||||
ORIGIN_QUERY = "origin_query"
|
||||
# agent's input from last agent
|
||||
CURRENT_QUESTION = "current_question"
|
||||
# ui memory contaisn (user and assistants)
|
||||
UI_MEMORY = "ui_memory"
|
||||
# agent's memory
|
||||
SELF_MEMORY = "self_memory"
|
||||
# chain memory
|
||||
CHAIN_MEMORY = "chain_memory"
|
||||
# agent's memory
|
||||
SELF_LOCAL_MEMORY = "self_local_memory"
|
||||
# chain memory
|
||||
CHAIN_LOCAL_MEMORY = "chain_local_memory"
|
||||
# Doc Infomations contains (Doc\Code\Search)
|
||||
DOC_INFOS = "doc_infos"
|
||||
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, str):
|
||||
return self.value == other
|
||||
return super().__eq__(other)
|
||||
|
||||
|
||||
class Doc(BaseModel):
|
||||
title: str
|
||||
snippet: str
|
||||
link: str
|
||||
index: int
|
||||
|
||||
def get_title(self):
|
||||
return self.title
|
||||
|
||||
def get_snippet(self, ):
|
||||
return self.snippet
|
||||
|
||||
def get_link(self, ):
|
||||
return self.link
|
||||
|
||||
def get_index(self, ):
|
||||
return self.index
|
||||
|
||||
def to_json(self):
|
||||
return vars(self)
|
||||
|
||||
def __str__(self,):
|
||||
return f"""出处 [{self.index + 1}] 标题 [{self.title}]\n\n来源 ({self.link}) \n\n内容 {self.snippet}\n\n"""
|
||||
|
||||
|
||||
class CodeDoc(BaseModel):
|
||||
code: str
|
||||
related_nodes: list
|
||||
index: int
|
||||
|
||||
def get_code(self, ):
|
||||
return self.code
|
||||
|
||||
def get_related_node(self, ):
|
||||
return self.related_nodes
|
||||
|
||||
def get_index(self, ):
|
||||
return self.index
|
||||
|
||||
def to_json(self):
|
||||
return vars(self)
|
||||
|
||||
def __str__(self,):
|
||||
return f"""出处 [{self.index + 1}] \n\n来源 ({self.related_nodes}) \n\n内容 {self.code}\n\n"""
|
||||
|
||||
|
||||
class LogVerboseEnum(Enum):
|
||||
Log0Level = "0" # don't print log
|
||||
Log1Level = "1" # print level-1 log
|
||||
Log2Level = "2" # print level-2 log
|
||||
Log3Level = "3" # print level-3 log
|
||||
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, str):
|
||||
return self.value.lower() == other.lower()
|
||||
if isinstance(other, LogVerboseEnum):
|
||||
return self.value == other.value
|
||||
return False
|
||||
|
||||
def __ge__(self, other):
|
||||
if isinstance(other, LogVerboseEnum):
|
||||
return int(self.value) >= int(other.value)
|
||||
if isinstance(other, str):
|
||||
return int(self.value) >= int(other)
|
||||
return NotImplemented
|
||||
|
||||
def __le__(self, other):
|
||||
if isinstance(other, LogVerboseEnum):
|
||||
return int(self.value) <= int(other.value)
|
||||
if isinstance(other, str):
|
||||
return int(self.value) <= int(other)
|
||||
return NotImplemented
|
||||
|
||||
@classmethod
|
||||
def ge(self, enum_value: 'LogVerboseEnum', other: Union[str, 'LogVerboseEnum']):
|
||||
return enum_value <= other
|
||||
|
||||
|
||||
class Task(BaseModel):
|
||||
task_type: str
|
||||
task_name: str
|
||||
task_desc: str
|
||||
task_prompt: str
|
||||
|
||||
class Env(BaseModel):
|
||||
env_type: str
|
||||
env_name: str
|
||||
env_desc:str
|
||||
|
||||
|
||||
class Role(BaseModel):
|
||||
role_type: str
|
||||
role_name: str
|
||||
role_desc: str = ""
|
||||
agent_type: str = "BaseAgent"
|
||||
role_prompt: str = ""
|
||||
template_prompt: str = ""
|
||||
|
||||
|
||||
class ChainConfig(BaseModel):
|
||||
chain_name: str
|
||||
chain_type: str = "BaseChain"
|
||||
agents: List[str]
|
||||
do_checker: bool = False
|
||||
chat_turn: int = 1
|
||||
|
||||
|
||||
class PromptField(BaseModel):
|
||||
field_name: str # 假设这是一个函数类型,您可以根据需要更改
|
||||
function_name: str
|
||||
title: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
is_context: Optional[bool] = True
|
||||
omit_if_empty: Optional[bool] = True
|
||||
|
||||
|
||||
class AgentConfig(BaseModel):
|
||||
role: Role
|
||||
prompt_config: List[PromptField]
|
||||
prompt_manager_type: str = "PromptManager"
|
||||
chat_turn: int = 1
|
||||
focus_agents: List = []
|
||||
focus_message_keys: List = []
|
||||
group_agents: List = []
|
||||
stop: str = ""
|
||||
|
||||
|
||||
class PhaseConfig(BaseModel):
|
||||
phase_name: str
|
||||
phase_type: str
|
||||
chains: List[str]
|
||||
do_summary: bool = False
|
||||
do_search: bool = False
|
||||
do_doc_retrieval: bool = False
|
||||
do_code_retrieval: bool = False
|
||||
do_tool_retrieval: bool = False
|
||||
|
||||
|
||||
class CompleteChainConfig(BaseModel):
|
||||
chain_name: str
|
||||
chain_type: str
|
||||
agents: Dict[str, AgentConfig]
|
||||
do_checker: bool = False
|
||||
chat_turn: int = 1
|
||||
|
||||
|
||||
class CompletePhaseConfig(BaseModel):
|
||||
phase_name: str
|
||||
phase_type: str
|
||||
chains: Dict[str, CompleteChainConfig]
|
||||
do_summary: bool = False
|
||||
do_search: bool = False
|
||||
do_doc_retrieval: bool = False
|
||||
do_code_retrieval: bool = False
|
||||
do_tool_retrieval: bool = False
|
||||
|
||||
|
||||
def load_role_configs(config) -> Dict[str, AgentConfig]:
|
||||
if isinstance(config, str):
|
||||
with open(config, 'r', encoding="utf8") as file:
|
||||
configs = json.load(file)
|
||||
else:
|
||||
configs = config
|
||||
# logger.debug(configs)
|
||||
return {name: AgentConfig(**v) for name, v in configs.items()}
|
||||
|
||||
|
||||
def load_chain_configs(config) -> Dict[str, ChainConfig]:
|
||||
if isinstance(config, str):
|
||||
with open(config, 'r', encoding="utf8") as file:
|
||||
configs = json.load(file)
|
||||
else:
|
||||
configs = config
|
||||
return {name: ChainConfig(**v) for name, v in configs.items()}
|
||||
|
||||
|
||||
def load_phase_configs(config) -> Dict[str, PhaseConfig]:
|
||||
if isinstance(config, str):
|
||||
with open(config, 'r', encoding="utf8") as file:
|
||||
configs = json.load(file)
|
||||
else:
|
||||
configs = config
|
||||
return {name: PhaseConfig(**v) for name, v in configs.items()}
|
||||
|
||||
# AgentConfig.update_forward_refs()
|
|
@ -1,161 +0,0 @@
|
|||
from pydantic import BaseModel
|
||||
from typing import List, Union, Dict
|
||||
from loguru import logger
|
||||
|
||||
from .message import Message
|
||||
from coagent.utils.common_utils import (
|
||||
save_to_jsonl_file, save_to_json_file, read_json_file, read_jsonl_file
|
||||
)
|
||||
|
||||
|
||||
class Memory(BaseModel):
|
||||
messages: List[Message] = []
|
||||
|
||||
# def __init__(self, messages: List[Message] = []):
|
||||
# self.messages = messages
|
||||
|
||||
def append(self, message: Message):
|
||||
self.messages.append(message)
|
||||
|
||||
def extend(self, memory: 'Memory'):
|
||||
self.messages.extend(memory.messages)
|
||||
|
||||
def update(self, role_name: str, role_type: str, role_content: str):
|
||||
self.messages.append(Message(role_name, role_type, role_content, role_content))
|
||||
|
||||
def clear(self, ):
|
||||
self.messages = []
|
||||
|
||||
def delete(self, ):
|
||||
pass
|
||||
|
||||
def get_messages(self, k=0) -> List[Message]:
|
||||
"""Return the most recent k memories, return all when k=0"""
|
||||
return self.messages[-k:]
|
||||
|
||||
def split_by_role_type(self) -> List[Dict[str, 'Memory']]:
|
||||
"""
|
||||
Split messages into rounds of conversation based on role_type.
|
||||
Each round consists of consecutive messages of the same role_type.
|
||||
User messages form a single round, while assistant and function messages are combined into a single round.
|
||||
Each round is represented by a dict with 'role' and 'memory' keys, with assistant and function messages
|
||||
labeled as 'assistant'.
|
||||
"""
|
||||
rounds = []
|
||||
current_memory = Memory()
|
||||
current_role = None
|
||||
|
||||
logger.debug(len(self.messages))
|
||||
|
||||
for msg in self.messages:
|
||||
# Determine the message's role, considering 'function' as 'assistant'
|
||||
message_role = 'assistant' if msg.role_type in ['assistant', 'function'] else 'user'
|
||||
|
||||
# If the current memory is empty or the current message is of the same role_type as current_role, add to current memory
|
||||
if not current_memory.messages or current_role == message_role:
|
||||
current_memory.append(msg)
|
||||
else:
|
||||
# Finish the current memory and start a new one
|
||||
rounds.append({'role': current_role, 'memory': current_memory})
|
||||
current_memory = Memory()
|
||||
current_memory.append(msg)
|
||||
|
||||
# Update the current_role, considering 'function' as 'assistant'
|
||||
current_role = message_role
|
||||
|
||||
# Don't forget to add the last memory if it exists
|
||||
if current_memory.messages:
|
||||
rounds.append({'role': current_role, 'memory': current_memory})
|
||||
|
||||
logger.debug(rounds)
|
||||
|
||||
return rounds
|
||||
|
||||
def format_rounds_to_html(self) -> str:
|
||||
formatted_html_str = ""
|
||||
rounds = self.split_by_role_type()
|
||||
|
||||
for round in rounds:
|
||||
role = round['role']
|
||||
memory = round['memory']
|
||||
|
||||
# 转换当前round的Memory为字符串
|
||||
messages_str = memory.to_str_messages()
|
||||
|
||||
# 根据角色类型添加相应的HTML标签
|
||||
if role == 'user':
|
||||
formatted_html_str += f"<user-message>\n{messages_str}\n</user-message>\n"
|
||||
else: # 对于'assistant'和'function'角色,我们将其视为'assistant'
|
||||
formatted_html_str += f"<assistant-message>\n{messages_str}\n</assistant-message>\n"
|
||||
|
||||
return formatted_html_str
|
||||
|
||||
|
||||
def filter_by_role_type(self, role_types: List[str]) -> List[Message]:
|
||||
# Filter messages based on role types
|
||||
return [message for message in self.messages if message.role_type not in role_types]
|
||||
|
||||
def select_by_role_type(self, role_types: List[str]) -> List[Message]:
|
||||
# Select messages based on role types
|
||||
return [message for message in self.messages if message.role_type in role_types]
|
||||
|
||||
def to_tuple_messages(self, return_all: bool = True, content_key="role_content", filter_roles=[]):
|
||||
# Convert messages to tuples based on parameters
|
||||
# logger.debug(f"{[message.to_tuple_message(return_all, content_key) for message in self.messages ]}")
|
||||
return [
|
||||
message.to_tuple_message(return_all, content_key) for message in self.messages
|
||||
if message.role_name not in filter_roles
|
||||
]
|
||||
|
||||
def to_dict_messages(self, filter_roles=[]):
|
||||
# Convert messages to dictionaries based on filter roles
|
||||
return [
|
||||
message.to_dict_message() for message in self.messages
|
||||
if message.role_name not in filter_roles
|
||||
]
|
||||
|
||||
def to_str_messages(self, return_all: bool = True, content_key="role_content", filter_roles=[], with_tag=False):
|
||||
# Convert messages to strings based on parameters
|
||||
# for message in self.messages:
|
||||
# logger.debug(f"{message.role_name}: {message.to_str_content(return_all, content_key, with_tag=with_tag)}")
|
||||
# logger.debug(f"{[message.to_tuple_message(return_all, content_key) for message in self.messages ]}")
|
||||
return "\n\n".join([message.to_str_content(return_all, content_key, with_tag=with_tag) for message in self.messages
|
||||
if message.role_name not in filter_roles
|
||||
])
|
||||
|
||||
def get_parserd_output(self, ):
|
||||
return [message.parsed_output for message in self.messages]
|
||||
|
||||
def get_parserd_output_list(self, ):
|
||||
# for message in self.messages:
|
||||
# logger.debug(f"{message.role_name}: {message.parsed_output_list}")
|
||||
# return [parsed_output for message in self.messages for parsed_output in message.parsed_output_list[1:]]
|
||||
return [parsed_output for message in self.messages for parsed_output in message.parsed_output_list]
|
||||
|
||||
def get_spec_parserd_output(self, ):
|
||||
return [message.spec_parsed_output for message in self.messages]
|
||||
|
||||
def get_rolenames(self, ):
|
||||
''''''
|
||||
return [message.role_name for message in self.messages]
|
||||
|
||||
@classmethod
|
||||
def from_memory_list(cls, memorys: List['Memory']) -> 'Memory':
|
||||
return cls(messages=[message for memory in memorys for message in memory.get_messages()])
|
||||
|
||||
def __len__(self, ):
|
||||
return len(self.messages)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return "\n".join([":".join(i) for i in self.to_tuple_messages()])
|
||||
|
||||
def __add__(self, other: Union[Message, 'Memory']) -> 'Memory':
|
||||
if isinstance(other, Message):
|
||||
return Memory(messages=self.messages + [other])
|
||||
elif isinstance(other, Memory):
|
||||
return Memory(messages=self.messages + other.messages)
|
||||
else:
|
||||
raise ValueError(f"cant add unspecified type like as {type(other)}")
|
||||
|
||||
|
||||
|
|
@ -1,121 +0,0 @@
|
|||
from pydantic import BaseModel, root_validator
|
||||
from loguru import logger
|
||||
|
||||
from coagent.utils.common_utils import getCurrentDatetime
|
||||
from .general_schema import *
|
||||
|
||||
|
||||
class Message(BaseModel):
|
||||
chat_index: str = None
|
||||
user_name: str = "default"
|
||||
role_name: str
|
||||
role_type: str
|
||||
role_prompt: str = None
|
||||
input_query: str = None
|
||||
origin_query: str = None
|
||||
datetime: str = getCurrentDatetime()
|
||||
|
||||
# llm output
|
||||
role_content: str = None
|
||||
step_content: str = None
|
||||
|
||||
# llm parsed information
|
||||
plans: List[str] = None
|
||||
code_content: str = None
|
||||
code_filename: str = None
|
||||
tool_params: str = None
|
||||
tool_name: str = None
|
||||
parsed_output: dict = {}
|
||||
spec_parsed_output: dict = {}
|
||||
parsed_output_list: List[Dict] = []
|
||||
|
||||
# llm\tool\code executre information
|
||||
action_status: str = "default"
|
||||
agent_index: int = None
|
||||
code_answer: str = None
|
||||
tool_answer: str = None
|
||||
observation: str = None
|
||||
figures: Dict[str, str] = {}
|
||||
|
||||
# prompt support information
|
||||
tools: List[BaseTool] = []
|
||||
task: Task = None
|
||||
db_docs: List['Doc'] = []
|
||||
code_docs: List['CodeDoc'] = []
|
||||
search_docs: List['Doc'] = []
|
||||
agents: List = []
|
||||
|
||||
# phase input
|
||||
phase_name: str = None
|
||||
chain_name: str = None
|
||||
do_search: bool = False
|
||||
doc_engine_name: str = None
|
||||
code_engine_name: str = None
|
||||
cb_search_type: str = None
|
||||
search_engine_name: str = None
|
||||
top_k: int = 3
|
||||
use_nh: bool = True
|
||||
local_graph_path: str = ''
|
||||
score_threshold: float = 1.0
|
||||
do_doc_retrieval: bool = False
|
||||
do_code_retrieval: bool = False
|
||||
do_tool_retrieval: bool = False
|
||||
history_node_list: List[str] = []
|
||||
# user's customed kargs for init or end action
|
||||
customed_kargs: dict = {}
|
||||
|
||||
|
||||
@root_validator(pre=True)
|
||||
def check_card_number_omitted(cls, values):
|
||||
input_query = values.get("input_query")
|
||||
origin_query = values.get("origin_query")
|
||||
role_content = values.get("role_content")
|
||||
if input_query is None:
|
||||
values["input_query"] = origin_query or role_content
|
||||
if role_content is None:
|
||||
values["role_content"] = origin_query
|
||||
return values
|
||||
|
||||
# pydantic>=2.0
|
||||
# @model_validator(mode='after')
|
||||
# def check_passwords_match(self) -> 'Message':
|
||||
# if self.input_query is None:
|
||||
# self.input_query = self.origin_query or self.role_content
|
||||
# if self.role_content is None:
|
||||
# self.role_content = self.origin_query
|
||||
# return self
|
||||
|
||||
def to_tuple_message(self, return_all: bool = True, content_key="role_content"):
|
||||
role_content = self.to_str_content(False, content_key)
|
||||
if return_all:
|
||||
return (self.role_name, role_content)
|
||||
else:
|
||||
return (role_content)
|
||||
|
||||
def to_dict_message(self, ):
|
||||
return vars(self)
|
||||
|
||||
def to_str_content(self, return_all: bool = True, content_key="role_content", with_tag=False):
|
||||
if content_key == "role_content":
|
||||
role_content = self.role_content or self.input_query
|
||||
elif content_key == "step_content":
|
||||
role_content = self.step_content or self.role_content or self.input_query
|
||||
elif content_key == "parsed_output":
|
||||
role_content = "\n".join([f"**{k}:** {v}" for k, v in self.parsed_output.items()])
|
||||
elif content_key == "parsed_output_list":
|
||||
role_content = "\n".join([f"**{k}:** {v}" for po in self.parsed_output_list for k,v in po.items()])
|
||||
else:
|
||||
role_content = self.role_content or self.input_query
|
||||
|
||||
if with_tag:
|
||||
start_tag = f"<{self.role_type}-{self.role_name}-message>"
|
||||
end_tag = f"</{self.role_type}-{self.role_name}-message>"
|
||||
return f"{start_tag}\n{role_content}\n{end_tag}"
|
||||
else:
|
||||
return role_content
|
||||
|
||||
def __str__(self) -> str:
|
||||
# key_str = '\n'.join([k for k, v in vars(self).items()])
|
||||
# logger.debug(f"{key_str}")
|
||||
return "\n".join([": ".join([k, str(v)]) for k, v in vars(self).items()])
|
||||
|
|
@ -1,122 +0,0 @@
|
|||
import re, copy, json
|
||||
from loguru import logger
|
||||
|
||||
|
||||
def extract_section(text, section_name):
|
||||
# Define a pattern to extract the named section along with its content
|
||||
section_pattern = rf'#### {section_name}\n(.*?)(?=####|$)'
|
||||
|
||||
# Find the specific section content
|
||||
section_content = re.search(section_pattern, text, re.DOTALL)
|
||||
|
||||
if section_content:
|
||||
# If the section is found, extract the content and strip the leading/trailing whitespace
|
||||
# This will also remove leading/trailing newlines
|
||||
content = section_content.group(1).strip()
|
||||
|
||||
# Return the cleaned content
|
||||
return content
|
||||
else:
|
||||
# If the section is not found, return an empty string
|
||||
return ""
|
||||
|
||||
|
||||
def parse_section(text, section_name):
|
||||
# Define a pattern to extract the named section along with its content
|
||||
section_pattern = rf'#### {section_name}\n(.*?)(?=####|$)'
|
||||
|
||||
# Find the specific section content
|
||||
section_content = re.search(section_pattern, text, re.DOTALL)
|
||||
|
||||
if section_content:
|
||||
# If the section is found, extract the content
|
||||
content = section_content.group(1)
|
||||
|
||||
# Define a pattern to find segments that follow the format **xx:**
|
||||
segments_pattern = r'\*\*([^*]+):\*\*'
|
||||
|
||||
# Use findall method to extract all matches in the section content
|
||||
segments = re.findall(segments_pattern, content)
|
||||
|
||||
return segments
|
||||
else:
|
||||
# If the section is not found, return an empty list
|
||||
return []
|
||||
|
||||
|
||||
def parse_text_to_dict(text):
|
||||
# Define a regular expression pattern to capture the key and value
|
||||
main_pattern = r"\*\*(.+?):\*\*\s*(.*?)\s*(?=\*\*|$)"
|
||||
list_pattern = r'```python\n(.*?)```'
|
||||
plan_pattern = r'\[\s*.*?\s*\]'
|
||||
|
||||
# Use re.findall to find all main matches in the text
|
||||
main_matches = re.findall(main_pattern, text, re.DOTALL)
|
||||
|
||||
# Convert main matches to a dictionary
|
||||
parsed_dict = {key.strip(): value.strip() for key, value in main_matches}
|
||||
|
||||
for k, v in parsed_dict.items():
|
||||
for pattern in [list_pattern, plan_pattern]:
|
||||
if "PLAN" != k: continue
|
||||
v = v.replace("```list", "```python")
|
||||
match_value = re.search(pattern, v, re.DOTALL)
|
||||
if match_value:
|
||||
# Add the code block to the dictionary
|
||||
parsed_dict[k] = eval(match_value.group(1).strip())
|
||||
break
|
||||
|
||||
return parsed_dict
|
||||
|
||||
|
||||
def parse_dict_to_dict(parsed_dict) -> dict:
|
||||
code_pattern = r'```python\n(.*?)```'
|
||||
tool_pattern = r'```json\n(.*?)```'
|
||||
java_pattern = r'```java\n(.*?)```'
|
||||
|
||||
pattern_dict = {"code": code_pattern, "json": tool_pattern, "java": java_pattern}
|
||||
spec_parsed_dict = copy.deepcopy(parsed_dict)
|
||||
for key, pattern in pattern_dict.items():
|
||||
for k, text in parsed_dict.items():
|
||||
# Search for the code block
|
||||
if not isinstance(text, str):
|
||||
spec_parsed_dict[k] = text
|
||||
continue
|
||||
_match = re.search(pattern, text, re.DOTALL)
|
||||
if _match:
|
||||
# Add the code block to the dictionary
|
||||
try:
|
||||
spec_parsed_dict[key] = json.loads(_match.group(1).strip())
|
||||
spec_parsed_dict[k] = json.loads(_match.group(1).strip())
|
||||
except:
|
||||
spec_parsed_dict[key] = _match.group(1).strip()
|
||||
spec_parsed_dict[k] = _match.group(1).strip()
|
||||
break
|
||||
return spec_parsed_dict
|
||||
|
||||
|
||||
def prompt_cost(model_type: str, num_prompt_tokens: float, num_completion_tokens: float):
|
||||
input_cost_map = {
|
||||
"gpt-3.5-turbo": 0.0015,
|
||||
"gpt-3.5-turbo-16k": 0.003,
|
||||
"gpt-3.5-turbo-0613": 0.0015,
|
||||
"gpt-3.5-turbo-16k-0613": 0.003,
|
||||
"gpt-4": 0.03,
|
||||
"gpt-4-0613": 0.03,
|
||||
"gpt-4-32k": 0.06,
|
||||
}
|
||||
|
||||
output_cost_map = {
|
||||
"gpt-3.5-turbo": 0.002,
|
||||
"gpt-3.5-turbo-16k": 0.004,
|
||||
"gpt-3.5-turbo-0613": 0.002,
|
||||
"gpt-3.5-turbo-16k-0613": 0.004,
|
||||
"gpt-4": 0.06,
|
||||
"gpt-4-0613": 0.06,
|
||||
"gpt-4-32k": 0.12,
|
||||
}
|
||||
|
||||
if model_type not in input_cost_map or model_type not in output_cost_map:
|
||||
return -1
|
||||
|
||||
return num_prompt_tokens * input_cost_map[model_type] / 1000.0 + num_completion_tokens * output_cost_map[model_type] / 1000.0
|
|
@ -1,7 +0,0 @@
|
|||
# encoding: utf-8
|
||||
'''
|
||||
@author: 温进
|
||||
@file: __init__.py.py
|
||||
@time: 2023/11/16 下午3:15
|
||||
@desc:
|
||||
'''
|
|
@ -1,7 +0,0 @@
|
|||
# encoding: utf-8
|
||||
'''
|
||||
@author: 温进
|
||||
@file: __init__.py.py
|
||||
@time: 2023/11/20 下午3:07
|
||||
@desc:
|
||||
'''
|
|
@ -1,285 +0,0 @@
|
|||
# encoding: utf-8
|
||||
'''
|
||||
@author: 温进
|
||||
@file: nebula_handler.py
|
||||
@time: 2023/11/16 下午3:15
|
||||
@desc:
|
||||
'''
|
||||
import time
|
||||
from loguru import logger
|
||||
|
||||
from nebula3.gclient.net import ConnectionPool
|
||||
from nebula3.Config import Config
|
||||
|
||||
|
||||
class NebulaHandler:
|
||||
def __init__(self, host: str, port: int, username: str, password: str = '', space_name: str = ''):
|
||||
'''
|
||||
init nebula connection_pool
|
||||
@param host: host
|
||||
@param port: port
|
||||
@param username: username
|
||||
@param password: password
|
||||
'''
|
||||
config = Config()
|
||||
|
||||
self.connection_pool = ConnectionPool()
|
||||
self.connection_pool.init([(host, port)], config)
|
||||
self.username = username
|
||||
self.password = password
|
||||
self.space_name = space_name
|
||||
|
||||
def execute_cypher(self, cypher: str, space_name: str = '', format_res: bool = False, use_space_name: bool = True):
|
||||
'''
|
||||
|
||||
@param space_name: space_name, if provided, will execute use space_name first
|
||||
@param cypher:
|
||||
@return:
|
||||
'''
|
||||
with self.connection_pool.session_context(self.username, self.password) as session:
|
||||
if use_space_name:
|
||||
if space_name:
|
||||
cypher = f'USE {space_name};{cypher}'
|
||||
elif self.space_name:
|
||||
cypher = f'USE {self.space_name};{cypher}'
|
||||
|
||||
# logger.debug(cypher)
|
||||
resp = session.execute(cypher)
|
||||
|
||||
if format_res:
|
||||
resp = self.result_to_dict(resp)
|
||||
return resp
|
||||
|
||||
def close_connection(self):
|
||||
self.connection_pool.close()
|
||||
|
||||
def create_space(self, space_name: str, vid_type: str, comment: str = ''):
|
||||
'''
|
||||
create space
|
||||
@param space_name: cannot startwith number
|
||||
@return:
|
||||
'''
|
||||
cypher = f'CREATE SPACE IF NOT EXISTS {space_name} (vid_type={vid_type}) comment="{comment}";'
|
||||
resp = self.execute_cypher(cypher, use_space_name=False)
|
||||
|
||||
return resp
|
||||
|
||||
def show_space(self):
|
||||
cypher = 'SHOW SPACES'
|
||||
resp = self.execute_cypher(cypher)
|
||||
return resp
|
||||
|
||||
def drop_space(self, space_name):
|
||||
cypher = f'DROP SPACE {space_name}'
|
||||
return self.execute_cypher(cypher)
|
||||
|
||||
def create_tag(self, tag_name: str, prop_dict: dict = {}):
|
||||
'''
|
||||
创建 tag
|
||||
@param tag_name: tag 名称
|
||||
@param prop_dict: 属性字典 {'prop 名字': 'prop 类型'}
|
||||
@return:
|
||||
'''
|
||||
cypher = f'CREATE TAG IF NOT EXISTS {tag_name}'
|
||||
cypher += '('
|
||||
for k, v in prop_dict.items():
|
||||
cypher += f'{k} {v},'
|
||||
cypher = cypher.rstrip(',')
|
||||
cypher += ')'
|
||||
cypher += ';'
|
||||
|
||||
res = self.execute_cypher(cypher, self.space_name)
|
||||
return res
|
||||
|
||||
def show_tags(self):
|
||||
'''
|
||||
查看 tag
|
||||
@return:
|
||||
'''
|
||||
cypher = 'SHOW TAGS'
|
||||
resp = self.execute_cypher(cypher, self.space_name)
|
||||
return resp
|
||||
|
||||
def insert_vertex(self, tag_name: str, value_dict: dict):
|
||||
'''
|
||||
insert vertex
|
||||
@param tag_name:
|
||||
@param value_dict: {'properties_name': [], values: {'vid':[]}} order should be the same in properties_name and values
|
||||
@return:
|
||||
'''
|
||||
cypher = f'INSERT VERTEX {tag_name} ('
|
||||
|
||||
properties_name = value_dict['properties_name']
|
||||
|
||||
for property_name in properties_name:
|
||||
cypher += f'{property_name},'
|
||||
cypher = cypher.rstrip(',')
|
||||
|
||||
cypher += ') VALUES '
|
||||
|
||||
for vid, properties in value_dict['values'].items():
|
||||
cypher += f'"{vid}":('
|
||||
for property in properties:
|
||||
if type(property) == str:
|
||||
cypher += f'"{property}",'
|
||||
else:
|
||||
cypher += f'{property}'
|
||||
cypher = cypher.rstrip(',')
|
||||
cypher += '),'
|
||||
cypher = cypher.rstrip(',')
|
||||
cypher += ';'
|
||||
|
||||
res = self.execute_cypher(cypher, self.space_name)
|
||||
return res
|
||||
|
||||
def create_edge_type(self, edge_type_name: str, prop_dict: dict = {}):
|
||||
'''
|
||||
创建 tag
|
||||
@param edge_type_name: tag 名称
|
||||
@param prop_dict: 属性字典 {'prop 名字': 'prop 类型'}
|
||||
@return:
|
||||
'''
|
||||
cypher = f'CREATE EDGE IF NOT EXISTS {edge_type_name}'
|
||||
|
||||
cypher += '('
|
||||
for k, v in prop_dict.items():
|
||||
cypher += f'{k} {v},'
|
||||
cypher = cypher.rstrip(',')
|
||||
cypher += ')'
|
||||
cypher += ';'
|
||||
|
||||
res = self.execute_cypher(cypher, self.space_name)
|
||||
return res
|
||||
|
||||
def show_edge_type(self):
|
||||
'''
|
||||
查看 tag
|
||||
@return:
|
||||
'''
|
||||
cypher = 'SHOW EDGES'
|
||||
resp = self.execute_cypher(cypher, self.space_name)
|
||||
return resp
|
||||
|
||||
def drop_edge_type(self, edge_type_name: str):
|
||||
cypher = f'DROP EDGE {edge_type_name}'
|
||||
return self.execute_cypher(cypher, self.space_name)
|
||||
|
||||
def insert_edge(self, edge_type_name: str, value_dict: dict):
|
||||
'''
|
||||
insert edge
|
||||
@param edge_type_name:
|
||||
@param value_dict: value_dict: {'properties_name': [], values: {(src_vid, dst_vid):[]}} order should be the
|
||||
same in properties_name and values
|
||||
@return:
|
||||
'''
|
||||
cypher = f'INSERT EDGE {edge_type_name} ('
|
||||
|
||||
properties_name = value_dict['properties_name']
|
||||
|
||||
for property_name in properties_name:
|
||||
cypher += f'{property_name},'
|
||||
cypher = cypher.rstrip(',')
|
||||
|
||||
cypher += ') VALUES '
|
||||
|
||||
for (src_vid, dst_vid), properties in value_dict['values'].items():
|
||||
cypher += f'"{src_vid}"->"{dst_vid}":('
|
||||
for property in properties:
|
||||
if type(property) == str:
|
||||
cypher += f'"{property}",'
|
||||
else:
|
||||
cypher += f'{property}'
|
||||
cypher = cypher.rstrip(',')
|
||||
cypher += '),'
|
||||
cypher = cypher.rstrip(',')
|
||||
cypher += ';'
|
||||
|
||||
res = self.execute_cypher(cypher, self.space_name)
|
||||
return res
|
||||
|
||||
def set_space_name(self, space_name):
|
||||
self.space_name = space_name
|
||||
|
||||
def add_host(self, host: str, port: str):
|
||||
'''
|
||||
add host
|
||||
@return:
|
||||
'''
|
||||
cypher = f'ADD HOSTS {host}:{port}'
|
||||
res = self.execute_cypher(cypher)
|
||||
return res
|
||||
|
||||
def get_stat(self):
|
||||
'''
|
||||
|
||||
@return:
|
||||
'''
|
||||
submit_cypher = 'SUBMIT JOB STATS;'
|
||||
self.execute_cypher(cypher=submit_cypher, space_name=self.space_name)
|
||||
time.sleep(2)
|
||||
|
||||
stats_cypher = 'SHOW STATS;'
|
||||
stats_res = self.execute_cypher(cypher=stats_cypher, space_name=self.space_name)
|
||||
|
||||
res = {'vertices': -1, 'edges': -1}
|
||||
|
||||
stats_res_dict = self.result_to_dict(stats_res)
|
||||
logger.info(stats_res_dict)
|
||||
for idx in range(len(stats_res_dict['Type'])):
|
||||
t = stats_res_dict['Type'][idx].as_string()
|
||||
name = stats_res_dict['Name'][idx].as_string()
|
||||
count = stats_res_dict['Count'][idx].as_int()
|
||||
|
||||
if t == 'Space' and name in res:
|
||||
res[name] = count
|
||||
return res
|
||||
|
||||
def get_vertices(self, tag_name: str = '', limit: int = 10000):
|
||||
'''
|
||||
get all vertices
|
||||
@return:
|
||||
'''
|
||||
if tag_name:
|
||||
cypher = f'''MATCH (v:{tag_name}) RETURN v LIMIT {limit};'''
|
||||
else:
|
||||
cypher = f'MATCH (v) RETURN v LIMIT {limit};'
|
||||
|
||||
res = self.execute_cypher(cypher, self.space_name)
|
||||
return self.result_to_dict(res)
|
||||
|
||||
def get_all_vertices(self,):
|
||||
'''
|
||||
get all vertices
|
||||
@return:
|
||||
'''
|
||||
cypher = "MATCH (v) RETURN v;"
|
||||
res = self.execute_cypher(cypher, self.space_name)
|
||||
return self.result_to_dict(res)
|
||||
|
||||
def get_relative_vertices(self, vertice):
|
||||
'''
|
||||
get all vertices
|
||||
@return:
|
||||
'''
|
||||
cypher = f'''MATCH (v1)--(v2) WHERE id(v1) == '{vertice}' RETURN id(v2) as id;'''
|
||||
res = self.execute_cypher(cypher, self.space_name)
|
||||
return self.result_to_dict(res)
|
||||
|
||||
def result_to_dict(self, result) -> dict:
|
||||
"""
|
||||
build list for each column, and transform to dataframe
|
||||
"""
|
||||
# logger.info(result.error_msg())
|
||||
assert result.is_succeeded()
|
||||
columns = result.keys()
|
||||
d = {}
|
||||
for col_num in range(result.col_size()):
|
||||
col_name = columns[col_num]
|
||||
col_list = result.column_values(col_name)
|
||||
d[col_name] = [x for x in col_list]
|
||||
return d
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
@ -1,7 +0,0 @@
|
|||
# encoding: utf-8
|
||||
'''
|
||||
@author: 温进
|
||||
@file: __init__.py.py
|
||||
@time: 2023/11/20 下午3:08
|
||||
@desc:
|
||||
'''
|
|
@ -1,144 +0,0 @@
|
|||
# encoding: utf-8
|
||||
'''
|
||||
@author: 温进
|
||||
@file: chroma_handler.py
|
||||
@time: 2023/11/21 下午12:21
|
||||
@desc:
|
||||
'''
|
||||
from loguru import logger
|
||||
import chromadb
|
||||
|
||||
|
||||
class ChromaHandler:
|
||||
def __init__(self, path: str, collection_name: str = ''):
|
||||
'''
|
||||
init client
|
||||
@param path: path of data
|
||||
@collection_name: name of collection
|
||||
'''
|
||||
settings = chromadb.get_settings()
|
||||
# disable the posthog telemetry mechnism that may raise the connection error, such as
|
||||
# "requests.exceptions.ConnectTimeout: HTTPSConnectionPool(host='us-api.i.posthog.com', port 443)"
|
||||
settings.anonymized_telemetry = False
|
||||
self.client = chromadb.PersistentClient(path, settings)
|
||||
self.client.heartbeat()
|
||||
|
||||
if collection_name:
|
||||
self.collection = self.client.get_or_create_collection(name=collection_name)
|
||||
|
||||
def create_collection(self, collection_name: str):
|
||||
'''
|
||||
create collection, if exists, will override
|
||||
@return:
|
||||
'''
|
||||
try:
|
||||
collection = self.client.create_collection(name=collection_name)
|
||||
except Exception as e:
|
||||
return {'result_code': -1, 'msg': f'fail, error={e}'}
|
||||
return {'result_code': 0, 'msg': 'success'}
|
||||
|
||||
def delete_collection(self, collection_name: str):
|
||||
'''
|
||||
|
||||
@param collection_name:
|
||||
@return:
|
||||
'''
|
||||
try:
|
||||
self.client.delete_collection(name=collection_name)
|
||||
except Exception as e:
|
||||
return {'result_code': -1, 'msg': f'fail, error={e}'}
|
||||
return {'result_code': 0, 'msg': 'success'}
|
||||
|
||||
def set_collection(self, collection_name: str):
|
||||
'''
|
||||
|
||||
@param collection_name:
|
||||
@return:
|
||||
'''
|
||||
try:
|
||||
self.collection = self.client.get_collection(collection_name)
|
||||
except Exception as e:
|
||||
return {'result_code': -1, 'msg': f'fail, error={e}'}
|
||||
return {'result_code': 0, 'msg': 'success'}
|
||||
|
||||
def add_data(self, ids: list, documents: list = None, embeddings: list = None, metadatas: list = None):
|
||||
'''
|
||||
add data to chroma
|
||||
@param documents: list of doc string
|
||||
@param embeddings: list of vector
|
||||
@param metadatas: list of metadata
|
||||
@param ids: list of id
|
||||
@return:
|
||||
'''
|
||||
try:
|
||||
self.collection.add(
|
||||
ids=ids,
|
||||
embeddings=embeddings,
|
||||
metadatas=metadatas,
|
||||
documents=documents
|
||||
)
|
||||
except Exception as e:
|
||||
return {'result_code': -1, 'msg': f'fail, error={e}'}
|
||||
return {'result_code': 0, 'msg': 'success'}
|
||||
|
||||
def query(self, query_embeddings=None, query_texts=None, n_results=10, where=None, where_document=None,
|
||||
include=["metadatas", "documents", "distances"]):
|
||||
'''
|
||||
|
||||
@param query_embeddings:
|
||||
@param query_texts:
|
||||
@param n_results:
|
||||
@param where:
|
||||
@param where_document:
|
||||
@param include:
|
||||
@return:
|
||||
'''
|
||||
try:
|
||||
query_result = self.collection.query(query_embeddings=query_embeddings, query_texts=query_texts,
|
||||
n_results=n_results, where=where, where_document=where_document,
|
||||
include=include)
|
||||
return {'result_code': 0, 'msg': 'success', 'result': query_result}
|
||||
except Exception as e:
|
||||
return {'result_code': -1, 'msg': f'fail, error={e}'}
|
||||
|
||||
def get(self, ids=None, where=None, limit=None, offset=None, where_document=None, include=["metadatas", "documents"]):
|
||||
'''
|
||||
get by condition
|
||||
@param ids:
|
||||
@param where:
|
||||
@param limit:
|
||||
@param offset:
|
||||
@param where_document:
|
||||
@param include:
|
||||
@return:
|
||||
'''
|
||||
try:
|
||||
query_result = self.collection.get(ids=ids, where=where, where_document=where_document,
|
||||
limit=limit,
|
||||
offset=offset, include=include)
|
||||
return {'result_code': 0, 'msg': 'success', 'result': query_result}
|
||||
except Exception as e:
|
||||
return {'result_code': -1, 'msg': f'fail, error={e}'}
|
||||
|
||||
def peek(self, limit: int=10):
|
||||
'''
|
||||
peek
|
||||
@param limit:
|
||||
@return:
|
||||
'''
|
||||
try:
|
||||
query_result = self.collection.peek(limit)
|
||||
return {'result_code': 0, 'msg': 'success', 'result': query_result}
|
||||
except Exception as e:
|
||||
return {'result_code': -1, 'msg': f'fail, error={e}'}
|
||||
|
||||
def count(self):
|
||||
'''
|
||||
count
|
||||
@return:
|
||||
'''
|
||||
try:
|
||||
query_result = self.collection.count()
|
||||
return {'result_code': 0, 'msg': 'success', 'result': query_result}
|
||||
except Exception as e:
|
||||
return {'result_code': -1, 'msg': f'fail, error={e}'}
|
|
@ -1,6 +0,0 @@
|
|||
from .json_loader import JSONLoader
|
||||
from .jsonl_loader import JSONLLoader
|
||||
|
||||
__all__ = [
|
||||
"JSONLoader", "JSONLLoader"
|
||||
]
|
|
@ -1,61 +0,0 @@
|
|||
import json
|
||||
from pathlib import Path
|
||||
from typing import AnyStr, Callable, Dict, List, Optional, Union
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
|
||||
|
||||
from coagent.utils.common_utils import read_json_file
|
||||
|
||||
|
||||
class JSONLoader(BaseLoader):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: Union[str, Path],
|
||||
schema_key: str = "all_text",
|
||||
content_key: Optional[str] = None,
|
||||
metadata_func: Optional[Callable[[Dict, Dict], Dict]] = None,
|
||||
text_content: bool = True,
|
||||
):
|
||||
self.file_path = Path(file_path).resolve()
|
||||
self.schema_key = schema_key
|
||||
self._content_key = content_key
|
||||
self._metadata_func = metadata_func
|
||||
self._text_content = text_content
|
||||
|
||||
def load(self, ) -> List[Document]:
|
||||
"""Load and return documents from the JSON file."""
|
||||
docs: List[Document] = []
|
||||
datas = read_json_file(self.file_path)
|
||||
self._parse(datas, docs)
|
||||
return docs
|
||||
|
||||
def _parse(self, datas: List, docs: List[Document]) -> None:
|
||||
for idx, sample in enumerate(datas):
|
||||
metadata = dict(
|
||||
source=str(self.file_path),
|
||||
seq_num=idx,
|
||||
)
|
||||
text = sample.get(self.schema_key, "")
|
||||
docs.append(Document(page_content=text, metadata=metadata))
|
||||
|
||||
def load_and_split(
|
||||
self, text_splitter: Optional[TextSplitter] = None
|
||||
) -> List[Document]:
|
||||
"""Load Documents and split into chunks. Chunks are returned as Documents.
|
||||
|
||||
Args:
|
||||
text_splitter: TextSplitter instance to use for splitting documents.
|
||||
Defaults to RecursiveCharacterTextSplitter.
|
||||
|
||||
Returns:
|
||||
List of Documents.
|
||||
"""
|
||||
if text_splitter is None:
|
||||
_text_splitter: TextSplitter = RecursiveCharacterTextSplitter()
|
||||
else:
|
||||
_text_splitter = text_splitter
|
||||
docs = self.load()
|
||||
return _text_splitter.split_documents(docs)
|
|
@ -1,62 +0,0 @@
|
|||
import json
|
||||
from pathlib import Path
|
||||
from typing import AnyStr, Callable, Dict, List, Optional, Union
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
|
||||
|
||||
from coagent.utils.common_utils import read_jsonl_file
|
||||
|
||||
|
||||
class JSONLLoader(BaseLoader):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: Union[str, Path],
|
||||
schema_key: str = "all_text",
|
||||
content_key: Optional[str] = None,
|
||||
metadata_func: Optional[Callable[[Dict, Dict], Dict]] = None,
|
||||
text_content: bool = True,
|
||||
):
|
||||
self.file_path = Path(file_path).resolve()
|
||||
self.schema_key = schema_key
|
||||
self._content_key = content_key
|
||||
self._metadata_func = metadata_func
|
||||
self._text_content = text_content
|
||||
|
||||
def load(self, ) -> List[Document]:
|
||||
"""Load and return documents from the JSON file."""
|
||||
docs: List[Document] = []
|
||||
datas = read_jsonl_file(self.file_path)
|
||||
self._parse(datas, docs)
|
||||
return docs
|
||||
|
||||
def _parse(self, datas: List, docs: List[Document]) -> None:
|
||||
for idx, sample in enumerate(datas):
|
||||
metadata = dict(
|
||||
source=str(self.file_path),
|
||||
seq_num=idx,
|
||||
)
|
||||
text = sample.get(self.schema_key, "")
|
||||
docs.append(Document(page_content=text, metadata=metadata))
|
||||
|
||||
def load_and_split(
|
||||
self, text_splitter: Optional[TextSplitter] = None
|
||||
) -> List[Document]:
|
||||
"""Load Documents and split into chunks. Chunks are returned as Documents.
|
||||
|
||||
Args:
|
||||
text_splitter: TextSplitter instance to use for splitting documents.
|
||||
Defaults to RecursiveCharacterTextSplitter.
|
||||
|
||||
Returns:
|
||||
List of Documents.
|
||||
"""
|
||||
if text_splitter is None:
|
||||
_text_splitter: TextSplitter = RecursiveCharacterTextSplitter()
|
||||
else:
|
||||
_text_splitter = text_splitter
|
||||
|
||||
docs = self.load()
|
||||
return _text_splitter.split_documents(docs)
|
|
@ -1,37 +0,0 @@
|
|||
from typing import List
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.schema import Document
|
||||
|
||||
|
||||
|
||||
class BaseVSCService:
|
||||
def do_create_kb(self):
|
||||
pass
|
||||
|
||||
def do_drop_kb(self):
|
||||
pass
|
||||
|
||||
def do_add_doc(self, docs: List[Document], embeddings: Embeddings):
|
||||
pass
|
||||
|
||||
def do_clear_vs(self):
|
||||
pass
|
||||
|
||||
def vs_type(self) -> str:
|
||||
return "default"
|
||||
|
||||
def do_init(self):
|
||||
pass
|
||||
|
||||
def do_search(self):
|
||||
pass
|
||||
|
||||
def do_insert_multi_knowledge(self):
|
||||
pass
|
||||
|
||||
def do_insert_one_knowledge(self):
|
||||
pass
|
||||
|
||||
def do_delete_doc(self):
|
||||
pass
|
|
@ -1,791 +0,0 @@
|
|||
"""Wrapper around FAISS vector database."""
|
||||
from __future__ import annotations
|
||||
|
||||
import operator
|
||||
import os
|
||||
import pickle
|
||||
import uuid
|
||||
import warnings
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Sized,
|
||||
Tuple,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
|
||||
from langchain.docstore.base import AddableMixin, Docstore
|
||||
from langchain.docstore.document import Document
|
||||
# from langchain.docstore.in_memory import InMemoryDocstore
|
||||
from .in_memory import InMemoryDocstore
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.vectorstores.base import VectorStore
|
||||
from langchain.vectorstores.utils import maximal_marginal_relevance
|
||||
|
||||
|
||||
class DistanceStrategy(str, Enum):
|
||||
"""Enumerator of the Distance strategies for calculating distances
|
||||
between vectors."""
|
||||
|
||||
EUCLIDEAN_DISTANCE = "EUCLIDEAN_DISTANCE"
|
||||
MAX_INNER_PRODUCT = "MAX_INNER_PRODUCT"
|
||||
DOT_PRODUCT = "DOT_PRODUCT"
|
||||
JACCARD = "JACCARD"
|
||||
COSINE = "COSINE"
|
||||
|
||||
|
||||
def dependable_faiss_import(no_avx2: Optional[bool] = None) -> Any:
|
||||
"""
|
||||
Import faiss if available, otherwise raise error.
|
||||
If FAISS_NO_AVX2 environment variable is set, it will be considered
|
||||
to load FAISS with no AVX2 optimization.
|
||||
|
||||
Args:
|
||||
no_avx2: Load FAISS strictly with no AVX2 optimization
|
||||
so that the vectorstore is portable and compatible with other devices.
|
||||
"""
|
||||
if no_avx2 is None and "FAISS_NO_AVX2" in os.environ:
|
||||
no_avx2 = bool(os.getenv("FAISS_NO_AVX2"))
|
||||
|
||||
try:
|
||||
if no_avx2:
|
||||
from faiss import swigfaiss as faiss
|
||||
else:
|
||||
import faiss
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import faiss python package. "
|
||||
"Please install it with `pip install faiss-gpu` (for CUDA supported GPU) "
|
||||
"or `pip install faiss-cpu` (depending on Python version)."
|
||||
)
|
||||
return faiss
|
||||
|
||||
|
||||
def _len_check_if_sized(x: Any, y: Any, x_name: str, y_name: str) -> None:
|
||||
if isinstance(x, Sized) and isinstance(y, Sized) and len(x) != len(y):
|
||||
raise ValueError(
|
||||
f"{x_name} and {y_name} expected to be equal length but "
|
||||
f"len({x_name})={len(x)} and len({y_name})={len(y)}"
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
class FAISS(VectorStore):
|
||||
"""Wrapper around FAISS vector database.
|
||||
|
||||
To use, you must have the ``faiss`` python package installed.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.embeddings.openai import OpenAIEmbeddings
|
||||
from langchain.vectorstores import FAISS
|
||||
|
||||
embeddings = OpenAIEmbeddings()
|
||||
texts = ["FAISS is an important library", "LangChain supports FAISS"]
|
||||
faiss = FAISS.from_texts(texts, embeddings)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedding_function: Callable,
|
||||
index: Any,
|
||||
docstore: Docstore,
|
||||
index_to_docstore_id: Dict[int, str],
|
||||
relevance_score_fn: Optional[Callable[[float], float]] = None,
|
||||
normalize_L2: bool = False,
|
||||
distance_strategy: DistanceStrategy = DistanceStrategy.EUCLIDEAN_DISTANCE,
|
||||
):
|
||||
"""Initialize with necessary components."""
|
||||
self.embedding_function = embedding_function
|
||||
self.index = index
|
||||
self.docstore = docstore
|
||||
self.index_to_docstore_id = index_to_docstore_id
|
||||
self.distance_strategy = distance_strategy
|
||||
self.override_relevance_score_fn = relevance_score_fn
|
||||
self._normalize_L2 = normalize_L2
|
||||
if (
|
||||
self.distance_strategy != DistanceStrategy.EUCLIDEAN_DISTANCE
|
||||
and self._normalize_L2
|
||||
):
|
||||
warnings.warn(
|
||||
"Normalizing L2 is not applicable for metric type: {strategy}".format(
|
||||
strategy=self.distance_strategy
|
||||
)
|
||||
)
|
||||
|
||||
def __add(
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
embeddings: Iterable[List[float]],
|
||||
metadatas: Optional[Iterable[dict]] = None,
|
||||
ids: Optional[List[str]] = None,
|
||||
) -> List[str]:
|
||||
faiss = dependable_faiss_import()
|
||||
|
||||
if not isinstance(self.docstore, AddableMixin):
|
||||
raise ValueError(
|
||||
"If trying to add texts, the underlying docstore should support "
|
||||
f"adding items, which {self.docstore} does not"
|
||||
)
|
||||
|
||||
_len_check_if_sized(texts, metadatas, "texts", "metadatas")
|
||||
_metadatas = metadatas or ({} for _ in texts)
|
||||
documents = [
|
||||
Document(page_content=t, metadata=m) for t, m in zip(texts, _metadatas)
|
||||
]
|
||||
|
||||
_len_check_if_sized(documents, embeddings, "documents", "embeddings")
|
||||
_len_check_if_sized(documents, ids, "documents", "ids")
|
||||
|
||||
# Add to the index.
|
||||
vector = np.array(embeddings, dtype=np.float32)
|
||||
if self._normalize_L2:
|
||||
faiss.normalize_L2(vector)
|
||||
self.index.add(vector)
|
||||
|
||||
# Add information to docstore and index.
|
||||
ids = ids or [str(uuid.uuid4()) for _ in texts]
|
||||
self.docstore.add({id_: doc for id_, doc in zip(ids, documents)})
|
||||
starting_len = len(self.index_to_docstore_id)
|
||||
index_to_id = {starting_len + j: id_ for j, id_ in enumerate(ids)}
|
||||
self.index_to_docstore_id.update(index_to_id)
|
||||
return ids
|
||||
|
||||
def add_texts(
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
ids: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[str]:
|
||||
"""Run more texts through the embeddings and add to the vectorstore.
|
||||
|
||||
Args:
|
||||
texts: Iterable of strings to add to the vectorstore.
|
||||
metadatas: Optional list of metadatas associated with the texts.
|
||||
ids: Optional list of unique IDs.
|
||||
|
||||
Returns:
|
||||
List of ids from adding the texts into the vectorstore.
|
||||
"""
|
||||
# embeddings = [self.embedding_function(text) for text in texts]
|
||||
embeddings = self.embedding_function(texts)
|
||||
return self.__add(texts, embeddings, metadatas=metadatas, ids=ids)
|
||||
|
||||
def add_embeddings(
|
||||
self,
|
||||
text_embeddings: Iterable[Tuple[str, List[float]]],
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
ids: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[str]:
|
||||
"""Run more texts through the embeddings and add to the vectorstore.
|
||||
|
||||
Args:
|
||||
text_embeddings: Iterable pairs of string and embedding to
|
||||
add to the vectorstore.
|
||||
metadatas: Optional list of metadatas associated with the texts.
|
||||
ids: Optional list of unique IDs.
|
||||
|
||||
Returns:
|
||||
List of ids from adding the texts into the vectorstore.
|
||||
"""
|
||||
# Embed and create the documents.
|
||||
texts, embeddings = zip(*text_embeddings)
|
||||
return self.__add(texts, embeddings, metadatas=metadatas, ids=ids)
|
||||
|
||||
def similarity_search_with_score_by_vector(
|
||||
self,
|
||||
embedding: List[float],
|
||||
k: int = 4,
|
||||
filter: Optional[Dict[str, Any]] = None,
|
||||
fetch_k: int = 20,
|
||||
**kwargs: Any,
|
||||
) -> List[Tuple[Document, float]]:
|
||||
"""Return docs most similar to query.
|
||||
|
||||
Args:
|
||||
embedding: Embedding vector to look up documents similar to.
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
filter (Optional[Dict[str, Any]]): Filter by metadata. Defaults to None.
|
||||
fetch_k: (Optional[int]) Number of Documents to fetch before filtering.
|
||||
Defaults to 20.
|
||||
**kwargs: kwargs to be passed to similarity search. Can include:
|
||||
score_threshold: Optional, a floating point value between 0 to 1 to
|
||||
filter the resulting set of retrieved docs
|
||||
|
||||
Returns:
|
||||
List of documents most similar to the query text and L2 distance
|
||||
in float for each. Lower score represents more similarity.
|
||||
"""
|
||||
faiss = dependable_faiss_import()
|
||||
vector = np.array([embedding], dtype=np.float32)
|
||||
if self._normalize_L2:
|
||||
faiss.normalize_L2(vector)
|
||||
scores, indices = self.index.search(vector, k if filter is None else fetch_k)
|
||||
# 经过normalize的结果会超出1
|
||||
if self._normalize_L2:
|
||||
scores = np.array([row / np.linalg.norm(row) if np.max(row) > 1 else row for row in scores])
|
||||
docs = []
|
||||
for j, i in enumerate(indices[0]):
|
||||
if i == -1:
|
||||
# This happens when not enough docs are returned.
|
||||
continue
|
||||
_id = self.index_to_docstore_id[i]
|
||||
doc = self.docstore.search(_id)
|
||||
if not isinstance(doc, Document):
|
||||
raise ValueError(f"Could not find document for id {_id}, got {doc}")
|
||||
if filter is not None:
|
||||
filter = {
|
||||
key: [value] if not isinstance(value, list) else value
|
||||
for key, value in filter.items()
|
||||
}
|
||||
if all(doc.metadata.get(key) in value for key, value in filter.items()):
|
||||
docs.append((doc, scores[0][j]))
|
||||
else:
|
||||
docs.append((doc, scores[0][j]))
|
||||
|
||||
score_threshold = kwargs.get("score_threshold")
|
||||
if score_threshold is not None:
|
||||
cmp = (
|
||||
operator.ge
|
||||
if self.distance_strategy
|
||||
in (DistanceStrategy.MAX_INNER_PRODUCT, DistanceStrategy.JACCARD)
|
||||
else operator.le
|
||||
)
|
||||
docs = [
|
||||
(doc, similarity)
|
||||
for doc, similarity in docs
|
||||
if cmp(similarity, score_threshold)
|
||||
]
|
||||
return docs[:k]
|
||||
|
||||
def similarity_search_with_score(
|
||||
self,
|
||||
query: str,
|
||||
k: int = 4,
|
||||
filter: Optional[Dict[str, Any]] = None,
|
||||
fetch_k: int = 20,
|
||||
**kwargs: Any,
|
||||
) -> List[Tuple[Document, float]]:
|
||||
"""Return docs most similar to query.
|
||||
|
||||
Args:
|
||||
query: Text to look up documents similar to.
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
|
||||
fetch_k: (Optional[int]) Number of Documents to fetch before filtering.
|
||||
Defaults to 20.
|
||||
|
||||
Returns:
|
||||
List of documents most similar to the query text with
|
||||
L2 distance in float. Lower score represents more similarity.
|
||||
"""
|
||||
embedding = self.embedding_function(query)
|
||||
docs = self.similarity_search_with_score_by_vector(
|
||||
embedding,
|
||||
k,
|
||||
filter=filter,
|
||||
fetch_k=fetch_k,
|
||||
**kwargs,
|
||||
)
|
||||
return docs
|
||||
|
||||
def similarity_search_by_vector(
|
||||
self,
|
||||
embedding: List[float],
|
||||
k: int = 4,
|
||||
filter: Optional[Dict[str, Any]] = None,
|
||||
fetch_k: int = 20,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Return docs most similar to embedding vector.
|
||||
|
||||
Args:
|
||||
embedding: Embedding to look up documents similar to.
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
|
||||
fetch_k: (Optional[int]) Number of Documents to fetch before filtering.
|
||||
Defaults to 20.
|
||||
|
||||
Returns:
|
||||
List of Documents most similar to the embedding.
|
||||
"""
|
||||
docs_and_scores = self.similarity_search_with_score_by_vector(
|
||||
embedding,
|
||||
k,
|
||||
filter=filter,
|
||||
fetch_k=fetch_k,
|
||||
**kwargs,
|
||||
)
|
||||
return [doc for doc, _ in docs_and_scores]
|
||||
|
||||
def similarity_search(
|
||||
self,
|
||||
query: str,
|
||||
k: int = 4,
|
||||
filter: Optional[Dict[str, Any]] = None,
|
||||
fetch_k: int = 20,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Return docs most similar to query.
|
||||
|
||||
Args:
|
||||
query: Text to look up documents similar to.
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
filter: (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
|
||||
fetch_k: (Optional[int]) Number of Documents to fetch before filtering.
|
||||
Defaults to 20.
|
||||
|
||||
Returns:
|
||||
List of Documents most similar to the query.
|
||||
"""
|
||||
docs_and_scores = self.similarity_search_with_score(
|
||||
query, k, filter=filter, fetch_k=fetch_k, **kwargs
|
||||
)
|
||||
return [doc for doc, _ in docs_and_scores]
|
||||
|
||||
def max_marginal_relevance_search_with_score_by_vector(
|
||||
self,
|
||||
embedding: List[float],
|
||||
*,
|
||||
k: int = 4,
|
||||
fetch_k: int = 20,
|
||||
lambda_mult: float = 0.5,
|
||||
filter: Optional[Dict[str, Any]] = None,
|
||||
) -> List[Tuple[Document, float]]:
|
||||
"""Return docs and their similarity scores selected using the maximal marginal
|
||||
relevance.
|
||||
|
||||
Maximal marginal relevance optimizes for similarity to query AND diversity
|
||||
among selected documents.
|
||||
|
||||
Args:
|
||||
embedding: Embedding to look up documents similar to.
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
fetch_k: Number of Documents to fetch before filtering to
|
||||
pass to MMR algorithm.
|
||||
lambda_mult: Number between 0 and 1 that determines the degree
|
||||
of diversity among the results with 0 corresponding
|
||||
to maximum diversity and 1 to minimum diversity.
|
||||
Defaults to 0.5.
|
||||
Returns:
|
||||
List of Documents and similarity scores selected by maximal marginal
|
||||
relevance and score for each.
|
||||
"""
|
||||
scores, indices = self.index.search(
|
||||
np.array([embedding], dtype=np.float32),
|
||||
fetch_k if filter is None else fetch_k * 2,
|
||||
)
|
||||
if filter is not None:
|
||||
filtered_indices = []
|
||||
for i in indices[0]:
|
||||
if i == -1:
|
||||
# This happens when not enough docs are returned.
|
||||
continue
|
||||
_id = self.index_to_docstore_id[i]
|
||||
doc = self.docstore.search(_id)
|
||||
if not isinstance(doc, Document):
|
||||
raise ValueError(f"Could not find document for id {_id}, got {doc}")
|
||||
if all(
|
||||
doc.metadata.get(key) in value
|
||||
if isinstance(value, list)
|
||||
else doc.metadata.get(key) == value
|
||||
for key, value in filter.items()
|
||||
):
|
||||
filtered_indices.append(i)
|
||||
indices = np.array([filtered_indices])
|
||||
# -1 happens when not enough docs are returned.
|
||||
embeddings = [self.index.reconstruct(int(i)) for i in indices[0] if i != -1]
|
||||
mmr_selected = maximal_marginal_relevance(
|
||||
np.array([embedding], dtype=np.float32),
|
||||
embeddings,
|
||||
k=k,
|
||||
lambda_mult=lambda_mult,
|
||||
)
|
||||
selected_indices = [indices[0][i] for i in mmr_selected]
|
||||
selected_scores = [scores[0][i] for i in mmr_selected]
|
||||
docs_and_scores = []
|
||||
for i, score in zip(selected_indices, selected_scores):
|
||||
if i == -1:
|
||||
# This happens when not enough docs are returned.
|
||||
continue
|
||||
_id = self.index_to_docstore_id[i]
|
||||
doc = self.docstore.search(_id)
|
||||
if not isinstance(doc, Document):
|
||||
raise ValueError(f"Could not find document for id {_id}, got {doc}")
|
||||
docs_and_scores.append((doc, score))
|
||||
return docs_and_scores
|
||||
|
||||
def max_marginal_relevance_search_by_vector(
|
||||
self,
|
||||
embedding: List[float],
|
||||
k: int = 4,
|
||||
fetch_k: int = 20,
|
||||
lambda_mult: float = 0.5,
|
||||
filter: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Return docs selected using the maximal marginal relevance.
|
||||
|
||||
Maximal marginal relevance optimizes for similarity to query AND diversity
|
||||
among selected documents.
|
||||
|
||||
Args:
|
||||
embedding: Embedding to look up documents similar to.
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
fetch_k: Number of Documents to fetch before filtering to
|
||||
pass to MMR algorithm.
|
||||
lambda_mult: Number between 0 and 1 that determines the degree
|
||||
of diversity among the results with 0 corresponding
|
||||
to maximum diversity and 1 to minimum diversity.
|
||||
Defaults to 0.5.
|
||||
Returns:
|
||||
List of Documents selected by maximal marginal relevance.
|
||||
"""
|
||||
docs_and_scores = self.max_marginal_relevance_search_with_score_by_vector(
|
||||
embedding, k=k, fetch_k=fetch_k, lambda_mult=lambda_mult, filter=filter
|
||||
)
|
||||
return [doc for doc, _ in docs_and_scores]
|
||||
|
||||
def max_marginal_relevance_search(
|
||||
self,
|
||||
query: str,
|
||||
k: int = 4,
|
||||
fetch_k: int = 20,
|
||||
lambda_mult: float = 0.5,
|
||||
filter: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Return docs selected using the maximal marginal relevance.
|
||||
|
||||
Maximal marginal relevance optimizes for similarity to query AND diversity
|
||||
among selected documents.
|
||||
|
||||
Args:
|
||||
query: Text to look up documents similar to.
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
fetch_k: Number of Documents to fetch before filtering (if needed) to
|
||||
pass to MMR algorithm.
|
||||
lambda_mult: Number between 0 and 1 that determines the degree
|
||||
of diversity among the results with 0 corresponding
|
||||
to maximum diversity and 1 to minimum diversity.
|
||||
Defaults to 0.5.
|
||||
Returns:
|
||||
List of Documents selected by maximal marginal relevance.
|
||||
"""
|
||||
embedding = self.embedding_function(query)
|
||||
docs = self.max_marginal_relevance_search_by_vector(
|
||||
embedding,
|
||||
k=k,
|
||||
fetch_k=fetch_k,
|
||||
lambda_mult=lambda_mult,
|
||||
filter=filter,
|
||||
**kwargs,
|
||||
)
|
||||
return docs
|
||||
|
||||
def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> Optional[bool]:
|
||||
"""Delete by ID. These are the IDs in the vectorstore.
|
||||
|
||||
Args:
|
||||
ids: List of ids to delete.
|
||||
|
||||
Returns:
|
||||
Optional[bool]: True if deletion is successful,
|
||||
False otherwise, None if not implemented.
|
||||
"""
|
||||
if ids is None:
|
||||
raise ValueError("No ids provided to delete.")
|
||||
missing_ids = set(ids).difference(self.index_to_docstore_id.values())
|
||||
if missing_ids:
|
||||
raise ValueError(
|
||||
f"Some specified ids do not exist in the current store. Ids not found: "
|
||||
f"{missing_ids}"
|
||||
)
|
||||
|
||||
reversed_index = {id_: idx for idx, id_ in self.index_to_docstore_id.items()}
|
||||
index_to_delete = [reversed_index[id_] for id_ in ids]
|
||||
|
||||
self.index.remove_ids(np.array(index_to_delete, dtype=np.int64))
|
||||
self.docstore.delete(ids)
|
||||
|
||||
remaining_ids = [
|
||||
id_
|
||||
for i, id_ in sorted(self.index_to_docstore_id.items())
|
||||
if i not in index_to_delete
|
||||
]
|
||||
self.index_to_docstore_id = {i: id_ for i, id_ in enumerate(remaining_ids)}
|
||||
|
||||
return True
|
||||
|
||||
def merge_from(self, target: FAISS) -> None:
|
||||
"""Merge another FAISS object with the current one.
|
||||
|
||||
Add the target FAISS to the current one.
|
||||
|
||||
Args:
|
||||
target: FAISS object you wish to merge into the current one
|
||||
|
||||
Returns:
|
||||
None.
|
||||
"""
|
||||
if not isinstance(self.docstore, AddableMixin):
|
||||
raise ValueError("Cannot merge with this type of docstore")
|
||||
# Numerical index for target docs are incremental on existing ones
|
||||
starting_len = len(self.index_to_docstore_id)
|
||||
|
||||
# Merge two IndexFlatL2
|
||||
self.index.merge_from(target.index)
|
||||
|
||||
# Get id and docs from target FAISS object
|
||||
full_info = []
|
||||
for i, target_id in target.index_to_docstore_id.items():
|
||||
doc = target.docstore.search(target_id)
|
||||
if not isinstance(doc, Document):
|
||||
raise ValueError("Document should be returned")
|
||||
full_info.append((starting_len + i, target_id, doc))
|
||||
|
||||
# Add information to docstore and index_to_docstore_id.
|
||||
self.docstore.add({_id: doc for _, _id, doc in full_info})
|
||||
index_to_id = {index: _id for index, _id, _ in full_info}
|
||||
self.index_to_docstore_id.update(index_to_id)
|
||||
|
||||
@classmethod
|
||||
def __from(
|
||||
cls,
|
||||
texts: Iterable[str],
|
||||
embeddings: List[List[float]],
|
||||
embedding: Embeddings,
|
||||
metadatas: Optional[Iterable[dict]] = None,
|
||||
ids: Optional[List[str]] = None,
|
||||
normalize_L2: bool = False,
|
||||
distance_strategy: DistanceStrategy = DistanceStrategy.EUCLIDEAN_DISTANCE,
|
||||
**kwargs: Any,
|
||||
) -> FAISS:
|
||||
faiss = dependable_faiss_import()
|
||||
if distance_strategy == DistanceStrategy.MAX_INNER_PRODUCT:
|
||||
index = faiss.IndexFlatIP(len(embeddings[0]))
|
||||
else:
|
||||
# Default to L2, currently other metric types not initialized.
|
||||
index = faiss.IndexFlatL2(len(embeddings[0]))
|
||||
vecstore = cls(
|
||||
embedding.embed_query,
|
||||
index,
|
||||
InMemoryDocstore({}),
|
||||
{},
|
||||
normalize_L2=normalize_L2,
|
||||
distance_strategy=distance_strategy,
|
||||
**kwargs,
|
||||
)
|
||||
vecstore.__add(texts, embeddings, metadatas=metadatas, ids=ids)
|
||||
return vecstore
|
||||
|
||||
@classmethod
|
||||
def from_texts(
|
||||
cls,
|
||||
texts: List[str],
|
||||
embedding: Embeddings,
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
ids: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> FAISS:
|
||||
"""Construct FAISS wrapper from raw documents.
|
||||
|
||||
This is a user friendly interface that:
|
||||
1. Embeds documents.
|
||||
2. Creates an in memory docstore
|
||||
3. Initializes the FAISS database
|
||||
|
||||
This is intended to be a quick way to get started.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain import FAISS
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
|
||||
embeddings = OpenAIEmbeddings()
|
||||
faiss = FAISS.from_texts(texts, embeddings)
|
||||
"""
|
||||
from loguru import logger
|
||||
embeddings = embedding.embed_documents(texts)
|
||||
return cls.__from(
|
||||
texts,
|
||||
embeddings,
|
||||
embedding,
|
||||
metadatas=metadatas,
|
||||
ids=ids,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_embeddings(
|
||||
cls,
|
||||
text_embeddings: Iterable[Tuple[str, List[float]]],
|
||||
embedding: Embeddings,
|
||||
metadatas: Optional[Iterable[dict]] = None,
|
||||
ids: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> FAISS:
|
||||
"""Construct FAISS wrapper from raw documents.
|
||||
|
||||
This is a user friendly interface that:
|
||||
1. Embeds documents.
|
||||
2. Creates an in memory docstore
|
||||
3. Initializes the FAISS database
|
||||
|
||||
This is intended to be a quick way to get started.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain import FAISS
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
|
||||
embeddings = OpenAIEmbeddings()
|
||||
text_embeddings = embeddings.embed_documents(texts)
|
||||
text_embedding_pairs = zip(texts, text_embeddings)
|
||||
faiss = FAISS.from_embeddings(text_embedding_pairs, embeddings)
|
||||
"""
|
||||
texts = [t[0] for t in text_embeddings]
|
||||
embeddings = [t[1] for t in text_embeddings]
|
||||
return cls.__from(
|
||||
texts,
|
||||
embeddings,
|
||||
embedding,
|
||||
metadatas=metadatas,
|
||||
ids=ids,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def save_local(self, folder_path: str, index_name: str = "index") -> None:
|
||||
"""Save FAISS index, docstore, and index_to_docstore_id to disk.
|
||||
|
||||
Args:
|
||||
folder_path: folder path to save index, docstore,
|
||||
and index_to_docstore_id to.
|
||||
index_name: for saving with a specific index file name
|
||||
"""
|
||||
path = Path(folder_path)
|
||||
path.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
# save index separately since it is not picklable
|
||||
faiss = dependable_faiss_import()
|
||||
faiss.write_index(
|
||||
self.index, str(path / "{index_name}.faiss".format(index_name=index_name))
|
||||
)
|
||||
|
||||
# save docstore and index_to_docstore_id
|
||||
with open(path / "{index_name}.pkl".format(index_name=index_name), "wb") as f:
|
||||
pickle.dump((self.docstore, self.index_to_docstore_id), f)
|
||||
|
||||
@classmethod
|
||||
def load_local(
|
||||
cls,
|
||||
folder_path: str,
|
||||
embeddings: Embeddings,
|
||||
index_name: str = "index",
|
||||
**kwargs: Any,
|
||||
) -> FAISS:
|
||||
"""Load FAISS index, docstore, and index_to_docstore_id from disk.
|
||||
|
||||
Args:
|
||||
folder_path: folder path to load index, docstore,
|
||||
and index_to_docstore_id from.
|
||||
embeddings: Embeddings to use when generating queries
|
||||
index_name: for saving with a specific index file name
|
||||
"""
|
||||
path = Path(folder_path)
|
||||
# load index separately since it is not picklable
|
||||
faiss = dependable_faiss_import()
|
||||
index = faiss.read_index(
|
||||
str(path / "{index_name}.faiss".format(index_name=index_name))
|
||||
)
|
||||
|
||||
# load docstore and index_to_docstore_id
|
||||
with open(path / "{index_name}.pkl".format(index_name=index_name), "rb") as f:
|
||||
docstore, index_to_docstore_id = pickle.load(f)
|
||||
return cls(
|
||||
embeddings.embed_query, index, docstore, index_to_docstore_id, **kwargs
|
||||
)
|
||||
|
||||
def serialize_to_bytes(self) -> bytes:
|
||||
"""Serialize FAISS index, docstore, and index_to_docstore_id to bytes."""
|
||||
return pickle.dumps((self.index, self.docstore, self.index_to_docstore_id))
|
||||
|
||||
@classmethod
|
||||
def deserialize_from_bytes(
|
||||
cls,
|
||||
serialized: bytes,
|
||||
embeddings: Embeddings,
|
||||
**kwargs: Any,
|
||||
) -> FAISS:
|
||||
"""Deserialize FAISS index, docstore, and index_to_docstore_id from bytes."""
|
||||
index, docstore, index_to_docstore_id = pickle.loads(serialized)
|
||||
return cls(
|
||||
embeddings.embed_query, index, docstore, index_to_docstore_id, **kwargs
|
||||
)
|
||||
|
||||
def _select_relevance_score_fn(self) -> Callable[[float], float]:
|
||||
"""
|
||||
The 'correct' relevance function
|
||||
may differ depending on a few things, including:
|
||||
- the distance / similarity metric used by the VectorStore
|
||||
- the scale of your embeddings (OpenAI's are unit normed. Many others are not!)
|
||||
- embedding dimensionality
|
||||
- etc.
|
||||
"""
|
||||
if self.override_relevance_score_fn is not None:
|
||||
return self.override_relevance_score_fn
|
||||
|
||||
# Default strategy is to rely on distance strategy provided in
|
||||
# vectorstore constructor
|
||||
if self.distance_strategy == DistanceStrategy.MAX_INNER_PRODUCT:
|
||||
return self._max_inner_product_relevance_score_fn
|
||||
elif self.distance_strategy == DistanceStrategy.EUCLIDEAN_DISTANCE:
|
||||
# Default behavior is to use euclidean distance relevancy
|
||||
return self._euclidean_relevance_score_fn
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unknown distance strategy, must be cosine, max_inner_product,"
|
||||
" or euclidean"
|
||||
)
|
||||
|
||||
def _similarity_search_with_relevance_scores(
|
||||
self,
|
||||
query: str,
|
||||
k: int = 4,
|
||||
filter: Optional[Dict[str, Any]] = None,
|
||||
fetch_k: int = 20,
|
||||
**kwargs: Any,
|
||||
) -> List[Tuple[Document, float]]:
|
||||
"""Return docs and their similarity scores on a scale from 0 to 1."""
|
||||
# Pop score threshold so that only relevancy scores, not raw scores, are
|
||||
# filtered.
|
||||
relevance_score_fn = self._select_relevance_score_fn()
|
||||
if relevance_score_fn is None:
|
||||
raise ValueError(
|
||||
"normalize_score_fn must be provided to"
|
||||
" FAISS constructor to normalize scores"
|
||||
)
|
||||
docs_and_scores = self.similarity_search_with_score(
|
||||
query,
|
||||
k=k,
|
||||
filter=filter,
|
||||
fetch_k=fetch_k,
|
||||
**kwargs,
|
||||
)
|
||||
docs_and_rel_scores = [
|
||||
(doc, relevance_score_fn(score)) for doc, score in docs_and_scores
|
||||
]
|
||||
return docs_and_rel_scores
|
|
@ -1,49 +0,0 @@
|
|||
# encoding: utf-8
|
||||
'''
|
||||
@author: 温进
|
||||
@file: get_embedding.py
|
||||
@time: 2023/11/22 上午11:30
|
||||
@desc:
|
||||
'''
|
||||
from loguru import logger
|
||||
|
||||
# from configs.model_config import EMBEDDING_MODEL
|
||||
from coagent.embeddings.openai_embedding import OpenAIEmbedding
|
||||
from coagent.embeddings.huggingface_embedding import HFEmbedding
|
||||
from coagent.llm_models.llm_config import EmbedConfig
|
||||
|
||||
def get_embedding(
|
||||
engine: str,
|
||||
text_list: list,
|
||||
model_path: str = "text2vec-base-chinese",
|
||||
embedding_device: str = "cpu",
|
||||
embed_config: EmbedConfig = None,
|
||||
):
|
||||
'''
|
||||
get embedding
|
||||
@param engine: openai / hf
|
||||
@param text_list:
|
||||
@return:
|
||||
'''
|
||||
emb_res = {}
|
||||
if embed_config and embed_config.langchain_embeddings:
|
||||
emb_res = embed_config.langchain_embeddings.embed_documents(text_list)
|
||||
emb_res = {
|
||||
text_list[idx]: emb_res[idx] for idx in range(len(text_list))
|
||||
}
|
||||
elif engine == 'openai':
|
||||
oae = OpenAIEmbedding()
|
||||
emb_res = oae.get_emb(text_list)
|
||||
elif engine == 'model':
|
||||
hfe = HFEmbedding(model_path, embedding_device)
|
||||
emb_res = hfe.get_emb(text_list)
|
||||
|
||||
return emb_res
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
engine = 'model'
|
||||
text_list = ['这段代码是一个OkHttp拦截器,用于在请求头中添加授权令牌。它继承自`com.theokanning.openai.client.AuthenticationInterceptor`类,并且被标记为`@Deprecated`,意味着它已经过时了。\n\n这个拦截器的作用是在每个请求的头部添加一个名为"Authorization"的字段,值为传入的授权令牌。这样,当请求被发送到服务器时,服务器可以使用这个令牌来验证请求的合法性。\n\n这段代码的构造函数接受一个令牌作为参数,并将其传递给父类的构造函数。这个令牌应该是一个有效的授权令牌,用于访问受保护的资源。', '这段代码定义了一个接口`OpenAiApi`,并使用`@Deprecated`注解将其标记为已过时。它还扩展了`com.theokanning.openai.client.OpenAiApi`接口。\n\n`@Deprecated`注解表示该接口已经过时,不推荐使用。开发者应该使用`com.theokanning.openai.client.OpenAiApi`接口代替。\n\n注释中提到这个接口只是为了保持向后兼容性。这意味着它可能是为了与旧版本的代码兼容而保留的,但不推荐在新代码中使用。', '这段代码是一个OkHttp的拦截器,用于在请求头中添加授权令牌(authorization token)。\n\n在这个拦截器中,首先获取到传入的授权令牌(token),然后在每个请求的构建过程中,使用`newBuilder()`方法创建一个新的请求构建器,并在该构建器中添加一个名为"Authorization"的请求头,值为"Bearer " + token。最后,使用该构建器构建一个新的请求,并通过`chain.proceed(request)`方法继续处理该请求。\n\n这样,当使用OkHttp发送请求时,该拦截器会自动在请求头中添加授权令牌,以实现身份验证的功能。', '这段代码是一个Java接口,用于定义与OpenAI API进行通信的方法。它包含了各种不同类型的请求和响应方法,用于与OpenAI API的不同端点进行交互。\n\n接口中的方法包括:\n- `listModels()`:获取可用的模型列表。\n- `getModel(String modelId)`:获取指定模型的详细信息。\n- `createCompletion(CompletionRequest request)`:创建文本生成的请求。\n- `createChatCompletion(ChatCompletionRequest request)`:创建聊天式文本生成的请求。\n- `createEdit(EditRequest request)`:创建文本编辑的请求。\n- `createEmbeddings(EmbeddingRequest request)`:创建文本嵌入的请求。\n- `listFiles()`:获取已上传文件的列表。\n- `uploadFile(RequestBody purpose, MultipartBody.Part file)`:上传文件。\n- `deleteFile(String fileId)`:删除文件。\n- `retrieveFile(String fileId)`:获取文件的详细信息。\n- `retrieveFileContent(String fileId)`:获取文件的内容。\n- `createFineTuningJob(FineTuningJobRequest request)`:创建Fine-Tuning任务。\n- `listFineTuningJobs()`:获取Fine-Tuning任务的列表。\n- `retrieveFineTuningJob(String fineTuningJobId)`:获取指定Fine-Tuning任务的详细信息。\n- `cancelFineTuningJob(String fineTuningJobId)`:取消Fine-Tuning任务。\n- `listFineTuningJobEvents(String fineTuningJobId)`:获取Fine-Tuning任务的事件列表。\n- `createFineTuneCompletion(CompletionRequest request)`:创建Fine-Tuning模型的文本生成请求。\n- `createImage(CreateImageRequest request)`:创建图像生成的请求。\n- `createImageEdit(RequestBody requestBody)`:创建图像编辑的请求。\n- `createImageVariation(RequestBody requestBody)`:创建图像变体的请求。\n- `createTranscription(RequestBody requestBody)`:创建音频转录的请求。\n- `createTranslation(RequestBody requestBody)`:创建音频翻译的请求。\n- `createModeration(ModerationRequest request)`:创建内容审核的请求。\n- `getEngines()`:获取可用的引擎列表。\n- `getEngine(String engineId)`:获取指定引擎的详细信息。\n- `subscription()`:获取账户订阅信息。\n- `billingUsage(LocalDate starDate, LocalDate endDate)`:获取账户消费信息。\n\n这些方法使用不同的HTTP请求类型(GET、POST、DELETE)和路径来与OpenAI API进行交互,并返回相应的响应数据。']
|
||||
|
||||
res = get_embedding(engine, text_list)
|
||||
logger.debug(res)
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue