[feature](coagent)<增加antflow兼容和增加coagent demo>

This commit is contained in:
shanshi 2024-03-12 15:31:06 +08:00
parent c14b41ecec
commit 4d9b268a98
86 changed files with 3449 additions and 901 deletions

View File

@ -26,9 +26,12 @@ JUPYTER_WORK_PATH = os.environ.get("JUPYTER_WORK_PATH", None) or os.path.join(ex
WEB_CRAWL_PATH = os.environ.get("WEB_CRAWL_PATH", None) or os.path.join(executable_path, "knowledge_base") WEB_CRAWL_PATH = os.environ.get("WEB_CRAWL_PATH", None) or os.path.join(executable_path, "knowledge_base")
# NEBULA_DATA存储路径 # NEBULA_DATA存储路径
NELUBA_PATH = os.environ.get("NELUBA_PATH", None) or os.path.join(executable_path, "data/neluba_data") NEBULA_PATH = os.environ.get("NEBULA_PATH", None) or os.path.join(executable_path, "data/nebula_data")
for _path in [LOG_PATH, SOURCE_PATH, KB_ROOT_PATH, NLTK_DATA_PATH, JUPYTER_WORK_PATH, WEB_CRAWL_PATH, NELUBA_PATH]: # CHROMA 存储路径
CHROMA_PERSISTENT_PATH = os.environ.get("CHROMA_PERSISTENT_PATH", None) or os.path.join(executable_path, "data/chroma_data")
for _path in [LOG_PATH, SOURCE_PATH, KB_ROOT_PATH, CB_ROOT_PATH, NLTK_DATA_PATH, JUPYTER_WORK_PATH, WEB_CRAWL_PATH, NEBULA_PATH, CHROMA_PERSISTENT_PATH]:
if not os.path.exists(_path): if not os.path.exists(_path):
os.makedirs(_path, exist_ok=True) os.makedirs(_path, exist_ok=True)
@ -58,7 +61,8 @@ NEBULA_GRAPH_SERVER = {
} }
# CHROMA CONFIG # CHROMA CONFIG
CHROMA_PERSISTENT_PATH = '/home/user/chatbot/data/chroma_data' # CHROMA_PERSISTENT_PATH = '/home/user/chatbot/data/chroma_data'
# CHROMA_PERSISTENT_PATH = '/Users/bingxu/Desktop/工作/大模型/chatbot/codefuse-chatbot-antcode/data/chroma_data'
# 默认向量库类型。可选faiss, milvus, pg. # 默认向量库类型。可选faiss, milvus, pg.

View File

@ -7,7 +7,7 @@ from langchain import LLMChain
from langchain.callbacks import AsyncIteratorCallbackHandler from langchain.callbacks import AsyncIteratorCallbackHandler
from langchain.prompts.chat import ChatPromptTemplate from langchain.prompts.chat import ChatPromptTemplate
from coagent.llm_models import getChatModel, getChatModelFromConfig from coagent.llm_models import getChatModelFromConfig
from coagent.chat.utils import History, wrap_done from coagent.chat.utils import History, wrap_done
from coagent.llm_models.llm_config import LLMConfig, EmbedConfig from coagent.llm_models.llm_config import LLMConfig, EmbedConfig
# from configs.model_config import (llm_model_dict, LLM_MODEL, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD) # from configs.model_config import (llm_model_dict, LLM_MODEL, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD)

View File

@ -22,7 +22,7 @@ from coagent.connector.configs.prompts import CODE_PROMPT_TEMPLATE
from coagent.chat.utils import History, wrap_done from coagent.chat.utils import History, wrap_done
from coagent.utils import BaseResponse from coagent.utils import BaseResponse
from .base_chat import Chat from .base_chat import Chat
from coagent.llm_models import getChatModel, getChatModelFromConfig from coagent.llm_models import getChatModelFromConfig
from coagent.llm_models.llm_config import LLMConfig, EmbedConfig from coagent.llm_models.llm_config import LLMConfig, EmbedConfig
@ -67,6 +67,7 @@ class CodeChat(Chat):
embed_model_path=embed_config.embed_model_path, embed_model_path=embed_config.embed_model_path,
embed_engine=embed_config.embed_engine, embed_engine=embed_config.embed_engine,
model_device=embed_config.model_device, model_device=embed_config.model_device,
embed_config=embed_config
) )
context = codes_res['context'] context = codes_res['context']

View File

@ -12,7 +12,7 @@ from langchain.schema import (
# from configs.model_config import CODE_INTERPERT_TEMPLATE # from configs.model_config import CODE_INTERPERT_TEMPLATE
from coagent.connector.configs.prompts import CODE_INTERPERT_TEMPLATE from coagent.connector.configs.prompts import CODE_INTERPERT_TEMPLATE
from coagent.llm_models.openai_model import getChatModel, getChatModelFromConfig from coagent.llm_models.openai_model import getChatModelFromConfig
from coagent.llm_models.llm_config import LLMConfig from coagent.llm_models.llm_config import LLMConfig
@ -53,9 +53,15 @@ class CodeIntepreter:
message = CODE_INTERPERT_TEMPLATE.format(code=code) message = CODE_INTERPERT_TEMPLATE.format(code=code)
messages.append(message) messages.append(message)
chat_ress = chat_model.batch(messages) try:
chat_ress = [chat_model(messages) for message in messages]
except:
chat_ress = chat_model.batch(messages)
for chat_res, code in zip(chat_ress, code_list): for chat_res, code in zip(chat_ress, code_list):
res[code] = chat_res.content try:
res[code] = chat_res.content
except:
res[code] = chat_res
return res return res

View File

@ -27,7 +27,7 @@ class DirCrawler:
logger.info(java_file_list) logger.info(java_file_list)
for java_file in java_file_list: for java_file in java_file_list:
with open(java_file) as f: with open(java_file, encoding="utf-8") as f:
java_code = ''.join(f.readlines()) java_code = ''.join(f.readlines())
java_code_dict[java_file] = java_code java_code_dict[java_file] = java_code
return java_code_dict return java_code_dict

View File

@ -5,6 +5,7 @@
@time: 2023/11/21 下午2:35 @time: 2023/11/21 下午2:35
@desc: @desc:
''' '''
import json
import time import time
from loguru import logger from loguru import logger
from collections import defaultdict from collections import defaultdict
@ -15,7 +16,7 @@ from coagent.db_handler.vector_db_handler.chroma_handler import ChromaHandler
from coagent.codechat.code_search.cypher_generator import CypherGenerator from coagent.codechat.code_search.cypher_generator import CypherGenerator
from coagent.codechat.code_search.tagger import Tagger from coagent.codechat.code_search.tagger import Tagger
from coagent.embeddings.get_embedding import get_embedding from coagent.embeddings.get_embedding import get_embedding
from coagent.llm_models.llm_config import LLMConfig from coagent.llm_models.llm_config import LLMConfig, EmbedConfig
# from configs.model_config import EMBEDDING_DEVICE, EMBEDDING_MODEL # from configs.model_config import EMBEDDING_DEVICE, EMBEDDING_MODEL
@ -29,7 +30,8 @@ MAX_DISTANCE = 1000
class CodeSearch: class CodeSearch:
def __init__(self, llm_config: LLMConfig, nh: NebulaHandler, ch: ChromaHandler, limit: int = 3): def __init__(self, llm_config: LLMConfig, nh: NebulaHandler, ch: ChromaHandler, limit: int = 3,
local_graph_file_path: str = ''):
''' '''
init init
@param nh: NebulaHandler @param nh: NebulaHandler
@ -37,7 +39,13 @@ class CodeSearch:
@param limit: limit of result @param limit: limit of result
''' '''
self.llm_config = llm_config self.llm_config = llm_config
self.nh = nh self.nh = nh
if not self.nh:
with open(local_graph_file_path, 'r') as f:
self.graph = json.load(f)
self.ch = ch self.ch = ch
self.limit = limit self.limit = limit
@ -51,7 +59,7 @@ class CodeSearch:
tag_list = tagger.generate_tag_query(query) tag_list = tagger.generate_tag_query(query)
logger.info(f'query tag={tag_list}') logger.info(f'query tag={tag_list}')
# get all verticex # get all vertices
vertex_list = self.nh.get_vertices().get('v', []) vertex_list = self.nh.get_vertices().get('v', [])
vertex_vid_list = [i.as_node().get_id().as_string() for i in vertex_list] vertex_vid_list = [i.as_node().get_id().as_string() for i in vertex_list]
@ -81,7 +89,7 @@ class CodeSearch:
# get most prominent package tag # get most prominent package tag
package_score_dict = defaultdict(lambda: 0) package_score_dict = defaultdict(lambda: 0)
for vertex, score in vertex_score_dict.items(): for vertex, score in vertex_score_dict_final.items():
if '#' in vertex: if '#' in vertex:
# get class name first # get class name first
cypher = f'''MATCH (v1:class)-[e:contain]->(v2) WHERE id(v2) == '{vertex}' RETURN id(v1) as id;''' cypher = f'''MATCH (v1:class)-[e:contain]->(v2) WHERE id(v2) == '{vertex}' RETURN id(v1) as id;'''
@ -111,6 +119,53 @@ class CodeSearch:
logger.info(f'ids={ids}') logger.info(f'ids={ids}')
chroma_res = self.ch.get(ids=ids, include=['metadatas']) chroma_res = self.ch.get(ids=ids, include=['metadatas'])
for vertex, score in package_score_tuple:
index = chroma_res['result']['ids'].index(vertex)
code_text = chroma_res['result']['metadatas'][index]['code_text']
res.append({
"vertex": vertex,
"code_text": code_text}
)
if len(res) >= self.limit:
break
# logger.info(f'retrival code={res}')
return res
def search_by_tag_by_graph(self, query: str):
'''
search code by tag with graph
@param query:
@return:
'''
tagger = Tagger()
tag_list = tagger.generate_tag_query(query)
logger.info(f'query tag={tag_list}')
# loop to get package node
package_score_dict = {}
for code, structure in self.graph.items():
score = 0
for class_name in structure['class_name_list']:
for tag in tag_list:
if tag.lower() in class_name.lower():
score += 1
for func_name_list in structure['func_name_dict'].values():
for func_name in func_name_list:
for tag in tag_list:
if tag.lower() in func_name.lower():
score += 1
package_score_dict[structure['pac_name']] = score
# get respective code
res = []
package_score_tuple = list(package_score_dict.items())
package_score_tuple.sort(key=lambda x: x[1], reverse=True)
ids = [i[0] for i in package_score_tuple]
logger.info(f'ids={ids}')
chroma_res = self.ch.get(ids=ids, include=['metadatas'])
# logger.info(chroma_res) # logger.info(chroma_res)
for vertex, score in package_score_tuple: for vertex, score in package_score_tuple:
index = chroma_res['result']['ids'].index(vertex) index = chroma_res['result']['ids'].index(vertex)
@ -121,23 +176,22 @@ class CodeSearch:
) )
if len(res) >= self.limit: if len(res) >= self.limit:
break break
logger.info(f'retrival code={res}') # logger.info(f'retrival code={res}')
return res return res
def search_by_desciption(self, query: str, engine: str, model_path: str = "text2vec-base-chinese", embedding_device: str = "cpu"): def search_by_desciption(self, query: str, engine: str, model_path: str = "text2vec-base-chinese", embedding_device: str = "cpu", embed_config: EmbedConfig=None):
''' '''
search by perform sim search search by perform sim search
@param query: @param query:
@return: @return:
''' '''
query = query.replace(',', '') query = query.replace(',', '')
query_emb = get_embedding(engine=engine, text_list=[query], model_path=model_path, embedding_device= embedding_device,) query_emb = get_embedding(engine=engine, text_list=[query], model_path=model_path, embedding_device= embedding_device, embed_config=embed_config)
query_emb = query_emb[query] query_emb = query_emb[query]
query_embeddings = [query_emb] query_embeddings = [query_emb]
query_result = self.ch.query(query_embeddings=query_embeddings, n_results=self.limit, query_result = self.ch.query(query_embeddings=query_embeddings, n_results=self.limit,
include=['metadatas', 'distances']) include=['metadatas', 'distances'])
logger.debug(query_result)
res = [] res = []
for idx, distance in enumerate(query_result['result']['distances'][0]): for idx, distance in enumerate(query_result['result']['distances'][0]):

View File

@ -8,7 +8,7 @@
from langchain import PromptTemplate from langchain import PromptTemplate
from loguru import logger from loguru import logger
from coagent.llm_models.openai_model import getChatModel, getChatModelFromConfig from coagent.llm_models.openai_model import getChatModelFromConfig
from coagent.llm_models.llm_config import LLMConfig from coagent.llm_models.llm_config import LLMConfig
from coagent.utils.postprocess import replace_lt_gt from coagent.utils.postprocess import replace_lt_gt
from langchain.schema import ( from langchain.schema import (

View File

@ -6,11 +6,10 @@
@desc: @desc:
''' '''
import time import time
import json
import os
from loguru import logger from loguru import logger
# from configs.server_config import NEBULA_HOST, NEBULA_PORT, NEBULA_USER, NEBULA_PASSWORD, NEBULA_STORAGED_PORT
# from configs.server_config import CHROMA_PERSISTENT_PATH
# from configs.model_config import EMBEDDING_DEVICE, EMBEDDING_MODEL
from coagent.db_handler.graph_db_handler.nebula_handler import NebulaHandler from coagent.db_handler.graph_db_handler.nebula_handler import NebulaHandler
from coagent.db_handler.vector_db_handler.chroma_handler import ChromaHandler from coagent.db_handler.vector_db_handler.chroma_handler import ChromaHandler
from coagent.embeddings.get_embedding import get_embedding from coagent.embeddings.get_embedding import get_embedding
@ -18,12 +17,14 @@ from coagent.llm_models.llm_config import EmbedConfig
class CodeImporter: class CodeImporter:
def __init__(self, codebase_name: str, embed_config: EmbedConfig, nh: NebulaHandler, ch: ChromaHandler): def __init__(self, codebase_name: str, embed_config: EmbedConfig, nh: NebulaHandler, ch: ChromaHandler,
local_graph_file_path: str):
self.codebase_name = codebase_name self.codebase_name = codebase_name
# self.engine = engine # self.engine = engine
self.embed_config: EmbedConfig= embed_config self.embed_config: EmbedConfig = embed_config
self.nh = nh self.nh = nh
self.ch = ch self.ch = ch
self.local_graph_file_path = local_graph_file_path
def import_code(self, static_analysis_res: dict, interpretation: dict, do_interpret: bool = True): def import_code(self, static_analysis_res: dict, interpretation: dict, do_interpret: bool = True):
''' '''
@ -31,9 +32,14 @@ class CodeImporter:
@return: @return:
''' '''
static_analysis_res = self.filter_out_vertex(static_analysis_res, interpretation) static_analysis_res = self.filter_out_vertex(static_analysis_res, interpretation)
logger.info(f'static_analysis_res={static_analysis_res}')
self.analysis_res_to_graph(static_analysis_res) if self.nh:
self.analysis_res_to_graph(static_analysis_res)
else:
# persist to local dir
with open(self.local_graph_file_path, 'w') as f:
json.dump(static_analysis_res, f)
self.interpretation_to_db(static_analysis_res, interpretation, do_interpret) self.interpretation_to_db(static_analysis_res, interpretation, do_interpret)
def filter_out_vertex(self, static_analysis_res, interpretation): def filter_out_vertex(self, static_analysis_res, interpretation):
@ -114,12 +120,12 @@ class CodeImporter:
# create vertex # create vertex
for tag_name, value_dict in vertex_value_dict.items(): for tag_name, value_dict in vertex_value_dict.items():
res = self.nh.insert_vertex(tag_name, value_dict) res = self.nh.insert_vertex(tag_name, value_dict)
logger.debug(res.error_msg()) # logger.debug(res.error_msg())
# create edge # create edge
for tag_name, value_dict in edge_value_dict.items(): for tag_name, value_dict in edge_value_dict.items():
res = self.nh.insert_edge(tag_name, value_dict) res = self.nh.insert_edge(tag_name, value_dict)
logger.debug(res.error_msg()) # logger.debug(res.error_msg())
return return
@ -132,7 +138,7 @@ class CodeImporter:
if do_interpret: if do_interpret:
logger.info('start get embedding for interpretion') logger.info('start get embedding for interpretion')
interp_list = list(interpretation.values()) interp_list = list(interpretation.values())
emb = get_embedding(engine=self.embed_config.embed_engine, text_list=interp_list, model_path=self.embed_config.embed_model_path, embedding_device= self.embed_config.model_device) emb = get_embedding(engine=self.embed_config.embed_engine, text_list=interp_list, model_path=self.embed_config.embed_model_path, embedding_device= self.embed_config.model_device, embed_config=self.embed_config)
logger.info('get embedding done') logger.info('get embedding done')
else: else:
emb = {i: [0] for i in list(interpretation.values())} emb = {i: [0] for i in list(interpretation.values())}
@ -161,7 +167,7 @@ class CodeImporter:
# add documents to chroma # add documents to chroma
res = self.ch.add_data(ids=ids, embeddings=embeddings, documents=documents, metadatas=metadatas) res = self.ch.add_data(ids=ids, embeddings=embeddings, documents=documents, metadatas=metadatas)
logger.debug(res) # logger.debug(res)
def init_graph(self): def init_graph(self):
''' '''
@ -169,7 +175,7 @@ class CodeImporter:
@return: @return:
''' '''
res = self.nh.create_space(space_name=self.codebase_name, vid_type='FIXED_STRING(1024)') res = self.nh.create_space(space_name=self.codebase_name, vid_type='FIXED_STRING(1024)')
logger.debug(res.error_msg()) # logger.debug(res.error_msg())
time.sleep(5) time.sleep(5)
self.nh.set_space_name(self.codebase_name) self.nh.set_space_name(self.codebase_name)
@ -179,29 +185,29 @@ class CodeImporter:
tag_name = 'package' tag_name = 'package'
prop_dict = {} prop_dict = {}
res = self.nh.create_tag(tag_name, prop_dict) res = self.nh.create_tag(tag_name, prop_dict)
logger.debug(res.error_msg()) # logger.debug(res.error_msg())
tag_name = 'class' tag_name = 'class'
prop_dict = {} prop_dict = {}
res = self.nh.create_tag(tag_name, prop_dict) res = self.nh.create_tag(tag_name, prop_dict)
logger.debug(res.error_msg()) # logger.debug(res.error_msg())
tag_name = 'method' tag_name = 'method'
prop_dict = {} prop_dict = {}
res = self.nh.create_tag(tag_name, prop_dict) res = self.nh.create_tag(tag_name, prop_dict)
logger.debug(res.error_msg()) # logger.debug(res.error_msg())
# create edge type # create edge type
edge_type_name = 'contain' edge_type_name = 'contain'
prop_dict = {} prop_dict = {}
res = self.nh.create_edge_type(edge_type_name, prop_dict) res = self.nh.create_edge_type(edge_type_name, prop_dict)
logger.debug(res.error_msg()) # logger.debug(res.error_msg())
# create edge type # create edge type
edge_type_name = 'depend' edge_type_name = 'depend'
prop_dict = {} prop_dict = {}
res = self.nh.create_edge_type(edge_type_name, prop_dict) res = self.nh.create_edge_type(edge_type_name, prop_dict)
logger.debug(res.error_msg()) # logger.debug(res.error_msg())
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -5,16 +5,15 @@
@time: 2023/11/21 下午2:25 @time: 2023/11/21 下午2:25
@desc: @desc:
''' '''
import os
import time import time
import json
from typing import List
from loguru import logger from loguru import logger
# from configs.server_config import NEBULA_HOST, NEBULA_PORT, NEBULA_USER, NEBULA_PASSWORD, NEBULA_STORAGED_PORT
# from configs.server_config import CHROMA_PERSISTENT_PATH
# from configs.model_config import EMBEDDING_ENGINE
from coagent.base_configs.env_config import ( from coagent.base_configs.env_config import (
NEBULA_HOST, NEBULA_PORT, NEBULA_USER, NEBULA_PASSWORD, NEBULA_STORAGED_PORT, NEBULA_HOST, NEBULA_PORT, NEBULA_USER, NEBULA_PASSWORD, NEBULA_STORAGED_PORT,
CHROMA_PERSISTENT_PATH CHROMA_PERSISTENT_PATH, CB_ROOT_PATH
) )
@ -35,7 +34,9 @@ class CodeBaseHandler:
language: str = 'java', language: str = 'java',
crawl_type: str = 'ZIP', crawl_type: str = 'ZIP',
embed_config: EmbedConfig = EmbedConfig(), embed_config: EmbedConfig = EmbedConfig(),
llm_config: LLMConfig = LLMConfig() llm_config: LLMConfig = LLMConfig(),
use_nh: bool = True,
local_graph_path: str = CB_ROOT_PATH
): ):
self.codebase_name = codebase_name self.codebase_name = codebase_name
self.code_path = code_path self.code_path = code_path
@ -43,11 +44,28 @@ class CodeBaseHandler:
self.crawl_type = crawl_type self.crawl_type = crawl_type
self.embed_config = embed_config self.embed_config = embed_config
self.llm_config = llm_config self.llm_config = llm_config
self.local_graph_file_path = local_graph_path + os.sep + f'{self.codebase_name}_graph.json'
self.nh = NebulaHandler(host=NEBULA_HOST, port=NEBULA_PORT, username=NEBULA_USER, if use_nh:
password=NEBULA_PASSWORD, space_name=codebase_name) try:
self.nh.add_host(NEBULA_HOST, NEBULA_STORAGED_PORT) self.nh = NebulaHandler(host=NEBULA_HOST, port=NEBULA_PORT, username=NEBULA_USER,
time.sleep(1) password=NEBULA_PASSWORD, space_name=codebase_name)
self.nh.add_host(NEBULA_HOST, NEBULA_STORAGED_PORT)
time.sleep(1)
except:
self.nh = None
try:
with open(self.local_graph_file_path, 'r') as f:
self.graph = json.load(f)
except:
pass
elif local_graph_path:
self.nh = None
try:
with open(self.local_graph_file_path, 'r') as f:
self.graph = json.load(f)
except:
pass
self.ch = ChromaHandler(path=CHROMA_PERSISTENT_PATH, collection_name=codebase_name) self.ch = ChromaHandler(path=CHROMA_PERSISTENT_PATH, collection_name=codebase_name)
@ -58,9 +76,10 @@ class CodeBaseHandler:
''' '''
# init graph to init tag and edge # init graph to init tag and edge
code_importer = CodeImporter(embed_config=self.embed_config, codebase_name=self.codebase_name, code_importer = CodeImporter(embed_config=self.embed_config, codebase_name=self.codebase_name,
nh=self.nh, ch=self.ch) nh=self.nh, ch=self.ch, local_graph_file_path=self.local_graph_file_path)
code_importer.init_graph() if self.nh:
time.sleep(5) code_importer.init_graph()
time.sleep(5)
# crawl code # crawl code
st0 = time.time() st0 = time.time()
@ -71,7 +90,7 @@ class CodeBaseHandler:
# analyze code # analyze code
logger.info('start analyze') logger.info('start analyze')
st1 = time.time() st1 = time.time()
code_analyzer = CodeAnalyzer(language=self.language, llm_config = self.llm_config) code_analyzer = CodeAnalyzer(language=self.language, llm_config=self.llm_config)
static_analysis_res, interpretation = code_analyzer.analyze(code_dict, do_interpret=do_interpret) static_analysis_res, interpretation = code_analyzer.analyze(code_dict, do_interpret=do_interpret)
logger.debug('analyze done, rt={}'.format(time.time() - st1)) logger.debug('analyze done, rt={}'.format(time.time() - st1))
@ -81,8 +100,12 @@ class CodeBaseHandler:
logger.debug('update codebase done, rt={}'.format(time.time() - st2)) logger.debug('update codebase done, rt={}'.format(time.time() - st2))
# get KG info # get KG info
stat = self.nh.get_stat() if self.nh:
vertices_num, edges_num = stat['vertices'], stat['edges'] stat = self.nh.get_stat()
vertices_num, edges_num = stat['vertices'], stat['edges']
else:
vertices_num = 0
edges_num = 0
# get chroma info # get chroma info
file_num = self.ch.count()['result'] file_num = self.ch.count()['result']
@ -95,7 +118,11 @@ class CodeBaseHandler:
@param codebase_name: name of codebase @param codebase_name: name of codebase
@return: @return:
''' '''
self.nh.drop_space(space_name=codebase_name) if self.nh:
self.nh.drop_space(space_name=codebase_name)
elif self.local_graph_file_path and os.path.isfile(self.local_graph_file_path):
os.remove(self.local_graph_file_path)
self.ch.delete_collection(collection_name=codebase_name) self.ch.delete_collection(collection_name=codebase_name)
def crawl_code(self, zip_file=''): def crawl_code(self, zip_file=''):
@ -124,9 +151,15 @@ class CodeBaseHandler:
@param search_type: ['cypher', 'graph', 'vector'] @param search_type: ['cypher', 'graph', 'vector']
@return: @return:
''' '''
assert search_type in ['cypher', 'tag', 'description'] if self.nh:
assert search_type in ['cypher', 'tag', 'description']
else:
if search_type == 'tag':
search_type = 'tag_by_local_graph'
assert search_type in ['tag_by_local_graph', 'description']
code_search = CodeSearch(llm_config=self.llm_config, nh=self.nh, ch=self.ch, limit=limit) code_search = CodeSearch(llm_config=self.llm_config, nh=self.nh, ch=self.ch, limit=limit,
local_graph_file_path=self.local_graph_file_path)
if search_type == 'cypher': if search_type == 'cypher':
search_res = code_search.search_by_cypher(query=query) search_res = code_search.search_by_cypher(query=query)
@ -134,7 +167,11 @@ class CodeBaseHandler:
search_res = code_search.search_by_tag(query=query) search_res = code_search.search_by_tag(query=query)
elif search_type == 'description': elif search_type == 'description':
search_res = code_search.search_by_desciption( search_res = code_search.search_by_desciption(
query=query, engine=self.embed_config.embed_engine, model_path=self.embed_config.embed_model_path, embedding_device=self.embed_config.model_device) query=query, engine=self.embed_config.embed_engine, model_path=self.embed_config.embed_model_path,
embedding_device=self.embed_config.model_device, embed_config=self.embed_config)
elif search_type == 'tag_by_local_graph':
search_res = code_search.search_by_tag_by_graph(query=query)
context, related_vertice = self.format_search_res(search_res, search_type) context, related_vertice = self.format_search_res(search_res, search_type)
return context, related_vertice return context, related_vertice
@ -160,6 +197,12 @@ class CodeBaseHandler:
for code in search_res: for code in search_res:
context = context + code['code_text'] + '\n' context = context + code['code_text'] + '\n'
related_vertice.append(code['vertex']) related_vertice.append(code['vertex'])
elif search_type == 'tag_by_local_graph':
context = ''
related_vertice = []
for code in search_res:
context = context + code['code_text'] + '\n'
related_vertice.append(code['vertex'])
elif search_type == 'description': elif search_type == 'description':
context = '' context = ''
related_vertice = [] related_vertice = []
@ -169,17 +212,63 @@ class CodeBaseHandler:
return context, related_vertice return context, related_vertice
def search_vertices(self, vertex_type="class") -> List[str]:
'''
通过 method/class 来搜索所有的节点
'''
vertices = []
if self.nh:
vertices = self.nh.get_all_vertices()
vertices = [str(v.as_node().get_id()) for v in vertices["v"] if vertex_type in v.as_node().tags()]
# for v in vertices["v"]:
# logger.debug(f"{v.as_node().get_id()}, {v.as_node().tags()}")
else:
if vertex_type == "class":
vertices = [str(class_name) for code, structure in self.graph.items() for class_name in structure['class_name_list']]
elif vertex_type == "method":
vertices = [
str(methods_name)
for code, structure in self.graph.items()
for methods_names in structure['func_name_dict'].values()
for methods_name in methods_names
]
# logger.debug(vertices)
return vertices
if __name__ == '__main__': if __name__ == '__main__':
codebase_name = 'testing' from configs.model_config import KB_ROOT_PATH, JUPYTER_WORK_PATH
from configs.server_config import SANDBOX_SERVER
LLM_MODEL = "gpt-3.5-turbo"
llm_config = LLMConfig(
model_name=LLM_MODEL, model_device="cpu", api_key=os.environ["OPENAI_API_KEY"],
api_base_url=os.environ["API_BASE_URL"], temperature=0.3
)
src_dir = '/Users/bingxu/Desktop/工作/大模型/chatbot/Codefuse-chatbot-antcode'
embed_config = EmbedConfig(
embed_engine="model", embed_model="text2vec-base-chinese",
embed_model_path=os.path.join(src_dir, "embedding_models/text2vec-base-chinese")
)
codebase_name = 'client_local'
code_path = '/Users/bingxu/Desktop/工作/大模型/chatbot/test_code_repo/client' code_path = '/Users/bingxu/Desktop/工作/大模型/chatbot/test_code_repo/client'
cbh = CodeBaseHandler(codebase_name, code_path, crawl_type='dir') use_nh = False
local_graph_path = '/Users/bingxu/Desktop/工作/大模型/chatbot/Codefuse-chatbot-antcode/code_base'
CHROMA_PERSISTENT_PATH = '/Users/bingxu/Desktop/工作/大模型/chatbot/Codefuse-chatbot-antcode/data/chroma_data'
cbh = CodeBaseHandler(codebase_name, code_path, crawl_type='dir', use_nh=use_nh, local_graph_path=local_graph_path,
llm_config=llm_config, embed_config=embed_config)
# test import code
# cbh.import_code(do_interpret=True)
# query = '使用不同的HTTP请求类型GET、POST、DELETE等来执行不同的操作' # query = '使用不同的HTTP请求类型GET、POST、DELETE等来执行不同的操作'
# query = '代码中一共有多少个类' # query = '代码中一共有多少个类'
# query = 'remove 这个函数是用来做什么的'
query = '有没有函数是从字符串中删除指定字符串的功能'
query = 'intercept 函数作用是什么' search_type = 'description'
search_type = 'graph'
limit = 2 limit = 2
res = cbh.search_code(query, search_type, limit) res = cbh.search_code(query, search_type, limit)
logger.debug(res) logger.debug(res)

View File

@ -0,0 +1,6 @@
from .base_action import BaseAction
__all__ = [
"BaseAction"
]

View File

@ -0,0 +1,16 @@
from langchain.schema import BaseRetriever, Document
class BaseAction:
def __init__(self, ):
pass
def step(self, ):
pass
def astep(self, ):
pass

View File

