Merge pull request #28 from codefuse-ai/ma_demo_commit
[feature](coagent)<增加antflow兼容和增加coagent demo>
This commit is contained in:
commit
333f1e97c6
|
@ -26,9 +26,12 @@ JUPYTER_WORK_PATH = os.environ.get("JUPYTER_WORK_PATH", None) or os.path.join(ex
|
|||
WEB_CRAWL_PATH = os.environ.get("WEB_CRAWL_PATH", None) or os.path.join(executable_path, "knowledge_base")
|
||||
|
||||
# NEBULA_DATA存储路径
|
||||
NELUBA_PATH = os.environ.get("NELUBA_PATH", None) or os.path.join(executable_path, "data/neluba_data")
|
||||
NEBULA_PATH = os.environ.get("NEBULA_PATH", None) or os.path.join(executable_path, "data/nebula_data")
|
||||
|
||||
for _path in [LOG_PATH, SOURCE_PATH, KB_ROOT_PATH, NLTK_DATA_PATH, JUPYTER_WORK_PATH, WEB_CRAWL_PATH, NELUBA_PATH]:
|
||||
# CHROMA 存储路径
|
||||
CHROMA_PERSISTENT_PATH = os.environ.get("CHROMA_PERSISTENT_PATH", None) or os.path.join(executable_path, "data/chroma_data")
|
||||
|
||||
for _path in [LOG_PATH, SOURCE_PATH, KB_ROOT_PATH, CB_ROOT_PATH, NLTK_DATA_PATH, JUPYTER_WORK_PATH, WEB_CRAWL_PATH, NEBULA_PATH, CHROMA_PERSISTENT_PATH]:
|
||||
if not os.path.exists(_path):
|
||||
os.makedirs(_path, exist_ok=True)
|
||||
|
||||
|
@ -58,7 +61,8 @@ NEBULA_GRAPH_SERVER = {
|
|||
}
|
||||
|
||||
# CHROMA CONFIG
|
||||
CHROMA_PERSISTENT_PATH = '/home/user/chatbot/data/chroma_data'
|
||||
# CHROMA_PERSISTENT_PATH = '/home/user/chatbot/data/chroma_data'
|
||||
# CHROMA_PERSISTENT_PATH = '/Users/bingxu/Desktop/工作/大模型/chatbot/codefuse-chatbot-antcode/data/chroma_data'
|
||||
|
||||
|
||||
# 默认向量库类型。可选:faiss, milvus, pg.
|
||||
|
|
|
@ -7,7 +7,7 @@ from langchain import LLMChain
|
|||
from langchain.callbacks import AsyncIteratorCallbackHandler
|
||||
from langchain.prompts.chat import ChatPromptTemplate
|
||||
|
||||
from coagent.llm_models import getChatModel, getChatModelFromConfig
|
||||
from coagent.llm_models import getChatModelFromConfig
|
||||
from coagent.chat.utils import History, wrap_done
|
||||
from coagent.llm_models.llm_config import LLMConfig, EmbedConfig
|
||||
# from configs.model_config import (llm_model_dict, LLM_MODEL, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD)
|
||||
|
|
|
@ -22,7 +22,7 @@ from coagent.connector.configs.prompts import CODE_PROMPT_TEMPLATE
|
|||
from coagent.chat.utils import History, wrap_done
|
||||
from coagent.utils import BaseResponse
|
||||
from .base_chat import Chat
|
||||
from coagent.llm_models import getChatModel, getChatModelFromConfig
|
||||
from coagent.llm_models import getChatModelFromConfig
|
||||
|
||||
from coagent.llm_models.llm_config import LLMConfig, EmbedConfig
|
||||
|
||||
|
@ -67,6 +67,7 @@ class CodeChat(Chat):
|
|||
embed_model_path=embed_config.embed_model_path,
|
||||
embed_engine=embed_config.embed_engine,
|
||||
model_device=embed_config.model_device,
|
||||
embed_config=embed_config
|
||||
)
|
||||
|
||||
context = codes_res['context']
|
||||
|
|
|
@ -12,7 +12,7 @@ from langchain.schema import (
|
|||
|
||||
# from configs.model_config import CODE_INTERPERT_TEMPLATE
|
||||
from coagent.connector.configs.prompts import CODE_INTERPERT_TEMPLATE
|
||||
from coagent.llm_models.openai_model import getChatModel, getChatModelFromConfig
|
||||
from coagent.llm_models.openai_model import getChatModelFromConfig
|
||||
from coagent.llm_models.llm_config import LLMConfig
|
||||
|
||||
|
||||
|
@ -53,9 +53,15 @@ class CodeIntepreter:
|
|||
message = CODE_INTERPERT_TEMPLATE.format(code=code)
|
||||
messages.append(message)
|
||||
|
||||
chat_ress = chat_model.batch(messages)
|
||||
try:
|
||||
chat_ress = [chat_model(messages) for message in messages]
|
||||
except:
|
||||
chat_ress = chat_model.batch(messages)
|
||||
for chat_res, code in zip(chat_ress, code_list):
|
||||
res[code] = chat_res.content
|
||||
try:
|
||||
res[code] = chat_res.content
|
||||
except:
|
||||
res[code] = chat_res
|
||||
return res
|
||||
|
||||
|
||||
|
|
|
@ -27,7 +27,7 @@ class DirCrawler:
|
|||
logger.info(java_file_list)
|
||||
|
||||
for java_file in java_file_list:
|
||||
with open(java_file) as f:
|
||||
with open(java_file, encoding="utf-8") as f:
|
||||
java_code = ''.join(f.readlines())
|
||||
java_code_dict[java_file] = java_code
|
||||
return java_code_dict
|
||||
|
|
|
@ -5,6 +5,7 @@
|
|||
@time: 2023/11/21 下午2:35
|
||||
@desc:
|
||||
'''
|
||||
import json
|
||||
import time
|
||||
from loguru import logger
|
||||
from collections import defaultdict
|
||||
|
@ -15,7 +16,7 @@ from coagent.db_handler.vector_db_handler.chroma_handler import ChromaHandler
|
|||
from coagent.codechat.code_search.cypher_generator import CypherGenerator
|
||||
from coagent.codechat.code_search.tagger import Tagger
|
||||
from coagent.embeddings.get_embedding import get_embedding
|
||||
from coagent.llm_models.llm_config import LLMConfig
|
||||
from coagent.llm_models.llm_config import LLMConfig, EmbedConfig
|
||||
|
||||
|
||||
# from configs.model_config import EMBEDDING_DEVICE, EMBEDDING_MODEL
|
||||
|
@ -29,7 +30,8 @@ MAX_DISTANCE = 1000
|
|||
|
||||
|
||||
class CodeSearch:
|
||||
def __init__(self, llm_config: LLMConfig, nh: NebulaHandler, ch: ChromaHandler, limit: int = 3):
|
||||
def __init__(self, llm_config: LLMConfig, nh: NebulaHandler, ch: ChromaHandler, limit: int = 3,
|
||||
local_graph_file_path: str = ''):
|
||||
'''
|
||||
init
|
||||
@param nh: NebulaHandler
|
||||
|
@ -37,7 +39,13 @@ class CodeSearch:
|
|||
@param limit: limit of result
|
||||
'''
|
||||
self.llm_config = llm_config
|
||||
|
||||
self.nh = nh
|
||||
|
||||
if not self.nh:
|
||||
with open(local_graph_file_path, 'r') as f:
|
||||
self.graph = json.load(f)
|
||||
|
||||
self.ch = ch
|
||||
self.limit = limit
|
||||
|
||||
|
@ -51,7 +59,7 @@ class CodeSearch:
|
|||
tag_list = tagger.generate_tag_query(query)
|
||||
logger.info(f'query tag={tag_list}')
|
||||
|
||||
# get all verticex
|
||||
# get all vertices
|
||||
vertex_list = self.nh.get_vertices().get('v', [])
|
||||
vertex_vid_list = [i.as_node().get_id().as_string() for i in vertex_list]
|
||||
|
||||
|
@ -81,7 +89,7 @@ class CodeSearch:
|
|||
# get most prominent package tag
|
||||
package_score_dict = defaultdict(lambda: 0)
|
||||
|
||||
for vertex, score in vertex_score_dict.items():
|
||||
for vertex, score in vertex_score_dict_final.items():
|
||||
if '#' in vertex:
|
||||
# get class name first
|
||||
cypher = f'''MATCH (v1:class)-[e:contain]->(v2) WHERE id(v2) == '{vertex}' RETURN id(v1) as id;'''
|
||||
|
@ -111,6 +119,53 @@ class CodeSearch:
|
|||
logger.info(f'ids={ids}')
|
||||
chroma_res = self.ch.get(ids=ids, include=['metadatas'])
|
||||
|
||||
for vertex, score in package_score_tuple:
|
||||
index = chroma_res['result']['ids'].index(vertex)
|
||||
code_text = chroma_res['result']['metadatas'][index]['code_text']
|
||||
res.append({
|
||||
"vertex": vertex,
|
||||
"code_text": code_text}
|
||||
)
|
||||
if len(res) >= self.limit:
|
||||
break
|
||||
# logger.info(f'retrival code={res}')
|
||||
return res
|
||||
|
||||
def search_by_tag_by_graph(self, query: str):
|
||||
'''
|
||||
search code by tag with graph
|
||||
@param query:
|
||||
@return:
|
||||
'''
|
||||
tagger = Tagger()
|
||||
tag_list = tagger.generate_tag_query(query)
|
||||
logger.info(f'query tag={tag_list}')
|
||||
|
||||
# loop to get package node
|
||||
package_score_dict = {}
|
||||
for code, structure in self.graph.items():
|
||||
score = 0
|
||||
for class_name in structure['class_name_list']:
|
||||
for tag in tag_list:
|
||||
if tag.lower() in class_name.lower():
|
||||
score += 1
|
||||
|
||||
for func_name_list in structure['func_name_dict'].values():
|
||||
for func_name in func_name_list:
|
||||
for tag in tag_list:
|
||||
if tag.lower() in func_name.lower():
|
||||
score += 1
|
||||
package_score_dict[structure['pac_name']] = score
|
||||
|
||||
# get respective code
|
||||
res = []
|
||||
package_score_tuple = list(package_score_dict.items())
|
||||
package_score_tuple.sort(key=lambda x: x[1], reverse=True)
|
||||
|
||||
ids = [i[0] for i in package_score_tuple]
|
||||
logger.info(f'ids={ids}')
|
||||
chroma_res = self.ch.get(ids=ids, include=['metadatas'])
|
||||
|
||||
# logger.info(chroma_res)
|
||||
for vertex, score in package_score_tuple:
|
||||
index = chroma_res['result']['ids'].index(vertex)
|
||||
|
@ -121,23 +176,22 @@ class CodeSearch:
|
|||
)
|
||||
if len(res) >= self.limit:
|
||||
break
|
||||
logger.info(f'retrival code={res}')
|
||||
# logger.info(f'retrival code={res}')
|
||||
return res
|
||||
|
||||
def search_by_desciption(self, query: str, engine: str, model_path: str = "text2vec-base-chinese", embedding_device: str = "cpu"):
|
||||
def search_by_desciption(self, query: str, engine: str, model_path: str = "text2vec-base-chinese", embedding_device: str = "cpu", embed_config: EmbedConfig=None):
|
||||
'''
|
||||
search by perform sim search
|
||||
@param query:
|
||||
@return:
|
||||
'''
|
||||
query = query.replace(',', ',')
|
||||
query_emb = get_embedding(engine=engine, text_list=[query], model_path=model_path, embedding_device= embedding_device,)
|
||||
query_emb = get_embedding(engine=engine, text_list=[query], model_path=model_path, embedding_device= embedding_device, embed_config=embed_config)
|
||||
query_emb = query_emb[query]
|
||||
|
||||
query_embeddings = [query_emb]
|
||||
query_result = self.ch.query(query_embeddings=query_embeddings, n_results=self.limit,
|
||||
include=['metadatas', 'distances'])
|
||||
logger.debug(query_result)
|
||||
|
||||
res = []
|
||||
for idx, distance in enumerate(query_result['result']['distances'][0]):
|
||||
|
|
|
@ -8,7 +8,7 @@
|
|||
from langchain import PromptTemplate
|
||||
from loguru import logger
|
||||
|
||||
from coagent.llm_models.openai_model import getChatModel, getChatModelFromConfig
|
||||
from coagent.llm_models.openai_model import getChatModelFromConfig
|
||||
from coagent.llm_models.llm_config import LLMConfig
|
||||
from coagent.utils.postprocess import replace_lt_gt
|
||||
from langchain.schema import (
|
||||
|
|
|
@ -6,11 +6,10 @@
|
|||
@desc:
|
||||
'''
|
||||
import time
|
||||
import json
|
||||
import os
|
||||
from loguru import logger
|
||||
|
||||
# from configs.server_config import NEBULA_HOST, NEBULA_PORT, NEBULA_USER, NEBULA_PASSWORD, NEBULA_STORAGED_PORT
|
||||
# from configs.server_config import CHROMA_PERSISTENT_PATH
|
||||
# from configs.model_config import EMBEDDING_DEVICE, EMBEDDING_MODEL
|
||||
from coagent.db_handler.graph_db_handler.nebula_handler import NebulaHandler
|
||||
from coagent.db_handler.vector_db_handler.chroma_handler import ChromaHandler
|
||||
from coagent.embeddings.get_embedding import get_embedding
|
||||
|
@ -18,12 +17,14 @@ from coagent.llm_models.llm_config import EmbedConfig
|
|||
|
||||
|
||||
class CodeImporter:
|
||||
def __init__(self, codebase_name: str, embed_config: EmbedConfig, nh: NebulaHandler, ch: ChromaHandler):
|
||||
def __init__(self, codebase_name: str, embed_config: EmbedConfig, nh: NebulaHandler, ch: ChromaHandler,
|
||||
local_graph_file_path: str):
|
||||
self.codebase_name = codebase_name
|
||||
# self.engine = engine
|
||||
self.embed_config: EmbedConfig= embed_config
|
||||
self.embed_config: EmbedConfig = embed_config
|
||||
self.nh = nh
|
||||
self.ch = ch
|
||||
self.local_graph_file_path = local_graph_file_path
|
||||
|
||||
def import_code(self, static_analysis_res: dict, interpretation: dict, do_interpret: bool = True):
|
||||
'''
|
||||
|
@ -31,9 +32,14 @@ class CodeImporter:
|
|||
@return:
|
||||
'''
|
||||
static_analysis_res = self.filter_out_vertex(static_analysis_res, interpretation)
|
||||
logger.info(f'static_analysis_res={static_analysis_res}')
|
||||
|
||||
self.analysis_res_to_graph(static_analysis_res)
|
||||
if self.nh:
|
||||
self.analysis_res_to_graph(static_analysis_res)
|
||||
else:
|
||||
# persist to local dir
|
||||
with open(self.local_graph_file_path, 'w') as f:
|
||||
json.dump(static_analysis_res, f)
|
||||
|
||||
self.interpretation_to_db(static_analysis_res, interpretation, do_interpret)
|
||||
|
||||
def filter_out_vertex(self, static_analysis_res, interpretation):
|
||||
|
@ -114,12 +120,12 @@ class CodeImporter:
|
|||
# create vertex
|
||||
for tag_name, value_dict in vertex_value_dict.items():
|
||||
res = self.nh.insert_vertex(tag_name, value_dict)
|
||||
logger.debug(res.error_msg())
|
||||
# logger.debug(res.error_msg())
|
||||
|
||||
# create edge
|
||||
for tag_name, value_dict in edge_value_dict.items():
|
||||
res = self.nh.insert_edge(tag_name, value_dict)
|
||||
logger.debug(res.error_msg())
|
||||
# logger.debug(res.error_msg())
|
||||
|
||||
return
|
||||
|
||||
|
@ -132,7 +138,7 @@ class CodeImporter:
|
|||
if do_interpret:
|
||||
logger.info('start get embedding for interpretion')
|
||||
interp_list = list(interpretation.values())
|
||||
emb = get_embedding(engine=self.embed_config.embed_engine, text_list=interp_list, model_path=self.embed_config.embed_model_path, embedding_device= self.embed_config.model_device)
|
||||
emb = get_embedding(engine=self.embed_config.embed_engine, text_list=interp_list, model_path=self.embed_config.embed_model_path, embedding_device= self.embed_config.model_device, embed_config=self.embed_config)
|
||||
logger.info('get embedding done')
|
||||
else:
|
||||
emb = {i: [0] for i in list(interpretation.values())}
|
||||
|
@ -161,7 +167,7 @@ class CodeImporter:
|
|||
|
||||
# add documents to chroma
|
||||
res = self.ch.add_data(ids=ids, embeddings=embeddings, documents=documents, metadatas=metadatas)
|
||||
logger.debug(res)
|
||||
# logger.debug(res)
|
||||
|
||||
def init_graph(self):
|
||||
'''
|
||||
|
@ -169,7 +175,7 @@ class CodeImporter:
|
|||
@return:
|
||||
'''
|
||||
res = self.nh.create_space(space_name=self.codebase_name, vid_type='FIXED_STRING(1024)')
|
||||
logger.debug(res.error_msg())
|
||||
# logger.debug(res.error_msg())
|
||||
time.sleep(5)
|
||||
|
||||
self.nh.set_space_name(self.codebase_name)
|
||||
|
@ -179,29 +185,29 @@ class CodeImporter:
|
|||
tag_name = 'package'
|
||||
prop_dict = {}
|
||||
res = self.nh.create_tag(tag_name, prop_dict)
|
||||
logger.debug(res.error_msg())
|
||||
# logger.debug(res.error_msg())
|
||||
|
||||
tag_name = 'class'
|
||||
prop_dict = {}
|
||||
res = self.nh.create_tag(tag_name, prop_dict)
|
||||
logger.debug(res.error_msg())
|
||||
# logger.debug(res.error_msg())
|
||||
|
||||
tag_name = 'method'
|
||||
prop_dict = {}
|
||||
res = self.nh.create_tag(tag_name, prop_dict)
|
||||
logger.debug(res.error_msg())
|
||||
# logger.debug(res.error_msg())
|
||||
|
||||
# create edge type
|
||||
edge_type_name = 'contain'
|
||||
prop_dict = {}
|
||||
res = self.nh.create_edge_type(edge_type_name, prop_dict)
|
||||
logger.debug(res.error_msg())
|
||||
# logger.debug(res.error_msg())
|
||||
|
||||
# create edge type
|
||||
edge_type_name = 'depend'
|
||||
prop_dict = {}
|
||||
res = self.nh.create_edge_type(edge_type_name, prop_dict)
|
||||
logger.debug(res.error_msg())
|
||||
# logger.debug(res.error_msg())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -5,16 +5,15 @@
|
|||
@time: 2023/11/21 下午2:25
|
||||
@desc:
|
||||
'''
|
||||
import os
|
||||
import time
|
||||
import json
|
||||
from typing import List
|
||||
from loguru import logger
|
||||
|
||||
# from configs.server_config import NEBULA_HOST, NEBULA_PORT, NEBULA_USER, NEBULA_PASSWORD, NEBULA_STORAGED_PORT
|
||||
# from configs.server_config import CHROMA_PERSISTENT_PATH
|
||||
# from configs.model_config import EMBEDDING_ENGINE
|
||||
|
||||
from coagent.base_configs.env_config import (
|
||||
NEBULA_HOST, NEBULA_PORT, NEBULA_USER, NEBULA_PASSWORD, NEBULA_STORAGED_PORT,
|
||||
CHROMA_PERSISTENT_PATH
|
||||
CHROMA_PERSISTENT_PATH, CB_ROOT_PATH
|
||||
)
|
||||
|
||||
|
||||
|
@ -35,7 +34,9 @@ class CodeBaseHandler:
|
|||
language: str = 'java',
|
||||
crawl_type: str = 'ZIP',
|
||||
embed_config: EmbedConfig = EmbedConfig(),
|
||||
llm_config: LLMConfig = LLMConfig()
|
||||
llm_config: LLMConfig = LLMConfig(),
|
||||
use_nh: bool = True,
|
||||
local_graph_path: str = CB_ROOT_PATH
|
||||
):
|
||||
self.codebase_name = codebase_name
|
||||
self.code_path = code_path
|
||||
|
@ -43,11 +44,28 @@ class CodeBaseHandler:
|
|||
self.crawl_type = crawl_type
|
||||
self.embed_config = embed_config
|
||||
self.llm_config = llm_config
|
||||
self.local_graph_file_path = local_graph_path + os.sep + f'{self.codebase_name}_graph.json'
|
||||
|
||||
self.nh = NebulaHandler(host=NEBULA_HOST, port=NEBULA_PORT, username=NEBULA_USER,
|
||||
password=NEBULA_PASSWORD, space_name=codebase_name)
|
||||
self.nh.add_host(NEBULA_HOST, NEBULA_STORAGED_PORT)
|
||||
time.sleep(1)
|
||||
if use_nh:
|
||||
try:
|
||||
self.nh = NebulaHandler(host=NEBULA_HOST, port=NEBULA_PORT, username=NEBULA_USER,
|
||||
password=NEBULA_PASSWORD, space_name=codebase_name)
|
||||
self.nh.add_host(NEBULA_HOST, NEBULA_STORAGED_PORT)
|
||||
time.sleep(1)
|
||||
except:
|
||||
self.nh = None
|
||||
try:
|
||||
with open(self.local_graph_file_path, 'r') as f:
|
||||
self.graph = json.load(f)
|
||||
except:
|
||||
pass
|
||||
elif local_graph_path:
|
||||
self.nh = None
|
||||
try:
|
||||
with open(self.local_graph_file_path, 'r') as f:
|
||||
self.graph = json.load(f)
|
||||
except:
|
||||
pass
|
||||
|
||||
self.ch = ChromaHandler(path=CHROMA_PERSISTENT_PATH, collection_name=codebase_name)
|
||||
|
||||
|
@ -58,9 +76,10 @@ class CodeBaseHandler:
|
|||
'''
|
||||
# init graph to init tag and edge
|
||||
code_importer = CodeImporter(embed_config=self.embed_config, codebase_name=self.codebase_name,
|
||||
nh=self.nh, ch=self.ch)
|
||||
code_importer.init_graph()
|
||||
time.sleep(5)
|
||||
nh=self.nh, ch=self.ch, local_graph_file_path=self.local_graph_file_path)
|
||||
if self.nh:
|
||||
code_importer.init_graph()
|
||||
time.sleep(5)
|
||||
|
||||
# crawl code
|
||||
st0 = time.time()
|
||||
|
@ -71,7 +90,7 @@ class CodeBaseHandler:
|
|||
# analyze code
|
||||
logger.info('start analyze')
|
||||
st1 = time.time()
|
||||
code_analyzer = CodeAnalyzer(language=self.language, llm_config = self.llm_config)
|
||||
code_analyzer = CodeAnalyzer(language=self.language, llm_config=self.llm_config)
|
||||
static_analysis_res, interpretation = code_analyzer.analyze(code_dict, do_interpret=do_interpret)
|
||||
logger.debug('analyze done, rt={}'.format(time.time() - st1))
|
||||
|
||||
|
@ -81,8 +100,12 @@ class CodeBaseHandler:
|
|||
logger.debug('update codebase done, rt={}'.format(time.time() - st2))
|
||||
|
||||
# get KG info
|
||||
stat = self.nh.get_stat()
|
||||
vertices_num, edges_num = stat['vertices'], stat['edges']
|
||||
if self.nh:
|
||||
stat = self.nh.get_stat()
|
||||
vertices_num, edges_num = stat['vertices'], stat['edges']
|
||||
else:
|
||||
vertices_num = 0
|
||||
edges_num = 0
|
||||
|
||||
# get chroma info
|
||||
file_num = self.ch.count()['result']
|
||||
|
@ -95,7 +118,11 @@ class CodeBaseHandler:
|
|||
@param codebase_name: name of codebase
|
||||
@return:
|
||||
'''
|
||||
self.nh.drop_space(space_name=codebase_name)
|
||||
if self.nh:
|
||||
self.nh.drop_space(space_name=codebase_name)
|
||||
elif self.local_graph_file_path and os.path.isfile(self.local_graph_file_path):
|
||||
os.remove(self.local_graph_file_path)
|
||||
|
||||
self.ch.delete_collection(collection_name=codebase_name)
|
||||
|
||||
def crawl_code(self, zip_file=''):
|
||||
|
@ -124,9 +151,15 @@ class CodeBaseHandler:
|
|||
@param search_type: ['cypher', 'graph', 'vector']
|
||||
@return:
|
||||
'''
|
||||
assert search_type in ['cypher', 'tag', 'description']
|
||||
if self.nh:
|
||||
assert search_type in ['cypher', 'tag', 'description']
|
||||
else:
|
||||
if search_type == 'tag':
|
||||
search_type = 'tag_by_local_graph'
|
||||
assert search_type in ['tag_by_local_graph', 'description']
|
||||
|
||||
code_search = CodeSearch(llm_config=self.llm_config, nh=self.nh, ch=self.ch, limit=limit)
|
||||
code_search = CodeSearch(llm_config=self.llm_config, nh=self.nh, ch=self.ch, limit=limit,
|
||||
local_graph_file_path=self.local_graph_file_path)
|
||||
|
||||
if search_type == 'cypher':
|
||||
search_res = code_search.search_by_cypher(query=query)
|
||||
|
@ -134,7 +167,11 @@ class CodeBaseHandler:
|
|||
search_res = code_search.search_by_tag(query=query)
|
||||
elif search_type == 'description':
|
||||
search_res = code_search.search_by_desciption(
|
||||
query=query, engine=self.embed_config.embed_engine, model_path=self.embed_config.embed_model_path, embedding_device=self.embed_config.model_device)
|
||||
query=query, engine=self.embed_config.embed_engine, model_path=self.embed_config.embed_model_path,
|
||||
embedding_device=self.embed_config.model_device, embed_config=self.embed_config)
|
||||
elif search_type == 'tag_by_local_graph':
|
||||
search_res = code_search.search_by_tag_by_graph(query=query)
|
||||
|
||||
|
||||
context, related_vertice = self.format_search_res(search_res, search_type)
|
||||
return context, related_vertice
|
||||
|
@ -160,6 +197,12 @@ class CodeBaseHandler:
|
|||
for code in search_res:
|
||||
context = context + code['code_text'] + '\n'
|
||||
related_vertice.append(code['vertex'])
|
||||
elif search_type == 'tag_by_local_graph':
|
||||
context = ''
|
||||
related_vertice = []
|
||||
for code in search_res:
|
||||
context = context + code['code_text'] + '\n'
|
||||
related_vertice.append(code['vertex'])
|
||||
elif search_type == 'description':
|
||||
context = ''
|
||||
related_vertice = []
|
||||
|
@ -169,17 +212,63 @@ class CodeBaseHandler:
|
|||
|
||||
return context, related_vertice
|
||||
|
||||
def search_vertices(self, vertex_type="class") -> List[str]:
|
||||
'''
|
||||
通过 method/class 来搜索所有的节点
|
||||
'''
|
||||
vertices = []
|
||||
if self.nh:
|
||||
vertices = self.nh.get_all_vertices()
|
||||
vertices = [str(v.as_node().get_id()) for v in vertices["v"] if vertex_type in v.as_node().tags()]
|
||||
# for v in vertices["v"]:
|
||||
# logger.debug(f"{v.as_node().get_id()}, {v.as_node().tags()}")
|
||||
else:
|
||||
if vertex_type == "class":
|
||||
vertices = [str(class_name) for code, structure in self.graph.items() for class_name in structure['class_name_list']]
|
||||
elif vertex_type == "method":
|
||||
vertices = [
|
||||
str(methods_name)
|
||||
for code, structure in self.graph.items()
|
||||
for methods_names in structure['func_name_dict'].values()
|
||||
for methods_name in methods_names
|
||||
]
|
||||
# logger.debug(vertices)
|
||||
return vertices
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
codebase_name = 'testing'
|
||||
from configs.model_config import KB_ROOT_PATH, JUPYTER_WORK_PATH
|
||||
from configs.server_config import SANDBOX_SERVER
|
||||
|
||||
LLM_MODEL = "gpt-3.5-turbo"
|
||||
llm_config = LLMConfig(
|
||||
model_name=LLM_MODEL, model_device="cpu", api_key=os.environ["OPENAI_API_KEY"],
|
||||
api_base_url=os.environ["API_BASE_URL"], temperature=0.3
|
||||
)
|
||||
src_dir = '/Users/bingxu/Desktop/工作/大模型/chatbot/Codefuse-chatbot-antcode'
|
||||
embed_config = EmbedConfig(
|
||||
embed_engine="model", embed_model="text2vec-base-chinese",
|
||||
embed_model_path=os.path.join(src_dir, "embedding_models/text2vec-base-chinese")
|
||||
)
|
||||
|
||||
codebase_name = 'client_local'
|
||||
code_path = '/Users/bingxu/Desktop/工作/大模型/chatbot/test_code_repo/client'
|
||||
cbh = CodeBaseHandler(codebase_name, code_path, crawl_type='dir')
|
||||
use_nh = False
|
||||
local_graph_path = '/Users/bingxu/Desktop/工作/大模型/chatbot/Codefuse-chatbot-antcode/code_base'
|
||||
CHROMA_PERSISTENT_PATH = '/Users/bingxu/Desktop/工作/大模型/chatbot/Codefuse-chatbot-antcode/data/chroma_data'
|
||||
|
||||
cbh = CodeBaseHandler(codebase_name, code_path, crawl_type='dir', use_nh=use_nh, local_graph_path=local_graph_path,
|
||||
llm_config=llm_config, embed_config=embed_config)
|
||||
|
||||
# test import code
|
||||
# cbh.import_code(do_interpret=True)
|
||||
|
||||
# query = '使用不同的HTTP请求类型(GET、POST、DELETE等)来执行不同的操作'
|
||||
# query = '代码中一共有多少个类'
|
||||
# query = 'remove 这个函数是用来做什么的'
|
||||
query = '有没有函数是从字符串中删除指定字符串的功能'
|
||||
|
||||
query = 'intercept 函数作用是什么'
|
||||
search_type = 'graph'
|
||||
search_type = 'description'
|
||||
limit = 2
|
||||
res = cbh.search_code(query, search_type, limit)
|
||||
logger.debug(res)
|
||||
|
|
|
@ -0,0 +1,6 @@
|
|||
from .base_action import BaseAction
|
||||
|
||||
|
||||
__all__ = [
|
||||
"BaseAction"
|
||||
]
|
|
@ -0,0 +1,16 @@
|
|||
|
||||
from langchain.schema import BaseRetriever, Document
|
||||
|
||||
class BaseAction:
|
||||
|
||||
|
||||
def __init__(self, ):
|
||||
pass
|
||||
|
||||
def step(self, ):
|
||||
pass
|
||||
|
||||
def astep(self, ):
|
||||
pass
|
||||
|
||||
|
|
@ -4,25 +4,25 @@ import re, os
|
|||
import copy
|
||||
from loguru import logger
|
||||
|
||||
from langchain.schema import BaseRetriever
|
||||
|
||||
from coagent.connector.schema import (
|
||||
Memory, Task, Role, Message, PromptField, LogVerboseEnum
|
||||
)
|
||||
from coagent.connector.memory_manager import BaseMemoryManager
|
||||
from coagent.connector.configs.prompts import BEGIN_PROMPT_INPUT
|
||||
from coagent.connector.message_process import MessageUtils
|
||||
from coagent.llm_models import getChatModel, getExtraModel, LLMConfig, getChatModelFromConfig, EmbedConfig
|
||||
from coagent.connector.prompt_manager import PromptManager
|
||||
from coagent.llm_models import getExtraModel, LLMConfig, getChatModelFromConfig, EmbedConfig
|
||||
from coagent.connector.prompt_manager.prompt_manager import PromptManager
|
||||
from coagent.connector.memory_manager import LocalMemoryManager
|
||||
from coagent.connector.utils import parse_section
|
||||
# from configs.model_config import JUPYTER_WORK_PATH
|
||||
# from configs.server_config import SANDBOX_SERVER
|
||||
from coagent.base_configs.env_config import JUPYTER_WORK_PATH, KB_ROOT_PATH
|
||||
|
||||
|
||||
class BaseAgent:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
role: Role,
|
||||
prompt_config: [PromptField],
|
||||
prompt_config: List[PromptField],
|
||||
prompt_manager_type: str = "PromptManager",
|
||||
task: Task = None,
|
||||
memory: Memory = None,
|
||||
|
@ -33,8 +33,11 @@ class BaseAgent:
|
|||
llm_config: LLMConfig = None,
|
||||
embed_config: EmbedConfig = None,
|
||||
sandbox_server: dict = {},
|
||||
jupyter_work_path: str = "",
|
||||
kb_root_path: str = "",
|
||||
jupyter_work_path: str = JUPYTER_WORK_PATH,
|
||||
kb_root_path: str = KB_ROOT_PATH,
|
||||
doc_retrieval: Union[BaseRetriever] = None,
|
||||
code_retrieval = None,
|
||||
search_retrieval = None,
|
||||
log_verbose: str = "0"
|
||||
):
|
||||
|
||||
|
@ -43,7 +46,7 @@ class BaseAgent:
|
|||
self.sandbox_server = sandbox_server
|
||||
self.jupyter_work_path = jupyter_work_path
|
||||
self.kb_root_path = kb_root_path
|
||||
self.message_utils = MessageUtils(role, sandbox_server, jupyter_work_path, embed_config, llm_config, kb_root_path, log_verbose)
|
||||
self.message_utils = MessageUtils(role, sandbox_server, jupyter_work_path, embed_config, llm_config, kb_root_path, doc_retrieval, code_retrieval, search_retrieval, log_verbose)
|
||||
self.memory = self.init_history(memory)
|
||||
self.llm_config: LLMConfig = llm_config
|
||||
self.embed_config: EmbedConfig = embed_config
|
||||
|
@ -82,12 +85,8 @@ class BaseAgent:
|
|||
llm_config=self.embed_config
|
||||
)
|
||||
memory_manager.append(query)
|
||||
memory_pool = memory_manager.current_memory
|
||||
else:
|
||||
memory_pool = memory_manager.current_memory
|
||||
memory_pool = memory_manager.get_memory_pool(query.user_name)
|
||||
|
||||
|
||||
logger.debug(f"memory_pool: {memory_pool}")
|
||||
prompt = self.prompt_manager.generate_full_prompt(
|
||||
previous_agent_message=query_c, agent_long_term_memory=self.memory, ui_history=history, chain_summary_messages=background, memory_pool=memory_pool)
|
||||
content = self.llm.predict(prompt)
|
||||
|
@ -99,6 +98,7 @@ class BaseAgent:
|
|||
logger.info(f"{self.role.role_name} content: {content}")
|
||||
|
||||
output_message = Message(
|
||||
user_name=query.user_name,
|
||||
role_name=self.role.role_name,
|
||||
role_type="assistant", #self.role.role_type,
|
||||
role_content=content,
|
||||
|
@ -151,10 +151,7 @@ class BaseAgent:
|
|||
self.memory = self.init_history()
|
||||
|
||||
def create_llm_engine(self, llm_config: LLMConfig = None, temperature=0.2, stop=None):
|
||||
if llm_config is None:
|
||||
return getChatModel(temperature=temperature, stop=stop)
|
||||
else:
|
||||
return getChatModelFromConfig(llm_config=llm_config)
|
||||
return getChatModelFromConfig(llm_config=llm_config)
|
||||
|
||||
def registry_actions(self, actions):
|
||||
'''registry llm's actions'''
|
||||
|
@ -212,171 +209,3 @@ class BaseAgent:
|
|||
|
||||
def get_memory_str(self, content_key="role_content"):
|
||||
return "\n".join([": ".join(i) for i in self.memory.to_tuple_messages(content_key="step_content")])
|
||||
|
||||
|
||||
def create_prompt(
|
||||
self, query: Message, memory: Memory =None, history: Memory = None, background: Memory = None, memory_pool: Memory=None, prompt_mamnger=None) -> str:
|
||||
'''
|
||||
prompt engineer, contains role\task\tools\docs\memory
|
||||
'''
|
||||
#
|
||||
doc_infos = self.create_doc_prompt(query)
|
||||
code_infos = self.create_codedoc_prompt(query)
|
||||
#
|
||||
formatted_tools, tool_names, _ = self.create_tools_prompt(query)
|
||||
task_prompt = self.create_task_prompt(query)
|
||||
background_prompt = self.create_background_prompt(background, control_key="step_content")
|
||||
history_prompt = self.create_history_prompt(history)
|
||||
selfmemory_prompt = self.create_selfmemory_prompt(memory, control_key="step_content")
|
||||
|
||||
# extra_system_prompt = self.role.role_prompt
|
||||
|
||||
|
||||
prompt = self.role.role_prompt.format(**{"formatted_tools": formatted_tools, "tool_names": tool_names})
|
||||
#
|
||||
memory_pool_select_by_agent_key = self.select_memory_by_agent_key(memory_pool)
|
||||
memory_pool_select_by_agent_key_context = '\n\n'.join([f"*{k}*\n{v}" for parsed_output in memory_pool_select_by_agent_key.get_parserd_output_list() for k, v in parsed_output.items() if k not in ['Action Status']])
|
||||
|
||||
# input_query = query.input_query
|
||||
|
||||
# # logger.debug(f"{self.role.role_name} extra_system_prompt: {self.role.role_prompt}")
|
||||
# # logger.debug(f"{self.role.role_name} input_query: {input_query}")
|
||||
# # logger.debug(f"{self.role.role_name} doc_infos: {doc_infos}")
|
||||
# # logger.debug(f"{self.role.role_name} tool_names: {tool_names}")
|
||||
# if "**Context:**" in self.role.role_prompt:
|
||||
# # logger.debug(f"parsed_output_list: {query.parsed_output_list}")
|
||||
# # input_query = "'''" + "\n".join([f"###{k}###\n{v}" for i in query.parsed_output_list for k,v in i.items() if "Action Status" !=k]) + "'''"
|
||||
# context = "\n".join([f"*{k}*\n{v}" for i in query.parsed_output_list for k,v in i.items() if "Action Status" !=k])
|
||||
# # context = history_prompt or '""'
|
||||
# # logger.debug(f"parsed_output_list: {t}")
|
||||
# prompt += "\n" + QUERY_CONTEXT_PROMPT_INPUT.format(**{"context": context, "query": query.origin_query})
|
||||
# else:
|
||||
# prompt += "\n" + PLAN_PROMPT_INPUT.format(**{"query": input_query})
|
||||
|
||||
task = query.task or self.task
|
||||
if task_prompt is not None:
|
||||
prompt += "\n" + task.task_prompt
|
||||
|
||||
DocInfos = ""
|
||||
if doc_infos is not None and doc_infos!="" and doc_infos!="不存在知识库辅助信息":
|
||||
DocInfos += f"\nDocument Information: {doc_infos}"
|
||||
|
||||
if code_infos is not None and code_infos!="" and code_infos!="不存在代码库辅助信息":
|
||||
DocInfos += f"\nCodeBase Infomation: {code_infos}"
|
||||
|
||||
# if selfmemory_prompt:
|
||||
# prompt += "\n" + selfmemory_prompt
|
||||
|
||||
# if background_prompt:
|
||||
# prompt += "\n" + background_prompt
|
||||
|
||||
# if history_prompt:
|
||||
# prompt += "\n" + history_prompt
|
||||
|
||||
input_query = query.input_query
|
||||
|
||||
# logger.debug(f"{self.role.role_name} extra_system_prompt: {self.role.role_prompt}")
|
||||
# logger.debug(f"{self.role.role_name} input_query: {input_query}")
|
||||
# logger.debug(f"{self.role.role_name} doc_infos: {doc_infos}")
|
||||
# logger.debug(f"{self.role.role_name} tool_names: {tool_names}")
|
||||
|
||||
# extra_system_prompt = self.role.role_prompt
|
||||
input_keys = parse_section(self.role.role_prompt, 'Input Format')
|
||||
prompt = self.role.role_prompt.format(**{"formatted_tools": formatted_tools, "tool_names": tool_names})
|
||||
prompt += "\n" + BEGIN_PROMPT_INPUT
|
||||
for input_key in input_keys:
|
||||
if input_key == "Origin Query":
|
||||
prompt += "\n**Origin Query:**\n" + query.origin_query
|
||||
elif input_key == "Context":
|
||||
context = "\n".join([f"*{k}*\n{v}" for i in query.parsed_output_list for k,v in i.items() if "Action Status" !=k])
|
||||
if history:
|
||||
context = history_prompt + "\n" + context
|
||||
if not context:
|
||||
context = "there is no context"
|
||||
|
||||
if self.focus_agents and memory_pool_select_by_agent_key_context:
|
||||
context = memory_pool_select_by_agent_key_context
|
||||
prompt += "\n**Context:**\n" + context + "\n" + input_query
|
||||
elif input_key == "DocInfos":
|
||||
if DocInfos:
|
||||
prompt += "\n**DocInfos:**\n" + DocInfos
|
||||
else:
|
||||
prompt += "\n**DocInfos:**\n" + "Empty"
|
||||
elif input_key == "Question":
|
||||
prompt += "\n**Question:**\n" + input_query
|
||||
|
||||
# if "**Context:**" in self.role.role_prompt:
|
||||
# # logger.debug(f"parsed_output_list: {query.parsed_output_list}")
|
||||
# # input_query = "'''" + "\n".join([f"###{k}###\n{v}" for i in query.parsed_output_list for k,v in i.items() if "Action Status" !=k]) + "'''"
|
||||
# context = "\n".join([f"*{k}*\n{v}" for i in query.parsed_output_list for k,v in i.items() if "Action Status" !=k])
|
||||
# if history:
|
||||
# context = history_prompt + "\n" + context
|
||||
|
||||
# if not context:
|
||||
# context = "there is no context"
|
||||
|
||||
# # logger.debug(f"parsed_output_list: {t}")
|
||||
# if "DocInfos" in prompt:
|
||||
# prompt += "\n" + QUERY_CONTEXT_DOC_PROMPT_INPUT.format(**{"context": context, "query": query.origin_query, "DocInfos": DocInfos})
|
||||
# else:
|
||||
# prompt += "\n" + QUERY_CONTEXT_PROMPT_INPUT.format(**{"context": context, "query": query.origin_query, "DocInfos": DocInfos})
|
||||
# else:
|
||||
# prompt += "\n" + BASE_PROMPT_INPUT.format(**{"query": input_query})
|
||||
|
||||
# prompt = extra_system_prompt.format(**{"query": input_query, "doc_infos": doc_infos, "formatted_tools": formatted_tools, "tool_names": tool_names})
|
||||
while "{{" in prompt or "}}" in prompt:
|
||||
prompt = prompt.replace("{{", "{")
|
||||
prompt = prompt.replace("}}", "}")
|
||||
|
||||
# logger.debug(f"{self.role.role_name} prompt: {prompt}")
|
||||
return prompt
|
||||
|
||||
def create_doc_prompt(self, message: Message) -> str:
|
||||
''''''
|
||||
db_docs = message.db_docs
|
||||
search_docs = message.search_docs
|
||||
doc_infos = "\n".join([doc.get_snippet() for doc in db_docs] + [doc.get_snippet() for doc in search_docs])
|
||||
return doc_infos or "不存在知识库辅助信息"
|
||||
|
||||
def create_codedoc_prompt(self, message: Message) -> str:
|
||||
''''''
|
||||
code_docs = message.code_docs
|
||||
doc_infos = "\n".join([doc.get_code() for doc in code_docs])
|
||||
return doc_infos or "不存在代码库辅助信息"
|
||||
|
||||
def create_tools_prompt(self, message: Message) -> str:
|
||||
tools = message.tools
|
||||
tool_strings = []
|
||||
tools_descs = []
|
||||
for tool in tools:
|
||||
args_schema = re.sub("}", "}}}}", re.sub("{", "{{{{", str(tool.args)))
|
||||
tool_strings.append(f"{tool.name}: {tool.description}, args: {args_schema}")
|
||||
tools_descs.append(f"{tool.name}: {tool.description}")
|
||||
formatted_tools = "\n".join(tool_strings)
|
||||
tools_desc_str = "\n".join(tools_descs)
|
||||
tool_names = ", ".join([tool.name for tool in tools])
|
||||
return formatted_tools, tool_names, tools_desc_str
|
||||
|
||||
def create_task_prompt(self, message: Message) -> str:
|
||||
task = message.task or self.task
|
||||
return "\n任务目标: " + task.task_prompt if task is not None else None
|
||||
|
||||
def create_background_prompt(self, background: Memory, control_key="role_content") -> str:
|
||||
background_message = None if background is None else background.to_str_messages(content_key=control_key)
|
||||
# logger.debug(f"background_message: {background_message}")
|
||||
if background_message:
|
||||
background_message = re.sub("}", "}}", re.sub("{", "{{", background_message))
|
||||
return "\n背景信息: " + background_message if background_message else None
|
||||
|
||||
def create_history_prompt(self, history: Memory, control_key="role_content") -> str:
|
||||
history_message = None if history is None else history.to_str_messages(content_key=control_key)
|
||||
if history_message:
|
||||
history_message = re.sub("}", "}}", re.sub("{", "{{", history_message))
|
||||
return "\n补充对话信息: " + history_message if history_message else None
|
||||
|
||||
def create_selfmemory_prompt(self, selfmemory: Memory, control_key="role_content") -> str:
|
||||
selfmemory_message = None if selfmemory is None else selfmemory.to_str_messages(content_key=control_key)
|
||||
if selfmemory_message:
|
||||
selfmemory_message = re.sub("}", "}}", re.sub("{", "{{", selfmemory_message))
|
||||
return "\n补充自身对话信息: " + selfmemory_message if selfmemory_message else None
|
||||
|
|
@ -2,14 +2,15 @@ from typing import List, Union
|
|||
import copy
|
||||
from loguru import logger
|
||||
|
||||
from langchain.schema import BaseRetriever
|
||||
|
||||
from coagent.connector.schema import (
|
||||
Memory, Task, Env, Role, Message, ActionStatus, PromptField, LogVerboseEnum
|
||||
)
|
||||
from coagent.connector.memory_manager import BaseMemoryManager
|
||||
from coagent.connector.configs.prompts import BEGIN_PROMPT_INPUT
|
||||
from coagent.llm_models import LLMConfig, EmbedConfig
|
||||
from coagent.connector.memory_manager import LocalMemoryManager
|
||||
|
||||
from coagent.base_configs.env_config import JUPYTER_WORK_PATH, KB_ROOT_PATH
|
||||
from .base_agent import BaseAgent
|
||||
|
||||
|
||||
|
@ -17,7 +18,7 @@ class ExecutorAgent(BaseAgent):
|
|||
def __init__(
|
||||
self,
|
||||
role: Role,
|
||||
prompt_config: [PromptField],
|
||||
prompt_config: List[PromptField],
|
||||
prompt_manager_type: str= "PromptManager",
|
||||
task: Task = None,
|
||||
memory: Memory = None,
|
||||
|
@ -28,14 +29,17 @@ class ExecutorAgent(BaseAgent):
|
|||
llm_config: LLMConfig = None,
|
||||
embed_config: EmbedConfig = None,
|
||||
sandbox_server: dict = {},
|
||||
jupyter_work_path: str = "",
|
||||
kb_root_path: str = "",
|
||||
jupyter_work_path: str = JUPYTER_WORK_PATH,
|
||||
kb_root_path: str = KB_ROOT_PATH,
|
||||
doc_retrieval: Union[BaseRetriever] = None,
|
||||
code_retrieval = None,
|
||||
search_retrieval = None,
|
||||
log_verbose: str = "0"
|
||||
):
|
||||
|
||||
super().__init__(role, prompt_config, prompt_manager_type, task, memory, chat_turn,
|
||||
focus_agents, focus_message_keys, llm_config, embed_config, sandbox_server,
|
||||
jupyter_work_path, kb_root_path, log_verbose
|
||||
jupyter_work_path, kb_root_path, doc_retrieval, code_retrieval, search_retrieval, log_verbose
|
||||
)
|
||||
self.do_all_task = True # run all tasks
|
||||
|
||||
|
@ -45,6 +49,7 @@ class ExecutorAgent(BaseAgent):
|
|||
task_executor_memory = Memory(messages=[])
|
||||
# insert query
|
||||
output_message = Message(
|
||||
user_name=query.user_name,
|
||||
role_name=self.role.role_name,
|
||||
role_type="assistant", #self.role.role_type,
|
||||
role_content=query.input_query,
|
||||
|
@ -115,7 +120,7 @@ class ExecutorAgent(BaseAgent):
|
|||
history: Memory, background: Memory, memory_manager: BaseMemoryManager,
|
||||
task_memory: Memory) -> Union[Message, Memory]:
|
||||
'''execute the llm predict by created prompt'''
|
||||
memory_pool = memory_manager.current_memory
|
||||
memory_pool = memory_manager.get_memory_pool(query.user_name)
|
||||
prompt = self.prompt_manager.generate_full_prompt(
|
||||
previous_agent_message=query, agent_long_term_memory=self_memory, ui_history=history, chain_summary_messages=background, memory_pool=memory_pool,
|
||||
task_memory=task_memory)
|
||||
|
|
|
@ -3,23 +3,23 @@ import traceback
|
|||
import copy
|
||||
from loguru import logger
|
||||
|
||||
from langchain.schema import BaseRetriever
|
||||
|
||||
from coagent.connector.schema import (
|
||||
Memory, Task, Env, Role, Message, ActionStatus, PromptField, LogVerboseEnum
|
||||
)
|
||||
from coagent.connector.memory_manager import BaseMemoryManager
|
||||
from coagent.connector.configs.agent_config import REACT_PROMPT_INPUT
|
||||
from coagent.llm_models import LLMConfig, EmbedConfig
|
||||
from .base_agent import BaseAgent
|
||||
from coagent.connector.memory_manager import LocalMemoryManager
|
||||
|
||||
from coagent.connector.prompt_manager import PromptManager
|
||||
from coagent.base_configs.env_config import JUPYTER_WORK_PATH, KB_ROOT_PATH
|
||||
|
||||
|
||||
class ReactAgent(BaseAgent):
|
||||
def __init__(
|
||||
self,
|
||||
role: Role,
|
||||
prompt_config: [PromptField],
|
||||
prompt_config: List[PromptField],
|
||||
prompt_manager_type: str = "PromptManager",
|
||||
task: Task = None,
|
||||
memory: Memory = None,
|
||||
|
@ -30,14 +30,17 @@ class ReactAgent(BaseAgent):
|
|||
llm_config: LLMConfig = None,
|
||||
embed_config: EmbedConfig = None,
|
||||
sandbox_server: dict = {},
|
||||
jupyter_work_path: str = "",
|
||||
kb_root_path: str = "",
|
||||
jupyter_work_path: str = JUPYTER_WORK_PATH,
|
||||
kb_root_path: str = KB_ROOT_PATH,
|
||||
doc_retrieval: Union[BaseRetriever] = None,
|
||||
code_retrieval = None,
|
||||
search_retrieval = None,
|
||||
log_verbose: str = "0"
|
||||
):
|
||||
|
||||
super().__init__(role, prompt_config, prompt_manager_type, task, memory, chat_turn,
|
||||
focus_agents, focus_message_keys, llm_config, embed_config, sandbox_server,
|
||||
jupyter_work_path, kb_root_path, log_verbose
|
||||
jupyter_work_path, kb_root_path, doc_retrieval, code_retrieval, search_retrieval, log_verbose
|
||||
)
|
||||
|
||||
def step(self, query: Message, history: Memory = None, background: Memory = None, memory_manager: BaseMemoryManager = None) -> Message:
|
||||
|
@ -52,6 +55,7 @@ class ReactAgent(BaseAgent):
|
|||
react_memory = Memory(messages=[])
|
||||
# insert query
|
||||
output_message = Message(
|
||||
user_name=query.user_name,
|
||||
role_name=self.role.role_name,
|
||||
role_type="assistant", #self.role.role_type,
|
||||
role_content=query.input_query,
|
||||
|
@ -84,9 +88,7 @@ class ReactAgent(BaseAgent):
|
|||
llm_config=self.embed_config
|
||||
)
|
||||
memory_manager.append(query)
|
||||
memory_pool = memory_manager.current_memory
|
||||
else:
|
||||
memory_pool = memory_manager.current_memory
|
||||
memory_pool = memory_manager.get_memory_pool(query_c.user_name)
|
||||
|
||||
prompt = self.prompt_manager.generate_full_prompt(
|
||||
previous_agent_message=query_c, agent_long_term_memory=self.memory, ui_history=history, chain_summary_messages=background, react_memory=react_memory,
|
||||
|
@ -142,82 +144,4 @@ class ReactAgent(BaseAgent):
|
|||
title = f"<<<<{self.role.role_name}'s prompt>>>>"
|
||||
print("#"*len(title) + f"\n{title}\n"+ "#"*len(title)+ f"\n\n{prompt}\n\n")
|
||||
|
||||
# def create_prompt(
|
||||
# self, query: Message, memory: Memory =None, history: Memory = None, background: Memory = None, react_memory: Memory = None, memory_manager: BaseMemoryManager= None,
|
||||
# prompt_mamnger=None) -> str:
|
||||
# prompt_mamnger = PromptManager()
|
||||
# prompt_mamnger.register_standard_fields()
|
||||
|
||||
# # input_keys = parse_section(self.role.role_prompt, 'Agent Profile')
|
||||
|
||||
# data_dict = {
|
||||
# "agent_profile": extract_section(self.role.role_prompt, 'Agent Profile'),
|
||||
# "tool_information": query.tools,
|
||||
# "session_records": memory_manager,
|
||||
# "reference_documents": query,
|
||||
# "output_format": extract_section(self.role.role_prompt, 'Response Output Format'),
|
||||
# "response": "\n".join(["\n".join([f"**{k}:**\n{v}" for k,v in _dict.items()]) for _dict in react_memory.get_parserd_output()]),
|
||||
# }
|
||||
# # logger.debug(memory_pool)
|
||||
|
||||
# return prompt_mamnger.generate_full_prompt(data_dict)
|
||||
|
||||
def create_prompt(
|
||||
self, query: Message, memory: Memory =None, history: Memory = None, background: Memory = None, react_memory: Memory = None, memory_pool: Memory= None,
|
||||
prompt_mamnger=None) -> str:
|
||||
'''
|
||||
role\task\tools\docs\memory
|
||||
'''
|
||||
#
|
||||
doc_infos = self.create_doc_prompt(query)
|
||||
code_infos = self.create_codedoc_prompt(query)
|
||||
#
|
||||
formatted_tools, tool_names, _ = self.create_tools_prompt(query)
|
||||
task_prompt = self.create_task_prompt(query)
|
||||
background_prompt = self.create_background_prompt(background)
|
||||
history_prompt = self.create_history_prompt(history)
|
||||
selfmemory_prompt = self.create_selfmemory_prompt(memory, control_key="step_content")
|
||||
#
|
||||
# extra_system_prompt = self.role.role_prompt
|
||||
prompt = self.role.role_prompt.format(**{"formatted_tools": formatted_tools, "tool_names": tool_names})
|
||||
|
||||
# react 流程是自身迭代过程,另外二次触发的是需要作为历史对话信息
|
||||
# input_query = react_memory.to_tuple_messages(content_key="step_content")
|
||||
# # input_query = query.input_query + "\n" + "\n".join([f"{v}" for k, v in input_query if v])
|
||||
# input_query = "\n".join([f"{v}" for k, v in input_query if v])
|
||||
input_query = "\n".join(["\n".join([f"**{k}:**\n{v}" for k,v in _dict.items()]) for _dict in react_memory.get_parserd_output()])
|
||||
# logger.debug(f"input_query: {input_query}")
|
||||
|
||||
prompt += "\n" + REACT_PROMPT_INPUT.format(**{"query": input_query})
|
||||
|
||||
task = query.task or self.task
|
||||
# if task_prompt is not None:
|
||||
# prompt += "\n" + task.task_prompt
|
||||
|
||||
# if doc_infos is not None and doc_infos!="" and doc_infos!="不存在知识库辅助信息":
|
||||
# prompt += f"\n知识库信息: {doc_infos}"
|
||||
|
||||
# if code_infos is not None and code_infos!="" and code_infos!="不存在代码库辅助信息":
|
||||
# prompt += f"\n代码库信息: {code_infos}"
|
||||
|
||||
# if background_prompt:
|
||||
# prompt += "\n" + background_prompt
|
||||
|
||||
# if history_prompt:
|
||||
# prompt += "\n" + history_prompt
|
||||
|
||||
# if selfmemory_prompt:
|
||||
# prompt += "\n" + selfmemory_prompt
|
||||
|
||||
# logger.debug(f"{self.role.role_name} extra_system_prompt: {self.role.role_prompt}")
|
||||
# logger.debug(f"{self.role.role_name} input_query: {input_query}")
|
||||
# logger.debug(f"{self.role.role_name} doc_infos: {doc_infos}")
|
||||
# logger.debug(f"{self.role.role_name} tool_names: {tool_names}")
|
||||
# prompt += "\n" + REACT_PROMPT_INPUT.format(**{"query": input_query})
|
||||
|
||||
# prompt = extra_system_prompt.format(**{"query": input_query, "doc_infos": doc_infos, "formatted_tools": formatted_tools, "tool_names": tool_names})
|
||||
while "{{" in prompt or "}}" in prompt:
|
||||
prompt = prompt.replace("{{", "{")
|
||||
prompt = prompt.replace("}}", "}")
|
||||
return prompt
|
||||
|
||||
|
|
|
@ -3,13 +3,15 @@ import copy
|
|||
import random
|
||||
from loguru import logger
|
||||
|
||||
from langchain.schema import BaseRetriever
|
||||
|
||||
from coagent.connector.schema import (
|
||||
Memory, Task, Role, Message, PromptField, LogVerboseEnum
|
||||
)
|
||||
from coagent.connector.memory_manager import BaseMemoryManager
|
||||
from coagent.connector.configs.prompts import BEGIN_PROMPT_INPUT
|
||||
from coagent.connector.memory_manager import LocalMemoryManager
|
||||
from coagent.llm_models import LLMConfig, EmbedConfig
|
||||
from coagent.base_configs.env_config import JUPYTER_WORK_PATH, KB_ROOT_PATH
|
||||
from .base_agent import BaseAgent
|
||||
|
||||
|
||||
|
@ -30,14 +32,17 @@ class SelectorAgent(BaseAgent):
|
|||
llm_config: LLMConfig = None,
|
||||
embed_config: EmbedConfig = None,
|
||||
sandbox_server: dict = {},
|
||||
jupyter_work_path: str = "",
|
||||
kb_root_path: str = "",
|
||||
jupyter_work_path: str = JUPYTER_WORK_PATH,
|
||||
kb_root_path: str = KB_ROOT_PATH,
|
||||
doc_retrieval: Union[BaseRetriever] = None,
|
||||
code_retrieval = None,
|
||||
search_retrieval = None,
|
||||
log_verbose: str = "0"
|
||||
):
|
||||
|
||||
super().__init__(role, prompt_config, prompt_manager_type, task, memory, chat_turn,
|
||||
focus_agents, focus_message_keys, llm_config, embed_config, sandbox_server,
|
||||
jupyter_work_path, kb_root_path, log_verbose
|
||||
jupyter_work_path, kb_root_path, doc_retrieval, code_retrieval, search_retrieval, log_verbose
|
||||
)
|
||||
self.group_agents = group_agents
|
||||
|
||||
|
@ -56,9 +61,8 @@ class SelectorAgent(BaseAgent):
|
|||
llm_config=self.embed_config
|
||||
)
|
||||
memory_manager.append(query)
|
||||
memory_pool = memory_manager.current_memory
|
||||
else:
|
||||
memory_pool = memory_manager.current_memory
|
||||
memory_pool = memory_manager.get_memory_pool(query_c.user_name)
|
||||
|
||||
prompt = self.prompt_manager.generate_full_prompt(
|
||||
previous_agent_message=query_c, agent_long_term_memory=self.memory, ui_history=history, chain_summary_messages=background, react_memory=None,
|
||||
memory_pool=memory_pool, agents=self.group_agents)
|
||||
|
@ -90,6 +94,9 @@ class SelectorAgent(BaseAgent):
|
|||
for agent in self.group_agents:
|
||||
if agent.role.role_name == select_message.parsed_output.get("Role", ""):
|
||||
break
|
||||
|
||||
# 把除了role以外的信息传给下一个agent
|
||||
query_c.parsed_output.update({k:v for k,v in select_message.parsed_output.items() if k!="Role"})
|
||||
for output_message in agent.astep(query_c, history, background=background, memory_manager=memory_manager):
|
||||
yield output_message or select_message
|
||||
# update self_memory
|
||||
|
@ -103,6 +110,7 @@ class SelectorAgent(BaseAgent):
|
|||
memory_manager.append(output_message)
|
||||
|
||||
select_message.parsed_output = output_message.parsed_output
|
||||
select_message.spec_parsed_output.update(output_message.spec_parsed_output)
|
||||
select_message.parsed_output_list.extend(output_message.parsed_output_list)
|
||||
yield select_message
|
||||
|
||||
|
@ -114,77 +122,4 @@ class SelectorAgent(BaseAgent):
|
|||
print("#"*len(title) + f"\n{title}\n"+ "#"*len(title)+ f"\n\n{prompt}\n\n")
|
||||
|
||||
for agent in self.group_agents:
|
||||
agent.pre_print(query=query, history=history, background=background, memory_manager=memory_manager)
|
||||
|
||||
# def create_prompt(
|
||||
# self, query: Message, memory: Memory =None, history: Memory = None, background: Memory = None, memory_manager: BaseMemoryManager=None, prompt_mamnger=None) -> str:
|
||||
# '''
|
||||
# role\task\tools\docs\memory
|
||||
# '''
|
||||
# #
|
||||
# doc_infos = self.create_doc_prompt(query)
|
||||
# code_infos = self.create_codedoc_prompt(query)
|
||||
# #
|
||||
# formatted_tools, tool_names, tools_descs = self.create_tools_prompt(query)
|
||||
# agent_names, agents = self.create_agent_names()
|
||||
# task_prompt = self.create_task_prompt(query)
|
||||
# background_prompt = self.create_background_prompt(background)
|
||||
# history_prompt = self.create_history_prompt(history)
|
||||
# selfmemory_prompt = self.create_selfmemory_prompt(memory, control_key="step_content")
|
||||
|
||||
|
||||
# DocInfos = ""
|
||||
# if doc_infos is not None and doc_infos!="" and doc_infos!="不存在知识库辅助信息":
|
||||
# DocInfos += f"\nDocument Information: {doc_infos}"
|
||||
|
||||
# if code_infos is not None and code_infos!="" and code_infos!="不存在代码库辅助信息":
|
||||
# DocInfos += f"\nCodeBase Infomation: {code_infos}"
|
||||
|
||||
# input_query = query.input_query
|
||||
# logger.debug(f"{self.role.role_name} input_query: {input_query}")
|
||||
# prompt = self.role.role_prompt.format(**{"agent_names": agent_names, "agents": agents, "formatted_tools": tools_descs, "tool_names": tool_names})
|
||||
# #
|
||||
# memory_pool_select_by_agent_key = self.select_memory_by_agent_key(memory_manager.current_memory)
|
||||
# memory_pool_select_by_agent_key_context = '\n\n'.join([f"*{k}*\n{v}" for parsed_output in memory_pool_select_by_agent_key.get_parserd_output_list() for k, v in parsed_output.items() if k not in ['Action Status']])
|
||||
|
||||
# input_keys = parse_section(self.role.role_prompt, 'Input Format')
|
||||
# #
|
||||
# prompt += "\n" + BEGIN_PROMPT_INPUT
|
||||
# for input_key in input_keys:
|
||||
# if input_key == "Origin Query":
|
||||
# prompt += "\n**Origin Query:**\n" + query.origin_query
|
||||
# elif input_key == "Context":
|
||||
# context = "\n".join([f"*{k}*\n{v}" for i in query.parsed_output_list for k,v in i.items() if "Action Status" !=k])
|
||||
# if history:
|
||||
# context = history_prompt + "\n" + context
|
||||
# if not context:
|
||||
# context = "there is no context"
|
||||
|
||||
# if self.focus_agents and memory_pool_select_by_agent_key_context:
|
||||
# context = memory_pool_select_by_agent_key_context
|
||||
# prompt += "\n**Context:**\n" + context + "\n" + input_query
|
||||
# elif input_key == "DocInfos":
|
||||
# prompt += "\n**DocInfos:**\n" + DocInfos
|
||||
# elif input_key == "Question":
|
||||
# prompt += "\n**Question:**\n" + input_query
|
||||
|
||||
# while "{{" in prompt or "}}" in prompt:
|
||||
# prompt = prompt.replace("{{", "{")
|
||||
# prompt = prompt.replace("}}", "}")
|
||||
|
||||
# # logger.debug(f"{self.role.role_name} prompt: {prompt}")
|
||||
# return prompt
|
||||
|
||||
# def create_agent_names(self):
|
||||
# random.shuffle(self.group_agents)
|
||||
# agent_names = ", ".join([f'{agent.role.role_name}' for agent in self.group_agents])
|
||||
# agent_descs = []
|
||||
# for agent in self.group_agents:
|
||||
# role_desc = agent.role.role_prompt.split("####")[1]
|
||||
# while "\n\n" in role_desc:
|
||||
# role_desc = role_desc.replace("\n\n", "\n")
|
||||
# role_desc = role_desc.replace("\n", ",")
|
||||
|
||||
# agent_descs.append(f'"role name: {agent.role.role_name}\nrole description: {role_desc}"')
|
||||
|
||||
# return agent_names, "\n".join(agent_descs)
|
||||
agent.pre_print(query=query, history=history, background=background, memory_manager=memory_manager)
|
|
@ -0,0 +1,7 @@
|
|||
from .flow import AgentFlow, PhaseFlow, ChainFlow
|
||||
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AgentFlow", "PhaseFlow", "ChainFlow"
|
||||
]
|
|
@ -0,0 +1,255 @@
|
|||
import importlib
|
||||
from typing import List, Union, Dict, Any
|
||||
from loguru import logger
|
||||
import os
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.agents import Tool
|
||||
from langchain.llms.base import BaseLLM, LLM
|
||||
|
||||
from coagent.retrieval.base_retrieval import IMRertrieval
|
||||
from coagent.llm_models.llm_config import EmbedConfig, LLMConfig
|
||||
from coagent.connector.phase import BasePhase
|
||||
from coagent.connector.agents import BaseAgent
|
||||
from coagent.connector.chains import BaseChain
|
||||
from coagent.connector.schema import Message, Role, PromptField, ChainConfig
|
||||
from coagent.tools import toLangchainTools, TOOL_DICT, TOOL_SETS
|
||||
|
||||
|
||||
class AgentFlow:
|
||||
def __init__(
|
||||
self,
|
||||
role_name: str,
|
||||
agent_type: str,
|
||||
role_type: str = "assistant",
|
||||
agent_index: int = 0,
|
||||
role_prompt: str = "",
|
||||
prompt_config: List[Dict[str, Any]] = [],
|
||||
prompt_manager_type: str = "PromptManager",
|
||||
chat_turn: int = 3,
|
||||
focus_agents: List[str] = [],
|
||||
focus_messages: List[str] = [],
|
||||
embeddings: Embeddings = None,
|
||||
llm: BaseLLM = None,
|
||||
doc_retrieval: IMRertrieval = None,
|
||||
code_retrieval: IMRertrieval = None,
|
||||
search_retrieval: IMRertrieval = None,
|
||||
**kwargs
|
||||
):
|
||||
self.role_type = role_type
|
||||
self.role_name = role_name
|
||||
self.agent_type = agent_type
|
||||
self.role_prompt = role_prompt
|
||||
self.agent_index = agent_index
|
||||
|
||||
self.prompt_config = prompt_config
|
||||
self.prompt_manager_type = prompt_manager_type
|
||||
|
||||
self.chat_turn = chat_turn
|
||||
self.focus_agents = focus_agents
|
||||
self.focus_messages = focus_messages
|
||||
|
||||
self.embeddings = embeddings
|
||||
self.llm = llm
|
||||
self.doc_retrieval = doc_retrieval
|
||||
self.code_retrieval = code_retrieval
|
||||
self.search_retrieval = search_retrieval
|
||||
# self.build_config()
|
||||
# self.build_agent()
|
||||
|
||||
def build_config(self, embeddings: Embeddings = None, llm: BaseLLM = None):
|
||||
self.llm_config = LLMConfig(model_name="test", llm=self.llm or llm)
|
||||
self.embed_config = EmbedConfig(embed_model="test", langchain_embeddings=self.embeddings or embeddings)
|
||||
|
||||
def build_agent(self,
|
||||
embeddings: Embeddings = None, llm: BaseLLM = None,
|
||||
doc_retrieval: IMRertrieval = None,
|
||||
code_retrieval: IMRertrieval = None,
|
||||
search_retrieval: IMRertrieval = None,
|
||||
):
|
||||
# 可注册个性化的agent,仅通过start_action和end_action来注册
|
||||
# class ExtraAgent(BaseAgent):
|
||||
# def start_action_step(self, message: Message) -> Message:
|
||||
# pass
|
||||
|
||||
# def end_action_step(self, message: Message) -> Message:
|
||||
# pass
|
||||
# agent_module = importlib.import_module("coagent.connector.agents")
|
||||
# setattr(agent_module, 'extraAgent', ExtraAgent)
|
||||
|
||||
# 可注册个性化的prompt组装方式,
|
||||
# class CodeRetrievalPM(PromptManager):
|
||||
# def handle_code_packages(self, **kwargs) -> str:
|
||||
# if 'previous_agent_message' not in kwargs:
|
||||
# return ""
|
||||
# previous_agent_message: Message = kwargs['previous_agent_message']
|
||||
# # 由于两个agent共用了同一个manager,所以临时性处理
|
||||
# vertices = previous_agent_message.customed_kargs.get("RelatedVerticesRetrivalRes", {}).get("vertices", [])
|
||||
# return ", ".join([str(v) for v in vertices])
|
||||
|
||||
# prompt_manager_module = importlib.import_module("coagent.connector.prompt_manager")
|
||||
# setattr(prompt_manager_module, 'CodeRetrievalPM', CodeRetrievalPM)
|
||||
|
||||
# agent实例化
|
||||
agent_module = importlib.import_module("coagent.connector.agents")
|
||||
baseAgent: BaseAgent = getattr(agent_module, self.agent_type)
|
||||
role = Role(
|
||||
role_type=self.agent_type, role_name=self.role_name,
|
||||
agent_type=self.agent_type, role_prompt=self.role_prompt,
|
||||
)
|
||||
|
||||
self.build_config(embeddings, llm)
|
||||
self.agent = baseAgent(
|
||||
role=role,
|
||||
prompt_config = [PromptField(**config) for config in self.prompt_config],
|
||||
prompt_manager_type=self.prompt_manager_type,
|
||||
chat_turn=self.chat_turn,
|
||||
focus_agents=self.focus_agents,
|
||||
focus_message_keys=self.focus_messages,
|
||||
llm_config=self.llm_config,
|
||||
embed_config=self.embed_config,
|
||||
doc_retrieval=doc_retrieval or self.doc_retrieval,
|
||||
code_retrieval=code_retrieval or self.code_retrieval,
|
||||
search_retrieval=search_retrieval or self.search_retrieval,
|
||||
)
|
||||
|
||||
|
||||
|
||||
class ChainFlow:
|
||||
def __init__(
|
||||
self,
|
||||
chain_name: str,
|
||||
chain_index: int = 0,
|
||||
agent_flows: List[AgentFlow] = [],
|
||||
chat_turn: int = 5,
|
||||
do_checker: bool = False,
|
||||
embeddings: Embeddings = None,
|
||||
llm: BaseLLM = None,
|
||||
doc_retrieval: IMRertrieval = None,
|
||||
code_retrieval: IMRertrieval = None,
|
||||
search_retrieval: IMRertrieval = None,
|
||||
# chain_type: str = "BaseChain",
|
||||
**kwargs
|
||||
):
|
||||
self.agent_flows = sorted(agent_flows, key=lambda x:x.agent_index)
|
||||
self.chat_turn = chat_turn
|
||||
self.do_checker = do_checker
|
||||
self.chain_name = chain_name
|
||||
self.chain_index = chain_index
|
||||
self.chain_type = "BaseChain"
|
||||
|
||||
self.embeddings = embeddings
|
||||
self.llm = llm
|
||||
|
||||
self.doc_retrieval = doc_retrieval
|
||||
self.code_retrieval = code_retrieval
|
||||
self.search_retrieval = search_retrieval
|
||||
# self.build_config()
|
||||
# self.build_chain()
|
||||
|
||||
def build_config(self, embeddings: Embeddings = None, llm: BaseLLM = None):
|
||||
self.llm_config = LLMConfig(model_name="test", llm=self.llm or llm)
|
||||
self.embed_config = EmbedConfig(embed_model="test", langchain_embeddings=self.embeddings or embeddings)
|
||||
|
||||
def build_chain(self,
|
||||
embeddings: Embeddings = None, llm: BaseLLM = None,
|
||||
doc_retrieval: IMRertrieval = None,
|
||||
code_retrieval: IMRertrieval = None,
|
||||
search_retrieval: IMRertrieval = None,
|
||||
):
|
||||
# chain 实例化
|
||||
chain_module = importlib.import_module("coagent.connector.chains")
|
||||
baseChain: BaseChain = getattr(chain_module, self.chain_type)
|
||||
|
||||
agent_names = [agent_flow.role_name for agent_flow in self.agent_flows]
|
||||
chain_config = ChainConfig(chain_name=self.chain_name, agents=agent_names, do_checker=self.do_checker, chat_turn=self.chat_turn)
|
||||
|
||||
# agent 实例化
|
||||
self.build_config(embeddings, llm)
|
||||
for agent_flow in self.agent_flows:
|
||||
agent_flow.build_agent(embeddings, llm)
|
||||
|
||||
self.chain = baseChain(
|
||||
chain_config,
|
||||
[agent_flow.agent for agent_flow in self.agent_flows],
|
||||
embed_config=self.embed_config,
|
||||
llm_config=self.llm_config,
|
||||
doc_retrieval=doc_retrieval or self.doc_retrieval,
|
||||
code_retrieval=code_retrieval or self.code_retrieval,
|
||||
search_retrieval=search_retrieval or self.search_retrieval,
|
||||
)
|
||||
|
||||
class PhaseFlow:
|
||||
def __init__(
|
||||
self,
|
||||
phase_name: str,
|
||||
chain_flows: List[ChainFlow],
|
||||
embeddings: Embeddings = None,
|
||||
llm: BaseLLM = None,
|
||||
tools: List[Tool] = [],
|
||||
doc_retrieval: IMRertrieval = None,
|
||||
code_retrieval: IMRertrieval = None,
|
||||
search_retrieval: IMRertrieval = None,
|
||||
**kwargs
|
||||
):
|
||||
self.phase_name = phase_name
|
||||
self.chain_flows = sorted(chain_flows, key=lambda x:x.chain_index)
|
||||
self.phase_type = "BasePhase"
|
||||
self.tools = tools
|
||||
|
||||
self.embeddings = embeddings
|
||||
self.llm = llm
|
||||
|
||||
self.doc_retrieval = doc_retrieval
|
||||
self.code_retrieval = code_retrieval
|
||||
self.search_retrieval = search_retrieval
|
||||
# self.build_config()
|
||||
self.build_phase()
|
||||
|
||||
def __call__(self, params: dict) -> str:
|
||||
|
||||
# tools = toLangchainTools([TOOL_DICT[i] for i in TOOL_SETS if i in TOOL_DICT])
|
||||
# query_content = "帮我确认下127.0.0.1这个服务器的在10点是否存在异常,请帮我判断一下"
|
||||
try:
|
||||
logger.info(f"params: {params}")
|
||||
query_content = params.get("query") or params.get("input")
|
||||
search_type = params.get("search_type")
|
||||
query = Message(
|
||||
role_name="human", role_type="user", tools=self.tools,
|
||||
role_content=query_content, input_query=query_content, origin_query=query_content,
|
||||
cb_search_type=search_type,
|
||||
)
|
||||
# phase.pre_print(query)
|
||||
output_message, output_memory = self.phase.step(query)
|
||||
output_content = "\n\n".join((output_memory.to_str_messages(return_all=True, content_key="parsed_output_list").split("\n\n")[1:])) or output_message.role_content
|
||||
return output_content
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
return f"Error {e}"
|
||||
|
||||
def build_config(self, embeddings: Embeddings = None, llm: BaseLLM = None):
|
||||
self.llm_config = LLMConfig(model_name="test", llm=self.llm or llm)
|
||||
self.embed_config = EmbedConfig(embed_model="test", langchain_embeddings=self.embeddings or embeddings)
|
||||
|
||||
def build_phase(self, embeddings: Embeddings = None, llm: BaseLLM = None):
|
||||
# phase 实例化
|
||||
phase_module = importlib.import_module("coagent.connector.phase")
|
||||
basePhase: BasePhase = getattr(phase_module, self.phase_type)
|
||||
|
||||
# chain 实例化
|
||||
self.build_config(self.embeddings or embeddings, self.llm or llm)
|
||||
os.environ["log_verbose"] = "2"
|
||||
for chain_flow in self.chain_flows:
|
||||
chain_flow.build_chain(
|
||||
self.embeddings or embeddings, self.llm or llm,
|
||||
self.doc_retrieval, self.code_retrieval, self.search_retrieval
|
||||
)
|
||||
|
||||
self.phase: BasePhase = basePhase(
|
||||
phase_name=self.phase_name,
|
||||
chains=[chain_flow.chain for chain_flow in self.chain_flows],
|
||||
embed_config=self.embed_config,
|
||||
llm_config=self.llm_config,
|
||||
doc_retrieval=self.doc_retrieval,
|
||||
code_retrieval=self.code_retrieval,
|
||||
search_retrieval=self.search_retrieval
|
||||
)
|
|
@ -1,9 +1,10 @@
|
|||
from typing import List
|
||||
from typing import List, Tuple, Union
|
||||
from loguru import logger
|
||||
import copy, os
|
||||
|
||||
from coagent.connector.agents import BaseAgent
|
||||
from langchain.schema import BaseRetriever
|
||||
|
||||
from coagent.connector.agents import BaseAgent
|
||||
from coagent.connector.schema import (
|
||||
Memory, Role, Message, ActionStatus, ChainConfig,
|
||||
load_role_configs
|
||||
|
@ -11,31 +12,32 @@ from coagent.connector.schema import (
|
|||
from coagent.connector.memory_manager import BaseMemoryManager
|
||||
from coagent.connector.message_process import MessageUtils
|
||||
from coagent.llm_models.llm_config import LLMConfig, EmbedConfig
|
||||
from coagent.base_configs.env_config import JUPYTER_WORK_PATH, KB_ROOT_PATH
|
||||
from coagent.connector.configs.agent_config import AGETN_CONFIGS
|
||||
role_configs = load_role_configs(AGETN_CONFIGS)
|
||||
|
||||
# from configs.model_config import JUPYTER_WORK_PATH
|
||||
# from configs.server_config import SANDBOX_SERVER
|
||||
|
||||
|
||||
class BaseChain:
|
||||
def __init__(
|
||||
self,
|
||||
# chainConfig: ChainConfig,
|
||||
chainConfig: ChainConfig,
|
||||
agents: List[BaseAgent],
|
||||
chat_turn: int = 1,
|
||||
do_checker: bool = False,
|
||||
# chat_turn: int = 1,
|
||||
# do_checker: bool = False,
|
||||
sandbox_server: dict = {},
|
||||
jupyter_work_path: str = "",
|
||||
kb_root_path: str = "",
|
||||
jupyter_work_path: str = JUPYTER_WORK_PATH,
|
||||
kb_root_path: str = KB_ROOT_PATH,
|
||||
llm_config: LLMConfig = LLMConfig(),
|
||||
embed_config: EmbedConfig = None,
|
||||
doc_retrieval: Union[BaseRetriever] = None,
|
||||
code_retrieval = None,
|
||||
search_retrieval = None,
|
||||
log_verbose: str = "0"
|
||||
) -> None:
|
||||
# self.chainConfig = chainConfig
|
||||
self.chainConfig = chainConfig
|
||||
self.agents: List[BaseAgent] = agents
|
||||
self.chat_turn = chat_turn
|
||||
self.do_checker = do_checker
|
||||
self.chat_turn = chainConfig.chat_turn
|
||||
self.do_checker = chainConfig.do_checker
|
||||
self.sandbox_server = sandbox_server
|
||||
self.jupyter_work_path = jupyter_work_path
|
||||
self.llm_config = llm_config
|
||||
|
@ -45,9 +47,11 @@ class BaseChain:
|
|||
task = None, memory = None,
|
||||
llm_config=llm_config, embed_config=embed_config,
|
||||
sandbox_server=sandbox_server, jupyter_work_path=jupyter_work_path,
|
||||
kb_root_path=kb_root_path
|
||||
kb_root_path=kb_root_path,
|
||||
doc_retrieval=doc_retrieval, code_retrieval=code_retrieval,
|
||||
search_retrieval=search_retrieval
|
||||
)
|
||||
self.messageUtils = MessageUtils(None, sandbox_server, self.jupyter_work_path, embed_config, llm_config, kb_root_path, log_verbose)
|
||||
self.messageUtils = MessageUtils(None, sandbox_server, self.jupyter_work_path, embed_config, llm_config, kb_root_path, doc_retrieval, code_retrieval, search_retrieval, log_verbose)
|
||||
# all memory created by agent until instance deleted
|
||||
self.global_memory = Memory(messages=[])
|
||||
|
||||
|
@ -62,13 +66,16 @@ class BaseChain:
|
|||
for agent in self.agents:
|
||||
agent.pre_print(query, history, background=background, memory_manager=memory_manager)
|
||||
|
||||
def astep(self, query: Message, history: Memory = None, background: Memory = None, memory_manager: BaseMemoryManager = None) -> Message:
|
||||
def astep(self, query: Message, history: Memory = None, background: Memory = None, memory_manager: BaseMemoryManager = None) -> Tuple[Message, Memory]:
|
||||
'''execute chain'''
|
||||
local_memory = Memory(messages=[])
|
||||
input_message = copy.deepcopy(query)
|
||||
step_nums = copy.deepcopy(self.chat_turn)
|
||||
check_message = None
|
||||
|
||||
# if input_message not in memory_manager:
|
||||
# memory_manager.append(input_message)
|
||||
|
||||
self.global_memory.append(input_message)
|
||||
# local_memory.append(input_message)
|
||||
while step_nums > 0:
|
||||
|
@ -78,7 +85,7 @@ class BaseChain:
|
|||
yield output_message, local_memory + output_message
|
||||
output_message = self.messageUtils.inherit_extrainfo(input_message, output_message)
|
||||
# according the output to choose one action for code_content or tool_content
|
||||
output_message = self.messageUtils.parser(output_message)
|
||||
# output_message = self.messageUtils.parser(output_message)
|
||||
yield output_message, local_memory + output_message
|
||||
# output_message = self.step_router(output_message)
|
||||
input_message = output_message
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
from .agent_config import AGETN_CONFIGS
|
||||
from .chain_config import CHAIN_CONFIGS
|
||||
from .phase_config import PHASE_CONFIGS
|
||||
from .prompt_config import BASE_PROMPT_CONFIGS, EXECUTOR_PROMPT_CONFIGS, SELECTOR_PROMPT_CONFIGS, BASE_NOTOOLPROMPT_CONFIGS
|
||||
from .prompt_config import *
|
||||
|
||||
__all__ = [
|
||||
"AGETN_CONFIGS", "CHAIN_CONFIGS", "PHASE_CONFIGS",
|
||||
"BASE_PROMPT_CONFIGS", "EXECUTOR_PROMPT_CONFIGS", "SELECTOR_PROMPT_CONFIGS", "BASE_NOTOOLPROMPT_CONFIGS"
|
||||
"BASE_PROMPT_CONFIGS", "EXECUTOR_PROMPT_CONFIGS", "SELECTOR_PROMPT_CONFIGS", "BASE_NOTOOLPROMPT_CONFIGS",
|
||||
"CODE2DOC_GROUP_PROMPT_CONFIGS", "CODE2DOC_PROMPT_CONFIGS", "CODE2TESTS_PROMPT_CONFIGS"
|
||||
]
|
|
@ -1,19 +1,21 @@
|
|||
from enum import Enum
|
||||
from .prompts import (
|
||||
REACT_PROMPT_INPUT, CHECK_PROMPT_INPUT, EXECUTOR_PROMPT_INPUT, CONTEXT_PROMPT_INPUT, QUERY_CONTEXT_PROMPT_INPUT,PLAN_PROMPT_INPUT,
|
||||
RECOGNIZE_INTENTION_PROMPT,
|
||||
CHECKER_TEMPLATE_PROMPT,
|
||||
CONV_SUMMARY_PROMPT,
|
||||
QA_PROMPT, CODE_QA_PROMPT, QA_TEMPLATE_PROMPT,
|
||||
EXECUTOR_TEMPLATE_PROMPT,
|
||||
REFINE_TEMPLATE_PROMPT,
|
||||
SELECTOR_AGENT_TEMPLATE_PROMPT,
|
||||
PLANNER_TEMPLATE_PROMPT, GENERAL_PLANNER_PROMPT, DATA_PLANNER_PROMPT, TOOL_PLANNER_PROMPT,
|
||||
PRD_WRITER_METAGPT_PROMPT, DESIGN_WRITER_METAGPT_PROMPT, TASK_WRITER_METAGPT_PROMPT, CODE_WRITER_METAGPT_PROMPT,
|
||||
REACT_TEMPLATE_PROMPT,
|
||||
REACT_TOOL_PROMPT, REACT_CODE_PROMPT, REACT_TOOL_AND_CODE_PLANNER_PROMPT, REACT_TOOL_AND_CODE_PROMPT
|
||||
)
|
||||
from .prompt_config import BASE_PROMPT_CONFIGS, EXECUTOR_PROMPT_CONFIGS, SELECTOR_PROMPT_CONFIGS, BASE_NOTOOLPROMPT_CONFIGS
|
||||
from .prompts import *
|
||||
# from .prompts import (
|
||||
# REACT_PROMPT_INPUT, CHECK_PROMPT_INPUT, EXECUTOR_PROMPT_INPUT, CONTEXT_PROMPT_INPUT, QUERY_CONTEXT_PROMPT_INPUT,PLAN_PROMPT_INPUT,
|
||||
# RECOGNIZE_INTENTION_PROMPT,
|
||||
# CHECKER_TEMPLATE_PROMPT,
|
||||
# CONV_SUMMARY_PROMPT,
|
||||
# QA_PROMPT, CODE_QA_PROMPT, QA_TEMPLATE_PROMPT,
|
||||
# EXECUTOR_TEMPLATE_PROMPT,
|
||||
# REFINE_TEMPLATE_PROMPT,
|
||||
# SELECTOR_AGENT_TEMPLATE_PROMPT,
|
||||
# PLANNER_TEMPLATE_PROMPT, GENERAL_PLANNER_PROMPT, DATA_PLANNER_PROMPT, TOOL_PLANNER_PROMPT,
|
||||
# PRD_WRITER_METAGPT_PROMPT, DESIGN_WRITER_METAGPT_PROMPT, TASK_WRITER_METAGPT_PROMPT, CODE_WRITER_METAGPT_PROMPT,
|
||||
# REACT_TEMPLATE_PROMPT,
|
||||
# REACT_TOOL_PROMPT, REACT_CODE_PROMPT, REACT_TOOL_AND_CODE_PLANNER_PROMPT, REACT_TOOL_AND_CODE_PROMPT
|
||||
# )
|
||||
from .prompt_config import *
|
||||
# BASE_PROMPT_CONFIGS, EXECUTOR_PROMPT_CONFIGS, SELECTOR_PROMPT_CONFIGS, BASE_NOTOOLPROMPT_CONFIGS
|
||||
|
||||
|
||||
|
||||
|
@ -261,4 +263,68 @@ AGETN_CONFIGS = {
|
|||
"focus_agents": ["metaGPT_DESIGN", "metaGPT_TASK"],
|
||||
"focus_message_keys": [],
|
||||
},
|
||||
"class2Docer": {
|
||||
"role": {
|
||||
"role_prompt": Class2Doc_PROMPT,
|
||||
"role_type": "assistant",
|
||||
"role_name": "class2Docer",
|
||||
"role_desc": "",
|
||||
"agent_type": "CodeGenDocer"
|
||||
},
|
||||
"prompt_config": CODE2DOC_PROMPT_CONFIGS,
|
||||
"prompt_manager_type": "Code2DocPM",
|
||||
"chat_turn": 1,
|
||||
"focus_agents": [],
|
||||
"focus_message_keys": [],
|
||||
},
|
||||
"func2Docer": {
|
||||
"role": {
|
||||
"role_prompt": Func2Doc_PROMPT,
|
||||
"role_type": "assistant",
|
||||
"role_name": "func2Docer",
|
||||
"role_desc": "",
|
||||
"agent_type": "CodeGenDocer"
|
||||
},
|
||||
"prompt_config": CODE2DOC_PROMPT_CONFIGS,
|
||||
"prompt_manager_type": "Code2DocPM",
|
||||
"chat_turn": 1,
|
||||
"focus_agents": [],
|
||||
"focus_message_keys": [],
|
||||
},
|
||||
"code2DocsGrouper": {
|
||||
"role": {
|
||||
"role_prompt": Code2DocGroup_PROMPT,
|
||||
"role_type": "assistant",
|
||||
"role_name": "code2DocsGrouper",
|
||||
"role_desc": "",
|
||||
"agent_type": "SelectorAgent"
|
||||
},
|
||||
"prompt_config": CODE2DOC_GROUP_PROMPT_CONFIGS,
|
||||
"group_agents": ["class2Docer", "func2Docer"],
|
||||
"chat_turn": 1,
|
||||
},
|
||||
"Code2TestJudger": {
|
||||
"role": {
|
||||
"role_prompt": judgeCode2Tests_PROMPT,
|
||||
"role_type": "assistant",
|
||||
"role_name": "Code2TestJudger",
|
||||
"role_desc": "",
|
||||
"agent_type": "CodeRetrieval"
|
||||
},
|
||||
"prompt_config": CODE2TESTS_PROMPT_CONFIGS,
|
||||
"prompt_manager_type": "CodeRetrievalPM",
|
||||
"chat_turn": 1,
|
||||
},
|
||||
"code2Tests": {
|
||||
"role": {
|
||||
"role_prompt": code2Tests_PROMPT,
|
||||
"role_type": "assistant",
|
||||
"role_name": "code2Tests",
|
||||
"role_desc": "",
|
||||
"agent_type": "CodeRetrieval"
|
||||
},
|
||||
"prompt_config": CODE2TESTS_PROMPT_CONFIGS,
|
||||
"prompt_manager_type": "CodeRetrievalPM",
|
||||
"chat_turn": 1,
|
||||
},
|
||||
}
|
|
@ -123,5 +123,21 @@ CHAIN_CONFIGS = {
|
|||
"chat_turn": 1,
|
||||
"do_checker": False,
|
||||
"chain_prompt": ""
|
||||
},
|
||||
"code2DocsGroupChain": {
|
||||
"chain_name": "code2DocsGroupChain",
|
||||
"chain_type": "BaseChain",
|
||||
"agents": ["code2DocsGrouper"],
|
||||
"chat_turn": 1,
|
||||
"do_checker": False,
|
||||
"chain_prompt": ""
|
||||
},
|
||||
"code2TestsChain": {
|
||||
"chain_name": "code2TestsChain",
|
||||
"chain_type": "BaseChain",
|
||||
"agents": ["Code2TestJudger", "code2Tests"],
|
||||
"chat_turn": 1,
|
||||
"do_checker": False,
|
||||
"chain_prompt": ""
|
||||
}
|
||||
}
|
||||
|
|
|
@ -14,44 +14,24 @@ PHASE_CONFIGS = {
|
|||
"phase_name": "docChatPhase",
|
||||
"phase_type": "BasePhase",
|
||||
"chains": ["docChatChain"],
|
||||
"do_summary": False,
|
||||
"do_search": False,
|
||||
"do_doc_retrieval": True,
|
||||
"do_code_retrieval": False,
|
||||
"do_tool_retrieval": False,
|
||||
"do_using_tool": False
|
||||
},
|
||||
"searchChatPhase": {
|
||||
"phase_name": "searchChatPhase",
|
||||
"phase_type": "BasePhase",
|
||||
"chains": ["searchChatChain"],
|
||||
"do_summary": False,
|
||||
"do_search": True,
|
||||
"do_doc_retrieval": False,
|
||||
"do_code_retrieval": False,
|
||||
"do_tool_retrieval": False,
|
||||
"do_using_tool": False
|
||||
},
|
||||
"codeChatPhase": {
|
||||
"phase_name": "codeChatPhase",
|
||||
"phase_type": "BasePhase",
|
||||
"chains": ["codeChatChain"],
|
||||
"do_summary": False,
|
||||
"do_search": False,
|
||||
"do_doc_retrieval": False,
|
||||
"do_code_retrieval": True,
|
||||
"do_tool_retrieval": False,
|
||||
"do_using_tool": False
|
||||
},
|
||||
"toolReactPhase": {
|
||||
"phase_name": "toolReactPhase",
|
||||
"phase_type": "BasePhase",
|
||||
"chains": ["toolReactChain"],
|
||||
"do_summary": False,
|
||||
"do_search": False,
|
||||
"do_doc_retrieval": False,
|
||||
"do_code_retrieval": False,
|
||||
"do_tool_retrieval": False,
|
||||
"do_using_tool": True
|
||||
},
|
||||
"codeReactPhase": {
|
||||
|
@ -59,55 +39,36 @@ PHASE_CONFIGS = {
|
|||
"phase_type": "BasePhase",
|
||||
# "chains": ["codePlannerChain", "codeReactChain"],
|
||||
"chains": ["planChain", "codeReactChain"],
|
||||
"do_summary": False,
|
||||
"do_search": False,
|
||||
"do_doc_retrieval": False,
|
||||
"do_code_retrieval": False,
|
||||
"do_tool_retrieval": False,
|
||||
"do_using_tool": False
|
||||
},
|
||||
"codeToolReactPhase": {
|
||||
"phase_name": "codeToolReactPhase",
|
||||
"phase_type": "BasePhase",
|
||||
"chains": ["codeToolPlanChain", "codeToolReactChain"],
|
||||
"do_summary": False,
|
||||
"do_search": False,
|
||||
"do_doc_retrieval": False,
|
||||
"do_code_retrieval": False,
|
||||
"do_tool_retrieval": False,
|
||||
"do_using_tool": True
|
||||
},
|
||||
"baseTaskPhase": {
|
||||
"phase_name": "baseTaskPhase",
|
||||
"phase_type": "BasePhase",
|
||||
"chains": ["planChain", "executorChain"],
|
||||
"do_summary": False,
|
||||
"do_search": False,
|
||||
"do_doc_retrieval": False,
|
||||
"do_code_retrieval": False,
|
||||
"do_tool_retrieval": False,
|
||||
"do_using_tool": False
|
||||
},
|
||||
"metagpt_code_devlop": {
|
||||
"phase_name": "metagpt_code_devlop",
|
||||
"phase_type": "BasePhase",
|
||||
"chains": ["metagptChain",],
|
||||
"do_summary": False,
|
||||
"do_search": False,
|
||||
"do_doc_retrieval": False,
|
||||
"do_code_retrieval": False,
|
||||
"do_tool_retrieval": False,
|
||||
"do_using_tool": False
|
||||
},
|
||||
"baseGroupPhase": {
|
||||
"phase_name": "baseGroupPhase",
|
||||
"phase_type": "BasePhase",
|
||||
"chains": ["baseGroupChain"],
|
||||
"do_summary": False,
|
||||
"do_search": False,
|
||||
"do_doc_retrieval": False,
|
||||
"do_code_retrieval": False,
|
||||
"do_tool_retrieval": False,
|
||||
"do_using_tool": False
|
||||
},
|
||||
"code2DocsGroup": {
|
||||
"phase_name": "code2DocsGroup",
|
||||
"phase_type": "BasePhase",
|
||||
"chains": ["code2DocsGroupChain"],
|
||||
},
|
||||
"code2Tests": {
|
||||
"phase_name": "code2Tests",
|
||||
"phase_type": "BasePhase",
|
||||
"chains": ["code2TestsChain"],
|
||||
}
|
||||
}
|
||||
|
|
|
@ -40,4 +40,41 @@ SELECTOR_PROMPT_CONFIGS = [
|
|||
{"field_name": 'current_plan', "function_name": 'handle_current_plan'},
|
||||
{"field_name": 'output_format', "function_name": 'handle_output_format', 'title': 'Response Output Format', "is_context": False},
|
||||
{"field_name": 'begin!!!', "function_name": 'handle_response', "is_context": False, "omit_if_empty": False}
|
||||
]
|
||||
]
|
||||
|
||||
|
||||
CODE2DOC_GROUP_PROMPT_CONFIGS = [
|
||||
{"field_name": 'agent_profile', "function_name": 'handle_agent_profile', "is_context": False},
|
||||
{"field_name": 'agent_infomation', "function_name": 'handle_agent_data', "is_context": False, "omit_if_empty": False},
|
||||
# {"field_name": 'tool_information',"function_name": 'handle_tool_data', "is_context": False},
|
||||
{"field_name": 'context_placeholder', "function_name": '', "is_context": True},
|
||||
# {"field_name": 'reference_documents', "function_name": 'handle_doc_info'},
|
||||
{"field_name": 'session_records', "function_name": 'handle_session_records'},
|
||||
{"field_name": 'Specific Objective', "function_name": 'handle_specific_objective'},
|
||||
{"field_name": 'Code Snippet', "function_name": 'handle_code_snippet'},
|
||||
{"field_name": 'output_format', "function_name": 'handle_output_format', 'title': 'Response Output Format', "is_context": False},
|
||||
{"field_name": 'begin!!!', "function_name": 'handle_response', "is_context": False, "omit_if_empty": False}
|
||||
]
|
||||
|
||||
CODE2DOC_PROMPT_CONFIGS = [
|
||||
{"field_name": 'agent_profile', "function_name": 'handle_agent_profile', "is_context": False},
|
||||
# {"field_name": 'tool_information',"function_name": 'handle_tool_data', "is_context": False},
|
||||
{"field_name": 'context_placeholder', "function_name": '', "is_context": True},
|
||||
# {"field_name": 'reference_documents', "function_name": 'handle_doc_info'},
|
||||
{"field_name": 'session_records', "function_name": 'handle_session_records'},
|
||||
{"field_name": 'Specific Objective', "function_name": 'handle_specific_objective'},
|
||||
{"field_name": 'Code Snippet', "function_name": 'handle_code_snippet'},
|
||||
{"field_name": 'output_format', "function_name": 'handle_output_format', 'title': 'Response Output Format', "is_context": False},
|
||||
{"field_name": 'begin!!!', "function_name": 'handle_response', "is_context": False, "omit_if_empty": False}
|
||||
]
|
||||
|
||||
|
||||
CODE2TESTS_PROMPT_CONFIGS = [
|
||||
{"field_name": 'agent_profile', "function_name": 'handle_agent_profile', "is_context": False},
|
||||
{"field_name": 'context_placeholder', "function_name": '', "is_context": True},
|
||||
{"field_name": 'session_records', "function_name": 'handle_session_records'},
|
||||
{"field_name": 'code_snippet', "function_name": 'handle_code_snippet'},
|
||||
{"field_name": 'retrieval_codes', "function_name": 'handle_retrieval_codes', "description": ""},
|
||||
{"field_name": 'output_format', "function_name": 'handle_output_format', 'title': 'Response Output Format', "is_context": False},
|
||||
{"field_name": 'begin!!!', "function_name": 'handle_response', "is_context": False, "omit_if_empty": False}
|
||||
]
|
|
@ -14,7 +14,8 @@ from .qa_template_prompt import QA_PROMPT, CODE_QA_PROMPT, QA_TEMPLATE_PROMPT, C
|
|||
|
||||
from .executor_template_prompt import EXECUTOR_TEMPLATE_PROMPT
|
||||
from .refine_template_prompt import REFINE_TEMPLATE_PROMPT
|
||||
|
||||
from .code2doc_template_prompt import Code2DocGroup_PROMPT, Class2Doc_PROMPT, Func2Doc_PROMPT
|
||||
from .code2test_template_prompt import code2Tests_PROMPT, judgeCode2Tests_PROMPT
|
||||
from .agent_selector_template_prompt import SELECTOR_AGENT_TEMPLATE_PROMPT
|
||||
|
||||
from .react_template_prompt import REACT_TEMPLATE_PROMPT
|
||||
|
@ -37,5 +38,7 @@ __all__ = [
|
|||
"SELECTOR_AGENT_TEMPLATE_PROMPT",
|
||||
"PLANNER_TEMPLATE_PROMPT", "GENERAL_PLANNER_PROMPT", "DATA_PLANNER_PROMPT", "TOOL_PLANNER_PROMPT",
|
||||
"REACT_TEMPLATE_PROMPT",
|
||||
"REACT_CODE_PROMPT", "REACT_TOOL_PROMPT", "REACT_TOOL_AND_CODE_PROMPT", "REACT_TOOL_AND_CODE_PLANNER_PROMPT"
|
||||
"REACT_CODE_PROMPT", "REACT_TOOL_PROMPT", "REACT_TOOL_AND_CODE_PROMPT", "REACT_TOOL_AND_CODE_PLANNER_PROMPT",
|
||||
"Code2DocGroup_PROMPT", "Class2Doc_PROMPT", "Func2Doc_PROMPT",
|
||||
"code2Tests_PROMPT", "judgeCode2Tests_PROMPT"
|
||||
]
|
|
@ -0,0 +1,95 @@
|
|||
Code2DocGroup_PROMPT = """#### Agent Profile
|
||||
|
||||
Your goal is to response according the Context Data's information with the role that will best facilitate a solution, taking into account all relevant context (Context) provided.
|
||||
|
||||
When you need to select the appropriate role for handling a user's query, carefully read the provided role names, role descriptions and tool list.
|
||||
|
||||
ATTENTION: response carefully referenced "Response Output Format" in format.
|
||||
|
||||
#### Input Format
|
||||
|
||||
#### Response Output Format
|
||||
|
||||
**Code Path:** Extract the paths for the class/method/function that need to be addressed from the context
|
||||
|
||||
**Role:** Select the role from agent names
|
||||
"""
|
||||
|
||||
Class2Doc_PROMPT = """#### Agent Profile
|
||||
As an advanced code documentation generator, you are proficient in translating class definitions into comprehensive documentation with a focus on instantiation parameters.
|
||||
Your specific task is to parse the given code snippet of a class, extract information regarding its instantiation parameters.
|
||||
|
||||
ATTENTION: response carefully in "Response Output Format".
|
||||
|
||||
#### Input Format
|
||||
|
||||
**Code Snippet:** Provide the full class definition, including the constructor and any parameters it may require for instantiation.
|
||||
|
||||
#### Response Output Format
|
||||
**Class Base:** Specify the base class or interface from which the current class extends, if any.
|
||||
|
||||
**Class Description:** Offer a brief description of the class's purpose and functionality.
|
||||
|
||||
**Init Parameters:** List each parameter from construct. For each parameter, provide:
|
||||
- `param`: The parameter name
|
||||
- `param_description`: A concise explanation of the parameter's purpose.
|
||||
- `param_type`: The data type of the parameter, if explicitly defined.
|
||||
|
||||
```json
|
||||
[
|
||||
{
|
||||
"param": "parameter_name",
|
||||
"param_description": "A brief description of what this parameter is used for.",
|
||||
"param_type": "The data type of the parameter"
|
||||
},
|
||||
...
|
||||
]
|
||||
```
|
||||
|
||||
|
||||
If no parameter for construct, return
|
||||
```json
|
||||
[]
|
||||
```
|
||||
"""
|
||||
|
||||
Func2Doc_PROMPT = """#### Agent Profile
|
||||
You are a high-level code documentation assistant, skilled at extracting information from function/method code into detailed and well-structured documentation.
|
||||
|
||||
ATTENTION: response carefully in "Response Output Format".
|
||||
|
||||
|
||||
#### Input Format
|
||||
**Code Path:** Provide the code path of the function or method you wish to document.
|
||||
This name will be used to identify and extract the relevant details from the code snippet provided.
|
||||
|
||||
**Code Snippet:** A segment of code that contains the function or method to be documented.
|
||||
|
||||
#### Response Output Format
|
||||
|
||||
**Class Description:** Offer a brief description of the method(function)'s purpose and functionality.
|
||||
|
||||
**Parameters:** Extract parameter for the specific function/method Code from Code Snippet. For parameter, provide:
|
||||
- `param`: The parameter name
|
||||
- `param_description`: A concise explanation of the parameter's purpose.
|
||||
- `param_type`: The data type of the parameter, if explicitly defined.
|
||||
```json
|
||||
[
|
||||
{
|
||||
"param": "parameter_name",
|
||||
"param_description": "A brief description of what this parameter is used for.",
|
||||
"param_type": "The data type of the parameter"
|
||||
},
|
||||
...
|
||||
]
|
||||
```
|
||||
|
||||
If no parameter for function/method, return
|
||||
```json
|
||||
[]
|
||||
```
|
||||
|
||||
**Return Value Description:** Describe what the function/method returns upon completion.
|
||||
|
||||
**Return Type:** Indicate the type of data the function/method returns (e.g., string, integer, object, void).
|
||||
"""
|
|
@ -0,0 +1,65 @@
|
|||
judgeCode2Tests_PROMPT = """#### Agent Profile
|
||||
When determining the necessity of writing test cases for a given code snippet,
|
||||
it's essential to evaluate its interactions with dependent classes and methods (retrieved code snippets),
|
||||
in addition to considering these critical factors:
|
||||
1. Functionality: If it implements a concrete function or logic, test cases are typically necessary to verify its correctness.
|
||||
2. Complexity: If the code is complex, especially if it contains multiple conditional statements, loops, exceptions handling, etc.,
|
||||
it's more likely to harbor bugs, and thus test cases should be written.
|
||||
If the code involves complex algorithms or logic, then writing test cases can help ensure the accuracy of the logic and prevent errors during future refactoring.
|
||||
3. Criticality: If it's part of the critical path or affects core functionalities, then it needs to be tested.
|
||||
Comprehensive test cases should be written for core business logic or key components of the system to ensure the correctness and stability of the functionality.
|
||||
4. Dependencies: If the code has external dependencies, integration testing may be necessary, or mocking these dependencies during unit testing might be required.
|
||||
5. User Input: If the code handles user input, especially from unregulated external sources, creating test cases to check input validation and handling is important.
|
||||
6. Frequent Changes: For code that requires regular updates or modifications, having the appropriate test cases ensures that changes do not break existing functionalities.
|
||||
|
||||
#### Input Format
|
||||
|
||||
**Code Snippet:** the initial Code or objective that the user wanted to achieve
|
||||
|
||||
**Retrieval Code Snippets:** These are the associated code segments that the main Code Snippet depends on.
|
||||
Examine these snippets to understand how they interact with the main snippet and to determine how they might affect the overall functionality.
|
||||
|
||||
#### Response Output Format
|
||||
**Action Status:** Set to 'finished' or 'continued'.
|
||||
If set to 'finished', the code snippet does not warrant the generation of a test case.
|
||||
If set to 'continued', the code snippet necessitates the creation of a test case.
|
||||
|
||||
**REASON:** Justify the selection of 'finished' or 'continued', contemplating the decision through a step-by-step rationale.
|
||||
"""
|
||||
|
||||
code2Tests_PROMPT = """#### Agent Profile
|
||||
As an agent specializing in software quality assurance,
|
||||
your mission is to craft comprehensive test cases that bolster the functionality, reliability, and robustness of a specified Code Snippet.
|
||||
This task is to be carried out with a keen understanding of the snippet's interactions with its dependent classes and methods—collectively referred to as Retrieval Code Snippets.
|
||||
Analyze the details given below to grasp the code's intended purpose, its inherent complexity, and the context within which it operates.
|
||||
Your constructed test cases must thoroughly examine the various factors influencing the code's quality and performance.
|
||||
|
||||
ATTENTION: response carefully referenced "Response Output Format" in format.
|
||||
|
||||
Each test case should include:
|
||||
1. clear description of the test purpose.
|
||||
2. The input values or conditions for the test.
|
||||
3. The expected outcome or assertion for the test.
|
||||
4. Appropriate tags (e.g., 'functional', 'integration', 'regression') that classify the type of test case.
|
||||
5. these test code should have package and import
|
||||
|
||||
#### Input Format
|
||||
|
||||
**Code Snippet:** the initial Code or objective that the user wanted to achieve
|
||||
|
||||
**Retrieval Code Snippets:** These are the interrelated pieces of code sourced from the codebase, which support or influence the primary Code Snippet.
|
||||
|
||||
#### Response Output Format
|
||||
**SaveFileName:** construct a local file name based on Question and Context, such as
|
||||
|
||||
```java
|
||||
package/class.java
|
||||
```
|
||||
|
||||
|
||||
**Test Code:** generate the test code for the current Code Snippet.
|
||||
```java
|
||||
...
|
||||
```
|
||||
|
||||
"""
|
|
@ -1,5 +1,5 @@
|
|||
from abc import abstractmethod, ABC
|
||||
from typing import List
|
||||
from typing import List, Dict
|
||||
import os, sys, copy, json
|
||||
from jieba.analyse import extract_tags
|
||||
from collections import Counter
|
||||
|
@ -10,12 +10,13 @@ from langchain.docstore.document import Document
|
|||
|
||||
from .schema import Memory, Message
|
||||
from coagent.service.service_factory import KBServiceFactory
|
||||
from coagent.llm_models import getChatModel, getChatModelFromConfig
|
||||
from coagent.llm_models import getChatModelFromConfig
|
||||
from coagent.llm_models.llm_config import EmbedConfig, LLMConfig
|
||||
from coagent.embeddings.utils import load_embeddings_from_path
|
||||
from coagent.utils.common_utils import save_to_json_file, read_json_file, addMinutesToTime
|
||||
from coagent.connector.configs.prompts import CONV_SUMMARY_PROMPT_SPEC
|
||||
from coagent.orm import table_init
|
||||
from coagent.base_configs.env_config import KB_ROOT_PATH
|
||||
# from configs.model_config import KB_ROOT_PATH, EMBEDDING_MODEL, EMBEDDING_DEVICE, SCORE_THRESHOLD
|
||||
# from configs.model_config import embedding_model_dict
|
||||
|
||||
|
@ -70,16 +71,22 @@ class BaseMemoryManager(ABC):
|
|||
self.unique_name = unique_name
|
||||
self.memory_type = memory_type
|
||||
self.do_init = do_init
|
||||
self.current_memory = Memory(messages=[])
|
||||
self.recall_memory = Memory(messages=[])
|
||||
self.summary_memory = Memory(messages=[])
|
||||
# self.current_memory = Memory(messages=[])
|
||||
# self.recall_memory = Memory(messages=[])
|
||||
# self.summary_memory = Memory(messages=[])
|
||||
self.current_memory_dict: Dict[str, Memory] = {}
|
||||
self.recall_memory_dict: Dict[str, Memory] = {}
|
||||
self.summary_memory_dict: Dict[str, Memory] = {}
|
||||
self.save_message_keys = [
|
||||
'chat_index', 'role_name', 'role_type', 'role_prompt', 'input_query', 'origin_query',
|
||||
'datetime', 'role_content', 'step_content', 'parsed_output', 'spec_parsed_output', 'parsed_output_list',
|
||||
'task', 'db_docs', 'code_docs', 'search_docs', 'phase_name', 'chain_name', 'customed_kargs']
|
||||
self.init_vb()
|
||||
|
||||
def init_vb(self):
|
||||
def re_init(self, do_init: bool=False):
|
||||
self.init_vb()
|
||||
|
||||
def init_vb(self, do_init: bool=None):
|
||||
"""
|
||||
Initializes the vb.
|
||||
"""
|
||||
|
@ -135,13 +142,15 @@ class BaseMemoryManager(ABC):
|
|||
"""
|
||||
pass
|
||||
|
||||
def save_to_vs(self, embed_model="", embed_device=""):
|
||||
def save_to_vs(self, ):
|
||||
"""
|
||||
Saves the memory to the vector space.
|
||||
"""
|
||||
pass
|
||||
|
||||
Args:
|
||||
- embed_model: A string representing the embedding model. Default is EMBEDDING_MODEL.
|
||||
- embed_device: A string representing the embedding device. Default is EMBEDDING_DEVICE.
|
||||
def get_memory_pool(self, user_name: str, ):
|
||||
"""
|
||||
return memory_pool
|
||||
"""
|
||||
pass
|
||||
|
||||
|
@ -230,7 +239,7 @@ class LocalMemoryManager(BaseMemoryManager):
|
|||
unique_name: str = "default",
|
||||
memory_type: str = "recall",
|
||||
do_init: bool = False,
|
||||
kb_root_path: str = "",
|
||||
kb_root_path: str = KB_ROOT_PATH,
|
||||
):
|
||||
self.user_name = user_name
|
||||
self.unique_name = unique_name
|
||||
|
@ -239,16 +248,22 @@ class LocalMemoryManager(BaseMemoryManager):
|
|||
self.kb_root_path = kb_root_path
|
||||
self.embed_config: EmbedConfig = embed_config
|
||||
self.llm_config: LLMConfig = llm_config
|
||||
self.current_memory = Memory(messages=[])
|
||||
self.recall_memory = Memory(messages=[])
|
||||
self.summary_memory = Memory(messages=[])
|
||||
# self.current_memory = Memory(messages=[])
|
||||
# self.recall_memory = Memory(messages=[])
|
||||
# self.summary_memory = Memory(messages=[])
|
||||
self.current_memory_dict: Dict[str, Memory] = {}
|
||||
self.recall_memory_dict: Dict[str, Memory] = {}
|
||||
self.summary_memory_dict: Dict[str, Memory] = {}
|
||||
self.save_message_keys = [
|
||||
'chat_index', 'role_name', 'role_type', 'role_prompt', 'input_query', 'origin_query',
|
||||
'datetime', 'role_content', 'step_content', 'parsed_output', 'spec_parsed_output', 'parsed_output_list',
|
||||
'task', 'db_docs', 'code_docs', 'search_docs', 'phase_name', 'chain_name', 'customed_kargs']
|
||||
self.init_vb()
|
||||
|
||||
def init_vb(self):
|
||||
def re_init(self, do_init: bool=False):
|
||||
self.init_vb(do_init)
|
||||
|
||||
def init_vb(self, do_init: bool=None):
|
||||
vb_name = f"{self.user_name}/{self.unique_name}/{self.memory_type}"
|
||||
# default to recreate a new vb
|
||||
table_init()
|
||||
|
@ -256,31 +271,37 @@ class LocalMemoryManager(BaseMemoryManager):
|
|||
if vb:
|
||||
status = vb.clear_vs()
|
||||
|
||||
if not self.do_init:
|
||||
check_do_init = do_init if do_init else self.do_init
|
||||
if not check_do_init:
|
||||
self.load(self.kb_root_path)
|
||||
self.save_to_vs()
|
||||
|
||||
def append(self, message: Message):
|
||||
self.recall_memory.append(message)
|
||||
self.check_user_name(message.user_name)
|
||||
|
||||
uuid_name = "_".join([self.user_name, self.unique_name, self.memory_type])
|
||||
self.recall_memory_dict[uuid_name].append(message)
|
||||
#
|
||||
if message.role_type == "summary":
|
||||
self.summary_memory.append(message)
|
||||
self.summary_memory_dict[uuid_name].append(message)
|
||||
else:
|
||||
self.current_memory.append(message)
|
||||
self.current_memory_dict[uuid_name].append(message)
|
||||
|
||||
self.save(self.kb_root_path)
|
||||
self.save_new_to_vs([message])
|
||||
|
||||
def extend(self, memory: Memory):
|
||||
self.recall_memory.extend(memory)
|
||||
self.current_memory.extend(self.recall_memory.filter_by_role_type(["summary"]))
|
||||
self.summary_memory.extend(self.recall_memory.select_by_role_type(["summary"]))
|
||||
self.save(self.kb_root_path)
|
||||
self.save_new_to_vs(memory.messages)
|
||||
# def extend(self, memory: Memory):
|
||||
# self.recall_memory.extend(memory)
|
||||
# self.current_memory.extend(self.recall_memory.filter_by_role_type(["summary"]))
|
||||
# self.summary_memory.extend(self.recall_memory.select_by_role_type(["summary"]))
|
||||
# self.save(self.kb_root_path)
|
||||
# self.save_new_to_vs(memory.messages)
|
||||
|
||||
def save(self, save_dir: str = "./"):
|
||||
file_path = os.path.join(save_dir, f"{self.user_name}/{self.unique_name}/{self.memory_type}/converation.jsonl")
|
||||
memory_messages = self.recall_memory.dict()
|
||||
uuid_name = "_".join([self.user_name, self.unique_name, self.memory_type])
|
||||
|
||||
memory_messages = self.recall_memory_dict[uuid_name].dict()
|
||||
memory_messages = {k: [
|
||||
{kkk: vvv for kkk, vvv in vv.items() if kkk in self.save_message_keys}
|
||||
for vv in v ]
|
||||
|
@ -291,18 +312,28 @@ class LocalMemoryManager(BaseMemoryManager):
|
|||
|
||||
def load(self, load_dir: str = "./") -> Memory:
|
||||
file_path = os.path.join(load_dir, f"{self.user_name}/{self.unique_name}/{self.memory_type}/converation.jsonl")
|
||||
uuid_name = "_".join([self.user_name, self.unique_name, self.memory_type])
|
||||
|
||||
if os.path.exists(file_path):
|
||||
self.recall_memory = Memory(**read_json_file(file_path))
|
||||
self.current_memory = Memory(messages=self.recall_memory.filter_by_role_type(["summary"]))
|
||||
self.summary_memory = Memory(messages=self.recall_memory.select_by_role_type(["summary"]))
|
||||
# self.recall_memory = Memory(**read_json_file(file_path))
|
||||
# self.current_memory = Memory(messages=self.recall_memory.filter_by_role_type(["summary"]))
|
||||
# self.summary_memory = Memory(messages=self.recall_memory.select_by_role_type(["summary"]))
|
||||
|
||||
recall_memory = Memory(**read_json_file(file_path))
|
||||
self.recall_memory_dict[uuid_name] = recall_memory
|
||||
self.current_memory_dict[uuid_name] = Memory(messages=recall_memory.filter_by_role_type(["summary"]))
|
||||
self.summary_memory_dict[uuid_name] = Memory(messages=recall_memory.select_by_role_type(["summary"]))
|
||||
else:
|
||||
self.recall_memory_dict[uuid_name] = Memory(messages=[])
|
||||
self.current_memory_dict[uuid_name] = Memory(messages=[])
|
||||
self.summary_memory_dict[uuid_name] = Memory(messages=[])
|
||||
|
||||
def save_new_to_vs(self, messages: List[Message]):
|
||||
if self.embed_config:
|
||||
vb_name = f"{self.user_name}/{self.unique_name}/{self.memory_type}"
|
||||
# default to faiss, todo: add new vstype
|
||||
vb = KBServiceFactory.get_service(vb_name, "faiss", self.embed_config, self.kb_root_path)
|
||||
embeddings = load_embeddings_from_path(self.embed_config.embed_model_path, self.embed_config.model_device,)
|
||||
embeddings = load_embeddings_from_path(self.embed_config.embed_model_path, self.embed_config.model_device, self.embed_config.langchain_embeddings)
|
||||
messages = [
|
||||
{k: v for k, v in m.dict().items() if k in self.save_message_keys}
|
||||
for m in messages]
|
||||
|
@ -311,23 +342,26 @@ class LocalMemoryManager(BaseMemoryManager):
|
|||
vb.do_add_doc(docs, embeddings)
|
||||
|
||||
def save_to_vs(self):
|
||||
vb_name = f"{self.user_name}/{self.unique_name}/{self.memory_type}"
|
||||
# default to recreate a new vb
|
||||
vb = KBServiceFactory.get_service_by_name(vb_name, self.embed_config, self.kb_root_path)
|
||||
if vb:
|
||||
status = vb.clear_vs()
|
||||
# create_kb(vb_name, "faiss", embed_model)
|
||||
'''only after load'''
|
||||
if self.embed_config:
|
||||
vb_name = f"{self.user_name}/{self.unique_name}/{self.memory_type}"
|
||||
uuid_name = "_".join([self.user_name, self.unique_name, self.memory_type])
|
||||
# default to recreate a new vb
|
||||
vb = KBServiceFactory.get_service_by_name(vb_name, self.embed_config, self.kb_root_path)
|
||||
if vb:
|
||||
status = vb.clear_vs()
|
||||
# create_kb(vb_name, "faiss", embed_model)
|
||||
|
||||
# default to faiss, todo: add new vstype
|
||||
vb = KBServiceFactory.get_service(vb_name, "faiss", self.embed_config, self.kb_root_path)
|
||||
embeddings = load_embeddings_from_path(self.embed_config.embed_model_path, self.embed_config.model_device,)
|
||||
messages = self.recall_memory.dict()
|
||||
messages = [
|
||||
{kkk: vvv for kkk, vvv in vv.items() if kkk in self.save_message_keys}
|
||||
for k, v in messages.items() for vv in v]
|
||||
docs = [{"page_content": m["step_content"] or m["role_content"] or m["input_query"] or m["origin_query"], "metadata": m} for m in messages]
|
||||
docs = [Document(**doc) for doc in docs]
|
||||
vb.do_add_doc(docs, embeddings)
|
||||
# default to faiss, todo: add new vstype
|
||||
vb = KBServiceFactory.get_service(vb_name, "faiss", self.embed_config, self.kb_root_path)
|
||||
embeddings = load_embeddings_from_path(self.embed_config.embed_model_path, self.embed_config.model_device, self.embed_config.langchain_embeddings)
|
||||
messages = self.recall_memory_dict[uuid_name].dict()
|
||||
messages = [
|
||||
{kkk: vvv for kkk, vvv in vv.items() if kkk in self.save_message_keys}
|
||||
for k, v in messages.items() for vv in v]
|
||||
docs = [{"page_content": m["step_content"] or m["role_content"] or m["input_query"] or m["origin_query"], "metadata": m} for m in messages]
|
||||
docs = [Document(**doc) for doc in docs]
|
||||
vb.do_add_doc(docs, embeddings)
|
||||
|
||||
# def load_from_vs(self, embed_model=EMBEDDING_MODEL) -> Memory:
|
||||
# vb_name = f"{self.user_name}/{self.unique_name}/{self.memory_type}"
|
||||
|
@ -338,7 +372,12 @@ class LocalMemoryManager(BaseMemoryManager):
|
|||
# docs = vb.get_all_documents()
|
||||
# print(docs)
|
||||
|
||||
def router_retrieval(self, text: str=None, datetime: str = None, n=5, top_k=5, retrieval_type: str = "embedding", **kwargs) -> List[Message]:
|
||||
def get_memory_pool(self, user_name: str, ):
|
||||
self.check_user_name(user_name)
|
||||
uuid_name = "_".join([self.user_name, self.unique_name, self.memory_type])
|
||||
return self.recall_memory_dict[uuid_name]
|
||||
|
||||
def router_retrieval(self, user_name: str = "default", text: str=None, datetime: str = None, n=5, top_k=5, retrieval_type: str = "embedding", **kwargs) -> List[Message]:
|
||||
retrieval_func_dict = {
|
||||
"embedding": self.embedding_retrieval, "text": self.text_retrieval, "datetime": self.datetime_retrieval
|
||||
}
|
||||
|
@ -356,20 +395,22 @@ class LocalMemoryManager(BaseMemoryManager):
|
|||
#
|
||||
return retrieval_func(**params)
|
||||
|
||||
def embedding_retrieval(self, text: str, top_k=1, score_threshold=1.0, **kwargs) -> List[Message]:
|
||||
def embedding_retrieval(self, text: str, top_k=1, score_threshold=1.0, user_name: str = "default", **kwargs) -> List[Message]:
|
||||
if text is None: return []
|
||||
vb_name = f"{self.user_name}/{self.unique_name}/{self.memory_type}"
|
||||
vb_name = f"{user_name}/{self.unique_name}/{self.memory_type}"
|
||||
vb = KBServiceFactory.get_service(vb_name, "faiss", self.embed_config, self.kb_root_path)
|
||||
docs = vb.search_docs(text, top_k=top_k, score_threshold=score_threshold)
|
||||
return [Message(**doc.metadata) for doc, score in docs]
|
||||
|
||||
def text_retrieval(self, text: str, **kwargs) -> List[Message]:
|
||||
def text_retrieval(self, text: str, user_name: str = "default", **kwargs) -> List[Message]:
|
||||
if text is None: return []
|
||||
return self._text_retrieval_from_cache(self.recall_memory.messages, text, score_threshold=0.3, topK=5, **kwargs)
|
||||
uuid_name = "_".join([user_name, self.unique_name, self.memory_type])
|
||||
return self._text_retrieval_from_cache(self.recall_memory_dict[uuid_name].messages, text, score_threshold=0.3, topK=5, **kwargs)
|
||||
|
||||
def datetime_retrieval(self, datetime: str, text: str = None, n: int = 5, **kwargs) -> List[Message]:
|
||||
def datetime_retrieval(self, datetime: str, text: str = None, n: int = 5, user_name: str = "default", **kwargs) -> List[Message]:
|
||||
if datetime is None: return []
|
||||
return self._datetime_retrieval_from_cache(self.recall_memory.messages, datetime, text, n, **kwargs)
|
||||
uuid_name = "_".join([user_name, self.unique_name, self.memory_type])
|
||||
return self._datetime_retrieval_from_cache(self.recall_memory_dict[uuid_name].messages, datetime, text, n, **kwargs)
|
||||
|
||||
def _text_retrieval_from_cache(self, messages: List[Message], text: str = None, score_threshold=0.3, topK=5, tag_topK=5, **kwargs) -> List[Message]:
|
||||
keywords = extract_tags(text, topK=tag_topK)
|
||||
|
@ -427,4 +468,18 @@ class LocalMemoryManager(BaseMemoryManager):
|
|||
)
|
||||
summary_message.parsed_output_list.append({"summary": content})
|
||||
newest_messages.insert(0, summary_message)
|
||||
return newest_messages
|
||||
return newest_messages
|
||||
|
||||
def check_user_name(self, user_name: str):
|
||||
# logger.debug(f"self.user_name is {self.user_name}")
|
||||
if user_name != self.user_name:
|
||||
self.user_name = user_name
|
||||
self.init_vb()
|
||||
|
||||
uuid_name = "_".join([self.user_name, self.unique_name, self.memory_type])
|
||||
if uuid_name not in self.recall_memory_dict:
|
||||
self.recall_memory_dict[uuid_name] = Memory(messages=[])
|
||||
self.current_memory_dict[uuid_name] = Memory(messages=[])
|
||||
self.summary_memory_dict[uuid_name] = Memory(messages=[])
|
||||
|
||||
# logger.debug(f"self.user_name is {self.user_name}")
|
|
@ -1,16 +1,19 @@
|
|||
import re, traceback, uuid, copy, json, os
|
||||
from typing import Union
|
||||
from loguru import logger
|
||||
|
||||
from langchain.schema import BaseRetriever
|
||||
|
||||
# from configs.server_config import SANDBOX_SERVER
|
||||
# from configs.model_config import JUPYTER_WORK_PATH
|
||||
from coagent.connector.schema import (
|
||||
Memory, Role, Message, ActionStatus, CodeDoc, Doc, LogVerboseEnum
|
||||
)
|
||||
from coagent.retrieval.base_retrieval import IMRertrieval
|
||||
from coagent.connector.memory_manager import BaseMemoryManager
|
||||
from coagent.tools import DDGSTool, DocRetrieval, CodeRetrieval
|
||||
from coagent.sandbox import PyCodeBox, CodeBoxResponse
|
||||
from coagent.llm_models.llm_config import LLMConfig, EmbedConfig
|
||||
from coagent.base_configs.env_config import JUPYTER_WORK_PATH
|
||||
|
||||
from .utils import parse_dict_to_dict, parse_text_to_dict
|
||||
|
||||
|
||||
|
@ -19,10 +22,13 @@ class MessageUtils:
|
|||
self,
|
||||
role: Role = None,
|
||||
sandbox_server: dict = {},
|
||||
jupyter_work_path: str = "./",
|
||||
jupyter_work_path: str = JUPYTER_WORK_PATH,
|
||||
embed_config: EmbedConfig = None,
|
||||
llm_config: LLMConfig = None,
|
||||
kb_root_path: str = "",
|
||||
doc_retrieval: Union[BaseRetriever, IMRertrieval] = None,
|
||||
code_retrieval: IMRertrieval = None,
|
||||
search_retrieval: IMRertrieval = None,
|
||||
log_verbose: str = "0"
|
||||
) -> None:
|
||||
self.role = role
|
||||
|
@ -31,6 +37,9 @@ class MessageUtils:
|
|||
self.embed_config = embed_config
|
||||
self.llm_config = llm_config
|
||||
self.kb_root_path = kb_root_path
|
||||
self.doc_retrieval = doc_retrieval
|
||||
self.code_retrieval = code_retrieval
|
||||
self.search_retrieval = search_retrieval
|
||||
self.codebox = PyCodeBox(
|
||||
remote_url=self.sandbox_server.get("url", "http://127.0.0.1:5050"),
|
||||
remote_ip=self.sandbox_server.get("host", "http://127.0.0.1"),
|
||||
|
@ -44,6 +53,7 @@ class MessageUtils:
|
|||
self.log_verbose = os.environ.get("log_verbose", "0") or log_verbose
|
||||
|
||||
def inherit_extrainfo(self, input_message: Message, output_message: Message):
|
||||
output_message.user_name = input_message.user_name
|
||||
output_message.db_docs = input_message.db_docs
|
||||
output_message.search_docs = input_message.search_docs
|
||||
output_message.code_docs = input_message.code_docs
|
||||
|
@ -116,18 +126,45 @@ class MessageUtils:
|
|||
knowledge_basename = message.doc_engine_name
|
||||
top_k = message.top_k
|
||||
score_threshold = message.score_threshold
|
||||
if knowledge_basename:
|
||||
docs = DocRetrieval.run(query, knowledge_basename, top_k, score_threshold, self.embed_config, self.kb_root_path)
|
||||
if self.doc_retrieval:
|
||||
if isinstance(self.doc_retrieval, BaseRetriever):
|
||||
docs = self.doc_retrieval.get_relevant_documents(query)
|
||||
else:
|
||||
# docs = self.doc_retrieval.run(query, search_top=message.top_k, score_threshold=message.score_threshold,)
|
||||
docs = self.doc_retrieval.run(query)
|
||||
docs = [
|
||||
{"index": idx, "snippet": doc.page_content, "title": doc.metadata.get("title_prefix", ""), "link": doc.metadata.get("url", "")}
|
||||
for idx, doc in enumerate(docs)
|
||||
]
|
||||
message.db_docs = [Doc(**doc) for doc in docs]
|
||||
else:
|
||||
if knowledge_basename:
|
||||
docs = DocRetrieval.run(query, knowledge_basename, top_k, score_threshold, self.embed_config, self.kb_root_path)
|
||||
message.db_docs = [Doc(**doc) for doc in docs]
|
||||
return message
|
||||
|
||||
def get_code_retrieval(self, message: Message) -> Message:
|
||||
query = message.input_query
|
||||
query = message.role_content
|
||||
code_engine_name = message.code_engine_name
|
||||
history_node_list = message.history_node_list
|
||||
code_docs = CodeRetrieval.run(code_engine_name, query, code_limit=message.top_k, history_node_list=history_node_list, search_type=message.cb_search_type,
|
||||
llm_config=self.llm_config, embed_config=self.embed_config,)
|
||||
message.code_docs = [CodeDoc(**doc) for doc in code_docs]
|
||||
|
||||
use_nh = message.use_nh
|
||||
local_graph_path = message.local_graph_path
|
||||
|
||||
if self.code_retrieval:
|
||||
code_docs = self.code_retrieval.run(
|
||||
query, history_node_list=history_node_list, search_type=message.cb_search_type,
|
||||
code_limit=1
|
||||
)
|
||||
else:
|
||||
code_docs = CodeRetrieval.run(code_engine_name, query, code_limit=message.top_k, history_node_list=history_node_list, search_type=message.cb_search_type,
|
||||
llm_config=self.llm_config, embed_config=self.embed_config,
|
||||
use_nh=use_nh, local_graph_path=local_graph_path)
|
||||
|
||||
message.code_docs = [CodeDoc(**doc) for doc in code_docs]
|
||||
|
||||
# related_nodes = [doc.get_related_node() for idx, doc in enumerate(message.code_docs) if idx==0],
|
||||
# history_node_list.extend([node[0] for node in related_nodes])
|
||||
return message
|
||||
|
||||
def get_tool_retrieval(self, message: Message) -> Message:
|
||||
|
@ -160,6 +197,7 @@ class MessageUtils:
|
|||
if code_answer.code_exe_type == "error" else f"The return information after executing the above code is {code_answer.code_exe_response}.\n"
|
||||
|
||||
observation_message = Message(
|
||||
user_name=message.user_name,
|
||||
role_name="observation",
|
||||
role_type="function", #self.role.role_type,
|
||||
role_content="",
|
||||
|
@ -190,6 +228,7 @@ class MessageUtils:
|
|||
def tool_step(self, message: Message) -> Message:
|
||||
'''execute tool'''
|
||||
observation_message = Message(
|
||||
user_name=message.user_name,
|
||||
role_name="observation",
|
||||
role_type="function", #self.role.role_type,
|
||||
role_content="\n**Observation:** there is no tool can execute\n",
|
||||
|
@ -226,7 +265,7 @@ class MessageUtils:
|
|||
return message, observation_message
|
||||
|
||||
def parser(self, message: Message) -> Message:
|
||||
''''''
|
||||
'''parse llm output into dict'''
|
||||
content = message.role_content
|
||||
# parse start
|
||||
parsed_dict = parse_text_to_dict(content)
|
||||
|
|
|
@ -5,6 +5,8 @@ import importlib
|
|||
import copy
|
||||
from loguru import logger
|
||||
|
||||
from langchain.schema import BaseRetriever
|
||||
|
||||
from coagent.connector.agents import BaseAgent
|
||||
from coagent.connector.chains import BaseChain
|
||||
from coagent.connector.schema import (
|
||||
|
@ -18,9 +20,6 @@ from coagent.connector.message_process import MessageUtils
|
|||
from coagent.llm_models.llm_config import EmbedConfig, LLMConfig
|
||||
from coagent.base_configs.env_config import JUPYTER_WORK_PATH, KB_ROOT_PATH
|
||||
|
||||
# from configs.model_config import JUPYTER_WORK_PATH, KB_ROOT_PATH
|
||||
# from configs.server_config import SANDBOX_SERVER
|
||||
|
||||
|
||||
role_configs = load_role_configs(AGETN_CONFIGS)
|
||||
chain_configs = load_chain_configs(CHAIN_CONFIGS)
|
||||
|
@ -39,20 +38,24 @@ class BasePhase:
|
|||
kb_root_path: str = KB_ROOT_PATH,
|
||||
jupyter_work_path: str = JUPYTER_WORK_PATH,
|
||||
sandbox_server: dict = {},
|
||||
embed_config: EmbedConfig = EmbedConfig(),
|
||||
llm_config: LLMConfig = LLMConfig(),
|
||||
embed_config: EmbedConfig = None,
|
||||
llm_config: LLMConfig = None,
|
||||
task: Task = None,
|
||||
base_phase_config: Union[dict, str] = PHASE_CONFIGS,
|
||||
base_chain_config: Union[dict, str] = CHAIN_CONFIGS,
|
||||
base_role_config: Union[dict, str] = AGETN_CONFIGS,
|
||||
chains: List[BaseChain] = [],
|
||||
doc_retrieval: Union[BaseRetriever] = None,
|
||||
code_retrieval = None,
|
||||
search_retrieval = None,
|
||||
log_verbose: str = "0"
|
||||
) -> None:
|
||||
#
|
||||
self.phase_name = phase_name
|
||||
self.do_summary = False
|
||||
self.do_search = False
|
||||
self.do_code_retrieval = False
|
||||
self.do_doc_retrieval = False
|
||||
self.do_search = search_retrieval is not None
|
||||
self.do_code_retrieval = code_retrieval is not None
|
||||
self.do_doc_retrieval = doc_retrieval is not None
|
||||
self.do_tool_retrieval = False
|
||||
# memory_pool dont have specific order
|
||||
# self.memory_pool = Memory(messages=[])
|
||||
|
@ -62,12 +65,15 @@ class BasePhase:
|
|||
self.jupyter_work_path = jupyter_work_path
|
||||
self.kb_root_path = kb_root_path
|
||||
self.log_verbose = max(os.environ.get("log_verbose", "0"), log_verbose)
|
||||
|
||||
self.message_utils = MessageUtils(None, sandbox_server, jupyter_work_path, embed_config, llm_config, kb_root_path, log_verbose)
|
||||
# TODO透传
|
||||
self.doc_retrieval = doc_retrieval
|
||||
self.code_retrieval = code_retrieval
|
||||
self.search_retrieval = search_retrieval
|
||||
self.message_utils = MessageUtils(None, sandbox_server, jupyter_work_path, embed_config, llm_config, kb_root_path, doc_retrieval, code_retrieval, search_retrieval, log_verbose)
|
||||
self.global_memory = Memory(messages=[])
|
||||
self.phase_memory: List[Memory] = []
|
||||
# according phase name to init the phase contains
|
||||
self.chains: List[BaseChain] = self.init_chains(
|
||||
self.chains: List[BaseChain] = chains if chains else self.init_chains(
|
||||
phase_name,
|
||||
phase_config,
|
||||
task=task,
|
||||
|
@ -90,7 +96,9 @@ class BasePhase:
|
|||
kb_root_path=kb_root_path
|
||||
)
|
||||
|
||||
def astep(self, query: Message, history: Memory = None) -> Tuple[Message, Memory]:
|
||||
def astep(self, query: Message, history: Memory = None, reinit_memory=False) -> Tuple[Message, Memory]:
|
||||
if reinit_memory:
|
||||
self.memory_manager.re_init(reinit_memory)
|
||||
self.memory_manager.append(query)
|
||||
summary_message = None
|
||||
chain_message = Memory(messages=[])
|
||||
|
@ -139,8 +147,8 @@ class BasePhase:
|
|||
message.role_name = self.phase_name
|
||||
yield message, local_phase_memory
|
||||
|
||||
def step(self, query: Message, history: Memory = None) -> Tuple[Message, Memory]:
|
||||
for message, local_phase_memory in self.astep(query, history=history):
|
||||
def step(self, query: Message, history: Memory = None, reinit_memory=False) -> Tuple[Message, Memory]:
|
||||
for message, local_phase_memory in self.astep(query, history=history, reinit_memory=reinit_memory):
|
||||
pass
|
||||
return message, local_phase_memory
|
||||
|
||||
|
@ -194,6 +202,9 @@ class BasePhase:
|
|||
sandbox_server=self.sandbox_server,
|
||||
jupyter_work_path=self.jupyter_work_path,
|
||||
kb_root_path=self.kb_root_path,
|
||||
doc_retrieval=self.doc_retrieval,
|
||||
code_retrieval=self.code_retrieval,
|
||||
search_retrieval=self.search_retrieval,
|
||||
log_verbose=self.log_verbose
|
||||
)
|
||||
if agent_config.role.agent_type == "SelectorAgent":
|
||||
|
@ -205,7 +216,7 @@ class BasePhase:
|
|||
group_base_agent = baseAgent(
|
||||
role=group_agent_config.role,
|
||||
prompt_config = group_agent_config.prompt_config,
|
||||
prompt_manager_type=agent_config.prompt_manager_type,
|
||||
prompt_manager_type=group_agent_config.prompt_manager_type,
|
||||
task = task,
|
||||
memory = memory,
|
||||
chat_turn=group_agent_config.chat_turn,
|
||||
|
@ -216,6 +227,9 @@ class BasePhase:
|
|||
sandbox_server=self.sandbox_server,
|
||||
jupyter_work_path=self.jupyter_work_path,
|
||||
kb_root_path=self.kb_root_path,
|
||||
doc_retrieval=self.doc_retrieval,
|
||||
code_retrieval=self.code_retrieval,
|
||||
search_retrieval=self.search_retrieval,
|
||||
log_verbose=self.log_verbose
|
||||
)
|
||||
base_agent.group_agents.append(group_base_agent)
|
||||
|
@ -223,13 +237,16 @@ class BasePhase:
|
|||
agents.append(base_agent)
|
||||
|
||||
chain_instance = BaseChain(
|
||||
agents, chain_config.chat_turn,
|
||||
do_checker=chain_configs[chain_name].do_checker,
|
||||
chain_config,
|
||||
agents,
|
||||
jupyter_work_path=self.jupyter_work_path,
|
||||
sandbox_server=self.sandbox_server,
|
||||
embed_config=self.embed_config,
|
||||
llm_config=self.llm_config,
|
||||
kb_root_path=self.kb_root_path,
|
||||
doc_retrieval=self.doc_retrieval,
|
||||
code_retrieval=self.code_retrieval,
|
||||
search_retrieval=self.search_retrieval,
|
||||
log_verbose=self.log_verbose
|
||||
)
|
||||
chains.append(chain_instance)
|
||||
|
|
|
@ -0,0 +1,2 @@
|
|||
from .prompt_manager import PromptManager
|
||||
from .extend_manager import *
|
|
@ -0,0 +1,45 @@
|
|||
|
||||
from coagent.connector.schema import Message
|
||||
from .prompt_manager import PromptManager
|
||||
|
||||
|
||||
class Code2DocPM(PromptManager):
|
||||
def handle_code_snippet(self, **kwargs) -> str:
|
||||
if 'previous_agent_message' not in kwargs:
|
||||
return ""
|
||||
previous_agent_message: Message = kwargs['previous_agent_message']
|
||||
code_snippet = previous_agent_message.customed_kargs.get("Code Snippet", "")
|
||||
current_vertex = previous_agent_message.customed_kargs.get("Current_Vertex", "")
|
||||
instruction = "A segment of code that contains the function or method to be documented.\n"
|
||||
return instruction + "\n" + f"name: {current_vertex}\n{code_snippet}"
|
||||
|
||||
def handle_specific_objective(self, **kwargs) -> str:
|
||||
if 'previous_agent_message' not in kwargs:
|
||||
return ""
|
||||
previous_agent_message: Message = kwargs['previous_agent_message']
|
||||
specific_objective = previous_agent_message.parsed_output.get("Code Path")
|
||||
|
||||
instruction = "Provide the code path of the function or method you wish to document.\n"
|
||||
s = instruction + f"\n{specific_objective}"
|
||||
return s
|
||||
|
||||
|
||||
class CodeRetrievalPM(PromptManager):
|
||||
def handle_code_snippet(self, **kwargs) -> str:
|
||||
if 'previous_agent_message' not in kwargs:
|
||||
return ""
|
||||
previous_agent_message: Message = kwargs['previous_agent_message']
|
||||
code_snippet = previous_agent_message.customed_kargs.get("Code Snippet", "")
|
||||
current_vertex = previous_agent_message.customed_kargs.get("Current_Vertex", "")
|
||||
instruction = "the initial Code or objective that the user wanted to achieve"
|
||||
return instruction + "\n" + f"name: {current_vertex}\n{code_snippet}"
|
||||
|
||||
def handle_retrieval_codes(self, **kwargs) -> str:
|
||||
if 'previous_agent_message' not in kwargs:
|
||||
return ""
|
||||
previous_agent_message: Message = kwargs['previous_agent_message']
|
||||
Retrieval_Codes = previous_agent_message.customed_kargs["Retrieval_Codes"]
|
||||
Relative_vertex = previous_agent_message.customed_kargs["Relative_vertex"]
|
||||
instruction = "the initial Code or objective that the user wanted to achieve"
|
||||
s = instruction + "\n" + "\n".join([f"name: {vertext}\n{code}" for vertext, code in zip(Relative_vertex, Retrieval_Codes)])
|
||||
return s
|
|
@ -0,0 +1,353 @@
|
|||
import random
|
||||
from textwrap import dedent
|
||||
import copy
|
||||
from loguru import logger
|
||||
|
||||
from langchain.agents.tools import Tool
|
||||
|
||||
from coagent.connector.schema import Memory, Message
|
||||
from coagent.connector.utils import extract_section, parse_section
|
||||
|
||||
|
||||
|
||||
class PromptManager:
|
||||
def __init__(self, role_prompt="", prompt_config=None, monitored_agents=[], monitored_fields=[]):
|
||||
self.role_prompt = role_prompt
|
||||
self.monitored_agents = monitored_agents
|
||||
self.monitored_fields = monitored_fields
|
||||
self.field_handlers = {}
|
||||
self.context_handlers = {}
|
||||
self.field_order = [] # 用于普通字段的顺序
|
||||
self.context_order = [] # 单独维护上下文字段的顺序
|
||||
self.field_descriptions = {}
|
||||
self.omit_if_empty_flags = {}
|
||||
self.context_title = "### Context Data\n\n"
|
||||
|
||||
self.prompt_config = prompt_config
|
||||
if self.prompt_config:
|
||||
self.register_fields_from_config()
|
||||
|
||||
def register_field(self, field_name, function=None, title=None, description=None, is_context=True, omit_if_empty=True):
|
||||
"""
|
||||
注册一个新的字段及其处理函数。
|
||||
Args:
|
||||
field_name (str): 字段名称。
|
||||
function (callable): 处理字段数据的函数。
|
||||
title (str, optional): 字段的自定义标题(可选)。
|
||||
description (str, optional): 字段的描述(可选,可以是几句话)。
|
||||
is_context (bool, optional): 指示该字段是否为上下文字段。
|
||||
omit_if_empty (bool, optional): 如果数据为空,是否省略该字段。
|
||||
"""
|
||||
if not function:
|
||||
function = self.handle_custom_data
|
||||
|
||||
# Register the handler function based on context flag
|
||||
if is_context:
|
||||
self.context_handlers[field_name] = function
|
||||
else:
|
||||
self.field_handlers[field_name] = function
|
||||
|
||||
# Store the custom title if provided and adjust the title prefix based on context
|
||||
title_prefix = "####" if is_context else "###"
|
||||
if title is not None:
|
||||
self.field_descriptions[field_name] = f"{title_prefix} {title}\n\n"
|
||||
elif description is not None:
|
||||
# If title is not provided but description is, use description as title
|
||||
self.field_descriptions[field_name] = f"{title_prefix} {field_name.replace('_', ' ').title()}\n\n{description}\n\n"
|
||||
else:
|
||||
# If neither title nor description is provided, use the field name as title
|
||||
self.field_descriptions[field_name] = f"{title_prefix} {field_name.replace('_', ' ').title()}\n\n"
|
||||
|
||||
# Store the omit_if_empty flag for this field
|
||||
self.omit_if_empty_flags[field_name] = omit_if_empty
|
||||
|
||||
if is_context and field_name != 'context_placeholder':
|
||||
self.context_handlers[field_name] = function
|
||||
self.context_order.append(field_name)
|
||||
else:
|
||||
self.field_handlers[field_name] = function
|
||||
self.field_order.append(field_name)
|
||||
|
||||
def generate_full_prompt(self, **kwargs):
|
||||
full_prompt = []
|
||||
context_prompts = [] # 用于收集上下文内容
|
||||
is_pre_print = kwargs.get("is_pre_print", False) # 用于强制打印所有prompt 字段信息,不管有没有空
|
||||
|
||||
# 先处理上下文字段
|
||||
for field_name in self.context_order:
|
||||
handler = self.context_handlers[field_name]
|
||||
processed_prompt = handler(**kwargs)
|
||||
# Check if the field should be omitted when empty
|
||||
if self.omit_if_empty_flags.get(field_name, False) and not processed_prompt and not is_pre_print:
|
||||
continue # Skip this field
|
||||
title_or_description = self.field_descriptions.get(field_name, f"#### {field_name.replace('_', ' ').title()}\n\n")
|
||||
context_prompts.append(title_or_description + processed_prompt + '\n\n')
|
||||
|
||||
# 处理普通字段,同时查找 context_placeholder 的位置
|
||||
for field_name in self.field_order:
|
||||
if field_name == 'context_placeholder':
|
||||
# 在 context_placeholder 的位置插入上下文数据
|
||||
full_prompt.append(self.context_title) # 添加上下文部分的大标题
|
||||
full_prompt.extend(context_prompts) # 添加收集的上下文内容
|
||||
else:
|
||||
handler = self.field_handlers[field_name]
|
||||
processed_prompt = handler(**kwargs)
|
||||
# Check if the field should be omitted when empty
|
||||
if self.omit_if_empty_flags.get(field_name, False) and not processed_prompt and not is_pre_print:
|
||||
continue # Skip this field
|
||||
title_or_description = self.field_descriptions.get(field_name, f"### {field_name.replace('_', ' ').title()}\n\n")
|
||||
full_prompt.append(title_or_description + processed_prompt + '\n\n')
|
||||
|
||||
# 返回完整的提示,移除尾部的空行
|
||||
return ''.join(full_prompt).rstrip('\n')
|
||||
|
||||
def pre_print(self, **kwargs):
|
||||
kwargs.update({"is_pre_print": True})
|
||||
prompt = self.generate_full_prompt(**kwargs)
|
||||
|
||||
input_keys = parse_section(self.role_prompt, 'Response Output Format')
|
||||
llm_predict = "\n".join([f"**{k}:**" for k in input_keys])
|
||||
return prompt + "\n\n" + "#"*19 + "\n<<<<LLM PREDICT>>>>\n" + "#"*19 + f"\n\n{llm_predict}\n"
|
||||
|
||||
def handle_custom_data(self, **kwargs):
|
||||
return ""
|
||||
|
||||
def handle_tool_data(self, **kwargs):
|
||||
if 'previous_agent_message' not in kwargs:
|
||||
return ""
|
||||
|
||||
previous_agent_message = kwargs.get('previous_agent_message')
|
||||
tools: list[Tool] = previous_agent_message.tools
|
||||
|
||||
if not tools:
|
||||
return ""
|
||||
|
||||
tool_strings = []
|
||||
for tool in tools:
|
||||
args_str = f'args: {str(tool.args)}' if tool.args_schema else ""
|
||||
tool_strings.append(f"{tool.name}: {tool.description}, {args_str}")
|
||||
formatted_tools = "\n".join(tool_strings)
|
||||
|
||||
tool_names = ", ".join([tool.name for tool in tools])
|
||||
|
||||
tool_prompt = dedent(f"""
|
||||
Below is a list of tools that are available for your use:
|
||||
{formatted_tools}
|
||||
|
||||
valid "tool_name" value is:
|
||||
{tool_names}
|
||||
""")
|
||||
|
||||
return tool_prompt
|
||||
|
||||
def handle_agent_data(self, **kwargs):
|
||||
if 'agents' not in kwargs:
|
||||
return ""
|
||||
|
||||
agents = kwargs.get('agents')
|
||||
random.shuffle(agents)
|
||||
agent_names = ", ".join([f'{agent.role.role_name}' for agent in agents])
|
||||
agent_descs = []
|
||||
for agent in agents:
|
||||
role_desc = agent.role.role_prompt.split("####")[1]
|
||||
while "\n\n" in role_desc:
|
||||
role_desc = role_desc.replace("\n\n", "\n")
|
||||
role_desc = role_desc.replace("\n", ",")
|
||||
|
||||
agent_descs.append(f'"role name: {agent.role.role_name}\nrole description: {role_desc}"')
|
||||
|
||||
agents = "\n".join(agent_descs)
|
||||
agent_prompt = f'''
|
||||
Please ensure your selection is one of the listed roles. Available roles for selection:
|
||||
{agents}
|
||||
Please ensure select the Role from agent names, such as {agent_names}'''
|
||||
|
||||
return dedent(agent_prompt)
|
||||
|
||||
def handle_doc_info(self, **kwargs) -> str:
|
||||
if 'previous_agent_message' not in kwargs:
|
||||
return ""
|
||||
previous_agent_message: Message = kwargs.get('previous_agent_message')
|
||||
db_docs = previous_agent_message.db_docs
|
||||
search_docs = previous_agent_message.search_docs
|
||||
code_cocs = previous_agent_message.code_docs
|
||||
doc_infos = "\n".join([doc.get_snippet() for doc in db_docs] + [doc.get_snippet() for doc in search_docs] +
|
||||
[doc.get_code() for doc in code_cocs])
|
||||
return doc_infos
|
||||
|
||||
def handle_session_records(self, **kwargs) -> str:
|
||||
|
||||
memory_pool: Memory = kwargs.get('memory_pool', Memory(messages=[]))
|
||||
memory_pool = self.select_memory_by_agent_name(memory_pool)
|
||||
memory_pool = self.select_memory_by_parsed_key(memory_pool)
|
||||
|
||||
return memory_pool.to_str_messages(content_key="parsed_output_list", with_tag=True)
|
||||
|
||||
def handle_current_plan(self, **kwargs) -> str:
|
||||
if 'previous_agent_message' not in kwargs:
|
||||
return ""
|
||||
previous_agent_message = kwargs['previous_agent_message']
|
||||
return previous_agent_message.parsed_output.get("CURRENT_STEP", "")
|
||||
|
||||
def handle_agent_profile(self, **kwargs) -> str:
|
||||
return extract_section(self.role_prompt, 'Agent Profile')
|
||||
|
||||
def handle_output_format(self, **kwargs) -> str:
|
||||
return extract_section(self.role_prompt, 'Response Output Format')
|
||||
|
||||
def handle_response(self, **kwargs) -> str:
|
||||
if 'react_memory' not in kwargs:
|
||||
return ""
|
||||
|
||||
react_memory = kwargs.get('react_memory', Memory(messages=[]))
|
||||
if react_memory is None:
|
||||
return ""
|
||||
|
||||
return "\n".join(["\n".join([f"**{k}:**\n{v}" for k,v in _dict.items()]) for _dict in react_memory.get_parserd_output()])
|
||||
|
||||
def handle_task_records(self, **kwargs) -> str:
|
||||
if 'task_memory' not in kwargs:
|
||||
return ""
|
||||
|
||||
task_memory: Memory = kwargs.get('task_memory', Memory(messages=[]))
|
||||
if task_memory is None:
|
||||
return ""
|
||||
|
||||
return "\n".join(["\n".join([f"**{k}:**\n{v}" for k,v in _dict.items() if k not in ["CURRENT_STEP"]]) for _dict in task_memory.get_parserd_output()])
|
||||
|
||||
def handle_previous_message(self, message: Message) -> str:
|
||||
pass
|
||||
|
||||
def handle_message_by_role_name(self, message: Message) -> str:
|
||||
pass
|
||||
|
||||
def handle_message_by_role_type(self, message: Message) -> str:
|
||||
pass
|
||||
|
||||
def handle_current_agent_react_message(self, message: Message) -> str:
|
||||
pass
|
||||
|
||||
def extract_codedoc_info_for_prompt(self, message: Message) -> str:
|
||||
code_docs = message.code_docs
|
||||
doc_infos = "\n".join([doc.get_code() for doc in code_docs])
|
||||
return doc_infos
|
||||
|
||||
def select_memory_by_parsed_key(self, memory: Memory) -> Memory:
|
||||
return Memory(
|
||||
messages=[self.select_message_by_parsed_key(message) for message in memory.messages
|
||||
if self.select_message_by_parsed_key(message) is not None]
|
||||
)
|
||||
|
||||
def select_memory_by_agent_name(self, memory: Memory) -> Memory:
|
||||
return Memory(
|
||||
messages=[self.select_message_by_agent_name(message) for message in memory.messages
|
||||
if self.select_message_by_agent_name(message) is not None]
|
||||
)
|
||||
|
||||
def select_message_by_agent_name(self, message: Message) -> Message:
|
||||
# assume we focus all agents
|
||||
if self.monitored_agents == []:
|
||||
return message
|
||||
return None if message is None or message.role_name not in self.monitored_agents else self.select_message_by_parsed_key(message)
|
||||
|
||||
def select_message_by_parsed_key(self, message: Message) -> Message:
|
||||
# assume we focus all key contents
|
||||
if message is None:
|
||||
return message
|
||||
|
||||
if self.monitored_fields == []:
|
||||
return message
|
||||
|
||||
message_c = copy.deepcopy(message)
|
||||
message_c.parsed_output = {k: v for k,v in message_c.parsed_output.items() if k in self.monitored_fields}
|
||||
message_c.parsed_output_list = [{k: v for k,v in parsed_output.items() if k in self.monitored_fields} for parsed_output in message_c.parsed_output_list]
|
||||
return message_c
|
||||
|
||||
def get_memory(self, content_key="role_content"):
|
||||
return self.memory.to_tuple_messages(content_key="step_content")
|
||||
|
||||
def get_memory_str(self, content_key="role_content"):
|
||||
return "\n".join([": ".join(i) for i in self.memory.to_tuple_messages(content_key="step_content")])
|
||||
|
||||
def register_fields_from_config(self):
|
||||
|
||||
for prompt_field in self.prompt_config:
|
||||
|
||||
function_name = prompt_field.function_name
|
||||
# 检查function_name是否是self的一个方法
|
||||
if function_name and hasattr(self, function_name):
|
||||
function = getattr(self, function_name)
|
||||
else:
|
||||
function = self.handle_custom_data
|
||||
|
||||
self.register_field(prompt_field.field_name,
|
||||
function=function,
|
||||
title=prompt_field.title,
|
||||
description=prompt_field.description,
|
||||
is_context=prompt_field.is_context,
|
||||
omit_if_empty=prompt_field.omit_if_empty)
|
||||
|
||||
def register_standard_fields(self):
|
||||
self.register_field('agent_profile', function=self.handle_agent_profile, is_context=False)
|
||||
self.register_field('tool_information', function=self.handle_tool_data, is_context=False)
|
||||
self.register_field('context_placeholder', is_context=True) # 用于标记上下文数据部分的位置
|
||||
self.register_field('reference_documents', function=self.handle_doc_info, is_context=True)
|
||||
self.register_field('session_records', function=self.handle_session_records, is_context=True)
|
||||
self.register_field('output_format', function=self.handle_output_format, title='Response Output Format', is_context=False)
|
||||
self.register_field('response', function=self.handle_response, is_context=False, omit_if_empty=False)
|
||||
|
||||
def register_executor_fields(self):
|
||||
self.register_field('agent_profile', function=self.handle_agent_profile, is_context=False)
|
||||
self.register_field('tool_information', function=self.handle_tool_data, is_context=False)
|
||||
self.register_field('context_placeholder', is_context=True) # 用于标记上下文数据部分的位置
|
||||
self.register_field('reference_documents', function=self.handle_doc_info, is_context=True)
|
||||
self.register_field('session_records', function=self.handle_session_records, is_context=True)
|
||||
self.register_field('current_plan', function=self.handle_current_plan, is_context=True)
|
||||
self.register_field('output_format', function=self.handle_output_format, title='Response Output Format', is_context=False)
|
||||
self.register_field('response', function=self.handle_response, is_context=False, omit_if_empty=False)
|
||||
|
||||
def register_fields_from_dict(self, fields_dict):
|
||||
# 使用字典注册字段的函数
|
||||
for field_name, field_config in fields_dict.items():
|
||||
function_name = field_config.get('function', None)
|
||||
title = field_config.get('title', None)
|
||||
description = field_config.get('description', None)
|
||||
is_context = field_config.get('is_context', True)
|
||||
omit_if_empty = field_config.get('omit_if_empty', True)
|
||||
|
||||
# 检查function_name是否是self的一个方法
|
||||
if function_name and hasattr(self, function_name):
|
||||
function = getattr(self, function_name)
|
||||
else:
|
||||
function = self.handle_custom_data
|
||||
|
||||
# 调用已存在的register_field方法注册字段
|
||||
self.register_field(field_name, function=function, title=title, description=description, is_context=is_context, omit_if_empty=omit_if_empty)
|
||||
|
||||
|
||||
|
||||
def main():
|
||||
manager = PromptManager()
|
||||
manager.register_standard_fields()
|
||||
|
||||
manager.register_field('agents_work_progress', title=f"Agents' Work Progress", is_context=True)
|
||||
|
||||
# 创建数据字典
|
||||
data_dict = {
|
||||
"agent_profile": "这是代理配置文件...",
|
||||
# "tool_list": "这是工具列表...",
|
||||
"reference_documents": "这是参考文档...",
|
||||
"session_records": "这是会话记录...",
|
||||
"agents_work_progress": "这是代理工作进展...",
|
||||
"output_format": "这是预期的输出格式...",
|
||||
# "response": "这是生成或继续回应的指令...",
|
||||
"response": "",
|
||||
"test": 'xxxxx'
|
||||
}
|
||||
|
||||
# 组合完整的提示
|
||||
full_prompt = manager.generate_full_prompt(data_dict)
|
||||
print(full_prompt)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -215,15 +215,15 @@ class Env(BaseModel):
|
|||
class Role(BaseModel):
|
||||
role_type: str
|
||||
role_name: str
|
||||
role_desc: str
|
||||
agent_type: str = ""
|
||||
role_desc: str = ""
|
||||
agent_type: str = "BaseAgent"
|
||||
role_prompt: str = ""
|
||||
template_prompt: str = ""
|
||||
|
||||
|
||||
class ChainConfig(BaseModel):
|
||||
chain_name: str
|
||||
chain_type: str
|
||||
chain_type: str = "BaseChain"
|
||||
agents: List[str]
|
||||
do_checker: bool = False
|
||||
chat_turn: int = 1
|
||||
|
|
|
@ -131,6 +131,9 @@ class Memory(BaseModel):
|
|||
# logger.debug(f"{message.role_name}: {message.parsed_output_list}")
|
||||
# return [parsed_output for message in self.messages for parsed_output in message.parsed_output_list[1:]]
|
||||
return [parsed_output for message in self.messages for parsed_output in message.parsed_output_list]
|
||||
|
||||
def get_spec_parserd_output(self, ):
|
||||
return [message.spec_parsed_output for message in self.messages]
|
||||
|
||||
def get_rolenames(self, ):
|
||||
''''''
|
||||
|
|
|
@ -7,6 +7,7 @@ from .general_schema import *
|
|||
|
||||
class Message(BaseModel):
|
||||
chat_index: str = None
|
||||
user_name: str = "default"
|
||||
role_name: str
|
||||
role_type: str
|
||||
role_prompt: str = None
|
||||
|
@ -53,6 +54,8 @@ class Message(BaseModel):
|
|||
cb_search_type: str = None
|
||||
search_engine_name: str = None
|
||||
top_k: int = 3
|
||||
use_nh: bool = True
|
||||
local_graph_path: str = ''
|
||||
score_threshold: float = 1.0
|
||||
do_doc_retrieval: bool = False
|
||||
do_code_retrieval: bool = False
|
||||
|
|
|
@ -72,20 +72,25 @@ def parse_text_to_dict(text):
|
|||
def parse_dict_to_dict(parsed_dict) -> dict:
|
||||
code_pattern = r'```python\n(.*?)```'
|
||||
tool_pattern = r'```json\n(.*?)```'
|
||||
java_pattern = r'```java\n(.*?)```'
|
||||
|
||||
pattern_dict = {"code": code_pattern, "json": tool_pattern}
|
||||
pattern_dict = {"code": code_pattern, "json": tool_pattern, "java": java_pattern}
|
||||
spec_parsed_dict = copy.deepcopy(parsed_dict)
|
||||
for key, pattern in pattern_dict.items():
|
||||
for k, text in parsed_dict.items():
|
||||
# Search for the code block
|
||||
if not isinstance(text, str): continue
|
||||
if not isinstance(text, str):
|
||||
spec_parsed_dict[k] = text
|
||||
continue
|
||||
_match = re.search(pattern, text, re.DOTALL)
|
||||
if _match:
|
||||
# Add the code block to the dictionary
|
||||
try:
|
||||
spec_parsed_dict[key] = json.loads(_match.group(1).strip())
|
||||
spec_parsed_dict[k] = json.loads(_match.group(1).strip())
|
||||
except:
|
||||
spec_parsed_dict[key] = _match.group(1).strip()
|
||||
spec_parsed_dict[k] = _match.group(1).strip()
|
||||
break
|
||||
return spec_parsed_dict
|
||||
|
||||
|
|
|
@ -43,7 +43,7 @@ class NebulaHandler:
|
|||
elif self.space_name:
|
||||
cypher = f'USE {self.space_name};{cypher}'
|
||||
|
||||
logger.debug(cypher)
|
||||
# logger.debug(cypher)
|
||||
resp = session.execute(cypher)
|
||||
|
||||
if format_res:
|
||||
|
@ -247,6 +247,24 @@ class NebulaHandler:
|
|||
res = self.execute_cypher(cypher, self.space_name)
|
||||
return self.result_to_dict(res)
|
||||
|
||||
def get_all_vertices(self,):
|
||||
'''
|
||||
get all vertices
|
||||
@return:
|
||||
'''
|
||||
cypher = "MATCH (v) RETURN v;"
|
||||
res = self.execute_cypher(cypher, self.space_name)
|
||||
return self.result_to_dict(res)
|
||||
|
||||
def get_relative_vertices(self, vertice):
|
||||
'''
|
||||
get all vertices
|
||||
@return:
|
||||
'''
|
||||
cypher = f'''MATCH (v1)--(v2) WHERE id(v1) == '{vertice}' RETURN id(v2) as id;'''
|
||||
res = self.execute_cypher(cypher, self.space_name)
|
||||
return self.result_to_dict(res)
|
||||
|
||||
def result_to_dict(self, result) -> dict:
|
||||
"""
|
||||
build list for each column, and transform to dataframe
|
||||
|
|
|
@ -6,6 +6,7 @@ import os
|
|||
import pickle
|
||||
import uuid
|
||||
import warnings
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
Any,
|
||||
|
@ -22,10 +23,22 @@ import numpy as np
|
|||
|
||||
from langchain.docstore.base import AddableMixin, Docstore
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.docstore.in_memory import InMemoryDocstore
|
||||
# from langchain.docstore.in_memory import InMemoryDocstore
|
||||
from .in_memory import InMemoryDocstore
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.vectorstores.base import VectorStore
|
||||
from langchain.vectorstores.utils import DistanceStrategy, maximal_marginal_relevance
|
||||
from langchain.vectorstores.utils import maximal_marginal_relevance
|
||||
|
||||
|
||||
class DistanceStrategy(str, Enum):
|
||||
"""Enumerator of the Distance strategies for calculating distances
|
||||
between vectors."""
|
||||
|
||||
EUCLIDEAN_DISTANCE = "EUCLIDEAN_DISTANCE"
|
||||
MAX_INNER_PRODUCT = "MAX_INNER_PRODUCT"
|
||||
DOT_PRODUCT = "DOT_PRODUCT"
|
||||
JACCARD = "JACCARD"
|
||||
COSINE = "COSINE"
|
||||
|
||||
|
||||
def dependable_faiss_import(no_avx2: Optional[bool] = None) -> Any:
|
||||
|
@ -219,6 +232,9 @@ class FAISS(VectorStore):
|
|||
if self._normalize_L2:
|
||||
faiss.normalize_L2(vector)
|
||||
scores, indices = self.index.search(vector, k if filter is None else fetch_k)
|
||||
# 经过normalize的结果会超出1
|
||||
if self._normalize_L2:
|
||||
scores = np.array([row / np.linalg.norm(row) if np.max(row) > 1 else row for row in scores])
|
||||
docs = []
|
||||
for j, i in enumerate(indices[0]):
|
||||
if i == -1:
|
||||
|
@ -565,7 +581,7 @@ class FAISS(VectorStore):
|
|||
vecstore = cls(
|
||||
embedding.embed_query,
|
||||
index,
|
||||
InMemoryDocstore(),
|
||||
InMemoryDocstore({}),
|
||||
{},
|
||||
normalize_L2=normalize_L2,
|
||||
distance_strategy=distance_strategy,
|
||||
|
|
|
@ -10,13 +10,14 @@ from loguru import logger
|
|||
# from configs.model_config import EMBEDDING_MODEL
|
||||
from coagent.embeddings.openai_embedding import OpenAIEmbedding
|
||||
from coagent.embeddings.huggingface_embedding import HFEmbedding
|
||||
|
||||
from coagent.llm_models.llm_config import EmbedConfig
|
||||
|
||||
def get_embedding(
|
||||
engine: str,
|
||||
text_list: list,
|
||||
model_path: str = "text2vec-base-chinese",
|
||||
embedding_device: str = "cpu",
|
||||
embed_config: EmbedConfig = None,
|
||||
):
|
||||
'''
|
||||
get embedding
|
||||
|
@ -25,8 +26,12 @@ def get_embedding(
|
|||
@return:
|
||||
'''
|
||||
emb_res = {}
|
||||
|
||||
if engine == 'openai':
|
||||
if embed_config and embed_config.langchain_embeddings:
|
||||
emb_res = embed_config.langchain_embeddings.embed_documents(text_list)
|
||||
emb_res = {
|
||||
text_list[idx]: emb_res[idx] for idx in range(len(text_list))
|
||||
}
|
||||
elif engine == 'openai':
|
||||
oae = OpenAIEmbedding()
|
||||
emb_res = oae.get_emb(text_list)
|
||||
elif engine == 'model':
|
||||
|
|
|
@ -0,0 +1,49 @@
|
|||
"""Simple in memory docstore in the form of a dict."""
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from langchain.docstore.base import AddableMixin, Docstore
|
||||
from langchain.docstore.document import Document
|
||||
|
||||
|
||||
class InMemoryDocstore(Docstore, AddableMixin):
|
||||
"""Simple in memory docstore in the form of a dict."""
|
||||
|
||||
def __init__(self, _dict: Optional[Dict[str, Document]] = None):
|
||||
"""Initialize with dict."""
|
||||
self._dict = _dict if _dict is not None else {}
|
||||
|
||||
def add(self, texts: Dict[str, Document]) -> None:
|
||||
"""Add texts to in memory dictionary.
|
||||
|
||||
Args:
|
||||
texts: dictionary of id -> document.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
overlapping = set(texts).intersection(self._dict)
|
||||
if overlapping:
|
||||
raise ValueError(f"Tried to add ids that already exist: {overlapping}")
|
||||
self._dict = {**self._dict, **texts}
|
||||
|
||||
def delete(self, ids: List) -> None:
|
||||
"""Deleting IDs from in memory dictionary."""
|
||||
overlapping = set(ids).intersection(self._dict)
|
||||
if not overlapping:
|
||||
raise ValueError(f"Tried to delete ids that does not exist: {ids}")
|
||||
for _id in ids:
|
||||
self._dict.pop(_id)
|
||||
|
||||
def search(self, search: str) -> Union[str, Document]:
|
||||
"""Search via direct lookup.
|
||||
|
||||
Args:
|
||||
search: id of a document to search for.
|
||||
|
||||
Returns:
|
||||
Document if found, else error message.
|
||||
"""
|
||||
if search not in self._dict:
|
||||
return f"ID {search} not found."
|
||||
else:
|
||||
return self._dict[search]
|
|
@ -1,6 +1,8 @@
|
|||
import os
|
||||
from functools import lru_cache
|
||||
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
||||
from langchain.embeddings.base import Embeddings
|
||||
|
||||
# from configs.model_config import embedding_model_dict
|
||||
from loguru import logger
|
||||
|
||||
|
@ -12,8 +14,11 @@ def load_embeddings(model: str, device: str, embedding_model_dict: dict):
|
|||
return embeddings
|
||||
|
||||
|
||||
@lru_cache(1)
|
||||
def load_embeddings_from_path(model_path: str, device: str):
|
||||
# @lru_cache(1)
|
||||
def load_embeddings_from_path(model_path: str, device: str, langchain_embeddings: Embeddings = None):
|
||||
if langchain_embeddings:
|
||||
return langchain_embeddings
|
||||
|
||||
embeddings = HuggingFaceEmbeddings(model_name=model_path,
|
||||
model_kwargs={'device': device})
|
||||
return embeddings
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
from .openai_model import getChatModel, getExtraModel, getChatModelFromConfig
|
||||
from .openai_model import getExtraModel, getChatModelFromConfig
|
||||
from .llm_config import LLMConfig, EmbedConfig
|
||||
|
||||
|
||||
__all__ = [
|
||||
"getChatModel", "getExtraModel", "getChatModelFromConfig",
|
||||
"getExtraModel", "getChatModelFromConfig",
|
||||
"LLMConfig", "EmbedConfig"
|
||||
]
|
|
@ -1,6 +1,9 @@
|
|||
from dataclasses import dataclass
|
||||
from typing import List, Union
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.llms.base import LLM, BaseLLM
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -12,7 +15,8 @@ class LLMConfig:
|
|||
stop: Union[List[str], str] = None,
|
||||
api_key: str = "",
|
||||
api_base_url: str = "",
|
||||
model_device: str = "cpu",
|
||||
model_device: str = "cpu", # unuse,will delete it
|
||||
llm: LLM = None,
|
||||
**kwargs
|
||||
):
|
||||
|
||||
|
@ -21,7 +25,7 @@ class LLMConfig:
|
|||
self.stop: Union[List[str], str] = stop
|
||||
self.api_key: str = api_key
|
||||
self.api_base_url: str = api_base_url
|
||||
self.model_device: str = model_device
|
||||
self.llm: LLM = llm
|
||||
#
|
||||
self.check_config()
|
||||
|
||||
|
@ -42,6 +46,7 @@ class EmbedConfig:
|
|||
embed_model_path: str = "",
|
||||
embed_engine: str = "",
|
||||
model_device: str = "cpu",
|
||||
langchain_embeddings: Embeddings = None,
|
||||
**kwargs
|
||||
):
|
||||
self.embed_model: str = embed_model
|
||||
|
@ -51,6 +56,8 @@ class EmbedConfig:
|
|||
self.api_key: str = api_key
|
||||
self.api_base_url: str = api_base_url
|
||||
#
|
||||
self.langchain_embeddings = langchain_embeddings
|
||||
#
|
||||
self.check_config()
|
||||
|
||||
def check_config(self, ):
|
||||
|
|
|
@ -1,38 +1,54 @@
|
|||
import os
|
||||
from typing import Union, Optional, List
|
||||
from loguru import logger
|
||||
|
||||
from langchain.callbacks import AsyncIteratorCallbackHandler
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.llms.base import LLM
|
||||
|
||||
from .llm_config import LLMConfig
|
||||
# from configs.model_config import (llm_model_dict, LLM_MODEL)
|
||||
|
||||
|
||||
def getChatModel(callBack: AsyncIteratorCallbackHandler = None, temperature=0.3, stop=None):
|
||||
if callBack is None:
|
||||
class CustomLLMModel:
|
||||
|
||||
def __init__(self, llm: LLM):
|
||||
self.llm: LLM = llm
|
||||
|
||||
def __call__(self, prompt: str,
|
||||
stop: Optional[List[str]] = None):
|
||||
return self.llm(prompt, stop)
|
||||
|
||||
def _call(self, prompt: str,
|
||||
stop: Optional[List[str]] = None):
|
||||
return self.llm(prompt, stop)
|
||||
|
||||
def predict(self, prompt: str,
|
||||
stop: Optional[List[str]] = None):
|
||||
return self.llm(prompt, stop)
|
||||
|
||||
def batch(self, prompts: str,
|
||||
stop: Optional[List[str]] = None):
|
||||
return [self.llm(prompt, stop) for prompt in prompts]
|
||||
|
||||
|
||||
def getChatModelFromConfig(llm_config: LLMConfig, callBack: AsyncIteratorCallbackHandler = None, ) -> Union[ChatOpenAI, LLM]:
|
||||
# logger.debug(f"llm type is {type(llm_config.llm)}")
|
||||
if llm_config is None:
|
||||
model = ChatOpenAI(
|
||||
streaming=True,
|
||||
verbose=True,
|
||||
openai_api_key=llm_model_dict[LLM_MODEL]["api_key"],
|
||||
openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"],
|
||||
model_name=LLM_MODEL,
|
||||
temperature=temperature,
|
||||
stop=stop
|
||||
openai_api_key=os.environ.get("api_key"),
|
||||
openai_api_base=os.environ.get("api_base_url"),
|
||||
model_name=os.environ.get("LLM_MODEL", "gpt-3.5-turbo"),
|
||||
temperature=os.environ.get("temperature", 0.5),
|
||||
stop=os.environ.get("stop", ""),
|
||||
)
|
||||
else:
|
||||
model = ChatOpenAI(
|
||||
streaming=True,
|
||||
verbose=True,
|
||||
callBack=[callBack],
|
||||
openai_api_key=llm_model_dict[LLM_MODEL]["api_key"],
|
||||
openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"],
|
||||
model_name=LLM_MODEL,
|
||||
temperature=temperature,
|
||||
stop=stop
|
||||
)
|
||||
return model
|
||||
return model
|
||||
|
||||
|
||||
def getChatModelFromConfig(llm_config: LLMConfig, callBack: AsyncIteratorCallbackHandler = None, ):
|
||||
if llm_config and llm_config.llm and isinstance(llm_config.llm, LLM):
|
||||
return CustomLLMModel(llm=llm_config.llm)
|
||||
|
||||
if callBack is None:
|
||||
model = ChatOpenAI(
|
||||
streaming=True,
|
||||
|
|
|
@ -0,0 +1,5 @@
|
|||
# from .base_retrieval import *
|
||||
|
||||
# __all__ = [
|
||||
# "IMRertrieval", "BaseDocRetrieval", "BaseCodeRetrieval", "BaseSearchRetrieval"
|
||||
# ]
|
|
@ -0,0 +1,75 @@
|
|||
|
||||
from coagent.llm_models.llm_config import EmbedConfig, LLMConfig
|
||||
from coagent.base_configs.env_config import KB_ROOT_PATH
|
||||
from coagent.tools import DocRetrieval, CodeRetrieval
|
||||
|
||||
|
||||
class IMRertrieval:
|
||||
|
||||
def __init__(self,):
|
||||
'''
|
||||
init your personal attributes
|
||||
'''
|
||||
pass
|
||||
|
||||
def run(self, ):
|
||||
'''
|
||||
execute interface, and can use init' attributes
|
||||
'''
|
||||
pass
|
||||
|
||||
|
||||
class BaseDocRetrieval(IMRertrieval):
|
||||
|
||||
def __init__(self, knowledge_base_name: str, search_top=5, score_threshold=1.0, embed_config: EmbedConfig=EmbedConfig(), kb_root_path: str=KB_ROOT_PATH):
|
||||
self.knowledge_base_name = knowledge_base_name
|
||||
self.search_top = search_top
|
||||
self.score_threshold = score_threshold
|
||||
self.embed_config = embed_config
|
||||
self.kb_root_path = kb_root_path
|
||||
|
||||
def run(self, query: str, search_top=None, score_threshold=None, ):
|
||||
docs = DocRetrieval.run(
|
||||
query=query, knowledge_base_name=self.knowledge_base_name,
|
||||
search_top=search_top or self.search_top,
|
||||
score_threshold=score_threshold or self.score_threshold,
|
||||
embed_config=self.embed_config,
|
||||
kb_root_path=self.kb_root_path
|
||||
)
|
||||
return docs
|
||||
|
||||
|
||||
class BaseCodeRetrieval(IMRertrieval):
|
||||
|
||||
def __init__(self, code_base_name, embed_config: EmbedConfig, llm_config: LLMConfig, search_type = 'tag', code_limit = 1, local_graph_path: str=""):
|
||||
self.code_base_name = code_base_name
|
||||
self.embed_config = embed_config
|
||||
self.llm_config = llm_config
|
||||
self.search_type = search_type
|
||||
self.code_limit = code_limit
|
||||
self.use_nh: bool = False
|
||||
self.local_graph_path: str = local_graph_path
|
||||
|
||||
def run(self, query, history_node_list=[], search_type = None, code_limit=None):
|
||||
code_docs = CodeRetrieval.run(
|
||||
code_base_name=self.code_base_name,
|
||||
query=query,
|
||||
history_node_list=history_node_list,
|
||||
code_limit=code_limit or self.code_limit,
|
||||
search_type=search_type or self.search_type,
|
||||
llm_config=self.llm_config,
|
||||
embed_config=self.embed_config,
|
||||
use_nh=self.use_nh,
|
||||
local_graph_path=self.local_graph_path
|
||||
)
|
||||
return code_docs
|
||||
|
||||
|
||||
|
||||
class BaseSearchRetrieval(IMRertrieval):
|
||||
|
||||
def __init__(self, ):
|
||||
pass
|
||||
|
||||
def run(self, ):
|
||||
pass
|
|
@ -0,0 +1,6 @@
|
|||
from .json_loader import JSONLoader
|
||||
from .jsonl_loader import JSONLLoader
|
||||
|
||||
__all__ = [
|
||||
"JSONLoader", "JSONLLoader"
|
||||
]
|
|
@ -0,0 +1,61 @@
|
|||
import json
|
||||
from pathlib import Path
|
||||
from typing import AnyStr, Callable, Dict, List, Optional, Union
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
|
||||
|
||||
from coagent.utils.common_utils import read_json_file
|
||||
|
||||
|
||||
class JSONLoader(BaseLoader):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: Union[str, Path],
|
||||
schema_key: str = "all_text",
|
||||
content_key: Optional[str] = None,
|
||||
metadata_func: Optional[Callable[[Dict, Dict], Dict]] = None,
|
||||
text_content: bool = True,
|
||||
):
|
||||
self.file_path = Path(file_path).resolve()
|
||||
self.schema_key = schema_key
|
||||
self._content_key = content_key
|
||||
self._metadata_func = metadata_func
|
||||
self._text_content = text_content
|
||||
|
||||
def load(self, ) -> List[Document]:
|
||||
"""Load and return documents from the JSON file."""
|
||||
docs: List[Document] = []
|
||||
datas = read_json_file(self.file_path)
|
||||
self._parse(datas, docs)
|
||||
return docs
|
||||
|
||||
def _parse(self, datas: List, docs: List[Document]) -> None:
|
||||
for idx, sample in enumerate(datas):
|
||||
metadata = dict(
|
||||
source=str(self.file_path),
|
||||
seq_num=idx,
|
||||
)
|
||||
text = sample.get(self.schema_key, "")
|
||||
docs.append(Document(page_content=text, metadata=metadata))
|
||||
|
||||
def load_and_split(
|
||||
self, text_splitter: Optional[TextSplitter] = None
|
||||
) -> List[Document]:
|
||||
"""Load Documents and split into chunks. Chunks are returned as Documents.
|
||||
|
||||
Args:
|
||||
text_splitter: TextSplitter instance to use for splitting documents.
|
||||
Defaults to RecursiveCharacterTextSplitter.
|
||||
|
||||
Returns:
|
||||
List of Documents.
|
||||
"""
|
||||
if text_splitter is None:
|
||||
_text_splitter: TextSplitter = RecursiveCharacterTextSplitter()
|
||||
else:
|
||||
_text_splitter = text_splitter
|
||||
docs = self.load()
|
||||
return _text_splitter.split_documents(docs)
|
|
@ -0,0 +1,62 @@
|
|||
import json
|
||||
from pathlib import Path
|
||||
from typing import AnyStr, Callable, Dict, List, Optional, Union
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
|
||||
|
||||
from coagent.utils.common_utils import read_jsonl_file
|
||||
|
||||
|
||||
class JSONLLoader(BaseLoader):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: Union[str, Path],
|
||||
schema_key: str = "all_text",
|
||||
content_key: Optional[str] = None,
|
||||
metadata_func: Optional[Callable[[Dict, Dict], Dict]] = None,
|
||||
text_content: bool = True,
|
||||
):
|
||||
self.file_path = Path(file_path).resolve()
|
||||
self.schema_key = schema_key
|
||||
self._content_key = content_key
|
||||
self._metadata_func = metadata_func
|
||||
self._text_content = text_content
|
||||
|
||||
def load(self, ) -> List[Document]:
|
||||
"""Load and return documents from the JSON file."""
|
||||
docs: List[Document] = []
|
||||
datas = read_jsonl_file(self.file_path)
|
||||
self._parse(datas, docs)
|
||||
return docs
|
||||
|
||||
def _parse(self, datas: List, docs: List[Document]) -> None:
|
||||
for idx, sample in enumerate(datas):
|
||||
metadata = dict(
|
||||
source=str(self.file_path),
|
||||
seq_num=idx,
|
||||
)
|
||||
text = sample.get(self.schema_key, "")
|
||||
docs.append(Document(page_content=text, metadata=metadata))
|
||||
|
||||
def load_and_split(
|
||||
self, text_splitter: Optional[TextSplitter] = None
|
||||
) -> List[Document]:
|
||||
"""Load Documents and split into chunks. Chunks are returned as Documents.
|
||||
|
||||
Args:
|
||||
text_splitter: TextSplitter instance to use for splitting documents.
|
||||
Defaults to RecursiveCharacterTextSplitter.
|
||||
|
||||
Returns:
|
||||
List of Documents.
|
||||
"""
|
||||
if text_splitter is None:
|
||||
_text_splitter: TextSplitter = RecursiveCharacterTextSplitter()
|
||||
else:
|
||||
_text_splitter = text_splitter
|
||||
|
||||
docs = self.load()
|
||||
return _text_splitter.split_documents(docs)
|
|
@ -0,0 +1,3 @@
|
|||
from .langchain_splitter import LCTextSplitter
|
||||
|
||||
__all__ = ["LCTextSplitter"]
|
|
@ -0,0 +1,77 @@
|
|||
import os
|
||||
import importlib
|
||||
from loguru import logger
|
||||
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
from langchain.text_splitter import (
|
||||
SpacyTextSplitter, RecursiveCharacterTextSplitter
|
||||
)
|
||||
|
||||
# from configs.model_config import (
|
||||
# CHUNK_SIZE,
|
||||
# OVERLAP_SIZE,
|
||||
# ZH_TITLE_ENHANCE
|
||||
# )
|
||||
from coagent.utils.path_utils import *
|
||||
|
||||
|
||||
|
||||
class LCTextSplitter:
|
||||
'''langchain textsplitter 执行file2text'''
|
||||
def __init__(
|
||||
self, filepath: str, text_splitter_name: str = None,
|
||||
chunk_size: int = 500,
|
||||
overlap_size: int = 50
|
||||
):
|
||||
self.filepath = filepath
|
||||
self.ext = os.path.splitext(filepath)[-1].lower()
|
||||
self.text_splitter_name = text_splitter_name
|
||||
self.chunk_size = chunk_size
|
||||
self.overlap_size = overlap_size
|
||||
if self.ext not in SUPPORTED_EXTS:
|
||||
raise ValueError(f"暂未支持的文件格式 {self.ext}")
|
||||
self.document_loader_name = get_LoaderClass(self.ext)
|
||||
|
||||
def file2text(self, ):
|
||||
loader = self._load_document()
|
||||
text_splitter = self._load_text_splitter()
|
||||
if self.document_loader_name in ["JSONLoader", "JSONLLoader"]:
|
||||
# docs = loader.load()
|
||||
docs = loader.load_and_split(text_splitter)
|
||||
# logger.debug(f"please check your file can be loaded, docs.lens {len(docs)}")
|
||||
else:
|
||||
docs = loader.load_and_split(text_splitter)
|
||||
|
||||
return docs
|
||||
|
||||
def _load_document(self, ) -> BaseLoader:
|
||||
DocumentLoader = EXT2LOADER_DICT[self.ext]
|
||||
if self.document_loader_name == "UnstructuredFileLoader":
|
||||
loader = DocumentLoader(self.filepath, autodetect_encoding=True)
|
||||
else:
|
||||
loader = DocumentLoader(self.filepath)
|
||||
return loader
|
||||
|
||||
def _load_text_splitter(self, ):
|
||||
try:
|
||||
if self.text_splitter_name is None:
|
||||
text_splitter = SpacyTextSplitter(
|
||||
pipeline="zh_core_web_sm",
|
||||
chunk_size=self.chunk_size,
|
||||
chunk_overlap=self.overlap_size,
|
||||
)
|
||||
self.text_splitter_name = "SpacyTextSplitter"
|
||||
# elif self.document_loader_name in ["JSONLoader", "JSONLLoader"]:
|
||||
# text_splitter = None
|
||||
else:
|
||||
text_splitter_module = importlib.import_module('langchain.text_splitter')
|
||||
TextSplitter = getattr(text_splitter_module, self.text_splitter_name)
|
||||
text_splitter = TextSplitter(
|
||||
chunk_size=self.chunk_size,
|
||||
chunk_overlap=self.overlap_size)
|
||||
except Exception as e:
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=self.chunk_size,
|
||||
chunk_overlap=self.overlap_size,
|
||||
)
|
||||
return text_splitter
|
|
@ -32,8 +32,8 @@ class PyCodeBox(BaseBox):
|
|||
self.do_check_net = do_check_net
|
||||
self.use_stop = use_stop
|
||||
self.jupyter_work_path = jupyter_work_path
|
||||
asyncio.run(self.astart())
|
||||
# self.start()
|
||||
# asyncio.run(self.astart())
|
||||
self.start()
|
||||
|
||||
# logger.info(f"""remote_url: {self.remote_url},
|
||||
# remote_ip: {self.remote_ip},
|
||||
|
@ -199,13 +199,13 @@ class PyCodeBox(BaseBox):
|
|||
|
||||
async def _aget_kernelid(self, ) -> None:
|
||||
headers = {"Authorization": f'Token {self.token}', 'token': self.token}
|
||||
response = requests.get(f"{self.kernel_url}?token={self.token}", headers=headers)
|
||||
# response = requests.get(f"{self.kernel_url}?token={self.token}", headers=headers)
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(f"{self.kernel_url}?token={self.token}", headers=headers) as resp:
|
||||
async with session.get(f"{self.kernel_url}?token={self.token}", headers=headers, timeout=270) as resp:
|
||||
if len(await resp.json()) > 0:
|
||||
self.kernel_id = (await resp.json())[0]["id"]
|
||||
else:
|
||||
async with session.post(f"{self.kernel_url}?token={self.token}", headers=headers) as response:
|
||||
async with session.post(f"{self.kernel_url}?token={self.token}", headers=headers, timeout=270) as response:
|
||||
self.kernel_id = (await response.json())["id"]
|
||||
|
||||
# if len(response.json()) > 0:
|
||||
|
@ -220,41 +220,45 @@ class PyCodeBox(BaseBox):
|
|||
return False
|
||||
|
||||
try:
|
||||
response = requests.get(f"{self.kernel_url}?token={self.token}", timeout=270)
|
||||
response = requests.get(f"{self.kernel_url}?token={self.token}", timeout=10)
|
||||
return response.status_code == 200
|
||||
except requests.exceptions.ConnectionError:
|
||||
return False
|
||||
except requests.exceptions.ReadTimeout:
|
||||
return False
|
||||
|
||||
async def _acheck_connect(self, ) -> bool:
|
||||
if self.kernel_url == "":
|
||||
return False
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(f"{self.kernel_url}?token={self.token}", timeout=270) as resp:
|
||||
async with session.get(f"{self.kernel_url}?token={self.token}", timeout=10) as resp:
|
||||
return resp.status == 200
|
||||
except aiohttp.ClientConnectorError:
|
||||
pass
|
||||
return False
|
||||
except aiohttp.ServerDisconnectedError:
|
||||
pass
|
||||
return False
|
||||
|
||||
def _check_port(self, ) -> bool:
|
||||
try:
|
||||
response = requests.get(f"{self.remote_ip}:{self.remote_port}", timeout=270)
|
||||
response = requests.get(f"{self.remote_ip}:{self.remote_port}", timeout=10)
|
||||
logger.warning(f"Port is conflict, please check your codebox's port {self.remote_port}")
|
||||
return response.status_code == 200
|
||||
except requests.exceptions.ConnectionError:
|
||||
return False
|
||||
except requests.exceptions.ReadTimeout:
|
||||
return False
|
||||
|
||||
async def _acheck_port(self, ) -> bool:
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(f"{self.remote_ip}:{self.remote_port}", timeout=270) as resp:
|
||||
async with session.get(f"{self.remote_ip}:{self.remote_port}", timeout=10) as resp:
|
||||
# logger.warning(f"Port is conflict, please check your codebox's port {self.remote_port}")
|
||||
return resp.status == 200
|
||||
except aiohttp.ClientConnectorError:
|
||||
pass
|
||||
return False
|
||||
except aiohttp.ServerDisconnectedError:
|
||||
pass
|
||||
return False
|
||||
|
||||
def _check_connect_success(self, retry_nums: int = 2) -> bool:
|
||||
if not self.do_check_net: return True
|
||||
|
@ -263,7 +267,7 @@ class PyCodeBox(BaseBox):
|
|||
try:
|
||||
connect_status = self._check_connect()
|
||||
if connect_status:
|
||||
logger.info(f"{self.remote_url} connection success")
|
||||
# logger.info(f"{self.remote_url} connection success")
|
||||
return True
|
||||
except requests.exceptions.ConnectionError:
|
||||
logger.info(f"{self.remote_url} connection fail")
|
||||
|
@ -301,10 +305,12 @@ class PyCodeBox(BaseBox):
|
|||
else:
|
||||
# TODO 自动检测本地接口
|
||||
port_status = self._check_port()
|
||||
self.kernel_url = self.remote_url + "/api/kernels"
|
||||
connect_status = self._check_connect()
|
||||
logger.info(f"port_status: {port_status}, connect_status: {connect_status}")
|
||||
if os.environ.get("log_verbose", "0") >= "2":
|
||||
logger.info(f"port_status: {port_status}, connect_status: {connect_status}")
|
||||
if port_status and not connect_status:
|
||||
raise BaseException(f"Port is conflict, please check your codebox's port {self.remote_port}")
|
||||
logger.error("Port is conflict, please check your codebox's port {self.remote_port}")
|
||||
|
||||
if not connect_status:
|
||||
self.jupyter = subprocess.Popen(
|
||||
|
@ -321,14 +327,32 @@ class PyCodeBox(BaseBox):
|
|||
stdout=subprocess.PIPE,
|
||||
)
|
||||
|
||||
record = []
|
||||
while True and self.jupyter and len(record)<100:
|
||||
line = self.jupyter.stderr.readline()
|
||||
try:
|
||||
content = line.decode("utf-8")
|
||||
except:
|
||||
content = line.decode("gbk")
|
||||
# logger.debug(content)
|
||||
record.append(content)
|
||||
if "control-c" in content.lower():
|
||||
break
|
||||
|
||||
self.kernel_url = self.remote_url + "/api/kernels"
|
||||
self.do_check_net = True
|
||||
self._check_connect_success()
|
||||
self._get_kernelid()
|
||||
# logger.debug(self.kernel_url.replace("http", "ws") + f"/{self.kernel_id}/channels?token={self.token}")
|
||||
self.wc_url = self.kernel_url.replace("http", "ws") + f"/{self.kernel_id}/channels?token={self.token}"
|
||||
headers = {"Authorization": f'Token {self.token}', 'token': self.token}
|
||||
self.ws = create_connection(self.wc_url, headers=headers)
|
||||
retry_nums = 3
|
||||
while retry_nums>=0:
|
||||
try:
|
||||
self.ws = create_connection(self.wc_url, headers=headers, timeout=5)
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"create ws connection timeout {e}")
|
||||
retry_nums -= 1
|
||||
|
||||
async def astart(self, ):
|
||||
'''判断是从外部service执行还是内部启动notebook执行'''
|
||||
|
@ -369,10 +393,16 @@ class PyCodeBox(BaseBox):
|
|||
cwd=self.jupyter_work_path
|
||||
)
|
||||
|
||||
while True and self.jupyter:
|
||||
record = []
|
||||
while True and self.jupyter and len(record)<100:
|
||||
line = self.jupyter.stderr.readline()
|
||||
# logger.debug(line.decode("gbk"))
|
||||
if "Control-C" in line.decode("gbk"):
|
||||
try:
|
||||
content = line.decode("utf-8")
|
||||
except:
|
||||
content = line.decode("gbk")
|
||||
# logger.debug(content)
|
||||
record.append(content)
|
||||
if "control-c" in content.lower():
|
||||
break
|
||||
self.kernel_url = self.remote_url + "/api/kernels"
|
||||
self.do_check_net = True
|
||||
|
@ -380,7 +410,15 @@ class PyCodeBox(BaseBox):
|
|||
await self._aget_kernelid()
|
||||
self.wc_url = self.kernel_url.replace("http", "ws") + f"/{self.kernel_id}/channels?token={self.token}"
|
||||
headers = {"Authorization": f'Token {self.token}', 'token': self.token}
|
||||
self.ws = create_connection(self.wc_url, headers=headers)
|
||||
|
||||
retry_nums = 3
|
||||
while retry_nums>=0:
|
||||
try:
|
||||
self.ws = create_connection(self.wc_url, headers=headers, timeout=5)
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"create ws connection timeout {e}")
|
||||
retry_nums -= 1
|
||||
|
||||
def status(self,) -> CodeBoxStatus:
|
||||
if not self.kernel_id:
|
||||
|
|
|
@ -17,7 +17,7 @@ from coagent.orm.commands import *
|
|||
from coagent.utils.path_utils import *
|
||||
from coagent.orm.utils import DocumentFile
|
||||
from coagent.embeddings.utils import load_embeddings, load_embeddings_from_path
|
||||
from coagent.text_splitter import LCTextSplitter
|
||||
from coagent.retrieval.text_splitter import LCTextSplitter
|
||||
from coagent.llm_models.llm_config import EmbedConfig
|
||||
|
||||
|
||||
|
@ -46,7 +46,7 @@ class KBService(ABC):
|
|||
|
||||
def _load_embeddings(self) -> Embeddings:
|
||||
# return load_embeddings(self.embed_model, embed_device, embedding_model_dict)
|
||||
return load_embeddings_from_path(self.embed_config.embed_model_path, self.embed_config.model_device)
|
||||
return load_embeddings_from_path(self.embed_config.embed_model_path, self.embed_config.model_device, self.embed_config.langchain_embeddings)
|
||||
|
||||
def create_kb(self):
|
||||
"""
|
||||
|
|
|
@ -20,9 +20,6 @@ from coagent.utils.path_utils import *
|
|||
from coagent.orm.commands import *
|
||||
from coagent.db_handler.graph_db_handler.nebula_handler import NebulaHandler
|
||||
from coagent.db_handler.vector_db_handler.chroma_handler import ChromaHandler
|
||||
# from configs.server_config import NEBULA_HOST, NEBULA_PORT, NEBULA_USER, NEBULA_PASSWORD, NEBULA_STORAGED_PORT
|
||||
# from configs.server_config import CHROMA_PERSISTENT_PATH
|
||||
|
||||
from coagent.base_configs.env_config import (
|
||||
CB_ROOT_PATH,
|
||||
NEBULA_HOST, NEBULA_PORT, NEBULA_USER, NEBULA_PASSWORD, NEBULA_STORAGED_PORT,
|
||||
|
@ -58,10 +55,11 @@ async def create_cb(zip_file,
|
|||
model_name: bool = Body(..., examples=["samples"]),
|
||||
temperature: bool = Body(..., examples=["samples"]),
|
||||
model_device: bool = Body(..., examples=["samples"]),
|
||||
embed_config: EmbedConfig = None,
|
||||
) -> BaseResponse:
|
||||
logger.info('cb_name={}, zip_path={}, do_interpret={}'.format(cb_name, code_path, do_interpret))
|
||||
|
||||
embed_config: EmbedConfig = EmbedConfig(**locals())
|
||||
embed_config: EmbedConfig = EmbedConfig(**locals()) if embed_config is None else embed_config
|
||||
llm_config: LLMConfig = LLMConfig(**locals())
|
||||
|
||||
# Create selected knowledge base
|
||||
|
@ -101,9 +99,10 @@ async def delete_cb(
|
|||
model_name: bool = Body(..., examples=["samples"]),
|
||||
temperature: bool = Body(..., examples=["samples"]),
|
||||
model_device: bool = Body(..., examples=["samples"]),
|
||||
embed_config: EmbedConfig = None,
|
||||
) -> BaseResponse:
|
||||
logger.info('cb_name={}'.format(cb_name))
|
||||
embed_config: EmbedConfig = EmbedConfig(**locals())
|
||||
embed_config: EmbedConfig = EmbedConfig(**locals()) if embed_config is None else embed_config
|
||||
llm_config: LLMConfig = LLMConfig(**locals())
|
||||
# Create selected knowledge base
|
||||
if not validate_kb_name(cb_name):
|
||||
|
@ -143,18 +142,24 @@ def search_code(cb_name: str = Body(..., examples=["sofaboot"]),
|
|||
model_name: bool = Body(..., examples=["samples"]),
|
||||
temperature: bool = Body(..., examples=["samples"]),
|
||||
model_device: bool = Body(..., examples=["samples"]),
|
||||
use_nh: bool = True,
|
||||
local_graph_path: str = '',
|
||||
embed_config: EmbedConfig = None,
|
||||
) -> dict:
|
||||
|
||||
logger.info('cb_name={}'.format(cb_name))
|
||||
logger.info('query={}'.format(query))
|
||||
logger.info('code_limit={}'.format(code_limit))
|
||||
logger.info('search_type={}'.format(search_type))
|
||||
logger.info('history_node_list={}'.format(history_node_list))
|
||||
embed_config: EmbedConfig = EmbedConfig(**locals())
|
||||
|
||||
if os.environ.get("log_verbose", "0") >= "2":
|
||||
logger.info(f'local_graph_path={local_graph_path}')
|
||||
logger.info('cb_name={}'.format(cb_name))
|
||||
logger.info('query={}'.format(query))
|
||||
logger.info('code_limit={}'.format(code_limit))
|
||||
logger.info('search_type={}'.format(search_type))
|
||||
logger.info('history_node_list={}'.format(history_node_list))
|
||||
embed_config: EmbedConfig = EmbedConfig(**locals()) if embed_config is None else embed_config
|
||||
llm_config: LLMConfig = LLMConfig(**locals())
|
||||
try:
|
||||
# load codebase
|
||||
cbh = CodeBaseHandler(codebase_name=cb_name, embed_config=embed_config, llm_config=llm_config)
|
||||
cbh = CodeBaseHandler(codebase_name=cb_name, embed_config=embed_config, llm_config=llm_config,
|
||||
use_nh=use_nh, local_graph_path=local_graph_path)
|
||||
|
||||
# search code
|
||||
context, related_vertices = cbh.search_code(query, search_type=search_type, limit=code_limit)
|
||||
|
@ -179,11 +184,13 @@ def search_related_vertices(cb_name: str = Body(..., examples=["sofaboot"]),
|
|||
# load codebase
|
||||
nh = NebulaHandler(host=NEBULA_HOST, port=NEBULA_PORT, username=NEBULA_USER,
|
||||
password=NEBULA_PASSWORD, space_name=cb_name)
|
||||
|
||||
cypher = f'''MATCH (v1)--(v2) WHERE id(v1) == '{vertex}' RETURN id(v2) as id;'''
|
||||
|
||||
|
||||
if vertex.endswith(".java"):
|
||||
cypher = f'''MATCH (v1)--(v2:package) WHERE id(v1) == '{vertex}' RETURN id(v2) as id;'''
|
||||
else:
|
||||
cypher = f'''MATCH (v1)--(v2) WHERE id(v1) == '{vertex}' RETURN id(v2) as id;'''
|
||||
# cypher = f'''MATCH (v1)--(v2) WHERE id(v1) == '{vertex}' RETURN v2;'''
|
||||
cypher_res = nh.execute_cypher(cypher=cypher, format_res=True)
|
||||
|
||||
related_vertices = cypher_res.get('id', [])
|
||||
related_vertices = [i.as_string() for i in related_vertices]
|
||||
|
||||
|
@ -200,8 +207,8 @@ def search_related_vertices(cb_name: str = Body(..., examples=["sofaboot"]),
|
|||
def search_code_by_vertex(cb_name: str = Body(..., examples=["sofaboot"]),
|
||||
vertex: str = Body(..., examples=['***'])) -> dict:
|
||||
|
||||
logger.info('cb_name={}'.format(cb_name))
|
||||
logger.info('vertex={}'.format(vertex))
|
||||
# logger.info('cb_name={}'.format(cb_name))
|
||||
# logger.info('vertex={}'.format(vertex))
|
||||
|
||||
try:
|
||||
nh = NebulaHandler(host=NEBULA_HOST, port=NEBULA_PORT, username=NEBULA_USER,
|
||||
|
@ -233,7 +240,7 @@ def search_code_by_vertex(cb_name: str = Body(..., examples=["sofaboot"]),
|
|||
return res
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
return {}
|
||||
return {'code': ""}
|
||||
|
||||
|
||||
def cb_exists_api(cb_name: str = Body(..., examples=["sofaboot"])) -> bool:
|
||||
|
|
|
@ -8,17 +8,6 @@ from loguru import logger
|
|||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
||||
from langchain.vectorstores.utils import DistanceStrategy
|
||||
|
||||
# from configs.model_config import (
|
||||
# KB_ROOT_PATH,
|
||||
# CACHED_VS_NUM,
|
||||
# EMBEDDING_MODEL,
|
||||
# EMBEDDING_DEVICE,
|
||||
# SCORE_THRESHOLD,
|
||||
# FAISS_NORMALIZE_L2
|
||||
# )
|
||||
# from configs.model_config import embedding_model_dict
|
||||
|
||||
from coagent.base_configs.env_config import (
|
||||
KB_ROOT_PATH,
|
||||
|
@ -52,15 +41,15 @@ def load_vector_store(
|
|||
tick: int = 0, # tick will be changed by upload_doc etc. and make cache refreshed.
|
||||
kb_root_path: str = KB_ROOT_PATH,
|
||||
):
|
||||
print(f"loading vector store in '{knowledge_base_name}'.")
|
||||
# print(f"loading vector store in '{knowledge_base_name}'.")
|
||||
vs_path = get_vs_path(knowledge_base_name, kb_root_path)
|
||||
if embeddings is None:
|
||||
embeddings = load_embeddings_from_path(embed_config.embed_model_path, embed_config.model_device)
|
||||
embeddings = load_embeddings_from_path(embed_config.embed_model_path, embed_config.model_device, embed_config.langchain_embeddings)
|
||||
|
||||
if not os.path.exists(vs_path):
|
||||
os.makedirs(vs_path)
|
||||
|
||||
distance_strategy = DistanceStrategy.EUCLIDEAN_DISTANCE
|
||||
distance_strategy = "EUCLIDEAN_DISTANCE"
|
||||
if "index.faiss" in os.listdir(vs_path):
|
||||
search_index = FAISS.load_local(vs_path, embeddings, normalize_L2=FAISS_NORMALIZE_L2, distance_strategy=distance_strategy)
|
||||
else:
|
||||
|
|
|
@ -9,9 +9,7 @@ from pydantic import BaseModel, Field
|
|||
from loguru import logger
|
||||
|
||||
from coagent.llm_models import LLMConfig, EmbedConfig
|
||||
|
||||
from .base_tool import BaseToolModel
|
||||
|
||||
from coagent.service.cb_api import search_code
|
||||
|
||||
|
||||
|
@ -29,7 +27,17 @@ class CodeRetrieval(BaseToolModel):
|
|||
code: str = Field(..., description="检索代码")
|
||||
|
||||
@classmethod
|
||||
def run(cls, code_base_name, query, code_limit=1, history_node_list=[], search_type="tag", llm_config: LLMConfig=None, embed_config: EmbedConfig=None):
|
||||
def run(cls,
|
||||
code_base_name,
|
||||
query,
|
||||
code_limit=1,
|
||||
history_node_list=[],
|
||||
search_type="tag",
|
||||
llm_config: LLMConfig=None,
|
||||
embed_config: EmbedConfig=None,
|
||||
use_nh: str=True,
|
||||
local_graph_path: str=''
|
||||
):
|
||||
"""excute your tool!"""
|
||||
|
||||
search_type = {
|
||||
|
@ -45,7 +53,8 @@ class CodeRetrieval(BaseToolModel):
|
|||
codes = search_code(code_base_name, query, code_limit, search_type=search_type, history_node_list=history_node_list,
|
||||
embed_engine=embed_config.embed_engine, embed_model=embed_config.embed_model, embed_model_path=embed_config.embed_model_path,
|
||||
model_device=embed_config.model_device, model_name=llm_config.model_name, temperature=llm_config.temperature,
|
||||
api_base_url=llm_config.api_base_url, api_key=llm_config.api_key
|
||||
api_base_url=llm_config.api_base_url, api_key=llm_config.api_key, use_nh=use_nh,
|
||||
local_graph_path=local_graph_path, embed_config=embed_config
|
||||
)
|
||||
return_codes = []
|
||||
context = codes['context']
|
||||
|
|
|
@ -5,6 +5,7 @@
|
|||
@time: 2023/12/14 上午10:24
|
||||
@desc:
|
||||
'''
|
||||
import os
|
||||
from pydantic import BaseModel, Field
|
||||
from loguru import logger
|
||||
|
||||
|
@ -40,10 +41,9 @@ class CodeRetrievalSingle(BaseToolModel):
|
|||
vertex: str = Field(..., description="代码对应 id")
|
||||
|
||||
@classmethod
|
||||
def run(cls, code_base_name, query, embed_config: EmbedConfig, llm_config: LLMConfig, **kargs):
|
||||
def run(cls, code_base_name, query, embed_config: EmbedConfig, llm_config: LLMConfig, search_type="description", **kargs):
|
||||
"""excute your tool!"""
|
||||
|
||||
search_type = 'description'
|
||||
code_limit = 1
|
||||
|
||||
# default
|
||||
|
@ -51,10 +51,11 @@ class CodeRetrievalSingle(BaseToolModel):
|
|||
history_node_list=[],
|
||||
embed_engine=embed_config.embed_engine, embed_model=embed_config.embed_model, embed_model_path=embed_config.embed_model_path,
|
||||
model_device=embed_config.model_device, model_name=llm_config.model_name, temperature=llm_config.temperature,
|
||||
api_base_url=llm_config.api_base_url, api_key=llm_config.api_key
|
||||
api_base_url=llm_config.api_base_url, api_key=llm_config.api_key, embed_config=embed_config, use_nh=kargs.get("use_nh", True),
|
||||
local_graph_path=kargs.get("local_graph_path", "")
|
||||
)
|
||||
|
||||
logger.debug(search_result)
|
||||
if os.environ.get("log_verbose", "0") >= "3":
|
||||
logger.debug(search_result)
|
||||
code = search_result['context']
|
||||
vertex = search_result['related_vertices'][0]
|
||||
# logger.debug(f"code: {code}, vertex: {vertex}")
|
||||
|
@ -83,7 +84,7 @@ class RelatedVerticesRetrival(BaseToolModel):
|
|||
def run(cls, code_base_name: str, vertex: str, **kargs):
|
||||
"""execute your tool!"""
|
||||
related_vertices = search_related_vertices(cb_name=code_base_name, vertex=vertex)
|
||||
logger.debug(f"related_vertices: {related_vertices}")
|
||||
# logger.debug(f"related_vertices: {related_vertices}")
|
||||
|
||||
return related_vertices
|
||||
|
||||
|
@ -110,6 +111,6 @@ class Vertex2Code(BaseToolModel):
|
|||
else:
|
||||
vertex = vertex.strip(' "')
|
||||
|
||||
logger.info(f'vertex={vertex}')
|
||||
# logger.info(f'vertex={vertex}')
|
||||
res = search_code_by_vertex(cb_name=code_base_name, vertex=vertex)
|
||||
return res
|
|
@ -2,11 +2,7 @@ from pydantic import BaseModel, Field
|
|||
from loguru import logger
|
||||
|
||||
from coagent.llm_models.llm_config import EmbedConfig
|
||||
|
||||
from .base_tool import BaseToolModel
|
||||
|
||||
|
||||
|
||||
from coagent.service.kb_api import search_docs
|
||||
|
||||
|
||||
|
|
|
@ -9,8 +9,10 @@ import numpy as np
|
|||
from loguru import logger
|
||||
|
||||
from .base_tool import BaseToolModel
|
||||
|
||||
from duckduckgo_search import DDGS
|
||||
try:
|
||||
from duckduckgo_search import DDGS
|
||||
except:
|
||||
logger.warning("can't find duckduckgo_search, if you need it, please `pip install duckduckgo_search`")
|
||||
|
||||
|
||||
class DDGSTool(BaseToolModel):
|
||||
|
|
|
@ -0,0 +1,89 @@
|
|||
import json
|
||||
|
||||
|
||||
def class_info_decode(data):
|
||||
'''解析class的相关信息'''
|
||||
params_dict = {}
|
||||
|
||||
for i in data:
|
||||
_params_dict = {}
|
||||
for ii in i:
|
||||
for k, v in ii.items():
|
||||
if k=="origin_query": continue
|
||||
|
||||
if k == "Code Path":
|
||||
_params_dict["code_path"] = v.split("#")[0]
|
||||
_params_dict["function_name"] = ".".join(v.split("#")[1:])
|
||||
|
||||
if k == "Class Description":
|
||||
_params_dict["ClassDescription"] = v
|
||||
|
||||
if k == "Class Base":
|
||||
_params_dict["ClassBase"] = v
|
||||
|
||||
if k=="Init Parameters":
|
||||
_params_dict["Parameters"] = v
|
||||
|
||||
|
||||
code_path = _params_dict["code_path"]
|
||||
params_dict.setdefault(code_path, []).append(_params_dict)
|
||||
|
||||
return params_dict
|
||||
|
||||
def method_info_decode(data):
|
||||
params_dict = {}
|
||||
|
||||
for i in data:
|
||||
_params_dict = {}
|
||||
for ii in i:
|
||||
for k, v in ii.items():
|
||||
if k=="origin_query": continue
|
||||
|
||||
if k == "Code Path":
|
||||
_params_dict["code_path"] = v.split("#")[0]
|
||||
_params_dict["function_name"] = ".".join(v.split("#")[1:])
|
||||
|
||||
if k == "Return Value Description":
|
||||
_params_dict["Returns"] = v
|
||||
|
||||
if k == "Return Type":
|
||||
_params_dict["ReturnType"] = v
|
||||
|
||||
if k=="Parameters":
|
||||
_params_dict["Parameters"] = v
|
||||
|
||||
|
||||
code_path = _params_dict["code_path"]
|
||||
params_dict.setdefault(code_path, []).append(_params_dict)
|
||||
|
||||
return params_dict
|
||||
|
||||
def encode2md(data, md_format):
|
||||
md_dict = {}
|
||||
for code_path, params_list in data.items():
|
||||
for params in params_list:
|
||||
params["Parameters_text"] = "\n".join([f"{param['param']}({param['param_type']})-{param['param_description']}"
|
||||
for param in params["Parameters"]])
|
||||
# params.delete("Parameters")
|
||||
text=md_format.format(**params)
|
||||
md_dict.setdefault(code_path, []).append(text)
|
||||
return md_dict
|
||||
|
||||
|
||||
method_text_md = '''> {function_name}
|
||||
|
||||
| Column Name | Content |
|
||||
|-----------------|-----------------|
|
||||
| Parameters | {Parameters_text} |
|
||||
| Returns | {Returns} |
|
||||
| Return type | {ReturnType} |
|
||||
'''
|
||||
|
||||
class_text_md = '''> {code_path}
|
||||
|
||||
Bases: {ClassBase}
|
||||
|
||||
{ClassDescription}
|
||||
|
||||
{Parameters_text}
|
||||
'''
|
|
@ -7,7 +7,7 @@ from pathlib import Path
|
|||
from io import BytesIO
|
||||
from fastapi import Body, File, Form, Body, Query, UploadFile
|
||||
from tempfile import SpooledTemporaryFile
|
||||
|
||||
import json
|
||||
|
||||
|
||||
DATE_FORMAT = "%Y-%m-%d %H:%M:%S"
|
||||
|
@ -109,4 +109,6 @@ def get_uploadfile(file: Union[str, Path, bytes], filename=None) -> UploadFile:
|
|||
temp_file = SpooledTemporaryFile(max_size=10 * 1024 * 1024)
|
||||
temp_file.write(file.read())
|
||||
temp_file.seek(0)
|
||||
return UploadFile(file=temp_file, filename=filename)
|
||||
return UploadFile(file=temp_file, filename=filename)
|
||||
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import os
|
||||
from langchain.document_loaders import CSVLoader, PyPDFLoader, UnstructuredFileLoader, TextLoader, PythonLoader
|
||||
|
||||
from coagent.document_loaders import JSONLLoader, JSONLoader
|
||||
from coagent.retrieval.document_loaders import JSONLLoader, JSONLoader
|
||||
# from configs.model_config import (
|
||||
# embedding_model_dict,
|
||||
# KB_ROOT_PATH,
|
||||
|
|
|
@ -21,17 +21,20 @@ JUPYTER_WORK_PATH = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath
|
|||
# WEB_CRAWL存储路径
|
||||
WEB_CRAWL_PATH = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "knowledge_base")
|
||||
# NEBULA_DATA存储路径
|
||||
NELUBA_PATH = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "data/neluba_data")
|
||||
NEBULA_PATH = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "data/nebula_data")
|
||||
|
||||
for _path in [LOG_PATH, SOURCE_PATH, KB_ROOT_PATH, NLTK_DATA_PATH, JUPYTER_WORK_PATH, WEB_CRAWL_PATH, NELUBA_PATH]:
|
||||
# CHROMA 存储路径
|
||||
CHROMA_PERSISTENT_PATH = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "data/chroma_data")
|
||||
|
||||
for _path in [LOG_PATH, SOURCE_PATH, KB_ROOT_PATH, CB_ROOT_PATH, NLTK_DATA_PATH, JUPYTER_WORK_PATH, WEB_CRAWL_PATH, NEBULA_PATH, CHROMA_PERSISTENT_PATH]:
|
||||
if not os.path.exists(_path):
|
||||
os.makedirs(_path, exist_ok=True)
|
||||
|
||||
#
|
||||
|
||||
path_envt_dict = {
|
||||
"LOG_PATH": LOG_PATH, "SOURCE_PATH": SOURCE_PATH, "KB_ROOT_PATH": KB_ROOT_PATH,
|
||||
"NLTK_DATA_PATH":NLTK_DATA_PATH, "JUPYTER_WORK_PATH": JUPYTER_WORK_PATH,
|
||||
"WEB_CRAWL_PATH": WEB_CRAWL_PATH, "NELUBA_PATH": NELUBA_PATH
|
||||
"WEB_CRAWL_PATH": WEB_CRAWL_PATH, "NEBULA_PATH": NEBULA_PATH,
|
||||
"CHROMA_PERSISTENT_PATH": CHROMA_PERSISTENT_PATH
|
||||
}
|
||||
for path_name, _path in path_envt_dict.items():
|
||||
os.environ[path_name] = _path
|
||||
|
|
|
@ -33,7 +33,7 @@ except:
|
|||
pass
|
||||
|
||||
# add your openai key
|
||||
OPENAI_API_BASE = "http://openai.com/v1/chat/completions"
|
||||
OPENAI_API_BASE = "https://api.openai.com/v1"
|
||||
os.environ["API_BASE_URL"] = OPENAI_API_BASE
|
||||
os.environ["OPENAI_API_KEY"] = "sk-xx"
|
||||
openai.api_key = "sk-xx"
|
||||
|
|
|
@ -58,9 +58,6 @@ NEBULA_GRAPH_SERVER = {
|
|||
"docker_port": NEBULA_PORT
|
||||
}
|
||||
|
||||
# chroma conf
|
||||
CHROMA_PERSISTENT_PATH = '/home/user/chatbot/data/chroma_data'
|
||||
|
||||
# sandbox api server
|
||||
SANDBOX_CONTRAINER_NAME = "devopsgpt_sandbox"
|
||||
SANDBOX_IMAGE_NAME = "devopsgpt:py39"
|
||||
|
|
|
@ -15,11 +15,11 @@ from coagent.connector.schema import Message
|
|||
#
|
||||
tools = toLangchainTools([TOOL_DICT[i] for i in TOOL_SETS if i in TOOL_DICT])
|
||||
# log-level,print prompt和llm predict
|
||||
os.environ["log_verbose"] = "0"
|
||||
os.environ["log_verbose"] = "2"
|
||||
|
||||
phase_name = "baseGroupPhase"
|
||||
llm_config = LLMConfig(
|
||||
model_name=LLM_MODEL, model_device="cpu",api_key=os.environ["OPENAI_API_KEY"],
|
||||
model_name=LLM_MODEL, api_key=os.environ["OPENAI_API_KEY"],
|
||||
api_base_url=os.environ["API_BASE_URL"], temperature=0.3
|
||||
)
|
||||
embed_config = EmbedConfig(
|
||||
|
|
|
@ -17,7 +17,7 @@ os.environ["log_verbose"] = "2"
|
|||
|
||||
phase_name = "baseTaskPhase"
|
||||
llm_config = LLMConfig(
|
||||
model_name="gpt-3.5-turbo", model_device="cpu",api_key=os.environ["OPENAI_API_KEY"],
|
||||
model_name="gpt-3.5-turbo", api_key=os.environ["OPENAI_API_KEY"],
|
||||
api_base_url=os.environ["API_BASE_URL"], temperature=0.3
|
||||
)
|
||||
embed_config = EmbedConfig(
|
||||
|
|
|
@ -0,0 +1,135 @@
|
|||
# encoding: utf-8
|
||||
'''
|
||||
@author: 温进
|
||||
@file: codeChatPhaseLocal_example.py
|
||||
@time: 2024/1/31 下午4:32
|
||||
@desc:
|
||||
'''
|
||||
import os, sys, requests
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from tqdm import tqdm
|
||||
|
||||
import requests
|
||||
from typing import List
|
||||
|
||||
src_dir = os.path.join(
|
||||
os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
)
|
||||
sys.path.append(src_dir)
|
||||
|
||||
from configs.model_config import KB_ROOT_PATH, JUPYTER_WORK_PATH, CB_ROOT_PATH
|
||||
from configs.server_config import SANDBOX_SERVER
|
||||
from coagent.tools import toLangchainTools, TOOL_DICT, TOOL_SETS
|
||||
from coagent.llm_models.llm_config import EmbedConfig, LLMConfig
|
||||
from coagent.connector.phase import BasePhase
|
||||
from coagent.connector.schema import Message, Memory
|
||||
from coagent.codechat.codebase_handler.codebase_handler import CodeBaseHandler
|
||||
|
||||
|
||||
|
||||
# log-level,print prompt和llm predict
|
||||
os.environ["log_verbose"] = "1"
|
||||
|
||||
llm_config = LLMConfig(
|
||||
model_name="gpt-3.5-turbo", model_device="cpu",api_key=os.environ["OPENAI_API_KEY"],
|
||||
api_base_url=os.environ["API_BASE_URL"], temperature=0.3
|
||||
)
|
||||
embed_config = EmbedConfig(
|
||||
embed_engine="model", embed_model="text2vec-base-chinese",
|
||||
embed_model_path=os.path.join(src_dir, "embedding_models/text2vec-base-chinese")
|
||||
)
|
||||
|
||||
|
||||
# delete codebase
|
||||
codebase_name = 'client_local'
|
||||
code_path = '/Users/bingxu/Desktop/工作/大模型/chatbot/test_code_repo/client'
|
||||
code_path = "D://chromeDownloads/devopschat-bot/client_v2/client"
|
||||
use_nh = True
|
||||
# cbh = CodeBaseHandler(codebase_name, code_path, crawl_type='dir', use_nh=use_nh, local_graph_path=CB_ROOT_PATH,
|
||||
# llm_config=llm_config, embed_config=embed_config)
|
||||
cbh = CodeBaseHandler(codebase_name, code_path, crawl_type='dir', use_nh=use_nh, local_graph_path=CB_ROOT_PATH,
|
||||
llm_config=llm_config, embed_config=embed_config)
|
||||
cbh.delete_codebase(codebase_name=codebase_name)
|
||||
|
||||
|
||||
# initialize codebase
|
||||
codebase_name = 'client_local'
|
||||
code_path = '/Users/bingxu/Desktop/工作/大模型/chatbot/test_code_repo/client'
|
||||
code_path = "D://chromeDownloads/devopschat-bot/client_v2/client"
|
||||
code_path = "/home/user/client"
|
||||
use_nh = True
|
||||
do_interpret = True
|
||||
cbh = CodeBaseHandler(codebase_name, code_path, crawl_type='dir', use_nh=use_nh, local_graph_path=CB_ROOT_PATH,
|
||||
llm_config=llm_config, embed_config=embed_config)
|
||||
cbh.import_code(do_interpret=do_interpret)
|
||||
|
||||
|
||||
|
||||
# chat with codebase
|
||||
phase_name = "codeChatPhase"
|
||||
phase = BasePhase(
|
||||
phase_name, sandbox_server=SANDBOX_SERVER, jupyter_work_path=JUPYTER_WORK_PATH,
|
||||
embed_config=embed_config, llm_config=llm_config, kb_root_path=KB_ROOT_PATH,
|
||||
)
|
||||
|
||||
# remove 这个函数是做什么的 => 基于标签
|
||||
# 有没有函数已经实现了从字符串删除指定字符串的功能,使用的话可以怎么使用,写个java代码 => 基于描述
|
||||
# 有根据我以下的需求用 java 开发一个方法:输入为字符串,将输入中的 .java 字符串给删除掉,然后返回新的字符串 => 基于描述
|
||||
|
||||
## 需要启动容器中的nebula,采用use_nh=True来构建代码库,是可以通过cypher来查询
|
||||
# round-1
|
||||
# query_content = "代码一共有多少类"
|
||||
# query = Message(
|
||||
# role_name="human", role_type="user",
|
||||
# role_content=query_content, input_query=query_content, origin_query=query_content,
|
||||
# code_engine_name="client_1", score_threshold=1.0, top_k=3, cb_search_type="cypher"
|
||||
# )
|
||||
#
|
||||
# output_message1, _ = phase.step(query)
|
||||
# print(output_message1)
|
||||
|
||||
# round-2
|
||||
# query_content = "代码库里有哪些函数,返回5个就行"
|
||||
# query = Message(
|
||||
# role_name="human", role_type="user",
|
||||
# role_content=query_content, input_query=query_content, origin_query=query_content,
|
||||
# code_engine_name="client_1", score_threshold=1.0, top_k=3, cb_search_type="cypher"
|
||||
# )
|
||||
# output_message2, _ = phase.step(query)
|
||||
# print(output_message2)
|
||||
|
||||
|
||||
# round-3
|
||||
query_content = "remove 这个函数是做什么的"
|
||||
query = Message(
|
||||
role_name="user", role_type="human",
|
||||
role_content=query_content, input_query=query_content, origin_query=query_content,
|
||||
code_engine_name=codebase_name, score_threshold=1.0, top_k=3, cb_search_type="tag",
|
||||
use_nh=False, local_graph_path=CB_ROOT_PATH
|
||||
)
|
||||
output_message3, output_memory3 = phase.step(query)
|
||||
print(output_memory3.to_str_messages(return_all=True, content_key="parsed_output_list"))
|
||||
|
||||
#
|
||||
# # round-4
|
||||
query_content = "有没有函数已经实现了从字符串删除指定字符串的功能,使用的话可以怎么使用,写个java代码"
|
||||
query = Message(
|
||||
role_name="human", role_type="user",
|
||||
role_content=query_content, input_query=query_content, origin_query=query_content,
|
||||
code_engine_name=codebase_name, score_threshold=1.0, top_k=3, cb_search_type="description",
|
||||
use_nh=False, local_graph_path=CB_ROOT_PATH
|
||||
)
|
||||
output_message4, output_memory4 = phase.step(query)
|
||||
print(output_memory4.to_str_messages(return_all=True, content_key="parsed_output_list"))
|
||||
|
||||
|
||||
# # round-5
|
||||
query_content = "有根据我以下的需求用 java 开发一个方法:输入为字符串,将输入中的 .java 字符串给删除掉,然后返回新的字符串"
|
||||
query = Message(
|
||||
role_name="human", role_type="user",
|
||||
role_content=query_content, input_query=query_content, origin_query=query_content,
|
||||
code_engine_name=codebase_name, score_threshold=1.0, top_k=3, cb_search_type="description",
|
||||
use_nh=False, local_graph_path=CB_ROOT_PATH
|
||||
)
|
||||
output_message5, output_memory5 = phase.step(query)
|
||||
print(output_memory5.to_str_messages(return_all=True, content_key="parsed_output_list"))
|
|
@ -17,13 +17,14 @@ os.environ["log_verbose"] = "2"
|
|||
|
||||
phase_name = "codeChatPhase"
|
||||
llm_config = LLMConfig(
|
||||
model_name="gpt-3.5-turbo", model_device="cpu",api_key=os.environ["OPENAI_API_KEY"],
|
||||
model_name="gpt-3.5-turbo", api_key=os.environ["OPENAI_API_KEY"],
|
||||
api_base_url=os.environ["API_BASE_URL"], temperature=0.3
|
||||
)
|
||||
embed_config = EmbedConfig(
|
||||
embed_engine="model", embed_model="text2vec-base-chinese",
|
||||
embed_engine="model", embed_model="text2vec-base-chinese",
|
||||
embed_model_path=os.path.join(src_dir, "embedding_models/text2vec-base-chinese")
|
||||
)
|
||||
|
||||
phase = BasePhase(
|
||||
phase_name, sandbox_server=SANDBOX_SERVER, jupyter_work_path=JUPYTER_WORK_PATH,
|
||||
embed_config=embed_config, llm_config=llm_config, kb_root_path=KB_ROOT_PATH,
|
||||
|
@ -35,50 +36,56 @@ phase = BasePhase(
|
|||
# 有根据我以下的需求用 java 开发一个方法:输入为字符串,将输入中的 .java 字符串给删除掉,然后返回新的字符串 => 基于描述
|
||||
|
||||
# round-1
|
||||
query_content = "代码一共有多少类"
|
||||
query = Message(
|
||||
role_name="human", role_type="user",
|
||||
role_content=query_content, input_query=query_content, origin_query=query_content,
|
||||
code_engine_name="client", score_threshold=1.0, top_k=3, cb_search_type="cypher"
|
||||
)
|
||||
|
||||
output_message1, _ = phase.step(query)
|
||||
# query_content = "代码一共有多少类"
|
||||
# query = Message(
|
||||
# role_name="human", role_type="user",
|
||||
# role_content=query_content, input_query=query_content, origin_query=query_content,
|
||||
# code_engine_name="client_1", score_threshold=1.0, top_k=3, cb_search_type="cypher"
|
||||
# )
|
||||
#
|
||||
# output_message1, _ = phase.step(query)
|
||||
# print(output_message1)
|
||||
|
||||
# round-2
|
||||
query_content = "代码库里有哪些函数,返回5个就行"
|
||||
query = Message(
|
||||
role_name="human", role_type="user",
|
||||
role_content=query_content, input_query=query_content, origin_query=query_content,
|
||||
code_engine_name="client", score_threshold=1.0, top_k=3, cb_search_type="cypher"
|
||||
)
|
||||
output_message2, _ = phase.step(query)
|
||||
# query_content = "代码库里有哪些函数,返回5个就行"
|
||||
# query = Message(
|
||||
# role_name="human", role_type="user",
|
||||
# role_content=query_content, input_query=query_content, origin_query=query_content,
|
||||
# code_engine_name="client_1", score_threshold=1.0, top_k=3, cb_search_type="cypher"
|
||||
# )
|
||||
# output_message2, _ = phase.step(query)
|
||||
# print(output_message2)
|
||||
|
||||
# round-3
|
||||
#
|
||||
# # round-3
|
||||
query_content = "remove 这个函数是做什么的"
|
||||
query = Message(
|
||||
role_name="user", role_type="human",
|
||||
role_name="user", role_type="human",
|
||||
role_content=query_content, input_query=query_content, origin_query=query_content,
|
||||
code_engine_name="client", score_threshold=1.0, top_k=3, cb_search_type="tag"
|
||||
)
|
||||
output_message3, _ = phase.step(query)
|
||||
print(output_message3)
|
||||
|
||||
# round-4
|
||||
query_content = "有没有函数已经实现了从字符串删除指定字符串的功能,使用的话可以怎么使用,写个java代码"
|
||||
query = Message(
|
||||
role_name="human", role_type="user",
|
||||
role_content=query_content, input_query=query_content, origin_query=query_content,
|
||||
code_engine_name="client", score_threshold=1.0, top_k=3, cb_search_type="description"
|
||||
)
|
||||
output_message4, _ = phase.step(query)
|
||||
|
||||
|
||||
# round-5
|
||||
query_content = "有根据我以下的需求用 java 开发一个方法:输入为字符串,将输入中的 .java 字符串给删除掉,然后返回新的字符串"
|
||||
query = Message(
|
||||
role_name="human", role_type="user",
|
||||
role_content=query_content, input_query=query_content, origin_query=query_content,
|
||||
code_engine_name="client", score_threshold=1.0, top_k=3, cb_search_type="description"
|
||||
)
|
||||
output_message5, output_memory5 = phase.step(query)
|
||||
|
||||
print(output_memory5.to_str_messages(return_all=True, content_key="parsed_output_list"))
|
||||
#
|
||||
# # round-4
|
||||
# query_content = "有没有函数已经实现了从字符串删除指定字符串的功能,使用的话可以怎么使用,写个java代码"
|
||||
# query = Message(
|
||||
# role_name="human", role_type="user",
|
||||
# role_content=query_content, input_query=query_content, origin_query=query_content,
|
||||
# code_engine_name="client_1", score_threshold=1.0, top_k=3, cb_search_type="description"
|
||||
# )
|
||||
# output_message4, _ = phase.step(query)
|
||||
# print(output_message4)
|
||||
#
|
||||
# # round-5
|
||||
# query_content = "有根据我以下的需求用 java 开发一个方法:输入为字符串,将输入中的 .java 字符串给删除掉,然后返回新的字符串"
|
||||
# query = Message(
|
||||
# role_name="human", role_type="user",
|
||||
# role_content=query_content, input_query=query_content, origin_query=query_content,
|
||||
# code_engine_name="client_1", score_threshold=1.0, top_k=3, cb_search_type="description"
|
||||
# )
|
||||
# output_message5, output_memory5 = phase.step(query)
|
||||
# print(output_message5)
|
||||
#
|
||||
# print(output_memory5.to_str_messages(return_all=True, content_key="parsed_output_list"))
|
|
@ -0,0 +1,507 @@
|
|||
import os, sys, json
|
||||
from loguru import logger
|
||||
src_dir = os.path.join(
|
||||
os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
)
|
||||
sys.path.append(src_dir)
|
||||
|
||||
from configs.model_config import KB_ROOT_PATH, JUPYTER_WORK_PATH, CB_ROOT_PATH
|
||||
from configs.server_config import SANDBOX_SERVER
|
||||
from coagent.llm_models.llm_config import EmbedConfig, LLMConfig
|
||||
|
||||
from coagent.connector.phase import BasePhase
|
||||
from coagent.connector.agents import BaseAgent
|
||||
from coagent.connector.schema import Message
|
||||
from coagent.tools import CodeRetrievalSingle
|
||||
from coagent.codechat.codebase_handler.codebase_handler import CodeBaseHandler
|
||||
import importlib
|
||||
|
||||
|
||||
# 定义一个新的agent类
|
||||
class CodeGenDocer(BaseAgent):
|
||||
|
||||
def start_action_step(self, message: Message) -> Message:
|
||||
'''do action before agent predict '''
|
||||
# 根据问题获取代码片段和节点信息
|
||||
action_json = CodeRetrievalSingle.run(message.code_engine_name, message.origin_query,
|
||||
llm_config=self.llm_config, embed_config=self.embed_config, local_graph_path=message.local_graph_path, use_nh=message.use_nh,search_type="tag")
|
||||
current_vertex = action_json['vertex']
|
||||
message.customed_kargs["Code Snippet"] = action_json["code"]
|
||||
message.customed_kargs['Current_Vertex'] = current_vertex
|
||||
return message
|
||||
|
||||
|
||||
# add agent or prompt_manager class
|
||||
agent_module = importlib.import_module("coagent.connector.agents")
|
||||
setattr(agent_module, 'CodeGenDocer', CodeGenDocer)
|
||||
|
||||
|
||||
# log-level,print prompt和llm predict
|
||||
os.environ["log_verbose"] = "1"
|
||||
|
||||
phase_name = "code2DocsGroup"
|
||||
llm_config = LLMConfig(
|
||||
model_name="gpt-4", api_key=os.environ["OPENAI_API_KEY"],
|
||||
api_base_url=os.environ["API_BASE_URL"], temperature=0.3
|
||||
)
|
||||
embed_config = EmbedConfig(
|
||||
embed_engine="model", embed_model="text2vec-base-chinese",
|
||||
embed_model_path=os.path.join(src_dir, "embedding_models/text2vec-base-chinese")
|
||||
)
|
||||
|
||||
# initialize codebase
|
||||
# delete codebase
|
||||
codebase_name = 'client_local'
|
||||
code_path = "D://chromeDownloads/devopschat-bot/client_v2/client"
|
||||
use_nh = False
|
||||
cbh = CodeBaseHandler(codebase_name, code_path, crawl_type='dir', use_nh=use_nh, local_graph_path=CB_ROOT_PATH,
|
||||
llm_config=llm_config, embed_config=embed_config)
|
||||
cbh.delete_codebase(codebase_name=codebase_name)
|
||||
|
||||
|
||||
# load codebase
|
||||
codebase_name = 'client_local'
|
||||
code_path = "D://chromeDownloads/devopschat-bot/client_v2/client"
|
||||
use_nh = True
|
||||
do_interpret = True
|
||||
cbh = CodeBaseHandler(codebase_name, code_path, crawl_type='dir', use_nh=use_nh, local_graph_path=CB_ROOT_PATH,
|
||||
llm_config=llm_config, embed_config=embed_config)
|
||||
cbh.import_code(do_interpret=do_interpret)
|
||||
|
||||
# 根据前面的load过程进行初始化
|
||||
cbh = CodeBaseHandler(codebase_name, code_path, crawl_type='dir', use_nh=use_nh, local_graph_path=CB_ROOT_PATH,
|
||||
llm_config=llm_config, embed_config=embed_config)
|
||||
phase = BasePhase(
|
||||
phase_name, sandbox_server=SANDBOX_SERVER, jupyter_work_path=JUPYTER_WORK_PATH,
|
||||
embed_config=embed_config, llm_config=llm_config, kb_root_path=KB_ROOT_PATH,
|
||||
)
|
||||
|
||||
for vertex_type in ["class", "method"]:
|
||||
vertexes = cbh.search_vertices(vertex_type=vertex_type)
|
||||
logger.info(f"vertexes={vertexes}")
|
||||
|
||||
# round-1
|
||||
docs = []
|
||||
for vertex in vertexes:
|
||||
vertex = vertex.split("-")[0] # -为method的参数
|
||||
query_content = f"为{vertex_type}节点 {vertex}生成文档"
|
||||
query = Message(
|
||||
role_name="human", role_type="user",
|
||||
role_content=query_content, input_query=query_content, origin_query=query_content,
|
||||
code_engine_name="client_local", score_threshold=1.0, top_k=3, cb_search_type="tag", use_nh=use_nh,
|
||||
local_graph_path=CB_ROOT_PATH,
|
||||
)
|
||||
output_message, output_memory = phase.step(query, reinit_memory=True)
|
||||
# print(output_memory.to_str_messages(return_all=True, content_key="parsed_output_list"))
|
||||
docs.append(output_memory.get_spec_parserd_output())
|
||||
|
||||
os.makedirs(f"{CB_ROOT_PATH}/docs", exist_ok=True)
|
||||
with open(f"{CB_ROOT_PATH}/docs/raw_{vertex_type}.json", "w") as f:
|
||||
json.dump(docs, f)
|
||||
|
||||
|
||||
# 下面把生成的文档信息转换成markdown文本
|
||||
from coagent.utils.code2doc_util import *
|
||||
import json
|
||||
with open(f"{CB_ROOT_PATH}/docs/raw_method.json", "r") as f:
|
||||
method_raw_data = json.load(f)
|
||||
|
||||
with open(f"{CB_ROOT_PATH}/docs/raw_class.json", "r") as f:
|
||||
class_raw_data = json.load(f)
|
||||
|
||||
|
||||
method_data = method_info_decode(method_raw_data)
|
||||
class_data = class_info_decode(class_raw_data)
|
||||
method_mds = encode2md(method_data, method_text_md)
|
||||
class_mds = encode2md(class_data, class_text_md)
|
||||
|
||||
|
||||
docs_dict = {}
|
||||
for k,v in class_mds.items():
|
||||
method_textmds = method_mds.get(k, [])
|
||||
for vv in v:
|
||||
# 理论上只有一个
|
||||
text_md = vv
|
||||
|
||||
for method_textmd in method_textmds:
|
||||
text_md += "\n<br>" + method_textmd
|
||||
|
||||
docs_dict.setdefault(k, []).append(text_md)
|
||||
|
||||
with open(f"{CB_ROOT_PATH}//docs/{k}.md", "w") as f:
|
||||
f.write(text_md)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
####################################
|
||||
######## 下面是完整的复现过程 ########
|
||||
####################################
|
||||
|
||||
# import os, sys, requests
|
||||
# from loguru import logger
|
||||
# src_dir = os.path.join(
|
||||
# os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
# )
|
||||
# sys.path.append(src_dir)
|
||||
|
||||
# from configs.model_config import KB_ROOT_PATH, JUPYTER_WORK_PATH, CB_ROOT_PATH
|
||||
# from configs.server_config import SANDBOX_SERVER
|
||||
# from coagent.tools import toLangchainTools, TOOL_DICT, TOOL_SETS
|
||||
# from coagent.llm_models.llm_config import EmbedConfig, LLMConfig
|
||||
|
||||
# from coagent.connector.phase import BasePhase
|
||||
# from coagent.connector.agents import BaseAgent, SelectorAgent
|
||||
# from coagent.connector.chains import BaseChain
|
||||
# from coagent.connector.schema import (
|
||||
# Message, Memory, load_role_configs, load_phase_configs, load_chain_configs, ActionStatus
|
||||
# )
|
||||
# from coagent.connector.memory_manager import BaseMemoryManager
|
||||
# from coagent.connector.configs import AGETN_CONFIGS, CHAIN_CONFIGS, PHASE_CONFIGS, BASE_PROMPT_CONFIGS
|
||||
# from coagent.connector.prompt_manager.prompt_manager import PromptManager
|
||||
# from coagent.codechat.codebase_handler.codebase_handler import CodeBaseHandler
|
||||
|
||||
# import importlib
|
||||
# from loguru import logger
|
||||
|
||||
|
||||
# from coagent.tools import CodeRetrievalSingle, RelatedVerticesRetrival, Vertex2Code
|
||||
|
||||
|
||||
# # update new agent configs
|
||||
# codeGenDocGroup_PROMPT = """#### Agent Profile
|
||||
|
||||
# Your goal is to response according the Context Data's information with the role that will best facilitate a solution, taking into account all relevant context (Context) provided.
|
||||
|
||||
# When you need to select the appropriate role for handling a user's query, carefully read the provided role names, role descriptions and tool list.
|
||||
|
||||
# ATTENTION: response carefully referenced "Response Output Format" in format.
|
||||
|
||||
# #### Input Format
|
||||
|
||||
# #### Response Output Format
|
||||
|
||||
# **Code Path:** Extract the paths for the class/method/function that need to be addressed from the context
|
||||
|
||||
# **Role:** Select the role from agent names
|
||||
# """
|
||||
|
||||
# classGenDoc_PROMPT = """#### Agent Profile
|
||||
# As an advanced code documentation generator, you are proficient in translating class definitions into comprehensive documentation with a focus on instantiation parameters.
|
||||
# Your specific task is to parse the given code snippet of a class, extract information regarding its instantiation parameters.
|
||||
|
||||
# ATTENTION: response carefully in "Response Output Format".
|
||||
|
||||
# #### Input Format
|
||||
|
||||
# **Code Snippet:** Provide the full class definition, including the constructor and any parameters it may require for instantiation.
|
||||
|
||||
# #### Response Output Format
|
||||
# **Class Base:** Specify the base class or interface from which the current class extends, if any.
|
||||
|
||||
# **Class Description:** Offer a brief description of the class's purpose and functionality.
|
||||
|
||||
# **Init Parameters:** List each parameter from construct. For each parameter, provide:
|
||||
# - `param`: The parameter name
|
||||
# - `param_description`: A concise explanation of the parameter's purpose.
|
||||
# - `param_type`: The data type of the parameter, if explicitly defined.
|
||||
|
||||
# ```json
|
||||
# [
|
||||
# {
|
||||
# "param": "parameter_name",
|
||||
# "param_description": "A brief description of what this parameter is used for.",
|
||||
# "param_type": "The data type of the parameter"
|
||||
# },
|
||||
# ...
|
||||
# ]
|
||||
# ```
|
||||
|
||||
|
||||
# If no parameter for construct, return
|
||||
# ```json
|
||||
# []
|
||||
# ```
|
||||
# """
|
||||
|
||||
# funcGenDoc_PROMPT = """#### Agent Profile
|
||||
# You are a high-level code documentation assistant, skilled at extracting information from function/method code into detailed and well-structured documentation.
|
||||
|
||||
# ATTENTION: response carefully in "Response Output Format".
|
||||
|
||||
|
||||
# #### Input Format
|
||||
# **Code Path:** Provide the code path of the function or method you wish to document.
|
||||
# This name will be used to identify and extract the relevant details from the code snippet provided.
|
||||
|
||||
# **Code Snippet:** A segment of code that contains the function or method to be documented.
|
||||
|
||||
# #### Response Output Format
|
||||
|
||||
# **Class Description:** Offer a brief description of the method(function)'s purpose and functionality.
|
||||
|
||||
# **Parameters:** Extract parameter for the specific function/method Code from Code Snippet. For parameter, provide:
|
||||
# - `param`: The parameter name
|
||||
# - `param_description`: A concise explanation of the parameter's purpose.
|
||||
# - `param_type`: The data type of the parameter, if explicitly defined.
|
||||
# ```json
|
||||
# [
|
||||
# {
|
||||
# "param": "parameter_name",
|
||||
# "param_description": "A brief description of what this parameter is used for.",
|
||||
# "param_type": "The data type of the parameter"
|
||||
# },
|
||||
# ...
|
||||
# ]
|
||||
# ```
|
||||
|
||||
# If no parameter for function/method, return
|
||||
# ```json
|
||||
# []
|
||||
# ```
|
||||
|
||||
# **Return Value Description:** Describe what the function/method returns upon completion.
|
||||
|
||||
# **Return Type:** Indicate the type of data the function/method returns (e.g., string, integer, object, void).
|
||||
# """
|
||||
|
||||
# CODE_GENERATE_GROUP_PROMPT_CONFIGS = [
|
||||
# {"field_name": 'agent_profile', "function_name": 'handle_agent_profile', "is_context": False},
|
||||
# {"field_name": 'agent_infomation', "function_name": 'handle_agent_data', "is_context": False, "omit_if_empty": False},
|
||||
# # {"field_name": 'tool_information',"function_name": 'handle_tool_data', "is_context": False},
|
||||
# {"field_name": 'context_placeholder', "function_name": '', "is_context": True},
|
||||
# # {"field_name": 'reference_documents', "function_name": 'handle_doc_info'},
|
||||
# {"field_name": 'session_records', "function_name": 'handle_session_records'},
|
||||
# {"field_name": 'Specific Objective', "function_name": 'handle_specific_objective'},
|
||||
# {"field_name": 'Code Snippet', "function_name": 'handle_code_snippet'},
|
||||
# {"field_name": 'output_format', "function_name": 'handle_output_format', 'title': 'Response Output Format', "is_context": False},
|
||||
# {"field_name": 'begin!!!', "function_name": 'handle_response', "is_context": False, "omit_if_empty": False}
|
||||
# ]
|
||||
|
||||
# CODE_GENERATE_DOC_PROMPT_CONFIGS = [
|
||||
# {"field_name": 'agent_profile', "function_name": 'handle_agent_profile', "is_context": False},
|
||||
# # {"field_name": 'tool_information',"function_name": 'handle_tool_data', "is_context": False},
|
||||
# {"field_name": 'context_placeholder', "function_name": '', "is_context": True},
|
||||
# # {"field_name": 'reference_documents', "function_name": 'handle_doc_info'},
|
||||
# {"field_name": 'session_records', "function_name": 'handle_session_records'},
|
||||
# {"field_name": 'Specific Objective', "function_name": 'handle_specific_objective'},
|
||||
# {"field_name": 'Code Snippet', "function_name": 'handle_code_snippet'},
|
||||
# {"field_name": 'output_format', "function_name": 'handle_output_format', 'title': 'Response Output Format', "is_context": False},
|
||||
# {"field_name": 'begin!!!', "function_name": 'handle_response', "is_context": False, "omit_if_empty": False}
|
||||
# ]
|
||||
|
||||
|
||||
# class CodeGenDocPM(PromptManager):
|
||||
# def handle_code_snippet(self, **kwargs) -> str:
|
||||
# if 'previous_agent_message' not in kwargs:
|
||||
# return ""
|
||||
# previous_agent_message: Message = kwargs['previous_agent_message']
|
||||
# code_snippet = previous_agent_message.customed_kargs.get("Code Snippet", "")
|
||||
# current_vertex = previous_agent_message.customed_kargs.get("Current_Vertex", "")
|
||||
# instruction = "A segment of code that contains the function or method to be documented.\n"
|
||||
# return instruction + "\n" + f"name: {current_vertex}\n{code_snippet}"
|
||||
|
||||
# def handle_specific_objective(self, **kwargs) -> str:
|
||||
# if 'previous_agent_message' not in kwargs:
|
||||
# return ""
|
||||
# previous_agent_message: Message = kwargs['previous_agent_message']
|
||||
# specific_objective = previous_agent_message.parsed_output.get("Code Path")
|
||||
|
||||
# instruction = "Provide the code path of the function or method you wish to document.\n"
|
||||
# s = instruction + f"\n{specific_objective}"
|
||||
# return s
|
||||
|
||||
|
||||
# from coagent.tools import CodeRetrievalSingle
|
||||
|
||||
# # 定义一个新的agent类
|
||||
# class CodeGenDocer(BaseAgent):
|
||||
|
||||
# def start_action_step(self, message: Message) -> Message:
|
||||
# '''do action before agent predict '''
|
||||
# # 根据问题获取代码片段和节点信息
|
||||
# action_json = CodeRetrievalSingle.run(message.code_engine_name, message.origin_query,
|
||||
# llm_config=self.llm_config, embed_config=self.embed_config, local_graph_path=message.local_graph_path, use_nh=message.use_nh,search_type="tag")
|
||||
# current_vertex = action_json['vertex']
|
||||
# message.customed_kargs["Code Snippet"] = action_json["code"]
|
||||
# message.customed_kargs['Current_Vertex'] = current_vertex
|
||||
# return message
|
||||
|
||||
# # add agent or prompt_manager class
|
||||
# agent_module = importlib.import_module("coagent.connector.agents")
|
||||
# prompt_manager_module = importlib.import_module("coagent.connector.prompt_manager")
|
||||
|
||||
# setattr(agent_module, 'CodeGenDocer', CodeGenDocer)
|
||||
# setattr(prompt_manager_module, 'CodeGenDocPM', CodeGenDocPM)
|
||||
|
||||
|
||||
|
||||
|
||||
# AGETN_CONFIGS.update({
|
||||
# "classGenDoc": {
|
||||
# "role": {
|
||||
# "role_prompt": classGenDoc_PROMPT,
|
||||
# "role_type": "assistant",
|
||||
# "role_name": "classGenDoc",
|
||||
# "role_desc": "",
|
||||
# "agent_type": "CodeGenDocer"
|
||||
# },
|
||||
# "prompt_config": CODE_GENERATE_DOC_PROMPT_CONFIGS,
|
||||
# "prompt_manager_type": "CodeGenDocPM",
|
||||
# "chat_turn": 1,
|
||||
# "focus_agents": [],
|
||||
# "focus_message_keys": [],
|
||||
# },
|
||||
# "funcGenDoc": {
|
||||
# "role": {
|
||||
# "role_prompt": funcGenDoc_PROMPT,
|
||||
# "role_type": "assistant",
|
||||
# "role_name": "funcGenDoc",
|
||||
# "role_desc": "",
|
||||
# "agent_type": "CodeGenDocer"
|
||||
# },
|
||||
# "prompt_config": CODE_GENERATE_DOC_PROMPT_CONFIGS,
|
||||
# "prompt_manager_type": "CodeGenDocPM",
|
||||
# "chat_turn": 1,
|
||||
# "focus_agents": [],
|
||||
# "focus_message_keys": [],
|
||||
# },
|
||||
# "codeGenDocsGrouper": {
|
||||
# "role": {
|
||||
# "role_prompt": codeGenDocGroup_PROMPT,
|
||||
# "role_type": "assistant",
|
||||
# "role_name": "codeGenDocsGrouper",
|
||||
# "role_desc": "",
|
||||
# "agent_type": "SelectorAgent"
|
||||
# },
|
||||
# "prompt_config": CODE_GENERATE_GROUP_PROMPT_CONFIGS,
|
||||
# "group_agents": ["classGenDoc", "funcGenDoc"],
|
||||
# "chat_turn": 1,
|
||||
# },
|
||||
# })
|
||||
# # update new chain configs
|
||||
# CHAIN_CONFIGS.update({
|
||||
# "codeGenDocsGroupChain": {
|
||||
# "chain_name": "codeGenDocsGroupChain",
|
||||
# "chain_type": "BaseChain",
|
||||
# "agents": ["codeGenDocsGrouper"],
|
||||
# "chat_turn": 1,
|
||||
# "do_checker": False,
|
||||
# "chain_prompt": ""
|
||||
# }
|
||||
# })
|
||||
|
||||
# # update phase configs
|
||||
# PHASE_CONFIGS.update({
|
||||
# "codeGenDocsGroup": {
|
||||
# "phase_name": "codeGenDocsGroup",
|
||||
# "phase_type": "BasePhase",
|
||||
# "chains": ["codeGenDocsGroupChain"],
|
||||
# "do_summary": False,
|
||||
# "do_search": False,
|
||||
# "do_doc_retrieval": False,
|
||||
# "do_code_retrieval": False,
|
||||
# "do_tool_retrieval": False,
|
||||
# },
|
||||
# })
|
||||
|
||||
|
||||
# role_configs = load_role_configs(AGETN_CONFIGS)
|
||||
# chain_configs = load_chain_configs(CHAIN_CONFIGS)
|
||||
# phase_configs = load_phase_configs(PHASE_CONFIGS)
|
||||
|
||||
# # log-level,print prompt和llm predict
|
||||
# os.environ["log_verbose"] = "1"
|
||||
|
||||
# phase_name = "codeGenDocsGroup"
|
||||
# llm_config = LLMConfig(
|
||||
# model_name="gpt-4", api_key=os.environ["OPENAI_API_KEY"],
|
||||
# api_base_url=os.environ["API_BASE_URL"], temperature=0.3
|
||||
# )
|
||||
# embed_config = EmbedConfig(
|
||||
# embed_engine="model", embed_model="text2vec-base-chinese",
|
||||
# embed_model_path=os.path.join(src_dir, "embedding_models/text2vec-base-chinese")
|
||||
# )
|
||||
|
||||
|
||||
# # initialize codebase
|
||||
# # delete codebase
|
||||
# codebase_name = 'client_local'
|
||||
# code_path = "D://chromeDownloads/devopschat-bot/client_v2/client"
|
||||
# use_nh = False
|
||||
# cbh = CodeBaseHandler(codebase_name, code_path, crawl_type='dir', use_nh=use_nh, local_graph_path=CB_ROOT_PATH,
|
||||
# llm_config=llm_config, embed_config=embed_config)
|
||||
# cbh.delete_codebase(codebase_name=codebase_name)
|
||||
|
||||
|
||||
# # load codebase
|
||||
# codebase_name = 'client_local'
|
||||
# code_path = "D://chromeDownloads/devopschat-bot/client_v2/client"
|
||||
# use_nh = False
|
||||
# do_interpret = True
|
||||
# cbh = CodeBaseHandler(codebase_name, code_path, crawl_type='dir', use_nh=use_nh, local_graph_path=CB_ROOT_PATH,
|
||||
# llm_config=llm_config, embed_config=embed_config)
|
||||
# cbh.import_code(do_interpret=do_interpret)
|
||||
|
||||
|
||||
# phase = BasePhase(
|
||||
# phase_name, sandbox_server=SANDBOX_SERVER, jupyter_work_path=JUPYTER_WORK_PATH,
|
||||
# embed_config=embed_config, llm_config=llm_config, kb_root_path=KB_ROOT_PATH,
|
||||
# )
|
||||
|
||||
# for vertex_type in ["class", "method"]:
|
||||
# vertexes = cbh.search_vertices(vertex_type=vertex_type)
|
||||
# logger.info(f"vertexes={vertexes}")
|
||||
|
||||
# # round-1
|
||||
# docs = []
|
||||
# for vertex in vertexes:
|
||||
# vertex = vertex.split("-")[0] # -为method的参数
|
||||
# query_content = f"为{vertex_type}节点 {vertex}生成文档"
|
||||
# query = Message(
|
||||
# role_name="human", role_type="user",
|
||||
# role_content=query_content, input_query=query_content, origin_query=query_content,
|
||||
# code_engine_name="client_local", score_threshold=1.0, top_k=3, cb_search_type="tag", use_nh=use_nh,
|
||||
# local_graph_path=CB_ROOT_PATH,
|
||||
# )
|
||||
# output_message, output_memory = phase.step(query, reinit_memory=True)
|
||||
# # print(output_memory.to_str_messages(return_all=True, content_key="parsed_output_list"))
|
||||
# docs.append(output_memory.get_spec_parserd_output())
|
||||
|
||||
# import json
|
||||
# os.makedirs("/home/user/code_base/docs", exist_ok=True)
|
||||
# with open(f"/home/user/code_base/docs/raw_{vertex_type}.json", "w") as f:
|
||||
# json.dump(docs, f)
|
||||
|
||||
|
||||
# # 下面把生成的文档信息转换成markdown文本
|
||||
# from coagent.utils.code2doc_util import *
|
||||
|
||||
# import json
|
||||
# with open(f"/home/user/code_base/docs/raw_method.json", "r") as f:
|
||||
# method_raw_data = json.load(f)
|
||||
|
||||
# with open(f"/home/user/code_base/docs/raw_class.json", "r") as f:
|
||||
# class_raw_data = json.load(f)
|
||||
|
||||
|
||||
# method_data = method_info_decode(method_raw_data)
|
||||
# class_data = class_info_decode(class_raw_data)
|
||||
# method_mds = encode2md(method_data, method_text_md)
|
||||
# class_mds = encode2md(class_data, class_text_md)
|
||||
|
||||
# docs_dict = {}
|
||||
# for k,v in class_mds.items():
|
||||
# method_textmds = method_mds.get(k, [])
|
||||
# for vv in v:
|
||||
# # 理论上只有一个
|
||||
# text_md = vv
|
||||
|
||||
# for method_textmd in method_textmds:
|
||||
# text_md += "\n<br>" + method_textmd
|
||||
|
||||
# docs_dict.setdefault(k, []).append(text_md)
|
||||
|
||||
# with open(f"/home/user/code_base/docs/{k}.md", "w") as f:
|
||||
# f.write(text_md)
|
|
@ -0,0 +1,444 @@
|
|||
import os, sys, json
|
||||
from loguru import logger
|
||||
src_dir = os.path.join(
|
||||
os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
)
|
||||
sys.path.append(src_dir)
|
||||
|
||||
from configs.model_config import KB_ROOT_PATH, JUPYTER_WORK_PATH, CB_ROOT_PATH
|
||||
from configs.server_config import SANDBOX_SERVER
|
||||
from coagent.llm_models.llm_config import EmbedConfig, LLMConfig
|
||||
|
||||
from coagent.connector.phase import BasePhase
|
||||
from coagent.connector.agents import BaseAgent
|
||||
from coagent.connector.schema import Message
|
||||
from coagent.codechat.codebase_handler.codebase_handler import CodeBaseHandler
|
||||
|
||||
import importlib
|
||||
from loguru import logger
|
||||
|
||||
|
||||
from coagent.tools import CodeRetrievalSingle, RelatedVerticesRetrival, Vertex2Code
|
||||
|
||||
# 定义一个新的agent类
|
||||
class CodeRetrieval(BaseAgent):
|
||||
|
||||
def start_action_step(self, message: Message) -> Message:
|
||||
'''do action before agent predict '''
|
||||
# 根据问题获取代码片段和节点信息
|
||||
action_json = CodeRetrievalSingle.run(message.code_engine_name, message.origin_query, llm_config=self.llm_config, embed_config=self.embed_config, search_type="tag",
|
||||
local_graph_path=message.local_graph_path, use_nh=message.use_nh)
|
||||
current_vertex = action_json['vertex']
|
||||
message.customed_kargs["Code Snippet"] = action_json["code"]
|
||||
message.customed_kargs['Current_Vertex'] = current_vertex
|
||||
|
||||
# 获取邻近节点
|
||||
action_json = RelatedVerticesRetrival.run(message.code_engine_name, message.customed_kargs['Current_Vertex'])
|
||||
# 获取邻近节点所有代码
|
||||
relative_vertex = []
|
||||
retrieval_Codes = []
|
||||
for vertex in action_json["vertices"]:
|
||||
# 由于代码是文件级别,所以相同文件代码不再获取
|
||||
# logger.debug(f"{current_vertex}, {vertex}")
|
||||
current_vertex_name = current_vertex.replace("#", "").replace(".java", "" ) if current_vertex.endswith(".java") else current_vertex
|
||||
if current_vertex_name.split("#")[0] == vertex.split("#")[0]: continue
|
||||
|
||||
action_json = Vertex2Code.run(message.code_engine_name, vertex)
|
||||
if action_json["code"]:
|
||||
retrieval_Codes.append(action_json["code"])
|
||||
relative_vertex.append(vertex)
|
||||
#
|
||||
message.customed_kargs["Retrieval_Codes"] = retrieval_Codes
|
||||
message.customed_kargs["Relative_vertex"] = relative_vertex
|
||||
return message
|
||||
|
||||
|
||||
# add agent or prompt_manager class
|
||||
agent_module = importlib.import_module("coagent.connector.agents")
|
||||
setattr(agent_module, 'CodeRetrieval', CodeRetrieval)
|
||||
|
||||
|
||||
|
||||
llm_config = LLMConfig(
|
||||
model_name="gpt-4", api_key=os.environ["OPENAI_API_KEY"],
|
||||
api_base_url=os.environ["API_BASE_URL"], temperature=0.3
|
||||
)
|
||||
embed_config = EmbedConfig(
|
||||
embed_engine="model", embed_model="text2vec-base-chinese",
|
||||
embed_model_path=os.path.join(src_dir, "embedding_models/text2vec-base-chinese")
|
||||
)
|
||||
|
||||
## initialize codebase
|
||||
# delete codebase
|
||||
codebase_name = 'client_local'
|
||||
code_path = "D://chromeDownloads/devopschat-bot/client_v2/client"
|
||||
use_nh = False
|
||||
cbh = CodeBaseHandler(codebase_name, code_path, crawl_type='dir', use_nh=use_nh, local_graph_path=CB_ROOT_PATH,
|
||||
llm_config=llm_config, embed_config=embed_config)
|
||||
cbh.delete_codebase(codebase_name=codebase_name)
|
||||
|
||||
|
||||
# load codebase
|
||||
codebase_name = 'client_local'
|
||||
code_path = "D://chromeDownloads/devopschat-bot/client_v2/client"
|
||||
use_nh = True
|
||||
do_interpret = True
|
||||
cbh = CodeBaseHandler(codebase_name, code_path, crawl_type='dir', use_nh=use_nh, local_graph_path=CB_ROOT_PATH,
|
||||
llm_config=llm_config, embed_config=embed_config)
|
||||
cbh.import_code(do_interpret=do_interpret)
|
||||
|
||||
|
||||
cbh = CodeBaseHandler(codebase_name, code_path, crawl_type='dir', use_nh=use_nh, local_graph_path=CB_ROOT_PATH,
|
||||
llm_config=llm_config, embed_config=embed_config)
|
||||
vertexes = cbh.search_vertices(vertex_type="class")
|
||||
|
||||
|
||||
# log-level,print prompt和llm predict
|
||||
os.environ["log_verbose"] = "0"
|
||||
|
||||
phase_name = "code2Tests"
|
||||
phase = BasePhase(
|
||||
phase_name, sandbox_server=SANDBOX_SERVER, jupyter_work_path=JUPYTER_WORK_PATH,
|
||||
embed_config=embed_config, llm_config=llm_config, kb_root_path=KB_ROOT_PATH,
|
||||
)
|
||||
|
||||
# round-1
|
||||
logger.debug(vertexes)
|
||||
test_cases = []
|
||||
for vertex in vertexes:
|
||||
query_content = f"为{vertex}生成可执行的测例 "
|
||||
query = Message(
|
||||
role_name="human", role_type="user",
|
||||
role_content=query_content, input_query=query_content, origin_query=query_content,
|
||||
code_engine_name=codebase_name, score_threshold=1.0, top_k=3, cb_search_type="tag",
|
||||
use_nh=use_nh, local_graph_path=CB_ROOT_PATH,
|
||||
)
|
||||
output_message, output_memory = phase.step(query, reinit_memory=True)
|
||||
# print(output_memory.to_str_messages(return_all=True, content_key="parsed_output_list"))
|
||||
print(output_memory.get_spec_parserd_output())
|
||||
values = output_memory.get_spec_parserd_output()
|
||||
test_code = {k:v for i in values for k,v in i.items() if k in ["SaveFileName", "Test Code"]}
|
||||
test_cases.append(test_code)
|
||||
|
||||
os.makedirs(f"{CB_ROOT_PATH}/tests", exist_ok=True)
|
||||
|
||||
with open(f"{CB_ROOT_PATH}/tests/{test_code['SaveFileName']}", "w") as f:
|
||||
f.write(test_code["Test Code"])
|
||||
break
|
||||
|
||||
|
||||
|
||||
|
||||
# import os, sys, json
|
||||
# from loguru import logger
|
||||
# src_dir = os.path.join(
|
||||
# os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
# )
|
||||
# sys.path.append(src_dir)
|
||||
|
||||
# from configs.model_config import KB_ROOT_PATH, JUPYTER_WORK_PATH, CB_ROOT_PATH
|
||||
# from configs.server_config import SANDBOX_SERVER
|
||||
# from coagent.tools import toLangchainTools, TOOL_DICT, TOOL_SETS
|
||||
# from coagent.llm_models.llm_config import EmbedConfig, LLMConfig
|
||||
|
||||
# from coagent.connector.phase import BasePhase
|
||||
# from coagent.connector.agents import BaseAgent
|
||||
# from coagent.connector.chains import BaseChain
|
||||
# from coagent.connector.schema import (
|
||||
# Message, Memory, load_role_configs, load_phase_configs, load_chain_configs, ActionStatus
|
||||
# )
|
||||
# from coagent.connector.memory_manager import BaseMemoryManager
|
||||
# from coagent.connector.configs import AGETN_CONFIGS, CHAIN_CONFIGS, PHASE_CONFIGS, BASE_PROMPT_CONFIGS
|
||||
# from coagent.connector.prompt_manager.prompt_manager import PromptManager
|
||||
# from coagent.codechat.codebase_handler.codebase_handler import CodeBaseHandler
|
||||
|
||||
# import importlib
|
||||
# from loguru import logger
|
||||
|
||||
# # 下面给定了一份代码片段,以及来自于它的依赖类、依赖方法相关的代码片段,你需要判断是否为这段指定代码片段生成测例。
|
||||
# # 理论上所有代码都需要写测例,但是受限于人的精力不可能覆盖所有代码
|
||||
# # 考虑以下因素进行裁剪:
|
||||
# # 功能性: 如果它实现的是一个具体的功能或逻辑,则通常需要编写测试用例以验证其正确性。
|
||||
# # 复杂性: 如果代码较为,尤其是包含多个条件判断、循环、异常处理等的代码,更可能隐藏bug,因此应该编写测试用例。如果代码涉及复杂的算法或者逻辑,那么编写测试用例可以帮助确保逻辑的正确性,并在未来的重构中防止引入错误。
|
||||
# # 关键性: 如果它是关键路径的一部分或影响到核心功能,那么它就需要被测试。对于核心业务逻辑或者系统的关键组件,应当编写全面的测试用例来确保功能的正确性和稳定性。
|
||||
# # 依赖性: 如果代码有外部依赖,可能需要编写集成测试或模拟这些依赖进行单元测试。
|
||||
# # 用户输入: 如果代码处理用户输入,尤其是来自外部的、非受控的输入,那么创建测试用例来检查输入验证和处理是很重要的。
|
||||
# # 频繁更改:对于经常需要更新或修改的代码,有相应的测试用例可以确保更改不会破坏现有功能。
|
||||
|
||||
|
||||
# # 代码公开或重用:如果代码将被公开或用于其他项目,编写测试用例可以提高代码的可信度和易用性。
|
||||
|
||||
|
||||
# # update new agent configs
|
||||
# judgeGenerateTests_PROMPT = """#### Agent Profile
|
||||
# When determining the necessity of writing test cases for a given code snippet,
|
||||
# it's essential to evaluate its interactions with dependent classes and methods (retrieved code snippets),
|
||||
# in addition to considering these critical factors:
|
||||
# 1. Functionality: If it implements a concrete function or logic, test cases are typically necessary to verify its correctness.
|
||||
# 2. Complexity: If the code is complex, especially if it contains multiple conditional statements, loops, exceptions handling, etc.,
|
||||
# it's more likely to harbor bugs, and thus test cases should be written.
|
||||
# If the code involves complex algorithms or logic, then writing test cases can help ensure the accuracy of the logic and prevent errors during future refactoring.
|
||||
# 3. Criticality: If it's part of the critical path or affects core functionalities, then it needs to be tested.
|
||||
# Comprehensive test cases should be written for core business logic or key components of the system to ensure the correctness and stability of the functionality.
|
||||
# 4. Dependencies: If the code has external dependencies, integration testing may be necessary, or mocking these dependencies during unit testing might be required.
|
||||
# 5. User Input: If the code handles user input, especially from unregulated external sources, creating test cases to check input validation and handling is important.
|
||||
# 6. Frequent Changes: For code that requires regular updates or modifications, having the appropriate test cases ensures that changes do not break existing functionalities.
|
||||
|
||||
# #### Input Format
|
||||
|
||||
# **Code Snippet:** the initial Code or objective that the user wanted to achieve
|
||||
|
||||
# **Retrieval Code Snippets:** These are the associated code segments that the main Code Snippet depends on.
|
||||
# Examine these snippets to understand how they interact with the main snippet and to determine how they might affect the overall functionality.
|
||||
|
||||
# #### Response Output Format
|
||||
# **Action Status:** Set to 'finished' or 'continued'.
|
||||
# If set to 'finished', the code snippet does not warrant the generation of a test case.
|
||||
# If set to 'continued', the code snippet necessitates the creation of a test case.
|
||||
|
||||
# **REASON:** Justify the selection of 'finished' or 'continued', contemplating the decision through a step-by-step rationale.
|
||||
# """
|
||||
|
||||
# generateTests_PROMPT = """#### Agent Profile
|
||||
# As an agent specializing in software quality assurance,
|
||||
# your mission is to craft comprehensive test cases that bolster the functionality, reliability, and robustness of a specified Code Snippet.
|
||||
# This task is to be carried out with a keen understanding of the snippet's interactions with its dependent classes and methods—collectively referred to as Retrieval Code Snippets.
|
||||
# Analyze the details given below to grasp the code's intended purpose, its inherent complexity, and the context within which it operates.
|
||||
# Your constructed test cases must thoroughly examine the various factors influencing the code's quality and performance.
|
||||
|
||||
# ATTENTION: response carefully referenced "Response Output Format" in format.
|
||||
|
||||
# Each test case should include:
|
||||
# 1. clear description of the test purpose.
|
||||
# 2. The input values or conditions for the test.
|
||||
# 3. The expected outcome or assertion for the test.
|
||||
# 4. Appropriate tags (e.g., 'functional', 'integration', 'regression') that classify the type of test case.
|
||||
# 5. these test code should have package and import
|
||||
|
||||
# #### Input Format
|
||||
|
||||
# **Code Snippet:** the initial Code or objective that the user wanted to achieve
|
||||
|
||||
# **Retrieval Code Snippets:** These are the interrelated pieces of code sourced from the codebase, which support or influence the primary Code Snippet.
|
||||
|
||||
# #### Response Output Format
|
||||
# **SaveFileName:** construct a local file name based on Question and Context, such as
|
||||
|
||||
# ```java
|
||||
# package/class.java
|
||||
# ```
|
||||
|
||||
|
||||
# **Test Code:** generate the test code for the current Code Snippet.
|
||||
# ```java
|
||||
# ...
|
||||
# ```
|
||||
|
||||
# """
|
||||
|
||||
# CODE_GENERATE_TESTS_PROMPT_CONFIGS = [
|
||||
# {"field_name": 'agent_profile', "function_name": 'handle_agent_profile', "is_context": False},
|
||||
# # {"field_name": 'tool_information',"function_name": 'handle_tool_data', "is_context": False},
|
||||
# {"field_name": 'context_placeholder', "function_name": '', "is_context": True},
|
||||
# # {"field_name": 'reference_documents', "function_name": 'handle_doc_info'},
|
||||
# {"field_name": 'session_records', "function_name": 'handle_session_records'},
|
||||
# {"field_name": 'code_snippet', "function_name": 'handle_code_snippet'},
|
||||
# {"field_name": 'retrieval_codes', "function_name": 'handle_retrieval_codes', "description": ""},
|
||||
# {"field_name": 'output_format', "function_name": 'handle_output_format', 'title': 'Response Output Format', "is_context": False},
|
||||
# {"field_name": 'begin!!!', "function_name": 'handle_response', "is_context": False, "omit_if_empty": False}
|
||||
# ]
|
||||
|
||||
|
||||
# class CodeRetrievalPM(PromptManager):
|
||||
# def handle_code_snippet(self, **kwargs) -> str:
|
||||
# if 'previous_agent_message' not in kwargs:
|
||||
# return ""
|
||||
# previous_agent_message: Message = kwargs['previous_agent_message']
|
||||
# code_snippet = previous_agent_message.customed_kargs.get("Code Snippet", "")
|
||||
# current_vertex = previous_agent_message.customed_kargs.get("Current_Vertex", "")
|
||||
# instruction = "the initial Code or objective that the user wanted to achieve"
|
||||
# return instruction + "\n" + f"name: {current_vertex}\n{code_snippet}"
|
||||
|
||||
# def handle_retrieval_codes(self, **kwargs) -> str:
|
||||
# if 'previous_agent_message' not in kwargs:
|
||||
# return ""
|
||||
# previous_agent_message: Message = kwargs['previous_agent_message']
|
||||
# Retrieval_Codes = previous_agent_message.customed_kargs["Retrieval_Codes"]
|
||||
# Relative_vertex = previous_agent_message.customed_kargs["Relative_vertex"]
|
||||
# instruction = "the initial Code or objective that the user wanted to achieve"
|
||||
# s = instruction + "\n" + "\n".join([f"name: {vertext}\n{code}" for vertext, code in zip(Relative_vertex, Retrieval_Codes)])
|
||||
# return s
|
||||
|
||||
|
||||
|
||||
# AGETN_CONFIGS.update({
|
||||
# "CodeJudger": {
|
||||
# "role": {
|
||||
# "role_prompt": judgeGenerateTests_PROMPT,
|
||||
# "role_type": "assistant",
|
||||
# "role_name": "CodeJudger",
|
||||
# "role_desc": "",
|
||||
# "agent_type": "CodeRetrieval"
|
||||
# # "agent_type": "BaseAgent"
|
||||
# },
|
||||
# "prompt_config": CODE_GENERATE_TESTS_PROMPT_CONFIGS,
|
||||
# "prompt_manager_type": "CodeRetrievalPM",
|
||||
# "chat_turn": 1,
|
||||
# "focus_agents": [],
|
||||
# "focus_message_keys": [],
|
||||
# },
|
||||
# "generateTests": {
|
||||
# "role": {
|
||||
# "role_prompt": generateTests_PROMPT,
|
||||
# "role_type": "assistant",
|
||||
# "role_name": "generateTests",
|
||||
# "role_desc": "",
|
||||
# "agent_type": "CodeRetrieval"
|
||||
# # "agent_type": "BaseAgent"
|
||||
# },
|
||||
# "prompt_config": CODE_GENERATE_TESTS_PROMPT_CONFIGS,
|
||||
# "prompt_manager_type": "CodeRetrievalPM",
|
||||
# "chat_turn": 1,
|
||||
# "focus_agents": [],
|
||||
# "focus_message_keys": [],
|
||||
# },
|
||||
# })
|
||||
# # update new chain configs
|
||||
# CHAIN_CONFIGS.update({
|
||||
# "codeRetrievalChain": {
|
||||
# "chain_name": "codeRetrievalChain",
|
||||
# "chain_type": "BaseChain",
|
||||
# "agents": ["CodeJudger", "generateTests"],
|
||||
# "chat_turn": 1,
|
||||
# "do_checker": False,
|
||||
# "chain_prompt": ""
|
||||
# }
|
||||
# })
|
||||
|
||||
# # update phase configs
|
||||
# PHASE_CONFIGS.update({
|
||||
# "codeGenerateTests": {
|
||||
# "phase_name": "codeGenerateTests",
|
||||
# "phase_type": "BasePhase",
|
||||
# "chains": ["codeRetrievalChain"],
|
||||
# "do_summary": False,
|
||||
# "do_search": False,
|
||||
# "do_doc_retrieval": False,
|
||||
# "do_code_retrieval": False,
|
||||
# "do_tool_retrieval": False,
|
||||
# },
|
||||
# })
|
||||
|
||||
|
||||
# role_configs = load_role_configs(AGETN_CONFIGS)
|
||||
# chain_configs = load_chain_configs(CHAIN_CONFIGS)
|
||||
# phase_configs = load_phase_configs(PHASE_CONFIGS)
|
||||
|
||||
|
||||
# from coagent.tools import CodeRetrievalSingle, RelatedVerticesRetrival, Vertex2Code
|
||||
|
||||
# # 定义一个新的agent类
|
||||
# class CodeRetrieval(BaseAgent):
|
||||
|
||||
# def start_action_step(self, message: Message) -> Message:
|
||||
# '''do action before agent predict '''
|
||||
# # 根据问题获取代码片段和节点信息
|
||||
# action_json = CodeRetrievalSingle.run(message.code_engine_name, message.origin_query, llm_config=self.llm_config, embed_config=self.embed_config, search_type="tag",
|
||||
# local_graph_path=message.local_graph_path, use_nh=message.use_nh)
|
||||
# current_vertex = action_json['vertex']
|
||||
# message.customed_kargs["Code Snippet"] = action_json["code"]
|
||||
# message.customed_kargs['Current_Vertex'] = current_vertex
|
||||
|
||||
# # 获取邻近节点
|
||||
# action_json = RelatedVerticesRetrival.run(message.code_engine_name, message.customed_kargs['Current_Vertex'])
|
||||
# # 获取邻近节点所有代码
|
||||
# relative_vertex = []
|
||||
# retrieval_Codes = []
|
||||
# for vertex in action_json["vertices"]:
|
||||
# # 由于代码是文件级别,所以相同文件代码不再获取
|
||||
# # logger.debug(f"{current_vertex}, {vertex}")
|
||||
# current_vertex_name = current_vertex.replace("#", "").replace(".java", "" ) if current_vertex.endswith(".java") else current_vertex
|
||||
# if current_vertex_name.split("#")[0] == vertex.split("#")[0]: continue
|
||||
|
||||
# action_json = Vertex2Code.run(message.code_engine_name, vertex)
|
||||
# if action_json["code"]:
|
||||
# retrieval_Codes.append(action_json["code"])
|
||||
# relative_vertex.append(vertex)
|
||||
# #
|
||||
# message.customed_kargs["Retrieval_Codes"] = retrieval_Codes
|
||||
# message.customed_kargs["Relative_vertex"] = relative_vertex
|
||||
# return message
|
||||
|
||||
|
||||
# # add agent or prompt_manager class
|
||||
# agent_module = importlib.import_module("coagent.connector.agents")
|
||||
# prompt_manager_module = importlib.import_module("coagent.connector.prompt_manager")
|
||||
|
||||
# setattr(agent_module, 'CodeRetrieval', CodeRetrieval)
|
||||
# setattr(prompt_manager_module, 'CodeRetrievalPM', CodeRetrievalPM)
|
||||
|
||||
|
||||
# # log-level,print prompt和llm predict
|
||||
# os.environ["log_verbose"] = "0"
|
||||
|
||||
# phase_name = "codeGenerateTests"
|
||||
# llm_config = LLMConfig(
|
||||
# model_name="gpt-4", api_key=os.environ["OPENAI_API_KEY"],
|
||||
# api_base_url=os.environ["API_BASE_URL"], temperature=0.3
|
||||
# )
|
||||
# embed_config = EmbedConfig(
|
||||
# embed_engine="model", embed_model="text2vec-base-chinese",
|
||||
# embed_model_path=os.path.join(src_dir, "embedding_models/text2vec-base-chinese")
|
||||
# )
|
||||
|
||||
# ## initialize codebase
|
||||
# # delete codebase
|
||||
# codebase_name = 'client_local'
|
||||
# code_path = "D://chromeDownloads/devopschat-bot/client_v2/client"
|
||||
# use_nh = False
|
||||
# cbh = CodeBaseHandler(codebase_name, code_path, crawl_type='dir', use_nh=use_nh, local_graph_path=CB_ROOT_PATH,
|
||||
# llm_config=llm_config, embed_config=embed_config)
|
||||
# cbh.delete_codebase(codebase_name=codebase_name)
|
||||
|
||||
|
||||
# # load codebase
|
||||
# codebase_name = 'client_local'
|
||||
# code_path = "D://chromeDownloads/devopschat-bot/client_v2/client"
|
||||
# use_nh = True
|
||||
# do_interpret = True
|
||||
# cbh = CodeBaseHandler(codebase_name, code_path, crawl_type='dir', use_nh=use_nh, local_graph_path=CB_ROOT_PATH,
|
||||
# llm_config=llm_config, embed_config=embed_config)
|
||||
# cbh.import_code(do_interpret=do_interpret)
|
||||
|
||||
|
||||
# cbh = CodeBaseHandler(codebase_name, code_path, crawl_type='dir', use_nh=use_nh, local_graph_path=CB_ROOT_PATH,
|
||||
# llm_config=llm_config, embed_config=embed_config)
|
||||
# vertexes = cbh.search_vertices(vertex_type="class")
|
||||
|
||||
# phase = BasePhase(
|
||||
# phase_name, sandbox_server=SANDBOX_SERVER, jupyter_work_path=JUPYTER_WORK_PATH,
|
||||
# embed_config=embed_config, llm_config=llm_config, kb_root_path=KB_ROOT_PATH,
|
||||
# )
|
||||
|
||||
# # round-1
|
||||
# logger.debug(vertexes)
|
||||
# test_cases = []
|
||||
# for vertex in vertexes:
|
||||
# query_content = f"为{vertex}生成可执行的测例 "
|
||||
# query = Message(
|
||||
# role_name="human", role_type="user",
|
||||
# role_content=query_content, input_query=query_content, origin_query=query_content,
|
||||
# code_engine_name=codebase_name, score_threshold=1.0, top_k=3, cb_search_type="tag",
|
||||
# use_nh=use_nh, local_graph_path=CB_ROOT_PATH,
|
||||
# )
|
||||
# output_message, output_memory = phase.step(query, reinit_memory=True)
|
||||
# # print(output_memory.to_str_messages(return_all=True, content_key="parsed_output_list"))
|
||||
# print(output_memory.get_spec_parserd_output())
|
||||
# values = output_memory.get_spec_parserd_output()
|
||||
# test_code = {k:v for i in values for k,v in i.items() if k in ["SaveFileName", "Test Code"]}
|
||||
# test_cases.append(test_code)
|
||||
|
||||
# os.makedirs(f"{CB_ROOT_PATH}/tests", exist_ok=True)
|
||||
|
||||
# with open(f"{CB_ROOT_PATH}/tests/{test_code['SaveFileName']}", "w") as f:
|
||||
# f.write(test_code["Test Code"])
|
|
@ -17,7 +17,7 @@ os.environ["log_verbose"] = "2"
|
|||
|
||||
phase_name = "codeReactPhase"
|
||||
llm_config = LLMConfig(
|
||||
model_name="gpt-3.5-turbo", model_device="cpu",api_key=os.environ["OPENAI_API_KEY"],
|
||||
model_name="gpt-3.5-turbo",api_key=os.environ["OPENAI_API_KEY"],
|
||||
api_base_url=os.environ["API_BASE_URL"], temperature=0.3
|
||||
)
|
||||
embed_config = EmbedConfig(
|
||||
|
|
|
@ -18,8 +18,7 @@ from coagent.connector.schema import (
|
|||
)
|
||||
from coagent.connector.memory_manager import BaseMemoryManager
|
||||
from coagent.connector.configs import AGETN_CONFIGS, CHAIN_CONFIGS, PHASE_CONFIGS, BASE_PROMPT_CONFIGS
|
||||
from coagent.connector.utils import parse_section
|
||||
from coagent.connector.prompt_manager import PromptManager
|
||||
from coagent.connector.prompt_manager.prompt_manager import PromptManager
|
||||
import importlib
|
||||
|
||||
from loguru import logger
|
||||
|
@ -230,7 +229,7 @@ os.environ["log_verbose"] = "2"
|
|||
|
||||
phase_name = "codeRetrievalPhase"
|
||||
llm_config = LLMConfig(
|
||||
model_name="gpt-3.5-turbo", model_device="cpu",api_key=os.environ["OPENAI_API_KEY"],
|
||||
model_name="gpt-3.5-turbo", api_key=os.environ["OPENAI_API_KEY"],
|
||||
api_base_url=os.environ["API_BASE_URL"], temperature=0.3
|
||||
)
|
||||
embed_config = EmbedConfig(
|
||||
|
@ -246,7 +245,7 @@ query_content = "UtilsTest 这个类中测试了哪些函数,测试的函数代
|
|||
query = Message(
|
||||
role_name="human", role_type="user",
|
||||
role_content=query_content, input_query=query_content, origin_query=query_content,
|
||||
code_engine_name="client", score_threshold=1.0, top_k=3, cb_search_type="tag"
|
||||
code_engine_name="client_1", score_threshold=1.0, top_k=3, cb_search_type="tag"
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -24,7 +24,7 @@ os.environ["log_verbose"] = "2"
|
|||
|
||||
phase_name = "codeToolReactPhase"
|
||||
llm_config = LLMConfig(
|
||||
model_name="gpt-3.5-turbo-0613", model_device="cpu",api_key=os.environ["OPENAI_API_KEY"],
|
||||
model_name="gpt-3.5-turbo-0613", api_key=os.environ["OPENAI_API_KEY"],
|
||||
api_base_url=os.environ["API_BASE_URL"], temperature=0.7
|
||||
)
|
||||
embed_config = EmbedConfig(
|
||||
|
|
|
@ -17,7 +17,7 @@ from coagent.connector.schema import Message, Memory
|
|||
|
||||
tools = toLangchainTools([TOOL_DICT[i] for i in TOOL_SETS if i in TOOL_DICT])
|
||||
llm_config = LLMConfig(
|
||||
model_name="gpt-3.5-turbo", model_device="cpu",api_key=os.environ["OPENAI_API_KEY"],
|
||||
model_name="gpt-3.5-turbo",api_key=os.environ["OPENAI_API_KEY"],
|
||||
api_base_url=os.environ["API_BASE_URL"], temperature=0.3
|
||||
)
|
||||
embed_config = EmbedConfig(
|
||||
|
|
|
@ -18,7 +18,7 @@ os.environ["log_verbose"] = "0"
|
|||
|
||||
phase_name = "metagpt_code_devlop"
|
||||
llm_config = LLMConfig(
|
||||
model_name="gpt-3.5-turbo", model_device="cpu",api_key=os.environ["OPENAI_API_KEY"],
|
||||
model_name="gpt-3.5-turbo", api_key=os.environ["OPENAI_API_KEY"],
|
||||
api_base_url=os.environ["API_BASE_URL"], temperature=0.3
|
||||
)
|
||||
embed_config = EmbedConfig(
|
||||
|
|
|
@ -20,7 +20,7 @@ os.environ["log_verbose"] = "2"
|
|||
|
||||
phase_name = "searchChatPhase"
|
||||
llm_config = LLMConfig(
|
||||
model_name="gpt-3.5-turbo", model_device="cpu",api_key=os.environ["OPENAI_API_KEY"],
|
||||
model_name="gpt-3.5-turbo", api_key=os.environ["OPENAI_API_KEY"],
|
||||
api_base_url=os.environ["API_BASE_URL"], temperature=0.3
|
||||
)
|
||||
embed_config = EmbedConfig(
|
||||
|
|
|
@ -18,7 +18,7 @@ os.environ["log_verbose"] = "2"
|
|||
|
||||
phase_name = "toolReactPhase"
|
||||
llm_config = LLMConfig(
|
||||
model_name="gpt-3.5-turbo", model_device="cpu",api_key=os.environ["OPENAI_API_KEY"],
|
||||
model_name="gpt-3.5-turbo",api_key=os.environ["OPENAI_API_KEY"],
|
||||
api_base_url=os.environ["API_BASE_URL"], temperature=0.3
|
||||
)
|
||||
embed_config = EmbedConfig(
|
||||
|
|
|
@ -151,9 +151,9 @@ def create_app():
|
|||
)(delete_cb)
|
||||
|
||||
app.post("/code_base/code_base_chat",
|
||||
tags=["Code Base Management"],
|
||||
summary="删除 code_base"
|
||||
)(delete_cb)
|
||||
tags=["Code Base Management"],
|
||||
summary="code_base 对话"
|
||||
)(search_code)
|
||||
|
||||
app.get("/code_base/list_code_bases",
|
||||
tags=["Code Base Management"],
|
||||
|
|
|
@ -117,7 +117,7 @@ PHASE_CONFIGS.update({
|
|||
|
||||
|
||||
llm_config = LLMConfig(
|
||||
model_name="gpt-3.5-turbo", model_device="cpu",api_key=os.environ["OPENAI_API_KEY"],
|
||||
model_name="gpt-3.5-turbo", api_key=os.environ["OPENAI_API_KEY"],
|
||||
api_base_url=os.environ["API_BASE_URL"], temperature=0.3
|
||||
)
|
||||
|
||||
|
|
|
@ -98,12 +98,6 @@ def start_docker(client, script_shs, ports, image_name, container_name, mounts=N
|
|||
network_name ='my_network'
|
||||
|
||||
def start_sandbox_service(network_name ='my_network'):
|
||||
# networks = client.networks.list()
|
||||
# if any([network_name==i.attrs["Name"] for i in networks]):
|
||||
# network = client.networks.get(network_name)
|
||||
# else:
|
||||
# network = client.networks.create('my_network', driver='bridge')
|
||||
|
||||
mount = Mount(
|
||||
type='bind',
|
||||
source=os.path.join(src_dir, "jupyter_work"),
|
||||
|
@ -114,6 +108,12 @@ def start_sandbox_service(network_name ='my_network'):
|
|||
# 沙盒的启动与服务的启动是独立的
|
||||
if SANDBOX_SERVER["do_remote"]:
|
||||
client = docker.from_env()
|
||||
networks = client.networks.list()
|
||||
if any([network_name==i.attrs["Name"] for i in networks]):
|
||||
network = client.networks.get(network_name)
|
||||
else:
|
||||
network = client.networks.create('my_network', driver='bridge')
|
||||
|
||||
# 启动容器
|
||||
logger.info("start container sandbox service")
|
||||
JUPYTER_WORK_PATH = "/home/user/chatbot/jupyter_work"
|
||||
|
@ -150,7 +150,7 @@ def start_api_service(sandbox_host=DEFAULT_BIND_HOST):
|
|||
client = docker.from_env()
|
||||
logger.info("start container service")
|
||||
check_process("api.py", do_stop=True)
|
||||
check_process("sdfile_api.py", do_stop=True)
|
||||
check_process("llm_api.py", do_stop=True)
|
||||
check_process("sdfile_api.py", do_stop=True)
|
||||
check_process("webui.py", do_stop=True)
|
||||
mount = Mount(
|
||||
|
@ -159,27 +159,28 @@ def start_api_service(sandbox_host=DEFAULT_BIND_HOST):
|
|||
target='/home/user/chatbot/',
|
||||
read_only=False # 如果需要只读访问,将此选项设置为True
|
||||
)
|
||||
mount_database = Mount(
|
||||
type='bind',
|
||||
source=os.path.join(src_dir, "knowledge_base"),
|
||||
target='/home/user/knowledge_base/',
|
||||
read_only=False # 如果需要只读访问,将此选项设置为True
|
||||
)
|
||||
mount_code_database = Mount(
|
||||
type='bind',
|
||||
source=os.path.join(src_dir, "code_base"),
|
||||
target='/home/user/code_base/',
|
||||
read_only=False # 如果需要只读访问,将此选项设置为True
|
||||
)
|
||||
# mount_database = Mount(
|
||||
# type='bind',
|
||||
# source=os.path.join(src_dir, "knowledge_base"),
|
||||
# target='/home/user/knowledge_base/',
|
||||
# read_only=False # 如果需要只读访问,将此选项设置为True
|
||||
# )
|
||||
# mount_code_database = Mount(
|
||||
# type='bind',
|
||||
# source=os.path.join(src_dir, "code_base"),
|
||||
# target='/home/user/code_base/',
|
||||
# read_only=False # 如果需要只读访问,将此选项设置为True
|
||||
# )
|
||||
ports={
|
||||
f"{API_SERVER['docker_port']}/tcp": f"{API_SERVER['port']}/tcp",
|
||||
f"{WEBUI_SERVER['docker_port']}/tcp": f"{WEBUI_SERVER['port']}/tcp",
|
||||
f"{SDFILE_API_SERVER['docker_port']}/tcp": f"{SDFILE_API_SERVER['port']}/tcp",
|
||||
f"{NEBULA_GRAPH_SERVER['docker_port']}/tcp": f"{NEBULA_GRAPH_SERVER['port']}/tcp"
|
||||
}
|
||||
mounts = [mount, mount_database, mount_code_database]
|
||||
# mounts = [mount, mount_database, mount_code_database]
|
||||
mounts = [mount]
|
||||
script_shs = [
|
||||
"mkdir -p /home/user/logs",
|
||||
"mkdir -p /home/user/chatbot/logs",
|
||||
'''
|
||||
if [ -d "/home/user/chatbot/data/nebula_data/data/meta" ]; then
|
||||
cp -r /home/user/chatbot/data/nebula_data/data /usr/local/nebula/
|
||||
|
@ -197,12 +198,12 @@ def start_api_service(sandbox_host=DEFAULT_BIND_HOST):
|
|||
"pip install jieba",
|
||||
"pip install duckduckgo-search",
|
||||
|
||||
"nohup python chatbot/examples/sdfile_api.py > /home/user/logs/sdfile_api.log 2>&1 &",
|
||||
"nohup python chatbot/examples/sdfile_api.py > /home/user/chatbot/logs/sdfile_api.log 2>&1 &",
|
||||
f"export DUCKDUCKGO_PROXY=socks5://host.docker.internal:13659 && export SANDBOX_HOST={sandbox_host} &&\
|
||||
nohup python chatbot/examples/api.py > /home/user/logs/api.log 2>&1 &",
|
||||
nohup python chatbot/examples/api.py > /home/user/chatbot/logs/api.log 2>&1 &",
|
||||
"nohup python chatbot/examples/llm_api.py > /home/user/llm.log 2>&1 &",
|
||||
f"export DUCKDUCKGO_PROXY=socks5://host.docker.internal:13659 && export SANDBOX_HOST={sandbox_host} &&\
|
||||
cd chatbot/examples && nohup streamlit run webui.py > /home/user/logs/start_webui.log 2>&1 &"
|
||||
cd chatbot/examples && nohup streamlit run webui.py > /home/user/chatbot/logs/start_webui.log 2>&1 &"
|
||||
]
|
||||
if check_docker(client, CONTRAINER_NAME, do_stop=True):
|
||||
container = start_docker(client, script_shs, ports, IMAGE_NAME, CONTRAINER_NAME, mounts, network=network_name)
|
||||
|
@ -212,12 +213,9 @@ def start_api_service(sandbox_host=DEFAULT_BIND_HOST):
|
|||
# 关闭之前启动的docker 服务
|
||||
# check_docker(client, CONTRAINER_NAME, do_stop=True, )
|
||||
|
||||
# api_sh = "nohup python ../coagent/service/api.py > ../logs/api.log 2>&1 &"
|
||||
api_sh = "nohup python api.py > ../logs/api.log 2>&1 &"
|
||||
# sdfile_sh = "nohup python ../coagent/service/sdfile_api.py > ../logs/sdfile_api.log 2>&1 &"
|
||||
sdfile_sh = "nohup python sdfile_api.py > ../logs/sdfile_api.log 2>&1 &"
|
||||
notebook_sh = f"nohup jupyter-notebook --NotebookApp.token=mytoken --port={SANDBOX_SERVER['port']} --allow-root --ip=0.0.0.0 --notebook-dir={JUPYTER_WORK_PATH} --no-browser --ServerApp.disable_check_xsrf=True > ../logs/sandbox.log 2>&1 &"
|
||||
# llm_sh = "nohup python ../coagent/service/llm_api.py > ../logs/llm_api.log 2>&1 &"
|
||||
llm_sh = "nohup python llm_api.py > ../logs/llm_api.log 2>&1 &"
|
||||
webui_sh = "streamlit run webui.py" if USE_TTY else "streamlit run webui.py"
|
||||
|
||||
|
|
|
@ -22,7 +22,7 @@ from coagent.service.service_factory import get_cb_details, get_cb_details_by_cb
|
|||
from coagent.orm import table_init
|
||||
|
||||
|
||||
from configs.model_config import EMBEDDING_DEVICE, EMBEDDING_ENGINE, EMBEDDING_MODEL, embedding_model_dict
|
||||
from configs.model_config import EMBEDDING_DEVICE, EMBEDDING_ENGINE, EMBEDDING_MODEL, embedding_model_dict,llm_model_dict
|
||||
# SENTENCE_SIZE = 100
|
||||
|
||||
cell_renderer = JsCode("""function(params) {if(params.value==true){return '✓'}else{return '×'}}""")
|
||||
|
@ -117,6 +117,8 @@ def code_page(api: ApiRequest):
|
|||
embed_model_path=embedding_model_dict[EMBEDDING_MODEL],
|
||||
embedding_device=EMBEDDING_DEVICE,
|
||||
llm_model=LLM_MODEL,
|
||||
api_key=llm_model_dict[LLM_MODEL]["api_key"],
|
||||
api_base_url=llm_model_dict[LLM_MODEL]["api_base_url"],
|
||||
)
|
||||
st.toast(ret.get("msg", " "))
|
||||
st.session_state["selected_cb_name"] = cb_name
|
||||
|
@ -153,6 +155,8 @@ def code_page(api: ApiRequest):
|
|||
embed_model_path=embedding_model_dict[EMBEDDING_MODEL],
|
||||
embedding_device=EMBEDDING_DEVICE,
|
||||
llm_model=LLM_MODEL,
|
||||
api_key=llm_model_dict[LLM_MODEL]["api_key"],
|
||||
api_base_url=llm_model_dict[LLM_MODEL]["api_base_url"],
|
||||
)
|
||||
st.toast(ret.get("msg", "删除成功"))
|
||||
time.sleep(0.05)
|
||||
|
|
|
@ -11,7 +11,7 @@ from coagent.chat.search_chat import SEARCH_ENGINES
|
|||
from coagent.connector import PHASE_LIST, PHASE_CONFIGS
|
||||
from coagent.service.service_factory import get_cb_details_by_cb_name
|
||||
|
||||
from configs.model_config import EMBEDDING_DEVICE, EMBEDDING_MODEL, embedding_model_dict, EMBEDDING_ENGINE, KB_ROOT_PATH
|
||||
from configs.model_config import EMBEDDING_DEVICE, EMBEDDING_MODEL, embedding_model_dict, EMBEDDING_ENGINE, KB_ROOT_PATH, llm_model_dict
|
||||
chat_box = ChatBox(
|
||||
assistant_avatar="../sources/imgs/devops-chatbot2.png"
|
||||
)
|
||||
|
@ -174,7 +174,7 @@ def dialogue_page(api: ApiRequest):
|
|||
is_detailed = st.toggle(webui_configs["dialogue"]["phase_toggle_detailed_name"], False)
|
||||
tool_using_on = st.toggle(
|
||||
webui_configs["dialogue"]["phase_toggle_doToolUsing"],
|
||||
PHASE_CONFIGS[choose_phase]["do_using_tool"])
|
||||
PHASE_CONFIGS[choose_phase].get("do_using_tool", False))
|
||||
tool_selects = []
|
||||
if tool_using_on:
|
||||
with st.expander("工具军火库", True):
|
||||
|
@ -183,7 +183,7 @@ def dialogue_page(api: ApiRequest):
|
|||
TOOL_SETS, ["WeatherInfo"])
|
||||
|
||||
search_on = st.toggle(webui_configs["dialogue"]["phase_toggle_doSearch"],
|
||||
PHASE_CONFIGS[choose_phase]["do_search"])
|
||||
PHASE_CONFIGS[choose_phase].get("do_search", False))
|
||||
search_engine, top_k = None, 3
|
||||
if search_on:
|
||||
with st.expander(webui_configs["dialogue"]["expander_search_name"], True):
|
||||
|
@ -195,7 +195,8 @@ def dialogue_page(api: ApiRequest):
|
|||
|
||||
doc_retrieval_on = st.toggle(
|
||||
webui_configs["dialogue"]["phase_toggle_doDocRetrieval"],
|
||||
PHASE_CONFIGS[choose_phase]["do_doc_retrieval"])
|
||||
PHASE_CONFIGS[choose_phase].get("do_doc_retrieval", False)
|
||||
)
|
||||
selected_kb, top_k, score_threshold = None, 3, 1.0
|
||||
if doc_retrieval_on:
|
||||
with st.expander(webui_configs["dialogue"]["kbase_expander_name"], True):
|
||||
|
@ -215,7 +216,7 @@ def dialogue_page(api: ApiRequest):
|
|||
|
||||
code_retrieval_on = st.toggle(
|
||||
webui_configs["dialogue"]["phase_toggle_doCodeRetrieval"],
|
||||
PHASE_CONFIGS[choose_phase]["do_code_retrieval"])
|
||||
PHASE_CONFIGS[choose_phase].get("do_code_retrieval", False))
|
||||
selected_cb, top_k = None, 1
|
||||
cb_search_type = "tag"
|
||||
if code_retrieval_on:
|
||||
|
@ -296,7 +297,8 @@ def dialogue_page(api: ApiRequest):
|
|||
r = api.chat_chat(
|
||||
prompt, history, no_remote_api=True,
|
||||
embed_model=EMBEDDING_MODEL, embed_model_path=embedding_model_dict[EMBEDDING_MODEL],
|
||||
model_device=EMBEDDING_DEVICE, embed_engine=EMBEDDING_ENGINE,
|
||||
model_device=EMBEDDING_DEVICE, embed_engine=EMBEDDING_ENGINE,api_key=llm_model_dict[LLM_MODEL]["api_key"],
|
||||
api_base_url=llm_model_dict[LLM_MODEL]["api_base_url"],
|
||||
llm_model=LLM_MODEL)
|
||||
for t in r:
|
||||
if error_msg := check_error_msg(t): # check whether error occured
|
||||
|
@ -362,6 +364,8 @@ def dialogue_page(api: ApiRequest):
|
|||
"embed_engine": EMBEDDING_ENGINE,
|
||||
"kb_root_path": KB_ROOT_PATH,
|
||||
"model_name": LLM_MODEL,
|
||||
"api_key": llm_model_dict[LLM_MODEL]["api_key"],
|
||||
"api_base_url": llm_model_dict[LLM_MODEL]["api_base_url"],
|
||||
}
|
||||
text = ""
|
||||
d = {"docs": []}
|
||||
|
@ -405,7 +409,10 @@ def dialogue_page(api: ApiRequest):
|
|||
api.knowledge_base_chat(
|
||||
prompt, selected_kb, kb_top_k, score_threshold, history,
|
||||
embed_model=EMBEDDING_MODEL, embed_model_path=embedding_model_dict[EMBEDDING_MODEL],
|
||||
model_device=EMBEDDING_DEVICE, embed_engine=EMBEDDING_ENGINE, llm_model=LLM_MODEL)
|
||||
model_device=EMBEDDING_DEVICE, embed_engine=EMBEDDING_ENGINE, llm_model=LLM_MODEL,
|
||||
api_key=llm_model_dict[LLM_MODEL]["api_key"],
|
||||
api_base_url=llm_model_dict[LLM_MODEL]["api_base_url"],
|
||||
)
|
||||
):
|
||||
if error_msg := check_error_msg(d): # check whether error occured
|
||||
st.error(error_msg)
|
||||
|
@ -415,11 +422,7 @@ def dialogue_page(api: ApiRequest):
|
|||
# chat_box.update_msg("知识库匹配结果: \n\n".join(d["docs"]), element_index=1, streaming=False, state="complete")
|
||||
chat_box.update_msg(text, element_index=0, streaming=False) # 更新最终的字符串,去除光标
|
||||
chat_box.update_msg("{webui_configs['chat']['chatbox_doc_result']}: \n\n".join(d["docs"]), element_index=1, streaming=False, state="complete")
|
||||
# # 判断是否存在代码, 并提高编辑功能,执行功能
|
||||
# code_text = api.codebox.decode_code_from_text(text)
|
||||
# GLOBAL_EXE_CODE_TEXT = code_text
|
||||
# if code_text and code_exec_on:
|
||||
# codebox_res = api.codebox_chat("```"+code_text+"```", do_code_exe=True)
|
||||
|
||||
elif dialogue_mode == webui_configs["dialogue"]["mode"][2]:
|
||||
logger.info('prompt={}'.format(prompt))
|
||||
logger.info('history={}'.format(history))
|
||||
|
@ -438,7 +441,9 @@ def dialogue_page(api: ApiRequest):
|
|||
cb_search_type=cb_search_type,
|
||||
no_remote_api=True, embed_model=EMBEDDING_MODEL,
|
||||
embed_model_path=embedding_model_dict[EMBEDDING_MODEL],
|
||||
embed_engine=EMBEDDING_ENGINE, llm_model=LLM_MODEL
|
||||
embed_engine=EMBEDDING_ENGINE, llm_model=LLM_MODEL,
|
||||
api_key=llm_model_dict[LLM_MODEL]["api_key"],
|
||||
api_base_url=llm_model_dict[LLM_MODEL]["api_base_url"],
|
||||
)):
|
||||
if error_msg := check_error_msg(d):
|
||||
st.error(error_msg)
|
||||
|
@ -448,6 +453,7 @@ def dialogue_page(api: ApiRequest):
|
|||
chat_box.update_msg(text, element_index=0)
|
||||
|
||||
# postprocess
|
||||
logger.debug(f"d={d}")
|
||||
text = replace_lt_gt(text)
|
||||
chat_box.update_msg(text, element_index=0, streaming=False) # 更新最终的字符串,去除光标
|
||||
logger.debug('text={}'.format(text))
|
||||
|
@ -467,7 +473,9 @@ def dialogue_page(api: ApiRequest):
|
|||
api.search_engine_chat(
|
||||
prompt, search_engine, se_top_k, history, embed_model=EMBEDDING_MODEL,
|
||||
embed_model_path=embedding_model_dict[EMBEDDING_MODEL],
|
||||
model_device=EMBEDDING_DEVICE, embed_engine=EMBEDDING_ENGINE, llm_model=LLM_MODEL)
|
||||
model_device=EMBEDDING_DEVICE, embed_engine=EMBEDDING_ENGINE, llm_model=LLM_MODEL,
|
||||
pi_key=llm_model_dict[LLM_MODEL]["api_key"],
|
||||
api_base_url=llm_model_dict[LLM_MODEL]["api_base_url"],)
|
||||
):
|
||||
if error_msg := check_error_msg(d): # check whether error occured
|
||||
st.error(error_msg)
|
||||
|
@ -477,56 +485,11 @@ def dialogue_page(api: ApiRequest):
|
|||
# chat_box.update_msg("搜索匹配结果: \n\n".join(d["docs"]), element_index=1, streaming=False)
|
||||
chat_box.update_msg(text, element_index=0, streaming=False) # 更新最终的字符串,去除光标
|
||||
chat_box.update_msg(f"{webui_configs['chat']['chatbox_search_result']}: \n\n".join(d["docs"]), element_index=1, streaming=False, state="complete")
|
||||
# # 判断是否存在代码, 并提高编辑功能,执行功能
|
||||
# code_text = api.codebox.decode_code_from_text(text)
|
||||
# GLOBAL_EXE_CODE_TEXT = code_text
|
||||
# if code_text and code_exec_on:
|
||||
# codebox_res = api.codebox_chat("```"+code_text+"```", do_code_exe=True)
|
||||
|
||||
# 将上传文件清空
|
||||
st.session_state["interpreter_file_key"] += 1
|
||||
st.experimental_rerun()
|
||||
|
||||
# if code_interpreter_on:
|
||||
# with st.expander(webui_configs['sandbox']['expander_code_name'], False):
|
||||
# code_part = st.text_area(
|
||||
# webui_configs['sandbox']['textArea_code_name'], code_text, key="code_text")
|
||||
# cols = st.columns(2)
|
||||
# if cols[0].button(
|
||||
# webui_configs['sandbox']['button_modify_code_name'],
|
||||
# use_container_width=True,
|
||||
# ):
|
||||
# code_text = code_part
|
||||
# GLOBAL_EXE_CODE_TEXT = code_text
|
||||
# st.toast(webui_configs['sandbox']['text_modify_code'])
|
||||
|
||||
# if cols[1].button(
|
||||
# webui_configs['sandbox']['button_exec_code_name'],
|
||||
# use_container_width=True
|
||||
# ):
|
||||
# if code_text:
|
||||
# codebox_res = api.codebox_chat("```"+code_text+"```", do_code_exe=True)
|
||||
# st.toast(webui_configs['sandbox']['text_execing_code'],)
|
||||
# else:
|
||||
# st.toast(webui_configs['sandbox']['text_error_exec_code'],)
|
||||
|
||||
# #TODO 这段信息会被记录到history里
|
||||
# if codebox_res is not None and codebox_res.code_exe_status != 200:
|
||||
# st.toast(f"{codebox_res.code_exe_response}")
|
||||
|
||||
# if codebox_res is not None and codebox_res.code_exe_status == 200:
|
||||
# st.toast(f"codebox_chat {codebox_res}")
|
||||
# chat_box.ai_say(Markdown(code_text, in_expander=True, title="code interpreter", unsafe_allow_html=True), )
|
||||
# if codebox_res.code_exe_type == "image/png":
|
||||
# base_text = f"```\n{code_text}\n```\n\n"
|
||||
# img_html = "<img src='data:image/png;base64,{}' class='img-fluid'>".format(
|
||||
# codebox_res.code_exe_response
|
||||
# )
|
||||
# chat_box.update_msg(img_html, streaming=False, state="complete")
|
||||
# else:
|
||||
# chat_box.update_msg('```\n'+code_text+'\n```'+"\n\n"+'```\n'+codebox_res.code_exe_response+'\n```',
|
||||
# streaming=False, state="complete")
|
||||
|
||||
now = datetime.now()
|
||||
with st.sidebar:
|
||||
|
||||
|
|
|
@ -14,7 +14,8 @@ from coagent.orm import table_init
|
|||
|
||||
from configs.model_config import (
|
||||
KB_ROOT_PATH, kbs_config, DEFAULT_VS_TYPE, WEB_CRAWL_PATH,
|
||||
EMBEDDING_DEVICE, EMBEDDING_ENGINE, EMBEDDING_MODEL, embedding_model_dict
|
||||
EMBEDDING_DEVICE, EMBEDDING_ENGINE, EMBEDDING_MODEL, embedding_model_dict,
|
||||
llm_model_dict
|
||||
)
|
||||
|
||||
# SENTENCE_SIZE = 100
|
||||
|
@ -136,6 +137,8 @@ def knowledge_page(
|
|||
embed_engine=EMBEDDING_ENGINE,
|
||||
embedding_device= EMBEDDING_DEVICE,
|
||||
embed_model_path=embedding_model_dict[embed_model],
|
||||
api_key=llm_model_dict[LLM_MODEL]["api_key"],
|
||||
api_base_url=llm_model_dict[LLM_MODEL]["api_base_url"],
|
||||
)
|
||||
st.toast(ret.get("msg", " "))
|
||||
st.session_state["selected_kb_name"] = kb_name
|
||||
|
@ -160,7 +163,10 @@ def knowledge_page(
|
|||
data = [{"file": f, "knowledge_base_name": kb, "not_refresh_vs_cache": True, "embed_model": EMBEDDING_MODEL,
|
||||
"embed_model_path": embedding_model_dict[EMBEDDING_MODEL],
|
||||
"model_device": EMBEDDING_DEVICE,
|
||||
"embed_engine": EMBEDDING_ENGINE}
|
||||
"embed_engine": EMBEDDING_ENGINE,
|
||||
"api_key": llm_model_dict[LLM_MODEL]["api_key"],
|
||||
"api_base_url": llm_model_dict[LLM_MODEL]["api_base_url"],
|
||||
}
|
||||
for f in files]
|
||||
data[-1]["not_refresh_vs_cache"]=False
|
||||
for k in data:
|
||||
|
@ -210,7 +216,9 @@ def knowledge_page(
|
|||
"embed_model": EMBEDDING_MODEL,
|
||||
"embed_model_path": embedding_model_dict[EMBEDDING_MODEL],
|
||||
"model_device": EMBEDDING_DEVICE,
|
||||
"embed_engine": EMBEDDING_ENGINE}]
|
||||
"embed_engine": EMBEDDING_ENGINE,
|
||||
"api_key": llm_model_dict[LLM_MODEL]["api_key"],
|
||||
"api_base_url": llm_model_dict[LLM_MODEL]["api_base_url"],}]
|
||||
for k in data:
|
||||
ret = api.upload_kb_doc(**k)
|
||||
logger.info(ret)
|
||||
|
@ -297,7 +305,9 @@ def knowledge_page(
|
|||
api.update_kb_doc(kb, row["file_name"],
|
||||
embed_engine=EMBEDDING_ENGINE,embed_model=EMBEDDING_MODEL,
|
||||
embed_model_path=embedding_model_dict[EMBEDDING_MODEL],
|
||||
model_device=EMBEDDING_DEVICE
|
||||
model_device=EMBEDDING_DEVICE,
|
||||
api_key=llm_model_dict[LLM_MODEL]["api_key"],
|
||||
api_base_url=llm_model_dict[LLM_MODEL]["api_base_url"],
|
||||
)
|
||||
st.experimental_rerun()
|
||||
|
||||
|
@ -311,7 +321,9 @@ def knowledge_page(
|
|||
api.delete_kb_doc(kb, row["file_name"],
|
||||
embed_engine=EMBEDDING_ENGINE,embed_model=EMBEDDING_MODEL,
|
||||
embed_model_path=embedding_model_dict[EMBEDDING_MODEL],
|
||||
model_device=EMBEDDING_DEVICE)
|
||||
model_device=EMBEDDING_DEVICE,
|
||||
api_key=llm_model_dict[LLM_MODEL]["api_key"],
|
||||
api_base_url=llm_model_dict[LLM_MODEL]["api_base_url"],)
|
||||
st.experimental_rerun()
|
||||
|
||||
if cols[3].button(
|
||||
|
@ -323,7 +335,9 @@ def knowledge_page(
|
|||
ret = api.delete_kb_doc(kb, row["file_name"], True,
|
||||
embed_engine=EMBEDDING_ENGINE,embed_model=EMBEDDING_MODEL,
|
||||
embed_model_path=embedding_model_dict[EMBEDDING_MODEL],
|
||||
model_device=EMBEDDING_DEVICE)
|
||||
model_device=EMBEDDING_DEVICE,
|
||||
api_key=llm_model_dict[LLM_MODEL]["api_key"],
|
||||
api_base_url=llm_model_dict[LLM_MODEL]["api_base_url"],)
|
||||
st.toast(ret.get("msg", " "))
|
||||
st.experimental_rerun()
|
||||
|
||||
|
@ -344,6 +358,8 @@ def knowledge_page(
|
|||
for d in api.recreate_vector_store(
|
||||
kb, vs_type=default_vs_type, embed_model=embedding_model, embedding_device=EMBEDDING_DEVICE,
|
||||
embed_model_path=embedding_model_dict["embedding_model"], embed_engine=EMBEDDING_ENGINE,
|
||||
api_key=llm_model_dict[LLM_MODEL]["api_key"],
|
||||
api_base_url=llm_model_dict[LLM_MODEL]["api_base_url"],
|
||||
):
|
||||
if msg := check_error_msg(d):
|
||||
st.toast(msg)
|
||||
|
|
|
@ -299,7 +299,9 @@ class ApiRequest:
|
|||
stream: bool = True,
|
||||
no_remote_api: bool = None,
|
||||
embed_model: str="", embed_model_path: str="", model_device: str="", embed_engine: str="",
|
||||
llm_model: str ="", temperature: float= 0.2
|
||||
llm_model: str ="", temperature: float= 0.2,
|
||||
api_key: str=os.environ["OPENAI_API_KEY"],
|
||||
api_base_url: str = os.environ["API_BASE_URL"],
|
||||
):
|
||||
'''
|
||||
对应api.py/chat/chat接口
|
||||
|
@ -311,8 +313,8 @@ class ApiRequest:
|
|||
"query": query,
|
||||
"history": history,
|
||||
"stream": stream,
|
||||
"api_key": os.environ["OPENAI_API_KEY"],
|
||||
"api_base_url": os.environ["API_BASE_URL"],
|
||||
"api_key": api_key,
|
||||
"api_base_url": api_base_url,
|
||||
"embed_model": embed_model,
|
||||
"embed_model_path": embed_model_path,
|
||||
"embed_engine": embed_engine,
|
||||
|
@ -339,7 +341,9 @@ class ApiRequest:
|
|||
stream: bool = True,
|
||||
no_remote_api: bool = None,
|
||||
embed_model: str="", embed_model_path: str="", model_device: str="", embed_engine: str="",
|
||||
llm_model: str ="", temperature: float= 0.2
|
||||
llm_model: str ="", temperature: float= 0.2,
|
||||
api_key: str=os.environ["OPENAI_API_KEY"],
|
||||
api_base_url: str = os.environ["API_BASE_URL"],
|
||||
):
|
||||
'''
|
||||
对应api.py/chat/knowledge_base_chat接口
|
||||
|
@ -355,8 +359,8 @@ class ApiRequest:
|
|||
"history": history,
|
||||
"stream": stream,
|
||||
"local_doc_url": no_remote_api,
|
||||
"api_key": os.environ["OPENAI_API_KEY"],
|
||||
"api_base_url": os.environ["API_BASE_URL"],
|
||||
"api_key": api_key,
|
||||
"api_base_url": api_base_url,
|
||||
"embed_model": embed_model,
|
||||
"embed_model_path": embed_model_path,
|
||||
"embed_engine": embed_engine,
|
||||
|
@ -386,7 +390,10 @@ class ApiRequest:
|
|||
stream: bool = True,
|
||||
no_remote_api: bool = None,
|
||||
embed_model: str="", embed_model_path: str="", model_device: str="", embed_engine: str="",
|
||||
llm_model: str ="", temperature: float= 0.2
|
||||
llm_model: str ="", temperature: float= 0.2,
|
||||
api_key: str=os.environ["OPENAI_API_KEY"],
|
||||
api_base_url: str = os.environ["API_BASE_URL"],
|
||||
|
||||
):
|
||||
'''
|
||||
对应api.py/chat/search_engine_chat接口
|
||||
|
@ -400,8 +407,8 @@ class ApiRequest:
|
|||
"top_k": top_k,
|
||||
"history": history,
|
||||
"stream": stream,
|
||||
"api_key": os.environ["OPENAI_API_KEY"],
|
||||
"api_base_url": os.environ["API_BASE_URL"],
|
||||
"api_key": api_key,
|
||||
"api_base_url": api_base_url,
|
||||
"embed_model": embed_model,
|
||||
"embed_model_path": embed_model_path,
|
||||
"embed_engine": embed_engine,
|
||||
|
@ -432,7 +439,9 @@ class ApiRequest:
|
|||
stream: bool = True,
|
||||
no_remote_api: bool = None,
|
||||
embed_model: str="", embed_model_path: str="", model_device: str="", embed_engine: str="",
|
||||
llm_model: str ="", temperature: float= 0.2
|
||||
llm_model: str ="", temperature: float= 0.2,
|
||||
api_key: str=os.environ["OPENAI_API_KEY"],
|
||||
api_base_url: str = os.environ["API_BASE_URL"],
|
||||
):
|
||||
'''
|
||||
对应api.py/chat/knowledge_base_chat接口
|
||||
|
@ -458,8 +467,8 @@ class ApiRequest:
|
|||
"cb_search_type": cb_search_type,
|
||||
"stream": stream,
|
||||
"local_doc_url": no_remote_api,
|
||||
"api_key": os.environ["OPENAI_API_KEY"],
|
||||
"api_base_url": os.environ["API_BASE_URL"],
|
||||
"api_key": api_key,
|
||||
"api_base_url": api_base_url,
|
||||
"embed_model": embed_model,
|
||||
"embed_model_path": embed_model_path,
|
||||
"embed_engine": embed_engine,
|
||||
|
@ -510,6 +519,8 @@ class ApiRequest:
|
|||
embed_model: str="", embed_model_path: str="",
|
||||
model_device: str="", embed_engine: str="",
|
||||
temperature: float=0.2, model_name:str ="",
|
||||
api_key: str=os.environ["OPENAI_API_KEY"],
|
||||
api_base_url: str = os.environ["API_BASE_URL"],
|
||||
):
|
||||
'''
|
||||
对应api.py/chat/chat接口
|
||||
|
@ -541,8 +552,8 @@ class ApiRequest:
|
|||
"isDetailed": isDetailed,
|
||||
"upload_file": upload_file,
|
||||
"kb_root_path": kb_root_path,
|
||||
"api_key": os.environ["OPENAI_API_KEY"],
|
||||
"api_base_url": os.environ["API_BASE_URL"],
|
||||
"api_key": api_key,
|
||||
"api_base_url": api_base_url,
|
||||
"embed_model": embed_model,
|
||||
"embed_model_path": embed_model_path,
|
||||
"embed_engine": embed_engine,
|
||||
|
@ -588,6 +599,8 @@ class ApiRequest:
|
|||
embed_model: str="", embed_model_path: str="",
|
||||
model_device: str="", embed_engine: str="",
|
||||
temperature: float=0.2, model_name: str="",
|
||||
api_key: str=os.environ["OPENAI_API_KEY"],
|
||||
api_base_url: str = os.environ["API_BASE_URL"],
|
||||
):
|
||||
'''
|
||||
对应api.py/chat/chat接口
|
||||
|
@ -620,8 +633,8 @@ class ApiRequest:
|
|||
"isDetailed": isDetailed,
|
||||
"upload_file": upload_file,
|
||||
"kb_root_path": kb_root_path,
|
||||
"api_key": os.environ["OPENAI_API_KEY"],
|
||||
"api_base_url": os.environ["API_BASE_URL"],
|
||||
"api_key": api_key,
|
||||
"api_base_url": api_base_url,
|
||||
"embed_model": embed_model,
|
||||
"embed_model_path": embed_model_path,
|
||||
"embed_engine": embed_engine,
|
||||
|
@ -694,7 +707,9 @@ class ApiRequest:
|
|||
no_remote_api: bool = None,
|
||||
kb_root_path: str =KB_ROOT_PATH,
|
||||
embed_model: str="", embed_model_path: str="",
|
||||
embedding_device: str="", embed_engine: str=""
|
||||
embedding_device: str="", embed_engine: str="",
|
||||
api_key: str=os.environ["OPENAI_API_KEY"],
|
||||
api_base_url: str = os.environ["API_BASE_URL"],
|
||||
):
|
||||
'''
|
||||
对应api.py/knowledge_base/create_knowledge_base接口
|
||||
|
@ -706,8 +721,8 @@ class ApiRequest:
|
|||
"knowledge_base_name": knowledge_base_name,
|
||||
"vector_store_type": vector_store_type,
|
||||
"kb_root_path": kb_root_path,
|
||||
"api_key": os.environ["OPENAI_API_KEY"],
|
||||
"api_base_url": os.environ["API_BASE_URL"],
|
||||
"api_key": api_key,
|
||||
"api_base_url": api_base_url,
|
||||
"embed_model": embed_model,
|
||||
"embed_model_path": embed_model_path,
|
||||
"model_device": embedding_device,
|
||||
|
@ -781,7 +796,9 @@ class ApiRequest:
|
|||
no_remote_api: bool = None,
|
||||
kb_root_path: str = KB_ROOT_PATH,
|
||||
embed_model: str="", embed_model_path: str="",
|
||||
model_device: str="", embed_engine: str=""
|
||||
model_device: str="", embed_engine: str="",
|
||||
api_key: str=os.environ["OPENAI_API_KEY"],
|
||||
api_base_url: str = os.environ["API_BASE_URL"],
|
||||
):
|
||||
'''
|
||||
对应api.py/knowledge_base/upload_docs接口
|
||||
|
@ -810,8 +827,8 @@ class ApiRequest:
|
|||
override,
|
||||
not_refresh_vs_cache,
|
||||
kb_root_path=kb_root_path,
|
||||
api_key=os.environ["OPENAI_API_KEY"],
|
||||
api_base_url=os.environ["API_BASE_URL"],
|
||||
api_key=api_key,
|
||||
api_base_url=api_base_url,
|
||||
embed_model=embed_model,
|
||||
embed_model_path=embed_model_path,
|
||||
model_device=model_device,
|
||||
|
@ -839,7 +856,9 @@ class ApiRequest:
|
|||
no_remote_api: bool = None,
|
||||
kb_root_path: str = KB_ROOT_PATH,
|
||||
embed_model: str="", embed_model_path: str="",
|
||||
model_device: str="", embed_engine: str=""
|
||||
model_device: str="", embed_engine: str="",
|
||||
api_key: str=os.environ["OPENAI_API_KEY"],
|
||||
api_base_url: str = os.environ["API_BASE_URL"],
|
||||
):
|
||||
'''
|
||||
对应api.py/knowledge_base/delete_doc接口
|
||||
|
@ -853,8 +872,8 @@ class ApiRequest:
|
|||
"delete_content": delete_content,
|
||||
"not_refresh_vs_cache": not_refresh_vs_cache,
|
||||
"kb_root_path": kb_root_path,
|
||||
"api_key": os.environ["OPENAI_API_KEY"],
|
||||
"api_base_url": os.environ["API_BASE_URL"],
|
||||
"api_key": api_key,
|
||||
"api_base_url": api_base_url,
|
||||
"embed_model": embed_model,
|
||||
"embed_model_path": embed_model_path,
|
||||
"model_device": model_device,
|
||||
|
@ -878,7 +897,9 @@ class ApiRequest:
|
|||
not_refresh_vs_cache: bool = False,
|
||||
no_remote_api: bool = None,
|
||||
embed_model: str="", embed_model_path: str="",
|
||||
model_device: str="", embed_engine: str=""
|
||||
model_device: str="", embed_engine: str="",
|
||||
api_key: str=os.environ["OPENAI_API_KEY"],
|
||||
api_base_url: str = os.environ["API_BASE_URL"],
|
||||
):
|
||||
'''
|
||||
对应api.py/knowledge_base/update_doc接口
|
||||
|
@ -889,8 +910,8 @@ class ApiRequest:
|
|||
if no_remote_api:
|
||||
response = run_async(update_doc(
|
||||
knowledge_base_name, file_name, not_refresh_vs_cache, kb_root_path=KB_ROOT_PATH,
|
||||
api_key=os.environ["OPENAI_API_KEY"],
|
||||
api_base_url=os.environ["API_BASE_URL"],
|
||||
api_key=api_key,
|
||||
api_base_url=api_base_url,
|
||||
embed_model=embed_model,
|
||||
embed_model_path=embed_model_path,
|
||||
model_device=model_device,
|
||||
|
@ -915,7 +936,9 @@ class ApiRequest:
|
|||
no_remote_api: bool = None,
|
||||
kb_root_path: str =KB_ROOT_PATH,
|
||||
embed_model: str="", embed_model_path: str="",
|
||||
embedding_device: str="", embed_engine: str=""
|
||||
embedding_device: str="", embed_engine: str="",
|
||||
api_key: str=os.environ["OPENAI_API_KEY"],
|
||||
api_base_url: str = os.environ["API_BASE_URL"],
|
||||
):
|
||||
'''
|
||||
对应api.py/knowledge_base/recreate_vector_store接口
|
||||
|
@ -928,8 +951,8 @@ class ApiRequest:
|
|||
"allow_empty_kb": allow_empty_kb,
|
||||
"vs_type": vs_type,
|
||||
"kb_root_path": kb_root_path,
|
||||
"api_key": os.environ["OPENAI_API_KEY"],
|
||||
"api_base_url": os.environ["API_BASE_URL"],
|
||||
"api_key": api_key,
|
||||
"api_base_url": api_base_url,
|
||||
"embed_model": embed_model,
|
||||
"embed_model_path": embed_model_path,
|
||||
"model_device": embedding_device,
|
||||
|
@ -1041,7 +1064,9 @@ class ApiRequest:
|
|||
# code base 相关操作
|
||||
def create_code_base(self, cb_name, zip_file, do_interpret: bool, no_remote_api: bool = None,
|
||||
embed_model: str="", embed_model_path: str="", embedding_device: str="", embed_engine: str="",
|
||||
llm_model: str ="", temperature: float= 0.2
|
||||
llm_model: str ="", temperature: float= 0.2,
|
||||
api_key: str=os.environ["OPENAI_API_KEY"],
|
||||
api_base_url: str = os.environ["API_BASE_URL"],
|
||||
):
|
||||
'''
|
||||
创建 code_base
|
||||
|
@ -1067,8 +1092,8 @@ class ApiRequest:
|
|||
"cb_name": cb_name,
|
||||
"code_path": raw_code_path,
|
||||
"do_interpret": do_interpret,
|
||||
"api_key": os.environ["OPENAI_API_KEY"],
|
||||
"api_base_url": os.environ["API_BASE_URL"],
|
||||
"api_key": api_key,
|
||||
"api_base_url": api_base_url,
|
||||
"embed_model": embed_model,
|
||||
"embed_model_path": embed_model_path,
|
||||
"embed_engine": embed_engine,
|
||||
|
@ -1091,7 +1116,9 @@ class ApiRequest:
|
|||
|
||||
def delete_code_base(self, cb_name: str, no_remote_api: bool = None,
|
||||
embed_model: str="", embed_model_path: str="", embedding_device: str="", embed_engine: str="",
|
||||
llm_model: str ="", temperature: float= 0.2
|
||||
llm_model: str ="", temperature: float= 0.2,
|
||||
api_key: str=os.environ["OPENAI_API_KEY"],
|
||||
api_base_url: str = os.environ["API_BASE_URL"],
|
||||
):
|
||||
'''
|
||||
删除 code_base
|
||||
|
@ -1102,8 +1129,8 @@ class ApiRequest:
|
|||
no_remote_api = self.no_remote_api
|
||||
data = {
|
||||
"cb_name": cb_name,
|
||||
"api_key": os.environ["OPENAI_API_KEY"],
|
||||
"api_base_url": os.environ["API_BASE_URL"],
|
||||
"api_key": api_key,
|
||||
"api_base_url": api_base_url,
|
||||
"embed_model": embed_model,
|
||||
"embed_model_path": embed_model_path,
|
||||
"embed_engine": embed_engine,
|
||||
|
|
Loading…
Reference in New Issue