update feature: access more LLM Models by fastchat from langchain-chatchat

This commit is contained in:
shanshi 2023-12-26 11:41:53 +08:00
parent ac0000890e
commit b35e849e9b
63 changed files with 3273 additions and 1264 deletions

View File

@ -14,7 +14,7 @@ logging.basicConfig(format=LOG_FORMAT)
# os.environ["OPENAI_PROXY"] = "socks5h://127.0.0.1:13659"
os.environ["API_BASE_URL"] = "http://openai.com/v1/chat/completions"
os.environ["OPENAI_API_KEY"] = ""
os.environ["DUCKDUCKGO_PROXY"] = "socks5://127.0.0.1:13659"
os.environ["DUCKDUCKGO_PROXY"] = os.environ.get("DUCKDUCKGO_PROXY") or "socks5://127.0.0.1:13659"
os.environ["BAIDU_OCR_API_KEY"] = ""
os.environ["BAIDU_OCR_SECRET_KEY"] = ""
@ -60,46 +60,26 @@ ONLINE_LLM_MODEL = {
"api_key": "",
"openai_proxy": "",
},
"example": {
"version": "gpt-3.5", # 采用openai接口做示例
"api_base_url": "https://api.openai.com/v1",
"api_key": "",
"provider": "ExampleWorker",
},
}
# 建议使用chat模型不要使用base无法获取正确输出
llm_model_dict = {
"chatglm-6b": {
"local_model_path": "THUDM/chatglm-6b",
"api_base_url": "http://localhost:8888/v1", # "name"修改为fastchat服务中的"api_base_url"
"api_key": "EMPTY"
},
"chatglm-6b-int4": {
"local_model_path": "THUDM/chatglm2-6b-int4/",
"api_base_url": "http://localhost:8888/v1", # "name"修改为fastchat服务中的"api_base_url"
"api_key": "EMPTY"
},
"chatglm2-6b": {
"local_model_path": "THUDM/chatglm2-6b",
"api_base_url": "http://localhost:8888/v1", # URL需要与运行fastchat服务端的server_config.FSCHAT_OPENAI_API一致
"api_key": "EMPTY"
},
"chatglm2-6b-int4": {
"local_model_path": "THUDM/chatglm2-6b-int4",
"api_base_url": "http://localhost:8888/v1", # URL需要与运行fastchat服务端的server_config.FSCHAT_OPENAI_API一致
"api_key": "EMPTY"
},
"chatglm2-6b-32k": {
"local_model_path": "THUDM/chatglm2-6b-32k", # "THUDM/chatglm2-6b-32k",
"api_base_url": "http://localhost:8888/v1", # "URL需要与运行fastchat服务端的server_config.FSCHAT_OPENAI_API一致
"api_key": "EMPTY"
},
"vicuna-13b-hf": {
"local_model_path": "",
"api_base_url": "http://localhost:8888/v1", # "name"修改为fastchat服务中的"api_base_url"
"api_key": "EMPTY"
},
# 以下模型经过测试可接入,配置仿照上述即可
# 'codellama_34b', 'Baichuan2-13B-Base', 'Baichuan2-13B-Chat', 'baichuan2-7b-base', 'baichuan2-7b-chat',
# 'internlm-7b-base', 'internlm-chat-7b', 'chatglm2-6b', 'qwen-14b-base', 'qwen-14b-chat', 'qwen-1-8B-Chat',
# 'Qwen-7B', 'Qwen-7B-Chat', 'qwen-7b-base-v1.1', 'qwen-7b-chat-v1.1', 'chatglm3-6b', 'chatglm3-6b-32k',
# 'chatglm3-6b-base', 'Qwen-72B-Chat-Int4'
# 调用chatgpt时如果报出 urllib3.exceptions.MaxRetryError: HTTPSConnectionPool(host='api.openai.com', port=443):
# Max retries exceeded with url: /v1/chat/completions
# 则需要将urllib3版本修改为1.25.11
@ -115,8 +95,23 @@ llm_model_dict = {
"api_base_url": os.environ.get("API_BASE_URL"),
"api_key": os.environ.get("OPENAI_API_KEY")
},
"gpt-3.5-turbo-16k": {
"local_model_path": "gpt-3.5-turbo-16k",
"api_base_url": os.environ.get("API_BASE_URL"),
"api_key": os.environ.get("OPENAI_API_KEY")
},
}
# 建议使用chat模型不要使用base无法获取正确输出
VLLM_MODEL_DICT = {
'chatglm2-6b': "THUDM/chatglm-6b",
}
# 以下模型经过测试可接入,配置仿照上述即可
# 'codellama_34b', 'Baichuan2-13B-Base', 'Baichuan2-13B-Chat', 'baichuan2-7b-base', 'baichuan2-7b-chat',
# 'internlm-7b-base', 'internlm-chat-7b', 'chatglm2-6b', 'qwen-14b-base', 'qwen-14b-chat', 'qwen-1-8B-Chat',
# 'Qwen-7B', 'Qwen-7B-Chat', 'qwen-7b-base-v1.1', 'qwen-7b-chat-v1.1', 'chatglm3-6b', 'chatglm3-6b-32k',
# 'chatglm3-6b-base', 'Qwen-72B-Chat-Int4'
LOCAL_LLM_MODEL_DIR = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "llm_models")
llm_model_dict_c = {}
@ -133,8 +128,12 @@ llm_model_dict = llm_model_dict_c
# LLM 名称
LLM_MODEL = "gpt-3.5-turbo"
LLM_MODELs = ["chatglm2-6b"]
# EMBEDDING_ENGINE = 'openai'
EMBEDDING_ENGINE = 'model'
EMBEDDING_MODEL = "text2vec-base"
# LLM_MODEL = "gpt-4"
LLM_MODEL = "gpt-3.5-turbo-16k"
LLM_MODELs = ["gpt-3.5-turbo-16k"]
USE_FASTCHAT = "gpt" not in LLM_MODEL # 判断是否进行fastchat
# LLM 运行设备
@ -161,7 +160,7 @@ NLTK_DATA_PATH = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__
JUPYTER_WORK_PATH = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "jupyter_work")
# WEB_CRAWL存储路径
WEB_CRAWL_PATH = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "sources/docs")
WEB_CRAWL_PATH = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "knowledge_base")
# NEBULA_DATA存储路径
NELUBA_PATH = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "data/neluba_data")
@ -257,3 +256,5 @@ BING_SUBSCRIPTION_KEY = ""
# 通过增加标题判断判断哪些文本为标题并在metadata中进行标记
# 然后将文本与往上一级的标题进行拼合,实现文本信息的增强。
ZH_TITLE_ENHANCE = False
log_verbose = False

View File

@ -100,12 +100,26 @@ FSCHAT_MODEL_WORKERS = {
# "stream_interval": 2,
# "no_register": False,
},
"chatglm2-6b": {
"port": 20003
},
"baichuan2-7b-base": {
"port": 20004
}
'codellama_34b': {'host': DEFAULT_BIND_HOST, 'port': 20002},
'Baichuan2-13B-Base': {'host': DEFAULT_BIND_HOST, 'port': 20003},
'Baichuan2-13B-Chat': {'host': DEFAULT_BIND_HOST, 'port': 20004},
'baichuan2-7b-base': {'host': DEFAULT_BIND_HOST, 'port': 20005},
'baichuan2-7b-chat': {'host': DEFAULT_BIND_HOST, 'port': 20006},
'internlm-7b-base': {'host': DEFAULT_BIND_HOST, 'port': 20007},
'internlm-chat-7b': {'host': DEFAULT_BIND_HOST, 'port': 20008},
'chatglm2-6b': {'host': DEFAULT_BIND_HOST, 'port': 20009},
'qwen-14b-base': {'host': DEFAULT_BIND_HOST, 'port': 20010},
'qwen-14b-chat': {'host': DEFAULT_BIND_HOST, 'port': 20011},
'qwen-1-8B-Chat': {'host': DEFAULT_BIND_HOST, 'port': 20012},
'Qwen-7B': {'host': DEFAULT_BIND_HOST, 'port': 20013},
'Qwen-7B-Chat': {'host': DEFAULT_BIND_HOST, 'port': 20014},
'qwen-7b-base-v1.1': {'host': DEFAULT_BIND_HOST, 'port': 20015},
'qwen-7b-chat-v1.1': {'host': DEFAULT_BIND_HOST, 'port': 20016},
'chatglm3-6b': {'host': DEFAULT_BIND_HOST, 'port': 20017},
'chatglm3-6b-32k': {'host': DEFAULT_BIND_HOST, 'port': 20018},
'chatglm3-6b-base': {'host': DEFAULT_BIND_HOST, 'port': 20019},
'Qwen-72B-Chat-Int4': {'host': DEFAULT_BIND_HOST, 'port': 20020},
'gpt-3.5-turbo': {'host': DEFAULT_BIND_HOST, 'port': 20021}
}
# fastchat multi model worker server
FSCHAT_MULTI_MODEL_WORKERS = {

View File

@ -275,7 +275,7 @@ class AgentChat:
"code_docs": [str(doc) for doc in message.code_docs],
"related_nodes": [doc.get_related_node() for idx, doc in enumerate(message.code_docs) if idx==0],
"figures": message.figures,
"step_content": step_content,
"step_content": step_content or final_content,
"final_content": final_content,
}

View File

@ -22,7 +22,7 @@ HISTORY_VERTEX_SCORE = 5
VERTEX_MERGE_RATIO = 0.5
# search_by_description
MAX_DISTANCE = 0.5
MAX_DISTANCE = 1000
class CodeSearch:

View File

@ -17,7 +17,7 @@ from dev_opsgpt.connector.configs.prompts import BASE_PROMPT_INPUT, QUERY_CONTEX
from dev_opsgpt.connector.message_process import MessageUtils
from dev_opsgpt.connector.configs.agent_config import REACT_PROMPT_INPUT, QUERY_CONTEXT_PROMPT_INPUT, PLAN_PROMPT_INPUT
from dev_opsgpt.llm_models import getChatModel
from dev_opsgpt.llm_models import getChatModel, getExtraModel
from dev_opsgpt.connector.utils import parse_section
@ -44,7 +44,7 @@ class BaseAgent:
self.task = task
self.role = role
self.message_utils = MessageUtils()
self.message_utils = MessageUtils(role)
self.llm = self.create_llm_engine(temperature, stop)
self.memory = self.init_history(memory)
self.chat_turn = chat_turn
@ -68,6 +68,7 @@ class BaseAgent:
'''agent reponse from multi-message'''
# insert query into memory
query_c = copy.deepcopy(query)
query_c = self.start_action_step(query_c)
self_memory = self.memory if self.do_use_self_memory else None
# create your llm prompt
@ -80,23 +81,30 @@ class BaseAgent:
role_name=self.role.role_name,
role_type="ai", #self.role.role_type,
role_content=content,
role_contents=[content],
step_content=content,
input_query=query_c.input_query,
tools=query_c.tools,
parsed_output_list=[query.parsed_output]
parsed_output_list=[query.parsed_output],
customed_kargs=query_c.customed_kargs
)
# common parse llm' content to message
output_message = self.message_utils.parser(output_message)
if self.do_filter:
output_message = self.message_utils.filter(output_message)
# action step
output_message, observation_message = self.message_utils.step_router(output_message, history, background, memory_pool=memory_pool)
output_message.parsed_output_list.append(output_message.parsed_output)
if observation_message:
output_message.parsed_output_list.append(observation_message.parsed_output)
# update self_memory
self.append_history(query_c)
self.append_history(output_message)
# logger.info(f"{self.role.role_name} currenct question: {output_message.input_query}\nllm_step_run: {output_message.role_content}")
output_message.input_query = output_message.role_content
output_message.parsed_output_list.append(output_message.parsed_output)
# output_message.parsed_output_list.append(output_message.parsed_output) # 与上述重复?
# end
output_message = self.message_utils.inherit_extrainfo(query, output_message)
output_message = self.end_action_step(output_message)
# update memory pool
memory_pool.append(output_message)
yield output_message
@ -110,7 +118,7 @@ class BaseAgent:
doc_infos = self.create_doc_prompt(query)
code_infos = self.create_codedoc_prompt(query)
#
formatted_tools, tool_names = self.create_tools_prompt(query)
formatted_tools, tool_names, _ = self.create_tools_prompt(query)
task_prompt = self.create_task_prompt(query)
background_prompt = self.create_background_prompt(background, control_key="step_content")
history_prompt = self.create_history_prompt(history)
@ -185,7 +193,10 @@ class BaseAgent:
context = memory_pool_select_by_agent_key_context
prompt += "\n**Context:**\n" + context + "\n" + input_query
elif input_key == "DocInfos":
prompt += "\n**DocInfos:**\n" + DocInfos
if DocInfos:
prompt += "\n**DocInfos:**\n" + DocInfos
else:
prompt += "\n**DocInfos:**\n" + "Empty"
elif input_key == "Question":
prompt += "\n**Question:**\n" + input_query
@ -231,12 +242,15 @@ class BaseAgent:
def create_tools_prompt(self, message: Message) -> str:
tools = message.tools
tool_strings = []
tools_descs = []
for tool in tools:
args_schema = re.sub("}", "}}}}", re.sub("{", "{{{{", str(tool.args)))
tool_strings.append(f"{tool.name}: {tool.description}, args: {args_schema}")
tools_descs.append(f"{tool.name}: {tool.description}")
formatted_tools = "\n".join(tool_strings)
tools_desc_str = "\n".join(tools_descs)
tool_names = ", ".join([tool.name for tool in tools])
return formatted_tools, tool_names
return formatted_tools, tool_names, tools_desc_str
def create_task_prompt(self, message: Message) -> str:
task = message.task or self.task
@ -276,19 +290,22 @@ class BaseAgent:
def create_llm_engine(self, temperature=0.2, stop=None):
return getChatModel(temperature=temperature, stop=stop)
def registry_actions(self, actions):
'''registry llm's actions'''
self.action_list = actions
# def filter(self, message: Message, stop=None) -> Message:
def start_action_step(self, message: Message) -> Message:
'''do action before agent predict '''
# action_json = self.start_action()
# message["customed_kargs"]["xx"] = action_json
return message
# tool_params = self.parser_spec_key(message.role_content, "tool_params")
# code_content = self.parser_spec_key(message.role_content, "code_content")
# plan = self.parser_spec_key(message.role_content, "plan")
# plans = self.parser_spec_key(message.role_content, "plans", do_search=False)
# content = self.parser_spec_key(message.role_content, "content", do_search=False)
# # logger.debug(f"tool_params: {tool_params}, code_content: {code_content}, plan: {plan}, plans: {plans}, content: {content}")
# role_content = tool_params or code_content or plan or plans or content
# message.role_content = role_content or message.role_content
# return message
def end_action_step(self, message: Message) -> Message:
'''do action after agent predict '''
# action_json = self.end_action()
# message["customed_kargs"]["xx"] = action_json
return message
def token_usage(self, ):
'''calculate the usage of token'''
@ -324,339 +341,6 @@ class BaseAgent:
message_c.parsed_output = {k: v for k,v in message_c.parsed_output.items() if k in self.focus_message_keys}
message_c.parsed_output_list = [{k: v for k,v in parsed_output.items() if k in self.focus_message_keys} for parsed_output in message_c.parsed_output_list]
return message_c
# def get_extra_infos(self, message: Message) -> Message:
# ''''''
# if self.do_search:
# message = self.get_search_retrieval(message)
# if self.do_doc_retrieval:
# message = self.get_doc_retrieval(message)
# if self.do_tool_retrieval:
# message = self.get_tool_retrieval(message)
# return message
# def get_search_retrieval(self, message: Message,) -> Message:
# SEARCH_ENGINES = {"duckduckgo": DDGSTool}
# search_docs = []
# for idx, doc in enumerate(SEARCH_ENGINES["duckduckgo"].run(message.role_content, 3)):
# doc.update({"index": idx})
# search_docs.append(Doc(**doc))
# message.search_docs = search_docs
# return message
# def get_doc_retrieval(self, message: Message) -> Message:
# query = message.role_content
# knowledge_basename = message.doc_engine_name
# top_k = message.top_k
# score_threshold = message.score_threshold
# if knowledge_basename:
# docs = DocRetrieval.run(query, knowledge_basename, top_k, score_threshold)
# message.db_docs = [Doc(**doc) for doc in docs]
# return message
# def get_code_retrieval(self, message: Message) -> Message:
# # DocRetrieval.run("langchain是什么", "DSADSAD")
# query = message.input_query
# code_engine_name = message.code_engine_name
# history_node_list = message.history_node_list
# code_docs = CodeRetrieval.run(code_engine_name, query, code_limit=message.top_k, history_node_list=history_node_list)
# message.code_docs = [CodeDoc(**doc) for doc in code_docs]
# return message
# def get_tool_retrieval(self, message: Message) -> Message:
# return message
# def step_router(self, message: Message) -> tuple[Message, ...]:
# ''''''
# # message = self.parser(message)
# # logger.debug(f"message.action_status: {message.action_status}")
# observation_message = None
# if message.action_status == ActionStatus.CODING:
# message, observation_message = self.code_step(message)
# elif message.action_status == ActionStatus.TOOL_USING:
# message, observation_message = self.tool_step(message)
# return message, observation_message
# def code_step(self, message: Message) -> Message:
# '''execute code'''
# # logger.debug(f"message.role_content: {message.role_content}, message.code_content: {message.code_content}")
# code_answer = self.codebox.chat('```python\n{}```'.format(message.code_content))
# code_prompt = f"执行上述代码后存在报错信息为 {code_answer.code_exe_response},需要进行修复" \
# if code_answer.code_exe_type == "error" else f"执行上述代码后返回信息为 {code_answer.code_exe_response}"
# observation_message = Message(
# role_name="observation",
# role_type="func", #self.role.role_type,
# role_content="",
# step_content="",
# input_query=message.code_content,
# )
# uid = str(uuid.uuid1())
# if code_answer.code_exe_type == "image/png":
# message.figures[uid] = code_answer.code_exe_response
# message.code_answer = f"\n**Observation:**: 执行上述代码后生成一张图片, 图片名为{uid}\n"
# message.observation = f"\n**Observation:**: 执行上述代码后生成一张图片, 图片名为{uid}\n"
# message.step_content += f"\n**Observation:**: 执行上述代码后生成一张图片, 图片名为{uid}\n"
# message.step_contents += [f"\n**Observation:**: 执行上述代码后生成一张图片, 图片名为{uid}\n"]
# # message.role_content += f"\n**Observation:**:执行上述代码后生成一张图片, 图片名为{uid}\n"
# observation_message.role_content = f"\n**Observation:**: 执行上述代码后生成一张图片, 图片名为{uid}\n"
# observation_message.parsed_output = {"Observation": f"执行上述代码后生成一张图片, 图片名为{uid}"}
# else:
# message.code_answer = code_answer.code_exe_response
# message.observation = code_answer.code_exe_response
# message.step_content += f"\n**Observation:**: {code_prompt}\n"
# message.step_contents += [f"\n**Observation:**: {code_prompt}\n"]
# # message.role_content += f"\n**Observation:**: {code_prompt}\n"
# observation_message.role_content = f"\n**Observation:**: {code_prompt}\n"
# observation_message.parsed_output = {"Observation": code_prompt}
# # logger.info(f"**Observation:** {message.action_status}, {message.observation}")
# return message, observation_message
# def tool_step(self, message: Message) -> Message:
# '''execute tool'''
# # logger.debug(f"{message}")
# observation_message = Message(
# role_name="observation",
# role_type="function", #self.role.role_type,
# role_content="\n**Observation:** there is no tool can execute\n" ,
# step_content="",
# input_query=str(message.tool_params),
# tools=message.tools,
# )
# # logger.debug(f"message: {message.action_status}, {message.tool_name}, {message.tool_params}")
# tool_names = [tool.name for tool in message.tools]
# if message.tool_name not in tool_names:
# message.tool_answer = "\n**Observation:** there is no tool can execute\n"
# message.observation = "\n**Observation:** there is no tool can execute\n"
# # message.role_content += f"\n**Observation:**: 不存在可以执行的tool\n"
# message.step_content += f"\n**Observation:** there is no tool can execute\n"
# message.step_contents += [f"\n**Observation:** there is no tool can execute\n"]
# observation_message.role_content = f"\n**Observation:** there is no tool can execute\n"
# observation_message.parsed_output = {"Observation": "there is no tool can execute\n"}
# for tool in message.tools:
# if tool.name == message.tool_name:
# tool_res = tool.func(**message.tool_params.get("tool_params", {}))
# logger.debug(f"tool_res {tool_res}")
# message.tool_answer = tool_res
# message.observation = tool_res
# # message.role_content += f"\n**Observation:**: {tool_res}\n"
# message.step_content += f"\n**Observation:** {tool_res}\n"
# message.step_contents += [f"\n**Observation:** {tool_res}\n"]
# observation_message.role_content = f"\n**Observation:** {tool_res}\n"
# observation_message.parsed_output = {"Observation": tool_res}
# break
# # logger.info(f"**Observation:** {message.action_status}, {message.observation}")
# return message, observation_message
# def parser(self, message: Message) -> Message:
# ''''''
# content = message.role_content
# parser_keys = ["action", "code_content", "code_filename", "tool_params", "plans"]
# try:
# s_json = self._parse_json(content)
# message.action_status = s_json.get("action")
# message.code_content = s_json.get("code_content")
# message.tool_params = s_json.get("tool_params")
# message.tool_name = s_json.get("tool_name")
# message.code_filename = s_json.get("code_filename")
# message.plans = s_json.get("plans")
# # for parser_key in parser_keys:
# # message.action_status = content.get(parser_key)
# except Exception as e:
# # logger.warning(f"{traceback.format_exc()}")
# def parse_text_to_dict(text):
# # Define a regular expression pattern to capture the key and value
# main_pattern = r"\*\*(.+?):\*\*\s*(.*?)\s*(?=\*\*|$)"
# list_pattern = r'```python\n(.*?)```'
# # Use re.findall to find all main matches in the text
# main_matches = re.findall(main_pattern, text, re.DOTALL)
# # Convert main matches to a dictionary
# parsed_dict = {key.strip(): value.strip() for key, value in main_matches}
# for k, v in parsed_dict.items():
# for pattern in [list_pattern]:
# if "PLAN" != k: continue
# match_value = re.search(pattern, v, re.DOTALL)
# if match_value:
# # Add the code block to the dictionary
# parsed_dict[k] = eval(match_value.group(1).strip())
# break
# return parsed_dict
# def extract_content_from_backticks(text):
# code_blocks = []
# lines = text.split('\n')
# is_code_block = False
# code_block = ''
# language = ''
# for line in lines:
# if line.startswith('```') and not is_code_block:
# is_code_block = True
# language = line[3:]
# code_block = ''
# elif line.startswith('```') and is_code_block:
# is_code_block = False
# code_blocks.append({language.strip(): code_block.strip()})
# elif is_code_block:
# code_block += line + '\n'
# return code_blocks
# def parse_dict_to_dict(parsed_dict):
# code_pattern = r'```python\n(.*?)```'
# tool_pattern = r'```tool_params\n(.*?)```'
# pattern_dict = {"code": code_pattern, "tool_params": tool_pattern}
# spec_parsed_dict = copy.deepcopy(parsed_dict)
# for key, pattern in pattern_dict.items():
# for k, text in parsed_dict.items():
# # Search for the code block
# if not isinstance(text, str): continue
# _match = re.search(pattern, text, re.DOTALL)
# if _match:
# # Add the code block to the dictionary
# try:
# spec_parsed_dict[key] = json.loads(_match.group(1).strip())
# except:
# spec_parsed_dict[key] = _match.group(1).strip()
# break
# return spec_parsed_dict
# def parse_dict_to_dict(parsed_dict):
# code_pattern = r'```python\n(.*?)```'
# tool_pattern = r'```json\n(.*?)```'
# pattern_dict = {"code": code_pattern, "json": tool_pattern}
# spec_parsed_dict = copy.deepcopy(parsed_dict)
# for key, pattern in pattern_dict.items():
# for k, text in parsed_dict.items():
# # Search for the code block
# if not isinstance(text, str): continue
# _match = re.search(pattern, text, re.DOTALL)
# if _match:
# # Add the code block to the dictionary
# logger.debug(f"dsadsa {text}")
# try:
# spec_parsed_dict[key] = json.loads(_match.group(1).strip())
# except:
# spec_parsed_dict[key] = _match.group(1).strip()
# break
# return spec_parsed_dict
# parsed_dict = parse_text_to_dict(content)
# spec_parsed_dict = parse_dict_to_dict(parsed_dict)
# action_value = parsed_dict.get('Action Status')
# if action_value:
# action_value = action_value.lower()
# logger.info(f'{self.role.role_name}: action_value: {action_value}')
# # action_value = self._match(r"'action':\s*'([^']*)'", content) if "'action'" in content else self._match(r'"action":\s*"([^"]*)"', content)
# code_content_value = spec_parsed_dict.get('code')
# # code_content_value = self._match(r"'code_content':\s*'([^']*)'", content) if "'code_content'" in content else self._match(r'"code_content":\s*"([^"]*)"', content)
# filename_value = self._match(r"'code_filename':\s*'([^']*)'", content) if "'code_filename'" in content else self._match(r'"code_filename":\s*"([^"]*)"', content)
# tool_params_value = spec_parsed_dict.get('tool_params')
# # tool_params_value = self._match(r"'tool_params':\s*(\{[^{}]*\})", content, do_json=True) if "'tool_params'" in content \
# # else self._match(r'"tool_params":\s*(\{[^{}]*\})', content, do_json=True)
# tool_name_value = self._match(r"'tool_name':\s*'([^']*)'", content) if "'tool_name'" in content else self._match(r'"tool_name":\s*"([^"]*)"', content)
# plans_value = self._match(r"'plans':\s*(\[.*?\])", content, do_search=False) if "'plans'" in content else self._match(r'"plans":\s*(\[.*?\])', content, do_search=False, )
# # re解析
# message.action_status = action_value or "default"
# message.code_content = code_content_value
# message.code_filename = filename_value
# message.tool_params = tool_params_value
# message.tool_name = tool_name_value
# message.plans = plans_value
# message.parsed_output = parsed_dict
# message.spec_parsed_output = spec_parsed_dict
# code_content_value = spec_parsed_dict.get('code')
# # code_content_value = self._match(r"'code_content':\s*'([^']*)'", content) if "'code_content'" in content else self._match(r'"code_content":\s*"([^"]*)"', content)
# filename_value = self._match(r"'code_filename':\s*'([^']*)'", content) if "'code_filename'" in content else self._match(r'"code_filename":\s*"([^"]*)"', content)
# logger.debug(spec_parsed_dict)
# if action_value == 'tool_using':
# tool_params_value = spec_parsed_dict.get('json')
# else:
# tool_params_value = None
# # tool_params_value = self._match(r"'tool_params':\s*(\{[^{}]*\})", content, do_json=True) if "'tool_params'" in content \
# # else self._match(r'"tool_params":\s*(\{[^{}]*\})', content, do_json=True)
# tool_name_value = self._match(r"'tool_name':\s*'([^']*)'", content) if "'tool_name'" in content else self._match(r'"tool_name":\s*"([^"]*)"', content)
# plans_value = self._match(r"'plans':\s*(\[.*?\])", content, do_search=False) if "'plans'" in content else self._match(r'"plans":\s*(\[.*?\])', content, do_search=False, )
# # re解析
# message.action_status = action_value or "default"
# message.code_content = code_content_value
# message.code_filename = filename_value
# message.tool_params = tool_params_value
# message.tool_name = tool_name_value
# message.plans = plans_value
# message.parsed_output = parsed_dict
# message.spec_parsed_output = spec_parsed_dict
# # logger.debug(f"确认当前的action: {message.action_status}")
# return message
# def parser_spec_key(self, content, key, do_search=True, do_json=False) -> str:
# ''''''
# key2pattern = {
# "'action'": r"'action':\s*'([^']*)'", '"action"': r'"action":\s*"([^"]*)"',
# "'code_content'": r"'code_content':\s*'([^']*)'", '"code_content"': r'"code_content":\s*"([^"]*)"',
# "'code_filename'": r"'code_filename':\s*'([^']*)'", '"code_filename"': r'"code_filename":\s*"([^"]*)"',
# "'tool_params'": r"'tool_params':\s*(\{[^{}]*\})", '"tool_params"': r'"tool_params":\s*(\{[^{}]*\})',
# "'tool_name'": r"'tool_name':\s*'([^']*)'", '"tool_name"': r'"tool_name":\s*"([^"]*)"',
# "'plans'": r"'plans':\s*(\[.*?\])", '"plans"': r'"plans":\s*(\[.*?\])',
# "'content'": r"'content':\s*'([^']*)'", '"content"': r'"content":\s*"([^"]*)"',
# }
# s_json = self._parse_json(content)
# try:
# if s_json and key in s_json:
# return str(s_json[key])
# except:
# pass
# keystr = f"'{key}'" if f"'{key}'" in content else f'"{key}"'
# return self._match(key2pattern.get(keystr, fr"'{key}':\s*'([^']*)'"), content, do_search=do_search, do_json=do_json)
# def _match(self, pattern, s, do_search=True, do_json=False):
# try:
# if do_search:
# match = re.search(pattern, s)
# if match:
# value = match.group(1).replace("\\n", "\n")
# if do_json:
# value = json.loads(value)
# else:
# value = None
# else:
# match = re.findall(pattern, s, re.DOTALL)
# if match:
# value = match[0]
# if do_json:
# value = json.loads(value)
# else:
# value = None
# except Exception as e:
# logger.warning(f"{traceback.format_exc()}")
# # logger.debug(f"pattern: {pattern}, s: {s}, match: {match}")
# return value
# def _parse_json(self, s):
# try:
# pattern = r"```([^`]+)```"
# match = re.findall(pattern, s)
# if match:
# return eval(match[0])
# except:
# pass
# return None
def get_memory(self, content_key="role_content"):
return self.memory.to_tuple_messages(content_key="step_content")