@ -4,25 +4,25 @@ import re, os
import copy import copy
from loguru import logger from loguru import logger
from langchain.schema import BaseRetriever
from coagent.connector.schema import ( from coagent.connector.schema import (
Memory, Task, Role, Message, PromptField, LogVerboseEnum Memory, Task, Role, Message, PromptField, LogVerboseEnum
) )
from coagent.connector.memory_manager import BaseMemoryManager from coagent.connector.memory_manager import BaseMemoryManager
from coagent.connector.configs.prompts import BEGIN_PROMPT_INPUT
from coagent.connector.message_process import MessageUtils from coagent.connector.message_process import MessageUtils
from coagent.llm_models import getChatModel, getExtraModel, LLMConfig, getChatModelFromConfig, EmbedConfig from coagent.llm_models import getExtraModel, LLMConfig, getChatModelFromConfig, EmbedConfig
from coagent.connector.prompt_manager import PromptManager from coagent.connector.prompt_manager.prompt_manager import PromptManager
from coagent.connector.memory_manager import LocalMemoryManager from coagent.connector.memory_manager import LocalMemoryManager
from coagent.connector.utils import parse_section from coagent.base_configs.env_config import JUPYTER_WORK_PATH, KB_ROOT_PATH
# from configs.model_config import JUPYTER_WORK_PATH
# from configs.server_config import SANDBOX_SERVER
class BaseAgent: class BaseAgent:
def __init__( def __init__(
self, self,
role: Role, role: Role,
prompt_config: [PromptField], prompt_config: List[PromptField],
prompt_manager_type: str = "PromptManager", prompt_manager_type: str = "PromptManager",
task: Task = None, task: Task = None,
memory: Memory = None, memory: Memory = None,
@ -33,8 +33,11 @@ class BaseAgent:
llm_config: LLMConfig = None, llm_config: LLMConfig = None,
embed_config: EmbedConfig = None, embed_config: EmbedConfig = None,
sandbox_server: dict = {}, sandbox_server: dict = {},
jupyter_work_path: str = "", jupyter_work_path: str = JUPYTER_WORK_PATH,
kb_root_path: str = "", kb_root_path: str = KB_ROOT_PATH,
doc_retrieval: Union[BaseRetriever] = None,
code_retrieval = None,
search_retrieval = None,
log_verbose: str = "0" log_verbose: str = "0"
): ):
@ -43,7 +46,7 @@ class BaseAgent:
self.sandbox_server = sandbox_server self.sandbox_server = sandbox_server
self.jupyter_work_path = jupyter_work_path self.jupyter_work_path = jupyter_work_path
self.kb_root_path = kb_root_path self.kb_root_path = kb_root_path
self.message_utils = MessageUtils(role, sandbox_server, jupyter_work_path, embed_config, llm_config, kb_root_path, log_verbose) self.message_utils = MessageUtils(role, sandbox_server, jupyter_work_path, embed_config, llm_config, kb_root_path, doc_retrieval, code_retrieval, search_retrieval, log_verbose)
self.memory = self.init_history(memory) self.memory = self.init_history(memory)
self.llm_config: LLMConfig = llm_config self.llm_config: LLMConfig = llm_config
self.embed_config: EmbedConfig = embed_config self.embed_config: EmbedConfig = embed_config
@ -82,12 +85,8 @@ class BaseAgent:
llm_config=self.embed_config llm_config=self.embed_config
) )
memory_manager.append(query) memory_manager.append(query)
memory_pool = memory_manager.current_memory memory_pool = memory_manager.get_memory_pool(query.user_name)
else:
memory_pool = memory_manager.current_memory
logger.debug(f"memory_pool: {memory_pool}")
prompt = self.prompt_manager.generate_full_prompt( prompt = self.prompt_manager.generate_full_prompt(
previous_agent_message=query_c, agent_long_term_memory=self.memory, ui_history=history, chain_summary_messages=background, memory_pool=memory_pool) previous_agent_message=query_c, agent_long_term_memory=self.memory, ui_history=history, chain_summary_messages=background, memory_pool=memory_pool)
content = self.llm.predict(prompt) content = self.llm.predict(prompt)
@ -99,6 +98,7 @@ class BaseAgent:
logger.info(f"{self.role.role_name} content: {content}") logger.info(f"{self.role.role_name} content: {content}")
output_message = Message( output_message = Message(
user_name=query.user_name,
role_name=self.role.role_name, role_name=self.role.role_name,
role_type="assistant", #self.role.role_type, role_type="assistant", #self.role.role_type,
role_content=content, role_content=content,
@ -151,10 +151,7 @@ class BaseAgent:
self.memory = self.init_history() self.memory = self.init_history()
def create_llm_engine(self, llm_config: LLMConfig = None, temperature=0.2, stop=None): def create_llm_engine(self, llm_config: LLMConfig = None, temperature=0.2, stop=None):
if llm_config is None: return getChatModelFromConfig(llm_config=llm_config)
return getChatModel(temperature=temperature, stop=stop)
else:
return getChatModelFromConfig(llm_config=llm_config)
def registry_actions(self, actions): def registry_actions(self, actions):
'''registry llm's actions''' '''registry llm's actions'''
@ -212,171 +209,3 @@ class BaseAgent:
def get_memory_str(self, content_key="role_content"): def get_memory_str(self, content_key="role_content"):
return "\n".join([": ".join(i) for i in self.memory.to_tuple_messages(content_key="step_content")]) return "\n".join([": ".join(i) for i in self.memory.to_tuple_messages(content_key="step_content")])
def create_prompt(
self, query: Message, memory: Memory =None, history: Memory = None, background: Memory = None, memory_pool: Memory=None, prompt_mamnger=None) -> str:
'''
prompt engineer, contains role\task\tools\docs\memory
'''
#
doc_infos = self.create_doc_prompt(query)
code_infos = self.create_codedoc_prompt(query)
#
formatted_tools, tool_names, _ = self.create_tools_prompt(query)
task_prompt = self.create_task_prompt(query)
background_prompt = self.create_background_prompt(background, control_key="step_content")
history_prompt = self.create_history_prompt(history)
selfmemory_prompt = self.create_selfmemory_prompt(memory, control_key="step_content")
# extra_system_prompt = self.role.role_prompt
prompt = self.role.role_prompt.format(**{"formatted_tools": formatted_tools, "tool_names": tool_names})
#
memory_pool_select_by_agent_key = self.select_memory_by_agent_key(memory_pool)
memory_pool_select_by_agent_key_context = '\n\n'.join([f"*{k}*\n{v}" for parsed_output in memory_pool_select_by_agent_key.get_parserd_output_list() for k, v in parsed_output.items() if k not in ['Action Status']])
# input_query = query.input_query
# # logger.debug(f"{self.role.role_name} extra_system_prompt: {self.role.role_prompt}")
# # logger.debug(f"{self.role.role_name} input_query: {input_query}")
# # logger.debug(f"{self.role.role_name} doc_infos: {doc_infos}")
# # logger.debug(f"{self.role.role_name} tool_names: {tool_names}")
# if "**Context:**" in self.role.role_prompt:
# # logger.debug(f"parsed_output_list: {query.parsed_output_list}")
# # input_query = "'''" + "\n".join([f"###{k}###\n{v}" for i in query.parsed_output_list for k,v in i.items() if "Action Status" !=k]) + "'''"
# context = "\n".join([f"*{k}*\n{v}" for i in query.parsed_output_list for k,v in i.items() if "Action Status" !=k])
# # context = history_prompt or '""'
# # logger.debug(f"parsed_output_list: {t}")
# prompt += "\n" + QUERY_CONTEXT_PROMPT_INPUT.format(**{"context": context, "query": query.origin_query})
# else:
# prompt += "\n" + PLAN_PROMPT_INPUT.format(**{"query": input_query})
task = query.task or self.task
if task_prompt is not None:
prompt += "\n" + task.task_prompt
DocInfos = ""
if doc_infos is not None and doc_infos!="" and doc_infos!="不存在知识库辅助信息":
DocInfos += f"\nDocument Information: {doc_infos}"
if code_infos is not None and code_infos!="" and code_infos!="不存在代码库辅助信息":
DocInfos += f"\nCodeBase Infomation: {code_infos}"
# if selfmemory_prompt:
# prompt += "\n" + selfmemory_prompt
# if background_prompt:
# prompt += "\n" + background_prompt
# if history_prompt:
# prompt += "\n" + history_prompt
input_query = query.input_query
# logger.debug(f"{self.role.role_name} extra_system_prompt: {self.role.role_prompt}")
# logger.debug(f"{self.role.role_name} input_query: {input_query}")
# logger.debug(f"{self.role.role_name} doc_infos: {doc_infos}")
# logger.debug(f"{self.role.role_name} tool_names: {tool_names}")
# extra_system_prompt = self.role.role_prompt
input_keys = parse_section(self.role.role_prompt, 'Input Format')
prompt = self.role.role_prompt.format(**{"formatted_tools": formatted_tools, "tool_names": tool_names})
prompt += "\n" + BEGIN_PROMPT_INPUT
for input_key in input_keys:
if input_key == "Origin Query":
prompt += "\n**Origin Query:**\n" + query.origin_query
elif input_key == "Context":
context = "\n".join([f"*{k}*\n{v}" for i in query.parsed_output_list for k,v in i.items() if "Action Status" !=k])
if history:
context = history_prompt + "\n" + context
if not context:
context = "there is no context"
if self.focus_agents and memory_pool_select_by_agent_key_context:
context = memory_pool_select_by_agent_key_context
prompt += "\n**Context:**\n" + context + "\n" + input_query
elif input_key == "DocInfos":
if DocInfos:
prompt += "\n**DocInfos:**\n" + DocInfos
else:
prompt += "\n**DocInfos:**\n" + "Empty"
elif input_key == "Question":
prompt += "\n**Question:**\n" + input_query
# if "**Context:**" in self.role.role_prompt:
# # logger.debug(f"parsed_output_list: {query.parsed_output_list}")
# # input_query = "'''" + "\n".join([f"###{k}###\n{v}" for i in query.parsed_output_list for k,v in i.items() if "Action Status" !=k]) + "'''"
# context = "\n".join([f"*{k}*\n{v}" for i in query.parsed_output_list for k,v in i.items() if "Action Status" !=k])
# if history:
# context = history_prompt + "\n" + context
# if not context:
# context = "there is no context"
# # logger.debug(f"parsed_output_list: {t}")
# if "DocInfos" in prompt:
# prompt += "\n" + QUERY_CONTEXT_DOC_PROMPT_INPUT.format(**{"context": context, "query": query.origin_query, "DocInfos": DocInfos})
# else:
# prompt += "\n" + QUERY_CONTEXT_PROMPT_INPUT.format(**{"context": context, "query": query.origin_query, "DocInfos": DocInfos})
# else:
# prompt += "\n" + BASE_PROMPT_INPUT.format(**{"query": input_query})
# prompt = extra_system_prompt.format(**{"query": input_query, "doc_infos": doc_infos, "formatted_tools": formatted_tools, "tool_names": tool_names})
while "{{" in prompt or "}}" in prompt:
prompt = prompt.replace("{{", "{")
prompt = prompt.replace("}}", "}")
# logger.debug(f"{self.role.role_name} prompt: {prompt}")
return prompt
def create_doc_prompt(self, message: Message) -> str:
''''''
db_docs = message.db_docs
search_docs = message.search_docs
doc_infos = "\n".join([doc.get_snippet() for doc in db_docs] + [doc.get_snippet() for doc in search_docs])
return doc_infos or "不存在知识库辅助信息"
def create_codedoc_prompt(self, message: Message) -> str:
''''''
code_docs = message.code_docs
doc_infos = "\n".join([doc.get_code() for doc in code_docs])
return doc_infos or "不存在代码库辅助信息"
def create_tools_prompt(self, message: Message) -> str:
tools = message.tools
tool_strings = []
tools_descs = []
for tool in tools:
args_schema = re.sub("}", "}}}}", re.sub("{", "{{{{", str(tool.args)))
tool_strings.append(f"{tool.name}: {tool.description}, args: {args_schema}")
tools_descs.append(f"{tool.name}: {tool.description}")
formatted_tools = "\n".join(tool_strings)
tools_desc_str = "\n".join(tools_descs)
tool_names = ", ".join([tool.name for tool in tools])
return formatted_tools, tool_names, tools_desc_str
def create_task_prompt(self, message: Message) -> str:
task = message.task or self.task
return "\n任务目标: " + task.task_prompt if task is not None else None
def create_background_prompt(self, background: Memory, control_key="role_content") -> str:
background_message = None if background is None else background.to_str_messages(content_key=control_key)
# logger.debug(f"background_message: {background_message}")
if background_message:
background_message = re.sub("}", "}}", re.sub("{", "{{", background_message))
return "\n背景信息: " + background_message if background_message else None
def create_history_prompt(self, history: Memory, control_key="role_content") -> str:
history_message = None if history is None else history.to_str_messages(content_key=control_key)
if history_message:
history_message = re.sub("}", "}}", re.sub("{", "{{", history_message))
return "\n补充对话信息: " + history_message if history_message else None
def create_selfmemory_prompt(self, selfmemory: Memory, control_key="role_content") -> str:
selfmemory_message = None if selfmemory is None else selfmemory.to_str_messages(content_key=control_key)
if selfmemory_message:
selfmemory_message = re.sub("}", "}}", re.sub("{", "{{", selfmemory_message))
return "\n补充自身对话信息: " + selfmemory_message if selfmemory_message else None

View File

@ -2,14 +2,15 @@ from typing import List, Union
import copy import copy
from loguru import logger from loguru import logger
from langchain.schema import BaseRetriever
from coagent.connector.schema import ( from coagent.connector.schema import (
Memory, Task, Env, Role, Message, ActionStatus, PromptField, LogVerboseEnum Memory, Task, Env, Role, Message, ActionStatus, PromptField, LogVerboseEnum
) )
from coagent.connector.memory_manager import BaseMemoryManager from coagent.connector.memory_manager import BaseMemoryManager
from coagent.connector.configs.prompts import BEGIN_PROMPT_INPUT
from coagent.llm_models import LLMConfig, EmbedConfig from coagent.llm_models import LLMConfig, EmbedConfig
from coagent.connector.memory_manager import LocalMemoryManager from coagent.connector.memory_manager import LocalMemoryManager
from coagent.base_configs.env_config import JUPYTER_WORK_PATH, KB_ROOT_PATH
from .base_agent import BaseAgent from .base_agent import BaseAgent
@ -17,7 +18,7 @@ class ExecutorAgent(BaseAgent):
def __init__( def __init__(
self, self,
role: Role, role: Role,
prompt_config: [PromptField], prompt_config: List[PromptField],
prompt_manager_type: str= "PromptManager", prompt_manager_type: str= "PromptManager",
task: Task = None, task: Task = None,
memory: Memory = None, memory: Memory = None,
@ -28,14 +29,17 @@ class ExecutorAgent(BaseAgent):
llm_config: LLMConfig = None, llm_config: LLMConfig = None,
embed_config: EmbedConfig = None, embed_config: EmbedConfig = None,
sandbox_server: dict = {}, sandbox_server: dict = {},
jupyter_work_path: str = "", jupyter_work_path: str = JUPYTER_WORK_PATH,
kb_root_path: str = "", kb_root_path: str = KB_ROOT_PATH,
doc_retrieval: Union[BaseRetriever] = None,
code_retrieval = None,
search_retrieval = None,
log_verbose: str = "0" log_verbose: str = "0"
): ):
super().__init__(role, prompt_config, prompt_manager_type, task, memory, chat_turn, super().__init__(role, prompt_config, prompt_manager_type, task, memory, chat_turn,
focus_agents, focus_message_keys, llm_config, embed_config, sandbox_server, focus_agents, focus_message_keys, llm_config, embed_config, sandbox_server,
jupyter_work_path, kb_root_path, log_verbose jupyter_work_path, kb_root_path, doc_retrieval, code_retrieval, search_retrieval, log_verbose
) )
self.do_all_task = True # run all tasks self.do_all_task = True # run all tasks
@ -45,6 +49,7 @@ class ExecutorAgent(BaseAgent):
task_executor_memory = Memory(messages=[]) task_executor_memory = Memory(messages=[])
# insert query # insert query
output_message = Message( output_message = Message(
user_name=query.user_name,
role_name=self.role.role_name, role_name=self.role.role_name,
role_type="assistant", #self.role.role_type, role_type="assistant", #self.role.role_type,
role_content=query.input_query, role_content=query.input_query,
@ -115,7 +120,7 @@ class ExecutorAgent(BaseAgent):
history: Memory, background: Memory, memory_manager: BaseMemoryManager, history: Memory, background: Memory, memory_manager: BaseMemoryManager,
task_memory: Memory) -> Union[Message, Memory]: task_memory: Memory) -> Union[Message, Memory]:
'''execute the llm predict by created prompt''' '''execute the llm predict by created prompt'''
memory_pool = memory_manager.current_memory memory_pool = memory_manager.get_memory_pool(query.user_name)
prompt = self.prompt_manager.generate_full_prompt( prompt = self.prompt_manager.generate_full_prompt(
previous_agent_message=query, agent_long_term_memory=self_memory, ui_history=history, chain_summary_messages=background, memory_pool=memory_pool, previous_agent_message=query, agent_long_term_memory=self_memory, ui_history=history, chain_summary_messages=background, memory_pool=memory_pool,
task_memory=task_memory) task_memory=task_memory)

View File

@ -3,23 +3,23 @@ import traceback
import copy import copy
from loguru import logger from loguru import logger
from langchain.schema import BaseRetriever
from coagent.connector.schema import ( from coagent.connector.schema import (
Memory, Task, Env, Role, Message, ActionStatus, PromptField, LogVerboseEnum Memory, Task, Env, Role, Message, ActionStatus, PromptField, LogVerboseEnum
) )
from coagent.connector.memory_manager import BaseMemoryManager from coagent.connector.memory_manager import BaseMemoryManager
from coagent.connector.configs.agent_config import REACT_PROMPT_INPUT
from coagent.llm_models import LLMConfig, EmbedConfig from coagent.llm_models import LLMConfig, EmbedConfig
from .base_agent import BaseAgent from .base_agent import BaseAgent
from coagent.connector.memory_manager import LocalMemoryManager from coagent.connector.memory_manager import LocalMemoryManager
from coagent.base_configs.env_config import JUPYTER_WORK_PATH, KB_ROOT_PATH
from coagent.connector.prompt_manager import PromptManager
class ReactAgent(BaseAgent): class ReactAgent(BaseAgent):
def __init__( def __init__(
self, self,
role: Role, role: Role,
prompt_config: [PromptField], prompt_config: List[PromptField],
prompt_manager_type: str = "PromptManager", prompt_manager_type: str = "PromptManager",
task: Task = None, task: Task = None,
memory: Memory = None, memory: Memory = None,
@ -30,14 +30,17 @@ class ReactAgent(BaseAgent):
llm_config: LLMConfig = None, llm_config: LLMConfig = None,
embed_config: EmbedConfig = None, embed_config: EmbedConfig = None,
sandbox_server: dict = {}, sandbox_server: dict = {},
jupyter_work_path: str = "", jupyter_work_path: str = JUPYTER_WORK_PATH,
kb_root_path: str = "", kb_root_path: str = KB_ROOT_PATH,
doc_retrieval: Union[BaseRetriever] = None,
code_retrieval = None,
search_retrieval = None,
log_verbose: str = "0" log_verbose: str = "0"
): ):
super().__init__(role, prompt_config, prompt_manager_type, task, memory, chat_turn, super().__init__(role, prompt_config, prompt_manager_type, task, memory, chat_turn,
focus_agents, focus_message_keys, llm_config, embed_config, sandbox_server, focus_agents, focus_message_keys, llm_config, embed_config, sandbox_server,
jupyter_work_path, kb_root_path, log_verbose jupyter_work_path, kb_root_path, doc_retrieval, code_retrieval, search_retrieval, log_verbose
) )
def step(self, query: Message, history: Memory = None, background: Memory = None, memory_manager: BaseMemoryManager = None) -> Message: def step(self, query: Message, history: Memory = None, background: Memory = None, memory_manager: BaseMemoryManager = None) -> Message:
@ -52,6 +55,7 @@ class ReactAgent(BaseAgent):
react_memory = Memory(messages=[]) react_memory = Memory(messages=[])
# insert query # insert query
output_message = Message( output_message = Message(
user_name=query.user_name,
role_name=self.role.role_name, role_name=self.role.role_name,
role_type="assistant", #self.role.role_type, role_type="assistant", #self.role.role_type,
role_content=query.input_query, role_content=query.input_query,
@ -84,9 +88,7 @@ class ReactAgent(BaseAgent):
llm_config=self.embed_config llm_config=self.embed_config
) )
memory_manager.append(query) memory_manager.append(query)
memory_pool = memory_manager.current_memory memory_pool = memory_manager.get_memory_pool(query_c.user_name)
else:
memory_pool = memory_manager.current_memory
prompt = self.prompt_manager.generate_full_prompt( prompt = self.prompt_manager.generate_full_prompt(
previous_agent_message=query_c, agent_long_term_memory=self.memory, ui_history=history, chain_summary_messages=background, react_memory=react_memory, previous_agent_message=query_c, agent_long_term_memory=self.memory, ui_history=history, chain_summary_messages=background, react_memory=react_memory,
@ -142,82 +144,4 @@ class ReactAgent(BaseAgent):
title = f"<<<<{self.role.role_name}'s prompt>>>>" title = f"<<<<{self.role.role_name}'s prompt>>>>"
print("#"*len(title) + f"\n{title}\n"+ "#"*len(title)+ f"\n\n{prompt}\n\n") print("#"*len(title) + f"\n{title}\n"+ "#"*len(title)+ f"\n\n{prompt}\n\n")
# def create_prompt(
# self, query: Message, memory: Memory =None, history: Memory = None, background: Memory = None, react_memory: Memory = None, memory_manager: BaseMemoryManager= None,
# prompt_mamnger=None) -> str:
# prompt_mamnger = PromptManager()
# prompt_mamnger.register_standard_fields()
# # input_keys = parse_section(self.role.role_prompt, 'Agent Profile')
# data_dict = {
# "agent_profile": extract_section(self.role.role_prompt, 'Agent Profile'),
# "tool_information": query.tools,
# "session_records": memory_manager,
# "reference_documents": query,
# "output_format": extract_section(self.role.role_prompt, 'Response Output Format'),
# "response": "\n".join(["\n".join([f"**{k}:**\n{v}" for k,v in _dict.items()]) for _dict in react_memory.get_parserd_output()]),
# }
# # logger.debug(memory_pool)
# return prompt_mamnger.generate_full_prompt(data_dict)
def create_prompt(
self, query: Message, memory: Memory =None, history: Memory = None, background: Memory = None, react_memory: Memory = None, memory_pool: Memory= None,
prompt_mamnger=None) -> str:
'''
role\task\tools\docs\memory
'''
#
doc_infos = self.create_doc_prompt(query)
code_infos = self.create_codedoc_prompt(query)
#
formatted_tools, tool_names, _ = self.create_tools_prompt(query)
task_prompt = self.create_task_prompt(query)
background_prompt = self.create_background_prompt(background)
history_prompt = self.create_history_prompt(history)
selfmemory_prompt = self.create_selfmemory_prompt(memory, control_key="step_content")
#
# extra_system_prompt = self.role.role_prompt
prompt = self.role.role_prompt.format(**{"formatted_tools": formatted_tools, "tool_names": tool_names})
# react 流程是自身迭代过程,另外二次触发的是需要作为历史对话信息
# input_query = react_memory.to_tuple_messages(content_key="step_content")
# # input_query = query.input_query + "\n" + "\n".join([f"{v}" for k, v in input_query if v])
# input_query = "\n".join([f"{v}" for k, v in input_query if v])
input_query = "\n".join(["\n".join([f"**{k}:**\n{v}" for k,v in _dict.items()]) for _dict in react_memory.get_parserd_output()])
# logger.debug(f"input_query: {input_query}")
prompt += "\n" + REACT_PROMPT_INPUT.format(**{"query": input_query})
task = query.task or self.task
# if task_prompt is not None:
# prompt += "\n" + task.task_prompt
# if doc_infos is not None and doc_infos!="" and doc_infos!="不存在知识库辅助信息":
# prompt += f"\n知识库信息: {doc_infos}"
# if code_infos is not None and code_infos!="" and code_infos!="不存在代码库辅助信息":
# prompt += f"\n代码库信息: {code_infos}"
# if background_prompt:
# prompt += "\n" + background_prompt
# if history_prompt:
# prompt += "\n" + history_prompt
# if selfmemory_prompt:
# prompt += "\n" + selfmemory_prompt
# logger.debug(f"{self.role.role_name} extra_system_prompt: {self.role.role_prompt}")
# logger.debug(f"{self.role.role_name} input_query: {input_query}")
# logger.debug(f"{self.role.role_name} doc_infos: {doc_infos}")
# logger.debug(f"{self.role.role_name} tool_names: {tool_names}")
# prompt += "\n" + REACT_PROMPT_INPUT.format(**{"query": input_query})
# prompt = extra_system_prompt.format(**{"query": input_query, "doc_infos": doc_infos, "formatted_tools": formatted_tools, "tool_names": tool_names})
while "{{" in prompt or "}}" in prompt:
prompt = prompt.replace("{{", "{")
prompt = prompt.replace("}}", "}")
return prompt

View File

@ -3,13 +3,15 @@ import copy
import random import random
from loguru import logger from loguru import logger
from langchain.schema import BaseRetriever
from coagent.connector.schema import ( from coagent.connector.schema import (
Memory, Task, Role, Message, PromptField, LogVerboseEnum Memory, Task, Role, Message, PromptField, LogVerboseEnum
) )
from coagent.connector.memory_manager import BaseMemoryManager from coagent.connector.memory_manager import BaseMemoryManager
from coagent.connector.configs.prompts import BEGIN_PROMPT_INPUT
from coagent.connector.memory_manager import LocalMemoryManager from coagent.connector.memory_manager import LocalMemoryManager
from coagent.llm_models import LLMConfig, EmbedConfig from coagent.llm_models import LLMConfig, EmbedConfig
from coagent.base_configs.env_config import JUPYTER_WORK_PATH, KB_ROOT_PATH
from .base_agent import BaseAgent from .base_agent import BaseAgent
@ -30,14 +32,17 @@ class SelectorAgent(BaseAgent):
llm_config: LLMConfig = None, llm_config: LLMConfig = None,
embed_config: EmbedConfig = None, embed_config: EmbedConfig = None,
sandbox_server: dict = {}, sandbox_server: dict = {},
jupyter_work_path: str = "", jupyter_work_path: str = JUPYTER_WORK_PATH,
kb_root_path: str = "", kb_root_path: str = KB_ROOT_PATH,
doc_retrieval: Union[BaseRetriever] = None,
code_retrieval = None,
search_retrieval = None,
log_verbose: str = "0" log_verbose: str = "0"
): ):
super().__init__(role, prompt_config, prompt_manager_type, task, memory, chat_turn, super().__init__(role, prompt_config, prompt_manager_type, task, memory, chat_turn,
focus_agents, focus_message_keys, llm_config, embed_config, sandbox_server, focus_agents, focus_message_keys, llm_config, embed_config, sandbox_server,
jupyter_work_path, kb_root_path, log_verbose jupyter_work_path, kb_root_path, doc_retrieval, code_retrieval, search_retrieval, log_verbose
) )
self.group_agents = group_agents self.group_agents = group_agents
@ -56,9 +61,8 @@ class SelectorAgent(BaseAgent):
llm_config=self.embed_config llm_config=self.embed_config
) )
memory_manager.append(query) memory_manager.append(query)
memory_pool = memory_manager.current_memory memory_pool = memory_manager.get_memory_pool(query_c.user_name)
else:
memory_pool = memory_manager.current_memory
prompt = self.prompt_manager.generate_full_prompt( prompt = self.prompt_manager.generate_full_prompt(
previous_agent_message=query_c, agent_long_term_memory=self.memory, ui_history=history, chain_summary_messages=background, react_memory=None, previous_agent_message=query_c, agent_long_term_memory=self.memory, ui_history=history, chain_summary_messages=background, react_memory=None,
memory_pool=memory_pool, agents=self.group_agents) memory_pool=memory_pool, agents=self.group_agents)
@ -90,6 +94,9 @@ class SelectorAgent(BaseAgent):
for agent in self.group_agents: for agent in self.group_agents:
if agent.role.role_name == select_message.parsed_output.get("Role", ""): if agent.role.role_name == select_message.parsed_output.get("Role", ""):
break break
# 把除了role以外的信息传给下一个agent
query_c.parsed_output.update({k:v for k,v in select_message.parsed_output.items() if k!="Role"})
for output_message in agent.astep(query_c, history, background=background, memory_manager=memory_manager): for output_message in agent.astep(query_c, history, background=background, memory_manager=memory_manager):
yield output_message or select_message yield output_message or select_message
# update self_memory # update self_memory
@ -103,6 +110,7 @@ class SelectorAgent(BaseAgent):
memory_manager.append(output_message) memory_manager.append(output_message)
select_message.parsed_output = output_message.parsed_output select_message.parsed_output = output_message.parsed_output
select_message.spec_parsed_output.update(output_message.spec_parsed_output)
select_message.parsed_output_list.extend(output_message.parsed_output_list) select_message.parsed_output_list.extend(output_message.parsed_output_list)
yield select_message yield select_message
@ -115,76 +123,3 @@ class SelectorAgent(BaseAgent):
for agent in self.group_agents: for agent in self.group_agents:
agent.pre_print(query=query, history=history, background=background, memory_manager=memory_manager) agent.pre_print(query=query, history=history, background=background, memory_manager=memory_manager)
# def create_prompt(
# self, query: Message, memory: Memory =None, history: Memory = None, background: Memory = None, memory_manager: BaseMemoryManager=None, prompt_mamnger=None) -> str:
# '''
# role\task\tools\docs\memory
# '''
# #
# doc_infos = self.create_doc_prompt(query)
# code_infos = self.create_codedoc_prompt(query)
# #
# formatted_tools, tool_names, tools_descs = self.create_tools_prompt(query)
# agent_names, agents = self.create_agent_names()
# task_prompt = self.create_task_prompt(query)
# background_prompt = self.create_background_prompt(background)
# history_prompt = self.create_history_prompt(history)
# selfmemory_prompt = self.create_selfmemory_prompt(memory, control_key="step_content")
# DocInfos = ""
# if doc_infos is not None and doc_infos!="" and doc_infos!="不存在知识库辅助信息":
# DocInfos += f"\nDocument Information: {doc_infos}"
# if code_infos is not None and code_infos!="" and code_infos!="不存在代码库辅助信息":
# DocInfos += f"\nCodeBase Infomation: {code_infos}"
# input_query = query.input_query
# logger.debug(f"{self.role.role_name} input_query: {input_query}")
# prompt = self.role.role_prompt.format(**{"agent_names": agent_names, "agents": agents, "formatted_tools": tools_descs, "tool_names": tool_names})
# #
# memory_pool_select_by_agent_key = self.select_memory_by_agent_key(memory_manager.current_memory)
# memory_pool_select_by_agent_key_context = '\n\n'.join([f"*{k}*\n{v}" for parsed_output in memory_pool_select_by_agent_key.get_parserd_output_list() for k, v in parsed_output.items() if k not in ['Action Status']])
# input_keys = parse_section(self.role.role_prompt, 'Input Format')
# #
# prompt += "\n" + BEGIN_PROMPT_INPUT
# for input_key in input_keys:
# if input_key == "Origin Query":
# prompt += "\n**Origin Query:**\n" + query.origin_query
# elif input_key == "Context":
# context = "\n".join([f"*{k}*\n{v}" for i in query.parsed_output_list for k,v in i.items() if "Action Status" !=k])
# if history:
# context = history_prompt + "\n" + context
# if not context:
# context = "there is no context"
# if self.focus_agents and memory_pool_select_by_agent_key_context:
# context = memory_pool_select_by_agent_key_context
# prompt += "\n**Context:**\n" + context + "\n" + input_query
# elif input_key == "DocInfos":
# prompt += "\n**DocInfos:**\n" + DocInfos
# elif input_key == "Question":
# prompt += "\n**Question:**\n" + input_query
# while "{{" in prompt or "}}" in prompt:
# prompt = prompt.replace("{{", "{")
# prompt = prompt.replace("}}", "}")
# # logger.debug(f"{self.role.role_name} prompt: {prompt}")
# return prompt
# def create_agent_names(self):
# random.shuffle(self.group_agents)
# agent_names = ", ".join([f'{agent.role.role_name}' for agent in self.group_agents])
# agent_descs = []
# for agent in self.group_agents:
# role_desc = agent.role.role_prompt.split("####")[1]
# while "\n\n" in role_desc:
# role_desc = role_desc.replace("\n\n", "\n")
# role_desc = role_desc.replace("\n", ",")
# agent_descs.append(f'"role name: {agent.role.role_name}\nrole description: {role_desc}"')
# return agent_names, "\n".join(agent_descs)

View File

@ -0,0 +1,7 @@
from .flow import AgentFlow, PhaseFlow, ChainFlow
__all__ = [
"AgentFlow", "PhaseFlow", "ChainFlow"
]

View File

@ -0,0 +1,255 @@
import importlib
from typing import List, Union, Dict, Any
from loguru import logger
import os
from langchain.embeddings.base import Embeddings
from langchain.agents import Tool
from langchain.llms.base import BaseLLM, LLM
from coagent.retrieval.base_retrieval import IMRertrieval
from coagent.llm_models.llm_config import EmbedConfig, LLMConfig
from coagent.connector.phase import BasePhase
from coagent.connector.agents import BaseAgent
from coagent.connector.chains import BaseChain
from coagent.connector.schema import Message, Role, PromptField, ChainConfig
from coagent.tools import toLangchainTools, TOOL_DICT, TOOL_SETS
class AgentFlow:
def __init__(
self,
role_name: str,
agent_type: str,
role_type: str = "assistant",
agent_index: int = 0,
role_prompt: str = "",
prompt_config: List[Dict[str, Any]] = [],
prompt_manager_type: str = "PromptManager",
chat_turn: int = 3,
focus_agents: List[str] = [],
focus_messages: List[str] = [],
embeddings: Embeddings = None,
llm: BaseLLM = None,
doc_retrieval: IMRertrieval = None,
code_retrieval: IMRertrieval = None,
search_retrieval: IMRertrieval = None,
**kwargs
):
self.role_type = role_type
self.role_name = role_name
self.agent_type = agent_type
self.role_prompt = role_prompt
self.agent_index = agent_index
self.prompt_config = prompt_config
self.prompt_manager_type = prompt_manager_type
self.chat_turn = chat_turn
self.focus_agents = focus_agents
self.focus_messages = focus_messages
self.embeddings = embeddings
self.llm = llm
self.doc_retrieval = doc_retrieval
self.code_retrieval = code_retrieval
self.search_retrieval = search_retrieval
# self.build_config()
# self.build_agent()
def build_config(self, embeddings: Embeddings = None, llm: BaseLLM = None):
self.llm_config = LLMConfig(model_name="test", llm=self.llm or llm)
self.embed_config = EmbedConfig(embed_model="test", langchain_embeddings=self.embeddings or embeddings)
def build_agent(self,
embeddings: Embeddings = None, llm: BaseLLM = None,
doc_retrieval: IMRertrieval = None,
code_retrieval: IMRertrieval = None,
search_retrieval: IMRertrieval = None,
):
# 可注册个性化的agent仅通过start_action和end_action来注册
# class ExtraAgent(BaseAgent):
# def start_action_step(self, message: Message) -> Message:
# pass
# def end_action_step(self, message: Message) -> Message:
# pass
# agent_module = importlib.import_module("coagent.connector.agents")
# setattr(agent_module, 'extraAgent', ExtraAgent)
# 可注册个性化的prompt组装方式
# class CodeRetrievalPM(PromptManager):
# def handle_code_packages(self, **kwargs) -> str:
# if 'previous_agent_message' not in kwargs:
# return ""
# previous_agent_message: Message = kwargs['previous_agent_message']
# # 由于两个agent共用了同一个manager所以临时性处理
# vertices = previous_agent_message.customed_kargs.get("RelatedVerticesRetrivalRes", {}).get("vertices", [])
# return ", ".join([str(v) for v in vertices])
# prompt_manager_module = importlib.import_module("coagent.connector.prompt_manager")
# setattr(prompt_manager_module, 'CodeRetrievalPM', CodeRetrievalPM)
# agent实例化
agent_module = importlib.import_module("coagent.connector.agents")
baseAgent: BaseAgent = getattr(agent_module, self.agent_type)
role = Role(
role_type=self.agent_type, role_name=self.role_name,
agent_type=self.agent_type, role_prompt=self.role_prompt,
)
self.build_config(embeddings, llm)
self.agent = baseAgent(
role=role,
prompt_config = [PromptField(**config) for config in self.prompt_config],
prompt_manager_type=self.prompt_manager_type,
chat_turn=self.chat_turn,
focus_agents=self.focus_agents,
focus_message_keys=self.focus_messages,
llm_config=self.llm_config,
embed_config=self.embed_config,
doc_retrieval=doc_retrieval or self.doc_retrieval,
code_retrieval=code_retrieval or self.code_retrieval,
search_retrieval=search_retrieval or self.search_retrieval,
)
class ChainFlow:
def __init__(
self,
chain_name: str,
chain_index: int = 0,
agent_flows: List[AgentFlow] = [],
chat_turn: int = 5,
do_checker: bool = False,
embeddings: Embeddings = None,
llm: BaseLLM = None,
doc_retrieval: IMRertrieval = None,
code_retrieval: IMRertrieval = None,
search_retrieval: IMRertrieval = None,
# chain_type: str = "BaseChain",
**kwargs
):
self.agent_flows = sorted(agent_flows, key=lambda x:x.agent_index)
self.chat_turn = chat_turn
self.do_checker = do_checker
self.chain_name = chain_name
self.chain_index = chain_index
self.chain_type = "BaseChain"
self.embeddings = embeddings
self.llm = llm
self.doc_retrieval = doc_retrieval
self.code_retrieval = code_retrieval
self.search_retrieval = search_retrieval
# self.build_config()
# self.build_chain()
def build_config(self, embeddings: Embeddings = None, llm: BaseLLM = None):
self.llm_config = LLMConfig(model_name="test", llm=self.llm or llm)
self.embed_config = EmbedConfig(embed_model="test", langchain_embeddings=self.embeddings or embeddings)
def build_chain(self,
embeddings: Embeddings = None, llm: BaseLLM = None,
doc_retrieval: IMRertrieval = None,
code_retrieval: IMRertrieval = None,
search_retrieval: IMRertrieval = None,
):
# chain 实例化
chain_module = importlib.import_module("coagent.connector.chains")
baseChain: BaseChain = getattr(chain_module, self.chain_type)
agent_names = [agent_flow.role_name for agent_flow in self.agent_flows]
chain_config = ChainConfig(chain_name=self.chain_name, agents=agent_names, do_checker=self.do_checker, chat_turn=self.chat_turn)
# agent 实例化
self.build_config(embeddings, llm)
for agent_flow in self.agent_flows:
agent_flow.build_agent(embeddings, llm)
self.chain = baseChain(
chain_config,
[agent_flow.agent for agent_flow in self.agent_flows],
embed_config=self.embed_config,
llm_config=self.llm_config,
doc_retrieval=doc_retrieval or self.doc_retrieval,
code_retrieval=code_retrieval or self.code_retrieval,
search_retrieval=search_retrieval or self.search_retrieval,
)
class PhaseFlow:
def __init__(
self,
phase_name: str,
chain_flows: List[ChainFlow],
embeddings: Embeddings = None,
llm: BaseLLM = None,
tools: List[Tool] = [],
doc_retrieval: IMRertrieval = None,
code_retrieval: IMRertrieval = None,
search_retrieval: IMRertrieval = None,
**kwargs
):
self.phase_name = phase_name
self.chain_flows = sorted(chain_flows, key=lambda x:x.chain_index)
self.phase_type = "BasePhase"
self.tools = tools
self.embeddings = embeddings
self.llm = llm
self.doc_retrieval = doc_retrieval
self.code_retrieval = code_retrieval
self.search_retrieval = search_retrieval
# self.build_config()
self.build_phase()
def __call__(self, params: dict) -> str:
# tools = toLangchainTools([TOOL_DICT[i] for i in TOOL_SETS if i in TOOL_DICT])
# query_content = "帮我确认下127.0.0.1这个服务器的在10点是否存在异常请帮我判断一下"
try:
logger.info(f"params: {params}")
query_content = params.get("query") or params.get("input")
search_type = params.get("search_type")
query = Message(
role_name="human", role_type="user", tools=self.tools,
role_content=query_content, input_query=query_content, origin_query=query_content,
cb_search_type=search_type,
)
# phase.pre_print(query)
output_message, output_memory = self.phase.step(query)
output_content = "\n\n".join((output_memory.to_str_messages(return_all=True, content_key="parsed_output_list").split("\n\n")[1:])) or output_message.role_content
return output_content
except Exception as e:
logger.exception(e)
return f"Error {e}"
def build_config(self, embeddings: Embeddings = None, llm: BaseLLM = None):
self.llm_config = LLMConfig(model_name="test", llm=self.llm or llm)
self.embed_config = EmbedConfig(embed_model="test", langchain_embeddings=self.embeddings or embeddings)
def build_phase(self, embeddings: Embeddings = None, llm: BaseLLM = None):
# phase 实例化
phase_module = importlib.import_module("coagent.connector.phase")
basePhase: BasePhase = getattr(phase_module, self.phase_type)
# chain 实例化
self.build_config(self.embeddings or embeddings, self.llm or llm)
os.environ["log_verbose"] = "2"
for chain_flow in self.chain_flows:
chain_flow.build_chain(
self.embeddings or embeddings, self.llm or llm,
self.doc_retrieval, self.code_retrieval, self.search_retrieval
)
self.phase: BasePhase = basePhase(
phase_name=self.phase_name,
chains=[chain_flow.chain for chain_flow in self.chain_flows],
embed_config=self.embed_config,
llm_config=self.llm_config,
doc_retrieval=self.doc_retrieval,
code_retrieval=self.code_retrieval,
search_retrieval=self.search_retrieval
)

View File

