update feature: access more LLM Models by fastchat from langchain-chatchat
This commit is contained in:
parent
ac0000890e
commit
b35e849e9b
|
@ -14,7 +14,7 @@ logging.basicConfig(format=LOG_FORMAT)
|
|||
# os.environ["OPENAI_PROXY"] = "socks5h://127.0.0.1:13659"
|
||||
os.environ["API_BASE_URL"] = "http://openai.com/v1/chat/completions"
|
||||
os.environ["OPENAI_API_KEY"] = ""
|
||||
os.environ["DUCKDUCKGO_PROXY"] = "socks5://127.0.0.1:13659"
|
||||
os.environ["DUCKDUCKGO_PROXY"] = os.environ.get("DUCKDUCKGO_PROXY") or "socks5://127.0.0.1:13659"
|
||||
os.environ["BAIDU_OCR_API_KEY"] = ""
|
||||
os.environ["BAIDU_OCR_SECRET_KEY"] = ""
|
||||
|
||||
|
@ -60,46 +60,26 @@ ONLINE_LLM_MODEL = {
|
|||
"api_key": "",
|
||||
"openai_proxy": "",
|
||||
},
|
||||
"example": {
|
||||
"version": "gpt-3.5", # 采用openai接口做示例
|
||||
"api_base_url": "https://api.openai.com/v1",
|
||||
"api_key": "",
|
||||
"provider": "ExampleWorker",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# 建议使用chat模型,不要使用base,无法获取正确输出
|
||||
llm_model_dict = {
|
||||
"chatglm-6b": {
|
||||
"local_model_path": "THUDM/chatglm-6b",
|
||||
"api_base_url": "http://localhost:8888/v1", # "name"修改为fastchat服务中的"api_base_url"
|
||||
"api_key": "EMPTY"
|
||||
},
|
||||
|
||||
"chatglm-6b-int4": {
|
||||
"local_model_path": "THUDM/chatglm2-6b-int4/",
|
||||
"api_base_url": "http://localhost:8888/v1", # "name"修改为fastchat服务中的"api_base_url"
|
||||
"api_key": "EMPTY"
|
||||
},
|
||||
|
||||
"chatglm2-6b": {
|
||||
"local_model_path": "THUDM/chatglm2-6b",
|
||||
"api_base_url": "http://localhost:8888/v1", # URL需要与运行fastchat服务端的server_config.FSCHAT_OPENAI_API一致
|
||||
"api_key": "EMPTY"
|
||||
},
|
||||
|
||||
"chatglm2-6b-int4": {
|
||||
"local_model_path": "THUDM/chatglm2-6b-int4",
|
||||
"api_base_url": "http://localhost:8888/v1", # URL需要与运行fastchat服务端的server_config.FSCHAT_OPENAI_API一致
|
||||
"api_key": "EMPTY"
|
||||
},
|
||||
|
||||
"chatglm2-6b-32k": {
|
||||
"local_model_path": "THUDM/chatglm2-6b-32k", # "THUDM/chatglm2-6b-32k",
|
||||
"api_base_url": "http://localhost:8888/v1", # "URL需要与运行fastchat服务端的server_config.FSCHAT_OPENAI_API一致
|
||||
"api_key": "EMPTY"
|
||||
},
|
||||
|
||||
"vicuna-13b-hf": {
|
||||
"local_model_path": "",
|
||||
"api_base_url": "http://localhost:8888/v1", # "name"修改为fastchat服务中的"api_base_url"
|
||||
"api_key": "EMPTY"
|
||||
},
|
||||
|
||||
# 以下模型经过测试可接入,配置仿照上述即可
|
||||
# 'codellama_34b', 'Baichuan2-13B-Base', 'Baichuan2-13B-Chat', 'baichuan2-7b-base', 'baichuan2-7b-chat',
|
||||
# 'internlm-7b-base', 'internlm-chat-7b', 'chatglm2-6b', 'qwen-14b-base', 'qwen-14b-chat', 'qwen-1-8B-Chat',
|
||||
# 'Qwen-7B', 'Qwen-7B-Chat', 'qwen-7b-base-v1.1', 'qwen-7b-chat-v1.1', 'chatglm3-6b', 'chatglm3-6b-32k',
|
||||
# 'chatglm3-6b-base', 'Qwen-72B-Chat-Int4'
|
||||
# 调用chatgpt时如果报出: urllib3.exceptions.MaxRetryError: HTTPSConnectionPool(host='api.openai.com', port=443):
|
||||
# Max retries exceeded with url: /v1/chat/completions
|
||||
# 则需要将urllib3版本修改为1.25.11
|
||||
|
@ -115,8 +95,23 @@ llm_model_dict = {
|
|||
"api_base_url": os.environ.get("API_BASE_URL"),
|
||||
"api_key": os.environ.get("OPENAI_API_KEY")
|
||||
},
|
||||
"gpt-3.5-turbo-16k": {
|
||||
"local_model_path": "gpt-3.5-turbo-16k",
|
||||
"api_base_url": os.environ.get("API_BASE_URL"),
|
||||
"api_key": os.environ.get("OPENAI_API_KEY")
|
||||
},
|
||||
}
|
||||
|
||||
# 建议使用chat模型,不要使用base,无法获取正确输出
|
||||
VLLM_MODEL_DICT = {
|
||||
'chatglm2-6b': "THUDM/chatglm-6b",
|
||||
}
|
||||
# 以下模型经过测试可接入,配置仿照上述即可
|
||||
# 'codellama_34b', 'Baichuan2-13B-Base', 'Baichuan2-13B-Chat', 'baichuan2-7b-base', 'baichuan2-7b-chat',
|
||||
# 'internlm-7b-base', 'internlm-chat-7b', 'chatglm2-6b', 'qwen-14b-base', 'qwen-14b-chat', 'qwen-1-8B-Chat',
|
||||
# 'Qwen-7B', 'Qwen-7B-Chat', 'qwen-7b-base-v1.1', 'qwen-7b-chat-v1.1', 'chatglm3-6b', 'chatglm3-6b-32k',
|
||||
# 'chatglm3-6b-base', 'Qwen-72B-Chat-Int4'
|
||||
|
||||
|
||||
LOCAL_LLM_MODEL_DIR = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "llm_models")
|
||||
llm_model_dict_c = {}
|
||||
|
@ -133,8 +128,12 @@ llm_model_dict = llm_model_dict_c
|
|||
|
||||
|
||||
# LLM 名称
|
||||
LLM_MODEL = "gpt-3.5-turbo"
|
||||
LLM_MODELs = ["chatglm2-6b"]
|
||||
# EMBEDDING_ENGINE = 'openai'
|
||||
EMBEDDING_ENGINE = 'model'
|
||||
EMBEDDING_MODEL = "text2vec-base"
|
||||
# LLM_MODEL = "gpt-4"
|
||||
LLM_MODEL = "gpt-3.5-turbo-16k"
|
||||
LLM_MODELs = ["gpt-3.5-turbo-16k"]
|
||||
USE_FASTCHAT = "gpt" not in LLM_MODEL # 判断是否进行fastchat
|
||||
|
||||
# LLM 运行设备
|
||||
|
@ -161,7 +160,7 @@ NLTK_DATA_PATH = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__
|
|||
JUPYTER_WORK_PATH = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "jupyter_work")
|
||||
|
||||
# WEB_CRAWL存储路径
|
||||
WEB_CRAWL_PATH = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "sources/docs")
|
||||
WEB_CRAWL_PATH = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "knowledge_base")
|
||||
|
||||
# NEBULA_DATA存储路径
|
||||
NELUBA_PATH = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "data/neluba_data")
|
||||
|
@ -257,3 +256,5 @@ BING_SUBSCRIPTION_KEY = ""
|
|||
# 通过增加标题判断,判断哪些文本为标题,并在metadata中进行标记;
|
||||
# 然后将文本与往上一级的标题进行拼合,实现文本信息的增强。
|
||||
ZH_TITLE_ENHANCE = False
|
||||
|
||||
log_verbose = False
|
|
@ -100,12 +100,26 @@ FSCHAT_MODEL_WORKERS = {
|
|||
# "stream_interval": 2,
|
||||
# "no_register": False,
|
||||
},
|
||||
"chatglm2-6b": {
|
||||
"port": 20003
|
||||
},
|
||||
"baichuan2-7b-base": {
|
||||
"port": 20004
|
||||
}
|
||||
'codellama_34b': {'host': DEFAULT_BIND_HOST, 'port': 20002},
|
||||
'Baichuan2-13B-Base': {'host': DEFAULT_BIND_HOST, 'port': 20003},
|
||||
'Baichuan2-13B-Chat': {'host': DEFAULT_BIND_HOST, 'port': 20004},
|
||||
'baichuan2-7b-base': {'host': DEFAULT_BIND_HOST, 'port': 20005},
|
||||
'baichuan2-7b-chat': {'host': DEFAULT_BIND_HOST, 'port': 20006},
|
||||
'internlm-7b-base': {'host': DEFAULT_BIND_HOST, 'port': 20007},
|
||||
'internlm-chat-7b': {'host': DEFAULT_BIND_HOST, 'port': 20008},
|
||||
'chatglm2-6b': {'host': DEFAULT_BIND_HOST, 'port': 20009},
|
||||
'qwen-14b-base': {'host': DEFAULT_BIND_HOST, 'port': 20010},
|
||||
'qwen-14b-chat': {'host': DEFAULT_BIND_HOST, 'port': 20011},
|
||||
'qwen-1-8B-Chat': {'host': DEFAULT_BIND_HOST, 'port': 20012},
|
||||
'Qwen-7B': {'host': DEFAULT_BIND_HOST, 'port': 20013},
|
||||
'Qwen-7B-Chat': {'host': DEFAULT_BIND_HOST, 'port': 20014},
|
||||
'qwen-7b-base-v1.1': {'host': DEFAULT_BIND_HOST, 'port': 20015},
|
||||
'qwen-7b-chat-v1.1': {'host': DEFAULT_BIND_HOST, 'port': 20016},
|
||||
'chatglm3-6b': {'host': DEFAULT_BIND_HOST, 'port': 20017},
|
||||
'chatglm3-6b-32k': {'host': DEFAULT_BIND_HOST, 'port': 20018},
|
||||
'chatglm3-6b-base': {'host': DEFAULT_BIND_HOST, 'port': 20019},
|
||||
'Qwen-72B-Chat-Int4': {'host': DEFAULT_BIND_HOST, 'port': 20020},
|
||||
'gpt-3.5-turbo': {'host': DEFAULT_BIND_HOST, 'port': 20021}
|
||||
}
|
||||
# fastchat multi model worker server
|
||||
FSCHAT_MULTI_MODEL_WORKERS = {
|
||||
|
|
|
@ -275,7 +275,7 @@ class AgentChat:
|
|||
"code_docs": [str(doc) for doc in message.code_docs],
|
||||
"related_nodes": [doc.get_related_node() for idx, doc in enumerate(message.code_docs) if idx==0],
|
||||
"figures": message.figures,
|
||||
"step_content": step_content,
|
||||
"step_content": step_content or final_content,
|
||||
"final_content": final_content,
|
||||
}
|
||||
|
||||
|
|
|
@ -22,7 +22,7 @@ HISTORY_VERTEX_SCORE = 5
|
|||
VERTEX_MERGE_RATIO = 0.5
|
||||
|
||||
# search_by_description
|
||||
MAX_DISTANCE = 0.5
|
||||
MAX_DISTANCE = 1000
|
||||
|
||||
|
||||
class CodeSearch:
|
||||
|
|
|
@ -17,7 +17,7 @@ from dev_opsgpt.connector.configs.prompts import BASE_PROMPT_INPUT, QUERY_CONTEX
|
|||
from dev_opsgpt.connector.message_process import MessageUtils
|
||||
from dev_opsgpt.connector.configs.agent_config import REACT_PROMPT_INPUT, QUERY_CONTEXT_PROMPT_INPUT, PLAN_PROMPT_INPUT
|
||||
|
||||
from dev_opsgpt.llm_models import getChatModel
|
||||
from dev_opsgpt.llm_models import getChatModel, getExtraModel
|
||||
from dev_opsgpt.connector.utils import parse_section
|
||||
|
||||
|
||||
|
@ -44,7 +44,7 @@ class BaseAgent:
|
|||
|
||||
self.task = task
|
||||
self.role = role
|
||||
self.message_utils = MessageUtils()
|
||||
self.message_utils = MessageUtils(role)
|
||||
self.llm = self.create_llm_engine(temperature, stop)
|
||||
self.memory = self.init_history(memory)
|
||||
self.chat_turn = chat_turn
|
||||
|
@ -68,6 +68,7 @@ class BaseAgent:
|
|||
'''agent reponse from multi-message'''
|
||||
# insert query into memory
|
||||
query_c = copy.deepcopy(query)
|
||||
query_c = self.start_action_step(query_c)
|
||||
|
||||
self_memory = self.memory if self.do_use_self_memory else None
|
||||
# create your llm prompt
|
||||
|
@ -80,23 +81,30 @@ class BaseAgent:
|
|||
role_name=self.role.role_name,
|
||||
role_type="ai", #self.role.role_type,
|
||||
role_content=content,
|
||||
role_contents=[content],
|
||||
step_content=content,
|
||||
input_query=query_c.input_query,
|
||||
tools=query_c.tools,
|
||||
parsed_output_list=[query.parsed_output]
|
||||
parsed_output_list=[query.parsed_output],
|
||||
customed_kargs=query_c.customed_kargs
|
||||
)
|
||||
# common parse llm' content to message
|
||||
output_message = self.message_utils.parser(output_message)
|
||||
if self.do_filter:
|
||||
output_message = self.message_utils.filter(output_message)
|
||||
|
||||
# action step
|
||||
output_message, observation_message = self.message_utils.step_router(output_message, history, background, memory_pool=memory_pool)
|
||||
output_message.parsed_output_list.append(output_message.parsed_output)
|
||||
if observation_message:
|
||||
output_message.parsed_output_list.append(observation_message.parsed_output)
|
||||
# update self_memory
|
||||
self.append_history(query_c)
|
||||
self.append_history(output_message)
|
||||
# logger.info(f"{self.role.role_name} currenct question: {output_message.input_query}\nllm_step_run: {output_message.role_content}")
|
||||
output_message.input_query = output_message.role_content
|
||||
output_message.parsed_output_list.append(output_message.parsed_output)
|
||||
# output_message.parsed_output_list.append(output_message.parsed_output) # 与上述重复?
|
||||
# end
|
||||
output_message = self.message_utils.inherit_extrainfo(query, output_message)
|
||||
output_message = self.end_action_step(output_message)
|
||||
# update memory pool
|
||||
memory_pool.append(output_message)
|
||||
yield output_message
|
||||
|
@ -110,7 +118,7 @@ class BaseAgent:
|
|||
doc_infos = self.create_doc_prompt(query)
|
||||
code_infos = self.create_codedoc_prompt(query)
|
||||
#
|
||||
formatted_tools, tool_names = self.create_tools_prompt(query)
|
||||
formatted_tools, tool_names, _ = self.create_tools_prompt(query)
|
||||
task_prompt = self.create_task_prompt(query)
|
||||
background_prompt = self.create_background_prompt(background, control_key="step_content")
|
||||
history_prompt = self.create_history_prompt(history)
|
||||
|
@ -185,7 +193,10 @@ class BaseAgent:
|
|||
context = memory_pool_select_by_agent_key_context
|
||||
prompt += "\n**Context:**\n" + context + "\n" + input_query
|
||||
elif input_key == "DocInfos":
|
||||
prompt += "\n**DocInfos:**\n" + DocInfos
|
||||
if DocInfos:
|
||||
prompt += "\n**DocInfos:**\n" + DocInfos
|
||||
else:
|
||||
prompt += "\n**DocInfos:**\n" + "Empty"
|
||||
elif input_key == "Question":
|
||||
prompt += "\n**Question:**\n" + input_query
|
||||
|
||||
|
@ -231,12 +242,15 @@ class BaseAgent:
|
|||
def create_tools_prompt(self, message: Message) -> str:
|
||||
tools = message.tools
|
||||
tool_strings = []
|
||||
tools_descs = []
|
||||
for tool in tools:
|
||||
args_schema = re.sub("}", "}}}}", re.sub("{", "{{{{", str(tool.args)))
|
||||
tool_strings.append(f"{tool.name}: {tool.description}, args: {args_schema}")
|
||||
tools_descs.append(f"{tool.name}: {tool.description}")
|
||||
formatted_tools = "\n".join(tool_strings)
|
||||
tools_desc_str = "\n".join(tools_descs)
|
||||
tool_names = ", ".join([tool.name for tool in tools])
|
||||
return formatted_tools, tool_names
|
||||
return formatted_tools, tool_names, tools_desc_str
|
||||
|
||||
def create_task_prompt(self, message: Message) -> str:
|
||||
task = message.task or self.task
|
||||
|
@ -276,19 +290,22 @@ class BaseAgent:
|
|||
|
||||
def create_llm_engine(self, temperature=0.2, stop=None):
|
||||
return getChatModel(temperature=temperature, stop=stop)
|
||||
|
||||
def registry_actions(self, actions):
|
||||
'''registry llm's actions'''
|
||||
self.action_list = actions
|
||||
|
||||
# def filter(self, message: Message, stop=None) -> Message:
|
||||
def start_action_step(self, message: Message) -> Message:
|
||||
'''do action before agent predict '''
|
||||
# action_json = self.start_action()
|
||||
# message["customed_kargs"]["xx"] = action_json
|
||||
return message
|
||||
|
||||
# tool_params = self.parser_spec_key(message.role_content, "tool_params")
|
||||
# code_content = self.parser_spec_key(message.role_content, "code_content")
|
||||
# plan = self.parser_spec_key(message.role_content, "plan")
|
||||
# plans = self.parser_spec_key(message.role_content, "plans", do_search=False)
|
||||
# content = self.parser_spec_key(message.role_content, "content", do_search=False)
|
||||
|
||||
# # logger.debug(f"tool_params: {tool_params}, code_content: {code_content}, plan: {plan}, plans: {plans}, content: {content}")
|
||||
# role_content = tool_params or code_content or plan or plans or content
|
||||
# message.role_content = role_content or message.role_content
|
||||
# return message
|
||||
def end_action_step(self, message: Message) -> Message:
|
||||
'''do action after agent predict '''
|
||||
# action_json = self.end_action()
|
||||
# message["customed_kargs"]["xx"] = action_json
|
||||
return message
|
||||
|
||||
def token_usage(self, ):
|
||||
'''calculate the usage of token'''
|
||||
|
@ -324,339 +341,6 @@ class BaseAgent:
|
|||
message_c.parsed_output = {k: v for k,v in message_c.parsed_output.items() if k in self.focus_message_keys}
|
||||
message_c.parsed_output_list = [{k: v for k,v in parsed_output.items() if k in self.focus_message_keys} for parsed_output in message_c.parsed_output_list]
|
||||
return message_c
|
||||
|
||||
# def get_extra_infos(self, message: Message) -> Message:
|
||||
# ''''''
|
||||
# if self.do_search:
|
||||
# message = self.get_search_retrieval(message)
|
||||
|
||||
# if self.do_doc_retrieval:
|
||||
# message = self.get_doc_retrieval(message)
|
||||
|
||||
# if self.do_tool_retrieval:
|
||||
# message = self.get_tool_retrieval(message)
|
||||
|
||||
# return message
|
||||
|
||||
# def get_search_retrieval(self, message: Message,) -> Message:
|
||||
# SEARCH_ENGINES = {"duckduckgo": DDGSTool}
|
||||
# search_docs = []
|
||||
# for idx, doc in enumerate(SEARCH_ENGINES["duckduckgo"].run(message.role_content, 3)):
|
||||
# doc.update({"index": idx})
|
||||
# search_docs.append(Doc(**doc))
|
||||
# message.search_docs = search_docs
|
||||
# return message
|
||||
|
||||
# def get_doc_retrieval(self, message: Message) -> Message:
|
||||
# query = message.role_content
|
||||
# knowledge_basename = message.doc_engine_name
|
||||
# top_k = message.top_k
|
||||
# score_threshold = message.score_threshold
|
||||
# if knowledge_basename:
|
||||
# docs = DocRetrieval.run(query, knowledge_basename, top_k, score_threshold)
|
||||
# message.db_docs = [Doc(**doc) for doc in docs]
|
||||
# return message
|
||||
|
||||
# def get_code_retrieval(self, message: Message) -> Message:
|
||||
# # DocRetrieval.run("langchain是什么", "DSADSAD")
|
||||
# query = message.input_query
|
||||
# code_engine_name = message.code_engine_name
|
||||
# history_node_list = message.history_node_list
|
||||
# code_docs = CodeRetrieval.run(code_engine_name, query, code_limit=message.top_k, history_node_list=history_node_list)
|
||||
# message.code_docs = [CodeDoc(**doc) for doc in code_docs]
|
||||
# return message
|
||||
|
||||
# def get_tool_retrieval(self, message: Message) -> Message:
|
||||
# return message
|
||||
|
||||
# def step_router(self, message: Message) -> tuple[Message, ...]:
|
||||
# ''''''
|
||||
# # message = self.parser(message)
|
||||
# # logger.debug(f"message.action_status: {message.action_status}")
|
||||
# observation_message = None
|
||||
# if message.action_status == ActionStatus.CODING:
|
||||
# message, observation_message = self.code_step(message)
|
||||
# elif message.action_status == ActionStatus.TOOL_USING:
|
||||
# message, observation_message = self.tool_step(message)
|
||||
|
||||
# return message, observation_message
|
||||
|
||||
# def code_step(self, message: Message) -> Message:
|
||||
# '''execute code'''
|
||||
# # logger.debug(f"message.role_content: {message.role_content}, message.code_content: {message.code_content}")
|
||||
# code_answer = self.codebox.chat('```python\n{}```'.format(message.code_content))
|
||||
# code_prompt = f"执行上述代码后存在报错信息为 {code_answer.code_exe_response},需要进行修复" \
|
||||
# if code_answer.code_exe_type == "error" else f"执行上述代码后返回信息为 {code_answer.code_exe_response}"
|
||||
|
||||
# observation_message = Message(
|
||||
# role_name="observation",
|
||||
# role_type="func", #self.role.role_type,
|
||||
# role_content="",
|
||||
# step_content="",
|
||||
# input_query=message.code_content,
|
||||
# )
|
||||
# uid = str(uuid.uuid1())
|
||||
# if code_answer.code_exe_type == "image/png":
|
||||
# message.figures[uid] = code_answer.code_exe_response
|
||||
# message.code_answer = f"\n**Observation:**: 执行上述代码后生成一张图片, 图片名为{uid}\n"
|
||||
# message.observation = f"\n**Observation:**: 执行上述代码后生成一张图片, 图片名为{uid}\n"
|
||||
# message.step_content += f"\n**Observation:**: 执行上述代码后生成一张图片, 图片名为{uid}\n"
|
||||
# message.step_contents += [f"\n**Observation:**: 执行上述代码后生成一张图片, 图片名为{uid}\n"]
|
||||
# # message.role_content += f"\n**Observation:**:执行上述代码后生成一张图片, 图片名为{uid}\n"
|
||||
# observation_message.role_content = f"\n**Observation:**: 执行上述代码后生成一张图片, 图片名为{uid}\n"
|
||||
# observation_message.parsed_output = {"Observation": f"执行上述代码后生成一张图片, 图片名为{uid}"}
|
||||
# else:
|
||||
# message.code_answer = code_answer.code_exe_response
|
||||
# message.observation = code_answer.code_exe_response
|
||||
# message.step_content += f"\n**Observation:**: {code_prompt}\n"
|
||||
# message.step_contents += [f"\n**Observation:**: {code_prompt}\n"]
|
||||
# # message.role_content += f"\n**Observation:**: {code_prompt}\n"
|
||||
# observation_message.role_content = f"\n**Observation:**: {code_prompt}\n"
|
||||
# observation_message.parsed_output = {"Observation": code_prompt}
|
||||
# # logger.info(f"**Observation:** {message.action_status}, {message.observation}")
|
||||
# return message, observation_message
|
||||
|
||||
# def tool_step(self, message: Message) -> Message:
|
||||
# '''execute tool'''
|
||||
# # logger.debug(f"{message}")
|
||||
# observation_message = Message(
|
||||
# role_name="observation",
|
||||
# role_type="function", #self.role.role_type,
|
||||
# role_content="\n**Observation:** there is no tool can execute\n" ,
|
||||
# step_content="",
|
||||
# input_query=str(message.tool_params),
|
||||
# tools=message.tools,
|
||||
# )
|
||||
# # logger.debug(f"message: {message.action_status}, {message.tool_name}, {message.tool_params}")
|
||||
# tool_names = [tool.name for tool in message.tools]
|
||||
# if message.tool_name not in tool_names:
|
||||
# message.tool_answer = "\n**Observation:** there is no tool can execute\n"
|
||||
# message.observation = "\n**Observation:** there is no tool can execute\n"
|
||||
# # message.role_content += f"\n**Observation:**: 不存在可以执行的tool\n"
|
||||
# message.step_content += f"\n**Observation:** there is no tool can execute\n"
|
||||
# message.step_contents += [f"\n**Observation:** there is no tool can execute\n"]
|
||||
# observation_message.role_content = f"\n**Observation:** there is no tool can execute\n"
|
||||
# observation_message.parsed_output = {"Observation": "there is no tool can execute\n"}
|
||||
# for tool in message.tools:
|
||||
# if tool.name == message.tool_name:
|
||||
# tool_res = tool.func(**message.tool_params.get("tool_params", {}))
|
||||
# logger.debug(f"tool_res {tool_res}")
|
||||
# message.tool_answer = tool_res
|
||||
# message.observation = tool_res
|
||||
# # message.role_content += f"\n**Observation:**: {tool_res}\n"
|
||||
# message.step_content += f"\n**Observation:** {tool_res}\n"
|
||||
# message.step_contents += [f"\n**Observation:** {tool_res}\n"]
|
||||
# observation_message.role_content = f"\n**Observation:** {tool_res}\n"
|
||||
# observation_message.parsed_output = {"Observation": tool_res}
|
||||
# break
|
||||
|
||||
# # logger.info(f"**Observation:** {message.action_status}, {message.observation}")
|
||||
# return message, observation_message
|
||||
|
||||
# def parser(self, message: Message) -> Message:
|
||||
# ''''''
|
||||
# content = message.role_content
|
||||
# parser_keys = ["action", "code_content", "code_filename", "tool_params", "plans"]
|
||||
# try:
|
||||
# s_json = self._parse_json(content)
|
||||
# message.action_status = s_json.get("action")
|
||||
# message.code_content = s_json.get("code_content")
|
||||
# message.tool_params = s_json.get("tool_params")
|
||||
# message.tool_name = s_json.get("tool_name")
|
||||
# message.code_filename = s_json.get("code_filename")
|
||||
# message.plans = s_json.get("plans")
|
||||
# # for parser_key in parser_keys:
|
||||
# # message.action_status = content.get(parser_key)
|
||||
# except Exception as e:
|
||||
# # logger.warning(f"{traceback.format_exc()}")
|
||||
# def parse_text_to_dict(text):
|
||||
# # Define a regular expression pattern to capture the key and value
|
||||
# main_pattern = r"\*\*(.+?):\*\*\s*(.*?)\s*(?=\*\*|$)"
|
||||
# list_pattern = r'```python\n(.*?)```'
|
||||
|
||||
# # Use re.findall to find all main matches in the text
|
||||
# main_matches = re.findall(main_pattern, text, re.DOTALL)
|
||||
|
||||
# # Convert main matches to a dictionary
|
||||
# parsed_dict = {key.strip(): value.strip() for key, value in main_matches}
|
||||
|
||||
# for k, v in parsed_dict.items():
|
||||
# for pattern in [list_pattern]:
|
||||
# if "PLAN" != k: continue
|
||||
# match_value = re.search(pattern, v, re.DOTALL)
|
||||
# if match_value:
|
||||
# # Add the code block to the dictionary
|
||||
# parsed_dict[k] = eval(match_value.group(1).strip())
|
||||
# break
|
||||
|
||||
# return parsed_dict
|
||||
|
||||
# def extract_content_from_backticks(text):
|
||||
# code_blocks = []
|
||||
# lines = text.split('\n')
|
||||
# is_code_block = False
|
||||
# code_block = ''
|
||||
# language = ''
|
||||
# for line in lines:
|
||||
# if line.startswith('```') and not is_code_block:
|
||||
# is_code_block = True
|
||||
# language = line[3:]
|
||||
# code_block = ''
|
||||
# elif line.startswith('```') and is_code_block:
|
||||
# is_code_block = False
|
||||
# code_blocks.append({language.strip(): code_block.strip()})
|
||||
# elif is_code_block:
|
||||
# code_block += line + '\n'
|
||||
# return code_blocks
|
||||
|
||||
# def parse_dict_to_dict(parsed_dict):
|
||||
# code_pattern = r'```python\n(.*?)```'
|
||||
# tool_pattern = r'```tool_params\n(.*?)```'
|
||||
|
||||
# pattern_dict = {"code": code_pattern, "tool_params": tool_pattern}
|
||||
# spec_parsed_dict = copy.deepcopy(parsed_dict)
|
||||
# for key, pattern in pattern_dict.items():
|
||||
# for k, text in parsed_dict.items():
|
||||
# # Search for the code block
|
||||
# if not isinstance(text, str): continue
|
||||
# _match = re.search(pattern, text, re.DOTALL)
|
||||
# if _match:
|
||||
# # Add the code block to the dictionary
|
||||
# try:
|
||||
# spec_parsed_dict[key] = json.loads(_match.group(1).strip())
|
||||
# except:
|
||||
# spec_parsed_dict[key] = _match.group(1).strip()
|
||||
# break
|
||||
# return spec_parsed_dict
|
||||
# def parse_dict_to_dict(parsed_dict):
|
||||
# code_pattern = r'```python\n(.*?)```'
|
||||
# tool_pattern = r'```json\n(.*?)```'
|
||||
|
||||
# pattern_dict = {"code": code_pattern, "json": tool_pattern}
|
||||
# spec_parsed_dict = copy.deepcopy(parsed_dict)
|
||||
# for key, pattern in pattern_dict.items():
|
||||
# for k, text in parsed_dict.items():
|
||||
# # Search for the code block
|
||||
# if not isinstance(text, str): continue
|
||||
# _match = re.search(pattern, text, re.DOTALL)
|
||||
# if _match:
|
||||
# # Add the code block to the dictionary
|
||||
# logger.debug(f"dsadsa {text}")
|
||||
# try:
|
||||
# spec_parsed_dict[key] = json.loads(_match.group(1).strip())
|
||||
# except:
|
||||
# spec_parsed_dict[key] = _match.group(1).strip()
|
||||
# break
|
||||
# return spec_parsed_dict
|
||||
|
||||
# parsed_dict = parse_text_to_dict(content)
|
||||
# spec_parsed_dict = parse_dict_to_dict(parsed_dict)
|
||||
# action_value = parsed_dict.get('Action Status')
|
||||
# if action_value:
|
||||
# action_value = action_value.lower()
|
||||
# logger.info(f'{self.role.role_name}: action_value: {action_value}')
|
||||
# # action_value = self._match(r"'action':\s*'([^']*)'", content) if "'action'" in content else self._match(r'"action":\s*"([^"]*)"', content)
|
||||
|
||||
# code_content_value = spec_parsed_dict.get('code')
|
||||
# # code_content_value = self._match(r"'code_content':\s*'([^']*)'", content) if "'code_content'" in content else self._match(r'"code_content":\s*"([^"]*)"', content)
|
||||
# filename_value = self._match(r"'code_filename':\s*'([^']*)'", content) if "'code_filename'" in content else self._match(r'"code_filename":\s*"([^"]*)"', content)
|
||||
# tool_params_value = spec_parsed_dict.get('tool_params')
|
||||
# # tool_params_value = self._match(r"'tool_params':\s*(\{[^{}]*\})", content, do_json=True) if "'tool_params'" in content \
|
||||
# # else self._match(r'"tool_params":\s*(\{[^{}]*\})', content, do_json=True)
|
||||
# tool_name_value = self._match(r"'tool_name':\s*'([^']*)'", content) if "'tool_name'" in content else self._match(r'"tool_name":\s*"([^"]*)"', content)
|
||||
# plans_value = self._match(r"'plans':\s*(\[.*?\])", content, do_search=False) if "'plans'" in content else self._match(r'"plans":\s*(\[.*?\])', content, do_search=False, )
|
||||
# # re解析
|
||||
# message.action_status = action_value or "default"
|
||||
# message.code_content = code_content_value
|
||||
# message.code_filename = filename_value
|
||||
# message.tool_params = tool_params_value
|
||||
# message.tool_name = tool_name_value
|
||||
# message.plans = plans_value
|
||||
# message.parsed_output = parsed_dict
|
||||
# message.spec_parsed_output = spec_parsed_dict
|
||||
# code_content_value = spec_parsed_dict.get('code')
|
||||
# # code_content_value = self._match(r"'code_content':\s*'([^']*)'", content) if "'code_content'" in content else self._match(r'"code_content":\s*"([^"]*)"', content)
|
||||
# filename_value = self._match(r"'code_filename':\s*'([^']*)'", content) if "'code_filename'" in content else self._match(r'"code_filename":\s*"([^"]*)"', content)
|
||||
# logger.debug(spec_parsed_dict)
|
||||
|
||||
# if action_value == 'tool_using':
|
||||
# tool_params_value = spec_parsed_dict.get('json')
|
||||
# else:
|
||||
# tool_params_value = None
|
||||
# # tool_params_value = self._match(r"'tool_params':\s*(\{[^{}]*\})", content, do_json=True) if "'tool_params'" in content \
|
||||
# # else self._match(r'"tool_params":\s*(\{[^{}]*\})', content, do_json=True)
|
||||
# tool_name_value = self._match(r"'tool_name':\s*'([^']*)'", content) if "'tool_name'" in content else self._match(r'"tool_name":\s*"([^"]*)"', content)
|
||||
# plans_value = self._match(r"'plans':\s*(\[.*?\])", content, do_search=False) if "'plans'" in content else self._match(r'"plans":\s*(\[.*?\])', content, do_search=False, )
|
||||
# # re解析
|
||||
# message.action_status = action_value or "default"
|
||||
# message.code_content = code_content_value
|
||||
# message.code_filename = filename_value
|
||||
# message.tool_params = tool_params_value
|
||||
# message.tool_name = tool_name_value
|
||||
# message.plans = plans_value
|
||||
# message.parsed_output = parsed_dict
|
||||
# message.spec_parsed_output = spec_parsed_dict
|
||||
|
||||
# # logger.debug(f"确认当前的action: {message.action_status}")
|
||||
|
||||
# return message
|
||||
|
||||
# def parser_spec_key(self, content, key, do_search=True, do_json=False) -> str:
|
||||
# ''''''
|
||||
# key2pattern = {
|
||||
# "'action'": r"'action':\s*'([^']*)'", '"action"': r'"action":\s*"([^"]*)"',
|
||||
# "'code_content'": r"'code_content':\s*'([^']*)'", '"code_content"': r'"code_content":\s*"([^"]*)"',
|
||||
# "'code_filename'": r"'code_filename':\s*'([^']*)'", '"code_filename"': r'"code_filename":\s*"([^"]*)"',
|
||||
# "'tool_params'": r"'tool_params':\s*(\{[^{}]*\})", '"tool_params"': r'"tool_params":\s*(\{[^{}]*\})',
|
||||
# "'tool_name'": r"'tool_name':\s*'([^']*)'", '"tool_name"': r'"tool_name":\s*"([^"]*)"',
|
||||
# "'plans'": r"'plans':\s*(\[.*?\])", '"plans"': r'"plans":\s*(\[.*?\])',
|
||||
# "'content'": r"'content':\s*'([^']*)'", '"content"': r'"content":\s*"([^"]*)"',
|
||||
# }
|
||||
|
||||
# s_json = self._parse_json(content)
|
||||
# try:
|
||||
# if s_json and key in s_json:
|
||||
# return str(s_json[key])
|
||||
# except:
|
||||
# pass
|
||||
|
||||
# keystr = f"'{key}'" if f"'{key}'" in content else f'"{key}"'
|
||||
# return self._match(key2pattern.get(keystr, fr"'{key}':\s*'([^']*)'"), content, do_search=do_search, do_json=do_json)
|
||||
|
||||
# def _match(self, pattern, s, do_search=True, do_json=False):
|
||||
# try:
|
||||
# if do_search:
|
||||
# match = re.search(pattern, s)
|
||||
# if match:
|
||||
# value = match.group(1).replace("\\n", "\n")
|
||||
# if do_json:
|
||||
# value = json.loads(value)
|
||||
# else:
|
||||
# value = None
|
||||
# else:
|
||||
# match = re.findall(pattern, s, re.DOTALL)
|
||||
# if match:
|
||||
# value = match[0]
|
||||
# if do_json:
|
||||
# value = json.loads(value)
|
||||
# else:
|
||||
# value = None
|
||||
# except Exception as e:
|
||||
# logger.warning(f"{traceback.format_exc()}")
|
||||
|
||||
# # logger.debug(f"pattern: {pattern}, s: {s}, match: {match}")
|
||||
# return value
|
||||
|
||||
# def _parse_json(self, s):
|
||||
# try:
|
||||
# pattern = r"```([^`]+)```"
|
||||
# match = re.findall(pattern, s)
|
||||
# if match:
|
||||
# return eval(match[0])
|
||||
# except:
|
||||
# pass
|
||||
# return None
|
||||
|
||||
|
||||
def get_memory(self, content_key="role_content"):
|
||||
return self.memory.to_tuple_messages(content_key="step_content")
|
||||
|
|
|
@ -50,7 +50,7 @@ class CheckAgent(BaseAgent):
|
|||
doc_infos = self.create_doc_prompt(query)
|
||||
code_infos = self.create_codedoc_prompt(query)
|
||||
#
|
||||
formatted_tools, tool_names = self.create_tools_prompt(query)
|
||||
formatted_tools, tool_names, _ = self.create_tools_prompt(query)
|
||||
task_prompt = self.create_task_prompt(query)
|
||||
background_prompt = self.create_background_prompt(background)
|
||||
history_prompt = self.create_history_prompt(history)
|
||||
|
|
|
@ -55,7 +55,8 @@ class ExecutorAgent(BaseAgent):
|
|||
step_content="",
|
||||
input_query=query.input_query,
|
||||
tools=query.tools,
|
||||
parsed_output_list=[query.parsed_output]
|
||||
parsed_output_list=[query.parsed_output],
|
||||
customed_kargs=query.customed_kargs
|
||||
)
|
||||
|
||||
self_memory = self.memory if self.do_use_self_memory else None
|
||||
|
@ -64,6 +65,7 @@ class ExecutorAgent(BaseAgent):
|
|||
# 如果存在plan字段且plan字段为str的时候
|
||||
if "PLAN" not in query.parsed_output or isinstance(query.parsed_output.get("PLAN", []), str) or plan_step >= len(query.parsed_output.get("PLAN", [])):
|
||||
query_c = copy.deepcopy(query)
|
||||
query_c = self.start_action_step(query_c)
|
||||
query_c.parsed_output = {"Question": query_c.input_query}
|
||||
task_executor_memory.append(query_c)
|
||||
for output_message, task_executor_memory in self._arun_step(output_message, query_c, self_memory, history, background, memory_pool, task_executor_memory):
|
||||
|
@ -87,6 +89,7 @@ class ExecutorAgent(BaseAgent):
|
|||
yield output_message
|
||||
else:
|
||||
query_c = copy.deepcopy(query)
|
||||
query_c = self.start_action_step(query_c)
|
||||
task_content = query_c.parsed_output["PLAN"][plan_step]
|
||||
query_c.parsed_output = {"Question": task_content}
|
||||
task_executor_memory.append(query_c)
|
||||
|
@ -99,6 +102,8 @@ class ExecutorAgent(BaseAgent):
|
|||
# logger.info(f"{self.role.role_name} currenct question: {output_message.input_query}\nllm_executor_run: {output_message.step_content}")
|
||||
# logger.info(f"{self.role.role_name} currenct parserd_output_list: {output_message.parserd_output_list}")
|
||||
output_message.input_query = output_message.role_content
|
||||
# end_action_step
|
||||
output_message = self.end_action_step(output_message)
|
||||
# update memory pool
|
||||
memory_pool.append(output_message)
|
||||
yield output_message
|
||||
|
@ -113,9 +118,7 @@ class ExecutorAgent(BaseAgent):
|
|||
logger.debug(f"{self.role.role_name} content: {content}")
|
||||
|
||||
output_message.role_content = content
|
||||
output_message.role_contents += [content]
|
||||
output_message.step_content += "\n"+output_message.role_content
|
||||
output_message.step_contents + [output_message.role_content]
|
||||
|
||||
output_message = self.message_utils.parser(output_message)
|
||||
# according the output to choose one action for code_content or tool_content
|
||||
|
@ -141,7 +144,7 @@ class ExecutorAgent(BaseAgent):
|
|||
doc_infos = self.create_doc_prompt(query)
|
||||
code_infos = self.create_codedoc_prompt(query)
|
||||
#
|
||||
formatted_tools, tool_names = self.create_tools_prompt(query)
|
||||
formatted_tools, tool_names, _ = self.create_tools_prompt(query)
|
||||
task_prompt = self.create_task_prompt(query)
|
||||
background_prompt = self.create_background_prompt(background, control_key="step_content")
|
||||
history_prompt = self.create_history_prompt(history)
|
||||
|
|
|
@ -59,10 +59,15 @@ class ReactAgent(BaseAgent):
|
|||
step_content="",
|
||||
input_query=query.input_query,
|
||||
tools=query.tools,
|
||||
parsed_output_list=[query.parsed_output]
|
||||
parsed_output_list=[query.parsed_output],
|
||||
customed_kargs=query.customed_kargs
|
||||
)
|
||||
query_c = copy.deepcopy(query)
|
||||
query_c.parsed_output = {"Question": "\n".join([f"{v}" for k, v in query.parsed_output.items() if k not in ["Action Status"]])}
|
||||
query_c = self.start_action_step(query_c)
|
||||
if query.parsed_output:
|
||||
query_c.parsed_output = {"Question": "\n".join([f"{v}" for k, v in query.parsed_output.items() if k not in ["Action Status"]])}
|
||||
else:
|
||||
query_c.parsed_output = {"Question": query.input_query}
|
||||
react_memory.append(query_c)
|
||||
self_memory = self.memory if self.do_use_self_memory else None
|
||||
idx = 0
|
||||
|
@ -77,9 +82,7 @@ class ReactAgent(BaseAgent):
|
|||
raise Exception(traceback.format_exc())
|
||||
|
||||
output_message.role_content = "\n"+content
|
||||
output_message.role_contents += [content]
|
||||
output_message.step_content += "\n"+output_message.role_content
|
||||
output_message.step_contents + [output_message.role_content]
|
||||
yield output_message
|
||||
|
||||
# logger.debug(f"{self.role.role_name}, {idx} iteration prompt: {prompt}")
|
||||
|
@ -87,7 +90,7 @@ class ReactAgent(BaseAgent):
|
|||
|
||||
output_message = self.message_utils.parser(output_message)
|
||||
# when get finished signal can stop early
|
||||
if output_message.action_status == ActionStatus.FINISHED: break
|
||||
if output_message.action_status == ActionStatus.FINISHED or output_message.action_status == ActionStatus.STOPED: break
|
||||
# according the output to choose one action for code_content or tool_content
|
||||
output_message, observation_message = self.message_utils.step_router(output_message)
|
||||
output_message.parsed_output_list.append(output_message.parsed_output)
|
||||
|
@ -108,6 +111,8 @@ class ReactAgent(BaseAgent):
|
|||
# update memory pool
|
||||
# memory_pool.append(output_message)
|
||||
output_message.input_query = query.input_query
|
||||
# end_action_step
|
||||
output_message = self.end_action_step(output_message)
|
||||
# update memory pool
|
||||
memory_pool.append(output_message)
|
||||
yield output_message
|
||||
|
@ -122,7 +127,7 @@ class ReactAgent(BaseAgent):
|
|||
doc_infos = self.create_doc_prompt(query)
|
||||
code_infos = self.create_codedoc_prompt(query)
|
||||
#
|
||||
formatted_tools, tool_names = self.create_tools_prompt(query)
|
||||
formatted_tools, tool_names, _ = self.create_tools_prompt(query)
|
||||
task_prompt = self.create_task_prompt(query)
|
||||
background_prompt = self.create_background_prompt(background)
|
||||
history_prompt = self.create_history_prompt(history)
|
||||
|
@ -136,7 +141,7 @@ class ReactAgent(BaseAgent):
|
|||
# # input_query = query.input_query + "\n" + "\n".join([f"{v}" for k, v in input_query if v])
|
||||
# input_query = "\n".join([f"{v}" for k, v in input_query if v])
|
||||
input_query = "\n".join(["\n".join([f"**{k}:**\n{v}" for k,v in _dict.items()]) for _dict in react_memory.get_parserd_output()])
|
||||
logger.debug(f"input_query: {input_query}")
|
||||
# logger.debug(f"input_query: {input_query}")
|
||||
|
||||
prompt += "\n" + REACT_PROMPT_INPUT.format(**{"query": input_query})
|
||||
|
||||
|
|
|
@ -4,6 +4,7 @@ import re
|
|||
import json
|
||||
import traceback
|
||||
import copy
|
||||
import random
|
||||
from loguru import logger
|
||||
|
||||
from langchain.prompts.chat import ChatPromptTemplate
|
||||
|
@ -12,7 +13,8 @@ from dev_opsgpt.connector.schema import (
|
|||
Memory, Task, Env, Role, Message, ActionStatus
|
||||
)
|
||||
from dev_opsgpt.llm_models import getChatModel
|
||||
from dev_opsgpt.connector.configs.agent_config import REACT_PROMPT_INPUT, CONTEXT_PROMPT_INPUT, QUERY_CONTEXT_PROMPT_INPUT
|
||||
from dev_opsgpt.connector.configs.prompts import BASE_PROMPT_INPUT, QUERY_CONTEXT_DOC_PROMPT_INPUT, BEGIN_PROMPT_INPUT
|
||||
from dev_opsgpt.connector.utils import parse_section
|
||||
|
||||
from .base_agent import BaseAgent
|
||||
|
||||
|
@ -33,6 +35,7 @@ class SelectorAgent(BaseAgent):
|
|||
do_use_self_memory: bool = True,
|
||||
focus_agents: List[str] = [],
|
||||
focus_message_keys: List[str] = [],
|
||||
group_agents: List[BaseAgent] = [],
|
||||
# prompt_mamnger: PromptManager
|
||||
):
|
||||
|
||||
|
@ -40,6 +43,53 @@ class SelectorAgent(BaseAgent):
|
|||
do_tool_retrieval, temperature, stop, do_filter,do_use_self_memory,
|
||||
focus_agents, focus_message_keys
|
||||
)
|
||||
self.group_agents = group_agents
|
||||
|
||||
def arun(self, query: Message, history: Memory = None, background: Memory = None, memory_pool: Memory=None) -> Message:
|
||||
'''agent reponse from multi-message'''
|
||||
# insert query into memory
|
||||
query_c = copy.deepcopy(query)
|
||||
query = self.start_action_step(query)
|
||||
self_memory = self.memory if self.do_use_self_memory else None
|
||||
# create your llm prompt
|
||||
prompt = self.create_prompt(query_c, self_memory, history, background, memory_pool=memory_pool)
|
||||
content = self.llm.predict(prompt)
|
||||
logger.debug(f"{self.role.role_name} prompt: {prompt}")
|
||||
logger.debug(f"{self.role.role_name} content: {content}")
|
||||
|
||||
# select agent
|
||||
select_message = Message(
|
||||
role_name=self.role.role_name,
|
||||
role_type="ai", #self.role.role_type,
|
||||
role_content=content,
|
||||
step_content=content,
|
||||
input_query=query_c.input_query,
|
||||
tools=query_c.tools,
|
||||
parsed_output_list=[query.parsed_output]
|
||||
)
|
||||
# common parse llm' content to message
|
||||
select_message = self.message_utils.parser(select_message)
|
||||
if self.do_filter:
|
||||
select_message = self.message_utils.filter(select_message)
|
||||
|
||||
output_message = None
|
||||
if select_message.parsed_output.get("Role", "") in [agent.role.role_name for agent in self.group_agents]:
|
||||
for agent in self.group_agents:
|
||||
if agent.role.role_name == select_message.parsed_output.get("Role", ""):
|
||||
break
|
||||
for output_message in agent.arun(query, history, background=background, memory_pool=memory_pool):
|
||||
pass
|
||||
# update self_memory
|
||||
self.append_history(query_c)
|
||||
self.append_history(output_message)
|
||||
logger.info(f"{agent.role.role_name} currenct question: {output_message.input_query}\nllm_step_run: {output_message.role_content}")
|
||||
output_message.input_query = output_message.role_content
|
||||
output_message.parsed_output_list.append(output_message.parsed_output)
|
||||
#
|
||||
output_message = self.end_action_step(output_message)
|
||||
# update memory pool
|
||||
memory_pool.append(output_message)
|
||||
yield output_message or select_message
|
||||
|
||||
def create_prompt(
|
||||
self, query: Message, memory: Memory =None, history: Memory = None, background: Memory = None, memory_pool: Memory=None, prompt_mamnger=None) -> str:
|
||||
|
@ -50,56 +100,49 @@ class SelectorAgent(BaseAgent):
|
|||
doc_infos = self.create_doc_prompt(query)
|
||||
code_infos = self.create_codedoc_prompt(query)
|
||||
#
|
||||
formatted_tools, tool_names = self.create_tools_prompt(query)
|
||||
formatted_tools, tool_names, tools_descs = self.create_tools_prompt(query)
|
||||
agent_names, agents = self.create_agent_names()
|
||||
task_prompt = self.create_task_prompt(query)
|
||||
background_prompt = self.create_background_prompt(background)
|
||||
history_prompt = self.create_history_prompt(history)
|
||||
selfmemory_prompt = self.create_selfmemory_prompt(memory, control_key="step_content")
|
||||
|
||||
# react 流程是自身迭代过程,另外二次触发的是需要作为历史对话信息
|
||||
# input_query = react_memory.to_tuple_messages(content_key="step_content")
|
||||
|
||||
DocInfos = ""
|
||||
if doc_infos is not None and doc_infos!="" and doc_infos!="不存在知识库辅助信息":
|
||||
DocInfos += f"\nDocument Information: {doc_infos}"
|
||||
|
||||
if code_infos is not None and code_infos!="" and code_infos!="不存在代码库辅助信息":
|
||||
DocInfos += f"\nCodeBase Infomation: {code_infos}"
|
||||
|
||||
input_query = query.input_query
|
||||
|
||||
# logger.debug(f"{self.role.role_name} extra_system_prompt: {self.role.role_prompt}")
|
||||
logger.debug(f"{self.role.role_name} input_query: {input_query}")
|
||||
# logger.debug(f"{self.role.role_name} doc_infos: {doc_infos}")
|
||||
# logger.debug(f"{self.role.role_name} tool_names: {tool_names}")
|
||||
# prompt += "\n" + CHECK_PROMPT_INPUT.format(**{"query": input_query})
|
||||
# prompt.format(**{"query": input_query})
|
||||
|
||||
# extra_system_prompt = self.role.role_prompt
|
||||
prompt = self.role.role_prompt.format(**{"query": input_query, "formatted_tools": formatted_tools, "tool_names": tool_names})
|
||||
prompt = self.role.role_prompt.format(**{"agent_names": agent_names, "agents": agents, "formatted_tools": tools_descs, "tool_names": tool_names})
|
||||
#
|
||||
memory_pool_select_by_agent_key = self.select_memory_by_agent_key(memory_pool)
|
||||
memory_pool_select_by_agent_key_context = '\n\n'.join([f"*{k}*\n{v}" for parsed_output in memory_pool_select_by_agent_key.get_parserd_output_list() for k, v in parsed_output.items() if k not in ['Action Status']])
|
||||
|
||||
if "**Context:**" in self.role.role_prompt:
|
||||
# logger.debug(f"parsed_output_list: {query.parsed_output_list}")
|
||||
# input_query = "'''" + "\n".join([f"*{k}*\n{v}" for i in background.get_parserd_output_list() for k,v in i.items() if "Action Status" !=k]) + "'''"
|
||||
context = "\n".join([f"*{k}*\n{v}" for i in background.get_parserd_output_list() for k,v in i.items() if "Action Status" !=k])
|
||||
# logger.debug(f"parsed_output_list: {t}")
|
||||
prompt += "\n" + QUERY_CONTEXT_PROMPT_INPUT.format(**{"query": query.origin_query, "context": context})
|
||||
else:
|
||||
prompt += "\n" + REACT_PROMPT_INPUT.format(**{"query": input_query})
|
||||
input_keys = parse_section(self.role.role_prompt, 'Input Format')
|
||||
#
|
||||
prompt += "\n" + BEGIN_PROMPT_INPUT
|
||||
for input_key in input_keys:
|
||||
if input_key == "Origin Query":
|
||||
prompt += "\n**Origin Query:**\n" + query.origin_query
|
||||
elif input_key == "Context":
|
||||
context = "\n".join([f"*{k}*\n{v}" for i in query.parsed_output_list for k,v in i.items() if "Action Status" !=k])
|
||||
if history:
|
||||
context = history_prompt + "\n" + context
|
||||
if not context:
|
||||
context = "there is no context"
|
||||
|
||||
if self.focus_agents and memory_pool_select_by_agent_key_context:
|
||||
context = memory_pool_select_by_agent_key_context
|
||||
prompt += "\n**Context:**\n" + context + "\n" + input_query
|
||||
elif input_key == "DocInfos":
|
||||
prompt += "\n**DocInfos:**\n" + DocInfos
|
||||
elif input_key == "Question":
|
||||
prompt += "\n**Question:**\n" + input_query
|
||||
|
||||
task = query.task or self.task
|
||||
if task_prompt is not None:
|
||||
prompt += "\n" + task.task_prompt
|
||||
|
||||
# if doc_infos is not None and doc_infos!="" and doc_infos!="不存在知识库辅助信息":
|
||||
# prompt += f"\n知识库信息: {doc_infos}"
|
||||
|
||||
# if code_infos is not None and code_infos!="" and code_infos!="不存在代码库辅助信息":
|
||||
# prompt += f"\n代码库信息: {code_infos}"
|
||||
|
||||
# if background_prompt:
|
||||
# prompt += "\n" + background_prompt
|
||||
|
||||
# if history_prompt:
|
||||
# prompt += "\n" + history_prompt
|
||||
|
||||
# if selfmemory_prompt:
|
||||
# prompt += "\n" + selfmemory_prompt
|
||||
|
||||
# prompt = extra_system_prompt.format(**{"query": input_query, "doc_infos": doc_infos, "formatted_tools": formatted_tools, "tool_names": tool_names})
|
||||
while "{{" in prompt or "}}" in prompt:
|
||||
prompt = prompt.replace("{{", "{")
|
||||
prompt = prompt.replace("}}", "}")
|
||||
|
@ -107,3 +150,16 @@ class SelectorAgent(BaseAgent):
|
|||
# logger.debug(f"{self.role.role_name} prompt: {prompt}")
|
||||
return prompt
|
||||
|
||||
def create_agent_names(self):
|
||||
random.shuffle(self.group_agents)
|
||||
agent_names = ", ".join([f'{agent.role.role_name}' for agent in self.group_agents])
|
||||
agent_descs = []
|
||||
for agent in self.group_agents:
|
||||
role_desc = agent.role.role_prompt.split("####")[1]
|
||||
while "\n\n" in role_desc:
|
||||
role_desc = role_desc.replace("\n\n", "\n")
|
||||
role_desc = role_desc.replace("\n", ",")
|
||||
|
||||
agent_descs.append(f'"role name: {agent.role.role_name}\nrole description: {role_desc}"')
|
||||
|
||||
return agent_names, "\n".join(agent_descs)
|
|
@ -15,10 +15,6 @@ from dev_opsgpt.connector.schema import (
|
|||
load_role_configs
|
||||
)
|
||||
from dev_opsgpt.connector.message_process import MessageUtils
|
||||
from dev_opsgpt.sandbox import PyCodeBox, CodeBoxResponse
|
||||
|
||||
|
||||
from configs.server_config import SANDBOX_SERVER
|
||||
|
||||
from dev_opsgpt.connector.configs.agent_config import AGETN_CONFIGS
|
||||
role_configs = load_role_configs(AGETN_CONFIGS)
|
||||
|
@ -31,7 +27,6 @@ class BaseChain:
|
|||
agents: List[BaseAgent],
|
||||
chat_turn: int = 1,
|
||||
do_checker: bool = False,
|
||||
do_code_exec: bool = False,
|
||||
# prompt_mamnger: PromptManager
|
||||
) -> None:
|
||||
self.chainConfig = chainConfig
|
||||
|
@ -45,29 +40,9 @@ class BaseChain:
|
|||
do_doc_retrieval = role_configs["checker"].do_doc_retrieval,
|
||||
do_tool_retrieval = role_configs["checker"].do_tool_retrieval,
|
||||
do_filter=False, do_use_self_memory=False)
|
||||
|
||||
self.do_agent_selector = False
|
||||
self.agent_selector = CheckAgent(role=role_configs["checker"].role,
|
||||
task = None,
|
||||
memory = None,
|
||||
do_search = role_configs["checker"].do_search,
|
||||
do_doc_retrieval = role_configs["checker"].do_doc_retrieval,
|
||||
do_tool_retrieval = role_configs["checker"].do_tool_retrieval,
|
||||
do_filter=False, do_use_self_memory=False)
|
||||
|
||||
self.messageUtils = MessageUtils()
|
||||
# all memory created by agent until instance deleted
|
||||
self.global_memory = Memory(messages=[])
|
||||
# self.do_code_exec = do_code_exec
|
||||
# self.codebox = PyCodeBox(
|
||||
# remote_url=SANDBOX_SERVER["url"],
|
||||
# remote_ip=SANDBOX_SERVER["host"],
|
||||
# remote_port=SANDBOX_SERVER["port"],
|
||||
# token="mytoken",
|
||||
# do_code_exe=True,
|
||||
# do_remote=SANDBOX_SERVER["do_remote"],
|
||||
# do_check_net=False
|
||||
# )
|
||||
|
||||
def step(self, query: Message, history: Memory = None, background: Memory = None, memory_pool: Memory = None) -> Message:
|
||||
'''execute chain'''
|
||||
|
@ -85,263 +60,48 @@ class BaseChain:
|
|||
self.global_memory.append(input_message)
|
||||
# local_memory.append(input_message)
|
||||
while step_nums > 0:
|
||||
|
||||
if self.do_agent_selector:
|
||||
agent_message = copy.deepcopy(query)
|
||||
agent_message.agents = self.agents
|
||||
for selectory_message in self.agent_selector.arun(query, background=self.global_memory, memory_pool=memory_pool):
|
||||
pass
|
||||
selectory_message = self.messageUtils.parser(selectory_message)
|
||||
selectory_message = self.messageUtils.filter(selectory_message)
|
||||
agent = self.agents[selectory_message.agent_index]
|
||||
# selector agent to execure next task
|
||||
for agent in self.agents:
|
||||
for output_message in agent.arun(input_message, history, background=background, memory_pool=memory_pool):
|
||||
# logger.debug(f"local_memory {local_memory + output_message}")
|
||||
yield output_message, local_memory + output_message
|
||||
|
||||
|
||||
output_message = self.messageUtils.inherit_extrainfo(input_message, output_message)
|
||||
# according the output to choose one action for code_content or tool_content
|
||||
# logger.info(f"{agent.role.role_name}\nmessage: {output_message.step_content}\nquery: {output_message.input_query}")
|
||||
# logger.info(f"{agent.role.role_name} currenct message: {output_message.step_content}\n next llm question: {output_message.input_query}")
|
||||
output_message = self.messageUtils.parser(output_message)
|
||||
yield output_message, local_memory + output_message
|
||||
# output_message = self.step_router(output_message)
|
||||
|
||||
input_message = output_message
|
||||
self.global_memory.append(output_message)
|
||||
|
||||
local_memory.append(output_message)
|
||||
# when get finished signal can stop early
|
||||
if output_message.action_status == ActionStatus.FINISHED:
|
||||
if output_message.action_status == ActionStatus.FINISHED or output_message.action_status == ActionStatus.STOPED:
|
||||
action_status = False
|
||||
break
|
||||
else:
|
||||
for agent in self.agents:
|
||||
for output_message in agent.arun(input_message, history, background=background, memory_pool=memory_pool):
|
||||
# logger.debug(f"local_memory {local_memory + output_message}")
|
||||
yield output_message, local_memory + output_message
|
||||
|
||||
output_message = self.messageUtils.inherit_extrainfo(input_message, output_message)
|
||||
# according the output to choose one action for code_content or tool_content
|
||||
# logger.info(f"{agent.role.role_name} currenct message: {output_message.step_content}\n next llm question: {output_message.input_query}")
|
||||
output_message = self.messageUtils.parser(output_message)
|
||||
yield output_message, local_memory + output_message
|
||||
# output_message = self.step_router(output_message)
|
||||
|
||||
if output_message.action_status == ActionStatus.FINISHED:
|
||||
break
|
||||
|
||||
input_message = output_message
|
||||
self.global_memory.append(output_message)
|
||||
if self.do_checker and self.chat_turn > 1:
|
||||
# logger.debug(f"{self.checker.role.role_name} input global memory: {self.global_memory.to_str_messages(content_key='step_content', return_all=False)}")
|
||||
for check_message in self.checker.arun(query, background=local_memory, memory_pool=memory_pool):
|
||||
pass
|
||||
check_message = self.messageUtils.parser(check_message)
|
||||
check_message = self.messageUtils.filter(check_message)
|
||||
check_message = self.messageUtils.inherit_extrainfo(output_message, check_message)
|
||||
logger.debug(f"{self.checker.role.role_name}: {check_message.role_content}")
|
||||
|
||||
local_memory.append(output_message)
|
||||
# when get finished signal can stop early
|
||||
if output_message.action_status == ActionStatus.FINISHED:
|
||||
action_status = False
|
||||
break
|
||||
|
||||
if self.do_checker and self.chat_turn > 1:
|
||||
# logger.debug(f"{self.checker.role.role_name} input global memory: {self.global_memory.to_str_messages(content_key='step_content', return_all=False)}")
|
||||
for check_message in self.checker.arun(query, background=local_memory, memory_pool=memory_pool):
|
||||
pass
|
||||
check_message = self.messageUtils.parser(check_message)
|
||||
check_message = self.messageUtils.filter(check_message)
|
||||
check_message = self.messageUtils.inherit_extrainfo(output_message, check_message)
|
||||
logger.debug(f"{self.checker.role.role_name}: {check_message.role_content}")
|
||||
|
||||
if check_message.action_status == ActionStatus.FINISHED:
|
||||
self.global_memory.append(check_message)
|
||||
break
|
||||
if check_message.action_status == ActionStatus.FINISHED:
|
||||
self.global_memory.append(check_message)
|
||||
break
|
||||
|
||||
step_nums -= 1
|
||||
#
|
||||
output_message = check_message or output_message # 返回chain和checker的结果
|
||||
output_message.input_query = query.input_query # chain和chain之间消息通信不改变问题
|
||||
yield output_message, local_memory
|
||||
|
||||
# def step_router(self, message: Message) -> Message:
|
||||
# ''''''
|
||||
# # message = self.parser(message)
|
||||
# # logger.debug(f"message.action_status: {message.action_status}")
|
||||
# if message.action_status == ActionStatus.CODING:
|
||||
# message = self.code_step(message)
|
||||
# elif message.action_status == ActionStatus.TOOL_USING:
|
||||
# message = self.tool_step(message)
|
||||
|
||||
# return message
|
||||
|
||||
# def code_step(self, message: Message) -> Message:
|
||||
# '''execute code'''
|
||||
# # logger.debug(f"message.role_content: {message.role_content}, message.code_content: {message.code_content}")
|
||||
# code_answer = self.codebox.chat('```python\n{}```'.format(message.code_content))
|
||||
# uid = str(uuid.uuid1())
|
||||
# if code_answer.code_exe_type == "image/png":
|
||||
# message.figures[uid] = code_answer.code_exe_response
|
||||
# message.code_answer = f"\n观察: 执行代码后获得输出一张图片, 文件名为{uid}\n"
|
||||
# message.observation = f"\n观察: 执行代码后获得输出一张图片, 文件名为{uid}\n"
|
||||
# message.step_content += f"\n观察: 执行代码后获得输出一张图片, 文件名为{uid}\n"
|
||||
# message.step_contents += [f"\n观察: 执行代码后获得输出一张图片, 文件名为{uid}\n"]
|
||||
# message.role_content += f"\n执行代码后获得输出一张图片, 文件名为{uid}\n"
|
||||
# else:
|
||||
# message.code_answer = code_answer.code_exe_response
|
||||
# message.observation = code_answer.code_exe_response
|
||||
# message.step_content += f"\n观察: 执行代码后获得输出是 {code_answer.code_exe_response}\n"
|
||||
# message.step_contents += [f"\n观察: 执行代码后获得输出是 {code_answer.code_exe_response}\n"]
|
||||
# message.role_content += f"\n观察: 执行代码后获得输出是 {code_answer.code_exe_response}\n"
|
||||
# logger.info(f"观察: {message.action_status}, {message.observation}")
|
||||
# return message
|
||||
|
||||
# def tool_step(self, message: Message) -> Message:
|
||||
# '''execute tool'''
|
||||
# # logger.debug(f"message: {message.action_status}, {message.tool_name}, {message.tool_params}")
|
||||
# tool_names = [tool.name for tool in message.tools]
|
||||
# if message.tool_name not in tool_names:
|
||||
# message.tool_answer = "不存在可以执行的tool"
|
||||
# message.observation = "不存在可以执行的tool"
|
||||
# message.role_content += f"\n观察: 不存在可以执行的tool\n"
|
||||
# message.step_content += f"\n观察: 不存在可以执行的tool\n"
|
||||
# message.step_contents += [f"\n观察: 不存在可以执行的tool\n"]
|
||||
# for tool in message.tools:
|
||||
# if tool.name == message.tool_name:
|
||||
# tool_res = tool.func(**message.tool_params)
|
||||
# message.tool_answer = tool_res
|
||||
# message.observation = tool_res
|
||||
# message.role_content += f"\n观察: {tool_res}\n"
|
||||
# message.step_content += f"\n观察: {tool_res}\n"
|
||||
# message.step_contents += [f"\n观察: {tool_res}\n"]
|
||||
# return message
|
||||
|
||||
# def filter(self, message: Message, stop=None) -> Message:
|
||||
|
||||
# tool_params = self.parser_spec_key(message.role_content, "tool_params")
|
||||
# code_content = self.parser_spec_key(message.role_content, "code_content")
|
||||
# plan = self.parser_spec_key(message.role_content, "plan")
|
||||
# plans = self.parser_spec_key(message.role_content, "plans", do_search=False)
|
||||
# content = self.parser_spec_key(message.role_content, "content", do_search=False)
|
||||
|
||||
# # logger.debug(f"tool_params: {tool_params}, code_content: {code_content}, plan: {plan}, plans: {plans}, content: {content}")
|
||||
# role_content = tool_params or code_content or plan or plans or content
|
||||
# message.role_content = role_content or message.role_content
|
||||
# return message
|
||||
|
||||
# def parser(self, message: Message) -> Message:
|
||||
# ''''''
|
||||
# content = message.role_content
|
||||
# parser_keys = ["action", "code_content", "code_filename", "tool_params", "plans"]
|
||||
# try:
|
||||
# s_json = self._parse_json(content)
|
||||
# message.action_status = s_json.get("action")
|
||||
# message.code_content = s_json.get("code_content")
|
||||
# message.tool_params = s_json.get("tool_params")
|
||||
# message.tool_name = s_json.get("tool_name")
|
||||
# message.code_filename = s_json.get("code_filename")
|
||||
# message.plans = s_json.get("plans")
|
||||
# # for parser_key in parser_keys:
|
||||
# # message.action_status = content.get(parser_key)
|
||||
# except Exception as e:
|
||||
# # logger.warning(f"{traceback.format_exc()}")
|
||||
# def parse_text_to_dict(text):
|
||||
# # Define a regular expression pattern to capture the key and value
|
||||
# main_pattern = r"\*\*(.+?):\*\*\s*(.*?)\s*(?=\*\*|$)"
|
||||
# code_pattern = r'```python\n(.*?)```'
|
||||
|
||||
# # Use re.findall to find all main matches in the text
|
||||
# main_matches = re.findall(main_pattern, text, re.DOTALL)
|
||||
|
||||
# # Convert main matches to a dictionary
|
||||
# parsed_dict = {key.strip(): value.strip() for key, value in main_matches}
|
||||
|
||||
# # Search for the code block
|
||||
# code_match = re.search(code_pattern, text, re.DOTALL)
|
||||
# if code_match:
|
||||
# # Add the code block to the dictionary
|
||||
# parsed_dict['code'] = code_match.group(1).strip()
|
||||
|
||||
# return parsed_dict
|
||||
|
||||
# parsed_dict = parse_text_to_dict(content)
|
||||
# action_value = parsed_dict.get('Action Status')
|
||||
# if action_value:
|
||||
# action_value = action_value.lower()
|
||||
# logger.debug(f'action_value: {action_value}')
|
||||
# # action_value = self._match(r"'action':\s*'([^']*)'", content) if "'action'" in content else self._match(r'"action":\s*"([^"]*)"', content)
|
||||
|
||||
# code_content_value = parsed_dict.get('code')
|
||||
# # code_content_value = self._match(r"'code_content':\s*'([^']*)'", content) if "'code_content'" in content else self._match(r'"code_content":\s*"([^"]*)"', content)
|
||||
# filename_value = self._match(r"'code_filename':\s*'([^']*)'", content) if "'code_filename'" in content else self._match(r'"code_filename":\s*"([^"]*)"', content)
|
||||
# tool_params_value = self._match(r"'tool_params':\s*(\{[^{}]*\})", content, do_json=True) if "'tool_params'" in content \
|
||||
# else self._match(r'"tool_params":\s*(\{[^{}]*\})', content, do_json=True)
|
||||
# tool_name_value = self._match(r"'tool_name':\s*'([^']*)'", content) if "'tool_name'" in content else self._match(r'"tool_name":\s*"([^"]*)"', content)
|
||||
# plans_value = self._match(r"'plans':\s*(\[.*?\])", content, do_search=False) if "'plans'" in content else self._match(r'"plans":\s*(\[.*?\])', content, do_search=False, )
|
||||
# # re解析
|
||||
# message.action_status = action_value or "default"
|
||||
# message.code_content = code_content_value
|
||||
# message.code_filename = filename_value
|
||||
# message.tool_params = tool_params_value
|
||||
# message.tool_name = tool_name_value
|
||||
# message.plans = plans_value
|
||||
|
||||
# # logger.debug(f"确认当前的action: {message.action_status}")
|
||||
|
||||
# return message
|
||||
|
||||
# def parser_spec_key(self, content, key, do_search=True, do_json=False) -> str:
|
||||
# ''''''
|
||||
# key2pattern = {
|
||||
# "'action'": r"'action':\s*'([^']*)'", '"action"': r'"action":\s*"([^"]*)"',
|
||||
# "'code_content'": r"'code_content':\s*'([^']*)'", '"code_content"': r'"code_content":\s*"([^"]*)"',
|
||||
# "'code_filename'": r"'code_filename':\s*'([^']*)'", '"code_filename"': r'"code_filename":\s*"([^"]*)"',
|
||||
# "'tool_params'": r"'tool_params':\s*(\{[^{}]*\})", '"tool_params"': r'"tool_params":\s*(\{[^{}]*\})',
|
||||
# "'tool_name'": r"'tool_name':\s*'([^']*)'", '"tool_name"': r'"tool_name":\s*"([^"]*)"',
|
||||
# "'plans'": r"'plans':\s*(\[.*?\])", '"plans"': r'"plans":\s*(\[.*?\])',
|
||||
# "'content'": r"'content':\s*'([^']*)'", '"content"': r'"content":\s*"([^"]*)"',
|
||||
# }
|
||||
|
||||
# s_json = self._parse_json(content)
|
||||
# try:
|
||||
# if s_json and key in s_json:
|
||||
# return str(s_json[key])
|
||||
# except:
|
||||
# pass
|
||||
|
||||
# keystr = f"'{key}'" if f"'{key}'" in content else f'"{key}"'
|
||||
# return self._match(key2pattern.get(keystr, fr"'{key}':\s*'([^']*)'"), content, do_search=do_search, do_json=do_json)
|
||||
|
||||
# def _match(self, pattern, s, do_search=True, do_json=False):
|
||||
# try:
|
||||
# if do_search:
|
||||
# match = re.search(pattern, s)
|
||||
# if match:
|
||||
# value = match.group(1).replace("\\n", "\n")
|
||||
# if do_json:
|
||||
# value = json.loads(value)
|
||||
# else:
|
||||
# value = None
|
||||
# else:
|
||||
# match = re.findall(pattern, s, re.DOTALL)
|
||||
# if match:
|
||||
# value = match[0]
|
||||
# if do_json:
|
||||
# value = json.loads(value)
|
||||
# else:
|
||||
# value = None
|
||||
# except Exception as e:
|
||||
# logger.warning(f"{traceback.format_exc()}")
|
||||
|
||||
# # logger.debug(f"pattern: {pattern}, s: {s}, match: {match}")
|
||||
# return value
|
||||
|
||||
# def _parse_json(self, s):
|
||||
# try:
|
||||
# pattern = r"```([^`]+)```"
|
||||
# match = re.findall(pattern, s)
|
||||
# if match:
|
||||
# return eval(match[0])
|
||||
# except:
|
||||
# pass
|
||||
# return None
|
||||
|
||||
# def inherit_extrainfo(self, input_message: Message, output_message: Message):
|
||||
# output_message.db_docs = input_message.db_docs
|
||||
# output_message.search_docs = input_message.search_docs
|
||||
# output_message.code_docs = input_message.code_docs
|
||||
# output_message.origin_query = input_message.origin_query
|
||||
# output_message.figures.update(input_message.figures)
|
||||
# return output_message
|
||||
|
||||
def get_memory(self, content_key="role_content") -> Memory:
|
||||
memory = self.global_memory
|
||||
|
|
|
@ -7,6 +7,7 @@ from .prompts import (
|
|||
QA_PROMPT, CODE_QA_PROMPT, QA_TEMPLATE_PROMPT,
|
||||
EXECUTOR_TEMPLATE_PROMPT,
|
||||
REFINE_TEMPLATE_PROMPT,
|
||||
SELECTOR_AGENT_TEMPLATE_PROMPT,
|
||||
PLANNER_TEMPLATE_PROMPT, GENERAL_PLANNER_PROMPT, DATA_PLANNER_PROMPT, TOOL_PLANNER_PROMPT,
|
||||
PRD_WRITER_METAGPT_PROMPT, DESIGN_WRITER_METAGPT_PROMPT, TASK_WRITER_METAGPT_PROMPT, CODE_WRITER_METAGPT_PROMPT,
|
||||
REACT_TEMPLATE_PROMPT,
|
||||
|
@ -20,10 +21,25 @@ class AgentType:
|
|||
EXECUTOR = "ExecutorAgent"
|
||||
ONE_STEP = "BaseAgent"
|
||||
DEFAULT = "BaseAgent"
|
||||
SELECTOR = "SelectorAgent"
|
||||
|
||||
|
||||
|
||||
AGETN_CONFIGS = {
|
||||
"baseGroup": {
|
||||
"role": {
|
||||
"role_prompt": SELECTOR_AGENT_TEMPLATE_PROMPT,
|
||||
"role_type": "assistant",
|
||||
"role_name": "baseGroup",
|
||||
"role_desc": "",
|
||||
"agent_type": "SelectorAgent"
|
||||
},
|
||||
"group_agents": ["tool_react", "code_react"],
|
||||
"chat_turn": 1,
|
||||
"do_search": False,
|
||||
"do_doc_retrieval": False,
|
||||
"do_tool_retrieval": False
|
||||
},
|
||||
"checker": {
|
||||
"role": {
|
||||
"role_prompt": CHECKER_TEMPLATE_PROMPT,
|
||||
|
|
|
@ -108,4 +108,20 @@ CHAIN_CONFIGS = {
|
|||
"do_checker": False,
|
||||
"chain_prompt": ""
|
||||
},
|
||||
"baseGroupChain": {
|
||||
"chain_name": "baseGroupChain",
|
||||
"chain_type": "BaseChain",
|
||||
"agents": ["baseGroup"],
|
||||
"chat_turn": 1,
|
||||
"do_checker": False,
|
||||
"chain_prompt": ""
|
||||
},
|
||||
"codeChatXXChain": {
|
||||
"chain_name": "codeChatXXChain",
|
||||
"chain_type": "BaseChain",
|
||||
"agents": ["codeChat1", "codeChat2"],
|
||||
"chat_turn": 1,
|
||||
"do_checker": False,
|
||||
"chain_prompt": ""
|
||||
}
|
||||
}
|
||||
|
|
|
@ -88,15 +88,26 @@ PHASE_CONFIGS = {
|
|||
"do_tool_retrieval": False,
|
||||
"do_using_tool": False
|
||||
},
|
||||
"metagpt_code_devlop": {
|
||||
"phase_name": "metagpt_code_devlop",
|
||||
"phase_type": "BasePhase",
|
||||
"chains": ["metagptChain",],
|
||||
"do_summary": False,
|
||||
"do_search": False,
|
||||
"do_doc_retrieval": False,
|
||||
"do_code_retrieval": False,
|
||||
"do_tool_retrieval": False,
|
||||
"do_using_tool": False
|
||||
},
|
||||
# "metagpt_code_devlop": {
|
||||
# "phase_name": "metagpt_code_devlop",
|
||||
# "phase_type": "BasePhase",
|
||||
# "chains": ["metagptChain",],
|
||||
# "do_summary": False,
|
||||
# "do_search": False,
|
||||
# "do_doc_retrieval": False,
|
||||
# "do_code_retrieval": False,
|
||||
# "do_tool_retrieval": False,
|
||||
# "do_using_tool": False
|
||||
# },
|
||||
# "baseGroupPhase": {
|
||||
# "phase_name": "baseGroupPhase",
|
||||
# "phase_type": "BasePhase",
|
||||
# "chains": ["baseGroupChain"],
|
||||
# "do_summary": False,
|
||||
# "do_search": False,
|
||||
# "do_doc_retrieval": False,
|
||||
# "do_code_retrieval": False,
|
||||
# "do_tool_retrieval": False,
|
||||
# "do_using_tool": False
|
||||
# },
|
||||
}
|
||||
|
|
|
@ -15,6 +15,8 @@ from .qa_template_prompt import QA_PROMPT, CODE_QA_PROMPT, QA_TEMPLATE_PROMPT
|
|||
from .executor_template_prompt import EXECUTOR_TEMPLATE_PROMPT
|
||||
from .refine_template_prompt import REFINE_TEMPLATE_PROMPT
|
||||
|
||||
from .agent_selector_template_prompt import SELECTOR_AGENT_TEMPLATE_PROMPT
|
||||
|
||||
from .react_template_prompt import REACT_TEMPLATE_PROMPT
|
||||
from .react_code_prompt import REACT_CODE_PROMPT
|
||||
from .react_tool_prompt import REACT_TOOL_PROMPT
|
||||
|
@ -32,6 +34,7 @@ __all__ = [
|
|||
"QA_PROMPT", "CODE_QA_PROMPT", "QA_TEMPLATE_PROMPT",
|
||||
"EXECUTOR_TEMPLATE_PROMPT",
|
||||
"REFINE_TEMPLATE_PROMPT",
|
||||
"SELECTOR_AGENT_TEMPLATE_PROMPT",
|
||||
"PLANNER_TEMPLATE_PROMPT", "GENERAL_PLANNER_PROMPT", "DATA_PLANNER_PROMPT", "TOOL_PLANNER_PROMPT",
|
||||
"REACT_TEMPLATE_PROMPT",
|
||||
"REACT_CODE_PROMPT", "REACT_TOOL_PROMPT", "REACT_TOOL_AND_CODE_PROMPT", "REACT_TOOL_AND_CODE_PLANNER_PROMPT"
|
||||
|
|
|
@ -0,0 +1,24 @@
|
|||
SELECTOR_AGENT_TEMPLATE_PROMPT = """#### Role Selector Assistance Guidance
|
||||
|
||||
Your goal is to match the user's initial Origin Query) with the role that will best facilitate a solution, taking into account all relevant context (Context) provided.
|
||||
|
||||
When you need to select the appropriate role for handling a user's query, carefully read the provided role names, role descriptions and tool list.
|
||||
|
||||
You can use these tools:\n{formatted_tools}
|
||||
|
||||
Please ensure your selection is one of the listed roles. Available roles for selection:
|
||||
{agents}
|
||||
|
||||
#### Input Format
|
||||
|
||||
**Origin Query:** the initial question or objective that the user wanted to achieve
|
||||
|
||||
**Context:** the context history to determine if Origin Query has been achieved.
|
||||
|
||||
#### Response Output Format
|
||||
|
||||
**Thoughts:** think the reason of selecting the role step by step
|
||||
|
||||
**Role:** Select the role name. such as {agent_names}
|
||||
|
||||
"""
|
|
@ -12,12 +12,12 @@ Each decision should be justified based on the context provided, specifying if t
|
|||
**Context:** the current status and history of the tasks to determine if Origin Query has been achieved.
|
||||
|
||||
#### Response Output Format
|
||||
**REASON:** Justify the decision of choosing 'finished' and 'continued' by evaluating the progress step by step.
|
||||
Consider all relevant information. If the tasks were aimed at an ongoing process, assess whether it has reached a satisfactory conclusion.
|
||||
|
||||
**Action Status:** Set to 'finished' or 'continued'.
|
||||
If it's 'finished', the context can answer the origin query.
|
||||
If it's 'continued', the context cant answer the origin query.
|
||||
|
||||
**REASON:** Justify the decision of choosing 'finished' and 'continued' by evaluating the progress step by step.
|
||||
Consider all relevant information. If the tasks were aimed at an ongoing process, assess whether it has reached a satisfactory conclusion.
|
||||
"""
|
||||
|
||||
CHECKER_PROMPT = """尽可能地以有帮助和准确的方式回应人类,判断问题是否得到解答,同时展现解答的过程和内容。
|
||||
|
|
|
@ -12,7 +12,7 @@ Each reply should contain only the code required for the current step.
|
|||
|
||||
**Thoughts:** Based on the question and observations above, provide the plan for executing this step.
|
||||
|
||||
**Action Status:** Set to 'finished' or 'coding'. If it's 'finished', the next action is to provide the final answer to the original question. If it's 'coding', the next step is to write the code.
|
||||
**Action Status:** Set to 'stoped' or 'code_executing'. If it's 'stoped', the next action is to provide the final answer to the original question. If it's 'code_executing', the next step is to write the code.
|
||||
|
||||
**Action:** Code according to your thoughts. Use this format for code:
|
||||
|
||||
|
@ -26,7 +26,7 @@ Each reply should contain only the code required for the current step.
|
|||
|
||||
**Thoughts:** I now know the final answer
|
||||
|
||||
**Action Status:** Set to 'finished'
|
||||
**Action Status:** Set to 'stoped'
|
||||
|
||||
**Action:** The final answer to the original input question
|
||||
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
QA_TEMPLATE_PROMPT = """#### Question Answer Assistance Guidance
|
||||
|
||||
Based on the information provided, please answer the origin query concisely and professionally.
|
||||
If the answer cannot be derived from the given Context and DocInfos, please say 'The question cannot be answered based on the information provided' and do not add any fabricated elements to the answer.
|
||||
Attention: Follow the input format and response output format
|
||||
|
||||
#### Input Format
|
||||
|
@ -13,14 +12,15 @@ Attention: Follow the input format and response output format
|
|||
**DocInfos:**: the relevant doc information or code information, if this is empty, don't refer to this.
|
||||
|
||||
#### Response Output Format
|
||||
**Action Status:** Set to 'Continued' or 'Stopped'.
|
||||
**Answer:** Response to the user's origin query based on Context and DocInfos. If DocInfos is empty, you can ignore it.
|
||||
If the answer cannot be derived from the given Context and DocInfos, please say 'The question cannot be answered based on the information provided' and do not add any fabricated elements to the answer.
|
||||
"""
|
||||
|
||||
|
||||
CODE_QA_PROMPT = """#### Code Answer Assistance Guidance
|
||||
|
||||
Based on the information provided, please answer the origin query concisely and professionally.
|
||||
If the answer cannot be derived from the given Context and DocInfos, please say 'The question cannot be answered based on the information provided' and do not add any fabricated elements to the answer.
|
||||
Attention: Follow the input format and response output format
|
||||
|
||||
#### Input Format
|
||||
|
@ -30,7 +30,9 @@ Attention: Follow the input format and response output format
|
|||
**DocInfos:**: the relevant doc information or code information, if this is empty, don't refer to this.
|
||||
|
||||
#### Response Output Format
|
||||
**Answer:** Response to the user's origin query based on DocInfos. If DocInfos is empty, you can ignore it.
|
||||
**Action Status:** Set to 'Continued' or 'Stopped'.
|
||||
**Answer:** Response to the user's origin query based on Context and DocInfos. If DocInfos is empty, you can ignore it.
|
||||
If the answer cannot be derived from the given Context and DocInfos, please say 'The question cannot be answered based on the information provided' and do not add any fabricated elements to the answer.
|
||||
"""
|
||||
|
||||
|
||||
|
|
|
@ -2,7 +2,9 @@
|
|||
|
||||
REACT_CODE_PROMPT = """#### Writing Code Assistance Guidance
|
||||
|
||||
When users need help with coding, your role is to provide precise and effective guidance. Write the code step by step, showing only the part necessary to solve the current problem. Each reply should contain only the code required for the current step.
|
||||
When users need help with coding, your role is to provide precise and effective guidance.
|
||||
|
||||
Write the code step by step, showing only the part necessary to solve the current problem. Each reply should contain only the code required for the current step.
|
||||
|
||||
#### Response Process
|
||||
|
||||
|
@ -10,12 +12,13 @@ When users need help with coding, your role is to provide precise and effective
|
|||
|
||||
**Thoughts:** Based on the question and observations above, provide the plan for executing this step.
|
||||
|
||||
**Action Status:** Set to 'finished' or 'coding'. If it's 'finished', the next action is to provide the final answer to the original question. If it's 'coding', the next step is to write the code.
|
||||
|
||||
**Action:** Code according to your thoughts. (Please note that only the content printed out by the executed code can be observed in the subsequent observation.) Use this format for code:
|
||||
**Action Status:** Set to 'stoped' or 'code_executing'. If it's 'stoped', the action is to provide the final answer to the original question. If it's 'code_executing', the action is to write the code.
|
||||
|
||||
**Action:**
|
||||
```python
|
||||
# Write your code here
|
||||
import os
|
||||
...
|
||||
```
|
||||
|
||||
**Observation:** Check the results and effects of the executed code.
|
||||
|
@ -24,7 +27,7 @@ When users need help with coding, your role is to provide precise and effective
|
|||
|
||||
**Thoughts:** I now know the final answer
|
||||
|
||||
**Action Status:** Set to 'finished'
|
||||
**Action Status:** Set to 'stoped'
|
||||
|
||||
**Action:** The final answer to the original input question
|
||||
|
||||
|
|
|
@ -10,7 +10,7 @@ When users need help with coding, your role is to provide precise and effective
|
|||
|
||||
**Thoughts:** Based on the question and observations above, provide the plan for executing this step.
|
||||
|
||||
**Action Status:** Set to 'finished' or 'coding'. If it's 'finished', the next action is to provide the final answer to the original question. If it's 'coding', the next step is to write the code.
|
||||
**Action Status:** Set to 'stoped' or 'code_executing'. If it's 'stoped', the next action is to provide the final answer to the original question. If it's 'code_executing', the next step is to write the code.
|
||||
|
||||
**Action:** Code according to your thoughts. Use this format for code:
|
||||
|
||||
|
@ -24,7 +24,7 @@ When users need help with coding, your role is to provide precise and effective
|
|||
|
||||
**Thoughts:** I now know the final answer
|
||||
|
||||
**Action Status:** Set to 'finished'
|
||||
**Action Status:** Set to 'stoped'
|
||||
|
||||
**Action:** The final answer to the original input question
|
||||
|
||||
|
|
|
@ -1,25 +1,24 @@
|
|||
REACT_TOOL_AND_CODE_PLANNER_PROMPT = """#### Tool and Code Sequence Breakdown Assistant
|
||||
When users need assistance with deconstructing problems into a series of actionable plans using tools or code, your role is to provide a structured plan or a direct solution.
|
||||
REACT_TOOL_AND_CODE_PLANNER_PROMPT = """#### Planner Assistance Guidance
|
||||
When users seek assistance in breaking down complex issues into manageable and actionable steps,
|
||||
your responsibility is to deliver a well-organized strategy or resolution through the use of tools or coding.
|
||||
|
||||
ATTENTION: response carefully referenced "Response Output Format" in format.
|
||||
|
||||
You may use the following tools:
|
||||
{formatted_tools}
|
||||
Depending on the user's query, the response will either be a plan detailing the use of tools and reasoning, or a direct answer if the problem does not require breaking down.
|
||||
|
||||
#### Input Format
|
||||
|
||||
**Origin Query:** user's query
|
||||
**Question:** First, clarify the problem to be solved.
|
||||
|
||||
#### Follow this Response Format
|
||||
#### Response Output Format
|
||||
|
||||
**Action Status:** Set to 'planning' to provide a sequence of tasks, or 'only_answer' to provide a direct response without a plan.
|
||||
|
||||
**Action:**
|
||||
|
||||
For planning:
|
||||
**Action:**
|
||||
```list
|
||||
[
|
||||
"First step of the plan using a specified tool or a outline plan for code...",
|
||||
"Next step in the plan...",
|
||||
// Continue with additional steps as necessary
|
||||
"First, we should ...",
|
||||
]
|
||||
```
|
||||
|
||||
|
|
|
@ -12,14 +12,16 @@ Valid "tool_name" value:\n{tool_names}
|
|||
|
||||
**Question:** Start by understanding the input question to be answered.
|
||||
|
||||
**Thoughts:** Considering the user's question, previously executed steps, and the plan, decide whether the current step requires the use of a tool or coding. Solve the problem step by step, only displaying the thought process necessary for the current step of solving the problem. If a tool can be used, provide its name and parameters. If coding is required, outline the plan for executing this step.
|
||||
**Thoughts:** Considering the user's question, previously executed steps, and the plan, decide whether the current step requires the use of a tool or code_executing. Solve the problem step by step, only displaying the thought process necessary for the current step of solving the problem. If a tool can be used, provide its name and parameters. If code_executing is required, outline the plan for executing this step.
|
||||
|
||||
**Action Status:** finished, tool_using, or coding. (Choose one from these three statuses. If the task is done, set it to 'finished'. If using a tool, set it to 'tool_using'. If writing code, set it to 'coding'.)
|
||||
**Action Status:** stoped, tool_using, or code_executing. (Choose one from these three statuses.)
|
||||
If the task is done, set it to 'stoped'.
|
||||
If using a tool, set it to 'tool_using'.
|
||||
If writing code, set it to 'code_executing'.
|
||||
|
||||
**Action:**
|
||||
|
||||
If using a tool, output the following format to call the tool:
|
||||
|
||||
If using a tool, use the tools by formatting the tool action in JSON from Question and Observation:. The format should be:
|
||||
```json
|
||||
{{
|
||||
"tool_name": "$TOOL_NAME",
|
||||
|
@ -30,7 +32,7 @@ If using a tool, output the following format to call the tool:
|
|||
If the problem cannot be solved with a tool at the moment, then proceed to solve the issue using code. Output the following format to execute the code:
|
||||
|
||||
```python
|
||||
# Write your code here
|
||||
Write your code here
|
||||
```
|
||||
|
||||
**Observation:** Check the results and effects of the executed action.
|
||||
|
@ -39,7 +41,7 @@ If the problem cannot be solved with a tool at the moment, then proceed to solve
|
|||
|
||||
**Thoughts:** Conclude the final response to the input question.
|
||||
|
||||
**Action Status:** finished
|
||||
**Action Status:** stoped
|
||||
|
||||
**Action:** The final answer or guidance to the original input question.
|
||||
"""
|
||||
|
@ -47,7 +49,7 @@ If the problem cannot be solved with a tool at the moment, then proceed to solve
|
|||
# REACT_TOOL_AND_CODE_PROMPT = """你是一个使用工具与代码的助手。
|
||||
# 如果现有工具不足以完成整个任务,请不要添加不存在的工具,只使用现有工具完成可能的部分。
|
||||
# 如果当前步骤不能使用工具完成,将由代码来完成。
|
||||
# 有效的"action"值为:"finished"(已经完成用户的任务) 、 "tool_using" (使用工具来回答问题) 或 'coding'(结合总结下述思维链过程编写下一步的可执行代码)。
|
||||
# 有效的"action"值为:"stoped"(已经完成用户的任务) 、 "tool_using" (使用工具来回答问题) 或 'code_executing'(结合总结下述思维链过程编写下一步的可执行代码)。
|
||||
# 尽可能地以有帮助和准确的方式回应人类,你可以使用以下工具:
|
||||
# {formatted_tools}
|
||||
# 如果现在的步骤可以用工具解决问题,请仅在每个$JSON_BLOB中提供一个action,如下所示:
|
||||
|
|
|
@ -16,7 +16,7 @@ valid "tool_name" value is:\n{tool_names}
|
|||
|
||||
**Thoughts:** Based on the question and previous observations, plan the approach for using the tool effectively.
|
||||
|
||||
**Action Status:** Set to either 'finished' or 'tool_using'. If 'finished', provide the final response to the original question. If 'tool_using', proceed with using the specified tool.
|
||||
**Action Status:** Set to either 'stoped' or 'tool_using'. If 'stoped', provide the final response to the original question. If 'tool_using', proceed with using the specified tool.
|
||||
|
||||
**Action:** Use the tools by formatting the tool action in JSON. The format should be:
|
||||
|
||||
|
@ -33,7 +33,7 @@ valid "tool_name" value is:\n{tool_names}
|
|||
|
||||
**Thoughts:** Determine the final response based on the results.
|
||||
|
||||
**Action Status:** Set to 'finished'
|
||||
**Action Status:** Set to 'stoped'
|
||||
|
||||
**Action:** Conclude with the final response to the original question in this format:
|
||||
|
||||
|
@ -49,7 +49,7 @@ valid "tool_name" value is:\n{tool_names}
|
|||
# REACT_TOOL_PROMPT = """尽可能地以有帮助和准确的方式回应人类。您可以使用以下工具:
|
||||
# {formatted_tools}
|
||||
# 使用json blob来指定一个工具,提供一个action关键字(工具名称)和一个tool_params关键字(工具输入)。
|
||||
# 有效的"action"值为:"finished" 或 "tool_using" (使用工具来回答问题)
|
||||
# 有效的"action"值为:"stoped" 或 "tool_using" (使用工具来回答问题)
|
||||
# 有效的"tool_name"值为:{tool_names}
|
||||
# 请仅在每个$JSON_BLOB中提供一个action,如下所示:
|
||||
# ```
|
||||
|
@ -73,7 +73,7 @@ valid "tool_name" value is:\n{tool_names}
|
|||
# 行动:
|
||||
# ```
|
||||
# {{{{
|
||||
# "action": "finished",
|
||||
# "action": "stoped",
|
||||
# "tool_name": "notool",
|
||||
# "tool_params": "最终返回答案给到用户"
|
||||
# }}}}
|
||||
|
|
|
@ -43,6 +43,38 @@ class MessageUtils:
|
|||
output_message.code_docs = input_message.code_docs
|
||||
output_message.figures.update(input_message.figures)
|
||||
output_message.origin_query = input_message.origin_query
|
||||
output_message.code_engine_name = input_message.code_engine_name
|
||||
|
||||
output_message.doc_engine_name = input_message.doc_engine_name
|
||||
output_message.search_engine_name = input_message.search_engine_name
|
||||
output_message.top_k = input_message.top_k
|
||||
output_message.score_threshold = input_message.score_threshold
|
||||
output_message.cb_search_type = input_message.cb_search_type
|
||||
output_message.do_doc_retrieval = input_message.do_doc_retrieval
|
||||
output_message.do_code_retrieval = input_message.do_code_retrieval
|
||||
output_message.do_tool_retrieval = input_message.do_tool_retrieval
|
||||
#
|
||||
output_message.tools = input_message.tools
|
||||
output_message.agents = input_message.agents
|
||||
# 存在bug导致相同key被覆盖
|
||||
output_message.customed_kargs.update(input_message.customed_kargs)
|
||||
return output_message
|
||||
|
||||
def inherit_baseparam(self, input_message: Message, output_message: Message):
|
||||
# 只更新参数
|
||||
output_message.doc_engine_name = input_message.doc_engine_name
|
||||
output_message.search_engine_name = input_message.search_engine_name
|
||||
output_message.top_k = input_message.top_k
|
||||
output_message.score_threshold = input_message.score_threshold
|
||||
output_message.cb_search_type = input_message.cb_search_type
|
||||
output_message.do_doc_retrieval = input_message.do_doc_retrieval
|
||||
output_message.do_code_retrieval = input_message.do_code_retrieval
|
||||
output_message.do_tool_retrieval = input_message.do_tool_retrieval
|
||||
#
|
||||
output_message.tools = input_message.tools
|
||||
output_message.agents = input_message.agents
|
||||
# 存在bug导致相同key被覆盖
|
||||
output_message.customed_kargs.update(input_message.customed_kargs)
|
||||
return output_message
|
||||
|
||||
def get_extrainfo_step(self, message: Message, do_search, do_doc_retrieval, do_code_retrieval, do_tool_retrieval) -> Message:
|
||||
|
@ -92,18 +124,22 @@ class MessageUtils:
|
|||
def get_tool_retrieval(self, message: Message) -> Message:
|
||||
return message
|
||||
|
||||
def step_router(self, message: Message) -> tuple[Message, ...]:
|
||||
def step_router(self, message: Message, history: Memory = None, background: Memory = None, memory_pool: Memory=None) -> tuple[Message, ...]:
|
||||
''''''
|
||||
# message = self.parser(message)
|
||||
# logger.debug(f"message.action_status: {message.action_status}")
|
||||
observation_message = None
|
||||
if message.action_status == ActionStatus.CODING:
|
||||
if message.action_status == ActionStatus.CODE_EXECUTING:
|
||||
message, observation_message = self.code_step(message)
|
||||
elif message.action_status == ActionStatus.TOOL_USING:
|
||||
message, observation_message = self.tool_step(message)
|
||||
elif message.action_status == ActionStatus.CODING2FILE:
|
||||
self.save_code2file(message)
|
||||
|
||||
elif message.action_status == ActionStatus.CODE_RETRIEVAL:
|
||||
pass
|
||||
elif message.action_status == ActionStatus.CODING:
|
||||
pass
|
||||
|
||||
return message, observation_message
|
||||
|
||||
def code_step(self, message: Message) -> Message:
|
||||
|
@ -126,7 +162,6 @@ class MessageUtils:
|
|||
message.code_answer = f"\n**Observation:**: The return figure name is {uid} after executing the above code.\n"
|
||||
message.observation = f"\n**Observation:**: The return figure name is {uid} after executing the above code.\n"
|
||||
message.step_content += f"\n**Observation:**: The return figure name is {uid} after executing the above code.\n"
|
||||
message.step_contents += [f"\n**Observation:**:The return figure name is {uid} after executing the above code.\n"]
|
||||
# message.role_content += f"\n**Observation:**:执行上述代码后生成一张图片, 图片名为{uid}\n"
|
||||
observation_message.role_content = f"\n**Observation:**: The return figure name is {uid} after executing the above code.\n"
|
||||
observation_message.parsed_output = {"Observation": f"The return figure name is {uid} after executing the above code."}
|
||||
|
@ -134,7 +169,6 @@ class MessageUtils:
|
|||
message.code_answer = code_answer.code_exe_response
|
||||
message.observation = code_answer.code_exe_response
|
||||
message.step_content += f"\n**Observation:**: {code_prompt}\n"
|
||||
message.step_contents += [f"\n**Observation:**: {code_prompt}\n"]
|
||||
# message.role_content += f"\n**Observation:**: {code_prompt}\n"
|
||||
observation_message.role_content = f"\n**Observation:**: {code_prompt}\n"
|
||||
observation_message.parsed_output = {"Observation": code_prompt}
|
||||
|
@ -159,7 +193,6 @@ class MessageUtils:
|
|||
message.observation = "\n**Observation:** there is no tool can execute\n"
|
||||
# message.role_content += f"\n**Observation:**: 不存在可以执行的tool\n"
|
||||
message.step_content += f"\n**Observation:** there is no tool can execute\n"
|
||||
message.step_contents += [f"\n**Observation:** there is no tool can execute\n"]
|
||||
observation_message.role_content = f"\n**Observation:** there is no tool can execute\n"
|
||||
observation_message.parsed_output = {"Observation": "there is no tool can execute\n"}
|
||||
for tool in message.tools:
|
||||
|
@ -170,7 +203,6 @@ class MessageUtils:
|
|||
message.observation = tool_res
|
||||
# message.role_content += f"\n**Observation:**: {tool_res}\n"
|
||||
message.step_content += f"\n**Observation:** {tool_res}\n"
|
||||
message.step_contents += [f"\n**Observation:** {tool_res}\n"]
|
||||
observation_message.role_content = f"\n**Observation:** {tool_res}\n"
|
||||
observation_message.parsed_output = {"Observation": tool_res}
|
||||
break
|
||||
|
|
|
@ -5,12 +5,12 @@ import importlib
|
|||
import copy
|
||||
from loguru import logger
|
||||
|
||||
from dev_opsgpt.connector.agents import BaseAgent
|
||||
from dev_opsgpt.connector.agents import BaseAgent, SelectorAgent
|
||||
from dev_opsgpt.connector.chains import BaseChain
|
||||
from dev_opsgpt.tools.base_tool import BaseTools, Tool
|
||||
|
||||
from dev_opsgpt.connector.schema import (
|
||||
Memory, Task, Env, Role, Message, Doc, Docs, AgentConfig, ChainConfig, PhaseConfig, CodeDoc,
|
||||
Memory, Task, Env, Role, Message, Doc, AgentConfig, ChainConfig, PhaseConfig, CodeDoc,
|
||||
load_chain_configs, load_phase_configs, load_role_configs
|
||||
)
|
||||
from dev_opsgpt.connector.configs import AGETN_CONFIGS, CHAIN_CONFIGS, PHASE_CONFIGS
|
||||
|
@ -156,12 +156,31 @@ class BasePhase:
|
|||
focus_agents=agent_config.focus_agents,
|
||||
focus_message_keys=agent_config.focus_message_keys,
|
||||
)
|
||||
|
||||
if agent_config.role.agent_type == "SelectorAgent":
|
||||
for group_agent_name in agent_config.group_agents:
|
||||
group_agent_config = role_configs[group_agent_name]
|
||||
baseAgent: BaseAgent = getattr(self.agent_module, group_agent_config.role.agent_type)
|
||||
group_base_agent = baseAgent(
|
||||
group_agent_config.role,
|
||||
task = task,
|
||||
memory = memory,
|
||||
chat_turn=group_agent_config.chat_turn,
|
||||
do_search = group_agent_config.do_search,
|
||||
do_doc_retrieval = group_agent_config.do_doc_retrieval,
|
||||
do_tool_retrieval = group_agent_config.do_tool_retrieval,
|
||||
stop= group_agent_config.stop,
|
||||
focus_agents=group_agent_config.focus_agents,
|
||||
focus_message_keys=group_agent_config.focus_message_keys,
|
||||
)
|
||||
base_agent.group_agents.append(group_base_agent)
|
||||
|
||||
agents.append(base_agent)
|
||||
|
||||
chain_instance = BaseChain(
|
||||
chain_config, agents, chain_config.chat_turn,
|
||||
do_checker=chain_configs[chain_name].do_checker,
|
||||
do_code_exec=False,)
|
||||
)
|
||||
chains.append(chain_instance)
|
||||
|
||||
return chains
|
||||
|
|
|
@ -8,22 +8,78 @@ from langchain.tools import BaseTool
|
|||
|
||||
|
||||
class ActionStatus(Enum):
|
||||
FINISHED = "finished"
|
||||
CODING = "coding"
|
||||
TOOL_USING = "tool_using"
|
||||
REASONING = "reasoning"
|
||||
PLANNING = "planning"
|
||||
EXECUTING_CODE = "executing_code"
|
||||
EXECUTING_TOOL = "executing_tool"
|
||||
DEFAUILT = "default"
|
||||
|
||||
FINISHED = "finished"
|
||||
STOPED = "stoped"
|
||||
CONTINUED = "continued"
|
||||
|
||||
TOOL_USING = "tool_using"
|
||||
CODING = "coding"
|
||||
CODE_EXECUTING = "code_executing"
|
||||
CODING2FILE = "coding2file"
|
||||
|
||||
PLANNING = "planning"
|
||||
UNCHANGED = "unchanged"
|
||||
ADJUSTED = "adjusted"
|
||||
CODE_RETRIEVAL = "code_retrieval"
|
||||
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, str):
|
||||
return self.value == other
|
||||
return self.value.lower() == other.lower()
|
||||
return super().__eq__(other)
|
||||
|
||||
|
||||
class Action(BaseModel):
|
||||
action_name: str
|
||||
description: str
|
||||
|
||||
class FinishedAction(Action):
|
||||
action_name: str = ActionStatus.FINISHED
|
||||
description: str = "provide the final answer to the original query to break the chain answer"
|
||||
|
||||
class StopedAction(Action):
|
||||
action_name: str = ActionStatus.STOPED
|
||||
description: str = "provide the final answer to the original query to break the agent answer"
|
||||
|
||||
class ContinuedAction(Action):
|
||||
action_name: str = ActionStatus.CONTINUED
|
||||
description: str = "cant't provide the final answer to the original query"
|
||||
|
||||
class ToolUsingAction(Action):
|
||||
action_name: str = ActionStatus.TOOL_USING
|
||||
description: str = "proceed with using the specified tool."
|
||||
|
||||
class CodingdAction(Action):
|
||||
action_name: str = ActionStatus.CODING
|
||||
description: str = "provide the answer by writing code"
|
||||
|
||||
class Coding2FileAction(Action):
|
||||
action_name: str = ActionStatus.CODING2FILE
|
||||
description: str = "provide the answer by writing code and filename"
|
||||
|
||||
class CodeExecutingAction(Action):
|
||||
action_name: str = ActionStatus.CODE_EXECUTING
|
||||
description: str = "provide the answer by writing executable code"
|
||||
|
||||
class PlanningAction(Action):
|
||||
action_name: str = ActionStatus.PLANNING
|
||||
description: str = "provide a sequence of tasks"
|
||||
|
||||
class UnchangedAction(Action):
|
||||
action_name: str = ActionStatus.UNCHANGED
|
||||
description: str = "this PLAN has no problem, just set PLAN_STEP to CURRENT_STEP+1."
|
||||
|
||||
class AdjustedAction(Action):
|
||||
action_name: str = ActionStatus.ADJUSTED
|
||||
description: str = "the PLAN is to provide an optimized version of the original plan."
|
||||
|
||||
# extended action exmaple
|
||||
class CodeRetrievalAction(Action):
|
||||
action_name: str = ActionStatus.CODE_RETRIEVAL
|
||||
description: str = "execute the code retrieval to acquire more code information"
|
||||
|
||||
|
||||
class RoleTypeEnums(Enum):
|
||||
SYSTEM = "system"
|
||||
USER = "user"
|
||||
|
@ -37,7 +93,11 @@ class RoleTypeEnums(Enum):
|
|||
return super().__eq__(other)
|
||||
|
||||
|
||||
class InputKeyEnums(Enum):
|
||||
class PromptKey(BaseModel):
|
||||
key_name: str
|
||||
description: str
|
||||
|
||||
class PromptKeyEnums(Enum):
|
||||
# Origin Query is ui's user question
|
||||
ORIGIN_QUERY = "origin_query"
|
||||
# agent's input from last agent
|
||||
|
@ -49,9 +109,9 @@ class InputKeyEnums(Enum):
|
|||
# chain memory
|
||||
CHAIN_MEMORY = "chain_memory"
|
||||
# agent's memory
|
||||
SELF_ONE_MEMORY = "self_one_memory"
|
||||
SELF_LOCAL_MEMORY = "self_local_memory"
|
||||
# chain memory
|
||||
CHAIN_ONE_MEMORY = "chain_one_memory"
|
||||
CHAIN_LOCAL_MEMORY = "chain_local_memory"
|
||||
# Doc Infomations contains (Doc\Code\Search)
|
||||
DOC_INFOS = "doc_infos"
|
||||
|
||||
|
@ -107,14 +167,6 @@ class CodeDoc(BaseModel):
|
|||
return f"""出处 [{self.index + 1}] \n\n来源 ({self.related_nodes}) \n\n内容 {self.code}\n\n"""
|
||||
|
||||
|
||||
class Docs:
|
||||
|
||||
def __init__(self, docs: List[Doc]):
|
||||
self.titles: List[str] = [doc.get_title() for doc in docs]
|
||||
self.snippets: List[str] = [doc.get_snippet() for doc in docs]
|
||||
self.links: List[str] = [doc.get_link() for doc in docs]
|
||||
self.indexs: List[int] = [doc.get_index() for doc in docs]
|
||||
|
||||
class Task(BaseModel):
|
||||
task_type: str
|
||||
task_name: str
|
||||
|
@ -163,6 +215,7 @@ class AgentConfig(BaseModel):
|
|||
do_tool_retrieval: bool = False
|
||||
focus_agents: List = []
|
||||
focus_message_keys: List = []
|
||||
group_agents: List = []
|
||||
|
||||
|
||||
class PhaseConfig(BaseModel):
|
||||
|
|
|
@ -12,15 +12,11 @@ class Message(BaseModel):
|
|||
input_query: str = None
|
||||
origin_query: str = None
|
||||
|
||||
# 模型最终返回
|
||||
# llm output
|
||||
role_content: str = None
|
||||
role_contents: List[str] = []
|
||||
step_content: str = None
|
||||
step_contents: List[str] = []
|
||||
chain_content: str = None
|
||||
chain_contents: List[str] = []
|
||||
|
||||
# 模型结果解析
|
||||
# llm parsed information
|
||||
plans: List[str] = None
|
||||
code_content: str = None
|
||||
code_filename: str = None
|
||||
|
@ -30,7 +26,7 @@ class Message(BaseModel):
|
|||
spec_parsed_output: dict = {}
|
||||
parsed_output_list: List[Dict] = []
|
||||
|
||||
# 执行结果
|
||||
# llm\tool\code executre information
|
||||
action_status: str = ActionStatus.DEFAUILT
|
||||
agent_index: int = None
|
||||
code_answer: str = None
|
||||
|
@ -38,7 +34,7 @@ class Message(BaseModel):
|
|||
observation: str = None
|
||||
figures: Dict[str, str] = {}
|
||||
|
||||
# 辅助信息
|
||||
# prompt support information
|
||||
tools: List[BaseTool] = []
|
||||
task: Task = None
|
||||
db_docs: List['Doc'] = []
|
||||
|
@ -46,7 +42,7 @@ class Message(BaseModel):
|
|||
search_docs: List['Doc'] = []
|
||||
agents: List = []
|
||||
|
||||
# 执行输入
|
||||
# phase input
|
||||
phase_name: str = None
|
||||
chain_name: str = None
|
||||
do_search: bool = False
|
||||
|
@ -60,6 +56,8 @@ class Message(BaseModel):
|
|||
do_code_retrieval: bool = False
|
||||
do_tool_retrieval: bool = False
|
||||
history_node_list: List[str] = []
|
||||
# user's customed kargs for init or end action
|
||||
customed_kargs: dict = {}
|
||||
|
||||
def to_tuple_message(self, return_all: bool = True, content_key="role_content"):
|
||||
role_content = self.to_str_content(False, content_key)
|
||||
|
|
|
@ -29,7 +29,7 @@ class NebulaHandler:
|
|||
self.password = password
|
||||
self.space_name = space_name
|
||||
|
||||
def execute_cypher(self, cypher: str, space_name: str = ''):
|
||||
def execute_cypher(self, cypher: str, space_name: str = '', format_res: bool = False, use_space_name: bool = True):
|
||||
'''
|
||||
|
||||
@param space_name: space_name, if provided, will execute use space_name first
|
||||
|
@ -37,11 +37,17 @@ class NebulaHandler:
|
|||
@return:
|
||||
'''
|
||||
with self.connection_pool.session_context(self.username, self.password) as session:
|
||||
if space_name:
|
||||
cypher = f'USE {space_name};{cypher}'
|
||||
if use_space_name:
|
||||
if space_name:
|
||||
cypher = f'USE {space_name};{cypher}'
|
||||
elif self.space_name:
|
||||
cypher = f'USE {self.space_name};{cypher}'
|
||||
|
||||
logger.debug(cypher)
|
||||
resp = session.execute(cypher)
|
||||
|
||||
if format_res:
|
||||
resp = self.result_to_dict(resp)
|
||||
return resp
|
||||
|
||||
def close_connection(self):
|
||||
|
@ -54,7 +60,8 @@ class NebulaHandler:
|
|||
@return:
|
||||
'''
|
||||
cypher = f'CREATE SPACE IF NOT EXISTS {space_name} (vid_type={vid_type}) comment="{comment}";'
|
||||
resp = self.execute_cypher(cypher)
|
||||
resp = self.execute_cypher(cypher, use_space_name=False)
|
||||
|
||||
return resp
|
||||
|
||||
def show_space(self):
|
||||
|
@ -95,7 +102,7 @@ class NebulaHandler:
|
|||
|
||||
def insert_vertex(self, tag_name: str, value_dict: dict):
|
||||
'''
|
||||
insert vertex
|
||||
insert vertex
|
||||
@param tag_name:
|
||||
@param value_dict: {'properties_name': [], values: {'vid':[]}} order should be the same in properties_name and values
|
||||
@return:
|
||||
|
@ -243,7 +250,7 @@ class NebulaHandler:
|
|||
"""
|
||||
build list for each column, and transform to dataframe
|
||||
"""
|
||||
logger.info(result.error_msg())
|
||||
# logger.info(result.error_msg())
|
||||
assert result.is_succeeded()
|
||||
columns = result.keys()
|
||||
d = {}
|
||||
|
|
|
@ -11,8 +11,6 @@ import json
|
|||
import os
|
||||
from loguru import logger
|
||||
|
||||
from configs.model_config import OPENAI_API_BASE
|
||||
|
||||
|
||||
class OpenAIEmbedding:
|
||||
def __init__(self):
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from .openai_model import getChatModel
|
||||
from .openai_model import getChatModel, getExtraModel
|
||||
|
||||
|
||||
__all__ = [
|
||||
"getChatModel"
|
||||
"getChatModel", "getExtraModel"
|
||||
]
|
|
@ -26,4 +26,23 @@ def getChatModel(callBack: AsyncIteratorCallbackHandler = None, temperature=0.3,
|
|||
temperature=temperature,
|
||||
stop=stop
|
||||
)
|
||||
return model
|
||||
return model
|
||||
|
||||
import json, requests
|
||||
|
||||
|
||||
def getExtraModel():
|
||||
return TestModel()
|
||||
|
||||
class TestModel:
|
||||
|
||||
def predict(self, request_body):
|
||||
headers = {"Content-Type":"application/json;charset=UTF-8",
|
||||
"codegpt_user":"",
|
||||
"codegpt_token":""
|
||||
}
|
||||
xxx = requests.post(
|
||||
'https://codegencore.alipay.com/api/chat/CODE_LLAMA_INT4/completion',
|
||||
data=json.dumps(request_body,ensure_ascii=False).encode('utf-8'),
|
||||
headers=headers)
|
||||
return xxx.json()["data"]
|
|
@ -18,6 +18,10 @@ from .service_factory import KBServiceFactory
|
|||
from dev_opsgpt.utils.server_utils import BaseResponse, ListResponse
|
||||
from dev_opsgpt.utils.path_utils import *
|
||||
from dev_opsgpt.orm.commands import *
|
||||
from dev_opsgpt.db_handler.graph_db_handler.nebula_handler import NebulaHandler
|
||||
from dev_opsgpt.db_handler.vector_db_handler.chroma_handler import ChromaHandler
|
||||
from configs.server_config import NEBULA_HOST, NEBULA_PORT, NEBULA_USER, NEBULA_PASSWORD, NEBULA_STORAGED_PORT
|
||||
from configs.server_config import CHROMA_PERSISTENT_PATH
|
||||
|
||||
from configs.model_config import (
|
||||
CB_ROOT_PATH
|
||||
|
@ -125,6 +129,60 @@ def search_code(cb_name: str = Body(..., examples=["sofaboot"]),
|
|||
return {}
|
||||
|
||||
|
||||
def search_related_vertices(cb_name: str = Body(..., examples=["sofaboot"]),
|
||||
vertex: str = Body(..., examples=['***'])) -> dict:
|
||||
|
||||
logger.info('cb_name={}'.format(cb_name))
|
||||
logger.info('vertex={}'.format(vertex))
|
||||
|
||||
try:
|
||||
# load codebase
|
||||
nh = NebulaHandler(host=NEBULA_HOST, port=NEBULA_PORT, username=NEBULA_USER,
|
||||
password=NEBULA_PASSWORD, space_name=cb_name)
|
||||
|
||||
cypher = f'''MATCH (v1)--(v2) WHERE id(v1) == '{vertex}' RETURN id(v2) as id;'''
|
||||
|
||||
cypher_res = nh.execute_cypher(cypher=cypher, format_res=True)
|
||||
|
||||
related_vertices = cypher_res.get('id', [])
|
||||
|
||||
res = {
|
||||
'vertices': related_vertices
|
||||
}
|
||||
|
||||
return res
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
return {}
|
||||
|
||||
|
||||
def search_code_by_vertex(cb_name: str = Body(..., examples=["sofaboot"]),
|
||||
vertex: str = Body(..., examples=['***'])) -> dict:
|
||||
|
||||
logger.info('cb_name={}'.format(cb_name))
|
||||
logger.info('vertex={}'.format(vertex))
|
||||
|
||||
try:
|
||||
ch = ChromaHandler(path=CHROMA_PERSISTENT_PATH, collection_name=cb_name)
|
||||
|
||||
# fix vertex
|
||||
vertex_use = '#'.join(vertex.split('#')[0:2])
|
||||
ids = [vertex_use]
|
||||
|
||||
chroma_res = ch.get(ids=ids)
|
||||
|
||||
code_text = chroma_res['result']['metadatas'][0]['code_text']
|
||||
|
||||
res = {
|
||||
'code': code_text
|
||||
}
|
||||
|
||||
return res
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
return {}
|
||||
|
||||
|
||||
def cb_exists_api(cb_name: str = Body(..., examples=["sofaboot"])) -> bool:
|
||||
try:
|
||||
res = cb_exists(cb_name)
|
||||
|
|
|
@ -1,4 +1,11 @@
|
|||
# Attention: code copied from https://github.com/chatchat-space/Langchain-Chatchat/blob/master/server/llm_api.py
|
||||
############################# Attention ########################
|
||||
|
||||
# Code copied from
|
||||
# https://github.com/chatchat-space/Langchain-Chatchat/blob/master/server/llm_api.py
|
||||
|
||||
#################################################################
|
||||
|
||||
|
||||
|
||||
from multiprocessing import Process, Queue
|
||||
import multiprocessing as mp
|
||||
|
@ -16,10 +23,12 @@ src_dir = os.path.join(
|
|||
sys.path.append(src_dir)
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
|
||||
from configs.model_config import llm_model_dict, LLM_MODEL, LLM_DEVICE, LOG_PATH, logger
|
||||
from configs.model_config import llm_model_dict, LLM_MODEL, LLM_DEVICE, LOG_PATH, logger, LLM_MODELs
|
||||
from configs.server_config import (
|
||||
FSCHAT_CONTROLLER, FSCHAT_MODEL_WORKERS, FSCHAT_OPENAI_API
|
||||
)
|
||||
from dev_opsgpt.service.utils import get_model_worker_config
|
||||
|
||||
from dev_opsgpt.utils.server_utils import (
|
||||
MakeFastAPIOffline,
|
||||
)
|
||||
|
@ -43,38 +52,6 @@ def set_httpx_timeout(timeout=60.0):
|
|||
httpx._config.DEFAULT_TIMEOUT_CONFIG.write = timeout
|
||||
|
||||
|
||||
def get_model_worker_config(model_name: str = None) -> dict:
|
||||
'''
|
||||
加载model worker的配置项。
|
||||
优先级:FSCHAT_MODEL_WORKERS[model_name] > ONLINE_LLM_MODEL[model_name] > FSCHAT_MODEL_WORKERS["default"]
|
||||
'''
|
||||
from configs.model_config import ONLINE_LLM_MODEL
|
||||
from configs.server_config import FSCHAT_MODEL_WORKERS
|
||||
# from server import model_workers
|
||||
|
||||
config = FSCHAT_MODEL_WORKERS.get("default", {}).copy()
|
||||
# config.update(ONLINE_LLM_MODEL.get(model_name, {}).copy())
|
||||
config.update(FSCHAT_MODEL_WORKERS.get(model_name, {}).copy())
|
||||
|
||||
# if model_name in ONLINE_LLM_MODEL:
|
||||
# config["online_api"] = True
|
||||
# if provider := config.get("provider"):
|
||||
# try:
|
||||
# config["worker_class"] = getattr(model_workers, provider)
|
||||
# except Exception as e:
|
||||
# msg = f"在线模型 ‘{model_name}’ 的provider没有正确配置"
|
||||
# logger.error(f'{e.__class__.__name__}: {msg}',
|
||||
# exc_info=e if log_verbose else None)
|
||||
# 本地模型
|
||||
if model_name in llm_model_dict:
|
||||
path = llm_model_dict[model_name]["local_model_path"]
|
||||
config["model_path"] = path
|
||||
if path and os.path.isdir(path):
|
||||
config["model_path_exists"] = True
|
||||
config["device"] = LLM_DEVICE
|
||||
return config
|
||||
|
||||
|
||||
def get_all_model_worker_configs() -> dict:
|
||||
result = {}
|
||||
model_names = set(FSCHAT_MODEL_WORKERS.keys())
|
||||
|
@ -281,6 +258,9 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
|
|||
|
||||
for k, v in kwargs.items():
|
||||
setattr(args, k, v)
|
||||
|
||||
logger.error(f"可用模型有哪些: {args.model_names}")
|
||||
|
||||
if worker_class := kwargs.get("langchain_model"): #Langchian支持的模型不用做操作
|
||||
from fastchat.serve.base_model_worker import app
|
||||
worker = ""
|
||||
|
@ -296,7 +276,8 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
|
|||
# 本地模型
|
||||
else:
|
||||
from configs.model_config import VLLM_MODEL_DICT
|
||||
if kwargs["model_names"][0] in VLLM_MODEL_DICT and args.infer_turbo == "vllm":
|
||||
# if kwargs["model_names"][0] in VLLM_MODEL_DICT and args.infer_turbo == "vllm":
|
||||
if kwargs["model_names"][0] in VLLM_MODEL_DICT:
|
||||
import fastchat.serve.vllm_worker
|
||||
from fastchat.serve.vllm_worker import VLLMWorker, app, worker_id
|
||||
from vllm import AsyncLLMEngine
|
||||
|
@ -321,7 +302,7 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
|
|||
args.conv_template = None
|
||||
args.limit_worker_concurrency = 5
|
||||
args.no_register = False
|
||||
args.num_gpus = 4 # vllm worker的切分是tensor并行,这里填写显卡的数量
|
||||
args.num_gpus = 1 # vllm worker的切分是tensor并行,这里填写显卡的数量
|
||||
args.engine_use_ray = False
|
||||
args.disable_log_requests = False
|
||||
|
||||
|
@ -575,7 +556,7 @@ def run_model_worker(
|
|||
kwargs["worker_address"] = fschat_model_worker_address(model_name)
|
||||
model_path = kwargs.get("model_path", "")
|
||||
kwargs["model_path"] = model_path
|
||||
# kwargs["gptq_wbits"] = 4
|
||||
# kwargs["gptq_wbits"] = 4 # int4 模型试用这个参数
|
||||
|
||||
app = create_model_worker_app(log_level=log_level, **kwargs)
|
||||
_set_app_event(app, started_event)
|
||||
|
@ -660,7 +641,7 @@ def parse_args() -> argparse.ArgumentParser:
|
|||
"--model-name",
|
||||
type=str,
|
||||
nargs="+",
|
||||
default=[LLM_MODEL],
|
||||
default=LLM_MODELs,
|
||||
help="specify model name for model worker. "
|
||||
"add addition names with space seperated to start multiple model workers.",
|
||||
dest="model_name",
|
||||
|
@ -722,7 +703,7 @@ def dump_server_info(after_start=False, args=None):
|
|||
print(f"langchain版本:{langchain.__version__}. fastchat版本:{fastchat.__version__}")
|
||||
print("\n")
|
||||
|
||||
models = [LLM_MODEL]
|
||||
models = LLM_MODELs
|
||||
if args and args.model_name:
|
||||
models = args.model_name
|
||||
|
||||
|
@ -813,8 +794,10 @@ async def start_main_server():
|
|||
)
|
||||
processes["model_worker"][model_name] = process
|
||||
|
||||
|
||||
for model_name in args.model_name:
|
||||
config = get_model_worker_config(model_name)
|
||||
logger.error(f"config: {config}, {model_name}, {FSCHAT_MODEL_WORKERS.keys()}")
|
||||
if (config.get("online_api")
|
||||
and config.get("worker_class")
|
||||
and model_name in FSCHAT_MODEL_WORKERS):
|
||||
|
|
|
@ -0,0 +1,79 @@
|
|||
import base64
|
||||
import datetime
|
||||
import hashlib
|
||||
import hmac
|
||||
from urllib.parse import urlparse
|
||||
from datetime import datetime
|
||||
from time import mktime
|
||||
from urllib.parse import urlencode
|
||||
from wsgiref.handlers import format_date_time
|
||||
|
||||
|
||||
class Ws_Param(object):
|
||||
# 初始化
|
||||
def __init__(self, APPID, APIKey, APISecret, Spark_url):
|
||||
self.APPID = APPID
|
||||
self.APIKey = APIKey
|
||||
self.APISecret = APISecret
|
||||
self.host = urlparse(Spark_url).netloc
|
||||
self.path = urlparse(Spark_url).path
|
||||
self.Spark_url = Spark_url
|
||||
|
||||
# 生成url
|
||||
def create_url(self):
|
||||
# 生成RFC1123格式的时间戳
|
||||
now = datetime.now()
|
||||
date = format_date_time(mktime(now.timetuple()))
|
||||
|
||||
# 拼接字符串
|
||||
signature_origin = "host: " + self.host + "\n"
|
||||
signature_origin += "date: " + date + "\n"
|
||||
signature_origin += "GET " + self.path + " HTTP/1.1"
|
||||
|
||||
# 进行hmac-sha256进行加密
|
||||
signature_sha = hmac.new(self.APISecret.encode('utf-8'), signature_origin.encode('utf-8'),
|
||||
digestmod=hashlib.sha256).digest()
|
||||
|
||||
signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding='utf-8')
|
||||
|
||||
authorization_origin = f'api_key="{self.APIKey}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"'
|
||||
|
||||
authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
|
||||
|
||||
# 将请求的鉴权参数组合为字典
|
||||
v = {
|
||||
"authorization": authorization,
|
||||
"date": date,
|
||||
"host": self.host
|
||||
}
|
||||
# 拼接鉴权参数,生成url
|
||||
url = self.Spark_url + '?' + urlencode(v)
|
||||
# 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释,比对相同参数时生成的url与自己代码生成的url是否一致
|
||||
return url
|
||||
|
||||
|
||||
def gen_params(appid, domain, question, temperature, max_token):
|
||||
"""
|
||||
通过appid和用户的提问来生成请参数
|
||||
"""
|
||||
data = {
|
||||
"header": {
|
||||
"app_id": appid,
|
||||
"uid": "1234"
|
||||
},
|
||||
"parameter": {
|
||||
"chat": {
|
||||
"domain": domain,
|
||||
"random_threshold": 0.5,
|
||||
"max_tokens": max_token,
|
||||
"auditing": "default",
|
||||
"temperature": temperature,
|
||||
}
|
||||
},
|
||||
"payload": {
|
||||
"message": {
|
||||
"text": question
|
||||
}
|
||||
}
|
||||
}
|
||||
return data
|
|
@ -0,0 +1,18 @@
|
|||
############################# Attention ########################
|
||||
|
||||
# The Code in model workers all copied from
|
||||
# https://github.com/chatchat-space/Langchain-Chatchat/blob/master/server/model_workers
|
||||
|
||||
#################################################################
|
||||
|
||||
from .base import *
|
||||
from .zhipu import ChatGLMWorker
|
||||
from .minimax import MiniMaxWorker
|
||||
from .xinghuo import XingHuoWorker
|
||||
from .qianfan import QianFanWorker
|
||||
from .fangzhou import FangZhouWorker
|
||||
from .qwen import QwenWorker
|
||||
from .baichuan import BaiChuanWorker
|
||||
from .azure import AzureWorker
|
||||
from .tiangong import TianGongWorker
|
||||
from .openai import ExampleWorker
|
|
@ -0,0 +1,94 @@
|
|||
import sys
|
||||
from fastchat.conversation import Conversation
|
||||
from .base import *
|
||||
# from server.utils import get_httpx_client
|
||||
from fastchat import conversation as conv
|
||||
import json
|
||||
from typing import List, Dict
|
||||
from configs import logger, log_verbose
|
||||
|
||||
|
||||
class AzureWorker(ApiModelWorker):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
controller_addr: str = None,
|
||||
worker_addr: str = None,
|
||||
model_names: List[str] = ["azure-api"],
|
||||
version: str = "gpt-35-turbo",
|
||||
**kwargs,
|
||||
):
|
||||
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
|
||||
kwargs.setdefault("context_len", 8000) #TODO 16K模型需要改成16384
|
||||
super().__init__(**kwargs)
|
||||
self.version = version
|
||||
|
||||
def do_chat(self, params: ApiChatParams) -> Dict:
|
||||
params.load_config(self.model_names[0])
|
||||
data = dict(
|
||||
messages=params.messages,
|
||||
temperature=params.temperature,
|
||||
max_tokens=params.max_tokens,
|
||||
stream=True,
|
||||
)
|
||||
url = ("https://{}.openai.azure.com/openai/deployments/{}/chat/completions?api-version={}"
|
||||
.format(params.resource_name, params.deployment_name, params.api_version))
|
||||
headers = {
|
||||
'Content-Type': 'application/json',
|
||||
'Accept': 'application/json',
|
||||
'api-key': params.api_key,
|
||||
}
|
||||
|
||||
text = ""
|
||||
if log_verbose:
|
||||
logger.info(f'{self.__class__.__name__}:url: {url}')
|
||||
logger.info(f'{self.__class__.__name__}:headers: {headers}')
|
||||
logger.info(f'{self.__class__.__name__}:data: {data}')
|
||||
|
||||
with get_httpx_client() as client:
|
||||
with client.stream("POST", url, headers=headers, json=data) as response:
|
||||
for line in response.iter_lines():
|
||||
if not line.strip() or "[DONE]" in line:
|
||||
continue
|
||||
if line.startswith("data: "):
|
||||
line = line[6:]
|
||||
resp = json.loads(line)
|
||||
if choices := resp["choices"]:
|
||||
if chunk := choices[0].get("delta", {}).get("content"):
|
||||
text += chunk
|
||||
yield {
|
||||
"error_code": 0,
|
||||
"text": text
|
||||
}
|
||||
else:
|
||||
self.logger.error(f"请求 Azure API 时发生错误:{resp}")
|
||||
|
||||
def get_embeddings(self, params):
|
||||
# TODO: 支持embeddings
|
||||
print("embedding")
|
||||
print(params)
|
||||
|
||||
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
|
||||
# TODO: 确认模板是否需要修改
|
||||
return conv.Conversation(
|
||||
name=self.model_names[0],
|
||||
system_message="You are a helpful, respectful and honest assistant.",
|
||||
messages=[],
|
||||
roles=["user", "assistant"],
|
||||
sep="\n### ",
|
||||
stop_str="###",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
from server.utils import MakeFastAPIOffline
|
||||
from fastchat.serve.base_model_worker import app
|
||||
|
||||
worker = AzureWorker(
|
||||
controller_addr="http://127.0.0.1:20001",
|
||||
worker_addr="http://127.0.0.1:21008",
|
||||
)
|
||||
sys.modules["fastchat.serve.model_worker"].worker = worker
|
||||
MakeFastAPIOffline(app)
|
||||
uvicorn.run(app, port=21008)
|
|
@ -0,0 +1,119 @@
|
|||
import json
|
||||
import time
|
||||
import hashlib
|
||||
|
||||
from fastchat.conversation import Conversation
|
||||
from .base import *
|
||||
# from server.utils import get_httpx_client
|
||||
from fastchat import conversation as conv
|
||||
import sys
|
||||
import json
|
||||
from typing import List, Literal, Dict
|
||||
from configs import logger, log_verbose
|
||||
|
||||
def calculate_md5(input_string):
|
||||
md5 = hashlib.md5()
|
||||
md5.update(input_string.encode('utf-8'))
|
||||
encrypted = md5.hexdigest()
|
||||
return encrypted
|
||||
|
||||
|
||||
class BaiChuanWorker(ApiModelWorker):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
controller_addr: str = None,
|
||||
worker_addr: str = None,
|
||||
model_names: List[str] = ["baichuan-api"],
|
||||
version: Literal["Baichuan2-53B"] = "Baichuan2-53B",
|
||||
**kwargs,
|
||||
):
|
||||
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
|
||||
kwargs.setdefault("context_len", 32768)
|
||||
super().__init__(**kwargs)
|
||||
self.version = version
|
||||
|
||||
def do_chat(self, params: ApiChatParams) -> Dict:
|
||||
params.load_config(self.model_names[0])
|
||||
|
||||
url = "https://api.baichuan-ai.com/v1/stream/chat"
|
||||
data = {
|
||||
"model": params.version,
|
||||
"messages": params.messages,
|
||||
"parameters": {"temperature": params.temperature}
|
||||
}
|
||||
|
||||
json_data = json.dumps(data)
|
||||
time_stamp = int(time.time())
|
||||
signature = calculate_md5(params.secret_key + json_data + str(time_stamp))
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": "Bearer " + params.api_key,
|
||||
"X-BC-Request-Id": "your requestId",
|
||||
"X-BC-Timestamp": str(time_stamp),
|
||||
"X-BC-Signature": signature,
|
||||
"X-BC-Sign-Algo": "MD5",
|
||||
}
|
||||
|
||||
text = ""
|
||||
if log_verbose:
|
||||
logger.info(f'{self.__class__.__name__}:json_data: {json_data}')
|
||||
logger.info(f'{self.__class__.__name__}:url: {url}')
|
||||
logger.info(f'{self.__class__.__name__}:headers: {headers}')
|
||||
|
||||
with get_httpx_client() as client:
|
||||
with client.stream("POST", url, headers=headers, json=data) as response:
|
||||
for line in response.iter_lines():
|
||||
if not line.strip():
|
||||
continue
|
||||
resp = json.loads(line)
|
||||
if resp["code"] == 0:
|
||||
text += resp["data"]["messages"][-1]["content"]
|
||||
yield {
|
||||
"error_code": resp["code"],
|
||||
"text": text
|
||||
}
|
||||
else:
|
||||
data = {
|
||||
"error_code": resp["code"],
|
||||
"text": resp["msg"],
|
||||
"error": {
|
||||
"message": resp["msg"],
|
||||
"type": "invalid_request_error",
|
||||
"param": None,
|
||||
"code": None,
|
||||
}
|
||||
}
|
||||
self.logger.error(f"请求百川 API 时发生错误:{data}")
|
||||
yield data
|
||||
|
||||
def get_embeddings(self, params):
|
||||
# TODO: 支持embeddings
|
||||
print("embedding")
|
||||
print(params)
|
||||
|
||||
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
|
||||
# TODO: 确认模板是否需要修改
|
||||
return conv.Conversation(
|
||||
name=self.model_names[0],
|
||||
system_message="",
|
||||
messages=[],
|
||||
roles=["user", "assistant"],
|
||||
sep="\n### ",
|
||||
stop_str="###",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
from server.utils import MakeFastAPIOffline
|
||||
from fastchat.serve.model_worker import app
|
||||
|
||||
worker = BaiChuanWorker(
|
||||
controller_addr="http://127.0.0.1:20001",
|
||||
worker_addr="http://127.0.0.1:21007",
|
||||
)
|
||||
sys.modules["fastchat.serve.model_worker"].worker = worker
|
||||
MakeFastAPIOffline(app)
|
||||
uvicorn.run(app, port=21007)
|
||||
# do_request()
|
|
@ -0,0 +1,249 @@
|
|||
from fastchat.conversation import Conversation
|
||||
from configs import LOG_PATH
|
||||
import fastchat.constants
|
||||
fastchat.constants.LOGDIR = LOG_PATH
|
||||
from fastchat.serve.base_model_worker import BaseModelWorker
|
||||
import uuid
|
||||
import json
|
||||
import sys
|
||||
from pydantic import BaseModel, root_validator
|
||||
import fastchat
|
||||
import asyncio
|
||||
from dev_opsgpt.service.utils import get_model_worker_config
|
||||
|
||||
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
|
||||
__all__ = ["ApiModelWorker", "ApiChatParams", "ApiCompletionParams", "ApiEmbeddingsParams"]
|
||||
|
||||
|
||||
class ApiConfigParams(BaseModel):
|
||||
'''
|
||||
在线API配置参数,未提供的值会自动从model_config.ONLINE_LLM_MODEL中读取
|
||||
'''
|
||||
api_base_url: Optional[str] = None
|
||||
api_proxy: Optional[str] = None
|
||||
api_key: Optional[str] = None
|
||||
secret_key: Optional[str] = None
|
||||
group_id: Optional[str] = None # for minimax
|
||||
is_pro: bool = False # for minimax
|
||||
|
||||
APPID: Optional[str] = None # for xinghuo
|
||||
APISecret: Optional[str] = None # for xinghuo
|
||||
is_v2: bool = False # for xinghuo
|
||||
|
||||
worker_name: Optional[str] = None
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
@root_validator(pre=True)
|
||||
def validate_config(cls, v: Dict) -> Dict:
|
||||
if config := get_model_worker_config(v.get("worker_name")):
|
||||
for n in cls.__fields__:
|
||||
if n in config:
|
||||
v[n] = config[n]
|
||||
return v
|
||||
|
||||
def load_config(self, worker_name: str):
|
||||
self.worker_name = worker_name
|
||||
if config := get_model_worker_config(worker_name):
|
||||
for n in self.__fields__:
|
||||
if n in config:
|
||||
setattr(self, n, config[n])
|
||||
return self
|
||||
|
||||
|
||||
class ApiModelParams(ApiConfigParams):
|
||||
'''
|
||||
模型配置参数
|
||||
'''
|
||||
version: Optional[str] = None
|
||||
version_url: Optional[str] = None
|
||||
api_version: Optional[str] = None # for azure
|
||||
deployment_name: Optional[str] = None # for azure
|
||||
resource_name: Optional[str] = None # for azure
|
||||
|
||||
temperature: float = 0.7
|
||||
max_tokens: Optional[int] = None
|
||||
top_p: Optional[float] = 1.0
|
||||
|
||||
|
||||
class ApiChatParams(ApiModelParams):
|
||||
'''
|
||||
chat请求参数
|
||||
'''
|
||||
messages: List[Dict[str, str]]
|
||||
system_message: Optional[str] = None # for minimax
|
||||
role_meta: Dict = {} # for minimax
|
||||
|
||||
|
||||
class ApiCompletionParams(ApiModelParams):
|
||||
prompt: str
|
||||
|
||||
|
||||
class ApiEmbeddingsParams(ApiConfigParams):
|
||||
texts: List[str]
|
||||
embed_model: Optional[str] = None
|
||||
to_query: bool = False # for minimax
|
||||
|
||||
|
||||
class ApiModelWorker(BaseModelWorker):
|
||||
DEFAULT_EMBED_MODEL: str = None # None means not support embedding
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_names: List[str],
|
||||
controller_addr: str = None,
|
||||
worker_addr: str = None,
|
||||
context_len: int = 2048,
|
||||
no_register: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
kwargs.setdefault("worker_id", uuid.uuid4().hex[:8])
|
||||
kwargs.setdefault("model_path", "")
|
||||
kwargs.setdefault("limit_worker_concurrency", 5)
|
||||
super().__init__(model_names=model_names,
|
||||
controller_addr=controller_addr,
|
||||
worker_addr=worker_addr,
|
||||
**kwargs)
|
||||
import fastchat.serve.base_model_worker
|
||||
import sys
|
||||
self.logger = fastchat.serve.base_model_worker.logger
|
||||
# 恢复被fastchat覆盖的标准输出
|
||||
sys.stdout = sys.__stdout__
|
||||
sys.stderr = sys.__stderr__
|
||||
|
||||
self.context_len = context_len
|
||||
self.semaphore = asyncio.Semaphore(self.limit_worker_concurrency)
|
||||
self.version = None
|
||||
|
||||
if not no_register and self.controller_addr:
|
||||
self.init_heart_beat()
|
||||
|
||||
|
||||
def count_token(self, params):
|
||||
# TODO:需要完善
|
||||
# print("count token")
|
||||
prompt = params["prompt"]
|
||||
return {"count": len(str(prompt)), "error_code": 0}
|
||||
|
||||
def generate_stream_gate(self, params: Dict):
|
||||
self.call_ct += 1
|
||||
|
||||
try:
|
||||
prompt = params["prompt"]
|
||||
if self._is_chat(prompt):
|
||||
messages = self.prompt_to_messages(prompt)
|
||||
messages = self.validate_messages(messages)
|
||||
else: # 使用chat模仿续写功能,不支持历史消息
|
||||
messages = [{"role": self.user_role, "content": f"please continue writing from here: {prompt}"}]
|
||||
|
||||
p = ApiChatParams(
|
||||
messages=messages,
|
||||
temperature=params.get("temperature"),
|
||||
top_p=params.get("top_p"),
|
||||
max_tokens=params.get("max_new_tokens"),
|
||||
version=self.version,
|
||||
)
|
||||
for resp in self.do_chat(p):
|
||||
yield self._jsonify(resp)
|
||||
except Exception as e:
|
||||
yield self._jsonify({"error_code": 500, "text": f"{self.model_names[0]}请求API时发生错误:{e}"})
|
||||
|
||||
def generate_gate(self, params):
|
||||
try:
|
||||
for x in self.generate_stream_gate(params):
|
||||
...
|
||||
return json.loads(x[:-1].decode())
|
||||
except Exception as e:
|
||||
return {"error_code": 500, "text": str(e)}
|
||||
|
||||
|
||||
# 需要用户自定义的方法
|
||||
|
||||
def do_chat(self, params: ApiChatParams) -> Dict:
|
||||
'''
|
||||
执行Chat的方法,默认使用模块里面的chat函数。
|
||||
要求返回形式:{"error_code": int, "text": str}
|
||||
'''
|
||||
return {"error_code": 500, "text": f"{self.model_names[0]}未实现chat功能"}
|
||||
|
||||
# def do_completion(self, p: ApiCompletionParams) -> Dict:
|
||||
# '''
|
||||
# 执行Completion的方法,默认使用模块里面的completion函数。
|
||||
# 要求返回形式:{"error_code": int, "text": str}
|
||||
# '''
|
||||
# return {"error_code": 500, "text": f"{self.model_names[0]}未实现completion功能"}
|
||||
|
||||
def do_embeddings(self, params: ApiEmbeddingsParams) -> Dict:
|
||||
'''
|
||||
执行Embeddings的方法,默认使用模块里面的embed_documents函数。
|
||||
要求返回形式:{"code": int, "data": List[List[float]], "msg": str}
|
||||
'''
|
||||
return {"code": 500, "msg": f"{self.model_names[0]}未实现embeddings功能"}
|
||||
|
||||
def get_embeddings(self, params):
|
||||
# fastchat对LLM做Embeddings限制很大,似乎只能使用openai的。
|
||||
# 在前端通过OpenAIEmbeddings发起的请求直接出错,无法请求过来。
|
||||
print("get_embedding")
|
||||
print(params)
|
||||
|
||||
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
|
||||
raise NotImplementedError
|
||||
|
||||
def validate_messages(self, messages: List[Dict]) -> List[Dict]:
|
||||
'''
|
||||
有些API对mesages有特殊格式,可以重写该函数替换默认的messages。
|
||||
之所以跟prompt_to_messages分开,是因为他们应用场景不同、参数不同
|
||||
'''
|
||||
return messages
|
||||
|
||||
|
||||
# help methods
|
||||
@property
|
||||
def user_role(self):
|
||||
return self.conv.roles[0]
|
||||
|
||||
@property
|
||||
def ai_role(self):
|
||||
return self.conv.roles[1]
|
||||
|
||||
def _jsonify(self, data: Dict) -> str:
|
||||
'''
|
||||
将chat函数返回的结果按照fastchat openai-api-server的格式返回
|
||||
'''
|
||||
return json.dumps(data, ensure_ascii=False).encode() + b"\0"
|
||||
|
||||
def _is_chat(self, prompt: str) -> bool:
|
||||
'''
|
||||
检查prompt是否由chat messages拼接而来
|
||||
TODO: 存在误判的可能,也许从fastchat直接传入原始messages是更好的做法
|
||||
'''
|
||||
key = f"{self.conv.sep}{self.user_role}:"
|
||||
return key in prompt
|
||||
|
||||
def prompt_to_messages(self, prompt: str) -> List[Dict]:
|
||||
'''
|
||||
将prompt字符串拆分成messages.
|
||||
'''
|
||||
result = []
|
||||
user_role = self.user_role
|
||||
ai_role = self.ai_role
|
||||
user_start = user_role + ":"
|
||||
ai_start = ai_role + ":"
|
||||
for msg in prompt.split(self.conv.sep)[1:-1]:
|
||||
if msg.startswith(user_start):
|
||||
if content := msg[len(user_start):].strip():
|
||||
result.append({"role": user_role, "content": content})
|
||||
elif msg.startswith(ai_start):
|
||||
if content := msg[len(ai_start):].strip():
|
||||
result.append({"role": ai_role, "content": content})
|
||||
else:
|
||||
raise RuntimeError(f"unknown role in msg: {msg}")
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def can_embedding(cls):
|
||||
return cls.DEFAULT_EMBED_MODEL is not None
|
|
@ -0,0 +1,106 @@
|
|||
from fastchat.conversation import Conversation
|
||||
from .base import *
|
||||
from fastchat import conversation as conv
|
||||
import sys
|
||||
from typing import List, Literal, Dict
|
||||
from configs import logger, log_verbose
|
||||
|
||||
|
||||
class FangZhouWorker(ApiModelWorker):
|
||||
"""
|
||||
火山方舟
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
model_names: List[str] = ["fangzhou-api"],
|
||||
controller_addr: str = None,
|
||||
worker_addr: str = None,
|
||||
version: Literal["chatglm-6b-model"] = "chatglm-6b-model",
|
||||
**kwargs,
|
||||
):
|
||||
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
|
||||
kwargs.setdefault("context_len", 16384) # TODO: 不同的模型有不同的大小
|
||||
super().__init__(**kwargs)
|
||||
self.version = version
|
||||
|
||||
def do_chat(self, params: ApiChatParams) -> Dict:
|
||||
from volcengine.maas import MaasService
|
||||
|
||||
params.load_config(self.model_names[0])
|
||||
maas = MaasService('maas-api.ml-platform-cn-beijing.volces.com', 'cn-beijing')
|
||||
maas.set_ak(params.api_key)
|
||||
maas.set_sk(params.secret_key)
|
||||
|
||||
# document: "https://www.volcengine.com/docs/82379/1099475"
|
||||
req = {
|
||||
"model": {
|
||||
"name": params.version,
|
||||
},
|
||||
"parameters": {
|
||||
# 这里的参数仅为示例,具体可用的参数请参考具体模型的 API 说明
|
||||
"max_new_tokens": params.max_tokens,
|
||||
"temperature": params.temperature,
|
||||
},
|
||||
"messages": params.messages,
|
||||
}
|
||||
|
||||
text = ""
|
||||
if log_verbose:
|
||||
self.logger.info(f'{self.__class__.__name__}:maas: {maas}')
|
||||
for resp in maas.stream_chat(req):
|
||||
if error := resp.error:
|
||||
if error.code_n > 0:
|
||||
data = {
|
||||
"error_code": error.code_n,
|
||||
"text": error.message,
|
||||
"error": {
|
||||
"message": error.message,
|
||||
"type": "invalid_request_error",
|
||||
"param": None,
|
||||
"code": None,
|
||||
}
|
||||
}
|
||||
self.logger.error(f"请求方舟 API 时发生错误:{data}")
|
||||
yield data
|
||||
elif chunk := resp.choice.message.content:
|
||||
text += chunk
|
||||
yield {"error_code": 0, "text": text}
|
||||
else:
|
||||
data = {
|
||||
"error_code": 500,
|
||||
"text": f"请求方舟 API 时发生未知的错误: {resp}"
|
||||
}
|
||||
self.logger.error(data)
|
||||
yield data
|
||||
break
|
||||
|
||||
def get_embeddings(self, params):
|
||||
# TODO: 支持embeddings
|
||||
print("embedding")
|
||||
print(params)
|
||||
|
||||
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
|
||||
return conv.Conversation(
|
||||
name=self.model_names[0],
|
||||
system_message="你是一个聪明、对人类有帮助的人工智能,你可以对人类提出的问题给出有用、详细、礼貌的回答。",
|
||||
messages=[],
|
||||
roles=["user", "assistant", "system"],
|
||||
sep="\n### ",
|
||||
stop_str="###",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
from server.utils import MakeFastAPIOffline
|
||||
from fastchat.serve.model_worker import app
|
||||
|
||||
worker = FangZhouWorker(
|
||||
controller_addr="http://127.0.0.1:20001",
|
||||
worker_addr="http://127.0.0.1:21005",
|
||||
)
|
||||
sys.modules["fastchat.serve.model_worker"].worker = worker
|
||||
MakeFastAPIOffline(app)
|
||||
uvicorn.run(app, port=21005)
|
|
@ -0,0 +1,172 @@
|
|||
from fastchat.conversation import Conversation
|
||||
from .base import *
|
||||
from fastchat import conversation as conv
|
||||
import sys
|
||||
import json
|
||||
# from server.utils import get_httpx_client
|
||||
from typing import List, Dict
|
||||
from configs import logger, log_verbose
|
||||
|
||||
|
||||
class MiniMaxWorker(ApiModelWorker):
|
||||
DEFAULT_EMBED_MODEL = "embo-01"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
model_names: List[str] = ["minimax-api"],
|
||||
controller_addr: str = None,
|
||||
worker_addr: str = None,
|
||||
version: str = "abab5.5-chat",
|
||||
**kwargs,
|
||||
):
|
||||
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
|
||||
kwargs.setdefault("context_len", 16384)
|
||||
super().__init__(**kwargs)
|
||||
self.version = version
|
||||
|
||||
def validate_messages(self, messages: List[Dict]) -> List[Dict]:
|
||||
role_maps = {
|
||||
"user": self.user_role,
|
||||
"assistant": self.ai_role,
|
||||
"system": "system",
|
||||
}
|
||||
messages = [{"sender_type": role_maps[x["role"]], "text": x["content"]} for x in messages]
|
||||
return messages
|
||||
|
||||
def do_chat(self, params: ApiChatParams) -> Dict:
|
||||
# 按照官网推荐,直接调用abab 5.5模型
|
||||
# TODO: 支持指定回复要求,支持指定用户名称、AI名称
|
||||
params.load_config(self.model_names[0])
|
||||
|
||||
url = 'https://api.minimax.chat/v1/text/chatcompletion{pro}?GroupId={group_id}'
|
||||
pro = "_pro" if params.is_pro else ""
|
||||
headers = {
|
||||
"Authorization": f"Bearer {params.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
messages = self.validate_messages(params.messages)
|
||||
data = {
|
||||
"model": params.version,
|
||||
"stream": True,
|
||||
"mask_sensitive_info": True,
|
||||
"messages": messages,
|
||||
"temperature": params.temperature,
|
||||
"top_p": params.top_p,
|
||||
"tokens_to_generate": params.max_tokens or 1024,
|
||||
# TODO: 以下参数为minimax特有,传入空值会出错。
|
||||
# "prompt": params.system_message or self.conv.system_message,
|
||||
# "bot_setting": [],
|
||||
# "role_meta": params.role_meta,
|
||||
}
|
||||
if log_verbose:
|
||||
logger.info(f'{self.__class__.__name__}:data: {data}')
|
||||
logger.info(f'{self.__class__.__name__}:url: {url.format(pro=pro, group_id=params.group_id)}')
|
||||
logger.info(f'{self.__class__.__name__}:headers: {headers}')
|
||||
|
||||
with get_httpx_client() as client:
|
||||
response = client.stream("POST",
|
||||
url.format(pro=pro, group_id=params.group_id),
|
||||
headers=headers,
|
||||
json=data)
|
||||
with response as r:
|
||||
text = ""
|
||||
for e in r.iter_text():
|
||||
if not e.startswith("data: "): # 真是优秀的返回
|
||||
data = {
|
||||
"error_code": 500,
|
||||
"text": f"minimax返回错误的结果:{e}",
|
||||
"error": {
|
||||
"message": f"minimax返回错误的结果:{e}",
|
||||
"type": "invalid_request_error",
|
||||
"param": None,
|
||||
"code": None,
|
||||
}
|
||||
}
|
||||
self.logger.error(f"请求 MiniMax API 时发生错误:{data}")
|
||||
yield data
|
||||
continue
|
||||
|
||||
data = json.loads(e[6:])
|
||||
if data.get("usage"):
|
||||
break
|
||||
|
||||
if choices := data.get("choices"):
|
||||
if chunk := choices[0].get("delta", ""):
|
||||
text += chunk
|
||||
yield {"error_code": 0, "text": text}
|
||||
|
||||
def do_embeddings(self, params: ApiEmbeddingsParams) -> Dict:
|
||||
params.load_config(self.model_names[0])
|
||||
url = f"https://api.minimax.chat/v1/embeddings?GroupId={params.group_id}"
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {params.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
data = {
|
||||
"model": params.embed_model or self.DEFAULT_EMBED_MODEL,
|
||||
"texts": [],
|
||||
"type": "query" if params.to_query else "db",
|
||||
}
|
||||
if log_verbose:
|
||||
logger.info(f'{self.__class__.__name__}:data: {data}')
|
||||
logger.info(f'{self.__class__.__name__}:url: {url}')
|
||||
logger.info(f'{self.__class__.__name__}:headers: {headers}')
|
||||
|
||||
with get_httpx_client() as client:
|
||||
result = []
|
||||
i = 0
|
||||
batch_size = 10
|
||||
while i < len(params.texts):
|
||||
texts = params.texts[i:i+batch_size]
|
||||
data["texts"] = texts
|
||||
r = client.post(url, headers=headers, json=data).json()
|
||||
if embeddings := r.get("vectors"):
|
||||
result += embeddings
|
||||
elif error := r.get("base_resp"):
|
||||
data = {
|
||||
"code": error["status_code"],
|
||||
"msg": error["status_msg"],
|
||||
"error": {
|
||||
"message": error["status_msg"],
|
||||
"type": "invalid_request_error",
|
||||
"param": None,
|
||||
"code": None,
|
||||
}
|
||||
}
|
||||
self.logger.error(f"请求 MiniMax API 时发生错误:{data}")
|
||||
return data
|
||||
i += batch_size
|
||||
return {"code": 200, "data": embeddings}
|
||||
|
||||
def get_embeddings(self, params):
|
||||
# TODO: 支持embeddings
|
||||
print("embedding")
|
||||
print(params)
|
||||
|
||||
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
|
||||
# TODO: 确认模板是否需要修改
|
||||
return conv.Conversation(
|
||||
name=self.model_names[0],
|
||||
system_message="你是MiniMax自主研发的大型语言模型,回答问题简洁有条理。",
|
||||
messages=[],
|
||||
roles=["USER", "BOT"],
|
||||
sep="\n### ",
|
||||
stop_str="###",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
from server.utils import MakeFastAPIOffline
|
||||
from fastchat.serve.model_worker import app
|
||||
|
||||
worker = MiniMaxWorker(
|
||||
controller_addr="http://127.0.0.1:20001",
|
||||
worker_addr="http://127.0.0.1:21002",
|
||||
)
|
||||
sys.modules["fastchat.serve.model_worker"].worker = worker
|
||||
MakeFastAPIOffline(app)
|
||||
uvicorn.run(app, port=21002)
|
|
@ -0,0 +1,92 @@
|
|||
import sys
|
||||
from fastchat.conversation import Conversation
|
||||
from .base import *
|
||||
from fastchat import conversation as conv
|
||||
import json
|
||||
from typing import List, Dict
|
||||
from configs import logger
|
||||
import openai
|
||||
|
||||
from langchain import PromptTemplate, LLMChain
|
||||
from langchain.prompts.chat import ChatPromptTemplate
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.schema import HumanMessage
|
||||
|
||||
|
||||
class ExampleWorker(ApiModelWorker):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
controller_addr: str = None,
|
||||
worker_addr: str = None,
|
||||
model_names: List[str] = ["gpt-3.5-turbo"],
|
||||
version: str = "gpt-3.5",
|
||||
**kwargs,
|
||||
):
|
||||
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
|
||||
kwargs.setdefault("context_len", 16384) #TODO 16K模型需要改成16384
|
||||
super().__init__(**kwargs)
|
||||
self.version = version
|
||||
|
||||
def do_chat(self, params: ApiChatParams) -> Dict:
|
||||
'''
|
||||
yield output: {"error_code": 0, "text": ""}
|
||||
'''
|
||||
params.load_config(self.model_names[0])
|
||||
openai.api_key = params.api_key
|
||||
openai.api_base = params.api_base_url
|
||||
|
||||
logger.error(f"{params.api_key}, {params.api_base_url}, {params.messages} {params.max_tokens},")
|
||||
# just for example
|
||||
prompt = "\n".join([f"{m['role']}:{m['content']}" for m in params.messages])
|
||||
logger.error(f"{prompt}, {params.temperature}, {params.max_tokens}")
|
||||
try:
|
||||
model = ChatOpenAI(
|
||||
streaming=True,
|
||||
verbose=True,
|
||||
openai_api_key= params.api_key,
|
||||
openai_api_base=params.api_base_url,
|
||||
model_name=params.version
|
||||
)
|
||||
chat_prompt = ChatPromptTemplate.from_messages([("human", "{input}")])
|
||||
chain = LLMChain(prompt=chat_prompt, llm=model)
|
||||
content = chain({"input": prompt})
|
||||
logger.info(content)
|
||||
except Exception as e:
|
||||
logger.error(f"{e}")
|
||||
yield {"error_code": 500, "text": "request error"}
|
||||
|
||||
# return the text by yield for stream
|
||||
try:
|
||||
yield {"error_code": 0, "text": content["text"]}
|
||||
except:
|
||||
yield {"error_code": 500, "text": "request error"}
|
||||
|
||||
def get_embeddings(self, params):
|
||||
# TODO: 支持embeddings
|
||||
print("embedding")
|
||||
print(params)
|
||||
|
||||
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
|
||||
# TODO: 确认模板是否需要修改
|
||||
return conv.Conversation(
|
||||
name=self.model_names[0],
|
||||
system_message="You are a helpful, respectful and honest assistant.",
|
||||
messages=[],
|
||||
roles=["user", "assistant", "system"],
|
||||
sep="\n### ",
|
||||
stop_str="###",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
from dev_opsgpt.utils.server_utils import MakeFastAPIOffline
|
||||
from fastchat.serve.base_model_worker import app
|
||||
|
||||
worker = ExampleWorker(
|
||||
controller_addr="http://127.0.0.1:20001",
|
||||
worker_addr="http://127.0.0.1:21008",
|
||||
)
|
||||
sys.modules["fastchat.serve.model_worker"].worker = worker
|
||||
uvicorn.run(app, port=21008)
|
|
@ -0,0 +1,242 @@
|
|||
import sys
|
||||
from fastchat.conversation import Conversation
|
||||
from .base import *
|
||||
# from server.utils import get_httpx_client
|
||||
from cachetools import cached, TTLCache
|
||||
import json
|
||||
from fastchat import conversation as conv
|
||||
import sys
|
||||
from typing import List, Literal, Dict
|
||||
from configs import logger, log_verbose
|
||||
|
||||
MODEL_VERSIONS = {
|
||||
"ernie-bot-4": "completions_pro",
|
||||
"ernie-bot": "completions",
|
||||
"ernie-bot-turbo": "eb-instant",
|
||||
"bloomz-7b": "bloomz_7b1",
|
||||
"qianfan-bloomz-7b-c": "qianfan_bloomz_7b_compressed",
|
||||
"llama2-7b-chat": "llama_2_7b",
|
||||
"llama2-13b-chat": "llama_2_13b",
|
||||
"llama2-70b-chat": "llama_2_70b",
|
||||
"qianfan-llama2-ch-7b": "qianfan_chinese_llama_2_7b",
|
||||
"chatglm2-6b-32k": "chatglm2_6b_32k",
|
||||
"aquilachat-7b": "aquilachat_7b",
|
||||
# "linly-llama2-ch-7b": "", # 暂未发布
|
||||
# "linly-llama2-ch-13b": "", # 暂未发布
|
||||
# "chatglm2-6b": "", # 暂未发布
|
||||
# "chatglm2-6b-int4": "", # 暂未发布
|
||||
# "falcon-7b": "", # 暂未发布
|
||||
# "falcon-180b-chat": "", # 暂未发布
|
||||
# "falcon-40b": "", # 暂未发布
|
||||
# "rwkv4-world": "", # 暂未发布
|
||||
# "rwkv5-world": "", # 暂未发布
|
||||
# "rwkv4-pile-14b": "", # 暂未发布
|
||||
# "rwkv4-raven-14b": "", # 暂未发布
|
||||
# "open-llama-7b": "", # 暂未发布
|
||||
# "dolly-12b": "", # 暂未发布
|
||||
# "mpt-7b-instruct": "", # 暂未发布
|
||||
# "mpt-30b-instruct": "", # 暂未发布
|
||||
# "OA-Pythia-12B-SFT-4": "", # 暂未发布
|
||||
# "xverse-13b": "", # 暂未发布
|
||||
|
||||
# # 以下为企业测试,需要单独申请
|
||||
# "flan-ul2": "",
|
||||
# "Cerebras-GPT-6.7B": ""
|
||||
# "Pythia-6.9B": ""
|
||||
}
|
||||
|
||||
|
||||
@cached(TTLCache(1, 1800)) # 经过测试,缓存的token可以使用,目前每30分钟刷新一次
|
||||
def get_baidu_access_token(api_key: str, secret_key: str) -> str:
|
||||
"""
|
||||
使用 AK,SK 生成鉴权签名(Access Token)
|
||||
:return: access_token,或是None(如果错误)
|
||||
"""
|
||||
url = "https://aip.baidubce.com/oauth/2.0/token"
|
||||
params = {"grant_type": "client_credentials", "client_id": api_key, "client_secret": secret_key}
|
||||
try:
|
||||
with get_httpx_client() as client:
|
||||
return client.get(url, params=params).json().get("access_token")
|
||||
except Exception as e:
|
||||
print(f"failed to get token from baidu: {e}")
|
||||
|
||||
|
||||
class QianFanWorker(ApiModelWorker):
|
||||
"""
|
||||
百度千帆
|
||||
"""
|
||||
DEFAULT_EMBED_MODEL = "embedding-v1"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
version: Literal["ernie-bot", "ernie-bot-turbo"] = "ernie-bot",
|
||||
model_names: List[str] = ["qianfan-api"],
|
||||
controller_addr: str = None,
|
||||
worker_addr: str = None,
|
||||
**kwargs,
|
||||
):
|
||||
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
|
||||
kwargs.setdefault("context_len", 16384)
|
||||
super().__init__(**kwargs)
|
||||
self.version = version
|
||||
|
||||
def do_chat(self, params: ApiChatParams) -> Dict:
|
||||
params.load_config(self.model_names[0])
|
||||
# import qianfan
|
||||
|
||||
# comp = qianfan.ChatCompletion(model=params.version,
|
||||
# endpoint=params.version_url,
|
||||
# ak=params.api_key,
|
||||
# sk=params.secret_key,)
|
||||
# text = ""
|
||||
# for resp in comp.do(messages=params.messages,
|
||||
# temperature=params.temperature,
|
||||
# top_p=params.top_p,
|
||||
# stream=True):
|
||||
# if resp.code == 200:
|
||||
# if chunk := resp.body.get("result"):
|
||||
# text += chunk
|
||||
# yield {
|
||||
# "error_code": 0,
|
||||
# "text": text
|
||||
# }
|
||||
# else:
|
||||
# yield {
|
||||
# "error_code": resp.code,
|
||||
# "text": str(resp.body),
|
||||
# }
|
||||
|
||||
BASE_URL = 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat' \
|
||||
'/{model_version}?access_token={access_token}'
|
||||
|
||||
access_token = get_baidu_access_token(params.api_key, params.secret_key)
|
||||
if not access_token:
|
||||
yield {
|
||||
"error_code": 403,
|
||||
"text": f"failed to get access token. have you set the correct api_key and secret key?",
|
||||
}
|
||||
|
||||
url = BASE_URL.format(
|
||||
model_version=params.version_url or MODEL_VERSIONS[params.version.lower()],
|
||||
access_token=access_token,
|
||||
)
|
||||
payload = {
|
||||
"messages": params.messages,
|
||||
"temperature": params.temperature,
|
||||
"stream": True
|
||||
}
|
||||
headers = {
|
||||
'Content-Type': 'application/json',
|
||||
'Accept': 'application/json',
|
||||
}
|
||||
|
||||
text = ""
|
||||
if log_verbose:
|
||||
logger.info(f'{self.__class__.__name__}:data: {payload}')
|
||||
logger.info(f'{self.__class__.__name__}:url: {url}')
|
||||
logger.info(f'{self.__class__.__name__}:headers: {headers}')
|
||||
|
||||
with get_httpx_client() as client:
|
||||
with client.stream("POST", url, headers=headers, json=payload) as response:
|
||||
for line in response.iter_lines():
|
||||
if not line.strip():
|
||||
continue
|
||||
if line.startswith("data: "):
|
||||
line = line[6:]
|
||||
resp = json.loads(line)
|
||||
|
||||
if "result" in resp.keys():
|
||||
text += resp["result"]
|
||||
yield {
|
||||
"error_code": 0,
|
||||
"text": text
|
||||
}
|
||||
else:
|
||||
data = {
|
||||
"error_code": resp["error_code"],
|
||||
"text": resp["error_msg"],
|
||||
"error": {
|
||||
"message": resp["error_msg"],
|
||||
"type": "invalid_request_error",
|
||||
"param": None,
|
||||
"code": None,
|
||||
}
|
||||
}
|
||||
self.logger.error(f"请求千帆 API 时发生错误:{data}")
|
||||
yield data
|
||||
|
||||
def do_embeddings(self, params: ApiEmbeddingsParams) -> Dict:
|
||||
params.load_config(self.model_names[0])
|
||||
# import qianfan
|
||||
|
||||
# embed = qianfan.Embedding(ak=params.api_key, sk=params.secret_key)
|
||||
# resp = embed.do(texts = params.texts, model=params.embed_model or self.DEFAULT_EMBED_MODEL)
|
||||
# if resp.code == 200:
|
||||
# embeddings = [x.embedding for x in resp.body.get("data", [])]
|
||||
# return {"code": 200, "embeddings": embeddings}
|
||||
# else:
|
||||
# return {"code": resp.code, "msg": str(resp.body)}
|
||||
|
||||
embed_model = params.embed_model or self.DEFAULT_EMBED_MODEL
|
||||
access_token = get_baidu_access_token(params.api_key, params.secret_key)
|
||||
url = f"https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/{embed_model}?access_token={access_token}"
|
||||
if log_verbose:
|
||||
logger.info(f'{self.__class__.__name__}:url: {url}')
|
||||
|
||||
with get_httpx_client() as client:
|
||||
result = []
|
||||
i = 0
|
||||
batch_size = 10
|
||||
while i < len(params.texts):
|
||||
texts = params.texts[i:i+batch_size]
|
||||
resp = client.post(url, json={"input": texts}).json()
|
||||
if "error_code" in resp:
|
||||
data = {
|
||||
"code": resp["error_code"],
|
||||
"msg": resp["error_msg"],
|
||||
"error": {
|
||||
"message": resp["error_msg"],
|
||||
"type": "invalid_request_error",
|
||||
"param": None,
|
||||
"code": None,
|
||||
}
|
||||
}
|
||||
self.logger.error(f"请求千帆 API 时发生错误:{data}")
|
||||
return data
|
||||
else:
|
||||
embeddings = [x["embedding"] for x in resp.get("data", [])]
|
||||
result += embeddings
|
||||
i += batch_size
|
||||
return {"code": 200, "data": result}
|
||||
|
||||
# TODO: qianfan支持续写模型
|
||||
def get_embeddings(self, params):
|
||||
# TODO: 支持embeddings
|
||||
print("embedding")
|
||||
print(params)
|
||||
|
||||
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
|
||||
# TODO: 确认模板是否需要修改
|
||||
return conv.Conversation(
|
||||
name=self.model_names[0],
|
||||
system_message="你是一个聪明的助手,请根据用户的提示来完成任务",
|
||||
messages=[],
|
||||
roles=["user", "assistant"],
|
||||
sep="\n### ",
|
||||
stop_str="###",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
from server.utils import MakeFastAPIOffline
|
||||
from fastchat.serve.model_worker import app
|
||||
|
||||
worker = QianFanWorker(
|
||||
controller_addr="http://127.0.0.1:20001",
|
||||
worker_addr="http://127.0.0.1:21004"
|
||||
)
|
||||
sys.modules["fastchat.serve.model_worker"].worker = worker
|
||||
MakeFastAPIOffline(app)
|
||||
uvicorn.run(app, port=21004)
|
|
@ -0,0 +1,128 @@
|
|||
import json
|
||||
import sys
|
||||
|
||||
from fastchat.conversation import Conversation
|
||||
from http import HTTPStatus
|
||||
from typing import List, Literal, Dict
|
||||
|
||||
from fastchat import conversation as conv
|
||||
from .base import *
|
||||
from configs import logger, log_verbose
|
||||
|
||||
|
||||
class QwenWorker(ApiModelWorker):
|
||||
DEFAULT_EMBED_MODEL = "text-embedding-v1"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
version: Literal["qwen-turbo", "qwen-plus"] = "qwen-turbo",
|
||||
model_names: List[str] = ["qwen-api"],
|
||||
controller_addr: str = None,
|
||||
worker_addr: str = None,
|
||||
**kwargs,
|
||||
):
|
||||
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
|
||||
kwargs.setdefault("context_len", 16384)
|
||||
super().__init__(**kwargs)
|
||||
self.version = version
|
||||
|
||||
def do_chat(self, params: ApiChatParams) -> Dict:
|
||||
import dashscope
|
||||
params.load_config(self.model_names[0])
|
||||
if log_verbose:
|
||||
logger.info(f'{self.__class__.__name__}:params: {params}')
|
||||
|
||||
gen = dashscope.Generation()
|
||||
responses = gen.call(
|
||||
model=params.version,
|
||||
temperature=params.temperature,
|
||||
api_key=params.api_key,
|
||||
messages=params.messages,
|
||||
result_format='message', # set the result is message format.
|
||||
stream=True,
|
||||
)
|
||||
|
||||
for resp in responses:
|
||||
if resp["status_code"] == 200:
|
||||
if choices := resp["output"]["choices"]:
|
||||
yield {
|
||||
"error_code": 0,
|
||||
"text": choices[0]["message"]["content"],
|
||||
}
|
||||
else:
|
||||
data = {
|
||||
"error_code": resp["status_code"],
|
||||
"text": resp["message"],
|
||||
"error": {
|
||||
"message": resp["message"],
|
||||
"type": "invalid_request_error",
|
||||
"param": None,
|
||||
"code": None,
|
||||
}
|
||||
}
|
||||
self.logger.error(f"请求千问 API 时发生错误:{data}")
|
||||
yield data
|
||||
|
||||
def do_embeddings(self, params: ApiEmbeddingsParams) -> Dict:
|
||||
import dashscope
|
||||
params.load_config(self.model_names[0])
|
||||
if log_verbose:
|
||||
logger.info(f'{self.__class__.__name__}:params: {params}')
|
||||
result = []
|
||||
i = 0
|
||||
while i < len(params.texts):
|
||||
texts = params.texts[i:i+25]
|
||||
resp = dashscope.TextEmbedding.call(
|
||||
model=params.embed_model or self.DEFAULT_EMBED_MODEL,
|
||||
input=texts, # 最大25行
|
||||
api_key=params.api_key,
|
||||
)
|
||||
if resp["status_code"] != 200:
|
||||
data = {
|
||||
"code": resp["status_code"],
|
||||
"msg": resp.message,
|
||||
"error": {
|
||||
"message": resp["message"],
|
||||
"type": "invalid_request_error",
|
||||
"param": None,
|
||||
"code": None,
|
||||
}
|
||||
}
|
||||
self.logger.error(f"请求千问 API 时发生错误:{data}")
|
||||
return data
|
||||
else:
|
||||
embeddings = [x["embedding"] for x in resp["output"]["embeddings"]]
|
||||
result += embeddings
|
||||
i += 25
|
||||
return {"code": 200, "data": result}
|
||||
|
||||
def get_embeddings(self, params):
|
||||
# TODO: 支持embeddings
|
||||
print("embedding")
|
||||
print(params)
|
||||
|
||||
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
|
||||
# TODO: 确认模板是否需要修改
|
||||
return conv.Conversation(
|
||||
name=self.model_names[0],
|
||||
system_message="你是一个聪明、对人类有帮助的人工智能,你可以对人类提出的问题给出有用、详细、礼貌的回答。",
|
||||
messages=[],
|
||||
roles=["user", "assistant", "system"],
|
||||
sep="\n### ",
|
||||
stop_str="###",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
from server.utils import MakeFastAPIOffline
|
||||
from fastchat.serve.model_worker import app
|
||||
|
||||
worker = QwenWorker(
|
||||
controller_addr="http://127.0.0.1:20001",
|
||||
worker_addr="http://127.0.0.1:20007",
|
||||
)
|
||||
sys.modules["fastchat.serve.model_worker"].worker = worker
|
||||
MakeFastAPIOffline(app)
|
||||
uvicorn.run(app, port=20007)
|
|
@ -0,0 +1,88 @@
|
|||
import json
|
||||
import time
|
||||
import hashlib
|
||||
|
||||
from fastchat.conversation import Conversation
|
||||
from .base import *
|
||||
from fastchat import conversation as conv
|
||||
import json
|
||||
from typing import List, Literal, Dict
|
||||
import requests
|
||||
|
||||
|
||||
|
||||
class TianGongWorker(ApiModelWorker):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
controller_addr: str = None,
|
||||
worker_addr: str = None,
|
||||
model_names: List[str] = ["tiangong-api"],
|
||||
version: Literal["SkyChat-MegaVerse"] = "SkyChat-MegaVerse",
|
||||
**kwargs,
|
||||
):
|
||||
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
|
||||
kwargs.setdefault("context_len", 32768)
|
||||
super().__init__(**kwargs)
|
||||
self.version = version
|
||||
|
||||
def do_chat(self, params: ApiChatParams) -> Dict:
|
||||
params.load_config(self.model_names[0])
|
||||
|
||||
url = 'https://sky-api.singularity-ai.com/saas/api/v4/generate'
|
||||
data = {
|
||||
"messages": params.messages,
|
||||
"model": "SkyChat-MegaVerse"
|
||||
}
|
||||
timestamp = str(int(time.time()))
|
||||
sign_content = params.api_key + params.secret_key + timestamp
|
||||
sign_result = hashlib.md5(sign_content.encode('utf-8')).hexdigest()
|
||||
headers={
|
||||
"app_key": params.api_key,
|
||||
"timestamp": timestamp,
|
||||
"sign": sign_result,
|
||||
"Content-Type": "application/json",
|
||||
"stream": "true" # or change to "false" 不处理流式返回内容
|
||||
}
|
||||
|
||||
# 发起请求并获取响应
|
||||
response = requests.post(url, headers=headers, json=data, stream=True)
|
||||
|
||||
text = ""
|
||||
# 处理响应流
|
||||
for line in response.iter_lines(chunk_size=None, decode_unicode=True):
|
||||
if line:
|
||||
# 处理接收到的数据
|
||||
# print(line.decode('utf-8'))
|
||||
resp = json.loads(line)
|
||||
if resp["code"] == 200:
|
||||
text += resp['resp_data']['reply']
|
||||
yield {
|
||||
"error_code": 0,
|
||||
"text": text
|
||||
}
|
||||
else:
|
||||
data = {
|
||||
"error_code": resp["code"],
|
||||
"text": resp["code_msg"]
|
||||
}
|
||||
self.logger.error(f"请求天工 API 时出错:{data}")
|
||||
yield data
|
||||
|
||||
def get_embeddings(self, params):
|
||||
# TODO: 支持embeddings
|
||||
print("embedding")
|
||||
print(params)
|
||||
|
||||
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
|
||||
# TODO: 确认模板是否需要修改
|
||||
return conv.Conversation(
|
||||
name=self.model_names[0],
|
||||
system_message="",
|
||||
messages=[],
|
||||
roles=["user", "system"],
|
||||
sep="\n### ",
|
||||
stop_str="###",
|
||||
)
|
||||
|
||||
|
|
@ -0,0 +1,105 @@
|
|||
from fastchat.conversation import Conversation
|
||||
from .base import *
|
||||
from fastchat import conversation as conv
|
||||
import sys
|
||||
import json
|
||||
from model_workers import SparkApi
|
||||
import websockets
|
||||
from dev_opsgpt.utils.server_utils import run_async, iter_over_async
|
||||
from typing import List, Dict
|
||||
import asyncio
|
||||
|
||||
|
||||
|
||||
async def request(appid, api_key, api_secret, Spark_url, domain, question, temperature, max_token):
|
||||
wsParam = SparkApi.Ws_Param(appid, api_key, api_secret, Spark_url)
|
||||
wsUrl = wsParam.create_url()
|
||||
data = SparkApi.gen_params(appid, domain, question, temperature, max_token)
|
||||
print(data)
|
||||
async with websockets.connect(wsUrl) as ws:
|
||||
await ws.send(json.dumps(data, ensure_ascii=False))
|
||||
finish = False
|
||||
while not finish:
|
||||
chunk = await ws.recv()
|
||||
response = json.loads(chunk)
|
||||
if response.get("header", {}).get("status") == 2:
|
||||
finish = True
|
||||
if text := response.get("payload", {}).get("choices", {}).get("text"):
|
||||
yield text[0]["content"]
|
||||
|
||||
|
||||
class XingHuoWorker(ApiModelWorker):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
model_names: List[str] = ["xinghuo-api"],
|
||||
controller_addr: str = None,
|
||||
worker_addr: str = None,
|
||||
version: str = None,
|
||||
**kwargs,
|
||||
):
|
||||
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
|
||||
kwargs.setdefault("context_len", 8000) # TODO: V1模型的最大长度为4000,需要自行修改
|
||||
super().__init__(**kwargs)
|
||||
self.version = version
|
||||
|
||||
def do_chat(self, params: ApiChatParams) -> Dict:
|
||||
# TODO: 当前每次对话都要重新连接websocket,确认是否可以保持连接
|
||||
params.load_config(self.model_names[0])
|
||||
|
||||
version_mapping = {
|
||||
"v1.5": {"domain": "general", "url": "ws://spark-api.xf-yun.com/v1.1/chat","max_tokens": 4000},
|
||||
"v2.0": {"domain": "generalv2", "url": "ws://spark-api.xf-yun.com/v2.1/chat","max_tokens": 8000},
|
||||
"v3.0": {"domain": "generalv3", "url": "ws://spark-api.xf-yun.com/v3.1/chat","max_tokens": 8000},
|
||||
}
|
||||
|
||||
def get_version_details(version_key):
|
||||
return version_mapping.get(version_key, {"domain": None, "url": None})
|
||||
|
||||
details = get_version_details(params.version)
|
||||
domain = details["domain"]
|
||||
Spark_url = details["url"]
|
||||
text = ""
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
except:
|
||||
loop = asyncio.new_event_loop()
|
||||
params.max_tokens = min(details["max_tokens"], params.max_tokens or 0)
|
||||
for chunk in iter_over_async(
|
||||
request(params.APPID, params.api_key, params.APISecret, Spark_url, domain, params.messages,
|
||||
params.temperature, params.max_tokens),
|
||||
loop=loop,
|
||||
):
|
||||
if chunk:
|
||||
text += chunk
|
||||
yield {"error_code": 0, "text": text}
|
||||
|
||||
def get_embeddings(self, params):
|
||||
# TODO: 支持embeddings
|
||||
print("embedding")
|
||||
print(params)
|
||||
|
||||
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
|
||||
# TODO: 确认模板是否需要修改
|
||||
return conv.Conversation(
|
||||
name=self.model_names[0],
|
||||
system_message="你是一个聪明的助手,请根据用户的提示来完成任务",
|
||||
messages=[],
|
||||
roles=["user", "assistant"],
|
||||
sep="\n### ",
|
||||
stop_str="###",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
from server.utils import MakeFastAPIOffline
|
||||
from fastchat.serve.model_worker import app
|
||||
|
||||
worker = XingHuoWorker(
|
||||
controller_addr="http://127.0.0.1:20001",
|
||||
worker_addr="http://127.0.0.1:21003",
|
||||
)
|
||||
sys.modules["fastchat.serve.model_worker"].worker = worker
|
||||
MakeFastAPIOffline(app)
|
||||
uvicorn.run(app, port=21003)
|
|
@ -0,0 +1,110 @@
|
|||
from fastchat.conversation import Conversation
|
||||
from .base import *
|
||||
from fastchat import conversation as conv
|
||||
import sys
|
||||
from typing import List, Dict, Iterator, Literal
|
||||
from configs import logger, log_verbose
|
||||
|
||||
|
||||
class ChatGLMWorker(ApiModelWorker):
|
||||
DEFAULT_EMBED_MODEL = "text_embedding"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
model_names: List[str] = ["zhipu-api"],
|
||||
controller_addr: str = None,
|
||||
worker_addr: str = None,
|
||||
version: Literal["chatglm_turbo"] = "chatglm_turbo",
|
||||
**kwargs,
|
||||
):
|
||||
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
|
||||
kwargs.setdefault("context_len", 32768)
|
||||
super().__init__(**kwargs)
|
||||
self.version = version
|
||||
|
||||
def do_chat(self, params: ApiChatParams) -> Iterator[Dict]:
|
||||
# TODO: 维护request_id
|
||||
import zhipuai
|
||||
|
||||
params.load_config(self.model_names[0])
|
||||
zhipuai.api_key = params.api_key
|
||||
|
||||
if log_verbose:
|
||||
logger.info(f'{self.__class__.__name__}:params: {params}')
|
||||
|
||||
response = zhipuai.model_api.sse_invoke(
|
||||
model=params.version,
|
||||
prompt=params.messages,
|
||||
temperature=params.temperature,
|
||||
top_p=params.top_p,
|
||||
incremental=False,
|
||||
)
|
||||
for e in response.events():
|
||||
if e.event == "add":
|
||||
yield {"error_code": 0, "text": e.data}
|
||||
elif e.event in ["error", "interrupted"]:
|
||||
data = {
|
||||
"error_code": 500,
|
||||
"text": str(e),
|
||||
"error": {
|
||||
"message": str(e),
|
||||
"type": "invalid_request_error",
|
||||
"param": None,
|
||||
"code": None,
|
||||
}
|
||||
}
|
||||
self.logger.error(f"请求智谱 API 时发生错误:{data}")
|
||||
yield data
|
||||
|
||||
def do_embeddings(self, params: ApiEmbeddingsParams) -> Dict:
|
||||
import zhipuai
|
||||
|
||||
params.load_config(self.model_names[0])
|
||||
zhipuai.api_key = params.api_key
|
||||
|
||||
embeddings = []
|
||||
try:
|
||||
for t in params.texts:
|
||||
response = zhipuai.model_api.invoke(model=params.embed_model or self.DEFAULT_EMBED_MODEL, prompt=t)
|
||||
if response["code"] == 200:
|
||||
embeddings.append(response["data"]["embedding"])
|
||||
else:
|
||||
self.logger.error(f"请求智谱 API 时发生错误:{response}")
|
||||
return response # dict with code & msg
|
||||
except Exception as e:
|
||||
self.logger.error(f"请求智谱 API 时发生错误:{data}")
|
||||
data = {"code": 500, "msg": f"对文本向量化时出错:{e}"}
|
||||
return data
|
||||
|
||||
return {"code": 200, "data": embeddings}
|
||||
|
||||
def get_embeddings(self, params):
|
||||
# TODO: 支持embeddings
|
||||
print("embedding")
|
||||
# print(params)
|
||||
|
||||
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
|
||||
# 这里的是chatglm api的模板,其它API的conv_template需要定制
|
||||
return conv.Conversation(
|
||||
name=self.model_names[0],
|
||||
system_message="你是一个聪明的助手,请根据用户的提示来完成任务",
|
||||
messages=[],
|
||||
roles=["Human", "Assistant", "System"],
|
||||
sep="\n###",
|
||||
stop_str="###",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
from server.utils import MakeFastAPIOffline
|
||||
from fastchat.serve.model_worker import app
|
||||
|
||||
worker = ChatGLMWorker(
|
||||
controller_addr="http://127.0.0.1:20001",
|
||||
worker_addr="http://127.0.0.1:21001",
|
||||
)
|
||||
sys.modules["fastchat.serve.model_worker"].worker = worker
|
||||
MakeFastAPIOffline(app)
|
||||
uvicorn.run(app, port=21001)
|
|
@ -0,0 +1,39 @@
|
|||
import os
|
||||
|
||||
from configs.model_config import ONLINE_LLM_MODEL
|
||||
from configs.server_config import FSCHAT_MODEL_WORKERS
|
||||
from configs.model_config import llm_model_dict, LLM_DEVICE
|
||||
|
||||
from loguru import logger
|
||||
|
||||
|
||||
|
||||
def get_model_worker_config(model_name: str = None) -> dict:
|
||||
'''
|
||||
加载model worker的配置项。
|
||||
优先级:FSCHAT_MODEL_WORKERS[model_name] > ONLINE_LLM_MODEL[model_name] > FSCHAT_MODEL_WORKERS["default"]
|
||||
'''
|
||||
from dev_opsgpt.service import model_workers
|
||||
|
||||
config = FSCHAT_MODEL_WORKERS.get("default", {}).copy()
|
||||
config.update(ONLINE_LLM_MODEL.get(model_name, {}).copy())
|
||||
config.update(FSCHAT_MODEL_WORKERS.get(model_name, {}).copy())
|
||||
|
||||
if model_name in ONLINE_LLM_MODEL:
|
||||
config["online_api"] = True
|
||||
if provider := config.get("provider"):
|
||||
try:
|
||||
config["worker_class"] = getattr(model_workers, provider)
|
||||
except Exception as e:
|
||||
msg = f"在线模型 ‘{model_name}’ 的provider没有正确配置"
|
||||
logger.error(f'{e.__class__.__name__}: {msg}')
|
||||
# 本地模型
|
||||
if model_name in llm_model_dict:
|
||||
path = llm_model_dict[model_name]["local_model_path"]
|
||||
config["model_path"] = path
|
||||
if path and os.path.isdir(path):
|
||||
config["model_path_exists"] = True
|
||||
config["device"] = LLM_DEVICE
|
||||
|
||||
# logger.debug(f"config: {config}")
|
||||
return config
|
|
@ -11,12 +11,13 @@ from .docs_retrieval import DocRetrieval
|
|||
from .cb_query_tool import CodeRetrieval
|
||||
from .ocr_tool import BaiduOcrTool
|
||||
from .stock_tool import StockInfo, StockName
|
||||
from .codechat_tools import CodeRetrievalSingle, RelatedVerticesRetrival, Vertex2Code
|
||||
|
||||
|
||||
IMPORT_TOOL = [
|
||||
WeatherInfo, DistrictInfo, Multiplier, WorldTimeGetTimezoneByArea,
|
||||
KSigmaDetector, MetricsQuery, DDGSTool, DocRetrieval, CodeRetrieval,
|
||||
BaiduOcrTool, StockInfo, StockName
|
||||
BaiduOcrTool, StockInfo, StockName, CodeRetrievalSingle, RelatedVerticesRetrival, Vertex2Code
|
||||
]
|
||||
|
||||
TOOL_SETS = [tool.__name__ for tool in IMPORT_TOOL]
|
||||
|
|
|
@ -58,3 +58,5 @@ class CodeRetrieval(BaseToolModel):
|
|||
return_codes.append({'index': 0, 'code': context, "related_nodes": related_nodes})
|
||||
|
||||
return return_codes
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,110 @@
|
|||
# encoding: utf-8
|
||||
'''
|
||||
@author: 温进
|
||||
@file: codechat_tools.py.py
|
||||
@time: 2023/12/14 上午10:24
|
||||
@desc:
|
||||
'''
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Dict
|
||||
import requests
|
||||
import numpy as np
|
||||
from loguru import logger
|
||||
|
||||
from configs.model_config import (
|
||||
CODE_SEARCH_TOP_K)
|
||||
from .base_tool import BaseToolModel
|
||||
|
||||
from dev_opsgpt.service.cb_api import search_code, search_related_vertices, search_code_by_vertex
|
||||
|
||||
|
||||
# 问题进来
|
||||
# 调用函数 0:输入问题,输出代码文件名 1 和 代码文件 1
|
||||
#
|
||||
# agent 1
|
||||
# 1. LLM:代码+问题 输出:是否能解决
|
||||
#
|
||||
# agent 2
|
||||
# 1. 调用函数 1 :输入:代码文件名 1 输出:代码文件名列表
|
||||
# 2. LLM:输入代码文件 1, 问题,代码文件名列表,输出:代码文件名 2
|
||||
# 3. 调用函数 2: 输入 :代码文件名 2 输出:代码文件 2
|
||||
|
||||
|
||||
class CodeRetrievalSingle(BaseToolModel):
|
||||
name = "CodeRetrievalOneCode"
|
||||
description = "输入用户的问题,输出一个代码文件名和代码文件"
|
||||
|
||||
class ToolInputArgs(BaseModel):
|
||||
query: str = Field(..., description="检索的问题")
|
||||
code_base_name: str = Field(..., description="代码库名称", examples=["samples"])
|
||||
|
||||
class ToolOutputArgs(BaseModel):
|
||||
"""Output for MetricsQuery."""
|
||||
code: str = Field(..., description="检索代码")
|
||||
vertex: str = Field(..., description="代码对应 id")
|
||||
|
||||
@classmethod
|
||||
def run(cls, code_base_name, query):
|
||||
"""excute your tool!"""
|
||||
|
||||
search_type = 'description'
|
||||
code_limit = 1
|
||||
|
||||
# default
|
||||
search_result = search_code(code_base_name, query, code_limit, search_type=search_type,
|
||||
history_node_list=[])
|
||||
|
||||
logger.debug(search_result)
|
||||
code = search_result['context']
|
||||
vertex = search_result['related_vertices'][0]
|
||||
# logger.debug(f"code: {code}, vertex: {vertex}")
|
||||
|
||||
res = {
|
||||
'code': code,
|
||||
'vertex': vertex
|
||||
}
|
||||
|
||||
return res
|
||||
|
||||
|
||||
class RelatedVerticesRetrival(BaseToolModel):
|
||||
name = "RelatedVerticesRetrival"
|
||||
description = "输入代码节点名,返回相连的节点名"
|
||||
|
||||
class ToolInputArgs(BaseModel):
|
||||
code_base_name: str = Field(..., description="代码库名称", examples=["samples"])
|
||||
vertex: str = Field(..., description="节点名", examples=["samples"])
|
||||
|
||||
class ToolOutputArgs(BaseModel):
|
||||
"""Output for MetricsQuery."""
|
||||
vertices: list = Field(..., description="相连节点名")
|
||||
|
||||
@classmethod
|
||||
def run(cls, code_base_name: str, vertex: str):
|
||||
"""execute your tool!"""
|
||||
related_vertices = search_related_vertices(cb_name=code_base_name, vertex=vertex)
|
||||
logger.debug(f"related_vertices: {related_vertices}")
|
||||
|
||||
return related_vertices
|
||||
|
||||
|
||||
class Vertex2Code(BaseToolModel):
|
||||
name = "Vertex2Code"
|
||||
description = "输入代码节点名,返回对应的代码文件"
|
||||
|
||||
class ToolInputArgs(BaseModel):
|
||||
code_base_name: str = Field(..., description="代码库名称", examples=["samples"])
|
||||
vertex: str = Field(..., description="节点名", examples=["samples"])
|
||||
|
||||
class ToolOutputArgs(BaseModel):
|
||||
"""Output for MetricsQuery."""
|
||||
code: str = Field(..., description="代码名")
|
||||
|
||||
@classmethod
|
||||
def run(cls, code_base_name: str, vertex: str):
|
||||
"""execute your tool!"""
|
||||
res = search_code_by_vertex(cb_name=code_base_name, vertex=vertex)
|
||||
return res
|
|
@ -15,9 +15,24 @@ chat_box = ChatBox(
|
|||
assistant_avatar="../sources/imgs/devops-chatbot2.png"
|
||||
)
|
||||
|
||||
cur_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
GLOBAL_EXE_CODE_TEXT = ""
|
||||
GLOBAL_MESSAGE = {"figures": {}, "final_contents": {}}
|
||||
|
||||
|
||||
|
||||
import yaml
|
||||
|
||||
# 加载YAML文件
|
||||
webui_yaml_filename = "webui_zh.yaml" if True else "webui_en.yaml"
|
||||
with open(os.path.join(cur_dir, f"yamls/{webui_yaml_filename}"), 'r') as f:
|
||||
try:
|
||||
webui_configs = yaml.safe_load(f)
|
||||
except yaml.YAMLError as exc:
|
||||
print(exc)
|
||||
|
||||
|
||||
def get_messages_history(history_len: int, isDetailed=False) -> List[Dict]:
|
||||
def filter(msg):
|
||||
'''
|
||||
|
@ -55,12 +70,6 @@ def upload2sandbox(upload_file, api: ApiRequest):
|
|||
res = {"msg": False}
|
||||
else:
|
||||
res = api.web_sd_upload(upload_file)
|
||||
# logger.debug(res)
|
||||
# if res["msg"]:
|
||||
# st.success("上文件传成功")
|
||||
# else:
|
||||
# st.toast("文件上传失败")
|
||||
|
||||
|
||||
def dialogue_page(api: ApiRequest):
|
||||
global GLOBAL_EXE_CODE_TEXT
|
||||
|
@ -70,33 +79,31 @@ def dialogue_page(api: ApiRequest):
|
|||
# TODO: 对话模型与会话绑定
|
||||
def on_mode_change():
|
||||
mode = st.session_state.dialogue_mode
|
||||
text = f"已切换到 {mode} 模式。"
|
||||
if mode == "知识库问答":
|
||||
text = webui_configs["dialogue"]["text_mode_swtich"] + f"{mode}"
|
||||
if mode == webui_configs["dialogue"]["mode"][1]:
|
||||
cur_kb = st.session_state.get("selected_kb")
|
||||
if cur_kb:
|
||||
text = f"{text} 当前知识库: `{cur_kb}`。"
|
||||
text = text + webui_configs["dialogue"]["text_knowledgeBase_swtich"] + f'`{cur_kb}`'
|
||||
st.toast(text)
|
||||
# sac.alert(text, description="descp", type="success", closable=True, banner=True)
|
||||
|
||||
dialogue_mode = st.selectbox("请选择对话模式",
|
||||
["LLM 对话",
|
||||
"知识库问答",
|
||||
"代码知识库问答",
|
||||
"搜索引擎问答",
|
||||
"Agent问答"
|
||||
],
|
||||
dialogue_mode = st.selectbox(webui_configs["dialogue"]["mode_instruction"],
|
||||
webui_configs["dialogue"]["mode"],
|
||||
# ["LLM 对话",
|
||||
# "知识库问答",
|
||||
# "代码知识库问答",
|
||||
# "搜索引擎问答",
|
||||
# "Agent问答"
|
||||
# ],
|
||||
on_change=on_mode_change,
|
||||
key="dialogue_mode",
|
||||
)
|
||||
history_len = st.number_input("历史对话轮数:", 0, 10, 3)
|
||||
|
||||
# todo: support history len
|
||||
history_len = st.number_input(webui_configs["dialogue"]["history_length"], 0, 10, 3)
|
||||
|
||||
def on_kb_change():
|
||||
st.toast(f"已加载知识库: {st.session_state.selected_kb}")
|
||||
st.toast(f"{webui_configs['dialogue']['text_loaded_kbase']}: {st.session_state.selected_kb}")
|
||||
|
||||
def on_cb_change():
|
||||
st.toast(f"已加载代码知识库: {st.session_state.selected_cb}")
|
||||
st.toast(f"{webui_configs['dialogue']['text_loaded_cbase']}: {st.session_state.selected_cb}")
|
||||
cb_details = get_cb_details_by_cb_name(st.session_state.selected_cb)
|
||||
st.session_state['do_interpret'] = cb_details['do_interpret']
|
||||
|
||||
|
@ -107,114 +114,140 @@ def dialogue_page(api: ApiRequest):
|
|||
not_agent_qa = True
|
||||
interpreter_file = ""
|
||||
is_detailed = False
|
||||
if dialogue_mode == "知识库问答":
|
||||
with st.expander("知识库配置", True):
|
||||
if dialogue_mode == webui_configs["dialogue"]["mode"][1]:
|
||||
with st.expander(webui_configs["dialogue"]["kbase_expander_name"], True):
|
||||
kb_list = api.list_knowledge_bases(no_remote_api=True)
|
||||
selected_kb = st.selectbox(
|
||||
"请选择知识库:",
|
||||
webui_configs["dialogue"]["kbase_selectbox_name"],
|
||||
kb_list,
|
||||
on_change=on_kb_change,
|
||||
key="selected_kb",
|
||||
)
|
||||
kb_top_k = st.number_input("匹配知识条数:", 1, 20, 3)
|
||||
score_threshold = st.number_input("知识匹配分数阈值:", 0.0, float(SCORE_THRESHOLD), float(SCORE_THRESHOLD), float(SCORE_THRESHOLD//100))
|
||||
# chunk_content = st.checkbox("关联上下文", False, disabled=True)
|
||||
# chunk_size = st.slider("关联长度:", 0, 500, 250, disabled=True)
|
||||
elif dialogue_mode == '代码知识库问答':
|
||||
with st.expander('代码知识库配置', True):
|
||||
kb_top_k = st.number_input(
|
||||
webui_configs["dialogue"]["kbase_ninput_topk_name"], 1, 20, 3)
|
||||
score_threshold = st.number_input(
|
||||
webui_configs["dialogue"]["kbase_ninput_score_threshold_name"],
|
||||
0.0, float(SCORE_THRESHOLD), float(SCORE_THRESHOLD),
|
||||
float(SCORE_THRESHOLD//100))
|
||||
|
||||
elif dialogue_mode == webui_configs["dialogue"]["mode"][2]:
|
||||
with st.expander(webui_configs["dialogue"]["cbase_expander_name"], True):
|
||||
cb_list = api.list_cb(no_remote_api=True)
|
||||
logger.debug('codebase_list={}'.format(cb_list))
|
||||
selected_cb = st.selectbox(
|
||||
"请选择代码知识库:",
|
||||
webui_configs["dialogue"]["cbase_selectbox_name"],
|
||||
cb_list,
|
||||
on_change=on_cb_change,
|
||||
key="selected_cb",
|
||||
)
|
||||
|
||||
# change do_interpret
|
||||
st.toast(f"已加载代码知识库: {st.session_state.selected_cb}")
|
||||
st.toast(f"{webui_configs['dialogue']['text_loaded_cbase']}: {st.session_state.selected_cb}")
|
||||
cb_details = get_cb_details_by_cb_name(st.session_state.selected_cb)
|
||||
st.session_state['do_interpret'] = cb_details['do_interpret']
|
||||
|
||||
cb_code_limit = st.number_input("匹配代码条数:", 1, 20, 1)
|
||||
cb_code_limit = st.number_input(
|
||||
webui_configs["dialogue"]["cbase_ninput_topk_name"], 1, 20, 1)
|
||||
|
||||
search_type_list = ['基于 cypher', '基于标签', '基于描述'] if st.session_state['do_interpret'] == 'YES' \
|
||||
else ['基于 cypher', '基于标签']
|
||||
search_type_list = webui_configs["dialogue"]["cbase_search_type_v1"] if st.session_state['do_interpret'] == 'YES' \
|
||||
else webui_configs["dialogue"]["cbase_search_type_v2"]
|
||||
|
||||
cb_search_type = st.selectbox(
|
||||
'请选择查询模式:',
|
||||
webui_configs["dialogue"]["cbase_selectbox_type_name"],
|
||||
search_type_list,
|
||||
key='cb_search_type'
|
||||
)
|
||||
elif dialogue_mode == "搜索引擎问答":
|
||||
with st.expander("搜索引擎配置", True):
|
||||
search_engine = st.selectbox("请选择搜索引擎", SEARCH_ENGINES.keys(), 0)
|
||||
se_top_k = st.number_input("匹配搜索结果条数:", 1, 20, 3)
|
||||
elif dialogue_mode == "Agent问答":
|
||||
elif dialogue_mode == webui_configs["dialogue"]["mode"][3]:
|
||||
with st.expander(webui_configs["dialogue"]["expander_search_name"], True):
|
||||
search_engine = st.selectbox(
|
||||
webui_configs["dialogue"]["selectbox_search_name"],
|
||||
SEARCH_ENGINES.keys(), 0)
|
||||
se_top_k = st.number_input(
|
||||
webui_configs["dialogue"]["ninput_search_topk_name"], 1, 20, 3)
|
||||
elif dialogue_mode == webui_configs["dialogue"]["mode"][4]:
|
||||
not_agent_qa = False
|
||||
with st.expander("Phase管理", True):
|
||||
with st.expander(webui_configs["dialogue"]["phase_expander_name"], True):
|
||||
choose_phase = st.selectbox(
|
||||
'请选择待使用的执行链路', PHASE_LIST, 0)
|
||||
webui_configs["dialogue"]["phase_selectbox_name"], PHASE_LIST, 0)
|
||||
|
||||
is_detailed = st.toggle("是否使用明细信息进行agent交互", False)
|
||||
tool_using_on = st.toggle("开启工具使用", PHASE_CONFIGS[choose_phase]["do_using_tool"])
|
||||
is_detailed = st.toggle(webui_configs["dialogue"]["phase_toggle_detailed_name"], False)
|
||||
tool_using_on = st.toggle(
|
||||
webui_configs["dialogue"]["phase_toggle_doToolUsing"],
|
||||
PHASE_CONFIGS[choose_phase]["do_using_tool"])
|
||||
tool_selects = []
|
||||
if tool_using_on:
|
||||
with st.expander("工具军火库", True):
|
||||
tool_selects = st.multiselect(
|
||||
'请选择待使用的工具', TOOL_SETS, ["WeatherInfo"])
|
||||
webui_configs["dialogue"]["phase_multiselect_tools"],
|
||||
TOOL_SETS, ["WeatherInfo"])
|
||||
|
||||
search_on = st.toggle("开启搜索增强", PHASE_CONFIGS[choose_phase]["do_search"])
|
||||
search_on = st.toggle(webui_configs["dialogue"]["phase_toggle_doSearch"],
|
||||
PHASE_CONFIGS[choose_phase]["do_search"])
|
||||
search_engine, top_k = None, 3
|
||||
if search_on:
|
||||
with st.expander("搜索引擎配置", True):
|
||||
search_engine = st.selectbox("请选择搜索引擎", SEARCH_ENGINES.keys(), 0)
|
||||
top_k = st.number_input("匹配搜索结果条数:", 1, 20, 3)
|
||||
with st.expander(webui_configs["dialogue"]["expander_search_name"], True):
|
||||
search_engine = st.selectbox(
|
||||
webui_configs["dialogue"]["selectbox_search_name"],
|
||||
SEARCH_ENGINES.keys(), 0)
|
||||
se_top_k = st.number_input(
|
||||
webui_configs["dialogue"]["ninput_search_topk_name"], 1, 20, 3)
|
||||
|
||||
doc_retrieval_on = st.toggle("开启知识库检索增强", PHASE_CONFIGS[choose_phase]["do_doc_retrieval"])
|
||||
doc_retrieval_on = st.toggle(
|
||||
webui_configs["dialogue"]["phase_toggle_doDocRetrieval"],
|
||||
PHASE_CONFIGS[choose_phase]["do_doc_retrieval"])
|
||||
selected_kb, top_k, score_threshold = None, 3, 1.0
|
||||
if doc_retrieval_on:
|
||||
with st.expander("知识库配置", True):
|
||||
with st.expander(webui_configs["dialogue"]["kbase_expander_name"], True):
|
||||
kb_list = api.list_knowledge_bases(no_remote_api=True)
|
||||
selected_kb = st.selectbox(
|
||||
"请选择知识库:",
|
||||
webui_configs["dialogue"]["kbase_selectbox_name"],
|
||||
kb_list,
|
||||
on_change=on_kb_change,
|
||||
key="selected_kb",
|
||||
)
|
||||
top_k = st.number_input("匹配知识条数:", 1, 20, 3)
|
||||
score_threshold = st.number_input("知识匹配分数阈值:", 0.0, float(SCORE_THRESHOLD), float(SCORE_THRESHOLD), float(SCORE_THRESHOLD//100))
|
||||
|
||||
code_retrieval_on = st.toggle("开启代码检索增强", PHASE_CONFIGS[choose_phase]["do_code_retrieval"])
|
||||
top_k = st.number_input(
|
||||
webui_configs["dialogue"]["kbase_ninput_topk_name"], 1, 20, 3)
|
||||
score_threshold = st.number_input(
|
||||
webui_configs["dialogue"]["kbase_ninput_score_threshold_name"],
|
||||
0.0, float(SCORE_THRESHOLD), float(SCORE_THRESHOLD),
|
||||
float(SCORE_THRESHOLD//100))
|
||||
|
||||
code_retrieval_on = st.toggle(
|
||||
webui_configs["dialogue"]["phase_toggle_doCodeRetrieval"],
|
||||
PHASE_CONFIGS[choose_phase]["do_code_retrieval"])
|
||||
selected_cb, top_k = None, 1
|
||||
cb_search_type = "tag"
|
||||
if code_retrieval_on:
|
||||
with st.expander('代码知识库配置', True):
|
||||
with st.expander(webui_configs["dialogue"]["cbase_expander_name"], True):
|
||||
cb_list = api.list_cb(no_remote_api=True)
|
||||
logger.debug('codebase_list={}'.format(cb_list))
|
||||
selected_cb = st.selectbox(
|
||||
"请选择代码知识库:",
|
||||
webui_configs["dialogue"]["cbase_selectbox_name"],
|
||||
cb_list,
|
||||
on_change=on_cb_change,
|
||||
key="selected_cb",
|
||||
)
|
||||
st.toast(f"已加载代码知识库: {st.session_state.selected_cb}")
|
||||
top_k = st.number_input("匹配代码条数:", 1, 20, 1)
|
||||
|
||||
# change do_interpret
|
||||
st.toast(f"{webui_configs['dialogue']['text_loaded_cbase']}: {st.session_state.selected_cb}")
|
||||
cb_details = get_cb_details_by_cb_name(st.session_state.selected_cb)
|
||||
st.session_state['do_interpret'] = cb_details['do_interpret']
|
||||
search_type_list = ['基于 cypher', '基于标签', '基于描述'] if st.session_state['do_interpret'] == 'YES' \
|
||||
else ['基于 cypher', '基于标签']
|
||||
|
||||
top_k = st.number_input(
|
||||
webui_configs["dialogue"]["cbase_ninput_topk_name"], 1, 20, 1)
|
||||
|
||||
search_type_list = webui_configs["dialogue"]["cbase_search_type_v1"] if st.session_state['do_interpret'] == 'YES' \
|
||||
else webui_configs["dialogue"]["cbase_search_type_v2"]
|
||||
|
||||
cb_search_type = st.selectbox(
|
||||
'请选择查询模式:',
|
||||
webui_configs["dialogue"]["cbase_selectbox_type_name"],
|
||||
search_type_list,
|
||||
key='cb_search_type'
|
||||
)
|
||||
|
||||
with st.expander("沙盒文件管理", False):
|
||||
with st.expander(webui_configs["sandbox"]["expander_name"], False):
|
||||
|
||||
interpreter_file = st.file_uploader(
|
||||
"上传沙盒文件",
|
||||
webui_configs["sandbox"]["file_upload_name"],
|
||||
[i for ls in LOADER2EXT_DICT.values() for i in ls] + ["jpg", "png"],
|
||||
accept_multiple_files=False,
|
||||
key=st.session_state.interpreter_file_key,
|
||||
|
@ -222,29 +255,31 @@ def dialogue_page(api: ApiRequest):
|
|||
|
||||
files = api.web_sd_list_files()
|
||||
files = files["data"]
|
||||
download_file = st.selectbox("选择要处理文件", files,
|
||||
download_file = st.selectbox(webui_configs["sandbox"]["selectbox_name"], files,
|
||||
key="download_file",)
|
||||
|
||||
cols = st.columns(3)
|
||||
file_url, file_name = api.web_sd_download(download_file)
|
||||
if cols[0].button("点击上传"):
|
||||
if cols[0].button(webui_configs["sandbox"]["button_upload_name"],):
|
||||
upload2sandbox(interpreter_file, api)
|
||||
st.session_state["interpreter_file_key"] += 1
|
||||
interpreter_file = ""
|
||||
st.experimental_rerun()
|
||||
|
||||
cols[1].download_button("点击下载", file_url, file_name)
|
||||
if cols[2].button("点击删除", ):
|
||||
cols[1].download_button(webui_configs["sandbox"]["button_download_name"],
|
||||
file_url, file_name)
|
||||
if cols[2].button(webui_configs["sandbox"]["button_delete_name"],):
|
||||
api.web_sd_delete(download_file)
|
||||
|
||||
code_interpreter_on = st.toggle("开启代码解释器") and not_agent_qa
|
||||
code_exec_on = st.toggle("自动执行代码") and not_agent_qa
|
||||
code_interpreter_on = st.toggle(
|
||||
webui_configs["sandbox"]["toggle_doCodeInterpreter"]) and not_agent_qa
|
||||
code_exec_on = st.toggle(webui_configs["sandbox"]["toggle_doAutoCodeExec"]) and not_agent_qa
|
||||
|
||||
# Display chat messages from history on app rerun
|
||||
|
||||
chat_box.output_messages()
|
||||
|
||||
chat_input_placeholder = "请输入对话内容,换行请使用Ctrl+Enter "
|
||||
chat_input_placeholder = webui_configs["chat"]["chat_placeholder"]
|
||||
code_text = "" or GLOBAL_EXE_CODE_TEXT
|
||||
codebox_res = None
|
||||
|
||||
|
@ -254,8 +289,8 @@ def dialogue_page(api: ApiRequest):
|
|||
|
||||
history = get_messages_history(history_len, is_detailed)
|
||||
chat_box.user_say(prompt)
|
||||
if dialogue_mode == "LLM 对话":
|
||||
chat_box.ai_say("正在思考...")
|
||||
if dialogue_mode == webui_configs["dialogue"]["mode"][0]:
|
||||
chat_box.ai_say(webui_configs["chat"]["chatbox_saying"])
|
||||
text = ""
|
||||
r = api.chat_chat(prompt, history, no_remote_api=True)
|
||||
for t in r:
|
||||
|
@ -277,12 +312,18 @@ def dialogue_page(api: ApiRequest):
|
|||
GLOBAL_EXE_CODE_TEXT = code_text
|
||||
if code_text and code_exec_on:
|
||||
codebox_res = api.codebox_chat("```"+code_text+"```", do_code_exe=True)
|
||||
elif dialogue_mode == "Agent问答":
|
||||
display_infos = [f"正在思考..."]
|
||||
elif dialogue_mode == webui_configs["dialogue"]["mode"][4]:
|
||||
display_infos = [webui_configs["chat"]["chatbox_saying"]]
|
||||
if search_on:
|
||||
display_infos.append(Markdown("...", in_expander=True, title="网络搜索结果"))
|
||||
display_infos.append(Markdown("...", in_expander=True,
|
||||
title=webui_configs["chat"]["chatbox_search_result"]))
|
||||
if doc_retrieval_on:
|
||||
display_infos.append(Markdown("...", in_expander=True, title="知识库匹配结果"))
|
||||
display_infos.append(Markdown("...", in_expander=True,
|
||||
title=webui_configs["chat"]["chatbox_doc_result"]))
|
||||
if code_retrieval_on:
|
||||
display_infos.append(Markdown("...", in_expander=True,
|
||||
title=webui_configs["chat"]["chatbox_code_result"]))
|
||||
|
||||
chat_box.ai_say(display_infos)
|
||||
|
||||
if 'history_node_list' in st.session_state:
|
||||
|
@ -331,18 +372,21 @@ def dialogue_page(api: ApiRequest):
|
|||
|
||||
chat_box.update_msg(text, element_index=0, streaming=False, state="complete") # 更新最终的字符串,去除光标
|
||||
if search_on:
|
||||
chat_box.update_msg("搜索匹配结果:\n\n" + "\n\n".join(d["search_docs"]), element_index=search_on, streaming=False, state="complete")
|
||||
chat_box.update_msg(f"{webui_configs['chat']['chatbox_search_result']}:\n\n" + "\n\n".join(d["search_docs"]), element_index=search_on, streaming=False, state="complete")
|
||||
if doc_retrieval_on:
|
||||
chat_box.update_msg("知识库匹配结果:\n\n" + "\n\n".join(d["db_docs"]), element_index=search_on+doc_retrieval_on, streaming=False, state="complete")
|
||||
chat_box.update_msg(f"{webui_configs['chat']['chatbox_doc_result']}:\n\n" + "\n\n".join(d["db_docs"]), element_index=search_on+doc_retrieval_on, streaming=False, state="complete")
|
||||
if code_retrieval_on:
|
||||
chat_box.update_msg(f"{webui_configs['chat']['chatbox_code_result']}:\n\n" + "\n\n".join(d["code_docs"]),
|
||||
element_index=search_on+doc_retrieval_on+code_retrieval_on, streaming=False, state="complete")
|
||||
|
||||
history_node_list.extend([node[0] for node in d.get("related_nodes", [])])
|
||||
history_node_list = list(set(history_node_list))
|
||||
st.session_state['history_node_list'] = history_node_list
|
||||
elif dialogue_mode == "知识库问答":
|
||||
elif dialogue_mode == webui_configs["dialogue"]["mode"][1]:
|
||||
history = get_messages_history(history_len)
|
||||
chat_box.ai_say([
|
||||
f"正在查询知识库 `{selected_kb}` ...",
|
||||
Markdown("...", in_expander=True, title="知识库匹配结果"),
|
||||
f"{webui_configs['chat']['chatbox_doc_querying']} `{selected_kb}` ...",
|
||||
Markdown("...", in_expander=True, title=webui_configs['chat']['chatbox_doc_result']),
|
||||
])
|
||||
text = ""
|
||||
d = {"docs": []}
|
||||
|
@ -354,21 +398,21 @@ def dialogue_page(api: ApiRequest):
|
|||
chat_box.update_msg(text, element_index=0)
|
||||
# chat_box.update_msg("知识库匹配结果: \n\n".join(d["docs"]), element_index=1, streaming=False, state="complete")
|
||||
chat_box.update_msg(text, element_index=0, streaming=False) # 更新最终的字符串,去除光标
|
||||
chat_box.update_msg("知识库匹配结果: \n\n".join(d["docs"]), element_index=1, streaming=False, state="complete")
|
||||
chat_box.update_msg("{webui_configs['chat']['chatbox_doc_result']}: \n\n".join(d["docs"]), element_index=1, streaming=False, state="complete")
|
||||
# 判断是否存在代码, 并提高编辑功能,执行功能
|
||||
code_text = api.codebox.decode_code_from_text(text)
|
||||
GLOBAL_EXE_CODE_TEXT = code_text
|
||||
if code_text and code_exec_on:
|
||||
codebox_res = api.codebox_chat("```"+code_text+"```", do_code_exe=True)
|
||||
elif dialogue_mode == '代码知识库问答':
|
||||
elif dialogue_mode == webui_configs["dialogue"]["mode"][2]:
|
||||
logger.info('prompt={}'.format(prompt))
|
||||
logger.info('history={}'.format(history))
|
||||
if 'history_node_list' in st.session_state:
|
||||
api.codeChat.history_node_list = st.session_state['history_node_list']
|
||||
|
||||
chat_box.ai_say([
|
||||
f"正在查询代码知识库 `{selected_cb}` ...",
|
||||
Markdown("...", in_expander=True, title="代码库匹配节点"),
|
||||
f"{webui_configs['chat']['chatbox_code_querying']} `{selected_cb}` ...",
|
||||
Markdown("...", in_expander=True, title=webui_configs['chat']['chatbox_code_result']),
|
||||
])
|
||||
text = ""
|
||||
d = {"codes": []}
|
||||
|
@ -393,14 +437,14 @@ def dialogue_page(api: ApiRequest):
|
|||
# session state update
|
||||
# st.session_state['history_node_list'] = api.codeChat.history_node_list
|
||||
|
||||
elif dialogue_mode == "搜索引擎问答":
|
||||
elif dialogue_mode == webui_configs["dialogue"]["mode"][3]:
|
||||
chat_box.ai_say([
|
||||
f"正在执行 `{search_engine}` 搜索...",
|
||||
Markdown("...", in_expander=True, title="网络搜索结果"),
|
||||
webui_configs['chat']['chatbox_searching'],
|
||||
Markdown("...", in_expander=True, title=webui_configs['chat']['chatbox_search_result']),
|
||||
])
|
||||
text = ""
|
||||
d = {"docs": []}
|
||||
for idx_count, d in enumerate(api.search_engine_chat(prompt, search_engine, se_top_k)):
|
||||
for idx_count, d in enumerate(api.search_engine_chat(prompt, search_engine, se_top_k, history)):
|
||||
if error_msg := check_error_msg(d): # check whether error occured
|
||||
st.error(error_msg)
|
||||
text += d["answer"]
|
||||
|
@ -408,7 +452,7 @@ def dialogue_page(api: ApiRequest):
|
|||
chat_box.update_msg(text, element_index=0)
|
||||
# chat_box.update_msg("搜索匹配结果: \n\n".join(d["docs"]), element_index=1, streaming=False)
|
||||
chat_box.update_msg(text, element_index=0, streaming=False) # 更新最终的字符串,去除光标
|
||||
chat_box.update_msg("搜索匹配结果: \n\n".join(d["docs"]), element_index=1, streaming=False, state="complete")
|
||||
chat_box.update_msg(f"{webui_configs['chat']['chatbox_search_result']}: \n\n".join(d["docs"]), element_index=1, streaming=False, state="complete")
|
||||
# 判断是否存在代码, 并提高编辑功能,执行功能
|
||||
code_text = api.codebox.decode_code_from_text(text)
|
||||
GLOBAL_EXE_CODE_TEXT = code_text
|
||||
|
@ -420,33 +464,34 @@ def dialogue_page(api: ApiRequest):
|
|||
st.experimental_rerun()
|
||||
|
||||
if code_interpreter_on:
|
||||
with st.expander("代码编辑执行器", False):
|
||||
code_part = st.text_area("代码片段", code_text, key="code_text")
|
||||
with st.expander(webui_configs['sandbox']['expander_code_name'], False):
|
||||
code_part = st.text_area(
|
||||
webui_configs['sandbox']['textArea_code_name'], code_text, key="code_text")
|
||||
cols = st.columns(2)
|
||||
if cols[0].button(
|
||||
"修改对话",
|
||||
webui_configs['sandbox']['button_modify_code_name'],
|
||||
use_container_width=True,
|
||||
):
|
||||
code_text = code_part
|
||||
GLOBAL_EXE_CODE_TEXT = code_text
|
||||
st.toast("修改对话成功")
|
||||
st.toast(webui_configs['sandbox']['text_modify_code'])
|
||||
|
||||
if cols[1].button(
|
||||
"执行代码",
|
||||
webui_configs['sandbox']['button_exec_code_name'],
|
||||
use_container_width=True
|
||||
):
|
||||
if code_text:
|
||||
codebox_res = api.codebox_chat("```"+code_text+"```", do_code_exe=True)
|
||||
st.toast("正在执行代码")
|
||||
st.toast(webui_configs['sandbox']['text_execing_code'],)
|
||||
else:
|
||||
st.toast("code 不能为空")
|
||||
st.toast(webui_configs['sandbox']['text_error_exec_code'],)
|
||||
|
||||
#TODO 这段信息会被记录到history里
|
||||
if codebox_res is not None and codebox_res.code_exe_status != 200:
|
||||
st.toast(f"{codebox_res.code_exe_response}")
|
||||
|
||||
if codebox_res is not None and codebox_res.code_exe_status == 200:
|
||||
st.toast(f"codebox_chajt {codebox_res}")
|
||||
st.toast(f"codebox_chat {codebox_res}")
|
||||
chat_box.ai_say(Markdown(code_text, in_expander=True, title="code interpreter", unsafe_allow_html=True), )
|
||||
if codebox_res.code_exe_type == "image/png":
|
||||
base_text = f"```\n{code_text}\n```\n\n"
|
||||
|
@ -464,7 +509,7 @@ def dialogue_page(api: ApiRequest):
|
|||
cols = st.columns(2)
|
||||
export_btn = cols[0]
|
||||
if cols[1].button(
|
||||
"清空对话",
|
||||
webui_configs['export']['button_clear_conversation_name'],
|
||||
use_container_width=True,
|
||||
):
|
||||
chat_box.reset_history()
|
||||
|
@ -474,9 +519,9 @@ def dialogue_page(api: ApiRequest):
|
|||
st.experimental_rerun()
|
||||
|
||||
export_btn.download_button(
|
||||
"导出记录",
|
||||
webui_configs['export']['download_button_export_name'],
|
||||
"".join(chat_box.export2md()),
|
||||
file_name=f"{now:%Y-%m-%d %H.%M}_对话记录.md",
|
||||
file_name=f"{now:%Y-%m-%d %H.%M}_conversations.md",
|
||||
mime="text/markdown",
|
||||
use_container_width=True,
|
||||
)
|
||||
|
|
|
@ -355,7 +355,8 @@ class ApiRequest:
|
|||
self,
|
||||
query: str,
|
||||
search_engine_name: str,
|
||||
code_limit: int,
|
||||
top_k: int,
|
||||
history: List[Dict] = [],
|
||||
stream: bool = True,
|
||||
no_remote_api: bool = None,
|
||||
):
|
||||
|
@ -368,8 +369,8 @@ class ApiRequest:
|
|||
data = {
|
||||
"query": query,
|
||||
"engine_name": search_engine_name,
|
||||
"code_limit": code_limit,
|
||||
"history": [],
|
||||
"top_k": top_k,
|
||||
"history": history,
|
||||
"stream": stream,
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,78 @@
|
|||
# This is an example of webui
|
||||
dialogue:
|
||||
mode_instruction: 请选择对话模式
|
||||
mode:
|
||||
- LLM Conversation
|
||||
- Knowledge Base Q&A
|
||||
- Code Knowledge Base Q&A
|
||||
- Search Engine Q&A
|
||||
- Agents Q&A
|
||||
history_length: History of Dialogue Turns
|
||||
text_mode_swtich: Switched to mode
|
||||
text_knowledgeBase_swtich: Current Knowledge Base"
|
||||
text_loaded_kbase: Loaded Knowledge Base
|
||||
text_loaded_cbase: Loaded Code Knowledge Base
|
||||
# Knowledge Base Q&A
|
||||
kbase_expander_name: 知识库配置
|
||||
kbase_selectbox_name: 请选择知识库:
|
||||
kbase_ninput_topk_name: 匹配知识条数:
|
||||
kbase_ninput_score_threshold_name: 知识匹配分数阈值:
|
||||
# Code Knowledge Base Q&A
|
||||
cbase_expander_name: 代码知识库配置
|
||||
cbase_selectbox_name: 请选择代码知识库:
|
||||
cbase_ninput_topk_name: 匹配代码条数:
|
||||
cbase_selectbox_type_name: 请选择查询模式:
|
||||
cbase_search_type_v1:
|
||||
- 基于 cypher
|
||||
- 基于标签
|
||||
- 基于描述
|
||||
cbase_search_type_v2:
|
||||
- 基于 cypher
|
||||
- 基于标签
|
||||
|
||||
# Search Engine Q&A
|
||||
expander_search_name: 搜索引擎配置
|
||||
selectbox_search_name: 请选择搜索引擎
|
||||
ninput_search_topk_name: 匹配搜索结果条数:
|
||||
# Agents Q&A
|
||||
phase_expander_name: Phase管理
|
||||
phase_selectbox_name: 请选择待使用的执行链路
|
||||
phase_toggle_detailed_name: 是否使用明细信息进行agent交互
|
||||
phase_toggle_doToolUsing: 开启工具使用
|
||||
phase_multiselect_tools: 请选择待使用的工具
|
||||
phase_toggle_doSearch: 开启搜索增强
|
||||
phase_toggle_doDocRetrieval: 开启知识库检索增强
|
||||
phase_toggle_doCodeRetrieval: 开启代码检索增强
|
||||
|
||||
sandbox:
|
||||
expander_name: 沙盒文件管理
|
||||
file_upload_name: 上传沙盒文件
|
||||
selectbox_name: 选择要处理文件
|
||||
button_upload_name: 点击上传
|
||||
button_download_name: 点击下载
|
||||
button_delete_name: 点击删除
|
||||
toggle_doCodeInterpreter: 开启代码解释器
|
||||
toggle_doAutoCodeExec: 自动执行代码
|
||||
|
||||
expander_code_name: 代码编辑执行器
|
||||
textArea_code_name: 代码片段
|
||||
button_modify_code_name: 修改对话
|
||||
text_modify_code: 修改对话成功
|
||||
button_exec_code_name: 执行代码
|
||||
text_execing_code: 正在执行代码
|
||||
text_error_exec_code: code 不能为空
|
||||
|
||||
|
||||
chat:
|
||||
chat_placeholder: 请输入对话内容,换行请使用Ctrl+Enter
|
||||
chatbox_saying: 正在思考...
|
||||
chatbox_doc_querying: 正在查询知识库
|
||||
chatbox_code_querying: 正在查询代码知识库
|
||||
chatbox_searching: 正在执行搜索
|
||||
chatbox_search_result: 网络搜索结果
|
||||
chatbox_doc_result: 知识库匹配结果
|
||||
chatbox_code_result: 代码库匹配节点
|
||||
|
||||
export:
|
||||
button_clear_conversation_name: 清空对话
|
||||
download_button_export_name: 导出记录
|
|
@ -0,0 +1,78 @@
|
|||
# This is an example of webui
|
||||
dialogue:
|
||||
mode_instruction: 请选择对话模式
|
||||
mode:
|
||||
- LLM 对话
|
||||
- 知识库问答
|
||||
- 代码知识库问答
|
||||
- 搜索引擎问答
|
||||
- Agent问答
|
||||
history_length: 历史对话轮数
|
||||
text_mode_swtich: 已切换到模式
|
||||
text_knowledgeBase_swtich: 当前知识库
|
||||
text_loaded_kbase: 已加载知识库
|
||||
text_loaded_cbase: 已加载代码知识库
|
||||
# Knowledge Base Q&A
|
||||
kbase_expander_name: 知识库配置
|
||||
kbase_selectbox_name: 请选择知识库:
|
||||
kbase_ninput_topk_name: 匹配知识条数:
|
||||
kbase_ninput_score_threshold_name: 知识匹配分数阈值:
|
||||
# Code Knowledge Base Q&A
|
||||
cbase_expander_name: 代码知识库配置
|
||||
cbase_selectbox_name: 请选择代码知识库:
|
||||
cbase_ninput_topk_name: 匹配代码条数:
|
||||
cbase_selectbox_type_name: 请选择查询模式:
|
||||
cbase_search_type_v1:
|
||||
- 基于 cypher
|
||||
- 基于标签
|
||||
- 基于描述
|
||||
cbase_search_type_v2:
|
||||
- 基于 cypher
|
||||
- 基于标签
|
||||
|
||||
# Search Engine Q&A
|
||||
expander_search_name: 搜索引擎配置
|
||||
selectbox_search_name: 请选择搜索引擎
|
||||
ninput_search_topk_name: 匹配搜索结果条数:
|
||||
# Agents Q&A
|
||||
phase_expander_name: Phase管理
|
||||
phase_selectbox_name: 请选择待使用的执行链路
|
||||
phase_toggle_detailed_name: 是否使用明细信息进行agent交互
|
||||
phase_toggle_doToolUsing: 开启工具使用
|
||||
phase_multiselect_tools: 请选择待使用的工具
|
||||
phase_toggle_doSearch: 开启搜索增强
|
||||
phase_toggle_doDocRetrieval: 开启知识库检索增强
|
||||
phase_toggle_doCodeRetrieval: 开启代码检索增强
|
||||
|
||||
sandbox:
|
||||
expander_name: 沙盒文件管理
|
||||
file_upload_name: 上传沙盒文件
|
||||
selectbox_name: 选择要处理文件
|
||||
button_upload_name: 点击上传
|
||||
button_download_name: 点击下载
|
||||
button_delete_name: 点击删除
|
||||
toggle_doCodeInterpreter: 开启代码解释器
|
||||
toggle_doAutoCodeExec: 自动执行代码
|
||||
|
||||
expander_code_name: 代码编辑执行器
|
||||
textArea_code_name: 代码片段
|
||||
button_modify_code_name: 修改对话
|
||||
text_modify_code: 修改对话成功
|
||||
button_exec_code_name: 执行代码
|
||||
text_execing_code: 正在执行代码
|
||||
text_error_exec_code: code 不能为空
|
||||
|
||||
|
||||
chat:
|
||||
chat_placeholder: 请输入对话内容,换行请使用Ctrl+Enter
|
||||
chatbox_saying: 正在思考...
|
||||
chatbox_doc_querying: 正在查询知识库
|
||||
chatbox_code_querying: 正在查询代码知识库
|
||||
chatbox_searching: 正在执行搜索
|
||||
chatbox_search_result: 网络搜索结果
|
||||
chatbox_doc_result: 知识库匹配结果
|
||||
chatbox_code_result: 代码库匹配节点
|
||||
|
||||
export:
|
||||
button_clear_conversation_name: 清空对话
|
||||
download_button_export_name: 导出记录
|
|
@ -0,0 +1,53 @@
|
|||
import os, sys, requests
|
||||
|
||||
src_dir = os.path.join(
|
||||
os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
)
|
||||
sys.path.append(src_dir)
|
||||
|
||||
from dev_opsgpt.tools import (
|
||||
toLangchainTools, get_tool_schema, DDGSTool, DocRetrieval,
|
||||
TOOL_DICT, TOOL_SETS
|
||||
)
|
||||
|
||||
from configs.model_config import *
|
||||
from dev_opsgpt.connector.phase import BasePhase
|
||||
from dev_opsgpt.connector.agents import BaseAgent
|
||||
from dev_opsgpt.connector.chains import BaseChain
|
||||
from dev_opsgpt.connector.schema import (
|
||||
Message, Memory, load_role_configs, load_phase_configs, load_chain_configs
|
||||
)
|
||||
from dev_opsgpt.connector.configs import AGETN_CONFIGS, CHAIN_CONFIGS, PHASE_CONFIGS
|
||||
import importlib
|
||||
|
||||
tools = toLangchainTools([TOOL_DICT[i] for i in TOOL_SETS if i in TOOL_DICT])
|
||||
|
||||
|
||||
role_configs = load_role_configs(AGETN_CONFIGS)
|
||||
chain_configs = load_chain_configs(CHAIN_CONFIGS)
|
||||
phase_configs = load_phase_configs(PHASE_CONFIGS)
|
||||
|
||||
agent_module = importlib.import_module("dev_opsgpt.connector.agents")
|
||||
|
||||
|
||||
phase_name = "baseGroupPhase"
|
||||
phase = BasePhase(phase_name,
|
||||
task = None,
|
||||
phase_config = PHASE_CONFIGS,
|
||||
chain_config = CHAIN_CONFIGS,
|
||||
role_config = AGETN_CONFIGS,
|
||||
do_summary=False,
|
||||
do_code_retrieval=False,
|
||||
do_doc_retrieval=True,
|
||||
do_search=False,
|
||||
)
|
||||
|
||||
# round-1
|
||||
query_content = "确认本地是否存在employee_data.csv,并查看它有哪些列和数据类型;然后画柱状图"
|
||||
# query_content = "帮我确认下127.0.0.1这个服务器的在10点是否存在异常,请帮我判断一下"
|
||||
query = Message(
|
||||
role_name="human", role_type="user", tools=tools,
|
||||
role_content=query_content, input_query=query_content, origin_query=query_content,
|
||||
)
|
||||
|
||||
output_message, _ = phase.step(query)
|
|
@ -43,7 +43,7 @@ phase = BasePhase(phase_name,
|
|||
)
|
||||
|
||||
# round-1
|
||||
query_content = "确认本地是否存在employee_data.csv,并查看它有哪些列和数据类型;然后画柱状图"
|
||||
query_content = "确认本地是否存在book_data.csv,并查看它有哪些列和数据类型;然后画柱状图"
|
||||
query = Message(
|
||||
role_name="human", role_type="user",
|
||||
role_content=query_content, input_query=query_content, origin_query=query_content,
|
||||
|
|
|
@ -0,0 +1,236 @@
|
|||
import os, sys, requests
|
||||
|
||||
src_dir = os.path.join(
|
||||
os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
)
|
||||
sys.path.append(src_dir)
|
||||
|
||||
from configs.model_config import *
|
||||
from dev_opsgpt.connector.phase import BasePhase
|
||||
from dev_opsgpt.connector.agents import BaseAgent
|
||||
from dev_opsgpt.connector.chains import BaseChain
|
||||
from dev_opsgpt.connector.schema import (
|
||||
Message, Memory, load_role_configs, load_phase_configs, load_chain_configs
|
||||
)
|
||||
from dev_opsgpt.connector.configs import AGETN_CONFIGS, CHAIN_CONFIGS, PHASE_CONFIGS
|
||||
from dev_opsgpt.connector.utils import parse_section
|
||||
import importlib
|
||||
|
||||
|
||||
|
||||
# update new agent configs
|
||||
codeRetrievalJudger_PROMPT = """#### CodeRetrievalJudger Assistance Guidance
|
||||
|
||||
Given the user's question and respective code, you need to decide whether the provided codes are enough to answer the question.
|
||||
|
||||
#### Input Format
|
||||
|
||||
**Origin Query:** the initial question or objective that the user wanted to achieve
|
||||
|
||||
**Retrieval Codes:** the Retrieval Codes from the code base
|
||||
|
||||
#### Response Output Format
|
||||
|
||||
**REASON:** Justify the decision of choosing 'finished' and 'continued' by evaluating the progress step by step.
|
||||
"""
|
||||
|
||||
# 将下面的话放到上面的prompt里面去执行,让它判断是否停止
|
||||
# **Action Status:** Set to 'finished' or 'continued'.
|
||||
# If it's 'finished', the provided codes can answer the origin query.
|
||||
# If it's 'continued', the origin query cannot be answered well from the provided code.
|
||||
|
||||
codeRetrievalDivergent_PROMPT = """#### CodeRetrievalDivergen Assistance Guidance
|
||||
|
||||
You are a assistant that helps to determine which code package is needed to answer the question.
|
||||
|
||||
Given the user's question, Retrieval code, and the code Packages related to Retrieval code. you need to decide which code package we need to read to better answer the question.
|
||||
|
||||
#### Input Format
|
||||
|
||||
**Origin Query:** the initial question or objective that the user wanted to achieve
|
||||
|
||||
**Retrieval Codes:** the Retrieval Codes from the code base
|
||||
|
||||
**Code Packages:** the code packages related to Retrieval code
|
||||
|
||||
#### Response Output Format
|
||||
|
||||
**Code Package:** Identify another Code Package from the Code Packages that should be read to provide a better answer to the Origin Query.
|
||||
|
||||
**REASON:** Justify the decision of choosing 'finished' and 'continued' by evaluating the progress step by step.
|
||||
"""
|
||||
|
||||
AGETN_CONFIGS.update({
|
||||
"codeRetrievalJudger": {
|
||||
"role": {
|
||||
"role_prompt": codeRetrievalJudger_PROMPT,
|
||||
"role_type": "assistant",
|
||||
"role_name": "codeRetrievalJudger",
|
||||
"role_desc": "",
|
||||
"agent_type": "CodeRetrievalJudger"
|
||||
# "agent_type": "BaseAgent"
|
||||
},
|
||||
"chat_turn": 1,
|
||||
"focus_agents": [],
|
||||
"focus_message_keys": [],
|
||||
},
|
||||
"codeRetrievalDivergent": {
|
||||
"role": {
|
||||
"role_prompt": codeRetrievalDivergent_PROMPT,
|
||||
"role_type": "assistant",
|
||||
"role_name": "codeRetrievalDivergent",
|
||||
"role_desc": "",
|
||||
"agent_type": "CodeRetrievalDivergent"
|
||||
# "agent_type": "BaseAgent"
|
||||
},
|
||||
"chat_turn": 1,
|
||||
"focus_agents": [],
|
||||
"focus_message_keys": [],
|
||||
},
|
||||
})
|
||||
# update new chain configs
|
||||
CHAIN_CONFIGS.update({
|
||||
"codeRetrievalChain": {
|
||||
"chain_name": "codeRetrievalChain",
|
||||
"chain_type": "BaseChain",
|
||||
"agents": ["codeRetrievalJudger", "codeRetrievalDivergent"],
|
||||
"chat_turn": 5,
|
||||
"do_checker": False,
|
||||
"chain_prompt": ""
|
||||
}
|
||||
})
|
||||
|
||||
# update phase configs
|
||||
PHASE_CONFIGS.update({
|
||||
"codeRetrievalPhase": {
|
||||
"phase_name": "codeRetrievalPhase",
|
||||
"phase_type": "BasePhase",
|
||||
"chains": ["codeRetrievalChain"],
|
||||
"do_summary": False,
|
||||
"do_search": False,
|
||||
"do_doc_retrieval": False,
|
||||
"do_code_retrieval": True,
|
||||
"do_tool_retrieval": False,
|
||||
"do_using_tool": False
|
||||
},
|
||||
})
|
||||
|
||||
|
||||
|
||||
|
||||
role_configs = load_role_configs(AGETN_CONFIGS)
|
||||
chain_configs = load_chain_configs(CHAIN_CONFIGS)
|
||||
phase_configs = load_phase_configs(PHASE_CONFIGS)
|
||||
|
||||
agent_module = importlib.import_module("dev_opsgpt.connector.agents")
|
||||
|
||||
from dev_opsgpt.tools import CodeRetrievalSingle, RelatedVerticesRetrival, Vertex2Code
|
||||
|
||||
# 定义一个新的类
|
||||
class CodeRetrievalJudger(BaseAgent):
|
||||
|
||||
def start_action_step(self, message: Message) -> Message:
|
||||
'''do action before agent predict '''
|
||||
action_json = CodeRetrievalSingle.run(message.code_engine_name, message.origin_query)
|
||||
message.customed_kargs["CodeRetrievalSingleRes"] = action_json
|
||||
message.customed_kargs.setdefault("Retrieval_Codes", "")
|
||||
message.customed_kargs["Retrieval_Codes"] += "\n" + action_json["code"]
|
||||
return message
|
||||
|
||||
def create_prompt(
|
||||
self, query: Message, memory: Memory =None, history: Memory = None, background: Memory = None, memory_pool: Memory=None, prompt_mamnger=None) -> str:
|
||||
'''
|
||||
prompt engineer, contains role\task\tools\docs\memory
|
||||
'''
|
||||
#
|
||||
logger.debug(f"query: {query.customed_kargs}")
|
||||
formatted_tools, tool_names, _ = self.create_tools_prompt(query)
|
||||
prompt = self.role.role_prompt.format(**{"formatted_tools": formatted_tools, "tool_names": tool_names})
|
||||
#
|
||||
input_keys = parse_section(self.role.role_prompt, 'Input Format')
|
||||
prompt += "\n#### Begin!!!\n"
|
||||
#
|
||||
for input_key in input_keys:
|
||||
if input_key == "Origin Query":
|
||||
prompt += f"\n**{input_key}:**\n" + query.origin_query
|
||||
elif input_key == "Retrieval Codes":
|
||||
prompt += f"\n**{input_key}:**\n" + query.customed_kargs["Retrieval_Codes"]
|
||||
|
||||
while "{{" in prompt or "}}" in prompt:
|
||||
prompt = prompt.replace("{{", "{")
|
||||
prompt = prompt.replace("}}", "}")
|
||||
return prompt
|
||||
|
||||
# 定义一个新的类
|
||||
class CodeRetrievalDivergent(BaseAgent):
|
||||
|
||||
def start_action_step(self, message: Message) -> Message:
|
||||
'''do action before agent predict '''
|
||||
action_json = RelatedVerticesRetrival.run(message.code_engine_name, message.customed_kargs["CodeRetrievalSingleRes"]["vertex"])
|
||||
message.customed_kargs["RelatedVerticesRetrivalRes"] = action_json
|
||||
return message
|
||||
|
||||
def end_action_step(self, message: Message) -> Message:
|
||||
'''do action before agent predict '''
|
||||
# logger.error(f"message: {message}")
|
||||
# action_json = Vertex2Code.run(message.code_engine_name, "com.theokanning.openai.client#Utils.java") # message.parsed_output["Code_Filename"])
|
||||
action_json = Vertex2Code.run(message.code_engine_name, message.parsed_output["Code Package"])
|
||||
message.customed_kargs["Vertex2Code"] = action_json
|
||||
message.customed_kargs.setdefault("Retrieval_Codes", "")
|
||||
message.customed_kargs["Retrieval_Codes"] += "\n" + action_json["code"]
|
||||
return message
|
||||
|
||||
def create_prompt(
|
||||
self, query: Message, memory: Memory =None, history: Memory = None, background: Memory = None, memory_pool: Memory=None, prompt_mamnger=None) -> str:
|
||||
'''
|
||||
prompt engineer, contains role\task\tools\docs\memory
|
||||
'''
|
||||
formatted_tools, tool_names, _ = self.create_tools_prompt(query)
|
||||
prompt = self.role.role_prompt.format(**{"formatted_tools": formatted_tools, "tool_names": tool_names})
|
||||
#
|
||||
input_query = query.input_query
|
||||
input_keys = parse_section(self.role.role_prompt, 'Input Format')
|
||||
prompt += "\n#### Begin!!!\n"
|
||||
#
|
||||
for input_key in input_keys:
|
||||
if input_key == "Origin Query":
|
||||
prompt += f"\n**{input_key}:**\n" + query.origin_query
|
||||
elif input_key == "Retrieval Codes":
|
||||
prompt += f"\n**{input_key}:**\n" + query.customed_kargs["Retrieval_Codes"]
|
||||
elif input_key == "Code Packages":
|
||||
vertices = query.customed_kargs["RelatedVerticesRetrivalRes"]["vertices"]
|
||||
prompt += f"\n**{input_key}:**\n" + ", ".join([str(v) for v in vertices])
|
||||
|
||||
while "{{" in prompt or "}}" in prompt:
|
||||
prompt = prompt.replace("{{", "{")
|
||||
prompt = prompt.replace("}}", "}")
|
||||
return prompt
|
||||
|
||||
|
||||
setattr(agent_module, 'CodeRetrievalJudger', CodeRetrievalJudger)
|
||||
setattr(agent_module, 'CodeRetrievalDivergent', CodeRetrievalDivergent)
|
||||
|
||||
|
||||
#
|
||||
phase_name = "codeRetrievalPhase"
|
||||
phase = BasePhase(phase_name,
|
||||
task = None,
|
||||
phase_config = PHASE_CONFIGS,
|
||||
chain_config = CHAIN_CONFIGS,
|
||||
role_config = AGETN_CONFIGS,
|
||||
do_summary=False,
|
||||
do_code_retrieval=False,
|
||||
do_doc_retrieval=False,
|
||||
do_search=False,
|
||||
)
|
||||
|
||||
# round-1
|
||||
query_content = "remove 这个函数是用来做什么的"
|
||||
query = Message(
|
||||
role_name="user", role_type="human",
|
||||
role_content=query_content, input_query=query_content, origin_query=query_content,
|
||||
code_engine_name="client", score_threshold=1.0, top_k=3, cb_search_type="cypher"
|
||||
)
|
||||
|
||||
output_message1, _ = phase.step(query)
|
||||
|
|
@ -28,7 +28,7 @@ print(src_dir)
|
|||
# tools = toLangchainTools([TOOL_DICT[i] for i in TOOL_SETS if i in TOOL_DICT])
|
||||
|
||||
TOOL_SETS = [
|
||||
"StockInfo", "StockName"
|
||||
"StockName", "StockInfo",
|
||||
]
|
||||
|
||||
tools = toLangchainTools([TOOL_DICT[i] for i in TOOL_SETS if i in TOOL_DICT])
|
||||
|
@ -52,7 +52,7 @@ phase = BasePhase(phase_name,
|
|||
do_search=False,
|
||||
)
|
||||
|
||||
query_content = "查询贵州茅台的股票代码,并查询截止到当前日期(2023年11月8日)的最近10天的每日时序数据,然后对时序数据画出折线图并分析"
|
||||
query_content = "查询贵州茅台的股票代码,并查询截止到当前日期(2023年12月24日)的最近10天的每日时序数据,然后对时序数据画出折线图并分析"
|
||||
|
||||
query = Message(role_name="human", role_type="user", input_query=query_content, role_content=query_content, origin_query=query_content, tools=tools)
|
||||
|
||||
|
|
|
@ -27,47 +27,25 @@ class GptqConfig:
|
|||
|
||||
|
||||
def load_quant_by_autogptq(model):
|
||||
# qwen-72b-int4 use these code
|
||||
from modelscope import AutoTokenizer, AutoModelForCausalLM
|
||||
# Note: The default behavior now has injection attack prevention off.
|
||||
tokenizer = AutoTokenizer.from_pretrained(model, revision='master', trust_remote_code=True)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model, device_map="auto",
|
||||
trust_remote_code=True
|
||||
).eval()
|
||||
return model, tokenizer
|
||||
# codellama-34b-int4 use these code
|
||||
# from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
|
||||
# tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False, trust_remote_code=True)
|
||||
# model = AutoGPTQForCausalLM.from_quantized(model, inject_fused_attention=False,trust_remote_code=True,
|
||||
# inject_fused_mlp=False,use_cuda_fp16=True,disable_exllama=False,device_map='auto')
|
||||
# return model, tokenizer
|
||||
|
||||
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
|
||||
model = AutoGPTQForCausalLM.from_quantized(model,
|
||||
inject_fused_attention=False,
|
||||
inject_fused_mlp=False,
|
||||
use_cuda_fp16=True,
|
||||
disable_exllama=False,
|
||||
device_map='auto'
|
||||
)
|
||||
return model
|
||||
|
||||
def load_gptq_quantized(model_name, gptq_config: GptqConfig):
|
||||
print("Loading GPTQ quantized model...")
|
||||
|
||||
if gptq_config.act_order:
|
||||
try:
|
||||
script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
module_path = os.path.join(script_path, "../repositories/GPTQ-for-LLaMa")
|
||||
|
||||
sys.path.insert(0, module_path)
|
||||
from llama import load_quant
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
|
||||
# only `fastest-inference-4bit` branch cares about `act_order`
|
||||
model = load_quant(
|
||||
model_name,
|
||||
find_gptq_ckpt(gptq_config),
|
||||
gptq_config.wbits,
|
||||
gptq_config.groupsize,
|
||||
act_order=gptq_config.act_order,
|
||||
)
|
||||
except ImportError as e:
|
||||
print(f"Error: Failed to load GPTQ-for-LLaMa. {e}")
|
||||
print("See https://github.com/lm-sys/FastChat/blob/main/docs/gptq.md")
|
||||
sys.exit(-1)
|
||||
|
||||
else:
|
||||
# other branches
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
|
||||
model = load_quant_by_autogptq(model_name)
|
||||
|
||||
model, tokenizer = load_quant_by_autogptq(model_name)
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
|
|
|
@ -1,248 +1,248 @@
|
|||
import docker, sys, os, time, requests, psutil
|
||||
import subprocess
|
||||
from docker.types import Mount, DeviceRequest
|
||||
from loguru import logger
|
||||
|
||||
src_dir = os.path.join(
|
||||
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
)
|
||||
sys.path.append(src_dir)
|
||||
|
||||
from configs.model_config import USE_FASTCHAT, JUPYTER_WORK_PATH
|
||||
from configs.server_config import (
|
||||
NO_REMOTE_API, SANDBOX_SERVER, SANDBOX_IMAGE_NAME, SANDBOX_CONTRAINER_NAME,
|
||||
WEBUI_SERVER, API_SERVER, SDFILE_API_SERVER, CONTRAINER_NAME, IMAGE_NAME, DOCKER_SERVICE,
|
||||
DEFAULT_BIND_HOST, NEBULA_GRAPH_SERVER
|
||||
)
|
||||
|
||||
|
||||
import platform
|
||||
system_name = platform.system()
|
||||
USE_TTY = system_name in ["Windows"]
|
||||
|
||||
|
||||
def check_process(content: str, lang: str = None, do_stop=False):
|
||||
'''process-not-exist is true, process-exist is false'''
|
||||
for process in psutil.process_iter(["pid", "name", "cmdline"]):
|
||||
# check process name contains "jupyter" and port=xx
|
||||
|
||||
# if f"port={SANDBOX_SERVER['port']}" in str(process.info["cmdline"]).lower() and \
|
||||
# "jupyter" in process.info['name'].lower():
|
||||
if content in str(process.info["cmdline"]).lower():
|
||||
logger.info(f"content, {process.info}")
|
||||
# 关闭进程
|
||||
if do_stop:
|
||||
process.terminate()
|
||||
return True
|
||||
return False
|
||||
return True
|
||||
|
||||
def check_docker(client, container_name, do_stop=False):
|
||||
'''container-not-exist is true, container-exist is false'''
|
||||
if client is None: return True
|
||||
for i in client.containers.list(all=True):
|
||||
if i.name == container_name:
|
||||
if do_stop:
|
||||
container = i
|
||||
|
||||
if container_name == CONTRAINER_NAME and i.status == 'running':
|
||||
# wrap up db
|
||||
logger.info(f'inside {container_name}')
|
||||
# cp nebula data
|
||||
res = container.exec_run('''sh chatbot/dev_opsgpt/utils/nebula_cp.sh''')
|
||||
logger.info(f'cp res={res}')
|
||||
|
||||
# stop nebula service
|
||||
res = container.exec_run('''/usr/local/nebula/scripts/nebula.service stop all''')
|
||||
logger.info(f'stop res={res}')
|
||||
|
||||
container.stop()
|
||||
container.remove()
|
||||
return True
|
||||
return False
|
||||
return True
|
||||
|
||||
def start_docker(client, script_shs, ports, image_name, container_name, mounts=None, network=None):
|
||||
container = client.containers.run(
|
||||
image=image_name,
|
||||
command="bash",
|
||||
mounts=mounts,
|
||||
name=container_name,
|
||||
mem_limit="8g",
|
||||
# device_requests=[DeviceRequest(count=-1, capabilities=[['gpu']])],
|
||||
# network_mode="host",
|
||||
ports=ports,
|
||||
stdin_open=True,
|
||||
detach=True,
|
||||
tty=USE_TTY,
|
||||
network=network,
|
||||
)
|
||||
|
||||
logger.info(f"docker id: {container.id[:10]}")
|
||||
|
||||
# 启动notebook
|
||||
for script_sh in script_shs:
|
||||
if USE_FASTCHAT and "llm_api" in script_sh:
|
||||
logger.debug(script_sh)
|
||||
response = container.exec_run(["sh", "-c", script_sh])
|
||||
logger.debug(response)
|
||||
elif "llm_api" not in script_sh:
|
||||
logger.debug(script_sh)
|
||||
response = container.exec_run(["sh", "-c", script_sh])
|
||||
logger.debug(response)
|
||||
return container
|
||||
|
||||
#########################################
|
||||
############# 开始启动服务 ###############
|
||||
#########################################
|
||||
network_name ='my_network'
|
||||
|
||||
def start_sandbox_service(network_name ='my_network'):
|
||||
# networks = client.networks.list()
|
||||
# if any([network_name==i.attrs["Name"] for i in networks]):
|
||||
# network = client.networks.get(network_name)
|
||||
# else:
|
||||
# network = client.networks.create('my_network', driver='bridge')
|
||||
|
||||
mount = Mount(
|
||||
type='bind',
|
||||
source=os.path.join(src_dir, "jupyter_work"),
|
||||
target='/home/user/chatbot/jupyter_work',
|
||||
read_only=False # 如果需要只读访问,将此选项设置为True
|
||||
)
|
||||
mounts = [mount]
|
||||
# 沙盒的启动与服务的启动是独立的
|
||||
if SANDBOX_SERVER["do_remote"]:
|
||||
client = docker.from_env()
|
||||
# 启动容器
|
||||
logger.info("start container sandbox service")
|
||||
script_shs = ["bash jupyter_start.sh"]
|
||||
JUPYTER_WORK_PATH = "/home/user/chatbot/jupyter_work"
|
||||
script_shs = [f"cd /home/user/chatbot/jupyter_work && nohup jupyter-notebook --NotebookApp.token=mytoken --port=5050 --allow-root --ip=0.0.0.0 --notebook-dir={JUPYTER_WORK_PATH} --no-browser --ServerApp.disable_check_xsrf=True &"]
|
||||
ports = {f"{SANDBOX_SERVER['docker_port']}/tcp": f"{SANDBOX_SERVER['port']}/tcp"}
|
||||
if check_docker(client, SANDBOX_CONTRAINER_NAME, do_stop=True, ):
|
||||
container = start_docker(client, script_shs, ports, SANDBOX_IMAGE_NAME, SANDBOX_CONTRAINER_NAME, mounts=mounts, network=network_name)
|
||||
# 判断notebook是否启动
|
||||
time.sleep(5)
|
||||
retry_nums = 3
|
||||
while retry_nums>0:
|
||||
logger.info(f"http://localhost:{SANDBOX_SERVER['port']}")
|
||||
response = requests.get(f"http://localhost:{SANDBOX_SERVER['port']}", timeout=270)
|
||||
if response.status_code == 200:
|
||||
logger.info("container & notebook init success")
|
||||
break
|
||||
else:
|
||||
retry_nums -= 1
|
||||
logger.info(client.containers.list())
|
||||
logger.info("wait container running ...")
|
||||
time.sleep(5)
|
||||
|
||||
else:
|
||||
try:
|
||||
client = docker.from_env()
|
||||
except:
|
||||
client = None
|
||||
check_docker(client, SANDBOX_CONTRAINER_NAME, do_stop=True, )
|
||||
logger.info("start local sandbox service")
|
||||
|
||||
def start_api_service(sandbox_host=DEFAULT_BIND_HOST):
|
||||
# 启动service的容器
|
||||
if DOCKER_SERVICE:
|
||||
client = docker.from_env()
|
||||
logger.info("start container service")
|
||||
check_process("service/api.py", do_stop=True)
|
||||
check_process("service/sdfile_api.py", do_stop=True)
|
||||
check_process("service/sdfile_api.py", do_stop=True)
|
||||
check_process("webui.py", do_stop=True)
|
||||
mount = Mount(
|
||||
type='bind',
|
||||
source=src_dir,
|
||||
target='/home/user/chatbot/',
|
||||
read_only=False # 如果需要只读访问,将此选项设置为True
|
||||
)
|
||||
mount_database = Mount(
|
||||
type='bind',
|
||||
source=os.path.join(src_dir, "knowledge_base"),
|
||||
target='/home/user/knowledge_base/',
|
||||
read_only=False # 如果需要只读访问,将此选项设置为True
|
||||
)
|
||||
mount_code_database = Mount(
|
||||
type='bind',
|
||||
source=os.path.join(src_dir, "code_base"),
|
||||
target='/home/user/code_base/',
|
||||
read_only=False # 如果需要只读访问,将此选项设置为True
|
||||
)
|
||||
ports={
|
||||
f"{API_SERVER['docker_port']}/tcp": f"{API_SERVER['port']}/tcp",
|
||||
f"{WEBUI_SERVER['docker_port']}/tcp": f"{WEBUI_SERVER['port']}/tcp",
|
||||
f"{SDFILE_API_SERVER['docker_port']}/tcp": f"{SDFILE_API_SERVER['port']}/tcp",
|
||||
f"{NEBULA_GRAPH_SERVER['docker_port']}/tcp": f"{NEBULA_GRAPH_SERVER['port']}/tcp"
|
||||
}
|
||||
mounts = [mount, mount_database, mount_code_database]
|
||||
script_shs = [
|
||||
"mkdir -p /home/user/logs",
|
||||
'''
|
||||
if [ -d "/home/user/chatbot/data/nebula_data/data/meta" ]; then
|
||||
cp -r /home/user/chatbot/data/nebula_data/data /usr/local/nebula/
|
||||
fi
|
||||
''',
|
||||
"/usr/local/nebula/scripts/nebula.service start all",
|
||||
"/usr/local/nebula/scripts/nebula.service status all",
|
||||
"sleep 2",
|
||||
'''curl -X PUT -H "Content-Type: application/json" -d'{"heartbeat_interval_secs":"2"}' -s "http://127.0.0.1:19559/flags"''',
|
||||
'''curl -X PUT -H "Content-Type: application/json" -d'{"heartbeat_interval_secs":"2"}' -s "http://127.0.0.1:19669/flags"''',
|
||||
'''curl -X PUT -H "Content-Type: application/json" -d'{"heartbeat_interval_secs":"2"}' -s "http://127.0.0.1:19779/flags"''',
|
||||
|
||||
"nohup python chatbot/dev_opsgpt/service/sdfile_api.py > /home/user/logs/sdfile_api.log 2>&1 &",
|
||||
f"export DUCKDUCKGO_PROXY=socks5://host.docker.internal:13659 && export SANDBOX_HOST={sandbox_host} &&\
|
||||
nohup python chatbot/dev_opsgpt/service/api.py > /home/user/logs/api.log 2>&1 &",
|
||||
"nohup python chatbot/dev_opsgpt/service/llm_api.py > /home/user/ 2>&1 &",
|
||||
f"export DUCKDUCKGO_PROXY=socks5://host.docker.internal:13659 && export SANDBOX_HOST={sandbox_host} &&\
|
||||
cd chatbot/examples && nohup streamlit run webui.py > /home/user/logs/start_webui.log 2>&1 &"
|
||||
]
|
||||
if check_docker(client, CONTRAINER_NAME, do_stop=True):
|
||||
container = start_docker(client, script_shs, ports, IMAGE_NAME, CONTRAINER_NAME, mounts, network=network_name)
|
||||
|
||||
else:
|
||||
logger.info("start local service")
|
||||
# 关闭之前启动的docker 服务
|
||||
# check_docker(client, CONTRAINER_NAME, do_stop=True, )
|
||||
|
||||
api_sh = "nohup python ../dev_opsgpt/service/api.py > ../logs/api.log 2>&1 &"
|
||||
sdfile_sh = "nohup python ../dev_opsgpt/service/sdfile_api.py > ../logs/sdfile_api.log 2>&1 &"
|
||||
notebook_sh = f"nohup jupyter-notebook --NotebookApp.token=mytoken --port={SANDBOX_SERVER['port']} --allow-root --ip=0.0.0.0 --notebook-dir={JUPYTER_WORK_PATH} --no-browser --ServerApp.disable_check_xsrf=True > ../logs/sandbox.log 2>&1 &"
|
||||
llm_sh = "nohup python ../dev_opsgpt/service/llm_api.py > ../logs/llm_api.log 2>&1 &"
|
||||
webui_sh = "streamlit run webui.py" if USE_TTY else "streamlit run webui.py"
|
||||
|
||||
if check_process("jupyter-notebook --NotebookApp"):
|
||||
logger.debug(f"{notebook_sh}")
|
||||
subprocess.Popen(notebook_sh, shell=True)
|
||||
#
|
||||
if not NO_REMOTE_API and check_process("service/api.py"):
|
||||
subprocess.Popen(api_sh, shell=True)
|
||||
#
|
||||
if USE_FASTCHAT and check_process("service/llm_api.py"):
|
||||
subprocess.Popen(llm_sh, shell=True)
|
||||
#
|
||||
if check_process("service/sdfile_api.py"):
|
||||
subprocess.Popen(sdfile_sh, shell=True)
|
||||
|
||||
subprocess.Popen(webui_sh, shell=True)
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
start_sandbox_service()
|
||||
sandbox_host = DEFAULT_BIND_HOST
|
||||
if SANDBOX_SERVER["do_remote"]:
|
||||
client = docker.from_env()
|
||||
containers = client.containers.list(all=True)
|
||||
|
||||
for container in containers:
|
||||
container_a_info = client.containers.get(container.id)
|
||||
if container_a_info.name == SANDBOX_CONTRAINER_NAME:
|
||||
container1_networks = container.attrs['NetworkSettings']['Networks']
|
||||
sandbox_host = container1_networks.get(network_name)["IPAddress"]
|
||||
break
|
||||
start_api_service(sandbox_host)
|
||||
|
||||
import docker, sys, os, time, requests, psutil
|
||||
import subprocess
|
||||
from docker.types import Mount, DeviceRequest
|
||||
from loguru import logger
|
||||
|
||||
src_dir = os.path.join(
|
||||
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
)
|
||||
sys.path.append(src_dir)
|
||||
|
||||
from configs.model_config import USE_FASTCHAT, JUPYTER_WORK_PATH
|
||||
from configs.server_config import (
|
||||
NO_REMOTE_API, SANDBOX_SERVER, SANDBOX_IMAGE_NAME, SANDBOX_CONTRAINER_NAME,
|
||||
WEBUI_SERVER, API_SERVER, SDFILE_API_SERVER, CONTRAINER_NAME, IMAGE_NAME, DOCKER_SERVICE,
|
||||
DEFAULT_BIND_HOST, NEBULA_GRAPH_SERVER
|
||||
)
|
||||
|
||||
|
||||
import platform
|
||||
system_name = platform.system()
|
||||
USE_TTY = system_name in ["Windows"]
|
||||
|
||||
|
||||
def check_process(content: str, lang: str = None, do_stop=False):
|
||||
'''process-not-exist is true, process-exist is false'''
|
||||
for process in psutil.process_iter(["pid", "name", "cmdline"]):
|
||||
# check process name contains "jupyter" and port=xx
|
||||
|
||||
# if f"port={SANDBOX_SERVER['port']}" in str(process.info["cmdline"]).lower() and \
|
||||
# "jupyter" in process.info['name'].lower():
|
||||
if content in str(process.info["cmdline"]).lower():
|
||||
logger.info(f"content, {process.info}")
|
||||
# 关闭进程
|
||||
if do_stop:
|
||||
process.terminate()
|
||||
return True
|
||||
return False
|
||||
return True
|
||||
|
||||
def check_docker(client, container_name, do_stop=False):
|
||||
'''container-not-exist is true, container-exist is false'''
|
||||
if client is None: return True
|
||||
for i in client.containers.list(all=True):
|
||||
if i.name == container_name:
|
||||
if do_stop:
|
||||
container = i
|
||||
|
||||
if container_name == CONTRAINER_NAME and i.status == 'running':
|
||||
# wrap up db
|
||||
logger.info(f'inside {container_name}')
|
||||
# cp nebula data
|
||||
res = container.exec_run('''sh chatbot/dev_opsgpt/utils/nebula_cp.sh''')
|
||||
logger.info(f'cp res={res}')
|
||||
|
||||
# stop nebula service
|
||||
res = container.exec_run('''/usr/local/nebula/scripts/nebula.service stop all''')
|
||||
logger.info(f'stop res={res}')
|
||||
|
||||
container.stop()
|
||||
container.remove()
|
||||
return True
|
||||
return False
|
||||
return True
|
||||
|
||||
def start_docker(client, script_shs, ports, image_name, container_name, mounts=None, network=None):
|
||||
container = client.containers.run(
|
||||
image=image_name,
|
||||
command="bash",
|
||||
mounts=mounts,
|
||||
name=container_name,
|
||||
mem_limit="8g",
|
||||
# device_requests=[DeviceRequest(count=-1, capabilities=[['gpu']])],
|
||||
# network_mode="host",
|
||||
ports=ports,
|
||||
stdin_open=True,
|
||||
detach=True,
|
||||
tty=USE_TTY,
|
||||
network=network,
|
||||
)
|
||||
|
||||
logger.info(f"docker id: {container.id[:10]}")
|
||||
|
||||
# 启动notebook
|
||||
for script_sh in script_shs:
|
||||
if USE_FASTCHAT and "llm_api" in script_sh:
|
||||
logger.debug(script_sh)
|
||||
response = container.exec_run(["sh", "-c", script_sh])
|
||||
logger.debug(response)
|
||||
elif "llm_api" not in script_sh:
|
||||
logger.debug(script_sh)
|
||||
response = container.exec_run(["sh", "-c", script_sh])
|
||||
logger.debug(response)
|
||||
return container
|
||||
|
||||
#########################################
|
||||
############# 开始启动服务 ###############
|
||||
#########################################
|
||||
network_name ='my_network'
|
||||
|
||||
def start_sandbox_service(network_name ='my_network'):
|
||||
# networks = client.networks.list()
|
||||
# if any([network_name==i.attrs["Name"] for i in networks]):
|
||||
# network = client.networks.get(network_name)
|
||||
# else:
|
||||
# network = client.networks.create('my_network', driver='bridge')
|
||||
|
||||
mount = Mount(
|
||||
type='bind',
|
||||
source=os.path.join(src_dir, "jupyter_work"),
|
||||
target='/home/user/chatbot/jupyter_work',
|
||||
read_only=False # 如果需要只读访问,将此选项设置为True
|
||||
)
|
||||
mounts = [mount]
|
||||
# 沙盒的启动与服务的启动是独立的
|
||||
if SANDBOX_SERVER["do_remote"]:
|
||||
client = docker.from_env()
|
||||
# 启动容器
|
||||
logger.info("start container sandbox service")
|
||||
script_shs = ["bash jupyter_start.sh"]
|
||||
JUPYTER_WORK_PATH = "/home/user/chatbot/jupyter_work"
|
||||
script_shs = [f"cd /home/user/chatbot/jupyter_work && nohup jupyter-notebook --NotebookApp.token=mytoken --port=5050 --allow-root --ip=0.0.0.0 --notebook-dir={JUPYTER_WORK_PATH} --no-browser --ServerApp.disable_check_xsrf=True &"]
|
||||
ports = {f"{SANDBOX_SERVER['docker_port']}/tcp": f"{SANDBOX_SERVER['port']}/tcp"}
|
||||
if check_docker(client, SANDBOX_CONTRAINER_NAME, do_stop=True, ):
|
||||
container = start_docker(client, script_shs, ports, SANDBOX_IMAGE_NAME, SANDBOX_CONTRAINER_NAME, mounts=mounts, network=network_name)
|
||||
# 判断notebook是否启动
|
||||
time.sleep(5)
|
||||
retry_nums = 3
|
||||
while retry_nums>0:
|
||||
logger.info(f"http://localhost:{SANDBOX_SERVER['port']}")
|
||||
response = requests.get(f"http://localhost:{SANDBOX_SERVER['port']}", timeout=270)
|
||||
if response.status_code == 200:
|
||||
logger.info("container & notebook init success")
|
||||
break
|
||||
else:
|
||||
retry_nums -= 1
|
||||
logger.info(client.containers.list())
|
||||
logger.info("wait container running ...")
|
||||
time.sleep(5)
|
||||
|
||||
else:
|
||||
try:
|
||||
client = docker.from_env()
|
||||
except:
|
||||
client = None
|
||||
check_docker(client, SANDBOX_CONTRAINER_NAME, do_stop=True, )
|
||||
logger.info("start local sandbox service")
|
||||
|
||||
def start_api_service(sandbox_host=DEFAULT_BIND_HOST):
|
||||
# 启动service的容器
|
||||
if DOCKER_SERVICE:
|
||||
client = docker.from_env()
|
||||
logger.info("start container service")
|
||||
check_process("service/api.py", do_stop=True)
|
||||
check_process("service/sdfile_api.py", do_stop=True)
|
||||
check_process("service/sdfile_api.py", do_stop=True)
|
||||
check_process("webui.py", do_stop=True)
|
||||
mount = Mount(
|
||||
type='bind',
|
||||
source=src_dir,
|
||||
target='/home/user/chatbot/',
|
||||
read_only=False # 如果需要只读访问,将此选项设置为True
|
||||
)
|
||||
mount_database = Mount(
|
||||
type='bind',
|
||||
source=os.path.join(src_dir, "knowledge_base"),
|
||||
target='/home/user/knowledge_base/',
|
||||
read_only=False # 如果需要只读访问,将此选项设置为True
|
||||
)
|
||||
mount_code_database = Mount(
|
||||
type='bind',
|
||||
source=os.path.join(src_dir, "code_base"),
|
||||
target='/home/user/code_base/',
|
||||
read_only=False # 如果需要只读访问,将此选项设置为True
|
||||
)
|
||||
ports={
|
||||
f"{API_SERVER['docker_port']}/tcp": f"{API_SERVER['port']}/tcp",
|
||||
f"{WEBUI_SERVER['docker_port']}/tcp": f"{WEBUI_SERVER['port']}/tcp",
|
||||
f"{SDFILE_API_SERVER['docker_port']}/tcp": f"{SDFILE_API_SERVER['port']}/tcp",
|
||||
f"{NEBULA_GRAPH_SERVER['docker_port']}/tcp": f"{NEBULA_GRAPH_SERVER['port']}/tcp"
|
||||
}
|
||||
mounts = [mount, mount_database, mount_code_database]
|
||||
script_shs = [
|
||||
"mkdir -p /home/user/logs",
|
||||
'''
|
||||
if [ -d "/home/user/chatbot/data/nebula_data/data/meta" ]; then
|
||||
cp -r /home/user/chatbot/data/nebula_data/data /usr/local/nebula/
|
||||
fi
|
||||
''',
|
||||
"/usr/local/nebula/scripts/nebula.service start all",
|
||||
"/usr/local/nebula/scripts/nebula.service status all",
|
||||
"sleep 2",
|
||||
'''curl -X PUT -H "Content-Type: application/json" -d'{"heartbeat_interval_secs":"2"}' -s "http://127.0.0.1:19559/flags"''',
|
||||
'''curl -X PUT -H "Content-Type: application/json" -d'{"heartbeat_interval_secs":"2"}' -s "http://127.0.0.1:19669/flags"''',
|
||||
'''curl -X PUT -H "Content-Type: application/json" -d'{"heartbeat_interval_secs":"2"}' -s "http://127.0.0.1:19779/flags"''',
|
||||
|
||||
"nohup python chatbot/dev_opsgpt/service/sdfile_api.py > /home/user/logs/sdfile_api.log 2>&1 &",
|
||||
f"export DUCKDUCKGO_PROXY=socks5://host.docker.internal:13659 && export SANDBOX_HOST={sandbox_host} &&\
|
||||
nohup python chatbot/dev_opsgpt/service/api.py > /home/user/logs/api.log 2>&1 &",
|
||||
"nohup python chatbot/dev_opsgpt/service/llm_api.py > /home/user/ 2>&1 &",
|
||||
f"export DUCKDUCKGO_PROXY=socks5://host.docker.internal:13659 && export SANDBOX_HOST={sandbox_host} &&\
|
||||
cd chatbot/examples && nohup streamlit run webui.py > /home/user/logs/start_webui.log 2>&1 &"
|
||||
]
|
||||
if check_docker(client, CONTRAINER_NAME, do_stop=True):
|
||||
container = start_docker(client, script_shs, ports, IMAGE_NAME, CONTRAINER_NAME, mounts, network=network_name)
|
||||
|
||||
else:
|
||||
logger.info("start local service")
|
||||
# 关闭之前启动的docker 服务
|
||||
# check_docker(client, CONTRAINER_NAME, do_stop=True, )
|
||||
|
||||
api_sh = "nohup python ../dev_opsgpt/service/api.py > ../logs/api.log 2>&1 &"
|
||||
sdfile_sh = "nohup python ../dev_opsgpt/service/sdfile_api.py > ../logs/sdfile_api.log 2>&1 &"
|
||||
notebook_sh = f"nohup jupyter-notebook --NotebookApp.token=mytoken --port={SANDBOX_SERVER['port']} --allow-root --ip=0.0.0.0 --notebook-dir={JUPYTER_WORK_PATH} --no-browser --ServerApp.disable_check_xsrf=True > ../logs/sandbox.log 2>&1 &"
|
||||
llm_sh = "nohup python ../dev_opsgpt/service/llm_api.py > ../logs/llm_api.log 2>&1 &"
|
||||
webui_sh = "streamlit run webui.py" if USE_TTY else "streamlit run webui.py"
|
||||
|
||||
if check_process("jupyter-notebook --NotebookApp"):
|
||||
logger.debug(f"{notebook_sh}")
|
||||
subprocess.Popen(notebook_sh, shell=True)
|
||||
#
|
||||
if not NO_REMOTE_API and check_process("service/api.py"):
|
||||
subprocess.Popen(api_sh, shell=True)
|
||||
#
|
||||
if USE_FASTCHAT and check_process("service/llm_api.py"):
|
||||
subprocess.Popen(llm_sh, shell=True)
|
||||
#
|
||||
if check_process("service/sdfile_api.py"):
|
||||
subprocess.Popen(sdfile_sh, shell=True)
|
||||
|
||||
subprocess.Popen(webui_sh, shell=True)
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
start_sandbox_service()
|
||||
sandbox_host = DEFAULT_BIND_HOST
|
||||
if SANDBOX_SERVER["do_remote"]:
|
||||
client = docker.from_env()
|
||||
containers = client.containers.list(all=True)
|
||||
|
||||
for container in containers:
|
||||
container_a_info = client.containers.get(container.id)
|
||||
if container_a_info.name == SANDBOX_CONTRAINER_NAME:
|
||||
container1_networks = container.attrs['NetworkSettings']['Networks']
|
||||
sandbox_host = container1_networks.get(network_name)["IPAddress"]
|
||||
break
|
||||
start_api_service(sandbox_host)
|
||||
|
||||
|
|
|
@ -27,19 +27,19 @@ api = ApiRequest(base_url="http://127.0.0.1:7861", no_remote_api=NO_REMOTE_API)
|
|||
|
||||
if __name__ == "__main__":
|
||||
st.set_page_config(
|
||||
"DevOpsGPT-Chat WebUI",
|
||||
"CodeFuse-ChatBot WebUI",
|
||||
os.path.join("../sources/imgs", "devops-chatbot.png"),
|
||||
initial_sidebar_state="expanded",
|
||||
menu_items={
|
||||
'Get Help': 'https://github.com/lightislost/devopsgpt',
|
||||
'Report a bug': "https://github.com/lightislost/devopsgpt/issues",
|
||||
'About': f"""欢迎使用 DevOpsGPT-Chat WebUI {VERSION}!"""
|
||||
'Get Help': 'https://github.com/codefuse-ai/codefuse-chatbot',
|
||||
'Report a bug': "https://github.com/codefuse-ai/codefuse-chatbot/issues",
|
||||
'About': f"""欢迎使用 CodeFuse-ChatBot WebUI {VERSION}!"""
|
||||
}
|
||||
)
|
||||
|
||||
if not chat_box.chat_inited:
|
||||
st.toast(
|
||||
f"欢迎使用 [`DevOpsGPT-Chat`](https://github.com/lightislost/devopsgpt) ! \n\n"
|
||||
f"欢迎使用 [`CodeFuse-ChatBot`](https://github.com/codefuse-ai/codefuse-chatbot) ! \n\n"
|
||||
f"当前使用模型`{LLM_MODEL}`, 您可以开始提问了."
|
||||
)
|
||||
|
||||
|
@ -71,7 +71,7 @@ if __name__ == "__main__":
|
|||
use_column_width=True
|
||||
)
|
||||
st.caption(
|
||||
f"""<p align="right">当前版本:{VERSION}</p>""",
|
||||
f"""<p align="right"> CodeFuse-ChatBot 当前版本:{VERSION}</p>""",
|
||||
unsafe_allow_html=True,
|
||||
)
|
||||
options = list(pages)
|
||||
|
|
|
@ -16,7 +16,7 @@ faiss-cpu
|
|||
nltk
|
||||
loguru
|
||||
pypdf
|
||||
duckduckgo-search
|
||||
duckduckgo-search==3.9.11
|
||||
pysocks
|
||||
accelerate
|
||||
docker
|
||||
|
@ -39,7 +39,7 @@ streamlit-option-menu>=0.3.6
|
|||
streamlit-antd-components>=0.1.11
|
||||
streamlit-chatbox>=1.1.6
|
||||
streamlit-aggrid>=0.3.4.post3
|
||||
httpx~=0.24.1
|
||||
httpx
|
||||
|
||||
javalang==0.13.0
|
||||
jsonref==1.1.0
|
||||
|
@ -51,4 +51,9 @@ nebula3-python==3.1.0
|
|||
protobuf==3.20.*
|
||||
transformers_stream_generator
|
||||
einops
|
||||
auto-gptq
|
||||
optimum
|
||||
modelscope
|
||||
|
||||
# vllm model
|
||||
vllm==0.2.2; sys_platform == "linux"
|
||||
|
|
|
@ -22,15 +22,30 @@ if __name__ == "__main__":
|
|||
# chat = ChatOpenAI(temperature=0.1, model_name="gpt-3.5-turbo")
|
||||
# print(chat.predict("hi!"))
|
||||
|
||||
print(LLM_MODEL, llm_model_dict[LLM_MODEL]["api_key"], llm_model_dict[LLM_MODEL]["api_base_url"])
|
||||
model = ChatOpenAI(
|
||||
streaming=True,
|
||||
verbose=True,
|
||||
openai_api_key=llm_model_dict[LLM_MODEL]["api_key"],
|
||||
openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"],
|
||||
model_name=LLM_MODEL
|
||||
# print(LLM_MODEL, llm_model_dict[LLM_MODEL]["api_key"], llm_model_dict[LLM_MODEL]["api_base_url"])
|
||||
# model = ChatOpenAI(
|
||||
# streaming=True,
|
||||
# verbose=True,
|
||||
# openai_api_key=llm_model_dict[LLM_MODEL]["api_key"],
|
||||
# openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"],
|
||||
# model_name=LLM_MODEL
|
||||
# )
|
||||
# chat_prompt = ChatPromptTemplate.from_messages([("human", "{input}")])
|
||||
# chain = LLMChain(prompt=chat_prompt, llm=model)
|
||||
# content = chain({"input": "hello"})
|
||||
# print(content)
|
||||
|
||||
import openai
|
||||
# openai.api_key = "EMPTY" # Not support yet
|
||||
openai.api_base = "http://127.0.0.1:8888/v1"
|
||||
|
||||
model = "example"
|
||||
|
||||
# create a chat completion
|
||||
completion = openai.ChatCompletion.create(
|
||||
model=model,
|
||||
messages=[{"role": "user", "content": "Hello! What is your name? "}],
|
||||
max_tokens=100,
|
||||
)
|
||||
chat_prompt = ChatPromptTemplate.from_messages([("human", "{input}")])
|
||||
chain = LLMChain(prompt=chat_prompt, llm=model)
|
||||
content = chain({"input": "hello"})
|
||||
print(content)
|
||||
# print the completion
|
||||
print(completion.choices[0].message.content)
|
Loading…
Reference in New Issue