2024-01-26 14:03:25 +08:00
|
|
|
from pydantic import BaseModel, root_validator
|
2023-12-07 20:17:21 +08:00
|
|
|
from loguru import logger
|
|
|
|
|
2024-01-26 14:03:25 +08:00
|
|
|
from coagent.utils.common_utils import getCurrentDatetime
|
2023-12-07 20:17:21 +08:00
|
|
|
from .general_schema import *
|
|
|
|
|
|
|
|
|
|
|
|
class Message(BaseModel):
|
|
|
|
chat_index: str = None
|
|
|
|
role_name: str
|
|
|
|
role_type: str
|
|
|
|
role_prompt: str = None
|
|
|
|
input_query: str = None
|
|
|
|
origin_query: str = None
|
2024-01-26 14:03:25 +08:00
|
|
|
datetime: str = getCurrentDatetime()
|
2023-12-07 20:17:21 +08:00
|
|
|
|
2023-12-26 11:41:53 +08:00
|
|
|
# llm output
|
2023-12-07 20:17:21 +08:00
|
|
|
role_content: str = None
|
|
|
|
step_content: str = None
|
|
|
|
|
2023-12-26 11:41:53 +08:00
|
|
|
# llm parsed information
|
2023-12-07 20:17:21 +08:00
|
|
|
plans: List[str] = None
|
|
|
|
code_content: str = None
|
|
|
|
code_filename: str = None
|
|
|
|
tool_params: str = None
|
|
|
|
tool_name: str = None
|
|
|
|
parsed_output: dict = {}
|
|
|
|
spec_parsed_output: dict = {}
|
|
|
|
parsed_output_list: List[Dict] = []
|
|
|
|
|
2023-12-26 11:41:53 +08:00
|
|
|
# llm\tool\code executre information
|
2024-01-26 14:03:25 +08:00
|
|
|
action_status: str = "default"
|
2023-12-07 20:17:21 +08:00
|
|
|
agent_index: int = None
|
|
|
|
code_answer: str = None
|
|
|
|
tool_answer: str = None
|
|
|
|
observation: str = None
|
|
|
|
figures: Dict[str, str] = {}
|
|
|
|
|
2023-12-26 11:41:53 +08:00
|
|
|
# prompt support information
|
2023-12-07 20:17:21 +08:00
|
|
|
tools: List[BaseTool] = []
|
|
|
|
task: Task = None
|
|
|
|
db_docs: List['Doc'] = []
|
|
|
|
code_docs: List['CodeDoc'] = []
|
|
|
|
search_docs: List['Doc'] = []
|
|
|
|
agents: List = []
|
|
|
|
|
2023-12-26 11:41:53 +08:00
|
|
|
# phase input
|
2023-12-07 20:17:21 +08:00
|
|
|
phase_name: str = None
|
|
|
|
chain_name: str = None
|
|
|
|
do_search: bool = False
|
|
|
|
doc_engine_name: str = None
|
|
|
|
code_engine_name: str = None
|
|
|
|
cb_search_type: str = None
|
|
|
|
search_engine_name: str = None
|
|
|
|
top_k: int = 3
|
|
|
|
score_threshold: float = 1.0
|
|
|
|
do_doc_retrieval: bool = False
|
|
|
|
do_code_retrieval: bool = False
|
|
|
|
do_tool_retrieval: bool = False
|
|
|
|
history_node_list: List[str] = []
|
2023-12-26 11:41:53 +08:00
|
|
|
# user's customed kargs for init or end action
|
|
|
|
customed_kargs: dict = {}
|
2023-12-07 20:17:21 +08:00
|
|
|
|
2024-01-26 14:03:25 +08:00
|
|
|
|
|
|
|
@root_validator(pre=True)
|
|
|
|
def check_card_number_omitted(cls, values):
|
|
|
|
input_query = values.get("input_query")
|
|
|
|
origin_query = values.get("origin_query")
|
|
|
|
role_content = values.get("role_content")
|
|
|
|
if input_query is None:
|
|
|
|
values["input_query"] = origin_query or role_content
|
|
|
|
if role_content is None:
|
|
|
|
values["role_content"] = origin_query
|
|
|
|
return values
|
|
|
|
|
|
|
|
# pydantic>=2.0
|
|
|
|
# @model_validator(mode='after')
|
|
|
|
# def check_passwords_match(self) -> 'Message':
|
|
|
|
# if self.input_query is None:
|
|
|
|
# self.input_query = self.origin_query or self.role_content
|
|
|
|
# if self.role_content is None:
|
|
|
|
# self.role_content = self.origin_query
|
|
|
|
# return self
|
|
|
|
|
2023-12-07 20:17:21 +08:00
|
|
|
def to_tuple_message(self, return_all: bool = True, content_key="role_content"):
|
|
|
|
role_content = self.to_str_content(False, content_key)
|
|
|
|
if return_all:
|
|
|
|
return (self.role_name, role_content)
|
|
|
|
else:
|
|
|
|
return (role_content)
|
|
|
|
|
2024-01-26 14:03:25 +08:00
|
|
|
def to_dict_message(self, ):
|
|
|
|
return vars(self)
|
2023-12-07 20:17:21 +08:00
|
|
|
|
2024-01-26 14:03:25 +08:00
|
|
|
def to_str_content(self, return_all: bool = True, content_key="role_content", with_tag=False):
|
2023-12-07 20:17:21 +08:00
|
|
|
if content_key == "role_content":
|
|
|
|
role_content = self.role_content or self.input_query
|
|
|
|
elif content_key == "step_content":
|
|
|
|
role_content = self.step_content or self.role_content or self.input_query
|
2024-01-26 14:03:25 +08:00
|
|
|
elif content_key == "parsed_output":
|
|
|
|
role_content = "\n".join([f"**{k}:** {v}" for k, v in self.parsed_output.items()])
|
|
|
|
elif content_key == "parsed_output_list":
|
|
|
|
role_content = "\n".join([f"**{k}:** {v}" for po in self.parsed_output_list for k,v in po.items()])
|
2023-12-07 20:17:21 +08:00
|
|
|
else:
|
|
|
|
role_content = self.role_content or self.input_query
|
|
|
|
|
2024-01-26 14:03:25 +08:00
|
|
|
if with_tag:
|
|
|
|
start_tag = f"<{self.role_type}-{self.role_name}-message>"
|
|
|
|
end_tag = f"</{self.role_type}-{self.role_name}-message>"
|
|
|
|
return f"{start_tag}\n{role_content}\n{end_tag}"
|
2023-12-07 20:17:21 +08:00
|
|
|
else:
|
|
|
|
return role_content
|
|
|
|
|
|
|
|
def __str__(self) -> str:
|
|
|
|
# key_str = '\n'.join([k for k, v in vars(self).items()])
|
|
|
|
# logger.debug(f"{key_str}")
|
|
|
|
return "\n".join([": ".join([k, str(v)]) for k, v in vars(self).items()])
|
|
|
|
|