@ -1,9 +1,10 @@
from typing import List from typing import List, Tuple, Union
from loguru import logger from loguru import logger
import copy, os import copy, os
from coagent.connector.agents import BaseAgent from langchain.schema import BaseRetriever
from coagent.connector.agents import BaseAgent
from coagent.connector.schema import ( from coagent.connector.schema import (
Memory, Role, Message, ActionStatus, ChainConfig, Memory, Role, Message, ActionStatus, ChainConfig,
load_role_configs load_role_configs
@ -11,31 +12,32 @@ from coagent.connector.schema import (
from coagent.connector.memory_manager import BaseMemoryManager from coagent.connector.memory_manager import BaseMemoryManager
from coagent.connector.message_process import MessageUtils from coagent.connector.message_process import MessageUtils
from coagent.llm_models.llm_config import LLMConfig, EmbedConfig from coagent.llm_models.llm_config import LLMConfig, EmbedConfig
from coagent.base_configs.env_config import JUPYTER_WORK_PATH, KB_ROOT_PATH
from coagent.connector.configs.agent_config import AGETN_CONFIGS from coagent.connector.configs.agent_config import AGETN_CONFIGS
role_configs = load_role_configs(AGETN_CONFIGS) role_configs = load_role_configs(AGETN_CONFIGS)
# from configs.model_config import JUPYTER_WORK_PATH
# from configs.server_config import SANDBOX_SERVER
class BaseChain: class BaseChain:
def __init__( def __init__(
self, self,
# chainConfig: ChainConfig, chainConfig: ChainConfig,
agents: List[BaseAgent], agents: List[BaseAgent],
chat_turn: int = 1, # chat_turn: int = 1,
do_checker: bool = False, # do_checker: bool = False,
sandbox_server: dict = {}, sandbox_server: dict = {},
jupyter_work_path: str = "", jupyter_work_path: str = JUPYTER_WORK_PATH,
kb_root_path: str = "", kb_root_path: str = KB_ROOT_PATH,
llm_config: LLMConfig = LLMConfig(), llm_config: LLMConfig = LLMConfig(),
embed_config: EmbedConfig = None, embed_config: EmbedConfig = None,
doc_retrieval: Union[BaseRetriever] = None,
code_retrieval = None,
search_retrieval = None,
log_verbose: str = "0" log_verbose: str = "0"
) -> None: ) -> None:
# self.chainConfig = chainConfig self.chainConfig = chainConfig
self.agents: List[BaseAgent] = agents self.agents: List[BaseAgent] = agents
self.chat_turn = chat_turn self.chat_turn = chainConfig.chat_turn
self.do_checker = do_checker self.do_checker = chainConfig.do_checker
self.sandbox_server = sandbox_server self.sandbox_server = sandbox_server
self.jupyter_work_path = jupyter_work_path self.jupyter_work_path = jupyter_work_path
self.llm_config = llm_config self.llm_config = llm_config
@ -45,9 +47,11 @@ class BaseChain:
task = None, memory = None, task = None, memory = None,
llm_config=llm_config, embed_config=embed_config, llm_config=llm_config, embed_config=embed_config,
sandbox_server=sandbox_server, jupyter_work_path=jupyter_work_path, sandbox_server=sandbox_server, jupyter_work_path=jupyter_work_path,
kb_root_path=kb_root_path kb_root_path=kb_root_path,
doc_retrieval=doc_retrieval, code_retrieval=code_retrieval,
search_retrieval=search_retrieval
) )
self.messageUtils = MessageUtils(None, sandbox_server, self.jupyter_work_path, embed_config, llm_config, kb_root_path, log_verbose) self.messageUtils = MessageUtils(None, sandbox_server, self.jupyter_work_path, embed_config, llm_config, kb_root_path, doc_retrieval, code_retrieval, search_retrieval, log_verbose)
# all memory created by agent until instance deleted # all memory created by agent until instance deleted
self.global_memory = Memory(messages=[]) self.global_memory = Memory(messages=[])
@ -62,13 +66,16 @@ class BaseChain:
for agent in self.agents: for agent in self.agents:
agent.pre_print(query, history, background=background, memory_manager=memory_manager) agent.pre_print(query, history, background=background, memory_manager=memory_manager)
def astep(self, query: Message, history: Memory = None, background: Memory = None, memory_manager: BaseMemoryManager = None) -> Message: def astep(self, query: Message, history: Memory = None, background: Memory = None, memory_manager: BaseMemoryManager = None) -> Tuple[Message, Memory]:
'''execute chain''' '''execute chain'''
local_memory = Memory(messages=[]) local_memory = Memory(messages=[])
input_message = copy.deepcopy(query) input_message = copy.deepcopy(query)
step_nums = copy.deepcopy(self.chat_turn) step_nums = copy.deepcopy(self.chat_turn)
check_message = None check_message = None
# if input_message not in memory_manager:
# memory_manager.append(input_message)
self.global_memory.append(input_message) self.global_memory.append(input_message)
# local_memory.append(input_message) # local_memory.append(input_message)
while step_nums > 0: while step_nums > 0:
@ -78,7 +85,7 @@ class BaseChain:
yield output_message, local_memory + output_message yield output_message, local_memory + output_message
output_message = self.messageUtils.inherit_extrainfo(input_message, output_message) output_message = self.messageUtils.inherit_extrainfo(input_message, output_message)
# according the output to choose one action for code_content or tool_content # according the output to choose one action for code_content or tool_content
output_message = self.messageUtils.parser(output_message) # output_message = self.messageUtils.parser(output_message)
yield output_message, local_memory + output_message yield output_message, local_memory + output_message
# output_message = self.step_router(output_message) # output_message = self.step_router(output_message)
input_message = output_message input_message = output_message

View File

@ -1,9 +1,10 @@
from .agent_config import AGETN_CONFIGS from .agent_config import AGETN_CONFIGS
from .chain_config import CHAIN_CONFIGS from .chain_config import CHAIN_CONFIGS
from .phase_config import PHASE_CONFIGS from .phase_config import PHASE_CONFIGS
from .prompt_config import BASE_PROMPT_CONFIGS, EXECUTOR_PROMPT_CONFIGS, SELECTOR_PROMPT_CONFIGS, BASE_NOTOOLPROMPT_CONFIGS from .prompt_config import *
__all__ = [ __all__ = [
"AGETN_CONFIGS", "CHAIN_CONFIGS", "PHASE_CONFIGS", "AGETN_CONFIGS", "CHAIN_CONFIGS", "PHASE_CONFIGS",
"BASE_PROMPT_CONFIGS", "EXECUTOR_PROMPT_CONFIGS", "SELECTOR_PROMPT_CONFIGS", "BASE_NOTOOLPROMPT_CONFIGS" "BASE_PROMPT_CONFIGS", "EXECUTOR_PROMPT_CONFIGS", "SELECTOR_PROMPT_CONFIGS", "BASE_NOTOOLPROMPT_CONFIGS",
"CODE2DOC_GROUP_PROMPT_CONFIGS", "CODE2DOC_PROMPT_CONFIGS", "CODE2TESTS_PROMPT_CONFIGS"
] ]

View File

@ -1,19 +1,21 @@
from enum import Enum from enum import Enum
from .prompts import ( from .prompts import *
REACT_PROMPT_INPUT, CHECK_PROMPT_INPUT, EXECUTOR_PROMPT_INPUT, CONTEXT_PROMPT_INPUT, QUERY_CONTEXT_PROMPT_INPUT,PLAN_PROMPT_INPUT, # from .prompts import (
RECOGNIZE_INTENTION_PROMPT, # REACT_PROMPT_INPUT, CHECK_PROMPT_INPUT, EXECUTOR_PROMPT_INPUT, CONTEXT_PROMPT_INPUT, QUERY_CONTEXT_PROMPT_INPUT,PLAN_PROMPT_INPUT,
CHECKER_TEMPLATE_PROMPT, # RECOGNIZE_INTENTION_PROMPT,
CONV_SUMMARY_PROMPT, # CHECKER_TEMPLATE_PROMPT,
QA_PROMPT, CODE_QA_PROMPT, QA_TEMPLATE_PROMPT, # CONV_SUMMARY_PROMPT,
EXECUTOR_TEMPLATE_PROMPT, # QA_PROMPT, CODE_QA_PROMPT, QA_TEMPLATE_PROMPT,
REFINE_TEMPLATE_PROMPT, # EXECUTOR_TEMPLATE_PROMPT,
SELECTOR_AGENT_TEMPLATE_PROMPT, # REFINE_TEMPLATE_PROMPT,
PLANNER_TEMPLATE_PROMPT, GENERAL_PLANNER_PROMPT, DATA_PLANNER_PROMPT, TOOL_PLANNER_PROMPT, # SELECTOR_AGENT_TEMPLATE_PROMPT,
PRD_WRITER_METAGPT_PROMPT, DESIGN_WRITER_METAGPT_PROMPT, TASK_WRITER_METAGPT_PROMPT, CODE_WRITER_METAGPT_PROMPT, # PLANNER_TEMPLATE_PROMPT, GENERAL_PLANNER_PROMPT, DATA_PLANNER_PROMPT, TOOL_PLANNER_PROMPT,
REACT_TEMPLATE_PROMPT, # PRD_WRITER_METAGPT_PROMPT, DESIGN_WRITER_METAGPT_PROMPT, TASK_WRITER_METAGPT_PROMPT, CODE_WRITER_METAGPT_PROMPT,
REACT_TOOL_PROMPT, REACT_CODE_PROMPT, REACT_TOOL_AND_CODE_PLANNER_PROMPT, REACT_TOOL_AND_CODE_PROMPT # REACT_TEMPLATE_PROMPT,
) # REACT_TOOL_PROMPT, REACT_CODE_PROMPT, REACT_TOOL_AND_CODE_PLANNER_PROMPT, REACT_TOOL_AND_CODE_PROMPT
from .prompt_config import BASE_PROMPT_CONFIGS, EXECUTOR_PROMPT_CONFIGS, SELECTOR_PROMPT_CONFIGS, BASE_NOTOOLPROMPT_CONFIGS # )
from .prompt_config import *
# BASE_PROMPT_CONFIGS, EXECUTOR_PROMPT_CONFIGS, SELECTOR_PROMPT_CONFIGS, BASE_NOTOOLPROMPT_CONFIGS
@ -261,4 +263,68 @@ AGETN_CONFIGS = {
"focus_agents": ["metaGPT_DESIGN", "metaGPT_TASK"], "focus_agents": ["metaGPT_DESIGN", "metaGPT_TASK"],
"focus_message_keys": [], "focus_message_keys": [],
}, },
"class2Docer": {
"role": {
"role_prompt": Class2Doc_PROMPT,
"role_type": "assistant",
"role_name": "class2Docer",
"role_desc": "",
"agent_type": "CodeGenDocer"
},
"prompt_config": CODE2DOC_PROMPT_CONFIGS,
"prompt_manager_type": "Code2DocPM",
"chat_turn": 1,
"focus_agents": [],
"focus_message_keys": [],
},
"func2Docer": {
"role": {
"role_prompt": Func2Doc_PROMPT,
"role_type": "assistant",
"role_name": "func2Docer",
"role_desc": "",
"agent_type": "CodeGenDocer"
},
"prompt_config": CODE2DOC_PROMPT_CONFIGS,
"prompt_manager_type": "Code2DocPM",
"chat_turn": 1,
"focus_agents": [],
"focus_message_keys": [],
},
"code2DocsGrouper": {
"role": {
"role_prompt": Code2DocGroup_PROMPT,
"role_type": "assistant",
"role_name": "code2DocsGrouper",
"role_desc": "",
"agent_type": "SelectorAgent"
},
"prompt_config": CODE2DOC_GROUP_PROMPT_CONFIGS,
"group_agents": ["class2Docer", "func2Docer"],
"chat_turn": 1,
},
"Code2TestJudger": {
"role": {
"role_prompt": judgeCode2Tests_PROMPT,
"role_type": "assistant",
"role_name": "Code2TestJudger",
"role_desc": "",
"agent_type": "CodeRetrieval"
},
"prompt_config": CODE2TESTS_PROMPT_CONFIGS,
"prompt_manager_type": "CodeRetrievalPM",
"chat_turn": 1,
},
"code2Tests": {
"role": {
"role_prompt": code2Tests_PROMPT,
"role_type": "assistant",
"role_name": "code2Tests",
"role_desc": "",
"agent_type": "CodeRetrieval"
},
"prompt_config": CODE2TESTS_PROMPT_CONFIGS,
"prompt_manager_type": "CodeRetrievalPM",
"chat_turn": 1,
},
} }

View File

@ -123,5 +123,21 @@ CHAIN_CONFIGS = {
"chat_turn": 1, "chat_turn": 1,
"do_checker": False, "do_checker": False,
"chain_prompt": "" "chain_prompt": ""
},
"code2DocsGroupChain": {
"chain_name": "code2DocsGroupChain",
"chain_type": "BaseChain",
"agents": ["code2DocsGrouper"],
"chat_turn": 1,
"do_checker": False,
"chain_prompt": ""
},
"code2TestsChain": {
"chain_name": "code2TestsChain",
"chain_type": "BaseChain",
"agents": ["Code2TestJudger", "code2Tests"],
"chat_turn": 1,
"do_checker": False,
"chain_prompt": ""
} }
} }

View File

@ -14,44 +14,24 @@ PHASE_CONFIGS = {
"phase_name": "docChatPhase", "phase_name": "docChatPhase",
"phase_type": "BasePhase", "phase_type": "BasePhase",
"chains": ["docChatChain"], "chains": ["docChatChain"],
"do_summary": False,
"do_search": False,
"do_doc_retrieval": True, "do_doc_retrieval": True,
"do_code_retrieval": False,
"do_tool_retrieval": False,
"do_using_tool": False
}, },
"searchChatPhase": { "searchChatPhase": {
"phase_name": "searchChatPhase", "phase_name": "searchChatPhase",
"phase_type": "BasePhase", "phase_type": "BasePhase",
"chains": ["searchChatChain"], "chains": ["searchChatChain"],
"do_summary": False,
"do_search": True, "do_search": True,
"do_doc_retrieval": False,
"do_code_retrieval": False,
"do_tool_retrieval": False,
"do_using_tool": False
}, },
"codeChatPhase": { "codeChatPhase": {
"phase_name": "codeChatPhase", "phase_name": "codeChatPhase",
"phase_type": "BasePhase", "phase_type": "BasePhase",
"chains": ["codeChatChain"], "chains": ["codeChatChain"],
"do_summary": False,
"do_search": False,
"do_doc_retrieval": False,
"do_code_retrieval": True, "do_code_retrieval": True,
"do_tool_retrieval": False,
"do_using_tool": False
}, },
"toolReactPhase": { "toolReactPhase": {
"phase_name": "toolReactPhase", "phase_name": "toolReactPhase",
"phase_type": "BasePhase", "phase_type": "BasePhase",
"chains": ["toolReactChain"], "chains": ["toolReactChain"],
"do_summary": False,
"do_search": False,
"do_doc_retrieval": False,
"do_code_retrieval": False,
"do_tool_retrieval": False,
"do_using_tool": True "do_using_tool": True
}, },
"codeReactPhase": { "codeReactPhase": {
@ -59,55 +39,36 @@ PHASE_CONFIGS = {
"phase_type": "BasePhase", "phase_type": "BasePhase",
# "chains": ["codePlannerChain", "codeReactChain"], # "chains": ["codePlannerChain", "codeReactChain"],
"chains": ["planChain", "codeReactChain"], "chains": ["planChain", "codeReactChain"],
"do_summary": False,
"do_search": False,
"do_doc_retrieval": False,
"do_code_retrieval": False,
"do_tool_retrieval": False,
"do_using_tool": False
}, },
"codeToolReactPhase": { "codeToolReactPhase": {
"phase_name": "codeToolReactPhase", "phase_name": "codeToolReactPhase",
"phase_type": "BasePhase", "phase_type": "BasePhase",
"chains": ["codeToolPlanChain", "codeToolReactChain"], "chains": ["codeToolPlanChain", "codeToolReactChain"],
"do_summary": False,
"do_search": False,
"do_doc_retrieval": False,
"do_code_retrieval": False,
"do_tool_retrieval": False,
"do_using_tool": True "do_using_tool": True
}, },
"baseTaskPhase": { "baseTaskPhase": {
"phase_name": "baseTaskPhase", "phase_name": "baseTaskPhase",
"phase_type": "BasePhase", "phase_type": "BasePhase",
"chains": ["planChain", "executorChain"], "chains": ["planChain", "executorChain"],
"do_summary": False,
"do_search": False,
"do_doc_retrieval": False,
"do_code_retrieval": False,
"do_tool_retrieval": False,
"do_using_tool": False
}, },
"metagpt_code_devlop": { "metagpt_code_devlop": {
"phase_name": "metagpt_code_devlop", "phase_name": "metagpt_code_devlop",
"phase_type": "BasePhase", "phase_type": "BasePhase",
"chains": ["metagptChain",], "chains": ["metagptChain",],
"do_summary": False,
"do_search": False,
"do_doc_retrieval": False,
"do_code_retrieval": False,
"do_tool_retrieval": False,
"do_using_tool": False
}, },
"baseGroupPhase": { "baseGroupPhase": {
"phase_name": "baseGroupPhase", "phase_name": "baseGroupPhase",
"phase_type": "BasePhase", "phase_type": "BasePhase",
"chains": ["baseGroupChain"], "chains": ["baseGroupChain"],
"do_summary": False,
"do_search": False,
"do_doc_retrieval": False,
"do_code_retrieval": False,
"do_tool_retrieval": False,
"do_using_tool": False
}, },
"code2DocsGroup": {
"phase_name": "code2DocsGroup",
"phase_type": "BasePhase",
"chains": ["code2DocsGroupChain"],
},
"code2Tests": {
"phase_name": "code2Tests",
"phase_type": "BasePhase",
"chains": ["code2TestsChain"],
}
} }

View File

@ -41,3 +41,40 @@ SELECTOR_PROMPT_CONFIGS = [
{"field_name": 'output_format', "function_name": 'handle_output_format', 'title': 'Response Output Format', "is_context": False}, {"field_name": 'output_format', "function_name": 'handle_output_format', 'title': 'Response Output Format', "is_context": False},
{"field_name": 'begin!!!', "function_name": 'handle_response', "is_context": False, "omit_if_empty": False} {"field_name": 'begin!!!', "function_name": 'handle_response', "is_context": False, "omit_if_empty": False}
] ]
CODE2DOC_GROUP_PROMPT_CONFIGS = [
{"field_name": 'agent_profile', "function_name": 'handle_agent_profile', "is_context": False},
{"field_name": 'agent_infomation', "function_name": 'handle_agent_data', "is_context": False, "omit_if_empty": False},
# {"field_name": 'tool_information',"function_name": 'handle_tool_data', "is_context": False},
{"field_name": 'context_placeholder', "function_name": '', "is_context": True},
# {"field_name": 'reference_documents', "function_name": 'handle_doc_info'},
{"field_name": 'session_records', "function_name": 'handle_session_records'},
{"field_name": 'Specific Objective', "function_name": 'handle_specific_objective'},
{"field_name": 'Code Snippet', "function_name": 'handle_code_snippet'},
{"field_name": 'output_format', "function_name": 'handle_output_format', 'title': 'Response Output Format', "is_context": False},
{"field_name": 'begin!!!', "function_name": 'handle_response', "is_context": False, "omit_if_empty": False}
]
CODE2DOC_PROMPT_CONFIGS = [
{"field_name": 'agent_profile', "function_name": 'handle_agent_profile', "is_context": False},
# {"field_name": 'tool_information',"function_name": 'handle_tool_data', "is_context": False},
{"field_name": 'context_placeholder', "function_name": '', "is_context": True},
# {"field_name": 'reference_documents', "function_name": 'handle_doc_info'},
{"field_name": 'session_records', "function_name": 'handle_session_records'},
{"field_name": 'Specific Objective', "function_name": 'handle_specific_objective'},
{"field_name": 'Code Snippet', "function_name": 'handle_code_snippet'},
{"field_name": 'output_format', "function_name": 'handle_output_format', 'title': 'Response Output Format', "is_context": False},
{"field_name": 'begin!!!', "function_name": 'handle_response', "is_context": False, "omit_if_empty": False}
]
CODE2TESTS_PROMPT_CONFIGS = [
{"field_name": 'agent_profile', "function_name": 'handle_agent_profile', "is_context": False},
{"field_name": 'context_placeholder', "function_name": '', "is_context": True},
{"field_name": 'session_records', "function_name": 'handle_session_records'},
{"field_name": 'code_snippet', "function_name": 'handle_code_snippet'},
{"field_name": 'retrieval_codes', "function_name": 'handle_retrieval_codes', "description": ""},
{"field_name": 'output_format', "function_name": 'handle_output_format', 'title': 'Response Output Format', "is_context": False},
{"field_name": 'begin!!!', "function_name": 'handle_response', "is_context": False, "omit_if_empty": False}
]

View File

@ -14,7 +14,8 @@ from .qa_template_prompt import QA_PROMPT, CODE_QA_PROMPT, QA_TEMPLATE_PROMPT, C
from .executor_template_prompt import EXECUTOR_TEMPLATE_PROMPT from .executor_template_prompt import EXECUTOR_TEMPLATE_PROMPT
from .refine_template_prompt import REFINE_TEMPLATE_PROMPT from .refine_template_prompt import REFINE_TEMPLATE_PROMPT
from .code2doc_template_prompt import Code2DocGroup_PROMPT, Class2Doc_PROMPT, Func2Doc_PROMPT
from .code2test_template_prompt import code2Tests_PROMPT, judgeCode2Tests_PROMPT
from .agent_selector_template_prompt import SELECTOR_AGENT_TEMPLATE_PROMPT from .agent_selector_template_prompt import SELECTOR_AGENT_TEMPLATE_PROMPT
from .react_template_prompt import REACT_TEMPLATE_PROMPT from .react_template_prompt import REACT_TEMPLATE_PROMPT
@ -37,5 +38,7 @@ __all__ = [
"SELECTOR_AGENT_TEMPLATE_PROMPT", "SELECTOR_AGENT_TEMPLATE_PROMPT",
"PLANNER_TEMPLATE_PROMPT", "GENERAL_PLANNER_PROMPT", "DATA_PLANNER_PROMPT", "TOOL_PLANNER_PROMPT", "PLANNER_TEMPLATE_PROMPT", "GENERAL_PLANNER_PROMPT", "DATA_PLANNER_PROMPT", "TOOL_PLANNER_PROMPT",
"REACT_TEMPLATE_PROMPT", "REACT_TEMPLATE_PROMPT",
"REACT_CODE_PROMPT", "REACT_TOOL_PROMPT", "REACT_TOOL_AND_CODE_PROMPT", "REACT_TOOL_AND_CODE_PLANNER_PROMPT" "REACT_CODE_PROMPT", "REACT_TOOL_PROMPT", "REACT_TOOL_AND_CODE_PROMPT", "REACT_TOOL_AND_CODE_PLANNER_PROMPT",
"Code2DocGroup_PROMPT", "Class2Doc_PROMPT", "Func2Doc_PROMPT",
"code2Tests_PROMPT", "judgeCode2Tests_PROMPT"
] ]

View File

@ -0,0 +1,95 @@
Code2DocGroup_PROMPT = """#### Agent Profile
Your goal is to response according the Context Data's information with the role that will best facilitate a solution, taking into account all relevant context (Context) provided.
When you need to select the appropriate role for handling a user's query, carefully read the provided role names, role descriptions and tool list.
ATTENTION: response carefully referenced "Response Output Format" in format.
#### Input Format
#### Response Output Format
**Code Path:** Extract the paths for the class/method/function that need to be addressed from the context
**Role:** Select the role from agent names
"""
Class2Doc_PROMPT = """#### Agent Profile
As an advanced code documentation generator, you are proficient in translating class definitions into comprehensive documentation with a focus on instantiation parameters.
Your specific task is to parse the given code snippet of a class, extract information regarding its instantiation parameters.
ATTENTION: response carefully in "Response Output Format".
#### Input Format
**Code Snippet:** Provide the full class definition, including the constructor and any parameters it may require for instantiation.
#### Response Output Format
**Class Base:** Specify the base class or interface from which the current class extends, if any.
**Class Description:** Offer a brief description of the class's purpose and functionality.
**Init Parameters:** List each parameter from construct. For each parameter, provide:
- `param`: The parameter name
- `param_description`: A concise explanation of the parameter's purpose.
- `param_type`: The data type of the parameter, if explicitly defined.
```json
[
{
"param": "parameter_name",
"param_description": "A brief description of what this parameter is used for.",
"param_type": "The data type of the parameter"
},
...
]
```
If no parameter for construct, return
```json
[]
```
"""
Func2Doc_PROMPT = """#### Agent Profile
You are a high-level code documentation assistant, skilled at extracting information from function/method code into detailed and well-structured documentation.
ATTENTION: response carefully in "Response Output Format".
#### Input Format
**Code Path:** Provide the code path of the function or method you wish to document.
This name will be used to identify and extract the relevant details from the code snippet provided.
**Code Snippet:** A segment of code that contains the function or method to be documented.
#### Response Output Format
**Class Description:** Offer a brief description of the method(function)'s purpose and functionality.
**Parameters:** Extract parameter for the specific function/method Code from Code Snippet. For parameter, provide:
- `param`: The parameter name
- `param_description`: A concise explanation of the parameter's purpose.
- `param_type`: The data type of the parameter, if explicitly defined.
```json
[
{
"param": "parameter_name",
"param_description": "A brief description of what this parameter is used for.",
"param_type": "The data type of the parameter"
},
...
]
```
If no parameter for function/method, return
```json
[]
```
**Return Value Description:** Describe what the function/method returns upon completion.
**Return Type:** Indicate the type of data the function/method returns (e.g., string, integer, object, void).
"""

View File

@ -0,0 +1,65 @@
judgeCode2Tests_PROMPT = """#### Agent Profile
When determining the necessity of writing test cases for a given code snippet,
it's essential to evaluate its interactions with dependent classes and methods (retrieved code snippets),
in addition to considering these critical factors:
1. Functionality: If it implements a concrete function or logic, test cases are typically necessary to verify its correctness.
2. Complexity: If the code is complex, especially if it contains multiple conditional statements, loops, exceptions handling, etc.,
it's more likely to harbor bugs, and thus test cases should be written.
If the code involves complex algorithms or logic, then writing test cases can help ensure the accuracy of the logic and prevent errors during future refactoring.
3. Criticality: If it's part of the critical path or affects core functionalities, then it needs to be tested.
Comprehensive test cases should be written for core business logic or key components of the system to ensure the correctness and stability of the functionality.
4. Dependencies: If the code has external dependencies, integration testing may be necessary, or mocking these dependencies during unit testing might be required.
5. User Input: If the code handles user input, especially from unregulated external sources, creating test cases to check input validation and handling is important.
6. Frequent Changes: For code that requires regular updates or modifications, having the appropriate test cases ensures that changes do not break existing functionalities.
#### Input Format
**Code Snippet:** the initial Code or objective that the user wanted to achieve
**Retrieval Code Snippets:** These are the associated code segments that the main Code Snippet depends on.
Examine these snippets to understand how they interact with the main snippet and to determine how they might affect the overall functionality.
#### Response Output Format
**Action Status:** Set to 'finished' or 'continued'.
If set to 'finished', the code snippet does not warrant the generation of a test case.
If set to 'continued', the code snippet necessitates the creation of a test case.
**REASON:** Justify the selection of 'finished' or 'continued', contemplating the decision through a step-by-step rationale.
"""
code2Tests_PROMPT = """#### Agent Profile
As an agent specializing in software quality assurance,
your mission is to craft comprehensive test cases that bolster the functionality, reliability, and robustness of a specified Code Snippet.
This task is to be carried out with a keen understanding of the snippet's interactions with its dependent classes and methods—collectively referred to as Retrieval Code Snippets.
Analyze the details given below to grasp the code's intended purpose, its inherent complexity, and the context within which it operates.
Your constructed test cases must thoroughly examine the various factors influencing the code's quality and performance.
ATTENTION: response carefully referenced "Response Output Format" in format.
Each test case should include:
1. clear description of the test purpose.
2. The input values or conditions for the test.
3. The expected outcome or assertion for the test.
4. Appropriate tags (e.g., 'functional', 'integration', 'regression') that classify the type of test case.
5. these test code should have package and import
#### Input Format
**Code Snippet:** the initial Code or objective that the user wanted to achieve
**Retrieval Code Snippets:** These are the interrelated pieces of code sourced from the codebase, which support or influence the primary Code Snippet.
#### Response Output Format
**SaveFileName:** construct a local file name based on Question and Context, such as
```java
package/class.java
```
**Test Code:** generate the test code for the current Code Snippet.
```java
...
```
"""

View File

