2023-12-26 11:41:53 +08:00
import os , sys , requests
2024-01-26 14:03:25 +08:00
from loguru import logger
2023-12-26 11:41:53 +08:00
src_dir = os . path . join (
os . path . dirname ( os . path . dirname ( os . path . dirname ( os . path . abspath ( __file__ ) ) ) )
)
sys . path . append ( src_dir )
2024-01-26 14:03:25 +08:00
from configs . model_config import KB_ROOT_PATH , JUPYTER_WORK_PATH
from configs . server_config import SANDBOX_SERVER
from coagent . tools import toLangchainTools , TOOL_DICT , TOOL_SETS
from coagent . llm_models . llm_config import EmbedConfig , LLMConfig
from coagent . connector . phase import BasePhase
from coagent . connector . agents import BaseAgent
from coagent . connector . chains import BaseChain
from coagent . connector . schema import (
Message , Memory , load_role_configs , load_phase_configs , load_chain_configs , ActionStatus
2023-12-26 11:41:53 +08:00
)
2024-01-26 14:03:25 +08:00
from coagent . connector . memory_manager import BaseMemoryManager
from coagent . connector . configs import AGETN_CONFIGS , CHAIN_CONFIGS , PHASE_CONFIGS , BASE_PROMPT_CONFIGS
from coagent . connector . utils import parse_section
from coagent . connector . prompt_manager import PromptManager
2023-12-26 11:41:53 +08:00
import importlib
2024-01-26 14:03:25 +08:00
from loguru import logger
2023-12-26 11:41:53 +08:00
# update new agent configs
2024-01-26 14:03:25 +08:00
codeRetrievalJudger_PROMPT = """ #### Agent Profile
2023-12-26 11:41:53 +08:00
Given the user ' s question and respective code, you need to decide whether the provided codes are enough to answer the question.
#### Input Format
* * Origin Query : * * the initial question or objective that the user wanted to achieve
* * Retrieval Codes : * * the Retrieval Codes from the code base
#### Response Output Format
2024-01-26 14:03:25 +08:00
* * Action Status : * * Set to ' finished ' or ' continued ' .
If it ' s ' finished ' , the provided codes can answer the origin query.
If it ' s ' continued ' , the origin query cannot be answered well from the provided code.
2023-12-26 11:41:53 +08:00
* * REASON : * * Justify the decision of choosing ' finished ' and ' continued ' by evaluating the progress step by step .
"""
# 将下面的话放到上面的prompt里面去执行, 让它判断是否停止
2024-01-26 14:03:25 +08:00
# **Action Status:** Set to 'finished' or 'continued'.
2023-12-26 11:41:53 +08:00
# If it's 'finished', the provided codes can answer the origin query.
# If it's 'continued', the origin query cannot be answered well from the provided code.
2024-01-26 14:03:25 +08:00
codeRetrievalDivergent_PROMPT = """ #### Agent Profile
2023-12-26 11:41:53 +08:00
You are a assistant that helps to determine which code package is needed to answer the question .
Given the user ' s question, Retrieval code, and the code Packages related to Retrieval code. you need to decide which code package we need to read to better answer the question.
#### Input Format
* * Origin Query : * * the initial question or objective that the user wanted to achieve
* * Retrieval Codes : * * the Retrieval Codes from the code base
* * Code Packages : * * the code packages related to Retrieval code
#### Response Output Format
2024-01-26 14:03:25 +08:00
* * Code Package : * * Identify another Code Package from the Code Packages that can most help to provide a better answer to the Origin Query , only put one name of the code package here .
2023-12-26 11:41:53 +08:00
* * REASON : * * Justify the decision of choosing ' finished ' and ' continued ' by evaluating the progress step by step .
"""
2024-01-26 14:03:25 +08:00
class CodeRetrievalPM ( PromptManager ) :
def handle_code_packages ( self , * * kwargs ) - > str :
if ' previous_agent_message ' not in kwargs :
return " "
previous_agent_message : Message = kwargs [ ' previous_agent_message ' ]
# 由于两个agent共用了同一个manager, 所以临时性处理
vertices = previous_agent_message . customed_kargs . get ( " RelatedVerticesRetrivalRes " , { } ) . get ( " vertices " , [ ] )
return " , " . join ( [ str ( v ) for v in vertices ] )
def handle_retrieval_codes ( self , * * kwargs ) - > str :
if ' previous_agent_message ' not in kwargs :
return " "
previous_agent_message : Message = kwargs [ ' previous_agent_message ' ]
return ' \n ' . join ( previous_agent_message . customed_kargs [ " Retrieval_Codes " ] )
# Design your personal PROMPT INPPUT FORMAT
CODE_RETRIEVAL_PROMPT_CONFIGS = [
{ " field_name " : ' agent_profile ' , " function_name " : ' handle_agent_profile ' , " is_context " : False } ,
{ " field_name " : ' tool_information ' , " function_name " : ' handle_tool_data ' , " is_context " : False } ,
{ " field_name " : ' context_placeholder ' , " function_name " : ' ' , " is_context " : True } ,
# {"field_name": 'reference_documents', "function_name": 'handle_doc_info'},
{ " field_name " : ' session_records ' , " function_name " : ' handle_session_records ' } ,
{ " field_name " : ' retrieval_codes ' , " function_name " : ' handle_retrieval_codes ' } ,
{ " field_name " : ' code_packages ' , " function_name " : ' handle_code_packages ' } ,
{ " field_name " : ' output_format ' , " function_name " : ' handle_output_format ' , ' title ' : ' Response Output Format ' , " is_context " : False } ,
{ " field_name " : ' begin!!! ' , " function_name " : ' handle_response ' , " is_context " : False , " omit_if_empty " : False }
]
2023-12-26 11:41:53 +08:00
AGETN_CONFIGS . update ( {
" codeRetrievalJudger " : {
" role " : {
" role_prompt " : codeRetrievalJudger_PROMPT ,
" role_type " : " assistant " ,
" role_name " : " codeRetrievalJudger " ,
" role_desc " : " " ,
" agent_type " : " CodeRetrievalJudger "
# "agent_type": "BaseAgent"
} ,
2024-01-26 14:03:25 +08:00
" prompt_config " : CODE_RETRIEVAL_PROMPT_CONFIGS ,
" prompt_manager_type " : " CodeRetrievalPM " ,
2023-12-26 11:41:53 +08:00
" chat_turn " : 1 ,
" focus_agents " : [ ] ,
" focus_message_keys " : [ ] ,
} ,
" codeRetrievalDivergent " : {
" role " : {
" role_prompt " : codeRetrievalDivergent_PROMPT ,
" role_type " : " assistant " ,
" role_name " : " codeRetrievalDivergent " ,
" role_desc " : " " ,
" agent_type " : " CodeRetrievalDivergent "
# "agent_type": "BaseAgent"
} ,
2024-01-26 14:03:25 +08:00
" prompt_config " : CODE_RETRIEVAL_PROMPT_CONFIGS ,
" prompt_manager_type " : " CodeRetrievalPM " ,
2023-12-26 11:41:53 +08:00
" chat_turn " : 1 ,
" focus_agents " : [ ] ,
" focus_message_keys " : [ ] ,
} ,
} )
# update new chain configs
CHAIN_CONFIGS . update ( {
" codeRetrievalChain " : {
" chain_name " : " codeRetrievalChain " ,
" chain_type " : " BaseChain " ,
" agents " : [ " codeRetrievalJudger " , " codeRetrievalDivergent " ] ,
2024-01-26 14:03:25 +08:00
" chat_turn " : 3 ,
2023-12-26 11:41:53 +08:00
" do_checker " : False ,
" chain_prompt " : " "
}
} )
# update phase configs
PHASE_CONFIGS . update ( {
" codeRetrievalPhase " : {
" phase_name " : " codeRetrievalPhase " ,
" phase_type " : " BasePhase " ,
" chains " : [ " codeRetrievalChain " ] ,
" do_summary " : False ,
" do_search " : False ,
" do_doc_retrieval " : False ,
" do_code_retrieval " : True ,
" do_tool_retrieval " : False ,
} ,
} )
role_configs = load_role_configs ( AGETN_CONFIGS )
chain_configs = load_chain_configs ( CHAIN_CONFIGS )
phase_configs = load_phase_configs ( PHASE_CONFIGS )
2024-01-26 14:03:25 +08:00
from coagent . tools import CodeRetrievalSingle , RelatedVerticesRetrival , Vertex2Code
2023-12-26 11:41:53 +08:00
# 定义一个新的类
class CodeRetrievalJudger ( BaseAgent ) :
def start_action_step ( self , message : Message ) - > Message :
''' do action before agent predict '''
2024-01-26 14:03:25 +08:00
if ' Retrieval_Codes ' in message . customed_kargs :
# already retrive the code
pass
else :
action_json = CodeRetrievalSingle . run ( message . code_engine_name , message . origin_query , llm_config = self . llm_config , embed_config = self . embed_config )
message . customed_kargs [ " CodeRetrievalSingleRes " ] = action_json
message . customed_kargs [ ' Current_Vertex ' ] = action_json [ ' vertex ' ]
message . customed_kargs . setdefault ( " Retrieval_Codes " , [ ] )
message . customed_kargs [ " Retrieval_Codes " ] . append ( action_json [ " code " ] )
2023-12-26 11:41:53 +08:00
return message
# 定义一个新的类
class CodeRetrievalDivergent ( BaseAgent ) :
def start_action_step ( self , message : Message ) - > Message :
''' do action before agent predict '''
2024-01-26 14:03:25 +08:00
action_json = RelatedVerticesRetrival . run ( message . code_engine_name , message . customed_kargs [ ' Current_Vertex ' ] )
if not action_json [ ' vertices ' ] :
message . action_status = ActionStatus . FINISHED
2023-12-26 11:41:53 +08:00
message . customed_kargs [ " RelatedVerticesRetrivalRes " ] = action_json
return message
def end_action_step ( self , message : Message ) - > Message :
''' do action before agent predict '''
# logger.error(f"message: {message}")
action_json = Vertex2Code . run ( message . code_engine_name , message . parsed_output [ " Code Package " ] )
2024-01-26 14:03:25 +08:00
logger . debug ( f ' action_json= { action_json } ' )
if not action_json [ ' code ' ] :
message . action_status = ActionStatus . FINISHED
return message
2023-12-26 11:41:53 +08:00
message . customed_kargs [ " Vertex2Code " ] = action_json
2024-01-26 14:03:25 +08:00
message . customed_kargs [ ' Current_Vertex ' ] = message . parsed_output [ " Code Package " ]
message . customed_kargs . setdefault ( " Retrieval_Codes " , [ ] )
if action_json [ ' code ' ] in message . customed_kargs [ " Retrieval_Codes " ] :
message . action_status = ActionStatus . FINISHED
return message
message . customed_kargs [ " Retrieval_Codes " ] . append ( action_json [ ' code ' ] )
2023-12-26 11:41:53 +08:00
return message
2024-01-26 14:03:25 +08:00
# add agent or prompt_manager class
agent_module = importlib . import_module ( " coagent.connector.agents " )
prompt_manager_module = importlib . import_module ( " coagent.connector.prompt_manager " )
2023-12-26 11:41:53 +08:00
setattr ( agent_module , ' CodeRetrievalJudger ' , CodeRetrievalJudger )
setattr ( agent_module , ' CodeRetrievalDivergent ' , CodeRetrievalDivergent )
2024-01-26 14:03:25 +08:00
setattr ( prompt_manager_module , ' CodeRetrievalPM ' , CodeRetrievalPM )
2023-12-26 11:41:53 +08:00
2024-01-26 14:03:25 +08:00
# log-level, print prompt和llm predict
os . environ [ " log_verbose " ] = " 2 "
2023-12-26 11:41:53 +08:00
2024-01-26 14:03:25 +08:00
phase_name = " codeRetrievalPhase "
llm_config = LLMConfig (
model_name = " gpt-3.5-turbo " , model_device = " cpu " , api_key = os . environ [ " OPENAI_API_KEY " ] ,
api_base_url = os . environ [ " API_BASE_URL " ] , temperature = 0.3
)
embed_config = EmbedConfig (
embed_engine = " model " , embed_model = " text2vec-base-chinese " ,
embed_model_path = os . path . join ( src_dir , " embedding_models/text2vec-base-chinese " )
)
phase = BasePhase (
phase_name , sandbox_server = SANDBOX_SERVER , jupyter_work_path = JUPYTER_WORK_PATH ,
embed_config = embed_config , llm_config = llm_config , kb_root_path = KB_ROOT_PATH ,
)
2023-12-26 11:41:53 +08:00
# round-1
2024-01-26 14:03:25 +08:00
query_content = " UtilsTest 这个类中测试了哪些函数,测试的函数代码是什么 "
2023-12-26 11:41:53 +08:00
query = Message (
2024-01-26 14:03:25 +08:00
role_name = " human " , role_type = " user " ,
2023-12-26 11:41:53 +08:00
role_content = query_content , input_query = query_content , origin_query = query_content ,
2024-01-26 14:03:25 +08:00
code_engine_name = " client " , score_threshold = 1.0 , top_k = 3 , cb_search_type = " tag "
2023-12-26 11:41:53 +08:00
)
2024-01-26 14:03:25 +08:00
output_message , output_memory = phase . step ( query )
print ( output_memory . to_str_messages ( return_all = True , content_key = " parsed_output_list " ) )