View File

@ -50,7 +50,7 @@ class CheckAgent(BaseAgent):
doc_infos = self.create_doc_prompt(query)
code_infos = self.create_codedoc_prompt(query)
#
formatted_tools, tool_names = self.create_tools_prompt(query)
formatted_tools, tool_names, _ = self.create_tools_prompt(query)
task_prompt = self.create_task_prompt(query)
background_prompt = self.create_background_prompt(background)
history_prompt = self.create_history_prompt(history)

View File

@ -55,7 +55,8 @@ class ExecutorAgent(BaseAgent):
step_content="",
input_query=query.input_query,
tools=query.tools,
parsed_output_list=[query.parsed_output]
parsed_output_list=[query.parsed_output],
customed_kargs=query.customed_kargs
)
self_memory = self.memory if self.do_use_self_memory else None
@ -64,6 +65,7 @@ class ExecutorAgent(BaseAgent):
# 如果存在plan字段且plan字段为str的时候
if "PLAN" not in query.parsed_output or isinstance(query.parsed_output.get("PLAN", []), str) or plan_step >= len(query.parsed_output.get("PLAN", [])):
query_c = copy.deepcopy(query)
query_c = self.start_action_step(query_c)
query_c.parsed_output = {"Question": query_c.input_query}
task_executor_memory.append(query_c)
for output_message, task_executor_memory in self._arun_step(output_message, query_c, self_memory, history, background, memory_pool, task_executor_memory):
@ -87,6 +89,7 @@ class ExecutorAgent(BaseAgent):
yield output_message
else:
query_c = copy.deepcopy(query)
query_c = self.start_action_step(query_c)
task_content = query_c.parsed_output["PLAN"][plan_step]
query_c.parsed_output = {"Question": task_content}
task_executor_memory.append(query_c)
@ -99,6 +102,8 @@ class ExecutorAgent(BaseAgent):
# logger.info(f"{self.role.role_name} currenct question: {output_message.input_query}\nllm_executor_run: {output_message.step_content}")
# logger.info(f"{self.role.role_name} currenct parserd_output_list: {output_message.parserd_output_list}")
output_message.input_query = output_message.role_content
# end_action_step
output_message = self.end_action_step(output_message)
# update memory pool
memory_pool.append(output_message)
yield output_message
@ -113,9 +118,7 @@ class ExecutorAgent(BaseAgent):
logger.debug(f"{self.role.role_name} content: {content}")
output_message.role_content = content
output_message.role_contents += [content]
output_message.step_content += "\n"+output_message.role_content
output_message.step_contents + [output_message.role_content]
output_message = self.message_utils.parser(output_message)
# according the output to choose one action for code_content or tool_content
@ -141,7 +144,7 @@ class ExecutorAgent(BaseAgent):
doc_infos = self.create_doc_prompt(query)
code_infos = self.create_codedoc_prompt(query)
#
formatted_tools, tool_names = self.create_tools_prompt(query)
formatted_tools, tool_names, _ = self.create_tools_prompt(query)
task_prompt = self.create_task_prompt(query)
background_prompt = self.create_background_prompt(background, control_key="step_content")
history_prompt = self.create_history_prompt(history)

View File

@ -59,10 +59,15 @@ class ReactAgent(BaseAgent):
step_content="",
input_query=query.input_query,
tools=query.tools,
parsed_output_list=[query.parsed_output]
parsed_output_list=[query.parsed_output],
customed_kargs=query.customed_kargs
)
query_c = copy.deepcopy(query)
query_c.parsed_output = {"Question": "\n".join([f"{v}" for k, v in query.parsed_output.items() if k not in ["Action Status"]])}
query_c = self.start_action_step(query_c)
if query.parsed_output:
query_c.parsed_output = {"Question": "\n".join([f"{v}" for k, v in query.parsed_output.items() if k not in ["Action Status"]])}
else:
query_c.parsed_output = {"Question": query.input_query}
react_memory.append(query_c)
self_memory = self.memory if self.do_use_self_memory else None
idx = 0
@ -77,9 +82,7 @@ class ReactAgent(BaseAgent):
raise Exception(traceback.format_exc())
output_message.role_content = "\n"+content
output_message.role_contents += [content]
output_message.step_content += "\n"+output_message.role_content
output_message.step_contents + [output_message.role_content]
yield output_message
# logger.debug(f"{self.role.role_name}, {idx} iteration prompt: {prompt}")
@ -87,7 +90,7 @@ class ReactAgent(BaseAgent):
output_message = self.message_utils.parser(output_message)
# when get finished signal can stop early
if output_message.action_status == ActionStatus.FINISHED: break
if output_message.action_status == ActionStatus.FINISHED or output_message.action_status == ActionStatus.STOPED: break
# according the output to choose one action for code_content or tool_content
output_message, observation_message = self.message_utils.step_router(output_message)
output_message.parsed_output_list.append(output_message.parsed_output)
@ -108,6 +111,8 @@ class ReactAgent(BaseAgent):
# update memory pool
# memory_pool.append(output_message)
output_message.input_query = query.input_query
# end_action_step
output_message = self.end_action_step(output_message)
# update memory pool
memory_pool.append(output_message)
yield output_message
@ -122,7 +127,7 @@ class ReactAgent(BaseAgent):
doc_infos = self.create_doc_prompt(query)
code_infos = self.create_codedoc_prompt(query)
#
formatted_tools, tool_names = self.create_tools_prompt(query)
formatted_tools, tool_names, _ = self.create_tools_prompt(query)
task_prompt = self.create_task_prompt(query)
background_prompt = self.create_background_prompt(background)
history_prompt = self.create_history_prompt(history)
@ -136,7 +141,7 @@ class ReactAgent(BaseAgent):
# # input_query = query.input_query + "\n" + "\n".join([f"{v}" for k, v in input_query if v])
# input_query = "\n".join([f"{v}" for k, v in input_query if v])
input_query = "\n".join(["\n".join([f"**{k}:**\n{v}" for k,v in _dict.items()]) for _dict in react_memory.get_parserd_output()])
logger.debug(f"input_query: {input_query}")
# logger.debug(f"input_query: {input_query}")
prompt += "\n" + REACT_PROMPT_INPUT.format(**{"query": input_query})

View File

@ -4,6 +4,7 @@ import re
import json
import traceback
import copy
import random
from loguru import logger
from langchain.prompts.chat import ChatPromptTemplate
@ -12,7 +13,8 @@ from dev_opsgpt.connector.schema import (
Memory, Task, Env, Role, Message, ActionStatus
)
from dev_opsgpt.llm_models import getChatModel
from dev_opsgpt.connector.configs.agent_config import REACT_PROMPT_INPUT, CONTEXT_PROMPT_INPUT, QUERY_CONTEXT_PROMPT_INPUT
from dev_opsgpt.connector.configs.prompts import BASE_PROMPT_INPUT, QUERY_CONTEXT_DOC_PROMPT_INPUT, BEGIN_PROMPT_INPUT
from dev_opsgpt.connector.utils import parse_section
from .base_agent import BaseAgent
@ -33,6 +35,7 @@ class SelectorAgent(BaseAgent):
do_use_self_memory: bool = True,
focus_agents: List[str] = [],
focus_message_keys: List[str] = [],
group_agents: List[BaseAgent] = [],
# prompt_mamnger: PromptManager
):
@ -40,6 +43,53 @@ class SelectorAgent(BaseAgent):
do_tool_retrieval, temperature, stop, do_filter,do_use_self_memory,
focus_agents, focus_message_keys
)
self.group_agents = group_agents
def arun(self, query: Message, history: Memory = None, background: Memory = None, memory_pool: Memory=None) -> Message:
'''agent reponse from multi-message'''
# insert query into memory
query_c = copy.deepcopy(query)
query = self.start_action_step(query)
self_memory = self.memory if self.do_use_self_memory else None
# create your llm prompt
prompt = self.create_prompt(query_c, self_memory, history, background, memory_pool=memory_pool)
content = self.llm.predict(prompt)
logger.debug(f"{self.role.role_name} prompt: {prompt}")
logger.debug(f"{self.role.role_name} content: {content}")
# select agent
select_message = Message(
role_name=self.role.role_name,
role_type="ai", #self.role.role_type,
role_content=content,
step_content=content,
input_query=query_c.input_query,
tools=query_c.tools,
parsed_output_list=[query.parsed_output]
)
# common parse llm' content to message
select_message = self.message_utils.parser(select_message)
if self.do_filter:
select_message = self.message_utils.filter(select_message)
output_message = None
if select_message.parsed_output.get("Role", "") in [agent.role.role_name for agent in self.group_agents]:
for agent in self.group_agents:
if agent.role.role_name == select_message.parsed_output.get("Role", ""):
break
for output_message in agent.arun(query, history, background=background, memory_pool=memory_pool):
pass
# update self_memory
self.append_history(query_c)
self.append_history(output_message)
logger.info(f"{agent.role.role_name} currenct question: {output_message.input_query}\nllm_step_run: {output_message.role_content}")
output_message.input_query = output_message.role_content
output_message.parsed_output_list.append(output_message.parsed_output)
#
output_message = self.end_action_step(output_message)
# update memory pool
memory_pool.append(output_message)
yield output_message or select_message
def create_prompt(
self, query: Message, memory: Memory =None, history: Memory = None, background: Memory = None, memory_pool: Memory=None, prompt_mamnger=None) -> str:
@ -50,56 +100,49 @@ class SelectorAgent(BaseAgent):
doc_infos = self.create_doc_prompt(query)
code_infos = self.create_codedoc_prompt(query)
#
formatted_tools, tool_names = self.create_tools_prompt(query)
formatted_tools, tool_names, tools_descs = self.create_tools_prompt(query)
agent_names, agents = self.create_agent_names()
task_prompt = self.create_task_prompt(query)
background_prompt = self.create_background_prompt(background)
history_prompt = self.create_history_prompt(history)
selfmemory_prompt = self.create_selfmemory_prompt(memory, control_key="step_content")
# react 流程是自身迭代过程,另外二次触发的是需要作为历史对话信息
# input_query = react_memory.to_tuple_messages(content_key="step_content")
DocInfos = ""
if doc_infos is not None and doc_infos!="" and doc_infos!="不存在知识库辅助信息":
DocInfos += f"\nDocument Information: {doc_infos}"
if code_infos is not None and code_infos!="" and code_infos!="不存在代码库辅助信息":
DocInfos += f"\nCodeBase Infomation: {code_infos}"
input_query = query.input_query
# logger.debug(f"{self.role.role_name} extra_system_prompt: {self.role.role_prompt}")
logger.debug(f"{self.role.role_name} input_query: {input_query}")
# logger.debug(f"{self.role.role_name} doc_infos: {doc_infos}")
# logger.debug(f"{self.role.role_name} tool_names: {tool_names}")
# prompt += "\n" + CHECK_PROMPT_INPUT.format(**{"query": input_query})
# prompt.format(**{"query": input_query})
# extra_system_prompt = self.role.role_prompt
prompt = self.role.role_prompt.format(**{"query": input_query, "formatted_tools": formatted_tools, "tool_names": tool_names})
prompt = self.role.role_prompt.format(**{"agent_names": agent_names, "agents": agents, "formatted_tools": tools_descs, "tool_names": tool_names})
#
memory_pool_select_by_agent_key = self.select_memory_by_agent_key(memory_pool)
memory_pool_select_by_agent_key_context = '\n\n'.join([f"*{k}*\n{v}" for parsed_output in memory_pool_select_by_agent_key.get_parserd_output_list() for k, v in parsed_output.items() if k not in ['Action Status']])
if "**Context:**" in self.role.role_prompt:
# logger.debug(f"parsed_output_list: {query.parsed_output_list}")
# input_query = "'''" + "\n".join([f"*{k}*\n{v}" for i in background.get_parserd_output_list() for k,v in i.items() if "Action Status" !=k]) + "'''"
context = "\n".join([f"*{k}*\n{v}" for i in background.get_parserd_output_list() for k,v in i.items() if "Action Status" !=k])
# logger.debug(f"parsed_output_list: {t}")
prompt += "\n" + QUERY_CONTEXT_PROMPT_INPUT.format(**{"query": query.origin_query, "context": context})
else:
prompt += "\n" + REACT_PROMPT_INPUT.format(**{"query": input_query})
input_keys = parse_section(self.role.role_prompt, 'Input Format')
#
prompt += "\n" + BEGIN_PROMPT_INPUT
for input_key in input_keys:
if input_key == "Origin Query":
prompt += "\n**Origin Query:**\n" + query.origin_query
elif input_key == "Context":
context = "\n".join([f"*{k}*\n{v}" for i in query.parsed_output_list for k,v in i.items() if "Action Status" !=k])
if history:
context = history_prompt + "\n" + context
if not context:
context = "there is no context"
if self.focus_agents and memory_pool_select_by_agent_key_context:
context = memory_pool_select_by_agent_key_context
prompt += "\n**Context:**\n" + context + "\n" + input_query
elif input_key == "DocInfos":
prompt += "\n**DocInfos:**\n" + DocInfos
elif input_key == "Question":
prompt += "\n**Question:**\n" + input_query
task = query.task or self.task
if task_prompt is not None:
prompt += "\n" + task.task_prompt
# if doc_infos is not None and doc_infos!="" and doc_infos!="不存在知识库辅助信息":
# prompt += f"\n知识库信息: {doc_infos}"
# if code_infos is not None and code_infos!="" and code_infos!="不存在代码库辅助信息":
# prompt += f"\n代码库信息: {code_infos}"
# if background_prompt:
# prompt += "\n" + background_prompt
# if history_prompt:
# prompt += "\n" + history_prompt
# if selfmemory_prompt:
# prompt += "\n" + selfmemory_prompt
# prompt = extra_system_prompt.format(**{"query": input_query, "doc_infos": doc_infos, "formatted_tools": formatted_tools, "tool_names": tool_names})
while "{{" in prompt or "}}" in prompt:
prompt = prompt.replace("{{", "{")
prompt = prompt.replace("}}", "}")
@ -107,3 +150,16 @@ class SelectorAgent(BaseAgent):
# logger.debug(f"{self.role.role_name} prompt: {prompt}")
return prompt
def create_agent_names(self):
random.shuffle(self.group_agents)
agent_names = ", ".join([f'{agent.role.role_name}' for agent in self.group_agents])
agent_descs = []
for agent in self.group_agents:
role_desc = agent.role.role_prompt.split("####")[1]
while "\n\n" in role_desc:
role_desc = role_desc.replace("\n\n", "\n")
role_desc = role_desc.replace("\n", ",")
agent_descs.append(f'"role name: {agent.role.role_name}\nrole description: {role_desc}"')
return agent_names, "\n".join(agent_descs)

View File

@ -15,10 +15,6 @@ from dev_opsgpt.connector.schema import (
load_role_configs
)
from dev_opsgpt.connector.message_process import MessageUtils
from dev_opsgpt.sandbox import PyCodeBox, CodeBoxResponse
from configs.server_config import SANDBOX_SERVER
from dev_opsgpt.connector.configs.agent_config import AGETN_CONFIGS
role_configs = load_role_configs(AGETN_CONFIGS)
@ -31,7 +27,6 @@ class BaseChain:
agents: List[BaseAgent],
chat_turn: int = 1,
do_checker: bool = False,
do_code_exec: bool = False,
# prompt_mamnger: PromptManager
) -> None:
self.chainConfig = chainConfig
@ -45,29 +40,9 @@ class BaseChain:
do_doc_retrieval = role_configs["checker"].do_doc_retrieval,
do_tool_retrieval = role_configs["checker"].do_tool_retrieval,
do_filter=False, do_use_self_memory=False)
self.do_agent_selector = False
self.agent_selector = CheckAgent(role=role_configs["checker"].role,
task = None,
memory = None,
do_search = role_configs["checker"].do_search,
do_doc_retrieval = role_configs["checker"].do_doc_retrieval,
do_tool_retrieval = role_configs["checker"].do_tool_retrieval,
do_filter=False, do_use_self_memory=False)
self.messageUtils = MessageUtils()
# all memory created by agent until instance deleted
self.global_memory = Memory(messages=[])
# self.do_code_exec = do_code_exec
# self.codebox = PyCodeBox(
# remote_url=SANDBOX_SERVER["url"],
# remote_ip=SANDBOX_SERVER["host"],
# remote_port=SANDBOX_SERVER["port"],
# token="mytoken",
# do_code_exe=True,
# do_remote=SANDBOX_SERVER["do_remote"],
# do_check_net=False
# )
def step(self, query: Message, history: Memory = None, background: Memory = None, memory_pool: Memory = None) -> Message:
'''execute chain'''
@ -85,263 +60,48 @@ class BaseChain:
self.global_memory.append(input_message)
# local_memory.append(input_message)
while step_nums > 0:
if self.do_agent_selector:
agent_message = copy.deepcopy(query)
agent_message.agents = self.agents
for selectory_message in self.agent_selector.arun(query, background=self.global_memory, memory_pool=memory_pool):
pass
selectory_message = self.messageUtils.parser(selectory_message)
selectory_message = self.messageUtils.filter(selectory_message)
agent = self.agents[selectory_message.agent_index]
# selector agent to execure next task
for agent in self.agents:
for output_message in agent.arun(input_message, history, background=background, memory_pool=memory_pool):
# logger.debug(f"local_memory {local_memory + output_message}")
yield output_message, local_memory + output_message
output_message = self.messageUtils.inherit_extrainfo(input_message, output_message)
# according the output to choose one action for code_content or tool_content
# logger.info(f"{agent.role.role_name}\nmessage: {output_message.step_content}\nquery: {output_message.input_query}")
# logger.info(f"{agent.role.role_name} currenct message: {output_message.step_content}\n next llm question: {output_message.input_query}")
output_message = self.messageUtils.parser(output_message)
yield output_message, local_memory + output_message
# output_message = self.step_router(output_message)
input_message = output_message
self.global_memory.append(output_message)
local_memory.append(output_message)
# when get finished signal can stop early
if output_message.action_status == ActionStatus.FINISHED:
if output_message.action_status == ActionStatus.FINISHED or output_message.action_status == ActionStatus.STOPED:
action_status = False
break
else:
for agent in self.agents:
for output_message in agent.arun(input_message, history, background=background, memory_pool=memory_pool):
# logger.debug(f"local_memory {local_memory + output_message}")
yield output_message, local_memory + output_message
output_message = self.messageUtils.inherit_extrainfo(input_message, output_message)
# according the output to choose one action for code_content or tool_content
# logger.info(f"{agent.role.role_name} currenct message: {output_message.step_content}\n next llm question: {output_message.input_query}")
output_message = self.messageUtils.parser(output_message)
yield output_message, local_memory + output_message
# output_message = self.step_router(output_message)
if output_message.action_status == ActionStatus.FINISHED:
break
input_message = output_message
self.global_memory.append(output_message)
if self.do_checker and self.chat_turn > 1:
# logger.debug(f"{self.checker.role.role_name} input global memory: {self.global_memory.to_str_messages(content_key='step_content', return_all=False)}")
for check_message in self.checker.arun(query, background=local_memory, memory_pool=memory_pool):
pass
check_message = self.messageUtils.parser(check_message)
check_message = self.messageUtils.filter(check_message)
check_message = self.messageUtils.inherit_extrainfo(output_message, check_message)
logger.debug(f"{self.checker.role.role_name}: {check_message.role_content}")
local_memory.append(output_message)
# when get finished signal can stop early
if output_message.action_status == ActionStatus.FINISHED:
action_status = False
break
if self.do_checker and self.chat_turn > 1:
# logger.debug(f"{self.checker.role.role_name} input global memory: {self.global_memory.to_str_messages(content_key='step_content', return_all=False)}")
for check_message in self.checker.arun(query, background=local_memory, memory_pool=memory_pool):
pass
check_message = self.messageUtils.parser(check_message)
check_message = self.messageUtils.filter(check_message)
check_message = self.messageUtils.inherit_extrainfo(output_message, check_message)
logger.debug(f"{self.checker.role.role_name}: {check_message.role_content}")
if check_message.action_status == ActionStatus.FINISHED:
self.global_memory.append(check_message)
break
if check_message.action_status == ActionStatus.FINISHED:
self.global_memory.append(check_message)
break
step_nums -= 1
#
output_message = check_message or output_message # 返回chain和checker的结果
output_message.input_query = query.input_query # chain和chain之间消息通信不改变问题
yield output_message, local_memory
# def step_router(self, message: Message) -> Message:
# ''''''
# # message = self.parser(message)
# # logger.debug(f"message.action_status: {message.action_status}")
# if message.action_status == ActionStatus.CODING:
# message = self.code_step(message)
# elif message.action_status == ActionStatus.TOOL_USING:
# message = self.tool_step(message)
# return message
# def code_step(self, message: Message) -> Message:
# '''execute code'''
# # logger.debug(f"message.role_content: {message.role_content}, message.code_content: {message.code_content}")
# code_answer = self.codebox.chat('```python\n{}```'.format(message.code_content))
# uid = str(uuid.uuid1())
# if code_answer.code_exe_type == "image/png":
# message.figures[uid] = code_answer.code_exe_response
# message.code_answer = f"\n观察: 执行代码后获得输出一张图片, 文件名为{uid}\n"
# message.observation = f"\n观察: 执行代码后获得输出一张图片, 文件名为{uid}\n"
# message.step_content += f"\n观察: 执行代码后获得输出一张图片, 文件名为{uid}\n"
# message.step_contents += [f"\n观察: 执行代码后获得输出一张图片, 文件名为{uid}\n"]
# message.role_content += f"\n执行代码后获得输出一张图片, 文件名为{uid}\n"
# else:
# message.code_answer = code_answer.code_exe_response
# message.observation = code_answer.code_exe_response
# message.step_content += f"\n观察: 执行代码后获得输出是 {code_answer.code_exe_response}\n"
# message.step_contents += [f"\n观察: 执行代码后获得输出是 {code_answer.code_exe_response}\n"]
# message.role_content += f"\n观察: 执行代码后获得输出是 {code_answer.code_exe_response}\n"
# logger.info(f"观察: {message.action_status}, {message.observation}")
# return message
# def tool_step(self, message: Message) -> Message:
# '''execute tool'''
# # logger.debug(f"message: {message.action_status}, {message.tool_name}, {message.tool_params}")
# tool_names = [tool.name for tool in message.tools]
# if message.tool_name not in tool_names:
# message.tool_answer = "不存在可以执行的tool"
# message.observation = "不存在可以执行的tool"
# message.role_content += f"\n观察: 不存在可以执行的tool\n"
# message.step_content += f"\n观察: 不存在可以执行的tool\n"
# message.step_contents += [f"\n观察: 不存在可以执行的tool\n"]
# for tool in message.tools:
# if tool.name == message.tool_name:
# tool_res = tool.func(**message.tool_params)
# message.tool_answer = tool_res
# message.observation = tool_res
# message.role_content += f"\n观察: {tool_res}\n"
# message.step_content += f"\n观察: {tool_res}\n"
# message.step_contents += [f"\n观察: {tool_res}\n"]
# return message
# def filter(self, message: Message, stop=None) -> Message:
# tool_params = self.parser_spec_key(message.role_content, "tool_params")
# code_content = self.parser_spec_key(message.role_content, "code_content")
# plan = self.parser_spec_key(message.role_content, "plan")
# plans = self.parser_spec_key(message.role_content, "plans", do_search=False)
# content = self.parser_spec_key(message.role_content, "content", do_search=False)
# # logger.debug(f"tool_params: {tool_params}, code_content: {code_content}, plan: {plan}, plans: {plans}, content: {content}")
# role_content = tool_params or code_content or plan or plans or content
# message.role_content = role_content or message.role_content
# return message
# def parser(self, message: Message) -> Message:
# ''''''
# content = message.role_content
# parser_keys = ["action", "code_content", "code_filename", "tool_params", "plans"]
# try:
# s_json = self._parse_json(content)
# message.action_status = s_json.get("action")
# message.code_content = s_json.get("code_content")
# message.tool_params = s_json.get("tool_params")
# message.tool_name = s_json.get("tool_name")
# message.code_filename = s_json.get("code_filename")
# message.plans = s_json.get("plans")
# # for parser_key in parser_keys:
# # message.action_status = content.get(parser_key)
# except Exception as e:
# # logger.warning(f"{traceback.format_exc()}")
# def parse_text_to_dict(text):
# # Define a regular expression pattern to capture the key and value
# main_pattern = r"\*\*(.+?):\*\*\s*(.*?)\s*(?=\*\*|$)"
# code_pattern = r'```python\n(.*?)```'
# # Use re.findall to find all main matches in the text
# main_matches = re.findall(main_pattern, text, re.DOTALL)
# # Convert main matches to a dictionary
# parsed_dict = {key.strip(): value.strip() for key, value in main_matches}
# # Search for the code block
# code_match = re.search(code_pattern, text, re.DOTALL)
# if code_match:
# # Add the code block to the dictionary
# parsed_dict['code'] = code_match.group(1).strip()
# return parsed_dict
# parsed_dict = parse_text_to_dict(content)
# action_value = parsed_dict.get('Action Status')
# if action_value:
# action_value = action_value.lower()
# logger.debug(f'action_value: {action_value}')
# # action_value = self._match(r"'action':\s*'([^']*)'", content) if "'action'" in content else self._match(r'"action":\s*"([^"]*)"', content)
# code_content_value = parsed_dict.get('code')
# # code_content_value = self._match(r"'code_content':\s*'([^']*)'", content) if "'code_content'" in content else self._match(r'"code_content":\s*"([^"]*)"', content)
# filename_value = self._match(r"'code_filename':\s*'([^']*)'", content) if "'code_filename'" in content else self._match(r'"code_filename":\s*"([^"]*)"', content)
# tool_params_value = self._match(r"'tool_params':\s*(\{[^{}]*\})", content, do_json=True) if "'tool_params'" in content \
# else self._match(r'"tool_params":\s*(\{[^{}]*\})', content, do_json=True)
# tool_name_value = self._match(r"'tool_name':\s*'([^']*)'", content) if "'tool_name'" in content else self._match(r'"tool_name":\s*"([^"]*)"', content)
# plans_value = self._match(r"'plans':\s*(\[.*?\])", content, do_search=False) if "'plans'" in content else self._match(r'"plans":\s*(\[.*?\])', content, do_search=False, )
# # re解析
# message.action_status = action_value or "default"
# message.code_content = code_content_value
# message.code_filename = filename_value
# message.tool_params = tool_params_value
# message.tool_name = tool_name_value
# message.plans = plans_value
# # logger.debug(f"确认当前的action: {message.action_status}")
# return message
# def parser_spec_key(self, content, key, do_search=True, do_json=False) -> str:
# ''''''
# key2pattern = {
# "'action'": r"'action':\s*'([^']*)'", '"action"': r'"action":\s*"([^"]*)"',
# "'code_content'": r"'code_content':\s*'([^']*)'", '"code_content"': r'"code_content":\s*"([^"]*)"',
# "'code_filename'": r"'code_filename':\s*'([^']*)'", '"code_filename"': r'"code_filename":\s*"([^"]*)"',
# "'tool_params'": r"'tool_params':\s*(\{[^{}]*\})", '"tool_params"': r'"tool_params":\s*(\{[^{}]*\})',
# "'tool_name'": r"'tool_name':\s*'([^']*)'", '"tool_name"': r'"tool_name":\s*"([^"]*)"',
# "'plans'": r"'plans':\s*(\[.*?\])", '"plans"': r'"plans":\s*(\[.*?\])',
# "'content'": r"'content':\s*'([^']*)'", '"content"': r'"content":\s*"([^"]*)"',
# }
# s_json = self._parse_json(content)
# try:
# if s_json and key in s_json:
# return str(s_json[key])
# except:
# pass
# keystr = f"'{key}'" if f"'{key}'" in content else f'"{key}"'
# return self._match(key2pattern.get(keystr, fr"'{key}':\s*'([^']*)'"), content, do_search=do_search, do_json=do_json)
# def _match(self, pattern, s, do_search=True, do_json=False):
# try:
# if do_search:
# match = re.search(pattern, s)
# if match:
# value = match.group(1).replace("\\n", "\n")
# if do_json:
# value = json.loads(value)
# else:
# value = None
# else:
# match = re.findall(pattern, s, re.DOTALL)
# if match:
# value = match[0]
# if do_json:
# value = json.loads(value)
# else:
# value = None
# except Exception as e:
# logger.warning(f"{traceback.format_exc()}")
# # logger.debug(f"pattern: {pattern}, s: {s}, match: {match}")
# return value
# def _parse_json(self, s):
# try:
# pattern = r"```([^`]+)```"
# match = re.findall(pattern, s)
# if match:
# return eval(match[0])
# except:
# pass
# return None
# def inherit_extrainfo(self, input_message: Message, output_message: Message):
# output_message.db_docs = input_message.db_docs
# output_message.search_docs = input_message.search_docs
# output_message.code_docs = input_message.code_docs
# output_message.origin_query = input_message.origin_query
# output_message.figures.update(input_message.figures)
# return output_message
def get_memory(self, content_key="role_content") -> Memory:
memory = self.global_memory

