codefuse-chatbot/coagent/connector/prompt_manager/prompt_manager.py

353 lines
16 KiB
Python
Raw Normal View History

import random
from textwrap import dedent
import copy
from loguru import logger
from langchain.agents.tools import Tool
from coagent.connector.schema import Memory, Message
from coagent.connector.utils import extract_section, parse_section
class PromptManager:
def __init__(self, role_prompt="", prompt_config=None, monitored_agents=[], monitored_fields=[]):
self.role_prompt = role_prompt
self.monitored_agents = monitored_agents
self.monitored_fields = monitored_fields
self.field_handlers = {}
self.context_handlers = {}
self.field_order = [] # 用于普通字段的顺序
self.context_order = [] # 单独维护上下文字段的顺序
self.field_descriptions = {}
self.omit_if_empty_flags = {}
self.context_title = "### Context Data\n\n"
self.prompt_config = prompt_config
if self.prompt_config:
self.register_fields_from_config()
def register_field(self, field_name, function=None, title=None, description=None, is_context=True, omit_if_empty=True):
"""
注册一个新的字段及其处理函数
Args:
field_name (str): 字段名称
function (callable): 处理字段数据的函数
title (str, optional): 字段的自定义标题可选
description (str, optional): 字段的描述可选可以是几句话
is_context (bool, optional): 指示该字段是否为上下文字段
omit_if_empty (bool, optional): 如果数据为空是否省略该字段
"""
if not function:
function = self.handle_custom_data
# Register the handler function based on context flag
if is_context:
self.context_handlers[field_name] = function
else:
self.field_handlers[field_name] = function
# Store the custom title if provided and adjust the title prefix based on context
title_prefix = "####" if is_context else "###"
if title is not None:
self.field_descriptions[field_name] = f"{title_prefix} {title}\n\n"
elif description is not None:
# If title is not provided but description is, use description as title
self.field_descriptions[field_name] = f"{title_prefix} {field_name.replace('_', ' ').title()}\n\n{description}\n\n"
else:
# If neither title nor description is provided, use the field name as title
self.field_descriptions[field_name] = f"{title_prefix} {field_name.replace('_', ' ').title()}\n\n"
# Store the omit_if_empty flag for this field
self.omit_if_empty_flags[field_name] = omit_if_empty
if is_context and field_name != 'context_placeholder':
self.context_handlers[field_name] = function
self.context_order.append(field_name)
else:
self.field_handlers[field_name] = function
self.field_order.append(field_name)
def generate_full_prompt(self, **kwargs):
full_prompt = []
context_prompts = [] # 用于收集上下文内容
is_pre_print = kwargs.get("is_pre_print", False) # 用于强制打印所有prompt 字段信息,不管有没有空
# 先处理上下文字段
for field_name in self.context_order:
handler = self.context_handlers[field_name]
processed_prompt = handler(**kwargs)
# Check if the field should be omitted when empty
if self.omit_if_empty_flags.get(field_name, False) and not processed_prompt and not is_pre_print:
continue # Skip this field
title_or_description = self.field_descriptions.get(field_name, f"#### {field_name.replace('_', ' ').title()}\n\n")
context_prompts.append(title_or_description + processed_prompt + '\n\n')
# 处理普通字段,同时查找 context_placeholder 的位置
for field_name in self.field_order:
if field_name == 'context_placeholder':
# 在 context_placeholder 的位置插入上下文数据
full_prompt.append(self.context_title) # 添加上下文部分的大标题
full_prompt.extend(context_prompts) # 添加收集的上下文内容
else:
handler = self.field_handlers[field_name]
processed_prompt = handler(**kwargs)
# Check if the field should be omitted when empty
if self.omit_if_empty_flags.get(field_name, False) and not processed_prompt and not is_pre_print:
continue # Skip this field
title_or_description = self.field_descriptions.get(field_name, f"### {field_name.replace('_', ' ').title()}\n\n")
full_prompt.append(title_or_description + processed_prompt + '\n\n')
# 返回完整的提示,移除尾部的空行
return ''.join(full_prompt).rstrip('\n')
def pre_print(self, **kwargs):
kwargs.update({"is_pre_print": True})
prompt = self.generate_full_prompt(**kwargs)
input_keys = parse_section(self.role_prompt, 'Response Output Format')
llm_predict = "\n".join([f"**{k}:**" for k in input_keys])
return prompt + "\n\n" + "#"*19 + "\n<<<<LLM PREDICT>>>>\n" + "#"*19 + f"\n\n{llm_predict}\n"
def handle_custom_data(self, **kwargs):
return ""
def handle_tool_data(self, **kwargs):
if 'previous_agent_message' not in kwargs:
return ""
previous_agent_message = kwargs.get('previous_agent_message')
tools: list[Tool] = previous_agent_message.tools
if not tools:
return ""
tool_strings = []
for tool in tools:
args_str = f'args: {str(tool.args)}' if tool.args_schema else ""
tool_strings.append(f"{tool.name}: {tool.description}, {args_str}")
formatted_tools = "\n".join(tool_strings)
tool_names = ", ".join([tool.name for tool in tools])
tool_prompt = dedent(f"""
Below is a list of tools that are available for your use:
{formatted_tools}
valid "tool_name" value is:
{tool_names}
""")
return tool_prompt
def handle_agent_data(self, **kwargs):
if 'agents' not in kwargs:
return ""
agents = kwargs.get('agents')
random.shuffle(agents)
agent_names = ", ".join([f'{agent.role.role_name}' for agent in agents])
agent_descs = []
for agent in agents:
role_desc = agent.role.role_prompt.split("####")[1]
while "\n\n" in role_desc:
role_desc = role_desc.replace("\n\n", "\n")
role_desc = role_desc.replace("\n", ",")
agent_descs.append(f'"role name: {agent.role.role_name}\nrole description: {role_desc}"')
agents = "\n".join(agent_descs)
agent_prompt = f'''
Please ensure your selection is one of the listed roles. Available roles for selection:
{agents}
Please ensure select the Role from agent names, such as {agent_names}'''
return dedent(agent_prompt)
def handle_doc_info(self, **kwargs) -> str:
if 'previous_agent_message' not in kwargs:
return ""
previous_agent_message: Message = kwargs.get('previous_agent_message')
db_docs = previous_agent_message.db_docs
search_docs = previous_agent_message.search_docs
code_cocs = previous_agent_message.code_docs
doc_infos = "\n".join([doc.get_snippet() for doc in db_docs] + [doc.get_snippet() for doc in search_docs] +
[doc.get_code() for doc in code_cocs])
return doc_infos
def handle_session_records(self, **kwargs) -> str:
memory_pool: Memory = kwargs.get('memory_pool', Memory(messages=[]))
memory_pool = self.select_memory_by_agent_name(memory_pool)
memory_pool = self.select_memory_by_parsed_key(memory_pool)
return memory_pool.to_str_messages(content_key="parsed_output_list", with_tag=True)
def handle_current_plan(self, **kwargs) -> str:
if 'previous_agent_message' not in kwargs:
return ""
previous_agent_message = kwargs['previous_agent_message']
return previous_agent_message.parsed_output.get("CURRENT_STEP", "")
def handle_agent_profile(self, **kwargs) -> str:
return extract_section(self.role_prompt, 'Agent Profile')
def handle_output_format(self, **kwargs) -> str:
return extract_section(self.role_prompt, 'Response Output Format')
def handle_response(self, **kwargs) -> str:
if 'react_memory' not in kwargs:
return ""
react_memory = kwargs.get('react_memory', Memory(messages=[]))
if react_memory is None:
return ""
return "\n".join(["\n".join([f"**{k}:**\n{v}" for k,v in _dict.items()]) for _dict in react_memory.get_parserd_output()])
def handle_task_records(self, **kwargs) -> str:
if 'task_memory' not in kwargs:
return ""
task_memory: Memory = kwargs.get('task_memory', Memory(messages=[]))
if task_memory is None:
return ""
return "\n".join(["\n".join([f"**{k}:**\n{v}" for k,v in _dict.items() if k not in ["CURRENT_STEP"]]) for _dict in task_memory.get_parserd_output()])
def handle_previous_message(self, message: Message) -> str:
pass
def handle_message_by_role_name(self, message: Message) -> str:
pass
def handle_message_by_role_type(self, message: Message) -> str:
pass
def handle_current_agent_react_message(self, message: Message) -> str:
pass
def extract_codedoc_info_for_prompt(self, message: Message) -> str:
code_docs = message.code_docs
doc_infos = "\n".join([doc.get_code() for doc in code_docs])
return doc_infos
def select_memory_by_parsed_key(self, memory: Memory) -> Memory:
return Memory(
messages=[self.select_message_by_parsed_key(message) for message in memory.messages
if self.select_message_by_parsed_key(message) is not None]
)
def select_memory_by_agent_name(self, memory: Memory) -> Memory:
return Memory(
messages=[self.select_message_by_agent_name(message) for message in memory.messages
if self.select_message_by_agent_name(message) is not None]
)
def select_message_by_agent_name(self, message: Message) -> Message:
# assume we focus all agents
if self.monitored_agents == []:
return message
return None if message is None or message.role_name not in self.monitored_agents else self.select_message_by_parsed_key(message)
def select_message_by_parsed_key(self, message: Message) -> Message:
# assume we focus all key contents
if message is None:
return message
if self.monitored_fields == []:
return message
message_c = copy.deepcopy(message)
message_c.parsed_output = {k: v for k,v in message_c.parsed_output.items() if k in self.monitored_fields}
message_c.parsed_output_list = [{k: v for k,v in parsed_output.items() if k in self.monitored_fields} for parsed_output in message_c.parsed_output_list]
return message_c
def get_memory(self, content_key="role_content"):
return self.memory.to_tuple_messages(content_key="step_content")
def get_memory_str(self, content_key="role_content"):
return "\n".join([": ".join(i) for i in self.memory.to_tuple_messages(content_key="step_content")])
def register_fields_from_config(self):
for prompt_field in self.prompt_config:
function_name = prompt_field.function_name
# 检查function_name是否是self的一个方法
if function_name and hasattr(self, function_name):
function = getattr(self, function_name)
else:
function = self.handle_custom_data
self.register_field(prompt_field.field_name,
function=function,
title=prompt_field.title,
description=prompt_field.description,
is_context=prompt_field.is_context,
omit_if_empty=prompt_field.omit_if_empty)
def register_standard_fields(self):
self.register_field('agent_profile', function=self.handle_agent_profile, is_context=False)
self.register_field('tool_information', function=self.handle_tool_data, is_context=False)
self.register_field('context_placeholder', is_context=True) # 用于标记上下文数据部分的位置
self.register_field('reference_documents', function=self.handle_doc_info, is_context=True)
self.register_field('session_records', function=self.handle_session_records, is_context=True)
self.register_field('output_format', function=self.handle_output_format, title='Response Output Format', is_context=False)
self.register_field('response', function=self.handle_response, is_context=False, omit_if_empty=False)
def register_executor_fields(self):
self.register_field('agent_profile', function=self.handle_agent_profile, is_context=False)
self.register_field('tool_information', function=self.handle_tool_data, is_context=False)
self.register_field('context_placeholder', is_context=True) # 用于标记上下文数据部分的位置
self.register_field('reference_documents', function=self.handle_doc_info, is_context=True)
self.register_field('session_records', function=self.handle_session_records, is_context=True)
self.register_field('current_plan', function=self.handle_current_plan, is_context=True)
self.register_field('output_format', function=self.handle_output_format, title='Response Output Format', is_context=False)
self.register_field('response', function=self.handle_response, is_context=False, omit_if_empty=False)
def register_fields_from_dict(self, fields_dict):
# 使用字典注册字段的函数
for field_name, field_config in fields_dict.items():
function_name = field_config.get('function', None)
title = field_config.get('title', None)
description = field_config.get('description', None)
is_context = field_config.get('is_context', True)
omit_if_empty = field_config.get('omit_if_empty', True)
# 检查function_name是否是self的一个方法
if function_name and hasattr(self, function_name):
function = getattr(self, function_name)
else:
function = self.handle_custom_data
# 调用已存在的register_field方法注册字段
self.register_field(field_name, function=function, title=title, description=description, is_context=is_context, omit_if_empty=omit_if_empty)
def main():
manager = PromptManager()
manager.register_standard_fields()
manager.register_field('agents_work_progress', title=f"Agents' Work Progress", is_context=True)
# 创建数据字典
data_dict = {
"agent_profile": "这是代理配置文件...",
# "tool_list": "这是工具列表...",
"reference_documents": "这是参考文档...",
"session_records": "这是会话记录...",
"agents_work_progress": "这是代理工作进展...",
"output_format": "这是预期的输出格式...",
# "response": "这是生成或继续回应的指令...",
"response": "",
"test": 'xxxxx'
}
# 组合完整的提示
full_prompt = manager.generate_full_prompt(data_dict)
print(full_prompt)
if __name__ == "__main__":
main()