@ -1,5 +1,5 @@
from abc import abstractmethod, ABC from abc import abstractmethod, ABC
from typing import List from typing import List, Dict
import os, sys, copy, json import os, sys, copy, json
from jieba.analyse import extract_tags from jieba.analyse import extract_tags
from collections import Counter from collections import Counter
@ -10,12 +10,13 @@ from langchain.docstore.document import Document
from .schema import Memory, Message from .schema import Memory, Message
from coagent.service.service_factory import KBServiceFactory from coagent.service.service_factory import KBServiceFactory
from coagent.llm_models import getChatModel, getChatModelFromConfig from coagent.llm_models import getChatModelFromConfig
from coagent.llm_models.llm_config import EmbedConfig, LLMConfig from coagent.llm_models.llm_config import EmbedConfig, LLMConfig
from coagent.embeddings.utils import load_embeddings_from_path from coagent.embeddings.utils import load_embeddings_from_path
from coagent.utils.common_utils import save_to_json_file, read_json_file, addMinutesToTime from coagent.utils.common_utils import save_to_json_file, read_json_file, addMinutesToTime
from coagent.connector.configs.prompts import CONV_SUMMARY_PROMPT_SPEC from coagent.connector.configs.prompts import CONV_SUMMARY_PROMPT_SPEC
from coagent.orm import table_init from coagent.orm import table_init
from coagent.base_configs.env_config import KB_ROOT_PATH
# from configs.model_config import KB_ROOT_PATH, EMBEDDING_MODEL, EMBEDDING_DEVICE, SCORE_THRESHOLD # from configs.model_config import KB_ROOT_PATH, EMBEDDING_MODEL, EMBEDDING_DEVICE, SCORE_THRESHOLD
# from configs.model_config import embedding_model_dict # from configs.model_config import embedding_model_dict
@ -70,16 +71,22 @@ class BaseMemoryManager(ABC):
self.unique_name = unique_name self.unique_name = unique_name
self.memory_type = memory_type self.memory_type = memory_type
self.do_init = do_init self.do_init = do_init
self.current_memory = Memory(messages=[]) # self.current_memory = Memory(messages=[])
self.recall_memory = Memory(messages=[]) # self.recall_memory = Memory(messages=[])
self.summary_memory = Memory(messages=[]) # self.summary_memory = Memory(messages=[])
self.current_memory_dict: Dict[str, Memory] = {}
self.recall_memory_dict: Dict[str, Memory] = {}
self.summary_memory_dict: Dict[str, Memory] = {}
self.save_message_keys = [ self.save_message_keys = [
'chat_index', 'role_name', 'role_type', 'role_prompt', 'input_query', 'origin_query', 'chat_index', 'role_name', 'role_type', 'role_prompt', 'input_query', 'origin_query',
'datetime', 'role_content', 'step_content', 'parsed_output', 'spec_parsed_output', 'parsed_output_list', 'datetime', 'role_content', 'step_content', 'parsed_output', 'spec_parsed_output', 'parsed_output_list',
'task', 'db_docs', 'code_docs', 'search_docs', 'phase_name', 'chain_name', 'customed_kargs'] 'task', 'db_docs', 'code_docs', 'search_docs', 'phase_name', 'chain_name', 'customed_kargs']
self.init_vb() self.init_vb()
def init_vb(self): def re_init(self, do_init: bool=False):
self.init_vb()
def init_vb(self, do_init: bool=None):
""" """
Initializes the vb. Initializes the vb.
""" """
@ -135,13 +142,15 @@ class BaseMemoryManager(ABC):
""" """
pass pass
def save_to_vs(self, embed_model="", embed_device=""): def save_to_vs(self, ):
""" """
Saves the memory to the vector space. Saves the memory to the vector space.
"""
pass
Args: def get_memory_pool(self, user_name: str, ):
- embed_model: A string representing the embedding model. Default is EMBEDDING_MODEL. """
- embed_device: A string representing the embedding device. Default is EMBEDDING_DEVICE. return memory_pool
""" """
pass pass
@ -230,7 +239,7 @@ class LocalMemoryManager(BaseMemoryManager):
unique_name: str = "default", unique_name: str = "default",
memory_type: str = "recall", memory_type: str = "recall",
do_init: bool = False, do_init: bool = False,
kb_root_path: str = "", kb_root_path: str = KB_ROOT_PATH,
): ):
self.user_name = user_name self.user_name = user_name
self.unique_name = unique_name self.unique_name = unique_name
@ -239,16 +248,22 @@ class LocalMemoryManager(BaseMemoryManager):
self.kb_root_path = kb_root_path self.kb_root_path = kb_root_path
self.embed_config: EmbedConfig = embed_config self.embed_config: EmbedConfig = embed_config
self.llm_config: LLMConfig = llm_config self.llm_config: LLMConfig = llm_config
self.current_memory = Memory(messages=[]) # self.current_memory = Memory(messages=[])
self.recall_memory = Memory(messages=[]) # self.recall_memory = Memory(messages=[])
self.summary_memory = Memory(messages=[]) # self.summary_memory = Memory(messages=[])
self.current_memory_dict: Dict[str, Memory] = {}
self.recall_memory_dict: Dict[str, Memory] = {}
self.summary_memory_dict: Dict[str, Memory] = {}
self.save_message_keys = [ self.save_message_keys = [
'chat_index', 'role_name', 'role_type', 'role_prompt', 'input_query', 'origin_query', 'chat_index', 'role_name', 'role_type', 'role_prompt', 'input_query', 'origin_query',
'datetime', 'role_content', 'step_content', 'parsed_output', 'spec_parsed_output', 'parsed_output_list', 'datetime', 'role_content', 'step_content', 'parsed_output', 'spec_parsed_output', 'parsed_output_list',
'task', 'db_docs', 'code_docs', 'search_docs', 'phase_name', 'chain_name', 'customed_kargs'] 'task', 'db_docs', 'code_docs', 'search_docs', 'phase_name', 'chain_name', 'customed_kargs']
self.init_vb() self.init_vb()
def init_vb(self): def re_init(self, do_init: bool=False):
self.init_vb(do_init)
def init_vb(self, do_init: bool=None):
vb_name = f"{self.user_name}/{self.unique_name}/{self.memory_type}" vb_name = f"{self.user_name}/{self.unique_name}/{self.memory_type}"
# default to recreate a new vb # default to recreate a new vb
table_init() table_init()
@ -256,31 +271,37 @@ class LocalMemoryManager(BaseMemoryManager):
if vb: if vb:
status = vb.clear_vs() status = vb.clear_vs()
if not self.do_init: check_do_init = do_init if do_init else self.do_init
if not check_do_init:
self.load(self.kb_root_path) self.load(self.kb_root_path)
self.save_to_vs() self.save_to_vs()
def append(self, message: Message): def append(self, message: Message):
self.recall_memory.append(message) self.check_user_name(message.user_name)
uuid_name = "_".join([self.user_name, self.unique_name, self.memory_type])
self.recall_memory_dict[uuid_name].append(message)
# #
if message.role_type == "summary": if message.role_type == "summary":
self.summary_memory.append(message) self.summary_memory_dict[uuid_name].append(message)
else: else:
self.current_memory.append(message) self.current_memory_dict[uuid_name].append(message)
self.save(self.kb_root_path) self.save(self.kb_root_path)
self.save_new_to_vs([message]) self.save_new_to_vs([message])
def extend(self, memory: Memory): # def extend(self, memory: Memory):
self.recall_memory.extend(memory) # self.recall_memory.extend(memory)
self.current_memory.extend(self.recall_memory.filter_by_role_type(["summary"])) # self.current_memory.extend(self.recall_memory.filter_by_role_type(["summary"]))
self.summary_memory.extend(self.recall_memory.select_by_role_type(["summary"])) # self.summary_memory.extend(self.recall_memory.select_by_role_type(["summary"]))
self.save(self.kb_root_path) # self.save(self.kb_root_path)
self.save_new_to_vs(memory.messages) # self.save_new_to_vs(memory.messages)
def save(self, save_dir: str = "./"): def save(self, save_dir: str = "./"):
file_path = os.path.join(save_dir, f"{self.user_name}/{self.unique_name}/{self.memory_type}/converation.jsonl") file_path = os.path.join(save_dir, f"{self.user_name}/{self.unique_name}/{self.memory_type}/converation.jsonl")
memory_messages = self.recall_memory.dict() uuid_name = "_".join([self.user_name, self.unique_name, self.memory_type])
memory_messages = self.recall_memory_dict[uuid_name].dict()
memory_messages = {k: [ memory_messages = {k: [
{kkk: vvv for kkk, vvv in vv.items() if kkk in self.save_message_keys} {kkk: vvv for kkk, vvv in vv.items() if kkk in self.save_message_keys}
for vv in v ] for vv in v ]
@ -291,18 +312,28 @@ class LocalMemoryManager(BaseMemoryManager):
def load(self, load_dir: str = "./") -> Memory: def load(self, load_dir: str = "./") -> Memory:
file_path = os.path.join(load_dir, f"{self.user_name}/{self.unique_name}/{self.memory_type}/converation.jsonl") file_path = os.path.join(load_dir, f"{self.user_name}/{self.unique_name}/{self.memory_type}/converation.jsonl")
uuid_name = "_".join([self.user_name, self.unique_name, self.memory_type])
if os.path.exists(file_path): if os.path.exists(file_path):
self.recall_memory = Memory(**read_json_file(file_path)) # self.recall_memory = Memory(**read_json_file(file_path))
self.current_memory = Memory(messages=self.recall_memory.filter_by_role_type(["summary"])) # self.current_memory = Memory(messages=self.recall_memory.filter_by_role_type(["summary"]))
self.summary_memory = Memory(messages=self.recall_memory.select_by_role_type(["summary"])) # self.summary_memory = Memory(messages=self.recall_memory.select_by_role_type(["summary"]))
recall_memory = Memory(**read_json_file(file_path))
self.recall_memory_dict[uuid_name] = recall_memory
self.current_memory_dict[uuid_name] = Memory(messages=recall_memory.filter_by_role_type(["summary"]))
self.summary_memory_dict[uuid_name] = Memory(messages=recall_memory.select_by_role_type(["summary"]))
else:
self.recall_memory_dict[uuid_name] = Memory(messages=[])
self.current_memory_dict[uuid_name] = Memory(messages=[])
self.summary_memory_dict[uuid_name] = Memory(messages=[])
def save_new_to_vs(self, messages: List[Message]): def save_new_to_vs(self, messages: List[Message]):
if self.embed_config: if self.embed_config:
vb_name = f"{self.user_name}/{self.unique_name}/{self.memory_type}" vb_name = f"{self.user_name}/{self.unique_name}/{self.memory_type}"
# default to faiss, todo: add new vstype # default to faiss, todo: add new vstype
vb = KBServiceFactory.get_service(vb_name, "faiss", self.embed_config, self.kb_root_path) vb = KBServiceFactory.get_service(vb_name, "faiss", self.embed_config, self.kb_root_path)
embeddings = load_embeddings_from_path(self.embed_config.embed_model_path, self.embed_config.model_device,) embeddings = load_embeddings_from_path(self.embed_config.embed_model_path, self.embed_config.model_device, self.embed_config.langchain_embeddings)
messages = [ messages = [
{k: v for k, v in m.dict().items() if k in self.save_message_keys} {k: v for k, v in m.dict().items() if k in self.save_message_keys}
for m in messages] for m in messages]
@ -311,23 +342,26 @@ class LocalMemoryManager(BaseMemoryManager):
vb.do_add_doc(docs, embeddings) vb.do_add_doc(docs, embeddings)
def save_to_vs(self): def save_to_vs(self):
vb_name = f"{self.user_name}/{self.unique_name}/{self.memory_type}" '''only after load'''
# default to recreate a new vb if self.embed_config:
vb = KBServiceFactory.get_service_by_name(vb_name, self.embed_config, self.kb_root_path) vb_name = f"{self.user_name}/{self.unique_name}/{self.memory_type}"
if vb: uuid_name = "_".join([self.user_name, self.unique_name, self.memory_type])
status = vb.clear_vs() # default to recreate a new vb
# create_kb(vb_name, "faiss", embed_model) vb = KBServiceFactory.get_service_by_name(vb_name, self.embed_config, self.kb_root_path)
if vb:
status = vb.clear_vs()
# create_kb(vb_name, "faiss", embed_model)
# default to faiss, todo: add new vstype # default to faiss, todo: add new vstype
vb = KBServiceFactory.get_service(vb_name, "faiss", self.embed_config, self.kb_root_path) vb = KBServiceFactory.get_service(vb_name, "faiss", self.embed_config, self.kb_root_path)
embeddings = load_embeddings_from_path(self.embed_config.embed_model_path, self.embed_config.model_device,) embeddings = load_embeddings_from_path(self.embed_config.embed_model_path, self.embed_config.model_device, self.embed_config.langchain_embeddings)
messages = self.recall_memory.dict() messages = self.recall_memory_dict[uuid_name].dict()
messages = [ messages = [
{kkk: vvv for kkk, vvv in vv.items() if kkk in self.save_message_keys} {kkk: vvv for kkk, vvv in vv.items() if kkk in self.save_message_keys}
for k, v in messages.items() for vv in v] for k, v in messages.items() for vv in v]
docs = [{"page_content": m["step_content"] or m["role_content"] or m["input_query"] or m["origin_query"], "metadata": m} for m in messages] docs = [{"page_content": m["step_content"] or m["role_content"] or m["input_query"] or m["origin_query"], "metadata": m} for m in messages]
docs = [Document(**doc) for doc in docs] docs = [Document(**doc) for doc in docs]
vb.do_add_doc(docs, embeddings) vb.do_add_doc(docs, embeddings)
# def load_from_vs(self, embed_model=EMBEDDING_MODEL) -> Memory: # def load_from_vs(self, embed_model=EMBEDDING_MODEL) -> Memory:
# vb_name = f"{self.user_name}/{self.unique_name}/{self.memory_type}" # vb_name = f"{self.user_name}/{self.unique_name}/{self.memory_type}"
@ -338,7 +372,12 @@ class LocalMemoryManager(BaseMemoryManager):
# docs = vb.get_all_documents() # docs = vb.get_all_documents()
# print(docs) # print(docs)
def router_retrieval(self, text: str=None, datetime: str = None, n=5, top_k=5, retrieval_type: str = "embedding", **kwargs) -> List[Message]: def get_memory_pool(self, user_name: str, ):
self.check_user_name(user_name)
uuid_name = "_".join([self.user_name, self.unique_name, self.memory_type])
return self.recall_memory_dict[uuid_name]
def router_retrieval(self, user_name: str = "default", text: str=None, datetime: str = None, n=5, top_k=5, retrieval_type: str = "embedding", **kwargs) -> List[Message]:
retrieval_func_dict = { retrieval_func_dict = {
"embedding": self.embedding_retrieval, "text": self.text_retrieval, "datetime": self.datetime_retrieval "embedding": self.embedding_retrieval, "text": self.text_retrieval, "datetime": self.datetime_retrieval
} }
@ -356,20 +395,22 @@ class LocalMemoryManager(BaseMemoryManager):
# #
return retrieval_func(**params) return retrieval_func(**params)
def embedding_retrieval(self, text: str, top_k=1, score_threshold=1.0, **kwargs) -> List[Message]: def embedding_retrieval(self, text: str, top_k=1, score_threshold=1.0, user_name: str = "default", **kwargs) -> List[Message]:
if text is None: return [] if text is None: return []
vb_name = f"{self.user_name}/{self.unique_name}/{self.memory_type}" vb_name = f"{user_name}/{self.unique_name}/{self.memory_type}"
vb = KBServiceFactory.get_service(vb_name, "faiss", self.embed_config, self.kb_root_path) vb = KBServiceFactory.get_service(vb_name, "faiss", self.embed_config, self.kb_root_path)
docs = vb.search_docs(text, top_k=top_k, score_threshold=score_threshold) docs = vb.search_docs(text, top_k=top_k, score_threshold=score_threshold)
return [Message(**doc.metadata) for doc, score in docs] return [Message(**doc.metadata) for doc, score in docs]
def text_retrieval(self, text: str, **kwargs) -> List[Message]: def text_retrieval(self, text: str, user_name: str = "default", **kwargs) -> List[Message]:
if text is None: return [] if text is None: return []
return self._text_retrieval_from_cache(self.recall_memory.messages, text, score_threshold=0.3, topK=5, **kwargs) uuid_name = "_".join([user_name, self.unique_name, self.memory_type])
return self._text_retrieval_from_cache(self.recall_memory_dict[uuid_name].messages, text, score_threshold=0.3, topK=5, **kwargs)
def datetime_retrieval(self, datetime: str, text: str = None, n: int = 5, **kwargs) -> List[Message]: def datetime_retrieval(self, datetime: str, text: str = None, n: int = 5, user_name: str = "default", **kwargs) -> List[Message]:
if datetime is None: return [] if datetime is None: return []
return self._datetime_retrieval_from_cache(self.recall_memory.messages, datetime, text, n, **kwargs) uuid_name = "_".join([user_name, self.unique_name, self.memory_type])
return self._datetime_retrieval_from_cache(self.recall_memory_dict[uuid_name].messages, datetime, text, n, **kwargs)
def _text_retrieval_from_cache(self, messages: List[Message], text: str = None, score_threshold=0.3, topK=5, tag_topK=5, **kwargs) -> List[Message]: def _text_retrieval_from_cache(self, messages: List[Message], text: str = None, score_threshold=0.3, topK=5, tag_topK=5, **kwargs) -> List[Message]:
keywords = extract_tags(text, topK=tag_topK) keywords = extract_tags(text, topK=tag_topK)
@ -428,3 +469,17 @@ class LocalMemoryManager(BaseMemoryManager):
summary_message.parsed_output_list.append({"summary": content}) summary_message.parsed_output_list.append({"summary": content})
newest_messages.insert(0, summary_message) newest_messages.insert(0, summary_message)
return newest_messages return newest_messages
def check_user_name(self, user_name: str):
# logger.debug(f"self.user_name is {self.user_name}")
if user_name != self.user_name:
self.user_name = user_name
self.init_vb()
uuid_name = "_".join([self.user_name, self.unique_name, self.memory_type])
if uuid_name not in self.recall_memory_dict:
self.recall_memory_dict[uuid_name] = Memory(messages=[])
self.current_memory_dict[uuid_name] = Memory(messages=[])
self.summary_memory_dict[uuid_name] = Memory(messages=[])
# logger.debug(f"self.user_name is {self.user_name}")

View File

@ -1,16 +1,19 @@
import re, traceback, uuid, copy, json, os import re, traceback, uuid, copy, json, os
from typing import Union
from loguru import logger from loguru import logger
from langchain.schema import BaseRetriever
# from configs.server_config import SANDBOX_SERVER
# from configs.model_config import JUPYTER_WORK_PATH
from coagent.connector.schema import ( from coagent.connector.schema import (
Memory, Role, Message, ActionStatus, CodeDoc, Doc, LogVerboseEnum Memory, Role, Message, ActionStatus, CodeDoc, Doc, LogVerboseEnum
) )
from coagent.retrieval.base_retrieval import IMRertrieval
from coagent.connector.memory_manager import BaseMemoryManager from coagent.connector.memory_manager import BaseMemoryManager
from coagent.tools import DDGSTool, DocRetrieval, CodeRetrieval from coagent.tools import DDGSTool, DocRetrieval, CodeRetrieval
from coagent.sandbox import PyCodeBox, CodeBoxResponse from coagent.sandbox import PyCodeBox, CodeBoxResponse
from coagent.llm_models.llm_config import LLMConfig, EmbedConfig from coagent.llm_models.llm_config import LLMConfig, EmbedConfig
from coagent.base_configs.env_config import JUPYTER_WORK_PATH
from .utils import parse_dict_to_dict, parse_text_to_dict from .utils import parse_dict_to_dict, parse_text_to_dict
@ -19,10 +22,13 @@ class MessageUtils:
self, self,
role: Role = None, role: Role = None,
sandbox_server: dict = {}, sandbox_server: dict = {},
jupyter_work_path: str = "./", jupyter_work_path: str = JUPYTER_WORK_PATH,
embed_config: EmbedConfig = None, embed_config: EmbedConfig = None,
llm_config: LLMConfig = None, llm_config: LLMConfig = None,
kb_root_path: str = "", kb_root_path: str = "",
doc_retrieval: Union[BaseRetriever, IMRertrieval] = None,
code_retrieval: IMRertrieval = None,
search_retrieval: IMRertrieval = None,
log_verbose: str = "0" log_verbose: str = "0"
) -> None: ) -> None:
self.role = role self.role = role
@ -31,6 +37,9 @@ class MessageUtils:
self.embed_config = embed_config self.embed_config = embed_config
self.llm_config = llm_config self.llm_config = llm_config
self.kb_root_path = kb_root_path self.kb_root_path = kb_root_path
self.doc_retrieval = doc_retrieval
self.code_retrieval = code_retrieval
self.search_retrieval = search_retrieval
self.codebox = PyCodeBox( self.codebox = PyCodeBox(
remote_url=self.sandbox_server.get("url", "http://127.0.0.1:5050"), remote_url=self.sandbox_server.get("url", "http://127.0.0.1:5050"),
remote_ip=self.sandbox_server.get("host", "http://127.0.0.1"), remote_ip=self.sandbox_server.get("host", "http://127.0.0.1"),
@ -44,6 +53,7 @@ class MessageUtils:
self.log_verbose = os.environ.get("log_verbose", "0") or log_verbose self.log_verbose = os.environ.get("log_verbose", "0") or log_verbose
def inherit_extrainfo(self, input_message: Message, output_message: Message): def inherit_extrainfo(self, input_message: Message, output_message: Message):
output_message.user_name = input_message.user_name
output_message.db_docs = input_message.db_docs output_message.db_docs = input_message.db_docs
output_message.search_docs = input_message.search_docs output_message.search_docs = input_message.search_docs
output_message.code_docs = input_message.code_docs output_message.code_docs = input_message.code_docs
@ -116,18 +126,45 @@ class MessageUtils:
knowledge_basename = message.doc_engine_name knowledge_basename = message.doc_engine_name
top_k = message.top_k top_k = message.top_k
score_threshold = message.score_threshold score_threshold = message.score_threshold
if knowledge_basename: if self.doc_retrieval:
docs = DocRetrieval.run(query, knowledge_basename, top_k, score_threshold, self.embed_config, self.kb_root_path) if isinstance(self.doc_retrieval, BaseRetriever):
docs = self.doc_retrieval.get_relevant_documents(query)
else:
# docs = self.doc_retrieval.run(query, search_top=message.top_k, score_threshold=message.score_threshold,)
docs = self.doc_retrieval.run(query)
docs = [
{"index": idx, "snippet": doc.page_content, "title": doc.metadata.get("title_prefix", ""), "link": doc.metadata.get("url", "")}
for idx, doc in enumerate(docs)
]
message.db_docs = [Doc(**doc) for doc in docs] message.db_docs = [Doc(**doc) for doc in docs]
else:
if knowledge_basename:
docs = DocRetrieval.run(query, knowledge_basename, top_k, score_threshold, self.embed_config, self.kb_root_path)
message.db_docs = [Doc(**doc) for doc in docs]
return message return message
def get_code_retrieval(self, message: Message) -> Message: def get_code_retrieval(self, message: Message) -> Message:
query = message.input_query query = message.role_content
code_engine_name = message.code_engine_name code_engine_name = message.code_engine_name
history_node_list = message.history_node_list history_node_list = message.history_node_list
code_docs = CodeRetrieval.run(code_engine_name, query, code_limit=message.top_k, history_node_list=history_node_list, search_type=message.cb_search_type,
llm_config=self.llm_config, embed_config=self.embed_config,) use_nh = message.use_nh
local_graph_path = message.local_graph_path
if self.code_retrieval:
code_docs = self.code_retrieval.run(
query, history_node_list=history_node_list, search_type=message.cb_search_type,
code_limit=1
)
else:
code_docs = CodeRetrieval.run(code_engine_name, query, code_limit=message.top_k, history_node_list=history_node_list, search_type=message.cb_search_type,
llm_config=self.llm_config, embed_config=self.embed_config,
use_nh=use_nh, local_graph_path=local_graph_path)
message.code_docs = [CodeDoc(**doc) for doc in code_docs] message.code_docs = [CodeDoc(**doc) for doc in code_docs]
# related_nodes = [doc.get_related_node() for idx, doc in enumerate(message.code_docs) if idx==0],
# history_node_list.extend([node[0] for node in related_nodes])
return message return message
def get_tool_retrieval(self, message: Message) -> Message: def get_tool_retrieval(self, message: Message) -> Message:
@ -160,6 +197,7 @@ class MessageUtils:
if code_answer.code_exe_type == "error" else f"The return information after executing the above code is {code_answer.code_exe_response}.\n" if code_answer.code_exe_type == "error" else f"The return information after executing the above code is {code_answer.code_exe_response}.\n"
observation_message = Message( observation_message = Message(
user_name=message.user_name,
role_name="observation", role_name="observation",
role_type="function", #self.role.role_type, role_type="function", #self.role.role_type,
role_content="", role_content="",
@ -190,6 +228,7 @@ class MessageUtils:
def tool_step(self, message: Message) -> Message: def tool_step(self, message: Message) -> Message:
'''execute tool''' '''execute tool'''
observation_message = Message( observation_message = Message(
user_name=message.user_name,
role_name="observation", role_name="observation",
role_type="function", #self.role.role_type, role_type="function", #self.role.role_type,
role_content="\n**Observation:** there is no tool can execute\n", role_content="\n**Observation:** there is no tool can execute\n",
@ -226,7 +265,7 @@ class MessageUtils:
return message, observation_message return message, observation_message
def parser(self, message: Message) -> Message: def parser(self, message: Message) -> Message:
'''''' '''parse llm output into dict'''
content = message.role_content content = message.role_content
# parse start # parse start
parsed_dict = parse_text_to_dict(content) parsed_dict = parse_text_to_dict(content)

View File

@ -5,6 +5,8 @@ import importlib
import copy import copy
from loguru import logger from loguru import logger
from langchain.schema import BaseRetriever
from coagent.connector.agents import BaseAgent from coagent.connector.agents import BaseAgent
from coagent.connector.chains import BaseChain from coagent.connector.chains import BaseChain
from coagent.connector.schema import ( from coagent.connector.schema import (
@ -18,9 +20,6 @@ from coagent.connector.message_process import MessageUtils
from coagent.llm_models.llm_config import EmbedConfig, LLMConfig from coagent.llm_models.llm_config import EmbedConfig, LLMConfig
from coagent.base_configs.env_config import JUPYTER_WORK_PATH, KB_ROOT_PATH from coagent.base_configs.env_config import JUPYTER_WORK_PATH, KB_ROOT_PATH
# from configs.model_config import JUPYTER_WORK_PATH, KB_ROOT_PATH
# from configs.server_config import SANDBOX_SERVER
role_configs = load_role_configs(AGETN_CONFIGS) role_configs = load_role_configs(AGETN_CONFIGS)
chain_configs = load_chain_configs(CHAIN_CONFIGS) chain_configs = load_chain_configs(CHAIN_CONFIGS)
@ -39,20 +38,24 @@ class BasePhase:
kb_root_path: str = KB_ROOT_PATH, kb_root_path: str = KB_ROOT_PATH,
jupyter_work_path: str = JUPYTER_WORK_PATH, jupyter_work_path: str = JUPYTER_WORK_PATH,
sandbox_server: dict = {}, sandbox_server: dict = {},
embed_config: EmbedConfig = EmbedConfig(), embed_config: EmbedConfig = None,
llm_config: LLMConfig = LLMConfig(), llm_config: LLMConfig = None,
task: Task = None, task: Task = None,
base_phase_config: Union[dict, str] = PHASE_CONFIGS, base_phase_config: Union[dict, str] = PHASE_CONFIGS,
base_chain_config: Union[dict, str] = CHAIN_CONFIGS, base_chain_config: Union[dict, str] = CHAIN_CONFIGS,
base_role_config: Union[dict, str] = AGETN_CONFIGS, base_role_config: Union[dict, str] = AGETN_CONFIGS,
chains: List[BaseChain] = [],
doc_retrieval: Union[BaseRetriever] = None,
code_retrieval = None,
search_retrieval = None,
log_verbose: str = "0" log_verbose: str = "0"
) -> None: ) -> None:
# #
self.phase_name = phase_name self.phase_name = phase_name
self.do_summary = False self.do_summary = False
self.do_search = False self.do_search = search_retrieval is not None
self.do_code_retrieval = False self.do_code_retrieval = code_retrieval is not None
self.do_doc_retrieval = False self.do_doc_retrieval = doc_retrieval is not None
self.do_tool_retrieval = False self.do_tool_retrieval = False
# memory_pool dont have specific order # memory_pool dont have specific order
# self.memory_pool = Memory(messages=[]) # self.memory_pool = Memory(messages=[])
@ -62,12 +65,15 @@ class BasePhase:
self.jupyter_work_path = jupyter_work_path self.jupyter_work_path = jupyter_work_path
self.kb_root_path = kb_root_path self.kb_root_path = kb_root_path
self.log_verbose = max(os.environ.get("log_verbose", "0"), log_verbose) self.log_verbose = max(os.environ.get("log_verbose", "0"), log_verbose)
# TODO透传
self.message_utils = MessageUtils(None, sandbox_server, jupyter_work_path, embed_config, llm_config, kb_root_path, log_verbose) self.doc_retrieval = doc_retrieval
self.code_retrieval = code_retrieval
self.search_retrieval = search_retrieval
self.message_utils = MessageUtils(None, sandbox_server, jupyter_work_path, embed_config, llm_config, kb_root_path, doc_retrieval, code_retrieval, search_retrieval, log_verbose)
self.global_memory = Memory(messages=[]) self.global_memory = Memory(messages=[])
self.phase_memory: List[Memory] = [] self.phase_memory: List[Memory] = []
# according phase name to init the phase contains # according phase name to init the phase contains
self.chains: List[BaseChain] = self.init_chains( self.chains: List[BaseChain] = chains if chains else self.init_chains(
phase_name, phase_name,
phase_config, phase_config,
task=task, task=task,
@ -90,7 +96,9 @@ class BasePhase:
kb_root_path=kb_root_path kb_root_path=kb_root_path
) )
def astep(self, query: Message, history: Memory = None) -> Tuple[Message, Memory]: def astep(self, query: Message, history: Memory = None, reinit_memory=False) -> Tuple[Message, Memory]:
if reinit_memory:
self.memory_manager.re_init(reinit_memory)
self.memory_manager.append(query) self.memory_manager.append(query)
summary_message = None summary_message = None
chain_message = Memory(messages=[]) chain_message = Memory(messages=[])
@ -139,8 +147,8 @@ class BasePhase:
message.role_name = self.phase_name message.role_name = self.phase_name
yield message, local_phase_memory yield message, local_phase_memory
def step(self, query: Message, history: Memory = None) -> Tuple[Message, Memory]: def step(self, query: Message, history: Memory = None, reinit_memory=False) -> Tuple[Message, Memory]:
for message, local_phase_memory in self.astep(query, history=history): for message, local_phase_memory in self.astep(query, history=history, reinit_memory=reinit_memory):
pass pass
return message, local_phase_memory return message, local_phase_memory
@ -194,6 +202,9 @@ class BasePhase:
sandbox_server=self.sandbox_server, sandbox_server=self.sandbox_server,
jupyter_work_path=self.jupyter_work_path, jupyter_work_path=self.jupyter_work_path,
kb_root_path=self.kb_root_path, kb_root_path=self.kb_root_path,
doc_retrieval=self.doc_retrieval,
code_retrieval=self.code_retrieval,
search_retrieval=self.search_retrieval,
log_verbose=self.log_verbose log_verbose=self.log_verbose
) )
if agent_config.role.agent_type == "SelectorAgent": if agent_config.role.agent_type == "SelectorAgent":
@ -205,7 +216,7 @@ class BasePhase:
group_base_agent = baseAgent( group_base_agent = baseAgent(
role=group_agent_config.role, role=group_agent_config.role,
prompt_config = group_agent_config.prompt_config, prompt_config = group_agent_config.prompt_config,
prompt_manager_type=agent_config.prompt_manager_type, prompt_manager_type=group_agent_config.prompt_manager_type,
task = task, task = task,
memory = memory, memory = memory,
chat_turn=group_agent_config.chat_turn, chat_turn=group_agent_config.chat_turn,
@ -216,6 +227,9 @@ class BasePhase:
sandbox_server=self.sandbox_server, sandbox_server=self.sandbox_server,
jupyter_work_path=self.jupyter_work_path, jupyter_work_path=self.jupyter_work_path,
kb_root_path=self.kb_root_path, kb_root_path=self.kb_root_path,
doc_retrieval=self.doc_retrieval,
code_retrieval=self.code_retrieval,
search_retrieval=self.search_retrieval,
log_verbose=self.log_verbose log_verbose=self.log_verbose
) )
base_agent.group_agents.append(group_base_agent) base_agent.group_agents.append(group_base_agent)
@ -223,13 +237,16 @@ class BasePhase:
agents.append(base_agent) agents.append(base_agent)
chain_instance = BaseChain( chain_instance = BaseChain(
agents, chain_config.chat_turn, chain_config,
do_checker=chain_configs[chain_name].do_checker, agents,
jupyter_work_path=self.jupyter_work_path, jupyter_work_path=self.jupyter_work_path,
sandbox_server=self.sandbox_server, sandbox_server=self.sandbox_server,
embed_config=self.embed_config, embed_config=self.embed_config,
llm_config=self.llm_config, llm_config=self.llm_config,
kb_root_path=self.kb_root_path, kb_root_path=self.kb_root_path,
doc_retrieval=self.doc_retrieval,
code_retrieval=self.code_retrieval,
search_retrieval=self.search_retrieval,
log_verbose=self.log_verbose log_verbose=self.log_verbose
) )
chains.append(chain_instance) chains.append(chain_instance)

View File

@ -0,0 +1,2 @@
from .prompt_manager import PromptManager
from .extend_manager import *

View File

@ -0,0 +1,45 @@
from coagent.connector.schema import Message
from .prompt_manager import PromptManager
class Code2DocPM(PromptManager):
def handle_code_snippet(self, **kwargs) -> str:
if 'previous_agent_message' not in kwargs:
return ""
previous_agent_message: Message = kwargs['previous_agent_message']
code_snippet = previous_agent_message.customed_kargs.get("Code Snippet", "")
current_vertex = previous_agent_message.customed_kargs.get("Current_Vertex", "")
instruction = "A segment of code that contains the function or method to be documented.\n"
return instruction + "\n" + f"name: {current_vertex}\n{code_snippet}"
def handle_specific_objective(self, **kwargs) -> str:
if 'previous_agent_message' not in kwargs:
return ""
previous_agent_message: Message = kwargs['previous_agent_message']
specific_objective = previous_agent_message.parsed_output.get("Code Path")
instruction = "Provide the code path of the function or method you wish to document.\n"
s = instruction + f"\n{specific_objective}"
return s
class CodeRetrievalPM(PromptManager):
def handle_code_snippet(self, **kwargs) -> str:
if 'previous_agent_message' not in kwargs:
return ""
previous_agent_message: Message = kwargs['previous_agent_message']
code_snippet = previous_agent_message.customed_kargs.get("Code Snippet", "")
current_vertex = previous_agent_message.customed_kargs.get("Current_Vertex", "")
instruction = "the initial Code or objective that the user wanted to achieve"
return instruction + "\n" + f"name: {current_vertex}\n{code_snippet}"
def handle_retrieval_codes(self, **kwargs) -> str:
if 'previous_agent_message' not in kwargs:
return ""
previous_agent_message: Message = kwargs['previous_agent_message']
Retrieval_Codes = previous_agent_message.customed_kargs["Retrieval_Codes"]
Relative_vertex = previous_agent_message.customed_kargs["Relative_vertex"]
instruction = "the initial Code or objective that the user wanted to achieve"
s = instruction + "\n" + "\n".join([f"name: {vertext}\n{code}" for vertext, code in zip(Relative_vertex, Retrieval_Codes)])
return s

View File

@ -0,0 +1,353 @@
import random
from textwrap import dedent
import copy
from loguru import logger
from langchain.agents.tools import Tool
from coagent.connector.schema import Memory, Message
from coagent.connector.utils import extract_section, parse_section
class PromptManager:
def __init__(self, role_prompt="", prompt_config=None, monitored_agents=[], monitored_fields=[]):
self.role_prompt = role_prompt
self.monitored_agents = monitored_agents
self.monitored_fields = monitored_fields
self.field_handlers = {}
self.context_handlers = {}
self.field_order = [] # 用于普通字段的顺序
self.context_order = [] # 单独维护上下文字段的顺序
self.field_descriptions = {}
self.omit_if_empty_flags = {}
self.context_title = "### Context Data\n\n"
self.prompt_config = prompt_config
if self.prompt_config:
self.register_fields_from_config()
def register_field(self, field_name, function=None, title=None, description=None, is_context=True, omit_if_empty=True):
"""
注册一个新的字段及其处理函数
Args:
field_name (str): 字段名称
function (callable): 处理字段数据的函数
title (str, optional): 字段的自定义标题可选
description (str, optional): 字段的描述可选可以是几句话
is_context (bool, optional): 指示该字段是否为上下文字段
omit_if_empty (bool, optional): 如果数据为空是否省略该字段
"""
if not function:
function = self.handle_custom_data
# Register the handler function based on context flag
if is_context:
self.context_handlers[field_name] = function
else:
self.field_handlers[field_name] = function
# Store the custom title if provided and adjust the title prefix based on context
title_prefix = "####" if is_context else "###"
if title is not None:
self.field_descriptions[field_name] = f"{title_prefix} {title}\n\n"
elif description is not None:
# If title is not provided but description is, use description as title
self.field_descriptions[field_name] = f"{title_prefix} {field_name.replace('_', ' ').title()}\n\n{description}\n\n"
else:
# If neither title nor description is provided, use the field name as title
self.field_descriptions[field_name] = f"{title_prefix} {field_name.replace('_', ' ').title()}\n\n"
# Store the omit_if_empty flag for this field
self.omit_if_empty_flags[field_name] = omit_if_empty
if is_context and field_name != 'context_placeholder':
self.context_handlers[field_name] = function
self.context_order.append(field_name)
else:
self.field_handlers[field_name] = function
self.field_order.append(field_name)
def generate_full_prompt(self, **kwargs):
full_prompt = []
context_prompts = [] # 用于收集上下文内容
is_pre_print = kwargs.get("is_pre_print", False) # 用于强制打印所有prompt 字段信息,不管有没有空
# 先处理上下文字段
for field_name in self.context_order:
handler = self.context_handlers[field_name]
processed_prompt = handler(**kwargs)
# Check if the field should be omitted when empty
if self.omit_if_empty_flags.get(field_name, False) and not processed_prompt and not is_pre_print:
continue # Skip this field
title_or_description = self.field_descriptions.get(field_name, f"#### {field_name.replace('_', ' ').title()}\n\n")
context_prompts.append(title_or_description + processed_prompt + '\n\n')
# 处理普通字段,同时查找 context_placeholder 的位置
for field_name in self.field_order:
if field_name == 'context_placeholder':
# 在 context_placeholder 的位置插入上下文数据
full_prompt.append(self.context_title) # 添加上下文部分的大标题
full_prompt.extend(context_prompts) # 添加收集的上下文内容
else:
handler = self.field_handlers[field_name]
processed_prompt = handler(**kwargs)
# Check if the field should be omitted when empty
if self.omit_if_empty_flags.get(field_name, False) and not processed_prompt and not is_pre_print:
continue # Skip this field
title_or_description = self.field_descriptions.get(field_name, f"### {field_name.replace('_', ' ').title()}\n\n")
full_prompt.append(title_or_description + processed_prompt + '\n\n')
# 返回完整的提示,移除尾部的空行
return ''.join(full_prompt).rstrip('\n')
def pre_print(self, **kwargs):
kwargs.update({"is_pre_print": True})
prompt = self.generate_full_prompt(**kwargs)
input_keys = parse_section(self.role_prompt, 'Response Output Format')
llm_predict = "\n".join([f"**{k}:**" for k in input_keys])
return prompt + "\n\n" + "#"*19 + "\n<<<<LLM PREDICT>>>>\n" + "#"*19 + f"\n\n{llm_predict}\n"
def handle_custom_data(self, **kwargs):
return ""
def handle_tool_data(self, **kwargs):
if 'previous_agent_message' not in kwargs:
return ""
previous_agent_message = kwargs.get('previous_agent_message')
tools: list[Tool] = previous_agent_message.tools
if not tools:
return ""
tool_strings = []
for tool in tools:
args_str = f'args: {str(tool.args)}' if tool.args_schema else ""
tool_strings.append(f"{tool.name}: {tool.description}, {args_str}")
formatted_tools = "\n".join(tool_strings)
tool_names = ", ".join([tool.name for tool in tools])
tool_prompt = dedent(f"""
Below is a list of tools that are available for your use:
{formatted_tools}
valid "tool_name" value is:
{tool_names}
""")
return tool_prompt
def handle_agent_data(self, **kwargs):
if 'agents' not in kwargs:
return ""
agents = kwargs.get('agents')
random.shuffle(agents)
agent_names = ", ".join([f'{agent.role.role_name}' for agent in agents])
agent_descs = []
for agent in agents:
role_desc = agent.role.role_prompt.split("####")[1]
while "\n\n" in role_desc:
role_desc = role_desc.replace("\n\n", "\n")
role_desc = role_desc.replace("\n", ",")
agent_descs.append(f'"role name: {agent.role.role_name}\nrole description: {role_desc}"')
agents = "\n".join(agent_descs)
agent_prompt = f'''
Please ensure your selection is one of the listed roles. Available roles for selection:
{agents}
Please ensure select the Role from agent names, such as {agent_names}'''
return dedent(agent_prompt)
def handle_doc_info(self, **kwargs) -> str:
if 'previous_agent_message' not in kwargs:
return ""
previous_agent_message: Message = kwargs.get('previous_agent_message')
db_docs = previous_agent_message.db_docs
search_docs = previous_agent_message.search_docs
code_cocs = previous_agent_message.code_docs
doc_infos = "\n".join([doc.get_snippet() for doc in db_docs] + [doc.get_snippet() for doc in search_docs] +
[doc.get_code() for doc in code_cocs])
return doc_infos
def handle_session_records(self, **kwargs) -> str:
memory_pool: Memory = kwargs.get('memory_pool', Memory(messages=[]))
memory_pool = self.select_memory_by_agent_name(memory_pool)
memory_pool = self.select_memory_by_parsed_key(memory_pool)
return memory_pool.to_str_messages(content_key="parsed_output_list", with_tag=True)
def handle_current_plan(self, **kwargs) -> str:
if 'previous_agent_message' not in kwargs:
return ""
previous_agent_message = kwargs['previous_agent_message']
return previous_agent_message.parsed_output.get("CURRENT_STEP", "")
def handle_agent_profile(self, **kwargs) -> str:
return extract_section(self.role_prompt, 'Agent Profile')
def handle_output_format(self, **kwargs) -> str:
return extract_section(self.role_prompt, 'Response Output Format')
def handle_response(self, **kwargs) -> str:
if 'react_memory' not in kwargs:
return ""
react_memory = kwargs.get('react_memory', Memory(messages=[]))
if react_memory is None:
return ""
return "\n".join(["\n".join([f"**{k}:**\n{v}" for k,v in _dict.items()]) for _dict in react_memory.get_parserd_output()])
def handle_task_records(self, **kwargs) -> str:
if 'task_memory' not in kwargs:
return ""
task_memory: Memory = kwargs.get('task_memory', Memory(messages=[]))
if task_memory is None:
return ""
return "\n".join(["\n".join([f"**{k}:**\n{v}" for k,v in _dict.items() if k not in ["CURRENT_STEP"]]) for _dict in task_memory.get_parserd_output()])
def handle_previous_message(self, message: Message) -> str:
pass
def handle_message_by_role_name(self, message: Message) -> str:
pass
def handle_message_by_role_type(self, message: Message) -> str:
pass
def handle_current_agent_react_message(self, message: Message) -> str:
pass
def extract_codedoc_info_for_prompt(self, message: Message) -> str:
code_docs = message.code_docs
doc_infos = "\n".join([doc.get_code() for doc in code_docs])
return doc_infos
def select_memory_by_parsed_key(self, memory: Memory) -> Memory:
return Memory(
messages=[self.select_message_by_parsed_key(message) for message in memory.messages
if self.select_message_by_parsed_key(message) is not None]
)
def select_memory_by_agent_name(self, memory: Memory) -> Memory:
return Memory(
messages=[self.select_message_by_agent_name(message) for message in memory.messages
if self.select_message_by_agent_name(message) is not None]
)
def select_message_by_agent_name(self, message: Message) -> Message:
# assume we focus all agents
if self.monitored_agents == []:
return message
return None if message is None or message.role_name not in self.monitored_agents else self.select_message_by_parsed_key(message)
def select_message_by_parsed_key(self, message: Message) -> Message:
# assume we focus all key contents
if message is None:
return message
if self.monitored_fields == []:
return message
message_c = copy.deepcopy(message)
message_c.parsed_output = {k: v for k,v in message_c.parsed_output.items() if k in self.monitored_fields}
message_c.parsed_output_list = [{k: v for k,v in parsed_output.items() if k in self.monitored_fields} for parsed_output in message_c.parsed_output_list]
return message_c
def get_memory(self, content_key="role_content"):
return self.memory.to_tuple_messages(content_key="step_content")
def get_memory_str(self, content_key="role_content"):
return "\n".join([": ".join(i) for i in self.memory.to_tuple_messages(content_key="step_content")])
def register_fields_from_config(self):
for prompt_field in self.prompt_config:
function_name = prompt_field.function_name
# 检查function_name是否是self的一个方法
if function_name and hasattr(self, function_name):
function = getattr(self, function_name)
else:
function = self.handle_custom_data
self.register_field(prompt_field.field_name,
function=function,
title=prompt_field.title,
description=prompt_field.description,
is_context=prompt_field.is_context,
omit_if_empty=prompt_field.omit_if_empty)
def register_standard_fields(self):
self.register_field('agent_profile', function=self.handle_agent_profile, is_context=False)
self.register_field('tool_information', function=self.handle_tool_data, is_context=False)
self.register_field('context_placeholder', is_context=True) # 用于标记上下文数据部分的位置
self.register_field('reference_documents', function=self.handle_doc_info, is_context=True)
self.register_field('session_records', function=self.handle_session_records, is_context=True)
self.register_field('output_format', function=self.handle_output_format, title='Response Output Format', is_context=False)
self.register_field('response', function=self.handle_response, is_context=False, omit_if_empty=False)
def register_executor_fields(self):
self.register_field('agent_profile', function=self.handle_agent_profile, is_context=False)
self.register_field('tool_information', function=self.handle_tool_data, is_context=False)
self.register_field('context_placeholder', is_context=True) # 用于标记上下文数据部分的位置
self.register_field('reference_documents', function=self.handle_doc_info, is_context=True)
self.register_field('session_records', function=self.handle_session_records, is_context=True)
self.register_field('current_plan', function=self.handle_current_plan, is_context=True)
self.register_field('output_format', function=self.handle_output_format, title='Response Output Format', is_context=False)
self.register_field('response', function=self.handle_response, is_context=False, omit_if_empty=False)
def register_fields_from_dict(self, fields_dict):
# 使用字典注册字段的函数
for field_name, field_config in fields_dict.items():
function_name = field_config.get('function', None)
title = field_config.get('title', None)
description = field_config.get('description', None)
is_context = field_config.get('is_context', True)
omit_if_empty = field_config.get('omit_if_empty', True)
# 检查function_name是否是self的一个方法
if function_name and hasattr(self, function_name):
function = getattr(self, function_name)
else:
function = self.handle_custom_data
# 调用已存在的register_field方法注册字段
self.register_field(field_name, function=function, title=title, description=description, is_context=is_context, omit_if_empty=omit_if_empty)
def main():
manager = PromptManager()
manager.register_standard_fields()
manager.register_field('agents_work_progress', title=f"Agents' Work Progress", is_context=True)
# 创建数据字典
data_dict = {
"agent_profile": "这是代理配置文件...",
# "tool_list": "这是工具列表...",
"reference_documents": "这是参考文档...",
"session_records": "这是会话记录...",
"agents_work_progress": "这是代理工作进展...",
"output_format": "这是预期的输出格式...",
# "response": "这是生成或继续回应的指令...",
"response": "",
"test": 'xxxxx'
}
# 组合完整的提示
full_prompt = manager.generate_full_prompt(data_dict)
print(full_prompt)
if __name__ == "__main__":
main()

View File

@ -215,15 +215,15 @@ class Env(BaseModel):
class Role(BaseModel): class Role(BaseModel):
role_type: str role_type: str
role_name: str role_name: str
role_desc: str role_desc: str = ""
agent_type: str = "" agent_type: str = "BaseAgent"
role_prompt: str = "" role_prompt: str = ""
template_prompt: str = "" template_prompt: str = ""
class ChainConfig(BaseModel): class ChainConfig(BaseModel):
chain_name: str chain_name: str
chain_type: str chain_type: str = "BaseChain"
agents: List[str] agents: List[str]
do_checker: bool = False do_checker: bool = False
chat_turn: int = 1 chat_turn: int = 1

View File

@ -132,6 +132,9 @@ class Memory(BaseModel):
# return [parsed_output for message in self.messages for parsed_output in message.parsed_output_list[1:]] # return [parsed_output for message in self.messages for parsed_output in message.parsed_output_list[1:]]
return [parsed_output for message in self.messages for parsed_output in message.parsed_output_list] return [parsed_output for message in self.messages for parsed_output in message.parsed_output_list]
def get_spec_parserd_output(self, ):
return [message.spec_parsed_output for message in self.messages]
def get_rolenames(self, ): def get_rolenames(self, ):
'''''' ''''''
return [message.role_name for message in self.messages] return [message.role_name for message in self.messages]

View File

@ -7,6 +7,7 @@ from .general_schema import *
class Message(BaseModel): class Message(BaseModel):
chat_index: str = None chat_index: str = None
user_name: str = "default"
role_name: str role_name: str
role_type: str role_type: str
role_prompt: str = None role_prompt: str = None
@ -53,6 +54,8 @@ class Message(BaseModel):
cb_search_type: str = None cb_search_type: str = None
search_engine_name: str = None search_engine_name: str = None
top_k: int = 3 top_k: int = 3
use_nh: bool = True
local_graph_path: str = ''
score_threshold: float = 1.0 score_threshold: float = 1.0
do_doc_retrieval: bool = False do_doc_retrieval: bool = False
do_code_retrieval: bool = False do_code_retrieval: bool = False

View File

@ -72,20 +72,25 @@ def parse_text_to_dict(text):
def parse_dict_to_dict(parsed_dict) -> dict: def parse_dict_to_dict(parsed_dict) -> dict:
code_pattern = r'```python\n(.*?)```' code_pattern = r'```python\n(.*?)```'
tool_pattern = r'```json\n(.*?)```' tool_pattern = r'```json\n(.*?)```'
java_pattern = r'```java\n(.*?)```'
pattern_dict = {"code": code_pattern, "json": tool_pattern} pattern_dict = {"code": code_pattern, "json": tool_pattern, "java": java_pattern}
spec_parsed_dict = copy.deepcopy(parsed_dict) spec_parsed_dict = copy.deepcopy(parsed_dict)
for key, pattern in pattern_dict.items(): for key, pattern in pattern_dict.items():
for k, text in parsed_dict.items(): for k, text in parsed_dict.items():
# Search for the code block # Search for the code block
if not isinstance(text, str): continue if not isinstance(text, str):
spec_parsed_dict[k] = text
continue
_match = re.search(pattern, text, re.DOTALL) _match = re.search(pattern, text, re.DOTALL)
if _match: if _match:
# Add the code block to the dictionary # Add the code block to the dictionary
try: try:
spec_parsed_dict[key] = json.loads(_match.group(1).strip()) spec_parsed_dict[key] = json.loads(_match.group(1).strip())
spec_parsed_dict[k] = json.loads(_match.group(1).strip())
except: except:
spec_parsed_dict[key] = _match.group(1).strip() spec_parsed_dict[key] = _match.group(1).strip()
spec_parsed_dict[k] = _match.group(1).strip()
break break
return spec_parsed_dict return spec_parsed_dict

View File

@ -43,7 +43,7 @@ class NebulaHandler:
elif self.space_name: elif self.space_name:
cypher = f'USE {self.space_name};{cypher}' cypher = f'USE {self.space_name};{cypher}'
logger.debug(cypher) # logger.debug(cypher)
resp = session.execute(cypher) resp = session.execute(cypher)
if format_res: if format_res:
@ -247,6 +247,24 @@ class NebulaHandler:
res = self.execute_cypher(cypher, self.space_name) res = self.execute_cypher(cypher, self.space_name)
return self.result_to_dict(res) return self.result_to_dict(res)
def get_all_vertices(self,):
'''
get all vertices
@return:
'''
cypher = "MATCH (v) RETURN v;"
res = self.execute_cypher(cypher, self.space_name)
return self.result_to_dict(res)
def get_relative_vertices(self, vertice):
'''
get all vertices
@return:
'''
cypher = f'''MATCH (v1)--(v2) WHERE id(v1) == '{vertice}' RETURN id(v2) as id;'''
res = self.execute_cypher(cypher, self.space_name)
return self.result_to_dict(res)
def result_to_dict(self, result) -> dict: def result_to_dict(self, result) -> dict:
""" """
build list for each column, and transform to dataframe build list for each column, and transform to dataframe

View File

@ -6,6 +6,7 @@ import os
import pickle import pickle
import uuid import uuid
import warnings import warnings
from enum import Enum
from pathlib import Path from pathlib import Path
from typing import ( from typing import (
Any, Any,
@ -22,10 +23,22 @@ import numpy as np
from langchain.docstore.base import AddableMixin, Docstore from langchain.docstore.base import AddableMixin, Docstore
from langchain.docstore.document import Document from langchain.docstore.document import Document
from langchain.docstore.in_memory import InMemoryDocstore # from langchain.docstore.in_memory import InMemoryDocstore
from .in_memory import InMemoryDocstore
from langchain.embeddings.base import Embeddings from langchain.embeddings.base import Embeddings
from langchain.vectorstores.base import VectorStore from langchain.vectorstores.base import VectorStore
from langchain.vectorstores.utils import DistanceStrategy, maximal_marginal_relevance from langchain.vectorstores.utils import maximal_marginal_relevance
class DistanceStrategy(str, Enum):
"""Enumerator of the Distance strategies for calculating distances
between vectors."""
EUCLIDEAN_DISTANCE = "EUCLIDEAN_DISTANCE"
MAX_INNER_PRODUCT = "MAX_INNER_PRODUCT"
DOT_PRODUCT = "DOT_PRODUCT"
JACCARD = "JACCARD"
COSINE = "COSINE"
def dependable_faiss_import(no_avx2: Optional[bool] = None) -> Any: def dependable_faiss_import(no_avx2: Optional[bool] = None) -> Any:
@ -219,6 +232,9 @@ class FAISS(VectorStore):
if self._normalize_L2: if self._normalize_L2:
faiss.normalize_L2(vector) faiss.normalize_L2(vector)
scores, indices = self.index.search(vector, k if filter is None else fetch_k) scores, indices = self.index.search(vector, k if filter is None else fetch_k)
# 经过normalize的结果会超出1
if self._normalize_L2:
scores = np.array([row / np.linalg.norm(row) if np.max(row) > 1 else row for row in scores])
docs = [] docs = []
for j, i in enumerate(indices[0]): for j, i in enumerate(indices[0]):
if i == -1: if i == -1:
@ -565,7 +581,7 @@ class FAISS(VectorStore):
vecstore = cls( vecstore = cls(
embedding.embed_query, embedding.embed_query,
index, index,
InMemoryDocstore(), InMemoryDocstore({}),
{}, {},
normalize_L2=normalize_L2, normalize_L2=normalize_L2,
distance_strategy=distance_strategy, distance_strategy=distance_strategy,

View File

@ -10,13 +10,14 @@ from loguru import logger
# from configs.model_config import EMBEDDING_MODEL # from configs.model_config import EMBEDDING_MODEL
from coagent.embeddings.openai_embedding import OpenAIEmbedding from coagent.embeddings.openai_embedding import OpenAIEmbedding
from coagent.embeddings.huggingface_embedding import HFEmbedding from coagent.embeddings.huggingface_embedding import HFEmbedding
from coagent.llm_models.llm_config import EmbedConfig
def get_embedding( def get_embedding(
engine: str, engine: str,
text_list: list, text_list: list,
model_path: str = "text2vec-base-chinese", model_path: str = "text2vec-base-chinese",
embedding_device: str = "cpu", embedding_device: str = "cpu",
embed_config: EmbedConfig = None,
): ):
''' '''
get embedding get embedding
@ -25,8 +26,12 @@ def get_embedding(
@return: @return:
''' '''
emb_res = {} emb_res = {}
if embed_config and embed_config.langchain_embeddings:
if engine == 'openai': emb_res = embed_config.langchain_embeddings.embed_documents(text_list)
emb_res = {
text_list[idx]: emb_res[idx] for idx in range(len(text_list))
}
elif engine == 'openai':
oae = OpenAIEmbedding() oae = OpenAIEmbedding()
emb_res = oae.get_emb(text_list) emb_res = oae.get_emb(text_list)
elif engine == 'model': elif engine == 'model':

View File

@ -0,0 +1,49 @@
"""Simple in memory docstore in the form of a dict."""
from typing import Dict, List, Optional, Union
from langchain.docstore.base import AddableMixin, Docstore
from langchain.docstore.document import Document
class InMemoryDocstore(Docstore, AddableMixin):
"""Simple in memory docstore in the form of a dict."""
def __init__(self, _dict: Optional[Dict[str, Document]] = None):
"""Initialize with dict."""
self._dict = _dict if _dict is not None else {}
def add(self, texts: Dict[str, Document]) -> None:
"""Add texts to in memory dictionary.
Args:
texts: dictionary of id -> document.
Returns:
None
"""
overlapping = set(texts).intersection(self._dict)
if overlapping:
raise ValueError(f"Tried to add ids that already exist: {overlapping}")
self._dict = {**self._dict, **texts}
def delete(self, ids: List) -> None:
"""Deleting IDs from in memory dictionary."""
overlapping = set(ids).intersection(self._dict)
if not overlapping:
raise ValueError(f"Tried to delete ids that does not exist: {ids}")
for _id in ids:
self._dict.pop(_id)
def search(self, search: str) -> Union[str, Document]:
"""Search via direct lookup.
Args:
search: id of a document to search for.
Returns:
Document if found, else error message.
"""
if search not in self._dict:
return f"ID {search} not found."
else:
return self._dict[search]

View File

@ -1,6 +1,8 @@
import os import os
from functools import lru_cache from functools import lru_cache
from langchain.embeddings.huggingface import HuggingFaceEmbeddings from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from langchain.embeddings.base import Embeddings
# from configs.model_config import embedding_model_dict # from configs.model_config import embedding_model_dict
from loguru import logger from loguru import logger
@ -12,8 +14,11 @@ def load_embeddings(model: str, device: str, embedding_model_dict: dict):
return embeddings return embeddings
@lru_cache(1) # @lru_cache(1)
def load_embeddings_from_path(model_path: str, device: str): def load_embeddings_from_path(model_path: str, device: str, langchain_embeddings: Embeddings = None):
if langchain_embeddings:
return langchain_embeddings
embeddings = HuggingFaceEmbeddings(model_name=model_path, embeddings = HuggingFaceEmbeddings(model_name=model_path,
model_kwargs={'device': device}) model_kwargs={'device': device})
return embeddings return embeddings

View File

@ -1,8 +1,8 @@
from .openai_model import getChatModel, getExtraModel, getChatModelFromConfig from .openai_model import getExtraModel, getChatModelFromConfig
from .llm_config import LLMConfig, EmbedConfig from .llm_config import LLMConfig, EmbedConfig
__all__ = [ __all__ = [
"getChatModel", "getExtraModel", "getChatModelFromConfig", "getExtraModel", "getChatModelFromConfig",
"LLMConfig", "EmbedConfig" "LLMConfig", "EmbedConfig"
] ]

View File

@ -1,6 +1,9 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Union from typing import List, Union
from langchain.embeddings.base import Embeddings
from langchain.llms.base import LLM, BaseLLM
@dataclass @dataclass
@ -12,7 +15,8 @@ class LLMConfig:
stop: Union[List[str], str] = None, stop: Union[List[str], str] = None,
api_key: str = "", api_key: str = "",
api_base_url: str = "", api_base_url: str = "",
model_device: str = "cpu", model_device: str = "cpu", # unusewill delete it
llm: LLM = None,
**kwargs **kwargs
): ):
@ -21,7 +25,7 @@ class LLMConfig:
self.stop: Union[List[str], str] = stop self.stop: Union[List[str], str] = stop
self.api_key: str = api_key self.api_key: str = api_key
self.api_base_url: str = api_base_url self.api_base_url: str = api_base_url
self.model_device: str = model_device self.llm: LLM = llm
# #
self.check_config() self.check_config()
@ -42,6 +46,7 @@ class EmbedConfig:
embed_model_path: str = "", embed_model_path: str = "",
embed_engine: str = "", embed_engine: str = "",
model_device: str = "cpu", model_device: str = "cpu",
langchain_embeddings: Embeddings = None,
**kwargs **kwargs
): ):
self.embed_model: str = embed_model self.embed_model: str = embed_model
@ -51,6 +56,8 @@ class EmbedConfig:
self.api_key: str = api_key self.api_key: str = api_key
self.api_base_url: str = api_base_url self.api_base_url: str = api_base_url
# #
self.langchain_embeddings = langchain_embeddings
#
self.check_config() self.check_config()
def check_config(self, ): def check_config(self, ):

View File

@ -1,38 +1,54 @@
import os import os
from typing import Union, Optional, List
from loguru import logger
from langchain.callbacks import AsyncIteratorCallbackHandler from langchain.callbacks import AsyncIteratorCallbackHandler
from langchain.chat_models import ChatOpenAI from langchain.chat_models import ChatOpenAI
from langchain.llms.base import LLM
from .llm_config import LLMConfig from .llm_config import LLMConfig
# from configs.model_config import (llm_model_dict, LLM_MODEL) # from configs.model_config import (llm_model_dict, LLM_MODEL)
def getChatModel(callBack: AsyncIteratorCallbackHandler = None, temperature=0.3, stop=None): class CustomLLMModel:
if callBack is None:
def __init__(self, llm: LLM):
self.llm: LLM = llm
def __call__(self, prompt: str,
stop: Optional[List[str]] = None):
return self.llm(prompt, stop)
def _call(self, prompt: str,
stop: Optional[List[str]] = None):
return self.llm(prompt, stop)
def predict(self, prompt: str,
stop: Optional[List[str]] = None):
return self.llm(prompt, stop)
def batch(self, prompts: str,
stop: Optional[List[str]] = None):
return [self.llm(prompt, stop) for prompt in prompts]
def getChatModelFromConfig(llm_config: LLMConfig, callBack: AsyncIteratorCallbackHandler = None, ) -> Union[ChatOpenAI, LLM]:
# logger.debug(f"llm type is {type(llm_config.llm)}")
if llm_config is None:
model = ChatOpenAI( model = ChatOpenAI(
streaming=True, streaming=True,
verbose=True, verbose=True,
openai_api_key=llm_model_dict[LLM_MODEL]["api_key"], openai_api_key=os.environ.get("api_key"),
openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"], openai_api_base=os.environ.get("api_base_url"),
model_name=LLM_MODEL, model_name=os.environ.get("LLM_MODEL", "gpt-3.5-turbo"),
temperature=temperature, temperature=os.environ.get("temperature", 0.5),
stop=stop stop=os.environ.get("stop", ""),
) )
else: return model
model = ChatOpenAI(
streaming=True,
verbose=True,
callBack=[callBack],
openai_api_key=llm_model_dict[LLM_MODEL]["api_key"],
openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"],
model_name=LLM_MODEL,
temperature=temperature,
stop=stop
)
return model
if llm_config and llm_config.llm and isinstance(llm_config.llm, LLM):
return CustomLLMModel(llm=llm_config.llm)
def getChatModelFromConfig(llm_config: LLMConfig, callBack: AsyncIteratorCallbackHandler = None, ):
if callBack is None: if callBack is None:
model = ChatOpenAI( model = ChatOpenAI(
streaming=True, streaming=True,

View File

@ -0,0 +1,5 @@
# from .base_retrieval import *
# __all__ = [
# "IMRertrieval", "BaseDocRetrieval", "BaseCodeRetrieval", "BaseSearchRetrieval"
# ]

View File

@ -0,0 +1,75 @@
from coagent.llm_models.llm_config import EmbedConfig, LLMConfig
from coagent.base_configs.env_config import KB_ROOT_PATH
from coagent.tools import DocRetrieval, CodeRetrieval
class IMRertrieval:
def __init__(self,):
'''
init your personal attributes
'''
pass
def run(self, ):
'''
execute interface, and can use init' attributes
'''
pass
class BaseDocRetrieval(IMRertrieval):
def __init__(self, knowledge_base_name: str, search_top=5, score_threshold=1.0, embed_config: EmbedConfig=EmbedConfig(), kb_root_path: str=KB_ROOT_PATH):
self.knowledge_base_name = knowledge_base_name
self.search_top = search_top
self.score_threshold = score_threshold
self.embed_config = embed_config
self.kb_root_path = kb_root_path
def run(self, query: str, search_top=None, score_threshold=None, ):
docs = DocRetrieval.run(
query=query, knowledge_base_name=self.knowledge_base_name,
search_top=search_top or self.search_top,
score_threshold=score_threshold or self.score_threshold,
embed_config=self.embed_config,
kb_root_path=self.kb_root_path
)
return docs
class BaseCodeRetrieval(IMRertrieval):
def __init__(self, code_base_name, embed_config: EmbedConfig, llm_config: LLMConfig, search_type = 'tag', code_limit = 1, local_graph_path: str=""):
self.code_base_name = code_base_name
self.embed_config = embed_config
self.llm_config = llm_config
self.search_type = search_type
self.code_limit = code_limit
self.use_nh: bool = False
self.local_graph_path: str = local_graph_path
def run(self, query, history_node_list=[], search_type = None, code_limit=None):
code_docs = CodeRetrieval.run(
code_base_name=self.code_base_name,
query=query,
history_node_list=history_node_list,
code_limit=code_limit or self.code_limit,
search_type=search_type or self.search_type,
llm_config=self.llm_config,
embed_config=self.embed_config,
use_nh=self.use_nh,
local_graph_path=self.local_graph_path
)
return code_docs
class BaseSearchRetrieval(IMRertrieval):
def __init__(self, ):
pass
def run(self, ):
pass

View File

@ -0,0 +1,6 @@
from .json_loader import JSONLoader
from .jsonl_loader import JSONLLoader
__all__ = [
"JSONLoader", "JSONLLoader"
]

View File

@ -0,0 +1,61 @@
import json
from pathlib import Path
from typing import AnyStr, Callable, Dict, List, Optional, Union
from langchain.docstore.document import Document
from langchain.document_loaders.base import BaseLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
from coagent.utils.common_utils import read_json_file
class JSONLoader(BaseLoader):
def __init__(
self,
file_path: Union[str, Path],
schema_key: str = "all_text",
content_key: Optional[str] = None,
metadata_func: Optional[Callable[[Dict, Dict], Dict]] = None,
text_content: bool = True,
):
self.file_path = Path(file_path).resolve()
self.schema_key = schema_key
self._content_key = content_key
self._metadata_func = metadata_func
self._text_content = text_content
def load(self, ) -> List[Document]:
"""Load and return documents from the JSON file."""
docs: List[Document] = []
datas = read_json_file(self.file_path)
self._parse(datas, docs)
return docs
def _parse(self, datas: List, docs: List[Document]) -> None:
for idx, sample in enumerate(datas):
metadata = dict(
source=str(self.file_path),
seq_num=idx,
)
text = sample.get(self.schema_key, "")
docs.append(Document(page_content=text, metadata=metadata))
def load_and_split(
self, text_splitter: Optional[TextSplitter] = None
) -> List[Document]:
"""Load Documents and split into chunks. Chunks are returned as Documents.
Args:
text_splitter: TextSplitter instance to use for splitting documents.
Defaults to RecursiveCharacterTextSplitter.
Returns:
List of Documents.
"""
if text_splitter is None:
_text_splitter: TextSplitter = RecursiveCharacterTextSplitter()
else:
_text_splitter = text_splitter
docs = self.load()
return _text_splitter.split_documents(docs)

View File

@ -0,0 +1,62 @@
import json
from pathlib import Path
from typing import AnyStr, Callable, Dict, List, Optional, Union
from langchain.docstore.document import Document
from langchain.document_loaders.base import BaseLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
from coagent.utils.common_utils import read_jsonl_file
class JSONLLoader(BaseLoader):
def __init__(
self,
file_path: Union[str, Path],
schema_key: str = "all_text",
content_key: Optional[str] = None,
metadata_func: Optional[Callable[[Dict, Dict], Dict]] = None,
text_content: bool = True,
):
self.file_path = Path(file_path).resolve()
self.schema_key = schema_key
self._content_key = content_key
self._metadata_func = metadata_func
self._text_content = text_content
def load(self, ) -> List[Document]:
"""Load and return documents from the JSON file."""
docs: List[Document] = []
datas = read_jsonl_file(self.file_path)
self._parse(datas, docs)
return docs
def _parse(self, datas: List, docs: List[Document]) -> None:
for idx, sample in enumerate(datas):
metadata = dict(
source=str(self.file_path),
seq_num=idx,
)
text = sample.get(self.schema_key, "")
docs.append(Document(page_content=text, metadata=metadata))
def load_and_split(
self, text_splitter: Optional[TextSplitter] = None
) -> List[Document]:
"""Load Documents and split into chunks. Chunks are returned as Documents.
Args:
text_splitter: TextSplitter instance to use for splitting documents.
Defaults to RecursiveCharacterTextSplitter.
Returns:
List of Documents.
"""
if text_splitter is None:
_text_splitter: TextSplitter = RecursiveCharacterTextSplitter()
else:
_text_splitter = text_splitter
docs = self.load()
return _text_splitter.split_documents(docs)

View File

@ -0,0 +1,3 @@
from .langchain_splitter import LCTextSplitter
__all__ = ["LCTextSplitter"]

View File

@ -0,0 +1,77 @@
import os
import importlib
from loguru import logger
from langchain.document_loaders.base import BaseLoader
from langchain.text_splitter import (
SpacyTextSplitter, RecursiveCharacterTextSplitter
)
# from configs.model_config import (
# CHUNK_SIZE,
# OVERLAP_SIZE,
# ZH_TITLE_ENHANCE
# )
from coagent.utils.path_utils import *
class LCTextSplitter:
'''langchain textsplitter 执行file2text'''
def __init__(
self, filepath: str, text_splitter_name: str = None,
chunk_size: int = 500,
overlap_size: int = 50
):
self.filepath = filepath
self.ext = os.path.splitext(filepath)[-1].lower()
self.text_splitter_name = text_splitter_name
self.chunk_size = chunk_size
self.overlap_size = overlap_size
if self.ext not in SUPPORTED_EXTS:
raise ValueError(f"暂未支持的文件格式 {self.ext}")
self.document_loader_name = get_LoaderClass(self.ext)
def file2text(self, ):
loader = self._load_document()
text_splitter = self._load_text_splitter()
if self.document_loader_name in ["JSONLoader", "JSONLLoader"]:
# docs = loader.load()
docs = loader.load_and_split(text_splitter)
# logger.debug(f"please check your file can be loaded, docs.lens {len(docs)}")
else:
docs = loader.load_and_split(text_splitter)
return docs
def _load_document(self, ) -> BaseLoader:
DocumentLoader = EXT2LOADER_DICT[self.ext]
if self.document_loader_name == "UnstructuredFileLoader":
loader = DocumentLoader(self.filepath, autodetect_encoding=True)
else:
loader = DocumentLoader(self.filepath)
return loader
def _load_text_splitter(self, ):
try:
if self.text_splitter_name is None:
text_splitter = SpacyTextSplitter(
pipeline="zh_core_web_sm",
chunk_size=self.chunk_size,
chunk_overlap=self.overlap_size,
)
self.text_splitter_name = "SpacyTextSplitter"
# elif self.document_loader_name in ["JSONLoader", "JSONLLoader"]:
# text_splitter = None
else:
text_splitter_module = importlib.import_module('langchain.text_splitter')
TextSplitter = getattr(text_splitter_module, self.text_splitter_name)
text_splitter = TextSplitter(
chunk_size=self.chunk_size,
chunk_overlap=self.overlap_size)
except Exception as e:
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=self.chunk_size,
chunk_overlap=self.overlap_size,
)
return text_splitter

View File

View File

@ -32,8 +32,8 @@ class PyCodeBox(BaseBox):
self.do_check_net = do_check_net self.do_check_net = do_check_net
self.use_stop = use_stop self.use_stop = use_stop
self.jupyter_work_path = jupyter_work_path self.jupyter_work_path = jupyter_work_path
asyncio.run(self.astart()) # asyncio.run(self.astart())
# self.start() self.start()
# logger.info(f"""remote_url: {self.remote_url}, # logger.info(f"""remote_url: {self.remote_url},
# remote_ip: {self.remote_ip}, # remote_ip: {self.remote_ip},
@ -199,13 +199,13 @@ class PyCodeBox(BaseBox):
async def _aget_kernelid(self, ) -> None: async def _aget_kernelid(self, ) -> None:
headers = {"Authorization": f'Token {self.token}', 'token': self.token} headers = {"Authorization": f'Token {self.token}', 'token': self.token}
response = requests.get(f"{self.kernel_url}?token={self.token}", headers=headers) # response = requests.get(f"{self.kernel_url}?token={self.token}", headers=headers)
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
async with session.get(f"{self.kernel_url}?token={self.token}", headers=headers) as resp: async with session.get(f"{self.kernel_url}?token={self.token}", headers=headers, timeout=270) as resp:
if len(await resp.json()) > 0: if len(await resp.json()) > 0:
self.kernel_id = (await resp.json())[0]["id"] self.kernel_id = (await resp.json())[0]["id"]
else: else:
async with session.post(f"{self.kernel_url}?token={self.token}", headers=headers) as response: async with session.post(f"{self.kernel_url}?token={self.token}", headers=headers, timeout=270) as response:
self.kernel_id = (await response.json())["id"] self.kernel_id = (await response.json())["id"]
# if len(response.json()) > 0: # if len(response.json()) > 0:
@ -220,41 +220,45 @@ class PyCodeBox(BaseBox):
return False return False
try: try:
response = requests.get(f"{self.kernel_url}?token={self.token}", timeout=270) response = requests.get(f"{self.kernel_url}?token={self.token}", timeout=10)
return response.status_code == 200 return response.status_code == 200
except requests.exceptions.ConnectionError: except requests.exceptions.ConnectionError:
return False return False
except requests.exceptions.ReadTimeout:
return False
async def _acheck_connect(self, ) -> bool: async def _acheck_connect(self, ) -> bool:
if self.kernel_url == "": if self.kernel_url == "":
return False return False
try: try:
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
async with session.get(f"{self.kernel_url}?token={self.token}", timeout=270) as resp: async with session.get(f"{self.kernel_url}?token={self.token}", timeout=10) as resp:
return resp.status == 200 return resp.status == 200
except aiohttp.ClientConnectorError: except aiohttp.ClientConnectorError:
pass return False
except aiohttp.ServerDisconnectedError: except aiohttp.ServerDisconnectedError:
pass return False
def _check_port(self, ) -> bool: def _check_port(self, ) -> bool:
try: try:
response = requests.get(f"{self.remote_ip}:{self.remote_port}", timeout=270) response = requests.get(f"{self.remote_ip}:{self.remote_port}", timeout=10)
logger.warning(f"Port is conflict, please check your codebox's port {self.remote_port}") logger.warning(f"Port is conflict, please check your codebox's port {self.remote_port}")
return response.status_code == 200 return response.status_code == 200
except requests.exceptions.ConnectionError: except requests.exceptions.ConnectionError:
return False return False
except requests.exceptions.ReadTimeout:
return False
async def _acheck_port(self, ) -> bool: async def _acheck_port(self, ) -> bool:
try: try:
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
async with session.get(f"{self.remote_ip}:{self.remote_port}", timeout=270) as resp: async with session.get(f"{self.remote_ip}:{self.remote_port}", timeout=10) as resp:
# logger.warning(f"Port is conflict, please check your codebox's port {self.remote_port}") # logger.warning(f"Port is conflict, please check your codebox's port {self.remote_port}")
return resp.status == 200 return resp.status == 200
except aiohttp.ClientConnectorError: except aiohttp.ClientConnectorError:
pass return False
except aiohttp.ServerDisconnectedError: except aiohttp.ServerDisconnectedError:
pass return False
def _check_connect_success(self, retry_nums: int = 2) -> bool: def _check_connect_success(self, retry_nums: int = 2) -> bool:
if not self.do_check_net: return True if not self.do_check_net: return True
@ -263,7 +267,7 @@ class PyCodeBox(BaseBox):
try: try:
connect_status = self._check_connect() connect_status = self._check_connect()
if connect_status: if connect_status:
logger.info(f"{self.remote_url} connection success") # logger.info(f"{self.remote_url} connection success")
return True return True
except requests.exceptions.ConnectionError: except requests.exceptions.ConnectionError:
logger.info(f"{self.remote_url} connection fail") logger.info(f"{self.remote_url} connection fail")
@ -301,10 +305,12 @@ class PyCodeBox(BaseBox):
else: else:
# TODO 自动检测本地接口 # TODO 自动检测本地接口
port_status = self._check_port() port_status = self._check_port()
self.kernel_url = self.remote_url + "/api/kernels"
connect_status = self._check_connect() connect_status = self._check_connect()
logger.info(f"port_status: {port_status}, connect_status: {connect_status}") if os.environ.get("log_verbose", "0") >= "2":
logger.info(f"port_status: {port_status}, connect_status: {connect_status}")
if port_status and not connect_status: if port_status and not connect_status:
raise BaseException(f"Port is conflict, please check your codebox's port {self.remote_port}") logger.error("Port is conflict, please check your codebox's port {self.remote_port}")
if not connect_status: if not connect_status:
self.jupyter = subprocess.Popen( self.jupyter = subprocess.Popen(
@ -321,14 +327,32 @@ class PyCodeBox(BaseBox):
stdout=subprocess.PIPE, stdout=subprocess.PIPE,
) )
record = []
while True and self.jupyter and len(record)<100:
line = self.jupyter.stderr.readline()
try:
content = line.decode("utf-8")
except:
content = line.decode("gbk")
# logger.debug(content)
record.append(content)
if "control-c" in content.lower():
break
self.kernel_url = self.remote_url + "/api/kernels" self.kernel_url = self.remote_url + "/api/kernels"
self.do_check_net = True self.do_check_net = True
self._check_connect_success() self._check_connect_success()
self._get_kernelid() self._get_kernelid()
# logger.debug(self.kernel_url.replace("http", "ws") + f"/{self.kernel_id}/channels?token={self.token}")
self.wc_url = self.kernel_url.replace("http", "ws") + f"/{self.kernel_id}/channels?token={self.token}" self.wc_url = self.kernel_url.replace("http", "ws") + f"/{self.kernel_id}/channels?token={self.token}"
headers = {"Authorization": f'Token {self.token}', 'token': self.token} headers = {"Authorization": f'Token {self.token}', 'token': self.token}
self.ws = create_connection(self.wc_url, headers=headers) retry_nums = 3
while retry_nums>=0:
try:
self.ws = create_connection(self.wc_url, headers=headers, timeout=5)
break
except Exception as e:
logger.error(f"create ws connection timeout {e}")
retry_nums -= 1
async def astart(self, ): async def astart(self, ):
'''判断是从外部service执行还是内部启动notebook执行''' '''判断是从外部service执行还是内部启动notebook执行'''
@ -369,10 +393,16 @@ class PyCodeBox(BaseBox):
cwd=self.jupyter_work_path cwd=self.jupyter_work_path
) )
while True and self.jupyter: record = []
while True and self.jupyter and len(record)<100:
line = self.jupyter.stderr.readline() line = self.jupyter.stderr.readline()
# logger.debug(line.decode("gbk")) try:
if "Control-C" in line.decode("gbk"): content = line.decode("utf-8")
except:
content = line.decode("gbk")
# logger.debug(content)
record.append(content)
if "control-c" in content.lower():
break break
self.kernel_url = self.remote_url + "/api/kernels" self.kernel_url = self.remote_url + "/api/kernels"
self.do_check_net = True self.do_check_net = True
@ -380,7 +410,15 @@ class PyCodeBox(BaseBox):
await self._aget_kernelid() await self._aget_kernelid()
self.wc_url = self.kernel_url.replace("http", "ws") + f"/{self.kernel_id}/channels?token={self.token}" self.wc_url = self.kernel_url.replace("http", "ws") + f"/{self.kernel_id}/channels?token={self.token}"
headers = {"Authorization": f'Token {self.token}', 'token': self.token} headers = {"Authorization": f'Token {self.token}', 'token': self.token}
self.ws = create_connection(self.wc_url, headers=headers)
retry_nums = 3
while retry_nums>=0:
try:
self.ws = create_connection(self.wc_url, headers=headers, timeout=5)
break
except Exception as e:
logger.error(f"create ws connection timeout {e}")
retry_nums -= 1
def status(self,) -> CodeBoxStatus: def status(self,) -> CodeBoxStatus:
if not self.kernel_id: if not self.kernel_id:

View File

@ -17,7 +17,7 @@ from coagent.orm.commands import *
from coagent.utils.path_utils import * from coagent.utils.path_utils import *
from coagent.orm.utils import DocumentFile from coagent.orm.utils import DocumentFile
from coagent.embeddings.utils import load_embeddings, load_embeddings_from_path from coagent.embeddings.utils import load_embeddings, load_embeddings_from_path
from coagent.text_splitter import LCTextSplitter from coagent.retrieval.text_splitter import LCTextSplitter
from coagent.llm_models.llm_config import EmbedConfig from coagent.llm_models.llm_config import EmbedConfig
@ -46,7 +46,7 @@ class KBService(ABC):
def _load_embeddings(self) -> Embeddings: def _load_embeddings(self) -> Embeddings:
# return load_embeddings(self.embed_model, embed_device, embedding_model_dict) # return load_embeddings(self.embed_model, embed_device, embedding_model_dict)
return load_embeddings_from_path(self.embed_config.embed_model_path, self.embed_config.model_device) return load_embeddings_from_path(self.embed_config.embed_model_path, self.embed_config.model_device, self.embed_config.langchain_embeddings)
def create_kb(self): def create_kb(self):
""" """

View File

@ -20,9 +20,6 @@ from coagent.utils.path_utils import *
from coagent.orm.commands import * from coagent.orm.commands import *
from coagent.db_handler.graph_db_handler.nebula_handler import NebulaHandler from coagent.db_handler.graph_db_handler.nebula_handler import NebulaHandler
from coagent.db_handler.vector_db_handler.chroma_handler import ChromaHandler from coagent.db_handler.vector_db_handler.chroma_handler import ChromaHandler
# from configs.server_config import NEBULA_HOST, NEBULA_PORT, NEBULA_USER, NEBULA_PASSWORD, NEBULA_STORAGED_PORT
# from configs.server_config import CHROMA_PERSISTENT_PATH
from coagent.base_configs.env_config import ( from coagent.base_configs.env_config import (
CB_ROOT_PATH, CB_ROOT_PATH,
NEBULA_HOST, NEBULA_PORT, NEBULA_USER, NEBULA_PASSWORD, NEBULA_STORAGED_PORT, NEBULA_HOST, NEBULA_PORT, NEBULA_USER, NEBULA_PASSWORD, NEBULA_STORAGED_PORT,
@ -58,10 +55,11 @@ async def create_cb(zip_file,
model_name: bool = Body(..., examples=["samples"]), model_name: bool = Body(..., examples=["samples"]),
temperature: bool = Body(..., examples=["samples"]), temperature: bool = Body(..., examples=["samples"]),
model_device: bool = Body(..., examples=["samples"]), model_device: bool = Body(..., examples=["samples"]),
embed_config: EmbedConfig = None,
) -> BaseResponse: ) -> BaseResponse:
logger.info('cb_name={}, zip_path={}, do_interpret={}'.format(cb_name, code_path, do_interpret)) logger.info('cb_name={}, zip_path={}, do_interpret={}'.format(cb_name, code_path, do_interpret))
embed_config: EmbedConfig = EmbedConfig(**locals()) embed_config: EmbedConfig = EmbedConfig(**locals()) if embed_config is None else embed_config
llm_config: LLMConfig = LLMConfig(**locals()) llm_config: LLMConfig = LLMConfig(**locals())
# Create selected knowledge base # Create selected knowledge base
@ -101,9 +99,10 @@ async def delete_cb(
model_name: bool = Body(..., examples=["samples"]), model_name: bool = Body(..., examples=["samples"]),
temperature: bool = Body(..., examples=["samples"]), temperature: bool = Body(..., examples=["samples"]),
model_device: bool = Body(..., examples=["samples"]), model_device: bool = Body(..., examples=["samples"]),
embed_config: EmbedConfig = None,
) -> BaseResponse: ) -> BaseResponse:
logger.info('cb_name={}'.format(cb_name)) logger.info('cb_name={}'.format(cb_name))
embed_config: EmbedConfig = EmbedConfig(**locals()) embed_config: EmbedConfig = EmbedConfig(**locals()) if embed_config is None else embed_config
llm_config: LLMConfig = LLMConfig(**locals()) llm_config: LLMConfig = LLMConfig(**locals())
# Create selected knowledge base # Create selected knowledge base
if not validate_kb_name(cb_name): if not validate_kb_name(cb_name):
@ -143,18 +142,24 @@ def search_code(cb_name: str = Body(..., examples=["sofaboot"]),
model_name: bool = Body(..., examples=["samples"]), model_name: bool = Body(..., examples=["samples"]),
temperature: bool = Body(..., examples=["samples"]), temperature: bool = Body(..., examples=["samples"]),
model_device: bool = Body(..., examples=["samples"]), model_device: bool = Body(..., examples=["samples"]),
use_nh: bool = True,
local_graph_path: str = '',
embed_config: EmbedConfig = None,
) -> dict: ) -> dict:
logger.info('cb_name={}'.format(cb_name)) if os.environ.get("log_verbose", "0") >= "2":
logger.info('query={}'.format(query)) logger.info(f'local_graph_path={local_graph_path}')
logger.info('code_limit={}'.format(code_limit)) logger.info('cb_name={}'.format(cb_name))
logger.info('search_type={}'.format(search_type)) logger.info('query={}'.format(query))
logger.info('history_node_list={}'.format(history_node_list)) logger.info('code_limit={}'.format(code_limit))
embed_config: EmbedConfig = EmbedConfig(**locals()) logger.info('search_type={}'.format(search_type))
logger.info('history_node_list={}'.format(history_node_list))
embed_config: EmbedConfig = EmbedConfig(**locals()) if embed_config is None else embed_config
llm_config: LLMConfig = LLMConfig(**locals()) llm_config: LLMConfig = LLMConfig(**locals())
try: try:
# load codebase # load codebase
cbh = CodeBaseHandler(codebase_name=cb_name, embed_config=embed_config, llm_config=llm_config) cbh = CodeBaseHandler(codebase_name=cb_name, embed_config=embed_config, llm_config=llm_config,
use_nh=use_nh, local_graph_path=local_graph_path)
# search code # search code
context, related_vertices = cbh.search_code(query, search_type=search_type, limit=code_limit) context, related_vertices = cbh.search_code(query, search_type=search_type, limit=code_limit)
@ -180,10 +185,12 @@ def search_related_vertices(cb_name: str = Body(..., examples=["sofaboot"]),
nh = NebulaHandler(host=NEBULA_HOST, port=NEBULA_PORT, username=NEBULA_USER, nh = NebulaHandler(host=NEBULA_HOST, port=NEBULA_PORT, username=NEBULA_USER,
password=NEBULA_PASSWORD, space_name=cb_name) password=NEBULA_PASSWORD, space_name=cb_name)
cypher = f'''MATCH (v1)--(v2) WHERE id(v1) == '{vertex}' RETURN id(v2) as id;''' if vertex.endswith(".java"):
cypher = f'''MATCH (v1)--(v2:package) WHERE id(v1) == '{vertex}' RETURN id(v2) as id;'''
else:
cypher = f'''MATCH (v1)--(v2) WHERE id(v1) == '{vertex}' RETURN id(v2) as id;'''
# cypher = f'''MATCH (v1)--(v2) WHERE id(v1) == '{vertex}' RETURN v2;'''
cypher_res = nh.execute_cypher(cypher=cypher, format_res=True) cypher_res = nh.execute_cypher(cypher=cypher, format_res=True)
related_vertices = cypher_res.get('id', []) related_vertices = cypher_res.get('id', [])
related_vertices = [i.as_string() for i in related_vertices] related_vertices = [i.as_string() for i in related_vertices]
@ -200,8 +207,8 @@ def search_related_vertices(cb_name: str = Body(..., examples=["sofaboot"]),
def search_code_by_vertex(cb_name: str = Body(..., examples=["sofaboot"]), def search_code_by_vertex(cb_name: str = Body(..., examples=["sofaboot"]),
vertex: str = Body(..., examples=['***'])) -> dict: vertex: str = Body(..., examples=['***'])) -> dict:
logger.info('cb_name={}'.format(cb_name)) # logger.info('cb_name={}'.format(cb_name))
logger.info('vertex={}'.format(vertex)) # logger.info('vertex={}'.format(vertex))
try: try:
nh = NebulaHandler(host=NEBULA_HOST, port=NEBULA_PORT, username=NEBULA_USER, nh = NebulaHandler(host=NEBULA_HOST, port=NEBULA_PORT, username=NEBULA_USER,
@ -233,7 +240,7 @@ def search_code_by_vertex(cb_name: str = Body(..., examples=["sofaboot"]),
return res return res
except Exception as e: except Exception as e:
logger.exception(e) logger.exception(e)
return {} return {'code': ""}
def cb_exists_api(cb_name: str = Body(..., examples=["sofaboot"])) -> bool: def cb_exists_api(cb_name: str = Body(..., examples=["sofaboot"])) -> bool:

View File

@ -8,17 +8,6 @@ from loguru import logger
from langchain.embeddings.base import Embeddings from langchain.embeddings.base import Embeddings
from langchain.docstore.document import Document from langchain.docstore.document import Document
from langchain.embeddings.huggingface import HuggingFaceEmbeddings from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from langchain.vectorstores.utils import DistanceStrategy
# from configs.model_config import (
# KB_ROOT_PATH,
# CACHED_VS_NUM,
# EMBEDDING_MODEL,
# EMBEDDING_DEVICE,
# SCORE_THRESHOLD,
# FAISS_NORMALIZE_L2
# )
# from configs.model_config import embedding_model_dict
from coagent.base_configs.env_config import ( from coagent.base_configs.env_config import (
KB_ROOT_PATH, KB_ROOT_PATH,
@ -52,15 +41,15 @@ def load_vector_store(
tick: int = 0, # tick will be changed by upload_doc etc. and make cache refreshed. tick: int = 0, # tick will be changed by upload_doc etc. and make cache refreshed.
kb_root_path: str = KB_ROOT_PATH, kb_root_path: str = KB_ROOT_PATH,
): ):
print(f"loading vector store in '{knowledge_base_name}'.") # print(f"loading vector store in '{knowledge_base_name}'.")
vs_path = get_vs_path(knowledge_base_name, kb_root_path) vs_path = get_vs_path(knowledge_base_name, kb_root_path)
if embeddings is None: if embeddings is None:
embeddings = load_embeddings_from_path(embed_config.embed_model_path, embed_config.model_device) embeddings = load_embeddings_from_path(embed_config.embed_model_path, embed_config.model_device, embed_config.langchain_embeddings)
if not os.path.exists(vs_path): if not os.path.exists(vs_path):
os.makedirs(vs_path) os.makedirs(vs_path)
distance_strategy = DistanceStrategy.EUCLIDEAN_DISTANCE distance_strategy = "EUCLIDEAN_DISTANCE"
if "index.faiss" in os.listdir(vs_path): if "index.faiss" in os.listdir(vs_path):
search_index = FAISS.load_local(vs_path, embeddings, normalize_L2=FAISS_NORMALIZE_L2, distance_strategy=distance_strategy) search_index = FAISS.load_local(vs_path, embeddings, normalize_L2=FAISS_NORMALIZE_L2, distance_strategy=distance_strategy)
else: else:

View File

@ -9,9 +9,7 @@ from pydantic import BaseModel, Field
from loguru import logger from loguru import logger
from coagent.llm_models import LLMConfig, EmbedConfig from coagent.llm_models import LLMConfig, EmbedConfig
from .base_tool import BaseToolModel from .base_tool import BaseToolModel
from coagent.service.cb_api import search_code from coagent.service.cb_api import search_code
@ -29,7 +27,17 @@ class CodeRetrieval(BaseToolModel):
code: str = Field(..., description="检索代码") code: str = Field(..., description="检索代码")
@classmethod @classmethod
def run(cls, code_base_name, query, code_limit=1, history_node_list=[], search_type="tag", llm_config: LLMConfig=None, embed_config: EmbedConfig=None): def run(cls,
code_base_name,
query,
code_limit=1,
history_node_list=[],
search_type="tag",
llm_config: LLMConfig=None,
embed_config: EmbedConfig=None,
use_nh: str=True,
local_graph_path: str=''
):
"""excute your tool!""" """excute your tool!"""
search_type = { search_type = {
@ -45,7 +53,8 @@ class CodeRetrieval(BaseToolModel):
codes = search_code(code_base_name, query, code_limit, search_type=search_type, history_node_list=history_node_list, codes = search_code(code_base_name, query, code_limit, search_type=search_type, history_node_list=history_node_list,
embed_engine=embed_config.embed_engine, embed_model=embed_config.embed_model, embed_model_path=embed_config.embed_model_path, embed_engine=embed_config.embed_engine, embed_model=embed_config.embed_model, embed_model_path=embed_config.embed_model_path,
model_device=embed_config.model_device, model_name=llm_config.model_name, temperature=llm_config.temperature, model_device=embed_config.model_device, model_name=llm_config.model_name, temperature=llm_config.temperature,
api_base_url=llm_config.api_base_url, api_key=llm_config.api_key api_base_url=llm_config.api_base_url, api_key=llm_config.api_key, use_nh=use_nh,
local_graph_path=local_graph_path, embed_config=embed_config
) )
return_codes = [] return_codes = []
context = codes['context'] context = codes['context']

View File

@ -5,6 +5,7 @@
@time: 2023/12/14 上午10:24 @time: 2023/12/14 上午10:24
@desc: @desc:
''' '''
import os
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from loguru import logger from loguru import logger
@ -40,10 +41,9 @@ class CodeRetrievalSingle(BaseToolModel):
vertex: str = Field(..., description="代码对应 id") vertex: str = Field(..., description="代码对应 id")
@classmethod @classmethod
def run(cls, code_base_name, query, embed_config: EmbedConfig, llm_config: LLMConfig, **kargs): def run(cls, code_base_name, query, embed_config: EmbedConfig, llm_config: LLMConfig, search_type="description", **kargs):
"""excute your tool!""" """excute your tool!"""
search_type = 'description'
code_limit = 1 code_limit = 1
# default # default
@ -51,10 +51,11 @@ class CodeRetrievalSingle(BaseToolModel):
history_node_list=[], history_node_list=[],
embed_engine=embed_config.embed_engine, embed_model=embed_config.embed_model, embed_model_path=embed_config.embed_model_path, embed_engine=embed_config.embed_engine, embed_model=embed_config.embed_model, embed_model_path=embed_config.embed_model_path,
model_device=embed_config.model_device, model_name=llm_config.model_name, temperature=llm_config.temperature, model_device=embed_config.model_device, model_name=llm_config.model_name, temperature=llm_config.temperature,
api_base_url=llm_config.api_base_url, api_key=llm_config.api_key api_base_url=llm_config.api_base_url, api_key=llm_config.api_key, embed_config=embed_config, use_nh=kargs.get("use_nh", True),
local_graph_path=kargs.get("local_graph_path", "")
) )
if os.environ.get("log_verbose", "0") >= "3":
logger.debug(search_result) logger.debug(search_result)
code = search_result['context'] code = search_result['context']
vertex = search_result['related_vertices'][0] vertex = search_result['related_vertices'][0]
# logger.debug(f"code: {code}, vertex: {vertex}") # logger.debug(f"code: {code}, vertex: {vertex}")
@ -83,7 +84,7 @@ class RelatedVerticesRetrival(BaseToolModel):
def run(cls, code_base_name: str, vertex: str, **kargs): def run(cls, code_base_name: str, vertex: str, **kargs):
"""execute your tool!""" """execute your tool!"""
related_vertices = search_related_vertices(cb_name=code_base_name, vertex=vertex) related_vertices = search_related_vertices(cb_name=code_base_name, vertex=vertex)
logger.debug(f"related_vertices: {related_vertices}") # logger.debug(f"related_vertices: {related_vertices}")
return related_vertices return related_vertices
@ -110,6 +111,6 @@ class Vertex2Code(BaseToolModel):
else: else:
vertex = vertex.strip(' "') vertex = vertex.strip(' "')
logger.info(f'vertex={vertex}') # logger.info(f'vertex={vertex}')
res = search_code_by_vertex(cb_name=code_base_name, vertex=vertex) res = search_code_by_vertex(cb_name=code_base_name, vertex=vertex)
return res return res

View File

@ -2,11 +2,7 @@ from pydantic import BaseModel, Field
from loguru import logger from loguru import logger
from coagent.llm_models.llm_config import EmbedConfig from coagent.llm_models.llm_config import EmbedConfig
from .base_tool import BaseToolModel from .base_tool import BaseToolModel
from coagent.service.kb_api import search_docs from coagent.service.kb_api import search_docs

View File

@ -9,8 +9,10 @@ import numpy as np
from loguru import logger from loguru import logger
from .base_tool import BaseToolModel from .base_tool import BaseToolModel
try:
from duckduckgo_search import DDGS from duckduckgo_search import DDGS
except:
logger.warning("can't find duckduckgo_search, if you need it, please `pip install duckduckgo_search`")
class DDGSTool(BaseToolModel): class DDGSTool(BaseToolModel):

View File

@ -0,0 +1,89 @@
import json
def class_info_decode(data):
'''解析class的相关信息'''
params_dict = {}
for i in data:
_params_dict = {}
for ii in i:
for k, v in ii.items():
if k=="origin_query": continue
if k == "Code Path":
_params_dict["code_path"] = v.split("#")[0]
_params_dict["function_name"] = ".".join(v.split("#")[1:])
if k == "Class Description":
_params_dict["ClassDescription"] = v
if k == "Class Base":
_params_dict["ClassBase"] = v
if k=="Init Parameters":
_params_dict["Parameters"] = v
code_path = _params_dict["code_path"]
params_dict.setdefault(code_path, []).append(_params_dict)
return params_dict
def method_info_decode(data):
params_dict = {}
for i in data:
_params_dict = {}
for ii in i:
for k, v in ii.items():
if k=="origin_query": continue
if k == "Code Path":
_params_dict["code_path"] = v.split("#")[0]
_params_dict["function_name"] = ".".join(v.split("#")[1:])
if k == "Return Value Description":
_params_dict["Returns"] = v
if k == "Return Type":
_params_dict["ReturnType"] = v
if k=="Parameters":
_params_dict["Parameters"] = v
code_path = _params_dict["code_path"]
params_dict.setdefault(code_path, []).append(_params_dict)
return params_dict
def encode2md(data, md_format):
md_dict = {}
for code_path, params_list in data.items():
for params in params_list:
params["Parameters_text"] = "\n".join([f"{param['param']}({param['param_type']})-{param['param_description']}"
for param in params["Parameters"]])
# params.delete("Parameters")
text=md_format.format(**params)
md_dict.setdefault(code_path, []).append(text)
return md_dict
method_text_md = '''> {function_name}
| Column Name | Content |
|-----------------|-----------------|
| Parameters | {Parameters_text} |
| Returns | {Returns} |
| Return type | {ReturnType} |
'''
class_text_md = '''> {code_path}
Bases: {ClassBase}
{ClassDescription}
{Parameters_text}
'''

View File

@ -7,7 +7,7 @@ from pathlib import Path
from io import BytesIO from io import BytesIO
from fastapi import Body, File, Form, Body, Query, UploadFile from fastapi import Body, File, Form, Body, Query, UploadFile
from tempfile import SpooledTemporaryFile from tempfile import SpooledTemporaryFile
import json
DATE_FORMAT = "%Y-%m-%d %H:%M:%S" DATE_FORMAT = "%Y-%m-%d %H:%M:%S"
@ -110,3 +110,5 @@ def get_uploadfile(file: Union[str, Path, bytes], filename=None) -> UploadFile:
temp_file.write(file.read()) temp_file.write(file.read())
temp_file.seek(0) temp_file.seek(0)
return UploadFile(file=temp_file, filename=filename) return UploadFile(file=temp_file, filename=filename)

View File

@ -1,7 +1,7 @@
import os import os
from langchain.document_loaders import CSVLoader, PyPDFLoader, UnstructuredFileLoader, TextLoader, PythonLoader from langchain.document_loaders import CSVLoader, PyPDFLoader, UnstructuredFileLoader, TextLoader, PythonLoader
from coagent.document_loaders import JSONLLoader, JSONLoader from coagent.retrieval.document_loaders import JSONLLoader, JSONLoader
# from configs.model_config import ( # from configs.model_config import (
# embedding_model_dict, # embedding_model_dict,
# KB_ROOT_PATH, # KB_ROOT_PATH,

View File

@ -21,17 +21,20 @@ JUPYTER_WORK_PATH = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath
# WEB_CRAWL存储路径 # WEB_CRAWL存储路径
WEB_CRAWL_PATH = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "knowledge_base") WEB_CRAWL_PATH = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "knowledge_base")
# NEBULA_DATA存储路径 # NEBULA_DATA存储路径
NELUBA_PATH = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "data/neluba_data") NEBULA_PATH = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "data/nebula_data")
for _path in [LOG_PATH, SOURCE_PATH, KB_ROOT_PATH, NLTK_DATA_PATH, JUPYTER_WORK_PATH, WEB_CRAWL_PATH, NELUBA_PATH]: # CHROMA 存储路径
CHROMA_PERSISTENT_PATH = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "data/chroma_data")
for _path in [LOG_PATH, SOURCE_PATH, KB_ROOT_PATH, CB_ROOT_PATH, NLTK_DATA_PATH, JUPYTER_WORK_PATH, WEB_CRAWL_PATH, NEBULA_PATH, CHROMA_PERSISTENT_PATH]:
if not os.path.exists(_path): if not os.path.exists(_path):
os.makedirs(_path, exist_ok=True) os.makedirs(_path, exist_ok=True)
#
path_envt_dict = { path_envt_dict = {
"LOG_PATH": LOG_PATH, "SOURCE_PATH": SOURCE_PATH, "KB_ROOT_PATH": KB_ROOT_PATH, "LOG_PATH": LOG_PATH, "SOURCE_PATH": SOURCE_PATH, "KB_ROOT_PATH": KB_ROOT_PATH,
"NLTK_DATA_PATH":NLTK_DATA_PATH, "JUPYTER_WORK_PATH": JUPYTER_WORK_PATH, "NLTK_DATA_PATH":NLTK_DATA_PATH, "JUPYTER_WORK_PATH": JUPYTER_WORK_PATH,
"WEB_CRAWL_PATH": WEB_CRAWL_PATH, "NELUBA_PATH": NELUBA_PATH "WEB_CRAWL_PATH": WEB_CRAWL_PATH, "NEBULA_PATH": NEBULA_PATH,
"CHROMA_PERSISTENT_PATH": CHROMA_PERSISTENT_PATH
} }
for path_name, _path in path_envt_dict.items(): for path_name, _path in path_envt_dict.items():
os.environ[path_name] = _path os.environ[path_name] = _path

View File

@ -33,7 +33,7 @@ except:
pass pass
# add your openai key # add your openai key
OPENAI_API_BASE = "http://openai.com/v1/chat/completions" OPENAI_API_BASE = "https://api.openai.com/v1"
os.environ["API_BASE_URL"] = OPENAI_API_BASE os.environ["API_BASE_URL"] = OPENAI_API_BASE
os.environ["OPENAI_API_KEY"] = "sk-xx" os.environ["OPENAI_API_KEY"] = "sk-xx"
openai.api_key = "sk-xx" openai.api_key = "sk-xx"

View File

@ -58,9 +58,6 @@ NEBULA_GRAPH_SERVER = {
"docker_port": NEBULA_PORT "docker_port": NEBULA_PORT
} }
# chroma conf
CHROMA_PERSISTENT_PATH = '/home/user/chatbot/data/chroma_data'
# sandbox api server # sandbox api server
SANDBOX_CONTRAINER_NAME = "devopsgpt_sandbox" SANDBOX_CONTRAINER_NAME = "devopsgpt_sandbox"
SANDBOX_IMAGE_NAME = "devopsgpt:py39" SANDBOX_IMAGE_NAME = "devopsgpt:py39"

View File

@ -15,11 +15,11 @@ from coagent.connector.schema import Message
# #
tools = toLangchainTools([TOOL_DICT[i] for i in TOOL_SETS if i in TOOL_DICT]) tools = toLangchainTools([TOOL_DICT[i] for i in TOOL_SETS if i in TOOL_DICT])
# log-levelprint prompt和llm predict # log-levelprint prompt和llm predict
os.environ["log_verbose"] = "0" os.environ["log_verbose"] = "2"
phase_name = "baseGroupPhase" phase_name = "baseGroupPhase"
llm_config = LLMConfig( llm_config = LLMConfig(
model_name=LLM_MODEL, model_device="cpu",api_key=os.environ["OPENAI_API_KEY"], model_name=LLM_MODEL, api_key=os.environ["OPENAI_API_KEY"],
api_base_url=os.environ["API_BASE_URL"], temperature=0.3 api_base_url=os.environ["API_BASE_URL"], temperature=0.3
) )
embed_config = EmbedConfig( embed_config = EmbedConfig(

View File

@ -17,7 +17,7 @@ os.environ["log_verbose"] = "2"
phase_name = "baseTaskPhase" phase_name = "baseTaskPhase"
llm_config = LLMConfig( llm_config = LLMConfig(
model_name="gpt-3.5-turbo", model_device="cpu",api_key=os.environ["OPENAI_API_KEY"], model_name="gpt-3.5-turbo", api_key=os.environ["OPENAI_API_KEY"],
api_base_url=os.environ["API_BASE_URL"], temperature=0.3 api_base_url=os.environ["API_BASE_URL"], temperature=0.3
) )
embed_config = EmbedConfig( embed_config = EmbedConfig(

View File

@ -0,0 +1,135 @@
# encoding: utf-8
'''
@author: 温进
@file: codeChatPhaseLocal_example.py
@time: 2024/1/31 下午4:32
@desc:
'''
import os, sys, requests
from concurrent.futures import ThreadPoolExecutor
from tqdm import tqdm
import requests
from typing import List
src_dir = os.path.join(
os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
)
sys.path.append(src_dir)
from configs.model_config import KB_ROOT_PATH, JUPYTER_WORK_PATH, CB_ROOT_PATH
from configs.server_config import SANDBOX_SERVER
from coagent.tools import toLangchainTools, TOOL_DICT, TOOL_SETS
from coagent.llm_models.llm_config import EmbedConfig, LLMConfig
from coagent.connector.phase import BasePhase
from coagent.connector.schema import Message, Memory
from coagent.codechat.codebase_handler.codebase_handler import CodeBaseHandler
# log-levelprint prompt和llm predict
os.environ["log_verbose"] = "1"
llm_config = LLMConfig(
model_name="gpt-3.5-turbo", model_device="cpu",api_key=os.environ["OPENAI_API_KEY"],
api_base_url=os.environ["API_BASE_URL"], temperature=0.3
)
embed_config = EmbedConfig(
embed_engine="model", embed_model="text2vec-base-chinese",
embed_model_path=os.path.join(src_dir, "embedding_models/text2vec-base-chinese")
)
# delete codebase
codebase_name = 'client_local'
code_path = '/Users/bingxu/Desktop/工作/大模型/chatbot/test_code_repo/client'
code_path = "D://chromeDownloads/devopschat-bot/client_v2/client"
use_nh = True
# cbh = CodeBaseHandler(codebase_name, code_path, crawl_type='dir', use_nh=use_nh, local_graph_path=CB_ROOT_PATH,
# llm_config=llm_config, embed_config=embed_config)
cbh = CodeBaseHandler(codebase_name, code_path, crawl_type='dir', use_nh=use_nh, local_graph_path=CB_ROOT_PATH,
llm_config=llm_config, embed_config=embed_config)
cbh.delete_codebase(codebase_name=codebase_name)
# initialize codebase
codebase_name = 'client_local'
code_path = '/Users/bingxu/Desktop/工作/大模型/chatbot/test_code_repo/client'
code_path = "D://chromeDownloads/devopschat-bot/client_v2/client"
code_path = "/home/user/client"
use_nh = True
do_interpret = True
cbh = CodeBaseHandler(codebase_name, code_path, crawl_type='dir', use_nh=use_nh, local_graph_path=CB_ROOT_PATH,
llm_config=llm_config, embed_config=embed_config)
cbh.import_code(do_interpret=do_interpret)
# chat with codebase
phase_name = "codeChatPhase"
phase = BasePhase(
phase_name, sandbox_server=SANDBOX_SERVER, jupyter_work_path=JUPYTER_WORK_PATH,
embed_config=embed_config, llm_config=llm_config, kb_root_path=KB_ROOT_PATH,
)
# remove 这个函数是做什么的 => 基于标签
# 有没有函数已经实现了从字符串删除指定字符串的功能使用的话可以怎么使用写个java代码 => 基于描述
# 有根据我以下的需求用 java 开发一个方法:输入为字符串,将输入中的 .java 字符串给删除掉,然后返回新的字符串 => 基于描述
## 需要启动容器中的nebula采用use_nh=True来构建代码库是可以通过cypher来查询
# round-1
# query_content = "代码一共有多少类"
# query = Message(
# role_name="human", role_type="user",
# role_content=query_content, input_query=query_content, origin_query=query_content,
# code_engine_name="client_1", score_threshold=1.0, top_k=3, cb_search_type="cypher"
# )
#
# output_message1, _ = phase.step(query)
# print(output_message1)
# round-2
# query_content = "代码库里有哪些函数返回5个就行"
# query = Message(
# role_name="human", role_type="user",
# role_content=query_content, input_query=query_content, origin_query=query_content,
# code_engine_name="client_1", score_threshold=1.0, top_k=3, cb_search_type="cypher"
# )
# output_message2, _ = phase.step(query)
# print(output_message2)
# round-3
query_content = "remove 这个函数是做什么的"
query = Message(
role_name="user", role_type="human",
role_content=query_content, input_query=query_content, origin_query=query_content,
code_engine_name=codebase_name, score_threshold=1.0, top_k=3, cb_search_type="tag",
use_nh=False, local_graph_path=CB_ROOT_PATH
)
output_message3, output_memory3 = phase.step(query)
print(output_memory3.to_str_messages(return_all=True, content_key="parsed_output_list"))
#
# # round-4
query_content = "有没有函数已经实现了从字符串删除指定字符串的功能使用的话可以怎么使用写个java代码"
query = Message(
role_name="human", role_type="user",
role_content=query_content, input_query=query_content, origin_query=query_content,
code_engine_name=codebase_name, score_threshold=1.0, top_k=3, cb_search_type="description",
use_nh=False, local_graph_path=CB_ROOT_PATH
)
output_message4, output_memory4 = phase.step(query)
print(output_memory4.to_str_messages(return_all=True, content_key="parsed_output_list"))
# # round-5
query_content = "有根据我以下的需求用 java 开发一个方法:输入为字符串,将输入中的 .java 字符串给删除掉,然后返回新的字符串"
query = Message(
role_name="human", role_type="user",
role_content=query_content, input_query=query_content, origin_query=query_content,
code_engine_name=codebase_name, score_threshold=1.0, top_k=3, cb_search_type="description",
use_nh=False, local_graph_path=CB_ROOT_PATH
)
output_message5, output_memory5 = phase.step(query)
print(output_memory5.to_str_messages(return_all=True, content_key="parsed_output_list"))

View File

@ -17,13 +17,14 @@ os.environ["log_verbose"] = "2"
phase_name = "codeChatPhase" phase_name = "codeChatPhase"
llm_config = LLMConfig( llm_config = LLMConfig(
model_name="gpt-3.5-turbo", model_device="cpu",api_key=os.environ["OPENAI_API_KEY"], model_name="gpt-3.5-turbo", api_key=os.environ["OPENAI_API_KEY"],
api_base_url=os.environ["API_BASE_URL"], temperature=0.3 api_base_url=os.environ["API_BASE_URL"], temperature=0.3
) )
embed_config = EmbedConfig( embed_config = EmbedConfig(
embed_engine="model", embed_model="text2vec-base-chinese", embed_engine="model", embed_model="text2vec-base-chinese",
embed_model_path=os.path.join(src_dir, "embedding_models/text2vec-base-chinese") embed_model_path=os.path.join(src_dir, "embedding_models/text2vec-base-chinese")
) )
phase = BasePhase( phase = BasePhase(
phase_name, sandbox_server=SANDBOX_SERVER, jupyter_work_path=JUPYTER_WORK_PATH, phase_name, sandbox_server=SANDBOX_SERVER, jupyter_work_path=JUPYTER_WORK_PATH,
embed_config=embed_config, llm_config=llm_config, kb_root_path=KB_ROOT_PATH, embed_config=embed_config, llm_config=llm_config, kb_root_path=KB_ROOT_PATH,
@ -35,25 +36,28 @@ phase = BasePhase(
# 有根据我以下的需求用 java 开发一个方法:输入为字符串,将输入中的 .java 字符串给删除掉,然后返回新的字符串 => 基于描述 # 有根据我以下的需求用 java 开发一个方法:输入为字符串,将输入中的 .java 字符串给删除掉,然后返回新的字符串 => 基于描述
# round-1 # round-1
query_content = "代码一共有多少类" # query_content = "代码一共有多少类"
query = Message( # query = Message(
role_name="human", role_type="user", # role_name="human", role_type="user",
role_content=query_content, input_query=query_content, origin_query=query_content, # role_content=query_content, input_query=query_content, origin_query=query_content,
code_engine_name="client", score_threshold=1.0, top_k=3, cb_search_type="cypher" # code_engine_name="client_1", score_threshold=1.0, top_k=3, cb_search_type="cypher"
) # )
#
output_message1, _ = phase.step(query) # output_message1, _ = phase.step(query)
# print(output_message1)
# round-2 # round-2
query_content = "代码库里有哪些函数返回5个就行" # query_content = "代码库里有哪些函数返回5个就行"
query = Message( # query = Message(
role_name="human", role_type="user", # role_name="human", role_type="user",
role_content=query_content, input_query=query_content, origin_query=query_content, # role_content=query_content, input_query=query_content, origin_query=query_content,
code_engine_name="client", score_threshold=1.0, top_k=3, cb_search_type="cypher" # code_engine_name="client_1", score_threshold=1.0, top_k=3, cb_search_type="cypher"
) # )
output_message2, _ = phase.step(query) # output_message2, _ = phase.step(query)
# print(output_message2)
# round-3 #
# # round-3
query_content = "remove 这个函数是做什么的" query_content = "remove 这个函数是做什么的"
query = Message( query = Message(
role_name="user", role_type="human", role_name="user", role_type="human",
@ -61,24 +65,27 @@ query = Message(
code_engine_name="client", score_threshold=1.0, top_k=3, cb_search_type="tag" code_engine_name="client", score_threshold=1.0, top_k=3, cb_search_type="tag"
) )
output_message3, _ = phase.step(query) output_message3, _ = phase.step(query)
print(output_message3)
# round-4 #
query_content = "有没有函数已经实现了从字符串删除指定字符串的功能使用的话可以怎么使用写个java代码" # # round-4
query = Message( # query_content = "有没有函数已经实现了从字符串删除指定字符串的功能使用的话可以怎么使用写个java代码"
role_name="human", role_type="user", # query = Message(
role_content=query_content, input_query=query_content, origin_query=query_content, # role_name="human", role_type="user",
code_engine_name="client", score_threshold=1.0, top_k=3, cb_search_type="description" # role_content=query_content, input_query=query_content, origin_query=query_content,
) # code_engine_name="client_1", score_threshold=1.0, top_k=3, cb_search_type="description"
output_message4, _ = phase.step(query) # )
# output_message4, _ = phase.step(query)
# print(output_message4)
# round-5 #
query_content = "有根据我以下的需求用 java 开发一个方法:输入为字符串,将输入中的 .java 字符串给删除掉,然后返回新的字符串" # # round-5
query = Message( # query_content = "有根据我以下的需求用 java 开发一个方法:输入为字符串,将输入中的 .java 字符串给删除掉,然后返回新的字符串"
role_name="human", role_type="user", # query = Message(
role_content=query_content, input_query=query_content, origin_query=query_content, # role_name="human", role_type="user",
code_engine_name="client", score_threshold=1.0, top_k=3, cb_search_type="description" # role_content=query_content, input_query=query_content, origin_query=query_content,
) # code_engine_name="client_1", score_threshold=1.0, top_k=3, cb_search_type="description"
output_message5, output_memory5 = phase.step(query) # )
# output_message5, output_memory5 = phase.step(query)
print(output_memory5.to_str_messages(return_all=True, content_key="parsed_output_list")) # print(output_message5)
#
# print(output_memory5.to_str_messages(return_all=True, content_key="parsed_output_list"))

View File

@ -0,0 +1,507 @@
import os, sys, json
from loguru import logger
src_dir = os.path.join(
os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
)
sys.path.append(src_dir)
from configs.model_config import KB_ROOT_PATH, JUPYTER_WORK_PATH, CB_ROOT_PATH
from configs.server_config import SANDBOX_SERVER
from coagent.llm_models.llm_config import EmbedConfig, LLMConfig
from coagent.connector.phase import BasePhase
from coagent.connector.agents import BaseAgent
from coagent.connector.schema import Message
from coagent.tools import CodeRetrievalSingle
from coagent.codechat.codebase_handler.codebase_handler import CodeBaseHandler
import importlib
# 定义一个新的agent类
class CodeGenDocer(BaseAgent):
def start_action_step(self, message: Message) -> Message:
'''do action before agent predict '''
# 根据问题获取代码片段和节点信息
action_json = CodeRetrievalSingle.run(message.code_engine_name, message.origin_query,
llm_config=self.llm_config, embed_config=self.embed_config, local_graph_path=message.local_graph_path, use_nh=message.use_nh,search_type="tag")
current_vertex = action_json['vertex']
message.customed_kargs["Code Snippet"] = action_json["code"]
message.customed_kargs['Current_Vertex'] = current_vertex
return message
# add agent or prompt_manager class
agent_module = importlib.import_module("coagent.connector.agents")
setattr(agent_module, 'CodeGenDocer', CodeGenDocer)
# log-levelprint prompt和llm predict
os.environ["log_verbose"] = "1"
phase_name = "code2DocsGroup"
llm_config = LLMConfig(
model_name="gpt-4", api_key=os.environ["OPENAI_API_KEY"],
api_base_url=os.environ["API_BASE_URL"], temperature=0.3
)
embed_config = EmbedConfig(
embed_engine="model", embed_model="text2vec-base-chinese",
embed_model_path=os.path.join(src_dir, "embedding_models/text2vec-base-chinese")
)
# initialize codebase
# delete codebase
codebase_name = 'client_local'
code_path = "D://chromeDownloads/devopschat-bot/client_v2/client"
use_nh = False
cbh = CodeBaseHandler(codebase_name, code_path, crawl_type='dir', use_nh=use_nh, local_graph_path=CB_ROOT_PATH,
llm_config=llm_config, embed_config=embed_config)
cbh.delete_codebase(codebase_name=codebase_name)
# load codebase
codebase_name = 'client_local'
code_path = "D://chromeDownloads/devopschat-bot/client_v2/client"
use_nh = True
do_interpret = True
cbh = CodeBaseHandler(codebase_name, code_path, crawl_type='dir', use_nh=use_nh, local_graph_path=CB_ROOT_PATH,
llm_config=llm_config, embed_config=embed_config)
cbh.import_code(do_interpret=do_interpret)
# 根据前面的load过程进行初始化
cbh = CodeBaseHandler(codebase_name, code_path, crawl_type='dir', use_nh=use_nh, local_graph_path=CB_ROOT_PATH,
llm_config=llm_config, embed_config=embed_config)
phase = BasePhase(
phase_name, sandbox_server=SANDBOX_SERVER, jupyter_work_path=JUPYTER_WORK_PATH,
embed_config=embed_config, llm_config=llm_config, kb_root_path=KB_ROOT_PATH,
)
for vertex_type in ["class", "method"]:
vertexes = cbh.search_vertices(vertex_type=vertex_type)
logger.info(f"vertexes={vertexes}")
# round-1
docs = []
for vertex in vertexes:
vertex = vertex.split("-")[0] # -为method的参数
query_content = f"{vertex_type}节点 {vertex}生成文档"
query = Message(
role_name="human", role_type="user",
role_content=query_content, input_query=query_content, origin_query=query_content,
code_engine_name="client_local", score_threshold=1.0, top_k=3, cb_search_type="tag", use_nh=use_nh,
local_graph_path=CB_ROOT_PATH,
)
output_message, output_memory = phase.step(query, reinit_memory=True)
# print(output_memory.to_str_messages(return_all=True, content_key="parsed_output_list"))
docs.append(output_memory.get_spec_parserd_output())
os.makedirs(f"{CB_ROOT_PATH}/docs", exist_ok=True)
with open(f"{CB_ROOT_PATH}/docs/raw_{vertex_type}.json", "w") as f:
json.dump(docs, f)
# 下面把生成的文档信息转换成markdown文本
from coagent.utils.code2doc_util import *
import json
with open(f"{CB_ROOT_PATH}/docs/raw_method.json", "r") as f:
method_raw_data = json.load(f)
with open(f"{CB_ROOT_PATH}/docs/raw_class.json", "r") as f:
class_raw_data = json.load(f)
method_data = method_info_decode(method_raw_data)
class_data = class_info_decode(class_raw_data)
method_mds = encode2md(method_data, method_text_md)
class_mds = encode2md(class_data, class_text_md)
docs_dict = {}
for k,v in class_mds.items():
method_textmds = method_mds.get(k, [])
for vv in v:
# 理论上只有一个
text_md = vv
for method_textmd in method_textmds:
text_md += "\n<br>" + method_textmd
docs_dict.setdefault(k, []).append(text_md)
with open(f"{CB_ROOT_PATH}//docs/{k}.md", "w") as f:
f.write(text_md)
####################################
######## 下面是完整的复现过程 ########
####################################
# import os, sys, requests
# from loguru import logger
# src_dir = os.path.join(
# os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
# )
# sys.path.append(src_dir)
# from configs.model_config import KB_ROOT_PATH, JUPYTER_WORK_PATH, CB_ROOT_PATH
# from configs.server_config import SANDBOX_SERVER
# from coagent.tools import toLangchainTools, TOOL_DICT, TOOL_SETS
# from coagent.llm_models.llm_config import EmbedConfig, LLMConfig
# from coagent.connector.phase import BasePhase
# from coagent.connector.agents import BaseAgent, SelectorAgent
# from coagent.connector.chains import BaseChain
# from coagent.connector.schema import (
# Message, Memory, load_role_configs, load_phase_configs, load_chain_configs, ActionStatus
# )
# from coagent.connector.memory_manager import BaseMemoryManager
# from coagent.connector.configs import AGETN_CONFIGS, CHAIN_CONFIGS, PHASE_CONFIGS, BASE_PROMPT_CONFIGS
# from coagent.connector.prompt_manager.prompt_manager import PromptManager
# from coagent.codechat.codebase_handler.codebase_handler import CodeBaseHandler
# import importlib
# from loguru import logger
# from coagent.tools import CodeRetrievalSingle, RelatedVerticesRetrival, Vertex2Code
# # update new agent configs
# codeGenDocGroup_PROMPT = """#### Agent Profile
# Your goal is to response according the Context Data's information with the role that will best facilitate a solution, taking into account all relevant context (Context) provided.
# When you need to select the appropriate role for handling a user's query, carefully read the provided role names, role descriptions and tool list.
# ATTENTION: response carefully referenced "Response Output Format" in format.
# #### Input Format
# #### Response Output Format
# **Code Path:** Extract the paths for the class/method/function that need to be addressed from the context
# **Role:** Select the role from agent names
# """
# classGenDoc_PROMPT = """#### Agent Profile
# As an advanced code documentation generator, you are proficient in translating class definitions into comprehensive documentation with a focus on instantiation parameters.
# Your specific task is to parse the given code snippet of a class, extract information regarding its instantiation parameters.
# ATTENTION: response carefully in "Response Output Format".
# #### Input Format
# **Code Snippet:** Provide the full class definition, including the constructor and any parameters it may require for instantiation.
# #### Response Output Format
# **Class Base:** Specify the base class or interface from which the current class extends, if any.
# **Class Description:** Offer a brief description of the class's purpose and functionality.
# **Init Parameters:** List each parameter from construct. For each parameter, provide:
# - `param`: The parameter name
# - `param_description`: A concise explanation of the parameter's purpose.
# - `param_type`: The data type of the parameter, if explicitly defined.
# ```json
# [
# {
# "param": "parameter_name",
# "param_description": "A brief description of what this parameter is used for.",
# "param_type": "The data type of the parameter"
# },
# ...
# ]
# ```
# If no parameter for construct, return
# ```json
# []
# ```
# """
# funcGenDoc_PROMPT = """#### Agent Profile
# You are a high-level code documentation assistant, skilled at extracting information from function/method code into detailed and well-structured documentation.
# ATTENTION: response carefully in "Response Output Format".
# #### Input Format
# **Code Path:** Provide the code path of the function or method you wish to document.
# This name will be used to identify and extract the relevant details from the code snippet provided.
# **Code Snippet:** A segment of code that contains the function or method to be documented.
# #### Response Output Format
# **Class Description:** Offer a brief description of the method(function)'s purpose and functionality.
# **Parameters:** Extract parameter for the specific function/method Code from Code Snippet. For parameter, provide:
# - `param`: The parameter name
# - `param_description`: A concise explanation of the parameter's purpose.
# - `param_type`: The data type of the parameter, if explicitly defined.
# ```json
# [
# {
# "param": "parameter_name",
# "param_description": "A brief description of what this parameter is used for.",
# "param_type": "The data type of the parameter"
# },
# ...
# ]
# ```
# If no parameter for function/method, return
# ```json
# []
# ```
# **Return Value Description:** Describe what the function/method returns upon completion.
# **Return Type:** Indicate the type of data the function/method returns (e.g., string, integer, object, void).
# """
# CODE_GENERATE_GROUP_PROMPT_CONFIGS = [
# {"field_name": 'agent_profile', "function_name": 'handle_agent_profile', "is_context": False},
# {"field_name": 'agent_infomation', "function_name": 'handle_agent_data', "is_context": False, "omit_if_empty": False},
# # {"field_name": 'tool_information',"function_name": 'handle_tool_data', "is_context": False},
# {"field_name": 'context_placeholder', "function_name": '', "is_context": True},
# # {"field_name": 'reference_documents', "function_name": 'handle_doc_info'},
# {"field_name": 'session_records', "function_name": 'handle_session_records'},
# {"field_name": 'Specific Objective', "function_name": 'handle_specific_objective'},
# {"field_name": 'Code Snippet', "function_name": 'handle_code_snippet'},
# {"field_name": 'output_format', "function_name": 'handle_output_format', 'title': 'Response Output Format', "is_context": False},
# {"field_name": 'begin!!!', "function_name": 'handle_response', "is_context": False, "omit_if_empty": False}
# ]
# CODE_GENERATE_DOC_PROMPT_CONFIGS = [
# {"field_name": 'agent_profile', "function_name": 'handle_agent_profile', "is_context": False},
# # {"field_name": 'tool_information',"function_name": 'handle_tool_data', "is_context": False},
# {"field_name": 'context_placeholder', "function_name": '', "is_context": True},
# # {"field_name": 'reference_documents', "function_name": 'handle_doc_info'},
# {"field_name": 'session_records', "function_name": 'handle_session_records'},
# {"field_name": 'Specific Objective', "function_name": 'handle_specific_objective'},
# {"field_name": 'Code Snippet', "function_name": 'handle_code_snippet'},
# {"field_name": 'output_format', "function_name": 'handle_output_format', 'title': 'Response Output Format', "is_context": False},
# {"field_name": 'begin!!!', "function_name": 'handle_response', "is_context": False, "omit_if_empty": False}
# ]
# class CodeGenDocPM(PromptManager):
# def handle_code_snippet(self, **kwargs) -> str:
# if 'previous_agent_message' not in kwargs:
# return ""
# previous_agent_message: Message = kwargs['previous_agent_message']
# code_snippet = previous_agent_message.customed_kargs.get("Code Snippet", "")
# current_vertex = previous_agent_message.customed_kargs.get("Current_Vertex", "")
# instruction = "A segment of code that contains the function or method to be documented.\n"
# return instruction + "\n" + f"name: {current_vertex}\n{code_snippet}"
# def handle_specific_objective(self, **kwargs) -> str:
# if 'previous_agent_message' not in kwargs:
# return ""
# previous_agent_message: Message = kwargs['previous_agent_message']
# specific_objective = previous_agent_message.parsed_output.get("Code Path")
# instruction = "Provide the code path of the function or method you wish to document.\n"
# s = instruction + f"\n{specific_objective}"
# return s
# from coagent.tools import CodeRetrievalSingle
# # 定义一个新的agent类
# class CodeGenDocer(BaseAgent):
# def start_action_step(self, message: Message) -> Message:
# '''do action before agent predict '''
# # 根据问题获取代码片段和节点信息
# action_json = CodeRetrievalSingle.run(message.code_engine_name, message.origin_query,
# llm_config=self.llm_config, embed_config=self.embed_config, local_graph_path=message.local_graph_path, use_nh=message.use_nh,search_type="tag")
# current_vertex = action_json['vertex']
# message.customed_kargs["Code Snippet"] = action_json["code"]
# message.customed_kargs['Current_Vertex'] = current_vertex
# return message
# # add agent or prompt_manager class
# agent_module = importlib.import_module("coagent.connector.agents")
# prompt_manager_module = importlib.import_module("coagent.connector.prompt_manager")
# setattr(agent_module, 'CodeGenDocer', CodeGenDocer)
# setattr(prompt_manager_module, 'CodeGenDocPM', CodeGenDocPM)
# AGETN_CONFIGS.update({
# "classGenDoc": {
# "role": {
# "role_prompt": classGenDoc_PROMPT,
# "role_type": "assistant",
# "role_name": "classGenDoc",
# "role_desc": "",
# "agent_type": "CodeGenDocer"
# },
# "prompt_config": CODE_GENERATE_DOC_PROMPT_CONFIGS,
# "prompt_manager_type": "CodeGenDocPM",
# "chat_turn": 1,
# "focus_agents": [],
# "focus_message_keys": [],
# },
# "funcGenDoc": {
# "role": {
# "role_prompt": funcGenDoc_PROMPT,
# "role_type": "assistant",
# "role_name": "funcGenDoc",
# "role_desc": "",
# "agent_type": "CodeGenDocer"
# },
# "prompt_config": CODE_GENERATE_DOC_PROMPT_CONFIGS,
# "prompt_manager_type": "CodeGenDocPM",
# "chat_turn": 1,
# "focus_agents": [],
# "focus_message_keys": [],
# },
# "codeGenDocsGrouper": {
# "role": {
# "role_prompt": codeGenDocGroup_PROMPT,
# "role_type": "assistant",
# "role_name": "codeGenDocsGrouper",
# "role_desc": "",
# "agent_type": "SelectorAgent"
# },
# "prompt_config": CODE_GENERATE_GROUP_PROMPT_CONFIGS,
# "group_agents": ["classGenDoc", "funcGenDoc"],
# "chat_turn": 1,
# },
# })
# # update new chain configs
# CHAIN_CONFIGS.update({
# "codeGenDocsGroupChain": {
# "chain_name": "codeGenDocsGroupChain",
# "chain_type": "BaseChain",
# "agents": ["codeGenDocsGrouper"],
# "chat_turn": 1,
# "do_checker": False,
# "chain_prompt": ""
# }
# })
# # update phase configs
# PHASE_CONFIGS.update({
# "codeGenDocsGroup": {
# "phase_name": "codeGenDocsGroup",
# "phase_type": "BasePhase",
# "chains": ["codeGenDocsGroupChain"],
# "do_summary": False,
# "do_search": False,
# "do_doc_retrieval": False,
# "do_code_retrieval": False,
# "do_tool_retrieval": False,
# },
# })
# role_configs = load_role_configs(AGETN_CONFIGS)
# chain_configs = load_chain_configs(CHAIN_CONFIGS)
# phase_configs = load_phase_configs(PHASE_CONFIGS)
# # log-levelprint prompt和llm predict
# os.environ["log_verbose"] = "1"
# phase_name = "codeGenDocsGroup"
# llm_config = LLMConfig(
# model_name="gpt-4", api_key=os.environ["OPENAI_API_KEY"],
# api_base_url=os.environ["API_BASE_URL"], temperature=0.3
# )
# embed_config = EmbedConfig(
# embed_engine="model", embed_model="text2vec-base-chinese",
# embed_model_path=os.path.join(src_dir, "embedding_models/text2vec-base-chinese")
# )
# # initialize codebase
# # delete codebase
# codebase_name = 'client_local'
# code_path = "D://chromeDownloads/devopschat-bot/client_v2/client"
# use_nh = False
# cbh = CodeBaseHandler(codebase_name, code_path, crawl_type='dir', use_nh=use_nh, local_graph_path=CB_ROOT_PATH,
# llm_config=llm_config, embed_config=embed_config)
# cbh.delete_codebase(codebase_name=codebase_name)
# # load codebase
# codebase_name = 'client_local'
# code_path = "D://chromeDownloads/devopschat-bot/client_v2/client"
# use_nh = False
# do_interpret = True
# cbh = CodeBaseHandler(codebase_name, code_path, crawl_type='dir', use_nh=use_nh, local_graph_path=CB_ROOT_PATH,
# llm_config=llm_config, embed_config=embed_config)
# cbh.import_code(do_interpret=do_interpret)
# phase = BasePhase(
# phase_name, sandbox_server=SANDBOX_SERVER, jupyter_work_path=JUPYTER_WORK_PATH,
# embed_config=embed_config, llm_config=llm_config, kb_root_path=KB_ROOT_PATH,
# )
# for vertex_type in ["class", "method"]:
# vertexes = cbh.search_vertices(vertex_type=vertex_type)
# logger.info(f"vertexes={vertexes}")
# # round-1
# docs = []
# for vertex in vertexes:
# vertex = vertex.split("-")[0] # -为method的参数
# query_content = f"为{vertex_type}节点 {vertex}生成文档"
# query = Message(
# role_name="human", role_type="user",
# role_content=query_content, input_query=query_content, origin_query=query_content,
# code_engine_name="client_local", score_threshold=1.0, top_k=3, cb_search_type="tag", use_nh=use_nh,
# local_graph_path=CB_ROOT_PATH,
# )
# output_message, output_memory = phase.step(query, reinit_memory=True)
# # print(output_memory.to_str_messages(return_all=True, content_key="parsed_output_list"))
# docs.append(output_memory.get_spec_parserd_output())
# import json
# os.makedirs("/home/user/code_base/docs", exist_ok=True)
# with open(f"/home/user/code_base/docs/raw_{vertex_type}.json", "w") as f:
# json.dump(docs, f)
# # 下面把生成的文档信息转换成markdown文本
# from coagent.utils.code2doc_util import *
# import json
# with open(f"/home/user/code_base/docs/raw_method.json", "r") as f:
# method_raw_data = json.load(f)
# with open(f"/home/user/code_base/docs/raw_class.json", "r") as f:
# class_raw_data = json.load(f)
# method_data = method_info_decode(method_raw_data)
# class_data = class_info_decode(class_raw_data)
# method_mds = encode2md(method_data, method_text_md)
# class_mds = encode2md(class_data, class_text_md)
# docs_dict = {}
# for k,v in class_mds.items():
# method_textmds = method_mds.get(k, [])
# for vv in v:
# # 理论上只有一个
# text_md = vv
# for method_textmd in method_textmds:
# text_md += "\n<br>" + method_textmd
# docs_dict.setdefault(k, []).append(text_md)
# with open(f"/home/user/code_base/docs/{k}.md", "w") as f:
# f.write(text_md)

View File

@ -0,0 +1,444 @@
import os, sys, json
from loguru import logger
src_dir = os.path.join(
os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
)
sys.path.append(src_dir)
from configs.model_config import KB_ROOT_PATH, JUPYTER_WORK_PATH, CB_ROOT_PATH
from configs.server_config import SANDBOX_SERVER
from coagent.llm_models.llm_config import EmbedConfig, LLMConfig
from coagent.connector.phase import BasePhase
from coagent.connector.agents import BaseAgent
from coagent.connector.schema import Message
from coagent.codechat.codebase_handler.codebase_handler import CodeBaseHandler
import importlib
from loguru import logger
from coagent.tools import CodeRetrievalSingle, RelatedVerticesRetrival, Vertex2Code
# 定义一个新的agent类
class CodeRetrieval(BaseAgent):
def start_action_step(self, message: Message) -> Message:
'''do action before agent predict '''
# 根据问题获取代码片段和节点信息
action_json = CodeRetrievalSingle.run(message.code_engine_name, message.origin_query, llm_config=self.llm_config, embed_config=self.embed_config, search_type="tag",
local_graph_path=message.local_graph_path, use_nh=message.use_nh)
current_vertex = action_json['vertex']
message.customed_kargs["Code Snippet"] = action_json["code"]
message.customed_kargs['Current_Vertex'] = current_vertex
# 获取邻近节点
action_json = RelatedVerticesRetrival.run(message.code_engine_name, message.customed_kargs['Current_Vertex'])
# 获取邻近节点所有代码
relative_vertex = []
retrieval_Codes = []
for vertex in action_json["vertices"]:
# 由于代码是文件级别,所以相同文件代码不再获取
# logger.debug(f"{current_vertex}, {vertex}")
current_vertex_name = current_vertex.replace("#", "").replace(".java", "" ) if current_vertex.endswith(".java") else current_vertex
if current_vertex_name.split("#")[0] == vertex.split("#")[0]: continue
action_json = Vertex2Code.run(message.code_engine_name, vertex)
if action_json["code"]:
retrieval_Codes.append(action_json["code"])
relative_vertex.append(vertex)
#
message.customed_kargs["Retrieval_Codes"] = retrieval_Codes
message.customed_kargs["Relative_vertex"] = relative_vertex
return message
# add agent or prompt_manager class
agent_module = importlib.import_module("coagent.connector.agents")
setattr(agent_module, 'CodeRetrieval', CodeRetrieval)
llm_config = LLMConfig(
model_name="gpt-4", api_key=os.environ["OPENAI_API_KEY"],
api_base_url=os.environ["API_BASE_URL"], temperature=0.3
)
embed_config = EmbedConfig(
embed_engine="model", embed_model="text2vec-base-chinese",
embed_model_path=os.path.join(src_dir, "embedding_models/text2vec-base-chinese")
)
## initialize codebase
# delete codebase
codebase_name = 'client_local'
code_path = "D://chromeDownloads/devopschat-bot/client_v2/client"
use_nh = False
cbh = CodeBaseHandler(codebase_name, code_path, crawl_type='dir', use_nh=use_nh, local_graph_path=CB_ROOT_PATH,
llm_config=llm_config, embed_config=embed_config)
cbh.delete_codebase(codebase_name=codebase_name)
# load codebase
codebase_name = 'client_local'
code_path = "D://chromeDownloads/devopschat-bot/client_v2/client"
use_nh = True
do_interpret = True
cbh = CodeBaseHandler(codebase_name, code_path, crawl_type='dir', use_nh=use_nh, local_graph_path=CB_ROOT_PATH,
llm_config=llm_config, embed_config=embed_config)
cbh.import_code(do_interpret=do_interpret)
cbh = CodeBaseHandler(codebase_name, code_path, crawl_type='dir', use_nh=use_nh, local_graph_path=CB_ROOT_PATH,
llm_config=llm_config, embed_config=embed_config)
vertexes = cbh.search_vertices(vertex_type="class")
# log-levelprint prompt和llm predict
os.environ["log_verbose"] = "0"
phase_name = "code2Tests"
phase = BasePhase(
phase_name, sandbox_server=SANDBOX_SERVER, jupyter_work_path=JUPYTER_WORK_PATH,
embed_config=embed_config, llm_config=llm_config, kb_root_path=KB_ROOT_PATH,
)
# round-1
logger.debug(vertexes)
test_cases = []
for vertex in vertexes:
query_content = f"{vertex}生成可执行的测例 "
query = Message(
role_name="human", role_type="user",
role_content=query_content, input_query=query_content, origin_query=query_content,
code_engine_name=codebase_name, score_threshold=1.0, top_k=3, cb_search_type="tag",
use_nh=use_nh, local_graph_path=CB_ROOT_PATH,
)
output_message, output_memory = phase.step(query, reinit_memory=True)
# print(output_memory.to_str_messages(return_all=True, content_key="parsed_output_list"))
print(output_memory.get_spec_parserd_output())
values = output_memory.get_spec_parserd_output()
test_code = {k:v for i in values for k,v in i.items() if k in ["SaveFileName", "Test Code"]}
test_cases.append(test_code)
os.makedirs(f"{CB_ROOT_PATH}/tests", exist_ok=True)
with open(f"{CB_ROOT_PATH}/tests/{test_code['SaveFileName']}", "w") as f:
f.write(test_code["Test Code"])
break
# import os, sys, json
# from loguru import logger
# src_dir = os.path.join(
# os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
# )
# sys.path.append(src_dir)
# from configs.model_config import KB_ROOT_PATH, JUPYTER_WORK_PATH, CB_ROOT_PATH
# from configs.server_config import SANDBOX_SERVER
# from coagent.tools import toLangchainTools, TOOL_DICT, TOOL_SETS
# from coagent.llm_models.llm_config import EmbedConfig, LLMConfig
# from coagent.connector.phase import BasePhase
# from coagent.connector.agents import BaseAgent
# from coagent.connector.chains import BaseChain
# from coagent.connector.schema import (
# Message, Memory, load_role_configs, load_phase_configs, load_chain_configs, ActionStatus
# )
# from coagent.connector.memory_manager import BaseMemoryManager
# from coagent.connector.configs import AGETN_CONFIGS, CHAIN_CONFIGS, PHASE_CONFIGS, BASE_PROMPT_CONFIGS
# from coagent.connector.prompt_manager.prompt_manager import PromptManager
# from coagent.codechat.codebase_handler.codebase_handler import CodeBaseHandler
# import importlib
# from loguru import logger
# # 下面给定了一份代码片段,以及来自于它的依赖类、依赖方法相关的代码片段,你需要判断是否为这段指定代码片段生成测例。
# # 理论上所有代码都需要写测例,但是受限于人的精力不可能覆盖所有代码
# # 考虑以下因素进行裁剪:
# # 功能性: 如果它实现的是一个具体的功能或逻辑,则通常需要编写测试用例以验证其正确性。
# # 复杂性: 如果代码较为尤其是包含多个条件判断、循环、异常处理等的代码更可能隐藏bug因此应该编写测试用例。如果代码涉及复杂的算法或者逻辑那么编写测试用例可以帮助确保逻辑的正确性并在未来的重构中防止引入错误。
# # 关键性: 如果它是关键路径的一部分或影响到核心功能,那么它就需要被测试。对于核心业务逻辑或者系统的关键组件,应当编写全面的测试用例来确保功能的正确性和稳定性。
# # 依赖性: 如果代码有外部依赖,可能需要编写集成测试或模拟这些依赖进行单元测试。
# # 用户输入: 如果代码处理用户输入,尤其是来自外部的、非受控的输入,那么创建测试用例来检查输入验证和处理是很重要的。
# # 频繁更改:对于经常需要更新或修改的代码,有相应的测试用例可以确保更改不会破坏现有功能。
# # 代码公开或重用:如果代码将被公开或用于其他项目,编写测试用例可以提高代码的可信度和易用性。
# # update new agent configs
# judgeGenerateTests_PROMPT = """#### Agent Profile
# When determining the necessity of writing test cases for a given code snippet,
# it's essential to evaluate its interactions with dependent classes and methods (retrieved code snippets),
# in addition to considering these critical factors:
# 1. Functionality: If it implements a concrete function or logic, test cases are typically necessary to verify its correctness.
# 2. Complexity: If the code is complex, especially if it contains multiple conditional statements, loops, exceptions handling, etc.,
# it's more likely to harbor bugs, and thus test cases should be written.
# If the code involves complex algorithms or logic, then writing test cases can help ensure the accuracy of the logic and prevent errors during future refactoring.
# 3. Criticality: If it's part of the critical path or affects core functionalities, then it needs to be tested.
# Comprehensive test cases should be written for core business logic or key components of the system to ensure the correctness and stability of the functionality.
# 4. Dependencies: If the code has external dependencies, integration testing may be necessary, or mocking these dependencies during unit testing might be required.
# 5. User Input: If the code handles user input, especially from unregulated external sources, creating test cases to check input validation and handling is important.
# 6. Frequent Changes: For code that requires regular updates or modifications, having the appropriate test cases ensures that changes do not break existing functionalities.
# #### Input Format
# **Code Snippet:** the initial Code or objective that the user wanted to achieve
# **Retrieval Code Snippets:** These are the associated code segments that the main Code Snippet depends on.
# Examine these snippets to understand how they interact with the main snippet and to determine how they might affect the overall functionality.
# #### Response Output Format
# **Action Status:** Set to 'finished' or 'continued'.
# If set to 'finished', the code snippet does not warrant the generation of a test case.
# If set to 'continued', the code snippet necessitates the creation of a test case.
# **REASON:** Justify the selection of 'finished' or 'continued', contemplating the decision through a step-by-step rationale.
# """
# generateTests_PROMPT = """#### Agent Profile
# As an agent specializing in software quality assurance,
# your mission is to craft comprehensive test cases that bolster the functionality, reliability, and robustness of a specified Code Snippet.
# This task is to be carried out with a keen understanding of the snippet's interactions with its dependent classes and methods—collectively referred to as Retrieval Code Snippets.
# Analyze the details given below to grasp the code's intended purpose, its inherent complexity, and the context within which it operates.
# Your constructed test cases must thoroughly examine the various factors influencing the code's quality and performance.
# ATTENTION: response carefully referenced "Response Output Format" in format.
# Each test case should include:
# 1. clear description of the test purpose.
# 2. The input values or conditions for the test.
# 3. The expected outcome or assertion for the test.
# 4. Appropriate tags (e.g., 'functional', 'integration', 'regression') that classify the type of test case.
# 5. these test code should have package and import
# #### Input Format
# **Code Snippet:** the initial Code or objective that the user wanted to achieve
# **Retrieval Code Snippets:** These are the interrelated pieces of code sourced from the codebase, which support or influence the primary Code Snippet.
# #### Response Output Format
# **SaveFileName:** construct a local file name based on Question and Context, such as
# ```java
# package/class.java
# ```
# **Test Code:** generate the test code for the current Code Snippet.
# ```java
# ...
# ```
# """
# CODE_GENERATE_TESTS_PROMPT_CONFIGS = [
# {"field_name": 'agent_profile', "function_name": 'handle_agent_profile', "is_context": False},
# # {"field_name": 'tool_information',"function_name": 'handle_tool_data', "is_context": False},
# {"field_name": 'context_placeholder', "function_name": '', "is_context": True},
# # {"field_name": 'reference_documents', "function_name": 'handle_doc_info'},
# {"field_name": 'session_records', "function_name": 'handle_session_records'},
# {"field_name": 'code_snippet', "function_name": 'handle_code_snippet'},
# {"field_name": 'retrieval_codes', "function_name": 'handle_retrieval_codes', "description": ""},
# {"field_name": 'output_format', "function_name": 'handle_output_format', 'title': 'Response Output Format', "is_context": False},
# {"field_name": 'begin!!!', "function_name": 'handle_response', "is_context": False, "omit_if_empty": False}
# ]
# class CodeRetrievalPM(PromptManager):
# def handle_code_snippet(self, **kwargs) -> str:
# if 'previous_agent_message' not in kwargs:
# return ""
# previous_agent_message: Message = kwargs['previous_agent_message']
# code_snippet = previous_agent_message.customed_kargs.get("Code Snippet", "")
# current_vertex = previous_agent_message.customed_kargs.get("Current_Vertex", "")
# instruction = "the initial Code or objective that the user wanted to achieve"
# return instruction + "\n" + f"name: {current_vertex}\n{code_snippet}"
# def handle_retrieval_codes(self, **kwargs) -> str:
# if 'previous_agent_message' not in kwargs:
# return ""
# previous_agent_message: Message = kwargs['previous_agent_message']
# Retrieval_Codes = previous_agent_message.customed_kargs["Retrieval_Codes"]
# Relative_vertex = previous_agent_message.customed_kargs["Relative_vertex"]
# instruction = "the initial Code or objective that the user wanted to achieve"
# s = instruction + "\n" + "\n".join([f"name: {vertext}\n{code}" for vertext, code in zip(Relative_vertex, Retrieval_Codes)])
# return s
# AGETN_CONFIGS.update({
# "CodeJudger": {
# "role": {
# "role_prompt": judgeGenerateTests_PROMPT,
# "role_type": "assistant",
# "role_name": "CodeJudger",
# "role_desc": "",
# "agent_type": "CodeRetrieval"
# # "agent_type": "BaseAgent"
# },
# "prompt_config": CODE_GENERATE_TESTS_PROMPT_CONFIGS,
# "prompt_manager_type": "CodeRetrievalPM",
# "chat_turn": 1,
# "focus_agents": [],
# "focus_message_keys": [],
# },
# "generateTests": {
# "role": {
# "role_prompt": generateTests_PROMPT,
# "role_type": "assistant",
# "role_name": "generateTests",
# "role_desc": "",
# "agent_type": "CodeRetrieval"
# # "agent_type": "BaseAgent"
# },
# "prompt_config": CODE_GENERATE_TESTS_PROMPT_CONFIGS,
# "prompt_manager_type": "CodeRetrievalPM",
# "chat_turn": 1,
# "focus_agents": [],
# "focus_message_keys": [],
# },
# })
# # update new chain configs
# CHAIN_CONFIGS.update({
# "codeRetrievalChain": {
# "chain_name": "codeRetrievalChain",
# "chain_type": "BaseChain",
# "agents": ["CodeJudger", "generateTests"],
# "chat_turn": 1,
# "do_checker": False,
# "chain_prompt": ""
# }
# })
# # update phase configs
# PHASE_CONFIGS.update({
# "codeGenerateTests": {
# "phase_name": "codeGenerateTests",
# "phase_type": "BasePhase",
# "chains": ["codeRetrievalChain"],
# "do_summary": False,
# "do_search": False,
# "do_doc_retrieval": False,
# "do_code_retrieval": False,
# "do_tool_retrieval": False,
# },
# })
# role_configs = load_role_configs(AGETN_CONFIGS)
# chain_configs = load_chain_configs(CHAIN_CONFIGS)
# phase_configs = load_phase_configs(PHASE_CONFIGS)
# from coagent.tools import CodeRetrievalSingle, RelatedVerticesRetrival, Vertex2Code
# # 定义一个新的agent类
# class CodeRetrieval(BaseAgent):
# def start_action_step(self, message: Message) -> Message:
# '''do action before agent predict '''
# # 根据问题获取代码片段和节点信息
# action_json = CodeRetrievalSingle.run(message.code_engine_name, message.origin_query, llm_config=self.llm_config, embed_config=self.embed_config, search_type="tag",
# local_graph_path=message.local_graph_path, use_nh=message.use_nh)
# current_vertex = action_json['vertex']
# message.customed_kargs["Code Snippet"] = action_json["code"]
# message.customed_kargs['Current_Vertex'] = current_vertex
# # 获取邻近节点
# action_json = RelatedVerticesRetrival.run(message.code_engine_name, message.customed_kargs['Current_Vertex'])
# # 获取邻近节点所有代码
# relative_vertex = []
# retrieval_Codes = []
# for vertex in action_json["vertices"]:
# # 由于代码是文件级别,所以相同文件代码不再获取
# # logger.debug(f"{current_vertex}, {vertex}")
# current_vertex_name = current_vertex.replace("#", "").replace(".java", "" ) if current_vertex.endswith(".java") else current_vertex
# if current_vertex_name.split("#")[0] == vertex.split("#")[0]: continue
# action_json = Vertex2Code.run(message.code_engine_name, vertex)
# if action_json["code"]:
# retrieval_Codes.append(action_json["code"])
# relative_vertex.append(vertex)
# #
# message.customed_kargs["Retrieval_Codes"] = retrieval_Codes
# message.customed_kargs["Relative_vertex"] = relative_vertex
# return message
# # add agent or prompt_manager class
# agent_module = importlib.import_module("coagent.connector.agents")
# prompt_manager_module = importlib.import_module("coagent.connector.prompt_manager")
# setattr(agent_module, 'CodeRetrieval', CodeRetrieval)
# setattr(prompt_manager_module, 'CodeRetrievalPM', CodeRetrievalPM)
# # log-levelprint prompt和llm predict
# os.environ["log_verbose"] = "0"
# phase_name = "codeGenerateTests"
# llm_config = LLMConfig(
# model_name="gpt-4", api_key=os.environ["OPENAI_API_KEY"],
# api_base_url=os.environ["API_BASE_URL"], temperature=0.3
# )
# embed_config = EmbedConfig(
# embed_engine="model", embed_model="text2vec-base-chinese",
# embed_model_path=os.path.join(src_dir, "embedding_models/text2vec-base-chinese")
# )
# ## initialize codebase
# # delete codebase
# codebase_name = 'client_local'
# code_path = "D://chromeDownloads/devopschat-bot/client_v2/client"
# use_nh = False
# cbh = CodeBaseHandler(codebase_name, code_path, crawl_type='dir', use_nh=use_nh, local_graph_path=CB_ROOT_PATH,
# llm_config=llm_config, embed_config=embed_config)
# cbh.delete_codebase(codebase_name=codebase_name)
# # load codebase
# codebase_name = 'client_local'
# code_path = "D://chromeDownloads/devopschat-bot/client_v2/client"
# use_nh = True
# do_interpret = True
# cbh = CodeBaseHandler(codebase_name, code_path, crawl_type='dir', use_nh=use_nh, local_graph_path=CB_ROOT_PATH,
# llm_config=llm_config, embed_config=embed_config)
# cbh.import_code(do_interpret=do_interpret)
# cbh = CodeBaseHandler(codebase_name, code_path, crawl_type='dir', use_nh=use_nh, local_graph_path=CB_ROOT_PATH,
# llm_config=llm_config, embed_config=embed_config)
# vertexes = cbh.search_vertices(vertex_type="class")
# phase = BasePhase(
# phase_name, sandbox_server=SANDBOX_SERVER, jupyter_work_path=JUPYTER_WORK_PATH,
# embed_config=embed_config, llm_config=llm_config, kb_root_path=KB_ROOT_PATH,
# )
# # round-1
# logger.debug(vertexes)
# test_cases = []
# for vertex in vertexes:
# query_content = f"为{vertex}生成可执行的测例 "
# query = Message(
# role_name="human", role_type="user",
# role_content=query_content, input_query=query_content, origin_query=query_content,
# code_engine_name=codebase_name, score_threshold=1.0, top_k=3, cb_search_type="tag",
# use_nh=use_nh, local_graph_path=CB_ROOT_PATH,
# )
# output_message, output_memory = phase.step(query, reinit_memory=True)
# # print(output_memory.to_str_messages(return_all=True, content_key="parsed_output_list"))
# print(output_memory.get_spec_parserd_output())
# values = output_memory.get_spec_parserd_output()
# test_code = {k:v for i in values for k,v in i.items() if k in ["SaveFileName", "Test Code"]}
# test_cases.append(test_code)
# os.makedirs(f"{CB_ROOT_PATH}/tests", exist_ok=True)
# with open(f"{CB_ROOT_PATH}/tests/{test_code['SaveFileName']}", "w") as f:
# f.write(test_code["Test Code"])

View File

@ -17,7 +17,7 @@ os.environ["log_verbose"] = "2"
phase_name = "codeReactPhase" phase_name = "codeReactPhase"
llm_config = LLMConfig( llm_config = LLMConfig(
model_name="gpt-3.5-turbo", model_device="cpu",api_key=os.environ["OPENAI_API_KEY"], model_name="gpt-3.5-turbo",api_key=os.environ["OPENAI_API_KEY"],
api_base_url=os.environ["API_BASE_URL"], temperature=0.3 api_base_url=os.environ["API_BASE_URL"], temperature=0.3
) )
embed_config = EmbedConfig( embed_config = EmbedConfig(

View File

@ -18,8 +18,7 @@ from coagent.connector.schema import (
) )
from coagent.connector.memory_manager import BaseMemoryManager from coagent.connector.memory_manager import BaseMemoryManager
from coagent.connector.configs import AGETN_CONFIGS, CHAIN_CONFIGS, PHASE_CONFIGS, BASE_PROMPT_CONFIGS from coagent.connector.configs import AGETN_CONFIGS, CHAIN_CONFIGS, PHASE_CONFIGS, BASE_PROMPT_CONFIGS
from coagent.connector.utils import parse_section from coagent.connector.prompt_manager.prompt_manager import PromptManager
from coagent.connector.prompt_manager import PromptManager
import importlib import importlib
from loguru import logger from loguru import logger
@ -230,7 +229,7 @@ os.environ["log_verbose"] = "2"
phase_name = "codeRetrievalPhase" phase_name = "codeRetrievalPhase"
llm_config = LLMConfig( llm_config = LLMConfig(
model_name="gpt-3.5-turbo", model_device="cpu",api_key=os.environ["OPENAI_API_KEY"], model_name="gpt-3.5-turbo", api_key=os.environ["OPENAI_API_KEY"],
api_base_url=os.environ["API_BASE_URL"], temperature=0.3 api_base_url=os.environ["API_BASE_URL"], temperature=0.3
) )
embed_config = EmbedConfig( embed_config = EmbedConfig(
@ -246,7 +245,7 @@ query_content = "UtilsTest 这个类中测试了哪些函数,测试的函数代
query = Message( query = Message(
role_name="human", role_type="user", role_name="human", role_type="user",
role_content=query_content, input_query=query_content, origin_query=query_content, role_content=query_content, input_query=query_content, origin_query=query_content,
code_engine_name="client", score_threshold=1.0, top_k=3, cb_search_type="tag" code_engine_name="client_1", score_threshold=1.0, top_k=3, cb_search_type="tag"
) )

View File

@ -24,7 +24,7 @@ os.environ["log_verbose"] = "2"
phase_name = "codeToolReactPhase" phase_name = "codeToolReactPhase"
llm_config = LLMConfig( llm_config = LLMConfig(
model_name="gpt-3.5-turbo-0613", model_device="cpu",api_key=os.environ["OPENAI_API_KEY"], model_name="gpt-3.5-turbo-0613", api_key=os.environ["OPENAI_API_KEY"],
api_base_url=os.environ["API_BASE_URL"], temperature=0.7 api_base_url=os.environ["API_BASE_URL"], temperature=0.7
) )
embed_config = EmbedConfig( embed_config = EmbedConfig(

View File

@ -17,7 +17,7 @@ from coagent.connector.schema import Message, Memory
tools = toLangchainTools([TOOL_DICT[i] for i in TOOL_SETS if i in TOOL_DICT]) tools = toLangchainTools([TOOL_DICT[i] for i in TOOL_SETS if i in TOOL_DICT])
llm_config = LLMConfig( llm_config = LLMConfig(
model_name="gpt-3.5-turbo", model_device="cpu",api_key=os.environ["OPENAI_API_KEY"], model_name="gpt-3.5-turbo",api_key=os.environ["OPENAI_API_KEY"],
api_base_url=os.environ["API_BASE_URL"], temperature=0.3 api_base_url=os.environ["API_BASE_URL"], temperature=0.3
) )
embed_config = EmbedConfig( embed_config = EmbedConfig(

View File

@ -18,7 +18,7 @@ os.environ["log_verbose"] = "0"
phase_name = "metagpt_code_devlop" phase_name = "metagpt_code_devlop"
llm_config = LLMConfig( llm_config = LLMConfig(
model_name="gpt-3.5-turbo", model_device="cpu",api_key=os.environ["OPENAI_API_KEY"], model_name="gpt-3.5-turbo", api_key=os.environ["OPENAI_API_KEY"],
api_base_url=os.environ["API_BASE_URL"], temperature=0.3 api_base_url=os.environ["API_BASE_URL"], temperature=0.3
) )
embed_config = EmbedConfig( embed_config = EmbedConfig(

View File

@ -20,7 +20,7 @@ os.environ["log_verbose"] = "2"
phase_name = "searchChatPhase" phase_name = "searchChatPhase"
llm_config = LLMConfig( llm_config = LLMConfig(
model_name="gpt-3.5-turbo", model_device="cpu",api_key=os.environ["OPENAI_API_KEY"], model_name="gpt-3.5-turbo", api_key=os.environ["OPENAI_API_KEY"],
api_base_url=os.environ["API_BASE_URL"], temperature=0.3 api_base_url=os.environ["API_BASE_URL"], temperature=0.3
) )
embed_config = EmbedConfig( embed_config = EmbedConfig(

View File

@ -18,7 +18,7 @@ os.environ["log_verbose"] = "2"
phase_name = "toolReactPhase" phase_name = "toolReactPhase"
llm_config = LLMConfig( llm_config = LLMConfig(
model_name="gpt-3.5-turbo", model_device="cpu",api_key=os.environ["OPENAI_API_KEY"], model_name="gpt-3.5-turbo",api_key=os.environ["OPENAI_API_KEY"],
api_base_url=os.environ["API_BASE_URL"], temperature=0.3 api_base_url=os.environ["API_BASE_URL"], temperature=0.3
) )
embed_config = EmbedConfig( embed_config = EmbedConfig(

View File

@ -151,9 +151,9 @@ def create_app():
)(delete_cb) )(delete_cb)
app.post("/code_base/code_base_chat", app.post("/code_base/code_base_chat",
tags=["Code Base Management"], tags=["Code Base Management"],
summary="删除 code_base" summary="code_base 对话"
)(delete_cb) )(search_code)
app.get("/code_base/list_code_bases", app.get("/code_base/list_code_bases",
tags=["Code Base Management"], tags=["Code Base Management"],

View File

@ -117,7 +117,7 @@ PHASE_CONFIGS.update({
llm_config = LLMConfig( llm_config = LLMConfig(
model_name="gpt-3.5-turbo", model_device="cpu",api_key=os.environ["OPENAI_API_KEY"], model_name="gpt-3.5-turbo", api_key=os.environ["OPENAI_API_KEY"],
api_base_url=os.environ["API_BASE_URL"], temperature=0.3 api_base_url=os.environ["API_BASE_URL"], temperature=0.3
) )

View File

@ -98,12 +98,6 @@ def start_docker(client, script_shs, ports, image_name, container_name, mounts=N
network_name ='my_network' network_name ='my_network'
def start_sandbox_service(network_name ='my_network'): def start_sandbox_service(network_name ='my_network'):
# networks = client.networks.list()
# if any([network_name==i.attrs["Name"] for i in networks]):
# network = client.networks.get(network_name)
# else:
# network = client.networks.create('my_network', driver='bridge')
mount = Mount( mount = Mount(
type='bind', type='bind',
source=os.path.join(src_dir, "jupyter_work"), source=os.path.join(src_dir, "jupyter_work"),
@ -114,6 +108,12 @@ def start_sandbox_service(network_name ='my_network'):
# 沙盒的启动与服务的启动是独立的 # 沙盒的启动与服务的启动是独立的
if SANDBOX_SERVER["do_remote"]: if SANDBOX_SERVER["do_remote"]:
client = docker.from_env() client = docker.from_env()
networks = client.networks.list()
if any([network_name==i.attrs["Name"] for i in networks]):
network = client.networks.get(network_name)
else:
network = client.networks.create('my_network', driver='bridge')
# 启动容器 # 启动容器
logger.info("start container sandbox service") logger.info("start container sandbox service")
JUPYTER_WORK_PATH = "/home/user/chatbot/jupyter_work" JUPYTER_WORK_PATH = "/home/user/chatbot/jupyter_work"
@ -150,7 +150,7 @@ def start_api_service(sandbox_host=DEFAULT_BIND_HOST):
client = docker.from_env() client = docker.from_env()
logger.info("start container service") logger.info("start container service")
check_process("api.py", do_stop=True) check_process("api.py", do_stop=True)
check_process("sdfile_api.py", do_stop=True) check_process("llm_api.py", do_stop=True)
check_process("sdfile_api.py", do_stop=True) check_process("sdfile_api.py", do_stop=True)
check_process("webui.py", do_stop=True) check_process("webui.py", do_stop=True)
mount = Mount( mount = Mount(
@ -159,27 +159,28 @@ def start_api_service(sandbox_host=DEFAULT_BIND_HOST):
target='/home/user/chatbot/', target='/home/user/chatbot/',
read_only=False # 如果需要只读访问将此选项设置为True read_only=False # 如果需要只读访问将此选项设置为True
) )
mount_database = Mount( # mount_database = Mount(
type='bind', # type='bind',
source=os.path.join(src_dir, "knowledge_base"), # source=os.path.join(src_dir, "knowledge_base"),
target='/home/user/knowledge_base/', # target='/home/user/knowledge_base/',
read_only=False # 如果需要只读访问将此选项设置为True # read_only=False # 如果需要只读访问将此选项设置为True
) # )
mount_code_database = Mount( # mount_code_database = Mount(
type='bind', # type='bind',
source=os.path.join(src_dir, "code_base"), # source=os.path.join(src_dir, "code_base"),
target='/home/user/code_base/', # target='/home/user/code_base/',
read_only=False # 如果需要只读访问将此选项设置为True # read_only=False # 如果需要只读访问将此选项设置为True
) # )
ports={ ports={
f"{API_SERVER['docker_port']}/tcp": f"{API_SERVER['port']}/tcp", f"{API_SERVER['docker_port']}/tcp": f"{API_SERVER['port']}/tcp",
f"{WEBUI_SERVER['docker_port']}/tcp": f"{WEBUI_SERVER['port']}/tcp", f"{WEBUI_SERVER['docker_port']}/tcp": f"{WEBUI_SERVER['port']}/tcp",
f"{SDFILE_API_SERVER['docker_port']}/tcp": f"{SDFILE_API_SERVER['port']}/tcp", f"{SDFILE_API_SERVER['docker_port']}/tcp": f"{SDFILE_API_SERVER['port']}/tcp",
f"{NEBULA_GRAPH_SERVER['docker_port']}/tcp": f"{NEBULA_GRAPH_SERVER['port']}/tcp" f"{NEBULA_GRAPH_SERVER['docker_port']}/tcp": f"{NEBULA_GRAPH_SERVER['port']}/tcp"
} }
mounts = [mount, mount_database, mount_code_database] # mounts = [mount, mount_database, mount_code_database]
mounts = [mount]
script_shs = [ script_shs = [
"mkdir -p /home/user/logs", "mkdir -p /home/user/chatbot/logs",
''' '''
if [ -d "/home/user/chatbot/data/nebula_data/data/meta" ]; then if [ -d "/home/user/chatbot/data/nebula_data/data/meta" ]; then
cp -r /home/user/chatbot/data/nebula_data/data /usr/local/nebula/ cp -r /home/user/chatbot/data/nebula_data/data /usr/local/nebula/
@ -197,12 +198,12 @@ def start_api_service(sandbox_host=DEFAULT_BIND_HOST):
"pip install jieba", "pip install jieba",
"pip install duckduckgo-search", "pip install duckduckgo-search",
"nohup python chatbot/examples/sdfile_api.py > /home/user/logs/sdfile_api.log 2>&1 &", "nohup python chatbot/examples/sdfile_api.py > /home/user/chatbot/logs/sdfile_api.log 2>&1 &",
f"export DUCKDUCKGO_PROXY=socks5://host.docker.internal:13659 && export SANDBOX_HOST={sandbox_host} &&\ f"export DUCKDUCKGO_PROXY=socks5://host.docker.internal:13659 && export SANDBOX_HOST={sandbox_host} &&\
nohup python chatbot/examples/api.py > /home/user/logs/api.log 2>&1 &", nohup python chatbot/examples/api.py > /home/user/chatbot/logs/api.log 2>&1 &",
"nohup python chatbot/examples/llm_api.py > /home/user/llm.log 2>&1 &", "nohup python chatbot/examples/llm_api.py > /home/user/llm.log 2>&1 &",
f"export DUCKDUCKGO_PROXY=socks5://host.docker.internal:13659 && export SANDBOX_HOST={sandbox_host} &&\ f"export DUCKDUCKGO_PROXY=socks5://host.docker.internal:13659 && export SANDBOX_HOST={sandbox_host} &&\
cd chatbot/examples && nohup streamlit run webui.py > /home/user/logs/start_webui.log 2>&1 &" cd chatbot/examples && nohup streamlit run webui.py > /home/user/chatbot/logs/start_webui.log 2>&1 &"
] ]
if check_docker(client, CONTRAINER_NAME, do_stop=True): if check_docker(client, CONTRAINER_NAME, do_stop=True):
container = start_docker(client, script_shs, ports, IMAGE_NAME, CONTRAINER_NAME, mounts, network=network_name) container = start_docker(client, script_shs, ports, IMAGE_NAME, CONTRAINER_NAME, mounts, network=network_name)
@ -212,12 +213,9 @@ def start_api_service(sandbox_host=DEFAULT_BIND_HOST):
# 关闭之前启动的docker 服务 # 关闭之前启动的docker 服务
# check_docker(client, CONTRAINER_NAME, do_stop=True, ) # check_docker(client, CONTRAINER_NAME, do_stop=True, )
# api_sh = "nohup python ../coagent/service/api.py > ../logs/api.log 2>&1 &"
api_sh = "nohup python api.py > ../logs/api.log 2>&1 &" api_sh = "nohup python api.py > ../logs/api.log 2>&1 &"
# sdfile_sh = "nohup python ../coagent/service/sdfile_api.py > ../logs/sdfile_api.log 2>&1 &"
sdfile_sh = "nohup python sdfile_api.py > ../logs/sdfile_api.log 2>&1 &" sdfile_sh = "nohup python sdfile_api.py > ../logs/sdfile_api.log 2>&1 &"
notebook_sh = f"nohup jupyter-notebook --NotebookApp.token=mytoken --port={SANDBOX_SERVER['port']} --allow-root --ip=0.0.0.0 --notebook-dir={JUPYTER_WORK_PATH} --no-browser --ServerApp.disable_check_xsrf=True > ../logs/sandbox.log 2>&1 &" notebook_sh = f"nohup jupyter-notebook --NotebookApp.token=mytoken --port={SANDBOX_SERVER['port']} --allow-root --ip=0.0.0.0 --notebook-dir={JUPYTER_WORK_PATH} --no-browser --ServerApp.disable_check_xsrf=True > ../logs/sandbox.log 2>&1 &"
# llm_sh = "nohup python ../coagent/service/llm_api.py > ../logs/llm_api.log 2>&1 &"
llm_sh = "nohup python llm_api.py > ../logs/llm_api.log 2>&1 &" llm_sh = "nohup python llm_api.py > ../logs/llm_api.log 2>&1 &"
webui_sh = "streamlit run webui.py" if USE_TTY else "streamlit run webui.py" webui_sh = "streamlit run webui.py" if USE_TTY else "streamlit run webui.py"

View File

@ -22,7 +22,7 @@ from coagent.service.service_factory import get_cb_details, get_cb_details_by_cb
from coagent.orm import table_init from coagent.orm import table_init
from configs.model_config import EMBEDDING_DEVICE, EMBEDDING_ENGINE, EMBEDDING_MODEL, embedding_model_dict from configs.model_config import EMBEDDING_DEVICE, EMBEDDING_ENGINE, EMBEDDING_MODEL, embedding_model_dict,llm_model_dict
# SENTENCE_SIZE = 100 # SENTENCE_SIZE = 100
cell_renderer = JsCode("""function(params) {if(params.value==true){return ''}else{return '×'}}""") cell_renderer = JsCode("""function(params) {if(params.value==true){return ''}else{return '×'}}""")
@ -117,6 +117,8 @@ def code_page(api: ApiRequest):
embed_model_path=embedding_model_dict[EMBEDDING_MODEL], embed_model_path=embedding_model_dict[EMBEDDING_MODEL],
embedding_device=EMBEDDING_DEVICE, embedding_device=EMBEDDING_DEVICE,
llm_model=LLM_MODEL, llm_model=LLM_MODEL,
api_key=llm_model_dict[LLM_MODEL]["api_key"],
api_base_url=llm_model_dict[LLM_MODEL]["api_base_url"],
) )
st.toast(ret.get("msg", " ")) st.toast(ret.get("msg", " "))
st.session_state["selected_cb_name"] = cb_name st.session_state["selected_cb_name"] = cb_name
@ -153,6 +155,8 @@ def code_page(api: ApiRequest):
embed_model_path=embedding_model_dict[EMBEDDING_MODEL], embed_model_path=embedding_model_dict[EMBEDDING_MODEL],
embedding_device=EMBEDDING_DEVICE, embedding_device=EMBEDDING_DEVICE,
llm_model=LLM_MODEL, llm_model=LLM_MODEL,
api_key=llm_model_dict[LLM_MODEL]["api_key"],
api_base_url=llm_model_dict[LLM_MODEL]["api_base_url"],
) )
st.toast(ret.get("msg", "删除成功")) st.toast(ret.get("msg", "删除成功"))
time.sleep(0.05) time.sleep(0.05)

View File

@ -11,7 +11,7 @@ from coagent.chat.search_chat import SEARCH_ENGINES
from coagent.connector import PHASE_LIST, PHASE_CONFIGS from coagent.connector import PHASE_LIST, PHASE_CONFIGS
from coagent.service.service_factory import get_cb_details_by_cb_name from coagent.service.service_factory import get_cb_details_by_cb_name
from configs.model_config import EMBEDDING_DEVICE, EMBEDDING_MODEL, embedding_model_dict, EMBEDDING_ENGINE, KB_ROOT_PATH from configs.model_config import EMBEDDING_DEVICE, EMBEDDING_MODEL, embedding_model_dict, EMBEDDING_ENGINE, KB_ROOT_PATH, llm_model_dict
chat_box = ChatBox( chat_box = ChatBox(
assistant_avatar="../sources/imgs/devops-chatbot2.png" assistant_avatar="../sources/imgs/devops-chatbot2.png"
) )
@ -174,7 +174,7 @@ def dialogue_page(api: ApiRequest):
is_detailed = st.toggle(webui_configs["dialogue"]["phase_toggle_detailed_name"], False) is_detailed = st.toggle(webui_configs["dialogue"]["phase_toggle_detailed_name"], False)
tool_using_on = st.toggle( tool_using_on = st.toggle(
webui_configs["dialogue"]["phase_toggle_doToolUsing"], webui_configs["dialogue"]["phase_toggle_doToolUsing"],
PHASE_CONFIGS[choose_phase]["do_using_tool"]) PHASE_CONFIGS[choose_phase].get("do_using_tool", False))
tool_selects = [] tool_selects = []
if tool_using_on: if tool_using_on:
with st.expander("工具军火库", True): with st.expander("工具军火库", True):
@ -183,7 +183,7 @@ def dialogue_page(api: ApiRequest):
TOOL_SETS, ["WeatherInfo"]) TOOL_SETS, ["WeatherInfo"])
search_on = st.toggle(webui_configs["dialogue"]["phase_toggle_doSearch"], search_on = st.toggle(webui_configs["dialogue"]["phase_toggle_doSearch"],
PHASE_CONFIGS[choose_phase]["do_search"]) PHASE_CONFIGS[choose_phase].get("do_search", False))
search_engine, top_k = None, 3 search_engine, top_k = None, 3
if search_on: if search_on:
with st.expander(webui_configs["dialogue"]["expander_search_name"], True): with st.expander(webui_configs["dialogue"]["expander_search_name"], True):
@ -195,7 +195,8 @@ def dialogue_page(api: ApiRequest):
doc_retrieval_on = st.toggle( doc_retrieval_on = st.toggle(
webui_configs["dialogue"]["phase_toggle_doDocRetrieval"], webui_configs["dialogue"]["phase_toggle_doDocRetrieval"],
PHASE_CONFIGS[choose_phase]["do_doc_retrieval"]) PHASE_CONFIGS[choose_phase].get("do_doc_retrieval", False)
)
selected_kb, top_k, score_threshold = None, 3, 1.0 selected_kb, top_k, score_threshold = None, 3, 1.0
if doc_retrieval_on: if doc_retrieval_on:
with st.expander(webui_configs["dialogue"]["kbase_expander_name"], True): with st.expander(webui_configs["dialogue"]["kbase_expander_name"], True):
@ -215,7 +216,7 @@ def dialogue_page(api: ApiRequest):
code_retrieval_on = st.toggle( code_retrieval_on = st.toggle(
webui_configs["dialogue"]["phase_toggle_doCodeRetrieval"], webui_configs["dialogue"]["phase_toggle_doCodeRetrieval"],
PHASE_CONFIGS[choose_phase]["do_code_retrieval"]) PHASE_CONFIGS[choose_phase].get("do_code_retrieval", False))
selected_cb, top_k = None, 1 selected_cb, top_k = None, 1
cb_search_type = "tag" cb_search_type = "tag"
if code_retrieval_on: if code_retrieval_on:
@ -296,7 +297,8 @@ def dialogue_page(api: ApiRequest):
r = api.chat_chat( r = api.chat_chat(
prompt, history, no_remote_api=True, prompt, history, no_remote_api=True,
embed_model=EMBEDDING_MODEL, embed_model_path=embedding_model_dict[EMBEDDING_MODEL], embed_model=EMBEDDING_MODEL, embed_model_path=embedding_model_dict[EMBEDDING_MODEL],
model_device=EMBEDDING_DEVICE, embed_engine=EMBEDDING_ENGINE, model_device=EMBEDDING_DEVICE, embed_engine=EMBEDDING_ENGINE,api_key=llm_model_dict[LLM_MODEL]["api_key"],
api_base_url=llm_model_dict[LLM_MODEL]["api_base_url"],
llm_model=LLM_MODEL) llm_model=LLM_MODEL)
for t in r: for t in r:
if error_msg := check_error_msg(t): # check whether error occured if error_msg := check_error_msg(t): # check whether error occured
@ -362,6 +364,8 @@ def dialogue_page(api: ApiRequest):
"embed_engine": EMBEDDING_ENGINE, "embed_engine": EMBEDDING_ENGINE,
"kb_root_path": KB_ROOT_PATH, "kb_root_path": KB_ROOT_PATH,
"model_name": LLM_MODEL, "model_name": LLM_MODEL,
"api_key": llm_model_dict[LLM_MODEL]["api_key"],
"api_base_url": llm_model_dict[LLM_MODEL]["api_base_url"],
} }
text = "" text = ""
d = {"docs": []} d = {"docs": []}
@ -405,7 +409,10 @@ def dialogue_page(api: ApiRequest):
api.knowledge_base_chat( api.knowledge_base_chat(
prompt, selected_kb, kb_top_k, score_threshold, history, prompt, selected_kb, kb_top_k, score_threshold, history,
embed_model=EMBEDDING_MODEL, embed_model_path=embedding_model_dict[EMBEDDING_MODEL], embed_model=EMBEDDING_MODEL, embed_model_path=embedding_model_dict[EMBEDDING_MODEL],
model_device=EMBEDDING_DEVICE, embed_engine=EMBEDDING_ENGINE, llm_model=LLM_MODEL) model_device=EMBEDDING_DEVICE, embed_engine=EMBEDDING_ENGINE, llm_model=LLM_MODEL,
api_key=llm_model_dict[LLM_MODEL]["api_key"],
api_base_url=llm_model_dict[LLM_MODEL]["api_base_url"],
)
): ):
if error_msg := check_error_msg(d): # check whether error occured if error_msg := check_error_msg(d): # check whether error occured
st.error(error_msg) st.error(error_msg)
@ -415,11 +422,7 @@ def dialogue_page(api: ApiRequest):
# chat_box.update_msg("知识库匹配结果: \n\n".join(d["docs"]), element_index=1, streaming=False, state="complete") # chat_box.update_msg("知识库匹配结果: \n\n".join(d["docs"]), element_index=1, streaming=False, state="complete")
chat_box.update_msg(text, element_index=0, streaming=False) # 更新最终的字符串,去除光标 chat_box.update_msg(text, element_index=0, streaming=False) # 更新最终的字符串,去除光标
chat_box.update_msg("{webui_configs['chat']['chatbox_doc_result']}: \n\n".join(d["docs"]), element_index=1, streaming=False, state="complete") chat_box.update_msg("{webui_configs['chat']['chatbox_doc_result']}: \n\n".join(d["docs"]), element_index=1, streaming=False, state="complete")
# # 判断是否存在代码, 并提高编辑功能,执行功能
# code_text = api.codebox.decode_code_from_text(text)
# GLOBAL_EXE_CODE_TEXT = code_text
# if code_text and code_exec_on:
# codebox_res = api.codebox_chat("```"+code_text+"```", do_code_exe=True)
elif dialogue_mode == webui_configs["dialogue"]["mode"][2]: elif dialogue_mode == webui_configs["dialogue"]["mode"][2]:
logger.info('prompt={}'.format(prompt)) logger.info('prompt={}'.format(prompt))
logger.info('history={}'.format(history)) logger.info('history={}'.format(history))
@ -438,7 +441,9 @@ def dialogue_page(api: ApiRequest):
cb_search_type=cb_search_type, cb_search_type=cb_search_type,
no_remote_api=True, embed_model=EMBEDDING_MODEL, no_remote_api=True, embed_model=EMBEDDING_MODEL,
embed_model_path=embedding_model_dict[EMBEDDING_MODEL], embed_model_path=embedding_model_dict[EMBEDDING_MODEL],
embed_engine=EMBEDDING_ENGINE, llm_model=LLM_MODEL embed_engine=EMBEDDING_ENGINE, llm_model=LLM_MODEL,
api_key=llm_model_dict[LLM_MODEL]["api_key"],
api_base_url=llm_model_dict[LLM_MODEL]["api_base_url"],
)): )):
if error_msg := check_error_msg(d): if error_msg := check_error_msg(d):
st.error(error_msg) st.error(error_msg)
@ -448,6 +453,7 @@ def dialogue_page(api: ApiRequest):
chat_box.update_msg(text, element_index=0) chat_box.update_msg(text, element_index=0)
# postprocess # postprocess
logger.debug(f"d={d}")
text = replace_lt_gt(text) text = replace_lt_gt(text)
chat_box.update_msg(text, element_index=0, streaming=False) # 更新最终的字符串,去除光标 chat_box.update_msg(text, element_index=0, streaming=False) # 更新最终的字符串,去除光标
logger.debug('text={}'.format(text)) logger.debug('text={}'.format(text))
@ -467,7 +473,9 @@ def dialogue_page(api: ApiRequest):
api.search_engine_chat( api.search_engine_chat(
prompt, search_engine, se_top_k, history, embed_model=EMBEDDING_MODEL, prompt, search_engine, se_top_k, history, embed_model=EMBEDDING_MODEL,
embed_model_path=embedding_model_dict[EMBEDDING_MODEL], embed_model_path=embedding_model_dict[EMBEDDING_MODEL],
model_device=EMBEDDING_DEVICE, embed_engine=EMBEDDING_ENGINE, llm_model=LLM_MODEL) model_device=EMBEDDING_DEVICE, embed_engine=EMBEDDING_ENGINE, llm_model=LLM_MODEL,
pi_key=llm_model_dict[LLM_MODEL]["api_key"],
api_base_url=llm_model_dict[LLM_MODEL]["api_base_url"],)
): ):
if error_msg := check_error_msg(d): # check whether error occured if error_msg := check_error_msg(d): # check whether error occured
st.error(error_msg) st.error(error_msg)
@ -477,56 +485,11 @@ def dialogue_page(api: ApiRequest):
# chat_box.update_msg("搜索匹配结果: \n\n".join(d["docs"]), element_index=1, streaming=False) # chat_box.update_msg("搜索匹配结果: \n\n".join(d["docs"]), element_index=1, streaming=False)
chat_box.update_msg(text, element_index=0, streaming=False) # 更新最终的字符串,去除光标 chat_box.update_msg(text, element_index=0, streaming=False) # 更新最终的字符串,去除光标
chat_box.update_msg(f"{webui_configs['chat']['chatbox_search_result']}: \n\n".join(d["docs"]), element_index=1, streaming=False, state="complete") chat_box.update_msg(f"{webui_configs['chat']['chatbox_search_result']}: \n\n".join(d["docs"]), element_index=1, streaming=False, state="complete")
# # 判断是否存在代码, 并提高编辑功能,执行功能
# code_text = api.codebox.decode_code_from_text(text)
# GLOBAL_EXE_CODE_TEXT = code_text
# if code_text and code_exec_on:
# codebox_res = api.codebox_chat("```"+code_text+"```", do_code_exe=True)
# 将上传文件清空 # 将上传文件清空
st.session_state["interpreter_file_key"] += 1 st.session_state["interpreter_file_key"] += 1
st.experimental_rerun() st.experimental_rerun()
# if code_interpreter_on:
# with st.expander(webui_configs['sandbox']['expander_code_name'], False):
# code_part = st.text_area(
# webui_configs['sandbox']['textArea_code_name'], code_text, key="code_text")
# cols = st.columns(2)
# if cols[0].button(
# webui_configs['sandbox']['button_modify_code_name'],
# use_container_width=True,
# ):
# code_text = code_part
# GLOBAL_EXE_CODE_TEXT = code_text
# st.toast(webui_configs['sandbox']['text_modify_code'])
# if cols[1].button(
# webui_configs['sandbox']['button_exec_code_name'],
# use_container_width=True
# ):
# if code_text:
# codebox_res = api.codebox_chat("```"+code_text+"```", do_code_exe=True)
# st.toast(webui_configs['sandbox']['text_execing_code'],)
# else:
# st.toast(webui_configs['sandbox']['text_error_exec_code'],)
# #TODO 这段信息会被记录到history里
# if codebox_res is not None and codebox_res.code_exe_status != 200:
# st.toast(f"{codebox_res.code_exe_response}")
# if codebox_res is not None and codebox_res.code_exe_status == 200:
# st.toast(f"codebox_chat {codebox_res}")
# chat_box.ai_say(Markdown(code_text, in_expander=True, title="code interpreter", unsafe_allow_html=True), )
# if codebox_res.code_exe_type == "image/png":
# base_text = f"```\n{code_text}\n```\n\n"
# img_html = "<img src='data:image/png;base64,{}' class='img-fluid'>".format(
# codebox_res.code_exe_response
# )
# chat_box.update_msg(img_html, streaming=False, state="complete")
# else:
# chat_box.update_msg('```\n'+code_text+'\n```'+"\n\n"+'```\n'+codebox_res.code_exe_response+'\n```',
# streaming=False, state="complete")
now = datetime.now() now = datetime.now()
with st.sidebar: with st.sidebar:

View File

@ -14,7 +14,8 @@ from coagent.orm import table_init
from configs.model_config import ( from configs.model_config import (
KB_ROOT_PATH, kbs_config, DEFAULT_VS_TYPE, WEB_CRAWL_PATH, KB_ROOT_PATH, kbs_config, DEFAULT_VS_TYPE, WEB_CRAWL_PATH,
EMBEDDING_DEVICE, EMBEDDING_ENGINE, EMBEDDING_MODEL, embedding_model_dict EMBEDDING_DEVICE, EMBEDDING_ENGINE, EMBEDDING_MODEL, embedding_model_dict,
llm_model_dict
) )
# SENTENCE_SIZE = 100 # SENTENCE_SIZE = 100
@ -136,6 +137,8 @@ def knowledge_page(
embed_engine=EMBEDDING_ENGINE, embed_engine=EMBEDDING_ENGINE,
embedding_device= EMBEDDING_DEVICE, embedding_device= EMBEDDING_DEVICE,
embed_model_path=embedding_model_dict[embed_model], embed_model_path=embedding_model_dict[embed_model],
api_key=llm_model_dict[LLM_MODEL]["api_key"],
api_base_url=llm_model_dict[LLM_MODEL]["api_base_url"],
) )
st.toast(ret.get("msg", " ")) st.toast(ret.get("msg", " "))
st.session_state["selected_kb_name"] = kb_name st.session_state["selected_kb_name"] = kb_name
@ -160,7 +163,10 @@ def knowledge_page(
data = [{"file": f, "knowledge_base_name": kb, "not_refresh_vs_cache": True, "embed_model": EMBEDDING_MODEL, data = [{"file": f, "knowledge_base_name": kb, "not_refresh_vs_cache": True, "embed_model": EMBEDDING_MODEL,
"embed_model_path": embedding_model_dict[EMBEDDING_MODEL], "embed_model_path": embedding_model_dict[EMBEDDING_MODEL],
"model_device": EMBEDDING_DEVICE, "model_device": EMBEDDING_DEVICE,
"embed_engine": EMBEDDING_ENGINE} "embed_engine": EMBEDDING_ENGINE,
"api_key": llm_model_dict[LLM_MODEL]["api_key"],
"api_base_url": llm_model_dict[LLM_MODEL]["api_base_url"],
}
for f in files] for f in files]
data[-1]["not_refresh_vs_cache"]=False data[-1]["not_refresh_vs_cache"]=False
for k in data: for k in data:
@ -210,7 +216,9 @@ def knowledge_page(
"embed_model": EMBEDDING_MODEL, "embed_model": EMBEDDING_MODEL,
"embed_model_path": embedding_model_dict[EMBEDDING_MODEL], "embed_model_path": embedding_model_dict[EMBEDDING_MODEL],
"model_device": EMBEDDING_DEVICE, "model_device": EMBEDDING_DEVICE,
"embed_engine": EMBEDDING_ENGINE}] "embed_engine": EMBEDDING_ENGINE,
"api_key": llm_model_dict[LLM_MODEL]["api_key"],
"api_base_url": llm_model_dict[LLM_MODEL]["api_base_url"],}]
for k in data: for k in data:
ret = api.upload_kb_doc(**k) ret = api.upload_kb_doc(**k)
logger.info(ret) logger.info(ret)
@ -297,7 +305,9 @@ def knowledge_page(
api.update_kb_doc(kb, row["file_name"], api.update_kb_doc(kb, row["file_name"],
embed_engine=EMBEDDING_ENGINE,embed_model=EMBEDDING_MODEL, embed_engine=EMBEDDING_ENGINE,embed_model=EMBEDDING_MODEL,
embed_model_path=embedding_model_dict[EMBEDDING_MODEL], embed_model_path=embedding_model_dict[EMBEDDING_MODEL],
model_device=EMBEDDING_DEVICE model_device=EMBEDDING_DEVICE,
api_key=llm_model_dict[LLM_MODEL]["api_key"],
api_base_url=llm_model_dict[LLM_MODEL]["api_base_url"],
) )
st.experimental_rerun() st.experimental_rerun()
@ -311,7 +321,9 @@ def knowledge_page(
api.delete_kb_doc(kb, row["file_name"], api.delete_kb_doc(kb, row["file_name"],
embed_engine=EMBEDDING_ENGINE,embed_model=EMBEDDING_MODEL, embed_engine=EMBEDDING_ENGINE,embed_model=EMBEDDING_MODEL,
embed_model_path=embedding_model_dict[EMBEDDING_MODEL], embed_model_path=embedding_model_dict[EMBEDDING_MODEL],
model_device=EMBEDDING_DEVICE) model_device=EMBEDDING_DEVICE,
api_key=llm_model_dict[LLM_MODEL]["api_key"],
api_base_url=llm_model_dict[LLM_MODEL]["api_base_url"],)
st.experimental_rerun() st.experimental_rerun()
if cols[3].button( if cols[3].button(
@ -323,7 +335,9 @@ def knowledge_page(
ret = api.delete_kb_doc(kb, row["file_name"], True, ret = api.delete_kb_doc(kb, row["file_name"], True,
embed_engine=EMBEDDING_ENGINE,embed_model=EMBEDDING_MODEL, embed_engine=EMBEDDING_ENGINE,embed_model=EMBEDDING_MODEL,
embed_model_path=embedding_model_dict[EMBEDDING_MODEL], embed_model_path=embedding_model_dict[EMBEDDING_MODEL],
model_device=EMBEDDING_DEVICE) model_device=EMBEDDING_DEVICE,
api_key=llm_model_dict[LLM_MODEL]["api_key"],
api_base_url=llm_model_dict[LLM_MODEL]["api_base_url"],)
st.toast(ret.get("msg", " ")) st.toast(ret.get("msg", " "))
st.experimental_rerun() st.experimental_rerun()
@ -344,6 +358,8 @@ def knowledge_page(
for d in api.recreate_vector_store( for d in api.recreate_vector_store(
kb, vs_type=default_vs_type, embed_model=embedding_model, embedding_device=EMBEDDING_DEVICE, kb, vs_type=default_vs_type, embed_model=embedding_model, embedding_device=EMBEDDING_DEVICE,
embed_model_path=embedding_model_dict["embedding_model"], embed_engine=EMBEDDING_ENGINE, embed_model_path=embedding_model_dict["embedding_model"], embed_engine=EMBEDDING_ENGINE,
api_key=llm_model_dict[LLM_MODEL]["api_key"],
api_base_url=llm_model_dict[LLM_MODEL]["api_base_url"],
): ):
if msg := check_error_msg(d): if msg := check_error_msg(d):
st.toast(msg) st.toast(msg)

View File

@ -299,7 +299,9 @@ class ApiRequest:
stream: bool = True, stream: bool = True,
no_remote_api: bool = None, no_remote_api: bool = None,
embed_model: str="", embed_model_path: str="", model_device: str="", embed_engine: str="", embed_model: str="", embed_model_path: str="", model_device: str="", embed_engine: str="",
llm_model: str ="", temperature: float= 0.2 llm_model: str ="", temperature: float= 0.2,
api_key: str=os.environ["OPENAI_API_KEY"],
api_base_url: str = os.environ["API_BASE_URL"],
): ):
''' '''
对应api.py/chat/chat接口 对应api.py/chat/chat接口
@ -311,8 +313,8 @@ class ApiRequest:
"query": query, "query": query,
"history": history, "history": history,
"stream": stream, "stream": stream,
"api_key": os.environ["OPENAI_API_KEY"], "api_key": api_key,
"api_base_url": os.environ["API_BASE_URL"], "api_base_url": api_base_url,
"embed_model": embed_model, "embed_model": embed_model,
"embed_model_path": embed_model_path, "embed_model_path": embed_model_path,
"embed_engine": embed_engine, "embed_engine": embed_engine,
@ -339,7 +341,9 @@ class ApiRequest:
stream: bool = True, stream: bool = True,
no_remote_api: bool = None, no_remote_api: bool = None,
embed_model: str="", embed_model_path: str="", model_device: str="", embed_engine: str="", embed_model: str="", embed_model_path: str="", model_device: str="", embed_engine: str="",
llm_model: str ="", temperature: float= 0.2 llm_model: str ="", temperature: float= 0.2,
api_key: str=os.environ["OPENAI_API_KEY"],
api_base_url: str = os.environ["API_BASE_URL"],
): ):
''' '''
对应api.py/chat/knowledge_base_chat接口 对应api.py/chat/knowledge_base_chat接口
@ -355,8 +359,8 @@ class ApiRequest:
"history": history, "history": history,
"stream": stream, "stream": stream,
"local_doc_url": no_remote_api, "local_doc_url": no_remote_api,
"api_key": os.environ["OPENAI_API_KEY"], "api_key": api_key,
"api_base_url": os.environ["API_BASE_URL"], "api_base_url": api_base_url,
"embed_model": embed_model, "embed_model": embed_model,
"embed_model_path": embed_model_path, "embed_model_path": embed_model_path,
"embed_engine": embed_engine, "embed_engine": embed_engine,
@ -386,7 +390,10 @@ class ApiRequest:
stream: bool = True, stream: bool = True,
no_remote_api: bool = None, no_remote_api: bool = None,
embed_model: str="", embed_model_path: str="", model_device: str="", embed_engine: str="", embed_model: str="", embed_model_path: str="", model_device: str="", embed_engine: str="",
llm_model: str ="", temperature: float= 0.2 llm_model: str ="", temperature: float= 0.2,
api_key: str=os.environ["OPENAI_API_KEY"],
api_base_url: str = os.environ["API_BASE_URL"],
): ):
''' '''
对应api.py/chat/search_engine_chat接口 对应api.py/chat/search_engine_chat接口
@ -400,8 +407,8 @@ class ApiRequest:
"top_k": top_k, "top_k": top_k,
"history": history, "history": history,
"stream": stream, "stream": stream,
"api_key": os.environ["OPENAI_API_KEY"], "api_key": api_key,
"api_base_url": os.environ["API_BASE_URL"], "api_base_url": api_base_url,
"embed_model": embed_model, "embed_model": embed_model,
"embed_model_path": embed_model_path, "embed_model_path": embed_model_path,
"embed_engine": embed_engine, "embed_engine": embed_engine,
@ -432,7 +439,9 @@ class ApiRequest:
stream: bool = True, stream: bool = True,
no_remote_api: bool = None, no_remote_api: bool = None,
embed_model: str="", embed_model_path: str="", model_device: str="", embed_engine: str="", embed_model: str="", embed_model_path: str="", model_device: str="", embed_engine: str="",
llm_model: str ="", temperature: float= 0.2 llm_model: str ="", temperature: float= 0.2,
api_key: str=os.environ["OPENAI_API_KEY"],
api_base_url: str = os.environ["API_BASE_URL"],
): ):
''' '''
对应api.py/chat/knowledge_base_chat接口 对应api.py/chat/knowledge_base_chat接口
@ -458,8 +467,8 @@ class ApiRequest:
"cb_search_type": cb_search_type, "cb_search_type": cb_search_type,
"stream": stream, "stream": stream,
"local_doc_url": no_remote_api, "local_doc_url": no_remote_api,
"api_key": os.environ["OPENAI_API_KEY"], "api_key": api_key,
"api_base_url": os.environ["API_BASE_URL"], "api_base_url": api_base_url,
"embed_model": embed_model, "embed_model": embed_model,
"embed_model_path": embed_model_path, "embed_model_path": embed_model_path,
"embed_engine": embed_engine, "embed_engine": embed_engine,
@ -510,6 +519,8 @@ class ApiRequest:
embed_model: str="", embed_model_path: str="", embed_model: str="", embed_model_path: str="",
model_device: str="", embed_engine: str="", model_device: str="", embed_engine: str="",
temperature: float=0.2, model_name:str ="", temperature: float=0.2, model_name:str ="",
api_key: str=os.environ["OPENAI_API_KEY"],
api_base_url: str = os.environ["API_BASE_URL"],
): ):
''' '''
对应api.py/chat/chat接口 对应api.py/chat/chat接口
@ -541,8 +552,8 @@ class ApiRequest:
"isDetailed": isDetailed, "isDetailed": isDetailed,
"upload_file": upload_file, "upload_file": upload_file,
"kb_root_path": kb_root_path, "kb_root_path": kb_root_path,
"api_key": os.environ["OPENAI_API_KEY"], "api_key": api_key,
"api_base_url": os.environ["API_BASE_URL"], "api_base_url": api_base_url,
"embed_model": embed_model, "embed_model": embed_model,
"embed_model_path": embed_model_path, "embed_model_path": embed_model_path,
"embed_engine": embed_engine, "embed_engine": embed_engine,
@ -588,6 +599,8 @@ class ApiRequest:
embed_model: str="", embed_model_path: str="", embed_model: str="", embed_model_path: str="",
model_device: str="", embed_engine: str="", model_device: str="", embed_engine: str="",
temperature: float=0.2, model_name: str="", temperature: float=0.2, model_name: str="",
api_key: str=os.environ["OPENAI_API_KEY"],
api_base_url: str = os.environ["API_BASE_URL"],
): ):
''' '''
对应api.py/chat/chat接口 对应api.py/chat/chat接口
@ -620,8 +633,8 @@ class ApiRequest:
"isDetailed": isDetailed, "isDetailed": isDetailed,
"upload_file": upload_file, "upload_file": upload_file,
"kb_root_path": kb_root_path, "kb_root_path": kb_root_path,
"api_key": os.environ["OPENAI_API_KEY"], "api_key": api_key,
"api_base_url": os.environ["API_BASE_URL"], "api_base_url": api_base_url,
"embed_model": embed_model, "embed_model": embed_model,
"embed_model_path": embed_model_path, "embed_model_path": embed_model_path,
"embed_engine": embed_engine, "embed_engine": embed_engine,
@ -694,7 +707,9 @@ class ApiRequest:
no_remote_api: bool = None, no_remote_api: bool = None,
kb_root_path: str =KB_ROOT_PATH, kb_root_path: str =KB_ROOT_PATH,
embed_model: str="", embed_model_path: str="", embed_model: str="", embed_model_path: str="",
embedding_device: str="", embed_engine: str="" embedding_device: str="", embed_engine: str="",
api_key: str=os.environ["OPENAI_API_KEY"],
api_base_url: str = os.environ["API_BASE_URL"],
): ):
''' '''
对应api.py/knowledge_base/create_knowledge_base接口 对应api.py/knowledge_base/create_knowledge_base接口
@ -706,8 +721,8 @@ class ApiRequest:
"knowledge_base_name": knowledge_base_name, "knowledge_base_name": knowledge_base_name,
"vector_store_type": vector_store_type, "vector_store_type": vector_store_type,
"kb_root_path": kb_root_path, "kb_root_path": kb_root_path,
"api_key": os.environ["OPENAI_API_KEY"], "api_key": api_key,
"api_base_url": os.environ["API_BASE_URL"], "api_base_url": api_base_url,
"embed_model": embed_model, "embed_model": embed_model,
"embed_model_path": embed_model_path, "embed_model_path": embed_model_path,
"model_device": embedding_device, "model_device": embedding_device,
@ -781,7 +796,9 @@ class ApiRequest:
no_remote_api: bool = None, no_remote_api: bool = None,
kb_root_path: str = KB_ROOT_PATH, kb_root_path: str = KB_ROOT_PATH,
embed_model: str="", embed_model_path: str="", embed_model: str="", embed_model_path: str="",
model_device: str="", embed_engine: str="" model_device: str="", embed_engine: str="",
api_key: str=os.environ["OPENAI_API_KEY"],
api_base_url: str = os.environ["API_BASE_URL"],
): ):
''' '''
对应api.py/knowledge_base/upload_docs接口 对应api.py/knowledge_base/upload_docs接口
@ -810,8 +827,8 @@ class ApiRequest:
override, override,
not_refresh_vs_cache, not_refresh_vs_cache,
kb_root_path=kb_root_path, kb_root_path=kb_root_path,
api_key=os.environ["OPENAI_API_KEY"], api_key=api_key,
api_base_url=os.environ["API_BASE_URL"], api_base_url=api_base_url,
embed_model=embed_model, embed_model=embed_model,
embed_model_path=embed_model_path, embed_model_path=embed_model_path,
model_device=model_device, model_device=model_device,
@ -839,7 +856,9 @@ class ApiRequest:
no_remote_api: bool = None, no_remote_api: bool = None,
kb_root_path: str = KB_ROOT_PATH, kb_root_path: str = KB_ROOT_PATH,
embed_model: str="", embed_model_path: str="", embed_model: str="", embed_model_path: str="",
model_device: str="", embed_engine: str="" model_device: str="", embed_engine: str="",
api_key: str=os.environ["OPENAI_API_KEY"],
api_base_url: str = os.environ["API_BASE_URL"],
): ):
''' '''
对应api.py/knowledge_base/delete_doc接口 对应api.py/knowledge_base/delete_doc接口
@ -853,8 +872,8 @@ class ApiRequest:
"delete_content": delete_content, "delete_content": delete_content,
"not_refresh_vs_cache": not_refresh_vs_cache, "not_refresh_vs_cache": not_refresh_vs_cache,
"kb_root_path": kb_root_path, "kb_root_path": kb_root_path,
"api_key": os.environ["OPENAI_API_KEY"], "api_key": api_key,
"api_base_url": os.environ["API_BASE_URL"], "api_base_url": api_base_url,
"embed_model": embed_model, "embed_model": embed_model,
"embed_model_path": embed_model_path, "embed_model_path": embed_model_path,
"model_device": model_device, "model_device": model_device,
@ -878,7 +897,9 @@ class ApiRequest:
not_refresh_vs_cache: bool = False, not_refresh_vs_cache: bool = False,
no_remote_api: bool = None, no_remote_api: bool = None,
embed_model: str="", embed_model_path: str="", embed_model: str="", embed_model_path: str="",
model_device: str="", embed_engine: str="" model_device: str="", embed_engine: str="",
api_key: str=os.environ["OPENAI_API_KEY"],
api_base_url: str = os.environ["API_BASE_URL"],
): ):
''' '''
对应api.py/knowledge_base/update_doc接口 对应api.py/knowledge_base/update_doc接口
@ -889,8 +910,8 @@ class ApiRequest:
if no_remote_api: if no_remote_api:
response = run_async(update_doc( response = run_async(update_doc(
knowledge_base_name, file_name, not_refresh_vs_cache, kb_root_path=KB_ROOT_PATH, knowledge_base_name, file_name, not_refresh_vs_cache, kb_root_path=KB_ROOT_PATH,
api_key=os.environ["OPENAI_API_KEY"], api_key=api_key,
api_base_url=os.environ["API_BASE_URL"], api_base_url=api_base_url,
embed_model=embed_model, embed_model=embed_model,
embed_model_path=embed_model_path, embed_model_path=embed_model_path,
model_device=model_device, model_device=model_device,
@ -915,7 +936,9 @@ class ApiRequest:
no_remote_api: bool = None, no_remote_api: bool = None,
kb_root_path: str =KB_ROOT_PATH, kb_root_path: str =KB_ROOT_PATH,
embed_model: str="", embed_model_path: str="", embed_model: str="", embed_model_path: str="",
embedding_device: str="", embed_engine: str="" embedding_device: str="", embed_engine: str="",
api_key: str=os.environ["OPENAI_API_KEY"],
api_base_url: str = os.environ["API_BASE_URL"],
): ):
''' '''
对应api.py/knowledge_base/recreate_vector_store接口 对应api.py/knowledge_base/recreate_vector_store接口
@ -928,8 +951,8 @@ class ApiRequest:
"allow_empty_kb": allow_empty_kb, "allow_empty_kb": allow_empty_kb,
"vs_type": vs_type, "vs_type": vs_type,
"kb_root_path": kb_root_path, "kb_root_path": kb_root_path,
"api_key": os.environ["OPENAI_API_KEY"], "api_key": api_key,
"api_base_url": os.environ["API_BASE_URL"], "api_base_url": api_base_url,
"embed_model": embed_model, "embed_model": embed_model,
"embed_model_path": embed_model_path, "embed_model_path": embed_model_path,
"model_device": embedding_device, "model_device": embedding_device,
@ -1041,7 +1064,9 @@ class ApiRequest:
# code base 相关操作 # code base 相关操作
def create_code_base(self, cb_name, zip_file, do_interpret: bool, no_remote_api: bool = None, def create_code_base(self, cb_name, zip_file, do_interpret: bool, no_remote_api: bool = None,
embed_model: str="", embed_model_path: str="", embedding_device: str="", embed_engine: str="", embed_model: str="", embed_model_path: str="", embedding_device: str="", embed_engine: str="",
llm_model: str ="", temperature: float= 0.2 llm_model: str ="", temperature: float= 0.2,
api_key: str=os.environ["OPENAI_API_KEY"],
api_base_url: str = os.environ["API_BASE_URL"],
): ):
''' '''
创建 code_base 创建 code_base
@ -1067,8 +1092,8 @@ class ApiRequest:
"cb_name": cb_name, "cb_name": cb_name,
"code_path": raw_code_path, "code_path": raw_code_path,
"do_interpret": do_interpret, "do_interpret": do_interpret,
"api_key": os.environ["OPENAI_API_KEY"], "api_key": api_key,
"api_base_url": os.environ["API_BASE_URL"], "api_base_url": api_base_url,
"embed_model": embed_model, "embed_model": embed_model,
"embed_model_path": embed_model_path, "embed_model_path": embed_model_path,
"embed_engine": embed_engine, "embed_engine": embed_engine,
@ -1091,7 +1116,9 @@ class ApiRequest:
def delete_code_base(self, cb_name: str, no_remote_api: bool = None, def delete_code_base(self, cb_name: str, no_remote_api: bool = None,
embed_model: str="", embed_model_path: str="", embedding_device: str="", embed_engine: str="", embed_model: str="", embed_model_path: str="", embedding_device: str="", embed_engine: str="",
llm_model: str ="", temperature: float= 0.2 llm_model: str ="", temperature: float= 0.2,
api_key: str=os.environ["OPENAI_API_KEY"],
api_base_url: str = os.environ["API_BASE_URL"],
): ):
''' '''
删除 code_base 删除 code_base
@ -1102,8 +1129,8 @@ class ApiRequest:
no_remote_api = self.no_remote_api no_remote_api = self.no_remote_api
data = { data = {
"cb_name": cb_name, "cb_name": cb_name,
"api_key": os.environ["OPENAI_API_KEY"], "api_key": api_key,
"api_base_url": os.environ["API_BASE_URL"], "api_base_url": api_base_url,
"embed_model": embed_model, "embed_model": embed_model,
"embed_model_path": embed_model_path, "embed_model_path": embed_model_path,
"embed_engine": embed_engine, "embed_engine": embed_engine,