Merge pull request #28 from codefuse-ai/ma_demo_commit
[feature](coagent)<增加antflow兼容和增加coagent demo>
This commit is contained in:
commit
333f1e97c6
|
@ -26,9 +26,12 @@ JUPYTER_WORK_PATH = os.environ.get("JUPYTER_WORK_PATH", None) or os.path.join(ex
|
||||||
WEB_CRAWL_PATH = os.environ.get("WEB_CRAWL_PATH", None) or os.path.join(executable_path, "knowledge_base")
|
WEB_CRAWL_PATH = os.environ.get("WEB_CRAWL_PATH", None) or os.path.join(executable_path, "knowledge_base")
|
||||||
|
|
||||||
# NEBULA_DATA存储路径
|
# NEBULA_DATA存储路径
|
||||||
NELUBA_PATH = os.environ.get("NELUBA_PATH", None) or os.path.join(executable_path, "data/neluba_data")
|
NEBULA_PATH = os.environ.get("NEBULA_PATH", None) or os.path.join(executable_path, "data/nebula_data")
|
||||||
|
|
||||||
for _path in [LOG_PATH, SOURCE_PATH, KB_ROOT_PATH, NLTK_DATA_PATH, JUPYTER_WORK_PATH, WEB_CRAWL_PATH, NELUBA_PATH]:
|
# CHROMA 存储路径
|
||||||
|
CHROMA_PERSISTENT_PATH = os.environ.get("CHROMA_PERSISTENT_PATH", None) or os.path.join(executable_path, "data/chroma_data")
|
||||||
|
|
||||||
|
for _path in [LOG_PATH, SOURCE_PATH, KB_ROOT_PATH, CB_ROOT_PATH, NLTK_DATA_PATH, JUPYTER_WORK_PATH, WEB_CRAWL_PATH, NEBULA_PATH, CHROMA_PERSISTENT_PATH]:
|
||||||
if not os.path.exists(_path):
|
if not os.path.exists(_path):
|
||||||
os.makedirs(_path, exist_ok=True)
|
os.makedirs(_path, exist_ok=True)
|
||||||
|
|
||||||
|
@ -58,7 +61,8 @@ NEBULA_GRAPH_SERVER = {
|
||||||
}
|
}
|
||||||
|
|
||||||
# CHROMA CONFIG
|
# CHROMA CONFIG
|
||||||
CHROMA_PERSISTENT_PATH = '/home/user/chatbot/data/chroma_data'
|
# CHROMA_PERSISTENT_PATH = '/home/user/chatbot/data/chroma_data'
|
||||||
|
# CHROMA_PERSISTENT_PATH = '/Users/bingxu/Desktop/工作/大模型/chatbot/codefuse-chatbot-antcode/data/chroma_data'
|
||||||
|
|
||||||
|
|
||||||
# 默认向量库类型。可选:faiss, milvus, pg.
|
# 默认向量库类型。可选:faiss, milvus, pg.
|
||||||
|
|
|
@ -7,7 +7,7 @@ from langchain import LLMChain
|
||||||
from langchain.callbacks import AsyncIteratorCallbackHandler
|
from langchain.callbacks import AsyncIteratorCallbackHandler
|
||||||
from langchain.prompts.chat import ChatPromptTemplate
|
from langchain.prompts.chat import ChatPromptTemplate
|
||||||
|
|
||||||
from coagent.llm_models import getChatModel, getChatModelFromConfig
|
from coagent.llm_models import getChatModelFromConfig
|
||||||
from coagent.chat.utils import History, wrap_done
|
from coagent.chat.utils import History, wrap_done
|
||||||
from coagent.llm_models.llm_config import LLMConfig, EmbedConfig
|
from coagent.llm_models.llm_config import LLMConfig, EmbedConfig
|
||||||
# from configs.model_config import (llm_model_dict, LLM_MODEL, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD)
|
# from configs.model_config import (llm_model_dict, LLM_MODEL, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD)
|
||||||
|
|
|
@ -22,7 +22,7 @@ from coagent.connector.configs.prompts import CODE_PROMPT_TEMPLATE
|
||||||
from coagent.chat.utils import History, wrap_done
|
from coagent.chat.utils import History, wrap_done
|
||||||
from coagent.utils import BaseResponse
|
from coagent.utils import BaseResponse
|
||||||
from .base_chat import Chat
|
from .base_chat import Chat
|
||||||
from coagent.llm_models import getChatModel, getChatModelFromConfig
|
from coagent.llm_models import getChatModelFromConfig
|
||||||
|
|
||||||
from coagent.llm_models.llm_config import LLMConfig, EmbedConfig
|
from coagent.llm_models.llm_config import LLMConfig, EmbedConfig
|
||||||
|
|
||||||
|
@ -67,6 +67,7 @@ class CodeChat(Chat):
|
||||||
embed_model_path=embed_config.embed_model_path,
|
embed_model_path=embed_config.embed_model_path,
|
||||||
embed_engine=embed_config.embed_engine,
|
embed_engine=embed_config.embed_engine,
|
||||||
model_device=embed_config.model_device,
|
model_device=embed_config.model_device,
|
||||||
|
embed_config=embed_config
|
||||||
)
|
)
|
||||||
|
|
||||||
context = codes_res['context']
|
context = codes_res['context']
|
||||||
|
|
|
@ -12,7 +12,7 @@ from langchain.schema import (
|
||||||
|
|
||||||
# from configs.model_config import CODE_INTERPERT_TEMPLATE
|
# from configs.model_config import CODE_INTERPERT_TEMPLATE
|
||||||
from coagent.connector.configs.prompts import CODE_INTERPERT_TEMPLATE
|
from coagent.connector.configs.prompts import CODE_INTERPERT_TEMPLATE
|
||||||
from coagent.llm_models.openai_model import getChatModel, getChatModelFromConfig
|
from coagent.llm_models.openai_model import getChatModelFromConfig
|
||||||
from coagent.llm_models.llm_config import LLMConfig
|
from coagent.llm_models.llm_config import LLMConfig
|
||||||
|
|
||||||
|
|
||||||
|
@ -53,9 +53,15 @@ class CodeIntepreter:
|
||||||
message = CODE_INTERPERT_TEMPLATE.format(code=code)
|
message = CODE_INTERPERT_TEMPLATE.format(code=code)
|
||||||
messages.append(message)
|
messages.append(message)
|
||||||
|
|
||||||
chat_ress = chat_model.batch(messages)
|
try:
|
||||||
|
chat_ress = [chat_model(messages) for message in messages]
|
||||||
|
except:
|
||||||
|
chat_ress = chat_model.batch(messages)
|
||||||
for chat_res, code in zip(chat_ress, code_list):
|
for chat_res, code in zip(chat_ress, code_list):
|
||||||
res[code] = chat_res.content
|
try:
|
||||||
|
res[code] = chat_res.content
|
||||||
|
except:
|
||||||
|
res[code] = chat_res
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -27,7 +27,7 @@ class DirCrawler:
|
||||||
logger.info(java_file_list)
|
logger.info(java_file_list)
|
||||||
|
|
||||||
for java_file in java_file_list:
|
for java_file in java_file_list:
|
||||||
with open(java_file) as f:
|
with open(java_file, encoding="utf-8") as f:
|
||||||
java_code = ''.join(f.readlines())
|
java_code = ''.join(f.readlines())
|
||||||
java_code_dict[java_file] = java_code
|
java_code_dict[java_file] = java_code
|
||||||
return java_code_dict
|
return java_code_dict
|
||||||
|
|
|
@ -5,6 +5,7 @@
|
||||||
@time: 2023/11/21 下午2:35
|
@time: 2023/11/21 下午2:35
|
||||||
@desc:
|
@desc:
|
||||||
'''
|
'''
|
||||||
|
import json
|
||||||
import time
|
import time
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
@ -15,7 +16,7 @@ from coagent.db_handler.vector_db_handler.chroma_handler import ChromaHandler
|
||||||
from coagent.codechat.code_search.cypher_generator import CypherGenerator
|
from coagent.codechat.code_search.cypher_generator import CypherGenerator
|
||||||
from coagent.codechat.code_search.tagger import Tagger
|
from coagent.codechat.code_search.tagger import Tagger
|
||||||
from coagent.embeddings.get_embedding import get_embedding
|
from coagent.embeddings.get_embedding import get_embedding
|
||||||
from coagent.llm_models.llm_config import LLMConfig
|
from coagent.llm_models.llm_config import LLMConfig, EmbedConfig
|
||||||
|
|
||||||
|
|
||||||
# from configs.model_config import EMBEDDING_DEVICE, EMBEDDING_MODEL
|
# from configs.model_config import EMBEDDING_DEVICE, EMBEDDING_MODEL
|
||||||
|
@ -29,7 +30,8 @@ MAX_DISTANCE = 1000
|
||||||
|
|
||||||
|
|
||||||
class CodeSearch:
|
class CodeSearch:
|
||||||
def __init__(self, llm_config: LLMConfig, nh: NebulaHandler, ch: ChromaHandler, limit: int = 3):
|
def __init__(self, llm_config: LLMConfig, nh: NebulaHandler, ch: ChromaHandler, limit: int = 3,
|
||||||
|
local_graph_file_path: str = ''):
|
||||||
'''
|
'''
|
||||||
init
|
init
|
||||||
@param nh: NebulaHandler
|
@param nh: NebulaHandler
|
||||||
|
@ -37,7 +39,13 @@ class CodeSearch:
|
||||||
@param limit: limit of result
|
@param limit: limit of result
|
||||||
'''
|
'''
|
||||||
self.llm_config = llm_config
|
self.llm_config = llm_config
|
||||||
|
|
||||||
self.nh = nh
|
self.nh = nh
|
||||||
|
|
||||||
|
if not self.nh:
|
||||||
|
with open(local_graph_file_path, 'r') as f:
|
||||||
|
self.graph = json.load(f)
|
||||||
|
|
||||||
self.ch = ch
|
self.ch = ch
|
||||||
self.limit = limit
|
self.limit = limit
|
||||||
|
|
||||||
|
@ -51,7 +59,7 @@ class CodeSearch:
|
||||||
tag_list = tagger.generate_tag_query(query)
|
tag_list = tagger.generate_tag_query(query)
|
||||||
logger.info(f'query tag={tag_list}')
|
logger.info(f'query tag={tag_list}')
|
||||||
|
|
||||||
# get all verticex
|
# get all vertices
|
||||||
vertex_list = self.nh.get_vertices().get('v', [])
|
vertex_list = self.nh.get_vertices().get('v', [])
|
||||||
vertex_vid_list = [i.as_node().get_id().as_string() for i in vertex_list]
|
vertex_vid_list = [i.as_node().get_id().as_string() for i in vertex_list]
|
||||||
|
|
||||||
|
@ -81,7 +89,7 @@ class CodeSearch:
|
||||||
# get most prominent package tag
|
# get most prominent package tag
|
||||||
package_score_dict = defaultdict(lambda: 0)
|
package_score_dict = defaultdict(lambda: 0)
|
||||||
|
|
||||||
for vertex, score in vertex_score_dict.items():
|
for vertex, score in vertex_score_dict_final.items():
|
||||||
if '#' in vertex:
|
if '#' in vertex:
|
||||||
# get class name first
|
# get class name first
|
||||||
cypher = f'''MATCH (v1:class)-[e:contain]->(v2) WHERE id(v2) == '{vertex}' RETURN id(v1) as id;'''
|
cypher = f'''MATCH (v1:class)-[e:contain]->(v2) WHERE id(v2) == '{vertex}' RETURN id(v1) as id;'''
|
||||||
|
@ -111,6 +119,53 @@ class CodeSearch:
|
||||||
logger.info(f'ids={ids}')
|
logger.info(f'ids={ids}')
|
||||||
chroma_res = self.ch.get(ids=ids, include=['metadatas'])
|
chroma_res = self.ch.get(ids=ids, include=['metadatas'])
|
||||||
|
|
||||||
|
for vertex, score in package_score_tuple:
|
||||||
|
index = chroma_res['result']['ids'].index(vertex)
|
||||||
|
code_text = chroma_res['result']['metadatas'][index]['code_text']
|
||||||
|
res.append({
|
||||||
|
"vertex": vertex,
|
||||||
|
"code_text": code_text}
|
||||||
|
)
|
||||||
|
if len(res) >= self.limit:
|
||||||
|
break
|
||||||
|
# logger.info(f'retrival code={res}')
|
||||||
|
return res
|
||||||
|
|
||||||
|
def search_by_tag_by_graph(self, query: str):
|
||||||
|
'''
|
||||||
|
search code by tag with graph
|
||||||
|
@param query:
|
||||||
|
@return:
|
||||||
|
'''
|
||||||
|
tagger = Tagger()
|
||||||
|
tag_list = tagger.generate_tag_query(query)
|
||||||
|
logger.info(f'query tag={tag_list}')
|
||||||
|
|
||||||
|
# loop to get package node
|
||||||
|
package_score_dict = {}
|
||||||
|
for code, structure in self.graph.items():
|
||||||
|
score = 0
|
||||||
|
for class_name in structure['class_name_list']:
|
||||||
|
for tag in tag_list:
|
||||||
|
if tag.lower() in class_name.lower():
|
||||||
|
score += 1
|
||||||
|
|
||||||
|
for func_name_list in structure['func_name_dict'].values():
|
||||||
|
for func_name in func_name_list:
|
||||||
|
for tag in tag_list:
|
||||||
|
if tag.lower() in func_name.lower():
|
||||||
|
score += 1
|
||||||
|
package_score_dict[structure['pac_name']] = score
|
||||||
|
|
||||||
|
# get respective code
|
||||||
|
res = []
|
||||||
|
package_score_tuple = list(package_score_dict.items())
|
||||||
|
package_score_tuple.sort(key=lambda x: x[1], reverse=True)
|
||||||
|
|
||||||
|
ids = [i[0] for i in package_score_tuple]
|
||||||
|
logger.info(f'ids={ids}')
|
||||||
|
chroma_res = self.ch.get(ids=ids, include=['metadatas'])
|
||||||
|
|
||||||
# logger.info(chroma_res)
|
# logger.info(chroma_res)
|
||||||
for vertex, score in package_score_tuple:
|
for vertex, score in package_score_tuple:
|
||||||
index = chroma_res['result']['ids'].index(vertex)
|
index = chroma_res['result']['ids'].index(vertex)
|
||||||
|
@ -121,23 +176,22 @@ class CodeSearch:
|
||||||
)
|
)
|
||||||
if len(res) >= self.limit:
|
if len(res) >= self.limit:
|
||||||
break
|
break
|
||||||
logger.info(f'retrival code={res}')
|
# logger.info(f'retrival code={res}')
|
||||||
return res
|
return res
|
||||||
|
|
||||||
def search_by_desciption(self, query: str, engine: str, model_path: str = "text2vec-base-chinese", embedding_device: str = "cpu"):
|
def search_by_desciption(self, query: str, engine: str, model_path: str = "text2vec-base-chinese", embedding_device: str = "cpu", embed_config: EmbedConfig=None):
|
||||||
'''
|
'''
|
||||||
search by perform sim search
|
search by perform sim search
|
||||||
@param query:
|
@param query:
|
||||||
@return:
|
@return:
|
||||||
'''
|
'''
|
||||||
query = query.replace(',', ',')
|
query = query.replace(',', ',')
|
||||||
query_emb = get_embedding(engine=engine, text_list=[query], model_path=model_path, embedding_device= embedding_device,)
|
query_emb = get_embedding(engine=engine, text_list=[query], model_path=model_path, embedding_device= embedding_device, embed_config=embed_config)
|
||||||
query_emb = query_emb[query]
|
query_emb = query_emb[query]
|
||||||
|
|
||||||
query_embeddings = [query_emb]
|
query_embeddings = [query_emb]
|
||||||
query_result = self.ch.query(query_embeddings=query_embeddings, n_results=self.limit,
|
query_result = self.ch.query(query_embeddings=query_embeddings, n_results=self.limit,
|
||||||
include=['metadatas', 'distances'])
|
include=['metadatas', 'distances'])
|
||||||
logger.debug(query_result)
|
|
||||||
|
|
||||||
res = []
|
res = []
|
||||||
for idx, distance in enumerate(query_result['result']['distances'][0]):
|
for idx, distance in enumerate(query_result['result']['distances'][0]):
|
||||||
|
|
|
@ -8,7 +8,7 @@
|
||||||
from langchain import PromptTemplate
|
from langchain import PromptTemplate
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from coagent.llm_models.openai_model import getChatModel, getChatModelFromConfig
|
from coagent.llm_models.openai_model import getChatModelFromConfig
|
||||||
from coagent.llm_models.llm_config import LLMConfig
|
from coagent.llm_models.llm_config import LLMConfig
|
||||||
from coagent.utils.postprocess import replace_lt_gt
|
from coagent.utils.postprocess import replace_lt_gt
|
||||||
from langchain.schema import (
|
from langchain.schema import (
|
||||||
|
|
|
@ -6,11 +6,10 @@
|
||||||
@desc:
|
@desc:
|
||||||
'''
|
'''
|
||||||
import time
|
import time
|
||||||
|
import json
|
||||||
|
import os
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
# from configs.server_config import NEBULA_HOST, NEBULA_PORT, NEBULA_USER, NEBULA_PASSWORD, NEBULA_STORAGED_PORT
|
|
||||||
# from configs.server_config import CHROMA_PERSISTENT_PATH
|
|
||||||
# from configs.model_config import EMBEDDING_DEVICE, EMBEDDING_MODEL
|
|
||||||
from coagent.db_handler.graph_db_handler.nebula_handler import NebulaHandler
|
from coagent.db_handler.graph_db_handler.nebula_handler import NebulaHandler
|
||||||
from coagent.db_handler.vector_db_handler.chroma_handler import ChromaHandler
|
from coagent.db_handler.vector_db_handler.chroma_handler import ChromaHandler
|
||||||
from coagent.embeddings.get_embedding import get_embedding
|
from coagent.embeddings.get_embedding import get_embedding
|
||||||
|
@ -18,12 +17,14 @@ from coagent.llm_models.llm_config import EmbedConfig
|
||||||
|
|
||||||
|
|
||||||
class CodeImporter:
|
class CodeImporter:
|
||||||
def __init__(self, codebase_name: str, embed_config: EmbedConfig, nh: NebulaHandler, ch: ChromaHandler):
|
def __init__(self, codebase_name: str, embed_config: EmbedConfig, nh: NebulaHandler, ch: ChromaHandler,
|
||||||
|
local_graph_file_path: str):
|
||||||
self.codebase_name = codebase_name
|
self.codebase_name = codebase_name
|
||||||
# self.engine = engine
|
# self.engine = engine
|
||||||
self.embed_config: EmbedConfig= embed_config
|
self.embed_config: EmbedConfig = embed_config
|
||||||
self.nh = nh
|
self.nh = nh
|
||||||
self.ch = ch
|
self.ch = ch
|
||||||
|
self.local_graph_file_path = local_graph_file_path
|
||||||
|
|
||||||
def import_code(self, static_analysis_res: dict, interpretation: dict, do_interpret: bool = True):
|
def import_code(self, static_analysis_res: dict, interpretation: dict, do_interpret: bool = True):
|
||||||
'''
|
'''
|
||||||
|
@ -31,9 +32,14 @@ class CodeImporter:
|
||||||
@return:
|
@return:
|
||||||
'''
|
'''
|
||||||
static_analysis_res = self.filter_out_vertex(static_analysis_res, interpretation)
|
static_analysis_res = self.filter_out_vertex(static_analysis_res, interpretation)
|
||||||
logger.info(f'static_analysis_res={static_analysis_res}')
|
|
||||||
|
|
||||||
self.analysis_res_to_graph(static_analysis_res)
|
if self.nh:
|
||||||
|
self.analysis_res_to_graph(static_analysis_res)
|
||||||
|
else:
|
||||||
|
# persist to local dir
|
||||||
|
with open(self.local_graph_file_path, 'w') as f:
|
||||||
|
json.dump(static_analysis_res, f)
|
||||||
|
|
||||||
self.interpretation_to_db(static_analysis_res, interpretation, do_interpret)
|
self.interpretation_to_db(static_analysis_res, interpretation, do_interpret)
|
||||||
|
|
||||||
def filter_out_vertex(self, static_analysis_res, interpretation):
|
def filter_out_vertex(self, static_analysis_res, interpretation):
|
||||||
|
@ -114,12 +120,12 @@ class CodeImporter:
|
||||||
# create vertex
|
# create vertex
|
||||||
for tag_name, value_dict in vertex_value_dict.items():
|
for tag_name, value_dict in vertex_value_dict.items():
|
||||||
res = self.nh.insert_vertex(tag_name, value_dict)
|
res = self.nh.insert_vertex(tag_name, value_dict)
|
||||||
logger.debug(res.error_msg())
|
# logger.debug(res.error_msg())
|
||||||
|
|
||||||
# create edge
|
# create edge
|
||||||
for tag_name, value_dict in edge_value_dict.items():
|
for tag_name, value_dict in edge_value_dict.items():
|
||||||
res = self.nh.insert_edge(tag_name, value_dict)
|
res = self.nh.insert_edge(tag_name, value_dict)
|
||||||
logger.debug(res.error_msg())
|
# logger.debug(res.error_msg())
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -132,7 +138,7 @@ class CodeImporter:
|
||||||
if do_interpret:
|
if do_interpret:
|
||||||
logger.info('start get embedding for interpretion')
|
logger.info('start get embedding for interpretion')
|
||||||
interp_list = list(interpretation.values())
|
interp_list = list(interpretation.values())
|
||||||
emb = get_embedding(engine=self.embed_config.embed_engine, text_list=interp_list, model_path=self.embed_config.embed_model_path, embedding_device= self.embed_config.model_device)
|
emb = get_embedding(engine=self.embed_config.embed_engine, text_list=interp_list, model_path=self.embed_config.embed_model_path, embedding_device= self.embed_config.model_device, embed_config=self.embed_config)
|
||||||
logger.info('get embedding done')
|
logger.info('get embedding done')
|
||||||
else:
|
else:
|
||||||
emb = {i: [0] for i in list(interpretation.values())}
|
emb = {i: [0] for i in list(interpretation.values())}
|
||||||
|
@ -161,7 +167,7 @@ class CodeImporter:
|
||||||
|
|
||||||
# add documents to chroma
|
# add documents to chroma
|
||||||
res = self.ch.add_data(ids=ids, embeddings=embeddings, documents=documents, metadatas=metadatas)
|
res = self.ch.add_data(ids=ids, embeddings=embeddings, documents=documents, metadatas=metadatas)
|
||||||
logger.debug(res)
|
# logger.debug(res)
|
||||||
|
|
||||||
def init_graph(self):
|
def init_graph(self):
|
||||||
'''
|
'''
|
||||||
|
@ -169,7 +175,7 @@ class CodeImporter:
|
||||||
@return:
|
@return:
|
||||||
'''
|
'''
|
||||||
res = self.nh.create_space(space_name=self.codebase_name, vid_type='FIXED_STRING(1024)')
|
res = self.nh.create_space(space_name=self.codebase_name, vid_type='FIXED_STRING(1024)')
|
||||||
logger.debug(res.error_msg())
|
# logger.debug(res.error_msg())
|
||||||
time.sleep(5)
|
time.sleep(5)
|
||||||
|
|
||||||
self.nh.set_space_name(self.codebase_name)
|
self.nh.set_space_name(self.codebase_name)
|
||||||
|
@ -179,29 +185,29 @@ class CodeImporter:
|
||||||
tag_name = 'package'
|
tag_name = 'package'
|
||||||
prop_dict = {}
|
prop_dict = {}
|
||||||
res = self.nh.create_tag(tag_name, prop_dict)
|
res = self.nh.create_tag(tag_name, prop_dict)
|
||||||
logger.debug(res.error_msg())
|
# logger.debug(res.error_msg())
|
||||||
|
|
||||||
tag_name = 'class'
|
tag_name = 'class'
|
||||||
prop_dict = {}
|
prop_dict = {}
|
||||||
res = self.nh.create_tag(tag_name, prop_dict)
|
res = self.nh.create_tag(tag_name, prop_dict)
|
||||||
logger.debug(res.error_msg())
|
# logger.debug(res.error_msg())
|
||||||
|
|
||||||
tag_name = 'method'
|
tag_name = 'method'
|
||||||
prop_dict = {}
|
prop_dict = {}
|
||||||
res = self.nh.create_tag(tag_name, prop_dict)
|
res = self.nh.create_tag(tag_name, prop_dict)
|
||||||
logger.debug(res.error_msg())
|
# logger.debug(res.error_msg())
|
||||||
|
|
||||||
# create edge type
|
# create edge type
|
||||||
edge_type_name = 'contain'
|
edge_type_name = 'contain'
|
||||||
prop_dict = {}
|
prop_dict = {}
|
||||||
res = self.nh.create_edge_type(edge_type_name, prop_dict)
|
res = self.nh.create_edge_type(edge_type_name, prop_dict)
|
||||||
logger.debug(res.error_msg())
|
# logger.debug(res.error_msg())
|
||||||
|
|
||||||
# create edge type
|
# create edge type
|
||||||
edge_type_name = 'depend'
|
edge_type_name = 'depend'
|
||||||
prop_dict = {}
|
prop_dict = {}
|
||||||
res = self.nh.create_edge_type(edge_type_name, prop_dict)
|
res = self.nh.create_edge_type(edge_type_name, prop_dict)
|
||||||
logger.debug(res.error_msg())
|
# logger.debug(res.error_msg())
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
|
@ -5,16 +5,15 @@
|
||||||
@time: 2023/11/21 下午2:25
|
@time: 2023/11/21 下午2:25
|
||||||
@desc:
|
@desc:
|
||||||
'''
|
'''
|
||||||
|
import os
|
||||||
import time
|
import time
|
||||||
|
import json
|
||||||
|
from typing import List
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
# from configs.server_config import NEBULA_HOST, NEBULA_PORT, NEBULA_USER, NEBULA_PASSWORD, NEBULA_STORAGED_PORT
|
|
||||||
# from configs.server_config import CHROMA_PERSISTENT_PATH
|
|
||||||
# from configs.model_config import EMBEDDING_ENGINE
|
|
||||||
|
|
||||||
from coagent.base_configs.env_config import (
|
from coagent.base_configs.env_config import (
|
||||||
NEBULA_HOST, NEBULA_PORT, NEBULA_USER, NEBULA_PASSWORD, NEBULA_STORAGED_PORT,
|
NEBULA_HOST, NEBULA_PORT, NEBULA_USER, NEBULA_PASSWORD, NEBULA_STORAGED_PORT,
|
||||||
CHROMA_PERSISTENT_PATH
|
CHROMA_PERSISTENT_PATH, CB_ROOT_PATH
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -35,7 +34,9 @@ class CodeBaseHandler:
|
||||||
language: str = 'java',
|
language: str = 'java',
|
||||||
crawl_type: str = 'ZIP',
|
crawl_type: str = 'ZIP',
|
||||||
embed_config: EmbedConfig = EmbedConfig(),
|
embed_config: EmbedConfig = EmbedConfig(),
|
||||||
llm_config: LLMConfig = LLMConfig()
|
llm_config: LLMConfig = LLMConfig(),
|
||||||
|
use_nh: bool = True,
|
||||||
|
local_graph_path: str = CB_ROOT_PATH
|
||||||
):
|
):
|
||||||
self.codebase_name = codebase_name
|
self.codebase_name = codebase_name
|
||||||
self.code_path = code_path
|
self.code_path = code_path
|
||||||
|
@ -43,11 +44,28 @@ class CodeBaseHandler:
|
||||||
self.crawl_type = crawl_type
|
self.crawl_type = crawl_type
|
||||||
self.embed_config = embed_config
|
self.embed_config = embed_config
|
||||||
self.llm_config = llm_config
|
self.llm_config = llm_config
|
||||||
|
self.local_graph_file_path = local_graph_path + os.sep + f'{self.codebase_name}_graph.json'
|
||||||
|
|
||||||
self.nh = NebulaHandler(host=NEBULA_HOST, port=NEBULA_PORT, username=NEBULA_USER,
|
if use_nh:
|
||||||
password=NEBULA_PASSWORD, space_name=codebase_name)
|
try:
|
||||||
self.nh.add_host(NEBULA_HOST, NEBULA_STORAGED_PORT)
|
self.nh = NebulaHandler(host=NEBULA_HOST, port=NEBULA_PORT, username=NEBULA_USER,
|
||||||
time.sleep(1)
|
password=NEBULA_PASSWORD, space_name=codebase_name)
|
||||||
|
self.nh.add_host(NEBULA_HOST, NEBULA_STORAGED_PORT)
|
||||||
|
time.sleep(1)
|
||||||
|
except:
|
||||||
|
self.nh = None
|
||||||
|
try:
|
||||||
|
with open(self.local_graph_file_path, 'r') as f:
|
||||||
|
self.graph = json.load(f)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
elif local_graph_path:
|
||||||
|
self.nh = None
|
||||||
|
try:
|
||||||
|
with open(self.local_graph_file_path, 'r') as f:
|
||||||
|
self.graph = json.load(f)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
self.ch = ChromaHandler(path=CHROMA_PERSISTENT_PATH, collection_name=codebase_name)
|
self.ch = ChromaHandler(path=CHROMA_PERSISTENT_PATH, collection_name=codebase_name)
|
||||||
|
|
||||||
|
@ -58,9 +76,10 @@ class CodeBaseHandler:
|
||||||
'''
|
'''
|
||||||
# init graph to init tag and edge
|
# init graph to init tag and edge
|
||||||
code_importer = CodeImporter(embed_config=self.embed_config, codebase_name=self.codebase_name,
|
code_importer = CodeImporter(embed_config=self.embed_config, codebase_name=self.codebase_name,
|
||||||
nh=self.nh, ch=self.ch)
|
nh=self.nh, ch=self.ch, local_graph_file_path=self.local_graph_file_path)
|
||||||
code_importer.init_graph()
|
if self.nh:
|
||||||
time.sleep(5)
|
code_importer.init_graph()
|
||||||
|
time.sleep(5)
|
||||||
|
|
||||||
# crawl code
|
# crawl code
|
||||||
st0 = time.time()
|
st0 = time.time()
|
||||||
|
@ -71,7 +90,7 @@ class CodeBaseHandler:
|
||||||
# analyze code
|
# analyze code
|
||||||
logger.info('start analyze')
|
logger.info('start analyze')
|
||||||
st1 = time.time()
|
st1 = time.time()
|
||||||
code_analyzer = CodeAnalyzer(language=self.language, llm_config = self.llm_config)
|
code_analyzer = CodeAnalyzer(language=self.language, llm_config=self.llm_config)
|
||||||
static_analysis_res, interpretation = code_analyzer.analyze(code_dict, do_interpret=do_interpret)
|
static_analysis_res, interpretation = code_analyzer.analyze(code_dict, do_interpret=do_interpret)
|
||||||
logger.debug('analyze done, rt={}'.format(time.time() - st1))
|
logger.debug('analyze done, rt={}'.format(time.time() - st1))
|
||||||
|
|
||||||
|
@ -81,8 +100,12 @@ class CodeBaseHandler:
|
||||||
logger.debug('update codebase done, rt={}'.format(time.time() - st2))
|
logger.debug('update codebase done, rt={}'.format(time.time() - st2))
|
||||||
|
|
||||||
# get KG info
|
# get KG info
|
||||||
stat = self.nh.get_stat()
|
if self.nh:
|
||||||
vertices_num, edges_num = stat['vertices'], stat['edges']
|
stat = self.nh.get_stat()
|
||||||
|
vertices_num, edges_num = stat['vertices'], stat['edges']
|
||||||
|
else:
|
||||||
|
vertices_num = 0
|
||||||
|
edges_num = 0
|
||||||
|
|
||||||
# get chroma info
|
# get chroma info
|
||||||
file_num = self.ch.count()['result']
|
file_num = self.ch.count()['result']
|
||||||
|
@ -95,7 +118,11 @@ class CodeBaseHandler:
|
||||||
@param codebase_name: name of codebase
|
@param codebase_name: name of codebase
|
||||||
@return:
|
@return:
|
||||||
'''
|
'''
|
||||||
self.nh.drop_space(space_name=codebase_name)
|
if self.nh:
|
||||||
|
self.nh.drop_space(space_name=codebase_name)
|
||||||
|
elif self.local_graph_file_path and os.path.isfile(self.local_graph_file_path):
|
||||||
|
os.remove(self.local_graph_file_path)
|
||||||
|
|
||||||
self.ch.delete_collection(collection_name=codebase_name)
|
self.ch.delete_collection(collection_name=codebase_name)
|
||||||
|
|
||||||
def crawl_code(self, zip_file=''):
|
def crawl_code(self, zip_file=''):
|
||||||
|
@ -124,9 +151,15 @@ class CodeBaseHandler:
|
||||||
@param search_type: ['cypher', 'graph', 'vector']
|
@param search_type: ['cypher', 'graph', 'vector']
|
||||||
@return:
|
@return:
|
||||||
'''
|
'''
|
||||||
assert search_type in ['cypher', 'tag', 'description']
|
if self.nh:
|
||||||
|
assert search_type in ['cypher', 'tag', 'description']
|
||||||
|
else:
|
||||||
|
if search_type == 'tag':
|
||||||
|
search_type = 'tag_by_local_graph'
|
||||||
|
assert search_type in ['tag_by_local_graph', 'description']
|
||||||
|
|
||||||
code_search = CodeSearch(llm_config=self.llm_config, nh=self.nh, ch=self.ch, limit=limit)
|
code_search = CodeSearch(llm_config=self.llm_config, nh=self.nh, ch=self.ch, limit=limit,
|
||||||
|
local_graph_file_path=self.local_graph_file_path)
|
||||||
|
|
||||||
if search_type == 'cypher':
|
if search_type == 'cypher':
|
||||||
search_res = code_search.search_by_cypher(query=query)
|
search_res = code_search.search_by_cypher(query=query)
|
||||||
|
@ -134,7 +167,11 @@ class CodeBaseHandler:
|
||||||
search_res = code_search.search_by_tag(query=query)
|
search_res = code_search.search_by_tag(query=query)
|
||||||
elif search_type == 'description':
|
elif search_type == 'description':
|
||||||
search_res = code_search.search_by_desciption(
|
search_res = code_search.search_by_desciption(
|
||||||
query=query, engine=self.embed_config.embed_engine, model_path=self.embed_config.embed_model_path, embedding_device=self.embed_config.model_device)
|
query=query, engine=self.embed_config.embed_engine, model_path=self.embed_config.embed_model_path,
|
||||||
|
embedding_device=self.embed_config.model_device, embed_config=self.embed_config)
|
||||||
|
elif search_type == 'tag_by_local_graph':
|
||||||
|
search_res = code_search.search_by_tag_by_graph(query=query)
|
||||||
|
|
||||||
|
|
||||||
context, related_vertice = self.format_search_res(search_res, search_type)
|
context, related_vertice = self.format_search_res(search_res, search_type)
|
||||||
return context, related_vertice
|
return context, related_vertice
|
||||||
|
@ -160,6 +197,12 @@ class CodeBaseHandler:
|
||||||
for code in search_res:
|
for code in search_res:
|
||||||
context = context + code['code_text'] + '\n'
|
context = context + code['code_text'] + '\n'
|
||||||
related_vertice.append(code['vertex'])
|
related_vertice.append(code['vertex'])
|
||||||
|
elif search_type == 'tag_by_local_graph':
|
||||||
|
context = ''
|
||||||
|
related_vertice = []
|
||||||
|
for code in search_res:
|
||||||
|
context = context + code['code_text'] + '\n'
|
||||||
|
related_vertice.append(code['vertex'])
|
||||||
elif search_type == 'description':
|
elif search_type == 'description':
|
||||||
context = ''
|
context = ''
|
||||||
related_vertice = []
|
related_vertice = []
|
||||||
|
@ -169,17 +212,63 @@ class CodeBaseHandler:
|
||||||
|
|
||||||
return context, related_vertice
|
return context, related_vertice
|
||||||
|
|
||||||
|
def search_vertices(self, vertex_type="class") -> List[str]:
|
||||||
|
'''
|
||||||
|
通过 method/class 来搜索所有的节点
|
||||||
|
'''
|
||||||
|
vertices = []
|
||||||
|
if self.nh:
|
||||||
|
vertices = self.nh.get_all_vertices()
|
||||||
|
vertices = [str(v.as_node().get_id()) for v in vertices["v"] if vertex_type in v.as_node().tags()]
|
||||||
|
# for v in vertices["v"]:
|
||||||
|
# logger.debug(f"{v.as_node().get_id()}, {v.as_node().tags()}")
|
||||||
|
else:
|
||||||
|
if vertex_type == "class":
|
||||||
|
vertices = [str(class_name) for code, structure in self.graph.items() for class_name in structure['class_name_list']]
|
||||||
|
elif vertex_type == "method":
|
||||||
|
vertices = [
|
||||||
|
str(methods_name)
|
||||||
|
for code, structure in self.graph.items()
|
||||||
|
for methods_names in structure['func_name_dict'].values()
|
||||||
|
for methods_name in methods_names
|
||||||
|
]
|
||||||
|
# logger.debug(vertices)
|
||||||
|
return vertices
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
codebase_name = 'testing'
|
from configs.model_config import KB_ROOT_PATH, JUPYTER_WORK_PATH
|
||||||
|
from configs.server_config import SANDBOX_SERVER
|
||||||
|
|
||||||
|
LLM_MODEL = "gpt-3.5-turbo"
|
||||||
|
llm_config = LLMConfig(
|
||||||
|
model_name=LLM_MODEL, model_device="cpu", api_key=os.environ["OPENAI_API_KEY"],
|
||||||
|
api_base_url=os.environ["API_BASE_URL"], temperature=0.3
|
||||||
|
)
|
||||||
|
src_dir = '/Users/bingxu/Desktop/工作/大模型/chatbot/Codefuse-chatbot-antcode'
|
||||||
|
embed_config = EmbedConfig(
|
||||||
|
embed_engine="model", embed_model="text2vec-base-chinese",
|
||||||
|
embed_model_path=os.path.join(src_dir, "embedding_models/text2vec-base-chinese")
|
||||||
|
)
|
||||||
|
|
||||||
|
codebase_name = 'client_local'
|
||||||
code_path = '/Users/bingxu/Desktop/工作/大模型/chatbot/test_code_repo/client'
|
code_path = '/Users/bingxu/Desktop/工作/大模型/chatbot/test_code_repo/client'
|
||||||
cbh = CodeBaseHandler(codebase_name, code_path, crawl_type='dir')
|
use_nh = False
|
||||||
|
local_graph_path = '/Users/bingxu/Desktop/工作/大模型/chatbot/Codefuse-chatbot-antcode/code_base'
|
||||||
|
CHROMA_PERSISTENT_PATH = '/Users/bingxu/Desktop/工作/大模型/chatbot/Codefuse-chatbot-antcode/data/chroma_data'
|
||||||
|
|
||||||
|
cbh = CodeBaseHandler(codebase_name, code_path, crawl_type='dir', use_nh=use_nh, local_graph_path=local_graph_path,
|
||||||
|
llm_config=llm_config, embed_config=embed_config)
|
||||||
|
|
||||||
|
# test import code
|
||||||
|
# cbh.import_code(do_interpret=True)
|
||||||
|
|
||||||
# query = '使用不同的HTTP请求类型(GET、POST、DELETE等)来执行不同的操作'
|
# query = '使用不同的HTTP请求类型(GET、POST、DELETE等)来执行不同的操作'
|
||||||
# query = '代码中一共有多少个类'
|
# query = '代码中一共有多少个类'
|
||||||
|
# query = 'remove 这个函数是用来做什么的'
|
||||||
|
query = '有没有函数是从字符串中删除指定字符串的功能'
|
||||||
|
|
||||||
query = 'intercept 函数作用是什么'
|
search_type = 'description'
|
||||||
search_type = 'graph'
|
|
||||||
limit = 2
|
limit = 2
|
||||||
res = cbh.search_code(query, search_type, limit)
|
res = cbh.search_code(query, search_type, limit)
|
||||||
logger.debug(res)
|
logger.debug(res)
|
||||||
|
|
|
@ -0,0 +1,6 @@
|
||||||
|
from .base_action import BaseAction
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"BaseAction"
|
||||||
|
]
|
|
@ -0,0 +1,16 @@
|
||||||
|
|
||||||
|
from langchain.schema import BaseRetriever, Document
|
||||||
|
|
||||||
|
class BaseAction:
|
||||||
|
|
||||||
|
|
||||||
|
def __init__(self, ):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def step(self, ):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def astep(self, ):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
|
@ -4,25 +4,25 @@ import re, os
|
||||||
import copy
|
import copy
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
|
from langchain.schema import BaseRetriever
|
||||||
|
|
||||||
from coagent.connector.schema import (
|
from coagent.connector.schema import (
|
||||||
Memory, Task, Role, Message, PromptField, LogVerboseEnum
|
Memory, Task, Role, Message, PromptField, LogVerboseEnum
|
||||||
)
|
)
|
||||||
from coagent.connector.memory_manager import BaseMemoryManager
|
from coagent.connector.memory_manager import BaseMemoryManager
|
||||||
from coagent.connector.configs.prompts import BEGIN_PROMPT_INPUT
|
|
||||||
from coagent.connector.message_process import MessageUtils
|
from coagent.connector.message_process import MessageUtils
|
||||||
from coagent.llm_models import getChatModel, getExtraModel, LLMConfig, getChatModelFromConfig, EmbedConfig
|
from coagent.llm_models import getExtraModel, LLMConfig, getChatModelFromConfig, EmbedConfig
|
||||||
from coagent.connector.prompt_manager import PromptManager
|
from coagent.connector.prompt_manager.prompt_manager import PromptManager
|
||||||
from coagent.connector.memory_manager import LocalMemoryManager
|
from coagent.connector.memory_manager import LocalMemoryManager
|
||||||
from coagent.connector.utils import parse_section
|
from coagent.base_configs.env_config import JUPYTER_WORK_PATH, KB_ROOT_PATH
|
||||||
# from configs.model_config import JUPYTER_WORK_PATH
|
|
||||||
# from configs.server_config import SANDBOX_SERVER
|
|
||||||
|
|
||||||
class BaseAgent:
|
class BaseAgent:
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
role: Role,
|
role: Role,
|
||||||
prompt_config: [PromptField],
|
prompt_config: List[PromptField],
|
||||||
prompt_manager_type: str = "PromptManager",
|
prompt_manager_type: str = "PromptManager",
|
||||||
task: Task = None,
|
task: Task = None,
|
||||||
memory: Memory = None,
|
memory: Memory = None,
|
||||||
|
@ -33,8 +33,11 @@ class BaseAgent:
|
||||||
llm_config: LLMConfig = None,
|
llm_config: LLMConfig = None,
|
||||||
embed_config: EmbedConfig = None,
|
embed_config: EmbedConfig = None,
|
||||||
sandbox_server: dict = {},
|
sandbox_server: dict = {},
|
||||||
jupyter_work_path: str = "",
|
jupyter_work_path: str = JUPYTER_WORK_PATH,
|
||||||
kb_root_path: str = "",
|
kb_root_path: str = KB_ROOT_PATH,
|
||||||
|
doc_retrieval: Union[BaseRetriever] = None,
|
||||||
|
code_retrieval = None,
|
||||||
|
search_retrieval = None,
|
||||||
log_verbose: str = "0"
|
log_verbose: str = "0"
|
||||||
):
|
):
|
||||||
|
|
||||||
|
@ -43,7 +46,7 @@ class BaseAgent:
|
||||||
self.sandbox_server = sandbox_server
|
self.sandbox_server = sandbox_server
|
||||||
self.jupyter_work_path = jupyter_work_path
|
self.jupyter_work_path = jupyter_work_path
|
||||||
self.kb_root_path = kb_root_path
|
self.kb_root_path = kb_root_path
|
||||||
self.message_utils = MessageUtils(role, sandbox_server, jupyter_work_path, embed_config, llm_config, kb_root_path, log_verbose)
|
self.message_utils = MessageUtils(role, sandbox_server, jupyter_work_path, embed_config, llm_config, kb_root_path, doc_retrieval, code_retrieval, search_retrieval, log_verbose)
|
||||||
self.memory = self.init_history(memory)
|
self.memory = self.init_history(memory)
|
||||||
self.llm_config: LLMConfig = llm_config
|
self.llm_config: LLMConfig = llm_config
|
||||||
self.embed_config: EmbedConfig = embed_config
|
self.embed_config: EmbedConfig = embed_config
|
||||||
|
@ -82,12 +85,8 @@ class BaseAgent:
|
||||||
llm_config=self.embed_config
|
llm_config=self.embed_config
|
||||||
)
|
)
|
||||||
memory_manager.append(query)
|
memory_manager.append(query)
|
||||||
memory_pool = memory_manager.current_memory
|
memory_pool = memory_manager.get_memory_pool(query.user_name)
|
||||||
else:
|
|
||||||
memory_pool = memory_manager.current_memory
|
|
||||||
|
|
||||||
|
|
||||||
logger.debug(f"memory_pool: {memory_pool}")
|
|
||||||
prompt = self.prompt_manager.generate_full_prompt(
|
prompt = self.prompt_manager.generate_full_prompt(
|
||||||
previous_agent_message=query_c, agent_long_term_memory=self.memory, ui_history=history, chain_summary_messages=background, memory_pool=memory_pool)
|
previous_agent_message=query_c, agent_long_term_memory=self.memory, ui_history=history, chain_summary_messages=background, memory_pool=memory_pool)
|
||||||
content = self.llm.predict(prompt)
|
content = self.llm.predict(prompt)
|
||||||
|
@ -99,6 +98,7 @@ class BaseAgent:
|
||||||
logger.info(f"{self.role.role_name} content: {content}")
|
logger.info(f"{self.role.role_name} content: {content}")
|
||||||
|
|
||||||
output_message = Message(
|
output_message = Message(
|
||||||
|
user_name=query.user_name,
|
||||||
role_name=self.role.role_name,
|
role_name=self.role.role_name,
|
||||||
role_type="assistant", #self.role.role_type,
|
role_type="assistant", #self.role.role_type,
|
||||||
role_content=content,
|
role_content=content,
|
||||||
|
@ -151,10 +151,7 @@ class BaseAgent:
|
||||||
self.memory = self.init_history()
|
self.memory = self.init_history()
|
||||||
|
|
||||||
def create_llm_engine(self, llm_config: LLMConfig = None, temperature=0.2, stop=None):
|
def create_llm_engine(self, llm_config: LLMConfig = None, temperature=0.2, stop=None):
|
||||||
if llm_config is None:
|
return getChatModelFromConfig(llm_config=llm_config)
|
||||||
return getChatModel(temperature=temperature, stop=stop)
|
|
||||||
else:
|
|
||||||
return getChatModelFromConfig(llm_config=llm_config)
|
|
||||||
|
|
||||||
def registry_actions(self, actions):
|
def registry_actions(self, actions):
|
||||||
'''registry llm's actions'''
|
'''registry llm's actions'''
|
||||||
|
@ -212,171 +209,3 @@ class BaseAgent:
|
||||||
|
|
||||||
def get_memory_str(self, content_key="role_content"):
|
def get_memory_str(self, content_key="role_content"):
|
||||||
return "\n".join([": ".join(i) for i in self.memory.to_tuple_messages(content_key="step_content")])
|
return "\n".join([": ".join(i) for i in self.memory.to_tuple_messages(content_key="step_content")])
|
||||||
|
|
||||||
|
|
||||||
def create_prompt(
|
|
||||||
self, query: Message, memory: Memory =None, history: Memory = None, background: Memory = None, memory_pool: Memory=None, prompt_mamnger=None) -> str:
|
|
||||||
'''
|
|
||||||
prompt engineer, contains role\task\tools\docs\memory
|
|
||||||
'''
|
|
||||||
#
|
|
||||||
doc_infos = self.create_doc_prompt(query)
|
|
||||||
code_infos = self.create_codedoc_prompt(query)
|
|
||||||
#
|
|
||||||
formatted_tools, tool_names, _ = self.create_tools_prompt(query)
|
|
||||||
task_prompt = self.create_task_prompt(query)
|
|
||||||
background_prompt = self.create_background_prompt(background, control_key="step_content")
|
|
||||||
history_prompt = self.create_history_prompt(history)
|
|
||||||
selfmemory_prompt = self.create_selfmemory_prompt(memory, control_key="step_content")
|
|
||||||
|
|
||||||
# extra_system_prompt = self.role.role_prompt
|
|
||||||
|
|
||||||
|
|
||||||
prompt = self.role.role_prompt.format(**{"formatted_tools": formatted_tools, "tool_names": tool_names})
|
|
||||||
#
|
|
||||||
memory_pool_select_by_agent_key = self.select_memory_by_agent_key(memory_pool)
|
|
||||||
memory_pool_select_by_agent_key_context = '\n\n'.join([f"*{k}*\n{v}" for parsed_output in memory_pool_select_by_agent_key.get_parserd_output_list() for k, v in parsed_output.items() if k not in ['Action Status']])
|
|
||||||
|
|
||||||
# input_query = query.input_query
|
|
||||||
|
|
||||||
# # logger.debug(f"{self.role.role_name} extra_system_prompt: {self.role.role_prompt}")
|
|
||||||
# # logger.debug(f"{self.role.role_name} input_query: {input_query}")
|
|
||||||
# # logger.debug(f"{self.role.role_name} doc_infos: {doc_infos}")
|
|
||||||
# # logger.debug(f"{self.role.role_name} tool_names: {tool_names}")
|
|
||||||
# if "**Context:**" in self.role.role_prompt:
|
|
||||||
# # logger.debug(f"parsed_output_list: {query.parsed_output_list}")
|
|
||||||
# # input_query = "'''" + "\n".join([f"###{k}###\n{v}" for i in query.parsed_output_list for k,v in i.items() if "Action Status" !=k]) + "'''"
|
|
||||||
# context = "\n".join([f"*{k}*\n{v}" for i in query.parsed_output_list for k,v in i.items() if "Action Status" !=k])
|
|
||||||
# # context = history_prompt or '""'
|
|
||||||
# # logger.debug(f"parsed_output_list: {t}")
|
|
||||||
# prompt += "\n" + QUERY_CONTEXT_PROMPT_INPUT.format(**{"context": context, "query": query.origin_query})
|
|
||||||
# else:
|
|
||||||
# prompt += "\n" + PLAN_PROMPT_INPUT.format(**{"query": input_query})
|
|
||||||
|
|
||||||
task = query.task or self.task
|
|
||||||
if task_prompt is not None:
|
|
||||||
prompt += "\n" + task.task_prompt
|
|
||||||
|
|
||||||
DocInfos = ""
|
|
||||||
if doc_infos is not None and doc_infos!="" and doc_infos!="不存在知识库辅助信息":
|
|
||||||
DocInfos += f"\nDocument Information: {doc_infos}"
|
|
||||||
|
|
||||||
if code_infos is not None and code_infos!="" and code_infos!="不存在代码库辅助信息":
|
|
||||||
DocInfos += f"\nCodeBase Infomation: {code_infos}"
|
|
||||||
|
|
||||||
# if selfmemory_prompt:
|
|
||||||
# prompt += "\n" + selfmemory_prompt
|
|
||||||
|
|
||||||
# if background_prompt:
|
|
||||||
# prompt += "\n" + background_prompt
|
|
||||||
|
|
||||||
# if history_prompt:
|
|
||||||
# prompt += "\n" + history_prompt
|
|
||||||
|
|
||||||
input_query = query.input_query
|
|
||||||
|
|
||||||
# logger.debug(f"{self.role.role_name} extra_system_prompt: {self.role.role_prompt}")
|
|
||||||
# logger.debug(f"{self.role.role_name} input_query: {input_query}")
|
|
||||||
# logger.debug(f"{self.role.role_name} doc_infos: {doc_infos}")
|
|
||||||
# logger.debug(f"{self.role.role_name} tool_names: {tool_names}")
|
|
||||||
|
|
||||||
# extra_system_prompt = self.role.role_prompt
|
|
||||||
input_keys = parse_section(self.role.role_prompt, 'Input Format')
|
|
||||||
prompt = self.role.role_prompt.format(**{"formatted_tools": formatted_tools, "tool_names": tool_names})
|
|
||||||
prompt += "\n" + BEGIN_PROMPT_INPUT
|
|
||||||
for input_key in input_keys:
|
|
||||||
if input_key == "Origin Query":
|
|
||||||
prompt += "\n**Origin Query:**\n" + query.origin_query
|
|
||||||
elif input_key == "Context":
|
|
||||||
context = "\n".join([f"*{k}*\n{v}" for i in query.parsed_output_list for k,v in i.items() if "Action Status" !=k])
|
|
||||||
if history:
|
|
||||||
context = history_prompt + "\n" + context
|
|
||||||
if not context:
|
|
||||||
context = "there is no context"
|
|
||||||
|
|
||||||
if self.focus_agents and memory_pool_select_by_agent_key_context:
|
|
||||||
context = memory_pool_select_by_agent_key_context
|
|
||||||
prompt += "\n**Context:**\n" + context + "\n" + input_query
|
|
||||||
elif input_key == "DocInfos":
|
|
||||||
if DocInfos:
|
|
||||||
prompt += "\n**DocInfos:**\n" + DocInfos
|
|
||||||
else:
|
|
||||||
prompt += "\n**DocInfos:**\n" + "Empty"
|
|
||||||
elif input_key == "Question":
|
|
||||||
prompt += "\n**Question:**\n" + input_query
|
|
||||||
|
|
||||||
# if "**Context:**" in self.role.role_prompt:
|
|
||||||
# # logger.debug(f"parsed_output_list: {query.parsed_output_list}")
|
|
||||||
# # input_query = "'''" + "\n".join([f"###{k}###\n{v}" for i in query.parsed_output_list for k,v in i.items() if "Action Status" !=k]) + "'''"
|
|
||||||
# context = "\n".join([f"*{k}*\n{v}" for i in query.parsed_output_list for k,v in i.items() if "Action Status" !=k])
|
|
||||||
# if history:
|
|
||||||
# context = history_prompt + "\n" + context
|
|
||||||
|
|
||||||
# if not context:
|
|
||||||
# context = "there is no context"
|
|
||||||
|
|
||||||
# # logger.debug(f"parsed_output_list: {t}")
|
|
||||||
# if "DocInfos" in prompt:
|
|
||||||
# prompt += "\n" + QUERY_CONTEXT_DOC_PROMPT_INPUT.format(**{"context": context, "query": query.origin_query, "DocInfos": DocInfos})
|
|
||||||
# else:
|
|
||||||
# prompt += "\n" + QUERY_CONTEXT_PROMPT_INPUT.format(**{"context": context, "query": query.origin_query, "DocInfos": DocInfos})
|
|
||||||
# else:
|
|
||||||
# prompt += "\n" + BASE_PROMPT_INPUT.format(**{"query": input_query})
|
|
||||||
|
|
||||||
# prompt = extra_system_prompt.format(**{"query": input_query, "doc_infos": doc_infos, "formatted_tools": formatted_tools, "tool_names": tool_names})
|
|
||||||
while "{{" in prompt or "}}" in prompt:
|
|
||||||
prompt = prompt.replace("{{", "{")
|
|
||||||
prompt = prompt.replace("}}", "}")
|
|
||||||
|
|
||||||
# logger.debug(f"{self.role.role_name} prompt: {prompt}")
|
|
||||||
return prompt
|
|
||||||
|
|
||||||
def create_doc_prompt(self, message: Message) -> str:
|
|
||||||
''''''
|
|
||||||
db_docs = message.db_docs
|
|
||||||
search_docs = message.search_docs
|
|
||||||
doc_infos = "\n".join([doc.get_snippet() for doc in db_docs] + [doc.get_snippet() for doc in search_docs])
|
|
||||||
return doc_infos or "不存在知识库辅助信息"
|
|
||||||
|
|
||||||
def create_codedoc_prompt(self, message: Message) -> str:
|
|
||||||
''''''
|
|
||||||
code_docs = message.code_docs
|
|
||||||
doc_infos = "\n".join([doc.get_code() for doc in code_docs])
|
|
||||||
return doc_infos or "不存在代码库辅助信息"
|
|
||||||
|
|
||||||
def create_tools_prompt(self, message: Message) -> str:
|
|
||||||
tools = message.tools
|
|
||||||
tool_strings = []
|
|
||||||
tools_descs = []
|
|
||||||
for tool in tools:
|
|
||||||
args_schema = re.sub("}", "}}}}", re.sub("{", "{{{{", str(tool.args)))
|
|
||||||
tool_strings.append(f"{tool.name}: {tool.description}, args: {args_schema}")
|
|
||||||
tools_descs.append(f"{tool.name}: {tool.description}")
|
|
||||||
formatted_tools = "\n".join(tool_strings)
|
|
||||||
tools_desc_str = "\n".join(tools_descs)
|
|
||||||
tool_names = ", ".join([tool.name for tool in tools])
|
|
||||||
return formatted_tools, tool_names, tools_desc_str
|
|
||||||
|
|
||||||
def create_task_prompt(self, message: Message) -> str:
|
|
||||||
task = message.task or self.task
|
|
||||||
return "\n任务目标: " + task.task_prompt if task is not None else None
|
|
||||||
|
|
||||||
def create_background_prompt(self, background: Memory, control_key="role_content") -> str:
|
|
||||||
background_message = None if background is None else background.to_str_messages(content_key=control_key)
|
|
||||||
# logger.debug(f"background_message: {background_message}")
|
|
||||||
if background_message:
|
|
||||||
background_message = re.sub("}", "}}", re.sub("{", "{{", background_message))
|
|
||||||
return "\n背景信息: " + background_message if background_message else None
|
|
||||||
|
|
||||||
def create_history_prompt(self, history: Memory, control_key="role_content") -> str:
|
|
||||||
history_message = None if history is None else history.to_str_messages(content_key=control_key)
|
|
||||||
if history_message:
|
|
||||||
history_message = re.sub("}", "}}", re.sub("{", "{{", history_message))
|
|
||||||
return "\n补充对话信息: " + history_message if history_message else None
|
|
||||||
|
|
||||||
def create_selfmemory_prompt(self, selfmemory: Memory, control_key="role_content") -> str:
|
|
||||||
selfmemory_message = None if selfmemory is None else selfmemory.to_str_messages(content_key=control_key)
|
|
||||||
if selfmemory_message:
|
|
||||||
selfmemory_message = re.sub("}", "}}", re.sub("{", "{{", selfmemory_message))
|
|
||||||
return "\n补充自身对话信息: " + selfmemory_message if selfmemory_message else None
|
|
||||||
|
|
|
@ -2,14 +2,15 @@ from typing import List, Union
|
||||||
import copy
|
import copy
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
|
from langchain.schema import BaseRetriever
|
||||||
|
|
||||||
from coagent.connector.schema import (
|
from coagent.connector.schema import (
|
||||||
Memory, Task, Env, Role, Message, ActionStatus, PromptField, LogVerboseEnum
|
Memory, Task, Env, Role, Message, ActionStatus, PromptField, LogVerboseEnum
|
||||||
)
|
)
|
||||||
from coagent.connector.memory_manager import BaseMemoryManager
|
from coagent.connector.memory_manager import BaseMemoryManager
|
||||||
from coagent.connector.configs.prompts import BEGIN_PROMPT_INPUT
|
|
||||||
from coagent.llm_models import LLMConfig, EmbedConfig
|
from coagent.llm_models import LLMConfig, EmbedConfig
|
||||||
from coagent.connector.memory_manager import LocalMemoryManager
|
from coagent.connector.memory_manager import LocalMemoryManager
|
||||||
|
from coagent.base_configs.env_config import JUPYTER_WORK_PATH, KB_ROOT_PATH
|
||||||
from .base_agent import BaseAgent
|
from .base_agent import BaseAgent
|
||||||
|
|
||||||
|
|
||||||
|
@ -17,7 +18,7 @@ class ExecutorAgent(BaseAgent):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
role: Role,
|
role: Role,
|
||||||
prompt_config: [PromptField],
|
prompt_config: List[PromptField],
|
||||||
prompt_manager_type: str= "PromptManager",
|
prompt_manager_type: str= "PromptManager",
|
||||||
task: Task = None,
|
task: Task = None,
|
||||||
memory: Memory = None,
|
memory: Memory = None,
|
||||||
|
@ -28,14 +29,17 @@ class ExecutorAgent(BaseAgent):
|
||||||
llm_config: LLMConfig = None,
|
llm_config: LLMConfig = None,
|
||||||
embed_config: EmbedConfig = None,
|
embed_config: EmbedConfig = None,
|
||||||
sandbox_server: dict = {},
|
sandbox_server: dict = {},
|
||||||
jupyter_work_path: str = "",
|
jupyter_work_path: str = JUPYTER_WORK_PATH,
|
||||||
kb_root_path: str = "",
|
kb_root_path: str = KB_ROOT_PATH,
|
||||||
|
doc_retrieval: Union[BaseRetriever] = None,
|
||||||
|
code_retrieval = None,
|
||||||
|
search_retrieval = None,
|
||||||
log_verbose: str = "0"
|
log_verbose: str = "0"
|
||||||
):
|
):
|
||||||
|
|
||||||
super().__init__(role, prompt_config, prompt_manager_type, task, memory, chat_turn,
|
super().__init__(role, prompt_config, prompt_manager_type, task, memory, chat_turn,
|
||||||
focus_agents, focus_message_keys, llm_config, embed_config, sandbox_server,
|
focus_agents, focus_message_keys, llm_config, embed_config, sandbox_server,
|
||||||
jupyter_work_path, kb_root_path, log_verbose
|
jupyter_work_path, kb_root_path, doc_retrieval, code_retrieval, search_retrieval, log_verbose
|
||||||
)
|
)
|
||||||
self.do_all_task = True # run all tasks
|
self.do_all_task = True # run all tasks
|
||||||
|
|
||||||
|
@ -45,6 +49,7 @@ class ExecutorAgent(BaseAgent):
|
||||||
task_executor_memory = Memory(messages=[])
|
task_executor_memory = Memory(messages=[])
|
||||||
# insert query
|
# insert query
|
||||||
output_message = Message(
|
output_message = Message(
|
||||||
|
user_name=query.user_name,
|
||||||
role_name=self.role.role_name,
|
role_name=self.role.role_name,
|
||||||
role_type="assistant", #self.role.role_type,
|
role_type="assistant", #self.role.role_type,
|
||||||
role_content=query.input_query,
|
role_content=query.input_query,
|
||||||
|
@ -115,7 +120,7 @@ class ExecutorAgent(BaseAgent):
|
||||||
history: Memory, background: Memory, memory_manager: BaseMemoryManager,
|
history: Memory, background: Memory, memory_manager: BaseMemoryManager,
|
||||||
task_memory: Memory) -> Union[Message, Memory]:
|
task_memory: Memory) -> Union[Message, Memory]:
|
||||||
'''execute the llm predict by created prompt'''
|
'''execute the llm predict by created prompt'''
|
||||||
memory_pool = memory_manager.current_memory
|
memory_pool = memory_manager.get_memory_pool(query.user_name)
|
||||||
prompt = self.prompt_manager.generate_full_prompt(
|
prompt = self.prompt_manager.generate_full_prompt(
|
||||||
previous_agent_message=query, agent_long_term_memory=self_memory, ui_history=history, chain_summary_messages=background, memory_pool=memory_pool,
|
previous_agent_message=query, agent_long_term_memory=self_memory, ui_history=history, chain_summary_messages=background, memory_pool=memory_pool,
|
||||||
task_memory=task_memory)
|
task_memory=task_memory)
|
||||||
|
|
|
@ -3,23 +3,23 @@ import traceback
|
||||||
import copy
|
import copy
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
|
from langchain.schema import BaseRetriever
|
||||||
|
|
||||||
from coagent.connector.schema import (
|
from coagent.connector.schema import (
|
||||||
Memory, Task, Env, Role, Message, ActionStatus, PromptField, LogVerboseEnum
|
Memory, Task, Env, Role, Message, ActionStatus, PromptField, LogVerboseEnum
|
||||||
)
|
)
|
||||||
from coagent.connector.memory_manager import BaseMemoryManager
|
from coagent.connector.memory_manager import BaseMemoryManager
|
||||||
from coagent.connector.configs.agent_config import REACT_PROMPT_INPUT
|
|
||||||
from coagent.llm_models import LLMConfig, EmbedConfig
|
from coagent.llm_models import LLMConfig, EmbedConfig
|
||||||
from .base_agent import BaseAgent
|
from .base_agent import BaseAgent
|
||||||
from coagent.connector.memory_manager import LocalMemoryManager
|
from coagent.connector.memory_manager import LocalMemoryManager
|
||||||
|
from coagent.base_configs.env_config import JUPYTER_WORK_PATH, KB_ROOT_PATH
|
||||||
from coagent.connector.prompt_manager import PromptManager
|
|
||||||
|
|
||||||
|
|
||||||
class ReactAgent(BaseAgent):
|
class ReactAgent(BaseAgent):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
role: Role,
|
role: Role,
|
||||||
prompt_config: [PromptField],
|
prompt_config: List[PromptField],
|
||||||
prompt_manager_type: str = "PromptManager",
|
prompt_manager_type: str = "PromptManager",
|
||||||
task: Task = None,
|
task: Task = None,
|
||||||
memory: Memory = None,
|
memory: Memory = None,
|
||||||
|
@ -30,14 +30,17 @@ class ReactAgent(BaseAgent):
|
||||||
llm_config: LLMConfig = None,
|
llm_config: LLMConfig = None,
|
||||||
embed_config: EmbedConfig = None,
|
embed_config: EmbedConfig = None,
|
||||||
sandbox_server: dict = {},
|
sandbox_server: dict = {},
|
||||||
jupyter_work_path: str = "",
|
jupyter_work_path: str = JUPYTER_WORK_PATH,
|
||||||
kb_root_path: str = "",
|
kb_root_path: str = KB_ROOT_PATH,
|
||||||
|
doc_retrieval: Union[BaseRetriever] = None,
|
||||||
|
code_retrieval = None,
|
||||||
|
search_retrieval = None,
|
||||||
log_verbose: str = "0"
|
log_verbose: str = "0"
|
||||||
):
|
):
|
||||||
|
|
||||||
super().__init__(role, prompt_config, prompt_manager_type, task, memory, chat_turn,
|
super().__init__(role, prompt_config, prompt_manager_type, task, memory, chat_turn,
|
||||||
focus_agents, focus_message_keys, llm_config, embed_config, sandbox_server,
|
focus_agents, focus_message_keys, llm_config, embed_config, sandbox_server,
|
||||||
jupyter_work_path, kb_root_path, log_verbose
|
jupyter_work_path, kb_root_path, doc_retrieval, code_retrieval, search_retrieval, log_verbose
|
||||||
)
|
)
|
||||||
|
|
||||||
def step(self, query: Message, history: Memory = None, background: Memory = None, memory_manager: BaseMemoryManager = None) -> Message:
|
def step(self, query: Message, history: Memory = None, background: Memory = None, memory_manager: BaseMemoryManager = None) -> Message:
|
||||||
|
@ -52,6 +55,7 @@ class ReactAgent(BaseAgent):
|
||||||
react_memory = Memory(messages=[])
|
react_memory = Memory(messages=[])
|
||||||
# insert query
|
# insert query
|
||||||
output_message = Message(
|
output_message = Message(
|
||||||
|
user_name=query.user_name,
|
||||||
role_name=self.role.role_name,
|
role_name=self.role.role_name,
|
||||||
role_type="assistant", #self.role.role_type,
|
role_type="assistant", #self.role.role_type,
|
||||||
role_content=query.input_query,
|
role_content=query.input_query,
|
||||||
|
@ -84,9 +88,7 @@ class ReactAgent(BaseAgent):
|
||||||
llm_config=self.embed_config
|
llm_config=self.embed_config
|
||||||
)
|
)
|
||||||
memory_manager.append(query)
|
memory_manager.append(query)
|
||||||
memory_pool = memory_manager.current_memory
|
memory_pool = memory_manager.get_memory_pool(query_c.user_name)
|
||||||
else:
|
|
||||||
memory_pool = memory_manager.current_memory
|
|
||||||
|
|
||||||
prompt = self.prompt_manager.generate_full_prompt(
|
prompt = self.prompt_manager.generate_full_prompt(
|
||||||
previous_agent_message=query_c, agent_long_term_memory=self.memory, ui_history=history, chain_summary_messages=background, react_memory=react_memory,
|
previous_agent_message=query_c, agent_long_term_memory=self.memory, ui_history=history, chain_summary_messages=background, react_memory=react_memory,
|
||||||
|
@ -142,82 +144,4 @@ class ReactAgent(BaseAgent):
|
||||||
title = f"<<<<{self.role.role_name}'s prompt>>>>"
|
title = f"<<<<{self.role.role_name}'s prompt>>>>"
|
||||||
print("#"*len(title) + f"\n{title}\n"+ "#"*len(title)+ f"\n\n{prompt}\n\n")
|
print("#"*len(title) + f"\n{title}\n"+ "#"*len(title)+ f"\n\n{prompt}\n\n")
|
||||||
|
|
||||||
# def create_prompt(
|
|
||||||
# self, query: Message, memory: Memory =None, history: Memory = None, background: Memory = None, react_memory: Memory = None, memory_manager: BaseMemoryManager= None,
|
|
||||||
# prompt_mamnger=None) -> str:
|
|
||||||
# prompt_mamnger = PromptManager()
|
|
||||||
# prompt_mamnger.register_standard_fields()
|
|
||||||
|
|
||||||
# # input_keys = parse_section(self.role.role_prompt, 'Agent Profile')
|
|
||||||
|
|
||||||
# data_dict = {
|
|
||||||
# "agent_profile": extract_section(self.role.role_prompt, 'Agent Profile'),
|
|
||||||
# "tool_information": query.tools,
|
|
||||||
# "session_records": memory_manager,
|
|
||||||
# "reference_documents": query,
|
|
||||||
# "output_format": extract_section(self.role.role_prompt, 'Response Output Format'),
|
|
||||||
# "response": "\n".join(["\n".join([f"**{k}:**\n{v}" for k,v in _dict.items()]) for _dict in react_memory.get_parserd_output()]),
|
|
||||||
# }
|
|
||||||
# # logger.debug(memory_pool)
|
|
||||||
|
|
||||||
# return prompt_mamnger.generate_full_prompt(data_dict)
|
|
||||||
|
|
||||||
def create_prompt(
|
|
||||||
self, query: Message, memory: Memory =None, history: Memory = None, background: Memory = None, react_memory: Memory = None, memory_pool: Memory= None,
|
|
||||||
prompt_mamnger=None) -> str:
|
|
||||||
'''
|
|
||||||
role\task\tools\docs\memory
|
|
||||||
'''
|
|
||||||
#
|
|
||||||
doc_infos = self.create_doc_prompt(query)
|
|
||||||
code_infos = self.create_codedoc_prompt(query)
|
|
||||||
#
|
|
||||||
formatted_tools, tool_names, _ = self.create_tools_prompt(query)
|
|
||||||
task_prompt = self.create_task_prompt(query)
|
|
||||||
background_prompt = self.create_background_prompt(background)
|
|
||||||
history_prompt = self.create_history_prompt(history)
|
|
||||||
selfmemory_prompt = self.create_selfmemory_prompt(memory, control_key="step_content")
|
|
||||||
#
|
|
||||||
# extra_system_prompt = self.role.role_prompt
|
|
||||||
prompt = self.role.role_prompt.format(**{"formatted_tools": formatted_tools, "tool_names": tool_names})
|
|
||||||
|
|
||||||
# react 流程是自身迭代过程,另外二次触发的是需要作为历史对话信息
|
|
||||||
# input_query = react_memory.to_tuple_messages(content_key="step_content")
|
|
||||||
# # input_query = query.input_query + "\n" + "\n".join([f"{v}" for k, v in input_query if v])
|
|
||||||
# input_query = "\n".join([f"{v}" for k, v in input_query if v])
|
|
||||||
input_query = "\n".join(["\n".join([f"**{k}:**\n{v}" for k,v in _dict.items()]) for _dict in react_memory.get_parserd_output()])
|
|
||||||
# logger.debug(f"input_query: {input_query}")
|
|
||||||
|
|
||||||
prompt += "\n" + REACT_PROMPT_INPUT.format(**{"query": input_query})
|
|
||||||
|
|
||||||
task = query.task or self.task
|
|
||||||
# if task_prompt is not None:
|
|
||||||
# prompt += "\n" + task.task_prompt
|
|
||||||
|
|
||||||
# if doc_infos is not None and doc_infos!="" and doc_infos!="不存在知识库辅助信息":
|
|
||||||
# prompt += f"\n知识库信息: {doc_infos}"
|
|
||||||
|
|
||||||
# if code_infos is not None and code_infos!="" and code_infos!="不存在代码库辅助信息":
|
|
||||||
# prompt += f"\n代码库信息: {code_infos}"
|
|
||||||
|
|
||||||
# if background_prompt:
|
|
||||||
# prompt += "\n" + background_prompt
|
|
||||||
|
|
||||||
# if history_prompt:
|
|
||||||
# prompt += "\n" + history_prompt
|
|
||||||
|
|
||||||
# if selfmemory_prompt:
|
|
||||||
# prompt += "\n" + selfmemory_prompt
|
|
||||||
|
|
||||||
# logger.debug(f"{self.role.role_name} extra_system_prompt: {self.role.role_prompt}")
|
|
||||||
# logger.debug(f"{self.role.role_name} input_query: {input_query}")
|
|
||||||
# logger.debug(f"{self.role.role_name} doc_infos: {doc_infos}")
|
|
||||||
# logger.debug(f"{self.role.role_name} tool_names: {tool_names}")
|
|
||||||
# prompt += "\n" + REACT_PROMPT_INPUT.format(**{"query": input_query})
|
|
||||||
|
|
||||||
# prompt = extra_system_prompt.format(**{"query": input_query, "doc_infos": doc_infos, "formatted_tools": formatted_tools, "tool_names": tool_names})
|
|
||||||
while "{{" in prompt or "}}" in prompt:
|
|
||||||
prompt = prompt.replace("{{", "{")
|
|
||||||
prompt = prompt.replace("}}", "}")
|
|
||||||
return prompt
|
|
||||||
|
|
||||||
|
|
|
@ -3,13 +3,15 @@ import copy
|
||||||
import random
|
import random
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
|
from langchain.schema import BaseRetriever
|
||||||
|
|
||||||
from coagent.connector.schema import (
|
from coagent.connector.schema import (
|
||||||
Memory, Task, Role, Message, PromptField, LogVerboseEnum
|
Memory, Task, Role, Message, PromptField, LogVerboseEnum
|
||||||
)
|
)
|
||||||
from coagent.connector.memory_manager import BaseMemoryManager
|
from coagent.connector.memory_manager import BaseMemoryManager
|
||||||
from coagent.connector.configs.prompts import BEGIN_PROMPT_INPUT
|
|
||||||
from coagent.connector.memory_manager import LocalMemoryManager
|
from coagent.connector.memory_manager import LocalMemoryManager
|
||||||
from coagent.llm_models import LLMConfig, EmbedConfig
|
from coagent.llm_models import LLMConfig, EmbedConfig
|
||||||
|
from coagent.base_configs.env_config import JUPYTER_WORK_PATH, KB_ROOT_PATH
|
||||||
from .base_agent import BaseAgent
|
from .base_agent import BaseAgent
|
||||||
|
|
||||||
|
|
||||||
|
@ -30,14 +32,17 @@ class SelectorAgent(BaseAgent):
|
||||||
llm_config: LLMConfig = None,
|
llm_config: LLMConfig = None,
|
||||||
embed_config: EmbedConfig = None,
|
embed_config: EmbedConfig = None,
|
||||||
sandbox_server: dict = {},
|
sandbox_server: dict = {},
|
||||||
jupyter_work_path: str = "",
|
jupyter_work_path: str = JUPYTER_WORK_PATH,
|
||||||
kb_root_path: str = "",
|
kb_root_path: str = KB_ROOT_PATH,
|
||||||
|
doc_retrieval: Union[BaseRetriever] = None,
|
||||||
|
code_retrieval = None,
|
||||||
|
search_retrieval = None,
|
||||||
log_verbose: str = "0"
|
log_verbose: str = "0"
|
||||||
):
|
):
|
||||||
|
|
||||||
super().__init__(role, prompt_config, prompt_manager_type, task, memory, chat_turn,
|
super().__init__(role, prompt_config, prompt_manager_type, task, memory, chat_turn,
|
||||||
focus_agents, focus_message_keys, llm_config, embed_config, sandbox_server,
|
focus_agents, focus_message_keys, llm_config, embed_config, sandbox_server,
|
||||||
jupyter_work_path, kb_root_path, log_verbose
|
jupyter_work_path, kb_root_path, doc_retrieval, code_retrieval, search_retrieval, log_verbose
|
||||||
)
|
)
|
||||||
self.group_agents = group_agents
|
self.group_agents = group_agents
|
||||||
|
|
||||||
|
@ -56,9 +61,8 @@ class SelectorAgent(BaseAgent):
|
||||||
llm_config=self.embed_config
|
llm_config=self.embed_config
|
||||||
)
|
)
|
||||||
memory_manager.append(query)
|
memory_manager.append(query)
|
||||||
memory_pool = memory_manager.current_memory
|
memory_pool = memory_manager.get_memory_pool(query_c.user_name)
|
||||||
else:
|
|
||||||
memory_pool = memory_manager.current_memory
|
|
||||||
prompt = self.prompt_manager.generate_full_prompt(
|
prompt = self.prompt_manager.generate_full_prompt(
|
||||||
previous_agent_message=query_c, agent_long_term_memory=self.memory, ui_history=history, chain_summary_messages=background, react_memory=None,
|
previous_agent_message=query_c, agent_long_term_memory=self.memory, ui_history=history, chain_summary_messages=background, react_memory=None,
|
||||||
memory_pool=memory_pool, agents=self.group_agents)
|
memory_pool=memory_pool, agents=self.group_agents)
|
||||||
|
@ -90,6 +94,9 @@ class SelectorAgent(BaseAgent):
|
||||||
for agent in self.group_agents:
|
for agent in self.group_agents:
|
||||||
if agent.role.role_name == select_message.parsed_output.get("Role", ""):
|
if agent.role.role_name == select_message.parsed_output.get("Role", ""):
|
||||||
break
|
break
|
||||||
|
|
||||||
|
# 把除了role以外的信息传给下一个agent
|
||||||
|
query_c.parsed_output.update({k:v for k,v in select_message.parsed_output.items() if k!="Role"})
|
||||||
for output_message in agent.astep(query_c, history, background=background, memory_manager=memory_manager):
|
for output_message in agent.astep(query_c, history, background=background, memory_manager=memory_manager):
|
||||||
yield output_message or select_message
|
yield output_message or select_message
|
||||||
# update self_memory
|
# update self_memory
|
||||||
|
@ -103,6 +110,7 @@ class SelectorAgent(BaseAgent):
|
||||||
memory_manager.append(output_message)
|
memory_manager.append(output_message)
|
||||||
|
|
||||||
select_message.parsed_output = output_message.parsed_output
|
select_message.parsed_output = output_message.parsed_output
|
||||||
|
select_message.spec_parsed_output.update(output_message.spec_parsed_output)
|
||||||
select_message.parsed_output_list.extend(output_message.parsed_output_list)
|
select_message.parsed_output_list.extend(output_message.parsed_output_list)
|
||||||
yield select_message
|
yield select_message
|
||||||
|
|
||||||
|
@ -114,77 +122,4 @@ class SelectorAgent(BaseAgent):
|
||||||
print("#"*len(title) + f"\n{title}\n"+ "#"*len(title)+ f"\n\n{prompt}\n\n")
|
print("#"*len(title) + f"\n{title}\n"+ "#"*len(title)+ f"\n\n{prompt}\n\n")
|
||||||
|
|
||||||
for agent in self.group_agents:
|
for agent in self.group_agents:
|
||||||
agent.pre_print(query=query, history=history, background=background, memory_manager=memory_manager)
|
agent.pre_print(query=query, history=history, background=background, memory_manager=memory_manager)
|
||||||
|
|
||||||
# def create_prompt(
|
|
||||||
# self, query: Message, memory: Memory =None, history: Memory = None, background: Memory = None, memory_manager: BaseMemoryManager=None, prompt_mamnger=None) -> str:
|
|
||||||
# '''
|
|
||||||
# role\task\tools\docs\memory
|
|
||||||
# '''
|
|
||||||
# #
|
|
||||||
# doc_infos = self.create_doc_prompt(query)
|
|
||||||
# code_infos = self.create_codedoc_prompt(query)
|
|
||||||
# #
|
|
||||||
# formatted_tools, tool_names, tools_descs = self.create_tools_prompt(query)
|
|
||||||
# agent_names, agents = self.create_agent_names()
|
|
||||||
# task_prompt = self.create_task_prompt(query)
|
|
||||||
# background_prompt = self.create_background_prompt(background)
|
|
||||||
# history_prompt = self.create_history_prompt(history)
|
|
||||||
# selfmemory_prompt = self.create_selfmemory_prompt(memory, control_key="step_content")
|
|
||||||
|
|
||||||
|
|
||||||
# DocInfos = ""
|
|
||||||
# if doc_infos is not None and doc_infos!="" and doc_infos!="不存在知识库辅助信息":
|
|
||||||
# DocInfos += f"\nDocument Information: {doc_infos}"
|
|
||||||
|
|
||||||
# if code_infos is not None and code_infos!="" and code_infos!="不存在代码库辅助信息":
|
|
||||||
# DocInfos += f"\nCodeBase Infomation: {code_infos}"
|
|
||||||
|
|
||||||
# input_query = query.input_query
|
|
||||||
# logger.debug(f"{self.role.role_name} input_query: {input_query}")
|
|
||||||
# prompt = self.role.role_prompt.format(**{"agent_names": agent_names, "agents": agents, "formatted_tools": tools_descs, "tool_names": tool_names})
|
|
||||||
# #
|
|
||||||
# memory_pool_select_by_agent_key = self.select_memory_by_agent_key(memory_manager.current_memory)
|
|
||||||
# memory_pool_select_by_agent_key_context = '\n\n'.join([f"*{k}*\n{v}" for parsed_output in memory_pool_select_by_agent_key.get_parserd_output_list() for k, v in parsed_output.items() if k not in ['Action Status']])
|
|
||||||
|
|
||||||
# input_keys = parse_section(self.role.role_prompt, 'Input Format')
|
|
||||||
# #
|
|
||||||
# prompt += "\n" + BEGIN_PROMPT_INPUT
|
|
||||||
# for input_key in input_keys:
|
|
||||||
# if input_key == "Origin Query":
|
|
||||||
# prompt += "\n**Origin Query:**\n" + query.origin_query
|
|
||||||
# elif input_key == "Context":
|
|
||||||
# context = "\n".join([f"*{k}*\n{v}" for i in query.parsed_output_list for k,v in i.items() if "Action Status" !=k])
|
|
||||||
# if history:
|
|
||||||
# context = history_prompt + "\n" + context
|
|
||||||
# if not context:
|
|
||||||
# context = "there is no context"
|
|
||||||
|
|
||||||
# if self.focus_agents and memory_pool_select_by_agent_key_context:
|
|
||||||
# context = memory_pool_select_by_agent_key_context
|
|
||||||
# prompt += "\n**Context:**\n" + context + "\n" + input_query
|
|
||||||
# elif input_key == "DocInfos":
|
|
||||||
# prompt += "\n**DocInfos:**\n" + DocInfos
|
|
||||||
# elif input_key == "Question":
|
|
||||||
# prompt += "\n**Question:**\n" + input_query
|
|
||||||
|
|
||||||
# while "{{" in prompt or "}}" in prompt:
|
|
||||||
# prompt = prompt.replace("{{", "{")
|
|
||||||
# prompt = prompt.replace("}}", "}")
|
|
||||||
|
|
||||||
# # logger.debug(f"{self.role.role_name} prompt: {prompt}")
|
|
||||||
# return prompt
|
|
||||||
|
|
||||||
# def create_agent_names(self):
|
|
||||||
# random.shuffle(self.group_agents)
|
|
||||||
# agent_names = ", ".join([f'{agent.role.role_name}' for agent in self.group_agents])
|
|
||||||
# agent_descs = []
|
|
||||||
# for agent in self.group_agents:
|
|
||||||
# role_desc = agent.role.role_prompt.split("####")[1]
|
|
||||||
# while "\n\n" in role_desc:
|
|
||||||
# role_desc = role_desc.replace("\n\n", "\n")
|
|
||||||
# role_desc = role_desc.replace("\n", ",")
|
|
||||||
|
|
||||||
# agent_descs.append(f'"role name: {agent.role.role_name}\nrole description: {role_desc}"')
|
|
||||||
|
|
||||||
# return agent_names, "\n".join(agent_descs)
|
|
|
@ -0,0 +1,7 @@
|
||||||
|
from .flow import AgentFlow, PhaseFlow, ChainFlow
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"AgentFlow", "PhaseFlow", "ChainFlow"
|
||||||
|
]
|
|
@ -0,0 +1,255 @@
|
||||||
|
import importlib
|
||||||
|
from typing import List, Union, Dict, Any
|
||||||
|
from loguru import logger
|
||||||
|
import os
|
||||||
|
from langchain.embeddings.base import Embeddings
|
||||||
|
from langchain.agents import Tool
|
||||||
|
from langchain.llms.base import BaseLLM, LLM
|
||||||
|
|
||||||
|
from coagent.retrieval.base_retrieval import IMRertrieval
|
||||||
|
from coagent.llm_models.llm_config import EmbedConfig, LLMConfig
|
||||||
|
from coagent.connector.phase import BasePhase
|
||||||
|
from coagent.connector.agents import BaseAgent
|
||||||
|
from coagent.connector.chains import BaseChain
|
||||||
|
from coagent.connector.schema import Message, Role, PromptField, ChainConfig
|
||||||
|
from coagent.tools import toLangchainTools, TOOL_DICT, TOOL_SETS
|
||||||
|
|
||||||
|
|
||||||
|
class AgentFlow:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
role_name: str,
|
||||||
|
agent_type: str,
|
||||||
|
role_type: str = "assistant",
|
||||||
|
agent_index: int = 0,
|
||||||
|
role_prompt: str = "",
|
||||||
|
prompt_config: List[Dict[str, Any]] = [],
|
||||||
|
prompt_manager_type: str = "PromptManager",
|
||||||
|
chat_turn: int = 3,
|
||||||
|
focus_agents: List[str] = [],
|
||||||
|
focus_messages: List[str] = [],
|
||||||
|
embeddings: Embeddings = None,
|
||||||
|
llm: BaseLLM = None,
|
||||||
|
doc_retrieval: IMRertrieval = None,
|
||||||
|
code_retrieval: IMRertrieval = None,
|
||||||
|
search_retrieval: IMRertrieval = None,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
self.role_type = role_type
|
||||||
|
self.role_name = role_name
|
||||||
|
self.agent_type = agent_type
|
||||||
|
self.role_prompt = role_prompt
|
||||||
|
self.agent_index = agent_index
|
||||||
|
|
||||||
|
self.prompt_config = prompt_config
|
||||||
|
self.prompt_manager_type = prompt_manager_type
|
||||||
|
|
||||||
|
self.chat_turn = chat_turn
|
||||||
|
self.focus_agents = focus_agents
|
||||||
|
self.focus_messages = focus_messages
|
||||||
|
|
||||||
|
self.embeddings = embeddings
|
||||||
|
self.llm = llm
|
||||||
|
self.doc_retrieval = doc_retrieval
|
||||||
|
self.code_retrieval = code_retrieval
|
||||||
|
self.search_retrieval = search_retrieval
|
||||||
|
# self.build_config()
|
||||||
|
# self.build_agent()
|
||||||
|
|
||||||
|
def build_config(self, embeddings: Embeddings = None, llm: BaseLLM = None):
|
||||||
|
self.llm_config = LLMConfig(model_name="test", llm=self.llm or llm)
|
||||||
|
self.embed_config = EmbedConfig(embed_model="test", langchain_embeddings=self.embeddings or embeddings)
|
||||||
|
|
||||||
|
def build_agent(self,
|
||||||
|
embeddings: Embeddings = None, llm: BaseLLM = None,
|
||||||
|
doc_retrieval: IMRertrieval = None,
|
||||||
|
code_retrieval: IMRertrieval = None,
|
||||||
|
search_retrieval: IMRertrieval = None,
|
||||||
|
):
|
||||||
|
# 可注册个性化的agent,仅通过start_action和end_action来注册
|
||||||
|
# class ExtraAgent(BaseAgent):
|
||||||
|
# def start_action_step(self, message: Message) -> Message:
|
||||||
|
# pass
|
||||||
|
|
||||||
|
# def end_action_step(self, message: Message) -> Message:
|
||||||
|
# pass
|
||||||
|
# agent_module = importlib.import_module("coagent.connector.agents")
|
||||||
|
# setattr(agent_module, 'extraAgent', ExtraAgent)
|
||||||
|
|
||||||
|
# 可注册个性化的prompt组装方式,
|
||||||
|
# class CodeRetrievalPM(PromptManager):
|
||||||
|
# def handle_code_packages(self, **kwargs) -> str:
|
||||||
|
# if 'previous_agent_message' not in kwargs:
|
||||||
|
# return ""
|
||||||
|
# previous_agent_message: Message = kwargs['previous_agent_message']
|
||||||
|
# # 由于两个agent共用了同一个manager,所以临时性处理
|
||||||
|
# vertices = previous_agent_message.customed_kargs.get("RelatedVerticesRetrivalRes", {}).get("vertices", [])
|
||||||
|
# return ", ".join([str(v) for v in vertices])
|
||||||
|
|
||||||
|
# prompt_manager_module = importlib.import_module("coagent.connector.prompt_manager")
|
||||||
|
# setattr(prompt_manager_module, 'CodeRetrievalPM', CodeRetrievalPM)
|
||||||
|
|
||||||
|
# agent实例化
|
||||||
|
agent_module = importlib.import_module("coagent.connector.agents")
|
||||||
|
baseAgent: BaseAgent = getattr(agent_module, self.agent_type)
|
||||||
|
role = Role(
|
||||||
|
role_type=self.agent_type, role_name=self.role_name,
|
||||||
|
agent_type=self.agent_type, role_prompt=self.role_prompt,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.build_config(embeddings, llm)
|
||||||
|
self.agent = baseAgent(
|
||||||
|
role=role,
|
||||||
|
prompt_config = [PromptField(**config) for config in self.prompt_config],
|
||||||
|
prompt_manager_type=self.prompt_manager_type,
|
||||||
|
chat_turn=self.chat_turn,
|
||||||
|
focus_agents=self.focus_agents,
|
||||||
|
focus_message_keys=self.focus_messages,
|
||||||
|
llm_config=self.llm_config,
|
||||||
|
embed_config=self.embed_config,
|
||||||
|
doc_retrieval=doc_retrieval or self.doc_retrieval,
|
||||||
|
code_retrieval=code_retrieval or self.code_retrieval,
|
||||||
|
search_retrieval=search_retrieval or self.search_retrieval,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class ChainFlow:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
chain_name: str,
|
||||||
|
chain_index: int = 0,
|
||||||
|
agent_flows: List[AgentFlow] = [],
|
||||||
|
chat_turn: int = 5,
|
||||||
|
do_checker: bool = False,
|
||||||
|
embeddings: Embeddings = None,
|
||||||
|
llm: BaseLLM = None,
|
||||||
|
doc_retrieval: IMRertrieval = None,
|
||||||
|
code_retrieval: IMRertrieval = None,
|
||||||
|
search_retrieval: IMRertrieval = None,
|
||||||
|
# chain_type: str = "BaseChain",
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
self.agent_flows = sorted(agent_flows, key=lambda x:x.agent_index)
|
||||||
|
self.chat_turn = chat_turn
|
||||||
|
self.do_checker = do_checker
|
||||||
|
self.chain_name = chain_name
|
||||||
|
self.chain_index = chain_index
|
||||||
|
self.chain_type = "BaseChain"
|
||||||
|
|
||||||
|
self.embeddings = embeddings
|
||||||
|
self.llm = llm
|
||||||
|
|
||||||
|
self.doc_retrieval = doc_retrieval
|
||||||
|
self.code_retrieval = code_retrieval
|
||||||
|
self.search_retrieval = search_retrieval
|
||||||
|
# self.build_config()
|
||||||
|
# self.build_chain()
|
||||||
|
|
||||||
|
def build_config(self, embeddings: Embeddings = None, llm: BaseLLM = None):
|
||||||
|
self.llm_config = LLMConfig(model_name="test", llm=self.llm or llm)
|
||||||
|
self.embed_config = EmbedConfig(embed_model="test", langchain_embeddings=self.embeddings or embeddings)
|
||||||
|
|
||||||
|
def build_chain(self,
|
||||||
|
embeddings: Embeddings = None, llm: BaseLLM = None,
|
||||||
|
doc_retrieval: IMRertrieval = None,
|
||||||
|
code_retrieval: IMRertrieval = None,
|
||||||
|
search_retrieval: IMRertrieval = None,
|
||||||
|
):
|
||||||
|
# chain 实例化
|
||||||
|
chain_module = importlib.import_module("coagent.connector.chains")
|
||||||
|
baseChain: BaseChain = getattr(chain_module, self.chain_type)
|
||||||
|
|
||||||
|
agent_names = [agent_flow.role_name for agent_flow in self.agent_flows]
|
||||||
|
chain_config = ChainConfig(chain_name=self.chain_name, agents=agent_names, do_checker=self.do_checker, chat_turn=self.chat_turn)
|
||||||
|
|
||||||
|
# agent 实例化
|
||||||
|
self.build_config(embeddings, llm)
|
||||||
|
for agent_flow in self.agent_flows:
|
||||||
|
agent_flow.build_agent(embeddings, llm)
|
||||||
|
|
||||||
|
self.chain = baseChain(
|
||||||
|
chain_config,
|
||||||
|
[agent_flow.agent for agent_flow in self.agent_flows],
|
||||||
|
embed_config=self.embed_config,
|
||||||
|
llm_config=self.llm_config,
|
||||||
|
doc_retrieval=doc_retrieval or self.doc_retrieval,
|
||||||
|
code_retrieval=code_retrieval or self.code_retrieval,
|
||||||
|
search_retrieval=search_retrieval or self.search_retrieval,
|
||||||
|
)
|
||||||
|
|
||||||
|
class PhaseFlow:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
phase_name: str,
|
||||||
|
chain_flows: List[ChainFlow],
|
||||||
|
embeddings: Embeddings = None,
|
||||||
|
llm: BaseLLM = None,
|
||||||
|
tools: List[Tool] = [],
|
||||||
|
doc_retrieval: IMRertrieval = None,
|
||||||
|
code_retrieval: IMRertrieval = None,
|
||||||
|
search_retrieval: IMRertrieval = None,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
self.phase_name = phase_name
|
||||||
|
self.chain_flows = sorted(chain_flows, key=lambda x:x.chain_index)
|
||||||
|
self.phase_type = "BasePhase"
|
||||||
|
self.tools = tools
|
||||||
|
|
||||||
|
self.embeddings = embeddings
|
||||||
|
self.llm = llm
|
||||||
|
|
||||||
|
self.doc_retrieval = doc_retrieval
|
||||||
|
self.code_retrieval = code_retrieval
|
||||||
|
self.search_retrieval = search_retrieval
|
||||||
|
# self.build_config()
|
||||||
|
self.build_phase()
|
||||||
|
|
||||||
|
def __call__(self, params: dict) -> str:
|
||||||
|
|
||||||
|
# tools = toLangchainTools([TOOL_DICT[i] for i in TOOL_SETS if i in TOOL_DICT])
|
||||||
|
# query_content = "帮我确认下127.0.0.1这个服务器的在10点是否存在异常,请帮我判断一下"
|
||||||
|
try:
|
||||||
|
logger.info(f"params: {params}")
|
||||||
|
query_content = params.get("query") or params.get("input")
|
||||||
|
search_type = params.get("search_type")
|
||||||
|
query = Message(
|
||||||
|
role_name="human", role_type="user", tools=self.tools,
|
||||||
|
role_content=query_content, input_query=query_content, origin_query=query_content,
|
||||||
|
cb_search_type=search_type,
|
||||||
|
)
|
||||||
|
# phase.pre_print(query)
|
||||||
|
output_message, output_memory = self.phase.step(query)
|
||||||
|
output_content = "\n\n".join((output_memory.to_str_messages(return_all=True, content_key="parsed_output_list").split("\n\n")[1:])) or output_message.role_content
|
||||||
|
return output_content
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(e)
|
||||||
|
return f"Error {e}"
|
||||||
|
|
||||||
|
def build_config(self, embeddings: Embeddings = None, llm: BaseLLM = None):
|
||||||
|
self.llm_config = LLMConfig(model_name="test", llm=self.llm or llm)
|
||||||
|
self.embed_config = EmbedConfig(embed_model="test", langchain_embeddings=self.embeddings or embeddings)
|
||||||
|
|
||||||
|
def build_phase(self, embeddings: Embeddings = None, llm: BaseLLM = None):
|
||||||
|
# phase 实例化
|
||||||
|
phase_module = importlib.import_module("coagent.connector.phase")
|
||||||
|
basePhase: BasePhase = getattr(phase_module, self.phase_type)
|
||||||
|
|
||||||
|
# chain 实例化
|
||||||
|
self.build_config(self.embeddings or embeddings, self.llm or llm)
|
||||||
|
os.environ["log_verbose"] = "2"
|
||||||
|
for chain_flow in self.chain_flows:
|
||||||
|
chain_flow.build_chain(
|
||||||
|
self.embeddings or embeddings, self.llm or llm,
|
||||||
|
self.doc_retrieval, self.code_retrieval, self.search_retrieval
|
||||||
|
)
|
||||||
|
|
||||||
|
self.phase: BasePhase = basePhase(
|
||||||
|
phase_name=self.phase_name,
|
||||||
|
chains=[chain_flow.chain for chain_flow in self.chain_flows],
|
||||||
|
embed_config=self.embed_config,
|
||||||
|
llm_config=self.llm_config,
|
||||||
|
doc_retrieval=self.doc_retrieval,
|
||||||
|
code_retrieval=self.code_retrieval,
|
||||||
|
search_retrieval=self.search_retrieval
|
||||||
|
)
|
|
@ -1,9 +1,10 @@
|
||||||
from typing import List
|
from typing import List, Tuple, Union
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
import copy, os
|
import copy, os
|
||||||
|
|
||||||
from coagent.connector.agents import BaseAgent
|
from langchain.schema import BaseRetriever
|
||||||
|
|
||||||
|
from coagent.connector.agents import BaseAgent
|
||||||
from coagent.connector.schema import (
|
from coagent.connector.schema import (
|
||||||
Memory, Role, Message, ActionStatus, ChainConfig,
|
Memory, Role, Message, ActionStatus, ChainConfig,
|
||||||
load_role_configs
|
load_role_configs
|
||||||
|
@ -11,31 +12,32 @@ from coagent.connector.schema import (
|
||||||
from coagent.connector.memory_manager import BaseMemoryManager
|
from coagent.connector.memory_manager import BaseMemoryManager
|
||||||
from coagent.connector.message_process import MessageUtils
|
from coagent.connector.message_process import MessageUtils
|
||||||
from coagent.llm_models.llm_config import LLMConfig, EmbedConfig
|
from coagent.llm_models.llm_config import LLMConfig, EmbedConfig
|
||||||
|
from coagent.base_configs.env_config import JUPYTER_WORK_PATH, KB_ROOT_PATH
|
||||||
from coagent.connector.configs.agent_config import AGETN_CONFIGS
|
from coagent.connector.configs.agent_config import AGETN_CONFIGS
|
||||||
role_configs = load_role_configs(AGETN_CONFIGS)
|
role_configs = load_role_configs(AGETN_CONFIGS)
|
||||||
|
|
||||||
# from configs.model_config import JUPYTER_WORK_PATH
|
|
||||||
# from configs.server_config import SANDBOX_SERVER
|
|
||||||
|
|
||||||
|
|
||||||
class BaseChain:
|
class BaseChain:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
# chainConfig: ChainConfig,
|
chainConfig: ChainConfig,
|
||||||
agents: List[BaseAgent],
|
agents: List[BaseAgent],
|
||||||
chat_turn: int = 1,
|
# chat_turn: int = 1,
|
||||||
do_checker: bool = False,
|
# do_checker: bool = False,
|
||||||
sandbox_server: dict = {},
|
sandbox_server: dict = {},
|
||||||
jupyter_work_path: str = "",
|
jupyter_work_path: str = JUPYTER_WORK_PATH,
|
||||||
kb_root_path: str = "",
|
kb_root_path: str = KB_ROOT_PATH,
|
||||||
llm_config: LLMConfig = LLMConfig(),
|
llm_config: LLMConfig = LLMConfig(),
|
||||||
embed_config: EmbedConfig = None,
|
embed_config: EmbedConfig = None,
|
||||||
|
doc_retrieval: Union[BaseRetriever] = None,
|
||||||
|
code_retrieval = None,
|
||||||
|
search_retrieval = None,
|
||||||
log_verbose: str = "0"
|
log_verbose: str = "0"
|
||||||
) -> None:
|
) -> None:
|
||||||
# self.chainConfig = chainConfig
|
self.chainConfig = chainConfig
|
||||||
self.agents: List[BaseAgent] = agents
|
self.agents: List[BaseAgent] = agents
|
||||||
self.chat_turn = chat_turn
|
self.chat_turn = chainConfig.chat_turn
|
||||||
self.do_checker = do_checker
|
self.do_checker = chainConfig.do_checker
|
||||||
self.sandbox_server = sandbox_server
|
self.sandbox_server = sandbox_server
|
||||||
self.jupyter_work_path = jupyter_work_path
|
self.jupyter_work_path = jupyter_work_path
|
||||||
self.llm_config = llm_config
|
self.llm_config = llm_config
|
||||||
|
@ -45,9 +47,11 @@ class BaseChain:
|
||||||
task = None, memory = None,
|
task = None, memory = None,
|
||||||
llm_config=llm_config, embed_config=embed_config,
|
llm_config=llm_config, embed_config=embed_config,
|
||||||
sandbox_server=sandbox_server, jupyter_work_path=jupyter_work_path,
|
sandbox_server=sandbox_server, jupyter_work_path=jupyter_work_path,
|
||||||
kb_root_path=kb_root_path
|
kb_root_path=kb_root_path,
|
||||||
|
doc_retrieval=doc_retrieval, code_retrieval=code_retrieval,
|
||||||
|
search_retrieval=search_retrieval
|
||||||
)
|
)
|
||||||
self.messageUtils = MessageUtils(None, sandbox_server, self.jupyter_work_path, embed_config, llm_config, kb_root_path, log_verbose)
|
self.messageUtils = MessageUtils(None, sandbox_server, self.jupyter_work_path, embed_config, llm_config, kb_root_path, doc_retrieval, code_retrieval, search_retrieval, log_verbose)
|
||||||
# all memory created by agent until instance deleted
|
# all memory created by agent until instance deleted
|
||||||
self.global_memory = Memory(messages=[])
|
self.global_memory = Memory(messages=[])
|
||||||
|
|
||||||
|
@ -62,13 +66,16 @@ class BaseChain:
|
||||||
for agent in self.agents:
|
for agent in self.agents:
|
||||||
agent.pre_print(query, history, background=background, memory_manager=memory_manager)
|
agent.pre_print(query, history, background=background, memory_manager=memory_manager)
|
||||||
|
|
||||||
def astep(self, query: Message, history: Memory = None, background: Memory = None, memory_manager: BaseMemoryManager = None) -> Message:
|
def astep(self, query: Message, history: Memory = None, background: Memory = None, memory_manager: BaseMemoryManager = None) -> Tuple[Message, Memory]:
|
||||||
'''execute chain'''
|
'''execute chain'''
|
||||||
local_memory = Memory(messages=[])
|
local_memory = Memory(messages=[])
|
||||||
input_message = copy.deepcopy(query)
|
input_message = copy.deepcopy(query)
|
||||||
step_nums = copy.deepcopy(self.chat_turn)
|
step_nums = copy.deepcopy(self.chat_turn)
|
||||||
check_message = None
|
check_message = None
|
||||||
|
|
||||||
|
# if input_message not in memory_manager:
|
||||||
|
# memory_manager.append(input_message)
|
||||||
|
|
||||||
self.global_memory.append(input_message)
|
self.global_memory.append(input_message)
|
||||||
# local_memory.append(input_message)
|
# local_memory.append(input_message)
|
||||||
while step_nums > 0:
|
while step_nums > 0:
|
||||||
|
@ -78,7 +85,7 @@ class BaseChain:
|
||||||
yield output_message, local_memory + output_message
|
yield output_message, local_memory + output_message
|
||||||
output_message = self.messageUtils.inherit_extrainfo(input_message, output_message)
|
output_message = self.messageUtils.inherit_extrainfo(input_message, output_message)
|
||||||
# according the output to choose one action for code_content or tool_content
|
# according the output to choose one action for code_content or tool_content
|
||||||
output_message = self.messageUtils.parser(output_message)
|
# output_message = self.messageUtils.parser(output_message)
|
||||||
yield output_message, local_memory + output_message
|
yield output_message, local_memory + output_message
|
||||||
# output_message = self.step_router(output_message)
|
# output_message = self.step_router(output_message)
|
||||||
input_message = output_message
|
input_message = output_message
|
||||||
|
|
|
@ -1,9 +1,10 @@
|
||||||
from .agent_config import AGETN_CONFIGS
|
from .agent_config import AGETN_CONFIGS
|
||||||
from .chain_config import CHAIN_CONFIGS
|
from .chain_config import CHAIN_CONFIGS
|
||||||
from .phase_config import PHASE_CONFIGS
|
from .phase_config import PHASE_CONFIGS
|
||||||
from .prompt_config import BASE_PROMPT_CONFIGS, EXECUTOR_PROMPT_CONFIGS, SELECTOR_PROMPT_CONFIGS, BASE_NOTOOLPROMPT_CONFIGS
|
from .prompt_config import *
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"AGETN_CONFIGS", "CHAIN_CONFIGS", "PHASE_CONFIGS",
|
"AGETN_CONFIGS", "CHAIN_CONFIGS", "PHASE_CONFIGS",
|
||||||
"BASE_PROMPT_CONFIGS", "EXECUTOR_PROMPT_CONFIGS", "SELECTOR_PROMPT_CONFIGS", "BASE_NOTOOLPROMPT_CONFIGS"
|
"BASE_PROMPT_CONFIGS", "EXECUTOR_PROMPT_CONFIGS", "SELECTOR_PROMPT_CONFIGS", "BASE_NOTOOLPROMPT_CONFIGS",
|
||||||
|
"CODE2DOC_GROUP_PROMPT_CONFIGS", "CODE2DOC_PROMPT_CONFIGS", "CODE2TESTS_PROMPT_CONFIGS"
|
||||||
]
|
]
|
|
@ -1,19 +1,21 @@
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from .prompts import (
|
from .prompts import *
|
||||||
REACT_PROMPT_INPUT, CHECK_PROMPT_INPUT, EXECUTOR_PROMPT_INPUT, CONTEXT_PROMPT_INPUT, QUERY_CONTEXT_PROMPT_INPUT,PLAN_PROMPT_INPUT,
|
# from .prompts import (
|
||||||
RECOGNIZE_INTENTION_PROMPT,
|
# REACT_PROMPT_INPUT, CHECK_PROMPT_INPUT, EXECUTOR_PROMPT_INPUT, CONTEXT_PROMPT_INPUT, QUERY_CONTEXT_PROMPT_INPUT,PLAN_PROMPT_INPUT,
|
||||||
CHECKER_TEMPLATE_PROMPT,
|
# RECOGNIZE_INTENTION_PROMPT,
|
||||||
CONV_SUMMARY_PROMPT,
|
# CHECKER_TEMPLATE_PROMPT,
|
||||||
QA_PROMPT, CODE_QA_PROMPT, QA_TEMPLATE_PROMPT,
|
# CONV_SUMMARY_PROMPT,
|
||||||
EXECUTOR_TEMPLATE_PROMPT,
|
# QA_PROMPT, CODE_QA_PROMPT, QA_TEMPLATE_PROMPT,
|
||||||
REFINE_TEMPLATE_PROMPT,
|
# EXECUTOR_TEMPLATE_PROMPT,
|
||||||
SELECTOR_AGENT_TEMPLATE_PROMPT,
|
# REFINE_TEMPLATE_PROMPT,
|
||||||
PLANNER_TEMPLATE_PROMPT, GENERAL_PLANNER_PROMPT, DATA_PLANNER_PROMPT, TOOL_PLANNER_PROMPT,
|
# SELECTOR_AGENT_TEMPLATE_PROMPT,
|
||||||
PRD_WRITER_METAGPT_PROMPT, DESIGN_WRITER_METAGPT_PROMPT, TASK_WRITER_METAGPT_PROMPT, CODE_WRITER_METAGPT_PROMPT,
|
# PLANNER_TEMPLATE_PROMPT, GENERAL_PLANNER_PROMPT, DATA_PLANNER_PROMPT, TOOL_PLANNER_PROMPT,
|
||||||
REACT_TEMPLATE_PROMPT,
|
# PRD_WRITER_METAGPT_PROMPT, DESIGN_WRITER_METAGPT_PROMPT, TASK_WRITER_METAGPT_PROMPT, CODE_WRITER_METAGPT_PROMPT,
|
||||||
REACT_TOOL_PROMPT, REACT_CODE_PROMPT, REACT_TOOL_AND_CODE_PLANNER_PROMPT, REACT_TOOL_AND_CODE_PROMPT
|
# REACT_TEMPLATE_PROMPT,
|
||||||
)
|
# REACT_TOOL_PROMPT, REACT_CODE_PROMPT, REACT_TOOL_AND_CODE_PLANNER_PROMPT, REACT_TOOL_AND_CODE_PROMPT
|
||||||
from .prompt_config import BASE_PROMPT_CONFIGS, EXECUTOR_PROMPT_CONFIGS, SELECTOR_PROMPT_CONFIGS, BASE_NOTOOLPROMPT_CONFIGS
|
# )
|
||||||
|
from .prompt_config import *
|
||||||
|
# BASE_PROMPT_CONFIGS, EXECUTOR_PROMPT_CONFIGS, SELECTOR_PROMPT_CONFIGS, BASE_NOTOOLPROMPT_CONFIGS
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -261,4 +263,68 @@ AGETN_CONFIGS = {
|
||||||
"focus_agents": ["metaGPT_DESIGN", "metaGPT_TASK"],
|
"focus_agents": ["metaGPT_DESIGN", "metaGPT_TASK"],
|
||||||
"focus_message_keys": [],
|
"focus_message_keys": [],
|
||||||
},
|
},
|
||||||
|
"class2Docer": {
|
||||||
|
"role": {
|
||||||
|
"role_prompt": Class2Doc_PROMPT,
|
||||||
|
"role_type": "assistant",
|
||||||
|
"role_name": "class2Docer",
|
||||||
|
"role_desc": "",
|
||||||
|
"agent_type": "CodeGenDocer"
|
||||||
|
},
|
||||||
|
"prompt_config": CODE2DOC_PROMPT_CONFIGS,
|
||||||
|
"prompt_manager_type": "Code2DocPM",
|
||||||
|
"chat_turn": 1,
|
||||||
|
"focus_agents": [],
|
||||||
|
"focus_message_keys": [],
|
||||||
|
},
|
||||||
|
"func2Docer": {
|
||||||
|
"role": {
|
||||||
|
"role_prompt": Func2Doc_PROMPT,
|
||||||
|
"role_type": "assistant",
|
||||||
|
"role_name": "func2Docer",
|
||||||
|
"role_desc": "",
|
||||||
|
"agent_type": "CodeGenDocer"
|
||||||
|
},
|
||||||
|
"prompt_config": CODE2DOC_PROMPT_CONFIGS,
|
||||||
|
"prompt_manager_type": "Code2DocPM",
|
||||||
|
"chat_turn": 1,
|
||||||
|
"focus_agents": [],
|
||||||
|
"focus_message_keys": [],
|
||||||
|
},
|
||||||
|
"code2DocsGrouper": {
|
||||||
|
"role": {
|
||||||
|
"role_prompt": Code2DocGroup_PROMPT,
|
||||||
|
"role_type": "assistant",
|
||||||
|
"role_name": "code2DocsGrouper",
|
||||||
|
"role_desc": "",
|
||||||
|
"agent_type": "SelectorAgent"
|
||||||
|
},
|
||||||
|
"prompt_config": CODE2DOC_GROUP_PROMPT_CONFIGS,
|
||||||
|
"group_agents": ["class2Docer", "func2Docer"],
|
||||||
|
"chat_turn": 1,
|
||||||
|
},
|
||||||
|
"Code2TestJudger": {
|
||||||
|
"role": {
|
||||||
|
"role_prompt": judgeCode2Tests_PROMPT,
|
||||||
|
"role_type": "assistant",
|
||||||
|
"role_name": "Code2TestJudger",
|
||||||
|
"role_desc": "",
|
||||||
|
"agent_type": "CodeRetrieval"
|
||||||
|
},
|
||||||
|
"prompt_config": CODE2TESTS_PROMPT_CONFIGS,
|
||||||
|
"prompt_manager_type": "CodeRetrievalPM",
|
||||||
|
"chat_turn": 1,
|
||||||
|
},
|
||||||
|
"code2Tests": {
|
||||||
|
"role": {
|
||||||
|
"role_prompt": code2Tests_PROMPT,
|
||||||
|
"role_type": "assistant",
|
||||||
|
"role_name": "code2Tests",
|
||||||
|
"role_desc": "",
|
||||||
|
"agent_type": "CodeRetrieval"
|
||||||
|
},
|
||||||
|
"prompt_config": CODE2TESTS_PROMPT_CONFIGS,
|
||||||
|
"prompt_manager_type": "CodeRetrievalPM",
|
||||||
|
"chat_turn": 1,
|
||||||
|
},
|
||||||
}
|
}
|
|
@ -123,5 +123,21 @@ CHAIN_CONFIGS = {
|
||||||
"chat_turn": 1,
|
"chat_turn": 1,
|
||||||
"do_checker": False,
|
"do_checker": False,
|
||||||
"chain_prompt": ""
|
"chain_prompt": ""
|
||||||
|
},
|
||||||
|
"code2DocsGroupChain": {
|
||||||
|
"chain_name": "code2DocsGroupChain",
|
||||||
|
"chain_type": "BaseChain",
|
||||||
|
"agents": ["code2DocsGrouper"],
|
||||||
|
"chat_turn": 1,
|
||||||
|
"do_checker": False,
|
||||||
|
"chain_prompt": ""
|
||||||
|
},
|
||||||
|
"code2TestsChain": {
|
||||||
|
"chain_name": "code2TestsChain",
|
||||||
|
"chain_type": "BaseChain",
|
||||||
|
"agents": ["Code2TestJudger", "code2Tests"],
|
||||||
|
"chat_turn": 1,
|
||||||
|
"do_checker": False,
|
||||||
|
"chain_prompt": ""
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -14,44 +14,24 @@ PHASE_CONFIGS = {
|
||||||
"phase_name": "docChatPhase",
|
"phase_name": "docChatPhase",
|
||||||
"phase_type": "BasePhase",
|
"phase_type": "BasePhase",
|
||||||
"chains": ["docChatChain"],
|
"chains": ["docChatChain"],
|
||||||
"do_summary": False,
|
|
||||||
"do_search": False,
|
|
||||||
"do_doc_retrieval": True,
|
"do_doc_retrieval": True,
|
||||||
"do_code_retrieval": False,
|
|
||||||
"do_tool_retrieval": False,
|
|
||||||
"do_using_tool": False
|
|
||||||
},
|
},
|
||||||
"searchChatPhase": {
|
"searchChatPhase": {
|
||||||
"phase_name": "searchChatPhase",
|
"phase_name": "searchChatPhase",
|
||||||
"phase_type": "BasePhase",
|
"phase_type": "BasePhase",
|
||||||
"chains": ["searchChatChain"],
|
"chains": ["searchChatChain"],
|
||||||
"do_summary": False,
|
|
||||||
"do_search": True,
|
"do_search": True,
|
||||||
"do_doc_retrieval": False,
|
|
||||||
"do_code_retrieval": False,
|
|
||||||
"do_tool_retrieval": False,
|
|
||||||
"do_using_tool": False
|
|
||||||
},
|
},
|
||||||
"codeChatPhase": {
|
"codeChatPhase": {
|
||||||
"phase_name": "codeChatPhase",
|
"phase_name": "codeChatPhase",
|
||||||
"phase_type": "BasePhase",
|
"phase_type": "BasePhase",
|
||||||
"chains": ["codeChatChain"],
|
"chains": ["codeChatChain"],
|
||||||
"do_summary": False,
|
|
||||||
"do_search": False,
|
|
||||||
"do_doc_retrieval": False,
|
|
||||||
"do_code_retrieval": True,
|
"do_code_retrieval": True,
|
||||||
"do_tool_retrieval": False,
|
|
||||||
"do_using_tool": False
|
|
||||||
},
|
},
|
||||||
"toolReactPhase": {
|
"toolReactPhase": {
|
||||||
"phase_name": "toolReactPhase",
|
"phase_name": "toolReactPhase",
|
||||||
"phase_type": "BasePhase",
|
"phase_type": "BasePhase",
|
||||||
"chains": ["toolReactChain"],
|
"chains": ["toolReactChain"],
|
||||||
"do_summary": False,
|
|
||||||
"do_search": False,
|
|
||||||
"do_doc_retrieval": False,
|
|
||||||
"do_code_retrieval": False,
|
|
||||||
"do_tool_retrieval": False,
|
|
||||||
"do_using_tool": True
|
"do_using_tool": True
|
||||||
},
|
},
|
||||||
"codeReactPhase": {
|
"codeReactPhase": {
|
||||||
|
@ -59,55 +39,36 @@ PHASE_CONFIGS = {
|
||||||
"phase_type": "BasePhase",
|
"phase_type": "BasePhase",
|
||||||
# "chains": ["codePlannerChain", "codeReactChain"],
|
# "chains": ["codePlannerChain", "codeReactChain"],
|
||||||
"chains": ["planChain", "codeReactChain"],
|
"chains": ["planChain", "codeReactChain"],
|
||||||
"do_summary": False,
|
|
||||||
"do_search": False,
|
|
||||||
"do_doc_retrieval": False,
|
|
||||||
"do_code_retrieval": False,
|
|
||||||
"do_tool_retrieval": False,
|
|
||||||
"do_using_tool": False
|
|
||||||
},
|
},
|
||||||
"codeToolReactPhase": {
|
"codeToolReactPhase": {
|
||||||
"phase_name": "codeToolReactPhase",
|
"phase_name": "codeToolReactPhase",
|
||||||
"phase_type": "BasePhase",
|
"phase_type": "BasePhase",
|
||||||
"chains": ["codeToolPlanChain", "codeToolReactChain"],
|
"chains": ["codeToolPlanChain", "codeToolReactChain"],
|
||||||
"do_summary": False,
|
|
||||||
"do_search": False,
|
|
||||||
"do_doc_retrieval": False,
|
|
||||||
"do_code_retrieval": False,
|
|
||||||
"do_tool_retrieval": False,
|
|
||||||
"do_using_tool": True
|
"do_using_tool": True
|
||||||
},
|
},
|
||||||
"baseTaskPhase": {
|
"baseTaskPhase": {
|
||||||
"phase_name": "baseTaskPhase",
|
"phase_name": "baseTaskPhase",
|
||||||
"phase_type": "BasePhase",
|
"phase_type": "BasePhase",
|
||||||
"chains": ["planChain", "executorChain"],
|
"chains": ["planChain", "executorChain"],
|
||||||
"do_summary": False,
|
|
||||||
"do_search": False,
|
|
||||||
"do_doc_retrieval": False,
|
|
||||||
"do_code_retrieval": False,
|
|
||||||
"do_tool_retrieval": False,
|
|
||||||
"do_using_tool": False
|
|
||||||
},
|
},
|
||||||
"metagpt_code_devlop": {
|
"metagpt_code_devlop": {
|
||||||
"phase_name": "metagpt_code_devlop",
|
"phase_name": "metagpt_code_devlop",
|
||||||
"phase_type": "BasePhase",
|
"phase_type": "BasePhase",
|
||||||
"chains": ["metagptChain",],
|
"chains": ["metagptChain",],
|
||||||
"do_summary": False,
|
|
||||||
"do_search": False,
|
|
||||||
"do_doc_retrieval": False,
|
|
||||||
"do_code_retrieval": False,
|
|
||||||
"do_tool_retrieval": False,
|
|
||||||
"do_using_tool": False
|
|
||||||
},
|
},
|
||||||
"baseGroupPhase": {
|
"baseGroupPhase": {
|
||||||
"phase_name": "baseGroupPhase",
|
"phase_name": "baseGroupPhase",
|
||||||
"phase_type": "BasePhase",
|
"phase_type": "BasePhase",
|
||||||
"chains": ["baseGroupChain"],
|
"chains": ["baseGroupChain"],
|
||||||
"do_summary": False,
|
|
||||||
"do_search": False,
|
|
||||||
"do_doc_retrieval": False,
|
|
||||||
"do_code_retrieval": False,
|
|
||||||
"do_tool_retrieval": False,
|
|
||||||
"do_using_tool": False
|
|
||||||
},
|
},
|
||||||
|
"code2DocsGroup": {
|
||||||
|
"phase_name": "code2DocsGroup",
|
||||||
|
"phase_type": "BasePhase",
|
||||||
|
"chains": ["code2DocsGroupChain"],
|
||||||
|
},
|
||||||
|
"code2Tests": {
|
||||||
|
"phase_name": "code2Tests",
|
||||||
|
"phase_type": "BasePhase",
|
||||||
|
"chains": ["code2TestsChain"],
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -40,4 +40,41 @@ SELECTOR_PROMPT_CONFIGS = [
|
||||||
{"field_name": 'current_plan', "function_name": 'handle_current_plan'},
|
{"field_name": 'current_plan', "function_name": 'handle_current_plan'},
|
||||||
{"field_name": 'output_format', "function_name": 'handle_output_format', 'title': 'Response Output Format', "is_context": False},
|
{"field_name": 'output_format', "function_name": 'handle_output_format', 'title': 'Response Output Format', "is_context": False},
|
||||||
{"field_name": 'begin!!!', "function_name": 'handle_response', "is_context": False, "omit_if_empty": False}
|
{"field_name": 'begin!!!', "function_name": 'handle_response', "is_context": False, "omit_if_empty": False}
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
CODE2DOC_GROUP_PROMPT_CONFIGS = [
|
||||||
|
{"field_name": 'agent_profile', "function_name": 'handle_agent_profile', "is_context": False},
|
||||||
|
{"field_name": 'agent_infomation', "function_name": 'handle_agent_data', "is_context": False, "omit_if_empty": False},
|
||||||
|
# {"field_name": 'tool_information',"function_name": 'handle_tool_data', "is_context": False},
|
||||||
|
{"field_name": 'context_placeholder', "function_name": '', "is_context": True},
|
||||||
|
# {"field_name": 'reference_documents', "function_name": 'handle_doc_info'},
|
||||||
|
{"field_name": 'session_records', "function_name": 'handle_session_records'},
|
||||||
|
{"field_name": 'Specific Objective', "function_name": 'handle_specific_objective'},
|
||||||
|
{"field_name": 'Code Snippet', "function_name": 'handle_code_snippet'},
|
||||||
|
{"field_name": 'output_format', "function_name": 'handle_output_format', 'title': 'Response Output Format', "is_context": False},
|
||||||
|
{"field_name": 'begin!!!', "function_name": 'handle_response', "is_context": False, "omit_if_empty": False}
|
||||||
|
]
|
||||||
|
|
||||||
|
CODE2DOC_PROMPT_CONFIGS = [
|
||||||
|
{"field_name": 'agent_profile', "function_name": 'handle_agent_profile', "is_context": False},
|
||||||
|
# {"field_name": 'tool_information',"function_name": 'handle_tool_data', "is_context": False},
|
||||||
|
{"field_name": 'context_placeholder', "function_name": '', "is_context": True},
|
||||||
|
# {"field_name": 'reference_documents', "function_name": 'handle_doc_info'},
|
||||||
|
{"field_name": 'session_records', "function_name": 'handle_session_records'},
|
||||||
|
{"field_name": 'Specific Objective', "function_name": 'handle_specific_objective'},
|
||||||
|
{"field_name": 'Code Snippet', "function_name": 'handle_code_snippet'},
|
||||||
|
{"field_name": 'output_format', "function_name": 'handle_output_format', 'title': 'Response Output Format', "is_context": False},
|
||||||
|
{"field_name": 'begin!!!', "function_name": 'handle_response', "is_context": False, "omit_if_empty": False}
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
CODE2TESTS_PROMPT_CONFIGS = [
|
||||||
|
{"field_name": 'agent_profile', "function_name": 'handle_agent_profile', "is_context": False},
|
||||||
|
{"field_name": 'context_placeholder', "function_name": '', "is_context": True},
|
||||||
|
{"field_name": 'session_records', "function_name": 'handle_session_records'},
|
||||||
|
{"field_name": 'code_snippet', "function_name": 'handle_code_snippet'},
|
||||||
|
{"field_name": 'retrieval_codes', "function_name": 'handle_retrieval_codes', "description": ""},
|
||||||
|
{"field_name": 'output_format', "function_name": 'handle_output_format', 'title': 'Response Output Format', "is_context": False},
|
||||||
|
{"field_name": 'begin!!!', "function_name": 'handle_response', "is_context": False, "omit_if_empty": False}
|
||||||
|
]
|
|
@ -14,7 +14,8 @@ from .qa_template_prompt import QA_PROMPT, CODE_QA_PROMPT, QA_TEMPLATE_PROMPT, C
|
||||||
|
|
||||||
from .executor_template_prompt import EXECUTOR_TEMPLATE_PROMPT
|
from .executor_template_prompt import EXECUTOR_TEMPLATE_PROMPT
|
||||||
from .refine_template_prompt import REFINE_TEMPLATE_PROMPT
|
from .refine_template_prompt import REFINE_TEMPLATE_PROMPT
|
||||||
|
from .code2doc_template_prompt import Code2DocGroup_PROMPT, Class2Doc_PROMPT, Func2Doc_PROMPT
|
||||||
|
from .code2test_template_prompt import code2Tests_PROMPT, judgeCode2Tests_PROMPT
|
||||||
from .agent_selector_template_prompt import SELECTOR_AGENT_TEMPLATE_PROMPT
|
from .agent_selector_template_prompt import SELECTOR_AGENT_TEMPLATE_PROMPT
|
||||||
|
|
||||||
from .react_template_prompt import REACT_TEMPLATE_PROMPT
|
from .react_template_prompt import REACT_TEMPLATE_PROMPT
|
||||||
|
@ -37,5 +38,7 @@ __all__ = [
|
||||||
"SELECTOR_AGENT_TEMPLATE_PROMPT",
|
"SELECTOR_AGENT_TEMPLATE_PROMPT",
|
||||||
"PLANNER_TEMPLATE_PROMPT", "GENERAL_PLANNER_PROMPT", "DATA_PLANNER_PROMPT", "TOOL_PLANNER_PROMPT",
|
"PLANNER_TEMPLATE_PROMPT", "GENERAL_PLANNER_PROMPT", "DATA_PLANNER_PROMPT", "TOOL_PLANNER_PROMPT",
|
||||||
"REACT_TEMPLATE_PROMPT",
|
"REACT_TEMPLATE_PROMPT",
|
||||||
"REACT_CODE_PROMPT", "REACT_TOOL_PROMPT", "REACT_TOOL_AND_CODE_PROMPT", "REACT_TOOL_AND_CODE_PLANNER_PROMPT"
|
"REACT_CODE_PROMPT", "REACT_TOOL_PROMPT", "REACT_TOOL_AND_CODE_PROMPT", "REACT_TOOL_AND_CODE_PLANNER_PROMPT",
|
||||||
|
"Code2DocGroup_PROMPT", "Class2Doc_PROMPT", "Func2Doc_PROMPT",
|
||||||
|
"code2Tests_PROMPT", "judgeCode2Tests_PROMPT"
|
||||||
]
|
]
|
|
@ -0,0 +1,95 @@
|
||||||
|
Code2DocGroup_PROMPT = """#### Agent Profile
|
||||||
|
|
||||||
|
Your goal is to response according the Context Data's information with the role that will best facilitate a solution, taking into account all relevant context (Context) provided.
|
||||||
|
|
||||||
|
When you need to select the appropriate role for handling a user's query, carefully read the provided role names, role descriptions and tool list.
|
||||||
|
|
||||||
|
ATTENTION: response carefully referenced "Response Output Format" in format.
|
||||||
|
|
||||||
|
#### Input Format
|
||||||
|
|
||||||
|
#### Response Output Format
|
||||||
|
|
||||||
|
**Code Path:** Extract the paths for the class/method/function that need to be addressed from the context
|
||||||
|
|
||||||
|
**Role:** Select the role from agent names
|
||||||
|
"""
|
||||||
|
|
||||||
|
Class2Doc_PROMPT = """#### Agent Profile
|
||||||
|
As an advanced code documentation generator, you are proficient in translating class definitions into comprehensive documentation with a focus on instantiation parameters.
|
||||||
|
Your specific task is to parse the given code snippet of a class, extract information regarding its instantiation parameters.
|
||||||
|
|
||||||
|
ATTENTION: response carefully in "Response Output Format".
|
||||||
|
|
||||||
|
#### Input Format
|
||||||
|
|
||||||
|
**Code Snippet:** Provide the full class definition, including the constructor and any parameters it may require for instantiation.
|
||||||
|
|
||||||
|
#### Response Output Format
|
||||||
|
**Class Base:** Specify the base class or interface from which the current class extends, if any.
|
||||||
|
|
||||||
|
**Class Description:** Offer a brief description of the class's purpose and functionality.
|
||||||
|
|
||||||
|
**Init Parameters:** List each parameter from construct. For each parameter, provide:
|
||||||
|
- `param`: The parameter name
|
||||||
|
- `param_description`: A concise explanation of the parameter's purpose.
|
||||||
|
- `param_type`: The data type of the parameter, if explicitly defined.
|
||||||
|
|
||||||
|
```json
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"param": "parameter_name",
|
||||||
|
"param_description": "A brief description of what this parameter is used for.",
|
||||||
|
"param_type": "The data type of the parameter"
|
||||||
|
},
|
||||||
|
...
|
||||||
|
]
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
If no parameter for construct, return
|
||||||
|
```json
|
||||||
|
[]
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
Func2Doc_PROMPT = """#### Agent Profile
|
||||||
|
You are a high-level code documentation assistant, skilled at extracting information from function/method code into detailed and well-structured documentation.
|
||||||
|
|
||||||
|
ATTENTION: response carefully in "Response Output Format".
|
||||||
|
|
||||||
|
|
||||||
|
#### Input Format
|
||||||
|
**Code Path:** Provide the code path of the function or method you wish to document.
|
||||||
|
This name will be used to identify and extract the relevant details from the code snippet provided.
|
||||||
|
|
||||||
|
**Code Snippet:** A segment of code that contains the function or method to be documented.
|
||||||
|
|
||||||
|
#### Response Output Format
|
||||||
|
|
||||||
|
**Class Description:** Offer a brief description of the method(function)'s purpose and functionality.
|
||||||
|
|
||||||
|
**Parameters:** Extract parameter for the specific function/method Code from Code Snippet. For parameter, provide:
|
||||||
|
- `param`: The parameter name
|
||||||
|
- `param_description`: A concise explanation of the parameter's purpose.
|
||||||
|
- `param_type`: The data type of the parameter, if explicitly defined.
|
||||||
|
```json
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"param": "parameter_name",
|
||||||
|
"param_description": "A brief description of what this parameter is used for.",
|
||||||
|
"param_type": "The data type of the parameter"
|
||||||
|
},
|
||||||
|
...
|
||||||
|
]
|
||||||
|
```
|
||||||
|
|
||||||
|
If no parameter for function/method, return
|
||||||
|
```json
|
||||||
|
[]
|
||||||
|
```
|
||||||
|
|
||||||
|
**Return Value Description:** Describe what the function/method returns upon completion.
|
||||||
|
|
||||||
|
**Return Type:** Indicate the type of data the function/method returns (e.g., string, integer, object, void).
|
||||||
|
"""
|
|
@ -0,0 +1,65 @@
|
||||||
|
judgeCode2Tests_PROMPT = """#### Agent Profile
|
||||||
|
When determining the necessity of writing test cases for a given code snippet,
|
||||||
|
it's essential to evaluate its interactions with dependent classes and methods (retrieved code snippets),
|
||||||
|
in addition to considering these critical factors:
|
||||||
|
1. Functionality: If it implements a concrete function or logic, test cases are typically necessary to verify its correctness.
|
||||||
|
2. Complexity: If the code is complex, especially if it contains multiple conditional statements, loops, exceptions handling, etc.,
|
||||||
|
it's more likely to harbor bugs, and thus test cases should be written.
|
||||||
|
If the code involves complex algorithms or logic, then writing test cases can help ensure the accuracy of the logic and prevent errors during future refactoring.
|
||||||
|
3. Criticality: If it's part of the critical path or affects core functionalities, then it needs to be tested.
|
||||||
|
Comprehensive test cases should be written for core business logic or key components of the system to ensure the correctness and stability of the functionality.
|
||||||
|
4. Dependencies: If the code has external dependencies, integration testing may be necessary, or mocking these dependencies during unit testing might be required.
|
||||||
|
5. User Input: If the code handles user input, especially from unregulated external sources, creating test cases to check input validation and handling is important.
|
||||||
|
6. Frequent Changes: For code that requires regular updates or modifications, having the appropriate test cases ensures that changes do not break existing functionalities.
|
||||||
|
|
||||||
|
#### Input Format
|
||||||
|
|
||||||
|
**Code Snippet:** the initial Code or objective that the user wanted to achieve
|
||||||
|
|
||||||
|
**Retrieval Code Snippets:** These are the associated code segments that the main Code Snippet depends on.
|
||||||
|
Examine these snippets to understand how they interact with the main snippet and to determine how they might affect the overall functionality.
|
||||||
|
|
||||||
|
#### Response Output Format
|
||||||
|
**Action Status:** Set to 'finished' or 'continued'.
|
||||||
|
If set to 'finished', the code snippet does not warrant the generation of a test case.
|
||||||
|
If set to 'continued', the code snippet necessitates the creation of a test case.
|
||||||
|
|
||||||
|
**REASON:** Justify the selection of 'finished' or 'continued', contemplating the decision through a step-by-step rationale.
|
||||||
|
"""
|
||||||
|
|
||||||
|
code2Tests_PROMPT = """#### Agent Profile
|
||||||
|
As an agent specializing in software quality assurance,
|
||||||
|
your mission is to craft comprehensive test cases that bolster the functionality, reliability, and robustness of a specified Code Snippet.
|
||||||
|
This task is to be carried out with a keen understanding of the snippet's interactions with its dependent classes and methods—collectively referred to as Retrieval Code Snippets.
|
||||||
|
Analyze the details given below to grasp the code's intended purpose, its inherent complexity, and the context within which it operates.
|
||||||
|
Your constructed test cases must thoroughly examine the various factors influencing the code's quality and performance.
|
||||||
|
|
||||||
|
ATTENTION: response carefully referenced "Response Output Format" in format.
|
||||||
|
|
||||||
|
Each test case should include:
|
||||||
|
1. clear description of the test purpose.
|
||||||
|
2. The input values or conditions for the test.
|
||||||
|
3. The expected outcome or assertion for the test.
|
||||||
|
4. Appropriate tags (e.g., 'functional', 'integration', 'regression') that classify the type of test case.
|
||||||
|
5. these test code should have package and import
|
||||||
|
|
||||||
|
#### Input Format
|
||||||
|
|
||||||
|
**Code Snippet:** the initial Code or objective that the user wanted to achieve
|
||||||
|
|
||||||
|
**Retrieval Code Snippets:** These are the interrelated pieces of code sourced from the codebase, which support or influence the primary Code Snippet.
|
||||||
|
|
||||||
|
#### Response Output Format
|
||||||
|
**SaveFileName:** construct a local file name based on Question and Context, such as
|
||||||
|
|
||||||
|
```java
|
||||||
|
package/class.java
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
**Test Code:** generate the test code for the current Code Snippet.
|
||||||
|
```java
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
|
"""
|
|
@ -1,5 +1,5 @@
|
||||||
from abc import abstractmethod, ABC
|
from abc import abstractmethod, ABC
|
||||||
from typing import List
|
from typing import List, Dict
|
||||||
import os, sys, copy, json
|
import os, sys, copy, json
|
||||||
from jieba.analyse import extract_tags
|
from jieba.analyse import extract_tags
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
|
@ -10,12 +10,13 @@ from langchain.docstore.document import Document
|
||||||
|
|
||||||
from .schema import Memory, Message
|
from .schema import Memory, Message
|
||||||
from coagent.service.service_factory import KBServiceFactory
|
from coagent.service.service_factory import KBServiceFactory
|
||||||
from coagent.llm_models import getChatModel, getChatModelFromConfig
|
from coagent.llm_models import getChatModelFromConfig
|
||||||
from coagent.llm_models.llm_config import EmbedConfig, LLMConfig
|
from coagent.llm_models.llm_config import EmbedConfig, LLMConfig
|
||||||
from coagent.embeddings.utils import load_embeddings_from_path
|
from coagent.embeddings.utils import load_embeddings_from_path
|
||||||
from coagent.utils.common_utils import save_to_json_file, read_json_file, addMinutesToTime
|
from coagent.utils.common_utils import save_to_json_file, read_json_file, addMinutesToTime
|
||||||
from coagent.connector.configs.prompts import CONV_SUMMARY_PROMPT_SPEC
|
from coagent.connector.configs.prompts import CONV_SUMMARY_PROMPT_SPEC
|
||||||
from coagent.orm import table_init
|
from coagent.orm import table_init
|
||||||
|
from coagent.base_configs.env_config import KB_ROOT_PATH
|
||||||
# from configs.model_config import KB_ROOT_PATH, EMBEDDING_MODEL, EMBEDDING_DEVICE, SCORE_THRESHOLD
|
# from configs.model_config import KB_ROOT_PATH, EMBEDDING_MODEL, EMBEDDING_DEVICE, SCORE_THRESHOLD
|
||||||
# from configs.model_config import embedding_model_dict
|
# from configs.model_config import embedding_model_dict
|
||||||
|
|
||||||
|
@ -70,16 +71,22 @@ class BaseMemoryManager(ABC):
|
||||||
self.unique_name = unique_name
|
self.unique_name = unique_name
|
||||||
self.memory_type = memory_type
|
self.memory_type = memory_type
|
||||||
self.do_init = do_init
|
self.do_init = do_init
|
||||||
self.current_memory = Memory(messages=[])
|
# self.current_memory = Memory(messages=[])
|
||||||
self.recall_memory = Memory(messages=[])
|
# self.recall_memory = Memory(messages=[])
|
||||||
self.summary_memory = Memory(messages=[])
|
# self.summary_memory = Memory(messages=[])
|
||||||
|
self.current_memory_dict: Dict[str, Memory] = {}
|
||||||
|
self.recall_memory_dict: Dict[str, Memory] = {}
|
||||||
|
self.summary_memory_dict: Dict[str, Memory] = {}
|
||||||
self.save_message_keys = [
|
self.save_message_keys = [
|
||||||
'chat_index', 'role_name', 'role_type', 'role_prompt', 'input_query', 'origin_query',
|
'chat_index', 'role_name', 'role_type', 'role_prompt', 'input_query', 'origin_query',
|
||||||
'datetime', 'role_content', 'step_content', 'parsed_output', 'spec_parsed_output', 'parsed_output_list',
|
'datetime', 'role_content', 'step_content', 'parsed_output', 'spec_parsed_output', 'parsed_output_list',
|
||||||
'task', 'db_docs', 'code_docs', 'search_docs', 'phase_name', 'chain_name', 'customed_kargs']
|
'task', 'db_docs', 'code_docs', 'search_docs', 'phase_name', 'chain_name', 'customed_kargs']
|
||||||
self.init_vb()
|
self.init_vb()
|
||||||
|
|
||||||
def init_vb(self):
|
def re_init(self, do_init: bool=False):
|
||||||
|
self.init_vb()
|
||||||
|
|
||||||
|
def init_vb(self, do_init: bool=None):
|
||||||
"""
|
"""
|
||||||
Initializes the vb.
|
Initializes the vb.
|
||||||
"""
|
"""
|
||||||
|
@ -135,13 +142,15 @@ class BaseMemoryManager(ABC):
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def save_to_vs(self, embed_model="", embed_device=""):
|
def save_to_vs(self, ):
|
||||||
"""
|
"""
|
||||||
Saves the memory to the vector space.
|
Saves the memory to the vector space.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
Args:
|
def get_memory_pool(self, user_name: str, ):
|
||||||
- embed_model: A string representing the embedding model. Default is EMBEDDING_MODEL.
|
"""
|
||||||
- embed_device: A string representing the embedding device. Default is EMBEDDING_DEVICE.
|
return memory_pool
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -230,7 +239,7 @@ class LocalMemoryManager(BaseMemoryManager):
|
||||||
unique_name: str = "default",
|
unique_name: str = "default",
|
||||||
memory_type: str = "recall",
|
memory_type: str = "recall",
|
||||||
do_init: bool = False,
|
do_init: bool = False,
|
||||||
kb_root_path: str = "",
|
kb_root_path: str = KB_ROOT_PATH,
|
||||||
):
|
):
|
||||||
self.user_name = user_name
|
self.user_name = user_name
|
||||||
self.unique_name = unique_name
|
self.unique_name = unique_name
|
||||||
|
@ -239,16 +248,22 @@ class LocalMemoryManager(BaseMemoryManager):
|
||||||
self.kb_root_path = kb_root_path
|
self.kb_root_path = kb_root_path
|
||||||
self.embed_config: EmbedConfig = embed_config
|
self.embed_config: EmbedConfig = embed_config
|
||||||
self.llm_config: LLMConfig = llm_config
|
self.llm_config: LLMConfig = llm_config
|
||||||
self.current_memory = Memory(messages=[])
|
# self.current_memory = Memory(messages=[])
|
||||||
self.recall_memory = Memory(messages=[])
|
# self.recall_memory = Memory(messages=[])
|
||||||
self.summary_memory = Memory(messages=[])
|
# self.summary_memory = Memory(messages=[])
|
||||||
|
self.current_memory_dict: Dict[str, Memory] = {}
|
||||||
|
self.recall_memory_dict: Dict[str, Memory] = {}
|
||||||
|
self.summary_memory_dict: Dict[str, Memory] = {}
|
||||||
self.save_message_keys = [
|
self.save_message_keys = [
|
||||||
'chat_index', 'role_name', 'role_type', 'role_prompt', 'input_query', 'origin_query',
|
'chat_index', 'role_name', 'role_type', 'role_prompt', 'input_query', 'origin_query',
|
||||||
'datetime', 'role_content', 'step_content', 'parsed_output', 'spec_parsed_output', 'parsed_output_list',
|
'datetime', 'role_content', 'step_content', 'parsed_output', 'spec_parsed_output', 'parsed_output_list',
|
||||||
'task', 'db_docs', 'code_docs', 'search_docs', 'phase_name', 'chain_name', 'customed_kargs']
|
'task', 'db_docs', 'code_docs', 'search_docs', 'phase_name', 'chain_name', 'customed_kargs']
|
||||||
self.init_vb()
|
self.init_vb()
|
||||||
|
|
||||||
def init_vb(self):
|
def re_init(self, do_init: bool=False):
|
||||||
|
self.init_vb(do_init)
|
||||||
|
|
||||||
|
def init_vb(self, do_init: bool=None):
|
||||||
vb_name = f"{self.user_name}/{self.unique_name}/{self.memory_type}"
|
vb_name = f"{self.user_name}/{self.unique_name}/{self.memory_type}"
|
||||||
# default to recreate a new vb
|
# default to recreate a new vb
|
||||||
table_init()
|
table_init()
|
||||||
|
@ -256,31 +271,37 @@ class LocalMemoryManager(BaseMemoryManager):
|
||||||
if vb:
|
if vb:
|
||||||
status = vb.clear_vs()
|
status = vb.clear_vs()
|
||||||
|
|
||||||
if not self.do_init:
|
check_do_init = do_init if do_init else self.do_init
|
||||||
|
if not check_do_init:
|
||||||
self.load(self.kb_root_path)
|
self.load(self.kb_root_path)
|
||||||
self.save_to_vs()
|
self.save_to_vs()
|
||||||
|
|
||||||
def append(self, message: Message):
|
def append(self, message: Message):
|
||||||
self.recall_memory.append(message)
|
self.check_user_name(message.user_name)
|
||||||
|
|
||||||
|
uuid_name = "_".join([self.user_name, self.unique_name, self.memory_type])
|
||||||
|
self.recall_memory_dict[uuid_name].append(message)
|
||||||
#
|
#
|
||||||
if message.role_type == "summary":
|
if message.role_type == "summary":
|
||||||
self.summary_memory.append(message)
|
self.summary_memory_dict[uuid_name].append(message)
|
||||||
else:
|
else:
|
||||||
self.current_memory.append(message)
|
self.current_memory_dict[uuid_name].append(message)
|
||||||
|
|
||||||
self.save(self.kb_root_path)
|
self.save(self.kb_root_path)
|
||||||
self.save_new_to_vs([message])
|
self.save_new_to_vs([message])
|
||||||
|
|
||||||
def extend(self, memory: Memory):
|
# def extend(self, memory: Memory):
|
||||||
self.recall_memory.extend(memory)
|
# self.recall_memory.extend(memory)
|
||||||
self.current_memory.extend(self.recall_memory.filter_by_role_type(["summary"]))
|
# self.current_memory.extend(self.recall_memory.filter_by_role_type(["summary"]))
|
||||||
self.summary_memory.extend(self.recall_memory.select_by_role_type(["summary"]))
|
# self.summary_memory.extend(self.recall_memory.select_by_role_type(["summary"]))
|
||||||
self.save(self.kb_root_path)
|
# self.save(self.kb_root_path)
|
||||||
self.save_new_to_vs(memory.messages)
|
# self.save_new_to_vs(memory.messages)
|
||||||
|
|
||||||
def save(self, save_dir: str = "./"):
|
def save(self, save_dir: str = "./"):
|
||||||
file_path = os.path.join(save_dir, f"{self.user_name}/{self.unique_name}/{self.memory_type}/converation.jsonl")
|
file_path = os.path.join(save_dir, f"{self.user_name}/{self.unique_name}/{self.memory_type}/converation.jsonl")
|
||||||
memory_messages = self.recall_memory.dict()
|
uuid_name = "_".join([self.user_name, self.unique_name, self.memory_type])
|
||||||
|
|
||||||
|
memory_messages = self.recall_memory_dict[uuid_name].dict()
|
||||||
memory_messages = {k: [
|
memory_messages = {k: [
|
||||||
{kkk: vvv for kkk, vvv in vv.items() if kkk in self.save_message_keys}
|
{kkk: vvv for kkk, vvv in vv.items() if kkk in self.save_message_keys}
|
||||||
for vv in v ]
|
for vv in v ]
|
||||||
|
@ -291,18 +312,28 @@ class LocalMemoryManager(BaseMemoryManager):
|
||||||
|
|
||||||
def load(self, load_dir: str = "./") -> Memory:
|
def load(self, load_dir: str = "./") -> Memory:
|
||||||
file_path = os.path.join(load_dir, f"{self.user_name}/{self.unique_name}/{self.memory_type}/converation.jsonl")
|
file_path = os.path.join(load_dir, f"{self.user_name}/{self.unique_name}/{self.memory_type}/converation.jsonl")
|
||||||
|
uuid_name = "_".join([self.user_name, self.unique_name, self.memory_type])
|
||||||
|
|
||||||
if os.path.exists(file_path):
|
if os.path.exists(file_path):
|
||||||
self.recall_memory = Memory(**read_json_file(file_path))
|
# self.recall_memory = Memory(**read_json_file(file_path))
|
||||||
self.current_memory = Memory(messages=self.recall_memory.filter_by_role_type(["summary"]))
|
# self.current_memory = Memory(messages=self.recall_memory.filter_by_role_type(["summary"]))
|
||||||
self.summary_memory = Memory(messages=self.recall_memory.select_by_role_type(["summary"]))
|
# self.summary_memory = Memory(messages=self.recall_memory.select_by_role_type(["summary"]))
|
||||||
|
|
||||||
|
recall_memory = Memory(**read_json_file(file_path))
|
||||||
|
self.recall_memory_dict[uuid_name] = recall_memory
|
||||||
|
self.current_memory_dict[uuid_name] = Memory(messages=recall_memory.filter_by_role_type(["summary"]))
|
||||||
|
self.summary_memory_dict[uuid_name] = Memory(messages=recall_memory.select_by_role_type(["summary"]))
|
||||||
|
else:
|
||||||
|
self.recall_memory_dict[uuid_name] = Memory(messages=[])
|
||||||
|
self.current_memory_dict[uuid_name] = Memory(messages=[])
|
||||||
|
self.summary_memory_dict[uuid_name] = Memory(messages=[])
|
||||||
|
|
||||||
def save_new_to_vs(self, messages: List[Message]):
|
def save_new_to_vs(self, messages: List[Message]):
|
||||||
if self.embed_config:
|
if self.embed_config:
|
||||||
vb_name = f"{self.user_name}/{self.unique_name}/{self.memory_type}"
|
vb_name = f"{self.user_name}/{self.unique_name}/{self.memory_type}"
|
||||||
# default to faiss, todo: add new vstype
|
# default to faiss, todo: add new vstype
|
||||||
vb = KBServiceFactory.get_service(vb_name, "faiss", self.embed_config, self.kb_root_path)
|
vb = KBServiceFactory.get_service(vb_name, "faiss", self.embed_config, self.kb_root_path)
|
||||||
embeddings = load_embeddings_from_path(self.embed_config.embed_model_path, self.embed_config.model_device,)
|
embeddings = load_embeddings_from_path(self.embed_config.embed_model_path, self.embed_config.model_device, self.embed_config.langchain_embeddings)
|
||||||
messages = [
|
messages = [
|
||||||
{k: v for k, v in m.dict().items() if k in self.save_message_keys}
|
{k: v for k, v in m.dict().items() if k in self.save_message_keys}
|
||||||
for m in messages]
|
for m in messages]
|
||||||
|
@ -311,23 +342,26 @@ class LocalMemoryManager(BaseMemoryManager):
|
||||||
vb.do_add_doc(docs, embeddings)
|
vb.do_add_doc(docs, embeddings)
|
||||||
|
|
||||||
def save_to_vs(self):
|
def save_to_vs(self):
|
||||||
vb_name = f"{self.user_name}/{self.unique_name}/{self.memory_type}"
|
'''only after load'''
|
||||||
# default to recreate a new vb
|
if self.embed_config:
|
||||||
vb = KBServiceFactory.get_service_by_name(vb_name, self.embed_config, self.kb_root_path)
|
vb_name = f"{self.user_name}/{self.unique_name}/{self.memory_type}"
|
||||||
if vb:
|
uuid_name = "_".join([self.user_name, self.unique_name, self.memory_type])
|
||||||
status = vb.clear_vs()
|
# default to recreate a new vb
|
||||||
# create_kb(vb_name, "faiss", embed_model)
|
vb = KBServiceFactory.get_service_by_name(vb_name, self.embed_config, self.kb_root_path)
|
||||||
|
if vb:
|
||||||
|
status = vb.clear_vs()
|
||||||
|
# create_kb(vb_name, "faiss", embed_model)
|
||||||
|
|
||||||
# default to faiss, todo: add new vstype
|
# default to faiss, todo: add new vstype
|
||||||
vb = KBServiceFactory.get_service(vb_name, "faiss", self.embed_config, self.kb_root_path)
|
vb = KBServiceFactory.get_service(vb_name, "faiss", self.embed_config, self.kb_root_path)
|
||||||
embeddings = load_embeddings_from_path(self.embed_config.embed_model_path, self.embed_config.model_device,)
|
embeddings = load_embeddings_from_path(self.embed_config.embed_model_path, self.embed_config.model_device, self.embed_config.langchain_embeddings)
|
||||||
messages = self.recall_memory.dict()
|
messages = self.recall_memory_dict[uuid_name].dict()
|
||||||
messages = [
|
messages = [
|
||||||
{kkk: vvv for kkk, vvv in vv.items() if kkk in self.save_message_keys}
|
{kkk: vvv for kkk, vvv in vv.items() if kkk in self.save_message_keys}
|
||||||
for k, v in messages.items() for vv in v]
|
for k, v in messages.items() for vv in v]
|
||||||
docs = [{"page_content": m["step_content"] or m["role_content"] or m["input_query"] or m["origin_query"], "metadata": m} for m in messages]
|
docs = [{"page_content": m["step_content"] or m["role_content"] or m["input_query"] or m["origin_query"], "metadata": m} for m in messages]
|
||||||
docs = [Document(**doc) for doc in docs]
|
docs = [Document(**doc) for doc in docs]
|
||||||
vb.do_add_doc(docs, embeddings)
|
vb.do_add_doc(docs, embeddings)
|
||||||
|
|
||||||
# def load_from_vs(self, embed_model=EMBEDDING_MODEL) -> Memory:
|
# def load_from_vs(self, embed_model=EMBEDDING_MODEL) -> Memory:
|
||||||
# vb_name = f"{self.user_name}/{self.unique_name}/{self.memory_type}"
|
# vb_name = f"{self.user_name}/{self.unique_name}/{self.memory_type}"
|
||||||
|
@ -338,7 +372,12 @@ class LocalMemoryManager(BaseMemoryManager):
|
||||||
# docs = vb.get_all_documents()
|
# docs = vb.get_all_documents()
|
||||||
# print(docs)
|
# print(docs)
|
||||||
|
|
||||||
def router_retrieval(self, text: str=None, datetime: str = None, n=5, top_k=5, retrieval_type: str = "embedding", **kwargs) -> List[Message]:
|
def get_memory_pool(self, user_name: str, ):
|
||||||
|
self.check_user_name(user_name)
|
||||||
|
uuid_name = "_".join([self.user_name, self.unique_name, self.memory_type])
|
||||||
|
return self.recall_memory_dict[uuid_name]
|
||||||
|
|
||||||
|
def router_retrieval(self, user_name: str = "default", text: str=None, datetime: str = None, n=5, top_k=5, retrieval_type: str = "embedding", **kwargs) -> List[Message]:
|
||||||
retrieval_func_dict = {
|
retrieval_func_dict = {
|
||||||
"embedding": self.embedding_retrieval, "text": self.text_retrieval, "datetime": self.datetime_retrieval
|
"embedding": self.embedding_retrieval, "text": self.text_retrieval, "datetime": self.datetime_retrieval
|
||||||
}
|
}
|
||||||
|
@ -356,20 +395,22 @@ class LocalMemoryManager(BaseMemoryManager):
|
||||||
#
|
#
|
||||||
return retrieval_func(**params)
|
return retrieval_func(**params)
|
||||||
|
|
||||||
def embedding_retrieval(self, text: str, top_k=1, score_threshold=1.0, **kwargs) -> List[Message]:
|
def embedding_retrieval(self, text: str, top_k=1, score_threshold=1.0, user_name: str = "default", **kwargs) -> List[Message]:
|
||||||
if text is None: return []
|
if text is None: return []
|
||||||
vb_name = f"{self.user_name}/{self.unique_name}/{self.memory_type}"
|
vb_name = f"{user_name}/{self.unique_name}/{self.memory_type}"
|
||||||
vb = KBServiceFactory.get_service(vb_name, "faiss", self.embed_config, self.kb_root_path)
|
vb = KBServiceFactory.get_service(vb_name, "faiss", self.embed_config, self.kb_root_path)
|
||||||
docs = vb.search_docs(text, top_k=top_k, score_threshold=score_threshold)
|
docs = vb.search_docs(text, top_k=top_k, score_threshold=score_threshold)
|
||||||
return [Message(**doc.metadata) for doc, score in docs]
|
return [Message(**doc.metadata) for doc, score in docs]
|
||||||
|
|
||||||
def text_retrieval(self, text: str, **kwargs) -> List[Message]:
|
def text_retrieval(self, text: str, user_name: str = "default", **kwargs) -> List[Message]:
|
||||||
if text is None: return []
|
if text is None: return []
|
||||||
return self._text_retrieval_from_cache(self.recall_memory.messages, text, score_threshold=0.3, topK=5, **kwargs)
|
uuid_name = "_".join([user_name, self.unique_name, self.memory_type])
|
||||||
|
return self._text_retrieval_from_cache(self.recall_memory_dict[uuid_name].messages, text, score_threshold=0.3, topK=5, **kwargs)
|
||||||
|
|
||||||
def datetime_retrieval(self, datetime: str, text: str = None, n: int = 5, **kwargs) -> List[Message]:
|
def datetime_retrieval(self, datetime: str, text: str = None, n: int = 5, user_name: str = "default", **kwargs) -> List[Message]:
|
||||||
if datetime is None: return []
|
if datetime is None: return []
|
||||||
return self._datetime_retrieval_from_cache(self.recall_memory.messages, datetime, text, n, **kwargs)
|
uuid_name = "_".join([user_name, self.unique_name, self.memory_type])
|
||||||
|
return self._datetime_retrieval_from_cache(self.recall_memory_dict[uuid_name].messages, datetime, text, n, **kwargs)
|
||||||
|
|
||||||
def _text_retrieval_from_cache(self, messages: List[Message], text: str = None, score_threshold=0.3, topK=5, tag_topK=5, **kwargs) -> List[Message]:
|
def _text_retrieval_from_cache(self, messages: List[Message], text: str = None, score_threshold=0.3, topK=5, tag_topK=5, **kwargs) -> List[Message]:
|
||||||
keywords = extract_tags(text, topK=tag_topK)
|
keywords = extract_tags(text, topK=tag_topK)
|
||||||
|
@ -427,4 +468,18 @@ class LocalMemoryManager(BaseMemoryManager):
|
||||||
)
|
)
|
||||||
summary_message.parsed_output_list.append({"summary": content})
|
summary_message.parsed_output_list.append({"summary": content})
|
||||||
newest_messages.insert(0, summary_message)
|
newest_messages.insert(0, summary_message)
|
||||||
return newest_messages
|
return newest_messages
|
||||||
|
|
||||||
|
def check_user_name(self, user_name: str):
|
||||||
|
# logger.debug(f"self.user_name is {self.user_name}")
|
||||||
|
if user_name != self.user_name:
|
||||||
|
self.user_name = user_name
|
||||||
|
self.init_vb()
|
||||||
|
|
||||||
|
uuid_name = "_".join([self.user_name, self.unique_name, self.memory_type])
|
||||||
|
if uuid_name not in self.recall_memory_dict:
|
||||||
|
self.recall_memory_dict[uuid_name] = Memory(messages=[])
|
||||||
|
self.current_memory_dict[uuid_name] = Memory(messages=[])
|
||||||
|
self.summary_memory_dict[uuid_name] = Memory(messages=[])
|
||||||
|
|
||||||
|
# logger.debug(f"self.user_name is {self.user_name}")
|
|
@ -1,16 +1,19 @@
|
||||||
import re, traceback, uuid, copy, json, os
|
import re, traceback, uuid, copy, json, os
|
||||||
|
from typing import Union
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
|
from langchain.schema import BaseRetriever
|
||||||
|
|
||||||
# from configs.server_config import SANDBOX_SERVER
|
|
||||||
# from configs.model_config import JUPYTER_WORK_PATH
|
|
||||||
from coagent.connector.schema import (
|
from coagent.connector.schema import (
|
||||||
Memory, Role, Message, ActionStatus, CodeDoc, Doc, LogVerboseEnum
|
Memory, Role, Message, ActionStatus, CodeDoc, Doc, LogVerboseEnum
|
||||||
)
|
)
|
||||||
|
from coagent.retrieval.base_retrieval import IMRertrieval
|
||||||
from coagent.connector.memory_manager import BaseMemoryManager
|
from coagent.connector.memory_manager import BaseMemoryManager
|
||||||
from coagent.tools import DDGSTool, DocRetrieval, CodeRetrieval
|
from coagent.tools import DDGSTool, DocRetrieval, CodeRetrieval
|
||||||
from coagent.sandbox import PyCodeBox, CodeBoxResponse
|
from coagent.sandbox import PyCodeBox, CodeBoxResponse
|
||||||
from coagent.llm_models.llm_config import LLMConfig, EmbedConfig
|
from coagent.llm_models.llm_config import LLMConfig, EmbedConfig
|
||||||
|
from coagent.base_configs.env_config import JUPYTER_WORK_PATH
|
||||||
|
|
||||||
from .utils import parse_dict_to_dict, parse_text_to_dict
|
from .utils import parse_dict_to_dict, parse_text_to_dict
|
||||||
|
|
||||||
|
|
||||||
|
@ -19,10 +22,13 @@ class MessageUtils:
|
||||||
self,
|
self,
|
||||||
role: Role = None,
|
role: Role = None,
|
||||||
sandbox_server: dict = {},
|
sandbox_server: dict = {},
|
||||||
jupyter_work_path: str = "./",
|
jupyter_work_path: str = JUPYTER_WORK_PATH,
|
||||||
embed_config: EmbedConfig = None,
|
embed_config: EmbedConfig = None,
|
||||||
llm_config: LLMConfig = None,
|
llm_config: LLMConfig = None,
|
||||||
kb_root_path: str = "",
|
kb_root_path: str = "",
|
||||||
|
doc_retrieval: Union[BaseRetriever, IMRertrieval] = None,
|
||||||
|
code_retrieval: IMRertrieval = None,
|
||||||
|
search_retrieval: IMRertrieval = None,
|
||||||
log_verbose: str = "0"
|
log_verbose: str = "0"
|
||||||
) -> None:
|
) -> None:
|
||||||
self.role = role
|
self.role = role
|
||||||
|
@ -31,6 +37,9 @@ class MessageUtils:
|
||||||
self.embed_config = embed_config
|
self.embed_config = embed_config
|
||||||
self.llm_config = llm_config
|
self.llm_config = llm_config
|
||||||
self.kb_root_path = kb_root_path
|
self.kb_root_path = kb_root_path
|
||||||
|
self.doc_retrieval = doc_retrieval
|
||||||
|
self.code_retrieval = code_retrieval
|
||||||
|
self.search_retrieval = search_retrieval
|
||||||
self.codebox = PyCodeBox(
|
self.codebox = PyCodeBox(
|
||||||
remote_url=self.sandbox_server.get("url", "http://127.0.0.1:5050"),
|
remote_url=self.sandbox_server.get("url", "http://127.0.0.1:5050"),
|
||||||
remote_ip=self.sandbox_server.get("host", "http://127.0.0.1"),
|
remote_ip=self.sandbox_server.get("host", "http://127.0.0.1"),
|
||||||
|
@ -44,6 +53,7 @@ class MessageUtils:
|
||||||
self.log_verbose = os.environ.get("log_verbose", "0") or log_verbose
|
self.log_verbose = os.environ.get("log_verbose", "0") or log_verbose
|
||||||
|
|
||||||
def inherit_extrainfo(self, input_message: Message, output_message: Message):
|
def inherit_extrainfo(self, input_message: Message, output_message: Message):
|
||||||
|
output_message.user_name = input_message.user_name
|
||||||
output_message.db_docs = input_message.db_docs
|
output_message.db_docs = input_message.db_docs
|
||||||
output_message.search_docs = input_message.search_docs
|
output_message.search_docs = input_message.search_docs
|
||||||
output_message.code_docs = input_message.code_docs
|
output_message.code_docs = input_message.code_docs
|
||||||
|
@ -116,18 +126,45 @@ class MessageUtils:
|
||||||
knowledge_basename = message.doc_engine_name
|
knowledge_basename = message.doc_engine_name
|
||||||
top_k = message.top_k
|
top_k = message.top_k
|
||||||
score_threshold = message.score_threshold
|
score_threshold = message.score_threshold
|
||||||
if knowledge_basename:
|
if self.doc_retrieval:
|
||||||
docs = DocRetrieval.run(query, knowledge_basename, top_k, score_threshold, self.embed_config, self.kb_root_path)
|
if isinstance(self.doc_retrieval, BaseRetriever):
|
||||||
|
docs = self.doc_retrieval.get_relevant_documents(query)
|
||||||
|
else:
|
||||||
|
# docs = self.doc_retrieval.run(query, search_top=message.top_k, score_threshold=message.score_threshold,)
|
||||||
|
docs = self.doc_retrieval.run(query)
|
||||||
|
docs = [
|
||||||
|
{"index": idx, "snippet": doc.page_content, "title": doc.metadata.get("title_prefix", ""), "link": doc.metadata.get("url", "")}
|
||||||
|
for idx, doc in enumerate(docs)
|
||||||
|
]
|
||||||
message.db_docs = [Doc(**doc) for doc in docs]
|
message.db_docs = [Doc(**doc) for doc in docs]
|
||||||
|
else:
|
||||||
|
if knowledge_basename:
|
||||||
|
docs = DocRetrieval.run(query, knowledge_basename, top_k, score_threshold, self.embed_config, self.kb_root_path)
|
||||||
|
message.db_docs = [Doc(**doc) for doc in docs]
|
||||||
return message
|
return message
|
||||||
|
|
||||||
def get_code_retrieval(self, message: Message) -> Message:
|
def get_code_retrieval(self, message: Message) -> Message:
|
||||||
query = message.input_query
|
query = message.role_content
|
||||||
code_engine_name = message.code_engine_name
|
code_engine_name = message.code_engine_name
|
||||||
history_node_list = message.history_node_list
|
history_node_list = message.history_node_list
|
||||||
code_docs = CodeRetrieval.run(code_engine_name, query, code_limit=message.top_k, history_node_list=history_node_list, search_type=message.cb_search_type,
|
|
||||||
llm_config=self.llm_config, embed_config=self.embed_config,)
|
use_nh = message.use_nh
|
||||||
message.code_docs = [CodeDoc(**doc) for doc in code_docs]
|
local_graph_path = message.local_graph_path
|
||||||
|
|
||||||
|
if self.code_retrieval:
|
||||||
|
code_docs = self.code_retrieval.run(
|
||||||
|
query, history_node_list=history_node_list, search_type=message.cb_search_type,
|
||||||
|
code_limit=1
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
code_docs = CodeRetrieval.run(code_engine_name, query, code_limit=message.top_k, history_node_list=history_node_list, search_type=message.cb_search_type,
|
||||||
|
llm_config=self.llm_config, embed_config=self.embed_config,
|
||||||
|
use_nh=use_nh, local_graph_path=local_graph_path)
|
||||||
|
|
||||||
|
message.code_docs = [CodeDoc(**doc) for doc in code_docs]
|
||||||
|
|
||||||
|
# related_nodes = [doc.get_related_node() for idx, doc in enumerate(message.code_docs) if idx==0],
|
||||||
|
# history_node_list.extend([node[0] for node in related_nodes])
|
||||||
return message
|
return message
|
||||||
|
|
||||||
def get_tool_retrieval(self, message: Message) -> Message:
|
def get_tool_retrieval(self, message: Message) -> Message:
|
||||||
|
@ -160,6 +197,7 @@ class MessageUtils:
|
||||||
if code_answer.code_exe_type == "error" else f"The return information after executing the above code is {code_answer.code_exe_response}.\n"
|
if code_answer.code_exe_type == "error" else f"The return information after executing the above code is {code_answer.code_exe_response}.\n"
|
||||||
|
|
||||||
observation_message = Message(
|
observation_message = Message(
|
||||||
|
user_name=message.user_name,
|
||||||
role_name="observation",
|
role_name="observation",
|
||||||
role_type="function", #self.role.role_type,
|
role_type="function", #self.role.role_type,
|
||||||
role_content="",
|
role_content="",
|
||||||
|
@ -190,6 +228,7 @@ class MessageUtils:
|
||||||
def tool_step(self, message: Message) -> Message:
|
def tool_step(self, message: Message) -> Message:
|
||||||
'''execute tool'''
|
'''execute tool'''
|
||||||
observation_message = Message(
|
observation_message = Message(
|
||||||
|
user_name=message.user_name,
|
||||||
role_name="observation",
|
role_name="observation",
|
||||||
role_type="function", #self.role.role_type,
|
role_type="function", #self.role.role_type,
|
||||||
role_content="\n**Observation:** there is no tool can execute\n",
|
role_content="\n**Observation:** there is no tool can execute\n",
|
||||||
|
@ -226,7 +265,7 @@ class MessageUtils:
|
||||||
return message, observation_message
|
return message, observation_message
|
||||||
|
|
||||||
def parser(self, message: Message) -> Message:
|
def parser(self, message: Message) -> Message:
|
||||||
''''''
|
'''parse llm output into dict'''
|
||||||
content = message.role_content
|
content = message.role_content
|
||||||
# parse start
|
# parse start
|
||||||
parsed_dict = parse_text_to_dict(content)
|
parsed_dict = parse_text_to_dict(content)
|
||||||
|
|
|
@ -5,6 +5,8 @@ import importlib
|
||||||
import copy
|
import copy
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
|
from langchain.schema import BaseRetriever
|
||||||
|
|
||||||
from coagent.connector.agents import BaseAgent
|
from coagent.connector.agents import BaseAgent
|
||||||
from coagent.connector.chains import BaseChain
|
from coagent.connector.chains import BaseChain
|
||||||
from coagent.connector.schema import (
|
from coagent.connector.schema import (
|
||||||
|
@ -18,9 +20,6 @@ from coagent.connector.message_process import MessageUtils
|
||||||
from coagent.llm_models.llm_config import EmbedConfig, LLMConfig
|
from coagent.llm_models.llm_config import EmbedConfig, LLMConfig
|
||||||
from coagent.base_configs.env_config import JUPYTER_WORK_PATH, KB_ROOT_PATH
|
from coagent.base_configs.env_config import JUPYTER_WORK_PATH, KB_ROOT_PATH
|
||||||
|
|
||||||
# from configs.model_config import JUPYTER_WORK_PATH, KB_ROOT_PATH
|
|
||||||
# from configs.server_config import SANDBOX_SERVER
|
|
||||||
|
|
||||||
|
|
||||||
role_configs = load_role_configs(AGETN_CONFIGS)
|
role_configs = load_role_configs(AGETN_CONFIGS)
|
||||||
chain_configs = load_chain_configs(CHAIN_CONFIGS)
|
chain_configs = load_chain_configs(CHAIN_CONFIGS)
|
||||||
|
@ -39,20 +38,24 @@ class BasePhase:
|
||||||
kb_root_path: str = KB_ROOT_PATH,
|
kb_root_path: str = KB_ROOT_PATH,
|
||||||
jupyter_work_path: str = JUPYTER_WORK_PATH,
|
jupyter_work_path: str = JUPYTER_WORK_PATH,
|
||||||
sandbox_server: dict = {},
|
sandbox_server: dict = {},
|
||||||
embed_config: EmbedConfig = EmbedConfig(),
|
embed_config: EmbedConfig = None,
|
||||||
llm_config: LLMConfig = LLMConfig(),
|
llm_config: LLMConfig = None,
|
||||||
task: Task = None,
|
task: Task = None,
|
||||||
base_phase_config: Union[dict, str] = PHASE_CONFIGS,
|
base_phase_config: Union[dict, str] = PHASE_CONFIGS,
|
||||||
base_chain_config: Union[dict, str] = CHAIN_CONFIGS,
|
base_chain_config: Union[dict, str] = CHAIN_CONFIGS,
|
||||||
base_role_config: Union[dict, str] = AGETN_CONFIGS,
|
base_role_config: Union[dict, str] = AGETN_CONFIGS,
|
||||||
|
chains: List[BaseChain] = [],
|
||||||
|
doc_retrieval: Union[BaseRetriever] = None,
|
||||||
|
code_retrieval = None,
|
||||||
|
search_retrieval = None,
|
||||||
log_verbose: str = "0"
|
log_verbose: str = "0"
|
||||||
) -> None:
|
) -> None:
|
||||||
#
|
#
|
||||||
self.phase_name = phase_name
|
self.phase_name = phase_name
|
||||||
self.do_summary = False
|
self.do_summary = False
|
||||||
self.do_search = False
|
self.do_search = search_retrieval is not None
|
||||||
self.do_code_retrieval = False
|
self.do_code_retrieval = code_retrieval is not None
|
||||||
self.do_doc_retrieval = False
|
self.do_doc_retrieval = doc_retrieval is not None
|
||||||
self.do_tool_retrieval = False
|
self.do_tool_retrieval = False
|
||||||
# memory_pool dont have specific order
|
# memory_pool dont have specific order
|
||||||
# self.memory_pool = Memory(messages=[])
|
# self.memory_pool = Memory(messages=[])
|
||||||
|
@ -62,12 +65,15 @@ class BasePhase:
|
||||||
self.jupyter_work_path = jupyter_work_path
|
self.jupyter_work_path = jupyter_work_path
|
||||||
self.kb_root_path = kb_root_path
|
self.kb_root_path = kb_root_path
|
||||||
self.log_verbose = max(os.environ.get("log_verbose", "0"), log_verbose)
|
self.log_verbose = max(os.environ.get("log_verbose", "0"), log_verbose)
|
||||||
|
# TODO透传
|
||||||
self.message_utils = MessageUtils(None, sandbox_server, jupyter_work_path, embed_config, llm_config, kb_root_path, log_verbose)
|
self.doc_retrieval = doc_retrieval
|
||||||
|
self.code_retrieval = code_retrieval
|
||||||
|
self.search_retrieval = search_retrieval
|
||||||
|
self.message_utils = MessageUtils(None, sandbox_server, jupyter_work_path, embed_config, llm_config, kb_root_path, doc_retrieval, code_retrieval, search_retrieval, log_verbose)
|
||||||
self.global_memory = Memory(messages=[])
|
self.global_memory = Memory(messages=[])
|
||||||
self.phase_memory: List[Memory] = []
|
self.phase_memory: List[Memory] = []
|
||||||
# according phase name to init the phase contains
|
# according phase name to init the phase contains
|
||||||
self.chains: List[BaseChain] = self.init_chains(
|
self.chains: List[BaseChain] = chains if chains else self.init_chains(
|
||||||
phase_name,
|
phase_name,
|
||||||
phase_config,
|
phase_config,
|
||||||
task=task,
|
task=task,
|
||||||
|
@ -90,7 +96,9 @@ class BasePhase:
|
||||||
kb_root_path=kb_root_path
|
kb_root_path=kb_root_path
|
||||||
)
|
)
|
||||||
|
|
||||||
def astep(self, query: Message, history: Memory = None) -> Tuple[Message, Memory]:
|
def astep(self, query: Message, history: Memory = None, reinit_memory=False) -> Tuple[Message, Memory]:
|
||||||
|
if reinit_memory:
|
||||||
|
self.memory_manager.re_init(reinit_memory)
|
||||||
self.memory_manager.append(query)
|
self.memory_manager.append(query)
|
||||||
summary_message = None
|
summary_message = None
|
||||||
chain_message = Memory(messages=[])
|
chain_message = Memory(messages=[])
|
||||||
|
@ -139,8 +147,8 @@ class BasePhase:
|
||||||
message.role_name = self.phase_name
|
message.role_name = self.phase_name
|
||||||
yield message, local_phase_memory
|
yield message, local_phase_memory
|
||||||
|
|
||||||
def step(self, query: Message, history: Memory = None) -> Tuple[Message, Memory]:
|
def step(self, query: Message, history: Memory = None, reinit_memory=False) -> Tuple[Message, Memory]:
|
||||||
for message, local_phase_memory in self.astep(query, history=history):
|
for message, local_phase_memory in self.astep(query, history=history, reinit_memory=reinit_memory):
|
||||||
pass
|
pass
|
||||||
return message, local_phase_memory
|
return message, local_phase_memory
|
||||||
|
|
||||||
|
@ -194,6 +202,9 @@ class BasePhase:
|
||||||
sandbox_server=self.sandbox_server,
|
sandbox_server=self.sandbox_server,
|
||||||
jupyter_work_path=self.jupyter_work_path,
|
jupyter_work_path=self.jupyter_work_path,
|
||||||
kb_root_path=self.kb_root_path,
|
kb_root_path=self.kb_root_path,
|
||||||
|
doc_retrieval=self.doc_retrieval,
|
||||||
|
code_retrieval=self.code_retrieval,
|
||||||
|
search_retrieval=self.search_retrieval,
|
||||||
log_verbose=self.log_verbose
|
log_verbose=self.log_verbose
|
||||||
)
|
)
|
||||||
if agent_config.role.agent_type == "SelectorAgent":
|
if agent_config.role.agent_type == "SelectorAgent":
|
||||||
|
@ -205,7 +216,7 @@ class BasePhase:
|
||||||
group_base_agent = baseAgent(
|
group_base_agent = baseAgent(
|
||||||
role=group_agent_config.role,
|
role=group_agent_config.role,
|
||||||
prompt_config = group_agent_config.prompt_config,
|
prompt_config = group_agent_config.prompt_config,
|
||||||
prompt_manager_type=agent_config.prompt_manager_type,
|
prompt_manager_type=group_agent_config.prompt_manager_type,
|
||||||
task = task,
|
task = task,
|
||||||
memory = memory,
|
memory = memory,
|
||||||
chat_turn=group_agent_config.chat_turn,
|
chat_turn=group_agent_config.chat_turn,
|
||||||
|
@ -216,6 +227,9 @@ class BasePhase:
|
||||||
sandbox_server=self.sandbox_server,
|
sandbox_server=self.sandbox_server,
|
||||||
jupyter_work_path=self.jupyter_work_path,
|
jupyter_work_path=self.jupyter_work_path,
|
||||||
kb_root_path=self.kb_root_path,
|
kb_root_path=self.kb_root_path,
|
||||||
|
doc_retrieval=self.doc_retrieval,
|
||||||
|
code_retrieval=self.code_retrieval,
|
||||||
|
search_retrieval=self.search_retrieval,
|
||||||
log_verbose=self.log_verbose
|
log_verbose=self.log_verbose
|
||||||
)
|
)
|
||||||
base_agent.group_agents.append(group_base_agent)
|
base_agent.group_agents.append(group_base_agent)
|
||||||
|
@ -223,13 +237,16 @@ class BasePhase:
|
||||||
agents.append(base_agent)
|
agents.append(base_agent)
|
||||||
|
|
||||||
chain_instance = BaseChain(
|
chain_instance = BaseChain(
|
||||||
agents, chain_config.chat_turn,
|
chain_config,
|
||||||
do_checker=chain_configs[chain_name].do_checker,
|
agents,
|
||||||
jupyter_work_path=self.jupyter_work_path,
|
jupyter_work_path=self.jupyter_work_path,
|
||||||
sandbox_server=self.sandbox_server,
|
sandbox_server=self.sandbox_server,
|
||||||
embed_config=self.embed_config,
|
embed_config=self.embed_config,
|
||||||
llm_config=self.llm_config,
|
llm_config=self.llm_config,
|
||||||
kb_root_path=self.kb_root_path,
|
kb_root_path=self.kb_root_path,
|
||||||
|
doc_retrieval=self.doc_retrieval,
|
||||||
|
code_retrieval=self.code_retrieval,
|
||||||
|
search_retrieval=self.search_retrieval,
|
||||||
log_verbose=self.log_verbose
|
log_verbose=self.log_verbose
|
||||||
)
|
)
|
||||||
chains.append(chain_instance)
|
chains.append(chain_instance)
|
||||||
|
|
|
@ -0,0 +1,2 @@
|
||||||
|
from .prompt_manager import PromptManager
|
||||||
|
from .extend_manager import *
|
|
@ -0,0 +1,45 @@
|
||||||
|
|
||||||
|
from coagent.connector.schema import Message
|
||||||
|
from .prompt_manager import PromptManager
|
||||||
|
|
||||||
|
|
||||||
|
class Code2DocPM(PromptManager):
|
||||||
|
def handle_code_snippet(self, **kwargs) -> str:
|
||||||
|
if 'previous_agent_message' not in kwargs:
|
||||||
|
return ""
|
||||||
|
previous_agent_message: Message = kwargs['previous_agent_message']
|
||||||
|
code_snippet = previous_agent_message.customed_kargs.get("Code Snippet", "")
|
||||||
|
current_vertex = previous_agent_message.customed_kargs.get("Current_Vertex", "")
|
||||||
|
instruction = "A segment of code that contains the function or method to be documented.\n"
|
||||||
|
return instruction + "\n" + f"name: {current_vertex}\n{code_snippet}"
|
||||||
|
|
||||||
|
def handle_specific_objective(self, **kwargs) -> str:
|
||||||
|
if 'previous_agent_message' not in kwargs:
|
||||||
|
return ""
|
||||||
|
previous_agent_message: Message = kwargs['previous_agent_message']
|
||||||
|
specific_objective = previous_agent_message.parsed_output.get("Code Path")
|
||||||
|
|
||||||
|
instruction = "Provide the code path of the function or method you wish to document.\n"
|
||||||
|
s = instruction + f"\n{specific_objective}"
|
||||||
|
return s
|
||||||
|
|
||||||
|
|
||||||
|
class CodeRetrievalPM(PromptManager):
|
||||||
|
def handle_code_snippet(self, **kwargs) -> str:
|
||||||
|
if 'previous_agent_message' not in kwargs:
|
||||||
|
return ""
|
||||||
|
previous_agent_message: Message = kwargs['previous_agent_message']
|
||||||
|
code_snippet = previous_agent_message.customed_kargs.get("Code Snippet", "")
|
||||||
|
current_vertex = previous_agent_message.customed_kargs.get("Current_Vertex", "")
|
||||||
|
instruction = "the initial Code or objective that the user wanted to achieve"
|
||||||
|
return instruction + "\n" + f"name: {current_vertex}\n{code_snippet}"
|
||||||
|
|
||||||
|
def handle_retrieval_codes(self, **kwargs) -> str:
|
||||||
|
if 'previous_agent_message' not in kwargs:
|
||||||
|
return ""
|
||||||
|
previous_agent_message: Message = kwargs['previous_agent_message']
|
||||||
|
Retrieval_Codes = previous_agent_message.customed_kargs["Retrieval_Codes"]
|
||||||
|
Relative_vertex = previous_agent_message.customed_kargs["Relative_vertex"]
|
||||||
|
instruction = "the initial Code or objective that the user wanted to achieve"
|
||||||
|
s = instruction + "\n" + "\n".join([f"name: {vertext}\n{code}" for vertext, code in zip(Relative_vertex, Retrieval_Codes)])
|
||||||
|
return s
|
|
@ -0,0 +1,353 @@
|
||||||
|
import random
|
||||||
|
from textwrap import dedent
|
||||||
|
import copy
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from langchain.agents.tools import Tool
|
||||||
|
|
||||||
|
from coagent.connector.schema import Memory, Message
|
||||||
|
from coagent.connector.utils import extract_section, parse_section
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class PromptManager:
|
||||||
|
def __init__(self, role_prompt="", prompt_config=None, monitored_agents=[], monitored_fields=[]):
|
||||||
|
self.role_prompt = role_prompt
|
||||||
|
self.monitored_agents = monitored_agents
|
||||||
|
self.monitored_fields = monitored_fields
|
||||||
|
self.field_handlers = {}
|
||||||
|
self.context_handlers = {}
|
||||||
|
self.field_order = [] # 用于普通字段的顺序
|
||||||
|
self.context_order = [] # 单独维护上下文字段的顺序
|
||||||
|
self.field_descriptions = {}
|
||||||
|
self.omit_if_empty_flags = {}
|
||||||
|
self.context_title = "### Context Data\n\n"
|
||||||
|
|
||||||
|
self.prompt_config = prompt_config
|
||||||
|
if self.prompt_config:
|
||||||
|
self.register_fields_from_config()
|
||||||
|
|
||||||
|
def register_field(self, field_name, function=None, title=None, description=None, is_context=True, omit_if_empty=True):
|
||||||
|
"""
|
||||||
|
注册一个新的字段及其处理函数。
|
||||||
|
Args:
|
||||||
|
field_name (str): 字段名称。
|
||||||
|
function (callable): 处理字段数据的函数。
|
||||||
|
title (str, optional): 字段的自定义标题(可选)。
|
||||||
|
description (str, optional): 字段的描述(可选,可以是几句话)。
|
||||||
|
is_context (bool, optional): 指示该字段是否为上下文字段。
|
||||||
|
omit_if_empty (bool, optional): 如果数据为空,是否省略该字段。
|
||||||
|
"""
|
||||||
|
if not function:
|
||||||
|
function = self.handle_custom_data
|
||||||
|
|
||||||
|
# Register the handler function based on context flag
|
||||||
|
if is_context:
|
||||||
|
self.context_handlers[field_name] = function
|
||||||
|
else:
|
||||||
|
self.field_handlers[field_name] = function
|
||||||
|
|
||||||
|
# Store the custom title if provided and adjust the title prefix based on context
|
||||||
|
title_prefix = "####" if is_context else "###"
|
||||||
|
if title is not None:
|
||||||
|
self.field_descriptions[field_name] = f"{title_prefix} {title}\n\n"
|
||||||
|
elif description is not None:
|
||||||
|
# If title is not provided but description is, use description as title
|
||||||
|
self.field_descriptions[field_name] = f"{title_prefix} {field_name.replace('_', ' ').title()}\n\n{description}\n\n"
|
||||||
|
else:
|
||||||
|
# If neither title nor description is provided, use the field name as title
|
||||||
|
self.field_descriptions[field_name] = f"{title_prefix} {field_name.replace('_', ' ').title()}\n\n"
|
||||||
|
|
||||||
|
# Store the omit_if_empty flag for this field
|
||||||
|
self.omit_if_empty_flags[field_name] = omit_if_empty
|
||||||
|
|
||||||
|
if is_context and field_name != 'context_placeholder':
|
||||||
|
self.context_handlers[field_name] = function
|
||||||
|
self.context_order.append(field_name)
|
||||||
|
else:
|
||||||
|
self.field_handlers[field_name] = function
|
||||||
|
self.field_order.append(field_name)
|
||||||
|
|
||||||
|
def generate_full_prompt(self, **kwargs):
|
||||||
|
full_prompt = []
|
||||||
|
context_prompts = [] # 用于收集上下文内容
|
||||||
|
is_pre_print = kwargs.get("is_pre_print", False) # 用于强制打印所有prompt 字段信息,不管有没有空
|
||||||
|
|
||||||
|
# 先处理上下文字段
|
||||||
|
for field_name in self.context_order:
|
||||||
|
handler = self.context_handlers[field_name]
|
||||||
|
processed_prompt = handler(**kwargs)
|
||||||
|
# Check if the field should be omitted when empty
|
||||||
|
if self.omit_if_empty_flags.get(field_name, False) and not processed_prompt and not is_pre_print:
|
||||||
|
continue # Skip this field
|
||||||
|
title_or_description = self.field_descriptions.get(field_name, f"#### {field_name.replace('_', ' ').title()}\n\n")
|
||||||
|
context_prompts.append(title_or_description + processed_prompt + '\n\n')
|
||||||
|
|
||||||
|
# 处理普通字段,同时查找 context_placeholder 的位置
|
||||||
|
for field_name in self.field_order:
|
||||||
|
if field_name == 'context_placeholder':
|
||||||
|
# 在 context_placeholder 的位置插入上下文数据
|
||||||
|
full_prompt.append(self.context_title) # 添加上下文部分的大标题
|
||||||
|
full_prompt.extend(context_prompts) # 添加收集的上下文内容
|
||||||
|
else:
|
||||||
|
handler = self.field_handlers[field_name]
|
||||||
|
processed_prompt = handler(**kwargs)
|
||||||
|
# Check if the field should be omitted when empty
|
||||||
|
if self.omit_if_empty_flags.get(field_name, False) and not processed_prompt and not is_pre_print:
|
||||||
|
continue # Skip this field
|
||||||
|
title_or_description = self.field_descriptions.get(field_name, f"### {field_name.replace('_', ' ').title()}\n\n")
|
||||||
|
full_prompt.append(title_or_description + processed_prompt + '\n\n')
|
||||||
|
|
||||||
|
# 返回完整的提示,移除尾部的空行
|
||||||
|
return ''.join(full_prompt).rstrip('\n')
|
||||||
|
|
||||||
|
def pre_print(self, **kwargs):
|
||||||
|
kwargs.update({"is_pre_print": True})
|
||||||
|
prompt = self.generate_full_prompt(**kwargs)
|
||||||
|
|
||||||
|
input_keys = parse_section(self.role_prompt, 'Response Output Format')
|
||||||
|
llm_predict = "\n".join([f"**{k}:**" for k in input_keys])
|
||||||
|
return prompt + "\n\n" + "#"*19 + "\n<<<<LLM PREDICT>>>>\n" + "#"*19 + f"\n\n{llm_predict}\n"
|
||||||
|
|
||||||
|
def handle_custom_data(self, **kwargs):
|
||||||
|
return ""
|
||||||
|
|
||||||
|
def handle_tool_data(self, **kwargs):
|
||||||
|
if 'previous_agent_message' not in kwargs:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
previous_agent_message = kwargs.get('previous_agent_message')
|
||||||
|
tools: list[Tool] = previous_agent_message.tools
|
||||||
|
|
||||||
|
if not tools:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
tool_strings = []
|
||||||
|
for tool in tools:
|
||||||
|
args_str = f'args: {str(tool.args)}' if tool.args_schema else ""
|
||||||
|
tool_strings.append(f"{tool.name}: {tool.description}, {args_str}")
|
||||||
|
formatted_tools = "\n".join(tool_strings)
|
||||||
|
|
||||||
|
tool_names = ", ".join([tool.name for tool in tools])
|
||||||
|
|
||||||
|
tool_prompt = dedent(f"""
|
||||||
|
Below is a list of tools that are available for your use:
|
||||||
|
{formatted_tools}
|
||||||
|
|
||||||
|
valid "tool_name" value is:
|
||||||
|
{tool_names}
|
||||||
|
""")
|
||||||
|
|
||||||
|
return tool_prompt
|
||||||
|
|
||||||
|
def handle_agent_data(self, **kwargs):
|
||||||
|
if 'agents' not in kwargs:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
agents = kwargs.get('agents')
|
||||||
|
random.shuffle(agents)
|
||||||
|
agent_names = ", ".join([f'{agent.role.role_name}' for agent in agents])
|
||||||
|
agent_descs = []
|
||||||
|
for agent in agents:
|
||||||
|
role_desc = agent.role.role_prompt.split("####")[1]
|
||||||
|
while "\n\n" in role_desc:
|
||||||
|
role_desc = role_desc.replace("\n\n", "\n")
|
||||||
|
role_desc = role_desc.replace("\n", ",")
|
||||||
|
|
||||||
|
agent_descs.append(f'"role name: {agent.role.role_name}\nrole description: {role_desc}"')
|
||||||
|
|
||||||
|
agents = "\n".join(agent_descs)
|
||||||
|
agent_prompt = f'''
|
||||||
|
Please ensure your selection is one of the listed roles. Available roles for selection:
|
||||||
|
{agents}
|
||||||
|
Please ensure select the Role from agent names, such as {agent_names}'''
|
||||||
|
|
||||||
|
return dedent(agent_prompt)
|
||||||
|
|
||||||
|
def handle_doc_info(self, **kwargs) -> str:
|
||||||
|
if 'previous_agent_message' not in kwargs:
|
||||||
|
return ""
|
||||||
|
previous_agent_message: Message = kwargs.get('previous_agent_message')
|
||||||
|
db_docs = previous_agent_message.db_docs
|
||||||
|
search_docs = previous_agent_message.search_docs
|
||||||
|
code_cocs = previous_agent_message.code_docs
|
||||||
|
doc_infos = "\n".join([doc.get_snippet() for doc in db_docs] + [doc.get_snippet() for doc in search_docs] +
|
||||||
|
[doc.get_code() for doc in code_cocs])
|
||||||
|
return doc_infos
|
||||||
|
|
||||||
|
def handle_session_records(self, **kwargs) -> str:
|
||||||
|
|
||||||
|
memory_pool: Memory = kwargs.get('memory_pool', Memory(messages=[]))
|
||||||
|
memory_pool = self.select_memory_by_agent_name(memory_pool)
|
||||||
|
memory_pool = self.select_memory_by_parsed_key(memory_pool)
|
||||||
|
|
||||||
|
return memory_pool.to_str_messages(content_key="parsed_output_list", with_tag=True)
|
||||||
|
|
||||||
|
def handle_current_plan(self, **kwargs) -> str:
|
||||||
|
if 'previous_agent_message' not in kwargs:
|
||||||
|
return ""
|
||||||
|
previous_agent_message = kwargs['previous_agent_message']
|
||||||
|
return previous_agent_message.parsed_output.get("CURRENT_STEP", "")
|
||||||
|
|
||||||
|
def handle_agent_profile(self, **kwargs) -> str:
|
||||||
|
return extract_section(self.role_prompt, 'Agent Profile')
|
||||||
|
|
||||||
|
def handle_output_format(self, **kwargs) -> str:
|
||||||
|
return extract_section(self.role_prompt, 'Response Output Format')
|
||||||
|
|
||||||
|
def handle_response(self, **kwargs) -> str:
|
||||||
|
if 'react_memory' not in kwargs:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
react_memory = kwargs.get('react_memory', Memory(messages=[]))
|
||||||
|
if react_memory is None:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
return "\n".join(["\n".join([f"**{k}:**\n{v}" for k,v in _dict.items()]) for _dict in react_memory.get_parserd_output()])
|
||||||
|
|
||||||
|
def handle_task_records(self, **kwargs) -> str:
|
||||||
|
if 'task_memory' not in kwargs:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
task_memory: Memory = kwargs.get('task_memory', Memory(messages=[]))
|
||||||
|
if task_memory is None:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
return "\n".join(["\n".join([f"**{k}:**\n{v}" for k,v in _dict.items() if k not in ["CURRENT_STEP"]]) for _dict in task_memory.get_parserd_output()])
|
||||||
|
|
||||||
|
def handle_previous_message(self, message: Message) -> str:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def handle_message_by_role_name(self, message: Message) -> str:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def handle_message_by_role_type(self, message: Message) -> str:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def handle_current_agent_react_message(self, message: Message) -> str:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def extract_codedoc_info_for_prompt(self, message: Message) -> str:
|
||||||
|
code_docs = message.code_docs
|
||||||
|
doc_infos = "\n".join([doc.get_code() for doc in code_docs])
|
||||||
|
return doc_infos
|
||||||
|
|
||||||
|
def select_memory_by_parsed_key(self, memory: Memory) -> Memory:
|
||||||
|
return Memory(
|
||||||
|
messages=[self.select_message_by_parsed_key(message) for message in memory.messages
|
||||||
|
if self.select_message_by_parsed_key(message) is not None]
|
||||||
|
)
|
||||||
|
|
||||||
|
def select_memory_by_agent_name(self, memory: Memory) -> Memory:
|
||||||
|
return Memory(
|
||||||
|
messages=[self.select_message_by_agent_name(message) for message in memory.messages
|
||||||
|
if self.select_message_by_agent_name(message) is not None]
|
||||||
|
)
|
||||||
|
|
||||||
|
def select_message_by_agent_name(self, message: Message) -> Message:
|
||||||
|
# assume we focus all agents
|
||||||
|
if self.monitored_agents == []:
|
||||||
|
return message
|
||||||
|
return None if message is None or message.role_name not in self.monitored_agents else self.select_message_by_parsed_key(message)
|
||||||
|
|
||||||
|
def select_message_by_parsed_key(self, message: Message) -> Message:
|
||||||
|
# assume we focus all key contents
|
||||||
|
if message is None:
|
||||||
|
return message
|
||||||
|
|
||||||
|
if self.monitored_fields == []:
|
||||||
|
return message
|
||||||
|
|
||||||
|
message_c = copy.deepcopy(message)
|
||||||
|
message_c.parsed_output = {k: v for k,v in message_c.parsed_output.items() if k in self.monitored_fields}
|
||||||
|
message_c.parsed_output_list = [{k: v for k,v in parsed_output.items() if k in self.monitored_fields} for parsed_output in message_c.parsed_output_list]
|
||||||
|
return message_c
|
||||||
|
|
||||||
|
def get_memory(self, content_key="role_content"):
|
||||||
|
return self.memory.to_tuple_messages(content_key="step_content")
|
||||||
|
|
||||||
|
def get_memory_str(self, content_key="role_content"):
|
||||||
|
return "\n".join([": ".join(i) for i in self.memory.to_tuple_messages(content_key="step_content")])
|
||||||
|
|
||||||
|
def register_fields_from_config(self):
|
||||||
|
|
||||||
|
for prompt_field in self.prompt_config:
|
||||||
|
|
||||||
|
function_name = prompt_field.function_name
|
||||||
|
# 检查function_name是否是self的一个方法
|
||||||
|
if function_name and hasattr(self, function_name):
|
||||||
|
function = getattr(self, function_name)
|
||||||
|
else:
|
||||||
|
function = self.handle_custom_data
|
||||||
|
|
||||||
|
self.register_field(prompt_field.field_name,
|
||||||
|
function=function,
|
||||||
|
title=prompt_field.title,
|
||||||
|
description=prompt_field.description,
|
||||||
|
is_context=prompt_field.is_context,
|
||||||
|
omit_if_empty=prompt_field.omit_if_empty)
|
||||||
|
|
||||||
|
def register_standard_fields(self):
|
||||||
|
self.register_field('agent_profile', function=self.handle_agent_profile, is_context=False)
|
||||||
|
self.register_field('tool_information', function=self.handle_tool_data, is_context=False)
|
||||||
|
self.register_field('context_placeholder', is_context=True) # 用于标记上下文数据部分的位置
|
||||||
|
self.register_field('reference_documents', function=self.handle_doc_info, is_context=True)
|
||||||
|
self.register_field('session_records', function=self.handle_session_records, is_context=True)
|
||||||
|
self.register_field('output_format', function=self.handle_output_format, title='Response Output Format', is_context=False)
|
||||||
|
self.register_field('response', function=self.handle_response, is_context=False, omit_if_empty=False)
|
||||||
|
|
||||||
|
def register_executor_fields(self):
|
||||||
|
self.register_field('agent_profile', function=self.handle_agent_profile, is_context=False)
|
||||||
|
self.register_field('tool_information', function=self.handle_tool_data, is_context=False)
|
||||||
|
self.register_field('context_placeholder', is_context=True) # 用于标记上下文数据部分的位置
|
||||||
|
self.register_field('reference_documents', function=self.handle_doc_info, is_context=True)
|
||||||
|
self.register_field('session_records', function=self.handle_session_records, is_context=True)
|
||||||
|
self.register_field('current_plan', function=self.handle_current_plan, is_context=True)
|
||||||
|
self.register_field('output_format', function=self.handle_output_format, title='Response Output Format', is_context=False)
|
||||||
|
self.register_field('response', function=self.handle_response, is_context=False, omit_if_empty=False)
|
||||||
|
|
||||||
|
def register_fields_from_dict(self, fields_dict):
|
||||||
|
# 使用字典注册字段的函数
|
||||||
|
for field_name, field_config in fields_dict.items():
|
||||||
|
function_name = field_config.get('function', None)
|
||||||
|
title = field_config.get('title', None)
|
||||||
|
description = field_config.get('description', None)
|
||||||
|
is_context = field_config.get('is_context', True)
|
||||||
|
omit_if_empty = field_config.get('omit_if_empty', True)
|
||||||
|
|
||||||
|
# 检查function_name是否是self的一个方法
|
||||||
|
if function_name and hasattr(self, function_name):
|
||||||
|
function = getattr(self, function_name)
|
||||||
|
else:
|
||||||
|
function = self.handle_custom_data
|
||||||
|
|
||||||
|
# 调用已存在的register_field方法注册字段
|
||||||
|
self.register_field(field_name, function=function, title=title, description=description, is_context=is_context, omit_if_empty=omit_if_empty)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
manager = PromptManager()
|
||||||
|
manager.register_standard_fields()
|
||||||
|
|
||||||
|
manager.register_field('agents_work_progress', title=f"Agents' Work Progress", is_context=True)
|
||||||
|
|
||||||
|
# 创建数据字典
|
||||||
|
data_dict = {
|
||||||
|
"agent_profile": "这是代理配置文件...",
|
||||||
|
# "tool_list": "这是工具列表...",
|
||||||
|
"reference_documents": "这是参考文档...",
|
||||||
|
"session_records": "这是会话记录...",
|
||||||
|
"agents_work_progress": "这是代理工作进展...",
|
||||||
|
"output_format": "这是预期的输出格式...",
|
||||||
|
# "response": "这是生成或继续回应的指令...",
|
||||||
|
"response": "",
|
||||||
|
"test": 'xxxxx'
|
||||||
|
}
|
||||||
|
|
||||||
|
# 组合完整的提示
|
||||||
|
full_prompt = manager.generate_full_prompt(data_dict)
|
||||||
|
print(full_prompt)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
|
@ -215,15 +215,15 @@ class Env(BaseModel):
|
||||||
class Role(BaseModel):
|
class Role(BaseModel):
|
||||||
role_type: str
|
role_type: str
|
||||||
role_name: str
|
role_name: str
|
||||||
role_desc: str
|
role_desc: str = ""
|
||||||
agent_type: str = ""
|
agent_type: str = "BaseAgent"
|
||||||
role_prompt: str = ""
|
role_prompt: str = ""
|
||||||
template_prompt: str = ""
|
template_prompt: str = ""
|
||||||
|
|
||||||
|
|
||||||
class ChainConfig(BaseModel):
|
class ChainConfig(BaseModel):
|
||||||
chain_name: str
|
chain_name: str
|
||||||
chain_type: str
|
chain_type: str = "BaseChain"
|
||||||
agents: List[str]
|
agents: List[str]
|
||||||
do_checker: bool = False
|
do_checker: bool = False
|
||||||
chat_turn: int = 1
|
chat_turn: int = 1
|
||||||
|
|
|
@ -131,6 +131,9 @@ class Memory(BaseModel):
|
||||||
# logger.debug(f"{message.role_name}: {message.parsed_output_list}")
|
# logger.debug(f"{message.role_name}: {message.parsed_output_list}")
|
||||||
# return [parsed_output for message in self.messages for parsed_output in message.parsed_output_list[1:]]
|
# return [parsed_output for message in self.messages for parsed_output in message.parsed_output_list[1:]]
|
||||||
return [parsed_output for message in self.messages for parsed_output in message.parsed_output_list]
|
return [parsed_output for message in self.messages for parsed_output in message.parsed_output_list]
|
||||||
|
|
||||||
|
def get_spec_parserd_output(self, ):
|
||||||
|
return [message.spec_parsed_output for message in self.messages]
|
||||||
|
|
||||||
def get_rolenames(self, ):
|
def get_rolenames(self, ):
|
||||||
''''''
|
''''''
|
||||||
|
|
|
@ -7,6 +7,7 @@ from .general_schema import *
|
||||||
|
|
||||||
class Message(BaseModel):
|
class Message(BaseModel):
|
||||||
chat_index: str = None
|
chat_index: str = None
|
||||||
|
user_name: str = "default"
|
||||||
role_name: str
|
role_name: str
|
||||||
role_type: str
|
role_type: str
|
||||||
role_prompt: str = None
|
role_prompt: str = None
|
||||||
|
@ -53,6 +54,8 @@ class Message(BaseModel):
|
||||||
cb_search_type: str = None
|
cb_search_type: str = None
|
||||||
search_engine_name: str = None
|
search_engine_name: str = None
|
||||||
top_k: int = 3
|
top_k: int = 3
|
||||||
|
use_nh: bool = True
|
||||||
|
local_graph_path: str = ''
|
||||||
score_threshold: float = 1.0
|
score_threshold: float = 1.0
|
||||||
do_doc_retrieval: bool = False
|
do_doc_retrieval: bool = False
|
||||||
do_code_retrieval: bool = False
|
do_code_retrieval: bool = False
|
||||||
|
|
|
@ -72,20 +72,25 @@ def parse_text_to_dict(text):
|
||||||
def parse_dict_to_dict(parsed_dict) -> dict:
|
def parse_dict_to_dict(parsed_dict) -> dict:
|
||||||
code_pattern = r'```python\n(.*?)```'
|
code_pattern = r'```python\n(.*?)```'
|
||||||
tool_pattern = r'```json\n(.*?)```'
|
tool_pattern = r'```json\n(.*?)```'
|
||||||
|
java_pattern = r'```java\n(.*?)```'
|
||||||
|
|
||||||
pattern_dict = {"code": code_pattern, "json": tool_pattern}
|
pattern_dict = {"code": code_pattern, "json": tool_pattern, "java": java_pattern}
|
||||||
spec_parsed_dict = copy.deepcopy(parsed_dict)
|
spec_parsed_dict = copy.deepcopy(parsed_dict)
|
||||||
for key, pattern in pattern_dict.items():
|
for key, pattern in pattern_dict.items():
|
||||||
for k, text in parsed_dict.items():
|
for k, text in parsed_dict.items():
|
||||||
# Search for the code block
|
# Search for the code block
|
||||||
if not isinstance(text, str): continue
|
if not isinstance(text, str):
|
||||||
|
spec_parsed_dict[k] = text
|
||||||
|
continue
|
||||||
_match = re.search(pattern, text, re.DOTALL)
|
_match = re.search(pattern, text, re.DOTALL)
|
||||||
if _match:
|
if _match:
|
||||||
# Add the code block to the dictionary
|
# Add the code block to the dictionary
|
||||||
try:
|
try:
|
||||||
spec_parsed_dict[key] = json.loads(_match.group(1).strip())
|
spec_parsed_dict[key] = json.loads(_match.group(1).strip())
|
||||||
|
spec_parsed_dict[k] = json.loads(_match.group(1).strip())
|
||||||
except:
|
except:
|
||||||
spec_parsed_dict[key] = _match.group(1).strip()
|
spec_parsed_dict[key] = _match.group(1).strip()
|
||||||
|
spec_parsed_dict[k] = _match.group(1).strip()
|
||||||
break
|
break
|
||||||
return spec_parsed_dict
|
return spec_parsed_dict
|
||||||
|
|
||||||
|
|
|
@ -43,7 +43,7 @@ class NebulaHandler:
|
||||||
elif self.space_name:
|
elif self.space_name:
|
||||||
cypher = f'USE {self.space_name};{cypher}'
|
cypher = f'USE {self.space_name};{cypher}'
|
||||||
|
|
||||||
logger.debug(cypher)
|
# logger.debug(cypher)
|
||||||
resp = session.execute(cypher)
|
resp = session.execute(cypher)
|
||||||
|
|
||||||
if format_res:
|
if format_res:
|
||||||
|
@ -247,6 +247,24 @@ class NebulaHandler:
|
||||||
res = self.execute_cypher(cypher, self.space_name)
|
res = self.execute_cypher(cypher, self.space_name)
|
||||||
return self.result_to_dict(res)
|
return self.result_to_dict(res)
|
||||||
|
|
||||||
|
def get_all_vertices(self,):
|
||||||
|
'''
|
||||||
|
get all vertices
|
||||||
|
@return:
|
||||||
|
'''
|
||||||
|
cypher = "MATCH (v) RETURN v;"
|
||||||
|
res = self.execute_cypher(cypher, self.space_name)
|
||||||
|
return self.result_to_dict(res)
|
||||||
|
|
||||||
|
def get_relative_vertices(self, vertice):
|
||||||
|
'''
|
||||||
|
get all vertices
|
||||||
|
@return:
|
||||||
|
'''
|
||||||
|
cypher = f'''MATCH (v1)--(v2) WHERE id(v1) == '{vertice}' RETURN id(v2) as id;'''
|
||||||
|
res = self.execute_cypher(cypher, self.space_name)
|
||||||
|
return self.result_to_dict(res)
|
||||||
|
|
||||||
def result_to_dict(self, result) -> dict:
|
def result_to_dict(self, result) -> dict:
|
||||||
"""
|
"""
|
||||||
build list for each column, and transform to dataframe
|
build list for each column, and transform to dataframe
|
||||||
|
|
|
@ -6,6 +6,7 @@ import os
|
||||||
import pickle
|
import pickle
|
||||||
import uuid
|
import uuid
|
||||||
import warnings
|
import warnings
|
||||||
|
from enum import Enum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
|
@ -22,10 +23,22 @@ import numpy as np
|
||||||
|
|
||||||
from langchain.docstore.base import AddableMixin, Docstore
|
from langchain.docstore.base import AddableMixin, Docstore
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
from langchain.docstore.in_memory import InMemoryDocstore
|
# from langchain.docstore.in_memory import InMemoryDocstore
|
||||||
|
from .in_memory import InMemoryDocstore
|
||||||
from langchain.embeddings.base import Embeddings
|
from langchain.embeddings.base import Embeddings
|
||||||
from langchain.vectorstores.base import VectorStore
|
from langchain.vectorstores.base import VectorStore
|
||||||
from langchain.vectorstores.utils import DistanceStrategy, maximal_marginal_relevance
|
from langchain.vectorstores.utils import maximal_marginal_relevance
|
||||||
|
|
||||||
|
|
||||||
|
class DistanceStrategy(str, Enum):
|
||||||
|
"""Enumerator of the Distance strategies for calculating distances
|
||||||
|
between vectors."""
|
||||||
|
|
||||||
|
EUCLIDEAN_DISTANCE = "EUCLIDEAN_DISTANCE"
|
||||||
|
MAX_INNER_PRODUCT = "MAX_INNER_PRODUCT"
|
||||||
|
DOT_PRODUCT = "DOT_PRODUCT"
|
||||||
|
JACCARD = "JACCARD"
|
||||||
|
COSINE = "COSINE"
|
||||||
|
|
||||||
|
|
||||||
def dependable_faiss_import(no_avx2: Optional[bool] = None) -> Any:
|
def dependable_faiss_import(no_avx2: Optional[bool] = None) -> Any:
|
||||||
|
@ -219,6 +232,9 @@ class FAISS(VectorStore):
|
||||||
if self._normalize_L2:
|
if self._normalize_L2:
|
||||||
faiss.normalize_L2(vector)
|
faiss.normalize_L2(vector)
|
||||||
scores, indices = self.index.search(vector, k if filter is None else fetch_k)
|
scores, indices = self.index.search(vector, k if filter is None else fetch_k)
|
||||||
|
# 经过normalize的结果会超出1
|
||||||
|
if self._normalize_L2:
|
||||||
|
scores = np.array([row / np.linalg.norm(row) if np.max(row) > 1 else row for row in scores])
|
||||||
docs = []
|
docs = []
|
||||||
for j, i in enumerate(indices[0]):
|
for j, i in enumerate(indices[0]):
|
||||||
if i == -1:
|
if i == -1:
|
||||||
|
@ -565,7 +581,7 @@ class FAISS(VectorStore):
|
||||||
vecstore = cls(
|
vecstore = cls(
|
||||||
embedding.embed_query,
|
embedding.embed_query,
|
||||||
index,
|
index,
|
||||||
InMemoryDocstore(),
|
InMemoryDocstore({}),
|
||||||
{},
|
{},
|
||||||
normalize_L2=normalize_L2,
|
normalize_L2=normalize_L2,
|
||||||
distance_strategy=distance_strategy,
|
distance_strategy=distance_strategy,
|
||||||
|
|
|
@ -10,13 +10,14 @@ from loguru import logger
|
||||||
# from configs.model_config import EMBEDDING_MODEL
|
# from configs.model_config import EMBEDDING_MODEL
|
||||||
from coagent.embeddings.openai_embedding import OpenAIEmbedding
|
from coagent.embeddings.openai_embedding import OpenAIEmbedding
|
||||||
from coagent.embeddings.huggingface_embedding import HFEmbedding
|
from coagent.embeddings.huggingface_embedding import HFEmbedding
|
||||||
|
from coagent.llm_models.llm_config import EmbedConfig
|
||||||
|
|
||||||
def get_embedding(
|
def get_embedding(
|
||||||
engine: str,
|
engine: str,
|
||||||
text_list: list,
|
text_list: list,
|
||||||
model_path: str = "text2vec-base-chinese",
|
model_path: str = "text2vec-base-chinese",
|
||||||
embedding_device: str = "cpu",
|
embedding_device: str = "cpu",
|
||||||
|
embed_config: EmbedConfig = None,
|
||||||
):
|
):
|
||||||
'''
|
'''
|
||||||
get embedding
|
get embedding
|
||||||
|
@ -25,8 +26,12 @@ def get_embedding(
|
||||||
@return:
|
@return:
|
||||||
'''
|
'''
|
||||||
emb_res = {}
|
emb_res = {}
|
||||||
|
if embed_config and embed_config.langchain_embeddings:
|
||||||
if engine == 'openai':
|
emb_res = embed_config.langchain_embeddings.embed_documents(text_list)
|
||||||
|
emb_res = {
|
||||||
|
text_list[idx]: emb_res[idx] for idx in range(len(text_list))
|
||||||
|
}
|
||||||
|
elif engine == 'openai':
|
||||||
oae = OpenAIEmbedding()
|
oae = OpenAIEmbedding()
|
||||||
emb_res = oae.get_emb(text_list)
|
emb_res = oae.get_emb(text_list)
|
||||||
elif engine == 'model':
|
elif engine == 'model':
|
||||||
|
|
|
@ -0,0 +1,49 @@
|
||||||
|
"""Simple in memory docstore in the form of a dict."""
|
||||||
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
|
from langchain.docstore.base import AddableMixin, Docstore
|
||||||
|
from langchain.docstore.document import Document
|
||||||
|
|
||||||
|
|
||||||
|
class InMemoryDocstore(Docstore, AddableMixin):
|
||||||
|
"""Simple in memory docstore in the form of a dict."""
|
||||||
|
|
||||||
|
def __init__(self, _dict: Optional[Dict[str, Document]] = None):
|
||||||
|
"""Initialize with dict."""
|
||||||
|
self._dict = _dict if _dict is not None else {}
|
||||||
|
|
||||||
|
def add(self, texts: Dict[str, Document]) -> None:
|
||||||
|
"""Add texts to in memory dictionary.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts: dictionary of id -> document.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
|
overlapping = set(texts).intersection(self._dict)
|
||||||
|
if overlapping:
|
||||||
|
raise ValueError(f"Tried to add ids that already exist: {overlapping}")
|
||||||
|
self._dict = {**self._dict, **texts}
|
||||||
|
|
||||||
|
def delete(self, ids: List) -> None:
|
||||||
|
"""Deleting IDs from in memory dictionary."""
|
||||||
|
overlapping = set(ids).intersection(self._dict)
|
||||||
|
if not overlapping:
|
||||||
|
raise ValueError(f"Tried to delete ids that does not exist: {ids}")
|
||||||
|
for _id in ids:
|
||||||
|
self._dict.pop(_id)
|
||||||
|
|
||||||
|
def search(self, search: str) -> Union[str, Document]:
|
||||||
|
"""Search via direct lookup.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
search: id of a document to search for.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Document if found, else error message.
|
||||||
|
"""
|
||||||
|
if search not in self._dict:
|
||||||
|
return f"ID {search} not found."
|
||||||
|
else:
|
||||||
|
return self._dict[search]
|
|
@ -1,6 +1,8 @@
|
||||||
import os
|
import os
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
||||||
|
from langchain.embeddings.base import Embeddings
|
||||||
|
|
||||||
# from configs.model_config import embedding_model_dict
|
# from configs.model_config import embedding_model_dict
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
|
@ -12,8 +14,11 @@ def load_embeddings(model: str, device: str, embedding_model_dict: dict):
|
||||||
return embeddings
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(1)
|
# @lru_cache(1)
|
||||||
def load_embeddings_from_path(model_path: str, device: str):
|
def load_embeddings_from_path(model_path: str, device: str, langchain_embeddings: Embeddings = None):
|
||||||
|
if langchain_embeddings:
|
||||||
|
return langchain_embeddings
|
||||||
|
|
||||||
embeddings = HuggingFaceEmbeddings(model_name=model_path,
|
embeddings = HuggingFaceEmbeddings(model_name=model_path,
|
||||||
model_kwargs={'device': device})
|
model_kwargs={'device': device})
|
||||||
return embeddings
|
return embeddings
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
from .openai_model import getChatModel, getExtraModel, getChatModelFromConfig
|
from .openai_model import getExtraModel, getChatModelFromConfig
|
||||||
from .llm_config import LLMConfig, EmbedConfig
|
from .llm_config import LLMConfig, EmbedConfig
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"getChatModel", "getExtraModel", "getChatModelFromConfig",
|
"getExtraModel", "getChatModelFromConfig",
|
||||||
"LLMConfig", "EmbedConfig"
|
"LLMConfig", "EmbedConfig"
|
||||||
]
|
]
|
|
@ -1,6 +1,9 @@
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List, Union
|
from typing import List, Union
|
||||||
|
|
||||||
|
from langchain.embeddings.base import Embeddings
|
||||||
|
from langchain.llms.base import LLM, BaseLLM
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -12,7 +15,8 @@ class LLMConfig:
|
||||||
stop: Union[List[str], str] = None,
|
stop: Union[List[str], str] = None,
|
||||||
api_key: str = "",
|
api_key: str = "",
|
||||||
api_base_url: str = "",
|
api_base_url: str = "",
|
||||||
model_device: str = "cpu",
|
model_device: str = "cpu", # unuse,will delete it
|
||||||
|
llm: LLM = None,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
|
|
||||||
|
@ -21,7 +25,7 @@ class LLMConfig:
|
||||||
self.stop: Union[List[str], str] = stop
|
self.stop: Union[List[str], str] = stop
|
||||||
self.api_key: str = api_key
|
self.api_key: str = api_key
|
||||||
self.api_base_url: str = api_base_url
|
self.api_base_url: str = api_base_url
|
||||||
self.model_device: str = model_device
|
self.llm: LLM = llm
|
||||||
#
|
#
|
||||||
self.check_config()
|
self.check_config()
|
||||||
|
|
||||||
|
@ -42,6 +46,7 @@ class EmbedConfig:
|
||||||
embed_model_path: str = "",
|
embed_model_path: str = "",
|
||||||
embed_engine: str = "",
|
embed_engine: str = "",
|
||||||
model_device: str = "cpu",
|
model_device: str = "cpu",
|
||||||
|
langchain_embeddings: Embeddings = None,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
self.embed_model: str = embed_model
|
self.embed_model: str = embed_model
|
||||||
|
@ -51,6 +56,8 @@ class EmbedConfig:
|
||||||
self.api_key: str = api_key
|
self.api_key: str = api_key
|
||||||
self.api_base_url: str = api_base_url
|
self.api_base_url: str = api_base_url
|
||||||
#
|
#
|
||||||
|
self.langchain_embeddings = langchain_embeddings
|
||||||
|
#
|
||||||
self.check_config()
|
self.check_config()
|
||||||
|
|
||||||
def check_config(self, ):
|
def check_config(self, ):
|
||||||
|
|
|
@ -1,38 +1,54 @@
|
||||||
import os
|
import os
|
||||||
|
from typing import Union, Optional, List
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
from langchain.callbacks import AsyncIteratorCallbackHandler
|
from langchain.callbacks import AsyncIteratorCallbackHandler
|
||||||
from langchain.chat_models import ChatOpenAI
|
from langchain.chat_models import ChatOpenAI
|
||||||
|
from langchain.llms.base import LLM
|
||||||
|
|
||||||
from .llm_config import LLMConfig
|
from .llm_config import LLMConfig
|
||||||
# from configs.model_config import (llm_model_dict, LLM_MODEL)
|
# from configs.model_config import (llm_model_dict, LLM_MODEL)
|
||||||
|
|
||||||
|
|
||||||
def getChatModel(callBack: AsyncIteratorCallbackHandler = None, temperature=0.3, stop=None):
|
class CustomLLMModel:
|
||||||
if callBack is None:
|
|
||||||
|
def __init__(self, llm: LLM):
|
||||||
|
self.llm: LLM = llm
|
||||||
|
|
||||||
|
def __call__(self, prompt: str,
|
||||||
|
stop: Optional[List[str]] = None):
|
||||||
|
return self.llm(prompt, stop)
|
||||||
|
|
||||||
|
def _call(self, prompt: str,
|
||||||
|
stop: Optional[List[str]] = None):
|
||||||
|
return self.llm(prompt, stop)
|
||||||
|
|
||||||
|
def predict(self, prompt: str,
|
||||||
|
stop: Optional[List[str]] = None):
|
||||||
|
return self.llm(prompt, stop)
|
||||||
|
|
||||||
|
def batch(self, prompts: str,
|
||||||
|
stop: Optional[List[str]] = None):
|
||||||
|
return [self.llm(prompt, stop) for prompt in prompts]
|
||||||
|
|
||||||
|
|
||||||
|
def getChatModelFromConfig(llm_config: LLMConfig, callBack: AsyncIteratorCallbackHandler = None, ) -> Union[ChatOpenAI, LLM]:
|
||||||
|
# logger.debug(f"llm type is {type(llm_config.llm)}")
|
||||||
|
if llm_config is None:
|
||||||
model = ChatOpenAI(
|
model = ChatOpenAI(
|
||||||
streaming=True,
|
streaming=True,
|
||||||
verbose=True,
|
verbose=True,
|
||||||
openai_api_key=llm_model_dict[LLM_MODEL]["api_key"],
|
openai_api_key=os.environ.get("api_key"),
|
||||||
openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"],
|
openai_api_base=os.environ.get("api_base_url"),
|
||||||
model_name=LLM_MODEL,
|
model_name=os.environ.get("LLM_MODEL", "gpt-3.5-turbo"),
|
||||||
temperature=temperature,
|
temperature=os.environ.get("temperature", 0.5),
|
||||||
stop=stop
|
stop=os.environ.get("stop", ""),
|
||||||
)
|
)
|
||||||
else:
|
return model
|
||||||
model = ChatOpenAI(
|
|
||||||
streaming=True,
|
|
||||||
verbose=True,
|
|
||||||
callBack=[callBack],
|
|
||||||
openai_api_key=llm_model_dict[LLM_MODEL]["api_key"],
|
|
||||||
openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"],
|
|
||||||
model_name=LLM_MODEL,
|
|
||||||
temperature=temperature,
|
|
||||||
stop=stop
|
|
||||||
)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
if llm_config and llm_config.llm and isinstance(llm_config.llm, LLM):
|
||||||
def getChatModelFromConfig(llm_config: LLMConfig, callBack: AsyncIteratorCallbackHandler = None, ):
|
return CustomLLMModel(llm=llm_config.llm)
|
||||||
|
|
||||||
if callBack is None:
|
if callBack is None:
|
||||||
model = ChatOpenAI(
|
model = ChatOpenAI(
|
||||||
streaming=True,
|
streaming=True,
|
||||||
|
|
|
@ -0,0 +1,5 @@
|
||||||
|
# from .base_retrieval import *
|
||||||
|
|
||||||
|
# __all__ = [
|
||||||
|
# "IMRertrieval", "BaseDocRetrieval", "BaseCodeRetrieval", "BaseSearchRetrieval"
|
||||||
|
# ]
|
|
@ -0,0 +1,75 @@
|
||||||
|
|
||||||
|
from coagent.llm_models.llm_config import EmbedConfig, LLMConfig
|
||||||
|
from coagent.base_configs.env_config import KB_ROOT_PATH
|
||||||
|
from coagent.tools import DocRetrieval, CodeRetrieval
|
||||||
|
|
||||||
|
|
||||||
|
class IMRertrieval:
|
||||||
|
|
||||||
|
def __init__(self,):
|
||||||
|
'''
|
||||||
|
init your personal attributes
|
||||||
|
'''
|
||||||
|
pass
|
||||||
|
|
||||||
|
def run(self, ):
|
||||||
|
'''
|
||||||
|
execute interface, and can use init' attributes
|
||||||
|
'''
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class BaseDocRetrieval(IMRertrieval):
|
||||||
|
|
||||||
|
def __init__(self, knowledge_base_name: str, search_top=5, score_threshold=1.0, embed_config: EmbedConfig=EmbedConfig(), kb_root_path: str=KB_ROOT_PATH):
|
||||||
|
self.knowledge_base_name = knowledge_base_name
|
||||||
|
self.search_top = search_top
|
||||||
|
self.score_threshold = score_threshold
|
||||||
|
self.embed_config = embed_config
|
||||||
|
self.kb_root_path = kb_root_path
|
||||||
|
|
||||||
|
def run(self, query: str, search_top=None, score_threshold=None, ):
|
||||||
|
docs = DocRetrieval.run(
|
||||||
|
query=query, knowledge_base_name=self.knowledge_base_name,
|
||||||
|
search_top=search_top or self.search_top,
|
||||||
|
score_threshold=score_threshold or self.score_threshold,
|
||||||
|
embed_config=self.embed_config,
|
||||||
|
kb_root_path=self.kb_root_path
|
||||||
|
)
|
||||||
|
return docs
|
||||||
|
|
||||||
|
|
||||||
|
class BaseCodeRetrieval(IMRertrieval):
|
||||||
|
|
||||||
|
def __init__(self, code_base_name, embed_config: EmbedConfig, llm_config: LLMConfig, search_type = 'tag', code_limit = 1, local_graph_path: str=""):
|
||||||
|
self.code_base_name = code_base_name
|
||||||
|
self.embed_config = embed_config
|
||||||
|
self.llm_config = llm_config
|
||||||
|
self.search_type = search_type
|
||||||
|
self.code_limit = code_limit
|
||||||
|
self.use_nh: bool = False
|
||||||
|
self.local_graph_path: str = local_graph_path
|
||||||
|
|
||||||
|
def run(self, query, history_node_list=[], search_type = None, code_limit=None):
|
||||||
|
code_docs = CodeRetrieval.run(
|
||||||
|
code_base_name=self.code_base_name,
|
||||||
|
query=query,
|
||||||
|
history_node_list=history_node_list,
|
||||||
|
code_limit=code_limit or self.code_limit,
|
||||||
|
search_type=search_type or self.search_type,
|
||||||
|
llm_config=self.llm_config,
|
||||||
|
embed_config=self.embed_config,
|
||||||
|
use_nh=self.use_nh,
|
||||||
|
local_graph_path=self.local_graph_path
|
||||||
|
)
|
||||||
|
return code_docs
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class BaseSearchRetrieval(IMRertrieval):
|
||||||
|
|
||||||
|
def __init__(self, ):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def run(self, ):
|
||||||
|
pass
|
|
@ -0,0 +1,6 @@
|
||||||
|
from .json_loader import JSONLoader
|
||||||
|
from .jsonl_loader import JSONLLoader
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"JSONLoader", "JSONLLoader"
|
||||||
|
]
|
|
@ -0,0 +1,61 @@
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import AnyStr, Callable, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
from langchain.docstore.document import Document
|
||||||
|
from langchain.document_loaders.base import BaseLoader
|
||||||
|
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
|
||||||
|
|
||||||
|
from coagent.utils.common_utils import read_json_file
|
||||||
|
|
||||||
|
|
||||||
|
class JSONLoader(BaseLoader):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
file_path: Union[str, Path],
|
||||||
|
schema_key: str = "all_text",
|
||||||
|
content_key: Optional[str] = None,
|
||||||
|
metadata_func: Optional[Callable[[Dict, Dict], Dict]] = None,
|
||||||
|
text_content: bool = True,
|
||||||
|
):
|
||||||
|
self.file_path = Path(file_path).resolve()
|
||||||
|
self.schema_key = schema_key
|
||||||
|
self._content_key = content_key
|
||||||
|
self._metadata_func = metadata_func
|
||||||
|
self._text_content = text_content
|
||||||
|
|
||||||
|
def load(self, ) -> List[Document]:
|
||||||
|
"""Load and return documents from the JSON file."""
|
||||||
|
docs: List[Document] = []
|
||||||
|
datas = read_json_file(self.file_path)
|
||||||
|
self._parse(datas, docs)
|
||||||
|
return docs
|
||||||
|
|
||||||
|
def _parse(self, datas: List, docs: List[Document]) -> None:
|
||||||
|
for idx, sample in enumerate(datas):
|
||||||
|
metadata = dict(
|
||||||
|
source=str(self.file_path),
|
||||||
|
seq_num=idx,
|
||||||
|
)
|
||||||
|
text = sample.get(self.schema_key, "")
|
||||||
|
docs.append(Document(page_content=text, metadata=metadata))
|
||||||
|
|
||||||
|
def load_and_split(
|
||||||
|
self, text_splitter: Optional[TextSplitter] = None
|
||||||
|
) -> List[Document]:
|
||||||
|
"""Load Documents and split into chunks. Chunks are returned as Documents.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text_splitter: TextSplitter instance to use for splitting documents.
|
||||||
|
Defaults to RecursiveCharacterTextSplitter.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of Documents.
|
||||||
|
"""
|
||||||
|
if text_splitter is None:
|
||||||
|
_text_splitter: TextSplitter = RecursiveCharacterTextSplitter()
|
||||||
|
else:
|
||||||
|
_text_splitter = text_splitter
|
||||||
|
docs = self.load()
|
||||||
|
return _text_splitter.split_documents(docs)
|
|
@ -0,0 +1,62 @@
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import AnyStr, Callable, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
from langchain.docstore.document import Document
|
||||||
|
from langchain.document_loaders.base import BaseLoader
|
||||||
|
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
|
||||||
|
|
||||||
|
from coagent.utils.common_utils import read_jsonl_file
|
||||||
|
|
||||||
|
|
||||||
|
class JSONLLoader(BaseLoader):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
file_path: Union[str, Path],
|
||||||
|
schema_key: str = "all_text",
|
||||||
|
content_key: Optional[str] = None,
|
||||||
|
metadata_func: Optional[Callable[[Dict, Dict], Dict]] = None,
|
||||||
|
text_content: bool = True,
|
||||||
|
):
|
||||||
|
self.file_path = Path(file_path).resolve()
|
||||||
|
self.schema_key = schema_key
|
||||||
|
self._content_key = content_key
|
||||||
|
self._metadata_func = metadata_func
|
||||||
|
self._text_content = text_content
|
||||||
|
|
||||||
|
def load(self, ) -> List[Document]:
|
||||||
|
"""Load and return documents from the JSON file."""
|
||||||
|
docs: List[Document] = []
|
||||||
|
datas = read_jsonl_file(self.file_path)
|
||||||
|
self._parse(datas, docs)
|
||||||
|
return docs
|
||||||
|
|
||||||
|
def _parse(self, datas: List, docs: List[Document]) -> None:
|
||||||
|
for idx, sample in enumerate(datas):
|
||||||
|
metadata = dict(
|
||||||
|
source=str(self.file_path),
|
||||||
|
seq_num=idx,
|
||||||
|
)
|
||||||
|
text = sample.get(self.schema_key, "")
|
||||||
|
docs.append(Document(page_content=text, metadata=metadata))
|
||||||
|
|
||||||
|
def load_and_split(
|
||||||
|
self, text_splitter: Optional[TextSplitter] = None
|
||||||
|
) -> List[Document]:
|
||||||
|
"""Load Documents and split into chunks. Chunks are returned as Documents.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text_splitter: TextSplitter instance to use for splitting documents.
|
||||||
|
Defaults to RecursiveCharacterTextSplitter.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of Documents.
|
||||||
|
"""
|
||||||
|
if text_splitter is None:
|
||||||
|
_text_splitter: TextSplitter = RecursiveCharacterTextSplitter()
|
||||||
|
else:
|
||||||
|
_text_splitter = text_splitter
|
||||||
|
|
||||||
|
docs = self.load()
|
||||||
|
return _text_splitter.split_documents(docs)
|
|
@ -0,0 +1,3 @@
|
||||||
|
from .langchain_splitter import LCTextSplitter
|
||||||
|
|
||||||
|
__all__ = ["LCTextSplitter"]
|
|
@ -0,0 +1,77 @@
|
||||||
|
import os
|
||||||
|
import importlib
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from langchain.document_loaders.base import BaseLoader
|
||||||
|
from langchain.text_splitter import (
|
||||||
|
SpacyTextSplitter, RecursiveCharacterTextSplitter
|
||||||
|
)
|
||||||
|
|
||||||
|
# from configs.model_config import (
|
||||||
|
# CHUNK_SIZE,
|
||||||
|
# OVERLAP_SIZE,
|
||||||
|
# ZH_TITLE_ENHANCE
|
||||||
|
# )
|
||||||
|
from coagent.utils.path_utils import *
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class LCTextSplitter:
|
||||||
|
'''langchain textsplitter 执行file2text'''
|
||||||
|
def __init__(
|
||||||
|
self, filepath: str, text_splitter_name: str = None,
|
||||||
|
chunk_size: int = 500,
|
||||||
|
overlap_size: int = 50
|
||||||
|
):
|
||||||
|
self.filepath = filepath
|
||||||
|
self.ext = os.path.splitext(filepath)[-1].lower()
|
||||||
|
self.text_splitter_name = text_splitter_name
|
||||||
|
self.chunk_size = chunk_size
|
||||||
|
self.overlap_size = overlap_size
|
||||||
|
if self.ext not in SUPPORTED_EXTS:
|
||||||
|
raise ValueError(f"暂未支持的文件格式 {self.ext}")
|
||||||
|
self.document_loader_name = get_LoaderClass(self.ext)
|
||||||
|
|
||||||
|
def file2text(self, ):
|
||||||
|
loader = self._load_document()
|
||||||
|
text_splitter = self._load_text_splitter()
|
||||||
|
if self.document_loader_name in ["JSONLoader", "JSONLLoader"]:
|
||||||
|
# docs = loader.load()
|
||||||
|
docs = loader.load_and_split(text_splitter)
|
||||||
|
# logger.debug(f"please check your file can be loaded, docs.lens {len(docs)}")
|
||||||
|
else:
|
||||||
|
docs = loader.load_and_split(text_splitter)
|
||||||
|
|
||||||
|
return docs
|
||||||
|
|
||||||
|
def _load_document(self, ) -> BaseLoader:
|
||||||
|
DocumentLoader = EXT2LOADER_DICT[self.ext]
|
||||||
|
if self.document_loader_name == "UnstructuredFileLoader":
|
||||||
|
loader = DocumentLoader(self.filepath, autodetect_encoding=True)
|
||||||
|
else:
|
||||||
|
loader = DocumentLoader(self.filepath)
|
||||||
|
return loader
|
||||||
|
|
||||||
|
def _load_text_splitter(self, ):
|
||||||
|
try:
|
||||||
|
if self.text_splitter_name is None:
|
||||||
|
text_splitter = SpacyTextSplitter(
|
||||||
|
pipeline="zh_core_web_sm",
|
||||||
|
chunk_size=self.chunk_size,
|
||||||
|
chunk_overlap=self.overlap_size,
|
||||||
|
)
|
||||||
|
self.text_splitter_name = "SpacyTextSplitter"
|
||||||
|
# elif self.document_loader_name in ["JSONLoader", "JSONLLoader"]:
|
||||||
|
# text_splitter = None
|
||||||
|
else:
|
||||||
|
text_splitter_module = importlib.import_module('langchain.text_splitter')
|
||||||
|
TextSplitter = getattr(text_splitter_module, self.text_splitter_name)
|
||||||
|
text_splitter = TextSplitter(
|
||||||
|
chunk_size=self.chunk_size,
|
||||||
|
chunk_overlap=self.overlap_size)
|
||||||
|
except Exception as e:
|
||||||
|
text_splitter = RecursiveCharacterTextSplitter(
|
||||||
|
chunk_size=self.chunk_size,
|
||||||
|
chunk_overlap=self.overlap_size,
|
||||||
|
)
|
||||||
|
return text_splitter
|
|
@ -32,8 +32,8 @@ class PyCodeBox(BaseBox):
|
||||||
self.do_check_net = do_check_net
|
self.do_check_net = do_check_net
|
||||||
self.use_stop = use_stop
|
self.use_stop = use_stop
|
||||||
self.jupyter_work_path = jupyter_work_path
|
self.jupyter_work_path = jupyter_work_path
|
||||||
asyncio.run(self.astart())
|
# asyncio.run(self.astart())
|
||||||
# self.start()
|
self.start()
|
||||||
|
|
||||||
# logger.info(f"""remote_url: {self.remote_url},
|
# logger.info(f"""remote_url: {self.remote_url},
|
||||||
# remote_ip: {self.remote_ip},
|
# remote_ip: {self.remote_ip},
|
||||||
|
@ -199,13 +199,13 @@ class PyCodeBox(BaseBox):
|
||||||
|
|
||||||
async def _aget_kernelid(self, ) -> None:
|
async def _aget_kernelid(self, ) -> None:
|
||||||
headers = {"Authorization": f'Token {self.token}', 'token': self.token}
|
headers = {"Authorization": f'Token {self.token}', 'token': self.token}
|
||||||
response = requests.get(f"{self.kernel_url}?token={self.token}", headers=headers)
|
# response = requests.get(f"{self.kernel_url}?token={self.token}", headers=headers)
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
async with session.get(f"{self.kernel_url}?token={self.token}", headers=headers) as resp:
|
async with session.get(f"{self.kernel_url}?token={self.token}", headers=headers, timeout=270) as resp:
|
||||||
if len(await resp.json()) > 0:
|
if len(await resp.json()) > 0:
|
||||||
self.kernel_id = (await resp.json())[0]["id"]
|
self.kernel_id = (await resp.json())[0]["id"]
|
||||||
else:
|
else:
|
||||||
async with session.post(f"{self.kernel_url}?token={self.token}", headers=headers) as response:
|
async with session.post(f"{self.kernel_url}?token={self.token}", headers=headers, timeout=270) as response:
|
||||||
self.kernel_id = (await response.json())["id"]
|
self.kernel_id = (await response.json())["id"]
|
||||||
|
|
||||||
# if len(response.json()) > 0:
|
# if len(response.json()) > 0:
|
||||||
|
@ -220,41 +220,45 @@ class PyCodeBox(BaseBox):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = requests.get(f"{self.kernel_url}?token={self.token}", timeout=270)
|
response = requests.get(f"{self.kernel_url}?token={self.token}", timeout=10)
|
||||||
return response.status_code == 200
|
return response.status_code == 200
|
||||||
except requests.exceptions.ConnectionError:
|
except requests.exceptions.ConnectionError:
|
||||||
return False
|
return False
|
||||||
|
except requests.exceptions.ReadTimeout:
|
||||||
|
return False
|
||||||
|
|
||||||
async def _acheck_connect(self, ) -> bool:
|
async def _acheck_connect(self, ) -> bool:
|
||||||
if self.kernel_url == "":
|
if self.kernel_url == "":
|
||||||
return False
|
return False
|
||||||
try:
|
try:
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
async with session.get(f"{self.kernel_url}?token={self.token}", timeout=270) as resp:
|
async with session.get(f"{self.kernel_url}?token={self.token}", timeout=10) as resp:
|
||||||
return resp.status == 200
|
return resp.status == 200
|
||||||
except aiohttp.ClientConnectorError:
|
except aiohttp.ClientConnectorError:
|
||||||
pass
|
return False
|
||||||
except aiohttp.ServerDisconnectedError:
|
except aiohttp.ServerDisconnectedError:
|
||||||
pass
|
return False
|
||||||
|
|
||||||
def _check_port(self, ) -> bool:
|
def _check_port(self, ) -> bool:
|
||||||
try:
|
try:
|
||||||
response = requests.get(f"{self.remote_ip}:{self.remote_port}", timeout=270)
|
response = requests.get(f"{self.remote_ip}:{self.remote_port}", timeout=10)
|
||||||
logger.warning(f"Port is conflict, please check your codebox's port {self.remote_port}")
|
logger.warning(f"Port is conflict, please check your codebox's port {self.remote_port}")
|
||||||
return response.status_code == 200
|
return response.status_code == 200
|
||||||
except requests.exceptions.ConnectionError:
|
except requests.exceptions.ConnectionError:
|
||||||
return False
|
return False
|
||||||
|
except requests.exceptions.ReadTimeout:
|
||||||
|
return False
|
||||||
|
|
||||||
async def _acheck_port(self, ) -> bool:
|
async def _acheck_port(self, ) -> bool:
|
||||||
try:
|
try:
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
async with session.get(f"{self.remote_ip}:{self.remote_port}", timeout=270) as resp:
|
async with session.get(f"{self.remote_ip}:{self.remote_port}", timeout=10) as resp:
|
||||||
# logger.warning(f"Port is conflict, please check your codebox's port {self.remote_port}")
|
# logger.warning(f"Port is conflict, please check your codebox's port {self.remote_port}")
|
||||||
return resp.status == 200
|
return resp.status == 200
|
||||||
except aiohttp.ClientConnectorError:
|
except aiohttp.ClientConnectorError:
|
||||||
pass
|
return False
|
||||||
except aiohttp.ServerDisconnectedError:
|
except aiohttp.ServerDisconnectedError:
|
||||||
pass
|
return False
|
||||||
|
|
||||||
def _check_connect_success(self, retry_nums: int = 2) -> bool:
|
def _check_connect_success(self, retry_nums: int = 2) -> bool:
|
||||||
if not self.do_check_net: return True
|
if not self.do_check_net: return True
|
||||||
|
@ -263,7 +267,7 @@ class PyCodeBox(BaseBox):
|
||||||
try:
|
try:
|
||||||
connect_status = self._check_connect()
|
connect_status = self._check_connect()
|
||||||
if connect_status:
|
if connect_status:
|
||||||
logger.info(f"{self.remote_url} connection success")
|
# logger.info(f"{self.remote_url} connection success")
|
||||||
return True
|
return True
|
||||||
except requests.exceptions.ConnectionError:
|
except requests.exceptions.ConnectionError:
|
||||||
logger.info(f"{self.remote_url} connection fail")
|
logger.info(f"{self.remote_url} connection fail")
|
||||||
|
@ -301,10 +305,12 @@ class PyCodeBox(BaseBox):
|
||||||
else:
|
else:
|
||||||
# TODO 自动检测本地接口
|
# TODO 自动检测本地接口
|
||||||
port_status = self._check_port()
|
port_status = self._check_port()
|
||||||
|
self.kernel_url = self.remote_url + "/api/kernels"
|
||||||
connect_status = self._check_connect()
|
connect_status = self._check_connect()
|
||||||
logger.info(f"port_status: {port_status}, connect_status: {connect_status}")
|
if os.environ.get("log_verbose", "0") >= "2":
|
||||||
|
logger.info(f"port_status: {port_status}, connect_status: {connect_status}")
|
||||||
if port_status and not connect_status:
|
if port_status and not connect_status:
|
||||||
raise BaseException(f"Port is conflict, please check your codebox's port {self.remote_port}")
|
logger.error("Port is conflict, please check your codebox's port {self.remote_port}")
|
||||||
|
|
||||||
if not connect_status:
|
if not connect_status:
|
||||||
self.jupyter = subprocess.Popen(
|
self.jupyter = subprocess.Popen(
|
||||||
|
@ -321,14 +327,32 @@ class PyCodeBox(BaseBox):
|
||||||
stdout=subprocess.PIPE,
|
stdout=subprocess.PIPE,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
record = []
|
||||||
|
while True and self.jupyter and len(record)<100:
|
||||||
|
line = self.jupyter.stderr.readline()
|
||||||
|
try:
|
||||||
|
content = line.decode("utf-8")
|
||||||
|
except:
|
||||||
|
content = line.decode("gbk")
|
||||||
|
# logger.debug(content)
|
||||||
|
record.append(content)
|
||||||
|
if "control-c" in content.lower():
|
||||||
|
break
|
||||||
|
|
||||||
self.kernel_url = self.remote_url + "/api/kernels"
|
self.kernel_url = self.remote_url + "/api/kernels"
|
||||||
self.do_check_net = True
|
self.do_check_net = True
|
||||||
self._check_connect_success()
|
self._check_connect_success()
|
||||||
self._get_kernelid()
|
self._get_kernelid()
|
||||||
# logger.debug(self.kernel_url.replace("http", "ws") + f"/{self.kernel_id}/channels?token={self.token}")
|
|
||||||
self.wc_url = self.kernel_url.replace("http", "ws") + f"/{self.kernel_id}/channels?token={self.token}"
|
self.wc_url = self.kernel_url.replace("http", "ws") + f"/{self.kernel_id}/channels?token={self.token}"
|
||||||
headers = {"Authorization": f'Token {self.token}', 'token': self.token}
|
headers = {"Authorization": f'Token {self.token}', 'token': self.token}
|
||||||
self.ws = create_connection(self.wc_url, headers=headers)
|
retry_nums = 3
|
||||||
|
while retry_nums>=0:
|
||||||
|
try:
|
||||||
|
self.ws = create_connection(self.wc_url, headers=headers, timeout=5)
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"create ws connection timeout {e}")
|
||||||
|
retry_nums -= 1
|
||||||
|
|
||||||
async def astart(self, ):
|
async def astart(self, ):
|
||||||
'''判断是从外部service执行还是内部启动notebook执行'''
|
'''判断是从外部service执行还是内部启动notebook执行'''
|
||||||
|
@ -369,10 +393,16 @@ class PyCodeBox(BaseBox):
|
||||||
cwd=self.jupyter_work_path
|
cwd=self.jupyter_work_path
|
||||||
)
|
)
|
||||||
|
|
||||||
while True and self.jupyter:
|
record = []
|
||||||
|
while True and self.jupyter and len(record)<100:
|
||||||
line = self.jupyter.stderr.readline()
|
line = self.jupyter.stderr.readline()
|
||||||
# logger.debug(line.decode("gbk"))
|
try:
|
||||||
if "Control-C" in line.decode("gbk"):
|
content = line.decode("utf-8")
|
||||||
|
except:
|
||||||
|
content = line.decode("gbk")
|
||||||
|
# logger.debug(content)
|
||||||
|
record.append(content)
|
||||||
|
if "control-c" in content.lower():
|
||||||
break
|
break
|
||||||
self.kernel_url = self.remote_url + "/api/kernels"
|
self.kernel_url = self.remote_url + "/api/kernels"
|
||||||
self.do_check_net = True
|
self.do_check_net = True
|
||||||
|
@ -380,7 +410,15 @@ class PyCodeBox(BaseBox):
|
||||||
await self._aget_kernelid()
|
await self._aget_kernelid()
|
||||||
self.wc_url = self.kernel_url.replace("http", "ws") + f"/{self.kernel_id}/channels?token={self.token}"
|
self.wc_url = self.kernel_url.replace("http", "ws") + f"/{self.kernel_id}/channels?token={self.token}"
|
||||||
headers = {"Authorization": f'Token {self.token}', 'token': self.token}
|
headers = {"Authorization": f'Token {self.token}', 'token': self.token}
|
||||||
self.ws = create_connection(self.wc_url, headers=headers)
|
|
||||||
|
retry_nums = 3
|
||||||
|
while retry_nums>=0:
|
||||||
|
try:
|
||||||
|
self.ws = create_connection(self.wc_url, headers=headers, timeout=5)
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"create ws connection timeout {e}")
|
||||||
|
retry_nums -= 1
|
||||||
|
|
||||||
def status(self,) -> CodeBoxStatus:
|
def status(self,) -> CodeBoxStatus:
|
||||||
if not self.kernel_id:
|
if not self.kernel_id:
|
||||||
|
|
|
@ -17,7 +17,7 @@ from coagent.orm.commands import *
|
||||||
from coagent.utils.path_utils import *
|
from coagent.utils.path_utils import *
|
||||||
from coagent.orm.utils import DocumentFile
|
from coagent.orm.utils import DocumentFile
|
||||||
from coagent.embeddings.utils import load_embeddings, load_embeddings_from_path
|
from coagent.embeddings.utils import load_embeddings, load_embeddings_from_path
|
||||||
from coagent.text_splitter import LCTextSplitter
|
from coagent.retrieval.text_splitter import LCTextSplitter
|
||||||
from coagent.llm_models.llm_config import EmbedConfig
|
from coagent.llm_models.llm_config import EmbedConfig
|
||||||
|
|
||||||
|
|
||||||
|
@ -46,7 +46,7 @@ class KBService(ABC):
|
||||||
|
|
||||||
def _load_embeddings(self) -> Embeddings:
|
def _load_embeddings(self) -> Embeddings:
|
||||||
# return load_embeddings(self.embed_model, embed_device, embedding_model_dict)
|
# return load_embeddings(self.embed_model, embed_device, embedding_model_dict)
|
||||||
return load_embeddings_from_path(self.embed_config.embed_model_path, self.embed_config.model_device)
|
return load_embeddings_from_path(self.embed_config.embed_model_path, self.embed_config.model_device, self.embed_config.langchain_embeddings)
|
||||||
|
|
||||||
def create_kb(self):
|
def create_kb(self):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -20,9 +20,6 @@ from coagent.utils.path_utils import *
|
||||||
from coagent.orm.commands import *
|
from coagent.orm.commands import *
|
||||||
from coagent.db_handler.graph_db_handler.nebula_handler import NebulaHandler
|
from coagent.db_handler.graph_db_handler.nebula_handler import NebulaHandler
|
||||||
from coagent.db_handler.vector_db_handler.chroma_handler import ChromaHandler
|
from coagent.db_handler.vector_db_handler.chroma_handler import ChromaHandler
|
||||||
# from configs.server_config import NEBULA_HOST, NEBULA_PORT, NEBULA_USER, NEBULA_PASSWORD, NEBULA_STORAGED_PORT
|
|
||||||
# from configs.server_config import CHROMA_PERSISTENT_PATH
|
|
||||||
|
|
||||||
from coagent.base_configs.env_config import (
|
from coagent.base_configs.env_config import (
|
||||||
CB_ROOT_PATH,
|
CB_ROOT_PATH,
|
||||||
NEBULA_HOST, NEBULA_PORT, NEBULA_USER, NEBULA_PASSWORD, NEBULA_STORAGED_PORT,
|
NEBULA_HOST, NEBULA_PORT, NEBULA_USER, NEBULA_PASSWORD, NEBULA_STORAGED_PORT,
|
||||||
|
@ -58,10 +55,11 @@ async def create_cb(zip_file,
|
||||||
model_name: bool = Body(..., examples=["samples"]),
|
model_name: bool = Body(..., examples=["samples"]),
|
||||||
temperature: bool = Body(..., examples=["samples"]),
|
temperature: bool = Body(..., examples=["samples"]),
|
||||||
model_device: bool = Body(..., examples=["samples"]),
|
model_device: bool = Body(..., examples=["samples"]),
|
||||||
|
embed_config: EmbedConfig = None,
|
||||||
) -> BaseResponse:
|
) -> BaseResponse:
|
||||||
logger.info('cb_name={}, zip_path={}, do_interpret={}'.format(cb_name, code_path, do_interpret))
|
logger.info('cb_name={}, zip_path={}, do_interpret={}'.format(cb_name, code_path, do_interpret))
|
||||||
|
|
||||||
embed_config: EmbedConfig = EmbedConfig(**locals())
|
embed_config: EmbedConfig = EmbedConfig(**locals()) if embed_config is None else embed_config
|
||||||
llm_config: LLMConfig = LLMConfig(**locals())
|
llm_config: LLMConfig = LLMConfig(**locals())
|
||||||
|
|
||||||
# Create selected knowledge base
|
# Create selected knowledge base
|
||||||
|
@ -101,9 +99,10 @@ async def delete_cb(
|
||||||
model_name: bool = Body(..., examples=["samples"]),
|
model_name: bool = Body(..., examples=["samples"]),
|
||||||
temperature: bool = Body(..., examples=["samples"]),
|
temperature: bool = Body(..., examples=["samples"]),
|
||||||
model_device: bool = Body(..., examples=["samples"]),
|
model_device: bool = Body(..., examples=["samples"]),
|
||||||
|
embed_config: EmbedConfig = None,
|
||||||
) -> BaseResponse:
|
) -> BaseResponse:
|
||||||
logger.info('cb_name={}'.format(cb_name))
|
logger.info('cb_name={}'.format(cb_name))
|
||||||
embed_config: EmbedConfig = EmbedConfig(**locals())
|
embed_config: EmbedConfig = EmbedConfig(**locals()) if embed_config is None else embed_config
|
||||||
llm_config: LLMConfig = LLMConfig(**locals())
|
llm_config: LLMConfig = LLMConfig(**locals())
|
||||||
# Create selected knowledge base
|
# Create selected knowledge base
|
||||||
if not validate_kb_name(cb_name):
|
if not validate_kb_name(cb_name):
|
||||||
|
@ -143,18 +142,24 @@ def search_code(cb_name: str = Body(..., examples=["sofaboot"]),
|
||||||
model_name: bool = Body(..., examples=["samples"]),
|
model_name: bool = Body(..., examples=["samples"]),
|
||||||
temperature: bool = Body(..., examples=["samples"]),
|
temperature: bool = Body(..., examples=["samples"]),
|
||||||
model_device: bool = Body(..., examples=["samples"]),
|
model_device: bool = Body(..., examples=["samples"]),
|
||||||
|
use_nh: bool = True,
|
||||||
|
local_graph_path: str = '',
|
||||||
|
embed_config: EmbedConfig = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
|
|
||||||
logger.info('cb_name={}'.format(cb_name))
|
if os.environ.get("log_verbose", "0") >= "2":
|
||||||
logger.info('query={}'.format(query))
|
logger.info(f'local_graph_path={local_graph_path}')
|
||||||
logger.info('code_limit={}'.format(code_limit))
|
logger.info('cb_name={}'.format(cb_name))
|
||||||
logger.info('search_type={}'.format(search_type))
|
logger.info('query={}'.format(query))
|
||||||
logger.info('history_node_list={}'.format(history_node_list))
|
logger.info('code_limit={}'.format(code_limit))
|
||||||
embed_config: EmbedConfig = EmbedConfig(**locals())
|
logger.info('search_type={}'.format(search_type))
|
||||||
|
logger.info('history_node_list={}'.format(history_node_list))
|
||||||
|
embed_config: EmbedConfig = EmbedConfig(**locals()) if embed_config is None else embed_config
|
||||||
llm_config: LLMConfig = LLMConfig(**locals())
|
llm_config: LLMConfig = LLMConfig(**locals())
|
||||||
try:
|
try:
|
||||||
# load codebase
|
# load codebase
|
||||||
cbh = CodeBaseHandler(codebase_name=cb_name, embed_config=embed_config, llm_config=llm_config)
|
cbh = CodeBaseHandler(codebase_name=cb_name, embed_config=embed_config, llm_config=llm_config,
|
||||||
|
use_nh=use_nh, local_graph_path=local_graph_path)
|
||||||
|
|
||||||
# search code
|
# search code
|
||||||
context, related_vertices = cbh.search_code(query, search_type=search_type, limit=code_limit)
|
context, related_vertices = cbh.search_code(query, search_type=search_type, limit=code_limit)
|
||||||
|
@ -179,11 +184,13 @@ def search_related_vertices(cb_name: str = Body(..., examples=["sofaboot"]),
|
||||||
# load codebase
|
# load codebase
|
||||||
nh = NebulaHandler(host=NEBULA_HOST, port=NEBULA_PORT, username=NEBULA_USER,
|
nh = NebulaHandler(host=NEBULA_HOST, port=NEBULA_PORT, username=NEBULA_USER,
|
||||||
password=NEBULA_PASSWORD, space_name=cb_name)
|
password=NEBULA_PASSWORD, space_name=cb_name)
|
||||||
|
|
||||||
cypher = f'''MATCH (v1)--(v2) WHERE id(v1) == '{vertex}' RETURN id(v2) as id;'''
|
if vertex.endswith(".java"):
|
||||||
|
cypher = f'''MATCH (v1)--(v2:package) WHERE id(v1) == '{vertex}' RETURN id(v2) as id;'''
|
||||||
|
else:
|
||||||
|
cypher = f'''MATCH (v1)--(v2) WHERE id(v1) == '{vertex}' RETURN id(v2) as id;'''
|
||||||
|
# cypher = f'''MATCH (v1)--(v2) WHERE id(v1) == '{vertex}' RETURN v2;'''
|
||||||
cypher_res = nh.execute_cypher(cypher=cypher, format_res=True)
|
cypher_res = nh.execute_cypher(cypher=cypher, format_res=True)
|
||||||
|
|
||||||
related_vertices = cypher_res.get('id', [])
|
related_vertices = cypher_res.get('id', [])
|
||||||
related_vertices = [i.as_string() for i in related_vertices]
|
related_vertices = [i.as_string() for i in related_vertices]
|
||||||
|
|
||||||
|
@ -200,8 +207,8 @@ def search_related_vertices(cb_name: str = Body(..., examples=["sofaboot"]),
|
||||||
def search_code_by_vertex(cb_name: str = Body(..., examples=["sofaboot"]),
|
def search_code_by_vertex(cb_name: str = Body(..., examples=["sofaboot"]),
|
||||||
vertex: str = Body(..., examples=['***'])) -> dict:
|
vertex: str = Body(..., examples=['***'])) -> dict:
|
||||||
|
|
||||||
logger.info('cb_name={}'.format(cb_name))
|
# logger.info('cb_name={}'.format(cb_name))
|
||||||
logger.info('vertex={}'.format(vertex))
|
# logger.info('vertex={}'.format(vertex))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
nh = NebulaHandler(host=NEBULA_HOST, port=NEBULA_PORT, username=NEBULA_USER,
|
nh = NebulaHandler(host=NEBULA_HOST, port=NEBULA_PORT, username=NEBULA_USER,
|
||||||
|
@ -233,7 +240,7 @@ def search_code_by_vertex(cb_name: str = Body(..., examples=["sofaboot"]),
|
||||||
return res
|
return res
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception(e)
|
logger.exception(e)
|
||||||
return {}
|
return {'code': ""}
|
||||||
|
|
||||||
|
|
||||||
def cb_exists_api(cb_name: str = Body(..., examples=["sofaboot"])) -> bool:
|
def cb_exists_api(cb_name: str = Body(..., examples=["sofaboot"])) -> bool:
|
||||||
|
|
|
@ -8,17 +8,6 @@ from loguru import logger
|
||||||
from langchain.embeddings.base import Embeddings
|
from langchain.embeddings.base import Embeddings
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
||||||
from langchain.vectorstores.utils import DistanceStrategy
|
|
||||||
|
|
||||||
# from configs.model_config import (
|
|
||||||
# KB_ROOT_PATH,
|
|
||||||
# CACHED_VS_NUM,
|
|
||||||
# EMBEDDING_MODEL,
|
|
||||||
# EMBEDDING_DEVICE,
|
|
||||||
# SCORE_THRESHOLD,
|
|
||||||
# FAISS_NORMALIZE_L2
|
|
||||||
# )
|
|
||||||
# from configs.model_config import embedding_model_dict
|
|
||||||
|
|
||||||
from coagent.base_configs.env_config import (
|
from coagent.base_configs.env_config import (
|
||||||
KB_ROOT_PATH,
|
KB_ROOT_PATH,
|
||||||
|
@ -52,15 +41,15 @@ def load_vector_store(
|
||||||
tick: int = 0, # tick will be changed by upload_doc etc. and make cache refreshed.
|
tick: int = 0, # tick will be changed by upload_doc etc. and make cache refreshed.
|
||||||
kb_root_path: str = KB_ROOT_PATH,
|
kb_root_path: str = KB_ROOT_PATH,
|
||||||
):
|
):
|
||||||
print(f"loading vector store in '{knowledge_base_name}'.")
|
# print(f"loading vector store in '{knowledge_base_name}'.")
|
||||||
vs_path = get_vs_path(knowledge_base_name, kb_root_path)
|
vs_path = get_vs_path(knowledge_base_name, kb_root_path)
|
||||||
if embeddings is None:
|
if embeddings is None:
|
||||||
embeddings = load_embeddings_from_path(embed_config.embed_model_path, embed_config.model_device)
|
embeddings = load_embeddings_from_path(embed_config.embed_model_path, embed_config.model_device, embed_config.langchain_embeddings)
|
||||||
|
|
||||||
if not os.path.exists(vs_path):
|
if not os.path.exists(vs_path):
|
||||||
os.makedirs(vs_path)
|
os.makedirs(vs_path)
|
||||||
|
|
||||||
distance_strategy = DistanceStrategy.EUCLIDEAN_DISTANCE
|
distance_strategy = "EUCLIDEAN_DISTANCE"
|
||||||
if "index.faiss" in os.listdir(vs_path):
|
if "index.faiss" in os.listdir(vs_path):
|
||||||
search_index = FAISS.load_local(vs_path, embeddings, normalize_L2=FAISS_NORMALIZE_L2, distance_strategy=distance_strategy)
|
search_index = FAISS.load_local(vs_path, embeddings, normalize_L2=FAISS_NORMALIZE_L2, distance_strategy=distance_strategy)
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -9,9 +9,7 @@ from pydantic import BaseModel, Field
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from coagent.llm_models import LLMConfig, EmbedConfig
|
from coagent.llm_models import LLMConfig, EmbedConfig
|
||||||
|
|
||||||
from .base_tool import BaseToolModel
|
from .base_tool import BaseToolModel
|
||||||
|
|
||||||
from coagent.service.cb_api import search_code
|
from coagent.service.cb_api import search_code
|
||||||
|
|
||||||
|
|
||||||
|
@ -29,7 +27,17 @@ class CodeRetrieval(BaseToolModel):
|
||||||
code: str = Field(..., description="检索代码")
|
code: str = Field(..., description="检索代码")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def run(cls, code_base_name, query, code_limit=1, history_node_list=[], search_type="tag", llm_config: LLMConfig=None, embed_config: EmbedConfig=None):
|
def run(cls,
|
||||||
|
code_base_name,
|
||||||
|
query,
|
||||||
|
code_limit=1,
|
||||||
|
history_node_list=[],
|
||||||
|
search_type="tag",
|
||||||
|
llm_config: LLMConfig=None,
|
||||||
|
embed_config: EmbedConfig=None,
|
||||||
|
use_nh: str=True,
|
||||||
|
local_graph_path: str=''
|
||||||
|
):
|
||||||
"""excute your tool!"""
|
"""excute your tool!"""
|
||||||
|
|
||||||
search_type = {
|
search_type = {
|
||||||
|
@ -45,7 +53,8 @@ class CodeRetrieval(BaseToolModel):
|
||||||
codes = search_code(code_base_name, query, code_limit, search_type=search_type, history_node_list=history_node_list,
|
codes = search_code(code_base_name, query, code_limit, search_type=search_type, history_node_list=history_node_list,
|
||||||
embed_engine=embed_config.embed_engine, embed_model=embed_config.embed_model, embed_model_path=embed_config.embed_model_path,
|
embed_engine=embed_config.embed_engine, embed_model=embed_config.embed_model, embed_model_path=embed_config.embed_model_path,
|
||||||
model_device=embed_config.model_device, model_name=llm_config.model_name, temperature=llm_config.temperature,
|
model_device=embed_config.model_device, model_name=llm_config.model_name, temperature=llm_config.temperature,
|
||||||
api_base_url=llm_config.api_base_url, api_key=llm_config.api_key
|
api_base_url=llm_config.api_base_url, api_key=llm_config.api_key, use_nh=use_nh,
|
||||||
|
local_graph_path=local_graph_path, embed_config=embed_config
|
||||||
)
|
)
|
||||||
return_codes = []
|
return_codes = []
|
||||||
context = codes['context']
|
context = codes['context']
|
||||||
|
|
|
@ -5,6 +5,7 @@
|
||||||
@time: 2023/12/14 上午10:24
|
@time: 2023/12/14 上午10:24
|
||||||
@desc:
|
@desc:
|
||||||
'''
|
'''
|
||||||
|
import os
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
|
@ -40,10 +41,9 @@ class CodeRetrievalSingle(BaseToolModel):
|
||||||
vertex: str = Field(..., description="代码对应 id")
|
vertex: str = Field(..., description="代码对应 id")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def run(cls, code_base_name, query, embed_config: EmbedConfig, llm_config: LLMConfig, **kargs):
|
def run(cls, code_base_name, query, embed_config: EmbedConfig, llm_config: LLMConfig, search_type="description", **kargs):
|
||||||
"""excute your tool!"""
|
"""excute your tool!"""
|
||||||
|
|
||||||
search_type = 'description'
|
|
||||||
code_limit = 1
|
code_limit = 1
|
||||||
|
|
||||||
# default
|
# default
|
||||||
|
@ -51,10 +51,11 @@ class CodeRetrievalSingle(BaseToolModel):
|
||||||
history_node_list=[],
|
history_node_list=[],
|
||||||
embed_engine=embed_config.embed_engine, embed_model=embed_config.embed_model, embed_model_path=embed_config.embed_model_path,
|
embed_engine=embed_config.embed_engine, embed_model=embed_config.embed_model, embed_model_path=embed_config.embed_model_path,
|
||||||
model_device=embed_config.model_device, model_name=llm_config.model_name, temperature=llm_config.temperature,
|
model_device=embed_config.model_device, model_name=llm_config.model_name, temperature=llm_config.temperature,
|
||||||
api_base_url=llm_config.api_base_url, api_key=llm_config.api_key
|
api_base_url=llm_config.api_base_url, api_key=llm_config.api_key, embed_config=embed_config, use_nh=kargs.get("use_nh", True),
|
||||||
|
local_graph_path=kargs.get("local_graph_path", "")
|
||||||
)
|
)
|
||||||
|
if os.environ.get("log_verbose", "0") >= "3":
|
||||||
logger.debug(search_result)
|
logger.debug(search_result)
|
||||||
code = search_result['context']
|
code = search_result['context']
|
||||||
vertex = search_result['related_vertices'][0]
|
vertex = search_result['related_vertices'][0]
|
||||||
# logger.debug(f"code: {code}, vertex: {vertex}")
|
# logger.debug(f"code: {code}, vertex: {vertex}")
|
||||||
|
@ -83,7 +84,7 @@ class RelatedVerticesRetrival(BaseToolModel):
|
||||||
def run(cls, code_base_name: str, vertex: str, **kargs):
|
def run(cls, code_base_name: str, vertex: str, **kargs):
|
||||||
"""execute your tool!"""
|
"""execute your tool!"""
|
||||||
related_vertices = search_related_vertices(cb_name=code_base_name, vertex=vertex)
|
related_vertices = search_related_vertices(cb_name=code_base_name, vertex=vertex)
|
||||||
logger.debug(f"related_vertices: {related_vertices}")
|
# logger.debug(f"related_vertices: {related_vertices}")
|
||||||
|
|
||||||
return related_vertices
|
return related_vertices
|
||||||
|
|
||||||
|
@ -110,6 +111,6 @@ class Vertex2Code(BaseToolModel):
|
||||||
else:
|
else:
|
||||||
vertex = vertex.strip(' "')
|
vertex = vertex.strip(' "')
|
||||||
|
|
||||||
logger.info(f'vertex={vertex}')
|
# logger.info(f'vertex={vertex}')
|
||||||
res = search_code_by_vertex(cb_name=code_base_name, vertex=vertex)
|
res = search_code_by_vertex(cb_name=code_base_name, vertex=vertex)
|
||||||
return res
|
return res
|
|
@ -2,11 +2,7 @@ from pydantic import BaseModel, Field
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from coagent.llm_models.llm_config import EmbedConfig
|
from coagent.llm_models.llm_config import EmbedConfig
|
||||||
|
|
||||||
from .base_tool import BaseToolModel
|
from .base_tool import BaseToolModel
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
from coagent.service.kb_api import search_docs
|
from coagent.service.kb_api import search_docs
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -9,8 +9,10 @@ import numpy as np
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from .base_tool import BaseToolModel
|
from .base_tool import BaseToolModel
|
||||||
|
try:
|
||||||
from duckduckgo_search import DDGS
|
from duckduckgo_search import DDGS
|
||||||
|
except:
|
||||||
|
logger.warning("can't find duckduckgo_search, if you need it, please `pip install duckduckgo_search`")
|
||||||
|
|
||||||
|
|
||||||
class DDGSTool(BaseToolModel):
|
class DDGSTool(BaseToolModel):
|
||||||
|
|
|
@ -0,0 +1,89 @@
|
||||||
|
import json
|
||||||
|
|
||||||
|
|
||||||
|
def class_info_decode(data):
|
||||||
|
'''解析class的相关信息'''
|
||||||
|
params_dict = {}
|
||||||
|
|
||||||
|
for i in data:
|
||||||
|
_params_dict = {}
|
||||||
|
for ii in i:
|
||||||
|
for k, v in ii.items():
|
||||||
|
if k=="origin_query": continue
|
||||||
|
|
||||||
|
if k == "Code Path":
|
||||||
|
_params_dict["code_path"] = v.split("#")[0]
|
||||||
|
_params_dict["function_name"] = ".".join(v.split("#")[1:])
|
||||||
|
|
||||||
|
if k == "Class Description":
|
||||||
|
_params_dict["ClassDescription"] = v
|
||||||
|
|
||||||
|
if k == "Class Base":
|
||||||
|
_params_dict["ClassBase"] = v
|
||||||
|
|
||||||
|
if k=="Init Parameters":
|
||||||
|
_params_dict["Parameters"] = v
|
||||||
|
|
||||||
|
|
||||||
|
code_path = _params_dict["code_path"]
|
||||||
|
params_dict.setdefault(code_path, []).append(_params_dict)
|
||||||
|
|
||||||
|
return params_dict
|
||||||
|
|
||||||
|
def method_info_decode(data):
|
||||||
|
params_dict = {}
|
||||||
|
|
||||||
|
for i in data:
|
||||||
|
_params_dict = {}
|
||||||
|
for ii in i:
|
||||||
|
for k, v in ii.items():
|
||||||
|
if k=="origin_query": continue
|
||||||
|
|
||||||
|
if k == "Code Path":
|
||||||
|
_params_dict["code_path"] = v.split("#")[0]
|
||||||
|
_params_dict["function_name"] = ".".join(v.split("#")[1:])
|
||||||
|
|
||||||
|
if k == "Return Value Description":
|
||||||
|
_params_dict["Returns"] = v
|
||||||
|
|
||||||
|
if k == "Return Type":
|
||||||
|
_params_dict["ReturnType"] = v
|
||||||
|
|
||||||
|
if k=="Parameters":
|
||||||
|
_params_dict["Parameters"] = v
|
||||||
|
|
||||||
|
|
||||||
|
code_path = _params_dict["code_path"]
|
||||||
|
params_dict.setdefault(code_path, []).append(_params_dict)
|
||||||
|
|
||||||
|
return params_dict
|
||||||
|
|
||||||
|
def encode2md(data, md_format):
|
||||||
|
md_dict = {}
|
||||||
|
for code_path, params_list in data.items():
|
||||||
|
for params in params_list:
|
||||||
|
params["Parameters_text"] = "\n".join([f"{param['param']}({param['param_type']})-{param['param_description']}"
|
||||||
|
for param in params["Parameters"]])
|
||||||
|
# params.delete("Parameters")
|
||||||
|
text=md_format.format(**params)
|
||||||
|
md_dict.setdefault(code_path, []).append(text)
|
||||||
|
return md_dict
|
||||||
|
|
||||||
|
|
||||||
|
method_text_md = '''> {function_name}
|
||||||
|
|
||||||
|
| Column Name | Content |
|
||||||
|
|-----------------|-----------------|
|
||||||
|
| Parameters | {Parameters_text} |
|
||||||
|
| Returns | {Returns} |
|
||||||
|
| Return type | {ReturnType} |
|
||||||
|
'''
|
||||||
|
|
||||||
|
class_text_md = '''> {code_path}
|
||||||
|
|
||||||
|
Bases: {ClassBase}
|
||||||
|
|
||||||
|
{ClassDescription}
|
||||||
|
|
||||||
|
{Parameters_text}
|
||||||
|
'''
|
|
@ -7,7 +7,7 @@ from pathlib import Path
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from fastapi import Body, File, Form, Body, Query, UploadFile
|
from fastapi import Body, File, Form, Body, Query, UploadFile
|
||||||
from tempfile import SpooledTemporaryFile
|
from tempfile import SpooledTemporaryFile
|
||||||
|
import json
|
||||||
|
|
||||||
|
|
||||||
DATE_FORMAT = "%Y-%m-%d %H:%M:%S"
|
DATE_FORMAT = "%Y-%m-%d %H:%M:%S"
|
||||||
|
@ -109,4 +109,6 @@ def get_uploadfile(file: Union[str, Path, bytes], filename=None) -> UploadFile:
|
||||||
temp_file = SpooledTemporaryFile(max_size=10 * 1024 * 1024)
|
temp_file = SpooledTemporaryFile(max_size=10 * 1024 * 1024)
|
||||||
temp_file.write(file.read())
|
temp_file.write(file.read())
|
||||||
temp_file.seek(0)
|
temp_file.seek(0)
|
||||||
return UploadFile(file=temp_file, filename=filename)
|
return UploadFile(file=temp_file, filename=filename)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
import os
|
import os
|
||||||
from langchain.document_loaders import CSVLoader, PyPDFLoader, UnstructuredFileLoader, TextLoader, PythonLoader
|
from langchain.document_loaders import CSVLoader, PyPDFLoader, UnstructuredFileLoader, TextLoader, PythonLoader
|
||||||
|
|
||||||
from coagent.document_loaders import JSONLLoader, JSONLoader
|
from coagent.retrieval.document_loaders import JSONLLoader, JSONLoader
|
||||||
# from configs.model_config import (
|
# from configs.model_config import (
|
||||||
# embedding_model_dict,
|
# embedding_model_dict,
|
||||||
# KB_ROOT_PATH,
|
# KB_ROOT_PATH,
|
||||||
|
|
|
@ -21,17 +21,20 @@ JUPYTER_WORK_PATH = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath
|
||||||
# WEB_CRAWL存储路径
|
# WEB_CRAWL存储路径
|
||||||
WEB_CRAWL_PATH = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "knowledge_base")
|
WEB_CRAWL_PATH = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "knowledge_base")
|
||||||
# NEBULA_DATA存储路径
|
# NEBULA_DATA存储路径
|
||||||
NELUBA_PATH = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "data/neluba_data")
|
NEBULA_PATH = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "data/nebula_data")
|
||||||
|
|
||||||
for _path in [LOG_PATH, SOURCE_PATH, KB_ROOT_PATH, NLTK_DATA_PATH, JUPYTER_WORK_PATH, WEB_CRAWL_PATH, NELUBA_PATH]:
|
# CHROMA 存储路径
|
||||||
|
CHROMA_PERSISTENT_PATH = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "data/chroma_data")
|
||||||
|
|
||||||
|
for _path in [LOG_PATH, SOURCE_PATH, KB_ROOT_PATH, CB_ROOT_PATH, NLTK_DATA_PATH, JUPYTER_WORK_PATH, WEB_CRAWL_PATH, NEBULA_PATH, CHROMA_PERSISTENT_PATH]:
|
||||||
if not os.path.exists(_path):
|
if not os.path.exists(_path):
|
||||||
os.makedirs(_path, exist_ok=True)
|
os.makedirs(_path, exist_ok=True)
|
||||||
|
|
||||||
#
|
|
||||||
path_envt_dict = {
|
path_envt_dict = {
|
||||||
"LOG_PATH": LOG_PATH, "SOURCE_PATH": SOURCE_PATH, "KB_ROOT_PATH": KB_ROOT_PATH,
|
"LOG_PATH": LOG_PATH, "SOURCE_PATH": SOURCE_PATH, "KB_ROOT_PATH": KB_ROOT_PATH,
|
||||||
"NLTK_DATA_PATH":NLTK_DATA_PATH, "JUPYTER_WORK_PATH": JUPYTER_WORK_PATH,
|
"NLTK_DATA_PATH":NLTK_DATA_PATH, "JUPYTER_WORK_PATH": JUPYTER_WORK_PATH,
|
||||||
"WEB_CRAWL_PATH": WEB_CRAWL_PATH, "NELUBA_PATH": NELUBA_PATH
|
"WEB_CRAWL_PATH": WEB_CRAWL_PATH, "NEBULA_PATH": NEBULA_PATH,
|
||||||
|
"CHROMA_PERSISTENT_PATH": CHROMA_PERSISTENT_PATH
|
||||||
}
|
}
|
||||||
for path_name, _path in path_envt_dict.items():
|
for path_name, _path in path_envt_dict.items():
|
||||||
os.environ[path_name] = _path
|
os.environ[path_name] = _path
|
||||||
|
|
|
@ -33,7 +33,7 @@ except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# add your openai key
|
# add your openai key
|
||||||
OPENAI_API_BASE = "http://openai.com/v1/chat/completions"
|
OPENAI_API_BASE = "https://api.openai.com/v1"
|
||||||
os.environ["API_BASE_URL"] = OPENAI_API_BASE
|
os.environ["API_BASE_URL"] = OPENAI_API_BASE
|
||||||
os.environ["OPENAI_API_KEY"] = "sk-xx"
|
os.environ["OPENAI_API_KEY"] = "sk-xx"
|
||||||
openai.api_key = "sk-xx"
|
openai.api_key = "sk-xx"
|
||||||
|
|
|
@ -58,9 +58,6 @@ NEBULA_GRAPH_SERVER = {
|
||||||
"docker_port": NEBULA_PORT
|
"docker_port": NEBULA_PORT
|
||||||
}
|
}
|
||||||
|
|
||||||
# chroma conf
|
|
||||||
CHROMA_PERSISTENT_PATH = '/home/user/chatbot/data/chroma_data'
|
|
||||||
|
|
||||||
# sandbox api server
|
# sandbox api server
|
||||||
SANDBOX_CONTRAINER_NAME = "devopsgpt_sandbox"
|
SANDBOX_CONTRAINER_NAME = "devopsgpt_sandbox"
|
||||||
SANDBOX_IMAGE_NAME = "devopsgpt:py39"
|
SANDBOX_IMAGE_NAME = "devopsgpt:py39"
|
||||||
|
|
|
@ -15,11 +15,11 @@ from coagent.connector.schema import Message
|
||||||
#
|
#
|
||||||
tools = toLangchainTools([TOOL_DICT[i] for i in TOOL_SETS if i in TOOL_DICT])
|
tools = toLangchainTools([TOOL_DICT[i] for i in TOOL_SETS if i in TOOL_DICT])
|
||||||
# log-level,print prompt和llm predict
|
# log-level,print prompt和llm predict
|
||||||
os.environ["log_verbose"] = "0"
|
os.environ["log_verbose"] = "2"
|
||||||
|
|
||||||
phase_name = "baseGroupPhase"
|
phase_name = "baseGroupPhase"
|
||||||
llm_config = LLMConfig(
|
llm_config = LLMConfig(
|
||||||
model_name=LLM_MODEL, model_device="cpu",api_key=os.environ["OPENAI_API_KEY"],
|
model_name=LLM_MODEL, api_key=os.environ["OPENAI_API_KEY"],
|
||||||
api_base_url=os.environ["API_BASE_URL"], temperature=0.3
|
api_base_url=os.environ["API_BASE_URL"], temperature=0.3
|
||||||
)
|
)
|
||||||
embed_config = EmbedConfig(
|
embed_config = EmbedConfig(
|
||||||
|
|
|
@ -17,7 +17,7 @@ os.environ["log_verbose"] = "2"
|
||||||
|
|
||||||
phase_name = "baseTaskPhase"
|
phase_name = "baseTaskPhase"
|
||||||
llm_config = LLMConfig(
|
llm_config = LLMConfig(
|
||||||
model_name="gpt-3.5-turbo", model_device="cpu",api_key=os.environ["OPENAI_API_KEY"],
|
model_name="gpt-3.5-turbo", api_key=os.environ["OPENAI_API_KEY"],
|
||||||
api_base_url=os.environ["API_BASE_URL"], temperature=0.3
|
api_base_url=os.environ["API_BASE_URL"], temperature=0.3
|
||||||
)
|
)
|
||||||
embed_config = EmbedConfig(
|
embed_config = EmbedConfig(
|
||||||
|
|
|
@ -0,0 +1,135 @@
|
||||||
|
# encoding: utf-8
|
||||||
|
'''
|
||||||
|
@author: 温进
|
||||||
|
@file: codeChatPhaseLocal_example.py
|
||||||
|
@time: 2024/1/31 下午4:32
|
||||||
|
@desc:
|
||||||
|
'''
|
||||||
|
import os, sys, requests
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
src_dir = os.path.join(
|
||||||
|
os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
|
)
|
||||||
|
sys.path.append(src_dir)
|
||||||
|
|
||||||
|
from configs.model_config import KB_ROOT_PATH, JUPYTER_WORK_PATH, CB_ROOT_PATH
|
||||||
|
from configs.server_config import SANDBOX_SERVER
|
||||||
|
from coagent.tools import toLangchainTools, TOOL_DICT, TOOL_SETS
|
||||||
|
from coagent.llm_models.llm_config import EmbedConfig, LLMConfig
|
||||||
|
from coagent.connector.phase import BasePhase
|
||||||
|
from coagent.connector.schema import Message, Memory
|
||||||
|
from coagent.codechat.codebase_handler.codebase_handler import CodeBaseHandler
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# log-level,print prompt和llm predict
|
||||||
|
os.environ["log_verbose"] = "1"
|
||||||
|
|
||||||
|
llm_config = LLMConfig(
|
||||||
|
model_name="gpt-3.5-turbo", model_device="cpu",api_key=os.environ["OPENAI_API_KEY"],
|
||||||
|
api_base_url=os.environ["API_BASE_URL"], temperature=0.3
|
||||||
|
)
|
||||||
|
embed_config = EmbedConfig(
|
||||||
|
embed_engine="model", embed_model="text2vec-base-chinese",
|
||||||
|
embed_model_path=os.path.join(src_dir, "embedding_models/text2vec-base-chinese")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# delete codebase
|
||||||
|
codebase_name = 'client_local'
|
||||||
|
code_path = '/Users/bingxu/Desktop/工作/大模型/chatbot/test_code_repo/client'
|
||||||
|
code_path = "D://chromeDownloads/devopschat-bot/client_v2/client"
|
||||||
|
use_nh = True
|
||||||
|
# cbh = CodeBaseHandler(codebase_name, code_path, crawl_type='dir', use_nh=use_nh, local_graph_path=CB_ROOT_PATH,
|
||||||
|
# llm_config=llm_config, embed_config=embed_config)
|
||||||
|
cbh = CodeBaseHandler(codebase_name, code_path, crawl_type='dir', use_nh=use_nh, local_graph_path=CB_ROOT_PATH,
|
||||||
|
llm_config=llm_config, embed_config=embed_config)
|
||||||
|
cbh.delete_codebase(codebase_name=codebase_name)
|
||||||
|
|
||||||
|
|
||||||
|
# initialize codebase
|
||||||
|
codebase_name = 'client_local'
|
||||||
|
code_path = '/Users/bingxu/Desktop/工作/大模型/chatbot/test_code_repo/client'
|
||||||
|
code_path = "D://chromeDownloads/devopschat-bot/client_v2/client"
|
||||||
|
code_path = "/home/user/client"
|
||||||
|
use_nh = True
|
||||||
|
do_interpret = True
|
||||||
|
cbh = CodeBaseHandler(codebase_name, code_path, crawl_type='dir', use_nh=use_nh, local_graph_path=CB_ROOT_PATH,
|
||||||
|
llm_config=llm_config, embed_config=embed_config)
|
||||||
|
cbh.import_code(do_interpret=do_interpret)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# chat with codebase
|
||||||
|
phase_name = "codeChatPhase"
|
||||||
|
phase = BasePhase(
|
||||||
|
phase_name, sandbox_server=SANDBOX_SERVER, jupyter_work_path=JUPYTER_WORK_PATH,
|
||||||
|
embed_config=embed_config, llm_config=llm_config, kb_root_path=KB_ROOT_PATH,
|
||||||
|
)
|
||||||
|
|
||||||
|
# remove 这个函数是做什么的 => 基于标签
|
||||||
|
# 有没有函数已经实现了从字符串删除指定字符串的功能,使用的话可以怎么使用,写个java代码 => 基于描述
|
||||||
|
# 有根据我以下的需求用 java 开发一个方法:输入为字符串,将输入中的 .java 字符串给删除掉,然后返回新的字符串 => 基于描述
|
||||||
|
|
||||||
|
## 需要启动容器中的nebula,采用use_nh=True来构建代码库,是可以通过cypher来查询
|
||||||
|
# round-1
|
||||||
|
# query_content = "代码一共有多少类"
|
||||||
|
# query = Message(
|
||||||
|
# role_name="human", role_type="user",
|
||||||
|
# role_content=query_content, input_query=query_content, origin_query=query_content,
|
||||||
|
# code_engine_name="client_1", score_threshold=1.0, top_k=3, cb_search_type="cypher"
|
||||||
|
# )
|
||||||
|
#
|
||||||
|
# output_message1, _ = phase.step(query)
|
||||||
|
# print(output_message1)
|
||||||
|
|
||||||
|
# round-2
|
||||||
|
# query_content = "代码库里有哪些函数,返回5个就行"
|
||||||
|
# query = Message(
|
||||||
|
# role_name="human", role_type="user",
|
||||||
|
# role_content=query_content, input_query=query_content, origin_query=query_content,
|
||||||
|
# code_engine_name="client_1", score_threshold=1.0, top_k=3, cb_search_type="cypher"
|
||||||
|
# )
|
||||||
|
# output_message2, _ = phase.step(query)
|
||||||
|
# print(output_message2)
|
||||||
|
|
||||||
|
|
||||||
|
# round-3
|
||||||
|
query_content = "remove 这个函数是做什么的"
|
||||||
|
query = Message(
|
||||||
|
role_name="user", role_type="human",
|
||||||
|
role_content=query_content, input_query=query_content, origin_query=query_content,
|
||||||
|
code_engine_name=codebase_name, score_threshold=1.0, top_k=3, cb_search_type="tag",
|
||||||
|
use_nh=False, local_graph_path=CB_ROOT_PATH
|
||||||
|
)
|
||||||
|
output_message3, output_memory3 = phase.step(query)
|
||||||
|
print(output_memory3.to_str_messages(return_all=True, content_key="parsed_output_list"))
|
||||||
|
|
||||||
|
#
|
||||||
|
# # round-4
|
||||||
|
query_content = "有没有函数已经实现了从字符串删除指定字符串的功能,使用的话可以怎么使用,写个java代码"
|
||||||
|
query = Message(
|
||||||
|
role_name="human", role_type="user",
|
||||||
|
role_content=query_content, input_query=query_content, origin_query=query_content,
|
||||||
|
code_engine_name=codebase_name, score_threshold=1.0, top_k=3, cb_search_type="description",
|
||||||
|
use_nh=False, local_graph_path=CB_ROOT_PATH
|
||||||
|
)
|
||||||
|
output_message4, output_memory4 = phase.step(query)
|
||||||
|
print(output_memory4.to_str_messages(return_all=True, content_key="parsed_output_list"))
|
||||||
|
|
||||||
|
|
||||||
|
# # round-5
|
||||||
|
query_content = "有根据我以下的需求用 java 开发一个方法:输入为字符串,将输入中的 .java 字符串给删除掉,然后返回新的字符串"
|
||||||
|
query = Message(
|
||||||
|
role_name="human", role_type="user",
|
||||||
|
role_content=query_content, input_query=query_content, origin_query=query_content,
|
||||||
|
code_engine_name=codebase_name, score_threshold=1.0, top_k=3, cb_search_type="description",
|
||||||
|
use_nh=False, local_graph_path=CB_ROOT_PATH
|
||||||
|
)
|
||||||
|
output_message5, output_memory5 = phase.step(query)
|
||||||
|
print(output_memory5.to_str_messages(return_all=True, content_key="parsed_output_list"))
|
|
@ -17,13 +17,14 @@ os.environ["log_verbose"] = "2"
|
||||||
|
|
||||||
phase_name = "codeChatPhase"
|
phase_name = "codeChatPhase"
|
||||||
llm_config = LLMConfig(
|
llm_config = LLMConfig(
|
||||||
model_name="gpt-3.5-turbo", model_device="cpu",api_key=os.environ["OPENAI_API_KEY"],
|
model_name="gpt-3.5-turbo", api_key=os.environ["OPENAI_API_KEY"],
|
||||||
api_base_url=os.environ["API_BASE_URL"], temperature=0.3
|
api_base_url=os.environ["API_BASE_URL"], temperature=0.3
|
||||||
)
|
)
|
||||||
embed_config = EmbedConfig(
|
embed_config = EmbedConfig(
|
||||||
embed_engine="model", embed_model="text2vec-base-chinese",
|
embed_engine="model", embed_model="text2vec-base-chinese",
|
||||||
embed_model_path=os.path.join(src_dir, "embedding_models/text2vec-base-chinese")
|
embed_model_path=os.path.join(src_dir, "embedding_models/text2vec-base-chinese")
|
||||||
)
|
)
|
||||||
|
|
||||||
phase = BasePhase(
|
phase = BasePhase(
|
||||||
phase_name, sandbox_server=SANDBOX_SERVER, jupyter_work_path=JUPYTER_WORK_PATH,
|
phase_name, sandbox_server=SANDBOX_SERVER, jupyter_work_path=JUPYTER_WORK_PATH,
|
||||||
embed_config=embed_config, llm_config=llm_config, kb_root_path=KB_ROOT_PATH,
|
embed_config=embed_config, llm_config=llm_config, kb_root_path=KB_ROOT_PATH,
|
||||||
|
@ -35,50 +36,56 @@ phase = BasePhase(
|
||||||
# 有根据我以下的需求用 java 开发一个方法:输入为字符串,将输入中的 .java 字符串给删除掉,然后返回新的字符串 => 基于描述
|
# 有根据我以下的需求用 java 开发一个方法:输入为字符串,将输入中的 .java 字符串给删除掉,然后返回新的字符串 => 基于描述
|
||||||
|
|
||||||
# round-1
|
# round-1
|
||||||
query_content = "代码一共有多少类"
|
# query_content = "代码一共有多少类"
|
||||||
query = Message(
|
# query = Message(
|
||||||
role_name="human", role_type="user",
|
# role_name="human", role_type="user",
|
||||||
role_content=query_content, input_query=query_content, origin_query=query_content,
|
# role_content=query_content, input_query=query_content, origin_query=query_content,
|
||||||
code_engine_name="client", score_threshold=1.0, top_k=3, cb_search_type="cypher"
|
# code_engine_name="client_1", score_threshold=1.0, top_k=3, cb_search_type="cypher"
|
||||||
)
|
# )
|
||||||
|
#
|
||||||
output_message1, _ = phase.step(query)
|
# output_message1, _ = phase.step(query)
|
||||||
|
# print(output_message1)
|
||||||
|
|
||||||
# round-2
|
# round-2
|
||||||
query_content = "代码库里有哪些函数,返回5个就行"
|
# query_content = "代码库里有哪些函数,返回5个就行"
|
||||||
query = Message(
|
# query = Message(
|
||||||
role_name="human", role_type="user",
|
# role_name="human", role_type="user",
|
||||||
role_content=query_content, input_query=query_content, origin_query=query_content,
|
# role_content=query_content, input_query=query_content, origin_query=query_content,
|
||||||
code_engine_name="client", score_threshold=1.0, top_k=3, cb_search_type="cypher"
|
# code_engine_name="client_1", score_threshold=1.0, top_k=3, cb_search_type="cypher"
|
||||||
)
|
# )
|
||||||
output_message2, _ = phase.step(query)
|
# output_message2, _ = phase.step(query)
|
||||||
|
# print(output_message2)
|
||||||
|
|
||||||
# round-3
|
#
|
||||||
|
# # round-3
|
||||||
query_content = "remove 这个函数是做什么的"
|
query_content = "remove 这个函数是做什么的"
|
||||||
query = Message(
|
query = Message(
|
||||||
role_name="user", role_type="human",
|
role_name="user", role_type="human",
|
||||||
role_content=query_content, input_query=query_content, origin_query=query_content,
|
role_content=query_content, input_query=query_content, origin_query=query_content,
|
||||||
code_engine_name="client", score_threshold=1.0, top_k=3, cb_search_type="tag"
|
code_engine_name="client", score_threshold=1.0, top_k=3, cb_search_type="tag"
|
||||||
)
|
)
|
||||||
output_message3, _ = phase.step(query)
|
output_message3, _ = phase.step(query)
|
||||||
|
print(output_message3)
|
||||||
|
|
||||||
# round-4
|
#
|
||||||
query_content = "有没有函数已经实现了从字符串删除指定字符串的功能,使用的话可以怎么使用,写个java代码"
|
# # round-4
|
||||||
query = Message(
|
# query_content = "有没有函数已经实现了从字符串删除指定字符串的功能,使用的话可以怎么使用,写个java代码"
|
||||||
role_name="human", role_type="user",
|
# query = Message(
|
||||||
role_content=query_content, input_query=query_content, origin_query=query_content,
|
# role_name="human", role_type="user",
|
||||||
code_engine_name="client", score_threshold=1.0, top_k=3, cb_search_type="description"
|
# role_content=query_content, input_query=query_content, origin_query=query_content,
|
||||||
)
|
# code_engine_name="client_1", score_threshold=1.0, top_k=3, cb_search_type="description"
|
||||||
output_message4, _ = phase.step(query)
|
# )
|
||||||
|
# output_message4, _ = phase.step(query)
|
||||||
|
# print(output_message4)
|
||||||
# round-5
|
#
|
||||||
query_content = "有根据我以下的需求用 java 开发一个方法:输入为字符串,将输入中的 .java 字符串给删除掉,然后返回新的字符串"
|
# # round-5
|
||||||
query = Message(
|
# query_content = "有根据我以下的需求用 java 开发一个方法:输入为字符串,将输入中的 .java 字符串给删除掉,然后返回新的字符串"
|
||||||
role_name="human", role_type="user",
|
# query = Message(
|
||||||
role_content=query_content, input_query=query_content, origin_query=query_content,
|
# role_name="human", role_type="user",
|
||||||
code_engine_name="client", score_threshold=1.0, top_k=3, cb_search_type="description"
|
# role_content=query_content, input_query=query_content, origin_query=query_content,
|
||||||
)
|
# code_engine_name="client_1", score_threshold=1.0, top_k=3, cb_search_type="description"
|
||||||
output_message5, output_memory5 = phase.step(query)
|
# )
|
||||||
|
# output_message5, output_memory5 = phase.step(query)
|
||||||
print(output_memory5.to_str_messages(return_all=True, content_key="parsed_output_list"))
|
# print(output_message5)
|
||||||
|
#
|
||||||
|
# print(output_memory5.to_str_messages(return_all=True, content_key="parsed_output_list"))
|
|
@ -0,0 +1,507 @@
|
||||||
|
import os, sys, json
|
||||||
|
from loguru import logger
|
||||||
|
src_dir = os.path.join(
|
||||||
|
os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
|
)
|
||||||
|
sys.path.append(src_dir)
|
||||||
|
|
||||||
|
from configs.model_config import KB_ROOT_PATH, JUPYTER_WORK_PATH, CB_ROOT_PATH
|
||||||
|
from configs.server_config import SANDBOX_SERVER
|
||||||
|
from coagent.llm_models.llm_config import EmbedConfig, LLMConfig
|
||||||
|
|
||||||
|
from coagent.connector.phase import BasePhase
|
||||||
|
from coagent.connector.agents import BaseAgent
|
||||||
|
from coagent.connector.schema import Message
|
||||||
|
from coagent.tools import CodeRetrievalSingle
|
||||||
|
from coagent.codechat.codebase_handler.codebase_handler import CodeBaseHandler
|
||||||
|
import importlib
|
||||||
|
|
||||||
|
|
||||||
|
# 定义一个新的agent类
|
||||||
|
class CodeGenDocer(BaseAgent):
|
||||||
|
|
||||||
|
def start_action_step(self, message: Message) -> Message:
|
||||||
|
'''do action before agent predict '''
|
||||||
|
# 根据问题获取代码片段和节点信息
|
||||||
|
action_json = CodeRetrievalSingle.run(message.code_engine_name, message.origin_query,
|
||||||
|
llm_config=self.llm_config, embed_config=self.embed_config, local_graph_path=message.local_graph_path, use_nh=message.use_nh,search_type="tag")
|
||||||
|
current_vertex = action_json['vertex']
|
||||||
|
message.customed_kargs["Code Snippet"] = action_json["code"]
|
||||||
|
message.customed_kargs['Current_Vertex'] = current_vertex
|
||||||
|
return message
|
||||||
|
|
||||||
|
|
||||||
|
# add agent or prompt_manager class
|
||||||
|
agent_module = importlib.import_module("coagent.connector.agents")
|
||||||
|
setattr(agent_module, 'CodeGenDocer', CodeGenDocer)
|
||||||
|
|
||||||
|
|
||||||
|
# log-level,print prompt和llm predict
|
||||||
|
os.environ["log_verbose"] = "1"
|
||||||
|
|
||||||
|
phase_name = "code2DocsGroup"
|
||||||
|
llm_config = LLMConfig(
|
||||||
|
model_name="gpt-4", api_key=os.environ["OPENAI_API_KEY"],
|
||||||
|
api_base_url=os.environ["API_BASE_URL"], temperature=0.3
|
||||||
|
)
|
||||||
|
embed_config = EmbedConfig(
|
||||||
|
embed_engine="model", embed_model="text2vec-base-chinese",
|
||||||
|
embed_model_path=os.path.join(src_dir, "embedding_models/text2vec-base-chinese")
|
||||||
|
)
|
||||||
|
|
||||||
|
# initialize codebase
|
||||||
|
# delete codebase
|
||||||
|
codebase_name = 'client_local'
|
||||||
|
code_path = "D://chromeDownloads/devopschat-bot/client_v2/client"
|
||||||
|
use_nh = False
|
||||||
|
cbh = CodeBaseHandler(codebase_name, code_path, crawl_type='dir', use_nh=use_nh, local_graph_path=CB_ROOT_PATH,
|
||||||
|
llm_config=llm_config, embed_config=embed_config)
|
||||||
|
cbh.delete_codebase(codebase_name=codebase_name)
|
||||||
|
|
||||||
|
|
||||||
|
# load codebase
|
||||||
|
codebase_name = 'client_local'
|
||||||
|
code_path = "D://chromeDownloads/devopschat-bot/client_v2/client"
|
||||||
|
use_nh = True
|
||||||
|
do_interpret = True
|
||||||
|
cbh = CodeBaseHandler(codebase_name, code_path, crawl_type='dir', use_nh=use_nh, local_graph_path=CB_ROOT_PATH,
|
||||||
|
llm_config=llm_config, embed_config=embed_config)
|
||||||
|
cbh.import_code(do_interpret=do_interpret)
|
||||||
|
|
||||||
|
# 根据前面的load过程进行初始化
|
||||||
|
cbh = CodeBaseHandler(codebase_name, code_path, crawl_type='dir', use_nh=use_nh, local_graph_path=CB_ROOT_PATH,
|
||||||
|
llm_config=llm_config, embed_config=embed_config)
|
||||||
|
phase = BasePhase(
|
||||||
|
phase_name, sandbox_server=SANDBOX_SERVER, jupyter_work_path=JUPYTER_WORK_PATH,
|
||||||
|
embed_config=embed_config, llm_config=llm_config, kb_root_path=KB_ROOT_PATH,
|
||||||
|
)
|
||||||
|
|
||||||
|
for vertex_type in ["class", "method"]:
|
||||||
|
vertexes = cbh.search_vertices(vertex_type=vertex_type)
|
||||||
|
logger.info(f"vertexes={vertexes}")
|
||||||
|
|
||||||
|
# round-1
|
||||||
|
docs = []
|
||||||
|
for vertex in vertexes:
|
||||||
|
vertex = vertex.split("-")[0] # -为method的参数
|
||||||
|
query_content = f"为{vertex_type}节点 {vertex}生成文档"
|
||||||
|
query = Message(
|
||||||
|
role_name="human", role_type="user",
|
||||||
|
role_content=query_content, input_query=query_content, origin_query=query_content,
|
||||||
|
code_engine_name="client_local", score_threshold=1.0, top_k=3, cb_search_type="tag", use_nh=use_nh,
|
||||||
|
local_graph_path=CB_ROOT_PATH,
|
||||||
|
)
|
||||||
|
output_message, output_memory = phase.step(query, reinit_memory=True)
|
||||||
|
# print(output_memory.to_str_messages(return_all=True, content_key="parsed_output_list"))
|
||||||
|
docs.append(output_memory.get_spec_parserd_output())
|
||||||
|
|
||||||
|
os.makedirs(f"{CB_ROOT_PATH}/docs", exist_ok=True)
|
||||||
|
with open(f"{CB_ROOT_PATH}/docs/raw_{vertex_type}.json", "w") as f:
|
||||||
|
json.dump(docs, f)
|
||||||
|
|
||||||
|
|
||||||
|
# 下面把生成的文档信息转换成markdown文本
|
||||||
|
from coagent.utils.code2doc_util import *
|
||||||
|
import json
|
||||||
|
with open(f"{CB_ROOT_PATH}/docs/raw_method.json", "r") as f:
|
||||||
|
method_raw_data = json.load(f)
|
||||||
|
|
||||||
|
with open(f"{CB_ROOT_PATH}/docs/raw_class.json", "r") as f:
|
||||||
|
class_raw_data = json.load(f)
|
||||||
|
|
||||||
|
|
||||||
|
method_data = method_info_decode(method_raw_data)
|
||||||
|
class_data = class_info_decode(class_raw_data)
|
||||||
|
method_mds = encode2md(method_data, method_text_md)
|
||||||
|
class_mds = encode2md(class_data, class_text_md)
|
||||||
|
|
||||||
|
|
||||||
|
docs_dict = {}
|
||||||
|
for k,v in class_mds.items():
|
||||||
|
method_textmds = method_mds.get(k, [])
|
||||||
|
for vv in v:
|
||||||
|
# 理论上只有一个
|
||||||
|
text_md = vv
|
||||||
|
|
||||||
|
for method_textmd in method_textmds:
|
||||||
|
text_md += "\n<br>" + method_textmd
|
||||||
|
|
||||||
|
docs_dict.setdefault(k, []).append(text_md)
|
||||||
|
|
||||||
|
with open(f"{CB_ROOT_PATH}//docs/{k}.md", "w") as f:
|
||||||
|
f.write(text_md)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
####################################
|
||||||
|
######## 下面是完整的复现过程 ########
|
||||||
|
####################################
|
||||||
|
|
||||||
|
# import os, sys, requests
|
||||||
|
# from loguru import logger
|
||||||
|
# src_dir = os.path.join(
|
||||||
|
# os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
|
# )
|
||||||
|
# sys.path.append(src_dir)
|
||||||
|
|
||||||
|
# from configs.model_config import KB_ROOT_PATH, JUPYTER_WORK_PATH, CB_ROOT_PATH
|
||||||
|
# from configs.server_config import SANDBOX_SERVER
|
||||||
|
# from coagent.tools import toLangchainTools, TOOL_DICT, TOOL_SETS
|
||||||
|
# from coagent.llm_models.llm_config import EmbedConfig, LLMConfig
|
||||||
|
|
||||||
|
# from coagent.connector.phase import BasePhase
|
||||||
|
# from coagent.connector.agents import BaseAgent, SelectorAgent
|
||||||
|
# from coagent.connector.chains import BaseChain
|
||||||
|
# from coagent.connector.schema import (
|
||||||
|
# Message, Memory, load_role_configs, load_phase_configs, load_chain_configs, ActionStatus
|
||||||
|
# )
|
||||||
|
# from coagent.connector.memory_manager import BaseMemoryManager
|
||||||
|
# from coagent.connector.configs import AGETN_CONFIGS, CHAIN_CONFIGS, PHASE_CONFIGS, BASE_PROMPT_CONFIGS
|
||||||
|
# from coagent.connector.prompt_manager.prompt_manager import PromptManager
|
||||||
|
# from coagent.codechat.codebase_handler.codebase_handler import CodeBaseHandler
|
||||||
|
|
||||||
|
# import importlib
|
||||||
|
# from loguru import logger
|
||||||
|
|
||||||
|
|
||||||
|
# from coagent.tools import CodeRetrievalSingle, RelatedVerticesRetrival, Vertex2Code
|
||||||
|
|
||||||
|
|
||||||
|
# # update new agent configs
|
||||||
|
# codeGenDocGroup_PROMPT = """#### Agent Profile
|
||||||
|
|
||||||
|
# Your goal is to response according the Context Data's information with the role that will best facilitate a solution, taking into account all relevant context (Context) provided.
|
||||||
|
|
||||||
|
# When you need to select the appropriate role for handling a user's query, carefully read the provided role names, role descriptions and tool list.
|
||||||
|
|
||||||
|
# ATTENTION: response carefully referenced "Response Output Format" in format.
|
||||||
|
|
||||||
|
# #### Input Format
|
||||||
|
|
||||||
|
# #### Response Output Format
|
||||||
|
|
||||||
|
# **Code Path:** Extract the paths for the class/method/function that need to be addressed from the context
|
||||||
|
|
||||||
|
# **Role:** Select the role from agent names
|
||||||
|
# """
|
||||||
|
|
||||||
|
# classGenDoc_PROMPT = """#### Agent Profile
|
||||||
|
# As an advanced code documentation generator, you are proficient in translating class definitions into comprehensive documentation with a focus on instantiation parameters.
|
||||||
|
# Your specific task is to parse the given code snippet of a class, extract information regarding its instantiation parameters.
|
||||||
|
|
||||||
|
# ATTENTION: response carefully in "Response Output Format".
|
||||||
|
|
||||||
|
# #### Input Format
|
||||||
|
|
||||||
|
# **Code Snippet:** Provide the full class definition, including the constructor and any parameters it may require for instantiation.
|
||||||
|
|
||||||
|
# #### Response Output Format
|
||||||
|
# **Class Base:** Specify the base class or interface from which the current class extends, if any.
|
||||||
|
|
||||||
|
# **Class Description:** Offer a brief description of the class's purpose and functionality.
|
||||||
|
|
||||||
|
# **Init Parameters:** List each parameter from construct. For each parameter, provide:
|
||||||
|
# - `param`: The parameter name
|
||||||
|
# - `param_description`: A concise explanation of the parameter's purpose.
|
||||||
|
# - `param_type`: The data type of the parameter, if explicitly defined.
|
||||||
|
|
||||||
|
# ```json
|
||||||
|
# [
|
||||||
|
# {
|
||||||
|
# "param": "parameter_name",
|
||||||
|
# "param_description": "A brief description of what this parameter is used for.",
|
||||||
|
# "param_type": "The data type of the parameter"
|
||||||
|
# },
|
||||||
|
# ...
|
||||||
|
# ]
|
||||||
|
# ```
|
||||||
|
|
||||||
|
|
||||||
|
# If no parameter for construct, return
|
||||||
|
# ```json
|
||||||
|
# []
|
||||||
|
# ```
|
||||||
|
# """
|
||||||
|
|
||||||
|
# funcGenDoc_PROMPT = """#### Agent Profile
|
||||||
|
# You are a high-level code documentation assistant, skilled at extracting information from function/method code into detailed and well-structured documentation.
|
||||||
|
|
||||||
|
# ATTENTION: response carefully in "Response Output Format".
|
||||||
|
|
||||||
|
|
||||||
|
# #### Input Format
|
||||||
|
# **Code Path:** Provide the code path of the function or method you wish to document.
|
||||||
|
# This name will be used to identify and extract the relevant details from the code snippet provided.
|
||||||
|
|
||||||
|
# **Code Snippet:** A segment of code that contains the function or method to be documented.
|
||||||
|
|
||||||
|
# #### Response Output Format
|
||||||
|
|
||||||
|
# **Class Description:** Offer a brief description of the method(function)'s purpose and functionality.
|
||||||
|
|
||||||
|
# **Parameters:** Extract parameter for the specific function/method Code from Code Snippet. For parameter, provide:
|
||||||
|
# - `param`: The parameter name
|
||||||
|
# - `param_description`: A concise explanation of the parameter's purpose.
|
||||||
|
# - `param_type`: The data type of the parameter, if explicitly defined.
|
||||||
|
# ```json
|
||||||
|
# [
|
||||||
|
# {
|
||||||
|
# "param": "parameter_name",
|
||||||
|
# "param_description": "A brief description of what this parameter is used for.",
|
||||||
|
# "param_type": "The data type of the parameter"
|
||||||
|
# },
|
||||||
|
# ...
|
||||||
|
# ]
|
||||||
|
# ```
|
||||||
|
|
||||||
|
# If no parameter for function/method, return
|
||||||
|
# ```json
|
||||||
|
# []
|
||||||
|
# ```
|
||||||
|
|
||||||
|
# **Return Value Description:** Describe what the function/method returns upon completion.
|
||||||
|
|
||||||
|
# **Return Type:** Indicate the type of data the function/method returns (e.g., string, integer, object, void).
|
||||||
|
# """
|
||||||
|
|
||||||
|
# CODE_GENERATE_GROUP_PROMPT_CONFIGS = [
|
||||||
|
# {"field_name": 'agent_profile', "function_name": 'handle_agent_profile', "is_context": False},
|
||||||
|
# {"field_name": 'agent_infomation', "function_name": 'handle_agent_data', "is_context": False, "omit_if_empty": False},
|
||||||
|
# # {"field_name": 'tool_information',"function_name": 'handle_tool_data', "is_context": False},
|
||||||
|
# {"field_name": 'context_placeholder', "function_name": '', "is_context": True},
|
||||||
|
# # {"field_name": 'reference_documents', "function_name": 'handle_doc_info'},
|
||||||
|
# {"field_name": 'session_records', "function_name": 'handle_session_records'},
|
||||||
|
# {"field_name": 'Specific Objective', "function_name": 'handle_specific_objective'},
|
||||||
|
# {"field_name": 'Code Snippet', "function_name": 'handle_code_snippet'},
|
||||||
|
# {"field_name": 'output_format', "function_name": 'handle_output_format', 'title': 'Response Output Format', "is_context": False},
|
||||||
|
# {"field_name": 'begin!!!', "function_name": 'handle_response', "is_context": False, "omit_if_empty": False}
|
||||||
|
# ]
|
||||||
|
|
||||||
|
# CODE_GENERATE_DOC_PROMPT_CONFIGS = [
|
||||||
|
# {"field_name": 'agent_profile', "function_name": 'handle_agent_profile', "is_context": False},
|
||||||
|
# # {"field_name": 'tool_information',"function_name": 'handle_tool_data', "is_context": False},
|
||||||
|
# {"field_name": 'context_placeholder', "function_name": '', "is_context": True},
|
||||||
|
# # {"field_name": 'reference_documents', "function_name": 'handle_doc_info'},
|
||||||
|
# {"field_name": 'session_records', "function_name": 'handle_session_records'},
|
||||||
|
# {"field_name": 'Specific Objective', "function_name": 'handle_specific_objective'},
|
||||||
|
# {"field_name": 'Code Snippet', "function_name": 'handle_code_snippet'},
|
||||||
|
# {"field_name": 'output_format', "function_name": 'handle_output_format', 'title': 'Response Output Format', "is_context": False},
|
||||||
|
# {"field_name": 'begin!!!', "function_name": 'handle_response', "is_context": False, "omit_if_empty": False}
|
||||||
|
# ]
|
||||||
|
|
||||||
|
|
||||||
|
# class CodeGenDocPM(PromptManager):
|
||||||
|
# def handle_code_snippet(self, **kwargs) -> str:
|
||||||
|
# if 'previous_agent_message' not in kwargs:
|
||||||
|
# return ""
|
||||||
|
# previous_agent_message: Message = kwargs['previous_agent_message']
|
||||||
|
# code_snippet = previous_agent_message.customed_kargs.get("Code Snippet", "")
|
||||||
|
# current_vertex = previous_agent_message.customed_kargs.get("Current_Vertex", "")
|
||||||
|
# instruction = "A segment of code that contains the function or method to be documented.\n"
|
||||||
|
# return instruction + "\n" + f"name: {current_vertex}\n{code_snippet}"
|
||||||
|
|
||||||
|
# def handle_specific_objective(self, **kwargs) -> str:
|
||||||
|
# if 'previous_agent_message' not in kwargs:
|
||||||
|
# return ""
|
||||||
|
# previous_agent_message: Message = kwargs['previous_agent_message']
|
||||||
|
# specific_objective = previous_agent_message.parsed_output.get("Code Path")
|
||||||
|
|
||||||
|
# instruction = "Provide the code path of the function or method you wish to document.\n"
|
||||||
|
# s = instruction + f"\n{specific_objective}"
|
||||||
|
# return s
|
||||||
|
|
||||||
|
|
||||||
|
# from coagent.tools import CodeRetrievalSingle
|
||||||
|
|
||||||
|
# # 定义一个新的agent类
|
||||||
|
# class CodeGenDocer(BaseAgent):
|
||||||
|
|
||||||
|
# def start_action_step(self, message: Message) -> Message:
|
||||||
|
# '''do action before agent predict '''
|
||||||
|
# # 根据问题获取代码片段和节点信息
|
||||||
|
# action_json = CodeRetrievalSingle.run(message.code_engine_name, message.origin_query,
|
||||||
|
# llm_config=self.llm_config, embed_config=self.embed_config, local_graph_path=message.local_graph_path, use_nh=message.use_nh,search_type="tag")
|
||||||
|
# current_vertex = action_json['vertex']
|
||||||
|
# message.customed_kargs["Code Snippet"] = action_json["code"]
|
||||||
|
# message.customed_kargs['Current_Vertex'] = current_vertex
|
||||||
|
# return message
|
||||||
|
|
||||||
|
# # add agent or prompt_manager class
|
||||||
|
# agent_module = importlib.import_module("coagent.connector.agents")
|
||||||
|
# prompt_manager_module = importlib.import_module("coagent.connector.prompt_manager")
|
||||||
|
|
||||||
|
# setattr(agent_module, 'CodeGenDocer', CodeGenDocer)
|
||||||
|
# setattr(prompt_manager_module, 'CodeGenDocPM', CodeGenDocPM)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# AGETN_CONFIGS.update({
|
||||||
|
# "classGenDoc": {
|
||||||
|
# "role": {
|
||||||
|
# "role_prompt": classGenDoc_PROMPT,
|
||||||
|
# "role_type": "assistant",
|
||||||
|
# "role_name": "classGenDoc",
|
||||||
|
# "role_desc": "",
|
||||||
|
# "agent_type": "CodeGenDocer"
|
||||||
|
# },
|
||||||
|
# "prompt_config": CODE_GENERATE_DOC_PROMPT_CONFIGS,
|
||||||
|
# "prompt_manager_type": "CodeGenDocPM",
|
||||||
|
# "chat_turn": 1,
|
||||||
|
# "focus_agents": [],
|
||||||
|
# "focus_message_keys": [],
|
||||||
|
# },
|
||||||
|
# "funcGenDoc": {
|
||||||
|
# "role": {
|
||||||
|
# "role_prompt": funcGenDoc_PROMPT,
|
||||||
|
# "role_type": "assistant",
|
||||||
|
# "role_name": "funcGenDoc",
|
||||||
|
# "role_desc": "",
|
||||||
|
# "agent_type": "CodeGenDocer"
|
||||||
|
# },
|
||||||
|
# "prompt_config": CODE_GENERATE_DOC_PROMPT_CONFIGS,
|
||||||
|
# "prompt_manager_type": "CodeGenDocPM",
|
||||||
|
# "chat_turn": 1,
|
||||||
|
# "focus_agents": [],
|
||||||
|
# "focus_message_keys": [],
|
||||||
|
# },
|
||||||
|
# "codeGenDocsGrouper": {
|
||||||
|
# "role": {
|
||||||
|
# "role_prompt": codeGenDocGroup_PROMPT,
|
||||||
|
# "role_type": "assistant",
|
||||||
|
# "role_name": "codeGenDocsGrouper",
|
||||||
|
# "role_desc": "",
|
||||||
|
# "agent_type": "SelectorAgent"
|
||||||
|
# },
|
||||||
|
# "prompt_config": CODE_GENERATE_GROUP_PROMPT_CONFIGS,
|
||||||
|
# "group_agents": ["classGenDoc", "funcGenDoc"],
|
||||||
|
# "chat_turn": 1,
|
||||||
|
# },
|
||||||
|
# })
|
||||||
|
# # update new chain configs
|
||||||
|
# CHAIN_CONFIGS.update({
|
||||||
|
# "codeGenDocsGroupChain": {
|
||||||
|
# "chain_name": "codeGenDocsGroupChain",
|
||||||
|
# "chain_type": "BaseChain",
|
||||||
|
# "agents": ["codeGenDocsGrouper"],
|
||||||
|
# "chat_turn": 1,
|
||||||
|
# "do_checker": False,
|
||||||
|
# "chain_prompt": ""
|
||||||
|
# }
|
||||||
|
# })
|
||||||
|
|
||||||
|
# # update phase configs
|
||||||
|
# PHASE_CONFIGS.update({
|
||||||
|
# "codeGenDocsGroup": {
|
||||||
|
# "phase_name": "codeGenDocsGroup",
|
||||||
|
# "phase_type": "BasePhase",
|
||||||
|
# "chains": ["codeGenDocsGroupChain"],
|
||||||
|
# "do_summary": False,
|
||||||
|
# "do_search": False,
|
||||||
|
# "do_doc_retrieval": False,
|
||||||
|
# "do_code_retrieval": False,
|
||||||
|
# "do_tool_retrieval": False,
|
||||||
|
# },
|
||||||
|
# })
|
||||||
|
|
||||||
|
|
||||||
|
# role_configs = load_role_configs(AGETN_CONFIGS)
|
||||||
|
# chain_configs = load_chain_configs(CHAIN_CONFIGS)
|
||||||
|
# phase_configs = load_phase_configs(PHASE_CONFIGS)
|
||||||
|
|
||||||
|
# # log-level,print prompt和llm predict
|
||||||
|
# os.environ["log_verbose"] = "1"
|
||||||
|
|
||||||
|
# phase_name = "codeGenDocsGroup"
|
||||||
|
# llm_config = LLMConfig(
|
||||||
|
# model_name="gpt-4", api_key=os.environ["OPENAI_API_KEY"],
|
||||||
|
# api_base_url=os.environ["API_BASE_URL"], temperature=0.3
|
||||||
|
# )
|
||||||
|
# embed_config = EmbedConfig(
|
||||||
|
# embed_engine="model", embed_model="text2vec-base-chinese",
|
||||||
|
# embed_model_path=os.path.join(src_dir, "embedding_models/text2vec-base-chinese")
|
||||||
|
# )
|
||||||
|
|
||||||
|
|
||||||
|
# # initialize codebase
|
||||||
|
# # delete codebase
|
||||||
|
# codebase_name = 'client_local'
|
||||||
|
# code_path = "D://chromeDownloads/devopschat-bot/client_v2/client"
|
||||||
|
# use_nh = False
|
||||||
|
# cbh = CodeBaseHandler(codebase_name, code_path, crawl_type='dir', use_nh=use_nh, local_graph_path=CB_ROOT_PATH,
|
||||||
|
# llm_config=llm_config, embed_config=embed_config)
|
||||||
|
# cbh.delete_codebase(codebase_name=codebase_name)
|
||||||
|
|
||||||
|
|
||||||
|
# # load codebase
|
||||||
|
# codebase_name = 'client_local'
|
||||||
|
# code_path = "D://chromeDownloads/devopschat-bot/client_v2/client"
|
||||||
|
# use_nh = False
|
||||||
|
# do_interpret = True
|
||||||
|
# cbh = CodeBaseHandler(codebase_name, code_path, crawl_type='dir', use_nh=use_nh, local_graph_path=CB_ROOT_PATH,
|
||||||
|
# llm_config=llm_config, embed_config=embed_config)
|
||||||
|
# cbh.import_code(do_interpret=do_interpret)
|
||||||
|
|
||||||
|
|
||||||
|
# phase = BasePhase(
|
||||||
|
# phase_name, sandbox_server=SANDBOX_SERVER, jupyter_work_path=JUPYTER_WORK_PATH,
|
||||||
|
# embed_config=embed_config, llm_config=llm_config, kb_root_path=KB_ROOT_PATH,
|
||||||
|
# )
|
||||||
|
|
||||||
|
# for vertex_type in ["class", "method"]:
|
||||||
|
# vertexes = cbh.search_vertices(vertex_type=vertex_type)
|
||||||
|
# logger.info(f"vertexes={vertexes}")
|
||||||
|
|
||||||
|
# # round-1
|
||||||
|
# docs = []
|
||||||
|
# for vertex in vertexes:
|
||||||
|
# vertex = vertex.split("-")[0] # -为method的参数
|
||||||
|
# query_content = f"为{vertex_type}节点 {vertex}生成文档"
|
||||||
|
# query = Message(
|
||||||
|
# role_name="human", role_type="user",
|
||||||
|
# role_content=query_content, input_query=query_content, origin_query=query_content,
|
||||||
|
# code_engine_name="client_local", score_threshold=1.0, top_k=3, cb_search_type="tag", use_nh=use_nh,
|
||||||
|
# local_graph_path=CB_ROOT_PATH,
|
||||||
|
# )
|
||||||
|
# output_message, output_memory = phase.step(query, reinit_memory=True)
|
||||||
|
# # print(output_memory.to_str_messages(return_all=True, content_key="parsed_output_list"))
|
||||||
|
# docs.append(output_memory.get_spec_parserd_output())
|
||||||
|
|
||||||
|
# import json
|
||||||
|
# os.makedirs("/home/user/code_base/docs", exist_ok=True)
|
||||||
|
# with open(f"/home/user/code_base/docs/raw_{vertex_type}.json", "w") as f:
|
||||||
|
# json.dump(docs, f)
|
||||||
|
|
||||||
|
|
||||||
|
# # 下面把生成的文档信息转换成markdown文本
|
||||||
|
# from coagent.utils.code2doc_util import *
|
||||||
|
|
||||||
|
# import json
|
||||||
|
# with open(f"/home/user/code_base/docs/raw_method.json", "r") as f:
|
||||||
|
# method_raw_data = json.load(f)
|
||||||
|
|
||||||
|
# with open(f"/home/user/code_base/docs/raw_class.json", "r") as f:
|
||||||
|
# class_raw_data = json.load(f)
|
||||||
|
|
||||||
|
|
||||||
|
# method_data = method_info_decode(method_raw_data)
|
||||||
|
# class_data = class_info_decode(class_raw_data)
|
||||||
|
# method_mds = encode2md(method_data, method_text_md)
|
||||||
|
# class_mds = encode2md(class_data, class_text_md)
|
||||||
|
|
||||||
|
# docs_dict = {}
|
||||||
|
# for k,v in class_mds.items():
|
||||||
|
# method_textmds = method_mds.get(k, [])
|
||||||
|
# for vv in v:
|
||||||
|
# # 理论上只有一个
|
||||||
|
# text_md = vv
|
||||||
|
|
||||||
|
# for method_textmd in method_textmds:
|
||||||
|
# text_md += "\n<br>" + method_textmd
|
||||||
|
|
||||||
|
# docs_dict.setdefault(k, []).append(text_md)
|
||||||
|
|
||||||
|
# with open(f"/home/user/code_base/docs/{k}.md", "w") as f:
|
||||||
|
# f.write(text_md)
|
|
@ -0,0 +1,444 @@
|
||||||
|
import os, sys, json
|
||||||
|
from loguru import logger
|
||||||
|
src_dir = os.path.join(
|
||||||
|
os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
|
)
|
||||||
|
sys.path.append(src_dir)
|
||||||
|
|
||||||
|
from configs.model_config import KB_ROOT_PATH, JUPYTER_WORK_PATH, CB_ROOT_PATH
|
||||||
|
from configs.server_config import SANDBOX_SERVER
|
||||||
|
from coagent.llm_models.llm_config import EmbedConfig, LLMConfig
|
||||||
|
|
||||||
|
from coagent.connector.phase import BasePhase
|
||||||
|
from coagent.connector.agents import BaseAgent
|
||||||
|
from coagent.connector.schema import Message
|
||||||
|
from coagent.codechat.codebase_handler.codebase_handler import CodeBaseHandler
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
|
||||||
|
from coagent.tools import CodeRetrievalSingle, RelatedVerticesRetrival, Vertex2Code
|
||||||
|
|
||||||
|
# 定义一个新的agent类
|
||||||
|
class CodeRetrieval(BaseAgent):
|
||||||
|
|
||||||
|
def start_action_step(self, message: Message) -> Message:
|
||||||
|
'''do action before agent predict '''
|
||||||
|
# 根据问题获取代码片段和节点信息
|
||||||
|
action_json = CodeRetrievalSingle.run(message.code_engine_name, message.origin_query, llm_config=self.llm_config, embed_config=self.embed_config, search_type="tag",
|
||||||
|
local_graph_path=message.local_graph_path, use_nh=message.use_nh)
|
||||||
|
current_vertex = action_json['vertex']
|
||||||
|
message.customed_kargs["Code Snippet"] = action_json["code"]
|
||||||
|
message.customed_kargs['Current_Vertex'] = current_vertex
|
||||||
|
|
||||||
|
# 获取邻近节点
|
||||||
|
action_json = RelatedVerticesRetrival.run(message.code_engine_name, message.customed_kargs['Current_Vertex'])
|
||||||
|
# 获取邻近节点所有代码
|
||||||
|
relative_vertex = []
|
||||||
|
retrieval_Codes = []
|
||||||
|
for vertex in action_json["vertices"]:
|
||||||
|
# 由于代码是文件级别,所以相同文件代码不再获取
|
||||||
|
# logger.debug(f"{current_vertex}, {vertex}")
|
||||||
|
current_vertex_name = current_vertex.replace("#", "").replace(".java", "" ) if current_vertex.endswith(".java") else current_vertex
|
||||||
|
if current_vertex_name.split("#")[0] == vertex.split("#")[0]: continue
|
||||||
|
|
||||||
|
action_json = Vertex2Code.run(message.code_engine_name, vertex)
|
||||||
|
if action_json["code"]:
|
||||||
|
retrieval_Codes.append(action_json["code"])
|
||||||
|
relative_vertex.append(vertex)
|
||||||
|
#
|
||||||
|
message.customed_kargs["Retrieval_Codes"] = retrieval_Codes
|
||||||
|
message.customed_kargs["Relative_vertex"] = relative_vertex
|
||||||
|
return message
|
||||||
|
|
||||||
|
|
||||||
|
# add agent or prompt_manager class
|
||||||
|
agent_module = importlib.import_module("coagent.connector.agents")
|
||||||
|
setattr(agent_module, 'CodeRetrieval', CodeRetrieval)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
llm_config = LLMConfig(
|
||||||
|
model_name="gpt-4", api_key=os.environ["OPENAI_API_KEY"],
|
||||||
|
api_base_url=os.environ["API_BASE_URL"], temperature=0.3
|
||||||
|
)
|
||||||
|
embed_config = EmbedConfig(
|
||||||
|
embed_engine="model", embed_model="text2vec-base-chinese",
|
||||||
|
embed_model_path=os.path.join(src_dir, "embedding_models/text2vec-base-chinese")
|
||||||
|
)
|
||||||
|
|
||||||
|
## initialize codebase
|
||||||
|
# delete codebase
|
||||||
|
codebase_name = 'client_local'
|
||||||
|
code_path = "D://chromeDownloads/devopschat-bot/client_v2/client"
|
||||||
|
use_nh = False
|
||||||
|
cbh = CodeBaseHandler(codebase_name, code_path, crawl_type='dir', use_nh=use_nh, local_graph_path=CB_ROOT_PATH,
|
||||||
|
llm_config=llm_config, embed_config=embed_config)
|
||||||
|
cbh.delete_codebase(codebase_name=codebase_name)
|
||||||
|
|
||||||
|
|
||||||
|
# load codebase
|
||||||
|
codebase_name = 'client_local'
|
||||||
|
code_path = "D://chromeDownloads/devopschat-bot/client_v2/client"
|
||||||
|
use_nh = True
|
||||||
|
do_interpret = True
|
||||||
|
cbh = CodeBaseHandler(codebase_name, code_path, crawl_type='dir', use_nh=use_nh, local_graph_path=CB_ROOT_PATH,
|
||||||
|
llm_config=llm_config, embed_config=embed_config)
|
||||||
|
cbh.import_code(do_interpret=do_interpret)
|
||||||
|
|
||||||
|
|
||||||
|
cbh = CodeBaseHandler(codebase_name, code_path, crawl_type='dir', use_nh=use_nh, local_graph_path=CB_ROOT_PATH,
|
||||||
|
llm_config=llm_config, embed_config=embed_config)
|
||||||
|
vertexes = cbh.search_vertices(vertex_type="class")
|
||||||
|
|
||||||
|
|
||||||
|
# log-level,print prompt和llm predict
|
||||||
|
os.environ["log_verbose"] = "0"
|
||||||
|
|
||||||
|
phase_name = "code2Tests"
|
||||||
|
phase = BasePhase(
|
||||||
|
phase_name, sandbox_server=SANDBOX_SERVER, jupyter_work_path=JUPYTER_WORK_PATH,
|
||||||
|
embed_config=embed_config, llm_config=llm_config, kb_root_path=KB_ROOT_PATH,
|
||||||
|
)
|
||||||
|
|
||||||
|
# round-1
|
||||||
|
logger.debug(vertexes)
|
||||||
|
test_cases = []
|
||||||
|
for vertex in vertexes:
|
||||||
|
query_content = f"为{vertex}生成可执行的测例 "
|
||||||
|
query = Message(
|
||||||
|
role_name="human", role_type="user",
|
||||||
|
role_content=query_content, input_query=query_content, origin_query=query_content,
|
||||||
|
code_engine_name=codebase_name, score_threshold=1.0, top_k=3, cb_search_type="tag",
|
||||||
|
use_nh=use_nh, local_graph_path=CB_ROOT_PATH,
|
||||||
|
)
|
||||||
|
output_message, output_memory = phase.step(query, reinit_memory=True)
|
||||||
|
# print(output_memory.to_str_messages(return_all=True, content_key="parsed_output_list"))
|
||||||
|
print(output_memory.get_spec_parserd_output())
|
||||||
|
values = output_memory.get_spec_parserd_output()
|
||||||
|
test_code = {k:v for i in values for k,v in i.items() if k in ["SaveFileName", "Test Code"]}
|
||||||
|
test_cases.append(test_code)
|
||||||
|
|
||||||
|
os.makedirs(f"{CB_ROOT_PATH}/tests", exist_ok=True)
|
||||||
|
|
||||||
|
with open(f"{CB_ROOT_PATH}/tests/{test_code['SaveFileName']}", "w") as f:
|
||||||
|
f.write(test_code["Test Code"])
|
||||||
|
break
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# import os, sys, json
|
||||||
|
# from loguru import logger
|
||||||
|
# src_dir = os.path.join(
|
||||||
|
# os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
|
# )
|
||||||
|
# sys.path.append(src_dir)
|
||||||
|
|
||||||
|
# from configs.model_config import KB_ROOT_PATH, JUPYTER_WORK_PATH, CB_ROOT_PATH
|
||||||
|
# from configs.server_config import SANDBOX_SERVER
|
||||||
|
# from coagent.tools import toLangchainTools, TOOL_DICT, TOOL_SETS
|
||||||
|
# from coagent.llm_models.llm_config import EmbedConfig, LLMConfig
|
||||||
|
|
||||||
|
# from coagent.connector.phase import BasePhase
|
||||||
|
# from coagent.connector.agents import BaseAgent
|
||||||
|
# from coagent.connector.chains import BaseChain
|
||||||
|
# from coagent.connector.schema import (
|
||||||
|
# Message, Memory, load_role_configs, load_phase_configs, load_chain_configs, ActionStatus
|
||||||
|
# )
|
||||||
|
# from coagent.connector.memory_manager import BaseMemoryManager
|
||||||
|
# from coagent.connector.configs import AGETN_CONFIGS, CHAIN_CONFIGS, PHASE_CONFIGS, BASE_PROMPT_CONFIGS
|
||||||
|
# from coagent.connector.prompt_manager.prompt_manager import PromptManager
|
||||||
|
# from coagent.codechat.codebase_handler.codebase_handler import CodeBaseHandler
|
||||||
|
|
||||||
|
# import importlib
|
||||||
|
# from loguru import logger
|
||||||
|
|
||||||
|
# # 下面给定了一份代码片段,以及来自于它的依赖类、依赖方法相关的代码片段,你需要判断是否为这段指定代码片段生成测例。
|
||||||
|
# # 理论上所有代码都需要写测例,但是受限于人的精力不可能覆盖所有代码
|
||||||
|
# # 考虑以下因素进行裁剪:
|
||||||
|
# # 功能性: 如果它实现的是一个具体的功能或逻辑,则通常需要编写测试用例以验证其正确性。
|
||||||
|
# # 复杂性: 如果代码较为,尤其是包含多个条件判断、循环、异常处理等的代码,更可能隐藏bug,因此应该编写测试用例。如果代码涉及复杂的算法或者逻辑,那么编写测试用例可以帮助确保逻辑的正确性,并在未来的重构中防止引入错误。
|
||||||
|
# # 关键性: 如果它是关键路径的一部分或影响到核心功能,那么它就需要被测试。对于核心业务逻辑或者系统的关键组件,应当编写全面的测试用例来确保功能的正确性和稳定性。
|
||||||
|
# # 依赖性: 如果代码有外部依赖,可能需要编写集成测试或模拟这些依赖进行单元测试。
|
||||||
|
# # 用户输入: 如果代码处理用户输入,尤其是来自外部的、非受控的输入,那么创建测试用例来检查输入验证和处理是很重要的。
|
||||||
|
# # 频繁更改:对于经常需要更新或修改的代码,有相应的测试用例可以确保更改不会破坏现有功能。
|
||||||
|
|
||||||
|
|
||||||
|
# # 代码公开或重用:如果代码将被公开或用于其他项目,编写测试用例可以提高代码的可信度和易用性。
|
||||||
|
|
||||||
|
|
||||||
|
# # update new agent configs
|
||||||
|
# judgeGenerateTests_PROMPT = """#### Agent Profile
|
||||||
|
# When determining the necessity of writing test cases for a given code snippet,
|
||||||
|
# it's essential to evaluate its interactions with dependent classes and methods (retrieved code snippets),
|
||||||
|
# in addition to considering these critical factors:
|
||||||
|
# 1. Functionality: If it implements a concrete function or logic, test cases are typically necessary to verify its correctness.
|
||||||
|
# 2. Complexity: If the code is complex, especially if it contains multiple conditional statements, loops, exceptions handling, etc.,
|
||||||
|
# it's more likely to harbor bugs, and thus test cases should be written.
|
||||||
|
# If the code involves complex algorithms or logic, then writing test cases can help ensure the accuracy of the logic and prevent errors during future refactoring.
|
||||||
|
# 3. Criticality: If it's part of the critical path or affects core functionalities, then it needs to be tested.
|
||||||
|
# Comprehensive test cases should be written for core business logic or key components of the system to ensure the correctness and stability of the functionality.
|
||||||
|
# 4. Dependencies: If the code has external dependencies, integration testing may be necessary, or mocking these dependencies during unit testing might be required.
|
||||||
|
# 5. User Input: If the code handles user input, especially from unregulated external sources, creating test cases to check input validation and handling is important.
|
||||||
|
# 6. Frequent Changes: For code that requires regular updates or modifications, having the appropriate test cases ensures that changes do not break existing functionalities.
|
||||||
|
|
||||||
|
# #### Input Format
|
||||||
|
|
||||||
|
# **Code Snippet:** the initial Code or objective that the user wanted to achieve
|
||||||
|
|
||||||
|
# **Retrieval Code Snippets:** These are the associated code segments that the main Code Snippet depends on.
|
||||||
|
# Examine these snippets to understand how they interact with the main snippet and to determine how they might affect the overall functionality.
|
||||||
|
|
||||||
|
# #### Response Output Format
|
||||||
|
# **Action Status:** Set to 'finished' or 'continued'.
|
||||||
|
# If set to 'finished', the code snippet does not warrant the generation of a test case.
|
||||||
|
# If set to 'continued', the code snippet necessitates the creation of a test case.
|
||||||
|
|
||||||
|
# **REASON:** Justify the selection of 'finished' or 'continued', contemplating the decision through a step-by-step rationale.
|
||||||
|
# """
|
||||||
|
|
||||||
|
# generateTests_PROMPT = """#### Agent Profile
|
||||||
|
# As an agent specializing in software quality assurance,
|
||||||
|
# your mission is to craft comprehensive test cases that bolster the functionality, reliability, and robustness of a specified Code Snippet.
|
||||||
|
# This task is to be carried out with a keen understanding of the snippet's interactions with its dependent classes and methods—collectively referred to as Retrieval Code Snippets.
|
||||||
|
# Analyze the details given below to grasp the code's intended purpose, its inherent complexity, and the context within which it operates.
|
||||||
|
# Your constructed test cases must thoroughly examine the various factors influencing the code's quality and performance.
|
||||||
|
|
||||||
|
# ATTENTION: response carefully referenced "Response Output Format" in format.
|
||||||
|
|
||||||
|
# Each test case should include:
|
||||||
|
# 1. clear description of the test purpose.
|
||||||
|
# 2. The input values or conditions for the test.
|
||||||
|
# 3. The expected outcome or assertion for the test.
|
||||||
|
# 4. Appropriate tags (e.g., 'functional', 'integration', 'regression') that classify the type of test case.
|
||||||
|
# 5. these test code should have package and import
|
||||||
|
|
||||||
|
# #### Input Format
|
||||||
|
|
||||||
|
# **Code Snippet:** the initial Code or objective that the user wanted to achieve
|
||||||
|
|
||||||
|
# **Retrieval Code Snippets:** These are the interrelated pieces of code sourced from the codebase, which support or influence the primary Code Snippet.
|
||||||
|
|
||||||
|
# #### Response Output Format
|
||||||
|
# **SaveFileName:** construct a local file name based on Question and Context, such as
|
||||||
|
|
||||||
|
# ```java
|
||||||
|
# package/class.java
|
||||||
|
# ```
|
||||||
|
|
||||||
|
|
||||||
|
# **Test Code:** generate the test code for the current Code Snippet.
|
||||||
|
# ```java
|
||||||
|
# ...
|
||||||
|
# ```
|
||||||
|
|
||||||
|
# """
|
||||||
|
|
||||||
|
# CODE_GENERATE_TESTS_PROMPT_CONFIGS = [
|
||||||
|
# {"field_name": 'agent_profile', "function_name": 'handle_agent_profile', "is_context": False},
|
||||||
|
# # {"field_name": 'tool_information',"function_name": 'handle_tool_data', "is_context": False},
|
||||||
|
# {"field_name": 'context_placeholder', "function_name": '', "is_context": True},
|
||||||
|
# # {"field_name": 'reference_documents', "function_name": 'handle_doc_info'},
|
||||||
|
# {"field_name": 'session_records', "function_name": 'handle_session_records'},
|
||||||
|
# {"field_name": 'code_snippet', "function_name": 'handle_code_snippet'},
|
||||||
|
# {"field_name": 'retrieval_codes', "function_name": 'handle_retrieval_codes', "description": ""},
|
||||||
|
# {"field_name": 'output_format', "function_name": 'handle_output_format', 'title': 'Response Output Format', "is_context": False},
|
||||||
|
# {"field_name": 'begin!!!', "function_name": 'handle_response', "is_context": False, "omit_if_empty": False}
|
||||||
|
# ]
|
||||||
|
|
||||||
|
|
||||||
|
# class CodeRetrievalPM(PromptManager):
|
||||||
|
# def handle_code_snippet(self, **kwargs) -> str:
|
||||||
|
# if 'previous_agent_message' not in kwargs:
|
||||||
|
# return ""
|
||||||
|
# previous_agent_message: Message = kwargs['previous_agent_message']
|
||||||
|
# code_snippet = previous_agent_message.customed_kargs.get("Code Snippet", "")
|
||||||
|
# current_vertex = previous_agent_message.customed_kargs.get("Current_Vertex", "")
|
||||||
|
# instruction = "the initial Code or objective that the user wanted to achieve"
|
||||||
|
# return instruction + "\n" + f"name: {current_vertex}\n{code_snippet}"
|
||||||
|
|
||||||
|
# def handle_retrieval_codes(self, **kwargs) -> str:
|
||||||
|
# if 'previous_agent_message' not in kwargs:
|
||||||
|
# return ""
|
||||||
|
# previous_agent_message: Message = kwargs['previous_agent_message']
|
||||||
|
# Retrieval_Codes = previous_agent_message.customed_kargs["Retrieval_Codes"]
|
||||||
|
# Relative_vertex = previous_agent_message.customed_kargs["Relative_vertex"]
|
||||||
|
# instruction = "the initial Code or objective that the user wanted to achieve"
|
||||||
|
# s = instruction + "\n" + "\n".join([f"name: {vertext}\n{code}" for vertext, code in zip(Relative_vertex, Retrieval_Codes)])
|
||||||
|
# return s
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# AGETN_CONFIGS.update({
|
||||||
|
# "CodeJudger": {
|
||||||
|
# "role": {
|
||||||
|
# "role_prompt": judgeGenerateTests_PROMPT,
|
||||||
|
# "role_type": "assistant",
|
||||||
|
# "role_name": "CodeJudger",
|
||||||
|
# "role_desc": "",
|
||||||
|
# "agent_type": "CodeRetrieval"
|
||||||
|
# # "agent_type": "BaseAgent"
|
||||||
|
# },
|
||||||
|
# "prompt_config": CODE_GENERATE_TESTS_PROMPT_CONFIGS,
|
||||||
|
# "prompt_manager_type": "CodeRetrievalPM",
|
||||||
|
# "chat_turn": 1,
|
||||||
|
# "focus_agents": [],
|
||||||
|
# "focus_message_keys": [],
|
||||||
|
# },
|
||||||
|
# "generateTests": {
|
||||||
|
# "role": {
|
||||||
|
# "role_prompt": generateTests_PROMPT,
|
||||||
|
# "role_type": "assistant",
|
||||||
|
# "role_name": "generateTests",
|
||||||
|
# "role_desc": "",
|
||||||
|
# "agent_type": "CodeRetrieval"
|
||||||
|
# # "agent_type": "BaseAgent"
|
||||||
|
# },
|
||||||
|
# "prompt_config": CODE_GENERATE_TESTS_PROMPT_CONFIGS,
|
||||||
|
# "prompt_manager_type": "CodeRetrievalPM",
|
||||||
|
# "chat_turn": 1,
|
||||||
|
# "focus_agents": [],
|
||||||
|
# "focus_message_keys": [],
|
||||||
|
# },
|
||||||
|
# })
|
||||||
|
# # update new chain configs
|
||||||
|
# CHAIN_CONFIGS.update({
|
||||||
|
# "codeRetrievalChain": {
|
||||||
|
# "chain_name": "codeRetrievalChain",
|
||||||
|
# "chain_type": "BaseChain",
|
||||||
|
# "agents": ["CodeJudger", "generateTests"],
|
||||||
|
# "chat_turn": 1,
|
||||||
|
# "do_checker": False,
|
||||||
|
# "chain_prompt": ""
|
||||||
|
# }
|
||||||
|
# })
|
||||||
|
|
||||||
|
# # update phase configs
|
||||||
|
# PHASE_CONFIGS.update({
|
||||||
|
# "codeGenerateTests": {
|
||||||
|
# "phase_name": "codeGenerateTests",
|
||||||
|
# "phase_type": "BasePhase",
|
||||||
|
# "chains": ["codeRetrievalChain"],
|
||||||
|
# "do_summary": False,
|
||||||
|
# "do_search": False,
|
||||||
|
# "do_doc_retrieval": False,
|
||||||
|
# "do_code_retrieval": False,
|
||||||
|
# "do_tool_retrieval": False,
|
||||||
|
# },
|
||||||
|
# })
|
||||||
|
|
||||||
|
|
||||||
|
# role_configs = load_role_configs(AGETN_CONFIGS)
|
||||||
|
# chain_configs = load_chain_configs(CHAIN_CONFIGS)
|
||||||
|
# phase_configs = load_phase_configs(PHASE_CONFIGS)
|
||||||
|
|
||||||
|
|
||||||
|
# from coagent.tools import CodeRetrievalSingle, RelatedVerticesRetrival, Vertex2Code
|
||||||
|
|
||||||
|
# # 定义一个新的agent类
|
||||||
|
# class CodeRetrieval(BaseAgent):
|
||||||
|
|
||||||
|
# def start_action_step(self, message: Message) -> Message:
|
||||||
|
# '''do action before agent predict '''
|
||||||
|
# # 根据问题获取代码片段和节点信息
|
||||||
|
# action_json = CodeRetrievalSingle.run(message.code_engine_name, message.origin_query, llm_config=self.llm_config, embed_config=self.embed_config, search_type="tag",
|
||||||
|
# local_graph_path=message.local_graph_path, use_nh=message.use_nh)
|
||||||
|
# current_vertex = action_json['vertex']
|
||||||
|
# message.customed_kargs["Code Snippet"] = action_json["code"]
|
||||||
|
# message.customed_kargs['Current_Vertex'] = current_vertex
|
||||||
|
|
||||||
|
# # 获取邻近节点
|
||||||
|
# action_json = RelatedVerticesRetrival.run(message.code_engine_name, message.customed_kargs['Current_Vertex'])
|
||||||
|
# # 获取邻近节点所有代码
|
||||||
|
# relative_vertex = []
|
||||||
|
# retrieval_Codes = []
|
||||||
|
# for vertex in action_json["vertices"]:
|
||||||
|
# # 由于代码是文件级别,所以相同文件代码不再获取
|
||||||
|
# # logger.debug(f"{current_vertex}, {vertex}")
|
||||||
|
# current_vertex_name = current_vertex.replace("#", "").replace(".java", "" ) if current_vertex.endswith(".java") else current_vertex
|
||||||
|
# if current_vertex_name.split("#")[0] == vertex.split("#")[0]: continue
|
||||||
|
|
||||||
|
# action_json = Vertex2Code.run(message.code_engine_name, vertex)
|
||||||
|
# if action_json["code"]:
|
||||||
|
# retrieval_Codes.append(action_json["code"])
|
||||||
|
# relative_vertex.append(vertex)
|
||||||
|
# #
|
||||||
|
# message.customed_kargs["Retrieval_Codes"] = retrieval_Codes
|
||||||
|
# message.customed_kargs["Relative_vertex"] = relative_vertex
|
||||||
|
# return message
|
||||||
|
|
||||||
|
|
||||||
|
# # add agent or prompt_manager class
|
||||||
|
# agent_module = importlib.import_module("coagent.connector.agents")
|
||||||
|
# prompt_manager_module = importlib.import_module("coagent.connector.prompt_manager")
|
||||||
|
|
||||||
|
# setattr(agent_module, 'CodeRetrieval', CodeRetrieval)
|
||||||
|
# setattr(prompt_manager_module, 'CodeRetrievalPM', CodeRetrievalPM)
|
||||||
|
|
||||||
|
|
||||||
|
# # log-level,print prompt和llm predict
|
||||||
|
# os.environ["log_verbose"] = "0"
|
||||||
|
|
||||||
|
# phase_name = "codeGenerateTests"
|
||||||
|
# llm_config = LLMConfig(
|
||||||
|
# model_name="gpt-4", api_key=os.environ["OPENAI_API_KEY"],
|
||||||
|
# api_base_url=os.environ["API_BASE_URL"], temperature=0.3
|
||||||
|
# )
|
||||||
|
# embed_config = EmbedConfig(
|
||||||
|
# embed_engine="model", embed_model="text2vec-base-chinese",
|
||||||
|
# embed_model_path=os.path.join(src_dir, "embedding_models/text2vec-base-chinese")
|
||||||
|
# )
|
||||||
|
|
||||||
|
# ## initialize codebase
|
||||||
|
# # delete codebase
|
||||||
|
# codebase_name = 'client_local'
|
||||||
|
# code_path = "D://chromeDownloads/devopschat-bot/client_v2/client"
|
||||||
|
# use_nh = False
|
||||||
|
# cbh = CodeBaseHandler(codebase_name, code_path, crawl_type='dir', use_nh=use_nh, local_graph_path=CB_ROOT_PATH,
|
||||||
|
# llm_config=llm_config, embed_config=embed_config)
|
||||||
|
# cbh.delete_codebase(codebase_name=codebase_name)
|
||||||
|
|
||||||
|
|
||||||
|
# # load codebase
|
||||||
|
# codebase_name = 'client_local'
|
||||||
|
# code_path = "D://chromeDownloads/devopschat-bot/client_v2/client"
|
||||||
|
# use_nh = True
|
||||||
|
# do_interpret = True
|
||||||
|
# cbh = CodeBaseHandler(codebase_name, code_path, crawl_type='dir', use_nh=use_nh, local_graph_path=CB_ROOT_PATH,
|
||||||
|
# llm_config=llm_config, embed_config=embed_config)
|
||||||
|
# cbh.import_code(do_interpret=do_interpret)
|
||||||
|
|
||||||
|
|
||||||
|
# cbh = CodeBaseHandler(codebase_name, code_path, crawl_type='dir', use_nh=use_nh, local_graph_path=CB_ROOT_PATH,
|
||||||
|
# llm_config=llm_config, embed_config=embed_config)
|
||||||
|
# vertexes = cbh.search_vertices(vertex_type="class")
|
||||||
|
|
||||||
|
# phase = BasePhase(
|
||||||
|
# phase_name, sandbox_server=SANDBOX_SERVER, jupyter_work_path=JUPYTER_WORK_PATH,
|
||||||
|
# embed_config=embed_config, llm_config=llm_config, kb_root_path=KB_ROOT_PATH,
|
||||||
|
# )
|
||||||
|
|
||||||
|
# # round-1
|
||||||
|
# logger.debug(vertexes)
|
||||||
|
# test_cases = []
|
||||||
|
# for vertex in vertexes:
|
||||||
|
# query_content = f"为{vertex}生成可执行的测例 "
|
||||||
|
# query = Message(
|
||||||
|
# role_name="human", role_type="user",
|
||||||
|
# role_content=query_content, input_query=query_content, origin_query=query_content,
|
||||||
|
# code_engine_name=codebase_name, score_threshold=1.0, top_k=3, cb_search_type="tag",
|
||||||
|
# use_nh=use_nh, local_graph_path=CB_ROOT_PATH,
|
||||||
|
# )
|
||||||
|
# output_message, output_memory = phase.step(query, reinit_memory=True)
|
||||||
|
# # print(output_memory.to_str_messages(return_all=True, content_key="parsed_output_list"))
|
||||||
|
# print(output_memory.get_spec_parserd_output())
|
||||||
|
# values = output_memory.get_spec_parserd_output()
|
||||||
|
# test_code = {k:v for i in values for k,v in i.items() if k in ["SaveFileName", "Test Code"]}
|
||||||
|
# test_cases.append(test_code)
|
||||||
|
|
||||||
|
# os.makedirs(f"{CB_ROOT_PATH}/tests", exist_ok=True)
|
||||||
|
|
||||||
|
# with open(f"{CB_ROOT_PATH}/tests/{test_code['SaveFileName']}", "w") as f:
|
||||||
|
# f.write(test_code["Test Code"])
|
|
@ -17,7 +17,7 @@ os.environ["log_verbose"] = "2"
|
||||||
|
|
||||||
phase_name = "codeReactPhase"
|
phase_name = "codeReactPhase"
|
||||||
llm_config = LLMConfig(
|
llm_config = LLMConfig(
|
||||||
model_name="gpt-3.5-turbo", model_device="cpu",api_key=os.environ["OPENAI_API_KEY"],
|
model_name="gpt-3.5-turbo",api_key=os.environ["OPENAI_API_KEY"],
|
||||||
api_base_url=os.environ["API_BASE_URL"], temperature=0.3
|
api_base_url=os.environ["API_BASE_URL"], temperature=0.3
|
||||||
)
|
)
|
||||||
embed_config = EmbedConfig(
|
embed_config = EmbedConfig(
|
||||||
|
|
|
@ -18,8 +18,7 @@ from coagent.connector.schema import (
|
||||||
)
|
)
|
||||||
from coagent.connector.memory_manager import BaseMemoryManager
|
from coagent.connector.memory_manager import BaseMemoryManager
|
||||||
from coagent.connector.configs import AGETN_CONFIGS, CHAIN_CONFIGS, PHASE_CONFIGS, BASE_PROMPT_CONFIGS
|
from coagent.connector.configs import AGETN_CONFIGS, CHAIN_CONFIGS, PHASE_CONFIGS, BASE_PROMPT_CONFIGS
|
||||||
from coagent.connector.utils import parse_section
|
from coagent.connector.prompt_manager.prompt_manager import PromptManager
|
||||||
from coagent.connector.prompt_manager import PromptManager
|
|
||||||
import importlib
|
import importlib
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
@ -230,7 +229,7 @@ os.environ["log_verbose"] = "2"
|
||||||
|
|
||||||
phase_name = "codeRetrievalPhase"
|
phase_name = "codeRetrievalPhase"
|
||||||
llm_config = LLMConfig(
|
llm_config = LLMConfig(
|
||||||
model_name="gpt-3.5-turbo", model_device="cpu",api_key=os.environ["OPENAI_API_KEY"],
|
model_name="gpt-3.5-turbo", api_key=os.environ["OPENAI_API_KEY"],
|
||||||
api_base_url=os.environ["API_BASE_URL"], temperature=0.3
|
api_base_url=os.environ["API_BASE_URL"], temperature=0.3
|
||||||
)
|
)
|
||||||
embed_config = EmbedConfig(
|
embed_config = EmbedConfig(
|
||||||
|
@ -246,7 +245,7 @@ query_content = "UtilsTest 这个类中测试了哪些函数,测试的函数代
|
||||||
query = Message(
|
query = Message(
|
||||||
role_name="human", role_type="user",
|
role_name="human", role_type="user",
|
||||||
role_content=query_content, input_query=query_content, origin_query=query_content,
|
role_content=query_content, input_query=query_content, origin_query=query_content,
|
||||||
code_engine_name="client", score_threshold=1.0, top_k=3, cb_search_type="tag"
|
code_engine_name="client_1", score_threshold=1.0, top_k=3, cb_search_type="tag"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -24,7 +24,7 @@ os.environ["log_verbose"] = "2"
|
||||||
|
|
||||||
phase_name = "codeToolReactPhase"
|
phase_name = "codeToolReactPhase"
|
||||||
llm_config = LLMConfig(
|
llm_config = LLMConfig(
|
||||||
model_name="gpt-3.5-turbo-0613", model_device="cpu",api_key=os.environ["OPENAI_API_KEY"],
|
model_name="gpt-3.5-turbo-0613", api_key=os.environ["OPENAI_API_KEY"],
|
||||||
api_base_url=os.environ["API_BASE_URL"], temperature=0.7
|
api_base_url=os.environ["API_BASE_URL"], temperature=0.7
|
||||||
)
|
)
|
||||||
embed_config = EmbedConfig(
|
embed_config = EmbedConfig(
|
||||||
|
|
|
@ -17,7 +17,7 @@ from coagent.connector.schema import Message, Memory
|
||||||
|
|
||||||
tools = toLangchainTools([TOOL_DICT[i] for i in TOOL_SETS if i in TOOL_DICT])
|
tools = toLangchainTools([TOOL_DICT[i] for i in TOOL_SETS if i in TOOL_DICT])
|
||||||
llm_config = LLMConfig(
|
llm_config = LLMConfig(
|
||||||
model_name="gpt-3.5-turbo", model_device="cpu",api_key=os.environ["OPENAI_API_KEY"],
|
model_name="gpt-3.5-turbo",api_key=os.environ["OPENAI_API_KEY"],
|
||||||
api_base_url=os.environ["API_BASE_URL"], temperature=0.3
|
api_base_url=os.environ["API_BASE_URL"], temperature=0.3
|
||||||
)
|
)
|
||||||
embed_config = EmbedConfig(
|
embed_config = EmbedConfig(
|
||||||
|
|
|
@ -18,7 +18,7 @@ os.environ["log_verbose"] = "0"
|
||||||
|
|
||||||
phase_name = "metagpt_code_devlop"
|
phase_name = "metagpt_code_devlop"
|
||||||
llm_config = LLMConfig(
|
llm_config = LLMConfig(
|
||||||
model_name="gpt-3.5-turbo", model_device="cpu",api_key=os.environ["OPENAI_API_KEY"],
|
model_name="gpt-3.5-turbo", api_key=os.environ["OPENAI_API_KEY"],
|
||||||
api_base_url=os.environ["API_BASE_URL"], temperature=0.3
|
api_base_url=os.environ["API_BASE_URL"], temperature=0.3
|
||||||
)
|
)
|
||||||
embed_config = EmbedConfig(
|
embed_config = EmbedConfig(
|
||||||
|
|
|
@ -20,7 +20,7 @@ os.environ["log_verbose"] = "2"
|
||||||
|
|
||||||
phase_name = "searchChatPhase"
|
phase_name = "searchChatPhase"
|
||||||
llm_config = LLMConfig(
|
llm_config = LLMConfig(
|
||||||
model_name="gpt-3.5-turbo", model_device="cpu",api_key=os.environ["OPENAI_API_KEY"],
|
model_name="gpt-3.5-turbo", api_key=os.environ["OPENAI_API_KEY"],
|
||||||
api_base_url=os.environ["API_BASE_URL"], temperature=0.3
|
api_base_url=os.environ["API_BASE_URL"], temperature=0.3
|
||||||
)
|
)
|
||||||
embed_config = EmbedConfig(
|
embed_config = EmbedConfig(
|
||||||
|
|
|
@ -18,7 +18,7 @@ os.environ["log_verbose"] = "2"
|
||||||
|
|
||||||
phase_name = "toolReactPhase"
|
phase_name = "toolReactPhase"
|
||||||
llm_config = LLMConfig(
|
llm_config = LLMConfig(
|
||||||
model_name="gpt-3.5-turbo", model_device="cpu",api_key=os.environ["OPENAI_API_KEY"],
|
model_name="gpt-3.5-turbo",api_key=os.environ["OPENAI_API_KEY"],
|
||||||
api_base_url=os.environ["API_BASE_URL"], temperature=0.3
|
api_base_url=os.environ["API_BASE_URL"], temperature=0.3
|
||||||
)
|
)
|
||||||
embed_config = EmbedConfig(
|
embed_config = EmbedConfig(
|
||||||
|
|
|
@ -151,9 +151,9 @@ def create_app():
|
||||||
)(delete_cb)
|
)(delete_cb)
|
||||||
|
|
||||||
app.post("/code_base/code_base_chat",
|
app.post("/code_base/code_base_chat",
|
||||||
tags=["Code Base Management"],
|
tags=["Code Base Management"],
|
||||||
summary="删除 code_base"
|
summary="code_base 对话"
|
||||||
)(delete_cb)
|
)(search_code)
|
||||||
|
|
||||||
app.get("/code_base/list_code_bases",
|
app.get("/code_base/list_code_bases",
|
||||||
tags=["Code Base Management"],
|
tags=["Code Base Management"],
|
||||||
|
|
|
@ -117,7 +117,7 @@ PHASE_CONFIGS.update({
|
||||||
|
|
||||||
|
|
||||||
llm_config = LLMConfig(
|
llm_config = LLMConfig(
|
||||||
model_name="gpt-3.5-turbo", model_device="cpu",api_key=os.environ["OPENAI_API_KEY"],
|
model_name="gpt-3.5-turbo", api_key=os.environ["OPENAI_API_KEY"],
|
||||||
api_base_url=os.environ["API_BASE_URL"], temperature=0.3
|
api_base_url=os.environ["API_BASE_URL"], temperature=0.3
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -98,12 +98,6 @@ def start_docker(client, script_shs, ports, image_name, container_name, mounts=N
|
||||||
network_name ='my_network'
|
network_name ='my_network'
|
||||||
|
|
||||||
def start_sandbox_service(network_name ='my_network'):
|
def start_sandbox_service(network_name ='my_network'):
|
||||||
# networks = client.networks.list()
|
|
||||||
# if any([network_name==i.attrs["Name"] for i in networks]):
|
|
||||||
# network = client.networks.get(network_name)
|
|
||||||
# else:
|
|
||||||
# network = client.networks.create('my_network', driver='bridge')
|
|
||||||
|
|
||||||
mount = Mount(
|
mount = Mount(
|
||||||
type='bind',
|
type='bind',
|
||||||
source=os.path.join(src_dir, "jupyter_work"),
|
source=os.path.join(src_dir, "jupyter_work"),
|
||||||
|
@ -114,6 +108,12 @@ def start_sandbox_service(network_name ='my_network'):
|
||||||
# 沙盒的启动与服务的启动是独立的
|
# 沙盒的启动与服务的启动是独立的
|
||||||
if SANDBOX_SERVER["do_remote"]:
|
if SANDBOX_SERVER["do_remote"]:
|
||||||
client = docker.from_env()
|
client = docker.from_env()
|
||||||
|
networks = client.networks.list()
|
||||||
|
if any([network_name==i.attrs["Name"] for i in networks]):
|
||||||
|
network = client.networks.get(network_name)
|
||||||
|
else:
|
||||||
|
network = client.networks.create('my_network', driver='bridge')
|
||||||
|
|
||||||
# 启动容器
|
# 启动容器
|
||||||
logger.info("start container sandbox service")
|
logger.info("start container sandbox service")
|
||||||
JUPYTER_WORK_PATH = "/home/user/chatbot/jupyter_work"
|
JUPYTER_WORK_PATH = "/home/user/chatbot/jupyter_work"
|
||||||
|
@ -150,7 +150,7 @@ def start_api_service(sandbox_host=DEFAULT_BIND_HOST):
|
||||||
client = docker.from_env()
|
client = docker.from_env()
|
||||||
logger.info("start container service")
|
logger.info("start container service")
|
||||||
check_process("api.py", do_stop=True)
|
check_process("api.py", do_stop=True)
|
||||||
check_process("sdfile_api.py", do_stop=True)
|
check_process("llm_api.py", do_stop=True)
|
||||||
check_process("sdfile_api.py", do_stop=True)
|
check_process("sdfile_api.py", do_stop=True)
|
||||||
check_process("webui.py", do_stop=True)
|
check_process("webui.py", do_stop=True)
|
||||||
mount = Mount(
|
mount = Mount(
|
||||||
|
@ -159,27 +159,28 @@ def start_api_service(sandbox_host=DEFAULT_BIND_HOST):
|
||||||
target='/home/user/chatbot/',
|
target='/home/user/chatbot/',
|
||||||
read_only=False # 如果需要只读访问,将此选项设置为True
|
read_only=False # 如果需要只读访问,将此选项设置为True
|
||||||
)
|
)
|
||||||
mount_database = Mount(
|
# mount_database = Mount(
|
||||||
type='bind',
|
# type='bind',
|
||||||
source=os.path.join(src_dir, "knowledge_base"),
|
# source=os.path.join(src_dir, "knowledge_base"),
|
||||||
target='/home/user/knowledge_base/',
|
# target='/home/user/knowledge_base/',
|
||||||
read_only=False # 如果需要只读访问,将此选项设置为True
|
# read_only=False # 如果需要只读访问,将此选项设置为True
|
||||||
)
|
# )
|
||||||
mount_code_database = Mount(
|
# mount_code_database = Mount(
|
||||||
type='bind',
|
# type='bind',
|
||||||
source=os.path.join(src_dir, "code_base"),
|
# source=os.path.join(src_dir, "code_base"),
|
||||||
target='/home/user/code_base/',
|
# target='/home/user/code_base/',
|
||||||
read_only=False # 如果需要只读访问,将此选项设置为True
|
# read_only=False # 如果需要只读访问,将此选项设置为True
|
||||||
)
|
# )
|
||||||
ports={
|
ports={
|
||||||
f"{API_SERVER['docker_port']}/tcp": f"{API_SERVER['port']}/tcp",
|
f"{API_SERVER['docker_port']}/tcp": f"{API_SERVER['port']}/tcp",
|
||||||
f"{WEBUI_SERVER['docker_port']}/tcp": f"{WEBUI_SERVER['port']}/tcp",
|
f"{WEBUI_SERVER['docker_port']}/tcp": f"{WEBUI_SERVER['port']}/tcp",
|
||||||
f"{SDFILE_API_SERVER['docker_port']}/tcp": f"{SDFILE_API_SERVER['port']}/tcp",
|
f"{SDFILE_API_SERVER['docker_port']}/tcp": f"{SDFILE_API_SERVER['port']}/tcp",
|
||||||
f"{NEBULA_GRAPH_SERVER['docker_port']}/tcp": f"{NEBULA_GRAPH_SERVER['port']}/tcp"
|
f"{NEBULA_GRAPH_SERVER['docker_port']}/tcp": f"{NEBULA_GRAPH_SERVER['port']}/tcp"
|
||||||
}
|
}
|
||||||
mounts = [mount, mount_database, mount_code_database]
|
# mounts = [mount, mount_database, mount_code_database]
|
||||||
|
mounts = [mount]
|
||||||
script_shs = [
|
script_shs = [
|
||||||
"mkdir -p /home/user/logs",
|
"mkdir -p /home/user/chatbot/logs",
|
||||||
'''
|
'''
|
||||||
if [ -d "/home/user/chatbot/data/nebula_data/data/meta" ]; then
|
if [ -d "/home/user/chatbot/data/nebula_data/data/meta" ]; then
|
||||||
cp -r /home/user/chatbot/data/nebula_data/data /usr/local/nebula/
|
cp -r /home/user/chatbot/data/nebula_data/data /usr/local/nebula/
|
||||||
|
@ -197,12 +198,12 @@ def start_api_service(sandbox_host=DEFAULT_BIND_HOST):
|
||||||
"pip install jieba",
|
"pip install jieba",
|
||||||
"pip install duckduckgo-search",
|
"pip install duckduckgo-search",
|
||||||
|
|
||||||
"nohup python chatbot/examples/sdfile_api.py > /home/user/logs/sdfile_api.log 2>&1 &",
|
"nohup python chatbot/examples/sdfile_api.py > /home/user/chatbot/logs/sdfile_api.log 2>&1 &",
|
||||||
f"export DUCKDUCKGO_PROXY=socks5://host.docker.internal:13659 && export SANDBOX_HOST={sandbox_host} &&\
|
f"export DUCKDUCKGO_PROXY=socks5://host.docker.internal:13659 && export SANDBOX_HOST={sandbox_host} &&\
|
||||||
nohup python chatbot/examples/api.py > /home/user/logs/api.log 2>&1 &",
|
nohup python chatbot/examples/api.py > /home/user/chatbot/logs/api.log 2>&1 &",
|
||||||
"nohup python chatbot/examples/llm_api.py > /home/user/llm.log 2>&1 &",
|
"nohup python chatbot/examples/llm_api.py > /home/user/llm.log 2>&1 &",
|
||||||
f"export DUCKDUCKGO_PROXY=socks5://host.docker.internal:13659 && export SANDBOX_HOST={sandbox_host} &&\
|
f"export DUCKDUCKGO_PROXY=socks5://host.docker.internal:13659 && export SANDBOX_HOST={sandbox_host} &&\
|
||||||
cd chatbot/examples && nohup streamlit run webui.py > /home/user/logs/start_webui.log 2>&1 &"
|
cd chatbot/examples && nohup streamlit run webui.py > /home/user/chatbot/logs/start_webui.log 2>&1 &"
|
||||||
]
|
]
|
||||||
if check_docker(client, CONTRAINER_NAME, do_stop=True):
|
if check_docker(client, CONTRAINER_NAME, do_stop=True):
|
||||||
container = start_docker(client, script_shs, ports, IMAGE_NAME, CONTRAINER_NAME, mounts, network=network_name)
|
container = start_docker(client, script_shs, ports, IMAGE_NAME, CONTRAINER_NAME, mounts, network=network_name)
|
||||||
|
@ -212,12 +213,9 @@ def start_api_service(sandbox_host=DEFAULT_BIND_HOST):
|
||||||
# 关闭之前启动的docker 服务
|
# 关闭之前启动的docker 服务
|
||||||
# check_docker(client, CONTRAINER_NAME, do_stop=True, )
|
# check_docker(client, CONTRAINER_NAME, do_stop=True, )
|
||||||
|
|
||||||
# api_sh = "nohup python ../coagent/service/api.py > ../logs/api.log 2>&1 &"
|
|
||||||
api_sh = "nohup python api.py > ../logs/api.log 2>&1 &"
|
api_sh = "nohup python api.py > ../logs/api.log 2>&1 &"
|
||||||
# sdfile_sh = "nohup python ../coagent/service/sdfile_api.py > ../logs/sdfile_api.log 2>&1 &"
|
|
||||||
sdfile_sh = "nohup python sdfile_api.py > ../logs/sdfile_api.log 2>&1 &"
|
sdfile_sh = "nohup python sdfile_api.py > ../logs/sdfile_api.log 2>&1 &"
|
||||||
notebook_sh = f"nohup jupyter-notebook --NotebookApp.token=mytoken --port={SANDBOX_SERVER['port']} --allow-root --ip=0.0.0.0 --notebook-dir={JUPYTER_WORK_PATH} --no-browser --ServerApp.disable_check_xsrf=True > ../logs/sandbox.log 2>&1 &"
|
notebook_sh = f"nohup jupyter-notebook --NotebookApp.token=mytoken --port={SANDBOX_SERVER['port']} --allow-root --ip=0.0.0.0 --notebook-dir={JUPYTER_WORK_PATH} --no-browser --ServerApp.disable_check_xsrf=True > ../logs/sandbox.log 2>&1 &"
|
||||||
# llm_sh = "nohup python ../coagent/service/llm_api.py > ../logs/llm_api.log 2>&1 &"
|
|
||||||
llm_sh = "nohup python llm_api.py > ../logs/llm_api.log 2>&1 &"
|
llm_sh = "nohup python llm_api.py > ../logs/llm_api.log 2>&1 &"
|
||||||
webui_sh = "streamlit run webui.py" if USE_TTY else "streamlit run webui.py"
|
webui_sh = "streamlit run webui.py" if USE_TTY else "streamlit run webui.py"
|
||||||
|
|
||||||
|
|
|
@ -22,7 +22,7 @@ from coagent.service.service_factory import get_cb_details, get_cb_details_by_cb
|
||||||
from coagent.orm import table_init
|
from coagent.orm import table_init
|
||||||
|
|
||||||
|
|
||||||
from configs.model_config import EMBEDDING_DEVICE, EMBEDDING_ENGINE, EMBEDDING_MODEL, embedding_model_dict
|
from configs.model_config import EMBEDDING_DEVICE, EMBEDDING_ENGINE, EMBEDDING_MODEL, embedding_model_dict,llm_model_dict
|
||||||
# SENTENCE_SIZE = 100
|
# SENTENCE_SIZE = 100
|
||||||
|
|
||||||
cell_renderer = JsCode("""function(params) {if(params.value==true){return '✓'}else{return '×'}}""")
|
cell_renderer = JsCode("""function(params) {if(params.value==true){return '✓'}else{return '×'}}""")
|
||||||
|
@ -117,6 +117,8 @@ def code_page(api: ApiRequest):
|
||||||
embed_model_path=embedding_model_dict[EMBEDDING_MODEL],
|
embed_model_path=embedding_model_dict[EMBEDDING_MODEL],
|
||||||
embedding_device=EMBEDDING_DEVICE,
|
embedding_device=EMBEDDING_DEVICE,
|
||||||
llm_model=LLM_MODEL,
|
llm_model=LLM_MODEL,
|
||||||
|
api_key=llm_model_dict[LLM_MODEL]["api_key"],
|
||||||
|
api_base_url=llm_model_dict[LLM_MODEL]["api_base_url"],
|
||||||
)
|
)
|
||||||
st.toast(ret.get("msg", " "))
|
st.toast(ret.get("msg", " "))
|
||||||
st.session_state["selected_cb_name"] = cb_name
|
st.session_state["selected_cb_name"] = cb_name
|
||||||
|
@ -153,6 +155,8 @@ def code_page(api: ApiRequest):
|
||||||
embed_model_path=embedding_model_dict[EMBEDDING_MODEL],
|
embed_model_path=embedding_model_dict[EMBEDDING_MODEL],
|
||||||
embedding_device=EMBEDDING_DEVICE,
|
embedding_device=EMBEDDING_DEVICE,
|
||||||
llm_model=LLM_MODEL,
|
llm_model=LLM_MODEL,
|
||||||
|
api_key=llm_model_dict[LLM_MODEL]["api_key"],
|
||||||
|
api_base_url=llm_model_dict[LLM_MODEL]["api_base_url"],
|
||||||
)
|
)
|
||||||
st.toast(ret.get("msg", "删除成功"))
|
st.toast(ret.get("msg", "删除成功"))
|
||||||
time.sleep(0.05)
|
time.sleep(0.05)
|
||||||
|
|
|
@ -11,7 +11,7 @@ from coagent.chat.search_chat import SEARCH_ENGINES
|
||||||
from coagent.connector import PHASE_LIST, PHASE_CONFIGS
|
from coagent.connector import PHASE_LIST, PHASE_CONFIGS
|
||||||
from coagent.service.service_factory import get_cb_details_by_cb_name
|
from coagent.service.service_factory import get_cb_details_by_cb_name
|
||||||
|
|
||||||
from configs.model_config import EMBEDDING_DEVICE, EMBEDDING_MODEL, embedding_model_dict, EMBEDDING_ENGINE, KB_ROOT_PATH
|
from configs.model_config import EMBEDDING_DEVICE, EMBEDDING_MODEL, embedding_model_dict, EMBEDDING_ENGINE, KB_ROOT_PATH, llm_model_dict
|
||||||
chat_box = ChatBox(
|
chat_box = ChatBox(
|
||||||
assistant_avatar="../sources/imgs/devops-chatbot2.png"
|
assistant_avatar="../sources/imgs/devops-chatbot2.png"
|
||||||
)
|
)
|
||||||
|
@ -174,7 +174,7 @@ def dialogue_page(api: ApiRequest):
|
||||||
is_detailed = st.toggle(webui_configs["dialogue"]["phase_toggle_detailed_name"], False)
|
is_detailed = st.toggle(webui_configs["dialogue"]["phase_toggle_detailed_name"], False)
|
||||||
tool_using_on = st.toggle(
|
tool_using_on = st.toggle(
|
||||||
webui_configs["dialogue"]["phase_toggle_doToolUsing"],
|
webui_configs["dialogue"]["phase_toggle_doToolUsing"],
|
||||||
PHASE_CONFIGS[choose_phase]["do_using_tool"])
|
PHASE_CONFIGS[choose_phase].get("do_using_tool", False))
|
||||||
tool_selects = []
|
tool_selects = []
|
||||||
if tool_using_on:
|
if tool_using_on:
|
||||||
with st.expander("工具军火库", True):
|
with st.expander("工具军火库", True):
|
||||||
|
@ -183,7 +183,7 @@ def dialogue_page(api: ApiRequest):
|
||||||
TOOL_SETS, ["WeatherInfo"])
|
TOOL_SETS, ["WeatherInfo"])
|
||||||
|
|
||||||
search_on = st.toggle(webui_configs["dialogue"]["phase_toggle_doSearch"],
|
search_on = st.toggle(webui_configs["dialogue"]["phase_toggle_doSearch"],
|
||||||
PHASE_CONFIGS[choose_phase]["do_search"])
|
PHASE_CONFIGS[choose_phase].get("do_search", False))
|
||||||
search_engine, top_k = None, 3
|
search_engine, top_k = None, 3
|
||||||
if search_on:
|
if search_on:
|
||||||
with st.expander(webui_configs["dialogue"]["expander_search_name"], True):
|
with st.expander(webui_configs["dialogue"]["expander_search_name"], True):
|
||||||
|
@ -195,7 +195,8 @@ def dialogue_page(api: ApiRequest):
|
||||||
|
|
||||||
doc_retrieval_on = st.toggle(
|
doc_retrieval_on = st.toggle(
|
||||||
webui_configs["dialogue"]["phase_toggle_doDocRetrieval"],
|
webui_configs["dialogue"]["phase_toggle_doDocRetrieval"],
|
||||||
PHASE_CONFIGS[choose_phase]["do_doc_retrieval"])
|
PHASE_CONFIGS[choose_phase].get("do_doc_retrieval", False)
|
||||||
|
)
|
||||||
selected_kb, top_k, score_threshold = None, 3, 1.0
|
selected_kb, top_k, score_threshold = None, 3, 1.0
|
||||||
if doc_retrieval_on:
|
if doc_retrieval_on:
|
||||||
with st.expander(webui_configs["dialogue"]["kbase_expander_name"], True):
|
with st.expander(webui_configs["dialogue"]["kbase_expander_name"], True):
|
||||||
|
@ -215,7 +216,7 @@ def dialogue_page(api: ApiRequest):
|
||||||
|
|
||||||
code_retrieval_on = st.toggle(
|
code_retrieval_on = st.toggle(
|
||||||
webui_configs["dialogue"]["phase_toggle_doCodeRetrieval"],
|
webui_configs["dialogue"]["phase_toggle_doCodeRetrieval"],
|
||||||
PHASE_CONFIGS[choose_phase]["do_code_retrieval"])
|
PHASE_CONFIGS[choose_phase].get("do_code_retrieval", False))
|
||||||
selected_cb, top_k = None, 1
|
selected_cb, top_k = None, 1
|
||||||
cb_search_type = "tag"
|
cb_search_type = "tag"
|
||||||
if code_retrieval_on:
|
if code_retrieval_on:
|
||||||
|
@ -296,7 +297,8 @@ def dialogue_page(api: ApiRequest):
|
||||||
r = api.chat_chat(
|
r = api.chat_chat(
|
||||||
prompt, history, no_remote_api=True,
|
prompt, history, no_remote_api=True,
|
||||||
embed_model=EMBEDDING_MODEL, embed_model_path=embedding_model_dict[EMBEDDING_MODEL],
|
embed_model=EMBEDDING_MODEL, embed_model_path=embedding_model_dict[EMBEDDING_MODEL],
|
||||||
model_device=EMBEDDING_DEVICE, embed_engine=EMBEDDING_ENGINE,
|
model_device=EMBEDDING_DEVICE, embed_engine=EMBEDDING_ENGINE,api_key=llm_model_dict[LLM_MODEL]["api_key"],
|
||||||
|
api_base_url=llm_model_dict[LLM_MODEL]["api_base_url"],
|
||||||
llm_model=LLM_MODEL)
|
llm_model=LLM_MODEL)
|
||||||
for t in r:
|
for t in r:
|
||||||
if error_msg := check_error_msg(t): # check whether error occured
|
if error_msg := check_error_msg(t): # check whether error occured
|
||||||
|
@ -362,6 +364,8 @@ def dialogue_page(api: ApiRequest):
|
||||||
"embed_engine": EMBEDDING_ENGINE,
|
"embed_engine": EMBEDDING_ENGINE,
|
||||||
"kb_root_path": KB_ROOT_PATH,
|
"kb_root_path": KB_ROOT_PATH,
|
||||||
"model_name": LLM_MODEL,
|
"model_name": LLM_MODEL,
|
||||||
|
"api_key": llm_model_dict[LLM_MODEL]["api_key"],
|
||||||
|
"api_base_url": llm_model_dict[LLM_MODEL]["api_base_url"],
|
||||||
}
|
}
|
||||||
text = ""
|
text = ""
|
||||||
d = {"docs": []}
|
d = {"docs": []}
|
||||||
|
@ -405,7 +409,10 @@ def dialogue_page(api: ApiRequest):
|
||||||
api.knowledge_base_chat(
|
api.knowledge_base_chat(
|
||||||
prompt, selected_kb, kb_top_k, score_threshold, history,
|
prompt, selected_kb, kb_top_k, score_threshold, history,
|
||||||
embed_model=EMBEDDING_MODEL, embed_model_path=embedding_model_dict[EMBEDDING_MODEL],
|
embed_model=EMBEDDING_MODEL, embed_model_path=embedding_model_dict[EMBEDDING_MODEL],
|
||||||
model_device=EMBEDDING_DEVICE, embed_engine=EMBEDDING_ENGINE, llm_model=LLM_MODEL)
|
model_device=EMBEDDING_DEVICE, embed_engine=EMBEDDING_ENGINE, llm_model=LLM_MODEL,
|
||||||
|
api_key=llm_model_dict[LLM_MODEL]["api_key"],
|
||||||
|
api_base_url=llm_model_dict[LLM_MODEL]["api_base_url"],
|
||||||
|
)
|
||||||
):
|
):
|
||||||
if error_msg := check_error_msg(d): # check whether error occured
|
if error_msg := check_error_msg(d): # check whether error occured
|
||||||
st.error(error_msg)
|
st.error(error_msg)
|
||||||
|
@ -415,11 +422,7 @@ def dialogue_page(api: ApiRequest):
|
||||||
# chat_box.update_msg("知识库匹配结果: \n\n".join(d["docs"]), element_index=1, streaming=False, state="complete")
|
# chat_box.update_msg("知识库匹配结果: \n\n".join(d["docs"]), element_index=1, streaming=False, state="complete")
|
||||||
chat_box.update_msg(text, element_index=0, streaming=False) # 更新最终的字符串,去除光标
|
chat_box.update_msg(text, element_index=0, streaming=False) # 更新最终的字符串,去除光标
|
||||||
chat_box.update_msg("{webui_configs['chat']['chatbox_doc_result']}: \n\n".join(d["docs"]), element_index=1, streaming=False, state="complete")
|
chat_box.update_msg("{webui_configs['chat']['chatbox_doc_result']}: \n\n".join(d["docs"]), element_index=1, streaming=False, state="complete")
|
||||||
# # 判断是否存在代码, 并提高编辑功能,执行功能
|
|
||||||
# code_text = api.codebox.decode_code_from_text(text)
|
|
||||||
# GLOBAL_EXE_CODE_TEXT = code_text
|
|
||||||
# if code_text and code_exec_on:
|
|
||||||
# codebox_res = api.codebox_chat("```"+code_text+"```", do_code_exe=True)
|
|
||||||
elif dialogue_mode == webui_configs["dialogue"]["mode"][2]:
|
elif dialogue_mode == webui_configs["dialogue"]["mode"][2]:
|
||||||
logger.info('prompt={}'.format(prompt))
|
logger.info('prompt={}'.format(prompt))
|
||||||
logger.info('history={}'.format(history))
|
logger.info('history={}'.format(history))
|
||||||
|
@ -438,7 +441,9 @@ def dialogue_page(api: ApiRequest):
|
||||||
cb_search_type=cb_search_type,
|
cb_search_type=cb_search_type,
|
||||||
no_remote_api=True, embed_model=EMBEDDING_MODEL,
|
no_remote_api=True, embed_model=EMBEDDING_MODEL,
|
||||||
embed_model_path=embedding_model_dict[EMBEDDING_MODEL],
|
embed_model_path=embedding_model_dict[EMBEDDING_MODEL],
|
||||||
embed_engine=EMBEDDING_ENGINE, llm_model=LLM_MODEL
|
embed_engine=EMBEDDING_ENGINE, llm_model=LLM_MODEL,
|
||||||
|
api_key=llm_model_dict[LLM_MODEL]["api_key"],
|
||||||
|
api_base_url=llm_model_dict[LLM_MODEL]["api_base_url"],
|
||||||
)):
|
)):
|
||||||
if error_msg := check_error_msg(d):
|
if error_msg := check_error_msg(d):
|
||||||
st.error(error_msg)
|
st.error(error_msg)
|
||||||
|
@ -448,6 +453,7 @@ def dialogue_page(api: ApiRequest):
|
||||||
chat_box.update_msg(text, element_index=0)
|
chat_box.update_msg(text, element_index=0)
|
||||||
|
|
||||||
# postprocess
|
# postprocess
|
||||||
|
logger.debug(f"d={d}")
|
||||||
text = replace_lt_gt(text)
|
text = replace_lt_gt(text)
|
||||||
chat_box.update_msg(text, element_index=0, streaming=False) # 更新最终的字符串,去除光标
|
chat_box.update_msg(text, element_index=0, streaming=False) # 更新最终的字符串,去除光标
|
||||||
logger.debug('text={}'.format(text))
|
logger.debug('text={}'.format(text))
|
||||||
|
@ -467,7 +473,9 @@ def dialogue_page(api: ApiRequest):
|
||||||
api.search_engine_chat(
|
api.search_engine_chat(
|
||||||
prompt, search_engine, se_top_k, history, embed_model=EMBEDDING_MODEL,
|
prompt, search_engine, se_top_k, history, embed_model=EMBEDDING_MODEL,
|
||||||
embed_model_path=embedding_model_dict[EMBEDDING_MODEL],
|
embed_model_path=embedding_model_dict[EMBEDDING_MODEL],
|
||||||
model_device=EMBEDDING_DEVICE, embed_engine=EMBEDDING_ENGINE, llm_model=LLM_MODEL)
|
model_device=EMBEDDING_DEVICE, embed_engine=EMBEDDING_ENGINE, llm_model=LLM_MODEL,
|
||||||
|
pi_key=llm_model_dict[LLM_MODEL]["api_key"],
|
||||||
|
api_base_url=llm_model_dict[LLM_MODEL]["api_base_url"],)
|
||||||
):
|
):
|
||||||
if error_msg := check_error_msg(d): # check whether error occured
|
if error_msg := check_error_msg(d): # check whether error occured
|
||||||
st.error(error_msg)
|
st.error(error_msg)
|
||||||
|
@ -477,56 +485,11 @@ def dialogue_page(api: ApiRequest):
|
||||||
# chat_box.update_msg("搜索匹配结果: \n\n".join(d["docs"]), element_index=1, streaming=False)
|
# chat_box.update_msg("搜索匹配结果: \n\n".join(d["docs"]), element_index=1, streaming=False)
|
||||||
chat_box.update_msg(text, element_index=0, streaming=False) # 更新最终的字符串,去除光标
|
chat_box.update_msg(text, element_index=0, streaming=False) # 更新最终的字符串,去除光标
|
||||||
chat_box.update_msg(f"{webui_configs['chat']['chatbox_search_result']}: \n\n".join(d["docs"]), element_index=1, streaming=False, state="complete")
|
chat_box.update_msg(f"{webui_configs['chat']['chatbox_search_result']}: \n\n".join(d["docs"]), element_index=1, streaming=False, state="complete")
|
||||||
# # 判断是否存在代码, 并提高编辑功能,执行功能
|
|
||||||
# code_text = api.codebox.decode_code_from_text(text)
|
|
||||||
# GLOBAL_EXE_CODE_TEXT = code_text
|
|
||||||
# if code_text and code_exec_on:
|
|
||||||
# codebox_res = api.codebox_chat("```"+code_text+"```", do_code_exe=True)
|
|
||||||
|
|
||||||
# 将上传文件清空
|
# 将上传文件清空
|
||||||
st.session_state["interpreter_file_key"] += 1
|
st.session_state["interpreter_file_key"] += 1
|
||||||
st.experimental_rerun()
|
st.experimental_rerun()
|
||||||
|
|
||||||
# if code_interpreter_on:
|
|
||||||
# with st.expander(webui_configs['sandbox']['expander_code_name'], False):
|
|
||||||
# code_part = st.text_area(
|
|
||||||
# webui_configs['sandbox']['textArea_code_name'], code_text, key="code_text")
|
|
||||||
# cols = st.columns(2)
|
|
||||||
# if cols[0].button(
|
|
||||||
# webui_configs['sandbox']['button_modify_code_name'],
|
|
||||||
# use_container_width=True,
|
|
||||||
# ):
|
|
||||||
# code_text = code_part
|
|
||||||
# GLOBAL_EXE_CODE_TEXT = code_text
|
|
||||||
# st.toast(webui_configs['sandbox']['text_modify_code'])
|
|
||||||
|
|
||||||
# if cols[1].button(
|
|
||||||
# webui_configs['sandbox']['button_exec_code_name'],
|
|
||||||
# use_container_width=True
|
|
||||||
# ):
|
|
||||||
# if code_text:
|
|
||||||
# codebox_res = api.codebox_chat("```"+code_text+"```", do_code_exe=True)
|
|
||||||
# st.toast(webui_configs['sandbox']['text_execing_code'],)
|
|
||||||
# else:
|
|
||||||
# st.toast(webui_configs['sandbox']['text_error_exec_code'],)
|
|
||||||
|
|
||||||
# #TODO 这段信息会被记录到history里
|
|
||||||
# if codebox_res is not None and codebox_res.code_exe_status != 200:
|
|
||||||
# st.toast(f"{codebox_res.code_exe_response}")
|
|
||||||
|
|
||||||
# if codebox_res is not None and codebox_res.code_exe_status == 200:
|
|
||||||
# st.toast(f"codebox_chat {codebox_res}")
|
|
||||||
# chat_box.ai_say(Markdown(code_text, in_expander=True, title="code interpreter", unsafe_allow_html=True), )
|
|
||||||
# if codebox_res.code_exe_type == "image/png":
|
|
||||||
# base_text = f"```\n{code_text}\n```\n\n"
|
|
||||||
# img_html = "<img src='data:image/png;base64,{}' class='img-fluid'>".format(
|
|
||||||
# codebox_res.code_exe_response
|
|
||||||
# )
|
|
||||||
# chat_box.update_msg(img_html, streaming=False, state="complete")
|
|
||||||
# else:
|
|
||||||
# chat_box.update_msg('```\n'+code_text+'\n```'+"\n\n"+'```\n'+codebox_res.code_exe_response+'\n```',
|
|
||||||
# streaming=False, state="complete")
|
|
||||||
|
|
||||||
now = datetime.now()
|
now = datetime.now()
|
||||||
with st.sidebar:
|
with st.sidebar:
|
||||||
|
|
||||||
|
|
|
@ -14,7 +14,8 @@ from coagent.orm import table_init
|
||||||
|
|
||||||
from configs.model_config import (
|
from configs.model_config import (
|
||||||
KB_ROOT_PATH, kbs_config, DEFAULT_VS_TYPE, WEB_CRAWL_PATH,
|
KB_ROOT_PATH, kbs_config, DEFAULT_VS_TYPE, WEB_CRAWL_PATH,
|
||||||
EMBEDDING_DEVICE, EMBEDDING_ENGINE, EMBEDDING_MODEL, embedding_model_dict
|
EMBEDDING_DEVICE, EMBEDDING_ENGINE, EMBEDDING_MODEL, embedding_model_dict,
|
||||||
|
llm_model_dict
|
||||||
)
|
)
|
||||||
|
|
||||||
# SENTENCE_SIZE = 100
|
# SENTENCE_SIZE = 100
|
||||||
|
@ -136,6 +137,8 @@ def knowledge_page(
|
||||||
embed_engine=EMBEDDING_ENGINE,
|
embed_engine=EMBEDDING_ENGINE,
|
||||||
embedding_device= EMBEDDING_DEVICE,
|
embedding_device= EMBEDDING_DEVICE,
|
||||||
embed_model_path=embedding_model_dict[embed_model],
|
embed_model_path=embedding_model_dict[embed_model],
|
||||||
|
api_key=llm_model_dict[LLM_MODEL]["api_key"],
|
||||||
|
api_base_url=llm_model_dict[LLM_MODEL]["api_base_url"],
|
||||||
)
|
)
|
||||||
st.toast(ret.get("msg", " "))
|
st.toast(ret.get("msg", " "))
|
||||||
st.session_state["selected_kb_name"] = kb_name
|
st.session_state["selected_kb_name"] = kb_name
|
||||||
|
@ -160,7 +163,10 @@ def knowledge_page(
|
||||||
data = [{"file": f, "knowledge_base_name": kb, "not_refresh_vs_cache": True, "embed_model": EMBEDDING_MODEL,
|
data = [{"file": f, "knowledge_base_name": kb, "not_refresh_vs_cache": True, "embed_model": EMBEDDING_MODEL,
|
||||||
"embed_model_path": embedding_model_dict[EMBEDDING_MODEL],
|
"embed_model_path": embedding_model_dict[EMBEDDING_MODEL],
|
||||||
"model_device": EMBEDDING_DEVICE,
|
"model_device": EMBEDDING_DEVICE,
|
||||||
"embed_engine": EMBEDDING_ENGINE}
|
"embed_engine": EMBEDDING_ENGINE,
|
||||||
|
"api_key": llm_model_dict[LLM_MODEL]["api_key"],
|
||||||
|
"api_base_url": llm_model_dict[LLM_MODEL]["api_base_url"],
|
||||||
|
}
|
||||||
for f in files]
|
for f in files]
|
||||||
data[-1]["not_refresh_vs_cache"]=False
|
data[-1]["not_refresh_vs_cache"]=False
|
||||||
for k in data:
|
for k in data:
|
||||||
|
@ -210,7 +216,9 @@ def knowledge_page(
|
||||||
"embed_model": EMBEDDING_MODEL,
|
"embed_model": EMBEDDING_MODEL,
|
||||||
"embed_model_path": embedding_model_dict[EMBEDDING_MODEL],
|
"embed_model_path": embedding_model_dict[EMBEDDING_MODEL],
|
||||||
"model_device": EMBEDDING_DEVICE,
|
"model_device": EMBEDDING_DEVICE,
|
||||||
"embed_engine": EMBEDDING_ENGINE}]
|
"embed_engine": EMBEDDING_ENGINE,
|
||||||
|
"api_key": llm_model_dict[LLM_MODEL]["api_key"],
|
||||||
|
"api_base_url": llm_model_dict[LLM_MODEL]["api_base_url"],}]
|
||||||
for k in data:
|
for k in data:
|
||||||
ret = api.upload_kb_doc(**k)
|
ret = api.upload_kb_doc(**k)
|
||||||
logger.info(ret)
|
logger.info(ret)
|
||||||
|
@ -297,7 +305,9 @@ def knowledge_page(
|
||||||
api.update_kb_doc(kb, row["file_name"],
|
api.update_kb_doc(kb, row["file_name"],
|
||||||
embed_engine=EMBEDDING_ENGINE,embed_model=EMBEDDING_MODEL,
|
embed_engine=EMBEDDING_ENGINE,embed_model=EMBEDDING_MODEL,
|
||||||
embed_model_path=embedding_model_dict[EMBEDDING_MODEL],
|
embed_model_path=embedding_model_dict[EMBEDDING_MODEL],
|
||||||
model_device=EMBEDDING_DEVICE
|
model_device=EMBEDDING_DEVICE,
|
||||||
|
api_key=llm_model_dict[LLM_MODEL]["api_key"],
|
||||||
|
api_base_url=llm_model_dict[LLM_MODEL]["api_base_url"],
|
||||||
)
|
)
|
||||||
st.experimental_rerun()
|
st.experimental_rerun()
|
||||||
|
|
||||||
|
@ -311,7 +321,9 @@ def knowledge_page(
|
||||||
api.delete_kb_doc(kb, row["file_name"],
|
api.delete_kb_doc(kb, row["file_name"],
|
||||||
embed_engine=EMBEDDING_ENGINE,embed_model=EMBEDDING_MODEL,
|
embed_engine=EMBEDDING_ENGINE,embed_model=EMBEDDING_MODEL,
|
||||||
embed_model_path=embedding_model_dict[EMBEDDING_MODEL],
|
embed_model_path=embedding_model_dict[EMBEDDING_MODEL],
|
||||||
model_device=EMBEDDING_DEVICE)
|
model_device=EMBEDDING_DEVICE,
|
||||||
|
api_key=llm_model_dict[LLM_MODEL]["api_key"],
|
||||||
|
api_base_url=llm_model_dict[LLM_MODEL]["api_base_url"],)
|
||||||
st.experimental_rerun()
|
st.experimental_rerun()
|
||||||
|
|
||||||
if cols[3].button(
|
if cols[3].button(
|
||||||
|
@ -323,7 +335,9 @@ def knowledge_page(
|
||||||
ret = api.delete_kb_doc(kb, row["file_name"], True,
|
ret = api.delete_kb_doc(kb, row["file_name"], True,
|
||||||
embed_engine=EMBEDDING_ENGINE,embed_model=EMBEDDING_MODEL,
|
embed_engine=EMBEDDING_ENGINE,embed_model=EMBEDDING_MODEL,
|
||||||
embed_model_path=embedding_model_dict[EMBEDDING_MODEL],
|
embed_model_path=embedding_model_dict[EMBEDDING_MODEL],
|
||||||
model_device=EMBEDDING_DEVICE)
|
model_device=EMBEDDING_DEVICE,
|
||||||
|
api_key=llm_model_dict[LLM_MODEL]["api_key"],
|
||||||
|
api_base_url=llm_model_dict[LLM_MODEL]["api_base_url"],)
|
||||||
st.toast(ret.get("msg", " "))
|
st.toast(ret.get("msg", " "))
|
||||||
st.experimental_rerun()
|
st.experimental_rerun()
|
||||||
|
|
||||||
|
@ -344,6 +358,8 @@ def knowledge_page(
|
||||||
for d in api.recreate_vector_store(
|
for d in api.recreate_vector_store(
|
||||||
kb, vs_type=default_vs_type, embed_model=embedding_model, embedding_device=EMBEDDING_DEVICE,
|
kb, vs_type=default_vs_type, embed_model=embedding_model, embedding_device=EMBEDDING_DEVICE,
|
||||||
embed_model_path=embedding_model_dict["embedding_model"], embed_engine=EMBEDDING_ENGINE,
|
embed_model_path=embedding_model_dict["embedding_model"], embed_engine=EMBEDDING_ENGINE,
|
||||||
|
api_key=llm_model_dict[LLM_MODEL]["api_key"],
|
||||||
|
api_base_url=llm_model_dict[LLM_MODEL]["api_base_url"],
|
||||||
):
|
):
|
||||||
if msg := check_error_msg(d):
|
if msg := check_error_msg(d):
|
||||||
st.toast(msg)
|
st.toast(msg)
|
||||||
|
|
|
@ -299,7 +299,9 @@ class ApiRequest:
|
||||||
stream: bool = True,
|
stream: bool = True,
|
||||||
no_remote_api: bool = None,
|
no_remote_api: bool = None,
|
||||||
embed_model: str="", embed_model_path: str="", model_device: str="", embed_engine: str="",
|
embed_model: str="", embed_model_path: str="", model_device: str="", embed_engine: str="",
|
||||||
llm_model: str ="", temperature: float= 0.2
|
llm_model: str ="", temperature: float= 0.2,
|
||||||
|
api_key: str=os.environ["OPENAI_API_KEY"],
|
||||||
|
api_base_url: str = os.environ["API_BASE_URL"],
|
||||||
):
|
):
|
||||||
'''
|
'''
|
||||||
对应api.py/chat/chat接口
|
对应api.py/chat/chat接口
|
||||||
|
@ -311,8 +313,8 @@ class ApiRequest:
|
||||||
"query": query,
|
"query": query,
|
||||||
"history": history,
|
"history": history,
|
||||||
"stream": stream,
|
"stream": stream,
|
||||||
"api_key": os.environ["OPENAI_API_KEY"],
|
"api_key": api_key,
|
||||||
"api_base_url": os.environ["API_BASE_URL"],
|
"api_base_url": api_base_url,
|
||||||
"embed_model": embed_model,
|
"embed_model": embed_model,
|
||||||
"embed_model_path": embed_model_path,
|
"embed_model_path": embed_model_path,
|
||||||
"embed_engine": embed_engine,
|
"embed_engine": embed_engine,
|
||||||
|
@ -339,7 +341,9 @@ class ApiRequest:
|
||||||
stream: bool = True,
|
stream: bool = True,
|
||||||
no_remote_api: bool = None,
|
no_remote_api: bool = None,
|
||||||
embed_model: str="", embed_model_path: str="", model_device: str="", embed_engine: str="",
|
embed_model: str="", embed_model_path: str="", model_device: str="", embed_engine: str="",
|
||||||
llm_model: str ="", temperature: float= 0.2
|
llm_model: str ="", temperature: float= 0.2,
|
||||||
|
api_key: str=os.environ["OPENAI_API_KEY"],
|
||||||
|
api_base_url: str = os.environ["API_BASE_URL"],
|
||||||
):
|
):
|
||||||
'''
|
'''
|
||||||
对应api.py/chat/knowledge_base_chat接口
|
对应api.py/chat/knowledge_base_chat接口
|
||||||
|
@ -355,8 +359,8 @@ class ApiRequest:
|
||||||
"history": history,
|
"history": history,
|
||||||
"stream": stream,
|
"stream": stream,
|
||||||
"local_doc_url": no_remote_api,
|
"local_doc_url": no_remote_api,
|
||||||
"api_key": os.environ["OPENAI_API_KEY"],
|
"api_key": api_key,
|
||||||
"api_base_url": os.environ["API_BASE_URL"],
|
"api_base_url": api_base_url,
|
||||||
"embed_model": embed_model,
|
"embed_model": embed_model,
|
||||||
"embed_model_path": embed_model_path,
|
"embed_model_path": embed_model_path,
|
||||||
"embed_engine": embed_engine,
|
"embed_engine": embed_engine,
|
||||||
|
@ -386,7 +390,10 @@ class ApiRequest:
|
||||||
stream: bool = True,
|
stream: bool = True,
|
||||||
no_remote_api: bool = None,
|
no_remote_api: bool = None,
|
||||||
embed_model: str="", embed_model_path: str="", model_device: str="", embed_engine: str="",
|
embed_model: str="", embed_model_path: str="", model_device: str="", embed_engine: str="",
|
||||||
llm_model: str ="", temperature: float= 0.2
|
llm_model: str ="", temperature: float= 0.2,
|
||||||
|
api_key: str=os.environ["OPENAI_API_KEY"],
|
||||||
|
api_base_url: str = os.environ["API_BASE_URL"],
|
||||||
|
|
||||||
):
|
):
|
||||||
'''
|
'''
|
||||||
对应api.py/chat/search_engine_chat接口
|
对应api.py/chat/search_engine_chat接口
|
||||||
|
@ -400,8 +407,8 @@ class ApiRequest:
|
||||||
"top_k": top_k,
|
"top_k": top_k,
|
||||||
"history": history,
|
"history": history,
|
||||||
"stream": stream,
|
"stream": stream,
|
||||||
"api_key": os.environ["OPENAI_API_KEY"],
|
"api_key": api_key,
|
||||||
"api_base_url": os.environ["API_BASE_URL"],
|
"api_base_url": api_base_url,
|
||||||
"embed_model": embed_model,
|
"embed_model": embed_model,
|
||||||
"embed_model_path": embed_model_path,
|
"embed_model_path": embed_model_path,
|
||||||
"embed_engine": embed_engine,
|
"embed_engine": embed_engine,
|
||||||
|
@ -432,7 +439,9 @@ class ApiRequest:
|
||||||
stream: bool = True,
|
stream: bool = True,
|
||||||
no_remote_api: bool = None,
|
no_remote_api: bool = None,
|
||||||
embed_model: str="", embed_model_path: str="", model_device: str="", embed_engine: str="",
|
embed_model: str="", embed_model_path: str="", model_device: str="", embed_engine: str="",
|
||||||
llm_model: str ="", temperature: float= 0.2
|
llm_model: str ="", temperature: float= 0.2,
|
||||||
|
api_key: str=os.environ["OPENAI_API_KEY"],
|
||||||
|
api_base_url: str = os.environ["API_BASE_URL"],
|
||||||
):
|
):
|
||||||
'''
|
'''
|
||||||
对应api.py/chat/knowledge_base_chat接口
|
对应api.py/chat/knowledge_base_chat接口
|
||||||
|
@ -458,8 +467,8 @@ class ApiRequest:
|
||||||
"cb_search_type": cb_search_type,
|
"cb_search_type": cb_search_type,
|
||||||
"stream": stream,
|
"stream": stream,
|
||||||
"local_doc_url": no_remote_api,
|
"local_doc_url": no_remote_api,
|
||||||
"api_key": os.environ["OPENAI_API_KEY"],
|
"api_key": api_key,
|
||||||
"api_base_url": os.environ["API_BASE_URL"],
|
"api_base_url": api_base_url,
|
||||||
"embed_model": embed_model,
|
"embed_model": embed_model,
|
||||||
"embed_model_path": embed_model_path,
|
"embed_model_path": embed_model_path,
|
||||||
"embed_engine": embed_engine,
|
"embed_engine": embed_engine,
|
||||||
|
@ -510,6 +519,8 @@ class ApiRequest:
|
||||||
embed_model: str="", embed_model_path: str="",
|
embed_model: str="", embed_model_path: str="",
|
||||||
model_device: str="", embed_engine: str="",
|
model_device: str="", embed_engine: str="",
|
||||||
temperature: float=0.2, model_name:str ="",
|
temperature: float=0.2, model_name:str ="",
|
||||||
|
api_key: str=os.environ["OPENAI_API_KEY"],
|
||||||
|
api_base_url: str = os.environ["API_BASE_URL"],
|
||||||
):
|
):
|
||||||
'''
|
'''
|
||||||
对应api.py/chat/chat接口
|
对应api.py/chat/chat接口
|
||||||
|
@ -541,8 +552,8 @@ class ApiRequest:
|
||||||
"isDetailed": isDetailed,
|
"isDetailed": isDetailed,
|
||||||
"upload_file": upload_file,
|
"upload_file": upload_file,
|
||||||
"kb_root_path": kb_root_path,
|
"kb_root_path": kb_root_path,
|
||||||
"api_key": os.environ["OPENAI_API_KEY"],
|
"api_key": api_key,
|
||||||
"api_base_url": os.environ["API_BASE_URL"],
|
"api_base_url": api_base_url,
|
||||||
"embed_model": embed_model,
|
"embed_model": embed_model,
|
||||||
"embed_model_path": embed_model_path,
|
"embed_model_path": embed_model_path,
|
||||||
"embed_engine": embed_engine,
|
"embed_engine": embed_engine,
|
||||||
|
@ -588,6 +599,8 @@ class ApiRequest:
|
||||||
embed_model: str="", embed_model_path: str="",
|
embed_model: str="", embed_model_path: str="",
|
||||||
model_device: str="", embed_engine: str="",
|
model_device: str="", embed_engine: str="",
|
||||||
temperature: float=0.2, model_name: str="",
|
temperature: float=0.2, model_name: str="",
|
||||||
|
api_key: str=os.environ["OPENAI_API_KEY"],
|
||||||
|
api_base_url: str = os.environ["API_BASE_URL"],
|
||||||
):
|
):
|
||||||
'''
|
'''
|
||||||
对应api.py/chat/chat接口
|
对应api.py/chat/chat接口
|
||||||
|
@ -620,8 +633,8 @@ class ApiRequest:
|
||||||
"isDetailed": isDetailed,
|
"isDetailed": isDetailed,
|
||||||
"upload_file": upload_file,
|
"upload_file": upload_file,
|
||||||
"kb_root_path": kb_root_path,
|
"kb_root_path": kb_root_path,
|
||||||
"api_key": os.environ["OPENAI_API_KEY"],
|
"api_key": api_key,
|
||||||
"api_base_url": os.environ["API_BASE_URL"],
|
"api_base_url": api_base_url,
|
||||||
"embed_model": embed_model,
|
"embed_model": embed_model,
|
||||||
"embed_model_path": embed_model_path,
|
"embed_model_path": embed_model_path,
|
||||||
"embed_engine": embed_engine,
|
"embed_engine": embed_engine,
|
||||||
|
@ -694,7 +707,9 @@ class ApiRequest:
|
||||||
no_remote_api: bool = None,
|
no_remote_api: bool = None,
|
||||||
kb_root_path: str =KB_ROOT_PATH,
|
kb_root_path: str =KB_ROOT_PATH,
|
||||||
embed_model: str="", embed_model_path: str="",
|
embed_model: str="", embed_model_path: str="",
|
||||||
embedding_device: str="", embed_engine: str=""
|
embedding_device: str="", embed_engine: str="",
|
||||||
|
api_key: str=os.environ["OPENAI_API_KEY"],
|
||||||
|
api_base_url: str = os.environ["API_BASE_URL"],
|
||||||
):
|
):
|
||||||
'''
|
'''
|
||||||
对应api.py/knowledge_base/create_knowledge_base接口
|
对应api.py/knowledge_base/create_knowledge_base接口
|
||||||
|
@ -706,8 +721,8 @@ class ApiRequest:
|
||||||
"knowledge_base_name": knowledge_base_name,
|
"knowledge_base_name": knowledge_base_name,
|
||||||
"vector_store_type": vector_store_type,
|
"vector_store_type": vector_store_type,
|
||||||
"kb_root_path": kb_root_path,
|
"kb_root_path": kb_root_path,
|
||||||
"api_key": os.environ["OPENAI_API_KEY"],
|
"api_key": api_key,
|
||||||
"api_base_url": os.environ["API_BASE_URL"],
|
"api_base_url": api_base_url,
|
||||||
"embed_model": embed_model,
|
"embed_model": embed_model,
|
||||||
"embed_model_path": embed_model_path,
|
"embed_model_path": embed_model_path,
|
||||||
"model_device": embedding_device,
|
"model_device": embedding_device,
|
||||||
|
@ -781,7 +796,9 @@ class ApiRequest:
|
||||||
no_remote_api: bool = None,
|
no_remote_api: bool = None,
|
||||||
kb_root_path: str = KB_ROOT_PATH,
|
kb_root_path: str = KB_ROOT_PATH,
|
||||||
embed_model: str="", embed_model_path: str="",
|
embed_model: str="", embed_model_path: str="",
|
||||||
model_device: str="", embed_engine: str=""
|
model_device: str="", embed_engine: str="",
|
||||||
|
api_key: str=os.environ["OPENAI_API_KEY"],
|
||||||
|
api_base_url: str = os.environ["API_BASE_URL"],
|
||||||
):
|
):
|
||||||
'''
|
'''
|
||||||
对应api.py/knowledge_base/upload_docs接口
|
对应api.py/knowledge_base/upload_docs接口
|
||||||
|
@ -810,8 +827,8 @@ class ApiRequest:
|
||||||
override,
|
override,
|
||||||
not_refresh_vs_cache,
|
not_refresh_vs_cache,
|
||||||
kb_root_path=kb_root_path,
|
kb_root_path=kb_root_path,
|
||||||
api_key=os.environ["OPENAI_API_KEY"],
|
api_key=api_key,
|
||||||
api_base_url=os.environ["API_BASE_URL"],
|
api_base_url=api_base_url,
|
||||||
embed_model=embed_model,
|
embed_model=embed_model,
|
||||||
embed_model_path=embed_model_path,
|
embed_model_path=embed_model_path,
|
||||||
model_device=model_device,
|
model_device=model_device,
|
||||||
|
@ -839,7 +856,9 @@ class ApiRequest:
|
||||||
no_remote_api: bool = None,
|
no_remote_api: bool = None,
|
||||||
kb_root_path: str = KB_ROOT_PATH,
|
kb_root_path: str = KB_ROOT_PATH,
|
||||||
embed_model: str="", embed_model_path: str="",
|
embed_model: str="", embed_model_path: str="",
|
||||||
model_device: str="", embed_engine: str=""
|
model_device: str="", embed_engine: str="",
|
||||||
|
api_key: str=os.environ["OPENAI_API_KEY"],
|
||||||
|
api_base_url: str = os.environ["API_BASE_URL"],
|
||||||
):
|
):
|
||||||
'''
|
'''
|
||||||
对应api.py/knowledge_base/delete_doc接口
|
对应api.py/knowledge_base/delete_doc接口
|
||||||
|
@ -853,8 +872,8 @@ class ApiRequest:
|
||||||
"delete_content": delete_content,
|
"delete_content": delete_content,
|
||||||
"not_refresh_vs_cache": not_refresh_vs_cache,
|
"not_refresh_vs_cache": not_refresh_vs_cache,
|
||||||
"kb_root_path": kb_root_path,
|
"kb_root_path": kb_root_path,
|
||||||
"api_key": os.environ["OPENAI_API_KEY"],
|
"api_key": api_key,
|
||||||
"api_base_url": os.environ["API_BASE_URL"],
|
"api_base_url": api_base_url,
|
||||||
"embed_model": embed_model,
|
"embed_model": embed_model,
|
||||||
"embed_model_path": embed_model_path,
|
"embed_model_path": embed_model_path,
|
||||||
"model_device": model_device,
|
"model_device": model_device,
|
||||||
|
@ -878,7 +897,9 @@ class ApiRequest:
|
||||||
not_refresh_vs_cache: bool = False,
|
not_refresh_vs_cache: bool = False,
|
||||||
no_remote_api: bool = None,
|
no_remote_api: bool = None,
|
||||||
embed_model: str="", embed_model_path: str="",
|
embed_model: str="", embed_model_path: str="",
|
||||||
model_device: str="", embed_engine: str=""
|
model_device: str="", embed_engine: str="",
|
||||||
|
api_key: str=os.environ["OPENAI_API_KEY"],
|
||||||
|
api_base_url: str = os.environ["API_BASE_URL"],
|
||||||
):
|
):
|
||||||
'''
|
'''
|
||||||
对应api.py/knowledge_base/update_doc接口
|
对应api.py/knowledge_base/update_doc接口
|
||||||
|
@ -889,8 +910,8 @@ class ApiRequest:
|
||||||
if no_remote_api:
|
if no_remote_api:
|
||||||
response = run_async(update_doc(
|
response = run_async(update_doc(
|
||||||
knowledge_base_name, file_name, not_refresh_vs_cache, kb_root_path=KB_ROOT_PATH,
|
knowledge_base_name, file_name, not_refresh_vs_cache, kb_root_path=KB_ROOT_PATH,
|
||||||
api_key=os.environ["OPENAI_API_KEY"],
|
api_key=api_key,
|
||||||
api_base_url=os.environ["API_BASE_URL"],
|
api_base_url=api_base_url,
|
||||||
embed_model=embed_model,
|
embed_model=embed_model,
|
||||||
embed_model_path=embed_model_path,
|
embed_model_path=embed_model_path,
|
||||||
model_device=model_device,
|
model_device=model_device,
|
||||||
|
@ -915,7 +936,9 @@ class ApiRequest:
|
||||||
no_remote_api: bool = None,
|
no_remote_api: bool = None,
|
||||||
kb_root_path: str =KB_ROOT_PATH,
|
kb_root_path: str =KB_ROOT_PATH,
|
||||||
embed_model: str="", embed_model_path: str="",
|
embed_model: str="", embed_model_path: str="",
|
||||||
embedding_device: str="", embed_engine: str=""
|
embedding_device: str="", embed_engine: str="",
|
||||||
|
api_key: str=os.environ["OPENAI_API_KEY"],
|
||||||
|
api_base_url: str = os.environ["API_BASE_URL"],
|
||||||
):
|
):
|
||||||
'''
|
'''
|
||||||
对应api.py/knowledge_base/recreate_vector_store接口
|
对应api.py/knowledge_base/recreate_vector_store接口
|
||||||
|
@ -928,8 +951,8 @@ class ApiRequest:
|
||||||
"allow_empty_kb": allow_empty_kb,
|
"allow_empty_kb": allow_empty_kb,
|
||||||
"vs_type": vs_type,
|
"vs_type": vs_type,
|
||||||
"kb_root_path": kb_root_path,
|
"kb_root_path": kb_root_path,
|
||||||
"api_key": os.environ["OPENAI_API_KEY"],
|
"api_key": api_key,
|
||||||
"api_base_url": os.environ["API_BASE_URL"],
|
"api_base_url": api_base_url,
|
||||||
"embed_model": embed_model,
|
"embed_model": embed_model,
|
||||||
"embed_model_path": embed_model_path,
|
"embed_model_path": embed_model_path,
|
||||||
"model_device": embedding_device,
|
"model_device": embedding_device,
|
||||||
|
@ -1041,7 +1064,9 @@ class ApiRequest:
|
||||||
# code base 相关操作
|
# code base 相关操作
|
||||||
def create_code_base(self, cb_name, zip_file, do_interpret: bool, no_remote_api: bool = None,
|
def create_code_base(self, cb_name, zip_file, do_interpret: bool, no_remote_api: bool = None,
|
||||||
embed_model: str="", embed_model_path: str="", embedding_device: str="", embed_engine: str="",
|
embed_model: str="", embed_model_path: str="", embedding_device: str="", embed_engine: str="",
|
||||||
llm_model: str ="", temperature: float= 0.2
|
llm_model: str ="", temperature: float= 0.2,
|
||||||
|
api_key: str=os.environ["OPENAI_API_KEY"],
|
||||||
|
api_base_url: str = os.environ["API_BASE_URL"],
|
||||||
):
|
):
|
||||||
'''
|
'''
|
||||||
创建 code_base
|
创建 code_base
|
||||||
|
@ -1067,8 +1092,8 @@ class ApiRequest:
|
||||||
"cb_name": cb_name,
|
"cb_name": cb_name,
|
||||||
"code_path": raw_code_path,
|
"code_path": raw_code_path,
|
||||||
"do_interpret": do_interpret,
|
"do_interpret": do_interpret,
|
||||||
"api_key": os.environ["OPENAI_API_KEY"],
|
"api_key": api_key,
|
||||||
"api_base_url": os.environ["API_BASE_URL"],
|
"api_base_url": api_base_url,
|
||||||
"embed_model": embed_model,
|
"embed_model": embed_model,
|
||||||
"embed_model_path": embed_model_path,
|
"embed_model_path": embed_model_path,
|
||||||
"embed_engine": embed_engine,
|
"embed_engine": embed_engine,
|
||||||
|
@ -1091,7 +1116,9 @@ class ApiRequest:
|
||||||
|
|
||||||
def delete_code_base(self, cb_name: str, no_remote_api: bool = None,
|
def delete_code_base(self, cb_name: str, no_remote_api: bool = None,
|
||||||
embed_model: str="", embed_model_path: str="", embedding_device: str="", embed_engine: str="",
|
embed_model: str="", embed_model_path: str="", embedding_device: str="", embed_engine: str="",
|
||||||
llm_model: str ="", temperature: float= 0.2
|
llm_model: str ="", temperature: float= 0.2,
|
||||||
|
api_key: str=os.environ["OPENAI_API_KEY"],
|
||||||
|
api_base_url: str = os.environ["API_BASE_URL"],
|
||||||
):
|
):
|
||||||
'''
|
'''
|
||||||
删除 code_base
|
删除 code_base
|
||||||
|
@ -1102,8 +1129,8 @@ class ApiRequest:
|
||||||
no_remote_api = self.no_remote_api
|
no_remote_api = self.no_remote_api
|
||||||
data = {
|
data = {
|
||||||
"cb_name": cb_name,
|
"cb_name": cb_name,
|
||||||
"api_key": os.environ["OPENAI_API_KEY"],
|
"api_key": api_key,
|
||||||
"api_base_url": os.environ["API_BASE_URL"],
|
"api_base_url": api_base_url,
|
||||||
"embed_model": embed_model,
|
"embed_model": embed_model,
|
||||||
"embed_model_path": embed_model_path,
|
"embed_model_path": embed_model_path,
|
||||||
"embed_engine": embed_engine,
|
"embed_engine": embed_engine,
|
||||||
|
|
Loading…
Reference in New Issue