View File

@ -7,6 +7,7 @@ from .prompts import (
QA_PROMPT, CODE_QA_PROMPT, QA_TEMPLATE_PROMPT,
EXECUTOR_TEMPLATE_PROMPT,
REFINE_TEMPLATE_PROMPT,
SELECTOR_AGENT_TEMPLATE_PROMPT,
PLANNER_TEMPLATE_PROMPT, GENERAL_PLANNER_PROMPT, DATA_PLANNER_PROMPT, TOOL_PLANNER_PROMPT,
PRD_WRITER_METAGPT_PROMPT, DESIGN_WRITER_METAGPT_PROMPT, TASK_WRITER_METAGPT_PROMPT, CODE_WRITER_METAGPT_PROMPT,
REACT_TEMPLATE_PROMPT,
@ -20,10 +21,25 @@ class AgentType:
EXECUTOR = "ExecutorAgent"
ONE_STEP = "BaseAgent"
DEFAULT = "BaseAgent"
SELECTOR = "SelectorAgent"
AGETN_CONFIGS = {
"baseGroup": {
"role": {
"role_prompt": SELECTOR_AGENT_TEMPLATE_PROMPT,
"role_type": "assistant",
"role_name": "baseGroup",
"role_desc": "",
"agent_type": "SelectorAgent"
},
"group_agents": ["tool_react", "code_react"],
"chat_turn": 1,
"do_search": False,
"do_doc_retrieval": False,
"do_tool_retrieval": False
},
"checker": {
"role": {
"role_prompt": CHECKER_TEMPLATE_PROMPT,

View File

@ -108,4 +108,20 @@ CHAIN_CONFIGS = {
"do_checker": False,
"chain_prompt": ""
},
"baseGroupChain": {
"chain_name": "baseGroupChain",
"chain_type": "BaseChain",
"agents": ["baseGroup"],
"chat_turn": 1,
"do_checker": False,
"chain_prompt": ""
},
"codeChatXXChain": {
"chain_name": "codeChatXXChain",
"chain_type": "BaseChain",
"agents": ["codeChat1", "codeChat2"],
"chat_turn": 1,
"do_checker": False,
"chain_prompt": ""
}
}

View File

@ -88,15 +88,26 @@ PHASE_CONFIGS = {
"do_tool_retrieval": False,
"do_using_tool": False
},
"metagpt_code_devlop": {
"phase_name": "metagpt_code_devlop",
"phase_type": "BasePhase",
"chains": ["metagptChain",],
"do_summary": False,
"do_search": False,
"do_doc_retrieval": False,
"do_code_retrieval": False,
"do_tool_retrieval": False,
"do_using_tool": False
},
# "metagpt_code_devlop": {
# "phase_name": "metagpt_code_devlop",
# "phase_type": "BasePhase",
# "chains": ["metagptChain",],
# "do_summary": False,
# "do_search": False,
# "do_doc_retrieval": False,
# "do_code_retrieval": False,
# "do_tool_retrieval": False,
# "do_using_tool": False
# },
# "baseGroupPhase": {
# "phase_name": "baseGroupPhase",
# "phase_type": "BasePhase",
# "chains": ["baseGroupChain"],
# "do_summary": False,
# "do_search": False,
# "do_doc_retrieval": False,
# "do_code_retrieval": False,
# "do_tool_retrieval": False,
# "do_using_tool": False
# },
}

View File

@ -15,6 +15,8 @@ from .qa_template_prompt import QA_PROMPT, CODE_QA_PROMPT, QA_TEMPLATE_PROMPT
from .executor_template_prompt import EXECUTOR_TEMPLATE_PROMPT
from .refine_template_prompt import REFINE_TEMPLATE_PROMPT
from .agent_selector_template_prompt import SELECTOR_AGENT_TEMPLATE_PROMPT
from .react_template_prompt import REACT_TEMPLATE_PROMPT
from .react_code_prompt import REACT_CODE_PROMPT
from .react_tool_prompt import REACT_TOOL_PROMPT
@ -32,6 +34,7 @@ __all__ = [
"QA_PROMPT", "CODE_QA_PROMPT", "QA_TEMPLATE_PROMPT",
"EXECUTOR_TEMPLATE_PROMPT",
"REFINE_TEMPLATE_PROMPT",
"SELECTOR_AGENT_TEMPLATE_PROMPT",
"PLANNER_TEMPLATE_PROMPT", "GENERAL_PLANNER_PROMPT", "DATA_PLANNER_PROMPT", "TOOL_PLANNER_PROMPT",
"REACT_TEMPLATE_PROMPT",
"REACT_CODE_PROMPT", "REACT_TOOL_PROMPT", "REACT_TOOL_AND_CODE_PROMPT", "REACT_TOOL_AND_CODE_PLANNER_PROMPT"

View File

@ -0,0 +1,24 @@
SELECTOR_AGENT_TEMPLATE_PROMPT = """#### Role Selector Assistance Guidance
Your goal is to match the user's initial Origin Query) with the role that will best facilitate a solution, taking into account all relevant context (Context) provided.
When you need to select the appropriate role for handling a user's query, carefully read the provided role names, role descriptions and tool list.
You can use these tools:\n{formatted_tools}
Please ensure your selection is one of the listed roles. Available roles for selection:
{agents}
#### Input Format
**Origin Query:** the initial question or objective that the user wanted to achieve
**Context:** the context history to determine if Origin Query has been achieved.
#### Response Output Format
**Thoughts:** think the reason of selecting the role step by step
**Role:** Select the role name. such as {agent_names}
"""

View File

@ -12,12 +12,12 @@ Each decision should be justified based on the context provided, specifying if t
**Context:** the current status and history of the tasks to determine if Origin Query has been achieved.
#### Response Output Format
**REASON:** Justify the decision of choosing 'finished' and 'continued' by evaluating the progress step by step.
Consider all relevant information. If the tasks were aimed at an ongoing process, assess whether it has reached a satisfactory conclusion.
**Action Status:** Set to 'finished' or 'continued'.
If it's 'finished', the context can answer the origin query.
If it's 'continued', the context cant answer the origin query.
**REASON:** Justify the decision of choosing 'finished' and 'continued' by evaluating the progress step by step.
Consider all relevant information. If the tasks were aimed at an ongoing process, assess whether it has reached a satisfactory conclusion.
"""
CHECKER_PROMPT = """尽可能地以有帮助和准确的方式回应人类,判断问题是否得到解答,同时展现解答的过程和内容。

View File

@ -12,7 +12,7 @@ Each reply should contain only the code required for the current step.
**Thoughts:** Based on the question and observations above, provide the plan for executing this step.
**Action Status:** Set to 'finished' or 'coding'. If it's 'finished', the next action is to provide the final answer to the original question. If it's 'coding', the next step is to write the code.
**Action Status:** Set to 'stoped' or 'code_executing'. If it's 'stoped', the next action is to provide the final answer to the original question. If it's 'code_executing', the next step is to write the code.
**Action:** Code according to your thoughts. Use this format for code:
@ -26,7 +26,7 @@ Each reply should contain only the code required for the current step.
**Thoughts:** I now know the final answer
**Action Status:** Set to 'finished'
**Action Status:** Set to 'stoped'
**Action:** The final answer to the original input question

View File

@ -1,7 +1,6 @@
QA_TEMPLATE_PROMPT = """#### Question Answer Assistance Guidance
Based on the information provided, please answer the origin query concisely and professionally.
If the answer cannot be derived from the given Context and DocInfos, please say 'The question cannot be answered based on the information provided' and do not add any fabricated elements to the answer.
Attention: Follow the input format and response output format
#### Input Format
@ -13,14 +12,15 @@ Attention: Follow the input format and response output format
**DocInfos:**: the relevant doc information or code information, if this is empty, don't refer to this.
#### Response Output Format
**Action Status:** Set to 'Continued' or 'Stopped'.
**Answer:** Response to the user's origin query based on Context and DocInfos. If DocInfos is empty, you can ignore it.
If the answer cannot be derived from the given Context and DocInfos, please say 'The question cannot be answered based on the information provided' and do not add any fabricated elements to the answer.
"""
CODE_QA_PROMPT = """#### Code Answer Assistance Guidance
Based on the information provided, please answer the origin query concisely and professionally.
If the answer cannot be derived from the given Context and DocInfos, please say 'The question cannot be answered based on the information provided' and do not add any fabricated elements to the answer.
Attention: Follow the input format and response output format
#### Input Format
@ -30,7 +30,9 @@ Attention: Follow the input format and response output format
**DocInfos:**: the relevant doc information or code information, if this is empty, don't refer to this.
#### Response Output Format
**Answer:** Response to the user's origin query based on DocInfos. If DocInfos is empty, you can ignore it.
**Action Status:** Set to 'Continued' or 'Stopped'.
**Answer:** Response to the user's origin query based on Context and DocInfos. If DocInfos is empty, you can ignore it.
If the answer cannot be derived from the given Context and DocInfos, please say 'The question cannot be answered based on the information provided' and do not add any fabricated elements to the answer.
"""

View File

@ -2,7 +2,9 @@
REACT_CODE_PROMPT = """#### Writing Code Assistance Guidance
When users need help with coding, your role is to provide precise and effective guidance. Write the code step by step, showing only the part necessary to solve the current problem. Each reply should contain only the code required for the current step.
When users need help with coding, your role is to provide precise and effective guidance.
Write the code step by step, showing only the part necessary to solve the current problem. Each reply should contain only the code required for the current step.
#### Response Process
@ -10,12 +12,13 @@ When users need help with coding, your role is to provide precise and effective
**Thoughts:** Based on the question and observations above, provide the plan for executing this step.
**Action Status:** Set to 'finished' or 'coding'. If it's 'finished', the next action is to provide the final answer to the original question. If it's 'coding', the next step is to write the code.
**Action:** Code according to your thoughts. (Please note that only the content printed out by the executed code can be observed in the subsequent observation.) Use this format for code:
**Action Status:** Set to 'stoped' or 'code_executing'. If it's 'stoped', the action is to provide the final answer to the original question. If it's 'code_executing', the action is to write the code.
**Action:**
```python
# Write your code here
import os
...
```
**Observation:** Check the results and effects of the executed code.
@ -24,7 +27,7 @@ When users need help with coding, your role is to provide precise and effective
**Thoughts:** I now know the final answer
**Action Status:** Set to 'finished'
**Action Status:** Set to 'stoped'
**Action:** The final answer to the original input question

View File

@ -10,7 +10,7 @@ When users need help with coding, your role is to provide precise and effective
**Thoughts:** Based on the question and observations above, provide the plan for executing this step.
**Action Status:** Set to 'finished' or 'coding'. If it's 'finished', the next action is to provide the final answer to the original question. If it's 'coding', the next step is to write the code.
**Action Status:** Set to 'stoped' or 'code_executing'. If it's 'stoped', the next action is to provide the final answer to the original question. If it's 'code_executing', the next step is to write the code.
**Action:** Code according to your thoughts. Use this format for code:
@ -24,7 +24,7 @@ When users need help with coding, your role is to provide precise and effective
**Thoughts:** I now know the final answer
**Action Status:** Set to 'finished'
**Action Status:** Set to 'stoped'
**Action:** The final answer to the original input question

View File

@ -1,25 +1,24 @@
REACT_TOOL_AND_CODE_PLANNER_PROMPT = """#### Tool and Code Sequence Breakdown Assistant
When users need assistance with deconstructing problems into a series of actionable plans using tools or code, your role is to provide a structured plan or a direct solution.
REACT_TOOL_AND_CODE_PLANNER_PROMPT = """#### Planner Assistance Guidance
When users seek assistance in breaking down complex issues into manageable and actionable steps,
your responsibility is to deliver a well-organized strategy or resolution through the use of tools or coding.
ATTENTION: response carefully referenced "Response Output Format" in format.
You may use the following tools:
{formatted_tools}
Depending on the user's query, the response will either be a plan detailing the use of tools and reasoning, or a direct answer if the problem does not require breaking down.
#### Input Format
**Origin Query:** user's query
**Question:** First, clarify the problem to be solved.
#### Follow this Response Format
#### Response Output Format
**Action Status:** Set to 'planning' to provide a sequence of tasks, or 'only_answer' to provide a direct response without a plan.
**Action:**
For planning:
**Action:**
```list
[
"First step of the plan using a specified tool or a outline plan for code...",
"Next step in the plan...",
// Continue with additional steps as necessary
"First, we should ...",
]
```

View File

@ -12,14 +12,16 @@ Valid "tool_name" value:\n{tool_names}
**Question:** Start by understanding the input question to be answered.
**Thoughts:** Considering the user's question, previously executed steps, and the plan, decide whether the current step requires the use of a tool or coding. Solve the problem step by step, only displaying the thought process necessary for the current step of solving the problem. If a tool can be used, provide its name and parameters. If coding is required, outline the plan for executing this step.
**Thoughts:** Considering the user's question, previously executed steps, and the plan, decide whether the current step requires the use of a tool or code_executing. Solve the problem step by step, only displaying the thought process necessary for the current step of solving the problem. If a tool can be used, provide its name and parameters. If code_executing is required, outline the plan for executing this step.
**Action Status:** finished, tool_using, or coding. (Choose one from these three statuses. If the task is done, set it to 'finished'. If using a tool, set it to 'tool_using'. If writing code, set it to 'coding'.)
**Action Status:** stoped, tool_using, or code_executing. (Choose one from these three statuses.)
If the task is done, set it to 'stoped'.
If using a tool, set it to 'tool_using'.
If writing code, set it to 'code_executing'.
**Action:**
If using a tool, output the following format to call the tool:
If using a tool, use the tools by formatting the tool action in JSON from Question and Observation:. The format should be:
```json
{{
"tool_name": "$TOOL_NAME",
@ -30,7 +32,7 @@ If using a tool, output the following format to call the tool:
If the problem cannot be solved with a tool at the moment, then proceed to solve the issue using code. Output the following format to execute the code:
```python
# Write your code here
Write your code here
```
**Observation:** Check the results and effects of the executed action.
@ -39,7 +41,7 @@ If the problem cannot be solved with a tool at the moment, then proceed to solve
**Thoughts:** Conclude the final response to the input question.
**Action Status:** finished
**Action Status:** stoped
**Action:** The final answer or guidance to the original input question.
"""
@ -47,7 +49,7 @@ If the problem cannot be solved with a tool at the moment, then proceed to solve
# REACT_TOOL_AND_CODE_PROMPT = """你是一个使用工具与代码的助手。
# 如果现有工具不足以完成整个任务,请不要添加不存在的工具,只使用现有工具完成可能的部分。
# 如果当前步骤不能使用工具完成,将由代码来完成。
# 有效的"action"值为:"finished"(已经完成用户的任务) 、 "tool_using" (使用工具来回答问题) 或 'coding'(结合总结下述思维链过程编写下一步的可执行代码)。
# 有效的"action"值为:"stoped"(已经完成用户的任务) 、 "tool_using" (使用工具来回答问题) 或 'code_executing'(结合总结下述思维链过程编写下一步的可执行代码)。
# 尽可能地以有帮助和准确的方式回应人类,你可以使用以下工具:
# {formatted_tools}
# 如果现在的步骤可以用工具解决问题,请仅在每个$JSON_BLOB中提供一个action如下所示

View File

@ -16,7 +16,7 @@ valid "tool_name" value is:\n{tool_names}
**Thoughts:** Based on the question and previous observations, plan the approach for using the tool effectively.
**Action Status:** Set to either 'finished' or 'tool_using'. If 'finished', provide the final response to the original question. If 'tool_using', proceed with using the specified tool.
**Action Status:** Set to either 'stoped' or 'tool_using'. If 'stoped', provide the final response to the original question. If 'tool_using', proceed with using the specified tool.
**Action:** Use the tools by formatting the tool action in JSON. The format should be:
@ -33,7 +33,7 @@ valid "tool_name" value is:\n{tool_names}
**Thoughts:** Determine the final response based on the results.
**Action Status:** Set to 'finished'
**Action Status:** Set to 'stoped'
**Action:** Conclude with the final response to the original question in this format:
@ -49,7 +49,7 @@ valid "tool_name" value is:\n{tool_names}
# REACT_TOOL_PROMPT = """尽可能地以有帮助和准确的方式回应人类。您可以使用以下工具:
# {formatted_tools}
# 使用json blob来指定一个工具提供一个action关键字工具名称和一个tool_params关键字工具输入
# 有效的"action"值为:"finished" 或 "tool_using" (使用工具来回答问题)
# 有效的"action"值为:"stoped" 或 "tool_using" (使用工具来回答问题)
# 有效的"tool_name"值为:{tool_names}
# 请仅在每个$JSON_BLOB中提供一个action如下所示
# ```
@ -73,7 +73,7 @@ valid "tool_name" value is:\n{tool_names}
# 行动:
# ```
# {{{{
# "action": "finished",
# "action": "stoped",
# "tool_name": "notool",
# "tool_params": "最终返回答案给到用户"
# }}}}

View File

@ -43,6 +43,38 @@ class MessageUtils:
output_message.code_docs = input_message.code_docs
output_message.figures.update(input_message.figures)
output_message.origin_query = input_message.origin_query
output_message.code_engine_name = input_message.code_engine_name
output_message.doc_engine_name = input_message.doc_engine_name
output_message.search_engine_name = input_message.search_engine_name
output_message.top_k = input_message.top_k
output_message.score_threshold = input_message.score_threshold
output_message.cb_search_type = input_message.cb_search_type
output_message.do_doc_retrieval = input_message.do_doc_retrieval
output_message.do_code_retrieval = input_message.do_code_retrieval
output_message.do_tool_retrieval = input_message.do_tool_retrieval
#
output_message.tools = input_message.tools
output_message.agents = input_message.agents
# 存在bug导致相同key被覆盖
output_message.customed_kargs.update(input_message.customed_kargs)
return output_message
def inherit_baseparam(self, input_message: Message, output_message: Message):
# 只更新参数
output_message.doc_engine_name = input_message.doc_engine_name
output_message.search_engine_name = input_message.search_engine_name
output_message.top_k = input_message.top_k
output_message.score_threshold = input_message.score_threshold
output_message.cb_search_type = input_message.cb_search_type
output_message.do_doc_retrieval = input_message.do_doc_retrieval
output_message.do_code_retrieval = input_message.do_code_retrieval
output_message.do_tool_retrieval = input_message.do_tool_retrieval
#
output_message.tools = input_message.tools
output_message.agents = input_message.agents
# 存在bug导致相同key被覆盖
output_message.customed_kargs.update(input_message.customed_kargs)
return output_message
def get_extrainfo_step(self, message: Message, do_search, do_doc_retrieval, do_code_retrieval, do_tool_retrieval) -> Message:
@ -92,18 +124,22 @@ class MessageUtils:
def get_tool_retrieval(self, message: Message) -> Message:
return message
def step_router(self, message: Message) -> tuple[Message, ...]:
def step_router(self, message: Message, history: Memory = None, background: Memory = None, memory_pool: Memory=None) -> tuple[Message, ...]:
''''''
# message = self.parser(message)
# logger.debug(f"message.action_status: {message.action_status}")
observation_message = None
if message.action_status == ActionStatus.CODING:
if message.action_status == ActionStatus.CODE_EXECUTING:
message, observation_message = self.code_step(message)
elif message.action_status == ActionStatus.TOOL_USING:
message, observation_message = self.tool_step(message)
elif message.action_status == ActionStatus.CODING2FILE:
self.save_code2file(message)
elif message.action_status == ActionStatus.CODE_RETRIEVAL:
pass
elif message.action_status == ActionStatus.CODING:
pass
return message, observation_message
def code_step(self, message: Message) -> Message:
@ -126,7 +162,6 @@ class MessageUtils:
message.code_answer = f"\n**Observation:**: The return figure name is {uid} after executing the above code.\n"
message.observation = f"\n**Observation:**: The return figure name is {uid} after executing the above code.\n"
message.step_content += f"\n**Observation:**: The return figure name is {uid} after executing the above code.\n"
message.step_contents += [f"\n**Observation:**:The return figure name is {uid} after executing the above code.\n"]
# message.role_content += f"\n**Observation:**:执行上述代码后生成一张图片, 图片名为{uid}\n"
observation_message.role_content = f"\n**Observation:**: The return figure name is {uid} after executing the above code.\n"
observation_message.parsed_output = {"Observation": f"The return figure name is {uid} after executing the above code."}
@ -134,7 +169,6 @@ class MessageUtils:
message.code_answer = code_answer.code_exe_response
message.observation = code_answer.code_exe_response
message.step_content += f"\n**Observation:**: {code_prompt}\n"
message.step_contents += [f"\n**Observation:**: {code_prompt}\n"]
# message.role_content += f"\n**Observation:**: {code_prompt}\n"
observation_message.role_content = f"\n**Observation:**: {code_prompt}\n"
observation_message.parsed_output = {"Observation": code_prompt}
@ -159,7 +193,6 @@ class MessageUtils:
message.observation = "\n**Observation:** there is no tool can execute\n"
# message.role_content += f"\n**Observation:**: 不存在可以执行的tool\n"
message.step_content += f"\n**Observation:** there is no tool can execute\n"
message.step_contents += [f"\n**Observation:** there is no tool can execute\n"]
observation_message.role_content = f"\n**Observation:** there is no tool can execute\n"
observation_message.parsed_output = {"Observation": "there is no tool can execute\n"}
for tool in message.tools:
@ -170,7 +203,6 @@ class MessageUtils:
message.observation = tool_res
# message.role_content += f"\n**Observation:**: {tool_res}\n"
message.step_content += f"\n**Observation:** {tool_res}\n"
message.step_contents += [f"\n**Observation:** {tool_res}\n"]
observation_message.role_content = f"\n**Observation:** {tool_res}\n"
observation_message.parsed_output = {"Observation": tool_res}
break

View File

@ -5,12 +5,12 @@ import importlib
import copy
from loguru import logger
from dev_opsgpt.connector.agents import BaseAgent
from dev_opsgpt.connector.agents import BaseAgent, SelectorAgent
from dev_opsgpt.connector.chains import BaseChain
from dev_opsgpt.tools.base_tool import BaseTools, Tool
from dev_opsgpt.connector.schema import (
Memory, Task, Env, Role, Message, Doc, Docs, AgentConfig, ChainConfig, PhaseConfig, CodeDoc,
Memory, Task, Env, Role, Message, Doc, AgentConfig, ChainConfig, PhaseConfig, CodeDoc,
load_chain_configs, load_phase_configs, load_role_configs
)
from dev_opsgpt.connector.configs import AGETN_CONFIGS, CHAIN_CONFIGS, PHASE_CONFIGS
@ -156,12 +156,31 @@ class BasePhase:
focus_agents=agent_config.focus_agents,
focus_message_keys=agent_config.focus_message_keys,
)
if agent_config.role.agent_type == "SelectorAgent":
for group_agent_name in agent_config.group_agents:
group_agent_config = role_configs[group_agent_name]
baseAgent: BaseAgent = getattr(self.agent_module, group_agent_config.role.agent_type)
group_base_agent = baseAgent(
group_agent_config.role,
task = task,
memory = memory,
chat_turn=group_agent_config.chat_turn,
do_search = group_agent_config.do_search,
do_doc_retrieval = group_agent_config.do_doc_retrieval,
do_tool_retrieval = group_agent_config.do_tool_retrieval,
stop= group_agent_config.stop,
focus_agents=group_agent_config.focus_agents,
focus_message_keys=group_agent_config.focus_message_keys,
)
base_agent.group_agents.append(group_base_agent)
agents.append(base_agent)
chain_instance = BaseChain(
chain_config, agents, chain_config.chat_turn,
do_checker=chain_configs[chain_name].do_checker,
do_code_exec=False,)
)
chains.append(chain_instance)
return chains

View File

@ -8,22 +8,78 @@ from langchain.tools import BaseTool
class ActionStatus(Enum):
FINISHED = "finished"
CODING = "coding"
TOOL_USING = "tool_using"
REASONING = "reasoning"
PLANNING = "planning"
EXECUTING_CODE = "executing_code"
EXECUTING_TOOL = "executing_tool"
DEFAUILT = "default"
FINISHED = "finished"
STOPED = "stoped"
CONTINUED = "continued"
TOOL_USING = "tool_using"
CODING = "coding"
CODE_EXECUTING = "code_executing"
CODING2FILE = "coding2file"
PLANNING = "planning"
UNCHANGED = "unchanged"
ADJUSTED = "adjusted"
CODE_RETRIEVAL = "code_retrieval"
def __eq__(self, other):
if isinstance(other, str):
return self.value == other
return self.value.lower() == other.lower()
return super().__eq__(other)
class Action(BaseModel):
action_name: str
description: str
class FinishedAction(Action):
action_name: str = ActionStatus.FINISHED
description: str = "provide the final answer to the original query to break the chain answer"
class StopedAction(Action):
action_name: str = ActionStatus.STOPED
description: str = "provide the final answer to the original query to break the agent answer"
class ContinuedAction(Action):
action_name: str = ActionStatus.CONTINUED
description: str = "cant't provide the final answer to the original query"
class ToolUsingAction(Action):
action_name: str = ActionStatus.TOOL_USING
description: str = "proceed with using the specified tool."
class CodingdAction(Action):
action_name: str = ActionStatus.CODING
description: str = "provide the answer by writing code"
class Coding2FileAction(Action):
action_name: str = ActionStatus.CODING2FILE
description: str = "provide the answer by writing code and filename"
class CodeExecutingAction(Action):
action_name: str = ActionStatus.CODE_EXECUTING
description: str = "provide the answer by writing executable code"
class PlanningAction(Action):
action_name: str = ActionStatus.PLANNING
description: str = "provide a sequence of tasks"
class UnchangedAction(Action):
action_name: str = ActionStatus.UNCHANGED
description: str = "this PLAN has no problem, just set PLAN_STEP to CURRENT_STEP+1."
class AdjustedAction(Action):
action_name: str = ActionStatus.ADJUSTED
description: str = "the PLAN is to provide an optimized version of the original plan."
# extended action exmaple
class CodeRetrievalAction(Action):
action_name: str = ActionStatus.CODE_RETRIEVAL
description: str = "execute the code retrieval to acquire more code information"
class RoleTypeEnums(Enum):
SYSTEM = "system"
USER = "user"
@ -37,7 +93,11 @@ class RoleTypeEnums(Enum):
return super().__eq__(other)
class InputKeyEnums(Enum):
class PromptKey(BaseModel):
key_name: str
description: str
class PromptKeyEnums(Enum):
# Origin Query is ui's user question
ORIGIN_QUERY = "origin_query"
# agent's input from last agent
@ -49,9 +109,9 @@ class InputKeyEnums(Enum):
# chain memory
CHAIN_MEMORY = "chain_memory"
# agent's memory
SELF_ONE_MEMORY = "self_one_memory"
SELF_LOCAL_MEMORY = "self_local_memory"
# chain memory
CHAIN_ONE_MEMORY = "chain_one_memory"
CHAIN_LOCAL_MEMORY = "chain_local_memory"
# Doc Infomations contains (Doc\Code\Search)
DOC_INFOS = "doc_infos"
@ -107,14 +167,6 @@ class CodeDoc(BaseModel):
return f"""出处 [{self.index + 1}] \n\n来源 ({self.related_nodes}) \n\n内容 {self.code}\n\n"""
class Docs:
def __init__(self, docs: List[Doc]):
self.titles: List[str] = [doc.get_title() for doc in docs]
self.snippets: List[str] = [doc.get_snippet() for doc in docs]
self.links: List[str] = [doc.get_link() for doc in docs]
self.indexs: List[int] = [doc.get_index() for doc in docs]
class Task(BaseModel):
task_type: str
task_name: str
@ -163,6 +215,7 @@ class AgentConfig(BaseModel):
do_tool_retrieval: bool = False
focus_agents: List = []
focus_message_keys: List = []
group_agents: List = []
class PhaseConfig(BaseModel):

View File

@ -12,15 +12,11 @@ class Message(BaseModel):
input_query: str = None
origin_query: str = None
# 模型最终返回
# llm output
role_content: str = None
role_contents: List[str] = []
step_content: str = None
step_contents: List[str] = []
chain_content: str = None
chain_contents: List[str] = []
# 模型结果解析
# llm parsed information
plans: List[str] = None
code_content: str = None
code_filename: str = None
@ -30,7 +26,7 @@ class Message(BaseModel):
spec_parsed_output: dict = {}
parsed_output_list: List[Dict] = []
# 执行结果
# llm\tool\code executre information
action_status: str = ActionStatus.DEFAUILT
agent_index: int = None
code_answer: str = None
@ -38,7 +34,7 @@ class Message(BaseModel):
observation: str = None
figures: Dict[str, str] = {}
# 辅助信息
# prompt support information
tools: List[BaseTool] = []
task: Task = None
db_docs: List['Doc'] = []
@ -46,7 +42,7 @@ class Message(BaseModel):
search_docs: List['Doc'] = []
agents: List = []
# 执行输入
# phase input
phase_name: str = None
chain_name: str = None
do_search: bool = False
@ -60,6 +56,8 @@ class Message(BaseModel):
do_code_retrieval: bool = False
do_tool_retrieval: bool = False
history_node_list: List[str] = []
# user's customed kargs for init or end action
customed_kargs: dict = {}
def to_tuple_message(self, return_all: bool = True, content_key="role_content"):
role_content = self.to_str_content(False, content_key)

View File

@ -29,7 +29,7 @@ class NebulaHandler:
self.password = password
self.space_name = space_name
def execute_cypher(self, cypher: str, space_name: str = ''):
def execute_cypher(self, cypher: str, space_name: str = '', format_res: bool = False, use_space_name: bool = True):
'''
@param space_name: space_name, if provided, will execute use space_name first
@ -37,11 +37,17 @@ class NebulaHandler:
@return:
'''
with self.connection_pool.session_context(self.username, self.password) as session:
if space_name:
cypher = f'USE {space_name};{cypher}'
if use_space_name:
if space_name:
cypher = f'USE {space_name};{cypher}'
elif self.space_name:
cypher = f'USE {self.space_name};{cypher}'
logger.debug(cypher)
resp = session.execute(cypher)
if format_res:
resp = self.result_to_dict(resp)
return resp
def close_connection(self):
@ -54,7 +60,8 @@ class NebulaHandler:
@return:
'''
cypher = f'CREATE SPACE IF NOT EXISTS {space_name} (vid_type={vid_type}) comment="{comment}";'
resp = self.execute_cypher(cypher)
resp = self.execute_cypher(cypher, use_space_name=False)
return resp
def show_space(self):
@ -95,7 +102,7 @@ class NebulaHandler:
def insert_vertex(self, tag_name: str, value_dict: dict):
'''
insert vertex
insert vertex
@param tag_name:
@param value_dict: {'properties_name': [], values: {'vid':[]}} order should be the same in properties_name and values
@return:
@ -243,7 +250,7 @@ class NebulaHandler:
"""
build list for each column, and transform to dataframe
"""
logger.info(result.error_msg())
# logger.info(result.error_msg())
assert result.is_succeeded()
columns = result.keys()
d = {}

View File

@ -11,8 +11,6 @@ import json
import os
from loguru import logger
from configs.model_config import OPENAI_API_BASE
class OpenAIEmbedding:
def __init__(self):

View File

@ -1,6 +1,6 @@
from .openai_model import getChatModel
from .openai_model import getChatModel, getExtraModel
__all__ = [
"getChatModel"
"getChatModel", "getExtraModel"
]

View File

@ -26,4 +26,23 @@ def getChatModel(callBack: AsyncIteratorCallbackHandler = None, temperature=0.3,
temperature=temperature,
stop=stop
)
return model
return model
import json, requests
def getExtraModel():
return TestModel()
class TestModel:
def predict(self, request_body):
headers = {"Content-Type":"application/json;charset=UTF-8",
"codegpt_user":"",
"codegpt_token":""
}
xxx = requests.post(
'https://codegencore.alipay.com/api/chat/CODE_LLAMA_INT4/completion',
data=json.dumps(request_body,ensure_ascii=False).encode('utf-8'),
headers=headers)
return xxx.json()["data"]

View File

@ -18,6 +18,10 @@ from .service_factory import KBServiceFactory
from dev_opsgpt.utils.server_utils import BaseResponse, ListResponse
from dev_opsgpt.utils.path_utils import *
from dev_opsgpt.orm.commands import *
from dev_opsgpt.db_handler.graph_db_handler.nebula_handler import NebulaHandler
from dev_opsgpt.db_handler.vector_db_handler.chroma_handler import ChromaHandler
from configs.server_config import NEBULA_HOST, NEBULA_PORT, NEBULA_USER, NEBULA_PASSWORD, NEBULA_STORAGED_PORT
from configs.server_config import CHROMA_PERSISTENT_PATH
from configs.model_config import (
CB_ROOT_PATH
@ -125,6 +129,60 @@ def search_code(cb_name: str = Body(..., examples=["sofaboot"]),
return {}
def search_related_vertices(cb_name: str = Body(..., examples=["sofaboot"]),
vertex: str = Body(..., examples=['***'])) -> dict:
logger.info('cb_name={}'.format(cb_name))
logger.info('vertex={}'.format(vertex))
try:
# load codebase
nh = NebulaHandler(host=NEBULA_HOST, port=NEBULA_PORT, username=NEBULA_USER,
password=NEBULA_PASSWORD, space_name=cb_name)
cypher = f'''MATCH (v1)--(v2) WHERE id(v1) == '{vertex}' RETURN id(v2) as id;'''
cypher_res = nh.execute_cypher(cypher=cypher, format_res=True)
related_vertices = cypher_res.get('id', [])
res = {
'vertices': related_vertices
}
return res
except Exception as e:
logger.exception(e)
return {}
def search_code_by_vertex(cb_name: str = Body(..., examples=["sofaboot"]),
vertex: str = Body(..., examples=['***'])) -> dict:
logger.info('cb_name={}'.format(cb_name))
logger.info('vertex={}'.format(vertex))
try:
ch = ChromaHandler(path=CHROMA_PERSISTENT_PATH, collection_name=cb_name)
# fix vertex
vertex_use = '#'.join(vertex.split('#')[0:2])
ids = [vertex_use]
chroma_res = ch.get(ids=ids)
code_text = chroma_res['result']['metadatas'][0]['code_text']
res = {
'code': code_text
}
return res
except Exception as e:
logger.exception(e)
return {}
def cb_exists_api(cb_name: str = Body(..., examples=["sofaboot"])) -> bool:
try:
res = cb_exists(cb_name)

View File

@ -1,4 +1,11 @@
# Attention: code copied from https://github.com/chatchat-space/Langchain-Chatchat/blob/master/server/llm_api.py
############################# Attention ########################
# Code copied from
# https://github.com/chatchat-space/Langchain-Chatchat/blob/master/server/llm_api.py
#################################################################
from multiprocessing import Process, Queue
import multiprocessing as mp
@ -16,10 +23,12 @@ src_dir = os.path.join(
sys.path.append(src_dir)
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
from configs.model_config import llm_model_dict, LLM_MODEL, LLM_DEVICE, LOG_PATH, logger
from configs.model_config import llm_model_dict, LLM_MODEL, LLM_DEVICE, LOG_PATH, logger, LLM_MODELs
from configs.server_config import (
FSCHAT_CONTROLLER, FSCHAT_MODEL_WORKERS, FSCHAT_OPENAI_API
)
from dev_opsgpt.service.utils import get_model_worker_config
from dev_opsgpt.utils.server_utils import (
MakeFastAPIOffline,
)
@ -43,38 +52,6 @@ def set_httpx_timeout(timeout=60.0):
httpx._config.DEFAULT_TIMEOUT_CONFIG.write = timeout
def get_model_worker_config(model_name: str = None) -> dict:
'''
加载model worker的配置项
优先级:FSCHAT_MODEL_WORKERS[model_name] > ONLINE_LLM_MODEL[model_name] > FSCHAT_MODEL_WORKERS["default"]
'''
from configs.model_config import ONLINE_LLM_MODEL
from configs.server_config import FSCHAT_MODEL_WORKERS
# from server import model_workers
config = FSCHAT_MODEL_WORKERS.get("default", {}).copy()
# config.update(ONLINE_LLM_MODEL.get(model_name, {}).copy())
config.update(FSCHAT_MODEL_WORKERS.get(model_name, {}).copy())
# if model_name in ONLINE_LLM_MODEL:
# config["online_api"] = True
# if provider := config.get("provider"):
# try:
# config["worker_class"] = getattr(model_workers, provider)
# except Exception as e:
# msg = f"在线模型 {model_name} 的provider没有正确配置"
# logger.error(f'{e.__class__.__name__}: {msg}',
# exc_info=e if log_verbose else None)
# 本地模型
if model_name in llm_model_dict:
path = llm_model_dict[model_name]["local_model_path"]
config["model_path"] = path
if path and os.path.isdir(path):
config["model_path_exists"] = True
config["device"] = LLM_DEVICE
return config
def get_all_model_worker_configs() -> dict:
result = {}
model_names = set(FSCHAT_MODEL_WORKERS.keys())
@ -281,6 +258,9 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
for k, v in kwargs.items():
setattr(args, k, v)
logger.error(f"可用模型有哪些: {args.model_names}")
if worker_class := kwargs.get("langchain_model"): #Langchian支持的模型不用做操作
from fastchat.serve.base_model_worker import app
worker = ""
@ -296,7 +276,8 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
# 本地模型
else:
from configs.model_config import VLLM_MODEL_DICT
if kwargs["model_names"][0] in VLLM_MODEL_DICT and args.infer_turbo == "vllm":
# if kwargs["model_names"][0] in VLLM_MODEL_DICT and args.infer_turbo == "vllm":
if kwargs["model_names"][0] in VLLM_MODEL_DICT:
import fastchat.serve.vllm_worker
from fastchat.serve.vllm_worker import VLLMWorker, app, worker_id
from vllm import AsyncLLMEngine
@ -321,7 +302,7 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
args.conv_template = None
args.limit_worker_concurrency = 5
args.no_register = False
args.num_gpus = 4 # vllm worker的切分是tensor并行这里填写显卡的数量
args.num_gpus = 1 # vllm worker的切分是tensor并行这里填写显卡的数量
args.engine_use_ray = False
args.disable_log_requests = False
@ -575,7 +556,7 @@ def run_model_worker(
kwargs["worker_address"] = fschat_model_worker_address(model_name)
model_path = kwargs.get("model_path", "")
kwargs["model_path"] = model_path
# kwargs["gptq_wbits"] = 4
# kwargs["gptq_wbits"] = 4 # int4 模型试用这个参数
app = create_model_worker_app(log_level=log_level, **kwargs)
_set_app_event(app, started_event)
@ -660,7 +641,7 @@ def parse_args() -> argparse.ArgumentParser:
"--model-name",
type=str,
nargs="+",
default=[LLM_MODEL],
default=LLM_MODELs,
help="specify model name for model worker. "
"add addition names with space seperated to start multiple model workers.",
dest="model_name",
@ -722,7 +703,7 @@ def dump_server_info(after_start=False, args=None):
print(f"langchain版本{langchain.__version__}. fastchat版本{fastchat.__version__}")
print("\n")
models = [LLM_MODEL]
models = LLM_MODELs
if args and args.model_name:
models = args.model_name
@ -813,8 +794,10 @@ async def start_main_server():
)
processes["model_worker"][model_name] = process
for model_name in args.model_name:
config = get_model_worker_config(model_name)
logger.error(f"config: {config}, {model_name}, {FSCHAT_MODEL_WORKERS.keys()}")
if (config.get("online_api")
and config.get("worker_class")
and model_name in FSCHAT_MODEL_WORKERS):

View File

@ -0,0 +1,79 @@
import base64
import datetime
import hashlib
import hmac
from urllib.parse import urlparse
from datetime import datetime
from time import mktime
from urllib.parse import urlencode
from wsgiref.handlers import format_date_time
class Ws_Param(object):
# 初始化
def __init__(self, APPID, APIKey, APISecret, Spark_url):
self.APPID = APPID
self.APIKey = APIKey
self.APISecret = APISecret
self.host = urlparse(Spark_url).netloc
self.path = urlparse(Spark_url).path
self.Spark_url = Spark_url
# 生成url
def create_url(self):
# 生成RFC1123格式的时间戳
now = datetime.now()
date = format_date_time(mktime(now.timetuple()))
# 拼接字符串
signature_origin = "host: " + self.host + "\n"
signature_origin += "date: " + date + "\n"
signature_origin += "GET " + self.path + " HTTP/1.1"
# 进行hmac-sha256进行加密
signature_sha = hmac.new(self.APISecret.encode('utf-8'), signature_origin.encode('utf-8'),
digestmod=hashlib.sha256).digest()
signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding='utf-8')
authorization_origin = f'api_key="{self.APIKey}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"'
authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
# 将请求的鉴权参数组合为字典
v = {
"authorization": authorization,
"date": date,
"host": self.host
}
# 拼接鉴权参数生成url
url = self.Spark_url + '?' + urlencode(v)
# 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释比对相同参数时生成的url与自己代码生成的url是否一致
return url
def gen_params(appid, domain, question, temperature, max_token):
"""
通过appid和用户的提问来生成请参数
"""
data = {
"header": {
"app_id": appid,
"uid": "1234"
},
"parameter": {
"chat": {
"domain": domain,
"random_threshold": 0.5,
"max_tokens": max_token,
"auditing": "default",
"temperature": temperature,
}
},
"payload": {
"message": {
"text": question
}
}
}
return data

View File

@ -0,0 +1,18 @@
############################# Attention ########################
# The Code in model workers all copied from
# https://github.com/chatchat-space/Langchain-Chatchat/blob/master/server/model_workers
#################################################################
from .base import *
from .zhipu import ChatGLMWorker
from .minimax import MiniMaxWorker
from .xinghuo import XingHuoWorker
from .qianfan import QianFanWorker
from .fangzhou import FangZhouWorker
from .qwen import QwenWorker
from .baichuan import BaiChuanWorker
from .azure import AzureWorker
from .tiangong import TianGongWorker
from .openai import ExampleWorker

View File

@ -0,0 +1,94 @@
import sys
from fastchat.conversation import Conversation
from .base import *
# from server.utils import get_httpx_client
from fastchat import conversation as conv
import json
from typing import List, Dict
from configs import logger, log_verbose
class AzureWorker(ApiModelWorker):
def __init__(
self,
*,
controller_addr: str = None,
worker_addr: str = None,
model_names: List[str] = ["azure-api"],
version: str = "gpt-35-turbo",
**kwargs,
):
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
kwargs.setdefault("context_len", 8000) #TODO 16K模型需要改成16384
super().__init__(**kwargs)
self.version = version
def do_chat(self, params: ApiChatParams) -> Dict:
params.load_config(self.model_names[0])
data = dict(
messages=params.messages,
temperature=params.temperature,
max_tokens=params.max_tokens,
stream=True,
)
url = ("https://{}.openai.azure.com/openai/deployments/{}/chat/completions?api-version={}"
.format(params.resource_name, params.deployment_name, params.api_version))
headers = {
'Content-Type': 'application/json',
'Accept': 'application/json',
'api-key': params.api_key,
}
text = ""
if log_verbose:
logger.info(f'{self.__class__.__name__}:url: {url}')
logger.info(f'{self.__class__.__name__}:headers: {headers}')
logger.info(f'{self.__class__.__name__}:data: {data}')
with get_httpx_client() as client:
with client.stream("POST", url, headers=headers, json=data) as response:
for line in response.iter_lines():
if not line.strip() or "[DONE]" in line:
continue
if line.startswith("data: "):
line = line[6:]
resp = json.loads(line)
if choices := resp["choices"]:
if chunk := choices[0].get("delta", {}).get("content"):
text += chunk
yield {
"error_code": 0,
"text": text
}
else:
self.logger.error(f"请求 Azure API 时发生错误:{resp}")
def get_embeddings(self, params):
# TODO: 支持embeddings
print("embedding")
print(params)
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
# TODO: 确认模板是否需要修改
return conv.Conversation(
name=self.model_names[0],
system_message="You are a helpful, respectful and honest assistant.",
messages=[],
roles=["user", "assistant"],
sep="\n### ",
stop_str="###",
)
if __name__ == "__main__":
import uvicorn
from server.utils import MakeFastAPIOffline
from fastchat.serve.base_model_worker import app
worker = AzureWorker(
controller_addr="http://127.0.0.1:20001",
worker_addr="http://127.0.0.1:21008",
)
sys.modules["fastchat.serve.model_worker"].worker = worker
MakeFastAPIOffline(app)
uvicorn.run(app, port=21008)

View File

@ -0,0 +1,119 @@
import json
import time
import hashlib
from fastchat.conversation import Conversation
from .base import *
# from server.utils import get_httpx_client
from fastchat import conversation as conv
import sys
import json
from typing import List, Literal, Dict
from configs import logger, log_verbose
def calculate_md5(input_string):
md5 = hashlib.md5()
md5.update(input_string.encode('utf-8'))
encrypted = md5.hexdigest()
return encrypted
class BaiChuanWorker(ApiModelWorker):
def __init__(
self,
*,
controller_addr: str = None,
worker_addr: str = None,
model_names: List[str] = ["baichuan-api"],
version: Literal["Baichuan2-53B"] = "Baichuan2-53B",
**kwargs,
):
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
kwargs.setdefault("context_len", 32768)
super().__init__(**kwargs)
self.version = version
def do_chat(self, params: ApiChatParams) -> Dict:
params.load_config(self.model_names[0])
url = "https://api.baichuan-ai.com/v1/stream/chat"
data = {
"model": params.version,
"messages": params.messages,
"parameters": {"temperature": params.temperature}
}
json_data = json.dumps(data)
time_stamp = int(time.time())
signature = calculate_md5(params.secret_key + json_data + str(time_stamp))
headers = {
"Content-Type": "application/json",
"Authorization": "Bearer " + params.api_key,
"X-BC-Request-Id": "your requestId",
"X-BC-Timestamp": str(time_stamp),
"X-BC-Signature": signature,
"X-BC-Sign-Algo": "MD5",
}
text = ""
if log_verbose:
logger.info(f'{self.__class__.__name__}:json_data: {json_data}')
logger.info(f'{self.__class__.__name__}:url: {url}')
logger.info(f'{self.__class__.__name__}:headers: {headers}')
with get_httpx_client() as client:
with client.stream("POST", url, headers=headers, json=data) as response:
for line in response.iter_lines():
if not line.strip():
continue
resp = json.loads(line)
if resp["code"] == 0:
text += resp["data"]["messages"][-1]["content"]
yield {
"error_code": resp["code"],
"text": text
}
else:
data = {
"error_code": resp["code"],
"text": resp["msg"],
"error": {
"message": resp["msg"],
"type": "invalid_request_error",
"param": None,
"code": None,
}
}
self.logger.error(f"请求百川 API 时发生错误:{data}")
yield data
def get_embeddings(self, params):
# TODO: 支持embeddings
print("embedding")
print(params)
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
# TODO: 确认模板是否需要修改
return conv.Conversation(
name=self.model_names[0],
system_message="",
messages=[],
roles=["user", "assistant"],
sep="\n### ",
stop_str="###",
)
if __name__ == "__main__":
import uvicorn
from server.utils import MakeFastAPIOffline
from fastchat.serve.model_worker import app
worker = BaiChuanWorker(
controller_addr="http://127.0.0.1:20001",
worker_addr="http://127.0.0.1:21007",
)
sys.modules["fastchat.serve.model_worker"].worker = worker
MakeFastAPIOffline(app)
uvicorn.run(app, port=21007)
# do_request()

View File

@ -0,0 +1,249 @@
from fastchat.conversation import Conversation
from configs import LOG_PATH
import fastchat.constants
fastchat.constants.LOGDIR = LOG_PATH
from fastchat.serve.base_model_worker import BaseModelWorker
import uuid
import json
import sys
from pydantic import BaseModel, root_validator
import fastchat
import asyncio
from dev_opsgpt.service.utils import get_model_worker_config
from typing import Dict, List, Optional
__all__ = ["ApiModelWorker", "ApiChatParams", "ApiCompletionParams", "ApiEmbeddingsParams"]
class ApiConfigParams(BaseModel):
'''
在线API配置参数未提供的值会自动从model_config.ONLINE_LLM_MODEL中读取
'''
api_base_url: Optional[str] = None
api_proxy: Optional[str] = None
api_key: Optional[str] = None
secret_key: Optional[str] = None
group_id: Optional[str] = None # for minimax
is_pro: bool = False # for minimax
APPID: Optional[str] = None # for xinghuo
APISecret: Optional[str] = None # for xinghuo
is_v2: bool = False # for xinghuo
worker_name: Optional[str] = None
class Config:
extra = "allow"
@root_validator(pre=True)
def validate_config(cls, v: Dict) -> Dict:
if config := get_model_worker_config(v.get("worker_name")):
for n in cls.__fields__:
if n in config:
v[n] = config[n]
return v
def load_config(self, worker_name: str):
self.worker_name = worker_name
if config := get_model_worker_config(worker_name):
for n in self.__fields__:
if n in config:
setattr(self, n, config[n])
return self
class ApiModelParams(ApiConfigParams):
'''
模型配置参数
'''
version: Optional[str] = None
version_url: Optional[str] = None
api_version: Optional[str] = None # for azure
deployment_name: Optional[str] = None # for azure
resource_name: Optional[str] = None # for azure
temperature: float = 0.7
max_tokens: Optional[int] = None
top_p: Optional[float] = 1.0
class ApiChatParams(ApiModelParams):
'''
chat请求参数
'''
messages: List[Dict[str, str]]
system_message: Optional[str] = None # for minimax
role_meta: Dict = {} # for minimax
class ApiCompletionParams(ApiModelParams):
prompt: str
class ApiEmbeddingsParams(ApiConfigParams):
texts: List[str]
embed_model: Optional[str] = None
to_query: bool = False # for minimax
class ApiModelWorker(BaseModelWorker):
DEFAULT_EMBED_MODEL: str = None # None means not support embedding
def __init__(
self,
model_names: List[str],
controller_addr: str = None,
worker_addr: str = None,
context_len: int = 2048,
no_register: bool = False,
**kwargs,
):
kwargs.setdefault("worker_id", uuid.uuid4().hex[:8])
kwargs.setdefault("model_path", "")
kwargs.setdefault("limit_worker_concurrency", 5)
super().__init__(model_names=model_names,
controller_addr=controller_addr,
worker_addr=worker_addr,
**kwargs)
import fastchat.serve.base_model_worker
import sys
self.logger = fastchat.serve.base_model_worker.logger
# 恢复被fastchat覆盖的标准输出
sys.stdout = sys.__stdout__
sys.stderr = sys.__stderr__
self.context_len = context_len
self.semaphore = asyncio.Semaphore(self.limit_worker_concurrency)
self.version = None
if not no_register and self.controller_addr:
self.init_heart_beat()
def count_token(self, params):
# TODO需要完善
# print("count token")
prompt = params["prompt"]
return {"count": len(str(prompt)), "error_code": 0}
def generate_stream_gate(self, params: Dict):
self.call_ct += 1
try:
prompt = params["prompt"]
if self._is_chat(prompt):
messages = self.prompt_to_messages(prompt)
messages = self.validate_messages(messages)
else: # 使用chat模仿续写功能不支持历史消息
messages = [{"role": self.user_role, "content": f"please continue writing from here: {prompt}"}]
p = ApiChatParams(
messages=messages,
temperature=params.get("temperature"),
top_p=params.get("top_p"),
max_tokens=params.get("max_new_tokens"),
version=self.version,
)
for resp in self.do_chat(p):
yield self._jsonify(resp)
except Exception as e:
yield self._jsonify({"error_code": 500, "text": f"{self.model_names[0]}请求API时发生错误{e}"})
def generate_gate(self, params):
try:
for x in self.generate_stream_gate(params):
...
return json.loads(x[:-1].decode())
except Exception as e:
return {"error_code": 500, "text": str(e)}
# 需要用户自定义的方法
def do_chat(self, params: ApiChatParams) -> Dict:
'''
执行Chat的方法默认使用模块里面的chat函数
要求返回形式{"error_code": int, "text": str}
'''
return {"error_code": 500, "text": f"{self.model_names[0]}未实现chat功能"}
# def do_completion(self, p: ApiCompletionParams) -> Dict:
# '''
# 执行Completion的方法默认使用模块里面的completion函数。
# 要求返回形式:{"error_code": int, "text": str}
# '''
# return {"error_code": 500, "text": f"{self.model_names[0]}未实现completion功能"}
def do_embeddings(self, params: ApiEmbeddingsParams) -> Dict:
'''
执行Embeddings的方法默认使用模块里面的embed_documents函数
要求返回形式{"code": int, "data": List[List[float]], "msg": str}
'''
return {"code": 500, "msg": f"{self.model_names[0]}未实现embeddings功能"}
def get_embeddings(self, params):
# fastchat对LLM做Embeddings限制很大似乎只能使用openai的。
# 在前端通过OpenAIEmbeddings发起的请求直接出错无法请求过来。
print("get_embedding")
print(params)
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
raise NotImplementedError
def validate_messages(self, messages: List[Dict]) -> List[Dict]:
'''
有些API对mesages有特殊格式可以重写该函数替换默认的messages
之所以跟prompt_to_messages分开是因为他们应用场景不同参数不同
'''
return messages
# help methods
@property
def user_role(self):
return self.conv.roles[0]
@property
def ai_role(self):
return self.conv.roles[1]
def _jsonify(self, data: Dict) -> str:
'''
将chat函数返回的结果按照fastchat openai-api-server的格式返回
'''
return json.dumps(data, ensure_ascii=False).encode() + b"\0"
def _is_chat(self, prompt: str) -> bool:
'''
检查prompt是否由chat messages拼接而来
TODO: 存在误判的可能也许从fastchat直接传入原始messages是更好的做法
'''
key = f"{self.conv.sep}{self.user_role}:"
return key in prompt
def prompt_to_messages(self, prompt: str) -> List[Dict]:
'''
将prompt字符串拆分成messages.
'''
result = []
user_role = self.user_role
ai_role = self.ai_role
user_start = user_role + ":"
ai_start = ai_role + ":"
for msg in prompt.split(self.conv.sep)[1:-1]:
if msg.startswith(user_start):
if content := msg[len(user_start):].strip():
result.append({"role": user_role, "content": content})
elif msg.startswith(ai_start):
if content := msg[len(ai_start):].strip():
result.append({"role": ai_role, "content": content})
else:
raise RuntimeError(f"unknown role in msg: {msg}")
return result
@classmethod
def can_embedding(cls):
return cls.DEFAULT_EMBED_MODEL is not None

View File

@ -0,0 +1,106 @@
from fastchat.conversation import Conversation
from .base import *
from fastchat import conversation as conv
import sys
from typing import List, Literal, Dict
from configs import logger, log_verbose
class FangZhouWorker(ApiModelWorker):
"""
火山方舟
"""
def __init__(
self,
*,
model_names: List[str] = ["fangzhou-api"],
controller_addr: str = None,
worker_addr: str = None,
version: Literal["chatglm-6b-model"] = "chatglm-6b-model",
**kwargs,
):
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
kwargs.setdefault("context_len", 16384) # TODO: 不同的模型有不同的大小
super().__init__(**kwargs)
self.version = version
def do_chat(self, params: ApiChatParams) -> Dict:
from volcengine.maas import MaasService
params.load_config(self.model_names[0])
maas = MaasService('maas-api.ml-platform-cn-beijing.volces.com', 'cn-beijing')
maas.set_ak(params.api_key)
maas.set_sk(params.secret_key)
# document: "https://www.volcengine.com/docs/82379/1099475"
req = {
"model": {
"name": params.version,
},
"parameters": {
# 这里的参数仅为示例,具体可用的参数请参考具体模型的 API 说明
"max_new_tokens": params.max_tokens,
"temperature": params.temperature,
},
"messages": params.messages,
}
text = ""
if log_verbose:
self.logger.info(f'{self.__class__.__name__}:maas: {maas}')
for resp in maas.stream_chat(req):
if error := resp.error:
if error.code_n > 0:
data = {
"error_code": error.code_n,
"text": error.message,
"error": {
"message": error.message,
"type": "invalid_request_error",
"param": None,
"code": None,
}
}
self.logger.error(f"请求方舟 API 时发生错误:{data}")
yield data
elif chunk := resp.choice.message.content:
text += chunk
yield {"error_code": 0, "text": text}
else:
data = {
"error_code": 500,
"text": f"请求方舟 API 时发生未知的错误: {resp}"
}
self.logger.error(data)
yield data
break
def get_embeddings(self, params):
# TODO: 支持embeddings
print("embedding")
print(params)
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
return conv.Conversation(
name=self.model_names[0],
system_message="你是一个聪明、对人类有帮助的人工智能,你可以对人类提出的问题给出有用、详细、礼貌的回答。",
messages=[],
roles=["user", "assistant", "system"],
sep="\n### ",
stop_str="###",
)
if __name__ == "__main__":
import uvicorn
from server.utils import MakeFastAPIOffline
from fastchat.serve.model_worker import app
worker = FangZhouWorker(
controller_addr="http://127.0.0.1:20001",
worker_addr="http://127.0.0.1:21005",
)
sys.modules["fastchat.serve.model_worker"].worker = worker
MakeFastAPIOffline(app)
uvicorn.run(app, port=21005)

View File

@ -0,0 +1,172 @@
from fastchat.conversation import Conversation
from .base import *
from fastchat import conversation as conv
import sys
import json
# from server.utils import get_httpx_client
from typing import List, Dict
from configs import logger, log_verbose
class MiniMaxWorker(ApiModelWorker):
DEFAULT_EMBED_MODEL = "embo-01"
def __init__(
self,
*,
model_names: List[str] = ["minimax-api"],
controller_addr: str = None,
worker_addr: str = None,
version: str = "abab5.5-chat",
**kwargs,
):
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
kwargs.setdefault("context_len", 16384)
super().__init__(**kwargs)
self.version = version
def validate_messages(self, messages: List[Dict]) -> List[Dict]:
role_maps = {
"user": self.user_role,
"assistant": self.ai_role,
"system": "system",
}
messages = [{"sender_type": role_maps[x["role"]], "text": x["content"]} for x in messages]
return messages
def do_chat(self, params: ApiChatParams) -> Dict:
# 按照官网推荐直接调用abab 5.5模型
# TODO: 支持指定回复要求支持指定用户名称、AI名称
params.load_config(self.model_names[0])
url = 'https://api.minimax.chat/v1/text/chatcompletion{pro}?GroupId={group_id}'
pro = "_pro" if params.is_pro else ""
headers = {
"Authorization": f"Bearer {params.api_key}",
"Content-Type": "application/json",
}
messages = self.validate_messages(params.messages)
data = {
"model": params.version,
"stream": True,
"mask_sensitive_info": True,
"messages": messages,
"temperature": params.temperature,
"top_p": params.top_p,
"tokens_to_generate": params.max_tokens or 1024,
# TODO: 以下参数为minimax特有传入空值会出错。
# "prompt": params.system_message or self.conv.system_message,
# "bot_setting": [],
# "role_meta": params.role_meta,
}
if log_verbose:
logger.info(f'{self.__class__.__name__}:data: {data}')
logger.info(f'{self.__class__.__name__}:url: {url.format(pro=pro, group_id=params.group_id)}')
logger.info(f'{self.__class__.__name__}:headers: {headers}')
with get_httpx_client() as client:
response = client.stream("POST",
url.format(pro=pro, group_id=params.group_id),
headers=headers,
json=data)
with response as r:
text = ""
for e in r.iter_text():
if not e.startswith("data: "): # 真是优秀的返回
data = {
"error_code": 500,
"text": f"minimax返回错误的结果{e}",
"error": {
"message": f"minimax返回错误的结果{e}",
"type": "invalid_request_error",
"param": None,
"code": None,
}
}
self.logger.error(f"请求 MiniMax API 时发生错误:{data}")
yield data
continue
data = json.loads(e[6:])
if data.get("usage"):
break
if choices := data.get("choices"):
if chunk := choices[0].get("delta", ""):
text += chunk
yield {"error_code": 0, "text": text}
def do_embeddings(self, params: ApiEmbeddingsParams) -> Dict:
params.load_config(self.model_names[0])
url = f"https://api.minimax.chat/v1/embeddings?GroupId={params.group_id}"
headers = {
"Authorization": f"Bearer {params.api_key}",
"Content-Type": "application/json",
}
data = {
"model": params.embed_model or self.DEFAULT_EMBED_MODEL,
"texts": [],
"type": "query" if params.to_query else "db",
}
if log_verbose:
logger.info(f'{self.__class__.__name__}:data: {data}')
logger.info(f'{self.__class__.__name__}:url: {url}')
logger.info(f'{self.__class__.__name__}:headers: {headers}')
with get_httpx_client() as client:
result = []
i = 0
batch_size = 10
while i < len(params.texts):
texts = params.texts[i:i+batch_size]
data["texts"] = texts
r = client.post(url, headers=headers, json=data).json()
if embeddings := r.get("vectors"):
result += embeddings
elif error := r.get("base_resp"):
data = {
"code": error["status_code"],
"msg": error["status_msg"],
"error": {
"message": error["status_msg"],
"type": "invalid_request_error",
"param": None,
"code": None,
}
}
self.logger.error(f"请求 MiniMax API 时发生错误:{data}")
return data
i += batch_size
return {"code": 200, "data": embeddings}
def get_embeddings(self, params):
# TODO: 支持embeddings
print("embedding")
print(params)
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
# TODO: 确认模板是否需要修改
return conv.Conversation(
name=self.model_names[0],
system_message="你是MiniMax自主研发的大型语言模型回答问题简洁有条理。",
messages=[],
roles=["USER", "BOT"],
sep="\n### ",
stop_str="###",
)
if __name__ == "__main__":
import uvicorn
from server.utils import MakeFastAPIOffline
from fastchat.serve.model_worker import app
worker = MiniMaxWorker(
controller_addr="http://127.0.0.1:20001",
worker_addr="http://127.0.0.1:21002",
)
sys.modules["fastchat.serve.model_worker"].worker = worker
MakeFastAPIOffline(app)
uvicorn.run(app, port=21002)

View File

@ -0,0 +1,92 @@
import sys
from fastchat.conversation import Conversation
from .base import *
from fastchat import conversation as conv
import json
from typing import List, Dict
from configs import logger
import openai
from langchain import PromptTemplate, LLMChain
from langchain.prompts.chat import ChatPromptTemplate
from langchain.chat_models import ChatOpenAI
from langchain.schema import HumanMessage
class ExampleWorker(ApiModelWorker):
def __init__(
self,
*,
controller_addr: str = None,
worker_addr: str = None,
model_names: List[str] = ["gpt-3.5-turbo"],
version: str = "gpt-3.5",
**kwargs,
):
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
kwargs.setdefault("context_len", 16384) #TODO 16K模型需要改成16384
super().__init__(**kwargs)
self.version = version
def do_chat(self, params: ApiChatParams) -> Dict:
'''
yield output: {"error_code": 0, "text": ""}
'''
params.load_config(self.model_names[0])
openai.api_key = params.api_key
openai.api_base = params.api_base_url
logger.error(f"{params.api_key}, {params.api_base_url}, {params.messages} {params.max_tokens},")
# just for example
prompt = "\n".join([f"{m['role']}:{m['content']}" for m in params.messages])
logger.error(f"{prompt}, {params.temperature}, {params.max_tokens}")
try:
model = ChatOpenAI(
streaming=True,
verbose=True,
openai_api_key= params.api_key,
openai_api_base=params.api_base_url,
model_name=params.version
)
chat_prompt = ChatPromptTemplate.from_messages([("human", "{input}")])
chain = LLMChain(prompt=chat_prompt, llm=model)
content = chain({"input": prompt})
logger.info(content)
except Exception as e:
logger.error(f"{e}")
yield {"error_code": 500, "text": "request error"}
# return the text by yield for stream
try:
yield {"error_code": 0, "text": content["text"]}
except:
yield {"error_code": 500, "text": "request error"}
def get_embeddings(self, params):
# TODO: 支持embeddings
print("embedding")
print(params)
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
# TODO: 确认模板是否需要修改
return conv.Conversation(
name=self.model_names[0],
system_message="You are a helpful, respectful and honest assistant.",
messages=[],
roles=["user", "assistant", "system"],
sep="\n### ",
stop_str="###",
)
if __name__ == "__main__":
import uvicorn
from dev_opsgpt.utils.server_utils import MakeFastAPIOffline
from fastchat.serve.base_model_worker import app
worker = ExampleWorker(
controller_addr="http://127.0.0.1:20001",
worker_addr="http://127.0.0.1:21008",
)
sys.modules["fastchat.serve.model_worker"].worker = worker
uvicorn.run(app, port=21008)

View File

@ -0,0 +1,242 @@
import sys
from fastchat.conversation import Conversation
from .base import *
# from server.utils import get_httpx_client
from cachetools import cached, TTLCache
import json
from fastchat import conversation as conv
import sys
from typing import List, Literal, Dict
from configs import logger, log_verbose
MODEL_VERSIONS = {
"ernie-bot-4": "completions_pro",
"ernie-bot": "completions",
"ernie-bot-turbo": "eb-instant",
"bloomz-7b": "bloomz_7b1",
"qianfan-bloomz-7b-c": "qianfan_bloomz_7b_compressed",
"llama2-7b-chat": "llama_2_7b",
"llama2-13b-chat": "llama_2_13b",
"llama2-70b-chat": "llama_2_70b",
"qianfan-llama2-ch-7b": "qianfan_chinese_llama_2_7b",
"chatglm2-6b-32k": "chatglm2_6b_32k",
"aquilachat-7b": "aquilachat_7b",
# "linly-llama2-ch-7b": "", # 暂未发布
# "linly-llama2-ch-13b": "", # 暂未发布
# "chatglm2-6b": "", # 暂未发布
# "chatglm2-6b-int4": "", # 暂未发布
# "falcon-7b": "", # 暂未发布
# "falcon-180b-chat": "", # 暂未发布
# "falcon-40b": "", # 暂未发布
# "rwkv4-world": "", # 暂未发布
# "rwkv5-world": "", # 暂未发布
# "rwkv4-pile-14b": "", # 暂未发布
# "rwkv4-raven-14b": "", # 暂未发布
# "open-llama-7b": "", # 暂未发布
# "dolly-12b": "", # 暂未发布
# "mpt-7b-instruct": "", # 暂未发布
# "mpt-30b-instruct": "", # 暂未发布
# "OA-Pythia-12B-SFT-4": "", # 暂未发布
# "xverse-13b": "", # 暂未发布
# # 以下为企业测试,需要单独申请
# "flan-ul2": "",
# "Cerebras-GPT-6.7B": ""
# "Pythia-6.9B": ""
}
@cached(TTLCache(1, 1800)) # 经过测试缓存的token可以使用目前每30分钟刷新一次
def get_baidu_access_token(api_key: str, secret_key: str) -> str:
"""
使用 AKSK 生成鉴权签名Access Token
:return: access_token或是None(如果错误)
"""
url = "https://aip.baidubce.com/oauth/2.0/token"
params = {"grant_type": "client_credentials", "client_id": api_key, "client_secret": secret_key}
try:
with get_httpx_client() as client:
return client.get(url, params=params).json().get("access_token")
except Exception as e:
print(f"failed to get token from baidu: {e}")
class QianFanWorker(ApiModelWorker):
"""
百度千帆
"""
DEFAULT_EMBED_MODEL = "embedding-v1"
def __init__(
self,
*,
version: Literal["ernie-bot", "ernie-bot-turbo"] = "ernie-bot",
model_names: List[str] = ["qianfan-api"],
controller_addr: str = None,
worker_addr: str = None,
**kwargs,
):
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
kwargs.setdefault("context_len", 16384)
super().__init__(**kwargs)
self.version = version
def do_chat(self, params: ApiChatParams) -> Dict:
params.load_config(self.model_names[0])
# import qianfan
# comp = qianfan.ChatCompletion(model=params.version,
# endpoint=params.version_url,
# ak=params.api_key,
# sk=params.secret_key,)
# text = ""
# for resp in comp.do(messages=params.messages,
# temperature=params.temperature,
# top_p=params.top_p,
# stream=True):
# if resp.code == 200:
# if chunk := resp.body.get("result"):
# text += chunk
# yield {
# "error_code": 0,
# "text": text
# }
# else:
# yield {
# "error_code": resp.code,
# "text": str(resp.body),
# }
BASE_URL = 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat' \
'/{model_version}?access_token={access_token}'
access_token = get_baidu_access_token(params.api_key, params.secret_key)
if not access_token:
yield {
"error_code": 403,
"text": f"failed to get access token. have you set the correct api_key and secret key?",
}
url = BASE_URL.format(
model_version=params.version_url or MODEL_VERSIONS[params.version.lower()],
access_token=access_token,
)
payload = {
"messages": params.messages,
"temperature": params.temperature,
"stream": True
}
headers = {
'Content-Type': 'application/json',
'Accept': 'application/json',
}
text = ""
if log_verbose:
logger.info(f'{self.__class__.__name__}:data: {payload}')
logger.info(f'{self.__class__.__name__}:url: {url}')
logger.info(f'{self.__class__.__name__}:headers: {headers}')
with get_httpx_client() as client:
with client.stream("POST", url, headers=headers, json=payload) as response:
for line in response.iter_lines():
if not line.strip():
continue
if line.startswith("data: "):
line = line[6:]
resp = json.loads(line)
if "result" in resp.keys():
text += resp["result"]
yield {
"error_code": 0,
"text": text
}
else:
data = {
"error_code": resp["error_code"],
"text": resp["error_msg"],
"error": {
"message": resp["error_msg"],
"type": "invalid_request_error",
"param": None,
"code": None,
}
}
self.logger.error(f"请求千帆 API 时发生错误:{data}")
yield data
def do_embeddings(self, params: ApiEmbeddingsParams) -> Dict:
params.load_config(self.model_names[0])
# import qianfan
# embed = qianfan.Embedding(ak=params.api_key, sk=params.secret_key)
# resp = embed.do(texts = params.texts, model=params.embed_model or self.DEFAULT_EMBED_MODEL)
# if resp.code == 200:
# embeddings = [x.embedding for x in resp.body.get("data", [])]
# return {"code": 200, "embeddings": embeddings}
# else:
# return {"code": resp.code, "msg": str(resp.body)}
embed_model = params.embed_model or self.DEFAULT_EMBED_MODEL
access_token = get_baidu_access_token(params.api_key, params.secret_key)
url = f"https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/{embed_model}?access_token={access_token}"
if log_verbose:
logger.info(f'{self.__class__.__name__}:url: {url}')
with get_httpx_client() as client:
result = []
i = 0
batch_size = 10
while i < len(params.texts):
texts = params.texts[i:i+batch_size]
resp = client.post(url, json={"input": texts}).json()
if "error_code" in resp:
data = {
"code": resp["error_code"],
"msg": resp["error_msg"],
"error": {
"message": resp["error_msg"],
"type": "invalid_request_error",
"param": None,
"code": None,
}
}
self.logger.error(f"请求千帆 API 时发生错误:{data}")
return data
else:
embeddings = [x["embedding"] for x in resp.get("data", [])]
result += embeddings
i += batch_size
return {"code": 200, "data": result}
# TODO: qianfan支持续写模型
def get_embeddings(self, params):
# TODO: 支持embeddings
print("embedding")
print(params)
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
# TODO: 确认模板是否需要修改
return conv.Conversation(
name=self.model_names[0],
system_message="你是一个聪明的助手,请根据用户的提示来完成任务",
messages=[],
roles=["user", "assistant"],
sep="\n### ",
stop_str="###",
)
if __name__ == "__main__":
import uvicorn
from server.utils import MakeFastAPIOffline
from fastchat.serve.model_worker import app
worker = QianFanWorker(
controller_addr="http://127.0.0.1:20001",
worker_addr="http://127.0.0.1:21004"
)
sys.modules["fastchat.serve.model_worker"].worker = worker
MakeFastAPIOffline(app)
uvicorn.run(app, port=21004)

View File

@ -0,0 +1,128 @@
import json
import sys
from fastchat.conversation import Conversation
from http import HTTPStatus
from typing import List, Literal, Dict
from fastchat import conversation as conv
from .base import *
from configs import logger, log_verbose
class QwenWorker(ApiModelWorker):
DEFAULT_EMBED_MODEL = "text-embedding-v1"
def __init__(
self,
*,
version: Literal["qwen-turbo", "qwen-plus"] = "qwen-turbo",
model_names: List[str] = ["qwen-api"],
controller_addr: str = None,
worker_addr: str = None,
**kwargs,
):
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
kwargs.setdefault("context_len", 16384)
super().__init__(**kwargs)
self.version = version
def do_chat(self, params: ApiChatParams) -> Dict:
import dashscope
params.load_config(self.model_names[0])
if log_verbose:
logger.info(f'{self.__class__.__name__}:params: {params}')
gen = dashscope.Generation()
responses = gen.call(
model=params.version,
temperature=params.temperature,
api_key=params.api_key,
messages=params.messages,
result_format='message', # set the result is message format.
stream=True,
)
for resp in responses:
if resp["status_code"] == 200:
if choices := resp["output"]["choices"]:
yield {
"error_code": 0,
"text": choices[0]["message"]["content"],
}
else:
data = {
"error_code": resp["status_code"],
"text": resp["message"],
"error": {
"message": resp["message"],
"type": "invalid_request_error",
"param": None,
"code": None,
}
}
self.logger.error(f"请求千问 API 时发生错误:{data}")
yield data
def do_embeddings(self, params: ApiEmbeddingsParams) -> Dict:
import dashscope
params.load_config(self.model_names[0])
if log_verbose:
logger.info(f'{self.__class__.__name__}:params: {params}')
result = []
i = 0
while i < len(params.texts):
texts = params.texts[i:i+25]
resp = dashscope.TextEmbedding.call(
model=params.embed_model or self.DEFAULT_EMBED_MODEL,
input=texts, # 最大25行
api_key=params.api_key,
)
if resp["status_code"] != 200:
data = {
"code": resp["status_code"],
"msg": resp.message,
"error": {
"message": resp["message"],
"type": "invalid_request_error",
"param": None,
"code": None,
}
}
self.logger.error(f"请求千问 API 时发生错误:{data}")
return data
else:
embeddings = [x["embedding"] for x in resp["output"]["embeddings"]]
result += embeddings
i += 25
return {"code": 200, "data": result}
def get_embeddings(self, params):
# TODO: 支持embeddings
print("embedding")
print(params)
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
# TODO: 确认模板是否需要修改
return conv.Conversation(
name=self.model_names[0],
system_message="你是一个聪明、对人类有帮助的人工智能,你可以对人类提出的问题给出有用、详细、礼貌的回答。",
messages=[],
roles=["user", "assistant", "system"],
sep="\n### ",
stop_str="###",
)
if __name__ == "__main__":
import uvicorn
from server.utils import MakeFastAPIOffline
from fastchat.serve.model_worker import app
worker = QwenWorker(
controller_addr="http://127.0.0.1:20001",
worker_addr="http://127.0.0.1:20007",
)
sys.modules["fastchat.serve.model_worker"].worker = worker
MakeFastAPIOffline(app)
uvicorn.run(app, port=20007)

View File

@ -0,0 +1,88 @@
import json
import time
import hashlib
from fastchat.conversation import Conversation
from .base import *
from fastchat import conversation as conv
import json
from typing import List, Literal, Dict
import requests
class TianGongWorker(ApiModelWorker):
def __init__(
self,
*,
controller_addr: str = None,
worker_addr: str = None,
model_names: List[str] = ["tiangong-api"],
version: Literal["SkyChat-MegaVerse"] = "SkyChat-MegaVerse",
**kwargs,
):
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
kwargs.setdefault("context_len", 32768)
super().__init__(**kwargs)
self.version = version
def do_chat(self, params: ApiChatParams) -> Dict:
params.load_config(self.model_names[0])
url = 'https://sky-api.singularity-ai.com/saas/api/v4/generate'
data = {
"messages": params.messages,
"model": "SkyChat-MegaVerse"
}
timestamp = str(int(time.time()))
sign_content = params.api_key + params.secret_key + timestamp
sign_result = hashlib.md5(sign_content.encode('utf-8')).hexdigest()
headers={
"app_key": params.api_key,
"timestamp": timestamp,
"sign": sign_result,
"Content-Type": "application/json",
"stream": "true" # or change to "false" 不处理流式返回内容
}
# 发起请求并获取响应
response = requests.post(url, headers=headers, json=data, stream=True)
text = ""
# 处理响应流
for line in response.iter_lines(chunk_size=None, decode_unicode=True):
if line:
# 处理接收到的数据
# print(line.decode('utf-8'))
resp = json.loads(line)
if resp["code"] == 200:
text += resp['resp_data']['reply']
yield {
"error_code": 0,
"text": text
}
else:
data = {
"error_code": resp["code"],
"text": resp["code_msg"]
}
self.logger.error(f"请求天工 API 时出错:{data}")
yield data
def get_embeddings(self, params):
# TODO: 支持embeddings
print("embedding")
print(params)
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
# TODO: 确认模板是否需要修改
return conv.Conversation(
name=self.model_names[0],
system_message="",
messages=[],
roles=["user", "system"],
sep="\n### ",
stop_str="###",
)

View File

@ -0,0 +1,105 @@
from fastchat.conversation import Conversation
from .base import *
from fastchat import conversation as conv
import sys
import json
from model_workers import SparkApi
import websockets
from dev_opsgpt.utils.server_utils import run_async, iter_over_async
from typing import List, Dict
import asyncio
async def request(appid, api_key, api_secret, Spark_url, domain, question, temperature, max_token):
wsParam = SparkApi.Ws_Param(appid, api_key, api_secret, Spark_url)
wsUrl = wsParam.create_url()
data = SparkApi.gen_params(appid, domain, question, temperature, max_token)
print(data)
async with websockets.connect(wsUrl) as ws:
await ws.send(json.dumps(data, ensure_ascii=False))
finish = False
while not finish:
chunk = await ws.recv()
response = json.loads(chunk)
if response.get("header", {}).get("status") == 2:
finish = True
if text := response.get("payload", {}).get("choices", {}).get("text"):
yield text[0]["content"]
class XingHuoWorker(ApiModelWorker):
def __init__(
self,
*,
model_names: List[str] = ["xinghuo-api"],
controller_addr: str = None,
worker_addr: str = None,
version: str = None,
**kwargs,
):
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
kwargs.setdefault("context_len", 8000) # TODO: V1模型的最大长度为4000需要自行修改
super().__init__(**kwargs)
self.version = version
def do_chat(self, params: ApiChatParams) -> Dict:
# TODO: 当前每次对话都要重新连接websocket确认是否可以保持连接
params.load_config(self.model_names[0])
version_mapping = {
"v1.5": {"domain": "general", "url": "ws://spark-api.xf-yun.com/v1.1/chat","max_tokens": 4000},
"v2.0": {"domain": "generalv2", "url": "ws://spark-api.xf-yun.com/v2.1/chat","max_tokens": 8000},
"v3.0": {"domain": "generalv3", "url": "ws://spark-api.xf-yun.com/v3.1/chat","max_tokens": 8000},
}
def get_version_details(version_key):
return version_mapping.get(version_key, {"domain": None, "url": None})
details = get_version_details(params.version)
domain = details["domain"]
Spark_url = details["url"]
text = ""
try:
loop = asyncio.get_event_loop()
except:
loop = asyncio.new_event_loop()
params.max_tokens = min(details["max_tokens"], params.max_tokens or 0)
for chunk in iter_over_async(
request(params.APPID, params.api_key, params.APISecret, Spark_url, domain, params.messages,
params.temperature, params.max_tokens),
loop=loop,
):
if chunk:
text += chunk
yield {"error_code": 0, "text": text}
def get_embeddings(self, params):
# TODO: 支持embeddings
print("embedding")
print(params)
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
# TODO: 确认模板是否需要修改
return conv.Conversation(
name=self.model_names[0],
system_message="你是一个聪明的助手,请根据用户的提示来完成任务",
messages=[],
roles=["user", "assistant"],
sep="\n### ",
stop_str="###",
)
if __name__ == "__main__":
import uvicorn
from server.utils import MakeFastAPIOffline
from fastchat.serve.model_worker import app
worker = XingHuoWorker(
controller_addr="http://127.0.0.1:20001",
worker_addr="http://127.0.0.1:21003",
)
sys.modules["fastchat.serve.model_worker"].worker = worker
MakeFastAPIOffline(app)
uvicorn.run(app, port=21003)

View File

@ -0,0 +1,110 @@
from fastchat.conversation import Conversation
from .base import *
from fastchat import conversation as conv
import sys
from typing import List, Dict, Iterator, Literal
from configs import logger, log_verbose
class ChatGLMWorker(ApiModelWorker):
DEFAULT_EMBED_MODEL = "text_embedding"
def __init__(
self,
*,
model_names: List[str] = ["zhipu-api"],
controller_addr: str = None,
worker_addr: str = None,
version: Literal["chatglm_turbo"] = "chatglm_turbo",
**kwargs,
):
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
kwargs.setdefault("context_len", 32768)
super().__init__(**kwargs)
self.version = version
def do_chat(self, params: ApiChatParams) -> Iterator[Dict]:
# TODO: 维护request_id
import zhipuai
params.load_config(self.model_names[0])
zhipuai.api_key = params.api_key
if log_verbose:
logger.info(f'{self.__class__.__name__}:params: {params}')
response = zhipuai.model_api.sse_invoke(
model=params.version,
prompt=params.messages,
temperature=params.temperature,
top_p=params.top_p,
incremental=False,
)
for e in response.events():
if e.event == "add":
yield {"error_code": 0, "text": e.data}
elif e.event in ["error", "interrupted"]:
data = {
"error_code": 500,
"text": str(e),
"error": {
"message": str(e),
"type": "invalid_request_error",
"param": None,
"code": None,
}
}
self.logger.error(f"请求智谱 API 时发生错误:{data}")
yield data
def do_embeddings(self, params: ApiEmbeddingsParams) -> Dict:
import zhipuai
params.load_config(self.model_names[0])
zhipuai.api_key = params.api_key
embeddings = []
try:
for t in params.texts:
response = zhipuai.model_api.invoke(model=params.embed_model or self.DEFAULT_EMBED_MODEL, prompt=t)
if response["code"] == 200:
embeddings.append(response["data"]["embedding"])
else:
self.logger.error(f"请求智谱 API 时发生错误:{response}")
return response # dict with code & msg
except Exception as e:
self.logger.error(f"请求智谱 API 时发生错误:{data}")
data = {"code": 500, "msg": f"对文本向量化时出错:{e}"}
return data
return {"code": 200, "data": embeddings}
def get_embeddings(self, params):
# TODO: 支持embeddings
print("embedding")
# print(params)
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
# 这里的是chatglm api的模板其它API的conv_template需要定制
return conv.Conversation(
name=self.model_names[0],
system_message="你是一个聪明的助手,请根据用户的提示来完成任务",
messages=[],
roles=["Human", "Assistant", "System"],
sep="\n###",
stop_str="###",
)
if __name__ == "__main__":
import uvicorn
from server.utils import MakeFastAPIOffline
from fastchat.serve.model_worker import app
worker = ChatGLMWorker(
controller_addr="http://127.0.0.1:20001",
worker_addr="http://127.0.0.1:21001",
)
sys.modules["fastchat.serve.model_worker"].worker = worker
MakeFastAPIOffline(app)
uvicorn.run(app, port=21001)

View File

@ -0,0 +1,39 @@
import os
from configs.model_config import ONLINE_LLM_MODEL
from configs.server_config import FSCHAT_MODEL_WORKERS
from configs.model_config import llm_model_dict, LLM_DEVICE
from loguru import logger
def get_model_worker_config(model_name: str = None) -> dict:
'''
加载model worker的配置项
优先级:FSCHAT_MODEL_WORKERS[model_name] > ONLINE_LLM_MODEL[model_name] > FSCHAT_MODEL_WORKERS["default"]
'''
from dev_opsgpt.service import model_workers
config = FSCHAT_MODEL_WORKERS.get("default", {}).copy()
config.update(ONLINE_LLM_MODEL.get(model_name, {}).copy())
config.update(FSCHAT_MODEL_WORKERS.get(model_name, {}).copy())
if model_name in ONLINE_LLM_MODEL:
config["online_api"] = True
if provider := config.get("provider"):
try:
config["worker_class"] = getattr(model_workers, provider)
except Exception as e:
msg = f"在线模型 {model_name} 的provider没有正确配置"
logger.error(f'{e.__class__.__name__}: {msg}')
# 本地模型
if model_name in llm_model_dict:
path = llm_model_dict[model_name]["local_model_path"]
config["model_path"] = path
if path and os.path.isdir(path):
config["model_path_exists"] = True
config["device"] = LLM_DEVICE
# logger.debug(f"config: {config}")
return config

View File

@ -11,12 +11,13 @@ from .docs_retrieval import DocRetrieval
from .cb_query_tool import CodeRetrieval
from .ocr_tool import BaiduOcrTool
from .stock_tool import StockInfo, StockName
from .codechat_tools import CodeRetrievalSingle, RelatedVerticesRetrival, Vertex2Code
IMPORT_TOOL = [
WeatherInfo, DistrictInfo, Multiplier, WorldTimeGetTimezoneByArea,
KSigmaDetector, MetricsQuery, DDGSTool, DocRetrieval, CodeRetrieval,
BaiduOcrTool, StockInfo, StockName
BaiduOcrTool, StockInfo, StockName, CodeRetrievalSingle, RelatedVerticesRetrival, Vertex2Code
]
TOOL_SETS = [tool.__name__ for tool in IMPORT_TOOL]

View File

@ -58,3 +58,5 @@ class CodeRetrieval(BaseToolModel):
return_codes.append({'index': 0, 'code': context, "related_nodes": related_nodes})
return return_codes

View File

@ -0,0 +1,110 @@
# encoding: utf-8
'''
@author: 温进
@file: codechat_tools.py.py
@time: 2023/12/14 上午10:24
@desc:
'''
import json
import os
import re
from pydantic import BaseModel, Field
from typing import List, Dict
import requests
import numpy as np
from loguru import logger
from configs.model_config import (
CODE_SEARCH_TOP_K)
from .base_tool import BaseToolModel
from dev_opsgpt.service.cb_api import search_code, search_related_vertices, search_code_by_vertex
# 问题进来
# 调用函数 0输入问题输出代码文件名 1 和 代码文件 1
#
# agent 1
# 1. LLM代码+问题 输出:是否能解决
#
# agent 2
# 1. 调用函数 1 :输入:代码文件名 1 输出:代码文件名列表
# 2. LLM输入代码文件 1 问题,代码文件名列表,输出:代码文件名 2
# 3. 调用函数 2 输入 :代码文件名 2 输出:代码文件 2
class CodeRetrievalSingle(BaseToolModel):
name = "CodeRetrievalOneCode"
description = "输入用户的问题,输出一个代码文件名和代码文件"
class ToolInputArgs(BaseModel):
query: str = Field(..., description="检索的问题")
code_base_name: str = Field(..., description="代码库名称", examples=["samples"])
class ToolOutputArgs(BaseModel):
"""Output for MetricsQuery."""
code: str = Field(..., description="检索代码")
vertex: str = Field(..., description="代码对应 id")
@classmethod
def run(cls, code_base_name, query):
"""excute your tool!"""
search_type = 'description'
code_limit = 1
# default
search_result = search_code(code_base_name, query, code_limit, search_type=search_type,
history_node_list=[])
logger.debug(search_result)
code = search_result['context']
vertex = search_result['related_vertices'][0]
# logger.debug(f"code: {code}, vertex: {vertex}")
res = {
'code': code,
'vertex': vertex
}
return res
class RelatedVerticesRetrival(BaseToolModel):
name = "RelatedVerticesRetrival"
description = "输入代码节点名,返回相连的节点名"
class ToolInputArgs(BaseModel):
code_base_name: str = Field(..., description="代码库名称", examples=["samples"])
vertex: str = Field(..., description="节点名", examples=["samples"])
class ToolOutputArgs(BaseModel):
"""Output for MetricsQuery."""
vertices: list = Field(..., description="相连节点名")
@classmethod
def run(cls, code_base_name: str, vertex: str):
"""execute your tool!"""
related_vertices = search_related_vertices(cb_name=code_base_name, vertex=vertex)
logger.debug(f"related_vertices: {related_vertices}")
return related_vertices
class Vertex2Code(BaseToolModel):
name = "Vertex2Code"
description = "输入代码节点名,返回对应的代码文件"
class ToolInputArgs(BaseModel):
code_base_name: str = Field(..., description="代码库名称", examples=["samples"])
vertex: str = Field(..., description="节点名", examples=["samples"])
class ToolOutputArgs(BaseModel):
"""Output for MetricsQuery."""
code: str = Field(..., description="代码名")
@classmethod
def run(cls, code_base_name: str, vertex: str):
"""execute your tool!"""
res = search_code_by_vertex(cb_name=code_base_name, vertex=vertex)
return res

View File

@ -15,9 +15,24 @@ chat_box = ChatBox(
assistant_avatar="../sources/imgs/devops-chatbot2.png"
)
cur_dir = os.path.dirname(os.path.abspath(__file__))
GLOBAL_EXE_CODE_TEXT = ""
GLOBAL_MESSAGE = {"figures": {}, "final_contents": {}}
import yaml
# 加载YAML文件
webui_yaml_filename = "webui_zh.yaml" if True else "webui_en.yaml"
with open(os.path.join(cur_dir, f"yamls/{webui_yaml_filename}"), 'r') as f:
try:
webui_configs = yaml.safe_load(f)
except yaml.YAMLError as exc:
print(exc)
def get_messages_history(history_len: int, isDetailed=False) -> List[Dict]:
def filter(msg):
'''
@ -55,12 +70,6 @@ def upload2sandbox(upload_file, api: ApiRequest):
res = {"msg": False}
else:
res = api.web_sd_upload(upload_file)
# logger.debug(res)
# if res["msg"]:
# st.success("上文件传成功")
# else:
# st.toast("文件上传失败")
def dialogue_page(api: ApiRequest):
global GLOBAL_EXE_CODE_TEXT
@ -70,33 +79,31 @@ def dialogue_page(api: ApiRequest):
# TODO: 对话模型与会话绑定
def on_mode_change():
mode = st.session_state.dialogue_mode
text = f"已切换到 {mode} 模式。"
if mode == "知识库问答":
text = webui_configs["dialogue"]["text_mode_swtich"] + f"{mode}"
if mode == webui_configs["dialogue"]["mode"][1]:
cur_kb = st.session_state.get("selected_kb")
if cur_kb:
text = f"{text} 当前知识库: `{cur_kb}`。"
text = text + webui_configs["dialogue"]["text_knowledgeBase_swtich"] + f'`{cur_kb}`'
st.toast(text)
# sac.alert(text, description="descp", type="success", closable=True, banner=True)
dialogue_mode = st.selectbox("请选择对话模式",
["LLM 对话",
"知识库问答",
"代码知识库问答",
"搜索引擎问答",
"Agent问答"
],
dialogue_mode = st.selectbox(webui_configs["dialogue"]["mode_instruction"],
webui_configs["dialogue"]["mode"],
# ["LLM 对话",
# "知识库问答",
# "代码知识库问答",
# "搜索引擎问答",
# "Agent问答"
# ],
on_change=on_mode_change,
key="dialogue_mode",
)
history_len = st.number_input("历史对话轮数:", 0, 10, 3)
# todo: support history len
history_len = st.number_input(webui_configs["dialogue"]["history_length"], 0, 10, 3)
def on_kb_change():
st.toast(f"已加载知识库: {st.session_state.selected_kb}")
st.toast(f"{webui_configs['dialogue']['text_loaded_kbase']}: {st.session_state.selected_kb}")
def on_cb_change():
st.toast(f"已加载代码知识库: {st.session_state.selected_cb}")
st.toast(f"{webui_configs['dialogue']['text_loaded_cbase']}: {st.session_state.selected_cb}")
cb_details = get_cb_details_by_cb_name(st.session_state.selected_cb)
st.session_state['do_interpret'] = cb_details['do_interpret']
@ -107,114 +114,140 @@ def dialogue_page(api: ApiRequest):
not_agent_qa = True
interpreter_file = ""
is_detailed = False
if dialogue_mode == "知识库问答":
with st.expander("知识库配置", True):
if dialogue_mode == webui_configs["dialogue"]["mode"][1]:
with st.expander(webui_configs["dialogue"]["kbase_expander_name"], True):
kb_list = api.list_knowledge_bases(no_remote_api=True)
selected_kb = st.selectbox(
"请选择知识库:",
webui_configs["dialogue"]["kbase_selectbox_name"],
kb_list,
on_change=on_kb_change,
key="selected_kb",
)
kb_top_k = st.number_input("匹配知识条数:", 1, 20, 3)
score_threshold = st.number_input("知识匹配分数阈值:", 0.0, float(SCORE_THRESHOLD), float(SCORE_THRESHOLD), float(SCORE_THRESHOLD//100))
# chunk_content = st.checkbox("关联上下文", False, disabled=True)
# chunk_size = st.slider("关联长度:", 0, 500, 250, disabled=True)
elif dialogue_mode == '代码知识库问答':
with st.expander('代码知识库配置', True):
kb_top_k = st.number_input(
webui_configs["dialogue"]["kbase_ninput_topk_name"], 1, 20, 3)
score_threshold = st.number_input(
webui_configs["dialogue"]["kbase_ninput_score_threshold_name"],
0.0, float(SCORE_THRESHOLD), float(SCORE_THRESHOLD),
float(SCORE_THRESHOLD//100))
elif dialogue_mode == webui_configs["dialogue"]["mode"][2]:
with st.expander(webui_configs["dialogue"]["cbase_expander_name"], True):
cb_list = api.list_cb(no_remote_api=True)
logger.debug('codebase_list={}'.format(cb_list))
selected_cb = st.selectbox(
"请选择代码知识库:",
webui_configs["dialogue"]["cbase_selectbox_name"],
cb_list,
on_change=on_cb_change,
key="selected_cb",
)
# change do_interpret
st.toast(f"已加载代码知识库: {st.session_state.selected_cb}")
st.toast(f"{webui_configs['dialogue']['text_loaded_cbase']}: {st.session_state.selected_cb}")
cb_details = get_cb_details_by_cb_name(st.session_state.selected_cb)
st.session_state['do_interpret'] = cb_details['do_interpret']
cb_code_limit = st.number_input("匹配代码条数:", 1, 20, 1)
cb_code_limit = st.number_input(
webui_configs["dialogue"]["cbase_ninput_topk_name"], 1, 20, 1)
search_type_list = ['基于 cypher', '基于标签', '基于描述'] if st.session_state['do_interpret'] == 'YES' \
else ['基于 cypher', '基于标签']
search_type_list = webui_configs["dialogue"]["cbase_search_type_v1"] if st.session_state['do_interpret'] == 'YES' \
else webui_configs["dialogue"]["cbase_search_type_v2"]
cb_search_type = st.selectbox(
'请选择查询模式:',
webui_configs["dialogue"]["cbase_selectbox_type_name"],
search_type_list,
key='cb_search_type'
)
elif dialogue_mode == "搜索引擎问答":
with st.expander("搜索引擎配置", True):
search_engine = st.selectbox("请选择搜索引擎", SEARCH_ENGINES.keys(), 0)
se_top_k = st.number_input("匹配搜索结果条数:", 1, 20, 3)
elif dialogue_mode == "Agent问答":
elif dialogue_mode == webui_configs["dialogue"]["mode"][3]:
with st.expander(webui_configs["dialogue"]["expander_search_name"], True):
search_engine = st.selectbox(
webui_configs["dialogue"]["selectbox_search_name"],
SEARCH_ENGINES.keys(), 0)
se_top_k = st.number_input(
webui_configs["dialogue"]["ninput_search_topk_name"], 1, 20, 3)
elif dialogue_mode == webui_configs["dialogue"]["mode"][4]:
not_agent_qa = False
with st.expander("Phase管理", True):
with st.expander(webui_configs["dialogue"]["phase_expander_name"], True):
choose_phase = st.selectbox(
'请选择待使用的执行链路', PHASE_LIST, 0)
webui_configs["dialogue"]["phase_selectbox_name"], PHASE_LIST, 0)
is_detailed = st.toggle("是否使用明细信息进行agent交互", False)
tool_using_on = st.toggle("开启工具使用", PHASE_CONFIGS[choose_phase]["do_using_tool"])
is_detailed = st.toggle(webui_configs["dialogue"]["phase_toggle_detailed_name"], False)
tool_using_on = st.toggle(
webui_configs["dialogue"]["phase_toggle_doToolUsing"],
PHASE_CONFIGS[choose_phase]["do_using_tool"])
tool_selects = []
if tool_using_on:
with st.expander("工具军火库", True):
tool_selects = st.multiselect(
'请选择待使用的工具', TOOL_SETS, ["WeatherInfo"])
webui_configs["dialogue"]["phase_multiselect_tools"],
TOOL_SETS, ["WeatherInfo"])
search_on = st.toggle("开启搜索增强", PHASE_CONFIGS[choose_phase]["do_search"])
search_on = st.toggle(webui_configs["dialogue"]["phase_toggle_doSearch"],
PHASE_CONFIGS[choose_phase]["do_search"])
search_engine, top_k = None, 3
if search_on:
with st.expander("搜索引擎配置", True):
search_engine = st.selectbox("请选择搜索引擎", SEARCH_ENGINES.keys(), 0)
top_k = st.number_input("匹配搜索结果条数:", 1, 20, 3)
with st.expander(webui_configs["dialogue"]["expander_search_name"], True):
search_engine = st.selectbox(
webui_configs["dialogue"]["selectbox_search_name"],
SEARCH_ENGINES.keys(), 0)
se_top_k = st.number_input(
webui_configs["dialogue"]["ninput_search_topk_name"], 1, 20, 3)
doc_retrieval_on = st.toggle("开启知识库检索增强", PHASE_CONFIGS[choose_phase]["do_doc_retrieval"])
doc_retrieval_on = st.toggle(
webui_configs["dialogue"]["phase_toggle_doDocRetrieval"],
PHASE_CONFIGS[choose_phase]["do_doc_retrieval"])
selected_kb, top_k, score_threshold = None, 3, 1.0
if doc_retrieval_on:
with st.expander("知识库配置", True):
with st.expander(webui_configs["dialogue"]["kbase_expander_name"], True):
kb_list = api.list_knowledge_bases(no_remote_api=True)
selected_kb = st.selectbox(
"请选择知识库:",
webui_configs["dialogue"]["kbase_selectbox_name"],
kb_list,
on_change=on_kb_change,
key="selected_kb",
)
top_k = st.number_input("匹配知识条数:", 1, 20, 3)
score_threshold = st.number_input("知识匹配分数阈值:", 0.0, float(SCORE_THRESHOLD), float(SCORE_THRESHOLD), float(SCORE_THRESHOLD//100))
code_retrieval_on = st.toggle("开启代码检索增强", PHASE_CONFIGS[choose_phase]["do_code_retrieval"])
top_k = st.number_input(
webui_configs["dialogue"]["kbase_ninput_topk_name"], 1, 20, 3)
score_threshold = st.number_input(
webui_configs["dialogue"]["kbase_ninput_score_threshold_name"],
0.0, float(SCORE_THRESHOLD), float(SCORE_THRESHOLD),
float(SCORE_THRESHOLD//100))
code_retrieval_on = st.toggle(
webui_configs["dialogue"]["phase_toggle_doCodeRetrieval"],
PHASE_CONFIGS[choose_phase]["do_code_retrieval"])
selected_cb, top_k = None, 1
cb_search_type = "tag"
if code_retrieval_on:
with st.expander('代码知识库配置', True):
with st.expander(webui_configs["dialogue"]["cbase_expander_name"], True):
cb_list = api.list_cb(no_remote_api=True)
logger.debug('codebase_list={}'.format(cb_list))
selected_cb = st.selectbox(
"请选择代码知识库:",
webui_configs["dialogue"]["cbase_selectbox_name"],
cb_list,
on_change=on_cb_change,
key="selected_cb",
)
st.toast(f"已加载代码知识库: {st.session_state.selected_cb}")
top_k = st.number_input("匹配代码条数:", 1, 20, 1)
# change do_interpret
st.toast(f"{webui_configs['dialogue']['text_loaded_cbase']}: {st.session_state.selected_cb}")
cb_details = get_cb_details_by_cb_name(st.session_state.selected_cb)
st.session_state['do_interpret'] = cb_details['do_interpret']
search_type_list = ['基于 cypher', '基于标签', '基于描述'] if st.session_state['do_interpret'] == 'YES' \
else ['基于 cypher', '基于标签']
top_k = st.number_input(
webui_configs["dialogue"]["cbase_ninput_topk_name"], 1, 20, 1)
search_type_list = webui_configs["dialogue"]["cbase_search_type_v1"] if st.session_state['do_interpret'] == 'YES' \
else webui_configs["dialogue"]["cbase_search_type_v2"]
cb_search_type = st.selectbox(
'请选择查询模式:',
webui_configs["dialogue"]["cbase_selectbox_type_name"],
search_type_list,
key='cb_search_type'
)
with st.expander("沙盒文件管理", False):
with st.expander(webui_configs["sandbox"]["expander_name"], False):
interpreter_file = st.file_uploader(
"上传沙盒文件",
webui_configs["sandbox"]["file_upload_name"],
[i for ls in LOADER2EXT_DICT.values() for i in ls] + ["jpg", "png"],
accept_multiple_files=False,
key=st.session_state.interpreter_file_key,
@ -222,29 +255,31 @@ def dialogue_page(api: ApiRequest):
files = api.web_sd_list_files()
files = files["data"]
download_file = st.selectbox("选择要处理文件", files,
download_file = st.selectbox(webui_configs["sandbox"]["selectbox_name"], files,
key="download_file",)
cols = st.columns(3)
file_url, file_name = api.web_sd_download(download_file)
if cols[0].button("点击上传"):
if cols[0].button(webui_configs["sandbox"]["button_upload_name"],):
upload2sandbox(interpreter_file, api)
st.session_state["interpreter_file_key"] += 1
interpreter_file = ""
st.experimental_rerun()
cols[1].download_button("点击下载", file_url, file_name)
if cols[2].button("点击删除", ):
cols[1].download_button(webui_configs["sandbox"]["button_download_name"],
file_url, file_name)
if cols[2].button(webui_configs["sandbox"]["button_delete_name"],):
api.web_sd_delete(download_file)
code_interpreter_on = st.toggle("开启代码解释器") and not_agent_qa
code_exec_on = st.toggle("自动执行代码") and not_agent_qa
code_interpreter_on = st.toggle(
webui_configs["sandbox"]["toggle_doCodeInterpreter"]) and not_agent_qa
code_exec_on = st.toggle(webui_configs["sandbox"]["toggle_doAutoCodeExec"]) and not_agent_qa
# Display chat messages from history on app rerun
chat_box.output_messages()
chat_input_placeholder = "请输入对话内容换行请使用Ctrl+Enter "
chat_input_placeholder = webui_configs["chat"]["chat_placeholder"]
code_text = "" or GLOBAL_EXE_CODE_TEXT
codebox_res = None
@ -254,8 +289,8 @@ def dialogue_page(api: ApiRequest):
history = get_messages_history(history_len, is_detailed)
chat_box.user_say(prompt)
if dialogue_mode == "LLM 对话":
chat_box.ai_say("正在思考...")
if dialogue_mode == webui_configs["dialogue"]["mode"][0]:
chat_box.ai_say(webui_configs["chat"]["chatbox_saying"])
text = ""
r = api.chat_chat(prompt, history, no_remote_api=True)
for t in r:
@ -277,12 +312,18 @@ def dialogue_page(api: ApiRequest):
GLOBAL_EXE_CODE_TEXT = code_text
if code_text and code_exec_on:
codebox_res = api.codebox_chat("```"+code_text+"```", do_code_exe=True)
elif dialogue_mode == "Agent问答":
display_infos = [f"正在思考..."]
elif dialogue_mode == webui_configs["dialogue"]["mode"][4]:
display_infos = [webui_configs["chat"]["chatbox_saying"]]
if search_on:
display_infos.append(Markdown("...", in_expander=True, title="网络搜索结果"))
display_infos.append(Markdown("...", in_expander=True,
title=webui_configs["chat"]["chatbox_search_result"]))
if doc_retrieval_on:
display_infos.append(Markdown("...", in_expander=True, title="知识库匹配结果"))
display_infos.append(Markdown("...", in_expander=True,
title=webui_configs["chat"]["chatbox_doc_result"]))
if code_retrieval_on:
display_infos.append(Markdown("...", in_expander=True,
title=webui_configs["chat"]["chatbox_code_result"]))
chat_box.ai_say(display_infos)
if 'history_node_list' in st.session_state:
@ -331,18 +372,21 @@ def dialogue_page(api: ApiRequest):
chat_box.update_msg(text, element_index=0, streaming=False, state="complete") # 更新最终的字符串,去除光标
if search_on:
chat_box.update_msg("搜索匹配结果:\n\n" + "\n\n".join(d["search_docs"]), element_index=search_on, streaming=False, state="complete")
chat_box.update_msg(f"{webui_configs['chat']['chatbox_search_result']}:\n\n" + "\n\n".join(d["search_docs"]), element_index=search_on, streaming=False, state="complete")
if doc_retrieval_on:
chat_box.update_msg("知识库匹配结果:\n\n" + "\n\n".join(d["db_docs"]), element_index=search_on+doc_retrieval_on, streaming=False, state="complete")
chat_box.update_msg(f"{webui_configs['chat']['chatbox_doc_result']}:\n\n" + "\n\n".join(d["db_docs"]), element_index=search_on+doc_retrieval_on, streaming=False, state="complete")
if code_retrieval_on:
chat_box.update_msg(f"{webui_configs['chat']['chatbox_code_result']}:\n\n" + "\n\n".join(d["code_docs"]),
element_index=search_on+doc_retrieval_on+code_retrieval_on, streaming=False, state="complete")
history_node_list.extend([node[0] for node in d.get("related_nodes", [])])
history_node_list = list(set(history_node_list))
st.session_state['history_node_list'] = history_node_list
elif dialogue_mode == "知识库问答":
elif dialogue_mode == webui_configs["dialogue"]["mode"][1]:
history = get_messages_history(history_len)
chat_box.ai_say([
f"正在查询知识库 `{selected_kb}` ...",
Markdown("...", in_expander=True, title="知识库匹配结果"),
f"{webui_configs['chat']['chatbox_doc_querying']} `{selected_kb}` ...",
Markdown("...", in_expander=True, title=webui_configs['chat']['chatbox_doc_result']),
])
text = ""
d = {"docs": []}
@ -354,21 +398,21 @@ def dialogue_page(api: ApiRequest):
chat_box.update_msg(text, element_index=0)
# chat_box.update_msg("知识库匹配结果: \n\n".join(d["docs"]), element_index=1, streaming=False, state="complete")
chat_box.update_msg(text, element_index=0, streaming=False) # 更新最终的字符串,去除光标
chat_box.update_msg("知识库匹配结果: \n\n".join(d["docs"]), element_index=1, streaming=False, state="complete")
chat_box.update_msg("{webui_configs['chat']['chatbox_doc_result']}: \n\n".join(d["docs"]), element_index=1, streaming=False, state="complete")
# 判断是否存在代码, 并提高编辑功能,执行功能
code_text = api.codebox.decode_code_from_text(text)
GLOBAL_EXE_CODE_TEXT = code_text
if code_text and code_exec_on:
codebox_res = api.codebox_chat("```"+code_text+"```", do_code_exe=True)
elif dialogue_mode == '代码知识库问答':
elif dialogue_mode == webui_configs["dialogue"]["mode"][2]:
logger.info('prompt={}'.format(prompt))
logger.info('history={}'.format(history))
if 'history_node_list' in st.session_state:
api.codeChat.history_node_list = st.session_state['history_node_list']
chat_box.ai_say([
f"正在查询代码知识库 `{selected_cb}` ...",
Markdown("...", in_expander=True, title="代码库匹配节点"),
f"{webui_configs['chat']['chatbox_code_querying']} `{selected_cb}` ...",
Markdown("...", in_expander=True, title=webui_configs['chat']['chatbox_code_result']),
])
text = ""
d = {"codes": []}
@ -393,14 +437,14 @@ def dialogue_page(api: ApiRequest):
# session state update
# st.session_state['history_node_list'] = api.codeChat.history_node_list
elif dialogue_mode == "搜索引擎问答":
elif dialogue_mode == webui_configs["dialogue"]["mode"][3]:
chat_box.ai_say([
f"正在执行 `{search_engine}` 搜索...",
Markdown("...", in_expander=True, title="网络搜索结果"),
webui_configs['chat']['chatbox_searching'],
Markdown("...", in_expander=True, title=webui_configs['chat']['chatbox_search_result']),
])
text = ""
d = {"docs": []}
for idx_count, d in enumerate(api.search_engine_chat(prompt, search_engine, se_top_k)):
for idx_count, d in enumerate(api.search_engine_chat(prompt, search_engine, se_top_k, history)):
if error_msg := check_error_msg(d): # check whether error occured
st.error(error_msg)
text += d["answer"]
@ -408,7 +452,7 @@ def dialogue_page(api: ApiRequest):
chat_box.update_msg(text, element_index=0)
# chat_box.update_msg("搜索匹配结果: \n\n".join(d["docs"]), element_index=1, streaming=False)
chat_box.update_msg(text, element_index=0, streaming=False) # 更新最终的字符串,去除光标
chat_box.update_msg("搜索匹配结果: \n\n".join(d["docs"]), element_index=1, streaming=False, state="complete")
chat_box.update_msg(f"{webui_configs['chat']['chatbox_search_result']}: \n\n".join(d["docs"]), element_index=1, streaming=False, state="complete")
# 判断是否存在代码, 并提高编辑功能,执行功能
code_text = api.codebox.decode_code_from_text(text)
GLOBAL_EXE_CODE_TEXT = code_text
@ -420,33 +464,34 @@ def dialogue_page(api: ApiRequest):
st.experimental_rerun()
if code_interpreter_on:
with st.expander("代码编辑执行器", False):
code_part = st.text_area("代码片段", code_text, key="code_text")
with st.expander(webui_configs['sandbox']['expander_code_name'], False):
code_part = st.text_area(
webui_configs['sandbox']['textArea_code_name'], code_text, key="code_text")
cols = st.columns(2)
if cols[0].button(
"修改对话",
webui_configs['sandbox']['button_modify_code_name'],
use_container_width=True,
):
code_text = code_part
GLOBAL_EXE_CODE_TEXT = code_text
st.toast("修改对话成功")
st.toast(webui_configs['sandbox']['text_modify_code'])
if cols[1].button(
"执行代码",
webui_configs['sandbox']['button_exec_code_name'],
use_container_width=True
):
if code_text:
codebox_res = api.codebox_chat("```"+code_text+"```", do_code_exe=True)
st.toast("正在执行代码")
st.toast(webui_configs['sandbox']['text_execing_code'],)
else:
st.toast("code 不能为空")
st.toast(webui_configs['sandbox']['text_error_exec_code'],)
#TODO 这段信息会被记录到history里
if codebox_res is not None and codebox_res.code_exe_status != 200:
st.toast(f"{codebox_res.code_exe_response}")
if codebox_res is not None and codebox_res.code_exe_status == 200:
st.toast(f"codebox_chajt {codebox_res}")
st.toast(f"codebox_chat {codebox_res}")
chat_box.ai_say(Markdown(code_text, in_expander=True, title="code interpreter", unsafe_allow_html=True), )
if codebox_res.code_exe_type == "image/png":
base_text = f"```\n{code_text}\n```\n\n"
@ -464,7 +509,7 @@ def dialogue_page(api: ApiRequest):
cols = st.columns(2)
export_btn = cols[0]
if cols[1].button(
"清空对话",
webui_configs['export']['button_clear_conversation_name'],
use_container_width=True,
):
chat_box.reset_history()
@ -474,9 +519,9 @@ def dialogue_page(api: ApiRequest):
st.experimental_rerun()
export_btn.download_button(
"导出记录",
webui_configs['export']['download_button_export_name'],
"".join(chat_box.export2md()),
file_name=f"{now:%Y-%m-%d %H.%M}_对话记录.md",
file_name=f"{now:%Y-%m-%d %H.%M}_conversations.md",
mime="text/markdown",
use_container_width=True,
)

View File

@ -355,7 +355,8 @@ class ApiRequest:
self,
query: str,
search_engine_name: str,
code_limit: int,
top_k: int,
history: List[Dict] = [],
stream: bool = True,
no_remote_api: bool = None,
):
@ -368,8 +369,8 @@ class ApiRequest:
data = {
"query": query,
"engine_name": search_engine_name,
"code_limit": code_limit,
"history": [],
"top_k": top_k,
"history": history,
"stream": stream,
}

View File

@ -0,0 +1,78 @@
# This is an example of webui
dialogue:
mode_instruction: 请选择对话模式
mode:
- LLM Conversation
- Knowledge Base Q&A
- Code Knowledge Base Q&A
- Search Engine Q&A
- Agents Q&A
history_length: History of Dialogue Turns
text_mode_swtich: Switched to mode
text_knowledgeBase_swtich: Current Knowledge Base"
text_loaded_kbase: Loaded Knowledge Base
text_loaded_cbase: Loaded Code Knowledge Base
# Knowledge Base Q&A
kbase_expander_name: 知识库配置
kbase_selectbox_name: 请选择知识库:
kbase_ninput_topk_name: 匹配知识条数:
kbase_ninput_score_threshold_name: 知识匹配分数阈值:
# Code Knowledge Base Q&A
cbase_expander_name: 代码知识库配置
cbase_selectbox_name: 请选择代码知识库:
cbase_ninput_topk_name: 匹配代码条数:
cbase_selectbox_type_name: 请选择查询模式:
cbase_search_type_v1:
- 基于 cypher
- 基于标签
- 基于描述
cbase_search_type_v2:
- 基于 cypher
- 基于标签
# Search Engine Q&A
expander_search_name: 搜索引擎配置
selectbox_search_name: 请选择搜索引擎
ninput_search_topk_name: 匹配搜索结果条数:
# Agents Q&A
phase_expander_name: Phase管理
phase_selectbox_name: 请选择待使用的执行链路
phase_toggle_detailed_name: 是否使用明细信息进行agent交互
phase_toggle_doToolUsing: 开启工具使用
phase_multiselect_tools: 请选择待使用的工具
phase_toggle_doSearch: 开启搜索增强
phase_toggle_doDocRetrieval: 开启知识库检索增强
phase_toggle_doCodeRetrieval: 开启代码检索增强
sandbox:
expander_name: 沙盒文件管理
file_upload_name: 上传沙盒文件
selectbox_name: 选择要处理文件
button_upload_name: 点击上传
button_download_name: 点击下载
button_delete_name: 点击删除
toggle_doCodeInterpreter: 开启代码解释器
toggle_doAutoCodeExec: 自动执行代码
expander_code_name: 代码编辑执行器
textArea_code_name: 代码片段
button_modify_code_name: 修改对话
text_modify_code: 修改对话成功
button_exec_code_name: 执行代码
text_execing_code: 正在执行代码
text_error_exec_code: code 不能为空
chat:
chat_placeholder: 请输入对话内容换行请使用Ctrl+Enter
chatbox_saying: 正在思考...
chatbox_doc_querying: 正在查询知识库
chatbox_code_querying: 正在查询代码知识库
chatbox_searching: 正在执行搜索
chatbox_search_result: 网络搜索结果
chatbox_doc_result: 知识库匹配结果
chatbox_code_result: 代码库匹配节点
export:
button_clear_conversation_name: 清空对话
download_button_export_name: 导出记录

View File

@ -0,0 +1,78 @@
# This is an example of webui
dialogue:
mode_instruction: 请选择对话模式
mode:
- LLM 对话
- 知识库问答
- 代码知识库问答
- 搜索引擎问答
- Agent问答
history_length: 历史对话轮数
text_mode_swtich: 已切换到模式
text_knowledgeBase_swtich: 当前知识库
text_loaded_kbase: 已加载知识库
text_loaded_cbase: 已加载代码知识库
# Knowledge Base Q&A
kbase_expander_name: 知识库配置
kbase_selectbox_name: 请选择知识库:
kbase_ninput_topk_name: 匹配知识条数:
kbase_ninput_score_threshold_name: 知识匹配分数阈值:
# Code Knowledge Base Q&A
cbase_expander_name: 代码知识库配置
cbase_selectbox_name: 请选择代码知识库:
cbase_ninput_topk_name: 匹配代码条数:
cbase_selectbox_type_name: 请选择查询模式:
cbase_search_type_v1:
- 基于 cypher
- 基于标签
- 基于描述
cbase_search_type_v2:
- 基于 cypher
- 基于标签
# Search Engine Q&A
expander_search_name: 搜索引擎配置
selectbox_search_name: 请选择搜索引擎
ninput_search_topk_name: 匹配搜索结果条数:
# Agents Q&A
phase_expander_name: Phase管理
phase_selectbox_name: 请选择待使用的执行链路
phase_toggle_detailed_name: 是否使用明细信息进行agent交互
phase_toggle_doToolUsing: 开启工具使用
phase_multiselect_tools: 请选择待使用的工具
phase_toggle_doSearch: 开启搜索增强
phase_toggle_doDocRetrieval: 开启知识库检索增强
phase_toggle_doCodeRetrieval: 开启代码检索增强
sandbox:
expander_name: 沙盒文件管理
file_upload_name: 上传沙盒文件
selectbox_name: 选择要处理文件
button_upload_name: 点击上传
button_download_name: 点击下载
button_delete_name: 点击删除
toggle_doCodeInterpreter: 开启代码解释器
toggle_doAutoCodeExec: 自动执行代码
expander_code_name: 代码编辑执行器
textArea_code_name: 代码片段
button_modify_code_name: 修改对话
text_modify_code: 修改对话成功
button_exec_code_name: 执行代码
text_execing_code: 正在执行代码
text_error_exec_code: code 不能为空
chat:
chat_placeholder: 请输入对话内容换行请使用Ctrl+Enter
chatbox_saying: 正在思考...
chatbox_doc_querying: 正在查询知识库
chatbox_code_querying: 正在查询代码知识库
chatbox_searching: 正在执行搜索
chatbox_search_result: 网络搜索结果
chatbox_doc_result: 知识库匹配结果
chatbox_code_result: 代码库匹配节点
export:
button_clear_conversation_name: 清空对话
download_button_export_name: 导出记录

View File

@ -0,0 +1,53 @@
import os, sys, requests
src_dir = os.path.join(
os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
)
sys.path.append(src_dir)
from dev_opsgpt.tools import (
toLangchainTools, get_tool_schema, DDGSTool, DocRetrieval,
TOOL_DICT, TOOL_SETS
)
from configs.model_config import *
from dev_opsgpt.connector.phase import BasePhase
from dev_opsgpt.connector.agents import BaseAgent
from dev_opsgpt.connector.chains import BaseChain
from dev_opsgpt.connector.schema import (
Message, Memory, load_role_configs, load_phase_configs, load_chain_configs
)
from dev_opsgpt.connector.configs import AGETN_CONFIGS, CHAIN_CONFIGS, PHASE_CONFIGS
import importlib
tools = toLangchainTools([TOOL_DICT[i] for i in TOOL_SETS if i in TOOL_DICT])
role_configs = load_role_configs(AGETN_CONFIGS)
chain_configs = load_chain_configs(CHAIN_CONFIGS)
phase_configs = load_phase_configs(PHASE_CONFIGS)
agent_module = importlib.import_module("dev_opsgpt.connector.agents")
phase_name = "baseGroupPhase"
phase = BasePhase(phase_name,
task = None,
phase_config = PHASE_CONFIGS,
chain_config = CHAIN_CONFIGS,
role_config = AGETN_CONFIGS,
do_summary=False,
do_code_retrieval=False,
do_doc_retrieval=True,
do_search=False,
)
# round-1
query_content = "确认本地是否存在employee_data.csv并查看它有哪些列和数据类型;然后画柱状图"
# query_content = "帮我确认下127.0.0.1这个服务器的在10点是否存在异常请帮我判断一下"
query = Message(
role_name="human", role_type="user", tools=tools,
role_content=query_content, input_query=query_content, origin_query=query_content,
)
output_message, _ = phase.step(query)

View File

@ -43,7 +43,7 @@ phase = BasePhase(phase_name,
)
# round-1
query_content = "确认本地是否存在employee_data.csv并查看它有哪些列和数据类型;然后画柱状图"
query_content = "确认本地是否存在book_data.csv并查看它有哪些列和数据类型;然后画柱状图"
query = Message(
role_name="human", role_type="user",
role_content=query_content, input_query=query_content, origin_query=query_content,

View File

@ -0,0 +1,236 @@
import os, sys, requests
src_dir = os.path.join(
os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
)
sys.path.append(src_dir)
from configs.model_config import *
from dev_opsgpt.connector.phase import BasePhase
from dev_opsgpt.connector.agents import BaseAgent
from dev_opsgpt.connector.chains import BaseChain
from dev_opsgpt.connector.schema import (
Message, Memory, load_role_configs, load_phase_configs, load_chain_configs
)
from dev_opsgpt.connector.configs import AGETN_CONFIGS, CHAIN_CONFIGS, PHASE_CONFIGS
from dev_opsgpt.connector.utils import parse_section
import importlib
# update new agent configs
codeRetrievalJudger_PROMPT = """#### CodeRetrievalJudger Assistance Guidance
Given the user's question and respective code, you need to decide whether the provided codes are enough to answer the question.
#### Input Format
**Origin Query:** the initial question or objective that the user wanted to achieve
**Retrieval Codes:** the Retrieval Codes from the code base
#### Response Output Format
**REASON:** Justify the decision of choosing 'finished' and 'continued' by evaluating the progress step by step.
"""
# 将下面的话放到上面的prompt里面去执行让它判断是否停止
# **Action Status:** Set to 'finished' or 'continued'.
# If it's 'finished', the provided codes can answer the origin query.
# If it's 'continued', the origin query cannot be answered well from the provided code.
codeRetrievalDivergent_PROMPT = """#### CodeRetrievalDivergen Assistance Guidance
You are a assistant that helps to determine which code package is needed to answer the question.
Given the user's question, Retrieval code, and the code Packages related to Retrieval code. you need to decide which code package we need to read to better answer the question.
#### Input Format
**Origin Query:** the initial question or objective that the user wanted to achieve
**Retrieval Codes:** the Retrieval Codes from the code base
**Code Packages:** the code packages related to Retrieval code
#### Response Output Format
**Code Package:** Identify another Code Package from the Code Packages that should be read to provide a better answer to the Origin Query.
**REASON:** Justify the decision of choosing 'finished' and 'continued' by evaluating the progress step by step.
"""
AGETN_CONFIGS.update({
"codeRetrievalJudger": {
"role": {
"role_prompt": codeRetrievalJudger_PROMPT,
"role_type": "assistant",
"role_name": "codeRetrievalJudger",
"role_desc": "",
"agent_type": "CodeRetrievalJudger"
# "agent_type": "BaseAgent"
},
"chat_turn": 1,
"focus_agents": [],
"focus_message_keys": [],
},
"codeRetrievalDivergent": {
"role": {
"role_prompt": codeRetrievalDivergent_PROMPT,
"role_type": "assistant",
"role_name": "codeRetrievalDivergent",
"role_desc": "",
"agent_type": "CodeRetrievalDivergent"
# "agent_type": "BaseAgent"
},
"chat_turn": 1,
"focus_agents": [],
"focus_message_keys": [],
},
})
# update new chain configs
CHAIN_CONFIGS.update({
"codeRetrievalChain": {
"chain_name": "codeRetrievalChain",
"chain_type": "BaseChain",
"agents": ["codeRetrievalJudger", "codeRetrievalDivergent"],
"chat_turn": 5,
"do_checker": False,
"chain_prompt": ""
}
})
# update phase configs
PHASE_CONFIGS.update({
"codeRetrievalPhase": {
"phase_name": "codeRetrievalPhase",
"phase_type": "BasePhase",
"chains": ["codeRetrievalChain"],
"do_summary": False,
"do_search": False,
"do_doc_retrieval": False,
"do_code_retrieval": True,
"do_tool_retrieval": False,
"do_using_tool": False
},
})
role_configs = load_role_configs(AGETN_CONFIGS)
chain_configs = load_chain_configs(CHAIN_CONFIGS)
phase_configs = load_phase_configs(PHASE_CONFIGS)
agent_module = importlib.import_module("dev_opsgpt.connector.agents")
from dev_opsgpt.tools import CodeRetrievalSingle, RelatedVerticesRetrival, Vertex2Code
# 定义一个新的类
class CodeRetrievalJudger(BaseAgent):
def start_action_step(self, message: Message) -> Message:
'''do action before agent predict '''
action_json = CodeRetrievalSingle.run(message.code_engine_name, message.origin_query)
message.customed_kargs["CodeRetrievalSingleRes"] = action_json
message.customed_kargs.setdefault("Retrieval_Codes", "")
message.customed_kargs["Retrieval_Codes"] += "\n" + action_json["code"]
return message
def create_prompt(
self, query: Message, memory: Memory =None, history: Memory = None, background: Memory = None, memory_pool: Memory=None, prompt_mamnger=None) -> str:
'''
prompt engineer, contains role\task\tools\docs\memory
'''
#
logger.debug(f"query: {query.customed_kargs}")
formatted_tools, tool_names, _ = self.create_tools_prompt(query)
prompt = self.role.role_prompt.format(**{"formatted_tools": formatted_tools, "tool_names": tool_names})
#
input_keys = parse_section(self.role.role_prompt, 'Input Format')
prompt += "\n#### Begin!!!\n"
#
for input_key in input_keys:
if input_key == "Origin Query":
prompt += f"\n**{input_key}:**\n" + query.origin_query
elif input_key == "Retrieval Codes":
prompt += f"\n**{input_key}:**\n" + query.customed_kargs["Retrieval_Codes"]
while "{{" in prompt or "}}" in prompt:
prompt = prompt.replace("{{", "{")
prompt = prompt.replace("}}", "}")
return prompt
# 定义一个新的类
class CodeRetrievalDivergent(BaseAgent):
def start_action_step(self, message: Message) -> Message:
'''do action before agent predict '''
action_json = RelatedVerticesRetrival.run(message.code_engine_name, message.customed_kargs["CodeRetrievalSingleRes"]["vertex"])
message.customed_kargs["RelatedVerticesRetrivalRes"] = action_json
return message
def end_action_step(self, message: Message) -> Message:
'''do action before agent predict '''
# logger.error(f"message: {message}")
# action_json = Vertex2Code.run(message.code_engine_name, "com.theokanning.openai.client#Utils.java") # message.parsed_output["Code_Filename"])
action_json = Vertex2Code.run(message.code_engine_name, message.parsed_output["Code Package"])
message.customed_kargs["Vertex2Code"] = action_json
message.customed_kargs.setdefault("Retrieval_Codes", "")
message.customed_kargs["Retrieval_Codes"] += "\n" + action_json["code"]
return message
def create_prompt(
self, query: Message, memory: Memory =None, history: Memory = None, background: Memory = None, memory_pool: Memory=None, prompt_mamnger=None) -> str:
'''
prompt engineer, contains role\task\tools\docs\memory
'''
formatted_tools, tool_names, _ = self.create_tools_prompt(query)
prompt = self.role.role_prompt.format(**{"formatted_tools": formatted_tools, "tool_names": tool_names})
#
input_query = query.input_query
input_keys = parse_section(self.role.role_prompt, 'Input Format')
prompt += "\n#### Begin!!!\n"
#
for input_key in input_keys:
if input_key == "Origin Query":
prompt += f"\n**{input_key}:**\n" + query.origin_query
elif input_key == "Retrieval Codes":
prompt += f"\n**{input_key}:**\n" + query.customed_kargs["Retrieval_Codes"]
elif input_key == "Code Packages":
vertices = query.customed_kargs["RelatedVerticesRetrivalRes"]["vertices"]
prompt += f"\n**{input_key}:**\n" + ", ".join([str(v) for v in vertices])
while "{{" in prompt or "}}" in prompt:
prompt = prompt.replace("{{", "{")
prompt = prompt.replace("}}", "}")
return prompt
setattr(agent_module, 'CodeRetrievalJudger', CodeRetrievalJudger)
setattr(agent_module, 'CodeRetrievalDivergent', CodeRetrievalDivergent)
#
phase_name = "codeRetrievalPhase"
phase = BasePhase(phase_name,
task = None,
phase_config = PHASE_CONFIGS,
chain_config = CHAIN_CONFIGS,
role_config = AGETN_CONFIGS,
do_summary=False,
do_code_retrieval=False,
do_doc_retrieval=False,
do_search=False,
)
# round-1
query_content = "remove 这个函数是用来做什么的"
query = Message(
role_name="user", role_type="human",
role_content=query_content, input_query=query_content, origin_query=query_content,
code_engine_name="client", score_threshold=1.0, top_k=3, cb_search_type="cypher"
)
output_message1, _ = phase.step(query)

View File

@ -28,7 +28,7 @@ print(src_dir)
# tools = toLangchainTools([TOOL_DICT[i] for i in TOOL_SETS if i in TOOL_DICT])
TOOL_SETS = [
"StockInfo", "StockName"
"StockName", "StockInfo",
]
tools = toLangchainTools([TOOL_DICT[i] for i in TOOL_SETS if i in TOOL_DICT])
@ -52,7 +52,7 @@ phase = BasePhase(phase_name,
do_search=False,
)
query_content = "查询贵州茅台的股票代码,并查询截止到当前日期(2023年11月8日)的最近10天的每日时序数据然后对时序数据画出折线图并分析"
query_content = "查询贵州茅台的股票代码,并查询截止到当前日期(2023年12月24日)的最近10天的每日时序数据然后对时序数据画出折线图并分析"
query = Message(role_name="human", role_type="user", input_query=query_content, role_content=query_content, origin_query=query_content, tools=tools)

View File

@ -27,47 +27,25 @@ class GptqConfig:
def load_quant_by_autogptq(model):
# qwen-72b-int4 use these code
from modelscope import AutoTokenizer, AutoModelForCausalLM
# Note: The default behavior now has injection attack prevention off.
tokenizer = AutoTokenizer.from_pretrained(model, revision='master', trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
model, device_map="auto",
trust_remote_code=True
).eval()
return model, tokenizer
# codellama-34b-int4 use these code
# from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
# tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False, trust_remote_code=True)
# model = AutoGPTQForCausalLM.from_quantized(model, inject_fused_attention=False,trust_remote_code=True,
# inject_fused_mlp=False,use_cuda_fp16=True,disable_exllama=False,device_map='auto')
# return model, tokenizer
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
model = AutoGPTQForCausalLM.from_quantized(model,
inject_fused_attention=False,
inject_fused_mlp=False,
use_cuda_fp16=True,
disable_exllama=False,
device_map='auto'
)
return model
def load_gptq_quantized(model_name, gptq_config: GptqConfig):
print("Loading GPTQ quantized model...")
if gptq_config.act_order:
try:
script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
module_path = os.path.join(script_path, "../repositories/GPTQ-for-LLaMa")
sys.path.insert(0, module_path)
from llama import load_quant
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
# only `fastest-inference-4bit` branch cares about `act_order`
model = load_quant(
model_name,
find_gptq_ckpt(gptq_config),
gptq_config.wbits,
gptq_config.groupsize,
act_order=gptq_config.act_order,
)
except ImportError as e:
print(f"Error: Failed to load GPTQ-for-LLaMa. {e}")
print("See https://github.com/lm-sys/FastChat/blob/main/docs/gptq.md")
sys.exit(-1)
else:
# other branches
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
model = load_quant_by_autogptq(model_name)
model, tokenizer = load_quant_by_autogptq(model_name)
return model, tokenizer

View File

@ -1,248 +1,248 @@
import docker, sys, os, time, requests, psutil
import subprocess
from docker.types import Mount, DeviceRequest
from loguru import logger
src_dir = os.path.join(
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
)
sys.path.append(src_dir)
from configs.model_config import USE_FASTCHAT, JUPYTER_WORK_PATH
from configs.server_config import (
NO_REMOTE_API, SANDBOX_SERVER, SANDBOX_IMAGE_NAME, SANDBOX_CONTRAINER_NAME,
WEBUI_SERVER, API_SERVER, SDFILE_API_SERVER, CONTRAINER_NAME, IMAGE_NAME, DOCKER_SERVICE,
DEFAULT_BIND_HOST, NEBULA_GRAPH_SERVER
)
import platform
system_name = platform.system()
USE_TTY = system_name in ["Windows"]
def check_process(content: str, lang: str = None, do_stop=False):
'''process-not-exist is true, process-exist is false'''
for process in psutil.process_iter(["pid", "name", "cmdline"]):
# check process name contains "jupyter" and port=xx
# if f"port={SANDBOX_SERVER['port']}" in str(process.info["cmdline"]).lower() and \
# "jupyter" in process.info['name'].lower():
if content in str(process.info["cmdline"]).lower():
logger.info(f"content, {process.info}")
# 关闭进程
if do_stop:
process.terminate()
return True
return False
return True
def check_docker(client, container_name, do_stop=False):
'''container-not-exist is true, container-exist is false'''
if client is None: return True
for i in client.containers.list(all=True):
if i.name == container_name:
if do_stop:
container = i
if container_name == CONTRAINER_NAME and i.status == 'running':
# wrap up db
logger.info(f'inside {container_name}')
# cp nebula data
res = container.exec_run('''sh chatbot/dev_opsgpt/utils/nebula_cp.sh''')
logger.info(f'cp res={res}')
# stop nebula service
res = container.exec_run('''/usr/local/nebula/scripts/nebula.service stop all''')
logger.info(f'stop res={res}')
container.stop()
container.remove()
return True
return False
return True
def start_docker(client, script_shs, ports, image_name, container_name, mounts=None, network=None):
container = client.containers.run(
image=image_name,
command="bash",
mounts=mounts,
name=container_name,
mem_limit="8g",
# device_requests=[DeviceRequest(count=-1, capabilities=[['gpu']])],
# network_mode="host",
ports=ports,
stdin_open=True,
detach=True,
tty=USE_TTY,
network=network,
)
logger.info(f"docker id: {container.id[:10]}")
# 启动notebook
for script_sh in script_shs:
if USE_FASTCHAT and "llm_api" in script_sh:
logger.debug(script_sh)
response = container.exec_run(["sh", "-c", script_sh])
logger.debug(response)
elif "llm_api" not in script_sh:
logger.debug(script_sh)
response = container.exec_run(["sh", "-c", script_sh])
logger.debug(response)
return container
#########################################
############# 开始启动服务 ###############
#########################################
network_name ='my_network'
def start_sandbox_service(network_name ='my_network'):
# networks = client.networks.list()
# if any([network_name==i.attrs["Name"] for i in networks]):
# network = client.networks.get(network_name)
# else:
# network = client.networks.create('my_network', driver='bridge')
mount = Mount(
type='bind',
source=os.path.join(src_dir, "jupyter_work"),
target='/home/user/chatbot/jupyter_work',
read_only=False # 如果需要只读访问将此选项设置为True
)
mounts = [mount]
# 沙盒的启动与服务的启动是独立的
if SANDBOX_SERVER["do_remote"]:
client = docker.from_env()
# 启动容器
logger.info("start container sandbox service")
script_shs = ["bash jupyter_start.sh"]
JUPYTER_WORK_PATH = "/home/user/chatbot/jupyter_work"
script_shs = [f"cd /home/user/chatbot/jupyter_work && nohup jupyter-notebook --NotebookApp.token=mytoken --port=5050 --allow-root --ip=0.0.0.0 --notebook-dir={JUPYTER_WORK_PATH} --no-browser --ServerApp.disable_check_xsrf=True &"]
ports = {f"{SANDBOX_SERVER['docker_port']}/tcp": f"{SANDBOX_SERVER['port']}/tcp"}
if check_docker(client, SANDBOX_CONTRAINER_NAME, do_stop=True, ):
container = start_docker(client, script_shs, ports, SANDBOX_IMAGE_NAME, SANDBOX_CONTRAINER_NAME, mounts=mounts, network=network_name)
# 判断notebook是否启动
time.sleep(5)
retry_nums = 3
while retry_nums>0:
logger.info(f"http://localhost:{SANDBOX_SERVER['port']}")
response = requests.get(f"http://localhost:{SANDBOX_SERVER['port']}", timeout=270)
if response.status_code == 200:
logger.info("container & notebook init success")
break
else:
retry_nums -= 1
logger.info(client.containers.list())
logger.info("wait container running ...")
time.sleep(5)
else:
try:
client = docker.from_env()
except:
client = None
check_docker(client, SANDBOX_CONTRAINER_NAME, do_stop=True, )
logger.info("start local sandbox service")
def start_api_service(sandbox_host=DEFAULT_BIND_HOST):
# 启动service的容器
if DOCKER_SERVICE:
client = docker.from_env()
logger.info("start container service")
check_process("service/api.py", do_stop=True)
check_process("service/sdfile_api.py", do_stop=True)
check_process("service/sdfile_api.py", do_stop=True)
check_process("webui.py", do_stop=True)
mount = Mount(
type='bind',
source=src_dir,
target='/home/user/chatbot/',
read_only=False # 如果需要只读访问将此选项设置为True
)
mount_database = Mount(
type='bind',
source=os.path.join(src_dir, "knowledge_base"),
target='/home/user/knowledge_base/',
read_only=False # 如果需要只读访问将此选项设置为True
)
mount_code_database = Mount(
type='bind',
source=os.path.join(src_dir, "code_base"),
target='/home/user/code_base/',
read_only=False # 如果需要只读访问将此选项设置为True
)
ports={
f"{API_SERVER['docker_port']}/tcp": f"{API_SERVER['port']}/tcp",
f"{WEBUI_SERVER['docker_port']}/tcp": f"{WEBUI_SERVER['port']}/tcp",
f"{SDFILE_API_SERVER['docker_port']}/tcp": f"{SDFILE_API_SERVER['port']}/tcp",
f"{NEBULA_GRAPH_SERVER['docker_port']}/tcp": f"{NEBULA_GRAPH_SERVER['port']}/tcp"
}
mounts = [mount, mount_database, mount_code_database]
script_shs = [
"mkdir -p /home/user/logs",
'''
if [ -d "/home/user/chatbot/data/nebula_data/data/meta" ]; then
cp -r /home/user/chatbot/data/nebula_data/data /usr/local/nebula/
fi
''',
"/usr/local/nebula/scripts/nebula.service start all",
"/usr/local/nebula/scripts/nebula.service status all",
"sleep 2",
'''curl -X PUT -H "Content-Type: application/json" -d'{"heartbeat_interval_secs":"2"}' -s "http://127.0.0.1:19559/flags"''',
'''curl -X PUT -H "Content-Type: application/json" -d'{"heartbeat_interval_secs":"2"}' -s "http://127.0.0.1:19669/flags"''',
'''curl -X PUT -H "Content-Type: application/json" -d'{"heartbeat_interval_secs":"2"}' -s "http://127.0.0.1:19779/flags"''',
"nohup python chatbot/dev_opsgpt/service/sdfile_api.py > /home/user/logs/sdfile_api.log 2>&1 &",
f"export DUCKDUCKGO_PROXY=socks5://host.docker.internal:13659 && export SANDBOX_HOST={sandbox_host} &&\
nohup python chatbot/dev_opsgpt/service/api.py > /home/user/logs/api.log 2>&1 &",
"nohup python chatbot/dev_opsgpt/service/llm_api.py > /home/user/ 2>&1 &",
f"export DUCKDUCKGO_PROXY=socks5://host.docker.internal:13659 && export SANDBOX_HOST={sandbox_host} &&\
cd chatbot/examples && nohup streamlit run webui.py > /home/user/logs/start_webui.log 2>&1 &"
]
if check_docker(client, CONTRAINER_NAME, do_stop=True):
container = start_docker(client, script_shs, ports, IMAGE_NAME, CONTRAINER_NAME, mounts, network=network_name)
else:
logger.info("start local service")
# 关闭之前启动的docker 服务
# check_docker(client, CONTRAINER_NAME, do_stop=True, )
api_sh = "nohup python ../dev_opsgpt/service/api.py > ../logs/api.log 2>&1 &"
sdfile_sh = "nohup python ../dev_opsgpt/service/sdfile_api.py > ../logs/sdfile_api.log 2>&1 &"
notebook_sh = f"nohup jupyter-notebook --NotebookApp.token=mytoken --port={SANDBOX_SERVER['port']} --allow-root --ip=0.0.0.0 --notebook-dir={JUPYTER_WORK_PATH} --no-browser --ServerApp.disable_check_xsrf=True > ../logs/sandbox.log 2>&1 &"
llm_sh = "nohup python ../dev_opsgpt/service/llm_api.py > ../logs/llm_api.log 2>&1 &"
webui_sh = "streamlit run webui.py" if USE_TTY else "streamlit run webui.py"
if check_process("jupyter-notebook --NotebookApp"):
logger.debug(f"{notebook_sh}")
subprocess.Popen(notebook_sh, shell=True)
#
if not NO_REMOTE_API and check_process("service/api.py"):
subprocess.Popen(api_sh, shell=True)
#
if USE_FASTCHAT and check_process("service/llm_api.py"):
subprocess.Popen(llm_sh, shell=True)
#
if check_process("service/sdfile_api.py"):
subprocess.Popen(sdfile_sh, shell=True)
subprocess.Popen(webui_sh, shell=True)
if __name__ == "__main__":
start_sandbox_service()
sandbox_host = DEFAULT_BIND_HOST
if SANDBOX_SERVER["do_remote"]:
client = docker.from_env()
containers = client.containers.list(all=True)
for container in containers:
container_a_info = client.containers.get(container.id)
if container_a_info.name == SANDBOX_CONTRAINER_NAME:
container1_networks = container.attrs['NetworkSettings']['Networks']
sandbox_host = container1_networks.get(network_name)["IPAddress"]
break
start_api_service(sandbox_host)
import docker, sys, os, time, requests, psutil
import subprocess
from docker.types import Mount, DeviceRequest
from loguru import logger
src_dir = os.path.join(
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
)
sys.path.append(src_dir)
from configs.model_config import USE_FASTCHAT, JUPYTER_WORK_PATH
from configs.server_config import (
NO_REMOTE_API, SANDBOX_SERVER, SANDBOX_IMAGE_NAME, SANDBOX_CONTRAINER_NAME,
WEBUI_SERVER, API_SERVER, SDFILE_API_SERVER, CONTRAINER_NAME, IMAGE_NAME, DOCKER_SERVICE,
DEFAULT_BIND_HOST, NEBULA_GRAPH_SERVER
)
import platform
system_name = platform.system()
USE_TTY = system_name in ["Windows"]
def check_process(content: str, lang: str = None, do_stop=False):
'''process-not-exist is true, process-exist is false'''
for process in psutil.process_iter(["pid", "name", "cmdline"]):
# check process name contains "jupyter" and port=xx
# if f"port={SANDBOX_SERVER['port']}" in str(process.info["cmdline"]).lower() and \
# "jupyter" in process.info['name'].lower():
if content in str(process.info["cmdline"]).lower():
logger.info(f"content, {process.info}")
# 关闭进程
if do_stop:
process.terminate()
return True
return False
return True
def check_docker(client, container_name, do_stop=False):
'''container-not-exist is true, container-exist is false'''
if client is None: return True
for i in client.containers.list(all=True):
if i.name == container_name:
if do_stop:
container = i
if container_name == CONTRAINER_NAME and i.status == 'running':
# wrap up db
logger.info(f'inside {container_name}')
# cp nebula data
res = container.exec_run('''sh chatbot/dev_opsgpt/utils/nebula_cp.sh''')
logger.info(f'cp res={res}')
# stop nebula service
res = container.exec_run('''/usr/local/nebula/scripts/nebula.service stop all''')
logger.info(f'stop res={res}')
container.stop()
container.remove()
return True
return False
return True
def start_docker(client, script_shs, ports, image_name, container_name, mounts=None, network=None):
container = client.containers.run(
image=image_name,
command="bash",
mounts=mounts,
name=container_name,
mem_limit="8g",
# device_requests=[DeviceRequest(count=-1, capabilities=[['gpu']])],
# network_mode="host",
ports=ports,
stdin_open=True,
detach=True,
tty=USE_TTY,
network=network,
)
logger.info(f"docker id: {container.id[:10]}")
# 启动notebook
for script_sh in script_shs:
if USE_FASTCHAT and "llm_api" in script_sh:
logger.debug(script_sh)
response = container.exec_run(["sh", "-c", script_sh])
logger.debug(response)
elif "llm_api" not in script_sh:
logger.debug(script_sh)
response = container.exec_run(["sh", "-c", script_sh])
logger.debug(response)
return container
#########################################
############# 开始启动服务 ###############
#########################################
network_name ='my_network'
def start_sandbox_service(network_name ='my_network'):
# networks = client.networks.list()
# if any([network_name==i.attrs["Name"] for i in networks]):
# network = client.networks.get(network_name)
# else:
# network = client.networks.create('my_network', driver='bridge')
mount = Mount(
type='bind',
source=os.path.join(src_dir, "jupyter_work"),
target='/home/user/chatbot/jupyter_work',
read_only=False # 如果需要只读访问将此选项设置为True
)
mounts = [mount]
# 沙盒的启动与服务的启动是独立的
if SANDBOX_SERVER["do_remote"]:
client = docker.from_env()
# 启动容器
logger.info("start container sandbox service")
script_shs = ["bash jupyter_start.sh"]
JUPYTER_WORK_PATH = "/home/user/chatbot/jupyter_work"
script_shs = [f"cd /home/user/chatbot/jupyter_work && nohup jupyter-notebook --NotebookApp.token=mytoken --port=5050 --allow-root --ip=0.0.0.0 --notebook-dir={JUPYTER_WORK_PATH} --no-browser --ServerApp.disable_check_xsrf=True &"]
ports = {f"{SANDBOX_SERVER['docker_port']}/tcp": f"{SANDBOX_SERVER['port']}/tcp"}
if check_docker(client, SANDBOX_CONTRAINER_NAME, do_stop=True, ):
container = start_docker(client, script_shs, ports, SANDBOX_IMAGE_NAME, SANDBOX_CONTRAINER_NAME, mounts=mounts, network=network_name)
# 判断notebook是否启动
time.sleep(5)
retry_nums = 3
while retry_nums>0:
logger.info(f"http://localhost:{SANDBOX_SERVER['port']}")
response = requests.get(f"http://localhost:{SANDBOX_SERVER['port']}", timeout=270)
if response.status_code == 200:
logger.info("container & notebook init success")
break
else:
retry_nums -= 1
logger.info(client.containers.list())
logger.info("wait container running ...")
time.sleep(5)
else:
try:
client = docker.from_env()
except:
client = None
check_docker(client, SANDBOX_CONTRAINER_NAME, do_stop=True, )
logger.info("start local sandbox service")
def start_api_service(sandbox_host=DEFAULT_BIND_HOST):
# 启动service的容器
if DOCKER_SERVICE:
client = docker.from_env()
logger.info("start container service")
check_process("service/api.py", do_stop=True)
check_process("service/sdfile_api.py", do_stop=True)
check_process("service/sdfile_api.py", do_stop=True)
check_process("webui.py", do_stop=True)
mount = Mount(
type='bind',
source=src_dir,
target='/home/user/chatbot/',
read_only=False # 如果需要只读访问将此选项设置为True
)
mount_database = Mount(
type='bind',
source=os.path.join(src_dir, "knowledge_base"),
target='/home/user/knowledge_base/',
read_only=False # 如果需要只读访问将此选项设置为True
)
mount_code_database = Mount(
type='bind',
source=os.path.join(src_dir, "code_base"),
target='/home/user/code_base/',
read_only=False # 如果需要只读访问将此选项设置为True
)
ports={
f"{API_SERVER['docker_port']}/tcp": f"{API_SERVER['port']}/tcp",
f"{WEBUI_SERVER['docker_port']}/tcp": f"{WEBUI_SERVER['port']}/tcp",
f"{SDFILE_API_SERVER['docker_port']}/tcp": f"{SDFILE_API_SERVER['port']}/tcp",
f"{NEBULA_GRAPH_SERVER['docker_port']}/tcp": f"{NEBULA_GRAPH_SERVER['port']}/tcp"
}
mounts = [mount, mount_database, mount_code_database]
script_shs = [
"mkdir -p /home/user/logs",
'''
if [ -d "/home/user/chatbot/data/nebula_data/data/meta" ]; then
cp -r /home/user/chatbot/data/nebula_data/data /usr/local/nebula/
fi
''',
"/usr/local/nebula/scripts/nebula.service start all",
"/usr/local/nebula/scripts/nebula.service status all",
"sleep 2",
'''curl -X PUT -H "Content-Type: application/json" -d'{"heartbeat_interval_secs":"2"}' -s "http://127.0.0.1:19559/flags"''',
'''curl -X PUT -H "Content-Type: application/json" -d'{"heartbeat_interval_secs":"2"}' -s "http://127.0.0.1:19669/flags"''',
'''curl -X PUT -H "Content-Type: application/json" -d'{"heartbeat_interval_secs":"2"}' -s "http://127.0.0.1:19779/flags"''',
"nohup python chatbot/dev_opsgpt/service/sdfile_api.py > /home/user/logs/sdfile_api.log 2>&1 &",
f"export DUCKDUCKGO_PROXY=socks5://host.docker.internal:13659 && export SANDBOX_HOST={sandbox_host} &&\
nohup python chatbot/dev_opsgpt/service/api.py > /home/user/logs/api.log 2>&1 &",
"nohup python chatbot/dev_opsgpt/service/llm_api.py > /home/user/ 2>&1 &",
f"export DUCKDUCKGO_PROXY=socks5://host.docker.internal:13659 && export SANDBOX_HOST={sandbox_host} &&\
cd chatbot/examples && nohup streamlit run webui.py > /home/user/logs/start_webui.log 2>&1 &"
]
if check_docker(client, CONTRAINER_NAME, do_stop=True):
container = start_docker(client, script_shs, ports, IMAGE_NAME, CONTRAINER_NAME, mounts, network=network_name)
else:
logger.info("start local service")
# 关闭之前启动的docker 服务
# check_docker(client, CONTRAINER_NAME, do_stop=True, )
api_sh = "nohup python ../dev_opsgpt/service/api.py > ../logs/api.log 2>&1 &"
sdfile_sh = "nohup python ../dev_opsgpt/service/sdfile_api.py > ../logs/sdfile_api.log 2>&1 &"
notebook_sh = f"nohup jupyter-notebook --NotebookApp.token=mytoken --port={SANDBOX_SERVER['port']} --allow-root --ip=0.0.0.0 --notebook-dir={JUPYTER_WORK_PATH} --no-browser --ServerApp.disable_check_xsrf=True > ../logs/sandbox.log 2>&1 &"
llm_sh = "nohup python ../dev_opsgpt/service/llm_api.py > ../logs/llm_api.log 2>&1 &"
webui_sh = "streamlit run webui.py" if USE_TTY else "streamlit run webui.py"
if check_process("jupyter-notebook --NotebookApp"):
logger.debug(f"{notebook_sh}")
subprocess.Popen(notebook_sh, shell=True)
#
if not NO_REMOTE_API and check_process("service/api.py"):
subprocess.Popen(api_sh, shell=True)
#
if USE_FASTCHAT and check_process("service/llm_api.py"):
subprocess.Popen(llm_sh, shell=True)
#
if check_process("service/sdfile_api.py"):
subprocess.Popen(sdfile_sh, shell=True)
subprocess.Popen(webui_sh, shell=True)
if __name__ == "__main__":
start_sandbox_service()
sandbox_host = DEFAULT_BIND_HOST
if SANDBOX_SERVER["do_remote"]:
client = docker.from_env()
containers = client.containers.list(all=True)
for container in containers:
container_a_info = client.containers.get(container.id)
if container_a_info.name == SANDBOX_CONTRAINER_NAME:
container1_networks = container.attrs['NetworkSettings']['Networks']
sandbox_host = container1_networks.get(network_name)["IPAddress"]
break
start_api_service(sandbox_host)

View File

@ -27,19 +27,19 @@ api = ApiRequest(base_url="http://127.0.0.1:7861", no_remote_api=NO_REMOTE_API)
if __name__ == "__main__":
st.set_page_config(
"DevOpsGPT-Chat WebUI",
"CodeFuse-ChatBot WebUI",
os.path.join("../sources/imgs", "devops-chatbot.png"),
initial_sidebar_state="expanded",
menu_items={
'Get Help': 'https://github.com/lightislost/devopsgpt',
'Report a bug': "https://github.com/lightislost/devopsgpt/issues",
'About': f"""欢迎使用 DevOpsGPT-Chat WebUI {VERSION}"""
'Get Help': 'https://github.com/codefuse-ai/codefuse-chatbot',
'Report a bug': "https://github.com/codefuse-ai/codefuse-chatbot/issues",
'About': f"""欢迎使用 CodeFuse-ChatBot WebUI {VERSION}"""
}
)
if not chat_box.chat_inited:
st.toast(
f"欢迎使用 [`DevOpsGPT-Chat`](https://github.com/lightislost/devopsgpt) ! \n\n"
f"欢迎使用 [`CodeFuse-ChatBot`](https://github.com/codefuse-ai/codefuse-chatbot) ! \n\n"
f"当前使用模型`{LLM_MODEL}`, 您可以开始提问了."
)
@ -71,7 +71,7 @@ if __name__ == "__main__":
use_column_width=True
)
st.caption(
f"""<p align="right">当前版本:{VERSION}</p>""",
f"""<p align="right"> CodeFuse-ChatBot 当前版本:{VERSION}</p>""",
unsafe_allow_html=True,
)
options = list(pages)

View File

@ -16,7 +16,7 @@ faiss-cpu
nltk
loguru
pypdf
duckduckgo-search
duckduckgo-search==3.9.11
pysocks
accelerate
docker
@ -39,7 +39,7 @@ streamlit-option-menu>=0.3.6
streamlit-antd-components>=0.1.11
streamlit-chatbox>=1.1.6
streamlit-aggrid>=0.3.4.post3
httpx~=0.24.1
httpx
javalang==0.13.0
jsonref==1.1.0
@ -51,4 +51,9 @@ nebula3-python==3.1.0
protobuf==3.20.*
transformers_stream_generator
einops
auto-gptq
optimum
modelscope
# vllm model
vllm==0.2.2; sys_platform == "linux"

View File

@ -22,15 +22,30 @@ if __name__ == "__main__":
# chat = ChatOpenAI(temperature=0.1, model_name="gpt-3.5-turbo")
# print(chat.predict("hi!"))
print(LLM_MODEL, llm_model_dict[LLM_MODEL]["api_key"], llm_model_dict[LLM_MODEL]["api_base_url"])
model = ChatOpenAI(
streaming=True,
verbose=True,
openai_api_key=llm_model_dict[LLM_MODEL]["api_key"],
openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"],
model_name=LLM_MODEL
# print(LLM_MODEL, llm_model_dict[LLM_MODEL]["api_key"], llm_model_dict[LLM_MODEL]["api_base_url"])
# model = ChatOpenAI(
# streaming=True,
# verbose=True,
# openai_api_key=llm_model_dict[LLM_MODEL]["api_key"],
# openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"],
# model_name=LLM_MODEL
# )
# chat_prompt = ChatPromptTemplate.from_messages([("human", "{input}")])
# chain = LLMChain(prompt=chat_prompt, llm=model)
# content = chain({"input": "hello"})
# print(content)
import openai
# openai.api_key = "EMPTY" # Not support yet
openai.api_base = "http://127.0.0.1:8888/v1"
model = "example"
# create a chat completion
completion = openai.ChatCompletion.create(
model=model,
messages=[{"role": "user", "content": "Hello! What is your name? "}],
max_tokens=100,
)
chat_prompt = ChatPromptTemplate.from_messages([("human", "{input}")])
chain = LLMChain(prompt=chat_prompt, llm=model)
content = chain({"input": "hello"})
print(content)
# print the completion
print(completion.choices[0].message.content)