237 lines
9.1 KiB
Python
237 lines
9.1 KiB
Python
|
import os, sys, requests
|
|||
|
|
|||
|
src_dir = os.path.join(
|
|||
|
os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|||
|
)
|
|||
|
sys.path.append(src_dir)
|
|||
|
|
|||
|
from configs.model_config import *
|
|||
|
from dev_opsgpt.connector.phase import BasePhase
|
|||
|
from dev_opsgpt.connector.agents import BaseAgent
|
|||
|
from dev_opsgpt.connector.chains import BaseChain
|
|||
|
from dev_opsgpt.connector.schema import (
|
|||
|
Message, Memory, load_role_configs, load_phase_configs, load_chain_configs
|
|||
|
)
|
|||
|
from dev_opsgpt.connector.configs import AGETN_CONFIGS, CHAIN_CONFIGS, PHASE_CONFIGS
|
|||
|
from dev_opsgpt.connector.utils import parse_section
|
|||
|
import importlib
|
|||
|
|
|||
|
|
|||
|
|
|||
|
# update new agent configs
|
|||
|
codeRetrievalJudger_PROMPT = """#### CodeRetrievalJudger Assistance Guidance
|
|||
|
|
|||
|
Given the user's question and respective code, you need to decide whether the provided codes are enough to answer the question.
|
|||
|
|
|||
|
#### Input Format
|
|||
|
|
|||
|
**Origin Query:** the initial question or objective that the user wanted to achieve
|
|||
|
|
|||
|
**Retrieval Codes:** the Retrieval Codes from the code base
|
|||
|
|
|||
|
#### Response Output Format
|
|||
|
|
|||
|
**REASON:** Justify the decision of choosing 'finished' and 'continued' by evaluating the progress step by step.
|
|||
|
"""
|
|||
|
|
|||
|
# 将下面的话放到上面的prompt里面去执行,让它判断是否停止
|
|||
|
# **Action Status:** Set to 'finished' or 'continued'.
|
|||
|
# If it's 'finished', the provided codes can answer the origin query.
|
|||
|
# If it's 'continued', the origin query cannot be answered well from the provided code.
|
|||
|
|
|||
|
codeRetrievalDivergent_PROMPT = """#### CodeRetrievalDivergen Assistance Guidance
|
|||
|
|
|||
|
You are a assistant that helps to determine which code package is needed to answer the question.
|
|||
|
|
|||
|
Given the user's question, Retrieval code, and the code Packages related to Retrieval code. you need to decide which code package we need to read to better answer the question.
|
|||
|
|
|||
|
#### Input Format
|
|||
|
|
|||
|
**Origin Query:** the initial question or objective that the user wanted to achieve
|
|||
|
|
|||
|
**Retrieval Codes:** the Retrieval Codes from the code base
|
|||
|
|
|||
|
**Code Packages:** the code packages related to Retrieval code
|
|||
|
|
|||
|
#### Response Output Format
|
|||
|
|
|||
|
**Code Package:** Identify another Code Package from the Code Packages that should be read to provide a better answer to the Origin Query.
|
|||
|
|
|||
|
**REASON:** Justify the decision of choosing 'finished' and 'continued' by evaluating the progress step by step.
|
|||
|
"""
|
|||
|
|
|||
|
AGETN_CONFIGS.update({
|
|||
|
"codeRetrievalJudger": {
|
|||
|
"role": {
|
|||
|
"role_prompt": codeRetrievalJudger_PROMPT,
|
|||
|
"role_type": "assistant",
|
|||
|
"role_name": "codeRetrievalJudger",
|
|||
|
"role_desc": "",
|
|||
|
"agent_type": "CodeRetrievalJudger"
|
|||
|
# "agent_type": "BaseAgent"
|
|||
|
},
|
|||
|
"chat_turn": 1,
|
|||
|
"focus_agents": [],
|
|||
|
"focus_message_keys": [],
|
|||
|
},
|
|||
|
"codeRetrievalDivergent": {
|
|||
|
"role": {
|
|||
|
"role_prompt": codeRetrievalDivergent_PROMPT,
|
|||
|
"role_type": "assistant",
|
|||
|
"role_name": "codeRetrievalDivergent",
|
|||
|
"role_desc": "",
|
|||
|
"agent_type": "CodeRetrievalDivergent"
|
|||
|
# "agent_type": "BaseAgent"
|
|||
|
},
|
|||
|
"chat_turn": 1,
|
|||
|
"focus_agents": [],
|
|||
|
"focus_message_keys": [],
|
|||
|
},
|
|||
|
})
|
|||
|
# update new chain configs
|
|||
|
CHAIN_CONFIGS.update({
|
|||
|
"codeRetrievalChain": {
|
|||
|
"chain_name": "codeRetrievalChain",
|
|||
|
"chain_type": "BaseChain",
|
|||
|
"agents": ["codeRetrievalJudger", "codeRetrievalDivergent"],
|
|||
|
"chat_turn": 5,
|
|||
|
"do_checker": False,
|
|||
|
"chain_prompt": ""
|
|||
|
}
|
|||
|
})
|
|||
|
|
|||
|
# update phase configs
|
|||
|
PHASE_CONFIGS.update({
|
|||
|
"codeRetrievalPhase": {
|
|||
|
"phase_name": "codeRetrievalPhase",
|
|||
|
"phase_type": "BasePhase",
|
|||
|
"chains": ["codeRetrievalChain"],
|
|||
|
"do_summary": False,
|
|||
|
"do_search": False,
|
|||
|
"do_doc_retrieval": False,
|
|||
|
"do_code_retrieval": True,
|
|||
|
"do_tool_retrieval": False,
|
|||
|
"do_using_tool": False
|
|||
|
},
|
|||
|
})
|
|||
|
|
|||
|
|
|||
|
|
|||
|
|
|||
|
role_configs = load_role_configs(AGETN_CONFIGS)
|
|||
|
chain_configs = load_chain_configs(CHAIN_CONFIGS)
|
|||
|
phase_configs = load_phase_configs(PHASE_CONFIGS)
|
|||
|
|
|||
|
agent_module = importlib.import_module("dev_opsgpt.connector.agents")
|
|||
|
|
|||
|
from dev_opsgpt.tools import CodeRetrievalSingle, RelatedVerticesRetrival, Vertex2Code
|
|||
|
|
|||
|
# 定义一个新的类
|
|||
|
class CodeRetrievalJudger(BaseAgent):
|
|||
|
|
|||
|
def start_action_step(self, message: Message) -> Message:
|
|||
|
'''do action before agent predict '''
|
|||
|
action_json = CodeRetrievalSingle.run(message.code_engine_name, message.origin_query)
|
|||
|
message.customed_kargs["CodeRetrievalSingleRes"] = action_json
|
|||
|
message.customed_kargs.setdefault("Retrieval_Codes", "")
|
|||
|
message.customed_kargs["Retrieval_Codes"] += "\n" + action_json["code"]
|
|||
|
return message
|
|||
|
|
|||
|
def create_prompt(
|
|||
|
self, query: Message, memory: Memory =None, history: Memory = None, background: Memory = None, memory_pool: Memory=None, prompt_mamnger=None) -> str:
|
|||
|
'''
|
|||
|
prompt engineer, contains role\task\tools\docs\memory
|
|||
|
'''
|
|||
|
#
|
|||
|
logger.debug(f"query: {query.customed_kargs}")
|
|||
|
formatted_tools, tool_names, _ = self.create_tools_prompt(query)
|
|||
|
prompt = self.role.role_prompt.format(**{"formatted_tools": formatted_tools, "tool_names": tool_names})
|
|||
|
#
|
|||
|
input_keys = parse_section(self.role.role_prompt, 'Input Format')
|
|||
|
prompt += "\n#### Begin!!!\n"
|
|||
|
#
|
|||
|
for input_key in input_keys:
|
|||
|
if input_key == "Origin Query":
|
|||
|
prompt += f"\n**{input_key}:**\n" + query.origin_query
|
|||
|
elif input_key == "Retrieval Codes":
|
|||
|
prompt += f"\n**{input_key}:**\n" + query.customed_kargs["Retrieval_Codes"]
|
|||
|
|
|||
|
while "{{" in prompt or "}}" in prompt:
|
|||
|
prompt = prompt.replace("{{", "{")
|
|||
|
prompt = prompt.replace("}}", "}")
|
|||
|
return prompt
|
|||
|
|
|||
|
# 定义一个新的类
|
|||
|
class CodeRetrievalDivergent(BaseAgent):
|
|||
|
|
|||
|
def start_action_step(self, message: Message) -> Message:
|
|||
|
'''do action before agent predict '''
|
|||
|
action_json = RelatedVerticesRetrival.run(message.code_engine_name, message.customed_kargs["CodeRetrievalSingleRes"]["vertex"])
|
|||
|
message.customed_kargs["RelatedVerticesRetrivalRes"] = action_json
|
|||
|
return message
|
|||
|
|
|||
|
def end_action_step(self, message: Message) -> Message:
|
|||
|
'''do action before agent predict '''
|
|||
|
# logger.error(f"message: {message}")
|
|||
|
# action_json = Vertex2Code.run(message.code_engine_name, "com.theokanning.openai.client#Utils.java") # message.parsed_output["Code_Filename"])
|
|||
|
action_json = Vertex2Code.run(message.code_engine_name, message.parsed_output["Code Package"])
|
|||
|
message.customed_kargs["Vertex2Code"] = action_json
|
|||
|
message.customed_kargs.setdefault("Retrieval_Codes", "")
|
|||
|
message.customed_kargs["Retrieval_Codes"] += "\n" + action_json["code"]
|
|||
|
return message
|
|||
|
|
|||
|
def create_prompt(
|
|||
|
self, query: Message, memory: Memory =None, history: Memory = None, background: Memory = None, memory_pool: Memory=None, prompt_mamnger=None) -> str:
|
|||
|
'''
|
|||
|
prompt engineer, contains role\task\tools\docs\memory
|
|||
|
'''
|
|||
|
formatted_tools, tool_names, _ = self.create_tools_prompt(query)
|
|||
|
prompt = self.role.role_prompt.format(**{"formatted_tools": formatted_tools, "tool_names": tool_names})
|
|||
|
#
|
|||
|
input_query = query.input_query
|
|||
|
input_keys = parse_section(self.role.role_prompt, 'Input Format')
|
|||
|
prompt += "\n#### Begin!!!\n"
|
|||
|
#
|
|||
|
for input_key in input_keys:
|
|||
|
if input_key == "Origin Query":
|
|||
|
prompt += f"\n**{input_key}:**\n" + query.origin_query
|
|||
|
elif input_key == "Retrieval Codes":
|
|||
|
prompt += f"\n**{input_key}:**\n" + query.customed_kargs["Retrieval_Codes"]
|
|||
|
elif input_key == "Code Packages":
|
|||
|
vertices = query.customed_kargs["RelatedVerticesRetrivalRes"]["vertices"]
|
|||
|
prompt += f"\n**{input_key}:**\n" + ", ".join([str(v) for v in vertices])
|
|||
|
|
|||
|
while "{{" in prompt or "}}" in prompt:
|
|||
|
prompt = prompt.replace("{{", "{")
|
|||
|
prompt = prompt.replace("}}", "}")
|
|||
|
return prompt
|
|||
|
|
|||
|
|
|||
|
setattr(agent_module, 'CodeRetrievalJudger', CodeRetrievalJudger)
|
|||
|
setattr(agent_module, 'CodeRetrievalDivergent', CodeRetrievalDivergent)
|
|||
|
|
|||
|
|
|||
|
#
|
|||
|
phase_name = "codeRetrievalPhase"
|
|||
|
phase = BasePhase(phase_name,
|
|||
|
task = None,
|
|||
|
phase_config = PHASE_CONFIGS,
|
|||
|
chain_config = CHAIN_CONFIGS,
|
|||
|
role_config = AGETN_CONFIGS,
|
|||
|
do_summary=False,
|
|||
|
do_code_retrieval=False,
|
|||
|
do_doc_retrieval=False,
|
|||
|
do_search=False,
|
|||
|
)
|
|||
|
|
|||
|
# round-1
|
|||
|
query_content = "remove 这个函数是用来做什么的"
|
|||
|
query = Message(
|
|||
|
role_name="user", role_type="human",
|
|||
|
role_content=query_content, input_query=query_content, origin_query=query_content,
|
|||
|
code_engine_name="client", score_threshold=1.0, top_k=3, cb_search_type="cypher"
|
|||
|
)
|
|||
|
|
|||
|
output_message1, _ = phase.step(query)
|
|||
|
|